summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-09-08 20:12:25 -0400
committerGitHub <noreply@github.com>2023-09-08 17:12:25 -0700
commit87bb0b503544f1b8c6ec818e25c695b31cda24b7 (patch)
tree61c517aa5eba3292acf803afa321e0f2ec21ce7e
parent26a0b3e04689fee1ec9ec071eacd72faf1efe4eb (diff)
Add check for contiguous tensors (#3199)
Otherwise, this can lead to undetected scenario where the strides are incorrect for non-scalar types (`float2`, `float3`, etc..) Users must call `tensor = tensor.contiguous()` on the inputs to avoid this error.
-rw-r--r--prelude/slang-torch-prelude.h4
1 files changed, 4 insertions, 0 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index dee116261..8d978642d 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -87,6 +87,10 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy
if (val.dtype() != targetScalarType)
throw std::runtime_error(std::string(name).append(": tensor is not of the expected type.").c_str());
+ // Check that the tensor is contiguous
+ if (!val.is_contiguous())
+ throw std::runtime_error(std::string(name).append(": tensor is not contiguous.").c_str());
+
TensorView res = {};
res.dimensionCount = val.dim();
res.data = nullptr;