diff options
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.cpp | 12 | ||||
| -rw-r--r-- | tests/bugs/simplify-if-else.slang | 26 |
2 files changed, 36 insertions, 2 deletions
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 90d30dcc7..68d79617a 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -490,11 +490,19 @@ static bool trySimplifyIfElse(IRBuilder& builder, IRIfElse* ifElseInst) bool isFalseBranchTrivial = false; if (isTrivialIfElse(ifElseInst, isTrueBranchTrivial, isFalseBranchTrivial)) { - // If both branches of `if-else` are trivial jumps into after block, + // If either branch of `if-else` is a trivial jump into after block, // we can get rid of the entire conditional branch and replace it // with a jump into the after block. - if (auto termInst = as<IRUnconditionalBranch>(ifElseInst->getTrueBlock()->getTerminator())) + IRUnconditionalBranch* termInst = + as<IRUnconditionalBranch>(ifElseInst->getTrueBlock()->getTerminator()); + if (!termInst || (termInst->getTargetBlock() != ifElseInst->getAfterBlock())) { + termInst = as<IRUnconditionalBranch>(ifElseInst->getFalseBlock()->getTerminator()); + } + + if (termInst) + { + SLANG_ASSERT(termInst->getTargetBlock() == ifElseInst->getAfterBlock()); List<IRInst*> args; for (UInt i = 0; i < termInst->getArgCount(); i++) args.add(termInst->getArg(i)); diff --git a/tests/bugs/simplify-if-else.slang b/tests/bugs/simplify-if-else.slang new file mode 100644 index 000000000..8719a1599 --- /dev/null +++ b/tests/bugs/simplify-if-else.slang @@ -0,0 +1,26 @@ +//TEST:SIMPLE(filecheck=CHECK): -stage compute -entry computeMain -target hlsl +//CHECK: computeMain + +//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + vector<float32_t, 4> vvv = vector<float32_t, 4>(0); + float32_t ret = 0.0f; + if (vvv.y < 1.0f) + { + ret = 1.0f; + } + else + { + if (vvv.y > 1.0f && outputBuffer[3] == 3) + { + ret = 0.0f; + } else { + if (true) {} + } + } + outputBuffer[int(dispatchThreadID.x)] = int(ret); +} |
