diff options
37 files changed, 562 insertions, 463 deletions
diff --git a/source/compiler-core/slang-diagnostic-sink.h b/source/compiler-core/slang-diagnostic-sink.h index c150de900..e4d131e37 100644 --- a/source/compiler-core/slang-diagnostic-sink.h +++ b/source/compiler-core/slang-diagnostic-sink.h @@ -112,7 +112,7 @@ void printDiagnosticArg(StringBuilder& sb, RefPtr<T> ptr) inline SourceLoc getDiagnosticPos(SourceLoc const& pos) { return pos; } SourceLoc getDiagnosticPos(Token const& token); - + template<typename T> SourceLoc getDiagnosticPos(RefPtr<T> const& ptr) diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index a2fdaf32e..5319dca7e 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -22,7 +22,7 @@ class NodeBase // MUST be called before used. Called automatically via the ASTBuilder. // Note that the astBuilder is not stored in the NodeBase derived types by default. - SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* /* astBuilder*/) + SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* /*astBuilder*/) { astNodeType = inAstNodeType; #ifdef _DEBUG @@ -78,6 +78,15 @@ SLANG_FORCE_INLINE const T* as(const NodeBase* node) return (node && ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) ? static_cast<const T*>(node) : nullptr; } +// Because DeclRefBase is now a `Val`, we prevent casting it directly into other nodes +// to avoid confusion and bugs. Instead, use the `as<>()` method on `DeclRefBase` to +// get a `DeclRef<T>` for a specific node type. +template<typename T> +T* as(const DeclRefBase* declRefBase) = delete; + +template<typename T, typename U> +DeclRef<T> as(DeclRef<U> declRef) { return DeclRef<T>(declRef); } + struct Scope : public NodeBase { SLANG_AST_CLASS(Scope) @@ -339,6 +348,86 @@ class ThisTypeSubstitution : public Substitutions {} }; +class Decl; + +// A reference to a declaration, which may include +// substitutions for generic parameters. +class DeclRefBase : public Val +{ + SLANG_AST_CLASS(DeclRefBase) + + Decl* getDecl() const { return decl; } + + Substitutions* getSubst() const { return substitutions; } + + DeclRefBase(Decl* decl) + :decl(decl) + { + } + + DeclRefBase(Decl* decl, Substitutions* subst) + :decl(decl), substitutions(subst) + { + } + + // Apply substitutions to a type or declaration + Type* substitute(ASTBuilder* astBuilder, Type* type) const; + + DeclRefBase* substitute(ASTBuilder* astBuilder, DeclRefBase* declRef) const; + + // Apply substitutions to an expression + SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const; + + // Apply substitutions to this declaration reference + DeclRefBase* substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; + + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) + { + return substituteImpl(astBuilder, subst, ioDiff); + } + bool _equalsValOverride(Val* val); + + bool _equalsImplOverride(DeclRefBase* declRef) { return equals(declRef); } + + void _toTextOverride(StringBuilder& out) { toText(out); } + + // Returns true if 'as' will return a valid cast + template <typename T> + bool is() const { return Slang::as<T>(decl) != nullptr; } + + // Check if this is an equivalent declaration reference to another + bool equals(DeclRefBase* declRef) const; + + // Convenience accessors for common properties of declarations + Name* getName() const; + SourceLoc getNameLoc() const; + SourceLoc getLoc() const; + DeclRefBase* getParent(ASTBuilder* astBuilder) const; + + HashCode getHashCode() const; + + // Debugging: + String toString() const; + void toText(StringBuilder& out) const; + +private: + + // The underlying declaration + Decl* decl = nullptr; + // Optionally, a chain of substitutions to perform + Substitutions* substitutions; + +}; + +SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const DeclRefBase* declRef) { declRef->toText(io); return io; } + +SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const Decl* decl) +{ + if (decl) + _printNestedDecl(nullptr, decl, io); + return io; +} + class SyntaxNode : public SyntaxNodeBase { SLANG_ABSTRACT_AST_CLASS(SyntaxNode); @@ -405,20 +494,27 @@ public: ContainerDecl* parentDecl = nullptr; + // A direct DeclRef to this Decl. + // For every Decl, we create a DeclRef node representing a direct reference to it + // upon the creation of the Decl (implemented in ASTBuilder), and store that + // DeclRef here so we can get a direct DeclRef from a Decl* (by calling makeDeclRef) + // without requiring a ASTBuilder to be available. + DeclRefBase* defaultDeclRef = nullptr; + NameLoc nameAndLoc; RefPtr<MarkupEntry> markup; - Name* getName() { return nameAndLoc.name; } - SourceLoc getNameLoc() { return nameAndLoc.loc ; } - NameLoc getNameAndLoc() { return nameAndLoc ; } + Name* getName() const { return nameAndLoc.name; } + SourceLoc getNameLoc() const { return nameAndLoc.loc ; } + NameLoc getNameAndLoc() const { return nameAndLoc ; } DeclCheckStateExt checkState = DeclCheckState::Unchecked; // The next declaration defined in the same container with the same name Decl* nextInContainerWithSameName = nullptr; - bool isChecked(DeclCheckState state) { return checkState >= state; } + bool isChecked(DeclCheckState state) const { return checkState >= state; } void setCheckState(DeclCheckState state) { SLANG_RELEASE_ASSERT(state >= checkState.getState()); @@ -446,5 +542,111 @@ class Stmt : public ModifiableSyntaxNode void accept(IStmtVisitor* visitor, void* extra); }; +template<typename T> +void DeclRef<T>::init(DeclRefBase* base) +{ + if (base && !Slang::as<T>(base->getDecl())) + declRefBase = nullptr; + else + declRefBase = base; +} + +template<typename T> +DeclRef<T>::DeclRef(Decl* decl) +{ + DeclRefBase* declRef = nullptr; + if (decl) + { + SLANG_ASSERT(decl->defaultDeclRef); + declRef = decl->defaultDeclRef; + } + init(declRef); +} + +template<typename T> +T* DeclRef<T>::getDecl() const +{ + return declRefBase ? (T*)declRefBase->getDecl() : nullptr; +} + +template<typename T> +Substitutions* DeclRef<T>::getSubst() const +{ + return declRefBase ? declRefBase->getSubst() : nullptr; +} + +template<typename T> +Name* DeclRef<T>::getName() const +{ + if (declRefBase) + return declRefBase->getName(); + return nullptr; +} + +template<typename T> +SourceLoc DeclRef<T>::getNameLoc() const +{ + if (declRefBase) return declRefBase->getNameLoc(); + return SourceLoc(); +} + +template<typename T> +SourceLoc DeclRef<T>::getLoc() const +{ + if (declRefBase) return declRefBase->getLoc(); + return SourceLoc(); +} + +template<typename T> +DeclRef<ContainerDecl> DeclRef<T>::getParent(ASTBuilder* astBuilder) const +{ + if (declRefBase) return DeclRef<ContainerDecl>(declRefBase->getParent(astBuilder)); + return DeclRef<ContainerDecl>((DeclRefBase*)nullptr); +} + +template<typename T> +HashCode DeclRef<T>::getHashCode() const +{ + if (declRefBase) return declRefBase->getHashCode(); + return 0; +} + +template<typename T> +Type* DeclRef<T>::substitute(ASTBuilder* astBuilder, Type* type) const +{ + if (!declRefBase) return type; + return declRefBase->substitute(astBuilder, type); +} + +template<typename T> +SubstExpr<Expr> DeclRef<T>::substitute(ASTBuilder* astBuilder, Expr* expr) const +{ + if (!declRefBase) return expr; + return declRefBase->substitute(astBuilder, expr); +} + +// Apply substitutions to a type or declaration +template<typename T> +template<typename U> +DeclRef<U> DeclRef<T>::substitute(ASTBuilder* astBuilder, DeclRef<U> declRef) const +{ + if (!declRefBase) return declRef; + return DeclRef<U>(declRefBase->substitute(astBuilder, declRef.declRefBase)); +} + +// Apply substitutions to this declaration reference +template<typename T> +DeclRef<T> DeclRef<T>::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const +{ + if (!declRefBase) return *this; + return DeclRef<T>(declRefBase->substituteImpl(astBuilder, subst, ioDiff)); +} + +template<typename T> +template<typename U> +bool DeclRef<T>::equals(DeclRef<U> other) const +{ + return declRefBase == other.declRefBase || (declRefBase && declRefBase->equals(other.declRefBase)); +} } // namespace Slang diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index d914b4568..33bd23f43 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -167,7 +167,7 @@ SharedASTBuilder::~SharedASTBuilder() void SharedASTBuilder::registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modifier) { - auto type = DeclRefType::create(m_astBuilder, DeclRef<Decl>(decl)); + auto type = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(decl)); m_builtinTypes[Index(modifier->tag)] = type; } @@ -294,7 +294,7 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength); auto result = getOrCreate<ArrayExpressionType>(elementType, elementCount); - if (!result->declRef.decl) + if (!result->declRef.getDecl()) { auto arrayGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ArrayType")); auto arrayTypeDecl = arrayGenericDecl->inner; @@ -309,7 +309,7 @@ VectorExpressionType* ASTBuilder::getVectorType( IntVal* elementCount) { auto result = getOrCreate<VectorExpressionType>(elementType, elementCount); - if (!result->declRef.decl) + if (!result->declRef.getDecl()) { auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector")); auto vectorTypeDecl = vectorGenericDecl->inner; @@ -340,8 +340,7 @@ DifferentialPairType* ASTBuilder::getDifferentialPairType( DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterface() { - DeclRef<InterfaceDecl> declRef; - declRef.decl = dynamicCast<InterfaceDecl>(m_sharedASTBuilder->findMagicDecl("DifferentiableType")); + DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiableType", nullptr)); return declRef; } @@ -381,22 +380,22 @@ MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier( DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg) { - DeclRef<Decl> declRef; - declRef.decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); - if (auto genericDecl = as<GenericDecl>(declRef.decl)) + auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); + if (auto genericDecl = as<GenericDecl>(decl)) { + decl = genericDecl->inner; + Substitutions* subst = nullptr; if (genericArg) { - auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, genericArg); - declRef.substitutions = substitutions; + subst = getOrCreate<GenericSubstitution>(genericDecl, genericArg); } - declRef.decl = genericDecl->inner; + return getSpecializedDeclRef(decl, subst); } else { SLANG_ASSERT(!genericArg); } - return declRef; + return makeDeclRef(decl); } Type* ASTBuilder::getAndType(Type* left, Type* right) @@ -458,8 +457,7 @@ bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const // via a `NodeDesc` *should* all be going through the // deduplication path anyway, as should their operands. // - if (operands[i].values.nodeOperand[0] != that.operands[i].values.nodeOperand[0]) return false; - if (operands[i].values.nodeOperand[1] != that.operands[i].values.nodeOperand[1]) return false; + if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; } return true; } @@ -474,8 +472,7 @@ HashCode ASTBuilder::NodeDesc::getHashCode() const // to match the semantics implemented for `==` on // `NodeDesc`. // - hasher.hashValue(operands[i].values.nodeOperand[0]); - hasher.hashValue(operands[i].values.nodeOperand[1]); + hasher.hashValue(operands[i].values.nodeOperand); } return hasher.getResult(); } diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 3207ef73c..618636417 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -114,21 +114,26 @@ public: { union { - NodeBase* nodeOperand[2]; - int64_t intOperand[2]; + NodeBase* nodeOperand; + int64_t intOperand; } values; + NodeOperand() { - values.nodeOperand[0] = nullptr; - values.nodeOperand[1] = nullptr; + values.nodeOperand = nullptr; } - NodeOperand(NodeBase* node) { values.nodeOperand[0] = node; values.nodeOperand[1] = nullptr; } + + NodeOperand(NodeBase* node) { values.nodeOperand = node; } + + template<typename T> + NodeOperand(DeclRef<T> declRef) { values.nodeOperand = declRef.declRefBase; } + template<typename EnumType> NodeOperand(EnumType intVal) { + static_assert(std::is_trivial<EnumType>::value, "Type to construct NodeOperand must be trivial."); static_assert(sizeof(EnumType) <= sizeof(values), "size of operand must be less than pointer size."); - values.intOperand[0] = 0; - values.intOperand[1] = 0; + values.intOperand = 0; memcpy(&values, &intVal, sizeof(intVal)); } }; @@ -254,16 +259,29 @@ public: }); } + // This is the bottlneck through which all DeclRefs are created. template<typename T> DeclRef<T> getSpecializedDeclRef(T* decl, Substitutions* subst) { - return DeclRef<T>(this, decl, subst); + // We never create an actual DeclRefBase node to point to a null decl. + if (!decl) + return DeclRef<T>(); + + // If we don't have substitutions, use the default decl ref if it is created. + if (!subst) + { + auto defaultDeclRef = static_cast<Decl*>(decl)->defaultDeclRef; + if (defaultDeclRef) + return defaultDeclRef; + } + + return getOrCreate<DeclRefBase>(decl, subst); } template<typename T> DeclRef<T> getSpecializedDeclRef(T* decl, SubstitutionSet subst) { - return DeclRef<T>(this, decl, subst); + return getSpecializedDeclRef(decl, subst.substitutions); } ConstantIntVal* getIntVal(Type* type, IntegerLiteralValue value) @@ -271,17 +289,9 @@ public: return getOrCreate<ConstantIntVal>(type, value); } - DeclRefType* getOrCreateDeclRefType(Decl* decl, Substitutions* outer) + DeclRefType* getOrCreateDeclRefType(DeclRefBase* declRef) { - NodeDesc desc; - desc.type = DeclRefType::kType; - desc.operands.add(decl); - if (outer) - { - desc.operands.add(outer); - } - auto result = (DeclRefType*)_getOrCreateImpl(desc, [&]() {return create<DeclRefType>(getSpecializedDeclRef(decl, outer)); }); - return result; + return getOrCreate<DeclRefType>(declRef); } GenericSubstitution* getOrCreateGenericSubstitution(GenericDecl* decl, const List<Val*>& args, Substitutions* outer) @@ -449,6 +459,11 @@ protected: // Keep such that dtor can be run on ASTBuilder being dtored m_dtorNodes.add(node); } + if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Decl::kType))) + { + auto decl = (Decl*)(node); + decl->defaultDeclRef = getSpecializedDeclRef(decl, nullptr); + } return node; } diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index 884f8b736..0ab440a18 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -436,7 +436,7 @@ struct ASTDumpContext m_writer->emit("}"); } - void dump(DeclRefBase declRef) + void dump(DeclRefBase* declRef) { StringBuilder sb; sb << declRef; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 1b829c836..fd317a2c2 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1055,7 +1055,7 @@ class DifferentiableAttribute : public Attribute SLANG_AST_CLASS(DifferentiableAttribute) /// Mapping from types to subtype witnesses for conformance to IDifferentiable. - OrderedDictionary<DeclRefBase, SubtypeWitness*> m_mapTypeToIDifferentiableWitness; + OrderedDictionary<DeclRefBase*, SubtypeWitness*> m_mapTypeToIDifferentiableWitness; SLANG_UNREFLECTED ValSet m_typeRegistrationWorkingSet; }; diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index bc0410fee..65c3a23c9 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -172,7 +172,7 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth) !declRef.as<GenericValueParamDecl>() && !declRef.as<GenericTypeParamDecl>()) { - auto genSubst = as<GenericSubstitution>(declRef.substitutions.substitutions); + auto genSubst = as<GenericSubstitution>(declRef.getSubst()); if (genSubst) { SLANG_RELEASE_ASSERT(genSubst); @@ -276,7 +276,7 @@ void ASTPrinter::addDeclParams(const DeclRef<Decl>& declRef, List<Range<Index>>* auto rangeStart = sb.getLength(); - ParamDecl* paramDecl = paramDeclRef; + ParamDecl* paramDecl = paramDeclRef.getDecl(); { ScopePart scopePart(this, Part::Type::ParamType); @@ -331,7 +331,7 @@ void ASTPrinter::addDeclParams(const DeclRef<Decl>& declRef, List<Range<Index>>* { addGenericParams(genericDeclRef); - addDeclParams(m_astBuilder->getSpecializedDeclRef<Decl>(getInner(genericDeclRef), genericDeclRef.substitutions), outParamRange); + addDeclParams(m_astBuilder->getSpecializedDeclRef<Decl>(getInner(genericDeclRef), genericDeclRef.getSubst()), outParamRange); } else { @@ -443,10 +443,10 @@ void ASTPrinter::addDeclResultType(const DeclRef<Decl>& inDeclRef) DeclRef<Decl> declRef = inDeclRef; if (auto genericDeclRef = declRef.as<GenericDecl>()) { - declRef = m_astBuilder->getSpecializedDeclRef<Decl>(getInner(genericDeclRef), genericDeclRef.substitutions); + declRef = m_astBuilder->getSpecializedDeclRef<Decl>(getInner(genericDeclRef), genericDeclRef.getSubst()); } - if (as<ConstructorDecl>(declRef)) + if (declRef.as<ConstructorDecl>()) { } else if (auto callableDeclRef = declRef.as<CallableDecl>()) diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 67340d52c..19f3f42d4 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -63,10 +63,13 @@ namespace Slang void printDiagnosticArg(StringBuilder& sb, TypeExp const& type); void printDiagnosticArg(StringBuilder& sb, QualType const& type); void printDiagnosticArg(StringBuilder& sb, Val* val); + void printDiagnosticArg(StringBuilder& sb, DeclRefBase* declRefBase); + class SyntaxNode; SourceLoc getDiagnosticPos(SyntaxNode const* syntax); SourceLoc getDiagnosticPos(TypeExp const& typeExp); + SourceLoc getDiagnosticPos(DeclRefBase* declRef); typedef NodeBase* (*SyntaxParseCallback)(Parser* parser, void* userData); @@ -743,185 +746,108 @@ namespace Slang struct DeclRef; Module* getModule(Decl* decl); - // A reference to a declaration, which may include - // substitutions for generic parameters. - struct DeclRefBase - { - typedef Decl DeclType; - - // The underlying declaration - Decl* decl = nullptr; - Decl* getDecl() const { return decl; } - - // Optionally, a chain of substitutions to perform - SubstitutionSet substitutions; - - DeclRefBase() - {} - - DeclRefBase(Decl* decl) - :decl(decl) - { - } - - DeclRefBase(ASTBuilder* astBuilder, Decl* decl, SubstitutionSet subst) - :decl(decl), - substitutions(subst) - { - SLANG_RELEASE_ASSERT(astBuilder); - } - - DeclRefBase(ASTBuilder* astBuilder, Decl* decl, Substitutions* subst) - : decl(decl) - , substitutions(subst) - { - SLANG_RELEASE_ASSERT(astBuilder); - } - - // Apply substitutions to a type or declaration - Type* substitute(ASTBuilder* astBuilder, Type* type) const; - - DeclRefBase substitute(ASTBuilder* astBuilder, DeclRefBase declRef) const; - - // Apply substitutions to an expression - SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const; - - // Apply substitutions to this declaration reference - DeclRefBase substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; - - // Returns true if 'as' will return a valid cast - template <typename T> - bool is() const { return Slang::as<T>(decl) != nullptr; } - - // "dynamic cast" to a more specific declaration reference type - template<typename T> - DeclRef<T> as() const; - - // Check if this is an equivalent declaration reference to another - bool equals(DeclRefBase const& declRef) const; - bool operator == (const DeclRefBase& other) const - { - return equals(other); - } - - // Convenience accessors for common properties of declarations - Name* getName() const; - SourceLoc getNameLoc() const; - SourceLoc getLoc() const; - DeclRefBase getParent(ASTBuilder* astBuilder) const; - - HashCode getHashCode() const; - - // Debugging: - String toString() const; - void toText(StringBuilder& out) const; - }; // If this is a declref to an associatedtype with a ThisTypeSubsitution, // try to find the concrete decl that satisfies the associatedtype requirement from the // concrete type supplied by ThisTypeSubstittution. Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef<Decl> declRef); - void _printNestedDecl(const Substitutions* substitutions, Decl* decl, StringBuilder& out); + void _printNestedDecl(const Substitutions* substitutions, const Decl* decl, StringBuilder& out); template<typename T> - struct DeclRef : DeclRefBase + struct DeclRef { friend class ASTBuilder; - private: - DeclRef(ASTBuilder* builder, T* decl, SubstitutionSet subst) - : DeclRefBase(builder, decl, subst) - {} - - DeclRef(ASTBuilder* builder, T* decl, Substitutions* subst) - : DeclRefBase(builder, decl, SubstitutionSet(subst)) - {} public: typedef T DeclType; - + DeclRefBase* declRefBase; DeclRef() + :declRefBase(nullptr) {} + + void init(DeclRefBase* base); - DeclRef(T* decl) - : DeclRefBase(decl) - {} + DeclRef(Decl* decl); + + DeclRef(DeclRefBase* base) + { + init(base); + } template <typename U> DeclRef(DeclRef<U> const& other, typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) + : declRefBase(other.declRefBase) { - this->decl = other.decl; - this->substitutions = other.substitutions; } - T* getDecl() const - { - return (T*)decl; - } + T* getDecl() const; + Substitutions* getSubst() const; + + Name* getName() const; + + SourceLoc getNameLoc() const; + SourceLoc getLoc() const; + DeclRef<ContainerDecl> getParent(ASTBuilder* astBuilder) const; + HashCode getHashCode() const; + Type* substitute(ASTBuilder* astBuilder, Type* type) const; + + SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const; - operator T*() const + // Apply substitutions to a type or declaration + template<typename U> + DeclRef<U> substitute(ASTBuilder* astBuilder, DeclRef<U> declRef) const; + + // Apply substitutions to this declaration reference + DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; + + template<typename U> + DeclRef<U> as() const { - return getDecl(); + DeclRef<U> result = DeclRef<U>(declRefBase); + return result; } - // - static DeclRef<T> unsafeInit(DeclRefBase const& declRef) + template<typename U> + bool is() const { - DeclRef<T> rs; - rs.decl = declRef.decl; - rs.substitutions = declRef.substitutions; - return rs; + return Slang::as<U>(static_cast<NodeBase*>(getDecl())) != nullptr; } - Type* substitute(ASTBuilder* astBuilder, Type* type) const + operator DeclRefBase* () const { - return DeclRefBase::substitute(astBuilder, type); + return declRefBase; } - SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const + operator DeclRef<Decl>() const { - return DeclRefBase::substitute(astBuilder, expr); + return DeclRef<Decl>(declRefBase); } - // Apply substitutions to a type or declaration template<typename U> - DeclRef<U> substitute(ASTBuilder* astBuilder, DeclRef<U> declRef) const - { - return DeclRef<U>::unsafeInit(DeclRefBase::substitute(astBuilder, declRef)); - } + bool equals(DeclRef<U> other) const; - // Apply substitutions to this declaration reference - DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const + template<typename U> + bool operator == (DeclRef<U> other) const { - return DeclRef<T>::unsafeInit(DeclRefBase::substituteImpl(astBuilder, subst, ioDiff)); + return equals(other); } - DeclRef<ContainerDecl> getParent(ASTBuilder* astBuilder) const + explicit operator bool() const { - return DeclRef<ContainerDecl>::unsafeInit(DeclRefBase::getParent(astBuilder)); + return declRefBase; } }; - SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr); - DeclRef<Decl> substituteDeclRef(SubstitutionSet const& substs, ASTBuilder* astBuilder, DeclRef<Decl> const& declRef); - Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type); - - SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const DeclRefBase& declRef) { declRef.toText(io); return io; } - - template<typename T> - DeclRef<T> DeclRefBase::as() const - { - DeclRef<T> result; - result.decl = Slang::as<T>(decl); - result.substitutions = substitutions; - return result; - } - template<typename T> inline DeclRef<T> makeDeclRef(T* decl) { return DeclRef<T>(decl); } + SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr); + DeclRef<Decl> substituteDeclRef(SubstitutionSet const& substs, ASTBuilder* astBuilder, DeclRef<Decl> const& declRef); + Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type); + enum class MemberFilterStyle { All, ///< All members @@ -1450,7 +1376,7 @@ namespace Slang : m_flavor(Flavor::none) {} - RequirementWitness(DeclRef<Decl> declRef) + RequirementWitness(DeclRefBase* declRef) : m_flavor(Flavor::declRef) , m_declRef(declRef) {} diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp index 5ad14321d..872088d54 100644 --- a/source/slang/slang-ast-synthesis.cpp +++ b/source/slang/slang-ast-synthesis.cpp @@ -59,7 +59,7 @@ Expr* ASTSynthesizer::emitVarExpr(Name* name) Expr* ASTSynthesizer::emitVarExpr(VarDecl* varDecl) { auto varExpr = m_builder->create<VarExpr>(); - varExpr->declRef = makeDeclRef(varDecl); + varExpr->declRef = makeDeclRef<Decl>(varDecl); varExpr->type = varDecl->type.type; return varExpr; } @@ -67,7 +67,7 @@ Expr* ASTSynthesizer::emitVarExpr(VarDecl* varDecl) Expr* ASTSynthesizer::emitVarExpr(VarDecl* var, Type* type) { auto expr = m_builder->create<VarExpr>(); - expr->declRef = DeclRef<Decl>(var); + expr->declRef = makeDeclRef<Decl>(var); expr->type.type = type; expr->type.isLeftValue = true; return expr; @@ -76,7 +76,7 @@ Expr* ASTSynthesizer::emitVarExpr(VarDecl* var, Type* type) Expr* ASTSynthesizer::emitVarExpr(DeclStmt* varStmt, Type* type) { auto expr = m_builder->create<VarExpr>(); - expr->declRef = DeclRef<Decl>(as<Decl>(varStmt->decl)); + expr->declRef = makeDeclRef<Decl>(as<Decl>(varStmt->decl)); expr->type.type = type; expr->type.isLeftValue = true; return expr; diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index c4d301852..b24e0eb8e 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -274,7 +274,7 @@ BasicExpressionType* BasicExpressionType::_getScalarTypeOverride() // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TensorViewType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* TensorViewType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); } @@ -304,17 +304,17 @@ BasicExpressionType* MatrixExpressionType::_getScalarTypeOverride() Type* MatrixExpressionType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); } IntVal* MatrixExpressionType::getRowCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]); + return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); } IntVal* MatrixExpressionType::getColumnCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[2]); + return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[2]); } Type* MatrixExpressionType::getRowType() @@ -330,12 +330,12 @@ Type* MatrixExpressionType::getRowType() Type* ArrayExpressionType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); } IntVal* ArrayExpressionType::getElementCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]); + return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); } void ArrayExpressionType::_toTextOverride(StringBuilder& out) @@ -441,7 +441,7 @@ Type* NamespaceType::_createCanonicalTypeOverride() Type* DifferentialPairType::getPrimalType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); } @@ -449,12 +449,12 @@ Type* DifferentialPairType::getPrimalType() Type* PtrTypeBase::getValueType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); } Type* OptionalType::getValueType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -463,7 +463,7 @@ void NamedExpressionType::_toTextOverride(StringBuilder& out) { if (declRef.getDecl()) { - _printNestedDecl(declRef.substitutions, declRef.getDecl(), out); + _printNestedDecl(declRef.getSubst(), declRef.getDecl(), out); } } @@ -773,7 +773,7 @@ DeclRef<InterfaceDecl> ExtractExistentialType::getSpecializedInterfaceDeclRef() SubtypeWitness* openedWitness = getSubtypeWitness(); ThisTypeSubstitution* openedThisType = m_astBuilder->create<ThisTypeSubstitution>(); - openedThisType->outer = originalInterfaceDeclRef.substitutions.substitutions; + openedThisType->outer = originalInterfaceDeclRef.getSubst(); openedThisType->interfaceDecl = interfaceDecl; openedThisType->witness = openedWitness; diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 5113e65f8..c5ec7e161 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -66,7 +66,6 @@ class DeclRefType : public Type DeclRef<Decl> declRef; - static DeclRefType* create(ASTBuilder* astBuilder, DeclRef<Decl> declRef); // Overrides should be public so base classes can access @@ -76,9 +75,8 @@ class DeclRefType : public Type HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); -protected: - DeclRefType(DeclRef<Decl> declRef) - : declRef(declRef) + DeclRefType(DeclRefBase* declRefBase) + : declRef(declRefBase) {} }; diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index f9b668f55..6850fdbfc 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -317,7 +317,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub // so we'll need to change this location in the code if we ever clean // up the hierarchy. // - if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(substDeclRef.decl)) + if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(substDeclRef.getDecl())) { if (auto substAssocTypeDecl = as<AssocTypeDecl>(substTypeConstraintDecl->parentDecl)) { @@ -326,7 +326,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub // At this point we have a constraint decl for an associated type, // and we nee to see if we are dealing with a concrete substitution // for the interface around that associated type. - if (auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.substitutions, interfaceDecl)) + if (auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.getSubst(), interfaceDecl)) { // We need to look up the declaration that satisfies // the requirement named by the associated type. @@ -349,7 +349,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub } DeclaredSubtypeWitness* rs = astBuilder->getOrCreate<DeclaredSubtypeWitness>( - substSub, substSup, astBuilder->getSpecializedDeclRef(substDeclRef.getDecl(), substDeclRef.substitutions.substitutions)); + substSub, substSup, astBuilder->getSpecializedDeclRef(substDeclRef.getDecl(), substDeclRef.getSubst())); rs->sub = substSub; rs->sup = substSup; rs->declRef = substDeclRef; diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 218f03200..75979cb15 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -129,7 +129,7 @@ public: if (thisGenParam->equalsVal(thatGenParam)) return power < other.power; else - return thisGenParam->declRef.decl < thatGenParam->declRef.decl; + return thisGenParam->declRef.getDecl() < thatGenParam->declRef.getDecl(); } else { diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 8df4631ac..7379a6538 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -13,7 +13,7 @@ namespace Slang DeclaredSubtypeWitness* witness = m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( breadcrumb->sub, breadcrumb->sup, - m_astBuilder->getSpecializedDeclRef(breadcrumb->declRef.decl, breadcrumb->declRef.substitutions.substitutions)); + breadcrumb->declRef); return witness; } @@ -143,7 +143,7 @@ namespace Slang { DeclaredSubtypeWitness* declaredWitness = m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( - bb->sub, bb->sup, m_astBuilder->getSpecializedDeclRef(bb->declRef.decl, bb->declRef.substitutions.substitutions)); + bb->sub, bb->sup, bb->declRef); TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(); transitiveWitness->sub = subType; @@ -472,7 +472,7 @@ namespace Slang leftBreadcrumb.prev = inBreadcrumbs; leftBreadcrumb.sub = andType; leftBreadcrumb.sup = DeclRefType::create(m_astBuilder, superTypeDeclRef); - leftBreadcrumb.declRef = makeDeclRef((Decl*)nullptr); + leftBreadcrumb.declRef = DeclRef<Decl>(); leftBreadcrumb.flavor = TypeWitnessBreadcrumb::Flavor::AndTypeLeftFlavor; if(_isDeclaredSubtype(originalSubType, andType->left, superTypeDeclRef, outWitness, &leftBreadcrumb)) @@ -484,7 +484,7 @@ namespace Slang rightBreadcrumb.prev = inBreadcrumbs; rightBreadcrumb.sub = andType; rightBreadcrumb.sup = DeclRefType::create(m_astBuilder, superTypeDeclRef); - rightBreadcrumb.declRef = makeDeclRef((Decl*)nullptr); + rightBreadcrumb.declRef = DeclRef<Decl>(); rightBreadcrumb.flavor = TypeWitnessBreadcrumb::Flavor::AndTypeRightFlavor; if(_isDeclaredSubtype(originalSubType, andType->right, superTypeDeclRef, outWitness, &rightBreadcrumb)) diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 65d716db0..c38c9e7f2 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -290,7 +290,7 @@ namespace Slang if(!TryUnifyTypes(*system, getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef))) return SubstitutionSet(); } - SubstitutionSet resultSubst = genericDeclRef.substitutions; + SubstitutionSet resultSubst = genericDeclRef.getSubst(); // Once have built up the full list of constraints we are trying to satisfy, // we will attempt to solve for each parameter in a way that satisfies all @@ -457,7 +457,7 @@ namespace Slang // apply the substitutions we already know... GenericSubstitution* solvedSubst = m_astBuilder->getOrCreateGenericSubstitution( - genericDeclRef.getDecl(), args, genericDeclRef.substitutions.substitutions); + genericDeclRef.getDecl(), args, genericDeclRef.getSubst()); for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { @@ -510,7 +510,7 @@ namespace Slang } resultSubst = m_astBuilder->getOrCreateGenericSubstitution( - genericDeclRef.getDecl(), args, genericDeclRef.substitutions.substitutions); + genericDeclRef.getDecl(), args, genericDeclRef.getSubst()); return resultSubst; } @@ -737,8 +737,8 @@ namespace Slang // to each declaration reference. if (!tryUnifySubstitutions( constraints, - fstDeclRef.substitutions.substitutions, - sndDeclRef.substitutions.substitutions)) + fstDeclRef.getSubst(), + sndDeclRef.getSubst())) { return false; } diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 17b55a2cb..a6130d4e8 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -163,7 +163,7 @@ namespace Slang ioInitArgIndex); } - DeclRefType* findBaseStructType(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef) + DeclRefType* findBaseStructType(ASTBuilder* astBuilder, DeclRef<StructDecl> structTypeDeclRef) { auto inheritanceDecl = getMembersOfType<InheritanceDecl>(astBuilder, structTypeDeclRef).getFirstOrNull(); if(!inheritanceDecl) @@ -182,7 +182,7 @@ namespace Slang return baseDeclRefType; } - DeclRef<StructDecl> findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef) + DeclRef<StructDecl> findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef<StructDecl> structTypeDeclRef) { auto inheritanceDecl = getMembersOfType<InheritanceDecl>(astBuilder, structTypeDeclRef).getFirstOrNull(); if (!inheritanceDecl) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 46b968279..06616b38c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1414,10 +1414,11 @@ namespace Slang { if (auto declRefType = as<DeclRefType>(sharedTypeExpr->base)) { - declRefType->declRef.substitutions = createDefaultSubstitutions(m_astBuilder, this, declRefType->declRef.getDecl()); - + auto subst = createDefaultSubstitutions(m_astBuilder, this, declRefType->declRef.getDecl()); + auto newType = m_astBuilder->getOrCreateDeclRefType(m_astBuilder->getSpecializedDeclRef(declRefType->declRef.getDecl(), subst)); + sharedTypeExpr->base.type = newType; if (auto typetype = as<TypeType>(typeExp.exp->type)) - typetype->type = declRefType; + typetype->type = newType; } } } @@ -1466,19 +1467,19 @@ namespace Slang // apply it to the newly synthesized decl. SubstitutionSet substSet; if (auto thisTypeSusbt = findThisTypeSubstitution( - requirementDeclRef.substitutions, - as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) + requirementDeclRef.getSubst(), + as<InterfaceDecl>(requirementDeclRef.getParent(m_astBuilder)).getDecl())) { if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) { - substSet = declRefType->declRef.substitutions; + substSet = declRefType->declRef.getSubst(); } } - auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet); + auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(m_astBuilder->getSpecializedDeclRef(aggTypeDecl, substSet)); // Helper function to add a `diffType` field into the synthesized type for the original // `member`. - auto differentialType = DeclRefType::create(m_astBuilder, makeDeclRef(aggTypeDecl)); + auto differentialType = DeclRefType::create(m_astBuilder, DeclRef<Decl>(makeDeclRef(aggTypeDecl))); auto addDiffMember = [&](Decl* member, Type* diffMemberType) { // If the field is differentiable, add a corresponding field in the associated Differential type. @@ -2089,7 +2090,7 @@ namespace Slang for( auto p : mapRequiredToSatisfyingAccessorDeclRef ) { witnessTable->add( - p.key, + p.key.getDecl(), RequirementWitness(p.value)); } // @@ -2141,7 +2142,7 @@ namespace Slang if (satisfyingVal) { witnessTable->add( - requiredMemberDeclRef, + requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingVal)); } else @@ -2270,9 +2271,7 @@ namespace Slang auto satisfyingVal = m_astBuilder->getOrCreate<GenericParamIntVal>( requiredValueParamDeclRef.getDecl()->getType(), - m_astBuilder->getSpecializedDeclRef( - satisfyingValueParamDeclRef.getDecl(), - satisfyingValueParamDeclRef.substitutions.substitutions)); + satisfyingValueParamDeclRef); satisfyingVal->declRef = satisfyingValueParamDeclRef; requiredSubstArgs.add(satisfyingVal); @@ -2300,7 +2299,7 @@ namespace Slang GenericSubstitution* requiredSubst = m_astBuilder->getOrCreateGenericSubstitution( requiredGenericDeclRef.getDecl(), requiredSubstArgs, - requiredGenericDeclRef.substitutions); + requiredGenericDeclRef.getSubst()); // Now that we have computed a set of specialization arguments that will // specialize the generic requirement at the type parameters of the satisfying @@ -2386,7 +2385,7 @@ namespace Slang // declaration (whatever it is) for an exact match. // return doesMemberSatisfyRequirement( - m_astBuilder->getSpecializedDeclRef<Decl>(satisfyingGenericDeclRef.getDecl()->inner, satisfyingGenericDeclRef.substitutions), + m_astBuilder->getSpecializedDeclRef<Decl>(satisfyingGenericDeclRef.getDecl()->inner, satisfyingGenericDeclRef.getSubst()), m_astBuilder->getSpecializedDeclRef<Decl>(requiredGenericDeclRef.getDecl()->inner, requiredSubst), witnessTable); } @@ -2409,7 +2408,7 @@ namespace Slang { // If a subtype witness was found, then the conformance // appears to hold, and we can satisfy that requirement. - witnessTable->add(requiredConstraintDeclRef, RequirementWitness(witness)); + witnessTable->add(requiredConstraintDeclRef.getDecl(), RequirementWitness(witness)); } else { @@ -2728,7 +2727,7 @@ namespace Slang witnessTable->add(bwdReq, RequirementWitness(val)); } } - witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef)); + witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); } bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( @@ -3242,9 +3241,9 @@ namespace Slang // for(auto p : mapRequiredAccessorToSynAccessor) { - witnessTable->add(p.key, RequirementWitness(makeDeclRef(p.value))); + witnessTable->add(p.key.getDecl(), RequirementWitness(makeDeclRef(p.value))); } - witnessTable->add(requiredMemberDeclRef, + witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(makeDeclRef(synPropertyDecl))); return true; } @@ -3475,16 +3474,16 @@ namespace Slang // generic substitution for outer generic parameters, and apply it here. SubstitutionSet substSet; if (auto thisTypeSusbt = findThisTypeSubstitution( - requirementDeclRef.substitutions, + requirementDeclRef.getSubst(), as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) { if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) { - substSet = declRefType->declRef.substitutions; + substSet = declRefType->declRef.getSubst(); } } - witnessTable->add(requirementDeclRef, RequirementWitness(m_astBuilder->getSpecializedDeclRef<Decl>(synFunc, substSet))); + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(m_astBuilder->getSpecializedDeclRef<Decl>(synFunc, substSet))); return true; } @@ -3538,9 +3537,7 @@ namespace Slang m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( superInterfaceType, reqType, - m_astBuilder->getSpecializedDeclRef( - requiredInheritanceDeclRef.getDecl(), - requiredInheritanceDeclRef.substitutions.substitutions)); + requiredInheritanceDeclRef); // ... TransitiveSubtypeWitness* subIsReqWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, reqType, interfaceIsReqWitness); @@ -3742,7 +3739,7 @@ namespace Slang ThisTypeSubstitution* thisTypeSubst = m_astBuilder->create<ThisTypeSubstitution>(); thisTypeSubst->interfaceDecl = superInterfaceDeclRef.getDecl(); thisTypeSubst->witness = subTypeConformsToSuperInterfaceWitness; - thisTypeSubst->outer = superInterfaceDeclRef.substitutions.substitutions; + thisTypeSubst->outer = superInterfaceDeclRef.getSubst(); auto specializedSuperInterfaceDeclRef = m_astBuilder->getSpecializedDeclRef<InterfaceDecl>(superInterfaceDeclRef.getDecl(), thisTypeSubst); @@ -3778,7 +3775,7 @@ namespace Slang // for(auto requiredMemberDeclRef : getMembers(m_astBuilder, specializedSuperInterfaceDeclRef)) { - if(!isAssociatedTypeDecl(requiredMemberDeclRef)) + if(!isAssociatedTypeDecl(requiredMemberDeclRef.getDecl())) continue; auto requirementSatisfied = findWitnessForInterfaceRequirement( @@ -3795,7 +3792,7 @@ namespace Slang } for(auto requiredMemberDeclRef : getMembers(m_astBuilder, specializedSuperInterfaceDeclRef)) { - if(isAssociatedTypeDecl(requiredMemberDeclRef)) + if(isAssociatedTypeDecl(requiredMemberDeclRef.getDecl())) continue; if (requiredMemberDeclRef.as<DerivativeRequirementDecl>()) continue; @@ -4080,7 +4077,7 @@ namespace Slang { return; } - auto baseDecl = baseDeclRefType->declRef.decl; + auto baseDecl = baseDeclRefType->declRef.getDecl(); // Using the parent/child hierarchy baked into `Decl`s we // can find the modules that contain both the `decl` doing @@ -5195,7 +5192,7 @@ namespace Slang if(!doGenericSignaturesMatch(newGenericDecl, oldGenericDecl, &subst)) return SLANG_OK; - oldDeclRef.substitutions.substitutions = subst; + oldDeclRef = getASTBuilder()->getSpecializedDeclRef(oldDecl, subst); } // If the parameter signatures don't match, then don't worry @@ -6187,7 +6184,7 @@ namespace Slang if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type)) return DeclRef<ExtensionDecl>(); - auto constraintSubst = trySolveConstraintSystem(&constraints, m_astBuilder->getSpecializedDeclRef<Decl>(extGenericDecl, nullptr).as<GenericDecl>()); + auto constraintSubst = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl)); if (!constraintSubst) { return DeclRef<ExtensionDecl>(); @@ -6220,20 +6217,20 @@ namespace Slang { // Looks like we have a match in the types, // now let's see if we have a this-type substitution. - if(auto appThisTypeSubst = as<ThisTypeSubstitution>(appInterfaceDeclRef.substitutions.substitutions)) + if(auto appThisTypeSubst = as<ThisTypeSubstitution>(appInterfaceDeclRef.getSubst())) { if(appThisTypeSubst->interfaceDecl == appInterfaceDeclRef.getDecl()) { // The type we want to apply to has a this-type substitution, // and (by construction) the target type currently does not. // - SLANG_ASSERT(!as<ThisTypeSubstitution>(targetInterfaceDeclRef.substitutions.substitutions)); + SLANG_ASSERT(!as<ThisTypeSubstitution>(targetInterfaceDeclRef.getSubst())); // We will create a new substitution to apply to the target type. ThisTypeSubstitution* newTargetSubst = m_astBuilder->create<ThisTypeSubstitution>(); newTargetSubst->interfaceDecl = appThisTypeSubst->interfaceDecl; newTargetSubst->witness = appThisTypeSubst->witness; - newTargetSubst->outer = targetInterfaceDeclRef.substitutions.substitutions; + newTargetSubst->outer = targetInterfaceDeclRef.getSubst(); targetType = DeclRefType::create(m_astBuilder, m_astBuilder->getSpecializedDeclRef<InterfaceDecl>(targetInterfaceDeclRef.getDecl(), newTargetSubst)); @@ -6248,7 +6245,7 @@ namespace Slang ThisTypeSubstitution* newExtSubst = m_astBuilder->create<ThisTypeSubstitution>(); newExtSubst->interfaceDecl = appThisTypeSubst->interfaceDecl; newExtSubst->witness = appThisTypeSubst->witness; - newExtSubst->outer = extDeclRef.substitutions.substitutions; + newExtSubst->outer = extDeclRef.getSubst(); extDeclRef = m_astBuilder->getSpecializedDeclRef<ExtensionDecl>( extDeclRef.getDecl(), @@ -6390,7 +6387,7 @@ namespace Slang { if( auto namespaceDeclRef = declRefExpr->declRef.as<NamespaceDeclBase>() ) { - SLANG_ASSERT(!namespaceDeclRef.substitutions.substitutions); + SLANG_ASSERT(!namespaceDeclRef.getSubst()); namespaceDecl = namespaceDeclRef.getDecl(); } } @@ -6649,7 +6646,7 @@ namespace Slang if (auto declRefExpr = as<DeclRefExpr>(primalSubst->funcExpr)) { if (auto primalSubstFunc = declRefExpr->declRef.as<FunctionDeclBase>()) - return _getFuncDifferentiableLevelImpl(primalSubstFunc, recurseLimit - 1); + return _getFuncDifferentiableLevelImpl(primalSubstFunc.getDecl(), recurseLimit - 1); } } } @@ -6713,7 +6710,7 @@ namespace Slang SemanticsVisitor* semantics, DeclRef<ContainerDecl> const& containerDeclRef, SyntaxClassBase const& syntaxClass, - void (*callback)(DeclRefBase, void*), + void (*callback)(DeclRefBase*, void*), void const* userData) { // We are being asked to invoke the given callback on @@ -6725,7 +6722,7 @@ namespace Slang // for( auto memberDeclRef : getMembers(semantics->getASTBuilder(), containerDeclRef)) { - if( memberDeclRef.decl->getClass().isSubClassOfImpl(syntaxClass) ) + if( memberDeclRef.getDecl()->getClass().isSubClassOfImpl(syntaxClass)) { callback(memberDeclRef, (void*)userData); } @@ -6757,7 +6754,7 @@ namespace Slang for( auto memberDeclRef : getMembers(semantics->getASTBuilder(), extDeclRef) ) { - if( memberDeclRef.decl->getClass().isSubClassOfImpl(syntaxClass) ) + if( memberDeclRef.getDecl()->getClass().isSubClassOfImpl(syntaxClass)) { callback(memberDeclRef, (void*)userData); } @@ -6858,7 +6855,7 @@ namespace Slang { if (auto concreteType = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(m_astBuilder, declRefType->declRef)) return as<Type>(concreteType); - for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst=subst->outer) + for (auto subst = declRefType->declRef.getSubst(); subst; subst=subst->outer) { if (auto genericSubst = as<GenericSubstitution>(subst)) { @@ -7022,7 +7019,7 @@ namespace Slang for (auto param : func->getParameters()) { auto arg = astBuilder->create<VarExpr>(); - arg->declRef.decl = param; + arg->declRef = makeDeclRef(param); arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; @@ -7038,7 +7035,7 @@ namespace Slang for (auto param : originalFuncDecl->getParameters()) { auto arg = visitor->getASTBuilder()->create<VarExpr>(); - arg->declRef.decl = param; + arg->declRef = makeDeclRef(param); arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; @@ -7076,7 +7073,7 @@ namespace Slang for (auto param : originalFuncDecl->getParameters()) { auto arg = visitor->getASTBuilder()->create<VarExpr>(); - arg->declRef.decl = param; + arg->declRef = makeDeclRef(param); arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; @@ -7216,7 +7213,7 @@ namespace Slang auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>(); derivativeAttr->loc = derivativeOfAttr->loc; auto outterGeneric = visitor->GetOuterGeneric(funcDecl); - auto declRef = visitor->getASTBuilder()->getSpecializedDeclRef((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); + auto declRef = makeDeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl)); auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); declRefExpr->type.type = nullptr; derivativeAttr->args.add(declRefExpr); diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 11b560d93..6b050aa89 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -213,8 +213,8 @@ namespace Slang // to the chosen interface decl must be the first substitution on // the list (which is a linked list from the "inside" out). // - auto thisTypeSubst = as<ThisTypeSubstitution>(interfaceDeclRef.substitutions.substitutions); - if(thisTypeSubst && thisTypeSubst->interfaceDecl == interfaceDeclRef.decl) + auto thisTypeSubst = as<ThisTypeSubstitution>(interfaceDeclRef.getSubst()); + if(thisTypeSubst && thisTypeSubst->interfaceDecl == interfaceDeclRef.getDecl()) { // This isn't really an existential type, because somebody // has already filled in a this-type substitution. @@ -538,8 +538,8 @@ namespace Slang synthesizedDecl = structDecl; auto typeDef = m_astBuilder->create<TypeAliasDecl>(); typeDef->nameAndLoc.name = getName("Differential"); - auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this,m_astBuilder->getSpecializedDeclRef(structDecl, nullptr)); - typeDef->type.type = m_astBuilder->getOrCreateDeclRefType(declRef.decl, declRef.substitutions); + auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); + typeDef->type.type = m_astBuilder->getOrCreateDeclRefType(declRef); typeDef->parentDecl = structDecl; structDecl->members.add(typeDef); } @@ -1052,7 +1052,7 @@ namespace Slang { foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) { - auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr); + auto subType = m_astBuilder->getOrCreateDeclRefType(member); maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, subType); }); foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member) @@ -1061,7 +1061,7 @@ namespace Slang maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType); }); } - for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst = subst->outer) + for (auto subst = declRefType->declRef.getSubst(); subst; subst = subst->outer) { if (auto genSubst = as<GenericSubstitution>(subst)) { @@ -1507,7 +1507,7 @@ namespace Slang if (isInterfaceRequirement(decl)) { - for (auto subst = declRef.substitutions.substitutions; subst; subst = subst->outer) + for (auto subst = declRef.getSubst(); subst; subst = subst->outer) { if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst)) { @@ -1524,7 +1524,7 @@ namespace Slang if (!getInitExpr(m_astBuilder, declRef)) return nullptr; - ensureDecl(declRef.decl, DeclCheckState::Checked); + ensureDecl(declRef.getDecl(), DeclCheckState::Checked); ConstantFoldingCircularityInfo newCircularityInfo(decl, circularityInfo); return tryConstantFoldExpr(getInitExpr(m_astBuilder, declRef), &newCircularityInfo); } @@ -1577,9 +1577,7 @@ namespace Slang { Val* valResult = m_astBuilder->getOrCreate<GenericParamIntVal>( declRef.substitute(m_astBuilder, genericValParamRef.getDecl()->getType()), - m_astBuilder->getSpecializedDeclRef( - genericValParamRef.getDecl(), - genericValParamRef.substitutions.substitutions)); + genericValParamRef); valResult = valResult->substitute(m_astBuilder, expr.getSubsts()); return as<IntVal>(valResult); } @@ -2475,7 +2473,7 @@ namespace Slang // Get inner function DeclRef<Decl> unspecializedInnerRef = astBuilder->getSpecializedDeclRef<Decl>( getInner(baseFuncGenericDeclRef), - baseFuncGenericDeclRef.substitutions); + baseFuncGenericDeclRef.getSubst()); auto callableDeclRef = unspecializedInnerRef.as<CallableDecl>(); if (!callableDeclRef) return nullptr; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index e710f93ec..0a99fcb97 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -172,7 +172,7 @@ namespace Slang // Look at a candidate definition to be called and // see if it gives us a key to work with. // - Decl* funcDecl = item.declRef.decl; + Decl* funcDecl = item.declRef.getDecl(); if (auto genDecl = as<GenericDecl>(funcDecl)) funcDecl = genDecl->inner; @@ -707,9 +707,9 @@ namespace Slang void ensureDecl(Decl* decl, DeclCheckState state, SemanticsContext* baseContext = nullptr); /// Helper routine allowing `ensureDecl` to be called on a `DeclRef` - void ensureDecl(DeclRefBase const& declRef, DeclCheckState state) + void ensureDecl(DeclRefBase* declRef, DeclCheckState state) { - ensureDecl(declRef.getDecl(), state); + ensureDecl(declRef->getDecl(), state); } /// Helper routine allowing `ensureDecl` to be used on a `DeclBase` diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index a72ca621f..423d1f6bb 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -156,7 +156,7 @@ namespace Slang { auto expr = context.originalExpr; - auto decl = candidate.item.declRef.decl; + auto decl = candidate.item.declRef.getDecl(); if(const auto prefixExpr = as<PrefixExpr>(expr)) { @@ -516,16 +516,14 @@ namespace Slang SLANG_ASSERT(subst); subst->genericDecl = genericDeclRef.getDecl(); - subst->outer = genericDeclRef.substitutions.substitutions; + subst->outer = genericDeclRef.getSubst(); List<Val*> newArgs = subst->getArgs(); for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { - auto subset = genericDeclRef.substitutions; - subset.substitutions = subst; DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getSpecializedDeclRef( - constraintDecl, subset); + constraintDecl, subst); auto sub = getSub(m_astBuilder, constraintDeclRef); auto sup = getSup(m_astBuilder, constraintDeclRef); @@ -545,7 +543,7 @@ namespace Slang } } - candidate.subst = m_astBuilder->getOrCreateGenericSubstitution(genericDeclRef.getDecl(), newArgs, genericDeclRef.substitutions.substitutions); + candidate.subst = m_astBuilder->getOrCreateGenericSubstitution(genericDeclRef.getDecl(), newArgs, genericDeclRef.getSubst()); // Done checking all the constraints, hooray. return true; @@ -596,7 +594,7 @@ namespace Slang } subst->genericDecl = baseGenericRef.getDecl(); - subst->outer = baseGenericRef.substitutions.substitutions; + subst->outer = baseGenericRef.getSubst(); DeclRef<Decl> innerDeclRef = m_astBuilder->getSpecializedDeclRef<Decl>(getInner(baseGenericRef), subst); @@ -822,8 +820,8 @@ namespace Slang // directly (it is only visible through the requirement witness // information for inheritance declarations). // - bool leftIsInterfaceRequirement = isInterfaceRequirement(left.declRef); - bool rightIsInterfaceRequirement = isInterfaceRequirement(right.declRef); + bool leftIsInterfaceRequirement = isInterfaceRequirement(left.declRef.getDecl()); + bool rightIsInterfaceRequirement = isInterfaceRequirement(right.declRef.getDecl()); if(leftIsInterfaceRequirement != rightIsInterfaceRequirement) return int(leftIsInterfaceRequirement) - int(rightIsInterfaceRequirement); @@ -1233,7 +1231,7 @@ namespace Slang // use any substitutions that were in place for referring to the // generic itself. // - Substitutions* substForInnerDecl = genericDeclRef.substitutions; + Substitutions* substForInnerDecl = genericDeclRef.getSubst(); // // In the case where we have explicit/known arguments, // we will use those as our baseline substitutions. @@ -1274,7 +1272,7 @@ namespace Slang // if (valueArgCount > valueParamCount) { - return DeclRef<Decl>(nullptr); + return DeclRef<Decl>(); } // If any of the arguments were specified explicitly (and are thus known), @@ -1310,7 +1308,7 @@ namespace Slang else { // TODO(tfoley): any other cases needed here? - return DeclRef<Decl>(nullptr); + return DeclRef<Decl>(); } // Once we have added all the appropriate constraints to the system, we @@ -1338,7 +1336,7 @@ namespace Slang // diagnostics), or this code could have a "just trying" vs. "actually // do things" distinction like some other steps. // - return DeclRef<Decl>(nullptr); + return DeclRef<Decl>(); } // If we found a solution (that is, a set of argument values that satisfy @@ -1623,7 +1621,8 @@ namespace Slang while (auto hoInner = as<HigherOrderInvokeExpr>(inner)) { lastInner = hoInner; - hoInner->type = innerRef.substitute(m_astBuilder, hoInner->type.type); + if (innerRef) + hoInner->type = innerRef.substitute(m_astBuilder, hoInner->type.type); inner = hoInner->baseFunction; } // Set inner expression to resolved declref expr. diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index acefa7660..69e419f75 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -583,7 +583,7 @@ namespace Slang return varDecl->getName(); } - Type* getParamType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& paramDeclRef) + Type* getParamType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> paramDeclRef) { auto paramType = getType(astBuilder, paramDeclRef); if (paramDeclRef.getDecl()->findModifier<NoDiffModifier>()) @@ -1207,17 +1207,15 @@ namespace Slang getLinkage()->getASTBuilder()->getOrCreateGenericSubstitution( genericDeclRef.getDecl(), genericArgs, - genericDeclRef.substitutions.substitutions); + genericDeclRef.getSubst()); + ASTBuilder* astBuilder = getLinkage()->getASTBuilder(); - for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) + for (auto constraintDecl : getMembersOfType<GenericTypeConstraintDecl>( + getLinkage()->getASTBuilder(), DeclRef<ContainerDecl>(genericDeclRef))) { - auto constraintSubst = genericDeclRef.substitutions; - constraintSubst.substitutions = genericSubst; + DeclRef<GenericTypeConstraintDecl> constraintDeclRef = astBuilder->getSpecializedDeclRef( + constraintDecl.getDecl(), genericSubst); - DeclRef<GenericTypeConstraintDecl> constraintDeclRef = getLinkage()->getASTBuilder()->getSpecializedDeclRef( - constraintDecl, constraintSubst); - - ASTBuilder* astBuilder = getLinkage()->getASTBuilder(); auto sub = getSub(astBuilder, constraintDeclRef); auto sup = getSup(astBuilder, constraintDeclRef); @@ -1239,8 +1237,8 @@ namespace Slang getLinkage()->getASTBuilder()->getOrCreateGenericSubstitution( genericDeclRef.getDecl(), genericArgs, - genericDeclRef.substitutions.substitutions); - specializedFuncDeclRef.substitutions.substitutions = genericSubst; + genericDeclRef.getSubst()); + specializedFuncDeclRef = astBuilder->getSpecializedDeclRef(specializedFuncDeclRef.getDecl(), genericSubst); } info->specializedFuncDeclRef = specializedFuncDeclRef; diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 5f246cc15..6453f68ab 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -372,7 +372,7 @@ namespace Slang auto varDecl = as<VarDecl>(varStmt->decl); if (!varDecl) return; - initialVar.decl = varDecl; + initialVar = makeDeclRef<Decl>(varDecl); initialVal = varDecl->initExpr; } else if (auto exprStmt = as<ExpressionStmt>(stmt->initialStatement)) diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index a5a46c435..d62d60db4 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -188,12 +188,9 @@ namespace Slang } GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution( - genericDeclRef.getDecl(), evaledArgs, genericDeclRef.substitutions.substitutions); - - DeclRef<Decl> innerDeclRef; - innerDeclRef.decl = getInner(genericDeclRef); - innerDeclRef.substitutions = SubstitutionSet(subst); + genericDeclRef.getDecl(), evaledArgs, genericDeclRef.getSubst()); + DeclRef<Decl> innerDeclRef = m_astBuilder->getSpecializedDeclRef(getInner(genericDeclRef), subst); return DeclRefType::create(m_astBuilder, innerDeclRef); } diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp index a439b4454..9e9efeb64 100644 --- a/source/slang/slang-doc-markdown-writer.cpp +++ b/source/slang/slang-doc-markdown-writer.cpp @@ -271,7 +271,7 @@ void DocMarkdownWriter::writeSignature(CallableDecl* callableDecl) List<ASTPrinter::Part> parts; ASTPrinter printer(m_astBuilder, ASTPrinter::OptionFlag::ParamNames, &parts); - printer.addDeclSignature(m_astBuilder->getSpecializedDeclRef<Decl>(callableDecl, nullptr)); + printer.addDeclSignature(makeDeclRef(callableDecl)); Signature signature; getSignature(parts, signature); diff --git a/source/slang/slang-language-server-semantic-tokens.cpp b/source/slang/slang-language-server-semantic-tokens.cpp index d52f631bd..ab6d8b5ab 100644 --- a/source/slang/slang-language-server-semantic-tokens.cpp +++ b/source/slang/slang-language-server-semantic-tokens.cpp @@ -62,7 +62,7 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS return; SemanticToken token = _createSemanticToken(manager, declRef->loc, declRef->name); - auto target = declRef->declRef.decl; + auto target = declRef->declRef.getDecl(); if (as<AggTypeDecl>(target)) { if (target->hasModifier<BuiltinTypeModifier>()) diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index db50f9bdb..e79716975 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -707,11 +707,11 @@ SlangResult LanguageServer::hover( } else if (auto decl = as<Decl>(leafNode)) { - fillDeclRefHoverInfo(version->linkage->getASTBuilder()->getSpecializedDeclRef(decl, nullptr)); + fillDeclRefHoverInfo(makeDeclRef(decl)); } else if (auto attr = as<Attribute>(leafNode)) { - fillDeclRefHoverInfo(version->linkage->getASTBuilder()->getSpecializedDeclRef(attr->attributeDecl, nullptr)); + fillDeclRefHoverInfo(makeDeclRef(attr->attributeDecl)); } if (sb.getLength() == 0) { @@ -1320,7 +1320,7 @@ SlangResult LanguageServer::signatureHelp( // Look for initializers for (auto member : aggDecl->getMembersOfType<ConstructorDecl>()) { - addDeclRef(version->linkage->getASTBuilder()->getSpecializedDeclRef<Decl>(member, declRefExpr->declRef.substitutions)); + addDeclRef(version->linkage->getASTBuilder()->getSpecializedDeclRef<Decl>(member, declRefExpr->declRef.getSubst())); } } else diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index 46977b71d..b16671efb 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -176,7 +176,7 @@ static void _lookUpDirectAndTransparentMembers( AddToLookupResult( result, CreateLookupResultItem( - astBuilder->getSpecializedDeclRef<Decl>(member, containerDeclRef.substitutions), inBreadcrumbs)); + astBuilder->getSpecializedDeclRef<Decl>(member, containerDeclRef.getSubst()), inBreadcrumbs)); } } else @@ -201,7 +201,7 @@ static void _lookUpDirectAndTransparentMembers( continue; // The declaration passed the test, so add it! - AddToLookupResult(result, CreateLookupResultItem(astBuilder->getSpecializedDeclRef<Decl>(m, containerDeclRef.substitutions), inBreadcrumbs)); + AddToLookupResult(result, CreateLookupResultItem(astBuilder->getSpecializedDeclRef<Decl>(m, containerDeclRef.getSubst()), inBreadcrumbs)); } } @@ -211,7 +211,7 @@ static void _lookUpDirectAndTransparentMembers( { // The reference to the transparent member should use whatever // substitutions we used in referring to its outer container - DeclRef<Decl> transparentMemberDeclRef = astBuilder->getSpecializedDeclRef(transparentInfo.decl, containerDeclRef.substitutions); + DeclRef<Decl> transparentMemberDeclRef = astBuilder->getSpecializedDeclRef(transparentInfo.decl, containerDeclRef.getSubst()); // We need to leave a breadcrumb so that we know that the result // of lookup involves a member lookup step here @@ -320,7 +320,7 @@ static Type* _maybeSpecializeSuperType( ThisTypeSubstitution* thisTypeSubst = astBuilder->create<ThisTypeSubstitution>(); thisTypeSubst->interfaceDecl = superInterfaceDeclRef.getDecl(); thisTypeSubst->witness = subIsSuperWitness; - thisTypeSubst->outer = superInterfaceDeclRef.substitutions.substitutions; + thisTypeSubst->outer = superInterfaceDeclRef.getSubst(); auto specializedInterfaceDeclRef = astBuilder->getSpecializedDeclRef<Decl>(superInterfaceDeclRef.getDecl(), thisTypeSubst); @@ -403,7 +403,7 @@ static void _lookUpMembersInSuperType( { if( request.semantics ) { - ensureDecl(request.semantics, intermediateIsSuperConstraint, DeclCheckState::CanUseBaseOfInheritanceDecl); + ensureDecl(request.semantics, intermediateIsSuperConstraint.getDecl(), DeclCheckState::CanUseBaseOfInheritanceDecl); } // The super-type in the constraint (e.g., `Foo` in `T : Foo`) @@ -450,11 +450,11 @@ static void _lookUpMembersInSuperTypeDeclImpl( auto genericDeclRef = genericTypeParamDeclRef.getParent(astBuilder).as<GenericDecl>(); assert(genericDeclRef); - for(auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericDeclRef)) + for(auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, DeclRef<ContainerDecl>(genericDeclRef))) { if( semantics ) { - ensureDecl(semantics, constraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); + ensureDecl(semantics, constraintDeclRef.getDecl(), DeclCheckState::CanUseBaseOfInheritanceDecl); } // Does this constraint pertain to the type we are working on? diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index babadb0f5..cf564605c 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -803,7 +803,7 @@ LoweredValInfo emitCallToDeclRef( if( auto ctorDeclRef = funcDeclRef.as<ConstructorDecl>() ) { - if(!ctorDeclRef.getDecl()->body && isFromStdLib(ctorDeclRef.decl) && !as<InterfaceDecl>(ctorDeclRef.decl->parentDecl)) + if(!ctorDeclRef.getDecl()->body && isFromStdLib(ctorDeclRef.getDecl()) && !as<InterfaceDecl>(ctorDeclRef.getParent(context->astBuilder).getDecl())) { SLANG_UNREACHABLE("stdlib error: __init() has no definition."); } @@ -1398,7 +1398,7 @@ void getGenericTypeConformances(IRGenContext* context, ShortList<IRType*>& supTy { if (auto declRefType = as<DeclRefType>(typeConstraint->sub.type)) { - if (declRefType->declRef.decl == genericParamDecl) + if (declRefType->declRef.getDecl() == genericParamDecl) { supTypes.add(lowerType(context, typeConstraint->getSup().type)); } @@ -1531,7 +1531,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower if (auto declaredMidToSup = as<DeclaredSubtypeWitness>(val->midToSup)) { - midToSup = getInterfaceRequirementKey(context, declaredMidToSup->declRef.decl); + midToSup = getInterfaceRequirementKey(context, declaredMidToSup->declRef.getDecl()); } else { @@ -2049,7 +2049,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower List<IRInst*> operands; // If there are any substitutions attached to the declRef, // add them as operands of the IR type. - _collectSubstitutionArgs(operands, type->declRef.substitutions.substitutions); + _collectSubstitutionArgs(operands, type->declRef.getSubst()); return getBuilder()->getType( op, static_cast<UInt>(operands.getCount()), @@ -2821,9 +2821,8 @@ ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection de DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, SemanticsVisitor* semantics, Decl* decl) { - DeclRef<Decl> declRef; - declRef.decl = decl; - declRef.substitutions = createDefaultSubstitutions(context->astBuilder, semantics, decl); + DeclRef<Decl> declRef = context->astBuilder->getSpecializedDeclRef( + decl, createDefaultSubstitutions(context->astBuilder, semantics, decl)); return declRef; } // @@ -2964,8 +2963,8 @@ IRLoweringParameterInfo getParameterInfo( IRLoweringParameterInfo info; info.type = getParamType(context->astBuilder, paramDecl); - info.decl = paramDecl; - info.direction = getParameterDirection(paramDecl); + info.decl = paramDecl.getDecl(); + info.direction = getParameterDirection(paramDecl.getDecl()); info.isThisParam = false; return info; } @@ -3039,13 +3038,13 @@ void collectParameterLists( // the outer declaration. The most important question here // is whether parameters of the outer declaration should // also count as parameters of the inner declaration. - ParameterListCollectMode innerMode = getModeForCollectingParentParameters(declRef, parentDeclRef); + ParameterListCollectMode innerMode = getModeForCollectingParentParameters(declRef.getDecl(), parentDeclRef.getDecl()); // Don't down-grade our `static`-ness along the chain. if(innerMode < mode) innerMode = mode; - ParameterDirection innerThisParamDirection = getThisParamDirection(declRef, thisParamDirection); + ParameterDirection innerThisParamDirection = getThisParamDirection(declRef.getDecl(), thisParamDirection); // Now collect any parameters from the parent declaration itself @@ -8010,7 +8009,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // if (auto declRefType = as<DeclRefType>(constraintDecl->sub.type)) { - auto typeParamDeclVal = subContext->findLoweredDecl(declRefType->declRef.decl); + auto typeParamDeclVal = subContext->findLoweredDecl(declRefType->declRef.getDecl()); SLANG_ASSERT(typeParamDeclVal && typeParamDeclVal->val); subBuilder->addTypeConstraintDecoration(typeParamDeclVal->val, supType); } @@ -9531,8 +9530,8 @@ LoweredValInfo emitDeclRef( { return emitDeclRef( context, - declRef.decl, - declRef.substitutions.substitutions, + declRef.getDecl(), + declRef.getSubst(), type); } @@ -10008,7 +10007,7 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor auto shaderParam = module->getShaderParam(ii); auto specializationArgCount = shaderParam.specializationParamCount; - IRInst* irParam = getSimpleVal(context, ensureDecl(context, shaderParam.paramDeclRef)); + IRInst* irParam = getSimpleVal(context, ensureDecl(context, shaderParam.paramDeclRef.getDecl())); List<IRInst*> irSlotArgs; // Tracks if there are any type args that is not an IRDynamicType. bool hasConcreteTypeArg = false; @@ -10223,7 +10222,7 @@ IRTypeLayout* lowerTypeLayout( // so that if we run into another type layout for the // same entry point we will re-use the same keys. // - if( !context->mapEntryPointParamToKey.tryGetValue(paramDecl, irFieldKey) ) + if( !context->mapEntryPointParamToKey.tryGetValue(paramDecl.getDecl(), irFieldKey)) { irFieldKey = context->irBuilder->createStructKey(); @@ -10240,13 +10239,13 @@ IRTypeLayout* lowerTypeLayout( // of these keys will be local to a single `IREntryPointLayout`, // and we don't support combination at a finer granularity than that. - context->mapEntryPointParamToKey.add(paramDecl, irFieldKey); + context->mapEntryPointParamToKey.add(paramDecl.getDecl(), irFieldKey); } } else { irFieldKey = getSimpleVal(context, - ensureDecl(context, fieldDecl)); + ensureDecl(context, fieldDecl.getDecl())); } SLANG_ASSERT(irFieldKey); @@ -10465,7 +10464,7 @@ RefPtr<IRModule> TargetProgram::createIRModuleForLayout(DiagnosticSink* sink) // has been emitted to this module, so that we will have something // to decorate. // - auto irVar = getSimpleVal(context, ensureDecl(context, varDecl)); + auto irVar = getSimpleVal(context, ensureDecl(context, varDecl.getDecl())); auto irLayout = lowerVarLayout(context, varLayout); diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index ee34358ab..de1b58999 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -423,7 +423,7 @@ namespace Slang // There are two cases here: either we have specializations // in place for the parent generic declaration, or we don't. - auto subst = findInnerMostGenericSubstitution(declRef.substitutions); + auto subst = findInnerMostGenericSubstitution(declRef.getSubst()); if( subst && subst->genericDecl == parentGenericDeclRef.getDecl() ) { // This is the case where we *do* have substitutions. @@ -484,7 +484,7 @@ namespace Slang for (auto type : constraint.value) { emitRaw(context, "C"); - emitQualifiedName(context, context->astBuilder->getSpecializedDeclRef(constraint.key, nullptr)); + emitQualifiedName(context, makeDeclRef(constraint.key)); emitType(context, type); } } @@ -531,6 +531,7 @@ namespace Slang // are asked to mangle the name of a `typedef`? auto decl = declRef.getDecl(); + if (!decl) return; // Handle `__extern_cpp` modifier by simply emitting // the given name. @@ -568,11 +569,11 @@ namespace Slang // mangling the generic and the inner entity emitRaw(context, "G"); - SLANG_ASSERT(genericDecl.substitutions == nullptr); + SLANG_ASSERT(genericDecl.getSubst() == nullptr); - auto innerDecl = makeDeclRef(getInner(genericDecl)); + auto innerDecl = getInner(genericDecl); - emitQualifiedName(context, innerDecl); + emitQualifiedName(context, makeDeclRef(innerDecl)); return; } else if (as<ForwardDerivativeRequirementDecl>(decl)) @@ -588,17 +589,16 @@ namespace Slang emitQualifiedName(context, declRef); } - String getMangledName(ASTBuilder* astBuilder, DeclRef<Decl> const& declRef) + static String getMangledName(ASTBuilder* astBuilder, DeclRef<Decl> const& declRef) { ManglingContext context(astBuilder); mangleName(&context, declRef); return context.sb.produceString(); } - String getMangledName(ASTBuilder* astBuilder, DeclRefBase const & declRef) + String getMangledName(ASTBuilder* astBuilder, DeclRefBase* declRef) { - return getMangledName(astBuilder, - astBuilder->getSpecializedDeclRef<Decl>(declRef.decl, declRef.substitutions)); + return getMangledName(astBuilder, DeclRef<Decl>(declRef)); } String getMangledName(ASTBuilder* astBuilder, Decl* decl) diff --git a/source/slang/slang-mangle.h b/source/slang/slang-mangle.h index 723b7250e..e28a6f09f 100644 --- a/source/slang/slang-mangle.h +++ b/source/slang/slang-mangle.h @@ -11,8 +11,7 @@ namespace Slang struct IRSpecialize; String getMangledName(ASTBuilder* astBuilder, Decl* decl); - String getMangledName(ASTBuilder* astBuilder, DeclRef<Decl> const & declRef); - String getMangledName(ASTBuilder* astBuilder, DeclRefBase const & declRef); + String getMangledName(ASTBuilder* astBuilder, DeclRefBase* declRef); String getMangledNameFromNameString(const UnownedStringSlice& name); String getHashedName(const UnownedStringSlice& mangledName); diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 4679e58c3..e22bc0597 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -960,7 +960,7 @@ static void addExplicitParameterBindings_HLSL( // TODO: warning here! } - addExplicitParameterBinding(context, parameterInfo, varDecl, semanticInfo, count); + addExplicitParameterBinding(context, parameterInfo, varDecl.getDecl(), semanticInfo, count); } } @@ -1048,7 +1048,7 @@ static void addExplicitParameterBindings_GLSL( auto count = resInfo->count; semanticInfo.kind = kind; - addExplicitParameterBinding(context, parameterInfo, varDecl, semanticInfo, count); + addExplicitParameterBinding(context, parameterInfo, varDecl.getDecl(), semanticInfo, count); return; } @@ -1071,7 +1071,7 @@ static void addExplicitParameterBindings_GLSL( // // TODO(JS): I suppose there is some ambiguity here, because if we did a semantic lookup, and it didn't have a vulkanKind // or didn't parse correctly we wouldn't issue this message. - getSink(context)->diagnose(varDecl, Diagnostics::cannotInferVulkanBindingWithoutRegisterModifier, varDecl); + getSink(context)->diagnose(varDecl.getDecl(), Diagnostics::cannotInferVulkanBindingWithoutRegisterModifier, varDecl); return; } @@ -1098,7 +1098,7 @@ static void addExplicitParameterBindings_GLSL( { // If we made it here, there are shift options, but there isn't one for the space/kind specified // That could be a problem and unexpected, so issue a warning - getSink(context)->diagnose(varDecl, Diagnostics::hlslToVulkanMappingNotFound, varDecl); + getSink(context)->diagnose(varDecl.getDecl(), Diagnostics::hlslToVulkanMappingNotFound, varDecl); return; } @@ -1111,7 +1111,7 @@ static void addExplicitParameterBindings_GLSL( const LayoutSize count = resInfo->count; - addExplicitParameterBinding(context, parameterInfo, varDecl, semanticInfo, count); + addExplicitParameterBinding(context, parameterInfo, varDecl.getDecl(), semanticInfo, count); } // Given a single parameter, collect whatever information we have on @@ -2071,7 +2071,7 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( auto fieldResInfo = fieldVarLayout->FindResourceInfo(kind); if( !fieldResInfo ) { - if(!firstImplicit) firstImplicit = field; + if(!firstImplicit) firstImplicit = field.getDecl(); // In the implicit-layout case, we assign the field // the next available offset after the fields that @@ -2083,7 +2083,7 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( } else { - if(!firstExplicit) firstExplicit = field; + if(!firstExplicit) firstExplicit = field.getDecl(); // In the explicit case, the field already has offset // information, and we just need to update the computed @@ -2108,7 +2108,7 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( if( auto concreteType = findGlobalGenericSpecializationArg( layoutContext, - globalGenericParamDecl) ) + globalGenericParamDecl.getDecl()) ) { // If we know what concrete type has been used to specialize // the global generic type parameter, then we should use @@ -2134,7 +2134,7 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( // to the generic type, since we can't know how many "slots" // of varying input/output it would consume. // - return createTypeLayoutForGlobalGenericTypeParam(layoutContext, type, globalGenericParamDecl); + return createTypeLayoutForGlobalGenericTypeParam(layoutContext, type, globalGenericParamDecl.getDecl()); } } else if (auto associatedTypeParam = declRef.as<AssocTypeDecl>()) @@ -2726,7 +2726,7 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( // Any generic specialization applied to the entry-point function // must also be applied to its parameters. - paramDeclRef.substitutions = entryPointFuncDeclRef.substitutions; + paramDeclRef = context->getASTBuilder()->getSpecializedDeclRef(paramDeclRef.getDecl(), entryPointFuncDeclRef.getSubst()); // When computing layout for an entry-point parameter, // we want to make sure that the layout context has access @@ -3782,7 +3782,7 @@ RefPtr<ProgramLayout> generateParameterBindings( if( varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) { needDefaultConstantBuffer = true; - diagnoseGlobalUniform(&sharedContext, varLayout->varDecl); + diagnoseGlobalUniform(&sharedContext, varLayout->varDecl.getDecl()); } } } diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 07857c293..7a79e9fcd 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -433,7 +433,7 @@ SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* inTyp if( auto structDeclRef = declRef.as<StructDecl>()) { return (unsigned int)getFields( - getModule(declRef.decl)->getLinkage()->getASTBuilder(), + getModule(declRef.getDecl())->getLinkage()->getASTBuilder(), structDeclRef, MemberFilterStyle::Instance) .getCount(); @@ -456,7 +456,7 @@ SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflect if( auto structDeclRef = declRef.as<StructDecl>()) { auto fields = getFields( - getModule(declRef)->getLinkage()->getASTBuilder(), structDeclRef, MemberFilterStyle::Instance); + getModule(declRef.getDecl())->getLinkage()->getASTBuilder(), structDeclRef, MemberFilterStyle::Instance); auto fieldDeclRef = fields[index]; return (SlangReflectionVariable*) fieldDeclRef.getDecl(); } @@ -924,7 +924,7 @@ SLANG_API SlangInt spReflectionTypeLayout_findFieldIndexByName(SlangReflectionTy for(Index f = 0; f < fieldCount; ++f) { auto field = structTypeLayout->fields[f]; - if(getReflectionName(field->varDecl)->text.getUnownedSlice() == name) + if(getReflectionName(field->varDecl.getDecl())->text.getUnownedSlice() == name) return f; } } diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h index 0412ef4da..c28c8a6d6 100644 --- a/source/slang/slang-serialize-ast-type-info.h +++ b/source/slang/slang-serialize-ast-type-info.h @@ -39,52 +39,9 @@ struct SerialTypeInfo<SyntaxClass<T>> } }; -// All the templates for DeclRef<T> can use this implementation. -struct SerialDeclRefBaseTypeInfo -{ - typedef DeclRefBase NativeType; - struct SerialType - { - SerialIndex substitutions; - SerialIndex decl; - }; - enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - SerialType& serial = *(SerialType*)outSerial; - const NativeType& native = *(const NativeType*)inNative; - - serial.decl = writer->addPointer(native.decl); - serial.substitutions = writer->addPointer(native.substitutions.substitutions); - } - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - DeclRefBase& native = *(DeclRefBase*)(outNative); - const SerialType& serial = *(const SerialType*)inSerial; - - native.decl = reader->getPointer(serial.decl).dynamicCast<Decl>(); - native.substitutions.substitutions = reader->getPointer(serial.substitutions).dynamicCast<Substitutions>(); - } - static const SerialFieldType* getFieldType() - { - static const SerialFieldType type = { sizeof(SerialType), uint8_t(SerialAlignment), &toSerial, &toNative }; - return &type; - } -}; -// Special case DeclRef, because it always uses the same type -template <typename T> -struct SerialGetFieldType<DeclRef<T>> -{ - static const SerialFieldType* getFieldType() { return SerialDeclRefBaseTypeInfo::getFieldType(); } -}; - template <typename T> -struct SerialTypeInfo<DeclRef<T>> : public SerialDeclRefBaseTypeInfo {}; - -template<> -struct SerialTypeInfo<DeclRefBase> : public SerialDeclRefBaseTypeInfo {}; +struct SerialTypeInfo<DeclRef<T>> : public SerialTypeInfo<DeclRefBase*> {}; // MatrixCoord can just go as is template <> diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 09329689e..fbb30e4eb 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -20,6 +20,11 @@ void printDiagnosticArg(StringBuilder& sb, Decl* decl) sb << getText(decl->getName()); } +void printDiagnosticArg(StringBuilder& sb, DeclRefBase* declRefBase) +{ + printDiagnosticArg(sb, declRefBase->getDecl()); +} + void printDiagnosticArg(StringBuilder& sb, Type* type) { if (!type) @@ -64,6 +69,12 @@ SourceLoc getDiagnosticPos(TypeExp const& typeExp) return typeExp.exp->loc; } +SourceLoc getDiagnosticPos(DeclRefBase* declRef) +{ + if (!declRef) + return SourceLoc(); + return declRef->getDecl()->loc; +} // !!!!!!!!!!!!!!!!!!!!!!!!!!!!! Free functions !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -327,7 +338,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // So, in order to get the *right* end result, we need to apply // the substitutions from the inheritance decl-ref to the witness. // - requirementWitness = requirementWitness.specialize(astBuilder, inheritanceDeclRef.substitutions); + requirementWitness = requirementWitness.specialize(astBuilder, inheritanceDeclRef.getSubst()); return requirementWitness; } @@ -338,14 +349,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt if (auto declaredSubtypeWitnessMidToSup = as<DeclaredSubtypeWitness>(transitiveTypeWitness->midToSup)) { auto midKey = declaredSubtypeWitnessMidToSup->declRef; - auto midWitness = tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(transitiveTypeWitness->subToMid), midKey); + auto midWitness = tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(transitiveTypeWitness->subToMid), midKey.getDecl()); if (midWitness.getFlavor() == RequirementWitness::Flavor::witnessTable) { auto table = midWitness.getWitnessTable(); RequirementWitness result; if (table->requirementDictionary.tryGetValue(requirementKey, result)) { - result = result.specialize(astBuilder, midKey.substitutions); + result = result.specialize(astBuilder, midKey.getSubst()); } return result; } @@ -436,7 +447,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // We have a generic ancestor, but do we have an substitutions for it? GenericSubstitution* foundSubst = nullptr; - for(auto s = declRef.substitutions.substitutions; s; s = s->outer) + for(auto s = declRef.getSubst(); s; s = s->outer) { auto genSubst = as<GenericSubstitution>(s); if(!genSubst) @@ -489,7 +500,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt else if (auto magicMod = declRef.getDecl()->findModifier<MagicTypeModifier>()) { GenericSubstitution* subst = nullptr; - for(auto s = declRef.substitutions.substitutions; s; s = s->outer) + for(auto s = declRef.getSubst(); s; s = s->outer) { if(auto genericSubst = as<GenericSubstitution>(s)) { @@ -581,7 +592,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt else if (magicMod->magicName == #n) \ { \ auto type = astBuilder->getOrCreateWithDefaultCtor<T>( \ - declRef.decl, declRef.substitutions.substitutions); \ + declRef.getDecl(), declRef.getSubst()); \ type->declRef = declRef; \ return type; \ } @@ -663,7 +674,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } else { - return astBuilder->getOrCreateDeclRefType(declRef.decl, declRef.substitutions.substitutions); + return astBuilder->getOrCreateDeclRefType(declRef.declRefBase); } } @@ -697,13 +708,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return Slang::as<Type>(type->substitute(astBuilder, substitutions)); } - DeclRefBase DeclRefBase::substitute(ASTBuilder* astBuilder, DeclRefBase declRef) const + DeclRefBase* DeclRefBase::substitute(ASTBuilder* astBuilder, DeclRefBase* declRef) const { if(!substitutions) return declRef; int diff = 0; - return declRef.substituteImpl(astBuilder, substitutions, &diff); + return declRef->substituteImpl(astBuilder, substitutions, &diff); } SubstExpr<Expr> DeclRefBase::substitute(ASTBuilder* /* astBuilder*/, Expr* expr) const @@ -723,7 +734,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt int diff = 0; auto declRefBase = declRef.substituteImpl(astBuilder, substs, &diff); - return astBuilder->getSpecializedDeclRef<Decl>(declRefBase.decl, declRefBase.substitutions); + return astBuilder->getSpecializedDeclRef<Decl>(declRefBase.getDecl(), declRefBase.getSubst()); } Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type) @@ -944,11 +955,11 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - DeclRefBase DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet substSet, int* ioDiff) const + DeclRefBase* DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet substSet, int* ioDiff) const { // Nothing to do when we have no declaration. if(!decl) - return *this; + return const_cast<DeclRefBase*>(this); // Apply the given substitutions to any specializations // that have already been applied to this declaration. @@ -957,18 +968,16 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt auto substSubst = specializeSubstitutions( astBuilder, decl, - substitutions.substitutions, + substitutions, substSet.substitutions, &diff); if (!diff) - return *this; + return const_cast<DeclRefBase*>(this); *ioDiff += diff; - DeclRefBase substDeclRef; - substDeclRef.decl = decl; - substDeclRef.substitutions = substSubst; + DeclRefBase* substDeclRef = astBuilder->getSpecializedDeclRef(decl, substSubst); // TODO: The old code here used to try to translate a decl-ref // to an associated type in a decl-ref for the concrete type @@ -980,13 +989,21 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return substDeclRef; } + bool DeclRefBase::_equalsValOverride(Val* val) + { + if (auto otherDeclRef = as<DeclRefBase>(val)) + return equals(otherDeclRef); + return false; + } // Check if this is an equivalent declaration reference to another - bool DeclRefBase::equals(DeclRefBase const& declRef) const + bool DeclRefBase::equals(DeclRefBase* declRef) const { - if (decl != declRef.decl) + if (!declRef) + return false; + if (decl != declRef->decl) return false; - if (!substitutions.equals(declRef.substitutions)) + if (!SubstitutionSet(substitutions).equals(declRef->substitutions)) return false; return true; @@ -1006,7 +1023,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return decl->loc; } - DeclRefBase DeclRefBase::getParent(ASTBuilder* astBuilder) const + DeclRefBase* DeclRefBase::getParent(ASTBuilder* astBuilder) const { // Want access to the free function (the 'as' method by default gets priority) // Can access as method with this->as because it removes any ambiguity. @@ -1014,11 +1031,11 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt auto parentDecl = decl->parentDecl; if (!parentDecl) - return DeclRefBase(); + return nullptr; // Default is to apply the same set of substitutions/specializations // to the parent declaration as were applied to the child. - Substitutions* substToApply = substitutions.substitutions; + Substitutions* substToApply = substitutions; if(auto interfaceDecl = as<InterfaceDecl>(decl)) { @@ -1059,7 +1076,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt HashCode DeclRefBase::getHashCode() const { - return combineHash(PointerHash<1>::getHashCode(decl), substitutions.getHashCode()); + return combineHash(PointerHash<1>::getHashCode(decl), SubstitutionSet(substitutions).getHashCode()); } // IntVal @@ -1080,12 +1097,12 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt Type* HLSLPatchType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); } IntVal* HLSLPatchType::getElementCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]); + return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); } // MeshOutputType @@ -1096,12 +1113,12 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt Type* MeshOutputType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); } IntVal* MeshOutputType::getMaxElementCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]); + return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); } // Constructors for types @@ -1212,7 +1229,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt auto substAssocTypeDecl = substDeclRef.getDecl(); - for (auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) + for (auto s = substDeclRef.getSubst(); s; s = s->outer) { auto thisSubst = as<ThisTypeSubstitution>(s); if (!thisSubst) @@ -1252,7 +1269,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt auto innerDeclRefType = as<DeclRefType>(thisSubst->witness->sub); if (!innerDeclRefType) return nullptr; - auto innerBuiltinReq = innerDeclRefType->declRef.decl->findModifier<BuiltinRequirementModifier>(); + auto innerBuiltinReq = innerDeclRefType->declRef.getDecl()->findModifier<BuiltinRequirementModifier>(); if (!innerBuiltinReq) return nullptr; if (innerBuiltinReq->kind != BuiltinRequirementKind::DifferentialType) @@ -1281,7 +1298,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } // Prints a partially qualified type name with generic substitutions. - void _printNestedDecl(const Substitutions* substitutions, Decl* decl, StringBuilder& out) + void _printNestedDecl(const Substitutions* substitutions, const Decl* decl, StringBuilder& out) { // If there is a parent scope for the declaration, print it first. // Exclude top-level namespaces like `tu0` or `core`. @@ -1307,7 +1324,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // type instead. for (;;) { - if (auto interfaceDecl = as<InterfaceDecl>(decl)) + if (auto interfaceDecl = const_cast<InterfaceDecl*>(as<InterfaceDecl>(decl))) { if (auto thisSubst = findThisTypeSubstitution(substitutions, interfaceDecl)) { diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 44bd1743f..a63a2471c 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -52,22 +52,22 @@ namespace Slang DeclRef<AggTypeDecl> const& declRef, SemanticsVisitor* semantics); - inline FilteredMemberRefList<Decl> getMembers(ASTBuilder* astBuilder, DeclRef<ContainerDecl> const& declRef, MemberFilterStyle filterStyle = MemberFilterStyle::All) + inline FilteredMemberRefList<Decl> getMembers(ASTBuilder* astBuilder, DeclRef<ContainerDecl> declRef, MemberFilterStyle filterStyle = MemberFilterStyle::All) { - return FilteredMemberRefList<Decl>(astBuilder, declRef.getDecl()->members, declRef.substitutions, filterStyle); + return FilteredMemberRefList<Decl>(astBuilder, declRef.getDecl()->members, declRef.getSubst(), filterStyle); } template<typename T> - inline FilteredMemberRefList<T> getMembersOfType(ASTBuilder* astBuilder, DeclRef<ContainerDecl> const& declRef, MemberFilterStyle filterStyle = MemberFilterStyle::All) + inline FilteredMemberRefList<T> getMembersOfType(ASTBuilder* astBuilder, DeclRef<ContainerDecl> declRef, MemberFilterStyle filterStyle = MemberFilterStyle::All) { - return FilteredMemberRefList<T>(astBuilder, declRef.getDecl()->members, declRef.substitutions, filterStyle); + return FilteredMemberRefList<T>(astBuilder, declRef.getDecl()->members, declRef.getSubst(), filterStyle); } void _foreachDirectOrExtensionMemberOfType( SemanticsVisitor* semantics, DeclRef<ContainerDecl> const& declRef, SyntaxClassBase const& syntaxClass, - void (*callback)(DeclRefBase, void*), + void (*callback)(DeclRefBase*, void*), void const* userData); DeclRef<Decl> _getSpecializedDeclRef(ASTBuilder* builder, Decl* decl, Substitutions* subst); @@ -82,9 +82,9 @@ namespace Slang { const F* userFunc; SemanticsVisitor* semanticsVisitor; - static void callback(DeclRefBase declRef, void* userData) + static void callback(DeclRefBase* declRef, void* userData) { - (*((*(Helper*)userData).userFunc))(_getSpecializedDeclRef(semanticsVisitorGetASTBuilder((*(Helper*)userData).semanticsVisitor), declRef.decl, declRef.substitutions).template as<T>()); + (*((*(Helper*)userData).userFunc))(DeclRef<T>(declRef)); } }; Helper helper; @@ -108,72 +108,72 @@ namespace Slang /// Name* getReflectionName(VarDeclBase* varDecl); - inline Type* getType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& declRef) + inline Type* getType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } /// same as getType, but take into account the additional type modifiers from the parameter's modifier list /// and return a ModifiedType if such modifiers exist. - Type* getParamType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& paramDeclRef); + Type* getParamType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> paramDeclRef); - inline SubstExpr<Expr> getInitExpr(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& declRef) + inline SubstExpr<Expr> getInitExpr(ASTBuilder* astBuilder, DeclRef<VarDeclBase> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->initExpr); } - inline Type* getType(ASTBuilder* astBuilder, DeclRef<PropertyDecl> const& declRef) + inline Type* getType(ASTBuilder* astBuilder, DeclRef<PropertyDecl> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } - inline Type* getType(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> const& declRef) + inline Type* getType(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } - inline SubstExpr<Expr> getTagExpr(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> const& declRef) + inline SubstExpr<Expr> getTagExpr(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->tagExpr); } - inline Type* getTargetType(ASTBuilder* astBuilder, DeclRef<ExtensionDecl> const& declRef) + inline Type* getTargetType(ASTBuilder* astBuilder, DeclRef<ExtensionDecl> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->targetType.Ptr()); } - inline FilteredMemberRefList<VarDecl> getFields(ASTBuilder* astBuilder, DeclRef<StructDecl> const& declRef, MemberFilterStyle filterStyle) + inline FilteredMemberRefList<VarDecl> getFields(ASTBuilder* astBuilder, DeclRef<StructDecl> declRef, MemberFilterStyle filterStyle) { return getMembersOfType<VarDecl>(astBuilder, declRef, filterStyle); } /// If the given `structTypeDeclRef` inherits from another struct type, return that base type - DeclRefType* findBaseStructType(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef); + DeclRefType* findBaseStructType(ASTBuilder* astBuilder, DeclRef<StructDecl> structTypeDeclRef); /// If the given `structTypeDeclRef` inherits from another struct type, return that base struct decl - DeclRef<StructDecl> findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef); + DeclRef<StructDecl> findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef<StructDecl> structTypeDeclRef); - inline Type* getTagType(ASTBuilder* astBuilder, DeclRef<EnumDecl> const& declRef) + inline Type* getTagType(ASTBuilder* astBuilder, DeclRef<EnumDecl> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->tagType); } - inline Type* getBaseType(ASTBuilder* astBuilder, DeclRef<InheritanceDecl> const& declRef) + inline Type* getBaseType(ASTBuilder* astBuilder, DeclRef<InheritanceDecl> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->base.type); } - inline Type* getType(ASTBuilder* astBuilder, DeclRef<TypeDefDecl> const& declRef) + inline Type* getType(ASTBuilder* astBuilder, DeclRef<TypeDefDecl> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } - inline Type* getResultType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declRef) + inline Type* getResultType(ASTBuilder* astBuilder, DeclRef<CallableDecl> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->returnType.type); } - inline Type* getErrorCodeType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declRef) + inline Type* getErrorCodeType(ASTBuilder* astBuilder, DeclRef<CallableDecl> declRef) { if (declRef.getDecl()->errorType.type) { @@ -185,12 +185,12 @@ namespace Slang } } - inline FilteredMemberRefList<ParamDecl> getParameters(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declRef) + inline FilteredMemberRefList<ParamDecl> getParameters(ASTBuilder* astBuilder, DeclRef<CallableDecl> declRef) { return getMembersOfType<ParamDecl>(astBuilder, declRef); } - inline Decl* getInner(DeclRef<GenericDecl> const& declRef) + inline Decl* getInner(DeclRef<GenericDecl> declRef) { // TODO: Should really return a `DeclRef<Decl>` for the inner // declaration, and not just a raw pointer diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 1388eba73..c933c5bb5 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -3984,7 +3984,7 @@ static TypeLayoutResult _createTypeLayout( { if( auto concreteType = findGlobalGenericSpecializationArg( context, - globalGenericParamDecl) ) + globalGenericParamDecl.getDecl()) ) { // If we know what concrete type has been used to specialize // the global generic type parameter, then we should use @@ -3997,7 +3997,7 @@ static TypeLayoutResult _createTypeLayout( // Otherwise we must create a type layout that represents // the generic type parameter itself. // - return _createTypeLayoutForGlobalGenericTypeParam(context, type, globalGenericParamDecl); + return _createTypeLayoutForGlobalGenericTypeParam(context, type, globalGenericParamDecl.getDecl()); } } else if (auto assocTypeParam = declRef.as<AssocTypeDecl>()) diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 5e0064226..daa9cda3b 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -4010,8 +4010,8 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor void collectReferencedModules(DeclRefBase const& declRef) { - collectReferencedModules(declRef.decl); - collectReferencedModules(declRef.substitutions); + collectReferencedModules(declRef.getDecl()); + collectReferencedModules(declRef.getSubst()); } void collectReferencedModules(Type* type) |
