diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2020-07-10 14:30:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-07-10 14:30:57 -0700 |
| commit | 249f48dbb5e240c713661be969a6939ec57561e5 (patch) | |
| tree | 6e7c3f117931e640b89a63c0083c1f14e5b81ea4 /source/slang/slang-emit-cpp.cpp | |
| parent | 6aad38a43394a60c02c6109199d427d88147e781 (diff) | |
CUDA/CPU varying compute inputs as IR pass (#1438)
The main change here is that the CPU and CUDA C++ emit paths now rely on an earlier IR pass to legalize the varying parameter list of a kernel and translate references to varying parameters with semantics like `SV_DispatchThreadID`. Doing so removes a lot of special-case logic from the emit passes.
This work moves us even closer to being able to eliminate `KernelContext` from the CPU/CUDA emit logic, because it removes the issue of state related to varying inputs being stored in `KernelContext`.
The new pass that handles the legalization is in `slang-ir-legalize-varying-params.cpp`, and it borrows heavily from the existing `slang-ir-glsl-legalize.cpp` pass. The new pass factors out the target-independent and target-dependent logic, so that both CPU and CUDA can share much of the same code despite having very different rules for how the system-value parameters are being provided.
An eventual goal is to have the new pass also handle the GLSL case, but doing so requires copying even more logic out of the GLSL-specific pass, and doing so seemed like a step to far for what was meant to be a stepping-stone change as part of other work. As a result of the incomplete nature of the pass, certain cases don't work for compute shader inputs for CPU/CUDA (e.g., wrapping your varying inputs in a `struct` type parameter), but those were cases that also didn't work in the existing `emit`-based logic.
One major consequence of this change is that the logic for emitting the various different functions that represent an entry point for our CPU back-end has been streamlined and simplified. The original logic had a fair bit of cleverness built in to try and avoid unnecessary math ops when computing the various IDs/indices, while the new logic is much more simplistic (the main dispatch function loops over threadgroups with a triply-nested `for` and then delegates to the group-level function loops over threads with its own nested `for`s).
Longer term, it will be important to simplify the CPU functions we emit further, by eliminating things like the `_Thread` function that should never really be exposed to users (the minimum granularity of invoking a CPU compute kernel should be a single threadgroup). We may eventually decide to synthesize all of the extra code that is being generated in the `emit` pass as IR instead.
Diffstat (limited to 'source/slang/slang-emit-cpp.cpp')
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 241 |
1 files changed, 85 insertions, 156 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 54c2257f2..b59611b38 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1955,38 +1955,48 @@ void CPPSourceEmitter::emitSimpleFuncImpl(IRFunc* func) // Deal with decorations that need // to be emitted as attributes - // We are going to ignore the parameters passed and just pass in the Context + // We start by emitting the result type and function name. + // if (IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>()) { + // Note: we currently emit multiple functions to represent an entry point + // on CPU/CUDA, and these all bottleneck through the actual `IRFunc` + // here as a workhorse. + // + // Because the workhorse function is currently emitted as a member of + // `KernelContext`, and doesn't have the right signature to service + // general-purpose calls, it is being emitted with a `_` prefix. + // StringBuilder prefixName; prefixName << "_" << name; emitType(resultType, prefixName); - m_writer->emit("()\n"); } else { emitType(resultType, name); + } - m_writer->emit("("); - auto firstParam = func->getFirstParam(); - for (auto pp = firstParam; pp; pp = pp->getNextParam()) - { - // Ingore TypeType-typed parameters for now. - // In the future we will pass around runtime type info - // for TypeType parameters. - if (as<IRTypeType>(pp->getFullType())) - continue; - - if (pp != firstParam) - m_writer->emit(", "); + // Next we emit the parameter list of the function. + // + m_writer->emit("("); + auto firstParam = func->getFirstParam(); + for (auto pp = firstParam; pp; pp = pp->getNextParam()) + { + // Ingore TypeType-typed parameters for now. + // In the future we will pass around runtime type info + // for TypeType parameters. + if (as<IRTypeType>(pp->getFullType())) + continue; - emitSimpleFuncParamImpl(pp); - } - m_writer->emit(")"); + if (pp != firstParam) + m_writer->emit(", "); - emitSemantics(func); + emitSimpleFuncParamImpl(pp); } + m_writer->emit(")"); + + emitSemantics(func); // TODO: encode declaration vs. definition if (isDefinition(func)) @@ -2431,40 +2441,6 @@ void CPPSourceEmitter::emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPre switch (inst->op) { - case kIROp_Param: - { - auto varLayout = getVarLayout(inst); - - if (varLayout) - { - if(auto systemValueSemantic = varLayout->findSystemValueSemanticAttr()) - { - String semanticNameSpelling = systemValueSemantic->getName(); - semanticNameSpelling = semanticNameSpelling.toLower(); - - if (semanticNameSpelling == "sv_dispatchthreadid") - { - m_semanticUsedFlags |= SemanticUsedFlag::DispatchThreadID; - m_writer->emit("dispatchThreadID"); - return; - } - else if (semanticNameSpelling == "sv_groupid") - { - m_semanticUsedFlags |= SemanticUsedFlag::GroupID; - m_writer->emit("groupID"); - return; - } - else if (semanticNameSpelling == "sv_groupthreadid") - { - m_semanticUsedFlags |= SemanticUsedFlag::GroupThreadID; - m_writer->emit("calcGroupThreadID()"); - return; - } - } - } - m_writer->emit(getName(inst)); - break; - } case kIROp_Var: case kIROp_GlobalVar: emitVarExpr(inst, outerPrec); @@ -2591,19 +2567,19 @@ void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroup const auto& axis = axes[i]; builder.Clear(); const char elem[2] = { s_elemNames[axis.axis], 0 }; - builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < start." << elem << " + " << axis.size << "; ++" << elem << ")\n{\n"; + builder << "for (uint32_t " << elem << " = 0; " << elem << " < " << axis.size << "; ++" << elem << ")\n{\n"; m_writer->emit(builder); m_writer->indent(); builder.Clear(); - builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n"; + builder << "threadInput.groupThreadID." << elem << " = " << elem << ";\n"; m_writer->emit(builder); } // just call at inner loop point m_writer->emit("context._"); m_writer->emit(funcName); - m_writer->emit("();\n"); + m_writer->emit("(&threadInput);\n"); // Close all the loops for (Index i = Index(axes.getCount() - 1); i >= 0; --i) @@ -2626,57 +2602,20 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThread builder.Clear(); const char elem[2] = { s_elemNames[axis.axis], 0 }; - if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - builder << "context.groupDispatchThreadID." << elem << " = start." << elem << ";\n"; - } - if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - builder << "context.groupID." << elem << " += varyingInput->startGroupID." << elem << ";\n"; - } - - builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < end." << elem << "; ++" << elem << ")\n{\n"; + builder << "for (uint32_t " << elem << " = vi.startGroupID." << elem << "; " << elem << " < vi.endGroupID." << elem << "; ++" << elem << ")\n{\n"; m_writer->emit(builder); m_writer->indent(); - builder.Clear(); - builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n"; - - if (m_semanticUsedFlags & (SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::GroupID)) - { - if (sizeAlongAxis[axis.axis] > 1) - { - builder << "const uint32_t next = context.groupDispatchThreadID." << elem << " + " << axis.size <<";\n"; - - if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - builder << "context.groupID." << elem << " += uint32_t(next == " << elem << ");\n"; - } - if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - builder << "context.groupDispatchThreadID." << elem << " = (" << elem << " == next) ? next : context.groupDispatchThreadID." << elem << ";\n"; - } - } - else - { - if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - builder << "context.groupDispatchThreadID." << elem << " = " << elem << ";\n"; - } - if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - builder << "context.groupID." << elem << " = " << elem << ";\n"; - } - } - } - - m_writer->emit(builder); + m_writer->emit("groupVaryingInput.startGroupID."); + m_writer->emit(elem); + m_writer->emit(" = "); + m_writer->emit(elem); + m_writer->emit(";\n"); } // just call at inner loop point - m_writer->emit("context._"); m_writer->emit(funcName); - m_writer->emit("();\n"); + m_writer->emit("_Group(&groupVaryingInput, entryPointParams, globalParams);\n"); // Close all the loops for (Index i = Index(axes.getCount() - 1); i >= 0; --i) @@ -2736,6 +2675,34 @@ void CPPSourceEmitter::_emitForwardDeclarations(const List<EmitAction>& actions) } } +static bool isVaryingResourceKind(LayoutResourceKind kind) +{ + switch(kind) + { + default: + return false; + + case LayoutResourceKind::VaryingInput: + case LayoutResourceKind::VaryingOutput: + return true; + } +} + +static bool isVaryingParameter(IRTypeLayout* typeLayout) +{ + for(auto sizeAttr : typeLayout->getSizeAttrs()) + { + if(!isVaryingResourceKind(sizeAttr->getResourceKind())) + return false; + } + return true; +} + +static bool isVaryingParameter(IRVarLayout* varLayout) +{ + return isVaryingParameter(varLayout->getTypeLayout()); +} + void CPPSourceEmitter::_findShaderParams( IRGlobalParam** outEntryPointParam, IRGlobalParam** outGlobalParam) @@ -2752,6 +2719,20 @@ void CPPSourceEmitter::_findShaderParams( if(!param) continue; + if(auto layoutDecor = param->findDecoration<IRLayoutDecoration>()) + { + if(auto varLayout = as<IRVarLayout>(layoutDecor->getLayout())) + { + if(isVaryingParameter(varLayout)) + continue; + auto typeLayout = varLayout->getTypeLayout(); + if(typeLayout->findSizeAttr(LayoutResourceKind::VaryingInput)) + continue; + if(typeLayout->findSizeAttr(LayoutResourceKind::VaryingOutput)) + continue; + } + } + // Currently, the entry-point parameters // are represented as a single parameter // at the global scope, and the same is @@ -2806,28 +2787,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) m_writer->emit("struct KernelContext\n{\n"); m_writer->indent(); - m_writer->emit("uint3 dispatchThreadID;\n"); - - //if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - // Note not always set! - m_writer->emit("uint3 groupID;\n"); - } - - //if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - m_writer->emit("uint3 groupDispatchThreadID;\n"); - - m_writer->emit("uint3 calcGroupThreadID() const \n{\n"); - m_writer->indent(); - // groupThreadID = dispatchThreadID - groupDispatchThreadID - m_writer->emit("uint3 v = { dispatchThreadID.x - groupDispatchThreadID.x, dispatchThreadID.y - groupDispatchThreadID.y, dispatchThreadID.z - groupDispatchThreadID.z };\n"); - m_writer->emit("return v;\n"); - m_writer->dedent(); - m_writer->emit("}\n"); - } - - if (globalParams) { emitGlobalInst(globalParams); @@ -2886,9 +2845,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) if (entryPointDecor && entryPointDecor->getProfile().getStage() == Stage::Compute) { - // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid - // SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID. - Int groupThreadSize[kThreadGroupAxisCount]; getComputeThreadGroupSize(func, groupThreadSize); @@ -2902,23 +2858,9 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) _emitEntryPointDefinitionStart(func, entryPointParams, globalParams, threadFuncName, UnownedStringSlice::fromLiteral("ComputeThreadVaryingInput")); - if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - m_writer->emit("context.groupDispatchThreadID = "); - _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice()); - } - if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - m_writer->emit("context.groupID = varyingInput->groupID;\n"); - } - - // Emit dispatchThreadID - m_writer->emit("context.dispatchThreadID = "); - _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID")); - m_writer->emit("context._"); m_writer->emit(funcName); - m_writer->emit("();\n"); + m_writer->emit("(varyingInput);\n"); _emitEntryPointDefinitionEnd(func); } @@ -2933,19 +2875,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) _emitEntryPointDefinitionStart(func, entryPointParams, globalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput")); - m_writer->emit("const uint3 start = "); - _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); - - if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - m_writer->emit("context.groupDispatchThreadID = start;\n"); - } - - if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - m_writer->emit("context.groupID = varyingInput->startGroupID;\n"); - } - m_writer->emit("context.dispatchThreadID = start;\n"); + m_writer->emit("ComputeThreadVaryingInput threadInput = {};\n"); + m_writer->emit("threadInput.groupID = varyingInput->startGroupID;\n"); _emitEntryPointGroup(groupThreadSize, funcName); _emitEntryPointDefinitionEnd(func); @@ -2955,10 +2886,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) { _emitEntryPointDefinitionStart(func, entryPointParams, globalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput")); - m_writer->emit("const uint3 start = "); - _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); - m_writer->emit("const uint3 end = "); - _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice()); + m_writer->emit("ComputeVaryingInput vi = *varyingInput;\n"); + m_writer->emit("ComputeVaryingInput groupVaryingInput = {};\n"); _emitEntryPointGroupRange(groupThreadSize, funcName); _emitEntryPointDefinitionEnd(func); |
