summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-cuda.cpp')
-rw-r--r--source/slang/slang-emit-cuda.cpp83
1 files changed, 79 insertions, 4 deletions
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index ac1e1ea63..fa63ba255 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -787,6 +787,22 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
#undef CASE
}
+ if( stage != Stage::Compute )
+ {
+ // Non-compute shaders (currently just OptiX ray tracing kernels)
+ // require parameter data that is shared across multiple kernels
+ // (which in our case is the global-scope shader parameters)
+ // to be passed using a global `__constant__` variable.
+ //
+ // The use of `"C"` linkage here is required because the name
+ // of this symbol must be passed to the OptiX API when creating
+ // a pipeline that uses this compiled module. The exact name
+ // used here (`SLANG_globalParams`) is thus a part of the
+ // binary interface for Slang->OptiX translation.
+ //
+ m_writer->emit("extern \"C\" { __constant__ UniformState SLANG_globalParams; }\n");
+ }
+
// As a convenience for anybody reading the generated
// CUDA C++ code, we will prefix a compute kernel
// with the information from the `[numthreads(...)]`
@@ -815,27 +831,86 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
emitEntryPointAttributes(func, entryPointDecor);
emitType(resultType, globalSymbolName);
- m_writer->emit("(UniformEntryPointParams* params, UniformState* uniformState)");
+ if( stage == Stage::Compute )
+ {
+ // CUDA compute shaders take all of their parameters explicitly as
+ // part of the entry-point parameter list. This means that the
+ // data representing Slang shader parameters at both the global
+ // and entry-point scopes needs to be passed as parameters.
+ //
+ // At the binary level, our generated CUDA compute kernels will take
+ // two pointer parameters: the first points to the per-entry-point
+ // `uniform` parameter data, and the second poinst to the global-scope
+ // parameter data (if any).
+ //
+ m_writer->emit("(UniformEntryPointParams* entryPointShaderParameters, UniformState* uniformState)");
+ }
+ else
+ {
+ // Non-compute shaders (currently just OptiX ray tracing kernels)
+ // rely on other mechanisms for parameter passing, and thus use
+ // an empty parameter list on the kernel declaration.
+ //
+ m_writer->emit("()");
+ }
+
emitSemantics(func);
m_writer->emit("\n{\n");
m_writer->indent();
// Initialize when constructing so that globals are zeroed
m_writer->emit("Context context = {};\n");
- m_writer->emit("context.uniformState = uniformState;\n");
+
+ // The global-scope parameter data got passed in differently depending on whether we have
+ // a compute shader or a ray-tracing shader, so we need to alter how we initialize
+ // the pointer in our `context` based on the stage.
+ //
+ if( stage == Stage::Compute )
+ {
+ m_writer->emit("context.uniformState = uniformState;\n");
+ }
+ else
+ {
+ m_writer->emit("context.uniformState = &SLANG_globalParams;\n");
+ }
if (entryPointGlobalParams)
{
auto varDecl = entryPointGlobalParams;
auto rawType = varDecl->getDataType();
-
auto varType = rawType;
m_writer->emit("context.");
m_writer->emit(getName(varDecl));
m_writer->emit(" = (");
emitType(varType);
- m_writer->emit("*)params; \n");
+ m_writer->emit("*)");
+
+ // Similar to the case for global parameter data above, the entry-point
+ // uniform parameter data gets passed in differently for compute kernels
+ // vs. ray-tracing kernels, and we need to handle the two cases here.
+ //
+ if( stage == Stage::Compute )
+ {
+ // In the compute case, the entry-point uniform parameters came
+ // in as an explicit parameter on the CUDA kernel, and we simply
+ // cast it to the expected type here.
+ //
+ m_writer->emit("entryPointShaderParameters");
+ }
+ else
+ {
+ // In the ray-tracing case, the entry-point uniform parameters
+ // implicitly map to the contents of the Shader Binding Table
+ // (SBT) entry for the entry point instance being invoked.
+ //
+ // The OptiX API provides an accessor function to get a pointer
+ // to the SBT data for the current entry, and we cast the result
+ // of that to the expected type.
+ //
+ m_writer->emit("optixGetSbtDataPointer()");
+ }
+ m_writer->emit(";\n");
}
m_writer->emit("context.");