summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-02-09 17:40:20 -0500
committerGitHub <noreply@github.com>2023-02-09 17:40:20 -0500
commitdf02f3f50f977112ca1fbb148cd48ee41d560f41 (patch)
tree7732e8fec9f33aff9666b3710c7adb899788c4be /source
parentd911e1bed9572664b1d0554feb3c7d1a2a880518 (diff)
Reverse-mode Loop Support (#2635)
* Full loop support now working. MaxItersAttr in progress * Lookup table updates? * Fixed the max iters decoration * Minox fixes & remove superfluous code * fixup warnings * Revert "Lookup table updates?" This reverts commit 7d9b0793fb5239f31d1155776e846dcf1892d8d9. * Update 07-autodiff.md * Change maxiters to MaxIters * Added asserts * Update 07-autodiff.md
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/slang-ast-modifier.h8
-rw-r--r--source/slang/slang-check-modifier.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp67
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h4
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp38
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h94
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h279
-rw-r--r--source/slang/slang-ir-autodiff.cpp1
-rw-r--r--source/slang/slang-ir-inst-defs.h11
-rw-r--r--source/slang/slang-ir-insts.h62
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
13 files changed, 466 insertions, 118 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 31dd5ed29..533713016 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2823,6 +2823,9 @@ attribute_syntax [fastopt] : FastOptAttribute;
__attributeTarget(LoopStmt)
attribute_syntax [allow_uav_condition] : AllowUAVConditionAttribute;
+__attributeTarget(LoopStmt)
+attribute_syntax [MaxIters(count)] : MaxItersAttribute;
+
__attributeTarget(IfStmt)
attribute_syntax [flatten] : FlattenAttribute;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 666ca77ea..42b79ca4a 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -620,6 +620,14 @@ class UnrollAttribute : public Attribute
IntegerLiteralValue getCount();
};
+// An `[maxiters(count)]`
+class MaxItersAttribute : public Attribute
+{
+ SLANG_AST_CLASS(MaxItersAttribute)
+
+ int32_t value = 0;
+};
+
class LoopAttribute : public Attribute
{
SLANG_AST_CLASS(LoopAttribute)
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index e73f04301..9f3e79978 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -507,6 +507,17 @@ namespace Slang
// as 1 arg if nothing is specified)
SLANG_ASSERT(attr->args.getCount() == 1);
}
+ else if (auto maxItersAttrs = as<MaxItersAttribute>(attr))
+ {
+ if (auto cint = checkConstantIntVal(attr->args[0]))
+ {
+ maxItersAttrs->value = (int32_t) cint->value;
+ }
+ else
+ {
+ getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 1);
+ }
+ }
else if (auto userDefAttr = as<UserDefinedAttribute>(attr))
{
// check arguments against attribute parameters defined in attribClassDecl
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 04acad435..fca34f9a2 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -149,6 +149,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns
builder->markInstAsDifferential(diffSub, resultType);
auto diffMul = builder->emitMul(resultType, primalRight, primalRight);
+ builder->markInstAsPrimal(diffMul);
auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul);
builder->markInstAsDifferential(diffDiv, resultType);
@@ -881,6 +882,29 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI
return InstPair(primalUpdateField, diffUpdateElement);
}
+List<IRInst*> ForwardDiffTranscriber::transcribePhiArgs(IRBuilder* builder, List<IRInst*> origPhiArgs)
+{
+ // Grab the differentials for any phi nodes.
+ List<IRInst*> newArgs;
+ for (auto origArg : origPhiArgs)
+ {
+ auto primalArg = lookupPrimalInst(builder, origArg);
+ newArgs.add(primalArg);
+
+ if (differentiateType(builder, origArg->getDataType()))
+ {
+ auto diffArg = lookupDiffInst(origArg, nullptr);
+ if (diffArg)
+ newArgs.add(diffArg);
+ else
+ newArgs.add(
+ getDifferentialZeroOfType(builder, origArg->getDataType()));
+ }
+ }
+
+ return newArgs;
+}
+
InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
{
// The loop comes with three blocks.. we just need to transcribe each one
@@ -902,13 +926,14 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig
diffLoopOperands.add(diffTargetBlock);
diffLoopOperands.add(diffBreakBlock);
diffLoopOperands.add(diffContinueBlock);
-
- // If there are any other operands, use their primal versions.
+
+ List<IRInst*> phiArgs;
for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++)
- {
- auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii));
- diffLoopOperands.add(primalOperand);
- }
+ phiArgs.add(origLoop->getOperand(ii));
+
+ auto newPhiArgs = transcribePhiArgs(builder, phiArgs);
+ for (auto newArg : newPhiArgs)
+ diffLoopOperands.add(newArg);
IRInst* diffLoop = builder->emitIntrinsicInst(
nullptr,
@@ -917,6 +942,9 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig
diffLoopOperands.getBuffer());
builder->markInstAsMixedDifferential(diffLoop);
+ if (auto maxItersDecoration = origLoop->findDecoration<IRLoopMaxItersDecoration>())
+ builder->addLoopMaxItersDecoration(diffLoop, maxItersDecoration->getMaxIters());
+
return InstPair(diffLoop, diffLoop);
}
@@ -1211,6 +1239,28 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
return diffFunc;
}
+void ForwardDiffTranscriber::checkAutodiffInstDecorations(IRFunc* fwdFunc)
+{
+ for (auto block = fwdFunc->getFirstBlock(); block; block = block->getNextBlock())
+ {
+ for (auto inst = block->getFirstOrdinaryInst(); inst; inst = inst->getNextInst())
+ {
+ // TODO: Special case, not sure why these insts show up
+ if (as<IRUndefined>(inst)) continue;
+
+ List<IRDecoration*> decorations;
+ for (auto decoration : inst->getDecorations())
+ {
+ if (as<IRAutodiffInstDecoration>(decoration))
+ decorations.add(decoration);
+ }
+
+ // Must have _exactly_ one autodiff tag.
+ SLANG_ASSERT(decorations.getCount() == 1);
+ }
+ }
+}
+
// Transcribe a function definition.
InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
{
@@ -1266,6 +1316,10 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr
}
}
}
+
+#if _DEBUG
+ checkAutodiffInstDecorations(diffFunc);
+#endif
return InstPair(primalFunc, diffFunc);
}
@@ -1310,7 +1364,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeMatrix:
case kIROp_MakeMatrixFromScalar:
case kIROp_MatrixReshape:
- case kIROp_VectorReshape:
case kIROp_IntCast:
case kIROp_FloatCast:
case kIROp_MakeVectorFromScalar:
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 260b0a433..e80b25754 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -94,6 +94,10 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
// Transcribe a function without marking the result as a decoration.
IRFunc* transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc);
+ List<IRInst*> transcribePhiArgs(IRBuilder* builder, List<IRInst*> origPhiArgs);
+
+ void checkAutodiffInstDecorations(IRFunc* fwdFunc);
+
// Create an empty func to represent the transcribed func of `origFunc`.
virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 702f9819a..20090ca42 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -765,7 +765,7 @@ namespace Slang
// Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
// derivative of the return value.
- DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam, nullptr};
+ DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam };
diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info);
eliminateDeadCode(diffPropagateFunc);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 31a3072c0..10a734d65 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -901,6 +901,15 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
{
mapPrimalInst(origInst, pair.primal);
mapDifferentialInst(origInst, pair.differential);
+
+
+ if (pair.primal != pair.differential &&
+ !pair.primal->findDecoration<IRAutodiffInstDecoration>() &&
+ !as<IRConstant>(pair.primal))
+ {
+ builder->markInstAsPrimal(pair.primal);
+ }
+
if (pair.differential)
{
switch (pair.differential->getOp())
@@ -920,16 +929,27 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
}
- // Tag the differential inst using a decoration (if it doesn't have one)
- if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() &&
- !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>() &&
- !as<IRConstant>(pair.differential))
+ // Automatically tag the primal and differential results
+ // if they haven't already been handled by the
+ // code.
+ //
+ if (pair.primal != pair.differential)
+ {
+ if (!pair.differential->findDecoration<IRAutodiffInstDecoration>()
+ && !as<IRConstant>(pair.differential))
+ {
+ auto primalType = as<IRType>(pair.primal->getDataType());
+ builder->markInstAsDifferential(pair.differential, primalType);
+ }
+ }
+ else
{
- // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential
- // instead.
- //
- auto primalType = as<IRType>(pair.primal->getDataType());
- builder->markInstAsDifferential(pair.differential, primalType);
+ if (!pair.primal->findDecoration<IRAutodiffInstDecoration>()
+ && !as<IRConstant>(pair.differential))
+ {
+ auto mixedType = as<IRType>(pair.primal->getDataType());
+ builder->markInstAsMixedDifferential(pair.primal, mixedType);
+ }
}
break;
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index d9b28ea3c..2953c6206 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -78,11 +78,6 @@ struct DiffTransposePass
// of the *output* of the function.
//
IRInst* dOutInst;
-
- // Mapping between *primal* insts in the forward-mode function, and the
- // reverse-mode function
- //
- Dictionary<IRInst*, IRInst*>* primalsMap;
};
struct PendingBlockTerminatorEntry
@@ -353,6 +348,13 @@ struct DiffTransposePass
getPhiGrads(firstLoopBlock).getBuffer());
}
+ auto phiGrads = getPhiGrads(condBlock);
+ if (phiGrads.getCount() > 0)
+ {
+ revTrueBlock = insertPhiBlockBefore(revTrueBlock, phiGrads);
+ revFalseBlock = insertPhiBlockBefore(revFalseBlock, phiGrads);
+ }
+
// Emit condition into the new cond block.
builder.setInsertInto(revCondBlock);
builder.emitIfElse(
@@ -533,8 +535,6 @@ struct DiffTransposePass
//
firstRevDiffBlockMap[revDiffFunc] = revBlockMap[terminalDiffBlocks[0]];
- IRInst* retVal = nullptr;
-
for (auto block : workList)
{
// Set dOutParameter as the transpose gradient for the return inst, if any.
@@ -543,7 +543,6 @@ struct DiffTransposePass
if (auto returnInst = as<IRReturn>(block->getTerminator()))
{
this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr));
- retVal = returnInst->getVal();
}
}
@@ -572,6 +571,11 @@ struct DiffTransposePass
auto terminalPrimalBlock = terminalPrimalBlocks[0];
auto firstRevBlock = as<IRBlock>(revBlockMap[terminalDiffBlocks[0]]);
+ auto returnDecoration =
+ terminalPrimalBlock->getTerminator()->findDecoration<IRBackwardDerivativePrimalReturnDecoration>();
+ SLANG_ASSERT(returnDecoration);
+ auto retVal = returnDecoration->getBackwardDerivativePrimalReturnValue();
+
terminalPrimalBlock->getTerminator()->removeAndDeallocate();
IRBuilder subBuilder(builder.getSharedBuilder());
@@ -582,15 +586,6 @@ struct DiffTransposePass
auto branch = subBuilder.emitBranch(firstRevBlock);
- if (!retVal || retVal->getOp() == kIROp_VoidLit)
- {
- retVal = subBuilder.getVoidValue();
- }
- else
- {
- auto makePair = cast<IRMakeDifferentialPair>(retVal);
- retVal = makePair->getPrimalValue();
- }
subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal);
}
@@ -610,6 +605,25 @@ struct DiffTransposePass
}
}
+ IRInst* extractAccumulatorVarGradient(IRBuilder* builder, IRInst* fwdInst)
+ {
+ if (auto accVar = getOrCreateAccumulatorVar(fwdInst))
+ {
+ auto gradValue = builder->emitLoad(accVar);
+ builder->emitStore(
+ accVar,
+ emitDZeroOfDiffInstType(
+ builder,
+ tryGetPrimalTypeFromDiffInst(fwdInst)));
+
+ return gradValue;
+ }
+ else
+ {
+ return nullptr;
+ }
+ }
+
// Fetch or create a gradient accumulator var
// corresponding to a inst. These are used to
// accumulate gradients across blocks.
@@ -688,6 +702,30 @@ struct DiffTransposePass
}
}
+ // Some special instructions simply need to be copied over.
+ // These do not deal with differentials.
+ // TODO: This will not work if there are any differential
+ // insts that rely on loop counter vars having a specific
+ // value.
+ // The solution is to have primal insts appearing in
+ // differential blocks be in their own special blocks that are
+ // ignored entirely, rather than dealing with them one inst
+ // at a time.
+ //
+ for (IRInst* child = fwdBlock->getFirstChild(); child;)
+ {
+ auto nextChild = child->getNextInst();
+
+ if (child->findDecoration<IRLoopCounterDecoration>())
+ {
+ // Loop counter insts should not have any gradients.
+ SLANG_ASSERT(!hasRevGradients(child));
+ child->insertAtEnd(revBlock);
+ }
+
+ child = nextChild;
+ }
+
// Move pointer & reference insts to the top of the reverse-mode block.
List<IRInst*> nonValueInsts;
for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
@@ -719,6 +757,7 @@ struct DiffTransposePass
if (as<IRDecoration>(child) || as<IRParam>(child))
continue;
+
transposeInst(&builder, child);
}
@@ -744,10 +783,10 @@ struct DiffTransposePass
//
if (isInstUsedOutsideParentBlock(param))
{
- auto accVar = getOrCreateAccumulatorVar(param);
+ auto accGradient = extractAccumulatorVarGradient(&builder, param);
addRevGradientForFwdInst(
param,
- RevGradient(param, builder.emitLoad(accVar), nullptr));
+ RevGradient(param, accGradient, nullptr));
}
if (hasRevGradients(param))
@@ -839,15 +878,6 @@ struct DiffTransposePass
break;
}
- // Some special instructions simply need to be copied over.
- // These do not deal with differentials.
- //
- if (inst->findDecoration<IRLoopCounterDecoration>())
- {
- inst->insertAtEnd(builder->getBlock());
- return;
- }
-
// Look for gradient entries for this inst.
List<RevGradient> gradients;
if (hasRevGradients(inst))
@@ -898,9 +928,9 @@ struct DiffTransposePass
//
if (isInstUsedOutsideParentBlock(inst) && !as<IRLoad>(inst))
{
- auto accVar = getOrCreateAccumulatorVar(inst);
+ auto accGradient = extractAccumulatorVarGradient(builder, inst);
gradients.add(
- RevGradient(inst, builder->emitLoad(accVar), nullptr));
+ RevGradient(inst, accGradient, nullptr));
}
// Emit the aggregate of all the gradients here.
@@ -2399,8 +2429,6 @@ struct DiffTransposePass
Dictionary<IRInst*, IRVar*> revAccumulatorVarMap;
- Dictionary<IRInst*, IRInst*>* primalsMap;
-
List<IRInst*> usedPtrs;
Dictionary<IRBlock*, IRBlock*> revBlockMap;
@@ -2412,6 +2440,8 @@ struct DiffTransposePass
List<PendingBlockTerminatorEntry> pendingBlocks;
Dictionary<IRBlock*, List<IRInst*>> phiGradsMap;
+
+ Dictionary<IRBlock*, IRBlock*> initializerBlockMap;
};
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 057ff53c4..1a85ea6a4 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -232,6 +232,24 @@ struct DiffUnzipPass
//
lowerIndexedRegions();
+ // Copy regions from fwd-block to their split blocks
+ // to make it easier to do lookups.
+ //
+ {
+ List<IRBlock*> workList;
+ for (auto blockRegionPair : indexRegionMap)
+ {
+ IRBlock* block = blockRegionPair.Key;
+ workList.add(block);
+ }
+
+ for (auto block : workList)
+ {
+ indexRegionMap[as<IRBlock>(primalMap[block])] = (IndexedRegion*)indexRegionMap[block];
+ indexRegionMap[as<IRBlock>(diffMap[block])] = (IndexedRegion*)indexRegionMap[block];
+ }
+ }
+
// Process intermediate insts in indexed blocks
// into array loads/stores.
//
@@ -262,19 +280,44 @@ struct DiffUnzipPass
IRBlock* getUpdateBlock(IndexedRegion* region)
{
+ // TODO: What if the 'continue' region has multiple
+ // blocks?
+ // We ideally want the _last_ block before control loops back.
+ //
+ SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(
+ region->continueBlock->getTerminator())->getTargetBlock() == region->firstBlock);
+
return region->continueBlock;
}
+
+ IRBlock* getFirstLoopBodyBlock(IndexedRegion* region)
+ {
+ // Grab the 'condition' block.
+ auto condBlock = region->firstBlock;
+
+ SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator()));
+
+ return as<IRIfElse>(condBlock->getTerminator())->getTrueBlock();
+ }
void tryInferMaxIndex(IndexedRegion* region)
{
if (region->status != IndexedRegion::CountStatus::Unresolved)
return;
+
+ auto loop = as<IRLoop>(region->initBlock->getTerminator());
- // We're going to fix this at a some random number
- // for now, and then add some basic inference + user-defined decoration
- //
- region->maxIters = 5;
- region->status = IndexedRegion::CountStatus::Static;
+ if (auto maxItersDecoration = loop->findDecoration<IRLoopMaxItersDecoration>())
+ {
+ region->maxIters = (Count) maxItersDecoration->getMaxIters();
+ region->status = IndexedRegion::CountStatus::Static;
+ }
+
+ if (region->status == IndexedRegion::CountStatus::Unresolved)
+ {
+ SLANG_UNEXPECTED("Could not resolve max iters \
+ for loop appearing in reverse-mode");
+ }
}
// Make a primal value *available* to the differential block.
@@ -297,22 +340,49 @@ struct DiffUnzipPass
for (auto region : indexRegions)
{
- IRBlock* initializerBlock = getInitializerBlock(region);
+ //IRBlock* initializerBlock = getInitializerBlock(region);
+ IRBlock* breakBlock = region->breakBlock;
// Grab first primal block.
- auto firstPrimalBlock = primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()];
+ IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]);
// Make variable in the top-most block (so it's visible to diff blocks)
- builder.setInsertInto(firstPrimalBlock);
- region->primalCountVar = builder.emitVar(builder.getUIntType());
-
- // Make another variable in the diff block initialized to the
- // final value of the primal counter.
+ builder.setInsertBefore(firstPrimalBlock->getTerminator());
+ region->primalCountVar = builder.emitVar(builder.getIntType());
+ builder.emitStore(
+ region->primalCountVar,
+ builder.getIntValue(builder.getIntType(), 0));
+
+ // NOTE: This is a hacky shortcut we're taking here.
+ // Technically the unzip pass should not affect the
+ // correctness (it must still compute the proper fwd-mode derivative)
+ // However, we're currently making the loop counter go backwards to
+ // make it easier on the transposition pass, so the output from
+ // the unzip pass is neither fwd-mode or rev-mode until the transposition
+ // step is complete.
+ //
+ // TODO: Ideally this needs to be replaced with a small inversion step
+ // within the transposition pass.
+ //
+ // Emit the diff counter into the diff *break* block (
+ // which we're praying turns into the reverse initializer block)
+ // initialized to the final value of the primal counter.
//
- builder.setInsertInto(diffMap[initializerBlock]);
- auto primalCounterValue = builder.emitLoad(region->primalCountVar);
- region->diffCountVar = builder.emitVar(builder.getUIntType());
- builder.emitStore(region->diffCountVar, primalCounterValue);
+ builder.setInsertBefore(as<IRBlock>(diffMap[breakBlock])->getTerminator());
+ //auto primalCounterValue = builder.emitLoad(region->primalCountVar);
+ auto primalCounterCurrValue = builder.emitLoad(region->primalCountVar);
+ auto primalCounterLastValue = builder.emitSub(
+ primalCounterCurrValue->getDataType(),
+ primalCounterCurrValue,
+ builder.getIntValue(builder.getIntType(), 1));
+
+ region->diffCountVar = builder.emitVar(builder.getIntType());
+ auto diffCountInit = builder.emitStore(region->diffCountVar, primalCounterLastValue);
+
+ builder.addLoopCounterDecoration(diffCountInit);
+ builder.addLoopCounterDecoration(region->diffCountVar);
+ builder.addLoopCounterDecoration(primalCounterCurrValue);
+ builder.addLoopCounterDecoration(primalCounterLastValue);
IRBlock* updateBlock = getUpdateBlock(region);
@@ -324,9 +394,9 @@ struct DiffUnzipPass
auto counterVal = builder.emitLoad(region->primalCountVar);
auto incCounterVal = builder.emitAdd(
- builder.getUIntType(),
+ builder.getIntType(),
counterVal,
- builder.getIntValue(builder.getUIntType(), 1));
+ builder.getIntValue(builder.getIntType(), 1));
auto incStore = builder.emitStore(region->primalCountVar, incCounterVal);
@@ -336,25 +406,16 @@ struct DiffUnzipPass
}
{
- // NOTE: This is a hacky shortcut we're taking here.
- // Technically the unzip pass should not affect the
- // correctness (it must still compute the proper fwd-mode derivative)
- // However, we're currently making the loop counter go backwards to
- // make it easier on the transposition pass, so the output from
- // the unzip pass is neither fwd-mode or rev-mode until the transposition
- // step is complete.
- //
- // TODO: Ideally this needs to be replaced with a small inversion step
- // within the transposition pass.
- //
+ IRBlock* firstLoopBlock = getFirstLoopBodyBlock(region);
+ auto diffFirstLoopBlock = as<IRBlock>(diffMap[firstLoopBlock]);
- builder.setInsertBefore(as<IRBlock>(diffMap[updateBlock])->getTerminator());
+ builder.setInsertBefore(diffFirstLoopBlock->getTerminator());
auto counterVal = builder.emitLoad(region->diffCountVar);
auto decCounterVal = builder.emitSub(
- builder.getUIntType(),
+ builder.getIntType(),
counterVal,
- builder.getIntValue(builder.getUIntType(), 0));
+ builder.getIntValue(builder.getIntType(), 1));
auto decStore = builder.emitStore(region->diffCountVar, decCounterVal);
@@ -363,6 +424,27 @@ struct DiffUnzipPass
builder.addLoopCounterDecoration(counterVal);
builder.addLoopCounterDecoration(decCounterVal);
builder.addLoopCounterDecoration(decStore);
+
+ // TODO:
+ // This is another hack here to avoid the counter from going negative
+ // (since they are not valid indices)
+ //
+ IRBlock* diffCondBlock = as<IRBlock>(diffMap[region->firstBlock]);
+
+ builder.setInsertBefore(diffCondBlock->getTerminator());
+ IRInst* diffCounterVal = builder.emitLoad(region->diffCountVar);
+ IRInst* diffCounterCmp = builder.emitIntrinsicInst(
+ builder.getBoolType(),
+ kIROp_Geq,
+ 2,
+ List<IRInst*>(
+ diffCounterVal,
+ builder.getIntValue(builder.getIntType(), 0)).getBuffer());
+
+ as<IRIfElse>(diffCondBlock->getTerminator())->condition.set(diffCounterCmp);
+
+ builder.addLoopCounterDecoration(diffCounterVal);
+ builder.addLoopCounterDecoration(diffCounterCmp);
}
}
@@ -394,6 +476,7 @@ struct DiffUnzipPass
for (; region; region = region->parent)
regions.add(region);
}
+
for (auto inst : primalInsts)
{
@@ -407,6 +490,7 @@ struct DiffUnzipPass
if (isDifferentialInst(useBlock))
{
shouldStore = true;
+ break;
}
}
@@ -439,52 +523,111 @@ struct DiffUnzipPass
auto storageVar = builder.emitVar(arrayType);
// 3. Store current value into the array and replace uses with a load.
+ // TODO: If an index is missing, use the 'last' value of the primal index.
{
builder.setInsertAfter(inst);
IRInst* storeAddr = storageVar;
- IRType* currType = storageVar->getDataType();
+ IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType();
for (auto region : regions)
{
currType = as<IRArrayType>(currType)->getElementType();
storeAddr = builder.emitElementAddress(
- currType,
+ builder.getPtrType(currType),
storeAddr,
- region->primalCountVar);
+ builder.emitLoad(region->primalCountVar));
}
builder.emitStore(storeAddr, inst);
}
// 4. Replace uses in differential blocks with loads from the array.
+ List<IRInst*> instsToTag;
{
+ List<IRUse*> diffUses;
for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (as<IRDecoration>(use->getUser()))
+ continue;
+
+ IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent());
+ if (useBlock && isDifferentialInst(useBlock))
+ diffUses.add(use);
+ }
+
+ for (auto use : diffUses)
{
IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent());
+ builder.setInsertBefore(use->getUser());
+
+ IRInst* loadAddr = storageVar;
+ IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType();
- if (isDifferentialInst(useBlock))
+ // Enumerate use block regions.
+ // TODO: Probably a good idea to do this ahead of time for
+ // all blocks.
+ //
+ List<IndexedRegion*> useBlockRegions;
{
- builder.setInsertBefore(use->getUser());
+ IndexedRegion* region = indexRegionMap.ContainsKey(useBlock) ?
+ (IndexedRegion*)indexRegionMap[useBlock] : nullptr;
+ for (; region; region = region->parent)
+ useBlockRegions.add(region);
+ }
- IRInst* loadAddr = storageVar;
- IRType* currType = storageVar->getDataType();
+ for (auto region : regions)
+ {
+ currType = as<IRArrayType>(currType)->getElementType();
+ if (useBlockRegions.contains(region))
+ {
+ // If the use-block is under the same region, use the
+ // differential counter variable
+ //
+ auto diffCounterCurrValue = builder.emitLoad(region->diffCountVar);
+ instsToTag.add(diffCounterCurrValue);
- for (auto region : regions)
+ loadAddr = builder.emitElementAddress(
+ builder.getPtrType(currType),
+ loadAddr,
+ diffCounterCurrValue);
+ }
+ else
{
- currType = as<IRArrayType>(currType)->getElementType();
+ // If the use-block is outside this region, use the
+ // last available value (by indexing with primal counter minus 1)
+ //
+ auto primalCounterCurrValue = builder.emitLoad(region->primalCountVar);
+ auto primalCounterLastValue = builder.emitSub(
+ primalCounterCurrValue->getDataType(),
+ primalCounterCurrValue,
+ builder.getIntValue(builder.getIntType(), 1));
+
+ instsToTag.add(primalCounterCurrValue);
+ instsToTag.add(primalCounterLastValue);
loadAddr = builder.emitElementAddress(
- currType,
+ builder.getPtrType(currType),
loadAddr,
- region->diffCountVar);
+ primalCounterLastValue);
}
- use->set(builder.emitLoad(loadAddr));
+ instsToTag.add(loadAddr);
}
+
+ auto loadedValue = builder.emitLoad(loadAddr);
+ instsToTag.add(loadedValue);
+
+ use->set(loadedValue);
}
}
+
+ // TODO: Loop-counter is not really the right decoration..
+ // replace with primal-inst when it's ready.
+ //
+ for (auto instToTag : instsToTag)
+ builder.addLoopCounterDecoration(instToTag);
}
}
@@ -710,7 +853,11 @@ struct DiffUnzipPass
// Check that we have an unambiguous 'first' differential block.
SLANG_ASSERT(firstDiffBlock);
+
auto primalBranch = primalBuilder->emitBranch(firstDiffBlock);
+ primalBuilder->addBackwardDerivativePrimalReturnDecoration(
+ primalBranch, lookupPrimalInst(mixedReturn->getVal()));
+
auto pairVal = diffBuilder->emitMakeDifferentialPair(
pairType,
lookupPrimalInst(mixedReturn->getVal()),
@@ -726,6 +873,9 @@ struct DiffUnzipPass
{
// If return value is not differentiable, just turn it into a trivial branch.
auto primalBranch = primalBuilder->emitBranch(firstDiffBlock);
+ primalBuilder->addBackwardDerivativePrimalReturnDecoration(
+ primalBranch, primalBuilder->getVoidValue());
+
auto returnInst = diffBuilder->emitReturn();
diffBuilder->markInstAsDifferential(returnInst, nullptr);
return InstPair(primalBranch, returnInst);
@@ -903,15 +1053,38 @@ struct DiffUnzipPass
// Push a new index.
addNewIndex(mixedLoop);
- return InstPair(
- primalBuilder->emitLoop(
- as<IRBlock>(primalMap[nextBlock]),
- as<IRBlock>(primalMap[breakBlock]),
- as<IRBlock>(primalMap[continueBlock])),
- diffBuilder->emitLoop(
- as<IRBlock>(diffMap[nextBlock]),
- as<IRBlock>(diffMap[breakBlock]),
- as<IRBlock>(diffMap[continueBlock])));
+ // Split args.
+ List<IRInst*> primalArgs;
+ List<IRInst*> diffArgs;
+ for (UIndex ii = 0; ii < mixedLoop->getArgCount(); ii++)
+ {
+ if (isDifferentialInst(mixedLoop->getArg(ii)))
+ diffArgs.add(mixedLoop->getArg(ii));
+ else
+ primalArgs.add(mixedLoop->getArg(ii));
+ }
+
+ auto primalLoop = primalBuilder->emitLoop(
+ as<IRBlock>(primalMap[nextBlock]),
+ as<IRBlock>(primalMap[breakBlock]),
+ as<IRBlock>(primalMap[continueBlock]),
+ primalArgs.getCount(),
+ primalArgs.getBuffer());
+
+ auto diffLoop = diffBuilder->emitLoop(
+ as<IRBlock>(diffMap[nextBlock]),
+ as<IRBlock>(diffMap[breakBlock]),
+ as<IRBlock>(diffMap[continueBlock]),
+ diffArgs.getCount(),
+ diffArgs.getBuffer());
+
+ if (auto maxItersDecoration = mixedLoop->findDecoration<IRLoopMaxItersDecoration>())
+ {
+ primalBuilder->addLoopMaxItersDecoration(primalLoop, maxItersDecoration->getMaxIters());
+ diffBuilder->addLoopMaxItersDecoration(diffLoop, maxItersDecoration->getMaxIters());
+ }
+
+ return InstPair(primalLoop, diffLoop);
}
InstPair splitControlFlow(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* branchInst)
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index f38bdfdbd..2ce5a48f7 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -460,6 +460,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_ForwardDerivativeDecoration:
case kIROp_DerivativeMemberDecoration:
case kIROp_DifferentiableTypeDictionaryDecoration:
+ case kIROp_PrimalInstDecoration:
case kIROp_DifferentialInstDecoration:
case kIROp_MixedDifferentialInstDecoration:
case kIROp_BackwardDerivativeDecoration:
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 26a92a17a..e627c575d 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -598,6 +598,7 @@ INST(GetOptiXSbtDataPtr, getOptiXSbtDataPointer, 0, 0)
INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(LayoutDecoration, layout, 1, 0)
INST(LoopControlDecoration, loopControl, 1, 0)
+ INST(LoopMaxItersDecoration, loopMaxIters, 1, 0)
INST(IntrinsicOpDecoration, intrinsicOp, 1, 0)
/* TargetSpecificDecoration */
INST(TargetDecoration, target, 1, 0)
@@ -767,13 +768,19 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(LoopCounterDecoration, loopCounterDecoration, 0, 0)
+ /* Auto-diff inst decorations */
+ /// Used by the auto-diff pass to mark insts that compute
+ /// a primal value.
+ INST(PrimalInstDecoration, primalInstDecoration, 0, 0)
+
/// Used by the auto-diff pass to mark insts that compute
/// a differential value.
- INST(DifferentialInstDecoration, diffInstDecoration, 1, 0)
+ INST(DifferentialInstDecoration, diffInstDecoration, 1, 0)
/// Used by the auto-diff pass to mark insts that compute
/// BOTH a differential and a primal value.
- INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0)
+ INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0)
+ INST_RANGE(AutodiffInstDecoration, PrimalInstDecoration, MixedDifferentialInstDecoration)
/// Used by the auto-diff pass to mark insts whose result is stored
/// in an intermediary struct for reuse in backward propagation phase.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index fad20e900..2453b56a7 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -66,6 +66,14 @@ struct IRLoopControlDecoration : IRDecoration
}
};
+struct IRLoopMaxItersDecoration : IRDecoration
+{
+ enum { kOp = kIROp_LoopMaxItersDecoration };
+ IR_LEAF_ISA(LoopMaxItersDecoration)
+
+ IRConstant* getMaxItersInst() { return cast<IRConstant>(getOperand(0)); }
+ IRIntegerValue getMaxIters() { return as<IRIntLit>(getOperand(0))->getValue(); }
+};
struct IRTargetSpecificDecoration : IRDecoration
{
@@ -672,7 +680,12 @@ struct IRLoopCounterDecoration : IRDecoration
IR_LEAF_ISA(LoopCounterDecoration)
};
-struct IRDifferentialInstDecoration : IRDecoration
+struct IRAutodiffInstDecoration : IRDecoration
+{
+ IR_PARENT_ISA(AutodiffInstDecoration)
+};
+
+struct IRDifferentialInstDecoration : IRAutodiffInstDecoration
{
enum
{
@@ -686,41 +699,52 @@ struct IRDifferentialInstDecoration : IRDecoration
IRInst* getPrimalInst() { return as<IRInst>(getOperand(1)); }
};
-struct IRPrimalValueStructKeyDecoration : IRDecoration
+struct IRPrimalInstDecoration : IRAutodiffInstDecoration
{
enum
{
- kOp = kIROp_PrimalValueStructKeyDecoration
+ kOp = kIROp_PrimalInstDecoration
};
- IR_LEAF_ISA(PrimalValueStructKeyDecoration)
+ IR_LEAF_ISA(PrimalInstDecoration)
+};
- IRStructKey* getStructKey() { return as<IRStructKey>(getOperand(0)); }
+
+struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration
+{
+ enum
+ {
+ kOp = kIROp_MixedDifferentialInstDecoration
+ };
+
+ IRUse pairType;
+ IR_LEAF_ISA(MixedDifferentialInstDecoration)
+
+ IRType* getPairType() { return as<IRType>(getOperand(0)); }
};
-struct IRPrimalElementTypeDecoration : IRDecoration
+struct IRPrimalValueStructKeyDecoration : IRDecoration
{
enum
{
- kOp = kIROp_PrimalElementTypeDecoration
+ kOp = kIROp_PrimalValueStructKeyDecoration
};
- IR_LEAF_ISA(PrimalElementTypeDecoration)
+ IR_LEAF_ISA(PrimalValueStructKeyDecoration)
- IRInst* getPrimalElementType() { return getOperand(0); }
+ IRStructKey* getStructKey() { return as<IRStructKey>(getOperand(0)); }
};
-struct IRMixedDifferentialInstDecoration : IRDecoration
+struct IRPrimalElementTypeDecoration : IRDecoration
{
enum
{
- kOp = kIROp_MixedDifferentialInstDecoration
+ kOp = kIROp_PrimalElementTypeDecoration
};
- IRUse pairType;
- IR_LEAF_ISA(MixedDifferentialInstDecoration)
+ IR_LEAF_ISA(PrimalElementTypeDecoration)
- IRType* getPairType() { return as<IRType>(getOperand(0)); }
+ IRInst* getPrimalElementType() { return getOperand(0); }
};
struct IRBackwardDifferentiableDecoration : IRDecoration
@@ -3519,6 +3543,11 @@ public:
addDecoration(value, kIROp_LoopControlDecoration, getIntValue(getIntType(), IRIntegerValue(mode)));
}
+ void addLoopMaxItersDecoration(IRInst* value, IntegerLiteralValue iters)
+ {
+ addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(getIntType(), iters));
+ }
+
void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0)
{
addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index));
@@ -3651,6 +3680,11 @@ public:
addDecoration(value, kIROp_LoopCounterDecoration);
}
+ void markInstAsPrimal(IRInst* value)
+ {
+ addDecoration(value, kIROp_PrimalInstDecoration);
+ }
+
void markInstAsDifferential(IRInst* value)
{
addDecoration(value, kIROp_DifferentialInstDecoration, nullptr);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 8377246fb..74f06557d 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4845,6 +4845,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
{
getBuilder()->addLoopControlDecoration(inst, kIRLoopControl_Loop);
}
+ else if( auto maxItersAttr = stmt->findModifier<MaxItersAttribute>() )
+ {
+ getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value);
+ }
// TODO: handle other cases here
}