summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-cpp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-cpp.cpp')
-rw-r--r--source/slang/slang-emit-cpp.cpp307
1 files changed, 218 insertions, 89 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index df6d1bee8..1a6a46fc5 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -57,20 +57,6 @@ ComputeVaryingInput - Fixed because we are doing compute shader
Uniform - All the uniform data in a big blob, both from uniform entry point parameters, and uniform globals
When called we can have a structure that holds the thread local variables, and these two pointers.
-
-
-We can stick pointers to these in a structure lets call it 'Context'. On C++ we could make all the functions 'methods', and then
-we don't need to pass around the context as a parameter. For C this doesn't work, so it might be worth just biting the bullet and
-just adding the context to the output.
-
-Issues:
-
-* How does this work with layout? The layout if it's going to specify offsets will need to know that they will be allocated into each
-of these structs AND that the order they are placed needs to be consistent.
-
-* When variables access one of these sources, we will now need code that will add the dereferencing. Hopefully this can be done by looking
-at the type of the variable, and then adding the appropriate access via part of emit.
-
*/
namespace Slang {
@@ -1790,6 +1776,9 @@ void CPPSourceEmitter::emitOperationCall(IntrinsicOp op, IRInst* inst, IRUse* op
CPPSourceEmitter::CPPSourceEmitter(const Desc& desc):
Super(desc)
{
+ m_semanticUsedFlags = 0;
+ //m_semanticUsedFlags = SemanticUsedFlag::GroupID | SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::DispatchThreadID;
+
m_sharedIRBuilder.module = nullptr;
m_sharedIRBuilder.session = desc.compileRequest->getSession();
@@ -2405,18 +2394,20 @@ void CPPSourceEmitter::emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPre
if (semanticNameSpelling == "sv_dispatchthreadid")
{
-
+ m_semanticUsedFlags |= SemanticUsedFlag::DispatchThreadID;
m_writer->emit("dispatchThreadID");
return;
}
else if (semanticNameSpelling == "sv_groupid")
{
- m_writer->emit("varyingInput.groupID");
+ m_semanticUsedFlags |= SemanticUsedFlag::GroupID;
+ m_writer->emit("groupID");
return;
}
else if (semanticNameSpelling == "sv_groupthreadid")
{
- m_writer->emit("varyingInput.groupThreadID");
+ m_semanticUsedFlags |= SemanticUsedFlag::GroupThreadID;
+ m_writer->emit("calcGroupThreadID()");
return;
}
}
@@ -2463,25 +2454,25 @@ struct GlobalParamInfo
UInt size;
};
-void CPPSourceEmitter::_emitEntryPointDefinitionStart(IRFunc* func, IRGlobalParam* entryPointGlobalParams, const String& funcName)
+void CPPSourceEmitter::_emitEntryPointDefinitionStart(IRFunc* func, IRGlobalParam* entryPointGlobalParams, const String& funcName, const UnownedStringSlice& varyingTypeName)
{
auto resultType = func->getResultType();
-
auto entryPointLayout = asEntryPoint(func);
// Emit the actual function
emitEntryPointAttributes(func, entryPointLayout);
emitType(resultType, funcName);
- m_writer->emit("(ComputeVaryingInput* varyingInput, UniformEntryPointParams* params, UniformState* uniformState)\n{\n");
+ m_writer->emit("(");
+ m_writer->emit(varyingTypeName);
+ m_writer->emit("* varyingInput, UniformEntryPointParams* params, UniformState* uniformState)\n{\n");
emitSemantics(func);
m_writer->indent();
// Initialize when constructing so that globals are zeroed
m_writer->emit("Context context = {};\n");
m_writer->emit("context.uniformState = uniformState;\n");
- m_writer->emit("context.varyingInput = *varyingInput;\n");
-
+
if (entryPointGlobalParams)
{
auto varDecl = entryPointGlobalParams;
@@ -2504,32 +2495,43 @@ void CPPSourceEmitter::_emitEntryPointDefinitionEnd(IRFunc* func)
m_writer->emit("}\n");
}
-// We want to order such that the largest range is the inner loop
+namespace { // anonymous
-void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName)
+struct AxisWithSize
{
- struct AxisWithSize
- {
- typedef AxisWithSize ThisType;
- bool operator<(const ThisType& rhs) const { return size < rhs.size; }
+ typedef AxisWithSize ThisType;
+ bool operator<(const ThisType& rhs) const { return size < rhs.size || (size == rhs.size && axis < rhs.axis); }
- int axis;
- UInt size;
- };
- List<AxisWithSize> axes;
+ int axis;
+ UInt size;
+};
- for (int i = 0; i < 3; ++i)
+} // anonymous
+
+static void _calcAxisOrder(const UInt sizeAlongAxis[3], bool allowSingle, List<AxisWithSize>& out)
+{
+ out.clear();
+ // Add in order z,y,x, so by default (if we don't sort), x will be the inner loop
+ for (int i = 3 - 1; i >= 0; --i)
{
- if (sizeAlongAxis[i] > 1)
+ if (allowSingle || sizeAlongAxis[i] > 1)
{
AxisWithSize axisWithSize;
axisWithSize.axis = i;
axisWithSize.size = sizeAlongAxis[i];
- axes.add(axisWithSize);
+ out.add(axisWithSize);
}
}
- axes.sort();
+ // The sort here works to make the axis with the highest value the inner most loop.
+ // Disabled for now to make the order well defined as x, y, z
+ // axes.sort();
+}
+
+void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName)
+{
+ List<AxisWithSize> axes;
+ _calcAxisOrder(sizeAlongAxis, false, axes);
// Open all the loops
StringBuilder builder;
@@ -2560,6 +2562,104 @@ void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const S
}
}
+void CPPSourceEmitter::_emitEntryPointGroupRange(const UInt sizeAlongAxis[3], const String& funcName)
+{
+ List<AxisWithSize> axes;
+ _calcAxisOrder(sizeAlongAxis, true, axes);
+
+ // Open all the loops
+ StringBuilder builder;
+ for (Index i = 0; i < axes.getCount(); ++i)
+ {
+ const auto& axis = axes[i];
+ builder.Clear();
+ const char elem[2] = { s_elemNames[axis.axis], 0 };
+
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
+ {
+ builder << "context.groupDispatchThreadID." << elem << " = start." << elem << ";\n";
+ }
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ builder << "context.groupID." << elem << " += varyingInput->startGroupID." << elem << ";\n";
+ }
+
+ builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < end." << elem << "; ++" << elem << ")\n{\n";
+ m_writer->emit(builder);
+ m_writer->indent();
+
+ builder.Clear();
+ builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n";
+
+ if (m_semanticUsedFlags & (SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::GroupID))
+ {
+ if (sizeAlongAxis[axis.axis] > 1)
+ {
+ builder << "const uint32_t next = context.groupDispatchThreadID." << elem << " + " << axis.size <<";\n";
+
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ builder << "context.groupID." << elem << " += uint32_t(next == " << elem << ");\n";
+ }
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
+ {
+ builder << "context.groupDispatchThreadID." << elem << " = (" << elem << " == next) ? next : context.groupDispatchThreadID." << elem << ";\n";
+ }
+ }
+ else
+ {
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
+ {
+ builder << "context.groupDispatchThreadID." << elem << " = " << elem << ";\n";
+ }
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ builder << "context.groupID." << elem << " = " << elem << ";\n";
+ }
+ }
+ }
+
+ m_writer->emit(builder);
+ }
+
+ // just call at inner loop point
+ m_writer->emit("context._");
+ m_writer->emit(funcName);
+ m_writer->emit("();\n");
+
+ // Close all the loops
+ for (Index i = Index(axes.getCount() - 1); i >= 0; --i)
+ {
+ m_writer->dedent();
+ m_writer->emit("}\n");
+ }
+}
+void CPPSourceEmitter::_emitInitAxisValues(const UInt sizeAlongAxis[3], const UnownedStringSlice& mulName, const UnownedStringSlice& addName)
+{
+ StringBuilder builder;
+
+ m_writer->emit("{\n");
+ m_writer->indent();
+ for (int i = 0; i < 3; ++i)
+ {
+ builder.Clear();
+ const char elem[2] = { s_elemNames[i], 0 };
+ builder << mulName << "." << elem << " * " << sizeAlongAxis[i];
+ if (addName.size() > 0)
+ {
+ builder << " + " << addName << "." << elem;
+ }
+ if (i < 3 - 1)
+ {
+ builder << ",";
+ }
+ builder << "\n";
+ m_writer->emit(builder);
+ }
+ m_writer->dedent();
+ m_writer->emit("};\n");
+}
+
void CPPSourceEmitter::emitModuleImpl(IRModule* module)
{
List<EmitAction> actions;
@@ -2656,9 +2756,29 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
m_writer->indent();
m_writer->emit("UniformState* uniformState;\n");
- m_writer->emit("ComputeVaryingInput varyingInput;\n");
+
+
m_writer->emit("uint3 dispatchThreadID;\n");
+ //if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ // Note not always set!
+ m_writer->emit("uint3 groupID;\n");
+ }
+
+ //if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
+ {
+ m_writer->emit("uint3 groupDispatchThreadID;\n");
+
+ m_writer->emit("uint3 calcGroupThreadID() const \n{\n");
+ m_writer->indent();
+ // groupThreadID = dispatchThreadID - groupDispatchThreadID
+ m_writer->emit("uint3 v = { dispatchThreadID.x - groupDispatchThreadID.x, dispatchThreadID.y - groupDispatchThreadID.y, dispatchThreadID.z - groupDispatchThreadID.z }; ");
+ m_writer->emit("return v;\n");
+ m_writer->dedent();
+ m_writer->emit("}\n");
+ }
+
if (entryPointGlobalParams)
{
emitGlobalInst(entryPointGlobalParams);
@@ -2695,7 +2815,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
IRFunc* func = as<IRFunc>(action.inst);
auto entryPointLayout = asEntryPoint(func);
- if (entryPointLayout)
+ if (entryPointLayout && entryPointLayout->profile.GetStage() == Stage::Compute)
{
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid
// SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID.
@@ -2703,40 +2823,30 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
static const UInt kAxisCount = 3;
UInt sizeAlongAxis[kAxisCount];
+ // TODO: this is kind of gross because we are using a public
+ // reflection API function, rather than some kind of internal
+ // utility it forwards to...
+ spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
+
String funcName = getFuncName(func);
{
- _emitEntryPointDefinitionStart(func, entryPointGlobalParams, funcName);
+ _emitEntryPointDefinitionStart(func, entryPointGlobalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
- // Emit dispatchThreadID
- if (entryPointLayout->profile.GetStage() == Stage::Compute)
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
{
- // TODO: this is kind of gross because we are using a public
- // reflection API function, rather than some kind of internal
- // utility it forwards to...
- spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
-
- m_writer->emit("context.dispatchThreadID = {\n");
- m_writer->indent();
-
- StringBuilder builder;
- for (int i = 0; i < kAxisCount; ++i)
- {
- builder.Clear();
- const char elem[2] = {s_elemNames[i], 0};
- builder << "varyingInput->groupID." << elem << " * " << sizeAlongAxis[i] << " + varyingInput->groupThreadID." << elem;
- if (i < kAxisCount - 1)
- {
- builder << ",";
- }
- builder << "\n";
- m_writer->emit(builder);
- }
-
- m_writer->dedent();
- m_writer->emit("};\n");
+ m_writer->emit("context.groupDispatchThreadID = ");
+ _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
+ }
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ m_writer->emit("context.groupID = varyingInput->groupID;\n");
}
+ // Emit dispatchThreadID
+ m_writer->emit("context.dispatchThreadID = ");
+ _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID"));
+
m_writer->emit("context._");
m_writer->emit(funcName);
m_writer->emit("();\n");
@@ -2752,36 +2862,55 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
String groupFuncName = builder;
- _emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupFuncName);
+ _emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
- // Emit dispatchThreadID
- if (entryPointLayout->profile.GetStage() == Stage::Compute)
+ m_writer->emit("const uint3 start = ");
+ _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
+
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
{
- spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
+ m_writer->emit("context.groupDispatchThreadID = start;\n");
+ }
- {
- m_writer->emit("const uint3 start = {\n");
- m_writer->indent();
- for (int i = 0; i < kAxisCount; ++i)
- {
- builder.Clear();
- const char elem[2] = { s_elemNames[i], 0 };
- builder << "varyingInput->groupID." << elem << " * " << sizeAlongAxis[i];
- if (i < kAxisCount - 1)
- {
- builder << ",";
- }
- builder << "\n";
- m_writer->emit(builder);
- }
- m_writer->dedent();
- m_writer->emit("};\n");
- }
- m_writer->emit("context.dispatchThreadID = start;\n");
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ m_writer->emit("context.groupID = varyingInput->groupID;\n");
+ }
+ m_writer->emit("context.dispatchThreadID = start;\n");
+
+ _emitEntryPointGroup(sizeAlongAxis, funcName);
+ _emitEntryPointDefinitionEnd(func);
+ }
+
+ // Emit the group version which runs for all elements in a thread group
+ {
+ StringBuilder builder;
+ builder << getFuncName(func);
+ builder << "_GroupRange";
- _emitEntryPointGroup(sizeAlongAxis, funcName);
+ String groupRangeFuncName = builder;
+
+ _emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupRangeFuncName, UnownedStringSlice::fromLiteral("GroupComputeVaryingInput"));
+
+ m_writer->emit("const uint3 start = ");
+ _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
+ m_writer->emit("const uint3 end = ");
+ _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice());
+
+#if 0
+ // Not needed as will be emitted as part of the loop
+ m_writer->emit("context.dispatchThreadID = start;\n");
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
+ {
+ m_writer->emit("context.groupDispatchThreadID = start;");
+ }
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ m_writer->emit("context.groupID = varyingInput->startGroupID;\n");
}
+#endif
+ _emitEntryPointGroupRange(sizeAlongAxis, funcName);
_emitEntryPointDefinitionEnd(func);
}
}