From 5fa35fcce532267a2ae5779dee9ff4d07fab6bf4 Mon Sep 17 00:00:00 2001 From: Anders Leino Date: Fri, 11 Oct 2024 10:05:15 +0300 Subject: WGSL: Enable load & store from byte-addressible buffers (#5252) --- source/slang/hlsl.meta.slang | 10 ++++----- source/slang/slang-emit-wgsl.cpp | 40 +++++++++++++++++++++++++++++++-- tests/compute/byte-address-buffer.slang | 1 - 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 907822060..cf3f25e8f 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -107,7 +107,7 @@ struct AppendStructuredBuffer /// @category buffer_types __magic_type(HLSLByteAddressBufferType) __intrinsic_type($(kIROp_HLSLByteAddressBufferType)) -[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer)] struct ByteAddressBuffer { [__readNone] @@ -4388,15 +4388,15 @@ uint64_t __asuint64(uint2 i) // __intrinsic_op($(kIROp_ByteAddressBufferLoad)) -[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer)] T __byteAddressBufferLoad(ByteAddressBuffer buffer, int offset, int alignment); __intrinsic_op($(kIROp_ByteAddressBufferLoad)) -[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer_rw)] T __byteAddressBufferLoad(RWByteAddressBuffer buffer, int offset, int alignment); __intrinsic_op($(kIROp_ByteAddressBufferLoad)) -[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer_rw)] T __byteAddressBufferLoad(RasterizerOrderedByteAddressBuffer buffer, int offset, int alignment); __intrinsic_op($(kIROp_ByteAddressBufferStore)) @@ -4583,7 +4583,7 @@ struct $(item.name) [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer_rw)] uint Load(int location) { __target_switch diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index a05e03afc..b1a723dc5 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -348,6 +348,13 @@ void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type) } break; + case kIROp_HLSLByteAddressBufferType: + case kIROp_HLSLRWByteAddressBufferType: + { + m_writer->emit("array"); + } + break; + case kIROp_VoidType: { // There is no void type in WGSL. @@ -590,13 +597,15 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, IRInst* varDecl) m_writer->emit(""); } else if (type->getOp() == kIROp_HLSLRWStructuredBufferType || - type->getOp() == kIROp_HLSLRasterizerOrderedStructuredBufferType) + type->getOp() == kIROp_HLSLRasterizerOrderedStructuredBufferType || + type->getOp() == kIROp_HLSLRWByteAddressBufferType) { m_writer->emit("<"); m_writer->emit("storage, read_write"); m_writer->emit(">"); } - else if (type->getOp() == kIROp_HLSLStructuredBufferType) + else if (type->getOp() == kIROp_HLSLStructuredBufferType || + type->getOp() == kIROp_HLSLByteAddressBufferType) { m_writer->emit("<"); m_writer->emit("storage, read"); @@ -1178,6 +1187,33 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu } } break; + + case kIROp_ByteAddressBufferLoad: + { + // Indices in Slang code count bytes, but in WASM they count u32's since + // byte address buffers translate to array in WASM, so divide by 4. + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("[("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")/4]"); + return true; + } + break; + + case kIROp_ByteAddressBufferStore: + { + // Indices in Slang code count bytes, but in WASM they count u32's since + // byte address buffers translate to array in WASM, so divide by 4. + auto base = inst->getOperand(0); + emitOperand(base, EmitOpInfo()); + m_writer->emit("[("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")/4] = "); + emitOperand(inst->getOperand(inst->getOperandCount() - 1), getInfo(EmitOp::General)); + return true; + } + break; + } return false; diff --git a/tests/compute/byte-address-buffer.slang b/tests/compute/byte-address-buffer.slang index 77748eb60..65356ec22 100644 --- a/tests/compute/byte-address-buffer.slang +++ b/tests/compute/byte-address-buffer.slang @@ -4,7 +4,6 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj //TEST(compute):COMPARE_COMPUTE_EX:-d3d12 -compute -shaderobj -//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -wgpu // Confirm cross-compilation of `(RW)ByteAddressBuffer` // -- cgit v1.2.3