summaryrefslogtreecommitdiff
path: root/source/slang/slang-ast-type.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-04 15:47:39 -0700
committerGitHub <noreply@github.com>2023-08-04 15:47:39 -0700
commita2d90fb275962da84611160f8ddd74d934a68dbd (patch)
tree066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-ast-type.cpp
parent17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff)
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val. * Update project files * Fix warning. * Fix. * Fix. * Remove `Val::_equalsImplOverride`. * Rmove `Val::_getHashCodeOverride`. * Remove `semanticVisitor` param from `resolve`. * Cleanups. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-type.cpp')
-rw-r--r--source/slang/slang-ast-type.cpp827
1 files changed, 198 insertions, 629 deletions
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index ee5d1d40e..13133a7f8 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -1,49 +1,19 @@
// slang-ast-type.cpp
#include "slang-ast-builder.h"
+#include "slang-ast-modifier.h"
#include <assert.h>
#include <typeinfo>
#include "slang-syntax.h"
#include "slang-generated-ast-macro.h"
-
namespace Slang {
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-Type* Type::createCanonicalType()
-{
- SLANG_AST_NODE_VIRTUAL_CALL(Type, createCanonicalType, ())
-}
-
-bool Type::equals(Type* type)
-{
- return getCanonicalType()->equalsImpl(type->getCanonicalType());
-}
-
-bool Type::equalsImpl(Type* type)
-{
- SLANG_AST_NODE_VIRTUAL_CALL(Type, equalsImpl, (type))
-}
-
-bool Type::_equalsImplOverride(Type* type)
-{
- SLANG_UNUSED(type)
- SLANG_UNEXPECTED("Type::_equalsImplOverride not overridden");
- //return false;
-}
-
Type* Type::_createCanonicalTypeOverride()
{
- SLANG_UNEXPECTED("Type::_createCanonicalTypeOverride not overridden");
- //return Type*();
-}
-
-bool Type::_equalsValOverride(Val* val)
-{
- if (auto type = dynamicCast<Type>(val))
- return const_cast<Type*>(this)->equals(type);
- return false;
+ return as<Type>(defaultResolveImpl());
}
Val* Type::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
@@ -61,20 +31,6 @@ Val* Type::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst
return canSubst;
}
-Type* Type::getCanonicalType()
-{
- Type* et = const_cast<Type*>(this);
- if (!et->canonicalType)
- {
- // TODO(tfoley): worry about thread safety here?
- auto canType = et->createCanonicalType();
- et->canonicalType = canType;
- if (!et->canonicalType)
- return getASTBuilder()->getErrorType();
- }
- return et->canonicalType;
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! OverloadGroupType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void OverloadGroupType::_toTextOverride(StringBuilder& out)
@@ -82,21 +38,11 @@ void OverloadGroupType::_toTextOverride(StringBuilder& out)
out << toSlice("overload group");
}
-bool OverloadGroupType::_equalsImplOverride(Type * /*type*/)
-{
- return false;
-}
-
Type* OverloadGroupType::_createCanonicalTypeOverride()
{
return this;
}
-HashCode OverloadGroupType::_getHashCodeOverride()
-{
- return (HashCode)(size_t(this));
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! InitializerListType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void InitializerListType::_toTextOverride(StringBuilder& out)
@@ -104,21 +50,11 @@ void InitializerListType::_toTextOverride(StringBuilder& out)
out << toSlice("initializer list");
}
-bool InitializerListType::_equalsImplOverride(Type * /*type*/)
-{
- return false;
-}
-
Type* InitializerListType::_createCanonicalTypeOverride()
{
return this;
}
-HashCode InitializerListType::_getHashCodeOverride()
-{
- return (HashCode)(size_t(this));
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void ErrorType::_toTextOverride(StringBuilder& out)
@@ -126,11 +62,6 @@ void ErrorType::_toTextOverride(StringBuilder& out)
out << toSlice("error");
}
-bool ErrorType::_equalsImplOverride(Type* type)
-{
- return as<ErrorType>(type);
-}
-
Type* ErrorType::_createCanonicalTypeOverride()
{
return this;
@@ -141,56 +72,21 @@ Val* ErrorType::_substituteImplOverride(ASTBuilder* /* astBuilder */, Substituti
return this;
}
-HashCode ErrorType::_getHashCodeOverride()
-{
- return HashCode(size_t(this));
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BottomType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void BottomType::_toTextOverride(StringBuilder& out) { out << toSlice("never"); }
-bool BottomType::_equalsImplOverride(Type* type)
-{
- return as<BottomType>(type);
-}
-
-Type* BottomType::_createCanonicalTypeOverride() { return this; }
-
Val* BottomType::_substituteImplOverride(
ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/)
{
return this;
}
-HashCode BottomType::_getHashCodeOverride() { return HashCode(size_t(this)); }
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void DeclRefType::_toTextOverride(StringBuilder& out)
{
- out << declRef;
-}
-
-HashCode DeclRefType::_getHashCodeOverride()
-{
- return (declRef.getHashCode() * 16777619) ^ (HashCode)(typeid(this).hash_code());
-}
-
-bool DeclRefType::_equalsImplOverride(Type * type)
-{
- if (auto declRefType = as<DeclRefType>(type))
- {
- return declRef.equals(declRefType->declRef);
- }
- return false;
-}
-
-Type* DeclRefType::_createCanonicalTypeOverride()
-{
- // A declaration reference is already canonical
- declRef.substitute(this->getASTBuilder(), this);
- return this;
+ out << getDeclRef();
}
Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff);
@@ -199,26 +95,47 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe
{
if (!subst) return this;
- // the case we especially care about is when this type references a declaration
- // of a generic parameter, since that is what we might be substituting...
- if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(declRef.getDecl()))
+ int diff = 0;
+ DeclRef<Decl> substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
+
+ // If this declref type is a direct reference to ThisType or a Generic parameter,
+ // and `subst` provides an argument for it, then we should just return that argument.
+ //
+ if (as<DirectDeclRef>(substDeclRef.declRefBase))
{
- if (auto result = maybeSubstituteGenericParam(this, genericTypeParamDecl, subst, ioDiff))
+ if (auto thisDecl = as<ThisTypeDecl>(substDeclRef.getDecl()))
+ {
+ auto lookupDeclRef = subst.findLookupDeclRef();
+ if (lookupDeclRef && lookupDeclRef->getSupDecl() == substDeclRef.getDecl()->parentDecl)
+ {
+ (*ioDiff)++;
+ return lookupDeclRef->getLookupSource();
+ }
+ }
+ else if (as<GenericTypeParamDecl>(substDeclRef.getDecl()) || as<GenericValueParamDecl>(substDeclRef.getDecl()))
{
- if (auto substDeclRefType = as<DeclRefType>(result))
+ auto resultVal = maybeSubstituteGenericParam(nullptr, substDeclRef.getDecl(), subst, ioDiff);
+ if (resultVal)
{
- // After generic substitution, we may be able to further simplify
- // by looking up the actual type of an associated type.
- if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(
- astBuilder, substDeclRefType->declRef))
- return satisfyingVal;
+ (*ioDiff)++;
+ return resultVal;
}
- return result;
}
}
- int diff = 0;
- DeclRef<Decl> substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff);
+ // If this type is a reference to an associated type declaration,
+ // and the substitutions provide a "this type" substitution for
+ // the outer interface, then try to replace the type with the
+ // actual value of the associated type for the given implementation.
+ //
+ if (auto satisfyingVal = substDeclRef.declRefBase->resolve())
+ {
+ if (satisfyingVal != getDeclRef())
+ {
+ *ioDiff += 1;
+ return DeclRefType::create(astBuilder, substDeclRef);
+ }
+ }
if (!diff)
return this;
@@ -226,14 +143,6 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe
// Make sure to record the difference!
*ioDiff += diff;
- // If this type is a reference to an associated type declaration,
- // and the substitutions provide a "this type" substitution for
- // the outer interface, then try to replace the type with the
- // actual value of the associated type for the given implementation.
- //
- if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, substDeclRef))
- return satisfyingVal;
-
// Re-construct the type in case we are using a specialized sub-class
return DeclRefType::create(astBuilder, substDeclRef);
}
@@ -254,40 +163,52 @@ BasicExpressionType* ArithmeticExpressionType::_getScalarTypeOverride()
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BasicExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool BasicExpressionType::_equalsImplOverride(Type * type)
+BasicExpressionType* BasicExpressionType::_getScalarTypeOverride()
{
- auto basicType = as<BasicExpressionType>(type);
- return basicType && basicType->baseType == this->baseType;
+ return this;
}
-Type* BasicExpressionType::_createCanonicalTypeOverride()
+static Val* _getGenericTypeArg(DeclRefBase* declRef, Index i)
{
- // A basic type is already canonical, in our setup
- return this;
+ auto args = findInnerMostGenericArgs(SubstitutionSet(declRef));
+ if (args.getCount() <= i)
+ return nullptr;
+
+ return args[i];
}
-BasicExpressionType* BasicExpressionType::_getScalarTypeOverride()
+static Val* _getGenericTypeArg(DeclRefType* declRefType, Index i)
{
- return this;
+ return _getGenericTypeArg(declRefType->getDeclRefBase(), i);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TensorViewType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Type* TensorViewType::getElementType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! VectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+Type* VectorExpressionType::getElementType()
+{
+ return as<Type>(_getGenericTypeArg(this, 0));
+}
+
+IntVal* VectorExpressionType::getElementCount()
+{
+ return as<IntVal>(_getGenericTypeArg(this, 1));
+}
+
void VectorExpressionType::_toTextOverride(StringBuilder& out)
{
- out << toSlice("vector<") << elementType << toSlice(",") << elementCount << toSlice(">");
+ out << toSlice("vector<") << getElementType() << toSlice(",") << getElementCount() << toSlice(">");
}
BasicExpressionType* VectorExpressionType::_getScalarTypeOverride()
{
- return as<BasicExpressionType>(elementType);
+ return as<BasicExpressionType>(getElementType());
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MatrixExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -304,24 +225,24 @@ BasicExpressionType* MatrixExpressionType::_getScalarTypeOverride()
Type* MatrixExpressionType::getElementType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
}
IntVal* MatrixExpressionType::getRowCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]);
+ return as<IntVal>(_getGenericTypeArg(this, 1));
}
IntVal* MatrixExpressionType::getColumnCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[2]);
+ return as<IntVal>(_getGenericTypeArg(this, 2));
}
Type* MatrixExpressionType::getRowType()
{
if (!rowType)
{
- rowType = m_astBuilder->getVectorType(getElementType(), getColumnCount());
+ rowType = getCurrentASTBuilder()->getVectorType(getElementType(), getColumnCount());
}
return rowType;
}
@@ -330,12 +251,12 @@ Type* MatrixExpressionType::getRowType()
Type* ArrayExpressionType::getElementType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
}
IntVal* ArrayExpressionType::getElementCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]);
+ return as<IntVal>(_getGenericTypeArg(this, 1));
}
void ArrayExpressionType::_toTextOverride(StringBuilder& out)
@@ -353,7 +274,7 @@ bool ArrayExpressionType::isUnsized()
{
if (auto constSize = as<ConstantIntVal>(getElementCount()))
{
- if (constSize->value == kUnsizedArrayMagicLength)
+ if (constSize->getValue() == kUnsizedArrayMagicLength)
return true;
}
return false;
@@ -363,27 +284,12 @@ bool ArrayExpressionType::isUnsized()
void TypeType::_toTextOverride(StringBuilder& out)
{
- out << toSlice("typeof(") << type << toSlice(")");
-}
-
-bool TypeType::_equalsImplOverride(Type * t)
-{
- if (auto typeType = as<TypeType>(t))
- {
- return t->equals(typeType->type);
- }
- return false;
+ out << toSlice("typeof(") << getType() << toSlice(")");
}
Type* TypeType::_createCanonicalTypeOverride()
{
- return getASTBuilder()->getTypeType(type->getCanonicalType());
-}
-
-HashCode TypeType::_getHashCodeOverride()
-{
- SLANG_UNEXPECTED("TypeType::_getHashCodeOverride should be unreachable");
- //return HashCode(0);
+ return getCurrentASTBuilder()->getTypeType(getType()->getCanonicalType());
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericDeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -394,20 +300,6 @@ void GenericDeclRefType::_toTextOverride(StringBuilder& out)
out << toSlice("<DeclRef<GenericDecl>>");
}
-bool GenericDeclRefType::_equalsImplOverride(Type * type)
-{
- if (auto genericDeclRefType = as<GenericDeclRefType>(type))
- {
- return declRef.equals(genericDeclRefType->declRef);
- }
- return false;
-}
-
-HashCode GenericDeclRefType::_getHashCodeOverride()
-{
- return declRef.getHashCode();
-}
-
Type* GenericDeclRefType::_createCanonicalTypeOverride()
{
return this;
@@ -417,21 +309,7 @@ Type* GenericDeclRefType::_createCanonicalTypeOverride()
void NamespaceType::_toTextOverride(StringBuilder& out)
{
- out << toSlice("namespace ") << declRef;
-}
-
-bool NamespaceType::_equalsImplOverride(Type * type)
-{
- if (auto namespaceType = as<NamespaceType>(type))
- {
- return declRef.equals(namespaceType->declRef);
- }
- return false;
-}
-
-HashCode NamespaceType::_getHashCodeOverride()
-{
- return declRef.getHashCode();
+ out << toSlice("namespace ") << getDeclRef();
}
Type* NamespaceType::_createCanonicalTypeOverride()
@@ -441,7 +319,7 @@ Type* NamespaceType::_createCanonicalTypeOverride()
Type* DifferentialPairType::getPrimalType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
}
@@ -449,51 +327,35 @@ Type* DifferentialPairType::getPrimalType()
Type* PtrTypeBase::getValueType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
}
Type* OptionalType::getValueType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
+}
+
+Type* NativeRefType::getValueType()
+{
+ return as<Type>(_getGenericTypeArg(this, 0));
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void NamedExpressionType::_toTextOverride(StringBuilder& out)
{
- if (declRef.getDecl())
+ if (getDeclRef().getDecl())
{
- _printNestedDecl(declRef.getSubst(), declRef.getDecl(), out);
+ getDeclRef().declRefBase->toText(out);
}
}
-bool NamedExpressionType::_equalsImplOverride(Type * /*type*/)
-{
- SLANG_UNEXPECTED("NamedExpressionType::_equalsImplOverride should be unreachable");
- //return false;
-}
-
Type* NamedExpressionType::_createCanonicalTypeOverride()
{
- if (!innerType)
- innerType = getType(m_astBuilder, declRef);
- if (innerType)
- return innerType->getCanonicalType();
- return nullptr;
-}
-
-HashCode NamedExpressionType::_getHashCodeOverride()
-{
- // Type equality is based on comparing canonical types,
- // so the hash code for a type needs to come from the
- // canonical version of the type. This really means
- // that `Type::getHashCode()` should dispatch out to
- // something like `Type::getHashCodeImpl()` on the
- // canonical version of a type, but it is less invasive
- // for now (and hopefully equivalent) to just have any
- // named types automaticlaly route hash-code requests
- // to their canonical type.
- return getCanonicalType()->getHashCode();
+ auto canType = getType(getCurrentASTBuilder(), getDeclRef());
+ if (canType)
+ return canType->getCanonicalType();
+ return getCurrentASTBuilder()->getErrorType();
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -533,58 +395,27 @@ void FuncType::_toTextOverride(StringBuilder& out)
}
out << ") -> " << getResultType();
- if (!getErrorType()->equals(getASTBuilder()->getBottomType()))
+ if (!getErrorType()->equals(getCurrentASTBuilder()->getBottomType()))
{
out << " throws " << getErrorType();
}
}
-bool FuncType::_equalsImplOverride(Type * type)
-{
- if (auto funcType = as<FuncType>(type))
- {
- auto paramCount = getParamCount();
- auto otherParamCount = funcType->getParamCount();
- if (paramCount != otherParamCount)
- return false;
-
- for (Index pp = 0; pp < paramCount; ++pp)
- {
- auto paramType = getParamType(pp);
- auto otherParamType = funcType->getParamType(pp);
- if (!paramType->equals(otherParamType))
- return false;
- }
-
- if (!resultType->equals(funcType->resultType))
- return false;
-
- if (!errorType->equals(funcType->errorType))
- return false;
-
- // TODO: if we ever introduce other kinds
- // of qualification on function types, we'd
- // want to consider it here.
- return true;
- }
- return false;
-}
-
Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
// result type
- Type* substResultType = as<Type>(resultType->substituteImpl(astBuilder, subst, &diff));
+ Type* substResultType = as<Type>(getResultType()->substituteImpl(astBuilder, subst, &diff));
// error type
- Type* substErrorType = as<Type>(errorType->substituteImpl(astBuilder, subst, &diff));
+ Type* substErrorType = as<Type>(getErrorType()->substituteImpl(astBuilder, subst, &diff));
// parameter types
List<Type*> substParamTypes;
- for (auto pp : paramTypes)
+ for (Index pp = 0; pp < getParamCount(); pp++ )
{
- substParamTypes.add(as<Type>(pp->substituteImpl(astBuilder, subst, &diff)));
+ substParamTypes.add(as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff)));
}
// early exit for no change...
@@ -592,138 +423,75 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s
return this;
(*ioDiff)++;
- FuncType* substType = astBuilder->create<FuncType>();
- substType->resultType = substResultType;
- substType->paramTypes = substParamTypes;
- substType->errorType = substErrorType;
+ FuncType* substType = astBuilder->getFuncType(substParamTypes.getArrayView(), substResultType, substErrorType);
return substType;
}
Type* FuncType::_createCanonicalTypeOverride()
{
// result type
- Type* canResultType = resultType->getCanonicalType();
- Type* canErrorType = errorType->getCanonicalType();
+ Type* canResultType = getResultType()->getCanonicalType();
+ Type* canErrorType = getErrorType()->getCanonicalType();
// parameter types
List<Type*> canParamTypes;
- for (auto pp : paramTypes)
+ for (Index pp = 0; pp < getParamCount(); pp++)
{
- canParamTypes.add(pp->getCanonicalType());
+ canParamTypes.add(getParamType(pp)->getCanonicalType());
}
- FuncType* canType = getASTBuilder()->create<FuncType>();
- canType->resultType = canResultType;
- canType->paramTypes = canParamTypes;
- canType->errorType = canErrorType;
+ FuncType* canType = getCurrentASTBuilder()->getFuncType(canParamTypes.getArrayView(), canResultType, canErrorType);
return canType;
}
-HashCode FuncType::_getHashCodeOverride()
-{
- HashCode hashCode = getResultType()->getHashCode();
- Index paramCount = getParamCount();
- hashCode = combineHash(hashCode, Slang::getHashCode(paramCount));
- for (Index pp = 0; pp < paramCount; ++pp)
- {
- hashCode = combineHash(
- hashCode,
- getParamType(pp)->getHashCode());
- }
- combineHash(hashCode, getErrorType()->getHashCode());
- return hashCode;
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TupleType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void TupleType::_toTextOverride(StringBuilder& out)
{
out << toSlice("(");
- for (Index pp = 0; pp < memberTypes.getCount(); ++pp)
+ for (Index pp = 0; pp < getOperandCount(); ++pp)
{
if (pp != 0)
out << toSlice(", ");
- out << memberTypes[pp];
+ out << getOperand(pp);
}
out << toSlice(")");
}
-bool TupleType::_equalsImplOverride(Type * type)
-{
- if (const auto other = as<TupleType>(type))
- {
- auto paramCount = memberTypes.getCount();
- auto otherParamCount = other->memberTypes.getCount();
- if (paramCount != otherParamCount)
- return false;
-
- for (Index i = 0; i < memberTypes.getCount(); ++i)
- {
- if(!memberTypes[i]->equals(other->memberTypes[i]))
- return false;
- }
-
- return true;
- }
- return false;
-}
-
Val* TupleType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
// just recurse into the members
List<Type*> substMemberTypes;
- for (auto m : memberTypes)
- substMemberTypes.add(as<Type>(m->substituteImpl(astBuilder, subst, &diff)));
+ for (Index m = 0; m < getMemberCount(); m++)
+ substMemberTypes.add(as<Type>(getMember(m)->substituteImpl(astBuilder, subst, &diff)));
// early exit for no change...
if (!diff)
return this;
(*ioDiff)++;
- return astBuilder->create<TupleType>(std::move(substMemberTypes));
+ return astBuilder->getTupleType(substMemberTypes);
}
Type* TupleType::_createCanonicalTypeOverride()
{
// member types
List<Type*> canMemberTypes;
- for (auto m : memberTypes)
+ for (Index m = 0; m < getMemberCount(); m++)
{
- canMemberTypes.add(m->getCanonicalType());
+ canMemberTypes.add(getMember(m)->getCanonicalType());
}
- return getASTBuilder()->create<TupleType>(std::move(canMemberTypes));
-}
-
-HashCode TupleType::_getHashCodeOverride()
-{
- HashCode hashCode = Slang::getHashCode(kType);
- for(auto m : memberTypes)
- hashCode = combineHash(hashCode, m->getHashCode());
- return hashCode;
+ return getCurrentASTBuilder()->getTupleType(canMemberTypes);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void ExtractExistentialType::_toTextOverride(StringBuilder& out)
{
- out << declRef << toSlice(".This");
-}
-
-bool ExtractExistentialType::_equalsImplOverride(Type* type)
-{
- if (auto extractExistential = as<ExtractExistentialType>(type))
- {
- return declRef.equals(extractExistential->declRef);
- }
- return false;
-}
-
-HashCode ExtractExistentialType::_getHashCodeOverride()
-{
- return combineHash(declRef.getHashCode(), originalInterfaceType->getHashCode(), originalInterfaceDeclRef.getHashCode());
+ out << getDeclRef() << toSlice(".This");
}
Type* ExtractExistentialType::_createCanonicalTypeOverride()
@@ -734,18 +502,16 @@ Type* ExtractExistentialType::_createCanonicalTypeOverride()
Val* ExtractExistentialType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff);
- auto substOriginalInterfaceType = originalInterfaceType->substituteImpl(astBuilder, subst, &diff);
- auto substOriginalInterfaceDeclRef = originalInterfaceDeclRef.substituteImpl(astBuilder, subst, &diff);
+ auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
+ auto substOriginalInterfaceType = getOriginalInterfaceType()->substituteImpl(astBuilder, subst, &diff);
+ auto substOriginalInterfaceDeclRef = getOriginalInterfaceDeclRef().substituteImpl(astBuilder, subst, &diff);
if (!diff)
return this;
(*ioDiff)++;
- ExtractExistentialType* substValue = astBuilder->create<ExtractExistentialType>();
- substValue->declRef = substDeclRef;
- substValue->originalInterfaceType = as<Type>(substOriginalInterfaceType);
- substValue->originalInterfaceDeclRef = substOriginalInterfaceDeclRef;
+ ExtractExistentialType* substValue = astBuilder->getOrCreate<ExtractExistentialType>(
+ substDeclRef, as<Type>(substOriginalInterfaceType), substOriginalInterfaceDeclRef);
return substValue;
}
@@ -754,165 +520,47 @@ SubtypeWitness* ExtractExistentialType::getSubtypeWitness()
if (auto cachedValue = this->cachedSubtypeWitness)
return cachedValue;
- ExtractExistentialSubtypeWitness* openedWitness = m_astBuilder->create<ExtractExistentialSubtypeWitness>();
- openedWitness->sub = this;
- openedWitness->sup = originalInterfaceType;
- openedWitness->declRef = this->declRef;
-
+ ExtractExistentialSubtypeWitness* openedWitness = getCurrentASTBuilder()->getOrCreate<ExtractExistentialSubtypeWitness>(this, getOriginalInterfaceType(), getDeclRef());
this->cachedSubtypeWitness = openedWitness;
return openedWitness;
}
-DeclRef<InterfaceDecl> ExtractExistentialType::getSpecializedInterfaceDeclRef()
+DeclRef<ThisTypeDecl> ExtractExistentialType::getThisTypeDeclRef()
{
- if (auto cachedValue = this->cachedSpecializedInterfaceDeclRef)
+ if (auto cachedValue = this->cachedThisTypeDeclRef)
return cachedValue;
- auto interfaceDecl = originalInterfaceDeclRef.getDecl();
+ auto interfaceDecl = getOriginalInterfaceDeclRef().getDecl();
SubtypeWitness* openedWitness = getSubtypeWitness();
- ThisTypeSubstitution* openedThisType = m_astBuilder->getOrCreateThisTypeSubstitution(
- interfaceDecl, openedWitness, originalInterfaceDeclRef.getSubst());
-
- DeclRef<InterfaceDecl> specialiedInterfaceDeclRef = m_astBuilder->getSpecializedDeclRef<InterfaceDecl>(interfaceDecl, openedThisType);
-
- this->cachedSpecializedInterfaceDeclRef = specialiedInterfaceDeclRef;
- return specialiedInterfaceDeclRef;
-}
-
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-void TaggedUnionType::_toTextOverride(StringBuilder& out)
-{
- out << toSlice("__TaggedUnion(");
- bool first = true;
- for (auto caseType : caseTypes)
- {
- if (!first)
+ ThisTypeDecl* thisTypeDecl = nullptr;
+ for (auto member : interfaceDecl->members)
+ if (as<ThisTypeDecl>(member))
{
- out << toSlice(", ");
+ thisTypeDecl = as<ThisTypeDecl>(member);
+ break;
}
- first = false;
-
- out << caseType;
- }
- out << toSlice(")");
-}
-
-bool TaggedUnionType::_equalsImplOverride(Type* type)
-{
- auto taggedUnion = as<TaggedUnionType>(type);
- if (!taggedUnion)
- return false;
-
- auto caseCount = caseTypes.getCount();
- if (caseCount != taggedUnion->caseTypes.getCount())
- return false;
-
- for (Index ii = 0; ii < caseCount; ++ii)
- {
- if (!caseTypes[ii]->equals(taggedUnion->caseTypes[ii]))
- return false;
- }
- return true;
-}
-
-HashCode TaggedUnionType::_getHashCodeOverride()
-{
- HashCode hashCode = 0;
- for (auto caseType : caseTypes)
- {
- hashCode = combineHash(hashCode, caseType->getHashCode());
- }
- return hashCode;
-}
-
-Type* TaggedUnionType::_createCanonicalTypeOverride()
-{
- TaggedUnionType* canType = m_astBuilder->create<TaggedUnionType>();
-
- for (auto caseType : caseTypes)
- {
- auto canCaseType = caseType->getCanonicalType();
- canType->caseTypes.add(canCaseType);
- }
-
- return canType;
-}
+ SLANG_ASSERT(thisTypeDecl);
-Val* TaggedUnionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
+ DeclRef<ThisTypeDecl> specialiedInterfaceDeclRef = getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl);
- List<Type*> substCaseTypes;
- for (auto caseType : caseTypes)
- {
- substCaseTypes.add(as<Type>(caseType->substituteImpl(astBuilder, subst, &diff)));
- }
- if (!diff)
- return this;
-
- (*ioDiff)++;
-
- TaggedUnionType* substType = astBuilder->create<TaggedUnionType>();
- substType->caseTypes.swapWith(substCaseTypes);
- return substType;
+ this->cachedThisTypeDeclRef = specialiedInterfaceDeclRef;
+ return specialiedInterfaceDeclRef;
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExistentialSpecializedType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void ExistentialSpecializedType::_toTextOverride(StringBuilder& out)
{
- out << toSlice("__ExistentialSpecializedType(") << baseType;
- for (auto arg : args)
+ out << toSlice("__ExistentialSpecializedType(") << getBaseType();
+ for (Index i = 0; i < getArgCount(); i++)
{
- out << toSlice(", ") << arg.val;
+ out << toSlice(", ") << getArg(i).val;
}
out << toSlice(")");
}
-bool ExistentialSpecializedType::_equalsImplOverride(Type * type)
-{
- auto other = as<ExistentialSpecializedType>(type);
- if (!other)
- return false;
-
- if (!baseType->equals(other->baseType))
- return false;
-
- auto argCount = args.getCount();
- if (argCount != other->args.getCount())
- return false;
-
- for (Index ii = 0; ii < argCount; ++ii)
- {
- auto arg = args[ii];
- auto otherArg = other->args[ii];
-
- if (!arg.val->equalsVal(otherArg.val))
- return false;
-
- if (!areValsEqual(arg.witness, otherArg.witness))
- return false;
- }
- return true;
-}
-
-HashCode ExistentialSpecializedType::_getHashCodeOverride()
-{
- Hasher hasher;
- hasher.hashObject(baseType);
- for (auto arg : args)
- {
- hasher.hashObject(arg.val);
- if (auto witness = arg.witness)
- hasher.hashObject(witness);
- }
- return hasher.getResult();
-}
-
static Val* _getCanonicalValue(Val* val)
{
if (!val)
@@ -928,16 +576,21 @@ static Val* _getCanonicalValue(Val* val)
Type* ExistentialSpecializedType::_createCanonicalTypeOverride()
{
- ExistentialSpecializedType* canType = m_astBuilder->create<ExistentialSpecializedType>();
+ ExpandedSpecializationArgs newArgs;
- canType->baseType = baseType->getCanonicalType();
- for (auto arg : args)
+ for (Index ii = 0; ii < getArgCount(); ++ii)
{
+ auto arg = getArg(ii);
ExpandedSpecializationArg canArg;
canArg.val = _getCanonicalValue(arg.val);
canArg.witness = _getCanonicalValue(arg.witness);
- canType->args.add(canArg);
+ newArgs.add(canArg);
}
+
+ ExistentialSpecializedType* canType = getCurrentASTBuilder()->getOrCreate<ExistentialSpecializedType>(
+ getBaseType()->getCanonicalType(),
+ newArgs);
+
return canType;
}
@@ -951,11 +604,12 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder,
{
int diff = 0;
- auto substBaseType = as<Type>(baseType->substituteImpl(astBuilder, subst, &diff));
+ auto substBaseType = as<Type>(getBaseType()->substituteImpl(astBuilder, subst, &diff));
ExpandedSpecializationArgs substArgs;
- for (auto arg : args)
+ for (Index ii = 0; ii < getArgCount(); ++ii)
{
+ auto arg = getArg(ii);
ExpandedSpecializationArg substArg;
substArg.val = _substituteImpl(astBuilder, arg.val, subst, &diff);
substArg.witness = _substituteImpl(astBuilder, arg.witness, subst, &diff);
@@ -967,96 +621,22 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder,
(*ioDiff)++;
- ExistentialSpecializedType* substType = astBuilder->create<ExistentialSpecializedType>();
- substType->baseType = substBaseType;
- substType->args = substArgs;
+ ExistentialSpecializedType* substType = astBuilder->getOrCreate<ExistentialSpecializedType>(substBaseType, substArgs);
return substType;
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-void ThisType::_toTextOverride(StringBuilder& out)
-{
- out << interfaceDeclRef << toSlice(".This");
-}
-
-bool ThisType::_equalsImplOverride(Type * type)
-{
- auto other = as<ThisType>(type);
- if (!other)
- return false;
-
- if (!interfaceDeclRef.equals(other->interfaceDeclRef))
- return false;
-
- return true;
-}
-
-HashCode ThisType::_getHashCodeOverride()
-{
- return combineHash(
- HashCode(typeid(*this).hash_code()),
- interfaceDeclRef.getHashCode());
-}
-
-Type* ThisType::_createCanonicalTypeOverride()
+InterfaceDecl* ThisType::getInterfaceDecl()
{
- ThisType* canType = m_astBuilder->create<ThisType>();
-
- // TODO: need to canonicalize the decl-ref
- canType->interfaceDeclRef = interfaceDeclRef;
- return canType;
-}
-
-Val* ThisType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
-
- auto substInterfaceDeclRef = interfaceDeclRef.substituteImpl(astBuilder, subst, &diff);
-
- auto thisTypeSubst = findThisTypeSubstitution(subst.substitutions, substInterfaceDeclRef.getDecl());
- if (thisTypeSubst)
- {
- return thisTypeSubst->witness->sub;
- }
-
- if (!diff)
- return this;
-
- (*ioDiff)++;
-
- ThisType* substType = m_astBuilder->create<ThisType>();
- substType->interfaceDeclRef = substInterfaceDeclRef;
- return substType;
+ return dynamicCast<InterfaceDecl>(getDeclRefBase()->getDecl()->parentDecl);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AndType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void AndType::_toTextOverride(StringBuilder& out)
{
- out << left << toSlice(" & ") << right;
-}
-
-bool AndType::_equalsImplOverride(Type * type)
-{
- auto other = as<AndType>(type);
- if (!other)
- return false;
-
- if(!left->equals(other->left))
- return false;
- if(!right->equals(other->right))
- return false;
-
- return true;
-}
-
-HashCode AndType::_getHashCodeOverride()
-{
- Hasher hasher;
- hasher.hashObject(left);
- hasher.hashObject(right);
- return hasher.getResult();
+ out << getLeft() << toSlice(" & ") << getRight();
}
Type* AndType::_createCanonicalTypeOverride()
@@ -1094,9 +674,9 @@ Type* AndType::_createCanonicalTypeOverride()
// right now, in the name of getting something up and running.
//
- auto canLeft = left->getCanonicalType();
- auto canRight = right->getCanonicalType();
- auto canType = m_astBuilder->getAndType(canLeft, canRight);
+ auto canLeft = getLeft()->getCanonicalType();
+ auto canRight = getRight()->getCanonicalType();
+ auto canType = getCurrentASTBuilder()->getAndType(canLeft, canRight);
return canType;
}
@@ -1104,15 +684,15 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su
{
int diff = 0;
- auto substLeft = as<Type>(left ->substituteImpl(astBuilder, subst, &diff));
- auto substRight = as<Type>(right->substituteImpl(astBuilder, subst, &diff));
+ auto substLeft = as<Type>(getLeft()->substituteImpl(astBuilder, subst, &diff));
+ auto substRight = as<Type>(getRight()->substituteImpl(astBuilder, subst, &diff));
if(!diff)
return this;
(*ioDiff)++;
- auto substType = m_astBuilder->getAndType(substLeft, substRight);
+ auto substType = getCurrentASTBuilder()->getAndType(substLeft, substRight);
return substType;
}
@@ -1120,83 +700,35 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su
void ModifiedType::_toTextOverride(StringBuilder& out)
{
- for( auto modifier : modifiers )
+ for( Index i = 0; i < getModifierCount(); i++ )
{
- modifier->toText(out);
+ getModifier(i)->toText(out);
out.appendChar(' ');
}
- base->toText(out);
-}
-
-bool ModifiedType::_equalsImplOverride(Type* type)
-{
- auto other = as<ModifiedType>(type);
- if(!other)
- return false;
-
- if(!base->equals(other->base))
- return false;
-
- // TODO: Eventually we need to put the `modifiers` into
- // a canonical ordering as part of creation of a `ModifiedType`,
- // so that two instances that apply the same modifiers to
- // the same type will have those modifiers in a matching order.
- //
- // The simplest way to achieve that ordering *for now* would
- // be to sort the array by the integer AST node type tag.
- // That approach would of course not scale to modifiers that
- // have any operands of their own.
- //
- // Note that we would *also* need the logic that creates a
- // `ModifiedType` to detect when the base type is itself a
- // `ModifiedType` and produce a single `ModifiedType` with
- // a combined list of modifiers and a non-`ModifiedType` as
- // its base type.
- //
- auto modifierCount = modifiers.getCount();
- if(modifierCount != other->modifiers.getCount())
- return false;
-
- for( Index i = 0; i < modifierCount; ++i )
- {
- auto thisModifier = this->modifiers[i];
- auto otherModifier = other->modifiers[i];
- if(!thisModifier->equalsVal(otherModifier))
- return false;
- }
- return true;
-}
-
-HashCode ModifiedType::_getHashCodeOverride()
-{
- Hasher hasher;
- hasher.hashObject(base);
- for( auto modifier : modifiers )
- {
- hasher.hashObject(modifier);
- }
- return hasher.getResult();
+ getBase()->toText(out);
}
Type* ModifiedType::_createCanonicalTypeOverride()
{
- ModifiedType* canonical = m_astBuilder->create<ModifiedType>();
- canonical->base = base->getCanonicalType();
- for( auto modifier : modifiers )
+ List<Val*> modifiers;
+ for (Index i = 0; i < getModifierCount(); ++i)
{
- canonical->modifiers.add(modifier);
+ auto modifier = this->getModifier(i);
+ modifiers.add(modifier);
}
+ ModifiedType* canonical = getCurrentASTBuilder()->getOrCreate<ModifiedType>(getBase()->getCanonicalType(), modifiers.getArrayView());
return canonical;
}
Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- Type* substBase = as<Type>(base->substituteImpl(astBuilder, subst, &diff));
+ Type* substBase = as<Type>(getBase()->substituteImpl(astBuilder, subst, &diff));
List<Val*> substModifiers;
- for( auto modifier : modifiers )
+ for (Index i = 0; i < getModifierCount(); ++i)
{
+ auto modifier = this->getModifier(i);
auto substModifier = modifier->substituteImpl(astBuilder, subst, &diff);
substModifiers.add(substModifier);
}
@@ -1206,12 +738,49 @@ Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionS
*ioDiff = 1;
- ModifiedType* substType = m_astBuilder->create<ModifiedType>();
- substType->base = substBase;
- substType->modifiers = _Move(substModifiers);
+ ModifiedType* substType = getCurrentASTBuilder()->getOrCreate<ModifiedType>(substBase, substModifiers.getArrayView());
return substType;
}
+BaseType BasicExpressionType::getBaseType() const
+{
+ auto builtinType = getDeclRef().getDecl()->findModifier<BuiltinTypeModifier>();
+ return builtinType->tag;
+}
+
+FeedbackType::Kind FeedbackType::getKind() const
+{
+ auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>();
+ return FeedbackType::Kind(magicMod->tag);
+}
+
+TextureFlavor ResourceType::getFlavor() const
+{
+ auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>();
+ return TextureFlavor(magicMod->tag);
+}
+
+SamplerStateFlavor SamplerStateType::getFlavor() const
+{
+ auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>();
+ return SamplerStateFlavor(magicMod->tag);
+}
+
+Type* BuiltinGenericType::getElementType() const
+{
+ return as<Type>(_getGenericTypeArg(getDeclRefBase(), 0));
+}
+
+Type* ResourceType::getElementType()
+{
+ return as<Type>(_getGenericTypeArg(this, 0));
+}
+
+Val* TextureTypeBase::getSampleCount()
+{
+ return as<Type>(_getGenericTypeArg(this, 1));
+}
+
Type* removeParamDirType(Type* type)
{
for (auto paramDirType = as<ParamDirectionType>(type); paramDirType;)