summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-23 09:39:08 -0800
committerGitHub <noreply@github.com>2022-11-23 09:39:08 -0800
commit97cb4851eed7a43f10196971b08d3d311386ce9f (patch)
tree99ba81368068b3345fa23b749108265aa753ed2b /source/slang/slang-check-decl.cpp
parent6178cb601368e977c4aa82e0ae25b8eb1e875d84 (diff)
Autodiff through simple dynamic dispatch. (#2527)
* Autodiff through simple dynamic dispatch. * Revert changes. * Fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp108
1 files changed, 108 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 5a1218abe..36a1061c9 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -107,6 +107,9 @@ namespace Slang
void visitAccessorDecl(AccessorDecl* decl);
void visitSetterDecl(SetterDecl* decl);
+
+ void cloneModifiers(Decl* dest, Decl* src);
+ void setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType);
};
struct SemanticsDeclRedeclarationVisitor
@@ -1866,6 +1869,32 @@ namespace Slang
return false;
}
+ bool hasBackwardDerivative = false;
+ bool hasForwardDerivative = false;
+ if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>())
+ {
+ if (!satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()
+ && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>())
+ {
+ // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa.
+ return false;
+ }
+ hasBackwardDerivative = true;
+ hasForwardDerivative = true;
+ }
+ else if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>())
+ {
+ if (!satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()
+ && !satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDerivativeAttribute>()
+ && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()
+ && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>())
+ {
+ // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa.
+ return false;
+ }
+ hasForwardDerivative = true;
+ }
+
// A signature matches the required one if it has the right number of parameters,
// and those parameters have the right types, and also the result/return type
// is the required one.
@@ -1896,6 +1925,24 @@ namespace Slang
witnessTable->add(
requiredMemberDeclRef.getDecl(),
RequirementWitness(satisfyingMemberDeclRef));
+
+ if (hasForwardDerivative)
+ {
+ auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<ForwardDerivativeRequirementDecl>();
+ SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty());
+ ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(reqDecl.getFirst(), RequirementWitness(val));
+ }
+
+ if (hasBackwardDerivative)
+ {
+ auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<BackwardDerivativeRequirementDecl>();
+ SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty());
+ BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(reqDecl.getFirst(), RequirementWitness(val));
+ }
return true;
}
@@ -5515,6 +5562,43 @@ namespace Slang
}
}
+ void SemanticsDeclHeaderVisitor::cloneModifiers(Decl* dest, Decl* src)
+ {
+ dest->modifiers = src->modifiers;
+ }
+ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType)
+ {
+ if (!funcType)
+ return;
+ decl->returnType.type = funcType->getResultType();
+ decl->errorType.type = funcType->getErrorType();
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ {
+ auto paramType = funcType->getParamType(i);
+ if (auto dirType = as<ParamDirectionType>(paramType))
+ paramType = dirType->getValueType();
+ auto param = m_astBuilder->create<ParamDecl>();
+ param->type.type = paramType;
+ auto paramDir = funcType->getParamDirection(i);
+ switch (paramDir)
+ {
+ case ParameterDirection::kParameterDirection_InOut:
+ addModifier(param, m_astBuilder->create<InOutModifier>());
+ break;
+ case ParameterDirection::kParameterDirection_Out:
+ addModifier(param, m_astBuilder->create<OutModifier>());
+ break;
+ case ParameterDirection::kParameterDirection_Ref:
+ addModifier(param, m_astBuilder->create<RefModifier>());
+ break;
+ default:
+ break;
+ }
+ decl->members.add(param);
+ param->parentDecl = decl;
+ }
+ }
+
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
{
for(auto paramDecl : decl->getParameters())
@@ -5532,6 +5616,30 @@ namespace Slang
errorType = TypeExp(m_astBuilder->getBottomType());
}
decl->errorType = errorType;
+
+ if (isInterfaceRequirement(decl))
+ {
+ if (decl->hasModifier<ForwardDifferentiableAttribute>())
+ {
+ auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
+ cloneModifiers(reqDecl, decl);
+ auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
+ auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef));
+ setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
+ decl->members.add(reqDecl);
+ reqDecl->parentDecl = decl;
+ }
+ if (decl->hasModifier<BackwardDifferentiableAttribute>())
+ {
+ auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
+ cloneModifiers(reqDecl, decl);
+ auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
+ auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef));
+ setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
+ decl->members.add(reqDecl);
+ reqDecl->parentDecl = decl;
+ }
+ }
}
void SemanticsDeclHeaderVisitor::visitFuncDecl(FuncDecl* funcDecl)