diff options
| author | Harsh Aggarwal (NVIDIA) <haaggarwal@nvidia.com> | 2025-08-20 14:41:06 +0530 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-08-20 09:11:06 +0000 |
| commit | e0c20a076f2ec84586b6508664df4f59273c6aaf (patch) | |
| tree | ae629eb56413f1ffd1d269ffe447471c07aa8137 /prelude | |
| parent | e4a7129b84692ddc3c586f0d0dde95e80e173ed8 (diff) | |
Updated support to enable batch3 (#8219)
Enable CUDA support for batch 3 tests
- Enhanced wave operations with exclusive support
- Added proper identity values for min/max operations
- Fixed intrinsic name mapping issues
- Updated test configurations
Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
Diffstat (limited to 'prelude')
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 295 |
1 files changed, 293 insertions, 2 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 9df2727f6..44afd71b9 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -2661,17 +2661,67 @@ struct WaveOpMul template<typename T> struct WaveOpMax { - __inline__ __device__ static T getInitial(T a) { return a; } + __inline__ __device__ static T getInitial(T a, bool exclusive = false); __inline__ __device__ static T doOp(T a, T b) { return a > b ? a : b; } }; template<typename T> struct WaveOpMin { - __inline__ __device__ static T getInitial(T a) { return a; } + __inline__ __device__ static T getInitial(T a, bool exclusive = false); __inline__ __device__ static T doOp(T a, T b) { return a < b ? a : b; } }; +// Compact specializations using macro for getInitial +#define SLANG_WAVE_MIN_SPEC(T, EXCL_VAL) \ + template<> \ + __inline__ __device__ T WaveOpMin<T>::getInitial(T a, bool exclusive) \ + { \ + return exclusive ? (EXCL_VAL) : a; \ + } + +#define SLANG_WAVE_MAX_SPEC(T, EXCL_VAL) \ + template<> \ + __inline__ __device__ T WaveOpMax<T>::getInitial(T a, bool exclusive) \ + { \ + return exclusive ? (EXCL_VAL) : a; \ + } + +// Min specializations (exclusive identity = max value) +SLANG_WAVE_MIN_SPEC(float, SLANG_INFINITY) +SLANG_WAVE_MIN_SPEC(double, SLANG_INFINITY) +SLANG_WAVE_MIN_SPEC(int, 0x7FFFFFFF) +SLANG_WAVE_MIN_SPEC(uint, 0xFFFFFFFF) +SLANG_WAVE_MIN_SPEC(char, (char)0x7F) +SLANG_WAVE_MIN_SPEC(int8_t, (int8_t)0x7F) +SLANG_WAVE_MIN_SPEC(uint8_t, (uint8_t)0xFF) +SLANG_WAVE_MIN_SPEC(int16_t, (int16_t)0x7FFF) +SLANG_WAVE_MIN_SPEC(uint16_t, (uint16_t)0xFFFF) +SLANG_WAVE_MIN_SPEC(int64_t, 0x7FFFFFFFFFFFFFFFLL) +SLANG_WAVE_MIN_SPEC(uint64_t, 0xFFFFFFFFFFFFFFFFULL) +#if SLANG_CUDA_ENABLE_HALF +SLANG_WAVE_MIN_SPEC(__half, __ushort_as_half(0x7BFF)) +#endif + +// Max specializations (exclusive identity = min value) +SLANG_WAVE_MAX_SPEC(float, -SLANG_INFINITY) +SLANG_WAVE_MAX_SPEC(double, -SLANG_INFINITY) +SLANG_WAVE_MAX_SPEC(int, (int)0x80000000) +SLANG_WAVE_MAX_SPEC(uint, 0) +SLANG_WAVE_MAX_SPEC(char, (char)0x80) +SLANG_WAVE_MAX_SPEC(int8_t, (int8_t)0x80) +SLANG_WAVE_MAX_SPEC(uint8_t, 0) +SLANG_WAVE_MAX_SPEC(int16_t, (int16_t)0x8000) +SLANG_WAVE_MAX_SPEC(uint16_t, 0) +SLANG_WAVE_MAX_SPEC(int64_t, (int64_t)0x8000000000000000LL) +SLANG_WAVE_MAX_SPEC(uint64_t, 0) +#if SLANG_CUDA_ENABLE_HALF +SLANG_WAVE_MAX_SPEC(__half, __ushort_as_half(0xFBFF)) +#endif + +#undef SLANG_WAVE_MIN_SPEC +#undef SLANG_WAVE_MAX_SPEC + template<typename T> struct ElementTypeTrait; @@ -2706,6 +2756,33 @@ struct ElementTypeTrait<int64_t> { typedef int64_t Type; }; +template<> +struct ElementTypeTrait<char> +{ + typedef char Type; +}; +template<> +struct ElementTypeTrait<uchar> +{ + typedef uchar Type; +}; +template<> +struct ElementTypeTrait<short> +{ + typedef short Type; +}; +template<> +struct ElementTypeTrait<ushort> +{ + typedef ushort Type; +}; +#if SLANG_CUDA_ENABLE_HALF +template<> +struct ElementTypeTrait<__half> +{ + typedef __half Type; +}; +#endif // Vector template<> @@ -2792,6 +2869,115 @@ struct ElementTypeTrait<double4> typedef double Type; }; +// Additional vector types +template<> +struct ElementTypeTrait<char2> +{ + typedef char Type; +}; +template<> +struct ElementTypeTrait<char3> +{ + typedef char Type; +}; +template<> +struct ElementTypeTrait<char4> +{ + typedef char Type; +}; +template<> +struct ElementTypeTrait<uchar2> +{ + typedef uchar Type; +}; +template<> +struct ElementTypeTrait<uchar3> +{ + typedef uchar Type; +}; +template<> +struct ElementTypeTrait<uchar4> +{ + typedef uchar Type; +}; +template<> +struct ElementTypeTrait<short2> +{ + typedef short Type; +}; +template<> +struct ElementTypeTrait<short3> +{ + typedef short Type; +}; +template<> +struct ElementTypeTrait<short4> +{ + typedef short Type; +}; +template<> +struct ElementTypeTrait<ushort2> +{ + typedef ushort Type; +}; +template<> +struct ElementTypeTrait<ushort3> +{ + typedef ushort Type; +}; +template<> +struct ElementTypeTrait<ushort4> +{ + typedef ushort Type; +}; +template<> +struct ElementTypeTrait<longlong2> +{ + typedef int64_t Type; +}; +template<> +struct ElementTypeTrait<longlong3> +{ + typedef int64_t Type; +}; +template<> +struct ElementTypeTrait<longlong4> +{ + typedef int64_t Type; +}; +template<> +struct ElementTypeTrait<ulonglong2> +{ + typedef uint64_t Type; +}; +template<> +struct ElementTypeTrait<ulonglong3> +{ + typedef uint64_t Type; +}; +template<> +struct ElementTypeTrait<ulonglong4> +{ + typedef uint64_t Type; +}; +#if SLANG_CUDA_ENABLE_HALF +template<> +struct ElementTypeTrait<__half2> +{ + typedef __half Type; +}; +template<> +struct ElementTypeTrait<__half3> +{ + typedef __half Type; +}; +template<> +struct ElementTypeTrait<__half4> +{ + typedef __half Type; +}; +#endif + // Matrix template<typename T, int ROWS, int COLS> struct ElementTypeTrait<Matrix<T, ROWS, COLS>> @@ -3431,6 +3617,111 @@ __inline__ __device__ T _wavePrefixAndMultiple(WarpMask mask, T val) } template<typename T> +__inline__ __device__ T _wavePrefixMin(WarpMask mask, T val) +{ + return _wavePrefixScalar<WaveOpMin<T>, T>(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixMax(WarpMask mask, T val) +{ + return _wavePrefixScalar<WaveOpMax<T>, T>(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixMinMultiple(WarpMask mask, T val) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpMin<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>( + mask, + (ElemType*)&val); + return val; +} + +template<typename T> +__inline__ __device__ T _wavePrefixMaxMultiple(WarpMask mask, T val) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpMax<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>( + mask, + (ElemType*)&val); + return val; +} + +// Wrapper structures for exclusive operations that use the overloaded getInitial method +template<typename T> +struct WaveOpExclusiveMin +{ + __inline__ __device__ static T getInitial(T a) { return WaveOpMin<T>::getInitial(a, true); } + __inline__ __device__ static T doOp(T a, T b) { return WaveOpMin<T>::doOp(a, b); } +}; + +template<typename T> +struct WaveOpExclusiveMax +{ + __inline__ __device__ static T getInitial(T a) { return WaveOpMax<T>::getInitial(a, true); } + __inline__ __device__ static T doOp(T a, T b) { return WaveOpMax<T>::doOp(a, b); } +}; + +// Inclusive prefix min/max functions (for WaveMultiPrefixInclusive*) +template<typename T> +__inline__ __device__ T _wavePrefixInclusiveMin(WarpMask mask, T val) +{ + return _wavePrefixMin(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixInclusiveMax(WarpMask mask, T val) +{ + return _wavePrefixMax(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixInclusiveMinMultiple(WarpMask mask, T val) +{ + return _wavePrefixMinMultiple(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixInclusiveMaxMultiple(WarpMask mask, T val) +{ + return _wavePrefixMaxMultiple(mask, val); +} + +// Explicit exclusive prefix min/max functions (for WaveMultiPrefixExclusive*) +template<typename T> +__inline__ __device__ T _wavePrefixExclusiveMin(WarpMask mask, T val) +{ + return _wavePrefixScalar<WaveOpExclusiveMin<T>, T>(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixExclusiveMax(WarpMask mask, T val) +{ + return _wavePrefixScalar<WaveOpExclusiveMax<T>, T>(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixExclusiveMinMultiple(WarpMask mask, T val) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpExclusiveMin<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>( + mask, + (ElemType*)&val); + return val; +} + +template<typename T> +__inline__ __device__ T _wavePrefixExclusiveMaxMultiple(WarpMask mask, T val) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpExclusiveMax<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>( + mask, + (ElemType*)&val); + return val; +} + +template<typename T> __inline__ __device__ uint4 _waveMatchScalar(WarpMask mask, T val) { int pred; |
