diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-11-22 18:55:47 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-22 23:55:47 +0000 |
| commit | 9913cfbf68dab8c3c8c418dd28b71c2a65a55ae0 (patch) | |
| tree | 735743bce54f0a4faf99925bc4582e3d0de8d5ea /source/slang/slang-check-decl.cpp | |
| parent | 95125f280a3ee6cad08866baedc41fee8585b91e (diff) | |
[AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred (#5630)
* [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred
* Fix failing tests
* Update custom-derivative-generic.slang
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 76 |
1 files changed, 75 insertions, 1 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 251ce6a69..e4206827f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -10915,7 +10915,61 @@ void checkDerivativeAttributeImpl( SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); - auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); + + auto exprToCheck = attr->funcExpr; + + // If this is a generic, we want to wrap the call to the derivative method + // with the generic parameters of the source. + // + if (as<GenericDecl>(funcDecl->parentDecl) && !as<GenericAppExpr>(attr->funcExpr)) + { + auto genericDecl = as<GenericDecl>(funcDecl->parentDecl); + auto substArgs = getDefaultSubstitutionArgs(ctx.getASTBuilder(), visitor, genericDecl); + auto appExpr = ctx.getASTBuilder()->create<GenericAppExpr>(); + + Index count = 0; + for (auto member : genericDecl->members) + { + if (as<GenericTypeParamDecl>(member) || as<GenericValueParamDecl>(member) || + as<GenericTypePackParamDecl>(member)) + count++; + } + + appExpr->functionExpr = attr->funcExpr; + + for (auto arg : substArgs) + { + if (count == 0) + break; + + if (auto declRefType = as<DeclRefType>(arg)) + { + auto baseTypeExpr = ctx.getASTBuilder()->create<SharedTypeExpr>(); + baseTypeExpr->base.type = declRefType; + auto baseTypeType = ctx.getASTBuilder()->getOrCreate<TypeType>(declRefType); + baseTypeExpr->type.type = baseTypeType; + + appExpr->arguments.add(baseTypeExpr); + } + else if (auto genericValParam = as<GenericParamIntVal>(arg)) + { + auto declRef = genericValParam->getDeclRef(); + appExpr->arguments.add( + subVisitor + .ConstructDeclRefExpr(declRef, nullptr, nullptr, SourceLoc(), nullptr)); + } + else + { + SLANG_UNEXPECTED("Unhandled substitution arg type"); + } + + count--; + } + + exprToCheck = appExpr; + } + + auto checkedFuncExpr = visitor->dispatchExpr(exprToCheck, ctx); attr->funcExpr = checkedFuncExpr; if (attr->args.getCount()) attr->args[0] = attr->funcExpr; @@ -11427,6 +11481,26 @@ void checkDerivativeOfAttributeImpl( calleeDeclRef = calleeDeclRefExpr->declRef; auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl()); + + if (!calleeFunc) + { + // If we couldn't find a direct function, it might be a generic. + if (auto genericDecl = as<GenericDecl>(calleeDeclRef.getDecl())) + { + calleeFunc = as<FunctionDeclBase>(genericDecl->inner); + + if (as<ErrorType>(resolved->type.type)) + { + // If we can't resolve a type, something went wrong. If we're working with a generic + // decl, the most likely cause is a failure of generic argument inference. + // + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + } + } + } + if (!calleeFunc) { visitor->getSink()->diagnose( |
