diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-02-25 10:42:19 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-25 10:42:19 -0800 |
| commit | a9f2f8a592c4514cd116c947486055788092ea56 (patch) | |
| tree | e7bb9fba9d80631254ea1b42a96fe2c201573979 /source | |
| parent | 19083925690f6180cb081ce2be4fbbdb64010b37 (diff) | |
Fix `UseGraph::replace` (#6395)
* Fix `UseGraph::isTrivial()` test.
* Fix.
* Fix.
* Refactor `UseGraph` and `UseChain`
* Update slang-ir-autodiff-primal-hoist.cpp
* Update all auto-diff locations that handle pointers to treat user pointers as regular values
* Update test to use direct-SPIRV only
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 88 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 4 |
11 files changed, 74 insertions, 74 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index d8500a694..0302d9ce7 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1777,7 +1777,7 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func) for (auto param : params) { - auto ptrType = as<IRPtrTypeBase>(param->getDataType()); + auto ptrType = asRelevantPtrType(param->getDataType()); auto tempVar = builder.emitVar(ptrType->getValueType()); param->replaceUsesWith(tempVar); mapParamToTempVar[param] = tempVar; @@ -2245,7 +2245,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam( builder->emitDifferentialPairGetPrimal(diffPairParam), builder->emitDifferentialPairGetDifferential(diffType, diffPairParam)); } - else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType)) + else if (auto pairPtrType = asRelevantPtrType(diffPairType)) { auto ptrInnerPairType = as<IRDifferentialPairTypeBase>(pairPtrType->getValueType()); // Make a local copy of the parameter for primal and diff parts. diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 06e3f409d..b5ac784ce 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -1174,7 +1174,7 @@ IRVar* emitIndexedLocalVar( SourceLoc location) { // Cannot store pointers. Case should have been handled by now. - SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType)); + SLANG_RELEASE_ASSERT(!asRelevantPtrType(baseType)); // Cannot store types. Case should have been handled by now. SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType)); @@ -1326,7 +1326,11 @@ static int getInstRegionNestLevel( struct UseChain { + // The chain of uses from the base use to the relevant use. + // However, this is stored in reverse order (so that the last use is the 'base use') + // List<IRUse*> chain; + static List<UseChain> from( IRUse* baseUse, Func<bool, IRUse*> isRelevantUse, @@ -1366,41 +1370,20 @@ struct UseChain return result; } - void replace(IROutOfOrderCloneContext* ctx, IRBuilder* builder, IRInst* inst) + // This function only replaces the inner links, not the base use. + void replaceInnerLinks(IROutOfOrderCloneContext* ctx, IRBuilder* builder) { SLANG_ASSERT(chain.getCount() > 0); - // Simple case: if there is only one use, then we can just replace it. - if (chain.getCount() == 1) - { - builder->replaceOperand(chain.getLast(), inst); - chain.clear(); - return; - } - - // Pop the last use, which is the base use that needs to be replaced. - auto baseUse = chain.getLast(); - chain.removeLast(); + const UIndex count = chain.getCount(); - // Ensure that replacement inst is set as mapping for the baseUse. - ctx->cloneEnv.mapOldValToNew[baseUse->get()] = inst; - - IRBuilder chainBuilder(builder->getModule()); - setInsertAfterOrdinaryInst(&chainBuilder, inst); - - chain.reverse(); - chain.removeLast(); - - // Clone the rest of the chain. - for (auto& use : chain) + // Process the chain in reverse order (excluding the first and last elements). + // That is, iterate from count - 2 down to 1 (inclusive). + for (int i = ((int)count) - 2; i >= 1; i--) { - ctx->cloneInstOutOfOrder(&chainBuilder, use->get()); + IRUse* use = chain[i]; + ctx->cloneInstOutOfOrder(builder, use->get()); } - - // We won't actually replace the final use, because if there are multiple chains - // it can cause problems. The parent UseGraph will handle that. - - chain.clear(); } IRInst* getUser() const @@ -1417,6 +1400,14 @@ struct UseGraph // OrderedDictionary<IRUse*, List<UseChain>> chainSets; + // Create a UseGraph from a base inst. + // + // `isRelevantUse` is a predicate that determines if a use is relevant. Traversal will stop at + // this use, and all chains to this use will be grouped together. + // + // `passthroughInst` is a predicate that determines if an inst should be looked through + // for uses. + // static UseGraph from( IRInst* baseInst, Func<bool, IRUse*> isRelevantUse, @@ -1445,36 +1436,33 @@ struct UseGraph return result; } - void replace(IRBuilder* builder, IRUse* use, IRInst* inst) + void replace(IRBuilder* builder, IRUse* relevantUse, IRInst* inst) { // Since we may have common nodes, we will use an out-of-order cloning context // that can retroactively correct the uses as needed. // IROutOfOrderCloneContext ctx; - List<UseChain> chains = chainSets[use]; - for (auto chain : chains) - { - chain.replace(&ctx, builder, inst); - } + List<UseChain> chains = chainSets[relevantUse]; - if (!isTrivial()) + // Link the first use of each chain to inst. + for (auto& chain : chains) + ctx.cloneEnv.mapOldValToNew[chain.chain.getLast()->get()] = inst; + + // Process the inner links of each chain using the replacement. + for (auto& chain : chains) { - builder->setInsertBefore(use->getUser()); - auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, use->get()); + IRBuilder chainBuilder(builder->getModule()); + setInsertAfterOrdinaryInst(&chainBuilder, inst); - // Replace the base use. - builder->replaceOperand(use, lastInstInChain); + chain.replaceInnerLinks(&ctx, builder); } - } - bool isTrivial() - { - // We're trivial if there's only one chain, and it has only one use. - if (chainSets.getCount() != 1) - return false; + // Finally, replace the relevant use (i.e, "final use") with the new replacement inst. + builder->setInsertBefore(relevantUse->getUser()); + auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, relevantUse->get()); - auto& chain = chainSets.getFirst().value; - return chain.getCount() == 1; + // Replace the base use. + builder->replaceOperand(relevantUse, lastInstInChain); } List<IRUse*> getUniqueUses() const @@ -1668,7 +1656,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( return true; } else if ( - as<IRPtrTypeBase>(instToStore->getDataType()) && + asRelevantPtrType(instToStore->getDataType()) && !isDifferentialOrRecomputeBlock(defBlock)) { return true; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 3237ba3b2..519f796b4 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -370,7 +370,7 @@ IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc( auto diffPairType = tryGetDiffPairType(builder, paramType); if (diffPairType) { - if (!as<IRPtrTypeBase>(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType)) + if (!asRelevantPtrType(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType)) return builder->getInOutType(diffPairType); return diffPairType; } @@ -514,7 +514,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF { // As long as the primal parameter is not an out or constref type, // we need to fetch the primal value from the parameter. - if (as<IRPtrTypeBase>(propagateParamType)) + if (asRelevantPtrType(propagateParamType)) { primalArg = builder.emitLoad(param); } @@ -544,7 +544,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF } else { - auto primalPtrType = as<IRPtrTypeBase>(primalParamType); + auto primalPtrType = asRelevantPtrType(primalParamType); SLANG_RELEASE_ASSERT(primalPtrType); auto primalValueType = primalPtrType->getValueType(); auto var = builder.emitVar(primalValueType); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 38a7a18bb..8356e5f81 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -291,7 +291,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy if (isNoDiffType(origType)) return nullptr; - if (auto ptrType = as<IRPtrTypeBase>(origType)) + if (auto ptrType = asRelevantPtrType(origType)) return builder->getPtrType( origType->getOp(), differentiateType(builder, ptrType->getValueType())); @@ -556,7 +556,7 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* if (isNoDiffType(originalType)) return nullptr; - if (auto origPtrType = as<IRPtrTypeBase>(originalType)) + if (auto origPtrType = asRelevantPtrType(originalType)) { if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) return builder->getPtrType(originalType->getOp(), diffPairValueType); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 5e96c4e0f..282cc9685 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -619,7 +619,7 @@ struct DiffTransposePass if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst)) { if (auto ptrPrimalType = - as<IRPtrTypeBase>(tryGetPrimalTypeFromDiffInst(varInst))) + asRelevantPtrType(tryGetPrimalTypeFromDiffInst(varInst))) { varInst->insertAtEnd(firstRevDiffBlock); @@ -1119,7 +1119,7 @@ struct DiffTransposePass auto getDiffPairType = [](IRType* type) { - if (auto ptrType = as<IRPtrTypeBase>(type)) + if (auto ptrType = asRelevantPtrType(type)) type = ptrType->getValueType(); return as<IRDifferentialPairType>(type); }; @@ -1168,7 +1168,7 @@ struct DiffTransposePass argRequiresLoad.add(false); writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar}); } - else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType())) + else if (!asRelevantPtrType(arg->getDataType()) && getDiffPairType(arg->getDataType())) { // Normal differentiable input parameter will become an inout DiffPair parameter // in the propagate func. The split logic has already prepared the initial value @@ -1241,7 +1241,6 @@ struct DiffTransposePass argRequiresLoad.add(false); } - auto revFnType = this->autodiffContext->transcriberSet.propagateTranscriber->differentiateFunctionType( builder, diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 6bc428ad6..4d5903ab1 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -332,8 +332,8 @@ bool isIntermediateContextType(IRInst* type) case kIROp_Specialize: return isIntermediateContextType(as<IRSpecialize>(type)->getBase()); default: - if (as<IRPtrTypeBase>(type)) - return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType()); + if (auto ptrType = asRelevantPtrType(type)) + return isIntermediateContextType(ptrType->getValueType()); return false; } } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 556fb58a8..ec435ee87 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -75,16 +75,15 @@ struct DiffUnzipPass primalParam = primalParam->getNextParam()) { auto type = primalParam->getFullType(); - if (auto ptrType = as<IRPtrTypeBase>(type)) + if (auto ptrType = asRelevantPtrType(type)) { type = ptrType->getValueType(); } if (auto pairType = as<IRDifferentialPairType>(type)) { IRInst* diffType = diffTypeContext.getDiffTypeFromPairType(builder, pairType); - if (as<IRPtrTypeBase>(primalParam->getFullType())) - diffType = - builder->getPtrType(primalParam->getFullType()->getOp(), (IRType*)diffType); + if (auto ptrType = asRelevantPtrType(primalParam->getFullType())) + diffType = builder->getPtrType(ptrType->getOp(), (IRType*)diffType); auto primalRef = builder->emitPrimalParamRef(primalParam); auto diffRef = builder->emitDiffParamRef((IRType*)diffType, primalParam); builder->markInstAsDifferential(diffRef, pairType->getValueType()); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 40dcb1b51..df657476a 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -135,7 +135,7 @@ bool isNoDiffType(IRType* paramType) paramType = attrType->getBaseType(); } - else if (auto ptrType = as<IRPtrTypeBase>(paramType)) + else if (auto ptrType = asRelevantPtrType(paramType)) { paramType = ptrType->getValueType(); } @@ -184,7 +184,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( IRStructKey* key) { IRInst* pairType = nullptr; - if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType())) + if (auto basePtrType = asRelevantPtrType(baseInst->getDataType())) { auto loweredType = lowerDiffPairType(builder, basePtrType->getValueType()); @@ -203,7 +203,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( baseInst, key)); } - else if (auto ptrType = as<IRPtrTypeBase>(pairType)) + else if (auto ptrType = asRelevantPtrType(pairType)) { if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType())) { @@ -240,7 +240,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( baseInst, key)); } - else if (auto genericPtrType = as<IRPtrTypeBase>(genericType)) + else if (auto genericPtrType = asRelevantPtrType(genericType)) { if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType())) { @@ -1646,7 +1646,7 @@ IRType* DifferentiableTypeConformanceContext::differentiateType( IRBuilder* builder, IRInst* primalType) { - if (auto ptrType = as<IRPtrTypeBase>(primalType)) + if (auto ptrType = asRelevantPtrType(primalType)) return builder->getPtrType( primalType->getOp(), differentiateType(builder, ptrType->getValueType())); diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 433b6093f..4698408e3 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -604,7 +604,7 @@ inline bool isRelevantDifferentialPair(IRType* type) { return true; } - else if (auto argPtrType = as<IRPtrTypeBase>(type)) + else if (auto argPtrType = asRelevantPtrType(type)) { if (as<IRDifferentialPairType>(argPtrType->getValueType())) { diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index dbd6ac099..bf5b25d9c 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1528,6 +1528,16 @@ bool isOne(IRInst* inst) } } +IRPtrTypeBase* asRelevantPtrType(IRInst* inst) +{ + if (auto ptrType = as<IRPtrTypeBase>(inst)) + { + if (ptrType->getAddressSpace() != AddressSpace::UserPointer) + return ptrType; + } + return nullptr; +} + IRPtrTypeBase* isMutablePointerType(IRInst* inst) { switch (inst->getOp()) @@ -1535,7 +1545,7 @@ IRPtrTypeBase* isMutablePointerType(IRInst* inst) case kIROp_ConstRefType: return nullptr; default: - return as<IRPtrTypeBase>(inst); + return asRelevantPtrType(inst); } } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 610524754..aed63da47 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -271,6 +271,10 @@ bool isZero(IRInst* inst); bool isOne(IRInst* inst); +// Casts inst to IRPtrTypeBase, excluding UserPointer address space. +IRPtrTypeBase* asRelevantPtrType(IRInst* inst); + +// Returns the pointer type if it is pointer type that is not a const ref or a user pointer. IRPtrTypeBase* isMutablePointerType(IRInst* inst); void initializeScratchData(IRInst* inst); |
