summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-transform-params-to-constref.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-09-29 17:45:08 -0700
committerGitHub <noreply@github.com>2025-09-30 00:45:08 +0000
commita6deb5ed82cb8fc6b4f4c5c5fee264e09f97ff89 (patch)
tree1c374bd52498cad2e142e3c7f5482fd42dca966f /source/slang/slang-ir-transform-params-to-constref.cpp
parent2827c94de5901cac42a67f73a78ab2548771b28c (diff)
Rewriting the lower-buffer-element-type pass to avoid unnecessary packing/unpacking. (#8526)
Part of the effort to improve the performance of generated SPIRV code. The existing lower-buffer-element-type pass works by loading the entire buffer element content from memory, and translate it to logical type stored in a local variable at the earliest reference of a buffer handle. This means that is can generate inefficient code that reads more than necessary. Consider this example: ``` struct BigStruct { bool values[1024]; } ConstantBuffer<BigStruct> cb; void test(BigStruct v) { if (v.values[0]) { printf("ok"); } } [numthreads(1,1,1)] void computeMain() { test(cb); } ``` In IR, the `computeMain` function before lower-buffer-element-type pass is something like following: ``` func test: %v = param : BigStruct %barr = fieldExtract(%v, "values") %element = elementExtract(%barr, 0) ... // uses %element func computeMain: %v = load(cb) call %test %v ``` The existing lower-buffer-element-type pass will rewrite the bool array in `BigStruct` into `int` array so it is legal in SPIRV. However, it does so by inserting the translation on the first `load` of the constant buffer: ``` struct BigStruct_std430 { int values[1024]; } var cb : ConstantBuffer<BigStruct_std430>; func computeMain: %tmpVar : var<BigStruct> call %unpackStorage(%tmpVar, cb) %v : BigStruct = load %tmpVar call %test %v ``` This means that the entire array will be loaded and translated to int, before calling `test`, which only uses one element. It turns out that the downstream compiler isn't always able to optimize out this inefficient translation/copy. This PR completely rewrites the way buffer-element-type lowering is handled to avoid producing this inefficient code. It works in two parts: first we turn on the `transformParamsToConstRef` pass for SPIRV target as well, so we will translate the `test` function to take the `v` parameter as `constref`. The second part is a redesigned buffer-element-type pass that defers the storage-type to logical-type translation until a value is actually used by a `load` instruction. In this example, after `transformParamsToConstRef`, the IR is: ``` func test: %v = param : ConstRef<BigStruct> %barr = fieldAddr(%v, "values") %elementPtr = elementAddr(%barr, 0) %element = load(%elementPtr) ... // uses %element func computeMain: call %test %cb ``` The new `buffer-element-type-lowering` pass will take this IR, and insert translation at latest possible time across the entire call graph, and translate the IR into: ``` func test: %v = param : ConstRef<BigStruct_std430> %barr = fieldAddr(%v, "values") %elementPtr : ptr<int> = elementAddr(%barr, 0) %element_int = load(%elementPtr) %element = cast(%element_int) : %bool ... // uses %element func computeMain: call %test %cb ``` In this new IR, there is no longer a load and conversion of the entire array. See new comment in `slang-ir-lower-buffer-element-type.cpp` for more details of how the pass works. This PR also address many other issues surfaced by turning on `transformParamsToConstRef` pass on SPIRV backend. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-ir-transform-params-to-constref.cpp')
-rw-r--r--source/slang/slang-ir-transform-params-to-constref.cpp142
1 files changed, 116 insertions, 26 deletions
diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp
index 9328a1de1..d34b3d25b 100644
--- a/source/slang/slang-ir-transform-params-to-constref.cpp
+++ b/source/slang/slang-ir-transform-params-to-constref.cpp
@@ -31,10 +31,7 @@ struct TransformParamsToConstRefContext
case kIROp_StructType:
case kIROp_ArrayType:
case kIROp_UnsizedArrayType:
- case kIROp_VectorType:
- case kIROp_MatrixType:
case kIROp_TupleType:
- case kIROp_CoopVectorType:
// valid type, continue to check
break;
default:
@@ -44,58 +41,121 @@ struct TransformParamsToConstRefContext
return true;
}
- void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams)
+ void rewriteValueUsesToAddrUses(IRInst* newAddrInst)
{
- // Traverse the uses of our updated params to rewrite them.
- // Assume a `in` parameter has been converted to a `constref` parameter.
- for (auto param : updatedParams)
+ HashSet<IRInst*> workListSet;
+ workListSet.add(newAddrInst);
+ List<IRInst*> workList;
+ workList.add(newAddrInst);
+ auto _addToWorkList = [&](IRInst* inst)
+ {
+ if (workListSet.add(inst))
+ workList.add(inst);
+ };
+ for (Index i = 0; i < workList.getCount(); i++)
{
+ auto inst = workList[i];
traverseUses(
- param,
+ inst,
[&](IRUse* use)
{
auto user = use->getUser();
+ if (workListSet.contains(user))
+ return;
switch (user->getOp())
{
case kIROp_FieldExtract:
{
// Transform the IRFieldExtract into a IRFieldAddress
+ if (!isUseBaseAddrOperand(use, user))
+ break;
auto fieldExtract = as<IRFieldExtract>(user);
builder.setInsertBefore(fieldExtract);
auto fieldAddr = builder.emitFieldAddress(
fieldExtract->getBase(),
fieldExtract->getField());
- auto loadInst = builder.emitLoad(fieldAddr);
- fieldExtract->replaceUsesWith(loadInst);
- fieldExtract->removeAndDeallocate();
- break;
+ fieldExtract->replaceUsesWith(fieldAddr);
+ _addToWorkList(fieldAddr);
+ return;
}
case kIROp_GetElement:
{
// Transform the IRGetElement into a IRGetElementPtr
+ if (!isUseBaseAddrOperand(use, user))
+ break;
auto getElement = as<IRGetElement>(user);
builder.setInsertBefore(getElement);
auto elemAddr = builder.emitElementAddress(
getElement->getBase(),
getElement->getIndex());
- auto loadInst = builder.emitLoad(elemAddr);
- getElement->replaceUsesWith(loadInst);
- getElement->removeAndDeallocate();
- break;
+ getElement->replaceUsesWith(elemAddr);
+ _addToWorkList(elemAddr);
+ return;
}
- default:
+ case kIROp_Store:
{
- // Insert a load before the user and replace the user with the load
- builder.setInsertBefore(user);
- auto loadInst = builder.emitLoad(param);
- use->set(loadInst);
+ // If the current value is being stored into a write-once temp var that
+ // is immediately passed into a constref location in a call, we can get
+ // rid of the temp var and replace it with `inst` directly.
+ // (such temp var can be introduced during `updateCallSites` when we
+ // were processing the callee.)
+ //
+ auto dest = as<IRStore>(user)->getPtr();
+ if (dest->findDecorationImpl(kIROp_TempCallArgImmutableVarDecoration))
+ {
+ user->removeAndDeallocate();
+ dest->replaceUsesWith(inst);
+ dest->removeAndDeallocate();
+ return;
+ }
break;
}
}
+ // Insert a load before the user and replace the user with the load
+ builder.setInsertBefore(user);
+ auto loadInst = builder.emitLoad(inst);
+ use->set(loadInst);
});
}
}
+ void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams)
+ {
+ // Traverse the uses of our updated params to rewrite them.
+ // Assume a `in` parameter has been converted to a `constref` parameter.
+ for (auto param : updatedParams)
+ {
+ rewriteValueUsesToAddrUses(param);
+ }
+ }
+
+ // Check if `load` is an `IRLoad(addr)` where `addr` is a immutable location.
+ IRInst* isLoadFromImmutableAddress(IRInst* load)
+ {
+ if (load->getOp() != kIROp_Load)
+ return nullptr;
+ auto addr = load->getOperand(0);
+ auto root = getRootAddr(addr);
+ if (!root)
+ return nullptr;
+ if (!root->getDataType())
+ return nullptr;
+ switch (root->getDataType()->getOp())
+ {
+ case kIROp_ConstantBufferType:
+ case kIROp_ConstRefType:
+ case kIROp_ParameterBlockType:
+ return addr;
+ default:
+ // Note that we should in general not assume a read-only StructuredBuffer or
+ // a pointer with read-only access as an immutable location due to potential aliasing.
+ // We could introduce a compiler flag to turn on optimizations on these buffer types
+ // assuming there is no aliasing.
+ break;
+ }
+ return nullptr;
+ }
+
// Update call sites to pass an address instead of value for each updated-param
void updateCallSites(IRFunc* func, HashSet<IRParam*>& updatedParams)
{
@@ -119,10 +179,19 @@ struct TransformParamsToConstRefContext
newArgs.add(arg);
continue;
}
-
- auto tempVar = builder.emitVar(arg->getFullType());
- builder.emitStore(tempVar, arg);
- newArgs.add(tempVar);
+ if (auto addr = isLoadFromImmutableAddress(arg))
+ {
+ // If existing argument is a load from an immutable buffer address,
+ // we can pass in the address as is, without making a temporary copy.
+ newArgs.add(addr);
+ }
+ else
+ {
+ auto tempVar = builder.emitVar(arg->getFullType());
+ builder.addDecoration(tempVar, kIROp_TempCallArgImmutableVarDecoration);
+ builder.emitStore(tempVar, arg);
+ newArgs.add(tempVar);
+ }
}
// Create new call with updated arguments
@@ -177,6 +246,25 @@ struct TransformParamsToConstRefContext
{
HashSet<IRParam*> updatedParams;
+ // If the function is used in any way that is not understood by the
+ // compiler, do not modify it.
+ // For example, if the function is used as callback, we must preserve
+ // its signature.
+ for (auto use = func->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (as<IRDecoration>(user))
+ continue;
+ if (auto call = as<IRCall>(user))
+ {
+ if (call->getCalleeUse() == use)
+ continue;
+ }
+ // If we reach here, we encountered a non-call use of the func,
+ // we will stop processing.
+ return;
+ }
+
// First pass: Transform parameter types
for (auto param = func->getFirstParam(); param; param = param->getNextParam())
{
@@ -192,7 +280,7 @@ struct TransformParamsToConstRefContext
// This allows us to pass the address of variables directly into a function,
// giving us the choice to remove copies into a parameter.
auto paramType = param->getDataType();
- auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal);
+ auto constRefType = builder.getConstRefType(paramType, AddressSpace::Generic);
param->setFullType(constRefType);
changed = true;
@@ -205,6 +293,8 @@ struct TransformParamsToConstRefContext
return;
}
+ fixUpFuncType(func);
+
// Second pass: Update function body according to the new `constref` parameters
rewriteParamUseSitesToSupportConstRefUsage(updatedParams);