diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 09:39:08 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 09:39:08 -0800 |
| commit | 97cb4851eed7a43f10196971b08d3d311386ce9f (patch) | |
| tree | 99ba81368068b3345fa23b749108265aa753ed2b /source/slang/slang-check-decl.cpp | |
| parent | 6178cb601368e977c4aa82e0ae25b8eb1e875d84 (diff) | |
Autodiff through simple dynamic dispatch. (#2527)
* Autodiff through simple dynamic dispatch.
* Revert changes.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5a1218abe..36a1061c9 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -107,6 +107,9 @@ namespace Slang void visitAccessorDecl(AccessorDecl* decl); void visitSetterDecl(SetterDecl* decl); + + void cloneModifiers(Decl* dest, Decl* src); + void setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType); }; struct SemanticsDeclRedeclarationVisitor @@ -1866,6 +1869,32 @@ namespace Slang return false; } + bool hasBackwardDerivative = false; + bool hasForwardDerivative = false; + if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()) + { + if (!satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>() + && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>()) + { + // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. + return false; + } + hasBackwardDerivative = true; + hasForwardDerivative = true; + } + else if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()) + { + if (!satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>() + && !satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDerivativeAttribute>() + && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>() + && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>()) + { + // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa. + return false; + } + hasForwardDerivative = true; + } + // A signature matches the required one if it has the right number of parameters, // and those parameters have the right types, and also the result/return type // is the required one. @@ -1896,6 +1925,24 @@ namespace Slang witnessTable->add( requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); + + if (hasForwardDerivative) + { + auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<ForwardDerivativeRequirementDecl>(); + SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); + ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); + } + + if (hasBackwardDerivative) + { + auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<BackwardDerivativeRequirementDecl>(); + SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); + BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); + } return true; } @@ -5515,6 +5562,43 @@ namespace Slang } } + void SemanticsDeclHeaderVisitor::cloneModifiers(Decl* dest, Decl* src) + { + dest->modifiers = src->modifiers; + } + void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType) + { + if (!funcType) + return; + decl->returnType.type = funcType->getResultType(); + decl->errorType.type = funcType->getErrorType(); + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + auto paramType = funcType->getParamType(i); + if (auto dirType = as<ParamDirectionType>(paramType)) + paramType = dirType->getValueType(); + auto param = m_astBuilder->create<ParamDecl>(); + param->type.type = paramType; + auto paramDir = funcType->getParamDirection(i); + switch (paramDir) + { + case ParameterDirection::kParameterDirection_InOut: + addModifier(param, m_astBuilder->create<InOutModifier>()); + break; + case ParameterDirection::kParameterDirection_Out: + addModifier(param, m_astBuilder->create<OutModifier>()); + break; + case ParameterDirection::kParameterDirection_Ref: + addModifier(param, m_astBuilder->create<RefModifier>()); + break; + default: + break; + } + decl->members.add(param); + param->parentDecl = decl; + } + } + void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { for(auto paramDecl : decl->getParameters()) @@ -5532,6 +5616,30 @@ namespace Slang errorType = TypeExp(m_astBuilder->getBottomType()); } decl->errorType = errorType; + + if (isInterfaceRequirement(decl)) + { + if (decl->hasModifier<ForwardDifferentiableAttribute>()) + { + auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>(); + cloneModifiers(reqDecl, decl); + auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); + auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); + setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); + decl->members.add(reqDecl); + reqDecl->parentDecl = decl; + } + if (decl->hasModifier<BackwardDifferentiableAttribute>()) + { + auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>(); + cloneModifiers(reqDecl, decl); + auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); + auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)); + setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); + decl->members.add(reqDecl); + reqDecl->parentDecl = decl; + } + } } void SemanticsDeclHeaderVisitor::visitFuncDecl(FuncDecl* funcDecl) |
