diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/emit.cpp | 136 | ||||
| -rw-r--r-- | source/slang/ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/ir-insts.h | 26 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 23 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 269 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 6 |
6 files changed, 457 insertions, 5 deletions
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index a8cd11af4..84d8f113e 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -2896,7 +2896,7 @@ struct EmitVisitor } else if (auto defaultStmt = stmt.As<DefaultStmt>()) { - Emit("default:{}\n"); + Emit("default:\n"); return; } else if (auto breakStmt = stmt.As<BreakStmt>()) @@ -5643,8 +5643,142 @@ emitDeclImpl(decl, nullptr); break; case kIROp_conditionalBranch: + // Note: We currently do not generate any plain + // `conditionalBranch` instructions when lowering + // to IR, because these would not have the annotations + // needed to be able to emit high-level control + // flow from them. SLANG_UNEXPECTED("terminator inst"); return; + + + case kIROp_switch: + { + // A `switch` instruction will always translate + // to a `switch` statement, but we need to + // take some care to emit the `case`s in ways + // that avoid code duplication. + + // TODO: Eventually, the "right" way to handle `switch` + // statements while being more robust about Duff's Device, etc. + // would be to register each of the case labels in a lookup + // table, and then walk the blocks in the region between + // the `switch` and the `break` and then whenever we see a block + // that matches one of the registered labels, emit the appropriate + // `case ...:` or `default:` label. + + auto t = (IRSwitch*) terminator; + + // Extract the fixed arguments. + auto conditionVal = t->getCondition(); + auto breakLabel = t->getBreakLabel(); + auto defaultLabel = t->getDefaultLabel(); + + // We need to track whether we've dealt with + // the `default` case already. + bool defaultLabelHandled = false; + + // If the `default` case just branches to + // the join point, then we don't need to + // do anything with it. + if(defaultLabel == breakLabel) + defaultLabelHandled = true; + + // Emit the start of our statement. + emit("switch("); + emitIROperand(ctx, conditionVal); + emit(")\n{\n"); + + // Now iterate over the `case`s of the branch + UInt caseIndex = 0; + UInt caseCount = t->getCaseCount(); + while(caseIndex < caseCount) + { + // We are going to extract one case here, + // but we might need to fold additional + // cases into it, if they share the + // same label. + // + // Note: this makes assumptions that the + // IR code generator orders cases such + // that: (1) cases with the same label + // are consecutive, and (2) any case + // that "falls through" to another must + // come right before it in the list. + auto caseVal = t->getCaseValue(caseIndex); + auto caseLabel = t->getCaseLabel(caseIndex); + caseIndex++; + + // Emit the `case ...:` for this case, and any + // others that share the same label + for(;;) + { + emit("case "); + emitIROperand(ctx, caseVal); + emit(":\n"); + + if(caseIndex >= caseCount) + break; + + auto nextCaseLabel = t->getCaseLabel(caseIndex); + if(nextCaseLabel != caseLabel) + break; + + caseVal = t->getCaseValue(caseIndex); + caseIndex++; + } + + // The label for the current `case` might also + // be the label used by the `default` case, so + // check for that here. + if(caseLabel == defaultLabel) + { + emit("default:\n"); + defaultLabelHandled = true; + } + + // Now we need to emit the statements that make + // up this case. The 99% case will be that it + // will terminate with a `break` (or a `return`, + // `continue`, etc.) and so we can pass in + // `nullptr` for the ending block. + IRBlock* caseEndLabel = nullptr; + + // However, there is also the possibility that + // this case will fall through to the next, and + // so we need to prepare for that possibility here. + // + // If there is a next case, then we will set its + // label up as the "end" label when emitting + // the statements inside the block. + if(caseIndex < caseCount) + { + caseEndLabel = t->getCaseLabel(caseIndex); + } + + // Now emit the statements for this case. + emit("{\n"); + emitIRStmtsForBlocks(ctx, caseLabel, caseEndLabel); + emit("}\n"); + } + + // If we've gone through all the cases and haven't + // managed to encounter the `default:` label, + // then assume it is a distinct case and handle it here. + if(!defaultLabelHandled) + { + emit("default:\n"); + emit("{\n"); + emitIRStmtsForBlocks(ctx, defaultLabel, breakLabel); + emit("break;\n"); + emit("}\n"); + } + + emit("}\n"); + block = breakLabel; + + } + break; } // If we reach this point, then we've emitted diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index a8118f714..dbdb697a4 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -181,6 +181,8 @@ INST(if, if, 3, 0) INST(ifElse, ifElse, 4, 0) INST(loopTest, loopTest, 3, 0) +INST(switch, switch, 3, 0) + INST(discard, discard, 0, 0) INST(Add, add, 2, 0) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 06ccf2921..b4335656e 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -209,6 +209,24 @@ struct IRIfElse : IRConditionalBranch IRBlock* getAfterBlock() { return (IRBlock*)afterBlock.usedValue; } }; +// A multi-way branch that represents a source-level `switch` +struct IRSwitch : IRTerminatorInst +{ + IRUse condition; + IRUse breakLabel; + IRUse defaultLabel; + + IRValue* getCondition() { return condition.usedValue; } + IRBlock* getBreakLabel() { return (IRBlock*) breakLabel.usedValue; } + IRBlock* getDefaultLabel() { return (IRBlock*) defaultLabel.usedValue; } + + // remaining args are: caseVal, caseLabel, ... + + UInt getCaseCount() { return (getArgCount() - 3) / 2; } + IRValue* getCaseValue(UInt index) { return getArg(3 + index*2 + 0); } + IRBlock* getCaseLabel(UInt index) { return (IRBlock*) getArg(3 + index*2 + 1); } +}; + struct IRSwizzle : IRReturn { IRUse base; @@ -504,6 +522,14 @@ struct IRBuilder IRBlock* bodyBlock, IRBlock* breakBlock); + IRInst* emitSwitch( + IRValue* val, + IRBlock* breakLabel, + IRBlock* defaultLabel, + UInt caseArgCount, + IRValue* const* caseArgs); + + IRDecoration* addDecorationImpl( IRValue* value, UInt decorationSize, diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 8fbe20aa6..313fd258b 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -167,6 +167,7 @@ namespace Slang case kIROp_ifElse: case kIROp_loopTest: case kIROp_discard: + case kIROp_switch: return true; } } @@ -1289,6 +1290,28 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitSwitch( + IRValue* val, + IRBlock* breakLabel, + IRBlock* defaultLabel, + UInt caseArgCount, + IRValue* const* caseArgs) + { + IRValue* fixedArgs[] = { val, breakLabel, defaultLabel }; + UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); + + auto inst = createInstWithTrailingArgs<IRSwitch>( + this, + kIROp_switch, + nullptr, + fixedArgCount, + fixedArgs, + caseArgCount, + caseArgs); + addInst(inst); + return inst; + } + IRDecoration* IRBuilder::addDecorationImpl( IRValue* inst, UInt decorationSize, diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index bbbc1812b..f642b316e 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -1622,9 +1622,24 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> IRBuilder* getBuilder() { return context->irBuilder; } - void visitStmt(Stmt* /*stmt*/) + void visitEmptyStmt(EmptyStmt*) { - SLANG_UNIMPLEMENTED_X("stmt catch-all"); + // Nothing to do. + } + + void visitUnparsedStmt(UnparsedStmt*) + { + SLANG_UNEXPECTED("UnparsedStmt not supported by IR"); + } + + void visitCaseStmtBase(CaseStmtBase*) + { + SLANG_UNEXPECTED("`case` or `default` not under `switch`"); + } + + void visitCompileTimeForStmt(CompileTimeForStmt*) + { + SLANG_UNIMPLEMENTED_X("IR lowering of CompileTimeForStmt"); } // Create a basic block in the current function, @@ -1642,12 +1657,12 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> auto builder = getBuilder(); auto prevBlock = builder->curBlock; - auto parentFunc = prevBlock->parentFunc; + auto parentFunc = prevBlock ? prevBlock->parentFunc : builder->curFunc; // If the previous block doesn't already have // a terminator instruction, then be sure to // emit a branch to the new block. - if (!isTerminatorInst(prevBlock->lastInst)) + if (prevBlock && !isTerminatorInst(prevBlock->lastInst)) { builder->emitBranch(block); } @@ -2020,6 +2035,252 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> SLANG_ASSERT(targetBlock); getBuilder()->emitContinue(targetBlock); } + + // Lowering a `switch` statement can get pretty involved, + // so we need to track a bit of extra data: + struct SwitchStmtInfo + { + // The label for the `default` case, if any. + IRBlock* defaultLabel = nullptr; + + // The label of the current "active" case block. + IRBlock* currentCaseLabel = nullptr; + + // Has anything been emitted to the current "active" case block? + bool anythingEmittedToCurrentCaseBlock = false; + + // The collected (value, label) pairs for + // all the `case` statements. + List<IRValue*> cases; + }; + + // We need a label to use for a `case` or `default` statement, + // so either create one here, or re-use the current one if + // that is okay. + IRBlock* getLabelForCase(SwitchStmtInfo* info) + { + // Look at the "current" label we are working with. + auto currentCaseLabel = info->currentCaseLabel; + + // If there is a current block, and it is empty, + // then it is still a viable target (we are in + // a case of "trivial fall-through" from the previous + // block). + if(currentCaseLabel && !info->anythingEmittedToCurrentCaseBlock) + { + return currentCaseLabel; + } + + // Othwerise, we need to start a new block and use that. + IRBlock* newCaseLabel = createBlock(); + + // Note: if the previous block failed + // to end with a `break`, then inserting + // this block will append an unconditional + // branch to the end of it that will target + // this block. + insertBlock(newCaseLabel); + + info->currentCaseLabel = newCaseLabel; + info->anythingEmittedToCurrentCaseBlock = false; + return newCaseLabel; + } + + // Given a statement that appears as (or in) the body + // of a `switch` statement + void lowerSwitchCases(Stmt* inStmt, SwitchStmtInfo* info) + { + // TODO: in the general case (e.g., if we were going + // to eventual lower to an unstructured format like LLVM), + // the Right Way to handle C-style `switch` statements + // is just to emit the body directly as "normal" statements, + // and then treat `case` and `default` as special statements + // that start a new block and register a label with the + // enclosing `switch`. + // + // For now we will assume that any `case` and `default` + // statements need to be direclty nested under the `switch`, + // and so we can find them with a simpler walk. + + Stmt* stmt = inStmt; + + // Unwrap any surrounding `{ ... }` so we can look + // at the statement inside. + while(auto blockStmt = dynamic_cast<BlockStmt*>(stmt)) + { + stmt = blockStmt->body; + continue; + } + + if(auto seqStmt = dynamic_cast<SeqStmt*>(stmt)) + { + // Walk through teh children and process each. + for(auto childStmt : seqStmt->stmts) + { + lowerSwitchCases(childStmt, info); + } + } + else if(auto caseStmt = dynamic_cast<CaseStmt*>(stmt)) + { + // A full `case` statement has a value we need + // to test against. It is expected to be a + // compile-time constant, so we will emit + // it like an expression here, and then hope + // for the best. + // + // TODO: figure out something cleaner. + auto caseVal = getSimpleVal(context, lowerRValueExpr(context, caseStmt->expr)); + + // Figure out where we are branching to. + auto label = getLabelForCase(info); + + + // Add this `case` to the list for the enclosing `switch`. + info->cases.Add(caseVal); + info->cases.Add(label); + } + else if(auto defaultStmt = dynamic_cast<DefaultStmt*>(stmt)) + { + auto label = getLabelForCase(info); + + // We expect to only find a single `default` stmt. + SLANG_ASSERT(!info->defaultLabel); + + info->defaultLabel = label; + } + else if(auto emptyStmt = dynamic_cast<EmptyStmt*>(stmt)) + { + // Special-case empty statements so they don't + // mess up our "trivial fall-through" optimization. + } + else + { + // We have an ordinary statement, that needs to get + // emitted to the currrent case block. + if(!info->currentCaseLabel) + { + // It possible in full C/C++ to have statements + // before the first `case`. Usually these are + // unreachable, unless they start with a label. + // + // We'll ignore them here, figuring they are + // dead. If we ever add `LabelStmt` then we'd + // need to emit these statements to a dummy + // block just in case. + } + else + { + // Emit the code to our current case block, + // and record that we've done so. + lowerStmt(context, stmt); + info->anythingEmittedToCurrentCaseBlock = true; + } + } + } + + void visitSwitchStmt(SwitchStmt* stmt) + { + auto builder = getBuilder(); + + // Given a statement: + // + // switch( CONDITION ) + // { + // case V0: + // S0; + // break; + // + // case V1: + // default: + // S1; + // break; + // } + // + // we want to generate IR like: + // + // let %c = <CONDITION>; + // switch %c, // value to switch on + // %breakLabel, // join point (and break target) + // %s1, // default label + // %v0, // first case value + // %s0, // first case label + // %v1, // second case value + // %s1 // second case label + // s0: + // <S0> + // break %breakLabel + // s1: + // <S1> + // break %breakLabel + // breakLabel: + // + + // First emit code to compute the condition: + auto conditionVal = getSimpleVal(context, lowerRValueExpr(context, stmt->condition)); + + // Remember the initial block so that we can add to it + // after we've collected all the `case`s + auto initialBlock = builder->curBlock; + + // Next, create a block to use as the target for any `break` statements + auto breakLabel = createBlock(); + + // Register the `break` label so + // that we can find it for nested statements. + context->shared->breakLabels.Add(stmt, breakLabel); + + builder->curFunc = initialBlock->parentFunc; + builder->curBlock = nullptr; + + // Iterate over the body of the statement, looking + // for `case` or `default` statements: + SwitchStmtInfo info; + info.defaultLabel = nullptr; + lowerSwitchCases(stmt->body, &info); + + // TODO: once we've discovered the cases, we should + // be able to make a quick pass over the list and eliminate + // any cases that have the exact same label as the `default` + // case, since these don't actually need to be represented. + + // If the current block (the end of the last + // `case`) is not terminated, then terminate with a + // `break` operation. + // + // Double check that we aren't in the initial + // block, so we don't get tripped up on an + // empty `switch`. + if(builder->curBlock != initialBlock) + { + // Is the block already terminated? + auto lastInst = builder->curBlock->lastInst; + if(!lastInst || !isTerminatorInst(lastInst)) + { + // Not terminated, so add one. + builder->emitBreak(breakLabel); + } + } + + // If there was no `default` statement, then the + // default case will just branch directly to the end. + auto defaultLabel = info.defaultLabel ? info.defaultLabel : breakLabel; + + // Now that we've collected the cases, we are + // prepared to emit the `switch` instruction + // itself. + builder->curBlock = initialBlock; + builder->emitSwitch( + conditionVal, + breakLabel, + defaultLabel, + info.cases.Count(), + info.cases.Buffer()); + + // Finally we insert the label that a `break` will jump to + // (and that control flow will fall through to otherwise). + // This is the block that subsequent code will go into. + insertBlock(breakLabel); + } }; void lowerStmt( diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index eb02d98c5..42c763099 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -2973,6 +2973,12 @@ namespace Slang } PopScope(); + if(!body) + { + body = new EmptyStmt(); + body->loc = blockStatement->loc; + } + blockStatement->body = body; return blockStatement; } |
