summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-30 12:50:02 -0700
committerGitHub <noreply@github.com>2023-03-30 12:50:02 -0700
commit917416f6db7056cddff9d2a0e4e9b4117359157d (patch)
tree9bd6aa89f235e4692cff83cdbe1ce4aae7ea861f /prelude
parente3b701c9f56f4a2fb8c56a65b5c75b49ee72ca73 (diff)
More builtin library support in torch backend. (#2760)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-cuda-prelude.h33
-rw-r--r--prelude/slang-torch-prelude.h31
2 files changed, 39 insertions, 25 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index a1c49e108..5d24df455 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -2095,13 +2095,14 @@ __forceinline__ __device__ void *traceOptiXRay(
#endif
+static const int kSlangTorchTensorMaxDim = 5;
// TensorView
struct TensorView
{
uint8_t* data;
- uint32_t* strides;
- uint32_t* sizes;
+ uint32_t strides[kSlangTorchTensorMaxDim];
+ uint32_t sizes[kSlangTorchTensorMaxDim];
uint32_t dimensionCount;
template<typename T>
@@ -2111,6 +2112,34 @@ struct TensorView
}
template<typename T>
+ __device__ T* data_ptr_at(uint32_t index)
+ {
+ uint64_t offset = strides[0] * index;
+ return reinterpret_cast<T*>(data + offset);
+ }
+
+ template<typename T>
+ __device__ T* data_ptr_at(uint2 index)
+ {
+ uint64_t offset = strides[0] * index.x + strides[1] * index.y;
+ return reinterpret_cast<T*>(data + offset);
+ }
+
+ template<typename T>
+ __device__ T* data_ptr_at(uint3 index)
+ {
+ uint64_t offset = strides[0] * index.x + strides[1] * index.y + strides[2] * index.z;
+ return reinterpret_cast<T*>(data + offset);
+ }
+
+ template<typename T>
+ __device__ T* data_ptr_at(uint4 index)
+ {
+ uint64_t offset = strides[0] * index.x + strides[1] * index.y + strides[2] * index.z + strides[3] * index.w;
+ return reinterpret_cast<T*>(data + offset);
+ }
+
+ template<typename T>
__device__ T& load(uint32_t x)
{
return *reinterpret_cast<T*>(data + strides[0] * x);
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 70c516a3a..cf04b75ab 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -41,34 +41,18 @@
#include "slang-cpp-types-core.h"
#include "slang-cpp-scalar-intrinsics.h"
+static const int kSlangTorchTensorMaxDim = 5;
+
struct TensorView
{
uint8_t* data;
- uint32_t* strides;
- uint32_t* sizes;
+ uint32_t strides[kSlangTorchTensorMaxDim];
+ uint32_t sizes[kSlangTorchTensorMaxDim];
uint32_t dimensionCount;
};
-struct CudaTaskMemoryAllocator
-{
- std::vector<void*> allocations;
-
- uint32_t* allocUIntArray(uint32_t size)
- {
- void* ptr = nullptr;
- cudaMallocHost(&ptr, size * sizeof(uint32_t));
- AT_CUDA_CHECK(cudaGetLastError());
- return (uint32_t*)ptr;
- }
-
- ~CudaTaskMemoryAllocator()
- {
- for (auto ptr : allocations)
- cudaFree(ptr);
- }
-};
-TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name, torch::ScalarType targetScalarType)
+TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType)
{
// Convert device and scalar types.
if (!val.device().is_cuda())
@@ -78,8 +62,6 @@ TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor va
TensorView res = {};
res.dimensionCount = val.dim();
- res.strides = allocator->allocUIntArray(val.dim());
- res.sizes = allocator->allocUIntArray(val.dim());
res.data = nullptr;
size_t elementSize = 4;
@@ -116,6 +98,9 @@ TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor va
break;
}
+ if (val.dim() > kSlangTorchTensorMaxDim)
+ throw std::runtime_error(std::string(name).append(": number of dimensions exceeds limit (").append(std::to_string(kSlangTorchTensorMaxDim)).append(")").c_str());
+
for (int i = 0; i < val.dim(); ++i)
{
res.strides[i] = val.stride(i) * elementSize;