summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-torch.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-27 23:00:42 -0700
committerGitHub <noreply@github.com>2023-03-27 23:00:42 -0700
commit0a6926003fd2300858e3089fe82f421543852395 (patch)
tree19865fa9eb69373f0c0c16b7fac4993f67aa2b20 /source/slang/slang-emit-torch.cpp
parentd120fec7e81bbd5e8cf2c551b573feaf6678b43d (diff)
Translate all composed types into tuple types in pyBind. (#2744)
* Translate all composed types into tuple types in pyBind. * Delete temp file. * Fix get tuple element code emit logic. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-torch.cpp')
-rw-r--r--source/slang/slang-emit-torch.cpp12
1 files changed, 11 insertions, 1 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index ef67c520a..877c1dc03 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -83,7 +83,17 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
emitOperand(arg, getInfo(EmitOp::General));
}
m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::");
- switch (inst->getDataType()->getOperand(0)->getOp())
+
+ // Get the element type of the tensor.
+ auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0);
+
+ // If instType is a vector type, then we need to get the element type.
+ if (auto vectorType = as<IRVectorType>(instType))
+ {
+ instType = vectorType->getElementType();
+ }
+
+ switch (instType->getOp())
{
case kIROp_FloatType:
m_writer->emit("kFloat32");