From 3cf1f5a616917480c63b76aae906dc36b29e46ce Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 9 Oct 2025 18:26:28 -0700 Subject: Small fix to buffer load specialization pass to allow more specialization to happen. (#8653) This allows us to further cleanup unnecessary copies in the target code we generate. Part of effort of #8652. --- tests/optimization/wrapped-array.slang | 56 ++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 tests/optimization/wrapped-array.slang (limited to 'tests/optimization') diff --git a/tests/optimization/wrapped-array.slang b/tests/optimization/wrapped-array.slang new file mode 100644 index 000000000..cf47a228b --- /dev/null +++ b/tests/optimization/wrapped-array.slang @@ -0,0 +1,56 @@ +//TEST:SIMPLE(filecheck=CHECK):-target spirv + +// CHECK: OpEntryPoint + +// Make sure we never load the entire TensorList struct to local registers, +// instead, we should specialize the fetch function to directly load the +// element tensor from gTensors. + +// CHECK-NOT: OpLoad %TensorList_std140 + +struct RWTensor +{ + int dims[D]; + RWStructuredBuffer buffer; + T getv(vector index) + { + int flat_index = 0; + int stride = 1; + for (int i = D - 1; i >= 0; --i) + { + flat_index += index[i] * stride; + stride *= dims[i]; + } + return buffer[flat_index]; + } +} +struct TensorList +{ + RWTensor tensors[N]; + + float fetch(int tensor_index, uint2 index) + { + return tensors[tensor_index].getv(index); + } +} + +float sum_indirect(uint2 tid, TensorList tensor_list, uint tensor_indices[N]) +{ + float result = 0.0; + for (int i = 0; i < N; i++) + { + result += tensor_list.fetch(tensor_indices[i], tid); + } + return result; +} + +uniform TensorList<32> gTensors; +uniform uint gTensorIndices[32]; + +uniform float* result; + +[numthreads(1,1,1)] +void computeMain(uint2 tid : SV_DispatchThreadID) +{ + *result = sum_indirect(tid, gTensors, gTensorIndices); +} -- cgit v1.2.3