summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-torch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-torch.cpp')
-rw-r--r--source/slang/slang-emit-torch.cpp95
1 files changed, 50 insertions, 45 deletions
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index 276164ed5..c198b011d 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -65,6 +65,49 @@ void emitTorchScalarTypeName(SourceWriter* m_writer, IRInst* type)
}
}
+void TorchCppSourceEmitter::emitInstStmtImpl(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ default:
+ return;
+ case kIROp_CudaKernelLaunch:
+ {
+ m_writer->emit("AT_CUDA_CHECK(cudaLaunchKernel(");
+ // func
+ m_writer->emit("(const void*)(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+
+ // gridDim
+ m_writer->emit("slang_bit_cast<dim3>(");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+
+ // blockDim
+ m_writer->emit("slang_bit_cast<dim3>(");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+
+ // args
+ emitOperand(inst->getOperand(3), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+
+ // shared mem
+ m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")), ");
+
+ // stream
+ m_writer->emit("((cudaStream_t)");
+ emitOperand(inst->getOperand(4), getInfo(EmitOp::General));
+ m_writer->emit(")));\n");
+
+ break;
+ }
+ }
+}
+
bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
switch (inst->getOp())
@@ -78,47 +121,12 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
m_writer->emit("make_tensor_view(");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit(", ");
- emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
- m_writer->emit(", ");
- emitStringLiteral(getUnmangledName(inst->getOperand(1)));
+ emitStringLiteral(getUnmangledName(inst->getOperand(0)));
m_writer->emit(", ");
- emitTorchScalarTypeName(m_writer, inst->getOperand(1)->getDataType());
+ emitTorchScalarTypeName(m_writer, inst->getOperand(0)->getDataType());
m_writer->emit(")");
return true;
}
- case kIROp_CudaKernelLaunch:
- {
- m_writer->emit("cudaLaunchKernel(");
- // func
- m_writer->emit("(const void*)(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit("), ");
-
- // gridDim
- m_writer->emit("slang_bit_cast<dim3>(");
- emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
- m_writer->emit("), ");
-
- // blockDim
- m_writer->emit("slang_bit_cast<dim3>(");
- emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
- m_writer->emit("), ");
-
- // args
- emitOperand(inst->getOperand(3), getInfo(EmitOp::General));
- m_writer->emit(", ");
-
- // shared mem
- m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(")), ");
-
- // stream
- m_writer->emit("((cudaStream_t)");
- emitOperand(inst->getOperand(4), getInfo(EmitOp::General));
- m_writer->emit("))");
- return true;
- }
case kIROp_TorchGetCudaStream:
{
m_writer->emit("at::cuda::getCurrentCUDAStream()");
@@ -131,12 +139,14 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
/*
Emit something like:
```
- torch::Tensor out = torch::zeros_like(other);
+ torch::Tensor out = torch::empty_like(other);
```
*/
- m_writer->emit("torch::zeros_like(");
+ m_writer->emit("torch::empty_like(");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(")");
+ m_writer->emit(", torch::TensorOptions().device(torch::kCUDA).dtype(");
+ emitTorchScalarTypeName(m_writer, inst->getDataType());
+ m_writer->emit("))");
}
else
{
@@ -180,11 +190,6 @@ SlangResult TorchCppSourceEmitter::calcTypeName(IRType* type, CodeGenTarget targ
out << "torch::Tensor";
return SLANG_OK;
}
- case kIROp_TorchKernelMemoryAllocatorType:
- {
- out << "CudaTaskMemoryAllocator";
- return SLANG_OK;
- }
}
}