summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-modifier.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-26 22:21:29 -0400
committerGitHub <noreply@github.com>2022-10-26 19:21:29 -0700
commitf7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch)
tree574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang/slang-check-modifier.cpp
parent939be44ca23476e622dfb24a592383fe2a1da61f (diff)
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang/slang-check-modifier.cpp')
-rw-r--r--source/slang/slang-check-modifier.cpp97
1 files changed, 93 insertions, 4 deletions
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 7e11ee3ca..20e5d5378 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -628,13 +628,102 @@ namespace Slang
else if (auto customJVPAttr = as<CustomJVPAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
-
+ SLANG_ASSERT(as<Decl>(attrTarget));
+
// Ensure that the argument is a reference to a function definition or declaration.
- auto funcExpr = as<DeclRefExpr>(CheckTerm(attr->args[0]));
- if (!as<FuncType>(funcExpr->type))
+ auto diffExpr = CheckTerm(attr->args[0]);
+ if (diffExpr->type == getASTBuilder()->getErrorType())
+ {
+ // Could not resolve the term.
+ getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget));
+ return false;
+ }
+
+ // Either diffExpr has a function type, or it is a reference to a generic.
+ if (!as<FuncType>(diffExpr->type) &&
+ !(as<DeclRefExpr>(diffExpr) &&
+ as<DeclRefExpr>(diffExpr)->declRef.as<GenericDecl>().getDecl() != nullptr))
+ {
return false;
+ }
- customJVPAttr->funcDeclRef = funcExpr;
+ auto diffDeclRef = as<DeclRefExpr>(diffExpr)->declRef;
+
+ UCount genericLevels = 0;
+ // If we've grabbed the outer generic for some reason,
+ // recursively construct GenericAppExpr<...>(generic)
+ // and check that to get a specialized func.
+ //
+ while (diffDeclRef.as<GenericDecl>().getDecl() != nullptr)
+ {
+ // Forward to the inner decl
+ diffDeclRef = makeDeclRef(diffDeclRef.as<GenericDecl>().getDecl()->inner);
+
+ // Increment counter.
+ genericLevels += 1;
+ }
+
+ auto targetGeneric = as<GenericDecl>(as<Decl>(attrTarget)->parentDecl);
+ auto diffGeneric = as<GenericDecl>(diffDeclRef.getDecl()->parentDecl);
+ Expr* currentDiffExpr = diffExpr;
+
+ // Go back through each level, and use generic declarations in the
+ // target's generic scope as arguments for the diff function's generic.
+ //
+ for (UIndex ii = 0; ii < genericLevels; ii++)
+ {
+ // Nest our expression inside a GenericAppExpr
+ auto genericAppExpr = getASTBuilder()->create<GenericAppExpr>();
+ genericAppExpr->functionExpr = currentDiffExpr;
+
+ // Construct references to the generic args in the current scope.
+ // TODO: Probably an easier way to do this.
+ for (auto member : targetGeneric->members)
+ {
+ if (auto typeParamDecl = as<GenericTypeParamDecl>(member))
+ {
+ genericAppExpr->arguments.add(
+ ConstructDeclRefExpr(makeDeclRef(typeParamDecl), nullptr, typeParamDecl->loc, nullptr));
+ }
+ else if (auto valueParamDecl = as<GenericValueParamDecl>(member))
+ {
+ genericAppExpr->arguments.add(
+ ConstructDeclRefExpr(makeDeclRef(valueParamDecl), nullptr, valueParamDecl->loc, nullptr));
+ }
+ }
+
+ // Set our generic-app-expr as the new expr.
+ currentDiffExpr = genericAppExpr;
+
+ // Peel the generic layer.
+ diffGeneric = as<GenericDecl>(diffGeneric->parentDecl);
+ targetGeneric = as<GenericDecl>(targetGeneric->parentDecl);
+ }
+
+ if ((diffGeneric == nullptr && targetGeneric != nullptr) ||
+ (targetGeneric == nullptr && diffGeneric != nullptr))
+ {
+ //getSink()->diagnose(diffDeclRef, Slang::Diagnostics::customDerivativeGenericSignatureMismatch, diffDeclRef, attrTarget);
+ SLANG_UNEXPECTED("");
+ }
+
+ // If we had to change currentDiffExpr, then re-check the expr.
+ if (!currentDiffExpr->type)
+ {
+ currentDiffExpr = CheckTerm(currentDiffExpr);
+ }
+
+ // Ensure that the argument is a reference to a function definition or declaration.
+ auto currentDiffDeclRefExpr = as<DeclRefExpr>(currentDiffExpr);
+ auto currentDiffDeclRef = currentDiffDeclRefExpr->declRef;
+
+ if (!as<FuncType>(GetTypeForDeclRef(currentDiffDeclRef, currentDiffDeclRef.getLoc())))
+ {
+ getSink()->diagnose(currentDiffDeclRef, Slang::Diagnostics::customDerivativeNotAFunction, currentDiffDeclRef);
+ }
+
+ // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr)
+ customJVPAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr));
}
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{