summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--build/visual-studio/slang/slang.vcxproj2
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters6
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/slang-ast-modifier.h11
-rw-r--r--source/slang/slang-check-modifier.cpp20
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-emit.cpp11
-rw-r--r--source/slang/slang-ir-dce.cpp20
-rw-r--r--source/slang/slang-ir-dominators.cpp38
-rw-r--r--source/slang/slang-ir-dominators.h33
-rw-r--r--source/slang/slang-ir-eliminate-multilevel-break.cpp2
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h8
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp520
-rw-r--r--source/slang/slang-ir-loop-unroll.h16
-rw-r--r--source/slang/slang-ir-peephole.cpp156
-rw-r--r--source/slang/slang-ir-peephole.h2
-rw-r--r--source/slang/slang-ir-sccp.cpp32
-rw-r--r--source/slang/slang-ir-sccp.h3
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp176
-rw-r--r--source/slang/slang-ir-util.cpp21
-rw-r--r--source/slang/slang-ir-util.h3
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
-rw-r--r--tests/ir/loop-unroll-0.slang19
-rw-r--r--tests/ir/loop-unroll-0.slang.expected.txt4
-rw-r--r--tests/ir/loop-unroll-1.slang25
-rw-r--r--tests/ir/loop-unroll-1.slang.expected.txt1
-rw-r--r--tests/ir/string-literal.slang.expected2
-rw-r--r--tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl16
29 files changed, 1021 insertions, 137 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index 979ee4e96..5f1f62511 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -391,6 +391,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-ir-legalize-varying-params.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-link.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-liveness.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-loop-unroll.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-com-methods.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-error-handling.h" />
@@ -577,6 +578,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-ir-legalize-varying-params.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-link.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-liveness.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-loop-unroll.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-com-methods.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-error-handling.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index e151f6c4f..f484b92e3 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -279,6 +279,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-liveness.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-loop-unroll.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -833,6 +836,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-liveness.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-loop-unroll.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 533713016..2a8344e3a 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2815,6 +2815,9 @@ __attributeTarget(LoopStmt)
attribute_syntax [unroll(count: int = 0)] : UnrollAttribute;
__attributeTarget(LoopStmt)
+attribute_syntax [ForceUnroll(count: int = 0)] : ForceUnrollAttribute;
+
+__attributeTarget(LoopStmt)
attribute_syntax [loop] : LoopAttribute;
__attributeTarget(LoopStmt)
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 42b79ca4a..4ab295da6 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -616,8 +616,15 @@ class AttributeUsageAttribute : public Attribute
class UnrollAttribute : public Attribute
{
SLANG_AST_CLASS(UnrollAttribute)
-
- IntegerLiteralValue getCount();
+
+};
+
+// An `[unroll]` or `[unroll(count)]` attribute
+class ForceUnrollAttribute : public Attribute
+{
+ SLANG_AST_CLASS(ForceUnrollAttribute)
+
+ int32_t maxIterations = 0;
};
// An `[maxiters(count)]`
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 9f3e79978..520d85971 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -507,15 +507,29 @@ namespace Slang
// as 1 arg if nothing is specified)
SLANG_ASSERT(attr->args.getCount() == 1);
}
+ else if (auto forceUnrollAttr = as<ForceUnrollAttribute>(attr))
+ {
+ if (forceUnrollAttr->args.getCount() < 1)
+ {
+ getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 1);
+ }
+ auto cint = checkConstantIntVal(attr->args[0]);
+ if (cint)
+ forceUnrollAttr->maxIterations = (int32_t)cint->value;
+ }
else if (auto maxItersAttrs = as<MaxItersAttribute>(attr))
{
- if (auto cint = checkConstantIntVal(attr->args[0]))
+ if (attr->args.getCount() < 1)
{
- maxItersAttrs->value = (int32_t) cint->value;
+ getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 1);
}
else
{
- getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 1);
+ auto cint = checkConstantIntVal(attr->args[0]);
+ if (cint)
+ {
+ maxItersAttrs->value = (int32_t) cint->value;
+ }
}
}
else if (auto userDefAttr = as<UserDefinedAttribute>(attr))
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index d3731756a..e29d7eeac 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -567,6 +567,7 @@ DIAGNOSTIC(40010, Note, seeInterfaceUsage, "see usage of interface '$0'.")
DIAGNOSTIC(40011, Error, unconstrainedGenericParameterNotAllowedInDynamicFunction, "unconstrained generic paramter '$0' is not allowed in a dynamic function.")
+DIAGNOSTIC(40020, Error, cannotUnrollLoop, "loop does not terminate within the limited number of iterations, unrolling is aborted.")
// 41000 - IR-level validation issues
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 2ef0a5647..c49265fe7 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -34,6 +34,7 @@
#include "slang-ir-lower-optional-type.h"
#include "slang-ir-lower-bit-cast.h"
#include "slang-ir-lower-reinterpret.h"
+#include "slang-ir-loop-unroll.h"
#include "slang-ir-metadata.h"
#include "slang-ir-optix-entry-point-uniforms.h"
#include "slang-ir-restructure.h"
@@ -377,6 +378,16 @@ Result linkAndOptimizeIR(
// since we may be missing out cases prevented by the functions that we just specialzied.
performMandatoryEarlyInlining(irModule);
+ // Unroll loops.
+ if (codeGenContext->getSink()->getErrorCount() == 0)
+ {
+ SharedIRBuilder sharedBuilder(irModule);
+ sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+ if (!unrollLoopsInModule(&sharedBuilder, irModule, codeGenContext->getSink()))
+ return SLANG_FAIL;
+ }
+
+
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
enableIRValidationAtInsert();
changed |= processAutodiffCalls(irModule, sink);
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index 337caa246..05c10b317 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -86,21 +86,13 @@ struct DeadCodeEliminationContext
{
if (!undefInst)
{
- for (auto inst : module->getModuleInst()->getChildren())
- {
- if (inst->getOp() == kIROp_undefined && inst->getDataType() && inst->getDataType()->getOp() == kIROp_VoidType)
- {
- undefInst = inst;
- break;
- }
- }
- if (!undefInst)
- {
- SharedIRBuilder builderStorage(module);
- IRBuilder builder(&builderStorage);
+ SharedIRBuilder builderStorage(module);
+ IRBuilder builder(&builderStorage);
+ if (auto firstChild = module->getModuleInst()->getFirstChild())
+ builder.setInsertBefore(firstChild);
+ else
builder.setInsertInto(module->getModuleInst());
- undefInst = builder.emitUndefined(builder.getVoidType());
- }
+ undefInst = Slang::getUndefInst(builder, module);
}
return undefInst;
}
diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp
index 1ffa7ba5d..5f606092b 100644
--- a/source/slang/slang-ir-dominators.cpp
+++ b/source/slang/slang-ir-dominators.cpp
@@ -276,29 +276,21 @@ struct DepthFirstSearchContext
/// then recursively visit its (unvisited) successors, and
/// then perform any post-actions.
///
- void walk(IRBlock* block)
+ template<typename SuccessorFunc>
+ void walk(IRBlock* block, const SuccessorFunc& getSuccessors)
{
visited.Add(block);
preVisit(block);
- for(auto succ : block->getSuccessors())
+ for(auto succ : getSuccessors(block))
{
if(!visited.Contains(succ))
{
- walk(succ);
+ walk(succ, getSuccessors);
}
}
postVisit(block);
}
- /// Walk the blocks in a function (or other code-bearing value).
- void walk(IRGlobalValueWithCode* code)
- {
- auto root = code->getFirstBlock();
- if(!root)
- return;
- walk(root);
- }
-
/// Overridable action to perform on first entering a CFG node.
virtual void preVisit(IRBlock* /*block*/) {}
@@ -329,7 +321,8 @@ void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder)
{
PostorderComputationContext context;
context.order = &outOrder;
- context.walk(code);
+ if (code->getFirstBlock())
+ context.walk(code->getFirstBlock(), [](IRBlock* block) {return block->getSuccessors(); });
// Append unvisited blocks (unreachable blocks) to the begining of postOrder.
List<IRBlock*> prefix;
@@ -344,6 +337,25 @@ void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder)
outOrder = _Move(prefix);
}
+void computePostorderOnReverseCFG(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder)
+{
+ PostorderComputationContext context;
+ context.order = &outOrder;
+ for (auto block = code->getLastBlock(); block; block = block->getPrevBlock())
+ {
+ auto terminator = block->getTerminator();
+ switch (terminator->getOp())
+ {
+ case kIROp_Return:
+ case kIROp_MissingReturn:
+ case kIROp_Unreachable:
+ context.walk(block, [](IRBlock* b) {return b->getPredecessors(); });
+ break;
+ }
+ }
+ return;
+}
+
//
// With the preliminaries out of the way, we are ready to implement
// the dominator tree construction algorithm as described by Cooper, Harvey, and Kennedy.
diff --git a/source/slang/slang-ir-dominators.h b/source/slang/slang-ir-dominators.h
index 1fb12c89e..14e84eac6 100644
--- a/source/slang/slang-ir-dominators.h
+++ b/source/slang/slang-ir-dominators.h
@@ -168,4 +168,37 @@ namespace Slang
};
RefPtr<IRDominatorTree> computeDominatorTree(IRGlobalValueWithCode* code);
+
+ void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder);
+ void computePostorderOnReverseCFG(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder);
+
+ inline List<IRBlock*> getPostorder(IRGlobalValueWithCode* code)
+ {
+ List<IRBlock*> result;
+ computePostorder(code, result);
+ return result;
+ }
+
+ inline List<IRBlock*> getPostorderOnReverseCFG(IRGlobalValueWithCode* code)
+ {
+ List<IRBlock*> result;
+ computePostorderOnReverseCFG(code, result);
+ return result;
+ }
+
+ inline List<IRBlock*> getReversePostorder(IRGlobalValueWithCode* code)
+ {
+ List<IRBlock*> result;
+ computePostorder(code, result);
+ result.reverse();
+ return result;
+ }
+
+ inline List<IRBlock*> getReversePostorderOnReverseCFG(IRGlobalValueWithCode* code)
+ {
+ List<IRBlock*> result;
+ computePostorderOnReverseCFG(code, result);
+ result.reverse();
+ return result;
+ }
}
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp
index 35412cc07..bf3f32217 100644
--- a/source/slang/slang-ir-eliminate-multilevel-break.cpp
+++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp
@@ -253,7 +253,7 @@ struct EliminateMultiLevelBreakContext
builder.emitBranch(newBreakBodyBlock);
builder.setInsertInto(newBreakBodyBlock);
auto levelNeq = builder.emitNeq(targetLevelParam, builder.getIntValue(builder.getIntType(), skippedRegion->level));
- builder.emitIfElse(levelNeq, jumpToOuterBlock, breakBlock, unreachableBlock);
+ builder.emitIfElse(levelNeq, jumpToOuterBlock, breakBlock, breakBlock);
builder.setInsertInto(jumpToOuterBlock);
if (skippedOverRegions.Contains(skippedRegion->parent))
{
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index e627c575d..788e02c90 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -678,6 +678,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// A `[ForceInline]` decoration indicates the callee should be inlined by the Slang compiler.
INST(ForceInlineDecoration, ForceInline, 0, 0)
+ /// A `[ForceUnroll]` decoration indicates the loop should be unrolled by the Slang compiler.
+ INST(ForceUnrollDecoration, ForceUnroll, 0, 0)
+
/// A `[naturalSizeAndAlignment(s,a)]` decoration is attached to a type to indicate that is has natural size `s` and alignment `a`
INST(NaturalSizeAndAlignmentDecoration, naturalSizeAndAlignment, 2, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 2453b56a7..7bc711f97 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -531,6 +531,9 @@ IR_SIMPLE_DECORATION(UnsafeForceInlineEarlyDecoration)
IR_SIMPLE_DECORATION(ForceInlineDecoration)
+IR_SIMPLE_DECORATION(ForceUnrollDecoration)
+
+
struct IRNaturalSizeAndAlignmentDecoration : IRDecoration
{
enum { kOp = kIROp_NaturalSizeAndAlignmentDecoration };
@@ -3548,6 +3551,11 @@ public:
addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(getIntType(), iters));
}
+ void addLoopForceUnrollDecoration(IRInst* value, IntegerLiteralValue iters)
+ {
+ addDecoration(value, kIROp_ForceUnrollDecoration, getIntValue(getIntType(), iters));
+ }
+
void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0)
{
addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index));
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
new file mode 100644
index 000000000..725f20902
--- /dev/null
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -0,0 +1,520 @@
+#include "slang-ir-loop-unroll.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-peephole.h"
+#include "slang-ir-dominators.h"
+#include "slang-ir-clone.h"
+#include "slang-ir-util.h"
+#include "slang-ir-simplify-cfg.h"
+
+namespace Slang
+{
+
+static bool _eliminateDeadBlocks(List<IRBlock*>& blocks, IRBlock* unreachableBlock)
+{
+ if (blocks.getCount() == 0)
+ return false;
+ bool changed = false;
+ HashSet<IRBlock*> aliveBlocks;
+ aliveBlocks.Add(blocks[0]);
+ List<IRBlock*> workList;
+ workList.add(blocks[0]);
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto block = workList[i];
+ for (auto succ : block->getSuccessors())
+ {
+ if (aliveBlocks.Add(succ))
+ {
+ workList.add(succ);
+ }
+ }
+ }
+ for (auto& b : blocks)
+ {
+ if (!aliveBlocks.Contains(b))
+ {
+ if (b->hasUses())
+ {
+ b->replaceUsesWith(unreachableBlock);
+ }
+ b->removeAndDeallocate();
+ b = nullptr;
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+List<IRBlock*> _collectBlocksInLoop(Dictionary<IRBlock*, int>& blockOrdering, IRLoop* loopInst)
+{
+ List<IRBlock*> loopBlocks;
+ HashSet<IRBlock*> loopBlocksSet;
+ auto addBlock = [&](IRBlock* block)
+ {
+ if (loopBlocksSet.Add(block))
+ loopBlocks.add(block);
+ };
+ auto firstBlock = as<IRBlock>(loopInst->block.get());
+ auto breakBlock = as<IRBlock>(loopInst->breakBlock.get());
+ auto breakBlockOrdering = blockOrdering[breakBlock].GetValue();
+
+ addBlock(firstBlock);
+ for (Index i = 0; i < loopBlocks.getCount(); i++)
+ {
+ auto block = loopBlocks[i];
+ for (auto succ : block->getSuccessors())
+ {
+ if (succ == breakBlock)
+ continue;
+ auto successorOrdering = blockOrdering[block].GetValue();
+ // The target must be post-dominated by the break block in order to be considered
+ // the body of the loop.
+ // Since we don't support arbitrary goto or multi-level continue, the simple
+ // ordering comparison is sufficient to serve as a post-dominance check.
+ if (successorOrdering < breakBlockOrdering)
+ addBlock(succ);
+ }
+ }
+ return loopBlocks;
+}
+
+static int _getLoopMaxIterationsToUnroll(IRLoop* loopInst)
+{
+ static constexpr int kMaxIterationsToAttempt = 100;
+
+ auto forceUnrollDecor = loopInst->findDecoration<IRForceUnrollDecoration>();
+ if (!forceUnrollDecor)
+ return -1;
+
+ int maxIterations = kMaxIterationsToAttempt;
+ auto maxIterCount = as<IRIntLit>(forceUnrollDecor->getOperand(0));
+ if (maxIterCount && maxIterCount->getValue() != 0)
+ {
+ maxIterations =
+ Math::Clamp(maxIterations, (int)maxIterCount->getValue() + 1, kMaxIterationsToAttempt);
+ }
+ return maxIterations;
+}
+
+static void _foldAndSimplifyLoopIteration(
+ IRBuilder& builder,
+ List<IRBlock*>& clonedBlocks,
+ IRBlock* firstIterationBreakBlock,
+ IRBlock* unreachableBlock)
+{
+ for (;;)
+ {
+ // Try to simplify and evaluate each inst in `firstIterationBreakBlock` and in
+ // cloned loop body.
+ for (auto b : clonedBlocks)
+ {
+ for (auto inst : b->getChildren())
+ {
+ tryReplaceInstUsesWithSimplifiedValue(builder.getSharedBuilder(), inst);
+ }
+ }
+
+ // It is important to also evaluate `firstIterationBreakBlock` because we need to have
+ // the phi arguments for next iteration evaluated (args in the new loop inst).
+ for (auto inst : firstIterationBreakBlock->getChildren())
+ {
+ tryReplaceInstUsesWithSimplifiedValue(builder.getSharedBuilder(), inst);
+ }
+
+ // Fold conditional branches into unconditional branches if the condition is known.
+ for (auto b : clonedBlocks)
+ {
+ auto terminator = b->getTerminator();
+ if (auto cbranch = as<IRConditionalBranch>(terminator))
+ {
+ if (auto constCondition = as<IRConstant>(cbranch->getCondition()))
+ {
+ auto targetBlock = (constCondition->value.intVal != 0) ? cbranch->getTrueBlock() : cbranch->getFalseBlock();
+ builder.setInsertBefore(cbranch);
+ builder.emitBranch(targetBlock);
+ cbranch->removeAndDeallocate();
+ }
+ }
+ else if (auto switchInst = as<IRSwitch>(terminator))
+ {
+ if (auto constCondition = as<IRConstant>(switchInst->condition.get()))
+ {
+ for (UInt i = 0; i < switchInst->getCaseCount(); i++)
+ {
+ if (constCondition == switchInst->getCaseValue(i))
+ {
+ builder.setInsertBefore(switchInst);
+ builder.emitBranch(switchInst->getCaseLabel(i));
+ switchInst->removeAndDeallocate();
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ // DCE on CFG.
+ bool hasChanges = _eliminateDeadBlocks(clonedBlocks, unreachableBlock);
+ if (!hasChanges)
+ break;
+
+ // Delete removed blocks from clonedBlocks.
+ Index insertIndex = 0;
+ for (Index i = 0; i < clonedBlocks.getCount(); i++)
+ {
+ auto b = clonedBlocks[i];
+ if (b)
+ {
+ if (i != insertIndex)
+ {
+ clonedBlocks[insertIndex] = b;
+ insertIndex++;
+ }
+ }
+ }
+ clonedBlocks.setCount(insertIndex);
+ }
+}
+
+// Unroll loop up to a predefined maximum number of iterations.
+// Returns true if we can statically determine that the loop terminated within the iteration limit.
+// This operation assumes the loop does not have `continue` jumps, i.e. continueBlock == targetBlock.
+static bool _unrollLoop(
+ SharedIRBuilder* sharedBuilder,
+ IRLoop* loopInst,
+ List<IRBlock*>& blocks)
+{
+ if (blocks.getCount() == 0)
+ {
+ IRBuilder subBuilder(sharedBuilder);
+ subBuilder.setInsertBefore(loopInst);
+ subBuilder.emitBranch(loopInst->getBreakBlock());
+ loopInst->removeAndDeallocate();
+ return true;
+ }
+
+ auto maxIterations = _getLoopMaxIterationsToUnroll(loopInst);
+ if (maxIterations < 0)
+ return true;
+
+ // We assume all `continue`s are eliminated and turned into multi-level breaks
+ // before this operation.
+ SLANG_RELEASE_ASSERT(loopInst->getContinueBlock() == loopInst->getTargetBlock());
+
+ // Insert an outer breakable region so we have a break label to use as the target for
+ // any `break` jumps in the unrolled loop.
+ // Transform CFG from [..., loopInst] -> [loopTarget] ->... [originalLoopBreakBlock]
+ // Into: [..., loop] -> [outerBreakableRegionHeader, loopInst(phi_arg)] -> [(phi_param) loopTarget] -> ... ->
+ // [newLoopBreakBlock] -> [originalLoopBreakBlock/outerBreakableRegionBreakBlock]
+ // After this transform, the original break block of the loop will serve as the break block for the
+ // outer breakable region.
+
+ IRBuilder builder(sharedBuilder);
+
+ auto unreachableBlock = builder.createBlock();
+ builder.setInsertInto(unreachableBlock);
+ builder.emitUnreachable();
+ unreachableBlock->insertAtEnd(loopInst->parent->parent);
+
+ auto outerBreakableRegionHeader = builder.createBlock();
+ outerBreakableRegionHeader->insertBefore(loopInst->getTargetBlock());
+
+ auto newLoopBreakableRegionBreakBlock = builder.createBlock();
+ newLoopBreakableRegionBreakBlock->insertBefore(loopInst->getBreakBlock());
+
+ IRBlock* outerBreakableRegionBreakBlock = nullptr;
+ {
+ auto originalBreakBlock = loopInst->getBreakBlock();
+
+ // Since all `break`s in the original loop body will become jumps into
+ // `newLoopBreakableRegionBreakBlock` after unrolling, we need to make sure
+ // `newLoopBreakableRegionBreakBlock` contains exactly the same set of
+ // phi parameters as the original break block.
+
+ IRCloneEnv cloneEnv;
+ builder.setInsertInto(newLoopBreakableRegionBreakBlock);
+ List<IRInst*> newParams;
+ for (auto param : originalBreakBlock->getParams())
+ {
+ auto clonedParam = cloneInst(&cloneEnv, &builder, param);
+ newParams.add(clonedParam);
+ }
+
+ // Make the existing code in the loop body to jump into `newLoopBreakableRegionBreakBlock`
+ // instead, because we are going to make `originalBreakBlock` the new break block for
+ // the outer breakable region.
+
+ originalBreakBlock->replaceUsesWith(newLoopBreakableRegionBreakBlock);
+ builder.emitBranch(originalBreakBlock, newParams.getCount(), newParams.getBuffer());
+
+ // Use the original break block as the break block for the new outer loop.
+ outerBreakableRegionBreakBlock = originalBreakBlock;
+
+ // Use a loop inst to enter the breakable region. (This isn't a real loop).
+ builder.setInsertBefore(loopInst);
+ builder.emitLoop(
+ outerBreakableRegionHeader,
+ outerBreakableRegionBreakBlock,
+ outerBreakableRegionHeader);
+
+ // The original loop inst should now be moved into `outerBreakableRegionHeader`.
+ loopInst->insertAtEnd(outerBreakableRegionHeader);
+ }
+
+ bool loopTerminated = false;
+ for (int attempedIterations = 0; attempedIterations < maxIterations; attempedIterations++)
+ {
+ // Our task is to peel off the first iteration and put it in front of the
+ // loop.
+ // We will create a breakable region (via single iteration loop), and clone the loop body
+ // into this region. This region is defined by the header block `firstIterationLoopHeader`,
+ // and the converge block `firstIterationBreakBlock`.
+
+ IRCloneEnv cloneEnv;
+
+ auto loopTargetBlock = loopInst->getTargetBlock();
+ auto firstIterationLoopHeader = builder.createBlock();
+ firstIterationLoopHeader->insertBefore(loopTargetBlock);
+ auto firstIterationBreakBlock = builder.createBlock();
+ firstIterationBreakBlock->insertBefore(loopTargetBlock);
+
+ // Map loop params for first iteration to arguments, so that
+ // when we clone the blocks, these parameters will get replaced
+ // with the actual arguments.
+ UInt argId = 0;
+ for (auto param : loopTargetBlock->getParams())
+ {
+ cloneEnv.mapOldValToNew[param] = loopInst->getArg(argId);
+ argId++;
+ }
+
+ // While cloning the loop body, if we see any `break`s, we replace it with a branch
+ // into outerBreakableRegionBreakBlock.
+ // We replace the back edge with a jump into firstIterationBreakBlock.
+ // The original loop will start from firstIterationBreakBlock.
+ cloneEnv.mapOldValToNew[loopInst->getBreakBlock()] = outerBreakableRegionBreakBlock;
+ cloneEnv.mapOldValToNew[loopInst->getTargetBlock()] = firstIterationBreakBlock;
+
+ // Wire up the breakable region blocks.
+ // Note that the breakable region header will never have any phi params because there will never
+ // be back jumps into the header (it is a single iteration loop just for the break label).
+
+ builder.setInsertBefore(loopInst);
+ builder.emitLoop(firstIterationLoopHeader, firstIterationBreakBlock, firstIterationLoopHeader);
+
+ // The `firstIterationBreakBlock` is supposed to act as the `targetBlock` for the back-jump in the
+ // loop body. Therefore, if the original loop target block has any phi params, we will need the
+ // same set of phi params in `firstIterationBreakBlock` so keep those branches valid.
+
+ builder.setInsertInto(firstIterationBreakBlock);
+ {
+ IRCloneEnv paramCloneEnv;
+ List<IRInst*> newParams;
+ for (auto param : loopTargetBlock->getParams())
+ {
+ newParams.add(cloneInst(&paramCloneEnv, &builder, param));
+ }
+
+ // In `firstIterationBreakBlock`, we emit a new loop inst
+ // to start a loop for the remaining iterations.
+ auto newLoopInst = as<IRLoop>(builder.emitLoop(
+ loopTargetBlock,
+ loopInst->getBreakBlock(),
+ loopInst->getContinueBlock(),
+ newParams.getCount(),
+ newParams.getBuffer()));
+ loopInst->removeAndDeallocate();
+
+ // Update `loopInst` to represent the remaining loop iterations that are yet to be unrolled.
+ loopInst = newLoopInst;
+ }
+
+ // With the break region set up and wired, we can now clone the loop body into the break region.
+ // We create all the blocks first, and setup the clone mapping for the blocks so when we
+ // clone the insts later, the branch targets will automatically set to their clones.
+
+ List<IRBlock*> clonedBlocks;
+ for (auto b : blocks)
+ {
+ builder.setInsertBefore(firstIterationBreakBlock);
+ auto clonedBlock = builder.createBlock();
+ clonedBlock->insertBefore(firstIterationBreakBlock);
+ cloneEnv.mapOldValToNew.AddIfNotExists(b, clonedBlock);
+ clonedBlocks.add(clonedBlock);
+ }
+
+ // Now clone the insts inside each block.
+
+ for (Index i = 0; i < blocks.getCount(); i++)
+ {
+ auto originalBlock = blocks[i];
+ auto clonedBlock = clonedBlocks[i];
+ builder.setInsertInto(clonedBlock);
+ for (auto inst : originalBlock->getChildren())
+ {
+ cloneInst(&cloneEnv, &builder, inst);
+ }
+ }
+
+ // Wire the break region header to jump to the first loop body block.
+
+ builder.setInsertInto(firstIterationLoopHeader);
+ builder.emitBranch(clonedBlocks[0]);
+
+ // Cloned first block of the iteration should not have any params,
+ // they must have been replaced with actual arguments since we have set up
+ // the mappings for them before the clone.
+
+ SLANG_RELEASE_ASSERT(clonedBlocks[0]->getFirstParam() == nullptr);
+
+ // With all the insts for the first iteration in place, we now iteratively run
+ // SCCP and simplification for the cloned blocks, in hope that some
+ // conditional jumps can be folded into unconditional jumps.
+
+ _foldAndSimplifyLoopIteration(
+ builder, clonedBlocks, firstIterationBreakBlock, unreachableBlock);
+
+ // Now we have peeled off one iteration from the loop, we check if there are any
+ // branches into next iteration, if not, the loop terminates and we are done.
+
+ bool hasJumpsToRemainingLoop = false;
+ for (auto b : clonedBlocks)
+ {
+ for (auto succ : b->getSuccessors())
+ {
+ if (succ == firstIterationBreakBlock)
+ {
+ hasJumpsToRemainingLoop = true;
+ break;
+ }
+ }
+ }
+ if (!hasJumpsToRemainingLoop)
+ {
+ loopTerminated = true;
+
+ // Now we know the loop terminates and we have just emitted the last iteration.
+ // We need to replace all uses of the insts defined within the loop body with their
+ // clones in the last iteration.
+
+ HashSet<IRBlock*> blockSet;
+ for (auto block : blocks)
+ {
+ blockSet.Add(block);
+ }
+ for (auto block : blocks)
+ {
+ for (auto inst : block->getChildren())
+ {
+ IRInst* newInst = nullptr;
+ if (!cloneEnv.mapOldValToNew.TryGetValue(inst, newInst))
+ continue;
+ for (auto use = inst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+ if (!blockSet.Contains(as<IRBlock>(use->getUser()->getParent())))
+ {
+ use->set(newInst);
+ }
+ use = nextUse;
+ }
+ }
+ }
+
+ // Now we can safely delete the original loop blocks.
+
+ for (auto block : blocks)
+ {
+ block->replaceUsesWith(unreachableBlock);
+ block->removeAndDeallocate();
+ }
+
+ // firstIterationBreakBlock is no longer reachable, so we can delete its children
+ // and turn it into an unreachable block.
+
+ firstIterationBreakBlock->removeAndDeallocateAllDecorationsAndChildren();
+ builder.setInsertInto(firstIterationBreakBlock);
+ builder.emitBranch(unreachableBlock);
+
+ break;
+ }
+ }
+
+ return loopTerminated;
+}
+
+bool unrollLoopsInFunc(
+ SharedIRBuilder* sharedBuilder,
+ IRGlobalValueWithCode* func,
+ DiagnosticSink* sink)
+{
+ List<IRLoop*> loops;
+
+ // Post order processing allows us to process inner loops first.
+ auto postOrder = getPostorder(func);
+
+ for (auto block : postOrder)
+ {
+ if (auto loop = as<IRLoop>(block->getTerminator()))
+ {
+ if (loop->findDecoration<IRForceUnrollDecoration>())
+ {
+ loops.add(loop);
+ }
+ }
+ }
+
+ if (loops.getCount() == 0)
+ return true;
+
+ for (auto loop : loops)
+ {
+ auto postOrderReverseCFG = getPostorderOnReverseCFG(func);
+ Dictionary<IRBlock*, int> blockOrdering;
+
+ for (Index i = 0; i < postOrderReverseCFG.getCount(); i++)
+ {
+ blockOrdering[postOrderReverseCFG[i]] = (int)i;
+ }
+
+ auto blocks = _collectBlocksInLoop(blockOrdering, loop);
+ auto loopLoc = loop->sourceLoc;
+ if (!_unrollLoop(sharedBuilder, loop, blocks))
+ {
+ if (sink)
+ sink->diagnose(loopLoc, Diagnostics::cannotUnrollLoop);
+ return false;
+ }
+
+ // Make sure we simplify things as much as possible before
+ // attempting to potentially unroll outer loop.
+ simplifyCFG(func);
+ }
+ return true;
+}
+
+bool unrollLoopsInModule(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink)
+{
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto genFunc = as<IRGeneric>(inst))
+ {
+ if (auto func = as<IRGlobalValueWithCode>(findGenericReturnVal(genFunc)))
+ {
+ bool result = unrollLoopsInFunc(sharedBuilder, func, sink);
+ if (!result)
+ return false;
+ }
+ }
+ else if (auto func = as<IRGlobalValueWithCode>(inst))
+ {
+ bool result = unrollLoopsInFunc(sharedBuilder, func, sink);
+ if (!result)
+ return false;
+ }
+ }
+ return true;
+}
+
+}
diff --git a/source/slang/slang-ir-loop-unroll.h b/source/slang/slang-ir-loop-unroll.h
new file mode 100644
index 000000000..a63625285
--- /dev/null
+++ b/source/slang/slang-ir-loop-unroll.h
@@ -0,0 +1,16 @@
+// slang-ir-loop-unroll.h
+#pragma once
+
+namespace Slang
+{
+ struct IRLoop;
+ struct IRGlobalValueWithCode;
+ struct SharedIRBuilder;
+ class DiagnosticSink;
+ struct IRModule;
+
+ // Return true if successfull, false if errors occurred.
+ bool unrollLoopsInFunc(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func, DiagnosticSink* sink);
+
+ bool unrollLoopsInModule(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink);
+}
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 65b4adcac..4dbe6d2cb 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -1,5 +1,6 @@
#include "slang-ir-peephole.h"
#include "slang-ir-inst-pass-base.h"
+#include "slang-ir-sccp.h"
namespace Slang
{
@@ -11,6 +12,13 @@ struct PeepholeContext : InstPassBase
bool changed = false;
FloatingPointMode floatingPointMode = FloatingPointMode::Precise;
+ bool removeOldInst = true;
+
+ void maybeRemoveOldInst(IRInst* inst)
+ {
+ if (removeOldInst)
+ inst->removeAndDeallocate();
+ }
bool tryFoldElementExtractFromUpdateInst(IRInst* inst)
{
@@ -71,7 +79,7 @@ struct PeepholeContext : InstPassBase
if (remainingKeys.getCount() == 0)
{
inst->replaceUsesWith(updateInst->getElementValue());
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
return true;
}
else if (remainingKeys.getCount() > 0)
@@ -80,7 +88,7 @@ struct PeepholeContext : InstPassBase
builder.setInsertBefore(inst);
auto newValue = builder.emitElementExtract(updateInst->getElementValue(), remainingKeys);
inst->replaceUsesWith(newValue);
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
return true;
}
}
@@ -90,7 +98,7 @@ struct PeepholeContext : InstPassBase
builder.setInsertBefore(inst);
auto newInst = builder.emitElementExtract(updateInst->getOldValue(), chainKey.getArrayView());
inst->replaceUsesWith(newInst);
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
return true;
}
}
@@ -105,6 +113,8 @@ struct PeepholeContext : InstPassBase
return as<IRIntLit>(inst)->getValue() == 0;
case kIROp_FloatLit:
return as<IRFloatLit>(inst)->getValue() == 0.0;
+ case kIROp_BoolLit:
+ return as<IRBoolLit>(inst)->getValue() == false;
case kIROp_MakeVector:
case kIROp_MakeVectorFromScalar:
case kIROp_MakeMatrix:
@@ -137,6 +147,8 @@ struct PeepholeContext : InstPassBase
return as<IRIntLit>(inst)->getValue() == 1;
case kIROp_FloatLit:
return as<IRFloatLit>(inst)->getValue() == 1.0;
+ case kIROp_BoolLit:
+ return as<IRBoolLit>(inst)->getValue();
case kIROp_MakeVector:
case kIROp_MakeVectorFromScalar:
case kIROp_MakeMatrix:
@@ -188,7 +200,7 @@ struct PeepholeContext : InstPassBase
}
inst->replaceUsesWith(replacement);
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
return true;
};
@@ -237,6 +249,43 @@ struct PeepholeContext : InstPassBase
{
return tryReplace(inst->getOperand(0));
}
+ break;
+ case kIROp_And:
+ if (isZero(inst->getOperand(0)))
+ {
+ return tryReplace(inst->getOperand(0));
+ }
+ else if (isZero(inst->getOperand(1)))
+ {
+ return tryReplace(inst->getOperand(1));
+ }
+ else if (isOne(inst->getOperand(1)))
+ {
+ return tryReplace(inst->getOperand(0));
+ }
+ else if (isOne(inst->getOperand(0)))
+ {
+ return tryReplace(inst->getOperand(1));
+ }
+ break;
+ case kIROp_Or:
+ if (isZero(inst->getOperand(0)))
+ {
+ return tryReplace(inst->getOperand(1));
+ }
+ else if (isZero(inst->getOperand(1)))
+ {
+ return tryReplace(inst->getOperand(0));
+ }
+ else if (isOne(inst->getOperand(1)))
+ {
+ return tryReplace(inst->getOperand(1));
+ }
+ else if (isOne(inst->getOperand(0)))
+ {
+ return tryReplace(inst->getOperand(0));
+ }
+ break;
}
return false;
}
@@ -255,6 +304,7 @@ struct PeepholeContext : InstPassBase
if (inst->getOperand(0)->getOp() == kIROp_MakeResultError)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
+ maybeRemoveOldInst(inst);
changed = true;
}
break;
@@ -262,7 +312,7 @@ struct PeepholeContext : InstPassBase
if (inst->getOperand(0)->getOp() == kIROp_MakeResultValue)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
break;
@@ -271,14 +321,14 @@ struct PeepholeContext : InstPassBase
{
IRBuilder builder(&sharedBuilderStorage);
inst->replaceUsesWith(builder.getBoolValue(true));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
else if (inst->getOperand(0)->getOp() == kIROp_MakeResultValue)
{
IRBuilder builder(&sharedBuilderStorage);
inst->replaceUsesWith(builder.getBoolValue(false));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
break;
@@ -289,7 +339,7 @@ struct PeepholeContext : InstPassBase
if (auto intLit = as<IRIntLit>(element))
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand((UInt)intLit->value.intVal));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -315,7 +365,7 @@ struct PeepholeContext : InstPassBase
if (fieldIndex != -1 && fieldIndex < (Index)inst->getOperand(0)->getOperandCount())
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand((UInt)fieldIndex));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -335,14 +385,14 @@ struct PeepholeContext : InstPassBase
if ((UInt)index->getValue() < opCount)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand((UInt)index->getValue()));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
else if (inst->getOperand(0)->getOp() == kIROp_MakeArrayFromElement)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
else
@@ -386,7 +436,7 @@ struct PeepholeContext : InstPassBase
builder.setInsertBefore(inst);
auto makeArray = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
inst->replaceUsesWith(makeArray);
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -425,7 +475,7 @@ struct PeepholeContext : InstPassBase
builder.setInsertBefore(inst);
auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer());
inst->replaceUsesWith(makeStruct);
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -439,7 +489,7 @@ struct PeepholeContext : InstPassBase
builder.setInsertBefore(inst);
auto neq = builder.emitNeq(ptr, builder.getNullVoidPtrValue());
inst->replaceUsesWith(neq);
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
break;
@@ -453,7 +503,7 @@ struct PeepholeContext : InstPassBase
builder.setInsertBefore(inst);
auto trueVal = builder.getBoolValue(true);
inst->replaceUsesWith(trueVal);
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -466,7 +516,7 @@ struct PeepholeContext : InstPassBase
if (isTypeEqual(inst->getOperand(0)->getDataType(), inst->getDataType()))
{
inst->replaceUsesWith(inst->getOperand(0));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -478,7 +528,7 @@ struct PeepholeContext : InstPassBase
if (isTypeEqual(inst->getOperand(0)->getOperand(0)->getDataType(), inst->getDataType()))
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -490,7 +540,7 @@ struct PeepholeContext : InstPassBase
if (isTypeEqual(inst->getOperand(0)->getDataType(), inst->getDataType()))
{
inst->replaceUsesWith(inst->getOperand(0));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -500,7 +550,7 @@ struct PeepholeContext : InstPassBase
if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -513,7 +563,7 @@ struct PeepholeContext : InstPassBase
builder.setInsertBefore(inst);
auto trueVal = builder.getBoolValue(true);
inst->replaceUsesWith(trueVal);
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
else if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalNone)
@@ -522,7 +572,7 @@ struct PeepholeContext : InstPassBase
builder.setInsertBefore(inst);
auto falseVal = builder.getBoolValue(false);
inst->replaceUsesWith(falseVal);
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -532,7 +582,7 @@ struct PeepholeContext : InstPassBase
if (inst->getOperand(0)->getOp() == kIROp_PtrLit)
{
inst->replaceUsesWith(inst->getOperand(0));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -542,7 +592,7 @@ struct PeepholeContext : InstPassBase
if (inst->getOperand(0)->getOp() == kIROp_ExtractExistentialValue)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -578,7 +628,7 @@ struct PeepholeContext : InstPassBase
if (auto newCtor = builder.emitDefaultConstruct(inst->getFullType(), false))
{
inst->replaceUsesWith(newCtor);
- inst->removeAndDeallocate();
+ maybeRemoveOldInst(inst);
changed = true;
}
}
@@ -587,8 +637,55 @@ struct PeepholeContext : InstPassBase
case kIROp_Mul:
case kIROp_Sub:
case kIROp_Div:
+ case kIROp_And:
+ case kIROp_Or:
changed = tryOptimizeArithmeticInst(inst);
break;
+
+ case kIROp_Param:
+ {
+ auto block = as<IRBlock>(inst->parent);
+ if (!block)
+ break;
+ UInt paramIndex = 0;
+ auto prevParam = inst->getPrevInst();
+ while (as<IRParam>(prevParam))
+ {
+ prevParam = prevParam->getPrevInst();
+ paramIndex++;
+ }
+ IRInst* argValue = nullptr;
+ for (auto pred : block->getPredecessors())
+ {
+ auto terminator = as<IRUnconditionalBranch>(pred->getTerminator());
+ if (!terminator)
+ continue;
+ SLANG_ASSERT(terminator->getArgCount() > paramIndex);
+ auto arg = terminator->getArg(paramIndex);
+ if (arg->getOp() == kIROp_undefined)
+ continue;
+ if (argValue == nullptr)
+ argValue = arg;
+ else if (argValue == arg)
+ {
+ }
+ else
+ {
+ argValue = nullptr;
+ break;
+ }
+ }
+ if (argValue)
+ {
+ if (inst->hasUses())
+ {
+ inst->replaceUsesWith(argValue);
+ // Never remove param inst.
+ changed = true;
+ }
+ }
+ }
+ break;
default:
break;
}
@@ -631,4 +728,15 @@ bool peepholeOptimize(IRInst* func)
return context.processFunc(func);
}
+bool tryReplaceInstUsesWithSimplifiedValue(SharedIRBuilder* sharedBuilder, IRInst* inst)
+{
+ if (inst != tryConstantFoldInst(sharedBuilder, inst))
+ return true;
+
+ PeepholeContext context = PeepholeContext(inst->getModule());
+ context.removeOldInst = false;
+ context.processInst(inst);
+ return context.changed;
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-peephole.h b/source/slang/slang-ir-peephole.h
index dc1b5527a..46ef9c80c 100644
--- a/source/slang/slang-ir-peephole.h
+++ b/source/slang/slang-ir-peephole.h
@@ -6,8 +6,10 @@ namespace Slang
struct IRModule;
struct IRCall;
struct IRInst;
+ struct SharedIRBuilder;
/// Apply peephole optimizations.
bool peepholeOptimize(IRModule* module);
bool peepholeOptimize(IRInst* func);
+ bool tryReplaceInstUsesWithSimplifiedValue(SharedIRBuilder* sharedBuilder, IRInst* inst);
}
diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp
index c03eee695..e60d2576f 100644
--- a/source/slang/slang-ir-sccp.cpp
+++ b/source/slang/slang-ir-sccp.cpp
@@ -15,7 +15,7 @@ namespace Slang {
struct SharedSCCPContext
{
IRModule* module;
- SharedIRBuilder sharedBuilder;
+ SharedIRBuilder* sharedBuilder;
};
//
// Next we have a context struct that will be applied for each function (or other
@@ -1663,8 +1663,10 @@ bool applySparseConditionalConstantPropagation(
{
SharedSCCPContext shared;
shared.module = module;
- shared.sharedBuilder.init(module);
- shared.sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+ SharedIRBuilder sharedBuilderStorage;
+ shared.sharedBuilder = &sharedBuilderStorage;
+ sharedBuilderStorage.init(module);
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
// First we fold constants at global scope.
SCCPContext globalContext;
@@ -1683,8 +1685,10 @@ bool applySparseConditionalConstantPropagation(IRInst* func)
{
SharedSCCPContext shared;
shared.module = func->getModule();
- shared.sharedBuilder.init(shared.module);
- shared.sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+ SharedIRBuilder sharedBuilderStorage;
+ shared.sharedBuilder = &sharedBuilderStorage;
+ sharedBuilderStorage.init(shared.module);
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
SCCPContext globalContext;
globalContext.shared = &shared;
@@ -1694,5 +1698,23 @@ bool applySparseConditionalConstantPropagation(IRInst* func)
return applySparseConditionalConstantPropagationRec(globalContext, func);
}
+IRInst* tryConstantFoldInst(SharedIRBuilder* sharedBuilder, IRInst* inst)
+{
+ SharedSCCPContext shared;
+ shared.module = inst->getModule();
+ shared.sharedBuilder = sharedBuilder;
+ SCCPContext instContext;
+ instContext.shared = &shared;
+ instContext.code = nullptr;
+ instContext.builderStorage.init(sharedBuilder);
+ auto foldResult = instContext.interpretOverLattice(inst);
+ if (!foldResult.value)
+ {
+ return inst;
+ }
+ inst->replaceUsesWith(foldResult.value);
+ return foldResult.value;
+}
+
}
diff --git a/source/slang/slang-ir-sccp.h b/source/slang/slang-ir-sccp.h
index 23c903eeb..80b21fbbb 100644
--- a/source/slang/slang-ir-sccp.h
+++ b/source/slang/slang-ir-sccp.h
@@ -5,6 +5,7 @@ namespace Slang
{
struct IRModule;
struct IRInst;
+ struct SharedIRBuilder;
/// Apply Sparse Conditional Constant Propagation (SCCP) to a module.
///
@@ -18,5 +19,7 @@ namespace Slang
IRModule* module);
bool applySparseConditionalConstantPropagation(IRInst* func);
+
+ IRInst* tryConstantFoldInst(SharedIRBuilder* sharedBuilder, IRInst* inst);
}
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index 54a1f7e08..ed6ad9089 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -96,6 +96,52 @@ static bool isTrivialSingleIterationLoop(
return true;
}
+static bool removeDeadBlocks(IRGlobalValueWithCode* func)
+{
+ bool changed = false;
+ List<IRBlock*> workList;
+ auto firstBlock = func->getFirstBlock();
+ if (!firstBlock)
+ return false;
+
+ for (auto block = firstBlock->getNextBlock(); block; block = block->getNextBlock())
+ {
+ workList.add(block);
+ }
+
+ HashSet<IRBlock*> workListSet;
+ List<IRBlock*> nextWorkList;
+ for (;;)
+ {
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto block = workList[i];
+ if (!block->hasUses() && as<IRTerminatorInst>(block->getFirstInst()))
+ {
+ for (auto succ : block->getSuccessors())
+ {
+ if (workListSet.Add(succ))
+ {
+ nextWorkList.add(succ);
+ }
+ }
+ block->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ if (nextWorkList.getCount())
+ {
+ workList = _Move(nextWorkList);
+ workListSet.Clear();
+ }
+ else
+ {
+ break;
+ }
+ }
+ return changed;
+}
+
static bool processFunc(IRGlobalValueWithCode* func)
{
auto firstBlock = func->getFirstBlock();
@@ -109,84 +155,90 @@ static bool processFunc(IRGlobalValueWithCode* func)
IRBuilder builder(&sharedBuilder);
bool changed = false;
-
- List<IRBlock*> workList;
- HashSet<IRBlock*> processedBlock;
- workList.add(func->getFirstBlock());
- while (workList.getCount())
+ for (;;)
{
- auto block = workList.getFirst();
- workList.fastRemoveAt(0);
- while (block)
+ List<IRBlock*> workList;
+ HashSet<IRBlock*> processedBlock;
+ workList.add(func->getFirstBlock());
+ while (workList.getCount())
{
- if (auto loop = as<IRLoop>(block->getTerminator()))
+ auto block = workList.getFirst();
+ workList.fastRemoveAt(0);
+ while (block)
{
- // If continue block is unreachable, remove it.
- auto continueBlock = loop->getContinueBlock();
- if (continueBlock && !continueBlock->hasMoreThanOneUse())
+ if (auto loop = as<IRLoop>(block->getTerminator()))
{
- loop->continueBlock.set(loop->getTargetBlock());
- continueBlock->removeAndDeallocate();
+ // If continue block is unreachable, remove it.
+ auto continueBlock = loop->getContinueBlock();
+ if (continueBlock && !continueBlock->hasMoreThanOneUse())
+ {
+ loop->continueBlock.set(loop->getTargetBlock());
+ continueBlock->removeAndDeallocate();
+ }
+
+ // If there isn't any actual back jumps into loop target and there is a trivial
+ // break at the end of the loop, we can remove the header and turn it into
+ // a normal branch.
+ auto targetBlock = loop->getTargetBlock();
+ if (isTrivialSingleIterationLoop(func, loop, simplificationContext))
+ {
+ builder.setInsertBefore(loop);
+ List<IRInst*> args;
+ for (UInt i = 0; i < loop->getArgCount(); i++)
+ {
+ args.add(loop->getArg(i));
+ }
+ builder.emitBranch(targetBlock, args.getCount(), args.getBuffer());
+ loop->removeAndDeallocate();
+ }
}
- // If there isn't any actual back jumps into loop target and there is a trivial
- // break at the end of the loop, we can remove the header and turn it into
- // a normal branch.
- auto targetBlock = loop->getTargetBlock();
- if (isTrivialSingleIterationLoop(func, loop, simplificationContext))
+ // If `block` does not end with an unconditional branch, bail.
+ if (block->getTerminator()->getOp() != kIROp_unconditionalBranch)
+ break;
+ auto branch = as<IRUnconditionalBranch>(block->getTerminator());
+ auto successor = branch->getTargetBlock();
+ // Only perform the merge if `block` is the only predecessor of `successor`.
+ // We also need to make sure not to merge a block that serves as the
+ // merge point in CFG. Such blocks will have more than one use.
+ if (successor->hasMoreThanOneUse())
+ break;
+ if (block->hasMoreThanOneUse())
+ break;
+ changed = true;
+ Index paramIndex = 0;
+ auto inst = successor->getFirstDecorationOrChild();
+ while (inst)
{
- builder.setInsertBefore(loop);
- List<IRInst*> args;
- for (UInt i = 0; i < loop->getArgCount(); i++)
+ auto next = inst->getNextInst();
+ if (inst->getOp() == kIROp_Param)
+ {
+ inst->replaceUsesWith(branch->getArg(paramIndex));
+ paramIndex++;
+ }
+ else
{
- args.add(loop->getArg(i));
+ inst->removeFromParent();
+ inst->insertAtEnd(block);
}
- builder.emitBranch(targetBlock, args.getCount(), args.getBuffer());
- loop->removeAndDeallocate();
+ inst = next;
}
+ branch->removeAndDeallocate();
+ assert(!successor->hasUses());
+ successor->removeAndDeallocate();
}
-
- // If `block` does not end with an unconditional branch, bail.
- if (block->getTerminator()->getOp() != kIROp_unconditionalBranch)
- break;
- auto branch = as<IRUnconditionalBranch>(block->getTerminator());
- auto successor = branch->getTargetBlock();
- // Only perform the merge if `block` is the only predecessor of `successor`.
- // We also need to make sure not to merge a block that serves as the
- // merge point in CFG. Such blocks will have more than one use.
- if (successor->hasMoreThanOneUse())
- break;
- if (block->hasMoreThanOneUse())
- break;
- changed = true;
- Index paramIndex = 0;
- auto inst = successor->getFirstDecorationOrChild();
- while (inst)
+ for (auto successor : block->getSuccessors())
{
- auto next = inst->getNextInst();
- if (inst->getOp() == kIROp_Param)
+ if (processedBlock.Add(successor))
{
- inst->replaceUsesWith(branch->getArg(paramIndex));
- paramIndex++;
+ workList.add(successor);
}
- else
- {
- inst->removeFromParent();
- inst->insertAtEnd(block);
- }
- inst = next;
- }
- branch->removeAndDeallocate();
- assert(!successor->hasUses());
- successor->removeAndDeallocate();
- }
- for (auto successor : block->getSuccessors())
- {
- if (processedBlock.Add(successor))
- {
- workList.add(successor);
}
}
+ bool blocksRemoved = removeDeadBlocks(func);
+ changed |= blocksRemoved;
+ if (!blocksRemoved)
+ break;
}
return changed;
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 942c8f2f8..0448eb649 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -443,6 +443,27 @@ bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, I
return false;
}
+IRInst* getUndefInst(IRBuilder builder, IRModule* module)
+{
+ IRInst* undefInst = nullptr;
+
+ for (auto inst : module->getModuleInst()->getChildren())
+ {
+ if (inst->getOp() == kIROp_undefined && inst->getDataType() && inst->getDataType()->getOp() == kIROp_VoidType)
+ {
+ undefInst = inst;
+ break;
+ }
+ }
+ if (!undefInst)
+ {
+ auto voidType = builder.getVoidType();
+ builder.setInsertAfter(voidType);
+ undefInst = builder.emitUndefined(voidType);
+ }
+ return undefInst;
+}
+
bool isPureFunctionalCall(IRCall* call)
{
auto callee = getResolvedInstForDecorations(call->getCallee());
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index c067bde44..efd38f7b7 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -167,6 +167,9 @@ bool isPureFunctionalCall(IRCall* callInst);
bool isPtrLikeOrHandleType(IRInst* type);
bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, IRInst* addr);
+
+IRInst* getUndefInst(IRBuilder builder, IRModule* module);
+
}
#endif
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 74f06557d..9cf6aedf4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4849,6 +4849,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
{
getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value);
}
+ else if (auto forceUnrollAttr = stmt->findModifier<ForceUnrollAttribute>())
+ {
+ getBuilder()->addLoopForceUnrollDecoration(inst, forceUnrollAttr->maxIterations);
+ }
// TODO: handle other cases here
}
diff --git a/tests/ir/loop-unroll-0.slang b/tests/ir/loop-unroll-0.slang
new file mode 100644
index 000000000..d97acac0e
--- /dev/null
+++ b/tests/ir/loop-unroll-0.slang
@@ -0,0 +1,19 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ int i = 0;
+ [ForceUnroll]
+ while (i < 5 && dispatchThreadID.x == 0)
+ {
+ if (i >= 3)
+ break;
+ outputBuffer[i] = i;
+ i++;
+ }
+}
diff --git a/tests/ir/loop-unroll-0.slang.expected.txt b/tests/ir/loop-unroll-0.slang.expected.txt
new file mode 100644
index 000000000..d2a5eee71
--- /dev/null
+++ b/tests/ir/loop-unroll-0.slang.expected.txt
@@ -0,0 +1,4 @@
+0
+1
+2
+0
diff --git a/tests/ir/loop-unroll-1.slang b/tests/ir/loop-unroll-1.slang
new file mode 100644
index 000000000..562cd3713
--- /dev/null
+++ b/tests/ir/loop-unroll-1.slang
@@ -0,0 +1,25 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ int i = 0;
+ int sum = 0;
+ [ForceUnroll]
+ while (i < 2)
+ {
+ int j = 1;
+ [ForceUnroll(2)]
+ while (j < 3)
+ {
+ sum += (i+j);
+ j++;
+ }
+ i++;
+ }
+ outputBuffer[0] = sum;
+}
diff --git a/tests/ir/loop-unroll-1.slang.expected.txt b/tests/ir/loop-unroll-1.slang.expected.txt
new file mode 100644
index 000000000..45a4fb75d
--- /dev/null
+++ b/tests/ir/loop-unroll-1.slang.expected.txt
@@ -0,0 +1 @@
+8
diff --git a/tests/ir/string-literal.slang.expected b/tests/ir/string-literal.slang.expected
index 48836cae6..6f414b0da 100644
--- a/tests/ir/string-literal.slang.expected
+++ b/tests/ir/string-literal.slang.expected
@@ -1,6 +1,7 @@
result code = 0
standard error = {
### LOWER-TO-IR:
+undefined
[entryPoint(6 : Int, "main", "string-literal")]
[numThreads(1 : Int, 1 : Int, 1 : Int)]
[export("_S3tu04mainp1puV")]
@@ -13,7 +14,6 @@ block %1(
return_val(void_constant)
}
global_hashed_string_literals("Hello \t\n\0x083 World")
-undefined
###
}
standard output = {
diff --git a/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl b/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl
index 04ef5a6fe..0364d2513 100644
--- a/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl
+++ b/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl
@@ -90,8 +90,6 @@ void main()
}
uint _S3 = (rayQueryGetIntersectionTypeEXT((query_0), false));
- MyProceduralHitAttrs_0 committedProceduralAttrs_1;
-
switch(_S3)
{
case 1U:
@@ -118,13 +116,13 @@ void main()
{
}
- committedProceduralAttrs_1 = _S6;
+ committedProceduralAttrs_0 = _S6;
}
else
{
- committedProceduralAttrs_1 = committedProceduralAttrs_0;
+ committedProceduralAttrs_0 = committedProceduralAttrs_0;
}
@@ -132,7 +130,7 @@ void main()
else
{
- committedProceduralAttrs_1 = committedProceduralAttrs_0;
+ committedProceduralAttrs_0 = committedProceduralAttrs_0;
}
@@ -157,21 +155,15 @@ void main()
else
{
}
-
- committedProceduralAttrs_1 = committedProceduralAttrs_0;
-
break;
}
default:
{
-
- committedProceduralAttrs_1 = committedProceduralAttrs_0;
-
break;
}
}
- committedProceduralAttrs_0 = committedProceduralAttrs_1;
+ committedProceduralAttrs_0 = committedProceduralAttrs_0;
}