summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-decl.cpp76
1 files changed, 75 insertions, 1 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 251ce6a69..e4206827f 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -10915,7 +10915,61 @@ void checkDerivativeAttributeImpl(
SemanticsContext::ExprLocalScope scope;
auto ctx = visitor->withExprLocalScope(&scope);
auto subVisitor = SemanticsVisitor(ctx);
- auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx);
+
+ auto exprToCheck = attr->funcExpr;
+
+ // If this is a generic, we want to wrap the call to the derivative method
+ // with the generic parameters of the source.
+ //
+ if (as<GenericDecl>(funcDecl->parentDecl) && !as<GenericAppExpr>(attr->funcExpr))
+ {
+ auto genericDecl = as<GenericDecl>(funcDecl->parentDecl);
+ auto substArgs = getDefaultSubstitutionArgs(ctx.getASTBuilder(), visitor, genericDecl);
+ auto appExpr = ctx.getASTBuilder()->create<GenericAppExpr>();
+
+ Index count = 0;
+ for (auto member : genericDecl->members)
+ {
+ if (as<GenericTypeParamDecl>(member) || as<GenericValueParamDecl>(member) ||
+ as<GenericTypePackParamDecl>(member))
+ count++;
+ }
+
+ appExpr->functionExpr = attr->funcExpr;
+
+ for (auto arg : substArgs)
+ {
+ if (count == 0)
+ break;
+
+ if (auto declRefType = as<DeclRefType>(arg))
+ {
+ auto baseTypeExpr = ctx.getASTBuilder()->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = declRefType;
+ auto baseTypeType = ctx.getASTBuilder()->getOrCreate<TypeType>(declRefType);
+ baseTypeExpr->type.type = baseTypeType;
+
+ appExpr->arguments.add(baseTypeExpr);
+ }
+ else if (auto genericValParam = as<GenericParamIntVal>(arg))
+ {
+ auto declRef = genericValParam->getDeclRef();
+ appExpr->arguments.add(
+ subVisitor
+ .ConstructDeclRefExpr(declRef, nullptr, nullptr, SourceLoc(), nullptr));
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unhandled substitution arg type");
+ }
+
+ count--;
+ }
+
+ exprToCheck = appExpr;
+ }
+
+ auto checkedFuncExpr = visitor->dispatchExpr(exprToCheck, ctx);
attr->funcExpr = checkedFuncExpr;
if (attr->args.getCount())
attr->args[0] = attr->funcExpr;
@@ -11427,6 +11481,26 @@ void checkDerivativeOfAttributeImpl(
calleeDeclRef = calleeDeclRefExpr->declRef;
auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
+
+ if (!calleeFunc)
+ {
+ // If we couldn't find a direct function, it might be a generic.
+ if (auto genericDecl = as<GenericDecl>(calleeDeclRef.getDecl()))
+ {
+ calleeFunc = as<FunctionDeclBase>(genericDecl->inner);
+
+ if (as<ErrorType>(resolved->type.type))
+ {
+ // If we can't resolve a type, something went wrong. If we're working with a generic
+ // decl, the most likely cause is a failure of generic argument inference.
+ //
+ visitor->getSink()->diagnose(
+ derivativeOfAttr,
+ Diagnostics::cannotResolveGenericArgumentForDerivativeFunction);
+ }
+ }
+ }
+
if (!calleeFunc)
{
visitor->getSink()->diagnose(