summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-24 19:44:23 -0700
committerGitHub <noreply@github.com>2023-04-24 19:44:23 -0700
commit284cee1f246c072f190c87c8fb60c1d2181e458f (patch)
tree6f8b4ff3d619ad518e733000464daae233890962 /source/slang/slang-ir-autodiff-rev.cpp
parentfbe37ea6d90f7bfe18506b042657c6e533eaf9b2 (diff)
Change AD checkpointing policy to recompute more. (#2836)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp12
1 files changed, 6 insertions, 6 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 979eb6343..d7abf1d40 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -711,12 +711,9 @@ namespace Slang
// Apply checkpointing policy to legalize cross-scope uses of primal values
// using either recompute or store strategies.
- auto primalsInfo = applyCheckpointPolicy(
- diffPropagateFunc, paramTransposeInfo.propagateFuncSpecificPrimalInsts);
-
+ auto primalsInfo = applyCheckpointPolicy(diffPropagateFunc);
eliminateDeadCode(diffPropagateFunc);
-
// Extracts the primal computations into its own func, and replace the primal insts
// with the intermediate results computed from the extracted func.
@@ -810,10 +807,13 @@ namespace Slang
// Find the 'next' block using the terminator inst of the parameter block.
auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator());
- auto nextBlock = fwdParamBlockBranch->getTargetBlock();
+ // We create a new block after parameter block to hold insts that translates from transposed parameters
+ // into something that the rest of the function can use.
+ IRBuilder::insertBlockAlongEdge(diffFunc->getModule(), IREdge(&fwdParamBlockBranch->block));
+ auto paramPreludeBlock = fwdParamBlockBranch->getTargetBlock();
auto nextBlockBuilder = *builder;
- nextBlockBuilder.setInsertBefore(nextBlock->getFirstOrdinaryInst());
+ nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst());
IRBlock* firstDiffBlock = nullptr;
for (auto block : diffFunc->getBlocks())