summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2024-03-05 12:55:50 +0800
committerGitHub <noreply@github.com>2024-03-04 20:55:50 -0800
commit2297623aad4c249bccae3fe363ada31e308131ac (patch)
treea97b1f7d63ea207d3123a4d5c51ec7c0acfc25be /source
parent0371deef52c2ef9ffda3c5ec11f5b1082c0b96e8 (diff)
Implement short-circuit logic operator (#3635)
* Implement short-circuit logic operator Implement short-circuit evaluation for logic && and || operator. The short-circuit behavior is only used when the operands involved are scalar and the parent function is non-differentiable. In implementation, we define a new class 'LogicOperatorShortCircuitExpr' derived from 'OperatorExpr'. In the visitInvoke() call, we will create a new expression object 'LogicOperatorShortCircuitExpr' if the expression is logic && or ||. So that we can generate new IR code in the new visit function 'visitLogicOperatorShortCircuitExpr' to implement the short-circuit behavior. Add new test to test the short-circuit behavior. * Fix an compile issue occurred in Falcon test Previously, we early return when at least one of the operands of "&&" or "||" is vector in convertToLogicOperatorExpr call. However, in that case the arguments involved in the expression have already been type checked. When it falls-back to 'visitInvokeExpr', it will check the arguments again, and some unexpected behavior could occur which could in turn cause some internal error. So we add a check in the 'visitInvokeExpr' to avoid double type checking of arguments. * Update glsl subgroup test to not use short-circuit Since the short-circuit evaluation could cause the threads diverging in subgroup intrinsics. So change the test to not using "&&" to chain those subgroup intrinsics together. Instead, using "&" to chain them together because those test functions have the return value as bool. * Disable short-circuit in few situations Disable short-circuit in following situations: 1. generic parameter list 2. static const varible initialization * Use a flag to indicate the enablement of short-circuit Instead of using a struct to indicate the state of the outer environment of current expression, use a simple bool flag to indicate whether or not apply the short-circuit to current expression because there few situations where we will disable short-circuiting and in those circumstances, there is no nested. Therefore, a flag is good enough to indicate the case. * Disable short-circuit in index expression Also fix the build issue. (A cleanup for the last change.) * check both 'static' and 'const' modifiers Previously we only check HLSLStaticModifier to decide whether or not using short-circuit, but we really should check both 'static' and 'const' modifiers together, because we only want to disable the short circuit for init expression for 'static const' variable. * relax the restriction of short-circuit for index expression Disable the short-circuit for index expression only when declare an array. * Simplify the logic by creating subVisitor Simplify the logic by create a sub expression visitor so that we don't need to introduce extra recursion. * Call convertToLogicOperatorExpr after args check Change to call convertToLogicOperatorExpr after arguments check in visitInvokeExpr such that we don't have to check whether the arguments checked to avoid the double checking issue.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-expr.h13
-rw-r--r--source/slang/slang-check-decl.cpp18
-rw-r--r--source/slang/slang-check-expr.cpp79
-rw-r--r--source/slang/slang-check-impl.h19
-rw-r--r--source/slang/slang-check-overload.cpp9
-rw-r--r--source/slang/slang-lower-to-ir.cpp46
6 files changed, 175 insertions, 9 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 9a7993937..fa1b27a04 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -431,6 +431,19 @@ class SelectExpr: public OperatorExpr
SLANG_AST_CLASS(SelectExpr)
};
+class LogicOperatorShortCircuitExpr: public OperatorExpr
+{
+ SLANG_AST_CLASS(LogicOperatorShortCircuitExpr)
+public:
+ enum Flavor
+ {
+ And, // &&
+ Or, // ||
+ };
+ Flavor flavor;
+};
+
+
class GenericAppExpr: public AppExprBase
{
SLANG_AST_CLASS(GenericAppExpr)
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 4dfffc414..3596b7045 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -310,6 +310,7 @@ namespace Slang
void visitParamDecl(ParamDecl* paramDecl);
void visitAggTypeDecl(AggTypeDecl* aggTypeDecl);
+
};
template<typename VisitorType>
@@ -1814,11 +1815,18 @@ namespace Slang
{
if (auto initExpr = varDecl->initExpr)
{
- // If the variable has an explicit initial-value expression,
- // then we simply need to check that expression and coerce
- // it to the type of the variable.
- //
- initExpr = CheckTerm(initExpr);
+ // Disable the short-circuiting for static const variable init expression
+ bool isStaticConst = varDecl->hasModifier<HLSLStaticModifier>() &&
+ varDecl->hasModifier<ConstModifier>();
+
+ auto subVisitor = isStaticConst?
+ SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this;
+ // If the variable has an explicit initial-value expression,
+ // then we simply need to check that expression and coerce
+ // it to the type of the variable.
+ //
+ initExpr = subVisitor.CheckTerm(initExpr);
+
initExpr = coerce(CoercionSite::Initializer, varDecl->type.Ptr(), initExpr);
varDecl->initExpr = initExpr;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 55ff90759..89a7373ee 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1906,9 +1906,17 @@ namespace Slang
{
auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression);
+ // If the base expression is a type, it means that this is an array declaration,
+ // then we should disable short-circuit in case there is logical expression in
+ // the subscript
+ auto baseType = baseExpr->type.Ptr();
+ auto baseTypeType = as<TypeType>(baseType);
+ auto subVisitor = (baseTypeType && m_shouldShortCircuitLogicExpr)?
+ SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this;
+
for (auto& arg : subscriptExpr->indexExprs)
{
- arg = CheckTerm(arg);
+ arg = subVisitor.CheckTerm(arg);
}
// If anything went wrong in the base expression,
@@ -1920,8 +1928,7 @@ namespace Slang
// Otherwise, we need to look at the type of the base expression,
// to figure out how subscripting should work.
- auto baseType = baseExpr->type.Ptr();
- if (auto baseTypeType = as<TypeType>(baseType))
+ if (baseTypeType)
{
// We are trying to "index" into a type, so we have an expression like `float[2]`
// which should be interpreted as resolving to an array type.
@@ -2371,12 +2378,69 @@ namespace Slang
return result;
}
+ Expr* SemanticsExprVisitor::convertToLogicOperatorExpr(InvokeExpr* expr)
+ {
+ LogicOperatorShortCircuitExpr* newExpr = nullptr;
+
+ // If the logic expression is inside the generic parameter list, it cannot support short-circuit
+ // which will generate the ifelse branch.
+ if (!m_shouldShortCircuitLogicExpr)
+ {
+ return nullptr;
+ }
+
+ if (auto varExpr = as<VarExpr>(expr->functionExpr))
+ {
+ if ((varExpr->name->text == "&&") || (varExpr->name->text == "||"))
+ {
+ // We only use short-circuiting in scalar input, will fall back
+ // to non-short-circuiting in vector input.
+ bool shortCircuitSupport = true;
+ for (auto & arg : expr->arguments)
+ {
+ if(!as<BasicExpressionType>(arg->type.type))
+ {
+ shortCircuitSupport = false;
+ }
+ }
+
+ if (!shortCircuitSupport)
+ {
+ return nullptr;
+ }
+
+ // We do the cast in the 2nd pass because we want to leave it for 'visitInvokeExpr'
+ // to handle if this expression doesn't support short-circuiting.
+ for (auto & arg : expr->arguments)
+ {
+ arg = coerce(CoercionSite::Argument, m_astBuilder->getBoolType(), arg);
+ }
+
+ expr->functionExpr = CheckTerm(expr->functionExpr);
+ newExpr = m_astBuilder->create<LogicOperatorShortCircuitExpr>();
+ if (varExpr->name->text == "&&")
+ {
+ newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::And;
+ }
+ else
+ {
+ newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::Or;
+ }
+ newExpr->loc = expr->loc;
+ newExpr->functionExpr = expr->functionExpr;
+ newExpr->type = m_astBuilder->getBoolType();
+ newExpr->arguments = expr->arguments;
+ }
+ }
+
+ return newExpr;
+ }
+
Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr)
{
// check the base expression first
if (!expr->originalFunctionExpr)
expr->originalFunctionExpr = expr->functionExpr;
- expr->functionExpr = CheckTerm(expr->functionExpr);
auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr;
m_treatAsDifferentiableExpr = nullptr;
// Next check the argument expressions
@@ -2384,6 +2448,13 @@ namespace Slang
{
arg = CheckTerm(arg);
}
+
+ // if the expression is '&&' or '||', we will convert it
+ // to use short-circuit evaluation.
+ if (auto newExpr = convertToLogicOperatorExpr(expr))
+ return newExpr;
+
+ expr->functionExpr = CheckTerm(expr->functionExpr);
m_treatAsDifferentiableExpr = treatAsDifferentiableExpr;
// If we are in a differentiable function, register differential witness tables involved in
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index eb3f9486c..a209d96d9 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -848,6 +848,15 @@ namespace Slang
return result;
}
+ // Setup the flag to indicate disabling the short-circuiting evaluation
+ // for the logical expressions associted with the subcontext
+ SemanticsContext disableShortCircuitLogicalExpr()
+ {
+ SemanticsContext result(*this);
+ result.m_shouldShortCircuitLogicExpr = false;
+ return result;
+ }
+
TryClauseType getEnclosingTryClauseType() { return m_enclosingTryClauseType; }
SemanticsContext withEnclosingTryClauseType(TryClauseType tryClauseType)
@@ -945,6 +954,13 @@ namespace Slang
ASTBuilder* m_astBuilder = nullptr;
Scope* m_outerScope = nullptr;
+
+ // By default, we will support short-circuit evaluation for the logic expression.
+ // However, there are few exceptions where we will disable it:
+ // 1. the logic expression is inside the generic parameter list.
+ // 2. the logic expression is in the init expression of a static const variable.
+ // 3. the logic expression is in an array size declaration.
+ bool m_shouldShortCircuitLogicExpr = true;
};
struct OuterScopeContextRAII
@@ -2589,6 +2605,9 @@ namespace Slang
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
Val* checkTypeModifier(Modifier* modifier, Type* type);
+ private:
+ // Convert the logic operator expression to not use 'InvokeExpr' type
+ Expr* convertToLogicOperatorExpr(InvokeExpr* expr);
};
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index c6e6677b3..84c005f28 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -2113,6 +2113,15 @@ namespace Slang
Expr* SemanticsExprVisitor::visitGenericAppExpr(GenericAppExpr* genericAppExpr)
{
// Start by checking the base expression and arguments.
+
+ // Disable the short-circuiting logic expression when the experssion is in
+ // the generic parameter.
+ if (this->m_shouldShortCircuitLogicExpr)
+ {
+ auto subContext = disableShortCircuitLogicalExpr();
+ return dispatchExpr(genericAppExpr, subContext);
+ }
+
auto& baseExpr = genericAppExpr->functionExpr;
baseExpr = CheckTerm(baseExpr);
auto& args = genericAppExpr->arguments;
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 5e35bf165..409cd65ee 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4575,6 +4575,46 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
return LoweredValInfo::simple(result);
}
+ LoweredValInfo visitLogicOperatorShortCircuitExpr(LogicOperatorShortCircuitExpr* expr)
+ {
+ auto builder = context->irBuilder;
+ auto thenBlock = builder->createBlock();
+ auto elseBlock = builder->createBlock();
+ auto afterBlock = builder->createBlock();
+ auto irCond = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[0]));
+
+ // ifElse(<first param>, %true-block, %false-block, %after-block)
+ builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock);
+
+ // true-block: nonconditionalBranch(%after-block, <second param> : Bool)
+ // true-block: nonconditionalBranch(%after-block, true) for ||
+ builder->insertBlock(thenBlock);
+ builder->setInsertInto(thenBlock);
+ auto trueVal = expr->flavor == LogicOperatorShortCircuitExpr::Flavor::And ?
+ getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1])) :
+ LoweredValInfo::simple(context->irBuilder->getBoolValue(true)).val;
+
+ builder->emitBranch(afterBlock, 1, &trueVal);
+
+ // false-block: nonconditionalBranch(%after-block, false) for &&
+ // false-block: nonconditionalBranch(%after-block, <second param>: Bool) for ||
+ builder->insertBlock(elseBlock);
+ builder->setInsertInto(elseBlock);
+ auto falseVal = expr->flavor == LogicOperatorShortCircuitExpr::Flavor::And ?
+ LoweredValInfo::simple(context->irBuilder->getBoolValue(false)).val :
+ getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1]));
+
+ builder->emitBranch(afterBlock, 1, &falseVal);
+
+ // after-block: return input parameter
+ builder->insertBlock(afterBlock);
+ builder->setInsertInto(afterBlock);
+ auto paramType = lowerType(context, expr->type.type);
+ auto result = builder->emitParam(paramType);
+
+ return LoweredValInfo::simple(result);
+ }
+
LoweredValInfo visitInvokeExpr(InvokeExpr* expr)
{
return sharedLoweringContext.visitInvokeExprImpl(expr, LoweredValInfo(), TryClauseEnvironment());
@@ -5197,6 +5237,12 @@ struct DestinationDrivenRValueExprLoweringVisitor
assign(context, destination, rValue);
}
+ void visitLogicOperatorShortCircuitExpr(LogicOperatorShortCircuitExpr* expr)
+ {
+ auto rValue = lowerRValueExpr(context, expr);
+ assign(context, destination, rValue);
+ }
+
void visitInvokeExpr(InvokeExpr* expr)
{
LoweredValInfo resultRVal;