summaryrefslogtreecommitdiff
path: root/prelude
diff options
context:
space:
mode:
authorHarsh Aggarwal (NVIDIA) <haaggarwal@nvidia.com>2025-08-20 14:41:06 +0530
committerGitHub <noreply@github.com>2025-08-20 09:11:06 +0000
commite0c20a076f2ec84586b6508664df4f59273c6aaf (patch)
treeae629eb56413f1ffd1d269ffe447471c07aa8137 /prelude
parente4a7129b84692ddc3c586f0d0dde95e80e173ed8 (diff)
Updated support to enable batch3 (#8219)
Enable CUDA support for batch 3 tests - Enhanced wave operations with exclusive support - Added proper identity values for min/max operations - Fixed intrinsic name mapping issues - Updated test configurations Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-cuda-prelude.h295
1 files changed, 293 insertions, 2 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 9df2727f6..44afd71b9 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -2661,17 +2661,67 @@ struct WaveOpMul
template<typename T>
struct WaveOpMax
{
- __inline__ __device__ static T getInitial(T a) { return a; }
+ __inline__ __device__ static T getInitial(T a, bool exclusive = false);
__inline__ __device__ static T doOp(T a, T b) { return a > b ? a : b; }
};
template<typename T>
struct WaveOpMin
{
- __inline__ __device__ static T getInitial(T a) { return a; }
+ __inline__ __device__ static T getInitial(T a, bool exclusive = false);
__inline__ __device__ static T doOp(T a, T b) { return a < b ? a : b; }
};
+// Compact specializations using macro for getInitial
+#define SLANG_WAVE_MIN_SPEC(T, EXCL_VAL) \
+ template<> \
+ __inline__ __device__ T WaveOpMin<T>::getInitial(T a, bool exclusive) \
+ { \
+ return exclusive ? (EXCL_VAL) : a; \
+ }
+
+#define SLANG_WAVE_MAX_SPEC(T, EXCL_VAL) \
+ template<> \
+ __inline__ __device__ T WaveOpMax<T>::getInitial(T a, bool exclusive) \
+ { \
+ return exclusive ? (EXCL_VAL) : a; \
+ }
+
+// Min specializations (exclusive identity = max value)
+SLANG_WAVE_MIN_SPEC(float, SLANG_INFINITY)
+SLANG_WAVE_MIN_SPEC(double, SLANG_INFINITY)
+SLANG_WAVE_MIN_SPEC(int, 0x7FFFFFFF)
+SLANG_WAVE_MIN_SPEC(uint, 0xFFFFFFFF)
+SLANG_WAVE_MIN_SPEC(char, (char)0x7F)
+SLANG_WAVE_MIN_SPEC(int8_t, (int8_t)0x7F)
+SLANG_WAVE_MIN_SPEC(uint8_t, (uint8_t)0xFF)
+SLANG_WAVE_MIN_SPEC(int16_t, (int16_t)0x7FFF)
+SLANG_WAVE_MIN_SPEC(uint16_t, (uint16_t)0xFFFF)
+SLANG_WAVE_MIN_SPEC(int64_t, 0x7FFFFFFFFFFFFFFFLL)
+SLANG_WAVE_MIN_SPEC(uint64_t, 0xFFFFFFFFFFFFFFFFULL)
+#if SLANG_CUDA_ENABLE_HALF
+SLANG_WAVE_MIN_SPEC(__half, __ushort_as_half(0x7BFF))
+#endif
+
+// Max specializations (exclusive identity = min value)
+SLANG_WAVE_MAX_SPEC(float, -SLANG_INFINITY)
+SLANG_WAVE_MAX_SPEC(double, -SLANG_INFINITY)
+SLANG_WAVE_MAX_SPEC(int, (int)0x80000000)
+SLANG_WAVE_MAX_SPEC(uint, 0)
+SLANG_WAVE_MAX_SPEC(char, (char)0x80)
+SLANG_WAVE_MAX_SPEC(int8_t, (int8_t)0x80)
+SLANG_WAVE_MAX_SPEC(uint8_t, 0)
+SLANG_WAVE_MAX_SPEC(int16_t, (int16_t)0x8000)
+SLANG_WAVE_MAX_SPEC(uint16_t, 0)
+SLANG_WAVE_MAX_SPEC(int64_t, (int64_t)0x8000000000000000LL)
+SLANG_WAVE_MAX_SPEC(uint64_t, 0)
+#if SLANG_CUDA_ENABLE_HALF
+SLANG_WAVE_MAX_SPEC(__half, __ushort_as_half(0xFBFF))
+#endif
+
+#undef SLANG_WAVE_MIN_SPEC
+#undef SLANG_WAVE_MAX_SPEC
+
template<typename T>
struct ElementTypeTrait;
@@ -2706,6 +2756,33 @@ struct ElementTypeTrait<int64_t>
{
typedef int64_t Type;
};
+template<>
+struct ElementTypeTrait<char>
+{
+ typedef char Type;
+};
+template<>
+struct ElementTypeTrait<uchar>
+{
+ typedef uchar Type;
+};
+template<>
+struct ElementTypeTrait<short>
+{
+ typedef short Type;
+};
+template<>
+struct ElementTypeTrait<ushort>
+{
+ typedef ushort Type;
+};
+#if SLANG_CUDA_ENABLE_HALF
+template<>
+struct ElementTypeTrait<__half>
+{
+ typedef __half Type;
+};
+#endif
// Vector
template<>
@@ -2792,6 +2869,115 @@ struct ElementTypeTrait<double4>
typedef double Type;
};
+// Additional vector types
+template<>
+struct ElementTypeTrait<char2>
+{
+ typedef char Type;
+};
+template<>
+struct ElementTypeTrait<char3>
+{
+ typedef char Type;
+};
+template<>
+struct ElementTypeTrait<char4>
+{
+ typedef char Type;
+};
+template<>
+struct ElementTypeTrait<uchar2>
+{
+ typedef uchar Type;
+};
+template<>
+struct ElementTypeTrait<uchar3>
+{
+ typedef uchar Type;
+};
+template<>
+struct ElementTypeTrait<uchar4>
+{
+ typedef uchar Type;
+};
+template<>
+struct ElementTypeTrait<short2>
+{
+ typedef short Type;
+};
+template<>
+struct ElementTypeTrait<short3>
+{
+ typedef short Type;
+};
+template<>
+struct ElementTypeTrait<short4>
+{
+ typedef short Type;
+};
+template<>
+struct ElementTypeTrait<ushort2>
+{
+ typedef ushort Type;
+};
+template<>
+struct ElementTypeTrait<ushort3>
+{
+ typedef ushort Type;
+};
+template<>
+struct ElementTypeTrait<ushort4>
+{
+ typedef ushort Type;
+};
+template<>
+struct ElementTypeTrait<longlong2>
+{
+ typedef int64_t Type;
+};
+template<>
+struct ElementTypeTrait<longlong3>
+{
+ typedef int64_t Type;
+};
+template<>
+struct ElementTypeTrait<longlong4>
+{
+ typedef int64_t Type;
+};
+template<>
+struct ElementTypeTrait<ulonglong2>
+{
+ typedef uint64_t Type;
+};
+template<>
+struct ElementTypeTrait<ulonglong3>
+{
+ typedef uint64_t Type;
+};
+template<>
+struct ElementTypeTrait<ulonglong4>
+{
+ typedef uint64_t Type;
+};
+#if SLANG_CUDA_ENABLE_HALF
+template<>
+struct ElementTypeTrait<__half2>
+{
+ typedef __half Type;
+};
+template<>
+struct ElementTypeTrait<__half3>
+{
+ typedef __half Type;
+};
+template<>
+struct ElementTypeTrait<__half4>
+{
+ typedef __half Type;
+};
+#endif
+
// Matrix
template<typename T, int ROWS, int COLS>
struct ElementTypeTrait<Matrix<T, ROWS, COLS>>
@@ -3431,6 +3617,111 @@ __inline__ __device__ T _wavePrefixAndMultiple(WarpMask mask, T val)
}
template<typename T>
+__inline__ __device__ T _wavePrefixMin(WarpMask mask, T val)
+{
+ return _wavePrefixScalar<WaveOpMin<T>, T>(mask, val);
+}
+
+template<typename T>
+__inline__ __device__ T _wavePrefixMax(WarpMask mask, T val)
+{
+ return _wavePrefixScalar<WaveOpMax<T>, T>(mask, val);
+}
+
+template<typename T>
+__inline__ __device__ T _wavePrefixMinMultiple(WarpMask mask, T val)
+{
+ typedef typename ElementTypeTrait<T>::Type ElemType;
+ _wavePrefixMultiple<WaveOpMin<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
+ mask,
+ (ElemType*)&val);
+ return val;
+}
+
+template<typename T>
+__inline__ __device__ T _wavePrefixMaxMultiple(WarpMask mask, T val)
+{
+ typedef typename ElementTypeTrait<T>::Type ElemType;
+ _wavePrefixMultiple<WaveOpMax<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
+ mask,
+ (ElemType*)&val);
+ return val;
+}
+
+// Wrapper structures for exclusive operations that use the overloaded getInitial method
+template<typename T>
+struct WaveOpExclusiveMin
+{
+ __inline__ __device__ static T getInitial(T a) { return WaveOpMin<T>::getInitial(a, true); }
+ __inline__ __device__ static T doOp(T a, T b) { return WaveOpMin<T>::doOp(a, b); }
+};
+
+template<typename T>
+struct WaveOpExclusiveMax
+{
+ __inline__ __device__ static T getInitial(T a) { return WaveOpMax<T>::getInitial(a, true); }
+ __inline__ __device__ static T doOp(T a, T b) { return WaveOpMax<T>::doOp(a, b); }
+};
+
+// Inclusive prefix min/max functions (for WaveMultiPrefixInclusive*)
+template<typename T>
+__inline__ __device__ T _wavePrefixInclusiveMin(WarpMask mask, T val)
+{
+ return _wavePrefixMin(mask, val);
+}
+
+template<typename T>
+__inline__ __device__ T _wavePrefixInclusiveMax(WarpMask mask, T val)
+{
+ return _wavePrefixMax(mask, val);
+}
+
+template<typename T>
+__inline__ __device__ T _wavePrefixInclusiveMinMultiple(WarpMask mask, T val)
+{
+ return _wavePrefixMinMultiple(mask, val);
+}
+
+template<typename T>
+__inline__ __device__ T _wavePrefixInclusiveMaxMultiple(WarpMask mask, T val)
+{
+ return _wavePrefixMaxMultiple(mask, val);
+}
+
+// Explicit exclusive prefix min/max functions (for WaveMultiPrefixExclusive*)
+template<typename T>
+__inline__ __device__ T _wavePrefixExclusiveMin(WarpMask mask, T val)
+{
+ return _wavePrefixScalar<WaveOpExclusiveMin<T>, T>(mask, val);
+}
+
+template<typename T>
+__inline__ __device__ T _wavePrefixExclusiveMax(WarpMask mask, T val)
+{
+ return _wavePrefixScalar<WaveOpExclusiveMax<T>, T>(mask, val);
+}
+
+template<typename T>
+__inline__ __device__ T _wavePrefixExclusiveMinMultiple(WarpMask mask, T val)
+{
+ typedef typename ElementTypeTrait<T>::Type ElemType;
+ _wavePrefixMultiple<WaveOpExclusiveMin<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
+ mask,
+ (ElemType*)&val);
+ return val;
+}
+
+template<typename T>
+__inline__ __device__ T _wavePrefixExclusiveMaxMultiple(WarpMask mask, T val)
+{
+ typedef typename ElementTypeTrait<T>::Type ElemType;
+ _wavePrefixMultiple<WaveOpExclusiveMax<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
+ mask,
+ (ElemType*)&val);
+ return val;
+}
+
+template<typename T>
__inline__ __device__ uint4 _waveMatchScalar(WarpMask mask, T val)
{
int pred;