summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--prelude/slang-cuda-prelude.h139
-rw-r--r--source/slang/hlsl.meta.slang103
-rw-r--r--source/slang/slang-profile-defs.h2
-rw-r--r--tests/hlsl-intrinsic/wave-multi-prefix.slang26
-rw-r--r--tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt8
5 files changed, 206 insertions, 72 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 6a1d87183..dcc585b9c 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -513,6 +513,18 @@ __forceinline__ __device__ int _getLaneLtMask()
return (int(1) << _getLaneId()) - 1;
}
+// Return a mask suitable for the straight 'Prefix' style ops
+__forceinline__ __device__ int _getPrefixMask()
+{
+ return __activemask();
+}
+
+// Return a mask suitable for the 'MultiPrefix' style functions
+__forceinline__ __device__ int _getMultiPrefixMask(int mask)
+{
+ return mask;
+}
+
// Note! Note will return true if mask is 0, but thats okay, because there must be one
// lane active to execute anything
__inline__ __device__ bool _waveIsSingleLane(int mask)
@@ -671,9 +683,9 @@ __device__ T _waveReduceScalar(T val)
while (remaining)
{
const int laneBit = remaining & -remaining;
- /* Get the sourceLane */
+ // Get the sourceLane
const int srcLane = __ffs(laneBit) - 1;
- /* Broadcast (can also broadcast to self) */
+ // Broadcast (can also broadcast to self)
result = INTF::doOp(result, __shfl_sync(mask, val, srcLane));
remaining &= ~laneBit;
}
@@ -718,9 +730,9 @@ __device__ void _waveReduceMultiple(T* val)
while (remaining)
{
const int laneBit = remaining & -remaining;
- /* Get the sourceLane */
+ // Get the sourceLane
const int srcLane = __ffs(laneBit) - 1;
- /* Broadcast (can also broadcast to self) */
+ // Broadcast (can also broadcast to self)
for (size_t i = 0; i < COUNT; ++i)
{
val[i] = INTF::doOp(val[i], __shfl_sync(mask, originalVal[i], srcLane));
@@ -786,7 +798,7 @@ __inline__ __device__ T _waveMaxMultiple(T val) { typedef typename ElementTypeT
template <typename T>
__inline__ __device__ bool _waveAllEqual(T val)
{
- // __match_all_sync is a synchronises so can use __activemask()
+ // __match_all_sync synchronizes so can use __activemask()
const int mask = __activemask();
int pred;
__match_all_sync(mask, val, &pred);
@@ -798,13 +810,10 @@ __inline__ __device__ bool _waveAllEqualMultiple(T inVal)
{
typedef typename ElementTypeTrait<T>::Type ElemType;
const size_t count = sizeof(T) / sizeof(ElemType);
-
- // __match_all_sync is a synchronises so can use __activemask()
+ // __match_all_sync synchronizes so can use __activemask()
const int mask = __activemask();
int pred;
-
const ElemType* src = (const ElemType*)&inVal;
-
for (size_t i = 0; i < count; ++i)
{
__match_all_sync(mask, src[i], &pred);
@@ -829,20 +838,15 @@ __inline__ __device__ T _waveReadFirstMultiple(T inVal)
{
typedef typename ElementTypeTrait<T>::Type ElemType;
const size_t count = sizeof(T) / sizeof(ElemType);
-
T outVal;
-
const ElemType* src = (const ElemType*)&inVal;
ElemType* dst = (ElemType*)&outVal;
-
const int mask = __activemask();
const int lowestLaneId = __ffs(mask) - 1;
-
for (size_t i = 0; i < count; ++i)
{
dst[i] = __shfl_sync(mask, src[i], lowestLaneId);
}
-
return outVal;
}
@@ -851,19 +855,14 @@ __inline__ __device__ T _waveReadLaneAtMultiple(T inVal, int lane)
{
typedef typename ElementTypeTrait<T>::Type ElemType;
const size_t count = sizeof(T) / sizeof(ElemType);
-
T outVal;
-
const ElemType* src = (const ElemType*)&inVal;
ElemType* dst = (ElemType*)&outVal;
-
const int mask = __activemask();
-
for (size_t i = 0; i < count; ++i)
{
dst[i] = __shfl_sync(mask, src[i], lane);
}
-
return outVal;
}
@@ -872,9 +871,8 @@ __inline__ __device__ T _waveReadLaneAtMultiple(T inVal, int lane)
// Invertable means that when we get to the end of the reduce, we can remove val (to make exclusive), using
// the inverse of the op.
template <typename INTF, typename T>
-__device__ T _wavePrefixInvertableScalar(T val)
+__device__ T _wavePrefixInvertableScalar(T val, const int mask)
{
- const int mask = __activemask();
const int offsetSize = _waveCalcPow2Offset(mask);
const int laneId = _getLaneId();
@@ -923,9 +921,8 @@ __device__ T _wavePrefixInvertableScalar(T val)
// This implementation separately tracks the value to be propogated, and the value
// that is the final result
template <typename INTF, typename T>
-__device__ T _wavePrefixScalar(T val)
+__device__ T _wavePrefixScalar(T val, const int mask)
{
- const int mask = __activemask();
const int offsetSize = _waveCalcPow2Offset(mask);
const int laneId = _getLaneId();
@@ -971,7 +968,7 @@ __device__ T _wavePrefixScalar(T val)
template <typename INTF, typename T, size_t COUNT>
-__device__ T _copy(T* dst, const T* src)
+__device__ T _waveOpCopy(T* dst, const T* src)
{
for (size_t j = 0; j < COUNT; ++j)
{
@@ -981,7 +978,7 @@ __device__ T _copy(T* dst, const T* src)
template <typename INTF, typename T, size_t COUNT>
-__device__ T _doInverse(T* inOut, const T* val)
+__device__ T _waveOpDoInverse(T* inOut, const T* val)
{
for (size_t j = 0; j < COUNT; ++j)
{
@@ -990,7 +987,7 @@ __device__ T _doInverse(T* inOut, const T* val)
}
template <typename INTF, typename T, size_t COUNT>
-__device__ T _setInitial(T* out, const T* val)
+__device__ T _waveOpSetInitial(T* out, const T* val)
{
for (size_t j = 0; j < COUNT; ++j)
{
@@ -999,14 +996,13 @@ __device__ T _setInitial(T* out, const T* val)
}
template <typename INTF, typename T, size_t COUNT>
-__device__ T _wavePrefixInvertableMultiple(T* val)
+__device__ T _wavePrefixInvertableMultiple(T* val, const int mask)
{
- const int mask = __activemask();
const int offsetSize = _waveCalcPow2Offset(mask);
const int laneId = _getLaneId();
T originalVal[COUNT];
- _copy<INTF, T, COUNT>(originalVal, val);
+ _waveOpCopy<INTF, T, COUNT>(originalVal, val);
if (offsetSize > 0)
{
@@ -1027,11 +1023,11 @@ __device__ T _wavePrefixInvertableMultiple(T* val)
}
}
// Remove originalVal from the result, by applyin inverse
- _doInverse<INTF, T, COUNT>(val, originalVal);
+ _waveOpDoInverse<INTF, T, COUNT>(val, originalVal);
}
else
{
- _setInitial<INTF, T, COUNT>(val, val);
+ _waveOpSetInitial<INTF, T, COUNT>(val, val);
if (!_waveIsSingleLane(mask))
{
int remaining = mask;
@@ -1058,16 +1054,15 @@ __device__ T _wavePrefixInvertableMultiple(T* val)
}
template <typename INTF, typename T, size_t COUNT>
-__device__ T _wavePrefixMultiple(T* val)
+__device__ T _wavePrefixMultiple(T* val, const int mask)
{
- const int mask = __activemask();
const int offsetSize = _waveCalcPow2Offset(mask);
const int laneId = _getLaneId();
T work[COUNT];
- _copy<INTF, T, COUNT>(work, val);
- _setInitial<INTF, T, COUNT>(val, val);
+ _waveOpCopy<INTF, T, COUNT>(work, val);
+ _waveOpSetInitial<INTF, T, COUNT>(val, val);
if (offsetSize > 0)
{
@@ -1113,29 +1108,89 @@ __device__ T _wavePrefixMultiple(T* val)
}
}
}
-
+
template <typename T>
-__inline__ __device__ T _wavePrefixProduct(T val) { return _wavePrefixScalar<WaveOpMul<T>, T>(val); }
+__inline__ __device__ T _wavePrefixProduct(T val, const int mask = _getPrefixMask()) { return _wavePrefixScalar<WaveOpMul<T>, T>(val, mask); }
template <typename T>
-__inline__ __device__ T _wavePrefixSum(T val) { return _wavePrefixInvertableScalar<WaveOpAdd<T>, T>(val); }
+__inline__ __device__ T _wavePrefixSum(T val, const int mask = _getPrefixMask()) { return _wavePrefixInvertableScalar<WaveOpAdd<T>, T>(val, mask); }
+
+template <typename T>
+__inline__ __device__ T _wavePrefixXor(T val, const int mask = _getPrefixMask()) { return _wavePrefixInvertableScalar<WaveOpXor<T>, T>(val, mask); }
+
+template <typename T>
+__inline__ __device__ T _wavePrefixOr(T val, const int mask = _getPrefixMask()) { return _wavePrefixScalar<WaveOpOr<T>, T>(val, mask); }
+
+template <typename T>
+__inline__ __device__ T _wavePrefixAnd(T val, const int mask = _getPrefixMask()) { return _wavePrefixScalar<WaveOpAnd<T>, T>(val, mask); }
+
template <typename T>
-__inline__ __device__ T _wavePrefixProductMultiple(T val)
+__inline__ __device__ T _wavePrefixProductMultiple(T val, const int mask = _getPrefixMask())
{
typedef typename ElementTypeTrait<T>::Type ElemType;
- _wavePrefixInvertableMultiple<WaveOpMul<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val);
+ _wavePrefixInvertableMultiple<WaveOpMul<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val, mask);
return val;
}
template <typename T>
-__inline__ __device__ T _wavePrefixSumMultiple(T val)
+__inline__ __device__ T _wavePrefixSumMultiple(T val, const int mask = _getPrefixMask())
{
typedef typename ElementTypeTrait<T>::Type ElemType;
- _wavePrefixMultiple<WaveOpAdd<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val);
+ _wavePrefixInvertableMultiple<WaveOpAdd<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val, mask);
return val;
}
+template <typename T>
+__inline__ __device__ T _wavePrefixXorMultiple(T val, const int mask = _getPrefixMask())
+{
+ typedef typename ElementTypeTrait<T>::Type ElemType;
+ _wavePrefixInvertableMultiple<WaveOpXor<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val, mask);
+ return val;
+}
+
+template <typename T>
+__inline__ __device__ T _wavePrefixOrMultiple(T val, const int mask = _getPrefixMask())
+{
+ typedef typename ElementTypeTrait<T>::Type ElemType;
+ _wavePrefixMultiple<WaveOpOr<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val, mask);
+ return val;
+}
+
+template <typename T>
+__inline__ __device__ T _wavePrefixAndMultiple(T val, const int mask = _getPrefixMask())
+{
+ typedef typename ElementTypeTrait<T>::Type ElemType;
+ _wavePrefixMultiple<WaveOpAnd<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val, mask);
+ return val;
+}
+
+template <typename T>
+__inline__ __device__ uint4 _waveMatchScalar(T val)
+{
+ // __match_all_sync synchronizes so can use __activemask()
+ const int mask = __activemask();
+ int pred;
+ return make_uint4(__match_all_sync(mask, val, &pred), 0, 0, 0);
+}
+
+template <typename T>
+__inline__ __device__ uint4 _waveMatchMultiple(const T& inVal)
+{
+ typedef typename ElementTypeTrait<T>::Type ElemType;
+ const size_t count = sizeof(T) / sizeof(ElemType);
+ // __match_all_sync synchronizes so can use __activemask()
+ const int mask = __activemask();
+ int pred;
+ const ElemType* src = (const ElemType*)&inVal;
+ uint matchBits = 0xffffffff;
+ for (size_t i = 0; i < count && matchBits; ++i)
+ {
+ matchBits = matchBits & __match_all_sync(mask, src[i], &pred);
+ }
+ return make_uint4(matchBits, 0, 0, 0);
+}
+
/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 30c86b3eb..6ac1038f8 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -2628,9 +2628,7 @@ __generic<T : __BuiltinType, let N : int, let M : int>
__target_intrinsic(cuda, "_waveAllEqualMultiple($0)")
bool WaveActiveAllEqual(matrix<T,N,M> value);
-__generic<T : __BuiltinType> uint4 WaveMatch(T value);
-__generic<T : __BuiltinType, let N : int> uint4 WaveMatch(vector<T,N> value);
-__generic<T : __BuiltinType, let N : int, let M : int> uint4 WaveMatch(matrix<T,N,M> value);
+
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
@@ -2650,11 +2648,9 @@ __target_intrinsic(glsl, "subgroupBallot($0)")
__target_intrinsic(cuda, "make_uint4(__ballot_sync(__activemask(), $0), 0, 0, 0)")
uint4 WaveActiveBallot(bool condition);
-// TODO(JS):
-// subgroupBallotBitCount seems to take a uint4 parameter.
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupBallotBitCount($0)")
+__target_intrinsic(glsl, "bitCount(subgroupBallot($0))")
__target_intrinsic(cuda, "__popc(__ballot_sync(__activemask(), $0))")
uint WaveActiveCountBits(bool value);
@@ -2751,58 +2747,105 @@ uint WavePrefixCountBits(bool value);
// https://github.com/microsoft/DirectX-Specs/blob/master/d3d/HLSL_ShaderModel6_5.md
// TODO(JS): Looks like they need a mask parameter
+__generic<T : __BuiltinType>
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_waveMatchScalar($0)")
+uint4 WaveMatch(T value);
+__generic<T : __BuiltinType, let N : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_waveMatchMultiple($0)")
+uint4 WaveMatch(vector<T,N> value);
+__generic<T : __BuiltinType, let N : int, let M : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_waveMatchMultiple($0)")
+uint4 WaveMatch(matrix<T,N,M> value);
+
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_popc(__ballot_sync(($1).x, $0) & _getLaneLtMask())")
+uint WaveMultiPrefixCountBits(bool value, uint4 mask);
+
__generic<T : __BuiltinArithmeticType>
+__target_intrinsic(hlsl)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExclusiveAnd($0)")
-__target_intrinsic(cuda, "_wavePrefixAnd($0)")
-T WaveMultiPrefixBitAnd(T expr);
+//__target_intrinsic(glsl, "subgroupExclusiveAnd($0)")
+__target_intrinsic(cuda, "_wavePrefixAnd($0, _getMultiPrefixMask(($1).x))")
+T WaveMultiPrefixBitAnd(T expr, uint4 mask);
+__target_intrinsic(hlsl)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__target_intrinsic(glsl, "subgroupExclusiveAnd($0)")
+__target_intrinsic(cuda, "_wavePrefixAndMultiple($0, _getMultiPrefixMask(($1).x))")
__generic<T : __BuiltinArithmeticType, let N : int>
-vector<T,N> WaveMultiPrefixBitAnd(vector<T,N> expr);
+vector<T,N> WaveMultiPrefixBitAnd(vector<T,N> expr, uint4 mask);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-matrix<T,N,M> WaveMultiPrefixBitAnd(matrix<T,N,M> expr);
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_wavePrefixAndMultiple($0, _getMultiPrefixMask(($1).x))")
+matrix<T,N,M> WaveMultiPrefixBitAnd(matrix<T,N,M> expr, uint4 mask);
__generic<T : __BuiltinArithmeticType>
+__target_intrinsic(hlsl)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExclusiveOr($0)")
-__target_intrinsic(cuda, "_wavePrefixOr($0)")
-T WaveMultiPrefixBitOr(T expr);
+//__target_intrinsic(glsl, "subgroupExclusiveOr($0)")
+__target_intrinsic(cuda, "_wavePrefixOr($0, _getMultiPrefixMask(($1).x))")
+T WaveMultiPrefixBitOr(T expr, uint4 mask);
__generic<T : __BuiltinArithmeticType, let N : int>
+__target_intrinsic(hlsl)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExclusiveOr($0)")
-vector<T,N> WaveMultiPrefixBitOr(vector<T,N> expr);
+//__target_intrinsic(glsl, "subgroupExclusiveOr($0)")
+__target_intrinsic(cuda, "_wavePrefixOrMultiple($0, _getMultiPrefixMask(($1).x))")
+vector<T,N> WaveMultiPrefixBitOr(vector<T,N> expr, uint4 mask);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-matrix<T,N,M> WaveMultiPrefixBitOr(matrix<T,N,M> expr);
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_wavePrefixOrMultiple($0, _getMultiPrefixMask(($1).x))")
+matrix<T,N,M> WaveMultiPrefixBitOr(matrix<T,N,M> expr, uint4 mask);
__generic<T : __BuiltinArithmeticType>
+__target_intrinsic(hlsl)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__target_intrinsic(glsl, "subgroupExclusiveXor($0)")
-__target_intrinsic(cuda, "_wavePrefixXor($0)")
-T WaveMultiPrefixBitXor(T expr);
+__target_intrinsic(cuda, "_wavePrefixXor($0, _getMultiPrefixMask(($1).x))")
+T WaveMultiPrefixBitXor(T expr, uint4 mask);
__generic<T : __BuiltinArithmeticType, let N : int>
+__target_intrinsic(hlsl)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__target_intrinsic(glsl, "subgroupExclusiveXor($0)")
-vector<T,N> WaveMultiPrefixBitXor(vector<T,N> expr);
+__target_intrinsic(cuda, "_wavePrefixXorMultiple($0, _getMultiPrefixMask(($1).x))")
+vector<T,N> WaveMultiPrefixBitXor(vector<T,N> expr, uint4 mask);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-matrix<T,N,M> WaveMultiPrefixBitXor(matrix<T,N,M> expr);
-
-
-uint WaveMultiPrefixCountBits(bool value, uint4 mask);
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_wavePrefixXorMultiple($0, _getMultiPrefixMask(($1).x))")
+matrix<T,N,M> WaveMultiPrefixBitXor(matrix<T,N,M> expr, uint4 mask);
-__generic<T : __BuiltinArithmeticType> T WaveMultiPrefixProduct(T value, uint4 mask);
-__generic<T : __BuiltinArithmeticType, let N : int> vector<T,N> WaveMultiPrefixProduct(vector<T,N> value, uint4 mask);
-__generic<T : __BuiltinArithmeticType, let N : int, let M : int> matrix<T,N,M> WaveMultiPrefixProduct(matrix<T,N,M> value, uint4 mask);
+__generic<T : __BuiltinArithmeticType>
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_wavePrefixProduct($0, _getMultiPrefixMask(($1).x))")
+T WaveMultiPrefixProduct(T value, uint4 mask);
+__generic<T : __BuiltinArithmeticType, let N : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_wavePrefixProductMultiple($0, _getMultiPrefixMask(($1).x))")
+vector<T,N> WaveMultiPrefixProduct(vector<T,N> value, uint4 mask);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_wavePrefixProductMultiple($0, _getMultiPrefixMask(($1).x))")
+matrix<T,N,M> WaveMultiPrefixProduct(matrix<T,N,M> value, uint4 mask);
-__generic<T : __BuiltinArithmeticType> T WaveMultiPrefixSum(T value, uint4 mask);
-__generic<T : __BuiltinArithmeticType, let N : int> vector<T,N> WaveMultiPrefixSum(vector<T,N> value, uint4 mask);
-__generic<T : __BuiltinArithmeticType, let N : int, let M : int> matrix<T,N,M> WaveMultiPrefixSum(matrix<T,N,M> value, uint4 mask);
+__generic<T : __BuiltinArithmeticType>
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_wavePrefixSum($0, _getMultiPrefixMask(($1).x))")
+T WaveMultiPrefixSum(T value, uint4 mask);
+__generic<T : __BuiltinArithmeticType, let N : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_wavePrefixSumMultiple($0, _getMultiPrefixMask(($1).x))")
+vector<T,N> WaveMultiPrefixSum(vector<T,N> value, uint4 mask);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "_wavePrefixSumMultiple($0, _getMultiPrefixMask(($1).x))")
+matrix<T,N,M> WaveMultiPrefixSum(matrix<T,N,M> value, uint4 mask);
// `typedef`s to help with the fact that HLSL has been sorta-kinda case insensitive at various points
typedef Texture2D texture2D;
diff --git a/source/slang/slang-profile-defs.h b/source/slang/slang-profile-defs.h
index fc2722160..7066b5942 100644
--- a/source/slang/slang-profile-defs.h
+++ b/source/slang/slang-profile-defs.h
@@ -129,6 +129,8 @@ PROFILE(DX_Compute_6_0, cs_6_0, Compute, DX_6_0)
PROFILE(DX_Compute_6_1, cs_6_1, Compute, DX_6_1)
PROFILE(DX_Compute_6_2, cs_6_2, Compute, DX_6_2)
PROFILE(DX_Compute_6_3, cs_6_3, Compute, DX_6_3)
+PROFILE(DX_Compute_6_4, cs_6_4, Compute, DX_6_4)
+PROFILE(DX_Compute_6_5, cs_6_5, Compute, DX_6_5)
PROFILE(DX_Domain_5_0, ds_5_0, Domain, DX_5_0)
PROFILE(DX_Domain_5_1, ds_5_1, Domain, DX_5_1)
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang b/tests/hlsl-intrinsic/wave-multi-prefix.slang
new file mode 100644
index 000000000..3eee16e31
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-multi-prefix.slang
@@ -0,0 +1,26 @@
+//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute
+//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
+// We need SM6.5 for these tests
+// Disable because version of dxc we are currently using doesn't support SM6.5
+//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile sm_6_5
+// Disabled because we don't have GLSL intrinsics for these it seems
+//DISABLE_TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute
+//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(8, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int idx = int(dispatchThreadID.x);
+
+ int value = 0;
+
+ uint4 mask = WaveMatch(true);
+
+ // Scalar
+ value += WaveMultiPrefixSum(1 << idx, mask);
+
+ outputBuffer[idx] = value;
+} \ No newline at end of file
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt b/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt
new file mode 100644
index 000000000..6ec6deeea
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt
@@ -0,0 +1,8 @@
+0
+1
+3
+7
+F
+1F
+3F
+7F