From a2d90fb275962da84611160f8ddd74d934a68dbd Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 4 Aug 2023 15:47:39 -0700 Subject: Redesign `DeclRef` and systematic `Val` deduplication (#3049) * Redesign DeclRef + Deduplicate Val. * Update project files * Fix warning. * Fix. * Fix. * Remove `Val::_equalsImplOverride`. * Rmove `Val::_getHashCodeOverride`. * Remove `semanticVisitor` param from `resolve`. * Cleanups. --------- Co-authored-by: Yong He --- source/slang/slang-ast-val.cpp | 1477 ++++++++++++++++++---------------------- 1 file changed, 677 insertions(+), 800 deletions(-) (limited to 'source/slang/slang-ast-val.cpp') diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index b45300af8..056577eb0 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -6,9 +6,47 @@ #include "slang-generated-ast-macro.h" #include "slang-diagnostics.h" #include "slang-syntax.h" +#include "slang-ast-val.h" namespace Slang { + +bool ValNodeDesc::operator==(ValNodeDesc const& that) const +{ + if (hashCode != that.hashCode) return false; + if (type != that.type) return false; + if (operands.getCount() != that.operands.getCount()) return false; + for (Index i = 0; i < operands.getCount(); ++i) + { + // Note: we are comparing the operands directly for identity + // (pointer equality) rather than doing the `Val`-level + // equality check. + // + // The rationale here is that nodes that will be created + // via a `NodeDesc` *should* all be going through the + // deduplication path anyway, as should their operands. + // + if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; + } + return true; +} + +void ValNodeDesc::init() +{ + Hasher hasher; + hasher.hashValue(Int(type)); + for (Index i = 0; i < operands.getCount(); ++i) + { + // Note: we are hashing the raw pointer value rather + // than the content of the value node. This is done + // to match the semantics implemented for `==` on + // `NodeDesc`. + // + hasher.hashValue(operands[i].values.nodeOperand); + } + hashCode = hasher.getResult(); +} + Val* Val::substitute(ASTBuilder* astBuilder, SubstitutionSet subst) { if (!subst) return this; @@ -21,14 +59,103 @@ Val* Val::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioD SLANG_AST_NODE_VIRTUAL_CALL(Val, substituteImpl, (astBuilder, subst, ioDiff)) } -bool Val::equalsVal(Val* val) +void Val::toText(StringBuilder& out) { - SLANG_AST_NODE_VIRTUAL_CALL(Val, equalsVal, (val)) + SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out)) } -void Val::toText(StringBuilder& out) +Val* Val::_resolveImplOverride() { - SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out)) + SLANG_UNEXPECTED("Val::_resolveImplOverride not overridden"); +} + +Val* Val::resolveImpl() +{ + SLANG_AST_NODE_VIRTUAL_CALL(Val, resolveImpl, ()); +} + +Val* Val::resolve() +{ + auto astBuilder = getCurrentASTBuilder(); + + // If we are not in a proper checking context, just return the previously resolved val. + if (!astBuilder) + return m_resolvedVal? m_resolvedVal : this; + if (m_resolvedVal && m_resolvedValEpoch == getCurrentASTBuilder()->getEpoch()) + { + SLANG_ASSERT(as(m_resolvedVal)); + return m_resolvedVal; + } + + // Update epoch now to avoid infinite recursion. + m_resolvedValEpoch = getCurrentASTBuilder()->getEpoch(); + m_resolvedVal = this; + m_resolvedVal = resolveImpl(); + + // Check if we are resolved to an existing Val in the AST cache. + ValNodeDesc newDesc; + newDesc.type = m_resolvedVal->astNodeType; + for (auto operand : m_resolvedVal->m_operands) + { + if (operand.kind == ValNodeOperandKind::ValNode) + { + auto valOperand = as(operand.values.nodeOperand); + if (valOperand) + { + operand.values.nodeOperand = valOperand->resolve(); + } + } + newDesc.operands.add(operand); + } + newDesc.init(); + + NodeBase* existingNode = nullptr; + if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode)) + m_resolvedVal = as(existingNode); + +#ifdef _DEBUG + if (m_resolvedVal->_debugUID > 0 && this->_debugUID < 0) + { + //SLANG_ASSERT_FAILURE("should not be modifying stdlib vals outside of stdlib checking."); + } +#endif + return m_resolvedVal; +} + +ValNodeDesc Val::getDesc() +{ + ValNodeDesc desc; + desc.type = astNodeType; + for (auto operand : m_operands) + desc.operands.add(operand); + desc.init(); + return desc; +} + +Val* Val::defaultResolveImpl() +{ + // Default resolve implementation is to recursively resolve all operands, and lookup in deduplication cache. + ValNodeDesc newDesc; + newDesc.type = astNodeType; + for (auto operand : m_operands) + { + if (operand.kind == ValNodeOperandKind::ValNode) + { + auto valOperand = as(operand.values.nodeOperand); + if (valOperand) + { + operand.values.nodeOperand = valOperand->resolve(); + } + } + newDesc.operands.add(operand); + } + newDesc.init(); + auto astBuilder = getCurrentASTBuilder(); + + NodeBase* existingNode = nullptr; + if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode)) + return as(existingNode); + return this; } String Val::toString() @@ -40,7 +167,7 @@ String Val::toString() HashCode Val::getHashCode() { - SLANG_AST_NODE_VIRTUAL_CALL(Val, getHashCode, ()) + return Slang::getHashCode(resolve()); } Val* Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) @@ -52,124 +179,84 @@ Val* Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, return this; } -bool Val::_equalsValOverride(Val* val) -{ - SLANG_UNUSED(val); - SLANG_UNEXPECTED("Val::_equalsValOverride not overridden"); - //return false; -} - void Val::_toTextOverride(StringBuilder& out) { SLANG_UNUSED(out); SLANG_UNEXPECTED("Val::_toStringOverride not overridden"); } -HashCode Val::_getHashCodeOverride() -{ - SLANG_UNEXPECTED("Val::_getHashCodeOverride not overridden"); - //return HashCode(0); -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ConstantIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ConstantIntVal::_equalsValOverride(Val* val) -{ - if (auto intVal = as(val)) - return value == intVal->value; - return false; -} - void ConstantIntVal::_toTextOverride(StringBuilder& out) { - out << value; -} - -HashCode ConstantIntVal::_getHashCodeOverride() -{ - return (HashCode)value; + out << getValue(); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericParamIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool GenericParamIntVal::_equalsValOverride(Val* val) -{ - if (auto genericParamVal = as(val)) - { - return declRef.equals(genericParamVal->declRef); - } - return false; -} - void GenericParamIntVal::_toTextOverride(StringBuilder& out) { - Name* name = declRef.getName(); + Name* name = getDeclRef().getName(); if (name) { out << name->text; } } -HashCode GenericParamIntVal::_getHashCodeOverride() -{ - return declRef.getHashCode() ^ HashCode(0xFFFF); -} - Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff) { // search for a substitution that might apply to us - for (auto s = subst.substitutions; s; s = s->getOuter()) + auto outerGeneric = as(paramDecl->parentDecl); + if (!outerGeneric) + return paramVal; + + GenericAppDeclRef* genAppArgs = subst.findGenericAppDeclRef(outerGeneric); + if (!genAppArgs) { - auto genSubst = as(s); - if (!genSubst) - continue; - - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genSubst->getGenericDecl(); - if (genericDecl != paramDecl->parentDecl) - continue; - - // In some cases, we construct a `DeclRef` to a `GenericDecl` - // (or a declaration under one) that only includes argument - // values for a prefix of the parameters of the generic. - // - // If we aren't careful, we could end up indexing into the - // argument list past the available range. - // - Count argCount = genSubst->getArgs().getCount(); + return paramVal; + } - Count argIndex = 0; - for (auto m : genericDecl->members) + auto args = genAppArgs->getArgs(); + + // In some cases, we construct a `DeclRef` to a `GenericDecl` + // (or a declaration under one) that only includes argument + // values for a prefix of the parameters of the generic. + // + // If we aren't careful, we could end up indexing into the + // argument list past the available range. + // + Count argCount = args.getCount(); + + Count argIndex = 0; + for (auto m : outerGeneric->members) + { + // If we have run out of arguments, then we can stop + // iterating over the parameters, because `this` + // parameter will not be replaced with anything by + // the substituion. + // + if (argIndex >= argCount) { - // If we have run out of arguments, then we can stop - // iterating over the parameters, because `this` - // parameter will not be replaced with anything by - // the substituion. - // - if (argIndex >= argCount) - { - return paramVal; - } + return paramVal; + } - if (m == paramDecl) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return genSubst->getArgs()[argIndex]; - } - else if (const auto typeParam = as(m)) - { - argIndex++; - } - else if (const auto valParam = as(m)) - { - argIndex++; - } - else - { - } + if (m == paramDecl) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return args[argIndex]; + } + else if (const auto typeParam = as(m)) + { + argIndex++; + } + else if (const auto valParam = as(m)) + { + argIndex++; + } + else + { } } @@ -180,7 +267,7 @@ Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff) { - if (auto result = maybeSubstituteGenericParam(this, declRef.getDecl(), subst, ioDiff)) + if (auto result = maybeSubstituteGenericParam(this, getDeclRef().getDecl(), subst, ioDiff)) return result; return this; @@ -188,21 +275,11 @@ Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, S // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ErrorIntVal::_equalsValOverride(Val* val) -{ - return as(val); -} - void ErrorIntVal::_toTextOverride(StringBuilder& out) { out << toSlice(""); } -HashCode ErrorIntVal::_getHashCodeOverride() -{ - return HashCode(typeid(this).hash_code()); -} - Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { SLANG_UNUSED(astBuilder); @@ -211,97 +288,110 @@ Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe return this; } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -// TODO: should really have a `type.cpp` and a `witness.cpp` - -bool TypeEqualityWitness::_equalsValOverride(Val* val) -{ - auto otherWitness = as(val); - if (!otherWitness) - return false; - return sub->equals(otherWitness->sub); -} - Val* TypeEqualityWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { - TypeEqualityWitness* rs = astBuilder->create(); - rs->sub = as(sub->substituteImpl(astBuilder, subst, ioDiff)); - rs->sup = as(sup->substituteImpl(astBuilder, subst, ioDiff)); + auto type = as(getSub()->substituteImpl(astBuilder, subst, ioDiff)); + TypeEqualityWitness* rs = astBuilder->getOrCreate(type, type); return rs; } void TypeEqualityWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("TypeEqualityWitness(") << sub << toSlice(")"); -} - -HashCode TypeEqualityWitness::_getHashCodeOverride() -{ - return sub->getHashCode(); + out << toSlice("TypeEqualityWitness(") << getSub() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclaredSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool DeclaredSubtypeWitness::_equalsValOverride(Val* val) +Val* DeclaredSubtypeWitness::_resolveImplOverride() { - auto otherWitness = as(val); - if (!otherWitness) - return false; + auto resolvedDeclRef = getDeclRef().declRefBase->resolve(); + if (auto resolvedVal = as(resolvedDeclRef)) + return resolvedVal; - return sub->equals(otherWitness->sub) - && sup->equals(otherWitness->sup) - && declRef.equals(otherWitness->declRef); + auto newSub = as(getSub()->resolve()); + auto newSup = as(getSup()->resolve()); + + // If we are trying to lookup for a witness that A<:B from a witness(A<:B), we + // can just return the witness itself. + if (auto lookupDeclRef = as(resolvedDeclRef)) + { + auto witnessToLookupFrom = lookupDeclRef->getWitness(); + if (witnessToLookupFrom->getSub()->equals(newSub) && + witnessToLookupFrom->getSup()->equals(newSup)) + return witnessToLookupFrom; + } + auto newDeclRef = as(resolvedDeclRef); + if (!newDeclRef) + newDeclRef = getDeclRef().declRefBase; + if (newSub != getSub() || newSup != getSup() || newDeclRef != getDeclRef()) + { + return getCurrentASTBuilder()->getDeclaredSubtypeWitness(newSub, newSup, newDeclRef); + } + return this; } Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { - if (auto genConstraintDeclRef = declRef.as()) + if (auto genConstraintDeclRef = getDeclRef().as()) { - auto genConstraintDecl = genConstraintDeclRef.getDecl(); + auto genericDecl = as(getDeclRef().getDecl()->parentDecl); + if (!genericDecl) + goto breakLabel; // search for a substitution that might apply to us - for (auto s = subst.substitutions; s; s = s->getOuter()) + auto args = tryGetGenericArguments(subst, genericDecl); + if (args.getCount() == 0) + goto breakLabel; + + bool found = false; + Index index = 0; + for (auto m : genericDecl->members) { - if (auto genericSubst = as(s)) + if (auto constraintParam = as(m)) { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genericSubst->getGenericDecl(); - if (genericDecl != genConstraintDecl->parentDecl) - continue; - - bool found = false; - Index index = 0; - for (auto m : genericDecl->members) + if (constraintParam == getDeclRef().getDecl()) { - if (auto constraintParam = as(m)) - { - if (constraintParam == declRef.getDecl()) - { - found = true; - break; - } - index++; - } - } - if (found) - { - (*ioDiff)++; - auto ordinaryParamCount = genericDecl->getMembersOfType().getCount() + - genericDecl->getMembersOfType().getCount(); - SLANG_ASSERT(index + ordinaryParamCount < genericSubst->getArgs().getCount()); - return genericSubst->getArgs()[index + ordinaryParamCount]; + found = true; + break; } + index++; + } + } + if (found) + { + auto ordinaryParamCount = genericDecl->getMembersOfType().getCount() + + genericDecl->getMembersOfType().getCount(); + if (index + ordinaryParamCount < args.getCount()) + { + (*ioDiff)++; + return args[index + ordinaryParamCount]; + } + else + { + // When the `subst` represents a partial substitution, we may not have a corresponding argument. + // In this case we just return the original witness. + // + goto breakLabel; } } } + else if (auto thisTypeConstraintDeclRef = getDeclRef().as()) + { + auto lookupSubst = subst.findLookupDeclRef(); + if (lookupSubst && lookupSubst->getSupDecl() == thisTypeConstraintDeclRef.getDecl()->getInterfaceDecl()) + { + (*ioDiff)++; + return lookupSubst->getWitness(); + } + } + +breakLabel:; // Perform substitution on the constituent elements. int diff = 0; - auto substSub = as(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as(sup->substituteImpl(astBuilder, subst, &diff)); - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + auto substSub = as(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as(getSup()->substituteImpl(astBuilder, subst, &diff)); + if (!diff) return this; @@ -317,7 +407,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub // so we'll need to change this location in the code if we ever clean // up the hierarchy. // - if (auto substTypeConstraintDecl = as(substDeclRef.getDecl())) + if (auto substTypeConstraintDecl = as(getDeclRef().getDecl())) { if (auto substAssocTypeDecl = as(substTypeConstraintDecl->parentDecl)) { @@ -326,12 +416,12 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub // At this point we have a constraint decl for an associated type, // and we nee to see if we are dealing with a concrete substitution // for the interface around that associated type. - if (auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.getSubst(), interfaceDecl)) + if (auto thisTypeWitness = findThisTypeWitness(subst, interfaceDecl)) { // We need to look up the declaration that satisfies // the requirement named by the associated type. Decl* requirementKey = substTypeConstraintDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeSubst->witness, requirementKey); + RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeWitness, requirementKey); switch (requirementWitness.getFlavor()) { default: @@ -348,6 +438,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub } } + auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); auto rs = astBuilder->getDeclaredSubtypeWitness( substSub, substSup, substDeclRef); return rs; @@ -355,34 +446,17 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub void DeclaredSubtypeWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("DeclaredSubtypeWitness(") << sub << toSlice(", ") << sup << toSlice(", ") << declRef << toSlice(")"); -} - -HashCode DeclaredSubtypeWitness::_getHashCodeOverride() -{ - return declRef.getHashCode(); + out << toSlice("DeclaredSubtypeWitness(") << getSub() << toSlice(", ") << getSup() << toSlice(", ") << getDeclRef() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TransitiveSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool TransitiveSubtypeWitness::_equalsValOverride(Val* val) -{ - auto otherWitness = as(val); - if (!otherWitness) - return false; - - return sub->equals(otherWitness->sub) - && sup->equals(otherWitness->sup) - && subToMid->equalsVal(otherWitness->subToMid) - && midToSup->equalsVal(otherWitness->midToSup); -} - Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { int diff = 0; - SubtypeWitness* substSubToMid = as(subToMid->substituteImpl(astBuilder, subst, &diff)); - SubtypeWitness* substMidToSup = as(midToSup->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substSubToMid = as(getSubToMid()->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substMidToSup = as(getMidToSup()->substituteImpl(astBuilder, subst, &diff)); // If nothing changed, then we can bail out early. if (!diff) @@ -407,16 +481,7 @@ void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out) // witnesses, and rely on them to print // the starting and ending types. - out << toSlice("TransitiveSubtypeWitness(") << subToMid << toSlice(", ") << midToSup << toSlice(")"); -} - -HashCode TransitiveSubtypeWitness::_getHashCodeOverride() -{ - auto hash = sub->getHashCode(); - hash = combineHash(hash, sup->getHashCode()); - hash = combineHash(hash, subToMid->getHashCode()); - hash = combineHash(hash, midToSup->getHashCode()); - return hash; + out << toSlice("TransitiveSubtypeWitness(") << getSubToMid() << toSlice(", ") << getMidToSup() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -425,9 +490,9 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a { int diff = 0; - auto substSub = as(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as(sup->substituteImpl(astBuilder, subst, &diff)); - auto substWitness = as(conjunctionWitness->substituteImpl(astBuilder, subst, &diff)); + auto substSub = as(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as(getSup()->substituteImpl(astBuilder, subst, &diff)); + auto substWitness = as(getConjunctionWitness()->substituteImpl(astBuilder, subst, &diff)); // If nothing changed, then we can bail out early. if (!diff) @@ -447,138 +512,34 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a // simplification logic as needed. // return astBuilder->getExtractFromConjunctionSubtypeWitness( - substSub, substSup, substWitness, indexInConjunction); + substSub, substSup, substWitness, getIndexInConjunction()); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val) -{ - if (auto extractWitness = as(val)) - { - return declRef.equals(extractWitness->declRef); - } - return false; -} - void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("extractExistentialValue(") << declRef << toSlice(")"); -} - -HashCode ExtractExistentialSubtypeWitness::_getHashCodeOverride() -{ - return declRef.getHashCode(); + out << toSlice("extractExistentialValue(") << getDeclRef() << toSlice(")"); } Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); - auto substSub = as(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as(sup->substituteImpl(astBuilder, subst, &diff)); + auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); + auto substSub = as(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as(getSup()->substituteImpl(astBuilder, subst, &diff)); if (!diff) return this; (*ioDiff)++; - ExtractExistentialSubtypeWitness* substValue = astBuilder->create(); - substValue->declRef = substDeclRef; - substValue->sub = substSub; - substValue->sup = substSup; + ExtractExistentialSubtypeWitness* substValue = astBuilder->getOrCreate( + substSub, substSup, substDeclRef); return substValue; } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -bool TaggedUnionSubtypeWitness::_equalsValOverride(Val* val) -{ - auto taggedUnionWitness = as(val); - if (!taggedUnionWitness) - return false; - - auto caseCount = caseWitnesses.getCount(); - if (caseCount != taggedUnionWitness->caseWitnesses.getCount()) - return false; - - for (Index ii = 0; ii < caseCount; ++ii) - { - if (!caseWitnesses[ii]->equalsVal(taggedUnionWitness->caseWitnesses[ii])) - return false; - } - - return true; -} - -void TaggedUnionSubtypeWitness::_toTextOverride(StringBuilder& out) -{ - out << toSlice("TaggedUnionSubtypeWitness("); - bool first = true; - for (auto caseWitness : caseWitnesses) - { - if (!first) - { - out << toSlice(", "); - } - first = false; - - out << caseWitness; - } - out << toSlice(")"); -} - -HashCode TaggedUnionSubtypeWitness::_getHashCodeOverride() -{ - HashCode hash = 0; - for (auto caseWitness : caseWitnesses) - { - hash = combineHash(hash, caseWitness->getHashCode()); - } - return hash; -} - -Val* TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substSub = as(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as(sup->substituteImpl(astBuilder, subst, &diff)); - - List substCaseWitnesses; - for (auto caseWitness : caseWitnesses) - { - substCaseWitnesses.add( - as(caseWitness->substituteImpl(astBuilder, subst, &diff))); - } - - if (!diff) - return this; - - (*ioDiff)++; - - TaggedUnionSubtypeWitness* substWitness = astBuilder->create(); - substWitness->sub = substSub; - substWitness->sup = substSup; - substWitness->caseWitnesses.swapWith(substCaseWitnesses); - return substWitness; -} - -bool ConjunctionSubtypeWitness::_equalsValOverride(Val* val) -{ - auto other = as(val); - if (!other) - return false; - - for (Index i = 0; i < kComponentCount; ++i) - { - if (!other->componentWitnesses[i]) return false; - if (!other->componentWitnesses[i]->equalsVal(componentWitnesses[i])) return false; - } - return true; -} - void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) { out << "ConjunctionSubtypeWitness("; @@ -586,34 +547,23 @@ void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) { if (i != 0) out << ","; - auto w = componentWitnesses[i]; + auto w = getComponentWitness(i); if (w) out << w; } out << ")"; } -HashCode ConjunctionSubtypeWitness::_getHashCodeOverride() -{ - HashCode result = 0; - for (Index i = 0; i < kComponentCount; ++i) - { - auto w = componentWitnesses[i]; - if (w) result = combineHash(result, w->getHashCode()); - } - return result; -} - Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; Val* substComponentWitnesses[kComponentCount]; - auto substSub = as(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as(sup->substituteImpl(astBuilder, subst, &diff)); + auto substSub = as(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as(getSup()->substituteImpl(astBuilder, subst, &diff)); for (Index i = 0; i < kComponentCount; ++i) { - auto w = componentWitnesses[i]; + auto w = getComponentWitness(i); substComponentWitnesses[i] = w ? w->substituteImpl(astBuilder, subst, &diff) : nullptr; } @@ -630,65 +580,25 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, auto result = astBuilder->getConjunctionSubtypeWitness( substSub, substSup, - componentWitnesses[0], - componentWitnesses[1]); + as(substComponentWitnesses[0]), + as(substComponentWitnesses[1])); return result; } -bool ExtractFromConjunctionSubtypeWitness::_equalsValOverride(Val* val) -{ - if (auto other = as(val)) - { - if(!sub->equals(other->sub)) return false; - if(!sup->equals(other->sup)) return false; - if(indexInConjunction != other->indexInConjunction) return false; - - return true; - } - return false; -} - void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) { out << "ExtractFromConjunctionSubtypeWitness("; - if (conjunctionWitness) - out << conjunctionWitness; - if (sub) - out << sub; + if (getConjunctionWitness()) + out << getConjunctionWitness(); + if (getSub()) + out << getSub(); out << ","; - if (sup) - out << sup; - out << "," << indexInConjunction; + if (getSup()) + out << getSup(); + out << "," << getIndexInConjunction(); out << ")"; } -HashCode ExtractFromConjunctionSubtypeWitness::_getHashCodeOverride() -{ - return combineHash( - conjunctionWitness ? conjunctionWitness->getHashCode() : 0, - sub ? sub->getHashCode() : 0, - sup ? sup->getHashCode() : 0, - indexInConjunction); -} - -// ModifierVal - -bool ModifierVal::_equalsValOverride(Val* val) -{ - // TODO: This is assuming we can fully deduplicate the values that represent - // modifiers, which may not actually be the case if there are multiple modules - // being combined that use different `ASTBuilder`s. - // - return this == val; -} - -HashCode ModifierVal::_getHashCodeOverride() -{ - Hasher hasher; - hasher.hashValue((void*) this); - return hasher.getResult(); -} - // UNormModifierVal void UNormModifierVal::_toTextOverride(StringBuilder& out) @@ -735,48 +645,14 @@ Val* NoDiffModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitu // PolynomialIntVal -bool PolynomialIntVal::_equalsValOverride(Val* val) -{ - if (auto genericParamVal = as(val)) - { - return constantTerm == 0 && terms.getCount() == 1 && - terms[0]->paramFactors.getCount() == 1 && terms[0]->constFactor == 1 && - terms[0]->paramFactors[0]->param->equalsVal(genericParamVal) && - terms[0]->paramFactors[0]->power == 1; - } - else if (auto otherPolynomial = as(val)) - { - if (constantTerm != otherPolynomial->constantTerm) - return false; - if (terms.getCount() != otherPolynomial->terms.getCount()) - return false; - for (Index i = 0; i < terms.getCount(); i++) - { - auto& thisTerm = *(terms[i]); - auto& thatTerm = *(otherPolynomial->terms[i]); - if (thisTerm.constFactor != thatTerm.constFactor) - return false; - if (thisTerm.paramFactors.getCount() != thatTerm.paramFactors.getCount()) - return false; - for (Index j = 0; j < thisTerm.paramFactors.getCount(); j++) - { - if (thisTerm.paramFactors[j]->power != thatTerm.paramFactors[j]->power) - return false; - if (!thisTerm.paramFactors[j]->param->equalsVal(thatTerm.paramFactors[j]->param)) - return false; - } - } - return true; - } - return false; -} - void PolynomialIntVal::_toTextOverride(StringBuilder& out) { + auto constantTerm = getConstantTerm(); + auto terms = getTerms(); for (Index i = 0; i < terms.getCount(); i++) { auto& term = *(terms[i]); - if (term.constFactor > 0) + if (term.getConstFactor() > 0) { if (i > 0) out << "+"; @@ -784,14 +660,14 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out) else out << "-"; bool isFirstFactor = true; - if (abs(term.constFactor) != 1 || term.paramFactors.getCount() == 0) + if (abs(term.getConstFactor()) != 1 || term.getParamFactors().getCount() == 0) { - out << abs(term.constFactor); + out << abs(term.getConstFactor()); isFirstFactor = false; } - for (Index j = 0; j < term.paramFactors.getCount(); j++) + for (Index j = 0; j < term.getParamFactors().getCount(); j++) { - auto factor = term.paramFactors[j]; + auto factor = term.getParamFactors()[j]; if (isFirstFactor) { isFirstFactor = false; @@ -800,10 +676,10 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out) { out << "*"; } - factor->param->toText(out); - if (factor->power != 1) + factor->getParam()->toText(out); + if (factor->getPower() != 1) { - out << "^^" << factor->power; + out << "^^" << factor->getPower(); } } } @@ -821,227 +697,304 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out) } } -HashCode PolynomialIntVal::_getHashCodeOverride() +struct PolynomialIntValBuilder { - HashCode result = (HashCode)constantTerm; - for (auto& term : terms) + ASTBuilder* astBuilder; + + IntegerLiteralValue constantTerm = 0; + List terms; + + PolynomialIntValBuilder(ASTBuilder* inAstBuilder) + : astBuilder(inAstBuilder) + {} + + // compute val += opreand*multiplier; + bool addToPolynomialTerm(IntVal* operand, IntegerLiteralValue multiplier) { - if (!term) continue; - result = combineHash(result, (HashCode)term->constFactor); - for (auto& factor : term->paramFactors) + if (auto c = as(operand)) { - result = combineHash(result, factor->param->getHashCode()); - result = combineHash(result, (HashCode)factor->power); + constantTerm += c->getValue() * multiplier; + return true; } + else if (auto poly = as(operand)) + { + constantTerm += poly->getConstantTerm() * multiplier; + for (auto term : poly->getTerms()) + { + auto newTerm = astBuilder->getOrCreate( + multiplier * term->getConstFactor(), term->getParamFactors()); + terms.add(newTerm); + } + return true; + } + else if (auto genVal = as(operand)) + { + auto factor = astBuilder->getOrCreate(genVal, 1); + auto term = astBuilder->getOrCreate(multiplier, makeArrayViewSingle(factor)); + terms.add(term); + return true; + } + return false; } - return result; -} + + IntVal* canonicalize(Type* type) + { + List newTerms; + IntegerLiteralValue newConstantTerm = constantTerm; + auto addTerm = [&](PolynomialIntValTerm* newTerm) + { + for (auto& term : newTerms) + { + if (term->canCombineWith(*newTerm)) + { + term = astBuilder->getOrCreate( + term->getConstFactor() + newTerm->getConstFactor(), + term->getParamFactors()); + return; + } + } + newTerms.add(newTerm); + }; + for (auto term : terms) + { + if (term->getConstFactor() == 0) + continue; + List newFactors; + List factorIsDifferent; + for (Index i = 0; i < term->getParamFactors().getCount(); i++) + { + auto factor = term->getParamFactors()[i]; + bool factorFound = false; + for (Index j = 0; j < newFactors.getCount(); j++) + { + auto& newFactor = newFactors[j]; + if (factor->getParam()->equals(newFactor->getParam())) + { + if (!factorIsDifferent[j]) + { + factorIsDifferent[j] = true; + auto clonedFactor = astBuilder->getOrCreate(newFactor->getParam(), newFactor->getPower()); + newFactor = clonedFactor; + } + newFactor = astBuilder->getOrCreate(newFactor->getParam(), newFactor->getPower() + factor->getPower()); + factorFound = true; + break; + } + } + if (!factorFound) + { + newFactors.add(factor); + factorIsDifferent.add(false); + } + } + List newFactors2; + // Remove zero-powered factors. + for (auto factor : newFactors) + { + if (factor->getPower() != 0) + newFactors2.add(factor); + } + if (newFactors2.getCount() == 0) + { + newConstantTerm += term->getConstFactor(); + continue; + } + newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); + bool isDifferent = false; + if (newFactors2.getCount() != term->getParamFactors().getCount()) + isDifferent = true; + if (!isDifferent) + { + for (Index i = 0; i < term->getParamFactors().getCount(); i++) + if (term->getParamFactors()[i] != newFactors2[i]) + { + isDifferent = true; + break; + } + } + if (!isDifferent) + { + addTerm(term); + } + else + { + auto newTerm = astBuilder->getOrCreate(term->getConstFactor(), newFactors2.getArrayView()); + addTerm(newTerm); + } + } + List newTerms2; + for (auto term : newTerms) + { + if (term->getConstFactor() == 0) + continue; + newTerms2.add(term); + } + newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); + terms = _Move(newTerms2); + constantTerm = newConstantTerm; + if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->getConstFactor() == 1 && terms[0]->getParamFactors().getCount() == 1 && + terms[0]->getParamFactors()[0]->getPower() == 1) + { + return terms[0]->getParamFactors()[0]->getParam(); + } + if (terms.getCount() == 0) + return astBuilder->getIntVal(type, constantTerm); + return nullptr; + } + + IntVal* getIntVal(Type* type) + { + if (auto canVal = canonicalize(type)) + return canVal; + return astBuilder->getOrCreate(type, constantTerm, terms.getArrayView()); + } +}; Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - IntegerLiteralValue evaluatedConstantTerm = constantTerm; - List evaluatedTerms; - for (auto& term : terms) + PolynomialIntValBuilder builder(astBuilder); + for (auto& term : getTerms()) { IntegerLiteralValue evaluatedTermConstFactor; List evaluatedTermParamFactors; - evaluatedTermConstFactor = term->constFactor; - for (auto& factor : term->paramFactors) + evaluatedTermConstFactor = term->getConstFactor(); + for (auto& factor : term->getParamFactors()) { - auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff); + auto substResult = factor->getParam()->substituteImpl(astBuilder, subst, &diff); if (auto constantVal = as(substResult)) { - evaluatedTermConstFactor *= constantVal->value; + evaluatedTermConstFactor *= constantVal->getValue(); } else if (auto intResult = as(substResult)) { - auto newFactor = astBuilder->create(); - newFactor->param = intResult; - newFactor->power = factor->power; + auto newFactor = astBuilder->getOrCreate(intResult, factor->getPower()); evaluatedTermParamFactors.add(newFactor); } } if (evaluatedTermParamFactors.getCount() == 0) { - evaluatedConstantTerm += evaluatedTermConstFactor; + builder.constantTerm += evaluatedTermConstFactor; } else { - auto newTerm = astBuilder->create(); - newTerm->paramFactors = _Move(evaluatedTermParamFactors); - newTerm->constFactor = evaluatedTermConstFactor; - evaluatedTerms.add(newTerm); + auto newTerm = astBuilder->getOrCreate( + evaluatedTermConstFactor, evaluatedTermParamFactors.getArrayView()); + builder.terms.add(newTerm); } } *ioDiff += diff; - if (evaluatedTerms.getCount() == 0) - return astBuilder->getIntVal(type, evaluatedConstantTerm); + if (builder.terms.getCount() == 0) + return astBuilder->getIntVal(getType(), builder.constantTerm); if (diff != 0) { - auto newPolynomial = astBuilder->create(type); - newPolynomial->constantTerm = evaluatedConstantTerm; - newPolynomial->terms = _Move(evaluatedTerms); - return newPolynomial->canonicalize(astBuilder); + return builder.getIntVal(getType()); } return this; } - -// compute val += opreand*multiplier; -bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier) -{ - if (auto c = as(operand)) - { - val->constantTerm += c->value * multiplier; - return true; - } - else if (auto poly = as(operand)) - { - val->constantTerm += poly->constantTerm * multiplier; - for (auto term : poly->terms) - { - auto newTerm = astBuilder->create(); - newTerm->constFactor = multiplier * term->constFactor; - newTerm->paramFactors = term->paramFactors; - val->terms.add(newTerm); - } - return true; - } - else if (auto genVal = as(operand)) - { - auto term = astBuilder->create(); - term->constFactor = multiplier; - auto factor = astBuilder->create(); - factor->power = 1; - factor->param = genVal; - term->paramFactors.add(factor); - val->terms.add(term); - return true; - } - return false; -} - -PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) +IntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) { - auto result = astBuilder->create(base->type); - if (!addToPolynomialTerm(astBuilder, result, base, -1)) - return nullptr; - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + builder.addToPolynomialTerm(base, -1); + return builder.getIntVal(base->getType()); } -PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +IntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { - auto result = astBuilder->create(op0->type); - if (!addToPolynomialTerm(astBuilder, result, op0, 1)) - return nullptr; - if (!addToPolynomialTerm(astBuilder, result, op1, -1)) - return nullptr; - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + builder.addToPolynomialTerm(op0, 1); + builder.addToPolynomialTerm(op1, -1); + return builder.getIntVal(op0->getType()); } -PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +IntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { - auto result = astBuilder->create(op0->type); - if (!addToPolynomialTerm(astBuilder, result, op0, 1)) - return nullptr; - if (!addToPolynomialTerm(astBuilder, result, op1, 1)) - return nullptr; - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + builder.addToPolynomialTerm(op0, 1); + builder.addToPolynomialTerm(op1, 1); + return builder.getIntVal(op0->getType()); } -PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { if (auto poly0 = as(op0)) { if (auto poly1 = as(op1)) { - auto result = astBuilder->create(poly0->type); + PolynomialIntValBuilder builder(astBuilder); // add poly0.constant * poly1.constant - result->constantTerm = poly0->constantTerm * poly1->constantTerm; + builder.constantTerm = poly0->getConstantTerm() * poly1->getConstantTerm(); // add poly0.constant * poly1.terms - if (poly0->constantTerm != 0) + if (poly0->getConstantTerm() != 0) { - for (auto term : poly1->terms) + for (auto term : poly1->getTerms()) { - auto newTerm = astBuilder->create(); - newTerm->constFactor = poly0->constantTerm * term->constFactor; - newTerm->paramFactors.addRange(term->paramFactors); - result->terms.add(newTerm); + auto newTerm = astBuilder->getOrCreate( + poly0->getConstantTerm() * term->getConstFactor(), term->getParamFactors()); + builder.terms.add(newTerm); } } // add poly1.constant * poly0.terms - if (poly1->constantTerm != 0) + if (poly1->getConstantTerm() != 0) { - for (auto term : poly0->terms) + for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->create(); - newTerm->constFactor = poly1->constantTerm * term->constFactor; - newTerm->paramFactors.addRange(term->paramFactors); - result->terms.add(newTerm); + auto newTerm = astBuilder->getOrCreate( + poly1->getConstantTerm() * term->getConstFactor(), + term->getParamFactors()); + builder.terms.add(newTerm); } } // add poly1.terms * poly0.terms - for (auto term0 : poly0->terms) + for (auto term0 : poly0->getTerms()) { - for (auto term1 : poly1->terms) + for (auto term1 : poly1->getTerms()) { - auto newTerm = astBuilder->create(); - newTerm->constFactor = term0->constFactor * term1->constFactor; - newTerm->paramFactors.addRange(term0->paramFactors); - newTerm->paramFactors.addRange(term1->paramFactors); - result->terms.add(newTerm); + List newFactors; + for (auto f : term0->getParamFactors()) newFactors.add(f); + for (auto f : term1->getParamFactors()) newFactors.add(f); + auto newTerm = astBuilder->getOrCreate( + term0->getConstFactor() * term1->getConstFactor(), newFactors.getArrayView()); + builder.terms.add(newTerm); } } - result->canonicalize(astBuilder); - return result; + return builder.getIntVal(op0->getType()); } else if (auto cVal1 = as(op1)) { - auto result = astBuilder->create(poly0->type); - result->constantTerm = poly0->constantTerm * cVal1->value; - auto factor1 = astBuilder->create(); - for (auto term : poly0->terms) + PolynomialIntValBuilder builder(astBuilder); + builder.constantTerm = poly0->getConstantTerm() * cVal1->getValue(); + for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->create(); - newTerm->constFactor = term->constFactor * cVal1->value; - newTerm->paramFactors.addRange(term->paramFactors); - newTerm->paramFactors.add(factor1); - result->terms.add(newTerm); + auto newTerm = astBuilder->getOrCreate(term->getConstFactor() * cVal1->getValue(), term->getParamFactors()); + builder.terms.add(newTerm); } - result->canonicalize(astBuilder); - return result; + return builder.getIntVal(poly0->getType()); } else if (auto val1 = as(op1)) { - auto result = astBuilder->create(poly0->type); - result->constantTerm = 0; - auto factor1 = astBuilder->create(); - factor1->power = 1; - factor1->param = val1; - if (poly0->constantTerm != 0) + PolynomialIntValBuilder builder(astBuilder); + auto factor1 = astBuilder->getOrCreate(val1, 1); + if (poly0->getConstantTerm() != 0) { - auto term0 = astBuilder->create(); - term0->constFactor = poly0->constantTerm; - term0->paramFactors.add(factor1); - result->terms.add(term0); + auto term0 = astBuilder->getOrCreate(poly0->getConstantTerm(), makeArrayViewSingle(factor1)); + builder.terms.add(term0); } - for (auto term : poly0->terms) + for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->create(); - newTerm->constFactor = term->constFactor; - newTerm->paramFactors.addRange(term->paramFactors); - newTerm->paramFactors.add(factor1); - result->terms.add(newTerm); + List newFactors; + for (auto f: term->getParamFactors()) + newFactors.add(f); + newFactors.add(factor1); + auto newTerm = astBuilder->getOrCreate( + term->getConstFactor(), newFactors.getArrayView()); + builder.terms.add(newTerm); } - result->canonicalize(astBuilder); - return result; + return builder.getIntVal(poly0->getType()); } else return nullptr; @@ -1058,184 +1011,48 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int } else if (auto cVal1 = as(op1)) { - auto result = astBuilder->create(val0->type); - auto term = astBuilder->create(); - term->constFactor = cVal1->value; - auto factor0 = astBuilder->create(); - factor0->power = 1; - factor0->param = val0; - term->paramFactors.add(factor0); - result->terms.add(term); - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + auto factor0 = astBuilder->getOrCreate(val0, 1); + auto term = astBuilder->getOrCreate( + cVal1->getValue(), makeArrayView(&factor0, 1)); + builder.terms.add(term); + return builder.getIntVal(val0->getType()); } else if (auto val1 = as(op1)) { - auto result = astBuilder->create(val0->type); - auto term = astBuilder->create(); - term->constFactor = 1; - auto factor0 = astBuilder->create(); - factor0->power = 1; - factor0->param = val0; - term->paramFactors.add(factor0); - auto factor1 = astBuilder->create(); - factor1->power = 1; - factor1->param = val1; - term->paramFactors.add(factor1); - result->terms.add(term); - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + auto factor0 = astBuilder->getOrCreate(val0, 1); + auto factor1 = astBuilder->getOrCreate(val1, 1); + PolynomialIntValFactor* newFactors[] = { factor0, factor1 }; + auto term = astBuilder->getOrCreate(1, makeArrayView(newFactors)); + builder.terms.add(term); + return builder.getIntVal(val0->getType()); } } return nullptr; } -IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder) -{ - List newTerms; - IntegerLiteralValue newConstantTerm = constantTerm; - auto addTerm = [&](PolynomialIntValTerm* newTerm) - { - for (auto term : newTerms) - { - if (term->canCombineWith(*newTerm)) - { - term->constFactor += newTerm->constFactor; - return; - } - } - newTerms.add(newTerm); - }; - for (auto term : terms) - { - if (term->constFactor == 0) - continue; - List newFactors; - List factorIsDifferent; - for (Index i = 0; i < term->paramFactors.getCount(); i++) - { - auto factor = term->paramFactors[i]; - bool factorFound = false; - for (Index j = 0; j < newFactors.getCount(); j++) - { - auto& newFactor = newFactors[j]; - if (factor->param->equalsVal(newFactor->param)) - { - if (!factorIsDifferent[j]) - { - factorIsDifferent[j] = true; - auto clonedFactor = builder->create(); - clonedFactor->param = newFactor->param; - clonedFactor->power = newFactor->power; - newFactor = clonedFactor; - } - newFactor->power += factor->power; - factorFound = true; - break; - } - } - if (!factorFound) - { - newFactors.add(factor); - factorIsDifferent.add(false); - } - } - List newFactors2; - for (auto factor : newFactors) - { - if (factor->power != 0) - newFactors2.add(factor); - } - if (newFactors2.getCount() == 0) - { - newConstantTerm += term->constFactor; - continue; - } - newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); - bool isDifferent = false; - if (newFactors2.getCount() != term->paramFactors.getCount()) - isDifferent = true; - if (!isDifferent) - { - for (Index i = 0; i < term->paramFactors.getCount(); i++) - if (term->paramFactors[i] != newFactors2[i]) - { - isDifferent = true; - break; - } - } - if (!isDifferent) - { - addTerm(term); - } - else - { - auto newTerm = builder->create(); - newTerm->constFactor = term->constFactor; - newTerm->paramFactors = _Move(newFactors2); - addTerm(newTerm); - } - } - List newTerms2; - for (auto term : newTerms) - { - if (term->constFactor == 0) - continue; - newTerms2.add(term); - } - newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); - terms = _Move(newTerms2); - constantTerm = newConstantTerm; - if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->constFactor == 1 && terms[0]->paramFactors.getCount() == 1 && - terms[0]->paramFactors[0]->power == 1) - { - return terms[0]->paramFactors[0]->param; - } - if (terms.getCount() == 0) - return builder->getIntVal(type, constantTerm); - return this; -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeCastIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool TypeCastIntVal::_equalsValOverride(Val* val) -{ - if (auto typeCastIntVal = as(val)) - { - if (!type->equals(typeCastIntVal->type)) - return false; - if (!base->equalsVal(typeCastIntVal->base)) - return false; - return true; - } - return false; -} void TypeCastIntVal::_toTextOverride(StringBuilder& out) { - type->toText(out); + getType()->toText(out); out << "("; - base->toText(out); + getBase()->toText(out); out << ")"; } -HashCode TypeCastIntVal::_getHashCodeOverride() -{ - HashCode result = type->getHashCode(); - result = combineHash(result, base->getHashCode()); - return result; -} - Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink) { SLANG_UNUSED(sink); if (auto c = as(base)) { - IntegerLiteralValue resultValue = c->value; + IntegerLiteralValue resultValue = c->getValue(); auto baseType = as(resultType); if (baseType) { - switch (baseType->baseType) + switch (baseType->getBaseType()) { case BaseType::Int: resultValue = (int)resultValue; @@ -1275,11 +1092,11 @@ Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substBase = base->substituteImpl(astBuilder, subst, &diff); - if (substBase != base) + auto substBase = getBase()->substituteImpl(astBuilder, subst, &diff); + if (substBase != getBase()) diff++; - auto substType = as(type->substituteImpl(astBuilder, subst, &diff)); - if (substType != type) + auto substType = as(getType()->substituteImpl(astBuilder, subst, &diff)); + if (substType != getType()) diff++; *ioDiff += diff; if (diff) @@ -1289,7 +1106,7 @@ Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio return newVal; else { - auto result = astBuilder->create(substType, substBase); + auto result = astBuilder->getOrCreate(substType, substBase); return result; } } @@ -1297,29 +1114,20 @@ Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio return this; } - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -bool FuncCallIntVal::_equalsValOverride(Val* val) +Val* TypeCastIntVal::_resolveImplOverride() { - if (auto funcCallIntVal = as(val)) - { - if (!funcDeclRef.equals(funcCallIntVal->funcDeclRef)) - return false; - if (args.getCount() != funcCallIntVal->args.getCount()) - return false; - for (Index i = 0; i < args.getCount(); i++) - { - if (!args[i]->equalsVal(funcCallIntVal->args[i])) - return false; - } - return true; - } - return false; + if (auto resolved = tryFoldImpl(getCurrentASTBuilder(), getType(), getBase(), nullptr)) + return resolved; + return this; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + void FuncCallIntVal::_toTextOverride(StringBuilder& out) { + auto args = getArgs(); + auto funcDeclRef = getFuncDeclRef(); + auto argToText = [&](int index) { if (as(args[index]) || as(args[index])) @@ -1369,14 +1177,37 @@ void FuncCallIntVal::_toTextOverride(StringBuilder& out) } } -HashCode FuncCallIntVal::_getHashCodeOverride() +Val* FuncCallIntVal::_resolveImplOverride() { - HashCode result = funcDeclRef.getHashCode(); + auto astBuilder = getCurrentASTBuilder(); + auto args = getArgs(); + auto funcDeclRef = getFuncDeclRef(); + auto funcType = getFuncType(); + + Val* resolvedVal = this; + + auto newFuncDeclRef = as(funcDeclRef.declRefBase->resolve()); + if (!newFuncDeclRef) + return this; + bool diff = false; + List newArgs; for (auto arg : args) { - result = combineHash(result, arg->getHashCode()); + auto newArg = as(arg->resolve()); + if (!newArg) + return this; + newArgs.add(newArg); + if (newArg != arg) + diff = true; } - return result; + + if (auto resolved = tryFoldImpl(astBuilder, getType(), newFuncDeclRef, newArgs, nullptr)) + resolvedVal = resolved; + else if (diff) + { + resolvedVal = astBuilder->getOrCreate(getType(), newFuncDeclRef, funcType, newArgs.getArrayView()); + } + return resolvedVal; } Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef newFuncDecl, List& newArgs, DiagnosticSink* sink) @@ -1413,25 +1244,25 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR #define BINARY_OPERATOR_CASE(op) \ if (opNameSlice == toSlice(#op)) \ { \ - resultValue = constArgs[0]->value op constArgs[1]->value; \ + resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ } else #define DIV_OPERATOR_CASE(op) \ if (opNameSlice == toSlice(#op)) \ { \ - if (constArgs[1]->value == 0) \ + if (constArgs[1]->getValue() == 0) \ { \ if (sink) \ sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \ return nullptr; \ } \ - resultValue = constArgs[0]->value op constArgs[1]->value; \ + resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ } else #define LOGICAL_OPERATOR_CASE(op) \ if (opNameSlice == toSlice(#op)) \ { \ - resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \ + resultValue = (((constArgs[0]->getValue()!=0) op (constArgs[1]->getValue()!=0)) ? 1 : 0); \ } else @@ -1463,9 +1294,9 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR LOGICAL_OPERATOR_CASE(&&) LOGICAL_OPERATOR_CASE(||) // Special cases need their "operator" names quoted. - SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->value != 0) ? 1 : 0);) - SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->value;) - SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value;) + SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->getValue() != 0) ? 1 : 0);) + SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->getValue();) + SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->getValue() != 0 ? constArgs[1]->getValue() : constArgs[2]->getValue();) TERMINATING_CASE(SLANG_UNREACHABLE("constant folding of FuncCallIntVal");) return astBuilder->getIntVal(resultType, resultValue); @@ -1483,9 +1314,9 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff); + auto newFuncDeclRef = getFuncDeclRef().substituteImpl(astBuilder, subst, &diff); List newArgs; - for (auto& arg : args) + for (auto& arg : getArgs()) { auto substArg = arg->substituteImpl(astBuilder, subst, &diff); if (substArg != arg) @@ -1496,15 +1327,12 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio if (diff) { // TODO: report diagnostics back. - auto newVal = tryFoldImpl(astBuilder, type, newFuncDeclRef, newArgs, nullptr); + auto newVal = tryFoldImpl(astBuilder, getType(), newFuncDeclRef, newArgs, nullptr); if (newVal) return newVal; else { - auto result = astBuilder->create(type); - result->args = _Move(newArgs); - result->funcDeclRef = newFuncDeclRef; - result->funcType = funcType; + auto result = astBuilder->getOrCreate(getType(), newFuncDeclRef, getFuncType(), newArgs.getArrayView()); return result; } } @@ -1514,40 +1342,47 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! WitnessLookupIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool WitnessLookupIntVal::_equalsValOverride(Val* val) -{ - if (auto lookupIntVal = as(val)) - { - if (!witness->equalsVal(lookupIntVal->witness)) - return false; - if (key != lookupIntVal->key) - return false; - return true; - } - return false; -} - void WitnessLookupIntVal::_toTextOverride(StringBuilder& out) { - witness->sub->toText(out); + getWitness()->getSub()->toText(out); out << "."; - out << (key->getName() ? key->getName()->text : "??"); + out << (getKey()->getName() ? getKey()->getName()->text : "??"); } -HashCode WitnessLookupIntVal::_getHashCodeOverride() +Val* WitnessLookupIntVal::_resolveImplOverride() { - HashCode result = witness->getHashCode(); - result = combineHash(result, Slang::getHashCode(key)); - return result; + auto astBuilder = getCurrentASTBuilder(); + + auto newWitness = as(getWitness()->resolve()); + if (!newWitness) + return this; + + auto witnessVal = tryLookUpRequirementWitness(astBuilder, newWitness, getKey()); + if (witnessVal.getFlavor() == RequirementWitness::Flavor::val) + { + return witnessVal.getVal(); + } + + auto newType = as(getType()->resolve()); + if (!newType) + return this; + + if (newWitness != getWitness() || newType != getType()) + { + return astBuilder->getOrCreate(newType, newWitness, getKey()); + } + + return this; } + Val* WitnessLookupIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto newWitness = witness->substituteImpl(astBuilder, subst, &diff); + auto newWitness = getWitness()->substituteImpl(astBuilder, subst, &diff); *ioDiff += diff; if (diff) { - auto witnessEntry = tryFoldOrNull(astBuilder, as(newWitness), key); + auto witnessEntry = tryFoldOrNull(astBuilder, as(newWitness), getKey()); if (witnessEntry) return witnessEntry; } @@ -1573,51 +1408,93 @@ Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witnes { if (auto result = tryFoldOrNull(astBuilder, witness, key)) return result; - auto witnessResult = astBuilder->create(); - witnessResult->witness = witness; - witnessResult->key = key; - witnessResult->type = type; + auto witnessResult = astBuilder->getOrCreate(type, witness, key); return witnessResult; } - -bool DifferentiateVal::_equalsValOverride(Val* val) -{ - if (auto other = as(val)) - { - return other->astNodeType == astNodeType && other->func == func; - } - return false; -} - void DifferentiateVal::_toTextOverride(StringBuilder& out) { out << "DifferentiateVal("; - out << func; + out << getFunc(); out << ")"; } -HashCode DifferentiateVal::_getHashCodeOverride() -{ - HashCode result = (HashCode)astNodeType; - result = combineHash(result, func.getHashCode()); - return result; -} - Val* DifferentiateVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto newFunc = func.substituteImpl(astBuilder, subst, &diff); + auto newFunc = getFunc().substituteImpl(astBuilder, subst, &diff); *ioDiff += diff; if (diff) { auto result = as(astBuilder->createByNodeType(astNodeType)); - result->func = newFunc; + result->getFunc() = newFunc; return result; } // Nothing found: don't substitute. return this; } +Val* DifferentiateVal::_resolveImplOverride() +{ + return this; +} + +Val* PolynomialIntValFactor::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + + auto newParam = as(getParam()->resolve()); + if (newParam && newParam != getParam()) + return astBuilder->getOrCreate(newParam, getPower()); + + return this; +} + +Val* PolynomialIntValTerm::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + + bool diff = false; + List newFactors; + for (auto factor : getParamFactors()) + { + auto newFactor = as(factor->resolve()); + if (!newFactor) + return this; + + if (newFactor != factor) + diff = true; + newFactors.add(newFactor); + } + + if (diff) + return astBuilder->getOrCreate(getConstFactor(), newFactors.getArrayView()); + + return this; +} + +Val* PolynomialIntVal::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + + bool diff = false; + PolynomialIntValBuilder builder(astBuilder); + builder.constantTerm = getConstantTerm(); + for (auto term : getTerms()) + { + auto newTerm = as(term->resolve()); + if (!newTerm) + return this; + + if (newTerm != term) + diff = true; + builder.terms.add(newTerm); + } + + if (diff) + return builder.getIntVal(getType()); + + return this; +} } // namespace Slang -- cgit v1.2.3