summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp79
1 files changed, 75 insertions, 4 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 55ff90759..89a7373ee 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1906,9 +1906,17 @@ namespace Slang
{
auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression);
+ // If the base expression is a type, it means that this is an array declaration,
+ // then we should disable short-circuit in case there is logical expression in
+ // the subscript
+ auto baseType = baseExpr->type.Ptr();
+ auto baseTypeType = as<TypeType>(baseType);
+ auto subVisitor = (baseTypeType && m_shouldShortCircuitLogicExpr)?
+ SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this;
+
for (auto& arg : subscriptExpr->indexExprs)
{
- arg = CheckTerm(arg);
+ arg = subVisitor.CheckTerm(arg);
}
// If anything went wrong in the base expression,
@@ -1920,8 +1928,7 @@ namespace Slang
// Otherwise, we need to look at the type of the base expression,
// to figure out how subscripting should work.
- auto baseType = baseExpr->type.Ptr();
- if (auto baseTypeType = as<TypeType>(baseType))
+ if (baseTypeType)
{
// We are trying to "index" into a type, so we have an expression like `float[2]`
// which should be interpreted as resolving to an array type.
@@ -2371,12 +2378,69 @@ namespace Slang
return result;
}
+ Expr* SemanticsExprVisitor::convertToLogicOperatorExpr(InvokeExpr* expr)
+ {
+ LogicOperatorShortCircuitExpr* newExpr = nullptr;
+
+ // If the logic expression is inside the generic parameter list, it cannot support short-circuit
+ // which will generate the ifelse branch.
+ if (!m_shouldShortCircuitLogicExpr)
+ {
+ return nullptr;
+ }
+
+ if (auto varExpr = as<VarExpr>(expr->functionExpr))
+ {
+ if ((varExpr->name->text == "&&") || (varExpr->name->text == "||"))
+ {
+ // We only use short-circuiting in scalar input, will fall back
+ // to non-short-circuiting in vector input.
+ bool shortCircuitSupport = true;
+ for (auto & arg : expr->arguments)
+ {
+ if(!as<BasicExpressionType>(arg->type.type))
+ {
+ shortCircuitSupport = false;
+ }
+ }
+
+ if (!shortCircuitSupport)
+ {
+ return nullptr;
+ }
+
+ // We do the cast in the 2nd pass because we want to leave it for 'visitInvokeExpr'
+ // to handle if this expression doesn't support short-circuiting.
+ for (auto & arg : expr->arguments)
+ {
+ arg = coerce(CoercionSite::Argument, m_astBuilder->getBoolType(), arg);
+ }
+
+ expr->functionExpr = CheckTerm(expr->functionExpr);
+ newExpr = m_astBuilder->create<LogicOperatorShortCircuitExpr>();
+ if (varExpr->name->text == "&&")
+ {
+ newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::And;
+ }
+ else
+ {
+ newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::Or;
+ }
+ newExpr->loc = expr->loc;
+ newExpr->functionExpr = expr->functionExpr;
+ newExpr->type = m_astBuilder->getBoolType();
+ newExpr->arguments = expr->arguments;
+ }
+ }
+
+ return newExpr;
+ }
+
Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr)
{
// check the base expression first
if (!expr->originalFunctionExpr)
expr->originalFunctionExpr = expr->functionExpr;
- expr->functionExpr = CheckTerm(expr->functionExpr);
auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr;
m_treatAsDifferentiableExpr = nullptr;
// Next check the argument expressions
@@ -2384,6 +2448,13 @@ namespace Slang
{
arg = CheckTerm(arg);
}
+
+ // if the expression is '&&' or '||', we will convert it
+ // to use short-circuit evaluation.
+ if (auto newExpr = convertToLogicOperatorExpr(expr))
+ return newExpr;
+
+ expr->functionExpr = CheckTerm(expr->functionExpr);
m_treatAsDifferentiableExpr = treatAsDifferentiableExpr;
// If we are in a differentiable function, register differential witness tables involved in