From e420f2f980813559b186a6a6bcd5540f74310d02 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 9 Oct 2025 18:30:24 -0700 Subject: Defer `IRCastStorageToLogicalDeref` in lowerBufferElementType pass. (#8668) Fix a regression on metal test. In `lowerBufferElementTypeToStorageType` pass, not only we want to defer an argument that is `CastStorageToLogical` to the callee, but also apply the same defer logic to `CastStorageToLogicalDeref` as well. Because `CastStorageToLogicalDeref` will appear as argumnet if `lowerBufferElementTypeToStorageType` is run before we apply the `in->borrow` transformation pass, which is the case for metal parameter block legalization. --- .github/workflows/ci-slang-test.yml | 3 - source/slang/slang-emit-metal.cpp | 3 +- source/slang/slang-emit-wgsl.cpp | 4 +- source/slang/slang-emit.cpp | 4 + source/slang/slang-ir-insts.h | 5 +- .../slang/slang-ir-lower-buffer-element-type.cpp | 47 +++++++++--- source/slang/slang-ir-specialize-arrays.cpp | 15 ++-- source/slang/slang-ir-wgsl-legalize.cpp | 89 ++++++++++++++++++++++ source/slang/slang-ir-wgsl-legalize.h | 3 + source/slang/slang-ir.cpp | 10 +-- tests/metal/sampler-array.slang | 34 +++++++++ 11 files changed, 187 insertions(+), 30 deletions(-) create mode 100644 tests/metal/sampler-array.slang diff --git a/.github/workflows/ci-slang-test.yml b/.github/workflows/ci-slang-test.yml index ec2eb6a8f..391716d2e 100644 --- a/.github/workflows/ci-slang-test.yml +++ b/.github/workflows/ci-slang-test.yml @@ -143,9 +143,6 @@ jobs: - name: Run slang-rhi tests run: | export SLANG_RHI_EXCLUDE_TESTS="md-clear*,cmd-copy*,cmd-upload*,fence*,staging-heap*,texture-create*" - if [[ "${{ inputs.os }}" == "macos" ]]; then - export SLANG_RHI_EXCLUDE_TESTS="sampler-array" - fi "$bin_dir/slang-rhi-tests" -check-devices -tce="$SLANG_RHI_EXCLUDE_TESTS" # Run slangpy tests when: diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index e992d17f5..d614bd4b1 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -1278,7 +1278,8 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) m_writer->emit("uint32_t device*"); break; case kIROp_RaytracingAccelerationStructureType: - m_writer->emit("acceleration_structure"); + m_writer->emit( + "metal::raytracing::acceleration_structure"); break; default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled buffer type"); diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 3cebae97c..7b7c4ffa3 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -336,10 +336,12 @@ void WGSLSourceEmitter::emit(const AddressSpace addressSpace) break; case AddressSpace::StorageBuffer: + case AddressSpace::Global: m_writer->emit("storage"); break; case AddressSpace::Generic: + case AddressSpace::Function: m_writer->emit("function"); break; @@ -1311,7 +1313,7 @@ bool WGSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) void WGSLSourceEmitter::emitCallArg(IRInst* inst) { - if (as(inst->getDataType())) + if (as(inst->getDataType()) || as(inst->getDataType())) { // If we are calling a function with a pointer-typed argument, we need to // explicitly prefix the argument with `&` to pass a pointer. diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 7d72fb77f..c57851bc2 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1881,6 +1881,10 @@ Result linkAndOptimizeIR( { specializeAddressSpaceForMetal(irModule); } + else if (isWGPUTarget(targetRequest)) + { + specializeAddressSpaceForWGSL(irModule); + } performForceInlining(irModule); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 42db9cb44..5c27d5e25 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4584,10 +4584,7 @@ public: IRInst* emitCastIntToPtr(IRType* ptrType, IRInst* val); IRInst* emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* bufferType); - IRCastStorageToLogicalDeref* emitCastStorageToLogicalDeref( - IRType* type, - IRInst* val, - IRInst* bufferType); + IRInst* emitCastStorageToLogicalDeref(IRType* type, IRInst* val, IRInst* bufferType); IRGlobalConstant* emitGlobalConstant(IRType* type); diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index c69592939..4bf36259b 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -12,13 +12,16 @@ /// /// Many of our targets have special restrictions on what is allowed to be used as a /// buffer element. Examples are: -/// - In HLSL and SPIRV, if you have ConstantBuffer, T must be a struct. /// - In SPIRV, `bool` is considered a logical type, meaning it cannot appear inside /// buffers. bool vectors and matrices needs to be lowered into arrays. /// - In SPIRV, if `T` is used to declare a buffer, then every member in `T` must have /// explicit offset. But if it is used to declare a local variable, then it cannot /// have explicit member offset. This means that we cannot use the same `Foo` struct /// inside a `StructuredBuffer` and also use it to declare a local variable. +/// - In Metal, if we have a `struct Foo {Texture2D member; }` and +/// `ParameterBlock`, then we should translate it to +/// `struct Foo_pb { Texture2D.Handle member; }` and `ParameterBlock`, so that +/// the resource legalization pass won't hoist the texture out of the parameter block. /// /// We use the terms "physical", "storage", or "lowered" types to refer to types that /// are legal to use as buffer elements. In contrast, the terms "original" or "logical" @@ -187,6 +190,11 @@ /// m_ptr = FieldAddr(ptr, member) /// call f_1, m_ptr /// ``` +/// Note that it is only correct to defer a load/CastStorageToLogicalDeref if the location +/// being loaded from is immutable. Otherwise, we might be changing the order of memory +/// operations and result in a change in application behavior. So this pass will also make sure +/// that we only create `CastStorageToLogicalDeref(x)` such that `x` is an immutable location, +/// such as an immutable temporary variable. /// /// # Trailing Pointer Rewrite /// @@ -1192,10 +1200,7 @@ struct LoweredElementTypeContext // and push the cast to inside the callee. // We will process calls after other gep insts, so for now just add // it into a separate worklist. - if (castInst->getOp() == kIROp_CastStorageToLogical) - { - callWorkListSet.add((IRCall*)user); - } + callWorkListSet.add((IRCall*)user); break; } case kIROp_Load: @@ -1261,7 +1266,8 @@ struct LoweredElementTypeContext castInst->getBufferType()); user->replaceUsesWith(newCast); user->removeAndDeallocate(); - castInstWorkList.add(newCast); + if (auto newCastStorage = as(newCast)) + castInstWorkList.add(newCastStorage); break; } case kIROp_FieldExtract: @@ -1353,6 +1359,16 @@ struct LoweredElementTypeContext paramTypes.add(storagePtrType); newArgs.add(castArg->getOperand(0)); } + else if (auto castArgDeref = as(arg)) + { + auto storageValueType = tryGetPointedToOrBufferElementType( + &builder, + castArgDeref->getOperand(0)->getDataType()); + auto storagePtrType = + builder.getBorrowInParamType(storageValueType, AddressSpace::Generic); + paramTypes.add(storagePtrType); + newArgs.add(castArgDeref->getOperand(0)); + } else { paramTypes.add(arg->getDataType()); @@ -1422,7 +1438,7 @@ struct LoweredElementTypeContext auto param = params[i]; SLANG_RELEASE_ASSERT(i < call->getArgCount()); auto arg = call->getArg(i); - auto cast = as(arg); + auto cast = as(arg); if (!cast) continue; auto logicalParamType = param->getFullType(); @@ -1434,8 +1450,21 @@ struct LoweredElementTypeContext uses.clear(); for (auto use = param->firstUse; use; use = use->nextUse) uses.add(use); - auto castedParam = - builder.emitCastStorageToLogical(logicalParamType, param, cast->getBufferType()); + IRInst* castedParam = nullptr; + if (arg->getOp() == kIROp_CastStorageToLogical) + { + castedParam = builder.emitCastStorageToLogical( + logicalParamType, + param, + cast->getBufferType()); + } + else + { + castedParam = builder.emitCastStorageToLogicalDeref( + logicalParamType, + param, + cast->getBufferType()); + } if (auto castStorage = as(castedParam)) outNewCasts.add(castStorage); diff --git a/source/slang/slang-ir-specialize-arrays.cpp b/source/slang/slang-ir-specialize-arrays.cpp index c2bc4d14e..1f50a7579 100644 --- a/source/slang/slang-ir-specialize-arrays.cpp +++ b/source/slang/slang-ir-specialize-arrays.cpp @@ -29,20 +29,23 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition if (auto outTypeBase = as(paramType)) { paramType = outTypeBase->getValueType(); - SLANG_ASSERT(as(argType)); - argType = as(argType)->getValueType(); + IRBuilder builder(paramType); + argType = tryGetPointedToType(&builder, argType); + SLANG_ASSERT(argType); } else if (auto refType = as(paramType)) { paramType = refType->getValueType(); - SLANG_ASSERT(as(argType)); - argType = as(argType)->getValueType(); + IRBuilder builder(paramType); + argType = tryGetPointedToType(&builder, argType); + SLANG_ASSERT(argType); } else if (auto constRefType = as(paramType)) { paramType = constRefType->getValueType(); - SLANG_ASSERT(as(argType)); - argType = as(argType)->getValueType(); + IRBuilder builder(paramType); + argType = tryGetPointedToType(&builder, argType); + SLANG_ASSERT(argType); } auto arrayType = as(paramType); if (!arrayType) diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index 51f16e603..571e6311f 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -4,6 +4,8 @@ #include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-varying-params.h" +#include "slang-ir-specialize-address-space.h" +#include "slang-ir-util.h" #include "slang-ir.h" namespace Slang @@ -225,4 +227,91 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) GlobalInstInliningContext().inlineGlobalValuesAndRemoveIfUnused(module); } +struct WGSLAddressSpaceAssigner : InitialAddressSpaceAssigner +{ + virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) override + { + switch (inst->getOp()) + { + case kIROp_Var: + if (as(inst->getParent())) + outAddressSpace = AddressSpace::Function; + else + outAddressSpace = AddressSpace::ThreadLocal; + return true; + case kIROp_RWStructuredBufferGetElementPtr: + outAddressSpace = AddressSpace::Global; + return true; + case kIROp_Load: + { + auto addrSpace = getAddressSpaceFromVarType(inst->getDataType()); + if (addrSpace != AddressSpace::Generic) + { + outAddressSpace = addrSpace; + return true; + } + } + return false; + default: + return false; + } + } + + virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) override + { + if (as(type)) + { + return AddressSpace::Uniform; + } + if (as(type)) + { + return AddressSpace::Global; + } + if (as(type)) + { + return AddressSpace::Global; + } + if (as(type)) + { + return AddressSpace::Global; + } + if (auto ptrType = as(type)) + { + if (ptrType->hasAddressSpace()) + return ptrType->getAddressSpace(); + return AddressSpace::Generic; + } + return AddressSpace::Generic; + } + + virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) override + { + if (as(inst->getRate())) + return AddressSpace::GroupShared; + switch (inst->getOp()) + { + case kIROp_RWStructuredBufferGetElementPtr: + return AddressSpace::Global; + case kIROp_Var: + if (as(inst->getParent())) + return AddressSpace::Function; + else + return AddressSpace::ThreadLocal; + break; + default: + break; + } + auto type = unwrapAttributedType(inst->getDataType()); + if (!type) + return AddressSpace::Generic; + return getAddressSpaceFromVarType(type); + } +}; + +void specializeAddressSpaceForWGSL(IRModule* module) +{ + WGSLAddressSpaceAssigner wgslAddressSpaceAssigner; + specializeAddressSpace(module, &wgslAddressSpaceAssigner); +} + } // namespace Slang diff --git a/source/slang/slang-ir-wgsl-legalize.h b/source/slang/slang-ir-wgsl-legalize.h index 11e25ea88..b9b270ba3 100644 --- a/source/slang/slang-ir-wgsl-legalize.h +++ b/source/slang/slang-ir-wgsl-legalize.h @@ -7,4 +7,7 @@ namespace Slang class DiagnosticSink; void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink); + +void specializeAddressSpaceForWGSL(IRModule* module); + } // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ebaebcc8a..ba8864684 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6187,14 +6187,12 @@ IRInst* IRBuilder::emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* b return (IRCastStorageToLogical*)emitIntrinsicInst(type, kIROp_CastStorageToLogical, 2, args); } -IRCastStorageToLogicalDeref* IRBuilder::emitCastStorageToLogicalDeref( - IRType* type, - IRInst* val, - IRInst* bufferType) +IRInst* IRBuilder::emitCastStorageToLogicalDeref(IRType* type, IRInst* val, IRInst* bufferType) { IRInst* args[] = {val, bufferType}; - return (IRCastStorageToLogicalDeref*) - emitIntrinsicInst(type, kIROp_CastStorageToLogicalDeref, 2, args); + if (type == tryGetPointedToType(this, val->getDataType())) + return emitLoad(type, val); + return emitIntrinsicInst(type, kIROp_CastStorageToLogicalDeref, 2, args); } IRGlobalConstant* IRBuilder::emitGlobalConstant(IRType* type) diff --git a/tests/metal/sampler-array.slang b/tests/metal/sampler-array.slang new file mode 100644 index 000000000..65476543e --- /dev/null +++ b/tests/metal/sampler-array.slang @@ -0,0 +1,34 @@ +//TEST:SIMPLE(filecheck=MTL): -target metal -stage compute -entry computeMain +//TEST:SIMPLE(filecheck=LIB): -target metallib -stage compute -entry computeMain + +// MTL: float S1_test{{.*}}(const S1_default{{.*}} constant* this{{.*}} +// LIB: computeMain + +struct S1 +{ + Texture2D tex[32]; + SamplerState samplers[32]; + float data; + float test(int i) + { + return tex[i].SampleLevel(samplers[i], float2(0.0, 0.0), 0.0).x + data; + } +} + +struct S0 +{ + float data; + RaytracingAccelerationStructure acc; + ParameterBlock s; +} + +ParameterBlock g; +RWStructuredBuffer buffer; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain( + uint3 sv_dispatchThreadID : SV_DispatchThreadID) +{ + buffer[0] = g.data * g.s.test(sv_dispatchThreadID.x); +} -- cgit v1.2.3