diff options
| author | Yong He <yonghe@outlook.com> | 2025-10-09 18:26:28 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-10 01:26:28 +0000 |
| commit | 3cf1f5a616917480c63b76aae906dc36b29e46ce (patch) | |
| tree | abbc4538e17be1163c06c950b4afdacd227fe39c | |
| parent | 4e4aad5a0493defde1e0ef29f27e5d663c1182cd (diff) | |
Small fix to buffer load specialization pass to allow more specialization to happen. (#8653)
This allows us to further cleanup unnecessary copies in the target code
we generate.
Part of effort of #8652.
| -rw-r--r-- | source/slang/slang-emit.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-buffer-load-arg.cpp | 2 | ||||
| -rw-r--r-- | tests/optimization/wrapped-array.slang | 56 |
3 files changed, 63 insertions, 1 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 804a44b81..7d72fb77f 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1402,6 +1402,12 @@ Result linkAndOptimizeIR( // function parameters, reults, etc. is invalid. // We clean up the usages of resource values here. specializeResourceUsage(codeGenContext, irModule); + + // Specialize calls to functions with values loaded from an immutable location, + // so that we directly load the value inside the callee, instead of loading the + // value outside of the callee and copy it in. This is necessary to avoid copying + // large values (e.g. arrays) in registers, where most of the elements are not + // actually used. specializeFuncsForBufferLoadArgs(codeGenContext, irModule); // Push `structuredBufferLoad` to the end of access chain to avoid loading unnecessary data. diff --git a/source/slang/slang-ir-specialize-buffer-load-arg.cpp b/source/slang/slang-ir-specialize-buffer-load-arg.cpp index a5a3dd2d9..c473c6047 100644 --- a/source/slang/slang-ir-specialize-buffer-load-arg.cpp +++ b/source/slang/slang-ir-specialize-buffer-load-arg.cpp @@ -89,7 +89,7 @@ struct FuncBufferLoadSpecializationCondition : FunctionCallSpecializeCondition a = argLoad->getPtr(); // We can safely defer a load to the callee if the source dest is immutable. - if (isPointerToImmutableLocation(a)) + if (isPointerToImmutableLocation(getRootAddr(a))) continue; // Otherwise, we check if there is no other instructions in between the load and the diff --git a/tests/optimization/wrapped-array.slang b/tests/optimization/wrapped-array.slang new file mode 100644 index 000000000..cf47a228b --- /dev/null +++ b/tests/optimization/wrapped-array.slang @@ -0,0 +1,56 @@ +//TEST:SIMPLE(filecheck=CHECK):-target spirv + +// CHECK: OpEntryPoint + +// Make sure we never load the entire TensorList struct to local registers, +// instead, we should specialize the fetch function to directly load the +// element tensor from gTensors. + +// CHECK-NOT: OpLoad %TensorList_std140 + +struct RWTensor<T, let D : int> +{ + int dims[D]; + RWStructuredBuffer<T> buffer; + T getv(vector<uint, D> index) + { + int flat_index = 0; + int stride = 1; + for (int i = D - 1; i >= 0; --i) + { + flat_index += index[i] * stride; + stride *= dims[i]; + } + return buffer[flat_index]; + } +} +struct TensorList<let N : int> +{ + RWTensor<float, 2> tensors[N]; + + float fetch(int tensor_index, uint2 index) + { + return tensors[tensor_index].getv(index); + } +} + +float sum_indirect<let N : int>(uint2 tid, TensorList<N> tensor_list, uint tensor_indices[N]) +{ + float result = 0.0; + for (int i = 0; i < N; i++) + { + result += tensor_list.fetch(tensor_indices[i], tid); + } + return result; +} + +uniform TensorList<32> gTensors; +uniform uint gTensorIndices[32]; + +uniform float* result; + +[numthreads(1,1,1)] +void computeMain(uint2 tid : SV_DispatchThreadID) +{ + *result = sum_indirect(tid, gTensors, gTensorIndices); +} |
