diff options
| author | Yong He <yonghe@outlook.com> | 2025-09-30 19:08:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-30 19:08:23 -0700 |
| commit | e4611e2e30a3e5969d402f5ed7e72706a0e3b024 (patch) | |
| tree | 0f4240ccf8c4f0786949ab33adb0fcc332890d11 /source/slang/slang-ir-specialize-address-space.cpp | |
| parent | b6422e50cb19f7f790f29678ba22f31b0b305511 (diff) | |
Enhance buffer load specialization pass to specialize past field extracts. (#8547)
This allows us to specialize functions whose argument is a sub element
of a constant buffer, instead of being only applicable to entire buffer
element. Closes #8421.
This change also implements a proper heuristic to determine when to
specialize the calls and defer the buffer loads.
This PR addresses a pathological case exposed in
`slangpy\slangpy\benchmarks\test_benchmark_tensor.py`, which used to
take 27ms to finish, and now takes 1.25ms.
For example, given:
```
struct Bottom
{
float bigArray[1024];
[mutating]
void setVal(int index, float value) { bigArray[index] = value; }
}
struct Root
{
Bottom top[2];
[mutating]
void setTopVal(int x, int y, float value)
{
top[x].setVal(y, value);
}
}
RWStructuredBuffer<Root> sb;
[shader("compute")]
[numthreads(1, 1, 1)]
void compute_main(uint3 tid: SV_DispatchThreadID)
{
sb[0].setTopVal(1, 2, 100.0f);
}
```
We are now able to specialize the call to `setTopVal` into:
```
void compute_main(uint3 tid: SV_DispatchThreadID)
{
setTopVal_specialized(0, 1, 2, 100.0f);
}
void setTopVal_specialized(int sbIdx, int x, int y, float value)
{
Bottom_setVal_specialized(sbIdx, x, y, value);
}
void Bottom_setVal_specialized(int sbIdx, int x, int y, float value)
{
sb[sbIdx].top[x].bigArray[y] = value;
}
```
And get rid of all unnecessary loads. Achieving this requires a
combination of function call specialization and buffer-load-defer pass.
The buffer-load-defer pass has been completely rewritten to be more
correct and avoid introducing redundant loads.
This PR also adds tests to make sure pointers, bindless handles, and
loads from structured buffer or constant buffers works as expected.
Diffstat (limited to 'source/slang/slang-ir-specialize-address-space.cpp')
| -rw-r--r-- | source/slang/slang-ir-specialize-address-space.cpp | 43 |
1 files changed, 22 insertions, 21 deletions
diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp index c4a155eec..04792bd8b 100644 --- a/source/slang/slang-ir-specialize-address-space.cpp +++ b/source/slang/slang-ir-specialize-address-space.cpp @@ -131,7 +131,6 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext bool processFunction(IRFunc* func) { bool retValAddrSpaceChanged = false; - Dictionary<IRInst*, AddressSpace> mapVarValueToAddrSpace; bool changed = true; while (changed) { @@ -152,18 +151,23 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext continue; } - // If the inst already has a pointer type with explicit address space, then use - // it. - if (auto ptrType = as<IRPtrTypeBase>(inst->getDataType())) + // If the inst already has a pointer/pointer-like type with explicit address + // space, then use it. + auto addrSpaceFromType = + addrSpaceAssigner->getAddressSpaceFromVarType(inst->getDataType()); + if (addrSpaceFromType != AddressSpace::Generic) { - if (ptrType->hasAddressSpace()) - { - mapInstToAddrSpace[inst] = ptrType->getAddressSpace(); + mapInstToAddrSpace[inst] = addrSpaceFromType; + changed = true; + + // Don't return early if the inst itself is a call, as we may still need to + // specialize it down below. + if (inst->getOp() != kIROp_Call) continue; - } } - // Otherwise, try to assign an address space based on the instruction type. + // Try to assign an address space based on the instruction type, and specialize + // calls. switch (inst->getOp()) { case kIROp_Var: @@ -195,15 +199,6 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext } break; case kIROp_Store: - { - auto addrSpace = getAddrSpace(inst->getOperand(1)); - if (addrSpace != AddressSpace::Generic) - { - mapVarValueToAddrSpace[inst->getOperand(0)] = addrSpace; - mapInstToAddrSpace[inst] = addrSpace; - changed = true; - } - } break; case kIROp_Param: if (!isFirstBlock) @@ -243,8 +238,9 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext for (UInt i = 0; i < callInst->getArgCount(); i++) { auto arg = callInst->getArg(i); - argAddrSpaces.add(getAddrSpace(arg)); - if (as<IRPtrTypeBase>(arg->getDataType())) + auto addrSpace = getAddrSpace(arg); + argAddrSpaces.add(addrSpace); + if (addrSpace != AddressSpace::Generic) { hasSpecializableArg = true; } @@ -477,8 +473,13 @@ void propagateAddressSpaceFromInsts(List<IRInst*>&& workList) } } -AddressSpace NoOpInitialAddressSpaceAssigner::getAddressSpaceFromVarType(IRInst*) +AddressSpace NoOpInitialAddressSpaceAssigner::getAddressSpaceFromVarType(IRInst* type) { + if (auto ptrType = as<IRPtrTypeBase>(type)) + { + if (ptrType->hasAddressSpace()) + return ptrType->getAddressSpace(); + } return AddressSpace::Generic; } |
