summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-modifier.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-09 19:19:17 -0800
committerGitHub <noreply@github.com>2022-11-09 19:19:17 -0800
commit004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch)
treecbc942746bab043da0eb5298993d95f9665dfddf /source/slang/slang-check-modifier.cpp
parentcedd93690c63188cf98e452c9d104cf51aad6c4e (diff)
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute. * Fix handling around phi nodes. * Fixes. * Remove IR opcode for ForwardDerivativeOfDecoration. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-modifier.cpp')
-rw-r--r--source/slang/slang-check-modifier.cpp100
1 files changed, 19 insertions, 81 deletions
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index d8b05198c..b8ac21e2d 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -617,92 +617,30 @@ namespace Slang
getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget));
return false;
}
-
- // Either diffExpr has a function type, or it is a reference to a generic.
- if (!as<FuncType>(diffExpr->type) &&
- !(as<DeclRefExpr>(diffExpr) &&
- as<DeclRefExpr>(diffExpr)->declRef.as<GenericDecl>().getDecl() != nullptr))
- {
- return false;
- }
-
- auto diffDeclRef = as<DeclRefExpr>(diffExpr)->declRef;
-
- UCount genericLevels = 0;
- // If we've grabbed the outer generic for some reason,
- // recursively construct GenericAppExpr<...>(generic)
- // and check that to get a specialized func.
- //
- while (diffDeclRef.as<GenericDecl>().getDecl() != nullptr)
- {
- // Forward to the inner decl
- diffDeclRef = makeDeclRef(diffDeclRef.as<GenericDecl>().getDecl()->inner);
-
- // Increment counter.
- genericLevels += 1;
- }
-
- auto targetGeneric = as<GenericDecl>(as<Decl>(attrTarget)->parentDecl);
- auto diffGeneric = as<GenericDecl>(diffDeclRef.getDecl()->parentDecl);
- Expr* currentDiffExpr = diffExpr;
-
- // Go back through each level, and use generic declarations in the
- // target's generic scope as arguments for the diff function's generic.
+ // We store the partially checked funcExpr in the attribute, and
+ // rely on `ResolveInvoke` to resolve it to the actual function decl.
+ // The call to `ResolveInvoke` is deferred until we are checking the
+ // body of the function.
//
- for (UIndex ii = 0; ii < genericLevels; ii++)
- {
- // Nest our expression inside a GenericAppExpr
- auto genericAppExpr = getASTBuilder()->create<GenericAppExpr>();
- genericAppExpr->functionExpr = currentDiffExpr;
-
- // Construct references to the generic args in the current scope.
- // TODO: Probably an easier way to do this.
- for (auto member : targetGeneric->members)
- {
- if (auto typeParamDecl = as<GenericTypeParamDecl>(member))
- {
- genericAppExpr->arguments.add(
- ConstructDeclRefExpr(makeDeclRef(typeParamDecl), nullptr, typeParamDecl->loc, nullptr));
- }
- else if (auto valueParamDecl = as<GenericValueParamDecl>(member))
- {
- genericAppExpr->arguments.add(
- ConstructDeclRefExpr(makeDeclRef(valueParamDecl), nullptr, valueParamDecl->loc, nullptr));
- }
- }
-
- // Set our generic-app-expr as the new expr.
- currentDiffExpr = genericAppExpr;
-
- // Peel the generic layer.
- diffGeneric = as<GenericDecl>(diffGeneric->parentDecl);
- targetGeneric = as<GenericDecl>(targetGeneric->parentDecl);
- }
-
- if ((diffGeneric == nullptr && targetGeneric != nullptr) ||
- (targetGeneric == nullptr && diffGeneric != nullptr))
- {
- //getSink()->diagnose(diffDeclRef, Slang::Diagnostics::customDerivativeGenericSignatureMismatch, diffDeclRef, attrTarget);
- SLANG_UNEXPECTED("");
- }
-
- // If we had to change currentDiffExpr, then re-check the expr.
- if (!currentDiffExpr->type)
- {
- currentDiffExpr = CheckTerm(currentDiffExpr);
- }
+ // Set type to null to indicate that this needs expr needs to be further resolved.
+ diffExpr->type.type = nullptr;
+ forwardDerivativeAttr->funcExpr = diffExpr;
+ }
+ else if (auto forwardDerivativeOfAttr = as<ForwardDerivativeOfAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+ SLANG_ASSERT(as<Decl>(attrTarget));
// Ensure that the argument is a reference to a function definition or declaration.
- auto currentDiffDeclRefExpr = as<DeclRefExpr>(currentDiffExpr);
- auto currentDiffDeclRef = currentDiffDeclRefExpr->declRef;
-
- if (!as<FuncType>(GetTypeForDeclRef(currentDiffDeclRef, currentDiffDeclRef.getLoc())))
+ auto primalFunc = CheckTerm(attr->args[0]);
+ if (primalFunc->type == getASTBuilder()->getErrorType())
{
- getSink()->diagnose(currentDiffDeclRef, Slang::Diagnostics::customDerivativeNotAFunction, currentDiffDeclRef);
+ // Could not resolve the term.
+ getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget));
+ return false;
}
-
- // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr)
- forwardDerivativeAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr));
+
+ forwardDerivativeOfAttr->funcExpr = primalFunc;
}
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{