summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-02-28 21:24:24 -0500
committerGitHub <noreply@github.com>2023-02-28 21:24:24 -0500
commit3c32dd951c5d69b5568929e0038e693553efca79 (patch)
tree377b4b921e82cfc201a768d88a70f12a16586614 /source
parent7eeda30df967671c410de4fd725f91f9078d74c4 (diff)
AD: Fixed do-while loops (#2683)
* WIP: Fix for do-while loops * Added a somewhat hacky fix for do-while loops * Redid the indexed region map builder step to fix issue with the nested loops test * rename * Used managed pointers
Diffstat (limited to 'source')
-rw-r--r--source/core/slang-dictionary.h2
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp44
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h516
3 files changed, 249 insertions, 313 deletions
diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h
index b11341051..e923832e5 100644
--- a/source/core/slang-dictionary.h
+++ b/source/core/slang-dictionary.h
@@ -450,7 +450,7 @@ namespace Slang
return dict->hashMap[pos.ObjectPosition].Value;
}
else
- SLANG_ASSERT_FAILURE("The key does not exists in dictionary.");
+ SLANG_ASSERT_FAILURE("The key does not exist in dictionary.");
}
inline TValue & operator()() const
{
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index 2199b0771..f3c739894 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -83,11 +83,19 @@ IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst)
// false side goes into the break block.
//
condBuilder.setInsertInto(condBlock);
- condBuilder.emitIfElse(
+ auto ifElse = as<IRIfElse>(condBuilder.emitIfElse(
condBuilder.getBoolValue(true),
firstBlock,
loopInst->getBreakBlock(),
- firstBlock);
+ firstBlock));
+
+ // We'll insert a blank block between the condition and the
+ // break block, since otherwise, we might trip up the later
+ // parts of this pass.
+ //
+ condBuilder.insertBlockAlongEdge(
+ loopInst->getModule(),
+ IREdge(&ifElse->falseBlock));
return condBlock;
}
@@ -232,7 +240,7 @@ struct CFGNormalizationPass
breakFlagValue,
block,
afterSplitAfterBlock,
- afterSplitAfterBlock);
+ afterSplitAfterBlock);
// At this point, we need to place afterSplitAfterBlock between
// at the _end_ of this region, but we aren't there yet (and
@@ -357,6 +365,36 @@ struct CFGNormalizationPass
// Do we need to split the after region?
if (afterBaseRegion && afterBreakRegion)
{
+ // Before we split the afterBlock, we
+ // want to make sure the afterBlock is
+ // firmly _inside_ the current region.
+ // If it's part of the parent, add a
+ // dummy block.
+ //
+ if (afterBlocks.contains(afterBlock))
+ {
+ auto newAfterBlock = builder.emitBlock();
+
+ // TODO: This is a hack. Ideally we should be putting
+ // the new after block 'before' the old after block,
+ // but if the latter is a loop condition block, it dominates
+ // the former, which may depend on parameters in the loop
+ // condition block. (This eventually causes cloneInst to fail,
+ // since it is currently order-dependent)
+ // Remove this once cloneInst is order-independent.
+ //
+ // newAfterBlock->insertBefore(afterBlock);
+ newAfterBlock->insertAfter(falseEndPoint.exitBlock);
+
+ builder.emitBranch(afterBlock);
+
+ ifElse->afterBlock.set(newAfterBlock);
+ as<IRUnconditionalBranch>(trueEndPoint.exitBlock->getTerminator())->block.set(newAfterBlock);
+ as<IRUnconditionalBranch>(falseEndPoint.exitBlock->getTerminator())->block.set(newAfterBlock);
+
+ afterBlock = newAfterBlock;
+ }
+
addBreakBypassBranch(afterBlock);
// Update current block.
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 3678bd4b3..f2aa1fd29 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -36,23 +36,147 @@ struct DiffUnzipPass
// might run into an issue here?
IRBlock* firstDiffBlock;
- struct IndexedRegion
+ struct IndexedRegion : public RefObject
{
- // Parent indexed region (for nested loops)
- IndexedRegion* parent = nullptr;
+ IRLoop* loop;
+ IndexedRegion* parent;
- // Intializer block for the index.
- IRBlock* initBlock = nullptr;
+ IndexedRegion(IRLoop* loop, IndexedRegion* parent) : loop(loop), parent(parent)
+ { }
+
+ IRBlock* getInitializerBlock() { return as<IRBlock>(loop->getParent()); }
+ IRBlock* getConditionBlock()
+ {
+ auto condBlock = as<IRBlock>(loop->getTargetBlock());
+ SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator()));
+ return condBlock;
+ }
+
+ IRBlock* getBreakBlock() { return loop->getBreakBlock(); }
+
+ IRBlock* getUpdateBlock()
+ {
+ auto initBlock = getInitializerBlock();
+
+ auto condBlock = getConditionBlock();
+
+ IRBlock* lastLoopBlock = nullptr;
+
+ for (auto predecessor : condBlock->getPredecessors())
+ {
+ if (predecessor != initBlock)
+ lastLoopBlock = predecessor;
+ }
+
+ // Should find atleast one predecessor that is _not_ the
+ // init block (that contains the loop info). This
+ // predecessor would be the last block in the loop
+ // before looping back to the condition.
+ //
+ SLANG_RELEASE_ASSERT(lastLoopBlock);
+
+ return lastLoopBlock;
+ }
+ };
+
+
+ struct IndexedRegionMap : public RefObject
+ {
+ Dictionary<IRBlock*, IndexedRegion*> map;
+ List<RefPtr<IndexedRegion>> regions;
+
+ IndexedRegion* newRegion(IRLoop* loop, IndexedRegion* parent)
+ {
+ auto region = new IndexedRegion(loop, parent);
+ regions.add(region);
+
+ return region;
+ }
+
+ void mapBlock(IRBlock* block, IndexedRegion* region)
+ {
+ map.Add(block, region);
+ }
+
+ bool hasMapping(IRBlock* block)
+ {
+ return map.ContainsKey(block);
+ }
+
+ IndexedRegion* getRegion(IRBlock* block)
+ {
+ return map[block];
+ }
+
+ List<IndexedRegion*> getAllAncestorRegions(IRBlock* block)
+ {
+ List<IndexedRegion*> regionList;
+
+ IndexedRegion* region = getRegion(block);
+ for (; region; region = region->parent)
+ regionList.add(region);
+
+ return regionList;
+ }
+ };
+
+ RefPtr<IndexedRegionMap> buildIndexedRegionMap(IRGlobalValueWithCode* func)
+ {
+ RefPtr<IndexedRegionMap> regionMap = new IndexedRegionMap;
+
+ List<IRBlock*> workList;
+
+ regionMap->mapBlock(func->getFirstBlock(), nullptr);
+ workList.add(func->getFirstBlock());
+
+ while (workList.getCount() > 0)
+ {
+ auto currentBlock = workList.getLast();
+ workList.removeLast();
- // Index 'starts' at the first loop block (included)
- IRBlock* firstBlock = nullptr;
+ auto terminator = currentBlock->getTerminator();
+ auto currentRegion = regionMap->getRegion(currentBlock);
+
+ switch (terminator->getOp())
+ {
+ case kIROp_loop:
+ {
+ auto loopRegion = regionMap->newRegion(as<IRLoop>(terminator), currentRegion);
+ auto condBlock = as<IRLoop>(terminator)->getTargetBlock();
+
+ regionMap->mapBlock(condBlock, loopRegion);
+ workList.add(condBlock);
+
+ auto ifElse = as<IRIfElse>(condBlock->getTerminator());
+ SLANG_RELEASE_ASSERT(ifElse);
+
+ // TODO: this is one of the places we'll need to change if we support loops that
+ // loop on either the true or false side. For now, we assume the loop is on the
+ // true side only.
+ //
+ regionMap->mapBlock(ifElse->getFalseBlock(), currentRegion);
+ workList.add(ifElse->getFalseBlock());
+ }
+ }
+
+ for (auto successor : currentBlock->getSuccessors())
+ {
+ // If already mapped, skip.
+ if (regionMap->hasMapping(successor))
+ continue;
+ regionMap->mapBlock(successor, currentRegion);
+ workList.add(successor);
+ }
+ }
+
+ return regionMap;
+ }
- // Index stops at the break block (not included)
- IRBlock* breakBlock = nullptr;
- // Block where index updates happen.
- IRBlock* continueBlock = nullptr;
+ RefPtr<IndexedRegionMap> indexRegionMap;
+ struct IndexTrackingInfo : public RefObject
+ {
// After lowering, store references to the count
// variables associated with this region
//
@@ -72,41 +196,9 @@ struct DiffUnzipPass
// Inferred maximum number of iterations.
Count maxIters = -1;
-
- IndexedRegion() :
- parent(nullptr),
- initBlock(nullptr),
- firstBlock(nullptr),
- breakBlock(nullptr),
- continueBlock(nullptr),
- primalCountParam(nullptr),
- diffCountParam(nullptr),
- status(CountStatus::Unresolved),
- maxIters(-1)
- { }
-
- IndexedRegion(
- IndexedRegion* parent,
- IRBlock* initBlock,
- IRBlock* firstBlock,
- IRBlock* breakBlock,
- IRBlock* continueBlock) :
- parent(parent),
- initBlock(initBlock),
- firstBlock(firstBlock),
- breakBlock(breakBlock),
- continueBlock(continueBlock),
- primalCountParam(nullptr),
- diffCountParam(nullptr),
- status(CountStatus::Unresolved),
- maxIters(-1)
- { }
};
- // Keep track of indexed blocks and their corresponding index heirarchy.
- Dictionary<IRBlock*, IndexedRegion*> indexRegionMap;
-
- List<IndexedRegion*> indexRegions;
+ Dictionary<IndexedRegion*, RefPtr<IndexTrackingInfo>> indexInfoMap;
DiffUnzipPass(
@@ -128,6 +220,11 @@ struct DiffUnzipPass
void unzipDiffInsts(IRFunc* func)
{
diffTypeContext.setFunc(func);
+
+ // Build a map of blocks to loop regions.
+ // This will be used later to insert tracking indices
+ //
+ indexRegionMap = buildIndexedRegionMap(func);
IRBuilder builderStorage(autodiffContext->moduleInst->getModule());
@@ -216,19 +313,6 @@ struct DiffUnzipPass
splitBlock(block, as<IRBlock>(primalMap[block]), as<IRBlock>(diffMap[block]));
}
- // Propagate indexed region information.
- propagateAllIndexRegions();
-
- // Try to infer maximum counts for all regions.
- // (only regions whose intermediates are used outside their region
- // require a maximum count, so we may see some unresolved regions
- // without any issues)
- //
- for (auto region : indexRegions)
- {
- tryInferMaxIndex(region);
- }
-
// Emit counter variables and other supporting
// instructions for all regions.
//
@@ -239,7 +323,7 @@ struct DiffUnzipPass
//
{
List<IRBlock*> workList;
- for (auto blockRegionPair : indexRegionMap)
+ for (auto blockRegionPair : indexRegionMap->map)
{
IRBlock* block = blockRegionPair.Key;
workList.add(block);
@@ -247,8 +331,11 @@ struct DiffUnzipPass
for (auto block : workList)
{
- indexRegionMap[as<IRBlock>(primalMap[block])] = (IndexedRegion*)indexRegionMap[block];
- indexRegionMap[as<IRBlock>(diffMap[block])] = (IndexedRegion*)indexRegionMap[block];
+ if (primalMap.ContainsKey(block))
+ indexRegionMap->map[as<IRBlock>(primalMap[block])] = (IndexedRegion*)indexRegionMap->map[block];
+
+ if (diffMap.ContainsKey(block))
+ indexRegionMap->map[as<IRBlock>(diffMap[block])] = (IndexedRegion*)indexRegionMap->map[block];
}
}
@@ -257,72 +344,31 @@ struct DiffUnzipPass
//
for (auto block : mixedBlocks)
{
- if (isBlockIndexed(block))
+ if (indexRegionMap->getRegion(block) != nullptr)
processIndexedFwdBlock(block);
}
// Swap the first block's occurences out for the first primal block.
firstBlock->replaceUsesWith(firstPrimalBlock);
- cleanupIndexRegionInfo();
-
for (auto block : mixedBlocks)
block->removeAndDeallocate();
}
- IRBlock* getInitializerBlock(IndexedRegion* region)
- {
- return region->initBlock;
- }
-
- IRBlock* getUpdateBlock(IndexedRegion* region)
+ void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info)
{
- auto initBlock = getInitializerBlock(region);
-
- auto condBlock = region->firstBlock;
-
- IRBlock* lastLoopBlock = nullptr;
-
- for (auto predecessor : condBlock->getPredecessors())
- {
- if (predecessor != initBlock)
- lastLoopBlock = predecessor;
- }
-
- // Should find atleast one predecessor that is _not_ the
- // init block (that contains the loop info). This
- // predecessor would be the last block in the loop
- // before looping back to the condition.
- //
- SLANG_RELEASE_ASSERT(lastLoopBlock);
-
- return lastLoopBlock;
- }
-
- IRBlock* getFirstLoopBodyBlock(IndexedRegion* region)
- {
- // Grab the 'condition' block.
- auto condBlock = region->firstBlock;
-
- SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator()));
-
- return as<IRIfElse>(condBlock->getTerminator())->getTrueBlock();
- }
-
- void tryInferMaxIndex(IndexedRegion* region)
- {
- if (region->status != IndexedRegion::CountStatus::Unresolved)
+ if (info->status != IndexTrackingInfo::CountStatus::Unresolved)
return;
- auto loop = as<IRLoop>(region->initBlock->getTerminator());
+ auto loop = as<IRLoop>(region->getInitializerBlock()->getTerminator());
if (auto maxItersDecoration = loop->findDecoration<IRLoopMaxItersDecoration>())
{
- region->maxIters = (Count) maxItersDecoration->getMaxIters();
- region->status = IndexedRegion::CountStatus::Static;
+ info->maxIters = (Count) maxItersDecoration->getMaxIters();
+ info->status = IndexTrackingInfo::CountStatus::Static;
}
- if (region->status == IndexedRegion::CountStatus::Unresolved)
+ if (info->status == IndexTrackingInfo::CountStatus::Unresolved)
{
SLANG_UNEXPECTED("Could not resolve max iters \
for loop appearing in reverse-mode");
@@ -406,15 +452,18 @@ struct DiffUnzipPass
{
IRBuilder builder(autodiffContext->moduleInst->getModule());
- for (auto region : indexRegions)
+ for (auto region : indexRegionMap->regions)
{
+ RefPtr<IndexTrackingInfo> info = new IndexTrackingInfo();
+ indexInfoMap[region] = info;
+
// Grab first primal block.
- IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]);
+ IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->getInitializerBlock()]);
builder.setInsertBefore(primalInitBlock->getTerminator());
// Make variable in the top-most block (so it's visible to diff blocks)
- region->primalCountLastVar = builder.emitVar(builder.getIntType());
- builder.addNameHintDecoration(region->primalCountLastVar, UnownedStringSlice("_pc_last_var"));
+ info->primalCountLastVar = builder.emitVar(builder.getIntType());
+ builder.addNameHintDecoration(info->primalCountLastVar, UnownedStringSlice("_pc_last_var"));
{
auto primalCondBlock = as<IRUnconditionalBranch>(
@@ -426,21 +475,21 @@ struct DiffUnzipPass
primalInitBlock,
builder.getIntValue(builder.getIntType(), 0));
- region->primalCountParam = addPhiInputParam(
+ info->primalCountParam = addPhiInputParam(
&builder,
primalCondBlock,
builder.getIntType(),
phiCounterArgLoopEntryIndex);
- builder.addNameHintDecoration(region->primalCountParam, UnownedStringSlice("_pc"));
- builder.addLoopCounterDecoration(region->primalCountParam);
- builder.markInstAsPrimal(region->primalCountParam);
+ builder.addNameHintDecoration(info->primalCountParam, UnownedStringSlice("_pc"));
+ builder.addLoopCounterDecoration(info->primalCountParam);
+ builder.markInstAsPrimal(info->primalCountParam);
- IRBlock* primalUpdateBlock = as<IRBlock>(primalMap[getUpdateBlock(region)]);
+ IRBlock* primalUpdateBlock = as<IRBlock>(primalMap[region->getUpdateBlock()]);
builder.setInsertBefore(primalUpdateBlock->getTerminator());
auto incCounterVal = builder.emitAdd(
builder.getIntType(),
- region->primalCountParam,
+ info->primalCountParam,
builder.getIntValue(builder.getIntType(), 1));
builder.markInstAsPrimal(incCounterVal);
@@ -448,14 +497,14 @@ struct DiffUnzipPass
SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex);
- IRBlock* primalBreakBlock = as<IRBlock>(primalMap[region->breakBlock]);
+ IRBlock* primalBreakBlock = as<IRBlock>(primalMap[region->getBreakBlock()]);
builder.setInsertBefore(primalBreakBlock->getTerminator());
- builder.emitStore(region->primalCountLastVar, region->primalCountParam);
+ builder.emitStore(info->primalCountLastVar, info->primalCountParam);
}
{
- IRBlock* diffInitBlock = as<IRBlock>(diffMap[region->initBlock]);
+ IRBlock* diffInitBlock = as<IRBlock>(diffMap[region->getInitializerBlock()]);
auto diffCondBlock = as<IRUnconditionalBranch>(
diffInitBlock->getTerminator())->getTargetBlock();
@@ -466,21 +515,21 @@ struct DiffUnzipPass
diffInitBlock,
builder.getIntValue(builder.getIntType(), 0));
- region->diffCountParam = addPhiInputParam(
+ info->diffCountParam = addPhiInputParam(
&builder,
diffCondBlock,
builder.getIntType(),
phiCounterArgLoopEntryIndex);
- builder.addNameHintDecoration(region->diffCountParam, UnownedStringSlice("_dc"));
- builder.addLoopCounterDecoration(region->diffCountParam);
- builder.markInstAsPrimal(region->diffCountParam);
+ builder.addNameHintDecoration(info->diffCountParam, UnownedStringSlice("_dc"));
+ builder.addLoopCounterDecoration(info->diffCountParam);
+ builder.markInstAsPrimal(info->diffCountParam);
- IRBlock* diffUpdateBlock = as<IRBlock>(diffMap[getUpdateBlock(region)]);
+ IRBlock* diffUpdateBlock = as<IRBlock>(diffMap[region->getUpdateBlock()]);
builder.setInsertBefore(diffUpdateBlock->getTerminator());
auto incCounterVal = builder.emitAdd(
builder.getIntType(),
- region->diffCountParam,
+ info->diffCountParam,
builder.getIntValue(builder.getIntType(), 1));
builder.markInstAsPrimal(incCounterVal);
@@ -492,12 +541,19 @@ struct DiffUnzipPass
builder.setInsertBefore(loopInst);
- auto primalCounterLastVal = builder.emitLoad(region->primalCountLastVar);
+ auto primalCounterLastVal = builder.emitLoad(info->primalCountLastVar);
builder.markInstAsPrimal(primalCounterLastVal);
builder.addPrimalValueAccessDecoration(primalCounterLastVal);
- builder.addLoopExitPrimalValueDecoration(loopInst, region->diffCountParam, primalCounterLastVal);
+ builder.addLoopExitPrimalValueDecoration(loopInst, info->diffCountParam, primalCounterLastVal);
}
+
+ // Try to infer maximum possible number of iterations.
+ // (only regions whose intermediates are used outside their region
+ // require a maximum count, so we may see some unresolved regions
+ // without any issues)
+ //
+ tryInferMaxIndex(region, info);
}
}
@@ -511,11 +567,17 @@ struct DiffUnzipPass
}
}
- void processIndexedFwdBlock(IRBlock* fwdBlock)
+ List<IndexTrackingInfo*> getIndexInfoList(IRBlock* block)
{
- if (!isBlockIndexed(fwdBlock))
- return;
+ List<IndexTrackingInfo*> indices;
+ for (auto region : indexRegionMap->getAllAncestorRegions(block))
+ indices.add((IndexTrackingInfo*) indexInfoMap[region].GetValue());
+ return indices;
+ }
+
+ void processIndexedFwdBlock(IRBlock* fwdBlock)
+ {
// Grab first primal block.
IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[fwdBlock->getParent()->getFirstBlock()->getNextBlock()]);
@@ -625,12 +687,7 @@ struct DiffUnzipPass
}
// Build list of indices that the value's block is affected by.
- List<IndexedRegion*> regions;
- {
- IndexedRegion* region = indexRegionMap[valueBlock];
- for (; region; region = region->parent)
- regions.add(region);
- }
+ List<IndexTrackingInfo*> indices = getIndexInfoList(valueBlock);
// 3. Emit an array to top-level to allocate space.
@@ -638,22 +695,22 @@ struct DiffUnzipPass
IRType* storageType = valueType;
- for (auto region : regions)
+ for (auto index : indices)
{
- SLANG_ASSERT(region->status == IndexedRegion::CountStatus::Static);
- SLANG_ASSERT(region->maxIters >= 0);
+ SLANG_ASSERT(index->status == IndexTrackingInfo::CountStatus::Static);
+ SLANG_ASSERT(index->maxIters >= 0);
storageType = builder.getArrayType(
storageType,
builder.getIntValue(
builder.getUIntType(),
- region->maxIters + 1));
+ index->maxIters + 1));
}
// Reverse the list since the indices need to be
// emitted in reverse order.
//
- regions.reverse();
+ indices.reverse();
auto storageVar = builder.emitVar(storageType);
if (isIntermediateContext)
@@ -673,14 +730,14 @@ struct DiffUnzipPass
IRInst* storeAddr = storageVar;
IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType();
- for (auto region : regions)
+ for (auto index : indices)
{
currType = as<IRArrayType>(currType)->getElementType();
storeAddr = builder.emitElementAddress(
builder.getPtrType(currType),
storeAddr,
- region->primalCountParam);
+ index->primalCountParam);
}
if (!isIntermediateContext)
@@ -730,23 +787,17 @@ struct DiffUnzipPass
// TODO: Probably a good idea to do this ahead of time for
// all blocks.
//
- List<IndexedRegion*> useBlockRegions;
- {
- IndexedRegion* region = indexRegionMap.ContainsKey(useBlock) ?
- (IndexedRegion*)indexRegionMap[useBlock] : nullptr;
- for (; region; region = region->parent)
- useBlockRegions.add(region);
- }
+ List<IndexTrackingInfo*> useBlockIndices = getIndexInfoList(useBlock);
- for (auto region : regions)
+ for (auto index : indices)
{
currType = as<IRArrayType>(currType)->getElementType();
- if (useBlockRegions.contains(region))
+ if (useBlockIndices.contains(index))
{
// If the use-block is under the same region, use the
// differential counter variable
//
- auto diffCounterCurrValue = region->diffCountParam;
+ auto diffCounterCurrValue = index->diffCountParam;
loadAddr = builder.emitElementAddress(
builder.getPtrType(currType),
@@ -758,7 +809,7 @@ struct DiffUnzipPass
// If the use-block is outside this region, use the
// last available value (by indexing with primal counter minus 1)
//
- auto primalCounterCurrValue = builder.emitLoad(region->primalCountLastVar);
+ auto primalCounterCurrValue = builder.emitLoad(index->primalCountLastVar);
auto primalCounterLastValue = builder.emitSub(
primalCounterCurrValue->getDataType(),
primalCounterCurrValue,
@@ -1079,156 +1130,6 @@ struct DiffUnzipPass
}
}
- bool isBlockIndexed(IRBlock* block)
- {
- return indexRegionMap.ContainsKey(block) && indexRegionMap[block] != nullptr;
- }
-
- void addNewIndex(IRLoop* targetLoop)
- {
- // Create indexed region without a parent for now.
- // The parent will be filled in during propagation.
- //
- IndexedRegion* region = new IndexedRegion(
- nullptr,
- as<IRBlock>(targetLoop->getParent()),
- targetLoop->getTargetBlock(),
- targetLoop->getBreakBlock(),
- targetLoop->getContinueBlock());
-
- indexRegionMap[targetLoop->getTargetBlock()] = region;
- indexRegions.add(region);
- }
-
- // Deallocate regions
- void cleanupIndexRegionInfo()
- {
- for (auto region : indexRegions)
- {
- delete region;
- }
-
- indexRegions.clear();
- indexRegionMap.Clear();
- }
-
- void propagateAllIndexRegions()
- {
-
-
- // Load up the starting block of every region into
- // initial worklist.
- //
- List<IRBlock*> workList;
- HashSet<IRBlock*> workSet;
- for (auto region : indexRegions)
- {
- workList.add(region->firstBlock);
- workSet.Add(region->firstBlock);
- }
-
- // Keep propagating from initial work list to predecessors
- // Add blocks to work list if their region assignment has changed
- // Add the beginning blocks for complete regions if region parent has changed.
- //
- while (workList.getCount() > 0)
- {
- auto block = workList.getLast();
- workList.removeLast();
- workSet.Remove(block);
-
- HashSet<IRBlock*> successors;
-
- for (auto successor : block->getSuccessors())
- {
- if (successors.Contains(successor))
- continue;
-
- if (propagateIndexRegion(block, successor))
- {
- if (!workSet.Contains(successor))
- {
- workList.add(successor);
- workSet.Add(successor);
- }
-
- // Do we have an index region for the successor, which is
- // also the starting block of that region?
- // Then the change might have been the addition of
- // a parent node. Add the break block so the
- // change can be propagated further.
- //
- if (isBlockIndexed(successor))
- {
- IndexedRegion* succRegion = indexRegionMap[successor];
- if (succRegion->firstBlock == successor)
- {
- if (!workSet.Contains(succRegion->breakBlock))
- {
- workList.add(succRegion->breakBlock);
- workSet.Add(succRegion->breakBlock);
- }
- }
- }
- }
-
- successors.Add(successor);
- }
- }
- }
-
- bool setIndexRegion(IRBlock* block, IndexedRegion* region)
- {
- if (!region) return false;
-
- if (indexRegionMap.ContainsKey(block)
- && indexRegionMap[block] == region)
- return false;
-
- indexRegionMap[block] = region;
- return true;
- }
-
- bool propagateIndexRegion(IRBlock* srcBlock, IRBlock* nextBlock)
- {
- // Is the current region indexed?
- // If not, there's nothing to propagate
- //
- if (!isBlockIndexed(srcBlock))
- return false;
-
- IndexedRegion* region = indexRegionMap[srcBlock];
-
- // If the target's index is already resolved,
- // check if it's a sub-region.
- //
- if (isBlockIndexed(nextBlock))
- {
- IndexedRegion* nextRegion = indexRegionMap[nextBlock];
-
- // If we're at the first block of a region,
- // set current region as continue-region's
- // parent.
- //
- if (nextBlock == nextRegion->firstBlock && nextRegion != region)
- {
- nextRegion->parent = region;
- return true;
- }
-
- return false;
- }
-
- // If we're at the break block, move up to the parent index.
- if (nextBlock == region->breakBlock)
- return setIndexRegion(nextBlock, region->parent);
-
- // If none of the special cases hit, copy the
- // current region to the next block.
- //
- return setIndexRegion(nextBlock, region);
- }
-
// Splitting a loop is one of the trickiest parts of the unzip pass.
// Thus far, we've been dealing with blocks that are only run once, so we
// could arbitrarily move intermediate instructions to other blocks since they are
@@ -1247,9 +1148,6 @@ struct DiffUnzipPass
auto continueBlock = mixedLoop->getContinueBlock();
auto nextBlock = mixedLoop->getTargetBlock();
- // Push a new index.
- addNewIndex(mixedLoop);
-
// Split args.
List<IRInst*> primalArgs;
List<IRInst*> diffArgs;