From 9913cfbf68dab8c3c8c418dd28b71c2a65a55ae0 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 22 Nov 2024 18:55:47 -0500 Subject: [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred (#5630) * [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred * Fix failing tests * Update custom-derivative-generic.slang --- source/slang/slang-check-decl.cpp | 76 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) (limited to 'source') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 251ce6a69..e4206827f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -10915,7 +10915,61 @@ void checkDerivativeAttributeImpl( SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); - auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); + + auto exprToCheck = attr->funcExpr; + + // If this is a generic, we want to wrap the call to the derivative method + // with the generic parameters of the source. + // + if (as(funcDecl->parentDecl) && !as(attr->funcExpr)) + { + auto genericDecl = as(funcDecl->parentDecl); + auto substArgs = getDefaultSubstitutionArgs(ctx.getASTBuilder(), visitor, genericDecl); + auto appExpr = ctx.getASTBuilder()->create(); + + Index count = 0; + for (auto member : genericDecl->members) + { + if (as(member) || as(member) || + as(member)) + count++; + } + + appExpr->functionExpr = attr->funcExpr; + + for (auto arg : substArgs) + { + if (count == 0) + break; + + if (auto declRefType = as(arg)) + { + auto baseTypeExpr = ctx.getASTBuilder()->create(); + baseTypeExpr->base.type = declRefType; + auto baseTypeType = ctx.getASTBuilder()->getOrCreate(declRefType); + baseTypeExpr->type.type = baseTypeType; + + appExpr->arguments.add(baseTypeExpr); + } + else if (auto genericValParam = as(arg)) + { + auto declRef = genericValParam->getDeclRef(); + appExpr->arguments.add( + subVisitor + .ConstructDeclRefExpr(declRef, nullptr, nullptr, SourceLoc(), nullptr)); + } + else + { + SLANG_UNEXPECTED("Unhandled substitution arg type"); + } + + count--; + } + + exprToCheck = appExpr; + } + + auto checkedFuncExpr = visitor->dispatchExpr(exprToCheck, ctx); attr->funcExpr = checkedFuncExpr; if (attr->args.getCount()) attr->args[0] = attr->funcExpr; @@ -11427,6 +11481,26 @@ void checkDerivativeOfAttributeImpl( calleeDeclRef = calleeDeclRefExpr->declRef; auto calleeFunc = as(calleeDeclRef.getDecl()); + + if (!calleeFunc) + { + // If we couldn't find a direct function, it might be a generic. + if (auto genericDecl = as(calleeDeclRef.getDecl())) + { + calleeFunc = as(genericDecl->inner); + + if (as(resolved->type.type)) + { + // If we can't resolve a type, something went wrong. If we're working with a generic + // decl, the most likely cause is a failure of generic argument inference. + // + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + } + } + } + if (!calleeFunc) { visitor->getSink()->diagnose( -- cgit v1.2.3