summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-24 14:33:32 -0800
committerGitHub <noreply@github.com>2023-02-24 14:33:32 -0800
commit85c1569308793cc2408088e539a3ed1da5f9d235 (patch)
treeb5f4109fd50535cbd9c037422e91a70622389b3f /source/slang/slang-check-decl.cpp
parent91694dacdb8d3ab7dd9783d7c0c43629bf11f578 (diff)
Support dynamic dispatch a backward differentiable function. (#2678)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp69
1 files changed, 1 insertions, 68 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 142842e12..a1d5acfb0 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2677,24 +2677,6 @@ namespace Slang
val->func = satisfyingMemberDeclRef;
witnessTable->add(bwdReq, RequirementWitness(val));
}
- else if (auto primalReq = as<BackwardDerivativePrimalRequirementDecl>(reqRefDecl->referencedDecl))
- {
- DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePrimalVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(primalReq, RequirementWitness(val));
- }
- else if (auto propReq = as<BackwardDerivativePropagateRequirementDecl>(reqRefDecl->referencedDecl))
- {
- DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePropagateVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(propReq, RequirementWitness(val));
- }
- else if (auto itypeReq = as<BackwardDerivativeIntermediateTypeRequirementDecl>(reqRefDecl->referencedDecl))
- {
- DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateIntermediateTypeVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(itypeReq, RequirementWitness(val));
- }
}
witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef));
}
@@ -5920,7 +5902,7 @@ namespace Slang
if (auto interfaceDecl = findParentInterfaceDecl(decl))
{
bool isDiffFunc = false;
- if (decl->hasModifier<ForwardDifferentiableAttribute>())
+ if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>())
{
auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
cloneModifiers(reqDecl, decl);
@@ -5954,55 +5936,6 @@ namespace Slang
reqRef->parentDecl = decl;
decl->members.add(reqRef);
}
- // Requirement for backward derivative intermediate type.
- auto intermediateTypeReqDecl = m_astBuilder->create<BackwardDerivativeIntermediateTypeRequirementDecl>();
- auto intermediateType = m_astBuilder->getOrCreateDeclRefType(
- intermediateTypeReqDecl, createDefaultSubstitutions(m_astBuilder, this, decl));
- {
- cloneModifiers(intermediateTypeReqDecl, decl);
- interfaceDecl->members.add(intermediateTypeReqDecl);
- intermediateTypeReqDecl->parentDecl = interfaceDecl;
-
- auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
- reqRef->referencedDecl = intermediateTypeReqDecl;
- reqRef->parentDecl = decl;
- decl->members.add(reqRef);
- }
- // Requirement for backward derivative primal func.
- {
- auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>();
- cloneModifiers(reqDecl, decl);
- FuncType* primalFuncType = m_astBuilder->create<FuncType>();
- primalFuncType->resultType = originalFuncType->resultType;
- primalFuncType->paramTypes.addRange(originalFuncType->paramTypes);
- auto outType = m_astBuilder->getOutType(intermediateType);
- primalFuncType->paramTypes.add(outType);
- setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType);
- interfaceDecl->members.add(reqDecl);
- reqDecl->parentDecl = interfaceDecl;
-
- auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
- reqRef->referencedDecl = reqDecl;
- reqRef->parentDecl = decl;
- decl->members.add(reqRef);
- }
- // Requirement for backward derivative propagate func.
- {
- auto reqDecl = m_astBuilder->create<BackwardDerivativePropagateRequirementDecl>();
- cloneModifiers(reqDecl, decl);
- interfaceDecl->members.add(reqDecl);
- reqDecl->parentDecl = interfaceDecl;
- FuncType* propagateFuncType = m_astBuilder->create<FuncType>();
- propagateFuncType->resultType = diffFuncType->resultType;
- propagateFuncType->paramTypes.addRange(diffFuncType->paramTypes);
- propagateFuncType->paramTypes.add(intermediateType);
- setFuncTypeIntoRequirementDecl(reqDecl, propagateFuncType);
- auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
- reqRef->referencedDecl = reqDecl;
- reqRef->parentDecl = decl;
- decl->members.add(reqRef);
- }
-
isDiffFunc = true;
}
if (isDiffFunc)