From a6deb5ed82cb8fc6b4f4c5c5fee264e09f97ff89 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 29 Sep 2025 17:45:08 -0700 Subject: Rewriting the lower-buffer-element-type pass to avoid unnecessary packing/unpacking. (#8526) Part of the effort to improve the performance of generated SPIRV code. The existing lower-buffer-element-type pass works by loading the entire buffer element content from memory, and translate it to logical type stored in a local variable at the earliest reference of a buffer handle. This means that is can generate inefficient code that reads more than necessary. Consider this example: ``` struct BigStruct { bool values[1024]; } ConstantBuffer cb; void test(BigStruct v) { if (v.values[0]) { printf("ok"); } } [numthreads(1,1,1)] void computeMain() { test(cb); } ``` In IR, the `computeMain` function before lower-buffer-element-type pass is something like following: ``` func test: %v = param : BigStruct %barr = fieldExtract(%v, "values") %element = elementExtract(%barr, 0) ... // uses %element func computeMain: %v = load(cb) call %test %v ``` The existing lower-buffer-element-type pass will rewrite the bool array in `BigStruct` into `int` array so it is legal in SPIRV. However, it does so by inserting the translation on the first `load` of the constant buffer: ``` struct BigStruct_std430 { int values[1024]; } var cb : ConstantBuffer; func computeMain: %tmpVar : var call %unpackStorage(%tmpVar, cb) %v : BigStruct = load %tmpVar call %test %v ``` This means that the entire array will be loaded and translated to int, before calling `test`, which only uses one element. It turns out that the downstream compiler isn't always able to optimize out this inefficient translation/copy. This PR completely rewrites the way buffer-element-type lowering is handled to avoid producing this inefficient code. It works in two parts: first we turn on the `transformParamsToConstRef` pass for SPIRV target as well, so we will translate the `test` function to take the `v` parameter as `constref`. The second part is a redesigned buffer-element-type pass that defers the storage-type to logical-type translation until a value is actually used by a `load` instruction. In this example, after `transformParamsToConstRef`, the IR is: ``` func test: %v = param : ConstRef %barr = fieldAddr(%v, "values") %elementPtr = elementAddr(%barr, 0) %element = load(%elementPtr) ... // uses %element func computeMain: call %test %cb ``` The new `buffer-element-type-lowering` pass will take this IR, and insert translation at latest possible time across the entire call graph, and translate the IR into: ``` func test: %v = param : ConstRef %barr = fieldAddr(%v, "values") %elementPtr : ptr = elementAddr(%barr, 0) %element_int = load(%elementPtr) %element = cast(%element_int) : %bool ... // uses %element func computeMain: call %test %cb ``` In this new IR, there is no longer a load and conversion of the entire array. See new comment in `slang-ir-lower-buffer-element-type.cpp` for more details of how the pass works. This PR also address many other issues surfaced by turning on `transformParamsToConstRef` pass on SPIRV backend. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --- source/slang/slang-ir.cpp | 69 ++++++++++++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 28 deletions(-) (limited to 'source/slang/slang-ir.cpp') diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index e9d1a1199..92543c952 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -386,6 +386,20 @@ IRType* tryGetPointedToType(IRBuilder* builder, IRType* type) return nullptr; } +IRType* tryGetPointedToOrBufferElementType(IRBuilder* builder, IRType* type) +{ + if (auto rateQualType = as(type)) + { + type = rateQualType->getValueType(); + } + auto resultType = tryGetPointedToType(builder, type); + if (resultType) + return resultType; + if (auto structuredBufferType = as(type)) + return structuredBufferType->getElementType(); + return nullptr; +} + // IRBlock @@ -5480,39 +5494,19 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index) } IRType* type = nullptr; valueType = unwrapAttributedType(valueType); - if (auto arrayType = as(valueType)) - { - type = arrayType->getElementType(); - } - else if (auto vectorType = as(valueType)) - { - type = vectorType->getElementType(); - } - else if (auto coopVecType = as(valueType)) - { - type = coopVecType->getElementType(); - } - else if (auto matrixType = as(valueType)) - { - type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount()); - } - else if (auto coopMatType = as(valueType)) - { - type = coopMatType->getElementType(); - } - else if (const auto basicType = as(valueType)) + if (as(valueType)) { // HLSL support things like float.x, in which case we just return the base pointer. return basePtr; } - else if (const auto tupleType = as(valueType)) - { - SLANG_ASSERT(as(index)); - type = (IRType*)tupleType->getOperand(getIntVal(index)); - } - else if (auto hlslInputPatchType = as(valueType)) + type = getElementType(*this, (IRType*)valueType); + if (!type) { - type = hlslInputPatchType->getElementType(); + if (const auto tupleType = as(valueType)) + { + SLANG_ASSERT(as(index)); + type = (IRType*)tupleType->getOperand(getIntVal(index)); + } } SLANG_RELEASE_ASSERT(type); @@ -6179,6 +6173,24 @@ IRInst* IRBuilder::emitCastIntToPtr(IRType* ptrType, IRInst* val) return inst; } +IRInst* IRBuilder::emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* bufferType) +{ + if (type == val->getDataType()) + return val; + IRInst* args[] = {val, bufferType}; + return (IRCastStorageToLogical*)emitIntrinsicInst(type, kIROp_CastStorageToLogical, 2, args); +} + +IRCastStorageToLogicalDeref* IRBuilder::emitCastStorageToLogicalDeref( + IRType* type, + IRInst* val, + IRInst* bufferType) +{ + IRInst* args[] = {val, bufferType}; + return (IRCastStorageToLogicalDeref*) + emitIntrinsicInst(type, kIROp_CastStorageToLogicalDeref, 2, args); +} + IRGlobalConstant* IRBuilder::emitGlobalConstant(IRType* type) { auto inst = createInst(this, kIROp_GlobalConstant, type); @@ -6627,6 +6639,7 @@ IRInst* IRBuilder::emitGenericAsm(UnownedStringSlice asmText) IRInst* IRBuilder::emitRWStructuredBufferGetElementPtr(IRInst* structuredBuffer, IRInst* index) { const auto sbt = cast(structuredBuffer->getDataType()); + SLANG_ASSERT(sbt); const auto t = getPtrType(sbt->getElementType()); IRInst* const operands[2] = {structuredBuffer, index}; const auto i = createInst( -- cgit v1.2.3