From 600cce28606ba36b31756bf0422d892d0e242b63 Mon Sep 17 00:00:00 2001 From: cheneym2 Date: Tue, 3 Dec 2024 15:24:18 -0500 Subject: Core lib Metal math function fixes (#5738) In order to emit fast target implementations of some Metal-based functions (fmin(), fmax(), fmin3(), fmax3(), fmedian3()) on all targets, remove some specification regarding the handling of NaNs, and also remove the enforcement of the specification. These functions are now documented to be basically undefined now in the presence of NaN input, to make the common "is a number" case fast. Also, clarify that powr() is undefined when given a non-positive base input value, allowing us to remove an additional abs() operation that was unnecessarily coercing results to be predictable on non-Metal targets. Closes #5580 Closes #5581 Closes #5587 --- source/slang/hlsl.meta.slang | 88 +++++++++----------------------------------- 1 file changed, 18 insertions(+), 70 deletions(-) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 84beb3b4b..7610dd395 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -10275,12 +10275,11 @@ vector max3(vector x, vector y, vector z) } } -/// Floating-point maximum considering NaN. +/// Floating-point maximum. /// @param x The first value to compare. /// @param y The second value to compare. -/// @return The larger of the two values, element-wise if vector typed, considering NaN. -/// @remarks For metal, if either value is NaN, the other value is returned. If both values are NaN, NaN is returned. -/// For other targets, if `x` is NaN, `y` is returned, otherwise the larger of `x` and `y` is returned. +/// @return The larger of the two values, element-wise if vector typed. +/// @remarks Result is `y` if `x` < `y`, either `x` or `y` if both `x` and `y` are zeros, otherwise `x`. Which operand is the result is undefined if one of the operands is a NaN. /// @category math __generic [__readNone] @@ -10291,7 +10290,6 @@ T fmax(T x, T y) { case metal: __intrinsic_asm "fmax"; default: - if (isnan(x)) return y; return max(x, y); } } @@ -10309,11 +10307,12 @@ vector fmax(vector x, vector y) } } -/// Floating-point maximum of 3 inputs, considering NaN. +/// Floating-point maximum of 3 inputs. /// @param x The first value to compare. /// @param y The second value to compare. /// @param z The third value to compare. -/// @return The largest of the three values, element-wise if vector typed, considering NaN. If all three values are NaN, NaN is returned. If any value is NaN, the largest is returned. +/// @return The largest of the three values, element-wise if vector typed. +/// @remarks If any operand in the 3-way comparison is NaN, it is undefined which operand is returned. /// @category math __generic [__readNone] @@ -10325,25 +10324,6 @@ T fmax3(T x, T y, T z) case metal: __intrinsic_asm "fmax3"; default: { - bool isnanX = isnan(x); - bool isnanY = isnan(y); - bool isnanZ = isnan(z); - - if (isnanX) - { - return isnanY ? z : y; - } - else if (isnanY) - { - if (isnanZ) - return x; - return max(x, z); - } - else if (isnanZ) - { - return max(x, y); - } - return max(y, max(x, z)); } } @@ -10522,12 +10502,11 @@ vector min3(vector x, vector y, vector z) } } -/// Floating-point minimum considering NaN. +/// Floating-point minimum. /// @param x The first value to compare. /// @param y The second value to compare. -/// @return The smaller of the two values, element-wise if vector typed, considering NaN. -/// @remarks For metal, if either value is NaN, the other value is returned. If both values are NaN, NaN is returned. -/// For other targets, if `x` is NaN, `y` is returned, otherwise the smaller of `x` and `y` is returned. +/// @return The smaller of the two values, element-wise if vector typed. +/// @remarks Result is `x` if `x` < `y`, either `x` or `y` if both `x` and `y` are zeros, otherwise `y`. Which operand is the result is undefined if one of the operands is a NaN. /// @category math __generic [__readNone] @@ -10538,7 +10517,6 @@ T fmin(T x, T y) { case metal: __intrinsic_asm "fmin"; default: - if (isnan(x)) return y; return min(x, y); } } @@ -10556,11 +10534,12 @@ vector fmin(vector x, vector y) } } -/// Floating-point minimum of 3 inputs, considering NaN. +/// Floating-point minimum of 3 inputs. /// @param x The first value to compare. /// @param y The second value to compare. /// @param z The third value to compare. -/// @return The smallest of the three values, element-wise if vector typed, considering NaN. If all three values are NaN, NaN is returned. If any value is NaN, the smallest non-NaN value is returned. +/// @return The smallest of the three values, element-wise if vector typed. +/// @remarks If any operand in the 3-way comparison is NaN, it is undefined which operand is returned. /// @category math __generic [__readNone] @@ -10572,25 +10551,6 @@ T fmin3(T x, T y, T z) case metal: __intrinsic_asm "fmin3"; default: { - bool isnanX = isnan(x); - bool isnanY = isnan(y); - bool isnanZ = isnan(z); - - if (isnan(x)) - { - return isnanY ? z : y; - } - else if (isnanY) - { - if (isnanZ) - return x; - return min(x, z); - } - else if (isnanZ) - { - return min(x, y); - } - return min(x, min(y, z)); } } @@ -10664,12 +10624,13 @@ vector median3(vector x, vector y, vector z) } } -/// Floating-point median considering NaN. +/// Floating-point median. /// @param x The first value to compare. /// @param y The second value to compare. /// @param z The third value to compare. -/// @return The median of the three values, element-wise if vector typed, considering NaN. If no value is NaN, the median is returned. If any value is NaN, one of the non-NaN values is returned. +/// @return The median of the three values, element-wise if vector typed. /// @remarks For metal, this is implemented with the fmedian3 intrinsic. +/// If any value is NaN, it is unspecified which operand is returned. /// @category math __generic [__readNone] @@ -10681,20 +10642,6 @@ T fmedian3(T x, T y, T z) case metal: __intrinsic_asm "fmedian3"; default: { - bool isnanX = isnan(x); - bool isnanY = isnan(y); - bool isnanZ = isnan(z); - - if (isnanX) - { - return isnanY ? z : y; - } - else if (isnanY || isnanZ) - { - // "the function can return either non-NaN value" - return x; - } - return median3(x, y, z); } } @@ -11350,6 +11297,7 @@ matrix pow(matrix x, matrix y) /// @param y The exponent value. /// @return The value of `x` raised to the power of `y`. /// @category math +/// @remarks Return value is undefined for non-positive values of `x`. __generic [__readNone] [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] @@ -11359,7 +11307,7 @@ T powr(T x, T y) { case metal: __intrinsic_asm "powr"; default: - return pow(abs(x), y); + return pow(x, y); } } @@ -11372,7 +11320,7 @@ vector powr(vector x, vector y) { case metal: __intrinsic_asm "powr"; default: - return pow(abs(x), y); + return pow(x, y); } } -- cgit v1.2.3