diff options
| author | Yong He <yonghe@outlook.com> | 2025-09-30 19:08:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-30 19:08:23 -0700 |
| commit | e4611e2e30a3e5969d402f5ed7e72706a0e3b024 (patch) | |
| tree | 0f4240ccf8c4f0786949ab33adb0fcc332890d11 /source/slang/slang-ir-specialize-function-call.cpp | |
| parent | b6422e50cb19f7f790f29678ba22f31b0b305511 (diff) | |
Enhance buffer load specialization pass to specialize past field extracts. (#8547)
This allows us to specialize functions whose argument is a sub element
of a constant buffer, instead of being only applicable to entire buffer
element. Closes #8421.
This change also implements a proper heuristic to determine when to
specialize the calls and defer the buffer loads.
This PR addresses a pathological case exposed in
`slangpy\slangpy\benchmarks\test_benchmark_tensor.py`, which used to
take 27ms to finish, and now takes 1.25ms.
For example, given:
```
struct Bottom
{
float bigArray[1024];
[mutating]
void setVal(int index, float value) { bigArray[index] = value; }
}
struct Root
{
Bottom top[2];
[mutating]
void setTopVal(int x, int y, float value)
{
top[x].setVal(y, value);
}
}
RWStructuredBuffer<Root> sb;
[shader("compute")]
[numthreads(1, 1, 1)]
void compute_main(uint3 tid: SV_DispatchThreadID)
{
sb[0].setTopVal(1, 2, 100.0f);
}
```
We are now able to specialize the call to `setTopVal` into:
```
void compute_main(uint3 tid: SV_DispatchThreadID)
{
setTopVal_specialized(0, 1, 2, 100.0f);
}
void setTopVal_specialized(int sbIdx, int x, int y, float value)
{
Bottom_setVal_specialized(sbIdx, x, y, value);
}
void Bottom_setVal_specialized(int sbIdx, int x, int y, float value)
{
sb[sbIdx].top[x].bigArray[y] = value;
}
```
And get rid of all unnecessary loads. Achieving this requires a
combination of function call specialization and buffer-load-defer pass.
The buffer-load-defer pass has been completely rewritten to be more
correct and avoid introducing redundant loads.
This PR also adds tests to make sure pointers, bindless handles, and
loads from structured buffer or constant buffers works as expected.
Diffstat (limited to 'source/slang/slang-ir-specialize-function-call.cpp')
| -rw-r--r-- | source/slang/slang-ir-specialize-function-call.cpp | 205 |
1 files changed, 172 insertions, 33 deletions
diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp index 7c82891a6..aead69258 100644 --- a/source/slang/slang-ir-specialize-function-call.cpp +++ b/source/slang/slang-ir-specialize-function-call.cpp @@ -40,6 +40,12 @@ bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization( if (as<IRGlobalValueWithCode>(arg)) return true; + if (isUserPointerType(arg->getDataType())) + return true; + + if (as<IRCastDescriptorHandleToResource>(arg)) + return true; + // As we will see later, we can also // specialize a call when the argument // is the result of indexing into an @@ -47,17 +53,29 @@ bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization( // of the indexing operation is also // suitable for specialization. // - if (arg->getOp() == kIROp_GetElement || arg->getOp() == kIROp_Load) + switch (arg->getOp()) { - auto base = arg->getOperand(0); - - // We will "recurse" on the base of - // the indexing operation by continuing - // our loop with the `base` as our new - // argument. - // - arg = base; - continue; + case kIROp_GetElement: + case kIROp_StructuredBufferLoad: + case kIROp_ByteAddressBufferLoad: + case kIROp_GetElementPtr: + case kIROp_RWStructuredBufferGetElementPtr: + case kIROp_FieldAddress: + case kIROp_FieldExtract: + case kIROp_Load: + { + auto base = arg->getOperand(0); + + // We will "recurse" on the base of + // the indexing operation by continuing + // our loop with the `base` as our new + // argument. + // + arg = base; + continue; + } + default: + break; } // By default, we will *not* consider an argument @@ -225,7 +243,7 @@ struct FunctionParameterSpecializationContext // If neither the parameter nor the argument wants specialization, // then we need to keep looking. // - auto paramWantSpecialization = doesParamWantSpecialization(param, arg); + auto paramWantSpecialization = doesParamWantSpecialization(param, arg, call); auto paramTypeWantSpecialization = doesParamTypeWantSpecialization(param, arg); if (!paramWantSpecialization && !paramTypeWantSpecialization) continue; @@ -255,9 +273,9 @@ struct FunctionParameterSpecializationContext // Of course, now we need to back-fill the predicates that // the above function used to evaluate prameters and arguments. - bool doesParamWantSpecialization(IRParam* param, IRInst* arg) + bool doesParamWantSpecialization(IRParam* param, IRInst* arg, IRCall* callInst) { - return condition->doesParamWantSpecialization(param, arg); + return condition->doesParamWantSpecialization(param, arg, callInst); } bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg) @@ -484,16 +502,20 @@ struct FunctionParameterSpecializationContext UInt oldArgIndex = oldArgCounter++; auto oldArg = oldCall->getArg(oldArgIndex); - getCallInfoForParam(callInfo, oldParam, oldArg); + getCallInfoForParam(callInfo, oldParam, oldArg, oldCall); } } - void getCallInfoForParam(CallSpecializationInfo& ioInfo, IRParam* oldParam, IRInst* oldArg) + void getCallInfoForParam( + CallSpecializationInfo& ioInfo, + IRParam* oldParam, + IRInst* oldArg, + IRCall* callInst) { // We know that the case where the parameter // and argument don't want specialization is easy. // - if (!doesParamWantSpecialization(oldParam, oldArg)) + if (!doesParamWantSpecialization(oldParam, oldArg, callInst)) { // The new call site will use the same argument // value as the old one, and we don't need @@ -546,7 +568,15 @@ struct FunctionParameterSpecializationContext // Similarly for other global constants ioInfo.key.vals.add(globalConstant); } - else if (oldArg->getOp() == kIROp_GetElement) + else if (isUserPointerType(oldArg->getDataType())) + { + // If the arg is a user pointer, we can pass it as an ordinary argument, + // and we won't need further tracing down the access chain. + // + ioInfo.key.vals.add(oldArg->getFullType()); + ioInfo.newArgs.add(oldArg); + } + else if (isElementAccessInst(oldArg)) { // This is the case where the `oldArg` is // in the form `oldBase[oldIndex]` @@ -587,19 +617,45 @@ struct FunctionParameterSpecializationContext ioInfo.newArgs.add(oldIndex); } + else if (isFieldAccessInst(oldArg)) + { + // This is the case where the `oldArg` is + // in the form `oldBase.structKey` + // + auto oldBase = oldArg->getOperand(0); + auto structKey = oldArg->getOperand(1); + + // Similar to the getElement case, we recursively setting up whatever + // `oldBase` needs first. + // + getCallInfoForArg(ioInfo, oldBase); + + // The main difference from the `getElement` case is we actually want + // the structKey to be in the specialization key because it will be baked + // into the specialized function. + // And we won't introduce a new parameter to hold the index. + // + ioInfo.key.vals.add(structKey); + } else if (oldArg->getOp() == kIROp_Load) { auto oldBase = oldArg->getOperand(0); getCallInfoForArg(ioInfo, oldBase); } + else if (oldArg->getOp() == kIROp_CastDescriptorHandleToResource) + { + // We are accessing a resource from a bindless handle. + // We can stop recursion here and just pass in the bindless handle as + // an argument. + auto oldBase = oldArg->getOperand(0); + ioInfo.key.vals.add(oldBase->getFullType()); + ioInfo.newArgs.add(oldBase); + } else { // If we fail to match any of the cases above - // then a precondition was violated in that - // `isArgSuitableForSpecialization` is allowing - // a case that this routine is not covering. - // - SLANG_UNEXPECTED("mising case in 'getCallInfoForArg'"); + // then the `SpecializeCondition` is letting through constructs that we cannot handle. + SLANG_UNEXPECTED("unexpected function call specialization argument form."); } } @@ -641,7 +697,7 @@ struct FunctionParameterSpecializationContext // will stand in for the parameter in the specialized // function. // - auto newVal = getSpecializedValueForParam(funcInfo, oldParam, oldArg); + auto newVal = getSpecializedValueForParam(funcInfo, oldParam, oldArg, oldCall); // We will collect the replacement value to use // for each of the original parameters in an array. @@ -681,12 +737,13 @@ struct FunctionParameterSpecializationContext IRInst* getSpecializedValueForParam( FuncSpecializationInfo& ioInfo, IRParam* oldParam, - IRInst* oldArg) + IRInst* oldArg, + IRCall* callInst) { // As always, the easy case is when the parameter of // the original function doesn't need specialization. // - if (!doesParamWantSpecialization(oldParam, oldArg)) + if (!doesParamWantSpecialization(oldParam, oldArg, callInst)) { // The specialized callee will need a new parameter // that fills the same role as the old one, so we @@ -718,6 +775,36 @@ struct FunctionParameterSpecializationContext } } + // Returns true if `inst` is an instruction that accesses an element from an array or a buffer. + // + static bool isElementAccessInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_GetElementPtr: + case kIROp_GetElement: + case kIROp_RWStructuredBufferGetElementPtr: + case kIROp_StructuredBufferLoad: + case kIROp_ByteAddressBufferLoad: + return true; + } + return false; + } + + // Returns true if `inst` is an instruction that accesses a field from a struct, that is + // either a FieldAddress or FieldExtract. + // + static bool isFieldAccessInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_FieldAddress: + case kIROp_FieldExtract: + return true; + } + return false; + } + IRInst* getSpecializedValueForArg(FuncSpecializationInfo& ioInfo, IRInst* oldArg) { // The logic here parallels `gatherCallInfoForArg`, @@ -735,13 +822,24 @@ struct FunctionParameterSpecializationContext // return globalParam; } + if (isUserPointerType(oldArg->getDataType())) + { + // If argument is a user pointer, we can pass it into the callee + // directly as an oridinary argument without further specializing + // for the access chain beyond the pointer. + // + auto builder = getBuilder(); + auto newParam = builder->createParam(oldArg->getFullType()); + ioInfo.newParams.add(newParam); + return newParam; + } if (auto globalFunc = as<IRGlobalValueWithCode>(oldArg)) { // As above, the identity of the specialized function is sufficient // to resolve the uses return globalFunc; } - else if (oldArg->getOp() == kIROp_GetElement) + else if (isElementAccessInst(oldArg)) { // This is the case where the argument is // in the form `oldBase[oldIndex]`. @@ -801,7 +899,9 @@ struct FunctionParameterSpecializationContext // of things, and then inserted to a more permanent location later. // builder->setInsertLoc(IRInsertLoc()); - auto newVal = builder->emitElementExtract(oldArg->getFullType(), newBase, newIndex); + IRInst* newOperands[] = {newBase, newIndex}; + auto newVal = + builder->emitIntrinsicInst(oldArg->getFullType(), oldArg->getOp(), 2, newOperands); // Because our new instruction wasn't // actually inserted anywhere, we need to @@ -813,6 +913,30 @@ struct FunctionParameterSpecializationContext return newVal; } + else if (isFieldAccessInst(oldArg)) + { + // This is the case where the argument is + // in the form `oldBase.structKey`. + // + auto oldBase = oldArg->getOperand(0); + auto structKey = oldArg->getOperand(1); + + // We handle this case in a similar way as the `oldBase[oldIndex]` + // case, except that we don't need to introduce a new parameter + // for the index, since the struct key is known at compile-time. + auto newBase = getSpecializedValueForArg(ioInfo, oldBase); + + auto builder = getBuilder(); + + builder->setInsertLoc(IRInsertLoc()); + IRInst* newOperands[] = {newBase, structKey}; + auto newVal = + builder->emitIntrinsicInst(oldArg->getFullType(), oldArg->getOp(), 2, newOperands); + + ioInfo.newBodyInsts.add(newVal); + + return newVal; + } else if (auto oldArgLoad = as<IRLoad>(oldArg)) { auto oldPtr = oldArgLoad->getPtr(); @@ -825,15 +949,30 @@ struct FunctionParameterSpecializationContext return newVal; } + else if (auto castHandleToResource = as<IRCastDescriptorHandleToResource>(oldArg)) + { + // We are accessing a resource from a bindless handle. + // We should create a param for the handle, and load the resource from the param. + auto builder = getBuilder(); + auto oldHandle = castHandleToResource->getOperand(0); + auto newHandle = builder->createParam(oldHandle->getFullType()); + ioInfo.newParams.add(newHandle); + + builder->setInsertLoc(IRInsertLoc()); + IRInst* newOperands[] = {newHandle}; + auto newVal = builder->emitIntrinsicInst( + oldArg->getFullType(), + kIROp_CastDescriptorHandleToResource, + 1, + newOperands); + ioInfo.newBodyInsts.add(newVal); + return newVal; + } else { // If we don't match one of the above cases, - // then `isArgSuitableForSpecialization` is - // letting through cases that this function - // hasn't been updated to handle. - // - SLANG_UNEXPECTED("mising case in 'getSpecializedValueForArg'"); - UNREACHABLE_RETURN(nullptr); + // then we are running into an invalid case. + SLANG_UNEXPECTED("unknown argument form for function call specialization."); } } |
