diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-04 15:47:39 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-04 15:47:39 -0700 |
| commit | a2d90fb275962da84611160f8ddd74d934a68dbd (patch) | |
| tree | 066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-check-constraint.cpp | |
| parent | 17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff) | |
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val.
* Update project files
* Fix warning.
* Fix.
* Fix.
* Remove `Val::_equalsImplOverride`.
* Rmove `Val::_getHashCodeOverride`.
* Remove `semanticVisitor` param from `resolve`.
* Cleanups.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 165 |
1 files changed, 77 insertions, 88 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 22a92bf0a..b9d33a1c1 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -65,14 +65,14 @@ namespace Slang // That is, the join of a vector and a scalar type is // a vector type with a joined element type. auto joinElementType = TryJoinTypes( - vectorType->elementType, + vectorType->getElementType(), scalarType); if(!joinElementType) return nullptr; return createVectorType( joinElementType, - vectorType->elementCount); + vectorType->getElementCount()); } Type* SemanticsVisitor::_tryJoinTypeWithInterface( @@ -110,11 +110,11 @@ namespace Slang for(Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); baseTypeFlavorIndex++) { // Don't consider `type`, since we already know it doesn't work. - if(baseTypeFlavorIndex == Int(basicType->baseType)) + if(baseTypeFlavorIndex == Int(basicType->getBaseType())) continue; // Look up the type in our session. - auto candidateType = type->getASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex)); + auto candidateType = getCurrentASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex)); if(!candidateType) continue; @@ -186,8 +186,8 @@ namespace Slang { if (auto rightBasic = as<BasicExpressionType>(right)) { - auto leftFlavor = leftBasic->baseType; - auto rightFlavor = rightBasic->baseType; + auto leftFlavor = leftBasic->getBaseType(); + auto rightFlavor = rightBasic->getBaseType(); // TODO(tfoley): Need a special-case rule here that if // either operand is of type `half`, then we promote @@ -217,19 +217,19 @@ namespace Slang if(auto rightVector = as<VectorExpressionType>(right)) { // Check if the vector sizes match - if(!leftVector->elementCount->equalsVal(rightVector->elementCount)) + if(!leftVector->getElementCount()->equals(rightVector->getElementCount())) return nullptr; // Try to join the element types auto joinElementType = TryJoinTypes( - leftVector->elementType, - rightVector->elementType); + leftVector->getElementType(), + rightVector->getElementType()); if(!joinElementType) return nullptr; return createVectorType( joinElementType, - leftVector->elementCount); + leftVector->getElementCount()); } // We can also join a vector and a scalar @@ -242,7 +242,7 @@ namespace Slang // HACK: trying to work trait types in here... if(auto leftDeclRefType = as<DeclRefType>(left)) { - if( auto leftInterfaceRef = leftDeclRefType->declRef.as<InterfaceDecl>() ) + if( auto leftInterfaceRef = leftDeclRefType->getDeclRef().as<InterfaceDecl>() ) { // return _tryJoinTypeWithInterface(right, left); @@ -250,7 +250,7 @@ namespace Slang } if(auto rightDeclRefType = as<DeclRefType>(right)) { - if( auto rightInterfaceRef = rightDeclRefType->declRef.as<InterfaceDecl>() ) + if( auto rightInterfaceRef = rightDeclRefType->getDeclRef().as<InterfaceDecl>() ) { // return _tryJoinTypeWithInterface(left, right); @@ -263,10 +263,10 @@ namespace Slang return nullptr; } - SubstitutionSet SemanticsVisitor::trySolveConstraintSystem( + DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem( ConstraintSystem* system, DeclRef<GenericDecl> genericDeclRef, - GenericSubstitution* substWithKnownGenericArgs) + ArrayView<Val*> knownGenericArgs) { // For now the "solver" is going to be ridiculously simplistic. @@ -288,9 +288,8 @@ namespace Slang for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, genericDeclRef) ) { if(!TryUnifyTypes(*system, getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef))) - return SubstitutionSet(); + return DeclRef<Decl>(); } - SubstitutionSet resultSubst = genericDeclRef.getSubst(); // Once have built up the full list of constraints we are trying to satisfy, // we will attempt to solve for each parameter in a way that satisfies all @@ -310,10 +309,10 @@ namespace Slang // or not they are compatible with the constraints). // Count knownGenericArgCount = 0; - if (substWithKnownGenericArgs) + if (knownGenericArgs.getCount()) { - knownGenericArgCount = substWithKnownGenericArgs->getArgs().getCount(); - for (auto arg : substWithKnownGenericArgs->getArgs()) + knownGenericArgCount = knownGenericArgs.getCount(); + for (auto arg : knownGenericArgs) { args.add(arg); } @@ -364,7 +363,7 @@ namespace Slang if (!joinType) { // failure! - return SubstitutionSet(); + return DeclRef<Decl>(); } type = joinType; } @@ -375,7 +374,7 @@ namespace Slang if (!type) { // failure! - return SubstitutionSet(); + return DeclRef<Decl>(); } args.add(type); } @@ -417,10 +416,10 @@ namespace Slang } else { - if(!val->equalsVal(cVal)) + if(!val->equals(cVal)) { // failure! - return SubstitutionSet(); + return DeclRef<Decl>(); } } @@ -430,7 +429,7 @@ namespace Slang if (!val) { // failure! - return SubstitutionSet(); + return DeclRef<Decl>(); } args.add(val); } @@ -456,14 +455,10 @@ namespace Slang // search for a conformance `Robin : ISidekick`, which involved // apply the substitutions we already know... - GenericSubstitution* solvedSubst = m_astBuilder->getOrCreateGenericSubstitution( - genericDeclRef.getSubst(), genericDeclRef.getDecl(), args.getArrayView()); - for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { - DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getSpecializedDeclRef( - constraintDecl, - solvedSubst); + DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef( + genericDeclRef, args.getArrayView(), constraintDecl).as<GenericTypeConstraintDecl>(); // Extract the (substituted) sub- and super-type from the constraint. auto sub = getSub(m_astBuilder, constraintDeclRef); @@ -476,7 +471,7 @@ namespace Slang // not provide an explicit type parameter to specialize a generic // and the type parameter cannot be inferred from any arguments. // In this case, we should fail the constraint check. - return SubstitutionSet(); + return DeclRef<Decl>(); } // Search for a witness that shows the constraint is satisfied. @@ -492,7 +487,7 @@ namespace Slang // // TODO: Ideally we should print an error message in // this case, to let the user know why things failed. - return SubstitutionSet(); + return DeclRef<Decl>(); } // TODO: We may need to mark some constrains in our constraint @@ -505,13 +500,11 @@ namespace Slang { if (!c.satisfied) { - return SubstitutionSet(); + return DeclRef<Decl>(); } } - resultSubst = m_astBuilder->getOrCreateGenericSubstitution( - genericDeclRef.getSubst(), genericDeclRef.getDecl(), args); - return resultSubst; + return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView()); } bool SemanticsVisitor::TryUnifyVals( @@ -533,7 +526,7 @@ namespace Slang { if (auto sndIntVal = as<ConstantIntVal>(snd)) { - return fstIntVal->value == sndIntVal->value; + return fstIntVal->getValue() == sndIntVal->getValue(); } } @@ -541,23 +534,23 @@ namespace Slang if (auto fstInt = as<IntVal>(fst)) { if (auto tc = as<TypeCastIntVal>(fstInt)) - fstInt = as<IntVal>(tc->base); + fstInt = as<IntVal>(tc->getBase()); if (auto sndInt = as<IntVal>(snd)) { if (auto tc = as<TypeCastIntVal>(sndInt)) - sndInt = as<IntVal>(tc->base); + sndInt = as<IntVal>(tc->getBase()); auto fstParam = as<GenericParamIntVal>(fstInt); auto sndParam = as<GenericParamIntVal>(sndInt); bool okay = false; if (fstParam) { - if(TryUnifyIntParam(constraints, fstParam->declRef, sndInt)) + if(TryUnifyIntParam(constraints, fstParam->getDeclRef(), sndInt)) okay = true; } if (sndParam) { - if(TryUnifyIntParam(constraints, sndParam->declRef, fstInt)) + if(TryUnifyIntParam(constraints, sndParam->getDeclRef(), fstInt)) okay = true; } return okay; @@ -568,8 +561,8 @@ namespace Slang { if (auto sndWit = as<DeclaredSubtypeWitness>(snd)) { - auto constraintDecl1 = fstWit->declRef.as<TypeConstraintDecl>(); - auto constraintDecl2 = sndWit->declRef.as<TypeConstraintDecl>(); + auto constraintDecl1 = fstWit->getDeclRef().as<TypeConstraintDecl>(); + auto constraintDecl2 = sndWit->getDeclRef().as<TypeConstraintDecl>(); SLANG_ASSERT(constraintDecl1); SLANG_ASSERT(constraintDecl2); return TryUnifyTypes(constraints, @@ -586,8 +579,8 @@ namespace Slang if (auto sndWit = as<SubtypeWitness>(snd)) { return TryUnifyTypes(constraints, - fstWit->sup, - sndWit->sup); + fstWit->getSup(), + sndWit->getSup()); } } @@ -597,35 +590,28 @@ namespace Slang //return false; } - bool SemanticsVisitor::tryUnifySubstitutions( - ConstraintSystem& constraints, - Substitutions* fst, - Substitutions* snd) + bool SemanticsVisitor::tryUnifyDeclRef( + ConstraintSystem& constraints, + DeclRefBase* fst, + DeclRefBase* snd) { - // They must both be NULL or non-NULL - if (!fst || !snd) - return !fst && !snd; - - if(auto fstGeneric = as<GenericSubstitution>(fst)) - { - if(auto sndGeneric = as<GenericSubstitution>(snd)) - { - return tryUnifyGenericSubstitutions( - constraints, - fstGeneric, - sndGeneric); - } - } - - // TODO: need to handle other cases here - - return false; + if (fst == snd) + return true; + if (fst == nullptr || snd == nullptr) + return false; + auto fstGen = SubstitutionSet(fst).findGenericAppDeclRef(); + auto sndGen = SubstitutionSet(snd).findGenericAppDeclRef(); + if (fstGen == sndGen) + return true; + if (fstGen == nullptr || sndGen == nullptr) + return false; + return tryUnifyGenericAppDeclRef(constraints, fstGen, sndGen); } - bool SemanticsVisitor::tryUnifyGenericSubstitutions( + bool SemanticsVisitor::tryUnifyGenericAppDeclRef( ConstraintSystem& constraints, - GenericSubstitution* fst, - GenericSubstitution* snd) + GenericAppDeclRef* fst, + GenericAppDeclRef* snd) { SLANG_ASSERT(fst); SLANG_ASSERT(snd); @@ -649,7 +635,10 @@ namespace Slang } // Their "base" specializations must unify - if (!tryUnifySubstitutions(constraints, fstGen->getOuter(), sndGen->getOuter())) + auto fstBase = fst->getBase(); + auto sndBase = snd->getBase(); + + if (!tryUnifyDeclRef(constraints, fstBase, sndBase)) { okay = false; } @@ -718,14 +707,14 @@ namespace Slang { if (auto fstDeclRefType = as<DeclRefType>(fst)) { - auto fstDeclRef = fstDeclRefType->declRef; + auto fstDeclRef = fstDeclRefType->getDeclRef(); if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl())) return TryUnifyTypeParam(constraints, typeParamDecl, snd); if (auto sndDeclRefType = as<DeclRefType>(snd)) { - auto sndDeclRef = sndDeclRefType->declRef; + auto sndDeclRef = sndDeclRefType->getDeclRef(); if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl())) return TryUnifyTypeParam(constraints, typeParamDecl, fst); @@ -735,10 +724,10 @@ namespace Slang // next we need to unify the substitutions applied // to each declaration reference. - if (!tryUnifySubstitutions( + if (!tryUnifyDeclRef( constraints, - fstDeclRef.getSubst(), - sndDeclRef.getSubst())) + fstDeclRef, + sndDeclRef)) { return false; } @@ -749,15 +738,15 @@ namespace Slang { if (auto sndFunType = as<FuncType>(snd)) { - const Index numParams = fstFunType->paramTypes.getCount(); - if(numParams != sndFunType->paramTypes.getCount()) + const Index numParams = fstFunType->getParamCount(); + if(numParams != sndFunType->getParamCount()) return false; for(Index i = 0; i < numParams; ++i) { - if(!TryUnifyTypes(constraints, fstFunType->paramTypes[i], sndFunType->paramTypes[i])) + if(!TryUnifyTypes(constraints, fstFunType->getParamType(i), sndFunType->getParamType(i))) return false; } - return TryUnifyTypes(constraints, fstFunType->resultType, sndFunType->resultType); + return TryUnifyTypes(constraints, fstFunType->getResultType(), sndFunType->getResultType()); } } @@ -779,13 +768,13 @@ namespace Slang // if (auto fstAndType = as<AndType>(fst)) { - return TryUnifyTypes(constraints, fstAndType->left, snd) - && TryUnifyTypes(constraints, fstAndType->right, snd); + return TryUnifyTypes(constraints, fstAndType->getLeft(), snd) + && TryUnifyTypes(constraints, fstAndType->getRight(), snd); } else if (auto sndAndType = as<AndType>(snd)) { - return TryUnifyTypes(constraints, fst, sndAndType->left) - || TryUnifyTypes(constraints, fst, sndAndType->right); + return TryUnifyTypes(constraints, fst, sndAndType->getLeft()) + || TryUnifyTypes(constraints, fst, sndAndType->getRight()); } else return false; @@ -828,7 +817,7 @@ namespace Slang if (auto fstDeclRefType = as<DeclRefType>(fst)) { - auto fstDeclRef = fstDeclRefType->declRef; + auto fstDeclRef = fstDeclRefType->getDeclRef(); if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl())) { @@ -839,7 +828,7 @@ namespace Slang if (auto sndDeclRefType = as<DeclRefType>(snd)) { - auto sndDeclRef = sndDeclRefType->declRef; + auto sndDeclRef = sndDeclRefType->getDeclRef(); if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl())) { @@ -863,7 +852,7 @@ namespace Slang { return TryUnifyTypes( constraints, - fstVectorType->elementType, + fstVectorType->getElementType(), sndScalarType); } } @@ -875,7 +864,7 @@ namespace Slang return TryUnifyTypes( constraints, fstScalarType, - sndVectorType->elementType); + sndVectorType->getElementType()); } } |
