From 8da47c460df01fad6f1d0614210a770f4781edb1 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Sat, 25 Jun 2022 15:45:34 -0400 Subject: Added basic auto-diff capabilities for local load/store and simple arithmetic. Also added type-checking during the semantic stage. (#2303) * Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types * Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly * Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated Co-authored-by: Yong He --- source/slang/slang-check-expr.cpp | 40 ++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index ff469428b..576220c02 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,16 +1509,46 @@ namespace Slang return expr; } - Expr* SemanticsExprVisitor::visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr) + Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) { // Check/Resolve inner function declaration. - expr->baseFn = CheckTerm(expr->baseFn); + expr->baseFunction = CheckTerm(expr->baseFunction); - if(auto funcType = as(expr->baseFn->type)) + if(auto primalType = as(expr->baseFunction->type)) { // Resolve JVP type here. - // Temporarily resolving to the same type as the original function. - expr->type = expr->baseFn->type; + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp + + auto astBuilder = this->getASTBuilder(); + FuncType* jvpType = astBuilder->create(); + + // Only float types can be differentiated for now. + + // The JVP return type is float if primal return type is float + // void otherwise. + // + if (primalType->resultType->equals(astBuilder->getFloatType())) + jvpType->resultType = astBuilder->getFloatType(); + else + jvpType->resultType = astBuilder->getVoidType(); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); + jvpType->errorType = primalType->errorType; + + for (UInt i = 0; i < primalType->getParamCount(); i++) + { + jvpType->paramTypes.add(primalType->getParamType(i)); + } + + for (UInt i = 0; i < primalType->getParamCount(); i++) + { + if(primalType->getParamType(i)->equals(astBuilder->getFloatType())) + jvpType->paramTypes.add(astBuilder->getFloatType()); + } + + expr->type = jvpType; } else { -- cgit v1.2.3