diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-21 14:28:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-21 14:28:57 -0700 |
| commit | 957a4d3eb0a14a9d57bbb325ef0e1d458df2d2b9 (patch) | |
| tree | fabc9317b1595c9f74f5b25ee83d16f4260a19d3 /source | |
| parent | 69a327a98e3f9504863f9ecb623aa93036ac43db (diff) | |
Refactor checkpointing policy and availability pass. (#2826)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 357 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 895 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.h | 62 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-region.h | 25 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 560 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 224 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 48 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-dce.cpp | 24 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 21 | ||||
| -rw-r--r-- | source/slang/slang-ir-loop-unroll.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-ir-redundancy-removal.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 7 |
18 files changed, 1150 insertions, 1164 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 80ee37988..a67b7f167 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -4,6 +4,7 @@ #include "slang-ir-ssa.h" #include "slang-ir-validate.h" +#include "slang-ir-util.h" namespace Slang { @@ -17,31 +18,26 @@ struct RegionEndpoint bool isRegionEmpty = false; - RegionEndpoint(IRBlock* exitBlock, bool inBreakRegion, bool inBaseRegion) : - exitBlock(exitBlock), - inBreakRegion(inBreakRegion), - inBaseRegion(inBaseRegion), - isRegionEmpty(false) - { } - - RegionEndpoint( - IRBlock* exitBlock, - bool inBreakRegion, - bool inBaseRegion, - bool isRegionEmpty) : - exitBlock(exitBlock), - inBreakRegion(inBreakRegion), - inBaseRegion(inBaseRegion), - isRegionEmpty(isRegionEmpty) - { } - - RegionEndpoint() - { } + RegionEndpoint(IRBlock* exitBlock, bool inBreakRegion, bool inBaseRegion) + : exitBlock(exitBlock) + , inBreakRegion(inBreakRegion) + , inBaseRegion(inBaseRegion) + , isRegionEmpty(false) + {} + + RegionEndpoint(IRBlock* exitBlock, bool inBreakRegion, bool inBaseRegion, bool isRegionEmpty) + : exitBlock(exitBlock) + , inBreakRegion(inBreakRegion) + , inBaseRegion(inBaseRegion) + , isRegionEmpty(isRegionEmpty) + {} + + RegionEndpoint() {} }; struct BreakableRegionInfo { - IRVar* breakVar; + IRVar* breakVar; IRBlock* breakBlock; IRBlock* headerBlock; }; @@ -49,16 +45,15 @@ struct BreakableRegionInfo struct CFGNormalizationContext { IRModule* module; - DiagnosticSink* sink; + DiagnosticSink* sink; }; - IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst) { // For now, we're going to naively assume the next block is the condition block. // Add in more support for more cases as necessary. - // - + // + auto firstBlock = loopInst->getTargetBlock(); if (as<IRIfElse>(firstBlock->getTerminator())) @@ -72,7 +67,7 @@ IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst) // IRBuilder condBuilder(loopInst->getModule()); - + auto condBlock = condBuilder.emitBlock(); condBlock->insertAfter(as<IRBlock>(loopInst->getParent())); @@ -81,22 +76,17 @@ IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst) // Emit a condition: true side goes to the loop body, and // false side goes into the break block. - // + // condBuilder.setInsertInto(condBlock); auto ifElse = as<IRIfElse>(condBuilder.emitIfElse( - condBuilder.getBoolValue(true), - firstBlock, - loopInst->getBreakBlock(), - firstBlock)); - + condBuilder.getBoolValue(true), firstBlock, loopInst->getBreakBlock(), firstBlock)); + // We'll insert a blank block between the condition and the // break block, since otherwise, we might trip up the later // parts of this pass. // - condBuilder.insertBlockAlongEdge( - loopInst->getModule(), - IREdge(&ifElse->falseBlock)); - + condBuilder.insertBlockAlongEdge(loopInst->getModule(), IREdge(&ifElse->falseBlock)); + return condBlock; } } @@ -105,9 +95,9 @@ struct CFGNormalizationPass { CFGNormalizationContext cfgContext; - CFGNormalizationPass(CFGNormalizationContext ctx) : - cfgContext(ctx) - { } + CFGNormalizationPass(CFGNormalizationContext ctx) + : cfgContext(ctx) + {} void replaceBreakWithAfterBlock( IRBuilder* builder, @@ -158,13 +148,12 @@ struct CFGNormalizationPass return branchInst ? branchInst->getTargetBlock() : nullptr; } - bool isSuccessorBlock(IRBlock* baseBlock, IRBlock* succBlock) { for (auto successor : baseBlock->getSuccessors()) if (successor == succBlock) return true; - + return false; } @@ -184,9 +173,7 @@ struct CFGNormalizationPass } RegionEndpoint getNormalizedRegionEndpoint( - BreakableRegionInfo* parentRegion, - IRBlock* entryBlock, - List<IRBlock*> afterBlocks) + BreakableRegionInfo* parentRegion, IRBlock* entryBlock, List<IRBlock*> afterBlocks) { IRBlock* currentBlock = entryBlock; _moveVarsToRegionHeader(parentRegion, currentBlock); @@ -195,7 +182,7 @@ struct CFGNormalizationPass // and not in the 'break' control flow // It is the job of the *caller* to make sure the break flow // does not reach this point. - // + // bool currBreakRegion = false; bool currBaseRegion = true; @@ -204,7 +191,7 @@ struct CFGNormalizationPass // if (afterBlocks.contains(currentBlock)) return RegionEndpoint(currentBlock, currBreakRegion, currBaseRegion, true); - + IRBuilder builder(cfgContext.module); List<IRBlock*> pendingAfterBlocks; @@ -216,7 +203,7 @@ struct CFGNormalizationPass // We could arrive at the after-block before or // after encountering a break statement. // To handle this, we'll split the flow by checking the break flag - // + // builder.setInsertAfter(block); auto preAfterSplitBlock = builder.emitBlock(); @@ -229,28 +216,24 @@ struct CFGNormalizationPass builder.setInsertInto(preAfterSplitBlock); builder.emitBranch(afterSplitBlock); - + // Converging block for the split that we're making. auto afterSplitAfterBlock = builder.emitBlock(); builder.setInsertInto(afterSplitBlock); auto breakFlagValue = builder.emitLoad(parentRegion->breakVar); - builder.emitIfElse( - breakFlagValue, - block, - afterSplitAfterBlock, - afterSplitAfterBlock); + builder.emitIfElse(breakFlagValue, block, afterSplitAfterBlock, afterSplitAfterBlock); // At this point, we need to place afterSplitAfterBlock between - // at the _end_ of this region, but we aren't there yet (and + // at the _end_ of this region, but we aren't there yet (and // don't know which block is the end of this region) // Therefore, we'll defer this step and add it to a list for later. - // + // pendingAfterBlocks.add(afterSplitAfterBlock); }; - // Follow this thread of execution till we hit an + // Follow this thread of execution till we hit an // acceptable after block. // while (!afterBlocks.contains(maybeGetUnconditionalTarget(currentBlock))) @@ -259,14 +242,14 @@ struct CFGNormalizationPass auto terminator = currentBlock->getTerminator(); switch (terminator->getOp()) { - case kIROp_unconditionalBranch: + case kIROp_unconditionalBranch: { auto targetBlock = as<IRUnconditionalBranch>(terminator)->getTargetBlock(); currentBlock = targetBlock; break; } - - case kIROp_ifElse: + + case kIROp_ifElse: { auto ifElse = as<IRIfElse>(terminator); @@ -274,24 +257,24 @@ struct CFGNormalizationPass // lead back to the condition. // SLANG_ASSERT(ifElse->getAfterBlock() != parentRegion->breakBlock); - + auto trueEndPoint = getNormalizedRegionEndpoint( parentRegion, ifElse->getTrueBlock(), List<IRBlock*>(ifElse->getAfterBlock(), parentRegion->breakBlock)); - + auto falseEndPoint = getNormalizedRegionEndpoint( parentRegion, ifElse->getFalseBlock(), List<IRBlock*>(ifElse->getAfterBlock(), parentRegion->breakBlock)); - + auto trueTargetBlock = getUnconditionalTarget(trueEndPoint); auto falseTargetBlock = getUnconditionalTarget(falseEndPoint); - + auto afterBlock = ifElse->getAfterBlock(); // Trivial case, both end-points branch into the after block - /*if (trueTargetBlock == afterBlock && + /*if (trueTargetBlock == afterBlock && falseTargetBlock == afterBlock) { if () @@ -308,7 +291,7 @@ struct CFGNormalizationPass { // Branch into after block (and set break variable) replaceBreakWithAfterBlock( - &builder, + &builder, parentRegion, trueEndPoint.exitBlock, afterBlock, @@ -321,10 +304,10 @@ struct CFGNormalizationPass } else { - // If this branch naturally branches into our + // If this branch naturally branches into our // after-block, copy whatever flags the endpoints // have. - // + // afterBreakRegion = afterBreakRegion || trueEndPoint.inBreakRegion; afterBaseRegion = afterBaseRegion || trueEndPoint.inBaseRegion; } @@ -346,10 +329,10 @@ struct CFGNormalizationPass } else { - // If this branch naturally branches into our + // If this branch naturally branches into our // after-block, copy whatever flags the endpoints // have. - // + // afterBreakRegion = afterBreakRegion || falseEndPoint.inBreakRegion; afterBaseRegion = afterBaseRegion || falseEndPoint.inBaseRegion; } @@ -365,12 +348,12 @@ struct CFGNormalizationPass // Do we need to split the after region? if (afterBaseRegion && afterBreakRegion) { - // Before we split the afterBlock, we + // Before we split the afterBlock, we // want to make sure the afterBlock is // firmly _inside_ the current region. - // If it's part of the parent, add a + // If it's part of the parent, add a // dummy block. - // + // if (afterBlocks.contains(afterBlock)) { auto newAfterBlock = builder.emitBlock(); @@ -382,15 +365,17 @@ struct CFGNormalizationPass // condition block. (This eventually causes cloneInst to fail, // since it is currently order-dependent) // Remove this once cloneInst is order-independent. - // + // // newAfterBlock->insertBefore(afterBlock); newAfterBlock->insertAfter(falseEndPoint.exitBlock); builder.emitBranch(afterBlock); - + ifElse->afterBlock.set(newAfterBlock); - as<IRUnconditionalBranch>(trueEndPoint.exitBlock->getTerminator())->block.set(newAfterBlock); - as<IRUnconditionalBranch>(falseEndPoint.exitBlock->getTerminator())->block.set(newAfterBlock); + as<IRUnconditionalBranch>(trueEndPoint.exitBlock->getTerminator()) + ->block.set(newAfterBlock); + as<IRUnconditionalBranch>(falseEndPoint.exitBlock->getTerminator()) + ->block.set(newAfterBlock); afterBlock = newAfterBlock; } @@ -402,15 +387,15 @@ struct CFGNormalizationPass afterBreakRegion = false; afterBaseRegion = true; } - + currentBlock = afterBlock; currBreakRegion = afterBreakRegion; currBaseRegion = afterBaseRegion; break; } - case kIROp_loop: - case kIROp_Switch: + case kIROp_loop: + case kIROp_Switch: { auto breakBlock = normalizeBreakableRegion(terminator); @@ -419,10 +404,10 @@ struct CFGNormalizationPass break; } - default: - // Do proper diagnosing - SLANG_UNEXPECTED("Unhandled control flow inst"); - break; + default: + // Do proper diagnosing + SLANG_UNEXPECTED("Unhandled control flow inst"); + break; } _moveVarsToRegionHeader(parentRegion, currentBlock); @@ -438,7 +423,7 @@ struct CFGNormalizationPass SLANG_ASSERT(nextRegionBlock); builder.emitBranch(nextRegionBlock); - + builder.setInsertInto(currentBlock); currentBlock->getTerminator()->removeAndDeallocate(); builder.emitBranch(block); @@ -458,7 +443,7 @@ struct CFGNormalizationPass HashSet<IRBlock*> predecessorSet; for (auto predecessor : block->getPredecessors()) predecessorSet.Add(predecessor); - + return predecessorSet; } @@ -466,29 +451,27 @@ struct CFGNormalizationPass { // Get 'looping' block (first block in loop) auto firstLoopBlock = loop->getTargetBlock(); - + // If we only have one predecessor, the loop is trivial. return (getPredecessorSet(firstLoopBlock).Count() == 1); } - IRBlock* normalizeBreakableRegion( - IRInst* branchInst) + IRBlock* normalizeBreakableRegion(IRInst* branchInst) { IRBuilder builder(cfgContext.module); switch (branchInst->getOp()) { - case kIROp_loop: + case kIROp_loop: { BreakableRegionInfo info; info.breakBlock = as<IRLoop>(branchInst)->getBreakBlock(); info.headerBlock = as<IRBlock>(branchInst->getParent()); // Emit var into parent block. - builder.setInsertBefore( - as<IRBlock>(branchInst->getParent())->getTerminator()); - - // Create and initialize break var to true + builder.setInsertBefore(as<IRBlock>(branchInst->getParent())->getTerminator()); + + // Create and initialize break var to true // true -> no break yet. // false -> atleast one break statement hit. // @@ -500,24 +483,23 @@ struct CFGNormalizationPass // edges actually in a loop), we're just going to remove // it.. (we can do this, because the normalization pass // will transform any break and continue statements) - // + // if (isLoopTrivial(as<IRLoop>(branchInst))) { auto firstLoopBlock = as<IRLoop>(branchInst)->getTargetBlock(); - + // Normalize the region from the first loop block till break. auto preBreakEndPoint = getNormalizedRegionEndpoint( - &info, - firstLoopBlock, - List<IRBlock*>(info.breakBlock)); - + &info, firstLoopBlock, List<IRBlock*>(info.breakBlock)); + // Should not be empty.. but check anyway SLANG_RELEASE_ASSERT(!preBreakEndPoint.isRegionEmpty); - // Quick consistency check.. preBreakEndPoint should be + // Quick consistency check.. preBreakEndPoint should be // branching into break block. - SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>( - preBreakEndPoint.exitBlock->getTerminator())->getTargetBlock() == info.breakBlock); + SLANG_RELEASE_ASSERT( + as<IRUnconditionalBranch>(preBreakEndPoint.exitBlock->getTerminator()) + ->getTargetBlock() == info.breakBlock); auto currentBlock = branchInst->getParent(); @@ -529,30 +511,27 @@ struct CFGNormalizationPass return info.breakBlock; } - auto condBlock = getOrCreateTopLevelCondition(as<IRLoop>(branchInst)); + auto condBlock = + getOrCreateTopLevelCondition(as<IRLoop>(branchInst)); auto ifElse = as<IRIfElse>(condBlock->getTerminator()); auto trueEndPoint = getNormalizedRegionEndpoint( - &info, - ifElse->getTrueBlock(), - List<IRBlock*>(condBlock, info.breakBlock)); - + &info, ifElse->getTrueBlock(), List<IRBlock*>(condBlock, info.breakBlock)); + auto falseEndPoint = getNormalizedRegionEndpoint( - &info, - ifElse->getFalseBlock(), - List<IRBlock*>(condBlock, info.breakBlock)); - + &info, ifElse->getFalseBlock(), List<IRBlock*>(condBlock, info.breakBlock)); + RegionEndpoint loopEndPoint; bool isLoopOnTrueSide = true; - + // First figure out which side belongs to the loop body. if (isSuccessorBlock(trueEndPoint.exitBlock, condBlock)) { loopEndPoint = trueEndPoint; isLoopOnTrueSide = true; } - + if (isSuccessorBlock(falseEndPoint.exitBlock, condBlock)) { loopEndPoint = falseEndPoint; @@ -560,11 +539,11 @@ struct CFGNormalizationPass } // Right now, we only support loops where the loop is on the true side of - // the condition. If we ever encounter the other case, fill in logic to + // the condition. If we ever encounter the other case, fill in logic to // flip the condition. // SLANG_RELEASE_ASSERT(isLoopOnTrueSide); - + // Expect atleast one basic block (other than the condition block), in // the loop. // @@ -573,7 +552,7 @@ struct CFGNormalizationPass // Does the loop endpoint have both 'break' and 'base' // control flows? - // + // if (loopEndPoint.inBaseRegion && loopEndPoint.inBreakRegion) { // Add a test for the break variable into the condition. @@ -582,36 +561,30 @@ struct CFGNormalizationPass builder.setInsertBefore(ifElse); auto breakFlagVal = builder.emitLoad(info.breakVar); - // Need to invert the break flag if the loop is + // Need to invert the break flag if the loop is // on the false side. - // + // if (!isLoopOnTrueSide) { IRInst* args[1] = {breakFlagVal}; - breakFlagVal = builder.emitIntrinsicInst( - builder.getBoolType(), - kIROp_Not, - 1, - args); + breakFlagVal = + builder.emitIntrinsicInst(builder.getBoolType(), kIROp_Not, 1, args); } IRInst* args[2] = {cond, breakFlagVal}; // If break-var = true, direct flow to the loop // otherwise, direct flow to break - // - auto complexCond = builder.emitIntrinsicInst( - builder.getBoolType(), - kIROp_And, - 2, - args); - + // + auto complexCond = + builder.emitIntrinsicInst(builder.getBoolType(), kIROp_And, 2, args); + ifElse->condition.set(complexCond); } - + return info.breakBlock; } - case kIROp_Switch: + case kIROp_Switch: { auto switchInst = as<IRSwitch>(branchInst); @@ -620,10 +593,9 @@ struct CFGNormalizationPass info.breakBlock = as<IRSwitch>(branchInst)->getBreakLabel(); // Emit var into parent block. - builder.setInsertBefore( - as<IRBlock>(branchInst->getParent())->getTerminator()); - - // Create and initialize break var to true + builder.setInsertBefore(as<IRBlock>(branchInst->getParent())->getTerminator()); + + // Create and initialize break var to true // true -> no break yet. // false -> atleast one break statement hit. // @@ -635,30 +607,31 @@ struct CFGNormalizationPass { auto caseBlock = switchInst->getCaseLabel(ii); auto caseEndPoint = getNormalizedRegionEndpoint( - &info, - caseBlock, - List<IRBlock*>(info.breakBlock)).exitBlock; + &info, caseBlock, List<IRBlock*>(info.breakBlock)) + .exitBlock; // Consistency check (if this case hits, it's probably // because the switch has fall-through, which we don't support) - SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>( - caseEndPoint->getTerminator())->getTargetBlock() == info.breakBlock); + SLANG_RELEASE_ASSERT( + as<IRUnconditionalBranch>(caseEndPoint->getTerminator()) + ->getTargetBlock() == info.breakBlock); } - auto defaultEndPoint = getNormalizedRegionEndpoint( - &info, - switchInst->getDefaultLabel(), - List<IRBlock*>(info.breakBlock)).exitBlock; + auto defaultEndPoint = + getNormalizedRegionEndpoint( + &info, switchInst->getDefaultLabel(), List<IRBlock*>(info.breakBlock)) + .exitBlock; // Consistency check (if this case hits, it's probably // because the switch has fall-through, which we don't support) - SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>( - defaultEndPoint->getTerminator())->getTargetBlock() == info.breakBlock); + SLANG_RELEASE_ASSERT( + as<IRUnconditionalBranch>(defaultEndPoint->getTerminator())->getTargetBlock() == + info.breakBlock); return info.breakBlock; } - default: - break; + default: + break; } SLANG_UNEXPECTED("Unhandled control-flow inst"); @@ -666,18 +639,16 @@ struct CFGNormalizationPass }; void normalizeCFG( - IRModule* module, - IRGlobalValueWithCode* func, - IRCFGNormalizationPass const& options) + IRModule* module, IRGlobalValueWithCode* func, IRCFGNormalizationPass const& options) { // Remove phis to simplify our pass. We'll add them back in later // with constructSSA. - // + // eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func); - CFGNormalizationContext context = {module, options.sink}; + CFGNormalizationContext context = {module, options.sink}; CFGNormalizationPass cfgPass(context); - + List<IRBlock*> workList; workList.add(func->getFirstBlock()); @@ -703,9 +674,83 @@ void normalizeCFG( } } + // If we created a new condition block for a loop, the local vars defined in + // the original loop body will no longer dominate the exit block of the + // loop. If there are any uses of these variables outside the loop, they + // will become invalid. Therefore we need to hoist the local variables to + // the loop header block. + HashSet<IRBlock*> workListSet; + for (auto block : func->getBlocks()) + { + if (auto loop = as<IRLoop>(block->getTerminator())) + { + auto condBlock = loop->getTargetBlock(); + auto ifElse = as<IRIfElse>(condBlock->getTerminator()); + auto bodyBlock = ifElse->getTrueBlock(); + + // Collect loop body blocks. + workList.clear(); + workListSet.Clear(); + workList.add(bodyBlock); + workListSet.add(bodyBlock); + for (Index i = 0; i < workList.getCount(); i++) + { + auto b = workList[i]; + for (auto succ : b->getSuccessors()) + { + if (succ != loop->getTargetBlock() && succ != loop->getBreakBlock()) + { + if (workListSet.add(succ)) + workList.add(succ); + } + } + } + auto insertionPoint = loop; + IRBuilder builder(func); + for (auto b : workList) + { + for (auto inst : b->getChildren()) + { + // If inst has uses outside the loop body, we need to hoist it. + IRVar* tempVar = nullptr; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto userBlock = as<IRBlock>(use->getUser()->getParent()); + if (userBlock && !workListSet.Contains(userBlock)) + { + // Hoist the inst. + if (auto var = as<IRVar>(inst)) + { + // If inst is an var, this is easy, we just move it to the + // loop header. + var->insertBefore(insertionPoint); + break; + } + else + { + // For all other insts, we need to create a local var for it. + if (!tempVar) + { + builder.setInsertBefore(insertionPoint); + tempVar = builder.emitVar(inst->getFullType()); + builder.setInsertAfter(inst); + builder.emitStore(tempVar, inst); + } + // Replace the use with a load of tempVar. + builder.setInsertBefore(use->getUser()); + auto load = builder.emitLoad(tempVar); + builder.replaceOperand(use, load); + } + break; + } + } + } + } + } + } disableIRValidationAtInsert(); constructSSA(module, func); enableIRValidationAtInsert(); } -} +} // namespace Slang diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 363572c86..6a9b504a6 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -1,19 +1,277 @@ #include "slang-ir-autodiff-primal-hoist.h" #include "slang-ir-autodiff-region.h" -namespace Slang +namespace Slang { +void applyCheckpointSet( + CheckpointSetInfo* checkpointInfo, + IRGlobalValueWithCode* func, + HoistedPrimalsInfo* hoistInfo, + HashSet<IRUse*> pendingUses, + Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock); + bool containsOperand(IRInst* inst, IRInst* operand) { for (UIndex ii = 0; ii < inst->getOperandCount(); ii++) if (inst->getOperand(ii) == operand) return true; - + return false; } -RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalValueWithCode* func, BlockSplitInfo* splitInfo) +static bool isDifferentialInst(IRInst* inst) +{ + auto parent = inst->getParent(); + if (parent->findDecoration<IRDifferentialInstDecoration>()) + return true; + return inst->findDecoration<IRDifferentialInstDecoration>() != nullptr; +} + +static bool isDifferentialBlock(IRBlock* block) +{ + return block->findDecoration<IRDifferentialInstDecoration>(); +} + +static Dictionary<IRBlock*, IRBlock*> reconstructDiffBlockMap(IRGlobalValueWithCode* func) +{ + Dictionary<IRBlock*, IRBlock*> diffBlockMap; + for (auto block : func->getBlocks()) + { + if (auto diffDecor = block->findDecoration<IRDifferentialInstDecoration>()) + { + if (diffDecor->getPrimalType()) + diffBlockMap[as<IRBlock>(diffDecor->getPrimalInst())] = block; + } + } + return diffBlockMap; +} + +static IRBlock* getLoopRegionBodyBlock(IRLoop* loop) +{ + auto condBlock = as<IRBlock>(loop->getTargetBlock()); + // We assume the loop body always sit at the true side of the if-else. + if (auto ifElse = as<IRIfElse>(condBlock->getTerminator())) + { + return ifElse->getTrueBlock(); + } + return nullptr; +} + +static IRBlock* tryGetSubRegionEndBlock(IRInst* terminator) +{ + auto loop = as<IRLoop>(terminator); + if (!loop) + return nullptr; + return loop->getBreakBlock(); +} + +static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks( + IRGlobalValueWithCode* func, + Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo) +{ + IRBlock* firstDiffBlock = nullptr; + for (auto block : func->getBlocks()) + { + if (isDifferentialBlock(block)) + { + firstDiffBlock = block; + break; + } + } + if (!firstDiffBlock) + return Dictionary<IRBlock*, IRBlock*>(); + + Dictionary<IRLoop*, IRLoop*> mapPrimalLoopToDiffLoop; + for (auto block : func->getBlocks()) + { + if (isDifferentialBlock(block)) + { + if (auto diffLoop = as<IRLoop>(block->getTerminator())) + { + if (auto diffDecor = diffLoop->findDecoration<IRDifferentialInstDecoration>()) + { + mapPrimalLoopToDiffLoop[as<IRLoop>(diffDecor->getPrimalInst())] = diffLoop; + } + } + } + } + + IRBuilder builder(func); + Dictionary<IRBlock*, IRBlock*> recomputeBlockMap; + + // Create the first recompute block right before the first diff block, + // and change all jumps into the diff block to the recompute block instead. + auto createRecomputeBlock = [&](IRBlock* primalBlock) + { + auto recomputeBlock = builder.createBlock(); + recomputeBlock->insertAtEnd(func); + builder.addDecoration(recomputeBlock, kIROp_RecomputeBlockDecoration); + recomputeBlockMap.Add(primalBlock, recomputeBlock); + indexedBlockInfo[recomputeBlock] = indexedBlockInfo[primalBlock].GetValue(); + return recomputeBlock; + }; + + auto firstRecomputeBlock = createRecomputeBlock(func->getFirstBlock()); + firstRecomputeBlock->insertBefore(firstDiffBlock); + moveParams(firstRecomputeBlock, firstDiffBlock); + firstDiffBlock->replaceUsesWith(firstRecomputeBlock); + + struct WorkItem + { + // The first primal block in this region. + IRBlock* primalBlock; + + // The recompute block created for the first primal block in this region. + IRBlock* recomptueBlock; + + // The end of primal block in tihs region. + IRBlock* regionEndBlock; + + // The first diff block in this region. + IRBlock* firstDiffBlock; + }; + + List<WorkItem> workList; + WorkItem firstWorkItem = { func->getFirstBlock(), firstRecomputeBlock, firstRecomputeBlock, firstDiffBlock }; + workList.add(firstWorkItem); + + IRCloneEnv recomputeCloneEnv; + recomputeBlockMap[func->getFirstBlock()] = firstRecomputeBlock; + + for (Index i = 0; i < workList.getCount(); i++) + { + auto workItem = workList[i]; + auto primalBlock = workItem.primalBlock; + auto recomputeBlock = workItem.recomptueBlock; + + List<IndexTrackingInfo>* thisBlockIndexInfo = indexedBlockInfo.TryGetValue(primalBlock); + if (!thisBlockIndexInfo) + continue; + + builder.setInsertInto(recomputeBlock); + if (auto subRegionEndBlock = tryGetSubRegionEndBlock(primalBlock->getTerminator())) + { + // The terminal inst of primalBlock marks the start of a sub loop region? + // We need to queue work for both the next region after the loop at the current level, + // and for the sub region for the next level. + if (subRegionEndBlock == workItem.regionEndBlock) + { + // We have reached the end of top-level region, jump to first diff block. + builder.emitBranch(workItem.firstDiffBlock); + } + else + { + // Have we already created a recompute block for this target? + // If so, use it. + IRBlock* existingRecomputeBlock = nullptr; + if (recomputeBlockMap.TryGetValue(subRegionEndBlock, existingRecomputeBlock)) + { + builder.emitBranch(existingRecomputeBlock); + } + else + { + // Queue work for the next region after the subregion at this level. + auto nextRegionRecomputeBlock = createRecomputeBlock(subRegionEndBlock); + nextRegionRecomputeBlock->insertAfter(recomputeBlock); + builder.emitBranch(nextRegionRecomputeBlock); + + { + WorkItem newWorkItem = { + subRegionEndBlock, + nextRegionRecomputeBlock, + workItem.regionEndBlock, + workItem.firstDiffBlock }; + workList.add(newWorkItem); + } + } + } + // Queue work for the subregion. + auto loop = as<IRLoop>(primalBlock->getTerminator()); + auto bodyBlock = getLoopRegionBodyBlock(loop); + auto diffLoop = mapPrimalLoopToDiffLoop[loop].GetValue(); + auto diffBodyBlock = getLoopRegionBodyBlock(diffLoop); + auto bodyRecomputeBlock = createRecomputeBlock(bodyBlock); + bodyRecomputeBlock->insertBefore(diffBodyBlock); + diffBodyBlock->replaceUsesWith(bodyRecomputeBlock); + moveParams(bodyRecomputeBlock, diffBodyBlock); + { + // After CFG normalization, the loop body will contain only jumps to the + // beginning of the loop. + // If we see such a jump, it means we have reached the end of current + // region in the loop. + // Therefore, we set the regionEndBlock for the sub-region as loop's target + // block. + WorkItem newWorkItem = { + bodyBlock, bodyRecomputeBlock, loop->getTargetBlock(), diffBodyBlock}; + workList.add(newWorkItem); + } + } + else + { + // This is a normal control flow, just copy the CFG structure. + auto terminator = primalBlock->getTerminator(); + IRInst* newTerminator = nullptr; + switch (terminator->getOp()) + { + case kIROp_Switch: + case kIROp_ifElse: + newTerminator = cloneInst(&recomputeCloneEnv, &builder, primalBlock->getTerminator()); + break; + case kIROp_unconditionalBranch: + newTerminator = builder.emitBranch(as<IRUnconditionalBranch>(terminator)->getTargetBlock()); + break; + default: + SLANG_UNREACHABLE("terminator type"); + } + + // Modify jump targets in newTerminator to point to the right recompute block or firstDiffBlock. + for (UInt op = 0; op < newTerminator->getOperandCount(); op++) + { + auto target = as<IRBlock>(newTerminator->getOperand(op)); + if (!target) + continue; + if (target == workItem.regionEndBlock) + { + // This jump target is the end of the current region, we will jump to + // firstDiffBlock instead. + newTerminator->setOperand(op, workItem.firstDiffBlock); + continue; + } + + // Have we already created a recompute block for this target? + // If so, use it. + IRBlock* existingRecomputeBlock = nullptr; + if (recomputeBlockMap.TryGetValue(target, existingRecomputeBlock)) + { + newTerminator->setOperand(op, existingRecomputeBlock); + continue; + } + + // This jump target is a normal part of control flow, clone the next block. + auto targetRecomputeBlock = createRecomputeBlock(target); + targetRecomputeBlock->insertBefore(workItem.firstDiffBlock); + + newTerminator->setOperand(op, targetRecomputeBlock); + + // Queue work for the successor. + WorkItem newWorkItem = { + target, + targetRecomputeBlock, + workItem.regionEndBlock, + workItem.firstDiffBlock}; + workList.add(newWorkItem); + } + } + } + // After this pass, all primal blocks except the condition block and the false block of a loop + // will have a corresponding recomputeBlock. + return recomputeBlockMap; +} + +RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( + IRGlobalValueWithCode* func, + Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock) { RefPtr<CheckpointSetInfo> checkpointInfo = new CheckpointSetInfo(); @@ -29,11 +287,11 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal UIndex opIndex = 0; for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++) { - if (!operand->get()->findDecoration<IRDifferentialInstDecoration>() && + if (!isDifferentialInst(operand->get()) && !as<IRFunc>(operand->get()) && !as<IRBlock>(operand->get()) && !(as<IRModuleInst>(operand->get()->getParent())) && - !getBlock(operand->get())->findDecoration<IRDifferentialInstDecoration>()) + !isDifferentialBlock(getBlock(operand->get()))) workList.add(operand); } @@ -44,7 +302,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal // if (inst->getDataType() && (getParentFunc(inst->getDataType()) == func)) { - if (!getBlock(inst->getDataType())->findDecoration<IRDifferentialInstDecoration>()) + if (!isDifferentialBlock(getBlock(inst->getDataType()))) workList.add(&inst->typeUse); } }; @@ -58,7 +316,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal if (block == func->getFirstBlock()) continue; - if (!block->findDecoration<IRDifferentialInstDecoration>()) + if (!isDifferentialBlock(block)) continue; for (auto child : block->getChildren()) @@ -111,7 +369,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal SLANG_ASSERT(!checkpointInfo->storeSet.Contains(result.instToRecompute)); checkpointInfo->recomputeSet.Add(result.instToRecompute); - if (use->getUser()->findDecoration<IRDifferentialInstDecoration>()) + if (isDifferentialInst(use->getUser())) usesToReplace.Add(use); if (auto param = as<IRParam>(result.instToRecompute)) @@ -160,7 +418,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal SLANG_RELEASE_ASSERT(containsOperand(instToInvert, use->getUser())); SLANG_RELEASE_ASSERT(result.inversionInfo.targetInsts.contains(use->getUser())); - if (use->getUser()->findDecoration<IRDifferentialInstDecoration>()) + if (isDifferentialInst(use->getUser())) usesToReplace.Add(use); checkpointInfo->invertSet.Add(instToInvert); @@ -178,7 +436,55 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal } } - return applyCheckpointSet(checkpointInfo, func, splitInfo, usesToReplace); + // If a var or call is in recomputeSet, move any var/calls associated with the same call to + // recomputeSet. + List<IRInst*> instWorkList; + HashSet<IRInst*> instWorkListSet; + for (auto inst : checkpointInfo->recomputeSet) + { + switch (inst->getOp()) + { + case kIROp_Call: + case kIROp_Var: + instWorkList.add(inst); + instWorkListSet.add(inst); + break; + } + } + for (Index i = 0; i < instWorkList.getCount(); i++) + { + auto inst = instWorkList[i]; + if (auto var = as<IRVar>(inst)) + { + for (auto use = var->firstUse; use; use = use->nextUse) + { + auto callUser = as<IRCall>(use->getUser()); + if (!callUser) + continue; + checkpointInfo->recomputeSet.add(callUser); + checkpointInfo->storeSet.Remove(callUser); + if (instWorkListSet.add(callUser)) + instWorkList.add(callUser); + } + } + else if (auto call = as<IRCall>(inst)) + { + for (UInt j = 0; j < call->getArgCount(); j++) + { + if (auto varArg = as<IRVar>(call->getArg(j))) + { + checkpointInfo->recomputeSet.add(varArg); + checkpointInfo->storeSet.Remove(varArg); + if (instWorkListSet.add(varArg)) + instWorkList.add(varArg); + } + } + } + } + + RefPtr<HoistedPrimalsInfo> hoistInfo = new HoistedPrimalsInfo(); + applyCheckpointSet(checkpointInfo, func, hoistInfo, usesToReplace, mapDiffBlockToRecomputeBlock); + return hoistInfo; } void applyToInst( @@ -195,6 +501,11 @@ void applyToInst( return; } + if (hoistInfo->ignoreSet.Contains(inst)) + { + return; + } + bool isInstRecomputed = checkpointInfo->recomputeSet.Contains(inst); if (isInstRecomputed) { @@ -242,13 +553,15 @@ void applyToInst( } } -RefPtr<HoistedPrimalsInfo> applyCheckpointSet( +void applyCheckpointSet( CheckpointSetInfo* checkpointInfo, IRGlobalValueWithCode* func, - BlockSplitInfo* splitInfo, - HashSet<IRUse*> pendingUses) + HoistedPrimalsInfo* hoistInfo, + HashSet<IRUse*> pendingUses, + Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock) { - RefPtr<HoistedPrimalsInfo> hoistInfo = new HoistedPrimalsInfo(); + // Reconstruct diff block map. + Dictionary<IRBlock*, IRBlock*> diffBlockMap = reconstructDiffBlockMap(func); RefPtr<IROutOfOrderCloneContext> cloneCtx = new IROutOfOrderCloneContext(); @@ -264,7 +577,7 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointSet( UIndex opIndex = 0; for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++) { - if (!operand->get()->findDecoration<IRDifferentialInstDecoration>()) + if (!isDifferentialInst(operand->get())) cloneCtx->pendingUses.Add(operand); } }; @@ -276,15 +589,15 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointSet( if (block == func->getFirstBlock()) continue; - if (block->findDecoration<IRDifferentialInstDecoration>()) + if (isDifferentialBlock(block)) + continue; + + if (block->findDecoration<IRRecomputeBlockDecoration>()) continue; - auto diffBlock = as<IRBlock>(splitInfo->diffBlockMap[block]); - - auto firstDiffInst = as<IRBlock>(splitInfo->diffBlockMap[block])->getFirstOrdinaryInst(); + auto diffBlock = as<IRBlock>(diffBlockMap[block]); IRBuilder builder(func->getModule()); - UIndex ii = 0; for (auto param : block->getParams()) { @@ -302,48 +615,58 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointSet( predecessorSet.Add(predecessor); - auto diffPredecessor = as<IRBlock>(splitInfo->diffBlockMap[block]); + auto diffPredecessor = as<IRBlock>(diffBlockMap[block]); if (checkpointInfo->recomputeSet.Contains(param)) + { + IRInst* terminator = diffPredecessor->getTerminator(); addPhiOutputArg(&builder, diffPredecessor, + terminator, as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii)); + } if (checkpointInfo->invertSet.Contains(param)) + { + IRInst* terminator = diffPredecessor->getTerminator(); + addPhiOutputArg(&builder, diffPredecessor, + terminator, as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii)); + } } ii++; } + IRBlock* recomputeBlock = block; + mapPrimalBlockToRecomputeBlock.TryGetValue(block, recomputeBlock); + auto recomputeInsertBeforeInst = recomputeBlock->getFirstOrdinaryInst(); + for (auto child : block->getChildren()) { - builder.setInsertBefore(firstDiffInst); - + builder.setInsertBefore(recomputeInsertBeforeInst); applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child); } } - - return hoistInfo; } IRType* getTypeForLocalStorage( IRBuilder* builder, IRType* storageType, - List<IndexTrackingInfo*> defBlockIndices) + const List<IndexTrackingInfo>& defBlockIndices) { - for (auto index : defBlockIndices) + for (auto& index : defBlockIndices) { - SLANG_ASSERT(index->status == IndexTrackingInfo::CountStatus::Static); - SLANG_ASSERT(index->maxIters >= 0); + SLANG_ASSERT(index.status == IndexTrackingInfo::CountStatus::Static); + SLANG_ASSERT(index.maxIters >= 0); storageType = builder->getArrayType( storageType, builder->getIntValue( builder->getUIntType(), - index->maxIters + 1)); + index.maxIters + 1)); } return storageType; @@ -352,7 +675,7 @@ IRType* getTypeForLocalStorage( IRVar* emitIndexedLocalVar( IRBlock* varBlock, IRType* baseType, - List<IndexTrackingInfo*> defBlockIndices) + const List<IndexTrackingInfo>& defBlockIndices) { SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType)); @@ -370,19 +693,19 @@ IRVar* emitIndexedLocalVar( IRInst* emitIndexedStoreAddressForVar( IRBuilder* builder, IRVar* localVar, - List<IndexTrackingInfo*> defBlockIndices) + const List<IndexTrackingInfo>& defBlockIndices) { IRInst* storeAddr = localVar; IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType(); - for (auto index : defBlockIndices) + for (auto& index : defBlockIndices) { currType = as<IRArrayType>(currType)->getElementType(); storeAddr = builder->emitElementAddress( builder->getPtrType(currType), storeAddr, - index->primalCountParam); + index.primalCountParam); } return storeAddr; @@ -392,8 +715,8 @@ IRInst* emitIndexedStoreAddressForVar( IRInst* emitIndexedLoadAddressForVar( IRBuilder* builder, IRVar* localVar, - List<IndexTrackingInfo*> defBlockIndices, - List<IndexTrackingInfo*> useBlockIndices) + const List<IndexTrackingInfo>& defBlockIndices, + const List<IndexTrackingInfo>& useBlockIndices) { IRInst* loadAddr = localVar; IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType(); @@ -406,7 +729,7 @@ IRInst* emitIndexedLoadAddressForVar( // If the use-block is under the same region, use the // differential counter variable // - auto diffCounterCurrValue = index->diffCountParam; + auto diffCounterCurrValue = index.diffCountParam; loadAddr = builder->emitElementAddress( builder->getPtrType(currType), @@ -418,7 +741,7 @@ IRInst* emitIndexedLoadAddressForVar( // If the use-block is outside this region, use the // last available value (by indexing with primal counter minus 1) // - auto primalCounterCurrValue = builder->emitLoad(index->primalCountLastVar); + auto primalCounterCurrValue = index.primalCountParam; auto primalCounterLastValue = builder->emitSub( primalCounterCurrValue->getDataType(), primalCounterCurrValue, @@ -438,7 +761,7 @@ IRVar* storeIndexedValue( IRBuilder* builder, IRBlock* defaultVarBlock, IRInst* instToStore, - List<IndexTrackingInfo*> defBlockIndices) + const List<IndexTrackingInfo>& defBlockIndices) { IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices); @@ -452,8 +775,8 @@ IRVar* storeIndexedValue( IRInst* loadIndexedValue( IRBuilder* builder, IRVar* localVar, - List<IndexTrackingInfo*> defBlockIndices, - List<IndexTrackingInfo*> useBlockIndices) + const List<IndexTrackingInfo>& defBlockIndices, + const List<IndexTrackingInfo>& useBlockIndices) { IRInst* addr = emitIndexedLoadAddressForVar(builder, localVar, defBlockIndices, useBlockIndices); @@ -461,15 +784,15 @@ IRInst* loadIndexedValue( } bool areIndicesEqual( - List<IndexTrackingInfo*> indicesA, - List<IndexTrackingInfo*> indicesB) + const List<IndexTrackingInfo>& indicesA, + const List<IndexTrackingInfo>& indicesB) { if (indicesA.getCount() != indicesB.getCount()) return false; for (Index ii = 0; ii < indicesA.getCount(); ii++) { - if (indicesA[ii] != indicesB[ii]) + if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam) return false; } @@ -477,31 +800,37 @@ bool areIndicesEqual( } bool areIndicesSubsetOf( - List<IndexTrackingInfo*> indicesA, - List<IndexTrackingInfo*> indicesB) + List<IndexTrackingInfo>& indicesA, + List<IndexTrackingInfo>& indicesB) { if (indicesA.getCount() > indicesB.getCount()) return false; for (Index ii = 0; ii < indicesA.getCount(); ii++) { - if (indicesA[ii] != indicesB[ii]) + if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam) return false; } return true; } - -bool isDifferentialBlock(IRBlock* block) +static int getInstRegionNestLevel( + Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo, + IRBlock* defBlock, + IRInst* inst) { - return block->findDecoration<IRDifferentialInstDecoration>(); + auto result = indexedBlockInfo[defBlock].GetValue().getCount(); + // Loop counters are considered to not belong to the region started by the its loop. + if (result > 0 && inst->findDecoration<IRLoopCounterDecoration>()) + result--; + return (int)result; } RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( HoistedPrimalsInfo* hoistInfo, IRGlobalValueWithCode* func, - Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlockInfo) + Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo) { RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); @@ -510,129 +839,369 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock)); - HashSet<IRInst*> processedStoreSet; + OrderedHashSet<IRInst*> processedStoreSet; - // TODO: Also ensure availability of everything in the recompute set (for proper recompute support) - for (auto instToStore : hoistInfo->storeSet) + auto ensureInstAvailable = [&](OrderedHashSet<IRInst*>& instSet) { - IRBlock* defBlock = nullptr; - if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType())) + for (auto instToStore : instSet) { - auto varInst = as<IRVar>(instToStore); - auto storeUse = findUniqueStoredVal(varInst); + if (!instSet.Contains(instToStore)) + continue; - defBlock = getBlock(storeUse->getUser()); - } - else - defBlock = getBlock(instToStore); + if (hoistInfo->ignoreSet.Contains(instToStore)) + continue; + IRBlock* defBlock = nullptr; + if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType())) + { + auto varInst = as<IRVar>(instToStore); + auto storeUse = findUniqueStoredVal(varInst); - SLANG_RELEASE_ASSERT(defBlock); + defBlock = getBlock(storeUse->getUser()); + } + else + defBlock = getBlock(instToStore); - List<IRUse*> outOfScopeUses; - for (auto use = instToStore->firstUse; use;) - { - auto nextUse = use->nextUse; - - // Only consider uses in differential blocks. - // This method is not responsible for other blocks. - // - IRBlock* userBlock = getBlock(use->getUser()); - if (userBlock->findDecoration<IRDifferentialInstDecoration>()) + SLANG_RELEASE_ASSERT(defBlock); + + List<IRUse*> outOfScopeUses; + for (auto use = instToStore->firstUse; use;) { - if (!domTree->dominates(defBlock, userBlock)) + auto nextUse = use->nextUse; + + // Only consider uses in differential blocks. + // This method is not responsible for other blocks. + // + IRBlock* userBlock = getBlock(use->getUser()); + if (isDifferentialOrRecomputeBlock(userBlock)) + { + if (!domTree->dominates(defBlock, userBlock)) + { + outOfScopeUses.add(use); + } + else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock])) + { + outOfScopeUses.add(use); + } + else if (getInstRegionNestLevel(indexedBlockInfo, defBlock, instToStore) > 0 && + !isDifferentialOrRecomputeBlock(defBlock)) + { + outOfScopeUses.add(use); + } + else if (as<IRPtrTypeBase>(instToStore->getDataType()) && + !isDifferentialOrRecomputeBlock(defBlock)) + { + outOfScopeUses.add(use); + } + } + + use = nextUse; + } + + if (outOfScopeUses.getCount() == 0) + { + processedStoreSet.Add(instToStore); + continue; + } + + if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType())) + { + + IRVar* varToStore = as<IRVar>(instToStore); + SLANG_RELEASE_ASSERT(varToStore); + + auto storeUse = findUniqueStoredVal(varToStore); + + List<IndexTrackingInfo>& defBlockIndices = indexedBlockInfo[defBlock]; + + bool isIndexedStore = (storeUse && defBlockIndices.getCount() > 0); + + // TODO: There's a slight hackiness here. (Ideally we might just want to emit + // additional vars when splitting a call) + // + if (!isIndexedStore && isDerivativeContextVar(varToStore)) { - outOfScopeUses.add(use); + varToStore->insertBefore(defaultVarBlock->getFirstOrdinaryInst()); + processedStoreSet.Add(varToStore); + continue; } - else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock])) + + setInsertAfterOrdinaryInst(&builder, getInstInBlock(storeUse->getUser())); + + IRVar* localVar = storeIndexedValue( + &builder, + defaultVarBlock, + builder.emitLoad(varToStore), + defBlockIndices); + + for (auto use : outOfScopeUses) { - outOfScopeUses.add(use); + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); + + List<IndexTrackingInfo>& useBlockIndices = indexedBlockInfo[getBlock(use->getUser())]; + + IRInst* loadAddr = emitIndexedLoadAddressForVar(&builder, localVar, defBlockIndices, useBlockIndices); + builder.replaceOperand(use, loadAddr); } - else if (indexedBlockInfo[defBlock].GetValue().getCount() > 0 && - !isDifferentialBlock(defBlock)) + + processedStoreSet.Add(localVar); + } + else + { + // Handle the special case of loop counters. + // The only case where there will be a reference of primal loop counter from rev blocks + // is the start of a loop in the reverse code. Since loop counters are not considered a + // part of their loop region, so we remove the first index info. + List<IndexTrackingInfo> defBlockIndices = indexedBlockInfo[defBlock]; + bool isLoopCounter = (instToStore->findDecoration<IRLoopCounterDecoration>() != nullptr); + if (isLoopCounter) { - outOfScopeUses.add(use); + defBlockIndices.removeAt(0); } - else if (as<IRPtrTypeBase>(instToStore->getDataType()) && - !isDifferentialBlock(defBlock)) + + setInsertAfterOrdinaryInst(&builder, instToStore); + auto localVar = storeIndexedValue(&builder, defaultVarBlock, instToStore, defBlockIndices); + + for (auto use : outOfScopeUses) { - outOfScopeUses.add(use); + List<IndexTrackingInfo> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())]; + if (isLoopCounter) + { + // The use site of a primal loop counter should be right before we enter the + // loop, and therefore its index count should equal to defBlockIndices.getCount() + // after we remove the first index from defBlockIndices. + SLANG_RELEASE_ASSERT(useBlockIndices.getCount() == defBlockIndices.getCount()); + } + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); + builder.replaceOperand(use, loadIndexedValue(&builder, localVar, defBlockIndices, useBlockIndices)); } - } - use = nextUse; + processedStoreSet.Add(localVar); + } } + }; - if (outOfScopeUses.getCount() == 0) - { - processedStoreSet.Add(instToStore); - continue; - } + ensureInstAvailable(hoistInfo->storeSet); + + // Replace the old store set with the processed one. + hoistInfo->storeSet = processedStoreSet; - if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType())) - { + return hoistInfo; +} - IRVar* varToStore = as<IRVar>(instToStore); - SLANG_RELEASE_ASSERT(varToStore); - - auto storeUse = findUniqueStoredVal(varToStore); - - List<IndexTrackingInfo*> defBlockIndices = indexedBlockInfo[defBlock]; - bool isIndexedStore = (storeUse && defBlockIndices.getCount() > 0); +void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info) +{ + if (info->status != IndexTrackingInfo::CountStatus::Unresolved) + return; - // TODO: There's a slight hackiness here. (Ideally we might just want to emit - // additional vars when splitting a call) - // - if (!isIndexedStore && isDerivativeContextVar(varToStore)) - { - varToStore->insertBefore(defaultVarBlock->getFirstOrdinaryInst()); - processedStoreSet.Add(varToStore); - continue; - } + auto loop = as<IRLoop>(region->getInitializerBlock()->getTerminator()); + + if (auto maxItersDecoration = loop->findDecoration<IRLoopMaxItersDecoration>()) + { + info->maxIters = (Count)maxItersDecoration->getMaxIters(); + info->status = IndexTrackingInfo::CountStatus::Static; + } +} - setInsertAfterOrdinaryInst(&builder, getInstInBlock(storeUse->getUser())); - IRVar* localVar = storeIndexedValue( - &builder, - defaultVarBlock, - builder.emitLoad(varToStore), - defBlockIndices); +IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type) +{ + builder->setInsertInto(block); + return builder->emitParam(type); +} - for (auto use : outOfScopeUses) - { - setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); - - List<IndexTrackingInfo*> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())]; +IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type, UIndex index) +{ + List<IRParam*> params; + for (auto param : block->getParams()) + params.add(param); - IRInst* loadAddr = emitIndexedLoadAddressForVar(&builder, localVar, defBlockIndices, useBlockIndices); - builder.replaceOperand(use, loadAddr); - } + SLANG_RELEASE_ASSERT(index == (UCount)params.getCount()); - processedStoreSet.Add(localVar); - } - else - { - setInsertAfterOrdinaryInst(&builder, instToStore); + return addPhiInputParam(builder, block, type); +} - List<IndexTrackingInfo*> defBlockIndices = indexedBlockInfo[defBlock]; - auto localVar = storeIndexedValue(&builder, defaultVarBlock, instToStore, defBlockIndices); - - for (auto use : outOfScopeUses) - { - setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); +static IRBlock* getUpdateBlock(IRLoop* loop) +{ + auto initBlock = cast<IRBlock>(loop->getParent()); - List<IndexTrackingInfo*> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())]; - builder.replaceOperand(use, loadIndexedValue(&builder, localVar, defBlockIndices, useBlockIndices)); - } + auto condBlock = loop->getTargetBlock(); + + IRBlock* lastLoopBlock = nullptr; + + for (auto predecessor : condBlock->getPredecessors()) + { + if (predecessor != initBlock) + lastLoopBlock = predecessor; + } + + // Should find atleast one predecessor that is _not_ the + // init block (that contains the loop info). This + // predecessor would be the last block in the loop + // before looping back to the condition. + // + SLANG_RELEASE_ASSERT(lastLoopBlock); + + return lastLoopBlock; +} + +void lowerIndexedRegion(IRLoop*& primalLoop, IRLoop*& diffLoop, IRInst*& primalCountParam, IRInst*& diffCountParam) +{ + IRBuilder builder(primalLoop); + primalCountParam = nullptr; + + // Grab first primal block. + IRBlock* primalInitBlock = as<IRBlock>(primalLoop->getParent()); + builder.setInsertBefore(primalInitBlock->getTerminator()); + { + auto primalCondBlock = as<IRUnconditionalBranch>( + primalInitBlock->getTerminator())->getTargetBlock(); + builder.setInsertBefore(primalInitBlock->getTerminator()); + + auto phiCounterArgLoopEntryIndex = addPhiOutputArg( + &builder, + primalInitBlock, + *(IRInst**)&primalLoop, + builder.getIntValue(builder.getIntType(), 0)); + + builder.setInsertBefore(primalCondBlock->getTerminator()); + primalCountParam = addPhiInputParam( + &builder, + primalCondBlock, + builder.getIntType(), + phiCounterArgLoopEntryIndex); + builder.addLoopCounterDecoration(primalCountParam); + builder.addNameHintDecoration(primalCountParam, UnownedStringSlice("_pc")); + builder.markInstAsPrimal(primalCountParam); + + IRBlock* primalUpdateBlock = getUpdateBlock(primalLoop); + IRInst* terminator = primalUpdateBlock->getTerminator(); + builder.setInsertBefore(primalUpdateBlock->getTerminator()); + + auto incCounterVal = builder.emitAdd( + builder.getIntType(), + primalCountParam, + builder.getIntValue(builder.getIntType(), 1)); + builder.markInstAsPrimal(incCounterVal); + + auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, primalUpdateBlock, terminator, incCounterVal); + + SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex); + } + + { + IRBlock* diffInitBlock = as<IRBlock>(diffLoop->getParent()); + + auto diffCondBlock = as<IRUnconditionalBranch>( + diffInitBlock->getTerminator())->getTargetBlock(); + builder.setInsertBefore(diffInitBlock->getTerminator()); + auto revCounterInitVal = builder.emitSub( + builder.getIntType(), + primalCountParam, + builder.getIntValue(builder.getIntType(), 1)); + auto phiCounterArgLoopEntryIndex = addPhiOutputArg( + &builder, + diffInitBlock, + *(IRInst**)&diffLoop, + revCounterInitVal); + + builder.setInsertBefore(diffCondBlock->getTerminator()); + + diffCountParam = addPhiInputParam( + &builder, + diffCondBlock, + builder.getIntType(), + phiCounterArgLoopEntryIndex); + builder.addNameHintDecoration(diffCountParam, UnownedStringSlice("_dc")); + builder.markInstAsPrimal(diffCountParam); + + IRBlock* diffUpdateBlock = getUpdateBlock(diffLoop); + builder.setInsertBefore(diffUpdateBlock->getTerminator()); + IRInst* terminator = diffUpdateBlock->getTerminator(); + + auto decCounterVal = builder.emitSub( + builder.getIntType(), + diffCountParam, + builder.getIntValue(builder.getIntType(), 1)); + builder.markInstAsPrimal(decCounterVal); + + auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, diffUpdateBlock, terminator, decCounterVal); + + auto ifElse = as<IRIfElse>(diffCondBlock->getTerminator()); + builder.setInsertBefore(ifElse); + auto exitCondition = builder.emitGeq(diffCountParam, builder.getIntValue(builder.getIntType(), 0)); + ifElse->condition.set(exitCondition); + + SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex); + } +} + +void buildIndexedBlocks( + Dictionary<IRBlock*, List<IndexTrackingInfo>>& info, + IRGlobalValueWithCode* func) +{ + Dictionary<IRLoop*, IndexTrackingInfo> mapLoopToTrackingInfo; - processedStoreSet.Add(localVar); + for (auto block : func->getBlocks()) + { + auto loop = as<IRLoop>(block->getTerminator()); + if (!loop) continue; + auto diffDecor = loop->findDecoration<IRDifferentialInstDecoration>(); + if (!diffDecor) continue; + auto primalLoop = as<IRLoop>(diffDecor->getPrimalInst()); + if (!primalLoop) continue; + + IndexTrackingInfo indexInfo = {}; + lowerIndexedRegion(primalLoop, loop, indexInfo.primalCountParam, indexInfo.diffCountParam); + + SLANG_RELEASE_ASSERT(indexInfo.primalCountParam); + SLANG_RELEASE_ASSERT(indexInfo.diffCountParam); + + mapLoopToTrackingInfo[loop] = indexInfo; + mapLoopToTrackingInfo[primalLoop] = indexInfo; + } + + auto regionMap = buildIndexedRegionMap(func); + + for (auto block : func->getBlocks()) + { + List<IndexTrackingInfo> trackingInfos; + for (auto region : regionMap->getAllAncestorRegions(block)) + { + IndexTrackingInfo trackingInfo; + if (mapLoopToTrackingInfo.TryGetValue(region->loop, trackingInfo)) + { + tryInferMaxIndex(region, &trackingInfo); + trackingInfos.add(trackingInfo); + } } + info[block] = trackingInfos; } - - // Replace the old store set with the processed onne one. - hoistInfo->storeSet = processedStoreSet; +} - return hoistInfo; +RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy( + IRGlobalValueWithCode* func, const List<IRInst*>& instsToIgnore) +{ + sortBlocksInFunc(func); + + Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo; + buildIndexedBlocks(indexedBlockInfo, func); + + auto recomputeBlockMap = createPrimalRecomputeBlocks(func, indexedBlockInfo); + + sortBlocksInFunc(func); + + RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(func->getModule()); + chkPolicy->preparePolicy(func); + + auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap); + + for (auto propagateFuncSpecificInst : instsToIgnore) + { + primalsInfo->ignoreSet.add(propagateFuncSpecificInst); + } + primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); + return primalsInfo; } void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func) @@ -716,7 +1285,7 @@ static bool shouldStoreVar(IRVar* var) { for (UInt i = 0; i < spec->getArgCount(); i++) { - if (!canTypeBeStored(spec->getArg(i)->getDataType())) + if (!canTypeBeStored(spec->getArg(i))) return false; } } @@ -772,6 +1341,30 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_ExtractExistentialWitnessTable: case kIROp_undefined: case kIROp_GetSequentialID: + case kIROp_Specialize: + case kIROp_LookupWitness: +#if 0 + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Neg: + case kIROp_Geq: + case kIROp_Leq: + case kIROp_Neq: + case kIROp_Eql: + case kIROp_Greater: + case kIROp_Less: + case kIROp_And: + case kIROp_Or: + case kIROp_Not: + case kIROp_BitNot: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Lsh: + case kIROp_Rsh: +#endif return false; case kIROp_GetElement: case kIROp_FieldExtract: @@ -791,6 +1384,9 @@ static bool shouldStoreInst(IRInst* inst) break; } + if (as<IRType>(inst)) + return false; + // Only store if the inst has differential inst user. bool hasDiffUser = doesInstHaveDiffUse(inst); if (!hasDiffUser) @@ -801,22 +1397,11 @@ static bool shouldStoreInst(IRInst* inst) bool canRecompute(IRDominatorTree* domTree, IRUse* use) { + SLANG_UNUSED(domTree); auto param = as<IRParam>(use->get()); if (!param) return true; - auto paramBlock = as<IRBlock>(param->getParent()); - for (auto predecessor : paramBlock->getPredecessors()) - { - // If we hit this, the checkpoint policy is trying to recompute - // values across a loop region boundary (we don't currently support this, - // and in general this is quite inefficient in both compute & memory) - // - if (domTree->dominates(paramBlock, predecessor)) - { - return false; - } - } - return true; + return false; } HoistResult DefaultCheckpointPolicy::classify(IRUse* use) diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index bd2575172..3b3fb82b1 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -7,7 +7,6 @@ #include "slang-ir-autodiff-region.h" #include "slang-ir-dominators.h" - namespace Slang { struct IROutOfOrderCloneContext : public RefObject @@ -84,11 +83,11 @@ namespace Slang struct HoistedPrimalsInfo : public RefObject { - HashSet<IRInst*> storeSet; - HashSet<IRInst*> recomputeSet; - HashSet<IRInst*> invertSet; - - HashSet<IRInst*> instsToInvert; + OrderedHashSet<IRInst*> storeSet; + OrderedHashSet<IRInst*> recomputeSet; + OrderedHashSet<IRInst*> invertSet; + OrderedHashSet<IRInst*> ignoreSet; + OrderedHashSet<IRInst*> instsToInvert; Dictionary<IRInst*, InversionInfo> invertInfoMap; @@ -130,6 +129,9 @@ namespace Slang for (auto inst : info->invertSet) invertSet.Add(inst); + for (auto inst : info->ignoreSet) + ignoreSet.add(inst); + for (auto inst : info->instsToInvert) instsToInvert.Add(inst); @@ -195,6 +197,31 @@ namespace Slang } }; + struct IndexTrackingInfo : public RefObject + { + // After lowering, store references to the count + // variables associated with this region + // + IRInst* primalCountParam = nullptr; + IRInst* diffCountParam = nullptr; + + enum CountStatus + { + Unresolved, + Dynamic, + Static + }; + + CountStatus status = CountStatus::Unresolved; + + // Inferred maximum number of iterations. + Count maxIters = -1; + + bool operator==(const IndexTrackingInfo& other) const + { + return primalCountParam == other.primalCountParam; + } + }; // Information on which insts are to be stored, recomputed // and inverted within a single function. @@ -210,6 +237,15 @@ namespace Slang Dictionary<IRInst*, InversionInfo> invInfoMap; }; + // Information on a block after it has been split in the unzip step. + // After unzipping, every block in the original function will have + // two corresponding blocks in the new function: + // - A 'primal-recompute' block, which contains the original instructions + // from the original block, but located in the corresponding the reverse + // diff region so their results are accessible in the diff block for + // derivative computation. + // - A 'diff' block, which contains the transcribed instructions from the + // original block. struct BlockSplitInfo : public RefObject { // Maps primal to differential blocks from the unzip step. @@ -223,7 +259,9 @@ namespace Slang AutodiffCheckpointPolicyBase(IRModule* module) : module(module) { } - RefPtr<HoistedPrimalsInfo> processFunc(IRGlobalValueWithCode* func, BlockSplitInfo* info); + RefPtr<HoistedPrimalsInfo> processFunc( + IRGlobalValueWithCode* func, + Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock); // Do pre-processing on the function (mainly for // 'global' checkpointing methods that consider the entire @@ -252,15 +290,9 @@ namespace Slang RefPtr<IRDominatorTree> domTree; }; - RefPtr<HoistedPrimalsInfo> applyCheckpointSet( - CheckpointSetInfo* checkpointInfo, + RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy( IRGlobalValueWithCode* func, - BlockSplitInfo* splitInfo, - HashSet<IRUse*> pendingUses); + const List<IRInst*>& instsToIgnore); - RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( - HoistedPrimalsInfo* hoistInfo, - IRGlobalValueWithCode* func, - Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlockInfo); }; diff --git a/source/slang/slang-ir-autodiff-region.h b/source/slang/slang-ir-autodiff-region.h index a4618e257..59a977619 100644 --- a/source/slang/slang-ir-autodiff-region.h +++ b/source/slang/slang-ir-autodiff-region.h @@ -50,29 +50,6 @@ struct IndexedRegion : public RefObject } }; -struct IndexTrackingInfo : public RefObject -{ - // After lowering, store references to the count - // variables associated with this region - // - IRInst* primalCountParam = nullptr; - IRInst* diffCountParam = nullptr; - - IRVar* primalCountLastVar = nullptr; - - enum CountStatus - { - Unresolved, - Dynamic, - Static - }; - - CountStatus status = CountStatus::Unresolved; - - // Inferred maximum number of iterations. - Count maxIters = -1; -}; - struct IndexedRegionMap : public RefObject { Dictionary<IRBlock*, IndexedRegion*> map; @@ -116,4 +93,4 @@ struct IndexedRegionMap : public RefObject RefPtr<IndexedRegionMap> buildIndexedRegionMap(IRGlobalValueWithCode* func); -};
\ No newline at end of file +}; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 0bdc4a935..979eb6343 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -679,10 +679,7 @@ namespace Slang // // diffPropagationPass->propagateDiffInstDecoration(builder, fwdDiffFunc); - // Copy primal insts to the first block of the unzipped function, copy diff insts to the - // second block of the unzipped function. - // - RefPtr<HoistedPrimalsInfo> primalsInfo = diffUnzipPass->unzipDiffInsts(fwdDiffFunc); + diffUnzipPass->unzipDiffInsts(fwdDiffFunc); IRFunc* unzippedFwdDiffFunc = fwdDiffFunc; // Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell. @@ -709,10 +706,17 @@ namespace Slang // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the // derivative of the return value. - DiffTransposePass::FuncTranspositionInfo transposeInfo = { paramTransposeInfo.dOutParam, primalsInfo }; + DiffTransposePass::FuncTranspositionInfo transposeInfo = { paramTransposeInfo.dOutParam }; diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, transposeInfo); + // Apply checkpointing policy to legalize cross-scope uses of primal values + // using either recompute or store strategies. + auto primalsInfo = applyCheckpointPolicy( + diffPropagateFunc, paramTransposeInfo.propagateFuncSpecificPrimalInsts); + + eliminateDeadCode(diffPropagateFunc); + // Extracts the primal computations into its own func, and replace the primal insts // with the intermediate results computed from the extracted func. @@ -907,6 +911,7 @@ namespace Slang // after transposition. auto tempVar = nextBlockBuilder.emitVar(diffType); copyNameHintDecoration(tempVar, fwdParam); + result.propagateFuncSpecificPrimalInsts.add(tempVar); // Initialize the var with input diff param at start. // Note that we insert the store in the primal block so it won't get transposed. @@ -993,9 +998,11 @@ namespace Slang // of the differential component of the pair. auto newParamLoad = diffBuilder.emitLoad(propParam); diffBuilder.markInstAsDifferential(newParamLoad, primalType); + result.propagateFuncSpecificPrimalInsts.add(newParamLoad); diffRefReplacement = diffBuilder.emitDifferentialPairGetDifferential(diffType, newParamLoad); diffBuilder.markInstAsDifferential(diffRefReplacement, primalType); + result.propagateFuncSpecificPrimalInsts.add(diffRefReplacement); // Load the primal component from the prop param and use it as replacement for the // primal param in the primal part of the prop func. @@ -1031,7 +1038,10 @@ namespace Slang // Load the inital diff value. auto loadedParam = nextBlockBuilder.emitLoad(diffParam); + result.propagateFuncSpecificPrimalInsts.add(loadedParam); + auto initDiff = nextBlockBuilder.emitDifferentialPairGetDifferential(diffType, loadedParam); + result.propagateFuncSpecificPrimalInsts.add(initDiff); // Create a local var for diff read access. auto diffVar = nextBlockBuilder.emitVar(diffType); @@ -1047,6 +1057,7 @@ namespace Slang // Create a local var for diff write access. auto diffWriteVar = nextBlockBuilder.emitVar(diffType); + result.propagateFuncSpecificPrimalInsts.add(diffWriteVar); copyNameHintDecoration(diffWriteVar, fwdParam); // Initialize write var to 0. diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 8c005a5c6..c7ac8c357 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -80,11 +80,6 @@ struct DiffTransposePass // of the *output* of the function. // IRInst* dOutInst; - - // Information from the unzip pass on how primal insts - // are split across the primal and differential blocks. - // - HoistedPrimalsInfo* hoistedPrimalsInfo; }; struct PendingBlockTerminatorEntry @@ -235,16 +230,12 @@ struct DiffTransposePass builder.setInsertInto(revCondBlock); - //hoistPrimalInst(&builder, ifElse->getCondition()); - - auto newIfElse = builder.emitIfElse( + builder.emitIfElse( ifElse->getCondition(), revTrueEntryBlock, revFalseEntryBlock, revAfterBlock); - hoistPrimalOperands(&builder, newIfElse); - if (!revTrueRegionInfo.isTrivial) { builder.setInsertInto(revTrueExitBlock); @@ -358,21 +349,21 @@ struct DiffTransposePass // Emit condition into the new cond block. builder.setInsertInto(revCondBlock); - // TODO: Need to defer this until after the CFG reversal is complete. - //hoistPrimalInst(&builder, ifElse->getCondition()); - - auto newIfElse = builder.emitIfElse( + builder.emitIfElse( ifElse->getCondition(), revTrueBlock, revFalseBlock, revTrueBlock); - - hoistPrimalOperands(&builder, newIfElse); - + + auto loopParentBlockDiffDecor = loop->getParent()->findDecoration<IRDifferentialInstDecoration>(); + SLANG_RELEASE_ASSERT(loopParentBlockDiffDecor); + auto primalBlock = as<IRBlock>(loopParentBlockDiffDecor->getPrimalInst()); + auto primalLoop = as<IRLoop>(primalBlock->getTerminator()); + SLANG_RELEASE_ASSERT(primalLoop); + // Old false-side starting block becomes end block // for the new pre-cond region (which could be empty) // - if (!falseRegionInfo.isTrivial) { IRBlock* revPreCondEndBlock = revBlockMap[falseBlock]; @@ -384,7 +375,8 @@ struct DiffTransposePass getPhiGrads(falseBlock).getCount(), getPhiGrads(falseBlock).getBuffer()); loop->transferDecorationsTo(revLoop); - + builder.markInstAsDifferential(revLoop, builder.getVoidType(), primalLoop); + auto revLoopStartBlock = revBlockMap[breakBlock]; builder.setInsertInto(revLoopStartBlock); builder.emitBranch( @@ -404,6 +396,7 @@ struct DiffTransposePass getPhiGrads(breakBlock).getCount(), getPhiGrads(breakBlock).getBuffer()); loop->transferDecorationsTo(revLoop); + builder.markInstAsDifferential(revLoop, builder.getVoidType(), primalLoop); } currentBlock = breakBlock; @@ -478,17 +471,13 @@ struct DiffTransposePass builder.setInsertInto(revSwitchBlock); - // hoistPrimalInst(&builder, switchInst->getCondition()); - - auto newSwitchInst = builder.emitSwitch( + builder.emitSwitch( switchInst->getCondition(), revBreakBlock, revDefaultRegionEntry, reverseSwitchArgs.getCount(), reverseSwitchArgs.getBuffer()); - hoistPrimalOperands(&builder, newSwitchInst); - currentBlock = breakBlock; break; } @@ -525,9 +514,7 @@ struct DiffTransposePass // (i.e. not store per-func info in 'this') // since it is reused for every reverse-mode call. // - - hoistedPrimalsInfo = transposeInfo.hoistedPrimalsInfo; - + primalVarsToHoist.clear(); // Grab all differentiable type information. diffTypeContext.setFunc(revDiffFunc); @@ -576,8 +563,10 @@ struct DiffTransposePass // Emit empty rev-mode blocks for every fwd-mode block. for (auto block : workList) { - revBlockMap[block] = builder.emitBlock(); - builder.markInstAsDifferential(revBlockMap[block]); + auto revBlock = builder.emitBlock(); + revBlockMap[block] = revBlock; + if (auto diffDecor = block->findDecoration<IRDifferentialInstDecoration>()) + builder.markInstAsDifferential(revBlockMap[block], builder.getBasicBlockType(), diffDecor->getPrimalInst()); } // Keep track of first diff block, since this is where @@ -637,20 +626,6 @@ struct DiffTransposePass auto firstFwdDiffBlock = branchInst->getTargetBlock(); reverseCFGRegion(firstFwdDiffBlock, List<IRBlock*>()); - // Lower any loop-exit-value decorations into initializations for loop intermediate vals, - // and convert loop initial values into terminating conditions. - // - // TODO: We need a way to confirm that all required vars have an initial value - // (is there a built-in dataflow tool for this?) - // - for (auto block : workList) - { - if (auto loopInst = as<IRLoop>(block->getTerminator())) - { - invertLoopCondition(&builder, loopInst); - } - } - // Link the last differential fwd-mode block (which will be the first // rev-mode block) as the successor to the last primal block. // We assume that the original function is in single-return form @@ -688,43 +663,9 @@ struct DiffTransposePass for (auto block : workList) block->removeFromParent(); - // Mark all primal operands for hoisting. - // TODO: Can we just merge this with finishHoistingPrimalInsts? - // TODO: Some of this logic is replicated in finishHoistingPrimalInsts. Merge it with the - // maybeAddOperandsToWorkList logic there. - // - for (auto block : workList) - { - IRBlock* revBlock = revBlockMap[block]; - - for (auto child = revBlock->getFirstChild(); child; child = child->getNextInst()) - { - hoistPrimalOperands(&builder, child); - - for (auto decoration = child->getFirstDecoration(); decoration; decoration = decoration->getNextDecoration()) - { - if (auto contextDecoration = as<IRBackwardDerivativePrimalContextDecoration>(decoration)) - hoistPrimalUse(&builder, &contextDecoration->primalContextVar); - - if (auto loopExitDecoration = as<IRLoopExitPrimalValueDecoration>(decoration)) - hoistPrimalUse(&builder, &loopExitDecoration->exitVal); - } - - if (auto instType = child->getDataType()) - if (!as<IRModuleInst>(instType->getParent())) - hoistPrimalUse(&builder, &child->typeUse); - } - } - finishHoistingPrimals(revDiffFunc); - for (auto block : workList) - { - auto revBlock = as<IRBlock>(revBlockMap[block]); - if (auto revLoop = as<IRLoop>(revBlock->getTerminator())) - lowerLoopExitValues(&builder, revLoop); - } - + // At this point, the only block left without terminator insts // should be the last one. Add a void return to complete it. // @@ -793,51 +734,6 @@ struct DiffTransposePass return tempRevVar; } - IRVar* lookupInverseVar(IRInst* inst) - { - return inverseVarMap[inst]; - } - - IRVar* getOrCreateInverseVar(IRInst* primalInst, IRGlobalValueWithCode* func) - { - IRBlock* varBlock = firstRevDiffBlockMap[func]; - return getOrCreateInverseVar(primalInst, varBlock); - } - - IRVar* getOrCreateInverseVar(IRInst* primalInst) - { - IRBlock* varBlock = firstRevDiffBlockMap[as<IRFunc>(primalInst->getParent()->getParent())]; - return getOrCreateInverseVar(primalInst, varBlock); - } - - IRVar* getOrCreateInverseVar(IRInst* primalInst, IRBlock* varBlock) - { - // No need to store inverse values for constants. - if (as<IRConstant>(primalInst)) - return nullptr; - - // Check if we have a var already. - if (inverseVarMap.ContainsKey(primalInst)) - return inverseVarMap[primalInst]; - - IRBuilder tempVarBuilder(autodiffContext->moduleInst); - - if (auto firstInst = varBlock->getFirstOrdinaryInst()) - tempVarBuilder.setInsertBefore(firstInst); - else - tempVarBuilder.setInsertInto(varBlock); - - auto primalType = primalInst->getDataType(); - - // Emit a var in the top-level differential block to hold the inverse, - // and initialize it. - auto tempInvVar = tempVarBuilder.emitVar(primalType); - - inverseVarMap[primalInst] = tempInvVar; - - return tempInvVar; - } - bool isInstUsedOutsideParentBlock(IRInst* inst) { auto currBlock = inst->getParent(); @@ -900,37 +796,9 @@ struct DiffTransposePass revParam, nullptr)); } - else if (hasInverse(arg)) - { - InversionInfo invInfo = this->hoistedPrimalsInfo->invertInfoMap[branchInst]; - if (invInfo.targetInsts.contains(arg)) - { - SLANG_ASSERT(hasInverse(getParamAt(branchInst->getTargetBlock(), ii))); - - // If the output arg is a primal, emit a parameter - // to accept it as an _input_ for the reverse-mode - // - auto primalType = arg->getDataType(); - auto primalInvParam = builder.emitParam(primalType); - - invBuilder.setInsertBefore(branchInst); - setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam); - } - } else { - if (hasInverse(getParamAt(branchInst->getTargetBlock(), ii))) - { - auto primalType = arg->getDataType(); - auto primalInvParam = builder.emitParam(primalType); - - invBuilder.setInsertBefore(branchInst); - setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam); - } - else - { - SLANG_UNEXPECTED("Encountered phi-param is not differential and is not marked for inversion"); - } + SLANG_UNEXPECTED("Encountered phi-param is not differential and is not marked for inversion"); } } } @@ -989,15 +857,6 @@ struct DiffTransposePass if (isDifferentialInst(child)) transposeInst(&builder, child); - else if (shouldInstBeInverted(child)) - { - // We'll collect inverse insts in an orphaned block, - // so disable IR validation temporarily. - // - disableIRValidationAtInsert(); - invertInst(&invBuilder, child); - enableIRValidationAtInsert(); - } } // After processing the block's instructions, we 'flush' any remaining gradients @@ -1046,10 +905,6 @@ struct DiffTransposePass emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param))); } } - else if (hasInverse(param)) - { - phiParamRevGradInsts.add(param); - } else { SLANG_UNEXPECTED("param is neither differential inst nor marked for inversion"); @@ -1169,46 +1024,6 @@ struct DiffTransposePass } } - // NOTE: This is a workaround for the fact that we expect inverses to use - // single-use variables. The loop exit value will add a - // second store to most inv-variables and mess with the primal hoisting mechanism. - // Instead of emitting into the orphaned inverse block, we'll directly emit into - // the reverse-mode block since we'll be running this _after_ the primal hoisting - // pass. - // - // This workaround is fine for inverting loop counters, but when we want to - // expand to supporting general-purpose adjoints, we would want to use per-region - // inverse vars based on 'invInfo' (enforcing single-use vars) - // - void lowerLoopExitValues(IRBuilder* builder, IRLoop* revLoop) - { - List<IRDecoration*> processedDecorations; - for (auto decoration : revLoop->getDecorations()) - { - if (auto loopExitValueDecoration = as<IRLoopExitPrimalValueDecoration>(decoration)) - { - builder->setInsertBefore(revLoop); - setInverse( - builder, - nullptr, - builder->getFunc(), - loopExitValueDecoration->getTargetInst(), - loopExitValueDecoration->getLoopExitValInst()); - - processedDecorations.add(loopExitValueDecoration); - } - } - - for (auto decoration : processedDecorations) - decoration->removeAndDeallocate(); - } - - void lowerLoopExitValues(IRBuilder* builder, IRBlock* block) - { - if (auto loopInst = as<IRLoop>(block->getTerminator())) - lowerLoopExitValues(builder, loopInst); - } - // Go through loop block phi-args, and look for loop counter // arguments, which for a loop means inserting a check into // loop condition block. @@ -1253,41 +1068,9 @@ struct DiffTransposePass loopCounterParam, loopCounterInitVal).getBuffer()); - hoistPrimalOperands(builder, paramBoundsCheck); - as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(paramBoundsCheck); } - List<InvInstPair> invertInst(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo) - { - switch (primalInst->getOp()) - { - case kIROp_Add: - case kIROp_Sub: - return invertArithmetic(builder, primalInst, invInfo); - - default: - SLANG_UNIMPLEMENTED_X("Unhandled inst type for inversion"); - } - } - - bool hasInverse(IRInst* primalInst) - { - return this->hoistedPrimalsInfo->invertSet.Contains(primalInst); - } - - IRInst* loadInverse(IRBuilder* builder, IRInst* primalInst) - { - // Note: There are other possible cases here, although not important - // right now. For example, a value is available to load from the primal block. - // - - if (auto invVar = getOrCreateInverseVar(primalInst, builder->getFunc())) - return builder->emitLoad(invVar); - - return nullptr; - } - IRInst* lookupInstInPrimalBlock(IRInst* invInst) { // Lookup the inst in the primal block whose value we can use as an operand @@ -1296,37 +1079,7 @@ struct DiffTransposePass // auto inversionInfo = this->hoistedPrimalsInfo->invertInfoMap[invInst]; return invInst; } - - void setInverse(IRBuilder* builder, IRBlock* defBlock, IRGlobalValueWithCode* func, IRInst* inst, IRInst* invInst) - { - auto instBlock = as<IRBlock>(inst->getParent()); - if (!instBlock) - return; - - disableIRValidationAtInsert(); - if (auto invVar = getOrCreateInverseVar(inst, func)) - { - auto invStore = builder->emitStore(invVar, invInst); - mapStoreToDefBlock[as<IRStore>(invStore)] = defBlock; - } - enableIRValidationAtInsert(); - } - - bool shouldInstBeInverted(IRInst* inst) - { - - if (this->hoistedPrimalsInfo->instsToInvert.Contains(inst)) - return true; - - return false; - } - - IRInst* hoistPrimalUse(IRBuilder*, IRUse* use) - { - primalUsesToHoist.add(use); - return use->get(); - } - + bool doesInstRequireHoisting(IRInst* inst) { if (as<IRModuleInst>(inst->getParent())) @@ -1336,10 +1089,10 @@ struct DiffTransposePass as<IRGlobalValueWithCode>(inst) || as<IRConstant>(inst)) return false; - + if (as<IRTerminatorInst>(inst)) return false; - + if (as<IRDecoration>(inst)) return doesInstRequireHoisting(getInstInBlock(inst)); @@ -1347,30 +1100,9 @@ struct DiffTransposePass // that have not yet been moved to the 'active' blocks // (i.e in diff blocks that do not have parents) // - return (!isDifferentialInst(inst) && - (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) && - getBlock(inst)->getParent() == nullptr); - } - - // Builds a map from inst to a list of uses by primal _inverted_ insts. - Dictionary<IRInst*, List<IRInst*>> buildInvOperandMap() - { - Dictionary<IRInst*, List<IRInst*>> invOperandMap; - for (auto kvpair : this->hoistedPrimalsInfo->invertInfoMap) - { - InversionInfo invInfo = kvpair.Value; - - for (auto operand : invInfo.requiredOperands) - { - if (!invOperandMap.ContainsKey(operand)) - invOperandMap[operand] = List<IRInst*>(); - - for (auto target : invInfo.targetInsts) - invOperandMap[operand].GetValue().add(target); - } - } - - return invOperandMap; + return (!isDifferentialInst(inst) && + (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) && + getBlock(inst)->getParent() == nullptr); } IRBlock* walkToEndOfRegion(IRBlock* block) @@ -1435,186 +1167,13 @@ struct DiffTransposePass void finishHoistingPrimals(IRGlobalValueWithCode* func) { - List<IRInst*> workList; - - Dictionary<IRInst*, IRInst*> hoistedInstMap; - - RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); - - Dictionary<IRInst*, List<IRInst*>> invOperandMap = buildInvOperandMap(); - auto varBlock = func->getFirstBlock()->getNextBlock(); - - // Load up pending insts into workList. - for (auto use : primalUsesToHoist) - workList.add(use->get()); - - primalUsesToHoist.clear(); - - auto maybeAddPrimalOperandsToWorkList = [&](IRInst* inst) + for (auto inst : primalVarsToHoist) { - UIndex opIndex = 0; - for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++) - { - if (doesInstRequireHoisting(operand->get()) && - !hoistedInstMap.ContainsKey(operand->get())) - { - workList.add(operand->get()); - } - } - - if (auto instType = inst->getDataType()) - { - if (doesInstRequireHoisting(instType) && - !hoistedInstMap.ContainsKey(instType)) - workList.add(instType); - } - }; - - auto maybeAddUsersToWorkList = [&](IRInst* inst) - { - for (auto use = inst->firstUse; use; use = use->nextUse) - { - if (doesInstRequireHoisting(use->getUser())) - { - if (as<IRVar>(inst) && as<IRStore>(use->getUser())) - continue; - - // Uses that haven't already been hoisted into reverse-mode - // blocks, and are not in the invert-set are pending uses. - // - if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser())) - workList.add(use->getUser()); - } - } - }; - - auto doesInstHavePendingUsers = [&](IRInst* inst) - { - for (auto use = inst->firstUse; use; use = use->nextUse) - { - if (doesInstRequireHoisting(use->getUser())) - { - if (as<IRVar>(inst) && as<IRStore>(use->getUser())) - continue; - - // Users that haven't already been hoisted into reverse-mode - // blocks are pending users. - // - if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser())) - return true; - } - } - - return false; - }; - - auto isInstHoisted = [&](IRInst* inst) - { - return getBlock(inst)->getParent() != nullptr && isDifferentialInst(getBlock(inst)); - }; - - while (workList.getCount() > 0) - { - // Pop work item - auto inst = workList.getLast(); - workList.removeLast(); - - // Already hoisted to reverse-mode block. - // replace with mapped inst (in case it's different) - // and continue on.. (this should actually never be hit) - // - if (hoistedInstMap.ContainsKey(inst)) - continue; - - if (invOperandMap.ContainsKey(inst)) - { - List<IRInst*> pendingInvDependencies; - for (auto dependency : invOperandMap[inst].GetValue()) - { - if (doesInstRequireHoisting(dependency) && - !hoistedInstMap.ContainsKey(dependency)) - pendingInvDependencies.add(dependency); - } - - if (pendingInvDependencies.getCount() > 0) - { - workList.add(inst); - for (auto dependency : pendingInvDependencies) - workList.add(dependency); - - // Skip until all the dependencies have been handled. - continue; - } - } - - // Are the uses of this primal inst already hoisted into the reverse-mode - // blocks? We cannot hoist this inst unless the uses are hoisted. - // - if (doesInstHavePendingUsers(inst)) - { - // Add inst back to work list. - workList.add(inst); - - // Then, add all the pending use to the top of - // list, ensuring they are processed before we see - // inst again. - // - maybeAddUsersToWorkList(inst); - - continue; - } - - // The used inst is marked for inversion, lookup and load - // an inverse. - // - if (this->hoistedPrimalsInfo->invertSet.Contains(inst)) - { - // Replace with inverse. - IRBuilder builder(func->getModule()); - - for (auto use = inst->firstUse; use;) - { - auto nextUse = use->nextUse; - - if (!isInstHoisted(use->getUser())) - { - use = nextUse; - continue; - } - - // TODO: Hacky workaround to prevent the 'key' being overwritten, - // avoid this by adding the decoration on the param instead of the loop - // - if (auto exitValDecoration = as<IRLoopExitPrimalValueDecoration>(use->getUser())) - { - if (&exitValDecoration->target == use) - { - use = nextUse; - continue; - } - } - - - builder.setInsertBefore(getInstInBlock(use->getUser())); - use->set(loadInverse(&builder, inst)); - - use = nextUse; - } - - // If all uses of the invertible inst have been hoisted, - // add the inv-var to the worklist. - // - workList.add(lookupInverseVar(inst)); - hoistedInstMap[inst] = nullptr; - + if (!doesInstRequireHoisting(inst)) continue; - } - - // Should not see an inst marked for inversion here. - SLANG_RELEASE_ASSERT(!this->hoistedPrimalsInfo->invertSet.Contains(inst)); - + List<IRUse*> relevantUses; IRBlock* defBlock = nullptr; @@ -1641,7 +1200,7 @@ struct DiffTransposePass if (!doesInstRequireHoisting(inst)) continue; - + // Move this inst to after it's diff uses. // { @@ -1662,62 +1221,9 @@ struct DiffTransposePass inst->insertBefore(currTopBlock->getFirstOrdinaryInst()); enableIRValidationAtInsert(); } - - // Finish up.. - hoistedInstMap[inst] = inst; - maybeAddPrimalOperandsToWorkList(inst); - } - } - - void hoistPrimalOperands(IRBuilder* revBuilder, IRInst* fwdInst) - { - for (UIndex ii = 0; ii < fwdInst->getOperandCount(); ii++) - { - // For now we'll only hoist primal operands that are - // generated in differential blocks. - // Eventually, we also want this method to move primal access - // insts to the reverse-mode blocks (i.e. this method will - // make sure all requried primal insts are moved to the right - // place) - // - if (doesInstRequireHoisting(fwdInst->getOperand(ii))) - { - hoistPrimalUse(revBuilder, &fwdInst->getOperands()[ii]); - } } } - void invertInst(IRBuilder* builder, IRInst* primalInst) - { - // Look for an available inverse entry for this primalInst's *output* - if (shouldInstBeInverted(primalInst)) - { - // This logic is already handled in transposeBlock() so we skip - // it here. - // - if (as<IRTerminatorInst>(primalInst)) - return; - - auto invInfo = this->hoistedPrimalsInfo->invertInfoMap[primalInst]; - - IRBuilder invBuilder(builder->getModule()); - invBuilder.setInsertAfter(primalInst); - - auto invEntries = invertInst(&invBuilder, primalInst, invInfo); - - for (auto entry : invEntries) - setInverse( - &invBuilder, - getBlock(primalInst), - as<IRGlobalValueWithCode>(entry.inst->getParent()->getParent()), - entry.inst, - entry.invInst); - } - else - { - SLANG_UNEXPECTED("Could not find value for the output of inst. Unable to invert"); - } - } void transposeInst(IRBuilder* builder, IRInst* inst) { @@ -1880,7 +1386,8 @@ struct DiffTransposePass auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType(); auto tempVar = builder->emitVar(pairType); auto primalVal = builder->emitLoad(instPair->getPrimal()); - hoistPrimalOperands(builder, primalVal); // TODO(sai): Do we need to hoist other insts here? + auto primalVar = instPair->getPrimal(); + primalVarsToHoist.add(primalVar); auto diffVal = builder->emitLoad(instPair->getDiff()); auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal); @@ -1961,7 +1468,6 @@ struct DiffTransposePass auto primalContextVar = primalContextDecor->getBackwardDerivativePrimalContextVar(); auto contextLoad = builder->emitLoad(primalContextVar); - hoistPrimalOperands(builder, contextLoad); args.add(contextLoad); argTypes.add(as<IRPtrTypeBase>( @@ -3477,7 +2983,7 @@ struct DiffTransposePass DifferentialPairTypeBuilder pairBuilder; - HoistedPrimalsInfo* hoistedPrimalsInfo; + List<IRInst*> primalVarsToHoist; IRBlock* tempInvBlock; diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index c3ce32540..44e981404 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -249,7 +249,7 @@ struct ExtractPrimalFuncContext List<IRBlock*> unusedBlocks; for (auto block : func->getBlocks()) { - if (isDiffInst(block)) + if (isDiffInst(block) || block->findDecoration<IRRecomputeBlockDecoration>()) unusedBlocks.add(block); } for (auto block : unusedBlocks) @@ -317,8 +317,11 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( // Remove propagate func specific primal insts from cloned func. for (auto inst : paramInfo.propagateFuncSpecificPrimalInsts) { - auto newInst = subEnv.mapOldValToNew[inst].GetValue(); - newInst->removeAndDeallocate(); + IRInst* newInst = nullptr; + if (subEnv.mapOldValToNew.TryGetValue(inst, newInst)) + { + newInst->removeAndDeallocate(); + } } HashSet<IRInst*> newPrimalParams; diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 8b24b122e..65f45ece8 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -32,6 +32,7 @@ struct DiffUnzipPass // Dictionary<IRInst*, IRInst*> primalMap; Dictionary<IRInst*, IRInst*> diffMap; + Dictionary<IRBlock*, IRBlock*> recomputeBlockMap; // First diff block. // TODO: Can the same pass object can be used for multiple functions? @@ -40,8 +41,6 @@ struct DiffUnzipPass RefPtr<IndexedRegionMap> indexRegionMap; - Dictionary<IndexedRegion*, RefPtr<IndexTrackingInfo>> indexInfoMap; - DiffUnzipPass( AutoDiffSharedContext* autodiffContext) : autodiffContext(autodiffContext) @@ -58,7 +57,7 @@ struct DiffUnzipPass return diffMap[inst]; } - RefPtr<HoistedPrimalsInfo> unzipDiffInsts(IRFunc* func) + void unzipDiffInsts(IRFunc* func) { diffTypeContext.setFunc(func); @@ -138,7 +137,8 @@ struct DiffUnzipPass // Mark the differential block as a differential inst // (and add a reference to the primal block) - builder->markInstAsDifferential(diffBlock, nullptr, primalMap[block]); + builder->markInstAsDifferential( + diffBlock, builder->getBasicBlockType(), primalMap[block]); // Record the first differential (code) block, // since we want all 'return' insts in primal blocks @@ -154,16 +154,6 @@ struct DiffUnzipPass splitBlock(block, as<IRBlock>(primalMap[block]), as<IRBlock>(diffMap[block])); } - // Emit counter variables and other supporting - // instructions for all regions. - // - // TODO: Need to have maxIndex in _both_ IndexTrackingInfo & IndexedRegionInfo. - // That way, we can do the various passes _before_ lowerIndexedRegions() - // TODO: Remove the call to lowerIndexedRegions() once checkpointing works properly. - // - RefPtr<HoistedPrimalsInfo> primalsInfo = new HoistedPrimalsInfo(); - lowerIndexedRegions(primalsInfo); - // Copy regions from fwd-block to their split blocks // to make it easier to do lookups. // @@ -189,217 +179,13 @@ struct DiffUnzipPass firstBlock->replaceUsesWith(firstPrimalBlock); RefPtr<BlockSplitInfo> splitInfo = new BlockSplitInfo(); + for (auto block : mixedBlocks) if (primalMap.ContainsKey(block)) splitInfo->diffBlockMap[as<IRBlock>(primalMap[block])] = as<IRBlock>(diffMap[block]); - - Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlocksInfo; - for (auto block : mixedBlocks) - { - indexedBlocksInfo[as<IRBlock>(diffMap[block])] = getIndexInfoList(as<IRBlock>(diffMap[block])); - indexedBlocksInfo[as<IRBlock>(primalMap[block])] = getIndexInfoList(as<IRBlock>(primalMap[block])); - } for (auto block : mixedBlocks) block->removeAndDeallocate(); - - // Run the three checkpointing passes to hoist/clone primal insts - // to the right spots. - // - { - RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(unzippedFunc->getModule()); - chkPolicy->preparePolicy(func); - - auto chkPrimalsInfo = chkPolicy->processFunc(func, splitInfo); - primalsInfo->merge(chkPrimalsInfo); - - primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlocksInfo); - } - - return primalsInfo; - } - - void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info) - { - if (info->status != IndexTrackingInfo::CountStatus::Unresolved) - return; - - auto loop = as<IRLoop>(region->getInitializerBlock()->getTerminator()); - - if (auto maxItersDecoration = loop->findDecoration<IRLoopMaxItersDecoration>()) - { - info->maxIters = (Count) maxItersDecoration->getMaxIters(); - info->status = IndexTrackingInfo::CountStatus::Static; - } - - if (info->status == IndexTrackingInfo::CountStatus::Unresolved) - { - SLANG_UNEXPECTED("Could not resolve max iters \ - for loop appearing in reverse-mode"); - } - } - - IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type) - { - builder->setInsertInto(block); - return builder->emitParam(type); - } - - IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type, UIndex index) - { - List<IRParam*> params; - for (auto param : block->getParams()) - params.add(param); - - SLANG_RELEASE_ASSERT(index == (UCount)params.getCount()); - - return addPhiInputParam(builder, block, type); - } - - void lowerIndexedRegions(HoistedPrimalsInfo* primalsInfo) - { - IRBuilder builder(autodiffContext->moduleInst->getModule()); - - for (auto region : indexRegionMap->regions) - { - RefPtr<IndexTrackingInfo> info = new IndexTrackingInfo(); - indexInfoMap[region] = info; - - // Grab first primal block. - IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->getInitializerBlock()]); - builder.setInsertBefore(primalInitBlock->getTerminator()); - - // Make variable in the top-most block (so it's visible to diff blocks) - info->primalCountLastVar = builder.emitVar(builder.getIntType()); - builder.addNameHintDecoration(info->primalCountLastVar, UnownedStringSlice("_pc_last_var")); - primalsInfo->storeSet.Add(info->primalCountLastVar); - - { - auto primalCondBlock = as<IRUnconditionalBranch>( - primalInitBlock->getTerminator())->getTargetBlock(); - builder.setInsertBefore(primalCondBlock->getTerminator()); - - auto phiCounterArgLoopEntryIndex = addPhiOutputArg( - &builder, - primalInitBlock, - builder.getIntValue(builder.getIntType(), 0)); - - info->primalCountParam = addPhiInputParam( - &builder, - primalCondBlock, - builder.getIntType(), - phiCounterArgLoopEntryIndex); - builder.addNameHintDecoration(info->primalCountParam, UnownedStringSlice("_pc")); - builder.addLoopCounterDecoration(info->primalCountParam); - builder.markInstAsPrimal(info->primalCountParam); - - IRBlock* primalUpdateBlock = as<IRBlock>(primalMap[region->getUpdateBlock()]); - builder.setInsertBefore(primalUpdateBlock->getTerminator()); - - auto incCounterVal = builder.emitAdd( - builder.getIntType(), - info->primalCountParam, - builder.getIntValue(builder.getIntType(), 1)); - builder.markInstAsPrimal(incCounterVal); - - auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, primalUpdateBlock, incCounterVal); - - SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex); - - IRBlock* primalBreakBlock = as<IRBlock>(primalMap[region->getBreakBlock()]); - builder.setInsertBefore(primalBreakBlock->getTerminator()); - - builder.emitStore(info->primalCountLastVar, info->primalCountParam); - } - - { - IRBlock* diffInitBlock = as<IRBlock>(diffMap[region->getInitializerBlock()]); - - auto diffCondBlock = as<IRUnconditionalBranch>( - diffInitBlock->getTerminator())->getTargetBlock(); - builder.setInsertBefore(diffCondBlock->getTerminator()); - - auto phiCounterArgLoopEntryIndex = addPhiOutputArg( - &builder, - diffInitBlock, - builder.getIntValue(builder.getIntType(), 0)); - - info->diffCountParam = addPhiInputParam( - &builder, - diffCondBlock, - builder.getIntType(), - phiCounterArgLoopEntryIndex); - builder.addNameHintDecoration(info->diffCountParam, UnownedStringSlice("_dc")); - builder.addLoopCounterDecoration(info->diffCountParam); - builder.markInstAsPrimal(info->diffCountParam); - - IRBlock* diffUpdateBlock = as<IRBlock>(diffMap[region->getUpdateBlock()]); - builder.setInsertBefore(diffUpdateBlock->getTerminator()); - - auto incCounterVal = builder.emitAdd( - builder.getIntType(), - info->diffCountParam, - builder.getIntValue(builder.getIntType(), 1)); - builder.markInstAsPrimal(incCounterVal); - - auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, diffUpdateBlock, incCounterVal); - - SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex); - - auto loopInst = as<IRLoop>(diffInitBlock->getTerminator()); - - builder.setInsertBefore(loopInst); - - auto primalCounterLastVal = builder.emitLoad(info->primalCountLastVar); - builder.markInstAsPrimal(primalCounterLastVal); - builder.addPrimalValueAccessDecoration(primalCounterLastVal); - - builder.addLoopExitPrimalValueDecoration(loopInst, info->diffCountParam, primalCounterLastVal); - - // We'll be manually creating the inversion entries for the counters - // TODO: This logic can be moved to the checkpointing alg. - // - primalsInfo->invertSet.Add(info->diffCountParam); - primalsInfo->instsToInvert.Add(incCounterVal); - primalsInfo->invertInfoMap[incCounterVal] = InversionInfo( - incCounterVal, - List<IRInst*>(incCounterVal), - List<IRInst*>(info->diffCountParam)); - - primalsInfo->invertSet.Add(incCounterVal); - primalsInfo->instsToInvert.Add(diffUpdateBlock->getTerminator()); - primalsInfo->invertInfoMap[diffUpdateBlock->getTerminator()] = InversionInfo( - diffUpdateBlock->getTerminator(), - List<IRInst*>(diffUpdateBlock->getTerminator()), - List<IRInst*>(incCounterVal)); - } - - // Try to infer maximum possible number of iterations. - // (only regions whose intermediates are used outside their region - // require a maximum count, so we may see some unresolved regions - // without any issues) - // - tryInferMaxIndex(region, info); - } - } - - void tagNewParams(IRBuilder* builder, IRFunc* func) - { - for (auto block : func->getBlocks()) - { - for (auto param = block->getFirstParam(); param; param = param->getNextParam()) - if (!param->findDecoration<IRAutodiffInstDecoration>()) - builder->markInstAsPrimal(param); - } - } - - List<IndexTrackingInfo*> getIndexInfoList(IRBlock* block) - { - List<IndexTrackingInfo*> indices; - for (auto region : indexRegionMap->getAllAncestorRegions(block)) - indices.add((IndexTrackingInfo*) indexInfoMap[region].GetValue()); - - return indices; } IRFunc* extractPrimalFunc( diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 9a7a42619..a8af148d9 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -774,7 +774,9 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_PrimalInstDecoration: case kIROp_DifferentialInstDecoration: case kIROp_MixedDifferentialInstDecoration: - case kIROp_PrimalValueAccessDecoration: + case kIROp_RecomputeBlockDecoration: + case kIROp_LoopCounterDecoration: + case kIROp_LoopCounterUpdateDecoration: case kIROp_BackwardDerivativeDecoration: case kIROp_BackwardDerivativeIntermediateTypeDecoration: case kIROp_BackwardDerivativePropagateDecoration: @@ -814,6 +816,7 @@ void stripTempDecorations(IRInst* inst) { case kIROp_DifferentialInstDecoration: case kIROp_MixedDifferentialInstDecoration: + case kIROp_RecomputeBlockDecoration: case kIROp_AutoDiffOriginalValueDecoration: case kIROp_BackwardDerivativePrimalReturnDecoration: case kIROp_PrimalValueStructKeyDecoration: @@ -902,8 +905,9 @@ bool canTypeBeStored(IRInst* type) case kIROp_FloatType: case kIROp_VectorType: case kIROp_MatrixType: - case kIROp_AttributedType: return true; + case kIROp_AttributedType: + return canTypeBeStored(type->getOperand(0)); default: return false; } @@ -1770,7 +1774,7 @@ IRInst* getInstInBlock(IRInst* inst) return getInstInBlock(inst->getParent()); } -UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) +UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst*& inoutTerminatorInst, IRInst* arg) { SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator())); @@ -1786,16 +1790,22 @@ UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) switch (branchInst->getOp()) { case kIROp_unconditionalBranch: - builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer()); + inoutTerminatorInst = builder->emitBranch( + branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer()); break; case kIROp_loop: - builder->emitLoop( - as<IRLoop>(branchInst)->getTargetBlock(), - as<IRLoop>(branchInst)->getBreakBlock(), - as<IRLoop>(branchInst)->getContinueBlock(), - phiArgs.getCount(), - phiArgs.getBuffer()); + { + auto newLoop = builder->emitLoop( + as<IRLoop>(branchInst)->getTargetBlock(), + as<IRLoop>(branchInst)->getBreakBlock(), + as<IRLoop>(branchInst)->getContinueBlock(), + phiArgs.getCount(), + phiArgs.getBuffer()); + branchInst->transferDecorationsTo(newLoop); + branchInst->replaceUsesWith(newLoop); + inoutTerminatorInst = newLoop; + } break; default: @@ -1806,6 +1816,24 @@ UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) return phiArgs.getCount() - 1; } +bool isDifferentialOrRecomputeBlock(IRBlock* block) +{ + if (!block) + return false; + for (auto decor : block->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_DifferentialInstDecoration: + case kIROp_RecomputeBlockDecoration: + return true; + default: + break; + } + } + return false; +} + IRUse* findUniqueStoredVal(IRVar* var) { if (isDerivativeContextVar(var)) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 167aa2357..d7d6119d4 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -340,7 +340,7 @@ IRBlock* getBlock(IRInst* inst); IRInst* getInstInBlock(IRInst* inst); -UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg); +UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst*& inoutTerminatorInst, IRInst* arg); IRUse* findUniqueStoredVal(IRVar* var); @@ -348,4 +348,6 @@ bool isDerivativeContextVar(IRVar* var); bool isDiffInst(IRInst* inst); +bool isDifferentialOrRecomputeBlock(IRBlock* block); + }; diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 1fe88e780..364abe68c 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -327,29 +327,7 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o // if (inst->mightHaveSideEffects()) { - // If the inst has side effect, we should keep it alive. - // An exception is if we have a call to a pure function - // that writes its output to a local variable, but we - // don't have any uses of that local variable. - auto call = as<IRCall>(inst); - if (!call) - return true; - if (!getResolvedInstForDecorations(call->getCallee())->findDecoration<IRReadNoneDecoration>()) - return true; - auto parentFunc = getParentFunc(inst); - if (!parentFunc) - return true; - for (UInt i = 0; i < call->getArgCount(); i++) - { - auto arg = call->getArg(i); - if (getParentFunc(arg) != parentFunc) - return true; - if (arg->getOp() != kIROp_Var) - return true; - if (arg->hasMoreThanOneUse()) - return true; - } - return false; + return true; } // // The `mightHaveSideEffects` query is conservative, and will diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index a8ec5a66f..11143cebb 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -810,7 +810,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0) INST(LoopCounterDecoration, loopCounterDecoration, 0, 0) - INST(PrimalValueAccessDecoration, primalValueAccessDecoration, 0, 0) + INST(LoopCounterUpdateDecoration, loopCounterUpdateDecoration, 0, 0) /* Auto-diff inst decorations */ /// Used by the auto-diff pass to mark insts that compute @@ -824,7 +824,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// Used by the auto-diff pass to mark insts that compute /// BOTH a differential and a primal value. INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0) - INST_RANGE(AutodiffInstDecoration, PrimalInstDecoration, MixedDifferentialInstDecoration) + + INST(RecomputeBlockDecoration, RecomputeBlockDecoration, 0, 0) + INST_RANGE(AutodiffInstDecoration, PrimalInstDecoration, RecomputeBlockDecoration) /// Used by the auto-diff pass to mark insts whose result is stored /// in an intermediary struct for reuse in backward propagation phase. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 356ccf4d6..f515baf8d 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -721,6 +721,16 @@ struct IRLoopCounterDecoration : IRDecoration IR_LEAF_ISA(LoopCounterDecoration) }; +struct IRLoopCounterUpdateDecoration : IRDecoration +{ + enum + { + kOp = kIROp_LoopCounterUpdateDecoration + }; + IR_LEAF_ISA(LoopCounterUpdateDecoration) +}; + + struct IRLoopExitPrimalValueDecoration : IRDecoration { enum @@ -777,14 +787,14 @@ struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration IRType* getPairType() { return as<IRType>(getOperand(0)); } }; -struct IRPrimalValueAccessDecoration : IRAutodiffInstDecoration +struct IRRecomputeBlockDecoration : IRAutodiffInstDecoration { enum { - kOp = kIROp_PrimalValueAccessDecoration + kOp = kIROp_RecomputeBlockDecoration }; - IR_LEAF_ISA(PrimalValueAccessDecoration) + IR_LEAF_ISA(RecomputeBlockDecoration) }; struct IRPrimalValueStructKeyDecoration : IRDecoration @@ -3532,6 +3542,7 @@ public: IRInst* emitEql(IRInst* left, IRInst* right); IRInst* emitNeq(IRInst* left, IRInst* right); IRInst* emitLess(IRInst* left, IRInst* right); + IRInst* emitGeq(IRInst* left, IRInst* right); IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1); IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1); @@ -3807,9 +3818,9 @@ public: addDecoration(value, kIROp_LoopExitPrimalValueDecoration, primalInst, exitValue); } - void addPrimalValueAccessDecoration(IRInst* value) + void addLoopCounterUpdateDecoration(IRInst* value) { - addDecoration(value, kIROp_PrimalValueAccessDecoration); + addDecoration(value, kIROp_LoopCounterUpdateDecoration); } void markInstAsPrimal(IRInst* value) diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index 121665c85..a368ff8c8 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -529,23 +529,6 @@ bool unrollLoopsInModule(IRModule* module, DiagnosticSink* sink) return true; } -static void _moveParams(IRBlock* dest, IRBlock* src) -{ - for (auto param = src->getFirstChild(); param;) - { - auto nextInst = param->getNextInst(); - if (as<IRDecoration>(param) || as<IRParam>(param)) - { - param->insertAtEnd(dest); - } - else - { - break; - } - param = nextInst; - } -} - void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst) { // Eliminate the continue jumps by turning a loop in the form of: @@ -599,7 +582,7 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst) targetBlock->replaceUsesWith(innerBreakableRegionHeader); // Move decorations and params from original targetBlock to innerBreakableRegionHeader. - _moveParams(innerBreakableRegionHeader, targetBlock); + moveParams(innerBreakableRegionHeader, targetBlock); builder.setInsertInto(innerBreakableRegionHeader); builder.emitLoop(targetBlock, innerBreakableRegionBreakBlock, targetBlock); @@ -607,7 +590,7 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst) continueBlock->replaceUsesWith(innerBreakableRegionBreakBlock); builder.setInsertInto(innerBreakableRegionBreakBlock); - _moveParams(innerBreakableRegionBreakBlock, continueBlock); + moveParams(innerBreakableRegionBreakBlock, continueBlock); builder.emitBranch(continueBlock); // If the original loop can be executed up to N times, the new loop may be executed diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index dd92630b3..99cae22f0 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -66,6 +66,14 @@ struct RedundancyRemovalContext case kIROp_Leq: case kIROp_Neq: case kIROp_Eql: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialWitnessTable: + case kIROp_PtrType: + case kIROp_ArrayType: + case kIROp_FuncType: + case kIROp_InOutType: + case kIROp_OutType: return true; case kIROp_Call: return isPureFunctionalCall(as<IRCall>(inst)); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 83f6735bd..03b74b36a 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -681,6 +681,9 @@ bool isPureFunctionalCall(IRCall* call) // are not dependent on whatever we do in the call here. continue; default: + // Skip the call itself, since we are checking if the call has side effect. + if (use->getUser() == call) + continue; // We have some other unknown use of the variable address, they can // be loads, or calls using addresses derived from the variable, // we will treat the call as having side effect to be safe. @@ -721,6 +724,23 @@ IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key) return nullptr; } +void moveParams(IRBlock* dest, IRBlock* src) +{ + for (auto param = src->getFirstChild(); param;) + { + auto nextInst = param->getNextInst(); + if (as<IRDecoration>(param) || as<IRParam>(param)) + { + param->insertAtEnd(dest); + } + else + { + break; + } + param = nextInst; + } +} + struct GenericChildrenMigrationContextImpl { IRCloneEnv cloneEnv; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index ef7ff47bb..e7d182604 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -198,6 +198,8 @@ void removeLinkageDecorations(IRGlobalValueWithCode* func); IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key); IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key); + +void moveParams(IRBlock* dest, IRBlock* src); } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 97109274f..558fd7796 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5405,6 +5405,13 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitGeq(IRInst* left, IRInst* right) + { + auto inst = createInst<IRInst>(this, kIROp_Geq, getBoolType(), left, right); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitMul(IRType* type, IRInst* left, IRInst* right) { auto inst = createInst<IRInst>( |
