diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-21 14:28:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-21 14:28:57 -0700 |
| commit | 957a4d3eb0a14a9d57bbb325ef0e1d458df2d2b9 (patch) | |
| tree | fabc9317b1595c9f74f5b25ee83d16f4260a19d3 /source/slang/slang-ir-autodiff.cpp | |
| parent | 69a327a98e3f9504863f9ecb623aa93036ac43db (diff) | |
Refactor checkpointing policy and availability pass. (#2826)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 48 |
1 files changed, 38 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 9a7a42619..a8af148d9 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -774,7 +774,9 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_PrimalInstDecoration: case kIROp_DifferentialInstDecoration: case kIROp_MixedDifferentialInstDecoration: - case kIROp_PrimalValueAccessDecoration: + case kIROp_RecomputeBlockDecoration: + case kIROp_LoopCounterDecoration: + case kIROp_LoopCounterUpdateDecoration: case kIROp_BackwardDerivativeDecoration: case kIROp_BackwardDerivativeIntermediateTypeDecoration: case kIROp_BackwardDerivativePropagateDecoration: @@ -814,6 +816,7 @@ void stripTempDecorations(IRInst* inst) { case kIROp_DifferentialInstDecoration: case kIROp_MixedDifferentialInstDecoration: + case kIROp_RecomputeBlockDecoration: case kIROp_AutoDiffOriginalValueDecoration: case kIROp_BackwardDerivativePrimalReturnDecoration: case kIROp_PrimalValueStructKeyDecoration: @@ -902,8 +905,9 @@ bool canTypeBeStored(IRInst* type) case kIROp_FloatType: case kIROp_VectorType: case kIROp_MatrixType: - case kIROp_AttributedType: return true; + case kIROp_AttributedType: + return canTypeBeStored(type->getOperand(0)); default: return false; } @@ -1770,7 +1774,7 @@ IRInst* getInstInBlock(IRInst* inst) return getInstInBlock(inst->getParent()); } -UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) +UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst*& inoutTerminatorInst, IRInst* arg) { SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator())); @@ -1786,16 +1790,22 @@ UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) switch (branchInst->getOp()) { case kIROp_unconditionalBranch: - builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer()); + inoutTerminatorInst = builder->emitBranch( + branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer()); break; case kIROp_loop: - builder->emitLoop( - as<IRLoop>(branchInst)->getTargetBlock(), - as<IRLoop>(branchInst)->getBreakBlock(), - as<IRLoop>(branchInst)->getContinueBlock(), - phiArgs.getCount(), - phiArgs.getBuffer()); + { + auto newLoop = builder->emitLoop( + as<IRLoop>(branchInst)->getTargetBlock(), + as<IRLoop>(branchInst)->getBreakBlock(), + as<IRLoop>(branchInst)->getContinueBlock(), + phiArgs.getCount(), + phiArgs.getBuffer()); + branchInst->transferDecorationsTo(newLoop); + branchInst->replaceUsesWith(newLoop); + inoutTerminatorInst = newLoop; + } break; default: @@ -1806,6 +1816,24 @@ UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) return phiArgs.getCount() - 1; } +bool isDifferentialOrRecomputeBlock(IRBlock* block) +{ + if (!block) + return false; + for (auto decor : block->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_DifferentialInstDecoration: + case kIROp_RecomputeBlockDecoration: + return true; + default: + break; + } + } + return false; +} + IRUse* findUniqueStoredVal(IRVar* var) { if (isDerivativeContextVar(var)) |
