diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-07 11:22:32 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-07 11:22:32 -0800 |
| commit | 257733f328f38a763c8b0c8830ff4c0d34ec9491 (patch) | |
| tree | 87e444746f353d69a365380904f3f8caf15fbfec /source/slang/slang-check-expr.cpp | |
| parent | 6f31eae79d5b4297d0099c5779a9806a786cf9f8 (diff) | |
Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688)
* Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs.
* Add diff implementation matrix versions of binary and ternary intrinsics.
* Add diff impl for legacy intrinsics.
* Fix diagnostics of using non-differentiable function in a diff operator.
* Add diff implementation for `determinant`.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 25 |
1 files changed, 10 insertions, 15 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2803b5959..bebaa63a2 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1971,31 +1971,26 @@ namespace Slang if (auto higherOrderInvoke = as<DifferentiateExpr>(invoke->functionExpr)) { - if (auto funcDeclExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke))) + FunctionDifferentiableLevel requiredLevel; + if (auto funcDeclExpr = as<DeclRefExpr>( + getInnerMostExprFromHigherOrderExpr(higherOrderInvoke, requiredLevel))) { auto funcDecl = as<FunctionDeclBase>(funcDeclExpr->declRef.getDecl()); if (funcDecl) { - DifferentiateExpr* forwardDiff = nullptr; - DifferentiateExpr* backwardDiff = nullptr; - for (auto node = as<DifferentiateExpr>(invoke->functionExpr); node; node = as<DifferentiateExpr>(node->baseFunction)) + if (requiredLevel == FunctionDifferentiableLevel::Forward && + !getShared()->isDifferentiableFunc(funcDecl)) { - if (auto fwd = as<ForwardDifferentiateExpr>(node)) - forwardDiff = fwd; - if (auto bwd = as<BackwardDifferentiateExpr>(node)) - backwardDiff = bwd; + getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); } - if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl)) + if (requiredLevel == FunctionDifferentiableLevel::Backward && + !getShared()->isBackwardDifferentiableFunc(funcDecl)) { - getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); - } - if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl)) - { - getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); + getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); } if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl)) { - getSink()->diagnose(forwardDiff, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); + getSink()->diagnose(invoke->functionExpr, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); } } } |
