summaryrefslogtreecommitdiffstats
path: root/tests/language-feature/ifunc/diff-functor.slang
diff options
context:
space:
mode:
Diffstat (limited to 'tests/language-feature/ifunc/diff-functor.slang')
-rw-r--r--tests/language-feature/ifunc/diff-functor.slang44
1 files changed, 44 insertions, 0 deletions
diff --git a/tests/language-feature/ifunc/diff-functor.slang b/tests/language-feature/ifunc/diff-functor.slang
new file mode 100644
index 000000000..04b0be44f
--- /dev/null
+++ b/tests/language-feature/ifunc/diff-functor.slang
@@ -0,0 +1,44 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -shaderobj -output-using-type
+
+struct DiffFunctor : IDifferentiableFunc<float, float>
+{
+ [Differentiable]
+ float __call(float p)
+ {
+ return p + 1;
+ }
+}
+
+float apply(IMutatingFunc<float, float> f, float p)
+{
+ return f.__call(p);
+}
+
+[Differentiable]
+float applyDiff(IDifferentiableFunc<float, float> f, float p)
+{
+ return f.__call(p);
+}
+
+[Differentiable]
+TR applyDiffGen<TR : IDifferentiable, each TP : IDifferentiable>(IDifferentiableFunc<TR, TP> f, expand each TP p)
+{
+ return f.__call(expand each p);
+}
+
+//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint tid: SV_DispatchThreadID)
+{
+ // CHECK: 4
+ outputBuffer[0] = (uint)apply(DiffFunctor(), 3.0);
+ // CHECK: 1
+ outputBuffer[1] = (uint)fwd_diff(applyDiff)(DiffFunctor(), diffPair(2.0, 1.0)).d;
+ // CHECK: 1
+ outputBuffer[2] = (uint)fwd_diff(applyDiffGen<float, float>)(DiffFunctor(), diffPair(2.0, 1.0)).d;
+}