diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-26 13:59:11 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-26 13:59:11 -0700 |
| commit | d64ee86a3130f8eeb75d09193c38c621d7565eba (patch) | |
| tree | fed25a0cc2a7372d26175774f5983bed693e6b64 /source/slang/slang-emit-torch.cpp | |
| parent | 666af0962b6ab41489a3a3287db83f77c2f6461a (diff) | |
Add PyTorch C++ binding generation. (#2734)
* Add PyTorch C++ binding generation.
* fix
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-torch.cpp')
| -rw-r--r-- | source/slang/slang-emit-torch.cpp | 181 |
1 files changed, 181 insertions, 0 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp new file mode 100644 index 000000000..ef67c520a --- /dev/null +++ b/source/slang/slang-emit-torch.cpp @@ -0,0 +1,181 @@ +// slang-emit-torch.cpp +#include "slang-emit-torch.h" + +#include "../core/slang-writer.h" + +#include "slang-emit-source-writer.h" +#include "slang-mangled-lexer.h" + +#include <assert.h> + +namespace Slang +{ +bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) +{ + switch (inst->getOp()) + { + default: + { + return Super::tryEmitInstExprImpl(inst, inOuterPrec); + } + case kIROp_MakeTensorView: + { + m_writer->emit("make_tensor_view("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_CudaKernelLaunch: + { + m_writer->emit("cudaLaunchKernel("); + // func + m_writer->emit("(const void*)("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // gridDim + m_writer->emit("slang_bit_cast<dim3>("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // blockDim + m_writer->emit("slang_bit_cast<dim3>("); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // args + emitOperand(inst->getOperand(3), getInfo(EmitOp::General)); + m_writer->emit(", "); + + // shared mem + m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")), "); + + // stream + m_writer->emit("((cudaStream_t)"); + emitOperand(inst->getOperand(4), getInfo(EmitOp::General)); + m_writer->emit("))"); + return true; + } + case kIROp_TorchGetCudaStream: + { + m_writer->emit("at::cuda::getCurrentCUDAStream()"); + return true; + } + case kIROp_AllocateTorchTensor: + { + /* + Emit something like: + ``` + torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, + torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); + ``` + */ + m_writer->emit("torch::empty({ "); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (i > 0) + m_writer->emit(", "); + auto arg = inst->getOperand(i); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); + switch (inst->getDataType()->getOperand(0)->getOp()) + { + case kIROp_FloatType: + m_writer->emit("kFloat32"); + break; + case kIROp_HalfType: + m_writer->emit("kFloat16"); + break; + case kIROp_DoubleType: + m_writer->emit("kFloat64"); + break; + case kIROp_UInt8Type: + m_writer->emit("kUInt8"); + break; + case kIROp_UInt16Type: + m_writer->emit("kUInt16"); + break; + case kIROp_UIntType: + m_writer->emit("kUInt32"); + break; + case kIROp_UInt64Type: + m_writer->emit("kUInt64"); + break; + case kIROp_Int8Type: + m_writer->emit("kInt8"); + break; + case kIROp_Int16Type: + m_writer->emit("kInt16"); + break; + case kIROp_IntType: + m_writer->emit("kInt32"); + break; + case kIROp_Int64Type: + m_writer->emit("kInt64"); + break; + default: + SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); + break; + } + m_writer->emit("))"); + return true; + } + } +} + +SlangResult TorchCppSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) +{ + switch (type->getOp()) + { + default: + return Super::calcTypeName(type, target, out); + case kIROp_TensorViewType: + { + out << "TensorView"; + return SLANG_OK; + } + case kIROp_TorchTensorType: + { + out << "torch::Tensor"; + return SLANG_OK; + } + case kIROp_TorchKernelMemoryAllocatorType: + { + out << "CudaTaskMemoryAllocator"; + return SLANG_OK; + } + } +} + +void TorchCppSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink) +{ + Super::emitModuleImpl(module, sink); + + // Emit PyBind declarations. + m_writer->emit("PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n"); + m_writer->indent(); + for (auto inst : module->getGlobalInsts()) + { + auto func = as<IRFunc>(inst); + if (!func) continue; + auto decor = func->findDecoration<IRTorchEntryPointDecoration>(); + if (!decor) continue; + m_writer->emit("m.def("); + emitStringLiteral(decor->getFunctionName()); + m_writer->emit(", &"); + m_writer->emit(decor->getFunctionName()); + m_writer->emit(", "); + emitStringLiteral(decor->getFunctionName()); + m_writer->emit(");\n"); + } + m_writer->dedent(); + m_writer->emit("}\n"); + +} + +} // namespace Slang |
