diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-27 23:42:06 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-27 23:42:06 -0500 |
| commit | 10e2d9c7c532c204f26bb2c9f383f21b121b2ff2 (patch) | |
| tree | 9ae0dd84b505a7ecd3fb45de9dbde74f8dd1ebe9 /tests | |
| parent | a3ba22b51c371d5a20d61aa4e35233ba4f4f68db (diff) | |
More fixes for reverse-mode on complicated loops (#2675)
* Multiple fixes to get various loop tests to pass.
* Create reverse-nested-loop.slang
* Fix for variables becoming inaccessible during cfg normalization
* Removed comments and moved break-branch-normalization to eliminateMultiLevelBreaks
* Fix.
* Override liveness tests
Diffstat (limited to 'tests')
12 files changed, 171 insertions, 20 deletions
diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang new file mode 100644 index 000000000..0f9502673 --- /dev/null +++ b/tests/autodiff/reverse-continue-loop.slang @@ -0,0 +1,43 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float test_loop_with_continue(float y) +{ + float t = y; + + for (int i = 0; i < 3; i++) + { + if (t > 4.0) + continue; + + t = t * t; + } + + return t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 0.0); + + __bwd_diff(test_loop_with_continue)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 32.0 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_loop_with_continue)(dpa, 1.0f); + outputBuffer[1] = dpa.d; // Expect: 0.0131072 + } +} diff --git a/tests/autodiff/reverse-continue-loop.slang.expected.txt b/tests/autodiff/reverse-continue-loop.slang.expected.txt new file mode 100644 index 000000000..17dbb061d --- /dev/null +++ b/tests/autodiff/reverse-continue-loop.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +32.000000 +0.013107 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/reverse-hybrid-control-flow.slang b/tests/autodiff/reverse-hybrid-control-flow.slang new file mode 100644 index 000000000..9379df4cf --- /dev/null +++ b/tests/autodiff/reverse-hybrid-control-flow.slang @@ -0,0 +1,47 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float test_simple_loop(float y) +{ + float t = y; + + if (y > 0.5) + { + for (int i = 0; i < 3; i++) + { + t = t * t; + } + } + else + { + t = t * 10.f; + } + + return t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_simple_loop)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 8.0 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_simple_loop)(dpa, 1.0f); + outputBuffer[1] = dpa.d; // Expect: 10.0 + } +} diff --git a/tests/autodiff/reverse-hybrid-control-flow.slang.expected.txt b/tests/autodiff/reverse-hybrid-control-flow.slang.expected.txt new file mode 100644 index 000000000..3cb76c394 --- /dev/null +++ b/tests/autodiff/reverse-hybrid-control-flow.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +8.000000 +10.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/reverse-nested-loop.slang b/tests/autodiff/reverse-nested-loop.slang new file mode 100644 index 000000000..08cde5230 --- /dev/null +++ b/tests/autodiff/reverse-nested-loop.slang @@ -0,0 +1,43 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float test_simple_nested_loop(float y) +{ + float t = y; + + for (int i = 0; i < 2; i++) + { + for (int j = 0; j < 2; j++) + { + t = t * (i + j + 1); + } + } + + return t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_simple_nested_loop)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 12.0 * 1 + } + + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_simple_nested_loop)(dpa, 0.4f); + outputBuffer[1] = dpa.d; // Expect: 12 * 0.4 = 4.8 + } +} diff --git a/tests/autodiff/reverse-nested-loop.slang.expected.txt b/tests/autodiff/reverse-nested-loop.slang.expected.txt new file mode 100644 index 000000000..59c14cf1d --- /dev/null +++ b/tests/autodiff/reverse-nested-loop.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +12.000000 +4.800000 +0.000000 +0.000000 +0.000000 diff --git a/tests/experimental/liveness/liveness-2.slang.expected b/tests/experimental/liveness/liveness-2.slang.expected index 16883c1fd..e0486f8da 100644 --- a/tests/experimental/liveness/liveness-2.slang.expected +++ b/tests/experimental/liveness/liveness-2.slang.expected @@ -32,12 +32,12 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(i_0, 0); break; } idx_0[i_0] = offset_0 + i_0; i_0 = i_0 + 1; } - livenessEnd_0(i_0, 0); int _S1 = idx_0[0] + idx_0[1]; int _S2 = idx_0[2]; livenessEnd_1(idx_0, 0); diff --git a/tests/experimental/liveness/liveness-3.slang.expected b/tests/experimental/liveness/liveness-3.slang.expected index cb093a640..d4b417082 100644 --- a/tests/experimental/liveness/liveness-3.slang.expected +++ b/tests/experimental/liveness/liveness-3.slang.expected @@ -12,12 +12,12 @@ void livenessStart_0(spirv_by_reference int _0[2], spirv_literal int _1); spirv_instruction(id = 256) void livenessStart_1(spirv_by_reference int _0, spirv_literal int _1); -spirv_instruction(id = 256) -void livenessStart_2(spirv_by_reference int _0[3], spirv_literal int _1); - spirv_instruction(id = 257) void livenessEnd_0(spirv_by_reference int _0, spirv_literal int _1); +spirv_instruction(id = 256) +void livenessStart_2(spirv_by_reference int _0[3], spirv_literal int _1); + spirv_instruction(id = 257) void livenessEnd_1(spirv_by_reference int _0[3], spirv_literal int _1); @@ -46,6 +46,8 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(_S1, 0); + livenessEnd_0(k_0, 0); break; } int idx_0[3]; @@ -69,6 +71,7 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(i_0, 0); break; } int modRange_0 = i_0 % 3; @@ -97,7 +100,6 @@ int calcThing_0(int offset_0) livenessEnd_0(_S6, 0); _S4 = _S10; } - livenessEnd_0(i_0, 0); livenessEnd_0(_S1, 0); livenessEnd_0(k_0, 0); if(_S3) @@ -123,8 +125,6 @@ int calcThing_0(int offset_0) livenessStart_1(total_0, 0); total_0 = total_1; } - livenessEnd_0(_S1, 0); - livenessEnd_0(k_0, 0); livenessEnd_2(another_0, 0); int _S16 = total_0; livenessEnd_0(total_0, 0); diff --git a/tests/experimental/liveness/liveness-4.slang.expected b/tests/experimental/liveness/liveness-4.slang.expected index efc2e3846..483247ecd 100644 --- a/tests/experimental/liveness/liveness-4.slang.expected +++ b/tests/experimental/liveness/liveness-4.slang.expected @@ -34,6 +34,7 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(k_0, 0); break; } int _S1 = (k_0 + 7) % 5; @@ -49,12 +50,12 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(i_0, 0); break; } another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); i_0 = i_0 + 1; } - livenessEnd_0(i_0, 0); livenessEnd_0(k_0, 0); if(_S2) { @@ -64,7 +65,6 @@ int calcThing_0(int offset_0) livenessStart_1(k_0, 0); k_0 = k_1; } - livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); return -2; } diff --git a/tests/experimental/liveness/liveness-5.slang.expected b/tests/experimental/liveness/liveness-5.slang.expected index e9fe9d652..5e144a095 100644 --- a/tests/experimental/liveness/liveness-5.slang.expected +++ b/tests/experimental/liveness/liveness-5.slang.expected @@ -37,6 +37,7 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(k_0, 0); break; } int _S1 = (k_0 + 7) % 5; @@ -52,12 +53,12 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(i_0, 0); break; } another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); i_0 = i_0 + 1; } - livenessEnd_0(i_0, 0); livenessEnd_0(k_0, 0); int _S3 = total_0; livenessEnd_0(total_0, 0); @@ -72,7 +73,6 @@ int calcThing_0(int offset_0) livenessStart_1(total_0, 0); total_0 = total_1; } - livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); if(total_0 > 4) { diff --git a/tests/experimental/liveness/liveness-6.slang.expected b/tests/experimental/liveness/liveness-6.slang.expected index b661c09bf..0d2e997b2 100644 --- a/tests/experimental/liveness/liveness-6.slang.expected +++ b/tests/experimental/liveness/liveness-6.slang.expected @@ -37,6 +37,7 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(k_0, 0); break; } int arr_0[2]; @@ -57,13 +58,13 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(i_0, 0); break; } another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); arr_0[_S1] = arr_0[_S1] + i_0; i_0 = i_0 + 1; } - livenessEnd_0(i_0, 0); livenessEnd_0(k_0, 0); int _S4 = total_0; livenessEnd_0(total_0, 0); @@ -81,7 +82,6 @@ int calcThing_0(int offset_0) livenessStart_1(total_0, 0); total_0 = total_2; } - livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); if(total_0 > 4) { diff --git a/tests/experimental/liveness/liveness.slang.expected b/tests/experimental/liveness/liveness.slang.expected index 06809ffc3..b0017ea9d 100644 --- a/tests/experimental/liveness/liveness.slang.expected +++ b/tests/experimental/liveness/liveness.slang.expected @@ -13,10 +13,10 @@ spirv_instruction(id = 256) void livenessStart_1(spirv_by_reference int _0, spirv_literal int _1); spirv_instruction(id = 257) -void livenessEnd_0(spirv_by_reference uint _0, spirv_literal int _1); +void livenessEnd_0(spirv_by_reference int _0, spirv_literal int _1); spirv_instruction(id = 257) -void livenessEnd_1(spirv_by_reference int _0, spirv_literal int _1); +void livenessEnd_1(spirv_by_reference uint _0, spirv_literal int _1); int someSlowFunc_0(int a_0) { @@ -35,18 +35,18 @@ int someSlowFunc_0(int a_0) } else { + livenessEnd_0(i_0, 0); break; } uint _S3 = v_0 >> 1; uint _S4 = v_0; - livenessEnd_0(v_0, 0); + livenessEnd_1(v_0, 0); uint _S5 = (_S3 | _S4 << 31) * uint(i_0); int i_1 = i_0 + 1; livenessStart_0(v_0, 0); v_0 = _S5; i_0 = i_1; } - livenessEnd_1(i_0, 0); return int(v_0); } @@ -111,6 +111,7 @@ void main() } else { + livenessEnd_0(i_2, 0); break; } SomeStruct_0 s_3; @@ -153,15 +154,14 @@ void main() livenessEnd_2(s_3, 0); int _S22 = _S20 + _S21; int _S23 = res_0; - livenessEnd_1(res_0, 0); + livenessEnd_0(res_0, 0); int res_1 = _S23 + _S22; i_2 = i_2 + 1; livenessStart_1(res_0, 0); res_0 = res_1; } - livenessEnd_1(i_2, 0); int _S24 = res_0; - livenessEnd_1(res_0, 0); + livenessEnd_0(res_0, 0); ((outputBuffer_0)._data[(uint(index_0))]) = _S24; return; } |
