diff options
Diffstat (limited to 'source/slang/slang-emit-cpp.cpp')
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 59 |
1 files changed, 23 insertions, 36 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 1628c6770..1f38512b4 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1916,23 +1916,15 @@ void CPPSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointLa { case Stage::Compute: { - static const UInt kAxisCount = 3; - UInt sizeAlongAxis[kAxisCount]; - - // TODO: this is kind of gross because we are using a public - // reflection API function, rather than some kind of internal - // utility it forwards to... - spReflectionEntryPoint_getComputeThreadGroupSize( - (SlangReflectionEntryPoint*)entryPointLayout, - kAxisCount, - &sizeAlongAxis[0]); - + Int numThreads[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, numThreads); + // TODO(JS): We might want to store this information such that it can be used to execute m_writer->emit("// [numthreads("); - for (int ii = 0; ii < 3; ++ii) + for (int ii = 0; ii < kThreadGroupAxisCount; ++ii) { if (ii != 0) m_writer->emit(", "); - m_writer->emit(sizeAlongAxis[ii]); + m_writer->emit(numThreads[ii]); } m_writer->emit(")]\n"); break; @@ -2504,16 +2496,16 @@ struct AxisWithSize bool operator<(const ThisType& rhs) const { return size < rhs.size || (size == rhs.size && axis < rhs.axis); } int axis; - UInt size; + Int size; }; } // anonymous -static void _calcAxisOrder(const UInt sizeAlongAxis[3], bool allowSingle, List<AxisWithSize>& out) +static void _calcAxisOrder(const Int sizeAlongAxis[CLikeSourceEmitter::kThreadGroupAxisCount], bool allowSingle, List<AxisWithSize>& out) { out.clear(); // Add in order z,y,x, so by default (if we don't sort), x will be the inner loop - for (int i = 3 - 1; i >= 0; --i) + for (int i = CLikeSourceEmitter::kThreadGroupAxisCount - 1; i >= 0; --i) { if (allowSingle || sizeAlongAxis[i] > 1) { @@ -2529,7 +2521,7 @@ static void _calcAxisOrder(const UInt sizeAlongAxis[3], bool allowSingle, List<A // axes.sort(); } -void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName) +void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName) { List<AxisWithSize> axes; _calcAxisOrder(sizeAlongAxis, false, axes); @@ -2563,7 +2555,7 @@ void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const S } } -void CPPSourceEmitter::_emitEntryPointGroupRange(const UInt sizeAlongAxis[3], const String& funcName) +void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName) { List<AxisWithSize> axes; _calcAxisOrder(sizeAlongAxis, true, axes); @@ -2635,13 +2627,13 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const UInt sizeAlongAxis[3], co m_writer->emit("}\n"); } } -void CPPSourceEmitter::_emitInitAxisValues(const UInt sizeAlongAxis[3], const UnownedStringSlice& mulName, const UnownedStringSlice& addName) +void CPPSourceEmitter::_emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupAxisCount], const UnownedStringSlice& mulName, const UnownedStringSlice& addName) { StringBuilder builder; m_writer->emit("{\n"); m_writer->indent(); - for (int i = 0; i < 3; ++i) + for (int i = 0; i < kThreadGroupAxisCount; ++i) { builder.Clear(); const char elem[2] = { s_elemNames[i], 0 }; @@ -2650,7 +2642,7 @@ void CPPSourceEmitter::_emitInitAxisValues(const UInt sizeAlongAxis[3], const Un { builder << " + " << addName << "." << elem; } - if (i < 3 - 1) + if (i < kThreadGroupAxisCount - 1) { builder << ","; } @@ -2821,14 +2813,9 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid // SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID. - static const UInt kAxisCount = 3; - UInt sizeAlongAxis[kAxisCount]; - - // TODO: this is kind of gross because we are using a public - // reflection API function, rather than some kind of internal - // utility it forwards to... - spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]); - + Int groupThreadSize[kThreadGroupAxisCount]; + getComputeThreadGroupSize(func, groupThreadSize); + String funcName = getFuncName(func); { @@ -2842,7 +2829,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) { m_writer->emit("context.groupDispatchThreadID = "); - _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice()); + _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice()); } if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) { @@ -2851,7 +2838,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) // Emit dispatchThreadID m_writer->emit("context.dispatchThreadID = "); - _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID")); + _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID")); m_writer->emit("context._"); m_writer->emit(funcName); @@ -2871,7 +2858,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) _emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput")); m_writer->emit("const uint3 start = "); - _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); + _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) { @@ -2884,7 +2871,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) } m_writer->emit("context.dispatchThreadID = start;\n"); - _emitEntryPointGroup(sizeAlongAxis, funcName); + _emitEntryPointGroup(groupThreadSize, funcName); _emitEntryPointDefinitionEnd(func); } @@ -2893,11 +2880,11 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) _emitEntryPointDefinitionStart(func, entryPointGlobalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput")); m_writer->emit("const uint3 start = "); - _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); + _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); m_writer->emit("const uint3 end = "); - _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice()); + _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice()); - _emitEntryPointGroupRange(sizeAlongAxis, funcName); + _emitEntryPointGroupRange(groupThreadSize, funcName); _emitEntryPointDefinitionEnd(func); } } |
