summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-stmt.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-20 10:17:00 -0800
committerGitHub <noreply@github.com>2023-02-20 10:17:00 -0800
commit8b05df4187117d61491f2fdbeb7d744146ad73f7 (patch)
treecfb17b26e9db313d0b6ce1a07efe85b35d6d1638 /source/slang/slang-check-stmt.cpp
parenta8da735ca4e0ed49796dda164c39e21aea4a7bc6 (diff)
Add static for loop iteration inference. (#2659)
Diffstat (limited to 'source/slang/slang-check-stmt.cpp')
-rw-r--r--source/slang/slang-check-stmt.cpp289
1 files changed, 288 insertions, 1 deletions
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index 8049c1230..519ca91ff 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -1,5 +1,6 @@
// slang-check-stmt.cpp
#include "slang-check-impl.h"
+#include "slang-ir-util.h"
// This file implements semantic checking logic related to statements.
@@ -158,17 +159,20 @@ namespace Slang
void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt *stmt)
{
+ checkModifiers(stmt);
WithOuterStmt subContext(this, stmt);
stmt->predicate = checkPredicateExpr(stmt->predicate);
subContext.checkStmt(stmt->statement);
+ checkLoopInDifferentiableFunc(stmt);
}
void SemanticsStmtVisitor::visitForStmt(ForStmt *stmt)
{
WithOuterStmt subContext(this, stmt);
-
+ checkModifiers(stmt);
checkStmt(stmt->initialStatement);
+
if (stmt->predicateExpression)
{
stmt->predicateExpression = checkPredicateExpr(stmt->predicateExpression);
@@ -178,6 +182,10 @@ namespace Slang
stmt->sideEffectExpression = CheckExpr(stmt->sideEffectExpression);
}
subContext.checkStmt(stmt->statement);
+
+ tryInferLoopMaxIterations(stmt);
+
+ checkLoopInDifferentiableFunc(stmt);
}
Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal)
@@ -317,9 +325,11 @@ namespace Slang
void SemanticsStmtVisitor::visitWhileStmt(WhileStmt *stmt)
{
+ checkModifiers(stmt);
WithOuterStmt subContext(this, stmt);
stmt->predicate = checkPredicateExpr(stmt->predicate);
subContext.checkStmt(stmt->statement);
+ checkLoopInDifferentiableFunc(stmt);
}
void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt *stmt)
@@ -327,6 +337,283 @@ namespace Slang
stmt->expression = CheckExpr(stmt->expression);
}
+ void SemanticsStmtVisitor::tryInferLoopMaxIterations(ForStmt* stmt)
+ {
+ // If a for loop is in the form of `for (var = initialVal; var $compareOp otherVal; var sideEffectOp operand)`
+ // we will try to constant fold the operands and see if we can statically determine the maximum number of
+ // iterations this loop will run, and insert the inferred result as a `[MaxIters]` attribute on the stmt.
+ //
+ // ++, --, +=, -= are supported in side effect expressions.
+ // >, <, >=, <= are supported in predicate expressions.
+ // induction variable can appear in either side of the expressions.
+ //
+ // Other forms like for (var1 = .., var2 = ..; ) will not be recognized here.
+ // If we see suspicious code like `for (int i = 0; i < 5; j++)`, we will produce a warning along the way.
+ //
+ DeclRef<Decl> predicateVar = {};
+ Expr* initialVal = nullptr;
+ DeclRef<Decl> initialVar = {};
+ if (auto varStmt = as<DeclStmt>(stmt->initialStatement))
+ {
+ auto varDecl = as<VarDecl>(varStmt->decl);
+ if (!varDecl)
+ return;
+ initialVar.decl = varDecl;
+ initialVal = varDecl->initExpr;
+ }
+ else if (auto exprStmt = as<ExpressionStmt>(stmt->initialStatement))
+ {
+ auto assignExpr = as<AssignExpr>(exprStmt->expression);
+ if (!assignExpr)
+ return;
+ auto varExpr = as<VarExpr>(assignExpr->left);
+ if (!varExpr)
+ return;
+ initialVar = varExpr->declRef;
+ initialVal = assignExpr->right;
+ }
+ else
+ return;
+
+ auto initialLitVal =
+ as<ConstantIntVal>(tryFoldIntegerConstantExpression(initialVal, nullptr));
+
+ ConstantIntVal* finalVal = nullptr;
+ auto binaryExpr = as<InfixExpr>(stmt->predicateExpression);
+ if (!binaryExpr)
+ return;
+ auto compareFuncExpr = as<DeclRefExpr>(binaryExpr->functionExpr);
+ if (!compareFuncExpr)
+ return;
+ if (!compareFuncExpr->declRef.getDecl())
+ return;
+ IROp compareOp = kIROp_Nop;
+ if (auto intrinsicOpModifier = compareFuncExpr->declRef.getDecl()->findModifier<IntrinsicOpModifier>())
+ {
+ compareOp = (IROp)intrinsicOpModifier->op;
+ }
+ else
+ {
+ return;
+ }
+ if (binaryExpr->arguments.getCount() != 2)
+ return;
+ auto leftCompareOperand = binaryExpr->arguments[0];
+ auto rightCompareOperand = binaryExpr->arguments[1];
+ if (!leftCompareOperand)
+ return;
+ if (!rightCompareOperand)
+ return;
+ if (auto rightVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[1], nullptr))
+ {
+ auto leftVar = as<VarExpr>(leftCompareOperand);
+ if (!leftVar)
+ return;
+ predicateVar = leftVar->declRef;
+ finalVal = as<ConstantIntVal>(rightVal);
+ }
+ else if (auto leftVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[0], nullptr))
+ {
+ auto rightVar = as<VarExpr>(rightCompareOperand);
+ if (!rightVar)
+ return;
+ predicateVar = rightVar->declRef;
+ finalVal = as<ConstantIntVal>(leftVal);
+ compareOp = getSwapSideComparisonOp(compareOp);
+ }
+ else
+ {
+ // If neither left or right is constant, we assume left is variable and continue checking.
+ if (auto leftVar = as<VarExpr>(leftCompareOperand))
+ {
+ predicateVar = leftVar->declRef;
+ }
+ if (auto rightVar = as<VarExpr>(rightCompareOperand))
+ {
+ if (rightVar->declRef == initialVar)
+ {
+ predicateVar = rightVar->declRef;
+ compareOp = getSwapSideComparisonOp(compareOp);
+ }
+ }
+ }
+
+ switch (compareOp)
+ {
+ case kIROp_Less:
+ case kIROp_Leq:
+ case kIROp_Greater:
+ case kIROp_Geq:
+ break;
+ default:
+ return;
+ }
+
+ ConstantIntVal* stepSize = nullptr;
+ IROp sideEffectFuncOp = kIROp_Nop;
+ auto opSideEffectExpr = as<InvokeExpr>(stmt->sideEffectExpression);
+ if (!opSideEffectExpr)
+ return;
+ auto sideEffectFuncExpr = as<DeclRefExpr>(opSideEffectExpr->functionExpr);
+ if (!sideEffectFuncExpr)
+ return;
+ auto sideEffectFuncDecl = sideEffectFuncExpr->declRef.getDecl();
+ if (!sideEffectFuncDecl)
+ return;
+ if (auto opName = sideEffectFuncDecl->getName())
+ {
+ if (opName->text == "++")
+ sideEffectFuncOp = kIROp_Add;
+ else if (opName->text == "--")
+ sideEffectFuncOp = kIROp_Sub;
+ else if (opName->text == "+=")
+ sideEffectFuncOp = kIROp_Add;
+ else if (opName->text == "-=")
+ sideEffectFuncOp = kIROp_Sub;
+ else
+ return;
+ }
+ if (opSideEffectExpr->arguments.getCount())
+ {
+ auto varExpr = as<VarExpr>(opSideEffectExpr->arguments[0]);
+ if (!varExpr)
+ return;
+ if (varExpr->declRef != initialVar)
+ {
+ // If the user writes something like `for (int i = 0; i < 5; j++)`,
+ // it is most likely a bug, so we issue a warning.
+ if (predicateVar == initialVar)
+ getSink()->diagnose(varExpr, Diagnostics::forLoopSideEffectChangingDifferentVar, initialVar, varExpr->declRef);
+ return;
+ }
+ }
+ else
+ return;
+ if (opSideEffectExpr->arguments.getCount() == 2)
+ {
+ auto stepVal = tryFoldIntegerConstantExpression(opSideEffectExpr->arguments[1], nullptr);
+ if (!stepVal)
+ return;
+ if (auto constantIntVal = as<ConstantIntVal>(stepVal))
+ {
+ stepSize = constantIntVal;
+ }
+ }
+ else
+ {
+ stepSize = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
+ }
+
+ if (predicateVar != initialVar)
+ {
+ if (predicateVar)
+ getSink()->diagnose(stmt->predicateExpression, Diagnostics::forLoopPredicateCheckingDifferentVar, initialVar, predicateVar);
+ return;
+ }
+ if (!stepSize)
+ return;
+ if (stepSize->value > 0)
+ {
+ if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater ||
+ sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less)
+ {
+ getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar);
+ return;
+ }
+ }
+ else if (stepSize->value < 0)
+ {
+ if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less ||
+ sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater)
+ {
+ getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar);
+ return;
+ }
+ }
+ else
+ {
+ getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopNotModifyingIterationVariable, initialVar);
+ return;
+ }
+
+ if (!initialLitVal || !finalVal)
+ return;
+
+ auto absStepSize = abs(stepSize->value);
+ int adjustment = 0;
+ if (compareOp == kIROp_Geq || compareOp == kIROp_Leq)
+ adjustment = 1;
+
+ auto iterations = (Math::Max(finalVal->value, initialLitVal->value) -
+ Math::Min(finalVal->value, initialLitVal->value) + absStepSize - 1 + adjustment) /
+ absStepSize;
+ switch (compareOp)
+ {
+ case kIROp_Geq:
+ case kIROp_Greater:
+ // Expect final value to be less than initial value.
+ if (finalVal->value > initialLitVal->value)
+ iterations = 0;
+ break;
+ case kIROp_Leq:
+ case kIROp_Less:
+ if (finalVal->value < initialLitVal->value)
+ iterations = 0;
+ break;
+ }
+ if (iterations == 0)
+ {
+ getSink()->diagnose(stmt, Diagnostics::loopRunsForZeroIterations);
+ }
+
+ // Note: the inferred max iterations may not be valid if the loop body
+ // also modifies the induction variable.
+ // We detect this case during lower-to-ir and will remove the `InferredMaxItersAttribute`
+ // if the loop body modifies the induction variable.
+ //
+ auto maxItersAttr = m_astBuilder->create<InferredMaxItersAttribute>();
+ auto litExpr = m_astBuilder->create<LiteralExpr>();
+ litExpr->type.type = m_astBuilder->getIntType();
+ litExpr->token.setName(getNamePool()->getName(String(iterations)));
+ maxItersAttr->args.add(litExpr);
+ maxItersAttr->intArgVals.Add(0, m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations));
+ maxItersAttr->value = (int32_t)iterations;
+ maxItersAttr->inductionVar = initialVar;
+ addModifier(stmt, maxItersAttr);
+ return;
+ }
+
+ void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt)
+ {
+ if (getParentDifferentiableAttribute())
+ {
+ if (!getParentFunc())
+ return;
+
+ // If the function is itself a derivative, or has a user defined derivative,
+ // then we don't require anything.
+
+ if (getParentFunc()->findModifier<ForwardDerivativeOfAttribute>())
+ return;
+ if (getParentFunc()->findModifier<ForwardDerivativeAttribute>())
+ return;
+ if (getParentFunc()->findModifier<BackwardDerivativeOfAttribute>())
+ return;
+ if (getParentFunc()->findModifier<BackwardDerivativeAttribute>())
+ return;
+
+ // For all ordinary differentiable functions, we require either a `[MaxIters]` attribute,
+ // or a `[ForceUnroll]` attribet on loops.
+ if (stmt->hasModifier<MaxItersAttribute>() || stmt->hasModifier<ForceUnrollAttribute>() || stmt->hasModifier<InferredMaxItersAttribute>())
+ {
+ }
+ else
+ {
+ getSink()->diagnose(stmt, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters);
+ }
+ }
+ }
+
void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt*stmt)
{
stmt->device = CheckExpr(stmt->device);