diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-18 12:37:27 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-18 12:37:27 -0800 |
| commit | d58e08f8237a1888ceaad53402d534679ea83b1a (patch) | |
| tree | e66838e0dc31fc12ebd7c1acecbb5060e8808366 /source/slang/slang-check-expr.cpp | |
| parent | 0a050a439fa91b66f2020421d4fec3e60aed4112 (diff) | |
Data flow validation pass for diagnosing derivative loss. (#2523)
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 75 |
1 files changed, 74 insertions, 1 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index b43a03150..2c6899269 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1893,6 +1893,34 @@ namespace Slang } } } + + if (auto higherOrderInvoke = as<DifferentiateExpr>(invoke->functionExpr)) + { + if (auto funcDeclExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke))) + { + auto funcDecl = as<FunctionDeclBase>(funcDeclExpr->declRef.getDecl()); + if (funcDecl) + { + DifferentiateExpr* forwardDiff = nullptr; + DifferentiateExpr* backwardDiff = nullptr; + for (auto node = as<DifferentiateExpr>(invoke->functionExpr); node; node = as<DifferentiateExpr>(node->baseFunction)) + { + if (auto fwd = as<ForwardDifferentiateExpr>(node)) + forwardDiff = fwd; + if (auto bwd = as<BackwardDifferentiateExpr>(node)) + backwardDiff = bwd; + } + if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl)) + { + getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); + } + if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl)) + { + getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); + } + } + } + } } } return rs; @@ -1920,7 +1948,7 @@ namespace Slang auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr); - if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>()) + if (m_parentDifferentiableAttr) { if (auto checkedInvokeExpr = as<InvokeExpr>(checkedExpr)) { @@ -1929,6 +1957,30 @@ namespace Slang { maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); } + if (auto calleeExpr = as<DeclRefExpr>(checkedInvokeExpr->functionExpr)) + { + if (auto calleeDecl = as<FunctionDeclBase>(calleeExpr->declRef.getDecl())) + { + if (getShared()->isDifferentiableFunc(calleeDecl)) + { + if (!m_treatAsDifferentiableExpr) + { + auto newFuncExpr = + getASTBuilder()->create<TreatAsDifferentiableExpr>(); + newFuncExpr->type = checkedInvokeExpr->type; + newFuncExpr->innerExpr = checkedInvokeExpr; + newFuncExpr->loc = checkedInvokeExpr->loc; + checkedExpr = newFuncExpr; + } + else + { + getSink()->diagnose( + m_treatAsDifferentiableExpr, + Diagnostics::useOfNoDiffOnDifferentiableFunc); + } + } + } + } } maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type); } @@ -2227,6 +2279,27 @@ namespace Slang return _checkDifferentiateExpr(this, expr, &actions); } + Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) + { + auto subContext = withTreatAsDifferentiable(expr); + expr->innerExpr = dispatchExpr(expr->innerExpr, subContext); + expr->type = expr->innerExpr->type; + auto innerExpr = expr->innerExpr; + while (auto parenExpr = as<ParenExpr>(innerExpr)) + { + innerExpr = parenExpr->base; + } + if (!as<InvokeExpr>(innerExpr)) + { + getSink()->diagnose(expr, Diagnostics::invalidUseOfNoDiff); + } + else if (!m_parentDifferentiableAttr) + { + getSink()->diagnose(expr, Diagnostics::cannotUseNoDiffInNonDifferentiableFunc); + } + return expr; + } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { expr->arrayExpr = CheckTerm(expr->arrayExpr); |
