summaryrefslogtreecommitdiff
path: root/prelude
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-09-08 17:03:14 -0400
committerGitHub <noreply@github.com>2023-09-08 14:03:14 -0700
commit26a7cf79526b86a3dff4084d42dde8f1a8c9ac1d (patch)
tree3f928dc9eb5294e72a9874e66e90db326768ae9b /prelude
parente8a1dd11eab4c07366b29aca775eb927a465e133 (diff)
Remove unsupported torch types + add bool type. (#3197)
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-torch-prelude.h4
1 files changed, 4 insertions, 0 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index a10e0070f..dee116261 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -123,6 +123,10 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy
elementSize = 8;
res.data = (uint8_t*)val.data_ptr<int64_t>();
break;
+ case torch::kBool:
+ elementSize = 1;
+ res.data = (uint8_t*)val.data_ptr<bool>();
+ break;
}
if (val.dim() > kSlangTorchTensorMaxDim)