summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp20
1 files changed, 13 insertions, 7 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index c0253fd2c..eaab43ef8 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1955,8 +1955,11 @@ namespace Slang
bool hasForwardDerivative = false;
if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>())
{
- if (!satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()
- && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>())
+ auto funcDecl = as<FunctionDeclBase>(satisfyingMemberDeclRef.getDecl());
+ if (!funcDecl)
+ return false;
+
+ if (getShared()->getFuncDifferentiableLevel(funcDecl) != FunctionDifferentiableLevel::Backward)
{
// A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa.
return false;
@@ -1966,12 +1969,12 @@ namespace Slang
}
else if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>())
{
- if (!satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()
- && !satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDerivativeAttribute>()
- && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()
- && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>())
+ auto funcDecl = as<FunctionDeclBase>(satisfyingMemberDeclRef.getDecl());
+ if (!funcDecl)
+ return false;
+ if (getShared()->getFuncDifferentiableLevel(funcDecl) == FunctionDifferentiableLevel::None)
{
- // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa.
+ // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa.
return false;
}
hasForwardDerivative = true;
@@ -6674,6 +6677,9 @@ namespace Slang
if (func->findModifier<BackwardDerivativeAttribute>())
return FunctionDifferentiableLevel::Backward;
+ if (func->findModifier<TreatAsDifferentiableAttribute>())
+ return FunctionDifferentiableLevel::Backward;
+
FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None;
if (func->findModifier<DifferentiableAttribute>())
diffLevel = FunctionDifferentiableLevel::Forward;