summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-09 19:19:17 -0800
committerGitHub <noreply@github.com>2022-11-09 19:19:17 -0800
commit004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch)
treecbc942746bab043da0eb5298993d95f9665dfddf /source/slang/slang-check-decl.cpp
parentcedd93690c63188cf98e452c9d104cf51aad6c4e (diff)
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute. * Fix handling around phi nodes. * Fixes. * Remove IR opcode for ForwardDerivativeOfDecoration. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp95
1 files changed, 95 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 333e9d973..b33c33e7a 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -256,6 +256,11 @@ namespace Slang
void visitParamDecl(ParamDecl* paramDecl);
void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context);
+
+ void checkDerivativeOfAttribute(FunctionDeclBase* funcDecl);
+
+ void checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr);
+
};
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
@@ -4582,11 +4587,101 @@ namespace Slang
}
}
+ void SemanticsDeclBodyVisitor::checkDerivativeOfAttribute(FunctionDeclBase* funcDecl)
+ {
+ auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>();
+ if (!attr)
+ return;
+
+ List<Expr*> imaginaryArguments;
+ for (auto param : funcDecl->getParameters())
+ {
+ auto arg = m_astBuilder->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = attr->loc;
+ if (auto pairType = as<DifferentialPairType>(param->getType()))
+ {
+ arg->type.type = pairType->getPrimalType();
+ }
+ imaginaryArguments.add(arg);
+ }
+ auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
+ auto resolved = ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ {
+ if (auto existingModifier = calleeDeclRef->declRef.getDecl()->findModifier<ForwardDerivativeAttribute>())
+ {
+ // The primal function already has a `[ForwardDerivative]` attribute, this is invalid.
+ getSink()->diagnose(attr, Diagnostics::declAlreadyHasAttribute, calleeDeclRef->declRef, "[ForwardDerivative]");
+ getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl());
+ }
+ attr->funcExpr = calleeDeclRef;
+ auto fwdDerivativeAttr = m_astBuilder->create<ForwardDerivativeAttribute>();
+ fwdDerivativeAttr->loc = attr->loc;
+ auto outterGeneric = GetOuterGeneric(funcDecl);
+ auto declRef =
+ DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
+ auto declRefExpr = ConstructDeclRefExpr(declRef, nullptr, attr->loc, nullptr);
+ declRefExpr->type.type = nullptr;
+ fwdDerivativeAttr->args.add(declRefExpr);
+ fwdDerivativeAttr->funcExpr = declRefExpr;
+ checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr);
+ attr->backDeclRef = fwdDerivativeAttr->funcExpr;
+ fwdDerivativeAttr->funcExpr = nullptr;
+ return;
+ }
+ }
+ getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ }
+
+ void SemanticsDeclBodyVisitor::checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments;
+ for (auto param : funcDecl->getParameters())
+ {
+ auto arg = m_astBuilder->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = attr->loc;
+ if (auto pairType = getDifferentialPairType(param->getType()))
+ {
+ arg->type.type = pairType;
+ }
+ imaginaryArguments.add(arg);
+ }
+ auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
+ auto resolved = ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ {
+ attr->funcExpr = calleeDeclRef;
+ return;
+ }
+ }
+ getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ }
+
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
{
auto newContext = withParentFunc(decl);
_maybeRegisterDifferentialBottomTypeConformance(newContext);
+ // Run checking on attributes that can't be fully checked in header checking stage.
+ checkDerivativeOfAttribute(decl);
+ if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>())
+ checkDerivativeAttribute(decl, derivativeAttr);
+
if (auto body = decl->body)
{
checkStmt(decl->body, newContext);