From d8a40abba5223fbcb56c52b04ccb88c02bbaf79f Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 21 Mar 2023 21:29:13 -0700 Subject: [TreatAsDifferentiable] functions. (#2720) --- source/slang/slang-check-decl.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c0253fd2c..eaab43ef8 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1955,8 +1955,11 @@ namespace Slang bool hasForwardDerivative = false; if (requiredMemberDeclRef.getDecl()->hasModifier()) { - if (!satisfyingMemberDeclRef.getDecl()->hasModifier() - && !satisfyingMemberDeclRef.getDecl()->hasModifier()) + auto funcDecl = as(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + + if (getShared()->getFuncDifferentiableLevel(funcDecl) != FunctionDifferentiableLevel::Backward) { // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. return false; @@ -1966,12 +1969,12 @@ namespace Slang } else if (requiredMemberDeclRef.getDecl()->hasModifier()) { - if (!satisfyingMemberDeclRef.getDecl()->hasModifier() - && !satisfyingMemberDeclRef.getDecl()->hasModifier() - && !satisfyingMemberDeclRef.getDecl()->hasModifier() - && !satisfyingMemberDeclRef.getDecl()->hasModifier()) + auto funcDecl = as(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + if (getShared()->getFuncDifferentiableLevel(funcDecl) == FunctionDifferentiableLevel::None) { - // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa. + // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. return false; } hasForwardDerivative = true; @@ -6674,6 +6677,9 @@ namespace Slang if (func->findModifier()) return FunctionDifferentiableLevel::Backward; + if (func->findModifier()) + return FunctionDifferentiableLevel::Backward; + FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None; if (func->findModifier()) diffLevel = FunctionDifferentiableLevel::Forward; -- cgit v1.2.3