From f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska Date: Tue, 29 Oct 2024 14:49:26 +0800 Subject: format * format * Minor test fixes * enable checking cpp format in ci --- source/slang/slang-emit-torch.cpp | 204 ++++++++++++++++++-------------------- 1 file changed, 94 insertions(+), 110 deletions(-) (limited to 'source/slang/slang-emit-torch.cpp') diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index bacc9d030..2da2cf79c 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -2,7 +2,6 @@ #include "slang-emit-torch.h" #include "../core/slang-writer.h" - #include "slang-emit-source-writer.h" #include "slang-mangled-lexer.h" @@ -26,37 +25,19 @@ void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type) switch (instType->getOp()) { - case kIROp_FloatType: - m_writer->emit("kFloat32"); - break; - case kIROp_HalfType: - m_writer->emit("kFloat16"); - break; - case kIROp_DoubleType: - m_writer->emit("kFloat64"); - break; - case kIROp_UInt8Type: - m_writer->emit("kUInt8"); - break; - case kIROp_Int8Type: - m_writer->emit("kInt8"); - break; - case kIROp_Int16Type: - m_writer->emit("kInt16"); - break; - case kIROp_IntType: - m_writer->emit("kInt32"); - break; - case kIROp_Int64Type: - m_writer->emit("kInt64"); - break; - case kIROp_BoolType: - m_writer->emit("kBool"); - break; + case kIROp_FloatType: m_writer->emit("kFloat32"); break; + case kIROp_HalfType: m_writer->emit("kFloat16"); break; + case kIROp_DoubleType: m_writer->emit("kFloat64"); break; + case kIROp_UInt8Type: m_writer->emit("kUInt8"); break; + case kIROp_Int8Type: m_writer->emit("kInt8"); break; + case kIROp_Int16Type: m_writer->emit("kInt16"); break; + case kIROp_IntType: m_writer->emit("kInt32"); break; + case kIROp_Int64Type: m_writer->emit("kInt64"); break; + case kIROp_BoolType: m_writer->emit("kBool"); break; default: - SLANG_UNEXPECTED(( - std::string("unknown scalar type in allocTorchTensor: ") + - std::string(getIROpInfo(type->getOp()).name)).c_str()); + SLANG_UNEXPECTED((std::string("unknown scalar type in allocTorchTensor: ") + + std::string(getIROpInfo(type->getOp()).name)) + .c_str()); break; } } @@ -65,8 +46,7 @@ bool TorchCppSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) { switch (inst->getOp()) { - default: - return false; + default: return false; case kIROp_CudaKernelLaunch: { m_writer->emit("AT_CUDA_CHECK(cudaLaunchKernel("); @@ -96,7 +76,7 @@ bool TorchCppSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) m_writer->emit("((cudaStream_t)"); emitOperand(inst->getOperand(4), getInfo(EmitOp::General)); m_writer->emit(")));\n"); - + return true; } } @@ -107,100 +87,103 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& switch (inst->getOp()) { default: - { - return Super::tryEmitInstExprImpl(inst, inOuterPrec); - } + { + return Super::tryEmitInstExprImpl(inst, inOuterPrec); + } case kIROp_MakeTensorView: - { - m_writer->emit("make_tensor_view("); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(", "); - emitStringLiteral(getUnmangledName(inst->getOperand(0))); - m_writer->emit(", "); - emitTorchScalarTypeName(m_writer, inst->getOperand(0)->getDataType()); - m_writer->emit(", "); + { + m_writer->emit("make_tensor_view("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitStringLiteral(getUnmangledName(inst->getOperand(0))); + m_writer->emit(", "); + emitTorchScalarTypeName(m_writer, inst->getOperand(0)->getDataType()); + m_writer->emit(", "); - auto tensorViewType = as(inst->getDataType()); - if (as(tensorViewType->getElementType())) - m_writer->emit("true"); - else - m_writer->emit("false"); - - m_writer->emit(")"); - return true; - } + auto tensorViewType = as(inst->getDataType()); + if (as(tensorViewType->getElementType())) + m_writer->emit("true"); + else + m_writer->emit("false"); + + m_writer->emit(")"); + return true; + } case kIROp_TorchGetCudaStream: - { - m_writer->emit("at::cuda::getCurrentCUDAStream()"); - return true; - } - case kIROp_AllocateTorchTensor: - { - if (as(inst->getOperand(0)->getDataType())) { - /* - Emit something like: - ``` - torch::Tensor out = torch::empty_like(other); - ``` - */ - m_writer->emit("torch::empty_like("); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(", torch::TensorOptions().device(torch::kCUDA).dtype("); - emitTorchScalarTypeName(m_writer, inst->getDataType()); - m_writer->emit("))"); + m_writer->emit("at::cuda::getCurrentCUDAStream()"); + return true; } - else + case kIROp_AllocateTorchTensor: { - /* - Emit something like: - ``` - torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, - torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); - ``` - */ - m_writer->emit("torch::empty({ "); - for (UInt i = 0; i < inst->getOperandCount(); i++) + if (as(inst->getOperand(0)->getDataType())) { - if (i > 0) - m_writer->emit(", "); - auto arg = inst->getOperand(i); - emitOperand(arg, getInfo(EmitOp::General)); + /* + Emit something like: + ``` + torch::Tensor out = torch::empty_like(other); + ``` + */ + m_writer->emit("torch::empty_like("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", torch::TensorOptions().device(torch::kCUDA).dtype("); + emitTorchScalarTypeName(m_writer, inst->getDataType()); + m_writer->emit("))"); } - if (as(inst->getDataType())) + else { - if (auto vectorType = as(inst->getDataType()->getOperand(0))) + /* + Emit something like: + ``` + torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, + torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); + ``` + */ + m_writer->emit("torch::empty({ "); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (i > 0) + m_writer->emit(", "); + auto arg = inst->getOperand(i); + emitOperand(arg, getInfo(EmitOp::General)); + } + if (as(inst->getDataType())) { - // If the element type of the tensor is a vector, we need to add the vector size to the shape. - m_writer->emit(", "); - emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General)); + if (auto vectorType = as(inst->getDataType()->getOperand(0))) + { + // If the element type of the tensor is a vector, we need to add the vector + // size to the shape. + m_writer->emit(", "); + emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General)); + } } + m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype("); + emitTorchScalarTypeName(m_writer, inst->getDataType()); + m_writer->emit("))"); } - m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype("); - emitTorchScalarTypeName(m_writer, inst->getDataType()); - m_writer->emit("))"); + return true; } - return true; - } } } -SlangResult TorchCppSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) +SlangResult TorchCppSourceEmitter::calcTypeName( + IRType* type, + CodeGenTarget target, + StringBuilder& out) { switch (type->getOp()) { - default: - return Super::calcTypeName(type, target, out); + default: return Super::calcTypeName(type, target, out); case kIROp_TensorViewType: - { - out << "TensorView"; - return SLANG_OK; - } + { + out << "TensorView"; + return SLANG_OK; + } case kIROp_TorchTensorType: - { - out << "torch::Tensor"; - return SLANG_OK; - } + { + out << "torch::Tensor"; + return SLANG_OK; + } } } @@ -214,9 +197,11 @@ void TorchCppSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sin for (auto inst : module->getGlobalInsts()) { auto func = as(inst); - if (!func) continue; + if (!func) + continue; auto decor = func->findDecoration(); - if (!decor) continue; + if (!decor) + continue; m_writer->emit("m.def("); emitStringLiteral(decor->getFunctionName()); m_writer->emit(", &"); @@ -227,7 +212,6 @@ void TorchCppSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sin } m_writer->dedent(); m_writer->emit("}\n"); - } } // namespace Slang -- cgit v1.2.3