summaryrefslogtreecommitdiff
path: root/tests/language-feature/ifunc
diff options
context:
space:
mode:
Diffstat (limited to 'tests/language-feature/ifunc')
-rw-r--r--tests/language-feature/ifunc/diff-functor.slang44
-rw-r--r--tests/language-feature/ifunc/ifunc.slang40
2 files changed, 84 insertions, 0 deletions
diff --git a/tests/language-feature/ifunc/diff-functor.slang b/tests/language-feature/ifunc/diff-functor.slang
new file mode 100644
index 000000000..04b0be44f
--- /dev/null
+++ b/tests/language-feature/ifunc/diff-functor.slang
@@ -0,0 +1,44 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -shaderobj -output-using-type
+
+struct DiffFunctor : IDifferentiableFunc<float, float>
+{
+ [Differentiable]
+ float __call(float p)
+ {
+ return p + 1;
+ }
+}
+
+float apply(IMutatingFunc<float, float> f, float p)
+{
+ return f.__call(p);
+}
+
+[Differentiable]
+float applyDiff(IDifferentiableFunc<float, float> f, float p)
+{
+ return f.__call(p);
+}
+
+[Differentiable]
+TR applyDiffGen<TR : IDifferentiable, each TP : IDifferentiable>(IDifferentiableFunc<TR, TP> f, expand each TP p)
+{
+ return f.__call(expand each p);
+}
+
+//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint tid: SV_DispatchThreadID)
+{
+ // CHECK: 4
+ outputBuffer[0] = (uint)apply(DiffFunctor(), 3.0);
+ // CHECK: 1
+ outputBuffer[1] = (uint)fwd_diff(applyDiff)(DiffFunctor(), diffPair(2.0, 1.0)).d;
+ // CHECK: 1
+ outputBuffer[2] = (uint)fwd_diff(applyDiffGen<float, float>)(DiffFunctor(), diffPair(2.0, 1.0)).d;
+}
diff --git a/tests/language-feature/ifunc/ifunc.slang b/tests/language-feature/ifunc/ifunc.slang
new file mode 100644
index 000000000..f270299b3
--- /dev/null
+++ b/tests/language-feature/ifunc/ifunc.slang
@@ -0,0 +1,40 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -shaderobj -output-using-type
+
+struct Functor : IFunc<int, int, bool>
+{
+ int __call(int p, bool t)
+ {
+ return p + 1;
+ }
+}
+
+struct MutatingFunctor : IMutatingFunc<int, int, bool>
+{
+ int data = 0;
+ [mutating]
+ int __call(int p, bool t)
+ {
+ data++;
+ return p + 1;
+ }
+}
+
+int apply(IMutatingFunc<int, int, bool> f, int p)
+{
+ return f.__call(p, true);
+}
+
+//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint tid: SV_DispatchThreadID)
+{
+ // CHECK: 2
+ outputBuffer[0] = apply(MutatingFunctor(), 1);
+ // CHECK: 3
+ outputBuffer[1] = apply(Functor(), 2);
+}