diff options
| author | Edward Liu <shiqiu1105@gmail.com> | 2022-11-14 12:08:01 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-14 12:08:01 -0800 |
| commit | 368ec3116ea0f10f44acbf76b5dc9e34d6ff3d32 (patch) | |
| tree | 3d9def111db278affb8413bddb5aab9ce3cf73a6 /source/slang/slang-check-expr.cpp | |
| parent | 623f5c36e0dc8190753aa5fa2e89f1010c367c67 (diff) | |
Minimum binary arithmetic reverse autodiff working. (#2514)
* Initial plumbing of backward autodiff in the frontend.
* More plumbing.
* Initial reverse autodiff working.
* Bug fixes.
* Misc.
* Remove redundant code.
* More clean up.
* Misc.
* Rebase and add backward diff test.
* Disable test.
* Clean up.
* Minor fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 30db9ecfa..f568dd8df 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -969,6 +969,14 @@ namespace Slang maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); } + // Differentiable type checking. + // TODO: This can be super slow. + if (this->m_parentFunc && + this->m_parentFunc->findModifier<BackwardDifferentiableAttribute>()) + { + maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + } + return checkedTerm; } @@ -2027,6 +2035,45 @@ namespace Slang return jvpType; } + Type* SemanticsVisitor::processBackwardDiffFuncType(FuncType* originalType) + { + // Resolve backward diff type here. + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp + + FuncType* type = m_astBuilder->create<FuncType>(); + + // The backward diff return type is void + // + type->resultType = m_astBuilder->getVoidType(); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); + type->errorType = originalType->errorType; + + for (UInt i = 0; i < originalType->getParamCount(); i++) + { + if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + { + // Using inout type on all the derivative parameters + if (auto outType = as<OutType>(derivType)) + { + derivType = outType->getValueType(); + } + else if (!as<PtrTypeBase>(derivType)) + { + derivType = m_astBuilder->getInOutType(derivType); + } + type->paramTypes.add(derivType); + } + } + + // Last parameter is the initial derivative of the original return type + type->paramTypes.add(originalType->resultType); + + return type; + } + Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { // Check/Resolve inner function declaration. @@ -2039,6 +2086,18 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + { + // Check/Resolve inner function declaration. + expr->baseFunction = CheckTerm(expr->baseFunction); + + // For now we only support using higher order expr as callee in an invoke expr. + // The actual type of the higher order function will be derived during resolve invoke. + expr->type = m_astBuilder->getBottomType(); + + return expr; + } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { expr->arrayExpr = CheckTerm(expr->arrayExpr); |
