summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2023-08-17 13:41:49 +0800
committerGitHub <noreply@github.com>2023-08-17 13:41:49 +0800
commita0ee2bf671d61d1e2b561db3966e57ffc802040f (patch)
treec82bbfeb75cf5fa8d630322a2a62010705579056
parent3e41d698714a3ab6235e9275d5e0687a1c5db9c9 (diff)
Add loop inversion pass (#2899)
* Generalize collectInductionValues * Support affine transformations of loop index as induction variables * Test for generalized induction value collection * Neaten inductive variable finding * Make types more specific * Add loop inversion pass * Test output changes after loop inversion * Store the type of implication success when finding inductive variables * Test that loop induction finding does not alway succeed * Support chains of additions and branches of additions in induction variable finding * Use c++17 for downstream compilers * Wiggle expected output for cross compile test after loop inversion * Add loop inversion test * Simplify IfElse instructions with a single trivial block * Invert loops with a user inserted break * Limit loop inversion to loops with a 4 instruction or less comparison block * regenerate vs projects
-rw-r--r--build/visual-studio/slang/slang.vcxproj2
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters6
-rw-r--r--source/slang/slang-ir-fuse-satcoop.cpp17
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir-loop-inversion.cpp305
-rw-r--r--source/slang/slang-ir-loop-inversion.h10
-rw-r--r--source/slang/slang-ir-util.h16
-rw-r--r--source/slang/slang-ir.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp19
-rw-r--r--tests/cross-compile/geometry-shader.slang.glsl21
-rw-r--r--tests/cross-compile/loop-attribs.slang.hlsl56
-rw-r--r--tests/ir/loop-inversion.slang107
12 files changed, 502 insertions, 61 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index 283221612..29437c152 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -406,6 +406,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-ir-legalize-varying-params.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-link.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-liveness.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-loop-inversion.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-loop-unroll.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-binding-query.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h" />
@@ -615,6 +616,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-ir-legalize-varying-params.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-link.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-liveness.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-loop-inversion.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-loop-unroll.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-binding-query.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index a579aa7e7..00c3bf74d 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -306,6 +306,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-liveness.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-loop-inversion.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-loop-unroll.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -929,6 +932,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-liveness.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-loop-inversion.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-loop-unroll.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-ir-fuse-satcoop.cpp b/source/slang/slang-ir-fuse-satcoop.cpp
index 3c827ef25..b672f3f7c 100644
--- a/source/slang/slang-ir-fuse-satcoop.cpp
+++ b/source/slang/slang-ir-fuse-satcoop.cpp
@@ -4,6 +4,7 @@
#include "slang-ir-insts.h"
#include "slang-ir-specialize-function-call.h"
#include "slang-ir-ssa-simplification.h"
+#include "slang-ir-util.h"
#include "slang-ir.h"
namespace Slang
@@ -13,22 +14,6 @@ namespace Slang
// Some helpers
//
-// Run an operation over every block in a module
-template<typename F>
-static void overAllBlocks(IRModule* module, F f)
-{
- for (auto globalInst : module->getGlobalInsts())
- {
- if (auto func = as<IRGlobalValueWithCode>(globalInst))
- {
- for (auto block : func->getBlocks())
- {
- f(block);
- }
- }
- }
-}
-
static bool uses(IRInst* used, IRInst* user)
{
for(auto use = used->firstUse; use; use = use->nextUse)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 3503ece79..f4489611d 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3751,7 +3751,7 @@ public:
IRBlock* trueBlock,
IRBlock* afterBlock);
- IRInst* emitIfElse(
+ IRIfElse* emitIfElse(
IRInst* val,
IRBlock* trueBlock,
IRBlock* falseBlock,
diff --git a/source/slang/slang-ir-loop-inversion.cpp b/source/slang/slang-ir-loop-inversion.cpp
new file mode 100644
index 000000000..5050ce9d6
--- /dev/null
+++ b/source/slang/slang-ir-loop-inversion.cpp
@@ -0,0 +1,305 @@
+#include "slang-ir-loop-inversion.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dominators.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-lower-witness-lookup.h"
+#include "slang-ir-reachability.h"
+#include "slang-ir-ssa-simplification.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+
+static bool isSameBlockOrTrivialBranch(IRBlock* target, IRBlock* scrutinee)
+{
+ if(target == scrutinee)
+ return true;
+ const auto br = as<IRUnconditionalBranch>(scrutinee->getFirstOrdinaryInst());
+ return br && br->getTargetBlock() == target && br->getArgCount() == 0 && !scrutinee->hasMoreThanOneUse();
+};
+
+static bool isSmallBlock(IRBlock* c)
+{
+ // Somewhat arbitrarily, 4 instructions, enough for:
+ // - Arith
+ // - Comparison
+ // - Negation
+ // - Terminator
+ Int n = 0;
+ for([[maybe_unused]] const auto i : c->getOrdinaryInsts())
+ if(++n > 4)
+ return false;
+ return true;
+}
+
+// Loops are suitable for inversion if:
+// - The loop jumps to a conditional branch which has the break block as one of
+// its successors (or a trivial break block which we erase) and the other
+// successor is empty
+// - The conditional block is "small", because we will be duplicating it
+static bool isSuitableForInversion(IRLoop* loop)
+{
+ const auto nextBlock = loop->getTargetBlock();
+ const auto breakBlock = loop->getBreakBlock();
+
+ // The first thing a loop does must be a conditional
+ const auto branch = as<IRIfElse>(nextBlock->getTerminator());
+ if(!branch)
+ return false;
+
+ if(!isSmallBlock(nextBlock))
+ return false;
+
+ const auto t = branch->getTrueBlock();
+ const auto f = branch->getFalseBlock();
+ const auto a = branch->getAfterBlock();
+
+ //
+ // In principle we could perform this simplification in the cfg simplifier,
+ // however it relies on slightly more context than is simple to insert
+ // there, namely that the removed trivial branching block is branching to a
+ // loop break block.
+ //
+
+ // Do we break on the 'true' side?
+ if(isSameBlockOrTrivialBranch(breakBlock, t) && f == a)
+ {
+ if(t != breakBlock)
+ {
+ branch->trueBlock.set(breakBlock);
+ t->removeAndDeallocate();
+ }
+ return true;
+ }
+
+ // ... or the false side
+ if(isSameBlockOrTrivialBranch(breakBlock, f) && t == a)
+ {
+ if(f != breakBlock)
+ {
+ branch->falseBlock.set(breakBlock);
+ f->removeAndDeallocate();
+ }
+ return true;
+ }
+
+ return false;
+}
+
+static IRParam* duplicateToParamWithDecorations(IRBuilder& builder, IRCloneEnv& cloneEnv, IRInst* i)
+{
+ const auto p = builder.emitParam(i->getFullType());
+ for(const auto dec : i->getDecorations())
+ cloneDecoration(&cloneEnv, dec, p, builder.getModule());
+ return p;
+}
+
+// Given
+// s: ...1 loop break=b next=c1
+// c1: if x then goto b else goto d
+// d: goto c1
+// b: ...2
+//
+// Produce:
+// s: ...1 goto c1
+// c1: if x then goto e1 else goto l
+// e1: goto b
+// l: loop break=b next=d
+// d: goto c2:
+// c2: if x then goto e2 else goto e3
+// e3: goto d
+// e2: goto b
+// b: ...2
+//
+// s is the Start block
+// c1, c2 are the Condition blocks
+// e1, e2, e3 are the critical Edge breakers
+// l is the Loop entering block
+// d is the loop boDy
+// b is the Break block
+static void invertLoop(IRBuilder& builder, IRLoop* loop)
+{
+ IRBuilderInsertLocScope builderScope(&builder);
+ const auto s = as<IRBlock>(loop->getParent());
+ auto domTree = computeDominatorTree(s->getParent());
+ SLANG_ASSERT(s);
+ const auto c1 = loop->getTargetBlock();
+ const auto c1Terminator = as<IRConditionalBranch>(c1->getTerminator());
+ SLANG_ASSERT(c1Terminator);
+ const auto b = loop->getBreakBlock();
+ auto& c1dUse = c1Terminator->getTrueBlock() == b ? c1Terminator->falseBlock : c1Terminator->trueBlock;
+ auto& c1bUse = c1Terminator->getTrueBlock() == b ? c1Terminator->trueBlock : c1Terminator->falseBlock;
+ const auto d = as<IRBlock>(c1dUse.get());
+ SLANG_ASSERT(d);
+
+ IRCloneEnv cloneEnv;
+ cloneEnv.squashChildrenMapping = true;
+
+ // Since we are duplicating the loop break condition block (c1) we must
+ // introduce phi values for anything in it upon which the rest of the
+ // program (b onwards) uses. Lift the values fron c1 used in b (and
+ // onwards) to parameters. To avoid a critical edge, pass these via a new
+ // block, e1.
+ builder.setInsertInto(b);
+ List<IRInst*> c1Params;
+ for(auto i : IRInstList<IRInst>(c1->getFirstInst(), c1->getLastInst()))
+ {
+ IRParam* p = nullptr;
+ traverseUses(i, [&](IRUse* u){
+ auto userBlock = u->getUser()->getParent();
+ if(domTree->dominates(b, userBlock))
+ {
+ // A new parameter to replace this 'i'
+ if(!p)
+ p = duplicateToParamWithDecorations(builder, cloneEnv, i);
+ u->set(p);
+ }
+ });
+ if(p)
+ c1Params.add(i);
+ }
+ auto e1 = builder.emitBlock();
+ e1->insertAfter(c1);
+ builder.emitBranch(b, c1Params.getCount(), c1Params.getBuffer());
+ c1bUse.set(e1);
+ // Similarly, we have to replace any existing 'break's to break via e1
+ traverseUses(b, [&](IRUse* u){
+ auto userBlock = u->getUser()->getParent();
+ // Restrict this to just those blocks within this loop
+ if(userBlock != e1 && domTree->dominates(s, userBlock) && !domTree->dominates(b, userBlock))
+ u->set(e1);
+ });
+ // We now have
+ // s: ...1 loop break=b next=c1
+ // c1: if x then goto e1 else goto d
+ // e1: goto b
+ // d: goto c1
+ // b: ...2
+
+ // Duplicate c1 into c2, and using the same cloneEnv, duplicate e1 into e2
+ builder.setInsertInto(builder.getFunc());
+ const auto c2 = as<IRBlock>(cloneInst(&cloneEnv, &builder, c1));
+ c2->insertBefore(c1);
+ const auto e2 = as<IRBlock>(cloneInst(&cloneEnv, &builder, e1));
+ e2->insertAfter(c2);
+ const auto c2Terminator = as<IRConditionalBranch>(c2->getTerminator());
+ auto& c2eUse = c2Terminator->getTrueBlock() == e1 ? c2Terminator->trueBlock : c2Terminator->falseBlock;
+ c2eUse.set(e2);
+ builder.setInsertAfter(c2Terminator);
+ const auto newC2Terminator = builder.emitIfElse(c2Terminator->getCondition(), c2Terminator->getTrueBlock(), c2Terminator->getFalseBlock(), b);
+ c2Terminator->removeAndDeallocate();
+ // We now have
+ // s: ...1 loop break=b next=c1
+ // c2: if x then goto e2 else goto d
+ // e2: goto b
+ // c1: if x then goto e1 else goto d
+ // e1: goto b
+ // d: goto c1
+ // b: ...2
+
+ // move the loop instruction to its own block, l
+ const auto l = builder.emitBlock();
+ l->insertAfter(e2);
+ loop->insertAtEnd(l);
+ // We now have
+ // s: ...1 no-termiator
+ // c2: if x then goto e2 else goto d
+ // e2: goto b
+ // l: loop break=b next=c1
+ // c1: if x then goto e1 else goto d
+ // e1: goto b
+ // d: goto c1
+ // b: ...2
+
+ // add a new terminator to s. A jump to c2, our outer conditional. retain
+ // any parameters the loop instruction passed to c1
+ builder.setInsertInto(s);
+ List<IRInst*> as;
+ for(UInt i = 0; i < loop->getArgCount(); ++i)
+ as.add(loop->getArg(i));
+ builder.emitBranch(c2, as.getCount(), as.getBuffer());
+ // We now have
+ // s: ...1, goto c2
+ // c2: if x then goto e2 else goto d
+ // e2: goto b
+ // l: loop break=b next=c1
+ // c1: if x then goto e1 else goto d
+ // e1: goto b
+ // d: goto c1
+ // b: ...2
+
+ // modify c2 to jump to the new loop
+ auto& c2dUse = newC2Terminator->getTrueBlock() == e2 ? newC2Terminator->falseBlock : newC2Terminator->trueBlock;
+ c2dUse.set(l);
+ // We now have
+ // s: ...1, goto c2
+ // c2: if x then goto e2 else goto l
+ // e2: goto b
+ // l: loop break=b next=c1
+ // c1: if x then goto e1 else goto d
+ // e1: goto b
+ // d: goto c1
+ // b: ...2
+
+ //
+ // Now we can modify the loop to jump to the block after the first
+ // conditional, d, as we know that it won't break out of the loop on the
+ // first iteration
+ //
+ // Beyond just retargeting the loop instruction, we need to make sure any
+ // parameters the loop instruction is passing to c1 are instead passed to
+ // 'd', and because we've added parameters to 'd' we need to forward them
+ // from c1 also which we will accomplish using a new block, e3,
+ loop->block.set(d);
+ loop->breakBlock.set(e1);
+ SLANG_ASSERT(d->getFirstParam() == nullptr);
+ c1->insertBefore(b);
+ e1->insertAfter(c1);
+ List<IRInst*> ps;
+ for(const auto p : c1->getParams())
+ ps.add(p);
+ builder.setInsertInto(d);
+ for(const auto p : ps)
+ {
+ const auto q = duplicateToParamWithDecorations(builder, cloneEnv, p);
+ // Replace all uses, except for those in c1 and e1
+ List<IRUse*> uses;
+ traverseUses(p, [&](IRUse* u){if(u->user->getParent() != c1 && u->user->getParent() != e1) uses.add(u);});
+ for(auto u : uses)
+ u->set(q);
+ }
+ const auto e3 = builder.emitBlock();
+ e3->insertAfter(c1);
+ builder.emitBranch(d, ps.getCount(), ps.getBuffer());
+ c1dUse.set(e3);
+ // We now have the desired output
+ // s: ...1, goto c2
+ // c2: if x then goto e2 else goto l
+ // e2: goto b
+ // l: loop break=e1 next=d
+ // d: goto c1
+ // c1: if x then goto e1 else goto e3
+ // e3: goto d
+ // e1: goto b
+ // b: ...2
+}
+
+bool invertLoops(IRModule* module)
+{
+ IRBuilder builder(module);
+ List<IRLoop*> toInvert;
+ overAllBlocks(module, [&](auto b){
+ if(auto loop = as<IRLoop>(b->getTerminator()))
+ {
+ if(isSuitableForInversion(loop))
+ toInvert.add(loop);
+ }
+ });
+ for(const auto loop : toInvert)
+ invertLoop(builder, loop);
+ return toInvert.getCount() > 0;
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-loop-inversion.h b/source/slang/slang-ir-loop-inversion.h
new file mode 100644
index 000000000..66c386d57
--- /dev/null
+++ b/source/slang/slang-ir-loop-inversion.h
@@ -0,0 +1,10 @@
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ struct IRInst;
+
+ bool invertLoops(IRModule* module);
+}
+
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index a0336e1c2..57a6c7c92 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -224,6 +224,22 @@ bool isOne(IRInst* inst);
void initializeScratchData(IRInst* inst);
void resetScratchDataBit(IRInst* inst, int bitIndex);
+// Run an operation over every block in a module
+template<typename F>
+static void overAllBlocks(IRModule* module, F f)
+{
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRGlobalValueWithCode>(globalInst))
+ {
+ for (auto block : func->getBlocks())
+ {
+ f(block);
+ }
+ }
+ }
+}
+
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 421c60d1d..1a499842d 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5134,7 +5134,7 @@ namespace Slang
return inst;
}
- IRInst* IRBuilder::emitIfElse(
+ IRIfElse* IRBuilder::emitIfElse(
IRInst* val,
IRBlock* trueBlock,
IRBlock* falseBlock,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 74ece5bc9..1ed79fbe3 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9,6 +9,7 @@
#include "../core/slang-performance-profiler.h"
#include "slang-check.h"
+#include "slang-ir-loop-inversion.h"
#include "slang-ir.h"
#include "slang-ir-constexpr.h"
#include "slang-ir-dce.h"
@@ -9702,6 +9703,24 @@ RefPtr<IRModule> generateIRForTranslationUnit(
//
performMandatoryEarlyInlining(module);
+ // Where possible, move loop condition checks to the end of loops, and wrap
+ // the loop in an 'if(condition)'.
+ // This makes it so that if sccp can see that the loop will always loop
+ // at least once it can record this information by removing the outer
+ // conditional.
+ // This has advantages:
+ // - Uninitialized variable usage detection doesn't have to
+ // worry about a loop never being executed.
+ // - The loop condition is evaluated one fewer times.
+ // - Allegedly better performance on pipelined processors:
+ // https://en.wikipedia.org/wiki/Loop_inversion
+ //
+ // And disadvantages
+ // - If sccp is unable to eliminate the outer 'if' then we end up with
+ // duplicated code the the conditional value. Users don't tend to put
+ // huge gobs of code in the conditional expression in loops however.
+ invertLoops(module);
+
// Next, attempt to promote local variables to SSA
// temporaries and do basic simplifications.
//
diff --git a/tests/cross-compile/geometry-shader.slang.glsl b/tests/cross-compile/geometry-shader.slang.glsl
index 38dbd72ba..3b7ecca43 100644
--- a/tests/cross-compile/geometry-shader.slang.glsl
+++ b/tests/cross-compile/geometry-shader.slang.glsl
@@ -68,13 +68,6 @@ void main()
for(;;)
{
- if(ii_0 < 3)
- {}
- else
- {
- break;
- }
-
RasterVertex_0 rasterVertex_0;
rasterVertex_0.position_0 = _S10[ii_0].position_1;
rasterVertex_0.color_0 = _S10[ii_0].color_1;
@@ -82,13 +75,17 @@ void main()
RasterVertex_0 _S11 = rasterVertex_0;
_S4 = rasterVertex_0.position_0;
_S5 = _S11.color_0;
-
gl_Layer = int(_S11.id_0);
-
EmitVertex();
-
- ii_0 = ii_0 + 1;
+ int ii_1 = ii_0 + 1;
+ if(ii_1 < 3)
+ {
+ ii_0 = ii_1;
+ }
+ else
+ {
+ break;
+ }
}
-
return;
}
diff --git a/tests/cross-compile/loop-attribs.slang.hlsl b/tests/cross-compile/loop-attribs.slang.hlsl
index 5d53f51e0..2c92d16f3 100644
--- a/tests/cross-compile/loop-attribs.slang.hlsl
+++ b/tests/cross-compile/loop-attribs.slang.hlsl
@@ -1,55 +1,49 @@
#pragma pack_matrix(column_major)
+#ifdef SLANG_HLSL_ENABLE_NVAPI
+#include "nvHLSLExtns.h"
+#endif
+#pragma warning(disable: 3557)
-#line 6 "tests/cross-compile/loop-attribs.slang"
-vector<float,4> main() : SV_TARGET
+float4 main() : SV_TARGET
{
- int i_0;
- float sum_0;
- int j_0;
- float sum_1;
- i_0 = int(0);
- sum_0 = 0.00000000000000000000;
+ float _S1 = 0.0;
+ int i_0 = int(0);
+ float sum_0 = 0.0;
[loop]
for(;;)
{
-
-#line 11
- if(i_0 < int(100))
+ float sum_1 = sum_0 + float(i_0);
+ _S1 = sum_1;
+ int i_1 = i_0 + int(1);
+ if(i_1 < int(100))
{
+ i_0 = i_1;
+ sum_0 = sum_1;
}
else
{
break;
}
- float _S1 = sum_0 + (float) i_0;
-
-#line 11
- int _S2 = i_0 + (int) int(1);
- i_0 = _S2;
- sum_0 = _S1;
}
- j_0 = int(0);
- sum_1 = sum_0;
+ float _S2 = 0.0;
+ int j_0 = int(0);
+ sum_0 = _S1;
[unroll]
for(;;)
{
-
-#line 15
- if(j_0 < int(100))
+ float sum_2 = sum_0 + float(j_0);
+ _S2 = sum_2;
+ int j_1 = j_0 + int(1);
+ if(j_1 < int(100))
{
+ j_0 = j_1;
+ sum_0 = sum_2;
}
else
{
break;
}
- float _S3 = sum_1 + (float) j_0;
-
-#line 15
- int _S4 = j_0 + (int) int(1);
- j_0 = _S4;
- sum_1 = _S3;
}
+ return float4(_S2, 0.0, 0.0, 0.0);
+}
-#line 18
- return vector<float,4>(sum_1, (float) int(0), (float) int(0), (float) int(0));
-} \ No newline at end of file
diff --git a/tests/ir/loop-inversion.slang b/tests/ir/loop-inversion.slang
new file mode 100644
index 000000000..03bdcc340
--- /dev/null
+++ b/tests/ir/loop-inversion.slang
@@ -0,0 +1,107 @@
+//TEST():SIMPLE(filecheck=CHECK):-entry computeMain -stage compute -line-directive-mode none -target hlsl
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=OUT):-shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=OUT):-dx12 -use-dxil -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=OUT):-cpu -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=OUT):-vk -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=OUT):-cpu -shaderobj -output-using-type
+
+// Check that all the backends cope with the slightly unusual IR the loop inversion generated
+
+// OUT: 180
+
+// For all the below functions, verify that the body (adding to j and
+// incrementing i) comes before any break. This verifies that the `break` has
+// been moved to the end of the loop.
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+// A standard loop
+// CHECK-LABEL: int a_{{.*}}()
+// CHECK-NOT: break;
+// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
+// CHECK: [[i]] + int(1);
+// CHECK: if(
+// CHECK: break;
+// CHECK: return
+int a()
+{
+ int j = 0;
+ for(int i = 0; i < 10; ++i)
+ j += i;
+ return j;
+}
+
+// A vanilla while loop
+// CHECK-LABEL: int b_{{.*}}()
+// CHECK-NOT: break;
+// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
+// CHECK: [[i]] + int(1);
+// CHECK: if(
+// CHECK: break;
+// CHECK: return
+int b()
+{
+ int j = 0;
+ int i = 0;
+ while(i < 10)
+ {
+ j += i;
+ i++;
+ }
+ return j;
+}
+
+// A while loop with a break on the false branch
+// CHECK-LABEL: int c_{{.*}}()
+// CHECK-NOT: break;
+// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
+// CHECK: [[i]] + int(1);
+// CHECK: if(
+// CHECK: break;
+// CHECK: return
+int c()
+{
+ int j = 0;
+ int i = 0;
+ do
+ {
+ if(i < 10)
+ {}
+ else
+ break;
+ j += i;
+ i++;
+ } while(true);
+ return j;
+}
+
+// A while loop with a break on the true branch
+// CHECK-LABEL: int d_{{.*}}()
+// CHECK-NOT: break;
+// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
+// CHECK: [[i]] + int(1);
+// CHECK: if(
+// CHECK: break;
+// CHECK: return
+int d()
+{
+ int j = 0;
+ int i = 0;
+ do
+ {
+ if(i >= 10)
+ break;
+ else
+ {}
+ j += i;
+ i++;
+ } while(true);
+ return j;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ outputBuffer[dispatchThreadID.x] = a() + b() + c() + d();
+}