From d58e08f8237a1888ceaad53402d534679ea83b1a Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 18 Nov 2022 12:37:27 -0800 Subject: Data flow validation pass for diagnosing derivative loss. (#2523) --- source/slang/slang-check-expr.cpp | 75 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index b43a03150..2c6899269 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1893,6 +1893,34 @@ namespace Slang } } } + + if (auto higherOrderInvoke = as(invoke->functionExpr)) + { + if (auto funcDeclExpr = as(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke))) + { + auto funcDecl = as(funcDeclExpr->declRef.getDecl()); + if (funcDecl) + { + DifferentiateExpr* forwardDiff = nullptr; + DifferentiateExpr* backwardDiff = nullptr; + for (auto node = as(invoke->functionExpr); node; node = as(node->baseFunction)) + { + if (auto fwd = as(node)) + forwardDiff = fwd; + if (auto bwd = as(node)) + backwardDiff = bwd; + } + if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl)) + { + getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); + } + if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl)) + { + getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); + } + } + } + } } } return rs; @@ -1920,7 +1948,7 @@ namespace Slang auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr); - if (m_parentFunc && m_parentFunc->hasModifier()) + if (m_parentDifferentiableAttr) { if (auto checkedInvokeExpr = as(checkedExpr)) { @@ -1929,6 +1957,30 @@ namespace Slang { maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); } + if (auto calleeExpr = as(checkedInvokeExpr->functionExpr)) + { + if (auto calleeDecl = as(calleeExpr->declRef.getDecl())) + { + if (getShared()->isDifferentiableFunc(calleeDecl)) + { + if (!m_treatAsDifferentiableExpr) + { + auto newFuncExpr = + getASTBuilder()->create(); + newFuncExpr->type = checkedInvokeExpr->type; + newFuncExpr->innerExpr = checkedInvokeExpr; + newFuncExpr->loc = checkedInvokeExpr->loc; + checkedExpr = newFuncExpr; + } + else + { + getSink()->diagnose( + m_treatAsDifferentiableExpr, + Diagnostics::useOfNoDiffOnDifferentiableFunc); + } + } + } + } } maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type); } @@ -2227,6 +2279,27 @@ namespace Slang return _checkDifferentiateExpr(this, expr, &actions); } + Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) + { + auto subContext = withTreatAsDifferentiable(expr); + expr->innerExpr = dispatchExpr(expr->innerExpr, subContext); + expr->type = expr->innerExpr->type; + auto innerExpr = expr->innerExpr; + while (auto parenExpr = as(innerExpr)) + { + innerExpr = parenExpr->base; + } + if (!as(innerExpr)) + { + getSink()->diagnose(expr, Diagnostics::invalidUseOfNoDiff); + } + else if (!m_parentDifferentiableAttr) + { + getSink()->diagnose(expr, Diagnostics::cannotUseNoDiffInNonDifferentiableFunc); + } + return expr; + } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { expr->arrayExpr = CheckTerm(expr->arrayExpr); -- cgit v1.2.3