diff options
| -rw-r--r-- | source/slang/emit.cpp | 33 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 112 | ||||
| -rw-r--r-- | tests/compute/break-stmt.slang | 27 | ||||
| -rw-r--r-- | tests/compute/break-stmt.slang.expected.txt | 4 | ||||
| -rw-r--r-- | tests/compute/continue-stmt.slang | 32 | ||||
| -rw-r--r-- | tests/compute/continue-stmt.slang.expected.txt | 4 |
6 files changed, 203 insertions, 9 deletions
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 2b7030897..039aea27d 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -97,6 +97,8 @@ struct SharedEmitContext Dictionary<IRValue*, UInt> mapIRValueToID; HashSet<Decl*> irDeclsVisited; + + Dictionary<IRBlock*, IRBlock*> irMapContinueTargetToLoopHead; }; struct EmitContext @@ -5555,10 +5557,10 @@ emitDeclImpl(decl, nullptr); emit("for(;;)\n{\n"); - // TODO: Okay, we *said* we'd do this special - // handling of the `continue` sites, but - // we aren't actually setting anything up here... - // + // Register information so that `continue` sites + // can do the right thing: + ctx->shared->irMapContinueTargetToLoopHead.Add(continueBlock, targetBlock); + emitIRStmtsForBlocks( ctx, @@ -5579,7 +5581,28 @@ emitDeclImpl(decl, nullptr); return; case kIROp_continue: - emit("continue;\n"); + // With out current strategy for outputting loops, + // just outputting an AST-level `continue` here won't + // actually execute the statements in the continue block. + // + // Instead, we have to manually output those statements + // directly here, and *then* do an AST-level `continue`. + // + // This leads to code duplication when we have multiple + // `continue` sites in the original program, but it avoids + // introducing additional temporaries for control flow. + { + auto continueInst = (IRContinue*) terminator; + auto targetBlock = continueInst->getTargetBlock(); + IRBlock* loopHead = nullptr; + ctx->shared->irMapContinueTargetToLoopHead.TryGetValue(targetBlock, loopHead); + SLANG_ASSERT(loopHead); + emitIRStmtsForBlocks( + ctx, + targetBlock, + loopHead); + emit("continue;\n"); + } return; case kIROp_loopTest: diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 86f037435..bbbc1812b 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -279,6 +279,12 @@ struct SharedIRGenContext // to reference-count these along the way because // they need to get stored into a `union` inside `LoweredValInfo` List<RefPtr<ExtendedValueInfo>> extValues; + + // Map from an AST-level statement that can be + // used as the target of a `break` or `continue` + // to the appropriate basic block to jump to. + Dictionary<Stmt*, IRBlock*> breakLabels; + Dictionary<Stmt*, IRBlock*> continueLabels; }; @@ -1741,8 +1747,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> auto breakLabel = createBlock(); auto continueLabel = createBlock(); - // TODO: register `loopHead` as the target for a - // `continue` statement. + // Register the `break` and `continue` labels so + // that we can find them for nested statements. + context->shared->breakLabels.Add(stmt, breakLabel); + context->shared->continueLabels.Add(stmt, continueLabel); // Emit the branch that will start out loop, // and then insert the block for the head. @@ -1807,8 +1815,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // jumps to the head of hte loop. auto continueLabel = loopHead; - // TODO: register appropriate targets for - // break/continue statements. + // Register the `break` and `continue` labels so + // that we can find them for nested statements. + context->shared->breakLabels.Add(stmt, breakLabel); + context->shared->continueLabels.Add(stmt, continueLabel); // Emit the branch that will start out loop, // and then insert the block for the head. @@ -1847,6 +1857,66 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> insertBlock(breakLabel); } + void visitDoWhileStmt(DoWhileStmt* stmt) + { + // Generating IR for `do {...} while` statement is similar to a + // `while` statement, just with the test in a different place + + auto builder = getBuilder(); + + // We will create blocks for the various places + // we need to jump to inside the control flow, + // including the blocks that will be referenced + // by `continue` or `break` statements. + auto loopHead = createBlock(); + auto testLabel = createBlock(); + auto breakLabel = createBlock(); + + // A `continue` inside a `do { ... } while ( ... )` loop always + // jumps to the loop test. + auto continueLabel = testLabel; + + // Register the `break` and `continue` labels so + // that we can find them for nested statements. + context->shared->breakLabels.Add(stmt, breakLabel); + context->shared->continueLabels.Add(stmt, continueLabel); + + // Emit the branch that will start out loop, + // and then insert the block for the head. + + auto loopInst = builder->emitLoop( + loopHead, + breakLabel, + continueLabel); + + addLoopDecorations(loopInst, stmt); + + insertBlock(loopHead); + + // Emit the body of the loop + lowerStmt(context, stmt->Statement); + + insertBlock(testLabel); + + // Now that we are within the header block, we + // want to emit the expression for the loop condition: + if (auto condExpr = stmt->Predicate) + { + auto irCondition = getSimpleVal(context, + lowerRValueExpr(context, condExpr)); + + // Now we want to `break` if the loop condition is false, + // otherwise we will jump back to the head of the loop. + builder->emitLoopTest( + irCondition, + loopHead, + breakLabel); + } + + // Finally we insert the label that a `break` will jump to + insertBlock(breakLabel); + } + void visitExpressionStmt(ExpressionStmt* stmt) { // The statement evaluates an expression @@ -1917,6 +1987,39 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> { getBuilder()->emitDiscard(); } + + void visitBreakStmt(BreakStmt* stmt) + { + // Semantic checking is responsible for finding + // the statement taht this `break` breaks out of + auto parentStmt = stmt->parentStmt; + SLANG_ASSERT(parentStmt); + + // We just need to look up the basic block that + // corresponds to the break label for that statement, + // and then emit an instruction to jump to it. + IRBlock* targetBlock; + context->shared->breakLabels.TryGetValue(parentStmt, targetBlock); + SLANG_ASSERT(targetBlock); + getBuilder()->emitBreak(targetBlock); + } + + void visitContinueStmt(ContinueStmt* stmt) + { + // Semantic checking is responsible for finding + // the loop that this `continue` statement continues + auto parentStmt = stmt->parentStmt; + SLANG_ASSERT(parentStmt); + + + // We just need to look up the basic block that + // corresponds to the continue label for that statement, + // and then emit an instruction to jump to it. + IRBlock* targetBlock; + context->shared->continueLabels.TryGetValue(parentStmt, targetBlock); + SLANG_ASSERT(targetBlock); + getBuilder()->emitContinue(targetBlock); + } }; void lowerStmt( @@ -1947,6 +2050,7 @@ top: case LoweredValInfo::Flavor::Simple: case LoweredValInfo::Flavor::Ptr: case LoweredValInfo::Flavor::SwizzledLValue: + case LoweredValInfo::Flavor::BoundSubscript: { builder->emitStore( left.val, diff --git a/tests/compute/break-stmt.slang b/tests/compute/break-stmt.slang new file mode 100644 index 000000000..02f5f9fa9 --- /dev/null +++ b/tests/compute/break-stmt.slang @@ -0,0 +1,27 @@ +//TEST(compute):COMPARE_COMPUTE: +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):dxbinding(0),glbinding(0),out + +// Test that `break` from a loop works. + +int test(int inVal) +{ + int ii = 0; + for(;;) + { + if(ii >= inVal) + break; + ii++; + } + return -ii; +} + +RWStructuredBuffer<int> outputBuffer : register(u0); + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = outputBuffer[tid]; + int outVal = test(inVal); + outputBuffer[tid] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/break-stmt.slang.expected.txt b/tests/compute/break-stmt.slang.expected.txt new file mode 100644 index 000000000..3ef7f3e49 --- /dev/null +++ b/tests/compute/break-stmt.slang.expected.txt @@ -0,0 +1,4 @@ +0 +FFFFFFFF +FFFFFFFE +FFFFFFFD diff --git a/tests/compute/continue-stmt.slang b/tests/compute/continue-stmt.slang new file mode 100644 index 000000000..9adb5a4a6 --- /dev/null +++ b/tests/compute/continue-stmt.slang @@ -0,0 +1,32 @@ +//TEST(compute):COMPARE_COMPUTE: +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):dxbinding(0),glbinding(0),out + +// Test that `break` from a loop works. + +int test(int inVal) +{ + int ii = 0; + do + { + if(ii < inVal) + { + ii++; + continue; + } + break; + } + while(true); + + return -ii; +} + +RWStructuredBuffer<int> outputBuffer : register(u0); + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = outputBuffer[tid]; + int outVal = test(inVal); + outputBuffer[tid] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/continue-stmt.slang.expected.txt b/tests/compute/continue-stmt.slang.expected.txt new file mode 100644 index 000000000..3ef7f3e49 --- /dev/null +++ b/tests/compute/continue-stmt.slang.expected.txt @@ -0,0 +1,4 @@ +0 +FFFFFFFF +FFFFFFFE +FFFFFFFD |
