diff options
| author | Yong He <yonghe@outlook.com> | 2025-10-09 18:26:28 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-10 01:26:28 +0000 |
| commit | 3cf1f5a616917480c63b76aae906dc36b29e46ce (patch) | |
| tree | abbc4538e17be1163c06c950b4afdacd227fe39c /tests/optimization | |
| parent | 4e4aad5a0493defde1e0ef29f27e5d663c1182cd (diff) | |
Small fix to buffer load specialization pass to allow more specialization to happen. (#8653)
This allows us to further cleanup unnecessary copies in the target code
we generate.
Part of effort of #8652.
Diffstat (limited to 'tests/optimization')
| -rw-r--r-- | tests/optimization/wrapped-array.slang | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/tests/optimization/wrapped-array.slang b/tests/optimization/wrapped-array.slang new file mode 100644 index 000000000..cf47a228b --- /dev/null +++ b/tests/optimization/wrapped-array.slang @@ -0,0 +1,56 @@ +//TEST:SIMPLE(filecheck=CHECK):-target spirv + +// CHECK: OpEntryPoint + +// Make sure we never load the entire TensorList struct to local registers, +// instead, we should specialize the fetch function to directly load the +// element tensor from gTensors. + +// CHECK-NOT: OpLoad %TensorList_std140 + +struct RWTensor<T, let D : int> +{ + int dims[D]; + RWStructuredBuffer<T> buffer; + T getv(vector<uint, D> index) + { + int flat_index = 0; + int stride = 1; + for (int i = D - 1; i >= 0; --i) + { + flat_index += index[i] * stride; + stride *= dims[i]; + } + return buffer[flat_index]; + } +} +struct TensorList<let N : int> +{ + RWTensor<float, 2> tensors[N]; + + float fetch(int tensor_index, uint2 index) + { + return tensors[tensor_index].getv(index); + } +} + +float sum_indirect<let N : int>(uint2 tid, TensorList<N> tensor_list, uint tensor_indices[N]) +{ + float result = 0.0; + for (int i = 0; i < N; i++) + { + result += tensor_list.fetch(tensor_indices[i], tid); + } + return result; +} + +uniform TensorList<32> gTensors; +uniform uint gTensorIndices[32]; + +uniform float* result; + +[numthreads(1,1,1)] +void computeMain(uint2 tid : SV_DispatchThreadID) +{ + *result = sum_indirect(tid, gTensors, gTensorIndices); +} |
