From 4ad0470025da4e808c46023f9a2525febcf973a2 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 23 Nov 2022 16:02:56 -0800 Subject: Fix issues around dynamic generic function and autodiff. (#2528) * Fix issues around dynamic generic function and autodiff. * Fix return type issue. * Fix type unification for generic `inout` parameter. * Fix. Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 61 ++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 20 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 36a1061c9..4d2839b8d 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1926,23 +1926,33 @@ namespace Slang requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); - if (hasForwardDerivative) + if (hasForwardDerivative || hasBackwardDerivative) { - auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType(); - SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); - ForwardDifferentiateVal* val = m_astBuilder->create(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); - } + int fwdReqFound = 0; + int bwdReqFound = 0; + for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType()) + { + if (auto fwdReq = as(reqRefDecl->referencedDecl)) + { + ForwardDifferentiateVal* val = m_astBuilder->create(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(fwdReq, RequirementWitness(val)); + fwdReqFound++; + } + else if (auto bwdReq = as(reqRefDecl->referencedDecl)) + { + BackwardDifferentiateVal* val = m_astBuilder->create(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(bwdReq, RequirementWitness(val)); + bwdReqFound++; + } + } - if (hasBackwardDerivative) - { - auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType(); - SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); - BackwardDifferentiateVal* val = m_astBuilder->create(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); + SLANG_RELEASE_ASSERT( + fwdReqFound == (hasForwardDerivative ? 1 : 0) && + bwdReqFound == (hasBackwardDerivative ? 1 : 0)); } + return true; } @@ -3706,7 +3716,8 @@ namespace Slang { if(isAssociatedTypeDecl(requiredMemberDeclRef)) continue; - + if (requiredMemberDeclRef.as()) + continue; auto requirementSatisfied = findWitnessForInterfaceRequirement( context, subType, @@ -5617,7 +5628,7 @@ namespace Slang } decl->errorType = errorType; - if (isInterfaceRequirement(decl)) + if (auto interfaceDecl = findParentInterfaceDecl(decl)) { if (decl->hasModifier()) { @@ -5626,8 +5637,13 @@ namespace Slang auto declRef = DeclRef(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as(diffFuncType)); - decl->members.add(reqDecl); - reqDecl->parentDecl = decl; + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); } if (decl->hasModifier()) { @@ -5636,8 +5652,13 @@ namespace Slang auto declRef = DeclRef(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as(diffFuncType)); - decl->members.add(reqDecl); - reqDecl->parentDecl = decl; + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); } } } -- cgit v1.2.3