diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-24 19:44:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-24 19:44:23 -0700 |
| commit | 284cee1f246c072f190c87c8fb60c1d2181e458f (patch) | |
| tree | 6f8b4ff3d619ad518e733000464daae233890962 /source/slang/slang-ir-autodiff-rev.cpp | |
| parent | fbe37ea6d90f7bfe18506b042657c6e533eaf9b2 (diff) | |
Change AD checkpointing policy to recompute more. (#2836)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 979eb6343..d7abf1d40 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -711,12 +711,9 @@ namespace Slang // Apply checkpointing policy to legalize cross-scope uses of primal values // using either recompute or store strategies. - auto primalsInfo = applyCheckpointPolicy( - diffPropagateFunc, paramTransposeInfo.propagateFuncSpecificPrimalInsts); - + auto primalsInfo = applyCheckpointPolicy(diffPropagateFunc); eliminateDeadCode(diffPropagateFunc); - // Extracts the primal computations into its own func, and replace the primal insts // with the intermediate results computed from the extracted func. @@ -810,10 +807,13 @@ namespace Slang // Find the 'next' block using the terminator inst of the parameter block. auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator()); - auto nextBlock = fwdParamBlockBranch->getTargetBlock(); + // We create a new block after parameter block to hold insts that translates from transposed parameters + // into something that the rest of the function can use. + IRBuilder::insertBlockAlongEdge(diffFunc->getModule(), IREdge(&fwdParamBlockBranch->block)); + auto paramPreludeBlock = fwdParamBlockBranch->getTargetBlock(); auto nextBlockBuilder = *builder; - nextBlockBuilder.setInsertBefore(nextBlock->getFirstOrdinaryInst()); + nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst()); IRBlock* firstDiffBlock = nullptr; for (auto block : diffFunc->getBlocks()) |
