diff options
Diffstat (limited to 'source/slang/slang-emit-cuda.cpp')
| -rw-r--r-- | source/slang/slang-emit-cuda.cpp | 83 |
1 files changed, 79 insertions, 4 deletions
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index ac1e1ea63..fa63ba255 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -787,6 +787,22 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) #undef CASE } + if( stage != Stage::Compute ) + { + // Non-compute shaders (currently just OptiX ray tracing kernels) + // require parameter data that is shared across multiple kernels + // (which in our case is the global-scope shader parameters) + // to be passed using a global `__constant__` variable. + // + // The use of `"C"` linkage here is required because the name + // of this symbol must be passed to the OptiX API when creating + // a pipeline that uses this compiled module. The exact name + // used here (`SLANG_globalParams`) is thus a part of the + // binary interface for Slang->OptiX translation. + // + m_writer->emit("extern \"C\" { __constant__ UniformState SLANG_globalParams; }\n"); + } + // As a convenience for anybody reading the generated // CUDA C++ code, we will prefix a compute kernel // with the information from the `[numthreads(...)]` @@ -815,27 +831,86 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) emitEntryPointAttributes(func, entryPointDecor); emitType(resultType, globalSymbolName); - m_writer->emit("(UniformEntryPointParams* params, UniformState* uniformState)"); + if( stage == Stage::Compute ) + { + // CUDA compute shaders take all of their parameters explicitly as + // part of the entry-point parameter list. This means that the + // data representing Slang shader parameters at both the global + // and entry-point scopes needs to be passed as parameters. + // + // At the binary level, our generated CUDA compute kernels will take + // two pointer parameters: the first points to the per-entry-point + // `uniform` parameter data, and the second poinst to the global-scope + // parameter data (if any). + // + m_writer->emit("(UniformEntryPointParams* entryPointShaderParameters, UniformState* uniformState)"); + } + else + { + // Non-compute shaders (currently just OptiX ray tracing kernels) + // rely on other mechanisms for parameter passing, and thus use + // an empty parameter list on the kernel declaration. + // + m_writer->emit("()"); + } + emitSemantics(func); m_writer->emit("\n{\n"); m_writer->indent(); // Initialize when constructing so that globals are zeroed m_writer->emit("Context context = {};\n"); - m_writer->emit("context.uniformState = uniformState;\n"); + + // The global-scope parameter data got passed in differently depending on whether we have + // a compute shader or a ray-tracing shader, so we need to alter how we initialize + // the pointer in our `context` based on the stage. + // + if( stage == Stage::Compute ) + { + m_writer->emit("context.uniformState = uniformState;\n"); + } + else + { + m_writer->emit("context.uniformState = &SLANG_globalParams;\n"); + } if (entryPointGlobalParams) { auto varDecl = entryPointGlobalParams; auto rawType = varDecl->getDataType(); - auto varType = rawType; m_writer->emit("context."); m_writer->emit(getName(varDecl)); m_writer->emit(" = ("); emitType(varType); - m_writer->emit("*)params; \n"); + m_writer->emit("*)"); + + // Similar to the case for global parameter data above, the entry-point + // uniform parameter data gets passed in differently for compute kernels + // vs. ray-tracing kernels, and we need to handle the two cases here. + // + if( stage == Stage::Compute ) + { + // In the compute case, the entry-point uniform parameters came + // in as an explicit parameter on the CUDA kernel, and we simply + // cast it to the expected type here. + // + m_writer->emit("entryPointShaderParameters"); + } + else + { + // In the ray-tracing case, the entry-point uniform parameters + // implicitly map to the contents of the Shader Binding Table + // (SBT) entry for the entry point instance being invoked. + // + // The OptiX API provides an accessor function to get a pointer + // to the SBT data for the current entry, and we cast the result + // of that to the expected type. + // + m_writer->emit("optixGetSbtDataPointer()"); + } + m_writer->emit(";\n"); } m_writer->emit("context."); |
