summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-01-30 11:46:36 -0500
committerGitHub <noreply@github.com>2023-01-30 08:46:36 -0800
commit134dd7eb26fc7988ae13559d276cbf337b4b9d27 (patch)
tree35bd06e6bebb4518bca805e14e85f8f9ef4341c6 /source
parent4a66e9729175a89833e5db784bb64e6a7f60cdf2 (diff)
Overhauled reverse-mode control flow handling (#2608)
* Added switch-case support; fixed non-diff parameter transposition * Made region propagation much more robust. Partial loop unzip implementation * WIP: Added most loop handling code, and a test. Still untested * Added CFG Normalization pass + CFG Reversal Pass + Loop Unzipping + most loop transcription * Add single-iter-loop test. * proj files * removed comments * Update reverse-loop.slang * Removed out-of-date code * Disabled IR validation during constructSSA phase of normalizeCFG. constructSSA now reuses sharedBuilder * Moved normalizeCFG() call to prepareFuncForBackwardDiff()
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp629
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.h26
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp6
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h694
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h532
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h21
-rw-r--r--source/slang/slang-ir-ssa.cpp23
-rw-r--r--source/slang/slang-ir-ssa.h2
-rw-r--r--source/slang/slang-ir-validate.cpp7
-rw-r--r--source/slang/slang-ir.cpp26
12 files changed, 1683 insertions, 287 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
new file mode 100644
index 000000000..4e0a413db
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -0,0 +1,629 @@
+// slang-ir-autodiff-cfg-norm.cpp
+#include "slang-ir-autodiff-cfg-norm.h"
+#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-ssa.h"
+
+#include "slang-ir-validate.h"
+
+namespace Slang
+{
+
+struct RegionEndpoint
+{
+ bool inBreakRegion = false;
+ bool inBaseRegion = false;
+
+ IRBlock* exitBlock = nullptr;
+
+ bool isRegionEmpty = false;
+
+ RegionEndpoint(IRBlock* exitBlock, bool inBreakRegion, bool inBaseRegion) :
+ exitBlock(exitBlock),
+ inBreakRegion(inBreakRegion),
+ inBaseRegion(inBaseRegion),
+ isRegionEmpty(false)
+ { }
+
+ RegionEndpoint(
+ IRBlock* exitBlock,
+ bool inBreakRegion,
+ bool inBaseRegion,
+ bool isRegionEmpty) :
+ exitBlock(exitBlock),
+ inBreakRegion(inBreakRegion),
+ inBaseRegion(inBaseRegion),
+ isRegionEmpty(isRegionEmpty)
+ { }
+
+ RegionEndpoint()
+ { }
+};
+
+struct BreakableRegionInfo
+{
+ IRVar* breakVar;
+ IRBlock* breakBlock;
+};
+
+struct CFGNormalizationContext
+{
+ SharedIRBuilder* sharedBuilder;
+ DiagnosticSink* sink;
+};
+
+
+IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst)
+{
+ // For now, we're going to naively assume the next block is the condition block.
+ // Add in more support for more cases as necessary.
+ //
+
+ auto firstBlock = loopInst->getTargetBlock();
+
+ auto ifElse = as<IRIfElse>(firstBlock->getTerminator());
+ SLANG_RELEASE_ASSERT(ifElse);
+
+ return firstBlock;
+}
+
+struct CFGNormalizationPass
+{
+ CFGNormalizationContext cfgContext;
+
+ CFGNormalizationPass(CFGNormalizationContext ctx) :
+ cfgContext(ctx)
+ { }
+
+ void replaceBreakWithAfterBlock(
+ IRBuilder* builder,
+ BreakableRegionInfo* info,
+ IRBlock* currBlock,
+ IRBlock* afterBlock,
+ IRBlock* parentAfterBlock)
+ {
+ SLANG_ASSERT(as<IRUnconditionalBranch>(currBlock->getTerminator()));
+
+ currBlock->getTerminator()->removeAndDeallocate();
+
+ builder->setInsertInto(currBlock);
+
+ builder->emitStore(info->breakVar, builder->getBoolValue(false));
+ builder->emitBranch(afterBlock);
+
+ // Is after-block unreachable?
+ if (auto unreachInst = as<IRUnreachable>(afterBlock->getFirstOrdinaryInst()))
+ {
+ // Link it to the parentAfterBlock.
+ builder->setInsertInto(afterBlock);
+ unreachInst->removeAndDeallocate();
+
+ /*
+ HashSet<IRBlock*> predecessorSet;
+ for (auto predecessor : parentAfterBlock->getPredecessors())
+ predecessorSet.Add(predecessor);
+
+ SLANG_ASSERT(predecessorSet.Count() <= 1);
+ */
+
+ builder->emitBranch(parentAfterBlock);
+ }
+ }
+
+ IRBlock* getUnconditionalTarget(RegionEndpoint endpoint)
+ {
+ if (!endpoint.isRegionEmpty)
+ {
+ auto branchInst = as<IRUnconditionalBranch>(endpoint.exitBlock->getTerminator());
+ SLANG_ASSERT(branchInst);
+
+ return branchInst->getTargetBlock();
+ }
+ else
+ {
+ return endpoint.exitBlock;
+ }
+ }
+
+ IRBlock* maybeGetUnconditionalTarget(IRBlock* block)
+ {
+ auto branchInst = as<IRUnconditionalBranch>(block->getTerminator());
+
+ return branchInst ? branchInst->getTargetBlock() : nullptr;
+ }
+
+
+ bool isSuccessorBlock(IRBlock* baseBlock, IRBlock* succBlock)
+ {
+ for (auto successor : baseBlock->getSuccessors())
+ if (successor == succBlock)
+ return true;
+
+ return false;
+ }
+
+
+ RegionEndpoint getNormalizedRegionEndpoint(
+ BreakableRegionInfo* parentRegion,
+ IRBlock* entryBlock,
+ List<IRBlock*> afterBlocks)
+ {
+ IRBlock* currentBlock = entryBlock;
+
+ // By default a region starts off with the 'base' control flow
+ // and not in the 'break' control flow
+ // It is the job of the *caller* to make sure the break flow
+ // does not reach this point.
+ //
+ bool currBreakRegion = false;
+ bool currBaseRegion = true;
+
+ // Detect the trivial case. The current block is alredy
+ // in the next region => this region is empty.
+ //
+ if (afterBlocks.contains(currentBlock))
+ return RegionEndpoint(currentBlock, currBreakRegion, currBaseRegion, true);
+
+ IRBuilder builder(cfgContext.sharedBuilder);
+
+ List<IRBlock*> pendingAfterBlocks;
+
+ IRBlock* parentAfterBlock = afterBlocks[0];
+
+ // Follow this thread of execution till we hit an
+ // acceptable after block.
+ //
+ while (!afterBlocks.contains(maybeGetUnconditionalTarget(currentBlock)))
+ {
+ // Check the terminator.
+ auto terminator = currentBlock->getTerminator();
+ switch (terminator->getOp())
+ {
+ case kIROp_unconditionalBranch:
+ {
+ auto targetBlock = as<IRUnconditionalBranch>(terminator)->getTargetBlock();
+ currentBlock = targetBlock;
+ break;
+ }
+
+ case kIROp_ifElse:
+ {
+ auto ifElse = as<IRIfElse>(terminator);
+
+ // Special case. One of the branches will
+ // lead back to the condition.
+ //
+ SLANG_ASSERT(ifElse->getAfterBlock() != parentRegion->breakBlock);
+
+ auto trueEndPoint = getNormalizedRegionEndpoint(
+ parentRegion,
+ ifElse->getTrueBlock(),
+ List<IRBlock*>(ifElse->getAfterBlock(), parentRegion->breakBlock));
+
+ auto falseEndPoint = getNormalizedRegionEndpoint(
+ parentRegion,
+ ifElse->getFalseBlock(),
+ List<IRBlock*>(ifElse->getAfterBlock(), parentRegion->breakBlock));
+
+ auto trueTargetBlock = getUnconditionalTarget(trueEndPoint);
+ auto falseTargetBlock = getUnconditionalTarget(falseEndPoint);
+
+ auto afterBlock = ifElse->getAfterBlock();
+
+ // Trivial case, both end-points branch into the after block
+ if (trueTargetBlock == afterBlock &&
+ falseTargetBlock == afterBlock)
+ {
+ currentBlock = afterBlock;
+ break;
+ }
+
+ auto afterBreakRegion = false;
+ auto afterBaseRegion = false;
+
+ if (trueTargetBlock == parentRegion->breakBlock)
+ {
+ // Branch into after block (and set break variable)
+ replaceBreakWithAfterBlock(
+ &builder,
+ parentRegion,
+ trueEndPoint.exitBlock,
+ afterBlock,
+ parentAfterBlock);
+
+ // If this branch breaks, then the after-block
+ // definitely has break-flow.
+ //
+ afterBreakRegion = true;
+ }
+ else
+ {
+ // If this branch naturally branches into our
+ // after-block, copy whatever flags the endpoints
+ // have.
+ //
+ afterBreakRegion = afterBreakRegion || trueEndPoint.inBreakRegion;
+ afterBaseRegion = afterBaseRegion || trueEndPoint.inBaseRegion;
+ }
+
+ if (falseTargetBlock == parentRegion->breakBlock)
+ {
+ // Branch into after block (and set break variable)
+ replaceBreakWithAfterBlock(
+ &builder,
+ parentRegion,
+ falseEndPoint.exitBlock,
+ afterBlock,
+ parentAfterBlock);
+
+ // If this branch breaks, then the after-block
+ // definitely has break-flow.
+ //
+ afterBreakRegion = true;
+ }
+ else
+ {
+ // If this branch naturally branches into our
+ // after-block, copy whatever flags the endpoints
+ // have.
+ //
+ afterBreakRegion = afterBreakRegion || falseEndPoint.inBreakRegion;
+ afterBaseRegion = afterBaseRegion || falseEndPoint.inBaseRegion;
+ }
+
+ // TODO: For now, we're being overly cautious and assuming
+ // the after region might have something to execute.
+ // Ideally, we should check if the block is empty, and
+ // hold off on splitting until we encounter non-empty
+ // blocks.
+ //
+ afterBaseRegion = true;
+
+ // Do we need to split the after region?
+ if (afterBaseRegion && afterBreakRegion)
+ {
+ // We could arrive at the after-block before or
+ // after encountering a break statement.
+ // To handle this, we'll split the flow by checking the break flag
+ //
+ builder.setInsertAfter(afterBlock);
+
+ auto preAfterSplitBlock = builder.emitBlock();
+ preAfterSplitBlock->insertBefore(afterBlock);
+
+ auto afterSplitBlock = builder.emitBlock();
+ afterSplitBlock->insertBefore(afterBlock);
+
+ afterBlock->replaceUsesWith(preAfterSplitBlock);
+
+ builder.setInsertInto(preAfterSplitBlock);
+ builder.emitBranch(afterSplitBlock);
+
+ // Converging block for the split that we're making.
+ auto afterSplitAfterBlock = builder.emitBlock();
+
+ builder.setInsertInto(afterSplitBlock);
+ auto breakFlagValue = builder.emitLoad(parentRegion->breakVar);
+
+ builder.emitIfElse(
+ breakFlagValue,
+ afterBlock,
+ afterSplitAfterBlock,
+ afterSplitAfterBlock);
+
+ // At this point, we need to place afterSplitAfterBlock between
+ // at the _end_ of this region, but we aren't there yet (and
+ // don't know which block is the end of this region)
+ // Therefore, we'll defer this step and add it to a list for later.
+ //
+ pendingAfterBlocks.add(afterSplitAfterBlock);
+
+ // Update current block.
+ currentBlock = afterBlock;
+ afterBreakRegion = false;
+ afterBaseRegion = true;
+ }
+
+ currentBlock = afterBlock;
+ currBreakRegion = afterBreakRegion;
+ currBaseRegion = afterBaseRegion;
+ break;
+ }
+
+ case kIROp_loop:
+ {
+ auto breakBlock = normalizeBreakableRegion(terminator);
+
+ // Advance to the break block (no updates to the control flags)
+ currentBlock = breakBlock;
+ break;
+ }
+
+ default:
+ // Do proper diagnosing
+ SLANG_UNEXPECTED("Unhandled control flow inst");
+ break;
+ }
+ }
+
+ // Resolve all intermediate after-blocks
+ pendingAfterBlocks.reverse();
+
+ for (auto block : pendingAfterBlocks)
+ {
+ builder.setInsertInto(block);
+ auto nextRegionBlock = maybeGetUnconditionalTarget(currentBlock);
+ SLANG_ASSERT(nextRegionBlock);
+
+ builder.emitBranch(nextRegionBlock);
+
+ builder.setInsertInto(currentBlock);
+ currentBlock->getTerminator()->removeAndDeallocate();
+ builder.emitBranch(block);
+
+ block->insertAfter(currentBlock);
+
+ currentBlock = block;
+ currBaseRegion = true;
+ currBreakRegion = true;
+ }
+
+ return RegionEndpoint(currentBlock, currBreakRegion, currBaseRegion);
+ }
+
+ HashSet<IRBlock*> getPredecessorSet(IRBlock* block)
+ {
+ HashSet<IRBlock*> predecessorSet;
+ for (auto predecessor : block->getPredecessors())
+ predecessorSet.Add(predecessor);
+
+ return predecessorSet;
+ }
+
+ bool isLoopTrivial(IRLoop* loop)
+ {
+ // Get 'looping' block (first block in loop)
+ auto firstLoopBlock = loop->getTargetBlock();
+
+ // If we only have one predecessor, the loop is trivial.
+ return (getPredecessorSet(firstLoopBlock).Count() == 1);
+ }
+
+ IRBlock* normalizeBreakableRegion(
+ IRInst* branchInst)
+ {
+ IRBuilder builder(cfgContext.sharedBuilder);
+
+ switch (branchInst->getOp())
+ {
+ case kIROp_loop:
+ {
+ BreakableRegionInfo info;
+ info.breakBlock = as<IRLoop>(branchInst)->getBreakBlock();
+
+ // Emit var into parent block.
+ builder.setInsertBefore(
+ as<IRBlock>(branchInst->getParent())->getTerminator());
+
+ // Create and initialize break var to true
+ // true -> no break yet.
+ // false -> atleast one break statement hit.
+ //
+ info.breakVar = builder.emitVar(builder.getBoolType());
+ builder.emitStore(info.breakVar, builder.getBoolValue(true));
+
+ // If the loop is trivial (i.e. single iteration, with no
+ // edges actually in a loop), we're just going to remove
+ // it.. (we can do this, because the normalization pass
+ // will transform any break and continue statements)
+ //
+ if (isLoopTrivial(as<IRLoop>(branchInst)))
+ {
+ auto firstLoopBlock = as<IRLoop>(branchInst)->getTargetBlock();
+ auto terminator = firstLoopBlock->getTerminator();
+
+ // We really shouldn't see a conditional branch on a trivial loop
+ // but if we hit this assert, handle this case.
+ //
+ SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(terminator));
+
+ // Normalize the region from the first loop block till break.
+ auto preBreakEndPoint = getNormalizedRegionEndpoint(
+ &info,
+ firstLoopBlock,
+ List<IRBlock*>(info.breakBlock));
+
+ // Should not be empty.. but check anyway
+ SLANG_RELEASE_ASSERT(!preBreakEndPoint.isRegionEmpty);
+
+ // Quick consistency check.. preBreakEndPoint should be
+ // branching into break block.
+ SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(
+ preBreakEndPoint.exitBlock->getTerminator())->getTargetBlock() == info.breakBlock);
+
+ auto currentBlock = branchInst->getParent();
+
+ // Now get rid of the loop inst and replace with unconditional branch.
+ branchInst->removeAndDeallocate();
+ builder.setInsertInto(currentBlock);
+ builder.emitBranch(firstLoopBlock);
+
+ return info.breakBlock;
+ }
+
+ auto condBlock = getOrCreateTopLevelCondition(as<IRLoop>(branchInst));
+
+ auto ifElse = as<IRIfElse>(condBlock->getTerminator());
+
+ auto trueEndPoint = getNormalizedRegionEndpoint(
+ &info,
+ ifElse->getTrueBlock(),
+ List<IRBlock*>(condBlock, info.breakBlock));
+
+ auto falseEndPoint = getNormalizedRegionEndpoint(
+ &info,
+ ifElse->getFalseBlock(),
+ List<IRBlock*>(condBlock, info.breakBlock));
+
+ RegionEndpoint loopEndPoint;
+ bool isLoopOnTrueSide = true;
+
+ // First figure out which side belongs to the loop body.
+ if (isSuccessorBlock(trueEndPoint.exitBlock, condBlock))
+ {
+ loopEndPoint = trueEndPoint;
+ isLoopOnTrueSide = true;
+ }
+
+ if (isSuccessorBlock(falseEndPoint.exitBlock, condBlock))
+ {
+ loopEndPoint = falseEndPoint;
+ isLoopOnTrueSide = false;
+ }
+
+ SLANG_RELEASE_ASSERT(loopEndPoint.exitBlock);
+
+ // Special case.. the if-else of a loop needs it's
+ // after block to be pointing at the last block before
+ // it loops back to the if-else.
+ //
+ // ifElse->afterBlock.set(loopEndPoint.exitBlock);
+
+ // Does the loop endpoint have both 'break' and 'base'
+ // control flows?
+ //
+ if (loopEndPoint.inBaseRegion && loopEndPoint.inBreakRegion)
+ {
+ // Add a test for the break variable into the condition.
+ auto cond = ifElse->getCondition();
+
+ builder.setInsertAfter(cond);
+ auto breakFlagVal = builder.emitLoad(info.breakVar);
+
+ // Need to invert the break flag if the loop is
+ // on the false side.
+ //
+ if (!isLoopOnTrueSide)
+ {
+ IRInst* args[1] = {breakFlagVal};
+ breakFlagVal = builder.emitIntrinsicInst(
+ builder.getBoolType(),
+ kIROp_Not,
+ 1,
+ args);
+ }
+
+ IRInst* args[2] = {cond, breakFlagVal};
+
+ // If break-var = true, direct flow to the loop
+ // otherwise, direct flow to break
+ //
+ auto complexCond = builder.emitIntrinsicInst(
+ builder.getBoolType(),
+ kIROp_And,
+ 2,
+ args);
+
+ ifElse->condition.set(complexCond);
+ }
+
+ return info.breakBlock;
+ }
+ case kIROp_Switch:
+ {
+ auto switchInst = as<IRSwitch>(branchInst);
+
+ // SLANG_UNEXPECTED("Switch-case normalization not implemented yet.");
+ BreakableRegionInfo info;
+ info.breakBlock = as<IRSwitch>(branchInst)->getBreakLabel();
+
+ // Emit var into parent block.
+ builder.setInsertBefore(
+ as<IRBlock>(branchInst->getParent())->getTerminator());
+
+ // Create and initialize break var to true
+ // true -> no break yet.
+ // false -> atleast one break statement hit.
+ //
+ info.breakVar = builder.emitVar(builder.getBoolType());
+ builder.emitStore(info.breakVar, builder.getBoolValue(true));
+
+ // Go over case labels and normalize all sub-regions.
+ for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii++)
+ {
+ auto caseBlock = switchInst->getCaseLabel(ii);
+ auto caseEndPoint = getNormalizedRegionEndpoint(
+ &info,
+ caseBlock,
+ List<IRBlock*>(info.breakBlock)).exitBlock;
+
+ // Consistency check (if this case hits, it's probably
+ // because the switch has fall-through, which we don't support)
+ SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(
+ caseEndPoint->getTerminator())->getTargetBlock() == info.breakBlock);
+ }
+
+ auto defaultEndPoint = getNormalizedRegionEndpoint(
+ &info,
+ switchInst->getDefaultLabel(),
+ List<IRBlock*>(info.breakBlock)).exitBlock;
+
+ // Consistency check (if this case hits, it's probably
+ // because the switch has fall-through, which we don't support)
+ SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(
+ defaultEndPoint->getTerminator())->getTargetBlock() == info.breakBlock);
+
+ return info.breakBlock;
+ }
+ default:
+ break;
+ }
+
+ SLANG_UNEXPECTED("Unhandled control-flow inst");
+ }
+};
+
+void normalizeCFG(
+ IRGlobalValueWithCode* func,
+ IRCFGNormalizationPass const& options)
+{
+ // Remove phis to simplify our pass. We'll add them back in later
+ // with constructSSA.
+ //
+ eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func);
+
+ SharedIRBuilder sharedBuilder(func->getModule());
+ sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+ CFGNormalizationContext context = {&sharedBuilder, options.sink};
+ CFGNormalizationPass cfgPass(context);
+
+ List<IRBlock*> workList;
+ workList.add(func->getFirstBlock());
+
+ while (workList.getCount() > 0)
+ {
+ auto block = workList.getLast();
+ workList.removeLast();
+
+ if (auto loop = as<IRLoop>(block->getTerminator()))
+ {
+ auto breakBlock = cfgPass.normalizeBreakableRegion(loop);
+ workList.add(breakBlock);
+ }
+ else if (auto switchCase = as<IRSwitch>(block->getTerminator()))
+ {
+ auto breakBlock = cfgPass.normalizeBreakableRegion(switchCase);
+ workList.add(breakBlock);
+ }
+ else
+ {
+ for (auto successor : block->getSuccessors())
+ workList.add(successor);
+ }
+ }
+
+ disableIRValidationAtInsert();
+ constructSSA(&sharedBuilder, func);
+ enableIRValidationAtInsert();
+}
+
+} \ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.h b/source/slang/slang-ir-autodiff-cfg-norm.h
new file mode 100644
index 000000000..2a39f7695
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-cfg-norm.h
@@ -0,0 +1,26 @@
+// slang-ir-autodiff-cfg-norm.h
+#pragma once
+
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+ struct IRModule;
+
+ struct IRCFGNormalizationPass
+ {
+ DiagnosticSink* sink;
+ };
+
+ /// Eliminate "break" statements from breakable regions
+ /// (loops, switch-case). This will use temporary booleans
+ /// instead of a break statement, in order to ensure all
+ /// branches inside the breakable region always have a valid
+ /// "after" block.
+ ///
+ void normalizeCFG(
+ IRGlobalValueWithCode* func,
+ IRCFGNormalizationPass const& options = IRCFGNormalizationPass());
+
+ IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst);
+}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index fce2043eb..6f18a3d8a 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -3,6 +3,7 @@
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-autodiff-cfg-norm.h"
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-ssa-simplification.h"
@@ -16,7 +17,7 @@ namespace Slang
IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermeidateType)
{
List<IRType*> newParameterTypes;
- IRType* diffReturnType;
+ IRType* diffReturnType;
for (UIndex i = 0; i < funcType->getParamCount(); i++)
{
@@ -509,6 +510,9 @@ namespace Slang
}
eliminateMultiLevelBreakForFunc(func->getModule(), func);
+ IRCFGNormalizationPass cfgPass = {this->getSink()};
+ normalizeCFG(func);
+
AutoDiffAddressConversionPolicy cvtPolicty;
cvtPolicty.diffTypeContext = &diffTypeContext;
auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index f43206333..05a5f8f56 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -960,8 +960,8 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori
}
else
{
- auto diffType = _differentiateTypeImpl(builder, origType);
IRInst* primal = maybeCloneForPrimalInst(builder, origType);
+ auto diffType = _differentiateTypeImpl(builder, origType);
result = InstPair(primal, diffType);
}
}
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 0d45c6a84..901649f3c 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -7,6 +7,7 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-autodiff-fwd.h"
+#include "slang-ir-autodiff-cfg-norm.h"
namespace Slang
{
@@ -96,37 +97,384 @@ struct DiffTransposePass
fwdBlock(fwdBlock), phiGrads(phiGrads)
{}
};
-
- struct Region
+
+ bool isBlockLastInRegion(IRBlock* block, List<IRBlock*> endBlocks)
{
- IRBlock* exitBlock;
- IRBlock* originBlock;
+ if (auto branchInst = as<IRUnconditionalBranch>(block->getTerminator()))
+ {
+ if (endBlocks.contains(branchInst->getTargetBlock()))
+ return true;
+ else
+ return false;
+ }
+ else if (as<IRReturn>(block->getTerminator()))
+ {
+ return true;
+ }
- Region* parent;
+ return false;
+ }
+
+ List<IRInst*> getPhiGrads(IRBlock* block)
+ {
+ if (!phiGradsMap.ContainsKey(block))
+ return List<IRInst*>();
+
+ return phiGradsMap[block];
+ }
- Region() :
- exitBlock(nullptr),
- originBlock(nullptr),
- parent(nullptr)
+ struct RegionEntryPoint
+ {
+ IRBlock* revEntry;
+ IRBlock* fwdEndPoint;
+ bool isTrivial;
+
+ RegionEntryPoint(IRBlock* revEntry, IRBlock* fwdEndPoint) :
+ revEntry(revEntry),
+ fwdEndPoint(fwdEndPoint),
+ isTrivial(false)
{ }
- Region(IRBlock* exitBlock, Region* parent) :
- exitBlock(exitBlock),
- originBlock(nullptr),
- parent(parent)
+ RegionEntryPoint(IRBlock* revEntry, IRBlock* fwdEndPoint, bool isTrivial) :
+ revEntry(revEntry),
+ fwdEndPoint(fwdEndPoint),
+ isTrivial(isTrivial)
{ }
+ };
+
+ IRBlock* getUniquePredecessor(IRBlock* block)
+ {
+ HashSet<IRBlock*> predecessorSet;
+ for (auto predecessor : block->getPredecessors())
+ predecessorSet.Add(predecessor);
+
+ SLANG_ASSERT(predecessorSet.Count() == 1);
+
+ return (*predecessorSet.begin());
+ }
+
+ RegionEntryPoint reverseCFGRegion(IRBlock* block, List<IRBlock*> endBlocks)
+ {
+ IRBlock* revBlock = revBlockMap[block];
- void finish(IRBlock* block)
+ if (endBlocks.contains(block))
{
- SLANG_ASSERT(!this->originBlock);
- this->originBlock = block;
+ return RegionEntryPoint(revBlock, block, true);
+ }
+
+ // We shouldn't already have a terminator for this block
+ SLANG_ASSERT(revBlock->getTerminator() == nullptr);
+
+ IRBuilder builder(autodiffContext->sharedBuilder);
+
+ auto currentBlock = block;
+ while (!isBlockLastInRegion(currentBlock, endBlocks))
+ {
+ auto terminator = currentBlock->getTerminator();
+ switch(terminator->getOp())
+ {
+ case kIROp_Return:
+ return RegionEntryPoint(revBlockMap[currentBlock], nullptr);
+
+ case kIROp_unconditionalBranch:
+ {
+ auto branchInst = as<IRUnconditionalBranch>(terminator);
+ auto nextBlock = as<IRBlock>(branchInst->getTargetBlock());
+ IRBlock* nextRevBlock = revBlockMap[nextBlock];
+ IRBlock* currRevBlock = revBlockMap[currentBlock];
+
+ SLANG_ASSERT(nextRevBlock->getTerminator() == nullptr);
+ builder.setInsertInto(nextRevBlock);
+
+ builder.emitBranch(currRevBlock,
+ getPhiGrads(nextBlock).getCount(),
+ getPhiGrads(nextBlock).getBuffer());
+
+
+ currentBlock = nextBlock;
+ break;
+ }
+
+ case kIROp_ifElse:
+ {
+ auto ifElse = as<IRIfElse>(terminator);
+
+ auto trueBlock = ifElse->getTrueBlock();
+ auto falseBlock = ifElse->getFalseBlock();
+ auto afterBlock = ifElse->getAfterBlock();
+
+ auto revTrueRegionInfo = reverseCFGRegion(
+ trueBlock,
+ List<IRBlock*>(afterBlock));
+ auto revFalseRegionInfo = reverseCFGRegion(
+ falseBlock,
+ List<IRBlock*>(afterBlock));
+ //bool isTrueTrivial = (trueBlock == afterBlock);
+ //bool isFalseTrivial = (falseBlock == afterBlock);
+
+ IRBlock* revCondBlock = revBlockMap[afterBlock];
+ SLANG_ASSERT(revCondBlock->getTerminator() == nullptr);
+
+
+ IRBlock* revTrueEntryBlock = revTrueRegionInfo.revEntry;
+ IRBlock* revFalseEntryBlock = revFalseRegionInfo.revEntry;
+
+ IRBlock* revTrueExitBlock = revBlockMap[trueBlock];
+ IRBlock* revFalseExitBlock = revBlockMap[falseBlock];
+
+ auto phiGrads = getPhiGrads(afterBlock);
+ if (phiGrads.getCount() > 0)
+ {
+ revTrueEntryBlock = insertPhiBlockBefore(revTrueEntryBlock, phiGrads);
+ revFalseEntryBlock = insertPhiBlockBefore(revFalseEntryBlock, phiGrads);
+ }
+
+ IRBlock* revAfterBlock = revBlockMap[currentBlock];
+
+ builder.setInsertInto(revCondBlock);
+ builder.emitIfElse(
+ ifElse->getCondition(),
+ revTrueEntryBlock,
+ revFalseEntryBlock,
+ revAfterBlock);
+
+ if (!revTrueRegionInfo.isTrivial)
+ {
+ builder.setInsertInto(revTrueExitBlock);
+ SLANG_ASSERT(revTrueExitBlock->getTerminator() == nullptr);
+ builder.emitBranch(
+ revAfterBlock,
+ getPhiGrads(trueBlock).getCount(),
+ getPhiGrads(trueBlock).getBuffer());
+ }
+
+ if (!revFalseRegionInfo.isTrivial)
+ {
+ builder.setInsertInto(revFalseExitBlock);
+ SLANG_ASSERT(revFalseExitBlock->getTerminator() == nullptr);
+ builder.emitBranch(
+ revAfterBlock,
+ getPhiGrads(falseBlock).getCount(),
+ getPhiGrads(falseBlock).getBuffer());
+ }
+
+ currentBlock = afterBlock;
+ break;
+ }
+
+ case kIROp_loop:
+ {
+ auto loop = as<IRLoop>(terminator);
+
+ auto firstLoopBlock = loop->getTargetBlock();
+ auto breakBlock = loop->getBreakBlock();
+
+ auto condBlock = getOrCreateTopLevelCondition(loop);
+
+ auto ifElse = as<IRIfElse>(condBlock->getTerminator());
+
+ auto trueBlock = ifElse->getTrueBlock();
+ auto falseBlock = ifElse->getFalseBlock();
+
+ auto trueRegionInfo = reverseCFGRegion(
+ trueBlock,
+ List<IRBlock*>(breakBlock, condBlock));
+
+ auto falseRegionInfo = reverseCFGRegion(
+ falseBlock,
+ List<IRBlock*>(breakBlock, condBlock));
+
+ auto preCondRegionInfo = reverseCFGRegion(
+ firstLoopBlock,
+ List<IRBlock*>(condBlock));
+
+ // assume loop[next] -> cond can be a region and reverse it.
+ // assume cond[false] -> break can be a region and reverse it.
+ // assume cond[true] -> cond can be a region and reverse it.
+ // rev-loop = rev[break]
+ // rev-cond = rev[cond]
+ // rev-cond[true] -> entry of (cond[true] -> cond)
+ // rev-cond[false] -> entry of (loop[next] -> cond)
+ // exit of (cond[false]->break) branches into rev-cond
+ // rev-loop[next] -> entry of (cond[false] -> break)
+ // exit of (cond[true] -> cond) branches into rev-cond
+ // exit of (loop[next] -> cond) branches into rev[loop] (rev-break)
+
+ // For now, we'll assume the loop is always on the 'true' side
+ // If this assert fails, add in the case where the loop
+ // may be on the 'false' side.
+ //
+ SLANG_RELEASE_ASSERT(trueRegionInfo.fwdEndPoint == condBlock);
+
+ auto revTrueBlock = trueRegionInfo.revEntry;
+ auto revFalseBlock = (preCondRegionInfo.isTrivial) ?
+ revBlockMap[currentBlock] : preCondRegionInfo.revEntry;
+
+ // The block that will become target of the new loop inst
+ // (the old false-region) This _could_ be the condition itself
+ //
+ IRBlock* revPreCondBlock = (falseRegionInfo.isTrivial) ?
+ revBlockMap[condBlock] : falseRegionInfo.revEntry;
+
+ // Old cond block remains new cond block.
+ IRBlock* revCondBlock = revBlockMap[condBlock];
+
+ // Old cond block becomes new pre-break block.
+ IRBlock* revBreakBlock = revBlockMap[currentBlock];
+
+ // Old true-side starting block becomes loop end block.
+ IRBlock* revLoopEndBlock = revBlockMap[trueBlock];
+ builder.setInsertInto(revLoopEndBlock);
+ builder.emitBranch(
+ revCondBlock,
+ getPhiGrads(trueBlock).getCount(),
+ getPhiGrads(trueBlock).getBuffer());
+
+ // Old false-side starting block becomes end block
+ // for the new pre-cond region (which could be empty)
+ //
+ IRBlock* revPreCondEndBlock = revBlockMap[falseBlock];
+ if (!falseRegionInfo.isTrivial)
+ {
+ builder.setInsertInto(revPreCondEndBlock);
+ builder.emitBranch(
+ revCondBlock,
+ getPhiGrads(falseBlock).getCount(),
+ getPhiGrads(falseBlock).getBuffer());
+ }
+
+ IRBlock* revBreakRegionExitBlock = revBlockMap[firstLoopBlock];
+ if (!preCondRegionInfo.isTrivial)
+ {
+ builder.setInsertInto(revBreakRegionExitBlock);
+ builder.emitBranch(
+ revBreakBlock,
+ getPhiGrads(firstLoopBlock).getCount(),
+ getPhiGrads(firstLoopBlock).getBuffer());
+ }
+
+ // Emit condition into the new cond block.
+ builder.setInsertInto(revCondBlock);
+ builder.emitIfElse(
+ ifElse->getCondition(),
+ revTrueBlock,
+ revFalseBlock,
+ revLoopEndBlock);
+
+ // Emit loop into rev-version of the break block.
+ auto revLoopBlock = revBlockMap[breakBlock];
+ builder.setInsertInto(revLoopBlock);
+ builder.emitLoop(
+ revPreCondBlock,
+ revBreakBlock,
+ revLoopEndBlock,
+ getPhiGrads(breakBlock).getCount(),
+ getPhiGrads(breakBlock).getBuffer());
+
+ currentBlock = breakBlock;
+ break;
+ }
+
+ case kIROp_Switch:
+ {
+ auto switchInst = as<IRSwitch>(terminator);
+
+ auto breakBlock = switchInst->getBreakLabel();
+
+ IRBlock* revBreakBlock = revBlockMap[currentBlock];
+
+ // Reverse each case label
+ List<IRInst*> reverseSwitchArgs;
+ Dictionary<IRBlock*, IRBlock*> reverseLabelEntryBlocks;
+
+ for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii++)
+ {
+ reverseSwitchArgs.add(switchInst->getCaseValue(ii));
+
+ auto caseLabel = switchInst->getCaseLabel(ii);
+ if (!reverseLabelEntryBlocks.ContainsKey(caseLabel))
+ {
+ auto labelRegionInfo = reverseCFGRegion(
+ caseLabel,
+ List<IRBlock*>(breakBlock));
+
+ // Handle this case eventually.
+ SLANG_ASSERT(!labelRegionInfo.isTrivial);
+
+ // Wire the exit to the break block
+ IRBlock* revLabelExit = revBlockMap[caseLabel];
+ SLANG_ASSERT(revLabelExit->getTerminator() == nullptr);
+
+ builder.setInsertInto(revLabelExit);
+ builder.emitBranch(revBreakBlock);
+
+ reverseLabelEntryBlocks[caseLabel] = labelRegionInfo.revEntry;
+ reverseSwitchArgs.add(labelRegionInfo.revEntry);
+ }
+ else
+ {
+ reverseSwitchArgs.add(reverseLabelEntryBlocks[caseLabel]);
+ }
+ }
+
+ auto defaultRegionInfo = reverseCFGRegion(
+ switchInst->getDefaultLabel(),
+ List<IRBlock*>(breakBlock));
+ SLANG_ASSERT(!defaultRegionInfo.isTrivial);
+
+ auto revDefaultRegionEntry = defaultRegionInfo.revEntry;
+
+ builder.setInsertInto(revBlockMap[switchInst->getDefaultLabel()]);
+ builder.emitBranch(revBreakBlock);
+
+ auto phiGrads = getPhiGrads(breakBlock);
+ if (phiGrads.getCount() > 0)
+ {
+ for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii++)
+ {
+ reverseSwitchArgs[ii * 2 + 1] =
+ insertPhiBlockBefore(as<IRBlock>(reverseSwitchArgs[ii * 2 + 1]), phiGrads);
+ }
+ revDefaultRegionEntry =
+ insertPhiBlockBefore(as<IRBlock>(revDefaultRegionEntry), phiGrads);
+ }
+
+ auto revSwitchBlock = revBlockMap[breakBlock];
+ builder.setInsertInto(revSwitchBlock);
+ builder.emitSwitch(
+ switchInst->getCondition(),
+ revBreakBlock,
+ revDefaultRegionEntry,
+ reverseSwitchArgs.getCount(),
+ reverseSwitchArgs.getBuffer());
+
+ currentBlock = breakBlock;
+ break;
+ }
+
+ }
}
- bool isComplete()
+ if (auto branchInst = as<IRUnconditionalBranch>(currentBlock->getTerminator()))
{
- return (this->originBlock != nullptr);
+ return RegionEntryPoint(
+ revBlockMap[currentBlock],
+ branchInst->getTargetBlock(),
+ false);
}
- };
+ else if (auto returnInst = as<IRReturn>(currentBlock->getTerminator()))
+ {
+ return RegionEntryPoint(
+ revBlockMap[currentBlock],
+ nullptr,
+ true);
+ }
+ else
+ {
+ // Regions should _really_ not end on a conditional branch (I think)
+ SLANG_UNEXPECTED("Unexpected: Region ended on a conditional branch");
+ }
+ }
void transposeDiffBlocksInFunc(
IRFunc* revDiffFunc,
@@ -140,11 +488,6 @@ struct DiffTransposePass
auto terminalPrimalBlocks = getTerminalPrimalBlocks(revDiffFunc);
auto terminalDiffBlocks = getTerminalDiffBlocks(revDiffFunc);
- // Add a top-level null region entry for the terminal diff block.
- regionMap[terminalDiffBlocks[0]] = nullptr;
-
- buildAfterBlockMap(revDiffFunc);
-
// Traverse all instructions/blocks in reverse (starting from the terminator inst)
// look for insts/blocks marked with IRDifferentialInstDecoration,
// and transpose them in the revDiffFunc.
@@ -184,7 +527,7 @@ struct DiffTransposePass
// Keep track of first diff block, since this is where
// we'll emit temporary vars to hold per-block derivatives.
//
- firstRevDiffBlockMap[revDiffFunc] = revBlockMap[workList[0]];
+ firstRevDiffBlockMap[revDiffFunc] = revBlockMap[terminalDiffBlocks[0]];
IRInst* retVal = nullptr;
@@ -201,17 +544,14 @@ struct DiffTransposePass
this->transposeBlock(block, revBlock);
}
- // Some blocks may not have their control flow
- // insts completed. Do them now that we have
- // more information.
+ // At this point all insts have been transposed, but the blocks
+ // have no control flow.
+ // reverseCFG will use fwd-mode blocks as reference, and
+ // wire the corresponding rev-mode blocks in reverse.
//
- for (auto pendingBlockInfo : pendingBlocks)
- {
- builder.setInsertInto(revBlockMap[pendingBlockInfo.fwdBlock]);
- completeEmitTerminator(&builder, pendingBlockInfo.fwdBlock, pendingBlockInfo.phiGrads);
- }
-
- pendingBlocks.clear();
+ auto branchInst = as<IRUnconditionalBranch>(terminalPrimalBlocks[0]->getTerminator());
+ auto firstFwdDiffBlock = branchInst->getTargetBlock();
+ reverseCFGRegion(firstFwdDiffBlock, List<IRBlock*>());
// Link the last differential fwd-mode block (which will be the first
// rev-mode block) as the successor to the last primal block.
@@ -223,7 +563,7 @@ struct DiffTransposePass
SLANG_ASSERT(terminalDiffBlocks.getCount() == 1);
auto terminalPrimalBlock = terminalPrimalBlocks[0];
- auto terminalRevBlock = as<IRBlock>(revBlockMap[terminalDiffBlocks[0]]);
+ auto firstRevBlock = as<IRBlock>(revBlockMap[terminalDiffBlocks[0]]);
terminalPrimalBlock->getTerminator()->removeAndDeallocate();
@@ -231,9 +571,9 @@ struct DiffTransposePass
subBuilder.setInsertInto(terminalPrimalBlock);
// There should be no parameters in the first reverse-mode block.
- SLANG_ASSERT(terminalRevBlock->getFirstParam() == nullptr);
+ SLANG_ASSERT(firstRevBlock->getFirstParam() == nullptr);
- auto branch = subBuilder.emitBranch(terminalRevBlock);
+ auto branch = subBuilder.emitBranch(firstRevBlock);
if (!retVal)
{
@@ -247,13 +587,20 @@ struct DiffTransposePass
subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal);
}
+ // At this point, the only block left without terminator insts
+ // should be the last one. Add a void return to complete it.
+ //
+ IRBlock* lastRevBlock = revBlockMap[firstFwdDiffBlock];
+ SLANG_ASSERT(lastRevBlock->getTerminator() == nullptr);
+
+ builder.setInsertInto(lastRevBlock);
+ builder.emitReturn();
+
// Remove fwd-mode blocks.
for (auto block : workList)
{
block->removeAndDeallocate();
}
-
- cleanupRegionInfo();
}
// Fetch or create a gradient accumulator var
@@ -385,6 +732,17 @@ struct DiffTransposePass
List<IRInst*> phiParamRevGradInsts;
for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam())
{
+ // This param might be used outside this block.
+ // If so, add/get an accumulator.
+ //
+ if (isInstUsedOutsideParentBlock(param))
+ {
+ auto accVar = getOrCreateAccumulatorVar(param);
+ addRevGradientForFwdInst(
+ param,
+ RevGradient(param, builder.emitLoad(accVar), nullptr));
+ }
+
if (hasRevGradients(param))
{
auto gradients = popRevGradients(param);
@@ -396,6 +754,11 @@ struct DiffTransposePass
phiParamRevGradInsts.add(gradInst);
}
+ else
+ {
+ phiParamRevGradInsts.add(
+ emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param)));
+ }
}
// Also handle any remaining gradients for insts that appear in prior blocks.
@@ -448,13 +811,9 @@ struct DiffTransposePass
// We _should_ be completely out of gradients to process at this point.
SLANG_ASSERT(gradientsMap.Count() == 0);
- if (!tryEmitTerminator(&builder, fwdBlock, phiParamRevGradInsts))
- {
- // If we couldn't emit a terminator right away, defer for later.
- pendingBlocks.add(PendingBlockTerminatorEntry(
- fwdBlock,
- phiParamRevGradInsts));
- }
+ // Record any phi gradients for the CFG reversal pass.
+ phiGradsMap[fwdBlock] = phiParamRevGradInsts;
+
}
void transposeInst(IRBuilder* builder, IRInst* inst)
@@ -467,6 +826,15 @@ struct DiffTransposePass
break;
}
+ // Some special instructions simply need to be copied over.
+ // These do not deal with differentials.
+ //
+ if (inst->findDecoration<IRLoopCounterDecoration>())
+ {
+ inst->insertAtEnd(builder->getBlock());
+ return;
+ }
+
// Look for gradient entries for this inst.
List<RevGradient> gradients;
if (hasRevGradients(inst))
@@ -787,234 +1155,6 @@ struct DiffTransposePass
return phiBlock;
}
-
- // Create a region to track control flow from the
- // the point of convergence (fwdConvBlock) back to the point of
- // divergence, along one specific path (fwdExitBlock)
- //
- void pushRegion(IRBlock* fwdConvBlock, IRBlock* fwdExitBlock)
- {
- SLANG_ASSERT(!regionMap.ContainsKey(fwdExitBlock));
- SLANG_ASSERT(regionMap.ContainsKey(fwdConvBlock));
-
- Region* newRegion = new Region(fwdExitBlock, regionMap[fwdConvBlock]);
- regions.add(newRegion);
-
- regionMap[fwdExitBlock] = newRegion;
- }
-
- // If we have a conditional-branch from fwdBlock to fwdNextBlock
- // complete the region, and remove from stack
- // otherwise, copy the region over.
- //
- void propagateRegion(IRBlock* fwdNextBlock, IRBlock* fwdBlock)
- {
- if (as<IRConditionalBranch>(fwdBlock->getTerminator()))
- {
- Region* currentRegion = regionMap[fwdNextBlock];
- currentRegion->finish(fwdNextBlock);
-
- regionMap[fwdBlock] = currentRegion->parent;
- }
- else if (as<IRUnconditionalBranch>(fwdBlock->getTerminator()) ||
- as<IRReturn>(fwdBlock->getTerminator()))
- {
- regionMap[fwdBlock] = regionMap[fwdNextBlock];
- }
- }
-
- // Deallocate regions
- void cleanupRegionInfo()
- {
- for (auto region : regions)
- {
- delete region;
- }
-
- regions.clear();
- regionMap.Clear();
- }
-
- bool tryEmitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads)
- {
- // If this block has no differential predecessors, add a return statement.
- if (!doesBlockHaveDifferentialPredecessors(fwdBlockInst))
- {
- // Emit a void return.
- builder->emitReturn();
- return true;
- }
-
- List<IRBlock*> fwdPredecesorBlocks;
- // Check for predecessors count.
- for (auto predecessor : fwdBlockInst->getPredecessors())
- {
- if (!fwdPredecesorBlocks.contains(predecessor))
- fwdPredecesorBlocks.add(predecessor);
- }
-
- SLANG_ASSERT(fwdPredecesorBlocks.getCount() > 0);
-
- // If we have just one, we simply need the reverse-mode block to
- // branch into the reverse-mode version of the predecessor block.
- // (along with the appropriate phi args)
- //
- if (fwdPredecesorBlocks.getCount() == 1)
- {
- builder->emitBranch(
- revBlockMap[fwdPredecesorBlocks[0]],
- phiParamGrads.getCount(),
- phiParamGrads.getBuffer());
-
- propagateRegion(fwdBlockInst, fwdPredecesorBlocks[0]);
- return true;
- }
-
- // If we have more than one, then control flow 'converges' at this point.
- // By convention, this block must be the after block for _some_ conditional
- // control flow statement.
- // If not, we are dealing with an inconsistent graph.
- //
- // Rather than actually emitting the terminator here, we're going to
- // defer to a pass after all the blocks have been transposed.
- // This is because, while we know that this block is the point of convergence
- // we don't know which predecessor belong to which side of the branch.
- // We will instead create 'regions' to track each predecessor for every
- // branch, and by the time all blocks are seen at-least once, we should have
- // resolved the 'start' points for every predecessor.
- //
-
- if (fwdPredecesorBlocks.getCount() > 1)
- {
- SLANG_ASSERT(afterBlockMap.ContainsKey(fwdBlockInst));
-
- for (auto predecessor : fwdPredecesorBlocks)
- {
- // Trivial case when the predecessor itself is the point
- // of divergence.
- //
- if (getAfterBlock(predecessor) == fwdBlockInst)
- continue;
-
- pushRegion(fwdBlockInst, predecessor);
- }
- }
-
- return false;
- }
-
- bool completeEmitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads)
- {
- IRBlock* revBlock = revBlockMap[fwdBlockInst];
-
- // If we already have a terminator, we've probably resolved it during
- // tryEmitTerminator()
- //
- if (revBlock->getTerminator() != nullptr)
- return true;
-
- auto terminatorInst = as<IRInst>(afterBlockMap[fwdBlockInst]);
- switch (terminatorInst->getOp())
- {
- case kIROp_ifElse:
- {
- auto ifElseInst = as<IRIfElse>(terminatorInst);
-
- auto condition = ifElseInst->getCondition();
- SLANG_ASSERT(!isDifferentialInst(condition));
-
- // fwd origin block is the reverse 'after' block.
- auto revAfterBlock = as<IRBlock>(
- revBlockMap[as<IRBlock>(ifElseInst->getParent())]);
-
- // Find region, and find the reverse-mode version of the
- // exit block.
- Region* trueRegion = regionMap[ifElseInst->getTrueBlock()];
- IRBlock* revTrueBlock = revBlockMap[trueRegion->exitBlock];
-
- Region* falseRegion = regionMap[ifElseInst->getFalseBlock()];
- IRBlock* revFalseBlock = revBlockMap[falseRegion->exitBlock];
-
- // If we have phi derivatives to pass on,
- // we need to add dummy blocks to pass them using
- // an unconditional branch.
- //
- if (phiParamGrads.getCount() > 0)
- {
- revTrueBlock = insertPhiBlockBefore(revTrueBlock, phiParamGrads);
- revFalseBlock = insertPhiBlockBefore(revFalseBlock, phiParamGrads);
-
- // Putting the phi blocks just after our current reverse-mode block
- // is not necessary. Just to make intermediate IR easier to follow.
- //
- revTrueBlock->insertAfter(revBlock);
- revFalseBlock->insertAfter(revBlock);
- }
-
- builder->emitIfElse(condition, revTrueBlock, revFalseBlock, revAfterBlock);
- return true;
- }
- case kIROp_Switch:
- {
- auto switchInst = as<IRSwitch>(terminatorInst);
-
- auto condition = switchInst->getCondition();
- SLANG_ASSERT(!isDifferentialInst(condition));
-
- // fwd origin block is the reverse 'break' block.
- auto revAfterBlock = as<IRBlock>(
- revBlockMap[as<IRBlock>(switchInst->getParent())]);
-
- // Find regions for every branch, and find the reverse-mode
- // version of the each exit block.
- Region* defaultRegion = regionMap[switchInst->getDefaultLabel()];
- IRBlock* revDefaultBlock = revBlockMap[defaultRegion->exitBlock];
-
- List<IRBlock*> revCaseBlocks;
- for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++)
- {
- Region* caseRegion = regionMap[switchInst->getCaseLabel(ii)];
- IRBlock* revCaseBlock = revBlockMap[caseRegion->exitBlock];
- revCaseBlocks.add(revCaseBlock);
- }
-
- // If we have phi derivatives to pass on,
- // we need to add dummy blocks to pass them using
- // an unconditional branch.
- //
- if (phiParamGrads.getCount() > 0)
- {
- revDefaultBlock = insertPhiBlockBefore(revDefaultBlock, phiParamGrads);
- revDefaultBlock->insertAfter(revBlock);
-
- for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++)
- {
- revCaseBlocks[ii] = insertPhiBlockBefore(revCaseBlocks[ii], phiParamGrads);
- revCaseBlocks[ii]->insertAfter(revBlock);
- }
- }
-
- List<IRInst*> revCaseArgs;
- for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++)
- {
- revCaseArgs.add(switchInst->getCaseValue(ii));
- revCaseArgs.add(revCaseBlocks[ii]);
- }
-
- builder->emitSwitch(
- condition,
- revAfterBlock,
- revDefaultBlock,
- revCaseArgs.getCount(),
- revCaseArgs.getBuffer());
-
- return true;
- }
- default:
- SLANG_UNIMPLEMENTED_X("Unhandled control flow inst during transposition");
- }
- return false;
- }
TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
{
@@ -1972,9 +2112,7 @@ struct DiffTransposePass
List<PendingBlockTerminatorEntry> pendingBlocks;
- Dictionary<IRBlock*, Region*> regionMap;
-
- List<Region*> regions;
+ Dictionary<IRBlock*, List<IRInst*>> phiGradsMap;
};
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index c525191a3..d808cbb5e 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -33,6 +33,77 @@ struct DiffUnzipPass
// might run into an issue here?
IRBlock* firstDiffBlock;
+ struct IndexedRegion
+ {
+ // Parent indexed region (for nested loops)
+ IndexedRegion* parent = nullptr;
+
+ // Intializer block for the index.
+ IRBlock* initBlock = nullptr;
+
+ // Index 'starts' at the first loop block (included)
+ IRBlock* firstBlock = nullptr;
+
+ // Index stops at the break block (not included)
+ IRBlock* breakBlock = nullptr;
+
+ // Block where index updates happen.
+ IRBlock* continueBlock = nullptr;
+
+ // After lowering, store references to the count
+ // variables associated with this region
+ //
+ IRVar* primalCountVar = nullptr;
+ IRVar* diffCountVar = nullptr;
+
+ enum CountStatus
+ {
+ Unresolved,
+ Dynamic,
+ Static
+ };
+
+ CountStatus status = CountStatus::Unresolved;
+
+ // Inferred maximum number of iterations.
+ Count maxIters = -1;
+
+ IndexedRegion() :
+ parent(nullptr),
+ initBlock(nullptr),
+ firstBlock(nullptr),
+ breakBlock(nullptr),
+ continueBlock(nullptr),
+ primalCountVar(nullptr),
+ diffCountVar(nullptr),
+ status(CountStatus::Unresolved),
+ maxIters(-1)
+ { }
+
+ IndexedRegion(
+ IndexedRegion* parent,
+ IRBlock* initBlock,
+ IRBlock* firstBlock,
+ IRBlock* breakBlock,
+ IRBlock* continueBlock) :
+ parent(parent),
+ initBlock(initBlock),
+ firstBlock(firstBlock),
+ breakBlock(breakBlock),
+ continueBlock(continueBlock),
+ primalCountVar(nullptr),
+ diffCountVar(nullptr),
+ status(CountStatus::Unresolved),
+ maxIters(-1)
+ { }
+ };
+
+ // Keep track of indexed blocks and their corresponding index heirarchy.
+ Dictionary<IRBlock*, IndexedRegion*> indexRegionMap;
+
+ List<IndexedRegion*> indexRegions;
+
+
DiffUnzipPass(
AutoDiffSharedContext* autodiffContext)
: autodiffContext(autodiffContext)
@@ -73,8 +144,8 @@ struct DiffUnzipPass
//
SLANG_ASSERT(unzippedFunc->getFirstBlock() != nullptr);
SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock() != nullptr);
-
- IRBlock* firstBlock = unzippedFunc->getFirstBlock()->getNextBlock();
+
+ IRBlock* firstBlock = as<IRUnconditionalBranch>(unzippedFunc->getFirstBlock()->getTerminator())->getTargetBlock();
List<IRBlock*> mixedBlocks;
for (IRBlock* block = firstBlock; block; block = block->getNextBlock())
@@ -122,9 +193,42 @@ struct DiffUnzipPass
splitBlock(block, as<IRBlock>(primalMap[block]), as<IRBlock>(diffMap[block]));
}
+ // Propagate indexed region information.
+ propagateAllIndexRegions();
+
+ // Try to infer maximum counts for all regions.
+ // (only regions whose intermediates are used outside their region
+ // require a maximum count, so we may see some unresolved regions
+ // without any issues)
+ //
+ for (auto region : indexRegions)
+ {
+ tryInferMaxIndex(region);
+ }
+
+ // Emit counter variables and other supporting
+ // instructions for all regions.
+ //
+ lowerIndexedRegions();
+
+ // Process intermediate insts in indexed blocks
+ // into array loads/stores.
+ //
+ for (auto block : mixedBlocks)
+ {
+ auto primalBlock = primalMap[block];
+
+ if (isBlockIndexed(block))
+ {
+ processIndexedFwdBlock(block);
+ }
+ }
+
// Swap the first block's occurences out for the first primal block.
firstBlock->replaceUsesWith(firstPrimalBlock);
+ cleanupIndexRegionInfo();
+
// Remove old blocks.
for (auto block : mixedBlocks)
block->removeAndDeallocate();
@@ -132,6 +236,239 @@ struct DiffUnzipPass
return unzippedFunc;
}
+ IRBlock* getInitializerBlock(IndexedRegion* region)
+ {
+ return region->initBlock;
+ }
+
+ IRBlock* getUpdateBlock(IndexedRegion* region)
+ {
+ return region->continueBlock;
+ }
+
+ void tryInferMaxIndex(IndexedRegion* region)
+ {
+ if (region->status != IndexedRegion::CountStatus::Unresolved)
+ return;
+
+ // We're going to fix this at a some random number
+ // for now, and then add some basic inference + user-defined decoration
+ //
+ region->maxIters = 5;
+ region->status = IndexedRegion::CountStatus::Static;
+ }
+
+ // Make a primal value *available* to the differential block.
+ // This can get quite involved, and we're going to rely on
+ // constructSSA to do most of the heavy-lifting & optimization
+ // For now, we'll simply create a variable in the top-most
+ // primal block, then load it in the last primal block
+ //
+ //void hoistValue(IRInst* primalInst)
+ //{
+ // IRBlock* terminalPrimalBlock = getTerminalPrimalBlock();
+ // IRBlock* firstPrimalBlock = getFirstPrimalBlock();
+ //}
+
+ void lowerIndexedRegions()
+ {
+ IRBuilder builder(autodiffContext->sharedBuilder);
+
+
+ for (auto region : indexRegions)
+ {
+
+ IRBlock* initializerBlock = getInitializerBlock(region);
+
+ // Grab first primal block.
+ auto firstPrimalBlock = primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()];
+
+ // Make variable in the top-most block (so it's visible to diff blocks)
+ builder.setInsertInto(firstPrimalBlock);
+ region->primalCountVar = builder.emitVar(builder.getUIntType());
+
+ // Make another variable in the diff block initialized to the
+ // final value of the primal counter.
+ //
+ builder.setInsertInto(diffMap[initializerBlock]);
+ auto primalCounterValue = builder.emitLoad(region->primalCountVar);
+ region->diffCountVar = builder.emitVar(builder.getUIntType());
+ builder.emitStore(region->diffCountVar, primalCounterValue);
+
+ IRBlock* updateBlock = getUpdateBlock(region);
+
+ {
+ // TODO: Figure out if the counter update needs to go before or after
+ // the rest of the update block.
+ //
+ builder.setInsertBefore(as<IRBlock>(primalMap[updateBlock])->getTerminator());
+
+ auto counterVal = builder.emitLoad(region->primalCountVar);
+ auto incCounterVal = builder.emitAdd(
+ builder.getUIntType(),
+ counterVal,
+ builder.getIntValue(builder.getUIntType(), 1));
+
+ auto incStore = builder.emitStore(region->primalCountVar, incCounterVal);
+
+ builder.addLoopCounterDecoration(counterVal);
+ builder.addLoopCounterDecoration(incCounterVal);
+ builder.addLoopCounterDecoration(incStore);
+ }
+
+ {
+ // NOTE: This is a hacky shortcut we're taking here.
+ // Technically the unzip pass should not affect the
+ // correctness (it must still compute the proper fwd-mode derivative)
+ // However, we're currently making the loop counter go backwards to
+ // make it easier on the transposition pass, so the output from
+ // the unzip pass is neither fwd-mode or rev-mode until the transposition
+ // step is complete.
+ //
+ // TODO: Ideally this needs to be replaced with a small inversion step
+ // within the transposition pass.
+ //
+
+ builder.setInsertBefore(as<IRBlock>(diffMap[updateBlock])->getTerminator());
+
+ auto counterVal = builder.emitLoad(region->diffCountVar);
+ auto decCounterVal = builder.emitSub(
+ builder.getUIntType(),
+ counterVal,
+ builder.getIntValue(builder.getUIntType(), 0));
+
+ auto decStore = builder.emitStore(region->diffCountVar, decCounterVal);
+
+ // Mark insts as loop counter insts to avoid removing them.
+ //
+ builder.addLoopCounterDecoration(counterVal);
+ builder.addLoopCounterDecoration(decCounterVal);
+ builder.addLoopCounterDecoration(decStore);
+ }
+
+ }
+ }
+
+ void processIndexedFwdBlock(IRBlock* fwdBlock)
+ {
+ if (!isBlockIndexed(fwdBlock))
+ return;
+
+ // Grab first primal block.
+ IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[fwdBlock->getParent()->getFirstBlock()->getNextBlock()]);
+
+ // Scan through instructions and identify those that are used
+ // outside the local block.
+ //
+ IRBlock* primalBlock = as<IRBlock>(primalMap[fwdBlock]);
+
+ List<IRInst*> primalInsts;
+ for (auto child = primalBlock->getFirstChild(); child; child = child->getNextInst())
+ primalInsts.add(child);
+
+ IRBuilder builder(autodiffContext->sharedBuilder);
+
+ // Build list of indices that this block is affected by.
+ List<IndexedRegion*> regions;
+ {
+ IndexedRegion* region = indexRegionMap[fwdBlock];
+ for (; region; region = region->parent)
+ regions.add(region);
+ }
+
+ for (auto inst : primalInsts)
+ {
+ // 1. Check if we need to store inst (is it used in a differential block?)
+
+ bool shouldStore = false;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent());
+
+ if (isDifferentialInst(useBlock))
+ {
+ shouldStore = true;
+ }
+ }
+
+ if (!shouldStore) continue;
+
+ // 2. Emit an array to top-level to allocate space.
+
+ builder.setInsertBefore(firstPrimalBlock->getTerminator());
+
+ IRType* arrayType = inst->getDataType();
+ SLANG_ASSERT(!as<IRPtrTypeBase>(arrayType)); // can't store pointers.
+
+ for (auto region : regions)
+ {
+ SLANG_ASSERT(region->status == IndexedRegion::CountStatus::Static);
+ SLANG_ASSERT(region->maxIters >= 0);
+
+ arrayType = builder.getArrayType(
+ arrayType,
+ builder.getIntValue(
+ builder.getUIntType(),
+ region->maxIters));
+ }
+
+ // Reverse the list since the indices needs to be
+ // emitted in reverse order.
+ //
+ regions.reverse();
+
+ auto storageVar = builder.emitVar(arrayType);
+
+ // 3. Store current value into the array and replace uses with a load.
+ {
+ builder.setInsertAfter(inst);
+
+ IRInst* storeAddr = storageVar;
+ IRType* currType = storageVar->getDataType();
+
+ for (auto region : regions)
+ {
+ currType = as<IRArrayType>(currType)->getElementType();
+
+ storeAddr = builder.emitElementAddress(
+ currType,
+ storeAddr,
+ region->primalCountVar);
+ }
+
+ builder.emitStore(storeAddr, inst);
+ }
+
+ // 4. Replace uses in differential blocks with loads from the array.
+ {
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent());
+
+ if (isDifferentialInst(useBlock))
+ {
+ builder.setInsertBefore(use->getUser());
+
+ IRInst* loadAddr = storageVar;
+ IRType* currType = storageVar->getDataType();
+
+ for (auto region : regions)
+ {
+ currType = as<IRArrayType>(currType)->getElementType();
+
+ loadAddr = builder.emitElementAddress(
+ currType,
+ loadAddr,
+ region->diffCountVar);
+ }
+
+ use->set(builder.emitLoad(loadAddr));
+ }
+ }
+ }
+ }
+ }
+
IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType);
bool isRelevantDifferentialPair(IRType* type)
@@ -327,6 +664,188 @@ struct DiffUnzipPass
return InstPair(primalBranch, returnInst);
}
+ bool isBlockIndexed(IRBlock* block)
+ {
+ return indexRegionMap.ContainsKey(block) && indexRegionMap[block] != nullptr;
+ }
+
+ void addNewIndex(IRLoop* targetLoop)
+ {
+ // Create indexed region without a parent for now.
+ // The parent will be filled in during propagation.
+ //
+ IndexedRegion* region = new IndexedRegion(
+ nullptr,
+ as<IRBlock>(targetLoop->getParent()),
+ targetLoop->getTargetBlock(),
+ targetLoop->getBreakBlock(),
+ targetLoop->getContinueBlock());
+
+ indexRegionMap[targetLoop->getTargetBlock()] = region;
+ indexRegions.add(region);
+ }
+
+ // Deallocate regions
+ void cleanupIndexRegionInfo()
+ {
+ for (auto region : indexRegions)
+ {
+ delete region;
+ }
+
+ indexRegions.clear();
+ indexRegionMap.Clear();
+ }
+
+ void propagateAllIndexRegions()
+ {
+
+
+ // Load up the starting block of every region into
+ // initial worklist.
+ //
+ List<IRBlock*> workList;
+ HashSet<IRBlock*> workSet;
+ for (auto region : indexRegions)
+ {
+ workList.add(region->firstBlock);
+ workSet.Add(region->firstBlock);
+ }
+
+ // Keep propagating from initial work list to predecessors
+ // Add blocks to work list if their region assignment has changed
+ // Add the beginning blocks for complete regions if region parent has changed.
+ //
+ while (workList.getCount() > 0)
+ {
+ auto block = workList.getLast();
+ workList.removeLast();
+ workSet.Remove(block);
+
+ HashSet<IRBlock*> successors;
+
+ for (auto successor : block->getSuccessors())
+ {
+ if (successors.Contains(successor))
+ continue;
+
+ if (propagateIndexRegion(block, successor))
+ {
+ if (!workSet.Contains(successor))
+ {
+ workList.add(successor);
+ workSet.Add(successor);
+ }
+
+ // Do we have an index region for the successor, which is
+ // also the starting block of that region?
+ // Then the change might have been the addition of
+ // a parent node. Add the break block so the
+ // change can be propagated further.
+ //
+ if (isBlockIndexed(successor))
+ {
+ IndexedRegion* succRegion = indexRegionMap[successor];
+ if (succRegion->firstBlock == successor)
+ {
+ if (!workSet.Contains(succRegion->breakBlock))
+ {
+ workList.add(succRegion->breakBlock);
+ workSet.Add(succRegion->breakBlock);
+ }
+ }
+ }
+ }
+
+ successors.Add(successor);
+ }
+ }
+ }
+
+ bool setIndexRegion(IRBlock* block, IndexedRegion* region)
+ {
+ if (!region) return false;
+
+ if (indexRegionMap.ContainsKey(block)
+ && indexRegionMap[block] == region)
+ return false;
+
+ indexRegionMap[block] = region;
+ return true;
+ }
+
+ bool propagateIndexRegion(IRBlock* srcBlock, IRBlock* nextBlock)
+ {
+ // Is the current region indexed?
+ // If not, there's nothing to propagate
+ //
+ if (!isBlockIndexed(srcBlock))
+ return false;
+
+ IndexedRegion* region = indexRegionMap[srcBlock];
+
+ // If the target's index is already resolved,
+ // check if it's a sub-region.
+ //
+ if (isBlockIndexed(nextBlock))
+ {
+ IndexedRegion* nextRegion = indexRegionMap[nextBlock];
+
+ // If we're at the first block of a region,
+ // set current region as continue-region's
+ // parent.
+ //
+ if (nextBlock == nextRegion->firstBlock && nextRegion != region)
+ {
+ nextRegion->parent = region;
+ return true;
+ }
+
+ return false;
+ }
+
+ // If we're at the break block, move up to the parent index.
+ if (nextBlock == region->breakBlock)
+ return setIndexRegion(nextBlock, region->parent);
+
+ // If none of the special cases hit, copy the
+ // current region to the next block.
+ //
+ return setIndexRegion(nextBlock, region);
+ }
+
+ // Splitting a loop is one of the trickiest parts of the unzip pass.
+ // Thus far, we've been dealing with blocks that are only run once, so we
+ // could arbitrarily move intermediate instructions to other blocks since they are
+ // generated and consumed at-most one time.
+ //
+ // Intermediate instructions in a loop can take on a different value each iteration
+ // and thus need to be stored explicitly to an array.
+ //
+ // We also need to ascertain an upper limit on the iteration count.
+ // With very few exceptions, this is a fundamental requirement.
+ //
+ InstPair splitLoop(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRLoop* mixedLoop)
+ {
+
+ auto breakBlock = mixedLoop->getBreakBlock();
+ auto continueBlock = mixedLoop->getContinueBlock();
+ auto nextBlock = mixedLoop->getTargetBlock();
+
+ // Push a new index.
+ addNewIndex(mixedLoop);
+
+ return InstPair(
+ primalBuilder->emitLoop(
+ as<IRBlock>(primalMap[nextBlock]),
+ as<IRBlock>(primalMap[breakBlock]),
+ as<IRBlock>(primalMap[continueBlock])),
+ diffBuilder->emitLoop(
+ as<IRBlock>(diffMap[nextBlock]),
+ as<IRBlock>(diffMap[breakBlock]),
+ as<IRBlock>(diffMap[continueBlock])));
+ }
+
InstPair splitControlFlow(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* branchInst)
{
switch (branchInst->getOp())
@@ -430,6 +949,9 @@ struct DiffUnzipPass
diffCaseArgs.getBuffer()));
}
+ case kIROp_loop:
+ return splitLoop(primalBuilder, diffBuilder, as<IRLoop>(branchInst));
+
default:
SLANG_UNEXPECTED("Unhandled instruction");
}
@@ -544,11 +1066,13 @@ struct DiffUnzipPass
(use->getUser()->getParent() != diffBlock));
}
- inst->removeAndDeallocate();
+ // Leave terminator in to keep CFG info.
+ if (!as<IRTerminatorInst>(inst))
+ inst->removeAndDeallocate();
}
// Nothing should be left in the original block.
- SLANG_ASSERT(block->getFirstChild() == nullptr);
+ SLANG_ASSERT(block->getFirstChild() == block->getTerminator());
// Branch from primal to differential block.
// Functionally, the new blocks should produce the same output as the
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 6b6b3924a..f2294671e 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -749,6 +749,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0)
INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0)
+ INST(LoopCounterDecoration, loopCounterDecoration, 0, 0)
+
/// Used by the auto-diff pass to mark insts that compute
/// a differential value.
INST(DifferentialInstDecoration, diffInstDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 8b30a02dd..5669a12d7 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -663,6 +663,15 @@ struct IRBackwardDerivativeDecoration : IRDecoration
IRInst* getBackwardDerivativeFunc() { return getOperand(0); }
};
+struct IRLoopCounterDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_LoopCounterDecoration
+ };
+ IR_LEAF_ISA(LoopCounterDecoration)
+};
+
struct IRDifferentialInstDecoration : IRDecoration
{
enum
@@ -3243,6 +3252,13 @@ public:
IRBlock* target,
IRBlock* breakBlock,
IRBlock* continueBlock);
+
+ IRInst* emitLoop(
+ IRBlock* target,
+ IRBlock* breakBlock,
+ IRBlock* continueBlock,
+ Int argCount,
+ IRInst*const* args);
IRInst* emitBranch(
IRInst* val,
@@ -3590,6 +3606,11 @@ public:
addDecoration(value, kIROp_BackwardDerivativePrimalContextDecoration, ctx);
}
+ void addLoopCounterDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_LoopCounterDecoration);
+ }
+
void markInstAsDifferential(IRInst* value)
{
addDecoration(value, kIROp_DifferentialInstDecoration, nullptr);
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index 0bd5c6e9f..ee55a6546 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -84,7 +84,7 @@ struct ConstructSSAContext
Dictionary<IRBlock*, RefPtr<SSABlockInfo>> blockInfos;
// IR building state to use during the operation
- SharedIRBuilder sharedBuilder;
+ SharedIRBuilder* sharedBuilder;
// Instructions to remove during cleanup
List<IRInst*> instsToRemove;
@@ -1043,7 +1043,7 @@ static void breakCriticalEdges(
for (auto edge : criticalEdges)
{
- context->sharedBuilder.insertBlockAlongEdge(edge);
+ context->sharedBuilder->insertBlockAlongEdge(edge);
}
}
@@ -1205,7 +1205,8 @@ bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal)
ConstructSSAContext context;
context.globalVal = globalVal;
- context.sharedBuilder.init(module);
+ SharedIRBuilder sharedBuilder(module);
+ context.sharedBuilder = &sharedBuilder;
context.builder.init(context.sharedBuilder);
context.builder.setInsertInto(module);
@@ -1213,6 +1214,22 @@ bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal)
return constructSSA(&context);
}
+// Construct SSA form for a global value with code and reuse
+// an existing sharedBuilder
+//
+bool constructSSA(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* globalVal)
+{
+ ConstructSSAContext context;
+ context.globalVal = globalVal;
+
+ context.sharedBuilder = sharedBuilder;
+
+ context.builder.init(sharedBuilder);
+ context.builder.setInsertInto(sharedBuilder->getModule());
+
+ return constructSSA(&context);
+}
+
bool constructSSA(IRModule* module, IRInst* globalVal)
{
switch (globalVal->getOp())
diff --git a/source/slang/slang-ir-ssa.h b/source/slang/slang-ir-ssa.h
index d455439df..02c9c4831 100644
--- a/source/slang/slang-ir-ssa.h
+++ b/source/slang/slang-ir-ssa.h
@@ -6,7 +6,9 @@ namespace Slang
struct IRModule;
struct IRGlobalValueWithCode;
struct IRInst;
+ struct SharedIRBuilder;
bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal);
+ bool constructSSA(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* globalVal);
bool constructSSA(IRModule* module);
bool constructSSA(IRInst* globalVal);
}
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index a49eda322..03db96ac5 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -199,6 +199,13 @@ namespace Slang
if(inst->getFullType())
validateIRInstOperand(context, inst, &inst->typeUse);
+ // Avoid validating decoration operands
+ // since they don't have to conform to inst visibility
+ // constraints.
+ //
+ if (as<IRDecoration>(inst))
+ return;
+
UInt operandCount = inst->getOperandCount();
for (UInt ii = 0; ii < operandCount; ++ii)
{
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 845232ae6..e72ba8c9f 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4764,6 +4764,32 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitLoop(
+ IRBlock* target,
+ IRBlock* breakBlock,
+ IRBlock* continueBlock,
+ Int argCount,
+ IRInst*const* args)
+ {
+ List<IRInst*> argList;
+
+ argList.add(target);
+ argList.add(breakBlock);
+ argList.add(continueBlock);
+
+ for (Count ii = 0; ii < argCount; ii++)
+ argList.add(args[ii]);
+
+ auto inst = createInst<IRLoop>(
+ this,
+ kIROp_loop,
+ nullptr,
+ argList.getCount(),
+ argList.getBuffer());
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitBranch(
IRInst* val,
IRBlock* trueBlock,