summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp34
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp6
-rw-r--r--source/slang/slang-ir-util.cpp8
-rw-r--r--source/slang/slang-ir-util.h2
4 files changed, 43 insertions, 7 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 906465384..bef96f309 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -1426,6 +1426,20 @@ static CheckpointPreference getCheckpointPreference(IRInst* callee)
return CheckpointPreference::None;
}
+static bool isGlobalAddress(IRInst* inst)
+{
+ auto root = getRootAddr(inst);
+ if (root)
+ {
+ if (as<IRParameterGroupType>(root->getDataType()))
+ {
+ return true;
+ }
+ return as<IRModuleInst>(root->getParent()) != nullptr;
+ }
+ return false;
+}
+
static bool shouldStoreInst(IRInst* inst)
{
if (!inst->getDataType())
@@ -1511,6 +1525,12 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_GetTupleElement:
return false;
+ case kIROp_Load:
+ // Never store a load of a global parameter/variable.
+ if (isGlobalAddress(as<IRLoad>(inst)->getPtr()))
+ return false;
+ break;
+
case kIROp_Call:
// If the callee prefers recompute policy, don't store.
if (getCheckpointPreference(inst->getOperand(0)) == CheckpointPreference::PreferRecompute)
@@ -1533,7 +1553,9 @@ bool canRecompute(IRUse* use)
if (auto load = as<IRLoad>(use->get()))
{
// Generally, we cannot recompute a load(ptr), since ptr may be modified
- // afterwards. The exceptions are a load of an inout param, since the
+ // afterwards.
+ //
+ // The exceptions are a load of an inout param or global param, since the
// propagation function never actually writes to the primal part of the
// inout param, and we can always just read the original param.
@@ -1545,6 +1567,14 @@ bool canRecompute(IRUse* use)
return (block == block->getParent()->getFirstBlock());
}
}
+ else if (ptr->getOp() == kIROp_GlobalParam)
+ {
+ return true;
+ }
+ else if (as<IRParameterGroupType>(ptr->getDataType()))
+ {
+ return true;
+ }
return false;
}
auto param = as<IRParam>(use->get());
@@ -1579,7 +1609,9 @@ HoistResult DefaultCheckpointPolicy::classify(IRUse* use)
else
{
if (shouldStoreInst(use->get()))
+ {
return HoistResult::store(use->get());
+ }
else
{
// We may not be able to recompute due to limitations of
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 60d829324..1b14856e6 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -408,7 +408,11 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
// in a primal block.
while (auto iuse = inst->firstUse)
{
- builder.setInsertBefore(iuse->getUser());
+ auto user = iuse->getUser();
+ if (as<IRDecoration>(user))
+ user = user->getParent();
+ if (!user) continue;
+ builder.setInsertBefore(user);
auto val = builder.emitFieldExtract(
inst->getFullType(),
intermediateVar,
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 05d3157a7..a69e13562 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -354,7 +354,7 @@ void getTypeNameHint(StringBuilder& sb, IRInst* type)
}
}
-static IRInst* _getRootAddr(IRInst* addr)
+IRInst* getRootAddr(IRInst* addr)
{
for (;;)
{
@@ -379,8 +379,8 @@ bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IR
return true;
// Two variables can never alias.
- addr1 = _getRootAddr(addr1);
- addr2 = _getRootAddr(addr2);
+ addr1 = getRootAddr(addr1);
+ addr2 = getRootAddr(addr2);
// Global addresses can alias with anything.
if (!isChildInstOf(addr1, func))
@@ -436,7 +436,7 @@ bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, I
// If addr is a global variable, calling a function may change its value.
// So we need to return true here to be conservative.
- if (!isChildInstOf(_getRootAddr(addr), func))
+ if (!isChildInstOf(getRootAddr(addr), func))
{
auto callee = call->getCallee();
if (callee &&
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 076ae8fd0..075788520 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -162,7 +162,7 @@ inline IRInst* unwrapAttributedType(IRInst* type)
void getTypeNameHint(StringBuilder& sb, IRInst* type);
void copyNameHintDecoration(IRInst* dest, IRInst* src);
-
+IRInst* getRootAddr(IRInst* addrInst);
bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IRInst* addr2);
String dumpIRToString(IRInst* root);