diff options
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 0901d2026..b3470e882 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1427,6 +1427,7 @@ namespace Slang varDecl->initExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); } } + maybeRegisterDifferentiableType(getASTBuilder(), varDecl->getType()); } // Fill in default substitutions for the 'subtype' part of a type constraint decl @@ -4738,7 +4739,6 @@ namespace Slang void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); - if (newContext.getParentDifferentiableAttribute()) { // Register additional types outside the function body first. @@ -5638,11 +5638,8 @@ namespace Slang bool isDiffFunc = false; if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>()) { - if (GetOuterGeneric(decl)) - { - getSink()->diagnose(decl, Diagnostics::differentiableGenericInterfaceMethodNotSupported); - } auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>(); + reqDecl->originalRequirementDecl = decl; cloneModifiers(reqDecl, decl); auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); @@ -5664,6 +5661,7 @@ namespace Slang auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(originalFuncType)); { auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>(); + reqDecl->originalRequirementDecl = decl; cloneModifiers(reqDecl, decl); setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType); interfaceDecl->members.add(reqDecl); |
