summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-11-06 18:58:26 -0800
committerGitHub <noreply@github.com>2024-11-07 10:58:26 +0800
commit7c195d3b31c85d8b53ad5848b7730bb2be6c6a89 (patch)
tree118b6ab79c2af45281d40b2ada9dcc85c98d8e65
parent65de5452b71a311d66169ea16334e84d7e6465c1 (diff)
Fix CUDA prelude for makeMatrix (#5509)
* Fix CUDA prelude for makeMatrix * Add regression test.
-rw-r--r--prelude/slang-cuda-prelude.h52
-rw-r--r--tests/cuda/make-matrix.slang14
2 files changed, 40 insertions, 26 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 9ac903955..46c6a4394 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -733,12 +733,12 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(
Matrix<T, ROWS, COLS> rs;
if (COLS == 3)
{
- rs.rows[0].x = v0;
- rs.rows[0].y = v1;
- rs.rows[0].z = v2;
- rs.rows[1].x = v3;
- rs.rows[1].y = v4;
- rs.rows[1].z = v5;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 0) = v0;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 1) = v1;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 2) = v2;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 0) = v3;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 1) = v4;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 2) = v5;
}
else
{
@@ -766,14 +766,14 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(
Matrix<T, ROWS, COLS> rs;
if (COLS == 4)
{
- rs.rows[0].x = v0;
- rs.rows[0].y = v1;
- rs.rows[0].z = v2;
- rs.rows[0].w = v3;
- rs.rows[1].x = v4;
- rs.rows[1].y = v5;
- rs.rows[1].z = v6;
- rs.rows[1].w = v7;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 0) = v0;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 1) = v1;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 2) = v2;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 3) = v3;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 0) = v4;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 1) = v5;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 2) = v6;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 3) = v7;
}
else
{
@@ -832,18 +832,18 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(
Matrix<T, ROWS, COLS> rs;
if (COLS == 4)
{
- rs.rows[0].x = v0;
- rs.rows[0].y = v1;
- rs.rows[0].z = v2;
- rs.rows[0].w = v3;
- rs.rows[1].x = v4;
- rs.rows[1].y = v5;
- rs.rows[1].z = v6;
- rs.rows[1].w = v7;
- rs.rows[2].x = v8;
- rs.rows[2].y = v9;
- rs.rows[2].z = v10;
- rs.rows[2].w = v11;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 0) = v0;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 1) = v1;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 2) = v2;
+ *_slang_vector_get_element_ptr(&rs.rows[0], 3) = v3;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 0) = v4;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 1) = v5;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 2) = v6;
+ *_slang_vector_get_element_ptr(&rs.rows[1], 3) = v7;
+ *_slang_vector_get_element_ptr(&rs.rows[2], 0) = v8;
+ *_slang_vector_get_element_ptr(&rs.rows[2], 1) = v9;
+ *_slang_vector_get_element_ptr(&rs.rows[2], 2) = v10;
+ *_slang_vector_get_element_ptr(&rs.rows[2], 3) = v11;
}
else
{
diff --git a/tests/cuda/make-matrix.slang b/tests/cuda/make-matrix.slang
new file mode 100644
index 000000000..ef6358620
--- /dev/null
+++ b/tests/cuda/make-matrix.slang
@@ -0,0 +1,14 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -compute
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cuda -compute
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint4x3> outputBuffer : register(u0);
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint idx = dispatchThreadID.x + 1;
+ uint4x3 mat1 = uint4x3(idx, idx, idx, idx, idx, idx, idx, idx, idx, idx, idx, idx);
+ outputBuffer[0] = mat1;
+ // CHECK: 1
+} \ No newline at end of file