diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-07-11 10:12:17 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-11 17:12:17 +0000 |
| commit | 57567778b7d91afe7e6325c731d54b313b8b16e9 (patch) | |
| tree | 7fbd665e6ed71680ab68b21571e84f40bab16fb5 | |
| parent | b20b9297ed20f85dec6212cad83eeacaecbaccf3 (diff) | |
Fix IEEE 754 NaN comparisons in constant folding (#7721)
* Fix IEEE 754 NaN comparisons in constant folding
Added proper NaN handling in SCCP optimization pass to follow IEEE 754 standard:
- NaN \!= any value returns true
- All other NaN comparisons return false
- Added double precision NaN detection support
- Fixed type detection to check operands instead of result type
* Avoid differentiating NaN and non-NaN cases
* format code (#76)
---------
Co-authored-by: slangbot <ellieh+slangbot@nvidia.com>
| -rw-r--r-- | source/core/slang-math.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-sccp.cpp | 44 | ||||
| -rw-r--r-- | tests/compute/ieee754-mixed-type-nan-comparisons.slang | 79 | ||||
| -rw-r--r-- | tests/compute/ieee754-nan-comparisons.slang | 150 |
4 files changed, 254 insertions, 21 deletions
diff --git a/source/core/slang-math.h b/source/core/slang-math.h index e977dc37d..c7e77fce5 100644 --- a/source/core/slang-math.h +++ b/source/core/slang-math.h @@ -106,8 +106,10 @@ public: } static inline int IsNaN(float x) { return std::isnan(x); } + static inline int IsNaN(double x) { return std::isnan(x); } static inline int IsInf(float x) { return std::isinf(x); } + static inline int IsInf(double x) { return std::isinf(x); } static inline unsigned int Ones32(unsigned int x) { diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp index 62584040a..f1b647045 100644 --- a/source/slang/slang-ir-sccp.cpp +++ b/source/slang/slang-ir-sccp.cpp @@ -511,7 +511,7 @@ struct SCCPContext template<typename TIntFunc, typename TFloatFunc> LatticeVal evalComparisonImpl( - IRType* type, + IRType*, LatticeVal v0, LatticeVal v1, const TIntFunc& intFunc, @@ -522,29 +522,31 @@ struct SCCPContext SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1) auto c1 = as<IRConstant>(v1.value); IRInst* resultVal = nullptr; - switch (type->getOp()) + + // Check the operand types, not the result type (which is always bool for comparisons) + // For mixed-type comparisons, use floating-point path if either operand is floating-point + IRType* operandType0 = c0->getDataType(); + IRType* operandType1 = c1->getDataType(); + + // Helper function to check if a type is floating-point + auto isFloatingPointType = [](IROp op) -> bool + { return op == kIROp_FloatType || op == kIROp_DoubleType || op == kIROp_HalfType; }; + + IROp op0 = operandType0->getOp(); + IROp op1 = operandType1->getOp(); + + // Use floating-point path if either operand is floating-point + // Otherwise use integer path + if (isFloatingPointType(op0) || isFloatingPointType(op1)) { - case kIROp_Int8Type: - case kIROp_Int16Type: - case kIROp_IntType: - case kIROp_Int64Type: - case kIROp_IntPtrType: - case kIROp_UInt8Type: - case kIROp_UInt16Type: - case kIROp_UIntType: - case kIROp_UInt64Type: - case kIROp_UIntPtrType: - case kIROp_BoolType: - resultVal = getBuilder()->getBoolValue(intFunc(c0->value.intVal, c1->value.intVal)); - break; - case kIROp_FloatType: - case kIROp_DoubleType: - case kIROp_HalfType: + // Floating-point path - C++ operators follow IEEE 754 automatically resultVal = getBuilder()->getBoolValue(floatFunc(c0->value.floatVal, c1->value.floatVal)); - break; - default: - break; + } + else + { + // Integer path - all integer types + resultVal = getBuilder()->getBoolValue(intFunc(c0->value.intVal, c1->value.intVal)); } if (!resultVal) return LatticeVal::getAny(); diff --git a/tests/compute/ieee754-mixed-type-nan-comparisons.slang b/tests/compute/ieee754-mixed-type-nan-comparisons.slang new file mode 100644 index 000000000..31ce6b05b --- /dev/null +++ b/tests/compute/ieee754-mixed-type-nan-comparisons.slang @@ -0,0 +1,79 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type + +// Test IEEE 754 NaN comparison behavior with mixed int/float types +// Tests the type promotion logic across integer and floating-point categories +// Also tests both operand orders since implementation bugs could affect operand handling + +static const float fNAN = 0.0f / 0.0f; +static const float fONE = 1.0f; +static const int iONE = 1; +static const int iZERO = 0; +static const uint uONE = 1u; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<uint> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint testIndex = 0; + + // Test int compared with float NaN - all should follow IEEE 754 + // CHECK: 0 + // CHECK: 1 + // CHECK: 0 + // CHECK: 0 + // CHECK: 0 + // CHECK: 0 + outputBuffer[testIndex++] = (iONE == fNAN) ? 1u : 0u; // int 1 == float NaN -> false + outputBuffer[testIndex++] = (iONE != fNAN) ? 1u : 0u; // int 1 != float NaN -> true + outputBuffer[testIndex++] = (iONE < fNAN) ? 1u : 0u; // int 1 < float NaN -> false + outputBuffer[testIndex++] = (iONE > fNAN) ? 1u : 0u; // int 1 > float NaN -> false + outputBuffer[testIndex++] = (iONE <= fNAN) ? 1u : 0u; // int 1 <= float NaN -> false + outputBuffer[testIndex++] = (iONE >= fNAN) ? 1u : 0u; // int 1 >= float NaN -> false + + // Test float NaN compared with int - same results but different operand order + // CHECK: 0 + // CHECK: 1 + // CHECK: 0 + // CHECK: 0 + // CHECK: 0 + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN == iONE) ? 1u : 0u; // float NaN == int 1 -> false + outputBuffer[testIndex++] = (fNAN != iONE) ? 1u : 0u; // float NaN != int 1 -> true + outputBuffer[testIndex++] = (fNAN < iONE) ? 1u : 0u; // float NaN < int 1 -> false + outputBuffer[testIndex++] = (fNAN > iONE) ? 1u : 0u; // float NaN > int 1 -> false + outputBuffer[testIndex++] = (fNAN <= iONE) ? 1u : 0u; // float NaN <= int 1 -> false + outputBuffer[testIndex++] = (fNAN >= iONE) ? 1u : 0u; // float NaN >= int 1 -> false + + // Test with different int values to ensure consistent behavior + // CHECK: 0 + // CHECK: 1 + // CHECK: 0 + // CHECK: 1 + outputBuffer[testIndex++] = (iZERO == fNAN) ? 1u : 0u; // int 0 == float NaN -> false + outputBuffer[testIndex++] = (iZERO != fNAN) ? 1u : 0u; // int 0 != float NaN -> true + outputBuffer[testIndex++] = (fNAN == iZERO) ? 1u : 0u; // float NaN == int 0 -> false + outputBuffer[testIndex++] = (fNAN != iZERO) ? 1u : 0u; // float NaN != int 0 -> true + + // Test unsigned int with float NaN + // CHECK: 0 + // CHECK: 1 + // CHECK: 0 + // CHECK: 1 + outputBuffer[testIndex++] = (uONE == fNAN) ? 1u : 0u; // uint 1 == float NaN -> false + outputBuffer[testIndex++] = (uONE != fNAN) ? 1u : 0u; // uint 1 != float NaN -> true + outputBuffer[testIndex++] = (fNAN == uONE) ? 1u : 0u; // float NaN == uint 1 -> false + outputBuffer[testIndex++] = (fNAN != uONE) ? 1u : 0u; // float NaN != uint 1 -> true + + // Test normal int vs float comparisons (no NaN) to ensure type promotion works + // CHECK: 1 + // CHECK: 1 + // CHECK: 1 + // CHECK: 1 + outputBuffer[testIndex++] = (iONE == fONE) ? 1u : 0u; // int 1 == float 1.0 -> true + outputBuffer[testIndex++] = (fONE == iONE) ? 1u : 0u; // float 1.0 == int 1 -> true + outputBuffer[testIndex++] = (iZERO < fONE) ? 1u : 0u; // int 0 < float 1.0 -> true + outputBuffer[testIndex++] = (fONE > iZERO) ? 1u : 0u; // float 1.0 > int 0 -> true +} + diff --git a/tests/compute/ieee754-nan-comparisons.slang b/tests/compute/ieee754-nan-comparisons.slang new file mode 100644 index 000000000..70a252cb1 --- /dev/null +++ b/tests/compute/ieee754-nan-comparisons.slang @@ -0,0 +1,150 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type -compute + +// Test IEEE 754 NaN comparison behavior +// According to IEEE 754 standard: +// - Any comparison with NaN (except !=) should return false +// - The != comparison with NaN should return true + +static const float fNAN = 0.0f / 0.0f; +static const float fPOSITIVE_INFINITY = 1.0f / 0.0f; +static const float fNEGATIVE_INFINITY = -1.0f / 0.0f; +static const float fZERO = 0.0f; +static const float fONE = 1.0f; +static const float fNEG_ONE = -1.0f; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<uint> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint testIndex = 0; + + // Test 1: NaN == NaN should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN == fNAN) ? 1u : 0u; + + // Test 2: NaN != NaN should be true + // CHECK: 1 + outputBuffer[testIndex++] = (fNAN != fNAN) ? 1u : 0u; + + // Test 3: NaN > NaN should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN > fNAN) ? 1u : 0u; + + // Test 4: NaN < NaN should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN < fNAN) ? 1u : 0u; + + // Test 5: NaN >= NaN should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN >= fNAN) ? 1u : 0u; + + // Test 6: NaN <= NaN should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN <= fNAN) ? 1u : 0u; + + // Test 7: NaN == 0.0 should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN == fZERO) ? 1u : 0u; + + // Test 8: NaN != 0.0 should be true + // CHECK: 1 + outputBuffer[testIndex++] = (fNAN != fZERO) ? 1u : 0u; + + // Test 9: NaN > 0.0 should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN > fZERO) ? 1u : 0u; + + // Test 10: NaN < 0.0 should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN < fZERO) ? 1u : 0u; + + // Test 11: NaN >= 0.0 should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN >= fZERO) ? 1u : 0u; + + // Test 12: NaN <= 0.0 should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN <= fZERO) ? 1u : 0u; + + // Test 13: 0.0 == NaN should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fZERO == fNAN) ? 1u : 0u; + + // Test 14: 0.0 != NaN should be true + // CHECK: 1 + outputBuffer[testIndex++] = (fZERO != fNAN) ? 1u : 0u; + + // Test 15: 0.0 > NaN should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fZERO > fNAN) ? 1u : 0u; + + // Test 16: 0.0 < NaN should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fZERO < fNAN) ? 1u : 0u; + + // Test 17: 0.0 >= NaN should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fZERO >= fNAN) ? 1u : 0u; + + // Test 18: 0.0 <= NaN should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fZERO <= fNAN) ? 1u : 0u; + + // Test 19: NaN == +infinity should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN == fPOSITIVE_INFINITY) ? 1u : 0u; + + // Test 20: NaN != +infinity should be true + // CHECK: 1 + outputBuffer[testIndex++] = (fNAN != fPOSITIVE_INFINITY) ? 1u : 0u; + + // Test 21: NaN > +infinity should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN > fPOSITIVE_INFINITY) ? 1u : 0u; + + // Test 22: NaN < +infinity should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN < fPOSITIVE_INFINITY) ? 1u : 0u; + + // Test 23: NaN >= +infinity should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN >= fPOSITIVE_INFINITY) ? 1u : 0u; + + // Test 24: NaN <= +infinity should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN <= fPOSITIVE_INFINITY) ? 1u : 0u; + + // Test 25: NaN == -infinity should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN == fNEGATIVE_INFINITY) ? 1u : 0u; + + // Test 26: NaN != -infinity should be true + // CHECK: 1 + outputBuffer[testIndex++] = (fNAN != fNEGATIVE_INFINITY) ? 1u : 0u; + + // Test 27: NaN > -infinity should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN > fNEGATIVE_INFINITY) ? 1u : 0u; + + // Test 28: NaN < -infinity should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN < fNEGATIVE_INFINITY) ? 1u : 0u; + + // Test 29: NaN >= -infinity should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN >= fNEGATIVE_INFINITY) ? 1u : 0u; + + // Test 30: NaN <= -infinity should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN <= fNEGATIVE_INFINITY) ? 1u : 0u; + + // Test 31: NaN == 1.0 should be false + // CHECK: 0 + outputBuffer[testIndex++] = (fNAN == fONE) ? 1u : 0u; + + // Test 32: NaN != 1.0 should be true + // CHECK: 1 + outputBuffer[testIndex++] = (fNAN != fONE) ? 1u : 0u; +} |
