summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-11-22 18:55:47 -0500
committerGitHub <noreply@github.com>2024-11-22 23:55:47 +0000
commit9913cfbf68dab8c3c8c418dd28b71c2a65a55ae0 (patch)
tree735743bce54f0a4faf99925bc4582e3d0de8d5ea
parent95125f280a3ee6cad08866baedc41fee8585b91e (diff)
[AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred (#5630)
* [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred * Fix failing tests * Update custom-derivative-generic.slang
-rw-r--r--source/slang/slang-check-decl.cpp76
-rw-r--r--tests/autodiff/custom-derivative-enum-param.slang57
-rw-r--r--tests/autodiff/custom-intrinsic-1.slang (renamed from tests/autodiff/custom-intrinsic.slang)0
-rw-r--r--tests/autodiff/custom-intrinsic-1.slang.expected.txt (renamed from tests/autodiff/custom-intrinsic.slang.expected.txt)0
-rw-r--r--tests/diagnostics/custom-derivative-generic.slang2
5 files changed, 133 insertions, 2 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 251ce6a69..e4206827f 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -10915,7 +10915,61 @@ void checkDerivativeAttributeImpl(
SemanticsContext::ExprLocalScope scope;
auto ctx = visitor->withExprLocalScope(&scope);
auto subVisitor = SemanticsVisitor(ctx);
- auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx);
+
+ auto exprToCheck = attr->funcExpr;
+
+ // If this is a generic, we want to wrap the call to the derivative method
+ // with the generic parameters of the source.
+ //
+ if (as<GenericDecl>(funcDecl->parentDecl) && !as<GenericAppExpr>(attr->funcExpr))
+ {
+ auto genericDecl = as<GenericDecl>(funcDecl->parentDecl);
+ auto substArgs = getDefaultSubstitutionArgs(ctx.getASTBuilder(), visitor, genericDecl);
+ auto appExpr = ctx.getASTBuilder()->create<GenericAppExpr>();
+
+ Index count = 0;
+ for (auto member : genericDecl->members)
+ {
+ if (as<GenericTypeParamDecl>(member) || as<GenericValueParamDecl>(member) ||
+ as<GenericTypePackParamDecl>(member))
+ count++;
+ }
+
+ appExpr->functionExpr = attr->funcExpr;
+
+ for (auto arg : substArgs)
+ {
+ if (count == 0)
+ break;
+
+ if (auto declRefType = as<DeclRefType>(arg))
+ {
+ auto baseTypeExpr = ctx.getASTBuilder()->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = declRefType;
+ auto baseTypeType = ctx.getASTBuilder()->getOrCreate<TypeType>(declRefType);
+ baseTypeExpr->type.type = baseTypeType;
+
+ appExpr->arguments.add(baseTypeExpr);
+ }
+ else if (auto genericValParam = as<GenericParamIntVal>(arg))
+ {
+ auto declRef = genericValParam->getDeclRef();
+ appExpr->arguments.add(
+ subVisitor
+ .ConstructDeclRefExpr(declRef, nullptr, nullptr, SourceLoc(), nullptr));
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unhandled substitution arg type");
+ }
+
+ count--;
+ }
+
+ exprToCheck = appExpr;
+ }
+
+ auto checkedFuncExpr = visitor->dispatchExpr(exprToCheck, ctx);
attr->funcExpr = checkedFuncExpr;
if (attr->args.getCount())
attr->args[0] = attr->funcExpr;
@@ -11427,6 +11481,26 @@ void checkDerivativeOfAttributeImpl(
calleeDeclRef = calleeDeclRefExpr->declRef;
auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
+
+ if (!calleeFunc)
+ {
+ // If we couldn't find a direct function, it might be a generic.
+ if (auto genericDecl = as<GenericDecl>(calleeDeclRef.getDecl()))
+ {
+ calleeFunc = as<FunctionDeclBase>(genericDecl->inner);
+
+ if (as<ErrorType>(resolved->type.type))
+ {
+ // If we can't resolve a type, something went wrong. If we're working with a generic
+ // decl, the most likely cause is a failure of generic argument inference.
+ //
+ visitor->getSink()->diagnose(
+ derivativeOfAttr,
+ Diagnostics::cannotResolveGenericArgumentForDerivativeFunction);
+ }
+ }
+ }
+
if (!calleeFunc)
{
visitor->getSink()->diagnose(
diff --git a/tests/autodiff/custom-derivative-enum-param.slang b/tests/autodiff/custom-derivative-enum-param.slang
new file mode 100644
index 000000000..aa6733873
--- /dev/null
+++ b/tests/autodiff/custom-derivative-enum-param.slang
@@ -0,0 +1,57 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
+
+enum MyEnum { A, B, C };
+
+[BackwardDerivative(mDiff)]
+float m<let M : MyEnum>(float x)
+{
+ switch (M)
+ {
+ case MyEnum.A:
+ return x * x;
+ case MyEnum.B:
+ return x;
+ case MyEnum.C:
+ return 3 * x;
+ default:
+ return 0;
+ }
+}
+
+void mDiff<let M : MyEnum>(inout DifferentialPair<float> x, float dResult)
+{
+ switch (M)
+ {
+ case MyEnum.A:
+ updateDiff(x, 2 * dResult * x.p);
+ break;
+ case MyEnum.B:
+ updateDiff(x, dResult);
+ break;
+ case MyEnum.C:
+ updateDiff(x, 3 * dResult);
+ break;
+ default:
+ updateDiff(x, 0);
+ break;
+ }
+}
+
+[Differentiable]
+float test(float x)
+{
+ return m<MyEnum.A>(x);
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ var a = diffPair(3.0);
+ __bwd_diff(test)(a, 1.0);
+ outputBuffer[dispatchThreadID.x] = a.d;
+ // CHECK: 6.0
+}
diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic-1.slang
index 1fe204b58..1fe204b58 100644
--- a/tests/autodiff/custom-intrinsic.slang
+++ b/tests/autodiff/custom-intrinsic-1.slang
diff --git a/tests/autodiff/custom-intrinsic.slang.expected.txt b/tests/autodiff/custom-intrinsic-1.slang.expected.txt
index ce22a5b95..ce22a5b95 100644
--- a/tests/autodiff/custom-intrinsic.slang.expected.txt
+++ b/tests/autodiff/custom-intrinsic-1.slang.expected.txt
diff --git a/tests/diagnostics/custom-derivative-generic.slang b/tests/diagnostics/custom-derivative-generic.slang
index 5f2cd9951..fb65dd2cc 100644
--- a/tests/diagnostics/custom-derivative-generic.slang
+++ b/tests/diagnostics/custom-derivative-generic.slang
@@ -34,7 +34,7 @@ DifferentialPair<float> dd1(DifferentialPair<float> x)
}
// CHECK-DAG: {{.*}}(37): error 31151
-[BackwardDerivative(f)]
+[BackwardDerivativeOf(f)]
DifferentialPair<float> df<let N:int>(inout DifferentialPair<float> x, float dOut)
{
var primal = x.p * x.p;