diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-07 11:22:32 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-07 11:22:32 -0800 |
| commit | 257733f328f38a763c8b0c8830ff4c0d34ec9491 (patch) | |
| tree | 87e444746f353d69a365380904f3f8caf15fbfec /tests | |
| parent | 6f31eae79d5b4297d0099c5779a9806a786cf9f8 (diff) | |
Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688)
* Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs.
* Add diff implementation matrix versions of binary and ternary intrinsics.
* Add diff impl for legacy intrinsics.
* Fix diagnostics of using non-differentiable function in a diff operator.
* Add diff implementation for `determinant`.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
4 files changed, 30 insertions, 4 deletions
diff --git a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang index 379e2c3ef..53972ac2c 100644 --- a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang +++ b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang @@ -2,7 +2,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; @@ -43,4 +43,14 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) __bwd_diff(diffSin)(dpx, 1.0); outputBuffer[4] = dpx.d; // Expect: -1.000000 } + + { + dpfloat dpx = dpfloat(float.getPi() / 3.0, 1.0); + __bwd_diff(sincos)(dpx, 1.0, 0.0); + outputBuffer[5] = dpx.d; // Expect: 0.5 + __bwd_diff(sincos)(dpx, 0.0, 1.0); + outputBuffer[6] = dpx.d; // Expect: -0.8660254 + __bwd_diff(sincos)(dpx, 1.0, 1.0); + outputBuffer[7] = dpx.d; // Expect: -0.3660254 + } } diff --git a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt index a4b804cb8..17627df68 100644 --- a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt @@ -3,4 +3,7 @@ type: float 7.389056 0.000000 1.000000 --1.00000
\ No newline at end of file +-1.000000 +0.500000 +-0.866025 +-0.366025 diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang b/tests/autodiff-dstdlib/dstdlib-sqrt.slang index 15573c4ef..d68a2697c 100644 --- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang +++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang @@ -1,7 +1,7 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; @@ -50,4 +50,13 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) outputBuffer[6] = dpx.d[0]; // Expect: 0.158114 outputBuffer[7] = dpx.d[1]; // Expect: 0.577350 } + + { + var dpx = diffPair(float2x2(4.0, 9.0, 16.0, 25.0), float2x2(0.0, 0.0, 0.0, 0.0)); + __bwd_diff(sqrt)(dpx, float2x2(1.0, 2.0, 3.0, 4.0)); + outputBuffer[8] = dpx.d[0][0]; // Expect: 0.25 + outputBuffer[9] = dpx.d[0][1]; // Expect: 0.3333 + outputBuffer[10] = dpx.d[1][0]; // Expect: 0.375 + outputBuffer[11] = dpx.d[1][1]; // Expect: 0.4 + } } diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt index fe6487fef..7e0fdf02f 100644 --- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt @@ -6,4 +6,8 @@ type: float 0.000000 0.000000 0.158114 -0.577350
\ No newline at end of file +0.577350 +0.250000 +0.333333 +0.375000 +0.400000
\ No newline at end of file |
