summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-autodiff-transpose.h
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-21 14:28:57 -0700
committerGitHub <noreply@github.com>2023-04-21 14:28:57 -0700
commit957a4d3eb0a14a9d57bbb325ef0e1d458df2d2b9 (patch)
treefabc9317b1595c9f74f5b25ee83d16f4260a19d3 /source/slang/slang-ir-autodiff-transpose.h
parent69a327a98e3f9504863f9ecb623aa93036ac43db (diff)
Refactor checkpointing policy and availability pass. (#2826)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-transpose.h')
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h560
1 files changed, 33 insertions, 527 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 8c005a5c6..c7ac8c357 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -80,11 +80,6 @@ struct DiffTransposePass
// of the *output* of the function.
//
IRInst* dOutInst;
-
- // Information from the unzip pass on how primal insts
- // are split across the primal and differential blocks.
- //
- HoistedPrimalsInfo* hoistedPrimalsInfo;
};
struct PendingBlockTerminatorEntry
@@ -235,16 +230,12 @@ struct DiffTransposePass
builder.setInsertInto(revCondBlock);
- //hoistPrimalInst(&builder, ifElse->getCondition());
-
- auto newIfElse = builder.emitIfElse(
+ builder.emitIfElse(
ifElse->getCondition(),
revTrueEntryBlock,
revFalseEntryBlock,
revAfterBlock);
- hoistPrimalOperands(&builder, newIfElse);
-
if (!revTrueRegionInfo.isTrivial)
{
builder.setInsertInto(revTrueExitBlock);
@@ -358,21 +349,21 @@ struct DiffTransposePass
// Emit condition into the new cond block.
builder.setInsertInto(revCondBlock);
- // TODO: Need to defer this until after the CFG reversal is complete.
- //hoistPrimalInst(&builder, ifElse->getCondition());
-
- auto newIfElse = builder.emitIfElse(
+ builder.emitIfElse(
ifElse->getCondition(),
revTrueBlock,
revFalseBlock,
revTrueBlock);
-
- hoistPrimalOperands(&builder, newIfElse);
-
+
+ auto loopParentBlockDiffDecor = loop->getParent()->findDecoration<IRDifferentialInstDecoration>();
+ SLANG_RELEASE_ASSERT(loopParentBlockDiffDecor);
+ auto primalBlock = as<IRBlock>(loopParentBlockDiffDecor->getPrimalInst());
+ auto primalLoop = as<IRLoop>(primalBlock->getTerminator());
+ SLANG_RELEASE_ASSERT(primalLoop);
+
// Old false-side starting block becomes end block
// for the new pre-cond region (which could be empty)
//
-
if (!falseRegionInfo.isTrivial)
{
IRBlock* revPreCondEndBlock = revBlockMap[falseBlock];
@@ -384,7 +375,8 @@ struct DiffTransposePass
getPhiGrads(falseBlock).getCount(),
getPhiGrads(falseBlock).getBuffer());
loop->transferDecorationsTo(revLoop);
-
+ builder.markInstAsDifferential(revLoop, builder.getVoidType(), primalLoop);
+
auto revLoopStartBlock = revBlockMap[breakBlock];
builder.setInsertInto(revLoopStartBlock);
builder.emitBranch(
@@ -404,6 +396,7 @@ struct DiffTransposePass
getPhiGrads(breakBlock).getCount(),
getPhiGrads(breakBlock).getBuffer());
loop->transferDecorationsTo(revLoop);
+ builder.markInstAsDifferential(revLoop, builder.getVoidType(), primalLoop);
}
currentBlock = breakBlock;
@@ -478,17 +471,13 @@ struct DiffTransposePass
builder.setInsertInto(revSwitchBlock);
- // hoistPrimalInst(&builder, switchInst->getCondition());
-
- auto newSwitchInst = builder.emitSwitch(
+ builder.emitSwitch(
switchInst->getCondition(),
revBreakBlock,
revDefaultRegionEntry,
reverseSwitchArgs.getCount(),
reverseSwitchArgs.getBuffer());
- hoistPrimalOperands(&builder, newSwitchInst);
-
currentBlock = breakBlock;
break;
}
@@ -525,9 +514,7 @@ struct DiffTransposePass
// (i.e. not store per-func info in 'this')
// since it is reused for every reverse-mode call.
//
-
- hoistedPrimalsInfo = transposeInfo.hoistedPrimalsInfo;
-
+ primalVarsToHoist.clear();
// Grab all differentiable type information.
diffTypeContext.setFunc(revDiffFunc);
@@ -576,8 +563,10 @@ struct DiffTransposePass
// Emit empty rev-mode blocks for every fwd-mode block.
for (auto block : workList)
{
- revBlockMap[block] = builder.emitBlock();
- builder.markInstAsDifferential(revBlockMap[block]);
+ auto revBlock = builder.emitBlock();
+ revBlockMap[block] = revBlock;
+ if (auto diffDecor = block->findDecoration<IRDifferentialInstDecoration>())
+ builder.markInstAsDifferential(revBlockMap[block], builder.getBasicBlockType(), diffDecor->getPrimalInst());
}
// Keep track of first diff block, since this is where
@@ -637,20 +626,6 @@ struct DiffTransposePass
auto firstFwdDiffBlock = branchInst->getTargetBlock();
reverseCFGRegion(firstFwdDiffBlock, List<IRBlock*>());
- // Lower any loop-exit-value decorations into initializations for loop intermediate vals,
- // and convert loop initial values into terminating conditions.
- //
- // TODO: We need a way to confirm that all required vars have an initial value
- // (is there a built-in dataflow tool for this?)
- //
- for (auto block : workList)
- {
- if (auto loopInst = as<IRLoop>(block->getTerminator()))
- {
- invertLoopCondition(&builder, loopInst);
- }
- }
-
// Link the last differential fwd-mode block (which will be the first
// rev-mode block) as the successor to the last primal block.
// We assume that the original function is in single-return form
@@ -688,43 +663,9 @@ struct DiffTransposePass
for (auto block : workList)
block->removeFromParent();
- // Mark all primal operands for hoisting.
- // TODO: Can we just merge this with finishHoistingPrimalInsts?
- // TODO: Some of this logic is replicated in finishHoistingPrimalInsts. Merge it with the
- // maybeAddOperandsToWorkList logic there.
- //
- for (auto block : workList)
- {
- IRBlock* revBlock = revBlockMap[block];
-
- for (auto child = revBlock->getFirstChild(); child; child = child->getNextInst())
- {
- hoistPrimalOperands(&builder, child);
-
- for (auto decoration = child->getFirstDecoration(); decoration; decoration = decoration->getNextDecoration())
- {
- if (auto contextDecoration = as<IRBackwardDerivativePrimalContextDecoration>(decoration))
- hoistPrimalUse(&builder, &contextDecoration->primalContextVar);
-
- if (auto loopExitDecoration = as<IRLoopExitPrimalValueDecoration>(decoration))
- hoistPrimalUse(&builder, &loopExitDecoration->exitVal);
- }
-
- if (auto instType = child->getDataType())
- if (!as<IRModuleInst>(instType->getParent()))
- hoistPrimalUse(&builder, &child->typeUse);
- }
- }
-
finishHoistingPrimals(revDiffFunc);
- for (auto block : workList)
- {
- auto revBlock = as<IRBlock>(revBlockMap[block]);
- if (auto revLoop = as<IRLoop>(revBlock->getTerminator()))
- lowerLoopExitValues(&builder, revLoop);
- }
-
+
// At this point, the only block left without terminator insts
// should be the last one. Add a void return to complete it.
//
@@ -793,51 +734,6 @@ struct DiffTransposePass
return tempRevVar;
}
- IRVar* lookupInverseVar(IRInst* inst)
- {
- return inverseVarMap[inst];
- }
-
- IRVar* getOrCreateInverseVar(IRInst* primalInst, IRGlobalValueWithCode* func)
- {
- IRBlock* varBlock = firstRevDiffBlockMap[func];
- return getOrCreateInverseVar(primalInst, varBlock);
- }
-
- IRVar* getOrCreateInverseVar(IRInst* primalInst)
- {
- IRBlock* varBlock = firstRevDiffBlockMap[as<IRFunc>(primalInst->getParent()->getParent())];
- return getOrCreateInverseVar(primalInst, varBlock);
- }
-
- IRVar* getOrCreateInverseVar(IRInst* primalInst, IRBlock* varBlock)
- {
- // No need to store inverse values for constants.
- if (as<IRConstant>(primalInst))
- return nullptr;
-
- // Check if we have a var already.
- if (inverseVarMap.ContainsKey(primalInst))
- return inverseVarMap[primalInst];
-
- IRBuilder tempVarBuilder(autodiffContext->moduleInst);
-
- if (auto firstInst = varBlock->getFirstOrdinaryInst())
- tempVarBuilder.setInsertBefore(firstInst);
- else
- tempVarBuilder.setInsertInto(varBlock);
-
- auto primalType = primalInst->getDataType();
-
- // Emit a var in the top-level differential block to hold the inverse,
- // and initialize it.
- auto tempInvVar = tempVarBuilder.emitVar(primalType);
-
- inverseVarMap[primalInst] = tempInvVar;
-
- return tempInvVar;
- }
-
bool isInstUsedOutsideParentBlock(IRInst* inst)
{
auto currBlock = inst->getParent();
@@ -900,37 +796,9 @@ struct DiffTransposePass
revParam,
nullptr));
}
- else if (hasInverse(arg))
- {
- InversionInfo invInfo = this->hoistedPrimalsInfo->invertInfoMap[branchInst];
- if (invInfo.targetInsts.contains(arg))
- {
- SLANG_ASSERT(hasInverse(getParamAt(branchInst->getTargetBlock(), ii)));
-
- // If the output arg is a primal, emit a parameter
- // to accept it as an _input_ for the reverse-mode
- //
- auto primalType = arg->getDataType();
- auto primalInvParam = builder.emitParam(primalType);
-
- invBuilder.setInsertBefore(branchInst);
- setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam);
- }
- }
else
{
- if (hasInverse(getParamAt(branchInst->getTargetBlock(), ii)))
- {
- auto primalType = arg->getDataType();
- auto primalInvParam = builder.emitParam(primalType);
-
- invBuilder.setInsertBefore(branchInst);
- setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam);
- }
- else
- {
- SLANG_UNEXPECTED("Encountered phi-param is not differential and is not marked for inversion");
- }
+ SLANG_UNEXPECTED("Encountered phi-param is not differential and is not marked for inversion");
}
}
}
@@ -989,15 +857,6 @@ struct DiffTransposePass
if (isDifferentialInst(child))
transposeInst(&builder, child);
- else if (shouldInstBeInverted(child))
- {
- // We'll collect inverse insts in an orphaned block,
- // so disable IR validation temporarily.
- //
- disableIRValidationAtInsert();
- invertInst(&invBuilder, child);
- enableIRValidationAtInsert();
- }
}
// After processing the block's instructions, we 'flush' any remaining gradients
@@ -1046,10 +905,6 @@ struct DiffTransposePass
emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param)));
}
}
- else if (hasInverse(param))
- {
- phiParamRevGradInsts.add(param);
- }
else
{
SLANG_UNEXPECTED("param is neither differential inst nor marked for inversion");
@@ -1169,46 +1024,6 @@ struct DiffTransposePass
}
}
- // NOTE: This is a workaround for the fact that we expect inverses to use
- // single-use variables. The loop exit value will add a
- // second store to most inv-variables and mess with the primal hoisting mechanism.
- // Instead of emitting into the orphaned inverse block, we'll directly emit into
- // the reverse-mode block since we'll be running this _after_ the primal hoisting
- // pass.
- //
- // This workaround is fine for inverting loop counters, but when we want to
- // expand to supporting general-purpose adjoints, we would want to use per-region
- // inverse vars based on 'invInfo' (enforcing single-use vars)
- //
- void lowerLoopExitValues(IRBuilder* builder, IRLoop* revLoop)
- {
- List<IRDecoration*> processedDecorations;
- for (auto decoration : revLoop->getDecorations())
- {
- if (auto loopExitValueDecoration = as<IRLoopExitPrimalValueDecoration>(decoration))
- {
- builder->setInsertBefore(revLoop);
- setInverse(
- builder,
- nullptr,
- builder->getFunc(),
- loopExitValueDecoration->getTargetInst(),
- loopExitValueDecoration->getLoopExitValInst());
-
- processedDecorations.add(loopExitValueDecoration);
- }
- }
-
- for (auto decoration : processedDecorations)
- decoration->removeAndDeallocate();
- }
-
- void lowerLoopExitValues(IRBuilder* builder, IRBlock* block)
- {
- if (auto loopInst = as<IRLoop>(block->getTerminator()))
- lowerLoopExitValues(builder, loopInst);
- }
-
// Go through loop block phi-args, and look for loop counter
// arguments, which for a loop means inserting a check into
// loop condition block.
@@ -1253,41 +1068,9 @@ struct DiffTransposePass
loopCounterParam,
loopCounterInitVal).getBuffer());
- hoistPrimalOperands(builder, paramBoundsCheck);
-
as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(paramBoundsCheck);
}
- List<InvInstPair> invertInst(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo)
- {
- switch (primalInst->getOp())
- {
- case kIROp_Add:
- case kIROp_Sub:
- return invertArithmetic(builder, primalInst, invInfo);
-
- default:
- SLANG_UNIMPLEMENTED_X("Unhandled inst type for inversion");
- }
- }
-
- bool hasInverse(IRInst* primalInst)
- {
- return this->hoistedPrimalsInfo->invertSet.Contains(primalInst);
- }
-
- IRInst* loadInverse(IRBuilder* builder, IRInst* primalInst)
- {
- // Note: There are other possible cases here, although not important
- // right now. For example, a value is available to load from the primal block.
- //
-
- if (auto invVar = getOrCreateInverseVar(primalInst, builder->getFunc()))
- return builder->emitLoad(invVar);
-
- return nullptr;
- }
-
IRInst* lookupInstInPrimalBlock(IRInst* invInst)
{
// Lookup the inst in the primal block whose value we can use as an operand
@@ -1296,37 +1079,7 @@ struct DiffTransposePass
// auto inversionInfo = this->hoistedPrimalsInfo->invertInfoMap[invInst];
return invInst;
}
-
- void setInverse(IRBuilder* builder, IRBlock* defBlock, IRGlobalValueWithCode* func, IRInst* inst, IRInst* invInst)
- {
- auto instBlock = as<IRBlock>(inst->getParent());
- if (!instBlock)
- return;
-
- disableIRValidationAtInsert();
- if (auto invVar = getOrCreateInverseVar(inst, func))
- {
- auto invStore = builder->emitStore(invVar, invInst);
- mapStoreToDefBlock[as<IRStore>(invStore)] = defBlock;
- }
- enableIRValidationAtInsert();
- }
-
- bool shouldInstBeInverted(IRInst* inst)
- {
-
- if (this->hoistedPrimalsInfo->instsToInvert.Contains(inst))
- return true;
-
- return false;
- }
-
- IRInst* hoistPrimalUse(IRBuilder*, IRUse* use)
- {
- primalUsesToHoist.add(use);
- return use->get();
- }
-
+
bool doesInstRequireHoisting(IRInst* inst)
{
if (as<IRModuleInst>(inst->getParent()))
@@ -1336,10 +1089,10 @@ struct DiffTransposePass
as<IRGlobalValueWithCode>(inst) ||
as<IRConstant>(inst))
return false;
-
+
if (as<IRTerminatorInst>(inst))
return false;
-
+
if (as<IRDecoration>(inst))
return doesInstRequireHoisting(getInstInBlock(inst));
@@ -1347,30 +1100,9 @@ struct DiffTransposePass
// that have not yet been moved to the 'active' blocks
// (i.e in diff blocks that do not have parents)
//
- return (!isDifferentialInst(inst) &&
- (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) &&
- getBlock(inst)->getParent() == nullptr);
- }
-
- // Builds a map from inst to a list of uses by primal _inverted_ insts.
- Dictionary<IRInst*, List<IRInst*>> buildInvOperandMap()
- {
- Dictionary<IRInst*, List<IRInst*>> invOperandMap;
- for (auto kvpair : this->hoistedPrimalsInfo->invertInfoMap)
- {
- InversionInfo invInfo = kvpair.Value;
-
- for (auto operand : invInfo.requiredOperands)
- {
- if (!invOperandMap.ContainsKey(operand))
- invOperandMap[operand] = List<IRInst*>();
-
- for (auto target : invInfo.targetInsts)
- invOperandMap[operand].GetValue().add(target);
- }
- }
-
- return invOperandMap;
+ return (!isDifferentialInst(inst) &&
+ (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) &&
+ getBlock(inst)->getParent() == nullptr);
}
IRBlock* walkToEndOfRegion(IRBlock* block)
@@ -1435,186 +1167,13 @@ struct DiffTransposePass
void finishHoistingPrimals(IRGlobalValueWithCode* func)
{
- List<IRInst*> workList;
-
- Dictionary<IRInst*, IRInst*> hoistedInstMap;
-
- RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
-
- Dictionary<IRInst*, List<IRInst*>> invOperandMap = buildInvOperandMap();
-
auto varBlock = func->getFirstBlock()->getNextBlock();
-
- // Load up pending insts into workList.
- for (auto use : primalUsesToHoist)
- workList.add(use->get());
-
- primalUsesToHoist.clear();
-
- auto maybeAddPrimalOperandsToWorkList = [&](IRInst* inst)
+ for (auto inst : primalVarsToHoist)
{
- UIndex opIndex = 0;
- for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++)
- {
- if (doesInstRequireHoisting(operand->get()) &&
- !hoistedInstMap.ContainsKey(operand->get()))
- {
- workList.add(operand->get());
- }
- }
-
- if (auto instType = inst->getDataType())
- {
- if (doesInstRequireHoisting(instType) &&
- !hoistedInstMap.ContainsKey(instType))
- workList.add(instType);
- }
- };
-
- auto maybeAddUsersToWorkList = [&](IRInst* inst)
- {
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- if (doesInstRequireHoisting(use->getUser()))
- {
- if (as<IRVar>(inst) && as<IRStore>(use->getUser()))
- continue;
-
- // Uses that haven't already been hoisted into reverse-mode
- // blocks, and are not in the invert-set are pending uses.
- //
- if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser()))
- workList.add(use->getUser());
- }
- }
- };
-
- auto doesInstHavePendingUsers = [&](IRInst* inst)
- {
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- if (doesInstRequireHoisting(use->getUser()))
- {
- if (as<IRVar>(inst) && as<IRStore>(use->getUser()))
- continue;
-
- // Users that haven't already been hoisted into reverse-mode
- // blocks are pending users.
- //
- if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser()))
- return true;
- }
- }
-
- return false;
- };
-
- auto isInstHoisted = [&](IRInst* inst)
- {
- return getBlock(inst)->getParent() != nullptr && isDifferentialInst(getBlock(inst));
- };
-
- while (workList.getCount() > 0)
- {
- // Pop work item
- auto inst = workList.getLast();
- workList.removeLast();
-
- // Already hoisted to reverse-mode block.
- // replace with mapped inst (in case it's different)
- // and continue on.. (this should actually never be hit)
- //
- if (hoistedInstMap.ContainsKey(inst))
- continue;
-
- if (invOperandMap.ContainsKey(inst))
- {
- List<IRInst*> pendingInvDependencies;
- for (auto dependency : invOperandMap[inst].GetValue())
- {
- if (doesInstRequireHoisting(dependency) &&
- !hoistedInstMap.ContainsKey(dependency))
- pendingInvDependencies.add(dependency);
- }
-
- if (pendingInvDependencies.getCount() > 0)
- {
- workList.add(inst);
- for (auto dependency : pendingInvDependencies)
- workList.add(dependency);
-
- // Skip until all the dependencies have been handled.
- continue;
- }
- }
-
- // Are the uses of this primal inst already hoisted into the reverse-mode
- // blocks? We cannot hoist this inst unless the uses are hoisted.
- //
- if (doesInstHavePendingUsers(inst))
- {
- // Add inst back to work list.
- workList.add(inst);
-
- // Then, add all the pending use to the top of
- // list, ensuring they are processed before we see
- // inst again.
- //
- maybeAddUsersToWorkList(inst);
-
- continue;
- }
-
- // The used inst is marked for inversion, lookup and load
- // an inverse.
- //
- if (this->hoistedPrimalsInfo->invertSet.Contains(inst))
- {
- // Replace with inverse.
- IRBuilder builder(func->getModule());
-
- for (auto use = inst->firstUse; use;)
- {
- auto nextUse = use->nextUse;
-
- if (!isInstHoisted(use->getUser()))
- {
- use = nextUse;
- continue;
- }
-
- // TODO: Hacky workaround to prevent the 'key' being overwritten,
- // avoid this by adding the decoration on the param instead of the loop
- //
- if (auto exitValDecoration = as<IRLoopExitPrimalValueDecoration>(use->getUser()))
- {
- if (&exitValDecoration->target == use)
- {
- use = nextUse;
- continue;
- }
- }
-
-
- builder.setInsertBefore(getInstInBlock(use->getUser()));
- use->set(loadInverse(&builder, inst));
-
- use = nextUse;
- }
-
- // If all uses of the invertible inst have been hoisted,
- // add the inv-var to the worklist.
- //
- workList.add(lookupInverseVar(inst));
- hoistedInstMap[inst] = nullptr;
-
+ if (!doesInstRequireHoisting(inst))
continue;
- }
-
- // Should not see an inst marked for inversion here.
- SLANG_RELEASE_ASSERT(!this->hoistedPrimalsInfo->invertSet.Contains(inst));
-
+
List<IRUse*> relevantUses;
IRBlock* defBlock = nullptr;
@@ -1641,7 +1200,7 @@ struct DiffTransposePass
if (!doesInstRequireHoisting(inst))
continue;
-
+
// Move this inst to after it's diff uses.
//
{
@@ -1662,62 +1221,9 @@ struct DiffTransposePass
inst->insertBefore(currTopBlock->getFirstOrdinaryInst());
enableIRValidationAtInsert();
}
-
- // Finish up..
- hoistedInstMap[inst] = inst;
- maybeAddPrimalOperandsToWorkList(inst);
- }
- }
-
- void hoistPrimalOperands(IRBuilder* revBuilder, IRInst* fwdInst)
- {
- for (UIndex ii = 0; ii < fwdInst->getOperandCount(); ii++)
- {
- // For now we'll only hoist primal operands that are
- // generated in differential blocks.
- // Eventually, we also want this method to move primal access
- // insts to the reverse-mode blocks (i.e. this method will
- // make sure all requried primal insts are moved to the right
- // place)
- //
- if (doesInstRequireHoisting(fwdInst->getOperand(ii)))
- {
- hoistPrimalUse(revBuilder, &fwdInst->getOperands()[ii]);
- }
}
}
- void invertInst(IRBuilder* builder, IRInst* primalInst)
- {
- // Look for an available inverse entry for this primalInst's *output*
- if (shouldInstBeInverted(primalInst))
- {
- // This logic is already handled in transposeBlock() so we skip
- // it here.
- //
- if (as<IRTerminatorInst>(primalInst))
- return;
-
- auto invInfo = this->hoistedPrimalsInfo->invertInfoMap[primalInst];
-
- IRBuilder invBuilder(builder->getModule());
- invBuilder.setInsertAfter(primalInst);
-
- auto invEntries = invertInst(&invBuilder, primalInst, invInfo);
-
- for (auto entry : invEntries)
- setInverse(
- &invBuilder,
- getBlock(primalInst),
- as<IRGlobalValueWithCode>(entry.inst->getParent()->getParent()),
- entry.inst,
- entry.invInst);
- }
- else
- {
- SLANG_UNEXPECTED("Could not find value for the output of inst. Unable to invert");
- }
- }
void transposeInst(IRBuilder* builder, IRInst* inst)
{
@@ -1880,7 +1386,8 @@ struct DiffTransposePass
auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType();
auto tempVar = builder->emitVar(pairType);
auto primalVal = builder->emitLoad(instPair->getPrimal());
- hoistPrimalOperands(builder, primalVal); // TODO(sai): Do we need to hoist other insts here?
+ auto primalVar = instPair->getPrimal();
+ primalVarsToHoist.add(primalVar);
auto diffVal = builder->emitLoad(instPair->getDiff());
auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal);
@@ -1961,7 +1468,6 @@ struct DiffTransposePass
auto primalContextVar = primalContextDecor->getBackwardDerivativePrimalContextVar();
auto contextLoad = builder->emitLoad(primalContextVar);
- hoistPrimalOperands(builder, contextLoad);
args.add(contextLoad);
argTypes.add(as<IRPtrTypeBase>(
@@ -3477,7 +2983,7 @@ struct DiffTransposePass
DifferentialPairTypeBuilder pairBuilder;
- HoistedPrimalsInfo* hoistedPrimalsInfo;
+ List<IRInst*> primalVarsToHoist;
IRBlock* tempInvBlock;