summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2024-07-09 12:45:57 -0700
committerGitHub <noreply@github.com>2024-07-09 12:45:57 -0700
commit1caef5907d0b0f16f686a8fcca479c6afc09f146 (patch)
tree55c30b0957b41c5fe1f2b655e27ab13fd6e54ca5
parentddd14be7a7e807a124a29221d53a5e83f92c570a (diff)
Fix Lexer to recognize swizzling on an integer scalar value (#4515)
* Fix Lexer to recognize swizzling on an integer scalar value Close #4413
-rw-r--r--source/compiler-core/slang-lexer.cpp78
-rw-r--r--tests/hlsl-intrinsic/scalar-swizzling.slang80
2 files changed, 133 insertions, 25 deletions
diff --git a/source/compiler-core/slang-lexer.cpp b/source/compiler-core/slang-lexer.cpp
index 5954dc668..8c428159c 100644
--- a/source/compiler-core/slang-lexer.cpp
+++ b/source/compiler-core/slang-lexer.cpp
@@ -173,38 +173,49 @@ namespace Slang
// Look ahead one code point, dealing with complications like
// escaped newlines.
- static int _peek(Lexer* lexer)
+ static int _peek(Lexer* lexer, int offset = 0)
{
- // Look at the next raw byte, and decide what to do
- int c = _peekRaw(lexer);
+ int pos = 0;
+ int c = kEOF;
- if(c == '\\')
+ do
{
- // We might have a backslash-escaped newline.
- // Look at the next byte (if any) to see.
- //
- // Note(tfoley): We are assuming a null-terminated input here,
- // so that we can safely look at the next byte without issue.
- int d = lexer->m_cursor[1];
- switch (d)
+ if (lexer->m_cursor + pos == lexer->m_end)
+ return kEOF;
+
+ c = lexer->m_cursor[pos++];
+
+ if (c == '\\')
{
- case '\r': case '\n':
+ // We might have a backslash-escaped newline.
+ // Look at the next byte (if any) to see.
+ //
+ // Note(tfoley): We are assuming a null-terminated input here,
+ // so that we can safely look at the next byte without issue.
+ int d = lexer->m_cursor[pos++];
+ switch (d)
+ {
+ case '\r': case '\n':
{
// The newline was escaped, so return the code point after *that*
- int e = lexer->m_cursor[2];
+ int e = lexer->m_cursor[pos++];
if ((d ^ e) == ('\r' ^ '\n'))
- return lexer->m_cursor[3];
- return e;
+ c = lexer->m_cursor[pos++];
+ else
+ c = e;
+ break;
}
- default:
- break;
+ default:
+ break;
+ }
}
- }
- // TODO: handle UTF-8 encoding for non-ASCII code points here
+ // TODO: handle UTF-8 encoding for non-ASCII code points here
+
+ // Default case is to just hand along the byte we read as an ASCII code point.
+ } while (offset--);
- // Default case is to just hand along the byte we read as an ASCII code point.
return c;
}
@@ -494,10 +505,19 @@ namespace Slang
if( _peek(lexer) == '.' )
{
- tokenType = TokenType::FloatingPointLiteral;
+ switch (_peek(lexer, 1))
+ {
+ // 123.xxxx or 123.rrrr
+ case 'x':
+ case 'r':
+ break;
- _advance(lexer);
- _lexDigits(lexer, base);
+ default:
+ tokenType = TokenType::FloatingPointLiteral;
+
+ _advance(lexer);
+ _lexDigits(lexer, base);
+ }
}
if( _maybeLexNumberExponent(lexer, base))
@@ -1089,8 +1109,16 @@ namespace Slang
return _maybeLexNumberSuffix(lexer, TokenType::IntegerLiteral);
case '.':
- _advance(lexer);
- return _lexNumberAfterDecimalPoint(lexer, 10);
+ switch (_peek(lexer, 1))
+ {
+ // 0.xxxx or 0.rrrr
+ case 'x':
+ case 'r':
+ return _maybeLexNumberSuffix(lexer, TokenType::IntegerLiteral);
+ default:
+ _advance(lexer);
+ return _lexNumberAfterDecimalPoint(lexer, 10);
+ }
case 'x': case 'X':
_advance(lexer);
diff --git a/tests/hlsl-intrinsic/scalar-swizzling.slang b/tests/hlsl-intrinsic/scalar-swizzling.slang
new file mode 100644
index 000000000..9ca024755
--- /dev/null
+++ b/tests/hlsl-intrinsic/scalar-swizzling.slang
@@ -0,0 +1,80 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT: ubuffer(data=[0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ bool result = true
+ && (0.x is int)
+ && 0 == 0.x
+ && all(uint2(0) == 0.xx)
+ && all(uint3(0) == 0.xxx)
+ && all(uint4(0) == 0.xxxx)
+ && (0.r is int)
+ && 0 == 0.r
+ && all(uint2(0) == 0.rr)
+ && all(uint3(0) == 0.rrr)
+ && all(uint4(0) == 0.rrrr)
+
+ && (123.x is int)
+ && 123 == 123.x
+ && all(uint2(123) == 123.xx)
+ && all(uint3(123) == 123.xxx)
+ && all(uint4(123) == 123.xxxx)
+ && (123.r is int)
+ && 123 == 123.r
+ && all(uint2(123) == 123.rr)
+ && all(uint3(123) == 123.rrr)
+ && all(uint4(123) == 123.rrrr)
+
+ && (0.f.x is float)
+ && 0.f == 0.f.x
+ && all(float2(0.f) == 0.f.xx)
+ && all(float3(0.f) == 0.f.xxx)
+ && all(float4(0.f) == 0.f.xxxx)
+ && (0.f.r is float)
+ && 0.f == 0.f.r
+ && all(float2(0.f) == 0.f.rr)
+ && all(float3(0.f) == 0.f.rrr)
+ && all(float4(0.f) == 0.f.rrrr)
+
+ && (123.f.x is float)
+ && 123.f == 123.f.x
+ && all(float2(123.f) == 123.f.xx)
+ && all(float3(123.f) == 123.f.xxx)
+ && all(float4(123.f) == 123.f.xxxx)
+ && (123.f.r is float)
+ && 123.f == 123.f.r
+ && all(float2(123.f) == 123.f.rr)
+ && all(float3(123.f) == 123.f.rrr)
+ && all(float4(123.f) == 123.f.rrrr)
+
+ && (0..x is float)
+ && 0.f == 0..x
+ && all(float2(0.f) == 0..xx)
+ && all(float3(0.f) == 0..xxx)
+ && all(float4(0.f) == 0..xxxx)
+ && (0..r is float)
+ && 0.f == 0..r
+ && all(float2(0.f) == 0..rr)
+ && all(float3(0.f) == 0..rrr)
+ && all(float4(0.f) == 0..rrrr)
+
+ && (123..x is float)
+ && 123.f == 123..x
+ && all(float2(123.f) == 123..xx)
+ && all(float3(123.f) == 123..xxx)
+ && all(float4(123.f) == 123..xxxx)
+ && (123..r is float)
+ && 123.f == 123..r
+ && all(float2(123.f) == 123..rr)
+ && all(float3(123.f) == 123..rrr)
+ && all(float4(123.f) == 123..rrrr)
+ ;
+
+ //CHK:1
+ outputBuffer[0] = int(result);
+}
+