From 7c3a40cf08091a6cf0ec2de1e9694c979fb5c551 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 11 Apr 2023 15:11:45 -0700 Subject: Small fixes to TorchTensor. (#2790) Co-authored-by: Yong He --- source/slang/slang-emit-torch.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'source/slang/slang-emit-torch.cpp') diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index c198b011d..819a6a136 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -165,6 +165,15 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& auto arg = inst->getOperand(i); emitOperand(arg, getInfo(EmitOp::General)); } + if (as(inst->getDataType())) + { + if (auto vectorType = as(inst->getDataType()->getOperand(0))) + { + // If the element type of the tensor is a vector, we need to add the vector size to the shape. + m_writer->emit(", "); + emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General)); + } + } m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype("); emitTorchScalarTypeName(m_writer, inst->getDataType()); m_writer->emit("))"); -- cgit v1.2.3