summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-08 21:52:34 -0800
committerGitHub <noreply@github.com>2023-03-08 21:52:34 -0800
commit86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch)
treeb4f9eb6cb1eea88145fde0bd1f670a8803120257 /tests
parent257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff)
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/primal-substitute-2.slang34
-rw-r--r--tests/autodiff/primal-substitute-2.slang.expected.txt6
-rw-r--r--tests/autodiff/primal-substitute-3.slang52
-rw-r--r--tests/autodiff/primal-substitute-3.slang.expected.txt6
-rw-r--r--tests/autodiff/primal-substitute.slang27
-rw-r--r--tests/autodiff/primal-substitute.slang.expected.txt3
6 files changed, 128 insertions, 0 deletions
diff --git a/tests/autodiff/primal-substitute-2.slang b/tests/autodiff/primal-substitute-2.slang
new file mode 100644
index 000000000..6c53f84a6
--- /dev/null
+++ b/tests/autodiff/primal-substitute-2.slang
@@ -0,0 +1,34 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+float original(float x)
+{
+ return x * x;
+}
+
+[PrimalSubstituteOf(original)]
+[BackwardDifferentiable]
+float primalSubst(float x)
+{
+ return 2.0f * x * x;
+}
+
+[BackwardDifferentiable]
+float caller(float x)
+{
+ return original(x);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var a = diffPair(3.0, 1.0);
+ __bwd_diff(caller)(a, 1.0);
+ outputBuffer[0] = a.d; // Expect: 12.0
+ outputBuffer[1] = __fwd_diff(caller)(diffPair(3.0, 1.0)).p; // Expect: 18.0
+ outputBuffer[2] = caller(3.0); // Expect: 9.0
+}
diff --git a/tests/autodiff/primal-substitute-2.slang.expected.txt b/tests/autodiff/primal-substitute-2.slang.expected.txt
new file mode 100644
index 000000000..ee60dfa22
--- /dev/null
+++ b/tests/autodiff/primal-substitute-2.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+12.000000
+18.000000
+9.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/primal-substitute-3.slang b/tests/autodiff/primal-substitute-3.slang
new file mode 100644
index 000000000..ab2899bdc
--- /dev/null
+++ b/tests/autodiff/primal-substitute-3.slang
@@ -0,0 +1,52 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+interface IFoo
+{
+ float doSomething();
+}
+
+struct A : IFoo
+{
+ float doSomething()
+ {
+ return 0.0f;
+ }
+}
+
+float original<T : IFoo>(T p, float x)
+{
+ p.doSomething();
+ return x * x;
+}
+
+[PrimalSubstituteOf(original)]
+[BackwardDifferentiable]
+float primalSubst<T : IFoo>(T p, float x)
+{
+ return 2.0f * x * x;
+}
+
+[BackwardDifferentiable]
+float caller(IFoo d, float x)
+{
+ return original(d, x);
+}
+
+//TEST_INPUT: type_conformance A:IFoo = 0
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IFoo>(dispatchThreadID.x, 0); // A
+
+ var a = diffPair(3.0, 1.0);
+ __bwd_diff(caller)(obj, a, 1.0);
+ outputBuffer[0] = a.d; // Expect: 12.0
+ outputBuffer[1] = __fwd_diff(caller)(obj, diffPair(3.0, 1.0)).p; // Expect: 18.0
+ outputBuffer[2] = caller(obj, 3.0); // Expect: 9.0
+}
diff --git a/tests/autodiff/primal-substitute-3.slang.expected.txt b/tests/autodiff/primal-substitute-3.slang.expected.txt
new file mode 100644
index 000000000..ee60dfa22
--- /dev/null
+++ b/tests/autodiff/primal-substitute-3.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+12.000000
+18.000000
+9.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/primal-substitute.slang b/tests/autodiff/primal-substitute.slang
new file mode 100644
index 000000000..01f221f2a
--- /dev/null
+++ b/tests/autodiff/primal-substitute.slang
@@ -0,0 +1,27 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+float original(float x)
+{
+ return x * x;
+}
+
+[PrimalSubstituteOf(original)]
+[BackwardDifferentiable]
+float primalSubst(float x)
+{
+ return 2.0f * x * x;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var a = diffPair(3.0, 1.0);
+ __bwd_diff(original)(a, 1.0);
+ outputBuffer[0] = a.d; // Expect: 12.0
+ outputBuffer[1] = __fwd_diff(original)(diffPair(3.0, 1.0)).p; // Expect: 18.0
+}
diff --git a/tests/autodiff/primal-substitute.slang.expected.txt b/tests/autodiff/primal-substitute.slang.expected.txt
new file mode 100644
index 000000000..af1b9f528
--- /dev/null
+++ b/tests/autodiff/primal-substitute.slang.expected.txt
@@ -0,0 +1,3 @@
+type: float
+12.0
+18.0