summaryrefslogtreecommitdiffstats
path: root/tests/diagnostics
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-10-26 14:01:26 -0700
committerGitHub <noreply@github.com>2023-10-26 14:01:26 -0700
commitbee74b16eafa64ccc33bb386a1dc753cd6c41a82 (patch)
tree199a4575cfe6297bb494d33f47d68ecd1a35776d /tests/diagnostics
parent927d176be9ba03be161375b8695de1f0a37f1785 (diff)
Add more diagnostics around use of custom derivatives. (#3291)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests/diagnostics')
-rw-r--r--tests/diagnostics/custom-derivative-generic.slang51
1 files changed, 51 insertions, 0 deletions
diff --git a/tests/diagnostics/custom-derivative-generic.slang b/tests/diagnostics/custom-derivative-generic.slang
new file mode 100644
index 000000000..5f2cd9951
--- /dev/null
+++ b/tests/diagnostics/custom-derivative-generic.slang
@@ -0,0 +1,51 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+
+interface IFoo
+{
+ static float bar1(float x);
+
+ // CHECK-DAG: {{.*}}(13): error 31152
+ [PrimalSubstitute(bar1)]
+ static float bar(float x);
+
+ static DifferentialPair<float> dd(DifferentialPair<float> x);
+}
+
+__generic<let N:int>
+float f(float x)
+{
+ return N*x*x;
+}
+
+// CHECK-DAG: {{.*}}(26): error 31153
+[ForwardDerivative(IFoo.dd)]
+float bbb(float x);
+
+// CHECK-DAG: {{.*}}(30): error 31152
+[ForwardDerivativeOf(IFoo.bar)]
+DifferentialPair<float> dd1(DifferentialPair<float> x)
+{
+ return x;
+}
+
+// CHECK-DAG: {{.*}}(37): error 31151
+[BackwardDerivative(f)]
+DifferentialPair<float> df<let N:int>(inout DifferentialPair<float> x, float dOut)
+{
+ var primal = x.p * x.p;
+ var diff = 2 * x.p * x.d * N;
+ return DifferentialPair<float>(primal, diff);
+}
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(3.0, 1.0);
+ outputBuffer[1] = __fwd_diff(f<3>)(dpa).d; // Expect: 6.0
+ }
+}