diff options
| author | Yong He <yonghe@outlook.com> | 2022-12-08 14:56:20 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-12-08 14:56:20 -0800 |
| commit | 41eb19e65a0974e23048bd7b3b1eb1e2f569b1d0 (patch) | |
| tree | c6cde57da4d3415d86d09213936a48d3d26e07e1 /tests/autodiff | |
| parent | 468bb7ecf65c000c308adae511bf65a1ca4cc412 (diff) | |
Auto-diff for matrix operations. (#2559)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests/autodiff')
| -rw-r--r-- | tests/autodiff/matrix-arithmetic-fwd.slang | 41 | ||||
| -rw-r--r-- | tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt | 5 |
2 files changed, 46 insertions, 0 deletions
diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang b/tests/autodiff/matrix-arithmetic-fwd.slang new file mode 100644 index 000000000..7a953cef8 --- /dev/null +++ b/tests/autodiff/matrix-arithmetic-fwd.slang @@ -0,0 +1,41 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[ForwardDifferentiable] +float3x3 g(float3x3 x, float3x3 y) +{ + float3x3 a = x + y; + float3x3 b = x - y; + return a * b + 2 * x * y; +} + +[ForwardDifferentiable] +float h(float2x2 x, float2x2 y) +{ + let t = mul(x, y); + return t[0][0] + t[0][1] + t[1][0] + t[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float3x3 a = float3x3(2.0); + float3x3 b = float3x3(1.5); + float3x3 da = float3x3(1.0); + + outputBuffer[0] = __fwd_diff(g)( + DifferentialPair<float3x3>(a, da), + DifferentialPair<float3x3>(b, da)).d._11; // Expect: 8 + + float2x2 l = float2x2(1.0, 2.0, 3.0, 4.0); + float2x2 r = float2x2(10.0, 11.0, 12.0, 13.0); + float2x2 d = float2x2(1.0, 0.0, 1.0, 1.0); + + //float2x2 epsilon = d * 0.001f; + //outputBuffer[1] = (h(l + epsilon, r + epsilon) - h(l - epsilon, r - epsilon)) / (epsilon[0][0] * 2.0)); + + outputBuffer[1] = __fwd_diff(h)(DifferentialPair<float2x2>(l, d), DifferentialPair<float2x2>(r, d)).d; // Expect 83.0 +} diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt b/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt new file mode 100644 index 000000000..c595048c3 --- /dev/null +++ b/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +8.0 +83.0 +0.0 +0.0
\ No newline at end of file |
