diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-29 18:23:21 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-29 18:23:21 -0700 |
| commit | 6fa4edbfbf01ef582a3ddc2fdfdedc79ba60d365 (patch) | |
| tree | 2fc3b6b7adb9e64cadf47fe4a4fdee7df4b4bceb /prelude | |
| parent | af062bff8f670de6a0c4fe7be797487ba124d811 (diff) | |
Convert tensor types in `make_tensor_view`. (#2755)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'prelude')
| -rw-r--r-- | prelude/slang-torch-prelude.h | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index 4844e9248..70c516a3a 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -68,11 +68,14 @@ struct CudaTaskMemoryAllocator } }; -TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name) +TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name, torch::ScalarType targetScalarType) { + // Convert device and scalar types. if (!val.device().is_cuda()) val = val.to(torch::kCUDA); - + if (val.dtype() != targetScalarType) + val = val.to(targetScalarType); + TensorView res = {}; res.dimensionCount = val.dim(); res.strides = allocator->allocUIntArray(val.dim()); |
