From 26a7cf79526b86a3dff4084d42dde8f1a8c9ac1d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 8 Sep 2023 17:03:14 -0400 Subject: Remove unsupported torch types + add bool type. (#3197) Co-authored-by: Yong He --- prelude/slang-torch-prelude.h | 4 ++++ source/slang/slang-emit-torch.cpp | 16 ++++++---------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index a10e0070f..dee116261 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -123,6 +123,10 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy elementSize = 8; res.data = (uint8_t*)val.data_ptr(); break; + case torch::kBool: + elementSize = 1; + res.data = (uint8_t*)val.data_ptr(); + break; } if (val.dim() > kSlangTorchTensorMaxDim) diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index ef04f33ba..7cd793ec1 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -38,15 +38,6 @@ void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type) case kIROp_UInt8Type: m_writer->emit("kUInt8"); break; - case kIROp_UInt16Type: - m_writer->emit("kUInt16"); - break; - case kIROp_UIntType: - m_writer->emit("kUInt32"); - break; - case kIROp_UInt64Type: - m_writer->emit("kUInt64"); - break; case kIROp_Int8Type: m_writer->emit("kInt8"); break; @@ -59,8 +50,13 @@ void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type) case kIROp_Int64Type: m_writer->emit("kInt64"); break; + case kIROp_BoolType: + m_writer->emit("kBool"); + break; default: - SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); + SLANG_UNEXPECTED(( + std::string("unknown scalar type in allocTorchTensor: ") + + std::string(getIROpInfo(type->getOp()).name)).c_str()); break; } } -- cgit v1.2.3