diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-11 15:11:45 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-11 15:11:45 -0700 |
| commit | 7c3a40cf08091a6cf0ec2de1e9694c979fb5c551 (patch) | |
| tree | 7866ecc98be4742ec7528c524bc7a43e27f2be85 /source/slang/slang-emit-torch.cpp | |
| parent | 54f112f8074c8ca490195c10db8c518cdc58546a (diff) | |
Small fixes to TorchTensor. (#2790)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-torch.cpp')
| -rw-r--r-- | source/slang/slang-emit-torch.cpp | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index c198b011d..819a6a136 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -165,6 +165,15 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& auto arg = inst->getOperand(i); emitOperand(arg, getInfo(EmitOp::General)); } + if (as<IRTorchTensorType>(inst->getDataType())) + { + if (auto vectorType = as<IRVectorType>(inst->getDataType()->getOperand(0))) + { + // If the element type of the tensor is a vector, we need to add the vector size to the shape. + m_writer->emit(", "); + emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General)); + } + } m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype("); emitTorchScalarTypeName(m_writer, inst->getDataType()); m_writer->emit("))"); |
