summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-02-25 10:42:19 -0800
committerGitHub <noreply@github.com>2025-02-25 10:42:19 -0800
commita9f2f8a592c4514cd116c947486055788092ea56 (patch)
treee7bb9fba9d80631254ea1b42a96fe2c201573979 /source
parent19083925690f6180cb081ce2be4fbbdb64010b37 (diff)
Fix `UseGraph::replace` (#6395)
* Fix `UseGraph::isTrivial()` test. * Fix. * Fix. * Refactor `UseGraph` and `UseChain` * Update slang-ir-autodiff-primal-hoist.cpp * Update all auto-diff locations that handle pointers to treat user pointers as regular values * Update test to use direct-SPIRV only --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp88
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp6
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h7
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h7
-rw-r--r--source/slang/slang-ir-autodiff.cpp10
-rw-r--r--source/slang/slang-ir-autodiff.h2
-rw-r--r--source/slang/slang-ir-util.cpp12
-rw-r--r--source/slang/slang-ir-util.h4
11 files changed, 74 insertions, 74 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index d8500a694..0302d9ce7 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1777,7 +1777,7 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
for (auto param : params)
{
- auto ptrType = as<IRPtrTypeBase>(param->getDataType());
+ auto ptrType = asRelevantPtrType(param->getDataType());
auto tempVar = builder.emitVar(ptrType->getValueType());
param->replaceUsesWith(tempVar);
mapParamToTempVar[param] = tempVar;
@@ -2245,7 +2245,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(
builder->emitDifferentialPairGetPrimal(diffPairParam),
builder->emitDifferentialPairGetDifferential(diffType, diffPairParam));
}
- else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
+ else if (auto pairPtrType = asRelevantPtrType(diffPairType))
{
auto ptrInnerPairType = as<IRDifferentialPairTypeBase>(pairPtrType->getValueType());
// Make a local copy of the parameter for primal and diff parts.
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 06e3f409d..b5ac784ce 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -1174,7 +1174,7 @@ IRVar* emitIndexedLocalVar(
SourceLoc location)
{
// Cannot store pointers. Case should have been handled by now.
- SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType));
+ SLANG_RELEASE_ASSERT(!asRelevantPtrType(baseType));
// Cannot store types. Case should have been handled by now.
SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType));
@@ -1326,7 +1326,11 @@ static int getInstRegionNestLevel(
struct UseChain
{
+ // The chain of uses from the base use to the relevant use.
+ // However, this is stored in reverse order (so that the last use is the 'base use')
+ //
List<IRUse*> chain;
+
static List<UseChain> from(
IRUse* baseUse,
Func<bool, IRUse*> isRelevantUse,
@@ -1366,41 +1370,20 @@ struct UseChain
return result;
}
- void replace(IROutOfOrderCloneContext* ctx, IRBuilder* builder, IRInst* inst)
+ // This function only replaces the inner links, not the base use.
+ void replaceInnerLinks(IROutOfOrderCloneContext* ctx, IRBuilder* builder)
{
SLANG_ASSERT(chain.getCount() > 0);
- // Simple case: if there is only one use, then we can just replace it.
- if (chain.getCount() == 1)
- {
- builder->replaceOperand(chain.getLast(), inst);
- chain.clear();
- return;
- }
-
- // Pop the last use, which is the base use that needs to be replaced.
- auto baseUse = chain.getLast();
- chain.removeLast();
+ const UIndex count = chain.getCount();
- // Ensure that replacement inst is set as mapping for the baseUse.
- ctx->cloneEnv.mapOldValToNew[baseUse->get()] = inst;
-
- IRBuilder chainBuilder(builder->getModule());
- setInsertAfterOrdinaryInst(&chainBuilder, inst);
-
- chain.reverse();
- chain.removeLast();
-
- // Clone the rest of the chain.
- for (auto& use : chain)
+ // Process the chain in reverse order (excluding the first and last elements).
+ // That is, iterate from count - 2 down to 1 (inclusive).
+ for (int i = ((int)count) - 2; i >= 1; i--)
{
- ctx->cloneInstOutOfOrder(&chainBuilder, use->get());
+ IRUse* use = chain[i];
+ ctx->cloneInstOutOfOrder(builder, use->get());
}
-
- // We won't actually replace the final use, because if there are multiple chains
- // it can cause problems. The parent UseGraph will handle that.
-
- chain.clear();
}
IRInst* getUser() const
@@ -1417,6 +1400,14 @@ struct UseGraph
//
OrderedDictionary<IRUse*, List<UseChain>> chainSets;
+ // Create a UseGraph from a base inst.
+ //
+ // `isRelevantUse` is a predicate that determines if a use is relevant. Traversal will stop at
+ // this use, and all chains to this use will be grouped together.
+ //
+ // `passthroughInst` is a predicate that determines if an inst should be looked through
+ // for uses.
+ //
static UseGraph from(
IRInst* baseInst,
Func<bool, IRUse*> isRelevantUse,
@@ -1445,36 +1436,33 @@ struct UseGraph
return result;
}
- void replace(IRBuilder* builder, IRUse* use, IRInst* inst)
+ void replace(IRBuilder* builder, IRUse* relevantUse, IRInst* inst)
{
// Since we may have common nodes, we will use an out-of-order cloning context
// that can retroactively correct the uses as needed.
//
IROutOfOrderCloneContext ctx;
- List<UseChain> chains = chainSets[use];
- for (auto chain : chains)
- {
- chain.replace(&ctx, builder, inst);
- }
+ List<UseChain> chains = chainSets[relevantUse];
- if (!isTrivial())
+ // Link the first use of each chain to inst.
+ for (auto& chain : chains)
+ ctx.cloneEnv.mapOldValToNew[chain.chain.getLast()->get()] = inst;
+
+ // Process the inner links of each chain using the replacement.
+ for (auto& chain : chains)
{
- builder->setInsertBefore(use->getUser());
- auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, use->get());
+ IRBuilder chainBuilder(builder->getModule());
+ setInsertAfterOrdinaryInst(&chainBuilder, inst);
- // Replace the base use.
- builder->replaceOperand(use, lastInstInChain);
+ chain.replaceInnerLinks(&ctx, builder);
}
- }
- bool isTrivial()
- {
- // We're trivial if there's only one chain, and it has only one use.
- if (chainSets.getCount() != 1)
- return false;
+ // Finally, replace the relevant use (i.e, "final use") with the new replacement inst.
+ builder->setInsertBefore(relevantUse->getUser());
+ auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, relevantUse->get());
- auto& chain = chainSets.getFirst().value;
- return chain.getCount() == 1;
+ // Replace the base use.
+ builder->replaceOperand(relevantUse, lastInstInChain);
}
List<IRUse*> getUniqueUses() const
@@ -1668,7 +1656,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
return true;
}
else if (
- as<IRPtrTypeBase>(instToStore->getDataType()) &&
+ asRelevantPtrType(instToStore->getDataType()) &&
!isDifferentialOrRecomputeBlock(defBlock))
{
return true;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 3237ba3b2..519f796b4 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -370,7 +370,7 @@ IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(
auto diffPairType = tryGetDiffPairType(builder, paramType);
if (diffPairType)
{
- if (!as<IRPtrTypeBase>(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType))
+ if (!asRelevantPtrType(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType))
return builder->getInOutType(diffPairType);
return diffPairType;
}
@@ -514,7 +514,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
{
// As long as the primal parameter is not an out or constref type,
// we need to fetch the primal value from the parameter.
- if (as<IRPtrTypeBase>(propagateParamType))
+ if (asRelevantPtrType(propagateParamType))
{
primalArg = builder.emitLoad(param);
}
@@ -544,7 +544,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
}
else
{
- auto primalPtrType = as<IRPtrTypeBase>(primalParamType);
+ auto primalPtrType = asRelevantPtrType(primalParamType);
SLANG_RELEASE_ASSERT(primalPtrType);
auto primalValueType = primalPtrType->getValueType();
auto var = builder.emitVar(primalValueType);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 38a7a18bb..8356e5f81 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -291,7 +291,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
if (isNoDiffType(origType))
return nullptr;
- if (auto ptrType = as<IRPtrTypeBase>(origType))
+ if (auto ptrType = asRelevantPtrType(origType))
return builder->getPtrType(
origType->getOp(),
differentiateType(builder, ptrType->getValueType()));
@@ -556,7 +556,7 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
if (isNoDiffType(originalType))
return nullptr;
- if (auto origPtrType = as<IRPtrTypeBase>(originalType))
+ if (auto origPtrType = asRelevantPtrType(originalType))
{
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
return builder->getPtrType(originalType->getOp(), diffPairValueType);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 5e96c4e0f..282cc9685 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -619,7 +619,7 @@ struct DiffTransposePass
if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst))
{
if (auto ptrPrimalType =
- as<IRPtrTypeBase>(tryGetPrimalTypeFromDiffInst(varInst)))
+ asRelevantPtrType(tryGetPrimalTypeFromDiffInst(varInst)))
{
varInst->insertAtEnd(firstRevDiffBlock);
@@ -1119,7 +1119,7 @@ struct DiffTransposePass
auto getDiffPairType = [](IRType* type)
{
- if (auto ptrType = as<IRPtrTypeBase>(type))
+ if (auto ptrType = asRelevantPtrType(type))
type = ptrType->getValueType();
return as<IRDifferentialPairType>(type);
};
@@ -1168,7 +1168,7 @@ struct DiffTransposePass
argRequiresLoad.add(false);
writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar});
}
- else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
+ else if (!asRelevantPtrType(arg->getDataType()) && getDiffPairType(arg->getDataType()))
{
// Normal differentiable input parameter will become an inout DiffPair parameter
// in the propagate func. The split logic has already prepared the initial value
@@ -1241,7 +1241,6 @@ struct DiffTransposePass
argRequiresLoad.add(false);
}
-
auto revFnType =
this->autodiffContext->transcriberSet.propagateTranscriber->differentiateFunctionType(
builder,
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 6bc428ad6..4d5903ab1 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -332,8 +332,8 @@ bool isIntermediateContextType(IRInst* type)
case kIROp_Specialize:
return isIntermediateContextType(as<IRSpecialize>(type)->getBase());
default:
- if (as<IRPtrTypeBase>(type))
- return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType());
+ if (auto ptrType = asRelevantPtrType(type))
+ return isIntermediateContextType(ptrType->getValueType());
return false;
}
}
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 556fb58a8..ec435ee87 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -75,16 +75,15 @@ struct DiffUnzipPass
primalParam = primalParam->getNextParam())
{
auto type = primalParam->getFullType();
- if (auto ptrType = as<IRPtrTypeBase>(type))
+ if (auto ptrType = asRelevantPtrType(type))
{
type = ptrType->getValueType();
}
if (auto pairType = as<IRDifferentialPairType>(type))
{
IRInst* diffType = diffTypeContext.getDiffTypeFromPairType(builder, pairType);
- if (as<IRPtrTypeBase>(primalParam->getFullType()))
- diffType =
- builder->getPtrType(primalParam->getFullType()->getOp(), (IRType*)diffType);
+ if (auto ptrType = asRelevantPtrType(primalParam->getFullType()))
+ diffType = builder->getPtrType(ptrType->getOp(), (IRType*)diffType);
auto primalRef = builder->emitPrimalParamRef(primalParam);
auto diffRef = builder->emitDiffParamRef((IRType*)diffType, primalParam);
builder->markInstAsDifferential(diffRef, pairType->getValueType());
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 40dcb1b51..df657476a 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -135,7 +135,7 @@ bool isNoDiffType(IRType* paramType)
paramType = attrType->getBaseType();
}
- else if (auto ptrType = as<IRPtrTypeBase>(paramType))
+ else if (auto ptrType = asRelevantPtrType(paramType))
{
paramType = ptrType->getValueType();
}
@@ -184,7 +184,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
IRStructKey* key)
{
IRInst* pairType = nullptr;
- if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType()))
+ if (auto basePtrType = asRelevantPtrType(baseInst->getDataType()))
{
auto loweredType = lowerDiffPairType(builder, basePtrType->getValueType());
@@ -203,7 +203,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
baseInst,
key));
}
- else if (auto ptrType = as<IRPtrTypeBase>(pairType))
+ else if (auto ptrType = asRelevantPtrType(pairType))
{
if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
{
@@ -240,7 +240,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
baseInst,
key));
}
- else if (auto genericPtrType = as<IRPtrTypeBase>(genericType))
+ else if (auto genericPtrType = asRelevantPtrType(genericType))
{
if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType()))
{
@@ -1646,7 +1646,7 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(
IRBuilder* builder,
IRInst* primalType)
{
- if (auto ptrType = as<IRPtrTypeBase>(primalType))
+ if (auto ptrType = asRelevantPtrType(primalType))
return builder->getPtrType(
primalType->getOp(),
differentiateType(builder, ptrType->getValueType()));
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 433b6093f..4698408e3 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -604,7 +604,7 @@ inline bool isRelevantDifferentialPair(IRType* type)
{
return true;
}
- else if (auto argPtrType = as<IRPtrTypeBase>(type))
+ else if (auto argPtrType = asRelevantPtrType(type))
{
if (as<IRDifferentialPairType>(argPtrType->getValueType()))
{
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index dbd6ac099..bf5b25d9c 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -1528,6 +1528,16 @@ bool isOne(IRInst* inst)
}
}
+IRPtrTypeBase* asRelevantPtrType(IRInst* inst)
+{
+ if (auto ptrType = as<IRPtrTypeBase>(inst))
+ {
+ if (ptrType->getAddressSpace() != AddressSpace::UserPointer)
+ return ptrType;
+ }
+ return nullptr;
+}
+
IRPtrTypeBase* isMutablePointerType(IRInst* inst)
{
switch (inst->getOp())
@@ -1535,7 +1545,7 @@ IRPtrTypeBase* isMutablePointerType(IRInst* inst)
case kIROp_ConstRefType:
return nullptr;
default:
- return as<IRPtrTypeBase>(inst);
+ return asRelevantPtrType(inst);
}
}
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 610524754..aed63da47 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -271,6 +271,10 @@ bool isZero(IRInst* inst);
bool isOne(IRInst* inst);
+// Casts inst to IRPtrTypeBase, excluding UserPointer address space.
+IRPtrTypeBase* asRelevantPtrType(IRInst* inst);
+
+// Returns the pointer type if it is pointer type that is not a const ref or a user pointer.
IRPtrTypeBase* isMutablePointerType(IRInst* inst);
void initializeScratchData(IRInst* inst);