diff options
Diffstat (limited to 'prelude')
| -rw-r--r-- | prelude/slang-torch-prelude.h | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index 4844e9248..70c516a3a 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -68,11 +68,14 @@ struct CudaTaskMemoryAllocator } }; -TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name) +TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name, torch::ScalarType targetScalarType) { + // Convert device and scalar types. if (!val.device().is_cuda()) val = val.to(torch::kCUDA); - + if (val.dtype() != targetScalarType) + val = val.to(targetScalarType); + TensorView res = {}; res.dimensionCount = val.dim(); res.strides = allocator->allocUIntArray(val.dim()); |
