summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-11 15:11:45 -0700
committerGitHub <noreply@github.com>2023-04-11 15:11:45 -0700
commit7c3a40cf08091a6cf0ec2de1e9694c979fb5c551 (patch)
tree7866ecc98be4742ec7528c524bc7a43e27f2be85 /source
parent54f112f8074c8ca490195c10db8c518cdc58546a (diff)
Small fixes to TorchTensor. (#2790)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-torch.cpp9
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir.cpp4
3 files changed, 12 insertions, 3 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index c198b011d..819a6a136 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -165,6 +165,15 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
auto arg = inst->getOperand(i);
emitOperand(arg, getInfo(EmitOp::General));
}
+ if (as<IRTorchTensorType>(inst->getDataType()))
+ {
+ if (auto vectorType = as<IRVectorType>(inst->getDataType()->getOperand(0)))
+ {
+ // If the element type of the tensor is a vector, we need to add the vector size to the shape.
+ m_writer->emit(", ");
+ emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General));
+ }
+ }
m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(");
emitTorchScalarTypeName(m_writer, inst->getDataType());
m_writer->emit("))");
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 26d00d4df..a19896287 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2807,7 +2807,7 @@ public:
IRArrayListType* getArrayListType(IRType* elementType);
IRTensorViewType* getTensorViewType(IRType* elementType);
- IRTorchTensorType* getTorchTensorType();
+ IRTorchTensorType* getTorchTensorType(IRType* elementType);
IRDifferentialPairType* getDifferentialPairType(
IRType* valueType,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6ca05d2d6..bd2271953 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2837,9 +2837,9 @@ namespace Slang
(IRInst**)&elementType);
}
- IRTorchTensorType* IRBuilder::getTorchTensorType()
+ IRTorchTensorType* IRBuilder::getTorchTensorType(IRType* elementType)
{
- return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 0, nullptr);
+ return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 1, (IRInst**)&elementType);
}
IRDifferentialPairType* IRBuilder::getDifferentialPairType(