summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-02-17 12:03:59 -0500
committerGitHub <noreply@github.com>2023-02-17 09:03:59 -0800
commitf253d15a3b2681dfa40491451fcb3f21f1dbe412 (patch)
tree589298ff23ea2b2eb89615694f2c06613f1199a1 /source
parent245466d89cfe54b78da486f06d470bc6daaf4625 (diff)
Proper reverse-mode loop handling with splitting + inversion steps (#2656)
* Halfway to loop inversion * More progress towards proper loop inversion * More progress towards inverse insts. Only thing left is adding `counter>=0` at the right place * More fixes for inversion step. * Lots more fixes, added primal inst 'hoisting' mechanism as the central method that ensures primal values are placed in the right spot * Loop inversion is now functional * Cleaned up commented code * rename diffCounterVar -> diffCounterParam * minor update * removed some comments and commented code * Switch `IRBuilder(sharedIRBuilder)` to `IRBuilder(moduleInst)`
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-propagate.h5
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h441
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h328
-rw-r--r--source/slang/slang-ir-autodiff.cpp1
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h33
-rw-r--r--source/slang/slang-ir-ssa.cpp5
7 files changed, 636 insertions, 179 deletions
diff --git a/source/slang/slang-ir-autodiff-propagate.h b/source/slang/slang-ir-autodiff-propagate.h
index 4edf20142..8f912ba61 100644
--- a/source/slang/slang-ir-autodiff-propagate.h
+++ b/source/slang/slang-ir-autodiff-propagate.h
@@ -15,6 +15,11 @@ inline bool isDifferentialInst(IRInst* inst)
return inst->findDecoration<IRDifferentialInstDecoration>();
}
+inline bool isPrimalInst(IRInst* inst)
+{
+ return inst->findDecoration<IRPrimalInstDecoration>() || (as<IRConstant>(inst) != nullptr);
+}
+
inline bool isMixedDifferentialInst(IRInst* inst)
{
return inst->findDecoration<IRMixedDifferentialInstDecoration>();
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 8aca31642..b74416b76 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -227,6 +227,9 @@ struct DiffTransposePass
IRBlock* revAfterBlock = revBlockMap[currentBlock];
builder.setInsertInto(revCondBlock);
+
+ hoistPrimalInst(&builder, ifElse->getCondition());
+
builder.emitIfElse(
ifElse->getCondition(),
revTrueEntryBlock,
@@ -357,6 +360,8 @@ struct DiffTransposePass
// Emit condition into the new cond block.
builder.setInsertInto(revCondBlock);
+ hoistPrimalInst(&builder, ifElse->getCondition());
+
builder.emitIfElse(
ifElse->getCondition(),
revTrueBlock,
@@ -442,7 +447,11 @@ struct DiffTransposePass
}
auto revSwitchBlock = revBlockMap[breakBlock];
+
builder.setInsertInto(revSwitchBlock);
+
+ hoistPrimalInst(&builder, switchInst->getCondition());
+
builder.emitSwitch(
switchInst->getCondition(),
revBreakBlock,
@@ -588,6 +597,21 @@ struct DiffTransposePass
auto firstFwdDiffBlock = branchInst->getTargetBlock();
reverseCFGRegion(firstFwdDiffBlock, List<IRBlock*>());
+ // Lower any loop-exit-value decorations into initializations for loop intermediate vals,
+ // and convert loop initial values into terminating conditions.
+ //
+ // TODO: We need a way to confirm that all required vars have an initial value
+ // (is there a built-in dataflow tool for this?)
+ //
+ for (auto block : workList)
+ {
+ if (auto loopInst = as<IRLoop>(block->getTerminator()))
+ {
+ lowerLoopExitValues(&builder, loopInst);
+ invertLoopCondition(&builder, loopInst);
+ }
+ }
+
// Link the last differential fwd-mode block (which will be the first
// rev-mode block) as the successor to the last primal block.
// We assume that the original function is in single-return form
@@ -686,6 +710,36 @@ struct DiffTransposePass
return tempRevVar;
}
+ IRVar* getOrCreateInverseVar(IRInst* primalInst)
+ {
+ // No need to store inverse values for constants.
+ if (as<IRConstant>(primalInst))
+ return nullptr;
+
+ // Check if we have a var already.
+ if (inverseVarMap.ContainsKey(primalInst))
+ return inverseVarMap[primalInst];
+
+ IRBuilder tempVarBuilder(autodiffContext->moduleInst);
+
+ IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(primalInst->getParent()->getParent())];
+
+ if (auto firstInst = firstDiffBlock->getFirstOrdinaryInst())
+ tempVarBuilder.setInsertBefore(firstInst);
+ else
+ tempVarBuilder.setInsertInto(firstDiffBlock);
+
+ auto primalType = primalInst->getDataType();
+
+ // Emit a var in the top-level differential block to hold the inverse,
+ // and initialize it.
+ auto tempInvVar = tempVarBuilder.emitVar(primalType);
+
+ inverseVarMap[primalInst] = tempInvVar;
+
+ return tempInvVar;
+ }
+
bool isInstUsedOutsideParentBlock(IRInst* inst)
{
auto currBlock = inst->getParent();
@@ -707,7 +761,7 @@ struct DiffTransposePass
builder.setInsertInto(revBlock);
// Check if this block has any 'outputs' (in the form of phi args
- // sent to the successor bvock)
+ // sent to the successor block)
//
if (auto branchInst = as<IRUnconditionalBranch>(fwdBlock->getTerminator()))
{
@@ -716,51 +770,48 @@ struct DiffTransposePass
auto arg = branchInst->getArg(ii);
if (isDifferentialInst(arg))
{
+ // If the arg is a differential, emit a parameter
+ // to accept it's reverse-mode differential as an input
+ //
+
auto diffType = arg->getDataType();
auto revParam = builder.emitParam(diffType);
addRevGradientForFwdInst(
- arg,
+ arg,
RevGradient(
RevGradient::Flavor::Simple,
arg,
revParam,
nullptr));
}
- }
- }
-
- // Some special instructions simply need to be copied over.
- // These do not deal with differentials.
- // TODO: This will not work if there are any differential
- // insts that rely on loop counter vars having a specific
- // value.
- // The solution is to have primal insts appearing in
- // differential blocks be in their own special blocks that are
- // ignored entirely, rather than dealing with them one inst
- // at a time.
- //
- for (IRInst* child = fwdBlock->getFirstChild(); child;)
- {
- auto nextChild = child->getNextInst();
+ else if (isPrimalInst(arg))
+ {
+ // If the output arg is a primal, emit a parameter
+ // to accept it as an _input_ for the reverse-mode
+ //
+ auto primalType = arg->getDataType();
+ auto primalInvParam = builder.emitParam(primalType);
- if (child->findDecoration<IRLoopCounterDecoration>())
- {
- // Loop counter insts should not have any gradients.
- SLANG_ASSERT(!hasRevGradients(child));
- child->insertAtEnd(revBlock);
+ setInverse(&builder, arg, primalInvParam);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Encountered inst not marked as primal or differential");
+ }
}
-
- child = nextChild;
}
// Move pointer & reference insts to the top of the reverse-mode block.
List<IRInst*> nonValueInsts;
for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
{
- // If the instruction is pointer typed, it's not actually computing a value.
+ // If the instruction is a variable allocation (or reverse-gradient pair reference),
+ // move to top.
+ // TODO: This is hacky.. Need a more principled way to handle this
+ // (like primal inst hoisting)
//
- if (as<IRPtrTypeBase>(child->getDataType()))
+ if (as<IRVar>(child) || as<IRReverseGradientDiffPairRef>(child))
nonValueInsts.add(child);
// Slang doesn't support function values. So if we see a func-typed inst
@@ -782,11 +833,16 @@ struct DiffTransposePass
//
for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst())
{
+ if (child->findDecoration<IRPrimalValueAccessDecoration>())
+ continue;
+
if (as<IRDecoration>(child) || as<IRParam>(child))
continue;
-
- transposeInst(&builder, child);
+ if (isDifferentialInst(child))
+ transposeInst(&builder, child);
+ else if (isPrimalInst(child))
+ invertInst(&builder, child);
}
// After processing the block's instructions, we 'flush' any remaining gradients
@@ -806,32 +862,47 @@ struct DiffTransposePass
List<IRInst*> phiParamRevGradInsts;
for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam())
{
- // This param might be used outside this block.
- // If so, add/get an accumulator.
- //
- if (isInstUsedOutsideParentBlock(param))
+ if (isDifferentialInst(param))
{
- auto accGradient = extractAccumulatorVarGradient(&builder, param);
- addRevGradientForFwdInst(
- param,
- RevGradient(param, accGradient, nullptr));
+ // This param might be used outside this block.
+ // If so, add/get an accumulator.
+ //
+ if (isInstUsedOutsideParentBlock(param))
+ {
+ auto accGradient = extractAccumulatorVarGradient(&builder, param);
+ addRevGradientForFwdInst(
+ param,
+ RevGradient(param, accGradient, nullptr));
+ }
+ if (hasRevGradients(param))
+ {
+ auto gradients = popRevGradients(param);
+
+ auto gradInst = emitAggregateValue(
+ &builder,
+ tryGetPrimalTypeFromDiffInst(param),
+ gradients);
+
+ phiParamRevGradInsts.add(gradInst);
+ }
+ else
+ {
+ phiParamRevGradInsts.add(
+ emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param)));
+ }
}
-
- if (hasRevGradients(param))
+ else if (isPrimalInst(param))
{
- auto gradients = popRevGradients(param);
-
- auto gradInst = emitAggregateValue(
- &builder,
- tryGetPrimalTypeFromDiffInst(param),
- gradients);
-
- phiParamRevGradInsts.add(gradInst);
+ if (hasInverse(param))
+ phiParamRevGradInsts.add(getInverse(&builder, param));
+ else
+ {
+ SLANG_UNEXPECTED("param is a primal inst but has no registered inverse");
+ }
}
else
{
- phiParamRevGradInsts.add(
- emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param)));
+ SLANG_UNEXPECTED("param is neither differential nor primal");
}
}
@@ -896,6 +967,266 @@ struct DiffTransposePass
}
+ struct InvInstPair
+ {
+ IRInst* inst;
+ IRInst* invInst;
+
+ InvInstPair(IRInst* inst, IRInst* invInst) :
+ inst(inst), invInst(invInst)
+ { }
+
+ InvInstPair() : inst(nullptr), invInst(nullptr)
+ { }
+ };
+
+ List<InvInstPair> invertArithmetic(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput)
+ {
+ switch (primalInst->getOp())
+ {
+ case kIROp_Add:
+ {
+ SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1)));
+ return List<InvInstPair>(
+ InvInstPair(
+ primalInst->getOperand(0),
+ builder->emitSub(
+ primalInst->getOperand(0)->getDataType(),
+ invOutput,
+ primalInst->getOperand(1))));
+ }
+ case kIROp_Sub:
+ {
+ SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1)));
+ return List<InvInstPair>(
+ InvInstPair(
+ primalInst->getOperand(0),
+ builder->emitAdd(
+ primalInst->getOperand(0)->getDataType(),
+ invOutput,
+ primalInst->getOperand(1))));
+ }
+
+ default:
+ SLANG_UNEXPECTED("Unhandled arithmetic inst for inversion");
+ }
+ }
+
+ void lowerLoopExitValues(IRBuilder* builder, IRLoop* fwdLoop)
+ {
+ for (auto decoration : fwdLoop->getDecorations())
+ {
+ if (auto loopExitValueDecoration = as<IRLoopExitPrimalValueDecoration>(decoration))
+ {
+ IRBlock* revLoopInitBlock = revBlockMap[fwdLoop->getBreakBlock()];
+
+ if (auto revLoopInst = revLoopInitBlock->getTerminator())
+ builder->setInsertBefore(revLoopInst);
+ else
+ builder->setInsertInto(revLoopInitBlock);
+
+ hoistPrimalInst(builder, loopExitValueDecoration->getLoopExitValInst());
+
+ setInverse(builder, loopExitValueDecoration->getTargetInst(), loopExitValueDecoration->getLoopExitValInst());
+ }
+ }
+ }
+
+ void lowerLoopExitValues(IRBuilder* builder, IRBlock* block)
+ {
+ if (auto loopInst = as<IRLoop>(block->getTerminator()))
+ lowerLoopExitValues(builder, loopInst);
+ }
+
+ // Go through loop block phi-args, and look for loop counter
+ // arguments, which for a loop means inserting a check into
+ // loop condition block.
+ // This method also adds logic to skip the first iteration.
+ // (a 'do-while' loop)
+ //
+ void invertLoopCondition(IRBuilder* builder, IRLoop* loopInst)
+ {
+ auto firstLoopBlock = loopInst->getTargetBlock();
+
+ IRBlock* revLoopCondBlock = revBlockMap[firstLoopBlock];
+ builder->setInsertBefore(revLoopCondBlock->getTerminator());
+
+ auto loopBaseCondition = as<IRIfElse>(revLoopCondBlock->getTerminator())->getCondition();
+
+ // Convert the loop from a 'for' into a 'do-while' by skipping the first check
+
+ IRBlock* revLoopStartBlock = revBlockMap[as<IRBlock>(loopInst->getBreakBlock())];
+ builder->setInsertBefore(revLoopStartBlock->getTerminator());
+
+ auto firstLoopCheckSkipVar = builder->emitVar(builder->getBoolType());
+ builder->emitStore(firstLoopCheckSkipVar, builder->getBoolValue(true));
+
+ builder->setInsertBefore(revLoopCondBlock->getTerminator());
+ auto firstLoopCheckSkipVal = builder->emitLoad(firstLoopCheckSkipVar);
+
+ builder->emitStore(firstLoopCheckSkipVar, builder->getBoolValue(false));
+
+ loopBaseCondition = builder->emitIntrinsicInst(
+ builder->getBoolType(),
+ kIROp_Or,
+ 2,
+ List<IRInst*>(firstLoopCheckSkipVal, loopBaseCondition).getBuffer());
+
+ // Add a terminating condition based on the loop counter's initial primal value
+
+ IRParam* loopCounterParam = nullptr;
+ UIndex loopCounterParamIndex = 0;
+ for (auto param : firstLoopBlock->getParams())
+ {
+ if (param->findDecoration<IRLoopCounterDecoration>())
+ {
+ // There really should be two (or more) loop counter params.
+ SLANG_RELEASE_ASSERT(loopCounterParam == nullptr);
+ loopCounterParam = param;
+ }
+ else
+ {
+ loopCounterParamIndex++;
+ }
+ }
+
+ // Should see atleast one loop counter parameter on the first loop block.
+ SLANG_RELEASE_ASSERT(loopCounterParam);
+
+ IRInst* loopCounterInitVal = loopInst->getArg(loopCounterParamIndex);
+
+ auto paramBoundsCheck = builder->emitIntrinsicInst(
+ builder->getBoolType(),
+ kIROp_Neq,
+ 2,
+ List<IRInst*>(
+ hoistPrimalInst(builder, loopCounterParam),
+ hoistPrimalInst(builder, loopCounterInitVal)).getBuffer());
+
+ loopBaseCondition = builder->emitIntrinsicInst(
+ builder->getBoolType(),
+ kIROp_And,
+ 2,
+ List<IRInst*>(paramBoundsCheck, loopBaseCondition).getBuffer());
+
+
+ as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(loopBaseCondition);
+ }
+
+ List<InvInstPair> invertInst(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput)
+ {
+ switch (primalInst->getOp())
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ return invertArithmetic(builder, primalInst, invOutput);
+
+ default:
+ SLANG_UNIMPLEMENTED_X("Unhandled inst type for inversion");
+ }
+ }
+
+ bool hasInverse(IRInst* primalInst)
+ {
+ if (getOrCreateInverseVar(primalInst))
+ return true;
+ else
+ return false;
+ }
+
+ IRInst* getInverse(IRBuilder* builder, IRInst* primalInst)
+ {
+ // Note: There are other possible cases here, although not important
+ // right now. For example, a value is available to load from the primal block.
+ //
+ if (auto invVar = getOrCreateInverseVar(primalInst))
+ return builder->emitLoad(invVar);
+
+ return nullptr;
+ }
+
+ void setInverse(IRBuilder* builder, IRInst* inst, IRInst* invInst)
+ {
+ if (auto invVar = getOrCreateInverseVar(inst))
+ builder->emitStore(invVar, invInst);
+ }
+
+ IRInst* hoistPrimalInst(IRBuilder* revBuilder, IRInst* inst)
+ {
+ SLANG_RELEASE_ASSERT(isPrimalInst(inst));
+
+ // Are the operands of this primal inst also available in the reverse-mode context?
+ // If not, move/load them.
+ //
+ hoistPrimalOperands(revBuilder, inst);
+
+ if (isPrimalInst(inst) &&
+ as<IRBlock>(inst->getParent()) &&
+ isDifferentialInst(as<IRBlock>(inst->getParent())))
+ {
+ if (!inst->findDecoration<IRPrimalValueAccessDecoration>())
+ {
+ return getInverse(revBuilder, inst);
+ }
+ else
+ {
+ auto block = as<IRBlock>(inst->getParent());
+ SLANG_RELEASE_ASSERT(block);
+
+ if (block == revBuilder->getBlock())
+ {
+ // Already in block..
+ return inst;
+ }
+
+ // Otherwise, move our inst to the the current builder location.
+ inst->removeFromParent();
+ revBuilder->addInst(inst);
+
+ return inst;
+ }
+ }
+
+ return inst;
+ }
+
+ void hoistPrimalOperands(IRBuilder* revBuilder, IRInst* fwdInst)
+ {
+ for (UIndex ii = 0; ii < fwdInst->getOperandCount(); ii++)
+ {
+ // For now we'll only hoist primal operands that are
+ // generated in differential blocks.
+ // Eventually, we also want this method to move primal access
+ // insts to the reverse-mode blocks (i.e. this method will
+ // make sure all requried primal insts are moved to the right
+ // place)
+ //
+ if (isPrimalInst(fwdInst->getOperand(ii)))
+ {
+ auto hoistedPrimalInst = hoistPrimalInst(revBuilder, fwdInst->getOperand(ii));
+ fwdInst->setOperand(ii, hoistedPrimalInst);
+ }
+ }
+ }
+
+ void invertInst(IRBuilder* builder, IRInst* primalInst)
+ {
+ // Look for an available inverse entry for this primalInst's *output*
+ if (hasInverse(primalInst))
+ {
+ auto invOutput = getInverse(builder, primalInst);
+
+ auto invEntries = invertInst(builder, primalInst, invOutput);
+
+ for (auto entry : invEntries)
+ setInverse(builder, entry.inst, entry.invInst);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Could not find value for the output of inst. Unable to invert");
+ }
+ }
+
void transposeInst(IRBuilder* builder, IRInst* inst)
{
switch (inst->getOp())
@@ -930,7 +1261,7 @@ struct DiffTransposePass
if (auto pairType = as<IRDifferentialPairType>(loadInst->getDataType()))
{
primalType = pairType->getValueType();
- }
+ }
}
}
@@ -948,6 +1279,11 @@ struct DiffTransposePass
SLANG_ASSERT(gradients.getCount() == 0);
}
+ // Ensure primal operands are replaced with insts accessible in the
+ // reverse-mode context.
+ //
+ hoistPrimalOperands(builder, inst);
+
// Is this inst used in another differential block?
// Emit a function-scope accumulator variable, and include it's value.
// Also, we ignore this if it's a load since those are turned into stores
@@ -2457,6 +2793,8 @@ struct DiffTransposePass
Dictionary<IRInst*, IRVar*> revAccumulatorVarMap;
+ Dictionary<IRInst*, IRVar*> inverseVarMap;
+
List<IRInst*> usedPtrs;
Dictionary<IRBlock*, IRBlock*> revBlockMap;
@@ -2468,9 +2806,8 @@ struct DiffTransposePass
List<PendingBlockTerminatorEntry> pendingBlocks;
Dictionary<IRBlock*, List<IRInst*>> phiGradsMap;
-
- Dictionary<IRBlock*, IRBlock*> initializerBlockMap;
+ Dictionary<IRInst*, IRInst*> inverseValueMap;
};
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 68326fd54..c3af52d8a 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -10,6 +10,7 @@
#include "slang-ir-autodiff-propagate.h"
#include "slang-ir-autodiff-transcriber-base.h"
#include "slang-ir-validate.h"
+#include "slang-ir-ssa.h"
namespace Slang
{
@@ -55,8 +56,10 @@ struct DiffUnzipPass
// After lowering, store references to the count
// variables associated with this region
//
- IRVar* primalCountVar = nullptr;
- IRVar* diffCountVar = nullptr;
+ IRInst* primalCountParam = nullptr;
+ IRInst* diffCountParam = nullptr;
+
+ IRVar* primalCountLastVar = nullptr;
enum CountStatus
{
@@ -76,8 +79,8 @@ struct DiffUnzipPass
firstBlock(nullptr),
breakBlock(nullptr),
continueBlock(nullptr),
- primalCountVar(nullptr),
- diffCountVar(nullptr),
+ primalCountParam(nullptr),
+ diffCountParam(nullptr),
status(CountStatus::Unresolved),
maxIters(-1)
{ }
@@ -93,8 +96,8 @@ struct DiffUnzipPass
firstBlock(firstBlock),
breakBlock(breakBlock),
continueBlock(continueBlock),
- primalCountVar(nullptr),
- diffCountVar(nullptr),
+ primalCountParam(nullptr),
+ diffCountParam(nullptr),
status(CountStatus::Unresolved),
maxIters(-1)
{ }
@@ -254,20 +257,15 @@ struct DiffUnzipPass
//
for (auto block : mixedBlocks)
{
- auto primalBlock = primalMap[block];
-
if (isBlockIndexed(block))
- {
processIndexedFwdBlock(block);
- }
}
-
+
// Swap the first block's occurences out for the first primal block.
firstBlock->replaceUsesWith(firstPrimalBlock);
cleanupIndexRegionInfo();
- // Remove old blocks.
for (auto block : mixedBlocks)
block->removeAndDeallocate();
}
@@ -319,17 +317,78 @@ struct DiffUnzipPass
}
}
- // Make a primal value *available* to the differential block.
- // This can get quite involved, and we're going to rely on
- // constructSSA to do most of the heavy-lifting & optimization
- // For now, we'll simply create a variable in the top-most
- // primal block, then load it in the last primal block
- //
- //void hoistValue(IRInst* primalInst)
- //{
- // IRBlock* terminalPrimalBlock = getTerminalPrimalBlock();
- // IRBlock* firstPrimalBlock = getFirstPrimalBlock();
- //}
+ UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
+ {
+ SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator()));
+
+ auto branchInst = as<IRUnconditionalBranch>(block->getTerminator());
+ List<IRInst*> phiArgs;
+
+ for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++)
+ phiArgs.add(branchInst->getArg(ii));
+
+ phiArgs.add(arg);
+
+ builder->setInsertInto(block);
+ switch (branchInst->getOp())
+ {
+ case kIROp_unconditionalBranch:
+ builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer());
+ break;
+
+ case kIROp_loop:
+ builder->emitLoop(
+ as<IRLoop>(branchInst)->getTargetBlock(),
+ as<IRLoop>(branchInst)->getBreakBlock(),
+ as<IRLoop>(branchInst)->getContinueBlock(),
+ phiArgs.getCount(),
+ phiArgs.getBuffer());
+ break;
+
+ default:
+ break;
+ }
+
+ branchInst->removeAndDeallocate();
+ return phiArgs.getCount() - 1;
+ }
+
+ IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type)
+ {
+ builder->setInsertInto(block);
+ return builder->emitParam(type);
+ }
+
+ IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type, UIndex index)
+ {
+ List<IRParam*> params;
+ for (auto param : block->getParams())
+ params.add(param);
+
+ SLANG_RELEASE_ASSERT(index == (UCount)params.getCount());
+
+ return addPhiInputParam(builder, block, type);
+ }
+
+ IRBlock* getBlock(IRInst* inst)
+ {
+ SLANG_RELEASE_ASSERT(inst);
+
+ if (auto block = as<IRBlock>(inst))
+ return block;
+
+ return getBlock(inst->getParent());
+ }
+
+ IRInst* getInstInBlock(IRInst* inst)
+ {
+ SLANG_RELEASE_ASSERT(inst);
+
+ if (auto block = as<IRBlock>(inst->getParent()))
+ return inst;
+
+ return getInstInBlock(inst->getParent());
+ }
void lowerIndexedRegions()
{
@@ -337,114 +396,131 @@ struct DiffUnzipPass
for (auto region : indexRegions)
{
-
- //IRBlock* initializerBlock = getInitializerBlock(region);
- IRBlock* breakBlock = region->breakBlock;
-
// Grab first primal block.
IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]);
-
- // Make variable in the top-most block (so it's visible to diff blocks)
builder.setInsertBefore(firstPrimalBlock->getTerminator());
- region->primalCountVar = builder.emitVar(builder.getIntType());
- builder.emitStore(
- region->primalCountVar,
- builder.getIntValue(builder.getIntType(), 0));
-
- // NOTE: This is a hacky shortcut we're taking here.
- // Technically the unzip pass should not affect the
- // correctness (it must still compute the proper fwd-mode derivative)
- // However, we're currently making the loop counter go backwards to
- // make it easier on the transposition pass, so the output from
- // the unzip pass is neither fwd-mode or rev-mode until the transposition
- // step is complete.
- //
- // TODO: Ideally this needs to be replaced with a small inversion step
- // within the transposition pass.
- //
- // Emit the diff counter into the diff *break* block (
- // which we're praying turns into the reverse initializer block)
- // initialized to the final value of the primal counter.
- //
- builder.setInsertBefore(as<IRBlock>(diffMap[breakBlock])->getTerminator());
- //auto primalCounterValue = builder.emitLoad(region->primalCountVar);
- auto primalCounterCurrValue = builder.emitLoad(region->primalCountVar);
- auto primalCounterLastValue = builder.emitSub(
- primalCounterCurrValue->getDataType(),
- primalCounterCurrValue,
- builder.getIntValue(builder.getIntType(), 1));
-
- region->diffCountVar = builder.emitVar(builder.getIntType());
- auto diffCountInit = builder.emitStore(region->diffCountVar, primalCounterLastValue);
-
- builder.addLoopCounterDecoration(diffCountInit);
- builder.addLoopCounterDecoration(region->diffCountVar);
- builder.addLoopCounterDecoration(primalCounterCurrValue);
- builder.addLoopCounterDecoration(primalCounterLastValue);
-
- IRBlock* updateBlock = getUpdateBlock(region);
+
+ // Make variable in the top-most block (so it's visible to diff blocks)
+ region->primalCountLastVar = builder.emitVar(builder.getIntType());
{
- // TODO: Figure out if the counter update needs to go before or after
- // the rest of the update block.
- //
- builder.setInsertBefore(as<IRBlock>(primalMap[updateBlock])->getTerminator());
+ IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]);
+
+ auto primalCondBlock = as<IRUnconditionalBranch>(
+ primalInitBlock->getTerminator())->getTargetBlock();
+ builder.setInsertBefore(primalCondBlock->getTerminator());
+
+ auto phiCounterArgLoopEntryIndex = addPhiOutputArg(
+ &builder,
+ primalInitBlock,
+ builder.getIntValue(builder.getIntType(), 0));
+
+ region->primalCountParam = addPhiInputParam(
+ &builder,
+ primalCondBlock,
+ builder.getIntType(),
+ phiCounterArgLoopEntryIndex);
+ builder.addLoopCounterDecoration(region->primalCountParam);
+ builder.markInstAsPrimal(region->primalCountParam);
+
+ IRBlock* primalUpdateBlock = as<IRBlock>(primalMap[getUpdateBlock(region)]);
+ builder.setInsertBefore(primalUpdateBlock->getTerminator());
- auto counterVal = builder.emitLoad(region->primalCountVar);
auto incCounterVal = builder.emitAdd(
builder.getIntType(),
- counterVal,
+ region->primalCountParam,
builder.getIntValue(builder.getIntType(), 1));
+ builder.markInstAsPrimal(incCounterVal);
+
+ auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, primalUpdateBlock, incCounterVal);
- auto incStore = builder.emitStore(region->primalCountVar, incCounterVal);
+ SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex);
- builder.addLoopCounterDecoration(counterVal);
- builder.addLoopCounterDecoration(incCounterVal);
- builder.addLoopCounterDecoration(incStore);
+ IRBlock* primalBreakBlock = as<IRBlock>(primalMap[region->breakBlock]);
+ builder.setInsertBefore(primalBreakBlock->getTerminator());
+
+ builder.emitStore(region->primalCountLastVar, region->primalCountParam);
}
{
- IRBlock* firstLoopBlock = getFirstLoopBodyBlock(region);
- auto diffFirstLoopBlock = as<IRBlock>(diffMap[firstLoopBlock]);
+ IRBlock* diffInitBlock = as<IRBlock>(diffMap[region->initBlock]);
+
+ auto diffCondBlock = as<IRUnconditionalBranch>(
+ diffInitBlock->getTerminator())->getTargetBlock();
+ builder.setInsertBefore(diffCondBlock->getTerminator());
- builder.setInsertBefore(diffFirstLoopBlock->getTerminator());
+ auto phiCounterArgLoopEntryIndex = addPhiOutputArg(
+ &builder,
+ diffInitBlock,
+ builder.getIntValue(builder.getIntType(), 0));
+
+ region->diffCountParam = addPhiInputParam(
+ &builder,
+ diffCondBlock,
+ builder.getIntType(),
+ phiCounterArgLoopEntryIndex);
+ builder.addLoopCounterDecoration(region->diffCountParam);
+ builder.markInstAsPrimal(region->diffCountParam);
+
+ IRBlock* diffUpdateBlock = as<IRBlock>(diffMap[getUpdateBlock(region)]);
+ builder.setInsertBefore(diffUpdateBlock->getTerminator());
- auto counterVal = builder.emitLoad(region->diffCountVar);
- auto decCounterVal = builder.emitSub(
+ auto incCounterVal = builder.emitAdd(
builder.getIntType(),
- counterVal,
+ region->diffCountParam,
builder.getIntValue(builder.getIntType(), 1));
+ builder.markInstAsPrimal(incCounterVal);
- auto decStore = builder.emitStore(region->diffCountVar, decCounterVal);
+ auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, diffUpdateBlock, incCounterVal);
- // Mark insts as loop counter insts to avoid removing them.
- //
- builder.addLoopCounterDecoration(counterVal);
- builder.addLoopCounterDecoration(decCounterVal);
- builder.addLoopCounterDecoration(decStore);
+ SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex);
- // TODO:
- // This is another hack here to avoid the counter from going negative
- // (since they are not valid indices)
- //
- IRBlock* diffCondBlock = as<IRBlock>(diffMap[region->firstBlock]);
+ auto loopInst = as<IRLoop>(diffInitBlock->getTerminator());
- builder.setInsertBefore(diffCondBlock->getTerminator());
- IRInst* diffCounterVal = builder.emitLoad(region->diffCountVar);
- IRInst* diffCounterCmp = builder.emitIntrinsicInst(
- builder.getBoolType(),
- kIROp_Geq,
- 2,
- List<IRInst*>(
- diffCounterVal,
- builder.getIntValue(builder.getIntType(), 0)).getBuffer());
-
- as<IRIfElse>(diffCondBlock->getTerminator())->condition.set(diffCounterCmp);
+ builder.setInsertBefore(loopInst);
+
+ auto primalCounterLastVal = builder.emitLoad(region->primalCountLastVar);
+ builder.markInstAsPrimal(primalCounterLastVal);
+ builder.addPrimalValueAccessDecoration(primalCounterLastVal);
- builder.addLoopCounterDecoration(diffCounterVal);
- builder.addLoopCounterDecoration(diffCounterCmp);
+ builder.addLoopExitPrimalValueDecoration(loopInst, region->diffCountParam, primalCounterLastVal);
}
+ }
+ }
+ void tagNewParams(IRBuilder* builder, IRFunc* func)
+ {
+ for (auto block : func->getBlocks())
+ {
+ for (auto param = block->getFirstParam(); param; param = param->getNextParam())
+ if (!param->findDecoration<IRAutodiffInstDecoration>())
+ builder->markInstAsPrimal(param);
+ }
+ }
+
+ void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst)
+ {
+ if (as<IRParam>(inst))
+ {
+ SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent()));
+ builder->setInsertBefore(as<IRBlock>(inst->getParent())->getFirstOrdinaryInst());
+ }
+ else
+ {
+ builder->setInsertBefore(inst);
+ }
+ }
+
+ void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst)
+ {
+ if (as<IRParam>(inst))
+ {
+ SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent()));
+ builder->setInsertBefore(as<IRBlock>(inst->getParent())->getFirstOrdinaryInst());
+ }
+ else
+ {
+ builder->setInsertAfter(inst);
}
}
@@ -520,10 +596,15 @@ struct DiffUnzipPass
auto storageVar = builder.emitVar(arrayType);
+ // TODO(sai) STOPPED HERE: For some reason, we still have a direct param access
+ // when trying to cover up the access to last value of loop counter.
+ // Maybe we need a different way to access this? (use a var)
+ // Special case?
+
// 3. Store current value into the array and replace uses with a load.
// TODO: If an index is missing, use the 'last' value of the primal index.
{
- builder.setInsertAfter(inst);
+ setInsertAfterOrdinaryInst(&builder, inst);
IRInst* storeAddr = storageVar;
IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType();
@@ -535,12 +616,12 @@ struct DiffUnzipPass
storeAddr = builder.emitElementAddress(
builder.getPtrType(currType),
storeAddr,
- builder.emitLoad(region->primalCountVar));
+ region->primalCountParam);
}
builder.emitStore(storeAddr, inst);
}
-
+
// 4. Replace uses in differential blocks with loads from the array.
List<IRInst*> instsToTag;
{
@@ -548,17 +629,20 @@ struct DiffUnzipPass
for (auto use = inst->firstUse; use; use = use->nextUse)
{
if (as<IRDecoration>(use->getUser()))
- continue;
+ {
+ if (!as<IRLoopExitPrimalValueDecoration>(use->getUser()))
+ continue;
+ }
- IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent());
+ IRBlock* useBlock = getBlock(use->getUser());
if (useBlock && isDifferentialInst(useBlock))
diffUses.add(use);
}
for (auto use : diffUses)
{
- IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent());
- builder.setInsertBefore(use->getUser());
+ IRBlock* useBlock = getBlock(use->getUser());
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
IRInst* loadAddr = storageVar;
IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType();
@@ -583,8 +667,7 @@ struct DiffUnzipPass
// If the use-block is under the same region, use the
// differential counter variable
//
- auto diffCounterCurrValue = builder.emitLoad(region->diffCountVar);
- instsToTag.add(diffCounterCurrValue);
+ auto diffCounterCurrValue = region->diffCountParam;
loadAddr = builder.emitElementAddress(
builder.getPtrType(currType),
@@ -596,7 +679,7 @@ struct DiffUnzipPass
// If the use-block is outside this region, use the
// last available value (by indexing with primal counter minus 1)
//
- auto primalCounterCurrValue = builder.emitLoad(region->primalCountVar);
+ auto primalCounterCurrValue = builder.emitLoad(region->primalCountLastVar);
auto primalCounterLastValue = builder.emitSub(
primalCounterCurrValue->getDataType(),
primalCounterCurrValue,
@@ -621,11 +704,11 @@ struct DiffUnzipPass
}
}
- // TODO: Loop-counter is not really the right decoration..
- // replace with primal-inst when it's ready.
- //
for (auto instToTag : instsToTag)
- builder.addLoopCounterDecoration(instToTag);
+ {
+ builder.addPrimalValueAccessDecoration(instToTag);
+ builder.markInstAsPrimal(instToTag);
+ }
}
}
@@ -1306,11 +1389,6 @@ struct DiffUnzipPass
// Nothing should be left in the original block.
SLANG_ASSERT(block->getFirstChild() == block->getTerminator());
-
- // Branch from primal to differential block.
- // Functionally, the new blocks should produce the same output as the
- // old block.
- // primalBuilder.emitBranch(diffBlock);
}
};
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 1232cf50d..97cdb644e 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -463,6 +463,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_PrimalInstDecoration:
case kIROp_DifferentialInstDecoration:
case kIROp_MixedDifferentialInstDecoration:
+ case kIROp_PrimalValueAccessDecoration:
case kIROp_BackwardDerivativeDecoration:
case kIROp_BackwardDerivativeIntermediateTypeDecoration:
case kIROp_BackwardDerivativePropagateDecoration:
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 35877d680..f2107aa62 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -598,6 +598,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(LayoutDecoration, layout, 1, 0)
INST(LoopControlDecoration, loopControl, 1, 0)
INST(LoopMaxItersDecoration, loopMaxIters, 1, 0)
+ INST(LoopExitPrimalValueDecoration, loopExitPrimalValue, 2, 0)
INST(IntrinsicOpDecoration, intrinsicOp, 1, 0)
/* TargetSpecificDecoration */
INST(TargetDecoration, target, 1, 0)
@@ -769,6 +770,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0)
INST(LoopCounterDecoration, loopCounterDecoration, 0, 0)
+ INST(PrimalValueAccessDecoration, primalValueAccessDecoration, 0, 0)
/* Auto-diff inst decorations */
/// Used by the auto-diff pass to mark insts that compute
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 0eef9cb43..fe20f17f5 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -683,6 +683,18 @@ struct IRLoopCounterDecoration : IRDecoration
IR_LEAF_ISA(LoopCounterDecoration)
};
+struct IRLoopExitPrimalValueDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_LoopExitPrimalValueDecoration
+ };
+ IR_LEAF_ISA(LoopExitPrimalValueDecoration)
+
+ IRInst* getTargetInst() { return getOperand(0); }
+ IRInst* getLoopExitValInst() { return getOperand(1); }
+};
+
struct IRAutodiffInstDecoration : IRDecoration
{
IR_PARENT_ISA(AutodiffInstDecoration)
@@ -712,7 +724,6 @@ struct IRPrimalInstDecoration : IRAutodiffInstDecoration
IR_LEAF_ISA(PrimalInstDecoration)
};
-
struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration
{
enum
@@ -726,6 +737,16 @@ struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration
IRType* getPairType() { return as<IRType>(getOperand(0)); }
};
+struct IRPrimalValueAccessDecoration : IRAutodiffInstDecoration
+{
+ enum
+ {
+ kOp = kIROp_PrimalValueAccessDecoration
+ };
+
+ IR_LEAF_ISA(PrimalValueAccessDecoration)
+};
+
struct IRPrimalValueStructKeyDecoration : IRDecoration
{
enum
@@ -3613,6 +3634,16 @@ public:
addDecoration(value, kIROp_LoopCounterDecoration);
}
+ void addLoopExitPrimalValueDecoration(IRInst* value, IRInst* primalInst, IRInst* exitValue)
+ {
+ addDecoration(value, kIROp_LoopExitPrimalValueDecoration, primalInst, exitValue);
+ }
+
+ void addPrimalValueAccessDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_PrimalValueAccessDecoration);
+ }
+
void markInstAsPrimal(IRInst* value)
{
addDecoration(value, kIROp_PrimalInstDecoration);
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index d8246edae..20a8d7d13 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -1056,7 +1056,10 @@ bool constructSSA(ConstructSSAContext* context)
// Figure out what variables we can promote to
// SSA temporaries.
- identifyPromotableVars(context);
+ if (!(context->promotableVars.getCount() > 0))
+ {
+ identifyPromotableVars(context);
+ }
// If none of the variables are promote-able,
// then we can exit without making any changes