summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp82
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp6
-rw-r--r--source/slang/slang-ir-autodiff-rev.h2
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h261
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h11
5 files changed, 80 insertions, 282 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 08b946cdd..5d11b7fb3 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -9,7 +9,7 @@ void applyCheckpointSet(
CheckpointSetInfo* checkpointInfo,
IRGlobalValueWithCode* func,
HoistedPrimalsInfo* hoistInfo,
- HashSet<IRUse*> pendingUses,
+ HashSet<IRUse*>& pendingUses,
Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock,
IROutOfOrderCloneContext* cloneCtx);
@@ -561,7 +561,7 @@ void applyCheckpointSet(
CheckpointSetInfo* checkpointInfo,
IRGlobalValueWithCode* func,
HoistedPrimalsInfo* hoistInfo,
- HashSet<IRUse*> pendingUses,
+ HashSet<IRUse*>& pendingUses,
Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock,
IROutOfOrderCloneContext* cloneCtx)
{
@@ -847,11 +847,64 @@ static int getInstRegionNestLevel(
return (int)result;
}
+
+/// Legalizes all accesses to primal insts from recompute and diff blocks.
+///
RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
HoistedPrimalsInfo* hoistInfo,
IRGlobalValueWithCode* func,
Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo)
{
+ // In general, after checkpointing, we can have a function like the following:
+ // ```
+ // void func()
+ // {
+ // primal:
+ // for (int i = 0; i < 5; i++)
+ // {
+ // float x = g(i);
+ // use(x);
+ // }
+ // recompute:
+ // ...
+ // diff:
+ // for (int i = 5; i >= 0; i--)
+ // {
+ // recompute:
+ // ...
+ // diff:
+ // use_diff(x); // def of x is not dominating this location!
+ // }
+ // }
+ // ```
+ // This function will legalize the access to x in the dff block by creating
+ // a proper local variable and insert store/loads, so that the above function
+ // will be transformed to:
+ // ```
+ // void func()
+ // {
+ // primal:
+ // float x_storage[5];
+ //
+ // for (int i = 0; i < 5; i++)
+ // {
+ // float x = g(i);
+ // x_storage[i] = x;
+ // use(x);
+ // }
+ // recompute:
+ // ...
+ // diff:
+ // for (int i = 5; i >= 0; i--)
+ // {
+ // recompute:
+ // ...
+ // diff:
+ // use_diff(x_storage[i]);
+ // }
+ // }
+ //
+
RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
IRBlock* defaultVarBlock = func->getFirstBlock()->getNextBlock();
@@ -1027,7 +1080,6 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
return hoistInfo;
}
-
void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info)
{
if (info->status != IndexTrackingInfo::CountStatus::Unresolved)
@@ -1042,7 +1094,6 @@ void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info)
}
}
-
IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type)
{
builder->setInsertInto(block);
@@ -1175,6 +1226,13 @@ void lowerIndexedRegion(IRLoop*& primalLoop, IRLoop*& diffLoop, IRInst*& primalC
}
}
+// Insert iteration counters for all loops to form indexed regions. For loops in
+// primal blocks, the counter is incremented from 0. For loops in reverse
+// blocks, the counter is decremented from the final value in primal block
+// downto 0. Returns a mapping from each block to a list of their enclosing loop
+// regions. A loop region records the iteration counter for the corresponding
+// loop in the primal block and the reverse block.
+//
void buildIndexedBlocks(
Dictionary<IRBlock*, List<IndexTrackingInfo>>& info,
IRGlobalValueWithCode* func)
@@ -1218,23 +1276,37 @@ void buildIndexedBlocks(
}
}
+// For each primal inst that is used in reverse blocks, decide if we should recompute or store
+// its value, then make them accessible in reverse blocks based the decision.
+//
RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
{
sortBlocksInFunc(func);
+ // Insert loop counters and establish loop regions.
+ // Also makes the reverse loops counting downwards from the final iteration count.
+ //
Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo;
buildIndexedBlocks(indexedBlockInfo, func);
+ // Create recompute blocks for each region following the same control flow structure
+ // as in primal code.
+ //
RefPtr<IROutOfOrderCloneContext> cloneCtx = new IROutOfOrderCloneContext();
auto recomputeBlockMap = createPrimalRecomputeBlocks(func, indexedBlockInfo, cloneCtx);
sortBlocksInFunc(func);
+ // Determine the strategy we should use to make a primal inst available.
+ // If we decide to recompute the inst, emit the recompute inst in the corresponding recompute block.
+ //
RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(func->getModule());
chkPolicy->preparePolicy(func);
-
auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap, cloneCtx);
+ // Legalize the primal inst accesses by introducing local variables / arrays and emitting
+ // necessary load/store logic.
+ //
primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
return primalsInfo;
}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index ef1bdaf1e..e3575aceb 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -715,8 +715,8 @@ namespace Slang
eliminateDeadCode(diffPropagateFunc);
- // Extracts the primal computations into its own func, and replace the primal insts
- // with the intermediate results computed from the extracted func.
+ // Extracts the primal computations into its own func, turn all accesses to stored primal insts into
+ // explicit intermediate data structure reads and writes.
IRInst* intermediateType = nullptr;
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
diffPropagateFunc, primalFunc, primalsInfo, paramTransposeInfo, intermediateType);
@@ -779,7 +779,7 @@ namespace Slang
initializeLocalVariables(builder->getModule(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric)));
initializeLocalVariables(builder->getModule(), diffPropagateFunc);
- // insertVariableForRecomputedPrimalInsts(diffPropagateFunc);
+
stripTempDecorations(diffPropagateFunc);
sortBlocksInFunc(diffPropagateFunc);
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index 86a6f2846..845372ba7 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -114,8 +114,6 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc);
- void insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc);
-
void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc);
InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index bcc494fa9..8a734446d 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -514,7 +514,6 @@ struct DiffTransposePass
// (i.e. not store per-func info in 'this')
// since it is reused for every reverse-mode call.
//
- primalVarsToHoist.clear();
// Grab all differentiable type information.
diffTypeContext.setFunc(revDiffFunc);
@@ -663,9 +662,6 @@ struct DiffTransposePass
for (auto block : workList)
block->removeFromParent();
- finishHoistingPrimals(revDiffFunc);
-
-
// At this point, the only block left without terminator insts
// should be the last one. Add a void return to complete it.
//
@@ -972,259 +968,6 @@ struct DiffTransposePass
}
- struct InvInstPair
- {
- IRInst* inst;
- IRInst* invInst;
-
- InvInstPair(IRInst* inst, IRInst* invInst) :
- inst(inst), invInst(invInst)
- { }
-
- InvInstPair() : inst(nullptr), invInst(nullptr)
- { }
- };
-
- List<InvInstPair> invertArithmetic(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo)
- {
- SLANG_RELEASE_ASSERT(invInfo.requiredOperands.getCount() == 1);
- SLANG_RELEASE_ASSERT(invInfo.targetInsts.getCount() == 1);
-
- auto invOutput = invInfo.requiredOperands[0];
-
- auto invTargetInst = invInfo.targetInsts[0];
-
- switch (primalInst->getOp())
- {
- case kIROp_Add:
- {
- SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1)));
- return List<InvInstPair>(
- InvInstPair(
- invTargetInst,
- builder->emitSub(
- primalInst->getOperand(0)->getDataType(),
- invOutput,
- primalInst->getOperand(1))));
- }
- case kIROp_Sub:
- {
- SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1)));
- return List<InvInstPair>(
- InvInstPair(
- invTargetInst,
- builder->emitAdd(
- primalInst->getOperand(0)->getDataType(),
- invOutput,
- primalInst->getOperand(1))));
- }
-
- default:
- SLANG_UNEXPECTED("Unhandled arithmetic inst for inversion");
- }
- }
-
- // Go through loop block phi-args, and look for loop counter
- // arguments, which for a loop means inserting a check into
- // loop condition block.
- // This method also adds logic to skip the first iteration.
- // (a 'do-while' loop)
- //
- void invertLoopCondition(IRBuilder* builder, IRLoop* loopInst)
- {
- auto firstLoopBlock = loopInst->getTargetBlock();
-
- IRBlock* revLoopCondBlock = revBlockMap[firstLoopBlock];
- builder->setInsertBefore(revLoopCondBlock->getTerminator());
-
- // Add a terminating condition based on the loop counter's initial primal value
-
- IRParam* loopCounterParam = nullptr;
- UIndex loopCounterParamIndex = 0;
- for (auto param : firstLoopBlock->getParams())
- {
- if (param->findDecoration<IRLoopCounterDecoration>())
- {
- // There really not should be two (or more) loop counter params.
- SLANG_RELEASE_ASSERT(loopCounterParam == nullptr);
- loopCounterParam = param;
- }
- else
- {
- loopCounterParamIndex++;
- }
- }
-
- // Should see atleast one loop counter parameter on the first loop block.
- SLANG_RELEASE_ASSERT(loopCounterParam);
-
- IRInst* loopCounterInitVal = loopInst->getArg(loopCounterParamIndex);
-
- auto paramBoundsCheck = builder->emitIntrinsicInst(
- builder->getBoolType(),
- kIROp_Neq,
- 2,
- List<IRInst*>(
- loopCounterParam,
- loopCounterInitVal).getBuffer());
-
- as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(paramBoundsCheck);
- }
-
- IRInst* lookupInstInPrimalBlock(IRInst* invInst)
- {
- // Lookup the inst in the primal block whose value we can use as an operand
- // for the inverted inst.
- //
- // auto inversionInfo = this->hoistedPrimalsInfo->invertInfoMap[invInst];
- return invInst;
- }
-
- bool doesInstRequireHoisting(IRInst* inst)
- {
- if (as<IRModuleInst>(inst->getParent()))
- return false;
-
- if (as<IRBlock>(inst) ||
- as<IRGlobalValueWithCode>(inst) ||
- as<IRConstant>(inst))
- return false;
-
- if (as<IRTerminatorInst>(inst))
- return false;
-
- if (as<IRDecoration>(inst))
- return doesInstRequireHoisting(getInstInBlock(inst));
-
- // We're looking for primal insts in differential blocks
- // that have not yet been moved to the 'active' blocks
- // (i.e in diff blocks that do not have parents)
- //
- return (!isDifferentialInst(inst) &&
- (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) &&
- getBlock(inst)->getParent() == nullptr);
- }
-
- IRBlock* walkToEndOfRegion(IRBlock* block)
- {
- IRBlock* currBlock = block;
-
- bool keepGoing = true;
- while (keepGoing)
- {
- auto terminator = currBlock->getTerminator();
- switch (terminator->getOp())
- {
- case kIROp_Return:
- keepGoing = false;
- break;
-
- case kIROp_unconditionalBranch:
- {
- auto nextBlock = as<IRUnconditionalBranch>(terminator)->getTargetBlock();
-
- HashSet<IRBlock*> predecessorSet;
- for (auto predecessor : nextBlock->getPredecessors())
- predecessorSet.add(predecessor);
-
- if (predecessorSet.getCount() > 1)
- {
- keepGoing = false;
- break;
- }
-
- currBlock = nextBlock;
- break;
- }
-
- case kIROp_ifElse:
- {
- for (auto predecessor : currBlock->getPredecessors())
- {
- if (as<IRLoop>(predecessor->getTerminator()))
- {
- keepGoing = false;
- break;
- }
- }
-
- currBlock = as<IRIfElse>(terminator)->getAfterBlock();
- break;
- }
-
- case kIROp_Switch:
- currBlock = as<IRSwitch>(terminator)->getBreakLabel();
- break;
-
- case kIROp_loop:
- currBlock = as<IRLoop>(terminator)->getBreakBlock();
- break;
- }
- }
-
- return currBlock;
- }
-
- void finishHoistingPrimals(IRGlobalValueWithCode* func)
- {
- auto varBlock = func->getFirstBlock()->getNextBlock();
-
- for (auto inst : primalVarsToHoist)
- {
- if (!doesInstRequireHoisting(inst))
- continue;
-
- List<IRUse*> relevantUses;
-
- IRBlock* defBlock = nullptr;
- if (auto varToHoist = as<IRVar>(inst))
- {
- varToHoist->insertBefore(varBlock->getFirstOrdinaryInst());
- auto uniqueStoreUse = findUniqueStoredVal(varToHoist);
- if (uniqueStoreUse)
- {
- inst = uniqueStoreUse->getUser();
- SLANG_ASSERT(inst);
-
- defBlock = getBlock(inst);
- }
- else
- {
- defBlock = getBlock(inst);
- }
- }
- else
- {
- defBlock = getBlock(inst);
- }
-
- if (!doesInstRequireHoisting(inst))
- continue;
-
- // Move this inst to after it's diff uses.
- //
- {
-
- IRBlock* currTopBlock = revBlockMap[walkToEndOfRegion(defBlock)];
-
- SLANG_RELEASE_ASSERT(currTopBlock);
-
- // More consistency checks
- SLANG_RELEASE_ASSERT(currTopBlock->getFirstOrdinaryInst() != nullptr);
- SLANG_RELEASE_ASSERT(currTopBlock->getParent() != nullptr);
- SLANG_RELEASE_ASSERT(isDifferentialInst(currTopBlock));
-
- // Insert at top. (disabling validation since the operands of
- // this inst might not be hoisted to the right place yet)
- //
- disableIRValidationAtInsert();
- inst->insertBefore(currTopBlock->getFirstOrdinaryInst());
- enableIRValidationAtInsert();
- }
- }
- }
-
-
void transposeInst(IRBuilder* builder, IRInst* inst)
{
switch (inst->getOp())
@@ -1386,8 +1129,6 @@ struct DiffTransposePass
auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType();
auto tempVar = builder->emitVar(pairType);
auto primalVal = builder->emitLoad(instPair->getPrimal());
- auto primalVar = instPair->getPrimal();
- primalVarsToHoist.add(primalVar);
auto diffVal = builder->emitLoad(instPair->getDiff());
auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal);
@@ -3002,8 +2743,6 @@ struct DiffTransposePass
DifferentialPairTypeBuilder pairBuilder;
- List<IRInst*> primalVarsToHoist;
-
IRBlock* tempInvBlock;
Dictionary<IRInst*, List<RevGradient>> gradientsMap;
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index e0723dcdd..34f0f6c9b 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -313,17 +313,6 @@ struct DiffUnzipPass
auto primalArg = lookupPrimalInst(arg);
auto diffArg = lookupDiffInst(arg);
- if (auto primalVar = as<IRVar>(primalArg))
- {
- primalArg = diffBuilder->emitVar(as<IRPtrTypeBase>(primalVar->getDataType())->getValueType());
- if (auto storeUse = findUniqueStoredVal(primalVar))
- {
- auto storeInst = diffBuilder->emitStore(primalArg, as<IRStore>(storeUse->getUser())->getVal());
- storeInst->insertAfter(storeUse->getUser());
- primalArg->insertBefore(storeInst);
- }
- }
-
// If arg is a mixed differential (pair), it should have already been split.
SLANG_ASSERT(primalArg);
SLANG_ASSERT(diffArg);