summaryrefslogtreecommitdiff
path: root/prelude
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-09-23 12:11:45 -0400
committerGitHub <noreply@github.com>2023-09-23 12:11:45 -0400
commitab04bd0dd7dd6a818bbac8c5fef9372c4f597352 (patch)
treed37f49273bc48c55ea3e16a243817907af0ebcbc /prelude
parent263f807285c93272abb0c0352be6f8553f01a373 (diff)
More `slangpy` features + polishing (#3233)
* Update user-guide with new slangpy features * More polishing of new slangpy docs * Update a1-02-slangpy.md * Only require contiguity for vector element types * Added `loadOnce/storeOnce` and subscript operations * Added docs, `DiffTensorView.dims()` & `DiffTensorView.stride(uint)` * Add constructors, remove storeOnce/loadOnce test * Adjusted intrinsic definitions
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-cuda-prelude.h66
-rw-r--r--prelude/slang-torch-prelude.h4
2 files changed, 68 insertions, 2 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 77ed2d51f..9075ed3d3 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -2204,6 +2204,17 @@ struct TensorView
return reinterpret_cast<T*>(data + offset);
}
+ template<typename T, unsigned int N>
+ __device__ T* data_ptr_at(uint index[N])
+ {
+ uint64_t offset = 0;
+ for (unsigned int i = 0; i < N; ++i)
+ {
+ offset += strides[i] * index[i];
+ }
+ return reinterpret_cast<T*>(data + offset);
+ }
+
template<typename T>
__device__ T& load(uint32_t x)
{
@@ -2215,20 +2226,48 @@ struct TensorView
return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y);
}
template<typename T>
+ __device__ T& load(uint2 index)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y);
+ }
+ template<typename T>
__device__ T& load(uint32_t x, uint32_t y, uint32_t z)
{
return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z);
}
template<typename T>
+ __device__ T& load(uint3 index)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z);
+ }
+ template<typename T>
__device__ T& load(uint32_t x, uint32_t y, uint32_t z, uint32_t w)
{
return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w);
}
template<typename T>
+ __device__ T& load(uint4 index)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z + strides[3] * index.w);
+ }
+ template<typename T>
__device__ T& load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4)
{
return *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4);
}
+
+ // Generic version of load
+ template<typename T, unsigned int N>
+ __device__ T& load(uint index[N])
+ {
+ uint64_t offset = 0;
+ for (unsigned int i = 0; i < N; ++i)
+ {
+ offset += strides[i] * index[i];
+ }
+ return *reinterpret_cast<T*>(data + offset);
+ }
+
template<typename T>
__device__ void store(uint32_t x, T val)
{
@@ -2240,19 +2279,46 @@ struct TensorView
*reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y) = val;
}
template<typename T>
+ __device__ void store(uint2 index, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y) = val;
+ }
+ template<typename T>
__device__ void store(uint32_t x, uint32_t y, uint32_t z, T val)
{
*reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z) = val;
}
template<typename T>
+ __device__ void store(uint3 index, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z) = val;
+ }
+ template<typename T>
__device__ void store(uint32_t x, uint32_t y, uint32_t z, uint32_t w, T val)
{
*reinterpret_cast<T*>(
data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w) = val;
}
template<typename T>
+ __device__ void store(uint4 index, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z + strides[3] * index.w) = val;
+ }
+ template<typename T>
__device__ void store(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4, T val)
{
*reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4) = val;
}
+
+ // Generic version
+ template<typename T, unsigned int N>
+ __device__ void store(uint index[N], T val)
+ {
+ uint64_t offset = 0;
+ for (unsigned int i = 0; i < N; ++i)
+ {
+ offset += strides[i] * index[i];
+ }
+ *reinterpret_cast<T*>(data + offset) = val;
+ }
};
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 8d978642d..a2e4a1980 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -72,7 +72,7 @@ struct TensorView
};
-TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType)
+TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType, bool requireContiguous)
{
// We're currently not trying to implicitly cast or transfer to device for two reasons:
// 1. There appears to be a bug with .to() where successive calls after the first one fail.
@@ -88,7 +88,7 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy
throw std::runtime_error(std::string(name).append(": tensor is not of the expected type.").c_str());
// Check that the tensor is contiguous
- if (!val.is_contiguous())
+ if (requireContiguous && !val.is_contiguous())
throw std::runtime_error(std::string(name).append(": tensor is not contiguous.").c_str());
TensorView res = {};