diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-07-11 23:18:06 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-07-11 23:18:06 -0400 |
| commit | b513d0deef521318ad943d820dd37029075a33c4 (patch) | |
| tree | cc6dc625ae381e0461724c5b137e1a034b03e636 /source/slang/slang-check-expr.cpp | |
| parent | 9261c7a23ddf061fe9f5bfc3376f09f3c0513bff (diff) | |
Added support for differentiating calls to basic functions, as well as arithmetic on the float3 type (#2313)
* Added support for differentiating calls to basic functions, as well as arithmetic on the float3 type
* Added test expected result
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index df58b11ed..a3fec4802 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,6 +1509,29 @@ namespace Slang return expr; } + Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) + { + // Only float and float3 types can be differentiated for now. + + if(primalType->equals(builder->getFloatType())) + return primalType; + else if(auto primalVectorType = as<VectorExpressionType>(primalType)) + { + // TODO(sai): There's probably a more elegant way to check if a type is a float3? + if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType())) + return primalVectorType; + } + return nullptr; + } + + Type* primalToJVPReturnType(ASTBuilder* builder, Type* primalType) + { + if(auto jvpType = primalToJVPParamType(builder, primalType)) + return jvpType; + else + return builder->getVoidType(); + } + Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) { // Check/Resolve inner function declaration. @@ -1524,18 +1547,10 @@ namespace Slang FuncType* jvpType = astBuilder->create<FuncType>(); - // Only float types can be differentiated for now. - // The JVP return type is float if primal return type is float // void otherwise. // - if (primalType->resultType->equals(astBuilder->getFloatType())) - jvpType->resultType = astBuilder->getFloatType(); - else - { - //TODO(yong): issue proper diagnostic here. - jvpType->resultType = astBuilder->getVoidType(); - } + jvpType->resultType = primalToJVPReturnType(astBuilder, primalType->getResultType()); // No support for differentiating function that throw errors, for now. SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); @@ -1548,8 +1563,8 @@ namespace Slang for (UInt i = 0; i < primalType->getParamCount(); i++) { - if(primalType->getParamType(i)->equals(astBuilder->getFloatType())) - jvpType->paramTypes.add(astBuilder->getFloatType()); + if(auto jvpParamType = primalToJVPParamType(astBuilder, primalType->getParamType(i))) + jvpType->paramTypes.add(jvpParamType); } expr->type = jvpType; |
