diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-05-06 03:03:25 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-06 00:03:25 -0700 |
| commit | 271dc1b98d3887b6297c5407dc67692716687f4d (patch) | |
| tree | a714a41f6a490000545e82cadd20561a020b0a1e /source | |
| parent | 0602eaaba32bdbaf3f99ab8987e97419cba395aa (diff) | |
Don't store loop induction values + fix minor issue (#2872)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 381 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.h | 38 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 10 |
3 files changed, 357 insertions, 72 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 135c72556..ab23aeb40 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -11,7 +11,8 @@ void applyCheckpointSet( HoistedPrimalsInfo* hoistInfo, HashSet<IRUse*>& pendingUses, Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock, - IROutOfOrderCloneContext* cloneCtx); + IROutOfOrderCloneContext* cloneCtx, + Dictionary<IRBlock*, List<IndexTrackingInfo>>& blockIndexInfo); bool containsOperand(IRInst* inst, IRInst* operand) { @@ -260,8 +261,11 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks( RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( IRGlobalValueWithCode* func, Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock, - IROutOfOrderCloneContext* cloneCtx) + IROutOfOrderCloneContext* cloneCtx, + Dictionary<IRBlock*, List<IndexTrackingInfo>>& blockIndexInfo) { + collectInductionValues(func); + RefPtr<CheckpointSetInfo> checkpointInfo = new CheckpointSetInfo(); RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); @@ -362,6 +366,12 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( if (auto param = as<IRParam>(result.instToRecompute)) { + if (auto inductionInfo = inductionValueInsts.tryGetValue(param)) + { + checkpointInfo->loopInductionInfo.addIfNotExists(param, *inductionInfo); + continue; + } + // Add in the branch-args of every predecessor block. auto paramBlock = as<IRBlock>(param->getParent()); UIndex paramIndex = 0; @@ -389,14 +399,19 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( { if (auto var = as<IRVar>(result.instToRecompute)) { - IRUse* storeUse = findLatestUniqueWriteUse(var); - if (storeUse) + for (auto varUse = var->firstUse; varUse; varUse = varUse->nextUse) { - // When we have a var and a store/call insts that writes to the var, - // we treat as if there is a pseudo-use of the store/call to compute - // the var inst, i.e. the var depends on the store/call, despite - // the IR's def-use chain doesn't reflect this. - workList.add(UseOrPseudoUse(var, storeUse->getUser())); + switch (varUse->getUser()->getOp()) + { + case kIROp_Store: + case kIROp_Call: + // When we have a var and a store/call insts that writes to the var, + // we treat as if there is a pseudo-use of the store/call to compute + // the var inst, i.e. the var depends on the store/call, despite + // the IR's def-use chain doesn't reflect this. + workList.add(UseOrPseudoUse(var, varUse->getUser())); + break; + } } } else @@ -429,13 +444,20 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( { for (auto use = var->firstUse; use; use = use->nextUse) { - auto callUser = as<IRCall>(use->getUser()); - if (!callUser) - continue; - checkpointInfo->recomputeSet.add(callUser); - checkpointInfo->storeSet.remove(callUser); - if (instWorkListSet.add(callUser)) - instWorkList.add(callUser); + if (auto callUser = as<IRCall>(use->getUser())) + { + checkpointInfo->recomputeSet.add(callUser); + checkpointInfo->storeSet.remove(callUser); + if (instWorkListSet.add(callUser)) + instWorkList.add(callUser); + } + else if (auto storeUser = as<IRStore>(use->getUser())) + { + checkpointInfo->recomputeSet.add(storeUser); + checkpointInfo->storeSet.remove(storeUser); + if (instWorkListSet.add(callUser)) + instWorkList.add(callUser); + } } } else if (auto call = as<IRCall>(inst)) @@ -454,15 +476,198 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( } RefPtr<HoistedPrimalsInfo> hoistInfo = new HoistedPrimalsInfo(); - applyCheckpointSet(checkpointInfo, func, hoistInfo, usesToReplace, mapDiffBlockToRecomputeBlock, cloneCtx); + applyCheckpointSet(checkpointInfo, func, hoistInfo, usesToReplace, mapDiffBlockToRecomputeBlock, cloneCtx, blockIndexInfo); return hoistInfo; } +void AutodiffCheckpointPolicyBase::collectInductionValues(IRGlobalValueWithCode* func) +{ + // Collect loop induction values. + // There are two special phi insts we want to handle differently in our + // checkpointing policy: + // 1. a bool execution flag inserted as the result of CFG normalization, + // that is always true as long as the loop is still active. + // 2. the original induction variable that can be replaced with the loop + // counter we inserted during createPrimalRecomputeBlocks(). + + for (auto block : func->getBlocks()) + { + auto loopInst = as<IRLoop>(block->getTerminator()); + if (!loopInst) + continue; + auto targetBlock = loopInst->getTargetBlock(); + auto ifElse = as<IRIfElse>(targetBlock->getTerminator()); + Int paramIndex = -1; + Int conditionParamIndex = -1; + // First, we are going to collect all the bool execution flags from loops. + // These are very easy to identify: they are a phi param defined in + // targetBlock, and used as the condition value in the condtion block. + for (auto param : targetBlock->getParams()) + { + paramIndex++; + if (!param->getDataType()) + continue; + if (param->getDataType()->getOp() == kIROp_BoolType) + { + if (ifElse->getCondition() == param) + { + // The bool param is used as the condition of the if-else inside the loop, + // this param will always be true during the loop, and we don't need to store it. + LoopInductionValueInfo info; + info.kind = LoopInductionValueInfo::Kind::AlwaysTrue; + inductionValueInsts[param] = info; + conditionParamIndex = paramIndex; + } + } + } + if (conditionParamIndex == -1) + continue; + + // Next, we try to identify the original induction variables, if they exist. + // These are trickier, and we have to hard code the complex pattern that + // we can recognize. + // This pattern matching logic is ugly and fragile against changes to cfg + // normalization, but it is the easiest way to do it right now. + // Basically, we are looking for this pattern: + // loop(..., i=initVal) + // { + // targetBlock: + // ... + // param int i; + // param bool condition; + // ... + // branch condtionBlock; + // conditionBlock: + // if (condition) + // { + // } + // else + // { + // break; + // } + // // ... + // someBodyBlock: + // ... + // if (condition) + // { + // ... + // // Check condition 1: i is used by an `add` + // // Check condition 2: parent of (i+1) is a branch target of if(condition) + // // Check condition 3: branches to parentBlock with i1 = i + 1. + // goto parentBlock(i + 1); + // } + // else + // goto parentBlock(other); + // parentBlock: + // // Check condition 4: parentBlock branches to finalBlock. + // param int i1; + // goto finalBlock; + // finalBlock: + // // Check condition 5: finalBlock branches to targetBlock with new i = i1. + // goto loopHeader(i1); + // } + // + paramIndex = -1; + for (auto param : targetBlock->getParams()) + { + paramIndex++; + if (!param->getDataType()) + continue; + if (isScalarIntegerType(param->getDataType())) + { + // If the param is always equal to the loop index, we don't need to store it. + IRInst* addUse = nullptr; + for (auto use = param->firstUse; use && !addUse; use = use->nextUse) + { + auto user = use->getUser(); + if (user->getOp() != kIROp_Add) + continue; + auto intLit = as<IRIntLit>(use->getUser()->getOperand(1)); + if (!intLit) + continue; + if (intLit->getValue() != 1) + continue; + + // The add inst's parent block is behind a `ifelse(loopCondition)`. + auto addInstBlock = as<IRBlock>(user->getParent()); + if (!addInstBlock) + continue; + auto predecessors = addInstBlock->getPredecessors(); + if (predecessors.getCount() != 1) + continue; + auto parentIfElse = as<IRIfElse>(predecessors.b->getUser()); + if (!parentIfElse) + continue; + auto parentCondition = parentIfElse->getCondition(); + + auto branch = as<IRUnconditionalBranch>(addInstBlock->getTerminator()); + if (!branch) + continue; + + // The add inst should be used as a branchArg. + UInt argIndex = 0; + for (UInt i = 0; i < branch->getArgCount(); i++) + { + if (branch->getArg(i) == user) + { + addUse = user; + argIndex = i; + break; + } + } + if (!addUse) + continue; + auto branchTarget1 = branch->getTargetBlock(); + auto branchParam = branchTarget1->getFirstParam(); + for (UInt i = 0; i < argIndex; i++) + if (branchParam) + branchParam = branchParam->getNextParam(); + if (!branchParam) + continue; + + // The branchParam is used as argument to branch back to loop header. + auto branch2 = as<IRUnconditionalBranch>(branchTarget1->getTerminator()); + if (!branch2) + continue; + if (branch2->getTargetBlock() != targetBlock) + continue; + argIndex = 0; + for (UInt i = 0; i < branch2->getArgCount(); i++) + { + if (branch2->getArg(i) == branchParam) + { + argIndex = i; + break; + } + } + if (argIndex != (UInt)paramIndex) + continue; + + // parentCondition is also used as the new condition in the back jump. + if (conditionParamIndex < 0 || (UInt)conditionParamIndex >= branch2->getArgCount() || + branch2->getArg((UInt)conditionParamIndex) != parentCondition) + continue; + + // The use of the add inst matches all of our conditions as an induction value + // that is equivalent to loop counter. + LoopInductionValueInfo info; + info.kind = LoopInductionValueInfo::Kind::EqualsToCounter; + info.loopInst = loopInst; + info.counterOffset = loopInst->getArg(paramIndex); + inductionValueInsts[param] = info; + break; + } + } + } + } +} + void applyToInst( IRBuilder* builder, CheckpointSetInfo* checkpointInfo, HoistedPrimalsInfo* hoistInfo, IROutOfOrderCloneContext* cloneCtx, + Dictionary<IRBlock*, List<IndexTrackingInfo>>& blockIndexInfo, IRInst* inst) { // Early-out.. @@ -483,6 +688,35 @@ void applyToInst( { return; } + // If this is loop condition, it is always true in reverse blocks. + LoopInductionValueInfo inductionValueInfo; + if (checkpointInfo->loopInductionInfo.tryGetValue(inst, inductionValueInfo)) + { + IRInst* replacement = nullptr; + if (inductionValueInfo.kind == LoopInductionValueInfo::Kind::AlwaysTrue) + { + replacement = builder->getBoolValue(true); + } + else if (inductionValueInfo.kind == LoopInductionValueInfo::Kind::EqualsToCounter) + { + auto indexInfo = blockIndexInfo.tryGetValue(inductionValueInfo.loopInst->getTargetBlock()); + SLANG_ASSERT(indexInfo); + SLANG_ASSERT(indexInfo->getCount() != 0); + replacement = indexInfo->getFirst().diffCountParam; + if (inductionValueInfo.counterOffset) + { + setInsertAfterOrdinaryInst(builder, replacement); + replacement = builder->emitAdd( + replacement->getDataType(), + replacement, + inductionValueInfo.counterOffset); + } + } + SLANG_ASSERT(replacement); + cloneCtx->cloneEnv.mapOldValToNew[inst] = replacement; + cloneCtx->registerClonedInst(builder, inst, replacement); + return; + } } auto recomputeInst = cloneCtx->cloneInstOutOfOrder(builder, inst); @@ -524,7 +758,8 @@ void applyCheckpointSet( HoistedPrimalsInfo* hoistInfo, HashSet<IRUse*>& pendingUses, Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock, - IROutOfOrderCloneContext* cloneCtx) + IROutOfOrderCloneContext* cloneCtx, + Dictionary<IRBlock*, List<IndexTrackingInfo>>& blockIndexInfo) { for (auto use : pendingUses) cloneCtx->pendingUses.add(use); @@ -554,16 +789,22 @@ void applyCheckpointSet( builder.setInsertBefore(recomputeInsertBeforeInst); bool isRecomputed = checkpointInfo->recomputeSet.contains(param); bool isInverted = checkpointInfo->invertSet.contains(param); - + bool loopInductionInfo = checkpointInfo->loopInductionInfo.tryGetValue(param); if (!isRecomputed && !isInverted) continue; - SLANG_RELEASE_ASSERT( - recomputeBlock != block && - "recomputed param should belong to block that has recompute block."); + if (!loopInductionInfo) + { + SLANG_RELEASE_ASSERT( + recomputeBlock != block && + "recomputed param should belong to block that has recompute block."); + } // Apply checkpoint rule to the parameter itself. - applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, param); + applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, blockIndexInfo, param); + + if (loopInductionInfo) + continue; // Copy primal branch-arg for predecessor blocks. HashSet<IRBlock*> predecessorSet; @@ -620,7 +861,7 @@ void applyCheckpointSet( builder.setInsertBefore(getParamPreludeBlock(func)->getTerminator()); } } - applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child); + applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, blockIndexInfo, child); } } @@ -1267,7 +1508,7 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func) // RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(func->getModule()); chkPolicy->preparePolicy(func); - auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap, cloneCtx); + auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap, cloneCtx, indexedBlockInfo); // Legalize the primal inst accesses by introducing local variables / arrays and emitting // necessary load/store logic. @@ -1306,20 +1547,37 @@ static CheckpointPreference getCheckpointPreference(IRInst* callee) return CheckpointPreference::None; } -static bool isGlobalAddress(IRInst* inst) +static bool isGlobalMutableAddress(IRInst* inst) { auto root = getRootAddr(inst); if (root) { if (as<IRParameterGroupType>(root->getDataType())) { - return true; + return false; } return as<IRModuleInst>(root->getParent()) != nullptr; } return false; } +static bool isInstInPrimalOrTransposedParameterBlocks(IRInst* inst) +{ + auto func = getParentFunc(inst); + if (!func) + return false; + auto firstBlock = func->getFirstBlock(); + if (inst->getParent() == firstBlock) + return true; + auto branch = as<IRUnconditionalBranch>(firstBlock->getTerminator()); + if (!branch) + return false; + auto secondBlock = branch->getTargetBlock(); + if (inst->getParent() == secondBlock) + return true; + return false; +} + static bool shouldStoreInst(IRInst* inst) { if (!inst->getDataType()) @@ -1406,10 +1664,16 @@ static bool shouldStoreInst(IRInst* inst) return false; case kIROp_Load: - // Never store a load of a global parameter/variable. - if (isGlobalAddress(as<IRLoad>(inst)->getPtr())) - return false; - break; + // In general, don't store loads, because: + // - Loads to constant data can just be reloaded. + // - Loads to local variables can only exist for the temp variables used for calls, + // those variables are written only once so we can always load them anytime. + // - Loads to global mutable variables are now allowed, but we will capture that + // case in canRecompute(). + // - The only exception is the load of an inout param, in which case we do need + // to store it because the param may be modified by the func at exit. Similarly, + // this will be handled in canRecompute(). + return false; case kIROp_Call: // If the callee prefers recompute policy, don't store. @@ -1462,47 +1726,38 @@ static bool shouldStoreVar(IRVar* var) return false; } -bool canRecompute(UseOrPseudoUse use) +bool DefaultCheckpointPolicy::canRecompute(UseOrPseudoUse use) { if (auto load = as<IRLoad>(use.usedVal)) { - // Generally, we cannot recompute a load(ptr), since ptr may be modified - // afterwards. - // - // The exceptions are a load of an inout param or global param, since the - // propagation function never actually writes to the primal part of the - // inout param, and we can always just read the original param. - auto ptr = load->getPtr(); - if (ptr->getOp() == kIROp_Param) - { - if (auto block = as<IRBlock>(ptr->getParent())) - { - return (block == block->getParent()->getFirstBlock()); - } - } - else if (ptr->getOp() == kIROp_GlobalParam) - { - return true; - } - else if (as<IRParameterGroupType>(ptr->getDataType())) + + // We can't recompute a `load` is if it is a load from a global mutable + // variable. + if (isGlobalMutableAddress(ptr)) + return false; + + // We can't recompute a 'load' from a mutable function parameter. + if (as<IRParam>(ptr) || as<IRVar>(ptr)) { - return true; + if (isInstInPrimalOrTransposedParameterBlocks(ptr)) + return false; } - return false; } - auto param = as<IRParam>(use.usedVal); - if (!param) - return true; - - // We can recompute a phi param if it is not in a loop start block. - auto parentBlock = as<IRBlock>(param->getParent()); - for (auto pred : parentBlock->getPredecessors()) + else if (auto param = as<IRParam>(use.usedVal)) { - if (auto loop = as<IRLoop>(pred->getTerminator())) + if (inductionValueInsts.containsKey(param)) + return true; + + // We can recompute a phi param if it is not in a loop start block. + auto parentBlock = as<IRBlock>(param->getParent()); + for (auto pred : parentBlock->getPredecessors()) { - if (loop->getTargetBlock() == parentBlock) - return false; + if (auto loop = as<IRLoop>(pred->getTerminator())) + { + if (loop->getTargetBlock() == parentBlock) + return false; + } } } return true; diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index e9fc0d4a5..c9377d56b 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -19,11 +19,18 @@ namespace Slang UInt operandCount = clonedInst->getOperandCount(); for (UInt ii = 0; ii < operandCount; ++ii) { - auto oldOperand = inst->getOperand(ii); auto newOperand = clonedInst->getOperand(ii); - - if (oldOperand == newOperand) - pendingUses.add(&clonedInst->getOperands()[ii]); + // If operand is in a differential or recompute block, it means it has already + // been cloned, so we don't add it to pending uses. + if (auto operandParent = as<IRBlock>(newOperand->getParent())) + { + if (isDifferentialOrRecomputeBlock(operandParent)) + { + continue; + } + } + // Otherwise, add it to pending uses. + pendingUses.add(&clonedInst->getOperands()[ii]); } for (auto use = inst->firstUse; use;) @@ -221,6 +228,18 @@ namespace Slang return primalCountParam == other.primalCountParam; } }; + + struct LoopInductionValueInfo + { + enum Kind + { + AlwaysTrue, + EqualsToCounter, + }; + Kind kind; + IRLoop* loopInst = nullptr; + IRInst* counterOffset = nullptr; + }; // Information on which insts are to be stored, recomputed // and inverted within a single function. @@ -232,7 +251,7 @@ namespace Slang HashSet<IRInst*> storeSet; HashSet<IRInst*> recomputeSet; HashSet<IRInst*> invertSet; - + Dictionary<IRInst*, LoopInductionValueInfo> loopInductionInfo; Dictionary<IRInst*, InversionInfo> invInfoMap; }; @@ -289,7 +308,8 @@ namespace Slang RefPtr<HoistedPrimalsInfo> processFunc( IRGlobalValueWithCode* func, Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock, - IROutOfOrderCloneContext* cloneCtx); + IROutOfOrderCloneContext* cloneCtx, + Dictionary<IRBlock*, List<IndexTrackingInfo>>& blockIndexInfo); // Do pre-processing on the function (mainly for // 'global' checkpointing methods that consider the entire @@ -302,6 +322,8 @@ namespace Slang protected: IRModule* module; + Dictionary<IRInst*, LoopInductionValueInfo> inductionValueInsts; + void collectInductionValues(IRGlobalValueWithCode* func); }; class DefaultCheckpointPolicy : public AutodiffCheckpointPolicyBase @@ -314,6 +336,10 @@ namespace Slang virtual void preparePolicy(IRGlobalValueWithCode* func); virtual HoistResult classify(UseOrPseudoUse use); + + private: + bool canRecompute(UseOrPseudoUse use); + }; RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index f6a977994..6d56736ad 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1909,7 +1909,7 @@ IRUse* findUniqueStoredVal(IRVar* var) // the final value to it, this method will return the call inst for this case. IRUse* findLatestUniqueWriteUse(IRVar* var) { - IRUse* storeUse = nullptr; + IRUse* callUse = nullptr; for (auto use = var->firstUse; use; use = use->nextUse) { if (const auto callInst = as<IRCall>(use->getUser())) @@ -1917,10 +1917,14 @@ IRUse* findLatestUniqueWriteUse(IRVar* var) // Ignore uses from differential blocks. if (callInst->getParent()->findDecoration<IRDifferentialInstDecoration>()) continue; - SLANG_RELEASE_ASSERT(!storeUse); - storeUse = use; + SLANG_RELEASE_ASSERT(!callUse); + callUse = use; } } + + if (callUse) + return callUse; + // If no unique call found, try to look for a store. return findUniqueStoredVal(var); } |
