summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-28 15:19:03 -0700
committerGitHub <noreply@github.com>2023-03-28 15:19:03 -0700
commita61f089fbc4b944d058e6417d8a0d22d57ca5c92 (patch)
tree4fa1a0c6370b8d34262d297653239f48aa004c71 /prelude
parent8f03af5e5b580170fab3fd2fe6144f92038c7701 (diff)
Add slangpy doc, fix cuda prelude. (#2748)
* Add slangpy doc, fix cuda prelude. * more bug fix. * fix. * fix. * More fix. * fix. * f * fix prelude. * update prelude. * update doc * Update prelude. * add zeros_like * update doc. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-cuda-prelude.h22
-rw-r--r--prelude/slang-torch-prelude.h14
2 files changed, 28 insertions, 8 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 240dfb3e7..a1c49e108 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -6,6 +6,14 @@
#define SLANG_CUDA_RTC 0
#endif
+#if SLANG_CUDA_RTC
+
+#else
+
+#include <cstdint>
+
+#endif
+
// Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support.
// For this to work NVRTC needs to have the path to the CUDA SDK.
//
@@ -162,6 +170,8 @@ typedef int2 bool2;
typedef int3 bool3;
typedef int4 bool4;
+#if SLANG_CUDA_RTC
+
typedef signed char int8_t;
typedef short int16_t;
typedef int int32_t;
@@ -172,6 +182,8 @@ typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long long uint64_t;
+#endif
+
typedef long long longlong;
typedef unsigned long long ulonglong;
@@ -2099,27 +2111,27 @@ struct TensorView
}
template<typename T>
- __device__ T load(uint32_t x)
+ __device__ T& load(uint32_t x)
{
return *reinterpret_cast<T*>(data + strides[0] * x);
}
template<typename T>
- __device__ T load(uint32_t x, uint32_t y)
+ __device__ T& load(uint32_t x, uint32_t y)
{
return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y);
}
template<typename T>
- __device__ T load(uint32_t x, uint32_t y, uint32_t z)
+ __device__ T& load(uint32_t x, uint32_t y, uint32_t z)
{
return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z);
}
template<typename T>
- __device__ T load(uint32_t x, uint32_t y, uint32_t z, uint32_t w)
+ __device__ T& load(uint32_t x, uint32_t y, uint32_t z, uint32_t w)
{
return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w);
}
template<typename T>
- __device__ T load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4)
+ __device__ T& load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4)
{
return *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4);
}
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index f2accc149..4844e9248 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -4,6 +4,8 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAUtils.h>
#include <vector>
+#include <stdexcept>
+#include <string>
#ifndef SLANG_NO_THROW
# define SLANG_NO_THROW
@@ -54,7 +56,7 @@ struct CudaTaskMemoryAllocator
uint32_t* allocUIntArray(uint32_t size)
{
void* ptr = nullptr;
- cudaMallocManaged(&ptr, size * sizeof(uint32_t));
+ cudaMallocHost(&ptr, size * sizeof(uint32_t));
AT_CUDA_CHECK(cudaGetLastError());
return (uint32_t*)ptr;
}
@@ -66,15 +68,18 @@ struct CudaTaskMemoryAllocator
}
};
-TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val)
+TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name)
{
- val = val.to(torch::kCUDA);
+ if (!val.device().is_cuda())
+ val = val.to(torch::kCUDA);
+
TensorView res = {};
res.dimensionCount = val.dim();
res.strides = allocator->allocUIntArray(val.dim());
res.sizes = allocator->allocUIntArray(val.dim());
res.data = nullptr;
size_t elementSize = 4;
+
switch (val.scalar_type())
{
case torch::kInt8:
@@ -107,11 +112,14 @@ TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor va
res.data = (uint8_t*)val.data_ptr<int64_t>();
break;
}
+
for (int i = 0; i < val.dim(); ++i)
{
res.strides[i] = val.stride(i) * elementSize;
res.sizes[i] = val.size(i);
}
+ if (!res.data)
+ throw std::runtime_error(std::string(name).append(": data pointer is invalid.").c_str());
return res;
}