summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-11-12 21:01:47 -0800
committerGitHub <noreply@github.com>2024-11-12 21:01:47 -0800
commit0754abee603c1afa7803a444124acc9d268d2f0a (patch)
tree016bf4c03ff58595c62ec94e902137b1191e5960
parent567c7e09b6df36b535c4ffbccd6a3658d18e04c2 (diff)
Push buffer load to end of access chain. (#5544)
* Push buffer load to end of access chain. * Update test. * Fix. * Fix. * Fix. * Make more robust. * Fix.
-rw-r--r--source/slang/slang-emit.cpp6
-rw-r--r--source/slang/slang-ir-defer-buffer-load.cpp369
-rw-r--r--source/slang/slang-ir-defer-buffer-load.h26
-rw-r--r--source/slang/slang-ir.cpp2
-rw-r--r--tests/spirv/sb-load-2.slang23
-rw-r--r--tests/spirv/sb-load.slang24
6 files changed, 449 insertions, 1 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 1950f251c..05bb12ecc 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -30,6 +30,7 @@
#include "slang-ir-com-interface.h"
#include "slang-ir-composite-reg-to-mem.h"
#include "slang-ir-dce.h"
+#include "slang-ir-defer-buffer-load.h"
#include "slang-ir-defunctionalization.h"
#include "slang-ir-diff-call.h"
#include "slang-ir-dll-export.h"
@@ -951,6 +952,11 @@ Result linkAndOptimizeIR(
// Inline calls to any functions marked with [__unsafeInlineEarly] or [ForceInline].
performForceInlining(irModule);
+ // Push `structuredBufferLoad` to the end of access chain to avoid loading unnecessary data.
+ if (isKhronosTarget(targetRequest) || isMetalTarget(targetRequest) ||
+ isWGPUTarget(targetRequest))
+ deferBufferLoad(irModule);
+
// Specialization can introduce dead code that could trip
// up downstream passes like type legalization, so we
// will run a DCE pass to clean up after the specialization.
diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp
new file mode 100644
index 000000000..d1eb4b5e5
--- /dev/null
+++ b/source/slang/slang-ir-defer-buffer-load.cpp
@@ -0,0 +1,369 @@
+#include "slang-ir-defer-buffer-load.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dominators.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-redundancy-removal.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+struct DeferBufferLoadContext
+{
+ struct AccessChain
+ {
+ List<IRInst*> chain;
+ mutable HashCode64 hash = 0;
+
+ bool operator==(const AccessChain& rhs) const
+ {
+ ensureHash();
+ rhs.ensureHash();
+ if (hash != rhs.hash)
+ return false;
+ if (chain.getCount() != rhs.chain.getCount())
+ return false;
+ for (Index i = 0; i < chain.getCount(); i++)
+ {
+ if (chain[i] != rhs.chain[i])
+ return false;
+ }
+ return true;
+ }
+ void ensureHash() const
+ {
+ if (hash == 0)
+ {
+ for (auto inst : chain)
+ {
+ hash = combineHash(hash, Slang::getHashCode(inst));
+ }
+ }
+ }
+ HashCode64 getHashCode() const
+ {
+ ensureHash();
+ return hash;
+ }
+ };
+
+ // Map an original SSA value to a pointer that can be used to load the value.
+ Dictionary<AccessChain, IRInst*> mapAccessChainToPtr;
+ Dictionary<IRInst*, IRInst*> mapValueToPtr;
+ // Map an ptr to its loaded value.
+ Dictionary<IRInst*, IRInst*> mapPtrToValue;
+
+ IRFunc* currentFunc = nullptr;
+ IRDominatorTree* dominatorTree = nullptr;
+
+ // Find the block that is dominated by all dependent blocks, and is the earliest block that
+ // dominates the target block.
+ // This is the place where we can insert the load instruction such that all access chain
+ // operands are defined and the load can be made avaialble to the location of valueInst.
+ //
+ IRBlock* findEarliestDominatingBlock(IRInst* valueInst, List<IRBlock*>& dependentBlocks)
+ {
+ auto targetBlock = getBlock(valueInst);
+ while (targetBlock)
+ {
+ auto idom = dominatorTree->getImmediateDominator(targetBlock);
+ if (!idom)
+ break;
+ bool isValid = true;
+ for (auto block : dependentBlocks)
+ {
+ if (!dominatorTree->dominates(block, idom))
+ {
+ isValid = false;
+ break;
+ }
+ }
+ if (isValid)
+ {
+ targetBlock = idom;
+ }
+ else
+ {
+ break;
+ }
+ }
+ return targetBlock;
+ }
+
+ // Find the earliest instruction before which we can insert the load instruction such that
+ // all dependent instructions for the load address are defined, and the load can reach all
+ // locations where the address is available.
+ //
+ IRInst* findEarliestInsertionPoint(IRInst* valueInst, AccessChain& chain)
+ {
+ List<IRBlock*> dependentBlocks;
+ List<IRInst*> dependentInsts;
+ for (auto inst : chain.chain)
+ {
+ if (auto block = getBlock(inst))
+ {
+ dependentBlocks.add(block);
+ dependentInsts.add(inst);
+ }
+ }
+ auto targetBlock = findEarliestDominatingBlock(valueInst, dependentBlocks);
+ IRInst* insertBeforeInst =
+ targetBlock == getBlock(valueInst) ? valueInst : targetBlock->getTerminator();
+ for (;;)
+ {
+ auto prev = insertBeforeInst->getPrevInst();
+ if (!prev)
+ break;
+ bool valid = true;
+ for (auto inst : dependentInsts)
+ {
+ if (!dominatorTree->dominates(inst, prev) || inst == prev)
+ {
+ valid = false;
+ break;
+ }
+ }
+ if (valid)
+ {
+ insertBeforeInst = prev;
+ }
+ else
+ {
+ break;
+ }
+ }
+ return insertBeforeInst;
+ }
+
+ // Ensure that for an original SSA value, we have formed a pointer that can be used to load the
+ // value.
+ IRInst* ensurePtr(IRInst* valueInst)
+ {
+ IRInst* result = nullptr;
+ if (mapValueToPtr.tryGetValue(valueInst, result))
+ return result;
+ AccessChain chain;
+ IRInst* current = valueInst;
+ while (current)
+ {
+ bool processed = false;
+ switch (current->getOp())
+ {
+ case kIROp_GetElement:
+ case kIROp_FieldExtract:
+ chain.chain.add(current->getOperand(1));
+ current = current->getOperand(0);
+ processed = true;
+ break;
+ default:
+ break;
+ }
+ if (!processed)
+ break;
+ }
+ chain.chain.add(current);
+ chain.chain.reverse();
+ if (mapAccessChainToPtr.tryGetValue(chain, result))
+ return result;
+
+ // Find the proper place to insert the load instruction.
+ // This is the location where all operands of the access chain are defined.
+ // And is the earliest block so all possible uses of the value at access chain
+ // can be reached.
+ IRBuilder b(valueInst);
+
+ auto insertBeforeInst = findEarliestInsertionPoint(valueInst, chain);
+ b.setInsertBefore(insertBeforeInst);
+
+ switch (valueInst->getOp())
+ {
+ case kIROp_StructuredBufferLoad:
+ case kIROp_StructuredBufferLoadStatus:
+ {
+ result = b.emitRWStructuredBufferGetElementPtr(
+ valueInst->getOperand(0),
+ valueInst->getOperand(1));
+ break;
+ }
+ case kIROp_GetElement:
+ {
+ auto ptr = ensurePtr(valueInst->getOperand(0));
+ if (!ptr)
+ return nullptr;
+ result = b.emitElementAddress(ptr, valueInst->getOperand(1));
+ break;
+ }
+ case kIROp_FieldExtract:
+ {
+ auto ptr = ensurePtr(valueInst->getOperand(0));
+ if (!ptr)
+ return nullptr;
+ result = b.emitFieldAddress(ptr, valueInst->getOperand(1));
+ break;
+ }
+ }
+ if (result)
+ {
+ mapAccessChainToPtr[chain] = result;
+ mapValueToPtr[valueInst] = result;
+ }
+ return result;
+ }
+
+ static bool isStructuredBufferLoad(IRInst* inst)
+ {
+ // Note: we cannot defer loads from RWStructuredBuffer because there can be other
+ // instructions that modify the buffer.
+ switch (inst->getOp())
+ {
+ case kIROp_StructuredBufferLoad:
+ case kIROp_StructuredBufferLoadStatus:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ // Ensure that for a pointer value, we have created a load instruction to materialize the value.
+ IRInst* materializePointer(IRBuilder& builder, IRInst* loadInst)
+ {
+ auto ptr = ensurePtr(loadInst);
+ if (!ptr)
+ return nullptr;
+ IRInst* result = nullptr;
+ if (mapPtrToValue.tryGetValue(ptr, result))
+ return result;
+ builder.setInsertAfter(ptr);
+ result = builder.emitLoad(ptr);
+ mapPtrToValue[ptr] = result;
+ return result;
+ }
+
+ static bool isSimpleType(IRInst* type)
+ {
+ if (as<IRBasicType>(type))
+ return true;
+ if (as<IRVectorType>(type))
+ return true;
+ if (as<IRMatrixType>(type))
+ return true;
+ return false;
+ }
+
+ void deferBufferLoadInst(IRBuilder& builder, List<IRInst*>& workList, IRInst* loadInst)
+ {
+ // Don't defer the load anymore if the type is simple.
+ if (isSimpleType(loadInst->getDataType()))
+ {
+ if (!isStructuredBufferLoad(loadInst))
+ {
+ auto materializedVal = materializePointer(builder, loadInst);
+ loadInst->replaceUsesWith(materializedVal);
+ }
+ return;
+ }
+
+ // Otherwise, look for all uses and try to defer the load before actual use of the value.
+ ShortList<IRInst*> pendingWorkList;
+ bool needMaterialize = false;
+ traverseUses(
+ loadInst,
+ [&](IRUse* use)
+ {
+ if (needMaterialize)
+ return;
+
+ auto user = use->getUser();
+ switch (user->getOp())
+ {
+ case kIROp_GetElement:
+ case kIROp_FieldExtract:
+ {
+ auto basePtr = ensurePtr(loadInst);
+ if (!basePtr)
+ return;
+ pendingWorkList.add(user);
+ }
+ break;
+ default:
+ if (!isStructuredBufferLoad(loadInst))
+ {
+ needMaterialize = true;
+ return;
+ }
+ break;
+ }
+ });
+
+ if (needMaterialize)
+ {
+ auto val = materializePointer(builder, loadInst);
+ loadInst->replaceUsesWith(val);
+ loadInst->removeAndDeallocate();
+ }
+ else
+ {
+ // Append to worklist in reverse order so we process the uses in natural appearance
+ // order.
+ for (Index i = pendingWorkList.getCount() - 1; i >= 0; i--)
+ workList.add(pendingWorkList[i]);
+ }
+ }
+
+ void deferBufferLoadInFunc(IRFunc* func)
+ {
+ removeRedundancyInFunc(func);
+
+ currentFunc = func;
+ dominatorTree = func->getModule()->findOrCreateDominatorTree(func);
+
+ List<IRInst*> workList;
+
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (isStructuredBufferLoad(inst))
+ {
+ workList.add(inst);
+ }
+ }
+ }
+
+ IRBuilder builder(func);
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto inst = workList[i];
+ deferBufferLoadInst(builder, workList, inst);
+ }
+ }
+
+ void deferBufferLoad(IRGlobalValueWithCode* inst)
+ {
+ if (auto func = as<IRFunc>(inst))
+ {
+ deferBufferLoadInFunc(func);
+ }
+ else if (auto generic = as<IRGeneric>(inst))
+ {
+ auto inner = findGenericReturnVal(generic);
+ if (auto innerFunc = as<IRFunc>(inner))
+ deferBufferLoadInFunc(innerFunc);
+ }
+ }
+};
+
+void deferBufferLoad(IRModule* module)
+{
+ DeferBufferLoadContext context;
+ for (auto childInst : module->getGlobalInsts())
+ {
+ if (auto code = as<IRGlobalValueWithCode>(childInst))
+ {
+ context.deferBufferLoad(code);
+ }
+ }
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-defer-buffer-load.h b/source/slang/slang-ir-defer-buffer-load.h
new file mode 100644
index 000000000..b54271883
--- /dev/null
+++ b/source/slang/slang-ir-defer-buffer-load.h
@@ -0,0 +1,26 @@
+#pragma once
+
+namespace Slang
+{
+
+/*
+This pass implements a targeted optimization that defers the loading of structured buffer elements
+to the end of the access chain to avoid loading and repacking unnecessary data.
+For example, if we see:
+ val = StructuredBufferLoad(s, i)
+ val2 = GetElement(val, j)
+ val3 = FieldExtract(val2, field_key_0)
+ call(foo, val3)
+We should rewrite the code into:
+ ptr = RWStructuredBufferGetElementPtr(s, i)
+ ptr2 = ElementAddress(ptr, j)
+ ptr3 = FieldAddress(ptr2, field_key_0)
+ val3 = Load(ptr3)
+ call(foo, val3)
+*/
+
+struct IRModule;
+
+void deferBufferLoad(IRModule* module);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 823b3cd7d..5e4db43b2 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -6175,7 +6175,7 @@ IRInst* IRBuilder::emitGenericAsm(UnownedStringSlice asmText)
IRInst* IRBuilder::emitRWStructuredBufferGetElementPtr(IRInst* structuredBuffer, IRInst* index)
{
- const auto sbt = cast<IRHLSLRWStructuredBufferType>(structuredBuffer->getDataType());
+ const auto sbt = cast<IRHLSLStructuredBufferTypeBase>(structuredBuffer->getDataType());
const auto t = getPtrType(sbt->getElementType());
IRInst* const operands[2] = {structuredBuffer, index};
const auto i = createInst<IRRWStructuredBufferGetElementPtr>(
diff --git a/tests/spirv/sb-load-2.slang b/tests/spirv/sb-load-2.slang
new file mode 100644
index 000000000..b4c10cb4a
--- /dev/null
+++ b/tests/spirv/sb-load-2.slang
@@ -0,0 +1,23 @@
+//TEST:SIMPLE(filecheck=CHECK): -target glsl -entry main -stage compute
+
+struct Test1
+{
+ float2x3 a; // 24B
+ float3x4 b; // 48B
+ float16_t3x2 c; // 12B
+ float16_t2x4 d; // 16B
+};
+
+StructuredBuffer<Test1> dp;
+RWStructuredBuffer<float4> outputBuffer;
+
+// CHECK-COUNT-2: unpackStorage
+// CHECK-NOT: unpackStorage
+[numthreads(4, 4, 1)]
+void main(uint3 GTid : SV_GroupThreadID,
+ uint GI : SV_GroupIndex)
+{
+ var tmp = dp[0];
+ var rs = tmp.a[0][0] + tmp.a[0][1];
+ outputBuffer[GI] = float4(rs);
+} \ No newline at end of file
diff --git a/tests/spirv/sb-load.slang b/tests/spirv/sb-load.slang
new file mode 100644
index 000000000..1b0df0be8
--- /dev/null
+++ b/tests/spirv/sb-load.slang
@@ -0,0 +1,24 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+
+#define FILL_PATTERN_DIMENSIONS_X 16
+#define FILL_PATTERN_DIMENSIONS_Y 16
+
+struct FillPatternBuffer
+{
+ float4 px[FILL_PATTERN_DIMENSIONS_Y][FILL_PATTERN_DIMENSIONS_X];
+};
+
+StructuredBuffer<FillPatternBuffer> dp;
+RWStructuredBuffer<float4> outputBuffer;
+
+// CHECK-NOT: OpCompositeConstruct
+
+[numthreads(4, 4, 1)]
+void main(uint3 GTid : SV_GroupThreadID,
+ uint GI : SV_GroupIndex)
+{
+ const uint ii = GTid.x;
+ const uint jj = GTid.y;
+ const float4 pmv = dp[0].px[ii][jj];
+ outputBuffer[GI] = pmv;
+} \ No newline at end of file