diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-02-25 12:04:31 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-25 12:04:31 -0800 |
| commit | f7b9745e46db6a7e55f6e0265493350d65ea4615 (patch) | |
| tree | fb74e013a1c57876c7b94299367c6b9b8343784f | |
| parent | a9f2f8a592c4514cd116c947486055788092ea56 (diff) | |
Fix a bug with hoisting 'IRVar' insts that are used outside the loop (#6446)
* Fix a bug with hoisting 'IRVar' insts that are used outside the loop
- We introduce a 'CheckpointObject' inst and use that to split loop state insts into two pieces (one for within-loop uses and one for outside-loop uses.
- This allows the two kinds of uses to be handled separately by the hoisting mechanism
- CheckpointObject is then lowered to a no-op after hoisting is complete.
* Update slang-ir-autodiff-primal-hoist.cpp
* Update slang-ir-autodiff-primal-hoist.cpp
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 211 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 94 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 7 | ||||
| -rw-r--r-- | tests/autodiff/reverse-continue-loop.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop-immediate-return.slang | 59 |
7 files changed, 336 insertions, 57 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index b5ac784ce..c2403e53b 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -281,6 +281,142 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks( return recomputeBlockMap; } +// Checks if list A is a subset of list B by comparing their primal count parameters. +// +// Parameters: +// indicesA - First list of IndexTrackingInfo to compare +// indicesB - Second list of IndexTrackingInfo to compare +// +// Returns: +// true if all indices in indicesA are present in indicesB, false otherwise +// +bool areIndicesSubsetOf(List<IndexTrackingInfo>& indicesA, List<IndexTrackingInfo>& indicesB) +{ + if (indicesA.getCount() > indicesB.getCount()) + return false; + + for (Index ii = 0; ii < indicesA.getCount(); ii++) + { + if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam) + return false; + } + + return true; +} + +bool canInstBeStored(IRInst* inst) +{ + // Cannot store insts whose value is a type or a witness table, or a function. + // These insts get lowered to target-specific logic, and cannot be + // stored into variables or context structs as normal values. + // + if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()) || + as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()) || + !inst->getDataType()) + return false; + + return true; +} + +// This is a helper that converts insts in a loop condition block into two if necessary, +// then replaces all uses 'outside' the loop region with the new insts. This is because +// insts in loop condition blocks can be used in two distinct regions (the loop body, and +// after the loop). +// +// We'll use CheckpointObject for the splitting, which is allowed on any value-typed inst. +// +void splitLoopConditionBlockInsts( + IRGlobalValueWithCode* func, + Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo) +{ + // RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); + + // Collect primal loop condition blocks, and map differential blocks to their primal blocks. + List<IRBlock*> loopConditionBlocks; + Dictionary<IRBlock*, IRBlock*> diffBlockMap; + for (auto block : func->getBlocks()) + { + if (auto loop = as<IRLoop>(block->getTerminator())) + { + auto loopConditionBlock = getLoopConditionBlock(loop); + if (isDifferentialBlock(loopConditionBlock)) + { + auto diffDecor = loopConditionBlock->findDecoration<IRDifferentialInstDecoration>(); + diffBlockMap[cast<IRBlock>(diffDecor->getPrimalInst())] = loopConditionBlock; + } + else + loopConditionBlocks.add(loopConditionBlock); + } + } + + // For each loop condition block, split the insts that are used in both the loop body and + // after the loop. + // Use the dominator tree to find uses of insts outside the loop body + // + // Essentially we want to split the uses dominated by the true block and the false block of the + // condition. + // + IRBuilder builder(func->getModule()); + + + List<IRUse*> loopUses; + List<IRUse*> afterLoopUses; + + for (auto condBlock : loopConditionBlocks) + { + // For each inst in the primal condition block, check if it has uses inside the loop body + // as well as outside of it. (Use the indexedBlockInfo to perform the teets) + // + for (auto inst = condBlock->getFirstInst(); inst; inst = inst->getNextInst()) + { + // Skip terminators and insts that can't be stored + if (as<IRTerminatorInst>(inst) || !canInstBeStored(inst)) + continue; + // Shouldn't see any vars. + SLANG_ASSERT(!as<IRVar>(inst)); + + // Get the indices for the condition block + auto& condBlockIndices = indexedBlockInfo[condBlock]; + + loopUses.clear(); + afterLoopUses.clear(); + + // Check all uses of this inst + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto userBlock = getBlock(use->getUser()); + auto& userBlockIndices = indexedBlockInfo[userBlock]; + + // If all of the condBlock's indices are a subset of the userBlock's indices, + // then the userBlock is inside the loop. + // + bool isInLoop = areIndicesSubsetOf(condBlockIndices, userBlockIndices); + + if (isInLoop) + loopUses.add(use); + else + afterLoopUses.add(use); + } + + // If inst has uses both inside and after the loop, create a copy for after-loop uses + if (loopUses.getCount() > 0 && afterLoopUses.getCount() > 0) + { + setInsertAfterOrdinaryInst(&builder, inst); + auto copy = builder.emitCheckpointObject(inst); + + // Copy source location so that checkpoint reporting is accurate + copy->sourceLoc = inst->sourceLoc; + + // Replace after-loop uses with the copy + for (auto use : afterLoopUses) + { + builder.replaceOperand(use, copy); + } + } + } + } +} + RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( IRGlobalValueWithCode* func, Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock, @@ -1297,20 +1433,6 @@ bool areIndicesEqual( return true; } -bool areIndicesSubsetOf(List<IndexTrackingInfo>& indicesA, List<IndexTrackingInfo>& indicesB) -{ - if (indicesA.getCount() > indicesB.getCount()) - return false; - - for (Index ii = 0; ii < indicesA.getCount(); ii++) - { - if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam) - return false; - } - - return true; -} - static int getInstRegionNestLevel( Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo, IRBlock* defBlock, @@ -1510,21 +1632,6 @@ static List<IndexTrackingInfo> maybeTrimIndices( return result; } -bool canInstBeStored(IRInst* inst) -{ - // Cannot store insts whose value is a type or a witness table, or a function. - // These insts get lowered to target-specific logic, and cannot be - // stored into variables or context structs as normal values. - // - if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()) || - as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()) || - !inst->getDataType()) - return false; - - return true; -} - - /// Legalizes all accesses to primal insts from recompute and diff blocks. /// RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( @@ -2104,6 +2211,39 @@ void buildIndexedBlocks( } } +// This function simply turns all CheckpointObject insts into a 'no-op'. +// i.e. simply replaces all uses of CheckpointObject with the original value. +// +// This operation is 'correct' because if CheckpointObject's operand is visible +// in a block, then it is visible in all dominated blocks. +// +void lowerCheckpointObjectInsts(IRGlobalValueWithCode* func) +{ + // For each block in the function + for (auto block : func->getBlocks()) + { + // For each instruction in the block + for (auto inst = block->getFirstInst(); inst;) + { + // Get next inst before potentially removing current one + auto nextInst = inst->getNextInst(); + + // Check if this is a CheckpointObject instruction + if (auto copyInst = as<IRCheckpointObject>(inst)) + { + // Replace all uses of the copy with the original value + auto originalVal = copyInst->getVal(); + copyInst->replaceUsesWith(originalVal); + + // Remove the now unused copy instruction + inst->removeAndDeallocate(); + } + + inst = nextInst; + } + } +} + // For each primal inst that is used in reverse blocks, decide if we should recompute or store // its value, then make them accessible in reverse blocks based the decision. // @@ -2117,6 +2257,9 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func) Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo; buildIndexedBlocks(indexedBlockInfo, func); + // Split loop condition insts into two if necessary. + splitLoopConditionBlockInsts(func, indexedBlockInfo); + // Create recompute blocks for each region following the same control flow structure // as in primal code. // @@ -2136,7 +2279,12 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func) // Legalize the primal inst accesses by introducing local variables / arrays and emitting // necessary load/store logic. // - return ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); + auto hoistedPrimalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); + + // Lower CheckpointObject insts to a no-op. + lowerCheckpointObjectInsts(func); + + return hoistedPrimalsInfo; } void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func) @@ -2312,6 +2460,9 @@ static bool shouldStoreInst(IRInst* inst) break; } + case kIROp_CheckpointObject: + // Special inst for when a value must be stored. + return true; default: break; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 5a1966d00..9ffaeeeb9 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -716,6 +716,8 @@ INST(BitNot, bitnot, 1, 0) INST(Select, select, 3, 0) +INST(CheckpointObject, checkpointObj, 1, 0) + INST(GetStringHash, getStringHash, 1, 0) INST(WaveGetActiveMask, waveGetActiveMask, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index d64820aa6..7c975cfcd 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2664,6 +2664,22 @@ struct IRDiscard : IRTerminatorInst { }; +// Used for representing a distinct copy of an object. +// This will get lowered into a no-op in the backend, +// but is useful for IR transformations that need to consider +// different uses of an inst separately. +// +// For example, when we hoist primal insts out of a loop, +// we need to make distinct copies of the inst for its uses +// within the loop body and outside of it. +// +struct IRCheckpointObject : IRInst +{ + IR_LEAF_ISA(CheckpointObject); + + IRInst* getVal() { return getOperand(0); } +}; + // Signals that this point in the code should be unreachable. // We can/should emit a dataflow error if we can ever determine // that a block ending in one of these can actually be @@ -4408,6 +4424,8 @@ public: IRInst* emitDiscard(); + IRInst* emitCheckpointObject(IRInst* value); + IRInst* emitUnreachable(); IRInst* emitMissingReturn(); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index bf5b25d9c..39c1c5bb1 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -2078,9 +2078,36 @@ Int getSpecializationConstantId(IRGlobalParam* param) return offset->getOffset(); } +IRBlock* getLoopHeaderForConditionBlock(IRBlock* block) +{ + // Go through uses and check if any of them are a loop condition block. + for (auto use = block->firstUse; use; use = use->nextUse) + { + if (auto loop = as<IRLoop>(use->getUser())) + { + if (loop->getTargetBlock() == block) + return cast<IRBlock>(loop->getParent()); + } + } + return nullptr; +} + void legalizeDefUse(IRGlobalValueWithCode* func) { auto dom = computeDominatorTree(func); + + // Make a map of loop condition blocks to their loop header. + // We need this because we'll be treating loop condition blocks as + // special cases (they are the special blocks since they "dominate" themselves, + // in the dominator tree sense) + // + Dictionary<IRBlock*, IRBlock*> loopHeaderBlockMap; + for (auto block : func->getBlocks()) + { + if (auto header = getLoopHeaderForConditionBlock(block)) + loopHeaderBlockMap.add(block, header); + } + for (auto block : func->getBlocks()) { for (auto inst : block->getModifiableChildren()) @@ -2099,16 +2126,22 @@ void legalizeDefUse(IRGlobalValueWithCode* func) } SLANG_ASSERT(commonDominator); - if (commonDominator == block) + // If commonDominator is 'block' and if the inst is not a Var in + // a loop condition block, we can skip the legalization. + // + if (commonDominator == block && + !(as<IRVar>(inst) && loopHeaderBlockMap.containsKey(block))) continue; - // If the common dominator is not `block`, it means we have detected - // uses that is no longer dominated by the current definition, and need - // to be fixed. - - // Normally, we can simply move the definition to the common dominator. + // Normally, if the common dominator is not `block`, we can simply move the definition + // to the common dominator. // An exception is when the common dominator is the target block of a - // loop. Note that after normalization, loops are in the form of: + // loop. + // Another exception is when a var in the loop condition block is accessed both inside + // and outside the loop. It is technically visible, but effects on the 'var' are not + // visible outside the loop, so we'll need to hoist it out of the loop. + // + // Note that after normalization, loops are in the form of: // ``` // loop { if (condition) block; else break; } // ``` @@ -2117,38 +2150,47 @@ void legalizeDefUse(IRGlobalValueWithCode* func) // In this case, we should insert a var/move the inst before the loop // instead of before the `if`. This situation can occur in the IR if // the original code is lowered from a `do-while` loop. - for (auto use = commonDominator->firstUse; use; use = use->nextUse) + // + bool shouldInitializeVar = false; + if (loopHeaderBlockMap.containsKey(commonDominator)) { - if (auto loopUser = as<IRLoop>(use->getUser())) + bool shouldMoveToHeader = false; + + // Check that the break-block dominates any of the uses are past the break + // block + for (auto _use = inst->firstUse; _use; _use = _use->nextUse) { - if (loopUser->getTargetBlock() == commonDominator) + if (dom->dominates( + as<IRLoop>(loopHeaderBlockMap[commonDominator]->getTerminator()) + ->getBreakBlock(), + _use->getUser()->getParent())) { - bool shouldMoveToHeader = false; - // Check that the break-block dominates any of the uses are past the break - // block - for (auto _use = inst->firstUse; _use; _use = _use->nextUse) - { - if (dom->dominates( - loopUser->getBreakBlock(), - _use->getUser()->getParent())) - { - shouldMoveToHeader = true; - break; - } - } - - if (shouldMoveToHeader) - commonDominator = as<IRBlock>(loopUser->getParent()); + shouldMoveToHeader = true; break; } } + if (shouldMoveToHeader) + { + commonDominator = loopHeaderBlockMap[commonDominator]; + shouldInitializeVar = true; + } } + // Now we can legalize uses based on the type of `inst`. if (auto var = as<IRVar>(inst)) { // If inst is an var, this is easy, we just move it to the // common dominator. var->insertBefore(commonDominator->getTerminator()); + if (shouldInitializeVar) + { + IRBuilder builder(func); + builder.setInsertAfter(var); + builder.emitStore( + var, + builder.emitDefaultConstruct( + as<IRPtrTypeBase>(var->getDataType())->getValueType())); + } } else { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index fb274c4a0..3a7ace37d 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5664,6 +5664,13 @@ IRInst* IRBuilder::emitDiscard() return inst; } +IRInst* IRBuilder::emitCheckpointObject(IRInst* value) +{ + auto inst = + createInst<IRCheckpointObject>(this, kIROp_CheckpointObject, value->getFullType(), value); + addInst(inst); + return inst; +} IRInst* IRBuilder::emitBranch(IRBlock* pBlock) { diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang index 2dfad0a61..51f17b611 100644 --- a/tests/autodiff/reverse-continue-loop.slang +++ b/tests/autodiff/reverse-continue-loop.slang @@ -16,7 +16,7 @@ float test_loop_with_continue(float y) //CHK-DAG: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item: float t = y; - //CHK-DAG: note: 4 bytes (int32_t) used for a loop counter here: + //CHK-DAG: note: 4 bytes (int32_t) used to checkpoint the following item: for (int i = 0; i < 3; i++) { if (t > 4.0) diff --git a/tests/autodiff/reverse-loop-immediate-return.slang b/tests/autodiff/reverse-loop-immediate-return.slang new file mode 100644 index 000000000..121836115 --- /dev/null +++ b/tests/autodiff/reverse-loop-immediate-return.slang @@ -0,0 +1,59 @@ + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + + +[BackwardDerivative(set_bwd)] +void set(uint idx, float x) +{ + outputBuffer[idx] = x; +} + +void set_bwd(uint idx, inout DifferentialPair<float> x) +{ + // For debugging, we'll set the derivative to 1.0 + x = DifferentialPair<float>(x.p, 1.0f); +} + +[Differentiable] +void run( + uint idx, + float x) +{ + if (idx >= 1) return; + + if (idx == 0) + { } + + for (int i = 0; i < 1; i++) + { + if (idx > 0) + { + return; + } + + if (idx == 0) + { + x = x * 2.0f; + } + } + + if (idx == 0) + { } + + set(idx, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // bwd_diff + DifferentialPair<float> dpa = DifferentialPair<float>(1.0, 0.0); + bwd_diff(run)(dispatchThreadID.x, dpa); + outputBuffer[dispatchThreadID.x] = dpa.d; + + // CHECK: type: float + // CHECK: 2.0 +} |
