summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-17 15:57:22 -0700
committerGitHub <noreply@github.com>2023-03-17 15:57:22 -0700
commit7f11f883d0781952f002b3aa3222a3aa0040f18a (patch)
tree08eaf10fef39211fbc3f124679bfe8a35775a5a7 /prelude
parent4b55bf6d75bdeed087728505a1c9b43d3a99af8d (diff)
Add support for emitting cuda kernel and host functions. (#2712)
* Add support for emitting cuda kernel and host functions. * Update test. * Fix cuda preamble emit. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-cuda-prelude.h20
1 files changed, 20 insertions, 0 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index cb1bb188b..7a4c5a918 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -1,3 +1,11 @@
+#define SLANG_PRELUDE_EXPORT
+
+#ifdef __CUDACC_RTC__
+#define SLANG_CUDA_RTC 1
+#else
+#define SLANG_CUDA_RTC 0
+#endif
+
// Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support.
// For this to work NVRTC needs to have the path to the CUDA SDK.
//
@@ -341,6 +349,7 @@ SLANG_CUDA_VECTOR_FLOAT_OPS(__half)
SLANG_CUDA_FLOAT_VECTOR_MOD(float)
SLANG_CUDA_FLOAT_VECTOR_MOD(double)
+#if SLANG_CUDA_RTC
#define SLANG_MAKE_VECTOR(T) \
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x, T y) { return T##2{x, y}; }\
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x, T y, T z) { return T##3{ x, y, z }; }\
@@ -355,15 +364,24 @@ SLANG_MAKE_VECTOR(float)
SLANG_MAKE_VECTOR(double)
SLANG_MAKE_VECTOR(longlong)
SLANG_MAKE_VECTOR(ulonglong)
+#endif
+
#if SLANG_CUDA_ENABLE_HALF
SLANG_MAKE_VECTOR(__half)
#endif
+#if SLANG_CUDA_RTC
#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##1 make_##T##1(T x) { return T##1{x}; }\
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) { return make_##T##2(x, x); }\
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) { return make_##T##3(x, x, x); }\
SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) { return make_##T##4(x, x, x, x); }
+#else
+#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) { return make_##T##2(x, x); }\
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) { return make_##T##3(x, x, x); }\
+ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) { return make_##T##4(x, x, x, x); }
+#endif
SLANG_MAKE_VECTOR_FROM_SCALAR(int)
SLANG_MAKE_VECTOR_FROM_SCALAR(uint)
SLANG_MAKE_VECTOR_FROM_SCALAR(short)
@@ -378,10 +396,12 @@ SLANG_MAKE_VECTOR_FROM_SCALAR(double)
SLANG_MAKE_VECTOR_FROM_SCALAR(__half)
#endif
+
template<typename T, int n>
struct GetVectorTypeImpl {};
#define GET_VECTOR_TYPE_IMPL(T, n)\
+template<>\
struct GetVectorTypeImpl<T,n>\
{\
typedef T##n type;\