summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-18 12:37:27 -0800
committerGitHub <noreply@github.com>2022-11-18 12:37:27 -0800
commitd58e08f8237a1888ceaad53402d534679ea83b1a (patch)
treee66838e0dc31fc12ebd7c1acecbb5060e8808366 /source/slang/slang-check-expr.cpp
parent0a050a439fa91b66f2020421d4fec3e60aed4112 (diff)
Data flow validation pass for diagnosing derivative loss. (#2523)
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp75
1 files changed, 74 insertions, 1 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index b43a03150..2c6899269 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1893,6 +1893,34 @@ namespace Slang
}
}
}
+
+ if (auto higherOrderInvoke = as<DifferentiateExpr>(invoke->functionExpr))
+ {
+ if (auto funcDeclExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke)))
+ {
+ auto funcDecl = as<FunctionDeclBase>(funcDeclExpr->declRef.getDecl());
+ if (funcDecl)
+ {
+ DifferentiateExpr* forwardDiff = nullptr;
+ DifferentiateExpr* backwardDiff = nullptr;
+ for (auto node = as<DifferentiateExpr>(invoke->functionExpr); node; node = as<DifferentiateExpr>(node->baseFunction))
+ {
+ if (auto fwd = as<ForwardDifferentiateExpr>(node))
+ forwardDiff = fwd;
+ if (auto bwd = as<BackwardDifferentiateExpr>(node))
+ backwardDiff = bwd;
+ }
+ if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl))
+ {
+ getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward");
+ }
+ if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl))
+ {
+ getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward");
+ }
+ }
+ }
+ }
}
}
return rs;
@@ -1920,7 +1948,7 @@ namespace Slang
auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr);
- if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>())
+ if (m_parentDifferentiableAttr)
{
if (auto checkedInvokeExpr = as<InvokeExpr>(checkedExpr))
{
@@ -1929,6 +1957,30 @@ namespace Slang
{
maybeRegisterDifferentiableType(m_astBuilder, arg->type.type);
}
+ if (auto calleeExpr = as<DeclRefExpr>(checkedInvokeExpr->functionExpr))
+ {
+ if (auto calleeDecl = as<FunctionDeclBase>(calleeExpr->declRef.getDecl()))
+ {
+ if (getShared()->isDifferentiableFunc(calleeDecl))
+ {
+ if (!m_treatAsDifferentiableExpr)
+ {
+ auto newFuncExpr =
+ getASTBuilder()->create<TreatAsDifferentiableExpr>();
+ newFuncExpr->type = checkedInvokeExpr->type;
+ newFuncExpr->innerExpr = checkedInvokeExpr;
+ newFuncExpr->loc = checkedInvokeExpr->loc;
+ checkedExpr = newFuncExpr;
+ }
+ else
+ {
+ getSink()->diagnose(
+ m_treatAsDifferentiableExpr,
+ Diagnostics::useOfNoDiffOnDifferentiableFunc);
+ }
+ }
+ }
+ }
}
maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type);
}
@@ -2227,6 +2279,27 @@ namespace Slang
return _checkDifferentiateExpr(this, expr, &actions);
}
+ Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
+ {
+ auto subContext = withTreatAsDifferentiable(expr);
+ expr->innerExpr = dispatchExpr(expr->innerExpr, subContext);
+ expr->type = expr->innerExpr->type;
+ auto innerExpr = expr->innerExpr;
+ while (auto parenExpr = as<ParenExpr>(innerExpr))
+ {
+ innerExpr = parenExpr->base;
+ }
+ if (!as<InvokeExpr>(innerExpr))
+ {
+ getSink()->diagnose(expr, Diagnostics::invalidUseOfNoDiff);
+ }
+ else if (!m_parentDifferentiableAttr)
+ {
+ getSink()->diagnose(expr, Diagnostics::cannotUseNoDiffInNonDifferentiableFunc);
+ }
+ return expr;
+ }
+
Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
{
expr->arrayExpr = CheckTerm(expr->arrayExpr);