diff options
| author | kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> | 2024-04-17 23:23:15 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-04-17 23:23:15 -0700 |
| commit | d3fd7470e6b71aa080415a3a7c207faebe21b00f (patch) | |
| tree | de3b9c1642fbea11392bfaf170de31a6d80b2826 | |
| parent | 5dd27a26da9b6b6191f3b1eba0f38f85714c1ae3 (diff) | |
Implement if(let ...) syntax (#3673) (#3958)
| -rw-r--r-- | docs/user-guide/06-interfaces-generics.md | 58 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 109 | ||||
| -rw-r--r-- | tests/language-feature/if-let/if-let-1.slang | 89 | ||||
| -rw-r--r-- | tests/language-feature/if-let/if-let-diagnose-1.slang | 53 | ||||
| -rw-r--r-- | tests/language-feature/if-let/if-let-diagnose.slang | 42 | ||||
| -rw-r--r-- | tests/language-feature/if-let/if-let.slang | 38 |
6 files changed, 388 insertions, 1 deletions
diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md index ba543df56..cb6a5070e 100644 --- a/docs/user-guide/06-interfaces-generics.md +++ b/docs/user-guide/06-interfaces-generics.md @@ -677,6 +677,64 @@ T compute<T>(T a1, T a2) // compute(3, 1) == 2 ``` +`as` operator can also be used in the `if` predicate to test if an object can be casted to a specific type, once the cast test is successful, +the object can be used in the `if` block as the casted type without the need to retrieve the `Optional<T>::value` property: +```csharp +interface IFoo +{ + void foo(); +} + +struct MyImpl1 : IFoo +{ + void foo() { printf("MyImpl1");} +} + +struct MyImpl2 : IFoo +{ + void foo() { printf("MyImpl2");} +} + +struct MyImpl3 : IFoo +{ + void foo() { printf("MyImpl3");} +} + +void test(IFoo foo) +{ + // This syntax will be desugared to the following: + // { + // Optional<MyImpl1> $OptVar = foo as MyImpl1; + // if ($OptVar.hasValue) + // { + // MyImpl1 t = $OptVar.value; + // t.foo(); + // } + // else if ... + // } + if (let t = foo as MyImpl1) // t is of type MyImpl1 + { + t.foo(); + } + else if (let t = foo as MyImpl2) // t is of type MyImpl2 + { + t.foo(); + } + else + printf("fail"); +} + +void main() +{ + MyImpl1 v1; + test(v1); + + MyImpl2 v2; + test(v2); +} + +``` + Extensions to Interfaces ----------------------------- diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index e564b4956..2c304a663 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -199,6 +199,7 @@ namespace Slang Stmt* parseLabelStatement(); DeclStmt* parseVarDeclrStatement(Modifiers modifiers); IfStmt* parseIfStatement(); + Stmt* parseIfLetStatement(); ForStmt* ParseForStatement(); WhileStmt* ParseWhileStatement(); DoWhileStmt* ParseDoWhileStatement(); @@ -5276,7 +5277,16 @@ namespace Slang if (LookAheadToken(TokenType::LBrace)) statement = parseBlockStatement(); else if (LookAheadToken("if")) - statement = parseIfStatement(); + { + if(LookAheadToken("let", 2)) + { + statement = parseIfLetStatement(); + } + else + { + statement = parseIfStatement(); + } + } else if (LookAheadToken("for")) statement = ParseForStatement(); else if (LookAheadToken("while")) @@ -5579,6 +5589,103 @@ namespace Slang return varDeclrStatement; } + static Expr* constructIfLetPredicate(Parser* parser, VarExpr* varExpr) + { + // create a "var.hasValue" expression + MemberExpr* memberExpr = parser->astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = varExpr; + parser->FillPosition(memberExpr); + memberExpr->name = getName(parser, "hasValue"); + + return memberExpr; + } + + // Parse the syntax 'if (let var = X as Y)' + Stmt* Parser::parseIfLetStatement() + { + ScopeDecl* scopeDecl = astBuilder->create<ScopeDecl>(); + pushScopeAndSetParent(scopeDecl); + + SeqStmt* newBody = astBuilder->create<SeqStmt>(); + + IfStmt* ifStatement = astBuilder->create<IfStmt>(); + FillPosition(ifStatement); + ReadToken("if"); + ReadToken(TokenType::LParent); + + // parse 'let var = X as Y' + ReadToken("let"); + auto identifierToken = ReadToken(TokenType::Identifier); + ReadToken(TokenType::OpAssign); + auto initExpr = ParseInitExpr(); + + // insert 'let tempVarDecl = X as Y;' + auto tempVarDecl = astBuilder->create<LetDecl>(); + tempVarDecl->nameAndLoc = NameLoc(getName(this, "$OptVar"), identifierToken.loc); + tempVarDecl->initExpr = initExpr; + AddMember(currentScope->containerDecl, tempVarDecl); + + DeclStmt* tmpVarDeclStmt = astBuilder->create<DeclStmt>(); + FillPosition(tmpVarDeclStmt); + tmpVarDeclStmt->decl = tempVarDecl; + newBody->stmts.add(tmpVarDeclStmt); + + // construct 'if (tempVarDecl.hasValue == true)' + VarExpr* tempVarExpr = astBuilder->create<VarExpr>(); + tempVarExpr->scope = currentScope; + FillPosition(tempVarExpr); + tempVarExpr->name = tempVarDecl->getName(); + ifStatement->predicate = constructIfLetPredicate(this, tempVarExpr); + + ReadToken(TokenType::RParent); + + // Create a new scope surrounding the positive statement, will be used for + // the variable declared in the if_let syntax + ScopeDecl* positiveScopeDecl = astBuilder->create<ScopeDecl>(); + pushScopeAndSetParent(positiveScopeDecl); + ifStatement->positiveStatement = ParseStatement(ifStatement); + PopScope(); + + if (LookAheadToken("else")) + { + ReadToken("else"); + ifStatement->negativeStatement = ParseStatement(ifStatement); + } + + if (ifStatement->positiveStatement) + { + auto seqPositiveStmt = as<SeqStmt>(ifStatement->positiveStatement); + if (!seqPositiveStmt) + { + seqPositiveStmt = astBuilder->create<SeqStmt>(); + } + + MemberExpr* memberExpr = astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = tempVarExpr; + memberExpr->name = getName(this, "value"); + + auto varDecl = astBuilder->create<LetDecl>(); + varDecl->nameAndLoc = NameLoc(identifierToken.getName(), identifierToken.loc); + varDecl->initExpr = memberExpr; + + DeclStmt* varDeclrStatement = astBuilder->create<DeclStmt>(); + varDeclrStatement->decl = varDecl; + + // Add scope to the variable declared in the if_let syntax such + // that this variable cannot be used outside the positive statement + AddMember(positiveScopeDecl, varDecl); + + seqPositiveStmt->stmts.add(varDeclrStatement); + seqPositiveStmt->stmts.add(ifStatement->positiveStatement); + ifStatement->positiveStatement = seqPositiveStmt; + } + + newBody->stmts.add(ifStatement); + PopScope(); + + return newBody; + } + IfStmt* Parser::parseIfStatement() { IfStmt* ifStatement = astBuilder->create<IfStmt>(); diff --git a/tests/language-feature/if-let/if-let-1.slang b/tests/language-feature/if-let/if-let-1.slang new file mode 100644 index 000000000..c197f26ea --- /dev/null +++ b/tests/language-feature/if-let/if-let-1.slang @@ -0,0 +1,89 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cuda -compute -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + + +interface IFoo +{ + int foo(int a); +} + +struct MyImpl1 : IFoo +{ + int foo(int a) { return a; } +} + +struct MyImpl2 : IFoo +{ + int foo(int a) { return a + 5; } +} + +int test(IFoo foo, int idx) +{ + int val = 0; + if (let a = foo as MyImpl1) + { + val = a.foo(idx); + } + else if (let a = foo as MyImpl2) + { + val = a.foo(idx); + } + return (val); +} + +int test1<T>(T t) +{ + if (let a = t as uint) + { + return 1; + } + else if(let a = t as float) + { + return 2; + } + else if (let a = t as double) + { + return 3; + } + else if (let a = t as int) + { + return 4; + } + else if (let a = t as uint64_t) + { + return 5; + } + else + { + return 6; + } +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + MyImpl1 impl1; + MyImpl2 impl2; + // CHECK: 1 + // CHECK: 7 + outputBuffer[0] = test(impl1, 1); + outputBuffer[1] = test(impl2, 2); + + // CHECK: 1 + outputBuffer[2] = test1(2U); + // CHECK: 2 + outputBuffer[3] = test1(2.0f); + // CHECK: 3 + outputBuffer[4] = test1(2.0lf); + // CHECK: 4 + outputBuffer[5] = test1(2); + // CHECK: 5 + outputBuffer[6] = test1(2LLU); + // CHECK: 6 + outputBuffer[7] = test1(impl1); +} diff --git a/tests/language-feature/if-let/if-let-diagnose-1.slang b/tests/language-feature/if-let/if-let-diagnose-1.slang new file mode 100644 index 000000000..2322abb1a --- /dev/null +++ b/tests/language-feature/if-let/if-let-diagnose-1.slang @@ -0,0 +1,53 @@ +//TEST:SIMPLE(filecheck=CHECK): -target glsl -stage compute -entry computeMain +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage compute -entry computeMain +//TEST:SIMPLE(filecheck=CHECK): -target cuda -stage compute -entry computeMain +//TEST:SIMPLE(filecheck=CHECK): -target cpp -stage compute -entry computeMain + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + + +interface IFoo +{ + int foo(int a); +} + +struct MyImpl : IFoo +{ + int foo(int a) { return a; } +} + +struct MyImpl1 : IFoo +{ + int foo(int a) { return a; } +} + +int test(IFoo foo, int idx) +{ + int val = 0; + if (let a = foo as MyImpl) + { + val = a.foo(idx); + } + // CHECK: error 30015: undefined identifier 'a'. + else if(a == none) + { + val = -1; + } + else + { + // CHECK: error 30015: undefined identifier 'a'. + if (a == none) + { + val = -1; + } + } + return (val); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + MyImpl1 impl; + outputBuffer[dispatchThreadID.x] = test(impl, dispatchThreadID.x); +} diff --git a/tests/language-feature/if-let/if-let-diagnose.slang b/tests/language-feature/if-let/if-let-diagnose.slang new file mode 100644 index 000000000..4e9aa69c7 --- /dev/null +++ b/tests/language-feature/if-let/if-let-diagnose.slang @@ -0,0 +1,42 @@ +//TEST:SIMPLE(filecheck=CHECK): -target glsl -stage compute -entry computeMain +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage compute -entry computeMain +//TEST:SIMPLE(filecheck=CHECK): -target cuda -stage compute -entry computeMain +//TEST:SIMPLE(filecheck=CHECK): -target cpp -stage compute -entry computeMain + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + + +interface IFoo +{ + int foo(int a); +} + +struct MyImpl : IFoo +{ + int foo(int a) { return a; } +} + +struct MyImpl1 : IFoo +{ + int foo(int a) { return a; } +} + +int test(IFoo foo, int idx) +{ + int val = 0; + // CHECK: error 20002: syntax error. + if ((let a = foo as MyImpl)) + { + val = a.foo(idx); + } + return (val); +} + + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + MyImpl impl; + outputBuffer[dispatchThreadID.x] = test(impl, dispatchThreadID.x); +} diff --git a/tests/language-feature/if-let/if-let.slang b/tests/language-feature/if-let/if-let.slang new file mode 100644 index 000000000..5bb31030f --- /dev/null +++ b/tests/language-feature/if-let/if-let.slang @@ -0,0 +1,38 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cuda -compute -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + + +interface IFoo +{ + int foo(int a); +} + +struct MyImpl : IFoo +{ + int foo(int a) { return a; } +} + +int test(IFoo foo, int idx) +{ + int val = 0; + if (let a = foo as MyImpl) + { + val = a.foo(idx); + } + return (val); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + MyImpl impl; + // CHECK: 0 + // CHECK: 1 + // CHECK: 2 + // CHECK: 3 + outputBuffer[dispatchThreadID.x] = test(impl, dispatchThreadID.x); +} |
