diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 25 |
1 files changed, 10 insertions, 15 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2803b5959..bebaa63a2 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1971,31 +1971,26 @@ namespace Slang if (auto higherOrderInvoke = as<DifferentiateExpr>(invoke->functionExpr)) { - if (auto funcDeclExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke))) + FunctionDifferentiableLevel requiredLevel; + if (auto funcDeclExpr = as<DeclRefExpr>( + getInnerMostExprFromHigherOrderExpr(higherOrderInvoke, requiredLevel))) { auto funcDecl = as<FunctionDeclBase>(funcDeclExpr->declRef.getDecl()); if (funcDecl) { - DifferentiateExpr* forwardDiff = nullptr; - DifferentiateExpr* backwardDiff = nullptr; - for (auto node = as<DifferentiateExpr>(invoke->functionExpr); node; node = as<DifferentiateExpr>(node->baseFunction)) + if (requiredLevel == FunctionDifferentiableLevel::Forward && + !getShared()->isDifferentiableFunc(funcDecl)) { - if (auto fwd = as<ForwardDifferentiateExpr>(node)) - forwardDiff = fwd; - if (auto bwd = as<BackwardDifferentiateExpr>(node)) - backwardDiff = bwd; + getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); } - if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl)) + if (requiredLevel == FunctionDifferentiableLevel::Backward && + !getShared()->isBackwardDifferentiableFunc(funcDecl)) { - getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); - } - if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl)) - { - getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); + getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); } if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl)) { - getSink()->diagnose(forwardDiff, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); + getSink()->diagnose(invoke->functionExpr, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); } } } |
