diff options
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 58 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.h | 32 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-ir-redundancy-removal.cpp | 28 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa-register-allocate.cpp | 62 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 3 | ||||
| -rw-r--r-- | tests/experimental/liveness/liveness-3.slang.expected | 38 | ||||
| -rw-r--r-- | tests/ir/loop-phi-coalesce.slang | 49 |
10 files changed, 231 insertions, 87 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index fc2505d63..b17ccf483 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1473,7 +1473,20 @@ struct SPIRVEmitContext // for( auto irBlock : irFunc->getBlocks() ) { - emitInst(spvFunc, irBlock, SpvOpLabel, kResultID); + auto spvBlock = emitInst(spvFunc, irBlock, SpvOpLabel, kResultID); + if (irBlock == irFunc->getFirstBlock()) + { + // OpVariable + // All variables used in the function must be declared before anything else. + for (auto block : irFunc->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (as<IRVar>(inst)) + emitLocalInst(spvBlock, inst); + } + } + } // In addition to normal basic blocks, // all loops gets a header block. @@ -1517,6 +1530,9 @@ struct SPIRVEmitContext // Any instructions local to the block will be emitted as children // of the block. // + // Skip vars because they are already emitted. + if (as<IRVar>(irInst)) + continue; emitLocalInst(spvBlock, irInst); if (irInst->getOp() == kIROp_loop) pendingLoopInsts.add(as<IRLoop>(irInst)); diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 202660682..135c72556 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -266,9 +266,8 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); - List<IRUse*> workList; - HashSet<IRUse*> processedUses; - + List<UseOrPseudoUse> workList; + HashSet<UseOrPseudoUse> processedUses; HashSet<IRUse*> usesToReplace; auto addPrimalOperandsToWorkList = [&](IRInst* inst) @@ -358,8 +357,8 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( SLANG_ASSERT(!checkpointInfo->storeSet.contains(result.instToRecompute)); checkpointInfo->recomputeSet.add(result.instToRecompute); - if (isDifferentialInst(use->getUser())) - usesToReplace.add(use); + if (isDifferentialInst(use.user) && use.irUse) + usesToReplace.add(use.irUse); if (auto param = as<IRParam>(result.instToRecompute)) { @@ -392,7 +391,13 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( { IRUse* storeUse = findLatestUniqueWriteUse(var); if (storeUse) - workList.add(storeUse); + { + // When we have a var and a store/call insts that writes to the var, + // we treat as if there is a pseudo-use of the store/call to compute + // the var inst, i.e. the var depends on the store/call, despite + // the IR's def-use chain doesn't reflect this. + workList.add(UseOrPseudoUse(var, storeUse->getUser())); + } } else { @@ -400,29 +405,6 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( } } } - else if (result.mode == HoistResult::Mode::Invert) - { - auto instToInvert = result.inversionInfo.instToInvert; - - SLANG_RELEASE_ASSERT(containsOperand(instToInvert, use->getUser())); - SLANG_RELEASE_ASSERT(result.inversionInfo.targetInsts.contains(use->getUser())); - - if (isDifferentialInst(use->getUser())) - usesToReplace.add(use); - - checkpointInfo->invertSet.add(instToInvert); - - if (checkpointInfo->invInfoMap.containsKey(instToInvert)) - { - List<IRInst*> currOperands = checkpointInfo->invInfoMap[instToInvert].getValue().requiredOperands; - for (Index ii = 0; ii < result.inversionInfo.requiredOperands.getCount(); ii++) - { - SLANG_RELEASE_ASSERT(result.inversionInfo.requiredOperands[ii] == currOperands[ii]); - } - } - else - checkpointInfo->invInfoMap[instToInvert] = result.inversionInfo; - } } // If a var or call is in recomputeSet, move any var/calls associated with the same call to @@ -1480,9 +1462,9 @@ static bool shouldStoreVar(IRVar* var) return false; } -bool canRecompute(IRUse* use) +bool canRecompute(UseOrPseudoUse use) { - if (auto load = as<IRLoad>(use->get())) + if (auto load = as<IRLoad>(use.usedVal)) { // Generally, we cannot recompute a load(ptr), since ptr may be modified // afterwards. @@ -1509,7 +1491,7 @@ bool canRecompute(IRUse* use) } return false; } - auto param = as<IRParam>(use->get()); + auto param = as<IRParam>(use.usedVal); if (!param) return true; @@ -1526,12 +1508,12 @@ bool canRecompute(IRUse* use) return true; } -HoistResult DefaultCheckpointPolicy::classify(IRUse* use) +HoistResult DefaultCheckpointPolicy::classify(UseOrPseudoUse use) { // Store all that we can.. by default, classify will only be called on relevant differential // uses (or on uses in a 'recompute' inst) // - if (auto var = as<IRVar>(use->get())) + if (auto var = as<IRVar>(use.usedVal)) { if (shouldStoreVar(var)) return HoistResult::store(var); @@ -1540,19 +1522,19 @@ HoistResult DefaultCheckpointPolicy::classify(IRUse* use) } else { - if (shouldStoreInst(use->get())) + if (shouldStoreInst(use.usedVal)) { - return HoistResult::store(use->get()); + return HoistResult::store(use.usedVal); } else { // We may not be able to recompute due to limitations of // the unzip pass. If so we will store the result. if (canRecompute(use)) - return HoistResult::recompute(use->get()); + return HoistResult::recompute(use.usedVal); // The fallback is to store. - return HoistResult::store(use->get()); + return HoistResult::store(use.usedVal); } } } diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index c0b56126d..e9fc0d4a5 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -236,6 +236,34 @@ namespace Slang Dictionary<IRInst*, InversionInfo> invInfoMap; }; + struct UseOrPseudoUse + { + IRUse* irUse = nullptr; + IRInst* user; + IRInst* usedVal; + UseOrPseudoUse() = default; + UseOrPseudoUse(IRUse* use) + { + user = use->getUser(); + usedVal = use->get(); + irUse = use; + } + UseOrPseudoUse(IRInst* inUser, IRInst* inUsedVal) + { + irUse = nullptr; + user = inUser; + usedVal = inUsedVal;; + } + HashCode getHashCode() const + { + return combineHash(Slang::getHashCode(user), Slang::getHashCode(usedVal)); + } + bool operator==(const UseOrPseudoUse& other) const + { + return user == other.user && usedVal == other.usedVal; + } + }; + // Information on a block after it has been split in the unzip step. // After unzipping, every block in the original function will have // two corresponding blocks in the new function: @@ -269,7 +297,7 @@ namespace Slang // virtual void preparePolicy(IRGlobalValueWithCode* func) = 0; - virtual HoistResult classify(IRUse* diffBlockUse) = 0; + virtual HoistResult classify(UseOrPseudoUse diffBlockUse) = 0; protected: @@ -285,7 +313,7 @@ namespace Slang { } virtual void preparePolicy(IRGlobalValueWithCode* func); - virtual HoistResult classify(IRUse* use); + virtual HoistResult classify(UseOrPseudoUse use); }; RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func); diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 1d6d2d039..d81a33719 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -134,21 +134,6 @@ public: return false; } - int getParamIndexInBlock(IRParam* paramInst) - { - auto block = as<IRBlock>(paramInst->getParent()); - if (!block) - return -1; - int paramIndex = 0; - for (auto param : block->getParams()) - { - if (param == paramInst) - return paramIndex; - paramIndex++; - } - return -1; - } - bool isInstInFunc(IRInst* inst, IRInst* func) { while (inst) diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index 32c6abd39..37e8ba5bb 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -345,6 +345,34 @@ bool tryRemoveRedundantStore(IRGlobalValueWithCode* func, IRStore* store) store->removeAndDeallocate(); return true; } + + // A store can be removed if it is a store into the same var, and there are + // no side effects between the load of the var and the store of the var. + if (auto load = as<IRLoad>(store->getVal())) + { + if (load->getPtr() == store->getPtr()) + { + if (load->getParent() == store->getParent()) + { + bool valueMayChange = false; + for (auto inst = load->next; inst; inst = inst->next) + { + if (inst == store) + break; + if (canInstHaveSideEffectAtAddress(func, inst, store->getPtr())) + { + valueMayChange = true; + break; + } + } + if (!valueMayChange) + { + store->removeAndDeallocate(); + return true; + } + } + } + } return false; } diff --git a/source/slang/slang-ir-ssa-register-allocate.cpp b/source/slang/slang-ir-ssa-register-allocate.cpp index 07eec0c2b..b1d375fcf 100644 --- a/source/slang/slang-ir-ssa-register-allocate.cpp +++ b/source/slang/slang-ir-ssa-register-allocate.cpp @@ -5,7 +5,7 @@ #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-ir-dominators.h" - +#include "slang-ir-util.h" namespace Slang { @@ -91,7 +91,7 @@ struct RegisterAllocateContext auto name2 = inst2->findDecoration<IRNameHintDecoration>(); if (name1 && !name2 || !name1 && name2) - return false; + return true; if (!name1 || !name2) return true; @@ -101,6 +101,45 @@ struct RegisterAllocateContext return true; } + bool isUseOfParamAfterPhiAssignment(IRDominatorTree* dom, IRUse* useToTest, IRInst* phiParam, IRInst* phiArg) + { + IRParam* param = as<IRParam>(phiParam); + if (!param) + return false; + IRUse* branchUse = nullptr; + for (auto use = phiArg->firstUse; use; use = use->nextUse) + { + if (use->getUser()->getOp() == kIROp_unconditionalBranch) + { + if (!branchUse) + branchUse = use; + else + { + // If arg is being used in more than one branch, don't handle it. + return false; + } + } + } + if (!branchUse) + return false; + auto branch = as<IRUnconditionalBranch>(branchUse->getUser()); + auto branchTarget = branch->getTargetBlock(); + + if (param->getParent() != branchTarget) + return false; + auto paramIndex = getParamIndexInBlock(param); + if (paramIndex >= (int)branch->getArgCount() || paramIndex == -1) + return false; + if (branch->getArg(paramIndex) != phiArg) + return false; + + // If we reach here, then phiArg is indeed used as arg for phiParam. + // We will allow any use of phiParam when phiArg isn't live. + if (dom->dominates(phiArg, useToTest->getUser())) + return false; + return true; + } + RegisterAllocationResult allocateRegisters(IRGlobalValueWithCode* func, RefPtr<IRDominatorTree>& inOutDom) { ReachabilityContext reachabilityContext; @@ -167,11 +206,12 @@ struct RegisterAllocateContext if (!dominatingInstSet.contains(existingInst)) continue; - // If `existingInst` does dominate `inst`, we need to check all - // its use sites U to see if there is a path from `inst` to U. - // The idea is that is `existingInst` is never used anywhere after - // `inst`, then its lifetime ended before `inst` is defined, so it - // is still fine to place them in the same register. + // In the general case, we need to check all its use + // sites U to see if there is a path from `inst` to U. + // The idea is that is `existingInst` is never used + // anywhere after `inst`, then its lifetime ended before + // `inst` is defined, so it is still fine to place them + // in the same register. for (auto use = existingInst->firstUse; use; use = use->nextUse) { if (use->getUser() == inst) @@ -180,6 +220,14 @@ struct RegisterAllocateContext if (!canCoalesce(existingInst, inst) || reachabilityContext.isInstReachable(inst, use->getUser())) { + // We found a use of `existingInst` (U) where + // there is a path from `inst` to U. + // Generally this means that existingInst and inst interfere. + // However, an exception is that existingInst is a PhiParam, + // and inst is an arg to that param, and use happens after + // the phi assignment. + if (isUseOfParamAfterPhiAssignment(dom, use, existingInst, inst)) + continue; hasInterference = true; goto endRegInstCheck; } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index a69e13562..95240a26c 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -759,6 +759,21 @@ void removePhiArgs(IRInst* phiParam) } } +int getParamIndexInBlock(IRParam* paramInst) +{ + auto block = as<IRBlock>(paramInst->getParent()); + if (!block) + return -1; + int paramIndex = 0; + for (auto param : block->getParams()) + { + if (param == paramInst) + return paramIndex; + paramIndex++; + } + return -1; +} + struct GenericChildrenMigrationContextImpl { IRCloneEnv cloneEnv; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 075788520..492a9f312 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -202,6 +202,9 @@ IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key); void moveParams(IRBlock* dest, IRBlock* src); void removePhiArgs(IRInst* phiParam); + +int getParamIndexInBlock(IRParam* paramInst); + } #endif diff --git a/tests/experimental/liveness/liveness-3.slang.expected b/tests/experimental/liveness/liveness-3.slang.expected index 78e80d7b5..c2191e9b9 100644 --- a/tests/experimental/liveness/liveness-3.slang.expected +++ b/tests/experimental/liveness/liveness-3.slang.expected @@ -78,28 +78,18 @@ int calcThing_0(int offset_0) } int modRange_0 = i_0 % 3; another_0[i_0 & 1] = another_0[i_0 & 1] + modRange_0; - int _S5; if(modRange_0 != 0) { - int _S6 = _S4; - livenessEnd_0(_S4, 0); - int _S7 = _S6 + 1; - livenessStart_1(_S5, 0); - _S5 = _S7; + _S4 = _S4 + 1; } else { - int _S8 = _S4; - livenessEnd_0(_S4, 0); - livenessStart_1(_S5, 0); - _S5 = _S8; } + int _S5 = _S4; + livenessEnd_0(_S4, 0); idx_0[modRange_0] = idx_0[modRange_0] + (_S5 + i_0); i_0 = i_0 + 1; livenessStart_1(_S4, 0); - int _S9 = _S5; - livenessEnd_0(_S5, 0); - _S4 = _S9; } livenessEnd_0(_S2, 0); livenessEnd_0(k_0, 0); @@ -110,34 +100,34 @@ int calcThing_0(int offset_0) livenessEnd_2(another_0, 0); return total_0; } - int _S10 = idx_0[0] + idx_0[1]; - int _S11 = idx_0[2]; + int _S6 = idx_0[0] + idx_0[1]; + int _S7 = idx_0[2]; livenessEnd_1(idx_0, 0); - int _S12 = _S10 + _S11; - int _S13 = total_0; + int _S8 = _S6 + _S7; + int _S9 = total_0; livenessEnd_0(total_0, 0); - int total_1 = _S13 + _S12; + int total_1 = _S9 + _S8; livenessStart_1(k_0, 0); k_0 = k_1; livenessStart_1(_S2, 0); - int _S14 = _S4; + int _S10 = _S4; livenessEnd_0(_S4, 0); - _S2 = _S14; + _S2 = _S10; livenessStart_1(total_0, 0); total_0 = total_1; } livenessEnd_2(another_0, 0); - int _S15 = total_0; + int _S11 = total_0; livenessEnd_0(total_0, 0); - return - _S15; + return - _S11; } layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - int _S16 = calcThing_0(index_0); - ((outputBuffer_0)._data[(uint(index_0))]) = _S16; + int _S12 = calcThing_0(index_0); + ((outputBuffer_0)._data[(uint(index_0))]) = _S12; return; } diff --git a/tests/ir/loop-phi-coalesce.slang b/tests/ir/loop-phi-coalesce.slang new file mode 100644 index 000000000..2f1aba472 --- /dev/null +++ b/tests/ir/loop-phi-coalesce.slang @@ -0,0 +1,49 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none + + +RWStructuredBuffer<float> outputBuffer; + +int test1() +{ + float t = 0; + for (int i = 0; i < 5; i++) + { + if (i < 3) + t = t + 1; + else + t = t + 2; + // we should coalesce the phi after the `if` the and phi of the `for` loop. + } + outputBuffer[0] = t; + return 0; +} +// CHECK: int test1{{[_0-9]*}}() +// CHECK-NOT: float t_1 +// CHECK: return + +int test2() +{ + float v = 0; + for (int i = 0; i < 5; i++) + { + float ov = v; + if (i < 3) + v = v + 1; + else + v = v + 2; + // use of ot here means we can't coalesce the phis of the `if` and the `for` loop. + outputBuffer[1] = ov; + } + outputBuffer[0] = v; + return 0; +} +// CHECK: int test2{{[_0-9]*}}() +// CHECK: float v_1 +// CHECK: return + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + test1(); + test2(); +} |
