diff options
Diffstat (limited to 'source/slang/slang-emit-torch.cpp')
| -rw-r--r-- | source/slang/slang-emit-torch.cpp | 95 |
1 files changed, 50 insertions, 45 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index 276164ed5..c198b011d 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -65,6 +65,49 @@ void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type) } } +void TorchCppSourceEmitter::emitInstStmtImpl(IRInst* inst) +{ + switch (inst->getOp()) + { + default: + return; + case kIROp_CudaKernelLaunch: + { + m_writer->emit("AT_CUDA_CHECK(cudaLaunchKernel("); + // func + m_writer->emit("(const void*)("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // gridDim + m_writer->emit("slang_bit_cast<dim3>("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // blockDim + m_writer->emit("slang_bit_cast<dim3>("); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // args + emitOperand(inst->getOperand(3), getInfo(EmitOp::General)); + m_writer->emit(", "); + + // shared mem + m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")), "); + + // stream + m_writer->emit("((cudaStream_t)"); + emitOperand(inst->getOperand(4), getInfo(EmitOp::General)); + m_writer->emit(")));\n"); + + break; + } + } +} + bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { switch (inst->getOp()) @@ -78,47 +121,12 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& m_writer->emit("make_tensor_view("); emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); m_writer->emit(", "); - emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); - m_writer->emit(", "); - emitStringLiteral(getUnmangledName(inst->getOperand(1))); + emitStringLiteral(getUnmangledName(inst->getOperand(0))); m_writer->emit(", "); - emitTorchScalarTypeName(m_writer, inst->getOperand(1)->getDataType()); + emitTorchScalarTypeName(m_writer, inst->getOperand(0)->getDataType()); m_writer->emit(")"); return true; } - case kIROp_CudaKernelLaunch: - { - m_writer->emit("cudaLaunchKernel("); - // func - m_writer->emit("(const void*)("); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit("), "); - - // gridDim - m_writer->emit("slang_bit_cast<dim3>("); - emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); - m_writer->emit("), "); - - // blockDim - m_writer->emit("slang_bit_cast<dim3>("); - emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); - m_writer->emit("), "); - - // args - emitOperand(inst->getOperand(3), getInfo(EmitOp::General)); - m_writer->emit(", "); - - // shared mem - m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)("); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(")), "); - - // stream - m_writer->emit("((cudaStream_t)"); - emitOperand(inst->getOperand(4), getInfo(EmitOp::General)); - m_writer->emit("))"); - return true; - } case kIROp_TorchGetCudaStream: { m_writer->emit("at::cuda::getCurrentCUDAStream()"); @@ -131,12 +139,14 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& /* Emit something like: ``` - torch::Tensor out = torch::zeros_like(other); + torch::Tensor out = torch::empty_like(other); ``` */ - m_writer->emit("torch::zeros_like("); + m_writer->emit("torch::empty_like("); emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(")"); + m_writer->emit(", torch::TensorOptions().device(torch::kCUDA).dtype("); + emitTorchScalarTypeName(m_writer, inst->getDataType()); + m_writer->emit("))"); } else { @@ -180,11 +190,6 @@ SlangResult TorchCppSourceEmitter::calcTypeName(IRType* type, CodeGenTarget targ out << "torch::Tensor"; return SLANG_OK; } - case kIROp_TorchKernelMemoryAllocatorType: - { - out << "CudaTaskMemoryAllocator"; - return SLANG_OK; - } } } |
