diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-28 21:24:24 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-28 21:24:24 -0500 |
| commit | 3c32dd951c5d69b5568929e0038e693553efca79 (patch) | |
| tree | 377b4b921e82cfc201a768d88a70f12a16586614 /source | |
| parent | 7eeda30df967671c410de4fd725f91f9078d74c4 (diff) | |
AD: Fixed do-while loops (#2683)
* WIP: Fix for do-while loops
* Added a somewhat hacky fix for do-while loops
* Redid the indexed region map builder step to fix issue with the nested loops test
* rename
* Used managed pointers
Diffstat (limited to 'source')
| -rw-r--r-- | source/core/slang-dictionary.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 516 |
3 files changed, 249 insertions, 313 deletions
diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h index b11341051..e923832e5 100644 --- a/source/core/slang-dictionary.h +++ b/source/core/slang-dictionary.h @@ -450,7 +450,7 @@ namespace Slang return dict->hashMap[pos.ObjectPosition].Value; } else - SLANG_ASSERT_FAILURE("The key does not exists in dictionary."); + SLANG_ASSERT_FAILURE("The key does not exist in dictionary."); } inline TValue & operator()() const { diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 2199b0771..f3c739894 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -83,11 +83,19 @@ IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst) // false side goes into the break block. // condBuilder.setInsertInto(condBlock); - condBuilder.emitIfElse( + auto ifElse = as<IRIfElse>(condBuilder.emitIfElse( condBuilder.getBoolValue(true), firstBlock, loopInst->getBreakBlock(), - firstBlock); + firstBlock)); + + // We'll insert a blank block between the condition and the + // break block, since otherwise, we might trip up the later + // parts of this pass. + // + condBuilder.insertBlockAlongEdge( + loopInst->getModule(), + IREdge(&ifElse->falseBlock)); return condBlock; } @@ -232,7 +240,7 @@ struct CFGNormalizationPass breakFlagValue, block, afterSplitAfterBlock, - afterSplitAfterBlock); + afterSplitAfterBlock); // At this point, we need to place afterSplitAfterBlock between // at the _end_ of this region, but we aren't there yet (and @@ -357,6 +365,36 @@ struct CFGNormalizationPass // Do we need to split the after region? if (afterBaseRegion && afterBreakRegion) { + // Before we split the afterBlock, we + // want to make sure the afterBlock is + // firmly _inside_ the current region. + // If it's part of the parent, add a + // dummy block. + // + if (afterBlocks.contains(afterBlock)) + { + auto newAfterBlock = builder.emitBlock(); + + // TODO: This is a hack. Ideally we should be putting + // the new after block 'before' the old after block, + // but if the latter is a loop condition block, it dominates + // the former, which may depend on parameters in the loop + // condition block. (This eventually causes cloneInst to fail, + // since it is currently order-dependent) + // Remove this once cloneInst is order-independent. + // + // newAfterBlock->insertBefore(afterBlock); + newAfterBlock->insertAfter(falseEndPoint.exitBlock); + + builder.emitBranch(afterBlock); + + ifElse->afterBlock.set(newAfterBlock); + as<IRUnconditionalBranch>(trueEndPoint.exitBlock->getTerminator())->block.set(newAfterBlock); + as<IRUnconditionalBranch>(falseEndPoint.exitBlock->getTerminator())->block.set(newAfterBlock); + + afterBlock = newAfterBlock; + } + addBreakBypassBranch(afterBlock); // Update current block. diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 3678bd4b3..f2aa1fd29 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -36,23 +36,147 @@ struct DiffUnzipPass // might run into an issue here? IRBlock* firstDiffBlock; - struct IndexedRegion + struct IndexedRegion : public RefObject { - // Parent indexed region (for nested loops) - IndexedRegion* parent = nullptr; + IRLoop* loop; + IndexedRegion* parent; - // Intializer block for the index. - IRBlock* initBlock = nullptr; + IndexedRegion(IRLoop* loop, IndexedRegion* parent) : loop(loop), parent(parent) + { } + + IRBlock* getInitializerBlock() { return as<IRBlock>(loop->getParent()); } + IRBlock* getConditionBlock() + { + auto condBlock = as<IRBlock>(loop->getTargetBlock()); + SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator())); + return condBlock; + } + + IRBlock* getBreakBlock() { return loop->getBreakBlock(); } + + IRBlock* getUpdateBlock() + { + auto initBlock = getInitializerBlock(); + + auto condBlock = getConditionBlock(); + + IRBlock* lastLoopBlock = nullptr; + + for (auto predecessor : condBlock->getPredecessors()) + { + if (predecessor != initBlock) + lastLoopBlock = predecessor; + } + + // Should find atleast one predecessor that is _not_ the + // init block (that contains the loop info). This + // predecessor would be the last block in the loop + // before looping back to the condition. + // + SLANG_RELEASE_ASSERT(lastLoopBlock); + + return lastLoopBlock; + } + }; + + + struct IndexedRegionMap : public RefObject + { + Dictionary<IRBlock*, IndexedRegion*> map; + List<RefPtr<IndexedRegion>> regions; + + IndexedRegion* newRegion(IRLoop* loop, IndexedRegion* parent) + { + auto region = new IndexedRegion(loop, parent); + regions.add(region); + + return region; + } + + void mapBlock(IRBlock* block, IndexedRegion* region) + { + map.Add(block, region); + } + + bool hasMapping(IRBlock* block) + { + return map.ContainsKey(block); + } + + IndexedRegion* getRegion(IRBlock* block) + { + return map[block]; + } + + List<IndexedRegion*> getAllAncestorRegions(IRBlock* block) + { + List<IndexedRegion*> regionList; + + IndexedRegion* region = getRegion(block); + for (; region; region = region->parent) + regionList.add(region); + + return regionList; + } + }; + + RefPtr<IndexedRegionMap> buildIndexedRegionMap(IRGlobalValueWithCode* func) + { + RefPtr<IndexedRegionMap> regionMap = new IndexedRegionMap; + + List<IRBlock*> workList; + + regionMap->mapBlock(func->getFirstBlock(), nullptr); + workList.add(func->getFirstBlock()); + + while (workList.getCount() > 0) + { + auto currentBlock = workList.getLast(); + workList.removeLast(); - // Index 'starts' at the first loop block (included) - IRBlock* firstBlock = nullptr; + auto terminator = currentBlock->getTerminator(); + auto currentRegion = regionMap->getRegion(currentBlock); + + switch (terminator->getOp()) + { + case kIROp_loop: + { + auto loopRegion = regionMap->newRegion(as<IRLoop>(terminator), currentRegion); + auto condBlock = as<IRLoop>(terminator)->getTargetBlock(); + + regionMap->mapBlock(condBlock, loopRegion); + workList.add(condBlock); + + auto ifElse = as<IRIfElse>(condBlock->getTerminator()); + SLANG_RELEASE_ASSERT(ifElse); + + // TODO: this is one of the places we'll need to change if we support loops that + // loop on either the true or false side. For now, we assume the loop is on the + // true side only. + // + regionMap->mapBlock(ifElse->getFalseBlock(), currentRegion); + workList.add(ifElse->getFalseBlock()); + } + } + + for (auto successor : currentBlock->getSuccessors()) + { + // If already mapped, skip. + if (regionMap->hasMapping(successor)) + continue; + regionMap->mapBlock(successor, currentRegion); + workList.add(successor); + } + } + + return regionMap; + } - // Index stops at the break block (not included) - IRBlock* breakBlock = nullptr; - // Block where index updates happen. - IRBlock* continueBlock = nullptr; + RefPtr<IndexedRegionMap> indexRegionMap; + struct IndexTrackingInfo : public RefObject + { // After lowering, store references to the count // variables associated with this region // @@ -72,41 +196,9 @@ struct DiffUnzipPass // Inferred maximum number of iterations. Count maxIters = -1; - - IndexedRegion() : - parent(nullptr), - initBlock(nullptr), - firstBlock(nullptr), - breakBlock(nullptr), - continueBlock(nullptr), - primalCountParam(nullptr), - diffCountParam(nullptr), - status(CountStatus::Unresolved), - maxIters(-1) - { } - - IndexedRegion( - IndexedRegion* parent, - IRBlock* initBlock, - IRBlock* firstBlock, - IRBlock* breakBlock, - IRBlock* continueBlock) : - parent(parent), - initBlock(initBlock), - firstBlock(firstBlock), - breakBlock(breakBlock), - continueBlock(continueBlock), - primalCountParam(nullptr), - diffCountParam(nullptr), - status(CountStatus::Unresolved), - maxIters(-1) - { } }; - // Keep track of indexed blocks and their corresponding index heirarchy. - Dictionary<IRBlock*, IndexedRegion*> indexRegionMap; - - List<IndexedRegion*> indexRegions; + Dictionary<IndexedRegion*, RefPtr<IndexTrackingInfo>> indexInfoMap; DiffUnzipPass( @@ -128,6 +220,11 @@ struct DiffUnzipPass void unzipDiffInsts(IRFunc* func) { diffTypeContext.setFunc(func); + + // Build a map of blocks to loop regions. + // This will be used later to insert tracking indices + // + indexRegionMap = buildIndexedRegionMap(func); IRBuilder builderStorage(autodiffContext->moduleInst->getModule()); @@ -216,19 +313,6 @@ struct DiffUnzipPass splitBlock(block, as<IRBlock>(primalMap[block]), as<IRBlock>(diffMap[block])); } - // Propagate indexed region information. - propagateAllIndexRegions(); - - // Try to infer maximum counts for all regions. - // (only regions whose intermediates are used outside their region - // require a maximum count, so we may see some unresolved regions - // without any issues) - // - for (auto region : indexRegions) - { - tryInferMaxIndex(region); - } - // Emit counter variables and other supporting // instructions for all regions. // @@ -239,7 +323,7 @@ struct DiffUnzipPass // { List<IRBlock*> workList; - for (auto blockRegionPair : indexRegionMap) + for (auto blockRegionPair : indexRegionMap->map) { IRBlock* block = blockRegionPair.Key; workList.add(block); @@ -247,8 +331,11 @@ struct DiffUnzipPass for (auto block : workList) { - indexRegionMap[as<IRBlock>(primalMap[block])] = (IndexedRegion*)indexRegionMap[block]; - indexRegionMap[as<IRBlock>(diffMap[block])] = (IndexedRegion*)indexRegionMap[block]; + if (primalMap.ContainsKey(block)) + indexRegionMap->map[as<IRBlock>(primalMap[block])] = (IndexedRegion*)indexRegionMap->map[block]; + + if (diffMap.ContainsKey(block)) + indexRegionMap->map[as<IRBlock>(diffMap[block])] = (IndexedRegion*)indexRegionMap->map[block]; } } @@ -257,72 +344,31 @@ struct DiffUnzipPass // for (auto block : mixedBlocks) { - if (isBlockIndexed(block)) + if (indexRegionMap->getRegion(block) != nullptr) processIndexedFwdBlock(block); } // Swap the first block's occurences out for the first primal block. firstBlock->replaceUsesWith(firstPrimalBlock); - cleanupIndexRegionInfo(); - for (auto block : mixedBlocks) block->removeAndDeallocate(); } - IRBlock* getInitializerBlock(IndexedRegion* region) - { - return region->initBlock; - } - - IRBlock* getUpdateBlock(IndexedRegion* region) + void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info) { - auto initBlock = getInitializerBlock(region); - - auto condBlock = region->firstBlock; - - IRBlock* lastLoopBlock = nullptr; - - for (auto predecessor : condBlock->getPredecessors()) - { - if (predecessor != initBlock) - lastLoopBlock = predecessor; - } - - // Should find atleast one predecessor that is _not_ the - // init block (that contains the loop info). This - // predecessor would be the last block in the loop - // before looping back to the condition. - // - SLANG_RELEASE_ASSERT(lastLoopBlock); - - return lastLoopBlock; - } - - IRBlock* getFirstLoopBodyBlock(IndexedRegion* region) - { - // Grab the 'condition' block. - auto condBlock = region->firstBlock; - - SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator())); - - return as<IRIfElse>(condBlock->getTerminator())->getTrueBlock(); - } - - void tryInferMaxIndex(IndexedRegion* region) - { - if (region->status != IndexedRegion::CountStatus::Unresolved) + if (info->status != IndexTrackingInfo::CountStatus::Unresolved) return; - auto loop = as<IRLoop>(region->initBlock->getTerminator()); + auto loop = as<IRLoop>(region->getInitializerBlock()->getTerminator()); if (auto maxItersDecoration = loop->findDecoration<IRLoopMaxItersDecoration>()) { - region->maxIters = (Count) maxItersDecoration->getMaxIters(); - region->status = IndexedRegion::CountStatus::Static; + info->maxIters = (Count) maxItersDecoration->getMaxIters(); + info->status = IndexTrackingInfo::CountStatus::Static; } - if (region->status == IndexedRegion::CountStatus::Unresolved) + if (info->status == IndexTrackingInfo::CountStatus::Unresolved) { SLANG_UNEXPECTED("Could not resolve max iters \ for loop appearing in reverse-mode"); @@ -406,15 +452,18 @@ struct DiffUnzipPass { IRBuilder builder(autodiffContext->moduleInst->getModule()); - for (auto region : indexRegions) + for (auto region : indexRegionMap->regions) { + RefPtr<IndexTrackingInfo> info = new IndexTrackingInfo(); + indexInfoMap[region] = info; + // Grab first primal block. - IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]); + IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->getInitializerBlock()]); builder.setInsertBefore(primalInitBlock->getTerminator()); // Make variable in the top-most block (so it's visible to diff blocks) - region->primalCountLastVar = builder.emitVar(builder.getIntType()); - builder.addNameHintDecoration(region->primalCountLastVar, UnownedStringSlice("_pc_last_var")); + info->primalCountLastVar = builder.emitVar(builder.getIntType()); + builder.addNameHintDecoration(info->primalCountLastVar, UnownedStringSlice("_pc_last_var")); { auto primalCondBlock = as<IRUnconditionalBranch>( @@ -426,21 +475,21 @@ struct DiffUnzipPass primalInitBlock, builder.getIntValue(builder.getIntType(), 0)); - region->primalCountParam = addPhiInputParam( + info->primalCountParam = addPhiInputParam( &builder, primalCondBlock, builder.getIntType(), phiCounterArgLoopEntryIndex); - builder.addNameHintDecoration(region->primalCountParam, UnownedStringSlice("_pc")); - builder.addLoopCounterDecoration(region->primalCountParam); - builder.markInstAsPrimal(region->primalCountParam); + builder.addNameHintDecoration(info->primalCountParam, UnownedStringSlice("_pc")); + builder.addLoopCounterDecoration(info->primalCountParam); + builder.markInstAsPrimal(info->primalCountParam); - IRBlock* primalUpdateBlock = as<IRBlock>(primalMap[getUpdateBlock(region)]); + IRBlock* primalUpdateBlock = as<IRBlock>(primalMap[region->getUpdateBlock()]); builder.setInsertBefore(primalUpdateBlock->getTerminator()); auto incCounterVal = builder.emitAdd( builder.getIntType(), - region->primalCountParam, + info->primalCountParam, builder.getIntValue(builder.getIntType(), 1)); builder.markInstAsPrimal(incCounterVal); @@ -448,14 +497,14 @@ struct DiffUnzipPass SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex); - IRBlock* primalBreakBlock = as<IRBlock>(primalMap[region->breakBlock]); + IRBlock* primalBreakBlock = as<IRBlock>(primalMap[region->getBreakBlock()]); builder.setInsertBefore(primalBreakBlock->getTerminator()); - builder.emitStore(region->primalCountLastVar, region->primalCountParam); + builder.emitStore(info->primalCountLastVar, info->primalCountParam); } { - IRBlock* diffInitBlock = as<IRBlock>(diffMap[region->initBlock]); + IRBlock* diffInitBlock = as<IRBlock>(diffMap[region->getInitializerBlock()]); auto diffCondBlock = as<IRUnconditionalBranch>( diffInitBlock->getTerminator())->getTargetBlock(); @@ -466,21 +515,21 @@ struct DiffUnzipPass diffInitBlock, builder.getIntValue(builder.getIntType(), 0)); - region->diffCountParam = addPhiInputParam( + info->diffCountParam = addPhiInputParam( &builder, diffCondBlock, builder.getIntType(), phiCounterArgLoopEntryIndex); - builder.addNameHintDecoration(region->diffCountParam, UnownedStringSlice("_dc")); - builder.addLoopCounterDecoration(region->diffCountParam); - builder.markInstAsPrimal(region->diffCountParam); + builder.addNameHintDecoration(info->diffCountParam, UnownedStringSlice("_dc")); + builder.addLoopCounterDecoration(info->diffCountParam); + builder.markInstAsPrimal(info->diffCountParam); - IRBlock* diffUpdateBlock = as<IRBlock>(diffMap[getUpdateBlock(region)]); + IRBlock* diffUpdateBlock = as<IRBlock>(diffMap[region->getUpdateBlock()]); builder.setInsertBefore(diffUpdateBlock->getTerminator()); auto incCounterVal = builder.emitAdd( builder.getIntType(), - region->diffCountParam, + info->diffCountParam, builder.getIntValue(builder.getIntType(), 1)); builder.markInstAsPrimal(incCounterVal); @@ -492,12 +541,19 @@ struct DiffUnzipPass builder.setInsertBefore(loopInst); - auto primalCounterLastVal = builder.emitLoad(region->primalCountLastVar); + auto primalCounterLastVal = builder.emitLoad(info->primalCountLastVar); builder.markInstAsPrimal(primalCounterLastVal); builder.addPrimalValueAccessDecoration(primalCounterLastVal); - builder.addLoopExitPrimalValueDecoration(loopInst, region->diffCountParam, primalCounterLastVal); + builder.addLoopExitPrimalValueDecoration(loopInst, info->diffCountParam, primalCounterLastVal); } + + // Try to infer maximum possible number of iterations. + // (only regions whose intermediates are used outside their region + // require a maximum count, so we may see some unresolved regions + // without any issues) + // + tryInferMaxIndex(region, info); } } @@ -511,11 +567,17 @@ struct DiffUnzipPass } } - void processIndexedFwdBlock(IRBlock* fwdBlock) + List<IndexTrackingInfo*> getIndexInfoList(IRBlock* block) { - if (!isBlockIndexed(fwdBlock)) - return; + List<IndexTrackingInfo*> indices; + for (auto region : indexRegionMap->getAllAncestorRegions(block)) + indices.add((IndexTrackingInfo*) indexInfoMap[region].GetValue()); + return indices; + } + + void processIndexedFwdBlock(IRBlock* fwdBlock) + { // Grab first primal block. IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[fwdBlock->getParent()->getFirstBlock()->getNextBlock()]); @@ -625,12 +687,7 @@ struct DiffUnzipPass } // Build list of indices that the value's block is affected by. - List<IndexedRegion*> regions; - { - IndexedRegion* region = indexRegionMap[valueBlock]; - for (; region; region = region->parent) - regions.add(region); - } + List<IndexTrackingInfo*> indices = getIndexInfoList(valueBlock); // 3. Emit an array to top-level to allocate space. @@ -638,22 +695,22 @@ struct DiffUnzipPass IRType* storageType = valueType; - for (auto region : regions) + for (auto index : indices) { - SLANG_ASSERT(region->status == IndexedRegion::CountStatus::Static); - SLANG_ASSERT(region->maxIters >= 0); + SLANG_ASSERT(index->status == IndexTrackingInfo::CountStatus::Static); + SLANG_ASSERT(index->maxIters >= 0); storageType = builder.getArrayType( storageType, builder.getIntValue( builder.getUIntType(), - region->maxIters + 1)); + index->maxIters + 1)); } // Reverse the list since the indices need to be // emitted in reverse order. // - regions.reverse(); + indices.reverse(); auto storageVar = builder.emitVar(storageType); if (isIntermediateContext) @@ -673,14 +730,14 @@ struct DiffUnzipPass IRInst* storeAddr = storageVar; IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType(); - for (auto region : regions) + for (auto index : indices) { currType = as<IRArrayType>(currType)->getElementType(); storeAddr = builder.emitElementAddress( builder.getPtrType(currType), storeAddr, - region->primalCountParam); + index->primalCountParam); } if (!isIntermediateContext) @@ -730,23 +787,17 @@ struct DiffUnzipPass // TODO: Probably a good idea to do this ahead of time for // all blocks. // - List<IndexedRegion*> useBlockRegions; - { - IndexedRegion* region = indexRegionMap.ContainsKey(useBlock) ? - (IndexedRegion*)indexRegionMap[useBlock] : nullptr; - for (; region; region = region->parent) - useBlockRegions.add(region); - } + List<IndexTrackingInfo*> useBlockIndices = getIndexInfoList(useBlock); - for (auto region : regions) + for (auto index : indices) { currType = as<IRArrayType>(currType)->getElementType(); - if (useBlockRegions.contains(region)) + if (useBlockIndices.contains(index)) { // If the use-block is under the same region, use the // differential counter variable // - auto diffCounterCurrValue = region->diffCountParam; + auto diffCounterCurrValue = index->diffCountParam; loadAddr = builder.emitElementAddress( builder.getPtrType(currType), @@ -758,7 +809,7 @@ struct DiffUnzipPass // If the use-block is outside this region, use the // last available value (by indexing with primal counter minus 1) // - auto primalCounterCurrValue = builder.emitLoad(region->primalCountLastVar); + auto primalCounterCurrValue = builder.emitLoad(index->primalCountLastVar); auto primalCounterLastValue = builder.emitSub( primalCounterCurrValue->getDataType(), primalCounterCurrValue, @@ -1079,156 +1130,6 @@ struct DiffUnzipPass } } - bool isBlockIndexed(IRBlock* block) - { - return indexRegionMap.ContainsKey(block) && indexRegionMap[block] != nullptr; - } - - void addNewIndex(IRLoop* targetLoop) - { - // Create indexed region without a parent for now. - // The parent will be filled in during propagation. - // - IndexedRegion* region = new IndexedRegion( - nullptr, - as<IRBlock>(targetLoop->getParent()), - targetLoop->getTargetBlock(), - targetLoop->getBreakBlock(), - targetLoop->getContinueBlock()); - - indexRegionMap[targetLoop->getTargetBlock()] = region; - indexRegions.add(region); - } - - // Deallocate regions - void cleanupIndexRegionInfo() - { - for (auto region : indexRegions) - { - delete region; - } - - indexRegions.clear(); - indexRegionMap.Clear(); - } - - void propagateAllIndexRegions() - { - - - // Load up the starting block of every region into - // initial worklist. - // - List<IRBlock*> workList; - HashSet<IRBlock*> workSet; - for (auto region : indexRegions) - { - workList.add(region->firstBlock); - workSet.Add(region->firstBlock); - } - - // Keep propagating from initial work list to predecessors - // Add blocks to work list if their region assignment has changed - // Add the beginning blocks for complete regions if region parent has changed. - // - while (workList.getCount() > 0) - { - auto block = workList.getLast(); - workList.removeLast(); - workSet.Remove(block); - - HashSet<IRBlock*> successors; - - for (auto successor : block->getSuccessors()) - { - if (successors.Contains(successor)) - continue; - - if (propagateIndexRegion(block, successor)) - { - if (!workSet.Contains(successor)) - { - workList.add(successor); - workSet.Add(successor); - } - - // Do we have an index region for the successor, which is - // also the starting block of that region? - // Then the change might have been the addition of - // a parent node. Add the break block so the - // change can be propagated further. - // - if (isBlockIndexed(successor)) - { - IndexedRegion* succRegion = indexRegionMap[successor]; - if (succRegion->firstBlock == successor) - { - if (!workSet.Contains(succRegion->breakBlock)) - { - workList.add(succRegion->breakBlock); - workSet.Add(succRegion->breakBlock); - } - } - } - } - - successors.Add(successor); - } - } - } - - bool setIndexRegion(IRBlock* block, IndexedRegion* region) - { - if (!region) return false; - - if (indexRegionMap.ContainsKey(block) - && indexRegionMap[block] == region) - return false; - - indexRegionMap[block] = region; - return true; - } - - bool propagateIndexRegion(IRBlock* srcBlock, IRBlock* nextBlock) - { - // Is the current region indexed? - // If not, there's nothing to propagate - // - if (!isBlockIndexed(srcBlock)) - return false; - - IndexedRegion* region = indexRegionMap[srcBlock]; - - // If the target's index is already resolved, - // check if it's a sub-region. - // - if (isBlockIndexed(nextBlock)) - { - IndexedRegion* nextRegion = indexRegionMap[nextBlock]; - - // If we're at the first block of a region, - // set current region as continue-region's - // parent. - // - if (nextBlock == nextRegion->firstBlock && nextRegion != region) - { - nextRegion->parent = region; - return true; - } - - return false; - } - - // If we're at the break block, move up to the parent index. - if (nextBlock == region->breakBlock) - return setIndexRegion(nextBlock, region->parent); - - // If none of the special cases hit, copy the - // current region to the next block. - // - return setIndexRegion(nextBlock, region); - } - // Splitting a loop is one of the trickiest parts of the unzip pass. // Thus far, we've been dealing with blocks that are only run once, so we // could arbitrarily move intermediate instructions to other blocks since they are @@ -1247,9 +1148,6 @@ struct DiffUnzipPass auto continueBlock = mixedLoop->getContinueBlock(); auto nextBlock = mixedLoop->getTargetBlock(); - // Push a new index. - addNewIndex(mixedLoop); - // Split args. List<IRInst*> primalArgs; List<IRInst*> diffArgs; |
