diff options
| author | Darren Wihandi <65404740+fairywreath@users.noreply.github.com> | 2025-04-22 14:04:56 -0600 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-22 20:04:56 +0000 |
| commit | ed5940a629ae05e9571bfe355d22f0728347dcb4 (patch) | |
| tree | 90a36c6543f0ee3748b80112a478897b027dddab /tests/hlsl-intrinsic | |
| parent | d5220b327632a8aeeb9a89494bb37bd82fec30cb (diff) | |
Implement shader subgroup rotate intrinsics (#6878)
* Initial implementation for SPIRV, GLSL and Metal
* test add bool test
* Fix and improve subgroup rotate tests
* Add proper GLSL extensions and proper Metal type checking
* Clean up tests and add diagnostics test for subgroup type for Metal
* Update wave-intrinsics docs
Diffstat (limited to 'tests/hlsl-intrinsic')
| -rw-r--r-- | tests/hlsl-intrinsic/wave-rotate/wave-rotate-clustered.slang | 133 | ||||
| -rw-r--r-- | tests/hlsl-intrinsic/wave-rotate/wave-rotate.slang | 134 |
2 files changed, 267 insertions, 0 deletions
diff --git a/tests/hlsl-intrinsic/wave-rotate/wave-rotate-clustered.slang b/tests/hlsl-intrinsic/wave-rotate/wave-rotate-clustered.slang new file mode 100644 index 000000000..d52384c15 --- /dev/null +++ b/tests/hlsl-intrinsic/wave-rotate/wave-rotate-clustered.slang @@ -0,0 +1,133 @@ +//TEST_CATEGORY(wave, compute) +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-via-glsl + +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly -xslang -DUSE_GLSL_SYNTAX -allow-glsl +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-via-glsl -xslang -DUSE_GLSL_SYNTAX -allow-glsl + +#if defined(USE_GLSL_SYNTAX) +#define __clusteredRotate subgroupClusteredRotate +#else +#define __clusteredRotate WaveClusteredRotate +#endif + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<uint> outputBuffer; + +#define SUBGROUP_SIZE 32 +#define DELTA 3 +#define CLUSTER_SIZE 8 + +static uint threadIndex; +static uint clusterIndex; +static uint rotatedValue; + +__generic<T : __BuiltinArithmeticType> +bool test1ClusteredRotate() +{ + return __clusteredRotate(T(threadIndex), DELTA, CLUSTER_SIZE) == T(rotatedValue); +} + +__generic<T : __BuiltinArithmeticType, let N : int> +bool testVRClusteredRotate() +{ + typealias gvec = vector<T, N>; + +#if defined(USE_GLSL_SYNTAX) + return (__clusteredRotate(gvec(T(threadIndex)), DELTA, CLUSTER_SIZE) == gvec(T(rotatedValue))); +#else + return (__clusteredRotate(gvec(T(threadIndex)), DELTA, CLUSTER_SIZE) == gvec(T(rotatedValue)))[0]; +#endif +} + +bool test1ClusteredRotateBool() +{ + bool currentValue = (threadIndex % 2 == 0) ? true : false; + bool rotatedValueBool = (threadIndex % 2 == 0) ? false : true; + return __clusteredRotate(currentValue, DELTA, CLUSTER_SIZE) == rotatedValueBool; +} + +__generic<let N : int> +bool testVRClusteredRotateBool() +{ + typealias gvec = vector<bool, N>; + bool currentValue = (threadIndex % 2 == 0) ? true : false; + bool rotatedValueBool = (threadIndex % 2 == 0) ? false : true; + +#if defined(USE_GLSL_SYNTAX) + return (__clusteredRotate(gvec(currentValue), DELTA, CLUSTER_SIZE) == gvec(rotatedValueBool)); +#else + return (__clusteredRotate(gvec(currentValue), DELTA, CLUSTER_SIZE) == gvec(rotatedValueBool))[0]; +#endif +} + +bool testClusteredRotate() +{ + return true + & test1ClusteredRotate<float>() + & testVRClusteredRotate<float, 2>() + & testVRClusteredRotate<float, 3>() + & testVRClusteredRotate<float, 4>() + & test1ClusteredRotate<half>() + & testVRClusteredRotate<half, 2>() + & testVRClusteredRotate<half, 3>() + & testVRClusteredRotate<half, 4>() + & test1ClusteredRotate<uint>() + & testVRClusteredRotate<uint, 2>() + & testVRClusteredRotate<uint, 3>() + & testVRClusteredRotate<uint, 4>() + & test1ClusteredRotate<uint16_t>() + & testVRClusteredRotate<uint16_t, 2>() + & testVRClusteredRotate<uint16_t, 3>() + & testVRClusteredRotate<uint16_t, 4>() + & test1ClusteredRotate<int>() + & testVRClusteredRotate<int, 2>() + & testVRClusteredRotate<int, 3>() + & testVRClusteredRotate<int, 4>() + & test1ClusteredRotate<int16_t>() + & testVRClusteredRotate<int16_t, 2>() + & testVRClusteredRotate<int16_t, 3>() + & testVRClusteredRotate<int16_t, 4>() + & test1ClusteredRotate<uint8_t>() + & testVRClusteredRotate<uint8_t, 2>() + & testVRClusteredRotate<uint8_t, 3>() + & testVRClusteredRotate<uint8_t, 4>() + & test1ClusteredRotate<uint64_t>() + & testVRClusteredRotate<uint64_t, 2>() + & testVRClusteredRotate<uint64_t, 3>() + & testVRClusteredRotate<uint64_t, 4>() + & test1ClusteredRotate<int8_t>() + & testVRClusteredRotate<int8_t, 2>() + & testVRClusteredRotate<int8_t, 3>() + & testVRClusteredRotate<int8_t, 4>() + & test1ClusteredRotate<int64_t>() + & testVRClusteredRotate<int64_t, 2>() + & testVRClusteredRotate<int64_t, 3>() + & testVRClusteredRotate<int64_t, 4>() + & test1ClusteredRotateBool() + & testVRClusteredRotateBool<2>() + & testVRClusteredRotateBool<3>() + & testVRClusteredRotateBool<4>() + ; +} + +[shader("compute")] +[numthreads(SUBGROUP_SIZE, 1, 1)] +void computeMain(uint3 dispatchID : SV_DispatchThreadID) +{ + threadIndex = dispatchID.x; + clusterIndex = dispatchID.x % CLUSTER_SIZE; + + // Determine expected value of clustered rotate in current invocation. + // The values passed in are global invocation ids, and we rotate them withina cluster of size `CLUSTER_SIZE`. + uint clusterStart = (threadIndex / CLUSTER_SIZE) * CLUSTER_SIZE; + rotatedValue = clusterStart + ((threadIndex - clusterStart + DELTA) % CLUSTER_SIZE); + + bool result = true + & testClusteredRotate() + ; + + // CHECK: 1 + outputBuffer[0] = uint(result); +} + diff --git a/tests/hlsl-intrinsic/wave-rotate/wave-rotate.slang b/tests/hlsl-intrinsic/wave-rotate/wave-rotate.slang new file mode 100644 index 000000000..4b815c265 --- /dev/null +++ b/tests/hlsl-intrinsic/wave-rotate/wave-rotate.slang @@ -0,0 +1,134 @@ +//TEST_CATEGORY(wave, compute) +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-via-glsl +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-metal -compute -shaderobj -xslang -DMETAL + + +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly -xslang -DUSE_GLSL_SYNTAX -allow-glsl +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-via-glsl -xslang -DUSE_GLSL_SYNTAX -allow-glsl +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-metal -compute -shaderobj -xslang -DMETAL -xslang -DUSE_GLSL_SYNTAX -allow-glsl + + +#if defined(USE_GLSL_SYNTAX) +#define __rotate subgroupRotate +#else +#define __rotate WaveRotate +#endif + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<uint> outputBuffer; + +#define SUBGROUP_SIZE 32 +#define DELTA 3 + +static uint threadIndex; +static uint rotatedValue; + +__generic<T : __BuiltinArithmeticType> +bool test1Rotate() +{ + return __rotate(T(threadIndex), DELTA) == T(rotatedValue); +} + +__generic<T : __BuiltinArithmeticType, let N : int> +bool testVRotate() +{ + typealias gvec = vector<T, N>; + +#if defined(USE_GLSL_SYNTAX) + return (__rotate(gvec(T(threadIndex)), DELTA) == gvec(T(rotatedValue))); +#else + return (__rotate(gvec(T(threadIndex)), DELTA) == gvec(T(rotatedValue)))[0]; +#endif +} + +bool test1RotateBool() +{ + bool currentValue = (threadIndex % 2 == 0) ? true : false; + bool rotatedValueBool = (threadIndex % 2 == 0) ? false : true; + return __rotate(currentValue, DELTA) == rotatedValueBool; +} + +__generic<let N : int> +bool testVRotateBool() +{ + typealias gvec = vector<bool, N>; + bool currentValue = (threadIndex % 2 == 0) ? true : false; + bool rotatedValueBool = (threadIndex % 2 == 0) ? false : true; + +#if defined(USE_GLSL_SYNTAX) + return (__rotate(gvec(currentValue), DELTA) == gvec(rotatedValueBool)); +#else + return (__rotate(gvec(currentValue), DELTA) == gvec(rotatedValueBool))[0]; +#endif +} + +bool testRotate() +{ + return true + & test1Rotate<float>() + & testVRotate<float, 2>() + & testVRotate<float, 3>() + & testVRotate<float, 4>() + & test1Rotate<half>() + & testVRotate<half, 2>() + & testVRotate<half, 3>() + & testVRotate<half, 4>() + & test1Rotate<uint>() + & testVRotate<uint, 2>() + & testVRotate<uint, 3>() + & testVRotate<uint, 4>() + & test1Rotate<uint16_t>() + & testVRotate<uint16_t, 2>() + & testVRotate<uint16_t, 3>() + & testVRotate<uint16_t, 4>() + & test1Rotate<int>() + & testVRotate<int, 2>() + & testVRotate<int, 3>() + & testVRotate<int, 4>() + & test1Rotate<int16_t>() + & testVRotate<int16_t, 2>() + & testVRotate<int16_t, 3>() + & testVRotate<int16_t, 4>() + + // Subgroup rotate operations on these builtin types are not supported on Metal. +#if !defined(METAL) + & test1Rotate<uint8_t>() + & testVRotate<uint8_t, 2>() + & testVRotate<uint8_t, 3>() + & testVRotate<uint8_t, 4>() + & test1Rotate<uint64_t>() + & testVRotate<uint64_t, 2>() + & testVRotate<uint64_t, 3>() + & testVRotate<uint64_t, 4>() + & test1Rotate<int8_t>() + & testVRotate<int8_t, 2>() + & testVRotate<int8_t, 3>() + & testVRotate<int8_t, 4>() + & test1Rotate<int64_t>() + & testVRotate<int64_t, 2>() + & testVRotate<int64_t, 3>() + & testVRotate<int64_t, 4>() + & test1RotateBool() + & testVRotateBool<2>() + & testVRotateBool<3>() + & testVRotateBool<4>() +#endif + ; +} + +[shader("compute")] +[numthreads(SUBGROUP_SIZE, 1, 1)] +void computeMain(uint3 dispatchID : SV_DispatchThreadID) +{ + threadIndex = dispatchID.x; + rotatedValue = (threadIndex + DELTA) % SUBGROUP_SIZE; + + bool result = true + & testRotate() + ; + + // CHECK: 1 + outputBuffer[0] = uint(result); +} + |
