diff options
| author | Yong He <yonghe@outlook.com> | 2023-10-26 14:01:26 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-10-26 14:01:26 -0700 |
| commit | bee74b16eafa64ccc33bb386a1dc753cd6c41a82 (patch) | |
| tree | 199a4575cfe6297bb494d33f47d68ecd1a35776d /tests/diagnostics | |
| parent | 927d176be9ba03be161375b8695de1f0a37f1785 (diff) | |
Add more diagnostics around use of custom derivatives. (#3291)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests/diagnostics')
| -rw-r--r-- | tests/diagnostics/custom-derivative-generic.slang | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/tests/diagnostics/custom-derivative-generic.slang b/tests/diagnostics/custom-derivative-generic.slang new file mode 100644 index 000000000..5f2cd9951 --- /dev/null +++ b/tests/diagnostics/custom-derivative-generic.slang @@ -0,0 +1,51 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; + +interface IFoo +{ + static float bar1(float x); + + // CHECK-DAG: {{.*}}(13): error 31152 + [PrimalSubstitute(bar1)] + static float bar(float x); + + static DifferentialPair<float> dd(DifferentialPair<float> x); +} + +__generic<let N:int> +float f(float x) +{ + return N*x*x; +} + +// CHECK-DAG: {{.*}}(26): error 31153 +[ForwardDerivative(IFoo.dd)] +float bbb(float x); + +// CHECK-DAG: {{.*}}(30): error 31152 +[ForwardDerivativeOf(IFoo.bar)] +DifferentialPair<float> dd1(DifferentialPair<float> x) +{ + return x; +} + +// CHECK-DAG: {{.*}}(37): error 31151 +[BackwardDerivative(f)] +DifferentialPair<float> df<let N:int>(inout DifferentialPair<float> x, float dOut) +{ + var primal = x.p * x.p; + var diff = 2 * x.p * x.d * N; + return DifferentialPair<float>(primal, diff); +} +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(3.0, 1.0); + outputBuffer[1] = __fwd_diff(f<3>)(dpa).d; // Expect: 6.0 + } +} |
