diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 79 |
1 files changed, 75 insertions, 4 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 55ff90759..89a7373ee 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1906,9 +1906,17 @@ namespace Slang { auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression); + // If the base expression is a type, it means that this is an array declaration, + // then we should disable short-circuit in case there is logical expression in + // the subscript + auto baseType = baseExpr->type.Ptr(); + auto baseTypeType = as<TypeType>(baseType); + auto subVisitor = (baseTypeType && m_shouldShortCircuitLogicExpr)? + SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this; + for (auto& arg : subscriptExpr->indexExprs) { - arg = CheckTerm(arg); + arg = subVisitor.CheckTerm(arg); } // If anything went wrong in the base expression, @@ -1920,8 +1928,7 @@ namespace Slang // Otherwise, we need to look at the type of the base expression, // to figure out how subscripting should work. - auto baseType = baseExpr->type.Ptr(); - if (auto baseTypeType = as<TypeType>(baseType)) + if (baseTypeType) { // We are trying to "index" into a type, so we have an expression like `float[2]` // which should be interpreted as resolving to an array type. @@ -2371,12 +2378,69 @@ namespace Slang return result; } + Expr* SemanticsExprVisitor::convertToLogicOperatorExpr(InvokeExpr* expr) + { + LogicOperatorShortCircuitExpr* newExpr = nullptr; + + // If the logic expression is inside the generic parameter list, it cannot support short-circuit + // which will generate the ifelse branch. + if (!m_shouldShortCircuitLogicExpr) + { + return nullptr; + } + + if (auto varExpr = as<VarExpr>(expr->functionExpr)) + { + if ((varExpr->name->text == "&&") || (varExpr->name->text == "||")) + { + // We only use short-circuiting in scalar input, will fall back + // to non-short-circuiting in vector input. + bool shortCircuitSupport = true; + for (auto & arg : expr->arguments) + { + if(!as<BasicExpressionType>(arg->type.type)) + { + shortCircuitSupport = false; + } + } + + if (!shortCircuitSupport) + { + return nullptr; + } + + // We do the cast in the 2nd pass because we want to leave it for 'visitInvokeExpr' + // to handle if this expression doesn't support short-circuiting. + for (auto & arg : expr->arguments) + { + arg = coerce(CoercionSite::Argument, m_astBuilder->getBoolType(), arg); + } + + expr->functionExpr = CheckTerm(expr->functionExpr); + newExpr = m_astBuilder->create<LogicOperatorShortCircuitExpr>(); + if (varExpr->name->text == "&&") + { + newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::And; + } + else + { + newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::Or; + } + newExpr->loc = expr->loc; + newExpr->functionExpr = expr->functionExpr; + newExpr->type = m_astBuilder->getBoolType(); + newExpr->arguments = expr->arguments; + } + } + + return newExpr; + } + Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr) { // check the base expression first if (!expr->originalFunctionExpr) expr->originalFunctionExpr = expr->functionExpr; - expr->functionExpr = CheckTerm(expr->functionExpr); auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr; m_treatAsDifferentiableExpr = nullptr; // Next check the argument expressions @@ -2384,6 +2448,13 @@ namespace Slang { arg = CheckTerm(arg); } + + // if the expression is '&&' or '||', we will convert it + // to use short-circuit evaluation. + if (auto newExpr = convertToLogicOperatorExpr(expr)) + return newExpr; + + expr->functionExpr = CheckTerm(expr->functionExpr); m_treatAsDifferentiableExpr = treatAsDifferentiableExpr; // If we are in a differentiable function, register differential witness tables involved in |
