summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-07-10 16:19:06 -0400
committerGitHub <noreply@github.com>2024-07-10 16:19:06 -0400
commit667e50498a226103278d0997528cc76979b2c4ef (patch)
tree577cd20ba21744c088b6b9c5bff1720bebef9712 /prelude
parent28ca743913975e42d9ed12bb824805a16bc52d94 (diff)
Add `float16` support to slang-torch (#4584)
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-torch-prelude.h4
1 files changed, 4 insertions, 0 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 28548e48e..11ffe3b66 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -107,6 +107,10 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy
elementSize = 2;
res.data = (uint8_t*)val.data_ptr<torch::BFloat16>();
break;
+ case torch::kFloat16:
+ elementSize = 2;
+ res.data = (uint8_t*)val.data_ptr<at::Half>();
+ break;
case torch::kInt16:
elementSize = 2;
res.data = (uint8_t*)val.data_ptr<int16_t>();