summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorcheneym2 <acheney@nvidia.com>2024-12-03 15:24:18 -0500
committerGitHub <noreply@github.com>2024-12-03 20:24:18 +0000
commit600cce28606ba36b31756bf0422d892d0e242b63 (patch)
treeefcdddb69ec01f9983c328804325f4d496e8e23e /source
parent5d8cf475b352ab517c565ccee59461640da63a2a (diff)
Core lib Metal math function fixes (#5738)
In order to emit fast target implementations of some Metal-based functions (fmin(), fmax(), fmin3(), fmax3(), fmedian3()) on all targets, remove some specification regarding the handling of NaNs, and also remove the enforcement of the specification. These functions are now documented to be basically undefined now in the presence of NaN input, to make the common "is a number" case fast. Also, clarify that powr() is undefined when given a non-positive base input value, allowing us to remove an additional abs() operation that was unnecessarily coercing results to be predictable on non-Metal targets. Closes #5580 Closes #5581 Closes #5587
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang88
1 files changed, 18 insertions, 70 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 84beb3b4b..7610dd395 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -10275,12 +10275,11 @@ vector<T,N> max3(vector<T,N> x, vector<T,N> y, vector<T,N> z)
}
}
-/// Floating-point maximum considering NaN.
+/// Floating-point maximum.
/// @param x The first value to compare.
/// @param y The second value to compare.
-/// @return The larger of the two values, element-wise if vector typed, considering NaN.
-/// @remarks For metal, if either value is NaN, the other value is returned. If both values are NaN, NaN is returned.
-/// For other targets, if `x` is NaN, `y` is returned, otherwise the larger of `x` and `y` is returned.
+/// @return The larger of the two values, element-wise if vector typed.
+/// @remarks Result is `y` if `x` < `y`, either `x` or `y` if both `x` and `y` are zeros, otherwise `x`. Which operand is the result is undefined if one of the operands is a NaN.
/// @category math
__generic<T : __BuiltinFloatingPointType>
[__readNone]
@@ -10291,7 +10290,6 @@ T fmax(T x, T y)
{
case metal: __intrinsic_asm "fmax";
default:
- if (isnan(x)) return y;
return max(x, y);
}
}
@@ -10309,11 +10307,12 @@ vector<T,N> fmax(vector<T,N> x, vector<T,N> y)
}
}
-/// Floating-point maximum of 3 inputs, considering NaN.
+/// Floating-point maximum of 3 inputs.
/// @param x The first value to compare.
/// @param y The second value to compare.
/// @param z The third value to compare.
-/// @return The largest of the three values, element-wise if vector typed, considering NaN. If all three values are NaN, NaN is returned. If any value is NaN, the largest is returned.
+/// @return The largest of the three values, element-wise if vector typed.
+/// @remarks If any operand in the 3-way comparison is NaN, it is undefined which operand is returned.
/// @category math
__generic<T : __BuiltinFloatingPointType>
[__readNone]
@@ -10325,25 +10324,6 @@ T fmax3(T x, T y, T z)
case metal: __intrinsic_asm "fmax3";
default:
{
- bool isnanX = isnan(x);
- bool isnanY = isnan(y);
- bool isnanZ = isnan(z);
-
- if (isnanX)
- {
- return isnanY ? z : y;
- }
- else if (isnanY)
- {
- if (isnanZ)
- return x;
- return max(x, z);
- }
- else if (isnanZ)
- {
- return max(x, y);
- }
-
return max(y, max(x, z));
}
}
@@ -10522,12 +10502,11 @@ vector<T,N> min3(vector<T,N> x, vector<T,N> y, vector<T,N> z)
}
}
-/// Floating-point minimum considering NaN.
+/// Floating-point minimum.
/// @param x The first value to compare.
/// @param y The second value to compare.
-/// @return The smaller of the two values, element-wise if vector typed, considering NaN.
-/// @remarks For metal, if either value is NaN, the other value is returned. If both values are NaN, NaN is returned.
-/// For other targets, if `x` is NaN, `y` is returned, otherwise the smaller of `x` and `y` is returned.
+/// @return The smaller of the two values, element-wise if vector typed.
+/// @remarks Result is `x` if `x` < `y`, either `x` or `y` if both `x` and `y` are zeros, otherwise `y`. Which operand is the result is undefined if one of the operands is a NaN.
/// @category math
__generic<T : __BuiltinFloatingPointType>
[__readNone]
@@ -10538,7 +10517,6 @@ T fmin(T x, T y)
{
case metal: __intrinsic_asm "fmin";
default:
- if (isnan(x)) return y;
return min(x, y);
}
}
@@ -10556,11 +10534,12 @@ vector<T,N> fmin(vector<T,N> x, vector<T,N> y)
}
}
-/// Floating-point minimum of 3 inputs, considering NaN.
+/// Floating-point minimum of 3 inputs.
/// @param x The first value to compare.
/// @param y The second value to compare.
/// @param z The third value to compare.
-/// @return The smallest of the three values, element-wise if vector typed, considering NaN. If all three values are NaN, NaN is returned. If any value is NaN, the smallest non-NaN value is returned.
+/// @return The smallest of the three values, element-wise if vector typed.
+/// @remarks If any operand in the 3-way comparison is NaN, it is undefined which operand is returned.
/// @category math
__generic<T : __BuiltinFloatingPointType>
[__readNone]
@@ -10572,25 +10551,6 @@ T fmin3(T x, T y, T z)
case metal: __intrinsic_asm "fmin3";
default:
{
- bool isnanX = isnan(x);
- bool isnanY = isnan(y);
- bool isnanZ = isnan(z);
-
- if (isnan(x))
- {
- return isnanY ? z : y;
- }
- else if (isnanY)
- {
- if (isnanZ)
- return x;
- return min(x, z);
- }
- else if (isnanZ)
- {
- return min(x, y);
- }
-
return min(x, min(y, z));
}
}
@@ -10664,12 +10624,13 @@ vector<T,N> median3(vector<T,N> x, vector<T,N> y, vector<T,N> z)
}
}
-/// Floating-point median considering NaN.
+/// Floating-point median.
/// @param x The first value to compare.
/// @param y The second value to compare.
/// @param z The third value to compare.
-/// @return The median of the three values, element-wise if vector typed, considering NaN. If no value is NaN, the median is returned. If any value is NaN, one of the non-NaN values is returned.
+/// @return The median of the three values, element-wise if vector typed.
/// @remarks For metal, this is implemented with the fmedian3 intrinsic.
+/// If any value is NaN, it is unspecified which operand is returned.
/// @category math
__generic<T : __BuiltinFloatingPointType>
[__readNone]
@@ -10681,20 +10642,6 @@ T fmedian3(T x, T y, T z)
case metal: __intrinsic_asm "fmedian3";
default:
{
- bool isnanX = isnan(x);
- bool isnanY = isnan(y);
- bool isnanZ = isnan(z);
-
- if (isnanX)
- {
- return isnanY ? z : y;
- }
- else if (isnanY || isnanZ)
- {
- // "the function can return either non-NaN value"
- return x;
- }
-
return median3(x, y, z);
}
}
@@ -11350,6 +11297,7 @@ matrix<T,N,M> pow(matrix<T,N,M> x, matrix<T,N,M> y)
/// @param y The exponent value.
/// @return The value of `x` raised to the power of `y`.
/// @category math
+/// @remarks Return value is undefined for non-positive values of `x`.
__generic<T : __BuiltinFloatingPointType>
[__readNone]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
@@ -11359,7 +11307,7 @@ T powr(T x, T y)
{
case metal: __intrinsic_asm "powr";
default:
- return pow(abs(x), y);
+ return pow(x, y);
}
}
@@ -11372,7 +11320,7 @@ vector<T, N> powr(vector<T, N> x, vector<T, N> y)
{
case metal: __intrinsic_asm "powr";
default:
- return pow(abs(x), y);
+ return pow(x, y);
}
}