summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-03-15 22:26:58 -0400
committerGitHub <noreply@github.com>2023-03-15 19:26:58 -0700
commit71efd949fa5276e2464416fcf237f8fd2c486281 (patch)
treea5b24cd077f2ecc3f74d4dd4671c8260eb6e9b67 /source/slang/slang-ir-autodiff-rev.cpp
parent38e62199cc75ce34608491c8dd299eb330bde518 (diff)
AD: Primal-Hoisting Rework + Checkpoint Policy Framework (#2702)
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp110
1 files changed, 5 insertions, 105 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 328af4867..157011b7c 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -606,106 +606,6 @@ namespace Slang
return fwdDiffFunc;
}
- void BackwardDiffTranscriberBase::insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc)
- {
- RefPtr<IRDominatorTree> domTree = computeDominatorTree(diffPropFunc);
- auto firstBlock = diffPropFunc->getFirstBlock();
- if (!firstBlock)
- return;
- Dictionary<IRInst*, IRVar*> instVars;
- Dictionary<IRBlock*, IRCloneEnv> cloneEnvs;
- auto storeInstAsLocalVar = [&](IRInst* inst)
- {
- IRVar* var = nullptr;
- if (instVars.TryGetValue(inst, var))
- return var;
- IRBuilder builder(diffPropFunc);
- builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
- var = builder.emitVar(inst->getDataType());
- builder.emitStore(var, builder.emitDefaultConstruct(inst->getDataType()));
-
- setInsertAfterOrdinaryInst(&builder, inst);
- builder.emitStore(var, inst);
- instVars[inst] = var;
- return var;
- };
-
- IRBuilder builder(diffPropFunc);
- List<IRInst*> workList;
- for (auto block : diffPropFunc->getBlocks())
- {
- if (!block->findDecoration<IRDifferentialInstDecoration>())
- continue;
- cloneEnvs[block] = IRCloneEnv();
- for (auto inst : block->getChildren())
- {
- workList.add(inst);
- }
- }
-
- for (Index i = 0; i < workList.getCount(); i++)
- {
- auto inst = workList[i];
- for (UInt j = 0; j < inst->getOperandCount(); j++)
- {
- auto operand = inst->getOperand(j);
- if (operand->getOp() == kIROp_Block)
- continue;
- auto operandParent = inst->getOperand(j)->getParent();
- if (!operandParent)
- continue;
- if (operandParent->parent != diffPropFunc)
- continue;
- if (domTree->dominates(operandParent, inst->parent))
- continue;
-
- // The def site of the operand does not dominate the use.
- // We need to insert a local variable to store this var.
-
- IRInst* operandReplacement = nullptr;
- if (canTypeBeStored(operand->getDataType()))
- {
- auto var = storeInstAsLocalVar(operand);
- builder.setInsertBefore(inst);
- operandReplacement = builder.emitLoad(var);
- }
- else if (operand->getOp() == kIROp_Var)
- {
- // Var can just be hoisted to first block.
- operand->insertBefore(firstBlock->getFirstOrdinaryInst());
- }
- else
- {
- // For all other insts, we need to copy it to right before this inst.
- // Before actually copying it, check if we have already copied it to
- // any blocks that dominates this block.
- auto dom = as<IRBlock>(inst->getParent());
- while (dom)
- {
- auto subCloneEnv = cloneEnvs.TryGetValue(dom);
- if (!subCloneEnv) break;
- if (subCloneEnv->mapOldValToNew.TryGetValue(operand, operandReplacement))
- {
- break;
- }
- dom = domTree->getImmediateDominator(dom);
- }
- // We have not found an existing clone in dominators, so we need to copy it
- // to this block.
- if (!operandReplacement)
- {
- auto subCloneEnv = cloneEnvs.TryGetValue(as<IRBlock>(inst->getParent()));
- builder.setInsertBefore(inst);
- operandReplacement = cloneInst(subCloneEnv, &builder, operand);
- workList.add(operandReplacement);
- }
- }
- if (operandReplacement)
- builder.replaceOperand(inst->getOperands() + j, operandReplacement);
- }
- }
- }
-
InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType)
{
SLANG_UNUSED(primalType);
@@ -774,7 +674,7 @@ namespace Slang
// Copy primal insts to the first block of the unzipped function, copy diff insts to the
// second block of the unzipped function.
//
- diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
+ RefPtr<HoistedPrimalsInfo> primalsInfo = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
IRFunc* unzippedFwdDiffFunc = fwdDiffFunc;
// Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell.
@@ -801,8 +701,8 @@ namespace Slang
// Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
// derivative of the return value.
- DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam };
- diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info);
+ DiffTransposePass::FuncTranspositionInfo transposeInfo = { paramTransposeInfo.dOutParam, primalsInfo };
+ diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, transposeInfo);
eliminateDeadCode(diffPropagateFunc);
@@ -810,7 +710,7 @@ namespace Slang
// with the intermediate results computed from the extracted func.
IRInst* intermediateType = nullptr;
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
- diffPropagateFunc, primalFunc, paramTransposeInfo, intermediateType);
+ diffPropagateFunc, primalFunc, primalsInfo, paramTransposeInfo, intermediateType);
// At this point the unzipped func is just an empty shell
// and we can simply remove it.
@@ -870,7 +770,7 @@ namespace Slang
initializeLocalVariables(builder->getModule(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric)));
initializeLocalVariables(builder->getModule(), diffPropagateFunc);
- insertVariableForRecomputedPrimalInsts(diffPropagateFunc);
+ // insertVariableForRecomputedPrimalInsts(diffPropagateFunc);
stripTempDecorations(diffPropagateFunc);
}