diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2024-10-29 14:49:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-29 14:49:26 +0800 |
| commit | f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch) | |
| tree | ea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ast-val.cpp | |
| parent | a729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff) | |
format
* format
* Minor test fixes
* enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 482 |
1 files changed, 304 insertions, 178 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index b8c5e6ee1..2a2f275ee 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -1,16 +1,18 @@ // slang-ast-type.cpp -#include "slang-ast-builder.h" -#include <assert.h> -#include <typeinfo> +#include "slang-ast-val.h" -#include "slang-generated-ast-macro.h" +#include "slang-ast-builder.h" +#include "slang-check-impl.h" #include "slang-diagnostics.h" -#include "slang-syntax.h" -#include "slang-ast-val.h" +#include "slang-generated-ast-macro.h" #include "slang-mangle.h" -#include "slang-check-impl.h" +#include "slang-syntax.h" + +#include <assert.h> +#include <typeinfo> -namespace Slang { +namespace Slang +{ void ValNodeDesc::init() { @@ -30,7 +32,8 @@ void ValNodeDesc::init() Val* Val::substitute(ASTBuilder* astBuilder, SubstitutionSet subst) { - if (!subst) return this; + if (!subst) + return this; int diff = 0; return substituteImpl(astBuilder, subst, &diff); } @@ -40,12 +43,9 @@ Val* Val::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioD SLANG_AST_NODE_VIRTUAL_CALL(Val, substituteImpl, (astBuilder, subst, ioDiff)) } -void Val::toText(StringBuilder& out) -{ - SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out)) -} +void Val::toText(StringBuilder& out){SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out))} -Val* Val::_resolveImplOverride() +Val* Val::_resolveImplOverride() { SLANG_UNEXPECTED("Val::_resolveImplOverride not overridden"); } @@ -60,7 +60,7 @@ Val* Val::resolve() auto astBuilder = getCurrentASTBuilder(); // If we are not in a proper checking context, just return the previously resolved val. if (!astBuilder) - return m_resolvedVal? m_resolvedVal : this; + return m_resolvedVal ? m_resolvedVal : this; if (m_resolvedVal && m_resolvedValEpoch == astBuilder->getEpoch()) { SLANG_ASSERT(as<Val>(m_resolvedVal)); @@ -72,7 +72,8 @@ Val* Val::resolve() #ifdef _DEBUG if (m_resolvedVal->_debugUID > 0 && this->_debugUID < 0) { - SLANG_ASSERT_FAILURE("should not be modifying the core module vals outside of the core module checking."); + SLANG_ASSERT_FAILURE( + "should not be modifying the core module vals outside of the core module checking."); } #endif return m_resolvedVal; @@ -86,7 +87,8 @@ void Val::_setUnique() Val* Val::defaultResolveImpl() { - // Default resolve implementation is to recursively resolve all operands, and lookup in deduplication cache. + // Default resolve implementation is to recursively resolve all operands, and lookup in + // deduplication cache. ValNodeDesc newDesc; newDesc.type = astNodeType; bool diff = false; @@ -107,7 +109,7 @@ Val* Val::defaultResolveImpl() } newDesc.operands.add(operand); } - + if (!diff) return this; @@ -220,10 +222,12 @@ Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet // Nothing found: don't substitute. return paramVal; - } -Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff) +Val* GenericParamIntVal::_substituteImplOverride( + ASTBuilder* /* astBuilder */, + SubstitutionSet subst, + int* ioDiff) { if (auto result = maybeSubstituteGenericParam(this, getDeclRef().getDecl(), subst, ioDiff)) return result; @@ -252,7 +256,10 @@ void ErrorIntVal::_toTextOverride(StringBuilder& out) out << toSlice("<error>"); } -Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ErrorIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { SLANG_UNUSED(astBuilder); SLANG_UNUSED(subst); @@ -260,7 +267,10 @@ Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe return this; } -Val* TypeEqualityWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +Val* TypeEqualityWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { auto type = as<Type>(getSub()->substituteImpl(astBuilder, subst, ioDiff)); TypeEqualityWitness* rs = astBuilder->getOrCreate<TypeEqualityWitness>(type, type); @@ -274,7 +284,10 @@ void TypeEqualityWitness::_toTextOverride(StringBuilder& out) // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypePackSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Val* TypePackSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* TypePackSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; ShortList<SubtypeWitness*> newWitnesses; @@ -289,7 +302,10 @@ Val* TypePackSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub if (!diff) return this; (*ioDiff)++; - return getCurrentASTBuilder()->getSubtypeWitnessPack(newSub, newSup, newWitnesses.getArrayView().arrayView); + return getCurrentASTBuilder()->getSubtypeWitnessPack( + newSub, + newSup, + newWitnesses.getArrayView().arrayView); } Val* TypePackSubtypeWitness::_resolveImplOverride() @@ -313,7 +329,10 @@ Val* TypePackSubtypeWitness::_resolveImplOverride() if (!diff) return this; - return getCurrentASTBuilder()->getSubtypeWitnessPack(newSub, newSup, newWitnesses.getArrayView().arrayView); + return getCurrentASTBuilder()->getSubtypeWitnessPack( + newSub, + newSup, + newWitnesses.getArrayView().arrayView); } void TypePackSubtypeWitness::_toTextOverride(StringBuilder& out) @@ -330,7 +349,10 @@ void TypePackSubtypeWitness::_toTextOverride(StringBuilder& out) // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExpandSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Val* ExpandSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ExpandSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto newSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); @@ -346,16 +368,24 @@ Val* ExpandSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Subst { auto elementType = subTypePack->getElementType(i); subst.packExpansionIndex = i; - auto elementWitness = as<SubtypeWitness>(getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); - auto newWitness = getCurrentASTBuilder()->getExpandSubtypeWitness(elementType, newSup, elementWitness); + auto elementWitness = as<SubtypeWitness>( + getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); + auto newWitness = getCurrentASTBuilder()->getExpandSubtypeWitness( + elementType, + newSup, + elementWitness); newWitnesses.add(as<SubtypeWitness>(newWitness)); } (*ioDiff)++; - return getCurrentASTBuilder()->getSubtypeWitnessPack(newSub, newSup, newWitnesses.getArrayView().arrayView); + return getCurrentASTBuilder()->getSubtypeWitnessPack( + newSub, + newSup, + newWitnesses.getArrayView().arrayView); } (*ioDiff)++; - auto newPatternWitness = as<SubtypeWitness>(getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); + auto newPatternWitness = + as<SubtypeWitness>(getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); return getCurrentASTBuilder()->getExpandSubtypeWitness(newSub, newSup, newPatternWitness); } @@ -385,10 +415,14 @@ void ExpandSubtypeWitness::_toTextOverride(StringBuilder& out) // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! EachSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Val* EachSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* EachSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; - auto newPatternWitness = as<SubtypeWitness>(getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); + auto newPatternWitness = + as<SubtypeWitness>(getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); if (auto witnessPack = as<TypePackSubtypeWitness>(newPatternWitness)) { if (subst.packExpansionIndex >= 0 && subst.packExpansionIndex < witnessPack->getCount()) @@ -464,7 +498,10 @@ ConversionCost DeclaredSubtypeWitness::_getOverloadResolutionCostOverride() return kConversionCost_None; } -Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +Val* DeclaredSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { if (auto genConstraintDeclRef = getDeclRef().as<GenericTypeConstraintDecl>()) { @@ -493,7 +530,8 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub } if (found) { - auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDeclBase>().getCount() + + auto ordinaryParamCount = + genericDecl->getMembersOfType<GenericTypeParamDeclBase>().getCount() + genericDecl->getMembersOfType<GenericValueParamDecl>().getCount(); if (index + ordinaryParamCount < args.getCount()) { @@ -502,8 +540,8 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub } else { - // When the `subst` represents a partial substitution, we may not have a corresponding argument. - // In this case we just return the original witness. + // When the `subst` represents a partial substitution, we may not have a + // corresponding argument. In this case we just return the original witness. // goto breakLabel; } @@ -512,7 +550,8 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub else if (auto thisTypeConstraintDeclRef = getDeclRef().as<ThisTypeConstraintDecl>()) { auto lookupSubst = subst.findLookupDeclRef(); - if (lookupSubst && lookupSubst->getSupDecl() == thisTypeConstraintDeclRef.getDecl()->getInterfaceDecl()) + if (lookupSubst && + lookupSubst->getSupDecl() == thisTypeConstraintDeclRef.getDecl()->getInterfaceDecl()) { (*ioDiff)++; return lookupSubst->getWitness(); @@ -525,7 +564,7 @@ breakLabel:; int diff = 0; auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff)); - + if (!diff) return this; @@ -555,13 +594,13 @@ breakLabel:; // We need to look up the declaration that satisfies // the requirement named by the associated type. Decl* requirementKey = substTypeConstraintDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeWitness, requirementKey); + RequirementWitness requirementWitness = + tryLookUpRequirementWitness(astBuilder, thisTypeWitness, requirementKey); switch (requirementWitness.getFlavor()) { - default: - break; + default: break; - case RequirementWitness::Flavor::val: + case RequirementWitness::Flavor::val: { auto satisfyingVal = requirementWitness.getVal(); return satisfyingVal; @@ -573,24 +612,29 @@ breakLabel:; } auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); - auto rs = astBuilder->getDeclaredSubtypeWitness( - substSub, substSup, substDeclRef); + auto rs = astBuilder->getDeclaredSubtypeWitness(substSub, substSup, substDeclRef); return rs; } void DeclaredSubtypeWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("DeclaredSubtypeWitness(") << getSub() << toSlice(", ") << getSup() << toSlice(", ") << getDeclRef() << toSlice(")"); + out << toSlice("DeclaredSubtypeWitness(") << getSub() << toSlice(", ") << getSup() + << toSlice(", ") << getDeclRef() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TransitiveSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +Val* TransitiveSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; - SubtypeWitness* substSubToMid = as<SubtypeWitness>(getSubToMid()->substituteImpl(astBuilder, subst, &diff)); - SubtypeWitness* substMidToSup = as<SubtypeWitness>(getMidToSup()->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substSubToMid = + as<SubtypeWitness>(getSubToMid()->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substMidToSup = + as<SubtypeWitness>(getMidToSup()->substituteImpl(astBuilder, subst, &diff)); // If nothing changed, then we can bail out early. if (!diff) @@ -611,7 +655,8 @@ Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, S ConversionCost TransitiveSubtypeWitness::_getOverloadResolutionCostOverride() { - return getSubToMid()->getOverloadResolutionCost() + getMidToSup()->getOverloadResolutionCost() + kConversionCost_GenericParamUpcast; + return getSubToMid()->getOverloadResolutionCost() + getMidToSup()->getOverloadResolutionCost() + + kConversionCost_GenericParamUpcast; } void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out) @@ -619,19 +664,25 @@ void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out) // Note: we only print the constituent // witnesses, and rely on them to print // the starting and ending types. - - out << toSlice("TransitiveSubtypeWitness(") << getSubToMid() << toSlice(", ") << getMidToSup() << toSlice(")"); + + out << toSlice("TransitiveSubtypeWitness(") << getSubToMid() << toSlice(", ") << getMidToSup() + << toSlice(")"); } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff)); - auto substWitness = as<SubtypeWitness>(getConjunctionWitness()->substituteImpl(astBuilder, subst, &diff)); + auto substWitness = + as<SubtypeWitness>(getConjunctionWitness()->substituteImpl(astBuilder, subst, &diff)); // If nothing changed, then we can bail out early. if (!diff) @@ -651,7 +702,10 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a // simplification logic as needed. // return astBuilder->getExtractFromConjunctionSubtypeWitness( - substSub, substSup, substWitness, getIndexInConjunction()); + substSub, + substSup, + substWitness, + getIndexInConjunction()); } ConversionCost ExtractFromConjunctionSubtypeWitness::_getOverloadResolutionCostOverride() @@ -665,14 +719,18 @@ ConversionCost ExtractFromConjunctionSubtypeWitness::_getOverloadResolutionCostO return kConversionCost_None; } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out) { out << toSlice("extractExistentialValue(") << getDeclRef() << toSlice(")"); } -Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ExtractExistentialSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; @@ -685,8 +743,8 @@ Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBu (*ioDiff)++; - ExtractExistentialSubtypeWitness* substValue = astBuilder->getOrCreate<ExtractExistentialSubtypeWitness>( - substSub, substSup, substDeclRef); + ExtractExistentialSubtypeWitness* substValue = + astBuilder->getOrCreate<ExtractExistentialSubtypeWitness>(substSub, substSup, substDeclRef); return substValue; } @@ -695,15 +753,20 @@ void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) out << "ConjunctionSubtypeWitness("; for (Index i = 0; i < kComponentCount; ++i) { - if (i != 0) out << ","; + if (i != 0) + out << ","; auto w = getComponentWitness(i); - if (w) out << w; + if (w) + out << w; } out << ")"; } -Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ConjunctionSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; Val* substComponentWitnesses[kComponentCount]; @@ -717,7 +780,7 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, substComponentWitnesses[i] = w ? w->substituteImpl(astBuilder, subst, &diff) : nullptr; } - if(!diff) + if (!diff) return this; *ioDiff += diff; @@ -764,7 +827,10 @@ void UNormModifierVal::_toTextOverride(StringBuilder& out) out.append("unorm"); } -Val* UNormModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* UNormModifierVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { SLANG_UNUSED(astBuilder); SLANG_UNUSED(subst); @@ -779,7 +845,10 @@ void SNormModifierVal::_toTextOverride(StringBuilder& out) out.append("snorm"); } -Val* SNormModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* SNormModifierVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { SLANG_UNUSED(astBuilder); SLANG_UNUSED(subst); @@ -793,7 +862,10 @@ void NoDiffModifierVal::_toTextOverride(StringBuilder& out) out.append("no_diff"); } -Val* NoDiffModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* NoDiffModifierVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { SLANG_UNUSED(astBuilder); SLANG_UNUSED(subst); @@ -864,7 +936,8 @@ struct PolynomialIntValBuilder PolynomialIntValBuilder(ASTBuilder* inAstBuilder) : astBuilder(inAstBuilder) - {} + { + } // compute val += opreand*multiplier; bool addToPolynomialTerm(IntVal* operand, IntegerLiteralValue multiplier) @@ -880,7 +953,8 @@ struct PolynomialIntValBuilder for (auto term : poly->getTerms()) { auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( - multiplier * term->getConstFactor(), term->getParamFactors()); + multiplier * term->getConstFactor(), + term->getParamFactors()); terms.add(newTerm); } return true; @@ -888,7 +962,9 @@ struct PolynomialIntValBuilder else if (auto genVal = as<IntVal>(operand)) { auto factor = astBuilder->getOrCreate<PolynomialIntValFactor>(genVal, 1); - auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(multiplier, makeArrayViewSingle(factor)); + auto term = astBuilder->getOrCreate<PolynomialIntValTerm>( + multiplier, + makeArrayViewSingle(factor)); terms.add(term); return true; } @@ -931,10 +1007,14 @@ struct PolynomialIntValBuilder if (!factorIsDifferent[j]) { factorIsDifferent[j] = true; - auto clonedFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(newFactor->getParam(), newFactor->getPower()); + auto clonedFactor = astBuilder->getOrCreate<PolynomialIntValFactor>( + newFactor->getParam(), + newFactor->getPower()); newFactor = clonedFactor; } - newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(newFactor->getParam(), newFactor->getPower() + factor->getPower()); + newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>( + newFactor->getParam(), + newFactor->getPower() + factor->getPower()); factorFound = true; break; } @@ -957,7 +1037,8 @@ struct PolynomialIntValBuilder newConstantTerm += term->getConstFactor(); continue; } - newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); + newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) + { return *t1 < *t2; }); bool isDifferent = false; if (newFactors2.getCount() != term->getParamFactors().getCount()) isDifferent = true; @@ -976,7 +1057,9 @@ struct PolynomialIntValBuilder } else { - auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(term->getConstFactor(), newFactors2.getArrayView()); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + term->getConstFactor(), + newFactors2.getArrayView()); addTerm(newTerm); } } @@ -987,10 +1070,12 @@ struct PolynomialIntValBuilder continue; newTerms2.add(term); } - newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); + newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) + { return *t1 < *t2; }); terms = _Move(newTerms2); constantTerm = newConstantTerm; - if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->getConstFactor() == 1 && terms[0]->getParamFactors().getCount() == 1 && + if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->getConstFactor() == 1 && + terms[0]->getParamFactors().getCount() == 1 && terms[0]->getParamFactors()[0]->getPower() == 1) { return terms[0]->getParamFactors()[0]->getParam(); @@ -1008,7 +1093,10 @@ struct PolynomialIntValBuilder } }; -Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* PolynomialIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; PolynomialIntValBuilder builder(astBuilder); @@ -1021,14 +1109,15 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut for (auto& factor : term->getParamFactors()) { auto substResult = factor->getParam()->substituteImpl(astBuilder, subst, &diff); - + if (auto constantVal = as<ConstantIntVal>(substResult)) { evaluatedTermConstFactor *= constantVal->getValue(); } else if (auto intResult = as<IntVal>(substResult)) { - auto newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(intResult, factor->getPower()); + auto newFactor = + astBuilder->getOrCreate<PolynomialIntValFactor>(intResult, factor->getPower()); evaluatedTermParamFactors.add(newFactor); } } @@ -1038,7 +1127,8 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut } else { - if (evaluatedTermParamFactors.getCount() == 1 && evaluatedTermParamFactors[0]->getPower() == 1) + if (evaluatedTermParamFactors.getCount() == 1 && + evaluatedTermParamFactors[0]->getPower() == 1) { if (auto polyTerm = as<PolynomialIntVal>(evaluatedTermParamFactors[0]->getParam())) { @@ -1047,7 +1137,8 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut } } auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( - evaluatedTermConstFactor, evaluatedTermParamFactors.getArrayView()); + evaluatedTermConstFactor, + evaluatedTermParamFactors.getArrayView()); builder.terms.add(newTerm); } } @@ -1101,7 +1192,8 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) for (auto term : poly1->getTerms()) { auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( - poly0->getConstantTerm() * term->getConstFactor(), term->getParamFactors()); + poly0->getConstantTerm() * term->getConstFactor(), + term->getParamFactors()); builder.terms.add(newTerm); } } @@ -1122,10 +1214,13 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) for (auto term1 : poly1->getTerms()) { List<PolynomialIntValFactor*> newFactors; - for (auto f : term0->getParamFactors()) newFactors.add(f); - for (auto f : term1->getParamFactors()) newFactors.add(f); + for (auto f : term0->getParamFactors()) + newFactors.add(f); + for (auto f : term1->getParamFactors()) + newFactors.add(f); auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( - term0->getConstFactor() * term1->getConstFactor(), newFactors.getArrayView()); + term0->getConstFactor() * term1->getConstFactor(), + newFactors.getArrayView()); builder.terms.add(newTerm); } } @@ -1137,7 +1232,9 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) builder.constantTerm = poly0->getConstantTerm() * cVal1->getValue(); for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(term->getConstFactor() * cVal1->getValue(), term->getParamFactors()); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + term->getConstFactor() * cVal1->getValue(), + term->getParamFactors()); builder.terms.add(newTerm); } return builder.getIntVal(poly0->getType()); @@ -1148,17 +1245,20 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) auto factor1 = astBuilder->getOrCreate<PolynomialIntValFactor>(val1, 1); if (poly0->getConstantTerm() != 0) { - auto term0 = astBuilder->getOrCreate<PolynomialIntValTerm>(poly0->getConstantTerm(), makeArrayViewSingle(factor1)); + auto term0 = astBuilder->getOrCreate<PolynomialIntValTerm>( + poly0->getConstantTerm(), + makeArrayViewSingle(factor1)); builder.terms.add(term0); } for (auto term : poly0->getTerms()) { List<PolynomialIntValFactor*> newFactors; - for (auto f: term->getParamFactors()) + for (auto f : term->getParamFactors()) newFactors.add(f); newFactors.add(factor1); auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( - term->getConstFactor(), newFactors.getArrayView()); + term->getConstFactor(), + newFactors.getArrayView()); builder.terms.add(newTerm); } return builder.getIntVal(poly0->getType()); @@ -1181,7 +1281,8 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) PolynomialIntValBuilder builder(astBuilder); auto factor0 = astBuilder->getOrCreate<PolynomialIntValFactor>(val0, 1); auto term = astBuilder->getOrCreate<PolynomialIntValTerm>( - cVal1->getValue(), makeArrayView(&factor0, 1)); + cVal1->getValue(), + makeArrayView(&factor0, 1)); builder.terms.add(term); return builder.getIntVal(val0->getType()); } @@ -1190,7 +1291,7 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) PolynomialIntValBuilder builder(astBuilder); auto factor0 = astBuilder->getOrCreate<PolynomialIntValFactor>(val0, 1); auto factor1 = astBuilder->getOrCreate<PolynomialIntValFactor>(val1, 1); - PolynomialIntValFactor* newFactors[] = { factor0, factor1 }; + PolynomialIntValFactor* newFactors[] = {factor0, factor1}; auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(1, makeArrayView(newFactors)); builder.terms.add(term); return builder.getIntVal(val0->getType()); @@ -1209,43 +1310,30 @@ void TypeCastIntVal::_toTextOverride(StringBuilder& out) out << ")"; } -Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink) +Val* TypeCastIntVal::tryFoldImpl( + ASTBuilder* astBuilder, + Type* resultType, + Val* base, + DiagnosticSink* sink) { SLANG_UNUSED(sink); auto convertValue = [&](BasicExpressionType* baseType, IntegerLiteralValue& resultValue) -> bool + { + switch (baseType->getBaseType()) { - switch (baseType->getBaseType()) - { - case BaseType::Int: - resultValue = (int)resultValue; - return true; - case BaseType::UInt: - resultValue = (unsigned int)resultValue; - return true; - case BaseType::Int64: - case BaseType::IntPtr: - resultValue = (Int64)resultValue; - return true; - case BaseType::UInt64: - case BaseType::UIntPtr: - resultValue = (UInt64)resultValue; - return true; - case BaseType::Int16: - resultValue = (int16_t)resultValue; - return true; - case BaseType::UInt16: - resultValue = (uint16_t)resultValue; - return true; - case BaseType::Int8: - resultValue = (int8_t)resultValue; - return true; - case BaseType::UInt8: - resultValue = (uint8_t)resultValue; - return true; - default: - return false; - } - }; + case BaseType::Int: resultValue = (int)resultValue; return true; + case BaseType::UInt: resultValue = (unsigned int)resultValue; return true; + case BaseType::Int64: + case BaseType::IntPtr: resultValue = (Int64)resultValue; return true; + case BaseType::UInt64: + case BaseType::UIntPtr: resultValue = (UInt64)resultValue; return true; + case BaseType::Int16: resultValue = (int16_t)resultValue; return true; + case BaseType::UInt16: resultValue = (uint16_t)resultValue; return true; + case BaseType::Int8: resultValue = (int8_t)resultValue; return true; + case BaseType::UInt8: resultValue = (uint8_t)resultValue; return true; + default: return false; + } + }; if (auto c = as<ConstantIntVal>(base)) { IntegerLiteralValue resultValue = c->getValue(); @@ -1275,7 +1363,10 @@ Val* TypeCastIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map) return tryFoldImpl(getCurrentASTBuilder(), getType(), resolvedBase, nullptr); } -Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* TypeCastIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto substBase = getBase()->substituteImpl(astBuilder, subst, &diff); @@ -1332,7 +1423,8 @@ void FuncCallIntVal::_toTextOverride(StringBuilder& out) { argToText(0); out << (name ? name->text : ""); - argToText(1);; + argToText(1); + ; } else if (args.getCount() == 1) { @@ -1356,7 +1448,8 @@ void FuncCallIntVal::_toTextOverride(StringBuilder& out) out << "("; for (Index i = 0; i < args.getCount(); i++) { - if (i > 0) out << ", "; + if (i > 0) + out << ", "; args[i]->toText(out); } out << ")"; @@ -1371,7 +1464,7 @@ Val* FuncCallIntVal::_resolveImplOverride() auto funcType = getFuncType(); Val* resolvedVal = this; - + auto newFuncDeclRef = as<DeclRefBase>(funcDeclRef.declRefBase->resolve()); if (!newFuncDeclRef) return this; @@ -1391,12 +1484,21 @@ Val* FuncCallIntVal::_resolveImplOverride() resolvedVal = resolved; else if (diff) { - resolvedVal = astBuilder->getOrCreate<FuncCallIntVal>(getType(), newFuncDeclRef, funcType, newArgs.getArrayView()); + resolvedVal = astBuilder->getOrCreate<FuncCallIntVal>( + getType(), + newFuncDeclRef, + funcType, + newArgs.getArrayView()); } return resolvedVal; } -Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink) +Val* FuncCallIntVal::tryFoldImpl( + ASTBuilder* astBuilder, + Type* resultType, + DeclRef<Decl> newFuncDecl, + List<IntVal*>& newArgs, + DiagnosticSink* sink) { // Are all args const now? List<ConstantIntVal*> constArgs; @@ -1422,46 +1524,48 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR const auto opNameSlice = opName->text.getUnownedSlice(); IntegerLiteralValue resultValue = 0; - - // Define convenience macros. + + // Define convenience macros. // The last macro used in the list *must* be // TERMINATING_CASE, as this handles the closing else, and matches if nothing else does. -#define BINARY_OPERATOR_CASE(op) \ - if (opNameSlice == toSlice(#op)) \ - { \ - resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ - } else - -#define DIV_OPERATOR_CASE(op) \ - if (opNameSlice == toSlice(#op)) \ - { \ - if (constArgs[1]->getValue() == 0) \ - { \ - if (sink) \ - sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \ - return nullptr; \ - } \ - resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ - } else - -#define LOGICAL_OPERATOR_CASE(op) \ - if (opNameSlice == toSlice(#op)) \ - { \ - resultValue = (((constArgs[0]->getValue()!=0) op (constArgs[1]->getValue()!=0)) ? 1 : 0); \ - } else +#define BINARY_OPERATOR_CASE(op) \ + if (opNameSlice == toSlice(#op)) \ + { \ + resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ + } \ + else + +#define DIV_OPERATOR_CASE(op) \ + if (opNameSlice == toSlice(#op)) \ + { \ + if (constArgs[1]->getValue() == 0) \ + { \ + if (sink) \ + sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \ + return nullptr; \ + } \ + resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ + } \ + else + +#define LOGICAL_OPERATOR_CASE(op) \ + if (opNameSlice == toSlice(#op)) \ + { \ + resultValue = \ + (((constArgs[0]->getValue() != 0) op(constArgs[1]->getValue() != 0)) ? 1 : 0); \ + } \ + else #define SPECIAL_OPERATOR_CASE(op, IF_MATCH) \ - if (opNameSlice == toSlice(op)) \ - { \ - IF_MATCH \ - } else - -#define TERMINATING_CASE(MATCH) \ - { \ - MATCH \ - } + if (opNameSlice == toSlice(op)) \ + { \ + IF_MATCH \ + } \ + else + +#define TERMINATING_CASE(MATCH) {MATCH} // Handle the cases using the macros BINARY_OPERATOR_CASE(>=) @@ -1482,16 +1586,19 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR // Special cases need their "operator" names quoted. SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->getValue() != 0) ? 1 : 0);) SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->getValue();) - SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->getValue() != 0 ? constArgs[1]->getValue() : constArgs[2]->getValue();) + SPECIAL_OPERATOR_CASE("?:", + resultValue = constArgs[0]->getValue() != 0 + ? constArgs[1]->getValue() + : constArgs[2]->getValue();) TERMINATING_CASE(SLANG_UNREACHABLE("constant folding of FuncCallIntVal");) return astBuilder->getIntVal(resultType, resultValue); // The macros for the cases are no longer needed so undef them all. #undef BINARY_OPERATOR_CASE -#undef DIV_OPERATOR_CASE +#undef DIV_OPERATOR_CASE #undef LOGICAL_OPERATOR_CASE -#undef SPECIAL_OPERATOR_CASE +#undef SPECIAL_OPERATOR_CASE #undef TERMINATING_CASE } return nullptr; @@ -1505,7 +1612,10 @@ Val* FuncCallIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map) return tryFoldImpl(getCurrentASTBuilder(), getType(), getFuncDeclRef(), newArgs, nullptr); } -Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* FuncCallIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto newFuncDeclRef = getFuncDeclRef().substituteImpl(astBuilder, subst, &diff); @@ -1526,7 +1636,11 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio return newVal; else { - auto result = astBuilder->getOrCreate<FuncCallIntVal>(getType(), newFuncDeclRef, getFuncType(), newArgs.getArrayView()); + auto result = astBuilder->getOrCreate<FuncCallIntVal>( + getType(), + newFuncDeclRef, + getFuncType(), + newArgs.getArrayView()); return result; } } @@ -1590,7 +1704,10 @@ Val* CountOfIntVal::tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType return result; } -Val* CountOfIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* CountOfIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto newType = as<Type>(getTypeArg()->substituteImpl(astBuilder, subst, &diff)); @@ -1644,7 +1761,10 @@ Val* WitnessLookupIntVal::_resolveImplOverride() return this; } -Val* WitnessLookupIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* WitnessLookupIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto newWitness = getWitness()->substituteImpl(astBuilder, subst, &diff); @@ -1669,16 +1789,17 @@ Val* WitnessLookupIntVal::tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* auto witnessEntry = tryLookUpRequirementWitness(astBuilder, witness, key); switch (witnessEntry.getFlavor()) { - case RequirementWitness::Flavor::val: - return witnessEntry.getVal(); - break; - default: - break; + case RequirementWitness::Flavor::val: return witnessEntry.getVal(); break; + default: break; } return nullptr; } -Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key, Type* type) +Val* WitnessLookupIntVal::tryFold( + ASTBuilder* astBuilder, + SubtypeWitness* witness, + Decl* key, + Type* type) { if (auto result = tryFoldOrNull(astBuilder, witness, key)) return result; @@ -1693,7 +1814,10 @@ void DifferentiateVal::_toTextOverride(StringBuilder& out) out << ")"; } -Val* DifferentiateVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* DifferentiateVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto newFunc = getFunc().substituteImpl(astBuilder, subst, &diff); @@ -1742,7 +1866,9 @@ Val* PolynomialIntValTerm::_resolveImplOverride() } if (diff) - return astBuilder->getOrCreate<PolynomialIntValTerm>(getConstFactor(), newFactors.getArrayView()); + return astBuilder->getOrCreate<PolynomialIntValTerm>( + getConstFactor(), + newFactors.getArrayView()); return this; } |
