diff options
Diffstat (limited to 'source/slang/slang-emit-cuda.cpp')
| -rw-r--r-- | source/slang/slang-emit-cuda.cpp | 60 |
1 files changed, 48 insertions, 12 deletions
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index a728df755..702543fc8 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -606,21 +606,57 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>(); - if (entryPointDecor && entryPointDecor->getProfile().GetStage() == Stage::Compute) + if (entryPointDecor) { - Int sizeAlongAxis[kThreadGroupAxisCount]; - getComputeThreadGroupSize(func, sizeAlongAxis); - - // - m_writer->emit("// [numthreads("); - for (int ii = 0; ii < kThreadGroupAxisCount; ++ii) + // We have an entry-point function in the IR module, which we + // will want to emit as a `__global__` function in the generated + // CUDA C++. + // + // The most common case will be a compute kernel, in which case + // we will emit the function more or less as-is, including + // usingits original name as the name of the global symbol. + // + String funcName = getName(func); + String globalSymbolName = funcName; + + // We also suport emitting ray tracing kernels for use with + // OptiX, and in that case the name of the global symbol + // must be prefixed to indicate to the OptiX runtime what + // stage it is to be compiled for. + // + auto stage = entryPointDecor->getProfile().GetStage(); + switch( stage ) { - if (ii != 0) m_writer->emit(", "); - m_writer->emit(sizeAlongAxis[ii]); + default: + break; + + #define CASE(STAGE, PREFIX) \ + case Stage::STAGE: globalSymbolName = #PREFIX + funcName; break + + CASE(RayGeneration, __raygen__); + // TODO: Add the other ray tracing shader stages here. + #undef CASE } - m_writer->emit(")]\n"); - String funcName = getName(func); + // As a convenience for anybody reading the generated + // CUDA C++ code, we will prefix a compute kernel + // with the information from the `[numthreads(...)]` + // attribute in the source. + // + if(stage == Stage::Compute) + { + Int sizeAlongAxis[kThreadGroupAxisCount]; + getComputeThreadGroupSize(func, sizeAlongAxis); + + // + m_writer->emit("// [numthreads("); + for (int ii = 0; ii < kThreadGroupAxisCount; ++ii) + { + if (ii != 0) m_writer->emit(", "); + m_writer->emit(sizeAlongAxis[ii]); + } + m_writer->emit(")]\n"); + } m_writer->emit("extern \"C\" __global__ "); @@ -628,7 +664,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) // Emit the actual function emitEntryPointAttributes(func, entryPointDecor); - emitType(resultType, funcName); + emitType(resultType, globalSymbolName); m_writer->emit("(UniformEntryPointParams* params, UniformState* uniformState)"); emitSemantics(func); |
