diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-27 23:00:42 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-27 23:00:42 -0700 |
| commit | 0a6926003fd2300858e3089fe82f421543852395 (patch) | |
| tree | 19865fa9eb69373f0c0c16b7fac4993f67aa2b20 /source/slang/slang-emit-torch.cpp | |
| parent | d120fec7e81bbd5e8cf2c551b573feaf6678b43d (diff) | |
Translate all composed types into tuple types in pyBind. (#2744)
* Translate all composed types into tuple types in pyBind.
* Delete temp file.
* Fix get tuple element code emit logic.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-torch.cpp')
| -rw-r--r-- | source/slang/slang-emit-torch.cpp | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index ef67c520a..877c1dc03 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -83,7 +83,17 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& emitOperand(arg, getInfo(EmitOp::General)); } m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); - switch (inst->getDataType()->getOperand(0)->getOp()) + + // Get the element type of the tensor. + auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0); + + // If instType is a vector type, then we need to get the element type. + if (auto vectorType = as<IRVectorType>(instType)) + { + instType = vectorType->getElementType(); + } + + switch (instType->getOp()) { case kIROp_FloatType: m_writer->emit("kFloat32"); |
