summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-05-01 16:44:22 -0700
committerGitHub <noreply@github.com>2024-05-01 16:44:22 -0700
commit0bb826f8b92aec330875d0b966c1f4a6b99988bf (patch)
treef0d086d4bfb93e302fcb8232816842ccfc182480
parent4533c825fe628e08228037b846ee9d10004fd56f (diff)
SPIRV: Fix performance issue when handling large arrays. (#4064)
* SPIRV: Fix performance issue when handling large arrays. * Add test for packing. * Fix clang.
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp76
-rw-r--r--source/slang/slang-ir-util.cpp5
-rw-r--r--tests/spirv/large-struct-pack.slang29
-rw-r--r--tests/spirv/large-struct-ptr.slang20
-rw-r--r--tests/spirv/large-struct.slang32
-rw-r--r--tools/render-test/render-test-main.cpp6
-rw-r--r--tools/render-test/shader-input-layout.cpp5
-rw-r--r--tools/render-test/shader-input-layout.h1
8 files changed, 160 insertions, 14 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index e9fbfc0d1..360596741 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -9,6 +9,8 @@ namespace Slang
{
struct LoweredElementTypeContext
{
+ static const IRIntegerValue kMaxArraySizeToUnroll = 32;
+
struct LoweredElementTypeInfo
{
IRType* originalType;
@@ -161,17 +163,42 @@ namespace Slang
auto packedParam = builder.emitParam(structType);
auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey);
auto count = getIntVal(arrayType->getElementCount());
- List<IRInst*> args;
- args.setCount((Index)count);
- for (IRIntegerValue ii = 0; ii < count; ++ii)
+ IRInst* result = nullptr;
+ if (count <= kMaxArraySizeToUnroll)
+ {
+ // If the array is small enough, just process each element directly.
+ List<IRInst*> args;
+ args.setCount((Index)count);
+ for (IRIntegerValue ii = 0; ii < count; ++ii)
+ {
+ auto packedElement = builder.emitElementExtract(packedArray, ii);
+ auto originalElement = innerTypeInfo.convertLoweredToOriginal
+ ? builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement)
+ : packedElement;
+ args[(Index)ii] = originalElement;
+ }
+ result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
+
+ }
+ else
{
- auto packedElement = builder.emitElementExtract(packedArray, ii);
+ // The general case for large arrays is to emit a loop through the elements.
+ IRVar* resultVar = builder.emitVar(arrayType);
+ IRBlock* loopBodyBlock;
+ IRBlock* loopBreakBlock;
+ auto loopParam = emitLoopBlocks(&builder, builder.getIntValue(builder.getIntType(), 0), builder.getIntValue(builder.getIntType(), count),
+ loopBodyBlock, loopBreakBlock);
+
+ builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst());
+ auto packedElement = builder.emitElementExtract(packedArray, loopParam);
auto originalElement = innerTypeInfo.convertLoweredToOriginal
? builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement)
: packedElement;
- args[(Index)ii] = originalElement;
+ auto varPtr = builder.emitElementAddress(resultVar, loopParam);
+ builder.emitStore(varPtr, originalElement);
+ builder.setInsertInto(loopBreakBlock);
+ result = builder.emitLoad(resultVar);
}
- auto result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
builder.emitReturn(result);
return func;
}
@@ -191,18 +218,43 @@ namespace Slang
builder.setInsertInto(func);
builder.emitBlock();
auto originalParam = builder.emitParam(arrayType);
+ IRInst* packedArray = nullptr;
auto count = getIntVal(arrayType->getElementCount());
- List<IRInst*> args;
- args.setCount((Index)count);
- for (IRIntegerValue ii = 0; ii < count; ++ii)
+ if (count <= kMaxArraySizeToUnroll)
+ {
+ // If the array is small enough, just process each element directly.
+ List<IRInst*> args;
+ args.setCount((Index)count);
+ for (IRIntegerValue ii = 0; ii < count; ++ii)
+ {
+ auto originalElement = builder.emitElementExtract(originalParam, ii);
+ auto packedElement = innerTypeInfo.convertOriginalToLowered
+ ? builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement)
+ : originalElement;
+ args[(Index)ii] = packedElement;
+ }
+ packedArray = builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer());
+ }
+ else
{
- auto originalElement = builder.emitElementExtract(originalParam, ii);
+ // The general case for large arrays is to emit a loop through the elements.
+ IRVar* packedArrayVar = builder.emitVar(innerArrayType);
+ IRBlock* loopBodyBlock;
+ IRBlock* loopBreakBlock;
+ auto loopParam = emitLoopBlocks(&builder, builder.getIntValue(builder.getIntType(), 0), builder.getIntValue(builder.getIntType(), count),
+ loopBodyBlock, loopBreakBlock);
+
+ builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst());
+ auto originalElement = builder.emitElementExtract(originalParam, loopParam);
auto packedElement = innerTypeInfo.convertOriginalToLowered
? builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement)
: originalElement;
- args[(Index)ii] = packedElement;
+ auto varPtr = builder.emitElementAddress(packedArrayVar, loopParam);
+ builder.emitStore(varPtr, packedElement);
+ builder.setInsertInto(loopBreakBlock);
+ packedArray = builder.emitLoad(packedArrayVar);
}
- auto packedArray = builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer());
+
auto result = builder.emitMakeStruct(structType, 1, &packedArray);
builder.emitReturn(result);
return func;
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index f6b0acaed..bcb9439fb 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -872,18 +872,21 @@ IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IR
IRBuilder loopBuilder = *builder;
auto loopHeadBlock = loopBuilder.emitBlock();
loopBodyBlock = loopBuilder.emitBlock();
+ auto ifBreakBlock = loopBuilder.emitBlock();
loopBreakBlock = loopBuilder.emitBlock();
auto loopContinueBlock = loopBuilder.emitBlock();
builder->emitLoop(loopHeadBlock, loopBreakBlock, loopHeadBlock, 1, &initVal);
loopBuilder.setInsertInto(loopHeadBlock);
auto loopParam = loopBuilder.emitParam(initVal->getFullType());
auto cmpResult = loopBuilder.emitLess(loopParam, finalVal);
- loopBuilder.emitIfElse(cmpResult, loopBodyBlock, loopBreakBlock, loopBreakBlock);
+ loopBuilder.emitIfElse(cmpResult, loopBodyBlock, ifBreakBlock, ifBreakBlock);
loopBuilder.setInsertInto(loopBodyBlock);
loopBuilder.emitBranch(loopContinueBlock);
loopBuilder.setInsertInto(loopContinueBlock);
auto newParam = loopBuilder.emitAdd(loopParam->getFullType(), loopParam, loopBuilder.getIntValue(loopBuilder.getIntType(), 1));
loopBuilder.emitBranch(loopHeadBlock, 1, &newParam);
+ loopBuilder.setInsertInto(ifBreakBlock);
+ loopBuilder.emitBranch(loopBreakBlock);
return loopParam;
}
diff --git a/tests/spirv/large-struct-pack.slang b/tests/spirv/large-struct-pack.slang
new file mode 100644
index 000000000..df15ac67c
--- /dev/null
+++ b/tests/spirv/large-struct-pack.slang
@@ -0,0 +1,29 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute
+
+// Check that when generating spirv directly, we use a loop
+// to copy large arrays in a local variable to a buffer, instead of emitting
+// unrolled code that reads each element of the array individually.
+
+struct WorkData
+{
+ int B[1024];
+};
+
+//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4, count=1024)
+RWStructuredBuffer<WorkData> resultBuffer;
+
+// CHECK: OpLoopMerge
+// CHECK: OpLoopMerge
+
+// BUF: 0
+// BUF: 1
+[numthreads(1, 1, 1)]
+void computeMain(uint3 tid: SV_DispatchThreadID)
+{
+ WorkData wd;
+ for (int i = 0; i < 1024; i++)
+ wd.B[i] = i;
+ resultBuffer[0] = wd;
+}
diff --git a/tests/spirv/large-struct-ptr.slang b/tests/spirv/large-struct-ptr.slang
new file mode 100644
index 000000000..131ccdbb9
--- /dev/null
+++ b/tests/spirv/large-struct-ptr.slang
@@ -0,0 +1,20 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460
+
+struct WorkData {
+ float A[2048 * 2048];
+ float B[2048 * 2048];
+};
+struct PushData {
+ WorkData* Input;
+ float* Dest;
+};
+
+[vk::push_constant] ConstantBuffer<PushData> cb;
+
+// CHECK: OpEntryPoint
+
+[numthreads(64, 1, 1)]
+void ComputeMain(uint tid: SV_DispatchThreadID)
+{
+ cb.Dest[tid] = cb.Input->A[tid] * cb.Input->B[tid];
+} \ No newline at end of file
diff --git a/tests/spirv/large-struct.slang b/tests/spirv/large-struct.slang
new file mode 100644
index 000000000..7738a5fcf
--- /dev/null
+++ b/tests/spirv/large-struct.slang
@@ -0,0 +1,32 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-d3d12 -compute -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -output-using-type
+
+// Check that when generating spirv directly, we use a loop
+// to copy large arrays in input data out into a local variable, instead of emitting
+// unrolled code that reads each element of the array individually.
+
+struct WorkData
+{
+ float A[2*2];
+ float B[1024];
+
+ float Foo(uint i) { return A[i] * B[i]; }
+};
+
+//TEST_INPUT:set input = new WorkData{[1.0, 2.0, 3.0, 4.0], [10.0, 20.0, 30.0, 40.0]}
+ConstantBuffer<WorkData> input;
+
+//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<float> resultBuffer;
+
+// CHECK: OpLoopMerge
+
+[numthreads(2, 1, 1)]
+void computeMain(uint3 tid: SV_DispatchThreadID)
+{
+ // BUF: 10.0
+ // BUF: 40.0
+ resultBuffer[tid.x] = input.Foo(tid.x);
+} \ No newline at end of file
diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp
index 02c0ea86a..fcdd4b54d 100644
--- a/tools/render-test/render-test-main.cpp
+++ b/tools/render-test/render-test-main.cpp
@@ -207,7 +207,10 @@ struct AssignValsFromLayoutContext
{
const InputBufferDesc& srcBuffer = srcVal->bufferDesc;
auto& bufferData = srcVal->bufferData;
- const size_t bufferSize = bufferData.getCount() * sizeof(uint32_t);
+ const size_t bufferSize = Math::Max((size_t)bufferData.getCount() * sizeof(uint32_t), (size_t)(srcBuffer.elementCount * srcBuffer.stride));
+ bufferData.reserve(bufferSize / sizeof(uint32_t));
+ for (size_t i = bufferData.getCount(); i < bufferSize / sizeof(uint32_t); i++)
+ bufferData.add(0);
ComPtr<IBufferResource> bufferResource;
SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBufferResource(srcBuffer, /*entry.isOutput,*/ bufferSize, bufferData.getBuffer(), device, bufferResource));
@@ -232,6 +235,7 @@ struct AssignValsFromLayoutContext
const InputBufferDesc& counterBufferDesc{
InputBufferType::StorageBuffer,
sizeof(uint32_t),
+ 1,
Format::Unknown,
};
SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBufferResource(
diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp
index 96f5db6e0..3012d45a4 100644
--- a/tools/render-test/shader-input-layout.cpp
+++ b/tools/render-test/shader-input-layout.cpp
@@ -222,6 +222,11 @@ namespace renderer_test
parser.Read("=");
val->bufferDesc.stride = parser.ReadInt();
}
+ else if (word == "count")
+ {
+ parser.Read("=");
+ val->bufferDesc.elementCount = parser.ReadInt();
+ }
else if (word == "counter")
{
parser.Read("=");
diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h
index de1da3da9..996635b94 100644
--- a/tools/render-test/shader-input-layout.h
+++ b/tools/render-test/shader-input-layout.h
@@ -68,6 +68,7 @@ struct InputBufferDesc
{
InputBufferType type = InputBufferType::StorageBuffer;
int stride = 0; // stride == 0 indicates an unstructured buffer.
+ int elementCount = 1;
Format format = Format::Unknown;
// For RWStructuredBuffer, AppendStructuredBuffer, ConsumeStructuredBuffer
// the default value of 0xffffffff indicates that a counter buffer should