diff options
| author | Yong He <yonghe@outlook.com> | 2025-07-11 16:54:43 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-11 23:54:43 +0000 |
| commit | 1e1a49ccf595dcc99bd9792a47199ec89d5b4370 (patch) | |
| tree | 199dca9cc2c5f27466ebd8b6e9e6fcd8328db9fa | |
| parent | d8d0b8969f731990820f25812f3d90ee4dd1ee75 (diff) | |
Fixup address spaces after inlining. (#7731)
* Fixup address spaces after inlining.
* add -O0
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-inline.cpp | 31 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-address-space.cpp | 65 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-address-space.h | 8 | ||||
| -rw-r--r-- | tests/spirv/pointer-access.slang | 48 |
5 files changed, 149 insertions, 14 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 376a828cd..b9627e1ee 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -6848,15 +6848,14 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { baseStructType = as<IRStructType>(ptrType->getValueType()); baseId = getID(ensureInst(base)); + SLANG_ASSERT( + as<IRPtrTypeBase>(fieldAddress->getFullType())->getAddressSpace() == + ptrType->getAddressSpace() && + "field_address requires base and result to have same address space."); } else { - baseStructType = as<IRStructType>(base->getDataType()); - - auto structPtrType = builder.getPtrType(baseStructType); - auto varInst = emitOpVariable(parent, nullptr, structPtrType, SpvStorageClassFunction); - emitOpStore(parent, nullptr, varInst, base); - baseId = getID(varInst); + SLANG_UNEXPECTED("field_address requires base to be an address."); } SLANG_ASSERT(baseStructType && "field_address requires base to be a struct."); auto fieldId = emitIntConstant( diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 9e522081f..f0f940b12 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -2,6 +2,7 @@ #include "slang-ir-inline.h" #include "../core/slang-performance-profiler.h" +#include "slang-ir-specialize-address-space.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-util.h" @@ -727,6 +728,16 @@ struct InliningPassBase auto debugInlineInfo = emitCalleeDebugInlinedAt(call, callee, *builder); + // Collect all arguments that are pointers, so we can propagate their address + // spaces to the cloned instructions after inlining. + List<IRInst*> ptrArgList; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + if (as<IRPtrTypeBase>(arg->getDataType())) + ptrArgList.add(arg); + } + // If the callee consists of a single basic block *and* that block // ends with a `return` instruction, then we can apply a simple approach // to inlining that is compatible with any call site (including those @@ -742,16 +753,20 @@ struct InliningPassBase builder, debugInlineInfo.newDebugInlinedAt, debugInlineInfo.calleeDebugFunc); - return; + } + else + { + // If the callee has multiple blocks, use the more complex inlining approach + inlineMultipleBlockFuncBody( + callSite, + env, + builder, + debugInlineInfo.newDebugInlinedAt, + debugInlineInfo.calleeDebugFunc); } - // If the callee has multiple blocks, use the more complex inlining approach - inlineMultipleBlockFuncBody( - callSite, - env, - builder, - debugInlineInfo.newDebugInlinedAt, - debugInlineInfo.calleeDebugFunc); + // Propagate the address space from the argument to the cloned instructions. + propagateAddressSpaceFromInsts(_Move(ptrArgList)); } // Inline the body of the callee for `callSite`, for a callee that has multiple basic blocks. diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp index ae0542734..2bc1de775 100644 --- a/source/slang/slang-ir-specialize-address-space.cpp +++ b/source/slang/slang-ir-specialize-address-space.cpp @@ -412,4 +412,69 @@ void specializeAddressSpace(IRModule* module, InitialAddressSpaceAssigner* addrS AddressSpaceContext context(module, addrSpaceAssigner); context.processModule(); } + +void propagateAddressSpaceFromInsts(List<IRInst*>&& workList) +{ + HashSet<IRInst*> visited; + auto addUserToWorkList = [&](IRInst* inst) + { + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (visited.add(user)) + workList.add(user); + } + }; + for (auto item : workList) + { + visited.add(item); + } + for (Index i = 0; i < workList.getCount(); i++) + { + auto inst = workList[i]; + IRBuilder builder(inst); + auto instPtrType = as<IRPtrTypeBase>(inst->getDataType()); + if (!instPtrType) + continue; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + builder.setInsertBefore(user); + switch (user->getOp()) + { + case kIROp_Loop: + case kIROp_UnconditionalBranch: + { + auto branch = as<IRUnconditionalBranch>(user); + UIndex phiIndex = (UIndex)(use - branch->getArgs()); + auto param = getParamAt(branch->getTargetBlock(), phiIndex); + if (!param) + continue; + user = param; + break; + } + } + switch (user->getOp()) + { + case kIROp_FieldAddress: + case kIROp_GetElementPtr: + case kIROp_GetOffsetPtr: + case kIROp_Param: + { + auto valueType = tryGetPointedToType(&builder, user->getDataType()); + if (!valueType) + continue; + auto newType = builder.getPtrTypeWithAddressSpace(valueType, instPtrType); + if (newType != user->getDataType()) + { + user->setFullType(newType); + addUserToWorkList(user); + } + break; + } + } + } + } +} + } // namespace Slang diff --git a/source/slang/slang-ir-specialize-address-space.h b/source/slang/slang-ir-specialize-address-space.h index 6f9269017..7e5f0fd9b 100644 --- a/source/slang/slang-ir-specialize-address-space.h +++ b/source/slang/slang-ir-specialize-address-space.h @@ -1,6 +1,8 @@ // slang-ir-specialize-address-space.h #pragma once +#include "core/slang-basic.h" + #include <cinttypes> namespace Slang @@ -27,4 +29,10 @@ struct InitialAddressSpaceAssigner /// based on the address space of the arguments. /// void specializeAddressSpace(IRModule* module, InitialAddressSpaceAssigner* addrSpaceAssigner); + +/// Traverse the user graph of the initial insts and fix up address spaces to make sure they are +/// consistent. This is needed after inlining a callee, the address space of the callee's +/// instructions should be propagated from the arguments. +void propagateAddressSpaceFromInsts(List<IRInst*>&& initialArgs); + } // namespace Slang diff --git a/tests/spirv/pointer-access.slang b/tests/spirv/pointer-access.slang new file mode 100644 index 000000000..281613640 --- /dev/null +++ b/tests/spirv/pointer-access.slang @@ -0,0 +1,48 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -O0 + +//CHECK: OpEntryPoint + +struct Result +{ + float3 value; +}; + +struct Indirect +{ + float scale; +}; + +struct PushConstant +{ + Indirect *ptr; +}; + +struct Payload +{ + uint seed; +}; + +ConstantBuffer<PushConstant> pushConstants; + +Result f3(Indirect ss, float2 randomSample) +{ + Result result; + result.value = randomSample.x; + return result; +} + +float3 f2(inout uint seed) +{ + return f3(*pushConstants.ptr, float2(seed)).value; +} + +float3 f1(inout Payload payload) +{ + return f2(payload.seed); +} + +[shader("closesthit")] +void main(inout Payload payload, in BuiltInTriangleIntersectionAttributes attr) +{ + f1(payload); +}
\ No newline at end of file |
