diff options
Diffstat (limited to 'prelude')
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 295 |
1 files changed, 293 insertions, 2 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 9df2727f6..44afd71b9 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -2661,17 +2661,67 @@ struct WaveOpMul template<typename T> struct WaveOpMax { - __inline__ __device__ static T getInitial(T a) { return a; } + __inline__ __device__ static T getInitial(T a, bool exclusive = false); __inline__ __device__ static T doOp(T a, T b) { return a > b ? a : b; } }; template<typename T> struct WaveOpMin { - __inline__ __device__ static T getInitial(T a) { return a; } + __inline__ __device__ static T getInitial(T a, bool exclusive = false); __inline__ __device__ static T doOp(T a, T b) { return a < b ? a : b; } }; +// Compact specializations using macro for getInitial +#define SLANG_WAVE_MIN_SPEC(T, EXCL_VAL) \ + template<> \ + __inline__ __device__ T WaveOpMin<T>::getInitial(T a, bool exclusive) \ + { \ + return exclusive ? (EXCL_VAL) : a; \ + } + +#define SLANG_WAVE_MAX_SPEC(T, EXCL_VAL) \ + template<> \ + __inline__ __device__ T WaveOpMax<T>::getInitial(T a, bool exclusive) \ + { \ + return exclusive ? (EXCL_VAL) : a; \ + } + +// Min specializations (exclusive identity = max value) +SLANG_WAVE_MIN_SPEC(float, SLANG_INFINITY) +SLANG_WAVE_MIN_SPEC(double, SLANG_INFINITY) +SLANG_WAVE_MIN_SPEC(int, 0x7FFFFFFF) +SLANG_WAVE_MIN_SPEC(uint, 0xFFFFFFFF) +SLANG_WAVE_MIN_SPEC(char, (char)0x7F) +SLANG_WAVE_MIN_SPEC(int8_t, (int8_t)0x7F) +SLANG_WAVE_MIN_SPEC(uint8_t, (uint8_t)0xFF) +SLANG_WAVE_MIN_SPEC(int16_t, (int16_t)0x7FFF) +SLANG_WAVE_MIN_SPEC(uint16_t, (uint16_t)0xFFFF) +SLANG_WAVE_MIN_SPEC(int64_t, 0x7FFFFFFFFFFFFFFFLL) +SLANG_WAVE_MIN_SPEC(uint64_t, 0xFFFFFFFFFFFFFFFFULL) +#if SLANG_CUDA_ENABLE_HALF +SLANG_WAVE_MIN_SPEC(__half, __ushort_as_half(0x7BFF)) +#endif + +// Max specializations (exclusive identity = min value) +SLANG_WAVE_MAX_SPEC(float, -SLANG_INFINITY) +SLANG_WAVE_MAX_SPEC(double, -SLANG_INFINITY) +SLANG_WAVE_MAX_SPEC(int, (int)0x80000000) +SLANG_WAVE_MAX_SPEC(uint, 0) +SLANG_WAVE_MAX_SPEC(char, (char)0x80) +SLANG_WAVE_MAX_SPEC(int8_t, (int8_t)0x80) +SLANG_WAVE_MAX_SPEC(uint8_t, 0) +SLANG_WAVE_MAX_SPEC(int16_t, (int16_t)0x8000) +SLANG_WAVE_MAX_SPEC(uint16_t, 0) +SLANG_WAVE_MAX_SPEC(int64_t, (int64_t)0x8000000000000000LL) +SLANG_WAVE_MAX_SPEC(uint64_t, 0) +#if SLANG_CUDA_ENABLE_HALF +SLANG_WAVE_MAX_SPEC(__half, __ushort_as_half(0xFBFF)) +#endif + +#undef SLANG_WAVE_MIN_SPEC +#undef SLANG_WAVE_MAX_SPEC + template<typename T> struct ElementTypeTrait; @@ -2706,6 +2756,33 @@ struct ElementTypeTrait<int64_t> { typedef int64_t Type; }; +template<> +struct ElementTypeTrait<char> +{ + typedef char Type; +}; +template<> +struct ElementTypeTrait<uchar> +{ + typedef uchar Type; +}; +template<> +struct ElementTypeTrait<short> +{ + typedef short Type; +}; +template<> +struct ElementTypeTrait<ushort> +{ + typedef ushort Type; +}; +#if SLANG_CUDA_ENABLE_HALF +template<> +struct ElementTypeTrait<__half> +{ + typedef __half Type; +}; +#endif // Vector template<> @@ -2792,6 +2869,115 @@ struct ElementTypeTrait<double4> typedef double Type; }; +// Additional vector types +template<> +struct ElementTypeTrait<char2> +{ + typedef char Type; +}; +template<> +struct ElementTypeTrait<char3> +{ + typedef char Type; +}; +template<> +struct ElementTypeTrait<char4> +{ + typedef char Type; +}; +template<> +struct ElementTypeTrait<uchar2> +{ + typedef uchar Type; +}; +template<> +struct ElementTypeTrait<uchar3> +{ + typedef uchar Type; +}; +template<> +struct ElementTypeTrait<uchar4> +{ + typedef uchar Type; +}; +template<> +struct ElementTypeTrait<short2> +{ + typedef short Type; +}; +template<> +struct ElementTypeTrait<short3> +{ + typedef short Type; +}; +template<> +struct ElementTypeTrait<short4> +{ + typedef short Type; +}; +template<> +struct ElementTypeTrait<ushort2> +{ + typedef ushort Type; +}; +template<> +struct ElementTypeTrait<ushort3> +{ + typedef ushort Type; +}; +template<> +struct ElementTypeTrait<ushort4> +{ + typedef ushort Type; +}; +template<> +struct ElementTypeTrait<longlong2> +{ + typedef int64_t Type; +}; +template<> +struct ElementTypeTrait<longlong3> +{ + typedef int64_t Type; +}; +template<> +struct ElementTypeTrait<longlong4> +{ + typedef int64_t Type; +}; +template<> +struct ElementTypeTrait<ulonglong2> +{ + typedef uint64_t Type; +}; +template<> +struct ElementTypeTrait<ulonglong3> +{ + typedef uint64_t Type; +}; +template<> +struct ElementTypeTrait<ulonglong4> +{ + typedef uint64_t Type; +}; +#if SLANG_CUDA_ENABLE_HALF +template<> +struct ElementTypeTrait<__half2> +{ + typedef __half Type; +}; +template<> +struct ElementTypeTrait<__half3> +{ + typedef __half Type; +}; +template<> +struct ElementTypeTrait<__half4> +{ + typedef __half Type; +}; +#endif + // Matrix template<typename T, int ROWS, int COLS> struct ElementTypeTrait<Matrix<T, ROWS, COLS>> @@ -3431,6 +3617,111 @@ __inline__ __device__ T _wavePrefixAndMultiple(WarpMask mask, T val) } template<typename T> +__inline__ __device__ T _wavePrefixMin(WarpMask mask, T val) +{ + return _wavePrefixScalar<WaveOpMin<T>, T>(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixMax(WarpMask mask, T val) +{ + return _wavePrefixScalar<WaveOpMax<T>, T>(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixMinMultiple(WarpMask mask, T val) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpMin<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>( + mask, + (ElemType*)&val); + return val; +} + +template<typename T> +__inline__ __device__ T _wavePrefixMaxMultiple(WarpMask mask, T val) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpMax<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>( + mask, + (ElemType*)&val); + return val; +} + +// Wrapper structures for exclusive operations that use the overloaded getInitial method +template<typename T> +struct WaveOpExclusiveMin +{ + __inline__ __device__ static T getInitial(T a) { return WaveOpMin<T>::getInitial(a, true); } + __inline__ __device__ static T doOp(T a, T b) { return WaveOpMin<T>::doOp(a, b); } +}; + +template<typename T> +struct WaveOpExclusiveMax +{ + __inline__ __device__ static T getInitial(T a) { return WaveOpMax<T>::getInitial(a, true); } + __inline__ __device__ static T doOp(T a, T b) { return WaveOpMax<T>::doOp(a, b); } +}; + +// Inclusive prefix min/max functions (for WaveMultiPrefixInclusive*) +template<typename T> +__inline__ __device__ T _wavePrefixInclusiveMin(WarpMask mask, T val) +{ + return _wavePrefixMin(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixInclusiveMax(WarpMask mask, T val) +{ + return _wavePrefixMax(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixInclusiveMinMultiple(WarpMask mask, T val) +{ + return _wavePrefixMinMultiple(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixInclusiveMaxMultiple(WarpMask mask, T val) +{ + return _wavePrefixMaxMultiple(mask, val); +} + +// Explicit exclusive prefix min/max functions (for WaveMultiPrefixExclusive*) +template<typename T> +__inline__ __device__ T _wavePrefixExclusiveMin(WarpMask mask, T val) +{ + return _wavePrefixScalar<WaveOpExclusiveMin<T>, T>(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixExclusiveMax(WarpMask mask, T val) +{ + return _wavePrefixScalar<WaveOpExclusiveMax<T>, T>(mask, val); +} + +template<typename T> +__inline__ __device__ T _wavePrefixExclusiveMinMultiple(WarpMask mask, T val) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpExclusiveMin<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>( + mask, + (ElemType*)&val); + return val; +} + +template<typename T> +__inline__ __device__ T _wavePrefixExclusiveMaxMultiple(WarpMask mask, T val) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpExclusiveMax<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>( + mask, + (ElemType*)&val); + return val; +} + +template<typename T> __inline__ __device__ uint4 _waveMatchScalar(WarpMask mask, T val) { int pred; |
