summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-06-25 15:45:34 -0400
committerGitHub <noreply@github.com>2022-06-25 12:45:34 -0700
commit8da47c460df01fad6f1d0614210a770f4781edb1 (patch)
tree170a5cc100c69e387e8c6d34217588ea00daed53 /tests
parent0229784b93a43e17a088881e6be32b44fc6ce713 (diff)
Added basic auto-diff capabilities for local load/store and simple arithmetic. Also added type-checking during the semantic stage. (#2303)
* Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types * Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly * Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/ir/derivative-op-ir-test.slang26
-rw-r--r--tests/ir/derivative-op-ir-test.slang.expected.txt6
2 files changed, 25 insertions, 7 deletions
diff --git a/tests/ir/derivative-op-ir-test.slang b/tests/ir/derivative-op-ir-test.slang
index 209446765..7addccdc2 100644
--- a/tests/ir/derivative-op-ir-test.slang
+++ b/tests/ir/derivative-op-ir-test.slang
@@ -9,13 +9,31 @@ __differentiate_jvp float f(float x)
return x;
}
+__differentiate_jvp float g(float x)
+{
+ return x + x;
+}
+
+__differentiate_jvp float h(float x, float y)
+{
+ float m = x + y;
+ float n = x - y;
+ return m * n + 2 * x * y;
+}
+
+
[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
{
- float a = 1.0;
- float b = -2.0;
- outputBuffer[0] = __jvp(f)(a);
- outputBuffer[1] = __jvp(f)(b);
+ float a = 2.0;
+ float b = 1.5;
+ float da = 1.0;
+ float db = 1.0;
+
+ outputBuffer[0] = __jvp(f)(a, da); // Expect: 1
+ outputBuffer[1] = __jvp(f)(a, 0.0); // Expect: 0
+ outputBuffer[2] = __jvp(g)(a, da); // Expect: 2
+ outputBuffer[3] = __jvp(h)(a, b, da, db); // Expect: 8
}
}
diff --git a/tests/ir/derivative-op-ir-test.slang.expected.txt b/tests/ir/derivative-op-ir-test.slang.expected.txt
index f095a0071..0545c08a1 100644
--- a/tests/ir/derivative-op-ir-test.slang.expected.txt
+++ b/tests/ir/derivative-op-ir-test.slang.expected.txt
@@ -1,5 +1,5 @@
type: float
+1.0
0.0
-0.0
-0.0
-0.0 \ No newline at end of file
+2.0
+8.0 \ No newline at end of file