From 46a4d98baa1d43b33717b4377aefeeaf46b9c2ff Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 23 Jan 2023 06:59:25 -0800 Subject: Full address insts elimination for backward autodiff. (#2604) Co-authored-by: Yong He --- source/slang/slang-ir-address-analysis.cpp | 173 +++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 source/slang/slang-ir-address-analysis.cpp (limited to 'source/slang/slang-ir-address-analysis.cpp') diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp new file mode 100644 index 000000000..aba59e1de --- /dev/null +++ b/source/slang/slang-ir-address-analysis.cpp @@ -0,0 +1,173 @@ +#include "slang-ir-address-analysis.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +namespace Slang +{ + void moveInstToEarliestPoint(IRDominatorTree* domTree, IRGlobalValueWithCode* func, IRInst* inst) + { + if (!as(inst->getParent())) + return; + if (domTree->isUnreachable(as(inst->getParent()))) + return; + + List blocks; + HashSet operandInsts; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + operandInsts.Add(inst->getOperand(i)); + auto parentBlock = as(inst->getOperand(i)->getParent()); + if (parentBlock) + { + if (!domTree->isUnreachable(parentBlock)) + blocks.add(parentBlock); + } + } + { + operandInsts.Add(inst->getFullType()); + auto parentBlock = as(inst->getFullType()->getParent()); + if (parentBlock) + { + if (!domTree->isUnreachable(parentBlock)) + blocks.add(parentBlock); + } + } + // Find earliest block that is dominated by all operand blocks. + IRBlock* earliestBlock = as(inst->getParent()); + for (auto block : func->getBlocks()) + { + bool dominated = true; + for (auto opBlock : blocks) + { + if (!domTree->dominates(opBlock, block)) + { + dominated = false; + break; + } + } + if (dominated) + { + earliestBlock = block; + break; + } + } + + if (!earliestBlock) + return; + + IRInst* latestOperand = nullptr; + for (auto childInst : earliestBlock->getChildren()) + { + if (operandInsts.Contains(childInst)) + { + latestOperand = childInst; + } + } + + if (!latestOperand || as(latestOperand)) + inst->insertBefore(earliestBlock->getFirstOrdinaryInst()); + else + inst->insertAfter(latestOperand); + } + + AddressAccessInfo analyzeAddressUse(IRDominatorTree* dom, IRGlobalValueWithCode* func) + { + DeduplicateContext deduplicateContext; + + AddressAccessInfo info; + + // Deduplicate and move known address insts. + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstChild(); inst;) + { + auto next = inst->getNextInst(); + switch (inst->getOp()) + { + case kIROp_Var: + { + RefPtr addrInfo = new AddressInfo(); + addrInfo->addrInst = inst; + addrInfo->isConstant = true; + addrInfo->parentAddress = nullptr; + info.addressInfos[inst] = addrInfo; + } + break; + case kIROp_Param: + if (as(inst->getFullType())) + { + RefPtr addrInfo = new AddressInfo(); + addrInfo->addrInst = inst; + addrInfo->isConstant = (block == func->getFirstBlock() ? true : false); + addrInfo->parentAddress = nullptr; + info.addressInfos[inst] = addrInfo; + } + break; + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + moveInstToEarliestPoint(dom, func, inst->getFullType()); + moveInstToEarliestPoint(dom, func, inst); + auto deduplicated = deduplicateContext.deduplicate(inst, [func](IRInst* inst) + { + if (!inst->getParent()) + return false; + if (inst->getParent()->getParent() != func) + return false; + switch (inst->getOp()) + { + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + return true; + default: + return false; + } + }); + + if (deduplicated != inst) + { + SLANG_RELEASE_ASSERT(dom->dominates( + as(deduplicated->getParent()), + as(inst->getParent()))); + + inst->replaceUsesWith(deduplicated); + inst->removeAndDeallocate(); + } + else + { + RefPtr addrInfo = new AddressInfo(); + addrInfo->addrInst = inst; + if (inst->getOp() == kIROp_FieldAddress) + { + addrInfo->isConstant = true; + } + else + { + addrInfo->isConstant = + as(inst->getOperand(1)) ? true : false; + } + info.addressInfos[inst] = addrInfo; + } + } + break; + } + inst = next; + } + } + + // Construct address info tree. + for (auto& addr : info.addressInfos) + { + RefPtr parentInfo; + if (addr.Value->addrInst->getOperandCount() > 1 && + info.addressInfos.TryGetValue(addr.Value->addrInst->getOperand(0), parentInfo)) + { + addr.Value->parentAddress = parentInfo; + parentInfo->children.add(addr.Value); + if (!parentInfo->isConstant) + addr.Value->isConstant = false; + } + } + return info; + } +} -- cgit v1.2.3