summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-21 14:28:57 -0700
committerGitHub <noreply@github.com>2023-04-21 14:28:57 -0700
commit957a4d3eb0a14a9d57bbb325ef0e1d458df2d2b9 (patch)
treefabc9317b1595c9f74f5b25ee83d16f4260a19d3 /source/slang/slang-ir-autodiff.cpp
parent69a327a98e3f9504863f9ecb623aa93036ac43db (diff)
Refactor checkpointing policy and availability pass. (#2826)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp48
1 files changed, 38 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 9a7a42619..a8af148d9 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -774,7 +774,9 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_PrimalInstDecoration:
case kIROp_DifferentialInstDecoration:
case kIROp_MixedDifferentialInstDecoration:
- case kIROp_PrimalValueAccessDecoration:
+ case kIROp_RecomputeBlockDecoration:
+ case kIROp_LoopCounterDecoration:
+ case kIROp_LoopCounterUpdateDecoration:
case kIROp_BackwardDerivativeDecoration:
case kIROp_BackwardDerivativeIntermediateTypeDecoration:
case kIROp_BackwardDerivativePropagateDecoration:
@@ -814,6 +816,7 @@ void stripTempDecorations(IRInst* inst)
{
case kIROp_DifferentialInstDecoration:
case kIROp_MixedDifferentialInstDecoration:
+ case kIROp_RecomputeBlockDecoration:
case kIROp_AutoDiffOriginalValueDecoration:
case kIROp_BackwardDerivativePrimalReturnDecoration:
case kIROp_PrimalValueStructKeyDecoration:
@@ -902,8 +905,9 @@ bool canTypeBeStored(IRInst* type)
case kIROp_FloatType:
case kIROp_VectorType:
case kIROp_MatrixType:
- case kIROp_AttributedType:
return true;
+ case kIROp_AttributedType:
+ return canTypeBeStored(type->getOperand(0));
default:
return false;
}
@@ -1770,7 +1774,7 @@ IRInst* getInstInBlock(IRInst* inst)
return getInstInBlock(inst->getParent());
}
-UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
+UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst*& inoutTerminatorInst, IRInst* arg)
{
SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator()));
@@ -1786,16 +1790,22 @@ UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
switch (branchInst->getOp())
{
case kIROp_unconditionalBranch:
- builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer());
+ inoutTerminatorInst = builder->emitBranch(
+ branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer());
break;
case kIROp_loop:
- builder->emitLoop(
- as<IRLoop>(branchInst)->getTargetBlock(),
- as<IRLoop>(branchInst)->getBreakBlock(),
- as<IRLoop>(branchInst)->getContinueBlock(),
- phiArgs.getCount(),
- phiArgs.getBuffer());
+ {
+ auto newLoop = builder->emitLoop(
+ as<IRLoop>(branchInst)->getTargetBlock(),
+ as<IRLoop>(branchInst)->getBreakBlock(),
+ as<IRLoop>(branchInst)->getContinueBlock(),
+ phiArgs.getCount(),
+ phiArgs.getBuffer());
+ branchInst->transferDecorationsTo(newLoop);
+ branchInst->replaceUsesWith(newLoop);
+ inoutTerminatorInst = newLoop;
+ }
break;
default:
@@ -1806,6 +1816,24 @@ UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
return phiArgs.getCount() - 1;
}
+bool isDifferentialOrRecomputeBlock(IRBlock* block)
+{
+ if (!block)
+ return false;
+ for (auto decor : block->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_DifferentialInstDecoration:
+ case kIROp_RecomputeBlockDecoration:
+ return true;
+ default:
+ break;
+ }
+ }
+ return false;
+}
+
IRUse* findUniqueStoredVal(IRVar* var)
{
if (isDerivativeContextVar(var))