summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-24 20:17:46 -0700
committerGitHub <noreply@github.com>2023-03-24 20:17:46 -0700
commit666af0962b6ab41489a3a3287db83f77c2f6461a (patch)
tree81a1247188ac03f1e8132e58ec31ae0f28c8c530 /source
parent7292edbd3eba3da7e8490ad19169a7d18283057a (diff)
Switch to short circuiting semantics for scalar `?:` operator. (#2733)
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang13
-rw-r--r--source/slang/slang-check-expr.cpp34
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-diagnostic-defs.h3
-rw-r--r--source/slang/slang-lower-to-ir.cpp37
5 files changed, 85 insertions, 4 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 6581cc605..82a60a612 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -315,13 +315,18 @@ U operator,(T left, U right)
return right;
}
-// The ternary `?:` operator does not short-circuit in HLSL, and Slang continues to
-// follow that definition, so that this operator is effectively just an ordinary
-// function, rather than a special-case piece of syntax.
-//
+// The ternary `?:` operator does not short-circuit in HLSL, and Slang no longer
+// follow that definition for the scalar condition overload, so this declaration just serves
+// for type-checking purpose only.
+
__generic<T> __intrinsic_op(select) T operator?:(bool condition, T ifTrue, T ifFalse);
__generic<T, let N : int> __intrinsic_op(select) vector<T,N> operator?:(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse);
+// Users are advised to use `select` instead if non-short-circuiting behavior is intended.
+__generic<T> __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse);
+__generic<T, let N : int> __intrinsic_op(select) vector<T,N> select(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse);
+
+
${{{{
// We are going to use code generation to produce the
// declarations for all of our base types.
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index bfad1dbfe..cfcb15269 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2022,6 +2022,40 @@ namespace Slang
return rs;
}
+ Expr* SemanticsExprVisitor::visitSelectExpr(SelectExpr* expr)
+ {
+ auto result = visitInvokeExpr(expr);
+ if (as<ErrorType>(result->type.type))
+ return result;
+ auto invokeExpr = as<InvokeExpr>(result);
+ if (!result)
+ return result;
+ if (invokeExpr->arguments.getCount() != 3)
+ return result;
+
+ if (as<BasicExpressionType>(invokeExpr->arguments[0]->type.type))
+ {
+ auto newArgs = invokeExpr->arguments;
+ expr->arguments.clear();
+ expr->arguments = newArgs;
+ expr->type = invokeExpr->type;
+ return expr;
+ }
+
+ if (getParentDifferentiableAttribute())
+ {
+ // If we are in a differentiable func, issue
+ // a diagnostic on use of non short-circuiting select.
+ getSink()->diagnose(expr->loc, Diagnostics::useOfNonShortCircuitingOperatorInDiffFunc);
+ }
+ else
+ {
+ // For all other functions, we issue a warning for deprecation of vector-typed ?: operator.
+ getSink()->diagnose(expr->loc, Diagnostics::useOfNonShortCircuitingOperator);
+ }
+ return result;
+ }
+
Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr *expr)
{
// check the base expression first
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 4181ca43b..50992b00b 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1896,6 +1896,8 @@ namespace Slang
Expr* visitInvokeExpr(InvokeExpr *expr);
+ Expr* visitSelectExpr(SelectExpr* expr);
+
Expr* visitVarExpr(VarExpr *expr);
Expr* visitTypeCastExpr(TypeCastExpr * expr);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index e3e9cfc44..c3e0adbca 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -275,6 +275,9 @@ DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$
DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'")
DIAGNOSTIC(30053, Error, breakLabelNotFound, "label '$0' used as break target is not found.")
DIAGNOSTIC(30054, Error, targetLabelDoesNotMarkBreakableStmt, "invalid break target: statement labeled '$0' is not breakable.")
+DIAGNOSTIC(30055, Error, useOfNonShortCircuitingOperatorInDiffFunc, "non-short-circuiting `?:` operator is not allowed in a differentiable function, use `select` instead.")
+DIAGNOSTIC(30056, Warning, useOfNonShortCircuitingOperator, "non-short-circuiting `?:` operator is deprecated, use 'select' instead.")
+
DIAGNOSTIC(30043, Error, getStringHashRequiresStringLiteral, "getStringHash parameter can only accept a string literal")
DIAGNOSTIC(30060, Error, expectedAType, "expected a type, got a '$0'")
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 9c27beb58..7144b3450 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4003,6 +4003,43 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
return e;
}
+ LoweredValInfo visitSelectExpr(SelectExpr* expr)
+ {
+ // A vector typed `select` expr will turn into a normal `select` op.
+ if (!as<BasicExpressionType>(expr->arguments[0]->type.type))
+ {
+ return visitInvokeExpr(expr);
+ }
+
+ // In global scope? This is a constant, and we should emit as `select` inst.
+ if (!getParentFunc(context->irBuilder->getInsertLoc().getInst()))
+ {
+ return visitInvokeExpr(expr);
+ }
+
+ // A scalar typed `select` expr will turn into an if-else to implement short circuiting
+ // semantics.
+ auto builder = context->irBuilder;
+ auto thenBlock = builder->createBlock();
+ auto elseBlock = builder->createBlock();
+ auto afterBlock = builder->createBlock();
+ auto irCond = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[0]));
+ builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock);
+ builder->insertBlock(thenBlock);
+ builder->setInsertInto(thenBlock);
+ auto trueVal = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1]));
+ builder->emitBranch(afterBlock, 1, &trueVal);
+ builder->insertBlock(elseBlock);
+ builder->setInsertInto(elseBlock);
+ auto falseVal = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[2]));
+ builder->emitBranch(afterBlock, 1, &falseVal);
+ builder->insertBlock(afterBlock);
+ builder->setInsertInto(afterBlock);
+ auto paramType = lowerType(context, expr->type.type);
+ auto result = builder->emitParam(paramType);
+ return LoweredValInfo::simple(result);
+ }
+
LoweredValInfo visitInvokeExpr(InvokeExpr* expr)
{
return visitInvokeExprImpl(expr, TryClauseEnvironment());