summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-19 08:58:20 -0800
committerGitHub <noreply@github.com>2023-01-19 08:58:20 -0800
commit6fae15cd1210d8b664243d640e70ca47dccf9752 (patch)
treed3235149f587ed18147f7a0d916932e199dce888 /source/slang/slang-check-decl.cpp
parent0586f3298fa7d554fa2682103eefba88740d6758 (diff)
Add diagnostic for calling non-bwd-diff func from bwd-diff func. (#2602)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp42
1 files changed, 19 insertions, 23 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index a535ba104..7b5f85b60 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -6894,38 +6894,34 @@ namespace Slang
bool SharedSemanticsContext::isDifferentiableFunc(FunctionDeclBase* func)
{
- // A function is differentiable if it is marked as differentiable, or it
- // has an associated derivative function.
- if (func->findModifier<DifferentiableAttribute>())
- return true;
- for (auto assocDecl : getAssociatedDeclsForDecl(func))
- {
- switch (assocDecl.kind)
- {
- case DeclAssociationKind::ForwardDerivativeFunc:
- case DeclAssociationKind::BackwardDerivativeFunc:
- return true;
- default:
- break;
- }
- }
- return false;
+ return getFuncDifferentiableLevel(func) != FunctionDifferentiableLevel::None;
}
bool SharedSemanticsContext::isBackwardDifferentiableFunc(FunctionDeclBase* func)
{
- // A function is differentiable if it is marked as differentiable, or it
- // has an associated derivative function.
+ return getFuncDifferentiableLevel(func) == FunctionDifferentiableLevel::Backward;
+ }
+
+ FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel(FunctionDeclBase* func)
+ {
if (func->findModifier<BackwardDifferentiableAttribute>())
- return true;
+ return FunctionDifferentiableLevel::Backward;
if (func->findModifier<BackwardDerivativeAttribute>())
- return true;
+ return FunctionDifferentiableLevel::Backward;
+
+ FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None;
+ if (func->findModifier<DifferentiableAttribute>())
+ diffLevel = FunctionDifferentiableLevel::Forward;
+
for (auto assocDecl : getAssociatedDeclsForDecl(func))
{
switch (assocDecl.kind)
{
case DeclAssociationKind::BackwardDerivativeFunc:
- return true;
+ return FunctionDifferentiableLevel::Backward;
+ case DeclAssociationKind::ForwardDerivativeFunc:
+ diffLevel = FunctionDifferentiableLevel::Forward;
+ break;
default:
break;
}
@@ -6937,12 +6933,12 @@ namespace Slang
case BuiltinRequirementKind::DAddFunc:
case BuiltinRequirementKind::DMulFunc:
case BuiltinRequirementKind::DZeroFunc:
- return true;
+ return FunctionDifferentiableLevel::Backward;
default:
break;
}
}
- return false;
+ return diffLevel;
}
List<ExtensionDecl*> const& getCandidateExtensions(