summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-10-09 18:26:28 -0700
committerGitHub <noreply@github.com>2025-10-10 01:26:28 +0000
commit3cf1f5a616917480c63b76aae906dc36b29e46ce (patch)
treeabbc4538e17be1163c06c950b4afdacd227fe39c
parent4e4aad5a0493defde1e0ef29f27e5d663c1182cd (diff)
Small fix to buffer load specialization pass to allow more specialization to happen. (#8653)
This allows us to further cleanup unnecessary copies in the target code we generate. Part of effort of #8652.
-rw-r--r--source/slang/slang-emit.cpp6
-rw-r--r--source/slang/slang-ir-specialize-buffer-load-arg.cpp2
-rw-r--r--tests/optimization/wrapped-array.slang56
3 files changed, 63 insertions, 1 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 804a44b81..7d72fb77f 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -1402,6 +1402,12 @@ Result linkAndOptimizeIR(
// function parameters, reults, etc. is invalid.
// We clean up the usages of resource values here.
specializeResourceUsage(codeGenContext, irModule);
+
+ // Specialize calls to functions with values loaded from an immutable location,
+ // so that we directly load the value inside the callee, instead of loading the
+ // value outside of the callee and copy it in. This is necessary to avoid copying
+ // large values (e.g. arrays) in registers, where most of the elements are not
+ // actually used.
specializeFuncsForBufferLoadArgs(codeGenContext, irModule);
// Push `structuredBufferLoad` to the end of access chain to avoid loading unnecessary data.
diff --git a/source/slang/slang-ir-specialize-buffer-load-arg.cpp b/source/slang/slang-ir-specialize-buffer-load-arg.cpp
index a5a3dd2d9..c473c6047 100644
--- a/source/slang/slang-ir-specialize-buffer-load-arg.cpp
+++ b/source/slang/slang-ir-specialize-buffer-load-arg.cpp
@@ -89,7 +89,7 @@ struct FuncBufferLoadSpecializationCondition : FunctionCallSpecializeCondition
a = argLoad->getPtr();
// We can safely defer a load to the callee if the source dest is immutable.
- if (isPointerToImmutableLocation(a))
+ if (isPointerToImmutableLocation(getRootAddr(a)))
continue;
// Otherwise, we check if there is no other instructions in between the load and the
diff --git a/tests/optimization/wrapped-array.slang b/tests/optimization/wrapped-array.slang
new file mode 100644
index 000000000..cf47a228b
--- /dev/null
+++ b/tests/optimization/wrapped-array.slang
@@ -0,0 +1,56 @@
+//TEST:SIMPLE(filecheck=CHECK):-target spirv
+
+// CHECK: OpEntryPoint
+
+// Make sure we never load the entire TensorList struct to local registers,
+// instead, we should specialize the fetch function to directly load the
+// element tensor from gTensors.
+
+// CHECK-NOT: OpLoad %TensorList_std140
+
+struct RWTensor<T, let D : int>
+{
+ int dims[D];
+ RWStructuredBuffer<T> buffer;
+ T getv(vector<uint, D> index)
+ {
+ int flat_index = 0;
+ int stride = 1;
+ for (int i = D - 1; i >= 0; --i)
+ {
+ flat_index += index[i] * stride;
+ stride *= dims[i];
+ }
+ return buffer[flat_index];
+ }
+}
+struct TensorList<let N : int>
+{
+ RWTensor<float, 2> tensors[N];
+
+ float fetch(int tensor_index, uint2 index)
+ {
+ return tensors[tensor_index].getv(index);
+ }
+}
+
+float sum_indirect<let N : int>(uint2 tid, TensorList<N> tensor_list, uint tensor_indices[N])
+{
+ float result = 0.0;
+ for (int i = 0; i < N; i++)
+ {
+ result += tensor_list.fetch(tensor_indices[i], tid);
+ }
+ return result;
+}
+
+uniform TensorList<32> gTensors;
+uniform uint gTensorIndices[32];
+
+uniform float* result;
+
+[numthreads(1,1,1)]
+void computeMain(uint2 tid : SV_DispatchThreadID)
+{
+ *result = sum_indirect(tid, gTensors, gTensorIndices);
+}