From 6fa4edbfbf01ef582a3ddc2fdfdedc79ba60d365 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 29 Mar 2023 18:23:21 -0700 Subject: Convert tensor types in `make_tensor_view`. (#2755) Co-authored-by: Yong He --- prelude/slang-torch-prelude.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'prelude') diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index 4844e9248..70c516a3a 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -68,11 +68,14 @@ struct CudaTaskMemoryAllocator } }; -TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name) +TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name, torch::ScalarType targetScalarType) { + // Convert device and scalar types. if (!val.device().is_cuda()) val = val.to(torch::kCUDA); - + if (val.dtype() != targetScalarType) + val = val.to(targetScalarType); + TensorView res = {}; res.dimensionCount = val.dim(); res.strides = allocator->allocUIntArray(val.dim()); -- cgit v1.2.3