summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-09-08 17:03:14 -0400
committerGitHub <noreply@github.com>2023-09-08 14:03:14 -0700
commit26a7cf79526b86a3dff4084d42dde8f1a8c9ac1d (patch)
tree3f928dc9eb5294e72a9874e66e90db326768ae9b
parente8a1dd11eab4c07366b29aca775eb927a465e133 (diff)
Remove unsupported torch types + add bool type. (#3197)
Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--prelude/slang-torch-prelude.h4
-rw-r--r--source/slang/slang-emit-torch.cpp16
2 files changed, 10 insertions, 10 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index a10e0070f..dee116261 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -123,6 +123,10 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy
elementSize = 8;
res.data = (uint8_t*)val.data_ptr<int64_t>();
break;
+ case torch::kBool:
+ elementSize = 1;
+ res.data = (uint8_t*)val.data_ptr<bool>();
+ break;
}
if (val.dim() > kSlangTorchTensorMaxDim)
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index ef04f33ba..7cd793ec1 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -38,15 +38,6 @@ void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type)
case kIROp_UInt8Type:
m_writer->emit("kUInt8");
break;
- case kIROp_UInt16Type:
- m_writer->emit("kUInt16");
- break;
- case kIROp_UIntType:
- m_writer->emit("kUInt32");
- break;
- case kIROp_UInt64Type:
- m_writer->emit("kUInt64");
- break;
case kIROp_Int8Type:
m_writer->emit("kInt8");
break;
@@ -59,8 +50,13 @@ void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type)
case kIROp_Int64Type:
m_writer->emit("kInt64");
break;
+ case kIROp_BoolType:
+ m_writer->emit("kBool");
+ break;
default:
- SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor");
+ SLANG_UNEXPECTED((
+ std::string("unknown scalar type in allocTorchTensor: ") +
+ std::string(getIROpInfo(type->getOp()).name)).c_str());
break;
}
}