summaryrefslogtreecommitdiff
path: root/source/slang/slang-ast-val.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-04 15:47:39 -0700
committerGitHub <noreply@github.com>2023-08-04 15:47:39 -0700
commita2d90fb275962da84611160f8ddd74d934a68dbd (patch)
tree066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-ast-val.cpp
parent17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff)
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val. * Update project files * Fix warning. * Fix. * Fix. * Remove `Val::_equalsImplOverride`. * Rmove `Val::_getHashCodeOverride`. * Remove `semanticVisitor` param from `resolve`. * Cleanups. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
-rw-r--r--source/slang/slang-ast-val.cpp1477
1 files changed, 677 insertions, 800 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index b45300af8..056577eb0 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -6,9 +6,47 @@
#include "slang-generated-ast-macro.h"
#include "slang-diagnostics.h"
#include "slang-syntax.h"
+#include "slang-ast-val.h"
namespace Slang {
+
+bool ValNodeDesc::operator==(ValNodeDesc const& that) const
+{
+ if (hashCode != that.hashCode) return false;
+ if (type != that.type) return false;
+ if (operands.getCount() != that.operands.getCount()) return false;
+ for (Index i = 0; i < operands.getCount(); ++i)
+ {
+ // Note: we are comparing the operands directly for identity
+ // (pointer equality) rather than doing the `Val`-level
+ // equality check.
+ //
+ // The rationale here is that nodes that will be created
+ // via a `NodeDesc` *should* all be going through the
+ // deduplication path anyway, as should their operands.
+ //
+ if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false;
+ }
+ return true;
+}
+
+void ValNodeDesc::init()
+{
+ Hasher hasher;
+ hasher.hashValue(Int(type));
+ for (Index i = 0; i < operands.getCount(); ++i)
+ {
+ // Note: we are hashing the raw pointer value rather
+ // than the content of the value node. This is done
+ // to match the semantics implemented for `==` on
+ // `NodeDesc`.
+ //
+ hasher.hashValue(operands[i].values.nodeOperand);
+ }
+ hashCode = hasher.getResult();
+}
+
Val* Val::substitute(ASTBuilder* astBuilder, SubstitutionSet subst)
{
if (!subst) return this;
@@ -21,14 +59,103 @@ Val* Val::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioD
SLANG_AST_NODE_VIRTUAL_CALL(Val, substituteImpl, (astBuilder, subst, ioDiff))
}
-bool Val::equalsVal(Val* val)
+void Val::toText(StringBuilder& out)
{
- SLANG_AST_NODE_VIRTUAL_CALL(Val, equalsVal, (val))
+ SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out))
}
-void Val::toText(StringBuilder& out)
+Val* Val::_resolveImplOverride()
{
- SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out))
+ SLANG_UNEXPECTED("Val::_resolveImplOverride not overridden");
+}
+
+Val* Val::resolveImpl()
+{
+ SLANG_AST_NODE_VIRTUAL_CALL(Val, resolveImpl, ());
+}
+
+Val* Val::resolve()
+{
+ auto astBuilder = getCurrentASTBuilder();
+
+ // If we are not in a proper checking context, just return the previously resolved val.
+ if (!astBuilder)
+ return m_resolvedVal? m_resolvedVal : this;
+ if (m_resolvedVal && m_resolvedValEpoch == getCurrentASTBuilder()->getEpoch())
+ {
+ SLANG_ASSERT(as<Val>(m_resolvedVal));
+ return m_resolvedVal;
+ }
+
+ // Update epoch now to avoid infinite recursion.
+ m_resolvedValEpoch = getCurrentASTBuilder()->getEpoch();
+ m_resolvedVal = this;
+ m_resolvedVal = resolveImpl();
+
+ // Check if we are resolved to an existing Val in the AST cache.
+ ValNodeDesc newDesc;
+ newDesc.type = m_resolvedVal->astNodeType;
+ for (auto operand : m_resolvedVal->m_operands)
+ {
+ if (operand.kind == ValNodeOperandKind::ValNode)
+ {
+ auto valOperand = as<Val>(operand.values.nodeOperand);
+ if (valOperand)
+ {
+ operand.values.nodeOperand = valOperand->resolve();
+ }
+ }
+ newDesc.operands.add(operand);
+ }
+ newDesc.init();
+
+ NodeBase* existingNode = nullptr;
+ if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode))
+ m_resolvedVal = as<Val>(existingNode);
+
+#ifdef _DEBUG
+ if (m_resolvedVal->_debugUID > 0 && this->_debugUID < 0)
+ {
+ //SLANG_ASSERT_FAILURE("should not be modifying stdlib vals outside of stdlib checking.");
+ }
+#endif
+ return m_resolvedVal;
+}
+
+ValNodeDesc Val::getDesc()
+{
+ ValNodeDesc desc;
+ desc.type = astNodeType;
+ for (auto operand : m_operands)
+ desc.operands.add(operand);
+ desc.init();
+ return desc;
+}
+
+Val* Val::defaultResolveImpl()
+{
+ // Default resolve implementation is to recursively resolve all operands, and lookup in deduplication cache.
+ ValNodeDesc newDesc;
+ newDesc.type = astNodeType;
+ for (auto operand : m_operands)
+ {
+ if (operand.kind == ValNodeOperandKind::ValNode)
+ {
+ auto valOperand = as<Val>(operand.values.nodeOperand);
+ if (valOperand)
+ {
+ operand.values.nodeOperand = valOperand->resolve();
+ }
+ }
+ newDesc.operands.add(operand);
+ }
+ newDesc.init();
+ auto astBuilder = getCurrentASTBuilder();
+
+ NodeBase* existingNode = nullptr;
+ if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode))
+ return as<Val>(existingNode);
+ return this;
}
String Val::toString()
@@ -40,7 +167,7 @@ String Val::toString()
HashCode Val::getHashCode()
{
- SLANG_AST_NODE_VIRTUAL_CALL(Val, getHashCode, ())
+ return Slang::getHashCode(resolve());
}
Val* Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
@@ -52,124 +179,84 @@ Val* Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst,
return this;
}
-bool Val::_equalsValOverride(Val* val)
-{
- SLANG_UNUSED(val);
- SLANG_UNEXPECTED("Val::_equalsValOverride not overridden");
- //return false;
-}
-
void Val::_toTextOverride(StringBuilder& out)
{
SLANG_UNUSED(out);
SLANG_UNEXPECTED("Val::_toStringOverride not overridden");
}
-HashCode Val::_getHashCodeOverride()
-{
- SLANG_UNEXPECTED("Val::_getHashCodeOverride not overridden");
- //return HashCode(0);
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ConstantIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool ConstantIntVal::_equalsValOverride(Val* val)
-{
- if (auto intVal = as<ConstantIntVal>(val))
- return value == intVal->value;
- return false;
-}
-
void ConstantIntVal::_toTextOverride(StringBuilder& out)
{
- out << value;
-}
-
-HashCode ConstantIntVal::_getHashCodeOverride()
-{
- return (HashCode)value;
+ out << getValue();
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericParamIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool GenericParamIntVal::_equalsValOverride(Val* val)
-{
- if (auto genericParamVal = as<GenericParamIntVal>(val))
- {
- return declRef.equals(genericParamVal->declRef);
- }
- return false;
-}
-
void GenericParamIntVal::_toTextOverride(StringBuilder& out)
{
- Name* name = declRef.getName();
+ Name* name = getDeclRef().getName();
if (name)
{
out << name->text;
}
}
-HashCode GenericParamIntVal::_getHashCodeOverride()
-{
- return declRef.getHashCode() ^ HashCode(0xFFFF);
-}
-
Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff)
{
// search for a substitution that might apply to us
- for (auto s = subst.substitutions; s; s = s->getOuter())
+ auto outerGeneric = as<GenericDecl>(paramDecl->parentDecl);
+ if (!outerGeneric)
+ return paramVal;
+
+ GenericAppDeclRef* genAppArgs = subst.findGenericAppDeclRef(outerGeneric);
+ if (!genAppArgs)
{
- auto genSubst = as<GenericSubstitution>(s);
- if (!genSubst)
- continue;
-
- // the generic decl associated with the substitution list must be
- // the generic decl that declared this parameter
- auto genericDecl = genSubst->getGenericDecl();
- if (genericDecl != paramDecl->parentDecl)
- continue;
-
- // In some cases, we construct a `DeclRef` to a `GenericDecl`
- // (or a declaration under one) that only includes argument
- // values for a prefix of the parameters of the generic.
- //
- // If we aren't careful, we could end up indexing into the
- // argument list past the available range.
- //
- Count argCount = genSubst->getArgs().getCount();
+ return paramVal;
+ }
- Count argIndex = 0;
- for (auto m : genericDecl->members)
+ auto args = genAppArgs->getArgs();
+
+ // In some cases, we construct a `DeclRef` to a `GenericDecl`
+ // (or a declaration under one) that only includes argument
+ // values for a prefix of the parameters of the generic.
+ //
+ // If we aren't careful, we could end up indexing into the
+ // argument list past the available range.
+ //
+ Count argCount = args.getCount();
+
+ Count argIndex = 0;
+ for (auto m : outerGeneric->members)
+ {
+ // If we have run out of arguments, then we can stop
+ // iterating over the parameters, because `this`
+ // parameter will not be replaced with anything by
+ // the substituion.
+ //
+ if (argIndex >= argCount)
{
- // If we have run out of arguments, then we can stop
- // iterating over the parameters, because `this`
- // parameter will not be replaced with anything by
- // the substituion.
- //
- if (argIndex >= argCount)
- {
- return paramVal;
- }
+ return paramVal;
+ }
- if (m == paramDecl)
- {
- // We've found it, so return the corresponding specialization argument
- (*ioDiff)++;
- return genSubst->getArgs()[argIndex];
- }
- else if (const auto typeParam = as<GenericTypeParamDecl>(m))
- {
- argIndex++;
- }
- else if (const auto valParam = as<GenericValueParamDecl>(m))
- {
- argIndex++;
- }
- else
- {
- }
+ if (m == paramDecl)
+ {
+ // We've found it, so return the corresponding specialization argument
+ (*ioDiff)++;
+ return args[argIndex];
+ }
+ else if (const auto typeParam = as<GenericTypeParamDecl>(m))
+ {
+ argIndex++;
+ }
+ else if (const auto valParam = as<GenericValueParamDecl>(m))
+ {
+ argIndex++;
+ }
+ else
+ {
}
}
@@ -180,7 +267,7 @@ Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet
Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff)
{
- if (auto result = maybeSubstituteGenericParam(this, declRef.getDecl(), subst, ioDiff))
+ if (auto result = maybeSubstituteGenericParam(this, getDeclRef().getDecl(), subst, ioDiff))
return result;
return this;
@@ -188,21 +275,11 @@ Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, S
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool ErrorIntVal::_equalsValOverride(Val* val)
-{
- return as<ErrorIntVal>(val);
-}
-
void ErrorIntVal::_toTextOverride(StringBuilder& out)
{
out << toSlice("<error>");
}
-HashCode ErrorIntVal::_getHashCodeOverride()
-{
- return HashCode(typeid(this).hash_code());
-}
-
Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
SLANG_UNUSED(astBuilder);
@@ -211,97 +288,110 @@ Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe
return this;
}
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-// TODO: should really have a `type.cpp` and a `witness.cpp`
-
-bool TypeEqualityWitness::_equalsValOverride(Val* val)
-{
- auto otherWitness = as<TypeEqualityWitness>(val);
- if (!otherWitness)
- return false;
- return sub->equals(otherWitness->sub);
-}
-
Val* TypeEqualityWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
{
- TypeEqualityWitness* rs = astBuilder->create<TypeEqualityWitness>();
- rs->sub = as<Type>(sub->substituteImpl(astBuilder, subst, ioDiff));
- rs->sup = as<Type>(sup->substituteImpl(astBuilder, subst, ioDiff));
+ auto type = as<Type>(getSub()->substituteImpl(astBuilder, subst, ioDiff));
+ TypeEqualityWitness* rs = astBuilder->getOrCreate<TypeEqualityWitness>(type, type);
return rs;
}
void TypeEqualityWitness::_toTextOverride(StringBuilder& out)
{
- out << toSlice("TypeEqualityWitness(") << sub << toSlice(")");
-}
-
-HashCode TypeEqualityWitness::_getHashCodeOverride()
-{
- return sub->getHashCode();
+ out << toSlice("TypeEqualityWitness(") << getSub() << toSlice(")");
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclaredSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool DeclaredSubtypeWitness::_equalsValOverride(Val* val)
+Val* DeclaredSubtypeWitness::_resolveImplOverride()
{
- auto otherWitness = as<DeclaredSubtypeWitness>(val);
- if (!otherWitness)
- return false;
+ auto resolvedDeclRef = getDeclRef().declRefBase->resolve();
+ if (auto resolvedVal = as<SubtypeWitness>(resolvedDeclRef))
+ return resolvedVal;
- return sub->equals(otherWitness->sub)
- && sup->equals(otherWitness->sup)
- && declRef.equals(otherWitness->declRef);
+ auto newSub = as<Type>(getSub()->resolve());
+ auto newSup = as<Type>(getSup()->resolve());
+
+ // If we are trying to lookup for a witness that A<:B from a witness(A<:B), we
+ // can just return the witness itself.
+ if (auto lookupDeclRef = as<LookupDeclRef>(resolvedDeclRef))
+ {
+ auto witnessToLookupFrom = lookupDeclRef->getWitness();
+ if (witnessToLookupFrom->getSub()->equals(newSub) &&
+ witnessToLookupFrom->getSup()->equals(newSup))
+ return witnessToLookupFrom;
+ }
+ auto newDeclRef = as<DeclRefBase>(resolvedDeclRef);
+ if (!newDeclRef)
+ newDeclRef = getDeclRef().declRefBase;
+ if (newSub != getSub() || newSup != getSup() || newDeclRef != getDeclRef())
+ {
+ return getCurrentASTBuilder()->getDeclaredSubtypeWitness(newSub, newSup, newDeclRef);
+ }
+ return this;
}
Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
{
- if (auto genConstraintDeclRef = declRef.as<GenericTypeConstraintDecl>())
+ if (auto genConstraintDeclRef = getDeclRef().as<GenericTypeConstraintDecl>())
{
- auto genConstraintDecl = genConstraintDeclRef.getDecl();
+ auto genericDecl = as<GenericDecl>(getDeclRef().getDecl()->parentDecl);
+ if (!genericDecl)
+ goto breakLabel;
// search for a substitution that might apply to us
- for (auto s = subst.substitutions; s; s = s->getOuter())
+ auto args = tryGetGenericArguments(subst, genericDecl);
+ if (args.getCount() == 0)
+ goto breakLabel;
+
+ bool found = false;
+ Index index = 0;
+ for (auto m : genericDecl->members)
{
- if (auto genericSubst = as<GenericSubstitution>(s))
+ if (auto constraintParam = as<GenericTypeConstraintDecl>(m))
{
- // the generic decl associated with the substitution list must be
- // the generic decl that declared this parameter
- auto genericDecl = genericSubst->getGenericDecl();
- if (genericDecl != genConstraintDecl->parentDecl)
- continue;
-
- bool found = false;
- Index index = 0;
- for (auto m : genericDecl->members)
+ if (constraintParam == getDeclRef().getDecl())
{
- if (auto constraintParam = as<GenericTypeConstraintDecl>(m))
- {
- if (constraintParam == declRef.getDecl())
- {
- found = true;
- break;
- }
- index++;
- }
- }
- if (found)
- {
- (*ioDiff)++;
- auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() +
- genericDecl->getMembersOfType<GenericValueParamDecl>().getCount();
- SLANG_ASSERT(index + ordinaryParamCount < genericSubst->getArgs().getCount());
- return genericSubst->getArgs()[index + ordinaryParamCount];
+ found = true;
+ break;
}
+ index++;
+ }
+ }
+ if (found)
+ {
+ auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() +
+ genericDecl->getMembersOfType<GenericValueParamDecl>().getCount();
+ if (index + ordinaryParamCount < args.getCount())
+ {
+ (*ioDiff)++;
+ return args[index + ordinaryParamCount];
+ }
+ else
+ {
+ // When the `subst` represents a partial substitution, we may not have a corresponding argument.
+ // In this case we just return the original witness.
+ //
+ goto breakLabel;
}
}
}
+ else if (auto thisTypeConstraintDeclRef = getDeclRef().as<ThisTypeConstraintDecl>())
+ {
+ auto lookupSubst = subst.findLookupDeclRef();
+ if (lookupSubst && lookupSubst->getSupDecl() == thisTypeConstraintDeclRef.getDecl()->getInterfaceDecl())
+ {
+ (*ioDiff)++;
+ return lookupSubst->getWitness();
+ }
+ }
+
+breakLabel:;
// Perform substitution on the constituent elements.
int diff = 0;
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
- auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff);
+ auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff));
+ auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff));
+
if (!diff)
return this;
@@ -317,7 +407,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
// so we'll need to change this location in the code if we ever clean
// up the hierarchy.
//
- if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(substDeclRef.getDecl()))
+ if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(getDeclRef().getDecl()))
{
if (auto substAssocTypeDecl = as<AssocTypeDecl>(substTypeConstraintDecl->parentDecl))
{
@@ -326,12 +416,12 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
// At this point we have a constraint decl for an associated type,
// and we nee to see if we are dealing with a concrete substitution
// for the interface around that associated type.
- if (auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.getSubst(), interfaceDecl))
+ if (auto thisTypeWitness = findThisTypeWitness(subst, interfaceDecl))
{
// We need to look up the declaration that satisfies
// the requirement named by the associated type.
Decl* requirementKey = substTypeConstraintDecl;
- RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeSubst->witness, requirementKey);
+ RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeWitness, requirementKey);
switch (requirementWitness.getFlavor())
{
default:
@@ -348,6 +438,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
}
}
+ auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
auto rs = astBuilder->getDeclaredSubtypeWitness(
substSub, substSup, substDeclRef);
return rs;
@@ -355,34 +446,17 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
void DeclaredSubtypeWitness::_toTextOverride(StringBuilder& out)
{
- out << toSlice("DeclaredSubtypeWitness(") << sub << toSlice(", ") << sup << toSlice(", ") << declRef << toSlice(")");
-}
-
-HashCode DeclaredSubtypeWitness::_getHashCodeOverride()
-{
- return declRef.getHashCode();
+ out << toSlice("DeclaredSubtypeWitness(") << getSub() << toSlice(", ") << getSup() << toSlice(", ") << getDeclRef() << toSlice(")");
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TransitiveSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool TransitiveSubtypeWitness::_equalsValOverride(Val* val)
-{
- auto otherWitness = as<TransitiveSubtypeWitness>(val);
- if (!otherWitness)
- return false;
-
- return sub->equals(otherWitness->sub)
- && sup->equals(otherWitness->sup)
- && subToMid->equalsVal(otherWitness->subToMid)
- && midToSup->equalsVal(otherWitness->midToSup);
-}
-
Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
{
int diff = 0;
- SubtypeWitness* substSubToMid = as<SubtypeWitness>(subToMid->substituteImpl(astBuilder, subst, &diff));
- SubtypeWitness* substMidToSup = as<SubtypeWitness>(midToSup->substituteImpl(astBuilder, subst, &diff));
+ SubtypeWitness* substSubToMid = as<SubtypeWitness>(getSubToMid()->substituteImpl(astBuilder, subst, &diff));
+ SubtypeWitness* substMidToSup = as<SubtypeWitness>(getMidToSup()->substituteImpl(astBuilder, subst, &diff));
// If nothing changed, then we can bail out early.
if (!diff)
@@ -407,16 +481,7 @@ void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out)
// witnesses, and rely on them to print
// the starting and ending types.
- out << toSlice("TransitiveSubtypeWitness(") << subToMid << toSlice(", ") << midToSup << toSlice(")");
-}
-
-HashCode TransitiveSubtypeWitness::_getHashCodeOverride()
-{
- auto hash = sub->getHashCode();
- hash = combineHash(hash, sup->getHashCode());
- hash = combineHash(hash, subToMid->getHashCode());
- hash = combineHash(hash, midToSup->getHashCode());
- return hash;
+ out << toSlice("TransitiveSubtypeWitness(") << getSubToMid() << toSlice(", ") << getMidToSup() << toSlice(")");
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -425,9 +490,9 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a
{
int diff = 0;
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
- auto substWitness = as<SubtypeWitness>(conjunctionWitness->substituteImpl(astBuilder, subst, &diff));
+ auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff));
+ auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff));
+ auto substWitness = as<SubtypeWitness>(getConjunctionWitness()->substituteImpl(astBuilder, subst, &diff));
// If nothing changed, then we can bail out early.
if (!diff)
@@ -447,138 +512,34 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a
// simplification logic as needed.
//
return astBuilder->getExtractFromConjunctionSubtypeWitness(
- substSub, substSup, substWitness, indexInConjunction);
+ substSub, substSup, substWitness, getIndexInConjunction());
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val)
-{
- if (auto extractWitness = as<ExtractExistentialSubtypeWitness>(val))
- {
- return declRef.equals(extractWitness->declRef);
- }
- return false;
-}
-
void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out)
{
- out << toSlice("extractExistentialValue(") << declRef << toSlice(")");
-}
-
-HashCode ExtractExistentialSubtypeWitness::_getHashCodeOverride()
-{
- return declRef.getHashCode();
+ out << toSlice("extractExistentialValue(") << getDeclRef() << toSlice(")");
}
Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff);
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
+ auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
+ auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff));
+ auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff));
if (!diff)
return this;
(*ioDiff)++;
- ExtractExistentialSubtypeWitness* substValue = astBuilder->create<ExtractExistentialSubtypeWitness>();
- substValue->declRef = substDeclRef;
- substValue->sub = substSub;
- substValue->sup = substSup;
+ ExtractExistentialSubtypeWitness* substValue = astBuilder->getOrCreate<ExtractExistentialSubtypeWitness>(
+ substSub, substSup, substDeclRef);
return substValue;
}
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-bool TaggedUnionSubtypeWitness::_equalsValOverride(Val* val)
-{
- auto taggedUnionWitness = as<TaggedUnionSubtypeWitness>(val);
- if (!taggedUnionWitness)
- return false;
-
- auto caseCount = caseWitnesses.getCount();
- if (caseCount != taggedUnionWitness->caseWitnesses.getCount())
- return false;
-
- for (Index ii = 0; ii < caseCount; ++ii)
- {
- if (!caseWitnesses[ii]->equalsVal(taggedUnionWitness->caseWitnesses[ii]))
- return false;
- }
-
- return true;
-}
-
-void TaggedUnionSubtypeWitness::_toTextOverride(StringBuilder& out)
-{
- out << toSlice("TaggedUnionSubtypeWitness(");
- bool first = true;
- for (auto caseWitness : caseWitnesses)
- {
- if (!first)
- {
- out << toSlice(", ");
- }
- first = false;
-
- out << caseWitness;
- }
- out << toSlice(")");
-}
-
-HashCode TaggedUnionSubtypeWitness::_getHashCodeOverride()
-{
- HashCode hash = 0;
- for (auto caseWitness : caseWitnesses)
- {
- hash = combineHash(hash, caseWitness->getHashCode());
- }
- return hash;
-}
-
-Val* TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
-
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
-
- List<SubtypeWitness*> substCaseWitnesses;
- for (auto caseWitness : caseWitnesses)
- {
- substCaseWitnesses.add(
- as<SubtypeWitness>(caseWitness->substituteImpl(astBuilder, subst, &diff)));
- }
-
- if (!diff)
- return this;
-
- (*ioDiff)++;
-
- TaggedUnionSubtypeWitness* substWitness = astBuilder->create<TaggedUnionSubtypeWitness>();
- substWitness->sub = substSub;
- substWitness->sup = substSup;
- substWitness->caseWitnesses.swapWith(substCaseWitnesses);
- return substWitness;
-}
-
-bool ConjunctionSubtypeWitness::_equalsValOverride(Val* val)
-{
- auto other = as<ConjunctionSubtypeWitness>(val);
- if (!other)
- return false;
-
- for (Index i = 0; i < kComponentCount; ++i)
- {
- if (!other->componentWitnesses[i]) return false;
- if (!other->componentWitnesses[i]->equalsVal(componentWitnesses[i])) return false;
- }
- return true;
-}
-
void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
{
out << "ConjunctionSubtypeWitness(";
@@ -586,34 +547,23 @@ void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
{
if (i != 0) out << ",";
- auto w = componentWitnesses[i];
+ auto w = getComponentWitness(i);
if (w) out << w;
}
out << ")";
}
-HashCode ConjunctionSubtypeWitness::_getHashCodeOverride()
-{
- HashCode result = 0;
- for (Index i = 0; i < kComponentCount; ++i)
- {
- auto w = componentWitnesses[i];
- if (w) result = combineHash(result, w->getHashCode());
- }
- return result;
-}
-
Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
Val* substComponentWitnesses[kComponentCount];
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
+ auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff));
+ auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff));
for (Index i = 0; i < kComponentCount; ++i)
{
- auto w = componentWitnesses[i];
+ auto w = getComponentWitness(i);
substComponentWitnesses[i] = w ? w->substituteImpl(astBuilder, subst, &diff) : nullptr;
}
@@ -630,65 +580,25 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder,
auto result = astBuilder->getConjunctionSubtypeWitness(
substSub,
substSup,
- componentWitnesses[0],
- componentWitnesses[1]);
+ as<SubtypeWitness>(substComponentWitnesses[0]),
+ as<SubtypeWitness>(substComponentWitnesses[1]));
return result;
}
-bool ExtractFromConjunctionSubtypeWitness::_equalsValOverride(Val* val)
-{
- if (auto other = as<ExtractFromConjunctionSubtypeWitness>(val))
- {
- if(!sub->equals(other->sub)) return false;
- if(!sup->equals(other->sup)) return false;
- if(indexInConjunction != other->indexInConjunction) return false;
-
- return true;
- }
- return false;
-}
-
void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
{
out << "ExtractFromConjunctionSubtypeWitness(";
- if (conjunctionWitness)
- out << conjunctionWitness;
- if (sub)
- out << sub;
+ if (getConjunctionWitness())
+ out << getConjunctionWitness();
+ if (getSub())
+ out << getSub();
out << ",";
- if (sup)
- out << sup;
- out << "," << indexInConjunction;
+ if (getSup())
+ out << getSup();
+ out << "," << getIndexInConjunction();
out << ")";
}
-HashCode ExtractFromConjunctionSubtypeWitness::_getHashCodeOverride()
-{
- return combineHash(
- conjunctionWitness ? conjunctionWitness->getHashCode() : 0,
- sub ? sub->getHashCode() : 0,
- sup ? sup->getHashCode() : 0,
- indexInConjunction);
-}
-
-// ModifierVal
-
-bool ModifierVal::_equalsValOverride(Val* val)
-{
- // TODO: This is assuming we can fully deduplicate the values that represent
- // modifiers, which may not actually be the case if there are multiple modules
- // being combined that use different `ASTBuilder`s.
- //
- return this == val;
-}
-
-HashCode ModifierVal::_getHashCodeOverride()
-{
- Hasher hasher;
- hasher.hashValue((void*) this);
- return hasher.getResult();
-}
-
// UNormModifierVal
void UNormModifierVal::_toTextOverride(StringBuilder& out)
@@ -735,48 +645,14 @@ Val* NoDiffModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitu
// PolynomialIntVal
-bool PolynomialIntVal::_equalsValOverride(Val* val)
-{
- if (auto genericParamVal = as<GenericParamIntVal>(val))
- {
- return constantTerm == 0 && terms.getCount() == 1 &&
- terms[0]->paramFactors.getCount() == 1 && terms[0]->constFactor == 1 &&
- terms[0]->paramFactors[0]->param->equalsVal(genericParamVal) &&
- terms[0]->paramFactors[0]->power == 1;
- }
- else if (auto otherPolynomial = as<PolynomialIntVal>(val))
- {
- if (constantTerm != otherPolynomial->constantTerm)
- return false;
- if (terms.getCount() != otherPolynomial->terms.getCount())
- return false;
- for (Index i = 0; i < terms.getCount(); i++)
- {
- auto& thisTerm = *(terms[i]);
- auto& thatTerm = *(otherPolynomial->terms[i]);
- if (thisTerm.constFactor != thatTerm.constFactor)
- return false;
- if (thisTerm.paramFactors.getCount() != thatTerm.paramFactors.getCount())
- return false;
- for (Index j = 0; j < thisTerm.paramFactors.getCount(); j++)
- {
- if (thisTerm.paramFactors[j]->power != thatTerm.paramFactors[j]->power)
- return false;
- if (!thisTerm.paramFactors[j]->param->equalsVal(thatTerm.paramFactors[j]->param))
- return false;
- }
- }
- return true;
- }
- return false;
-}
-
void PolynomialIntVal::_toTextOverride(StringBuilder& out)
{
+ auto constantTerm = getConstantTerm();
+ auto terms = getTerms();
for (Index i = 0; i < terms.getCount(); i++)
{
auto& term = *(terms[i]);
- if (term.constFactor > 0)
+ if (term.getConstFactor() > 0)
{
if (i > 0)
out << "+";
@@ -784,14 +660,14 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out)
else
out << "-";
bool isFirstFactor = true;
- if (abs(term.constFactor) != 1 || term.paramFactors.getCount() == 0)
+ if (abs(term.getConstFactor()) != 1 || term.getParamFactors().getCount() == 0)
{
- out << abs(term.constFactor);
+ out << abs(term.getConstFactor());
isFirstFactor = false;
}
- for (Index j = 0; j < term.paramFactors.getCount(); j++)
+ for (Index j = 0; j < term.getParamFactors().getCount(); j++)
{
- auto factor = term.paramFactors[j];
+ auto factor = term.getParamFactors()[j];
if (isFirstFactor)
{
isFirstFactor = false;
@@ -800,10 +676,10 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out)
{
out << "*";
}
- factor->param->toText(out);
- if (factor->power != 1)
+ factor->getParam()->toText(out);
+ if (factor->getPower() != 1)
{
- out << "^^" << factor->power;
+ out << "^^" << factor->getPower();
}
}
}
@@ -821,227 +697,304 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out)
}
}
-HashCode PolynomialIntVal::_getHashCodeOverride()
+struct PolynomialIntValBuilder
{
- HashCode result = (HashCode)constantTerm;
- for (auto& term : terms)
+ ASTBuilder* astBuilder;
+
+ IntegerLiteralValue constantTerm = 0;
+ List<PolynomialIntValTerm*> terms;
+
+ PolynomialIntValBuilder(ASTBuilder* inAstBuilder)
+ : astBuilder(inAstBuilder)
+ {}
+
+ // compute val += opreand*multiplier;
+ bool addToPolynomialTerm(IntVal* operand, IntegerLiteralValue multiplier)
{
- if (!term) continue;
- result = combineHash(result, (HashCode)term->constFactor);
- for (auto& factor : term->paramFactors)
+ if (auto c = as<ConstantIntVal>(operand))
{
- result = combineHash(result, factor->param->getHashCode());
- result = combineHash(result, (HashCode)factor->power);
+ constantTerm += c->getValue() * multiplier;
+ return true;
}
+ else if (auto poly = as<PolynomialIntVal>(operand))
+ {
+ constantTerm += poly->getConstantTerm() * multiplier;
+ for (auto term : poly->getTerms())
+ {
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ multiplier * term->getConstFactor(), term->getParamFactors());
+ terms.add(newTerm);
+ }
+ return true;
+ }
+ else if (auto genVal = as<IntVal>(operand))
+ {
+ auto factor = astBuilder->getOrCreate<PolynomialIntValFactor>(genVal, 1);
+ auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(multiplier, makeArrayViewSingle(factor));
+ terms.add(term);
+ return true;
+ }
+ return false;
}
- return result;
-}
+
+ IntVal* canonicalize(Type* type)
+ {
+ List<PolynomialIntValTerm*> newTerms;
+ IntegerLiteralValue newConstantTerm = constantTerm;
+ auto addTerm = [&](PolynomialIntValTerm* newTerm)
+ {
+ for (auto& term : newTerms)
+ {
+ if (term->canCombineWith(*newTerm))
+ {
+ term = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ term->getConstFactor() + newTerm->getConstFactor(),
+ term->getParamFactors());
+ return;
+ }
+ }
+ newTerms.add(newTerm);
+ };
+ for (auto term : terms)
+ {
+ if (term->getConstFactor() == 0)
+ continue;
+ List<PolynomialIntValFactor*> newFactors;
+ List<bool> factorIsDifferent;
+ for (Index i = 0; i < term->getParamFactors().getCount(); i++)
+ {
+ auto factor = term->getParamFactors()[i];
+ bool factorFound = false;
+ for (Index j = 0; j < newFactors.getCount(); j++)
+ {
+ auto& newFactor = newFactors[j];
+ if (factor->getParam()->equals(newFactor->getParam()))
+ {
+ if (!factorIsDifferent[j])
+ {
+ factorIsDifferent[j] = true;
+ auto clonedFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(newFactor->getParam(), newFactor->getPower());
+ newFactor = clonedFactor;
+ }
+ newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(newFactor->getParam(), newFactor->getPower() + factor->getPower());
+ factorFound = true;
+ break;
+ }
+ }
+ if (!factorFound)
+ {
+ newFactors.add(factor);
+ factorIsDifferent.add(false);
+ }
+ }
+ List<PolynomialIntValFactor*> newFactors2;
+ // Remove zero-powered factors.
+ for (auto factor : newFactors)
+ {
+ if (factor->getPower() != 0)
+ newFactors2.add(factor);
+ }
+ if (newFactors2.getCount() == 0)
+ {
+ newConstantTerm += term->getConstFactor();
+ continue;
+ }
+ newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; });
+ bool isDifferent = false;
+ if (newFactors2.getCount() != term->getParamFactors().getCount())
+ isDifferent = true;
+ if (!isDifferent)
+ {
+ for (Index i = 0; i < term->getParamFactors().getCount(); i++)
+ if (term->getParamFactors()[i] != newFactors2[i])
+ {
+ isDifferent = true;
+ break;
+ }
+ }
+ if (!isDifferent)
+ {
+ addTerm(term);
+ }
+ else
+ {
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(term->getConstFactor(), newFactors2.getArrayView());
+ addTerm(newTerm);
+ }
+ }
+ List<PolynomialIntValTerm*> newTerms2;
+ for (auto term : newTerms)
+ {
+ if (term->getConstFactor() == 0)
+ continue;
+ newTerms2.add(term);
+ }
+ newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; });
+ terms = _Move(newTerms2);
+ constantTerm = newConstantTerm;
+ if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->getConstFactor() == 1 && terms[0]->getParamFactors().getCount() == 1 &&
+ terms[0]->getParamFactors()[0]->getPower() == 1)
+ {
+ return terms[0]->getParamFactors()[0]->getParam();
+ }
+ if (terms.getCount() == 0)
+ return astBuilder->getIntVal(type, constantTerm);
+ return nullptr;
+ }
+
+ IntVal* getIntVal(Type* type)
+ {
+ if (auto canVal = canonicalize(type))
+ return canVal;
+ return astBuilder->getOrCreate<PolynomialIntVal>(type, constantTerm, terms.getArrayView());
+ }
+};
Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- IntegerLiteralValue evaluatedConstantTerm = constantTerm;
- List<PolynomialIntValTerm*> evaluatedTerms;
- for (auto& term : terms)
+ PolynomialIntValBuilder builder(astBuilder);
+ for (auto& term : getTerms())
{
IntegerLiteralValue evaluatedTermConstFactor;
List<PolynomialIntValFactor*> evaluatedTermParamFactors;
- evaluatedTermConstFactor = term->constFactor;
- for (auto& factor : term->paramFactors)
+ evaluatedTermConstFactor = term->getConstFactor();
+ for (auto& factor : term->getParamFactors())
{
- auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff);
+ auto substResult = factor->getParam()->substituteImpl(astBuilder, subst, &diff);
if (auto constantVal = as<ConstantIntVal>(substResult))
{
- evaluatedTermConstFactor *= constantVal->value;
+ evaluatedTermConstFactor *= constantVal->getValue();
}
else if (auto intResult = as<IntVal>(substResult))
{
- auto newFactor = astBuilder->create<PolynomialIntValFactor>();
- newFactor->param = intResult;
- newFactor->power = factor->power;
+ auto newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(intResult, factor->getPower());
evaluatedTermParamFactors.add(newFactor);
}
}
if (evaluatedTermParamFactors.getCount() == 0)
{
- evaluatedConstantTerm += evaluatedTermConstFactor;
+ builder.constantTerm += evaluatedTermConstFactor;
}
else
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->paramFactors = _Move(evaluatedTermParamFactors);
- newTerm->constFactor = evaluatedTermConstFactor;
- evaluatedTerms.add(newTerm);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ evaluatedTermConstFactor, evaluatedTermParamFactors.getArrayView());
+ builder.terms.add(newTerm);
}
}
*ioDiff += diff;
- if (evaluatedTerms.getCount() == 0)
- return astBuilder->getIntVal(type, evaluatedConstantTerm);
+ if (builder.terms.getCount() == 0)
+ return astBuilder->getIntVal(getType(), builder.constantTerm);
if (diff != 0)
{
- auto newPolynomial = astBuilder->create<PolynomialIntVal>(type);
- newPolynomial->constantTerm = evaluatedConstantTerm;
- newPolynomial->terms = _Move(evaluatedTerms);
- return newPolynomial->canonicalize(astBuilder);
+ return builder.getIntVal(getType());
}
return this;
}
-
-// compute val += opreand*multiplier;
-bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier)
-{
- if (auto c = as<ConstantIntVal>(operand))
- {
- val->constantTerm += c->value * multiplier;
- return true;
- }
- else if (auto poly = as<PolynomialIntVal>(operand))
- {
- val->constantTerm += poly->constantTerm * multiplier;
- for (auto term : poly->terms)
- {
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = multiplier * term->constFactor;
- newTerm->paramFactors = term->paramFactors;
- val->terms.add(newTerm);
- }
- return true;
- }
- else if (auto genVal = as<IntVal>(operand))
- {
- auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = multiplier;
- auto factor = astBuilder->create<PolynomialIntValFactor>();
- factor->power = 1;
- factor->param = genVal;
- term->paramFactors.add(factor);
- val->terms.add(term);
- return true;
- }
- return false;
-}
-
-PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base)
+IntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base)
{
- auto result = astBuilder->create<PolynomialIntVal>(base->type);
- if (!addToPolynomialTerm(astBuilder, result, base, -1))
- return nullptr;
- result->canonicalize(astBuilder);
- return result;
+ PolynomialIntValBuilder builder(astBuilder);
+ builder.addToPolynomialTerm(base, -1);
+ return builder.getIntVal(base->getType());
}
-PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
+IntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
{
- auto result = astBuilder->create<PolynomialIntVal>(op0->type);
- if (!addToPolynomialTerm(astBuilder, result, op0, 1))
- return nullptr;
- if (!addToPolynomialTerm(astBuilder, result, op1, -1))
- return nullptr;
- result->canonicalize(astBuilder);
- return result;
+ PolynomialIntValBuilder builder(astBuilder);
+ builder.addToPolynomialTerm(op0, 1);
+ builder.addToPolynomialTerm(op1, -1);
+ return builder.getIntVal(op0->getType());
}
-PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
+IntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
{
- auto result = astBuilder->create<PolynomialIntVal>(op0->type);
- if (!addToPolynomialTerm(astBuilder, result, op0, 1))
- return nullptr;
- if (!addToPolynomialTerm(astBuilder, result, op1, 1))
- return nullptr;
- result->canonicalize(astBuilder);
- return result;
+ PolynomialIntValBuilder builder(astBuilder);
+ builder.addToPolynomialTerm(op0, 1);
+ builder.addToPolynomialTerm(op1, 1);
+ return builder.getIntVal(op0->getType());
}
-PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
+IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
{
if (auto poly0 = as<PolynomialIntVal>(op0))
{
if (auto poly1 = as<PolynomialIntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>(poly0->type);
+ PolynomialIntValBuilder builder(astBuilder);
// add poly0.constant * poly1.constant
- result->constantTerm = poly0->constantTerm * poly1->constantTerm;
+ builder.constantTerm = poly0->getConstantTerm() * poly1->getConstantTerm();
// add poly0.constant * poly1.terms
- if (poly0->constantTerm != 0)
+ if (poly0->getConstantTerm() != 0)
{
- for (auto term : poly1->terms)
+ for (auto term : poly1->getTerms())
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = poly0->constantTerm * term->constFactor;
- newTerm->paramFactors.addRange(term->paramFactors);
- result->terms.add(newTerm);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ poly0->getConstantTerm() * term->getConstFactor(), term->getParamFactors());
+ builder.terms.add(newTerm);
}
}
// add poly1.constant * poly0.terms
- if (poly1->constantTerm != 0)
+ if (poly1->getConstantTerm() != 0)
{
- for (auto term : poly0->terms)
+ for (auto term : poly0->getTerms())
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = poly1->constantTerm * term->constFactor;
- newTerm->paramFactors.addRange(term->paramFactors);
- result->terms.add(newTerm);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ poly1->getConstantTerm() * term->getConstFactor(),
+ term->getParamFactors());
+ builder.terms.add(newTerm);
}
}
// add poly1.terms * poly0.terms
- for (auto term0 : poly0->terms)
+ for (auto term0 : poly0->getTerms())
{
- for (auto term1 : poly1->terms)
+ for (auto term1 : poly1->getTerms())
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term0->constFactor * term1->constFactor;
- newTerm->paramFactors.addRange(term0->paramFactors);
- newTerm->paramFactors.addRange(term1->paramFactors);
- result->terms.add(newTerm);
+ List<PolynomialIntValFactor*> newFactors;
+ for (auto f : term0->getParamFactors()) newFactors.add(f);
+ for (auto f : term1->getParamFactors()) newFactors.add(f);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ term0->getConstFactor() * term1->getConstFactor(), newFactors.getArrayView());
+ builder.terms.add(newTerm);
}
}
- result->canonicalize(astBuilder);
- return result;
+ return builder.getIntVal(op0->getType());
}
else if (auto cVal1 = as<ConstantIntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>(poly0->type);
- result->constantTerm = poly0->constantTerm * cVal1->value;
- auto factor1 = astBuilder->create<PolynomialIntValFactor>();
- for (auto term : poly0->terms)
+ PolynomialIntValBuilder builder(astBuilder);
+ builder.constantTerm = poly0->getConstantTerm() * cVal1->getValue();
+ for (auto term : poly0->getTerms())
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term->constFactor * cVal1->value;
- newTerm->paramFactors.addRange(term->paramFactors);
- newTerm->paramFactors.add(factor1);
- result->terms.add(newTerm);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(term->getConstFactor() * cVal1->getValue(), term->getParamFactors());
+ builder.terms.add(newTerm);
}
- result->canonicalize(astBuilder);
- return result;
+ return builder.getIntVal(poly0->getType());
}
else if (auto val1 = as<IntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>(poly0->type);
- result->constantTerm = 0;
- auto factor1 = astBuilder->create<PolynomialIntValFactor>();
- factor1->power = 1;
- factor1->param = val1;
- if (poly0->constantTerm != 0)
+ PolynomialIntValBuilder builder(astBuilder);
+ auto factor1 = astBuilder->getOrCreate<PolynomialIntValFactor>(val1, 1);
+ if (poly0->getConstantTerm() != 0)
{
- auto term0 = astBuilder->create<PolynomialIntValTerm>();
- term0->constFactor = poly0->constantTerm;
- term0->paramFactors.add(factor1);
- result->terms.add(term0);
+ auto term0 = astBuilder->getOrCreate<PolynomialIntValTerm>(poly0->getConstantTerm(), makeArrayViewSingle(factor1));
+ builder.terms.add(term0);
}
- for (auto term : poly0->terms)
+ for (auto term : poly0->getTerms())
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term->constFactor;
- newTerm->paramFactors.addRange(term->paramFactors);
- newTerm->paramFactors.add(factor1);
- result->terms.add(newTerm);
+ List<PolynomialIntValFactor*> newFactors;
+ for (auto f: term->getParamFactors())
+ newFactors.add(f);
+ newFactors.add(factor1);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ term->getConstFactor(), newFactors.getArrayView());
+ builder.terms.add(newTerm);
}
- result->canonicalize(astBuilder);
- return result;
+ return builder.getIntVal(poly0->getType());
}
else
return nullptr;
@@ -1058,184 +1011,48 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
}
else if (auto cVal1 = as<ConstantIntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>(val0->type);
- auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = cVal1->value;
- auto factor0 = astBuilder->create<PolynomialIntValFactor>();
- factor0->power = 1;
- factor0->param = val0;
- term->paramFactors.add(factor0);
- result->terms.add(term);
- result->canonicalize(astBuilder);
- return result;
+ PolynomialIntValBuilder builder(astBuilder);
+ auto factor0 = astBuilder->getOrCreate<PolynomialIntValFactor>(val0, 1);
+ auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ cVal1->getValue(), makeArrayView(&factor0, 1));
+ builder.terms.add(term);
+ return builder.getIntVal(val0->getType());
}
else if (auto val1 = as<IntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>(val0->type);
- auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = 1;
- auto factor0 = astBuilder->create<PolynomialIntValFactor>();
- factor0->power = 1;
- factor0->param = val0;
- term->paramFactors.add(factor0);
- auto factor1 = astBuilder->create<PolynomialIntValFactor>();
- factor1->power = 1;
- factor1->param = val1;
- term->paramFactors.add(factor1);
- result->terms.add(term);
- result->canonicalize(astBuilder);
- return result;
+ PolynomialIntValBuilder builder(astBuilder);
+ auto factor0 = astBuilder->getOrCreate<PolynomialIntValFactor>(val0, 1);
+ auto factor1 = astBuilder->getOrCreate<PolynomialIntValFactor>(val1, 1);
+ PolynomialIntValFactor* newFactors[] = { factor0, factor1 };
+ auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(1, makeArrayView(newFactors));
+ builder.terms.add(term);
+ return builder.getIntVal(val0->getType());
}
}
return nullptr;
}
-IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder)
-{
- List<PolynomialIntValTerm*> newTerms;
- IntegerLiteralValue newConstantTerm = constantTerm;
- auto addTerm = [&](PolynomialIntValTerm* newTerm)
- {
- for (auto term : newTerms)
- {
- if (term->canCombineWith(*newTerm))
- {
- term->constFactor += newTerm->constFactor;
- return;
- }
- }
- newTerms.add(newTerm);
- };
- for (auto term : terms)
- {
- if (term->constFactor == 0)
- continue;
- List<PolynomialIntValFactor*> newFactors;
- List<bool> factorIsDifferent;
- for (Index i = 0; i < term->paramFactors.getCount(); i++)
- {
- auto factor = term->paramFactors[i];
- bool factorFound = false;
- for (Index j = 0; j < newFactors.getCount(); j++)
- {
- auto& newFactor = newFactors[j];
- if (factor->param->equalsVal(newFactor->param))
- {
- if (!factorIsDifferent[j])
- {
- factorIsDifferent[j] = true;
- auto clonedFactor = builder->create<PolynomialIntValFactor>();
- clonedFactor->param = newFactor->param;
- clonedFactor->power = newFactor->power;
- newFactor = clonedFactor;
- }
- newFactor->power += factor->power;
- factorFound = true;
- break;
- }
- }
- if (!factorFound)
- {
- newFactors.add(factor);
- factorIsDifferent.add(false);
- }
- }
- List<PolynomialIntValFactor*> newFactors2;
- for (auto factor : newFactors)
- {
- if (factor->power != 0)
- newFactors2.add(factor);
- }
- if (newFactors2.getCount() == 0)
- {
- newConstantTerm += term->constFactor;
- continue;
- }
- newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; });
- bool isDifferent = false;
- if (newFactors2.getCount() != term->paramFactors.getCount())
- isDifferent = true;
- if (!isDifferent)
- {
- for (Index i = 0; i < term->paramFactors.getCount(); i++)
- if (term->paramFactors[i] != newFactors2[i])
- {
- isDifferent = true;
- break;
- }
- }
- if (!isDifferent)
- {
- addTerm(term);
- }
- else
- {
- auto newTerm = builder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term->constFactor;
- newTerm->paramFactors = _Move(newFactors2);
- addTerm(newTerm);
- }
- }
- List<PolynomialIntValTerm*> newTerms2;
- for (auto term : newTerms)
- {
- if (term->constFactor == 0)
- continue;
- newTerms2.add(term);
- }
- newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; });
- terms = _Move(newTerms2);
- constantTerm = newConstantTerm;
- if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->constFactor == 1 && terms[0]->paramFactors.getCount() == 1 &&
- terms[0]->paramFactors[0]->power == 1)
- {
- return terms[0]->paramFactors[0]->param;
- }
- if (terms.getCount() == 0)
- return builder->getIntVal(type, constantTerm);
- return this;
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeCastIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool TypeCastIntVal::_equalsValOverride(Val* val)
-{
- if (auto typeCastIntVal = as<TypeCastIntVal>(val))
- {
- if (!type->equals(typeCastIntVal->type))
- return false;
- if (!base->equalsVal(typeCastIntVal->base))
- return false;
- return true;
- }
- return false;
-}
void TypeCastIntVal::_toTextOverride(StringBuilder& out)
{
- type->toText(out);
+ getType()->toText(out);
out << "(";
- base->toText(out);
+ getBase()->toText(out);
out << ")";
}
-HashCode TypeCastIntVal::_getHashCodeOverride()
-{
- HashCode result = type->getHashCode();
- result = combineHash(result, base->getHashCode());
- return result;
-}
-
Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);
if (auto c = as<ConstantIntVal>(base))
{
- IntegerLiteralValue resultValue = c->value;
+ IntegerLiteralValue resultValue = c->getValue();
auto baseType = as<BasicExpressionType>(resultType);
if (baseType)
{
- switch (baseType->baseType)
+ switch (baseType->getBaseType())
{
case BaseType::Int:
resultValue = (int)resultValue;
@@ -1275,11 +1092,11 @@ Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val*
Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto substBase = base->substituteImpl(astBuilder, subst, &diff);
- if (substBase != base)
+ auto substBase = getBase()->substituteImpl(astBuilder, subst, &diff);
+ if (substBase != getBase())
diff++;
- auto substType = as<Type>(type->substituteImpl(astBuilder, subst, &diff));
- if (substType != type)
+ auto substType = as<Type>(getType()->substituteImpl(astBuilder, subst, &diff));
+ if (substType != getType())
diff++;
*ioDiff += diff;
if (diff)
@@ -1289,7 +1106,7 @@ Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio
return newVal;
else
{
- auto result = astBuilder->create<TypeCastIntVal>(substType, substBase);
+ auto result = astBuilder->getOrCreate<TypeCastIntVal>(substType, substBase);
return result;
}
}
@@ -1297,29 +1114,20 @@ Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio
return this;
}
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-bool FuncCallIntVal::_equalsValOverride(Val* val)
+Val* TypeCastIntVal::_resolveImplOverride()
{
- if (auto funcCallIntVal = as<FuncCallIntVal>(val))
- {
- if (!funcDeclRef.equals(funcCallIntVal->funcDeclRef))
- return false;
- if (args.getCount() != funcCallIntVal->args.getCount())
- return false;
- for (Index i = 0; i < args.getCount(); i++)
- {
- if (!args[i]->equalsVal(funcCallIntVal->args[i]))
- return false;
- }
- return true;
- }
- return false;
+ if (auto resolved = tryFoldImpl(getCurrentASTBuilder(), getType(), getBase(), nullptr))
+ return resolved;
+ return this;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
void FuncCallIntVal::_toTextOverride(StringBuilder& out)
{
+ auto args = getArgs();
+ auto funcDeclRef = getFuncDeclRef();
+
auto argToText = [&](int index)
{
if (as<PolynomialIntVal>(args[index]) || as<FuncCallIntVal>(args[index]))
@@ -1369,14 +1177,37 @@ void FuncCallIntVal::_toTextOverride(StringBuilder& out)
}
}
-HashCode FuncCallIntVal::_getHashCodeOverride()
+Val* FuncCallIntVal::_resolveImplOverride()
{
- HashCode result = funcDeclRef.getHashCode();
+ auto astBuilder = getCurrentASTBuilder();
+ auto args = getArgs();
+ auto funcDeclRef = getFuncDeclRef();
+ auto funcType = getFuncType();
+
+ Val* resolvedVal = this;
+
+ auto newFuncDeclRef = as<DeclRefBase>(funcDeclRef.declRefBase->resolve());
+ if (!newFuncDeclRef)
+ return this;
+ bool diff = false;
+ List<IntVal*> newArgs;
for (auto arg : args)
{
- result = combineHash(result, arg->getHashCode());
+ auto newArg = as<IntVal>(arg->resolve());
+ if (!newArg)
+ return this;
+ newArgs.add(newArg);
+ if (newArg != arg)
+ diff = true;
}
- return result;
+
+ if (auto resolved = tryFoldImpl(astBuilder, getType(), newFuncDeclRef, newArgs, nullptr))
+ resolvedVal = resolved;
+ else if (diff)
+ {
+ resolvedVal = astBuilder->getOrCreate<FuncCallIntVal>(getType(), newFuncDeclRef, funcType, newArgs.getArrayView());
+ }
+ return resolvedVal;
}
Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink)
@@ -1413,25 +1244,25 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
#define BINARY_OPERATOR_CASE(op) \
if (opNameSlice == toSlice(#op)) \
{ \
- resultValue = constArgs[0]->value op constArgs[1]->value; \
+ resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \
} else
#define DIV_OPERATOR_CASE(op) \
if (opNameSlice == toSlice(#op)) \
{ \
- if (constArgs[1]->value == 0) \
+ if (constArgs[1]->getValue() == 0) \
{ \
if (sink) \
sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \
return nullptr; \
} \
- resultValue = constArgs[0]->value op constArgs[1]->value; \
+ resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \
} else
#define LOGICAL_OPERATOR_CASE(op) \
if (opNameSlice == toSlice(#op)) \
{ \
- resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \
+ resultValue = (((constArgs[0]->getValue()!=0) op (constArgs[1]->getValue()!=0)) ? 1 : 0); \
} else
@@ -1463,9 +1294,9 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
LOGICAL_OPERATOR_CASE(&&)
LOGICAL_OPERATOR_CASE(||)
// Special cases need their "operator" names quoted.
- SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->value != 0) ? 1 : 0);)
- SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->value;)
- SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value;)
+ SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->getValue() != 0) ? 1 : 0);)
+ SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->getValue();)
+ SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->getValue() != 0 ? constArgs[1]->getValue() : constArgs[2]->getValue();)
TERMINATING_CASE(SLANG_UNREACHABLE("constant folding of FuncCallIntVal");)
return astBuilder->getIntVal(resultType, resultValue);
@@ -1483,9 +1314,9 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff);
+ auto newFuncDeclRef = getFuncDeclRef().substituteImpl(astBuilder, subst, &diff);
List<IntVal*> newArgs;
- for (auto& arg : args)
+ for (auto& arg : getArgs())
{
auto substArg = arg->substituteImpl(astBuilder, subst, &diff);
if (substArg != arg)
@@ -1496,15 +1327,12 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio
if (diff)
{
// TODO: report diagnostics back.
- auto newVal = tryFoldImpl(astBuilder, type, newFuncDeclRef, newArgs, nullptr);
+ auto newVal = tryFoldImpl(astBuilder, getType(), newFuncDeclRef, newArgs, nullptr);
if (newVal)
return newVal;
else
{
- auto result = astBuilder->create<FuncCallIntVal>(type);
- result->args = _Move(newArgs);
- result->funcDeclRef = newFuncDeclRef;
- result->funcType = funcType;
+ auto result = astBuilder->getOrCreate<FuncCallIntVal>(getType(), newFuncDeclRef, getFuncType(), newArgs.getArrayView());
return result;
}
}
@@ -1514,40 +1342,47 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! WitnessLookupIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool WitnessLookupIntVal::_equalsValOverride(Val* val)
-{
- if (auto lookupIntVal = as<WitnessLookupIntVal>(val))
- {
- if (!witness->equalsVal(lookupIntVal->witness))
- return false;
- if (key != lookupIntVal->key)
- return false;
- return true;
- }
- return false;
-}
-
void WitnessLookupIntVal::_toTextOverride(StringBuilder& out)
{
- witness->sub->toText(out);
+ getWitness()->getSub()->toText(out);
out << ".";
- out << (key->getName() ? key->getName()->text : "??");
+ out << (getKey()->getName() ? getKey()->getName()->text : "??");
}
-HashCode WitnessLookupIntVal::_getHashCodeOverride()
+Val* WitnessLookupIntVal::_resolveImplOverride()
{
- HashCode result = witness->getHashCode();
- result = combineHash(result, Slang::getHashCode(key));
- return result;
+ auto astBuilder = getCurrentASTBuilder();
+
+ auto newWitness = as<SubtypeWitness>(getWitness()->resolve());
+ if (!newWitness)
+ return this;
+
+ auto witnessVal = tryLookUpRequirementWitness(astBuilder, newWitness, getKey());
+ if (witnessVal.getFlavor() == RequirementWitness::Flavor::val)
+ {
+ return witnessVal.getVal();
+ }
+
+ auto newType = as<Type>(getType()->resolve());
+ if (!newType)
+ return this;
+
+ if (newWitness != getWitness() || newType != getType())
+ {
+ return astBuilder->getOrCreate<WitnessLookupIntVal>(newType, newWitness, getKey());
+ }
+
+ return this;
}
+
Val* WitnessLookupIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto newWitness = witness->substituteImpl(astBuilder, subst, &diff);
+ auto newWitness = getWitness()->substituteImpl(astBuilder, subst, &diff);
*ioDiff += diff;
if (diff)
{
- auto witnessEntry = tryFoldOrNull(astBuilder, as<SubtypeWitness>(newWitness), key);
+ auto witnessEntry = tryFoldOrNull(astBuilder, as<SubtypeWitness>(newWitness), getKey());
if (witnessEntry)
return witnessEntry;
}
@@ -1573,51 +1408,93 @@ Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witnes
{
if (auto result = tryFoldOrNull(astBuilder, witness, key))
return result;
- auto witnessResult = astBuilder->create<WitnessLookupIntVal>();
- witnessResult->witness = witness;
- witnessResult->key = key;
- witnessResult->type = type;
+ auto witnessResult = astBuilder->getOrCreate<WitnessLookupIntVal>(type, witness, key);
return witnessResult;
}
-
-bool DifferentiateVal::_equalsValOverride(Val* val)
-{
- if (auto other = as<DifferentiateVal>(val))
- {
- return other->astNodeType == astNodeType && other->func == func;
- }
- return false;
-}
-
void DifferentiateVal::_toTextOverride(StringBuilder& out)
{
out << "DifferentiateVal(";
- out << func;
+ out << getFunc();
out << ")";
}
-HashCode DifferentiateVal::_getHashCodeOverride()
-{
- HashCode result = (HashCode)astNodeType;
- result = combineHash(result, func.getHashCode());
- return result;
-}
-
Val* DifferentiateVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto newFunc = func.substituteImpl(astBuilder, subst, &diff);
+ auto newFunc = getFunc().substituteImpl(astBuilder, subst, &diff);
*ioDiff += diff;
if (diff)
{
auto result = as<DifferentiateVal>(astBuilder->createByNodeType(astNodeType));
- result->func = newFunc;
+ result->getFunc() = newFunc;
return result;
}
// Nothing found: don't substitute.
return this;
}
+Val* DifferentiateVal::_resolveImplOverride()
+{
+ return this;
+}
+
+Val* PolynomialIntValFactor::_resolveImplOverride()
+{
+ auto astBuilder = getCurrentASTBuilder();
+
+ auto newParam = as<IntVal>(getParam()->resolve());
+ if (newParam && newParam != getParam())
+ return astBuilder->getOrCreate<PolynomialIntValFactor>(newParam, getPower());
+
+ return this;
+}
+
+Val* PolynomialIntValTerm::_resolveImplOverride()
+{
+ auto astBuilder = getCurrentASTBuilder();
+
+ bool diff = false;
+ List<PolynomialIntValFactor*> newFactors;
+ for (auto factor : getParamFactors())
+ {
+ auto newFactor = as<PolynomialIntValFactor>(factor->resolve());
+ if (!newFactor)
+ return this;
+
+ if (newFactor != factor)
+ diff = true;
+ newFactors.add(newFactor);
+ }
+
+ if (diff)
+ return astBuilder->getOrCreate<PolynomialIntValTerm>(getConstFactor(), newFactors.getArrayView());
+
+ return this;
+}
+
+Val* PolynomialIntVal::_resolveImplOverride()
+{
+ auto astBuilder = getCurrentASTBuilder();
+
+ bool diff = false;
+ PolynomialIntValBuilder builder(astBuilder);
+ builder.constantTerm = getConstantTerm();
+ for (auto term : getTerms())
+ {
+ auto newTerm = as<PolynomialIntValTerm>(term->resolve());
+ if (!newTerm)
+ return this;
+
+ if (newTerm != term)
+ diff = true;
+ builder.terms.add(newTerm);
+ }
+
+ if (diff)
+ return builder.getIntVal(getType());
+
+ return this;
+}
} // namespace Slang