summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-cuda.cpp')
-rw-r--r--source/slang/slang-emit-cuda.cpp74
1 files changed, 44 insertions, 30 deletions
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index b79518052..d05c4edfc 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -711,19 +711,9 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
_emitForwardDeclarations(actions);
- IRGlobalParam* entryPointGlobalParams = nullptr;
-
- // Output the global parameters in a 'UniformState' structure
- {
- m_writer->emit("struct UniformState\n{\n");
- m_writer->indent();
-
- // We need these to be prefixed by __device__
- _emitUniformStateMembers(actions, &entryPointGlobalParams);
-
- m_writer->dedent();
- m_writer->emit("\n};\n\n");
- }
+ IRGlobalParam* entryPointParams = nullptr;
+ IRGlobalParam* globalParams = nullptr;
+ _findShaderParams(&entryPointParams, &globalParams);
// Output group shared variables
@@ -742,11 +732,13 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
m_writer->emit("struct KernelContext\n{\n");
m_writer->indent();
- m_writer->emit("UniformState* uniformState;\n");
-
- if (entryPointGlobalParams)
+ if (globalParams)
{
- emitGlobalInst(entryPointGlobalParams);
+ emitGlobalInst(globalParams);
+ }
+ if (entryPointParams)
+ {
+ emitGlobalInst(entryPointParams);
}
// Output all the thread locals
@@ -813,7 +805,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
#undef CASE
}
- if( stage != Stage::Compute )
+ if(globalParams && stage != Stage::Compute )
{
// Non-compute shaders (currently just OptiX ray tracing kernels)
// require parameter data that is shared across multiple kernels
@@ -826,7 +818,20 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
// used here (`SLANG_globalParams`) is thus a part of the
// binary interface for Slang->OptiX translation.
//
- m_writer->emit("extern \"C\" { __constant__ UniformState SLANG_globalParams; }\n");
+ // TODO: We need to make a decision about how indirected
+ // the parameter passing for global-scope data is going to
+ // be for CUDA and OptiX (ideally with an answer that is
+ // consistent across the two). For now we are deciding to
+ // make this global `__constant__` variable represent the
+ // global parameter data directly, rather than indirectly.
+ //
+ auto globalParamsPtrType = as<IRPointerLikeType>(globalParams->getDataType());
+ SLANG_ASSERT(globalParamsPtrType);
+ auto gloablParamsElementType = globalParamsPtrType->getElementType();
+ //
+ m_writer->emit("extern \"C\" { __constant__ ");
+ emitType(gloablParamsElementType, "SLANG_globalParams");
+ m_writer->emit("; }\n");
}
// As a convenience for anybody reading the generated
@@ -869,7 +874,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
// `uniform` parameter data, and the second points to the global-scope
// parameter data (if any).
//
- m_writer->emit("(void* entryPointShaderParameters, void* uniformState)");
+ m_writer->emit("(void* entryPointParams, void* globalParams)");
}
else
{
@@ -891,18 +896,27 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
// a compute shader or a ray-tracing shader, so we need to alter how we initialize
// the pointer in our `context` based on the stage.
//
- if( stage == Stage::Compute )
- {
- m_writer->emit("context.uniformState = (UniformState*)uniformState;\n");
- }
- else
+ if( globalParams )
{
- m_writer->emit("context.uniformState = &SLANG_globalParams;\n");
+ if( stage == Stage::Compute )
+ {
+ m_writer->emit("context.");
+ m_writer->emit(getName(globalParams));
+ m_writer->emit(" = (");
+ emitType(globalParams->getDataType());
+ m_writer->emit(")globalParams;\n");
+ }
+ else
+ {
+ m_writer->emit("context.");
+ m_writer->emit(getName(globalParams));
+ m_writer->emit(" = &SLANG_globalParams;\n");
+ }
}
- if (entryPointGlobalParams)
+ if (entryPointParams)
{
- auto varDecl = entryPointGlobalParams;
+ auto varDecl = entryPointParams;
auto rawType = varDecl->getDataType();
auto varType = rawType;
@@ -910,7 +924,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
m_writer->emit(getName(varDecl));
m_writer->emit(" = (");
emitType(varType);
- m_writer->emit("*)");
+ m_writer->emit(")");
// Similar to the case for global parameter data above, the entry-point
// uniform parameter data gets passed in differently for compute kernels
@@ -922,7 +936,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
// in as an explicit parameter on the CUDA kernel, and we simply
// cast it to the expected type here.
//
- m_writer->emit("entryPointShaderParameters");
+ m_writer->emit("entryPointParams");
}
else
{