summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp8
1 files changed, 3 insertions, 5 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 0901d2026..b3470e882 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1427,6 +1427,7 @@ namespace Slang
varDecl->initExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate);
}
}
+ maybeRegisterDifferentiableType(getASTBuilder(), varDecl->getType());
}
// Fill in default substitutions for the 'subtype' part of a type constraint decl
@@ -4738,7 +4739,6 @@ namespace Slang
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
{
auto newContext = withParentFunc(decl);
-
if (newContext.getParentDifferentiableAttribute())
{
// Register additional types outside the function body first.
@@ -5638,11 +5638,8 @@ namespace Slang
bool isDiffFunc = false;
if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>())
{
- if (GetOuterGeneric(decl))
- {
- getSink()->diagnose(decl, Diagnostics::differentiableGenericInterfaceMethodNotSupported);
- }
auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
+ reqDecl->originalRequirementDecl = decl;
cloneModifiers(reqDecl, decl);
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef));
@@ -5664,6 +5661,7 @@ namespace Slang
auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(originalFuncType));
{
auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
+ reqDecl->originalRequirementDecl = decl;
cloneModifiers(reqDecl, decl);
setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType);
interfaceDecl->members.add(reqDecl);