summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-27 16:41:31 -0800
committerGitHub <noreply@github.com>2023-01-27 16:41:31 -0800
commit4a66e9729175a89833e5db784bb64e6a7f60cdf2 (patch)
tree6a3cb0da3a6682ac0f8b06e66cb8e5fcd6dff279
parent93a6b6119b6b65c4f6b00ca12d745e21b679c82f (diff)
Register allocation during phi elimination. (#2613)
* Register allocation during phi elimination. * Enhance the test case. * Cleanup line breaks in test case. * remove unncessary line break changes. * More cleanups. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--build/visual-studio/slang/slang.vcxproj2
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters6
-rw-r--r--source/slang/slang-emit.cpp26
-rw-r--r--source/slang/slang-ir-eliminate-phis.cpp244
-rw-r--r--source/slang/slang-ir-insts.h8
-rw-r--r--source/slang/slang-ir-ssa-register-allocate.cpp342
-rw-r--r--source/slang/slang-ir-ssa-register-allocate.h24
-rw-r--r--source/slang/slang-ir.cpp63
-rw-r--r--tests/experimental/liveness/liveness-3.slang.expected26
-rw-r--r--tests/experimental/liveness/liveness-5.slang.expected21
-rw-r--r--tests/experimental/liveness/liveness-6.slang.expected21
-rw-r--r--tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected12
-rw-r--r--tests/ir/ssa-reg-alloc.slang68
-rw-r--r--tests/ir/ssa-reg-alloc.slang.expected.txt4
-rw-r--r--tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl14
15 files changed, 788 insertions, 93 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index b1972fc13..9c9a3e4be 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -422,6 +422,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-spirv-legalize.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-spirv-snippet.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-ssa-register-allocate.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-ssa-simplification.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-ssa.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-string-hash.h" />
@@ -603,6 +604,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-spirv-legalize.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-spirv-snippet.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-ssa-register-allocate.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-ssa-simplification.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-ssa.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-string-hash.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 4c9e136e9..34d1b2838 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -372,6 +372,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-spirv-snippet.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-ssa-register-allocate.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-ssa-simplification.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -911,6 +914,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-spirv-snippet.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-ssa-register-allocate.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-ssa-simplification.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 00fa5d3cb..3d923179c 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -811,10 +811,17 @@ Result linkAndOptimizeIR(
lowerBitCast(targetRequest, irModule);
simplifyIR(irModule);
+ eliminateMultiLevelBreak(irModule);
+
+ // As a late step, we need to take the SSA-form IR and move things *out*
+ // of SSA form, by eliminating all "phi nodes" (block parameters) and
+ // introducing explicit temporaries instead. Doing this at the IR level
+ // means that subsequent emit logic doesn't need to contend with the
+ // complexities of blocks with parameters.
+ //
{
// Get the liveness mode.
const LivenessMode livenessMode = codeGenContext->shouldTrackLiveness() ? LivenessMode::Enabled : LivenessMode::Disabled;
-
//
// Downstream targets may benefit from having live-range information for
// local variables, and our IR currently encodes a reasonably good version
@@ -830,22 +837,11 @@ Result linkAndOptimizeIR(
LivenessUtil::addVariableRangeStarts(irModule, livenessMode);
}
- eliminateMultiLevelBreak(irModule);
-
- // As a late step, we need to take the SSA-form IR and move things *out*
- // of SSA form, by eliminating all "phi nodes" (block parameters) and
- // introducing explicit temporaries instead. Doing this at the IR level
- // means that subsequent emit logic doesn't need to contend with the
- // complexities of blocks with parameters.
- //
-
- {
- // We only want to accumulate locations if liveness tracking is enabled.
- eliminatePhis(livenessMode, irModule);
+ // We only want to accumulate locations if liveness tracking is enabled.
+ eliminatePhis(livenessMode, irModule);
#if 0
- dumpIRIfEnabled(codeGenContext, irModule, "PHIS ELIMINATED");
+ dumpIRIfEnabled(codeGenContext, irModule, "PHIS ELIMINATED");
#endif
- }
// If liveness is enabled add liveness ranges based on the accumulated liveness locations
diff --git a/source/slang/slang-ir-eliminate-phis.cpp b/source/slang/slang-ir-eliminate-phis.cpp
index 07d0e7374..818953152 100644
--- a/source/slang/slang-ir-eliminate-phis.cpp
+++ b/source/slang/slang-ir-eliminate-phis.cpp
@@ -1,5 +1,6 @@
// slang-ir-eliminate-phis.cpp
#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-ssa-register-allocate.h"
// This file implements a pass to take code in the Slang IR out out SSA form
// by eliminating all "phi nodes."
@@ -107,8 +108,13 @@ struct PhiEliminationContext
//
void eliminatePhisInFunc(IRGlobalValueWithCode* func)
{
+ // Perform initialization and register allocation
+ // for Phi parameters and other insts that benefit from
+ // converting to memory.
initializePerFuncState(func);
+ // First, we eliminate all the phi instructions (params)
+ // using the result of register allocation.
// The first block in a function is always the entry block,
// and its parameters are different than those of the other blocks;
// they represent the parameters of the *function*. We therefore
@@ -124,6 +130,71 @@ struct PhiEliminationContext
eliminatePhisInBlock(block);
}
+
+ // Next, convert the definition of other ordinary insts to assignments.
+ convertInstDefToRegisterAssignment();
+
+ // Finally, replaces the uses of other ordinary insts to loads from registers.
+ replaceInstUseWithRegisterLoad();
+ }
+
+ void convertInstDefToRegisterAssignment()
+ {
+ IRBuilder builder(m_sharedBuilder);
+
+ for (auto instAlloc : m_registerAllocation.mapInstToRegister)
+ {
+ auto inst = instAlloc.Key;
+ IRInst* registerVar = nullptr;
+ m_mapRegToTempVar.TryGetValue(instAlloc.Value, registerVar);
+ SLANG_RELEASE_ASSERT(registerVar);
+
+ switch (inst->getOp())
+ {
+ case kIROp_Param:
+ continue;
+ case kIROp_UpdateElement:
+ {
+ auto updateInst = as<IRUpdateElement>(inst);
+ builder.setInsertBefore(updateInst);
+ RefPtr<RegisterInfo> oldReg;
+ m_registerAllocation.mapInstToRegister.TryGetValue(updateInst->getOldValue(), oldReg);
+ // If the original value is not assigned to the same register as this inst,
+ // we need to insert a copy.
+ if (instAlloc.Value != oldReg)
+ {
+ builder.emitStore(registerVar, updateInst->getOldValue());
+ }
+ // Perform update on the register var.
+ auto elementAddr = builder.emitElementAddress(registerVar, updateInst->getAccessChain().getArrayView());
+ builder.emitStore(elementAddr, updateInst->getElementValue());
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+
+ void replaceInstUseWithRegisterLoad()
+ {
+ IRBuilder builder(m_sharedBuilder);
+
+ for (auto instAlloc : m_registerAllocation.mapInstToRegister)
+ {
+ auto inst = instAlloc.Key;
+ IRInst* registerVar = nullptr;
+ m_mapRegToTempVar.TryGetValue(instAlloc.Value, registerVar);
+ SLANG_RELEASE_ASSERT(registerVar);
+ while (auto use = inst->firstUse)
+ {
+ auto user = use->getUser();
+ m_builder.setInsertBefore(user);
+ auto newVal = m_builder.emitLoad(registerVar);
+ use->set(newVal);
+ }
+ inst->removeAndDeallocate();
+ }
}
// In order to facilitate breaking things down into subroutines, we use a
@@ -132,6 +203,8 @@ struct PhiEliminationContext
//
IRGlobalValueWithCode* m_func = nullptr;
RefPtr<IRDominatorTree> m_dominatorTree;
+ RegisterAllocationResult m_registerAllocation;
+ Dictionary<RegisterInfo*, IRInst*> m_mapRegToTempVar;
// Because we use the same `PhiEliminationContext` to process all of
// the functions in a module, we need to set up these per-function
@@ -141,6 +214,83 @@ struct PhiEliminationContext
{
m_func = func;
m_dominatorTree = nullptr;
+ m_registerAllocation = allocateRegistersForFunc(func, m_dominatorTree);
+ m_mapRegToTempVar = createTempVarForInsts(func);
+ }
+
+ Dictionary<RegisterInfo*, IRInst*> createTempVarForInsts(IRGlobalValueWithCode* func)
+ {
+ Dictionary<RegisterInfo*, IRInst*> mapRegToVar;
+ for (auto& regList : m_registerAllocation.mapTypeToRegisterList)
+ {
+ auto type = regList.Key;
+ for (auto reg : regList.Value)
+ {
+ // Find the common dominator for all the insts, and determine the latest insertion
+ // point of the tempVar inst.
+ IRBlock* dom = nullptr;
+ IRInst* insertionPoint = nullptr;
+ for (auto inst : reg->insts)
+ {
+ // Determine where the temp register var should be inserted if
+ // it represents only `inst`.
+ IRBlock* thisDom = as<IRBlock>(inst->getParent());
+ IRInst* thisInsertionPoint = inst;
+ if (inst->getOp() == kIROp_Param)
+ {
+ thisDom = getDominatorTree()->getImmediateDominator(thisDom);
+ thisInsertionPoint = thisDom->getTerminator();
+ }
+
+ // Push the insertionPoint early enough to dominate `thisInsertionPoint`.
+ if (dom == nullptr)
+ {
+ dom = thisDom;
+ insertionPoint = thisInsertionPoint;
+ }
+ else
+ {
+ auto domTree = getDominatorTree();
+ while (!domTree->dominates(dom, thisDom) && dom != func->getFirstBlock())
+ {
+ dom = domTree->getImmediateDominator(dom);
+ insertionPoint = dom->getTerminator();
+ }
+ }
+ // Move insertion point to before thisInsertionPoint.
+ if (dom == thisDom)
+ {
+ bool isInsertionPointBeforeCurrentInst = false;
+ for (auto current = insertionPoint; current; current = current->getNextInst())
+ {
+ if (current == thisInsertionPoint)
+ {
+ isInsertionPointBeforeCurrentInst = true;
+ break;
+ }
+ }
+ if (!isInsertionPointBeforeCurrentInst)
+ insertionPoint = thisInsertionPoint;
+ }
+ }
+ SLANG_ASSERT(dom);
+ SLANG_ASSERT(insertionPoint && insertionPoint->getParent() == dom);
+ m_builder.setInsertBefore(insertionPoint);
+
+ // Note that the `emitVar` operation expects to be passed the
+ // type *stored* in the variable, but the IR `var` instruction
+ // itself will have a pointer type. Thus if `param` has type
+ // `T`, then `temp` will have type `T*`.
+ //
+ auto temp = m_builder.emitVar(type);
+ for (auto inst : reg->insts)
+ {
+ inst->transferDecorationsTo(temp);
+ }
+ mapRegToVar[reg] = temp;
+ }
+ }
+ return mapRegToVar;
}
// The dominator tree for the function is computed on demand and
@@ -177,7 +327,7 @@ struct PhiEliminationContext
// 1. Create a temporary corresponding to each of the
// parameters of `block`.
//
- createTempsForParams(block);
+ collectPhiInfoForParams(block);
//
// 2. For each predecessor of `block`, eliminate the arguments
// it passes, by assigning them to the temporaries.
@@ -216,7 +366,7 @@ struct PhiEliminationContext
//
Dictionary<IRInst*, Index> mapParamToIndex;
- void createTempsForParams(IRBlock* block)
+ void collectPhiInfoForParams(IRBlock* block)
{
// The temporaries used to replace the parameters of `block`
// must be read-able any where that the parameters were accessible.
@@ -277,18 +427,30 @@ struct PhiEliminationContext
Index paramIndex = paramCounter++;
mapParamToIndex.Add(param, paramIndex);
- // Note that the `emitVar` operation expects to be passed the
- // type *stored* in the variable, but the IR `var` instruction
- // itself will have a pointer type. Thus if `param` has type
- // `T`, then `temp` will have type `T*`.
- //
- auto temp = m_builder.emitVar(param->getDataType());
- //
- // Because we will be eliminating the paramter, we can transfer
- // any decorations that were added to it (notably any name hint)
- // to the temporary that will replace it.
- //
- param->transferDecorationsTo(temp);
+ IRInst* temp = nullptr;
+
+ // Have we already allocated a register for this inst?
+ // If so we use the var for that register.
+ if (auto registerInfo = m_registerAllocation.mapInstToRegister.TryGetValue(param))
+ {
+ m_mapRegToTempVar.TryGetValue(registerInfo->get(), temp);
+ }
+
+ if (!temp)
+ {
+ // Note that the `emitVar` operation expects to be passed the
+ // type *stored* in the variable, but the IR `var` instruction
+ // itself will have a pointer type. Thus if `param` has type
+ // `T`, then `temp` will have type `T*`.
+ //
+ temp = m_builder.emitVar(param->getDataType());
+ //
+ // Because we will be eliminating the paramter, we can transfer
+ // any decorations that were added to it (notably any name hint)
+ // to the temporary that will replace it.
+ //
+ param->transferDecorationsTo(temp);
+ }
// The other main auxilliary sxtructure is used to track
// both per-parameter information (which we can fill in
@@ -300,7 +462,7 @@ struct PhiEliminationContext
PhiInfo phiInfo;
auto& paramInfo = phiInfo.param;
paramInfo.param = param;
- paramInfo.temp = temp;
+ paramInfo.temp = cast<IRVar>(temp);
phiInfos.add(phiInfo);
}
}
@@ -758,6 +920,45 @@ struct PhiEliminationContext
oldBranch->removeAndDeallocate();
}
+ bool canLoadBeFoldedAtInst(IRInst* load, IRInst* useSite)
+ {
+ if (load->getParent() != useSite->getParent())
+ return false;
+
+ auto addr = load->getOperand(0);
+ switch (addr->getOp())
+ {
+ case kIROp_Var:
+ case kIROp_Param:
+ break;
+ default:
+ return false;
+ }
+ for (auto inst = load; inst; inst = inst->getNextInst())
+ {
+ if (inst == useSite)
+ {
+ return true;
+ }
+ switch (inst->getOp())
+ {
+ case kIROp_Store:
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ if (inst->getOperand(0) == addr)
+ return false;
+ break;
+ default:
+ if (inst->mightHaveSideEffects())
+ return false;
+ break;
+ }
+ }
+ // Should never reach here if useSite appears after inst.
+ // Return false to be safe.
+ return false;
+ }
+
// The most subtle bit of logic, which relies on the data structures
// we have built so far, is the way we attempt to perform assignments
// that have become ready.
@@ -809,7 +1010,18 @@ struct PhiEliminationContext
//
if ((*srcArg.currentValPtr)->getOp() != kIROp_undefined)
{
- m_builder.emitStore(dstParam.temp, *srcArg.currentValPtr);
+ // If we are trying to emit a store directly after a load from the same var,
+ // skip the store.
+ SLANG_ASSERT(m_builder.getInsertLoc().getMode() == IRInsertLoc::Mode::Before);
+ auto srcLoad = as<IRLoad>(*srcArg.currentValPtr);
+ if (srcLoad && srcLoad->getOperand(0) == dstParam.temp &&
+ canLoadBeFoldedAtInst(srcLoad, m_builder.getInsertLoc().getInst()))
+ {
+ }
+ else
+ {
+ m_builder.emitStore(dstParam.temp, *srcArg.currentValPtr);
+ }
}
//
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 405df4073..8b30a02dd 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3161,6 +3161,14 @@ public:
IRInst* basePtr,
IRInst* index);
+ IRInst* emitElementAddress(
+ IRInst* basePtr,
+ IRInst* index);
+
+ IRInst* emitElementAddress(
+ IRInst* basePtr,
+ const ArrayView<IRInst*>& accessChain);
+
IRInst* emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement);
IRInst* emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement);
diff --git a/source/slang/slang-ir-ssa-register-allocate.cpp b/source/slang/slang-ir-ssa-register-allocate.cpp
new file mode 100644
index 000000000..2f06797fa
--- /dev/null
+++ b/source/slang/slang-ir-ssa-register-allocate.cpp
@@ -0,0 +1,342 @@
+// slang-ir-ssa-register-allocate.cpp
+#include "slang-ir-ssa-register-allocate.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-dominators.h"
+
+
+namespace Slang {
+
+// A context for computing and caching reachability between blocks on the CFG.
+struct ReachabilityContext
+{
+ struct BlockPair
+ {
+ IRBlock* first;
+ IRBlock* second;
+ HashCode getHashCode()
+ {
+ Hasher h;
+ h.hashValue(first);
+ h.hashValue(second);
+ return h.getResult();
+ }
+ bool operator == (const BlockPair& other)
+ {
+ return first == other.first && second == other.second;
+ }
+ };
+ Dictionary<BlockPair, bool> reachabilityResults;
+
+ List<IRBlock*> workList;
+ HashSet<IRBlock*> reachableBlocks;
+
+ // Computes whether block1 can reach block2.
+ // A block is considered not reachable from itself unless there is a backedge in the CFG.
+ bool computeReachability(IRBlock* block1, IRBlock* block2)
+ {
+ workList.clear();
+ reachableBlocks.Clear();
+ workList.add(block1);
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto src = workList[i];
+ for (auto successor : src->getSuccessors())
+ {
+ if (successor == block2)
+ return true;
+ if (reachableBlocks.Add(successor))
+ workList.add(successor);
+ }
+ }
+ return false;
+ }
+
+ bool isBlockReachable(IRBlock* from, IRBlock* to)
+ {
+ BlockPair pair;
+ pair.first = from;
+ pair.second = to;
+ bool result = false;
+ if (reachabilityResults.TryGetValue(pair, result))
+ return result;
+ result = computeReachability(from, to);
+ reachabilityResults[pair] = result;
+ return result;
+ }
+
+ bool isInstReachable(IRInst* inst1, IRInst* inst2)
+ {
+ if (isBlockReachable(as<IRBlock>(inst1->getParent()), as<IRBlock>(inst2->getParent())))
+ return true;
+
+ // If the parent blocks are not reachable, but inst1 and inst2 are in the same block,
+ // we test if inst2 appears after inst1.
+ if (inst1->getParent() == inst2->getParent())
+ {
+ for (auto inst = inst1->getNextInst(); inst; inst = inst->getNextInst())
+ {
+ if (inst == inst2)
+ return true;
+ }
+ }
+
+ return false;
+ }
+};
+
+struct RegisterAllocateContext
+{
+ OrderedDictionary<IRType*, List<RefPtr<RegisterInfo>>> mapTypeToRegisterList;
+ List<RefPtr<RegisterInfo>>& getRegisterListForType(IRType* type)
+ {
+ if (auto list = mapTypeToRegisterList.TryGetValue(type))
+ {
+ return *list;
+ }
+ mapTypeToRegisterList[type] = List<RefPtr<RegisterInfo>>();
+ return mapTypeToRegisterList[type].GetValue();
+ }
+
+ void assignInstToNewRegister(List<RefPtr<RegisterInfo>>& regList, IRInst* inst)
+ {
+ auto reg = new RegisterInfo();
+ reg->type = inst->getFullType();
+ reg->insts.add(inst);
+ regList.add(reg);
+ }
+
+ bool areInstsPreferredToBeCoalescedImpl(IRInst* inst0, IRInst* inst1)
+ {
+ switch (inst1->getOp())
+ {
+ case kIROp_UpdateElement:
+ if (inst0 == inst1->getOperand(0))
+ return true;
+ break;
+ default:
+ break;
+ }
+
+ // If isnts have the same name, prefer to coalesce them.
+ auto name1 = inst0->findDecoration<IRNameHintDecoration>();
+ auto name2 = inst1->findDecoration<IRNameHintDecoration>();
+ if (name1 && name2 && name1->getName() == name2->getName())
+ return true;
+
+ return false;
+ }
+ bool areInstsPreferredToBeCoalesced(IRInst* inst0, IRInst* inst1)
+ {
+ return areInstsPreferredToBeCoalescedImpl(inst0, inst1) ||
+ areInstsPreferredToBeCoalescedImpl(inst1, inst0);
+ }
+
+ bool isRegisterPreferred(RegisterInfo* existingRegister, RegisterInfo* newRegister, IRInst* inst)
+ {
+ int preferredCountExistingReg = 0;
+ int preferredCountNewReg = 0;
+ for (auto existingInst : existingRegister->insts)
+ {
+ if (areInstsPreferredToBeCoalesced(existingInst, inst))
+ preferredCountExistingReg++;
+ }
+ for (auto existingInst : newRegister->insts)
+ {
+ if (areInstsPreferredToBeCoalesced(existingInst, inst))
+ preferredCountNewReg++;
+ }
+ return preferredCountNewReg > preferredCountExistingReg;
+ }
+
+ bool canCoalesce(IRInst* inst1, IRInst* inst2)
+ {
+ // If two insts are Phis from the same block, don't coalesce.
+ // This logic should not be needed in most cases because params from the same block should
+ // always interfere anyways. However if a param is never used for for
+ // some reason not DCE'd out, we don't want it to share the same register as another
+ // param to avoid problems during phi elimination.
+ if (inst1->getParent() == inst2->getParent() && inst1->getOp() == kIROp_Param &&
+ inst2->getOp() == kIROp_Param)
+ return false;
+
+ // If two insts are coming from two separate user defined names, don't coalesce them into
+ // the same register.
+ auto name1 = inst1->findDecoration<IRNameHintDecoration>();
+ auto name2 = inst2->findDecoration<IRNameHintDecoration>();
+
+ if (name1 && !name2 || !name1 && name2)
+ return false;
+
+ if (!name1 || !name2)
+ return true;
+ if (name1->getName() != name2->getName())
+ return false;
+
+ return true;
+ }
+
+ RegisterAllocationResult allocateRegisters(IRGlobalValueWithCode* func, RefPtr<IRDominatorTree>& inOutDom)
+ {
+ ReachabilityContext reachabilityContext;
+ mapTypeToRegisterList.Clear();
+
+ auto dom = computeDominatorTree(func);
+ inOutDom = dom;
+
+ // Note that if inst A does not dominate inst B, then A can't be alive at B.
+ // Therefore we only need to test interference against insts that dominates the
+ // current inst.
+ //
+ // We can visit the dominance tree in pre-order and assign insts to registers.
+ // This order allows us to easily track what is dominating the current inst.
+
+ // We track the insts dominating the current location in a stack.
+ List<IRInst*> dominatingInsts;
+ HashSet<IRInst*> dominatingInstSet;
+
+ struct WorkStackItem
+ {
+ IRBlock* block;
+ Index dominatingInstCount;
+ WorkStackItem() = default;
+ WorkStackItem(IRBlock* inBlock, Index inDominatingInstCount)
+ {
+ block = inBlock;
+ dominatingInstCount = inDominatingInstCount;
+ }
+ };
+ List<WorkStackItem> workStack;
+ workStack.add(WorkStackItem(func->getFirstBlock(), 0));
+
+ while (workStack.getCount())
+ {
+ auto item = workStack.getLast();
+ workStack.removeLast();
+
+ // Pop dominatingInst stack to correct location.
+ for (Index i = item.dominatingInstCount; i < dominatingInsts.getCount(); i++)
+ dominatingInstSet.Remove(dominatingInsts[i]);
+ dominatingInsts.setCount(item.dominatingInstCount);
+
+ for (auto inst : item.block->getChildren())
+ {
+ if (!instNeedsProcessing(func, inst))
+ continue;
+ // This is an inst we need to allocate register for.
+ // Find register list for this type.
+ auto& registers = getRegisterListForType(inst->getFullType());
+ RegisterInfo* allocatedReg = nullptr;
+ for (auto reg : registers)
+ {
+ // Can we assign inst to this reg?
+ // We answer this by checking if any insts already assigned
+ // to this register is alive. If none are alive we can assign
+ // the register.
+ bool hasInterference = false;
+ for (auto existingInst : reg->insts)
+ {
+ // If `existingInst` does not dominate `inst`, it
+ // can't be alive here and during the entire life-time of the `inst`.
+ // This means that `inst` and `existingInst` won't interfere.
+ if (!dominatingInstSet.Contains(existingInst))
+ continue;
+
+ // If `existingInst` does dominate `inst`, we need to check all
+ // its use sites U to see if there is a path from `inst` to U.
+ // The idea is that is `existingInst` is never used anywhere after
+ // `inst`, then its lifetime ended before `inst` is defined, so it
+ // is still fine to place them in the same register.
+ for (auto use = existingInst->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser() == inst)
+ continue;
+
+ if (!canCoalesce(existingInst, inst) ||
+ reachabilityContext.isInstReachable(inst, use->getUser()))
+ {
+ hasInterference = true;
+ goto endRegInstCheck;
+ }
+ }
+ }
+ endRegInstCheck:;
+ if (!hasInterference)
+ {
+ if (!allocatedReg || isRegisterPreferred(allocatedReg, reg, inst))
+ {
+ allocatedReg = reg;
+ }
+ }
+ }
+ if (!allocatedReg)
+ {
+ assignInstToNewRegister(registers, inst);
+ }
+ else
+ {
+ allocatedReg->insts.add(inst);
+ }
+ dominatingInsts.add(inst);
+ dominatingInstSet.Add(inst);
+ }
+
+ // Recursively visit idom children.
+ for (auto idomChild : dom->getImmediatelyDominatedBlocks(item.block))
+ {
+ workStack.add(WorkStackItem(idomChild, dominatingInsts.getCount()));
+ }
+ }
+
+ RegisterAllocationResult result;
+ result.mapTypeToRegisterList = _Move(mapTypeToRegisterList);
+ for (auto& regList : result.mapTypeToRegisterList)
+ {
+ for (auto reg : regList.Value)
+ {
+ for (auto inst : reg->insts)
+ {
+ result.mapInstToRegister[inst] = reg;
+ }
+ }
+ }
+ return result;
+ }
+ bool instNeedsProcessing(IRGlobalValueWithCode* func, IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Param:
+ if (inst->getParent() == func->getFirstBlock())
+ return false;
+ return true;
+ case kIROp_UpdateElement:
+ return true;
+ default:
+ return false;
+ }
+ }
+ bool needProcessing(IRGlobalValueWithCode* func)
+ {
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (instNeedsProcessing(func, inst))
+ return true;
+ }
+ }
+ return false;
+ }
+};
+
+RegisterAllocationResult allocateRegistersForFunc(IRGlobalValueWithCode* func, RefPtr<IRDominatorTree>& inOutDom)
+{
+ RegisterAllocateContext context;
+ if (context.needProcessing(func))
+ return context.allocateRegisters(func, inOutDom);
+ return RegisterAllocationResult();
+}
+
+}
diff --git a/source/slang/slang-ir-ssa-register-allocate.h b/source/slang/slang-ir-ssa-register-allocate.h
new file mode 100644
index 000000000..1e8c586cd
--- /dev/null
+++ b/source/slang/slang-ir-ssa-register-allocate.h
@@ -0,0 +1,24 @@
+// slang-ir-ssa-register-allocate.h
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+struct IRDominatorTree;
+
+struct RegisterInfo : RefObject
+{
+ IRType* type;
+ List<IRInst*> insts;
+};
+
+struct RegisterAllocationResult
+{
+ OrderedDictionary<IRType*, List<RefPtr<RegisterInfo>>> mapTypeToRegisterList;
+ Dictionary<IRInst*, RefPtr<RegisterInfo>> mapInstToRegister;
+};
+
+RegisterAllocationResult allocateRegistersForFunc(IRGlobalValueWithCode* func, RefPtr<IRDominatorTree>& inOutDom);
+
+}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 0434ff682..845232ae6 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4378,7 +4378,7 @@ namespace Slang
type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount());
}
SLANG_RELEASE_ASSERT(type);
- auto inst = createInst<IRFieldAddress>(
+ auto inst = createInst<IRGetElement>(
this,
kIROp_GetElement,
type,
@@ -4435,6 +4435,67 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitElementAddress(
+ IRInst* basePtr,
+ IRInst* index)
+ {
+ IRType* type = nullptr;
+ auto basePtrType = as<IRPtrTypeBase>(basePtr->getDataType());
+ if (auto arrayType = as<IRArrayType>(basePtrType->getValueType()))
+ {
+ type = arrayType->getElementType();
+ }
+ else if (auto vectorType = as<IRVectorType>(basePtrType->getValueType()))
+ {
+ type = vectorType->getElementType();
+ }
+ else if (auto matrixType = as<IRMatrixType>(basePtrType->getValueType()))
+ {
+ type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount());
+ }
+ SLANG_RELEASE_ASSERT(type);
+ auto inst = createInst<IRGetElementPtr>(
+ this,
+ kIROp_GetElementPtr,
+ getPtrType(type),
+ basePtr,
+ index);
+
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitElementAddress(
+ IRInst* basePtr,
+ const ArrayView<IRInst*>& accessChain)
+ {
+ for (auto access : accessChain)
+ {
+ auto basePtrType = cast<IRPtrTypeBase>(basePtr->getDataType());
+ IRType* resultType = nullptr;
+ if (auto structKey = as<IRStructKey>(access))
+ {
+ auto structType = as<IRStructType>(basePtrType->getValueType());
+ SLANG_RELEASE_ASSERT(structType);
+ for (auto field : structType->getFields())
+ {
+ if (field->getKey() == structKey)
+ {
+ resultType = field->getFieldType();
+ break;
+ }
+ }
+ SLANG_RELEASE_ASSERT(resultType);
+ basePtr = emitFieldAddress(getPtrType(resultType), basePtr, structKey);
+ }
+ else
+ {
+ basePtr = emitElementAddress(basePtr, access);
+ }
+ }
+ return basePtr;
+ }
+
IRInst* IRBuilder::emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement)
{
auto inst = createInst<IRUpdateElement>(
diff --git a/tests/experimental/liveness/liveness-3.slang.expected b/tests/experimental/liveness/liveness-3.slang.expected
index 58f562d86..dac9ff1bd 100644
--- a/tests/experimental/liveness/liveness-3.slang.expected
+++ b/tests/experimental/liveness/liveness-3.slang.expected
@@ -52,13 +52,13 @@ int calcThing_0(int offset_0)
livenessStart_2(idx_0, 0);
const int _S3[3] = { 0, 0, 0 };
idx_0 = _S3;
- int _S4 = _S2;
int i_0;
- int _S5;
+ int _S4;
+ int _S5 = _S2;
livenessStart_1(i_0, 0);
i_0 = 0;
- livenessStart_1(_S5, 0);
- _S5 = _S4;
+ livenessStart_1(_S4, 0);
+ _S4 = _S5;
for(;;)
{
if(i_0 < 17)
@@ -74,32 +74,32 @@ int calcThing_0(int offset_0)
int _S7;
if(_S6 != 0)
{
- int _S8 = _S5;
- livenessEnd_0(_S5, 0);
+ int _S8 = _S4;
+ livenessEnd_0(_S4, 0);
int _S9 = _S8 + 1;
livenessStart_1(_S7, 0);
_S7 = _S9;
}
else
{
- int _S10 = _S5;
- livenessEnd_0(_S5, 0);
+ int _S10 = _S4;
+ livenessEnd_0(_S4, 0);
livenessStart_1(_S7, 0);
_S7 = _S10;
}
idx_0[modRange_0] = idx_0[modRange_0] + (_S7 + i_0);
i_0 = i_0 + 1;
- livenessStart_1(_S5, 0);
+ livenessStart_1(_S4, 0);
int _S11 = _S7;
livenessEnd_0(_S7, 0);
- _S5 = _S11;
+ _S4 = _S11;
}
livenessEnd_0(i_0, 0);
livenessEnd_0(_S2, 0);
int _S12 = (k_0 + 7) % 5;
if(_S12 == 4)
{
- livenessEnd_0(_S5, 0);
+ livenessEnd_0(_S4, 0);
livenessEnd_1(idx_0, 0);
livenessEnd_0(k_0, 0);
livenessEnd_2(another_0, 0);
@@ -114,8 +114,8 @@ int calcThing_0(int offset_0)
int total_1 = _S16 + _S15;
k_0 = k_0 + 1;
livenessStart_1(_S2, 0);
- int _S17 = _S5;
- livenessEnd_0(_S5, 0);
+ int _S17 = _S4;
+ livenessEnd_0(_S4, 0);
_S2 = _S17;
livenessStart_1(total_0, 0);
total_0 = total_1;
diff --git a/tests/experimental/liveness/liveness-5.slang.expected b/tests/experimental/liveness/liveness-5.slang.expected
index ea6e37036..a8a9707d7 100644
--- a/tests/experimental/liveness/liveness-5.slang.expected
+++ b/tests/experimental/liveness/liveness-5.slang.expected
@@ -72,35 +72,26 @@ int calcThing_0(int offset_0)
}
livenessEnd_0(k_0, 0);
livenessEnd_1(another_0, 0);
- int total_2;
if(total_0 > 4)
{
- int _S5 = total_0;
- livenessEnd_0(total_0, 0);
- int _S6 = - _S5;
- livenessStart_1(total_2, 0);
- total_2 = _S6;
+ total_0 = - total_0;
}
else
{
- int _S7 = total_0;
- livenessEnd_0(total_0, 0);
- livenessStart_1(total_2, 0);
- total_2 = _S7;
}
- return total_2;
+ return total_0;
}
-layout(std430, binding = 0) buffer _S8 {
+layout(std430, binding = 0) buffer _S5 {
int _data[];
} outputBuffer_0;
layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in;
void main()
{
int index_0 = int(gl_GlobalInvocationID.x);
- uint _S9 = uint(index_0);
- int _S10 = calcThing_0(index_0);
- ((outputBuffer_0)._data[(_S9)]) = _S10;
+ uint _S6 = uint(index_0);
+ int _S7 = calcThing_0(index_0);
+ ((outputBuffer_0)._data[(_S6)]) = _S7;
return;
}
diff --git a/tests/experimental/liveness/liveness-6.slang.expected b/tests/experimental/liveness/liveness-6.slang.expected
index 26a537330..402e19886 100644
--- a/tests/experimental/liveness/liveness-6.slang.expected
+++ b/tests/experimental/liveness/liveness-6.slang.expected
@@ -81,35 +81,26 @@ int calcThing_0(int offset_0)
}
livenessEnd_0(k_0, 0);
livenessEnd_1(another_0, 0);
- int total_3;
if(total_0 > 4)
{
- int _S8 = total_0;
- livenessEnd_0(total_0, 0);
- int _S9 = - _S8;
- livenessStart_1(total_3, 0);
- total_3 = _S9;
+ total_0 = - total_0;
}
else
{
- int _S10 = total_0;
- livenessEnd_0(total_0, 0);
- livenessStart_1(total_3, 0);
- total_3 = _S10;
}
- return total_3;
+ return total_0;
}
-layout(std430, binding = 0) buffer _S11 {
+layout(std430, binding = 0) buffer _S8 {
int _data[];
} outputBuffer_0;
layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in;
void main()
{
int index_0 = int(gl_GlobalInvocationID.x);
- uint _S12 = uint(index_0);
- int _S13 = calcThing_0(index_0);
- ((outputBuffer_0)._data[(_S12)]) = _S13;
+ uint _S9 = uint(index_0);
+ int _S10 = calcThing_0(index_0);
+ ((outputBuffer_0)._data[(_S9)]) = _S10;
return;
}
diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected
index 15221b921..847eab926 100644
--- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected
+++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang.1.expected
@@ -65,16 +65,14 @@ uint calcValue_0(hitObjectNV hit_0)
else
{
bool _S7 = (hitObjectIsMissNV((hit_0)));
- uint r_3;
if(_S7)
{
- r_3 = 1U;
+ r_0 = 1U;
}
else
{
- r_3 = 0U;
+ r_0 = 0U;
}
- r_0 = r_3;
}
return r_0;
}
@@ -96,13 +94,13 @@ void main()
RayDesc_0 _S11 = ray_2;
hitObjectNV hitObj_0;
hitObjectRecordHitWithIndexNV(hitObj_0, scene_0, int(uint(idx_0)), int(uint(idx_0 * 2)), int(uint(idx_0 * 3)), 0U, 0U, _S11.Origin_0, _S11.TMin_0, _S11.Direction_0, _S11.TMax_0, (0));
- uint r_4 = calcValue_0(hitObj_0);
+ uint r_3 = calcValue_0(hitObj_0);
RayDesc_0 _S12 = ray_2;
hitObjectNV hitObj_1;
hitObjectRecordHitNV(hitObj_1, scene_0, int(uint(idx_0)), int(uint(idx_0 * 3)), int(uint(idx_0 * 2)), 0U, 0U, 4U, _S12.Origin_0, _S12.TMin_0, _S12.Direction_0, _S12.TMax_0, (0));
uint _S13 = calcValue_0(hitObj_1);
- uint r_5 = r_4 + _S13;
- ((outputBuffer_0)._data[(uint(idx_0))]) = r_5;
+ uint r_4 = r_3 + _S13;
+ ((outputBuffer_0)._data[(uint(idx_0))]) = r_4;
return;
}
diff --git a/tests/ir/ssa-reg-alloc.slang b/tests/ir/ssa-reg-alloc.slang
new file mode 100644
index 000000000..3bfd795a8
--- /dev/null
+++ b/tests/ir/ssa-reg-alloc.slang
@@ -0,0 +1,68 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+int test1(uint p)
+{
+ int a, b;
+ if (p > 1)
+ {
+ a = 1;
+ b = 2;
+ }
+ else
+ {
+ a = 2;
+ b = 3;
+ }
+ // b is not used and should not interfere the result of a.
+ return a;
+}
+
+int test2(uint p)
+{
+ int a, b;
+ if (p > 1)
+ {
+ a = 1;
+ b = 2;
+ }
+ else
+ {
+ a = 2;
+ b = 3;
+ }
+ // a is not used and should not interfere the result of b.
+ return b;
+}
+
+int test3(uint p)
+{
+ int a = 1;
+ int b = 5;
+
+ if (p > 0) a = 2;
+ if (p > 0) b = 3;
+
+ // a and b are now register allocated.
+ // The first block of the loop will have IRParams in the form of (a, b)
+ for (int i = 0; i <= p + 2; i++)
+ {
+ let tmp = a;
+ a = b;
+ b = tmp;
+ // The branch back to the loop header will have phi args: (b, a)
+ // Phi-elmination must handle this case of concurrent assignment correctly.
+ }
+ return a - b; // should be 4 when p == 0.
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ let rs1 = test1(dispatchThreadID.x) + test2(dispatchThreadID.x);
+ outputBuffer[0] = rs1;
+ outputBuffer[1] = test3(0);
+}
diff --git a/tests/ir/ssa-reg-alloc.slang.expected.txt b/tests/ir/ssa-reg-alloc.slang.expected.txt
new file mode 100644
index 000000000..e43a2b945
--- /dev/null
+++ b/tests/ir/ssa-reg-alloc.slang.expected.txt
@@ -0,0 +1,4 @@
+5
+4
+0
+0
diff --git a/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl b/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl
index f9e94bb30..389dae05a 100644
--- a/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl
+++ b/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl
@@ -106,14 +106,10 @@ void main()
tHit_1 = 0.00000000000000000000;
bool _S6 = myProceduralIntersection_0(tHit_1, candidateProceduralAttrs_0);
- MyProceduralHitAttrs_0 committedProceduralAttrs_2;
-
if(_S6)
{
bool _S7 = myProceduralAnyHit_0(payload_5);
- MyProceduralHitAttrs_0 committedProceduralAttrs_3;
-
if(_S7)
{
rayQueryGenerateIntersectionEXT(query_0, tHit_1);
@@ -126,28 +122,24 @@ void main()
{
}
- committedProceduralAttrs_3 = _S8;
+ committedProceduralAttrs_1 = _S8;
}
else
{
- committedProceduralAttrs_3 = committedProceduralAttrs_0;
+ committedProceduralAttrs_1 = committedProceduralAttrs_0;
}
- committedProceduralAttrs_2 = committedProceduralAttrs_3;
-
}
else
{
- committedProceduralAttrs_2 = committedProceduralAttrs_0;
+ committedProceduralAttrs_1 = committedProceduralAttrs_0;
}
- committedProceduralAttrs_1 = committedProceduralAttrs_2;
-
break;
}
case 0U: