diff options
| author | Yong He <yonghe@outlook.com> | 2023-05-31 14:34:40 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-31 14:34:40 -0700 |
| commit | a7ed48b2e6da9bf952aa11ec0d57acf9688bbb0e (patch) | |
| tree | 920c3407b7401103a36f8e5d41e911c3ba934aaa /source | |
| parent | 02bb741a8d1b4ed31a65c46b7e43d153b42a7b73 (diff) | |
Fix def-use legalization in CFG normalization. (#2909)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 177 |
1 files changed, 97 insertions, 80 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 30c8a934e..0a22e112e 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -5,6 +5,7 @@ #include "slang-ir-validate.h" #include "slang-ir-util.h" +#include "slang-ir-dominators.h" namespace Slang { @@ -638,6 +639,95 @@ struct CFGNormalizationPass } }; +static void legalizeDefUse(IRGlobalValueWithCode* func) +{ + auto dom = computeDominatorTree(func); + for (auto block : func->getBlocks()) + { + for (auto inst : block->getModifiableChildren()) + { + // Inspect all uses of `inst` and find the common dominator of all use sites. + IRBlock* commonDominator = block; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto userBlock = as<IRBlock>(use->getUser()->getParent()); + if (!userBlock) + continue; + while (commonDominator && !dom->dominates(commonDominator, userBlock)) + { + commonDominator = dom->getImmediateDominator(commonDominator); + } + } + SLANG_ASSERT(commonDominator); + + if (commonDominator == block) + continue; + + // If the common dominator is not `block`, it means we have detected + // uses that is no longer dominated by the current definition, and need + // to be fixed. + + // Normally, we can simply move the definition to the common dominator. + // An exception is when the common dominator is the target block of a + // loop. Note that after normalization, loops are in the form of: + // ``` + // loop { if (condition) block; else break; } + // ``` + // If we find ourselves needing to make the inst available right before + // the `if`, it means we are seeing uses of the inst outside the loop. + // In this case, we should insert a var/move the inst before the loop + // instead of before the `if`. This situation can occur in the IR if + // the original code is lowered from a `do-while` loop. + for (auto use = commonDominator->firstUse; use; use = use->nextUse) + { + if (auto loopUser = as<IRLoop>(use->getUser())) + { + if (loopUser->getTargetBlock() == commonDominator) + { + commonDominator = as<IRBlock>(loopUser->getParent()); + break; + } + } + } + // Now we can legalize uses based on the type of `inst`. + if (auto var = as<IRVar>(inst)) + { + // If inst is an var, this is easy, we just move it to the + // common dominator. + var->insertBefore(commonDominator->getTerminator()); + } + else + { + // For all other insts, we need to create a local var for it, + // and replace all uses with a load from the local var. + IRBuilder builder(func); + builder.setInsertBefore(commonDominator->getTerminator()); + IRVar* tempVar = builder.emitVar(inst->getFullType()); + auto defaultVal = builder.emitDefaultConstruct(inst->getFullType()); + builder.emitStore(tempVar, defaultVal); + + builder.setInsertAfter(inst); + builder.emitStore(tempVar, inst); + + traverseUses(inst, [&](IRUse* use) + { + auto userBlock = as<IRBlock>(use->getUser()->getParent()); + if (!userBlock) return; + // Only fix the use of the current definition of `inst` does not + // dominate it. + if (!dom->dominates(block, userBlock)) + { + // Replace the use with a load of tempVar. + builder.setInsertBefore(use->getUser()); + auto load = builder.emitLoad(tempVar); + builder.replaceOperand(use, load); + } + }); + } + } + } +} + void normalizeCFG( IRModule* module, IRGlobalValueWithCode* func, IRCFGNormalizationPass const& options) { @@ -674,90 +764,17 @@ void normalizeCFG( } } - // If we created a new condition block for a loop, the local vars defined in - // the original loop body will no longer dominate the exit block of the - // loop. If there are any uses of these variables outside the loop, they - // will become invalid. Therefore we need to hoist the local variables to - // the loop header block. - HashSet<IRBlock*> workListSet; - for (auto block : func->getBlocks()) - { - if (auto loop = as<IRLoop>(block->getTerminator())) - { - auto condBlock = loop->getTargetBlock(); - auto ifElse = as<IRIfElse>(condBlock->getTerminator()); - auto bodyBlock = ifElse->getTrueBlock(); - - // Collect loop body blocks. - workList.clear(); - workListSet.clear(); - workList.add(bodyBlock); - workListSet.add(bodyBlock); - for (Index i = 0; i < workList.getCount(); i++) - { - auto b = workList[i]; - for (auto succ : b->getSuccessors()) - { - if (succ != loop->getTargetBlock() && succ != loop->getBreakBlock()) - { - if (workListSet.add(succ)) - workList.add(succ); - } - } - } - auto insertionPoint = loop; - IRBuilder builder(func); - for (auto b : workList) - { - for (auto inst : b->getModifiableChildren()) - { - // If inst has uses outside the loop body, we need to hoist it. - IRVar* tempVar = nullptr; - if (auto var = as<IRVar>(inst)) - { - for (auto use = inst->firstUse; use; use = use->nextUse) - { - // If inst is an var, this is easy, we just move it to the - // loop header. - auto userBlock = as<IRBlock>(use->getUser()->getParent()); - if (userBlock && !workListSet.contains(userBlock)) - { - var->insertBefore(insertionPoint); - break; - } - } - } - else - { - traverseUses(inst, [&](IRUse* use) - { - auto userBlock = as<IRBlock>(use->getUser()->getParent()); - if (userBlock && !workListSet.contains(userBlock)) - { - // For all other insts, we need to create a local var for it. - if (!tempVar) - { - builder.setInsertBefore(insertionPoint); - tempVar = builder.emitVar(inst->getFullType()); - builder.setInsertAfter(inst); - builder.emitStore(tempVar, inst); - } - // Replace the use with a load of tempVar. - builder.setInsertBefore(use->getUser()); - auto load = builder.emitLoad(tempVar); - builder.replaceOperand(use, load); - } - }); - } - } - } - } - } + // After CFG normalization, there may be invalid uses of var/ssa values where the def + // no longer dominate the use. We fix these up by going through the IR and create temp + // vars for such uses. + sortBlocksInFunc(func); + legalizeDefUse(func); + disableIRValidationAtInsert(); constructSSA(module, func); enableIRValidationAtInsert(); #if _DEBUG - validateIRInst(func); + validateIRInst(maybeFindOuterGeneric(func)); #endif } |
