diff options
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 41 | ||||
| -rw-r--r-- | source/compiler-core/slang-nvrtc-compiler.cpp | 16 |
2 files changed, 50 insertions, 7 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 24e400d2d..38f8a721a 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -53,7 +53,8 @@ # define SLANG_CUDA_WARP_SIZE 32 #endif -#define SLANG_CUDA_WARP_MASK (SLANG_CUDA_WARP_SIZE - 1) +#define SLANG_CUDA_WARP_MASK (SLANG_CUDA_WARP_SIZE - 1) // Used for masking threadIdx.x to the warp lane index +#define SLANG_CUDA_WARP_BITMASK (~int(0)) // #define SLANG_FORCE_INLINE inline @@ -1418,10 +1419,16 @@ __inline__ __device__ bool _waveIsSingleLane(WarpMask mask) } // Returns the power of 2 size of run of set bits. Returns 0 if not a suitable run. +// Examples: +// 0b00000000'00000000'00000000'11111111 -> 8 +// 0b11111111'11111111'11111111'11111111 -> 32 +// 0b00000000'00000000'00000000'00011111 -> 0 (since 5 is not a power of 2) +// 0b00000000'00000000'00000000'11110000 -> 0 (since the run of bits does not start at the LSB) +// 0b00000000'00000000'00000000'00100111 -> 0 (since it is not a single contiguous run) __inline__ __device__ int _waveCalcPow2Offset(WarpMask mask) { // This should be the most common case, so fast path it - if (mask == SLANG_CUDA_WARP_MASK) + if (mask == SLANG_CUDA_WARP_BITMASK) { return SLANG_CUDA_WARP_SIZE; } @@ -1647,6 +1654,36 @@ __inline__ __device__ T _waveMin(WarpMask mask, T val) { return _waveReduceScala template <typename T> __inline__ __device__ T _waveMax(WarpMask mask, T val) { return _waveReduceScalar<WaveOpMax<T>, T>(mask, val); } +// Fast-path specializations when CUDA warp reduce operators are available +#if __CUDA_ARCH__ >= 800 // 8.x or higher +template<> +__inline__ __device__ unsigned _waveOr<unsigned>(WarpMask mask, unsigned val) { return __reduce_or_sync(mask, val); } + +template<> +__inline__ __device__ unsigned _waveAnd<unsigned>(WarpMask mask, unsigned val) { return __reduce_and_sync(mask, val); } + +template<> +__inline__ __device__ unsigned _waveXor<unsigned>(WarpMask mask, unsigned val) { return __reduce_xor_sync(mask, val); } + +template<> +__inline__ __device__ unsigned _waveSum<unsigned>(WarpMask mask, unsigned val) { return __reduce_add_sync(mask, val); } + +template<> +__inline__ __device__ int _waveSum<int>(WarpMask mask, int val) { return __reduce_add_sync(mask, val); } + +template<> +__inline__ __device__ unsigned _waveMin<unsigned>(WarpMask mask, unsigned val) { return __reduce_min_sync(mask, val); } + +template<> +__inline__ __device__ int _waveMin<int>(WarpMask mask, int val) { return __reduce_min_sync(mask, val); } + +template<> +__inline__ __device__ unsigned _waveMax<unsigned>(WarpMask mask, unsigned val) { return __reduce_max_sync(mask, val); } + +template<> +__inline__ __device__ int _waveMax<int>(WarpMask mask, int val) { return __reduce_max_sync(mask, val); } +#endif + // Multiple diff --git a/source/compiler-core/slang-nvrtc-compiler.cpp b/source/compiler-core/slang-nvrtc-compiler.cpp index daa392120..b84b1a403 100644 --- a/source/compiler-core/slang-nvrtc-compiler.cpp +++ b/source/compiler-core/slang-nvrtc-compiler.cpp @@ -820,15 +820,21 @@ SlangResult NVRTCDownstreamCompiler::compile(const DownstreamCompileOptions& inO { // The lowest supported CUDA architecture version supported - // by NVRTC is `compute_30`. + // by any version of NVRTC we support is `compute_30`. // SemanticVersion version(3); - // Newer releases of NVRTC only support `compute_35` and up - // (with everything before `compute_52` being deprecated). - // - if( m_desc.version.m_major >= 11 ) + // Newer releases of NVRTC only support newer CUDA architectures. + if ( m_desc.version.m_major >= 12 ) + { + // NVRTC in CUDA 12 only supports `compute_50` and up + // (with everything before `compute_52` being deprecated). + version = SemanticVersion(5, 0); + } + else if ( m_desc.version.m_major == 11 ) { + // NVRTC in CUDA 11 only supports `compute_35` and up + // (with everything before `compute_52` being deprecated). version = SemanticVersion(3, 5); } |
