From d1cc6a8c1e5b378ea34dc4006045bcbd37e0dfd3 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 27 Apr 2023 14:30:36 -0700 Subject: Prevent storing loads of global parameters. (#2850) Co-authored-by: Yong He --- source/slang/slang-ir-autodiff-primal-hoist.cpp | 34 ++++++++++++++++++++++++- source/slang/slang-ir-autodiff-unzip.cpp | 6 ++++- source/slang/slang-ir-util.cpp | 8 +++--- source/slang/slang-ir-util.h | 2 +- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 906465384..bef96f309 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -1426,6 +1426,20 @@ static CheckpointPreference getCheckpointPreference(IRInst* callee) return CheckpointPreference::None; } +static bool isGlobalAddress(IRInst* inst) +{ + auto root = getRootAddr(inst); + if (root) + { + if (as(root->getDataType())) + { + return true; + } + return as(root->getParent()) != nullptr; + } + return false; +} + static bool shouldStoreInst(IRInst* inst) { if (!inst->getDataType()) @@ -1511,6 +1525,12 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_GetTupleElement: return false; + case kIROp_Load: + // Never store a load of a global parameter/variable. + if (isGlobalAddress(as(inst)->getPtr())) + return false; + break; + case kIROp_Call: // If the callee prefers recompute policy, don't store. if (getCheckpointPreference(inst->getOperand(0)) == CheckpointPreference::PreferRecompute) @@ -1533,7 +1553,9 @@ bool canRecompute(IRUse* use) if (auto load = as(use->get())) { // Generally, we cannot recompute a load(ptr), since ptr may be modified - // afterwards. The exceptions are a load of an inout param, since the + // afterwards. + // + // The exceptions are a load of an inout param or global param, since the // propagation function never actually writes to the primal part of the // inout param, and we can always just read the original param. @@ -1545,6 +1567,14 @@ bool canRecompute(IRUse* use) return (block == block->getParent()->getFirstBlock()); } } + else if (ptr->getOp() == kIROp_GlobalParam) + { + return true; + } + else if (as(ptr->getDataType())) + { + return true; + } return false; } auto param = as(use->get()); @@ -1579,7 +1609,9 @@ HoistResult DefaultCheckpointPolicy::classify(IRUse* use) else { if (shouldStoreInst(use->get())) + { return HoistResult::store(use->get()); + } else { // We may not be able to recompute due to limitations of diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 60d829324..1b14856e6 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -408,7 +408,11 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( // in a primal block. while (auto iuse = inst->firstUse) { - builder.setInsertBefore(iuse->getUser()); + auto user = iuse->getUser(); + if (as(user)) + user = user->getParent(); + if (!user) continue; + builder.setInsertBefore(user); auto val = builder.emitFieldExtract( inst->getFullType(), intermediateVar, diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 05d3157a7..a69e13562 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -354,7 +354,7 @@ void getTypeNameHint(StringBuilder& sb, IRInst* type) } } -static IRInst* _getRootAddr(IRInst* addr) +IRInst* getRootAddr(IRInst* addr) { for (;;) { @@ -379,8 +379,8 @@ bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IR return true; // Two variables can never alias. - addr1 = _getRootAddr(addr1); - addr2 = _getRootAddr(addr2); + addr1 = getRootAddr(addr1); + addr2 = getRootAddr(addr2); // Global addresses can alias with anything. if (!isChildInstOf(addr1, func)) @@ -436,7 +436,7 @@ bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, I // If addr is a global variable, calling a function may change its value. // So we need to return true here to be conservative. - if (!isChildInstOf(_getRootAddr(addr), func)) + if (!isChildInstOf(getRootAddr(addr), func)) { auto callee = call->getCallee(); if (callee && diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 076ae8fd0..075788520 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -162,7 +162,7 @@ inline IRInst* unwrapAttributedType(IRInst* type) void getTypeNameHint(StringBuilder& sb, IRInst* type); void copyNameHintDecoration(IRInst* dest, IRInst* src); - +IRInst* getRootAddr(IRInst* addrInst); bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IRInst* addr2); String dumpIRToString(IRInst* root); -- cgit v1.2.3