From c787c4b82ba76f87069911f203eb192060b5264f Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 28 Aug 2023 21:24:49 -0700 Subject: Add `target_switch` and `intrinsic_asm` statement. (#3154) * Add `target_switch` and `__intrinsic_asm` statement. * Cleanup. * WaveGetActiveMask, WaveGetActiveMask, WaveCountBits. * WaveIsFirstLane. * More wave intrinsics. * wave intrinsics. * merge fix. * Fix. * Fix. * Update test. * update test. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-parser.cpp | 126 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 121 insertions(+), 5 deletions(-) (limited to 'source/slang/slang-parser.cpp') diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 580215fc7..57a21a90a 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -4202,6 +4202,108 @@ namespace Slang return stmt; } + static Stmt* parseTargetSwitchStmt(Parser* parser) + { + TargetSwitchStmt* stmt = parser->astBuilder->create(); + parser->FillPosition(stmt); + parser->ReadToken(); + if (!beginMatch(parser, MatchedTokenType::CurlyBraces)) + { + return stmt; + } + Token closingBraceToken; + while (!AdvanceIfMatch(parser, MatchedTokenType::CurlyBraces, &closingBraceToken)) + { + List caseNames; + for (;;) + { + if (parser->LookAheadToken("case")) + { + parser->ReadToken(); + caseNames.add(parser->ReadToken()); + parser->ReadToken(TokenType::Colon); + } + else if (parser->LookAheadToken("default")) + { + auto token = parser->ReadToken(); + parser->ReadToken(TokenType::Colon); + token.setContent(UnownedStringSlice("")); + caseNames.add(token); + } + else + break; + } + if (caseNames.getCount() == 0) + { + parser->sink->diagnose( + parser->tokenReader.peekLoc(), + Diagnostics::unexpectedTokenExpectedTokenType, + parser->tokenReader.peekToken(), + "'case' or 'default'"); + parser->isRecovering = true; + goto recover; + } + else + { + Stmt* bodyStmt = nullptr; + for (;;) + { + if (parser->LookAheadToken("case") || parser->LookAheadToken("default") || parser->LookAheadToken(TokenType::RBrace) || + parser->LookAheadToken(TokenType::EndOfFile)) + break; + auto nextStmt = parser->ParseStatement(stmt); + if (nextStmt) + { + if (!bodyStmt) + { + bodyStmt = nextStmt; + } + else if (auto seqStmt = as(bodyStmt)) + { + seqStmt->stmts.add(nextStmt); + } + else + { + SeqStmt* newBody = parser->astBuilder->create(); + newBody->loc = bodyStmt->loc; + newBody->stmts.add(bodyStmt); + newBody->stmts.add(nextStmt); + bodyStmt = newBody; + } + } + } + + for (auto caseName : caseNames) + { + TargetCaseStmt* targetCase = parser->astBuilder->create(); + auto cap = findCapabilityAtom(caseName.getContent()); + if (caseName.getContent().getLength() && cap == CapabilityAtom::Invalid) + { + parser->sink->diagnose(caseName.loc, Diagnostics::unknownTargetName, caseName.getContent()); + } + targetCase->capability = int32_t(cap); + targetCase->loc = caseName.loc; + targetCase->body = bodyStmt; + stmt->targetCases.add(targetCase); + } + } + recover:; + TryRecover(parser); + } + return stmt; + } + + static Stmt* parseIntrinsicAsmStmt(Parser* parser) + { + IntrinsicAsmStmt* stmt = parser->astBuilder->create(); + parser->FillPosition(stmt); + parser->ReadToken(); + + stmt->asmText = getStringLiteralTokenValue(parser->ReadToken(TokenType::StringLiteral)); + parser->ReadToken(TokenType::Semicolon); + return stmt; + } + GpuForeachStmt* ParseGpuForeachStmt(Parser* parser) { // Hard-coding parsing of the following: @@ -4421,6 +4523,10 @@ namespace Slang } else if (LookAheadToken("switch")) statement = ParseSwitchStmt(this); + else if (LookAheadToken("__target_switch")) + statement = parseTargetSwitchStmt(this); + else if (LookAheadToken("__intrinsic_asm")) + statement = parseIntrinsicAsmStmt(this); else if (LookAheadToken("case")) statement = ParseCaseStmt(this); else if (LookAheadToken("default")) @@ -6160,11 +6266,18 @@ namespace Slang return SPIRVAsmOperand{flavor, tok, varExpr}; }; + const auto slangTypeExprOperand = [&](auto flavor) { + auto tok = parser->tokenReader.peekToken(); + const auto typeExpr = parser->ParseType(); + return SPIRVAsmOperand{ flavor, tok, typeExpr }; + }; + // The result marker if(parser->LookAheadToken("result")) { return SPIRVAsmOperand{SPIRVAsmOperand::ResultMarker, parser->ReadToken()}; } + // A regular identifier else if(parser->LookAheadToken(TokenType::Identifier)) { @@ -6206,7 +6319,7 @@ namespace Slang // A $$foo type else if(AdvanceIf(parser, TokenType::DollarDollar)) { - return slangIdentOperand(SPIRVAsmOperand::SlangType); + return slangTypeExprOperand(SPIRVAsmOperand::SlangType); } Unexpected(parser); @@ -6756,8 +6869,8 @@ namespace Slang static NodeBase* parseSPIRVCapabilityModifier(Parser* parser, void*) { - Token token; - token = parser->ReadToken(); + parser->ReadToken(TokenType::LParent); + Token token = parser->ReadToken(TokenType::Identifier); auto modifier = parser->astBuilder->create(); const SPIRVCoreGrammarInfo& spirvInfo = parser->astBuilder->getGlobalSession()->getSPIRVCoreGrammarInfo(); @@ -6765,9 +6878,12 @@ namespace Slang if (!cap) { parser->sink->diagnose(token, Diagnostics::unknownSPIRVCapability, token); - return nullptr; } - modifier->capability = int32_t(*cap); + else + { + modifier->capability = (int32_t)cap.value(); + } + parser->ReadToken(TokenType::RParent); return modifier; } -- cgit v1.2.3