diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-20 10:17:00 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-20 10:17:00 -0800 |
| commit | 8b05df4187117d61491f2fdbeb7d744146ad73f7 (patch) | |
| tree | cfb17b26e9db313d0b6ce1a07efe85b35d6d1638 /source/slang/slang-check-stmt.cpp | |
| parent | a8da735ca4e0ed49796dda164c39e21aea4a7bc6 (diff) | |
Add static for loop iteration inference. (#2659)
Diffstat (limited to 'source/slang/slang-check-stmt.cpp')
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 289 |
1 files changed, 288 insertions, 1 deletions
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 8049c1230..519ca91ff 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -1,5 +1,6 @@ // slang-check-stmt.cpp #include "slang-check-impl.h" +#include "slang-ir-util.h" // This file implements semantic checking logic related to statements. @@ -158,17 +159,20 @@ namespace Slang void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt *stmt) { + checkModifiers(stmt); WithOuterStmt subContext(this, stmt); stmt->predicate = checkPredicateExpr(stmt->predicate); subContext.checkStmt(stmt->statement); + checkLoopInDifferentiableFunc(stmt); } void SemanticsStmtVisitor::visitForStmt(ForStmt *stmt) { WithOuterStmt subContext(this, stmt); - + checkModifiers(stmt); checkStmt(stmt->initialStatement); + if (stmt->predicateExpression) { stmt->predicateExpression = checkPredicateExpr(stmt->predicateExpression); @@ -178,6 +182,10 @@ namespace Slang stmt->sideEffectExpression = CheckExpr(stmt->sideEffectExpression); } subContext.checkStmt(stmt->statement); + + tryInferLoopMaxIterations(stmt); + + checkLoopInDifferentiableFunc(stmt); } Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal) @@ -317,9 +325,11 @@ namespace Slang void SemanticsStmtVisitor::visitWhileStmt(WhileStmt *stmt) { + checkModifiers(stmt); WithOuterStmt subContext(this, stmt); stmt->predicate = checkPredicateExpr(stmt->predicate); subContext.checkStmt(stmt->statement); + checkLoopInDifferentiableFunc(stmt); } void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt *stmt) @@ -327,6 +337,283 @@ namespace Slang stmt->expression = CheckExpr(stmt->expression); } + void SemanticsStmtVisitor::tryInferLoopMaxIterations(ForStmt* stmt) + { + // If a for loop is in the form of `for (var = initialVal; var $compareOp otherVal; var sideEffectOp operand)` + // we will try to constant fold the operands and see if we can statically determine the maximum number of + // iterations this loop will run, and insert the inferred result as a `[MaxIters]` attribute on the stmt. + // + // ++, --, +=, -= are supported in side effect expressions. + // >, <, >=, <= are supported in predicate expressions. + // induction variable can appear in either side of the expressions. + // + // Other forms like for (var1 = .., var2 = ..; ) will not be recognized here. + // If we see suspicious code like `for (int i = 0; i < 5; j++)`, we will produce a warning along the way. + // + DeclRef<Decl> predicateVar = {}; + Expr* initialVal = nullptr; + DeclRef<Decl> initialVar = {}; + if (auto varStmt = as<DeclStmt>(stmt->initialStatement)) + { + auto varDecl = as<VarDecl>(varStmt->decl); + if (!varDecl) + return; + initialVar.decl = varDecl; + initialVal = varDecl->initExpr; + } + else if (auto exprStmt = as<ExpressionStmt>(stmt->initialStatement)) + { + auto assignExpr = as<AssignExpr>(exprStmt->expression); + if (!assignExpr) + return; + auto varExpr = as<VarExpr>(assignExpr->left); + if (!varExpr) + return; + initialVar = varExpr->declRef; + initialVal = assignExpr->right; + } + else + return; + + auto initialLitVal = + as<ConstantIntVal>(tryFoldIntegerConstantExpression(initialVal, nullptr)); + + ConstantIntVal* finalVal = nullptr; + auto binaryExpr = as<InfixExpr>(stmt->predicateExpression); + if (!binaryExpr) + return; + auto compareFuncExpr = as<DeclRefExpr>(binaryExpr->functionExpr); + if (!compareFuncExpr) + return; + if (!compareFuncExpr->declRef.getDecl()) + return; + IROp compareOp = kIROp_Nop; + if (auto intrinsicOpModifier = compareFuncExpr->declRef.getDecl()->findModifier<IntrinsicOpModifier>()) + { + compareOp = (IROp)intrinsicOpModifier->op; + } + else + { + return; + } + if (binaryExpr->arguments.getCount() != 2) + return; + auto leftCompareOperand = binaryExpr->arguments[0]; + auto rightCompareOperand = binaryExpr->arguments[1]; + if (!leftCompareOperand) + return; + if (!rightCompareOperand) + return; + if (auto rightVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[1], nullptr)) + { + auto leftVar = as<VarExpr>(leftCompareOperand); + if (!leftVar) + return; + predicateVar = leftVar->declRef; + finalVal = as<ConstantIntVal>(rightVal); + } + else if (auto leftVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[0], nullptr)) + { + auto rightVar = as<VarExpr>(rightCompareOperand); + if (!rightVar) + return; + predicateVar = rightVar->declRef; + finalVal = as<ConstantIntVal>(leftVal); + compareOp = getSwapSideComparisonOp(compareOp); + } + else + { + // If neither left or right is constant, we assume left is variable and continue checking. + if (auto leftVar = as<VarExpr>(leftCompareOperand)) + { + predicateVar = leftVar->declRef; + } + if (auto rightVar = as<VarExpr>(rightCompareOperand)) + { + if (rightVar->declRef == initialVar) + { + predicateVar = rightVar->declRef; + compareOp = getSwapSideComparisonOp(compareOp); + } + } + } + + switch (compareOp) + { + case kIROp_Less: + case kIROp_Leq: + case kIROp_Greater: + case kIROp_Geq: + break; + default: + return; + } + + ConstantIntVal* stepSize = nullptr; + IROp sideEffectFuncOp = kIROp_Nop; + auto opSideEffectExpr = as<InvokeExpr>(stmt->sideEffectExpression); + if (!opSideEffectExpr) + return; + auto sideEffectFuncExpr = as<DeclRefExpr>(opSideEffectExpr->functionExpr); + if (!sideEffectFuncExpr) + return; + auto sideEffectFuncDecl = sideEffectFuncExpr->declRef.getDecl(); + if (!sideEffectFuncDecl) + return; + if (auto opName = sideEffectFuncDecl->getName()) + { + if (opName->text == "++") + sideEffectFuncOp = kIROp_Add; + else if (opName->text == "--") + sideEffectFuncOp = kIROp_Sub; + else if (opName->text == "+=") + sideEffectFuncOp = kIROp_Add; + else if (opName->text == "-=") + sideEffectFuncOp = kIROp_Sub; + else + return; + } + if (opSideEffectExpr->arguments.getCount()) + { + auto varExpr = as<VarExpr>(opSideEffectExpr->arguments[0]); + if (!varExpr) + return; + if (varExpr->declRef != initialVar) + { + // If the user writes something like `for (int i = 0; i < 5; j++)`, + // it is most likely a bug, so we issue a warning. + if (predicateVar == initialVar) + getSink()->diagnose(varExpr, Diagnostics::forLoopSideEffectChangingDifferentVar, initialVar, varExpr->declRef); + return; + } + } + else + return; + if (opSideEffectExpr->arguments.getCount() == 2) + { + auto stepVal = tryFoldIntegerConstantExpression(opSideEffectExpr->arguments[1], nullptr); + if (!stepVal) + return; + if (auto constantIntVal = as<ConstantIntVal>(stepVal)) + { + stepSize = constantIntVal; + } + } + else + { + stepSize = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + } + + if (predicateVar != initialVar) + { + if (predicateVar) + getSink()->diagnose(stmt->predicateExpression, Diagnostics::forLoopPredicateCheckingDifferentVar, initialVar, predicateVar); + return; + } + if (!stepSize) + return; + if (stepSize->value > 0) + { + if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater || + sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less) + { + getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar); + return; + } + } + else if (stepSize->value < 0) + { + if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less || + sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater) + { + getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar); + return; + } + } + else + { + getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopNotModifyingIterationVariable, initialVar); + return; + } + + if (!initialLitVal || !finalVal) + return; + + auto absStepSize = abs(stepSize->value); + int adjustment = 0; + if (compareOp == kIROp_Geq || compareOp == kIROp_Leq) + adjustment = 1; + + auto iterations = (Math::Max(finalVal->value, initialLitVal->value) - + Math::Min(finalVal->value, initialLitVal->value) + absStepSize - 1 + adjustment) / + absStepSize; + switch (compareOp) + { + case kIROp_Geq: + case kIROp_Greater: + // Expect final value to be less than initial value. + if (finalVal->value > initialLitVal->value) + iterations = 0; + break; + case kIROp_Leq: + case kIROp_Less: + if (finalVal->value < initialLitVal->value) + iterations = 0; + break; + } + if (iterations == 0) + { + getSink()->diagnose(stmt, Diagnostics::loopRunsForZeroIterations); + } + + // Note: the inferred max iterations may not be valid if the loop body + // also modifies the induction variable. + // We detect this case during lower-to-ir and will remove the `InferredMaxItersAttribute` + // if the loop body modifies the induction variable. + // + auto maxItersAttr = m_astBuilder->create<InferredMaxItersAttribute>(); + auto litExpr = m_astBuilder->create<LiteralExpr>(); + litExpr->type.type = m_astBuilder->getIntType(); + litExpr->token.setName(getNamePool()->getName(String(iterations))); + maxItersAttr->args.add(litExpr); + maxItersAttr->intArgVals.Add(0, m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations)); + maxItersAttr->value = (int32_t)iterations; + maxItersAttr->inductionVar = initialVar; + addModifier(stmt, maxItersAttr); + return; + } + + void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt) + { + if (getParentDifferentiableAttribute()) + { + if (!getParentFunc()) + return; + + // If the function is itself a derivative, or has a user defined derivative, + // then we don't require anything. + + if (getParentFunc()->findModifier<ForwardDerivativeOfAttribute>()) + return; + if (getParentFunc()->findModifier<ForwardDerivativeAttribute>()) + return; + if (getParentFunc()->findModifier<BackwardDerivativeOfAttribute>()) + return; + if (getParentFunc()->findModifier<BackwardDerivativeAttribute>()) + return; + + // For all ordinary differentiable functions, we require either a `[MaxIters]` attribute, + // or a `[ForceUnroll]` attribet on loops. + if (stmt->hasModifier<MaxItersAttribute>() || stmt->hasModifier<ForceUnrollAttribute>() || stmt->hasModifier<InferredMaxItersAttribute>()) + { + } + else + { + getSink()->diagnose(stmt, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters); + } + } + } + void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt*stmt) { stmt->device = CheckExpr(stmt->device); |
