diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-19 08:58:20 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-19 08:58:20 -0800 |
| commit | 6fae15cd1210d8b664243d640e70ca47dccf9752 (patch) | |
| tree | d3235149f587ed18147f7a0d916932e199dce888 /source/slang/slang-check-decl.cpp | |
| parent | 0586f3298fa7d554fa2682103eefba88740d6758 (diff) | |
Add diagnostic for calling non-bwd-diff func from bwd-diff func. (#2602)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 42 |
1 files changed, 19 insertions, 23 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index a535ba104..7b5f85b60 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6894,38 +6894,34 @@ namespace Slang bool SharedSemanticsContext::isDifferentiableFunc(FunctionDeclBase* func) { - // A function is differentiable if it is marked as differentiable, or it - // has an associated derivative function. - if (func->findModifier<DifferentiableAttribute>()) - return true; - for (auto assocDecl : getAssociatedDeclsForDecl(func)) - { - switch (assocDecl.kind) - { - case DeclAssociationKind::ForwardDerivativeFunc: - case DeclAssociationKind::BackwardDerivativeFunc: - return true; - default: - break; - } - } - return false; + return getFuncDifferentiableLevel(func) != FunctionDifferentiableLevel::None; } bool SharedSemanticsContext::isBackwardDifferentiableFunc(FunctionDeclBase* func) { - // A function is differentiable if it is marked as differentiable, or it - // has an associated derivative function. + return getFuncDifferentiableLevel(func) == FunctionDifferentiableLevel::Backward; + } + + FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel(FunctionDeclBase* func) + { if (func->findModifier<BackwardDifferentiableAttribute>()) - return true; + return FunctionDifferentiableLevel::Backward; if (func->findModifier<BackwardDerivativeAttribute>()) - return true; + return FunctionDifferentiableLevel::Backward; + + FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None; + if (func->findModifier<DifferentiableAttribute>()) + diffLevel = FunctionDifferentiableLevel::Forward; + for (auto assocDecl : getAssociatedDeclsForDecl(func)) { switch (assocDecl.kind) { case DeclAssociationKind::BackwardDerivativeFunc: - return true; + return FunctionDifferentiableLevel::Backward; + case DeclAssociationKind::ForwardDerivativeFunc: + diffLevel = FunctionDifferentiableLevel::Forward; + break; default: break; } @@ -6937,12 +6933,12 @@ namespace Slang case BuiltinRequirementKind::DAddFunc: case BuiltinRequirementKind::DMulFunc: case BuiltinRequirementKind::DZeroFunc: - return true; + return FunctionDifferentiableLevel::Backward; default: break; } } - return false; + return diffLevel; } List<ExtensionDecl*> const& getCandidateExtensions( |
