diff options
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 61 |
1 files changed, 41 insertions, 20 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 36a1061c9..4d2839b8d 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1926,23 +1926,33 @@ namespace Slang requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); - if (hasForwardDerivative) + if (hasForwardDerivative || hasBackwardDerivative) { - auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<ForwardDerivativeRequirementDecl>(); - SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); - ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); - } + int fwdReqFound = 0; + int bwdReqFound = 0; + for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>()) + { + if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) + { + ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(fwdReq, RequirementWitness(val)); + fwdReqFound++; + } + else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) + { + BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(bwdReq, RequirementWitness(val)); + bwdReqFound++; + } + } - if (hasBackwardDerivative) - { - auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<BackwardDerivativeRequirementDecl>(); - SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); - BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); + SLANG_RELEASE_ASSERT( + fwdReqFound == (hasForwardDerivative ? 1 : 0) && + bwdReqFound == (hasBackwardDerivative ? 1 : 0)); } + return true; } @@ -3706,7 +3716,8 @@ namespace Slang { if(isAssociatedTypeDecl(requiredMemberDeclRef)) continue; - + if (requiredMemberDeclRef.as<DerivativeRequirementDecl>()) + continue; auto requirementSatisfied = findWitnessForInterfaceRequirement( context, subType, @@ -5617,7 +5628,7 @@ namespace Slang } decl->errorType = errorType; - if (isInterfaceRequirement(decl)) + if (auto interfaceDecl = findParentInterfaceDecl(decl)) { if (decl->hasModifier<ForwardDifferentiableAttribute>()) { @@ -5626,8 +5637,13 @@ namespace Slang auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); - decl->members.add(reqDecl); - reqDecl->parentDecl = decl; + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); } if (decl->hasModifier<BackwardDifferentiableAttribute>()) { @@ -5636,8 +5652,13 @@ namespace Slang auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); - decl->members.add(reqDecl); - reqDecl->parentDecl = decl; + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); } } } |
