summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-07-14 18:12:38 -0400
committerGitHub <noreply@github.com>2023-07-14 15:12:38 -0700
commit1b778811dbc1468ed8d1f6f82117017064de2c96 (patch)
treefcca749feb362b4f2ec30eb648d591e242b1dbd0
parent2de296ca00c0e77526ae1a83b4fb3df30419f70b (diff)
Avoid implicit casts or device transfers. (#2992)
-rw-r--r--prelude/slang-torch-prelude.h13
1 files changed, 10 insertions, 3 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 0de06750c..a10e0070f 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -74,11 +74,18 @@ struct TensorView
TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType)
{
- // Convert device and scalar types.
+ // We're currently not trying to implicitly cast or transfer to device for two reasons:
+ // 1. There appears to be a bug with .to() where successive calls after the first one fail.
+ // 2. Silent casts like this can cause large memory allocations & unexpected overheads.
+ // It's better to be explicit.
+
+ // Expect tensors to be on CUDA device
if (!val.device().is_cuda())
- val = val.to(torch::kCUDA);
+ throw std::runtime_error(std::string(name).append(": tensor is not on CUDA device.").c_str());
+
+ // Expect tensors to be the right type.
if (val.dtype() != targetScalarType)
- val = val.to(targetScalarType);
+ throw std::runtime_error(std::string(name).append(": tensor is not of the expected type.").c_str());
TensorView res = {};
res.dimensionCount = val.dim();