summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-expr.h13
-rw-r--r--source/slang/slang-check-decl.cpp18
-rw-r--r--source/slang/slang-check-expr.cpp79
-rw-r--r--source/slang/slang-check-impl.h19
-rw-r--r--source/slang/slang-check-overload.cpp9
-rw-r--r--source/slang/slang-lower-to-ir.cpp46
6 files changed, 175 insertions, 9 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 9a7993937..fa1b27a04 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -431,6 +431,19 @@ class SelectExpr: public OperatorExpr
SLANG_AST_CLASS(SelectExpr)
};
+class LogicOperatorShortCircuitExpr: public OperatorExpr
+{
+ SLANG_AST_CLASS(LogicOperatorShortCircuitExpr)
+public:
+ enum Flavor
+ {
+ And, // &&
+ Or, // ||
+ };
+ Flavor flavor;
+};
+
+
class GenericAppExpr: public AppExprBase
{
SLANG_AST_CLASS(GenericAppExpr)
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 4dfffc414..3596b7045 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -310,6 +310,7 @@ namespace Slang
void visitParamDecl(ParamDecl* paramDecl);
void visitAggTypeDecl(AggTypeDecl* aggTypeDecl);
+
};
template<typename VisitorType>
@@ -1814,11 +1815,18 @@ namespace Slang
{
if (auto initExpr = varDecl->initExpr)
{
- // If the variable has an explicit initial-value expression,
- // then we simply need to check that expression and coerce
- // it to the type of the variable.
- //
- initExpr = CheckTerm(initExpr);
+ // Disable the short-circuiting for static const variable init expression
+ bool isStaticConst = varDecl->hasModifier<HLSLStaticModifier>() &&
+ varDecl->hasModifier<ConstModifier>();
+
+ auto subVisitor = isStaticConst?
+ SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this;
+ // If the variable has an explicit initial-value expression,
+ // then we simply need to check that expression and coerce
+ // it to the type of the variable.
+ //
+ initExpr = subVisitor.CheckTerm(initExpr);
+
initExpr = coerce(CoercionSite::Initializer, varDecl->type.Ptr(), initExpr);
varDecl->initExpr = initExpr;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 55ff90759..89a7373ee 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1906,9 +1906,17 @@ namespace Slang
{
auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression);
+ // If the base expression is a type, it means that this is an array declaration,
+ // then we should disable short-circuit in case there is logical expression in
+ // the subscript
+ auto baseType = baseExpr->type.Ptr();
+ auto baseTypeType = as<TypeType>(baseType);
+ auto subVisitor = (baseTypeType && m_shouldShortCircuitLogicExpr)?
+ SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this;
+
for (auto& arg : subscriptExpr->indexExprs)
{
- arg = CheckTerm(arg);
+ arg = subVisitor.CheckTerm(arg);
}
// If anything went wrong in the base expression,
@@ -1920,8 +1928,7 @@ namespace Slang
// Otherwise, we need to look at the type of the base expression,
// to figure out how subscripting should work.
- auto baseType = baseExpr->type.Ptr();
- if (auto baseTypeType = as<TypeType>(baseType))
+ if (baseTypeType)
{
// We are trying to "index" into a type, so we have an expression like `float[2]`
// which should be interpreted as resolving to an array type.
@@ -2371,12 +2378,69 @@ namespace Slang
return result;
}
+ Expr* SemanticsExprVisitor::convertToLogicOperatorExpr(InvokeExpr* expr)
+ {
+ LogicOperatorShortCircuitExpr* newExpr = nullptr;
+
+ // If the logic expression is inside the generic parameter list, it cannot support short-circuit
+ // which will generate the ifelse branch.
+ if (!m_shouldShortCircuitLogicExpr)
+ {
+ return nullptr;
+ }
+
+ if (auto varExpr = as<VarExpr>(expr->functionExpr))
+ {
+ if ((varExpr->name->text == "&&") || (varExpr->name->text == "||"))
+ {
+ // We only use short-circuiting in scalar input, will fall back
+ // to non-short-circuiting in vector input.
+ bool shortCircuitSupport = true;
+ for (auto & arg : expr->arguments)
+ {
+ if(!as<BasicExpressionType>(arg->type.type))
+ {
+ shortCircuitSupport = false;
+ }
+ }
+
+ if (!shortCircuitSupport)
+ {
+ return nullptr;
+ }
+
+ // We do the cast in the 2nd pass because we want to leave it for 'visitInvokeExpr'
+ // to handle if this expression doesn't support short-circuiting.
+ for (auto & arg : expr->arguments)
+ {
+ arg = coerce(CoercionSite::Argument, m_astBuilder->getBoolType(), arg);
+ }
+
+ expr->functionExpr = CheckTerm(expr->functionExpr);
+ newExpr = m_astBuilder->create<LogicOperatorShortCircuitExpr>();
+ if (varExpr->name->text == "&&")
+ {
+ newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::And;
+ }
+ else
+ {
+ newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::Or;
+ }
+ newExpr->loc = expr->loc;
+ newExpr->functionExpr = expr->functionExpr;
+ newExpr->type = m_astBuilder->getBoolType();
+ newExpr->arguments = expr->arguments;
+ }
+ }
+
+ return newExpr;
+ }
+
Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr)
{
// check the base expression first
if (!expr->originalFunctionExpr)
expr->originalFunctionExpr = expr->functionExpr;
- expr->functionExpr = CheckTerm(expr->functionExpr);
auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr;
m_treatAsDifferentiableExpr = nullptr;
// Next check the argument expressions
@@ -2384,6 +2448,13 @@ namespace Slang
{
arg = CheckTerm(arg);
}
+
+ // if the expression is '&&' or '||', we will convert it
+ // to use short-circuit evaluation.
+ if (auto newExpr = convertToLogicOperatorExpr(expr))
+ return newExpr;
+
+ expr->functionExpr = CheckTerm(expr->functionExpr);
m_treatAsDifferentiableExpr = treatAsDifferentiableExpr;
// If we are in a differentiable function, register differential witness tables involved in
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index eb3f9486c..a209d96d9 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -848,6 +848,15 @@ namespace Slang
return result;
}
+ // Setup the flag to indicate disabling the short-circuiting evaluation
+ // for the logical expressions associted with the subcontext
+ SemanticsContext disableShortCircuitLogicalExpr()
+ {
+ SemanticsContext result(*this);
+ result.m_shouldShortCircuitLogicExpr = false;
+ return result;
+ }
+
TryClauseType getEnclosingTryClauseType() { return m_enclosingTryClauseType; }
SemanticsContext withEnclosingTryClauseType(TryClauseType tryClauseType)
@@ -945,6 +954,13 @@ namespace Slang
ASTBuilder* m_astBuilder = nullptr;
Scope* m_outerScope = nullptr;
+
+ // By default, we will support short-circuit evaluation for the logic expression.
+ // However, there are few exceptions where we will disable it:
+ // 1. the logic expression is inside the generic parameter list.
+ // 2. the logic expression is in the init expression of a static const variable.
+ // 3. the logic expression is in an array size declaration.
+ bool m_shouldShortCircuitLogicExpr = true;
};
struct OuterScopeContextRAII
@@ -2589,6 +2605,9 @@ namespace Slang
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
Val* checkTypeModifier(Modifier* modifier, Type* type);
+ private:
+ // Convert the logic operator expression to not use 'InvokeExpr' type
+ Expr* convertToLogicOperatorExpr(InvokeExpr* expr);
};
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index c6e6677b3..84c005f28 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -2113,6 +2113,15 @@ namespace Slang
Expr* SemanticsExprVisitor::visitGenericAppExpr(GenericAppExpr* genericAppExpr)
{
// Start by checking the base expression and arguments.
+
+ // Disable the short-circuiting logic expression when the experssion is in
+ // the generic parameter.
+ if (this->m_shouldShortCircuitLogicExpr)
+ {
+ auto subContext = disableShortCircuitLogicalExpr();
+ return dispatchExpr(genericAppExpr, subContext);
+ }
+
auto& baseExpr = genericAppExpr->functionExpr;
baseExpr = CheckTerm(baseExpr);
auto& args = genericAppExpr->arguments;
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 5e35bf165..409cd65ee 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4575,6 +4575,46 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
return LoweredValInfo::simple(result);
}
+ LoweredValInfo visitLogicOperatorShortCircuitExpr(LogicOperatorShortCircuitExpr* expr)
+ {
+ auto builder = context->irBuilder;
+ auto thenBlock = builder->createBlock();
+ auto elseBlock = builder->createBlock();
+ auto afterBlock = builder->createBlock();
+ auto irCond = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[0]));
+
+ // ifElse(<first param>, %true-block, %false-block, %after-block)
+ builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock);
+
+ // true-block: nonconditionalBranch(%after-block, <second param> : Bool)
+ // true-block: nonconditionalBranch(%after-block, true) for ||
+ builder->insertBlock(thenBlock);
+ builder->setInsertInto(thenBlock);
+ auto trueVal = expr->flavor == LogicOperatorShortCircuitExpr::Flavor::And ?
+ getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1])) :
+ LoweredValInfo::simple(context->irBuilder->getBoolValue(true)).val;
+
+ builder->emitBranch(afterBlock, 1, &trueVal);
+
+ // false-block: nonconditionalBranch(%after-block, false) for &&
+ // false-block: nonconditionalBranch(%after-block, <second param>: Bool) for ||
+ builder->insertBlock(elseBlock);
+ builder->setInsertInto(elseBlock);
+ auto falseVal = expr->flavor == LogicOperatorShortCircuitExpr::Flavor::And ?
+ LoweredValInfo::simple(context->irBuilder->getBoolValue(false)).val :
+ getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1]));
+
+ builder->emitBranch(afterBlock, 1, &falseVal);
+
+ // after-block: return input parameter
+ builder->insertBlock(afterBlock);
+ builder->setInsertInto(afterBlock);
+ auto paramType = lowerType(context, expr->type.type);
+ auto result = builder->emitParam(paramType);
+
+ return LoweredValInfo::simple(result);
+ }
+
LoweredValInfo visitInvokeExpr(InvokeExpr* expr)
{
return sharedLoweringContext.visitInvokeExprImpl(expr, LoweredValInfo(), TryClauseEnvironment());
@@ -5197,6 +5237,12 @@ struct DestinationDrivenRValueExprLoweringVisitor
assign(context, destination, rValue);
}
+ void visitLogicOperatorShortCircuitExpr(LogicOperatorShortCircuitExpr* expr)
+ {
+ auto rValue = lowerRValueExpr(context, expr);
+ assign(context, destination, rValue);
+ }
+
void visitInvokeExpr(InvokeExpr* expr)
{
LoweredValInfo resultRVal;