summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-cuda.cpp')
-rw-r--r--source/slang/slang-emit-cuda.cpp60
1 files changed, 48 insertions, 12 deletions
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index a728df755..702543fc8 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -606,21 +606,57 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>();
- if (entryPointDecor && entryPointDecor->getProfile().GetStage() == Stage::Compute)
+ if (entryPointDecor)
{
- Int sizeAlongAxis[kThreadGroupAxisCount];
- getComputeThreadGroupSize(func, sizeAlongAxis);
-
- //
- m_writer->emit("// [numthreads(");
- for (int ii = 0; ii < kThreadGroupAxisCount; ++ii)
+ // We have an entry-point function in the IR module, which we
+ // will want to emit as a `__global__` function in the generated
+ // CUDA C++.
+ //
+ // The most common case will be a compute kernel, in which case
+ // we will emit the function more or less as-is, including
+ // usingits original name as the name of the global symbol.
+ //
+ String funcName = getName(func);
+ String globalSymbolName = funcName;
+
+ // We also suport emitting ray tracing kernels for use with
+ // OptiX, and in that case the name of the global symbol
+ // must be prefixed to indicate to the OptiX runtime what
+ // stage it is to be compiled for.
+ //
+ auto stage = entryPointDecor->getProfile().GetStage();
+ switch( stage )
{
- if (ii != 0) m_writer->emit(", ");
- m_writer->emit(sizeAlongAxis[ii]);
+ default:
+ break;
+
+ #define CASE(STAGE, PREFIX) \
+ case Stage::STAGE: globalSymbolName = #PREFIX + funcName; break
+
+ CASE(RayGeneration, __raygen__);
+ // TODO: Add the other ray tracing shader stages here.
+ #undef CASE
}
- m_writer->emit(")]\n");
- String funcName = getName(func);
+ // As a convenience for anybody reading the generated
+ // CUDA C++ code, we will prefix a compute kernel
+ // with the information from the `[numthreads(...)]`
+ // attribute in the source.
+ //
+ if(stage == Stage::Compute)
+ {
+ Int sizeAlongAxis[kThreadGroupAxisCount];
+ getComputeThreadGroupSize(func, sizeAlongAxis);
+
+ //
+ m_writer->emit("// [numthreads(");
+ for (int ii = 0; ii < kThreadGroupAxisCount; ++ii)
+ {
+ if (ii != 0) m_writer->emit(", ");
+ m_writer->emit(sizeAlongAxis[ii]);
+ }
+ m_writer->emit(")]\n");
+ }
m_writer->emit("extern \"C\" __global__ ");
@@ -628,7 +664,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
// Emit the actual function
emitEntryPointAttributes(func, entryPointDecor);
- emitType(resultType, funcName);
+ emitType(resultType, globalSymbolName);
m_writer->emit("(UniformEntryPointParams* params, UniformState* uniformState)");
emitSemantics(func);