From 0298a0427bbfe19700169c4e239a1b9e91baa410 Mon Sep 17 00:00:00 2001 From: Tim Foley Date: Fri, 17 Nov 2017 07:09:58 -0800 Subject: IR: support `CompileTimeForStmt` (#286) This statement type is a bit of a hack, to support loops that *must* be unrolled. The AST-to-AST pass handles them by cloning the AST for the loop body N times, and it was easy enough to do the same thing for the IR: emit the instructions for the body N times. The only thing that requires a bit of care is that now we might see the same variable declarations multiple times, so we need to play it safe and overwrite existing entries in our map from declarations to their IR values. Of course a better answer long-term would be to do the actual unrolling in the IR. This is especially true because we might some day want to support compile-time/must-unroll loops in functions, where the loop counter comes in as a parameter (but must still be compile-time-constant at every call site). --- source/slang/lower-to-ir.cpp | 44 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 9 deletions(-) (limited to 'source') diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 827504122..326d25649 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -1637,9 +1637,37 @@ struct StmtLoweringVisitor : StmtVisitor SLANG_UNEXPECTED("`case` or `default` not under `switch`"); } - void visitCompileTimeForStmt(CompileTimeForStmt*) + void visitCompileTimeForStmt(CompileTimeForStmt* stmt) { - SLANG_UNIMPLEMENTED_X("IR lowering of CompileTimeForStmt"); + // The user is asking us to emit code for the loop + // body for each value in the given integer range. + // For now, we will handle this by repeatedly lowering + // the body statement, with the loop variable bound + // to a different integer literal value each time. + // + // TODO: eventually we might handle this as just an + // ordinary loop, with an `[unroll]` attribute on + // it that we would respect. + + auto rangeBeginVal = GetIntVal(stmt->rangeBeginVal); + auto rangeEndVal = GetIntVal(stmt->rangeEndVal); + + if (rangeBeginVal >= rangeEndVal) + return; + + auto varDecl = stmt->varDecl; + auto varType = varDecl->type; + + for (IntegerLiteralValue ii = rangeBeginVal; ii < rangeEndVal; ++ii) + { + auto constVal = getBuilder()->getIntValue( + varType, + ii); + + context->shared->declValues[varDecl] = LoweredValInfo::simple(constVal); + + lowerStmt(context, stmt->body); + } } // Create a basic block in the current function, @@ -2590,9 +2618,8 @@ struct DeclLoweringVisitor : DeclVisitor // A global variable's SSA value is a *pointer* to // the underlying storage. auto globalVal = LoweredValInfo::ptr(irGlobal); - context->shared->declValues.Add( - DeclRef(decl, nullptr), - globalVal); + context->shared->declValues[ + DeclRef(decl, nullptr)] = globalVal; if( auto initExpr = decl->initExpr ) { @@ -2667,9 +2694,8 @@ struct DeclLoweringVisitor : DeclVisitor assign(context, varVal, initVal); } - context->shared->declValues.Add( - DeclRef(decl, nullptr), - varVal); + context->shared->declValues[ + DeclRef(decl, nullptr)] = varVal; return varVal; } @@ -3214,7 +3240,7 @@ struct DeclLoweringVisitor : DeclVisitor if( auto paramDecl = paramInfo.decl ) { DeclRef paramDeclRef = makeDeclRef(paramDecl); - subContext->shared->declValues.Add(paramDeclRef, paramVal); + subContext->shared->declValues[paramDeclRef] = paramVal; } if (paramInfo.isThisParam) -- cgit v1.2.3