summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-torch.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-30 12:50:02 -0700
committerGitHub <noreply@github.com>2023-03-30 12:50:02 -0700
commit917416f6db7056cddff9d2a0e4e9b4117359157d (patch)
tree9bd6aa89f235e4692cff83cdbe1ce4aae7ea861f /source/slang/slang-emit-torch.cpp
parente3b701c9f56f4a2fb8c56a65b5c75b49ee72ca73 (diff)
More builtin library support in torch backend. (#2760)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-torch.cpp')
-rw-r--r--source/slang/slang-emit-torch.cpp95
1 files changed, 50 insertions, 45 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index 276164ed5..c198b011d 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -65,6 +65,49 @@ void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type)
}
}
+void TorchCppSourceEmitter::emitInstStmtImpl(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ default:
+ return;
+ case kIROp_CudaKernelLaunch:
+ {
+ m_writer->emit("AT_CUDA_CHECK(cudaLaunchKernel(");
+ // func
+ m_writer->emit("(const void*)(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+
+ // gridDim
+ m_writer->emit("slang_bit_cast<dim3>(");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+
+ // blockDim
+ m_writer->emit("slang_bit_cast<dim3>(");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+
+ // args
+ emitOperand(inst->getOperand(3), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+
+ // shared mem
+ m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")), ");
+
+ // stream
+ m_writer->emit("((cudaStream_t)");
+ emitOperand(inst->getOperand(4), getInfo(EmitOp::General));
+ m_writer->emit(")));\n");
+
+ break;
+ }
+ }
+}
+
bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
switch (inst->getOp())
@@ -78,47 +121,12 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
m_writer->emit("make_tensor_view(");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit(", ");
- emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
- m_writer->emit(", ");
- emitStringLiteral(getUnmangledName(inst->getOperand(1)));
+ emitStringLiteral(getUnmangledName(inst->getOperand(0)));
m_writer->emit(", ");
- emitTorchScalarTypeName(m_writer, inst->getOperand(1)->getDataType());
+ emitTorchScalarTypeName(m_writer, inst->getOperand(0)->getDataType());
m_writer->emit(")");
return true;
}
- case kIROp_CudaKernelLaunch:
- {
- m_writer->emit("cudaLaunchKernel(");
- // func
- m_writer->emit("(const void*)(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit("), ");
-
- // gridDim
- m_writer->emit("slang_bit_cast<dim3>(");
- emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
- m_writer->emit("), ");
-
- // blockDim
- m_writer->emit("slang_bit_cast<dim3>(");
- emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
- m_writer->emit("), ");
-
- // args
- emitOperand(inst->getOperand(3), getInfo(EmitOp::General));
- m_writer->emit(", ");
-
- // shared mem
- m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(")), ");
-
- // stream
- m_writer->emit("((cudaStream_t)");
- emitOperand(inst->getOperand(4), getInfo(EmitOp::General));
- m_writer->emit("))");
- return true;
- }
case kIROp_TorchGetCudaStream:
{
m_writer->emit("at::cuda::getCurrentCUDAStream()");
@@ -131,12 +139,14 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
/*
Emit something like:
```
- torch::Tensor out = torch::zeros_like(other);
+ torch::Tensor out = torch::empty_like(other);
```
*/
- m_writer->emit("torch::zeros_like(");
+ m_writer->emit("torch::empty_like(");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(")");
+ m_writer->emit(", torch::TensorOptions().device(torch::kCUDA).dtype(");
+ emitTorchScalarTypeName(m_writer, inst->getDataType());
+ m_writer->emit("))");
}
else
{
@@ -180,11 +190,6 @@ SlangResult TorchCppSourceEmitter::calcTypeName(IRType* type, CodeGenTarget targ
out << "torch::Tensor";
return SLANG_OK;
}
- case kIROp_TorchKernelMemoryAllocatorType:
- {
- out << "CudaTaskMemoryAllocator";
- return SLANG_OK;
- }
}
}