summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-unzip.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-03-15 22:26:58 -0400
committerGitHub <noreply@github.com>2023-03-15 19:26:58 -0700
commit71efd949fa5276e2464416fcf237f8fd2c486281 (patch)
treea5b24cd077f2ecc3f74d4dd4671c8260eb6e9b67 /source/slang/slang-ir-autodiff-unzip.cpp
parent38e62199cc75ce34608491c8dd299eb330bde518 (diff)
AD: Primal-Hoisting Rework + Checkpoint Policy Framework (#2702)
Diffstat (limited to 'source/slang/slang-ir-autodiff-unzip.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp34
1 files changed, 21 insertions, 13 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 5b59416d4..16862bb19 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -332,7 +332,12 @@ struct ExtractPrimalFuncContext
inst);
}
- IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, HashSet<IRInst*>& primalParams, IRInst*& outIntermediateType)
+ IRFunc* turnUnzippedFuncIntoPrimalFunc(
+ IRFunc* unzippedFunc,
+ IRFunc* originalFunc,
+ HoistedPrimalsInfo* primalsInfo,
+ HashSet<IRInst*>& primalParams,
+ IRInst*& outIntermediateType)
{
IRBuilder builder(module);
@@ -375,17 +380,9 @@ struct ExtractPrimalFuncContext
// output intermediary struct.
for (auto inst : block->getChildren())
{
- if (shouldStoreInst(inst))
+ if (primalsInfo->storeSet.Contains(inst))
{
- if (as<IRParam>(inst))
- builder.setInsertBefore(block->getFirstOrdinaryInst());
- else
- builder.setInsertAfter(inst);
- storeInst(builder, inst, outIntermediary);
- }
- else if (inst->getOp() == kIROp_Var)
- {
- if (shouldStoreVar(as<IRVar>(inst)))
+ if (as<IRVar>(inst))
{
auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary);
builder.setInsertBefore(inst);
@@ -394,7 +391,14 @@ struct ExtractPrimalFuncContext
inst->replaceUsesWith(fieldAddr);
builder.addPrimalValueStructKeyDecoration(inst, field->getKey());
}
-
+ else
+ {
+ if (as<IRParam>(inst))
+ builder.setInsertBefore(block->getFirstOrdinaryInst());
+ else
+ builder.setInsertAfter(inst);
+ storeInst(builder, inst, outIntermediary);
+ }
}
}
}
@@ -459,6 +463,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE
IRFunc* DiffUnzipPass::extractPrimalFunc(
IRFunc* func,
IRFunc* originalFunc,
+ HoistedPrimalsInfo* primalsInfo,
ParameterBlockTransposeInfo& paramInfo,
IRInst*& intermediateType)
{
@@ -470,6 +475,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
subEnv.parent = &cloneEnv;
auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func));
+ auto clonedPrimalsInfo = primalsInfo->applyMap(&subEnv);
+
// Remove [KeepAlive] decorations in clonedFunc.
for (auto block : clonedFunc->getBlocks())
for (auto inst : block->getChildren())
@@ -494,7 +501,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
context.init(autodiffContext->moduleInst->getModule(), autodiffContext->transcriberSet.primalTranscriber);
intermediateType = nullptr;
- auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, newPrimalParams, intermediateType);
+ auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, clonedPrimalsInfo, newPrimalParams, intermediateType);
if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>())
{
@@ -580,6 +587,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
{
// The primal calls should be marked as no side effect so they can be DCE'd if possible.
// We can only do so if the intermediate context of the callee is stored.
+ //
if (primalCtx->getBackwardDerivativePrimalContextVar()
->findDecoration<IRPrimalValueStructKeyDecoration>())
{