summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-cpp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-cpp.cpp')
-rw-r--r--source/slang/slang-emit-cpp.cpp241
1 files changed, 85 insertions, 156 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 54c2257f2..b59611b38 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1955,38 +1955,48 @@ void CPPSourceEmitter::emitSimpleFuncImpl(IRFunc* func)
// Deal with decorations that need
// to be emitted as attributes
- // We are going to ignore the parameters passed and just pass in the Context
+ // We start by emitting the result type and function name.
+ //
if (IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>())
{
+ // Note: we currently emit multiple functions to represent an entry point
+ // on CPU/CUDA, and these all bottleneck through the actual `IRFunc`
+ // here as a workhorse.
+ //
+ // Because the workhorse function is currently emitted as a member of
+ // `KernelContext`, and doesn't have the right signature to service
+ // general-purpose calls, it is being emitted with a `_` prefix.
+ //
StringBuilder prefixName;
prefixName << "_" << name;
emitType(resultType, prefixName);
- m_writer->emit("()\n");
}
else
{
emitType(resultType, name);
+ }
- m_writer->emit("(");
- auto firstParam = func->getFirstParam();
- for (auto pp = firstParam; pp; pp = pp->getNextParam())
- {
- // Ingore TypeType-typed parameters for now.
- // In the future we will pass around runtime type info
- // for TypeType parameters.
- if (as<IRTypeType>(pp->getFullType()))
- continue;
-
- if (pp != firstParam)
- m_writer->emit(", ");
+ // Next we emit the parameter list of the function.
+ //
+ m_writer->emit("(");
+ auto firstParam = func->getFirstParam();
+ for (auto pp = firstParam; pp; pp = pp->getNextParam())
+ {
+ // Ingore TypeType-typed parameters for now.
+ // In the future we will pass around runtime type info
+ // for TypeType parameters.
+ if (as<IRTypeType>(pp->getFullType()))
+ continue;
- emitSimpleFuncParamImpl(pp);
- }
- m_writer->emit(")");
+ if (pp != firstParam)
+ m_writer->emit(", ");
- emitSemantics(func);
+ emitSimpleFuncParamImpl(pp);
}
+ m_writer->emit(")");
+
+ emitSemantics(func);
// TODO: encode declaration vs. definition
if (isDefinition(func))
@@ -2431,40 +2441,6 @@ void CPPSourceEmitter::emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPre
switch (inst->op)
{
- case kIROp_Param:
- {
- auto varLayout = getVarLayout(inst);
-
- if (varLayout)
- {
- if(auto systemValueSemantic = varLayout->findSystemValueSemanticAttr())
- {
- String semanticNameSpelling = systemValueSemantic->getName();
- semanticNameSpelling = semanticNameSpelling.toLower();
-
- if (semanticNameSpelling == "sv_dispatchthreadid")
- {
- m_semanticUsedFlags |= SemanticUsedFlag::DispatchThreadID;
- m_writer->emit("dispatchThreadID");
- return;
- }
- else if (semanticNameSpelling == "sv_groupid")
- {
- m_semanticUsedFlags |= SemanticUsedFlag::GroupID;
- m_writer->emit("groupID");
- return;
- }
- else if (semanticNameSpelling == "sv_groupthreadid")
- {
- m_semanticUsedFlags |= SemanticUsedFlag::GroupThreadID;
- m_writer->emit("calcGroupThreadID()");
- return;
- }
- }
- }
- m_writer->emit(getName(inst));
- break;
- }
case kIROp_Var:
case kIROp_GlobalVar:
emitVarExpr(inst, outerPrec);
@@ -2591,19 +2567,19 @@ void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroup
const auto& axis = axes[i];
builder.Clear();
const char elem[2] = { s_elemNames[axis.axis], 0 };
- builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < start." << elem << " + " << axis.size << "; ++" << elem << ")\n{\n";
+ builder << "for (uint32_t " << elem << " = 0; " << elem << " < " << axis.size << "; ++" << elem << ")\n{\n";
m_writer->emit(builder);
m_writer->indent();
builder.Clear();
- builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n";
+ builder << "threadInput.groupThreadID." << elem << " = " << elem << ";\n";
m_writer->emit(builder);
}
// just call at inner loop point
m_writer->emit("context._");
m_writer->emit(funcName);
- m_writer->emit("();\n");
+ m_writer->emit("(&threadInput);\n");
// Close all the loops
for (Index i = Index(axes.getCount() - 1); i >= 0; --i)
@@ -2626,57 +2602,20 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThread
builder.Clear();
const char elem[2] = { s_elemNames[axis.axis], 0 };
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- builder << "context.groupDispatchThreadID." << elem << " = start." << elem << ";\n";
- }
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- builder << "context.groupID." << elem << " += varyingInput->startGroupID." << elem << ";\n";
- }
-
- builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < end." << elem << "; ++" << elem << ")\n{\n";
+ builder << "for (uint32_t " << elem << " = vi.startGroupID." << elem << "; " << elem << " < vi.endGroupID." << elem << "; ++" << elem << ")\n{\n";
m_writer->emit(builder);
m_writer->indent();
- builder.Clear();
- builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n";
-
- if (m_semanticUsedFlags & (SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::GroupID))
- {
- if (sizeAlongAxis[axis.axis] > 1)
- {
- builder << "const uint32_t next = context.groupDispatchThreadID." << elem << " + " << axis.size <<";\n";
-
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- builder << "context.groupID." << elem << " += uint32_t(next == " << elem << ");\n";
- }
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- builder << "context.groupDispatchThreadID." << elem << " = (" << elem << " == next) ? next : context.groupDispatchThreadID." << elem << ";\n";
- }
- }
- else
- {
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- builder << "context.groupDispatchThreadID." << elem << " = " << elem << ";\n";
- }
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- builder << "context.groupID." << elem << " = " << elem << ";\n";
- }
- }
- }
-
- m_writer->emit(builder);
+ m_writer->emit("groupVaryingInput.startGroupID.");
+ m_writer->emit(elem);
+ m_writer->emit(" = ");
+ m_writer->emit(elem);
+ m_writer->emit(";\n");
}
// just call at inner loop point
- m_writer->emit("context._");
m_writer->emit(funcName);
- m_writer->emit("();\n");
+ m_writer->emit("_Group(&groupVaryingInput, entryPointParams, globalParams);\n");
// Close all the loops
for (Index i = Index(axes.getCount() - 1); i >= 0; --i)
@@ -2736,6 +2675,34 @@ void CPPSourceEmitter::_emitForwardDeclarations(const List<EmitAction>& actions)
}
}
+static bool isVaryingResourceKind(LayoutResourceKind kind)
+{
+ switch(kind)
+ {
+ default:
+ return false;
+
+ case LayoutResourceKind::VaryingInput:
+ case LayoutResourceKind::VaryingOutput:
+ return true;
+ }
+}
+
+static bool isVaryingParameter(IRTypeLayout* typeLayout)
+{
+ for(auto sizeAttr : typeLayout->getSizeAttrs())
+ {
+ if(!isVaryingResourceKind(sizeAttr->getResourceKind()))
+ return false;
+ }
+ return true;
+}
+
+static bool isVaryingParameter(IRVarLayout* varLayout)
+{
+ return isVaryingParameter(varLayout->getTypeLayout());
+}
+
void CPPSourceEmitter::_findShaderParams(
IRGlobalParam** outEntryPointParam,
IRGlobalParam** outGlobalParam)
@@ -2752,6 +2719,20 @@ void CPPSourceEmitter::_findShaderParams(
if(!param)
continue;
+ if(auto layoutDecor = param->findDecoration<IRLayoutDecoration>())
+ {
+ if(auto varLayout = as<IRVarLayout>(layoutDecor->getLayout()))
+ {
+ if(isVaryingParameter(varLayout))
+ continue;
+ auto typeLayout = varLayout->getTypeLayout();
+ if(typeLayout->findSizeAttr(LayoutResourceKind::VaryingInput))
+ continue;
+ if(typeLayout->findSizeAttr(LayoutResourceKind::VaryingOutput))
+ continue;
+ }
+ }
+
// Currently, the entry-point parameters
// are represented as a single parameter
// at the global scope, and the same is
@@ -2806,28 +2787,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
m_writer->emit("struct KernelContext\n{\n");
m_writer->indent();
- m_writer->emit("uint3 dispatchThreadID;\n");
-
- //if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- // Note not always set!
- m_writer->emit("uint3 groupID;\n");
- }
-
- //if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- m_writer->emit("uint3 groupDispatchThreadID;\n");
-
- m_writer->emit("uint3 calcGroupThreadID() const \n{\n");
- m_writer->indent();
- // groupThreadID = dispatchThreadID - groupDispatchThreadID
- m_writer->emit("uint3 v = { dispatchThreadID.x - groupDispatchThreadID.x, dispatchThreadID.y - groupDispatchThreadID.y, dispatchThreadID.z - groupDispatchThreadID.z };\n");
- m_writer->emit("return v;\n");
- m_writer->dedent();
- m_writer->emit("}\n");
- }
-
-
if (globalParams)
{
emitGlobalInst(globalParams);
@@ -2886,9 +2845,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
if (entryPointDecor && entryPointDecor->getProfile().getStage() == Stage::Compute)
{
- // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid
- // SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID.
-
Int groupThreadSize[kThreadGroupAxisCount];
getComputeThreadGroupSize(func, groupThreadSize);
@@ -2902,23 +2858,9 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
_emitEntryPointDefinitionStart(func, entryPointParams, globalParams, threadFuncName, UnownedStringSlice::fromLiteral("ComputeThreadVaryingInput"));
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- m_writer->emit("context.groupDispatchThreadID = ");
- _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
- }
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- m_writer->emit("context.groupID = varyingInput->groupID;\n");
- }
-
- // Emit dispatchThreadID
- m_writer->emit("context.dispatchThreadID = ");
- _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID"));
-
m_writer->emit("context._");
m_writer->emit(funcName);
- m_writer->emit("();\n");
+ m_writer->emit("(varyingInput);\n");
_emitEntryPointDefinitionEnd(func);
}
@@ -2933,19 +2875,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
_emitEntryPointDefinitionStart(func, entryPointParams, globalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
- m_writer->emit("const uint3 start = ");
- _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
-
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
- {
- m_writer->emit("context.groupDispatchThreadID = start;\n");
- }
-
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
- {
- m_writer->emit("context.groupID = varyingInput->startGroupID;\n");
- }
- m_writer->emit("context.dispatchThreadID = start;\n");
+ m_writer->emit("ComputeThreadVaryingInput threadInput = {};\n");
+ m_writer->emit("threadInput.groupID = varyingInput->startGroupID;\n");
_emitEntryPointGroup(groupThreadSize, funcName);
_emitEntryPointDefinitionEnd(func);
@@ -2955,10 +2886,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
{
_emitEntryPointDefinitionStart(func, entryPointParams, globalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
- m_writer->emit("const uint3 start = ");
- _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
- m_writer->emit("const uint3 end = ");
- _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice());
+ m_writer->emit("ComputeVaryingInput vi = *varyingInput;\n");
+ m_writer->emit("ComputeVaryingInput groupVaryingInput = {};\n");
_emitEntryPointGroupRange(groupThreadSize, funcName);
_emitEntryPointDefinitionEnd(func);