diff options
| author | Yong He <yonghe@outlook.com> | 2025-10-09 18:30:24 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-10 01:30:24 +0000 |
| commit | e420f2f980813559b186a6a6bcd5540f74310d02 (patch) | |
| tree | 8f7a833ed86e8ce2a7b40bd1e9e7da5cb95d66ab /source | |
| parent | 3cf1f5a616917480c63b76aae906dc36b29e46ce (diff) | |
Defer `IRCastStorageToLogicalDeref` in lowerBufferElementType pass. (#8668)
Fix a regression on metal test.
In `lowerBufferElementTypeToStorageType` pass, not only we want to defer
an argument that is `CastStorageToLogical` to the callee, but also apply
the same defer logic to `CastStorageToLogicalDeref` as well.
Because `CastStorageToLogicalDeref` will appear as argumnet if
`lowerBufferElementTypeToStorageType` is run before we apply the
`in->borrow` transformation pass, which is the case for metal parameter
block legalization.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit-wgsl.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 47 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-arrays.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-ir-wgsl-legalize.cpp | 89 | ||||
| -rw-r--r-- | source/slang/slang-ir-wgsl-legalize.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 10 |
9 files changed, 153 insertions, 27 deletions
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index e992d17f5..d614bd4b1 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -1278,7 +1278,8 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) m_writer->emit("uint32_t device*"); break; case kIROp_RaytracingAccelerationStructureType: - m_writer->emit("acceleration_structure<instancing>"); + m_writer->emit( + "metal::raytracing::acceleration_structure<metal::raytracing::instancing>"); break; default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled buffer type"); diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 3cebae97c..7b7c4ffa3 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -336,10 +336,12 @@ void WGSLSourceEmitter::emit(const AddressSpace addressSpace) break; case AddressSpace::StorageBuffer: + case AddressSpace::Global: m_writer->emit("storage"); break; case AddressSpace::Generic: + case AddressSpace::Function: m_writer->emit("function"); break; @@ -1311,7 +1313,7 @@ bool WGSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) void WGSLSourceEmitter::emitCallArg(IRInst* inst) { - if (as<IRPtrTypeBase>(inst->getDataType())) + if (as<IRPointerLikeType>(inst->getDataType()) || as<IRPtrTypeBase>(inst->getDataType())) { // If we are calling a function with a pointer-typed argument, we need to // explicitly prefix the argument with `&` to pass a pointer. diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 7d72fb77f..c57851bc2 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1881,6 +1881,10 @@ Result linkAndOptimizeIR( { specializeAddressSpaceForMetal(irModule); } + else if (isWGPUTarget(targetRequest)) + { + specializeAddressSpaceForWGSL(irModule); + } performForceInlining(irModule); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 42db9cb44..5c27d5e25 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4584,10 +4584,7 @@ public: IRInst* emitCastIntToPtr(IRType* ptrType, IRInst* val); IRInst* emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* bufferType); - IRCastStorageToLogicalDeref* emitCastStorageToLogicalDeref( - IRType* type, - IRInst* val, - IRInst* bufferType); + IRInst* emitCastStorageToLogicalDeref(IRType* type, IRInst* val, IRInst* bufferType); IRGlobalConstant* emitGlobalConstant(IRType* type); diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index c69592939..4bf36259b 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -12,13 +12,16 @@ /// /// Many of our targets have special restrictions on what is allowed to be used as a /// buffer element. Examples are: -/// - In HLSL and SPIRV, if you have ConstantBuffer<T>, T must be a struct. /// - In SPIRV, `bool` is considered a logical type, meaning it cannot appear inside /// buffers. bool vectors and matrices needs to be lowered into arrays. /// - In SPIRV, if `T` is used to declare a buffer, then every member in `T` must have /// explicit offset. But if it is used to declare a local variable, then it cannot /// have explicit member offset. This means that we cannot use the same `Foo` struct /// inside a `StructuredBuffer<Foo>` and also use it to declare a local variable. +/// - In Metal, if we have a `struct Foo {Texture2D member; }` and +/// `ParameterBlock<Foo>`, then we should translate it to +/// `struct Foo_pb { Texture2D.Handle member; }` and `ParameterBlock<Foo_pb>`, so that +/// the resource legalization pass won't hoist the texture out of the parameter block. /// /// We use the terms "physical", "storage", or "lowered" types to refer to types that /// are legal to use as buffer elements. In contrast, the terms "original" or "logical" @@ -187,6 +190,11 @@ /// m_ptr = FieldAddr(ptr, member) /// call f_1, m_ptr /// ``` +/// Note that it is only correct to defer a load/CastStorageToLogicalDeref if the location +/// being loaded from is immutable. Otherwise, we might be changing the order of memory +/// operations and result in a change in application behavior. So this pass will also make sure +/// that we only create `CastStorageToLogicalDeref(x)` such that `x` is an immutable location, +/// such as an immutable temporary variable. /// /// # Trailing Pointer Rewrite /// @@ -1192,10 +1200,7 @@ struct LoweredElementTypeContext // and push the cast to inside the callee. // We will process calls after other gep insts, so for now just add // it into a separate worklist. - if (castInst->getOp() == kIROp_CastStorageToLogical) - { - callWorkListSet.add((IRCall*)user); - } + callWorkListSet.add((IRCall*)user); break; } case kIROp_Load: @@ -1261,7 +1266,8 @@ struct LoweredElementTypeContext castInst->getBufferType()); user->replaceUsesWith(newCast); user->removeAndDeallocate(); - castInstWorkList.add(newCast); + if (auto newCastStorage = as<IRCastStorageToLogicalBase>(newCast)) + castInstWorkList.add(newCastStorage); break; } case kIROp_FieldExtract: @@ -1353,6 +1359,16 @@ struct LoweredElementTypeContext paramTypes.add(storagePtrType); newArgs.add(castArg->getOperand(0)); } + else if (auto castArgDeref = as<IRCastStorageToLogicalDeref>(arg)) + { + auto storageValueType = tryGetPointedToOrBufferElementType( + &builder, + castArgDeref->getOperand(0)->getDataType()); + auto storagePtrType = + builder.getBorrowInParamType(storageValueType, AddressSpace::Generic); + paramTypes.add(storagePtrType); + newArgs.add(castArgDeref->getOperand(0)); + } else { paramTypes.add(arg->getDataType()); @@ -1422,7 +1438,7 @@ struct LoweredElementTypeContext auto param = params[i]; SLANG_RELEASE_ASSERT(i < call->getArgCount()); auto arg = call->getArg(i); - auto cast = as<IRCastStorageToLogical>(arg); + auto cast = as<IRCastStorageToLogicalBase>(arg); if (!cast) continue; auto logicalParamType = param->getFullType(); @@ -1434,8 +1450,21 @@ struct LoweredElementTypeContext uses.clear(); for (auto use = param->firstUse; use; use = use->nextUse) uses.add(use); - auto castedParam = - builder.emitCastStorageToLogical(logicalParamType, param, cast->getBufferType()); + IRInst* castedParam = nullptr; + if (arg->getOp() == kIROp_CastStorageToLogical) + { + castedParam = builder.emitCastStorageToLogical( + logicalParamType, + param, + cast->getBufferType()); + } + else + { + castedParam = builder.emitCastStorageToLogicalDeref( + logicalParamType, + param, + cast->getBufferType()); + } if (auto castStorage = as<IRCastStorageToLogicalBase>(castedParam)) outNewCasts.add(castStorage); diff --git a/source/slang/slang-ir-specialize-arrays.cpp b/source/slang/slang-ir-specialize-arrays.cpp index c2bc4d14e..1f50a7579 100644 --- a/source/slang/slang-ir-specialize-arrays.cpp +++ b/source/slang/slang-ir-specialize-arrays.cpp @@ -29,20 +29,23 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition if (auto outTypeBase = as<IROutParamTypeBase>(paramType)) { paramType = outTypeBase->getValueType(); - SLANG_ASSERT(as<IRPtrTypeBase>(argType)); - argType = as<IRPtrTypeBase>(argType)->getValueType(); + IRBuilder builder(paramType); + argType = tryGetPointedToType(&builder, argType); + SLANG_ASSERT(argType); } else if (auto refType = as<IRRefParamType>(paramType)) { paramType = refType->getValueType(); - SLANG_ASSERT(as<IRPtrTypeBase>(argType)); - argType = as<IRPtrTypeBase>(argType)->getValueType(); + IRBuilder builder(paramType); + argType = tryGetPointedToType(&builder, argType); + SLANG_ASSERT(argType); } else if (auto constRefType = as<IRBorrowInParamType>(paramType)) { paramType = constRefType->getValueType(); - SLANG_ASSERT(as<IRPtrTypeBase>(argType)); - argType = as<IRPtrTypeBase>(argType)->getValueType(); + IRBuilder builder(paramType); + argType = tryGetPointedToType(&builder, argType); + SLANG_ASSERT(argType); } auto arrayType = as<IRUnsizedArrayType>(paramType); if (!arrayType) diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index 51f16e603..571e6311f 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -4,6 +4,8 @@ #include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-varying-params.h" +#include "slang-ir-specialize-address-space.h" +#include "slang-ir-util.h" #include "slang-ir.h" namespace Slang @@ -225,4 +227,91 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) GlobalInstInliningContext().inlineGlobalValuesAndRemoveIfUnused(module); } +struct WGSLAddressSpaceAssigner : InitialAddressSpaceAssigner +{ + virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) override + { + switch (inst->getOp()) + { + case kIROp_Var: + if (as<IRBlock>(inst->getParent())) + outAddressSpace = AddressSpace::Function; + else + outAddressSpace = AddressSpace::ThreadLocal; + return true; + case kIROp_RWStructuredBufferGetElementPtr: + outAddressSpace = AddressSpace::Global; + return true; + case kIROp_Load: + { + auto addrSpace = getAddressSpaceFromVarType(inst->getDataType()); + if (addrSpace != AddressSpace::Generic) + { + outAddressSpace = addrSpace; + return true; + } + } + return false; + default: + return false; + } + } + + virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) override + { + if (as<IRUniformParameterGroupType>(type)) + { + return AddressSpace::Uniform; + } + if (as<IRByteAddressBufferTypeBase>(type)) + { + return AddressSpace::Global; + } + if (as<IRHLSLStructuredBufferTypeBase>(type)) + { + return AddressSpace::Global; + } + if (as<IRGLSLShaderStorageBufferType>(type)) + { + return AddressSpace::Global; + } + if (auto ptrType = as<IRPtrTypeBase>(type)) + { + if (ptrType->hasAddressSpace()) + return ptrType->getAddressSpace(); + return AddressSpace::Generic; + } + return AddressSpace::Generic; + } + + virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) override + { + if (as<IRGroupSharedRate>(inst->getRate())) + return AddressSpace::GroupShared; + switch (inst->getOp()) + { + case kIROp_RWStructuredBufferGetElementPtr: + return AddressSpace::Global; + case kIROp_Var: + if (as<IRBlock>(inst->getParent())) + return AddressSpace::Function; + else + return AddressSpace::ThreadLocal; + break; + default: + break; + } + auto type = unwrapAttributedType(inst->getDataType()); + if (!type) + return AddressSpace::Generic; + return getAddressSpaceFromVarType(type); + } +}; + +void specializeAddressSpaceForWGSL(IRModule* module) +{ + WGSLAddressSpaceAssigner wgslAddressSpaceAssigner; + specializeAddressSpace(module, &wgslAddressSpaceAssigner); +} + } // namespace Slang diff --git a/source/slang/slang-ir-wgsl-legalize.h b/source/slang/slang-ir-wgsl-legalize.h index 11e25ea88..b9b270ba3 100644 --- a/source/slang/slang-ir-wgsl-legalize.h +++ b/source/slang/slang-ir-wgsl-legalize.h @@ -7,4 +7,7 @@ namespace Slang class DiagnosticSink; void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink); + +void specializeAddressSpaceForWGSL(IRModule* module); + } // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ebaebcc8a..ba8864684 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6187,14 +6187,12 @@ IRInst* IRBuilder::emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* b return (IRCastStorageToLogical*)emitIntrinsicInst(type, kIROp_CastStorageToLogical, 2, args); } -IRCastStorageToLogicalDeref* IRBuilder::emitCastStorageToLogicalDeref( - IRType* type, - IRInst* val, - IRInst* bufferType) +IRInst* IRBuilder::emitCastStorageToLogicalDeref(IRType* type, IRInst* val, IRInst* bufferType) { IRInst* args[] = {val, bufferType}; - return (IRCastStorageToLogicalDeref*) - emitIntrinsicInst(type, kIROp_CastStorageToLogicalDeref, 2, args); + if (type == tryGetPointedToType(this, val->getDataType())) + return emitLoad(type, val); + return emitIntrinsicInst(type, kIROp_CastStorageToLogicalDeref, 2, args); } IRGlobalConstant* IRBuilder::emitGlobalConstant(IRType* type) |
