summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-cpp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-cpp.cpp')
-rw-r--r--source/slang/slang-emit-cpp.cpp221
1 files changed, 167 insertions, 54 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index a5173549a..df6d1bee8 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -2463,6 +2463,103 @@ struct GlobalParamInfo
UInt size;
};
+void CPPSourceEmitter::_emitEntryPointDefinitionStart(IRFunc* func, IRGlobalParam* entryPointGlobalParams, const String& funcName)
+{
+ auto resultType = func->getResultType();
+
+ auto entryPointLayout = asEntryPoint(func);
+
+ // Emit the actual function
+ emitEntryPointAttributes(func, entryPointLayout);
+ emitType(resultType, funcName);
+
+ m_writer->emit("(ComputeVaryingInput* varyingInput, UniformEntryPointParams* params, UniformState* uniformState)\n{\n");
+ emitSemantics(func);
+
+ m_writer->indent();
+ // Initialize when constructing so that globals are zeroed
+ m_writer->emit("Context context = {};\n");
+ m_writer->emit("context.uniformState = uniformState;\n");
+ m_writer->emit("context.varyingInput = *varyingInput;\n");
+
+ if (entryPointGlobalParams)
+ {
+ auto varDecl = entryPointGlobalParams;
+ auto rawType = varDecl->getDataType();
+
+ auto varType = rawType;
+
+ m_writer->emit("context.");
+ m_writer->emit(getName(varDecl));
+ m_writer->emit(" = (");
+ emitType(varType);
+ m_writer->emit("*)params; \n");
+ }
+}
+
+void CPPSourceEmitter::_emitEntryPointDefinitionEnd(IRFunc* func)
+{
+ SLANG_UNUSED(func);
+ m_writer->dedent();
+ m_writer->emit("}\n");
+}
+
+// We want to order such that the largest range is the inner loop
+
+void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName)
+{
+ struct AxisWithSize
+ {
+ typedef AxisWithSize ThisType;
+ bool operator<(const ThisType& rhs) const { return size < rhs.size; }
+
+ int axis;
+ UInt size;
+ };
+ List<AxisWithSize> axes;
+
+ for (int i = 0; i < 3; ++i)
+ {
+ if (sizeAlongAxis[i] > 1)
+ {
+ AxisWithSize axisWithSize;
+ axisWithSize.axis = i;
+ axisWithSize.size = sizeAlongAxis[i];
+ axes.add(axisWithSize);
+ }
+ }
+
+ axes.sort();
+
+ // Open all the loops
+ StringBuilder builder;
+ for (Index i = 0; i < axes.getCount(); ++i)
+ {
+ const auto& axis = axes[i];
+ builder.Clear();
+ const char elem[2] = { s_elemNames[axis.axis], 0 };
+ builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < start." << elem << " + " << axis.size << "; ++" << elem << ")\n{\n";
+ m_writer->emit(builder);
+ m_writer->indent();
+
+ builder.Clear();
+ builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n";
+ m_writer->emit(builder);
+ }
+
+ // just call at inner loop point
+ m_writer->emit("context._");
+ m_writer->emit(funcName);
+ m_writer->emit("();\n");
+
+ // Close all the loops
+ for (Index i = Index(axes.getCount() - 1); i >= 0; --i)
+ {
+ m_writer->dedent();
+ m_writer->emit("}\n");
+ }
+}
+
void CPPSourceEmitter::emitModuleImpl(IRModule* module)
{
List<EmitAction> actions;
@@ -2600,77 +2697,93 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
auto entryPointLayout = asEntryPoint(func);
if (entryPointLayout)
{
- auto resultType = func->getResultType();
- auto name = getFuncName(func);
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid
+ // SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID.
- // Emit the actual function
- emitEntryPointAttributes(func, entryPointLayout);
- emitType(resultType, name);
+ static const UInt kAxisCount = 3;
+ UInt sizeAlongAxis[kAxisCount];
- m_writer->emit("(ComputeVaryingInput* varyingInput, UniformEntryPointParams* params, UniformState* uniformState)\n{\n");
- emitSemantics(func);
+ String funcName = getFuncName(func);
- m_writer->indent();
- // Initialize when constructing so that globals are zeroed
- m_writer->emit("Context context = {};\n");
- m_writer->emit("context.uniformState = uniformState;\n");
- m_writer->emit("context.varyingInput = *varyingInput;\n");
+ {
+ _emitEntryPointDefinitionStart(func, entryPointGlobalParams, funcName);
- if (entryPointGlobalParams)
- {
- auto varDecl = entryPointGlobalParams;
- auto rawType = varDecl->getDataType();
+ // Emit dispatchThreadID
+ if (entryPointLayout->profile.GetStage() == Stage::Compute)
+ {
+ // TODO: this is kind of gross because we are using a public
+ // reflection API function, rather than some kind of internal
+ // utility it forwards to...
+ spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
- auto varType = rawType;
+ m_writer->emit("context.dispatchThreadID = {\n");
+ m_writer->indent();
- m_writer->emit("context.");
- m_writer->emit(getName(varDecl));
- m_writer->emit(" = (");
- emitType(varType);
- m_writer->emit("*)params; \n");
- }
-
- // Emit dispatchThreadID
- if (entryPointLayout->profile.GetStage() == Stage::Compute)
- {
- // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid
- // SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID.
+ StringBuilder builder;
+ for (int i = 0; i < kAxisCount; ++i)
+ {
+ builder.Clear();
+ const char elem[2] = {s_elemNames[i], 0};
+ builder << "varyingInput->groupID." << elem << " * " << sizeAlongAxis[i] << " + varyingInput->groupThreadID." << elem;
+ if (i < kAxisCount - 1)
+ {
+ builder << ",";
+ }
+ builder << "\n";
+ m_writer->emit(builder);
+ }
- static const UInt kAxisCount = 3;
- UInt sizeAlongAxis[kAxisCount];
+ m_writer->dedent();
+ m_writer->emit("};\n");
+ }
- // TODO: this is kind of gross because we are using a public
- // reflection API function, rather than some kind of internal
- // utility it forwards to...
- spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
+ m_writer->emit("context._");
+ m_writer->emit(funcName);
+ m_writer->emit("();\n");
- m_writer->emit("context.dispatchThreadID = {\n");
- m_writer->indent();
+ _emitEntryPointDefinitionEnd(func);
+ }
+ // Emit the group version which runs for all elements in a thread group
+ {
StringBuilder builder;
-
- for (int i = 0; i < kAxisCount; ++i)
+ builder << getFuncName(func);
+ builder << "_Group";
+
+ String groupFuncName = builder;
+
+ _emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupFuncName);
+
+ // Emit dispatchThreadID
+ if (entryPointLayout->profile.GetStage() == Stage::Compute)
{
- builder.Clear();
- const char elem[2] = {s_elemNames[i], 0};
- builder << "varyingInput->groupID." << elem << " * " << sizeAlongAxis[i] << " + varyingInput->groupThreadID." << elem;
- if (i < kAxisCount - 1)
+ spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
+
{
- builder << ",";
+ m_writer->emit("const uint3 start = {\n");
+ m_writer->indent();
+ for (int i = 0; i < kAxisCount; ++i)
+ {
+ builder.Clear();
+ const char elem[2] = { s_elemNames[i], 0 };
+ builder << "varyingInput->groupID." << elem << " * " << sizeAlongAxis[i];
+ if (i < kAxisCount - 1)
+ {
+ builder << ",";
+ }
+ builder << "\n";
+ m_writer->emit(builder);
+ }
+ m_writer->dedent();
+ m_writer->emit("};\n");
}
- builder << "\n";
- m_writer->emit(builder);
+ m_writer->emit("context.dispatchThreadID = start;\n");
+
+ _emitEntryPointGroup(sizeAlongAxis, funcName);
}
- m_writer->dedent();
- m_writer->emit("};\n");
+ _emitEntryPointDefinitionEnd(func);
}
-
- m_writer->emit("context._");
- m_writer->emit(name);
- m_writer->emit("();\n");
- m_writer->dedent();
- m_writer->emit("}\n");
}
}
}