summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-05-06 03:03:25 -0400
committerGitHub <noreply@github.com>2023-05-06 00:03:25 -0700
commit271dc1b98d3887b6297c5407dc67692716687f4d (patch)
treea714a41f6a490000545e82cadd20561a020b0a1e /source
parent0602eaaba32bdbaf3f99ab8987e97419cba395aa (diff)
Don't store loop induction values + fix minor issue (#2872)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp381
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h38
-rw-r--r--source/slang/slang-ir-autodiff.cpp10
3 files changed, 357 insertions, 72 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 135c72556..ab23aeb40 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -11,7 +11,8 @@ void applyCheckpointSet(
HoistedPrimalsInfo* hoistInfo,
HashSet<IRUse*>& pendingUses,
Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock,
- IROutOfOrderCloneContext* cloneCtx);
+ IROutOfOrderCloneContext* cloneCtx,
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& blockIndexInfo);
bool containsOperand(IRInst* inst, IRInst* operand)
{
@@ -260,8 +261,11 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks(
RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
IRGlobalValueWithCode* func,
Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock,
- IROutOfOrderCloneContext* cloneCtx)
+ IROutOfOrderCloneContext* cloneCtx,
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& blockIndexInfo)
{
+ collectInductionValues(func);
+
RefPtr<CheckpointSetInfo> checkpointInfo = new CheckpointSetInfo();
RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
@@ -362,6 +366,12 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
if (auto param = as<IRParam>(result.instToRecompute))
{
+ if (auto inductionInfo = inductionValueInsts.tryGetValue(param))
+ {
+ checkpointInfo->loopInductionInfo.addIfNotExists(param, *inductionInfo);
+ continue;
+ }
+
// Add in the branch-args of every predecessor block.
auto paramBlock = as<IRBlock>(param->getParent());
UIndex paramIndex = 0;
@@ -389,14 +399,19 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
{
if (auto var = as<IRVar>(result.instToRecompute))
{
- IRUse* storeUse = findLatestUniqueWriteUse(var);
- if (storeUse)
+ for (auto varUse = var->firstUse; varUse; varUse = varUse->nextUse)
{
- // When we have a var and a store/call insts that writes to the var,
- // we treat as if there is a pseudo-use of the store/call to compute
- // the var inst, i.e. the var depends on the store/call, despite
- // the IR's def-use chain doesn't reflect this.
- workList.add(UseOrPseudoUse(var, storeUse->getUser()));
+ switch (varUse->getUser()->getOp())
+ {
+ case kIROp_Store:
+ case kIROp_Call:
+ // When we have a var and a store/call insts that writes to the var,
+ // we treat as if there is a pseudo-use of the store/call to compute
+ // the var inst, i.e. the var depends on the store/call, despite
+ // the IR's def-use chain doesn't reflect this.
+ workList.add(UseOrPseudoUse(var, varUse->getUser()));
+ break;
+ }
}
}
else
@@ -429,13 +444,20 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
{
for (auto use = var->firstUse; use; use = use->nextUse)
{
- auto callUser = as<IRCall>(use->getUser());
- if (!callUser)
- continue;
- checkpointInfo->recomputeSet.add(callUser);
- checkpointInfo->storeSet.remove(callUser);
- if (instWorkListSet.add(callUser))
- instWorkList.add(callUser);
+ if (auto callUser = as<IRCall>(use->getUser()))
+ {
+ checkpointInfo->recomputeSet.add(callUser);
+ checkpointInfo->storeSet.remove(callUser);
+ if (instWorkListSet.add(callUser))
+ instWorkList.add(callUser);
+ }
+ else if (auto storeUser = as<IRStore>(use->getUser()))
+ {
+ checkpointInfo->recomputeSet.add(storeUser);
+ checkpointInfo->storeSet.remove(storeUser);
+ if (instWorkListSet.add(callUser))
+ instWorkList.add(callUser);
+ }
}
}
else if (auto call = as<IRCall>(inst))
@@ -454,15 +476,198 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
}
RefPtr<HoistedPrimalsInfo> hoistInfo = new HoistedPrimalsInfo();
- applyCheckpointSet(checkpointInfo, func, hoistInfo, usesToReplace, mapDiffBlockToRecomputeBlock, cloneCtx);
+ applyCheckpointSet(checkpointInfo, func, hoistInfo, usesToReplace, mapDiffBlockToRecomputeBlock, cloneCtx, blockIndexInfo);
return hoistInfo;
}
+void AutodiffCheckpointPolicyBase::collectInductionValues(IRGlobalValueWithCode* func)
+{
+ // Collect loop induction values.
+ // There are two special phi insts we want to handle differently in our
+ // checkpointing policy:
+ // 1. a bool execution flag inserted as the result of CFG normalization,
+ // that is always true as long as the loop is still active.
+ // 2. the original induction variable that can be replaced with the loop
+ // counter we inserted during createPrimalRecomputeBlocks().
+
+ for (auto block : func->getBlocks())
+ {
+ auto loopInst = as<IRLoop>(block->getTerminator());
+ if (!loopInst)
+ continue;
+ auto targetBlock = loopInst->getTargetBlock();
+ auto ifElse = as<IRIfElse>(targetBlock->getTerminator());
+ Int paramIndex = -1;
+ Int conditionParamIndex = -1;
+ // First, we are going to collect all the bool execution flags from loops.
+ // These are very easy to identify: they are a phi param defined in
+ // targetBlock, and used as the condition value in the condtion block.
+ for (auto param : targetBlock->getParams())
+ {
+ paramIndex++;
+ if (!param->getDataType())
+ continue;
+ if (param->getDataType()->getOp() == kIROp_BoolType)
+ {
+ if (ifElse->getCondition() == param)
+ {
+ // The bool param is used as the condition of the if-else inside the loop,
+ // this param will always be true during the loop, and we don't need to store it.
+ LoopInductionValueInfo info;
+ info.kind = LoopInductionValueInfo::Kind::AlwaysTrue;
+ inductionValueInsts[param] = info;
+ conditionParamIndex = paramIndex;
+ }
+ }
+ }
+ if (conditionParamIndex == -1)
+ continue;
+
+ // Next, we try to identify the original induction variables, if they exist.
+ // These are trickier, and we have to hard code the complex pattern that
+ // we can recognize.
+ // This pattern matching logic is ugly and fragile against changes to cfg
+ // normalization, but it is the easiest way to do it right now.
+ // Basically, we are looking for this pattern:
+ // loop(..., i=initVal)
+ // {
+ // targetBlock:
+ // ...
+ // param int i;
+ // param bool condition;
+ // ...
+ // branch condtionBlock;
+ // conditionBlock:
+ // if (condition)
+ // {
+ // }
+ // else
+ // {
+ // break;
+ // }
+ // // ...
+ // someBodyBlock:
+ // ...
+ // if (condition)
+ // {
+ // ...
+ // // Check condition 1: i is used by an `add`
+ // // Check condition 2: parent of (i+1) is a branch target of if(condition)
+ // // Check condition 3: branches to parentBlock with i1 = i + 1.
+ // goto parentBlock(i + 1);
+ // }
+ // else
+ // goto parentBlock(other);
+ // parentBlock:
+ // // Check condition 4: parentBlock branches to finalBlock.
+ // param int i1;
+ // goto finalBlock;
+ // finalBlock:
+ // // Check condition 5: finalBlock branches to targetBlock with new i = i1.
+ // goto loopHeader(i1);
+ // }
+ //
+ paramIndex = -1;
+ for (auto param : targetBlock->getParams())
+ {
+ paramIndex++;
+ if (!param->getDataType())
+ continue;
+ if (isScalarIntegerType(param->getDataType()))
+ {
+ // If the param is always equal to the loop index, we don't need to store it.
+ IRInst* addUse = nullptr;
+ for (auto use = param->firstUse; use && !addUse; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (user->getOp() != kIROp_Add)
+ continue;
+ auto intLit = as<IRIntLit>(use->getUser()->getOperand(1));
+ if (!intLit)
+ continue;
+ if (intLit->getValue() != 1)
+ continue;
+
+ // The add inst's parent block is behind a `ifelse(loopCondition)`.
+ auto addInstBlock = as<IRBlock>(user->getParent());
+ if (!addInstBlock)
+ continue;
+ auto predecessors = addInstBlock->getPredecessors();
+ if (predecessors.getCount() != 1)
+ continue;
+ auto parentIfElse = as<IRIfElse>(predecessors.b->getUser());
+ if (!parentIfElse)
+ continue;
+ auto parentCondition = parentIfElse->getCondition();
+
+ auto branch = as<IRUnconditionalBranch>(addInstBlock->getTerminator());
+ if (!branch)
+ continue;
+
+ // The add inst should be used as a branchArg.
+ UInt argIndex = 0;
+ for (UInt i = 0; i < branch->getArgCount(); i++)
+ {
+ if (branch->getArg(i) == user)
+ {
+ addUse = user;
+ argIndex = i;
+ break;
+ }
+ }
+ if (!addUse)
+ continue;
+ auto branchTarget1 = branch->getTargetBlock();
+ auto branchParam = branchTarget1->getFirstParam();
+ for (UInt i = 0; i < argIndex; i++)
+ if (branchParam)
+ branchParam = branchParam->getNextParam();
+ if (!branchParam)
+ continue;
+
+ // The branchParam is used as argument to branch back to loop header.
+ auto branch2 = as<IRUnconditionalBranch>(branchTarget1->getTerminator());
+ if (!branch2)
+ continue;
+ if (branch2->getTargetBlock() != targetBlock)
+ continue;
+ argIndex = 0;
+ for (UInt i = 0; i < branch2->getArgCount(); i++)
+ {
+ if (branch2->getArg(i) == branchParam)
+ {
+ argIndex = i;
+ break;
+ }
+ }
+ if (argIndex != (UInt)paramIndex)
+ continue;
+
+ // parentCondition is also used as the new condition in the back jump.
+ if (conditionParamIndex < 0 || (UInt)conditionParamIndex >= branch2->getArgCount() ||
+ branch2->getArg((UInt)conditionParamIndex) != parentCondition)
+ continue;
+
+ // The use of the add inst matches all of our conditions as an induction value
+ // that is equivalent to loop counter.
+ LoopInductionValueInfo info;
+ info.kind = LoopInductionValueInfo::Kind::EqualsToCounter;
+ info.loopInst = loopInst;
+ info.counterOffset = loopInst->getArg(paramIndex);
+ inductionValueInsts[param] = info;
+ break;
+ }
+ }
+ }
+ }
+}
+
void applyToInst(
IRBuilder* builder,
CheckpointSetInfo* checkpointInfo,
HoistedPrimalsInfo* hoistInfo,
IROutOfOrderCloneContext* cloneCtx,
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& blockIndexInfo,
IRInst* inst)
{
// Early-out..
@@ -483,6 +688,35 @@ void applyToInst(
{
return;
}
+ // If this is loop condition, it is always true in reverse blocks.
+ LoopInductionValueInfo inductionValueInfo;
+ if (checkpointInfo->loopInductionInfo.tryGetValue(inst, inductionValueInfo))
+ {
+ IRInst* replacement = nullptr;
+ if (inductionValueInfo.kind == LoopInductionValueInfo::Kind::AlwaysTrue)
+ {
+ replacement = builder->getBoolValue(true);
+ }
+ else if (inductionValueInfo.kind == LoopInductionValueInfo::Kind::EqualsToCounter)
+ {
+ auto indexInfo = blockIndexInfo.tryGetValue(inductionValueInfo.loopInst->getTargetBlock());
+ SLANG_ASSERT(indexInfo);
+ SLANG_ASSERT(indexInfo->getCount() != 0);
+ replacement = indexInfo->getFirst().diffCountParam;
+ if (inductionValueInfo.counterOffset)
+ {
+ setInsertAfterOrdinaryInst(builder, replacement);
+ replacement = builder->emitAdd(
+ replacement->getDataType(),
+ replacement,
+ inductionValueInfo.counterOffset);
+ }
+ }
+ SLANG_ASSERT(replacement);
+ cloneCtx->cloneEnv.mapOldValToNew[inst] = replacement;
+ cloneCtx->registerClonedInst(builder, inst, replacement);
+ return;
+ }
}
auto recomputeInst = cloneCtx->cloneInstOutOfOrder(builder, inst);
@@ -524,7 +758,8 @@ void applyCheckpointSet(
HoistedPrimalsInfo* hoistInfo,
HashSet<IRUse*>& pendingUses,
Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock,
- IROutOfOrderCloneContext* cloneCtx)
+ IROutOfOrderCloneContext* cloneCtx,
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& blockIndexInfo)
{
for (auto use : pendingUses)
cloneCtx->pendingUses.add(use);
@@ -554,16 +789,22 @@ void applyCheckpointSet(
builder.setInsertBefore(recomputeInsertBeforeInst);
bool isRecomputed = checkpointInfo->recomputeSet.contains(param);
bool isInverted = checkpointInfo->invertSet.contains(param);
-
+ bool loopInductionInfo = checkpointInfo->loopInductionInfo.tryGetValue(param);
if (!isRecomputed && !isInverted)
continue;
- SLANG_RELEASE_ASSERT(
- recomputeBlock != block &&
- "recomputed param should belong to block that has recompute block.");
+ if (!loopInductionInfo)
+ {
+ SLANG_RELEASE_ASSERT(
+ recomputeBlock != block &&
+ "recomputed param should belong to block that has recompute block.");
+ }
// Apply checkpoint rule to the parameter itself.
- applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, param);
+ applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, blockIndexInfo, param);
+
+ if (loopInductionInfo)
+ continue;
// Copy primal branch-arg for predecessor blocks.
HashSet<IRBlock*> predecessorSet;
@@ -620,7 +861,7 @@ void applyCheckpointSet(
builder.setInsertBefore(getParamPreludeBlock(func)->getTerminator());
}
}
- applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child);
+ applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, blockIndexInfo, child);
}
}
@@ -1267,7 +1508,7 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
//
RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(func->getModule());
chkPolicy->preparePolicy(func);
- auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap, cloneCtx);
+ auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap, cloneCtx, indexedBlockInfo);
// Legalize the primal inst accesses by introducing local variables / arrays and emitting
// necessary load/store logic.
@@ -1306,20 +1547,37 @@ static CheckpointPreference getCheckpointPreference(IRInst* callee)
return CheckpointPreference::None;
}
-static bool isGlobalAddress(IRInst* inst)
+static bool isGlobalMutableAddress(IRInst* inst)
{
auto root = getRootAddr(inst);
if (root)
{
if (as<IRParameterGroupType>(root->getDataType()))
{
- return true;
+ return false;
}
return as<IRModuleInst>(root->getParent()) != nullptr;
}
return false;
}
+static bool isInstInPrimalOrTransposedParameterBlocks(IRInst* inst)
+{
+ auto func = getParentFunc(inst);
+ if (!func)
+ return false;
+ auto firstBlock = func->getFirstBlock();
+ if (inst->getParent() == firstBlock)
+ return true;
+ auto branch = as<IRUnconditionalBranch>(firstBlock->getTerminator());
+ if (!branch)
+ return false;
+ auto secondBlock = branch->getTargetBlock();
+ if (inst->getParent() == secondBlock)
+ return true;
+ return false;
+}
+
static bool shouldStoreInst(IRInst* inst)
{
if (!inst->getDataType())
@@ -1406,10 +1664,16 @@ static bool shouldStoreInst(IRInst* inst)
return false;
case kIROp_Load:
- // Never store a load of a global parameter/variable.
- if (isGlobalAddress(as<IRLoad>(inst)->getPtr()))
- return false;
- break;
+ // In general, don't store loads, because:
+ // - Loads to constant data can just be reloaded.
+ // - Loads to local variables can only exist for the temp variables used for calls,
+ // those variables are written only once so we can always load them anytime.
+ // - Loads to global mutable variables are now allowed, but we will capture that
+ // case in canRecompute().
+ // - The only exception is the load of an inout param, in which case we do need
+ // to store it because the param may be modified by the func at exit. Similarly,
+ // this will be handled in canRecompute().
+ return false;
case kIROp_Call:
// If the callee prefers recompute policy, don't store.
@@ -1462,47 +1726,38 @@ static bool shouldStoreVar(IRVar* var)
return false;
}
-bool canRecompute(UseOrPseudoUse use)
+bool DefaultCheckpointPolicy::canRecompute(UseOrPseudoUse use)
{
if (auto load = as<IRLoad>(use.usedVal))
{
- // Generally, we cannot recompute a load(ptr), since ptr may be modified
- // afterwards.
- //
- // The exceptions are a load of an inout param or global param, since the
- // propagation function never actually writes to the primal part of the
- // inout param, and we can always just read the original param.
-
auto ptr = load->getPtr();
- if (ptr->getOp() == kIROp_Param)
- {
- if (auto block = as<IRBlock>(ptr->getParent()))
- {
- return (block == block->getParent()->getFirstBlock());
- }
- }
- else if (ptr->getOp() == kIROp_GlobalParam)
- {
- return true;
- }
- else if (as<IRParameterGroupType>(ptr->getDataType()))
+
+ // We can't recompute a `load` is if it is a load from a global mutable
+ // variable.
+ if (isGlobalMutableAddress(ptr))
+ return false;
+
+ // We can't recompute a 'load' from a mutable function parameter.
+ if (as<IRParam>(ptr) || as<IRVar>(ptr))
{
- return true;
+ if (isInstInPrimalOrTransposedParameterBlocks(ptr))
+ return false;
}
- return false;
}
- auto param = as<IRParam>(use.usedVal);
- if (!param)
- return true;
-
- // We can recompute a phi param if it is not in a loop start block.
- auto parentBlock = as<IRBlock>(param->getParent());
- for (auto pred : parentBlock->getPredecessors())
+ else if (auto param = as<IRParam>(use.usedVal))
{
- if (auto loop = as<IRLoop>(pred->getTerminator()))
+ if (inductionValueInsts.containsKey(param))
+ return true;
+
+ // We can recompute a phi param if it is not in a loop start block.
+ auto parentBlock = as<IRBlock>(param->getParent());
+ for (auto pred : parentBlock->getPredecessors())
{
- if (loop->getTargetBlock() == parentBlock)
- return false;
+ if (auto loop = as<IRLoop>(pred->getTerminator()))
+ {
+ if (loop->getTargetBlock() == parentBlock)
+ return false;
+ }
}
}
return true;
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
index e9fc0d4a5..c9377d56b 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.h
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -19,11 +19,18 @@ namespace Slang
UInt operandCount = clonedInst->getOperandCount();
for (UInt ii = 0; ii < operandCount; ++ii)
{
- auto oldOperand = inst->getOperand(ii);
auto newOperand = clonedInst->getOperand(ii);
-
- if (oldOperand == newOperand)
- pendingUses.add(&clonedInst->getOperands()[ii]);
+ // If operand is in a differential or recompute block, it means it has already
+ // been cloned, so we don't add it to pending uses.
+ if (auto operandParent = as<IRBlock>(newOperand->getParent()))
+ {
+ if (isDifferentialOrRecomputeBlock(operandParent))
+ {
+ continue;
+ }
+ }
+ // Otherwise, add it to pending uses.
+ pendingUses.add(&clonedInst->getOperands()[ii]);
}
for (auto use = inst->firstUse; use;)
@@ -221,6 +228,18 @@ namespace Slang
return primalCountParam == other.primalCountParam;
}
};
+
+ struct LoopInductionValueInfo
+ {
+ enum Kind
+ {
+ AlwaysTrue,
+ EqualsToCounter,
+ };
+ Kind kind;
+ IRLoop* loopInst = nullptr;
+ IRInst* counterOffset = nullptr;
+ };
// Information on which insts are to be stored, recomputed
// and inverted within a single function.
@@ -232,7 +251,7 @@ namespace Slang
HashSet<IRInst*> storeSet;
HashSet<IRInst*> recomputeSet;
HashSet<IRInst*> invertSet;
-
+ Dictionary<IRInst*, LoopInductionValueInfo> loopInductionInfo;
Dictionary<IRInst*, InversionInfo> invInfoMap;
};
@@ -289,7 +308,8 @@ namespace Slang
RefPtr<HoistedPrimalsInfo> processFunc(
IRGlobalValueWithCode* func,
Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock,
- IROutOfOrderCloneContext* cloneCtx);
+ IROutOfOrderCloneContext* cloneCtx,
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& blockIndexInfo);
// Do pre-processing on the function (mainly for
// 'global' checkpointing methods that consider the entire
@@ -302,6 +322,8 @@ namespace Slang
protected:
IRModule* module;
+ Dictionary<IRInst*, LoopInductionValueInfo> inductionValueInsts;
+ void collectInductionValues(IRGlobalValueWithCode* func);
};
class DefaultCheckpointPolicy : public AutodiffCheckpointPolicyBase
@@ -314,6 +336,10 @@ namespace Slang
virtual void preparePolicy(IRGlobalValueWithCode* func);
virtual HoistResult classify(UseOrPseudoUse use);
+
+ private:
+ bool canRecompute(UseOrPseudoUse use);
+
};
RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index f6a977994..6d56736ad 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1909,7 +1909,7 @@ IRUse* findUniqueStoredVal(IRVar* var)
// the final value to it, this method will return the call inst for this case.
IRUse* findLatestUniqueWriteUse(IRVar* var)
{
- IRUse* storeUse = nullptr;
+ IRUse* callUse = nullptr;
for (auto use = var->firstUse; use; use = use->nextUse)
{
if (const auto callInst = as<IRCall>(use->getUser()))
@@ -1917,10 +1917,14 @@ IRUse* findLatestUniqueWriteUse(IRVar* var)
// Ignore uses from differential blocks.
if (callInst->getParent()->findDecoration<IRDifferentialInstDecoration>())
continue;
- SLANG_RELEASE_ASSERT(!storeUse);
- storeUse = use;
+ SLANG_RELEASE_ASSERT(!callUse);
+ callUse = use;
}
}
+
+ if (callUse)
+ return callUse;
+
// If no unique call found, try to look for a store.
return findUniqueStoredVal(var);
}