diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-11 15:11:45 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-11 15:11:45 -0700 |
| commit | 7c3a40cf08091a6cf0ec2de1e9694c979fb5c551 (patch) | |
| tree | 7866ecc98be4742ec7528c524bc7a43e27f2be85 /prelude | |
| parent | 54f112f8074c8ca490195c10db8c518cdc58546a (diff) | |
Small fixes to TorchTensor. (#2790)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'prelude')
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 1bbd42168..0e0349bd7 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -382,6 +382,11 @@ SLANG_MAKE_VECTOR(ulonglong) SLANG_MAKE_VECTOR(__half) #endif +SLANG_FORCE_INLINE SLANG_CUDA_CALL bool1 make_bool1(bool x) { return bool1{ x }; } +SLANG_FORCE_INLINE SLANG_CUDA_CALL bool2 make_bool2(bool x, bool y) { return bool2{ x, y }; } +SLANG_FORCE_INLINE SLANG_CUDA_CALL bool3 make_bool3(bool x, bool y, bool z) { return bool3{ x, y, z }; } +SLANG_FORCE_INLINE SLANG_CUDA_CALL bool4 make_bool4(bool x, bool y, bool z, bool w) { return bool4{ x, y, z, w }; } + #if SLANG_CUDA_RTC #define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \ SLANG_FORCE_INLINE SLANG_CUDA_CALL T##1 make_##T##1(T x) { return T##1{x}; }\ @@ -408,6 +413,27 @@ SLANG_MAKE_VECTOR_FROM_SCALAR(double) SLANG_MAKE_VECTOR_FROM_SCALAR(__half) #endif +#define SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(Fn,T,N) \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##N Fn(T##N* address, T##N val) \ + {\ + T##N result; \ + for (int i = 0; i < N; i++) \ + *_slang_vector_get_element_ptr(&result, i) = Fn(_slang_vector_get_element_ptr(address, i), _slang_vector_get_element(val, i)); \ + return result; \ + }\ + +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, float, 2) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, float, 3) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, float, 4) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, int, 2) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, int, 3) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, int, 4) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, uint, 2) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, uint, 3) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, uint, 4) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, ulonglong, 2) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, ulonglong, 3) +SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(atomicAdd, ulonglong, 4) template<typename T, int n> struct GetVectorTypeImpl {}; |
