diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-05-29 16:36:49 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-29 16:36:49 -0700 |
| commit | 984d7f22f8a0909dc870c65bb927094c54f55402 (patch) | |
| tree | ab255bf44e14f6cbaa09522f90b12464f1c6a339 /source/slang/slang-check-constraint.cpp | |
| parent | f4d7954e088966c2ae8618b1cc17aac4d64ef013 (diff) | |
Implement MapElement for CoopMat (#7159)
With this PR, MapElement works for the following signatures:
- CoopMat<...>::MapElement(functype(...));
- CoopMat<...>::MapElement(capturing-lambda);
- CoopMat<...>::MapElement(not-capturing-lambda);
- Tuple<CoopMat<...>,...>::MapElement(functype(...));
- Tuple<CoopMat<...>,...>::MapElement(capturing-lambda);
- Tuple<CoopMat<...>,...>::MapElement(not-capturing-lambda);
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 243 |
1 files changed, 242 insertions, 1 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 6f9191135..6259e2fb8 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -1210,6 +1210,191 @@ void SemanticsVisitor::maybeUnifyUnconstraintIntParam( constraints.constraints.add(c); } +struct IndexSpan +{ + Index index; + Index count; + + IndexSpan() + : index(0), count(0) + { + } + IndexSpan(Index idx, Index cnt) + : index(idx), count(cnt) + { + } +}; + +struct IndexSpanPair +{ + IndexSpan first; + IndexSpan second; + + IndexSpanPair() {} + IndexSpanPair(IndexSpan f, IndexSpan s) + : first(f), second(s) + { + } +}; + +// Helper function to unwrap a type and count expandable types +static void unwrapTypeAndCountExpandable( + Type* type, + ShortList<Type*>& outTypes, + int& outExpandableCount) +{ + if (auto concretePack = as<ConcreteTypePack>(type)) + { + for (Index i = 0; i < concretePack->getTypeCount(); ++i) + { + outTypes.add(concretePack->getElementType(i)); + if (isAbstractTypePack(concretePack->getElementType(i))) + outExpandableCount++; + } + } + else if (isAbstractTypePack(type)) + { + outTypes.add(type); + outExpandableCount++; + } +} + +// Helper function to map type arguments between two types, handling expandable types +static bool matchTypeArgMapping( + Type* firstType, + Type* secondType, + ShortList<Type*>& outFlattenedFirst, + ShortList<Type*>& outFlattenedSecond, + ShortList<IndexSpanPair>& outMapping) +{ + // Unwrap and flatten the types + ShortList<Type*>& firstTypes = outFlattenedFirst; + ShortList<Type*>& secondTypes = outFlattenedSecond; + + // Count expandable types as we unwrap + int firstExpandableCount = 0; + int secondExpandableCount = 0; + + // Unwrap both types using the helper function + unwrapTypeAndCountExpandable(firstType, firstTypes, firstExpandableCount); + unwrapTypeAndCountExpandable(secondType, secondTypes, secondExpandableCount); + + // We need to figure out which side should be expanding. + // Consider the following cases, + // + // left = [ expand, expand ] + // right = [ int, float, expand ] + // when one side has more non-expandable types, the other side should expand to match it. + // in this case, "left" should expand to cover "int" and "float". + // + // left = [ int, float, expand, expand ] + // right = [ int, float, expand ] + // when the number of the non-expandable types are same, we want to expand side that has + // fewer expandable types. In this case, "right" should expand to cover the first "expand". + // + // left = ConcreteTypePack(ExpandType, ExpandType) + // right = ConcreteTypePack(int, bool, float, double). + // In this case, we shouldn't be mapping the first ExpandType to int and the second + // ExpandType to bool, float, double. Instead, they should evenly divide the second type + // pack, so we map first ExpandType with int, bool, and second ExpandType to float, double. + // + int firstCount = (int)firstTypes.getCount(); + int secondCount = (int)secondTypes.getCount(); + int countDifference = + (firstCount - firstExpandableCount) - (secondCount - secondExpandableCount); + + bool shouldExpandFirst = + (firstExpandableCount > 0) && + ((countDifference < 0) || + (countDifference == 0 && firstExpandableCount < secondExpandableCount)); + + bool shouldExpandSecond = + (secondExpandableCount > 0) && + ((countDifference > 0) || + (countDifference == 0 && firstExpandableCount > secondExpandableCount)); + + // We need to figure out how much types should match per each expandable type. + int typesPerExpand = 0; + if (shouldExpandSecond) + { + // More types on first, need to expand second + int countToMatch = countDifference + firstExpandableCount; + SLANG_ASSERT(secondExpandableCount != 0); + if (countToMatch % secondExpandableCount != 0) + return false; + typesPerExpand = countToMatch / secondExpandableCount; + } + else if (shouldExpandFirst) + { + // More types on second, need to expand first + int countToMatch = -countDifference + secondExpandableCount; + SLANG_ASSERT(firstExpandableCount != 0); + if (countToMatch % firstExpandableCount != 0) + return false; + typesPerExpand = countToMatch / firstExpandableCount; + } + // If countDifference == 0, no expansion needed + + // Generate the mapping + Index firstIndex = 0; + Index secondIndex = 0; + + while (firstIndex < firstCount && secondIndex < secondCount) + { + IndexSpanPair mapping; + + // Determine spans based on expandable types and count difference + if (shouldExpandFirst) + { + // Expanding first to match second + if (isAbstractTypePack(firstTypes[firstIndex])) + { + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, typesPerExpand); + secondIndex += typesPerExpand; + } + else + { + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, 1); + secondIndex++; + } + firstIndex++; + } + else if (shouldExpandSecond) + { + // Expanding second to match first + if (isAbstractTypePack(secondTypes[secondIndex])) + { + mapping.first = IndexSpan(firstIndex, typesPerExpand); + mapping.second = IndexSpan(secondIndex, 1); + firstIndex += typesPerExpand; + } + else + { + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, 1); + firstIndex++; + } + secondIndex++; + } + else + { + // No expansion needed + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, 1); + firstIndex++; + secondIndex++; + } + + outMapping.add(mapping); + } + + SLANG_ASSERT(!shouldExpandSecond || firstIndex == firstCount); + SLANG_ASSERT(!shouldExpandFirst || secondIndex == secondCount); + return true; +} + bool SemanticsVisitor::TryUnifyTypes( ConstraintSystem& constraints, ValUnificationContext unifyCtx, @@ -1244,7 +1429,63 @@ bool SemanticsVisitor::TryUnifyTypes( return TryUnifyConjunctionType(constraints, unifyCtx, fst, snd); } - // If one of the types is a type pack, we need to recursively unify the element types. + // Unwrap ConcreteTypePack and call TryUnifyTypes recursively. + ShortList<IndexSpanPair> typeMapping; + ShortList<Type*> flattenedFirst; + ShortList<Type*> flattenedSecond; + if (matchTypeArgMapping(fst, snd, flattenedFirst, flattenedSecond, typeMapping) && + typeMapping.getCount() > 1) + { + // Apply unification based on the mapping + for (const auto& mapping : typeMapping) + { + // Make sure it is one of three cases: 1:1, 1:N or N:1 + SLANG_ASSERT(mapping.first.count > 0 && mapping.second.count > 0); + SLANG_ASSERT(mapping.first.count == 1 || mapping.second.count == 1); + + // Helper function to create QualType from mapping span + auto mayPackAndGetQualType = [this]( + const IndexSpan& span, + ShortList<Type*>& typeList, + bool isLeftValue, + Type* otherType) -> QualType + { + // When the `otherType` is GenericTypePackParamDecl or ExpandType, + // we need to create a ConcreteTypePack so that we can handle them + // recursively. + if (isDeclRefTypeOf<GenericTypePackParamDecl>(otherType) || + as<ExpandType>(otherType)) + { + // Multiple types: create ConcreteTypePack + auto typesView = makeArrayView(&typeList[span.index], span.count); + auto typePack = m_astBuilder->getTypePack(typesView); + return QualType(typePack, isLeftValue); + } + + SLANG_ASSERT(span.count == 1); + return QualType(typeList[span.index], isLeftValue); + }; + + // Get the types directly from the mapping + QualType firstArg = mayPackAndGetQualType( + mapping.first, + flattenedFirst, + fst.isLeftValue, + flattenedSecond[mapping.second.index]); + QualType secondArg = mayPackAndGetQualType( + mapping.second, + flattenedSecond, + snd.isLeftValue, + flattenedFirst[mapping.first.index]); + + // Perform the unification + if (!TryUnifyTypes(constraints, unifyCtx, firstArg, secondArg)) + return false; + } + + return true; + } + if (auto fstTypePack = as<ConcreteTypePack>(fst)) { if (auto sndTypePack = as<ConcreteTypePack>(snd)) |
