summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/emit.cpp33
-rw-r--r--source/slang/lower-to-ir.cpp112
-rw-r--r--tests/compute/break-stmt.slang27
-rw-r--r--tests/compute/break-stmt.slang.expected.txt4
-rw-r--r--tests/compute/continue-stmt.slang32
-rw-r--r--tests/compute/continue-stmt.slang.expected.txt4
6 files changed, 203 insertions, 9 deletions
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 2b7030897..039aea27d 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -97,6 +97,8 @@ struct SharedEmitContext
Dictionary<IRValue*, UInt> mapIRValueToID;
HashSet<Decl*> irDeclsVisited;
+
+ Dictionary<IRBlock*, IRBlock*> irMapContinueTargetToLoopHead;
};
struct EmitContext
@@ -5555,10 +5557,10 @@ emitDeclImpl(decl, nullptr);
emit("for(;;)\n{\n");
- // TODO: Okay, we *said* we'd do this special
- // handling of the `continue` sites, but
- // we aren't actually setting anything up here...
- //
+ // Register information so that `continue` sites
+ // can do the right thing:
+ ctx->shared->irMapContinueTargetToLoopHead.Add(continueBlock, targetBlock);
+
emitIRStmtsForBlocks(
ctx,
@@ -5579,7 +5581,28 @@ emitDeclImpl(decl, nullptr);
return;
case kIROp_continue:
- emit("continue;\n");
+ // With out current strategy for outputting loops,
+ // just outputting an AST-level `continue` here won't
+ // actually execute the statements in the continue block.
+ //
+ // Instead, we have to manually output those statements
+ // directly here, and *then* do an AST-level `continue`.
+ //
+ // This leads to code duplication when we have multiple
+ // `continue` sites in the original program, but it avoids
+ // introducing additional temporaries for control flow.
+ {
+ auto continueInst = (IRContinue*) terminator;
+ auto targetBlock = continueInst->getTargetBlock();
+ IRBlock* loopHead = nullptr;
+ ctx->shared->irMapContinueTargetToLoopHead.TryGetValue(targetBlock, loopHead);
+ SLANG_ASSERT(loopHead);
+ emitIRStmtsForBlocks(
+ ctx,
+ targetBlock,
+ loopHead);
+ emit("continue;\n");
+ }
return;
case kIROp_loopTest:
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 86f037435..bbbc1812b 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -279,6 +279,12 @@ struct SharedIRGenContext
// to reference-count these along the way because
// they need to get stored into a `union` inside `LoweredValInfo`
List<RefPtr<ExtendedValueInfo>> extValues;
+
+ // Map from an AST-level statement that can be
+ // used as the target of a `break` or `continue`
+ // to the appropriate basic block to jump to.
+ Dictionary<Stmt*, IRBlock*> breakLabels;
+ Dictionary<Stmt*, IRBlock*> continueLabels;
};
@@ -1741,8 +1747,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
auto breakLabel = createBlock();
auto continueLabel = createBlock();
- // TODO: register `loopHead` as the target for a
- // `continue` statement.
+ // Register the `break` and `continue` labels so
+ // that we can find them for nested statements.
+ context->shared->breakLabels.Add(stmt, breakLabel);
+ context->shared->continueLabels.Add(stmt, continueLabel);
// Emit the branch that will start out loop,
// and then insert the block for the head.
@@ -1807,8 +1815,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// jumps to the head of hte loop.
auto continueLabel = loopHead;
- // TODO: register appropriate targets for
- // break/continue statements.
+ // Register the `break` and `continue` labels so
+ // that we can find them for nested statements.
+ context->shared->breakLabels.Add(stmt, breakLabel);
+ context->shared->continueLabels.Add(stmt, continueLabel);
// Emit the branch that will start out loop,
// and then insert the block for the head.
@@ -1847,6 +1857,66 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
insertBlock(breakLabel);
}
+ void visitDoWhileStmt(DoWhileStmt* stmt)
+ {
+ // Generating IR for `do {...} while` statement is similar to a
+ // `while` statement, just with the test in a different place
+
+ auto builder = getBuilder();
+
+ // We will create blocks for the various places
+ // we need to jump to inside the control flow,
+ // including the blocks that will be referenced
+ // by `continue` or `break` statements.
+ auto loopHead = createBlock();
+ auto testLabel = createBlock();
+ auto breakLabel = createBlock();
+
+ // A `continue` inside a `do { ... } while ( ... )` loop always
+ // jumps to the loop test.
+ auto continueLabel = testLabel;
+
+ // Register the `break` and `continue` labels so
+ // that we can find them for nested statements.
+ context->shared->breakLabels.Add(stmt, breakLabel);
+ context->shared->continueLabels.Add(stmt, continueLabel);
+
+ // Emit the branch that will start out loop,
+ // and then insert the block for the head.
+
+ auto loopInst = builder->emitLoop(
+ loopHead,
+ breakLabel,
+ continueLabel);
+
+ addLoopDecorations(loopInst, stmt);
+
+ insertBlock(loopHead);
+
+ // Emit the body of the loop
+ lowerStmt(context, stmt->Statement);
+
+ insertBlock(testLabel);
+
+ // Now that we are within the header block, we
+ // want to emit the expression for the loop condition:
+ if (auto condExpr = stmt->Predicate)
+ {
+ auto irCondition = getSimpleVal(context,
+ lowerRValueExpr(context, condExpr));
+
+ // Now we want to `break` if the loop condition is false,
+ // otherwise we will jump back to the head of the loop.
+ builder->emitLoopTest(
+ irCondition,
+ loopHead,
+ breakLabel);
+ }
+
+ // Finally we insert the label that a `break` will jump to
+ insertBlock(breakLabel);
+ }
+
void visitExpressionStmt(ExpressionStmt* stmt)
{
// The statement evaluates an expression
@@ -1917,6 +1987,39 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
{
getBuilder()->emitDiscard();
}
+
+ void visitBreakStmt(BreakStmt* stmt)
+ {
+ // Semantic checking is responsible for finding
+ // the statement taht this `break` breaks out of
+ auto parentStmt = stmt->parentStmt;
+ SLANG_ASSERT(parentStmt);
+
+ // We just need to look up the basic block that
+ // corresponds to the break label for that statement,
+ // and then emit an instruction to jump to it.
+ IRBlock* targetBlock;
+ context->shared->breakLabels.TryGetValue(parentStmt, targetBlock);
+ SLANG_ASSERT(targetBlock);
+ getBuilder()->emitBreak(targetBlock);
+ }
+
+ void visitContinueStmt(ContinueStmt* stmt)
+ {
+ // Semantic checking is responsible for finding
+ // the loop that this `continue` statement continues
+ auto parentStmt = stmt->parentStmt;
+ SLANG_ASSERT(parentStmt);
+
+
+ // We just need to look up the basic block that
+ // corresponds to the continue label for that statement,
+ // and then emit an instruction to jump to it.
+ IRBlock* targetBlock;
+ context->shared->continueLabels.TryGetValue(parentStmt, targetBlock);
+ SLANG_ASSERT(targetBlock);
+ getBuilder()->emitContinue(targetBlock);
+ }
};
void lowerStmt(
@@ -1947,6 +2050,7 @@ top:
case LoweredValInfo::Flavor::Simple:
case LoweredValInfo::Flavor::Ptr:
case LoweredValInfo::Flavor::SwizzledLValue:
+ case LoweredValInfo::Flavor::BoundSubscript:
{
builder->emitStore(
left.val,
diff --git a/tests/compute/break-stmt.slang b/tests/compute/break-stmt.slang
new file mode 100644
index 000000000..02f5f9fa9
--- /dev/null
+++ b/tests/compute/break-stmt.slang
@@ -0,0 +1,27 @@
+//TEST(compute):COMPARE_COMPUTE:
+//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):dxbinding(0),glbinding(0),out
+
+// Test that `break` from a loop works.
+
+int test(int inVal)
+{
+ int ii = 0;
+ for(;;)
+ {
+ if(ii >= inVal)
+ break;
+ ii++;
+ }
+ return -ii;
+}
+
+RWStructuredBuffer<int> outputBuffer : register(u0);
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = outputBuffer[tid];
+ int outVal = test(inVal);
+ outputBuffer[tid] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/break-stmt.slang.expected.txt b/tests/compute/break-stmt.slang.expected.txt
new file mode 100644
index 000000000..3ef7f3e49
--- /dev/null
+++ b/tests/compute/break-stmt.slang.expected.txt
@@ -0,0 +1,4 @@
+0
+FFFFFFFF
+FFFFFFFE
+FFFFFFFD
diff --git a/tests/compute/continue-stmt.slang b/tests/compute/continue-stmt.slang
new file mode 100644
index 000000000..9adb5a4a6
--- /dev/null
+++ b/tests/compute/continue-stmt.slang
@@ -0,0 +1,32 @@
+//TEST(compute):COMPARE_COMPUTE:
+//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):dxbinding(0),glbinding(0),out
+
+// Test that `break` from a loop works.
+
+int test(int inVal)
+{
+ int ii = 0;
+ do
+ {
+ if(ii < inVal)
+ {
+ ii++;
+ continue;
+ }
+ break;
+ }
+ while(true);
+
+ return -ii;
+}
+
+RWStructuredBuffer<int> outputBuffer : register(u0);
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = outputBuffer[tid];
+ int outVal = test(inVal);
+ outputBuffer[tid] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/continue-stmt.slang.expected.txt b/tests/compute/continue-stmt.slang.expected.txt
new file mode 100644
index 000000000..3ef7f3e49
--- /dev/null
+++ b/tests/compute/continue-stmt.slang.expected.txt
@@ -0,0 +1,4 @@
+0
+FFFFFFFF
+FFFFFFFE
+FFFFFFFD