diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 102 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 15 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 9 |
4 files changed, 148 insertions, 22 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index c8b9d0bb8..d8886f05b 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -1217,6 +1217,108 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder) return this; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeCastIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +bool TypeCastIntVal::_equalsValOverride(Val* val) +{ + if (auto typeCastIntVal = as<TypeCastIntVal>(val)) + { + if (!type->equals(typeCastIntVal->type)) + return false; + if (!base->equalsVal(typeCastIntVal->base)) + return false; + return true; + } + return false; +} + +void TypeCastIntVal::_toTextOverride(StringBuilder& out) +{ + type->toText(out); + out << "("; + base->toText(out); + out << ")"; +} + +HashCode TypeCastIntVal::_getHashCodeOverride() +{ + HashCode result = type->getHashCode(); + result = combineHash(result, base->getHashCode()); + return result; +} + +Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + + if (auto c = as<ConstantIntVal>(base)) + { + IntegerLiteralValue resultValue = c->value; + auto baseType = as<BasicExpressionType>(resultType); + if (baseType) + { + switch (baseType->baseType) + { + case BaseType::Int: + resultValue = (int)resultValue; + break; + case BaseType::UInt: + resultValue = (unsigned int)resultValue; + break; + case BaseType::Int64: + case BaseType::IntPtr: + resultValue = (Int64)resultValue; + break; + case BaseType::UInt64: + case BaseType::UIntPtr: + resultValue = (UInt64)resultValue; + break; + case BaseType::Int16: + resultValue = (int16_t)resultValue; + break; + case BaseType::UInt16: + resultValue = (uint16_t)resultValue; + break; + case BaseType::Int8: + resultValue = (int8_t)resultValue; + break; + case BaseType::UInt8: + resultValue = (uint8_t)resultValue; + break; + default: + return nullptr; + } + } + return astBuilder->getIntVal(resultType, resultValue); + } + return nullptr; +} + +Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto substBase = base->substituteImpl(astBuilder, subst, &diff); + if (substBase != base) + diff++; + auto substType = as<Type>(type->substituteImpl(astBuilder, subst, &diff)); + if (substType != type) + diff++; + *ioDiff += diff; + if (diff) + { + auto newVal = tryFoldImpl(astBuilder, substType, substBase, nullptr); + if (newVal) + return newVal; + else + { + auto result = astBuilder->create<TypeCastIntVal>(substType, substBase); + return result; + } + } + // Nothing found: don't substitute. + return this; +} + + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! bool FuncCallIntVal::_equalsValOverride(Val* val) diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 0eb77e06e..222ba48d1 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -63,6 +63,21 @@ protected: {} }; +class TypeCastIntVal : public IntVal +{ + SLANG_AST_CLASS(TypeCastIntVal) + + bool _equalsValOverride(Val* val); + void _toTextOverride(StringBuilder& out); + HashCode _getHashCodeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + Val* base; + TypeCastIntVal(Type* inType, Val* inBase) : IntVal(inType), base(inBase) {} + + static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink); +}; + // An compile time int val as result of some general computation. class FuncCallIntVal : public IntVal { diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index a91ec1e98..e0084e08e 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1385,26 +1385,12 @@ namespace Slang auto targetBasicType = as<BasicExpressionType>(invokeExpr.getExpr()->type.type); if (!targetBasicType) return nullptr; - switch (targetBasicType->baseType) - { - case BaseType::Bool: - resultValue = constArgVals[0] != 0; - break; - case BaseType::Int: - case BaseType::UInt: - case BaseType::UInt16: - case BaseType::Int16: - case BaseType::UInt8: - case BaseType::Int8: - case BaseType::UIntPtr: - case BaseType::IntPtr: - case BaseType::Int64: - case BaseType::UInt64: - resultValue = constArgVals[0]; - break; - default: - return nullptr; - } + auto foldVal = as<IntVal>( + TypeCastIntVal::tryFoldImpl(m_astBuilder, targetBasicType, argVals[0], getSink())); + if (foldVal) + return foldVal; + auto result = m_astBuilder->getOrCreate<TypeCastIntVal>(targetBasicType, argVals[0]); + return result; } else { @@ -1619,9 +1605,23 @@ namespace Slang if(auto castExpr = expr.as<TypeCastExpr>()) { + auto substType = getType(m_astBuilder, expr); + if (!substType) + return nullptr; + if (!isScalarIntegerType(substType)) + return nullptr; auto val = tryConstantFoldExpr(getArg(castExpr, 0), circularityInfo); - if(val) - return val; + if (val) + { + if (!castExpr.getExpr()->type) + return nullptr; + auto foldVal = as<IntVal>( + TypeCastIntVal::tryFoldImpl(m_astBuilder, substType, val, getSink())); + if (foldVal) + return foldVal; + auto result = m_astBuilder->getOrCreate<TypeCastIntVal>(substType, val); + return result; + } } else if (auto invokeExpr = expr.as<InvokeExpr>()) { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b36f0dc94..144875e8c 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1432,6 +1432,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower tryEnv); } + LoweredValInfo visitTypeCastIntVal(TypeCastIntVal* val) + { + TryClauseEnvironment tryEnv; + auto baseVal = lowerVal(context, val->base); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); + auto type = lowerType(context, val->type); + return LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val)); + } + LoweredValInfo visitWitnessLookupIntVal(WitnessLookupIntVal* val) { auto witnessVal = lowerVal(context, val->witness); |
