From ed5940a629ae05e9571bfe355d22f0728347dcb4 Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Tue, 22 Apr 2025 14:04:56 -0600 Subject: Implement shader subgroup rotate intrinsics (#6878) * Initial implementation for SPIRV, GLSL and Metal * test add bool test * Fix and improve subgroup rotate tests * Add proper GLSL extensions and proper Metal type checking * Clean up tests and add diagnostics test for subgroup type for Metal * Update wave-intrinsics docs --- tests/diagnostics/wave-operations-types.slang | 14 +++ .../wave-rotate/wave-rotate-clustered.slang | 133 ++++++++++++++++++++ tests/hlsl-intrinsic/wave-rotate/wave-rotate.slang | 134 +++++++++++++++++++++ 3 files changed, 281 insertions(+) create mode 100644 tests/diagnostics/wave-operations-types.slang create mode 100644 tests/hlsl-intrinsic/wave-rotate/wave-rotate-clustered.slang create mode 100644 tests/hlsl-intrinsic/wave-rotate/wave-rotate.slang (limited to 'tests') diff --git a/tests/diagnostics/wave-operations-types.slang b/tests/diagnostics/wave-operations-types.slang new file mode 100644 index 000000000..55a6a8e91 --- /dev/null +++ b/tests/diagnostics/wave-operations-types.slang @@ -0,0 +1,14 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target metal + +RWStructuredBuffer out; + +[shader("compute")] +void computeMain(uint3 dispatchID : SV_DispatchThreadID) +{ + // CHECK: Unsupported type for subgroup operations in Metal. Valid types include + // CHECK: Unsupported type for subgroup operations in Metal. Valid types include + // CHECK: Unsupported type for subgroup operations in Metal. Valid types include + out[0] = WaveRotate(true, 1); + out[1] = WaveRotate(uint8_t(dispatchID.x), 1); + out[2] = WaveRotate(uint64_t(dispatchID.x), 1); +} diff --git a/tests/hlsl-intrinsic/wave-rotate/wave-rotate-clustered.slang b/tests/hlsl-intrinsic/wave-rotate/wave-rotate-clustered.slang new file mode 100644 index 000000000..d52384c15 --- /dev/null +++ b/tests/hlsl-intrinsic/wave-rotate/wave-rotate-clustered.slang @@ -0,0 +1,133 @@ +//TEST_CATEGORY(wave, compute) +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-via-glsl + +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly -xslang -DUSE_GLSL_SYNTAX -allow-glsl +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-via-glsl -xslang -DUSE_GLSL_SYNTAX -allow-glsl + +#if defined(USE_GLSL_SYNTAX) +#define __clusteredRotate subgroupClusteredRotate +#else +#define __clusteredRotate WaveClusteredRotate +#endif + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +#define SUBGROUP_SIZE 32 +#define DELTA 3 +#define CLUSTER_SIZE 8 + +static uint threadIndex; +static uint clusterIndex; +static uint rotatedValue; + +__generic +bool test1ClusteredRotate() +{ + return __clusteredRotate(T(threadIndex), DELTA, CLUSTER_SIZE) == T(rotatedValue); +} + +__generic +bool testVRClusteredRotate() +{ + typealias gvec = vector; + +#if defined(USE_GLSL_SYNTAX) + return (__clusteredRotate(gvec(T(threadIndex)), DELTA, CLUSTER_SIZE) == gvec(T(rotatedValue))); +#else + return (__clusteredRotate(gvec(T(threadIndex)), DELTA, CLUSTER_SIZE) == gvec(T(rotatedValue)))[0]; +#endif +} + +bool test1ClusteredRotateBool() +{ + bool currentValue = (threadIndex % 2 == 0) ? true : false; + bool rotatedValueBool = (threadIndex % 2 == 0) ? false : true; + return __clusteredRotate(currentValue, DELTA, CLUSTER_SIZE) == rotatedValueBool; +} + +__generic +bool testVRClusteredRotateBool() +{ + typealias gvec = vector; + bool currentValue = (threadIndex % 2 == 0) ? true : false; + bool rotatedValueBool = (threadIndex % 2 == 0) ? false : true; + +#if defined(USE_GLSL_SYNTAX) + return (__clusteredRotate(gvec(currentValue), DELTA, CLUSTER_SIZE) == gvec(rotatedValueBool)); +#else + return (__clusteredRotate(gvec(currentValue), DELTA, CLUSTER_SIZE) == gvec(rotatedValueBool))[0]; +#endif +} + +bool testClusteredRotate() +{ + return true + & test1ClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & test1ClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & test1ClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & test1ClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & test1ClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & test1ClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & test1ClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & test1ClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & test1ClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & test1ClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & testVRClusteredRotate() + & test1ClusteredRotateBool() + & testVRClusteredRotateBool<2>() + & testVRClusteredRotateBool<3>() + & testVRClusteredRotateBool<4>() + ; +} + +[shader("compute")] +[numthreads(SUBGROUP_SIZE, 1, 1)] +void computeMain(uint3 dispatchID : SV_DispatchThreadID) +{ + threadIndex = dispatchID.x; + clusterIndex = dispatchID.x % CLUSTER_SIZE; + + // Determine expected value of clustered rotate in current invocation. + // The values passed in are global invocation ids, and we rotate them withina cluster of size `CLUSTER_SIZE`. + uint clusterStart = (threadIndex / CLUSTER_SIZE) * CLUSTER_SIZE; + rotatedValue = clusterStart + ((threadIndex - clusterStart + DELTA) % CLUSTER_SIZE); + + bool result = true + & testClusteredRotate() + ; + + // CHECK: 1 + outputBuffer[0] = uint(result); +} + diff --git a/tests/hlsl-intrinsic/wave-rotate/wave-rotate.slang b/tests/hlsl-intrinsic/wave-rotate/wave-rotate.slang new file mode 100644 index 000000000..4b815c265 --- /dev/null +++ b/tests/hlsl-intrinsic/wave-rotate/wave-rotate.slang @@ -0,0 +1,134 @@ +//TEST_CATEGORY(wave, compute) +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-via-glsl +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-metal -compute -shaderobj -xslang -DMETAL + + +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-directly -xslang -DUSE_GLSL_SYNTAX -allow-glsl +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -emit-spirv-via-glsl -xslang -DUSE_GLSL_SYNTAX -allow-glsl +//TEST:COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-metal -compute -shaderobj -xslang -DMETAL -xslang -DUSE_GLSL_SYNTAX -allow-glsl + + +#if defined(USE_GLSL_SYNTAX) +#define __rotate subgroupRotate +#else +#define __rotate WaveRotate +#endif + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +#define SUBGROUP_SIZE 32 +#define DELTA 3 + +static uint threadIndex; +static uint rotatedValue; + +__generic +bool test1Rotate() +{ + return __rotate(T(threadIndex), DELTA) == T(rotatedValue); +} + +__generic +bool testVRotate() +{ + typealias gvec = vector; + +#if defined(USE_GLSL_SYNTAX) + return (__rotate(gvec(T(threadIndex)), DELTA) == gvec(T(rotatedValue))); +#else + return (__rotate(gvec(T(threadIndex)), DELTA) == gvec(T(rotatedValue)))[0]; +#endif +} + +bool test1RotateBool() +{ + bool currentValue = (threadIndex % 2 == 0) ? true : false; + bool rotatedValueBool = (threadIndex % 2 == 0) ? false : true; + return __rotate(currentValue, DELTA) == rotatedValueBool; +} + +__generic +bool testVRotateBool() +{ + typealias gvec = vector; + bool currentValue = (threadIndex % 2 == 0) ? true : false; + bool rotatedValueBool = (threadIndex % 2 == 0) ? false : true; + +#if defined(USE_GLSL_SYNTAX) + return (__rotate(gvec(currentValue), DELTA) == gvec(rotatedValueBool)); +#else + return (__rotate(gvec(currentValue), DELTA) == gvec(rotatedValueBool))[0]; +#endif +} + +bool testRotate() +{ + return true + & test1Rotate() + & testVRotate() + & testVRotate() + & testVRotate() + & test1Rotate() + & testVRotate() + & testVRotate() + & testVRotate() + & test1Rotate() + & testVRotate() + & testVRotate() + & testVRotate() + & test1Rotate() + & testVRotate() + & testVRotate() + & testVRotate() + & test1Rotate() + & testVRotate() + & testVRotate() + & testVRotate() + & test1Rotate() + & testVRotate() + & testVRotate() + & testVRotate() + + // Subgroup rotate operations on these builtin types are not supported on Metal. +#if !defined(METAL) + & test1Rotate() + & testVRotate() + & testVRotate() + & testVRotate() + & test1Rotate() + & testVRotate() + & testVRotate() + & testVRotate() + & test1Rotate() + & testVRotate() + & testVRotate() + & testVRotate() + & test1Rotate() + & testVRotate() + & testVRotate() + & testVRotate() + & test1RotateBool() + & testVRotateBool<2>() + & testVRotateBool<3>() + & testVRotateBool<4>() +#endif + ; +} + +[shader("compute")] +[numthreads(SUBGROUP_SIZE, 1, 1)] +void computeMain(uint3 dispatchID : SV_DispatchThreadID) +{ + threadIndex = dispatchID.x; + rotatedValue = (threadIndex + DELTA) % SUBGROUP_SIZE; + + bool result = true + & testRotate() + ; + + // CHECK: 1 + outputBuffer[0] = uint(result); +} + -- cgit v1.2.3