summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-torch-prelude.h7
1 files changed, 5 insertions, 2 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 4844e9248..70c516a3a 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -68,11 +68,14 @@ struct CudaTaskMemoryAllocator
}
};
-TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name)
+TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name, torch::ScalarType targetScalarType)
{
+ // Convert device and scalar types.
if (!val.device().is_cuda())
val = val.to(torch::kCUDA);
-
+ if (val.dtype() != targetScalarType)
+ val = val.to(targetScalarType);
+
TensorView res = {};
res.dimensionCount = val.dim();
res.strides = allocator->allocUIntArray(val.dim());