summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-decl.cpp34
-rw-r--r--source/slang/slang-diagnostic-defs.h3
-rw-r--r--tests/autodiff/overloaded-custom-deriv.slang41
-rw-r--r--tests/autodiff/overloaded-custom-deriv.slang.expected.txt2
4 files changed, 69 insertions, 11 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 3ffb6c100..63c7d9741 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -6923,20 +6923,34 @@ namespace Slang
auto ctx = visitor->withExprLocalScope(&scope);
auto subVisitor = SemanticsVisitor(ctx);
auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx);
- if (auto derivFuncDeclRef = as<DeclRefExpr>(checkedFuncExpr)->declRef)
+ if (auto declRefExpr = as<DeclRefExpr>(checkedFuncExpr))
{
- visitor->ensureDecl(derivFuncDeclRef, DeclCheckState::TypesFullyResolved);
- auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
- auto resolved = subVisitor.ResolveInvoke(invokeExpr);
- if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ visitor->ensureDecl(declRefExpr->declRef, DeclCheckState::TypesFullyResolved);
+ }
+ else if (auto overloadedExpr = as<OverloadedExpr>(checkedFuncExpr))
+ {
+ for (auto candidate : overloadedExpr->lookupResult2.items)
{
- if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
- {
- attr->funcExpr = calleeDeclRef;
- return;
- }
+ visitor->ensureDecl(candidate.declRef, DeclCheckState::TypesFullyResolved);
}
}
+ else
+ {
+ visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveDerivativeFunction);
+ return;
+ }
+
+ auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
+ auto resolved = subVisitor.ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ {
+ attr->funcExpr = calleeDeclRef;
+ return;
+ }
+ }
+
visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 4a9b83c6c..128142d84 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -356,8 +356,9 @@ DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original defi
DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative attribute.")
DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.")
DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.")
+DIAGNOSTIC(31148, Error, cannotResolveDerivativeFunction, "cannot resolve the custom derivative function")
-DIAGNOSTIC(31148, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.")
+DIAGNOSTIC(31149, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.")
// Enums
diff --git a/tests/autodiff/overloaded-custom-deriv.slang b/tests/autodiff/overloaded-custom-deriv.slang
new file mode 100644
index 000000000..81f91974f
--- /dev/null
+++ b/tests/autodiff/overloaded-custom-deriv.slang
@@ -0,0 +1,41 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[ForwardDerivative(diff_f)]
+float f(float v)
+{
+ return v * v;
+}
+
+DifferentialPair<float> diff_f(DifferentialPair<float> v)
+{
+ return diffPair(v.p * v.p, 2 * v.d * v.p);
+}
+
+[ForwardDerivative(diff_f)]
+float2 f(float2 v)
+{
+ return v * v;
+}
+
+DifferentialPair<float2> diff_f(DifferentialPair<float2> v)
+{
+ return diffPair(v.p * v.p, float2(2 * v.d.x * v.p.x, 0.0));
+}
+
+[ForwardDifferentiable]
+float test(float v)
+{
+ return f(v) + f(float2(v, v)).x;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var p = diffPair(3.0, 1.0);
+ let rs = __fwd_diff(test)(p);
+ outputBuffer[0] = rs.d;
+}
diff --git a/tests/autodiff/overloaded-custom-deriv.slang.expected.txt b/tests/autodiff/overloaded-custom-deriv.slang.expected.txt
new file mode 100644
index 000000000..7da9c9037
--- /dev/null
+++ b/tests/autodiff/overloaded-custom-deriv.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+12.000000