diff options
| author | Gangzheng Tong <tonggangzheng@gmail.com> | 2025-07-29 12:12:40 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-29 19:12:40 +0000 |
| commit | 48efc60380aa79e8c4aba13976cc2015f38a659e (patch) | |
| tree | cc13517501c1b7dcdca55fce61a7d19f59a64b43 | |
| parent | 2db6ac97ad62f28c246e8176df52a104bb7c4be9 (diff) | |
Fix Metal invalid as_type cast for 64-bit RWByteAddressBuffer.Store values (#7843)
* Fix 64-bit val lowering for metal
* Add ByteAddressBuffer load/store 64-bit tests
* Handle Store/Load ptr types
* Use bitcast for non-pointer typers
* format code (#7966)
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
---------
Co-authored-by: slangbot <ellieh+slangbot@nvidia.com>
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
| -rw-r--r-- | source/slang/slang-ir-byte-address-legalize.cpp | 63 | ||||
| -rw-r--r-- | tests/compute/byte-address-buffer-64bit.slang | 32 | ||||
| -rw-r--r-- | tests/compute/byte-address-buffer-64bit.slang.expected.txt | 8 | ||||
| -rw-r--r-- | tests/metal/byte-address-buffer.slang | 22 |
4 files changed, 111 insertions, 14 deletions
diff --git a/source/slang/slang-ir-byte-address-legalize.cpp b/source/slang/slang-ir-byte-address-legalize.cpp index 617c8c7c4..baafc7e7f 100644 --- a/source/slang/slang-ir-byte-address-legalize.cpp +++ b/source/slang/slang-ir-byte-address-legalize.cpp @@ -41,7 +41,7 @@ struct ByteAddressBufferLegalizationContext Dictionary<IRInst*, IRType*> byteAddrBufferToReplace; // Everything starts with a request to process a module, - // which delegates to the central recrusive walk of the IR. + // which delegates to the central recursive walk of the IR. // void processModule(IRModule* module) { @@ -178,6 +178,20 @@ struct ByteAddressBufferLegalizationContext return false; } + // For Metal targets, 64-bit integer types need special handling + // because Metal doesn't support as_type casts from 64-bit to 32-bit. + // These types should be lowered to two 32-bit operations. + // + if (m_options.lowerBasicTypeOps) + { + // getSameSizeUIntBaseType should convert any 64-bit types to UInt64 + auto unsignedBaseType = getSameSizeUIntBaseType(type->getOp()); + if (unsignedBaseType == BaseType::UInt64) + { + return false; // Force legalization for 64-bit integer types + } + } + // Otherwise, scalar types are assumed // legal for load/store. // @@ -190,7 +204,7 @@ struct ByteAddressBufferLegalizationContext { // If we've been asked to scalarize all // vector load/store, then we need to - // tread them as illegal. + // treat them as illegal. // if (m_options.scalarizeVectorLoadStore) return false; @@ -223,7 +237,7 @@ struct ByteAddressBufferLegalizationContext else if (auto alignInst = as<IRIntLit>(unknownOffsetAlignment)) { // If the offset is not known during compile time, use the explicit align - // field of the overloaded `Load` or `Store` operation or vi `LoadAligned` + // field of the overloaded `Load` or `Store` operation or via `LoadAligned` // or `StoreAligned` function. // // Unaligned `Load`s or `Store`s are identified with 0 alignment, to prevent @@ -348,8 +362,8 @@ struct ByteAddressBufferLegalizationContext } // Once all the field values have been loaded, we can bind - // then together to make a singel value of the `struct` type, - // representing the reuslt of the legalized load. + // then together to make a single value of the `struct` type, + // representing the result of the legalized load. // return m_builder.emitMakeStruct(type, fieldVals); } @@ -358,7 +372,7 @@ struct ByteAddressBufferLegalizationContext // Loading a value of array type amounts to loading each // of its elements. There is shared logic between the // array, matrix, and vector cases, so we factor it into - // a subroutien that we will explain later. + // a subroutine that we will explain later. // // We need a known constant number of elements in an array // to be able to emit per-element loads, so we skip @@ -739,7 +753,18 @@ struct ByteAddressBufferLegalizationContext hi64, m_builder.getIntValue(m_builder.getUInt64Type(), 32)); auto fullValue = m_builder.emitBitOr(m_builder.getUInt64Type(), lo64, shift); - return m_builder.emitBitCast(type, fullValue); + // For pointer types, Metal doesn't allow as_type casts from integers to pointers, + // so we use proper cast operations instead of bit casts. + if (type->getOp() == kIROp_PtrType || type->getOp() == kIROp_RawPointerType) + { + // Use proper cast operation instead of bit cast for pointers + return m_builder.emitCastIntToPtr(type, fullValue); + } + else + { + // For non-pointer 64-bit types (including IntPtr/UIntPtr) + return m_builder.emitBitCast(type, fullValue); + } } else if (sizeAlignment.size < 4) { @@ -1100,7 +1125,7 @@ struct ByteAddressBufferLegalizationContext void processStore(IRInst* store) { - // Just as for loads, the logic for stores is base don the type + // Just as for loads, the logic for stores is base on the type // being used, but unlike in the load case we don't care about // the type of the store operation, but instead the operand // that represents the value to be stored. @@ -1384,16 +1409,30 @@ struct ByteAddressBufferLegalizationContext if (m_options.lowerBasicTypeOps) { // Some platforms e.g. Metal does not allow storing basic types that are not 4-byte - // sized. We need to lower such loads. + // sized. We need to lower such stores. IRSizeAndAlignment sizeAlignment; SLANG_RETURN_ON_FAIL( getNaturalSizeAndAlignment(m_targetProgram->getOptionSet(), type, &sizeAlignment)); if (sizeAlignment.size == 8) { // We need to store the value as two 4-byte values. - auto uint64Val = m_builder.emitBitCast(m_builder.getUInt64Type(), value); - auto loVal = m_builder.emitCast(m_builder.getUIntType(), uint64Val); - auto hiVal = m_builder.emitCast( + // For pointer types, Metal doesn't allow as_type casts from pointers to integers, + // so we use proper cast operations instead of bit casts. + IRInst* loVal; + IRInst* hiVal; + IRInst* uint64Val; + if (type->getOp() == kIROp_PtrType || type->getOp() == kIROp_RawPointerType) + { + // Use proper cast operation instead of bit cast for pointers + uint64Val = m_builder.emitCastPtrToInt(value); + } + else + { + // For non-pointer 64-bit types (including IntPtr/UIntPtr) + uint64Val = m_builder.emitBitCast(m_builder.getUInt64Type(), value); + } + loVal = m_builder.emitCast(m_builder.getUIntType(), uint64Val); + hiVal = m_builder.emitCast( m_builder.getUIntType(), m_builder.emitShr( m_builder.getUInt64Type(), diff --git a/tests/compute/byte-address-buffer-64bit.slang b/tests/compute/byte-address-buffer-64bit.slang new file mode 100644 index 000000000..e8d6a42f2 --- /dev/null +++ b/tests/compute/byte-address-buffer-64bit.slang @@ -0,0 +1,32 @@ +// byte-address-buffer.slang + +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX:-d3d12 -compute -shaderobj -profile cs_6_0 -use-dxil +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0 1 2 3]):name=inputBuffer +RWByteAddressBuffer inputBuffer; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0]):out,name=outputBuffer +RWByteAddressBuffer outputBuffer; + +void testInt64(uint val) +{ + // Load a 32-bit value from the input buffer + uint tmp = inputBuffer.Load(uint(val * 4)); + + // Cast to uint64_t + uint64_t tmp64 = uint64_t(tmp) + 1; + + // Store the result back as uint64_t + outputBuffer.Store(uint(val * 8), tmp64); +} + +[numthreads(4, 1, 1)] +[shader("compute")] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + + testInt64(tid); +} diff --git a/tests/compute/byte-address-buffer-64bit.slang.expected.txt b/tests/compute/byte-address-buffer-64bit.slang.expected.txt new file mode 100644 index 000000000..ac5027985 --- /dev/null +++ b/tests/compute/byte-address-buffer-64bit.slang.expected.txt @@ -0,0 +1,8 @@ +1 +0 +2 +0 +3 +0 +4 +0
\ No newline at end of file diff --git a/tests/metal/byte-address-buffer.slang b/tests/metal/byte-address-buffer.slang index 53c0e0ac5..8514dafe6 100644 --- a/tests/metal/byte-address-buffer.slang +++ b/tests/metal/byte-address-buffer.slang @@ -1,8 +1,6 @@ //TEST:SIMPLE(filecheck=CHECK): -target metal //TEST:SIMPLE(filecheck=CHECK-ASM): -target metallib -uniform RWStructuredBuffer<float> outputBuffer; - RWByteAddressBuffer buffer; // CHECK-ASM: define void @main_kernel @@ -27,4 +25,24 @@ void main_kernel(uint3 tid: SV_DispatchThreadID) // CHECK: {{.*}}[(128U)>>2] = as_type<uint32_t>(({{.*}} & 4294967040U) | (uint([[A]]) << 0U)); // CHECK: {{.*}}[(128U)>>2] = as_type<uint32_t>(({{.*}} & 65535U) | (uint(as_type<ushort>([[H]])) << 16U)); buffer.Store(128, buffer.Load<TestStruct>(0)); + + // CHECK: {{.*}}[(256U)>>2] = as_type<uint32_t>(4294967295U); + // CHECK: {{.*}}[(260U)>>2] = as_type<uint32_t>(4294967295U); + int64_t i64 = -1; + buffer.Store(256, i64); + + // CHECK: {{.*}}[(264U)>>2] = as_type<uint32_t>(123U); + // CHECK: {{.*}}[(268U)>>2] = as_type<uint32_t>(0U); + uint64_t u64 = 123; + buffer.Store(264, u64); + + int64_t* ptr = Ptr<int64_t>(0xFF); + // CHECK: {{.*}}[(272U)>>2] = as_type<uint32_t>({{.*}}); + // CHECK: {{.*}}[(276U)>>2] = as_type<uint32_t>({{.*}}); + buffer.Store(272, ptr); + + // CHECK: {{.*}}[(280U)>>2] = as_type<uint32_t>(4294967295U); + // CHECK: {{.*}}[(284U)>>2] = as_type<uint32_t>(4294967295U); + uintptr_t uintptr_val = (uintptr_t)-1; + buffer.Store(280, uintptr_val); } |
