diff options
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c0253fd2c..eaab43ef8 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1955,8 +1955,11 @@ namespace Slang bool hasForwardDerivative = false; if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()) { - if (!satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>()) + auto funcDecl = as<FunctionDeclBase>(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + + if (getShared()->getFuncDifferentiableLevel(funcDecl) != FunctionDifferentiableLevel::Backward) { // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. return false; @@ -1966,12 +1969,12 @@ namespace Slang } else if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()) { - if (!satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDerivativeAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>()) + auto funcDecl = as<FunctionDeclBase>(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + if (getShared()->getFuncDifferentiableLevel(funcDecl) == FunctionDifferentiableLevel::None) { - // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa. + // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. return false; } hasForwardDerivative = true; @@ -6674,6 +6677,9 @@ namespace Slang if (func->findModifier<BackwardDerivativeAttribute>()) return FunctionDifferentiableLevel::Backward; + if (func->findModifier<TreatAsDifferentiableAttribute>()) + return FunctionDifferentiableLevel::Backward; + FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None; if (func->findModifier<DifferentiableAttribute>()) diffLevel = FunctionDifferentiableLevel::Forward; |
