From 7c3a40cf08091a6cf0ec2de1e9694c979fb5c551 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 11 Apr 2023 15:11:45 -0700 Subject: Small fixes to TorchTensor. (#2790) Co-authored-by: Yong He --- source/slang/slang-emit-torch.cpp | 9 +++++++++ source/slang/slang-ir-insts.h | 2 +- source/slang/slang-ir.cpp | 4 ++-- 3 files changed, 12 insertions(+), 3 deletions(-) (limited to 'source') diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index c198b011d..819a6a136 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -165,6 +165,15 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& auto arg = inst->getOperand(i); emitOperand(arg, getInfo(EmitOp::General)); } + if (as(inst->getDataType())) + { + if (auto vectorType = as(inst->getDataType()->getOperand(0))) + { + // If the element type of the tensor is a vector, we need to add the vector size to the shape. + m_writer->emit(", "); + emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General)); + } + } m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype("); emitTorchScalarTypeName(m_writer, inst->getDataType()); m_writer->emit("))"); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 26d00d4df..a19896287 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2807,7 +2807,7 @@ public: IRArrayListType* getArrayListType(IRType* elementType); IRTensorViewType* getTensorViewType(IRType* elementType); - IRTorchTensorType* getTorchTensorType(); + IRTorchTensorType* getTorchTensorType(IRType* elementType); IRDifferentialPairType* getDifferentialPairType( IRType* valueType, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6ca05d2d6..bd2271953 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2837,9 +2837,9 @@ namespace Slang (IRInst**)&elementType); } - IRTorchTensorType* IRBuilder::getTorchTensorType() + IRTorchTensorType* IRBuilder::getTorchTensorType(IRType* elementType) { - return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 0, nullptr); + return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 1, (IRInst**)&elementType); } IRDifferentialPairType* IRBuilder::getDifferentialPairType( -- cgit v1.2.3