summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-20 10:17:00 -0800
committerGitHub <noreply@github.com>2023-02-20 10:17:00 -0800
commit8b05df4187117d61491f2fdbeb7d744146ad73f7 (patch)
treecfb17b26e9db313d0b6ce1a07efe85b35d6d1638 /source
parenta8da735ca4e0ed49796dda164c39e21aea4a7bc6 (diff)
Add static for loop iteration inference. (#2659)
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang6
-rw-r--r--source/slang/slang-ast-modifier.h8
-rw-r--r--source/slang/slang-check-impl.h5
-rw-r--r--source/slang/slang-check-stmt.cpp289
-rw-r--r--source/slang/slang-diagnostic-defs.h9
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp16
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-util.cpp21
-rw-r--r--source/slang/slang-ir-util.h3
-rw-r--r--source/slang/slang-lower-to-ir.cpp43
10 files changed, 396 insertions, 5 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index af06e6bac..a60a77cc3 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -270,6 +270,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma
#define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \
vector<TYPE, COUNT> result; \
vector<TYPE, COUNT>.Differential d_result; \
+ [ForceUnroll]\
for (int i = 0; i < N; ++i) \
{ \
DifferentialPair<TYPE> dp_elem = D_FUNC(DifferentialPair<TYPE>(VALUE.p[i], __slang_noop_cast<TYPE.Differential>(VALUE.d[i]))); \
@@ -281,6 +282,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma
#define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \
vector<TYPE, COUNT> result; \
vector<TYPE, COUNT>.Differential d_result; \
+ [ForceUnroll] \
for (int i = 0; i < N; ++i) \
{ \
DifferentialPair<TYPE> dp_elem = D_FUNC(DifferentialPair<TYPE>(LEFT.p[i], __slang_noop_cast<TYPE.Differential>(LEFT.d[i])), \
@@ -292,6 +294,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma
#define VECTOR_MAP_BWD_D_UNARY(TYPE, COUNT, D_FUNC, VALUE, D_OUT) \
vector<TYPE, COUNT>.Differential d_result; \
+ [ForceUnroll] \
for (int i = 0; i < N; ++i) \
{ \
DifferentialPair<TYPE> dp_elem = diffPair(VALUE.p[i], TYPE.dzero()); \
@@ -302,6 +305,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma
#define VECTOR_MAP_BWD_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT, D_OUT) \
vector<TYPE, COUNT>.Differential left_d_result, right_d_result; \
+ [ForceUnroll] \
for (int i = 0; i < N; ++i) \
{ \
DifferentialPair<TYPE> left_dp = diffPair(LEFT.p[i], TYPE.dzero()); \
@@ -705,6 +709,7 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair
{
T result = T(0);
T.Differential d_result = T.dzero();
+ [ForceUnroll]
for (int i = 0; i < N; ++i)
{
result = result + dpx.p[i] * dpy.p[i];
@@ -719,6 +724,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int>
void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut)
{
vector<T, N>.Differential x_d_result, y_d_result;
+ [ForceUnroll]
for (int i = 0; i < N; ++i)
{
x_d_result[i] = dpy.p[i] * __slang_noop_cast<T>(dOut);
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 4ab295da6..99e221b1e 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -635,6 +635,14 @@ class MaxItersAttribute : public Attribute
int32_t value = 0;
};
+// An inferred max iteration count on a loop.
+class InferredMaxItersAttribute : public Attribute
+{
+ SLANG_AST_CLASS(InferredMaxItersAttribute)
+ DeclRef<Decl> inductionVar;
+ int32_t value = 0;
+};
+
class LoopAttribute : public Attribute
{
SLANG_AST_CLASS(LoopAttribute)
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index ccc739da3..165c84192 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2023,6 +2023,11 @@ namespace Slang
void visitGpuForeachStmt(GpuForeachStmt *stmt);
void visitExpressionStmt(ExpressionStmt *stmt);
+
+ // Try to infer the max number of iterations the loop will run.
+ void tryInferLoopMaxIterations(ForStmt* stmt);
+
+ void checkLoopInDifferentiableFunc(Stmt* stmt);
};
struct SemanticsDeclVisitorBase
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index 8049c1230..519ca91ff 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -1,5 +1,6 @@
// slang-check-stmt.cpp
#include "slang-check-impl.h"
+#include "slang-ir-util.h"
// This file implements semantic checking logic related to statements.
@@ -158,17 +159,20 @@ namespace Slang
void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt *stmt)
{
+ checkModifiers(stmt);
WithOuterStmt subContext(this, stmt);
stmt->predicate = checkPredicateExpr(stmt->predicate);
subContext.checkStmt(stmt->statement);
+ checkLoopInDifferentiableFunc(stmt);
}
void SemanticsStmtVisitor::visitForStmt(ForStmt *stmt)
{
WithOuterStmt subContext(this, stmt);
-
+ checkModifiers(stmt);
checkStmt(stmt->initialStatement);
+
if (stmt->predicateExpression)
{
stmt->predicateExpression = checkPredicateExpr(stmt->predicateExpression);
@@ -178,6 +182,10 @@ namespace Slang
stmt->sideEffectExpression = CheckExpr(stmt->sideEffectExpression);
}
subContext.checkStmt(stmt->statement);
+
+ tryInferLoopMaxIterations(stmt);
+
+ checkLoopInDifferentiableFunc(stmt);
}
Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal)
@@ -317,9 +325,11 @@ namespace Slang
void SemanticsStmtVisitor::visitWhileStmt(WhileStmt *stmt)
{
+ checkModifiers(stmt);
WithOuterStmt subContext(this, stmt);
stmt->predicate = checkPredicateExpr(stmt->predicate);
subContext.checkStmt(stmt->statement);
+ checkLoopInDifferentiableFunc(stmt);
}
void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt *stmt)
@@ -327,6 +337,283 @@ namespace Slang
stmt->expression = CheckExpr(stmt->expression);
}
+ void SemanticsStmtVisitor::tryInferLoopMaxIterations(ForStmt* stmt)
+ {
+ // If a for loop is in the form of `for (var = initialVal; var $compareOp otherVal; var sideEffectOp operand)`
+ // we will try to constant fold the operands and see if we can statically determine the maximum number of
+ // iterations this loop will run, and insert the inferred result as a `[MaxIters]` attribute on the stmt.
+ //
+ // ++, --, +=, -= are supported in side effect expressions.
+ // >, <, >=, <= are supported in predicate expressions.
+ // induction variable can appear in either side of the expressions.
+ //
+ // Other forms like for (var1 = .., var2 = ..; ) will not be recognized here.
+ // If we see suspicious code like `for (int i = 0; i < 5; j++)`, we will produce a warning along the way.
+ //
+ DeclRef<Decl> predicateVar = {};
+ Expr* initialVal = nullptr;
+ DeclRef<Decl> initialVar = {};
+ if (auto varStmt = as<DeclStmt>(stmt->initialStatement))
+ {
+ auto varDecl = as<VarDecl>(varStmt->decl);
+ if (!varDecl)
+ return;
+ initialVar.decl = varDecl;
+ initialVal = varDecl->initExpr;
+ }
+ else if (auto exprStmt = as<ExpressionStmt>(stmt->initialStatement))
+ {
+ auto assignExpr = as<AssignExpr>(exprStmt->expression);
+ if (!assignExpr)
+ return;
+ auto varExpr = as<VarExpr>(assignExpr->left);
+ if (!varExpr)
+ return;
+ initialVar = varExpr->declRef;
+ initialVal = assignExpr->right;
+ }
+ else
+ return;
+
+ auto initialLitVal =
+ as<ConstantIntVal>(tryFoldIntegerConstantExpression(initialVal, nullptr));
+
+ ConstantIntVal* finalVal = nullptr;
+ auto binaryExpr = as<InfixExpr>(stmt->predicateExpression);
+ if (!binaryExpr)
+ return;
+ auto compareFuncExpr = as<DeclRefExpr>(binaryExpr->functionExpr);
+ if (!compareFuncExpr)
+ return;
+ if (!compareFuncExpr->declRef.getDecl())
+ return;
+ IROp compareOp = kIROp_Nop;
+ if (auto intrinsicOpModifier = compareFuncExpr->declRef.getDecl()->findModifier<IntrinsicOpModifier>())
+ {
+ compareOp = (IROp)intrinsicOpModifier->op;
+ }
+ else
+ {
+ return;
+ }
+ if (binaryExpr->arguments.getCount() != 2)
+ return;
+ auto leftCompareOperand = binaryExpr->arguments[0];
+ auto rightCompareOperand = binaryExpr->arguments[1];
+ if (!leftCompareOperand)
+ return;
+ if (!rightCompareOperand)
+ return;
+ if (auto rightVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[1], nullptr))
+ {
+ auto leftVar = as<VarExpr>(leftCompareOperand);
+ if (!leftVar)
+ return;
+ predicateVar = leftVar->declRef;
+ finalVal = as<ConstantIntVal>(rightVal);
+ }
+ else if (auto leftVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[0], nullptr))
+ {
+ auto rightVar = as<VarExpr>(rightCompareOperand);
+ if (!rightVar)
+ return;
+ predicateVar = rightVar->declRef;
+ finalVal = as<ConstantIntVal>(leftVal);
+ compareOp = getSwapSideComparisonOp(compareOp);
+ }
+ else
+ {
+ // If neither left or right is constant, we assume left is variable and continue checking.
+ if (auto leftVar = as<VarExpr>(leftCompareOperand))
+ {
+ predicateVar = leftVar->declRef;
+ }
+ if (auto rightVar = as<VarExpr>(rightCompareOperand))
+ {
+ if (rightVar->declRef == initialVar)
+ {
+ predicateVar = rightVar->declRef;
+ compareOp = getSwapSideComparisonOp(compareOp);
+ }
+ }
+ }
+
+ switch (compareOp)
+ {
+ case kIROp_Less:
+ case kIROp_Leq:
+ case kIROp_Greater:
+ case kIROp_Geq:
+ break;
+ default:
+ return;
+ }
+
+ ConstantIntVal* stepSize = nullptr;
+ IROp sideEffectFuncOp = kIROp_Nop;
+ auto opSideEffectExpr = as<InvokeExpr>(stmt->sideEffectExpression);
+ if (!opSideEffectExpr)
+ return;
+ auto sideEffectFuncExpr = as<DeclRefExpr>(opSideEffectExpr->functionExpr);
+ if (!sideEffectFuncExpr)
+ return;
+ auto sideEffectFuncDecl = sideEffectFuncExpr->declRef.getDecl();
+ if (!sideEffectFuncDecl)
+ return;
+ if (auto opName = sideEffectFuncDecl->getName())
+ {
+ if (opName->text == "++")
+ sideEffectFuncOp = kIROp_Add;
+ else if (opName->text == "--")
+ sideEffectFuncOp = kIROp_Sub;
+ else if (opName->text == "+=")
+ sideEffectFuncOp = kIROp_Add;
+ else if (opName->text == "-=")
+ sideEffectFuncOp = kIROp_Sub;
+ else
+ return;
+ }
+ if (opSideEffectExpr->arguments.getCount())
+ {
+ auto varExpr = as<VarExpr>(opSideEffectExpr->arguments[0]);
+ if (!varExpr)
+ return;
+ if (varExpr->declRef != initialVar)
+ {
+ // If the user writes something like `for (int i = 0; i < 5; j++)`,
+ // it is most likely a bug, so we issue a warning.
+ if (predicateVar == initialVar)
+ getSink()->diagnose(varExpr, Diagnostics::forLoopSideEffectChangingDifferentVar, initialVar, varExpr->declRef);
+ return;
+ }
+ }
+ else
+ return;
+ if (opSideEffectExpr->arguments.getCount() == 2)
+ {
+ auto stepVal = tryFoldIntegerConstantExpression(opSideEffectExpr->arguments[1], nullptr);
+ if (!stepVal)
+ return;
+ if (auto constantIntVal = as<ConstantIntVal>(stepVal))
+ {
+ stepSize = constantIntVal;
+ }
+ }
+ else
+ {
+ stepSize = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
+ }
+
+ if (predicateVar != initialVar)
+ {
+ if (predicateVar)
+ getSink()->diagnose(stmt->predicateExpression, Diagnostics::forLoopPredicateCheckingDifferentVar, initialVar, predicateVar);
+ return;
+ }
+ if (!stepSize)
+ return;
+ if (stepSize->value > 0)
+ {
+ if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater ||
+ sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less)
+ {
+ getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar);
+ return;
+ }
+ }
+ else if (stepSize->value < 0)
+ {
+ if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less ||
+ sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater)
+ {
+ getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar);
+ return;
+ }
+ }
+ else
+ {
+ getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopNotModifyingIterationVariable, initialVar);
+ return;
+ }
+
+ if (!initialLitVal || !finalVal)
+ return;
+
+ auto absStepSize = abs(stepSize->value);
+ int adjustment = 0;
+ if (compareOp == kIROp_Geq || compareOp == kIROp_Leq)
+ adjustment = 1;
+
+ auto iterations = (Math::Max(finalVal->value, initialLitVal->value) -
+ Math::Min(finalVal->value, initialLitVal->value) + absStepSize - 1 + adjustment) /
+ absStepSize;
+ switch (compareOp)
+ {
+ case kIROp_Geq:
+ case kIROp_Greater:
+ // Expect final value to be less than initial value.
+ if (finalVal->value > initialLitVal->value)
+ iterations = 0;
+ break;
+ case kIROp_Leq:
+ case kIROp_Less:
+ if (finalVal->value < initialLitVal->value)
+ iterations = 0;
+ break;
+ }
+ if (iterations == 0)
+ {
+ getSink()->diagnose(stmt, Diagnostics::loopRunsForZeroIterations);
+ }
+
+ // Note: the inferred max iterations may not be valid if the loop body
+ // also modifies the induction variable.
+ // We detect this case during lower-to-ir and will remove the `InferredMaxItersAttribute`
+ // if the loop body modifies the induction variable.
+ //
+ auto maxItersAttr = m_astBuilder->create<InferredMaxItersAttribute>();
+ auto litExpr = m_astBuilder->create<LiteralExpr>();
+ litExpr->type.type = m_astBuilder->getIntType();
+ litExpr->token.setName(getNamePool()->getName(String(iterations)));
+ maxItersAttr->args.add(litExpr);
+ maxItersAttr->intArgVals.Add(0, m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations));
+ maxItersAttr->value = (int32_t)iterations;
+ maxItersAttr->inductionVar = initialVar;
+ addModifier(stmt, maxItersAttr);
+ return;
+ }
+
+ void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt)
+ {
+ if (getParentDifferentiableAttribute())
+ {
+ if (!getParentFunc())
+ return;
+
+ // If the function is itself a derivative, or has a user defined derivative,
+ // then we don't require anything.
+
+ if (getParentFunc()->findModifier<ForwardDerivativeOfAttribute>())
+ return;
+ if (getParentFunc()->findModifier<ForwardDerivativeAttribute>())
+ return;
+ if (getParentFunc()->findModifier<BackwardDerivativeOfAttribute>())
+ return;
+ if (getParentFunc()->findModifier<BackwardDerivativeAttribute>())
+ return;
+
+ // For all ordinary differentiable functions, we require either a `[MaxIters]` attribute,
+ // or a `[ForceUnroll]` attribet on loops.
+ if (stmt->hasModifier<MaxItersAttribute>() || stmt->hasModifier<ForceUnrollAttribute>() || stmt->hasModifier<InferredMaxItersAttribute>())
+ {
+ }
+ else
+ {
+ getSink()->diagnose(stmt, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters);
+ }
+ }
+ }
+
void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt*stmt)
{
stmt->device = CheckExpr(stmt->device);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index e29d7eeac..c5f7e6cbe 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -366,6 +366,15 @@ DIAGNOSTIC(30310, Error, typeIsNotDifferentiable, "type '$0' is not differentiab
// Interop
DIAGNOSTIC(30400, Error, cannotDefinePtrTypeToManagedResource, "pointer to a managed resource is invalid, use `NativeRef<T>` instead")
+// Control flow
+DIAGNOSTIC(30500, Warning, forLoopSideEffectChangingDifferentVar, "the for loop initializes and checks variable '$0' but the side effect expression is modifying '$1'.")
+DIAGNOSTIC(30501, Warning, forLoopPredicateCheckingDifferentVar, "the for loop initializes and modifies variable '$0' but the predicate expression is checking '$1'.")
+DIAGNOSTIC(30502, Warning, forLoopChangingIterationVariableInOppsoiteDirection, "the for loop is modifiying variable '$0' in the opposite direction from loop exit condition.")
+DIAGNOSTIC(30503, Warning, forLoopNotModifyingIterationVariable, "the for loop is not modifiying variable '$0' because the step size evaluates to 0.")
+DIAGNOSTIC(30504, Warning, forLoopTerminatesInFewerIterationsThanMaxIters, "the for loop is statically determined to terminate within $0 iterations, which is less than what [MaxIters] specifies.")
+DIAGNOSTIC(30505, Warning, loopRunsForZeroIterations, "the loop runs for 0 iterations and will be removed.")
+DIAGNOSTIC(30510, Error, loopInDiffFuncRequireUnrollOrMaxIters, "loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute.")
+
// TODO: need to assign numbers to all these extra diagnostics...
DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.")
DIAGNOSTIC(39999, Error, localVariableUsedBeforeDeclared, "local variable '$0' is being used before its declaration.")
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 98f2b2c34..c750b2d3d 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -346,6 +346,22 @@ public:
}
}
}
+
+ // Make sure all loops are marked with either [MaxIters] or [ForceUnroll].
+ for (auto block : funcInst->getBlocks())
+ {
+ auto loop = as<IRLoop>(block->getTerminator());
+ if (!loop)
+ continue;
+ if (loop->findDecoration<IRLoopMaxItersDecoration>() || loop->findDecoration<IRForceUnrollDecoration>())
+ {
+ // We are good.
+ }
+ else
+ {
+ sink->diagnose(loop->sourceLoc, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters);
+ }
+ }
}
void processModule()
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f2107aa62..4dea3985a 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -598,6 +598,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(LayoutDecoration, layout, 1, 0)
INST(LoopControlDecoration, loopControl, 1, 0)
INST(LoopMaxItersDecoration, loopMaxIters, 1, 0)
+ INST(LoopInferredMaxItersDecoration, loopInferredMaxIters, 2, 0)
INST(LoopExitPrimalValueDecoration, loopExitPrimalValue, 2, 0)
INST(IntrinsicOpDecoration, intrinsicOp, 1, 0)
/* TargetSpecificDecoration */
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 253686aa5..f3c4c2c82 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -466,6 +466,27 @@ IRInst* getUndefInst(IRBuilder builder, IRModule* module)
return undefInst;
}
+IROp getSwapSideComparisonOp(IROp op)
+{
+ switch (op)
+ {
+ case kIROp_Eql:
+ return kIROp_Eql;
+ case kIROp_Neq:
+ return kIROp_Neq;
+ case kIROp_Leq:
+ return kIROp_Geq;
+ case kIROp_Geq:
+ return kIROp_Leq;
+ case kIROp_Less:
+ return kIROp_Greater;
+ case kIROp_Greater:
+ return kIROp_Less;
+ default:
+ return kIROp_Nop;
+ }
+}
+
bool isPureFunctionalCall(IRCall* call)
{
auto callee = getResolvedInstForDecorations(call->getCallee());
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 2f1ac2d1a..0fb26f791 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -170,6 +170,9 @@ bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, I
IRInst* getUndefInst(IRBuilder builder, IRModule* module);
+// The the equivalent op of (a op b) in (b op' a). For example, a > b is equivalent to b < a. So (<) ==> (>).
+IROp getSwapSideComparisonOp(IROp op);
+
}
#endif
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index aa2dc4efb..681871b6c 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4856,11 +4856,17 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
{
getBuilder()->addLoopControlDecoration(inst, kIRLoopControl_Loop);
}
- else if( auto maxItersAttr = stmt->findModifier<MaxItersAttribute>() )
+
+ if( auto maxItersAttr = stmt->findModifier<MaxItersAttribute>() )
{
getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value);
}
- else if (auto forceUnrollAttr = stmt->findModifier<ForceUnrollAttribute>())
+ else if (auto inferredMaxItersAttr = stmt->findModifier<InferredMaxItersAttribute>())
+ {
+ getBuilder()->addLoopMaxItersDecoration(inst, inferredMaxItersAttr->value);
+ }
+
+ if (auto forceUnrollAttr = stmt->findModifier<ForceUnrollAttribute>())
{
getBuilder()->addLoopForceUnrollDecoration(inst, forceUnrollAttr->maxIterations);
}
@@ -4901,8 +4907,6 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
breakLabel,
continueLabel);
- addLoopDecorations(loopInst, stmt);
-
insertBlock(loopHead);
// Now that we are within the header block, we
@@ -4923,6 +4927,37 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
insertBlock(bodyLabel);
lowerStmt(context, stmt->statement);
+ if (auto inferredMaxIters = stmt->findModifier<InferredMaxItersAttribute>())
+ {
+ // We only use inferred max iters attribute when the loop body
+ // does not modify induction var.
+ auto inductionVar = emitDeclRef(context, inferredMaxIters->inductionVar, builder->getIntType());
+ if (inductionVar.val)
+ {
+ int writes = 0;
+ traverseUsers(inductionVar.val, [&](IRInst* user) {if (user->getOp() != kIROp_Load) writes++; });
+ if (writes > 1)
+ {
+ removeModifier(stmt, inferredMaxIters);
+ }
+ }
+ }
+ if (auto inferredMaxIters = stmt->findModifier<InferredMaxItersAttribute>())
+ {
+ if (auto maxIters = stmt->findModifier<MaxItersAttribute>())
+ {
+ if (inferredMaxIters->value < maxIters->value)
+ {
+ context->getSink()->diagnose(
+ maxIters,
+ Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters,
+ inferredMaxIters->value);
+ }
+ }
+ }
+ addLoopDecorations(loopInst, stmt);
+
+
// Insert the `continue` block
insertBlock(continueLabel);
if (auto incrExpr = stmt->sideEffectExpression)