summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--prelude/slang-torch-prelude.h13
1 files changed, 10 insertions, 3 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 0de06750c..a10e0070f 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -74,11 +74,18 @@ struct TensorView
TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType)
{
- // Convert device and scalar types.
+ // We're currently not trying to implicitly cast or transfer to device for two reasons:
+ // 1. There appears to be a bug with .to() where successive calls after the first one fail.
+ // 2. Silent casts like this can cause large memory allocations & unexpected overheads.
+ // It's better to be explicit.
+
+ // Expect tensors to be on CUDA device
if (!val.device().is_cuda())
- val = val.to(torch::kCUDA);
+ throw std::runtime_error(std::string(name).append(": tensor is not on CUDA device.").c_str());
+
+ // Expect tensors to be the right type.
if (val.dtype() != targetScalarType)
- val = val.to(targetScalarType);
+ throw std::runtime_error(std::string(name).append(": tensor is not of the expected type.").c_str());
TensorView res = {};
res.dimensionCount = val.dim();