summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-constraint.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
-rw-r--r--source/slang/slang-check-constraint.cpp740
1 files changed, 740 insertions, 0 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
new file mode 100644
index 000000000..e12997904
--- /dev/null
+++ b/source/slang/slang-check-constraint.cpp
@@ -0,0 +1,740 @@
+// slang-check-constraint.cpp
+#include "slang-check-impl.h"
+
+// This file provides the core services for creating
+// and solving constraint systems during semantic checking.
+//
+// We currently use constraint systems primarily to solve
+// for the implied values to use for generic parameters when a
+// generic declaration is being applied without explicit
+// generic arguments.
+//
+// Conceptually, our constraint-solving strategy starts by
+// trying to "unify" the actual argument types to a call
+// with the parameter types of the callee (which may mention
+// generic parameters). E.g., if we have a situation like:
+//
+// void doIt<T>(T a, vector<T,3> b);
+//
+// int x, y;
+// ...
+// doIt(x, y);
+//
+// then an we would try to unify the type of the argument
+// `x` (which is `int`) with the type of the parameter `a`
+// (which is `T`). Attempting to unify a concrete type
+// and a generic type parameter would (in the simplest case)
+// give rise to a constraint that, e.g., `T` must be `int`.
+//
+// In our example, unifying `y` and `b` creates a more complex
+// scenario, because we cannot ever unify `int` with `vector<T,3>`;
+// there is no possible value of `T` for which those two types
+// are equivalent.
+//
+// So instead of the simpler approach to unification (which
+// works well for languages without implicit type conversion),
+// our approach to unification recognizes that scalar types
+// can be promoted to vectors, and thus tries to unify the
+// type of `y` with the element type of `b`.
+//
+// When it comes time to actually solve the constraints, we
+// might have seemingly conflicting constraints:
+//
+// void another<U>(U a, U b);
+//
+// float x; int y;
+// another(x, y);
+//
+// In this case we'd have constraints that `U` must be `int`,
+// *and* that `U` must be `float`, which is clearly impossible
+// to satisfy. Instead, our constraints are treated as a kind
+// of "lower bound" on the type variable, and we combine
+// those lower bounds using the "join" operation (in the
+// sense of "meet" and "join" on lattices), which ideally
+// gives us a type for `U` that all the argument types can
+// convert to.
+
+namespace Slang
+{
+ RefPtr<Type> SemanticsVisitor::TryJoinVectorAndScalarType(
+ RefPtr<VectorExpressionType> vectorType,
+ RefPtr<BasicExpressionType> scalarType)
+ {
+ // Join( vector<T,N>, S ) -> vetor<Join(T,S), N>
+ //
+ // That is, the join of a vector and a scalar type is
+ // a vector type with a joined element type.
+ auto joinElementType = TryJoinTypes(
+ vectorType->elementType,
+ scalarType);
+ if(!joinElementType)
+ return nullptr;
+
+ return createVectorType(
+ joinElementType,
+ vectorType->elementCount);
+ }
+
+ RefPtr<Type> SemanticsVisitor::TryJoinTypeWithInterface(
+ RefPtr<Type> type,
+ DeclRef<InterfaceDecl> interfaceDeclRef)
+ {
+ // The most basic test here should be: does the type declare conformance to the trait.
+ if(DoesTypeConformToInterface(type, interfaceDeclRef))
+ return type;
+
+ // Just because `type` doesn't conform to the given `interfaceDeclRef`, that
+ // doesn't necessarily indicate a failure. It is possible that we have a call
+ // like `sqrt(2)` so that `type` is `int` and `interfaceDeclRef` is
+ // `__BuiltinFloatingPointType`. The "obvious" answer is that we should infer
+ // the type `float`, but it seems like the compiler would have to synthesize
+ // that answer from thin air.
+ //
+ // A robsut/correct solution here might be to enumerate set of types types `S`
+ // such that for each type `X` in `S`:
+ //
+ // * `type` is implicitly convertible to `X`
+ // * `X` conforms to the interface named by `interfaceDeclRef`
+ //
+ // If the set `S` is non-empty then we would try to pick the "best" type from `S`.
+ // The "best" type would be a type `Y` such that `Y` is implicitly convertible to
+ // every other type in `S`.
+ //
+ // We are going to implement a much simpler strategy for now, where we only apply
+ // the search process if `type` is a builtin scalar type, and then we only search
+ // through types `X` that are also builtin scalar types.
+ //
+ RefPtr<Type> bestType;
+ if(auto basicType = type.dynamicCast<BasicExpressionType>())
+ {
+ for(Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); baseTypeFlavorIndex++)
+ {
+ // Don't consider `type`, since we already know it doesn't work.
+ if(baseTypeFlavorIndex == Int(basicType->baseType))
+ continue;
+
+ // Look up the type in our session.
+ auto candidateType = type->getSession()->getBuiltinType(BaseType(baseTypeFlavorIndex));
+ if(!candidateType)
+ continue;
+
+ // We only want to consider types that implement the target interface.
+ if(!DoesTypeConformToInterface(candidateType, interfaceDeclRef))
+ continue;
+
+ // We only want to consider types where we can implicitly convert from `type`
+ if(!canConvertImplicitly(candidateType, type))
+ continue;
+
+ // At this point, we have a candidate type that is usable.
+ //
+ // If this is our first viable candidate, then it is our best one:
+ //
+ if(!bestType)
+ {
+ bestType = candidateType;
+ }
+ else
+ {
+ // Otherwise, we want to pick the "better" type between `candidateType`
+ // and `bestType`.
+ //
+ // We are going to be a bit loose here, and not worry about the
+ // case where conversion is allowed in both directions.
+ //
+ // TODO: make this completely robust.
+ //
+ if(canConvertImplicitly(bestType, candidateType))
+ {
+ // Our candidate can convert to the current "best" type, so
+ // it is logically a more specific type that satisfies our
+ // constraints, therefore we should keep it.
+ //
+ bestType = candidateType;
+ }
+ }
+ }
+ if(bestType)
+ return bestType;
+ }
+
+ // For all other cases, we will just bail out for now.
+ //
+ // TODO: In the future we should build some kind of side data structure
+ // to accelerate either one or both of these queries:
+ //
+ // * Given a type `T`, what types `U` can it convert to implicitly?
+ //
+ // * Given an interface `I`, what types `U` conform to it?
+ //
+ // The intersection of the sets returned by these two queries is
+ // the set of candidates we would like to consider here.
+
+ return nullptr;
+ }
+
+ RefPtr<Type> SemanticsVisitor::TryJoinTypes(
+ RefPtr<Type> left,
+ RefPtr<Type> right)
+ {
+ // Easy case: they are the same type!
+ if (left->Equals(right))
+ return left;
+
+ // We can join two basic types by picking the "better" of the two
+ if (auto leftBasic = as<BasicExpressionType>(left))
+ {
+ if (auto rightBasic = as<BasicExpressionType>(right))
+ {
+ auto leftFlavor = leftBasic->baseType;
+ auto rightFlavor = rightBasic->baseType;
+
+ // TODO(tfoley): Need a special-case rule here that if
+ // either operand is of type `half`, then we promote
+ // to at least `float`
+
+ // Return the one that had higher rank...
+ if (leftFlavor > rightFlavor)
+ return left;
+ else
+ {
+ SLANG_ASSERT(rightFlavor > leftFlavor); // equality was handles at the top of this function
+ return right;
+ }
+ }
+
+ // We can also join a vector and a scalar
+ if(auto rightVector = as<VectorExpressionType>(right))
+ {
+ return TryJoinVectorAndScalarType(rightVector, leftBasic);
+ }
+ }
+
+ // We can join two vector types by joining their element types
+ // (and also their sizes...)
+ if( auto leftVector = as<VectorExpressionType>(left))
+ {
+ if(auto rightVector = as<VectorExpressionType>(right))
+ {
+ // Check if the vector sizes match
+ if(!leftVector->elementCount->EqualsVal(rightVector->elementCount.Ptr()))
+ return nullptr;
+
+ // Try to join the element types
+ auto joinElementType = TryJoinTypes(
+ leftVector->elementType,
+ rightVector->elementType);
+ if(!joinElementType)
+ return nullptr;
+
+ return createVectorType(
+ joinElementType,
+ leftVector->elementCount);
+ }
+
+ // We can also join a vector and a scalar
+ if(auto rightBasic = as<BasicExpressionType>(right))
+ {
+ return TryJoinVectorAndScalarType(leftVector, rightBasic);
+ }
+ }
+
+ // HACK: trying to work trait types in here...
+ if(auto leftDeclRefType = as<DeclRefType>(left))
+ {
+ if( auto leftInterfaceRef = leftDeclRefType->declRef.as<InterfaceDecl>() )
+ {
+ //
+ return TryJoinTypeWithInterface(right, leftInterfaceRef);
+ }
+ }
+ if(auto rightDeclRefType = as<DeclRefType>(right))
+ {
+ if( auto rightInterfaceRef = rightDeclRefType->declRef.as<InterfaceDecl>() )
+ {
+ //
+ return TryJoinTypeWithInterface(left, rightInterfaceRef);
+ }
+ }
+
+ // TODO: all the cases for vectors apply to matrices too!
+
+ // Default case is that we just fail.
+ return nullptr;
+ }
+
+ SubstitutionSet SemanticsVisitor::TrySolveConstraintSystem(
+ ConstraintSystem* system,
+ DeclRef<GenericDecl> genericDeclRef)
+ {
+ // For now the "solver" is going to be ridiculously simplistic.
+
+ // The generic itself will have some constraints, and for now we add these
+ // to the system of constrains we will use for solving for the type variables.
+ //
+ // TODO: we need to decide whether constraints are used like this to influence
+ // how we solve for type/value variables, or whether constraints in the parameter
+ // list just work as a validation step *after* we've solved for the types.
+ //
+ // That is, should we allow `<T : Int>` to be written, and cause us to "infer"
+ // that `T` should be the type `Int`? That seems a little silly.
+ //
+ // Eventually, though, we may want to support type identity constraints, especially
+ // on associated types, like `<C where C : IContainer && C.IndexType == Int>`
+ // These seem more reasonable to have influence constraint solving, since it could
+ // conceivably let us specialize a `X<T> : IContainer` to `X<Int>` if we find
+ // that `X<T>.IndexType == T`.
+ for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef) )
+ {
+ if(!TryUnifyTypes(*system, GetSub(constraintDeclRef), GetSup(constraintDeclRef)))
+ return SubstitutionSet();
+ }
+ SubstitutionSet resultSubst = genericDeclRef.substitutions;
+ // We will loop over the generic parameters, and for
+ // each we will try to find a way to satisfy all
+ // the constraints for that parameter
+ List<RefPtr<Val>> args;
+ for (auto m : getMembers(genericDeclRef))
+ {
+ if (auto typeParam = m.as<GenericTypeParamDecl>())
+ {
+ RefPtr<Type> type = nullptr;
+ for (auto& c : system->constraints)
+ {
+ if (c.decl != typeParam.getDecl())
+ continue;
+
+ auto cType = as<Type>(c.val);
+ SLANG_RELEASE_ASSERT(cType);
+
+ if (!type)
+ {
+ type = cType;
+ }
+ else
+ {
+ auto joinType = TryJoinTypes(type, cType);
+ if (!joinType)
+ {
+ // failure!
+ return SubstitutionSet();
+ }
+ type = joinType;
+ }
+
+ c.satisfied = true;
+ }
+
+ if (!type)
+ {
+ // failure!
+ return SubstitutionSet();
+ }
+ args.add(type);
+ }
+ else if (auto valParam = m.as<GenericValueParamDecl>())
+ {
+ // TODO(tfoley): maybe support more than integers some day?
+ // TODO(tfoley): figure out how this needs to interact with
+ // compile-time integers that aren't just constants...
+ RefPtr<IntVal> val = nullptr;
+ for (auto& c : system->constraints)
+ {
+ if (c.decl != valParam.getDecl())
+ continue;
+
+ auto cVal = as<IntVal>(c.val);
+ SLANG_RELEASE_ASSERT(cVal);
+
+ if (!val)
+ {
+ val = cVal;
+ }
+ else
+ {
+ if(!val->EqualsVal(cVal))
+ {
+ // failure!
+ return SubstitutionSet();
+ }
+ }
+
+ c.satisfied = true;
+ }
+
+ if (!val)
+ {
+ // failure!
+ return SubstitutionSet();
+ }
+ args.add(val);
+ }
+ else
+ {
+ // ignore anything that isn't a generic parameter
+ }
+ }
+
+ // After we've solved for the explicit arguments, we need to
+ // make a second pass and consider the implicit arguments,
+ // based on what we've already determined to be the values
+ // for the explicit arguments.
+
+ // Before we begin, we are going to go ahead and create the
+ // "solved" substitution that we will return if everything works.
+ // This is because we are going to use this substitution,
+ // partially filled in with the results we know so far,
+ // in order to specialize any constraints on the generic.
+ //
+ // E.g., if the generic parameters were `<T : ISidekick>`, and
+ // we've already decided that `T` is `Robin`, then we want to
+ // search for a conformance `Robin : ISidekick`, which involved
+ // apply the substitutions we already know...
+
+ RefPtr<GenericSubstitution> solvedSubst = new GenericSubstitution();
+ solvedSubst->genericDecl = genericDeclRef.getDecl();
+ solvedSubst->outer = genericDeclRef.substitutions.substitutions;
+ solvedSubst->args = args;
+ resultSubst.substitutions = solvedSubst;
+
+ for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
+ {
+ DeclRef<GenericTypeConstraintDecl> constraintDeclRef(
+ constraintDecl,
+ solvedSubst);
+
+ // Extract the (substituted) sub- and super-type from the constraint.
+ auto sub = GetSub(constraintDeclRef);
+ auto sup = GetSup(constraintDeclRef);
+
+ // Search for a witness that shows the constraint is satisfied.
+ auto subTypeWitness = tryGetSubtypeWitness(sub, sup);
+ if(subTypeWitness)
+ {
+ // We found a witness, so it will become an (implicit) argument.
+ solvedSubst->args.add(subTypeWitness);
+ }
+ else
+ {
+ // No witness was found, so the inference will now fail.
+ //
+ // TODO: Ideally we should print an error message in
+ // this case, to let the user know why things failed.
+ return SubstitutionSet();
+ }
+
+ // TODO: We may need to mark some constrains in our constraint
+ // system as being solved now, as a result of the witness we found.
+ }
+
+ // Make sure we haven't constructed any spurious constraints
+ // that we aren't able to satisfy:
+ for (auto c : system->constraints)
+ {
+ if (!c.satisfied)
+ {
+ return SubstitutionSet();
+ }
+ }
+
+ return resultSubst;
+ }
+
+ bool SemanticsVisitor::TryUnifyVals(
+ ConstraintSystem& constraints,
+ RefPtr<Val> fst,
+ RefPtr<Val> snd)
+ {
+ // if both values are types, then unify types
+ if (auto fstType = as<Type>(fst))
+ {
+ if (auto sndType = as<Type>(snd))
+ {
+ return TryUnifyTypes(constraints, fstType, sndType);
+ }
+ }
+
+ // if both values are constant integers, then compare them
+ if (auto fstIntVal = as<ConstantIntVal>(fst))
+ {
+ if (auto sndIntVal = as<ConstantIntVal>(snd))
+ {
+ return fstIntVal->value == sndIntVal->value;
+ }
+ }
+
+ // Check if both are integer values in general
+ if (auto fstInt = as<IntVal>(fst))
+ {
+ if (auto sndInt = as<IntVal>(snd))
+ {
+ auto fstParam = as<GenericParamIntVal>(fstInt);
+ auto sndParam = as<GenericParamIntVal>(sndInt);
+
+ bool okay = false;
+ if (fstParam)
+ {
+ if(TryUnifyIntParam(constraints, fstParam->declRef, sndInt))
+ okay = true;
+ }
+ if (sndParam)
+ {
+ if(TryUnifyIntParam(constraints, sndParam->declRef, fstInt))
+ okay = true;
+ }
+ return okay;
+ }
+ }
+
+ if (auto fstWit = as<DeclaredSubtypeWitness>(fst))
+ {
+ if (auto sndWit = as<DeclaredSubtypeWitness>(snd))
+ {
+ auto constraintDecl1 = fstWit->declRef.as<TypeConstraintDecl>();
+ auto constraintDecl2 = sndWit->declRef.as<TypeConstraintDecl>();
+ SLANG_ASSERT(constraintDecl1);
+ SLANG_ASSERT(constraintDecl2);
+ return TryUnifyTypes(constraints,
+ constraintDecl1.getDecl()->getSup().type,
+ constraintDecl2.getDecl()->getSup().type);
+ }
+ }
+
+ SLANG_UNIMPLEMENTED_X("value unification case");
+
+ // default: fail
+ return false;
+ }
+
+ bool SemanticsVisitor::tryUnifySubstitutions(
+ ConstraintSystem& constraints,
+ RefPtr<Substitutions> fst,
+ RefPtr<Substitutions> snd)
+ {
+ // They must both be NULL or non-NULL
+ if (!fst || !snd)
+ return !fst && !snd;
+
+ if(auto fstGeneric = as<GenericSubstitution>(fst))
+ {
+ if(auto sndGeneric = as<GenericSubstitution>(snd))
+ {
+ return tryUnifyGenericSubstitutions(
+ constraints,
+ fstGeneric,
+ sndGeneric);
+ }
+ }
+
+ // TODO: need to handle other cases here
+
+ return false;
+ }
+
+ bool SemanticsVisitor::tryUnifyGenericSubstitutions(
+ ConstraintSystem& constraints,
+ RefPtr<GenericSubstitution> fst,
+ RefPtr<GenericSubstitution> snd)
+ {
+ SLANG_ASSERT(fst);
+ SLANG_ASSERT(snd);
+
+ auto fstGen = fst;
+ auto sndGen = snd;
+ // They must be specializing the same generic
+ if (fstGen->genericDecl != sndGen->genericDecl)
+ return false;
+
+ // Their arguments must unify
+ SLANG_RELEASE_ASSERT(fstGen->args.getCount() == sndGen->args.getCount());
+ Index argCount = fstGen->args.getCount();
+ bool okay = true;
+ for (Index aa = 0; aa < argCount; ++aa)
+ {
+ if (!TryUnifyVals(constraints, fstGen->args[aa], sndGen->args[aa]))
+ {
+ okay = false;
+ }
+ }
+
+ // Their "base" specializations must unify
+ if (!tryUnifySubstitutions(constraints, fstGen->outer, sndGen->outer))
+ {
+ okay = false;
+ }
+
+ return okay;
+ }
+
+ bool SemanticsVisitor::TryUnifyTypeParam(
+ ConstraintSystem& constraints,
+ RefPtr<GenericTypeParamDecl> typeParamDecl,
+ RefPtr<Type> type)
+ {
+ // We want to constrain the given type parameter
+ // to equal the given type.
+ Constraint constraint;
+ constraint.decl = typeParamDecl.Ptr();
+ constraint.val = type;
+
+ constraints.constraints.add(constraint);
+
+ return true;
+ }
+
+ bool SemanticsVisitor::TryUnifyIntParam(
+ ConstraintSystem& constraints,
+ RefPtr<GenericValueParamDecl> paramDecl,
+ RefPtr<IntVal> val)
+ {
+ // We only want to accumulate constraints on
+ // the parameters of the declarations being
+ // specialized (don't accidentially constrain
+ // parameters of a generic function based on
+ // calls in its body).
+ if(paramDecl->ParentDecl != constraints.genericDecl)
+ return false;
+
+ // We want to constrain the given parameter to equal the given value.
+ Constraint constraint;
+ constraint.decl = paramDecl.Ptr();
+ constraint.val = val;
+
+ constraints.constraints.add(constraint);
+
+ return true;
+ }
+
+ bool SemanticsVisitor::TryUnifyIntParam(
+ ConstraintSystem& constraints,
+ DeclRef<VarDeclBase> const& varRef,
+ RefPtr<IntVal> val)
+ {
+ if(auto genericValueParamRef = varRef.as<GenericValueParamDecl>())
+ {
+ return TryUnifyIntParam(constraints, RefPtr<GenericValueParamDecl>(genericValueParamRef.getDecl()), val);
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ bool SemanticsVisitor::TryUnifyTypesByStructuralMatch(
+ ConstraintSystem& constraints,
+ RefPtr<Type> fst,
+ RefPtr<Type> snd)
+ {
+ if (auto fstDeclRefType = as<DeclRefType>(fst))
+ {
+ auto fstDeclRef = fstDeclRefType->declRef;
+
+ if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
+ return TryUnifyTypeParam(constraints, typeParamDecl, snd);
+
+ if (auto sndDeclRefType = as<DeclRefType>(snd))
+ {
+ auto sndDeclRef = sndDeclRefType->declRef;
+
+ if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl()))
+ return TryUnifyTypeParam(constraints, typeParamDecl, fst);
+
+ // can't be unified if they refer to different declarations.
+ if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) return false;
+
+ // next we need to unify the substitutions applied
+ // to each declaration reference.
+ if (!tryUnifySubstitutions(
+ constraints,
+ fstDeclRef.substitutions.substitutions,
+ sndDeclRef.substitutions.substitutions))
+ {
+ return false;
+ }
+
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ bool SemanticsVisitor::TryUnifyTypes(
+ ConstraintSystem& constraints,
+ RefPtr<Type> fst,
+ RefPtr<Type> snd)
+ {
+ if (fst->Equals(snd)) return true;
+
+ // An error type can unify with anything, just so we avoid cascading errors.
+
+ if (auto fstErrorType = as<ErrorType>(fst))
+ return true;
+
+ if (auto sndErrorType = as<ErrorType>(snd))
+ return true;
+
+ // A generic parameter type can unify with anything.
+ // TODO: there actually needs to be some kind of "occurs check" sort
+ // of thing here...
+
+ if (auto fstDeclRefType = as<DeclRefType>(fst))
+ {
+ auto fstDeclRef = fstDeclRefType->declRef;
+
+ if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
+ {
+ if(typeParamDecl->ParentDecl == constraints.genericDecl )
+ return TryUnifyTypeParam(constraints, typeParamDecl, snd);
+ }
+ }
+
+ if (auto sndDeclRefType = as<DeclRefType>(snd))
+ {
+ auto sndDeclRef = sndDeclRefType->declRef;
+
+ if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl()))
+ {
+ if(typeParamDecl->ParentDecl == constraints.genericDecl )
+ return TryUnifyTypeParam(constraints, typeParamDecl, fst);
+ }
+ }
+
+ // If we can unify the types structurally, then we are golden
+ if(TryUnifyTypesByStructuralMatch(constraints, fst, snd))
+ return true;
+
+ // Now we need to consider cases where coercion might
+ // need to be applied. For now we can try to do this
+ // in a completely ad hoc fashion, but eventually we'd
+ // want to do it more formally.
+
+ if(auto fstVectorType = as<VectorExpressionType>(fst))
+ {
+ if(auto sndScalarType = as<BasicExpressionType>(snd))
+ {
+ return TryUnifyTypes(
+ constraints,
+ fstVectorType->elementType,
+ sndScalarType);
+ }
+ }
+
+ if(auto fstScalarType = as<BasicExpressionType>(fst))
+ {
+ if(auto sndVectorType = as<VectorExpressionType>(snd))
+ {
+ return TryUnifyTypes(
+ constraints,
+ fstScalarType,
+ sndVectorType->elementType);
+ }
+ }
+
+ // TODO: the same thing for vectors...
+
+ return false;
+ }
+
+
+}