summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-cpp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-cpp.cpp')
-rw-r--r--source/slang/slang-emit-cpp.cpp59
1 files changed, 23 insertions, 36 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 1628c6770..1f38512b4 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1916,23 +1916,15 @@ void CPPSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointLa
{
case Stage::Compute:
{
- static const UInt kAxisCount = 3;
- UInt sizeAlongAxis[kAxisCount];
-
- // TODO: this is kind of gross because we are using a public
- // reflection API function, rather than some kind of internal
- // utility it forwards to...
- spReflectionEntryPoint_getComputeThreadGroupSize(
- (SlangReflectionEntryPoint*)entryPointLayout,
- kAxisCount,
- &sizeAlongAxis[0]);
-
+ Int numThreads[kThreadGroupAxisCount];
+ getComputeThreadGroupSize(irFunc, numThreads);
+
// TODO(JS): We might want to store this information such that it can be used to execute
m_writer->emit("// [numthreads(");
- for (int ii = 0; ii < 3; ++ii)
+ for (int ii = 0; ii < kThreadGroupAxisCount; ++ii)
{
if (ii != 0) m_writer->emit(", ");
- m_writer->emit(sizeAlongAxis[ii]);
+ m_writer->emit(numThreads[ii]);
}
m_writer->emit(")]\n");
break;
@@ -2504,16 +2496,16 @@ struct AxisWithSize
bool operator<(const ThisType& rhs) const { return size < rhs.size || (size == rhs.size && axis < rhs.axis); }
int axis;
- UInt size;
+ Int size;
};
} // anonymous
-static void _calcAxisOrder(const UInt sizeAlongAxis[3], bool allowSingle, List<AxisWithSize>& out)
+static void _calcAxisOrder(const Int sizeAlongAxis[CLikeSourceEmitter::kThreadGroupAxisCount], bool allowSingle, List<AxisWithSize>& out)
{
out.clear();
// Add in order z,y,x, so by default (if we don't sort), x will be the inner loop
- for (int i = 3 - 1; i >= 0; --i)
+ for (int i = CLikeSourceEmitter::kThreadGroupAxisCount - 1; i >= 0; --i)
{
if (allowSingle || sizeAlongAxis[i] > 1)
{
@@ -2529,7 +2521,7 @@ static void _calcAxisOrder(const UInt sizeAlongAxis[3], bool allowSingle, List<A
// axes.sort();
}
-void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName)
+void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName)
{
List<AxisWithSize> axes;
_calcAxisOrder(sizeAlongAxis, false, axes);
@@ -2563,7 +2555,7 @@ void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const S
}
}
-void CPPSourceEmitter::_emitEntryPointGroupRange(const UInt sizeAlongAxis[3], const String& funcName)
+void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName)
{
List<AxisWithSize> axes;
_calcAxisOrder(sizeAlongAxis, true, axes);
@@ -2635,13 +2627,13 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const UInt sizeAlongAxis[3], co
m_writer->emit("}\n");
}
}
-void CPPSourceEmitter::_emitInitAxisValues(const UInt sizeAlongAxis[3], const UnownedStringSlice& mulName, const UnownedStringSlice& addName)
+void CPPSourceEmitter::_emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupAxisCount], const UnownedStringSlice& mulName, const UnownedStringSlice& addName)
{
StringBuilder builder;
m_writer->emit("{\n");
m_writer->indent();
- for (int i = 0; i < 3; ++i)
+ for (int i = 0; i < kThreadGroupAxisCount; ++i)
{
builder.Clear();
const char elem[2] = { s_elemNames[i], 0 };
@@ -2650,7 +2642,7 @@ void CPPSourceEmitter::_emitInitAxisValues(const UInt sizeAlongAxis[3], const Un
{
builder << " + " << addName << "." << elem;
}
- if (i < 3 - 1)
+ if (i < kThreadGroupAxisCount - 1)
{
builder << ",";
}
@@ -2821,14 +2813,9 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid
// SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID.
- static const UInt kAxisCount = 3;
- UInt sizeAlongAxis[kAxisCount];
-
- // TODO: this is kind of gross because we are using a public
- // reflection API function, rather than some kind of internal
- // utility it forwards to...
- spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
-
+ Int groupThreadSize[kThreadGroupAxisCount];
+ getComputeThreadGroupSize(func, groupThreadSize);
+
String funcName = getFuncName(func);
{
@@ -2842,7 +2829,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
{
m_writer->emit("context.groupDispatchThreadID = ");
- _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
+ _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
}
if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
{
@@ -2851,7 +2838,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
// Emit dispatchThreadID
m_writer->emit("context.dispatchThreadID = ");
- _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID"));
+ _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID"));
m_writer->emit("context._");
m_writer->emit(funcName);
@@ -2871,7 +2858,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
_emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
m_writer->emit("const uint3 start = ");
- _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
+ _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
{
@@ -2884,7 +2871,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
}
m_writer->emit("context.dispatchThreadID = start;\n");
- _emitEntryPointGroup(sizeAlongAxis, funcName);
+ _emitEntryPointGroup(groupThreadSize, funcName);
_emitEntryPointDefinitionEnd(func);
}
@@ -2893,11 +2880,11 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
_emitEntryPointDefinitionStart(func, entryPointGlobalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
m_writer->emit("const uint3 start = ");
- _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
+ _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
m_writer->emit("const uint3 end = ");
- _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice());
+ _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice());
- _emitEntryPointGroupRange(sizeAlongAxis, funcName);
+ _emitEntryPointGroupRange(groupThreadSize, funcName);
_emitEntryPointDefinitionEnd(func);
}
}