From a0ee2bf671d61d1e2b561db3966e57ffc802040f Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska Date: Thu, 17 Aug 2023 13:41:49 +0800 Subject: Add loop inversion pass (#2899) * Generalize collectInductionValues * Support affine transformations of loop index as induction variables * Test for generalized induction value collection * Neaten inductive variable finding * Make types more specific * Add loop inversion pass * Test output changes after loop inversion * Store the type of implication success when finding inductive variables * Test that loop induction finding does not alway succeed * Support chains of additions and branches of additions in induction variable finding * Use c++17 for downstream compilers * Wiggle expected output for cross compile test after loop inversion * Add loop inversion test * Simplify IfElse instructions with a single trivial block * Invert loops with a user inserted break * Limit loop inversion to loops with a 4 instruction or less comparison block * regenerate vs projects --- tests/cross-compile/geometry-shader.slang.glsl | 21 +++-- tests/cross-compile/loop-attribs.slang.hlsl | 56 ++++++------- tests/ir/loop-inversion.slang | 107 +++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 43 deletions(-) create mode 100644 tests/ir/loop-inversion.slang (limited to 'tests') diff --git a/tests/cross-compile/geometry-shader.slang.glsl b/tests/cross-compile/geometry-shader.slang.glsl index 38dbd72ba..3b7ecca43 100644 --- a/tests/cross-compile/geometry-shader.slang.glsl +++ b/tests/cross-compile/geometry-shader.slang.glsl @@ -68,13 +68,6 @@ void main() for(;;) { - if(ii_0 < 3) - {} - else - { - break; - } - RasterVertex_0 rasterVertex_0; rasterVertex_0.position_0 = _S10[ii_0].position_1; rasterVertex_0.color_0 = _S10[ii_0].color_1; @@ -82,13 +75,17 @@ void main() RasterVertex_0 _S11 = rasterVertex_0; _S4 = rasterVertex_0.position_0; _S5 = _S11.color_0; - gl_Layer = int(_S11.id_0); - EmitVertex(); - - ii_0 = ii_0 + 1; + int ii_1 = ii_0 + 1; + if(ii_1 < 3) + { + ii_0 = ii_1; + } + else + { + break; + } } - return; } diff --git a/tests/cross-compile/loop-attribs.slang.hlsl b/tests/cross-compile/loop-attribs.slang.hlsl index 5d53f51e0..2c92d16f3 100644 --- a/tests/cross-compile/loop-attribs.slang.hlsl +++ b/tests/cross-compile/loop-attribs.slang.hlsl @@ -1,55 +1,49 @@ #pragma pack_matrix(column_major) +#ifdef SLANG_HLSL_ENABLE_NVAPI +#include "nvHLSLExtns.h" +#endif +#pragma warning(disable: 3557) -#line 6 "tests/cross-compile/loop-attribs.slang" -vector main() : SV_TARGET +float4 main() : SV_TARGET { - int i_0; - float sum_0; - int j_0; - float sum_1; - i_0 = int(0); - sum_0 = 0.00000000000000000000; + float _S1 = 0.0; + int i_0 = int(0); + float sum_0 = 0.0; [loop] for(;;) { - -#line 11 - if(i_0 < int(100)) + float sum_1 = sum_0 + float(i_0); + _S1 = sum_1; + int i_1 = i_0 + int(1); + if(i_1 < int(100)) { + i_0 = i_1; + sum_0 = sum_1; } else { break; } - float _S1 = sum_0 + (float) i_0; - -#line 11 - int _S2 = i_0 + (int) int(1); - i_0 = _S2; - sum_0 = _S1; } - j_0 = int(0); - sum_1 = sum_0; + float _S2 = 0.0; + int j_0 = int(0); + sum_0 = _S1; [unroll] for(;;) { - -#line 15 - if(j_0 < int(100)) + float sum_2 = sum_0 + float(j_0); + _S2 = sum_2; + int j_1 = j_0 + int(1); + if(j_1 < int(100)) { + j_0 = j_1; + sum_0 = sum_2; } else { break; } - float _S3 = sum_1 + (float) j_0; - -#line 15 - int _S4 = j_0 + (int) int(1); - j_0 = _S4; - sum_1 = _S3; } + return float4(_S2, 0.0, 0.0, 0.0); +} -#line 18 - return vector(sum_1, (float) int(0), (float) int(0), (float) int(0)); -} \ No newline at end of file diff --git a/tests/ir/loop-inversion.slang b/tests/ir/loop-inversion.slang new file mode 100644 index 000000000..03bdcc340 --- /dev/null +++ b/tests/ir/loop-inversion.slang @@ -0,0 +1,107 @@ +//TEST():SIMPLE(filecheck=CHECK):-entry computeMain -stage compute -line-directive-mode none -target hlsl +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=OUT):-shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=OUT):-dx12 -use-dxil -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=OUT):-cpu -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=OUT):-vk -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=OUT):-cpu -shaderobj -output-using-type + +// Check that all the backends cope with the slightly unusual IR the loop inversion generated + +// OUT: 180 + +// For all the below functions, verify that the body (adding to j and +// incrementing i) comes before any break. This verifies that the `break` has +// been moved to the end of the loop. + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +// A standard loop +// CHECK-LABEL: int a_{{.*}}() +// CHECK-NOT: break; +// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]] +// CHECK: [[i]] + int(1); +// CHECK: if( +// CHECK: break; +// CHECK: return +int a() +{ + int j = 0; + for(int i = 0; i < 10; ++i) + j += i; + return j; +} + +// A vanilla while loop +// CHECK-LABEL: int b_{{.*}}() +// CHECK-NOT: break; +// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]] +// CHECK: [[i]] + int(1); +// CHECK: if( +// CHECK: break; +// CHECK: return +int b() +{ + int j = 0; + int i = 0; + while(i < 10) + { + j += i; + i++; + } + return j; +} + +// A while loop with a break on the false branch +// CHECK-LABEL: int c_{{.*}}() +// CHECK-NOT: break; +// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]] +// CHECK: [[i]] + int(1); +// CHECK: if( +// CHECK: break; +// CHECK: return +int c() +{ + int j = 0; + int i = 0; + do + { + if(i < 10) + {} + else + break; + j += i; + i++; + } while(true); + return j; +} + +// A while loop with a break on the true branch +// CHECK-LABEL: int d_{{.*}}() +// CHECK-NOT: break; +// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]] +// CHECK: [[i]] + int(1); +// CHECK: if( +// CHECK: break; +// CHECK: return +int d() +{ + int j = 0; + int i = 0; + do + { + if(i >= 10) + break; + else + {} + j += i; + i++; + } while(true); + return j; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[dispatchThreadID.x] = a() + b() + c() + d(); +} -- cgit v1.2.3