summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-29 18:23:21 -0700
committerGitHub <noreply@github.com>2023-03-29 18:23:21 -0700
commit6fa4edbfbf01ef582a3ddc2fdfdedc79ba60d365 (patch)
tree2fc3b6b7adb9e64cadf47fe4a4fdee7df4b4bceb /prelude
parentaf062bff8f670de6a0c4fe7be797487ba124d811 (diff)
Convert tensor types in `make_tensor_view`. (#2755)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-torch-prelude.h7
1 files changed, 5 insertions, 2 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 4844e9248..70c516a3a 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -68,11 +68,14 @@ struct CudaTaskMemoryAllocator
}
};
-TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name)
+TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name, torch::ScalarType targetScalarType)
{
+ // Convert device and scalar types.
if (!val.device().is_cuda())
val = val.to(torch::kCUDA);
-
+ if (val.dtype() != targetScalarType)
+ val = val.to(targetScalarType);
+
TensorView res = {};
res.dimensionCount = val.dim();
res.strides = allocator->allocUIntArray(val.dim());