diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2017-11-14 11:33:36 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-11-14 11:33:36 -0800 |
| commit | 59a4c0caed15b607990fdc1990992fb1b944ae96 (patch) | |
| tree | 55d418af6456b7f3a15fdfea5df3c7aaa37b545a /source/slang | |
| parent | 7c4ad877dcd353b9431011f0fc98aad315366c6c (diff) | |
IR: add support for `switch` statements (#278)
* IR: add support for `switch` statements
Fixes #273
This is just something we hadn't gotten to yet on the IR. The actual design of the instruction is unsurprising (once you take into consideration the requirement for structured control flow). A `switch` instruction takes the form:
switch <condition> <breakLabel> <defaultLabel> [<caseVal> <caseLabel>]*
Where `condition` is the value to switch on, `breakLabel` is the "join point" after the original `switch` statement, `defaultLabel` is where to go if the value doesn't match any case, and each pair of `caseVal` and `caseLabel` is what to do on a particular value. It is required that `caseVal` be a literal, but this isn't currently being enforced in the IR (the front-end should be making a check and constant-folding the case labels).
For structured control flow, we also make the assumption that the cases are in order: cases with the same label must be grouped together, and any case that falls through to another must come right before it.
Given this representation, the emit logic can reconstruct a `switch` statement with relative ease, given the machinery we already have. It makes sure to group together case values with the same label (again, assuming they are contiguous), and will insert the `default:` label in with whatever group it belongs to.
Actually emitting code for a `switch` statement seems superficially simple, until you realize that a complete implementation needs to handle stuff like "Duff's Device." The current implementation makes the assumption that all `case` and `default` statements are directly nested under a `switch`, and that there is no way for control flow to enter a case except by the `switch` itself, or fall-through.
In order to facilitate the grouping of cases in the IR-to-HLSL emit logic, the AST-to-IR lowering logic tries to detect cases where there are multiple `case`s in a row, and emit only a single label for them.
One big/annoying gotcha is that we don't properly handle the case where a `default:` case has a non-trivial fall-throguh to another case. That seems fine for now since HLSL doesn't support fall-through anyway, but it probably needs to get detected somewhere in the Slang compiler (e.g., maybe we should add a diagnostic pass over the IR that detects target-specific problems like that and emits errors).
* IR: Add support for empty statements.
- Add empty statement in `lower-to-ir.cpp`
- Go ahead and eliminate the statement catch-all and explicitly enumerate the cases we don't support
- Fix up parser for block statements so that it doesn't leave a null statement as the body of a `{}`
- Add an empty statement to one of the cases for the `switch` test, to ensure we are testing empty statements
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/emit.cpp | 136 | ||||
| -rw-r--r-- | source/slang/ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/ir-insts.h | 26 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 23 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 269 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 6 |
6 files changed, 457 insertions, 5 deletions
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index a8cd11af4..84d8f113e 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -2896,7 +2896,7 @@ struct EmitVisitor } else if (auto defaultStmt = stmt.As<DefaultStmt>()) { - Emit("default:{}\n"); + Emit("default:\n"); return; } else if (auto breakStmt = stmt.As<BreakStmt>()) @@ -5643,8 +5643,142 @@ emitDeclImpl(decl, nullptr); break; case kIROp_conditionalBranch: + // Note: We currently do not generate any plain + // `conditionalBranch` instructions when lowering + // to IR, because these would not have the annotations + // needed to be able to emit high-level control + // flow from them. SLANG_UNEXPECTED("terminator inst"); return; + + + case kIROp_switch: + { + // A `switch` instruction will always translate + // to a `switch` statement, but we need to + // take some care to emit the `case`s in ways + // that avoid code duplication. + + // TODO: Eventually, the "right" way to handle `switch` + // statements while being more robust about Duff's Device, etc. + // would be to register each of the case labels in a lookup + // table, and then walk the blocks in the region between + // the `switch` and the `break` and then whenever we see a block + // that matches one of the registered labels, emit the appropriate + // `case ...:` or `default:` label. + + auto t = (IRSwitch*) terminator; + + // Extract the fixed arguments. + auto conditionVal = t->getCondition(); + auto breakLabel = t->getBreakLabel(); + auto defaultLabel = t->getDefaultLabel(); + + // We need to track whether we've dealt with + // the `default` case already. + bool defaultLabelHandled = false; + + // If the `default` case just branches to + // the join point, then we don't need to + // do anything with it. + if(defaultLabel == breakLabel) + defaultLabelHandled = true; + + // Emit the start of our statement. + emit("switch("); + emitIROperand(ctx, conditionVal); + emit(")\n{\n"); + + // Now iterate over the `case`s of the branch + UInt caseIndex = 0; + UInt caseCount = t->getCaseCount(); + while(caseIndex < caseCount) + { + // We are going to extract one case here, + // but we might need to fold additional + // cases into it, if they share the + // same label. + // + // Note: this makes assumptions that the + // IR code generator orders cases such + // that: (1) cases with the same label + // are consecutive, and (2) any case + // that "falls through" to another must + // come right before it in the list. + auto caseVal = t->getCaseValue(caseIndex); + auto caseLabel = t->getCaseLabel(caseIndex); + caseIndex++; + + // Emit the `case ...:` for this case, and any + // others that share the same label + for(;;) + { + emit("case "); + emitIROperand(ctx, caseVal); + emit(":\n"); + + if(caseIndex >= caseCount) + break; + + auto nextCaseLabel = t->getCaseLabel(caseIndex); + if(nextCaseLabel != caseLabel) + break; + + caseVal = t->getCaseValue(caseIndex); + caseIndex++; + } + + // The label for the current `case` might also + // be the label used by the `default` case, so + // check for that here. + if(caseLabel == defaultLabel) + { + emit("default:\n"); + defaultLabelHandled = true; + } + + // Now we need to emit the statements that make + // up this case. The 99% case will be that it + // will terminate with a `break` (or a `return`, + // `continue`, etc.) and so we can pass in + // `nullptr` for the ending block. + IRBlock* caseEndLabel = nullptr; + + // However, there is also the possibility that + // this case will fall through to the next, and + // so we need to prepare for that possibility here. + // + // If there is a next case, then we will set its + // label up as the "end" label when emitting + // the statements inside the block. + if(caseIndex < caseCount) + { + caseEndLabel = t->getCaseLabel(caseIndex); + } + + // Now emit the statements for this case. + emit("{\n"); + emitIRStmtsForBlocks(ctx, caseLabel, caseEndLabel); + emit("}\n"); + } + + // If we've gone through all the cases and haven't + // managed to encounter the `default:` label, + // then assume it is a distinct case and handle it here. + if(!defaultLabelHandled) + { + emit("default:\n"); + emit("{\n"); + emitIRStmtsForBlocks(ctx, defaultLabel, breakLabel); + emit("break;\n"); + emit("}\n"); + } + + emit("}\n"); + block = breakLabel; + + } + break; } // If we reach this point, then we've emitted diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index a8118f714..dbdb697a4 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -181,6 +181,8 @@ INST(if, if, 3, 0) INST(ifElse, ifElse, 4, 0) INST(loopTest, loopTest, 3, 0) +INST(switch, switch, 3, 0) + INST(discard, discard, 0, 0) INST(Add, add, 2, 0) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 06ccf2921..b4335656e 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -209,6 +209,24 @@ struct IRIfElse : IRConditionalBranch IRBlock* getAfterBlock() { return (IRBlock*)afterBlock.usedValue; } }; +// A multi-way branch that represents a source-level `switch` +struct IRSwitch : IRTerminatorInst +{ + IRUse condition; + IRUse breakLabel; + IRUse defaultLabel; + + IRValue* getCondition() { return condition.usedValue; } + IRBlock* getBreakLabel() { return (IRBlock*) breakLabel.usedValue; } + IRBlock* getDefaultLabel() { return (IRBlock*) defaultLabel.usedValue; } + + // remaining args are: caseVal, caseLabel, ... + + UInt getCaseCount() { return (getArgCount() - 3) / 2; } + IRValue* getCaseValue(UInt index) { return getArg(3 + index*2 + 0); } + IRBlock* getCaseLabel(UInt index) { return (IRBlock*) getArg(3 + index*2 + 1); } +}; + struct IRSwizzle : IRReturn { IRUse base; @@ -504,6 +522,14 @@ struct IRBuilder IRBlock* bodyBlock, IRBlock* breakBlock); + IRInst* emitSwitch( + IRValue* val, + IRBlock* breakLabel, + IRBlock* defaultLabel, + UInt caseArgCount, + IRValue* const* caseArgs); + + IRDecoration* addDecorationImpl( IRValue* value, UInt decorationSize, diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 8fbe20aa6..313fd258b 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -167,6 +167,7 @@ namespace Slang case kIROp_ifElse: case kIROp_loopTest: case kIROp_discard: + case kIROp_switch: return true; } } @@ -1289,6 +1290,28 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitSwitch( + IRValue* val, + IRBlock* breakLabel, + IRBlock* defaultLabel, + UInt caseArgCount, + IRValue* const* caseArgs) + { + IRValue* fixedArgs[] = { val, breakLabel, defaultLabel }; + UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); + + auto inst = createInstWithTrailingArgs<IRSwitch>( + this, + kIROp_switch, + nullptr, + fixedArgCount, + fixedArgs, + caseArgCount, + caseArgs); + addInst(inst); + return inst; + } + IRDecoration* IRBuilder::addDecorationImpl( IRValue* inst, UInt decorationSize, diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index bbbc1812b..f642b316e 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -1622,9 +1622,24 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> IRBuilder* getBuilder() { return context->irBuilder; } - void visitStmt(Stmt* /*stmt*/) + void visitEmptyStmt(EmptyStmt*) { - SLANG_UNIMPLEMENTED_X("stmt catch-all"); + // Nothing to do. + } + + void visitUnparsedStmt(UnparsedStmt*) + { + SLANG_UNEXPECTED("UnparsedStmt not supported by IR"); + } + + void visitCaseStmtBase(CaseStmtBase*) + { + SLANG_UNEXPECTED("`case` or `default` not under `switch`"); + } + + void visitCompileTimeForStmt(CompileTimeForStmt*) + { + SLANG_UNIMPLEMENTED_X("IR lowering of CompileTimeForStmt"); } // Create a basic block in the current function, @@ -1642,12 +1657,12 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> auto builder = getBuilder(); auto prevBlock = builder->curBlock; - auto parentFunc = prevBlock->parentFunc; + auto parentFunc = prevBlock ? prevBlock->parentFunc : builder->curFunc; // If the previous block doesn't already have // a terminator instruction, then be sure to // emit a branch to the new block. - if (!isTerminatorInst(prevBlock->lastInst)) + if (prevBlock && !isTerminatorInst(prevBlock->lastInst)) { builder->emitBranch(block); } @@ -2020,6 +2035,252 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> SLANG_ASSERT(targetBlock); getBuilder()->emitContinue(targetBlock); } + + // Lowering a `switch` statement can get pretty involved, + // so we need to track a bit of extra data: + struct SwitchStmtInfo + { + // The label for the `default` case, if any. + IRBlock* defaultLabel = nullptr; + + // The label of the current "active" case block. + IRBlock* currentCaseLabel = nullptr; + + // Has anything been emitted to the current "active" case block? + bool anythingEmittedToCurrentCaseBlock = false; + + // The collected (value, label) pairs for + // all the `case` statements. + List<IRValue*> cases; + }; + + // We need a label to use for a `case` or `default` statement, + // so either create one here, or re-use the current one if + // that is okay. + IRBlock* getLabelForCase(SwitchStmtInfo* info) + { + // Look at the "current" label we are working with. + auto currentCaseLabel = info->currentCaseLabel; + + // If there is a current block, and it is empty, + // then it is still a viable target (we are in + // a case of "trivial fall-through" from the previous + // block). + if(currentCaseLabel && !info->anythingEmittedToCurrentCaseBlock) + { + return currentCaseLabel; + } + + // Othwerise, we need to start a new block and use that. + IRBlock* newCaseLabel = createBlock(); + + // Note: if the previous block failed + // to end with a `break`, then inserting + // this block will append an unconditional + // branch to the end of it that will target + // this block. + insertBlock(newCaseLabel); + + info->currentCaseLabel = newCaseLabel; + info->anythingEmittedToCurrentCaseBlock = false; + return newCaseLabel; + } + + // Given a statement that appears as (or in) the body + // of a `switch` statement + void lowerSwitchCases(Stmt* inStmt, SwitchStmtInfo* info) + { + // TODO: in the general case (e.g., if we were going + // to eventual lower to an unstructured format like LLVM), + // the Right Way to handle C-style `switch` statements + // is just to emit the body directly as "normal" statements, + // and then treat `case` and `default` as special statements + // that start a new block and register a label with the + // enclosing `switch`. + // + // For now we will assume that any `case` and `default` + // statements need to be direclty nested under the `switch`, + // and so we can find them with a simpler walk. + + Stmt* stmt = inStmt; + + // Unwrap any surrounding `{ ... }` so we can look + // at the statement inside. + while(auto blockStmt = dynamic_cast<BlockStmt*>(stmt)) + { + stmt = blockStmt->body; + continue; + } + + if(auto seqStmt = dynamic_cast<SeqStmt*>(stmt)) + { + // Walk through teh children and process each. + for(auto childStmt : seqStmt->stmts) + { + lowerSwitchCases(childStmt, info); + } + } + else if(auto caseStmt = dynamic_cast<CaseStmt*>(stmt)) + { + // A full `case` statement has a value we need + // to test against. It is expected to be a + // compile-time constant, so we will emit + // it like an expression here, and then hope + // for the best. + // + // TODO: figure out something cleaner. + auto caseVal = getSimpleVal(context, lowerRValueExpr(context, caseStmt->expr)); + + // Figure out where we are branching to. + auto label = getLabelForCase(info); + + + // Add this `case` to the list for the enclosing `switch`. + info->cases.Add(caseVal); + info->cases.Add(label); + } + else if(auto defaultStmt = dynamic_cast<DefaultStmt*>(stmt)) + { + auto label = getLabelForCase(info); + + // We expect to only find a single `default` stmt. + SLANG_ASSERT(!info->defaultLabel); + + info->defaultLabel = label; + } + else if(auto emptyStmt = dynamic_cast<EmptyStmt*>(stmt)) + { + // Special-case empty statements so they don't + // mess up our "trivial fall-through" optimization. + } + else + { + // We have an ordinary statement, that needs to get + // emitted to the currrent case block. + if(!info->currentCaseLabel) + { + // It possible in full C/C++ to have statements + // before the first `case`. Usually these are + // unreachable, unless they start with a label. + // + // We'll ignore them here, figuring they are + // dead. If we ever add `LabelStmt` then we'd + // need to emit these statements to a dummy + // block just in case. + } + else + { + // Emit the code to our current case block, + // and record that we've done so. + lowerStmt(context, stmt); + info->anythingEmittedToCurrentCaseBlock = true; + } + } + } + + void visitSwitchStmt(SwitchStmt* stmt) + { + auto builder = getBuilder(); + + // Given a statement: + // + // switch( CONDITION ) + // { + // case V0: + // S0; + // break; + // + // case V1: + // default: + // S1; + // break; + // } + // + // we want to generate IR like: + // + // let %c = <CONDITION>; + // switch %c, // value to switch on + // %breakLabel, // join point (and break target) + // %s1, // default label + // %v0, // first case value + // %s0, // first case label + // %v1, // second case value + // %s1 // second case label + // s0: + // <S0> + // break %breakLabel + // s1: + // <S1> + // break %breakLabel + // breakLabel: + // + + // First emit code to compute the condition: + auto conditionVal = getSimpleVal(context, lowerRValueExpr(context, stmt->condition)); + + // Remember the initial block so that we can add to it + // after we've collected all the `case`s + auto initialBlock = builder->curBlock; + + // Next, create a block to use as the target for any `break` statements + auto breakLabel = createBlock(); + + // Register the `break` label so + // that we can find it for nested statements. + context->shared->breakLabels.Add(stmt, breakLabel); + + builder->curFunc = initialBlock->parentFunc; + builder->curBlock = nullptr; + + // Iterate over the body of the statement, looking + // for `case` or `default` statements: + SwitchStmtInfo info; + info.defaultLabel = nullptr; + lowerSwitchCases(stmt->body, &info); + + // TODO: once we've discovered the cases, we should + // be able to make a quick pass over the list and eliminate + // any cases that have the exact same label as the `default` + // case, since these don't actually need to be represented. + + // If the current block (the end of the last + // `case`) is not terminated, then terminate with a + // `break` operation. + // + // Double check that we aren't in the initial + // block, so we don't get tripped up on an + // empty `switch`. + if(builder->curBlock != initialBlock) + { + // Is the block already terminated? + auto lastInst = builder->curBlock->lastInst; + if(!lastInst || !isTerminatorInst(lastInst)) + { + // Not terminated, so add one. + builder->emitBreak(breakLabel); + } + } + + // If there was no `default` statement, then the + // default case will just branch directly to the end. + auto defaultLabel = info.defaultLabel ? info.defaultLabel : breakLabel; + + // Now that we've collected the cases, we are + // prepared to emit the `switch` instruction + // itself. + builder->curBlock = initialBlock; + builder->emitSwitch( + conditionVal, + breakLabel, + defaultLabel, + info.cases.Count(), + info.cases.Buffer()); + + // Finally we insert the label that a `break` will jump to + // (and that control flow will fall through to otherwise). + // This is the block that subsequent code will go into. + insertBlock(breakLabel); + } }; void lowerStmt( diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index eb02d98c5..42c763099 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -2973,6 +2973,12 @@ namespace Slang } PopScope(); + if(!body) + { + body = new EmptyStmt(); + body->loc = blockStatement->loc; + } + blockStatement->body = body; return blockStatement; } |
