diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-04 15:47:39 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-04 15:47:39 -0700 |
| commit | a2d90fb275962da84611160f8ddd74d934a68dbd (patch) | |
| tree | 066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-ast-type.cpp | |
| parent | 17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff) | |
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val.
* Update project files
* Fix warning.
* Fix.
* Fix.
* Remove `Val::_equalsImplOverride`.
* Rmove `Val::_getHashCodeOverride`.
* Remove `semanticVisitor` param from `resolve`.
* Cleanups.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-type.cpp')
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 827 |
1 files changed, 198 insertions, 629 deletions
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index ee5d1d40e..13133a7f8 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -1,49 +1,19 @@ // slang-ast-type.cpp #include "slang-ast-builder.h" +#include "slang-ast-modifier.h" #include <assert.h> #include <typeinfo> #include "slang-syntax.h" #include "slang-generated-ast-macro.h" - namespace Slang { // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Type* Type::createCanonicalType() -{ - SLANG_AST_NODE_VIRTUAL_CALL(Type, createCanonicalType, ()) -} - -bool Type::equals(Type* type) -{ - return getCanonicalType()->equalsImpl(type->getCanonicalType()); -} - -bool Type::equalsImpl(Type* type) -{ - SLANG_AST_NODE_VIRTUAL_CALL(Type, equalsImpl, (type)) -} - -bool Type::_equalsImplOverride(Type* type) -{ - SLANG_UNUSED(type) - SLANG_UNEXPECTED("Type::_equalsImplOverride not overridden"); - //return false; -} - Type* Type::_createCanonicalTypeOverride() { - SLANG_UNEXPECTED("Type::_createCanonicalTypeOverride not overridden"); - //return Type*(); -} - -bool Type::_equalsValOverride(Val* val) -{ - if (auto type = dynamicCast<Type>(val)) - return const_cast<Type*>(this)->equals(type); - return false; + return as<Type>(defaultResolveImpl()); } Val* Type::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) @@ -61,20 +31,6 @@ Val* Type::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst return canSubst; } -Type* Type::getCanonicalType() -{ - Type* et = const_cast<Type*>(this); - if (!et->canonicalType) - { - // TODO(tfoley): worry about thread safety here? - auto canType = et->createCanonicalType(); - et->canonicalType = canType; - if (!et->canonicalType) - return getASTBuilder()->getErrorType(); - } - return et->canonicalType; -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! OverloadGroupType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void OverloadGroupType::_toTextOverride(StringBuilder& out) @@ -82,21 +38,11 @@ void OverloadGroupType::_toTextOverride(StringBuilder& out) out << toSlice("overload group"); } -bool OverloadGroupType::_equalsImplOverride(Type * /*type*/) -{ - return false; -} - Type* OverloadGroupType::_createCanonicalTypeOverride() { return this; } -HashCode OverloadGroupType::_getHashCodeOverride() -{ - return (HashCode)(size_t(this)); -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! InitializerListType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void InitializerListType::_toTextOverride(StringBuilder& out) @@ -104,21 +50,11 @@ void InitializerListType::_toTextOverride(StringBuilder& out) out << toSlice("initializer list"); } -bool InitializerListType::_equalsImplOverride(Type * /*type*/) -{ - return false; -} - Type* InitializerListType::_createCanonicalTypeOverride() { return this; } -HashCode InitializerListType::_getHashCodeOverride() -{ - return (HashCode)(size_t(this)); -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ErrorType::_toTextOverride(StringBuilder& out) @@ -126,11 +62,6 @@ void ErrorType::_toTextOverride(StringBuilder& out) out << toSlice("error"); } -bool ErrorType::_equalsImplOverride(Type* type) -{ - return as<ErrorType>(type); -} - Type* ErrorType::_createCanonicalTypeOverride() { return this; @@ -141,56 +72,21 @@ Val* ErrorType::_substituteImplOverride(ASTBuilder* /* astBuilder */, Substituti return this; } -HashCode ErrorType::_getHashCodeOverride() -{ - return HashCode(size_t(this)); -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BottomType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void BottomType::_toTextOverride(StringBuilder& out) { out << toSlice("never"); } -bool BottomType::_equalsImplOverride(Type* type) -{ - return as<BottomType>(type); -} - -Type* BottomType::_createCanonicalTypeOverride() { return this; } - Val* BottomType::_substituteImplOverride( ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) { return this; } -HashCode BottomType::_getHashCodeOverride() { return HashCode(size_t(this)); } - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void DeclRefType::_toTextOverride(StringBuilder& out) { - out << declRef; -} - -HashCode DeclRefType::_getHashCodeOverride() -{ - return (declRef.getHashCode() * 16777619) ^ (HashCode)(typeid(this).hash_code()); -} - -bool DeclRefType::_equalsImplOverride(Type * type) -{ - if (auto declRefType = as<DeclRefType>(type)) - { - return declRef.equals(declRefType->declRef); - } - return false; -} - -Type* DeclRefType::_createCanonicalTypeOverride() -{ - // A declaration reference is already canonical - declRef.substitute(this->getASTBuilder(), this); - return this; + out << getDeclRef(); } Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff); @@ -199,26 +95,47 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe { if (!subst) return this; - // the case we especially care about is when this type references a declaration - // of a generic parameter, since that is what we might be substituting... - if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(declRef.getDecl())) + int diff = 0; + DeclRef<Decl> substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); + + // If this declref type is a direct reference to ThisType or a Generic parameter, + // and `subst` provides an argument for it, then we should just return that argument. + // + if (as<DirectDeclRef>(substDeclRef.declRefBase)) { - if (auto result = maybeSubstituteGenericParam(this, genericTypeParamDecl, subst, ioDiff)) + if (auto thisDecl = as<ThisTypeDecl>(substDeclRef.getDecl())) + { + auto lookupDeclRef = subst.findLookupDeclRef(); + if (lookupDeclRef && lookupDeclRef->getSupDecl() == substDeclRef.getDecl()->parentDecl) + { + (*ioDiff)++; + return lookupDeclRef->getLookupSource(); + } + } + else if (as<GenericTypeParamDecl>(substDeclRef.getDecl()) || as<GenericValueParamDecl>(substDeclRef.getDecl())) { - if (auto substDeclRefType = as<DeclRefType>(result)) + auto resultVal = maybeSubstituteGenericParam(nullptr, substDeclRef.getDecl(), subst, ioDiff); + if (resultVal) { - // After generic substitution, we may be able to further simplify - // by looking up the actual type of an associated type. - if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst( - astBuilder, substDeclRefType->declRef)) - return satisfyingVal; + (*ioDiff)++; + return resultVal; } - return result; } } - int diff = 0; - DeclRef<Decl> substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + // If this type is a reference to an associated type declaration, + // and the substitutions provide a "this type" substitution for + // the outer interface, then try to replace the type with the + // actual value of the associated type for the given implementation. + // + if (auto satisfyingVal = substDeclRef.declRefBase->resolve()) + { + if (satisfyingVal != getDeclRef()) + { + *ioDiff += 1; + return DeclRefType::create(astBuilder, substDeclRef); + } + } if (!diff) return this; @@ -226,14 +143,6 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe // Make sure to record the difference! *ioDiff += diff; - // If this type is a reference to an associated type declaration, - // and the substitutions provide a "this type" substitution for - // the outer interface, then try to replace the type with the - // actual value of the associated type for the given implementation. - // - if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, substDeclRef)) - return satisfyingVal; - // Re-construct the type in case we are using a specialized sub-class return DeclRefType::create(astBuilder, substDeclRef); } @@ -254,40 +163,52 @@ BasicExpressionType* ArithmeticExpressionType::_getScalarTypeOverride() // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BasicExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool BasicExpressionType::_equalsImplOverride(Type * type) +BasicExpressionType* BasicExpressionType::_getScalarTypeOverride() { - auto basicType = as<BasicExpressionType>(type); - return basicType && basicType->baseType == this->baseType; + return this; } -Type* BasicExpressionType::_createCanonicalTypeOverride() +static Val* _getGenericTypeArg(DeclRefBase* declRef, Index i) { - // A basic type is already canonical, in our setup - return this; + auto args = findInnerMostGenericArgs(SubstitutionSet(declRef)); + if (args.getCount() <= i) + return nullptr; + + return args[i]; } -BasicExpressionType* BasicExpressionType::_getScalarTypeOverride() +static Val* _getGenericTypeArg(DeclRefType* declRefType, Index i) { - return this; + return _getGenericTypeArg(declRefType->getDeclRefBase(), i); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TensorViewType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* TensorViewType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! VectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +Type* VectorExpressionType::getElementType() +{ + return as<Type>(_getGenericTypeArg(this, 0)); +} + +IntVal* VectorExpressionType::getElementCount() +{ + return as<IntVal>(_getGenericTypeArg(this, 1)); +} + void VectorExpressionType::_toTextOverride(StringBuilder& out) { - out << toSlice("vector<") << elementType << toSlice(",") << elementCount << toSlice(">"); + out << toSlice("vector<") << getElementType() << toSlice(",") << getElementCount() << toSlice(">"); } BasicExpressionType* VectorExpressionType::_getScalarTypeOverride() { - return as<BasicExpressionType>(elementType); + return as<BasicExpressionType>(getElementType()); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MatrixExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -304,24 +225,24 @@ BasicExpressionType* MatrixExpressionType::_getScalarTypeOverride() Type* MatrixExpressionType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); } IntVal* MatrixExpressionType::getRowCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); + return as<IntVal>(_getGenericTypeArg(this, 1)); } IntVal* MatrixExpressionType::getColumnCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[2]); + return as<IntVal>(_getGenericTypeArg(this, 2)); } Type* MatrixExpressionType::getRowType() { if (!rowType) { - rowType = m_astBuilder->getVectorType(getElementType(), getColumnCount()); + rowType = getCurrentASTBuilder()->getVectorType(getElementType(), getColumnCount()); } return rowType; } @@ -330,12 +251,12 @@ Type* MatrixExpressionType::getRowType() Type* ArrayExpressionType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); } IntVal* ArrayExpressionType::getElementCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); + return as<IntVal>(_getGenericTypeArg(this, 1)); } void ArrayExpressionType::_toTextOverride(StringBuilder& out) @@ -353,7 +274,7 @@ bool ArrayExpressionType::isUnsized() { if (auto constSize = as<ConstantIntVal>(getElementCount())) { - if (constSize->value == kUnsizedArrayMagicLength) + if (constSize->getValue() == kUnsizedArrayMagicLength) return true; } return false; @@ -363,27 +284,12 @@ bool ArrayExpressionType::isUnsized() void TypeType::_toTextOverride(StringBuilder& out) { - out << toSlice("typeof(") << type << toSlice(")"); -} - -bool TypeType::_equalsImplOverride(Type * t) -{ - if (auto typeType = as<TypeType>(t)) - { - return t->equals(typeType->type); - } - return false; + out << toSlice("typeof(") << getType() << toSlice(")"); } Type* TypeType::_createCanonicalTypeOverride() { - return getASTBuilder()->getTypeType(type->getCanonicalType()); -} - -HashCode TypeType::_getHashCodeOverride() -{ - SLANG_UNEXPECTED("TypeType::_getHashCodeOverride should be unreachable"); - //return HashCode(0); + return getCurrentASTBuilder()->getTypeType(getType()->getCanonicalType()); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericDeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -394,20 +300,6 @@ void GenericDeclRefType::_toTextOverride(StringBuilder& out) out << toSlice("<DeclRef<GenericDecl>>"); } -bool GenericDeclRefType::_equalsImplOverride(Type * type) -{ - if (auto genericDeclRefType = as<GenericDeclRefType>(type)) - { - return declRef.equals(genericDeclRefType->declRef); - } - return false; -} - -HashCode GenericDeclRefType::_getHashCodeOverride() -{ - return declRef.getHashCode(); -} - Type* GenericDeclRefType::_createCanonicalTypeOverride() { return this; @@ -417,21 +309,7 @@ Type* GenericDeclRefType::_createCanonicalTypeOverride() void NamespaceType::_toTextOverride(StringBuilder& out) { - out << toSlice("namespace ") << declRef; -} - -bool NamespaceType::_equalsImplOverride(Type * type) -{ - if (auto namespaceType = as<NamespaceType>(type)) - { - return declRef.equals(namespaceType->declRef); - } - return false; -} - -HashCode NamespaceType::_getHashCodeOverride() -{ - return declRef.getHashCode(); + out << toSlice("namespace ") << getDeclRef(); } Type* NamespaceType::_createCanonicalTypeOverride() @@ -441,7 +319,7 @@ Type* NamespaceType::_createCanonicalTypeOverride() Type* DifferentialPairType::getPrimalType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); } @@ -449,51 +327,35 @@ Type* DifferentialPairType::getPrimalType() Type* PtrTypeBase::getValueType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); } Type* OptionalType::getValueType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); +} + +Type* NativeRefType::getValueType() +{ + return as<Type>(_getGenericTypeArg(this, 0)); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void NamedExpressionType::_toTextOverride(StringBuilder& out) { - if (declRef.getDecl()) + if (getDeclRef().getDecl()) { - _printNestedDecl(declRef.getSubst(), declRef.getDecl(), out); + getDeclRef().declRefBase->toText(out); } } -bool NamedExpressionType::_equalsImplOverride(Type * /*type*/) -{ - SLANG_UNEXPECTED("NamedExpressionType::_equalsImplOverride should be unreachable"); - //return false; -} - Type* NamedExpressionType::_createCanonicalTypeOverride() { - if (!innerType) - innerType = getType(m_astBuilder, declRef); - if (innerType) - return innerType->getCanonicalType(); - return nullptr; -} - -HashCode NamedExpressionType::_getHashCodeOverride() -{ - // Type equality is based on comparing canonical types, - // so the hash code for a type needs to come from the - // canonical version of the type. This really means - // that `Type::getHashCode()` should dispatch out to - // something like `Type::getHashCodeImpl()` on the - // canonical version of a type, but it is less invasive - // for now (and hopefully equivalent) to just have any - // named types automaticlaly route hash-code requests - // to their canonical type. - return getCanonicalType()->getHashCode(); + auto canType = getType(getCurrentASTBuilder(), getDeclRef()); + if (canType) + return canType->getCanonicalType(); + return getCurrentASTBuilder()->getErrorType(); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -533,58 +395,27 @@ void FuncType::_toTextOverride(StringBuilder& out) } out << ") -> " << getResultType(); - if (!getErrorType()->equals(getASTBuilder()->getBottomType())) + if (!getErrorType()->equals(getCurrentASTBuilder()->getBottomType())) { out << " throws " << getErrorType(); } } -bool FuncType::_equalsImplOverride(Type * type) -{ - if (auto funcType = as<FuncType>(type)) - { - auto paramCount = getParamCount(); - auto otherParamCount = funcType->getParamCount(); - if (paramCount != otherParamCount) - return false; - - for (Index pp = 0; pp < paramCount; ++pp) - { - auto paramType = getParamType(pp); - auto otherParamType = funcType->getParamType(pp); - if (!paramType->equals(otherParamType)) - return false; - } - - if (!resultType->equals(funcType->resultType)) - return false; - - if (!errorType->equals(funcType->errorType)) - return false; - - // TODO: if we ever introduce other kinds - // of qualification on function types, we'd - // want to consider it here. - return true; - } - return false; -} - Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; // result type - Type* substResultType = as<Type>(resultType->substituteImpl(astBuilder, subst, &diff)); + Type* substResultType = as<Type>(getResultType()->substituteImpl(astBuilder, subst, &diff)); // error type - Type* substErrorType = as<Type>(errorType->substituteImpl(astBuilder, subst, &diff)); + Type* substErrorType = as<Type>(getErrorType()->substituteImpl(astBuilder, subst, &diff)); // parameter types List<Type*> substParamTypes; - for (auto pp : paramTypes) + for (Index pp = 0; pp < getParamCount(); pp++ ) { - substParamTypes.add(as<Type>(pp->substituteImpl(astBuilder, subst, &diff))); + substParamTypes.add(as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff))); } // early exit for no change... @@ -592,138 +423,75 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s return this; (*ioDiff)++; - FuncType* substType = astBuilder->create<FuncType>(); - substType->resultType = substResultType; - substType->paramTypes = substParamTypes; - substType->errorType = substErrorType; + FuncType* substType = astBuilder->getFuncType(substParamTypes.getArrayView(), substResultType, substErrorType); return substType; } Type* FuncType::_createCanonicalTypeOverride() { // result type - Type* canResultType = resultType->getCanonicalType(); - Type* canErrorType = errorType->getCanonicalType(); + Type* canResultType = getResultType()->getCanonicalType(); + Type* canErrorType = getErrorType()->getCanonicalType(); // parameter types List<Type*> canParamTypes; - for (auto pp : paramTypes) + for (Index pp = 0; pp < getParamCount(); pp++) { - canParamTypes.add(pp->getCanonicalType()); + canParamTypes.add(getParamType(pp)->getCanonicalType()); } - FuncType* canType = getASTBuilder()->create<FuncType>(); - canType->resultType = canResultType; - canType->paramTypes = canParamTypes; - canType->errorType = canErrorType; + FuncType* canType = getCurrentASTBuilder()->getFuncType(canParamTypes.getArrayView(), canResultType, canErrorType); return canType; } -HashCode FuncType::_getHashCodeOverride() -{ - HashCode hashCode = getResultType()->getHashCode(); - Index paramCount = getParamCount(); - hashCode = combineHash(hashCode, Slang::getHashCode(paramCount)); - for (Index pp = 0; pp < paramCount; ++pp) - { - hashCode = combineHash( - hashCode, - getParamType(pp)->getHashCode()); - } - combineHash(hashCode, getErrorType()->getHashCode()); - return hashCode; -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TupleType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void TupleType::_toTextOverride(StringBuilder& out) { out << toSlice("("); - for (Index pp = 0; pp < memberTypes.getCount(); ++pp) + for (Index pp = 0; pp < getOperandCount(); ++pp) { if (pp != 0) out << toSlice(", "); - out << memberTypes[pp]; + out << getOperand(pp); } out << toSlice(")"); } -bool TupleType::_equalsImplOverride(Type * type) -{ - if (const auto other = as<TupleType>(type)) - { - auto paramCount = memberTypes.getCount(); - auto otherParamCount = other->memberTypes.getCount(); - if (paramCount != otherParamCount) - return false; - - for (Index i = 0; i < memberTypes.getCount(); ++i) - { - if(!memberTypes[i]->equals(other->memberTypes[i])) - return false; - } - - return true; - } - return false; -} - Val* TupleType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; // just recurse into the members List<Type*> substMemberTypes; - for (auto m : memberTypes) - substMemberTypes.add(as<Type>(m->substituteImpl(astBuilder, subst, &diff))); + for (Index m = 0; m < getMemberCount(); m++) + substMemberTypes.add(as<Type>(getMember(m)->substituteImpl(astBuilder, subst, &diff))); // early exit for no change... if (!diff) return this; (*ioDiff)++; - return astBuilder->create<TupleType>(std::move(substMemberTypes)); + return astBuilder->getTupleType(substMemberTypes); } Type* TupleType::_createCanonicalTypeOverride() { // member types List<Type*> canMemberTypes; - for (auto m : memberTypes) + for (Index m = 0; m < getMemberCount(); m++) { - canMemberTypes.add(m->getCanonicalType()); + canMemberTypes.add(getMember(m)->getCanonicalType()); } - return getASTBuilder()->create<TupleType>(std::move(canMemberTypes)); -} - -HashCode TupleType::_getHashCodeOverride() -{ - HashCode hashCode = Slang::getHashCode(kType); - for(auto m : memberTypes) - hashCode = combineHash(hashCode, m->getHashCode()); - return hashCode; + return getCurrentASTBuilder()->getTupleType(canMemberTypes); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExtractExistentialType::_toTextOverride(StringBuilder& out) { - out << declRef << toSlice(".This"); -} - -bool ExtractExistentialType::_equalsImplOverride(Type* type) -{ - if (auto extractExistential = as<ExtractExistentialType>(type)) - { - return declRef.equals(extractExistential->declRef); - } - return false; -} - -HashCode ExtractExistentialType::_getHashCodeOverride() -{ - return combineHash(declRef.getHashCode(), originalInterfaceType->getHashCode(), originalInterfaceDeclRef.getHashCode()); + out << getDeclRef() << toSlice(".This"); } Type* ExtractExistentialType::_createCanonicalTypeOverride() @@ -734,18 +502,16 @@ Type* ExtractExistentialType::_createCanonicalTypeOverride() Val* ExtractExistentialType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); - auto substOriginalInterfaceType = originalInterfaceType->substituteImpl(astBuilder, subst, &diff); - auto substOriginalInterfaceDeclRef = originalInterfaceDeclRef.substituteImpl(astBuilder, subst, &diff); + auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); + auto substOriginalInterfaceType = getOriginalInterfaceType()->substituteImpl(astBuilder, subst, &diff); + auto substOriginalInterfaceDeclRef = getOriginalInterfaceDeclRef().substituteImpl(astBuilder, subst, &diff); if (!diff) return this; (*ioDiff)++; - ExtractExistentialType* substValue = astBuilder->create<ExtractExistentialType>(); - substValue->declRef = substDeclRef; - substValue->originalInterfaceType = as<Type>(substOriginalInterfaceType); - substValue->originalInterfaceDeclRef = substOriginalInterfaceDeclRef; + ExtractExistentialType* substValue = astBuilder->getOrCreate<ExtractExistentialType>( + substDeclRef, as<Type>(substOriginalInterfaceType), substOriginalInterfaceDeclRef); return substValue; } @@ -754,165 +520,47 @@ SubtypeWitness* ExtractExistentialType::getSubtypeWitness() if (auto cachedValue = this->cachedSubtypeWitness) return cachedValue; - ExtractExistentialSubtypeWitness* openedWitness = m_astBuilder->create<ExtractExistentialSubtypeWitness>(); - openedWitness->sub = this; - openedWitness->sup = originalInterfaceType; - openedWitness->declRef = this->declRef; - + ExtractExistentialSubtypeWitness* openedWitness = getCurrentASTBuilder()->getOrCreate<ExtractExistentialSubtypeWitness>(this, getOriginalInterfaceType(), getDeclRef()); this->cachedSubtypeWitness = openedWitness; return openedWitness; } -DeclRef<InterfaceDecl> ExtractExistentialType::getSpecializedInterfaceDeclRef() +DeclRef<ThisTypeDecl> ExtractExistentialType::getThisTypeDeclRef() { - if (auto cachedValue = this->cachedSpecializedInterfaceDeclRef) + if (auto cachedValue = this->cachedThisTypeDeclRef) return cachedValue; - auto interfaceDecl = originalInterfaceDeclRef.getDecl(); + auto interfaceDecl = getOriginalInterfaceDeclRef().getDecl(); SubtypeWitness* openedWitness = getSubtypeWitness(); - ThisTypeSubstitution* openedThisType = m_astBuilder->getOrCreateThisTypeSubstitution( - interfaceDecl, openedWitness, originalInterfaceDeclRef.getSubst()); - - DeclRef<InterfaceDecl> specialiedInterfaceDeclRef = m_astBuilder->getSpecializedDeclRef<InterfaceDecl>(interfaceDecl, openedThisType); - - this->cachedSpecializedInterfaceDeclRef = specialiedInterfaceDeclRef; - return specialiedInterfaceDeclRef; -} - - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -void TaggedUnionType::_toTextOverride(StringBuilder& out) -{ - out << toSlice("__TaggedUnion("); - bool first = true; - for (auto caseType : caseTypes) - { - if (!first) + ThisTypeDecl* thisTypeDecl = nullptr; + for (auto member : interfaceDecl->members) + if (as<ThisTypeDecl>(member)) { - out << toSlice(", "); + thisTypeDecl = as<ThisTypeDecl>(member); + break; } - first = false; - - out << caseType; - } - out << toSlice(")"); -} - -bool TaggedUnionType::_equalsImplOverride(Type* type) -{ - auto taggedUnion = as<TaggedUnionType>(type); - if (!taggedUnion) - return false; - - auto caseCount = caseTypes.getCount(); - if (caseCount != taggedUnion->caseTypes.getCount()) - return false; - - for (Index ii = 0; ii < caseCount; ++ii) - { - if (!caseTypes[ii]->equals(taggedUnion->caseTypes[ii])) - return false; - } - return true; -} - -HashCode TaggedUnionType::_getHashCodeOverride() -{ - HashCode hashCode = 0; - for (auto caseType : caseTypes) - { - hashCode = combineHash(hashCode, caseType->getHashCode()); - } - return hashCode; -} - -Type* TaggedUnionType::_createCanonicalTypeOverride() -{ - TaggedUnionType* canType = m_astBuilder->create<TaggedUnionType>(); - - for (auto caseType : caseTypes) - { - auto canCaseType = caseType->getCanonicalType(); - canType->caseTypes.add(canCaseType); - } - - return canType; -} + SLANG_ASSERT(thisTypeDecl); -Val* TaggedUnionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; + DeclRef<ThisTypeDecl> specialiedInterfaceDeclRef = getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl); - List<Type*> substCaseTypes; - for (auto caseType : caseTypes) - { - substCaseTypes.add(as<Type>(caseType->substituteImpl(astBuilder, subst, &diff))); - } - if (!diff) - return this; - - (*ioDiff)++; - - TaggedUnionType* substType = astBuilder->create<TaggedUnionType>(); - substType->caseTypes.swapWith(substCaseTypes); - return substType; + this->cachedThisTypeDeclRef = specialiedInterfaceDeclRef; + return specialiedInterfaceDeclRef; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExistentialSpecializedType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExistentialSpecializedType::_toTextOverride(StringBuilder& out) { - out << toSlice("__ExistentialSpecializedType(") << baseType; - for (auto arg : args) + out << toSlice("__ExistentialSpecializedType(") << getBaseType(); + for (Index i = 0; i < getArgCount(); i++) { - out << toSlice(", ") << arg.val; + out << toSlice(", ") << getArg(i).val; } out << toSlice(")"); } -bool ExistentialSpecializedType::_equalsImplOverride(Type * type) -{ - auto other = as<ExistentialSpecializedType>(type); - if (!other) - return false; - - if (!baseType->equals(other->baseType)) - return false; - - auto argCount = args.getCount(); - if (argCount != other->args.getCount()) - return false; - - for (Index ii = 0; ii < argCount; ++ii) - { - auto arg = args[ii]; - auto otherArg = other->args[ii]; - - if (!arg.val->equalsVal(otherArg.val)) - return false; - - if (!areValsEqual(arg.witness, otherArg.witness)) - return false; - } - return true; -} - -HashCode ExistentialSpecializedType::_getHashCodeOverride() -{ - Hasher hasher; - hasher.hashObject(baseType); - for (auto arg : args) - { - hasher.hashObject(arg.val); - if (auto witness = arg.witness) - hasher.hashObject(witness); - } - return hasher.getResult(); -} - static Val* _getCanonicalValue(Val* val) { if (!val) @@ -928,16 +576,21 @@ static Val* _getCanonicalValue(Val* val) Type* ExistentialSpecializedType::_createCanonicalTypeOverride() { - ExistentialSpecializedType* canType = m_astBuilder->create<ExistentialSpecializedType>(); + ExpandedSpecializationArgs newArgs; - canType->baseType = baseType->getCanonicalType(); - for (auto arg : args) + for (Index ii = 0; ii < getArgCount(); ++ii) { + auto arg = getArg(ii); ExpandedSpecializationArg canArg; canArg.val = _getCanonicalValue(arg.val); canArg.witness = _getCanonicalValue(arg.witness); - canType->args.add(canArg); + newArgs.add(canArg); } + + ExistentialSpecializedType* canType = getCurrentASTBuilder()->getOrCreate<ExistentialSpecializedType>( + getBaseType()->getCanonicalType(), + newArgs); + return canType; } @@ -951,11 +604,12 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder, { int diff = 0; - auto substBaseType = as<Type>(baseType->substituteImpl(astBuilder, subst, &diff)); + auto substBaseType = as<Type>(getBaseType()->substituteImpl(astBuilder, subst, &diff)); ExpandedSpecializationArgs substArgs; - for (auto arg : args) + for (Index ii = 0; ii < getArgCount(); ++ii) { + auto arg = getArg(ii); ExpandedSpecializationArg substArg; substArg.val = _substituteImpl(astBuilder, arg.val, subst, &diff); substArg.witness = _substituteImpl(astBuilder, arg.witness, subst, &diff); @@ -967,96 +621,22 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder, (*ioDiff)++; - ExistentialSpecializedType* substType = astBuilder->create<ExistentialSpecializedType>(); - substType->baseType = substBaseType; - substType->args = substArgs; + ExistentialSpecializedType* substType = astBuilder->getOrCreate<ExistentialSpecializedType>(substBaseType, substArgs); return substType; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -void ThisType::_toTextOverride(StringBuilder& out) -{ - out << interfaceDeclRef << toSlice(".This"); -} - -bool ThisType::_equalsImplOverride(Type * type) -{ - auto other = as<ThisType>(type); - if (!other) - return false; - - if (!interfaceDeclRef.equals(other->interfaceDeclRef)) - return false; - - return true; -} - -HashCode ThisType::_getHashCodeOverride() -{ - return combineHash( - HashCode(typeid(*this).hash_code()), - interfaceDeclRef.getHashCode()); -} - -Type* ThisType::_createCanonicalTypeOverride() +InterfaceDecl* ThisType::getInterfaceDecl() { - ThisType* canType = m_astBuilder->create<ThisType>(); - - // TODO: need to canonicalize the decl-ref - canType->interfaceDeclRef = interfaceDeclRef; - return canType; -} - -Val* ThisType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substInterfaceDeclRef = interfaceDeclRef.substituteImpl(astBuilder, subst, &diff); - - auto thisTypeSubst = findThisTypeSubstitution(subst.substitutions, substInterfaceDeclRef.getDecl()); - if (thisTypeSubst) - { - return thisTypeSubst->witness->sub; - } - - if (!diff) - return this; - - (*ioDiff)++; - - ThisType* substType = m_astBuilder->create<ThisType>(); - substType->interfaceDeclRef = substInterfaceDeclRef; - return substType; + return dynamicCast<InterfaceDecl>(getDeclRefBase()->getDecl()->parentDecl); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AndType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void AndType::_toTextOverride(StringBuilder& out) { - out << left << toSlice(" & ") << right; -} - -bool AndType::_equalsImplOverride(Type * type) -{ - auto other = as<AndType>(type); - if (!other) - return false; - - if(!left->equals(other->left)) - return false; - if(!right->equals(other->right)) - return false; - - return true; -} - -HashCode AndType::_getHashCodeOverride() -{ - Hasher hasher; - hasher.hashObject(left); - hasher.hashObject(right); - return hasher.getResult(); + out << getLeft() << toSlice(" & ") << getRight(); } Type* AndType::_createCanonicalTypeOverride() @@ -1094,9 +674,9 @@ Type* AndType::_createCanonicalTypeOverride() // right now, in the name of getting something up and running. // - auto canLeft = left->getCanonicalType(); - auto canRight = right->getCanonicalType(); - auto canType = m_astBuilder->getAndType(canLeft, canRight); + auto canLeft = getLeft()->getCanonicalType(); + auto canRight = getRight()->getCanonicalType(); + auto canType = getCurrentASTBuilder()->getAndType(canLeft, canRight); return canType; } @@ -1104,15 +684,15 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su { int diff = 0; - auto substLeft = as<Type>(left ->substituteImpl(astBuilder, subst, &diff)); - auto substRight = as<Type>(right->substituteImpl(astBuilder, subst, &diff)); + auto substLeft = as<Type>(getLeft()->substituteImpl(astBuilder, subst, &diff)); + auto substRight = as<Type>(getRight()->substituteImpl(astBuilder, subst, &diff)); if(!diff) return this; (*ioDiff)++; - auto substType = m_astBuilder->getAndType(substLeft, substRight); + auto substType = getCurrentASTBuilder()->getAndType(substLeft, substRight); return substType; } @@ -1120,83 +700,35 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su void ModifiedType::_toTextOverride(StringBuilder& out) { - for( auto modifier : modifiers ) + for( Index i = 0; i < getModifierCount(); i++ ) { - modifier->toText(out); + getModifier(i)->toText(out); out.appendChar(' '); } - base->toText(out); -} - -bool ModifiedType::_equalsImplOverride(Type* type) -{ - auto other = as<ModifiedType>(type); - if(!other) - return false; - - if(!base->equals(other->base)) - return false; - - // TODO: Eventually we need to put the `modifiers` into - // a canonical ordering as part of creation of a `ModifiedType`, - // so that two instances that apply the same modifiers to - // the same type will have those modifiers in a matching order. - // - // The simplest way to achieve that ordering *for now* would - // be to sort the array by the integer AST node type tag. - // That approach would of course not scale to modifiers that - // have any operands of their own. - // - // Note that we would *also* need the logic that creates a - // `ModifiedType` to detect when the base type is itself a - // `ModifiedType` and produce a single `ModifiedType` with - // a combined list of modifiers and a non-`ModifiedType` as - // its base type. - // - auto modifierCount = modifiers.getCount(); - if(modifierCount != other->modifiers.getCount()) - return false; - - for( Index i = 0; i < modifierCount; ++i ) - { - auto thisModifier = this->modifiers[i]; - auto otherModifier = other->modifiers[i]; - if(!thisModifier->equalsVal(otherModifier)) - return false; - } - return true; -} - -HashCode ModifiedType::_getHashCodeOverride() -{ - Hasher hasher; - hasher.hashObject(base); - for( auto modifier : modifiers ) - { - hasher.hashObject(modifier); - } - return hasher.getResult(); + getBase()->toText(out); } Type* ModifiedType::_createCanonicalTypeOverride() { - ModifiedType* canonical = m_astBuilder->create<ModifiedType>(); - canonical->base = base->getCanonicalType(); - for( auto modifier : modifiers ) + List<Val*> modifiers; + for (Index i = 0; i < getModifierCount(); ++i) { - canonical->modifiers.add(modifier); + auto modifier = this->getModifier(i); + modifiers.add(modifier); } + ModifiedType* canonical = getCurrentASTBuilder()->getOrCreate<ModifiedType>(getBase()->getCanonicalType(), modifiers.getArrayView()); return canonical; } Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - Type* substBase = as<Type>(base->substituteImpl(astBuilder, subst, &diff)); + Type* substBase = as<Type>(getBase()->substituteImpl(astBuilder, subst, &diff)); List<Val*> substModifiers; - for( auto modifier : modifiers ) + for (Index i = 0; i < getModifierCount(); ++i) { + auto modifier = this->getModifier(i); auto substModifier = modifier->substituteImpl(astBuilder, subst, &diff); substModifiers.add(substModifier); } @@ -1206,12 +738,49 @@ Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionS *ioDiff = 1; - ModifiedType* substType = m_astBuilder->create<ModifiedType>(); - substType->base = substBase; - substType->modifiers = _Move(substModifiers); + ModifiedType* substType = getCurrentASTBuilder()->getOrCreate<ModifiedType>(substBase, substModifiers.getArrayView()); return substType; } +BaseType BasicExpressionType::getBaseType() const +{ + auto builtinType = getDeclRef().getDecl()->findModifier<BuiltinTypeModifier>(); + return builtinType->tag; +} + +FeedbackType::Kind FeedbackType::getKind() const +{ + auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>(); + return FeedbackType::Kind(magicMod->tag); +} + +TextureFlavor ResourceType::getFlavor() const +{ + auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>(); + return TextureFlavor(magicMod->tag); +} + +SamplerStateFlavor SamplerStateType::getFlavor() const +{ + auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>(); + return SamplerStateFlavor(magicMod->tag); +} + +Type* BuiltinGenericType::getElementType() const +{ + return as<Type>(_getGenericTypeArg(getDeclRefBase(), 0)); +} + +Type* ResourceType::getElementType() +{ + return as<Type>(_getGenericTypeArg(this, 0)); +} + +Val* TextureTypeBase::getSampleCount() +{ + return as<Type>(_getGenericTypeArg(this, 1)); +} + Type* removeParamDirType(Type* type) { for (auto paramDirType = as<ParamDirectionType>(type); paramDirType;) |
