From 666af0962b6ab41489a3a3287db83f77c2f6461a Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 24 Mar 2023 20:17:46 -0700 Subject: Switch to short circuiting semantics for scalar `?:` operator. (#2733) --- source/slang/core.meta.slang | 13 +++++++++---- source/slang/slang-check-expr.cpp | 34 +++++++++++++++++++++++++++++++++ source/slang/slang-check-impl.h | 2 ++ source/slang/slang-diagnostic-defs.h | 3 +++ source/slang/slang-lower-to-ir.cpp | 37 ++++++++++++++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 4 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 6581cc605..82a60a612 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -315,13 +315,18 @@ U operator,(T left, U right) return right; } -// The ternary `?:` operator does not short-circuit in HLSL, and Slang continues to -// follow that definition, so that this operator is effectively just an ordinary -// function, rather than a special-case piece of syntax. -// +// The ternary `?:` operator does not short-circuit in HLSL, and Slang no longer +// follow that definition for the scalar condition overload, so this declaration just serves +// for type-checking purpose only. + __generic __intrinsic_op(select) T operator?:(bool condition, T ifTrue, T ifFalse); __generic __intrinsic_op(select) vector operator?:(vector condition, vector ifTrue, vector ifFalse); +// Users are advised to use `select` instead if non-short-circuiting behavior is intended. +__generic __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse); +__generic __intrinsic_op(select) vector select(vector condition, vector ifTrue, vector ifFalse); + + ${{{{ // We are going to use code generation to produce the // declarations for all of our base types. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index bfad1dbfe..cfcb15269 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2022,6 +2022,40 @@ namespace Slang return rs; } + Expr* SemanticsExprVisitor::visitSelectExpr(SelectExpr* expr) + { + auto result = visitInvokeExpr(expr); + if (as(result->type.type)) + return result; + auto invokeExpr = as(result); + if (!result) + return result; + if (invokeExpr->arguments.getCount() != 3) + return result; + + if (as(invokeExpr->arguments[0]->type.type)) + { + auto newArgs = invokeExpr->arguments; + expr->arguments.clear(); + expr->arguments = newArgs; + expr->type = invokeExpr->type; + return expr; + } + + if (getParentDifferentiableAttribute()) + { + // If we are in a differentiable func, issue + // a diagnostic on use of non short-circuiting select. + getSink()->diagnose(expr->loc, Diagnostics::useOfNonShortCircuitingOperatorInDiffFunc); + } + else + { + // For all other functions, we issue a warning for deprecation of vector-typed ?: operator. + getSink()->diagnose(expr->loc, Diagnostics::useOfNonShortCircuitingOperator); + } + return result; + } + Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr *expr) { // check the base expression first diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 4181ca43b..50992b00b 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1896,6 +1896,8 @@ namespace Slang Expr* visitInvokeExpr(InvokeExpr *expr); + Expr* visitSelectExpr(SelectExpr* expr); + Expr* visitVarExpr(VarExpr *expr); Expr* visitTypeCastExpr(TypeCastExpr * expr); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index e3e9cfc44..c3e0adbca 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -275,6 +275,9 @@ DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$ DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'") DIAGNOSTIC(30053, Error, breakLabelNotFound, "label '$0' used as break target is not found.") DIAGNOSTIC(30054, Error, targetLabelDoesNotMarkBreakableStmt, "invalid break target: statement labeled '$0' is not breakable.") +DIAGNOSTIC(30055, Error, useOfNonShortCircuitingOperatorInDiffFunc, "non-short-circuiting `?:` operator is not allowed in a differentiable function, use `select` instead.") +DIAGNOSTIC(30056, Warning, useOfNonShortCircuitingOperator, "non-short-circuiting `?:` operator is deprecated, use 'select' instead.") + DIAGNOSTIC(30043, Error, getStringHashRequiresStringLiteral, "getStringHash parameter can only accept a string literal") DIAGNOSTIC(30060, Error, expectedAType, "expected a type, got a '$0'") diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 9c27beb58..7144b3450 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4003,6 +4003,43 @@ struct ExprLoweringVisitorBase : ExprVisitor return e; } + LoweredValInfo visitSelectExpr(SelectExpr* expr) + { + // A vector typed `select` expr will turn into a normal `select` op. + if (!as(expr->arguments[0]->type.type)) + { + return visitInvokeExpr(expr); + } + + // In global scope? This is a constant, and we should emit as `select` inst. + if (!getParentFunc(context->irBuilder->getInsertLoc().getInst())) + { + return visitInvokeExpr(expr); + } + + // A scalar typed `select` expr will turn into an if-else to implement short circuiting + // semantics. + auto builder = context->irBuilder; + auto thenBlock = builder->createBlock(); + auto elseBlock = builder->createBlock(); + auto afterBlock = builder->createBlock(); + auto irCond = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[0])); + builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock); + builder->insertBlock(thenBlock); + builder->setInsertInto(thenBlock); + auto trueVal = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1])); + builder->emitBranch(afterBlock, 1, &trueVal); + builder->insertBlock(elseBlock); + builder->setInsertInto(elseBlock); + auto falseVal = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[2])); + builder->emitBranch(afterBlock, 1, &falseVal); + builder->insertBlock(afterBlock); + builder->setInsertInto(afterBlock); + auto paramType = lowerType(context, expr->type.type); + auto result = builder->emitParam(paramType); + return LoweredValInfo::simple(result); + } + LoweredValInfo visitInvokeExpr(InvokeExpr* expr) { return visitInvokeExprImpl(expr, TryClauseEnvironment()); -- cgit v1.2.3