diff options
| author | Yong He <yonghe@outlook.com> | 2025-09-30 19:08:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-30 19:08:23 -0700 |
| commit | e4611e2e30a3e5969d402f5ed7e72706a0e3b024 (patch) | |
| tree | 0f4240ccf8c4f0786949ab33adb0fcc332890d11 /source/slang/slang-ir-specialize-buffer-load-arg.cpp | |
| parent | b6422e50cb19f7f790f29678ba22f31b0b305511 (diff) | |
Enhance buffer load specialization pass to specialize past field extracts. (#8547)
This allows us to specialize functions whose argument is a sub element
of a constant buffer, instead of being only applicable to entire buffer
element. Closes #8421.
This change also implements a proper heuristic to determine when to
specialize the calls and defer the buffer loads.
This PR addresses a pathological case exposed in
`slangpy\slangpy\benchmarks\test_benchmark_tensor.py`, which used to
take 27ms to finish, and now takes 1.25ms.
For example, given:
```
struct Bottom
{
float bigArray[1024];
[mutating]
void setVal(int index, float value) { bigArray[index] = value; }
}
struct Root
{
Bottom top[2];
[mutating]
void setTopVal(int x, int y, float value)
{
top[x].setVal(y, value);
}
}
RWStructuredBuffer<Root> sb;
[shader("compute")]
[numthreads(1, 1, 1)]
void compute_main(uint3 tid: SV_DispatchThreadID)
{
sb[0].setTopVal(1, 2, 100.0f);
}
```
We are now able to specialize the call to `setTopVal` into:
```
void compute_main(uint3 tid: SV_DispatchThreadID)
{
setTopVal_specialized(0, 1, 2, 100.0f);
}
void setTopVal_specialized(int sbIdx, int x, int y, float value)
{
Bottom_setVal_specialized(sbIdx, x, y, value);
}
void Bottom_setVal_specialized(int sbIdx, int x, int y, float value)
{
sb[sbIdx].top[x].bigArray[y] = value;
}
```
And get rid of all unnecessary loads. Achieving this requires a
combination of function call specialization and buffer-load-defer pass.
The buffer-load-defer pass has been completely rewritten to be more
correct and avoid introducing redundant loads.
This PR also adds tests to make sure pointers, bindless handles, and
loads from structured buffer or constant buffers works as expected.
Diffstat (limited to 'source/slang/slang-ir-specialize-buffer-load-arg.cpp')
| -rw-r--r-- | source/slang/slang-ir-specialize-buffer-load-arg.cpp | 124 |
1 files changed, 83 insertions, 41 deletions
diff --git a/source/slang/slang-ir-specialize-buffer-load-arg.cpp b/source/slang/slang-ir-specialize-buffer-load-arg.cpp index 905f2e058..a5a3dd2d9 100644 --- a/source/slang/slang-ir-specialize-buffer-load-arg.cpp +++ b/source/slang/slang-ir-specialize-buffer-load-arg.cpp @@ -1,8 +1,11 @@ // slang-ir-specialize-buffer-load-arg.cpp #include "slang-ir-specialize-buffer-load-arg.h" +#include "slang-ir-defer-buffer-load.h" #include "slang-ir-insts.h" +#include "slang-ir-layout.h" #include "slang-ir-specialize-function-call.h" +#include "slang-ir-util.h" #include "slang-ir.h" namespace Slang @@ -17,76 +20,115 @@ namespace Slang // As swith most of our IR passes, we encapsulate the logic here in a context // type so that the data that needs to be shared throughout the pass can // be conveniently scoped. +// + +// Note that this pass also ensures other more contrived cases are properly +// handled. For example: +// +// * A load of a large structure from field in a constant buffer, so that +// the value loaded is not the entire buffer contents. +// +// * A load of a large structure from a structured buffer, or any other kind +// of buffer that requires an index. +// struct FuncBufferLoadSpecializationCondition : FunctionCallSpecializeCondition { typedef FunctionCallSpecializeCondition Super; - virtual bool doesParamWantSpecialization(IRParam* param, IRInst* arg) + CodeGenContext* codegenContext; + + virtual bool doesParamWantSpecialization(IRParam* param, IRInst* arg, IRCall* callInst) { // We only want to specialize for `struct` types and not base types. // - // TODO: We might want to consider some criteria here for the "large-ness" - // of a structure (in terms of bytes and/or fields), so that we don't - // eliminate loads of sufficiently small types (which are cheap to pass - // by value). - // - auto paramType = param->getDataType(); - if (!as<IRStructType>(paramType)) + auto paramType = (IRType*)unwrapAttributedType(param->getDataType()); + if (!isTypePreferrableToDeferLoad(codegenContext, paramType)) return false; - // We also only want to specialize for arguments that are a load - // from some kind of global shader parameter. + // We want to handle loads from arbitrary access chains rooting from a shader parameter. // IRInst* a = arg; - if (auto argLoad = as<IRLoad>(arg)) - { - a = argLoad->getPtr(); - } - else + for (;;) { - return false; - } + // A user pointer can be directly passed into the function, so we no + // longer need to trace up further. + if (isUserPointerType(a->getDataType())) + break; - // We want to handle loads from a shader parameter that is an array - // of buffers, and not just a single global buffer. - // - while (auto argGetElement = as<IRGetElement>(a)) - { - a = argGetElement->getBase(); + if (auto argGetElement = as<IRGetElement>(a)) + { + a = argGetElement->getBase(); + } + else if (auto argSbLoad = as<IRStructuredBufferLoad>(a)) + { + a = argSbLoad->getOperand(0); + } + else if (auto argBbLoad = as<IRByteAddressBufferLoad>(a)) + { + a = argBbLoad->getOperand(0); + } + else if (auto argFieldExtract = as<IRFieldExtract>(a)) + { + a = argFieldExtract->getBase(); + } + else if (auto argGetElementPtr = as<IRGetElementPtr>(a)) + { + a = argGetElementPtr->getBase(); + } + else if (auto argSBGetElementPtr = as<IRRWStructuredBufferGetElementPtr>(a)) + { + a = argSBGetElementPtr->getBase(); + } + else if (auto argFieldAddr = as<IRFieldAddress>(a)) + { + a = argFieldAddr->getBase(); + } + else if (auto argLoad = as<IRLoad>(a)) + { + a = argLoad->getPtr(); + + // We can safely defer a load to the callee if the source dest is immutable. + if (isPointerToImmutableLocation(a)) + continue; + + // Otherwise, we check if there is no other instructions in between the load and the + // call that can modify the memory location. If so, we can still safely defer the + // load to the callee. + if (!isMemoryLocationUnmodifiedBetweenLoadAndUser( + codegenContext->getTargetReq(), + argLoad, + callInst)) + return false; + } + else + { + break; + } } - // The "root" of the parameter must be a reference to a global-scope - // shader parameter, so that we know we can substitute it into the callee. + // The "root" of the parameter must be one of the following: + // 1. A reference to a global-scope shader parameter that can be referenced directly from + // the callee. + // 2. A user pointer or bindless resource handle that can be passed to the callee as + // ordinary argument. // if (const auto argGlobalParam = as<IRGlobalParam>(a)) { return true; } - else + else if (isUserPointerType(a->getDataType()) || as<IRCastDescriptorHandleToResource>(a)) { - return false; + return true; } - - // TODO: There are other patterns that we could attempt to optimize here. - // For example, this logic only handles loads of the *entire* contents of - // a buffer, so it would miss: - // - // * A load of a large structure from field in a constant buffer, so that - // the value loaded is not the entire buffer contents. - // - // * A load of a large structure from a structured buffer, or any other kind - // of buffer that requires an index. - // - // * Any resource load that is not expressed at the IR level with a `load` - // instruction (e.g., those that might use an intrinsic function). - // + return false; } }; void specializeFuncsForBufferLoadArgs(CodeGenContext* codegenContext, IRModule* module) { FuncBufferLoadSpecializationCondition condition; + condition.codegenContext = codegenContext; specializeFunctionCalls(codegenContext, module, &condition); } |
