summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-stmt.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-10-29 14:49:26 +0800
committerGitHub <noreply@github.com>2024-10-29 14:49:26 +0800
commitf65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch)
treeea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-check-stmt.cpp
parenta729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff)
format
* format * Minor test fixes * enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-check-stmt.cpp')
-rw-r--r--source/slang/slang-check-stmt.cpp1196
1 files changed, 624 insertions, 572 deletions
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index 8b0e0b284..d02140d70 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -6,722 +6,774 @@
namespace Slang
{
- namespace
- {
- /// RAII-like type for establishing an "outer" statement during nested checks.
- ///
- /// The `SemanticsStmtVisitor` maintains a linked list of outer statements
- /// using `OuterStmtInfo` records stored on the recursive call stack during
- /// checking. This type creates a sub-`SemanticsStmtVisitor` that has one
- /// additional outer statement added to the stack of outer statements.
- ///
- /// The outer statements are used to validate and resolve things like
- /// the target of `break` or `continue` statements.
- ///
- struct WithOuterStmt : public SemanticsStmtVisitor
- {
- public:
- WithOuterStmt(SemanticsStmtVisitor* visitor, Stmt* outerStmt)
- : SemanticsStmtVisitor(visitor->withOuterStmts(&m_outerStmt))
- {
- m_outerStmt.next = visitor->getOuterStmts();
- m_outerStmt.stmt = outerStmt;
- }
-
- private:
- OuterStmtInfo m_outerStmt;
- };
- }
-
- void SemanticsVisitor::checkStmt(Stmt* stmt, SemanticsContext const& context)
+namespace
+{
+/// RAII-like type for establishing an "outer" statement during nested checks.
+///
+/// The `SemanticsStmtVisitor` maintains a linked list of outer statements
+/// using `OuterStmtInfo` records stored on the recursive call stack during
+/// checking. This type creates a sub-`SemanticsStmtVisitor` that has one
+/// additional outer statement added to the stack of outer statements.
+///
+/// The outer statements are used to validate and resolve things like
+/// the target of `break` or `continue` statements.
+///
+struct WithOuterStmt : public SemanticsStmtVisitor
+{
+public:
+ WithOuterStmt(SemanticsStmtVisitor* visitor, Stmt* outerStmt)
+ : SemanticsStmtVisitor(visitor->withOuterStmts(&m_outerStmt))
{
- if (!stmt) return;
- dispatchStmt(stmt, context);
- checkModifiers(stmt);
+ m_outerStmt.next = visitor->getOuterStmts();
+ m_outerStmt.stmt = outerStmt;
}
- void SemanticsStmtVisitor::visitDeclStmt(DeclStmt* stmt)
- {
- // When we encounter a declaration during statement checking,
- // we expect that it hasn't been checked yet (because otherwise
- // it would be referenced before its declaration point), but
- // we will bottleneck through the `ensureDecl()` path anyway,
- // to unify with the rest of semantic checking.
- //
- // TODO: This logic might not suffice for something like a
- // local `struct` declaration, where it would have members
- // that need to be recursively checked.
- //
- ensureDeclBase(stmt->decl, DeclCheckState::DefinitionChecked, this);
- }
+private:
+ OuterStmtInfo m_outerStmt;
+};
+} // namespace
- void SemanticsStmtVisitor::visitBlockStmt(BlockStmt* stmt)
- {
- // Make sure to fully check all nested agg type decls first.
- if (stmt->scopeDecl)
- {
- for (auto decl : stmt->scopeDecl->members)
- {
- if (as<AggTypeDeclBase>(decl))
- ensureAllDeclsRec(decl, DeclCheckState::DefinitionChecked);
- }
- }
- checkStmt(stmt->body);
- }
+void SemanticsVisitor::checkStmt(Stmt* stmt, SemanticsContext const& context)
+{
+ if (!stmt)
+ return;
+ dispatchStmt(stmt, context);
+ checkModifiers(stmt);
+}
- void SemanticsStmtVisitor::visitSeqStmt(SeqStmt* stmt)
+void SemanticsStmtVisitor::visitDeclStmt(DeclStmt* stmt)
+{
+ // When we encounter a declaration during statement checking,
+ // we expect that it hasn't been checked yet (because otherwise
+ // it would be referenced before its declaration point), but
+ // we will bottleneck through the `ensureDecl()` path anyway,
+ // to unify with the rest of semantic checking.
+ //
+ // TODO: This logic might not suffice for something like a
+ // local `struct` declaration, where it would have members
+ // that need to be recursively checked.
+ //
+ ensureDeclBase(stmt->decl, DeclCheckState::DefinitionChecked, this);
+}
+
+void SemanticsStmtVisitor::visitBlockStmt(BlockStmt* stmt)
+{
+ // Make sure to fully check all nested agg type decls first.
+ if (stmt->scopeDecl)
{
- for(auto ss : stmt->stmts)
+ for (auto decl : stmt->scopeDecl->members)
{
- checkStmt(ss);
+ if (as<AggTypeDeclBase>(decl))
+ ensureAllDeclsRec(decl, DeclCheckState::DefinitionChecked);
}
}
+ checkStmt(stmt->body);
+}
- void SemanticsStmtVisitor::visitLabelStmt(LabelStmt* stmt)
+void SemanticsStmtVisitor::visitSeqStmt(SeqStmt* stmt)
+{
+ for (auto ss : stmt->stmts)
{
- WithOuterStmt subContext(this, stmt);
- subContext.checkStmt(stmt->innerStmt);
+ checkStmt(ss);
}
+}
- void SemanticsStmtVisitor::checkStmt(Stmt* stmt)
- {
- SemanticsVisitor::checkStmt(stmt, *this);
- }
+void SemanticsStmtVisitor::visitLabelStmt(LabelStmt* stmt)
+{
+ WithOuterStmt subContext(this, stmt);
+ subContext.checkStmt(stmt->innerStmt);
+}
- template<typename T>
- T* SemanticsStmtVisitor::FindOuterStmt()
+void SemanticsStmtVisitor::checkStmt(Stmt* stmt)
+{
+ SemanticsVisitor::checkStmt(stmt, *this);
+}
+
+template<typename T>
+T* SemanticsStmtVisitor::FindOuterStmt()
+{
+ for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next)
{
- for(auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next)
- {
- auto outerStmt = outerStmtInfo->stmt;
- auto found = as<T>(outerStmt);
- if (found)
- return found;
- }
- return nullptr;
+ auto outerStmt = outerStmtInfo->stmt;
+ auto found = as<T>(outerStmt);
+ if (found)
+ return found;
}
+ return nullptr;
+}
- Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label)
+Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label)
+{
+ for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next)
{
- for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next)
+ auto outerStmt = outerStmtInfo->stmt;
+ auto found = as<LabelStmt>(outerStmt);
+ if (found)
{
- auto outerStmt = outerStmtInfo->stmt;
- auto found = as<LabelStmt>(outerStmt);
- if (found)
+ if (found->label.getName() == label)
{
- if (found->label.getName() == label)
- {
- return found->innerStmt;
- }
+ return found->innerStmt;
}
}
- return nullptr;
}
+ return nullptr;
+}
- void SemanticsStmtVisitor::visitBreakStmt(BreakStmt *stmt)
+void SemanticsStmtVisitor::visitBreakStmt(BreakStmt* stmt)
+{
+ Stmt* targetStmt = nullptr;
+ if (stmt->targetLabel.type == TokenType::Identifier)
{
- Stmt* targetStmt = nullptr;
- if (stmt->targetLabel.type == TokenType::Identifier)
+ // This is a break statement with an explicit target label.
+ // Try to find the outer stmt with the label.
+ targetStmt = findOuterStmtWithLabel(stmt->targetLabel.getName());
+ if (!targetStmt)
{
- // This is a break statement with an explicit target label.
- // Try to find the outer stmt with the label.
- targetStmt = findOuterStmtWithLabel(stmt->targetLabel.getName());
- if (!targetStmt)
- {
- getSink()->diagnose(stmt, Diagnostics::breakLabelNotFound, stmt->targetLabel.getName());
- }
- if (!as<BreakableStmt>(targetStmt))
- {
- getSink()->diagnose(stmt, Diagnostics::targetLabelDoesNotMarkBreakableStmt, stmt->targetLabel.getName());
- }
+ getSink()->diagnose(stmt, Diagnostics::breakLabelNotFound, stmt->targetLabel.getName());
}
- else
+ if (!as<BreakableStmt>(targetStmt))
{
- // For `break` statements without an explicit target,
- // find the inner most breakable stmt.
- targetStmt = FindOuterStmt<BreakableStmt>();
- if (!targetStmt)
- {
- getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop);
- }
+ getSink()->diagnose(
+ stmt,
+ Diagnostics::targetLabelDoesNotMarkBreakableStmt,
+ stmt->targetLabel.getName());
}
- stmt->parentStmt = targetStmt;
}
-
- void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt *stmt)
+ else
{
- auto outer = FindOuterStmt<LoopStmt>();
- if (!outer)
+ // For `break` statements without an explicit target,
+ // find the inner most breakable stmt.
+ targetStmt = FindOuterStmt<BreakableStmt>();
+ if (!targetStmt)
{
- getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop);
+ getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop);
}
- stmt->parentStmt = outer;
}
+ stmt->parentStmt = targetStmt;
+}
- Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr)
+void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt* stmt)
+{
+ auto outer = FindOuterStmt<LoopStmt>();
+ if (!outer)
{
- if (as<AssignExpr>(expr))
- {
- getSink()->diagnose(expr, Diagnostics::assignmentInPredicateExpr);
- }
- Expr* e = expr;
- e = CheckTerm(e);
- e = coerce(CoercionSite::General, m_astBuilder->getBoolType(), e);
- return e;
+ getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop);
}
+ stmt->parentStmt = outer;
+}
- void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt *stmt)
+Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr)
+{
+ if (as<AssignExpr>(expr))
{
- checkModifiers(stmt);
- WithOuterStmt subContext(this, stmt);
-
- stmt->predicate = checkPredicateExpr(stmt->predicate);
- subContext.checkStmt(stmt->statement);
- checkLoopInDifferentiableFunc(stmt);
+ getSink()->diagnose(expr, Diagnostics::assignmentInPredicateExpr);
}
+ Expr* e = expr;
+ e = CheckTerm(e);
+ e = coerce(CoercionSite::General, m_astBuilder->getBoolType(), e);
+ return e;
+}
- void SemanticsStmtVisitor::visitForStmt(ForStmt *stmt)
- {
- WithOuterStmt subContext(this, stmt);
- checkModifiers(stmt);
- checkStmt(stmt->initialStatement);
+void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt* stmt)
+{
+ checkModifiers(stmt);
+ WithOuterStmt subContext(this, stmt);
- if (stmt->predicateExpression)
- {
- stmt->predicateExpression = checkPredicateExpr(stmt->predicateExpression);
- }
- if (stmt->sideEffectExpression)
- {
- stmt->sideEffectExpression = CheckExpr(stmt->sideEffectExpression);
- }
- subContext.checkStmt(stmt->statement);
-
- tryInferLoopMaxIterations(stmt);
+ stmt->predicate = checkPredicateExpr(stmt->predicate);
+ subContext.checkStmt(stmt->statement);
+ checkLoopInDifferentiableFunc(stmt);
+}
- checkLoopInDifferentiableFunc(stmt);
- }
+void SemanticsStmtVisitor::visitForStmt(ForStmt* stmt)
+{
+ WithOuterStmt subContext(this, stmt);
+ checkModifiers(stmt);
+ checkStmt(stmt->initialStatement);
- Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal, ConstantFoldingKind kind)
+ if (stmt->predicateExpression)
{
- expr = CheckExpr(expr);
- auto intVal = CheckIntegerConstantExpression(expr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr, kind);
- if (outIntVal)
- *outIntVal = intVal;
- return expr;
+ stmt->predicateExpression = checkPredicateExpr(stmt->predicateExpression);
}
-
- void SemanticsStmtVisitor::visitCompileTimeForStmt(CompileTimeForStmt* stmt)
+ if (stmt->sideEffectExpression)
{
- WithOuterStmt subContext(this, stmt);
+ stmt->sideEffectExpression = CheckExpr(stmt->sideEffectExpression);
+ }
+ subContext.checkStmt(stmt->statement);
- stmt->varDecl->type.type = m_astBuilder->getIntType();
- addModifier(stmt->varDecl, m_astBuilder->create<ConstModifier>());
- stmt->varDecl->setCheckState(DeclCheckState::DefinitionChecked);
+ tryInferLoopMaxIterations(stmt);
- IntVal* rangeBeginVal = nullptr;
- IntVal* rangeEndVal = nullptr;
+ checkLoopInDifferentiableFunc(stmt);
+}
- if (stmt->rangeBeginExpr)
- {
- stmt->rangeBeginExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeBeginExpr, &rangeBeginVal, ConstantFoldingKind::LinkTime);
- }
- else
- {
- ConstantIntVal* rangeBeginConst = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 0);
- rangeBeginVal = rangeBeginConst;
- }
+Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant(
+ Expr* expr,
+ IntVal** outIntVal,
+ ConstantFoldingKind kind)
+{
+ expr = CheckExpr(expr);
+ auto intVal = CheckIntegerConstantExpression(
+ expr,
+ IntegerConstantExpressionCoercionType::AnyInteger,
+ nullptr,
+ kind);
+ if (outIntVal)
+ *outIntVal = intVal;
+ return expr;
+}
- stmt->rangeEndExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeEndExpr, &rangeEndVal, ConstantFoldingKind::LinkTime);
+void SemanticsStmtVisitor::visitCompileTimeForStmt(CompileTimeForStmt* stmt)
+{
+ WithOuterStmt subContext(this, stmt);
- stmt->rangeBeginVal = rangeBeginVal;
- stmt->rangeEndVal = rangeEndVal;
+ stmt->varDecl->type.type = m_astBuilder->getIntType();
+ addModifier(stmt->varDecl, m_astBuilder->create<ConstModifier>());
+ stmt->varDecl->setCheckState(DeclCheckState::DefinitionChecked);
- subContext.checkStmt(stmt->body);
- }
+ IntVal* rangeBeginVal = nullptr;
+ IntVal* rangeEndVal = nullptr;
- void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink)
+ if (stmt->rangeBeginExpr)
{
- auto blockStmt = as<BlockStmt>(stmt->body);
- if (!blockStmt)
- return;
+ stmt->rangeBeginExpr = checkExpressionAndExpectIntegerConstant(
+ stmt->rangeBeginExpr,
+ &rangeBeginVal,
+ ConstantFoldingKind::LinkTime);
+ }
+ else
+ {
+ ConstantIntVal* rangeBeginConst = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 0);
+ rangeBeginVal = rangeBeginConst;
+ }
- auto seqStmt = as<SeqStmt>(blockStmt->body);
- if (!seqStmt)
- return;
+ stmt->rangeEndExpr = checkExpressionAndExpectIntegerConstant(
+ stmt->rangeEndExpr,
+ &rangeEndVal,
+ ConstantFoldingKind::LinkTime);
+
+ stmt->rangeBeginVal = rangeBeginVal;
+ stmt->rangeEndVal = rangeEndVal;
+
+ subContext.checkStmt(stmt->body);
+}
- bool hasDefaultStmt = false;
- HashSet<Val*> caseStmtVals;
- for (auto& sStmt : seqStmt->stmts)
+void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink)
+{
+ auto blockStmt = as<BlockStmt>(stmt->body);
+ if (!blockStmt)
+ return;
+
+ auto seqStmt = as<SeqStmt>(blockStmt->body);
+ if (!seqStmt)
+ return;
+
+ bool hasDefaultStmt = false;
+ HashSet<Val*> caseStmtVals;
+ for (auto& sStmt : seqStmt->stmts)
+ {
+ if (auto caseStmt = as<CaseStmt>(sStmt))
{
- if (auto caseStmt = as<CaseStmt>(sStmt))
+ // check that all case tags are unique
+ if (caseStmt->exprVal)
{
- // check that all case tags are unique
- if (caseStmt->exprVal)
+ // exprVal contains the constant folded expr, that is checked for
+ // uniqueness within the scope of the switch statement.
+ if (!caseStmtVals.add(caseStmt->exprVal))
{
- // exprVal contains the constant folded expr, that is checked for
- // uniqueness within the scope of the switch statement.
- if (!caseStmtVals.add(caseStmt->exprVal))
- {
- sink->diagnose(sStmt, Diagnostics::switchDuplicateCases);
- return;
- }
+ sink->diagnose(sStmt, Diagnostics::switchDuplicateCases);
+ return;
}
}
- else if (as<DefaultStmt>(sStmt))
+ }
+ else if (as<DefaultStmt>(sStmt))
+ {
+ // check that there is at most one `default` clause
+ if (hasDefaultStmt)
{
- // check that there is at most one `default` clause
- if (hasDefaultStmt)
- {
- sink->diagnose(sStmt, Diagnostics::switchMultipleDefault);
- return;
- }
- hasDefaultStmt = true;
+ sink->diagnose(sStmt, Diagnostics::switchMultipleDefault);
+ return;
}
+ hasDefaultStmt = true;
}
}
+}
- void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt)
- {
- WithOuterStmt subContext(this, stmt);
+void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt)
+{
+ WithOuterStmt subContext(this, stmt);
- // TODO(tfoley): need to coerce condition to an integral type...
- stmt->condition = CheckExpr(stmt->condition);
- subContext.checkStmt(stmt->body);
+ // TODO(tfoley): need to coerce condition to an integral type...
+ stmt->condition = CheckExpr(stmt->condition);
+ subContext.checkStmt(stmt->body);
- // check the case value exits within the switch
- validateCaseStmts(stmt, getSink());
- }
+ // check the case value exits within the switch
+ validateCaseStmts(stmt, getSink());
+}
- void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt)
+void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt)
+{
+ auto switchStmt = FindOuterStmt<SwitchStmt>();
+ if (!switchStmt)
{
- auto switchStmt = FindOuterStmt<SwitchStmt>();
- if (!switchStmt)
- {
- getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
- return;
- }
+ getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
+ return;
+ }
- // Check that the type for the `case` is consistent with the type for the `switch`.
- auto expr = CheckExpr(stmt->expr);
- expr = coerce(CoercionSite::Argument, switchStmt->condition->type, expr);
+ // Check that the type for the `case` is consistent with the type for the `switch`.
+ auto expr = CheckExpr(stmt->expr);
+ expr = coerce(CoercionSite::Argument, switchStmt->condition->type, expr);
- // coerce to type being switch on, and ensure that value is a compile-time constant
- // The Vals in the AST are pointer-unique, making them easy to check for duplicates
- // by addeing them to a HashSet.
- auto exprVal = checkConstantIntVal(expr);
+ // coerce to type being switch on, and ensure that value is a compile-time constant
+ // The Vals in the AST are pointer-unique, making them easy to check for duplicates
+ // by addeing them to a HashSet.
+ auto exprVal = checkConstantIntVal(expr);
- stmt->expr = expr;
- stmt->exprVal = exprVal;
- stmt->parentStmt = switchStmt;
- }
+ stmt->expr = expr;
+ stmt->exprVal = exprVal;
+ stmt->parentStmt = switchStmt;
+}
- void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt)
+void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt)
+{
+ WithOuterStmt subContext(this, stmt);
+ HashSet<Stmt*> checkedStmt;
+ for (auto caseStmt : stmt->targetCases)
{
- WithOuterStmt subContext(this, stmt);
- HashSet<Stmt*> checkedStmt;
- for (auto caseStmt : stmt->targetCases)
- {
- if (checkedStmt.contains(caseStmt->body))
- continue;
- subContext.checkStmt(caseStmt);
- checkedStmt.add(caseStmt->body);
- }
+ if (checkedStmt.contains(caseStmt->body))
+ continue;
+ subContext.checkStmt(caseStmt);
+ checkedStmt.add(caseStmt->body);
}
+}
- void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt)
+void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt)
+{
+ auto switchStmt = FindOuterStmt<TargetSwitchStmt>();
+ CapabilitySet set((CapabilityName)stmt->capability);
+ if (getShared()->isInLanguageServer() &&
+ getShared()->getSession()->getCompletionRequestTokenName() ==
+ stmt->capabilityToken.getName())
{
- auto switchStmt = FindOuterStmt<TargetSwitchStmt>();
- CapabilitySet set((CapabilityName)stmt->capability);
- if (getShared()->isInLanguageServer() && getShared()->getSession()->getCompletionRequestTokenName() == stmt->capabilityToken.getName())
- {
- getShared()->getLinkage()->contentAssistInfo.completionSuggestions.scopeKind = CompletionSuggestions::ScopeKind::Capabilities;
- }
-
- if (stmt->capabilityToken.getContentLength() != 0 &&
- (set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty()))
- {
- getSink()->diagnose(
- stmt->capabilityToken.loc,
- Diagnostics::invalidTargetSwitchCase,
- capabilityNameToString((CapabilityName)stmt->capability));
- }
- if (!switchStmt)
- {
- getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
- }
- WithOuterStmt subContext(this, stmt);
- subContext.checkStmt(stmt->body);
+ getShared()->getLinkage()->contentAssistInfo.completionSuggestions.scopeKind =
+ CompletionSuggestions::ScopeKind::Capabilities;
}
- void SemanticsStmtVisitor::visitIntrinsicAsmStmt(IntrinsicAsmStmt* stmt)
+ if (stmt->capabilityToken.getContentLength() != 0 &&
+ (set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty()))
{
- WithOuterStmt subContext(this, stmt);
- for (auto& arg : stmt->args)
- arg = subContext.CheckExpr(arg);
+ getSink()->diagnose(
+ stmt->capabilityToken.loc,
+ Diagnostics::invalidTargetSwitchCase,
+ capabilityNameToString((CapabilityName)stmt->capability));
}
-
- void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt)
+ if (!switchStmt)
{
- auto switchStmt = FindOuterStmt<SwitchStmt>();
- if (!switchStmt)
- {
- getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch);
- }
- stmt->parentStmt = switchStmt;
+ getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
}
+ WithOuterStmt subContext(this, stmt);
+ subContext.checkStmt(stmt->body);
+}
- void SemanticsStmtVisitor::visitIfStmt(IfStmt *stmt)
- {
- stmt->predicate = checkPredicateExpr(stmt->predicate);
- checkStmt(stmt->positiveStatement);
- checkStmt(stmt->negativeStatement);
- }
+void SemanticsStmtVisitor::visitIntrinsicAsmStmt(IntrinsicAsmStmt* stmt)
+{
+ WithOuterStmt subContext(this, stmt);
+ for (auto& arg : stmt->args)
+ arg = subContext.CheckExpr(arg);
+}
- void SemanticsStmtVisitor::visitUnparsedStmt(UnparsedStmt*)
+void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt)
+{
+ auto switchStmt = FindOuterStmt<SwitchStmt>();
+ if (!switchStmt)
{
- // Nothing to do
+ getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch);
}
+ stmt->parentStmt = switchStmt;
+}
- void SemanticsStmtVisitor::visitEmptyStmt(EmptyStmt*)
- {
- // Nothing to do
- }
+void SemanticsStmtVisitor::visitIfStmt(IfStmt* stmt)
+{
+ stmt->predicate = checkPredicateExpr(stmt->predicate);
+ checkStmt(stmt->positiveStatement);
+ checkStmt(stmt->negativeStatement);
+}
- void SemanticsStmtVisitor::visitDiscardStmt(DiscardStmt*)
+void SemanticsStmtVisitor::visitUnparsedStmt(UnparsedStmt*)
+{
+ // Nothing to do
+}
+
+void SemanticsStmtVisitor::visitEmptyStmt(EmptyStmt*)
+{
+ // Nothing to do
+}
+
+void SemanticsStmtVisitor::visitDiscardStmt(DiscardStmt*)
+{
+ // Nothing to do
+}
+
+void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt)
+{
+ auto function = getParentFunc();
+ if (!stmt->expression)
{
- // Nothing to do
+ if (function && !function->returnType.equals(m_astBuilder->getVoidType()) &&
+ !as<ConstructorDecl>(function))
+ {
+ getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression);
+ }
}
-
- void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt *stmt)
+ else
{
- auto function = getParentFunc();
- if (!stmt->expression)
+ stmt->expression = CheckTerm(stmt->expression);
+ if (!stmt->expression->type->equals(m_astBuilder->getErrorType()))
{
- if (function && !function->returnType.equals(m_astBuilder->getVoidType()) && !as<ConstructorDecl>(function))
+ if (function)
{
- getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression);
+ stmt->expression =
+ coerce(CoercionSite::Return, function->returnType.Ptr(), stmt->expression);
}
- }
- else
- {
- stmt->expression = CheckTerm(stmt->expression);
- if (!stmt->expression->type->equals(m_astBuilder->getErrorType()))
+ else
{
- if (function)
- {
- stmt->expression = coerce(CoercionSite::Return, function->returnType.Ptr(), stmt->expression);
- }
- else
- {
- // TODO(tfoley): this case currently gets triggered for member functions,
- // which aren't being checked consistently (because of the whole symbol
- // table idea getting in the way).
+ // TODO(tfoley): this case currently gets triggered for member functions,
+ // which aren't being checked consistently (because of the whole symbol
+ // table idea getting in the way).
-// getSink()->diagnose(stmt, Diagnostics::unimplemented, "case for return stmt");
- }
+ // getSink()->diagnose(stmt,
+ // Diagnostics::unimplemented, "case for return stmt");
}
}
}
+}
- void SemanticsStmtVisitor::visitWhileStmt(WhileStmt *stmt)
- {
- checkModifiers(stmt);
- WithOuterStmt subContext(this, stmt);
- stmt->predicate = checkPredicateExpr(stmt->predicate);
- subContext.checkStmt(stmt->statement);
- checkLoopInDifferentiableFunc(stmt);
- }
+void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt)
+{
+ checkModifiers(stmt);
+ WithOuterStmt subContext(this, stmt);
+ stmt->predicate = checkPredicateExpr(stmt->predicate);
+ subContext.checkStmt(stmt->statement);
+ checkLoopInDifferentiableFunc(stmt);
+}
- void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt *stmt)
+void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt* stmt)
+{
+ stmt->expression = CheckExpr(stmt->expression);
+ if (auto operatorExpr = as<OperatorExpr>(stmt->expression))
{
- stmt->expression = CheckExpr(stmt->expression);
- if (auto operatorExpr = as<OperatorExpr>(stmt->expression))
+ if (auto func = as<VarExpr>(operatorExpr->functionExpr))
{
- if (auto func = as<VarExpr>(operatorExpr->functionExpr))
+ if (func->name && func->name->text == "==")
{
- if (func->name && func->name->text == "==")
- {
- getSink()->diagnose(operatorExpr, Diagnostics::danglingEqualityExpr);
- }
+ getSink()->diagnose(operatorExpr, Diagnostics::danglingEqualityExpr);
}
}
}
+}
- void SemanticsStmtVisitor::tryInferLoopMaxIterations(ForStmt* stmt)
- {
- // If a for loop is in the form of `for (var = initialVal; var $compareOp otherVal; var sideEffectOp operand)`
- // we will try to constant fold the operands and see if we can statically determine the maximum number of
- // iterations this loop will run, and insert the inferred result as a `[MaxIters]` attribute on the stmt.
- //
- // ++, --, +=, -= are supported in side effect expressions.
- // >, <, >=, <= are supported in predicate expressions.
- // induction variable can appear in either side of the expressions.
- //
- // Other forms like for (var1 = .., var2 = ..; ) will not be recognized here.
- // If we see suspicious code like `for (int i = 0; i < 5; j++)`, we will produce a warning along the way.
- //
- DeclRef<Decl> predicateVar = {};
- Expr* initialVal = nullptr;
- DeclRef<Decl> initialVar = {};
- if (auto varStmt = as<DeclStmt>(stmt->initialStatement))
- {
- auto varDecl = as<VarDecl>(varStmt->decl);
- if (!varDecl)
- return;
- initialVar = makeDeclRef<Decl>(varDecl);
- initialVal = varDecl->initExpr;
- }
- else if (auto exprStmt = as<ExpressionStmt>(stmt->initialStatement))
- {
- auto assignExpr = as<AssignExpr>(exprStmt->expression);
- if (!assignExpr)
- return;
- auto varExpr = as<VarExpr>(assignExpr->left);
- if (!varExpr)
- return;
- initialVar = varExpr->declRef;
- initialVal = assignExpr->right;
- }
- else
- return;
-
- auto initialLitVal =
- as<ConstantIntVal>(tryFoldIntegerConstantExpression(initialVal, ConstantFoldingKind::CompileTime, nullptr));
-
- ConstantIntVal* finalVal = nullptr;
- auto binaryExpr = as<InfixExpr>(stmt->predicateExpression);
- if (!binaryExpr)
- return;
- auto compareFuncExpr = as<DeclRefExpr>(binaryExpr->functionExpr);
- if (!compareFuncExpr)
- return;
- if (!compareFuncExpr->declRef.getDecl())
+void SemanticsStmtVisitor::tryInferLoopMaxIterations(ForStmt* stmt)
+{
+ // If a for loop is in the form of `for (var = initialVal; var $compareOp otherVal; var
+ // sideEffectOp operand)` we will try to constant fold the operands and see if we can statically
+ // determine the maximum number of iterations this loop will run, and insert the inferred result
+ // as a `[MaxIters]` attribute on the stmt.
+ //
+ // ++, --, +=, -= are supported in side effect expressions.
+ // >, <, >=, <= are supported in predicate expressions.
+ // induction variable can appear in either side of the expressions.
+ //
+ // Other forms like for (var1 = .., var2 = ..; ) will not be recognized here.
+ // If we see suspicious code like `for (int i = 0; i < 5; j++)`, we will produce a warning along
+ // the way.
+ //
+ DeclRef<Decl> predicateVar = {};
+ Expr* initialVal = nullptr;
+ DeclRef<Decl> initialVar = {};
+ if (auto varStmt = as<DeclStmt>(stmt->initialStatement))
+ {
+ auto varDecl = as<VarDecl>(varStmt->decl);
+ if (!varDecl)
return;
- IROp compareOp = kIROp_Nop;
- if (auto intrinsicOpModifier = compareFuncExpr->declRef.getDecl()->findModifier<IntrinsicOpModifier>())
- {
- compareOp = (IROp)intrinsicOpModifier->op;
- }
- else
- {
+ initialVar = makeDeclRef<Decl>(varDecl);
+ initialVal = varDecl->initExpr;
+ }
+ else if (auto exprStmt = as<ExpressionStmt>(stmt->initialStatement))
+ {
+ auto assignExpr = as<AssignExpr>(exprStmt->expression);
+ if (!assignExpr)
return;
- }
- if (binaryExpr->arguments.getCount() != 2)
+ auto varExpr = as<VarExpr>(assignExpr->left);
+ if (!varExpr)
return;
- auto leftCompareOperand = binaryExpr->arguments[0];
- auto rightCompareOperand = binaryExpr->arguments[1];
- if (!leftCompareOperand)
+ initialVar = varExpr->declRef;
+ initialVal = assignExpr->right;
+ }
+ else
+ return;
+
+ auto initialLitVal = as<ConstantIntVal>(
+ tryFoldIntegerConstantExpression(initialVal, ConstantFoldingKind::CompileTime, nullptr));
+
+ ConstantIntVal* finalVal = nullptr;
+ auto binaryExpr = as<InfixExpr>(stmt->predicateExpression);
+ if (!binaryExpr)
+ return;
+ auto compareFuncExpr = as<DeclRefExpr>(binaryExpr->functionExpr);
+ if (!compareFuncExpr)
+ return;
+ if (!compareFuncExpr->declRef.getDecl())
+ return;
+ IROp compareOp = kIROp_Nop;
+ if (auto intrinsicOpModifier =
+ compareFuncExpr->declRef.getDecl()->findModifier<IntrinsicOpModifier>())
+ {
+ compareOp = (IROp)intrinsicOpModifier->op;
+ }
+ else
+ {
+ return;
+ }
+ if (binaryExpr->arguments.getCount() != 2)
+ return;
+ auto leftCompareOperand = binaryExpr->arguments[0];
+ auto rightCompareOperand = binaryExpr->arguments[1];
+ if (!leftCompareOperand)
+ return;
+ if (!rightCompareOperand)
+ return;
+ if (auto rightVal = tryFoldIntegerConstantExpression(
+ binaryExpr->arguments[1],
+ ConstantFoldingKind::CompileTime,
+ nullptr))
+ {
+ auto leftVar = as<VarExpr>(leftCompareOperand);
+ if (!leftVar)
return;
- if (!rightCompareOperand)
+ predicateVar = leftVar->declRef;
+ finalVal = as<ConstantIntVal>(rightVal);
+ }
+ else if (
+ auto leftVal = tryFoldIntegerConstantExpression(
+ binaryExpr->arguments[0],
+ ConstantFoldingKind::CompileTime,
+ nullptr))
+ {
+ auto rightVar = as<VarExpr>(rightCompareOperand);
+ if (!rightVar)
return;
- if (auto rightVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[1], ConstantFoldingKind::CompileTime, nullptr))
+ predicateVar = rightVar->declRef;
+ finalVal = as<ConstantIntVal>(leftVal);
+ compareOp = getSwapSideComparisonOp(compareOp);
+ }
+ else
+ {
+ // If neither left or right is constant, we assume left is variable and continue checking.
+ if (auto leftVar = as<VarExpr>(leftCompareOperand))
{
- auto leftVar = as<VarExpr>(leftCompareOperand);
- if (!leftVar)
- return;
predicateVar = leftVar->declRef;
- finalVal = as<ConstantIntVal>(rightVal);
}
- else if (auto leftVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[0], ConstantFoldingKind::CompileTime, nullptr))
+ if (auto rightVar = as<VarExpr>(rightCompareOperand))
{
- auto rightVar = as<VarExpr>(rightCompareOperand);
- if (!rightVar)
- return;
- predicateVar = rightVar->declRef;
- finalVal = as<ConstantIntVal>(leftVal);
- compareOp = getSwapSideComparisonOp(compareOp);
- }
- else
- {
- // If neither left or right is constant, we assume left is variable and continue checking.
- if (auto leftVar = as<VarExpr>(leftCompareOperand))
- {
- predicateVar = leftVar->declRef;
- }
- if (auto rightVar = as<VarExpr>(rightCompareOperand))
+ if (rightVar->declRef == initialVar)
{
- if (rightVar->declRef == initialVar)
- {
- predicateVar = rightVar->declRef;
- compareOp = getSwapSideComparisonOp(compareOp);
- }
+ predicateVar = rightVar->declRef;
+ compareOp = getSwapSideComparisonOp(compareOp);
}
}
+ }
- switch (compareOp)
- {
- case kIROp_Less:
- case kIROp_Leq:
- case kIROp_Greater:
- case kIROp_Geq:
- break;
- default:
- return;
- }
+ switch (compareOp)
+ {
+ case kIROp_Less:
+ case kIROp_Leq:
+ case kIROp_Greater:
+ case kIROp_Geq: break;
+ default: return;
+ }
- ConstantIntVal* stepSize = nullptr;
- IROp sideEffectFuncOp = kIROp_Nop;
- auto opSideEffectExpr = as<InvokeExpr>(stmt->sideEffectExpression);
- if (!opSideEffectExpr)
+ ConstantIntVal* stepSize = nullptr;
+ IROp sideEffectFuncOp = kIROp_Nop;
+ auto opSideEffectExpr = as<InvokeExpr>(stmt->sideEffectExpression);
+ if (!opSideEffectExpr)
+ return;
+ auto sideEffectFuncExpr = as<DeclRefExpr>(opSideEffectExpr->functionExpr);
+ if (!sideEffectFuncExpr)
+ return;
+ auto sideEffectFuncDecl = sideEffectFuncExpr->declRef.getDecl();
+ if (!sideEffectFuncDecl)
+ return;
+ if (auto opName = sideEffectFuncDecl->getName())
+ {
+ if (opName->text == "++")
+ sideEffectFuncOp = kIROp_Add;
+ else if (opName->text == "--")
+ sideEffectFuncOp = kIROp_Sub;
+ else if (opName->text == "+=")
+ sideEffectFuncOp = kIROp_Add;
+ else if (opName->text == "-=")
+ sideEffectFuncOp = kIROp_Sub;
+ else
return;
- auto sideEffectFuncExpr = as<DeclRefExpr>(opSideEffectExpr->functionExpr);
- if (!sideEffectFuncExpr)
+ }
+ if (opSideEffectExpr->arguments.getCount())
+ {
+ auto varExpr = as<VarExpr>(opSideEffectExpr->arguments[0]);
+ if (!varExpr)
return;
- auto sideEffectFuncDecl = sideEffectFuncExpr->declRef.getDecl();
- if (!sideEffectFuncDecl)
+ if (varExpr->declRef.getDecl() != initialVar.getDecl())
+ {
+ // If the user writes something like `for (int i = 0; i < 5; j++)`,
+ // it is most likely a bug, so we issue a warning.
+ if (predicateVar == initialVar)
+ getSink()->diagnose(
+ varExpr,
+ Diagnostics::forLoopSideEffectChangingDifferentVar,
+ initialVar,
+ varExpr->declRef);
return;
- if (auto opName = sideEffectFuncDecl->getName())
- {
- if (opName->text == "++")
- sideEffectFuncOp = kIROp_Add;
- else if (opName->text == "--")
- sideEffectFuncOp = kIROp_Sub;
- else if (opName->text == "+=")
- sideEffectFuncOp = kIROp_Add;
- else if (opName->text == "-=")
- sideEffectFuncOp = kIROp_Sub;
- else
- return;
- }
- if (opSideEffectExpr->arguments.getCount())
- {
- auto varExpr = as<VarExpr>(opSideEffectExpr->arguments[0]);
- if (!varExpr)
- return;
- if (varExpr->declRef.getDecl() != initialVar.getDecl())
- {
- // If the user writes something like `for (int i = 0; i < 5; j++)`,
- // it is most likely a bug, so we issue a warning.
- if (predicateVar == initialVar)
- getSink()->diagnose(varExpr, Diagnostics::forLoopSideEffectChangingDifferentVar, initialVar, varExpr->declRef);
- return;
- }
}
- else
+ }
+ else
+ return;
+ if (opSideEffectExpr->arguments.getCount() == 2)
+ {
+ auto stepVal = tryFoldIntegerConstantExpression(
+ opSideEffectExpr->arguments[1],
+ ConstantFoldingKind::CompileTime,
+ nullptr);
+ if (!stepVal)
return;
- if (opSideEffectExpr->arguments.getCount() == 2)
- {
- auto stepVal = tryFoldIntegerConstantExpression(opSideEffectExpr->arguments[1], ConstantFoldingKind::CompileTime, nullptr);
- if (!stepVal)
- return;
- if (auto constantIntVal = as<ConstantIntVal>(stepVal))
- {
- stepSize = constantIntVal;
- }
- }
- else
+ if (auto constantIntVal = as<ConstantIntVal>(stepVal))
{
- stepSize = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
+ stepSize = constantIntVal;
}
+ }
+ else
+ {
+ stepSize = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
+ }
- if (predicateVar.getDecl() != initialVar.getDecl())
+ if (predicateVar.getDecl() != initialVar.getDecl())
+ {
+ if (predicateVar)
+ getSink()->diagnose(
+ stmt->predicateExpression,
+ Diagnostics::forLoopPredicateCheckingDifferentVar,
+ initialVar,
+ predicateVar);
+ return;
+ }
+ if (!stepSize)
+ return;
+ if (stepSize->getValue() > 0)
+ {
+ if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater ||
+ sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less)
{
- if (predicateVar)
- getSink()->diagnose(stmt->predicateExpression, Diagnostics::forLoopPredicateCheckingDifferentVar, initialVar, predicateVar);
- return;
- }
- if (!stepSize)
+ getSink()->diagnose(
+ stmt->sideEffectExpression,
+ Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection,
+ initialVar);
return;
- if (stepSize->getValue() > 0)
- {
- if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater ||
- sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less)
- {
- getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar);
- return;
- }
}
- else if (stepSize->getValue() < 0)
- {
- if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less ||
- sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater)
- {
- getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar);
- return;
- }
- }
- else
+ }
+ else if (stepSize->getValue() < 0)
+ {
+ if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less ||
+ sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater)
{
- getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopNotModifyingIterationVariable, initialVar);
+ getSink()->diagnose(
+ stmt->sideEffectExpression,
+ Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection,
+ initialVar);
return;
}
-
- if (!initialLitVal || !finalVal)
- return;
-
- auto absStepSize = abs(stepSize->getValue());
- int adjustment = 0;
- if (compareOp == kIROp_Geq || compareOp == kIROp_Leq)
- adjustment = 1;
-
- auto iterations = (Math::Max(finalVal->getValue(), initialLitVal->getValue()) -
- Math::Min(finalVal->getValue(), initialLitVal->getValue()) + absStepSize - 1 + adjustment) /
- absStepSize;
- switch (compareOp)
- {
- case kIROp_Geq:
- case kIROp_Greater:
- // Expect final value to be less than initial value.
- if (finalVal->getValue() > initialLitVal->getValue())
- iterations = 0;
- break;
- case kIROp_Leq:
- case kIROp_Less:
- if (finalVal->getValue() < initialLitVal->getValue())
- iterations = 0;
- break;
- }
- if (iterations == 0)
- {
- getSink()->diagnose(stmt, Diagnostics::loopRunsForZeroIterations);
- }
-
- // Note: the inferred max iterations may not be valid if the loop body
- // also modifies the induction variable.
- // We detect this case during lower-to-ir and will remove the `InferredMaxItersAttribute`
- // if the loop body modifies the induction variable.
- //
- auto maxItersAttr = m_astBuilder->create<InferredMaxItersAttribute>();
- auto litExpr = m_astBuilder->create<LiteralExpr>();
- litExpr->type.type = m_astBuilder->getIntType();
- litExpr->token.setName(getNamePool()->getName(String(iterations)));
- maxItersAttr->args.add(litExpr);
- maxItersAttr->intArgVals.add(m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations));
- maxItersAttr->value = (int32_t)iterations;
- maxItersAttr->inductionVar = initialVar;
- addModifier(stmt, maxItersAttr);
+ }
+ else
+ {
+ getSink()->diagnose(
+ stmt->sideEffectExpression,
+ Diagnostics::forLoopNotModifyingIterationVariable,
+ initialVar);
return;
}
- void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt)
- {
- SLANG_UNUSED(stmt);
- if (getParentDifferentiableAttribute())
- {
- if (!getParentFunc())
- return;
-
- // If the function is itself a derivative, or has a user defined derivative,
- // then we don't require anything.
+ if (!initialLitVal || !finalVal)
+ return;
- if (getParentFunc()->findModifier<ForwardDerivativeOfAttribute>())
- return;
- if (getParentFunc()->findModifier<ForwardDerivativeAttribute>())
- return;
- if (getParentFunc()->findModifier<BackwardDerivativeOfAttribute>())
- return;
- if (getParentFunc()->findModifier<BackwardDerivativeAttribute>())
- return;
- }
- }
+ auto absStepSize = abs(stepSize->getValue());
+ int adjustment = 0;
+ if (compareOp == kIROp_Geq || compareOp == kIROp_Leq)
+ adjustment = 1;
+
+ auto iterations = (Math::Max(finalVal->getValue(), initialLitVal->getValue()) -
+ Math::Min(finalVal->getValue(), initialLitVal->getValue()) + absStepSize -
+ 1 + adjustment) /
+ absStepSize;
+ switch (compareOp)
+ {
+ case kIROp_Geq:
+ case kIROp_Greater:
+ // Expect final value to be less than initial value.
+ if (finalVal->getValue() > initialLitVal->getValue())
+ iterations = 0;
+ break;
+ case kIROp_Leq:
+ case kIROp_Less:
+ if (finalVal->getValue() < initialLitVal->getValue())
+ iterations = 0;
+ break;
+ }
+ if (iterations == 0)
+ {
+ getSink()->diagnose(stmt, Diagnostics::loopRunsForZeroIterations);
+ }
+
+ // Note: the inferred max iterations may not be valid if the loop body
+ // also modifies the induction variable.
+ // We detect this case during lower-to-ir and will remove the `InferredMaxItersAttribute`
+ // if the loop body modifies the induction variable.
+ //
+ auto maxItersAttr = m_astBuilder->create<InferredMaxItersAttribute>();
+ auto litExpr = m_astBuilder->create<LiteralExpr>();
+ litExpr->type.type = m_astBuilder->getIntType();
+ litExpr->token.setName(getNamePool()->getName(String(iterations)));
+ maxItersAttr->args.add(litExpr);
+ maxItersAttr->intArgVals.add(m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations));
+ maxItersAttr->value = (int32_t)iterations;
+ maxItersAttr->inductionVar = initialVar;
+ addModifier(stmt, maxItersAttr);
+ return;
+}
- void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt*stmt)
+void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt)
+{
+ SLANG_UNUSED(stmt);
+ if (getParentDifferentiableAttribute())
{
- stmt->device = CheckExpr(stmt->device);
- stmt->gridDims = CheckExpr(stmt->gridDims);
- ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::DefinitionChecked, this);
- WithOuterStmt subContext(this, stmt);
- stmt->kernelCall = subContext.CheckExpr(stmt->kernelCall);
- return;
+ if (!getParentFunc())
+ return;
+
+ // If the function is itself a derivative, or has a user defined derivative,
+ // then we don't require anything.
+
+ if (getParentFunc()->findModifier<ForwardDerivativeOfAttribute>())
+ return;
+ if (getParentFunc()->findModifier<ForwardDerivativeAttribute>())
+ return;
+ if (getParentFunc()->findModifier<BackwardDerivativeOfAttribute>())
+ return;
+ if (getParentFunc()->findModifier<BackwardDerivativeAttribute>())
+ return;
}
}
+
+void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt* stmt)
+{
+ stmt->device = CheckExpr(stmt->device);
+ stmt->gridDims = CheckExpr(stmt->gridDims);
+ ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::DefinitionChecked, this);
+ WithOuterStmt subContext(this, stmt);
+ stmt->kernelCall = subContext.CheckExpr(stmt->kernelCall);
+ return;
+}
+} // namespace Slang