diff options
| author | Yong He <yonghe@outlook.com> | 2025-09-30 19:08:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-30 19:08:23 -0700 |
| commit | e4611e2e30a3e5969d402f5ed7e72706a0e3b024 (patch) | |
| tree | 0f4240ccf8c4f0786949ab33adb0fcc332890d11 /source/slang/slang-ir-defer-buffer-load.cpp | |
| parent | b6422e50cb19f7f790f29678ba22f31b0b305511 (diff) | |
Enhance buffer load specialization pass to specialize past field extracts. (#8547)
This allows us to specialize functions whose argument is a sub element
of a constant buffer, instead of being only applicable to entire buffer
element. Closes #8421.
This change also implements a proper heuristic to determine when to
specialize the calls and defer the buffer loads.
This PR addresses a pathological case exposed in
`slangpy\slangpy\benchmarks\test_benchmark_tensor.py`, which used to
take 27ms to finish, and now takes 1.25ms.
For example, given:
```
struct Bottom
{
float bigArray[1024];
[mutating]
void setVal(int index, float value) { bigArray[index] = value; }
}
struct Root
{
Bottom top[2];
[mutating]
void setTopVal(int x, int y, float value)
{
top[x].setVal(y, value);
}
}
RWStructuredBuffer<Root> sb;
[shader("compute")]
[numthreads(1, 1, 1)]
void compute_main(uint3 tid: SV_DispatchThreadID)
{
sb[0].setTopVal(1, 2, 100.0f);
}
```
We are now able to specialize the call to `setTopVal` into:
```
void compute_main(uint3 tid: SV_DispatchThreadID)
{
setTopVal_specialized(0, 1, 2, 100.0f);
}
void setTopVal_specialized(int sbIdx, int x, int y, float value)
{
Bottom_setVal_specialized(sbIdx, x, y, value);
}
void Bottom_setVal_specialized(int sbIdx, int x, int y, float value)
{
sb[sbIdx].top[x].bigArray[y] = value;
}
```
And get rid of all unnecessary loads. Achieving this requires a
combination of function call specialization and buffer-load-defer pass.
The buffer-load-defer pass has been completely rewritten to be more
correct and avoid introducing redundant loads.
This PR also adds tests to make sure pointers, bindless handles, and
loads from structured buffer or constant buffers works as expected.
Diffstat (limited to 'source/slang/slang-ir-defer-buffer-load.cpp')
| -rw-r--r-- | source/slang/slang-ir-defer-buffer-load.cpp | 326 |
1 files changed, 201 insertions, 125 deletions
diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp index 51c6a161b..ccdfe4538 100644 --- a/source/slang/slang-ir-defer-buffer-load.cpp +++ b/source/slang/slang-ir-defer-buffer-load.cpp @@ -3,142 +3,211 @@ #include "slang-ir-clone.h" #include "slang-ir-dominators.h" #include "slang-ir-insts.h" +#include "slang-ir-layout.h" #include "slang-ir-redundancy-removal.h" #include "slang-ir-util.h" #include "slang-ir.h" namespace Slang { -struct DeferBufferLoadContext -{ - // Map an original SSA value to a pointer that can be used to load the value. - Dictionary<IRInst*, IRInst*> mapValueToPtr; - // Map an ptr to its loaded value. - Dictionary<IRInst*, IRInst*> mapPtrToValue; +// Generally, we want to specialize arguments that are large in size, or arguments that +// are arrays or composite type that contains arrays. +// This is because: +// 1. Struct types without arrays will eventually be SROA's into registers and then effectively +// DCE'd, so they usually won't cause performance issues. In fact, front loading structs +// and reusing the loaded value instead of repetitively loading from constant memory is +// usually beneficial to performance. However large struct values can be SROA'd into a large +// number of registers, causing slow downstream compilation. Therefore we should avoid/defer +// loading them into registers if we can. +// 2. Arrays usually cannot be SROA'd into individual registers, which usually leads to +// large register consumption if they ever get loaded, so we want to defer loading array +// typed values as much as possible. - IRFunc* currentFunc = nullptr; +// If the argument data is bigger than this threshold, it is considered a large object +// and we will try to specialize it even if it doesn't contain arrays. +static const int kBufferLoadElementSizeSpecializationThreshold = 128; - // Ensure that for an original SSA value, we have formed a pointer that can be used to load the - // value. - IRInst* ensurePtr(IRInst* valueInst) - { - IRInst* result = nullptr; - if (mapValueToPtr.tryGetValue(valueInst, result)) - return result; +// If the argument data is smaller than this threshold, it is considered a tiny object +// and we will not consider specializing it, even if it contains arrays. +static const int kBufferLoadElementSizeSpecializationMinThreshold = 16; - IRBuilder b(valueInst); - b.setInsertBefore(valueInst); - - switch (valueInst->getOp()) +static bool isCompositeTypeContainingArrays(IRType* type) +{ + if (auto structType = as<IRStructType>(type)) + { + for (auto field : structType->getFields()) { - case kIROp_StructuredBufferLoad: - case kIROp_StructuredBufferLoadStatus: - { - result = b.emitRWStructuredBufferGetElementPtr( - valueInst->getOperand(0), - valueInst->getOperand(1)); - break; - } - case kIROp_GetElement: + if (const auto arrayType = as<IRArrayTypeBase>(field->getFieldType())) { - auto ptr = ensurePtr(valueInst->getOperand(0)); - if (!ptr) - return nullptr; - result = b.emitElementAddress(ptr, valueInst->getOperand(1)); - break; + return true; } - case kIROp_FieldExtract: + if (auto subStructType = as<IRStructType>(field->getFieldType())) { - auto ptr = ensurePtr(valueInst->getOperand(0)); - if (!ptr) - return nullptr; - result = b.emitFieldAddress(ptr, valueInst->getOperand(1)); - break; + if (isCompositeTypeContainingArrays(subStructType)) + return true; } - case kIROp_Load: - result = valueInst->getOperand(0); - break; - } - if (result) - { - mapValueToPtr[valueInst] = result; } - return result; } + else if (as<IRArrayTypeBase>(type)) + { + return true; + } + return false; +} - static bool isImmutableBufferLoad(IRInst* inst) +bool isTypePreferrableToDeferLoad(CodeGenContext* codeGenContext, IRType* type) +{ + // If parameter is a pointer/reference, we should consider specialize it. + if (as<IROutTypeBase>(type) || as<IRRefType>(type) || as<IRConstRefType>(type)) + return true; + + // We only want to defer loading values that are "large enough" that + // we expect them to be expensive to pass by value. + // + IRSizeAndAlignment sizeAlignment = {}; + if (SLANG_FAILED(getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + type, + &sizeAlignment))) { - // Note: we cannot defer loads from RWStructuredBuffer because there can be other - // instructions that modify the buffer. + // If type contains fields that we don't know how to compute natural size + // for, default to specialize if it contains arrays. + return isCompositeTypeContainingArrays(type); + } + + // If the argument is very small, don't bother specializing. + if (sizeAlignment.size <= kBufferLoadElementSizeSpecializationMinThreshold) + return false; + + // If the argument is somewhat small, don't specialize, unless it contains + // arrays. + if (sizeAlignment.size <= kBufferLoadElementSizeSpecializationThreshold) + { + // We generally do not specialize for small values, except it contains + // arrays that usually present a challenge for the SROA pass to eliminate + // unnecessary loads. + if (!isCompositeTypeContainingArrays(type)) + return false; + } + return true; +} + +// Returns true if memory loaded by `loadInst` is not modified before `userInst` after it is +// loaded. +// This method is currently implementing a very conservative analysis that only allows +// `loadInst` to be in the same block as `userInst`, with basic aliasing analysis for any +// stores in between. All other cases are conservatively treated as the memory location may be +// modified. +bool isMemoryLocationUnmodifiedBetweenLoadAndUser( + TargetRequest* target, + IRInst* loadInst, + IRInst* userInst) +{ + auto func = getParentFunc(loadInst); + if (!func) + return false; + + // For now we only check if loadInst and userInst are in the same block. + if (loadInst->getParent() != userInst->getParent()) + return false; + + for (IRInst* inst = loadInst->getNextInst(); inst; inst = inst->getNextInst()) + { + // We found callInst before hitting any instruction that may modify the memory. + if (inst == userInst) + return true; + + if (!inst->mightHaveSideEffects()) + continue; + + // If we see any inst that has side effect, check if it is simple case that we can rule + // out the possibility of modifying the memory location. switch (inst->getOp()) { - case kIROp_StructuredBufferLoad: - case kIROp_StructuredBufferLoadStatus: - return true; - case kIROp_Load: + case kIROp_Store: { - auto rootAddr = getRootAddr(inst->getOperand(0)); - return isPointerToImmutableLocation(rootAddr); + auto storedDest = inst->getOperand(0); + if (canAddressesPotentiallyAlias(target, func, loadInst->getOperand(0), storedDest)) + return false; + continue; } default: + // For any other case, conservatively assume the memory location may be modified. return false; } } + // We didn't found callInst after loadInst within the same basic block. + // We conservatively assume the memory location may be modified. + // This check can be extended to use the dominator tree to allow + // loadInst and userInst to be in different blocks. + return false; +} - // Ensure that for a pointer value, we have created a load instruction to materialize the value. - IRInst* materializePointer(IRBuilder& builder, IRInst* loadInst) +struct DeferBufferLoadContext +{ + CodeGenContext* codeGenContext; + + + void deferBufferLoadInst(IRBuilder& builder, List<IRInst*>& workList, IRInst* loadInst) { - auto ptr = ensurePtr(loadInst); - if (!ptr) - return nullptr; - IRInst* result = nullptr; - if (mapPtrToValue.tryGetValue(ptr, result)) - return result; - IRAlignedAttr* align = nullptr; - if (auto load = as<IRLoad>(loadInst)) - align = load->findAttr<IRAlignedAttr>(); - if (!as<IRModuleInst>(ptr->getParent())) + // Don't defer the load anymore if the type is simple. + if (!isTypePreferrableToDeferLoad(codeGenContext, loadInst->getDataType()) || + loadInst->findAttr<IRAlignedAttr>()) { - setInsertAfterOrdinaryInst(&builder, ptr); - IRType* valueType = tryGetPointedToType(&builder, ptr->getFullType()); - result = builder.emitLoad(valueType, ptr, align); - mapPtrToValue[ptr] = result; + return; } - else + + auto rootAddr = getRootAddr(loadInst->getOperand(0)); + bool isImmutableBufferLoad = isPointerToImmutableLocation(rootAddr); + + // Don't defer the load if there are uses that are not getElement or fieldExtract. + // Because in this case we need to use the entire loaded value, and further deferring + // the load down any access chain will introduce redundant loads. + for (auto use = loadInst->firstUse; use; use = use->nextUse) { - setInsertBeforeOrdinaryInst(&builder, loadInst); - IRType* valueType = tryGetPointedToType(&builder, ptr->getFullType()); - result = builder.emitLoad(valueType, ptr, align); - // Since we are inserting the load in a local scope, we can't register - // the mapping to the pointer, since the global pointer needs to be - // loaded once per function. + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_GetElement: + case kIROp_FieldExtract: + // Can we defer the load to load only the requested element right before + // the element extract inst? + // If the buffer is immutable, we can always do that. + // If it is not, we need to make sure there is no other instructions that can modify + // the buffer between the load and the use. + // + if (isImmutableBufferLoad) + continue; + if (isMemoryLocationUnmodifiedBetweenLoadAndUser( + codeGenContext->getTargetReq(), + loadInst, + user)) + continue; + return; + default: + // If we see any other use the laod instruction, we assume the entire loaded value + // is needed, and we can't defer the load anymore. + return; + } } - return result; - } - static bool isSimpleType(IRInst* type) - { - if (auto modType = as<IRRateQualifiedType>(type)) - type = modType->getValueType(); - if (as<IRStructType>(type)) - return false; - if (as<IRTupleType>(type)) - return false; - if (as<IRArrayTypeBase>(type)) - return false; - return true; - } + // If we reach here, it means all uses are getElement or fieldExtract, and + // it is safe to defer the load down the access chain. - void deferBufferLoadInst(IRBuilder& builder, List<IRInst*>& workList, IRInst* loadInst) - { - // Don't defer the load anymore if the type is simple. - if (isSimpleType(loadInst->getDataType()) || loadInst->findAttr<IRAlignedAttr>()) + if (loadInst->getOp() == kIROp_StructuredBufferLoad) { - auto materializedVal = materializePointer(builder, loadInst); - loadInst->transferDecorationsTo(materializedVal); - loadInst->replaceUsesWith(materializedVal); - return; + // Convert the structuredBufferLoad to a regular load to reuse + // the same logic for deferring regular loads. + builder.setInsertBefore(loadInst); + auto bufferPtr = builder.emitRWStructuredBufferGetElementPtr( + loadInst->getOperand(0), + loadInst->getOperand(1)); + auto sbLoad = builder.emitLoad(bufferPtr); + loadInst->transferDecorationsTo(sbLoad); + loadInst->replaceUsesWith(sbLoad); + loadInst->removeAndDeallocate(); + loadInst = sbLoad; } // Otherwise, look for all uses and try to defer the load before actual use of the value. @@ -148,19 +217,29 @@ struct DeferBufferLoadContext loadInst, [&](IRUse* use) { - if (needMaterialize) - return; - auto user = use->getUser(); + switch (user->getOp()) { case kIROp_GetElement: case kIROp_FieldExtract: { - auto basePtr = ensurePtr(loadInst); - if (!basePtr) - return; - pendingWorkList.add(user); + // If we see a getElement or fieldExtract, we defer the load by + // replacing the getElement/fieldExtract with a load of the + // elementAddr/fieldAddr. + builder.setInsertBefore(user); + auto basePtr = loadInst->getOperand(0); + IRInst* gepArg = user->getOperand(1); + auto elementPtr = builder.emitElementAddress( + basePtr, + makeArrayViewSingle<IRInst*>(gepArg)); + auto newLoad = builder.emitLoad(elementPtr); + user->transferDecorationsTo(newLoad); + user->replaceUsesWith(newLoad); + user->removeAndDeallocate(); + + // Now add the new load to work list to try to defer it further. + pendingWorkList.add(newLoad); } break; default: @@ -169,41 +248,37 @@ struct DeferBufferLoadContext } }); - if (needMaterialize) - { - auto val = materializePointer(builder, loadInst); - loadInst->transferDecorationsTo(val); - loadInst->replaceUsesWith(val); - loadInst->removeAndDeallocate(); - } - else - { - // Append to worklist in reverse order so we process the uses in natural appearance - // order. - for (Index i = pendingWorkList.getCount() - 1; i >= 0; i--) - workList.add(pendingWorkList[i]); - } + // Append to worklist in reverse order so we process the uses in natural appearance + // order. + for (Index i = pendingWorkList.getCount() - 1; i >= 0; i--) + workList.add(pendingWorkList[i]); } void deferBufferLoadInFunc(IRFunc* func) { removeRedundancyInFunc(func, false); - currentFunc = func; - List<IRInst*> workList; + // Discover all load instructions and add to work list. + for (auto block : func->getBlocks()) { for (auto inst : block->getChildren()) { - if (isImmutableBufferLoad(inst)) + switch (inst->getOp()) { + case kIROp_Load: + case kIROp_StructuredBufferLoad: + // Note: We don't handle `kIROp_StructuredBufferLoadStatus` here because + // it also writes to the status code out parameter, which we can't defer. workList.add(inst); + break; } } } + // Iteratively process the work list until it is empty. IRBuilder builder(func); for (Index i = 0; i < workList.getCount(); i++) { @@ -227,9 +302,10 @@ struct DeferBufferLoadContext } }; -void deferBufferLoad(IRModule* module) +void deferBufferLoad(CodeGenContext* codeGenContext, IRModule* module) { DeferBufferLoadContext context; + context.codeGenContext = codeGenContext; for (auto childInst : module->getGlobalInsts()) { if (auto code = as<IRGlobalValueWithCode>(childInst)) |
