summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-24 19:44:23 -0700
committerGitHub <noreply@github.com>2023-04-24 19:44:23 -0700
commit284cee1f246c072f190c87c8fb60c1d2181e458f (patch)
tree6f8b4ff3d619ad518e733000464daae233890962 /source
parentfbe37ea6d90f7bfe18506b042657c6e533eaf9b2 (diff)
Change AD checkpointing policy to recompute more. (#2836)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp155
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h26
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp12
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp19
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h8
-rw-r--r--source/slang/slang-ir-autodiff.cpp42
-rw-r--r--source/slang/slang-ir-autodiff.h2
-rw-r--r--source/slang/slang-ir-dce.cpp20
-rw-r--r--source/slang/slang-ir-util.cpp18
-rw-r--r--source/slang/slang-ir-util.h2
10 files changed, 197 insertions, 107 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 6a9b504a6..1bc3caaba 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -1,5 +1,6 @@
#include "slang-ir-autodiff-primal-hoist.h"
#include "slang-ir-autodiff-region.h"
+#include "slang-ir-simplify-cfg.h"
namespace Slang
{
@@ -9,7 +10,8 @@ void applyCheckpointSet(
IRGlobalValueWithCode* func,
HoistedPrimalsInfo* hoistInfo,
HashSet<IRUse*> pendingUses,
- Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock);
+ Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock,
+ IROutOfOrderCloneContext* cloneCtx);
bool containsOperand(IRInst* inst, IRInst* operand)
{
@@ -68,7 +70,8 @@ static IRBlock* tryGetSubRegionEndBlock(IRInst* terminator)
static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks(
IRGlobalValueWithCode* func,
- Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo)
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo,
+ IROutOfOrderCloneContext* cloneCtx)
{
IRBlock* firstDiffBlock = nullptr;
for (auto block : func->getBlocks())
@@ -136,7 +139,6 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks(
WorkItem firstWorkItem = { func->getFirstBlock(), firstRecomputeBlock, firstRecomputeBlock, firstDiffBlock };
workList.add(firstWorkItem);
- IRCloneEnv recomputeCloneEnv;
recomputeBlockMap[func->getFirstBlock()] = firstRecomputeBlock;
for (Index i = 0; i < workList.getCount(); i++)
@@ -216,7 +218,7 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks(
{
case kIROp_Switch:
case kIROp_ifElse:
- newTerminator = cloneInst(&recomputeCloneEnv, &builder, primalBlock->getTerminator());
+ newTerminator = cloneCtx->cloneInstOutOfOrder(&builder, primalBlock->getTerminator());
break;
case kIROp_unconditionalBranch:
newTerminator = builder.emitBranch(as<IRUnconditionalBranch>(terminator)->getTargetBlock());
@@ -271,7 +273,8 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks(
RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
IRGlobalValueWithCode* func,
- Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock)
+ Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock,
+ IROutOfOrderCloneContext* cloneCtx)
{
RefPtr<CheckpointSetInfo> checkpointInfo = new CheckpointSetInfo();
@@ -483,7 +486,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
}
RefPtr<HoistedPrimalsInfo> hoistInfo = new HoistedPrimalsInfo();
- applyCheckpointSet(checkpointInfo, func, hoistInfo, usesToReplace, mapDiffBlockToRecomputeBlock);
+ applyCheckpointSet(checkpointInfo, func, hoistInfo, usesToReplace, mapDiffBlockToRecomputeBlock, cloneCtx);
return hoistInfo;
}
@@ -501,11 +504,6 @@ void applyToInst(
return;
}
- if (hoistInfo->ignoreSet.Contains(inst))
- {
- return;
- }
-
bool isInstRecomputed = checkpointInfo->recomputeSet.Contains(inst);
if (isInstRecomputed)
{
@@ -522,11 +520,10 @@ void applyToInst(
//
SLANG_UNIMPLEMENTED_X("Parameter recompute is not currently supported");
}
+ return;
}
- else
- {
- hoistInfo->recomputeSet.Add(cloneCtx->cloneInstOutOfOrder(builder, inst));
- }
+ auto recomputeInst = cloneCtx->cloneInstOutOfOrder(builder, inst);
+ hoistInfo->recomputeSet.Add(recomputeInst);
}
bool isInstInverted = checkpointInfo->invertSet.Contains(inst);
@@ -553,17 +550,22 @@ void applyToInst(
}
}
+static IRBlock* getParamPreludeBlock(IRGlobalValueWithCode* func)
+{
+ return func->getFirstBlock()->getNextBlock();
+}
+
void applyCheckpointSet(
CheckpointSetInfo* checkpointInfo,
IRGlobalValueWithCode* func,
HoistedPrimalsInfo* hoistInfo,
HashSet<IRUse*> pendingUses,
- Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock)
+ Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock,
+ IROutOfOrderCloneContext* cloneCtx)
{
// Reconstruct diff block map.
Dictionary<IRBlock*, IRBlock*> diffBlockMap = reconstructDiffBlockMap(func);
- RefPtr<IROutOfOrderCloneContext> cloneCtx = new IROutOfOrderCloneContext();
for (auto use : pendingUses)
cloneCtx->pendingUses.Add(use);
@@ -583,10 +585,11 @@ void applyCheckpointSet(
};
// Go back over the insts and move/clone them accoridngly.
+ auto paramPreludeBlock = getParamPreludeBlock(func);
for (auto block : func->getBlocks())
{
- // Skip parameter block.
- if (block == func->getFirstBlock())
+ // Skip parameter block and the param prelude block.
+ if (block == func->getFirstBlock() || block == paramPreludeBlock)
continue;
if (isDifferentialBlock(block))
@@ -646,7 +649,22 @@ void applyCheckpointSet(
for (auto child : block->getChildren())
{
+ // Determine the insertion point for the recomputeInst.
+ // Normally we insert recomputeInst into the block's corresponding recomputeBlock.
+ // The exception is a load(inoutParam), in which case we insert the recomputed load
+ // at the right beginning of the function to correctly receive the initial parameter
+ // value. We can't just insert the load at recomputeBlock because at that point the
+ // primal logic may have already updated the param with a new value, and instead we
+ // want the original value.
builder.setInsertBefore(recomputeInsertBeforeInst);
+ if (auto load = as<IRLoad>(child))
+ {
+ if (load->getPtr()->getOp() == kIROp_Param &&
+ load->getPtr()->getParent() == func->getFirstBlock())
+ {
+ builder.setInsertBefore(getParamPreludeBlock(func)->getTerminator());
+ }
+ }
applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child);
}
}
@@ -833,28 +851,33 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo)
{
RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
+
+ IRBlock* defaultVarBlock = func->getFirstBlock()->getNextBlock();
IRBuilder builder(func->getModule());
- IRBlock* defaultVarBlock = func->getFirstBlock()->getNextBlock();
- SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock));
+ IRBlock* defaultRecomptueVarBlock = nullptr;
+ for (auto block : func->getBlocks())
+ if (isDifferentialOrRecomputeBlock(block))
+ {
+ defaultRecomptueVarBlock = block;
+ break;
+ }
+ SLANG_RELEASE_ASSERT(defaultRecomptueVarBlock);
OrderedHashSet<IRInst*> processedStoreSet;
- auto ensureInstAvailable = [&](OrderedHashSet<IRInst*>& instSet)
+ auto ensureInstAvailable = [&](OrderedHashSet<IRInst*>& instSet, bool isRecomputeInst)
{
+ SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock));
+
for (auto instToStore : instSet)
{
- if (!instSet.Contains(instToStore))
- continue;
-
- if (hoistInfo->ignoreSet.Contains(instToStore))
- continue;
IRBlock* defBlock = nullptr;
if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType()))
{
auto varInst = as<IRVar>(instToStore);
- auto storeUse = findUniqueStoredVal(varInst);
+ auto storeUse = findEarliestUniqueWriteUse(varInst);
defBlock = getBlock(storeUse->getUser());
}
@@ -899,19 +922,28 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
if (outOfScopeUses.getCount() == 0)
{
- processedStoreSet.Add(instToStore);
+ if (!isRecomputeInst)
+ processedStoreSet.Add(instToStore);
continue;
}
+ auto defBlockIndices = indexedBlockInfo[defBlock].GetValue();
+ IRBlock* varBlock = defaultVarBlock;
+ if (isRecomputeInst)
+ {
+ varBlock = defaultRecomptueVarBlock;
+ if (defBlockIndices.getCount())
+ {
+ varBlock = as<IRBlock>(defBlockIndices[0].diffCountParam->getParent());
+ defBlockIndices.clear();
+ }
+ }
if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType()))
{
-
IRVar* varToStore = as<IRVar>(instToStore);
SLANG_RELEASE_ASSERT(varToStore);
- auto storeUse = findUniqueStoredVal(varToStore);
-
- List<IndexTrackingInfo>& defBlockIndices = indexedBlockInfo[defBlock];
+ auto storeUse = findLatestUniqueWriteUse(varToStore);
bool isIndexedStore = (storeUse && defBlockIndices.getCount() > 0);
@@ -921,7 +953,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
if (!isIndexedStore && isDerivativeContextVar(varToStore))
{
varToStore->insertBefore(defaultVarBlock->getFirstOrdinaryInst());
- processedStoreSet.Add(varToStore);
+ if (!isRecomputeInst)
+ processedStoreSet.Add(varToStore);
continue;
}
@@ -929,7 +962,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
IRVar* localVar = storeIndexedValue(
&builder,
- defaultVarBlock,
+ varBlock,
builder.emitLoad(varToStore),
defBlockIndices);
@@ -942,8 +975,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
IRInst* loadAddr = emitIndexedLoadAddressForVar(&builder, localVar, defBlockIndices, useBlockIndices);
builder.replaceOperand(use, loadAddr);
}
-
- processedStoreSet.Add(localVar);
+ if (!isRecomputeInst)
+ processedStoreSet.Add(localVar);
}
else
{
@@ -951,7 +984,6 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
// The only case where there will be a reference of primal loop counter from rev blocks
// is the start of a loop in the reverse code. Since loop counters are not considered a
// part of their loop region, so we remove the first index info.
- List<IndexTrackingInfo> defBlockIndices = indexedBlockInfo[defBlock];
bool isLoopCounter = (instToStore->findDecoration<IRLoopCounterDecoration>() != nullptr);
if (isLoopCounter)
{
@@ -959,7 +991,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
}
setInsertAfterOrdinaryInst(&builder, instToStore);
- auto localVar = storeIndexedValue(&builder, defaultVarBlock, instToStore, defBlockIndices);
+ auto localVar = storeIndexedValue(&builder, varBlock, instToStore, defBlockIndices);
for (auto use : outOfScopeUses)
{
@@ -974,14 +1006,15 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
builder.replaceOperand(use, loadIndexedValue(&builder, localVar, defBlockIndices, useBlockIndices));
}
-
- processedStoreSet.Add(localVar);
+ if (!isRecomputeInst)
+ processedStoreSet.Add(localVar);
}
}
};
- ensureInstAvailable(hoistInfo->storeSet);
-
+ ensureInstAvailable(hoistInfo->storeSet, false);
+ ensureInstAvailable(hoistInfo->recomputeSet, true);
+
// Replace the old store set with the processed one.
hoistInfo->storeSet = processedStoreSet;
@@ -1179,27 +1212,23 @@ void buildIndexedBlocks(
}
}
-RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(
- IRGlobalValueWithCode* func, const List<IRInst*>& instsToIgnore)
+RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
{
sortBlocksInFunc(func);
Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo;
buildIndexedBlocks(indexedBlockInfo, func);
- auto recomputeBlockMap = createPrimalRecomputeBlocks(func, indexedBlockInfo);
+ RefPtr<IROutOfOrderCloneContext> cloneCtx = new IROutOfOrderCloneContext();
+ auto recomputeBlockMap = createPrimalRecomputeBlocks(func, indexedBlockInfo, cloneCtx);
sortBlocksInFunc(func);
RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(func->getModule());
chkPolicy->preparePolicy(func);
- auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap);
+ auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap, cloneCtx);
- for (auto propagateFuncSpecificInst : instsToIgnore)
- {
- primalsInfo->ignoreSet.add(propagateFuncSpecificInst);
- }
primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
return primalsInfo;
}
@@ -1343,7 +1372,6 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_GetSequentialID:
case kIROp_Specialize:
case kIROp_LookupWitness:
-#if 0
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
@@ -1364,7 +1392,6 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_BitXor:
case kIROp_Lsh:
case kIROp_Rsh:
-#endif
return false;
case kIROp_GetElement:
case kIROp_FieldExtract:
@@ -1387,17 +1414,29 @@ static bool shouldStoreInst(IRInst* inst)
if (as<IRType>(inst))
return false;
- // Only store if the inst has differential inst user.
- bool hasDiffUser = doesInstHaveDiffUse(inst);
- if (!hasDiffUser)
- return false;
-
return true;
}
bool canRecompute(IRDominatorTree* domTree, IRUse* use)
{
SLANG_UNUSED(domTree);
+ if (auto load = as<IRLoad>(use->get()))
+ {
+ // Generally, we cannot recompute a load(ptr), since ptr may be modified
+ // afterwards. The exceptions are a load of an inout param, since the
+ // propagation function never actually writes to the primal part of the
+ // inout param, and we can always just read the original param.
+
+ auto ptr = load->getPtr();
+ if (ptr->getOp() == kIROp_Param)
+ {
+ if (auto block = as<IRBlock>(ptr->getParent()))
+ {
+ return (block == block->getParent()->getFirstBlock());
+ }
+ }
+ return false;
+ }
auto param = as<IRParam>(use->get());
if (!param)
return true;
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
index 3b3fb82b1..6e861bc5b 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.h
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -14,10 +14,8 @@ namespace Slang
IRCloneEnv cloneEnv;
HashSet<IRUse*> pendingUses;
- IRInst* cloneInstOutOfOrder(IRBuilder* builder, IRInst* inst)
+ void registerClonedInst(IRBuilder* builder, IRInst* inst, IRInst* clonedInst)
{
- IRInst* clonedInst = cloneInst(&cloneEnv, builder, inst);
-
UInt operandCount = clonedInst->getOperandCount();
for (UInt ii = 0; ii < operandCount; ++ii)
{
@@ -31,16 +29,21 @@ namespace Slang
for (auto use = inst->firstUse; use;)
{
auto nextUse = use->nextUse;
-
+
if (pendingUses.Contains(use))
{
pendingUses.Remove(use);
builder->replaceOperand(use, clonedInst);
}
-
+
use = nextUse;
}
+ }
+ IRInst* cloneInstOutOfOrder(IRBuilder* builder, IRInst* inst)
+ {
+ IRInst* clonedInst = cloneInst(&cloneEnv, builder, inst);
+ registerClonedInst(builder, inst, clonedInst);
return clonedInst;
}
};
@@ -86,7 +89,6 @@ namespace Slang
OrderedHashSet<IRInst*> storeSet;
OrderedHashSet<IRInst*> recomputeSet;
OrderedHashSet<IRInst*> invertSet;
- OrderedHashSet<IRInst*> ignoreSet;
OrderedHashSet<IRInst*> instsToInvert;
Dictionary<IRInst*, InversionInfo> invertInfoMap;
@@ -129,9 +131,6 @@ namespace Slang
for (auto inst : info->invertSet)
invertSet.Add(inst);
- for (auto inst : info->ignoreSet)
- ignoreSet.add(inst);
-
for (auto inst : info->instsToInvert)
instsToInvert.Add(inst);
@@ -261,7 +260,8 @@ namespace Slang
RefPtr<HoistedPrimalsInfo> processFunc(
IRGlobalValueWithCode* func,
- Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock);
+ Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock,
+ IROutOfOrderCloneContext* cloneCtx);
// Do pre-processing on the function (mainly for
// 'global' checkpointing methods that consider the entire
@@ -290,9 +290,5 @@ namespace Slang
RefPtr<IRDominatorTree> domTree;
};
- RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(
- IRGlobalValueWithCode* func,
- const List<IRInst*>& instsToIgnore);
-
-
+ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func);
};
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 979eb6343..d7abf1d40 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -711,12 +711,9 @@ namespace Slang
// Apply checkpointing policy to legalize cross-scope uses of primal values
// using either recompute or store strategies.
- auto primalsInfo = applyCheckpointPolicy(
- diffPropagateFunc, paramTransposeInfo.propagateFuncSpecificPrimalInsts);
-
+ auto primalsInfo = applyCheckpointPolicy(diffPropagateFunc);
eliminateDeadCode(diffPropagateFunc);
-
// Extracts the primal computations into its own func, and replace the primal insts
// with the intermediate results computed from the extracted func.
@@ -810,10 +807,13 @@ namespace Slang
// Find the 'next' block using the terminator inst of the parameter block.
auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator());
- auto nextBlock = fwdParamBlockBranch->getTargetBlock();
+ // We create a new block after parameter block to hold insts that translates from transposed parameters
+ // into something that the rest of the function can use.
+ IRBuilder::insertBlockAlongEdge(diffFunc->getModule(), IREdge(&fwdParamBlockBranch->block));
+ auto paramPreludeBlock = fwdParamBlockBranch->getTargetBlock();
auto nextBlockBuilder = *builder;
- nextBlockBuilder.setInsertBefore(nextBlock->getFirstOrdinaryInst());
+ nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst());
IRBlock* firstDiffBlock = nullptr;
for (auto block : diffFunc->getBlocks())
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 44e981404..a864a74b2 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -216,12 +216,15 @@ struct ExtractPrimalFuncContext
{
if (as<IRVar>(inst))
{
- auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary);
- builder.setInsertBefore(inst);
- auto fieldAddr = builder.emitFieldAddress(
- inst->getFullType(), outIntermediary, field->getKey());
- inst->replaceUsesWith(fieldAddr);
- builder.addPrimalValueStructKeyDecoration(inst, field->getKey());
+ if (inst->hasUses())
+ {
+ auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary);
+ builder.setInsertBefore(inst);
+ auto fieldAddr = builder.emitFieldAddress(
+ inst->getFullType(), outIntermediary, field->getKey());
+ inst->replaceUsesWith(fieldAddr);
+ builder.addPrimalValueStructKeyDecoration(inst, field->getKey());
+ }
}
else
{
@@ -359,7 +362,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
List<IRInst*> instsToRemove;
for (auto block : func->getBlocks())
{
- for (auto inst : block->getOrdinaryInsts())
+ for (auto inst : block->getChildren())
{
if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>())
{
@@ -420,6 +423,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
for (auto inst : instsToRemove)
{
+ if (as<IRParam>(inst))
+ removePhiArgs(inst);
inst->removeAndDeallocate();
}
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 65f45ece8..532e63b42 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -316,8 +316,12 @@ struct DiffUnzipPass
if (auto primalVar = as<IRVar>(primalArg))
{
primalArg = diffBuilder->emitVar(as<IRPtrTypeBase>(primalVar->getDataType())->getValueType());
- if (auto storeUse = findUniqueStoredVal(primalVar))
- diffBuilder->emitStore(primalArg, as<IRStore>(storeUse->getUser())->getVal());
+ if (auto storeUse = findUniqueStoredVal(primalVar))
+ {
+ auto storeInst = diffBuilder->emitStore(primalArg, as<IRStore>(storeUse->getUser())->getVal());
+ storeInst->insertAfter(storeUse->getUser());
+ primalArg->insertBefore(storeInst);
+ }
}
// If arg is a mixed differential (pair), it should have already been split.
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index a8af148d9..656b0e11b 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1868,6 +1868,48 @@ IRUse* findUniqueStoredVal(IRVar* var)
}
}
+// Given a local var that is supposed to have a unique write, find the last inst
+// that writes to it. Note: if var is intended for an inout argument, it will
+// have exactly one store that sets its initial value and one call that writes
+// the final value to it, this method will return the call inst for this case.
+IRUse* findLatestUniqueWriteUse(IRVar* var)
+{
+ IRUse* storeUse = nullptr;
+ // If no unique store found, try to look for a call.
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ {
+ if (auto callInst = as<IRCall>(use->getUser()))
+ {
+ SLANG_RELEASE_ASSERT(!storeUse);
+ storeUse = use;
+ }
+ }
+ return findUniqueStoredVal(var);
+}
+
+// Given a local var that is supposed to have a unique write, find the last inst
+// that writes to it. Note: if var is intended for an inout argument, it will
+// have exactly one store that sets its initial value and one call that writes
+// the final value to it, this method will return the store inst for this case.
+IRUse* findEarliestUniqueWriteUse(IRVar* var)
+{
+ IRUse* storeUse = findUniqueStoredVal(var);
+ if (storeUse)
+ return storeUse;
+
+ // If no unique store found, try to look for a call.
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ {
+ if (auto callInst = as<IRCall>(use->getUser()))
+ {
+ SLANG_RELEASE_ASSERT(!storeUse);
+ storeUse = use;
+ }
+ }
+ return storeUse;
+}
+
+
bool isDerivativeContextVar(IRVar* var)
{
return var->findDecoration<IRBackwardDerivativePrimalContextDecoration>();
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index d7d6119d4..52cf346b3 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -343,6 +343,8 @@ IRInst* getInstInBlock(IRInst* inst);
UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst*& inoutTerminatorInst, IRInst* arg);
IRUse* findUniqueStoredVal(IRVar* var);
+IRUse* findLatestUniqueWriteUse(IRVar* var);
+IRUse* findEarliestUniqueWriteUse(IRVar* var);
bool isDerivativeContextVar(IRVar* var);
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index 364abe68c..1b0ecf521 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -223,25 +223,6 @@ struct DeadCodeEliminationContext
return processInst(module->getModuleInst());
}
- void removePhiArgs(IRInst* phiParam)
- {
- auto block = cast<IRBlock>(phiParam->getParent());
- UInt paramIndex = 0;
- for (auto p = block->getFirstParam(); p; p = p->getNextParam())
- {
- if (p == phiParam)
- break;
- paramIndex++;
- }
- for (auto predBlock : block->getPredecessors())
- {
- auto termInst = as<IRUnconditionalBranch>(predBlock->getTerminator());
- SLANG_ASSERT(paramIndex < termInst->getArgCount());
- termInst->removeArgument(paramIndex);
- }
- phiRemoved = true;
- }
-
bool eliminateDeadInstsRec(IRInst* inst)
{
bool changed = false;
@@ -266,6 +247,7 @@ struct DeadCodeEliminationContext
{
// For Phi parameters, we need to update all branch arguments.
removePhiArgs(inst);
+ phiRemoved = true;
}
inst->removeAndDeallocate();
changed = true;
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 03b74b36a..9348dfe8a 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -741,6 +741,24 @@ void moveParams(IRBlock* dest, IRBlock* src)
}
}
+void removePhiArgs(IRInst* phiParam)
+{
+ auto block = cast<IRBlock>(phiParam->getParent());
+ UInt paramIndex = 0;
+ for (auto p = block->getFirstParam(); p; p = p->getNextParam())
+ {
+ if (p == phiParam)
+ break;
+ paramIndex++;
+ }
+ for (auto predBlock : block->getPredecessors())
+ {
+ auto termInst = as<IRUnconditionalBranch>(predBlock->getTerminator());
+ SLANG_ASSERT(paramIndex < termInst->getArgCount());
+ termInst->removeArgument(paramIndex);
+ }
+}
+
struct GenericChildrenMigrationContextImpl
{
IRCloneEnv cloneEnv;
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index e7d182604..9405771b1 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -200,6 +200,8 @@ IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key);
IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key);
void moveParams(IRBlock* dest, IRBlock* src);
+
+void removePhiArgs(IRInst* phiParam);
}
#endif