diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 30db9ecfa..f568dd8df 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -969,6 +969,14 @@ namespace Slang maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); } + // Differentiable type checking. + // TODO: This can be super slow. + if (this->m_parentFunc && + this->m_parentFunc->findModifier<BackwardDifferentiableAttribute>()) + { + maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + } + return checkedTerm; } @@ -2027,6 +2035,45 @@ namespace Slang return jvpType; } + Type* SemanticsVisitor::processBackwardDiffFuncType(FuncType* originalType) + { + // Resolve backward diff type here. + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp + + FuncType* type = m_astBuilder->create<FuncType>(); + + // The backward diff return type is void + // + type->resultType = m_astBuilder->getVoidType(); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); + type->errorType = originalType->errorType; + + for (UInt i = 0; i < originalType->getParamCount(); i++) + { + if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + { + // Using inout type on all the derivative parameters + if (auto outType = as<OutType>(derivType)) + { + derivType = outType->getValueType(); + } + else if (!as<PtrTypeBase>(derivType)) + { + derivType = m_astBuilder->getInOutType(derivType); + } + type->paramTypes.add(derivType); + } + } + + // Last parameter is the initial derivative of the original return type + type->paramTypes.add(originalType->resultType); + + return type; + } + Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { // Check/Resolve inner function declaration. @@ -2039,6 +2086,18 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + { + // Check/Resolve inner function declaration. + expr->baseFunction = CheckTerm(expr->baseFunction); + + // For now we only support using higher order expr as callee in an invoke expr. + // The actual type of the higher order function will be derived during resolve invoke. + expr->type = m_astBuilder->getBottomType(); + + return expr; + } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { expr->arrayExpr = CheckTerm(expr->arrayExpr); |
