diff options
Diffstat (limited to 'source/slang/slang-emit-cuda.cpp')
| -rw-r--r-- | source/slang/slang-emit-cuda.cpp | 74 |
1 files changed, 44 insertions, 30 deletions
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index b79518052..d05c4edfc 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -711,19 +711,9 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) _emitForwardDeclarations(actions); - IRGlobalParam* entryPointGlobalParams = nullptr; - - // Output the global parameters in a 'UniformState' structure - { - m_writer->emit("struct UniformState\n{\n"); - m_writer->indent(); - - // We need these to be prefixed by __device__ - _emitUniformStateMembers(actions, &entryPointGlobalParams); - - m_writer->dedent(); - m_writer->emit("\n};\n\n"); - } + IRGlobalParam* entryPointParams = nullptr; + IRGlobalParam* globalParams = nullptr; + _findShaderParams(&entryPointParams, &globalParams); // Output group shared variables @@ -742,11 +732,13 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) m_writer->emit("struct KernelContext\n{\n"); m_writer->indent(); - m_writer->emit("UniformState* uniformState;\n"); - - if (entryPointGlobalParams) + if (globalParams) { - emitGlobalInst(entryPointGlobalParams); + emitGlobalInst(globalParams); + } + if (entryPointParams) + { + emitGlobalInst(entryPointParams); } // Output all the thread locals @@ -813,7 +805,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) #undef CASE } - if( stage != Stage::Compute ) + if(globalParams && stage != Stage::Compute ) { // Non-compute shaders (currently just OptiX ray tracing kernels) // require parameter data that is shared across multiple kernels @@ -826,7 +818,20 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) // used here (`SLANG_globalParams`) is thus a part of the // binary interface for Slang->OptiX translation. // - m_writer->emit("extern \"C\" { __constant__ UniformState SLANG_globalParams; }\n"); + // TODO: We need to make a decision about how indirected + // the parameter passing for global-scope data is going to + // be for CUDA and OptiX (ideally with an answer that is + // consistent across the two). For now we are deciding to + // make this global `__constant__` variable represent the + // global parameter data directly, rather than indirectly. + // + auto globalParamsPtrType = as<IRPointerLikeType>(globalParams->getDataType()); + SLANG_ASSERT(globalParamsPtrType); + auto gloablParamsElementType = globalParamsPtrType->getElementType(); + // + m_writer->emit("extern \"C\" { __constant__ "); + emitType(gloablParamsElementType, "SLANG_globalParams"); + m_writer->emit("; }\n"); } // As a convenience for anybody reading the generated @@ -869,7 +874,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) // `uniform` parameter data, and the second points to the global-scope // parameter data (if any). // - m_writer->emit("(void* entryPointShaderParameters, void* uniformState)"); + m_writer->emit("(void* entryPointParams, void* globalParams)"); } else { @@ -891,18 +896,27 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) // a compute shader or a ray-tracing shader, so we need to alter how we initialize // the pointer in our `context` based on the stage. // - if( stage == Stage::Compute ) - { - m_writer->emit("context.uniformState = (UniformState*)uniformState;\n"); - } - else + if( globalParams ) { - m_writer->emit("context.uniformState = &SLANG_globalParams;\n"); + if( stage == Stage::Compute ) + { + m_writer->emit("context."); + m_writer->emit(getName(globalParams)); + m_writer->emit(" = ("); + emitType(globalParams->getDataType()); + m_writer->emit(")globalParams;\n"); + } + else + { + m_writer->emit("context."); + m_writer->emit(getName(globalParams)); + m_writer->emit(" = &SLANG_globalParams;\n"); + } } - if (entryPointGlobalParams) + if (entryPointParams) { - auto varDecl = entryPointGlobalParams; + auto varDecl = entryPointParams; auto rawType = varDecl->getDataType(); auto varType = rawType; @@ -910,7 +924,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) m_writer->emit(getName(varDecl)); m_writer->emit(" = ("); emitType(varType); - m_writer->emit("*)"); + m_writer->emit(")"); // Similar to the case for global parameter data above, the entry-point // uniform parameter data gets passed in differently for compute kernels @@ -922,7 +936,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module) // in as an explicit parameter on the CUDA kernel, and we simply // cast it to the expected type here. // - m_writer->emit("entryPointShaderParameters"); + m_writer->emit("entryPointParams"); } else { |
