From 1b778811dbc1468ed8d1f6f82117017064de2c96 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 14 Jul 2023 18:12:38 -0400 Subject: Avoid implicit casts or device transfers. (#2992) --- prelude/slang-torch-prelude.h | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index 0de06750c..a10e0070f 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -74,11 +74,18 @@ struct TensorView TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType) { - // Convert device and scalar types. + // We're currently not trying to implicitly cast or transfer to device for two reasons: + // 1. There appears to be a bug with .to() where successive calls after the first one fail. + // 2. Silent casts like this can cause large memory allocations & unexpected overheads. + // It's better to be explicit. + + // Expect tensors to be on CUDA device if (!val.device().is_cuda()) - val = val.to(torch::kCUDA); + throw std::runtime_error(std::string(name).append(": tensor is not on CUDA device.").c_str()); + + // Expect tensors to be the right type. if (val.dtype() != targetScalarType) - val = val.to(targetScalarType); + throw std::runtime_error(std::string(name).append(": tensor is not of the expected type.").c_str()); TensorView res = {}; res.dimensionCount = val.dim(); -- cgit v1.2.3