diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2020-06-12 13:30:32 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-06-12 13:30:32 -0700 |
| commit | 36a06f1289c9a68a261920ef5d34f075f2a43219 (patch) | |
| tree | dd11eba962d87da0d437a752b818ddc68f5b6603 /source/slang/slang-check-expr.cpp | |
| parent | 2359921bb7aba569b36ce3c1904b2dccbde5ffec (diff) | |
Diagnose circularly-defined constants (#1384)
* Diagnose circularly-defined constants
Work on #1374
This change diagnoses cases like the following:
```hlsl
static const int kCircular = kCircular;
static const int kInfinite = kInfinite + 1;
static const int kHere = kThere;
static const int kThere = kHere;
```
By diagnosing these as errors in the front-end we protect against infinite recursion leading to stack overflow crashes.
The basic approach is to have front-end constant folding track variables that are in use when folding a sub-expression, and then diagnosing an error if the same variable is encountered again while it is in use. In order to make sure the error occurs whether or not the constant is referenced, we invoke constant folding on all `static const` integer variables.
Limitations:
* This only works for integers, since that is all front-end constant folding applies to. A future change can/should catch circularity in constants at the IR level (and handle more types).
* This only works for constants. Circular references in the definition of a global variable are harder to diagnose, but at least shouldn't result in compiler crashes.
* This doesn't work across modules, or through generic specialization: anything that requires global knowledge won't be checked
* fixup: missing files
* fixup: review feedback
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 103 |
1 files changed, 69 insertions, 34 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 17bd5e263..d2e5afd85 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -669,8 +669,9 @@ namespace Slang return m_astBuilder->create<ConstantIntVal>(expr->value); } - IntVal* SemanticsVisitor::TryConstantFoldExpr( - InvokeExpr* invokeExpr) + IntVal* SemanticsVisitor::tryConstantFoldExpr( + InvokeExpr* invokeExpr, + ConstantFoldingCircularityInfo* circularityInfo) { // We need all the operands to the expression @@ -707,7 +708,7 @@ namespace Slang bool allConst = true; for (auto argExpr : invokeExpr->arguments) { - auto argVal = TryCheckIntegerConstantExpression(argExpr); + auto argVal = tryFoldIntegerConstantExpression(argExpr, circularityInfo); if (!argVal) return nullptr; @@ -795,8 +796,53 @@ namespace Slang return result; } - IntVal* SemanticsVisitor::TryConstantFoldExpr( - Expr* expr) + bool SemanticsVisitor::_checkForCircularityInConstantFolding( + Decl* decl, + ConstantFoldingCircularityInfo* circularityInfo) + { + // TODO: If the `decl` is already on the chain of `circularityInfo`, + // then we know that we are trying to recursively fold the + // same declaration as part of its own definition, and we need + // to diagnose that as an error. + // + for( auto info = circularityInfo; info; info = info->next ) + { + if(decl == info->decl) + { + getSink()->diagnose(decl, Diagnostics::variableUsedInItsOwnDefinition, decl); + return true; + } + } + + return false; + } + + IntVal* SemanticsVisitor::tryConstantFoldDeclRef( + DeclRef<VarDeclBase> const& declRef, + ConstantFoldingCircularityInfo* circularityInfo) + { + auto decl = declRef.getDecl(); + + if(_checkForCircularityInConstantFolding(decl, circularityInfo)) + return nullptr; + + // In HLSL, `static const` is used to mark compile-time constant expressions + if(!decl->hasModifier<HLSLStaticModifier>()) + return nullptr; + if(!decl->hasModifier<ConstModifier>()) + return nullptr; + + auto initExpr = getInitExpr(m_astBuilder, declRef); + if(!initExpr) + return nullptr; + + ConstantFoldingCircularityInfo newCircularityInfo(decl, circularityInfo); + return tryConstantFoldExpr(initExpr, &newCircularityInfo); + } + + IntVal* SemanticsVisitor::tryConstantFoldExpr( + Expr* expr, + ConstantFoldingCircularityInfo* circularityInfo) { // Unwrap any "identity" expressions while (auto parenExpr = as<ParenExpr>(expr)) @@ -825,40 +871,32 @@ namespace Slang // are defined in a way that can be used as a constant expression: if(auto varRef = declRef.as<VarDeclBase>()) { - auto varDecl = varRef.getDecl(); - - // In HLSL, `static const` is used to mark compile-time constant expressions - if(auto staticAttr = varDecl->findModifier<HLSLStaticModifier>()) - { - if(auto constAttr = varDecl->findModifier<ConstModifier>()) - { - // HLSL `static const` can be used as a constant expression - if(auto initExpr = getInitExpr(m_astBuilder, varRef)) - { - return TryConstantFoldExpr(initExpr); - } - } - } + return tryConstantFoldDeclRef(varRef, circularityInfo); } else if(auto enumRef = declRef.as<EnumCaseDecl>()) { // The cases in an `enum` declaration can also be used as constant expressions, if(auto tagExpr = getTagExpr(m_astBuilder, enumRef)) { - return TryConstantFoldExpr(tagExpr); + auto enumCaseDecl = enumRef.getDecl(); + if(_checkForCircularityInConstantFolding(enumCaseDecl, circularityInfo)) + return nullptr; + + ConstantFoldingCircularityInfo newCircularityInfo(enumCaseDecl, circularityInfo); + return tryConstantFoldExpr(tagExpr, &newCircularityInfo); } } } if(auto castExpr = as<TypeCastExpr>(expr)) { - auto val = TryConstantFoldExpr(castExpr->arguments[0]); + auto val = tryConstantFoldExpr(castExpr->arguments[0], circularityInfo); if(val) return val; } else if (auto invokeExpr = as<InvokeExpr>(expr)) { - auto val = TryConstantFoldExpr(invokeExpr); + auto val = tryConstantFoldExpr(invokeExpr, circularityInfo); if (val) return val; } @@ -866,21 +904,18 @@ namespace Slang return nullptr; } - IntVal* SemanticsVisitor::TryCheckIntegerConstantExpression(Expr* exp) + IntVal* SemanticsVisitor::tryFoldIntegerConstantExpression( + Expr* expr, + ConstantFoldingCircularityInfo* circularityInfo) { // Check if type is acceptable for an integer constant expression - if(auto basicType = as<BasicExpressionType>(exp->type.type)) - { - if(!isIntegerBaseType(basicType->baseType)) - return nullptr; - } - else - { + // + if(!isScalarIntegerType(expr->type)) return nullptr; - } // Consider operations that we might be able to constant-fold... - return TryConstantFoldExpr(exp); + // + return tryConstantFoldExpr(expr, circularityInfo); } IntVal* SemanticsVisitor::CheckIntegerConstantExpression(Expr* inExpr, DiagnosticSink* sink) @@ -894,7 +929,7 @@ namespace Slang // No need to issue further errors if the type coercion failed. if(IsErrorExpr(expr)) return nullptr; - auto result = TryCheckIntegerConstantExpression(expr); + auto result = tryFoldIntegerConstantExpression(expr, nullptr); if (!result && sink) { sink->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); @@ -915,7 +950,7 @@ namespace Slang // No need to issue further errors if the type coercion failed. if(IsErrorExpr(expr)) return nullptr; - auto result = TryConstantFoldExpr(expr); + auto result = tryConstantFoldExpr(expr, nullptr); if (!result) { getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); |
