diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-24 20:17:46 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-24 20:17:46 -0700 |
| commit | 666af0962b6ab41489a3a3287db83f77c2f6461a (patch) | |
| tree | 81a1247188ac03f1e8132e58ec31ae0f28c8c530 | |
| parent | 7292edbd3eba3da7e8490ad19169a7d18283057a (diff) | |
Switch to short circuiting semantics for scalar `?:` operator. (#2733)
| -rw-r--r-- | source/slang/core.meta.slang | 13 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 37 | ||||
| -rw-r--r-- | tests/autodiff/select.slang | 4 | ||||
| -rw-r--r-- | tests/cross-compile/vector-comparison.slang | 2 | ||||
| -rw-r--r-- | tests/cross-compile/vector-comparison.slang.glsl | 7 |
8 files changed, 92 insertions, 10 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 6581cc605..82a60a612 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -315,13 +315,18 @@ U operator,(T left, U right) return right; } -// The ternary `?:` operator does not short-circuit in HLSL, and Slang continues to -// follow that definition, so that this operator is effectively just an ordinary -// function, rather than a special-case piece of syntax. -// +// The ternary `?:` operator does not short-circuit in HLSL, and Slang no longer +// follow that definition for the scalar condition overload, so this declaration just serves +// for type-checking purpose only. + __generic<T> __intrinsic_op(select) T operator?:(bool condition, T ifTrue, T ifFalse); __generic<T, let N : int> __intrinsic_op(select) vector<T,N> operator?:(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse); +// Users are advised to use `select` instead if non-short-circuiting behavior is intended. +__generic<T> __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse); +__generic<T, let N : int> __intrinsic_op(select) vector<T,N> select(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse); + + ${{{{ // We are going to use code generation to produce the // declarations for all of our base types. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index bfad1dbfe..cfcb15269 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2022,6 +2022,40 @@ namespace Slang return rs; } + Expr* SemanticsExprVisitor::visitSelectExpr(SelectExpr* expr) + { + auto result = visitInvokeExpr(expr); + if (as<ErrorType>(result->type.type)) + return result; + auto invokeExpr = as<InvokeExpr>(result); + if (!result) + return result; + if (invokeExpr->arguments.getCount() != 3) + return result; + + if (as<BasicExpressionType>(invokeExpr->arguments[0]->type.type)) + { + auto newArgs = invokeExpr->arguments; + expr->arguments.clear(); + expr->arguments = newArgs; + expr->type = invokeExpr->type; + return expr; + } + + if (getParentDifferentiableAttribute()) + { + // If we are in a differentiable func, issue + // a diagnostic on use of non short-circuiting select. + getSink()->diagnose(expr->loc, Diagnostics::useOfNonShortCircuitingOperatorInDiffFunc); + } + else + { + // For all other functions, we issue a warning for deprecation of vector-typed ?: operator. + getSink()->diagnose(expr->loc, Diagnostics::useOfNonShortCircuitingOperator); + } + return result; + } + Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr *expr) { // check the base expression first diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 4181ca43b..50992b00b 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1896,6 +1896,8 @@ namespace Slang Expr* visitInvokeExpr(InvokeExpr *expr); + Expr* visitSelectExpr(SelectExpr* expr); + Expr* visitVarExpr(VarExpr *expr); Expr* visitTypeCastExpr(TypeCastExpr * expr); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index e3e9cfc44..c3e0adbca 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -275,6 +275,9 @@ DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$ DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'") DIAGNOSTIC(30053, Error, breakLabelNotFound, "label '$0' used as break target is not found.") DIAGNOSTIC(30054, Error, targetLabelDoesNotMarkBreakableStmt, "invalid break target: statement labeled '$0' is not breakable.") +DIAGNOSTIC(30055, Error, useOfNonShortCircuitingOperatorInDiffFunc, "non-short-circuiting `?:` operator is not allowed in a differentiable function, use `select` instead.") +DIAGNOSTIC(30056, Warning, useOfNonShortCircuitingOperator, "non-short-circuiting `?:` operator is deprecated, use 'select' instead.") + DIAGNOSTIC(30043, Error, getStringHashRequiresStringLiteral, "getStringHash parameter can only accept a string literal") DIAGNOSTIC(30060, Error, expectedAType, "expected a type, got a '$0'") diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 9c27beb58..7144b3450 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4003,6 +4003,43 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return e; } + LoweredValInfo visitSelectExpr(SelectExpr* expr) + { + // A vector typed `select` expr will turn into a normal `select` op. + if (!as<BasicExpressionType>(expr->arguments[0]->type.type)) + { + return visitInvokeExpr(expr); + } + + // In global scope? This is a constant, and we should emit as `select` inst. + if (!getParentFunc(context->irBuilder->getInsertLoc().getInst())) + { + return visitInvokeExpr(expr); + } + + // A scalar typed `select` expr will turn into an if-else to implement short circuiting + // semantics. + auto builder = context->irBuilder; + auto thenBlock = builder->createBlock(); + auto elseBlock = builder->createBlock(); + auto afterBlock = builder->createBlock(); + auto irCond = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[0])); + builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock); + builder->insertBlock(thenBlock); + builder->setInsertInto(thenBlock); + auto trueVal = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[1])); + builder->emitBranch(afterBlock, 1, &trueVal); + builder->insertBlock(elseBlock); + builder->setInsertInto(elseBlock); + auto falseVal = getSimpleVal(context, lowerRValueExpr(context, expr->arguments[2])); + builder->emitBranch(afterBlock, 1, &falseVal); + builder->insertBlock(afterBlock); + builder->setInsertInto(afterBlock); + auto paramType = lowerType(context, expr->type.type); + auto result = builder->emitParam(paramType); + return LoweredValInfo::simple(result); + } + LoweredValInfo visitInvokeExpr(InvokeExpr* expr) { return visitInvokeExprImpl(expr, TryClauseEnvironment()); diff --git a/tests/autodiff/select.slang b/tests/autodiff/select.slang index 261c170db..20abab977 100644 --- a/tests/autodiff/select.slang +++ b/tests/autodiff/select.slang @@ -10,7 +10,7 @@ typedef float.Differential dfloat; [BackwardDifferentiable] float f(float x, float y) { - return x > 0.5 ? x : y; + return x > 0.0 ? sqrt(x)*sqrt(x) : y; } [numthreads(1, 1, 1)] @@ -27,7 +27,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) } { - dpfloat dpa = dpfloat(0.3, 1.0); + dpfloat dpa = dpfloat(-0.3, 1.0); dpfloat dpb = dpfloat(0.3, 1.0); __bwd_diff(f)(dpa, dpb, 1.0); diff --git a/tests/cross-compile/vector-comparison.slang b/tests/cross-compile/vector-comparison.slang index d1fdcfd4a..f363eb40c 100644 --- a/tests/cross-compile/vector-comparison.slang +++ b/tests/cross-compile/vector-comparison.slang @@ -1,6 +1,6 @@ // vector-comparison.slang -//TEST:CROSS_COMPILE:-target spirv-assembly -entry main -stage fragment +//TEST:CROSS_COMPILE:-target spirv-assembly -entry main -stage fragment -Wno-use-of-non-short-circuiting-operator // This test ensures that we cross-compile vector comparison operators // correctly to GLSL diff --git a/tests/cross-compile/vector-comparison.slang.glsl b/tests/cross-compile/vector-comparison.slang.glsl index 2497055a0..3e6f7b9c2 100644 --- a/tests/cross-compile/vector-comparison.slang.glsl +++ b/tests/cross-compile/vector-comparison.slang.glsl @@ -1,8 +1,6 @@ -//TEST_IGNORE_FILE #version 450 layout(row_major) uniform; layout(row_major) buffer; - struct Param_0 { vec4 a_0; @@ -19,6 +17,9 @@ out vec4 _S2; void main() { - _S2 = mix(vec4(3.0), vec4(2.0), (equal(params_0._data.a_0,params_0._data.b_0))) + mix(vec4(3.0), vec4(2.0), (lessThan(params_0._data.a_0,params_0._data.b_0))) + mix(vec4(3.0), vec4(2.0), (greaterThan(params_0._data.a_0,params_0._data.b_0))) + mix(vec4(3.0), vec4(2.0), (lessThanEqual(params_0._data.a_0,params_0._data.b_0))) + mix(vec4(3.0), vec4(2.0), (greaterThanEqual(params_0._data.a_0,params_0._data.b_0))) + mix(vec4(3.0), vec4(2.0), (notEqual(params_0._data.a_0,params_0._data.b_0))); + + const vec4 _S3 = vec4(2.0); + const vec4 _S4 = vec4(3.0); + _S2 = mix(_S4, _S3, (equal(params_0._data.a_0,params_0._data.b_0))) + mix(_S4, _S3, (lessThan(params_0._data.a_0,params_0._data.b_0))) + mix(_S4, _S3, (greaterThan(params_0._data.a_0,params_0._data.b_0))) + mix(_S4, _S3, (lessThanEqual(params_0._data.a_0,params_0._data.b_0))) + mix(_S4, _S3, (greaterThanEqual(params_0._data.a_0,params_0._data.b_0))) + mix(_S4, _S3, (notEqual(params_0._data.a_0,params_0._data.b_0))); return; } |
