summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-21 21:29:13 -0700
committerGitHub <noreply@github.com>2023-03-21 21:29:13 -0700
commitd8a40abba5223fbcb56c52b04ccb88c02bbaf79f (patch)
tree3207babbce41957fbd01c3c791fe9957c81f6a09 /tests
parent83876733d69582eec6bad26af64a651d40fa43aa (diff)
[TreatAsDifferentiable] functions. (#2720)
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/treat-as-differentiable.slang37
-rw-r--r--tests/autodiff/treat-as-differentiable.slang.expected.txt2
2 files changed, 39 insertions, 0 deletions
diff --git a/tests/autodiff/treat-as-differentiable.slang b/tests/autodiff/treat-as-differentiable.slang
new file mode 100644
index 000000000..95423d978
--- /dev/null
+++ b/tests/autodiff/treat-as-differentiable.slang
@@ -0,0 +1,37 @@
+// Tests automatic synthesis of Differential type and method requirements.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+interface IFoo
+{
+ [BackwardDifferentiable]
+ float f(float v);
+}
+
+struct B : IFoo
+{
+ [TreatAsDifferentiable]
+ float f(float v)
+ {
+ return v * v;
+ }
+}
+
+[BackwardDifferentiable]
+float use(IFoo o, float x)
+{
+ return o.f(x);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ B b;
+ var p = diffPair(1.0);
+ __bwd_diff(use)(b, p, 1.0);
+ outputBuffer[0] = p.d;
+}
diff --git a/tests/autodiff/treat-as-differentiable.slang.expected.txt b/tests/autodiff/treat-as-differentiable.slang.expected.txt
new file mode 100644
index 000000000..9d11e5c94
--- /dev/null
+++ b/tests/autodiff/treat-as-differentiable.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+0.0 \ No newline at end of file