summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-28 23:28:23 -0700
committerGitHub <noreply@github.com>2023-04-28 23:28:23 -0700
commitc571bcb025009f9c662e8d631fa49dbfed560287 (patch)
tree3ade836c28920210b3dc1af5e8447d4804dc03ad
parent5adecbe837d27cf4e0a554ae13a0338743a8cb4b (diff)
SSA Register Allocation improvements. (#2857)
* SSA Register Allocation improvements. * Fix. * Rename `Use`->`UseOrPseudoUse`. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-emit-spirv.cpp18
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp58
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h32
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp15
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp28
-rw-r--r--source/slang/slang-ir-ssa-register-allocate.cpp62
-rw-r--r--source/slang/slang-ir-util.cpp15
-rw-r--r--source/slang/slang-ir-util.h3
-rw-r--r--tests/experimental/liveness/liveness-3.slang.expected38
-rw-r--r--tests/ir/loop-phi-coalesce.slang49
10 files changed, 231 insertions, 87 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index fc2505d63..b17ccf483 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1473,7 +1473,20 @@ struct SPIRVEmitContext
//
for( auto irBlock : irFunc->getBlocks() )
{
- emitInst(spvFunc, irBlock, SpvOpLabel, kResultID);
+ auto spvBlock = emitInst(spvFunc, irBlock, SpvOpLabel, kResultID);
+ if (irBlock == irFunc->getFirstBlock())
+ {
+ // OpVariable
+ // All variables used in the function must be declared before anything else.
+ for (auto block : irFunc->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (as<IRVar>(inst))
+ emitLocalInst(spvBlock, inst);
+ }
+ }
+ }
// In addition to normal basic blocks,
// all loops gets a header block.
@@ -1517,6 +1530,9 @@ struct SPIRVEmitContext
// Any instructions local to the block will be emitted as children
// of the block.
//
+ // Skip vars because they are already emitted.
+ if (as<IRVar>(irInst))
+ continue;
emitLocalInst(spvBlock, irInst);
if (irInst->getOp() == kIROp_loop)
pendingLoopInsts.add(as<IRLoop>(irInst));
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 202660682..135c72556 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -266,9 +266,8 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
- List<IRUse*> workList;
- HashSet<IRUse*> processedUses;
-
+ List<UseOrPseudoUse> workList;
+ HashSet<UseOrPseudoUse> processedUses;
HashSet<IRUse*> usesToReplace;
auto addPrimalOperandsToWorkList = [&](IRInst* inst)
@@ -358,8 +357,8 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
SLANG_ASSERT(!checkpointInfo->storeSet.contains(result.instToRecompute));
checkpointInfo->recomputeSet.add(result.instToRecompute);
- if (isDifferentialInst(use->getUser()))
- usesToReplace.add(use);
+ if (isDifferentialInst(use.user) && use.irUse)
+ usesToReplace.add(use.irUse);
if (auto param = as<IRParam>(result.instToRecompute))
{
@@ -392,7 +391,13 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
{
IRUse* storeUse = findLatestUniqueWriteUse(var);
if (storeUse)
- workList.add(storeUse);
+ {
+ // When we have a var and a store/call insts that writes to the var,
+ // we treat as if there is a pseudo-use of the store/call to compute
+ // the var inst, i.e. the var depends on the store/call, despite
+ // the IR's def-use chain doesn't reflect this.
+ workList.add(UseOrPseudoUse(var, storeUse->getUser()));
+ }
}
else
{
@@ -400,29 +405,6 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
}
}
}
- else if (result.mode == HoistResult::Mode::Invert)
- {
- auto instToInvert = result.inversionInfo.instToInvert;
-
- SLANG_RELEASE_ASSERT(containsOperand(instToInvert, use->getUser()));
- SLANG_RELEASE_ASSERT(result.inversionInfo.targetInsts.contains(use->getUser()));
-
- if (isDifferentialInst(use->getUser()))
- usesToReplace.add(use);
-
- checkpointInfo->invertSet.add(instToInvert);
-
- if (checkpointInfo->invInfoMap.containsKey(instToInvert))
- {
- List<IRInst*> currOperands = checkpointInfo->invInfoMap[instToInvert].getValue().requiredOperands;
- for (Index ii = 0; ii < result.inversionInfo.requiredOperands.getCount(); ii++)
- {
- SLANG_RELEASE_ASSERT(result.inversionInfo.requiredOperands[ii] == currOperands[ii]);
- }
- }
- else
- checkpointInfo->invInfoMap[instToInvert] = result.inversionInfo;
- }
}
// If a var or call is in recomputeSet, move any var/calls associated with the same call to
@@ -1480,9 +1462,9 @@ static bool shouldStoreVar(IRVar* var)
return false;
}
-bool canRecompute(IRUse* use)
+bool canRecompute(UseOrPseudoUse use)
{
- if (auto load = as<IRLoad>(use->get()))
+ if (auto load = as<IRLoad>(use.usedVal))
{
// Generally, we cannot recompute a load(ptr), since ptr may be modified
// afterwards.
@@ -1509,7 +1491,7 @@ bool canRecompute(IRUse* use)
}
return false;
}
- auto param = as<IRParam>(use->get());
+ auto param = as<IRParam>(use.usedVal);
if (!param)
return true;
@@ -1526,12 +1508,12 @@ bool canRecompute(IRUse* use)
return true;
}
-HoistResult DefaultCheckpointPolicy::classify(IRUse* use)
+HoistResult DefaultCheckpointPolicy::classify(UseOrPseudoUse use)
{
// Store all that we can.. by default, classify will only be called on relevant differential
// uses (or on uses in a 'recompute' inst)
//
- if (auto var = as<IRVar>(use->get()))
+ if (auto var = as<IRVar>(use.usedVal))
{
if (shouldStoreVar(var))
return HoistResult::store(var);
@@ -1540,19 +1522,19 @@ HoistResult DefaultCheckpointPolicy::classify(IRUse* use)
}
else
{
- if (shouldStoreInst(use->get()))
+ if (shouldStoreInst(use.usedVal))
{
- return HoistResult::store(use->get());
+ return HoistResult::store(use.usedVal);
}
else
{
// We may not be able to recompute due to limitations of
// the unzip pass. If so we will store the result.
if (canRecompute(use))
- return HoistResult::recompute(use->get());
+ return HoistResult::recompute(use.usedVal);
// The fallback is to store.
- return HoistResult::store(use->get());
+ return HoistResult::store(use.usedVal);
}
}
}
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
index c0b56126d..e9fc0d4a5 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.h
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -236,6 +236,34 @@ namespace Slang
Dictionary<IRInst*, InversionInfo> invInfoMap;
};
+ struct UseOrPseudoUse
+ {
+ IRUse* irUse = nullptr;
+ IRInst* user;
+ IRInst* usedVal;
+ UseOrPseudoUse() = default;
+ UseOrPseudoUse(IRUse* use)
+ {
+ user = use->getUser();
+ usedVal = use->get();
+ irUse = use;
+ }
+ UseOrPseudoUse(IRInst* inUser, IRInst* inUsedVal)
+ {
+ irUse = nullptr;
+ user = inUser;
+ usedVal = inUsedVal;;
+ }
+ HashCode getHashCode() const
+ {
+ return combineHash(Slang::getHashCode(user), Slang::getHashCode(usedVal));
+ }
+ bool operator==(const UseOrPseudoUse& other) const
+ {
+ return user == other.user && usedVal == other.usedVal;
+ }
+ };
+
// Information on a block after it has been split in the unzip step.
// After unzipping, every block in the original function will have
// two corresponding blocks in the new function:
@@ -269,7 +297,7 @@ namespace Slang
//
virtual void preparePolicy(IRGlobalValueWithCode* func) = 0;
- virtual HoistResult classify(IRUse* diffBlockUse) = 0;
+ virtual HoistResult classify(UseOrPseudoUse diffBlockUse) = 0;
protected:
@@ -285,7 +313,7 @@ namespace Slang
{ }
virtual void preparePolicy(IRGlobalValueWithCode* func);
- virtual HoistResult classify(IRUse* use);
+ virtual HoistResult classify(UseOrPseudoUse use);
};
RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func);
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 1d6d2d039..d81a33719 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -134,21 +134,6 @@ public:
return false;
}
- int getParamIndexInBlock(IRParam* paramInst)
- {
- auto block = as<IRBlock>(paramInst->getParent());
- if (!block)
- return -1;
- int paramIndex = 0;
- for (auto param : block->getParams())
- {
- if (param == paramInst)
- return paramIndex;
- paramIndex++;
- }
- return -1;
- }
-
bool isInstInFunc(IRInst* inst, IRInst* func)
{
while (inst)
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index 32c6abd39..37e8ba5bb 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -345,6 +345,34 @@ bool tryRemoveRedundantStore(IRGlobalValueWithCode* func, IRStore* store)
store->removeAndDeallocate();
return true;
}
+
+ // A store can be removed if it is a store into the same var, and there are
+ // no side effects between the load of the var and the store of the var.
+ if (auto load = as<IRLoad>(store->getVal()))
+ {
+ if (load->getPtr() == store->getPtr())
+ {
+ if (load->getParent() == store->getParent())
+ {
+ bool valueMayChange = false;
+ for (auto inst = load->next; inst; inst = inst->next)
+ {
+ if (inst == store)
+ break;
+ if (canInstHaveSideEffectAtAddress(func, inst, store->getPtr()))
+ {
+ valueMayChange = true;
+ break;
+ }
+ }
+ if (!valueMayChange)
+ {
+ store->removeAndDeallocate();
+ return true;
+ }
+ }
+ }
+ }
return false;
}
diff --git a/source/slang/slang-ir-ssa-register-allocate.cpp b/source/slang/slang-ir-ssa-register-allocate.cpp
index 07eec0c2b..b1d375fcf 100644
--- a/source/slang/slang-ir-ssa-register-allocate.cpp
+++ b/source/slang/slang-ir-ssa-register-allocate.cpp
@@ -5,7 +5,7 @@
#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-ir-dominators.h"
-
+#include "slang-ir-util.h"
namespace Slang
{
@@ -91,7 +91,7 @@ struct RegisterAllocateContext
auto name2 = inst2->findDecoration<IRNameHintDecoration>();
if (name1 && !name2 || !name1 && name2)
- return false;
+ return true;
if (!name1 || !name2)
return true;
@@ -101,6 +101,45 @@ struct RegisterAllocateContext
return true;
}
+ bool isUseOfParamAfterPhiAssignment(IRDominatorTree* dom, IRUse* useToTest, IRInst* phiParam, IRInst* phiArg)
+ {
+ IRParam* param = as<IRParam>(phiParam);
+ if (!param)
+ return false;
+ IRUse* branchUse = nullptr;
+ for (auto use = phiArg->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser()->getOp() == kIROp_unconditionalBranch)
+ {
+ if (!branchUse)
+ branchUse = use;
+ else
+ {
+ // If arg is being used in more than one branch, don't handle it.
+ return false;
+ }
+ }
+ }
+ if (!branchUse)
+ return false;
+ auto branch = as<IRUnconditionalBranch>(branchUse->getUser());
+ auto branchTarget = branch->getTargetBlock();
+
+ if (param->getParent() != branchTarget)
+ return false;
+ auto paramIndex = getParamIndexInBlock(param);
+ if (paramIndex >= (int)branch->getArgCount() || paramIndex == -1)
+ return false;
+ if (branch->getArg(paramIndex) != phiArg)
+ return false;
+
+ // If we reach here, then phiArg is indeed used as arg for phiParam.
+ // We will allow any use of phiParam when phiArg isn't live.
+ if (dom->dominates(phiArg, useToTest->getUser()))
+ return false;
+ return true;
+ }
+
RegisterAllocationResult allocateRegisters(IRGlobalValueWithCode* func, RefPtr<IRDominatorTree>& inOutDom)
{
ReachabilityContext reachabilityContext;
@@ -167,11 +206,12 @@ struct RegisterAllocateContext
if (!dominatingInstSet.contains(existingInst))
continue;
- // If `existingInst` does dominate `inst`, we need to check all
- // its use sites U to see if there is a path from `inst` to U.
- // The idea is that is `existingInst` is never used anywhere after
- // `inst`, then its lifetime ended before `inst` is defined, so it
- // is still fine to place them in the same register.
+ // In the general case, we need to check all its use
+ // sites U to see if there is a path from `inst` to U.
+ // The idea is that is `existingInst` is never used
+ // anywhere after `inst`, then its lifetime ended before
+ // `inst` is defined, so it is still fine to place them
+ // in the same register.
for (auto use = existingInst->firstUse; use; use = use->nextUse)
{
if (use->getUser() == inst)
@@ -180,6 +220,14 @@ struct RegisterAllocateContext
if (!canCoalesce(existingInst, inst) ||
reachabilityContext.isInstReachable(inst, use->getUser()))
{
+ // We found a use of `existingInst` (U) where
+ // there is a path from `inst` to U.
+ // Generally this means that existingInst and inst interfere.
+ // However, an exception is that existingInst is a PhiParam,
+ // and inst is an arg to that param, and use happens after
+ // the phi assignment.
+ if (isUseOfParamAfterPhiAssignment(dom, use, existingInst, inst))
+ continue;
hasInterference = true;
goto endRegInstCheck;
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index a69e13562..95240a26c 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -759,6 +759,21 @@ void removePhiArgs(IRInst* phiParam)
}
}
+int getParamIndexInBlock(IRParam* paramInst)
+{
+ auto block = as<IRBlock>(paramInst->getParent());
+ if (!block)
+ return -1;
+ int paramIndex = 0;
+ for (auto param : block->getParams())
+ {
+ if (param == paramInst)
+ return paramIndex;
+ paramIndex++;
+ }
+ return -1;
+}
+
struct GenericChildrenMigrationContextImpl
{
IRCloneEnv cloneEnv;
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 075788520..492a9f312 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -202,6 +202,9 @@ IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key);
void moveParams(IRBlock* dest, IRBlock* src);
void removePhiArgs(IRInst* phiParam);
+
+int getParamIndexInBlock(IRParam* paramInst);
+
}
#endif
diff --git a/tests/experimental/liveness/liveness-3.slang.expected b/tests/experimental/liveness/liveness-3.slang.expected
index 78e80d7b5..c2191e9b9 100644
--- a/tests/experimental/liveness/liveness-3.slang.expected
+++ b/tests/experimental/liveness/liveness-3.slang.expected
@@ -78,28 +78,18 @@ int calcThing_0(int offset_0)
}
int modRange_0 = i_0 % 3;
another_0[i_0 & 1] = another_0[i_0 & 1] + modRange_0;
- int _S5;
if(modRange_0 != 0)
{
- int _S6 = _S4;
- livenessEnd_0(_S4, 0);
- int _S7 = _S6 + 1;
- livenessStart_1(_S5, 0);
- _S5 = _S7;
+ _S4 = _S4 + 1;
}
else
{
- int _S8 = _S4;
- livenessEnd_0(_S4, 0);
- livenessStart_1(_S5, 0);
- _S5 = _S8;
}
+ int _S5 = _S4;
+ livenessEnd_0(_S4, 0);
idx_0[modRange_0] = idx_0[modRange_0] + (_S5 + i_0);
i_0 = i_0 + 1;
livenessStart_1(_S4, 0);
- int _S9 = _S5;
- livenessEnd_0(_S5, 0);
- _S4 = _S9;
}
livenessEnd_0(_S2, 0);
livenessEnd_0(k_0, 0);
@@ -110,34 +100,34 @@ int calcThing_0(int offset_0)
livenessEnd_2(another_0, 0);
return total_0;
}
- int _S10 = idx_0[0] + idx_0[1];
- int _S11 = idx_0[2];
+ int _S6 = idx_0[0] + idx_0[1];
+ int _S7 = idx_0[2];
livenessEnd_1(idx_0, 0);
- int _S12 = _S10 + _S11;
- int _S13 = total_0;
+ int _S8 = _S6 + _S7;
+ int _S9 = total_0;
livenessEnd_0(total_0, 0);
- int total_1 = _S13 + _S12;
+ int total_1 = _S9 + _S8;
livenessStart_1(k_0, 0);
k_0 = k_1;
livenessStart_1(_S2, 0);
- int _S14 = _S4;
+ int _S10 = _S4;
livenessEnd_0(_S4, 0);
- _S2 = _S14;
+ _S2 = _S10;
livenessStart_1(total_0, 0);
total_0 = total_1;
}
livenessEnd_2(another_0, 0);
- int _S15 = total_0;
+ int _S11 = total_0;
livenessEnd_0(total_0, 0);
- return - _S15;
+ return - _S11;
}
layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in;
void main()
{
int index_0 = int(gl_GlobalInvocationID.x);
- int _S16 = calcThing_0(index_0);
- ((outputBuffer_0)._data[(uint(index_0))]) = _S16;
+ int _S12 = calcThing_0(index_0);
+ ((outputBuffer_0)._data[(uint(index_0))]) = _S12;
return;
}
diff --git a/tests/ir/loop-phi-coalesce.slang b/tests/ir/loop-phi-coalesce.slang
new file mode 100644
index 000000000..2f1aba472
--- /dev/null
+++ b/tests/ir/loop-phi-coalesce.slang
@@ -0,0 +1,49 @@
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
+
+
+RWStructuredBuffer<float> outputBuffer;
+
+int test1()
+{
+ float t = 0;
+ for (int i = 0; i < 5; i++)
+ {
+ if (i < 3)
+ t = t + 1;
+ else
+ t = t + 2;
+ // we should coalesce the phi after the `if` the and phi of the `for` loop.
+ }
+ outputBuffer[0] = t;
+ return 0;
+}
+// CHECK: int test1{{[_0-9]*}}()
+// CHECK-NOT: float t_1
+// CHECK: return
+
+int test2()
+{
+ float v = 0;
+ for (int i = 0; i < 5; i++)
+ {
+ float ov = v;
+ if (i < 3)
+ v = v + 1;
+ else
+ v = v + 2;
+ // use of ot here means we can't coalesce the phis of the `if` and the `for` loop.
+ outputBuffer[1] = ov;
+ }
+ outputBuffer[0] = v;
+ return 0;
+}
+// CHECK: int test2{{[_0-9]*}}()
+// CHECK: float v_1
+// CHECK: return
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ test1();
+ test2();
+}