diff options
| author | Yong He <yonghe@outlook.com> | 2025-09-29 17:45:08 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-30 00:45:08 +0000 |
| commit | a6deb5ed82cb8fc6b4f4c5c5fee264e09f97ff89 (patch) | |
| tree | 1c374bd52498cad2e142e3c7f5482fd42dca966f /source/slang/slang-ir-transform-params-to-constref.cpp | |
| parent | 2827c94de5901cac42a67f73a78ab2548771b28c (diff) | |
Rewriting the lower-buffer-element-type pass to avoid unnecessary packing/unpacking. (#8526)
Part of the effort to improve the performance of generated SPIRV code.
The existing lower-buffer-element-type pass works by loading the entire
buffer element content from memory, and translate it to logical type
stored in a local variable at the earliest reference of a buffer handle.
This means that is can generate inefficient code that reads more than
necessary.
Consider this example:
```
struct BigStruct { bool values[1024]; }
ConstantBuffer<BigStruct> cb;
void test(BigStruct v)
{
if (v.values[0]) { printf("ok"); }
}
[numthreads(1,1,1)]
void computeMain()
{
test(cb);
}
```
In IR, the `computeMain` function before lower-buffer-element-type pass
is something like following:
```
func test:
%v = param : BigStruct
%barr = fieldExtract(%v, "values")
%element = elementExtract(%barr, 0)
... // uses %element
func computeMain:
%v = load(cb)
call %test %v
```
The existing lower-buffer-element-type pass will rewrite the bool array
in `BigStruct` into `int` array so it is legal in SPIRV. However, it
does so by inserting the translation on the first `load` of the constant
buffer:
```
struct BigStruct_std430 {
int values[1024];
}
var cb : ConstantBuffer<BigStruct_std430>;
func computeMain:
%tmpVar : var<BigStruct>
call %unpackStorage(%tmpVar, cb)
%v : BigStruct = load %tmpVar
call %test %v
```
This means that the entire array will be loaded and translated to int,
before calling `test`, which only uses one element. It turns out that
the downstream compiler isn't always able to optimize out this
inefficient translation/copy.
This PR completely rewrites the way buffer-element-type lowering is
handled to avoid producing this inefficient code. It works in two parts:
first we turn on the `transformParamsToConstRef` pass for SPIRV target
as well, so we will translate the `test` function to take the `v`
parameter as `constref`. The second part is a redesigned
buffer-element-type pass that defers the storage-type to logical-type
translation until a value is actually used by a `load` instruction.
In this example, after `transformParamsToConstRef`, the IR is:
```
func test:
%v = param : ConstRef<BigStruct>
%barr = fieldAddr(%v, "values")
%elementPtr = elementAddr(%barr, 0)
%element = load(%elementPtr)
... // uses %element
func computeMain:
call %test %cb
```
The new `buffer-element-type-lowering` pass will take this IR, and
insert translation at latest possible time across the entire call graph,
and translate the IR into:
```
func test:
%v = param : ConstRef<BigStruct_std430>
%barr = fieldAddr(%v, "values")
%elementPtr : ptr<int> = elementAddr(%barr, 0)
%element_int = load(%elementPtr)
%element = cast(%element_int) : %bool
... // uses %element
func computeMain:
call %test %cb
```
In this new IR, there is no longer a load and conversion of the entire
array.
See new comment in `slang-ir-lower-buffer-element-type.cpp` for more
details of how the pass works.
This PR also address many other issues surfaced by turning on
`transformParamsToConstRef` pass on SPIRV backend.
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-ir-transform-params-to-constref.cpp')
| -rw-r--r-- | source/slang/slang-ir-transform-params-to-constref.cpp | 142 |
1 files changed, 116 insertions, 26 deletions
diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp index 9328a1de1..d34b3d25b 100644 --- a/source/slang/slang-ir-transform-params-to-constref.cpp +++ b/source/slang/slang-ir-transform-params-to-constref.cpp @@ -31,10 +31,7 @@ struct TransformParamsToConstRefContext case kIROp_StructType: case kIROp_ArrayType: case kIROp_UnsizedArrayType: - case kIROp_VectorType: - case kIROp_MatrixType: case kIROp_TupleType: - case kIROp_CoopVectorType: // valid type, continue to check break; default: @@ -44,58 +41,121 @@ struct TransformParamsToConstRefContext return true; } - void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams) + void rewriteValueUsesToAddrUses(IRInst* newAddrInst) { - // Traverse the uses of our updated params to rewrite them. - // Assume a `in` parameter has been converted to a `constref` parameter. - for (auto param : updatedParams) + HashSet<IRInst*> workListSet; + workListSet.add(newAddrInst); + List<IRInst*> workList; + workList.add(newAddrInst); + auto _addToWorkList = [&](IRInst* inst) + { + if (workListSet.add(inst)) + workList.add(inst); + }; + for (Index i = 0; i < workList.getCount(); i++) { + auto inst = workList[i]; traverseUses( - param, + inst, [&](IRUse* use) { auto user = use->getUser(); + if (workListSet.contains(user)) + return; switch (user->getOp()) { case kIROp_FieldExtract: { // Transform the IRFieldExtract into a IRFieldAddress + if (!isUseBaseAddrOperand(use, user)) + break; auto fieldExtract = as<IRFieldExtract>(user); builder.setInsertBefore(fieldExtract); auto fieldAddr = builder.emitFieldAddress( fieldExtract->getBase(), fieldExtract->getField()); - auto loadInst = builder.emitLoad(fieldAddr); - fieldExtract->replaceUsesWith(loadInst); - fieldExtract->removeAndDeallocate(); - break; + fieldExtract->replaceUsesWith(fieldAddr); + _addToWorkList(fieldAddr); + return; } case kIROp_GetElement: { // Transform the IRGetElement into a IRGetElementPtr + if (!isUseBaseAddrOperand(use, user)) + break; auto getElement = as<IRGetElement>(user); builder.setInsertBefore(getElement); auto elemAddr = builder.emitElementAddress( getElement->getBase(), getElement->getIndex()); - auto loadInst = builder.emitLoad(elemAddr); - getElement->replaceUsesWith(loadInst); - getElement->removeAndDeallocate(); - break; + getElement->replaceUsesWith(elemAddr); + _addToWorkList(elemAddr); + return; } - default: + case kIROp_Store: { - // Insert a load before the user and replace the user with the load - builder.setInsertBefore(user); - auto loadInst = builder.emitLoad(param); - use->set(loadInst); + // If the current value is being stored into a write-once temp var that + // is immediately passed into a constref location in a call, we can get + // rid of the temp var and replace it with `inst` directly. + // (such temp var can be introduced during `updateCallSites` when we + // were processing the callee.) + // + auto dest = as<IRStore>(user)->getPtr(); + if (dest->findDecorationImpl(kIROp_TempCallArgImmutableVarDecoration)) + { + user->removeAndDeallocate(); + dest->replaceUsesWith(inst); + dest->removeAndDeallocate(); + return; + } break; } } + // Insert a load before the user and replace the user with the load + builder.setInsertBefore(user); + auto loadInst = builder.emitLoad(inst); + use->set(loadInst); }); } } + void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams) + { + // Traverse the uses of our updated params to rewrite them. + // Assume a `in` parameter has been converted to a `constref` parameter. + for (auto param : updatedParams) + { + rewriteValueUsesToAddrUses(param); + } + } + + // Check if `load` is an `IRLoad(addr)` where `addr` is a immutable location. + IRInst* isLoadFromImmutableAddress(IRInst* load) + { + if (load->getOp() != kIROp_Load) + return nullptr; + auto addr = load->getOperand(0); + auto root = getRootAddr(addr); + if (!root) + return nullptr; + if (!root->getDataType()) + return nullptr; + switch (root->getDataType()->getOp()) + { + case kIROp_ConstantBufferType: + case kIROp_ConstRefType: + case kIROp_ParameterBlockType: + return addr; + default: + // Note that we should in general not assume a read-only StructuredBuffer or + // a pointer with read-only access as an immutable location due to potential aliasing. + // We could introduce a compiler flag to turn on optimizations on these buffer types + // assuming there is no aliasing. + break; + } + return nullptr; + } + // Update call sites to pass an address instead of value for each updated-param void updateCallSites(IRFunc* func, HashSet<IRParam*>& updatedParams) { @@ -119,10 +179,19 @@ struct TransformParamsToConstRefContext newArgs.add(arg); continue; } - - auto tempVar = builder.emitVar(arg->getFullType()); - builder.emitStore(tempVar, arg); - newArgs.add(tempVar); + if (auto addr = isLoadFromImmutableAddress(arg)) + { + // If existing argument is a load from an immutable buffer address, + // we can pass in the address as is, without making a temporary copy. + newArgs.add(addr); + } + else + { + auto tempVar = builder.emitVar(arg->getFullType()); + builder.addDecoration(tempVar, kIROp_TempCallArgImmutableVarDecoration); + builder.emitStore(tempVar, arg); + newArgs.add(tempVar); + } } // Create new call with updated arguments @@ -177,6 +246,25 @@ struct TransformParamsToConstRefContext { HashSet<IRParam*> updatedParams; + // If the function is used in any way that is not understood by the + // compiler, do not modify it. + // For example, if the function is used as callback, we must preserve + // its signature. + for (auto use = func->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (as<IRDecoration>(user)) + continue; + if (auto call = as<IRCall>(user)) + { + if (call->getCalleeUse() == use) + continue; + } + // If we reach here, we encountered a non-call use of the func, + // we will stop processing. + return; + } + // First pass: Transform parameter types for (auto param = func->getFirstParam(); param; param = param->getNextParam()) { @@ -192,7 +280,7 @@ struct TransformParamsToConstRefContext // This allows us to pass the address of variables directly into a function, // giving us the choice to remove copies into a parameter. auto paramType = param->getDataType(); - auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal); + auto constRefType = builder.getConstRefType(paramType, AddressSpace::Generic); param->setFullType(constRefType); changed = true; @@ -205,6 +293,8 @@ struct TransformParamsToConstRefContext return; } + fixUpFuncType(func); + // Second pass: Update function body according to the new `constref` parameters rewriteParamUseSitesToSupportConstRefUsage(updatedParams); |
