summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-13 11:48:54 -0800
committerGitHub <noreply@github.com>2023-01-13 11:48:54 -0800
commit4adc64f2a033ec141df6a16e65131612b30fb23b (patch)
tree31e4fabbfcac5e59ee334acb2be0f1df2542d679 /source/slang/slang-check-decl.cpp
parent63b874dab2df8950a37e0861d24f322e0ab9bfda (diff)
Frontend work for `[BackwardDerivative]` and `[BackwardDerivativeOf]`. (#2589)
* Frontend work for `[BackwardDerivative]` and `[BackwardDerivativeOf]`. * Fix clang issue. * Fix. * fix gcc issue * fix formatting. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp287
1 files changed, 236 insertions, 51 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index b8732a67f..f016ae3d8 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -259,10 +259,9 @@ namespace Slang
void visitParamDecl(ParamDecl* paramDecl);
- void checkDerivativeOfAttribute(FunctionDeclBase* funcDecl);
-
- void checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr);
+ void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl);
+ void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl);
};
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
@@ -4668,90 +4667,273 @@ namespace Slang
getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
}
- void SemanticsDeclBodyVisitor::checkDerivativeOfAttribute(FunctionDeclBase* funcDecl)
+ template<typename TDerivativeAttr>
+ void checkDerivativeAttributeImpl(
+ SemanticsVisitor* visitor,
+ TDerivativeAttr* attr,
+ const List<Expr*>& imaginaryArguments)
{
- auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>();
- if (!attr)
- return;
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
+ auto resolved = visitor->ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ {
+ attr->funcExpr = calleeDeclRef;
+ return;
+ }
+ }
+ visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ }
+
+ template<typename TDerivativeAttr>
+ const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); }
+
+ template<>
+ const char* getDerivativeAttrName<ForwardDerivativeAttribute>()
+ {
+ return "ForwardDerivative";
+ }
+ template<>
+ const char* getDerivativeAttrName<BackwardDerivativeAttribute>()
+ {
+ return "BackwardDerivative";
+ }
+ List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ {
List<Expr*> imaginaryArguments;
- for (auto param : funcDecl->getParameters())
+ for (auto param : originalFuncDecl->getParameters())
{
- auto arg = m_astBuilder->create<VarExpr>();
+ auto arg = visitor->getASTBuilder()->create<VarExpr>();
arg->declRef.decl = param;
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
- arg->loc = attr->loc;
+ arg->loc = loc;
+ if (auto pairType = visitor->getDifferentialPairType(param->getType()))
+ {
+ arg->type.type = pairType;
+ }
+ imaginaryArguments.add(arg);
+ }
+ return imaginaryArguments;
+ }
+
+ List<Expr*> getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* fwdDiffFunc, SourceLoc loc)
+ {
+ List<Expr*> imaginaryArguments;
+ for (auto param : fwdDiffFunc->getParameters())
+ {
+ auto arg = astBuilder->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = loc;
if (auto pairType = as<DifferentialPairType>(param->getType()))
{
arg->type.type = pairType->getPrimalType();
}
imaginaryArguments.add(arg);
}
- auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
- auto resolved = ResolveInvoke(invokeExpr);
- if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ return imaginaryArguments;
+ }
+
+ List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ {
+ List<Expr*> imaginaryArguments;
+ for (auto param : originalFuncDecl->getParameters())
{
- if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ auto arg = visitor->getASTBuilder()->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = loc;
+ if (auto pairType = visitor->getDifferentialPairType(param->getType()))
{
- if (auto existingModifier = calleeDeclRef->declRef.getDecl()->findModifier<ForwardDerivativeAttribute>())
+ arg->type.type = pairType;
+ if (auto diffPairType = as<DifferentialPairType>(pairType))
{
- // The primal function already has a `[ForwardDerivative]` attribute, this is invalid.
- getSink()->diagnose(attr, Diagnostics::declAlreadyHasAttribute, calleeDeclRef->declRef, "[ForwardDerivative]");
- getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl());
+ if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr)
+ {
+ arg->type.isLeftValue = false;
+ arg->type.type = diffPairType->getPrimalType();
+ }
}
- attr->funcExpr = calleeDeclRef;
- auto fwdDerivativeAttr = m_astBuilder->create<ForwardDerivativeAttribute>();
- fwdDerivativeAttr->loc = attr->loc;
- auto outterGeneric = GetOuterGeneric(funcDecl);
- auto declRef =
- DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
- auto declRefExpr = ConstructDeclRefExpr(declRef, nullptr, attr->loc, nullptr);
- declRefExpr->type.type = nullptr;
- fwdDerivativeAttr->args.add(declRefExpr);
- fwdDerivativeAttr->funcExpr = declRefExpr;
- checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr);
- attr->backDeclRef = fwdDerivativeAttr->funcExpr;
- fwdDerivativeAttr->funcExpr = nullptr;
- getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), DeclAssociationKind::ForwardDerivativeFunc, funcDecl);
- return;
}
+ imaginaryArguments.add(arg);
+ }
+ if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type))
+ {
+ auto arg = visitor->getASTBuilder()->create<InitializerListExpr>();
+ arg->type.isLeftValue = false;
+ arg->type.type = diffReturnType;
+ arg->loc = loc;
+ imaginaryArguments.add(arg);
}
- getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ return imaginaryArguments;
}
- void SemanticsDeclBodyVisitor::checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
+ List<Expr*> getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc)
{
- if (!attr->funcExpr)
- return;
- if (attr->funcExpr->type.type)
- return;
-
List<Expr*> imaginaryArguments;
- for (auto param : funcDecl->getParameters())
+ for (auto param : bwdDiffFunc->getParameters())
{
- auto arg = m_astBuilder->create<VarExpr>();
+ auto arg = astBuilder->create<VarExpr>();
arg->declRef.decl = param;
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
- arg->loc = attr->loc;
- if (auto pairType = getDifferentialPairType(param->getType()))
+ arg->loc = loc;
+ if (auto pairType = as<DifferentialPairType>(param->getType()))
{
- arg->type.type = pairType;
+ if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr)
+ {
+ arg->type.isLeftValue = false;
+ }
+ arg->type.type = pairType->getPrimalType();
}
imaginaryArguments.add(arg);
}
- auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
- auto resolved = ResolveInvoke(invokeExpr);
+ return imaginaryArguments;
+ }
+
+ // This helper function is needed to workaround a gcc bug.
+ // Remove when we upgrade to a newer version of gcc.
+ template <typename T>
+ static T* _findModifier(Decl* decl)
+ {
+ return decl->findModifier<T>();
+ }
+
+ template <typename TDerivativeAttr, typename TDerivativeOfAttr>
+ void checkDerivativeOfAttributeImpl(
+ SemanticsVisitor* visitor,
+ FunctionDeclBase* funcDecl,
+ TDerivativeOfAttr* derivativeOfAttr,
+ DeclAssociationKind assocKind,
+ const List<Expr*>& imaginaryArgsToOriginal)
+ {
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal);
+ auto resolved = visitor->ResolveInvoke(invokeExpr);
if (auto resolvedInvoke = as<InvokeExpr>(resolved))
{
if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
{
- attr->funcExpr = calleeDeclRef;
+ auto calleeDecl = calleeDeclRef->declRef.getDecl();
+ if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeDecl))
+ {
+ // The primal function already has a `[*Derivative]` attribute, this is invalid.
+ visitor->getSink()->diagnose(
+ derivativeOfAttr,
+ Diagnostics::declAlreadyHasAttribute,
+ calleeDeclRef->declRef,
+ getDerivativeAttrName<TDerivativeAttr>());
+ visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl());
+ }
+ derivativeOfAttr->funcExpr = calleeDeclRef;
+ auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>();
+ derivativeAttr->loc = derivativeOfAttr->loc;
+ auto outterGeneric = visitor->GetOuterGeneric(funcDecl);
+ auto declRef =
+ DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
+ auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr);
+ declRefExpr->type.type = nullptr;
+ derivativeAttr->args.add(declRefExpr);
+ derivativeAttr->funcExpr = declRefExpr;
+ checkDerivativeAttribute(visitor, as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), derivativeAttr);
+ derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr;
+ derivativeAttr->funcExpr = nullptr;
+ visitor->getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), assocKind, funcDecl);
return;
}
}
- getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::invalidCustomDerivative);
+ }
+
+ static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ }
+
+ static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ }
+
+ template<typename TDerivativeAttr, typename TDerivativeOfAttr>
+ bool tryCheckDerivativeOfAttributeImpl(
+ SemanticsVisitor* visitor,
+ FunctionDeclBase* funcDecl,
+ TDerivativeOfAttr* derivativeOfAttr,
+ DeclAssociationKind assocKind,
+ const List<Expr*>& imaginaryArgsToOriginal)
+ {
+ DiagnosticSink tempSink(visitor->getSourceManager(), nullptr);
+ SemanticsVisitor subVisitor(visitor->withSink(&tempSink));
+ checkDerivativeOfAttributeImpl<TDerivativeAttr>(
+ &subVisitor,
+ funcDecl,
+ derivativeOfAttr,
+ assocKind,
+ imaginaryArgsToOriginal);
+ return tempSink.getErrorCount() == 0;
+ }
+
+ void SemanticsDeclBodyVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl)
+ {
+ auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>();
+ if (!attr)
+ return;
+
+ List<Expr*> imaginaryArgsToOriginal = getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc);
+ checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute>(
+ this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc, imaginaryArgsToOriginal);
+ }
+
+ void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl)
+ {
+ auto attr = funcDecl->findModifier<BackwardDerivativeOfAttribute>();
+ if (!attr)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc);
+
+ // The tricky part here is that we can't easily derive the arguments to original func just
+ // from the definition of a backward derivative function, because we don't know if the last
+ // parameter is just a normal parameter of the original func, or if it is the additional
+ // derivative of the return value. The solution here is to try to resolve the original
+ // function with or without the last argument. However if the type of the last argument
+ // isn't differentiable, we know that it can't possibly be the result derivative.
+
+ if (imaginaryArguments.getCount() == 0 ||
+ !tryGetDifferentialType(m_astBuilder, imaginaryArguments.getLast()->type.type))
+ {
+ checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(
+ this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments);
+ return;
+ }
+
+ // Otherwise, try resolve with all the arguments, if failed, resolve without the last
+ // argument.
+ if (tryCheckDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments))
+ {
+ return;
+ }
+
+ imaginaryArguments.removeLast();
+ checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(
+ this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments);
}
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
@@ -4759,9 +4941,12 @@ namespace Slang
auto newContext = withParentFunc(decl);
// Run checking on attributes that can't be fully checked in header checking stage.
- checkDerivativeOfAttribute(decl);
+ checkForwardDerivativeOfAttribute(decl);
if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>())
- checkDerivativeAttribute(decl, derivativeAttr);
+ checkDerivativeAttribute(this, decl, derivativeAttr);
+ checkBackwardDerivativeOfAttribute(decl);
+ if (auto derivativeAttr = decl->findModifier<BackwardDerivativeAttribute>())
+ checkDerivativeAttribute(this, decl, derivativeAttr);
if (newContext.getParentDifferentiableAttribute())
{