diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-06-25 15:45:34 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-25 12:45:34 -0700 |
| commit | 8da47c460df01fad6f1d0614210a770f4781edb1 (patch) | |
| tree | 170a5cc100c69e387e8c6d34217588ea00daed53 /tests | |
| parent | 0229784b93a43e17a088881e6be32b44fc6ce713 (diff) | |
Added basic auto-diff capabilities for local load/store and simple arithmetic. Also added type-checking during the semantic stage. (#2303)
* Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types
* Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly
* Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/ir/derivative-op-ir-test.slang | 26 | ||||
| -rw-r--r-- | tests/ir/derivative-op-ir-test.slang.expected.txt | 6 |
2 files changed, 25 insertions, 7 deletions
diff --git a/tests/ir/derivative-op-ir-test.slang b/tests/ir/derivative-op-ir-test.slang index 209446765..7addccdc2 100644 --- a/tests/ir/derivative-op-ir-test.slang +++ b/tests/ir/derivative-op-ir-test.slang @@ -9,13 +9,31 @@ __differentiate_jvp float f(float x) return x; } +__differentiate_jvp float g(float x) +{ + return x + x; +} + +__differentiate_jvp float h(float x, float y) +{ + float m = x + y; + float n = x - y; + return m * n + 2 * x * y; +} + + [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { { - float a = 1.0; - float b = -2.0; - outputBuffer[0] = __jvp(f)(a); - outputBuffer[1] = __jvp(f)(b); + float a = 2.0; + float b = 1.5; + float da = 1.0; + float db = 1.0; + + outputBuffer[0] = __jvp(f)(a, da); // Expect: 1 + outputBuffer[1] = __jvp(f)(a, 0.0); // Expect: 0 + outputBuffer[2] = __jvp(g)(a, da); // Expect: 2 + outputBuffer[3] = __jvp(h)(a, b, da, db); // Expect: 8 } } diff --git a/tests/ir/derivative-op-ir-test.slang.expected.txt b/tests/ir/derivative-op-ir-test.slang.expected.txt index f095a0071..0545c08a1 100644 --- a/tests/ir/derivative-op-ir-test.slang.expected.txt +++ b/tests/ir/derivative-op-ir-test.slang.expected.txt @@ -1,5 +1,5 @@ type: float +1.0 0.0 -0.0 -0.0 -0.0
\ No newline at end of file +2.0 +8.0
\ No newline at end of file |
