diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-03-15 22:26:58 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-15 19:26:58 -0700 |
| commit | 71efd949fa5276e2464416fcf237f8fd2c486281 (patch) | |
| tree | a5b24cd077f2ecc3f74d4dd4671c8260eb6e9b67 /source/slang/slang-ir-autodiff-unzip.cpp | |
| parent | 38e62199cc75ce34608491c8dd299eb330bde518 (diff) | |
AD: Primal-Hoisting Rework + Checkpoint Policy Framework (#2702)
Diffstat (limited to 'source/slang/slang-ir-autodiff-unzip.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 34 |
1 files changed, 21 insertions, 13 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 5b59416d4..16862bb19 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -332,7 +332,12 @@ struct ExtractPrimalFuncContext inst); } - IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, HashSet<IRInst*>& primalParams, IRInst*& outIntermediateType) + IRFunc* turnUnzippedFuncIntoPrimalFunc( + IRFunc* unzippedFunc, + IRFunc* originalFunc, + HoistedPrimalsInfo* primalsInfo, + HashSet<IRInst*>& primalParams, + IRInst*& outIntermediateType) { IRBuilder builder(module); @@ -375,17 +380,9 @@ struct ExtractPrimalFuncContext // output intermediary struct. for (auto inst : block->getChildren()) { - if (shouldStoreInst(inst)) + if (primalsInfo->storeSet.Contains(inst)) { - if (as<IRParam>(inst)) - builder.setInsertBefore(block->getFirstOrdinaryInst()); - else - builder.setInsertAfter(inst); - storeInst(builder, inst, outIntermediary); - } - else if (inst->getOp() == kIROp_Var) - { - if (shouldStoreVar(as<IRVar>(inst))) + if (as<IRVar>(inst)) { auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary); builder.setInsertBefore(inst); @@ -394,7 +391,14 @@ struct ExtractPrimalFuncContext inst->replaceUsesWith(fieldAddr); builder.addPrimalValueStructKeyDecoration(inst, field->getKey()); } - + else + { + if (as<IRParam>(inst)) + builder.setInsertBefore(block->getFirstOrdinaryInst()); + else + builder.setInsertAfter(inst); + storeInst(builder, inst, outIntermediary); + } } } } @@ -459,6 +463,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE IRFunc* DiffUnzipPass::extractPrimalFunc( IRFunc* func, IRFunc* originalFunc, + HoistedPrimalsInfo* primalsInfo, ParameterBlockTransposeInfo& paramInfo, IRInst*& intermediateType) { @@ -470,6 +475,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( subEnv.parent = &cloneEnv; auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func)); + auto clonedPrimalsInfo = primalsInfo->applyMap(&subEnv); + // Remove [KeepAlive] decorations in clonedFunc. for (auto block : clonedFunc->getBlocks()) for (auto inst : block->getChildren()) @@ -494,7 +501,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( context.init(autodiffContext->moduleInst->getModule(), autodiffContext->transcriberSet.primalTranscriber); intermediateType = nullptr; - auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, newPrimalParams, intermediateType); + auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, clonedPrimalsInfo, newPrimalParams, intermediateType); if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>()) { @@ -580,6 +587,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( { // The primal calls should be marked as no side effect so they can be DCE'd if possible. // We can only do so if the intermediate context of the callee is stored. + // if (primalCtx->getBackwardDerivativePrimalContextVar() ->findDecoration<IRPrimalValueStructKeyDecoration>()) { |
