diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-06 13:39:06 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-06 13:39:06 -0800 |
| commit | 33fb95980b0120cdd4d4f2d51f5f116e808dd4aa (patch) | |
| tree | 318b1669a0e52aabd11f8694de1278ef7dbc0e3b /source/slang/slang-check-decl.cpp | |
| parent | e70cbe76ce74769069b7384f5f05c62da1ca45ed (diff) | |
Split bwd_diff op into separate ops for primal and propagate func. (#2582)
* Split bwd_diff op into separate ops for primal and propagate func.
* Fix.
* Download swiftshader with github actions instead of curl on linux.
* Fix github action.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 94 |
1 files changed, 82 insertions, 12 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 80bf74e53..7c8e320c4 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2665,10 +2665,28 @@ namespace Slang } else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) { - BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); + DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); val->func = satisfyingMemberDeclRef; witnessTable->add(bwdReq, RequirementWitness(val)); } + else if (auto primalReq = as<BackwardDerivativePrimalRequirementDecl>(reqRefDecl->referencedDecl)) + { + DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePrimalVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(primalReq, RequirementWitness(val)); + } + else if (auto propReq = as<BackwardDerivativePropagateRequirementDecl>(reqRefDecl->referencedDecl)) + { + DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePropagateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(propReq, RequirementWitness(val)); + } + else if (auto itypeReq = as<BackwardDerivativeIntermediateTypeRequirementDecl>(reqRefDecl->referencedDecl)) + { + DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateIntermediateTypeVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(itypeReq, RequirementWitness(val)); + } } witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef)); } @@ -5652,18 +5670,70 @@ namespace Slang } if (decl->hasModifier<BackwardDifferentiableAttribute>()) { - auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>(); - cloneModifiers(reqDecl, decl); + // Requirement for backward derivative. auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); - auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)); - setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; - - auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); - reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); + auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef))); + { + auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>(); + cloneModifiers(reqDecl, decl); + setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType); + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + // Requirement for backward derivative intermediate type. + auto intermediateTypeReqDecl = m_astBuilder->create<BackwardDerivativeIntermediateTypeRequirementDecl>(); + auto intermediateType = m_astBuilder->getOrCreateDeclRefType( + intermediateTypeReqDecl, createDefaultSubstitutions(m_astBuilder, this, decl)); + { + cloneModifiers(intermediateTypeReqDecl, decl); + interfaceDecl->members.add(intermediateTypeReqDecl); + intermediateTypeReqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = intermediateTypeReqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + // Requirement for backward derivative primal func. + { + auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>(); + cloneModifiers(reqDecl, decl); + FuncType* primalFuncType = m_astBuilder->create<FuncType>(); + primalFuncType->resultType = diffFuncType->resultType; + primalFuncType->paramTypes.addRange(diffFuncType->paramTypes); + auto outType = m_astBuilder->getOutType(intermediateType); + primalFuncType->paramTypes.add(outType); + setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType); + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + // Requirement for backward derivative propagate func. + { + auto reqDecl = m_astBuilder->create<BackwardDerivativePropagateRequirementDecl>(); + cloneModifiers(reqDecl, decl); + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + FuncType* propagateFuncType = m_astBuilder->create<FuncType>(); + propagateFuncType->resultType = diffFuncType->resultType; + propagateFuncType->paramTypes.addRange(diffFuncType->paramTypes); + propagateFuncType->paramTypes.add(intermediateType); + setFuncTypeIntoRequirementDecl(reqDecl, propagateFuncType); + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + isDiffFunc = true; } if (isDiffFunc) |
