diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ast-stmt.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 52 |
5 files changed, 80 insertions, 0 deletions
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index e9e5ea4f3..2fb10db1f 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -117,6 +117,17 @@ class DefaultStmt : public CaseStmtBase SLANG_CLASS(DefaultStmt) }; +// a `default` statement inside a `switch` +class GpuForeachStmt : public ScopeStmt +{ + SLANG_CLASS(GpuForeachStmt) + + Expr* renderer = nullptr; + Expr* gridDims = nullptr; + VarDecl* dispatchThreadID = nullptr; + Expr* kernelCall = nullptr; +}; + // A statement that represents a loop, and can thus be escaped with a `continue` class LoopStmt : public BreakableStmt { diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 97e77ec3e..19f5553c7 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1596,6 +1596,8 @@ namespace Slang void visitReturnStmt(ReturnStmt *stmt); void visitWhileStmt(WhileStmt *stmt); + + void visitGpuForeachStmt(GpuForeachStmt *stmt); void visitExpressionStmt(ExpressionStmt *stmt); }; diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 2d01086f1..9a5aee15c 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -288,4 +288,13 @@ namespace Slang stmt->expression = CheckExpr(stmt->expression); } + void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt*stmt) + { + stmt->renderer = CheckExpr(stmt->renderer); + stmt->gridDims = CheckExpr(stmt->gridDims); + ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::Checked); + WithOuterStmt subContext(this, stmt); + stmt->kernelCall = subContext.CheckExpr(stmt->kernelCall); + return; + } } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index a4542647a..9c4808f31 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4114,6 +4114,12 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> insertBlock(breakLabel); } + void visitGpuForeachStmt(GpuForeachStmt* stmt) + { + startBlockIfNeeded(stmt); + return; + } + void visitExpressionStmt(ExpressionStmt* stmt) { startBlockIfNeeded(stmt); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 460f781fe..179587550 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -3380,6 +3380,56 @@ namespace Slang return stmt; } + GpuForeachStmt* ParseGpuForeachStmt(Parser* parser) + { + // Hard-coding parsing of the following: + // __GPU_FOREACH(renderer, gridDims, LAMBDA(uint3 dispatchThreadID) { + // kernelCall(args, ...); }); + + // Setup the scope so that dispatchThreadID is in scope for kernelCall + ScopeDecl* scopeDecl = parser->astBuilder->create<ScopeDecl>(); + GpuForeachStmt* stmt = parser->astBuilder->create<GpuForeachStmt>(); + stmt->scopeDecl = scopeDecl; + + parser->FillPosition(stmt); + parser->ReadToken("__GPU_FOREACH"); + parser->ReadToken(TokenType::LParent); + stmt->renderer = parser->ParseArgExpr(); + parser->ReadToken(TokenType::Comma); + stmt->gridDims = parser->ParseArgExpr(); + + parser->ReadToken(TokenType::Comma); + parser->ReadToken("LAMBDA"); + parser->ReadToken(TokenType::LParent); + + auto idType = parser->ParseTypeExp(); + NameLoc varNameAndLoc = expectIdentifier(parser); + VarDecl* varDecl = parser->astBuilder->create<VarDecl>(); + varDecl->nameAndLoc = varNameAndLoc; + varDecl->loc = varNameAndLoc.loc; + varDecl->type = idType; + stmt->dispatchThreadID = varDecl; + + parser->ReadToken(TokenType::RParent); + parser->ReadToken(TokenType::LBrace); + + parser->pushScopeAndSetParent(scopeDecl); + AddMember(parser->currentScope, varDecl); + + stmt->kernelCall = parser->ParseExpression(); + + parser->PopScope(); + + parser->ReadToken(TokenType::Semicolon); + parser->ReadToken(TokenType::RBrace); + + parser->ReadToken(TokenType::RParent); + + parser->ReadToken(TokenType::Semicolon); + + return stmt; + } + static bool _isType(Decl* decl) { return decl && (as<AggTypeDecl>(decl) || as<SimpleTypeDecl>(decl)); @@ -3552,6 +3602,8 @@ namespace Slang statement = ParseCaseStmt(this); else if (LookAheadToken("default")) statement = ParseDefaultStmt(this); + else if (LookAheadToken("__GPU_FOREACH")) + statement = ParseGpuForeachStmt(this); else if (LookAheadToken(TokenType::Dollar)) { statement = parseCompileTimeStmt(this); |
