diff options
| author | kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> | 2025-05-07 00:46:42 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-06 22:46:42 -0700 |
| commit | ccdb2e39da37753961f3694d0f90e676bf859006 (patch) | |
| tree | e4dd8cea8e54083283c7728df8654fa5ad4516b2 | |
| parent | 90ecf185a742efffc7e1fcf399961289b3e00d08 (diff) | |
bitcast require the input has same width with result type (#7018)
bitcast requires the input has same width with result type, this PR ensures that we always lower the bitcast IR instruction satisfies this requirement.
Close #7017.
| -rw-r--r-- | source/slang/slang-ir-extract-value-from-type.cpp | 75 | ||||
| -rw-r--r-- | tests/glsl/interger_pack.slang | 47 | ||||
| -rw-r--r-- | tests/language-feature/bit-cast/float-bit-cast.slang | 32 |
3 files changed, 129 insertions, 25 deletions
diff --git a/source/slang/slang-ir-extract-value-from-type.cpp b/source/slang/slang-ir-extract-value-from-type.cpp index 621532a38..8f72c1623 100644 --- a/source/slang/slang-ir-extract-value-from-type.cpp +++ b/source/slang/slang-ir-extract-value-from-type.cpp @@ -17,6 +17,50 @@ struct FindLeafValueResult 0; // The offset in bytes within `leafValue` that contains the requested value. }; +// bitcast the leaf value to the same size as leaf value's type. +// For type that has size smaller than 4 bytes, we will need to cast them +// to 32-bit unsigned int first, and then cast to the target type. +IRInst* bitCastLeafValue(IRBuilder& builder, FindLeafValueResult& leaf) +{ + auto resultValue = leaf.leafValue; + + IRType* intermediateUintType = nullptr; + IRType* targetUintType = nullptr; + switch (leaf.valueSize) + { + case 1: + intermediateUintType = builder.getUInt8Type(); + targetUintType = builder.getUIntType(); + break; + case 2: + intermediateUintType = builder.getUInt16Type(); + targetUintType = builder.getUIntType(); + break; + case 4: + intermediateUintType = builder.getUIntType(); + targetUintType = intermediateUintType; + break; + case 8: + intermediateUintType = builder.getUInt64Type(); + targetUintType = intermediateUintType; + break; + default: + SLANG_UNEXPECTED("Unsupported value size"); + break; + } + resultValue = builder.emitBitCast(intermediateUintType, resultValue); + + // In case of 1-byte or 2-byte value, we need to cast it to 32-bit unsigned int first + // because we don't allow bitCast from 1-byte or 2-byte type to 32-bit type. + if (intermediateUintType != targetUintType) + { + resultValue = builder.emitCast(targetUintType, resultValue); + resultValue = builder.emitBitCast(targetUintType, resultValue); + } + + return resultValue; +} + FindLeafValueResult findLeafValueAtOffset( TargetProgram* targetProgram, IRBuilder& builder, @@ -181,16 +225,9 @@ IRInst* extractByteAtOffset( uint32_t offset) { auto leaf = findLeafValueAtOffset(targetProgram, builder, dataType, layout, src, offset); - IRType* uintType = nullptr; - if (leaf.valueSize <= 4) - { - uintType = builder.getUIntType(); - } - else - { - uintType = builder.getUInt64Type(); - } - auto resultValue = builder.emitBitCast(uintType, leaf.leafValue); + auto resultValue = bitCastLeafValue(builder, leaf); + auto uintType = resultValue->getDataType(); + if (leaf.offsetInValue != 0) { uint32_t shift = leaf.offsetInValue * 8; @@ -217,21 +254,13 @@ IRInst* extractMultiByteValueAtOffset( return extractByteAtOffset(builder, targetProgram, dataType, layout, src, offset); auto leaf = findLeafValueAtOffset(targetProgram, builder, dataType, layout, src, offset); - auto resultValue = leaf.leafValue; - IRType* uintType = nullptr; - if (leaf.valueSize <= 4) - { - uintType = builder.getUIntType(); - } - else - { - uintType = builder.getUInt64Type(); - } if (leaf.valueSize - leaf.offsetInValue >= size) { // The request value is fully contained in the found leaf element. // We can proceed to extract the requested bits from the element. - resultValue = builder.emitBitCast(uintType, leaf.leafValue); + auto resultValue = bitCastLeafValue(builder, leaf); + auto uintType = resultValue->getDataType(); + uint32_t shift = leaf.offsetInValue * 8; if (shift > 0) resultValue = @@ -274,6 +303,8 @@ IRInst* extractMultiByteValueAtOffset( src, firstHalfSize, offset); + + auto uintType = firstHalf->getDataType(); switch (firstHalfSize) { case 1: @@ -301,7 +332,7 @@ IRInst* extractMultiByteValueAtOffset( restSize, offset + firstHalfSize); uint32_t shift = firstHalfSize * 8; - resultValue = builder.emitBitOr( + auto resultValue = builder.emitBitOr( builder.getUIntType(), firstHalf, builder.emitShl( diff --git a/tests/glsl/interger_pack.slang b/tests/glsl/interger_pack.slang index 7bf414c4d..cf2c49f9c 100644 --- a/tests/glsl/interger_pack.slang +++ b/tests/glsl/interger_pack.slang @@ -3,8 +3,26 @@ #version 450 -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer -layout(scalar) buffer MyBlockName2 +//TEST_INPUT:ubuffer(data=[0xA802 0x1349 0xC2 0x91 0xB2 0x72], stride=4):out,name=inputBuffer +layout(scalar) buffer MyBlock1 +{ + uint32_t a; + uint32_t b; + + uint32_t c; + uint32_t d; + uint32_t e; + uint32_t f; +} inputBuffer; +// BUF: A802 +// BUF-NEXT: 1349 +// BUF-NEXT: C2 +// BUF-NEXT: 91 +// BUF-NEXT: B2 +// BUF-NEXT: 72 + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +layout(scalar) buffer MyBlock2 { uvec4 a; ivec4 b; @@ -20,6 +38,12 @@ layout(scalar) buffer MyBlockName2 uint32_t i; int32_t j; + + uint32_t k; + int32_t l; + + uint32_t m; + uint32_t n; } outputBuffer; layout(local_size_x = 1) in; @@ -27,7 +51,7 @@ void computeMain() { uint32_t a = 0xF2845678; outputBuffer.a = unpack8(a); - // BUF: 78 + // BUF-NEXT: 78 // BUF-NEXT: 56 // BUF-NEXT: 84 // BUF-NEXT: F2 @@ -74,4 +98,21 @@ void computeMain() i8vec4 j = {0x82, 0x56, 0x12, 0x80}; outputBuffer.j = pack32(j); // BUF-NEXT: 80125682 + + // Note: Below tests are mainly to verify that we don't emit invalid spirv code + u16vec2 k = {inputBuffer.a, inputBuffer.b}; + outputBuffer.k = pack32(k); + // BUF-NEXT: 1349A802 + + i16vec2 l = {inputBuffer.a, inputBuffer.b}; + outputBuffer.l = pack32(l); + // BUF-NEXT: 1349A802 + + u8vec4 m = {inputBuffer.c, inputBuffer.d, inputBuffer.e, inputBuffer.f}; + outputBuffer.m = pack32(m); + // BUF-NEXT: 72B291C2 + + u8vec4 n = {inputBuffer.c, inputBuffer.d, inputBuffer.e, inputBuffer.f}; + outputBuffer.n = pack32(n); + // BUF-NEXT: 72B291C2 } diff --git a/tests/language-feature/bit-cast/float-bit-cast.slang b/tests/language-feature/bit-cast/float-bit-cast.slang new file mode 100644 index 000000000..66b6812c4 --- /dev/null +++ b/tests/language-feature/bit-cast/float-bit-cast.slang @@ -0,0 +1,32 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk + +struct MyStruct +{ + half a; + half b; +} + + + +//TEST_INPUT:set inputBuffer = ubuffer(data=[1.0 3.0 4.0], stride=4) +RWStructuredBuffer<float> inputBuffer; + + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<uint> outputBuffer; + +[shader("compute")] +void computeMain() +{ + vector<float, 1> a = {inputBuffer[0]}; + + outputBuffer[0] = bit_cast<uint>(a); + // BUF: 3F800000 + + MyStruct s = MyStruct(half(inputBuffer[1]), half(inputBuffer[2])); + outputBuffer[1] = bit_cast<uint>(s); + // 2.0 : 0x4000 + // 3.0 : 0x4200 + // BUF-NEXT: 44004200 + +} |
