diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-03-15 22:26:58 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-15 19:26:58 -0700 |
| commit | 71efd949fa5276e2464416fcf237f8fd2c486281 (patch) | |
| tree | a5b24cd077f2ecc3f74d4dd4671c8260eb6e9b67 /source/slang/slang-ir-autodiff-rev.cpp | |
| parent | 38e62199cc75ce34608491c8dd299eb330bde518 (diff) | |
AD: Primal-Hoisting Rework + Checkpoint Policy Framework (#2702)
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 110 |
1 files changed, 5 insertions, 105 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 328af4867..157011b7c 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -606,106 +606,6 @@ namespace Slang return fwdDiffFunc; } - void BackwardDiffTranscriberBase::insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc) - { - RefPtr<IRDominatorTree> domTree = computeDominatorTree(diffPropFunc); - auto firstBlock = diffPropFunc->getFirstBlock(); - if (!firstBlock) - return; - Dictionary<IRInst*, IRVar*> instVars; - Dictionary<IRBlock*, IRCloneEnv> cloneEnvs; - auto storeInstAsLocalVar = [&](IRInst* inst) - { - IRVar* var = nullptr; - if (instVars.TryGetValue(inst, var)) - return var; - IRBuilder builder(diffPropFunc); - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - var = builder.emitVar(inst->getDataType()); - builder.emitStore(var, builder.emitDefaultConstruct(inst->getDataType())); - - setInsertAfterOrdinaryInst(&builder, inst); - builder.emitStore(var, inst); - instVars[inst] = var; - return var; - }; - - IRBuilder builder(diffPropFunc); - List<IRInst*> workList; - for (auto block : diffPropFunc->getBlocks()) - { - if (!block->findDecoration<IRDifferentialInstDecoration>()) - continue; - cloneEnvs[block] = IRCloneEnv(); - for (auto inst : block->getChildren()) - { - workList.add(inst); - } - } - - for (Index i = 0; i < workList.getCount(); i++) - { - auto inst = workList[i]; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto operand = inst->getOperand(j); - if (operand->getOp() == kIROp_Block) - continue; - auto operandParent = inst->getOperand(j)->getParent(); - if (!operandParent) - continue; - if (operandParent->parent != diffPropFunc) - continue; - if (domTree->dominates(operandParent, inst->parent)) - continue; - - // The def site of the operand does not dominate the use. - // We need to insert a local variable to store this var. - - IRInst* operandReplacement = nullptr; - if (canTypeBeStored(operand->getDataType())) - { - auto var = storeInstAsLocalVar(operand); - builder.setInsertBefore(inst); - operandReplacement = builder.emitLoad(var); - } - else if (operand->getOp() == kIROp_Var) - { - // Var can just be hoisted to first block. - operand->insertBefore(firstBlock->getFirstOrdinaryInst()); - } - else - { - // For all other insts, we need to copy it to right before this inst. - // Before actually copying it, check if we have already copied it to - // any blocks that dominates this block. - auto dom = as<IRBlock>(inst->getParent()); - while (dom) - { - auto subCloneEnv = cloneEnvs.TryGetValue(dom); - if (!subCloneEnv) break; - if (subCloneEnv->mapOldValToNew.TryGetValue(operand, operandReplacement)) - { - break; - } - dom = domTree->getImmediateDominator(dom); - } - // We have not found an existing clone in dominators, so we need to copy it - // to this block. - if (!operandReplacement) - { - auto subCloneEnv = cloneEnvs.TryGetValue(as<IRBlock>(inst->getParent())); - builder.setInsertBefore(inst); - operandReplacement = cloneInst(subCloneEnv, &builder, operand); - workList.add(operandReplacement); - } - } - if (operandReplacement) - builder.replaceOperand(inst->getOperands() + j, operandReplacement); - } - } - } - InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) { SLANG_UNUSED(primalType); @@ -774,7 +674,7 @@ namespace Slang // Copy primal insts to the first block of the unzipped function, copy diff insts to the // second block of the unzipped function. // - diffUnzipPass->unzipDiffInsts(fwdDiffFunc); + RefPtr<HoistedPrimalsInfo> primalsInfo = diffUnzipPass->unzipDiffInsts(fwdDiffFunc); IRFunc* unzippedFwdDiffFunc = fwdDiffFunc; // Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell. @@ -801,8 +701,8 @@ namespace Slang // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the // derivative of the return value. - DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam }; - diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info); + DiffTransposePass::FuncTranspositionInfo transposeInfo = { paramTransposeInfo.dOutParam, primalsInfo }; + diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, transposeInfo); eliminateDeadCode(diffPropagateFunc); @@ -810,7 +710,7 @@ namespace Slang // with the intermediate results computed from the extracted func. IRInst* intermediateType = nullptr; auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc( - diffPropagateFunc, primalFunc, paramTransposeInfo, intermediateType); + diffPropagateFunc, primalFunc, primalsInfo, paramTransposeInfo, intermediateType); // At this point the unzipped func is just an empty shell // and we can simply remove it. @@ -870,7 +770,7 @@ namespace Slang initializeLocalVariables(builder->getModule(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric))); initializeLocalVariables(builder->getModule(), diffPropagateFunc); - insertVariableForRecomputedPrimalInsts(diffPropagateFunc); + // insertVariableForRecomputedPrimalInsts(diffPropagateFunc); stripTempDecorations(diffPropagateFunc); } |
