diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-09-08 20:12:25 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-08 17:12:25 -0700 |
| commit | 87bb0b503544f1b8c6ec818e25c695b31cda24b7 (patch) | |
| tree | 61c517aa5eba3292acf803afa321e0f2ec21ce7e | |
| parent | 26a0b3e04689fee1ec9ec071eacd72faf1efe4eb (diff) | |
Add check for contiguous tensors (#3199)
Otherwise, this can lead to undetected scenario where the strides are incorrect for non-scalar types (`float2`, `float3`, etc..)
Users must call `tensor = tensor.contiguous()` on the inputs to avoid this error.
| -rw-r--r-- | prelude/slang-torch-prelude.h | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index dee116261..8d978642d 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -87,6 +87,10 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy if (val.dtype() != targetScalarType) throw std::runtime_error(std::string(name).append(": tensor is not of the expected type.").c_str()); + // Check that the tensor is contiguous + if (!val.is_contiguous()) + throw std::runtime_error(std::string(name).append(": tensor is not contiguous.").c_str()); + TensorView res = {}; res.dimensionCount = val.dim(); res.data = nullptr; |
