diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2022-11-09 09:15:15 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-09 09:15:15 -0500 |
| commit | e743ddd49045284b706cc2cbbb615acc6fe3d882 (patch) | |
| tree | 71aa7e68dc38410abb86defabe7550e940882a7c /prelude | |
| parent | bf67309454032b4f92d0bc9735b608e56b16882f (diff) | |
f32tof16 and f16tof32 support for CPU targets (#2500)
* #include an absolute path didn't work - because paths were taken to always be relative.
* Float16 support for C++/CPU based targets with f16tof32 and f32tof16.
* Small correction around INF/NAN handling for f32tof16
* Small improvement to f16tof32
* Disable CUDA test for now.
Diffstat (limited to 'prelude')
| -rw-r--r-- | prelude/slang-cpp-scalar-intrinsics.h | 96 |
1 files changed, 95 insertions, 1 deletions
diff --git a/prelude/slang-cpp-scalar-intrinsics.h b/prelude/slang-cpp-scalar-intrinsics.h index 60f1dd278..66035260d 100644 --- a/prelude/slang-cpp-scalar-intrinsics.h +++ b/prelude/slang-cpp-scalar-intrinsics.h @@ -18,7 +18,6 @@ namespace SLANG_PRELUDE_NAMESPACE { # define SLANG_PRELUDE_PI 3.14159265358979323846 #endif -// ----------------------------- F32 ----------------------------------------- union Union32 { @@ -34,6 +33,101 @@ union Union64 double d; }; +// 32 bit cast conversions +SLANG_FORCE_INLINE int32_t _bitCastFloatToInt(float f) { Union32 u; u.f = f; return u.i; } +SLANG_FORCE_INLINE float _bitCastIntToFloat(int32_t i) { Union32 u; u.i = i; return u.f; } +SLANG_FORCE_INLINE uint32_t _bitCastFloatToUInt(float f) { Union32 u; u.f = f; return u.u; } +SLANG_FORCE_INLINE float _bitCastUIntToFloat(uint32_t ui) { Union32 u; u.u = ui; return u.f; } + +// ----------------------------- F16 ----------------------------------------- + + +// This impl is based on FloatToHalf that is in Slang codebase +uint32_t f32tof16(const float value) +{ + const uint32_t inBits = _bitCastFloatToUInt(value); + + // bits initially set to just the sign bit + uint32_t bits = (inBits >> 16) & 0x8000; + // Mantissa can't be used as is, as it holds last bit, for rounding. + uint32_t m = (inBits >> 12) & 0x07ff; + uint32_t e = (inBits >> 23) & 0xff; + + if (e < 103) + { + // It's zero + return bits; + } + if (e == 0xff) + { + // Could be a NAN or INF. Is INF if *input* mantissa is 0. + + // Remove last bit for rounding to make output mantissa. + m >>= 1; + + // We *assume* float16/float32 signaling bit and remaining bits + // semantics are the same. (The signalling bit convention is target specific!). + // Non signal bit's usage within mantissa for a NAN are also target specific. + + // If the m is 0, it could be because the result is INF, but it could also be because all the + // bits that made NAN were dropped as we have less mantissa bits in f16. + + // To fix for this we make non zero if m is 0 and the input mantissa was not. + // This will (typically) produce a signalling NAN. + m += uint32_t(m == 0 && (inBits & 0x007fffffu)); + + // Combine for output + return (bits | 0x7c00u | m); + } + if (e > 142) + { + // INF. + return bits | 0x7c00u; + } + if (e < 113) + { + m |= 0x0800u; + bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1); + return bits; + } + bits |= ((e - 112) << 10) | (m >> 1); + bits += m & 1; + return bits; +} + +static const float g_f16tof32Magic = _bitCastIntToFloat((127 + (127 - 15)) << 23); + +float f16tof32(const uint32_t value) +{ + const uint32_t sign = (value & 0x8000) << 16; + uint32_t exponent = (value & 0x7c00) >> 10; + uint32_t mantissa = (value & 0x03ff); + + if (exponent == 0) + { + // If mantissa is 0 we are done, as output is 0. + // If it's not zero we must have a denormal. + if (mantissa) + { + // We have a denormal so use the magic to do exponent adjust + return _bitCastIntToFloat(sign | ((value & 0x7fff) << 13)) * g_f16tof32Magic; + } + } + else + { + // If the exponent is NAN or INF exponent is 0x1f on input. + // If that's the case, we just need to set the exponent to 0xff on output + // and the mantissa can just stay the same. If its 0 it's INF, else it is NAN and we just copy the bits + // + // Else we need to correct the exponent in the normalized case. + exponent = (exponent == 0x1F) ? 0xff : (exponent + (-15 + 127)); + } + + return _bitCastUIntToFloat(sign | (exponent << 23) | (mantissa << 13)); +} + +// ----------------------------- F32 ----------------------------------------- + // Helpers SLANG_FORCE_INLINE float F32_calcSafeRadians(float radians); |
