diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-20 10:17:00 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-20 10:17:00 -0800 |
| commit | 8b05df4187117d61491f2fdbeb7d744146ad73f7 (patch) | |
| tree | cfb17b26e9db313d0b6ce1a07efe85b35d6d1638 /source | |
| parent | a8da735ca4e0ed49796dda164c39e21aea4a7bc6 (diff) | |
Add static for loop iteration inference. (#2659)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 289 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 43 |
10 files changed, 396 insertions, 5 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index af06e6bac..a60a77cc3 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -270,6 +270,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ vector<TYPE, COUNT> result; \ vector<TYPE, COUNT>.Differential d_result; \ + [ForceUnroll]\ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> dp_elem = D_FUNC(DifferentialPair<TYPE>(VALUE.p[i], __slang_noop_cast<TYPE.Differential>(VALUE.d[i]))); \ @@ -281,6 +282,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \ vector<TYPE, COUNT> result; \ vector<TYPE, COUNT>.Differential d_result; \ + [ForceUnroll] \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> dp_elem = D_FUNC(DifferentialPair<TYPE>(LEFT.p[i], __slang_noop_cast<TYPE.Differential>(LEFT.d[i])), \ @@ -292,6 +294,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_BWD_D_UNARY(TYPE, COUNT, D_FUNC, VALUE, D_OUT) \ vector<TYPE, COUNT>.Differential d_result; \ + [ForceUnroll] \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> dp_elem = diffPair(VALUE.p[i], TYPE.dzero()); \ @@ -302,6 +305,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_BWD_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT, D_OUT) \ vector<TYPE, COUNT>.Differential left_d_result, right_d_result; \ + [ForceUnroll] \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> left_dp = diffPair(LEFT.p[i], TYPE.dzero()); \ @@ -705,6 +709,7 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair { T result = T(0); T.Differential d_result = T.dzero(); + [ForceUnroll] for (int i = 0; i < N; ++i) { result = result + dpx.p[i] * dpy.p[i]; @@ -719,6 +724,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int> void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut) { vector<T, N>.Differential x_d_result, y_d_result; + [ForceUnroll] for (int i = 0; i < N; ++i) { x_d_result[i] = dpy.p[i] * __slang_noop_cast<T>(dOut); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 4ab295da6..99e221b1e 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -635,6 +635,14 @@ class MaxItersAttribute : public Attribute int32_t value = 0; }; +// An inferred max iteration count on a loop. +class InferredMaxItersAttribute : public Attribute +{ + SLANG_AST_CLASS(InferredMaxItersAttribute) + DeclRef<Decl> inductionVar; + int32_t value = 0; +}; + class LoopAttribute : public Attribute { SLANG_AST_CLASS(LoopAttribute) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index ccc739da3..165c84192 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2023,6 +2023,11 @@ namespace Slang void visitGpuForeachStmt(GpuForeachStmt *stmt); void visitExpressionStmt(ExpressionStmt *stmt); + + // Try to infer the max number of iterations the loop will run. + void tryInferLoopMaxIterations(ForStmt* stmt); + + void checkLoopInDifferentiableFunc(Stmt* stmt); }; struct SemanticsDeclVisitorBase diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 8049c1230..519ca91ff 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -1,5 +1,6 @@ // slang-check-stmt.cpp #include "slang-check-impl.h" +#include "slang-ir-util.h" // This file implements semantic checking logic related to statements. @@ -158,17 +159,20 @@ namespace Slang void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt *stmt) { + checkModifiers(stmt); WithOuterStmt subContext(this, stmt); stmt->predicate = checkPredicateExpr(stmt->predicate); subContext.checkStmt(stmt->statement); + checkLoopInDifferentiableFunc(stmt); } void SemanticsStmtVisitor::visitForStmt(ForStmt *stmt) { WithOuterStmt subContext(this, stmt); - + checkModifiers(stmt); checkStmt(stmt->initialStatement); + if (stmt->predicateExpression) { stmt->predicateExpression = checkPredicateExpr(stmt->predicateExpression); @@ -178,6 +182,10 @@ namespace Slang stmt->sideEffectExpression = CheckExpr(stmt->sideEffectExpression); } subContext.checkStmt(stmt->statement); + + tryInferLoopMaxIterations(stmt); + + checkLoopInDifferentiableFunc(stmt); } Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal) @@ -317,9 +325,11 @@ namespace Slang void SemanticsStmtVisitor::visitWhileStmt(WhileStmt *stmt) { + checkModifiers(stmt); WithOuterStmt subContext(this, stmt); stmt->predicate = checkPredicateExpr(stmt->predicate); subContext.checkStmt(stmt->statement); + checkLoopInDifferentiableFunc(stmt); } void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt *stmt) @@ -327,6 +337,283 @@ namespace Slang stmt->expression = CheckExpr(stmt->expression); } + void SemanticsStmtVisitor::tryInferLoopMaxIterations(ForStmt* stmt) + { + // If a for loop is in the form of `for (var = initialVal; var $compareOp otherVal; var sideEffectOp operand)` + // we will try to constant fold the operands and see if we can statically determine the maximum number of + // iterations this loop will run, and insert the inferred result as a `[MaxIters]` attribute on the stmt. + // + // ++, --, +=, -= are supported in side effect expressions. + // >, <, >=, <= are supported in predicate expressions. + // induction variable can appear in either side of the expressions. + // + // Other forms like for (var1 = .., var2 = ..; ) will not be recognized here. + // If we see suspicious code like `for (int i = 0; i < 5; j++)`, we will produce a warning along the way. + // + DeclRef<Decl> predicateVar = {}; + Expr* initialVal = nullptr; + DeclRef<Decl> initialVar = {}; + if (auto varStmt = as<DeclStmt>(stmt->initialStatement)) + { + auto varDecl = as<VarDecl>(varStmt->decl); + if (!varDecl) + return; + initialVar.decl = varDecl; + initialVal = varDecl->initExpr; + } + else if (auto exprStmt = as<ExpressionStmt>(stmt->initialStatement)) + { + auto assignExpr = as<AssignExpr>(exprStmt->expression); + if (!assignExpr) + return; + auto varExpr = as<VarExpr>(assignExpr->left); + if (!varExpr) + return; + initialVar = varExpr->declRef; + initialVal = assignExpr->right; + } + else + return; + + auto initialLitVal = + as<ConstantIntVal>(tryFoldIntegerConstantExpression(initialVal, nullptr)); + + ConstantIntVal* finalVal = nullptr; + auto binaryExpr = as<InfixExpr>(stmt->predicateExpression); + if (!binaryExpr) + return; + auto compareFuncExpr = as<DeclRefExpr>(binaryExpr->functionExpr); + if (!compareFuncExpr) + return; + if (!compareFuncExpr->declRef.getDecl()) + return; + IROp compareOp = kIROp_Nop; + if (auto intrinsicOpModifier = compareFuncExpr->declRef.getDecl()->findModifier<IntrinsicOpModifier>()) + { + compareOp = (IROp)intrinsicOpModifier->op; + } + else + { + return; + } + if (binaryExpr->arguments.getCount() != 2) + return; + auto leftCompareOperand = binaryExpr->arguments[0]; + auto rightCompareOperand = binaryExpr->arguments[1]; + if (!leftCompareOperand) + return; + if (!rightCompareOperand) + return; + if (auto rightVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[1], nullptr)) + { + auto leftVar = as<VarExpr>(leftCompareOperand); + if (!leftVar) + return; + predicateVar = leftVar->declRef; + finalVal = as<ConstantIntVal>(rightVal); + } + else if (auto leftVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[0], nullptr)) + { + auto rightVar = as<VarExpr>(rightCompareOperand); + if (!rightVar) + return; + predicateVar = rightVar->declRef; + finalVal = as<ConstantIntVal>(leftVal); + compareOp = getSwapSideComparisonOp(compareOp); + } + else + { + // If neither left or right is constant, we assume left is variable and continue checking. + if (auto leftVar = as<VarExpr>(leftCompareOperand)) + { + predicateVar = leftVar->declRef; + } + if (auto rightVar = as<VarExpr>(rightCompareOperand)) + { + if (rightVar->declRef == initialVar) + { + predicateVar = rightVar->declRef; + compareOp = getSwapSideComparisonOp(compareOp); + } + } + } + + switch (compareOp) + { + case kIROp_Less: + case kIROp_Leq: + case kIROp_Greater: + case kIROp_Geq: + break; + default: + return; + } + + ConstantIntVal* stepSize = nullptr; + IROp sideEffectFuncOp = kIROp_Nop; + auto opSideEffectExpr = as<InvokeExpr>(stmt->sideEffectExpression); + if (!opSideEffectExpr) + return; + auto sideEffectFuncExpr = as<DeclRefExpr>(opSideEffectExpr->functionExpr); + if (!sideEffectFuncExpr) + return; + auto sideEffectFuncDecl = sideEffectFuncExpr->declRef.getDecl(); + if (!sideEffectFuncDecl) + return; + if (auto opName = sideEffectFuncDecl->getName()) + { + if (opName->text == "++") + sideEffectFuncOp = kIROp_Add; + else if (opName->text == "--") + sideEffectFuncOp = kIROp_Sub; + else if (opName->text == "+=") + sideEffectFuncOp = kIROp_Add; + else if (opName->text == "-=") + sideEffectFuncOp = kIROp_Sub; + else + return; + } + if (opSideEffectExpr->arguments.getCount()) + { + auto varExpr = as<VarExpr>(opSideEffectExpr->arguments[0]); + if (!varExpr) + return; + if (varExpr->declRef != initialVar) + { + // If the user writes something like `for (int i = 0; i < 5; j++)`, + // it is most likely a bug, so we issue a warning. + if (predicateVar == initialVar) + getSink()->diagnose(varExpr, Diagnostics::forLoopSideEffectChangingDifferentVar, initialVar, varExpr->declRef); + return; + } + } + else + return; + if (opSideEffectExpr->arguments.getCount() == 2) + { + auto stepVal = tryFoldIntegerConstantExpression(opSideEffectExpr->arguments[1], nullptr); + if (!stepVal) + return; + if (auto constantIntVal = as<ConstantIntVal>(stepVal)) + { + stepSize = constantIntVal; + } + } + else + { + stepSize = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + } + + if (predicateVar != initialVar) + { + if (predicateVar) + getSink()->diagnose(stmt->predicateExpression, Diagnostics::forLoopPredicateCheckingDifferentVar, initialVar, predicateVar); + return; + } + if (!stepSize) + return; + if (stepSize->value > 0) + { + if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater || + sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less) + { + getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar); + return; + } + } + else if (stepSize->value < 0) + { + if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less || + sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater) + { + getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar); + return; + } + } + else + { + getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopNotModifyingIterationVariable, initialVar); + return; + } + + if (!initialLitVal || !finalVal) + return; + + auto absStepSize = abs(stepSize->value); + int adjustment = 0; + if (compareOp == kIROp_Geq || compareOp == kIROp_Leq) + adjustment = 1; + + auto iterations = (Math::Max(finalVal->value, initialLitVal->value) - + Math::Min(finalVal->value, initialLitVal->value) + absStepSize - 1 + adjustment) / + absStepSize; + switch (compareOp) + { + case kIROp_Geq: + case kIROp_Greater: + // Expect final value to be less than initial value. + if (finalVal->value > initialLitVal->value) + iterations = 0; + break; + case kIROp_Leq: + case kIROp_Less: + if (finalVal->value < initialLitVal->value) + iterations = 0; + break; + } + if (iterations == 0) + { + getSink()->diagnose(stmt, Diagnostics::loopRunsForZeroIterations); + } + + // Note: the inferred max iterations may not be valid if the loop body + // also modifies the induction variable. + // We detect this case during lower-to-ir and will remove the `InferredMaxItersAttribute` + // if the loop body modifies the induction variable. + // + auto maxItersAttr = m_astBuilder->create<InferredMaxItersAttribute>(); + auto litExpr = m_astBuilder->create<LiteralExpr>(); + litExpr->type.type = m_astBuilder->getIntType(); + litExpr->token.setName(getNamePool()->getName(String(iterations))); + maxItersAttr->args.add(litExpr); + maxItersAttr->intArgVals.Add(0, m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations)); + maxItersAttr->value = (int32_t)iterations; + maxItersAttr->inductionVar = initialVar; + addModifier(stmt, maxItersAttr); + return; + } + + void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt) + { + if (getParentDifferentiableAttribute()) + { + if (!getParentFunc()) + return; + + // If the function is itself a derivative, or has a user defined derivative, + // then we don't require anything. + + if (getParentFunc()->findModifier<ForwardDerivativeOfAttribute>()) + return; + if (getParentFunc()->findModifier<ForwardDerivativeAttribute>()) + return; + if (getParentFunc()->findModifier<BackwardDerivativeOfAttribute>()) + return; + if (getParentFunc()->findModifier<BackwardDerivativeAttribute>()) + return; + + // For all ordinary differentiable functions, we require either a `[MaxIters]` attribute, + // or a `[ForceUnroll]` attribet on loops. + if (stmt->hasModifier<MaxItersAttribute>() || stmt->hasModifier<ForceUnrollAttribute>() || stmt->hasModifier<InferredMaxItersAttribute>()) + { + } + else + { + getSink()->diagnose(stmt, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters); + } + } + } + void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt*stmt) { stmt->device = CheckExpr(stmt->device); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index e29d7eeac..c5f7e6cbe 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -366,6 +366,15 @@ DIAGNOSTIC(30310, Error, typeIsNotDifferentiable, "type '$0' is not differentiab // Interop DIAGNOSTIC(30400, Error, cannotDefinePtrTypeToManagedResource, "pointer to a managed resource is invalid, use `NativeRef<T>` instead") +// Control flow +DIAGNOSTIC(30500, Warning, forLoopSideEffectChangingDifferentVar, "the for loop initializes and checks variable '$0' but the side effect expression is modifying '$1'.") +DIAGNOSTIC(30501, Warning, forLoopPredicateCheckingDifferentVar, "the for loop initializes and modifies variable '$0' but the predicate expression is checking '$1'.") +DIAGNOSTIC(30502, Warning, forLoopChangingIterationVariableInOppsoiteDirection, "the for loop is modifiying variable '$0' in the opposite direction from loop exit condition.") +DIAGNOSTIC(30503, Warning, forLoopNotModifyingIterationVariable, "the for loop is not modifiying variable '$0' because the step size evaluates to 0.") +DIAGNOSTIC(30504, Warning, forLoopTerminatesInFewerIterationsThanMaxIters, "the for loop is statically determined to terminate within $0 iterations, which is less than what [MaxIters] specifies.") +DIAGNOSTIC(30505, Warning, loopRunsForZeroIterations, "the loop runs for 0 iterations and will be removed.") +DIAGNOSTIC(30510, Error, loopInDiffFuncRequireUnrollOrMaxIters, "loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute.") + // TODO: need to assign numbers to all these extra diagnostics... DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.") DIAGNOSTIC(39999, Error, localVariableUsedBeforeDeclared, "local variable '$0' is being used before its declaration.") diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 98f2b2c34..c750b2d3d 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -346,6 +346,22 @@ public: } } } + + // Make sure all loops are marked with either [MaxIters] or [ForceUnroll]. + for (auto block : funcInst->getBlocks()) + { + auto loop = as<IRLoop>(block->getTerminator()); + if (!loop) + continue; + if (loop->findDecoration<IRLoopMaxItersDecoration>() || loop->findDecoration<IRForceUnrollDecoration>()) + { + // We are good. + } + else + { + sink->diagnose(loop->sourceLoc, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters); + } + } } void processModule() diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f2107aa62..4dea3985a 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -598,6 +598,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(LayoutDecoration, layout, 1, 0) INST(LoopControlDecoration, loopControl, 1, 0) INST(LoopMaxItersDecoration, loopMaxIters, 1, 0) + INST(LoopInferredMaxItersDecoration, loopInferredMaxIters, 2, 0) INST(LoopExitPrimalValueDecoration, loopExitPrimalValue, 2, 0) INST(IntrinsicOpDecoration, intrinsicOp, 1, 0) /* TargetSpecificDecoration */ diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 253686aa5..f3c4c2c82 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -466,6 +466,27 @@ IRInst* getUndefInst(IRBuilder builder, IRModule* module) return undefInst; } +IROp getSwapSideComparisonOp(IROp op) +{ + switch (op) + { + case kIROp_Eql: + return kIROp_Eql; + case kIROp_Neq: + return kIROp_Neq; + case kIROp_Leq: + return kIROp_Geq; + case kIROp_Geq: + return kIROp_Leq; + case kIROp_Less: + return kIROp_Greater; + case kIROp_Greater: + return kIROp_Less; + default: + return kIROp_Nop; + } +} + bool isPureFunctionalCall(IRCall* call) { auto callee = getResolvedInstForDecorations(call->getCallee()); diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 2f1ac2d1a..0fb26f791 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -170,6 +170,9 @@ bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, I IRInst* getUndefInst(IRBuilder builder, IRModule* module); +// The the equivalent op of (a op b) in (b op' a). For example, a > b is equivalent to b < a. So (<) ==> (>). +IROp getSwapSideComparisonOp(IROp op); + } #endif diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index aa2dc4efb..681871b6c 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4856,11 +4856,17 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> { getBuilder()->addLoopControlDecoration(inst, kIRLoopControl_Loop); } - else if( auto maxItersAttr = stmt->findModifier<MaxItersAttribute>() ) + + if( auto maxItersAttr = stmt->findModifier<MaxItersAttribute>() ) { getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value); } - else if (auto forceUnrollAttr = stmt->findModifier<ForceUnrollAttribute>()) + else if (auto inferredMaxItersAttr = stmt->findModifier<InferredMaxItersAttribute>()) + { + getBuilder()->addLoopMaxItersDecoration(inst, inferredMaxItersAttr->value); + } + + if (auto forceUnrollAttr = stmt->findModifier<ForceUnrollAttribute>()) { getBuilder()->addLoopForceUnrollDecoration(inst, forceUnrollAttr->maxIterations); } @@ -4901,8 +4907,6 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> breakLabel, continueLabel); - addLoopDecorations(loopInst, stmt); - insertBlock(loopHead); // Now that we are within the header block, we @@ -4923,6 +4927,37 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> insertBlock(bodyLabel); lowerStmt(context, stmt->statement); + if (auto inferredMaxIters = stmt->findModifier<InferredMaxItersAttribute>()) + { + // We only use inferred max iters attribute when the loop body + // does not modify induction var. + auto inductionVar = emitDeclRef(context, inferredMaxIters->inductionVar, builder->getIntType()); + if (inductionVar.val) + { + int writes = 0; + traverseUsers(inductionVar.val, [&](IRInst* user) {if (user->getOp() != kIROp_Load) writes++; }); + if (writes > 1) + { + removeModifier(stmt, inferredMaxIters); + } + } + } + if (auto inferredMaxIters = stmt->findModifier<InferredMaxItersAttribute>()) + { + if (auto maxIters = stmt->findModifier<MaxItersAttribute>()) + { + if (inferredMaxIters->value < maxIters->value) + { + context->getSink()->diagnose( + maxIters, + Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters, + inferredMaxIters->value); + } + } + } + addLoopDecorations(loopInst, stmt); + + // Insert the `continue` block insertBlock(continueLabel); if (auto incrExpr = stmt->sideEffectExpression) |
