summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-constraint.cpp
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-05-29 16:36:49 -0700
committerGitHub <noreply@github.com>2025-05-29 16:36:49 -0700
commit984d7f22f8a0909dc870c65bb927094c54f55402 (patch)
treeab255bf44e14f6cbaa09522f90b12464f1c6a339 /source/slang/slang-check-constraint.cpp
parentf4d7954e088966c2ae8618b1cc17aac4d64ef013 (diff)
Implement MapElement for CoopMat (#7159)
With this PR, MapElement works for the following signatures: - CoopMat<...>::MapElement(functype(...)); - CoopMat<...>::MapElement(capturing-lambda); - CoopMat<...>::MapElement(not-capturing-lambda); - Tuple<CoopMat<...>,...>::MapElement(functype(...)); - Tuple<CoopMat<...>,...>::MapElement(capturing-lambda); - Tuple<CoopMat<...>,...>::MapElement(not-capturing-lambda);
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
-rw-r--r--source/slang/slang-check-constraint.cpp243
1 files changed, 242 insertions, 1 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 6f9191135..6259e2fb8 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -1210,6 +1210,191 @@ void SemanticsVisitor::maybeUnifyUnconstraintIntParam(
constraints.constraints.add(c);
}
+struct IndexSpan
+{
+ Index index;
+ Index count;
+
+ IndexSpan()
+ : index(0), count(0)
+ {
+ }
+ IndexSpan(Index idx, Index cnt)
+ : index(idx), count(cnt)
+ {
+ }
+};
+
+struct IndexSpanPair
+{
+ IndexSpan first;
+ IndexSpan second;
+
+ IndexSpanPair() {}
+ IndexSpanPair(IndexSpan f, IndexSpan s)
+ : first(f), second(s)
+ {
+ }
+};
+
+// Helper function to unwrap a type and count expandable types
+static void unwrapTypeAndCountExpandable(
+ Type* type,
+ ShortList<Type*>& outTypes,
+ int& outExpandableCount)
+{
+ if (auto concretePack = as<ConcreteTypePack>(type))
+ {
+ for (Index i = 0; i < concretePack->getTypeCount(); ++i)
+ {
+ outTypes.add(concretePack->getElementType(i));
+ if (isAbstractTypePack(concretePack->getElementType(i)))
+ outExpandableCount++;
+ }
+ }
+ else if (isAbstractTypePack(type))
+ {
+ outTypes.add(type);
+ outExpandableCount++;
+ }
+}
+
+// Helper function to map type arguments between two types, handling expandable types
+static bool matchTypeArgMapping(
+ Type* firstType,
+ Type* secondType,
+ ShortList<Type*>& outFlattenedFirst,
+ ShortList<Type*>& outFlattenedSecond,
+ ShortList<IndexSpanPair>& outMapping)
+{
+ // Unwrap and flatten the types
+ ShortList<Type*>& firstTypes = outFlattenedFirst;
+ ShortList<Type*>& secondTypes = outFlattenedSecond;
+
+ // Count expandable types as we unwrap
+ int firstExpandableCount = 0;
+ int secondExpandableCount = 0;
+
+ // Unwrap both types using the helper function
+ unwrapTypeAndCountExpandable(firstType, firstTypes, firstExpandableCount);
+ unwrapTypeAndCountExpandable(secondType, secondTypes, secondExpandableCount);
+
+ // We need to figure out which side should be expanding.
+ // Consider the following cases,
+ //
+ // left = [ expand, expand ]
+ // right = [ int, float, expand ]
+ // when one side has more non-expandable types, the other side should expand to match it.
+ // in this case, "left" should expand to cover "int" and "float".
+ //
+ // left = [ int, float, expand, expand ]
+ // right = [ int, float, expand ]
+ // when the number of the non-expandable types are same, we want to expand side that has
+ // fewer expandable types. In this case, "right" should expand to cover the first "expand".
+ //
+ // left = ConcreteTypePack(ExpandType, ExpandType)
+ // right = ConcreteTypePack(int, bool, float, double).
+ // In this case, we shouldn't be mapping the first ExpandType to int and the second
+ // ExpandType to bool, float, double. Instead, they should evenly divide the second type
+ // pack, so we map first ExpandType with int, bool, and second ExpandType to float, double.
+ //
+ int firstCount = (int)firstTypes.getCount();
+ int secondCount = (int)secondTypes.getCount();
+ int countDifference =
+ (firstCount - firstExpandableCount) - (secondCount - secondExpandableCount);
+
+ bool shouldExpandFirst =
+ (firstExpandableCount > 0) &&
+ ((countDifference < 0) ||
+ (countDifference == 0 && firstExpandableCount < secondExpandableCount));
+
+ bool shouldExpandSecond =
+ (secondExpandableCount > 0) &&
+ ((countDifference > 0) ||
+ (countDifference == 0 && firstExpandableCount > secondExpandableCount));
+
+ // We need to figure out how much types should match per each expandable type.
+ int typesPerExpand = 0;
+ if (shouldExpandSecond)
+ {
+ // More types on first, need to expand second
+ int countToMatch = countDifference + firstExpandableCount;
+ SLANG_ASSERT(secondExpandableCount != 0);
+ if (countToMatch % secondExpandableCount != 0)
+ return false;
+ typesPerExpand = countToMatch / secondExpandableCount;
+ }
+ else if (shouldExpandFirst)
+ {
+ // More types on second, need to expand first
+ int countToMatch = -countDifference + secondExpandableCount;
+ SLANG_ASSERT(firstExpandableCount != 0);
+ if (countToMatch % firstExpandableCount != 0)
+ return false;
+ typesPerExpand = countToMatch / firstExpandableCount;
+ }
+ // If countDifference == 0, no expansion needed
+
+ // Generate the mapping
+ Index firstIndex = 0;
+ Index secondIndex = 0;
+
+ while (firstIndex < firstCount && secondIndex < secondCount)
+ {
+ IndexSpanPair mapping;
+
+ // Determine spans based on expandable types and count difference
+ if (shouldExpandFirst)
+ {
+ // Expanding first to match second
+ if (isAbstractTypePack(firstTypes[firstIndex]))
+ {
+ mapping.first = IndexSpan(firstIndex, 1);
+ mapping.second = IndexSpan(secondIndex, typesPerExpand);
+ secondIndex += typesPerExpand;
+ }
+ else
+ {
+ mapping.first = IndexSpan(firstIndex, 1);
+ mapping.second = IndexSpan(secondIndex, 1);
+ secondIndex++;
+ }
+ firstIndex++;
+ }
+ else if (shouldExpandSecond)
+ {
+ // Expanding second to match first
+ if (isAbstractTypePack(secondTypes[secondIndex]))
+ {
+ mapping.first = IndexSpan(firstIndex, typesPerExpand);
+ mapping.second = IndexSpan(secondIndex, 1);
+ firstIndex += typesPerExpand;
+ }
+ else
+ {
+ mapping.first = IndexSpan(firstIndex, 1);
+ mapping.second = IndexSpan(secondIndex, 1);
+ firstIndex++;
+ }
+ secondIndex++;
+ }
+ else
+ {
+ // No expansion needed
+ mapping.first = IndexSpan(firstIndex, 1);
+ mapping.second = IndexSpan(secondIndex, 1);
+ firstIndex++;
+ secondIndex++;
+ }
+
+ outMapping.add(mapping);
+ }
+
+ SLANG_ASSERT(!shouldExpandSecond || firstIndex == firstCount);
+ SLANG_ASSERT(!shouldExpandFirst || secondIndex == secondCount);
+ return true;
+}
+
bool SemanticsVisitor::TryUnifyTypes(
ConstraintSystem& constraints,
ValUnificationContext unifyCtx,
@@ -1244,7 +1429,63 @@ bool SemanticsVisitor::TryUnifyTypes(
return TryUnifyConjunctionType(constraints, unifyCtx, fst, snd);
}
- // If one of the types is a type pack, we need to recursively unify the element types.
+ // Unwrap ConcreteTypePack and call TryUnifyTypes recursively.
+ ShortList<IndexSpanPair> typeMapping;
+ ShortList<Type*> flattenedFirst;
+ ShortList<Type*> flattenedSecond;
+ if (matchTypeArgMapping(fst, snd, flattenedFirst, flattenedSecond, typeMapping) &&
+ typeMapping.getCount() > 1)
+ {
+ // Apply unification based on the mapping
+ for (const auto& mapping : typeMapping)
+ {
+ // Make sure it is one of three cases: 1:1, 1:N or N:1
+ SLANG_ASSERT(mapping.first.count > 0 && mapping.second.count > 0);
+ SLANG_ASSERT(mapping.first.count == 1 || mapping.second.count == 1);
+
+ // Helper function to create QualType from mapping span
+ auto mayPackAndGetQualType = [this](
+ const IndexSpan& span,
+ ShortList<Type*>& typeList,
+ bool isLeftValue,
+ Type* otherType) -> QualType
+ {
+ // When the `otherType` is GenericTypePackParamDecl or ExpandType,
+ // we need to create a ConcreteTypePack so that we can handle them
+ // recursively.
+ if (isDeclRefTypeOf<GenericTypePackParamDecl>(otherType) ||
+ as<ExpandType>(otherType))
+ {
+ // Multiple types: create ConcreteTypePack
+ auto typesView = makeArrayView(&typeList[span.index], span.count);
+ auto typePack = m_astBuilder->getTypePack(typesView);
+ return QualType(typePack, isLeftValue);
+ }
+
+ SLANG_ASSERT(span.count == 1);
+ return QualType(typeList[span.index], isLeftValue);
+ };
+
+ // Get the types directly from the mapping
+ QualType firstArg = mayPackAndGetQualType(
+ mapping.first,
+ flattenedFirst,
+ fst.isLeftValue,
+ flattenedSecond[mapping.second.index]);
+ QualType secondArg = mayPackAndGetQualType(
+ mapping.second,
+ flattenedSecond,
+ snd.isLeftValue,
+ flattenedFirst[mapping.first.index]);
+
+ // Perform the unification
+ if (!TryUnifyTypes(constraints, unifyCtx, firstArg, secondArg))
+ return false;
+ }
+
+ return true;
+ }
+
if (auto fstTypePack = as<ConcreteTypePack>(fst))
{
if (auto sndTypePack = as<ConcreteTypePack>(snd))