summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGangzheng Tong <tonggangzheng@gmail.com>2025-07-29 12:12:40 -0700
committerGitHub <noreply@github.com>2025-07-29 19:12:40 +0000
commit48efc60380aa79e8c4aba13976cc2015f38a659e (patch)
treecc13517501c1b7dcdca55fce61a7d19f59a64b43
parent2db6ac97ad62f28c246e8176df52a104bb7c4be9 (diff)
Fix Metal invalid as_type cast for 64-bit RWByteAddressBuffer.Store values (#7843)
* Fix 64-bit val lowering for metal * Add ByteAddressBuffer load/store 64-bit tests * Handle Store/Load ptr types * Use bitcast for non-pointer typers * format code (#7966) Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --------- Co-authored-by: slangbot <ellieh+slangbot@nvidia.com> Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--source/slang/slang-ir-byte-address-legalize.cpp63
-rw-r--r--tests/compute/byte-address-buffer-64bit.slang32
-rw-r--r--tests/compute/byte-address-buffer-64bit.slang.expected.txt8
-rw-r--r--tests/metal/byte-address-buffer.slang22
4 files changed, 111 insertions, 14 deletions
diff --git a/source/slang/slang-ir-byte-address-legalize.cpp b/source/slang/slang-ir-byte-address-legalize.cpp
index 617c8c7c4..baafc7e7f 100644
--- a/source/slang/slang-ir-byte-address-legalize.cpp
+++ b/source/slang/slang-ir-byte-address-legalize.cpp
@@ -41,7 +41,7 @@ struct ByteAddressBufferLegalizationContext
Dictionary<IRInst*, IRType*> byteAddrBufferToReplace;
// Everything starts with a request to process a module,
- // which delegates to the central recrusive walk of the IR.
+ // which delegates to the central recursive walk of the IR.
//
void processModule(IRModule* module)
{
@@ -178,6 +178,20 @@ struct ByteAddressBufferLegalizationContext
return false;
}
+ // For Metal targets, 64-bit integer types need special handling
+ // because Metal doesn't support as_type casts from 64-bit to 32-bit.
+ // These types should be lowered to two 32-bit operations.
+ //
+ if (m_options.lowerBasicTypeOps)
+ {
+ // getSameSizeUIntBaseType should convert any 64-bit types to UInt64
+ auto unsignedBaseType = getSameSizeUIntBaseType(type->getOp());
+ if (unsignedBaseType == BaseType::UInt64)
+ {
+ return false; // Force legalization for 64-bit integer types
+ }
+ }
+
// Otherwise, scalar types are assumed
// legal for load/store.
//
@@ -190,7 +204,7 @@ struct ByteAddressBufferLegalizationContext
{
// If we've been asked to scalarize all
// vector load/store, then we need to
- // tread them as illegal.
+ // treat them as illegal.
//
if (m_options.scalarizeVectorLoadStore)
return false;
@@ -223,7 +237,7 @@ struct ByteAddressBufferLegalizationContext
else if (auto alignInst = as<IRIntLit>(unknownOffsetAlignment))
{
// If the offset is not known during compile time, use the explicit align
- // field of the overloaded `Load` or `Store` operation or vi `LoadAligned`
+ // field of the overloaded `Load` or `Store` operation or via `LoadAligned`
// or `StoreAligned` function.
//
// Unaligned `Load`s or `Store`s are identified with 0 alignment, to prevent
@@ -348,8 +362,8 @@ struct ByteAddressBufferLegalizationContext
}
// Once all the field values have been loaded, we can bind
- // then together to make a singel value of the `struct` type,
- // representing the reuslt of the legalized load.
+ // then together to make a single value of the `struct` type,
+ // representing the result of the legalized load.
//
return m_builder.emitMakeStruct(type, fieldVals);
}
@@ -358,7 +372,7 @@ struct ByteAddressBufferLegalizationContext
// Loading a value of array type amounts to loading each
// of its elements. There is shared logic between the
// array, matrix, and vector cases, so we factor it into
- // a subroutien that we will explain later.
+ // a subroutine that we will explain later.
//
// We need a known constant number of elements in an array
// to be able to emit per-element loads, so we skip
@@ -739,7 +753,18 @@ struct ByteAddressBufferLegalizationContext
hi64,
m_builder.getIntValue(m_builder.getUInt64Type(), 32));
auto fullValue = m_builder.emitBitOr(m_builder.getUInt64Type(), lo64, shift);
- return m_builder.emitBitCast(type, fullValue);
+ // For pointer types, Metal doesn't allow as_type casts from integers to pointers,
+ // so we use proper cast operations instead of bit casts.
+ if (type->getOp() == kIROp_PtrType || type->getOp() == kIROp_RawPointerType)
+ {
+ // Use proper cast operation instead of bit cast for pointers
+ return m_builder.emitCastIntToPtr(type, fullValue);
+ }
+ else
+ {
+ // For non-pointer 64-bit types (including IntPtr/UIntPtr)
+ return m_builder.emitBitCast(type, fullValue);
+ }
}
else if (sizeAlignment.size < 4)
{
@@ -1100,7 +1125,7 @@ struct ByteAddressBufferLegalizationContext
void processStore(IRInst* store)
{
- // Just as for loads, the logic for stores is base don the type
+ // Just as for loads, the logic for stores is base on the type
// being used, but unlike in the load case we don't care about
// the type of the store operation, but instead the operand
// that represents the value to be stored.
@@ -1384,16 +1409,30 @@ struct ByteAddressBufferLegalizationContext
if (m_options.lowerBasicTypeOps)
{
// Some platforms e.g. Metal does not allow storing basic types that are not 4-byte
- // sized. We need to lower such loads.
+ // sized. We need to lower such stores.
IRSizeAndAlignment sizeAlignment;
SLANG_RETURN_ON_FAIL(
getNaturalSizeAndAlignment(m_targetProgram->getOptionSet(), type, &sizeAlignment));
if (sizeAlignment.size == 8)
{
// We need to store the value as two 4-byte values.
- auto uint64Val = m_builder.emitBitCast(m_builder.getUInt64Type(), value);
- auto loVal = m_builder.emitCast(m_builder.getUIntType(), uint64Val);
- auto hiVal = m_builder.emitCast(
+ // For pointer types, Metal doesn't allow as_type casts from pointers to integers,
+ // so we use proper cast operations instead of bit casts.
+ IRInst* loVal;
+ IRInst* hiVal;
+ IRInst* uint64Val;
+ if (type->getOp() == kIROp_PtrType || type->getOp() == kIROp_RawPointerType)
+ {
+ // Use proper cast operation instead of bit cast for pointers
+ uint64Val = m_builder.emitCastPtrToInt(value);
+ }
+ else
+ {
+ // For non-pointer 64-bit types (including IntPtr/UIntPtr)
+ uint64Val = m_builder.emitBitCast(m_builder.getUInt64Type(), value);
+ }
+ loVal = m_builder.emitCast(m_builder.getUIntType(), uint64Val);
+ hiVal = m_builder.emitCast(
m_builder.getUIntType(),
m_builder.emitShr(
m_builder.getUInt64Type(),
diff --git a/tests/compute/byte-address-buffer-64bit.slang b/tests/compute/byte-address-buffer-64bit.slang
new file mode 100644
index 000000000..e8d6a42f2
--- /dev/null
+++ b/tests/compute/byte-address-buffer-64bit.slang
@@ -0,0 +1,32 @@
+// byte-address-buffer.slang
+
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE_EX:-d3d12 -compute -shaderobj -profile cs_6_0 -use-dxil
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 1 2 3]):name=inputBuffer
+RWByteAddressBuffer inputBuffer;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0]):out,name=outputBuffer
+RWByteAddressBuffer outputBuffer;
+
+void testInt64(uint val)
+{
+ // Load a 32-bit value from the input buffer
+ uint tmp = inputBuffer.Load(uint(val * 4));
+
+ // Cast to uint64_t
+ uint64_t tmp64 = uint64_t(tmp) + 1;
+
+ // Store the result back as uint64_t
+ outputBuffer.Store(uint(val * 8), tmp64);
+}
+
+[numthreads(4, 1, 1)]
+[shader("compute")]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+
+ testInt64(tid);
+}
diff --git a/tests/compute/byte-address-buffer-64bit.slang.expected.txt b/tests/compute/byte-address-buffer-64bit.slang.expected.txt
new file mode 100644
index 000000000..ac5027985
--- /dev/null
+++ b/tests/compute/byte-address-buffer-64bit.slang.expected.txt
@@ -0,0 +1,8 @@
+1
+0
+2
+0
+3
+0
+4
+0 \ No newline at end of file
diff --git a/tests/metal/byte-address-buffer.slang b/tests/metal/byte-address-buffer.slang
index 53c0e0ac5..8514dafe6 100644
--- a/tests/metal/byte-address-buffer.slang
+++ b/tests/metal/byte-address-buffer.slang
@@ -1,8 +1,6 @@
//TEST:SIMPLE(filecheck=CHECK): -target metal
//TEST:SIMPLE(filecheck=CHECK-ASM): -target metallib
-uniform RWStructuredBuffer<float> outputBuffer;
-
RWByteAddressBuffer buffer;
// CHECK-ASM: define void @main_kernel
@@ -27,4 +25,24 @@ void main_kernel(uint3 tid: SV_DispatchThreadID)
// CHECK: {{.*}}[(128U)>>2] = as_type<uint32_t>(({{.*}} & 4294967040U) | (uint([[A]]) << 0U));
// CHECK: {{.*}}[(128U)>>2] = as_type<uint32_t>(({{.*}} & 65535U) | (uint(as_type<ushort>([[H]])) << 16U));
buffer.Store(128, buffer.Load<TestStruct>(0));
+
+ // CHECK: {{.*}}[(256U)>>2] = as_type<uint32_t>(4294967295U);
+ // CHECK: {{.*}}[(260U)>>2] = as_type<uint32_t>(4294967295U);
+ int64_t i64 = -1;
+ buffer.Store(256, i64);
+
+ // CHECK: {{.*}}[(264U)>>2] = as_type<uint32_t>(123U);
+ // CHECK: {{.*}}[(268U)>>2] = as_type<uint32_t>(0U);
+ uint64_t u64 = 123;
+ buffer.Store(264, u64);
+
+ int64_t* ptr = Ptr<int64_t>(0xFF);
+ // CHECK: {{.*}}[(272U)>>2] = as_type<uint32_t>({{.*}});
+ // CHECK: {{.*}}[(276U)>>2] = as_type<uint32_t>({{.*}});
+ buffer.Store(272, ptr);
+
+ // CHECK: {{.*}}[(280U)>>2] = as_type<uint32_t>(4294967295U);
+ // CHECK: {{.*}}[(284U)>>2] = as_type<uint32_t>(4294967295U);
+ uintptr_t uintptr_val = (uintptr_t)-1;
+ buffer.Store(280, uintptr_val);
}