summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-constraint.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-04 15:47:39 -0700
committerGitHub <noreply@github.com>2023-08-04 15:47:39 -0700
commita2d90fb275962da84611160f8ddd74d934a68dbd (patch)
tree066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-check-constraint.cpp
parent17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff)
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val. * Update project files * Fix warning. * Fix. * Fix. * Remove `Val::_equalsImplOverride`. * Rmove `Val::_getHashCodeOverride`. * Remove `semanticVisitor` param from `resolve`. * Cleanups. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
-rw-r--r--source/slang/slang-check-constraint.cpp165
1 files changed, 77 insertions, 88 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 22a92bf0a..b9d33a1c1 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -65,14 +65,14 @@ namespace Slang
// That is, the join of a vector and a scalar type is
// a vector type with a joined element type.
auto joinElementType = TryJoinTypes(
- vectorType->elementType,
+ vectorType->getElementType(),
scalarType);
if(!joinElementType)
return nullptr;
return createVectorType(
joinElementType,
- vectorType->elementCount);
+ vectorType->getElementCount());
}
Type* SemanticsVisitor::_tryJoinTypeWithInterface(
@@ -110,11 +110,11 @@ namespace Slang
for(Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); baseTypeFlavorIndex++)
{
// Don't consider `type`, since we already know it doesn't work.
- if(baseTypeFlavorIndex == Int(basicType->baseType))
+ if(baseTypeFlavorIndex == Int(basicType->getBaseType()))
continue;
// Look up the type in our session.
- auto candidateType = type->getASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex));
+ auto candidateType = getCurrentASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex));
if(!candidateType)
continue;
@@ -186,8 +186,8 @@ namespace Slang
{
if (auto rightBasic = as<BasicExpressionType>(right))
{
- auto leftFlavor = leftBasic->baseType;
- auto rightFlavor = rightBasic->baseType;
+ auto leftFlavor = leftBasic->getBaseType();
+ auto rightFlavor = rightBasic->getBaseType();
// TODO(tfoley): Need a special-case rule here that if
// either operand is of type `half`, then we promote
@@ -217,19 +217,19 @@ namespace Slang
if(auto rightVector = as<VectorExpressionType>(right))
{
// Check if the vector sizes match
- if(!leftVector->elementCount->equalsVal(rightVector->elementCount))
+ if(!leftVector->getElementCount()->equals(rightVector->getElementCount()))
return nullptr;
// Try to join the element types
auto joinElementType = TryJoinTypes(
- leftVector->elementType,
- rightVector->elementType);
+ leftVector->getElementType(),
+ rightVector->getElementType());
if(!joinElementType)
return nullptr;
return createVectorType(
joinElementType,
- leftVector->elementCount);
+ leftVector->getElementCount());
}
// We can also join a vector and a scalar
@@ -242,7 +242,7 @@ namespace Slang
// HACK: trying to work trait types in here...
if(auto leftDeclRefType = as<DeclRefType>(left))
{
- if( auto leftInterfaceRef = leftDeclRefType->declRef.as<InterfaceDecl>() )
+ if( auto leftInterfaceRef = leftDeclRefType->getDeclRef().as<InterfaceDecl>() )
{
//
return _tryJoinTypeWithInterface(right, left);
@@ -250,7 +250,7 @@ namespace Slang
}
if(auto rightDeclRefType = as<DeclRefType>(right))
{
- if( auto rightInterfaceRef = rightDeclRefType->declRef.as<InterfaceDecl>() )
+ if( auto rightInterfaceRef = rightDeclRefType->getDeclRef().as<InterfaceDecl>() )
{
//
return _tryJoinTypeWithInterface(left, right);
@@ -263,10 +263,10 @@ namespace Slang
return nullptr;
}
- SubstitutionSet SemanticsVisitor::trySolveConstraintSystem(
+ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
ConstraintSystem* system,
DeclRef<GenericDecl> genericDeclRef,
- GenericSubstitution* substWithKnownGenericArgs)
+ ArrayView<Val*> knownGenericArgs)
{
// For now the "solver" is going to be ridiculously simplistic.
@@ -288,9 +288,8 @@ namespace Slang
for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, genericDeclRef) )
{
if(!TryUnifyTypes(*system, getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef)))
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
- SubstitutionSet resultSubst = genericDeclRef.getSubst();
// Once have built up the full list of constraints we are trying to satisfy,
// we will attempt to solve for each parameter in a way that satisfies all
@@ -310,10 +309,10 @@ namespace Slang
// or not they are compatible with the constraints).
//
Count knownGenericArgCount = 0;
- if (substWithKnownGenericArgs)
+ if (knownGenericArgs.getCount())
{
- knownGenericArgCount = substWithKnownGenericArgs->getArgs().getCount();
- for (auto arg : substWithKnownGenericArgs->getArgs())
+ knownGenericArgCount = knownGenericArgs.getCount();
+ for (auto arg : knownGenericArgs)
{
args.add(arg);
}
@@ -364,7 +363,7 @@ namespace Slang
if (!joinType)
{
// failure!
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
type = joinType;
}
@@ -375,7 +374,7 @@ namespace Slang
if (!type)
{
// failure!
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
args.add(type);
}
@@ -417,10 +416,10 @@ namespace Slang
}
else
{
- if(!val->equalsVal(cVal))
+ if(!val->equals(cVal))
{
// failure!
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
}
@@ -430,7 +429,7 @@ namespace Slang
if (!val)
{
// failure!
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
args.add(val);
}
@@ -456,14 +455,10 @@ namespace Slang
// search for a conformance `Robin : ISidekick`, which involved
// apply the substitutions we already know...
- GenericSubstitution* solvedSubst = m_astBuilder->getOrCreateGenericSubstitution(
- genericDeclRef.getSubst(), genericDeclRef.getDecl(), args.getArrayView());
-
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
- DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getSpecializedDeclRef(
- constraintDecl,
- solvedSubst);
+ DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(
+ genericDeclRef, args.getArrayView(), constraintDecl).as<GenericTypeConstraintDecl>();
// Extract the (substituted) sub- and super-type from the constraint.
auto sub = getSub(m_astBuilder, constraintDeclRef);
@@ -476,7 +471,7 @@ namespace Slang
// not provide an explicit type parameter to specialize a generic
// and the type parameter cannot be inferred from any arguments.
// In this case, we should fail the constraint check.
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
// Search for a witness that shows the constraint is satisfied.
@@ -492,7 +487,7 @@ namespace Slang
//
// TODO: Ideally we should print an error message in
// this case, to let the user know why things failed.
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
// TODO: We may need to mark some constrains in our constraint
@@ -505,13 +500,11 @@ namespace Slang
{
if (!c.satisfied)
{
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
}
- resultSubst = m_astBuilder->getOrCreateGenericSubstitution(
- genericDeclRef.getSubst(), genericDeclRef.getDecl(), args);
- return resultSubst;
+ return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView());
}
bool SemanticsVisitor::TryUnifyVals(
@@ -533,7 +526,7 @@ namespace Slang
{
if (auto sndIntVal = as<ConstantIntVal>(snd))
{
- return fstIntVal->value == sndIntVal->value;
+ return fstIntVal->getValue() == sndIntVal->getValue();
}
}
@@ -541,23 +534,23 @@ namespace Slang
if (auto fstInt = as<IntVal>(fst))
{
if (auto tc = as<TypeCastIntVal>(fstInt))
- fstInt = as<IntVal>(tc->base);
+ fstInt = as<IntVal>(tc->getBase());
if (auto sndInt = as<IntVal>(snd))
{
if (auto tc = as<TypeCastIntVal>(sndInt))
- sndInt = as<IntVal>(tc->base);
+ sndInt = as<IntVal>(tc->getBase());
auto fstParam = as<GenericParamIntVal>(fstInt);
auto sndParam = as<GenericParamIntVal>(sndInt);
bool okay = false;
if (fstParam)
{
- if(TryUnifyIntParam(constraints, fstParam->declRef, sndInt))
+ if(TryUnifyIntParam(constraints, fstParam->getDeclRef(), sndInt))
okay = true;
}
if (sndParam)
{
- if(TryUnifyIntParam(constraints, sndParam->declRef, fstInt))
+ if(TryUnifyIntParam(constraints, sndParam->getDeclRef(), fstInt))
okay = true;
}
return okay;
@@ -568,8 +561,8 @@ namespace Slang
{
if (auto sndWit = as<DeclaredSubtypeWitness>(snd))
{
- auto constraintDecl1 = fstWit->declRef.as<TypeConstraintDecl>();
- auto constraintDecl2 = sndWit->declRef.as<TypeConstraintDecl>();
+ auto constraintDecl1 = fstWit->getDeclRef().as<TypeConstraintDecl>();
+ auto constraintDecl2 = sndWit->getDeclRef().as<TypeConstraintDecl>();
SLANG_ASSERT(constraintDecl1);
SLANG_ASSERT(constraintDecl2);
return TryUnifyTypes(constraints,
@@ -586,8 +579,8 @@ namespace Slang
if (auto sndWit = as<SubtypeWitness>(snd))
{
return TryUnifyTypes(constraints,
- fstWit->sup,
- sndWit->sup);
+ fstWit->getSup(),
+ sndWit->getSup());
}
}
@@ -597,35 +590,28 @@ namespace Slang
//return false;
}
- bool SemanticsVisitor::tryUnifySubstitutions(
- ConstraintSystem& constraints,
- Substitutions* fst,
- Substitutions* snd)
+ bool SemanticsVisitor::tryUnifyDeclRef(
+ ConstraintSystem& constraints,
+ DeclRefBase* fst,
+ DeclRefBase* snd)
{
- // They must both be NULL or non-NULL
- if (!fst || !snd)
- return !fst && !snd;
-
- if(auto fstGeneric = as<GenericSubstitution>(fst))
- {
- if(auto sndGeneric = as<GenericSubstitution>(snd))
- {
- return tryUnifyGenericSubstitutions(
- constraints,
- fstGeneric,
- sndGeneric);
- }
- }
-
- // TODO: need to handle other cases here
-
- return false;
+ if (fst == snd)
+ return true;
+ if (fst == nullptr || snd == nullptr)
+ return false;
+ auto fstGen = SubstitutionSet(fst).findGenericAppDeclRef();
+ auto sndGen = SubstitutionSet(snd).findGenericAppDeclRef();
+ if (fstGen == sndGen)
+ return true;
+ if (fstGen == nullptr || sndGen == nullptr)
+ return false;
+ return tryUnifyGenericAppDeclRef(constraints, fstGen, sndGen);
}
- bool SemanticsVisitor::tryUnifyGenericSubstitutions(
+ bool SemanticsVisitor::tryUnifyGenericAppDeclRef(
ConstraintSystem& constraints,
- GenericSubstitution* fst,
- GenericSubstitution* snd)
+ GenericAppDeclRef* fst,
+ GenericAppDeclRef* snd)
{
SLANG_ASSERT(fst);
SLANG_ASSERT(snd);
@@ -649,7 +635,10 @@ namespace Slang
}
// Their "base" specializations must unify
- if (!tryUnifySubstitutions(constraints, fstGen->getOuter(), sndGen->getOuter()))
+ auto fstBase = fst->getBase();
+ auto sndBase = snd->getBase();
+
+ if (!tryUnifyDeclRef(constraints, fstBase, sndBase))
{
okay = false;
}
@@ -718,14 +707,14 @@ namespace Slang
{
if (auto fstDeclRefType = as<DeclRefType>(fst))
{
- auto fstDeclRef = fstDeclRefType->declRef;
+ auto fstDeclRef = fstDeclRefType->getDeclRef();
if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
return TryUnifyTypeParam(constraints, typeParamDecl, snd);
if (auto sndDeclRefType = as<DeclRefType>(snd))
{
- auto sndDeclRef = sndDeclRefType->declRef;
+ auto sndDeclRef = sndDeclRefType->getDeclRef();
if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl()))
return TryUnifyTypeParam(constraints, typeParamDecl, fst);
@@ -735,10 +724,10 @@ namespace Slang
// next we need to unify the substitutions applied
// to each declaration reference.
- if (!tryUnifySubstitutions(
+ if (!tryUnifyDeclRef(
constraints,
- fstDeclRef.getSubst(),
- sndDeclRef.getSubst()))
+ fstDeclRef,
+ sndDeclRef))
{
return false;
}
@@ -749,15 +738,15 @@ namespace Slang
{
if (auto sndFunType = as<FuncType>(snd))
{
- const Index numParams = fstFunType->paramTypes.getCount();
- if(numParams != sndFunType->paramTypes.getCount())
+ const Index numParams = fstFunType->getParamCount();
+ if(numParams != sndFunType->getParamCount())
return false;
for(Index i = 0; i < numParams; ++i)
{
- if(!TryUnifyTypes(constraints, fstFunType->paramTypes[i], sndFunType->paramTypes[i]))
+ if(!TryUnifyTypes(constraints, fstFunType->getParamType(i), sndFunType->getParamType(i)))
return false;
}
- return TryUnifyTypes(constraints, fstFunType->resultType, sndFunType->resultType);
+ return TryUnifyTypes(constraints, fstFunType->getResultType(), sndFunType->getResultType());
}
}
@@ -779,13 +768,13 @@ namespace Slang
//
if (auto fstAndType = as<AndType>(fst))
{
- return TryUnifyTypes(constraints, fstAndType->left, snd)
- && TryUnifyTypes(constraints, fstAndType->right, snd);
+ return TryUnifyTypes(constraints, fstAndType->getLeft(), snd)
+ && TryUnifyTypes(constraints, fstAndType->getRight(), snd);
}
else if (auto sndAndType = as<AndType>(snd))
{
- return TryUnifyTypes(constraints, fst, sndAndType->left)
- || TryUnifyTypes(constraints, fst, sndAndType->right);
+ return TryUnifyTypes(constraints, fst, sndAndType->getLeft())
+ || TryUnifyTypes(constraints, fst, sndAndType->getRight());
}
else
return false;
@@ -828,7 +817,7 @@ namespace Slang
if (auto fstDeclRefType = as<DeclRefType>(fst))
{
- auto fstDeclRef = fstDeclRefType->declRef;
+ auto fstDeclRef = fstDeclRefType->getDeclRef();
if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
{
@@ -839,7 +828,7 @@ namespace Slang
if (auto sndDeclRefType = as<DeclRefType>(snd))
{
- auto sndDeclRef = sndDeclRefType->declRef;
+ auto sndDeclRef = sndDeclRefType->getDeclRef();
if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl()))
{
@@ -863,7 +852,7 @@ namespace Slang
{
return TryUnifyTypes(
constraints,
- fstVectorType->elementType,
+ fstVectorType->getElementType(),
sndScalarType);
}
}
@@ -875,7 +864,7 @@ namespace Slang
return TryUnifyTypes(
constraints,
fstScalarType,
- sndVectorType->elementType);
+ sndVectorType->getElementType());
}
}