summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-06-25 15:45:34 -0400
committerGitHub <noreply@github.com>2022-06-25 12:45:34 -0700
commit8da47c460df01fad6f1d0614210a770f4781edb1 (patch)
tree170a5cc100c69e387e8c6d34217588ea00daed53 /source/slang/slang-check-expr.cpp
parent0229784b93a43e17a088881e6be32b44fc6ce713 (diff)
Added basic auto-diff capabilities for local load/store and simple arithmetic. Also added type-checking during the semantic stage. (#2303)
* Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types * Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly * Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp40
1 files changed, 35 insertions, 5 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index ff469428b..576220c02 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1509,16 +1509,46 @@ namespace Slang
return expr;
}
- Expr* SemanticsExprVisitor::visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr)
+ Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
{
// Check/Resolve inner function declaration.
- expr->baseFn = CheckTerm(expr->baseFn);
+ expr->baseFunction = CheckTerm(expr->baseFunction);
- if(auto funcType = as<FuncType>(expr->baseFn->type))
+ if(auto primalType = as<FuncType>(expr->baseFunction->type))
{
// Resolve JVP type here.
- // Temporarily resolving to the same type as the original function.
- expr->type = expr->baseFn->type;
+ // Note that this type checking needs to be in sync with
+ // the auto-generation logic in slang-ir-jvp-diff.cpp
+
+ auto astBuilder = this->getASTBuilder();
+ FuncType* jvpType = astBuilder->create<FuncType>();
+
+ // Only float types can be differentiated for now.
+
+ // The JVP return type is float if primal return type is float
+ // void otherwise.
+ //
+ if (primalType->resultType->equals(astBuilder->getFloatType()))
+ jvpType->resultType = astBuilder->getFloatType();
+ else
+ jvpType->resultType = astBuilder->getVoidType();
+
+ // No support for differentiating function that throw errors, for now.
+ SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType()));
+ jvpType->errorType = primalType->errorType;
+
+ for (UInt i = 0; i < primalType->getParamCount(); i++)
+ {
+ jvpType->paramTypes.add(primalType->getParamType(i));
+ }
+
+ for (UInt i = 0; i < primalType->getParamCount(); i++)
+ {
+ if(primalType->getParamType(i)->equals(astBuilder->getFloatType()))
+ jvpType->paramTypes.add(astBuilder->getFloatType());
+ }
+
+ expr->type = jvpType;
}
else
{