From da951e06e7eb8ad1b9c91d6176be8165ea4f2b45 Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Fri, 16 May 2025 15:01:24 -0400 Subject: Address structured buffer `GetDimensions` issues for WGSL, GLSL and SPIRV (#7010) * Fix structured buffer get dimensions * Further fixes and added tests * Remove unnecessary include * Fix test issues * attempt to fix wgpu crash * test remove half usage in test * attempt to fix WGPU test issue * Another attempt to fix WGSL test - make test similar to the existing GetDimensions test --------- Co-authored-by: Yong He --- source/slang/slang-emit-glsl.cpp | 11 ++--- source/slang/slang-emit-spirv.cpp | 21 ++++++--- .../slang/slang-ir-lower-buffer-element-type.cpp | 14 ++++++ .../get-dimensions-stride-struct.slang | 31 ++++++++++++++ tests/cross-compile/get-dimensions-stride.slang | 50 ++++++++++++++++++++++ 5 files changed, 116 insertions(+), 11 deletions(-) create mode 100644 tests/cross-compile/get-dimensions-stride-struct.slang create mode 100644 tests/cross-compile/get-dimensions-stride.slang diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 116d9b1d6..a8ce6564c 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -2790,11 +2790,12 @@ bool GLSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) auto elementType = as(inst->getOperand(0)->getDataType()) ->getElementType(); - IRIntegerValue stride = 0; - if (auto sizeDecor = elementType->findDecoration()) - { - stride = align(sizeDecor->getSize(), (int)sizeDecor->getAlignment()); - } + + // The element type should have a `SizeAndAlignment` decoration created during lowering. + auto sizeDecor = elementType->findDecoration(); + SLANG_ASSERT(sizeDecor); + const auto stride = align(sizeDecor->getSize(), (int)sizeDecor->getAlignment()); + m_writer->emit(stride); m_writer->emit(");\n"); return true; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 7d202c7c1..54417899c 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -6878,12 +6878,21 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex kResultID, inst->getOperand(0), SpvLiteralInteger::from32(0)); - auto elementType = as(inst->getOperand(0)->getDataType())->getValueType(); - IRIntegerValue stride = 0; - if (auto sizeDecor = elementType->findDecoration()) - { - stride = align(sizeDecor->getSize(), (int)sizeDecor->getAlignment()); - } + + // The buffer is a global parameter, so it's a pointer + auto bufPtrType = cast(inst->getOperand(0)->getDataType()); + // It's lowered to a struct type.. + auto bufType = cast(bufPtrType->getValueType()); + // containing an unsized array, specifically one with an explicit + // stride, which is not expressible in spirv_asm blocks + auto arrayType = cast(bufType->getFields().getFirst()->getFieldType()); + + // The element type should have a `SizeAndAlignment` decoration created during lowering. + auto sizeDecor = + arrayType->getElementType()->findDecoration(); + SLANG_ASSERT(sizeDecor); + const auto stride = align(sizeDecor->getSize(), (int)sizeDecor->getAlignment()); + auto strideOperand = emitIntConstant(stride, builder.getUIntType()); auto result = emitOpCompositeConstruct(parent, inst, inst->getDataType(), arrayLength, strideOperand); diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 1294b400d..ed0a3b309 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -939,11 +939,25 @@ struct LoweredElementTypeContext } } if (auto structBuffer = as(globalInst)) + { elementType = structBuffer->getElementType(); + auto config = getTypeLoweringConfigForBuffer(target, structBuffer); + + // Create size and alignment decoration for potential use + // in`StructuredBufferGetDimensions`. + IRSizeAndAlignment sizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + elementType, + &sizeAlignment); + SLANG_UNUSED(sizeAlignment); + } else if (auto constBuffer = as(globalInst)) elementType = constBuffer->getElementType(); else if (auto storageBuffer = as(globalInst)) elementType = storageBuffer->getElementType(); + if (as(globalInst)) continue; if (!as(elementType) && !as(elementType) && diff --git a/tests/cross-compile/get-dimensions-stride-struct.slang b/tests/cross-compile/get-dimensions-stride-struct.slang new file mode 100644 index 000000000..3af888589 --- /dev/null +++ b/tests/cross-compile/get-dimensions-stride-struct.slang @@ -0,0 +1,31 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-via-glsl +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-d3d12 -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-d3d11 -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-wgpu -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer outputBuffer; + +struct MS +{ + uint a; + uint b; +} + +//TEST_INPUT:ubuffer(data=[7 2 9 53], stride=8):name buffer0 +RWStructuredBuffer buffer0; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + int index = int(dispatchThreadID.x); + uint count = 0; + uint stride = 0; + + // CHECK: 8 + buffer0.GetDimensions(count, stride); + + outputBuffer[index] = int(stride); +} diff --git a/tests/cross-compile/get-dimensions-stride.slang b/tests/cross-compile/get-dimensions-stride.slang new file mode 100644 index 000000000..4aa57a325 --- /dev/null +++ b/tests/cross-compile/get-dimensions-stride.slang @@ -0,0 +1,50 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-d3d12 -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-wgpu -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer outputBuffer; + +//TEST_INPUT:ubuffer(data=[7 2 9 53], stride=4):name buffer0 +RWStructuredBuffer buffer0; + +//TEST_INPUT:ubuffer(data=[23 2], stride=4):name buffer1 +RWStructuredBuffer buffer1; + +//TEST_INPUT:ubuffer(data=[-10 17 9 4 2 0], stride=4):name buffer2 +RWStructuredBuffer buffer2; + +//TEST_INPUT:ubuffer(data=[-10 17 9 4 2 0], stride=4):name buffer3 +RWStructuredBuffer buffer3; + +[shader("compute")] +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + int index = int(dispatchThreadID.x); + uint count = 0; + uint stride = 0; + + if (index == 0) + { + // CHECK: 4 + buffer0.GetDimensions(count, stride); + } + else if (index == 1) + { + // CHECK: 4 + buffer1.GetDimensions(count, stride); + } + else if (index == 2) + { + // CHECK: 4 + buffer2.GetDimensions(count, stride); + } + else if (index == 3) + { + // CHECK: 4 + buffer3.GetDimensions(count, stride); + } + + outputBuffer[index] = int(stride); +} -- cgit v1.2.3