summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-06 13:39:06 -0800
committerGitHub <noreply@github.com>2023-01-06 13:39:06 -0800
commit33fb95980b0120cdd4d4f2d51f5f116e808dd4aa (patch)
tree318b1669a0e52aabd11f8694de1278ef7dbc0e3b /source/slang/slang-check-decl.cpp
parente70cbe76ce74769069b7384f5f05c62da1ca45ed (diff)
Split bwd_diff op into separate ops for primal and propagate func. (#2582)
* Split bwd_diff op into separate ops for primal and propagate func. * Fix. * Download swiftshader with github actions instead of curl on linux. * Fix github action. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp94
1 files changed, 82 insertions, 12 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 80bf74e53..7c8e320c4 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2665,10 +2665,28 @@ namespace Slang
}
else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
{
- BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
+ DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
val->func = satisfyingMemberDeclRef;
witnessTable->add(bwdReq, RequirementWitness(val));
}
+ else if (auto primalReq = as<BackwardDerivativePrimalRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePrimalVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(primalReq, RequirementWitness(val));
+ }
+ else if (auto propReq = as<BackwardDerivativePropagateRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePropagateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(propReq, RequirementWitness(val));
+ }
+ else if (auto itypeReq = as<BackwardDerivativeIntermediateTypeRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateIntermediateTypeVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(itypeReq, RequirementWitness(val));
+ }
}
witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef));
}
@@ -5652,18 +5670,70 @@ namespace Slang
}
if (decl->hasModifier<BackwardDifferentiableAttribute>())
{
- auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
- cloneModifiers(reqDecl, decl);
+ // Requirement for backward derivative.
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
- auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef));
- setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
- interfaceDecl->members.add(reqDecl);
- reqDecl->parentDecl = interfaceDecl;
-
- auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
- reqRef->referencedDecl = reqDecl;
- reqRef->parentDecl = decl;
- decl->members.add(reqRef);
+ auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)));
+ {
+ auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
+ cloneModifiers(reqDecl, decl);
+ setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType);
+ interfaceDecl->members.add(reqDecl);
+ reqDecl->parentDecl = interfaceDecl;
+
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = reqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
+ }
+ // Requirement for backward derivative intermediate type.
+ auto intermediateTypeReqDecl = m_astBuilder->create<BackwardDerivativeIntermediateTypeRequirementDecl>();
+ auto intermediateType = m_astBuilder->getOrCreateDeclRefType(
+ intermediateTypeReqDecl, createDefaultSubstitutions(m_astBuilder, this, decl));
+ {
+ cloneModifiers(intermediateTypeReqDecl, decl);
+ interfaceDecl->members.add(intermediateTypeReqDecl);
+ intermediateTypeReqDecl->parentDecl = interfaceDecl;
+
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = intermediateTypeReqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
+ }
+ // Requirement for backward derivative primal func.
+ {
+ auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>();
+ cloneModifiers(reqDecl, decl);
+ FuncType* primalFuncType = m_astBuilder->create<FuncType>();
+ primalFuncType->resultType = diffFuncType->resultType;
+ primalFuncType->paramTypes.addRange(diffFuncType->paramTypes);
+ auto outType = m_astBuilder->getOutType(intermediateType);
+ primalFuncType->paramTypes.add(outType);
+ setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType);
+ interfaceDecl->members.add(reqDecl);
+ reqDecl->parentDecl = interfaceDecl;
+
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = reqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
+ }
+ // Requirement for backward derivative propagate func.
+ {
+ auto reqDecl = m_astBuilder->create<BackwardDerivativePropagateRequirementDecl>();
+ cloneModifiers(reqDecl, decl);
+ interfaceDecl->members.add(reqDecl);
+ reqDecl->parentDecl = interfaceDecl;
+ FuncType* propagateFuncType = m_astBuilder->create<FuncType>();
+ propagateFuncType->resultType = diffFuncType->resultType;
+ propagateFuncType->paramTypes.addRange(diffFuncType->paramTypes);
+ propagateFuncType->paramTypes.add(intermediateType);
+ setFuncTypeIntoRequirementDecl(reqDecl, propagateFuncType);
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = reqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
+ }
+
isDiffFunc = true;
}
if (isDiffFunc)