diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-09-23 12:11:45 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-23 12:11:45 -0400 |
| commit | ab04bd0dd7dd6a818bbac8c5fef9372c4f597352 (patch) | |
| tree | d37f49273bc48c55ea3e16a243817907af0ebcbc /prelude | |
| parent | 263f807285c93272abb0c0352be6f8553f01a373 (diff) | |
More `slangpy` features + polishing (#3233)
* Update user-guide with new slangpy features
* More polishing of new slangpy docs
* Update a1-02-slangpy.md
* Only require contiguity for vector element types
* Added `loadOnce/storeOnce` and subscript operations
* Added docs, `DiffTensorView.dims()` & `DiffTensorView.stride(uint)`
* Add constructors, remove storeOnce/loadOnce test
* Adjusted intrinsic definitions
Diffstat (limited to 'prelude')
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 66 | ||||
| -rw-r--r-- | prelude/slang-torch-prelude.h | 4 |
2 files changed, 68 insertions, 2 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 77ed2d51f..9075ed3d3 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -2204,6 +2204,17 @@ struct TensorView return reinterpret_cast<T*>(data + offset); } + template<typename T, unsigned int N> + __device__ T* data_ptr_at(uint index[N]) + { + uint64_t offset = 0; + for (unsigned int i = 0; i < N; ++i) + { + offset += strides[i] * index[i]; + } + return reinterpret_cast<T*>(data + offset); + } + template<typename T> __device__ T& load(uint32_t x) { @@ -2215,20 +2226,48 @@ struct TensorView return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y); } template<typename T> + __device__ T& load(uint2 index) + { + return *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y); + } + template<typename T> __device__ T& load(uint32_t x, uint32_t y, uint32_t z) { return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z); } template<typename T> + __device__ T& load(uint3 index) + { + return *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z); + } + template<typename T> __device__ T& load(uint32_t x, uint32_t y, uint32_t z, uint32_t w) { return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w); } template<typename T> + __device__ T& load(uint4 index) + { + return *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z + strides[3] * index.w); + } + template<typename T> __device__ T& load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4) { return *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4); } + + // Generic version of load + template<typename T, unsigned int N> + __device__ T& load(uint index[N]) + { + uint64_t offset = 0; + for (unsigned int i = 0; i < N; ++i) + { + offset += strides[i] * index[i]; + } + return *reinterpret_cast<T*>(data + offset); + } + template<typename T> __device__ void store(uint32_t x, T val) { @@ -2240,19 +2279,46 @@ struct TensorView *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y) = val; } template<typename T> + __device__ void store(uint2 index, T val) + { + *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y) = val; + } + template<typename T> __device__ void store(uint32_t x, uint32_t y, uint32_t z, T val) { *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z) = val; } template<typename T> + __device__ void store(uint3 index, T val) + { + *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z) = val; + } + template<typename T> __device__ void store(uint32_t x, uint32_t y, uint32_t z, uint32_t w, T val) { *reinterpret_cast<T*>( data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w) = val; } template<typename T> + __device__ void store(uint4 index, T val) + { + *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z + strides[3] * index.w) = val; + } + template<typename T> __device__ void store(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4, T val) { *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4) = val; } + + // Generic version + template<typename T, unsigned int N> + __device__ void store(uint index[N], T val) + { + uint64_t offset = 0; + for (unsigned int i = 0; i < N; ++i) + { + offset += strides[i] * index[i]; + } + *reinterpret_cast<T*>(data + offset) = val; + } }; diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index 8d978642d..a2e4a1980 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -72,7 +72,7 @@ struct TensorView }; -TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType) +TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType, bool requireContiguous) { // We're currently not trying to implicitly cast or transfer to device for two reasons: // 1. There appears to be a bug with .to() where successive calls after the first one fail. @@ -88,7 +88,7 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy throw std::runtime_error(std::string(name).append(": tensor is not of the expected type.").c_str()); // Check that the tensor is contiguous - if (!val.is_contiguous()) + if (requireContiguous && !val.is_contiguous()) throw std::runtime_error(std::string(name).append(": tensor is not contiguous.").c_str()); TensorView res = {}; |
