diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-emit-torch.cpp | 109 |
1 files changed, 59 insertions, 50 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index 4511039e3..276164ed5 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -10,6 +10,61 @@ namespace Slang { + +void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type) +{ + m_writer->emit("torch::"); + + // Get the element type of the tensor. + auto instType = as<IRTorchTensorType>(type)->getOperand(0); + + // If instType is a vector type, then we need to get the element type. + if (auto vectorType = as<IRVectorType>(instType)) + { + instType = vectorType->getElementType(); + } + + switch (instType->getOp()) + { + case kIROp_FloatType: + m_writer->emit("kFloat32"); + break; + case kIROp_HalfType: + m_writer->emit("kFloat16"); + break; + case kIROp_DoubleType: + m_writer->emit("kFloat64"); + break; + case kIROp_UInt8Type: + m_writer->emit("kUInt8"); + break; + case kIROp_UInt16Type: + m_writer->emit("kUInt16"); + break; + case kIROp_UIntType: + m_writer->emit("kUInt32"); + break; + case kIROp_UInt64Type: + m_writer->emit("kUInt64"); + break; + case kIROp_Int8Type: + m_writer->emit("kInt8"); + break; + case kIROp_Int16Type: + m_writer->emit("kInt16"); + break; + case kIROp_IntType: + m_writer->emit("kInt32"); + break; + case kIROp_Int64Type: + m_writer->emit("kInt64"); + break; + default: + SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); + break; + } +} + bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { switch (inst->getOp()) @@ -26,6 +81,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); m_writer->emit(", "); emitStringLiteral(getUnmangledName(inst->getOperand(1))); + m_writer->emit(", "); + emitTorchScalarTypeName(m_writer, inst->getOperand(1)->getDataType()); m_writer->emit(")"); return true; } @@ -98,56 +155,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& auto arg = inst->getOperand(i); emitOperand(arg, getInfo(EmitOp::General)); } - m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); - - // Get the element type of the tensor. - auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0); - - // If instType is a vector type, then we need to get the element type. - if (auto vectorType = as<IRVectorType>(instType)) - { - instType = vectorType->getElementType(); - } - - switch (instType->getOp()) - { - case kIROp_FloatType: - m_writer->emit("kFloat32"); - break; - case kIROp_HalfType: - m_writer->emit("kFloat16"); - break; - case kIROp_DoubleType: - m_writer->emit("kFloat64"); - break; - case kIROp_UInt8Type: - m_writer->emit("kUInt8"); - break; - case kIROp_UInt16Type: - m_writer->emit("kUInt16"); - break; - case kIROp_UIntType: - m_writer->emit("kUInt32"); - break; - case kIROp_UInt64Type: - m_writer->emit("kUInt64"); - break; - case kIROp_Int8Type: - m_writer->emit("kInt8"); - break; - case kIROp_Int16Type: - m_writer->emit("kInt16"); - break; - case kIROp_IntType: - m_writer->emit("kInt32"); - break; - case kIROp_Int64Type: - m_writer->emit("kInt64"); - break; - default: - SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); - break; - } + m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype("); + emitTorchScalarTypeName(m_writer, inst->getDataType()); m_writer->emit("))"); } return true; |
