summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--prelude/slang-cpp-types.h6
-rw-r--r--source/slang/slang-compiler.cpp14
-rw-r--r--source/slang/slang-emit-cpp.cpp307
-rw-r--r--source/slang/slang-emit-cpp.h18
-rw-r--r--tests/compute/semantic.slang13
-rw-r--r--tests/compute/semantic.slang.expected.txt12
-rw-r--r--tools/render-test/cpu-compute-util.cpp75
-rw-r--r--tools/render-test/cpu-compute-util.h2
-rw-r--r--tools/render-test/options.cpp28
-rw-r--r--tools/render-test/options.h2
-rw-r--r--tools/render-test/render-test-main.cpp4
11 files changed, 367 insertions, 114 deletions
diff --git a/prelude/slang-cpp-types.h b/prelude/slang-cpp-types.h
index d5d88d7b2..c79465032 100644
--- a/prelude/slang-cpp-types.h
+++ b/prelude/slang-cpp-types.h
@@ -232,6 +232,12 @@ struct ComputeVaryingInput
uint3 groupThreadID;
};
+struct GroupComputeVaryingInput
+{
+ uint3 startGroupID; ///< start groupID
+ uint3 endGroupID; ///< Non inclusive end groupID
+};
+
/* Type that defines the uniform entry point params. The actual content of this type is dependent on the entry point parameters, and can be
found via reflection or defined such that it matches the shader appropriately. */
struct UniformEntryPointParams;
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index dbb900ab6..ce22186ee 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -789,11 +789,6 @@ namespace Slang
builder << compilerName << ": ";
}
- if (diagnostic.size() > 0)
- {
- builder.Append(diagnostic);
- }
-
if (SLANG_FAILED(res) && res != SLANG_FAIL)
{
{
@@ -805,6 +800,15 @@ namespace Slang
PlatformUtil::appendResult(res, builder);
}
+ if (diagnostic.size() > 0)
+ {
+ builder.Append(diagnostic);
+ if (!diagnostic.endsWith("\n"))
+ {
+ builder.Append("\n");
+ }
+ }
+
sink->diagnoseRaw(severity, builder.getUnownedSlice());
}
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index df6d1bee8..1a6a46fc5 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -57,20 +57,6 @@ ComputeVaryingInput - Fixed because we are doing compute shader
Uniform - All the uniform data in a big blob, both from uniform entry point parameters, and uniform globals
When called we can have a structure that holds the thread local variables, and these two pointers.
-
-
-We can stick pointers to these in a structure lets call it 'Context'. On C++ we could make all the functions 'methods', and then
-we don't need to pass around the context as a parameter. For C this doesn't work, so it might be worth just biting the bullet and
-just adding the context to the output.
-
-Issues:
-
-* How does this work with layout? The layout if it's going to specify offsets will need to know that they will be allocated into each
-of these structs AND that the order they are placed needs to be consistent.
-
-* When variables access one of these sources, we will now need code that will add the dereferencing. Hopefully this can be done by looking
-at the type of the variable, and then adding the appropriate access via part of emit.
-
*/
namespace Slang {
@@ -1790,6 +1776,9 @@ void CPPSourceEmitter::emitOperationCall(IntrinsicOp op, IRInst* inst, IRUse* op
CPPSourceEmitter::CPPSourceEmitter(const Desc& desc):
Super(desc)
{
+ m_semanticUsedFlags = 0;
+ //m_semanticUsedFlags = SemanticUsedFlag::GroupID | SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::DispatchThreadID;
+
m_sharedIRBuilder.module = nullptr;
m_sharedIRBuilder.session = desc.compileRequest->getSession();
@@ -2405,18 +2394,20 @@ void CPPSourceEmitter::emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPre
if (semanticNameSpelling == "sv_dispatchthreadid")
{
-
+ m_semanticUsedFlags |= SemanticUsedFlag::DispatchThreadID;
m_writer->emit("dispatchThreadID");
return;
}
else if (semanticNameSpelling == "sv_groupid")
{
- m_writer->emit("varyingInput.groupID");
+ m_semanticUsedFlags |= SemanticUsedFlag::GroupID;
+ m_writer->emit("groupID");
return;
}
else if (semanticNameSpelling == "sv_groupthreadid")
{
- m_writer->emit("varyingInput.groupThreadID");
+ m_semanticUsedFlags |= SemanticUsedFlag::GroupThreadID;
+ m_writer->emit("calcGroupThreadID()");
return;
}
}
@@ -2463,25 +2454,25 @@ struct GlobalParamInfo
UInt size;
};
-void CPPSourceEmitter::_emitEntryPointDefinitionStart(IRFunc* func, IRGlobalParam* entryPointGlobalParams, const String& funcName)
+void CPPSourceEmitter::_emitEntryPointDefinitionStart(IRFunc* func, IRGlobalParam* entryPointGlobalParams, const String& funcName, const UnownedStringSlice& varyingTypeName)
{
auto resultType = func->getResultType();
-
auto entryPointLayout = asEntryPoint(func);
// Emit the actual function
emitEntryPointAttributes(func, entryPointLayout);
emitType(resultType, funcName);
- m_writer->emit("(ComputeVaryingInput* varyingInput, UniformEntryPointParams* params, UniformState* uniformState)\n{\n");
+ m_writer->emit("(");
+ m_writer->emit(varyingTypeName);
+ m_writer->emit("* varyingInput, UniformEntryPointParams* params, UniformState* uniformState)\n{\n");
emitSemantics(func);
m_writer->indent();
// Initialize when constructing so that globals are zeroed
m_writer->emit("Context context = {};\n");
m_writer->emit("context.uniformState = uniformState;\n");
- m_writer->emit("context.varyingInput = *varyingInput;\n");
-
+
if (entryPointGlobalParams)
{
auto varDecl = entryPointGlobalParams;
@@ -2504,32 +2495,43 @@ void CPPSourceEmitter::_emitEntryPointDefinitionEnd(IRFunc* func)
m_writer->emit("}\n");
}
-// We want to order such that the largest range is the inner loop
+namespace { // anonymous
-void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName)
+struct AxisWithSize
{
- struct AxisWithSize
- {
- typedef AxisWithSize ThisType;
- bool operator<(const ThisType& rhs) const { return size < rhs.size; }
+ typedef AxisWithSize ThisType;
+ bool operator<(const ThisType& rhs) const { return size < rhs.size || (size == rhs.size && axis < rhs.axis); }
- int axis;
- UInt size;
- };
- List<AxisWithSize> axes;
+ int axis;
+ UInt size;
+};
- for (int i = 0; i < 3; ++i)
+} // anonymous
+
+static void _calcAxisOrder(const UInt sizeAlongAxis[3], bool allowSingle, List<AxisWithSize>& out)
+{
+ out.clear();
+ // Add in order z,y,x, so by default (if we don't sort), x will be the inner loop
+ for (int i = 3 - 1; i >= 0; --i)
{
- if (sizeAlongAxis[i] > 1)
+ if (allowSingle || sizeAlongAxis[i] > 1)
{
AxisWithSize axisWithSize;
axisWithSize.axis = i;
axisWithSize.size = sizeAlongAxis[i];
- axes.add(axisWithSize);
+ out.add(axisWithSize);
}
}
- axes.sort();
+ // The sort here works to make the axis with the highest value the inner most loop.
+ // Disabled for now to make the order well defined as x, y, z
+ // axes.sort();
+}
+
+void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName)
+{
+ List<AxisWithSize> axes;
+ _calcAxisOrder(sizeAlongAxis, false, axes);
// Open all the loops
StringBuilder builder;
@@ -2560,6 +2562,104 @@ void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const S
}
}
+void CPPSourceEmitter::_emitEntryPointGroupRange(const UInt sizeAlongAxis[3], const String& funcName)
+{
+ List<AxisWithSize> axes;
+ _calcAxisOrder(sizeAlongAxis, true, axes);
+
+ // Open all the loops
+ StringBuilder builder;
+ for (Index i = 0; i < axes.getCount(); ++i)
+ {
+ const auto& axis = axes[i];
+ builder.Clear();
+ const char elem[2] = { s_elemNames[axis.axis], 0 };
+
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
+ {
+ builder << "context.groupDispatchThreadID." << elem << " = start." << elem << ";\n";
+ }
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ builder << "context.groupID." << elem << " += varyingInput->startGroupID." << elem << ";\n";
+ }
+
+ builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < end." << elem << "; ++" << elem << ")\n{\n";
+ m_writer->emit(builder);
+ m_writer->indent();
+
+ builder.Clear();
+ builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n";
+
+ if (m_semanticUsedFlags & (SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::GroupID))
+ {
+ if (sizeAlongAxis[axis.axis] > 1)
+ {
+ builder << "const uint32_t next = context.groupDispatchThreadID." << elem << " + " << axis.size <<";\n";
+
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ builder << "context.groupID." << elem << " += uint32_t(next == " << elem << ");\n";
+ }
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
+ {
+ builder << "context.groupDispatchThreadID." << elem << " = (" << elem << " == next) ? next : context.groupDispatchThreadID." << elem << ";\n";
+ }
+ }
+ else
+ {
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
+ {
+ builder << "context.groupDispatchThreadID." << elem << " = " << elem << ";\n";
+ }
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ builder << "context.groupID." << elem << " = " << elem << ";\n";
+ }
+ }
+ }
+
+ m_writer->emit(builder);
+ }
+
+ // just call at inner loop point
+ m_writer->emit("context._");
+ m_writer->emit(funcName);
+ m_writer->emit("();\n");
+
+ // Close all the loops
+ for (Index i = Index(axes.getCount() - 1); i >= 0; --i)
+ {
+ m_writer->dedent();
+ m_writer->emit("}\n");
+ }
+}
+void CPPSourceEmitter::_emitInitAxisValues(const UInt sizeAlongAxis[3], const UnownedStringSlice& mulName, const UnownedStringSlice& addName)
+{
+ StringBuilder builder;
+
+ m_writer->emit("{\n");
+ m_writer->indent();
+ for (int i = 0; i < 3; ++i)
+ {
+ builder.Clear();
+ const char elem[2] = { s_elemNames[i], 0 };
+ builder << mulName << "." << elem << " * " << sizeAlongAxis[i];
+ if (addName.size() > 0)
+ {
+ builder << " + " << addName << "." << elem;
+ }
+ if (i < 3 - 1)
+ {
+ builder << ",";
+ }
+ builder << "\n";
+ m_writer->emit(builder);
+ }
+ m_writer->dedent();
+ m_writer->emit("};\n");
+}
+
void CPPSourceEmitter::emitModuleImpl(IRModule* module)
{
List<EmitAction> actions;
@@ -2656,9 +2756,29 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
m_writer->indent();
m_writer->emit("UniformState* uniformState;\n");
- m_writer->emit("ComputeVaryingInput varyingInput;\n");
+
+
m_writer->emit("uint3 dispatchThreadID;\n");
+ //if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ // Note not always set!
+ m_writer->emit("uint3 groupID;\n");
+ }
+
+ //if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
+ {
+ m_writer->emit("uint3 groupDispatchThreadID;\n");
+
+ m_writer->emit("uint3 calcGroupThreadID() const \n{\n");
+ m_writer->indent();
+ // groupThreadID = dispatchThreadID - groupDispatchThreadID
+ m_writer->emit("uint3 v = { dispatchThreadID.x - groupDispatchThreadID.x, dispatchThreadID.y - groupDispatchThreadID.y, dispatchThreadID.z - groupDispatchThreadID.z }; ");
+ m_writer->emit("return v;\n");
+ m_writer->dedent();
+ m_writer->emit("}\n");
+ }
+
if (entryPointGlobalParams)
{
emitGlobalInst(entryPointGlobalParams);
@@ -2695,7 +2815,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
IRFunc* func = as<IRFunc>(action.inst);
auto entryPointLayout = asEntryPoint(func);
- if (entryPointLayout)
+ if (entryPointLayout && entryPointLayout->profile.GetStage() == Stage::Compute)
{
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid
// SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID.
@@ -2703,40 +2823,30 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
static const UInt kAxisCount = 3;
UInt sizeAlongAxis[kAxisCount];
+ // TODO: this is kind of gross because we are using a public
+ // reflection API function, rather than some kind of internal
+ // utility it forwards to...
+ spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
+
String funcName = getFuncName(func);
{
- _emitEntryPointDefinitionStart(func, entryPointGlobalParams, funcName);
+ _emitEntryPointDefinitionStart(func, entryPointGlobalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
- // Emit dispatchThreadID
- if (entryPointLayout->profile.GetStage() == Stage::Compute)
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
{
- // TODO: this is kind of gross because we are using a public
- // reflection API function, rather than some kind of internal
- // utility it forwards to...
- spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
-
- m_writer->emit("context.dispatchThreadID = {\n");
- m_writer->indent();
-
- StringBuilder builder;
- for (int i = 0; i < kAxisCount; ++i)
- {
- builder.Clear();
- const char elem[2] = {s_elemNames[i], 0};
- builder << "varyingInput->groupID." << elem << " * " << sizeAlongAxis[i] << " + varyingInput->groupThreadID." << elem;
- if (i < kAxisCount - 1)
- {
- builder << ",";
- }
- builder << "\n";
- m_writer->emit(builder);
- }
-
- m_writer->dedent();
- m_writer->emit("};\n");
+ m_writer->emit("context.groupDispatchThreadID = ");
+ _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
+ }
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ m_writer->emit("context.groupID = varyingInput->groupID;\n");
}
+ // Emit dispatchThreadID
+ m_writer->emit("context.dispatchThreadID = ");
+ _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID"));
+
m_writer->emit("context._");
m_writer->emit(funcName);
m_writer->emit("();\n");
@@ -2752,36 +2862,55 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
String groupFuncName = builder;
- _emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupFuncName);
+ _emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
- // Emit dispatchThreadID
- if (entryPointLayout->profile.GetStage() == Stage::Compute)
+ m_writer->emit("const uint3 start = ");
+ _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
+
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
{
- spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
+ m_writer->emit("context.groupDispatchThreadID = start;\n");
+ }
- {
- m_writer->emit("const uint3 start = {\n");
- m_writer->indent();
- for (int i = 0; i < kAxisCount; ++i)
- {
- builder.Clear();
- const char elem[2] = { s_elemNames[i], 0 };
- builder << "varyingInput->groupID." << elem << " * " << sizeAlongAxis[i];
- if (i < kAxisCount - 1)
- {
- builder << ",";
- }
- builder << "\n";
- m_writer->emit(builder);
- }
- m_writer->dedent();
- m_writer->emit("};\n");
- }
- m_writer->emit("context.dispatchThreadID = start;\n");
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ m_writer->emit("context.groupID = varyingInput->groupID;\n");
+ }
+ m_writer->emit("context.dispatchThreadID = start;\n");
+
+ _emitEntryPointGroup(sizeAlongAxis, funcName);
+ _emitEntryPointDefinitionEnd(func);
+ }
+
+ // Emit the group version which runs for all elements in a thread group
+ {
+ StringBuilder builder;
+ builder << getFuncName(func);
+ builder << "_GroupRange";
- _emitEntryPointGroup(sizeAlongAxis, funcName);
+ String groupRangeFuncName = builder;
+
+ _emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupRangeFuncName, UnownedStringSlice::fromLiteral("GroupComputeVaryingInput"));
+
+ m_writer->emit("const uint3 start = ");
+ _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
+ m_writer->emit("const uint3 end = ");
+ _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice());
+
+#if 0
+ // Not needed as will be emitted as part of the loop
+ m_writer->emit("context.dispatchThreadID = start;\n");
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
+ {
+ m_writer->emit("context.groupDispatchThreadID = start;");
+ }
+ if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
+ {
+ m_writer->emit("context.groupID = varyingInput->startGroupID;\n");
}
+#endif
+ _emitEntryPointGroupRange(sizeAlongAxis, funcName);
_emitEntryPointDefinitionEnd(func);
}
}
diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h
index 906086d71..f15a302e5 100644
--- a/source/slang/slang-emit-cpp.h
+++ b/source/slang/slang-emit-cpp.h
@@ -132,6 +132,17 @@ public:
SLANG_CPP_INTRINSIC_OP(SLANG_CPP_INTRINSIC_OP_ENUM)
};
+ typedef uint32_t SemanticUsedFlags;
+ struct SemanticUsedFlag
+ {
+ enum Enum : SemanticUsedFlags
+ {
+ DispatchThreadID = 0x01,
+ GroupThreadID = 0x02,
+ GroupID = 0x04,
+ };
+ };
+
struct OperationInfo
{
UnownedStringSlice name;
@@ -257,9 +268,12 @@ protected:
SlangResult _calcTextureTypeName(IRTextureTypeBase* texType, StringBuilder& outName);
- void _emitEntryPointDefinitionStart(IRFunc* func, IRGlobalParam* entryPointGlobalParams, const String& funcName);
+ void _emitEntryPointDefinitionStart(IRFunc* func, IRGlobalParam* entryPointGlobalParams, const String& funcName, const UnownedStringSlice& varyingTypeName);
void _emitEntryPointDefinitionEnd(IRFunc* func);
void _emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName);
+ void _emitEntryPointGroupRange(const UInt sizeAlongAxis[3], const String& funcName);
+
+ void _emitInitAxisValues(const UInt sizeAlongAxis[3], const UnownedStringSlice& mulName, const UnownedStringSlice& addName);
Dictionary<SpecializedIntrinsic, StringSlicePool::Handle> m_intrinsicNameMap;
Dictionary<IRType*, StringSlicePool::Handle> m_typeNameMap;
@@ -295,6 +309,8 @@ protected:
List<IntrinsicOp> m_intrinsicOpMap;
StringSlicePool m_slicePool;
+
+ SemanticUsedFlags m_semanticUsedFlags;
};
}
diff --git a/tests/compute/semantic.slang b/tests/compute/semantic.slang
new file mode 100644
index 000000000..12b0ca853
--- /dev/null
+++ b/tests/compute/semantic.slang
@@ -0,0 +1,13 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -compile-arg -O3 -compute-dispatch 3,1,1
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -compute-dispatch 3,1,1
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -compute-dispatch 3,1,1
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -compute-dispatch 3,1,1
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):dxbinding(0),glbinding(0),out,name outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID, uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadId)
+{
+ outputBuffer[dispatchThreadID.x] = (dispatchThreadID.x << 8) | (groupID.x << 4) | (groupThreadID.x);
+} \ No newline at end of file
diff --git a/tests/compute/semantic.slang.expected.txt b/tests/compute/semantic.slang.expected.txt
new file mode 100644
index 000000000..4c57d34ed
--- /dev/null
+++ b/tests/compute/semantic.slang.expected.txt
@@ -0,0 +1,12 @@
+0
+101
+202
+303
+410
+511
+612
+713
+820
+921
+A22
+B23
diff --git a/tools/render-test/cpu-compute-util.cpp b/tools/render-test/cpu-compute-util.cpp
index 4294ad539..1b1adef82 100644
--- a/tools/render-test/cpu-compute-util.cpp
+++ b/tools/render-test/cpu-compute-util.cpp
@@ -301,7 +301,7 @@ static CPUComputeUtil::Resource* _newOneTexture2D(int elemCount)
return SLANG_OK;
}
-/* static */SlangResult CPUComputeUtil::execute(const ShaderCompilerUtil::OutputAndLayout& compilationAndLayout, Context& context)
+/* static */SlangResult CPUComputeUtil::execute(const uint32_t dispatchSize[3], const ShaderCompilerUtil::OutputAndLayout& compilationAndLayout, Context& context)
{
auto request = compilationAndLayout.output.request;
auto reflection = (slang::ShaderReflection*) spGetReflection(request);
@@ -313,10 +313,12 @@ static CPUComputeUtil::Resource* _newOneTexture2D(int elemCount)
struct UniformState;
typedef void(*Func)(CPPPrelude::ComputeVaryingInput* varyingInput, CPPPrelude::UniformEntryPointParams* uniformEntryPointParams, UniformState* uniformState);
+ typedef void(*GroupRangeFunc)(CPPPrelude::GroupComputeVaryingInput* varyingInput, CPPPrelude::UniformEntryPointParams* uniformEntryPointParams, UniformState* uniformState);
slang::EntryPointReflection* entryPoint = nullptr;
Func func = nullptr;
Func groupFunc = nullptr;
+ GroupRangeFunc groupRangeFunc = nullptr;
{
auto entryPointCount = reflection->getEntryPointCount();
SLANG_ASSERT(entryPointCount == 1);
@@ -326,27 +328,58 @@ static CPUComputeUtil::Resource* _newOneTexture2D(int elemCount)
const char* entryPointName = entryPoint->getName();
func = (Func)sharedLibrary->findFuncByName(entryPointName);
- StringBuilder groupEntryPointName;
- groupEntryPointName << entryPointName << "_Group";
+ {
+ StringBuilder groupEntryPointName;
+ groupEntryPointName << entryPointName << "_Group";
+
+ groupFunc = (Func)sharedLibrary->findFuncByName(groupEntryPointName.getBuffer());
+ }
- groupFunc = (Func)sharedLibrary->findFuncByName(groupEntryPointName.getBuffer());
+ {
+ StringBuilder groupRangeEntryPointName;
+ groupRangeEntryPointName << entryPointName << "_GroupRange";
+
+ groupRangeFunc = (GroupRangeFunc)sharedLibrary->findFuncByName(groupRangeEntryPointName.getBuffer());
+ }
- if (func == nullptr && groupFunc == nullptr)
+ if (func == nullptr && groupFunc == nullptr && groupRangeFunc == nullptr)
{
return SLANG_FAIL;
}
}
// If we have the group function, that's the faster way to execute all threads in group...
- if (groupFunc)
+ if (groupRangeFunc)
{
UniformState* uniformState = (UniformState*)context.binding.m_rootBuffer.m_data;
CPPPrelude::UniformEntryPointParams* uniformEntryPointParams = (CPPPrelude::UniformEntryPointParams*)context.binding.m_entryPointBuffer.m_data;
+ CPPPrelude::GroupComputeVaryingInput varying;
+ varying.startGroupID = {};
+ varying.endGroupID = { dispatchSize[0], dispatchSize[1], dispatchSize[2] };
+
+ groupRangeFunc(&varying, uniformEntryPointParams, uniformState);
+ }
+ else if (groupFunc)
+ {
CPPPrelude::ComputeVaryingInput varying;
- varying.groupID = {};
- groupFunc(&varying, uniformEntryPointParams, uniformState);
+ for (uint32_t groupZ = 0; groupZ < dispatchSize[2]; ++groupZ)
+ {
+ for (uint32_t groupY = 0; groupY < dispatchSize[1]; ++groupY)
+ {
+ for (uint32_t groupX = 0; groupX < dispatchSize[0]; ++groupX)
+ {
+ UniformState* uniformState = (UniformState*)context.binding.m_rootBuffer.m_data;
+ CPPPrelude::UniformEntryPointParams* uniformEntryPointParams = (CPPPrelude::UniformEntryPointParams*)context.binding.m_entryPointBuffer.m_data;
+
+ varying.groupID = {groupX, groupY, groupZ};
+
+ groupFunc(&varying, uniformEntryPointParams, uniformState);
+ }
+ }
+ }
+
}
else
{
@@ -359,19 +392,29 @@ static CPUComputeUtil::Resource* _newOneTexture2D(int elemCount)
CPPPrelude::UniformEntryPointParams* uniformEntryPointParams = (CPPPrelude::UniformEntryPointParams*)context.binding.m_entryPointBuffer.m_data;
CPPPrelude::ComputeVaryingInput varying;
- varying.groupID = {};
- for (int z = 0; z < int(numThreadsPerAxis[2]); ++z)
+ for (uint32_t groupZ = 0; groupZ < dispatchSize[2]; ++groupZ)
{
- varying.groupThreadID.z = z;
- for (int y = 0; y < int(numThreadsPerAxis[1]); ++y)
+ for (uint32_t groupY = 0; groupY < dispatchSize[1]; ++groupY)
{
- varying.groupThreadID.y = y;
- for (int x = 0; x < int(numThreadsPerAxis[0]); ++x)
+ for (uint32_t groupX = 0; groupX < dispatchSize[0]; ++groupX)
{
- varying.groupThreadID.x = x;
+ varying.groupID = {groupX, groupY, groupZ};
- func(&varying, uniformEntryPointParams, uniformState);
+ for (int z = 0; z < int(numThreadsPerAxis[2]); ++z)
+ {
+ varying.groupThreadID.z = z;
+ for (int y = 0; y < int(numThreadsPerAxis[1]); ++y)
+ {
+ varying.groupThreadID.y = y;
+ for (int x = 0; x < int(numThreadsPerAxis[0]); ++x)
+ {
+ varying.groupThreadID.x = x;
+
+ func(&varying, uniformEntryPointParams, uniformState);
+ }
+ }
+ }
}
}
}
diff --git a/tools/render-test/cpu-compute-util.h b/tools/render-test/cpu-compute-util.h
index cbc4e6e58..b30ef146b 100644
--- a/tools/render-test/cpu-compute-util.h
+++ b/tools/render-test/cpu-compute-util.h
@@ -29,7 +29,7 @@ struct CPUComputeUtil
static SlangResult calcBindings(const ShaderCompilerUtil::OutputAndLayout& compilationAndLayout, Context& outContext);
- static SlangResult execute(const ShaderCompilerUtil::OutputAndLayout& compilationAndLayout, Context& outContext);
+ static SlangResult execute(const uint32_t dispatchSize[3], const ShaderCompilerUtil::OutputAndLayout& compilationAndLayout, Context& outContext);
static SlangResult writeBindings(const ShaderInputLayout& layout, const List<CPUMemoryBinding::Buffer>& buffers, const Slang::String& fileName);
};
diff --git a/tools/render-test/options.cpp b/tools/render-test/options.cpp
index 1cf0ffbe8..d2f21a5d9 100644
--- a/tools/render-test/options.cpp
+++ b/tools/render-test/options.cpp
@@ -179,6 +179,34 @@ SlangResult parseOptions(int argc, const char*const* argv, Slang::WriterHelper s
gOptions.adapter = *argCursor++;
}
+ else if (strcmp(arg, "-compute-dispatch") == 0)
+ {
+ if (argCursor == argEnd)
+ {
+ stdError.print("error: comma separated compute dispatch size for '%s'\n", arg);
+ return SLANG_FAIL;
+ }
+ List<UnownedStringSlice> slices;
+ StringUtil::split(UnownedStringSlice(*argCursor++), ',', slices);
+ if (slices.getCount() != 3)
+ {
+ stdError.print("error: expected 3 comma separated integers for compute dispatch size for '%s'\n", arg);
+ return SLANG_FAIL;
+ }
+
+ String string;
+ for (Index i = 0; i < 3; ++i)
+ {
+ string = slices[i];
+ int v = StringToInt(string);
+ if (v < 1)
+ {
+ stdError.print("error: expected 3 comma positive integers for compute dispatch size for '%s'\n", arg);
+ return SLANG_FAIL;
+ }
+ gOptions.computeDispatchSize[i] = v;
+ }
+ }
else
{
// Lookup
diff --git a/tools/render-test/options.h b/tools/render-test/options.h
index a57c94ed0..67eae6603 100644
--- a/tools/render-test/options.h
+++ b/tools/render-test/options.h
@@ -64,6 +64,8 @@ struct Options
Slang::List<Slang::CommandLine::Arg> compileArgs;
Slang::String adapter; ///< The adapter to use either name or index
+
+ uint32_t computeDispatchSize[3] = { 1, 1, 1 };
};
extern Options gOptions;
diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp
index 0e457f9e4..2a0b9a6c9 100644
--- a/tools/render-test/render-test-main.cpp
+++ b/tools/render-test/render-test-main.cpp
@@ -232,7 +232,7 @@ void RenderTestApp::runCompute()
auto pipelineType = PipelineType::Compute;
m_renderer->setPipelineState(pipelineType, m_pipelineState);
m_bindingState->apply(m_renderer, pipelineType);
- m_renderer->dispatchCompute(1, 1, 1);
+ m_renderer->dispatchCompute(m_options.computeDispatchSize[0], m_options.computeDispatchSize[1], m_options.computeDispatchSize[2]);
}
void RenderTestApp::finalize()
@@ -461,7 +461,7 @@ SLANG_TEST_TOOL_API SlangResult innerMain(Slang::StdWriters* stdWriters, SlangSe
CPUComputeUtil::Context context;
SLANG_RETURN_ON_FAIL(CPUComputeUtil::calcBindings(compilationAndLayout, context));
- SLANG_RETURN_ON_FAIL(CPUComputeUtil::execute(compilationAndLayout, context));
+ SLANG_RETURN_ON_FAIL(CPUComputeUtil::execute(gOptions.computeDispatchSize, compilationAndLayout, context));
// Dump everything out that was written
return CPUComputeUtil::writeBindings(compilationAndLayout.layout, context.buffers, gOptions.outputPath);