summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-cuda.cpp')
-rw-r--r--source/slang/slang-emit-cuda.cpp234
1 files changed, 13 insertions, 221 deletions
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index c7dee9f9d..6f24d5b74 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -248,6 +248,19 @@ void CUDASourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin
SLANG_UNUSED(entryPointDecor);
}
+void CUDASourceEmitter::emitFunctionPreambleImpl(IRInst* inst)
+{
+ if(inst && inst->findDecoration<IREntryPointDecoration>())
+ {
+ m_writer->emit("extern \"C\" __global__ ");
+ }
+ else
+ {
+ m_writer->emit("__device__ ");
+ }
+}
+
+
void CUDASourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec)
{
switch (specOp->op)
@@ -661,10 +674,6 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
_emitForwardDeclarations(actions);
- IRGlobalParam* entryPointParams = nullptr;
- IRGlobalParam* globalParams = nullptr;
- _findShaderParams(&entryPointParams, &globalParams);
-
// Output group shared variables
{
@@ -677,20 +686,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
}
}
- // Output the 'Context' which will be used for execution
{
- m_writer->emit("struct KernelContext\n{\n");
- m_writer->indent();
-
- if (globalParams)
- {
- emitGlobalInst(globalParams);
- }
- if (entryPointParams)
- {
- emitGlobalInst(entryPointParams);
- }
-
// Output all the thread locals
for (auto action : actions)
{
@@ -708,211 +704,7 @@ void CUDASourceEmitter::emitModuleImpl(IRModule* module)
emitGlobalInst(action.inst);
}
}
-
- m_writer->dedent();
- m_writer->emit("};\n\n");
}
-
- // Finally we need to output dll entry points
-
- for (auto action : actions)
- {
- if (action.level == EmitAction::Level::Definition && as<IRFunc>(action.inst))
- {
- IRFunc* func = as<IRFunc>(action.inst);
-
- IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>();
-
- if (entryPointDecor)
- {
- // We have an entry-point function in the IR module, which we
- // will want to emit as a `__global__` function in the generated
- // CUDA C++.
- //
- // The most common case will be a compute kernel, in which case
- // we will emit the function more or less as-is, including
- // usingits original name as the name of the global symbol.
- //
- String funcName = getName(func);
- String globalSymbolName = funcName;
-
- // We also suport emitting ray tracing kernels for use with
- // OptiX, and in that case the name of the global symbol
- // must be prefixed to indicate to the OptiX runtime what
- // stage it is to be compiled for.
- //
- auto stage = entryPointDecor->getProfile().getStage();
- switch( stage )
- {
- default:
- break;
-
- #define CASE(STAGE, PREFIX) \
- case Stage::STAGE: globalSymbolName = #PREFIX + funcName; break
-
- CASE(RayGeneration, __raygen__);
- // TODO: Add the other ray tracing shader stages here.
- #undef CASE
- }
-
- if(globalParams && stage != Stage::Compute )
- {
- // Non-compute shaders (currently just OptiX ray tracing kernels)
- // require parameter data that is shared across multiple kernels
- // (which in our case is the global-scope shader parameters)
- // to be passed using a global `__constant__` variable.
- //
- // The use of `"C"` linkage here is required because the name
- // of this symbol must be passed to the OptiX API when creating
- // a pipeline that uses this compiled module. The exact name
- // used here (`SLANG_globalParams`) is thus a part of the
- // binary interface for Slang->OptiX translation.
- //
- // TODO: We need to make a decision about how indirected
- // the parameter passing for global-scope data is going to
- // be for CUDA and OptiX (ideally with an answer that is
- // consistent across the two). For now we are deciding to
- // make this global `__constant__` variable represent the
- // global parameter data directly, rather than indirectly.
- //
- auto globalParamsPtrType = as<IRPointerLikeType>(globalParams->getDataType());
- SLANG_ASSERT(globalParamsPtrType);
- auto gloablParamsElementType = globalParamsPtrType->getElementType();
- //
- m_writer->emit("extern \"C\" { __constant__ ");
- emitType(gloablParamsElementType, "SLANG_globalParams");
- m_writer->emit("; }\n");
- }
-
- // As a convenience for anybody reading the generated
- // CUDA C++ code, we will prefix a compute kernel
- // with the information from the `[numthreads(...)]`
- // attribute in the source.
- //
- if(stage == Stage::Compute)
- {
- Int sizeAlongAxis[kThreadGroupAxisCount];
- getComputeThreadGroupSize(func, sizeAlongAxis);
-
- //
- m_writer->emit("// [numthreads(");
- for (int ii = 0; ii < kThreadGroupAxisCount; ++ii)
- {
- if (ii != 0) m_writer->emit(", ");
- m_writer->emit(sizeAlongAxis[ii]);
- }
- m_writer->emit(")]\n");
- }
-
- m_writer->emit("extern \"C\" __global__ ");
-
- auto resultType = func->getResultType();
-
- // Emit the actual function
- emitEntryPointAttributes(func, entryPointDecor);
- emitType(resultType, globalSymbolName);
-
- if( stage == Stage::Compute )
- {
- // CUDA compute shaders take all of their parameters explicitly as
- // part of the entry-point parameter list. This means that the
- // data representing Slang shader parameters at both the global
- // and entry-point scopes needs to be passed as parameters.
- //
- // At the binary level, our generated CUDA compute kernels will take
- // two pointer parameters: the first points to the per-entry-point
- // `uniform` parameter data, and the second points to the global-scope
- // parameter data (if any).
- //
- m_writer->emit("(void* entryPointParams, void* globalParams)");
- }
- else
- {
- // Non-compute shaders (currently just OptiX ray tracing kernels)
- // rely on other mechanisms for parameter passing, and thus use
- // an empty parameter list on the kernel declaration.
- //
- m_writer->emit("()");
- }
-
- emitSemantics(func);
- m_writer->emit("\n{\n");
- m_writer->indent();
-
- // Initialize when constructing so that globals are zeroed
- m_writer->emit("KernelContext context = {};\n");
-
- // The global-scope parameter data got passed in differently depending on whether we have
- // a compute shader or a ray-tracing shader, so we need to alter how we initialize
- // the pointer in our `context` based on the stage.
- //
- if( globalParams )
- {
- if( stage == Stage::Compute )
- {
- m_writer->emit("context.");
- m_writer->emit(getName(globalParams));
- m_writer->emit(" = (");
- emitType(globalParams->getDataType());
- m_writer->emit(")globalParams;\n");
- }
- else
- {
- m_writer->emit("context.");
- m_writer->emit(getName(globalParams));
- m_writer->emit(" = &SLANG_globalParams;\n");
- }
- }
-
- if (entryPointParams)
- {
- auto varDecl = entryPointParams;
- auto rawType = varDecl->getDataType();
- auto varType = rawType;
-
- m_writer->emit("context.");
- m_writer->emit(getName(varDecl));
- m_writer->emit(" = (");
- emitType(varType);
- m_writer->emit(")");
-
- // Similar to the case for global parameter data above, the entry-point
- // uniform parameter data gets passed in differently for compute kernels
- // vs. ray-tracing kernels, and we need to handle the two cases here.
- //
- if( stage == Stage::Compute )
- {
- // In the compute case, the entry-point uniform parameters came
- // in as an explicit parameter on the CUDA kernel, and we simply
- // cast it to the expected type here.
- //
- m_writer->emit("entryPointParams");
- }
- else
- {
- // In the ray-tracing case, the entry-point uniform parameters
- // implicitly map to the contents of the Shader Binding Table
- // (SBT) entry for the entry point instance being invoked.
- //
- // The OptiX API provides an accessor function to get a pointer
- // to the SBT data for the current entry, and we cast the result
- // of that to the expected type.
- //
- m_writer->emit("optixGetSbtDataPointer()");
- }
- m_writer->emit(";\n");
- }
-
- m_writer->emit("context.");
- m_writer->emit(funcName);
- m_writer->emit("();\n");
-
- m_writer->dedent();
- m_writer->emit("}\n");
- }
- }
- }
-
}