From 368ec3116ea0f10f44acbf76b5dc9e34d6ff3d32 Mon Sep 17 00:00:00 2001 From: Edward Liu Date: Mon, 14 Nov 2022 12:08:01 -0800 Subject: Minimum binary arithmetic reverse autodiff working. (#2514) * Initial plumbing of backward autodiff in the frontend. * More plumbing. * Initial reverse autodiff working. * Bug fixes. * Misc. * Remove redundant code. * More clean up. * Misc. * Rebase and add backward diff test. * Disable test. * Clean up. * Minor fix. Co-authored-by: Yong He --- source/slang/slang-check-expr.cpp | 59 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 30db9ecfa..f568dd8df 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -969,6 +969,14 @@ namespace Slang maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); } + // Differentiable type checking. + // TODO: This can be super slow. + if (this->m_parentFunc && + this->m_parentFunc->findModifier()) + { + maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + } + return checkedTerm; } @@ -2027,6 +2035,45 @@ namespace Slang return jvpType; } + Type* SemanticsVisitor::processBackwardDiffFuncType(FuncType* originalType) + { + // Resolve backward diff type here. + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp + + FuncType* type = m_astBuilder->create(); + + // The backward diff return type is void + // + type->resultType = m_astBuilder->getVoidType(); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); + type->errorType = originalType->errorType; + + for (UInt i = 0; i < originalType->getParamCount(); i++) + { + if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + { + // Using inout type on all the derivative parameters + if (auto outType = as(derivType)) + { + derivType = outType->getValueType(); + } + else if (!as(derivType)) + { + derivType = m_astBuilder->getInOutType(derivType); + } + type->paramTypes.add(derivType); + } + } + + // Last parameter is the initial derivative of the original return type + type->paramTypes.add(originalType->resultType); + + return type; + } + Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { // Check/Resolve inner function declaration. @@ -2039,6 +2086,18 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + { + // Check/Resolve inner function declaration. + expr->baseFunction = CheckTerm(expr->baseFunction); + + // For now we only support using higher order expr as callee in an invoke expr. + // The actual type of the higher order function will be derived during resolve invoke. + expr->type = m_astBuilder->getBottomType(); + + return expr; + } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { expr->arrayExpr = CheckTerm(expr->arrayExpr); -- cgit v1.2.3