summaryrefslogtreecommitdiff
path: root/tests/hlsl-intrinsic/wave-multi/wave-multi-sum-product.slang
blob: b40b014f4e19d50dfdc6fd22e5b37a08c2dcfb70 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//TEST_CATEGORY(wave, compute)
//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly
//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-via-glsl
//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-cuda -compute  -shaderobj -xslang -DCUDA

//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly -xslang -DUSE_GLSL_SYNTAX -allow-glsl

//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<uint> outputBuffer;

#if defined(USE_GLSL_SYNTAX)
#define __partitionedSum subgroupPartitionedAddNV
#define __partitionedProduct subgroupPartitionedMulNV
#else
#define __partitionedSum WaveMultiSum
#define __partitionedProduct WaveMultiProduct
#endif

static uint gSumResult = 0;

__generic<T : __BuiltinArithmeticType>
bool test1SumProduct(uint4 mask)
{
    let sumResult = T(gSumResult);

    return true
        & (__partitionedSum(T(1), mask) == sumResult)
        & (__partitionedProduct(T(1), mask) == T(1))
        ;
}

__generic<T : __BuiltinArithmeticType, let N : int>
bool testVSumProduct(uint4 mask) {
    typealias GVec = vector<T, N>;

    let sumResult = GVec(T(gSumResult));

    return true
        & all(__partitionedSum(GVec(T(1)), mask) == sumResult)
        & all(__partitionedProduct(GVec(T(1)), mask) == GVec(T(1)))
        ;
}

bool testSumProduct(uint4 mask)
{
    return true
        & test1SumProduct<int>(mask)
        & testVSumProduct<int, 2>(mask)
        & testVSumProduct<int, 3>(mask)
        & testVSumProduct<int, 4>(mask)
        & test1SumProduct<uint>(mask)
        & testVSumProduct<uint, 2>(mask)
        & testVSumProduct<uint, 3>(mask)
        & testVSumProduct<uint, 4>(mask)
        & test1SumProduct<float>(mask)
        & testVSumProduct<float, 2>(mask)
        & testVSumProduct<float, 3>(mask)
        & testVSumProduct<float, 4>(mask)
        & test1SumProduct<double>(mask)
        & testVSumProduct<double, 2>(mask)
        & testVSumProduct<double, 3>(mask)
        & testVSumProduct<double, 4>(mask)

#if !defined(CUDA)
        & test1SumProduct<int8_t>(mask)
        & testVSumProduct<int8_t, 2>(mask)
        & testVSumProduct<int8_t, 3>(mask)
        & testVSumProduct<int8_t, 4>(mask)
        & test1SumProduct<int16_t>(mask)
        & testVSumProduct<int16_t, 2>(mask)
        & testVSumProduct<int16_t, 3>(mask)
        & testVSumProduct<int16_t, 4>(mask)
        & test1SumProduct<int64_t>(mask)
        & testVSumProduct<int64_t, 2>(mask)
        & testVSumProduct<int64_t, 3>(mask)
        & testVSumProduct<int64_t, 4>(mask)
        & test1SumProduct<uint8_t>(mask)
        & testVSumProduct<uint8_t, 2>(mask)
        & testVSumProduct<uint8_t, 3>(mask)
        & testVSumProduct<uint8_t, 4>(mask)
        & test1SumProduct<uint16_t>(mask)
        & testVSumProduct<uint16_t, 2>(mask)
        & testVSumProduct<uint16_t, 3>(mask)
        & testVSumProduct<uint16_t, 4>(mask)
        & test1SumProduct<uint64_t>(mask)
        & testVSumProduct<uint64_t, 2>(mask)
        & testVSumProduct<uint64_t, 3>(mask)
        & testVSumProduct<uint64_t, 4>(mask)
        & test1SumProduct<half>(mask)
        & testVSumProduct<half, 2>(mask)
        & testVSumProduct<half, 3>(mask)
        & testVSumProduct<half, 4>(mask)
#endif
        ;
}

[numthreads(32, 1, 1)]
[shader("compute")]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
    uint index = dispatchThreadID.x;

    // Split into two groups, first group has 15 invocations/lanes and second group has 17.
    let isSecondGroup = index >= 15;
    uint4 mask = isSecondGroup ? uint4(0xFFFF8000, 0, 0, 0) : uint4(0x0007FFF, 0, 0, 0);
    gSumResult = isSecondGroup ? uint(17) : uint(15);

    bool result = true
            & testSumProduct(mask)
            ;

    // CHECK-COUNT-32: 1
    outputBuffer[index] = uint(result);
}