summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorEdward Liu <shiqiu1105@gmail.com>2022-11-14 12:08:01 -0800
committerGitHub <noreply@github.com>2022-11-14 12:08:01 -0800
commit368ec3116ea0f10f44acbf76b5dc9e34d6ff3d32 (patch)
tree3d9def111db278affb8413bddb5aab9ce3cf73a6 /source/slang/slang-check-expr.cpp
parent623f5c36e0dc8190753aa5fa2e89f1010c367c67 (diff)
Minimum binary arithmetic reverse autodiff working. (#2514)
* Initial plumbing of backward autodiff in the frontend. * More plumbing. * Initial reverse autodiff working. * Bug fixes. * Misc. * Remove redundant code. * More clean up. * Misc. * Rebase and add backward diff test. * Disable test. * Clean up. * Minor fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp59
1 files changed, 59 insertions, 0 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 30db9ecfa..f568dd8df 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -969,6 +969,14 @@ namespace Slang
maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
}
+ // Differentiable type checking.
+ // TODO: This can be super slow.
+ if (this->m_parentFunc &&
+ this->m_parentFunc->findModifier<BackwardDifferentiableAttribute>())
+ {
+ maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
+ }
+
return checkedTerm;
}
@@ -2027,6 +2035,45 @@ namespace Slang
return jvpType;
}
+ Type* SemanticsVisitor::processBackwardDiffFuncType(FuncType* originalType)
+ {
+ // Resolve backward diff type here.
+ // Note that this type checking needs to be in sync with
+ // the auto-generation logic in slang-ir-jvp-diff.cpp
+
+ FuncType* type = m_astBuilder->create<FuncType>();
+
+ // The backward diff return type is void
+ //
+ type->resultType = m_astBuilder->getVoidType();
+
+ // No support for differentiating function that throw errors, for now.
+ SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType()));
+ type->errorType = originalType->errorType;
+
+ for (UInt i = 0; i < originalType->getParamCount(); i++)
+ {
+ if (auto derivType = _toDifferentialParamType(originalType->getParamType(i)))
+ {
+ // Using inout type on all the derivative parameters
+ if (auto outType = as<OutType>(derivType))
+ {
+ derivType = outType->getValueType();
+ }
+ else if (!as<PtrTypeBase>(derivType))
+ {
+ derivType = m_astBuilder->getInOutType(derivType);
+ }
+ type->paramTypes.add(derivType);
+ }
+ }
+
+ // Last parameter is the initial derivative of the original return type
+ type->paramTypes.add(originalType->resultType);
+
+ return type;
+ }
+
Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
{
// Check/Resolve inner function declaration.
@@ -2039,6 +2086,18 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr)
+ {
+ // Check/Resolve inner function declaration.
+ expr->baseFunction = CheckTerm(expr->baseFunction);
+
+ // For now we only support using higher order expr as callee in an invoke expr.
+ // The actual type of the higher order function will be derived during resolve invoke.
+ expr->type = m_astBuilder->getBottomType();
+
+ return expr;
+ }
+
Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
{
expr->arrayExpr = CheckTerm(expr->arrayExpr);