summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-07 11:22:32 -0800
committerGitHub <noreply@github.com>2023-03-07 11:22:32 -0800
commit257733f328f38a763c8b0c8830ff4c0d34ec9491 (patch)
tree87e444746f353d69a365380904f3f8caf15fbfec /source/slang/slang-check-expr.cpp
parent6f31eae79d5b4297d0099c5779a9806a786cf9f8 (diff)
Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688)
* Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs. * Add diff implementation matrix versions of binary and ternary intrinsics. * Add diff impl for legacy intrinsics. * Fix diagnostics of using non-differentiable function in a diff operator. * Add diff implementation for `determinant`. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp25
1 files changed, 10 insertions, 15 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 2803b5959..bebaa63a2 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1971,31 +1971,26 @@ namespace Slang
if (auto higherOrderInvoke = as<DifferentiateExpr>(invoke->functionExpr))
{
- if (auto funcDeclExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke)))
+ FunctionDifferentiableLevel requiredLevel;
+ if (auto funcDeclExpr = as<DeclRefExpr>(
+ getInnerMostExprFromHigherOrderExpr(higherOrderInvoke, requiredLevel)))
{
auto funcDecl = as<FunctionDeclBase>(funcDeclExpr->declRef.getDecl());
if (funcDecl)
{
- DifferentiateExpr* forwardDiff = nullptr;
- DifferentiateExpr* backwardDiff = nullptr;
- for (auto node = as<DifferentiateExpr>(invoke->functionExpr); node; node = as<DifferentiateExpr>(node->baseFunction))
+ if (requiredLevel == FunctionDifferentiableLevel::Forward &&
+ !getShared()->isDifferentiableFunc(funcDecl))
{
- if (auto fwd = as<ForwardDifferentiateExpr>(node))
- forwardDiff = fwd;
- if (auto bwd = as<BackwardDifferentiateExpr>(node))
- backwardDiff = bwd;
+ getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward");
}
- if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl))
+ if (requiredLevel == FunctionDifferentiableLevel::Backward &&
+ !getShared()->isBackwardDifferentiableFunc(funcDecl))
{
- getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward");
- }
- if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl))
- {
- getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward");
+ getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward");
}
if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl))
{
- getSink()->diagnose(forwardDiff, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl);
+ getSink()->diagnose(invoke->functionExpr, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl);
}
}
}