From 33fb95980b0120cdd4d4f2d51f5f116e808dd4aa Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 6 Jan 2023 13:39:06 -0800 Subject: Split bwd_diff op into separate ops for primal and propagate func. (#2582) * Split bwd_diff op into separate ops for primal and propagate func. * Fix. * Download swiftshader with github actions instead of curl on linux. * Fix github action. Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 94 ++++++++++++++++++++++++++++++++++----- 1 file changed, 82 insertions(+), 12 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 80bf74e53..7c8e320c4 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2665,10 +2665,28 @@ namespace Slang } else if (auto bwdReq = as(reqRefDecl->referencedDecl)) { - BackwardDifferentiateVal* val = m_astBuilder->create(); + DifferentiateVal* val = m_astBuilder->create(); val->func = satisfyingMemberDeclRef; witnessTable->add(bwdReq, RequirementWitness(val)); } + else if (auto primalReq = as(reqRefDecl->referencedDecl)) + { + DifferentiateVal* val = m_astBuilder->create(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(primalReq, RequirementWitness(val)); + } + else if (auto propReq = as(reqRefDecl->referencedDecl)) + { + DifferentiateVal* val = m_astBuilder->create(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(propReq, RequirementWitness(val)); + } + else if (auto itypeReq = as(reqRefDecl->referencedDecl)) + { + DifferentiateVal* val = m_astBuilder->create(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(itypeReq, RequirementWitness(val)); + } } witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef)); } @@ -5652,18 +5670,70 @@ namespace Slang } if (decl->hasModifier()) { - auto reqDecl = m_astBuilder->create(); - cloneModifiers(reqDecl, decl); + // Requirement for backward derivative. auto declRef = DeclRef(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); - auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)); - setFuncTypeIntoRequirementDecl(reqDecl, as(diffFuncType)); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; - - auto reqRef = m_astBuilder->create(); - reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); + auto diffFuncType = as(getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef))); + { + auto reqDecl = m_astBuilder->create(); + cloneModifiers(reqDecl, decl); + setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType); + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + // Requirement for backward derivative intermediate type. + auto intermediateTypeReqDecl = m_astBuilder->create(); + auto intermediateType = m_astBuilder->getOrCreateDeclRefType( + intermediateTypeReqDecl, createDefaultSubstitutions(m_astBuilder, this, decl)); + { + cloneModifiers(intermediateTypeReqDecl, decl); + interfaceDecl->members.add(intermediateTypeReqDecl); + intermediateTypeReqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create(); + reqRef->referencedDecl = intermediateTypeReqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + // Requirement for backward derivative primal func. + { + auto reqDecl = m_astBuilder->create(); + cloneModifiers(reqDecl, decl); + FuncType* primalFuncType = m_astBuilder->create(); + primalFuncType->resultType = diffFuncType->resultType; + primalFuncType->paramTypes.addRange(diffFuncType->paramTypes); + auto outType = m_astBuilder->getOutType(intermediateType); + primalFuncType->paramTypes.add(outType); + setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType); + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + // Requirement for backward derivative propagate func. + { + auto reqDecl = m_astBuilder->create(); + cloneModifiers(reqDecl, decl); + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + FuncType* propagateFuncType = m_astBuilder->create(); + propagateFuncType->resultType = diffFuncType->resultType; + propagateFuncType->paramTypes.addRange(diffFuncType->paramTypes); + propagateFuncType->paramTypes.add(intermediateType); + setFuncTypeIntoRequirementDecl(reqDecl, propagateFuncType); + auto reqRef = m_astBuilder->create(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + isDiffFunc = true; } if (isDiffFunc) -- cgit v1.2.3