diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 40 |
1 files changed, 35 insertions, 5 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index ff469428b..576220c02 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,16 +1509,46 @@ namespace Slang return expr; } - Expr* SemanticsExprVisitor::visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr) + Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) { // Check/Resolve inner function declaration. - expr->baseFn = CheckTerm(expr->baseFn); + expr->baseFunction = CheckTerm(expr->baseFunction); - if(auto funcType = as<FuncType>(expr->baseFn->type)) + if(auto primalType = as<FuncType>(expr->baseFunction->type)) { // Resolve JVP type here. - // Temporarily resolving to the same type as the original function. - expr->type = expr->baseFn->type; + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp + + auto astBuilder = this->getASTBuilder(); + FuncType* jvpType = astBuilder->create<FuncType>(); + + // Only float types can be differentiated for now. + + // The JVP return type is float if primal return type is float + // void otherwise. + // + if (primalType->resultType->equals(astBuilder->getFloatType())) + jvpType->resultType = astBuilder->getFloatType(); + else + jvpType->resultType = astBuilder->getVoidType(); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); + jvpType->errorType = primalType->errorType; + + for (UInt i = 0; i < primalType->getParamCount(); i++) + { + jvpType->paramTypes.add(primalType->getParamType(i)); + } + + for (UInt i = 0; i < primalType->getParamCount(); i++) + { + if(primalType->getParamType(i)->equals(astBuilder->getFloatType())) + jvpType->paramTypes.add(astBuilder->getFloatType()); + } + + expr->type = jvpType; } else { |
