summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-07-11 10:12:17 -0700
committerGitHub <noreply@github.com>2025-07-11 17:12:17 +0000
commit57567778b7d91afe7e6325c731d54b313b8b16e9 (patch)
tree7fbd665e6ed71680ab68b21571e84f40bab16fb5
parentb20b9297ed20f85dec6212cad83eeacaecbaccf3 (diff)
Fix IEEE 754 NaN comparisons in constant folding (#7721)
* Fix IEEE 754 NaN comparisons in constant folding Added proper NaN handling in SCCP optimization pass to follow IEEE 754 standard: - NaN \!= any value returns true - All other NaN comparisons return false - Added double precision NaN detection support - Fixed type detection to check operands instead of result type * Avoid differentiating NaN and non-NaN cases * format code (#76) --------- Co-authored-by: slangbot <ellieh+slangbot@nvidia.com>
-rw-r--r--source/core/slang-math.h2
-rw-r--r--source/slang/slang-ir-sccp.cpp44
-rw-r--r--tests/compute/ieee754-mixed-type-nan-comparisons.slang79
-rw-r--r--tests/compute/ieee754-nan-comparisons.slang150
4 files changed, 254 insertions, 21 deletions
diff --git a/source/core/slang-math.h b/source/core/slang-math.h
index e977dc37d..c7e77fce5 100644
--- a/source/core/slang-math.h
+++ b/source/core/slang-math.h
@@ -106,8 +106,10 @@ public:
}
static inline int IsNaN(float x) { return std::isnan(x); }
+ static inline int IsNaN(double x) { return std::isnan(x); }
static inline int IsInf(float x) { return std::isinf(x); }
+ static inline int IsInf(double x) { return std::isinf(x); }
static inline unsigned int Ones32(unsigned int x)
{
diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp
index 62584040a..f1b647045 100644
--- a/source/slang/slang-ir-sccp.cpp
+++ b/source/slang/slang-ir-sccp.cpp
@@ -511,7 +511,7 @@ struct SCCPContext
template<typename TIntFunc, typename TFloatFunc>
LatticeVal evalComparisonImpl(
- IRType* type,
+ IRType*,
LatticeVal v0,
LatticeVal v1,
const TIntFunc& intFunc,
@@ -522,29 +522,31 @@ struct SCCPContext
SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1)
auto c1 = as<IRConstant>(v1.value);
IRInst* resultVal = nullptr;
- switch (type->getOp())
+
+ // Check the operand types, not the result type (which is always bool for comparisons)
+ // For mixed-type comparisons, use floating-point path if either operand is floating-point
+ IRType* operandType0 = c0->getDataType();
+ IRType* operandType1 = c1->getDataType();
+
+ // Helper function to check if a type is floating-point
+ auto isFloatingPointType = [](IROp op) -> bool
+ { return op == kIROp_FloatType || op == kIROp_DoubleType || op == kIROp_HalfType; };
+
+ IROp op0 = operandType0->getOp();
+ IROp op1 = operandType1->getOp();
+
+ // Use floating-point path if either operand is floating-point
+ // Otherwise use integer path
+ if (isFloatingPointType(op0) || isFloatingPointType(op1))
{
- case kIROp_Int8Type:
- case kIROp_Int16Type:
- case kIROp_IntType:
- case kIROp_Int64Type:
- case kIROp_IntPtrType:
- case kIROp_UInt8Type:
- case kIROp_UInt16Type:
- case kIROp_UIntType:
- case kIROp_UInt64Type:
- case kIROp_UIntPtrType:
- case kIROp_BoolType:
- resultVal = getBuilder()->getBoolValue(intFunc(c0->value.intVal, c1->value.intVal));
- break;
- case kIROp_FloatType:
- case kIROp_DoubleType:
- case kIROp_HalfType:
+ // Floating-point path - C++ operators follow IEEE 754 automatically
resultVal =
getBuilder()->getBoolValue(floatFunc(c0->value.floatVal, c1->value.floatVal));
- break;
- default:
- break;
+ }
+ else
+ {
+ // Integer path - all integer types
+ resultVal = getBuilder()->getBoolValue(intFunc(c0->value.intVal, c1->value.intVal));
}
if (!resultVal)
return LatticeVal::getAny();
diff --git a/tests/compute/ieee754-mixed-type-nan-comparisons.slang b/tests/compute/ieee754-mixed-type-nan-comparisons.slang
new file mode 100644
index 000000000..31ce6b05b
--- /dev/null
+++ b/tests/compute/ieee754-mixed-type-nan-comparisons.slang
@@ -0,0 +1,79 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
+
+// Test IEEE 754 NaN comparison behavior with mixed int/float types
+// Tests the type promotion logic across integer and floating-point categories
+// Also tests both operand orders since implementation bugs could affect operand handling
+
+static const float fNAN = 0.0f / 0.0f;
+static const float fONE = 1.0f;
+static const int iONE = 1;
+static const int iZERO = 0;
+static const uint uONE = 1u;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint testIndex = 0;
+
+ // Test int compared with float NaN - all should follow IEEE 754
+ // CHECK: 0
+ // CHECK: 1
+ // CHECK: 0
+ // CHECK: 0
+ // CHECK: 0
+ // CHECK: 0
+ outputBuffer[testIndex++] = (iONE == fNAN) ? 1u : 0u; // int 1 == float NaN -> false
+ outputBuffer[testIndex++] = (iONE != fNAN) ? 1u : 0u; // int 1 != float NaN -> true
+ outputBuffer[testIndex++] = (iONE < fNAN) ? 1u : 0u; // int 1 < float NaN -> false
+ outputBuffer[testIndex++] = (iONE > fNAN) ? 1u : 0u; // int 1 > float NaN -> false
+ outputBuffer[testIndex++] = (iONE <= fNAN) ? 1u : 0u; // int 1 <= float NaN -> false
+ outputBuffer[testIndex++] = (iONE >= fNAN) ? 1u : 0u; // int 1 >= float NaN -> false
+
+ // Test float NaN compared with int - same results but different operand order
+ // CHECK: 0
+ // CHECK: 1
+ // CHECK: 0
+ // CHECK: 0
+ // CHECK: 0
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN == iONE) ? 1u : 0u; // float NaN == int 1 -> false
+ outputBuffer[testIndex++] = (fNAN != iONE) ? 1u : 0u; // float NaN != int 1 -> true
+ outputBuffer[testIndex++] = (fNAN < iONE) ? 1u : 0u; // float NaN < int 1 -> false
+ outputBuffer[testIndex++] = (fNAN > iONE) ? 1u : 0u; // float NaN > int 1 -> false
+ outputBuffer[testIndex++] = (fNAN <= iONE) ? 1u : 0u; // float NaN <= int 1 -> false
+ outputBuffer[testIndex++] = (fNAN >= iONE) ? 1u : 0u; // float NaN >= int 1 -> false
+
+ // Test with different int values to ensure consistent behavior
+ // CHECK: 0
+ // CHECK: 1
+ // CHECK: 0
+ // CHECK: 1
+ outputBuffer[testIndex++] = (iZERO == fNAN) ? 1u : 0u; // int 0 == float NaN -> false
+ outputBuffer[testIndex++] = (iZERO != fNAN) ? 1u : 0u; // int 0 != float NaN -> true
+ outputBuffer[testIndex++] = (fNAN == iZERO) ? 1u : 0u; // float NaN == int 0 -> false
+ outputBuffer[testIndex++] = (fNAN != iZERO) ? 1u : 0u; // float NaN != int 0 -> true
+
+ // Test unsigned int with float NaN
+ // CHECK: 0
+ // CHECK: 1
+ // CHECK: 0
+ // CHECK: 1
+ outputBuffer[testIndex++] = (uONE == fNAN) ? 1u : 0u; // uint 1 == float NaN -> false
+ outputBuffer[testIndex++] = (uONE != fNAN) ? 1u : 0u; // uint 1 != float NaN -> true
+ outputBuffer[testIndex++] = (fNAN == uONE) ? 1u : 0u; // float NaN == uint 1 -> false
+ outputBuffer[testIndex++] = (fNAN != uONE) ? 1u : 0u; // float NaN != uint 1 -> true
+
+ // Test normal int vs float comparisons (no NaN) to ensure type promotion works
+ // CHECK: 1
+ // CHECK: 1
+ // CHECK: 1
+ // CHECK: 1
+ outputBuffer[testIndex++] = (iONE == fONE) ? 1u : 0u; // int 1 == float 1.0 -> true
+ outputBuffer[testIndex++] = (fONE == iONE) ? 1u : 0u; // float 1.0 == int 1 -> true
+ outputBuffer[testIndex++] = (iZERO < fONE) ? 1u : 0u; // int 0 < float 1.0 -> true
+ outputBuffer[testIndex++] = (fONE > iZERO) ? 1u : 0u; // float 1.0 > int 0 -> true
+}
+
diff --git a/tests/compute/ieee754-nan-comparisons.slang b/tests/compute/ieee754-nan-comparisons.slang
new file mode 100644
index 000000000..70a252cb1
--- /dev/null
+++ b/tests/compute/ieee754-nan-comparisons.slang
@@ -0,0 +1,150 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type -compute
+
+// Test IEEE 754 NaN comparison behavior
+// According to IEEE 754 standard:
+// - Any comparison with NaN (except !=) should return false
+// - The != comparison with NaN should return true
+
+static const float fNAN = 0.0f / 0.0f;
+static const float fPOSITIVE_INFINITY = 1.0f / 0.0f;
+static const float fNEGATIVE_INFINITY = -1.0f / 0.0f;
+static const float fZERO = 0.0f;
+static const float fONE = 1.0f;
+static const float fNEG_ONE = -1.0f;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint testIndex = 0;
+
+ // Test 1: NaN == NaN should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN == fNAN) ? 1u : 0u;
+
+ // Test 2: NaN != NaN should be true
+ // CHECK: 1
+ outputBuffer[testIndex++] = (fNAN != fNAN) ? 1u : 0u;
+
+ // Test 3: NaN > NaN should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN > fNAN) ? 1u : 0u;
+
+ // Test 4: NaN < NaN should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN < fNAN) ? 1u : 0u;
+
+ // Test 5: NaN >= NaN should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN >= fNAN) ? 1u : 0u;
+
+ // Test 6: NaN <= NaN should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN <= fNAN) ? 1u : 0u;
+
+ // Test 7: NaN == 0.0 should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN == fZERO) ? 1u : 0u;
+
+ // Test 8: NaN != 0.0 should be true
+ // CHECK: 1
+ outputBuffer[testIndex++] = (fNAN != fZERO) ? 1u : 0u;
+
+ // Test 9: NaN > 0.0 should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN > fZERO) ? 1u : 0u;
+
+ // Test 10: NaN < 0.0 should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN < fZERO) ? 1u : 0u;
+
+ // Test 11: NaN >= 0.0 should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN >= fZERO) ? 1u : 0u;
+
+ // Test 12: NaN <= 0.0 should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN <= fZERO) ? 1u : 0u;
+
+ // Test 13: 0.0 == NaN should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fZERO == fNAN) ? 1u : 0u;
+
+ // Test 14: 0.0 != NaN should be true
+ // CHECK: 1
+ outputBuffer[testIndex++] = (fZERO != fNAN) ? 1u : 0u;
+
+ // Test 15: 0.0 > NaN should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fZERO > fNAN) ? 1u : 0u;
+
+ // Test 16: 0.0 < NaN should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fZERO < fNAN) ? 1u : 0u;
+
+ // Test 17: 0.0 >= NaN should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fZERO >= fNAN) ? 1u : 0u;
+
+ // Test 18: 0.0 <= NaN should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fZERO <= fNAN) ? 1u : 0u;
+
+ // Test 19: NaN == +infinity should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN == fPOSITIVE_INFINITY) ? 1u : 0u;
+
+ // Test 20: NaN != +infinity should be true
+ // CHECK: 1
+ outputBuffer[testIndex++] = (fNAN != fPOSITIVE_INFINITY) ? 1u : 0u;
+
+ // Test 21: NaN > +infinity should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN > fPOSITIVE_INFINITY) ? 1u : 0u;
+
+ // Test 22: NaN < +infinity should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN < fPOSITIVE_INFINITY) ? 1u : 0u;
+
+ // Test 23: NaN >= +infinity should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN >= fPOSITIVE_INFINITY) ? 1u : 0u;
+
+ // Test 24: NaN <= +infinity should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN <= fPOSITIVE_INFINITY) ? 1u : 0u;
+
+ // Test 25: NaN == -infinity should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN == fNEGATIVE_INFINITY) ? 1u : 0u;
+
+ // Test 26: NaN != -infinity should be true
+ // CHECK: 1
+ outputBuffer[testIndex++] = (fNAN != fNEGATIVE_INFINITY) ? 1u : 0u;
+
+ // Test 27: NaN > -infinity should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN > fNEGATIVE_INFINITY) ? 1u : 0u;
+
+ // Test 28: NaN < -infinity should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN < fNEGATIVE_INFINITY) ? 1u : 0u;
+
+ // Test 29: NaN >= -infinity should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN >= fNEGATIVE_INFINITY) ? 1u : 0u;
+
+ // Test 30: NaN <= -infinity should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN <= fNEGATIVE_INFINITY) ? 1u : 0u;
+
+ // Test 31: NaN == 1.0 should be false
+ // CHECK: 0
+ outputBuffer[testIndex++] = (fNAN == fONE) ? 1u : 0u;
+
+ // Test 32: NaN != 1.0 should be true
+ // CHECK: 1
+ outputBuffer[testIndex++] = (fNAN != fONE) ? 1u : 0u;
+}