diff options
| -rw-r--r-- | source/slang/diff.meta.slang | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 289 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 43 | ||||
| -rw-r--r-- | tests/autodiff/generic-impl-jvp.slang | 10 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop.slang | 1 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-data-flow.slang | 5 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-data-flow.slang.expected | 3 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff.slang | 15 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff.slang.expected | 9 | ||||
| -rw-r--r-- | tests/diagnostics/for-loop-warning.slang | 60 | ||||
| -rw-r--r-- | tests/diagnostics/for-loop-warning.slang.expected | 35 |
18 files changed, 529 insertions, 10 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index af06e6bac..a60a77cc3 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -270,6 +270,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ vector<TYPE, COUNT> result; \ vector<TYPE, COUNT>.Differential d_result; \ + [ForceUnroll]\ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> dp_elem = D_FUNC(DifferentialPair<TYPE>(VALUE.p[i], __slang_noop_cast<TYPE.Differential>(VALUE.d[i]))); \ @@ -281,6 +282,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \ vector<TYPE, COUNT> result; \ vector<TYPE, COUNT>.Differential d_result; \ + [ForceUnroll] \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> dp_elem = D_FUNC(DifferentialPair<TYPE>(LEFT.p[i], __slang_noop_cast<TYPE.Differential>(LEFT.d[i])), \ @@ -292,6 +294,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_BWD_D_UNARY(TYPE, COUNT, D_FUNC, VALUE, D_OUT) \ vector<TYPE, COUNT>.Differential d_result; \ + [ForceUnroll] \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> dp_elem = diffPair(VALUE.p[i], TYPE.dzero()); \ @@ -302,6 +305,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_BWD_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT, D_OUT) \ vector<TYPE, COUNT>.Differential left_d_result, right_d_result; \ + [ForceUnroll] \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> left_dp = diffPair(LEFT.p[i], TYPE.dzero()); \ @@ -705,6 +709,7 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair { T result = T(0); T.Differential d_result = T.dzero(); + [ForceUnroll] for (int i = 0; i < N; ++i) { result = result + dpx.p[i] * dpy.p[i]; @@ -719,6 +724,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int> void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut) { vector<T, N>.Differential x_d_result, y_d_result; + [ForceUnroll] for (int i = 0; i < N; ++i) { x_d_result[i] = dpy.p[i] * __slang_noop_cast<T>(dOut); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 4ab295da6..99e221b1e 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -635,6 +635,14 @@ class MaxItersAttribute : public Attribute int32_t value = 0; }; +// An inferred max iteration count on a loop. +class InferredMaxItersAttribute : public Attribute +{ + SLANG_AST_CLASS(InferredMaxItersAttribute) + DeclRef<Decl> inductionVar; + int32_t value = 0; +}; + class LoopAttribute : public Attribute { SLANG_AST_CLASS(LoopAttribute) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index ccc739da3..165c84192 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2023,6 +2023,11 @@ namespace Slang void visitGpuForeachStmt(GpuForeachStmt *stmt); void visitExpressionStmt(ExpressionStmt *stmt); + + // Try to infer the max number of iterations the loop will run. + void tryInferLoopMaxIterations(ForStmt* stmt); + + void checkLoopInDifferentiableFunc(Stmt* stmt); }; struct SemanticsDeclVisitorBase diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 8049c1230..519ca91ff 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -1,5 +1,6 @@ // slang-check-stmt.cpp #include "slang-check-impl.h" +#include "slang-ir-util.h" // This file implements semantic checking logic related to statements. @@ -158,17 +159,20 @@ namespace Slang void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt *stmt) { + checkModifiers(stmt); WithOuterStmt subContext(this, stmt); stmt->predicate = checkPredicateExpr(stmt->predicate); subContext.checkStmt(stmt->statement); + checkLoopInDifferentiableFunc(stmt); } void SemanticsStmtVisitor::visitForStmt(ForStmt *stmt) { WithOuterStmt subContext(this, stmt); - + checkModifiers(stmt); checkStmt(stmt->initialStatement); + if (stmt->predicateExpression) { stmt->predicateExpression = checkPredicateExpr(stmt->predicateExpression); @@ -178,6 +182,10 @@ namespace Slang stmt->sideEffectExpression = CheckExpr(stmt->sideEffectExpression); } subContext.checkStmt(stmt->statement); + + tryInferLoopMaxIterations(stmt); + + checkLoopInDifferentiableFunc(stmt); } Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal) @@ -317,9 +325,11 @@ namespace Slang void SemanticsStmtVisitor::visitWhileStmt(WhileStmt *stmt) { + checkModifiers(stmt); WithOuterStmt subContext(this, stmt); stmt->predicate = checkPredicateExpr(stmt->predicate); subContext.checkStmt(stmt->statement); + checkLoopInDifferentiableFunc(stmt); } void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt *stmt) @@ -327,6 +337,283 @@ namespace Slang stmt->expression = CheckExpr(stmt->expression); } + void SemanticsStmtVisitor::tryInferLoopMaxIterations(ForStmt* stmt) + { + // If a for loop is in the form of `for (var = initialVal; var $compareOp otherVal; var sideEffectOp operand)` + // we will try to constant fold the operands and see if we can statically determine the maximum number of + // iterations this loop will run, and insert the inferred result as a `[MaxIters]` attribute on the stmt. + // + // ++, --, +=, -= are supported in side effect expressions. + // >, <, >=, <= are supported in predicate expressions. + // induction variable can appear in either side of the expressions. + // + // Other forms like for (var1 = .., var2 = ..; ) will not be recognized here. + // If we see suspicious code like `for (int i = 0; i < 5; j++)`, we will produce a warning along the way. + // + DeclRef<Decl> predicateVar = {}; + Expr* initialVal = nullptr; + DeclRef<Decl> initialVar = {}; + if (auto varStmt = as<DeclStmt>(stmt->initialStatement)) + { + auto varDecl = as<VarDecl>(varStmt->decl); + if (!varDecl) + return; + initialVar.decl = varDecl; + initialVal = varDecl->initExpr; + } + else if (auto exprStmt = as<ExpressionStmt>(stmt->initialStatement)) + { + auto assignExpr = as<AssignExpr>(exprStmt->expression); + if (!assignExpr) + return; + auto varExpr = as<VarExpr>(assignExpr->left); + if (!varExpr) + return; + initialVar = varExpr->declRef; + initialVal = assignExpr->right; + } + else + return; + + auto initialLitVal = + as<ConstantIntVal>(tryFoldIntegerConstantExpression(initialVal, nullptr)); + + ConstantIntVal* finalVal = nullptr; + auto binaryExpr = as<InfixExpr>(stmt->predicateExpression); + if (!binaryExpr) + return; + auto compareFuncExpr = as<DeclRefExpr>(binaryExpr->functionExpr); + if (!compareFuncExpr) + return; + if (!compareFuncExpr->declRef.getDecl()) + return; + IROp compareOp = kIROp_Nop; + if (auto intrinsicOpModifier = compareFuncExpr->declRef.getDecl()->findModifier<IntrinsicOpModifier>()) + { + compareOp = (IROp)intrinsicOpModifier->op; + } + else + { + return; + } + if (binaryExpr->arguments.getCount() != 2) + return; + auto leftCompareOperand = binaryExpr->arguments[0]; + auto rightCompareOperand = binaryExpr->arguments[1]; + if (!leftCompareOperand) + return; + if (!rightCompareOperand) + return; + if (auto rightVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[1], nullptr)) + { + auto leftVar = as<VarExpr>(leftCompareOperand); + if (!leftVar) + return; + predicateVar = leftVar->declRef; + finalVal = as<ConstantIntVal>(rightVal); + } + else if (auto leftVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[0], nullptr)) + { + auto rightVar = as<VarExpr>(rightCompareOperand); + if (!rightVar) + return; + predicateVar = rightVar->declRef; + finalVal = as<ConstantIntVal>(leftVal); + compareOp = getSwapSideComparisonOp(compareOp); + } + else + { + // If neither left or right is constant, we assume left is variable and continue checking. + if (auto leftVar = as<VarExpr>(leftCompareOperand)) + { + predicateVar = leftVar->declRef; + } + if (auto rightVar = as<VarExpr>(rightCompareOperand)) + { + if (rightVar->declRef == initialVar) + { + predicateVar = rightVar->declRef; + compareOp = getSwapSideComparisonOp(compareOp); + } + } + } + + switch (compareOp) + { + case kIROp_Less: + case kIROp_Leq: + case kIROp_Greater: + case kIROp_Geq: + break; + default: + return; + } + + ConstantIntVal* stepSize = nullptr; + IROp sideEffectFuncOp = kIROp_Nop; + auto opSideEffectExpr = as<InvokeExpr>(stmt->sideEffectExpression); + if (!opSideEffectExpr) + return; + auto sideEffectFuncExpr = as<DeclRefExpr>(opSideEffectExpr->functionExpr); + if (!sideEffectFuncExpr) + return; + auto sideEffectFuncDecl = sideEffectFuncExpr->declRef.getDecl(); + if (!sideEffectFuncDecl) + return; + if (auto opName = sideEffectFuncDecl->getName()) + { + if (opName->text == "++") + sideEffectFuncOp = kIROp_Add; + else if (opName->text == "--") + sideEffectFuncOp = kIROp_Sub; + else if (opName->text == "+=") + sideEffectFuncOp = kIROp_Add; + else if (opName->text == "-=") + sideEffectFuncOp = kIROp_Sub; + else + return; + } + if (opSideEffectExpr->arguments.getCount()) + { + auto varExpr = as<VarExpr>(opSideEffectExpr->arguments[0]); + if (!varExpr) + return; + if (varExpr->declRef != initialVar) + { + // If the user writes something like `for (int i = 0; i < 5; j++)`, + // it is most likely a bug, so we issue a warning. + if (predicateVar == initialVar) + getSink()->diagnose(varExpr, Diagnostics::forLoopSideEffectChangingDifferentVar, initialVar, varExpr->declRef); + return; + } + } + else + return; + if (opSideEffectExpr->arguments.getCount() == 2) + { + auto stepVal = tryFoldIntegerConstantExpression(opSideEffectExpr->arguments[1], nullptr); + if (!stepVal) + return; + if (auto constantIntVal = as<ConstantIntVal>(stepVal)) + { + stepSize = constantIntVal; + } + } + else + { + stepSize = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + } + + if (predicateVar != initialVar) + { + if (predicateVar) + getSink()->diagnose(stmt->predicateExpression, Diagnostics::forLoopPredicateCheckingDifferentVar, initialVar, predicateVar); + return; + } + if (!stepSize) + return; + if (stepSize->value > 0) + { + if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater || + sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less) + { + getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar); + return; + } + } + else if (stepSize->value < 0) + { + if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less || + sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater) + { + getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar); + return; + } + } + else + { + getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopNotModifyingIterationVariable, initialVar); + return; + } + + if (!initialLitVal || !finalVal) + return; + + auto absStepSize = abs(stepSize->value); + int adjustment = 0; + if (compareOp == kIROp_Geq || compareOp == kIROp_Leq) + adjustment = 1; + + auto iterations = (Math::Max(finalVal->value, initialLitVal->value) - + Math::Min(finalVal->value, initialLitVal->value) + absStepSize - 1 + adjustment) / + absStepSize; + switch (compareOp) + { + case kIROp_Geq: + case kIROp_Greater: + // Expect final value to be less than initial value. + if (finalVal->value > initialLitVal->value) + iterations = 0; + break; + case kIROp_Leq: + case kIROp_Less: + if (finalVal->value < initialLitVal->value) + iterations = 0; + break; + } + if (iterations == 0) + { + getSink()->diagnose(stmt, Diagnostics::loopRunsForZeroIterations); + } + + // Note: the inferred max iterations may not be valid if the loop body + // also modifies the induction variable. + // We detect this case during lower-to-ir and will remove the `InferredMaxItersAttribute` + // if the loop body modifies the induction variable. + // + auto maxItersAttr = m_astBuilder->create<InferredMaxItersAttribute>(); + auto litExpr = m_astBuilder->create<LiteralExpr>(); + litExpr->type.type = m_astBuilder->getIntType(); + litExpr->token.setName(getNamePool()->getName(String(iterations))); + maxItersAttr->args.add(litExpr); + maxItersAttr->intArgVals.Add(0, m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations)); + maxItersAttr->value = (int32_t)iterations; + maxItersAttr->inductionVar = initialVar; + addModifier(stmt, maxItersAttr); + return; + } + + void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt) + { + if (getParentDifferentiableAttribute()) + { + if (!getParentFunc()) + return; + + // If the function is itself a derivative, or has a user defined derivative, + // then we don't require anything. + + if (getParentFunc()->findModifier<ForwardDerivativeOfAttribute>()) + return; + if (getParentFunc()->findModifier<ForwardDerivativeAttribute>()) + return; + if (getParentFunc()->findModifier<BackwardDerivativeOfAttribute>()) + return; + if (getParentFunc()->findModifier<BackwardDerivativeAttribute>()) + return; + + // For all ordinary differentiable functions, we require either a `[MaxIters]` attribute, + // or a `[ForceUnroll]` attribet on loops. + if (stmt->hasModifier<MaxItersAttribute>() || stmt->hasModifier<ForceUnrollAttribute>() || stmt->hasModifier<InferredMaxItersAttribute>()) + { + } + else + { + getSink()->diagnose(stmt, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters); + } + } + } + void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt*stmt) { stmt->device = CheckExpr(stmt->device); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index e29d7eeac..c5f7e6cbe 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -366,6 +366,15 @@ DIAGNOSTIC(30310, Error, typeIsNotDifferentiable, "type '$0' is not differentiab // Interop DIAGNOSTIC(30400, Error, cannotDefinePtrTypeToManagedResource, "pointer to a managed resource is invalid, use `NativeRef<T>` instead") +// Control flow +DIAGNOSTIC(30500, Warning, forLoopSideEffectChangingDifferentVar, "the for loop initializes and checks variable '$0' but the side effect expression is modifying '$1'.") +DIAGNOSTIC(30501, Warning, forLoopPredicateCheckingDifferentVar, "the for loop initializes and modifies variable '$0' but the predicate expression is checking '$1'.") +DIAGNOSTIC(30502, Warning, forLoopChangingIterationVariableInOppsoiteDirection, "the for loop is modifiying variable '$0' in the opposite direction from loop exit condition.") +DIAGNOSTIC(30503, Warning, forLoopNotModifyingIterationVariable, "the for loop is not modifiying variable '$0' because the step size evaluates to 0.") +DIAGNOSTIC(30504, Warning, forLoopTerminatesInFewerIterationsThanMaxIters, "the for loop is statically determined to terminate within $0 iterations, which is less than what [MaxIters] specifies.") +DIAGNOSTIC(30505, Warning, loopRunsForZeroIterations, "the loop runs for 0 iterations and will be removed.") +DIAGNOSTIC(30510, Error, loopInDiffFuncRequireUnrollOrMaxIters, "loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute.") + // TODO: need to assign numbers to all these extra diagnostics... DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.") DIAGNOSTIC(39999, Error, localVariableUsedBeforeDeclared, "local variable '$0' is being used before its declaration.") diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 98f2b2c34..c750b2d3d 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -346,6 +346,22 @@ public: } } } + + // Make sure all loops are marked with either [MaxIters] or [ForceUnroll]. + for (auto block : funcInst->getBlocks()) + { + auto loop = as<IRLoop>(block->getTerminator()); + if (!loop) + continue; + if (loop->findDecoration<IRLoopMaxItersDecoration>() || loop->findDecoration<IRForceUnrollDecoration>()) + { + // We are good. + } + else + { + sink->diagnose(loop->sourceLoc, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters); + } + } } void processModule() diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f2107aa62..4dea3985a 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -598,6 +598,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(LayoutDecoration, layout, 1, 0) INST(LoopControlDecoration, loopControl, 1, 0) INST(LoopMaxItersDecoration, loopMaxIters, 1, 0) + INST(LoopInferredMaxItersDecoration, loopInferredMaxIters, 2, 0) INST(LoopExitPrimalValueDecoration, loopExitPrimalValue, 2, 0) INST(IntrinsicOpDecoration, intrinsicOp, 1, 0) /* TargetSpecificDecoration */ diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 253686aa5..f3c4c2c82 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -466,6 +466,27 @@ IRInst* getUndefInst(IRBuilder builder, IRModule* module) return undefInst; } +IROp getSwapSideComparisonOp(IROp op) +{ + switch (op) + { + case kIROp_Eql: + return kIROp_Eql; + case kIROp_Neq: + return kIROp_Neq; + case kIROp_Leq: + return kIROp_Geq; + case kIROp_Geq: + return kIROp_Leq; + case kIROp_Less: + return kIROp_Greater; + case kIROp_Greater: + return kIROp_Less; + default: + return kIROp_Nop; + } +} + bool isPureFunctionalCall(IRCall* call) { auto callee = getResolvedInstForDecorations(call->getCallee()); diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 2f1ac2d1a..0fb26f791 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -170,6 +170,9 @@ bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, I IRInst* getUndefInst(IRBuilder builder, IRModule* module); +// The the equivalent op of (a op b) in (b op' a). For example, a > b is equivalent to b < a. So (<) ==> (>). +IROp getSwapSideComparisonOp(IROp op); + } #endif diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index aa2dc4efb..681871b6c 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4856,11 +4856,17 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> { getBuilder()->addLoopControlDecoration(inst, kIRLoopControl_Loop); } - else if( auto maxItersAttr = stmt->findModifier<MaxItersAttribute>() ) + + if( auto maxItersAttr = stmt->findModifier<MaxItersAttribute>() ) { getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value); } - else if (auto forceUnrollAttr = stmt->findModifier<ForceUnrollAttribute>()) + else if (auto inferredMaxItersAttr = stmt->findModifier<InferredMaxItersAttribute>()) + { + getBuilder()->addLoopMaxItersDecoration(inst, inferredMaxItersAttr->value); + } + + if (auto forceUnrollAttr = stmt->findModifier<ForceUnrollAttribute>()) { getBuilder()->addLoopForceUnrollDecoration(inst, forceUnrollAttr->maxIterations); } @@ -4901,8 +4907,6 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> breakLabel, continueLabel); - addLoopDecorations(loopInst, stmt); - insertBlock(loopHead); // Now that we are within the header block, we @@ -4923,6 +4927,37 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> insertBlock(bodyLabel); lowerStmt(context, stmt->statement); + if (auto inferredMaxIters = stmt->findModifier<InferredMaxItersAttribute>()) + { + // We only use inferred max iters attribute when the loop body + // does not modify induction var. + auto inductionVar = emitDeclRef(context, inferredMaxIters->inductionVar, builder->getIntType()); + if (inductionVar.val) + { + int writes = 0; + traverseUsers(inductionVar.val, [&](IRInst* user) {if (user->getOp() != kIROp_Load) writes++; }); + if (writes > 1) + { + removeModifier(stmt, inferredMaxIters); + } + } + } + if (auto inferredMaxIters = stmt->findModifier<InferredMaxItersAttribute>()) + { + if (auto maxIters = stmt->findModifier<MaxItersAttribute>()) + { + if (inferredMaxIters->value < maxIters->value) + { + context->getSink()->diagnose( + maxIters, + Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters, + inferredMaxIters->value); + } + } + } + addLoopDecorations(loopInst, stmt); + + // Insert the `continue` block insertBlock(continueLabel); if (auto incrExpr = stmt->sideEffectExpression) diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 332833fff..98adc4a7c 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -24,6 +24,7 @@ struct myvector : IDifferentiable __init(T c) { + [ForceUnroll] for (int i = 0; i < N; i++) { values[i] = c; @@ -46,7 +47,7 @@ struct myvector : IDifferentiable static Differential dmul(This a, Differential b) { Differential output; - + for (int i = 0; i < N; i++) { output.values[i] = T.dmul(a.values[i], b.values[i]); @@ -73,6 +74,7 @@ __generic<T : IDFloat, let N : int> myvector<T, N> operator +(myvector<T, N> a, myvector<T, N> b) { myvector<T, N> output; + [ForceUnroll] for (int i = 0; i < N; i++) { output.values[i] = a.values[i] + b.values[i]; @@ -85,6 +87,7 @@ __generic<T : IDFloat, let N : int> myvector<T, N> operator *(myvector<T, N> a, myvector<T, N> b) { myvector<T, N> output; + [ForceUnroll] for (int i = 0; i < N; i++) { output.values[i] = a.values[i] * b.values[i]; @@ -97,6 +100,7 @@ __generic<T : IDFloat, let N : int> myvector<T, N> operator *(T a, myvector<T, N> b) { myvector<T, N> output; + [ForceUnroll] for (int i = 0; i < N; i++) { output.values[i] = a * b.values[i]; @@ -109,6 +113,7 @@ __generic<T : IDFloat, let N : int> T dot(myvector<T, N> a, myvector<T, N> b) { T curr = (T)0.0; + [ForceUnroll] for (int i = 0; i < N; i++) { curr = curr + (a.values[i] * b.values[i]); @@ -125,6 +130,7 @@ DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b) { T.Differential curr_d = (T.dzero()); T curr_p = (T)0.0; + [ForceUnroll] for (int i = 0; i < N; i++) { curr_p = curr_p + (a.p.values[i] * b.p.values[i]); @@ -145,6 +151,7 @@ struct lineardvector : IDifferentiable __init(vector<Real.Differential, N> a) { + [ForceUnroll] for (int i = 0; i < N; i++) { val.values[i] = a[i]; @@ -204,6 +211,7 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable [ForwardDifferentiable] __init(vector<Real, N> a) { + [ForceUnroll] for (int i = 0; i < N; i++) { val.values[i] = a[i]; diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop.slang index 828a06185..5598f6b71 100644 --- a/tests/autodiff/reverse-loop.slang +++ b/tests/autodiff/reverse-loop.slang @@ -13,7 +13,6 @@ float test_simple_loop(float y) { float t = y; - [MaxIters(3)] for (int i = 0; i < 3; i++) { t = t * t; diff --git a/tests/diagnostics/autodiff-data-flow.slang b/tests/diagnostics/autodiff-data-flow.slang index 93c76c07e..e8d9502e4 100644 --- a/tests/diagnostics/autodiff-data-flow.slang +++ b/tests/diagnostics/autodiff-data-flow.slang @@ -24,6 +24,11 @@ void g(float x) float val = 0; if (x > 5) val = x + 1; + + for (int i = 0; i < 5; i++) // Not ok, we can't infer the loop iterations because the body modifies induction var. + { + i = (int)x; + } return; } diff --git a/tests/diagnostics/autodiff-data-flow.slang.expected b/tests/diagnostics/autodiff-data-flow.slang.expected index 290ef974b..301f84985 100644 --- a/tests/diagnostics/autodiff-data-flow.slang.expected +++ b/tests/diagnostics/autodiff-data-flow.slang.expected @@ -6,6 +6,9 @@ tests/diagnostics/autodiff-data-flow.slang(15): error 41020: derivative cannot b tests/diagnostics/autodiff-data-flow.slang(22): error 41021: a differentiable function must have at least one differentiable output. void g(float x) ^ +tests/diagnostics/autodiff-data-flow.slang(28): error 30510: loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute. + for (int i = 0; i < 5; i++) // Not ok, we can't infer the loop iterations because the body modifies induction var. + ^~~ } standard output = { } diff --git a/tests/diagnostics/autodiff.slang b/tests/diagnostics/autodiff.slang index f9fed6753..935ef07cb 100644 --- a/tests/diagnostics/autodiff.slang +++ b/tests/diagnostics/autodiff.slang @@ -11,6 +11,21 @@ float f(float x) float val = 0; if (x > 5) val = x + 1; + + // warning: dynamic loop without [MaxIters] or [ForceUnroll] + for (int i = 0; i < (int)x; i++) + { + } + + [MaxIters(2)] + for (int i = 0; i < (int)x; i++) // OK + { + } + + for (int i = 0; i < 5; i++) // OK + { + } + return val; } diff --git a/tests/diagnostics/autodiff.slang.expected b/tests/diagnostics/autodiff.slang.expected index cd97bce76..952503d1c 100644 --- a/tests/diagnostics/autodiff.slang.expected +++ b/tests/diagnostics/autodiff.slang.expected @@ -1,12 +1,15 @@ result code = -1 standard error = { -tests/diagnostics/autodiff.slang(20): error 38031: 'no_diff' can only be used to decorate a call. +tests/diagnostics/autodiff.slang(16): error 30510: loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute. + for (int i = 0; i < (int)x; i++) + ^~~ +tests/diagnostics/autodiff.slang(35): error 38031: 'no_diff' can only be used to decorate a call. float x1 = no_diff x; // invalid use of no_diff here. ^~~~~~~ -tests/diagnostics/autodiff.slang(21): error 38032: use 'no_diff' on a call to a differentiable function has no meaning. +tests/diagnostics/autodiff.slang(36): error 38032: use 'no_diff' on a call to a differentiable function has no meaning. return no_diff f(x); // no_diff on a differentiable call has no meaning. ^~~~~~~ -tests/diagnostics/autodiff.slang(26): error 38033: cannot use 'no_diff' in a non-differentiable function. +tests/diagnostics/autodiff.slang(41): error 38033: cannot use 'no_diff' in a non-differentiable function. return no_diff nonDiff(x); // no_diff in a non-differentiable function ^~~~~~~ } diff --git a/tests/diagnostics/for-loop-warning.slang b/tests/diagnostics/for-loop-warning.slang new file mode 100644 index 000000000..226af46f5 --- /dev/null +++ b/tests/diagnostics/for-loop-warning.slang @@ -0,0 +1,60 @@ +//DIAGNOSTIC_TEST:SIMPLE: + + +float doSomething(int x) +{ + for (int i = 0; i < x; i--) // warn. + {} + for (int i = 0; i < 5; i-=-2) // ok. + {} + for (int j = 0; j < 3; j += 0) // warn. + {} + for (int i = 0; i < 5; i++) // ok + { + for (int j = 0; j < 3; i++) // warn. + {} + } + for (int i = 0; i < 5; i++) // ok + { + for (int j = 0; i < 4; j++) // warn. + {} + } + + [MaxIters(6)] // warn + for (int i = 0; i <= 6; i+=3) + { + } + + [MaxIters(6)] // warn + for (int i = 5; i >= 0; i -= 3) + { + } + [MaxIters(6)] // warn + for (int i = 5; i > 0; i--) + { + } + + [MaxIters(5)] // ok + for (int i = 0; i < 5; i++) // ok + { + } + + for (int i = 1; i < 0; i++) // warn + { + } + for (int i = 1; i >= 2; i--) // warn + { + } + for (int i = 1; i >= 1; i--) // ok + { + } + for (int i = 1; i > 1; i--) // warn + { + } + [MaxIters(5)] // ok, because the loop body modifies i so we can't infer the iterations. + for (int i = 0; i < 5; i+=2) + { + i--; + } + return 0.0; +} diff --git a/tests/diagnostics/for-loop-warning.slang.expected b/tests/diagnostics/for-loop-warning.slang.expected new file mode 100644 index 000000000..e37abb035 --- /dev/null +++ b/tests/diagnostics/for-loop-warning.slang.expected @@ -0,0 +1,35 @@ +result code = 0 +standard error = { +tests/diagnostics/for-loop-warning.slang(6): warning 30502: the for loop is modifiying variable 'i' in the opposite direction from loop exit condition. + for (int i = 0; i < x; i--) // warn. + ^~ +tests/diagnostics/for-loop-warning.slang(10): warning 30503: the for loop is not modifiying variable 'j' because the step size evaluates to 0. + for (int j = 0; j < 3; j += 0) // warn. + ^~ +tests/diagnostics/for-loop-warning.slang(14): warning 30500: the for loop initializes and checks variable 'j' but the side effect expression is modifying 'i'. + for (int j = 0; j < 3; i++) // warn. + ^ +tests/diagnostics/for-loop-warning.slang(19): warning 30501: the for loop initializes and modifies variable 'j' but the predicate expression is checking 'i'. + for (int j = 0; i < 4; j++) // warn. + ^ +tests/diagnostics/for-loop-warning.slang(42): warning 30505: the loop runs for 0 iterations and will be removed. + for (int i = 1; i < 0; i++) // warn + ^~~ +tests/diagnostics/for-loop-warning.slang(45): warning 30505: the loop runs for 0 iterations and will be removed. + for (int i = 1; i >= 2; i--) // warn + ^~~ +tests/diagnostics/for-loop-warning.slang(51): warning 30505: the loop runs for 0 iterations and will be removed. + for (int i = 1; i > 1; i--) // warn + ^~~ +tests/diagnostics/for-loop-warning.slang(23): warning 30504: the for loop is statically determined to terminate within 3 iterations, which is less than what [MaxIters] specifies. + [MaxIters(6)] // warn + ^~~~~~~~ +tests/diagnostics/for-loop-warning.slang(28): warning 30504: the for loop is statically determined to terminate within 2 iterations, which is less than what [MaxIters] specifies. + [MaxIters(6)] // warn + ^~~~~~~~ +tests/diagnostics/for-loop-warning.slang(32): warning 30504: the for loop is statically determined to terminate within 5 iterations, which is less than what [MaxIters] specifies. + [MaxIters(6)] // warn + ^~~~~~~~ +} +standard output = { +} |
