diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2024-07-09 12:45:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-09 12:45:57 -0700 |
| commit | 1caef5907d0b0f16f686a8fcca479c6afc09f146 (patch) | |
| tree | 55c30b0957b41c5fe1f2b655e27ab13fd6e54ca5 | |
| parent | ddd14be7a7e807a124a29221d53a5e83f92c570a (diff) | |
Fix Lexer to recognize swizzling on an integer scalar value (#4515)
* Fix Lexer to recognize swizzling on an integer scalar value
Close #4413
| -rw-r--r-- | source/compiler-core/slang-lexer.cpp | 78 | ||||
| -rw-r--r-- | tests/hlsl-intrinsic/scalar-swizzling.slang | 80 |
2 files changed, 133 insertions, 25 deletions
diff --git a/source/compiler-core/slang-lexer.cpp b/source/compiler-core/slang-lexer.cpp index 5954dc668..8c428159c 100644 --- a/source/compiler-core/slang-lexer.cpp +++ b/source/compiler-core/slang-lexer.cpp @@ -173,38 +173,49 @@ namespace Slang // Look ahead one code point, dealing with complications like // escaped newlines. - static int _peek(Lexer* lexer) + static int _peek(Lexer* lexer, int offset = 0) { - // Look at the next raw byte, and decide what to do - int c = _peekRaw(lexer); + int pos = 0; + int c = kEOF; - if(c == '\\') + do { - // We might have a backslash-escaped newline. - // Look at the next byte (if any) to see. - // - // Note(tfoley): We are assuming a null-terminated input here, - // so that we can safely look at the next byte without issue. - int d = lexer->m_cursor[1]; - switch (d) + if (lexer->m_cursor + pos == lexer->m_end) + return kEOF; + + c = lexer->m_cursor[pos++]; + + if (c == '\\') { - case '\r': case '\n': + // We might have a backslash-escaped newline. + // Look at the next byte (if any) to see. + // + // Note(tfoley): We are assuming a null-terminated input here, + // so that we can safely look at the next byte without issue. + int d = lexer->m_cursor[pos++]; + switch (d) + { + case '\r': case '\n': { // The newline was escaped, so return the code point after *that* - int e = lexer->m_cursor[2]; + int e = lexer->m_cursor[pos++]; if ((d ^ e) == ('\r' ^ '\n')) - return lexer->m_cursor[3]; - return e; + c = lexer->m_cursor[pos++]; + else + c = e; + break; } - default: - break; + default: + break; + } } - } - // TODO: handle UTF-8 encoding for non-ASCII code points here + // TODO: handle UTF-8 encoding for non-ASCII code points here + + // Default case is to just hand along the byte we read as an ASCII code point. + } while (offset--); - // Default case is to just hand along the byte we read as an ASCII code point. return c; } @@ -494,10 +505,19 @@ namespace Slang if( _peek(lexer) == '.' ) { - tokenType = TokenType::FloatingPointLiteral; + switch (_peek(lexer, 1)) + { + // 123.xxxx or 123.rrrr + case 'x': + case 'r': + break; - _advance(lexer); - _lexDigits(lexer, base); + default: + tokenType = TokenType::FloatingPointLiteral; + + _advance(lexer); + _lexDigits(lexer, base); + } } if( _maybeLexNumberExponent(lexer, base)) @@ -1089,8 +1109,16 @@ namespace Slang return _maybeLexNumberSuffix(lexer, TokenType::IntegerLiteral); case '.': - _advance(lexer); - return _lexNumberAfterDecimalPoint(lexer, 10); + switch (_peek(lexer, 1)) + { + // 0.xxxx or 0.rrrr + case 'x': + case 'r': + return _maybeLexNumberSuffix(lexer, TokenType::IntegerLiteral); + default: + _advance(lexer); + return _lexNumberAfterDecimalPoint(lexer, 10); + } case 'x': case 'X': _advance(lexer); diff --git a/tests/hlsl-intrinsic/scalar-swizzling.slang b/tests/hlsl-intrinsic/scalar-swizzling.slang new file mode 100644 index 000000000..9ca024755 --- /dev/null +++ b/tests/hlsl-intrinsic/scalar-swizzling.slang @@ -0,0 +1,80 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type + +//TEST_INPUT: ubuffer(data=[0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + bool result = true + && (0.x is int) + && 0 == 0.x + && all(uint2(0) == 0.xx) + && all(uint3(0) == 0.xxx) + && all(uint4(0) == 0.xxxx) + && (0.r is int) + && 0 == 0.r + && all(uint2(0) == 0.rr) + && all(uint3(0) == 0.rrr) + && all(uint4(0) == 0.rrrr) + + && (123.x is int) + && 123 == 123.x + && all(uint2(123) == 123.xx) + && all(uint3(123) == 123.xxx) + && all(uint4(123) == 123.xxxx) + && (123.r is int) + && 123 == 123.r + && all(uint2(123) == 123.rr) + && all(uint3(123) == 123.rrr) + && all(uint4(123) == 123.rrrr) + + && (0.f.x is float) + && 0.f == 0.f.x + && all(float2(0.f) == 0.f.xx) + && all(float3(0.f) == 0.f.xxx) + && all(float4(0.f) == 0.f.xxxx) + && (0.f.r is float) + && 0.f == 0.f.r + && all(float2(0.f) == 0.f.rr) + && all(float3(0.f) == 0.f.rrr) + && all(float4(0.f) == 0.f.rrrr) + + && (123.f.x is float) + && 123.f == 123.f.x + && all(float2(123.f) == 123.f.xx) + && all(float3(123.f) == 123.f.xxx) + && all(float4(123.f) == 123.f.xxxx) + && (123.f.r is float) + && 123.f == 123.f.r + && all(float2(123.f) == 123.f.rr) + && all(float3(123.f) == 123.f.rrr) + && all(float4(123.f) == 123.f.rrrr) + + && (0..x is float) + && 0.f == 0..x + && all(float2(0.f) == 0..xx) + && all(float3(0.f) == 0..xxx) + && all(float4(0.f) == 0..xxxx) + && (0..r is float) + && 0.f == 0..r + && all(float2(0.f) == 0..rr) + && all(float3(0.f) == 0..rrr) + && all(float4(0.f) == 0..rrrr) + + && (123..x is float) + && 123.f == 123..x + && all(float2(123.f) == 123..xx) + && all(float3(123.f) == 123..xxx) + && all(float4(123.f) == 123..xxxx) + && (123..r is float) + && 123.f == 123..r + && all(float2(123.f) == 123..rr) + && all(float3(123.f) == 123..rrr) + && all(float4(123.f) == 123..rrrr) + ; + + //CHK:1 + outputBuffer[0] = int(result); +} + |
