summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-torch.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-29 18:23:21 -0700
committerGitHub <noreply@github.com>2023-03-29 18:23:21 -0700
commit6fa4edbfbf01ef582a3ddc2fdfdedc79ba60d365 (patch)
tree2fc3b6b7adb9e64cadf47fe4a4fdee7df4b4bceb /source/slang/slang-emit-torch.cpp
parentaf062bff8f670de6a0c4fe7be797487ba124d811 (diff)
Convert tensor types in `make_tensor_view`. (#2755)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-torch.cpp')
-rw-r--r--source/slang/slang-emit-torch.cpp109
1 files changed, 59 insertions, 50 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index 4511039e3..276164ed5 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -10,6 +10,61 @@
namespace Slang
{
+
+void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type)
+{
+ m_writer->emit("torch::");
+
+ // Get the element type of the tensor.
+ auto instType = as<IRTorchTensorType>(type)->getOperand(0);
+
+ // If instType is a vector type, then we need to get the element type.
+ if (auto vectorType = as<IRVectorType>(instType))
+ {
+ instType = vectorType->getElementType();
+ }
+
+ switch (instType->getOp())
+ {
+ case kIROp_FloatType:
+ m_writer->emit("kFloat32");
+ break;
+ case kIROp_HalfType:
+ m_writer->emit("kFloat16");
+ break;
+ case kIROp_DoubleType:
+ m_writer->emit("kFloat64");
+ break;
+ case kIROp_UInt8Type:
+ m_writer->emit("kUInt8");
+ break;
+ case kIROp_UInt16Type:
+ m_writer->emit("kUInt16");
+ break;
+ case kIROp_UIntType:
+ m_writer->emit("kUInt32");
+ break;
+ case kIROp_UInt64Type:
+ m_writer->emit("kUInt64");
+ break;
+ case kIROp_Int8Type:
+ m_writer->emit("kInt8");
+ break;
+ case kIROp_Int16Type:
+ m_writer->emit("kInt16");
+ break;
+ case kIROp_IntType:
+ m_writer->emit("kInt32");
+ break;
+ case kIROp_Int64Type:
+ m_writer->emit("kInt64");
+ break;
+ default:
+ SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor");
+ break;
+ }
+}
+
bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
switch (inst->getOp())
@@ -26,6 +81,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit(", ");
emitStringLiteral(getUnmangledName(inst->getOperand(1)));
+ m_writer->emit(", ");
+ emitTorchScalarTypeName(m_writer, inst->getOperand(1)->getDataType());
m_writer->emit(")");
return true;
}
@@ -98,56 +155,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
auto arg = inst->getOperand(i);
emitOperand(arg, getInfo(EmitOp::General));
}
- m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::");
-
- // Get the element type of the tensor.
- auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0);
-
- // If instType is a vector type, then we need to get the element type.
- if (auto vectorType = as<IRVectorType>(instType))
- {
- instType = vectorType->getElementType();
- }
-
- switch (instType->getOp())
- {
- case kIROp_FloatType:
- m_writer->emit("kFloat32");
- break;
- case kIROp_HalfType:
- m_writer->emit("kFloat16");
- break;
- case kIROp_DoubleType:
- m_writer->emit("kFloat64");
- break;
- case kIROp_UInt8Type:
- m_writer->emit("kUInt8");
- break;
- case kIROp_UInt16Type:
- m_writer->emit("kUInt16");
- break;
- case kIROp_UIntType:
- m_writer->emit("kUInt32");
- break;
- case kIROp_UInt64Type:
- m_writer->emit("kUInt64");
- break;
- case kIROp_Int8Type:
- m_writer->emit("kInt8");
- break;
- case kIROp_Int16Type:
- m_writer->emit("kInt16");
- break;
- case kIROp_IntType:
- m_writer->emit("kInt32");
- break;
- case kIROp_Int64Type:
- m_writer->emit("kInt64");
- break;
- default:
- SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor");
- break;
- }
+ m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(");
+ emitTorchScalarTypeName(m_writer, inst->getDataType());
m_writer->emit("))");
}
return true;