From 257733f328f38a763c8b0c8830ff4c0d34ec9491 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 7 Mar 2023 11:22:32 -0800 Subject: Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688) * Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs. * Add diff implementation matrix versions of binary and ternary intrinsics. * Add diff impl for legacy intrinsics. * Fix diagnostics of using non-differentiable function in a diff operator. * Add diff implementation for `determinant`. --------- Co-authored-by: Yong He --- source/slang/slang-check-expr.cpp | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2803b5959..bebaa63a2 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1971,31 +1971,26 @@ namespace Slang if (auto higherOrderInvoke = as(invoke->functionExpr)) { - if (auto funcDeclExpr = as(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke))) + FunctionDifferentiableLevel requiredLevel; + if (auto funcDeclExpr = as( + getInnerMostExprFromHigherOrderExpr(higherOrderInvoke, requiredLevel))) { auto funcDecl = as(funcDeclExpr->declRef.getDecl()); if (funcDecl) { - DifferentiateExpr* forwardDiff = nullptr; - DifferentiateExpr* backwardDiff = nullptr; - for (auto node = as(invoke->functionExpr); node; node = as(node->baseFunction)) + if (requiredLevel == FunctionDifferentiableLevel::Forward && + !getShared()->isDifferentiableFunc(funcDecl)) { - if (auto fwd = as(node)) - forwardDiff = fwd; - if (auto bwd = as(node)) - backwardDiff = bwd; + getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); } - if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl)) + if (requiredLevel == FunctionDifferentiableLevel::Backward && + !getShared()->isBackwardDifferentiableFunc(funcDecl)) { - getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); - } - if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl)) - { - getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); + getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); } if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl)) { - getSink()->diagnose(forwardDiff, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); + getSink()->diagnose(invoke->functionExpr, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); } } } -- cgit v1.2.3