diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-23 06:59:25 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-23 06:59:25 -0800 |
| commit | 46a4d98baa1d43b33717b4377aefeeaf46b9c2ff (patch) | |
| tree | c89f3a1c416330f859887d00f896b18bcc7488a5 /source/slang/slang-ir-address-analysis.cpp | |
| parent | 263ca18ea516cfce43fda703c0a411aaf1938e42 (diff) | |
Full address insts elimination for backward autodiff. (#2604)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-address-analysis.cpp')
| -rw-r--r-- | source/slang/slang-ir-address-analysis.cpp | 173 |
1 files changed, 173 insertions, 0 deletions
diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp new file mode 100644 index 000000000..aba59e1de --- /dev/null +++ b/source/slang/slang-ir-address-analysis.cpp @@ -0,0 +1,173 @@ +#include "slang-ir-address-analysis.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +namespace Slang +{ + void moveInstToEarliestPoint(IRDominatorTree* domTree, IRGlobalValueWithCode* func, IRInst* inst) + { + if (!as<IRBlock>(inst->getParent())) + return; + if (domTree->isUnreachable(as<IRBlock>(inst->getParent()))) + return; + + List<IRBlock*> blocks; + HashSet<IRInst*> operandInsts; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + operandInsts.Add(inst->getOperand(i)); + auto parentBlock = as<IRBlock>(inst->getOperand(i)->getParent()); + if (parentBlock) + { + if (!domTree->isUnreachable(parentBlock)) + blocks.add(parentBlock); + } + } + { + operandInsts.Add(inst->getFullType()); + auto parentBlock = as<IRBlock>(inst->getFullType()->getParent()); + if (parentBlock) + { + if (!domTree->isUnreachable(parentBlock)) + blocks.add(parentBlock); + } + } + // Find earliest block that is dominated by all operand blocks. + IRBlock* earliestBlock = as<IRBlock>(inst->getParent()); + for (auto block : func->getBlocks()) + { + bool dominated = true; + for (auto opBlock : blocks) + { + if (!domTree->dominates(opBlock, block)) + { + dominated = false; + break; + } + } + if (dominated) + { + earliestBlock = block; + break; + } + } + + if (!earliestBlock) + return; + + IRInst* latestOperand = nullptr; + for (auto childInst : earliestBlock->getChildren()) + { + if (operandInsts.Contains(childInst)) + { + latestOperand = childInst; + } + } + + if (!latestOperand || as<IRParam>(latestOperand)) + inst->insertBefore(earliestBlock->getFirstOrdinaryInst()); + else + inst->insertAfter(latestOperand); + } + + AddressAccessInfo analyzeAddressUse(IRDominatorTree* dom, IRGlobalValueWithCode* func) + { + DeduplicateContext deduplicateContext; + + AddressAccessInfo info; + + // Deduplicate and move known address insts. + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstChild(); inst;) + { + auto next = inst->getNextInst(); + switch (inst->getOp()) + { + case kIROp_Var: + { + RefPtr<AddressInfo> addrInfo = new AddressInfo(); + addrInfo->addrInst = inst; + addrInfo->isConstant = true; + addrInfo->parentAddress = nullptr; + info.addressInfos[inst] = addrInfo; + } + break; + case kIROp_Param: + if (as<IRPtrTypeBase>(inst->getFullType())) + { + RefPtr<AddressInfo> addrInfo = new AddressInfo(); + addrInfo->addrInst = inst; + addrInfo->isConstant = (block == func->getFirstBlock() ? true : false); + addrInfo->parentAddress = nullptr; + info.addressInfos[inst] = addrInfo; + } + break; + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + moveInstToEarliestPoint(dom, func, inst->getFullType()); + moveInstToEarliestPoint(dom, func, inst); + auto deduplicated = deduplicateContext.deduplicate(inst, [func](IRInst* inst) + { + if (!inst->getParent()) + return false; + if (inst->getParent()->getParent() != func) + return false; + switch (inst->getOp()) + { + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + return true; + default: + return false; + } + }); + + if (deduplicated != inst) + { + SLANG_RELEASE_ASSERT(dom->dominates( + as<IRBlock>(deduplicated->getParent()), + as<IRBlock>(inst->getParent()))); + + inst->replaceUsesWith(deduplicated); + inst->removeAndDeallocate(); + } + else + { + RefPtr<AddressInfo> addrInfo = new AddressInfo(); + addrInfo->addrInst = inst; + if (inst->getOp() == kIROp_FieldAddress) + { + addrInfo->isConstant = true; + } + else + { + addrInfo->isConstant = + as<IRConstant>(inst->getOperand(1)) ? true : false; + } + info.addressInfos[inst] = addrInfo; + } + } + break; + } + inst = next; + } + } + + // Construct address info tree. + for (auto& addr : info.addressInfos) + { + RefPtr<AddressInfo> parentInfo; + if (addr.Value->addrInst->getOperandCount() > 1 && + info.addressInfos.TryGetValue(addr.Value->addrInst->getOperand(0), parentInfo)) + { + addr.Value->parentAddress = parentInfo; + parentInfo->children.add(addr.Value); + if (!parentInfo->isConstant) + addr.Value->isConstant = false; + } + } + return info; + } +} |
