summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-09-30 19:08:23 -0700
committerGitHub <noreply@github.com>2025-09-30 19:08:23 -0700
commite4611e2e30a3e5969d402f5ed7e72706a0e3b024 (patch)
tree0f4240ccf8c4f0786949ab33adb0fcc332890d11 /prelude
parentb6422e50cb19f7f790f29678ba22f31b0b305511 (diff)
Enhance buffer load specialization pass to specialize past field extracts. (#8547)
This allows us to specialize functions whose argument is a sub element of a constant buffer, instead of being only applicable to entire buffer element. Closes #8421. This change also implements a proper heuristic to determine when to specialize the calls and defer the buffer loads. This PR addresses a pathological case exposed in `slangpy\slangpy\benchmarks\test_benchmark_tensor.py`, which used to take 27ms to finish, and now takes 1.25ms. For example, given: ``` struct Bottom { float bigArray[1024]; [mutating] void setVal(int index, float value) { bigArray[index] = value; } } struct Root { Bottom top[2]; [mutating] void setTopVal(int x, int y, float value) { top[x].setVal(y, value); } } RWStructuredBuffer<Root> sb; [shader("compute")] [numthreads(1, 1, 1)] void compute_main(uint3 tid: SV_DispatchThreadID) { sb[0].setTopVal(1, 2, 100.0f); } ``` We are now able to specialize the call to `setTopVal` into: ``` void compute_main(uint3 tid: SV_DispatchThreadID) { setTopVal_specialized(0, 1, 2, 100.0f); } void setTopVal_specialized(int sbIdx, int x, int y, float value) { Bottom_setVal_specialized(sbIdx, x, y, value); } void Bottom_setVal_specialized(int sbIdx, int x, int y, float value) { sb[sbIdx].top[x].bigArray[y] = value; } ``` And get rid of all unnecessary loads. Achieving this requires a combination of function call specialization and buffer-load-defer pass. The buffer-load-defer pass has been completely rewritten to be more correct and avoid introducing redundant loads. This PR also adds tests to make sure pointers, bindless handles, and loads from structured buffer or constant buffers works as expected.
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-cpp-types.h4
-rw-r--r--prelude/slang-cuda-prelude.h4
2 files changed, 4 insertions, 4 deletions
diff --git a/prelude/slang-cpp-types.h b/prelude/slang-cpp-types.h
index 491438c80..26b45d53f 100644
--- a/prelude/slang-cpp-types.h
+++ b/prelude/slang-cpp-types.h
@@ -59,12 +59,12 @@ struct RWStructuredBuffer
template<typename T>
struct StructuredBuffer
{
- SLANG_FORCE_INLINE const T& operator[](size_t index) const
+ SLANG_FORCE_INLINE T& operator[](size_t index) const
{
SLANG_BOUND_CHECK(index, count);
return data[index];
}
- const T& Load(size_t index) const
+ T& Load(size_t index) const
{
SLANG_BOUND_CHECK(index, count);
return data[index];
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 5c5335ac5..6c68cdb71 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -2312,7 +2312,7 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL uintptr_t UPTR_max(uintptr_t a, uintptr_t b)
template<typename T>
struct StructuredBuffer
{
- SLANG_CUDA_CALL const T& operator[](size_t index) const
+ SLANG_CUDA_CALL T& operator[](size_t index) const
{
#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
SLANG_BOUND_CHECK(index, count);
@@ -2320,7 +2320,7 @@ struct StructuredBuffer
return data[index];
}
- SLANG_CUDA_CALL const T& Load(size_t index) const
+ SLANG_CUDA_CALL T& Load(size_t index) const
{
#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
SLANG_BOUND_CHECK(index, count);