diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-24 14:33:32 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-24 14:33:32 -0800 |
| commit | 85c1569308793cc2408088e539a3ed1da5f9d235 (patch) | |
| tree | b5f4109fd50535cbd9c037422e91a70622389b3f /source/slang/slang-check-decl.cpp | |
| parent | 91694dacdb8d3ab7dd9783d7c0c43629bf11f578 (diff) | |
Support dynamic dispatch a backward differentiable function. (#2678)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 69 |
1 files changed, 1 insertions, 68 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 142842e12..a1d5acfb0 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2677,24 +2677,6 @@ namespace Slang val->func = satisfyingMemberDeclRef; witnessTable->add(bwdReq, RequirementWitness(val)); } - else if (auto primalReq = as<BackwardDerivativePrimalRequirementDecl>(reqRefDecl->referencedDecl)) - { - DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePrimalVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(primalReq, RequirementWitness(val)); - } - else if (auto propReq = as<BackwardDerivativePropagateRequirementDecl>(reqRefDecl->referencedDecl)) - { - DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePropagateVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(propReq, RequirementWitness(val)); - } - else if (auto itypeReq = as<BackwardDerivativeIntermediateTypeRequirementDecl>(reqRefDecl->referencedDecl)) - { - DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateIntermediateTypeVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(itypeReq, RequirementWitness(val)); - } } witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef)); } @@ -5920,7 +5902,7 @@ namespace Slang if (auto interfaceDecl = findParentInterfaceDecl(decl)) { bool isDiffFunc = false; - if (decl->hasModifier<ForwardDifferentiableAttribute>()) + if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>()) { auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>(); cloneModifiers(reqDecl, decl); @@ -5954,55 +5936,6 @@ namespace Slang reqRef->parentDecl = decl; decl->members.add(reqRef); } - // Requirement for backward derivative intermediate type. - auto intermediateTypeReqDecl = m_astBuilder->create<BackwardDerivativeIntermediateTypeRequirementDecl>(); - auto intermediateType = m_astBuilder->getOrCreateDeclRefType( - intermediateTypeReqDecl, createDefaultSubstitutions(m_astBuilder, this, decl)); - { - cloneModifiers(intermediateTypeReqDecl, decl); - interfaceDecl->members.add(intermediateTypeReqDecl); - intermediateTypeReqDecl->parentDecl = interfaceDecl; - - auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); - reqRef->referencedDecl = intermediateTypeReqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); - } - // Requirement for backward derivative primal func. - { - auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>(); - cloneModifiers(reqDecl, decl); - FuncType* primalFuncType = m_astBuilder->create<FuncType>(); - primalFuncType->resultType = originalFuncType->resultType; - primalFuncType->paramTypes.addRange(originalFuncType->paramTypes); - auto outType = m_astBuilder->getOutType(intermediateType); - primalFuncType->paramTypes.add(outType); - setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; - - auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); - reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); - } - // Requirement for backward derivative propagate func. - { - auto reqDecl = m_astBuilder->create<BackwardDerivativePropagateRequirementDecl>(); - cloneModifiers(reqDecl, decl); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; - FuncType* propagateFuncType = m_astBuilder->create<FuncType>(); - propagateFuncType->resultType = diffFuncType->resultType; - propagateFuncType->paramTypes.addRange(diffFuncType->paramTypes); - propagateFuncType->paramTypes.add(intermediateType); - setFuncTypeIntoRequirementDecl(reqDecl, propagateFuncType); - auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); - reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); - } - isDiffFunc = true; } if (isDiffFunc) |
