diff options
| -rw-r--r-- | prelude/slang-torch-prelude.h | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index 0de06750c..a10e0070f 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -74,11 +74,18 @@ struct TensorView TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType) { - // Convert device and scalar types. + // We're currently not trying to implicitly cast or transfer to device for two reasons: + // 1. There appears to be a bug with .to() where successive calls after the first one fail. + // 2. Silent casts like this can cause large memory allocations & unexpected overheads. + // It's better to be explicit. + + // Expect tensors to be on CUDA device if (!val.device().is_cuda()) - val = val.to(torch::kCUDA); + throw std::runtime_error(std::string(name).append(": tensor is not on CUDA device.").c_str()); + + // Expect tensors to be the right type. if (val.dtype() != targetScalarType) - val = val.to(targetScalarType); + throw std::runtime_error(std::string(name).append(": tensor is not of the expected type.").c_str()); TensorView res = {}; res.dimensionCount = val.dim(); |
