diff options
| author | cheneym2 <acheney@nvidia.com> | 2024-12-03 15:24:18 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-03 20:24:18 +0000 |
| commit | 600cce28606ba36b31756bf0422d892d0e242b63 (patch) | |
| tree | efcdddb69ec01f9983c328804325f4d496e8e23e | |
| parent | 5d8cf475b352ab517c565ccee59461640da63a2a (diff) | |
Core lib Metal math function fixes (#5738)
In order to emit fast target implementations of some Metal-based
functions (fmin(), fmax(), fmin3(), fmax3(), fmedian3()) on all
targets, remove some specification regarding the handling of NaNs,
and also remove the enforcement of the specification.
These functions are now documented to be basically undefined now
in the presence of NaN input, to make the common "is a number"
case fast.
Also, clarify that powr() is undefined when given a non-positive
base input value, allowing us to remove an additional abs()
operation that was unnecessarily coercing results to be predictable
on non-Metal targets.
Closes #5580
Closes #5581
Closes #5587
| -rw-r--r-- | source/slang/hlsl.meta.slang | 88 |
1 files changed, 18 insertions, 70 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 84beb3b4b..7610dd395 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -10275,12 +10275,11 @@ vector<T,N> max3(vector<T,N> x, vector<T,N> y, vector<T,N> z) } } -/// Floating-point maximum considering NaN. +/// Floating-point maximum. /// @param x The first value to compare. /// @param y The second value to compare. -/// @return The larger of the two values, element-wise if vector typed, considering NaN. -/// @remarks For metal, if either value is NaN, the other value is returned. If both values are NaN, NaN is returned. -/// For other targets, if `x` is NaN, `y` is returned, otherwise the larger of `x` and `y` is returned. +/// @return The larger of the two values, element-wise if vector typed. +/// @remarks Result is `y` if `x` < `y`, either `x` or `y` if both `x` and `y` are zeros, otherwise `x`. Which operand is the result is undefined if one of the operands is a NaN. /// @category math __generic<T : __BuiltinFloatingPointType> [__readNone] @@ -10291,7 +10290,6 @@ T fmax(T x, T y) { case metal: __intrinsic_asm "fmax"; default: - if (isnan(x)) return y; return max(x, y); } } @@ -10309,11 +10307,12 @@ vector<T,N> fmax(vector<T,N> x, vector<T,N> y) } } -/// Floating-point maximum of 3 inputs, considering NaN. +/// Floating-point maximum of 3 inputs. /// @param x The first value to compare. /// @param y The second value to compare. /// @param z The third value to compare. -/// @return The largest of the three values, element-wise if vector typed, considering NaN. If all three values are NaN, NaN is returned. If any value is NaN, the largest is returned. +/// @return The largest of the three values, element-wise if vector typed. +/// @remarks If any operand in the 3-way comparison is NaN, it is undefined which operand is returned. /// @category math __generic<T : __BuiltinFloatingPointType> [__readNone] @@ -10325,25 +10324,6 @@ T fmax3(T x, T y, T z) case metal: __intrinsic_asm "fmax3"; default: { - bool isnanX = isnan(x); - bool isnanY = isnan(y); - bool isnanZ = isnan(z); - - if (isnanX) - { - return isnanY ? z : y; - } - else if (isnanY) - { - if (isnanZ) - return x; - return max(x, z); - } - else if (isnanZ) - { - return max(x, y); - } - return max(y, max(x, z)); } } @@ -10522,12 +10502,11 @@ vector<T,N> min3(vector<T,N> x, vector<T,N> y, vector<T,N> z) } } -/// Floating-point minimum considering NaN. +/// Floating-point minimum. /// @param x The first value to compare. /// @param y The second value to compare. -/// @return The smaller of the two values, element-wise if vector typed, considering NaN. -/// @remarks For metal, if either value is NaN, the other value is returned. If both values are NaN, NaN is returned. -/// For other targets, if `x` is NaN, `y` is returned, otherwise the smaller of `x` and `y` is returned. +/// @return The smaller of the two values, element-wise if vector typed. +/// @remarks Result is `x` if `x` < `y`, either `x` or `y` if both `x` and `y` are zeros, otherwise `y`. Which operand is the result is undefined if one of the operands is a NaN. /// @category math __generic<T : __BuiltinFloatingPointType> [__readNone] @@ -10538,7 +10517,6 @@ T fmin(T x, T y) { case metal: __intrinsic_asm "fmin"; default: - if (isnan(x)) return y; return min(x, y); } } @@ -10556,11 +10534,12 @@ vector<T,N> fmin(vector<T,N> x, vector<T,N> y) } } -/// Floating-point minimum of 3 inputs, considering NaN. +/// Floating-point minimum of 3 inputs. /// @param x The first value to compare. /// @param y The second value to compare. /// @param z The third value to compare. -/// @return The smallest of the three values, element-wise if vector typed, considering NaN. If all three values are NaN, NaN is returned. If any value is NaN, the smallest non-NaN value is returned. +/// @return The smallest of the three values, element-wise if vector typed. +/// @remarks If any operand in the 3-way comparison is NaN, it is undefined which operand is returned. /// @category math __generic<T : __BuiltinFloatingPointType> [__readNone] @@ -10572,25 +10551,6 @@ T fmin3(T x, T y, T z) case metal: __intrinsic_asm "fmin3"; default: { - bool isnanX = isnan(x); - bool isnanY = isnan(y); - bool isnanZ = isnan(z); - - if (isnan(x)) - { - return isnanY ? z : y; - } - else if (isnanY) - { - if (isnanZ) - return x; - return min(x, z); - } - else if (isnanZ) - { - return min(x, y); - } - return min(x, min(y, z)); } } @@ -10664,12 +10624,13 @@ vector<T,N> median3(vector<T,N> x, vector<T,N> y, vector<T,N> z) } } -/// Floating-point median considering NaN. +/// Floating-point median. /// @param x The first value to compare. /// @param y The second value to compare. /// @param z The third value to compare. -/// @return The median of the three values, element-wise if vector typed, considering NaN. If no value is NaN, the median is returned. If any value is NaN, one of the non-NaN values is returned. +/// @return The median of the three values, element-wise if vector typed. /// @remarks For metal, this is implemented with the fmedian3 intrinsic. +/// If any value is NaN, it is unspecified which operand is returned. /// @category math __generic<T : __BuiltinFloatingPointType> [__readNone] @@ -10681,20 +10642,6 @@ T fmedian3(T x, T y, T z) case metal: __intrinsic_asm "fmedian3"; default: { - bool isnanX = isnan(x); - bool isnanY = isnan(y); - bool isnanZ = isnan(z); - - if (isnanX) - { - return isnanY ? z : y; - } - else if (isnanY || isnanZ) - { - // "the function can return either non-NaN value" - return x; - } - return median3(x, y, z); } } @@ -11350,6 +11297,7 @@ matrix<T,N,M> pow(matrix<T,N,M> x, matrix<T,N,M> y) /// @param y The exponent value. /// @return The value of `x` raised to the power of `y`. /// @category math +/// @remarks Return value is undefined for non-positive values of `x`. __generic<T : __BuiltinFloatingPointType> [__readNone] [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] @@ -11359,7 +11307,7 @@ T powr(T x, T y) { case metal: __intrinsic_asm "powr"; default: - return pow(abs(x), y); + return pow(x, y); } } @@ -11372,7 +11320,7 @@ vector<T, N> powr(vector<T, N> x, vector<T, N> y) { case metal: __intrinsic_asm "powr"; default: - return pow(abs(x), y); + return pow(x, y); } } |
