summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-02-27 23:42:06 -0500
committerGitHub <noreply@github.com>2023-02-27 23:42:06 -0500
commit10e2d9c7c532c204f26bb2c9f383f21b121b2ff2 (patch)
tree9ae0dd84b505a7ecd3fb45de9dbde74f8dd1ebe9 /source
parenta3ba22b51c371d5a20d61aa4e35233ba4f4f68db (diff)
More fixes for reverse-mode on complicated loops (#2675)
* Multiple fixes to get various loop tests to pass. * Create reverse-nested-loop.slang * Fix for variables becoming inaccessible during cfg normalization * Removed comments and moved break-branch-normalization to eliminateMultiLevelBreaks * Fix. * Override liveness tests
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp57
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h12
-rw-r--r--source/slang/slang-ir-eliminate-multilevel-break.cpp54
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp10
5 files changed, 124 insertions, 13 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index 9116f67e9..2199b0771 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -43,6 +43,7 @@ struct BreakableRegionInfo
{
IRVar* breakVar;
IRBlock* breakBlock;
+ IRBlock* headerBlock;
};
struct CFGNormalizationContext
@@ -57,13 +58,39 @@ IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst)
// For now, we're going to naively assume the next block is the condition block.
// Add in more support for more cases as necessary.
//
-
+
auto firstBlock = loopInst->getTargetBlock();
- auto ifElse = as<IRIfElse>(firstBlock->getTerminator());
- SLANG_RELEASE_ASSERT(ifElse);
+ if (as<IRIfElse>(firstBlock->getTerminator()))
+ {
+ return firstBlock;
+ }
+ else
+ {
+ // If there isn't a condition we need to make one with a dummy condition that
+ // always evaluates to true
+ //
- return firstBlock;
+ IRBuilder condBuilder(loopInst->getModule());
+
+ auto condBlock = condBuilder.emitBlock();
+ condBlock->insertAfter(as<IRBlock>(loopInst->getParent()));
+
+ // Make loop go into the condition block
+ firstBlock->replaceUsesWith(condBlock);
+
+ // Emit a condition: true side goes to the loop body, and
+ // false side goes into the break block.
+ //
+ condBuilder.setInsertInto(condBlock);
+ condBuilder.emitIfElse(
+ condBuilder.getBoolValue(true),
+ firstBlock,
+ loopInst->getBreakBlock(),
+ firstBlock);
+
+ return condBlock;
+ }
}
struct CFGNormalizationPass
@@ -133,6 +160,20 @@ struct CFGNormalizationPass
return false;
}
+ void _moveVarsToRegionHeader(BreakableRegionInfo* region, IRBlock* block)
+ {
+ for (auto child = block->getFirstChild(); child;)
+ {
+ auto nextChild = child->getNextInst();
+
+ if (as<IRVar>(child))
+ {
+ child->insertBefore(region->headerBlock->getTerminator());
+ }
+
+ child = nextChild;
+ }
+ }
RegionEndpoint getNormalizedRegionEndpoint(
BreakableRegionInfo* parentRegion,
@@ -140,6 +181,7 @@ struct CFGNormalizationPass
List<IRBlock*> afterBlocks)
{
IRBlock* currentBlock = entryBlock;
+ _moveVarsToRegionHeader(parentRegion, currentBlock);
// By default a region starts off with the 'base' control flow
// and not in the 'break' control flow
@@ -343,6 +385,8 @@ struct CFGNormalizationPass
SLANG_UNEXPECTED("Unhandled control flow inst");
break;
}
+
+ _moveVarsToRegionHeader(parentRegion, currentBlock);
}
// Resolve all intermediate after-blocks
@@ -399,6 +443,7 @@ struct CFGNormalizationPass
{
BreakableRegionInfo info;
info.breakBlock = as<IRLoop>(branchInst)->getBreakBlock();
+ info.headerBlock = as<IRBlock>(branchInst->getParent());
// Emit var into parent block.
builder.setInsertBefore(
@@ -426,7 +471,7 @@ struct CFGNormalizationPass
&info,
firstLoopBlock,
List<IRBlock*>(info.breakBlock));
-
+
// Should not be empty.. but check anyway
SLANG_RELEASE_ASSERT(!preBreakEndPoint.isRegionEmpty);
@@ -495,7 +540,7 @@ struct CFGNormalizationPass
// Add a test for the break variable into the condition.
auto cond = ifElse->getCondition();
- builder.setInsertAfter(cond);
+ builder.setInsertBefore(ifElse);
auto breakFlagVal = builder.emitLoad(info.breakVar);
// Need to invert the break flag if the loop is
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 640f516ed..709968f77 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -14,6 +14,7 @@
#include "slang-ir-init-local-var.h"
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-dominators.h"
+#include "slang-ir-loop-unroll.h"
namespace Slang
{
@@ -583,6 +584,9 @@ namespace Slang
{
convertFuncToSingleReturnForm(func->getModule(), func);
}
+
+ eliminateContinueBlocksInFunc(func->getModule(), func);
+
eliminateMultiLevelBreakForFunc(func->getModule(), func);
IRCFGNormalizationPass cfgPass = {this->getSink()};
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index a30826370..3678bd4b3 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -409,16 +409,14 @@ struct DiffUnzipPass
for (auto region : indexRegions)
{
// Grab first primal block.
- IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]);
- builder.setInsertBefore(firstPrimalBlock->getTerminator());
+ IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]);
+ builder.setInsertBefore(primalInitBlock->getTerminator());
// Make variable in the top-most block (so it's visible to diff blocks)
region->primalCountLastVar = builder.emitVar(builder.getIntType());
builder.addNameHintDecoration(region->primalCountLastVar, UnownedStringSlice("_pc_last_var"));
- {
- IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]);
-
+ {
auto primalCondBlock = as<IRUnconditionalBranch>(
primalInitBlock->getTerminator())->getTargetBlock();
builder.setInsertBefore(primalCondBlock->getTerminator());
@@ -664,8 +662,8 @@ struct DiffUnzipPass
storageVar);
// 4. Store current value into the array and replace uses with a load.
- // TODO: If an index is missing, use the 'last' value of the primal index.
-
+ // If an index is missing, use the 'last' value of the primal index.
+
{
if (!isIntermediateContext)
setInsertAfterOrdinaryInst(&builder, valueToStore);
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp
index 3618e1326..e73fae982 100644
--- a/source/slang/slang-ir-eliminate-multilevel-break.cpp
+++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp
@@ -175,8 +175,62 @@ struct EliminateMultiLevelBreakContext
}
};
+
+ void insertBlockBetween(IRBlock* block, IRBlock* successor)
+ {
+ IRBuilder builder(block->getModule());
+
+ List<IRUse*> relevantUses;
+ for (auto use = successor->firstUse; use; use = use->nextUse)
+ {
+ if (auto terminator = as<IRTerminatorInst>(use->getUser()))
+ {
+ if (as<IRBlock>(terminator->getParent()) == block)
+ {
+ relevantUses.add(use);
+ }
+ }
+ }
+
+ SLANG_RELEASE_ASSERT(relevantUses.getCount() == 1);
+
+ builder.insertBlockAlongEdge(block->getModule(), IREdge(relevantUses[0]));
+ }
+
+ bool normalizeBranchesIntoBreakBlocks(IRGlobalValueWithCode* func)
+ {
+ bool changed = false;
+
+ List<IRBlock*> workList;
+
+ for (auto block : func->getBlocks())
+ workList.add(block);
+
+ for (auto block : workList)
+ {
+ if (auto loop = as<IRLoop>(block->getTerminator()))
+ {
+ auto breakBlock = loop->getBreakBlock();
+
+ for (auto predecessor : breakBlock->getPredecessors())
+ {
+ if (!as<IRUnconditionalBranch>(predecessor->getTerminator()))
+ {
+ insertBlockBetween(predecessor, breakBlock);
+ changed = true;
+ }
+ }
+ }
+ }
+
+ return changed;
+ }
+
void processFunc(IRGlobalValueWithCode* func)
{
+
+ normalizeBranchesIntoBreakBlocks(func);
+
// If func does not have any multi-level breaks, return.
{
FuncContext funcInfo;
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index 2f689ebde..4f9b8d272 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -603,6 +603,16 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst)
builder.setInsertInto(innerBreakableRegionBreakBlock);
_moveParams(innerBreakableRegionBreakBlock, continueBlock);
builder.emitBranch(continueBlock);
+
+ // If the original loop can be executed up to N times, the new loop may be executed
+ // upto N+1 times (although most insts are skipped in the last traversal)
+ //
+ if (auto maxItersDecoration = loopInst->findDecoration<IRLoopMaxItersDecoration>())
+ {
+ auto maxIters = maxItersDecoration->getMaxIters();
+ maxItersDecoration->removeAndDeallocate();
+ builder.addLoopMaxItersDecoration(loopInst, maxIters + 1);
+ }
}
void eliminateContinueBlocksInFunc(IRModule* module, IRGlobalValueWithCode* func)