summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-11-22 18:55:47 -0500
committerGitHub <noreply@github.com>2024-11-22 23:55:47 +0000
commit9913cfbf68dab8c3c8c418dd28b71c2a65a55ae0 (patch)
tree735743bce54f0a4faf99925bc4582e3d0de8d5ea /tests
parent95125f280a3ee6cad08866baedc41fee8585b91e (diff)
[AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred (#5630)
* [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred * Fix failing tests * Update custom-derivative-generic.slang
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/custom-derivative-enum-param.slang57
-rw-r--r--tests/autodiff/custom-intrinsic-1.slang (renamed from tests/autodiff/custom-intrinsic.slang)0
-rw-r--r--tests/autodiff/custom-intrinsic-1.slang.expected.txt (renamed from tests/autodiff/custom-intrinsic.slang.expected.txt)0
-rw-r--r--tests/diagnostics/custom-derivative-generic.slang2
4 files changed, 58 insertions, 1 deletions
diff --git a/tests/autodiff/custom-derivative-enum-param.slang b/tests/autodiff/custom-derivative-enum-param.slang
new file mode 100644
index 000000000..aa6733873
--- /dev/null
+++ b/tests/autodiff/custom-derivative-enum-param.slang
@@ -0,0 +1,57 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
+
+enum MyEnum { A, B, C };
+
+[BackwardDerivative(mDiff)]
+float m<let M : MyEnum>(float x)
+{
+ switch (M)
+ {
+ case MyEnum.A:
+ return x * x;
+ case MyEnum.B:
+ return x;
+ case MyEnum.C:
+ return 3 * x;
+ default:
+ return 0;
+ }
+}
+
+void mDiff<let M : MyEnum>(inout DifferentialPair<float> x, float dResult)
+{
+ switch (M)
+ {
+ case MyEnum.A:
+ updateDiff(x, 2 * dResult * x.p);
+ break;
+ case MyEnum.B:
+ updateDiff(x, dResult);
+ break;
+ case MyEnum.C:
+ updateDiff(x, 3 * dResult);
+ break;
+ default:
+ updateDiff(x, 0);
+ break;
+ }
+}
+
+[Differentiable]
+float test(float x)
+{
+ return m<MyEnum.A>(x);
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ var a = diffPair(3.0);
+ __bwd_diff(test)(a, 1.0);
+ outputBuffer[dispatchThreadID.x] = a.d;
+ // CHECK: 6.0
+}
diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic-1.slang
index 1fe204b58..1fe204b58 100644
--- a/tests/autodiff/custom-intrinsic.slang
+++ b/tests/autodiff/custom-intrinsic-1.slang
diff --git a/tests/autodiff/custom-intrinsic.slang.expected.txt b/tests/autodiff/custom-intrinsic-1.slang.expected.txt
index ce22a5b95..ce22a5b95 100644
--- a/tests/autodiff/custom-intrinsic.slang.expected.txt
+++ b/tests/autodiff/custom-intrinsic-1.slang.expected.txt
diff --git a/tests/diagnostics/custom-derivative-generic.slang b/tests/diagnostics/custom-derivative-generic.slang
index 5f2cd9951..fb65dd2cc 100644
--- a/tests/diagnostics/custom-derivative-generic.slang
+++ b/tests/diagnostics/custom-derivative-generic.slang
@@ -34,7 +34,7 @@ DifferentialPair<float> dd1(DifferentialPair<float> x)
}
// CHECK-DAG: {{.*}}(37): error 31151
-[BackwardDerivative(f)]
+[BackwardDerivativeOf(f)]
DifferentialPair<float> df<let N:int>(inout DifferentialPair<float> x, float dOut)
{
var primal = x.p * x.p;