diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-expr.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 79 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 46 |
6 files changed, 175 insertions, 9 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 9a7993937..fa1b27a04 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -431,6 +431,19 @@ class SelectExpr: public OperatorExpr SLANG_AST_CLASS(SelectExpr) }; +class LogicOperatorShortCircuitExpr: public OperatorExpr +{ + SLANG_AST_CLASS(LogicOperatorShortCircuitExpr) +public: + enum Flavor + { + And, // && + Or, // || + }; + Flavor flavor; +}; + + class GenericAppExpr: public AppExprBase { SLANG_AST_CLASS(GenericAppExpr) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4dfffc414..3596b7045 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -310,6 +310,7 @@ namespace Slang void visitParamDecl(ParamDecl* paramDecl); void visitAggTypeDecl(AggTypeDecl* aggTypeDecl); + }; template<typename VisitorType> @@ -1814,11 +1815,18 @@ namespace Slang { if (auto initExpr = varDecl->initExpr) { - // If the variable has an explicit initial-value expression, - // then we simply need to check that expression and coerce - // it to the type of the variable. - // - initExpr = CheckTerm(initExpr); + // Disable the short-circuiting for static const variable init expression + bool isStaticConst = varDecl->hasModifier<HLSLStaticModifier>() && + varDecl->hasModifier<ConstModifier>(); + + auto subVisitor = isStaticConst? + SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this; + // If the variable has an explicit initial-value expression, + // then we simply need to check that expression and coerce + // it to the type of the variable. + // + initExpr = subVisitor.CheckTerm(initExpr); + initExpr = coerce(CoercionSite::Initializer, varDecl->type.Ptr(), initExpr); varDecl->initExpr = initExpr; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 55ff90759..89a7373ee 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1906,9 +1906,17 @@ namespace Slang { auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression); + // If the base expression is a type, it means that this is an array declaration, + // then we should disable short-circuit in case there is logical expression in + // the subscript + auto baseType = baseExpr->type.Ptr(); + auto baseTypeType = as<TypeType>(baseType); + auto subVisitor = (baseTypeType && m_shouldShortCircuitLogicExpr)? + SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this; + for (auto& arg : subscriptExpr->indexExprs) { - arg = CheckTerm(arg); + arg = subVisitor.CheckTerm(arg); } // If anything went wrong in the base expression, @@ -1920,8 +1928,7 @@ namespace Slang // Otherwise, we need to look at the type of the base expression, // to figure out how subscripting should work. - auto baseType = baseExpr->type.Ptr(); - if (auto baseTypeType = as<TypeType>(baseType)) + if (baseTypeType) { // We are trying to "index" into a type, so we have an expression like `float[2]` // which should be interpreted as resolving to an array type. @@ -2371,12 +2378,69 @@ namespace Slang return result; } + Expr* SemanticsExprVisitor::convertToLogicOperatorExpr(InvokeExpr* expr) + { + LogicOperatorShortCircuitExpr* newExpr = nullptr; + + // If the logic expression is inside the generic parameter list, it cannot support short-circuit + // which will generate the ifelse branch. + if (!m_shouldShortCircuitLogicExpr) + { + return nullptr; + } + + if (auto varExpr = as<VarExpr>(expr->functionExpr)) + { + if ((varExpr->name->text == "&&") || (varExpr->name->text == "||")) + { + // We only use short-circuiting in scalar input, will fall back + // to non-short-circuiting in vector input. + bool shortCircuitSupport = true; + for (auto & arg : expr->arguments) + { + if(!as<BasicExpressionType>(arg->type.type)) + { + shortCircuitSupport = false; + } + } + + if (!shortCircuitSupport) + { + return nullptr; + } + + // We do the cast in the 2nd pass because we want to leave it for 'visitInvokeExpr' + // to handle if this expression doesn't support short-circuiting. + for (auto & arg : expr->arguments) + { + arg = coerce(CoercionSite::Argument, m_astBuilder->getBoolType(), arg); + } + + expr->functionExpr = CheckTerm(expr->functionExpr); + newExpr = m_astBuilder->create<LogicOperatorShortCircuitExpr>(); + if (varExpr->name->text == "&&") + { + newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::And; + } + else + { + newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::Or; + } + newExpr->loc = expr->loc; + newExpr->functionExpr = expr->functionExpr; + newExpr->type = m_astBuilder->getBoolType(); + newExpr->arguments = expr->arguments; + } + } + + return newExpr; + } + Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr) { // check the base expression first if (!expr->originalFunctionExpr) expr->originalFunctionExpr = expr->functionExpr; - expr->functionExpr = CheckTerm(expr->functionExpr); auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr; m_treatAsDifferentiableExpr = nullptr; // Next check the argument expressions @@ -2384,6 +2448,13 @@ namespace Slang { arg = CheckTerm(arg); } + + // if the expression is '&&' or '||', we will convert it + // to use short-circuit evaluation. + if (auto newExpr = convertToLogicOperatorExpr(expr)) + return newExpr; + + expr->functionExpr = CheckTerm(expr->functionExpr); m_treatAsDifferentiableExpr = treatAsDifferentiableExpr; // If we are in a differentiable function, register differential witness tables involved in diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index eb3f9486c..a209d96d9 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -848,6 +848,15 @@ namespace Slang return result; } + // Setup the flag to indicate disabling the short-circuiting evaluation + // for the logical expressions associted with the subcontext + SemanticsContext disableShortCircuitLogicalExpr() + { + SemanticsContext result(*this); + result.m_shouldShortCircuitLogicExpr = false; + return result; + } + TryClauseType getEnclosingTryClauseType() { return m_enclosingTryClauseType; } SemanticsContext withEnclosingTryClauseType(TryClauseType tryClauseType) @@ -945,6 +954,13 @@ namespace Slang ASTBuilder* m_astBuilder = nullptr; Scope* m_outerScope = nullptr; + + // By default, we will support short-circuit evaluation for the logic expression. + // However, there are few exceptions where we will disable it: + // 1. the logic expression is inside the generic parameter list. + // 2. the logic expression is in the init expression of a static const variable. + // 3. the logic expression is in an array size declaration. + bool m_shouldShortCircuitLogicExpr = true; }; struct OuterScopeContextRAII @@ -2589,6 +2605,9 @@ namespace Slang /// Perform semantic checking on a `modifier` that is being applied to the given `type` Val* checkTypeModifier(Modifier* modifier, Type* type); + private: + // Convert the logic operator expression to not use 'InvokeExpr' type + Expr* convertToLogicOperatorExpr(InvokeExpr* expr); }; diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index c6e6677b3..84c005f28 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -2113,6 +2113,15 @@ namespace Slang Expr* SemanticsExprVisitor::visitGenericAppExpr(GenericAppExpr* genericAppExpr) { // Start by checking the base expression and arguments. + + // Disable the short-circuiting logic expression when the experssion is in + // the generic parameter. + if (this->m_shouldShortCircuitLogicExpr) + { + auto subContext = disableShortCircuitLogicalExpr(); + return dispatchExpr(genericAppExpr, subContext); + } + auto& baseExpr = genericAppExpr->functionExpr; baseExpr = CheckTerm(baseExpr); auto& args = genericAppExpr->arguments; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5e35bf165..409cd65ee 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4575,6 +4575,46 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple(result); } + LoweredValInfo visitLogicOperatorShortCircuitExpr(LogicOperatorShortCircuitExpr* expr) + { + auto builder = context->irBuilder; + auto thenBlock = builder->createBlock(); + auto elseBlock = builder->createBlock(); + auto afterBlock = builder->createBlock(); + auto irCond = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[0])); + + // ifElse(<first param>, %true-block, %false-block, %after-block) + builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock); + + // true-block: nonconditionalBranch(%after-block, <second param> : Bool) + // true-block: nonconditionalBranch(%after-block, true) for || + builder->insertBlock(thenBlock); + builder->setInsertInto(thenBlock); + auto trueVal = expr->flavor == LogicOperatorShortCircuitExpr::Flavor::And ? + getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1])) : + LoweredValInfo::simple(context->irBuilder->getBoolValue(true)).val; + + builder->emitBranch(afterBlock, 1, &trueVal); + + // false-block: nonconditionalBranch(%after-block, false) for && + // false-block: nonconditionalBranch(%after-block, <second param>: Bool) for || + builder->insertBlock(elseBlock); + builder->setInsertInto(elseBlock); + auto falseVal = expr->flavor == LogicOperatorShortCircuitExpr::Flavor::And ? + LoweredValInfo::simple(context->irBuilder->getBoolValue(false)).val : + getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1])); + + builder->emitBranch(afterBlock, 1, &falseVal); + + // after-block: return input parameter + builder->insertBlock(afterBlock); + builder->setInsertInto(afterBlock); + auto paramType = lowerType(context, expr->type.type); + auto result = builder->emitParam(paramType); + + return LoweredValInfo::simple(result); + } + LoweredValInfo visitInvokeExpr(InvokeExpr* expr) { return sharedLoweringContext.visitInvokeExprImpl(expr, LoweredValInfo(), TryClauseEnvironment()); @@ -5197,6 +5237,12 @@ struct DestinationDrivenRValueExprLoweringVisitor assign(context, destination, rValue); } + void visitLogicOperatorShortCircuitExpr(LogicOperatorShortCircuitExpr* expr) + { + auto rValue = lowerRValueExpr(context, expr); + assign(context, destination, rValue); + } + void visitInvokeExpr(InvokeExpr* expr) { LoweredValInfo resultRVal; |
