summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-address-analysis.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-23 06:59:25 -0800
committerGitHub <noreply@github.com>2023-01-23 06:59:25 -0800
commit46a4d98baa1d43b33717b4377aefeeaf46b9c2ff (patch)
treec89f3a1c416330f859887d00f896b18bcc7488a5 /source/slang/slang-ir-address-analysis.cpp
parent263ca18ea516cfce43fda703c0a411aaf1938e42 (diff)
Full address insts elimination for backward autodiff. (#2604)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-address-analysis.cpp')
-rw-r--r--source/slang/slang-ir-address-analysis.cpp173
1 files changed, 173 insertions, 0 deletions
diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp
new file mode 100644
index 000000000..aba59e1de
--- /dev/null
+++ b/source/slang/slang-ir-address-analysis.cpp
@@ -0,0 +1,173 @@
+#include "slang-ir-address-analysis.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+ void moveInstToEarliestPoint(IRDominatorTree* domTree, IRGlobalValueWithCode* func, IRInst* inst)
+ {
+ if (!as<IRBlock>(inst->getParent()))
+ return;
+ if (domTree->isUnreachable(as<IRBlock>(inst->getParent())))
+ return;
+
+ List<IRBlock*> blocks;
+ HashSet<IRInst*> operandInsts;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ operandInsts.Add(inst->getOperand(i));
+ auto parentBlock = as<IRBlock>(inst->getOperand(i)->getParent());
+ if (parentBlock)
+ {
+ if (!domTree->isUnreachable(parentBlock))
+ blocks.add(parentBlock);
+ }
+ }
+ {
+ operandInsts.Add(inst->getFullType());
+ auto parentBlock = as<IRBlock>(inst->getFullType()->getParent());
+ if (parentBlock)
+ {
+ if (!domTree->isUnreachable(parentBlock))
+ blocks.add(parentBlock);
+ }
+ }
+ // Find earliest block that is dominated by all operand blocks.
+ IRBlock* earliestBlock = as<IRBlock>(inst->getParent());
+ for (auto block : func->getBlocks())
+ {
+ bool dominated = true;
+ for (auto opBlock : blocks)
+ {
+ if (!domTree->dominates(opBlock, block))
+ {
+ dominated = false;
+ break;
+ }
+ }
+ if (dominated)
+ {
+ earliestBlock = block;
+ break;
+ }
+ }
+
+ if (!earliestBlock)
+ return;
+
+ IRInst* latestOperand = nullptr;
+ for (auto childInst : earliestBlock->getChildren())
+ {
+ if (operandInsts.Contains(childInst))
+ {
+ latestOperand = childInst;
+ }
+ }
+
+ if (!latestOperand || as<IRParam>(latestOperand))
+ inst->insertBefore(earliestBlock->getFirstOrdinaryInst());
+ else
+ inst->insertAfter(latestOperand);
+ }
+
+ AddressAccessInfo analyzeAddressUse(IRDominatorTree* dom, IRGlobalValueWithCode* func)
+ {
+ DeduplicateContext deduplicateContext;
+
+ AddressAccessInfo info;
+
+ // Deduplicate and move known address insts.
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstChild(); inst;)
+ {
+ auto next = inst->getNextInst();
+ switch (inst->getOp())
+ {
+ case kIROp_Var:
+ {
+ RefPtr<AddressInfo> addrInfo = new AddressInfo();
+ addrInfo->addrInst = inst;
+ addrInfo->isConstant = true;
+ addrInfo->parentAddress = nullptr;
+ info.addressInfos[inst] = addrInfo;
+ }
+ break;
+ case kIROp_Param:
+ if (as<IRPtrTypeBase>(inst->getFullType()))
+ {
+ RefPtr<AddressInfo> addrInfo = new AddressInfo();
+ addrInfo->addrInst = inst;
+ addrInfo->isConstant = (block == func->getFirstBlock() ? true : false);
+ addrInfo->parentAddress = nullptr;
+ info.addressInfos[inst] = addrInfo;
+ }
+ break;
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ {
+ moveInstToEarliestPoint(dom, func, inst->getFullType());
+ moveInstToEarliestPoint(dom, func, inst);
+ auto deduplicated = deduplicateContext.deduplicate(inst, [func](IRInst* inst)
+ {
+ if (!inst->getParent())
+ return false;
+ if (inst->getParent()->getParent() != func)
+ return false;
+ switch (inst->getOp())
+ {
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ return true;
+ default:
+ return false;
+ }
+ });
+
+ if (deduplicated != inst)
+ {
+ SLANG_RELEASE_ASSERT(dom->dominates(
+ as<IRBlock>(deduplicated->getParent()),
+ as<IRBlock>(inst->getParent())));
+
+ inst->replaceUsesWith(deduplicated);
+ inst->removeAndDeallocate();
+ }
+ else
+ {
+ RefPtr<AddressInfo> addrInfo = new AddressInfo();
+ addrInfo->addrInst = inst;
+ if (inst->getOp() == kIROp_FieldAddress)
+ {
+ addrInfo->isConstant = true;
+ }
+ else
+ {
+ addrInfo->isConstant =
+ as<IRConstant>(inst->getOperand(1)) ? true : false;
+ }
+ info.addressInfos[inst] = addrInfo;
+ }
+ }
+ break;
+ }
+ inst = next;
+ }
+ }
+
+ // Construct address info tree.
+ for (auto& addr : info.addressInfos)
+ {
+ RefPtr<AddressInfo> parentInfo;
+ if (addr.Value->addrInst->getOperandCount() > 1 &&
+ info.addressInfos.TryGetValue(addr.Value->addrInst->getOperand(0), parentInfo))
+ {
+ addr.Value->parentAddress = parentInfo;
+ parentInfo->children.add(addr.Value);
+ if (!parentInfo->isConstant)
+ addr.Value->isConstant = false;
+ }
+ }
+ return info;
+ }
+}