summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp61
1 files changed, 41 insertions, 20 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 36a1061c9..4d2839b8d 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1926,23 +1926,33 @@ namespace Slang
requiredMemberDeclRef.getDecl(),
RequirementWitness(satisfyingMemberDeclRef));
- if (hasForwardDerivative)
+ if (hasForwardDerivative || hasBackwardDerivative)
{
- auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<ForwardDerivativeRequirementDecl>();
- SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty());
- ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(reqDecl.getFirst(), RequirementWitness(val));
- }
+ int fwdReqFound = 0;
+ int bwdReqFound = 0;
+ for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>())
+ {
+ if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(fwdReq, RequirementWitness(val));
+ fwdReqFound++;
+ }
+ else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(bwdReq, RequirementWitness(val));
+ bwdReqFound++;
+ }
+ }
- if (hasBackwardDerivative)
- {
- auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<BackwardDerivativeRequirementDecl>();
- SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty());
- BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(reqDecl.getFirst(), RequirementWitness(val));
+ SLANG_RELEASE_ASSERT(
+ fwdReqFound == (hasForwardDerivative ? 1 : 0) &&
+ bwdReqFound == (hasBackwardDerivative ? 1 : 0));
}
+
return true;
}
@@ -3706,7 +3716,8 @@ namespace Slang
{
if(isAssociatedTypeDecl(requiredMemberDeclRef))
continue;
-
+ if (requiredMemberDeclRef.as<DerivativeRequirementDecl>())
+ continue;
auto requirementSatisfied = findWitnessForInterfaceRequirement(
context,
subType,
@@ -5617,7 +5628,7 @@ namespace Slang
}
decl->errorType = errorType;
- if (isInterfaceRequirement(decl))
+ if (auto interfaceDecl = findParentInterfaceDecl(decl))
{
if (decl->hasModifier<ForwardDifferentiableAttribute>())
{
@@ -5626,8 +5637,13 @@ namespace Slang
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef));
setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
- decl->members.add(reqDecl);
- reqDecl->parentDecl = decl;
+ interfaceDecl->members.add(reqDecl);
+ reqDecl->parentDecl = interfaceDecl;
+
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = reqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
}
if (decl->hasModifier<BackwardDifferentiableAttribute>())
{
@@ -5636,8 +5652,13 @@ namespace Slang
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef));
setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
- decl->members.add(reqDecl);
- reqDecl->parentDecl = decl;
+ interfaceDecl->members.add(reqDecl);
+ reqDecl->parentDecl = interfaceDecl;
+
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = reqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
}
}
}