summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-cuda.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-17 15:57:22 -0700
committerGitHub <noreply@github.com>2023-03-17 15:57:22 -0700
commit7f11f883d0781952f002b3aa3222a3aa0040f18a (patch)
tree08eaf10fef39211fbc3f124679bfe8a35775a5a7 /source/slang/slang-emit-cuda.cpp
parent4b55bf6d75bdeed087728505a1c9b43d3a99af8d (diff)
Add support for emitting cuda kernel and host functions. (#2712)
* Add support for emitting cuda kernel and host functions. * Update test. * Fix cuda preamble emit. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-cuda.cpp')
-rw-r--r--source/slang/slang-emit-cuda.cpp32
1 files changed, 31 insertions, 1 deletions
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index a151ab0e2..846b3b1f2 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -223,9 +223,21 @@ void CUDASourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin
void CUDASourceEmitter::emitFunctionPreambleImpl(IRInst* inst)
{
- if(inst && inst->findDecoration<IREntryPointDecoration>())
+ if (!inst)
+ return;
+ if (inst->findDecoration<IREntryPointDecoration>())
{
m_writer->emit("extern \"C\" __global__ ");
+ return;
+ }
+
+ if (inst->findDecoration<IRCudaKernelDecoration>())
+ {
+ m_writer->emit("__global__ ");
+ }
+ else if (inst->findDecoration<IRCudaHostDecoration>())
+ {
+ m_writer->emit("__host__ ");
}
else
{
@@ -608,6 +620,24 @@ bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
m_writer->emit(")optixGetSbtDataPointer())");
return true;
}
+ case kIROp_DispatchKernel:
+ {
+ auto dispatchInst = as<IRDispatchKernel>(inst);
+ emitOperand(dispatchInst->getBaseFn(), getInfo(EmitOp::Atomic));
+ m_writer->emit("<<<");
+ emitOperand(dispatchInst->getThreadGroupSize(), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(dispatchInst->getDispatchSize(), getInfo(EmitOp::General));
+ m_writer->emit(">>>(");
+ for (UInt i = 0; i < dispatchInst->getArgCount(); i++)
+ {
+ if (i > 0)
+ m_writer->emit(", ");
+ emitOperand(dispatchInst->getArg(i), getInfo(EmitOp::General));
+ }
+ m_writer->emit(")");
+ return true;
+ }
default: break;
}