summaryrefslogtreecommitdiffstats
path: root/tests/optimization/wrapped-array.slang
blob: cf47a228b0432e411bef0f2ef92e3adbd67921f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
//TEST:SIMPLE(filecheck=CHECK):-target spirv

// CHECK: OpEntryPoint

// Make sure we never load the entire TensorList struct to local registers,
// instead, we should specialize the fetch function to directly load the
// element tensor from gTensors.

// CHECK-NOT: OpLoad %TensorList_std140

struct RWTensor<T, let D : int>
{
    int dims[D];
    RWStructuredBuffer<T> buffer;
    T getv(vector<uint, D> index)
    {
        int flat_index = 0;
        int stride = 1;
        for (int i = D - 1; i >= 0; --i)
        {
            flat_index += index[i] * stride;
            stride *= dims[i];
        }
        return buffer[flat_index];
    }
}
struct TensorList<let N : int>
{
    RWTensor<float, 2> tensors[N];

    float fetch(int tensor_index, uint2 index)
    {
        return tensors[tensor_index].getv(index);
    }
}

float sum_indirect<let N : int>(uint2 tid, TensorList<N> tensor_list, uint tensor_indices[N])
{
    float result = 0.0;
    for (int i = 0; i < N; i++)
    {
        result += tensor_list.fetch(tensor_indices[i], tid);
    }
    return result;
}

uniform TensorList<32> gTensors;
uniform uint gTensorIndices[32];

uniform float* result;

[numthreads(1,1,1)]
void computeMain(uint2 tid : SV_DispatchThreadID)
{
    *result = sum_indirect(tid, gTensors, gTensorIndices);
}