diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-04 15:47:39 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-04 15:47:39 -0700 |
| commit | a2d90fb275962da84611160f8ddd74d934a68dbd (patch) | |
| tree | 066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-ast-val.cpp | |
| parent | 17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff) | |
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val.
* Update project files
* Fix warning.
* Fix.
* Fix.
* Remove `Val::_equalsImplOverride`.
* Rmove `Val::_getHashCodeOverride`.
* Remove `semanticVisitor` param from `resolve`.
* Cleanups.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 1477 |
1 files changed, 677 insertions, 800 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index b45300af8..056577eb0 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -6,9 +6,47 @@ #include "slang-generated-ast-macro.h" #include "slang-diagnostics.h" #include "slang-syntax.h" +#include "slang-ast-val.h" namespace Slang { + +bool ValNodeDesc::operator==(ValNodeDesc const& that) const +{ + if (hashCode != that.hashCode) return false; + if (type != that.type) return false; + if (operands.getCount() != that.operands.getCount()) return false; + for (Index i = 0; i < operands.getCount(); ++i) + { + // Note: we are comparing the operands directly for identity + // (pointer equality) rather than doing the `Val`-level + // equality check. + // + // The rationale here is that nodes that will be created + // via a `NodeDesc` *should* all be going through the + // deduplication path anyway, as should their operands. + // + if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; + } + return true; +} + +void ValNodeDesc::init() +{ + Hasher hasher; + hasher.hashValue(Int(type)); + for (Index i = 0; i < operands.getCount(); ++i) + { + // Note: we are hashing the raw pointer value rather + // than the content of the value node. This is done + // to match the semantics implemented for `==` on + // `NodeDesc`. + // + hasher.hashValue(operands[i].values.nodeOperand); + } + hashCode = hasher.getResult(); +} + Val* Val::substitute(ASTBuilder* astBuilder, SubstitutionSet subst) { if (!subst) return this; @@ -21,14 +59,103 @@ Val* Val::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioD SLANG_AST_NODE_VIRTUAL_CALL(Val, substituteImpl, (astBuilder, subst, ioDiff)) } -bool Val::equalsVal(Val* val) +void Val::toText(StringBuilder& out) { - SLANG_AST_NODE_VIRTUAL_CALL(Val, equalsVal, (val)) + SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out)) } -void Val::toText(StringBuilder& out) +Val* Val::_resolveImplOverride() { - SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out)) + SLANG_UNEXPECTED("Val::_resolveImplOverride not overridden"); +} + +Val* Val::resolveImpl() +{ + SLANG_AST_NODE_VIRTUAL_CALL(Val, resolveImpl, ()); +} + +Val* Val::resolve() +{ + auto astBuilder = getCurrentASTBuilder(); + + // If we are not in a proper checking context, just return the previously resolved val. + if (!astBuilder) + return m_resolvedVal? m_resolvedVal : this; + if (m_resolvedVal && m_resolvedValEpoch == getCurrentASTBuilder()->getEpoch()) + { + SLANG_ASSERT(as<Val>(m_resolvedVal)); + return m_resolvedVal; + } + + // Update epoch now to avoid infinite recursion. + m_resolvedValEpoch = getCurrentASTBuilder()->getEpoch(); + m_resolvedVal = this; + m_resolvedVal = resolveImpl(); + + // Check if we are resolved to an existing Val in the AST cache. + ValNodeDesc newDesc; + newDesc.type = m_resolvedVal->astNodeType; + for (auto operand : m_resolvedVal->m_operands) + { + if (operand.kind == ValNodeOperandKind::ValNode) + { + auto valOperand = as<Val>(operand.values.nodeOperand); + if (valOperand) + { + operand.values.nodeOperand = valOperand->resolve(); + } + } + newDesc.operands.add(operand); + } + newDesc.init(); + + NodeBase* existingNode = nullptr; + if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode)) + m_resolvedVal = as<Val>(existingNode); + +#ifdef _DEBUG + if (m_resolvedVal->_debugUID > 0 && this->_debugUID < 0) + { + //SLANG_ASSERT_FAILURE("should not be modifying stdlib vals outside of stdlib checking."); + } +#endif + return m_resolvedVal; +} + +ValNodeDesc Val::getDesc() +{ + ValNodeDesc desc; + desc.type = astNodeType; + for (auto operand : m_operands) + desc.operands.add(operand); + desc.init(); + return desc; +} + +Val* Val::defaultResolveImpl() +{ + // Default resolve implementation is to recursively resolve all operands, and lookup in deduplication cache. + ValNodeDesc newDesc; + newDesc.type = astNodeType; + for (auto operand : m_operands) + { + if (operand.kind == ValNodeOperandKind::ValNode) + { + auto valOperand = as<Val>(operand.values.nodeOperand); + if (valOperand) + { + operand.values.nodeOperand = valOperand->resolve(); + } + } + newDesc.operands.add(operand); + } + newDesc.init(); + auto astBuilder = getCurrentASTBuilder(); + + NodeBase* existingNode = nullptr; + if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode)) + return as<Val>(existingNode); + return this; } String Val::toString() @@ -40,7 +167,7 @@ String Val::toString() HashCode Val::getHashCode() { - SLANG_AST_NODE_VIRTUAL_CALL(Val, getHashCode, ()) + return Slang::getHashCode(resolve()); } Val* Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) @@ -52,124 +179,84 @@ Val* Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, return this; } -bool Val::_equalsValOverride(Val* val) -{ - SLANG_UNUSED(val); - SLANG_UNEXPECTED("Val::_equalsValOverride not overridden"); - //return false; -} - void Val::_toTextOverride(StringBuilder& out) { SLANG_UNUSED(out); SLANG_UNEXPECTED("Val::_toStringOverride not overridden"); } -HashCode Val::_getHashCodeOverride() -{ - SLANG_UNEXPECTED("Val::_getHashCodeOverride not overridden"); - //return HashCode(0); -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ConstantIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ConstantIntVal::_equalsValOverride(Val* val) -{ - if (auto intVal = as<ConstantIntVal>(val)) - return value == intVal->value; - return false; -} - void ConstantIntVal::_toTextOverride(StringBuilder& out) { - out << value; -} - -HashCode ConstantIntVal::_getHashCodeOverride() -{ - return (HashCode)value; + out << getValue(); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericParamIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool GenericParamIntVal::_equalsValOverride(Val* val) -{ - if (auto genericParamVal = as<GenericParamIntVal>(val)) - { - return declRef.equals(genericParamVal->declRef); - } - return false; -} - void GenericParamIntVal::_toTextOverride(StringBuilder& out) { - Name* name = declRef.getName(); + Name* name = getDeclRef().getName(); if (name) { out << name->text; } } -HashCode GenericParamIntVal::_getHashCodeOverride() -{ - return declRef.getHashCode() ^ HashCode(0xFFFF); -} - Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff) { // search for a substitution that might apply to us - for (auto s = subst.substitutions; s; s = s->getOuter()) + auto outerGeneric = as<GenericDecl>(paramDecl->parentDecl); + if (!outerGeneric) + return paramVal; + + GenericAppDeclRef* genAppArgs = subst.findGenericAppDeclRef(outerGeneric); + if (!genAppArgs) { - auto genSubst = as<GenericSubstitution>(s); - if (!genSubst) - continue; - - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genSubst->getGenericDecl(); - if (genericDecl != paramDecl->parentDecl) - continue; - - // In some cases, we construct a `DeclRef` to a `GenericDecl` - // (or a declaration under one) that only includes argument - // values for a prefix of the parameters of the generic. - // - // If we aren't careful, we could end up indexing into the - // argument list past the available range. - // - Count argCount = genSubst->getArgs().getCount(); + return paramVal; + } - Count argIndex = 0; - for (auto m : genericDecl->members) + auto args = genAppArgs->getArgs(); + + // In some cases, we construct a `DeclRef` to a `GenericDecl` + // (or a declaration under one) that only includes argument + // values for a prefix of the parameters of the generic. + // + // If we aren't careful, we could end up indexing into the + // argument list past the available range. + // + Count argCount = args.getCount(); + + Count argIndex = 0; + for (auto m : outerGeneric->members) + { + // If we have run out of arguments, then we can stop + // iterating over the parameters, because `this` + // parameter will not be replaced with anything by + // the substituion. + // + if (argIndex >= argCount) { - // If we have run out of arguments, then we can stop - // iterating over the parameters, because `this` - // parameter will not be replaced with anything by - // the substituion. - // - if (argIndex >= argCount) - { - return paramVal; - } + return paramVal; + } - if (m == paramDecl) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return genSubst->getArgs()[argIndex]; - } - else if (const auto typeParam = as<GenericTypeParamDecl>(m)) - { - argIndex++; - } - else if (const auto valParam = as<GenericValueParamDecl>(m)) - { - argIndex++; - } - else - { - } + if (m == paramDecl) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return args[argIndex]; + } + else if (const auto typeParam = as<GenericTypeParamDecl>(m)) + { + argIndex++; + } + else if (const auto valParam = as<GenericValueParamDecl>(m)) + { + argIndex++; + } + else + { } } @@ -180,7 +267,7 @@ Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff) { - if (auto result = maybeSubstituteGenericParam(this, declRef.getDecl(), subst, ioDiff)) + if (auto result = maybeSubstituteGenericParam(this, getDeclRef().getDecl(), subst, ioDiff)) return result; return this; @@ -188,21 +275,11 @@ Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, S // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ErrorIntVal::_equalsValOverride(Val* val) -{ - return as<ErrorIntVal>(val); -} - void ErrorIntVal::_toTextOverride(StringBuilder& out) { out << toSlice("<error>"); } -HashCode ErrorIntVal::_getHashCodeOverride() -{ - return HashCode(typeid(this).hash_code()); -} - Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { SLANG_UNUSED(astBuilder); @@ -211,97 +288,110 @@ Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe return this; } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -// TODO: should really have a `type.cpp` and a `witness.cpp` - -bool TypeEqualityWitness::_equalsValOverride(Val* val) -{ - auto otherWitness = as<TypeEqualityWitness>(val); - if (!otherWitness) - return false; - return sub->equals(otherWitness->sub); -} - Val* TypeEqualityWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { - TypeEqualityWitness* rs = astBuilder->create<TypeEqualityWitness>(); - rs->sub = as<Type>(sub->substituteImpl(astBuilder, subst, ioDiff)); - rs->sup = as<Type>(sup->substituteImpl(astBuilder, subst, ioDiff)); + auto type = as<Type>(getSub()->substituteImpl(astBuilder, subst, ioDiff)); + TypeEqualityWitness* rs = astBuilder->getOrCreate<TypeEqualityWitness>(type, type); return rs; } void TypeEqualityWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("TypeEqualityWitness(") << sub << toSlice(")"); -} - -HashCode TypeEqualityWitness::_getHashCodeOverride() -{ - return sub->getHashCode(); + out << toSlice("TypeEqualityWitness(") << getSub() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclaredSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool DeclaredSubtypeWitness::_equalsValOverride(Val* val) +Val* DeclaredSubtypeWitness::_resolveImplOverride() { - auto otherWitness = as<DeclaredSubtypeWitness>(val); - if (!otherWitness) - return false; + auto resolvedDeclRef = getDeclRef().declRefBase->resolve(); + if (auto resolvedVal = as<SubtypeWitness>(resolvedDeclRef)) + return resolvedVal; - return sub->equals(otherWitness->sub) - && sup->equals(otherWitness->sup) - && declRef.equals(otherWitness->declRef); + auto newSub = as<Type>(getSub()->resolve()); + auto newSup = as<Type>(getSup()->resolve()); + + // If we are trying to lookup for a witness that A<:B from a witness(A<:B), we + // can just return the witness itself. + if (auto lookupDeclRef = as<LookupDeclRef>(resolvedDeclRef)) + { + auto witnessToLookupFrom = lookupDeclRef->getWitness(); + if (witnessToLookupFrom->getSub()->equals(newSub) && + witnessToLookupFrom->getSup()->equals(newSup)) + return witnessToLookupFrom; + } + auto newDeclRef = as<DeclRefBase>(resolvedDeclRef); + if (!newDeclRef) + newDeclRef = getDeclRef().declRefBase; + if (newSub != getSub() || newSup != getSup() || newDeclRef != getDeclRef()) + { + return getCurrentASTBuilder()->getDeclaredSubtypeWitness(newSub, newSup, newDeclRef); + } + return this; } Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { - if (auto genConstraintDeclRef = declRef.as<GenericTypeConstraintDecl>()) + if (auto genConstraintDeclRef = getDeclRef().as<GenericTypeConstraintDecl>()) { - auto genConstraintDecl = genConstraintDeclRef.getDecl(); + auto genericDecl = as<GenericDecl>(getDeclRef().getDecl()->parentDecl); + if (!genericDecl) + goto breakLabel; // search for a substitution that might apply to us - for (auto s = subst.substitutions; s; s = s->getOuter()) + auto args = tryGetGenericArguments(subst, genericDecl); + if (args.getCount() == 0) + goto breakLabel; + + bool found = false; + Index index = 0; + for (auto m : genericDecl->members) { - if (auto genericSubst = as<GenericSubstitution>(s)) + if (auto constraintParam = as<GenericTypeConstraintDecl>(m)) { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genericSubst->getGenericDecl(); - if (genericDecl != genConstraintDecl->parentDecl) - continue; - - bool found = false; - Index index = 0; - for (auto m : genericDecl->members) + if (constraintParam == getDeclRef().getDecl()) { - if (auto constraintParam = as<GenericTypeConstraintDecl>(m)) - { - if (constraintParam == declRef.getDecl()) - { - found = true; - break; - } - index++; - } - } - if (found) - { - (*ioDiff)++; - auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() + - genericDecl->getMembersOfType<GenericValueParamDecl>().getCount(); - SLANG_ASSERT(index + ordinaryParamCount < genericSubst->getArgs().getCount()); - return genericSubst->getArgs()[index + ordinaryParamCount]; + found = true; + break; } + index++; + } + } + if (found) + { + auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() + + genericDecl->getMembersOfType<GenericValueParamDecl>().getCount(); + if (index + ordinaryParamCount < args.getCount()) + { + (*ioDiff)++; + return args[index + ordinaryParamCount]; + } + else + { + // When the `subst` represents a partial substitution, we may not have a corresponding argument. + // In this case we just return the original witness. + // + goto breakLabel; } } } + else if (auto thisTypeConstraintDeclRef = getDeclRef().as<ThisTypeConstraintDecl>()) + { + auto lookupSubst = subst.findLookupDeclRef(); + if (lookupSubst && lookupSubst->getSupDecl() == thisTypeConstraintDeclRef.getDecl()->getInterfaceDecl()) + { + (*ioDiff)++; + return lookupSubst->getWitness(); + } + } + +breakLabel:; // Perform substitution on the constituent elements. int diff = 0; - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff)); + if (!diff) return this; @@ -317,7 +407,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub // so we'll need to change this location in the code if we ever clean // up the hierarchy. // - if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(substDeclRef.getDecl())) + if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(getDeclRef().getDecl())) { if (auto substAssocTypeDecl = as<AssocTypeDecl>(substTypeConstraintDecl->parentDecl)) { @@ -326,12 +416,12 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub // At this point we have a constraint decl for an associated type, // and we nee to see if we are dealing with a concrete substitution // for the interface around that associated type. - if (auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.getSubst(), interfaceDecl)) + if (auto thisTypeWitness = findThisTypeWitness(subst, interfaceDecl)) { // We need to look up the declaration that satisfies // the requirement named by the associated type. Decl* requirementKey = substTypeConstraintDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeSubst->witness, requirementKey); + RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeWitness, requirementKey); switch (requirementWitness.getFlavor()) { default: @@ -348,6 +438,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub } } + auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); auto rs = astBuilder->getDeclaredSubtypeWitness( substSub, substSup, substDeclRef); return rs; @@ -355,34 +446,17 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub void DeclaredSubtypeWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("DeclaredSubtypeWitness(") << sub << toSlice(", ") << sup << toSlice(", ") << declRef << toSlice(")"); -} - -HashCode DeclaredSubtypeWitness::_getHashCodeOverride() -{ - return declRef.getHashCode(); + out << toSlice("DeclaredSubtypeWitness(") << getSub() << toSlice(", ") << getSup() << toSlice(", ") << getDeclRef() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TransitiveSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool TransitiveSubtypeWitness::_equalsValOverride(Val* val) -{ - auto otherWitness = as<TransitiveSubtypeWitness>(val); - if (!otherWitness) - return false; - - return sub->equals(otherWitness->sub) - && sup->equals(otherWitness->sup) - && subToMid->equalsVal(otherWitness->subToMid) - && midToSup->equalsVal(otherWitness->midToSup); -} - Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { int diff = 0; - SubtypeWitness* substSubToMid = as<SubtypeWitness>(subToMid->substituteImpl(astBuilder, subst, &diff)); - SubtypeWitness* substMidToSup = as<SubtypeWitness>(midToSup->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substSubToMid = as<SubtypeWitness>(getSubToMid()->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substMidToSup = as<SubtypeWitness>(getMidToSup()->substituteImpl(astBuilder, subst, &diff)); // If nothing changed, then we can bail out early. if (!diff) @@ -407,16 +481,7 @@ void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out) // witnesses, and rely on them to print // the starting and ending types. - out << toSlice("TransitiveSubtypeWitness(") << subToMid << toSlice(", ") << midToSup << toSlice(")"); -} - -HashCode TransitiveSubtypeWitness::_getHashCodeOverride() -{ - auto hash = sub->getHashCode(); - hash = combineHash(hash, sup->getHashCode()); - hash = combineHash(hash, subToMid->getHashCode()); - hash = combineHash(hash, midToSup->getHashCode()); - return hash; + out << toSlice("TransitiveSubtypeWitness(") << getSubToMid() << toSlice(", ") << getMidToSup() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -425,9 +490,9 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a { int diff = 0; - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); - auto substWitness = as<SubtypeWitness>(conjunctionWitness->substituteImpl(astBuilder, subst, &diff)); + auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff)); + auto substWitness = as<SubtypeWitness>(getConjunctionWitness()->substituteImpl(astBuilder, subst, &diff)); // If nothing changed, then we can bail out early. if (!diff) @@ -447,138 +512,34 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a // simplification logic as needed. // return astBuilder->getExtractFromConjunctionSubtypeWitness( - substSub, substSup, substWitness, indexInConjunction); + substSub, substSup, substWitness, getIndexInConjunction()); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val) -{ - if (auto extractWitness = as<ExtractExistentialSubtypeWitness>(val)) - { - return declRef.equals(extractWitness->declRef); - } - return false; -} - void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("extractExistentialValue(") << declRef << toSlice(")"); -} - -HashCode ExtractExistentialSubtypeWitness::_getHashCodeOverride() -{ - return declRef.getHashCode(); + out << toSlice("extractExistentialValue(") << getDeclRef() << toSlice(")"); } Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); + auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); + auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff)); if (!diff) return this; (*ioDiff)++; - ExtractExistentialSubtypeWitness* substValue = astBuilder->create<ExtractExistentialSubtypeWitness>(); - substValue->declRef = substDeclRef; - substValue->sub = substSub; - substValue->sup = substSup; + ExtractExistentialSubtypeWitness* substValue = astBuilder->getOrCreate<ExtractExistentialSubtypeWitness>( + substSub, substSup, substDeclRef); return substValue; } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -bool TaggedUnionSubtypeWitness::_equalsValOverride(Val* val) -{ - auto taggedUnionWitness = as<TaggedUnionSubtypeWitness>(val); - if (!taggedUnionWitness) - return false; - - auto caseCount = caseWitnesses.getCount(); - if (caseCount != taggedUnionWitness->caseWitnesses.getCount()) - return false; - - for (Index ii = 0; ii < caseCount; ++ii) - { - if (!caseWitnesses[ii]->equalsVal(taggedUnionWitness->caseWitnesses[ii])) - return false; - } - - return true; -} - -void TaggedUnionSubtypeWitness::_toTextOverride(StringBuilder& out) -{ - out << toSlice("TaggedUnionSubtypeWitness("); - bool first = true; - for (auto caseWitness : caseWitnesses) - { - if (!first) - { - out << toSlice(", "); - } - first = false; - - out << caseWitness; - } - out << toSlice(")"); -} - -HashCode TaggedUnionSubtypeWitness::_getHashCodeOverride() -{ - HashCode hash = 0; - for (auto caseWitness : caseWitnesses) - { - hash = combineHash(hash, caseWitness->getHashCode()); - } - return hash; -} - -Val* TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); - - List<SubtypeWitness*> substCaseWitnesses; - for (auto caseWitness : caseWitnesses) - { - substCaseWitnesses.add( - as<SubtypeWitness>(caseWitness->substituteImpl(astBuilder, subst, &diff))); - } - - if (!diff) - return this; - - (*ioDiff)++; - - TaggedUnionSubtypeWitness* substWitness = astBuilder->create<TaggedUnionSubtypeWitness>(); - substWitness->sub = substSub; - substWitness->sup = substSup; - substWitness->caseWitnesses.swapWith(substCaseWitnesses); - return substWitness; -} - -bool ConjunctionSubtypeWitness::_equalsValOverride(Val* val) -{ - auto other = as<ConjunctionSubtypeWitness>(val); - if (!other) - return false; - - for (Index i = 0; i < kComponentCount; ++i) - { - if (!other->componentWitnesses[i]) return false; - if (!other->componentWitnesses[i]->equalsVal(componentWitnesses[i])) return false; - } - return true; -} - void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) { out << "ConjunctionSubtypeWitness("; @@ -586,34 +547,23 @@ void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) { if (i != 0) out << ","; - auto w = componentWitnesses[i]; + auto w = getComponentWitness(i); if (w) out << w; } out << ")"; } -HashCode ConjunctionSubtypeWitness::_getHashCodeOverride() -{ - HashCode result = 0; - for (Index i = 0; i < kComponentCount; ++i) - { - auto w = componentWitnesses[i]; - if (w) result = combineHash(result, w->getHashCode()); - } - return result; -} - Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; Val* substComponentWitnesses[kComponentCount]; - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); + auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff)); for (Index i = 0; i < kComponentCount; ++i) { - auto w = componentWitnesses[i]; + auto w = getComponentWitness(i); substComponentWitnesses[i] = w ? w->substituteImpl(astBuilder, subst, &diff) : nullptr; } @@ -630,65 +580,25 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, auto result = astBuilder->getConjunctionSubtypeWitness( substSub, substSup, - componentWitnesses[0], - componentWitnesses[1]); + as<SubtypeWitness>(substComponentWitnesses[0]), + as<SubtypeWitness>(substComponentWitnesses[1])); return result; } -bool ExtractFromConjunctionSubtypeWitness::_equalsValOverride(Val* val) -{ - if (auto other = as<ExtractFromConjunctionSubtypeWitness>(val)) - { - if(!sub->equals(other->sub)) return false; - if(!sup->equals(other->sup)) return false; - if(indexInConjunction != other->indexInConjunction) return false; - - return true; - } - return false; -} - void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) { out << "ExtractFromConjunctionSubtypeWitness("; - if (conjunctionWitness) - out << conjunctionWitness; - if (sub) - out << sub; + if (getConjunctionWitness()) + out << getConjunctionWitness(); + if (getSub()) + out << getSub(); out << ","; - if (sup) - out << sup; - out << "," << indexInConjunction; + if (getSup()) + out << getSup(); + out << "," << getIndexInConjunction(); out << ")"; } -HashCode ExtractFromConjunctionSubtypeWitness::_getHashCodeOverride() -{ - return combineHash( - conjunctionWitness ? conjunctionWitness->getHashCode() : 0, - sub ? sub->getHashCode() : 0, - sup ? sup->getHashCode() : 0, - indexInConjunction); -} - -// ModifierVal - -bool ModifierVal::_equalsValOverride(Val* val) -{ - // TODO: This is assuming we can fully deduplicate the values that represent - // modifiers, which may not actually be the case if there are multiple modules - // being combined that use different `ASTBuilder`s. - // - return this == val; -} - -HashCode ModifierVal::_getHashCodeOverride() -{ - Hasher hasher; - hasher.hashValue((void*) this); - return hasher.getResult(); -} - // UNormModifierVal void UNormModifierVal::_toTextOverride(StringBuilder& out) @@ -735,48 +645,14 @@ Val* NoDiffModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitu // PolynomialIntVal -bool PolynomialIntVal::_equalsValOverride(Val* val) -{ - if (auto genericParamVal = as<GenericParamIntVal>(val)) - { - return constantTerm == 0 && terms.getCount() == 1 && - terms[0]->paramFactors.getCount() == 1 && terms[0]->constFactor == 1 && - terms[0]->paramFactors[0]->param->equalsVal(genericParamVal) && - terms[0]->paramFactors[0]->power == 1; - } - else if (auto otherPolynomial = as<PolynomialIntVal>(val)) - { - if (constantTerm != otherPolynomial->constantTerm) - return false; - if (terms.getCount() != otherPolynomial->terms.getCount()) - return false; - for (Index i = 0; i < terms.getCount(); i++) - { - auto& thisTerm = *(terms[i]); - auto& thatTerm = *(otherPolynomial->terms[i]); - if (thisTerm.constFactor != thatTerm.constFactor) - return false; - if (thisTerm.paramFactors.getCount() != thatTerm.paramFactors.getCount()) - return false; - for (Index j = 0; j < thisTerm.paramFactors.getCount(); j++) - { - if (thisTerm.paramFactors[j]->power != thatTerm.paramFactors[j]->power) - return false; - if (!thisTerm.paramFactors[j]->param->equalsVal(thatTerm.paramFactors[j]->param)) - return false; - } - } - return true; - } - return false; -} - void PolynomialIntVal::_toTextOverride(StringBuilder& out) { + auto constantTerm = getConstantTerm(); + auto terms = getTerms(); for (Index i = 0; i < terms.getCount(); i++) { auto& term = *(terms[i]); - if (term.constFactor > 0) + if (term.getConstFactor() > 0) { if (i > 0) out << "+"; @@ -784,14 +660,14 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out) else out << "-"; bool isFirstFactor = true; - if (abs(term.constFactor) != 1 || term.paramFactors.getCount() == 0) + if (abs(term.getConstFactor()) != 1 || term.getParamFactors().getCount() == 0) { - out << abs(term.constFactor); + out << abs(term.getConstFactor()); isFirstFactor = false; } - for (Index j = 0; j < term.paramFactors.getCount(); j++) + for (Index j = 0; j < term.getParamFactors().getCount(); j++) { - auto factor = term.paramFactors[j]; + auto factor = term.getParamFactors()[j]; if (isFirstFactor) { isFirstFactor = false; @@ -800,10 +676,10 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out) { out << "*"; } - factor->param->toText(out); - if (factor->power != 1) + factor->getParam()->toText(out); + if (factor->getPower() != 1) { - out << "^^" << factor->power; + out << "^^" << factor->getPower(); } } } @@ -821,227 +697,304 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out) } } -HashCode PolynomialIntVal::_getHashCodeOverride() +struct PolynomialIntValBuilder { - HashCode result = (HashCode)constantTerm; - for (auto& term : terms) + ASTBuilder* astBuilder; + + IntegerLiteralValue constantTerm = 0; + List<PolynomialIntValTerm*> terms; + + PolynomialIntValBuilder(ASTBuilder* inAstBuilder) + : astBuilder(inAstBuilder) + {} + + // compute val += opreand*multiplier; + bool addToPolynomialTerm(IntVal* operand, IntegerLiteralValue multiplier) { - if (!term) continue; - result = combineHash(result, (HashCode)term->constFactor); - for (auto& factor : term->paramFactors) + if (auto c = as<ConstantIntVal>(operand)) { - result = combineHash(result, factor->param->getHashCode()); - result = combineHash(result, (HashCode)factor->power); + constantTerm += c->getValue() * multiplier; + return true; } + else if (auto poly = as<PolynomialIntVal>(operand)) + { + constantTerm += poly->getConstantTerm() * multiplier; + for (auto term : poly->getTerms()) + { + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + multiplier * term->getConstFactor(), term->getParamFactors()); + terms.add(newTerm); + } + return true; + } + else if (auto genVal = as<IntVal>(operand)) + { + auto factor = astBuilder->getOrCreate<PolynomialIntValFactor>(genVal, 1); + auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(multiplier, makeArrayViewSingle(factor)); + terms.add(term); + return true; + } + return false; } - return result; -} + + IntVal* canonicalize(Type* type) + { + List<PolynomialIntValTerm*> newTerms; + IntegerLiteralValue newConstantTerm = constantTerm; + auto addTerm = [&](PolynomialIntValTerm* newTerm) + { + for (auto& term : newTerms) + { + if (term->canCombineWith(*newTerm)) + { + term = astBuilder->getOrCreate<PolynomialIntValTerm>( + term->getConstFactor() + newTerm->getConstFactor(), + term->getParamFactors()); + return; + } + } + newTerms.add(newTerm); + }; + for (auto term : terms) + { + if (term->getConstFactor() == 0) + continue; + List<PolynomialIntValFactor*> newFactors; + List<bool> factorIsDifferent; + for (Index i = 0; i < term->getParamFactors().getCount(); i++) + { + auto factor = term->getParamFactors()[i]; + bool factorFound = false; + for (Index j = 0; j < newFactors.getCount(); j++) + { + auto& newFactor = newFactors[j]; + if (factor->getParam()->equals(newFactor->getParam())) + { + if (!factorIsDifferent[j]) + { + factorIsDifferent[j] = true; + auto clonedFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(newFactor->getParam(), newFactor->getPower()); + newFactor = clonedFactor; + } + newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(newFactor->getParam(), newFactor->getPower() + factor->getPower()); + factorFound = true; + break; + } + } + if (!factorFound) + { + newFactors.add(factor); + factorIsDifferent.add(false); + } + } + List<PolynomialIntValFactor*> newFactors2; + // Remove zero-powered factors. + for (auto factor : newFactors) + { + if (factor->getPower() != 0) + newFactors2.add(factor); + } + if (newFactors2.getCount() == 0) + { + newConstantTerm += term->getConstFactor(); + continue; + } + newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); + bool isDifferent = false; + if (newFactors2.getCount() != term->getParamFactors().getCount()) + isDifferent = true; + if (!isDifferent) + { + for (Index i = 0; i < term->getParamFactors().getCount(); i++) + if (term->getParamFactors()[i] != newFactors2[i]) + { + isDifferent = true; + break; + } + } + if (!isDifferent) + { + addTerm(term); + } + else + { + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(term->getConstFactor(), newFactors2.getArrayView()); + addTerm(newTerm); + } + } + List<PolynomialIntValTerm*> newTerms2; + for (auto term : newTerms) + { + if (term->getConstFactor() == 0) + continue; + newTerms2.add(term); + } + newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); + terms = _Move(newTerms2); + constantTerm = newConstantTerm; + if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->getConstFactor() == 1 && terms[0]->getParamFactors().getCount() == 1 && + terms[0]->getParamFactors()[0]->getPower() == 1) + { + return terms[0]->getParamFactors()[0]->getParam(); + } + if (terms.getCount() == 0) + return astBuilder->getIntVal(type, constantTerm); + return nullptr; + } + + IntVal* getIntVal(Type* type) + { + if (auto canVal = canonicalize(type)) + return canVal; + return astBuilder->getOrCreate<PolynomialIntVal>(type, constantTerm, terms.getArrayView()); + } +}; Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - IntegerLiteralValue evaluatedConstantTerm = constantTerm; - List<PolynomialIntValTerm*> evaluatedTerms; - for (auto& term : terms) + PolynomialIntValBuilder builder(astBuilder); + for (auto& term : getTerms()) { IntegerLiteralValue evaluatedTermConstFactor; List<PolynomialIntValFactor*> evaluatedTermParamFactors; - evaluatedTermConstFactor = term->constFactor; - for (auto& factor : term->paramFactors) + evaluatedTermConstFactor = term->getConstFactor(); + for (auto& factor : term->getParamFactors()) { - auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff); + auto substResult = factor->getParam()->substituteImpl(astBuilder, subst, &diff); if (auto constantVal = as<ConstantIntVal>(substResult)) { - evaluatedTermConstFactor *= constantVal->value; + evaluatedTermConstFactor *= constantVal->getValue(); } else if (auto intResult = as<IntVal>(substResult)) { - auto newFactor = astBuilder->create<PolynomialIntValFactor>(); - newFactor->param = intResult; - newFactor->power = factor->power; + auto newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(intResult, factor->getPower()); evaluatedTermParamFactors.add(newFactor); } } if (evaluatedTermParamFactors.getCount() == 0) { - evaluatedConstantTerm += evaluatedTermConstFactor; + builder.constantTerm += evaluatedTermConstFactor; } else { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->paramFactors = _Move(evaluatedTermParamFactors); - newTerm->constFactor = evaluatedTermConstFactor; - evaluatedTerms.add(newTerm); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + evaluatedTermConstFactor, evaluatedTermParamFactors.getArrayView()); + builder.terms.add(newTerm); } } *ioDiff += diff; - if (evaluatedTerms.getCount() == 0) - return astBuilder->getIntVal(type, evaluatedConstantTerm); + if (builder.terms.getCount() == 0) + return astBuilder->getIntVal(getType(), builder.constantTerm); if (diff != 0) { - auto newPolynomial = astBuilder->create<PolynomialIntVal>(type); - newPolynomial->constantTerm = evaluatedConstantTerm; - newPolynomial->terms = _Move(evaluatedTerms); - return newPolynomial->canonicalize(astBuilder); + return builder.getIntVal(getType()); } return this; } - -// compute val += opreand*multiplier; -bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier) -{ - if (auto c = as<ConstantIntVal>(operand)) - { - val->constantTerm += c->value * multiplier; - return true; - } - else if (auto poly = as<PolynomialIntVal>(operand)) - { - val->constantTerm += poly->constantTerm * multiplier; - for (auto term : poly->terms) - { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = multiplier * term->constFactor; - newTerm->paramFactors = term->paramFactors; - val->terms.add(newTerm); - } - return true; - } - else if (auto genVal = as<IntVal>(operand)) - { - auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = multiplier; - auto factor = astBuilder->create<PolynomialIntValFactor>(); - factor->power = 1; - factor->param = genVal; - term->paramFactors.add(factor); - val->terms.add(term); - return true; - } - return false; -} - -PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) +IntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) { - auto result = astBuilder->create<PolynomialIntVal>(base->type); - if (!addToPolynomialTerm(astBuilder, result, base, -1)) - return nullptr; - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + builder.addToPolynomialTerm(base, -1); + return builder.getIntVal(base->getType()); } -PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +IntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { - auto result = astBuilder->create<PolynomialIntVal>(op0->type); - if (!addToPolynomialTerm(astBuilder, result, op0, 1)) - return nullptr; - if (!addToPolynomialTerm(astBuilder, result, op1, -1)) - return nullptr; - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + builder.addToPolynomialTerm(op0, 1); + builder.addToPolynomialTerm(op1, -1); + return builder.getIntVal(op0->getType()); } -PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +IntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { - auto result = astBuilder->create<PolynomialIntVal>(op0->type); - if (!addToPolynomialTerm(astBuilder, result, op0, 1)) - return nullptr; - if (!addToPolynomialTerm(astBuilder, result, op1, 1)) - return nullptr; - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + builder.addToPolynomialTerm(op0, 1); + builder.addToPolynomialTerm(op1, 1); + return builder.getIntVal(op0->getType()); } -PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { if (auto poly0 = as<PolynomialIntVal>(op0)) { if (auto poly1 = as<PolynomialIntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(poly0->type); + PolynomialIntValBuilder builder(astBuilder); // add poly0.constant * poly1.constant - result->constantTerm = poly0->constantTerm * poly1->constantTerm; + builder.constantTerm = poly0->getConstantTerm() * poly1->getConstantTerm(); // add poly0.constant * poly1.terms - if (poly0->constantTerm != 0) + if (poly0->getConstantTerm() != 0) { - for (auto term : poly1->terms) + for (auto term : poly1->getTerms()) { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = poly0->constantTerm * term->constFactor; - newTerm->paramFactors.addRange(term->paramFactors); - result->terms.add(newTerm); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + poly0->getConstantTerm() * term->getConstFactor(), term->getParamFactors()); + builder.terms.add(newTerm); } } // add poly1.constant * poly0.terms - if (poly1->constantTerm != 0) + if (poly1->getConstantTerm() != 0) { - for (auto term : poly0->terms) + for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = poly1->constantTerm * term->constFactor; - newTerm->paramFactors.addRange(term->paramFactors); - result->terms.add(newTerm); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + poly1->getConstantTerm() * term->getConstFactor(), + term->getParamFactors()); + builder.terms.add(newTerm); } } // add poly1.terms * poly0.terms - for (auto term0 : poly0->terms) + for (auto term0 : poly0->getTerms()) { - for (auto term1 : poly1->terms) + for (auto term1 : poly1->getTerms()) { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term0->constFactor * term1->constFactor; - newTerm->paramFactors.addRange(term0->paramFactors); - newTerm->paramFactors.addRange(term1->paramFactors); - result->terms.add(newTerm); + List<PolynomialIntValFactor*> newFactors; + for (auto f : term0->getParamFactors()) newFactors.add(f); + for (auto f : term1->getParamFactors()) newFactors.add(f); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + term0->getConstFactor() * term1->getConstFactor(), newFactors.getArrayView()); + builder.terms.add(newTerm); } } - result->canonicalize(astBuilder); - return result; + return builder.getIntVal(op0->getType()); } else if (auto cVal1 = as<ConstantIntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(poly0->type); - result->constantTerm = poly0->constantTerm * cVal1->value; - auto factor1 = astBuilder->create<PolynomialIntValFactor>(); - for (auto term : poly0->terms) + PolynomialIntValBuilder builder(astBuilder); + builder.constantTerm = poly0->getConstantTerm() * cVal1->getValue(); + for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term->constFactor * cVal1->value; - newTerm->paramFactors.addRange(term->paramFactors); - newTerm->paramFactors.add(factor1); - result->terms.add(newTerm); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(term->getConstFactor() * cVal1->getValue(), term->getParamFactors()); + builder.terms.add(newTerm); } - result->canonicalize(astBuilder); - return result; + return builder.getIntVal(poly0->getType()); } else if (auto val1 = as<IntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(poly0->type); - result->constantTerm = 0; - auto factor1 = astBuilder->create<PolynomialIntValFactor>(); - factor1->power = 1; - factor1->param = val1; - if (poly0->constantTerm != 0) + PolynomialIntValBuilder builder(astBuilder); + auto factor1 = astBuilder->getOrCreate<PolynomialIntValFactor>(val1, 1); + if (poly0->getConstantTerm() != 0) { - auto term0 = astBuilder->create<PolynomialIntValTerm>(); - term0->constFactor = poly0->constantTerm; - term0->paramFactors.add(factor1); - result->terms.add(term0); + auto term0 = astBuilder->getOrCreate<PolynomialIntValTerm>(poly0->getConstantTerm(), makeArrayViewSingle(factor1)); + builder.terms.add(term0); } - for (auto term : poly0->terms) + for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term->constFactor; - newTerm->paramFactors.addRange(term->paramFactors); - newTerm->paramFactors.add(factor1); - result->terms.add(newTerm); + List<PolynomialIntValFactor*> newFactors; + for (auto f: term->getParamFactors()) + newFactors.add(f); + newFactors.add(factor1); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + term->getConstFactor(), newFactors.getArrayView()); + builder.terms.add(newTerm); } - result->canonicalize(astBuilder); - return result; + return builder.getIntVal(poly0->getType()); } else return nullptr; @@ -1058,184 +1011,48 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int } else if (auto cVal1 = as<ConstantIntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(val0->type); - auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = cVal1->value; - auto factor0 = astBuilder->create<PolynomialIntValFactor>(); - factor0->power = 1; - factor0->param = val0; - term->paramFactors.add(factor0); - result->terms.add(term); - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + auto factor0 = astBuilder->getOrCreate<PolynomialIntValFactor>(val0, 1); + auto term = astBuilder->getOrCreate<PolynomialIntValTerm>( + cVal1->getValue(), makeArrayView(&factor0, 1)); + builder.terms.add(term); + return builder.getIntVal(val0->getType()); } else if (auto val1 = as<IntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(val0->type); - auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = 1; - auto factor0 = astBuilder->create<PolynomialIntValFactor>(); - factor0->power = 1; - factor0->param = val0; - term->paramFactors.add(factor0); - auto factor1 = astBuilder->create<PolynomialIntValFactor>(); - factor1->power = 1; - factor1->param = val1; - term->paramFactors.add(factor1); - result->terms.add(term); - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + auto factor0 = astBuilder->getOrCreate<PolynomialIntValFactor>(val0, 1); + auto factor1 = astBuilder->getOrCreate<PolynomialIntValFactor>(val1, 1); + PolynomialIntValFactor* newFactors[] = { factor0, factor1 }; + auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(1, makeArrayView(newFactors)); + builder.terms.add(term); + return builder.getIntVal(val0->getType()); } } return nullptr; } -IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder) -{ - List<PolynomialIntValTerm*> newTerms; - IntegerLiteralValue newConstantTerm = constantTerm; - auto addTerm = [&](PolynomialIntValTerm* newTerm) - { - for (auto term : newTerms) - { - if (term->canCombineWith(*newTerm)) - { - term->constFactor += newTerm->constFactor; - return; - } - } - newTerms.add(newTerm); - }; - for (auto term : terms) - { - if (term->constFactor == 0) - continue; - List<PolynomialIntValFactor*> newFactors; - List<bool> factorIsDifferent; - for (Index i = 0; i < term->paramFactors.getCount(); i++) - { - auto factor = term->paramFactors[i]; - bool factorFound = false; - for (Index j = 0; j < newFactors.getCount(); j++) - { - auto& newFactor = newFactors[j]; - if (factor->param->equalsVal(newFactor->param)) - { - if (!factorIsDifferent[j]) - { - factorIsDifferent[j] = true; - auto clonedFactor = builder->create<PolynomialIntValFactor>(); - clonedFactor->param = newFactor->param; - clonedFactor->power = newFactor->power; - newFactor = clonedFactor; - } - newFactor->power += factor->power; - factorFound = true; - break; - } - } - if (!factorFound) - { - newFactors.add(factor); - factorIsDifferent.add(false); - } - } - List<PolynomialIntValFactor*> newFactors2; - for (auto factor : newFactors) - { - if (factor->power != 0) - newFactors2.add(factor); - } - if (newFactors2.getCount() == 0) - { - newConstantTerm += term->constFactor; - continue; - } - newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); - bool isDifferent = false; - if (newFactors2.getCount() != term->paramFactors.getCount()) - isDifferent = true; - if (!isDifferent) - { - for (Index i = 0; i < term->paramFactors.getCount(); i++) - if (term->paramFactors[i] != newFactors2[i]) - { - isDifferent = true; - break; - } - } - if (!isDifferent) - { - addTerm(term); - } - else - { - auto newTerm = builder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term->constFactor; - newTerm->paramFactors = _Move(newFactors2); - addTerm(newTerm); - } - } - List<PolynomialIntValTerm*> newTerms2; - for (auto term : newTerms) - { - if (term->constFactor == 0) - continue; - newTerms2.add(term); - } - newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); - terms = _Move(newTerms2); - constantTerm = newConstantTerm; - if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->constFactor == 1 && terms[0]->paramFactors.getCount() == 1 && - terms[0]->paramFactors[0]->power == 1) - { - return terms[0]->paramFactors[0]->param; - } - if (terms.getCount() == 0) - return builder->getIntVal(type, constantTerm); - return this; -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeCastIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool TypeCastIntVal::_equalsValOverride(Val* val) -{ - if (auto typeCastIntVal = as<TypeCastIntVal>(val)) - { - if (!type->equals(typeCastIntVal->type)) - return false; - if (!base->equalsVal(typeCastIntVal->base)) - return false; - return true; - } - return false; -} void TypeCastIntVal::_toTextOverride(StringBuilder& out) { - type->toText(out); + getType()->toText(out); out << "("; - base->toText(out); + getBase()->toText(out); out << ")"; } -HashCode TypeCastIntVal::_getHashCodeOverride() -{ - HashCode result = type->getHashCode(); - result = combineHash(result, base->getHashCode()); - return result; -} - Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink) { SLANG_UNUSED(sink); if (auto c = as<ConstantIntVal>(base)) { - IntegerLiteralValue resultValue = c->value; + IntegerLiteralValue resultValue = c->getValue(); auto baseType = as<BasicExpressionType>(resultType); if (baseType) { - switch (baseType->baseType) + switch (baseType->getBaseType()) { case BaseType::Int: resultValue = (int)resultValue; @@ -1275,11 +1092,11 @@ Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substBase = base->substituteImpl(astBuilder, subst, &diff); - if (substBase != base) + auto substBase = getBase()->substituteImpl(astBuilder, subst, &diff); + if (substBase != getBase()) diff++; - auto substType = as<Type>(type->substituteImpl(astBuilder, subst, &diff)); - if (substType != type) + auto substType = as<Type>(getType()->substituteImpl(astBuilder, subst, &diff)); + if (substType != getType()) diff++; *ioDiff += diff; if (diff) @@ -1289,7 +1106,7 @@ Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio return newVal; else { - auto result = astBuilder->create<TypeCastIntVal>(substType, substBase); + auto result = astBuilder->getOrCreate<TypeCastIntVal>(substType, substBase); return result; } } @@ -1297,29 +1114,20 @@ Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio return this; } - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -bool FuncCallIntVal::_equalsValOverride(Val* val) +Val* TypeCastIntVal::_resolveImplOverride() { - if (auto funcCallIntVal = as<FuncCallIntVal>(val)) - { - if (!funcDeclRef.equals(funcCallIntVal->funcDeclRef)) - return false; - if (args.getCount() != funcCallIntVal->args.getCount()) - return false; - for (Index i = 0; i < args.getCount(); i++) - { - if (!args[i]->equalsVal(funcCallIntVal->args[i])) - return false; - } - return true; - } - return false; + if (auto resolved = tryFoldImpl(getCurrentASTBuilder(), getType(), getBase(), nullptr)) + return resolved; + return this; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + void FuncCallIntVal::_toTextOverride(StringBuilder& out) { + auto args = getArgs(); + auto funcDeclRef = getFuncDeclRef(); + auto argToText = [&](int index) { if (as<PolynomialIntVal>(args[index]) || as<FuncCallIntVal>(args[index])) @@ -1369,14 +1177,37 @@ void FuncCallIntVal::_toTextOverride(StringBuilder& out) } } -HashCode FuncCallIntVal::_getHashCodeOverride() +Val* FuncCallIntVal::_resolveImplOverride() { - HashCode result = funcDeclRef.getHashCode(); + auto astBuilder = getCurrentASTBuilder(); + auto args = getArgs(); + auto funcDeclRef = getFuncDeclRef(); + auto funcType = getFuncType(); + + Val* resolvedVal = this; + + auto newFuncDeclRef = as<DeclRefBase>(funcDeclRef.declRefBase->resolve()); + if (!newFuncDeclRef) + return this; + bool diff = false; + List<IntVal*> newArgs; for (auto arg : args) { - result = combineHash(result, arg->getHashCode()); + auto newArg = as<IntVal>(arg->resolve()); + if (!newArg) + return this; + newArgs.add(newArg); + if (newArg != arg) + diff = true; } - return result; + + if (auto resolved = tryFoldImpl(astBuilder, getType(), newFuncDeclRef, newArgs, nullptr)) + resolvedVal = resolved; + else if (diff) + { + resolvedVal = astBuilder->getOrCreate<FuncCallIntVal>(getType(), newFuncDeclRef, funcType, newArgs.getArrayView()); + } + return resolvedVal; } Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink) @@ -1413,25 +1244,25 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR #define BINARY_OPERATOR_CASE(op) \ if (opNameSlice == toSlice(#op)) \ { \ - resultValue = constArgs[0]->value op constArgs[1]->value; \ + resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ } else #define DIV_OPERATOR_CASE(op) \ if (opNameSlice == toSlice(#op)) \ { \ - if (constArgs[1]->value == 0) \ + if (constArgs[1]->getValue() == 0) \ { \ if (sink) \ sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \ return nullptr; \ } \ - resultValue = constArgs[0]->value op constArgs[1]->value; \ + resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ } else #define LOGICAL_OPERATOR_CASE(op) \ if (opNameSlice == toSlice(#op)) \ { \ - resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \ + resultValue = (((constArgs[0]->getValue()!=0) op (constArgs[1]->getValue()!=0)) ? 1 : 0); \ } else @@ -1463,9 +1294,9 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR LOGICAL_OPERATOR_CASE(&&) LOGICAL_OPERATOR_CASE(||) // Special cases need their "operator" names quoted. - SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->value != 0) ? 1 : 0);) - SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->value;) - SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value;) + SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->getValue() != 0) ? 1 : 0);) + SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->getValue();) + SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->getValue() != 0 ? constArgs[1]->getValue() : constArgs[2]->getValue();) TERMINATING_CASE(SLANG_UNREACHABLE("constant folding of FuncCallIntVal");) return astBuilder->getIntVal(resultType, resultValue); @@ -1483,9 +1314,9 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff); + auto newFuncDeclRef = getFuncDeclRef().substituteImpl(astBuilder, subst, &diff); List<IntVal*> newArgs; - for (auto& arg : args) + for (auto& arg : getArgs()) { auto substArg = arg->substituteImpl(astBuilder, subst, &diff); if (substArg != arg) @@ -1496,15 +1327,12 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio if (diff) { // TODO: report diagnostics back. - auto newVal = tryFoldImpl(astBuilder, type, newFuncDeclRef, newArgs, nullptr); + auto newVal = tryFoldImpl(astBuilder, getType(), newFuncDeclRef, newArgs, nullptr); if (newVal) return newVal; else { - auto result = astBuilder->create<FuncCallIntVal>(type); - result->args = _Move(newArgs); - result->funcDeclRef = newFuncDeclRef; - result->funcType = funcType; + auto result = astBuilder->getOrCreate<FuncCallIntVal>(getType(), newFuncDeclRef, getFuncType(), newArgs.getArrayView()); return result; } } @@ -1514,40 +1342,47 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! WitnessLookupIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool WitnessLookupIntVal::_equalsValOverride(Val* val) -{ - if (auto lookupIntVal = as<WitnessLookupIntVal>(val)) - { - if (!witness->equalsVal(lookupIntVal->witness)) - return false; - if (key != lookupIntVal->key) - return false; - return true; - } - return false; -} - void WitnessLookupIntVal::_toTextOverride(StringBuilder& out) { - witness->sub->toText(out); + getWitness()->getSub()->toText(out); out << "."; - out << (key->getName() ? key->getName()->text : "??"); + out << (getKey()->getName() ? getKey()->getName()->text : "??"); } -HashCode WitnessLookupIntVal::_getHashCodeOverride() +Val* WitnessLookupIntVal::_resolveImplOverride() { - HashCode result = witness->getHashCode(); - result = combineHash(result, Slang::getHashCode(key)); - return result; + auto astBuilder = getCurrentASTBuilder(); + + auto newWitness = as<SubtypeWitness>(getWitness()->resolve()); + if (!newWitness) + return this; + + auto witnessVal = tryLookUpRequirementWitness(astBuilder, newWitness, getKey()); + if (witnessVal.getFlavor() == RequirementWitness::Flavor::val) + { + return witnessVal.getVal(); + } + + auto newType = as<Type>(getType()->resolve()); + if (!newType) + return this; + + if (newWitness != getWitness() || newType != getType()) + { + return astBuilder->getOrCreate<WitnessLookupIntVal>(newType, newWitness, getKey()); + } + + return this; } + Val* WitnessLookupIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto newWitness = witness->substituteImpl(astBuilder, subst, &diff); + auto newWitness = getWitness()->substituteImpl(astBuilder, subst, &diff); *ioDiff += diff; if (diff) { - auto witnessEntry = tryFoldOrNull(astBuilder, as<SubtypeWitness>(newWitness), key); + auto witnessEntry = tryFoldOrNull(astBuilder, as<SubtypeWitness>(newWitness), getKey()); if (witnessEntry) return witnessEntry; } @@ -1573,51 +1408,93 @@ Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witnes { if (auto result = tryFoldOrNull(astBuilder, witness, key)) return result; - auto witnessResult = astBuilder->create<WitnessLookupIntVal>(); - witnessResult->witness = witness; - witnessResult->key = key; - witnessResult->type = type; + auto witnessResult = astBuilder->getOrCreate<WitnessLookupIntVal>(type, witness, key); return witnessResult; } - -bool DifferentiateVal::_equalsValOverride(Val* val) -{ - if (auto other = as<DifferentiateVal>(val)) - { - return other->astNodeType == astNodeType && other->func == func; - } - return false; -} - void DifferentiateVal::_toTextOverride(StringBuilder& out) { out << "DifferentiateVal("; - out << func; + out << getFunc(); out << ")"; } -HashCode DifferentiateVal::_getHashCodeOverride() -{ - HashCode result = (HashCode)astNodeType; - result = combineHash(result, func.getHashCode()); - return result; -} - Val* DifferentiateVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto newFunc = func.substituteImpl(astBuilder, subst, &diff); + auto newFunc = getFunc().substituteImpl(astBuilder, subst, &diff); *ioDiff += diff; if (diff) { auto result = as<DifferentiateVal>(astBuilder->createByNodeType(astNodeType)); - result->func = newFunc; + result->getFunc() = newFunc; return result; } // Nothing found: don't substitute. return this; } +Val* DifferentiateVal::_resolveImplOverride() +{ + return this; +} + +Val* PolynomialIntValFactor::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + + auto newParam = as<IntVal>(getParam()->resolve()); + if (newParam && newParam != getParam()) + return astBuilder->getOrCreate<PolynomialIntValFactor>(newParam, getPower()); + + return this; +} + +Val* PolynomialIntValTerm::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + + bool diff = false; + List<PolynomialIntValFactor*> newFactors; + for (auto factor : getParamFactors()) + { + auto newFactor = as<PolynomialIntValFactor>(factor->resolve()); + if (!newFactor) + return this; + + if (newFactor != factor) + diff = true; + newFactors.add(newFactor); + } + + if (diff) + return astBuilder->getOrCreate<PolynomialIntValTerm>(getConstFactor(), newFactors.getArrayView()); + + return this; +} + +Val* PolynomialIntVal::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + + bool diff = false; + PolynomialIntValBuilder builder(astBuilder); + builder.constantTerm = getConstantTerm(); + for (auto term : getTerms()) + { + auto newTerm = as<PolynomialIntValTerm>(term->resolve()); + if (!newTerm) + return this; + + if (newTerm != term) + diff = true; + builder.terms.add(newTerm); + } + + if (diff) + return builder.getIntVal(getType()); + + return this; +} } // namespace Slang |
