summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-val.cpp102
-rw-r--r--source/slang/slang-ast-val.h15
-rw-r--r--source/slang/slang-check-expr.cpp44
-rw-r--r--source/slang/slang-lower-to-ir.cpp9
4 files changed, 148 insertions, 22 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index c8b9d0bb8..d8886f05b 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -1217,6 +1217,108 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder)
return this;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeCastIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+bool TypeCastIntVal::_equalsValOverride(Val* val)
+{
+ if (auto typeCastIntVal = as<TypeCastIntVal>(val))
+ {
+ if (!type->equals(typeCastIntVal->type))
+ return false;
+ if (!base->equalsVal(typeCastIntVal->base))
+ return false;
+ return true;
+ }
+ return false;
+}
+
+void TypeCastIntVal::_toTextOverride(StringBuilder& out)
+{
+ type->toText(out);
+ out << "(";
+ base->toText(out);
+ out << ")";
+}
+
+HashCode TypeCastIntVal::_getHashCodeOverride()
+{
+ HashCode result = type->getHashCode();
+ result = combineHash(result, base->getHashCode());
+ return result;
+}
+
+Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink)
+{
+ SLANG_UNUSED(sink);
+
+ if (auto c = as<ConstantIntVal>(base))
+ {
+ IntegerLiteralValue resultValue = c->value;
+ auto baseType = as<BasicExpressionType>(resultType);
+ if (baseType)
+ {
+ switch (baseType->baseType)
+ {
+ case BaseType::Int:
+ resultValue = (int)resultValue;
+ break;
+ case BaseType::UInt:
+ resultValue = (unsigned int)resultValue;
+ break;
+ case BaseType::Int64:
+ case BaseType::IntPtr:
+ resultValue = (Int64)resultValue;
+ break;
+ case BaseType::UInt64:
+ case BaseType::UIntPtr:
+ resultValue = (UInt64)resultValue;
+ break;
+ case BaseType::Int16:
+ resultValue = (int16_t)resultValue;
+ break;
+ case BaseType::UInt16:
+ resultValue = (uint16_t)resultValue;
+ break;
+ case BaseType::Int8:
+ resultValue = (int8_t)resultValue;
+ break;
+ case BaseType::UInt8:
+ resultValue = (uint8_t)resultValue;
+ break;
+ default:
+ return nullptr;
+ }
+ }
+ return astBuilder->getIntVal(resultType, resultValue);
+ }
+ return nullptr;
+}
+
+Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+ auto substBase = base->substituteImpl(astBuilder, subst, &diff);
+ if (substBase != base)
+ diff++;
+ auto substType = as<Type>(type->substituteImpl(astBuilder, subst, &diff));
+ if (substType != type)
+ diff++;
+ *ioDiff += diff;
+ if (diff)
+ {
+ auto newVal = tryFoldImpl(astBuilder, substType, substBase, nullptr);
+ if (newVal)
+ return newVal;
+ else
+ {
+ auto result = astBuilder->create<TypeCastIntVal>(substType, substBase);
+ return result;
+ }
+ }
+ // Nothing found: don't substitute.
+ return this;
+}
+
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
bool FuncCallIntVal::_equalsValOverride(Val* val)
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 0eb77e06e..222ba48d1 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -63,6 +63,21 @@ protected:
{}
};
+class TypeCastIntVal : public IntVal
+{
+ SLANG_AST_CLASS(TypeCastIntVal)
+
+ bool _equalsValOverride(Val* val);
+ void _toTextOverride(StringBuilder& out);
+ HashCode _getHashCodeOverride();
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+
+ Val* base;
+ TypeCastIntVal(Type* inType, Val* inBase) : IntVal(inType), base(inBase) {}
+
+ static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink);
+};
+
// An compile time int val as result of some general computation.
class FuncCallIntVal : public IntVal
{
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index a91ec1e98..e0084e08e 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1385,26 +1385,12 @@ namespace Slang
auto targetBasicType = as<BasicExpressionType>(invokeExpr.getExpr()->type.type);
if (!targetBasicType)
return nullptr;
- switch (targetBasicType->baseType)
- {
- case BaseType::Bool:
- resultValue = constArgVals[0] != 0;
- break;
- case BaseType::Int:
- case BaseType::UInt:
- case BaseType::UInt16:
- case BaseType::Int16:
- case BaseType::UInt8:
- case BaseType::Int8:
- case BaseType::UIntPtr:
- case BaseType::IntPtr:
- case BaseType::Int64:
- case BaseType::UInt64:
- resultValue = constArgVals[0];
- break;
- default:
- return nullptr;
- }
+ auto foldVal = as<IntVal>(
+ TypeCastIntVal::tryFoldImpl(m_astBuilder, targetBasicType, argVals[0], getSink()));
+ if (foldVal)
+ return foldVal;
+ auto result = m_astBuilder->getOrCreate<TypeCastIntVal>(targetBasicType, argVals[0]);
+ return result;
}
else
{
@@ -1619,9 +1605,23 @@ namespace Slang
if(auto castExpr = expr.as<TypeCastExpr>())
{
+ auto substType = getType(m_astBuilder, expr);
+ if (!substType)
+ return nullptr;
+ if (!isScalarIntegerType(substType))
+ return nullptr;
auto val = tryConstantFoldExpr(getArg(castExpr, 0), circularityInfo);
- if(val)
- return val;
+ if (val)
+ {
+ if (!castExpr.getExpr()->type)
+ return nullptr;
+ auto foldVal = as<IntVal>(
+ TypeCastIntVal::tryFoldImpl(m_astBuilder, substType, val, getSink()));
+ if (foldVal)
+ return foldVal;
+ auto result = m_astBuilder->getOrCreate<TypeCastIntVal>(substType, val);
+ return result;
+ }
}
else if (auto invokeExpr = expr.as<InvokeExpr>())
{
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index b36f0dc94..144875e8c 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1432,6 +1432,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
tryEnv);
}
+ LoweredValInfo visitTypeCastIntVal(TypeCastIntVal* val)
+ {
+ TryClauseEnvironment tryEnv;
+ auto baseVal = lowerVal(context, val->base);
+ SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
+ auto type = lowerType(context, val->type);
+ return LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val));
+ }
+
LoweredValInfo visitWitnessLookupIntVal(WitnessLookupIntVal* val)
{
auto witnessVal = lowerVal(context, val->witness);