diff options
| author | kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> | 2024-03-05 12:55:50 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-03-04 20:55:50 -0800 |
| commit | 2297623aad4c249bccae3fe363ada31e308131ac (patch) | |
| tree | a97b1f7d63ea207d3123a4d5c51ec7c0acfc25be /source | |
| parent | 0371deef52c2ef9ffda3c5ec11f5b1082c0b96e8 (diff) | |
Implement short-circuit logic operator (#3635)
* Implement short-circuit logic operator
Implement short-circuit evaluation for logic && and ||
operator.
The short-circuit behavior is only used when the operands
involved are scalar and the parent function is non-differentiable.
In implementation, we define a new class 'LogicOperatorShortCircuitExpr'
derived from 'OperatorExpr'. In the visitInvoke() call, we will create
a new expression object 'LogicOperatorShortCircuitExpr' if the
expression is logic && or ||. So that we can generate new IR code in the
new visit function 'visitLogicOperatorShortCircuitExpr' to implement the
short-circuit behavior.
Add new test to test the short-circuit behavior.
* Fix an compile issue occurred in Falcon test
Previously, we early return when at least one of the operands of
"&&" or "||" is vector in convertToLogicOperatorExpr call. However,
in that case the arguments involved in the expression have already been
type checked. When it falls-back to 'visitInvokeExpr', it will check
the arguments again, and some unexpected behavior could occur
which could in turn cause some internal error.
So we add a check in the 'visitInvokeExpr' to avoid double type checking
of arguments.
* Update glsl subgroup test to not use short-circuit
Since the short-circuit evaluation could cause the threads
diverging in subgroup intrinsics. So change the test to not
using "&&" to chain those subgroup intrinsics together. Instead,
using "&" to chain them together because those test functions have
the return value as bool.
* Disable short-circuit in few situations
Disable short-circuit in following situations:
1. generic parameter list
2. static const varible initialization
* Use a flag to indicate the enablement of short-circuit
Instead of using a struct to indicate the state of the outer
environment of current expression, use a simple bool flag to
indicate whether or not apply the short-circuit to current
expression because there few situations where we will disable
short-circuiting and in those circumstances, there is no nested.
Therefore, a flag is good enough to indicate the case.
* Disable short-circuit in index expression
Also fix the build issue. (A cleanup for the last change.)
* check both 'static' and 'const' modifiers
Previously we only check HLSLStaticModifier to decide whether or
not using short-circuit, but we really should check both 'static'
and 'const' modifiers together, because we only want to disable
the short circuit for init expression for 'static const' variable.
* relax the restriction of short-circuit for index expression
Disable the short-circuit for index expression only when declare
an array.
* Simplify the logic by creating subVisitor
Simplify the logic by create a sub expression visitor so
that we don't need to introduce extra recursion.
* Call convertToLogicOperatorExpr after args check
Change to call convertToLogicOperatorExpr after arguments
check in visitInvokeExpr such that we don't have to check
whether the arguments checked to avoid the double checking
issue.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-expr.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 79 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 46 |
6 files changed, 175 insertions, 9 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 9a7993937..fa1b27a04 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -431,6 +431,19 @@ class SelectExpr: public OperatorExpr SLANG_AST_CLASS(SelectExpr) }; +class LogicOperatorShortCircuitExpr: public OperatorExpr +{ + SLANG_AST_CLASS(LogicOperatorShortCircuitExpr) +public: + enum Flavor + { + And, // && + Or, // || + }; + Flavor flavor; +}; + + class GenericAppExpr: public AppExprBase { SLANG_AST_CLASS(GenericAppExpr) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4dfffc414..3596b7045 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -310,6 +310,7 @@ namespace Slang void visitParamDecl(ParamDecl* paramDecl); void visitAggTypeDecl(AggTypeDecl* aggTypeDecl); + }; template<typename VisitorType> @@ -1814,11 +1815,18 @@ namespace Slang { if (auto initExpr = varDecl->initExpr) { - // If the variable has an explicit initial-value expression, - // then we simply need to check that expression and coerce - // it to the type of the variable. - // - initExpr = CheckTerm(initExpr); + // Disable the short-circuiting for static const variable init expression + bool isStaticConst = varDecl->hasModifier<HLSLStaticModifier>() && + varDecl->hasModifier<ConstModifier>(); + + auto subVisitor = isStaticConst? + SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this; + // If the variable has an explicit initial-value expression, + // then we simply need to check that expression and coerce + // it to the type of the variable. + // + initExpr = subVisitor.CheckTerm(initExpr); + initExpr = coerce(CoercionSite::Initializer, varDecl->type.Ptr(), initExpr); varDecl->initExpr = initExpr; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 55ff90759..89a7373ee 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1906,9 +1906,17 @@ namespace Slang { auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression); + // If the base expression is a type, it means that this is an array declaration, + // then we should disable short-circuit in case there is logical expression in + // the subscript + auto baseType = baseExpr->type.Ptr(); + auto baseTypeType = as<TypeType>(baseType); + auto subVisitor = (baseTypeType && m_shouldShortCircuitLogicExpr)? + SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this; + for (auto& arg : subscriptExpr->indexExprs) { - arg = CheckTerm(arg); + arg = subVisitor.CheckTerm(arg); } // If anything went wrong in the base expression, @@ -1920,8 +1928,7 @@ namespace Slang // Otherwise, we need to look at the type of the base expression, // to figure out how subscripting should work. - auto baseType = baseExpr->type.Ptr(); - if (auto baseTypeType = as<TypeType>(baseType)) + if (baseTypeType) { // We are trying to "index" into a type, so we have an expression like `float[2]` // which should be interpreted as resolving to an array type. @@ -2371,12 +2378,69 @@ namespace Slang return result; } + Expr* SemanticsExprVisitor::convertToLogicOperatorExpr(InvokeExpr* expr) + { + LogicOperatorShortCircuitExpr* newExpr = nullptr; + + // If the logic expression is inside the generic parameter list, it cannot support short-circuit + // which will generate the ifelse branch. + if (!m_shouldShortCircuitLogicExpr) + { + return nullptr; + } + + if (auto varExpr = as<VarExpr>(expr->functionExpr)) + { + if ((varExpr->name->text == "&&") || (varExpr->name->text == "||")) + { + // We only use short-circuiting in scalar input, will fall back + // to non-short-circuiting in vector input. + bool shortCircuitSupport = true; + for (auto & arg : expr->arguments) + { + if(!as<BasicExpressionType>(arg->type.type)) + { + shortCircuitSupport = false; + } + } + + if (!shortCircuitSupport) + { + return nullptr; + } + + // We do the cast in the 2nd pass because we want to leave it for 'visitInvokeExpr' + // to handle if this expression doesn't support short-circuiting. + for (auto & arg : expr->arguments) + { + arg = coerce(CoercionSite::Argument, m_astBuilder->getBoolType(), arg); + } + + expr->functionExpr = CheckTerm(expr->functionExpr); + newExpr = m_astBuilder->create<LogicOperatorShortCircuitExpr>(); + if (varExpr->name->text == "&&") + { + newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::And; + } + else + { + newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::Or; + } + newExpr->loc = expr->loc; + newExpr->functionExpr = expr->functionExpr; + newExpr->type = m_astBuilder->getBoolType(); + newExpr->arguments = expr->arguments; + } + } + + return newExpr; + } + Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr) { // check the base expression first if (!expr->originalFunctionExpr) expr->originalFunctionExpr = expr->functionExpr; - expr->functionExpr = CheckTerm(expr->functionExpr); auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr; m_treatAsDifferentiableExpr = nullptr; // Next check the argument expressions @@ -2384,6 +2448,13 @@ namespace Slang { arg = CheckTerm(arg); } + + // if the expression is '&&' or '||', we will convert it + // to use short-circuit evaluation. + if (auto newExpr = convertToLogicOperatorExpr(expr)) + return newExpr; + + expr->functionExpr = CheckTerm(expr->functionExpr); m_treatAsDifferentiableExpr = treatAsDifferentiableExpr; // If we are in a differentiable function, register differential witness tables involved in diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index eb3f9486c..a209d96d9 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -848,6 +848,15 @@ namespace Slang return result; } + // Setup the flag to indicate disabling the short-circuiting evaluation + // for the logical expressions associted with the subcontext + SemanticsContext disableShortCircuitLogicalExpr() + { + SemanticsContext result(*this); + result.m_shouldShortCircuitLogicExpr = false; + return result; + } + TryClauseType getEnclosingTryClauseType() { return m_enclosingTryClauseType; } SemanticsContext withEnclosingTryClauseType(TryClauseType tryClauseType) @@ -945,6 +954,13 @@ namespace Slang ASTBuilder* m_astBuilder = nullptr; Scope* m_outerScope = nullptr; + + // By default, we will support short-circuit evaluation for the logic expression. + // However, there are few exceptions where we will disable it: + // 1. the logic expression is inside the generic parameter list. + // 2. the logic expression is in the init expression of a static const variable. + // 3. the logic expression is in an array size declaration. + bool m_shouldShortCircuitLogicExpr = true; }; struct OuterScopeContextRAII @@ -2589,6 +2605,9 @@ namespace Slang /// Perform semantic checking on a `modifier` that is being applied to the given `type` Val* checkTypeModifier(Modifier* modifier, Type* type); + private: + // Convert the logic operator expression to not use 'InvokeExpr' type + Expr* convertToLogicOperatorExpr(InvokeExpr* expr); }; diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index c6e6677b3..84c005f28 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -2113,6 +2113,15 @@ namespace Slang Expr* SemanticsExprVisitor::visitGenericAppExpr(GenericAppExpr* genericAppExpr) { // Start by checking the base expression and arguments. + + // Disable the short-circuiting logic expression when the experssion is in + // the generic parameter. + if (this->m_shouldShortCircuitLogicExpr) + { + auto subContext = disableShortCircuitLogicalExpr(); + return dispatchExpr(genericAppExpr, subContext); + } + auto& baseExpr = genericAppExpr->functionExpr; baseExpr = CheckTerm(baseExpr); auto& args = genericAppExpr->arguments; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5e35bf165..409cd65ee 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4575,6 +4575,46 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple(result); } + LoweredValInfo visitLogicOperatorShortCircuitExpr(LogicOperatorShortCircuitExpr* expr) + { + auto builder = context->irBuilder; + auto thenBlock = builder->createBlock(); + auto elseBlock = builder->createBlock(); + auto afterBlock = builder->createBlock(); + auto irCond = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[0])); + + // ifElse(<first param>, %true-block, %false-block, %after-block) + builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock); + + // true-block: nonconditionalBranch(%after-block, <second param> : Bool) + // true-block: nonconditionalBranch(%after-block, true) for || + builder->insertBlock(thenBlock); + builder->setInsertInto(thenBlock); + auto trueVal = expr->flavor == LogicOperatorShortCircuitExpr::Flavor::And ? + getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1])) : + LoweredValInfo::simple(context->irBuilder->getBoolValue(true)).val; + + builder->emitBranch(afterBlock, 1, &trueVal); + + // false-block: nonconditionalBranch(%after-block, false) for && + // false-block: nonconditionalBranch(%after-block, <second param>: Bool) for || + builder->insertBlock(elseBlock); + builder->setInsertInto(elseBlock); + auto falseVal = expr->flavor == LogicOperatorShortCircuitExpr::Flavor::And ? + LoweredValInfo::simple(context->irBuilder->getBoolValue(false)).val : + getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1])); + + builder->emitBranch(afterBlock, 1, &falseVal); + + // after-block: return input parameter + builder->insertBlock(afterBlock); + builder->setInsertInto(afterBlock); + auto paramType = lowerType(context, expr->type.type); + auto result = builder->emitParam(paramType); + + return LoweredValInfo::simple(result); + } + LoweredValInfo visitInvokeExpr(InvokeExpr* expr) { return sharedLoweringContext.visitInvokeExprImpl(expr, LoweredValInfo(), TryClauseEnvironment()); @@ -5197,6 +5237,12 @@ struct DestinationDrivenRValueExprLoweringVisitor assign(context, destination, rValue); } + void visitLogicOperatorShortCircuitExpr(LogicOperatorShortCircuitExpr* expr) + { + auto rValue = lowerRValueExpr(context, expr); + assign(context, destination, rValue); + } + void visitInvokeExpr(InvokeExpr* expr) { LoweredValInfo resultRVal; |
