diff options
Diffstat (limited to 'source/slang/slang-emit-cpp.cpp')
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 241 |
1 files changed, 85 insertions, 156 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 54c2257f2..b59611b38 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1955,38 +1955,48 @@ void CPPSourceEmitter::emitSimpleFuncImpl(IRFunc* func) // Deal with decorations that need // to be emitted as attributes - // We are going to ignore the parameters passed and just pass in the Context + // We start by emitting the result type and function name. + // if (IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>()) { + // Note: we currently emit multiple functions to represent an entry point + // on CPU/CUDA, and these all bottleneck through the actual `IRFunc` + // here as a workhorse. + // + // Because the workhorse function is currently emitted as a member of + // `KernelContext`, and doesn't have the right signature to service + // general-purpose calls, it is being emitted with a `_` prefix. + // StringBuilder prefixName; prefixName << "_" << name; emitType(resultType, prefixName); - m_writer->emit("()\n"); } else { emitType(resultType, name); + } - m_writer->emit("("); - auto firstParam = func->getFirstParam(); - for (auto pp = firstParam; pp; pp = pp->getNextParam()) - { - // Ingore TypeType-typed parameters for now. - // In the future we will pass around runtime type info - // for TypeType parameters. - if (as<IRTypeType>(pp->getFullType())) - continue; - - if (pp != firstParam) - m_writer->emit(", "); + // Next we emit the parameter list of the function. + // + m_writer->emit("("); + auto firstParam = func->getFirstParam(); + for (auto pp = firstParam; pp; pp = pp->getNextParam()) + { + // Ingore TypeType-typed parameters for now. + // In the future we will pass around runtime type info + // for TypeType parameters. + if (as<IRTypeType>(pp->getFullType())) + continue; - emitSimpleFuncParamImpl(pp); - } - m_writer->emit(")"); + if (pp != firstParam) + m_writer->emit(", "); - emitSemantics(func); + emitSimpleFuncParamImpl(pp); } + m_writer->emit(")"); + + emitSemantics(func); // TODO: encode declaration vs. definition if (isDefinition(func)) @@ -2431,40 +2441,6 @@ void CPPSourceEmitter::emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPre switch (inst->op) { - case kIROp_Param: - { - auto varLayout = getVarLayout(inst); - - if (varLayout) - { - if(auto systemValueSemantic = varLayout->findSystemValueSemanticAttr()) - { - String semanticNameSpelling = systemValueSemantic->getName(); - semanticNameSpelling = semanticNameSpelling.toLower(); - - if (semanticNameSpelling == "sv_dispatchthreadid") - { - m_semanticUsedFlags |= SemanticUsedFlag::DispatchThreadID; - m_writer->emit("dispatchThreadID"); - return; - } - else if (semanticNameSpelling == "sv_groupid") - { - m_semanticUsedFlags |= SemanticUsedFlag::GroupID; - m_writer->emit("groupID"); - return; - } - else if (semanticNameSpelling == "sv_groupthreadid") - { - m_semanticUsedFlags |= SemanticUsedFlag::GroupThreadID; - m_writer->emit("calcGroupThreadID()"); - return; - } - } - } - m_writer->emit(getName(inst)); - break; - } case kIROp_Var: case kIROp_GlobalVar: emitVarExpr(inst, outerPrec); @@ -2591,19 +2567,19 @@ void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroup const auto& axis = axes[i]; builder.Clear(); const char elem[2] = { s_elemNames[axis.axis], 0 }; - builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < start." << elem << " + " << axis.size << "; ++" << elem << ")\n{\n"; + builder << "for (uint32_t " << elem << " = 0; " << elem << " < " << axis.size << "; ++" << elem << ")\n{\n"; m_writer->emit(builder); m_writer->indent(); builder.Clear(); - builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n"; + builder << "threadInput.groupThreadID." << elem << " = " << elem << ";\n"; m_writer->emit(builder); } // just call at inner loop point m_writer->emit("context._"); m_writer->emit(funcName); - m_writer->emit("();\n"); + m_writer->emit("(&threadInput);\n"); // Close all the loops for (Index i = Index(axes.getCount() - 1); i >= 0; --i) @@ -2626,57 +2602,20 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThread builder.Clear(); const char elem[2] = { s_elemNames[axis.axis], 0 }; - if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - builder << "context.groupDispatchThreadID." << elem << " = start." << elem << ";\n"; - } - if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - builder << "context.groupID." << elem << " += varyingInput->startGroupID." << elem << ";\n"; - } - - builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < end." << elem << "; ++" << elem << ")\n{\n"; + builder << "for (uint32_t " << elem << " = vi.startGroupID." << elem << "; " << elem << " < vi.endGroupID." << elem << "; ++" << elem << ")\n{\n"; m_writer->emit(builder); m_writer->indent(); - builder.Clear(); - builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n"; - - if (m_semanticUsedFlags & (SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::GroupID)) - { - if (sizeAlongAxis[axis.axis] > 1) - { - builder << "const uint32_t next = context.groupDispatchThreadID." << elem << " + " << axis.size <<";\n"; - - if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - builder << "context.groupID." << elem << " += uint32_t(next == " << elem << ");\n"; - } - if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - builder << "context.groupDispatchThreadID." << elem << " = (" << elem << " == next) ? next : context.groupDispatchThreadID." << elem << ";\n"; - } - } - else - { - if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - builder << "context.groupDispatchThreadID." << elem << " = " << elem << ";\n"; - } - if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - builder << "context.groupID." << elem << " = " << elem << ";\n"; - } - } - } - - m_writer->emit(builder); + m_writer->emit("groupVaryingInput.startGroupID."); + m_writer->emit(elem); + m_writer->emit(" = "); + m_writer->emit(elem); + m_writer->emit(";\n"); } // just call at inner loop point - m_writer->emit("context._"); m_writer->emit(funcName); - m_writer->emit("();\n"); + m_writer->emit("_Group(&groupVaryingInput, entryPointParams, globalParams);\n"); // Close all the loops for (Index i = Index(axes.getCount() - 1); i >= 0; --i) @@ -2736,6 +2675,34 @@ void CPPSourceEmitter::_emitForwardDeclarations(const List<EmitAction>& actions) } } +static bool isVaryingResourceKind(LayoutResourceKind kind) +{ + switch(kind) + { + default: + return false; + + case LayoutResourceKind::VaryingInput: + case LayoutResourceKind::VaryingOutput: + return true; + } +} + +static bool isVaryingParameter(IRTypeLayout* typeLayout) +{ + for(auto sizeAttr : typeLayout->getSizeAttrs()) + { + if(!isVaryingResourceKind(sizeAttr->getResourceKind())) + return false; + } + return true; +} + +static bool isVaryingParameter(IRVarLayout* varLayout) +{ + return isVaryingParameter(varLayout->getTypeLayout()); +} + void CPPSourceEmitter::_findShaderParams( IRGlobalParam** outEntryPointParam, IRGlobalParam** outGlobalParam) @@ -2752,6 +2719,20 @@ void CPPSourceEmitter::_findShaderParams( if(!param) continue; + if(auto layoutDecor = param->findDecoration<IRLayoutDecoration>()) + { + if(auto varLayout = as<IRVarLayout>(layoutDecor->getLayout())) + { + if(isVaryingParameter(varLayout)) + continue; + auto typeLayout = varLayout->getTypeLayout(); + if(typeLayout->findSizeAttr(LayoutResourceKind::VaryingInput)) + continue; + if(typeLayout->findSizeAttr(LayoutResourceKind::VaryingOutput)) + continue; + } + } + // Currently, the entry-point parameters // are represented as a single parameter // at the global scope, and the same is @@ -2806,28 +2787,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) m_writer->emit("struct KernelContext\n{\n"); m_writer->indent(); - m_writer->emit("uint3 dispatchThreadID;\n"); - - //if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - // Note not always set! - m_writer->emit("uint3 groupID;\n"); - } - - //if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - m_writer->emit("uint3 groupDispatchThreadID;\n"); - - m_writer->emit("uint3 calcGroupThreadID() const \n{\n"); - m_writer->indent(); - // groupThreadID = dispatchThreadID - groupDispatchThreadID - m_writer->emit("uint3 v = { dispatchThreadID.x - groupDispatchThreadID.x, dispatchThreadID.y - groupDispatchThreadID.y, dispatchThreadID.z - groupDispatchThreadID.z };\n"); - m_writer->emit("return v;\n"); - m_writer->dedent(); - m_writer->emit("}\n"); - } - - if (globalParams) { emitGlobalInst(globalParams); @@ -2886,9 +2845,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) if (entryPointDecor && entryPointDecor->getProfile().getStage() == Stage::Compute) { - // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid - // SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID. - Int groupThreadSize[kThreadGroupAxisCount]; getComputeThreadGroupSize(func, groupThreadSize); @@ -2902,23 +2858,9 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) _emitEntryPointDefinitionStart(func, entryPointParams, globalParams, threadFuncName, UnownedStringSlice::fromLiteral("ComputeThreadVaryingInput")); - if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - m_writer->emit("context.groupDispatchThreadID = "); - _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice()); - } - if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - m_writer->emit("context.groupID = varyingInput->groupID;\n"); - } - - // Emit dispatchThreadID - m_writer->emit("context.dispatchThreadID = "); - _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID")); - m_writer->emit("context._"); m_writer->emit(funcName); - m_writer->emit("();\n"); + m_writer->emit("(varyingInput);\n"); _emitEntryPointDefinitionEnd(func); } @@ -2933,19 +2875,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) _emitEntryPointDefinitionStart(func, entryPointParams, globalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput")); - m_writer->emit("const uint3 start = "); - _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); - - if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) - { - m_writer->emit("context.groupDispatchThreadID = start;\n"); - } - - if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) - { - m_writer->emit("context.groupID = varyingInput->startGroupID;\n"); - } - m_writer->emit("context.dispatchThreadID = start;\n"); + m_writer->emit("ComputeThreadVaryingInput threadInput = {};\n"); + m_writer->emit("threadInput.groupID = varyingInput->startGroupID;\n"); _emitEntryPointGroup(groupThreadSize, funcName); _emitEntryPointDefinitionEnd(func); @@ -2955,10 +2886,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) { _emitEntryPointDefinitionStart(func, entryPointParams, globalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput")); - m_writer->emit("const uint3 start = "); - _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); - m_writer->emit("const uint3 end = "); - _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice()); + m_writer->emit("ComputeVaryingInput vi = *varyingInput;\n"); + m_writer->emit("ComputeVaryingInput groupVaryingInput = {};\n"); _emitEntryPointGroupRange(groupThreadSize, funcName); _emitEntryPointDefinitionEnd(func); |
