diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-06-25 15:45:34 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-25 12:45:34 -0700 |
| commit | 8da47c460df01fad6f1d0614210a770f4781edb1 (patch) | |
| tree | 170a5cc100c69e387e8c6d34217588ea00daed53 /source/slang/slang-check-expr.cpp | |
| parent | 0229784b93a43e17a088881e6be32b44fc6ce713 (diff) | |
Added basic auto-diff capabilities for local load/store and simple arithmetic. Also added type-checking during the semantic stage. (#2303)
* Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types
* Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly
* Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 40 |
1 files changed, 35 insertions, 5 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index ff469428b..576220c02 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,16 +1509,46 @@ namespace Slang return expr; } - Expr* SemanticsExprVisitor::visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr) + Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) { // Check/Resolve inner function declaration. - expr->baseFn = CheckTerm(expr->baseFn); + expr->baseFunction = CheckTerm(expr->baseFunction); - if(auto funcType = as<FuncType>(expr->baseFn->type)) + if(auto primalType = as<FuncType>(expr->baseFunction->type)) { // Resolve JVP type here. - // Temporarily resolving to the same type as the original function. - expr->type = expr->baseFn->type; + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp + + auto astBuilder = this->getASTBuilder(); + FuncType* jvpType = astBuilder->create<FuncType>(); + + // Only float types can be differentiated for now. + + // The JVP return type is float if primal return type is float + // void otherwise. + // + if (primalType->resultType->equals(astBuilder->getFloatType())) + jvpType->resultType = astBuilder->getFloatType(); + else + jvpType->resultType = astBuilder->getVoidType(); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); + jvpType->errorType = primalType->errorType; + + for (UInt i = 0; i < primalType->getParamCount(); i++) + { + jvpType->paramTypes.add(primalType->getParamType(i)); + } + + for (UInt i = 0; i < primalType->getParamCount(); i++) + { + if(primalType->getParamType(i)->equals(astBuilder->getFloatType())) + jvpType->paramTypes.add(astBuilder->getFloatType()); + } + + expr->type = jvpType; } else { |
