summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-07-11 10:12:17 -0700
committerGitHub <noreply@github.com>2025-07-11 17:12:17 +0000
commit57567778b7d91afe7e6325c731d54b313b8b16e9 (patch)
tree7fbd665e6ed71680ab68b21571e84f40bab16fb5 /source
parentb20b9297ed20f85dec6212cad83eeacaecbaccf3 (diff)
Fix IEEE 754 NaN comparisons in constant folding (#7721)
* Fix IEEE 754 NaN comparisons in constant folding Added proper NaN handling in SCCP optimization pass to follow IEEE 754 standard: - NaN \!= any value returns true - All other NaN comparisons return false - Added double precision NaN detection support - Fixed type detection to check operands instead of result type * Avoid differentiating NaN and non-NaN cases * format code (#76) --------- Co-authored-by: slangbot <ellieh+slangbot@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/core/slang-math.h2
-rw-r--r--source/slang/slang-ir-sccp.cpp44
2 files changed, 25 insertions, 21 deletions
diff --git a/source/core/slang-math.h b/source/core/slang-math.h
index e977dc37d..c7e77fce5 100644
--- a/source/core/slang-math.h
+++ b/source/core/slang-math.h
@@ -106,8 +106,10 @@ public:
}
static inline int IsNaN(float x) { return std::isnan(x); }
+ static inline int IsNaN(double x) { return std::isnan(x); }
static inline int IsInf(float x) { return std::isinf(x); }
+ static inline int IsInf(double x) { return std::isinf(x); }
static inline unsigned int Ones32(unsigned int x)
{
diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp
index 62584040a..f1b647045 100644
--- a/source/slang/slang-ir-sccp.cpp
+++ b/source/slang/slang-ir-sccp.cpp
@@ -511,7 +511,7 @@ struct SCCPContext
template<typename TIntFunc, typename TFloatFunc>
LatticeVal evalComparisonImpl(
- IRType* type,
+ IRType*,
LatticeVal v0,
LatticeVal v1,
const TIntFunc& intFunc,
@@ -522,29 +522,31 @@ struct SCCPContext
SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1)
auto c1 = as<IRConstant>(v1.value);
IRInst* resultVal = nullptr;
- switch (type->getOp())
+
+ // Check the operand types, not the result type (which is always bool for comparisons)
+ // For mixed-type comparisons, use floating-point path if either operand is floating-point
+ IRType* operandType0 = c0->getDataType();
+ IRType* operandType1 = c1->getDataType();
+
+ // Helper function to check if a type is floating-point
+ auto isFloatingPointType = [](IROp op) -> bool
+ { return op == kIROp_FloatType || op == kIROp_DoubleType || op == kIROp_HalfType; };
+
+ IROp op0 = operandType0->getOp();
+ IROp op1 = operandType1->getOp();
+
+ // Use floating-point path if either operand is floating-point
+ // Otherwise use integer path
+ if (isFloatingPointType(op0) || isFloatingPointType(op1))
{
- case kIROp_Int8Type:
- case kIROp_Int16Type:
- case kIROp_IntType:
- case kIROp_Int64Type:
- case kIROp_IntPtrType:
- case kIROp_UInt8Type:
- case kIROp_UInt16Type:
- case kIROp_UIntType:
- case kIROp_UInt64Type:
- case kIROp_UIntPtrType:
- case kIROp_BoolType:
- resultVal = getBuilder()->getBoolValue(intFunc(c0->value.intVal, c1->value.intVal));
- break;
- case kIROp_FloatType:
- case kIROp_DoubleType:
- case kIROp_HalfType:
+ // Floating-point path - C++ operators follow IEEE 754 automatically
resultVal =
getBuilder()->getBoolValue(floatFunc(c0->value.floatVal, c1->value.floatVal));
- break;
- default:
- break;
+ }
+ else
+ {
+ // Integer path - all integer types
+ resultVal = getBuilder()->getBoolValue(intFunc(c0->value.intVal, c1->value.intVal));
}
if (!resultVal)
return LatticeVal::getAny();