summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-01-04 23:40:13 +0530
committerGitHub <noreply@github.com>2023-01-04 10:10:13 -0800
commit7f64b2a9e3eb7aea13de550bd24c1aea7787c94b (patch)
tree40afc50c9fb227b8728487403d3f9b712a1509b2 /source
parente8f977a00f5d131ec2d51d2a026d6452e8f762f0 (diff)
Multi-block reverse-mode autodiff (#2576)
* Initial multi-block implementation * Implemented multi-block reverse-mode (without loops) * Added logic to remove block-level decorations to avoid confusing IR simplification passes * Fixed issues with block-level decorations during IR simplification by removing them prior to simplification. Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp23
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp12
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h631
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp95
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h181
-rw-r--r--source/slang/slang-ir-autodiff.cpp51
-rw-r--r--source/slang/slang-ir-insts.h6
-rw-r--r--source/slang/slang-ir.cpp4
9 files changed, 847 insertions, 159 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index dbf79b5f8..c245701df 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -519,6 +519,11 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns
// block to compute *both* primals and derivatives (i.e linearized block)
SLANG_ASSERT(diffBranch);
+ // Since blocks always compute both primals and differentials, the branch
+ // instructions are also always mixed.
+ //
+ builder->markInstAsMixedDifferential(diffBranch);
+
return InstPair(diffBranch, diffBranch);
}
@@ -740,6 +745,7 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig
kIROp_loop,
diffLoopOperands.getCount(),
diffLoopOperands.getBuffer());
+ builder->markInstAsMixedDifferential(diffLoop);
return InstPair(diffLoop, diffLoop);
}
@@ -779,13 +785,14 @@ InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse*
diffIfElseArgs.add(primalOperand);
}
- IRInst* diffLoop = builder->emitIntrinsicInst(
+ IRInst* diffIfElse = builder->emitIntrinsicInst(
nullptr,
kIROp_ifElse,
diffIfElseArgs.getCount(),
diffIfElseArgs.getBuffer());
+ builder->markInstAsMixedDifferential(diffIfElse);
- return InstPair(diffLoop, diffLoop);
+ return InstPair(diffIfElse, diffIfElse);
}
InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst)
@@ -963,10 +970,16 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr
builder.setInsertInto(diffFunc);
differentiableTypeConformanceContext.setFunc(primalFunc);
+
// Transcribe children from origFunc into diffFunc
for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
this->transcribe(&builder, block);
+ // Some of the transcribed blocks can appear 'out-of-order'. Although this
+ // shouldn't be an issue, for consistency, we put them back in order.
+ for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
+ as<IRBlock>(lookupDiffInst(block))->insertAtEnd(diffFunc);
+
return InstPair(primalFunc, diffFunc);
}
@@ -1124,6 +1137,12 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
return trascribeNonDiffInst(builder, origInst);
case kIROp_StructKey:
return InstPair(origInst, nullptr);
+ case kIROp_Unreachable:
+ {
+ auto unreachInst = builder->emitUnreachable();
+ builder->markInstAsMixedDifferential(unreachInst);
+ return InstPair(unreachInst, nullptr);
+ }
case kIROp_MakeExistentialWithRTTI:
SLANG_UNEXPECTED("MakeExistentialWithRTTI inst is not expected in autodiff pass.");
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index cfee49eb1..ae9b69f61 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -429,11 +429,6 @@ namespace Slang
block->insertAtEnd(diffFunc);
}
- // Extracts the primal computations into its own func, and replace the primal insts
- // with the intermediate results computed from the extracted func.
- IRInst* intermediateType = nullptr;
- auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType);
-
// Transpose the first block (parameter block)
transposeParameterBlock(builder, diffFunc);
@@ -445,7 +440,12 @@ namespace Slang
DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr};
diffTransposePass->transposeDiffBlocksInFunc(diffFunc, info);
- // Clean up by deallocating intermediate steps.
+ // Extracts the primal computations into its own func, and replace the primal insts
+ // with the intermediate results computed from the extracted func.
+ IRInst* intermediateType = nullptr;
+ auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType);
+
+ // Clean up by deallocating intermediate versions.
tempDiffFunc->removeAndDeallocate();
unzippedFwdDiffFunc->removeAndDeallocate();
fwdDiffFunc->removeAndDeallocate();
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index da7762908..69cef941c 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -598,8 +598,9 @@ InstPair AutoDiffTranscriberBase::transcribeBlock(IRBuilder* builder, IRBlock* o
{
IRBuilder subBuilder(builder->getSharedBuilder());
subBuilder.setInsertLoc(builder->getInsertLoc());
-
+
IRInst* diffBlock = subBuilder.emitBlock();
+ subBuilder.markInstAsMixedDifferential(diffBlock);
// Note: for blocks, we setup the mapping _before_
// processing the children since we could encounter
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index a14ecad84..436a17a7f 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -54,6 +54,24 @@ struct DiffTransposePass
Flavor flavor;
};
+ struct BlockPredecessorEntry
+ {
+ // Previous block.
+ IRBlock* prevBlock;
+
+ // Integer value corresponding to this predecessor.
+ IRIntegerValue* indexVal;
+ };
+
+ struct ControlFlowTranspositionInfo
+ {
+ // Variable used for recording control flow.
+ IRVar* controlVar;
+
+ // Info about all possible predecessor blocks.
+ Dictionary<IRBlock*, IRInst*> predEntries;
+ };
+
DiffTransposePass(AutoDiffSharedContext* autodiffContext) :
autodiffContext(autodiffContext), pairBuilder(autodiffContext), diffTypeContext(autodiffContext)
{ }
@@ -90,6 +108,11 @@ struct DiffTransposePass
{
// Grab all differentiable type information.
diffTypeContext.setFunc(revDiffFunc);
+
+ // Note down terminal primal and terminal differential blocks
+ // since we need to link them up at the end.
+ auto terminalPrimalBlocks = getTerminalPrimalBlocks(revDiffFunc);
+ auto terminalDiffBlocks = getTerminalDiffBlocks(revDiffFunc);
// Traverse all instructions/blocks in reverse (starting from the terminator inst)
// look for insts/blocks marked with IRDifferentialInstDecoration,
@@ -117,9 +140,20 @@ struct DiffTransposePass
workList.add(block);
}
- // TODO: We *might* need a step here that 'sorts' the work list in reverse order starting with 'leaf'
- // differential blocks, and following the branches backwards.
- // The alternative is to make phi nodes and treat all intermediaries & their gradients as arguments.
+ // Reverse the order of the blocks.
+ workList.reverse();
+
+ // Emit empty rev-mode blocks for every fwd-mode block.
+ for (auto block : workList)
+ {
+ revBlockMap[block] = builder.emitBlock();
+ builder.markInstAsDifferential(revBlockMap[block]);
+ }
+
+ // Keep track of first diff block, since this is where
+ // we'll emit temporary vars to hold per-block derivatives.
+ //
+ firstRevDiffBlockMap[revDiffFunc] = revBlockMap[workList[0]];
for (auto block : workList)
{
@@ -129,27 +163,123 @@ struct DiffTransposePass
this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr));
}
- IRBlock* revBlock = builder.emitBlock();
+ IRBlock* revBlock = revBlockMap[block];
this->transposeBlock(block, revBlock);
+ }
+
+ // Link the last differential fwd-mode block (which will be the first
+ // rev-mode block) as the successor to the last primal block.
+ // We assume that the original function is in single-return form
+ // So, there should be exactly 1 'last' block of each type.
+ //
+ {
+ SLANG_ASSERT(terminalPrimalBlocks.getCount() == 1);
+ SLANG_ASSERT(terminalDiffBlocks.getCount() == 1);
+
+ auto terminalPrimalBlock = terminalPrimalBlocks[0];
+ auto terminalRevBlock = as<IRBlock>(revBlockMap[terminalDiffBlocks[0]]);
+
+ terminalPrimalBlock->getTerminator()->removeAndDeallocate();
+
+ IRBuilder subBuilder(builder.getSharedBuilder());
+ subBuilder.setInsertInto(terminalPrimalBlock);
+
+ // There should be no parameters in the first reverse-mode block.
+ SLANG_ASSERT(terminalRevBlock->getFirstParam() == nullptr);
- // TODO: This should only really be used for the transition from
- // the 'last' primal block(s) to the first differential block.
- // Transitions from differential blocks to
- block->replaceUsesWith(revBlock);
+ subBuilder.emitBranch(terminalRevBlock);
+ }
+
+ // Remove fwd-mode blocks.
+ for (auto block : workList)
+ {
block->removeAndDeallocate();
}
}
- // A[cond_inst] -> (B or C) -> D => D[cond_inst] -> (B_T -> C_T) -> A_T
+ // Fetch or create a gradient accumulator var
+ // corresponding to a inst. These are used to
+ // accumulate gradients across blocks.
+ //
+ IRVar* getOrCreateAccumulatorVar(IRInst* fwdInst)
+ {
+ // Check if we have a var already.
+ if (revAccumulatorVarMap.ContainsKey(fwdInst))
+ return revAccumulatorVarMap[fwdInst];
+
+ IRBuilder tempVarBuilder(autodiffContext->sharedBuilder);
+
+ IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(fwdInst->getParent()->getParent())];
+ tempVarBuilder.setInsertBefore(firstDiffBlock->getTerminator());
+
+ auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst);
+ auto diffType = fwdInst->getDataType();
+
+ auto zeroMethod = diffTypeContext.getZeroMethodForType(
+ &tempVarBuilder,
+ primalType);
+
+ SLANG_ASSERT(zeroMethod);
+
+ // Emit a var in the top-level differential block to hold the gradient,
+ // and initialize it.
+ auto tempRevVar = tempVarBuilder.emitVar(diffType);
+ auto diffZero = tempVarBuilder.emitCallInst(
+ diffType,
+ zeroMethod,
+ List<IRInst*>());
+ tempVarBuilder.emitStore(tempRevVar, diffZero);
+
+ revAccumulatorVarMap[fwdInst] = tempRevVar;
+
+ return tempRevVar;
+ }
+
+ bool isInstUsedOutsideParentBlock(IRInst* inst)
+ {
+ auto currBlock = inst->getParent();
+
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser()->getParent() != currBlock)
+ return true;
+ }
+ return false;
+ }
+
void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock)
{
IRBuilder builder;
builder.init(autodiffContext->sharedBuilder);
- // Insert after the last block.
+ // Insert into our reverse block.
builder.setInsertInto(revBlock);
+ // Check if this block has any 'outputs' (in the form of phi args
+ // sent to the successor bvock)
+ //
+ if (auto branchInst = as<IRUnconditionalBranch>(fwdBlock->getTerminator()))
+ {
+ for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++)
+ {
+ auto arg = branchInst->getArg(ii);
+ if (isDifferentialInst(arg))
+ {
+ auto diffType = arg->getDataType();
+ auto revParam = builder.emitParam(diffType);
+
+ addRevGradientForFwdInst(
+ arg,
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ arg,
+ revParam,
+ nullptr));
+ }
+ }
+ }
+
// Move pointer & reference insts to the top of the reverse-mode block.
List<IRInst*> nonValueInsts;
for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
@@ -178,7 +308,7 @@ struct DiffTransposePass
//
for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst())
{
- if (as<IRDecoration>(child))
+ if (as<IRDecoration>(child) || as<IRParam>(child))
continue;
transposeInst(&builder, child);
@@ -193,22 +323,78 @@ struct DiffTransposePass
//
for (auto pair : gradientsMap)
{
- if (auto param = as<IRLoad>(pair.Key))
- accumulateGradientsForLoad(&builder, param);
+ if (auto loadInst = as<IRLoad>(pair.Key))
+ accumulateGradientsForLoad(&builder, loadInst);
}
- // Emit a terminator inst.
- // TODO: need a be a lot smarter here. For now, we assume a single differential
- // block, so it should end in a return statement.
- if (as<IRReturn>(fwdBlock->getTerminator()))
+ // Do the same thing with the phi parameters if the block.
+ List<IRInst*> phiParamRevGradInsts;
+ for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam())
{
- // Emit a void return.
- builder.emitReturn();
+ if (hasRevGradients(param))
+ {
+ auto gradients = popRevGradients(param);
+
+ auto gradInst = emitAggregateValue(
+ &builder,
+ tryGetPrimalTypeFromDiffInst(param),
+ gradients);
+
+ phiParamRevGradInsts.add(gradInst);
+ }
}
- else
+
+ // Also handle any remaining gradients for insts that appear in prior blocks.
+ List<IRInst*> externInsts; // Holds insts in a different block, same function.
+ List<IRInst*> globalInsts; // Holds insts in the global scope.
+ for (auto pair : gradientsMap)
{
- SLANG_UNEXPECTED("Unhandled block terminator");
+ auto instParent = pair.Key->getParent();
+ if (instParent != fwdBlock)
+ {
+ if (instParent->getParent() == fwdBlock->getParent())
+ externInsts.add(pair.Key);
+
+ if (as<IRModuleInst>(instParent))
+ globalInsts.add(pair.Key);
+ }
}
+
+ for (auto externInst : externInsts)
+ {
+ auto primalType = tryGetPrimalTypeFromDiffInst(externInst);
+ SLANG_ASSERT(primalType);
+
+ if (auto accVar = getOrCreateAccumulatorVar(externInst))
+ {
+ // Accumulate all gradients, including our accumulator variable,
+ // into one inst.
+ //
+ auto gradients = popRevGradients(externInst);
+ gradients.add(RevGradient(externInst, builder.emitLoad(accVar), nullptr));
+
+ auto gradInst = emitAggregateValue(
+ &builder,
+ primalType,
+ gradients);
+
+ builder.emitStore(accVar, gradInst);
+ }
+ }
+
+ // For now, we're not going to handle global insts, and simply ignore them
+ // Eventually, we want to turn these into global writes.
+ //
+ for (auto globalInst : globalInsts)
+ {
+ if (hasRevGradients(globalInst))
+ popRevGradients(globalInst);
+ }
+
+ // We _should_ be completely out of gradients to process at this point.
+ SLANG_ASSERT(gradientsMap.Count() == 0);
+
+ emitTerminator(&builder, fwdBlock, phiParamRevGradInsts);
}
void transposeInst(IRBuilder* builder, IRInst* inst)
@@ -242,13 +428,32 @@ struct DiffTransposePass
if (!primalType)
{
// Check for special insts for which a reverse-mode gradient doesn't apply.
- if(!as<IRStore>(inst))
+ if(!as<IRStore>(inst) && !as<IRTerminatorInst>(inst))
{
SLANG_UNEXPECTED("Could not resolve primal type for diff inst");
}
+
+ // If we still can't resolve a differential type, there shouldn't
+ // be any gradients to aggregate.
+ //
+ SLANG_ASSERT(gradients.getCount() == 0);
}
- // Emit the aggregate of all the gradients here. This will form the total derivative for this inst.
+ // Is this inst used in another differential block?
+ // Emit a function-scope accumulator variable, and include it's value.
+ // Also, we ignore this if it's a load since those are turned into stores
+ // on a per-block basis. (We should change this behaviour to treat loads like
+ // any other inst)
+ //
+ if (isInstUsedOutsideParentBlock(inst) && !as<IRLoad>(inst))
+ {
+ auto accVar = getOrCreateAccumulatorVar(inst);
+ gradients.add(
+ RevGradient(inst, builder->emitLoad(accVar), nullptr));
+ }
+
+ // Emit the aggregate of all the gradients here.
+ // This will form the total derivative for this inst.
auto revValue = emitAggregateValue(builder, primalType, gradients);
auto transposeResult = transposeInst(builder, inst, revValue);
@@ -376,15 +581,297 @@ struct DiffTransposePass
return TranspositionResult(gradients);
}
+
+ IRBlock* getPrimalBlock(IRBlock* fwdBlock)
+ {
+ if (auto fwdDiffDecoration = fwdBlock->findDecoration<IRDifferentialInstDecoration>())
+ {
+ return as<IRBlock>(fwdDiffDecoration->getPrimalInst());
+ }
+
+ return nullptr;
+ }
+
+ IRBlock* getFirstCodeBlock(IRGlobalValueWithCode* func)
+ {
+ return func->getFirstBlock()->getNextBlock();
+ }
+
+ List<IRBlock*> getTerminalPrimalBlocks(IRGlobalValueWithCode* func)
+ {
+ // 'Terminal' primal blocks are those that branch into a differential block.
+ List<IRBlock*> terminalPrimalBlocks;
+ for (auto block : func->getBlocks())
+ for (auto successor : block->getSuccessors())
+ if (!isDifferentialInst(block) && isDifferentialInst(successor))
+ terminalPrimalBlocks.add(block);
+
+ return terminalPrimalBlocks;
+ }
+
+ List<IRBlock*> getTerminalDiffBlocks(IRGlobalValueWithCode* func)
+ {
+ // Terminal differential blocks are those with a return statement.
+ // Note that this method is designed to work with Fwd-Mode blocks,
+ // and this logic will be different for Rev-Mode blocks.
+ //
+ List<IRBlock*> terminalDiffBlocks;
+ for (auto block : func->getBlocks())
+ if (as<IRReturn>(block->getTerminator()))
+ terminalDiffBlocks.add(block);
+
+ return terminalDiffBlocks;
+ }
+
+ IRInst* addPredecessorForBlock(IRBlock* block, IRBlock* predBlock)
+ {
+ if (!this->blockEntries.ContainsKey(block))
+ {
+ // We haven't encountered this block yet, create a var for this in the
+ // first code block.
+ auto firstCodeBlock = getFirstCodeBlock(block->getParent());
+
+ IRBuilder subBuilder(this->autodiffContext->sharedBuilder);
+ subBuilder.setInsertBefore(firstCodeBlock->getTerminator());
+ auto controlVar = subBuilder.emitVar(subBuilder.getUIntType());
+
+ ControlFlowTranspositionInfo info;
+ info.controlVar = controlVar;
+
+ this->blockEntries[block] = info;
+ }
+
+ auto info = this->blockEntries[block];
+
+ // Does precessor block already exist?
+ if (info.GetValue().predEntries.ContainsKey(predBlock))
+ {
+ return info.GetValue().predEntries[predBlock];
+ }
+
+ // Otherwise, create an entry..
+ auto uniqueIndex = info.GetValue().predEntries.Count();
+
+ IRBuilder builder(this->autodiffContext->sharedBuilder);
+ auto uniqueIndexLiteral = builder.getIntValue(builder.getUIntType(), uniqueIndex);
+
+ info.GetValue().predEntries[predBlock] = uniqueIndexLiteral;
+
+ return uniqueIndexLiteral;
+ }
+
+ IRVar* getControlVar(IRBlock* block)
+ {
+ return this->blockEntries[block].GetValue().controlVar;
+ }
+
+ // Inserts a block between the branch from fwdPredecessorBlock to fwdBlock, which sets a control
+ // variable to a unique index.
+ //
+ IRInst* insertPreludeForPredecessor(IRBlock* fwdBlock, IRBlock* fwdPredecessorBlock)
+ {
+ // Get associated primal blocks for both the differential blocks.
+ auto primalPredecessorBlock = getPrimalBlock(fwdPredecessorBlock);
+ SLANG_ASSERT(primalPredecessorBlock);
+
+ auto primalBlock = getPrimalBlock(fwdBlock);
+ SLANG_ASSERT(primalBlock);
+
+ // Add this block as a predecessor, and get an unique index (as an integer literal)
+ auto indexVal = addPredecessorForBlock(fwdBlock, fwdPredecessorBlock);
+
+ IRBuilder subBuilder(this->autodiffContext->sharedBuilder);
+ subBuilder.setInsertInto(primalPredecessorBlock->getParent());
+
+ IRInst* preludeBlock = subBuilder.emitBlock();
+ preludeBlock->insertAfter(primalPredecessorBlock);
+
+ // Copy over phi parameters.
+ List<IRInst*> phiParams;
+ for (auto param = primalBlock->getFirstParam(); param; param = param->getNextParam())
+ {
+ phiParams.add(subBuilder.emitParam(param->getDataType()));
+ }
+
+ auto controlVar = getControlVar(fwdBlock);
+ subBuilder.emitStore(controlVar, indexVal);
+
+ // Branch into the successor block using all the same phi parameters.
+ subBuilder.emitBranch(primalBlock, phiParams.getCount(), phiParams.getBuffer());
+
+ // Scan through uses of primalBlock to find the ones that are in
+ // primalPredecessorBlock, and replace them with branches to
+ // preludeBlock.
+ //
+ List<IRUse*> relevantUses;
+ for (auto use = primalBlock->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser()->getParent() == primalPredecessorBlock)
+ relevantUses.add(use);
+ }
+
+ for (auto use : relevantUses)
+ use->set(preludeBlock);
+
+ return indexVal;
+ }
+
+ bool doesBlockHaveDifferentialPredecessors(IRBlock* fwdBlock)
+ {
+ for (auto block : fwdBlock->getPredecessors())
+ {
+ if (isDifferentialInst(block))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ void emitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads)
+ {
+ // If this block has no differential predecessors, add a return statement.
+ if (!doesBlockHaveDifferentialPredecessors(fwdBlockInst))
+ {
+ // Emit a void return.
+ builder->emitReturn();
+ return;
+ }
+
+ for (auto predecessor : fwdBlockInst->getPredecessors())
+ {
+ // Insert code into the *primal* version of the predecessor block
+ // to set the control variable to indexVal before branching.
+ //
+ insertPreludeForPredecessor(fwdBlockInst, predecessor);
+ }
+
+ List<IRBlock*> revPredecessorBlocks;
+ List<IRInst*> indexVals;
+
+ for (auto blockEntry : this->blockEntries[fwdBlockInst].GetValue().predEntries)
+ {
+ revPredecessorBlocks.add(revBlockMap[blockEntry.Key]);
+ indexVals.add(blockEntry.Value);
+ }
+
+ auto predCount = revPredecessorBlocks.getCount();
+
+ SLANG_ASSERT(predCount > 0);
+
+ List<IRBlock*> intermediateBranchBlocks;
+
+ IRBuilder branchBlockBuilder(builder->getSharedBuilder());
+
+ branchBlockBuilder.setInsertInto(builder->getFunc());
+
+ // Make a block to unconditionally branch into predecessor-0 with the
+ // appropriate phi gradients.
+ //
+ auto firstBranchBlock = branchBlockBuilder.emitBlock();
+ intermediateBranchBlocks.add(firstBranchBlock);
+
+ branchBlockBuilder.markInstAsDifferential(firstBranchBlock);
+ branchBlockBuilder.emitBranch(
+ revPredecessorBlocks[0],
+ phiParamGrads.getCount(),
+ phiParamGrads.getBuffer());
+
+ // Create a builder to insert loads and comparison insts to figure
+ // out which block to branch into based on the control vars.
+ // This builder is set up to emit into the last _primal_ block.
+ //
+ IRBuilder booleanIndicatorBuilder(builder->getSharedBuilder());
+ auto terminalPrimalBlock = getTerminalPrimalBlocks(builder->getFunc())[0];
+
+ booleanIndicatorBuilder.setInsertBefore(terminalPrimalBlock->getTerminator());
+
+ if (predCount == 1)
+ {
+ builder->emitBranch(firstBranchBlock);
+ }
+ else
+ {
+ IRBuilder ladderBlockBuilder(builder->getSharedBuilder());
+ ladderBlockBuilder.setInsertInto(builder->getFunc());
+
+ // TODO: For now, we're trivially setting 'afterBlock' to
+ // the first reverse block. This is not really optimal for the
+ // restructuring passes since the 'then' and 'else' regions
+ // can have significant overlap.
+ //
+ auto firstFwdDiffBlock = (*terminalPrimalBlock->getSuccessors().begin());
+ SLANG_ASSERT(firstFwdDiffBlock);
+
+ auto defaultAfterBlock = revBlockMap[firstFwdDiffBlock];
+
+ auto nextLadderBlock = firstBranchBlock;
+ for (Index ii = 0; ii < predCount - 1; ii++)
+ {
+ // Make the 'leaf' block. This just branches into
+ // predecessor-i+1 with the appropriate phi args.
+ //
+ branchBlockBuilder.setInsertInto(branchBlockBuilder.getFunc());
+
+ auto thisIndexBlock = branchBlockBuilder.emitBlock();
+ intermediateBranchBlocks.add(thisIndexBlock);
+
+ branchBlockBuilder.markInstAsDifferential(thisIndexBlock);
+ branchBlockBuilder.emitBranch(
+ revPredecessorBlocks[ii+1],
+ phiParamGrads.getCount(),
+ phiParamGrads.getBuffer());
+
+ // Emit a boolean inst to represent whether we need to branch into
+ // block ii.
+ auto blockIndicatorInst = booleanIndicatorBuilder.emitEql(
+ booleanIndicatorBuilder.emitLoad(getControlVar(fwdBlockInst)),
+ indexVals[ii+1]);
+
+ // Create a block to branch between i+1 and the rest of the ladder so far
+ // (0 ... i)
+ //
+ auto upperLadderBlock = ladderBlockBuilder.emitBlock();
+ intermediateBranchBlocks.add(upperLadderBlock);
+
+ ladderBlockBuilder.markInstAsDifferential(upperLadderBlock);
+ ladderBlockBuilder.emitIfElse(
+ blockIndicatorInst,
+ thisIndexBlock,
+ nextLadderBlock,
+ defaultAfterBlock);
+
+ nextLadderBlock = upperLadderBlock;
+ }
+
+ // Branch into the last ladder block.
+ builder->emitBranch(nextLadderBlock);
+ }
+
+ // Insert all intermediate blocks in the order they were created, right after
+ // the current reverse block.
+
+ auto revBlock = revBlockMap[fwdBlockInst];
+ SLANG_ASSERT(revBlock);
+
+ for (auto block : intermediateBranchBlocks)
+ {
+ block->insertAfter(revBlock);
+ }
+
+ return;
+ }
TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
{
+
// Dispatch logic.
switch(fwdInst->getOp())
{
case kIROp_Add:
case kIROp_Mul:
- case kIROp_Sub:
+ case kIROp_Sub:
return transposeArithmetic(builder, fwdInst, revValue);
case kIROp_Call:
@@ -413,6 +900,16 @@ struct DiffTransposePass
case kIROp_MakeVector:
return transposeMakeVector(builder, fwdInst, revValue);
+
+ case kIROp_unconditionalBranch:
+ case kIROp_conditionalBranch:
+ case kIROp_ifElse:
+ case kIROp_loop:
+ {
+ // Ignore. transposeBlock() should take care of adding the
+ // appropriate branch instruction.
+ return TranspositionResult();
+ }
default:
SLANG_ASSERT_FAILURE("Unhandled instruction");
@@ -470,7 +967,6 @@ struct DiffTransposePass
return TranspositionResult(List<RevGradient>());
}
-
TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*)
{
@@ -935,71 +1431,6 @@ struct DiffTransposePass
nullptr);
}
- IRInst* emitAggregateDifferentialPair(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> pairGradients)
- {
- SLANG_UNEXPECTED("Should not run.");
-
- auto aggPairType = as<IRDifferentialPairType>(aggPrimalType);
- SLANG_ASSERT(aggPairType);
-
- IRType* diffType = (IRType*)pairBuilder.getDiffTypeFromPairType(builder, aggPairType);
-
- IRInst* primalInst = nullptr;
- IRInst* diffInst = nullptr;
-
- List<RevGradient> gradients;
- for (auto gradient : pairGradients)
- {
- switch (gradient.flavor)
- {
- case RevGradient::Flavor::Simple:
- {
- // In this case, the gradient is a 'pair' already, but we need to treat the primal element
- // as if it didn't exist (we simply copy it over)
- // If we already saw a pair, throw an error since we don't know how to combine to primals.
- // (i.e. something went wrong prior to this step.)
- //
- if (primalInst)
- {
- SLANG_UNEXPECTED("Encountered multiple pair types in emitAggregateDifferentialPair");
- }
-
- primalInst = builder->emitDifferentialPairGetPrimal(gradient.revGradInst);
- gradients.add(
- RevGradient(
- RevGradient::Flavor::Simple,
- gradient.targetInst,
- builder->emitDifferentialPairGetDifferential(
- diffType,
- gradient.revGradInst),
- gradient.fwdGradInst));
- break;
- }
-
- case RevGradient::Flavor::GetDifferential:
- {
- // In this case, the gradient is the result of transposing a GetDifferential
- // so we have only the gradient part. Just add it to the list of gradients to aggregate
- gradients.add(
- RevGradient(
- RevGradient::Flavor::Simple,
- gradient.targetInst,
- gradient.revGradInst,
- gradient.fwdGradInst));
- break;
- }
- default:
- SLANG_UNEXPECTED("Unexpected gradient flavor in emitAggregateDifferentialPair");
- }
- }
-
- // Aggregate only the differentials
- diffInst = emitAggregateValue(builder, aggPairType->getValueType(), gradients);
-
- // Pack them back together.
- return builder->emitMakeDifferentialPair(aggPrimalType, primalInst, diffInst);
- }
-
IRInst* emitAggregateValue(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
{
// If we're dealing with the differential-pair types, we need to use a different aggregation method, since
@@ -1138,17 +1569,25 @@ struct DiffTransposePass
return gradientsMap.ContainsKey(fwdInst);
}
- AutoDiffSharedContext* autodiffContext;
+ AutoDiffSharedContext* autodiffContext;
+
+ DifferentiableTypeConformanceContext diffTypeContext;
- DifferentiableTypeConformanceContext diffTypeContext;
+ DifferentialPairTypeBuilder pairBuilder;
- DifferentialPairTypeBuilder pairBuilder;
+ Dictionary<IRInst*, List<RevGradient>> gradientsMap;
- Dictionary<IRInst*, List<RevGradient>> gradientsMap;
+ Dictionary<IRInst*, IRVar*> revAccumulatorVarMap;
+
+ Dictionary<IRInst*, IRInst*>* primalsMap;
+
+ List<IRInst*> usedPtrs;
+
+ Dictionary<IRBlock*, ControlFlowTranspositionInfo> blockEntries;
- Dictionary<IRInst*, IRInst*>* primalsMap;
+ Dictionary<IRBlock*, IRBlock*> revBlockMap;
- List<IRInst*> usedPtrs;
+ Dictionary<IRGlobalValueWithCode*, IRBlock*> firstRevDiffBlockMap;
};
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 546d5a6ec..2fd53dbd0 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -91,7 +91,7 @@ struct ExtractPrimalFuncContext
for (UInt i = 0; i < originalFuncType->getParamCount(); i++)
paramTypes.add(originalFuncType->getParamType(i));
paramTypes.add(builder.getInOutType((IRType*)outIntermediateType));
- auto newFuncType = builder.getFuncType(paramTypes, originalFuncType->getResultType());
+ auto newFuncType = builder.getFuncType(paramTypes, builder.getVoidType());
return newFuncType;
}
@@ -100,6 +100,10 @@ struct ExtractPrimalFuncContext
if (inst->findDecoration<IRDifferentialInstDecoration>() ||
inst->findDecoration<IRMixedDifferentialInstDecoration>())
return true;
+
+ if (auto block = as<IRBlock>(inst->getParent()))
+ return isDiffInst(block);
+
return false;
}
@@ -161,6 +165,7 @@ struct ExtractPrimalFuncContext
case kIROp_DoubleType:
case kIROp_VectorType:
case kIROp_MatrixType:
+ case kIROp_BoolType:
case kIROp_Param:
case kIROp_Specialize:
case kIROp_LookupWitness:
@@ -383,47 +388,76 @@ struct ExtractPrimalFuncContext
genericMigrationContext.init(gen, as<IRGeneric>(spec->getBase()));
}
+ List<IRBlock*> diffBlocksList;
+ List<IRBlock*> primalBlocksList;
+
for (auto block : func->getBlocks())
{
if (block == paramBlock)
continue;
- if (block->findDecoration<IRDifferentialInstDecoration>() ||
- block->findDecoration<IRMixedDifferentialInstDecoration>())
+
+ if (isDiffInst(block))
+ diffBlocksList.add(block);
+ else
+ primalBlocksList.add(block);
+ }
+
+ // Go over primal blocks and store insts.
+ for (auto block : primalBlocksList)
+ {
+ // For primal insts, decide whether or not to store its result in
+ // output intermediary struct.
+ for (auto inst : block->getChildren())
{
- if (block->getFirstParam() == nullptr)
+ if (shouldStoreInst(inst))
{
- // If the block does not have any PHI nodes, just remove it and
- // replace all its uses with returnBlock.
- block->replaceUsesWith(returnBlock);
- block->removeAndDeallocate();
- }
- else
- {
- // If the block has Phi nodes, we can't directly replace it with
- // `returnBlock`, but we can turn the block into a trivial branch
- // into `returnBlock` to safely preserve the invariants of Phi nodes.
- auto inst = block->getLastParam()->getNextInst();
- for (; inst; inst = inst->getNextInst())
- inst->removeAndDeallocate();
- builder.setInsertInto(block);
- builder.emitBranch(returnBlock);
+ builder.setInsertAfter(inst);
+ storeInst(builder, inst, genericMigrationContext, outIntermediary);
}
}
+ }
+
+ // Go over differential blocks and complete
+ for (auto block : diffBlocksList)
+ {
+
+ if (block->getFirstParam() == nullptr)
+ {
+ // If the block does not have any PHI nodes, just remove it and
+ // replace all its uses with returnBlock.
+
+ // TODO: This invalides the next block in the chain. Make a list first.
+ block->replaceUsesWith(returnBlock);
+ block->removeAndDeallocate();
+ }
else
{
- // For primal insts, decide whether or not to store its result in
- // output intermediary struct.
- for (auto inst : block->getChildren())
+ // If the block has Phi nodes, we can't directly replace it with
+ // `returnBlock`, but we can turn the block into a trivial branch
+ // into `returnBlock` to safely preserve the invariants of Phi nodes.
+ auto inst = block->getLastParam()->getNextInst();
+ for (; inst;)
{
- if (shouldStoreInst(inst))
- {
- builder.setInsertAfter(inst);
- storeInst(builder, inst, genericMigrationContext, outIntermediary);
- }
+ auto nextInst = inst->getNextInst();
+ inst->removeAndDeallocate();
+ inst = nextInst;
}
+
+ builder.setInsertInto(block);
+ builder.emitBranch(returnBlock);
}
}
+ List<IRBlock*> unusedBlocks;
+ for (auto block : func->getBlocks())
+ {
+ if (!block->hasUses() && isDiffInst(block))
+ unusedBlocks.add(block);
+ }
+
+ for (auto block : unusedBlocks)
+ block->removeAndDeallocate();
+
builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
auto defVal = builder.emitDefaultConstructRaw((IRType*)intermediateType);
builder.emitStore(outIntermediary, defVal);
@@ -503,13 +537,17 @@ IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc(
if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>())
{
builder.setInsertBefore(inst);
- auto addr = builder.emitFieldAddress(builder.getPtrType(inst->getDataType()), intermediateVar, structKeyDecor->getStructKey());
+ auto addr = builder.emitFieldAddress(
+ builder.getPtrType(inst->getDataType()),
+ intermediateVar,
+ structKeyDecor->getStructKey());
auto val = builder.emitLoad(addr);
inst->replaceUsesWith(val);
instsToRemove.add(inst);
}
}
}
+
for (auto inst : instsToRemove)
{
inst->removeAndDeallocate();
@@ -517,6 +555,7 @@ IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc(
// Run simplification to DCE unnecessary insts.
eliminateDeadCode(innerFunc);
+ eliminateDeadCode(specializedPrimalFunc);
return primalFunc;
}
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 35aa55dd3..2c55b390b 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -26,6 +26,13 @@ struct DiffUnzipPass
Dictionary<IRInst*, IRInst*> primalMap;
Dictionary<IRInst*, IRInst*> diffMap;
+ // First diff block.
+ // TODO: Can the same pass object can be used for multiple functions?
+ // might run into an issue here?
+ IRBlock* firstDiffBlock;
+
+ // Dictionary<IRBlock*, List<IRBlock*>>
+
DiffUnzipPass(AutoDiffSharedContext* autodiffContext) :
autodiffContext(autodiffContext), diffTypeContext(autodiffContext)
{ }
@@ -58,34 +65,70 @@ struct DiffUnzipPass
builder->setInsertInto(unzippedFunc);
- // Work *only* with two-block functions for now.
+ // Functions need to have at least two blocks at this point (one for parameters,
+ // and atleast one for code)
+ //
SLANG_ASSERT(unzippedFunc->getFirstBlock() != nullptr);
SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock() != nullptr);
- SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock()->getNextBlock() == nullptr);
// Ignore the first block (this is reserved for parameters), start
// at the second block. (For now, we work with only a single block of insts)
// TODO: expand to handle multi-block functions later.
+ IRBlock* firstBlock = unzippedFunc->getFirstBlock()->getNextBlock();
- IRBlock* mainBlock = unzippedFunc->getFirstBlock()->getNextBlock();
+ List<IRBlock*> mixedBlocks;
+ for (IRBlock* block = firstBlock; block; block = block->getNextBlock())
+ {
+ // Only need to unzip blocks with both differential and primal instructions.
+ if (block->findDecoration<IRMixedDifferentialInstDecoration>())
+ {
+ mixedBlocks.add(block);
+ }
+ }
+
+ IRBlock* firstPrimalBlock = nullptr;
- // Emit new blocks for split vesions of mainblock.
- IRBlock* primalBlock = builder->emitBlock();
- IRBlock* diffBlock = builder->emitBlock();
+ // Emit an empty primal block for every mixed block.
+ for (auto block : mixedBlocks)
+ {
+ IRBlock* primalBlock = builder->emitBlock();
+ primalMap[block] = primalBlock;
- // Mark the differential block as a differential inst.
- builder->markInstAsDifferential(diffBlock);
+ if (block == firstBlock)
+ firstPrimalBlock = primalBlock;
+ }
- // Split the main block into two. This method should also emit
- // a branch statement from primalBlock to diffBlock.
- // TODO: extend this code to split multiple blocks
- //
- splitBlock(mainBlock, primalBlock, diffBlock);
+ // Emit an empty differential block for every mixed block.
+ for (auto block : mixedBlocks)
+ {
+ IRBlock* diffBlock = builder->emitBlock();
+ diffMap[block] = diffBlock;
+
+ // Mark the differential block as a differential inst
+ // (and add a reference to the primal block)
+ builder->markInstAsDifferential(diffBlock, nullptr, primalMap[block]);
+
+ // Record the first differential (code) block,
+ // since we want all 'return' insts in primal blocks
+ // to be replaced with a brahcn into this block.
+ //
+ if (block == firstBlock)
+ this->firstDiffBlock = diffBlock;
+ }
+
+ // Split each block into two.
+ for (auto block : mixedBlocks)
+ {
+ splitBlock(block, as<IRBlock>(primalMap[block]), as<IRBlock>(diffMap[block]));
+ }
+
+ // Swap the first block's occurences out for the first primal block.
+ firstBlock->replaceUsesWith(firstPrimalBlock);
+
+ // Remove old blocks.
+ for (auto block : mixedBlocks)
+ block->removeAndDeallocate();
- // Replace occurences of mainBlock with primalBlock
- mainBlock->replaceUsesWith(primalBlock);
- mainBlock->removeAndDeallocate();
-
return unzippedFunc;
}
@@ -221,10 +264,14 @@ struct DiffUnzipPass
return InstPair(primalBuilder->emitVar(primalType), diffBuilder->emitVar(diffType));
}
- InstPair splitReturn(IRBuilder*, IRBuilder* diffBuilder, IRReturn* mixedReturn)
+ InstPair splitReturn(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRReturn* mixedReturn)
{
auto pairType = as<IRDifferentialPairType>(mixedReturn->getVal()->getDataType());
auto primalType = pairType->getValueType();
+
+ // Check that we have an unambiguous 'first' differential block.
+ SLANG_ASSERT(firstDiffBlock);
+ auto primalBranch = primalBuilder->emitBranch(firstDiffBlock);
auto pairVal = diffBuilder->emitMakeDifferentialPair(
pairType,
@@ -235,7 +282,81 @@ struct DiffUnzipPass
auto returnInst = diffBuilder->emitReturn(pairVal);
diffBuilder->markInstAsDifferential(returnInst, primalType);
- return InstPair(nullptr, returnInst);
+ return InstPair(primalBranch, returnInst);
+ }
+
+ InstPair splitControlFlow(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* branchInst)
+ {
+ switch (branchInst->getOp())
+ {
+ case kIROp_unconditionalBranch:
+ {
+ auto uncondBranchInst = as<IRUnconditionalBranch>(branchInst);
+ auto targetBlock = uncondBranchInst->getTargetBlock();
+
+ // Split args.
+ List<IRInst*> primalArgs;
+ List<IRInst*> diffArgs;
+ for (UIndex ii = 0; ii < uncondBranchInst->getArgCount(); ii++)
+ {
+ if (isDifferentialInst(uncondBranchInst->getArg(ii)))
+ diffArgs.add(uncondBranchInst->getArg(ii));
+ else
+ primalArgs.add(uncondBranchInst->getArg(ii));
+ }
+
+ return InstPair(
+ primalBuilder->emitBranch(
+ as<IRBlock>(primalMap[targetBlock]),
+ primalArgs.getCount(),
+ primalArgs.getBuffer()),
+ diffBuilder->emitBranch(
+ as<IRBlock>(diffMap[targetBlock]),
+ diffArgs.getCount(),
+ diffArgs.getBuffer()));
+
+ }
+
+ case kIROp_conditionalBranch:
+ {
+ auto trueBlock = as<IRConditionalBranch>(branchInst)->getTrueBlock();
+ auto falseBlock = as<IRConditionalBranch>(branchInst)->getFalseBlock();
+ auto condInst = as<IRConditionalBranch>(branchInst)->getCondition();
+
+ return InstPair(
+ primalBuilder->emitBranch(
+ condInst,
+ as<IRBlock>(primalMap[trueBlock]),
+ as<IRBlock>(primalMap[falseBlock])),
+ diffBuilder->emitBranch(
+ condInst,
+ as<IRBlock>(diffMap[trueBlock]),
+ as<IRBlock>(diffMap[falseBlock])));
+ }
+
+ case kIROp_ifElse:
+ {
+ auto trueBlock = as<IRIfElse>(branchInst)->getTrueBlock();
+ auto falseBlock = as<IRIfElse>(branchInst)->getFalseBlock();
+ auto afterBlock = as<IRIfElse>(branchInst)->getAfterBlock();
+ auto condInst = as<IRIfElse>(branchInst)->getCondition();
+
+ return InstPair(
+ primalBuilder->emitIfElse(
+ condInst,
+ as<IRBlock>(primalMap[trueBlock]),
+ as<IRBlock>(primalMap[falseBlock]),
+ as<IRBlock>(primalMap[afterBlock])),
+ diffBuilder->emitIfElse(
+ condInst,
+ as<IRBlock>(diffMap[trueBlock]),
+ as<IRBlock>(diffMap[falseBlock]),
+ as<IRBlock>(diffMap[afterBlock])));
+ }
+
+ default:
+ SLANG_UNEXPECTED("Unhandled instruction");
+ }
}
InstPair _splitMixedInst(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* inst)
@@ -257,6 +378,15 @@ struct DiffUnzipPass
case kIROp_Return:
return splitReturn(primalBuilder, diffBuilder, as<IRReturn>(inst));
+ case kIROp_unconditionalBranch:
+ case kIROp_conditionalBranch:
+ case kIROp_ifElse:
+ return splitControlFlow(primalBuilder, diffBuilder, inst);
+
+ case kIROp_Unreachable:
+ return InstPair(primalBuilder->emitUnreachable(),
+ diffBuilder->emitUnreachable());
+
default:
SLANG_ASSERT_FAILURE("Unhandled mixed diff inst");
}
@@ -270,7 +400,7 @@ struct DiffUnzipPass
diffMap[inst] = instPair.differential;
}
- void splitBlock(IRBlock* mainBlock, IRBlock* primalBlock, IRBlock* diffBlock)
+ void splitBlock(IRBlock* block, IRBlock* primalBlock, IRBlock* diffBlock)
{
// Make two builders for primal and differential blocks.
IRBuilder primalBuilder;
@@ -282,12 +412,13 @@ struct DiffUnzipPass
diffBuilder.setInsertInto(diffBlock);
List<IRInst*> splitInsts;
- for (auto child = mainBlock->getFirstChild(); child;)
+ for (auto child = block->getFirstChild(); child;)
{
IRInst* nextChild = child->getNextInst();
if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(child))
{
+ // Replace GetDiff(A) with A.d
if (diffMap.ContainsKey(getDiffInst->getBase()))
{
getDiffInst->replaceUsesWith(lookupDiffInst(getDiffInst->getBase()));
@@ -296,9 +427,9 @@ struct DiffUnzipPass
continue;
}
}
-
- if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(child))
+ else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(child))
{
+ // Replace GetPrimal(A) with A.p
if (primalMap.ContainsKey(getPrimalInst->getBase()))
{
getPrimalInst->replaceUsesWith(lookupPrimalInst(getPrimalInst->getBase()));
@@ -339,12 +470,12 @@ struct DiffUnzipPass
}
// Nothing should be left in the original block.
- SLANG_ASSERT(mainBlock->getFirstChild() == nullptr);
+ SLANG_ASSERT(block->getFirstChild() == nullptr);
// Branch from primal to differential block.
// Functionally, the new blocks should produce the same output as the
// old block.
- primalBuilder.emitBranch(diffBlock);
+ // primalBuilder.emitBranch(diffBlock);
}
};
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index f0ec1542e..40c24d11d 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -421,6 +421,32 @@ void stripAutoDiffDecorations(IRModule* module)
stripAutoDiffDecorationsFromChildren(module->getModuleInst());
}
+
+void stripBlockTypeDecorations(IRFunc* func)
+{
+ for (auto child : func->getChildren())
+ {
+ if (auto block = as<IRBlock>(child))
+ {
+ for (auto decor = block->getFirstDecoration(); decor; )
+ {
+ auto next = decor->getNextDecoration();
+ switch (decor->getOp())
+ {
+ case kIROp_DifferentialInstDecoration:
+ case kIROp_MixedDifferentialInstDecoration:
+ decor->removeAndDeallocate();
+ break;
+ default:
+ break;
+ }
+ decor = next;
+ }
+ }
+ }
+}
+
+
struct StripNoDiffTypeAttributePass : InstPassBase
{
StripNoDiffTypeAttributePass(IRModule* module) :
@@ -484,7 +510,7 @@ struct AutoDiffPass : public InstPassBase
{
bool changed = false;
List<IRInst*> autoDiffWorkList;
- // Collect all `ForwardDifferentiate` insts from the module.
+ // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the module.
autoDiffWorkList.clear();
processAllInsts([&](IRInst* inst)
{
@@ -541,6 +567,7 @@ struct AutoDiffPass : public InstPassBase
// Run transcription logic to generate the body of forward/backward derivatives functions.
// While doing so, we may discover new functions to differentiate, so we keep running until
// the worklist goes dry.
+ List<IRFunc*> autodiffCleanupList;
while (autodiffContext->followUpFunctionsToTranscribe.getCount() != 0)
{
changed = true;
@@ -549,6 +576,14 @@ struct AutoDiffPass : public InstPassBase
{
auto diffFunc = as<IRFunc>(task.resultFunc);
SLANG_ASSERT(diffFunc);
+
+ // We're running in to some situations where the follow-up task
+ // has already been completed (diffFunc has been generated, processed,
+ // and deallocated). Skip over these for now.
+ //
+ if (!diffFunc->getDataType())
+ continue;
+
auto primalFunc = as<IRFunc>(task.originalFunc);
SLANG_ASSERT(primalFunc);
switch (task.type)
@@ -562,12 +597,26 @@ struct AutoDiffPass : public InstPassBase
default:
break;
}
+
+ autodiffCleanupList.add(diffFunc);
}
}
+
+ // Get rid of block-level decorations that are used to keep track of
+ // different block types. These don't work well with the IR simplification
+ // passes since they don't expect decorations in blocks.
+ //
+ for (auto diffFunc : autodiffCleanupList)
+ stripBlockTypeDecorations(diffFunc);
+
+ autodiffCleanupList.clear();
+
if (!changed)
break;
hasChanges |= changed;
}
+
+
return hasChanges;
}
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 6373334bf..03a3fb063 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -608,6 +608,7 @@ struct IRDifferentialInstDecoration : IRDecoration
IR_LEAF_ISA(DifferentialInstDecoration)
IRType* getPrimalType() { return as<IRType>(getOperand(0)); }
+ IRInst* getPrimalInst() { return as<IRInst>(getOperand(1)); }
};
struct IRPrimalValueStructKeyDecoration : IRDecoration
@@ -3423,6 +3424,11 @@ public:
addDecoration(value, kIROp_DifferentialInstDecoration, primalType);
}
+ void markInstAsDifferential(IRInst* value, IRType* primalType, IRInst* primalInst)
+ {
+ addDecoration(value, kIROp_DifferentialInstDecoration, primalType, primalInst);
+ }
+
void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable)
{
addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 33130cfb3..d8a8fb7c4 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -6623,6 +6623,10 @@ namespace Slang
case kIROp_Reinterpret:
case kIROp_GetNativePtr:
return false;
+
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ return false;
}
}