diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-07-10 16:19:06 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-10 16:19:06 -0400 |
| commit | 667e50498a226103278d0997528cc76979b2c4ef (patch) | |
| tree | 577cd20ba21744c088b6b9c5bff1720bebef9712 | |
| parent | 28ca743913975e42d9ed12bb824805a16bc52d94 (diff) | |
Add `float16` support to slang-torch (#4584)
| -rw-r--r-- | prelude/slang-torch-prelude.h | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index 28548e48e..11ffe3b66 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -107,6 +107,10 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy elementSize = 2; res.data = (uint8_t*)val.data_ptr<torch::BFloat16>(); break; + case torch::kFloat16: + elementSize = 2; + res.data = (uint8_t*)val.data_ptr<at::Half>(); + break; case torch::kInt16: elementSize = 2; res.data = (uint8_t*)val.data_ptr<int16_t>(); |
