summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ast-type.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-14 18:41:48 -0700
committerGitHub <noreply@github.com>2024-08-14 18:41:48 -0700
commit071f1b6062b459928ebfd6f2f60a8d6ad021112b (patch)
tree2ba65eb40f39701db6fc775a9258ec8079d161a0 /source/slang/slang-ast-type.cpp
parent35a3d32c87f079749f6b100d01b289c3da02d7d6 (diff)
Variadic Generics Part 1: parsing and type checking. (#4833)
Diffstat (limited to 'source/slang/slang-ast-type.cpp')
-rw-r--r--source/slang/slang-ast-type.cpp206
1 files changed, 181 insertions, 25 deletions
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 44585ee30..1c9f68a48 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -9,6 +9,22 @@
#include "slang-generated-ast-macro.h"
namespace Slang {
+bool isAbstractTypePack(Type* type)
+{
+ if (as<ExpandType>(type))
+ return true;
+ if (isDeclRefTypeOf<GenericTypePackParamDecl>(type))
+ return true;
+ return false;
+}
+
+bool isTypePack(Type* type)
+{
+ if (as<ConcreteTypePack>(type))
+ return true;
+ return isAbstractTypePack(type);
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Type* Type::_createCanonicalTypeOverride()
@@ -119,7 +135,7 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe
return lookupDeclRef->getLookupSource();
}
}
- else if (as<GenericTypeParamDecl>(substDeclRef.getDecl()) || as<GenericValueParamDecl>(substDeclRef.getDecl()))
+ else if (as<GenericTypeParamDeclBase>(substDeclRef.getDecl()) || as<GenericValueParamDecl>(substDeclRef.getDecl()))
{
auto resultVal = maybeSubstituteGenericParam(nullptr, substDeclRef.getDecl(), subst, ioDiff);
if (resultVal)
@@ -259,6 +275,26 @@ Type* MatrixExpressionType::getRowType()
return rowType;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TupleType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+Type* TupleType::getMember(Index i) const
+{
+ if (auto typePack = as<ConcreteTypePack>(_getGenericTypeArg(getDeclRefBase(), 0)))
+ return typePack->getElementType(i);
+ return nullptr;
+}
+
+Index TupleType::getMemberCount() const
+{
+ if (auto typePack = as<ConcreteTypePack>(_getGenericTypeArg(getDeclRefBase(), 0)))
+ return typePack->getTypeCount();
+ return 0;
+}
+
+Type* TupleType::getTypePack() const
+{
+ return as<Type>(_getGenericTypeArg(getDeclRefBase(), 0));
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArrayExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Type* ArrayExpressionType::getElementType()
@@ -520,47 +556,167 @@ Type* FuncType::_createCanonicalTypeOverride()
return canType;
}
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TupleType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-void TupleType::_toTextOverride(StringBuilder& out)
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! EachType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+void EachType::_toTextOverride(StringBuilder& out)
{
- out << toSlice("(");
- for (Index pp = 0; pp < getOperandCount(); ++pp)
+ out << "each ";
+ if (getElementType())
{
- if (pp != 0)
- out << toSlice(", ");
- out << getOperand(pp);
+ getElementType()->toText(out);
}
- out << toSlice(")");
+ else
+ {
+ out << "<null>";
+ }
+}
+
+Type* EachType::_createCanonicalTypeOverride()
+{
+ return this;
}
-Val* TupleType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+Val* EachType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
+ auto substElementType = as<Type>(getElementType()->substituteImpl(astBuilder, subst, &diff));
+ if (!diff)
+ return this;
+ if (auto typePack = as<ConcreteTypePack>(substElementType))
+ {
+ if (subst.packExpansionIndex >= 0 && subst.packExpansionIndex < typePack->getTypeCount())
+ {
+ (*ioDiff)++;
+ return typePack->getElementType(subst.packExpansionIndex);
+ }
+ }
+ else if (auto expandType = as<ExpandType>(substElementType))
+ {
+ if (auto innerEach = as<EachType>(expandType->getPatternType()))
+ {
+ (*ioDiff)++;
+ return innerEach;
+ }
+ }
+ (*ioDiff)++;
+ return astBuilder->getEachType(substElementType);
+}
- // just recurse into the members
- List<Type*> substMemberTypes;
- for (Index m = 0; m < getMemberCount(); m++)
- substMemberTypes.add(as<Type>(getMember(m)->substituteImpl(astBuilder, subst, &diff)));
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExpandType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+void ExpandType::_toTextOverride(StringBuilder& out)
+{
+ out << "expand ";
+ getPatternType()->toText(out);
+}
- // early exit for no change...
- if (!diff)
+Type* ExpandType::_createCanonicalTypeOverride()
+{
+ auto canonicalPatternType = getPatternType()->getCanonicalType();
+ if (canonicalPatternType == getPatternType())
return this;
+ ShortList<Type*> capturedPacks;
+ for (Index i = 0; i < getCapturedTypePackCount(); i++)
+ {
+ capturedPacks.add(getCapturedTypePack(i));
+ }
+ return getCurrentASTBuilder()->getExpandType(canonicalPatternType, capturedPacks.getArrayView().arrayView);
+}
- (*ioDiff)++;
- return astBuilder->getTupleType(substMemberTypes);
+Val* ExpandType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+
+ // Given ExpandType(PatternType, CapturedTypePackParams), we first need to know
+ // if all captured GenericTypePackParams can be substituted into concrete type packs.
+ // We can't expand the ExpandType into a concrete type pack, if any of the captured type
+ // pack parameters aren't concrete themselves.
+ //
+ ShortList<Type*> capturedPacks;
+ ShortList<ConcreteTypePack*> concreteTypePacks;
+ for (Index i = 0; i < getCapturedTypePackCount(); i++)
+ {
+ auto substCapturedTypePack = getCapturedTypePack(i)->substituteImpl(astBuilder, subst, &diff);
+ if (auto expandType = as<ExpandType>(substCapturedTypePack))
+ {
+ for (Index j = 0; j < expandType->getCapturedTypePackCount(); j++)
+ capturedPacks.add(expandType->getCapturedTypePack(j));
+ }
+ else
+ {
+ capturedPacks.add(as<Type>(substCapturedTypePack));
+ if (auto pack = as<ConcreteTypePack>(capturedPacks.getLast()))
+ {
+ concreteTypePacks.add(pack);
+ }
+ }
+ }
+
+ if (!diff || concreteTypePacks.getCount() != capturedPacks.getCount())
+ {
+ auto substPatternType = getPatternType()->substituteImpl(astBuilder, subst, &diff);
+ if (!diff)
+ return this;
+
+ // If some part of pattern type or captured type can be substituted into something else,
+ // but not all of the captured types are resolved to concrete type packs yet, we will just
+ // create a new ExpandType with the substituted pattern/capture types, instead of actually
+ // expanding into a concrete type pack.
+ (*ioDiff)++;
+ return astBuilder->getExpandType(as<Type>(substPatternType), capturedPacks.getArrayView().arrayView);
+ }
+ else
+ {
+ // All type pack parameters are now concrete type packs, so we can construct a concrete type pack
+ // by substituting the pattern type with each element of the captured type pack.
+ ShortList<Type*> expandedTypes;
+ SLANG_ASSERT(capturedPacks.getCount() != 0);
+
+ for (Index i = 0; i < concreteTypePacks[0]->getTypeCount(); i++)
+ {
+ subst.packExpansionIndex = i;
+ auto substElementType = getPatternType()->substituteImpl(astBuilder, subst, &diff);
+ expandedTypes.add(as<Type>(substElementType));
+ }
+ if (!diff)
+ return this;
+ (*ioDiff)++;
+ return astBuilder->getTypePack(expandedTypes.getArrayView().arrayView);
+ }
}
-Type* TupleType::_createCanonicalTypeOverride()
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ConcreteTypePack !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+void ConcreteTypePack::_toTextOverride(StringBuilder& out)
{
- // member types
- List<Type*> canMemberTypes;
- for (Index m = 0; m < getMemberCount(); m++)
+ for (Index i = 0; i < getTypeCount(); i++)
{
- canMemberTypes.add(getMember(m)->getCanonicalType());
+ if (i != 0)
+ out << ", ";
+ getElementType(i)->toText(out);
}
+}
- return getCurrentASTBuilder()->getTupleType(canMemberTypes);
+Type* ConcreteTypePack::_createCanonicalTypeOverride()
+{
+ ShortList<Type*> canonicalElementTypes;
+ for (Index i = 0; i < getTypeCount(); i++)
+ {
+ canonicalElementTypes.add(getElementType(i)->getCanonicalType());
+ }
+ return getCurrentASTBuilder()->getTypePack(canonicalElementTypes.getArrayView().arrayView);
+}
+
+Val* ConcreteTypePack::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+ ShortList<Type*> substElementTypes;
+ for (Index i = 0; i < getTypeCount(); i++)
+ {
+ auto substType = as<Type>(getElementType(i)->substituteImpl(astBuilder, subst, &diff));
+ substElementTypes.add(substType);
+ }
+ if (!diff)
+ return this;
+ (*ioDiff)++;
+ return getCurrentASTBuilder()->getTypePack(substElementTypes.getArrayView().arrayView);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!