From a61f089fbc4b944d058e6417d8a0d22d57ca5c92 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 28 Mar 2023 15:19:03 -0700 Subject: Add slangpy doc, fix cuda prelude. (#2748) * Add slangpy doc, fix cuda prelude. * more bug fix. * fix. * fix. * More fix. * fix. * f * fix prelude. * update prelude. * update doc * Update prelude. * add zeros_like * update doc. --------- Co-authored-by: Yong He --- source/slang/slang-emit-torch.cpp | 141 +++++++++++++++++++++----------------- 1 file changed, 79 insertions(+), 62 deletions(-) (limited to 'source/slang/slang-emit-torch.cpp') diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index 877c1dc03..4511039e3 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -24,6 +24,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); m_writer->emit(", "); emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitStringLiteral(getUnmangledName(inst->getOperand(1))); m_writer->emit(")"); return true; } @@ -67,72 +69,87 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& } case kIROp_AllocateTorchTensor: { - /* - Emit something like: - ``` - torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, - torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); - ``` - */ - m_writer->emit("torch::empty({ "); - for (UInt i = 0; i < inst->getOperandCount(); i++) + if (as(inst->getOperand(0)->getDataType())) { - if (i > 0) - m_writer->emit(", "); - auto arg = inst->getOperand(i); - emitOperand(arg, getInfo(EmitOp::General)); + /* + Emit something like: + ``` + torch::Tensor out = torch::zeros_like(other); + ``` + */ + m_writer->emit("torch::zeros_like("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); } - m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); - - // Get the element type of the tensor. - auto instType = as(inst->getDataType())->getOperand(0); - - // If instType is a vector type, then we need to get the element type. - if (auto vectorType = as(instType)) + else { - instType = vectorType->getElementType(); + /* + Emit something like: + ``` + torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, + torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); + ``` + */ + m_writer->emit("torch::empty({ "); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (i > 0) + m_writer->emit(", "); + auto arg = inst->getOperand(i); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); + + // Get the element type of the tensor. + auto instType = as(inst->getDataType())->getOperand(0); + + // If instType is a vector type, then we need to get the element type. + if (auto vectorType = as(instType)) + { + instType = vectorType->getElementType(); + } + + switch (instType->getOp()) + { + case kIROp_FloatType: + m_writer->emit("kFloat32"); + break; + case kIROp_HalfType: + m_writer->emit("kFloat16"); + break; + case kIROp_DoubleType: + m_writer->emit("kFloat64"); + break; + case kIROp_UInt8Type: + m_writer->emit("kUInt8"); + break; + case kIROp_UInt16Type: + m_writer->emit("kUInt16"); + break; + case kIROp_UIntType: + m_writer->emit("kUInt32"); + break; + case kIROp_UInt64Type: + m_writer->emit("kUInt64"); + break; + case kIROp_Int8Type: + m_writer->emit("kInt8"); + break; + case kIROp_Int16Type: + m_writer->emit("kInt16"); + break; + case kIROp_IntType: + m_writer->emit("kInt32"); + break; + case kIROp_Int64Type: + m_writer->emit("kInt64"); + break; + default: + SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); + break; + } + m_writer->emit("))"); } - - switch (instType->getOp()) - { - case kIROp_FloatType: - m_writer->emit("kFloat32"); - break; - case kIROp_HalfType: - m_writer->emit("kFloat16"); - break; - case kIROp_DoubleType: - m_writer->emit("kFloat64"); - break; - case kIROp_UInt8Type: - m_writer->emit("kUInt8"); - break; - case kIROp_UInt16Type: - m_writer->emit("kUInt16"); - break; - case kIROp_UIntType: - m_writer->emit("kUInt32"); - break; - case kIROp_UInt64Type: - m_writer->emit("kUInt64"); - break; - case kIROp_Int8Type: - m_writer->emit("kInt8"); - break; - case kIROp_Int16Type: - m_writer->emit("kInt16"); - break; - case kIROp_IntType: - m_writer->emit("kInt32"); - break; - case kIROp_Int64Type: - m_writer->emit("kInt64"); - break; - default: - SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); - break; - } - m_writer->emit("))"); return true; } } -- cgit v1.2.3