diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-07-14 18:12:38 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-14 15:12:38 -0700 |
| commit | 1b778811dbc1468ed8d1f6f82117017064de2c96 (patch) | |
| tree | fcca749feb362b4f2ec30eb648d591e242b1dbd0 | |
| parent | 2de296ca00c0e77526ae1a83b4fb3df30419f70b (diff) | |
Avoid implicit casts or device transfers. (#2992)
| -rw-r--r-- | prelude/slang-torch-prelude.h | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index 0de06750c..a10e0070f 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -74,11 +74,18 @@ struct TensorView TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType) { - // Convert device and scalar types. + // We're currently not trying to implicitly cast or transfer to device for two reasons: + // 1. There appears to be a bug with .to() where successive calls after the first one fail. + // 2. Silent casts like this can cause large memory allocations & unexpected overheads. + // It's better to be explicit. + + // Expect tensors to be on CUDA device if (!val.device().is_cuda()) - val = val.to(torch::kCUDA); + throw std::runtime_error(std::string(name).append(": tensor is not on CUDA device.").c_str()); + + // Expect tensors to be the right type. if (val.dtype() != targetScalarType) - val = val.to(targetScalarType); + throw std::runtime_error(std::string(name).append(": tensor is not of the expected type.").c_str()); TensorView res = {}; res.dimensionCount = val.dim(); |
