diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2020-05-28 14:01:51 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-05-28 14:01:51 -0400 |
| commit | c2d31347ea06c768045e7c503ef0188e0e5356de (patch) | |
| tree | 1a4ee67aafca0a709ae691104023431bb6829825 /source | |
| parent | e5d0f3360f44a4cdd2390e7817db17bb3cc0dd04 (diff) | |
WIP: ASTBuilder (#1358)
* Compiles.
* Small tidy up around session/ASTBuilder.
* Tests are now passing.
* Fix Visual Studio project.
* Fix using new X to use builder when protectedness of Ctor is not enough.
Substitute->substitute
* Add some missing ast nodes created outside of ASTBuilder.
* Compile time check that ASTBuilder is making an AST type.
* Moced findClasInfo and findSyntaxClass (essentially the same thing) to SharedASTBuilder from Session.
Diffstat (limited to 'source')
39 files changed, 1304 insertions, 1224 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index ffa065323..9e99d008f 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -22,7 +22,10 @@ struct ReflectClassInfo; class NodeBase : public RefObject { SLANG_ABSTRACT_CLASS(NodeBase) - + + // By default AST types do *not* store the builder. This is called when constructed tho. + SLANG_FORCE_INLINE void setASTBuilder(ASTBuilder* astBuilder) { SLANG_UNUSED(astBuilder); } + SyntaxClass<NodeBase> getClass() { return SyntaxClass<NodeBase>(&getClassInfo()); } }; @@ -78,14 +81,14 @@ class Val : public NodeBase // construct a new value by applying a set of parameter // substitutions to this one - RefPtr<Val> substitute(SubstitutionSet subst); + RefPtr<Val> substitute(ASTBuilder* astBuilder, SubstitutionSet subst); // Lower-level interface for substitution. Like the basic // `Substitute` above, but also takes a by-reference // integer parameter that should be incremented when // returning a modified value (this can help the caller // decide whether they need to do anything). - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff); + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); virtual bool equalsVal(Val* val) = 0; virtual String toString() = 0; @@ -124,15 +127,20 @@ class Type: public Val void accept(ITypeVisitor* visitor, void* extra); -public: - Session* getSession() { return this->session; } - void setSession(Session* s) { this->session = s; } + /// Type derived types store the AST builder they were constructed on. The builder calls this function + /// after constructing + SLANG_FORCE_INLINE void setASTBuilder(ASTBuilder* astBuilder) { m_astBuilder = astBuilder; } + + /// Get the ASTBuilder that was used to construct this Type + SLANG_FORCE_INLINE ASTBuilder* getASTBuilder() const { return m_astBuilder; } + //Session* getSession() { return this->session; } + bool equals(Type* type); Type* getCanonicalType(); - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; virtual bool equalsVal(Val* val) override; @@ -145,7 +153,7 @@ protected: Type* canonicalType = nullptr; SLANG_UNREFLECTED - Session* session = nullptr; + ASTBuilder* m_astBuilder = nullptr; }; template <typename T> @@ -159,11 +167,14 @@ class Substitutions: public RefObject { SLANG_ABSTRACT_CLASS(Substitutions) + // By default AST types do *not* store the builder. This is called when constructed tho. + SLANG_FORCE_INLINE void setASTBuilder(ASTBuilder* astBuilder) { SLANG_UNUSED(astBuilder); } + // The next outer that this one refines. RefPtr<Substitutions> outer; // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr<Substitutions> applySubstitutionsShallow(SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) = 0; + virtual RefPtr<Substitutions> applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) = 0; // Check if these are equivalent substitutions to another set virtual bool equals(Substitutions* subst) = 0; @@ -182,7 +193,7 @@ class GenericSubstitution : public Substitutions List<RefPtr<Val> > args; // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr<Substitutions> applySubstitutionsShallow(SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) override; + virtual RefPtr<Substitutions> applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) override; // Check if these are equivalent substitutions to another set virtual bool equals(Substitutions* subst) override; @@ -212,7 +223,7 @@ class ThisTypeSubstitution : public Substitutions // The actual type that provides the lookup scope for an associated type // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr<Substitutions> applySubstitutionsShallow(SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) override; + virtual RefPtr<Substitutions> applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) override; // Check if these are equivalent substitutions to another set virtual bool equals(Substitutions* subst) override; @@ -239,7 +250,7 @@ class GlobalGenericParamSubstitution : public Substitutions List<ConstraintArg> constraintArgs; // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr<Substitutions> applySubstitutionsShallow(SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) override; + virtual RefPtr<Substitutions> applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) override; // Check if these are equivalent substitutions to another set virtual bool equals(Substitutions* subst) override; @@ -361,4 +372,5 @@ class Stmt : public ModifiableSyntaxNode void accept(IStmtVisitor* visitor, void* extra); }; + } // namespace Slang diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp new file mode 100644 index 000000000..7e40e52ec --- /dev/null +++ b/source/slang/slang-ast-builder.cpp @@ -0,0 +1,225 @@ +// slang-ast-builder.cpp +#include "slang-ast-builder.h" +#include <assert.h> + +#include "slang-compiler.h" + +namespace Slang { + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SharedASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +SharedASTBuilder::SharedASTBuilder() +{ +} + + +void SharedASTBuilder::init(Session* session) +{ + m_namePool = session->getNamePool(); + + // Save the associated session + m_session = session; + + // We just want as a place to store allocations of shared types + RefPtr<ASTBuilder> astBuilder(new ASTBuilder); + + astBuilder->m_sharedASTBuilder = this; + + memset(m_builtinTypes, 0, sizeof(m_builtinTypes)); + + m_errorType = m_astBuilder->create<ErrorType>(); + m_initializerListType = m_astBuilder->create<InitializerListType>(); + m_overloadedType = m_astBuilder->create<OverloadGroupType>(); + + m_astBuilder = astBuilder.detach(); + + // We can just iterate over the class pointers. + // NOTE! That this adds the names of the abstract classes too(!) + for (Index i = 0; i < Index(ASTNodeType::CountOf); ++i) + { + const ReflectClassInfo* info = ReflectClassInfo::getInfo(ASTNodeType(i)); + if (info) + { + m_sliceToTypeMap.Add(UnownedStringSlice(info->m_name), info); + Name* name = m_namePool->getName(String(info->m_name)); + m_nameToTypeMap.Add(name, info); + } + } +} + +const ReflectClassInfo* SharedASTBuilder::findClassInfo(const UnownedStringSlice& slice) +{ + const ReflectClassInfo* typeInfo; + return m_sliceToTypeMap.TryGetValue(slice, typeInfo) ? typeInfo : nullptr; +} + +SyntaxClass<RefObject> SharedASTBuilder::findSyntaxClass(const UnownedStringSlice& slice) +{ + const ReflectClassInfo* typeInfo; + if (m_sliceToTypeMap.TryGetValue(slice, typeInfo)) + { + return SyntaxClass<RefObject>(typeInfo); + } + return SyntaxClass<RefObject>(); +} + +const ReflectClassInfo* SharedASTBuilder::findClassInfo(Name* name) +{ + const ReflectClassInfo* typeInfo; + return m_nameToTypeMap.TryGetValue(name, typeInfo) ? typeInfo : nullptr; +} + +SyntaxClass<RefObject> SharedASTBuilder::findSyntaxClass(Name* name) +{ + const ReflectClassInfo* typeInfo; + if (m_nameToTypeMap.TryGetValue(name, typeInfo)) + { + return SyntaxClass<RefObject>(typeInfo); + } + return SyntaxClass<RefObject>(); +} + +Type* SharedASTBuilder::getStringType() +{ + if (!m_stringType) + { + auto stringTypeDecl = findMagicDecl("StringType"); + m_stringType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(stringTypeDecl)); + } + return m_stringType; +} + +Type* SharedASTBuilder::getEnumTypeType() +{ + if (!m_enumTypeType) + { + auto enumTypeTypeDecl = findMagicDecl("EnumTypeType"); + m_enumTypeType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(enumTypeTypeDecl)); + } + return m_enumTypeType; +} + +SharedASTBuilder::~SharedASTBuilder() +{ + // Release built in types.. + for (Index i = 0; i < SLANG_COUNT_OF(m_builtinTypes); ++i) + { + m_builtinTypes[i].setNull(); + } + + if (m_astBuilder) + { + m_astBuilder->releaseReference(); + } +} + +void SharedASTBuilder::registerBuiltinDecl(RefPtr<Decl> decl, RefPtr<BuiltinTypeModifier> modifier) +{ + auto type = DeclRefType::create(m_astBuilder, DeclRef<Decl>(decl.Ptr(), nullptr)); + m_builtinTypes[Index(modifier->tag)] = type; +} + +void SharedASTBuilder::registerMagicDecl(RefPtr<Decl> decl, RefPtr<MagicTypeModifier> modifier) +{ + // In some cases the modifier will have been applied to the + // "inner" declaration of a `GenericDecl`, but what we + // actually want to register is the generic itself. + // + auto declToRegister = decl; + if (auto genericDecl = as<GenericDecl>(decl->parentDecl)) + declToRegister = genericDecl; + + m_magicDecls[modifier->name] = declToRegister.Ptr(); +} + +RefPtr<Decl> SharedASTBuilder::findMagicDecl(const String& name) +{ + return m_magicDecls[name].GetValue(); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder): + m_sharedASTBuilder(sharedASTBuilder) +{ + SLANG_ASSERT(sharedASTBuilder); +} + +ASTBuilder::ASTBuilder(): + m_sharedASTBuilder(nullptr) +{ +} + +RefPtr<PtrType> ASTBuilder::getPtrType( RefPtr<Type> valueType) +{ + return getPtrType(valueType, "PtrType").dynamicCast<PtrType>(); +} + +// Construct the type `Out<valueType>` +RefPtr<OutType> ASTBuilder::getOutType(RefPtr<Type> valueType) +{ + return getPtrType(valueType, "OutType").dynamicCast<OutType>(); +} + +RefPtr<InOutType> ASTBuilder::getInOutType(RefPtr<Type> valueType) +{ + return getPtrType(valueType, "InOutType").dynamicCast<InOutType>(); +} + +RefPtr<RefType> ASTBuilder::getRefType(RefPtr<Type> valueType) +{ + return getPtrType(valueType, "RefType").dynamicCast<RefType>(); +} + +RefPtr<PtrTypeBase> ASTBuilder::getPtrType(RefPtr<Type> valueType, char const* ptrTypeName) +{ + auto genericDecl = m_sharedASTBuilder->findMagicDecl(ptrTypeName).dynamicCast<GenericDecl>(); + return getPtrType(valueType, genericDecl); +} + +RefPtr<PtrTypeBase> ASTBuilder::getPtrType(RefPtr<Type> valueType, GenericDecl* genericDecl) +{ + auto typeDecl = genericDecl->inner; + + auto substitutions = create<GenericSubstitution>(); + substitutions->genericDecl = genericDecl; + substitutions->args.add(valueType); + + auto declRef = DeclRef<Decl>(typeDecl.Ptr(), substitutions); + auto rsType = DeclRefType::create(this, declRef); + return as<PtrTypeBase>(rsType); +} + +RefPtr<ArrayExpressionType> ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount) +{ + RefPtr<ArrayExpressionType> arrayType = create<ArrayExpressionType>(); + arrayType->baseType = elementType; + arrayType->arrayLength = elementCount; + return arrayType; +} + +RefPtr<VectorExpressionType> ASTBuilder::getVectorType( + RefPtr<Type> elementType, + RefPtr<IntVal> elementCount) +{ + auto vectorGenericDecl = m_sharedASTBuilder->findMagicDecl("Vector").as<GenericDecl>(); + + auto vectorTypeDecl = vectorGenericDecl->inner; + + auto substitutions = new GenericSubstitution(); + substitutions->genericDecl = vectorGenericDecl.Ptr(); + substitutions->args.add(elementType); + substitutions->args.add(elementCount); + + auto declRef = DeclRef<Decl>(vectorTypeDecl.Ptr(), substitutions); + + return DeclRefType::create(this, declRef).as<VectorExpressionType>(); +} + +RefPtr<TypeType> ASTBuilder::getTypeType(Type* type) +{ + return create<TypeType>(type); +} + + +} // namespace Slang diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h new file mode 100644 index 000000000..0cb853424 --- /dev/null +++ b/source/slang/slang-ast-builder.h @@ -0,0 +1,174 @@ +// slang-ast-dump.h +#ifndef SLANG_AST_BUILDER_H +#define SLANG_AST_BUILDER_H + +#include "slang-ast-support-types.h" +#include "slang-ast-all.h" + +#include "../core/slang-type-traits.h" + +namespace Slang +{ + +class SharedASTBuilder : public RefObject +{ + friend class ASTBuilder; +public: + + void registerBuiltinDecl(RefPtr<Decl> decl, RefPtr<BuiltinTypeModifier> modifier); + void registerMagicDecl(RefPtr<Decl> decl, RefPtr<MagicTypeModifier> modifier); + + /// Get the string type + Type* getStringType(); + /// Get the enum type type + Type* getEnumTypeType(); + + const ReflectClassInfo* findClassInfo(Name* name); + SyntaxClass<RefObject> findSyntaxClass(Name* name); + + const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice); + SyntaxClass<RefObject> findSyntaxClass(const UnownedStringSlice& slice); + + // Look up a magic declaration by its name + RefPtr<Decl> findMagicDecl(String const& name); + + /// A name pool that can be used for lookup for findClassInfo etc. It is the same pool as the Session. + NamePool* getNamePool() { return m_namePool; } + + /// Must be called before used + void init(Session* session); + + SharedASTBuilder(); + + ~SharedASTBuilder(); + +protected: + // State shared between ASTBuilders + + RefPtr<Type> m_errorType; + RefPtr<Type> m_initializerListType; + RefPtr<Type> m_overloadedType; + + // The following types are created lazily, such that part of their definition + // can be in the standard library + // + // Note(tfoley): These logically belong to `Type`, + // but order-of-declaration stuff makes that tricky + // + // TODO(tfoley): These should really belong to the compilation context! + // + RefPtr<Type> m_stringType; + RefPtr<Type> m_enumTypeType; + + RefPtr<Type> m_builtinTypes[Index(BaseType::CountOf)]; + + Dictionary<String, Decl*> m_magicDecls; + + Dictionary<UnownedStringSlice, const ReflectClassInfo*> m_sliceToTypeMap; + Dictionary<Name*, const ReflectClassInfo*> m_nameToTypeMap; + + NamePool* m_namePool = nullptr; + + // This is a private builder used for these shared types + ASTBuilder* m_astBuilder = nullptr; + Session* m_session = nullptr; +}; + +class ASTBuilder : public RefObject +{ + friend class SharedASTBuilder; +public: + + // For compile time check to see if thing being constructed is an AST type + template <typename T> + struct IsValidType + { + enum + { + Value = IsBaseOf<NodeBase, T>::Value || IsBaseOf<Substitutions, T>::Value + }; + }; + + /// Create AST type. + template <typename T> + T* create() { SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); T* node = new T; node->setASTBuilder(this); return node; } + + template<typename T, typename P0> + T* create(const P0& p0) { SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); T* node = new T(p0); node->setASTBuilder(this); return node;} + + template<typename T, typename P0, typename P1> + T* create(const P0& p0, const P1& p1) { SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); T* node = new T(p0, p1); node->setASTBuilder(this); return node; } + + /// Get the built in types + SLANG_FORCE_INLINE Type* getBoolType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Bool)]; } + SLANG_FORCE_INLINE Type* getHalfType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Half)]; } + SLANG_FORCE_INLINE Type* getFloatType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Float)]; } + SLANG_FORCE_INLINE Type* getDoubleType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Double)]; } + SLANG_FORCE_INLINE Type* getIntType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Int)]; } + SLANG_FORCE_INLINE Type* getInt64Type() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Int64)]; } + SLANG_FORCE_INLINE Type* getUIntType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UInt)]; } + SLANG_FORCE_INLINE Type* getUInt64Type() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UInt64)]; } + SLANG_FORCE_INLINE Type* getVoidType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Void)]; } + + /// Get a builtin type by the BaseType + SLANG_FORCE_INLINE Type* getBuiltinType(BaseType flavor) { return m_sharedASTBuilder->m_builtinTypes[Index(flavor)]; } + + Type* getInitializerListType() { return m_sharedASTBuilder->m_initializerListType; } + Type* getOverloadedType() { return m_sharedASTBuilder->m_overloadedType; } + Type* getErrorType() { return m_sharedASTBuilder->m_errorType; } + Type* getStringType() { return m_sharedASTBuilder->getStringType(); } + Type* getEnumTypeType() { return m_sharedASTBuilder->getEnumTypeType(); } + + // Construct the type `Ptr<valueType>`, where `Ptr` + // is looked up as a builtin type. + RefPtr<PtrType> getPtrType(RefPtr<Type> valueType); + + // Construct the type `Out<valueType>` + RefPtr<OutType> getOutType(RefPtr<Type> valueType); + + // Construct the type `InOut<valueType>` + RefPtr<InOutType> getInOutType(RefPtr<Type> valueType); + + // Construct the type `Ref<valueType>` + RefPtr<RefType> getRefType(RefPtr<Type> valueType); + + // Construct a pointer type like `Ptr<valueType>`, but where + // the actual type name for the pointer type is given by `ptrTypeName` + RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, char const* ptrTypeName); + + // Construct a pointer type like `Ptr<valueType>`, but where + // the generic declaration for the pointer type is `genericDecl` + RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, GenericDecl* genericDecl); + + RefPtr<ArrayExpressionType> getArrayType(Type* elementType, IntVal* elementCount); + + RefPtr<VectorExpressionType> getVectorType(RefPtr<Type> elementType, RefPtr<IntVal> elementCount); + + RefPtr<TypeType> getTypeType(Type* type); + + /// Helpers to get type info from the SharedASTBuilder + const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice) { return m_sharedASTBuilder->findClassInfo(slice); } + SyntaxClass<RefObject> findSyntaxClass(const UnownedStringSlice& slice) { return m_sharedASTBuilder->findSyntaxClass(slice); } + + const ReflectClassInfo* findClassInfo(Name* name) { return m_sharedASTBuilder->findClassInfo(name); } + SyntaxClass<RefObject> findSyntaxClass(Name* name) { return m_sharedASTBuilder->findSyntaxClass(name); } + + /// Get the shared AST builder + SharedASTBuilder* getSharedASTBuilder() { return m_sharedASTBuilder; } + + /// Get the global session + Session* getGlobalSession() { return m_sharedASTBuilder->m_session; } + + /// Ctor + ASTBuilder(SharedASTBuilder* sharedASTBuilder); + +protected: + // Special default Ctor that can only be used by SharedASTBuilder + ASTBuilder(); + + SharedASTBuilder* m_sharedASTBuilder; +}; + +} // namespace Slang + +#endif diff --git a/source/slang/slang-ast-reflect.cpp b/source/slang/slang-ast-reflect.cpp index eb511689d..7efccd3e4 100644 --- a/source/slang/slang-ast-reflect.cpp +++ b/source/slang/slang-ast-reflect.cpp @@ -44,11 +44,13 @@ bool ReflectClassInfo::isSubClassOfSlow(const ThisType& super) const // Now try and implement all of the classes // Macro generated is of the format - -template <typename T> -struct CreateImpl +struct ASTConstructAccess { - static void* create() { return new T; } + template <typename T> + struct CreateImpl + { + static void* create() { return new T; } + }; }; #define SLANG_GET_SUPER_BASE(SUPER) nullptr @@ -56,10 +58,10 @@ struct CreateImpl #define SLANG_GET_SUPER_LEAF(SUPER) &SUPER::kReflectClassInfo #define SLANG_GET_CREATE_FUNC_ABSTRACT(NAME) nullptr -#define SLANG_GET_CREATE_FUNC_NONE(NAME) &CreateImpl<NAME>::create +#define SLANG_GET_CREATE_FUNC_NONE(NAME) &ASTConstructAccess::CreateImpl<NAME>::create #define SLANG_GET_CREATE_FUNC_NON_VISITOR_ABSTRACT(NAME) nullptr -#define SLANG_GET_CREATE_FUNC_NON_VISITOR(NAME) &CreateImpl<NAME>::create +#define SLANG_GET_CREATE_FUNC_NON_VISITOR(NAME) &ASTConstructAccess::CreateImpl<NAME>::create #define SLANG_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ diff --git a/source/slang/slang-ast-reflect.h b/source/slang/slang-ast-reflect.h index 5384e4afd..a5875d34a 100644 --- a/source/slang/slang-ast-reflect.h +++ b/source/slang/slang-ast-reflect.h @@ -12,12 +12,17 @@ // Implementation for SLANG_ABSTRACT_CLASS(x) using reflection from C++ extractor in slang-ast-generated.h #define SLANG_CLASS_REFLECT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ + protected: \ + NAME() = default; \ public: \ + typedef NAME This; \ typedef SUPER Super; \ static const ASTNodeType kType = ASTNodeType::NAME; \ static const ReflectClassInfo kReflectClassInfo; \ SLANG_FORCE_INLINE static bool isDerivedFrom(ASTNodeType type) { return int(type) >= int(kType) && int(type) <= int(ASTNodeType::LAST); } \ virtual const ReflectClassInfo& getClassInfo() const SLANG_AST_OVERRIDE_##TYPE { return kReflectClassInfo; } \ + friend class ASTBuilder; \ + friend struct ASTConstructAccess; // Macro definitions - use the SLANG_ASTNode_ definitions to invoke the IMPL to produce the code // injected into AST classes diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index 946106e78..9f7f8aa4d 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -177,7 +177,6 @@ class CompileTimeForStmt : public ScopeStmt class JumpStmt : public ChildStmt { SLANG_ABSTRACT_CLASS(JumpStmt) - }; class BreakStmt : public JumpStmt diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index af5689a5c..5cb7f3202 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -610,6 +610,8 @@ namespace Slang HashCode getHashCode() const; }; + class ASTBuilder; + template<typename T> struct DeclRef; @@ -644,15 +646,15 @@ namespace Slang {} // Apply substitutions to a type or declaration - RefPtr<Type> Substitute(RefPtr<Type> type) const; + RefPtr<Type> substitute(ASTBuilder* astBuilder, RefPtr<Type> type) const; - DeclRefBase Substitute(DeclRefBase declRef) const; + DeclRefBase substitute(ASTBuilder* astBuilder, DeclRefBase declRef) const; // Apply substitutions to an expression - RefPtr<Expr> Substitute(RefPtr<Expr> expr) const; + RefPtr<Expr> substitute(ASTBuilder* astBuilder, RefPtr<Expr> expr) const; // Apply substitutions to this declaration reference - DeclRefBase SubstituteImpl(SubstitutionSet subst, int* ioDiff); + DeclRefBase substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); // Returns true if 'as' will return a valid cast template <typename T> @@ -719,26 +721,26 @@ namespace Slang return DeclRef<T>((T*) declRef.decl, declRef.substitutions); } - RefPtr<Type> Substitute(RefPtr<Type> type) const + RefPtr<Type> substitute(ASTBuilder* astBuilder, RefPtr<Type> type) const { - return DeclRefBase::Substitute(type); + return DeclRefBase::substitute(astBuilder, type); } - RefPtr<Expr> Substitute(RefPtr<Expr> expr) const + RefPtr<Expr> substitute(ASTBuilder* astBuilder, RefPtr<Expr> expr) const { - return DeclRefBase::Substitute(expr); + return DeclRefBase::substitute(astBuilder, expr); } // Apply substitutions to a type or declaration template<typename U> - DeclRef<U> Substitute(DeclRef<U> declRef) const + DeclRef<U> substitute(ASTBuilder* astBuilder, DeclRef<U> declRef) const { - return DeclRef<U>::unsafeInit(DeclRefBase::Substitute(declRef)); + return DeclRef<U>::unsafeInit(DeclRefBase::substitute(astBuilder, declRef)); } // Apply substitutions to this declaration reference - DeclRef<T> SubstituteImpl(SubstitutionSet subst, int* ioDiff) + DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { - return DeclRef<T>::unsafeInit(DeclRefBase::SubstituteImpl(subst, ioDiff)); + return DeclRef<T>::unsafeInit(DeclRefBase::substituteImpl(astBuilder, subst, ioDiff)); } DeclRef<ContainerDecl> GetParent() const @@ -1284,7 +1286,7 @@ namespace Slang RefPtr<WitnessTable> getWitnessTable(); - RequirementWitness specialize(SubstitutionSet const& subst); + RequirementWitness specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); Flavor m_flavor; DeclRef<Decl> m_declRef; diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index a1ed29b8a..68935108c 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -13,12 +13,11 @@ class OverloadGroupType : public Type { SLANG_CLASS(OverloadGroupType) -public: virtual String toString() override; protected: - virtual bool equalsImpl(Type * type) override; virtual RefPtr<Type> createCanonicalType() override; + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; }; @@ -31,8 +30,8 @@ class InitializerListType : public Type virtual String toString() override; protected: - virtual bool equalsImpl(Type * type) override; virtual RefPtr<Type> createCanonicalType() override; + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; }; @@ -41,13 +40,12 @@ class ErrorType : public Type { SLANG_CLASS(ErrorType) -public: virtual String toString() override; protected: - virtual bool equalsImpl(Type * type) override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; virtual RefPtr<Type> createCanonicalType() override; + virtual bool equalsImpl(Type* type) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; virtual HashCode getHashCode() override; }; @@ -59,22 +57,18 @@ class DeclRefType : public Type DeclRef<Decl> declRef; virtual String toString() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; - static RefPtr<DeclRefType> Create( - Session* session, - DeclRef<Decl> declRef); + static RefPtr<DeclRefType> create(ASTBuilder* astBuilder, DeclRef<Decl> declRef); - DeclRefType() - {} - DeclRefType( - DeclRef<Decl> declRef) +protected: + DeclRefType( DeclRef<Decl> declRef) : declRef(declRef) {} -protected: + virtual HashCode getHashCode() override; - virtual bool equalsImpl(Type * type) override; virtual RefPtr<Type> createCanonicalType() override; + virtual bool equalsImpl(Type* type) override; }; // Base class for types that can be used in arithmetic expressions @@ -90,18 +84,17 @@ class BasicExpressionType : public ArithmeticExpressionType { SLANG_CLASS(BasicExpressionType) - BaseType baseType; - BasicExpressionType() {} +protected: BasicExpressionType( Slang::BaseType baseType) : baseType(baseType) {} -protected: + virtual BasicExpressionType* GetScalarType() override; - virtual bool equalsImpl(Type * type) override; virtual RefPtr<Type> createCanonicalType() override; + virtual bool equalsImpl(Type* type) override; }; @@ -111,7 +104,6 @@ protected: class BuiltinType : public DeclRefType { SLANG_ABSTRACT_CLASS(BuiltinType) - }; // Resources that contain "elements" that can be fetched @@ -139,14 +131,11 @@ class TextureTypeBase : public ResourceType { SLANG_ABSTRACT_CLASS(TextureTypeBase) - TextureTypeBase() - {} - TextureTypeBase( - TextureFlavor flavor, - RefPtr<Type> elementType) +protected: + TextureTypeBase(TextureFlavor inFlavor, RefPtr<Type> inElementType) { - this->elementType = elementType; - this->flavor = flavor; + elementType = inElementType; + flavor = inFlavor; } }; @@ -154,11 +143,8 @@ class TextureType : public TextureTypeBase { SLANG_CLASS(TextureType) - TextureType() - {} - TextureType( - TextureFlavor flavor, - RefPtr<Type> elementType) +protected: + TextureType(TextureFlavor flavor, RefPtr<Type> elementType) : TextureTypeBase(flavor, elementType) {} }; @@ -169,11 +155,8 @@ class TextureSamplerType : public TextureTypeBase { SLANG_CLASS(TextureSamplerType) - TextureSamplerType() - {} - TextureSamplerType( - TextureFlavor flavor, - RefPtr<Type> elementType) +protected: + TextureSamplerType(TextureFlavor flavor, RefPtr<Type> elementType) : TextureTypeBase(flavor, elementType) {} }; @@ -183,8 +166,7 @@ class GLSLImageType : public TextureTypeBase { SLANG_CLASS(GLSLImageType) - GLSLImageType() - {} +protected: GLSLImageType( TextureFlavor flavor, RefPtr<Type> elementType) @@ -396,9 +378,9 @@ class ArrayExpressionType : public Type virtual String toString() override; protected: - virtual bool equalsImpl(Type * type) override; virtual RefPtr<Type> createCanonicalType() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual bool equalsImpl(Type* type) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; virtual HashCode getHashCode() override; }; @@ -412,18 +394,15 @@ class TypeType : public Type // The type that this is the type of... RefPtr<Type> type; -public: - TypeType() - {} + virtual String toString() override; + +protected: TypeType(RefPtr<Type> type) : type(type) {} - virtual String toString() override; - -protected: - virtual bool equalsImpl(Type * type) override; virtual RefPtr<Type> createCanonicalType() override; + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; }; @@ -527,20 +506,18 @@ class NamedExpressionType : public Type SLANG_CLASS(NamedExpressionType) DeclRef<TypeDefDecl> declRef; - RefPtr<Type> innerType; - NamedExpressionType() - {} + + virtual String toString() override; + +protected: NamedExpressionType( DeclRef<TypeDefDecl> declRef) : declRef(declRef) {} - virtual String toString() override; - -protected: - virtual bool equalsImpl(Type * type) override; virtual RefPtr<Type> createCanonicalType() override; + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; }; @@ -559,18 +536,16 @@ class FuncType : public Type List<RefPtr<Type>> paramTypes; RefPtr<Type> resultType; - FuncType() - {} - UInt getParamCount() { return paramTypes.getCount(); } Type* getParamType(UInt index) { return paramTypes[index]; } Type* getResultType() { return resultType; } virtual String toString() override; + protected: - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; - virtual bool equalsImpl(Type * type) override; virtual RefPtr<Type> createCanonicalType() override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; }; @@ -581,19 +556,17 @@ class GenericDeclRefType : public Type DeclRef<GenericDecl> declRef; - GenericDeclRefType() - {} - GenericDeclRefType( - DeclRef<GenericDecl> declRef) - : declRef(declRef) - {} - DeclRef<GenericDecl> const& getDeclRef() const { return declRef; } virtual String toString() override; protected: - virtual bool equalsImpl(Type * type) override; + GenericDeclRefType( + DeclRef<GenericDecl> declRef) + : declRef(declRef) + {} + + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; virtual RefPtr<Type> createCanonicalType() override; }; @@ -605,15 +578,12 @@ class NamespaceType : public Type DeclRef<NamespaceDeclBase> declRef; - NamespaceType() - {} - DeclRef<NamespaceDeclBase> const& getDeclRef() const { return declRef; } virtual String toString() override; protected: - virtual bool equalsImpl(Type * type) override; + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; virtual RefPtr<Type> createCanonicalType() override; }; @@ -627,10 +597,10 @@ class ExtractExistentialType : public Type DeclRef<VarDeclBase> declRef; virtual String toString() override; - virtual bool equalsImpl(Type * type) override; + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; virtual RefPtr<Type> createCanonicalType() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; /// A tagged union of zero or more other types. @@ -646,10 +616,10 @@ class TaggedUnionType : public Type List<RefPtr<Type>> caseTypes; virtual String toString() override; - virtual bool equalsImpl(Type * type) override; + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; virtual RefPtr<Type> createCanonicalType() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; class ExistentialSpecializedType : public Type @@ -660,10 +630,10 @@ class ExistentialSpecializedType : public Type ExpandedSpecializationArgs args; virtual String toString() override; - virtual bool equalsImpl(Type * type) override; + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; virtual RefPtr<Type> createCanonicalType() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; /// The type of `this` within a polymorphic declaration @@ -674,11 +644,10 @@ class ThisType : public Type DeclRef<InterfaceDecl> interfaceDeclRef; virtual String toString() override; - virtual bool equalsImpl(Type * type) override; + virtual bool equalsImpl(Type* type) override; virtual HashCode getHashCode() override; virtual RefPtr<Type> createCanonicalType() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; - + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; } // namespace Slang diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index a6425e060..5345a389f 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -21,10 +21,9 @@ class ConstantIntVal : public IntVal IntegerLiteralValue value; - ConstantIntVal() - {} - ConstantIntVal(IntegerLiteralValue value) - : value(value) +protected: + ConstantIntVal(IntegerLiteralValue inValue) + : value(inValue) {} virtual bool equalsVal(Val* val) override; @@ -32,23 +31,22 @@ class ConstantIntVal : public IntVal virtual HashCode getHashCode() override; }; -// The logical "value" of a rererence to a generic value parameter +// The logical "value" of a reference to a generic value parameter class GenericParamIntVal : public IntVal { SLANG_CLASS(GenericParamIntVal) DeclRef<VarDeclBase> declRef; - GenericParamIntVal() - {} - GenericParamIntVal(DeclRef<VarDeclBase> declRef) - : declRef(declRef) +protected: + GenericParamIntVal(DeclRef<VarDeclBase> inDeclRef) + : declRef(inDeclRef) {} virtual bool equalsVal(Val* val) override; virtual String toString() override; virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; /// An unknown integer value indicating an erroneous sub-expression @@ -60,13 +58,10 @@ class ErrorIntVal : public IntVal // and have all `Val`s that represent ordinary values hold their // `Type` so that we can have an `ErrorVal` of any type. - ErrorIntVal() - {} - virtual bool equalsVal(Val* val) override; virtual String toString() override; virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; // A witness to the fact that some proposition is true, encoded @@ -124,11 +119,10 @@ class TypeEqualityWitness : public SubtypeWitness { SLANG_CLASS(TypeEqualityWitness) - virtual bool equalsVal(Val* val) override; virtual String toString() override; virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int * ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; // A witness that one type is a subtype of another @@ -142,7 +136,7 @@ class DeclaredSubtypeWitness : public SubtypeWitness virtual bool equalsVal(Val* val) override; virtual String toString() override; virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int * ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; // A witness that `sub : sup` because `sub : mid` and `mid : sup` @@ -159,7 +153,7 @@ class TransitiveSubtypeWitness : public SubtypeWitness virtual bool equalsVal(Val* val) override; virtual String toString() override; virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int * ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; // A witness taht `sub : sup` because `sub` was wrapped into @@ -174,7 +168,7 @@ class ExtractExistentialSubtypeWitness : public SubtypeWitness virtual bool equalsVal(Val* val) override; virtual String toString() override; virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int * ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; // A witness that `sub : sup`, because `sub` is a tagged union @@ -193,7 +187,7 @@ class TaggedUnionSubtypeWitness : public SubtypeWitness virtual bool equalsVal(Val* val) override; virtual String toString() override; virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(SubstitutionSet subst, int * ioDiff) override; + virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; } // namespace Slang diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 639716ee5..4a149c10f 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -10,7 +10,7 @@ namespace Slang RefPtr<DeclaredSubtypeWitness> SemanticsVisitor::createSimpleSubtypeWitness( TypeWitnessBreadcrumb* breadcrumb) { - RefPtr<DeclaredSubtypeWitness> witness = new DeclaredSubtypeWitness(); + RefPtr<DeclaredSubtypeWitness> witness = m_astBuilder->create<DeclaredSubtypeWitness>(); witness->sub = breadcrumb->sub; witness->sup = breadcrumb->sup; witness->declRef = breadcrumb->declRef; @@ -79,7 +79,7 @@ namespace Slang // where `[...]` represents the "hole" we leave // open to fill in next. // - RefPtr<TransitiveSubtypeWitness> transitiveWitness = new TransitiveSubtypeWitness(); + RefPtr<TransitiveSubtypeWitness> transitiveWitness = m_astBuilder->create<TransitiveSubtypeWitness>(); transitiveWitness->sub = bb->sub; transitiveWitness->sup = bb->sup; transitiveWitness->midToSup = bb->declRef; @@ -190,7 +190,7 @@ namespace Slang // loops better). This would also help avoid checking multiply-inherited // conformances multiple times. - auto inheritedType = getBaseType(inheritanceDeclRef); + auto inheritedType = getBaseType(m_astBuilder, inheritanceDeclRef); // We need to ensure that the witness that gets created // is a composite one, reflecting lookup through @@ -211,7 +211,7 @@ namespace Slang for (auto genConstraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(aggTypeDeclRef)) { ensureDecl(genConstraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); - auto inheritedType = GetSup(genConstraintDeclRef); + auto inheritedType = GetSup(m_astBuilder, genConstraintDeclRef); TypeWitnessBreadcrumb breadcrumb; breadcrumb.prev = inBreadcrumbs; breadcrumb.sub = type; @@ -233,8 +233,8 @@ namespace Slang for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef) ) { - auto sub = GetSub(constraintDeclRef); - auto sup = GetSup(constraintDeclRef); + auto sub = GetSub(m_astBuilder, constraintDeclRef); + auto sup = GetSup(m_astBuilder, constraintDeclRef); auto subDeclRef = as<DeclRefType>(sub); if(!subDeclRef) @@ -313,9 +313,9 @@ namespace Slang // if(outWitness) { - RefPtr<TaggedUnionSubtypeWitness> taggedUnionWitness = new TaggedUnionSubtypeWitness(); + RefPtr<TaggedUnionSubtypeWitness> taggedUnionWitness = m_astBuilder->create<TaggedUnionSubtypeWitness>(); taggedUnionWitness->sub = taggedUnionType; - taggedUnionWitness->sup = DeclRefType::Create(getSession(), interfaceDeclRef); + taggedUnionWitness->sup = DeclRefType::create(m_astBuilder, interfaceDeclRef); taggedUnionWitness->caseWitnesses.swapWith(caseWitnesses); *outWitness = taggedUnionWitness; @@ -346,7 +346,7 @@ namespace Slang RefPtr<Val> SemanticsVisitor::createTypeEqualityWitness( Type* type) { - RefPtr<TypeEqualityWitness> rs = new TypeEqualityWitness(); + RefPtr<TypeEqualityWitness> rs = m_astBuilder->create<TypeEqualityWitness>(); rs->sub = type; rs->sup = type; return rs; diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index a5a5620ae..315f9c5b1 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -114,7 +114,7 @@ namespace Slang continue; // Look up the type in our session. - auto candidateType = type->getSession()->getBuiltinType(BaseType(baseTypeFlavorIndex)); + auto candidateType = type->getASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex)); if(!candidateType) continue; @@ -286,7 +286,7 @@ namespace Slang // that `X<T>.IndexType == T`. for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef) ) { - if(!TryUnifyTypes(*system, GetSub(constraintDeclRef), GetSup(constraintDeclRef))) + if(!TryUnifyTypes(*system, GetSub(m_astBuilder, constraintDeclRef), GetSup(m_astBuilder, constraintDeclRef))) return SubstitutionSet(); } SubstitutionSet resultSubst = genericDeclRef.substitutions; @@ -391,7 +391,7 @@ namespace Slang // search for a conformance `Robin : ISidekick`, which involved // apply the substitutions we already know... - RefPtr<GenericSubstitution> solvedSubst = new GenericSubstitution(); + RefPtr<GenericSubstitution> solvedSubst = m_astBuilder->create<GenericSubstitution>(); solvedSubst->genericDecl = genericDeclRef.getDecl(); solvedSubst->outer = genericDeclRef.substitutions.substitutions; solvedSubst->args = args; @@ -404,8 +404,8 @@ namespace Slang solvedSubst); // Extract the (substituted) sub- and super-type from the constraint. - auto sub = GetSub(constraintDeclRef); - auto sup = GetSup(constraintDeclRef); + auto sub = GetSub(m_astBuilder, constraintDeclRef); + auto sup = GetSup(m_astBuilder, constraintDeclRef); // Search for a witness that shows the constraint is satisfied. auto subTypeWitness = tryGetSubtypeWitness(sub, sup); diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index e514caebf..dfcf7ece2 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -310,9 +310,8 @@ namespace Slang // We have a new type for the conversion, based on what // we learned. - toType = getSession()->getArrayType( - toElementType, - new ConstantIntVal(elementCount)); + toType = m_astBuilder->getArrayType(toElementType, + m_astBuilder->create<ConstantIntVal>(elementCount)); } } else if(auto toMatrixType = as<MatrixExpressionType>(toType)) @@ -383,7 +382,7 @@ namespace Slang { RefPtr<Expr> coercedArg; bool argResult = _readValueFromInitializerList( - GetType(fieldDeclRef), + GetType(m_astBuilder, fieldDeclRef), outToExpr ? &coercedArg : nullptr, fromInitializerListExpr, ioArgIndex); @@ -418,7 +417,7 @@ namespace Slang // if(outToExpr) { - auto toInitializerListExpr = new InitializerListExpr(); + auto toInitializerListExpr = m_astBuilder->create<InitializerListExpr>(); toInitializerListExpr->loc = fromInitializerListExpr->loc; toInitializerListExpr->type = QualType(toType); toInitializerListExpr->args = coercedArgs; @@ -575,7 +574,7 @@ namespace Slang RefPtr<DerefExpr> derefExpr; if(outToExpr) { - derefExpr = new DerefExpr(); + derefExpr = m_astBuilder->create<DerefExpr>(); derefExpr->base = fromExpr; derefExpr->type = QualType(fromElementType); } @@ -814,7 +813,7 @@ namespace Slang RefPtr<TypeCastExpr> SemanticsVisitor::createImplicitCastExpr() { - return new ImplicitCastExpr(); + return m_astBuilder->create<ImplicitCastExpr>(); } RefPtr<Expr> SemanticsVisitor::CreateImplicitCastExpr( @@ -823,9 +822,9 @@ namespace Slang { RefPtr<TypeCastExpr> castExpr = createImplicitCastExpr(); - auto typeType = getTypeType(toType); + auto typeType = m_astBuilder->getTypeType(toType); - auto typeExpr = new SharedTypeExpr(); + auto typeExpr = m_astBuilder->create<SharedTypeExpr>(); typeExpr->type.type = typeType; typeExpr->base.type = toType; @@ -841,7 +840,7 @@ namespace Slang RefPtr<Expr> fromExpr, RefPtr<Val> witness) { - RefPtr<CastToInterfaceExpr> expr = new CastToInterfaceExpr(); + RefPtr<CastToInterfaceExpr> expr = m_astBuilder->create<CastToInterfaceExpr>(); expr->loc = fromExpr->loc; expr->type = QualType(toType); expr->valueArg = fromExpr; @@ -866,7 +865,7 @@ namespace Slang // really shouldn't *change* the expression that is passed in, but should // introduce new AST nodes to coerce its value to a different type... return CreateImplicitCastExpr( - getSession()->getErrorType(), + m_astBuilder->getErrorType(), fromExpr); } return expr; diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index b69b3ad7d..b08baa42f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -272,7 +272,7 @@ namespace Slang // Get the type to use when referencing a declaration QualType getTypeForDeclRef( - Session* session, + ASTBuilder* astBuilder, SemanticsVisitor* sema, DiagnosticSink* sink, DeclRef<Decl> declRef, @@ -304,7 +304,7 @@ namespace Slang if(_isLocalVar(varDecl)) { sema->getSink()->diagnose(varDecl, Diagnostics::localVariableUsedBeforeDeclared, varDecl); - return QualType(session->getErrorType()); + return QualType(astBuilder->getErrorType()); } } } @@ -322,7 +322,7 @@ namespace Slang if (auto varDeclRef = declRef.as<VarDeclBase>()) { QualType qualType; - qualType.type = GetType(varDeclRef); + qualType.type = GetType(astBuilder, varDeclRef); bool isLValue = true; if(varDeclRef.getDecl()->findModifier<ConstModifier>()) @@ -368,37 +368,37 @@ namespace Slang else if( auto enumCaseDeclRef = declRef.as<EnumCaseDecl>() ) { QualType qualType; - qualType.type = getType(enumCaseDeclRef); + qualType.type = getType(astBuilder, enumCaseDeclRef); qualType.IsLeftValue = false; return qualType; } else if (auto typeAliasDeclRef = declRef.as<TypeDefDecl>()) { - auto type = getNamedType(session, typeAliasDeclRef); + auto type = getNamedType(astBuilder, typeAliasDeclRef); *outTypeResult = type; - return QualType(getTypeType(type)); + return QualType(astBuilder->getTypeType(type)); } else if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>()) { - auto type = DeclRefType::Create(session, aggTypeDeclRef); + auto type = DeclRefType::create(astBuilder, aggTypeDeclRef); *outTypeResult = type; - return QualType(getTypeType(type)); + return QualType(astBuilder->getTypeType(type)); } else if (auto simpleTypeDeclRef = declRef.as<SimpleTypeDecl>()) { - auto type = DeclRefType::Create(session, simpleTypeDeclRef); + auto type = DeclRefType::create(astBuilder, simpleTypeDeclRef); *outTypeResult = type; - return QualType(getTypeType(type)); + return QualType(astBuilder->getTypeType(type)); } else if (auto genericDeclRef = declRef.as<GenericDecl>()) { - auto type = getGenericDeclRefType(session, genericDeclRef); + auto type = getGenericDeclRefType(astBuilder, genericDeclRef); *outTypeResult = type; - return QualType(getTypeType(type)); + return QualType(astBuilder->getTypeType(type)); } else if (auto funcDeclRef = declRef.as<CallableDecl>()) { - auto type = getFuncType(session, funcDeclRef); + auto type = getFuncType(astBuilder, funcDeclRef); return QualType(type); } else if (auto constraintDeclRef = declRef.as<TypeConstraintDecl>()) @@ -406,12 +406,12 @@ namespace Slang // When we access a constraint or an inheritance decl (as a member), // we are conceptually performing a "cast" to the given super-type, // with the declaration showing that such a cast is legal. - auto type = GetSup(constraintDeclRef); + auto type = GetSup(astBuilder, constraintDeclRef); return QualType(type); } else if( auto namespaceDeclRef = declRef.as<NamespaceDeclBase>()) { - auto type = getNamespaceType(session, namespaceDeclRef); + auto type = getNamespaceType(astBuilder, namespaceDeclRef); return QualType(type); } if( sink ) @@ -430,16 +430,16 @@ namespace Slang // sink->diagnose(loc, Diagnostics::undefinedIdentifier2, declRef.GetName()); } - return QualType(session->getErrorType()); + return QualType(astBuilder->getErrorType()); } QualType getTypeForDeclRef( - Session* session, + ASTBuilder* astBuilder, DeclRef<Decl> declRef, SourceLoc loc) { RefPtr<Type> typeResult; - return getTypeForDeclRef(session, nullptr, nullptr, declRef, &typeResult, loc); + return getTypeForDeclRef(astBuilder, nullptr, nullptr, declRef, &typeResult, loc); } DeclRef<ExtensionDecl> ApplyExtensionToType( @@ -453,12 +453,12 @@ namespace Slang return semantics->ApplyExtensionToType(extDecl, type); } - RefPtr<GenericSubstitution> createDefaultSubsitutionsForGeneric( - Session* session, + RefPtr<GenericSubstitution> createDefaultSubstitutionsForGeneric( + ASTBuilder* astBuilder, GenericDecl* genericDecl, RefPtr<Substitutions> outerSubst) { - RefPtr<GenericSubstitution> genericSubst = new GenericSubstitution(); + RefPtr<GenericSubstitution> genericSubst = astBuilder->create<GenericSubstitution>(); genericSubst->genericDecl = genericDecl; genericSubst->outer = outerSubst; @@ -466,11 +466,11 @@ namespace Slang { if( auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm) ) { - genericSubst->args.add(DeclRefType::Create(session, DeclRef<Decl>(genericTypeParamDecl, outerSubst))); + genericSubst->args.add(DeclRefType::create(astBuilder, DeclRef<Decl>(genericTypeParamDecl, outerSubst))); } else if( auto genericValueParamDecl = as<GenericValueParamDecl>(mm) ) { - genericSubst->args.add(new GenericParamIntVal(DeclRef<GenericValueParamDecl>(genericValueParamDecl, outerSubst))); + genericSubst->args.add(astBuilder->create<GenericParamIntVal>(DeclRef<GenericValueParamDecl>(genericValueParamDecl, outerSubst))); } } @@ -479,7 +479,7 @@ namespace Slang { if (auto genericTypeConstraintDecl = as<GenericTypeConstraintDecl>(mm)) { - RefPtr<DeclaredSubtypeWitness> witness = new DeclaredSubtypeWitness(); + RefPtr<DeclaredSubtypeWitness> witness = astBuilder->create<DeclaredSubtypeWitness>(); witness->declRef = DeclRef<Decl>(genericTypeConstraintDecl, outerSubst); witness->sub = genericTypeConstraintDecl->sub.type; witness->sup = genericTypeConstraintDecl->sup.type; @@ -495,7 +495,7 @@ namespace Slang // using their archetypes). // SubstitutionSet createDefaultSubstitutions( - Session* session, + ASTBuilder* astBuilder, Decl* decl, SubstitutionSet outerSubstSet) { @@ -507,8 +507,8 @@ namespace Slang if(decl != genericDecl->inner) return outerSubstSet; - RefPtr<GenericSubstitution> genericSubst = createDefaultSubsitutionsForGeneric( - session, + RefPtr<GenericSubstitution> genericSubst = createDefaultSubstitutionsForGeneric( + astBuilder, genericDecl, outerSubstSet.substitutions); @@ -519,15 +519,15 @@ namespace Slang } SubstitutionSet createDefaultSubstitutions( - Session* session, + ASTBuilder* astBuilder, Decl* decl) { SubstitutionSet subst; if( auto parentDecl = decl->parentDecl ) { - subst = createDefaultSubstitutions(session, parentDecl); + subst = createDefaultSubstitutions(astBuilder, parentDecl); } - subst = createDefaultSubstitutions(session, decl, subst); + subst = createDefaultSubstitutions(astBuilder, decl, subst); return subst; } @@ -776,7 +776,7 @@ namespace Slang if(!initExpr) { getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer); - varDecl->type.type = getSession()->getErrorType(); + varDecl->type.type = m_astBuilder->getErrorType(); } else { @@ -803,7 +803,7 @@ namespace Slang TypeExp typeExp = CheckUsableType(varDecl->type); varDecl->type = typeExp; - if (varDecl->type.equals(getSession()->getVoidType())) + if (varDecl->type.equals(m_astBuilder->getVoidType())) { getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid); } @@ -921,7 +921,7 @@ namespace Slang { if (auto declRefType = as<DeclRefType>(sharedTypeExpr->base)) { - declRefType->declRef.substitutions = createDefaultSubstitutions(getSession(), declRefType->declRef.getDecl()); + declRefType->declRef.substitutions = createDefaultSubstitutions(m_astBuilder, declRefType->declRef.getDecl()); if (auto typetype = as<TypeType>(typeExp.exp->type)) typetype->type = declRefType; @@ -1041,13 +1041,15 @@ namespace Slang /// static void _registerBuiltinDeclsRec(Session* session, Decl* decl) { + SharedASTBuilder* sharedASTBuilder = session->m_sharedASTBuilder; + if (auto builtinMod = decl->findModifier<BuiltinTypeModifier>()) { - registerBuiltinDecl(session, decl, builtinMod); + sharedASTBuilder->registerBuiltinDecl(decl, builtinMod); } if (auto magicMod = decl->findModifier<MagicTypeModifier>()) { - registerMagicDecl(session, decl, magicMod); + sharedASTBuilder->registerMagicDecl(decl, magicMod); } if(auto containerDecl = as<ContainerDecl>(decl)) @@ -1283,7 +1285,7 @@ namespace Slang for (auto requiredConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(requiredAssociatedTypeDeclRef)) { // Grab the type we expect to conform to from the constraint. - auto requiredSuperType = GetSup(requiredConstraintDeclRef); + auto requiredSuperType = GetSup(m_astBuilder, requiredConstraintDeclRef); // Perform a search for a witness to the subtype relationship. auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); @@ -1397,7 +1399,7 @@ namespace Slang { ensureDecl(subAggTypeDeclRef, DeclCheckState::CanUseAsType); - auto satisfyingType = DeclRefType::Create(getSession(), subAggTypeDeclRef); + auto satisfyingType = DeclRefType::create(m_astBuilder, subAggTypeDeclRef); return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); } } @@ -1409,7 +1411,7 @@ namespace Slang { ensureDecl(typedefDeclRef, DeclCheckState::CanUseAsType); - auto satisfyingType = getNamedType(getSession(), typedefDeclRef); + auto satisfyingType = getNamedType(m_astBuilder, typedefDeclRef); return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); } } @@ -1461,7 +1463,7 @@ namespace Slang context, type, requiredInheritanceDeclRef.getDecl(), - getBaseType(requiredInheritanceDeclRef)); + getBaseType(m_astBuilder, requiredInheritanceDeclRef)); if(!satisfyingWitnessTable) return false; @@ -1503,7 +1505,7 @@ namespace Slang // on subsequent checking in this function to // rule out inherited abstract members. // - auto lookupResult = lookUpMember(getSession(), this, name, type); + auto lookupResult = lookUpMember(m_astBuilder, this, name, type); // Iterate over the members and look for one that matches // the expected signature for the requirement. @@ -1607,9 +1609,7 @@ namespace Slang // // TODO: need to decide if a this-type substitution is needed here. // It probably it. - RefPtr<Type> targetType = DeclRefType::Create( - getSession(), - interfaceDeclRef); + RefPtr<Type> targetType = DeclRefType::create(m_astBuilder, interfaceDeclRef); auto extDeclRef = ApplyExtensionToType(candidateExt, targetType); if(!extDeclRef) continue; @@ -1713,8 +1713,8 @@ namespace Slang void SemanticsVisitor::checkExtensionConformance(ExtensionDecl* decl) { - auto declRef = createDefaultSubstitutionsIfNeeded(getSession(), makeDeclRef(decl)).as<ExtensionDecl>(); - auto targetType = GetTargetType(declRef); + auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, makeDeclRef(decl)).as<ExtensionDecl>(); + auto targetType = GetTargetType(m_astBuilder, declRef); for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>()) { @@ -1744,8 +1744,10 @@ namespace Slang // For non-interface types we need to check conformance. // - auto declRef = createDefaultSubstitutionsIfNeeded(getSession(), makeDeclRef(decl)).as<AggTypeDeclBase>(); - auto type = DeclRefType::Create(getSession(), declRef); + auto astBuilder = getASTBuilder(); + + auto declRef = createDefaultSubstitutionsIfNeeded(astBuilder, makeDeclRef(decl)).as<AggTypeDeclBase>(); + auto type = DeclRefType::create(astBuilder, declRef); // TODO: Need to figure out what this should do for // `abstract` types if we ever add them. Should they @@ -1845,7 +1847,7 @@ namespace Slang // type of their tag. if(!tagType) { - tagType = getSession()->getIntType(); + tagType = m_astBuilder->getIntType(); } else { @@ -1867,12 +1869,12 @@ namespace Slang // seems like the best place to do it. { // First, look up the type of the `__EnumType` interface. - RefPtr<Type> enumTypeType = getSession()->getEnumTypeType(); + RefPtr<Type> enumTypeType = getASTBuilder()->getEnumTypeType(); - RefPtr<InheritanceDecl> enumConformanceDecl = new InheritanceDecl(); + RefPtr<InheritanceDecl> enumConformanceDecl = m_astBuilder->create<InheritanceDecl>(); enumConformanceDecl->parentDecl = decl; enumConformanceDecl->loc = decl->loc; - enumConformanceDecl->base.type = getSession()->getEnumTypeType(); + enumConformanceDecl->base.type = getASTBuilder()->getEnumTypeType(); decl->members.add(enumConformanceDecl); // The `__EnumType` interface has one required member, the `__Tag` type. @@ -1921,9 +1923,7 @@ namespace Slang void SemanticsDeclBodyVisitor::visitEnumDecl(EnumDecl* decl) { - auto enumType = DeclRefType::Create( - getSession(), - makeDeclRef(decl)); + auto enumType = DeclRefType::create(m_astBuilder, makeDeclRef(decl)); auto tagType = decl->tagType; @@ -1979,7 +1979,7 @@ namespace Slang { // This tag has no initializer, so it should use // the default tag value we are tracking. - RefPtr<IntegerLiteralExpr> tagValExpr = new IntegerLiteralExpr(); + RefPtr<IntegerLiteralExpr> tagValExpr = m_astBuilder->create<IntegerLiteralExpr>(); tagValExpr->loc = caseDecl->loc; tagValExpr->type = QualType(tagType); tagValExpr->value = defaultTag; @@ -2297,12 +2297,12 @@ namespace Slang // and `sup` types are pairwise equivalent. // auto leftSub = leftConstraint->sub; - auto rightSub = GetSub(rightConstraint); + auto rightSub = GetSub(m_astBuilder, rightConstraint); if(!leftSub->equals(rightSub)) return false; auto leftSup = leftConstraint->sup; - auto rightSup = GetSup(rightConstraint); + auto rightSup = GetSup(m_astBuilder, rightConstraint); if(!leftSup->equals(rightSup)) return false; } @@ -2336,7 +2336,7 @@ namespace Slang auto sndParam = sndParams[ii]; // If a given parameter type doesn't match, then signatures don't match - if (!GetType(fstParam)->equals(GetType(sndParam))) + if (!GetType(m_astBuilder, fstParam)->equals(GetType(m_astBuilder, sndParam))) return false; // If one parameter is `out` and the other isn't, then they don't match @@ -2361,7 +2361,7 @@ namespace Slang RefPtr<GenericSubstitution> SemanticsVisitor::createDummySubstitutions( GenericDecl* genericDecl) { - RefPtr<GenericSubstitution> subst = new GenericSubstitution(); + RefPtr<GenericSubstitution> subst = m_astBuilder->create<GenericSubstitution>(); subst->genericDecl = genericDecl; for (auto dd : genericDecl->members) { @@ -2370,13 +2370,12 @@ namespace Slang if (auto typeParam = as<GenericTypeParamDecl>(dd)) { - auto type = DeclRefType::Create(getSession(), - makeDeclRef(typeParam)); + auto type = DeclRefType::create(m_astBuilder, makeDeclRef(typeParam)); subst->args.add(type); } else if (auto valueParam = as<GenericValueParamDecl>(dd)) { - auto val = new GenericParamIntVal( + auto val = m_astBuilder->create<GenericParamIntVal>( makeDeclRef(valueParam)); subst->args.add(val); } @@ -2544,8 +2543,8 @@ namespace Slang // consider result types earlier, as part of the signature // matching step. // - auto resultType = GetResultType(newDeclRef); - auto prevResultType = GetResultType(oldDeclRef); + auto resultType = GetResultType(m_astBuilder, newDeclRef); + auto prevResultType = GetResultType(m_astBuilder, oldDeclRef); if (!resultType->equals(prevResultType)) { // Bad redeclaration @@ -2779,7 +2778,7 @@ namespace Slang } else { - resultType = TypeExp(getSession()->getVoidType()); + resultType = TypeExp(m_astBuilder->getVoidType()); } funcDecl->returnType = resultType; @@ -2823,6 +2822,7 @@ namespace Slang // Create a new array type based on the size we found, // and install it into our type. varDecl->type.type = getArrayType( + m_astBuilder, arrayType->baseType, elementCount); } @@ -2879,8 +2879,7 @@ namespace Slang // conform to the interface and fill in its // requirements. // - RefPtr<ThisType> thisType = new ThisType(); - thisType->setSession(getSession()); + RefPtr<ThisType> thisType = m_astBuilder->create<ThisType>(); thisType->interfaceDeclRef = interfaceDeclRef; return thisType; } @@ -2895,9 +2894,7 @@ namespace Slang // would need to refer to the eventual concrete // type, much like the `interface` case above. // - return DeclRefType::Create( - getSession(), - aggTypeDeclRef); + return DeclRefType::create(m_astBuilder, aggTypeDeclRef); } else if (auto extDeclRef = declRef.as<ExtensionDecl>()) { @@ -2922,7 +2919,7 @@ namespace Slang // sooner or later. // ensureDecl(extDeclRef, DeclCheckState::CanUseExtensionTargetType); - auto targetType = GetTargetType(extDeclRef); + auto targetType = GetTargetType(m_astBuilder, extDeclRef); return calcThisType(targetType); } else @@ -2964,7 +2961,7 @@ namespace Slang if( !thisType ) { getSink()->diagnose(decl, Diagnostics::initializerNotInsideType); - thisType = getSession()->getErrorType(); + thisType = m_astBuilder->getErrorType(); } return thisType; } @@ -2997,7 +2994,7 @@ namespace Slang if(!anyAccessors) { - RefPtr<GetterDecl> getterDecl = new GetterDecl(); + RefPtr<GetterDecl> getterDecl = m_astBuilder->create<GetterDecl>(); getterDecl->loc = decl->loc; getterDecl->parentDecl = decl; @@ -3067,7 +3064,7 @@ namespace Slang } // Now extract the target type from our (possibly specialized) extension decl-ref. - RefPtr<Type> targetType = GetTargetType(extDeclRef); + RefPtr<Type> targetType = GetTargetType(m_astBuilder, extDeclRef); // As a bit of a kludge here, if the target type of the extension is // an interface, and the `type` we are trying to match up has a this-type @@ -3098,12 +3095,12 @@ namespace Slang SLANG_ASSERT(!targetInterfaceDeclRef.substitutions.substitutions.as<ThisTypeSubstitution>()); // We will create a new substitution to apply to the target type. - RefPtr<ThisTypeSubstitution> newTargetSubst = new ThisTypeSubstitution(); + RefPtr<ThisTypeSubstitution> newTargetSubst = m_astBuilder->create<ThisTypeSubstitution>(); newTargetSubst->interfaceDecl = appThisTypeSubst->interfaceDecl; newTargetSubst->witness = appThisTypeSubst->witness; newTargetSubst->outer = targetInterfaceDeclRef.substitutions.substitutions; - targetType = DeclRefType::Create(getSession(), + targetType = DeclRefType::create(m_astBuilder, DeclRef<InterfaceDecl>(targetInterfaceDeclRef.getDecl(), newTargetSubst)); // Note: we are constructing a this-type substitution that @@ -3113,7 +3110,7 @@ namespace Slang // references to the target type of the extension // declaration have a chance to resolve the way we want them to. - RefPtr<ThisTypeSubstitution> newExtSubst = new ThisTypeSubstitution(); + RefPtr<ThisTypeSubstitution> newExtSubst = m_astBuilder->create<ThisTypeSubstitution>(); newExtSubst->interfaceDecl = appThisTypeSubst->interfaceDecl; newExtSubst->witness = appThisTypeSubst->witness; newExtSubst->outer = extDeclRef.substitutions.substitutions; @@ -3155,7 +3152,7 @@ namespace Slang { RefPtr<Type> typeResult; return getTypeForDeclRef( - getSession(), + m_astBuilder, this, getSink(), declRef, diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 6bf096ae9..e0f439ab3 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -31,7 +31,7 @@ namespace Slang template<typename F> RefPtr<Expr> SemanticsVisitor::moveTemp(RefPtr<Expr> const& expr, F const& func) { - RefPtr<VarDecl> varDecl = new VarDecl(); + RefPtr<VarDecl> varDecl = m_astBuilder->create<VarDecl>(); varDecl->parentDecl = nullptr; // TODO: need to fill this in somehow! varDecl->checkState = DeclCheckState::Checked; varDecl->nameAndLoc.loc = expr->loc; @@ -40,7 +40,7 @@ namespace Slang auto varDeclRef = makeDeclRef(varDecl.Ptr()); - RefPtr<LetExpr> letExpr = new LetExpr(); + RefPtr<LetExpr> letExpr = m_astBuilder->create<LetExpr>(); letExpr->decl = varDecl; auto body = func(varDeclRef); @@ -102,23 +102,23 @@ namespace Slang auto interfaceDecl = interfaceDeclRef.getDecl(); return maybeMoveTemp(expr, [&](DeclRef<VarDeclBase> varDeclRef) { - RefPtr<ExtractExistentialType> openedType = new ExtractExistentialType(); + RefPtr<ExtractExistentialType> openedType = m_astBuilder->create<ExtractExistentialType>(); openedType->declRef = varDeclRef; - RefPtr<ExtractExistentialSubtypeWitness> openedWitness = new ExtractExistentialSubtypeWitness(); + RefPtr<ExtractExistentialSubtypeWitness> openedWitness = m_astBuilder->create<ExtractExistentialSubtypeWitness>(); openedWitness->sub = openedType; openedWitness->sup = expr->type.type; openedWitness->declRef = varDeclRef; - RefPtr<ThisTypeSubstitution> openedThisType = new ThisTypeSubstitution(); + RefPtr<ThisTypeSubstitution> openedThisType = m_astBuilder->create<ThisTypeSubstitution>(); openedThisType->outer = interfaceDeclRef.substitutions.substitutions; openedThisType->interfaceDecl = interfaceDecl; openedThisType->witness = openedWitness; DeclRef<InterfaceDecl> substDeclRef = DeclRef<InterfaceDecl>(interfaceDecl, openedThisType); - auto substDeclRefType = DeclRefType::Create(getSession(), substDeclRef); + auto substDeclRefType = DeclRefType::create(m_astBuilder, substDeclRef); - RefPtr<ExtractExistentialValueExpr> openedValue = new ExtractExistentialValueExpr(); + RefPtr<ExtractExistentialValueExpr> openedValue = m_astBuilder->create<ExtractExistentialValueExpr>(); openedValue->declRef = varDeclRef; openedValue->type = QualType(substDeclRefType); @@ -218,7 +218,7 @@ namespace Slang declRef.GetName()); } - auto expr = new StaticMemberExpr(); + auto expr = m_astBuilder->create<StaticMemberExpr>(); expr->loc = loc; expr->type = type; expr->baseExpression = baseExpr; @@ -230,11 +230,11 @@ namespace Slang { // Extract the type of the baseExpr auto baseExprType = baseExpr->type.type; - RefPtr<SharedTypeExpr> baseTypeExpr = new SharedTypeExpr(); + RefPtr<SharedTypeExpr> baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); baseTypeExpr->base.type = baseExprType; - baseTypeExpr->type.type = getTypeType(baseExprType); + baseTypeExpr->type.type = m_astBuilder->getTypeType(baseExprType); - auto expr = new StaticMemberExpr(); + auto expr = m_astBuilder->create<StaticMemberExpr>(); expr->loc = loc; expr->type = type; expr->baseExpression = baseTypeExpr; @@ -247,7 +247,7 @@ namespace Slang // If the base expression wasn't a type, then this // is a normal member expression. // - auto expr = new MemberExpr(); + auto expr = m_astBuilder->create<MemberExpr>(); expr->loc = loc; expr->type = type; expr->baseExpression = baseExpr; @@ -274,7 +274,7 @@ namespace Slang // If there is no base expression, then the result must // be an ordinary variable expression. // - auto expr = new VarExpr(); + auto expr = m_astBuilder->create<VarExpr>(); expr->loc = loc; expr->name = declRef.GetName(); expr->type = type; @@ -290,7 +290,7 @@ namespace Slang auto ptrLikeType = as<PointerLikeType>(base->type); SLANG_ASSERT(ptrLikeType); - auto derefExpr = new DerefExpr(); + auto derefExpr = m_astBuilder->create<DerefExpr>(); derefExpr->loc = loc; derefExpr->base = base; derefExpr->type = QualType(ptrLikeType->elementType); @@ -364,9 +364,9 @@ namespace Slang // will be `typeof(This)`, which conceptually // `typeof(typeof(this))` // - auto thisTypeType = getTypeType(thisType); + auto thisTypeType = m_astBuilder->getTypeType(thisType); - auto typeExpr = new SharedTypeExpr(); + auto typeExpr = m_astBuilder->create<SharedTypeExpr>(); typeExpr->type.type = thisTypeType; typeExpr->base.type = thisType; @@ -380,7 +380,7 @@ namespace Slang // refernece to `this.someStaticMember` will be translated // over to `This.someStaticMember`. // - RefPtr<ThisExpr> expr = new ThisExpr(); + RefPtr<ThisExpr> expr = m_astBuilder->create<ThisExpr>(); expr->type.type = thisType; expr->loc = loc; @@ -412,11 +412,11 @@ namespace Slang { if (lookupResult.isOverloaded()) { - auto overloadedExpr = new OverloadedExpr(); + auto overloadedExpr = m_astBuilder->create<OverloadedExpr>(); overloadedExpr->name = name; overloadedExpr->loc = loc; overloadedExpr->type = QualType( - getSession()->getOverloadedType()); + m_astBuilder->getOverloadedType()); overloadedExpr->base = baseExpr; overloadedExpr->lookupResult2 = lookupResult; return overloadedExpr; @@ -588,7 +588,7 @@ namespace Slang RefPtr<Expr> SemanticsVisitor::CreateErrorExpr(Expr* expr) { - expr->type = QualType(getSession()->getErrorType()); + expr->type = QualType(m_astBuilder->getErrorType()); return expr; } @@ -617,7 +617,7 @@ namespace Slang RefPtr<Expr> SemanticsExprVisitor::visitBoolLiteralExpr(BoolLiteralExpr* expr) { - expr->type = getSession()->getBoolType(); + expr->type = m_astBuilder->getBoolType(); return expr; } @@ -636,7 +636,7 @@ namespace Slang // if(!expr->type.type) { - expr->type = getSession()->getIntType(); + expr->type = m_astBuilder->getIntType(); } return expr; } @@ -645,21 +645,21 @@ namespace Slang { if(!expr->type.type) { - expr->type = getSession()->getFloatType(); + expr->type = m_astBuilder->getFloatType(); } return expr; } RefPtr<Expr> SemanticsExprVisitor::visitStringLiteralExpr(StringLiteralExpr* expr) { - expr->type = getSession()->getStringType(); + expr->type = m_astBuilder->getStringType(); return expr; } IntVal* SemanticsVisitor::GetIntVal(IntegerLiteralExpr* expr) { // TODO(tfoley): don't keep allocating here! - return new ConstantIntVal(expr->value); + return m_astBuilder->create<ConstantIntVal>(expr->value); } RefPtr<IntVal> SemanticsVisitor::TryConstantFoldExpr( @@ -784,7 +784,7 @@ namespace Slang return nullptr; } - RefPtr<IntVal> result = new ConstantIntVal(resultValue); + RefPtr<IntVal> result = m_astBuilder->create<ConstantIntVal>(resultValue); return result; } @@ -811,7 +811,7 @@ namespace Slang if (auto genericValParamRef = declRef.as<GenericValueParamDecl>()) { // TODO(tfoley): handle the case of non-`int` value parameters... - return new GenericParamIntVal(genericValParamRef); + return m_astBuilder->create<GenericParamIntVal>(genericValParamRef); } // We may also need to check for references to variables that @@ -826,7 +826,7 @@ namespace Slang if(auto constAttr = varDecl->findModifier<ConstModifier>()) { // HLSL `static const` can be used as a constant expression - if(auto initExpr = getInitExpr(varRef)) + if(auto initExpr = getInitExpr(m_astBuilder, varRef)) { return TryConstantFoldExpr(initExpr.Ptr()); } @@ -836,7 +836,7 @@ namespace Slang else if(auto enumRef = declRef.as<EnumCaseDecl>()) { // The cases in an `enum` declaration can also be used as constant expressions, - if(auto tagExpr = getTagExpr(enumRef)) + if(auto tagExpr = getTagExpr(m_astBuilder, enumRef)) { return TryConstantFoldExpr(tagExpr.Ptr()); } @@ -882,7 +882,7 @@ namespace Slang if(IsErrorExpr(inExpr)) return nullptr; // First coerce the expression to the expected type - auto expr = coerce(getSession()->getIntType(),inExpr); + auto expr = coerce(m_astBuilder->getIntType(),inExpr); // No need to issue further errors if the type coercion failed. if(IsErrorExpr(expr)) return nullptr; @@ -923,8 +923,8 @@ namespace Slang auto baseExpr = subscriptExpr->baseExpression; auto indexExpr = subscriptExpr->indexExpression; - if (!indexExpr->type->equals(getSession()->getIntType()) && - !indexExpr->type->equals(getSession()->getUIntType())) + if (!indexExpr->type->equals(m_astBuilder->getIntType()) && + !indexExpr->type->equals(m_astBuilder->getUIntType())) { getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger); return CreateErrorExpr(subscriptExpr.Ptr()); @@ -973,10 +973,11 @@ namespace Slang auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); auto arrayType = getArrayType( + m_astBuilder, elementType, elementCount); - subscriptExpr->type = QualType(getTypeType(arrayType)); + subscriptExpr->type = QualType(m_astBuilder->getTypeType(arrayType)); return subscriptExpr; } else if (auto baseArrayType = as<ArrayExpressionType>(baseType)) @@ -1010,7 +1011,7 @@ namespace Slang { Name* name = getName("operator[]"); LookupResult lookupResult = lookUpMember( - getSession(), + m_astBuilder, this, name, baseType); @@ -1028,7 +1029,7 @@ namespace Slang RefPtr<Expr> subscriptFuncExpr = createLookupResultExpr( name, lookupResult, subscriptExpr->baseExpression, subscriptExpr->loc); - RefPtr<InvokeExpr> subscriptCallExpr = new InvokeExpr(); + RefPtr<InvokeExpr> subscriptCallExpr = m_astBuilder->create<InvokeExpr>(); subscriptCallExpr->loc = subscriptExpr->loc; subscriptCallExpr->functionExpr = subscriptFuncExpr; @@ -1223,9 +1224,9 @@ namespace Slang if (expr->declRef) return expr; - expr->type = QualType(getSession()->getErrorType()); + expr->type = QualType(m_astBuilder->getErrorType()); auto lookupResult = lookUp( - getSession(), + m_astBuilder, this, expr->name, expr->scope); if (lookupResult.isValid()) { @@ -1316,7 +1317,7 @@ namespace Slang // of explicit default initializers for `struct` fields to // make this a major concern (since they aren't supported in HLSL). // - RefPtr<InitializerListExpr> initListExpr = new InitializerListExpr(); + RefPtr<InitializerListExpr> initListExpr = m_astBuilder->create<InitializerListExpr>(); auto checkedInitListExpr = visitInitializerListExpr(initListExpr); return coerce(typeExp.type, initListExpr); } @@ -1342,7 +1343,7 @@ namespace Slang auto elementType = QualType(pointerLikeType->elementType); elementType.IsLeftValue = baseType.IsLeftValue; - auto derefExpr = new DerefExpr(); + auto derefExpr = m_astBuilder->create<DerefExpr>(); derefExpr->base = expr; derefExpr->type = elementType; @@ -1360,7 +1361,7 @@ namespace Slang RefPtr<Type> baseElementType, IntegerLiteralValue baseElementCount) { - RefPtr<SwizzleExpr> swizExpr = new SwizzleExpr(); + RefPtr<SwizzleExpr> swizExpr = m_astBuilder->create<SwizzleExpr>(); swizExpr->loc = memberRefExpr->loc; swizExpr->base = memberRefExpr->baseExpression; @@ -1439,7 +1440,7 @@ namespace Slang // here if the input type had a sugared name... swizExpr->type = QualType(createVectorType( baseElementType, - new ConstantIntVal(elementCount))); + m_astBuilder->create<ConstantIntVal>(elementCount))); } // A swizzle can be used as an l-value as long as there @@ -1484,7 +1485,7 @@ namespace Slang // we can reference the declaration here. // LookupResult lookupResult = lookUpDirectAndTransparentMembers( - getSession(), + m_astBuilder, this, expr->name, namespaceDeclRef); @@ -1515,7 +1516,7 @@ namespace Slang } LookupResult lookupResult = lookUpMember( - getSession(), + m_astBuilder, this, expr->name, type); @@ -1640,7 +1641,7 @@ namespace Slang SLANG_ASSERT(as<StaticMemberExpr>(expr) || as<MemberExpr>(expr)); getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->name, baseType); - expr->type = QualType(getSession()->getErrorType()); + expr->type = QualType(m_astBuilder->getErrorType()); return expr; } @@ -1703,7 +1704,7 @@ namespace Slang else { LookupResult lookupResult = lookUpMember( - getSession(), + m_astBuilder, this, expr->name, baseType.Ptr()); @@ -1734,7 +1735,7 @@ namespace Slang arg = CheckTerm(arg); } - expr->type = getSession()->getInitializerListType(); + expr->type = m_astBuilder->getInitializerListType(); return expr; } @@ -1814,7 +1815,7 @@ namespace Slang if( auto typeOrExtensionDecl = as<AggTypeDeclBase>(containerDecl) ) { auto thisType = calcThisType(makeDeclRef(typeOrExtensionDecl)); - auto thisTypeType = getTypeType(thisType); + auto thisTypeType = m_astBuilder->getTypeType(thisType); expr->type.type = thisTypeType; return expr; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index fb55e5bef..fba43adcb 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -11,9 +11,7 @@ namespace Slang { - RefPtr<TypeType> getTypeType( - Type* type); - + /// Should the given `decl` be treated as a static rather than instance declaration? bool isEffectivelyStatic( Decl* decl); @@ -229,18 +227,25 @@ namespace Slang return m_linkage->getSessionImpl(); } + Linkage* getLinkage() + { + return m_linkage; + } }; struct SemanticsVisitor { SemanticsVisitor( SharedSemanticsContext* shared) - : m_shared(shared) + : m_shared(shared), + m_astBuilder(shared->getLinkage()->getASTBuilder()) {} SharedSemanticsContext* m_shared = nullptr; + ASTBuilder* m_astBuilder = nullptr; SharedSemanticsContext* getShared() { return m_shared; } + ASTBuilder* getASTBuilder() { return m_astBuilder;} DiagnosticSink* getSink() { return m_shared->getSink(); } @@ -1304,7 +1309,7 @@ namespace Slang DeclRefExpr* expr, QualType const& baseType); - SharedSemanticsContext & operator = (const SharedSemanticsContext &) = delete; + SharedSemanticsContext& operator=(const SharedSemanticsContext &) = delete; // diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index b8b3846dc..42a9735af 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -80,15 +80,13 @@ namespace Slang AttributeDecl* SemanticsVisitor::lookUpAttributeDecl(Name* attributeName, Scope* scope) { - auto session = getSession(); - // We start by looking for an existing attribute matching // the name `attributeName`. // { // Look up the name and see what attributes we find. // - auto lookupResult = lookUp(session, this, attributeName, scope, LookupMask::Attribute); + auto lookupResult = lookUp(m_astBuilder, this, attributeName, scope, LookupMask::Attribute); // If the result was overloaded, then that means there // are multiple attributes matching the name, and we @@ -118,7 +116,7 @@ namespace Slang // If the attribute was `[Something(...)]` then we will // look for a `struct` named `SomethingAttribute`. // - LookupResult lookupResult = lookUp(session, this, session->getNameObj(attributeName->text + "Attribute"), scope, LookupMask::type); + LookupResult lookupResult = lookUp(m_astBuilder, this, m_astBuilder->getGlobalSession()->getNameObj(attributeName->text + "Attribute"), scope, LookupMask::type); // // If we didn't find a matching type name, then we give up. // @@ -140,12 +138,12 @@ namespace Slang // We will now synthesize a new `AttributeDecl` to mirror // what was declared on the `struct` type. // - RefPtr<AttributeDecl> attrDecl = new AttributeDecl(); + RefPtr<AttributeDecl> attrDecl = m_astBuilder->create<AttributeDecl>(); attrDecl->nameAndLoc.name = attributeName; attrDecl->nameAndLoc.loc = structDecl->nameAndLoc.loc; attrDecl->loc = structDecl->loc; - RefPtr<AttributeTargetModifier> targetModifier = new AttributeTargetModifier(); + RefPtr<AttributeTargetModifier> targetModifier = m_astBuilder->create<AttributeTargetModifier>(); targetModifier->syntaxClass = attrUsageAttr->targetSyntaxClass; targetModifier->loc = attrUsageAttr->loc; addModifier(attrDecl, targetModifier); @@ -155,8 +153,8 @@ namespace Slang // // User-defined attributes create instances of // `UserDefinedAttribute`. - // - attrDecl->syntaxClass = session->findSyntaxClass(session->getNameObj("UserDefinedAttribute")); + // + attrDecl->syntaxClass = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("UserDefinedAttribute")); // The fields of the user-defined `struct` type become // the parameters of the new attribute. @@ -169,7 +167,7 @@ namespace Slang { ensureDecl(varMember, DeclCheckState::CanUseTypeOfValueDecl); - RefPtr<ParamDecl> paramDecl = new ParamDecl(); + RefPtr<ParamDecl> paramDecl = m_astBuilder->create<ParamDecl>(); paramDecl->nameAndLoc = member->nameAndLoc; paramDecl->type = varMember->type; paramDecl->loc = member->loc; @@ -236,17 +234,17 @@ namespace Slang { if (typeFlags == (int)UserDefinedAttributeTargets::Struct) { - cls = getSession()->findSyntaxClass(getSession()->getNameObj("StructDecl")); + cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("StructDecl")); return true; } if (typeFlags == (int)UserDefinedAttributeTargets::Var) { - cls = getSession()->findSyntaxClass(getSession()->getNameObj("VarDecl")); + cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("VarDecl")); return true; } if (typeFlags == (int)UserDefinedAttributeTargets::Function) { - cls = getSession()->findSyntaxClass(getSession()->getNameObj("FuncDecl")); + cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("FuncDecl")); return true; } return false; diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 745597a76..6af2eec58 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -165,9 +165,9 @@ namespace Slang // // Along the way we will build up a `GenericSubstitution` // to represent the arguments that have been coerced to - // appropriateforms. + // appropriate forms. // - auto genSubst = new GenericSubstitution(); + auto genSubst = m_astBuilder->create<GenericSubstitution>(); candidate.subst = genSubst; auto& checkedArgs = genSubst->args; @@ -228,7 +228,7 @@ namespace Slang // if( !typeArg.type ) { - typeArg.type = getSession()->getErrorType(); + typeArg.type = m_astBuilder->getErrorType(); success = false; } @@ -257,7 +257,7 @@ namespace Slang if (context.mode == OverloadResolveContext::Mode::JustTrying) { ConversionCost cost = kConversionCost_None; - if (!canCoerce(GetType(valParamRef), arg->type, &cost)) + if (!canCoerce(GetType(m_astBuilder, valParamRef), arg->type, &cost)) { success = false; } @@ -265,7 +265,7 @@ namespace Slang } else { - arg = coerce(GetType(valParamRef), arg); + arg = coerce(GetType(m_astBuilder, valParamRef), arg); } } @@ -288,7 +288,7 @@ namespace Slang // if( !val ) { - val = new ErrorIntVal(); + val = m_astBuilder->create<ErrorIntVal>(); } checkedArgs.add(val); } @@ -343,10 +343,10 @@ namespace Slang if( context.disallowNestedConversions ) { // We need an exact match in this case. - if(!GetType(param)->equals(argType)) + if(!GetType(m_astBuilder, param)->equals(argType)) return false; } - else if (!canCoerce(GetType(param), argType, &cost)) + else if (!canCoerce(GetType(m_astBuilder, param), argType, &cost)) { return false; } @@ -354,7 +354,7 @@ namespace Slang } else { - arg = coerce(GetType(param), arg); + arg = coerce(GetType(m_astBuilder, param), arg); } } return true; @@ -396,8 +396,8 @@ namespace Slang DeclRef<GenericTypeConstraintDecl> constraintDeclRef( constraintDecl, subset); - auto sub = GetSub(constraintDeclRef); - auto sup = GetSup(constraintDeclRef); + auto sub = GetSub(m_astBuilder, constraintDeclRef); + auto sup = GetSup(m_astBuilder, constraintDeclRef); auto subTypeWitness = tryGetSubtypeWitness(sub, sup); if(subTypeWitness) @@ -523,7 +523,7 @@ namespace Slang RefPtr<AppExprBase> callExpr = as<InvokeExpr>(context.originalExpr); if(!callExpr) { - callExpr = new InvokeExpr(); + callExpr = m_astBuilder->create<InvokeExpr>(); callExpr->loc = context.loc; for(Index aa = 0; aa < context.argCount; ++aa) @@ -940,7 +940,7 @@ namespace Slang OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Func; candidate.item = item; - candidate.resultType = GetResultType(funcDeclRef); + candidate.resultType = GetResultType(m_astBuilder, funcDeclRef); AddOverloadCandidate(context, candidate); } @@ -1037,7 +1037,7 @@ namespace Slang // So the question is then whether a mismatch during the // unification step should be taken as an immediate failure... - TryUnifyTypes(constraints, context.getArgType(aa), GetType(params[aa])); + TryUnifyTypes(constraints, context.getArgType(aa), GetType(m_astBuilder, params[aa])); #endif } } @@ -1086,7 +1086,7 @@ namespace Slang // some just to make code that does, e.g., `float(1.0f)` work. LookupResult initializers = lookUpMember( - getSession(), + m_astBuilder, this, getName("$init"), type); @@ -1106,9 +1106,7 @@ namespace Slang } else if (auto aggTypeDeclRef = item.declRef.as<AggTypeDecl>()) { - auto type = DeclRefType::Create( - getSession(), - aggTypeDeclRef); + auto type = DeclRefType::create(m_astBuilder, aggTypeDeclRef); AddTypeOverloadCandidates(type, context); } else if (auto genericDeclRef = item.declRef.as<GenericDecl>()) @@ -1142,14 +1140,12 @@ namespace Slang } else if( auto typeDefDeclRef = item.declRef.as<TypeDefDecl>() ) { - auto type = getNamedType(getSession(), typeDefDeclRef); + auto type = getNamedType(m_astBuilder, typeDefDeclRef); AddTypeOverloadCandidates(type, context); } else if( auto genericTypeParamDeclRef = item.declRef.as<GenericTypeParamDecl>() ) { - auto type = DeclRefType::Create( - getSession(), - genericTypeParamDeclRef); + auto type = DeclRefType::create(m_astBuilder, genericTypeParamDeclRef); AddTypeOverloadCandidates(type, context); } else @@ -1328,7 +1324,7 @@ namespace Slang { if (!first) sb << ", "; - formatType(sb, GetType(paramDeclRef)); + formatType(sb, GetType(m_astBuilder, paramDeclRef)); first = false; @@ -1356,7 +1352,7 @@ namespace Slang sb << getText(genericValParam.GetName()); sb << ":"; - formatType(sb, GetType(genericValParam)); + formatType(sb, GetType(m_astBuilder, genericValParam)); } else {} @@ -1395,7 +1391,7 @@ namespace Slang else if(auto callableDeclRef = declRef.as<CallableDecl>()) { sb << " -> "; - formatType(sb, GetResultType(callableDeclRef)); + formatType(sb, GetResultType(m_astBuilder, callableDeclRef)); } } @@ -1632,7 +1628,7 @@ namespace Slang { // Nothing at all was found that we could even consider invoking getSink()->diagnose(expr->functionExpr, Diagnostics::expectedFunction, funcExprType); - expr->type = QualType(getSession()->getErrorType()); + expr->type = QualType(m_astBuilder->getErrorType()); return expr; } } @@ -1748,7 +1744,7 @@ namespace Slang // There were multiple viable candidates, but that isn't an error: we just need // to complete all of them and create an overloaded expression as a result. - auto overloadedExpr = new OverloadedExpr2(); + auto overloadedExpr = m_astBuilder->create<OverloadedExpr2>(); overloadedExpr->base = context.baseExpr; for (auto candidate : context.bestCandidates) { diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index ec0c5137e..f68439f8a 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -46,11 +46,13 @@ namespace Slang /// Recursively walk `paramDeclRef` and add any existential/interface specialization parameters to `ioSpecializationParams`. static void _collectExistentialSpecializationParamsRec( + ASTBuilder* astBuilder, SpecializationParams& ioSpecializationParams, DeclRef<VarDeclBase> paramDeclRef); /// Recursively walk `type` and add any existential/interface specialization parameters to `ioSpecializationParams`. static void _collectExistentialSpecializationParamsRec( + ASTBuilder* astBuilder, SpecializationParams& ioSpecializationParams, Type* type, SourceLoc loc) @@ -66,6 +68,7 @@ namespace Slang if( auto parameterGroupType = as<ParameterGroupType>(type) ) { _collectExistentialSpecializationParamsRec( + astBuilder, ioSpecializationParams, parameterGroupType->getElementType(), loc); @@ -95,6 +98,7 @@ namespace Slang for( auto fieldDeclRef : GetFields(structDeclRef, MemberFilterStyle::Instance) ) { _collectExistentialSpecializationParamsRec( + astBuilder, ioSpecializationParams, fieldDeclRef); } @@ -107,24 +111,27 @@ namespace Slang } static void _collectExistentialSpecializationParamsRec( + ASTBuilder* astBuilder, SpecializationParams& ioSpecializationParams, DeclRef<VarDeclBase> paramDeclRef) { _collectExistentialSpecializationParamsRec( + astBuilder, ioSpecializationParams, - GetType(paramDeclRef), + GetType(astBuilder, paramDeclRef), paramDeclRef.getLoc()); } /// Collect any interface/existential specialization parameters for `paramDeclRef` into `ioParamInfo` and `ioSpecializationParams` static void _collectExistentialSpecializationParamsForShaderParam( + ASTBuilder* astBuilder, ShaderParamInfo& ioParamInfo, SpecializationParams& ioSpecializationParams, DeclRef<VarDeclBase> paramDeclRef) { Index beginParamIndex = ioSpecializationParams.getCount(); - _collectExistentialSpecializationParamsRec(ioSpecializationParams, paramDeclRef); + _collectExistentialSpecializationParamsRec(astBuilder, ioSpecializationParams, paramDeclRef); Index endParamIndex = ioSpecializationParams.getCount(); ioParamInfo.firstSpecializationParamIndex = beginParamIndex; @@ -199,6 +206,7 @@ namespace Slang shaderParamInfo.paramDeclRef = paramDeclRef; _collectExistentialSpecializationParamsForShaderParam( + getLinkage()->getASTBuilder(), shaderParamInfo, m_existentialSpecializationParams, paramDeclRef); @@ -617,6 +625,7 @@ namespace Slang // with the correct parameter. // _collectExistentialSpecializationParamsForShaderParam( + getLinkage()->getASTBuilder(), shaderParamInfo, m_specializationParams, makeDeclRef(globalVar.Ptr())); @@ -966,7 +975,7 @@ namespace Slang if(!argType) { sink->diagnose(param.loc, Diagnostics::expectedTypeForSpecializationArg, genericTypeParamDecl); - argType = getLinkage()->getSessionImpl()->getErrorType(); + argType = getLinkage()->getASTBuilder()->getErrorType(); } // TODO: There is a serious flaw to this checking logic if we ever have cases where @@ -1026,7 +1035,7 @@ namespace Slang for(auto constraintDecl : genericTypeParamDecl->getMembersOfType<GenericTypeConstraintDecl>()) { // Get the type that the constraint is enforcing conformance to - auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraintDecl, nullptr)); + auto interfaceType = GetSup(getLinkage()->getASTBuilder(), DeclRef<GenericTypeConstraintDecl>(constraintDecl, nullptr)); // Use our semantic-checking logic to search for a witness to the required conformance auto witness = visitor.tryGetSubtypeWitness(argType, interfaceType); @@ -1058,7 +1067,7 @@ namespace Slang if(!argType) { sink->diagnose(param.loc, Diagnostics::expectedTypeForSpecializationArg, interfaceType); - argType = getLinkage()->getSessionImpl()->getErrorType(); + argType = getLinkage()->getASTBuilder()->getErrorType(); } auto witness = visitor.tryGetSubtypeWitness(argType, interfaceType); @@ -1092,7 +1101,7 @@ namespace Slang if(!intVal) { sink->diagnose(param.loc, Diagnostics::expectedValueOfTypeForSpecializationArg, paramDecl->getType(), paramDecl); - intVal = new ConstantIntVal(0); + intVal = getLinkage()->getASTBuilder()->create<ConstantIntVal>(0); } ModuleSpecializationInfo::GenericArgInfo expandedArg; @@ -1166,7 +1175,7 @@ namespace Slang auto genericDeclRef = m_funcDeclRef.GetParent().as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters - RefPtr<GenericSubstitution> genericSubst = new GenericSubstitution(); + RefPtr<GenericSubstitution> genericSubst = getLinkage()->getASTBuilder()->create<GenericSubstitution>(); genericSubst->outer = genericDeclRef.substitutions.substitutions; genericSubst->genericDecl = genericDeclRef.getDecl(); @@ -1184,8 +1193,10 @@ namespace Slang DeclRef<GenericTypeConstraintDecl> constraintDeclRef( constraintDecl, constraintSubst); - auto sub = GetSub(constraintDeclRef); - auto sup = GetSup(constraintDeclRef); + ASTBuilder* astBuilder = getLinkage()->getASTBuilder(); + + auto sub = GetSub(astBuilder, constraintDeclRef); + auto sup = GetSup(astBuilder, constraintDeclRef); auto subTypeWitness = visitor.tryGetSubtypeWitness(sub, sup); if(subTypeWitness) @@ -1360,7 +1371,7 @@ namespace Slang SemanticsVisitor visitor(&sharedSemanticsContext); SpecializationParams specializationParams; - _collectExistentialSpecializationParamsRec(specializationParams, unspecializedType, SourceLoc()); + _collectExistentialSpecializationParamsRec(getASTBuilder(), specializationParams, unspecializedType, SourceLoc()); assert(specializationParams.getCount() == argCount); @@ -1376,7 +1387,7 @@ namespace Slang specializationArgs.add(arg); } - RefPtr<ExistentialSpecializedType> specializedType = new ExistentialSpecializedType(); + RefPtr<ExistentialSpecializedType> specializedType = m_astBuilder->create<ExistentialSpecializedType>(); specializedType->baseType = unspecializedType; specializedType->args = specializationArgs; diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index fa2eb59a0..71cf97396 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -112,7 +112,7 @@ namespace Slang { RefPtr<Expr> e = expr; e = CheckTerm(e); - e = coerce(getSession()->getBoolType(), e); + e = coerce(m_astBuilder->getBoolType(), e); return e; } @@ -153,8 +153,8 @@ namespace Slang { WithOuterStmt subContext(this, stmt); - stmt->varDecl->type.type = getSession()->getIntType(); - addModifier(stmt->varDecl, new ConstModifier()); + stmt->varDecl->type.type = m_astBuilder->getIntType(); + addModifier(stmt->varDecl, m_astBuilder->create<ConstModifier>()); stmt->varDecl->setCheckState(DeclCheckState::Checked); RefPtr<IntVal> rangeBeginVal; @@ -166,7 +166,7 @@ namespace Slang } else { - RefPtr<ConstantIntVal> rangeBeginConst = new ConstantIntVal(); + RefPtr<ConstantIntVal> rangeBeginConst = m_astBuilder->create<ConstantIntVal>(); rangeBeginConst->value = 0; rangeBeginVal = rangeBeginConst; } @@ -250,7 +250,7 @@ namespace Slang auto function = getParentFunc(); if (!stmt->expression) { - if (function && !function->returnType.equals(getSession()->getVoidType())) + if (function && !function->returnType.equals(m_astBuilder->getVoidType())) { getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); } @@ -258,7 +258,7 @@ namespace Slang else { stmt->expression = CheckTerm(stmt->expression); - if (!stmt->expression->type->equals(getSession()->getErrorType())) + if (!stmt->expression->type->equals(m_astBuilder->getErrorType())) { if (function) { diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index e6c924be8..db5d555cd 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -37,7 +37,7 @@ namespace Slang { return typeType->type; } - return getSession()->getErrorType(); + return m_astBuilder->getErrorType(); } RefPtr<Type> SemanticsVisitor::TranslateTypeNode(const RefPtr<Expr> & node) @@ -97,7 +97,7 @@ namespace Slang { return typeType->type; } - return getSession()->getErrorType(); + return m_astBuilder->getErrorType(); } RefPtr<Type> SemanticsVisitor::ExtractGenericArgType(RefPtr<Expr> exp) @@ -114,7 +114,7 @@ namespace Slang // constant expression in context, then we will instead construct // a dummy "error" value to represent the result. // - val = new ErrorIntVal(); + val = m_astBuilder->create<ErrorIntVal>(); return val; } @@ -149,7 +149,7 @@ namespace Slang DeclRef<GenericDecl> genericDeclRef, List<RefPtr<Expr>> const& args) { - RefPtr<GenericSubstitution> subst = new GenericSubstitution(); + RefPtr<GenericSubstitution> subst = m_astBuilder->create<GenericSubstitution>(); subst->genericDecl = genericDeclRef.getDecl(); subst->outer = genericDeclRef.substitutions.substitutions; @@ -162,9 +162,7 @@ namespace Slang innerDeclRef.decl = GetInner(genericDeclRef); innerDeclRef.substitutions = SubstitutionSet(subst); - return DeclRefType::Create( - getSession(), - innerDeclRef); + return DeclRefType::create(m_astBuilder, innerDeclRef); } bool SemanticsVisitor::CoerceToProperTypeImpl( @@ -216,7 +214,7 @@ namespace Slang if (diagSink) { diagSink->diagnose(typeExp.exp.Ptr(), Diagnostics::genericTypeNeedsArgs, typeExp); - *outProperType = getSession()->getErrorType(); + *outProperType = m_astBuilder->getErrorType(); } return false; } @@ -232,7 +230,7 @@ namespace Slang if (diagSink) { diagSink->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); - *outProperType = getSession()->getErrorType(); + *outProperType = m_astBuilder->getErrorType(); } return false; } @@ -293,7 +291,7 @@ namespace Slang { // TODO(tfoley): pick the right diagnostic message getSink()->diagnose(result.exp.Ptr(), Diagnostics::invalidTypeVoid); - result.type = getSession()->getErrorType(); + result.type = m_astBuilder->getErrorType(); return result; } } @@ -334,21 +332,18 @@ namespace Slang RefPtr<Type> elementType, RefPtr<IntVal> elementCount) { - auto session = getSession(); - auto vectorGenericDecl = findMagicDecl( - session, "Vector").as<GenericDecl>(); + auto vectorGenericDecl = m_astBuilder->getSharedASTBuilder()->findMagicDecl("Vector").as<GenericDecl>(); + auto vectorTypeDecl = vectorGenericDecl->inner; - auto substitutions = new GenericSubstitution(); + auto substitutions = m_astBuilder->create<GenericSubstitution>(); substitutions->genericDecl = vectorGenericDecl.Ptr(); substitutions->args.add(elementType); substitutions->args.add(elementCount); auto declRef = DeclRef<Decl>(vectorTypeDecl.Ptr(), substitutions); - return DeclRefType::Create( - session, - declRef).as<VectorExpressionType>(); + return DeclRefType::create(m_astBuilder, declRef).as<VectorExpressionType>(); } RefPtr<Expr> SemanticsExprVisitor::visitSharedTypeExpr(SharedTypeExpr* expr) @@ -366,8 +361,8 @@ namespace Slang // We have an expression of the form `__TaggedUnion(A, B, ...)` // which will evaluate to a tagged-union type over `A`, `B`, etc. // - RefPtr<TaggedUnionType> type = new TaggedUnionType(); - expr->type = QualType(getTypeType(type)); + RefPtr<TaggedUnionType> type = m_astBuilder->create<TaggedUnionType>(); + expr->type = QualType(m_astBuilder->getTypeType(type)); for( auto& caseTypeExpr : expr->caseTypes ) { diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 83f362ec7..489808a44 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -188,7 +188,7 @@ namespace Slang funcDeclRef.GetName(), profile, funcDeclRef); - entryPoint->m_mangledName = getMangledName(funcDeclRef); + entryPoint->m_mangledName = getMangledName(linkage->getASTBuilder(), funcDeclRef); return entryPoint; } diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 852898de6..2c23a89bc 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1228,7 +1228,7 @@ namespace Slang SlangMatrixLayoutMode mode); /// Create an initially-empty linkage - Linkage(Session* session); + Linkage(Session* session, ASTBuilder* astBuilder); /// Get the parent session for this linkage Session* getSessionImpl() { return m_session; } @@ -1256,6 +1256,10 @@ namespace Slang NamePool* getNamePool() { return &namePool; } + ASTBuilder* getASTBuilder() { return m_astBuilder; } + + RefPtr<ASTBuilder> m_astBuilder; + // Modules that have been dynamically loaded via `import` // // This is a list of unique modules loaded, in the order they were encountered. @@ -2043,6 +2047,12 @@ namespace Slang Name* tryGetNameObj(String name) { return namePool.tryGetName(name); } // + /// This AST Builder should only be used for creating AST nodes that are global across requests + /// not doing so could lead to memory being consumed but not used. + ASTBuilder* getGlobalASTBuilder() { return globalAstBuilder; } + + RefPtr<ASTBuilder> globalAstBuilder; + // Generated code for stdlib, etc. String stdlibPath; String coreLibraryCode; @@ -2054,72 +2064,8 @@ namespace Slang String getCoreLibraryCode(); String getHLSLLibraryCode(); - // Basic types that we don't want to re-create all the time - RefPtr<Type> errorType; - RefPtr<Type> initializerListType; - RefPtr<Type> overloadedType; - RefPtr<Type> constExprRate; - RefPtr<Type> irBasicBlockType; - - RefPtr<Type> stringType; - RefPtr<Type> enumTypeType; - - RefPtr<Type> builtinTypes[Index(BaseType::CountOf)]; - Dictionary<String, Decl*> magicDecls; - - void initializeTypes(); - - Type* getBoolType(); - Type* getHalfType(); - Type* getFloatType(); - Type* getDoubleType(); - Type* getIntType(); - Type* getInt64Type(); - Type* getUIntType(); - Type* getUInt64Type(); - Type* getVoidType(); - Type* getBuiltinType(BaseType flavor); - - Type* getInitializerListType(); - Type* getOverloadedType(); - Type* getErrorType(); - Type* getStringType(); - - Type* getEnumTypeType(); - - // Construct the type `Ptr<valueType>`, where `Ptr` - // is looked up as a builtin type. - RefPtr<PtrType> getPtrType(RefPtr<Type> valueType); - - // Construct the type `Out<valueType>` - RefPtr<OutType> getOutType(RefPtr<Type> valueType); - - // Construct the type `InOut<valueType>` - RefPtr<InOutType> getInOutType(RefPtr<Type> valueType); - - // Construct the type `Ref<valueType>` - RefPtr<RefType> getRefType(RefPtr<Type> valueType); - - // Construct a pointer type like `Ptr<valueType>`, but where - // the actual type name for the pointer type is given by `ptrTypeName` - RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, char const* ptrTypeName); - - // Construct a pointer type like `Ptr<valueType>`, but where - // the generic declaration for the pointer type is `genericDecl` - RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, GenericDecl* genericDecl); - - RefPtr<ArrayExpressionType> getArrayType( - Type* elementType, - IntVal* elementCount); - - RefPtr<VectorExpressionType> getVectorType( - RefPtr<Type> elementType, - RefPtr<IntVal> elementCount); - - SyntaxClass<RefObject> findSyntaxClass(Name* name); - - Dictionary<Name*, SyntaxClass<RefObject> > mapNameToSyntaxClass; + RefPtr<SharedASTBuilder> m_sharedASTBuilder; // cache used by type checking, implemented in check.cpp TypeCheckingCache* typeCheckingCache = nullptr; diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index 1a1d2e621..b174d068c 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -173,7 +173,7 @@ LookupResultItem CreateLookupResultItem( } static void _lookUpMembersInValue( - Session* session, + ASTBuilder* astBuilder, Name* name, DeclRef<Decl> valueDeclRef, LookupRequest const& request, @@ -187,7 +187,7 @@ static void _lookUpMembersInValue( /// inheritance clauses, etc. /// static void _lookUpDirectAndTransparentMembers( - Session* session, + ASTBuilder* astBuilder, Name* name, DeclRef<ContainerDecl> containerDeclRef, LookupRequest const& request, @@ -236,7 +236,7 @@ static void _lookUpDirectAndTransparentMembers( memberRefBreadcrumb.prev = inBreadcrumbs; _lookUpMembersInValue( - session, + astBuilder, name, transparentMemberDeclRef, request, @@ -247,7 +247,7 @@ static void _lookUpDirectAndTransparentMembers( /// Perform "direct" lookup in a container declaration LookupResult lookUpDirectAndTransparentMembers( - Session* session, + ASTBuilder* astBuilder, SemanticsVisitor* semantics, Name* name, DeclRef<ContainerDecl> containerDeclRef, @@ -258,7 +258,7 @@ LookupResult lookUpDirectAndTransparentMembers( request.mask = mask; LookupResult result; _lookUpDirectAndTransparentMembers( - session, + astBuilder, name, containerDeclRef, request, @@ -269,6 +269,7 @@ LookupResult lookUpDirectAndTransparentMembers( static RefPtr<SubtypeWitness> _makeSubtypeWitness( + ASTBuilder* astBuilder, Type* subType, SubtypeWitness* subToMidWitness, Type* superType, @@ -276,7 +277,7 @@ static RefPtr<SubtypeWitness> _makeSubtypeWitness( { if(subToMidWitness) { - RefPtr<TransitiveSubtypeWitness> transitiveWitness = new TransitiveSubtypeWitness(); + RefPtr<TransitiveSubtypeWitness> transitiveWitness = astBuilder->create<TransitiveSubtypeWitness>(); transitiveWitness->subToMid = subToMidWitness; transitiveWitness->midToSup = midToSuperConstraint; transitiveWitness->sub = subType; @@ -285,7 +286,7 @@ static RefPtr<SubtypeWitness> _makeSubtypeWitness( } else { - RefPtr<DeclaredSubtypeWitness> declaredWitness = new DeclaredSubtypeWitness(); + RefPtr<DeclaredSubtypeWitness> declaredWitness = astBuilder->create<DeclaredSubtypeWitness>(); declaredWitness->declRef = midToSuperConstraint; declaredWitness->sub = subType; declaredWitness->sup = superType; @@ -295,7 +296,7 @@ static RefPtr<SubtypeWitness> _makeSubtypeWitness( // Same as the above, but we are specializing a type instead of a decl-ref static RefPtr<Type> _maybeSpecializeSuperType( - Session* session, + ASTBuilder* astBuilder, Type* superType, SubtypeWitness* subIsSuperWitness) { @@ -303,14 +304,14 @@ static RefPtr<Type> _maybeSpecializeSuperType( { if (auto superInterfaceDeclRef = superDeclRefType->declRef.as<InterfaceDecl>()) { - RefPtr<ThisTypeSubstitution> thisTypeSubst = new ThisTypeSubstitution(); + RefPtr<ThisTypeSubstitution> thisTypeSubst = astBuilder->create<ThisTypeSubstitution>(); thisTypeSubst->interfaceDecl = superInterfaceDeclRef.getDecl(); thisTypeSubst->witness = subIsSuperWitness; thisTypeSubst->outer = superInterfaceDeclRef.substitutions.substitutions; auto specializedInterfaceDeclRef = DeclRef<Decl>(superInterfaceDeclRef.getDecl(), thisTypeSubst); - auto specializedInterfaceType = DeclRefType::Create(session, specializedInterfaceDeclRef); + auto specializedInterfaceType = DeclRefType::create(astBuilder, specializedInterfaceDeclRef); return specializedInterfaceType; } } @@ -319,7 +320,7 @@ static RefPtr<Type> _maybeSpecializeSuperType( } static void _lookUpMembersInType( - Session* session, + ASTBuilder* astBuilder, Name* name, RefPtr<Type> type, LookupRequest const& request, @@ -327,7 +328,7 @@ static void _lookUpMembersInType( BreadcrumbInfo* breadcrumbs); static void _lookUpMembersInSuperTypeImpl( - Session* session, + ASTBuilder* astBuilder, Name* name, Type* leafType, Type* superType, @@ -338,7 +339,7 @@ static void _lookUpMembersInSuperTypeImpl( static void _lookUpMembersInSuperType( - Session* session, + ASTBuilder* astBuilder, Name* name, Type* leafType, SubtypeWitness* leafIsIntermediateWitness, @@ -355,12 +356,13 @@ static void _lookUpMembersInSuperType( // The super-type in the constraint (e.g., `Foo` in `T : Foo`) // will tell us a type we should use for lookup. // - auto superType = GetSup(intermediateIsSuperConstraint); + auto superType = GetSup(astBuilder, intermediateIsSuperConstraint); // // We will go ahead and perform lookup using `superType`, // after dealing with some details. auto leafIsSuperWitness = _makeSubtypeWitness( + astBuilder, leafType, leafIsIntermediateWitness, superType, @@ -372,7 +374,7 @@ static void _lookUpMembersInSuperType( // be applied to any members we look up. // superType = _maybeSpecializeSuperType( - session, + astBuilder, superType, leafIsSuperWitness); @@ -394,11 +396,11 @@ static void _lookUpMembersInSuperType( // we might end up seeing the same interface via different "paths" and // we wouldn't want that to lead to overload-resolution failure. // - _lookUpMembersInSuperTypeImpl(session, name, leafType, superType, leafIsSuperWitness, request, ioResult, &breadcrumb); + _lookUpMembersInSuperTypeImpl(astBuilder, name, leafType, superType, leafIsSuperWitness, request, ioResult, &breadcrumb); } static void _lookUpMembersInSuperTypeDeclImpl( - Session* session, + ASTBuilder* astBuilder, Name* name, Type* leafType, Type* superType, @@ -435,7 +437,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( // We want constraints of the form `T : Foo` where `T` is the // generic parameter in question, and `Foo` is whatever we are // constraining it to. - auto subType = GetSub(constraintDeclRef); + auto subType = GetSub(astBuilder, constraintDeclRef); auto subDeclRefType = as<DeclRefType>(subType); if(!subDeclRefType) continue; @@ -443,7 +445,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( continue; _lookUpMembersInSuperType( - session, + astBuilder, name, leafType, leafIsSuperWitness, @@ -458,7 +460,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( for (auto constraintDeclRef : getMembersOfType<TypeConstraintDecl>(declRef.as<ContainerDecl>())) { _lookUpMembersInSuperType( - session, + astBuilder, name, leafType, leafIsSuperWitness, @@ -474,7 +476,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( // type or an `extension`, so the first thing to do is to look for // matching members declared directly in the body of the type/`extension`. // - _lookUpDirectAndTransparentMembers(session, name, aggTypeDeclBaseRef, request, ioResult, inBreadcrumbs); + _lookUpDirectAndTransparentMembers(astBuilder, name, aggTypeDeclBaseRef, request, ioResult, inBreadcrumbs); // There are further lookup steps that we can only perform when a // semantic checking context is available to us. That means that @@ -506,7 +508,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( // was found through an extension. // _lookUpMembersInSuperTypeDeclImpl( - session, + astBuilder, name, leafType, superType, @@ -525,14 +527,14 @@ static void _lookUpMembersInSuperTypeDeclImpl( for (auto inheritanceDeclRef : getMembersOfType<InheritanceDecl>(aggTypeDeclBaseRef)) { ensureDecl(semantics, inheritanceDeclRef.getDecl(), DeclCheckState::CanUseBaseOfInheritanceDecl); - _lookUpMembersInSuperType(session, name, leafType, leafIsSuperWitness, inheritanceDeclRef, request, ioResult, inBreadcrumbs); + _lookUpMembersInSuperType(astBuilder, name, leafType, leafIsSuperWitness, inheritanceDeclRef, request, ioResult, inBreadcrumbs); } } } } static void _lookUpMembersInSuperTypeImpl( - Session* session, + ASTBuilder* astBuilder, Name* name, Type* leafType, Type* superType, @@ -553,7 +555,7 @@ static void _lookUpMembersInSuperTypeImpl( // Recursively perform lookup on the result of deref _lookUpMembersInType( - session, + astBuilder, name, pointerLikeType->elementType, request, ioResult, &derefBreacrumb); return; } @@ -564,7 +566,7 @@ static void _lookUpMembersInSuperTypeImpl( { auto declRef = declRefType->declRef; - _lookUpMembersInSuperTypeDeclImpl(session, name, leafType, superType, leafIsSuperWitness, declRef, request, ioResult, inBreadcrumbs); + _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, leafType, superType, leafIsSuperWitness, declRef, request, ioResult, inBreadcrumbs); } } @@ -580,7 +582,7 @@ static void _lookUpMembersInSuperTypeImpl( /// set of members visible on `type`. /// static void _lookUpMembersInType( - Session* session, + ASTBuilder* astBuilder, Name* name, RefPtr<Type> type, LookupRequest const& request, @@ -592,7 +594,7 @@ static void _lookUpMembersInType( return; } - _lookUpMembersInSuperTypeImpl(session, name, type, type, nullptr, request, ioResult, breadcrumbs); + _lookUpMembersInSuperTypeImpl(astBuilder, name, type, type, nullptr, request, ioResult, breadcrumbs); } /// Look up members by `name` in the given `valueDeclRef`. @@ -602,7 +604,7 @@ static void _lookUpMembersInType( /// kind of lookup we'd expect for `valueDeclRef.<name>`. /// static void _lookUpMembersInValue( - Session* session, + ASTBuilder* astBuilder, Name* name, DeclRef<Decl> valueDeclRef, LookupRequest const& request, @@ -613,17 +615,13 @@ static void _lookUpMembersInValue( // be reduced to the problem of looking up `name` // in the *type* of that value. // - auto valueType = getTypeForDeclRef( - session, - valueDeclRef, - SourceLoc()); - return _lookUpMembersInType( - session, - name, valueType, request, ioResult, breadcrumbs); + auto valueType = getTypeForDeclRef(astBuilder, valueDeclRef, SourceLoc()); + + return _lookUpMembersInType(astBuilder, name, valueType, request, ioResult, breadcrumbs); } static void _lookUpInScopes( - Session* session, + ASTBuilder* astBuilder, Name* name, LookupRequest const& request, LookupResult& result) @@ -656,7 +654,7 @@ static void _lookUpInScopes( // just a decl. // DeclRef<ContainerDecl> containerDeclRef = - DeclRef<Decl>(containerDecl, createDefaultSubstitutions(session, containerDecl)).as<ContainerDecl>(); + DeclRef<Decl>(containerDecl, createDefaultSubstitutions(astBuilder, containerDecl)).as<ContainerDecl>(); // If the container we are looking into represents a type // or an `extension` of a type, then we need to treat @@ -693,15 +691,15 @@ static void _lookUpInScopes( // declaration, then the `this` expression will have // a type that uses the "target type" of the `extension`. // - type = GetTargetType(extDeclRef); + type = GetTargetType(astBuilder, extDeclRef); } else { assert(aggTypeDeclBaseRef.as<AggTypeDecl>()); - type = DeclRefType::Create(session, aggTypeDeclBaseRef); + type = DeclRefType::create(astBuilder, aggTypeDeclBaseRef); } - _lookUpMembersInType(session, name, type, request, result, &breadcrumb); + _lookUpMembersInType(astBuilder, name, type, request, result, &breadcrumb); } else { @@ -709,7 +707,7 @@ static void _lookUpInScopes( // type or `extension` declaration, so we can look up members // in that scope much more simply. // - _lookUpDirectAndTransparentMembers(session, name, containerDeclRef, request, result, nullptr); + _lookUpDirectAndTransparentMembers(astBuilder, name, containerDeclRef, request, result, nullptr); } // Before we proceed up to the next outer scope to perform lookup @@ -779,7 +777,7 @@ static void _lookUpInScopes( } LookupResult lookUp( - Session* session, + ASTBuilder* astBuilder, SemanticsVisitor* semantics, Name* name, RefPtr<Scope> scope, @@ -791,12 +789,12 @@ LookupResult lookUp( request.mask = mask; LookupResult result; - _lookUpInScopes(session, name, request, result); + _lookUpInScopes(astBuilder, name, request, result); return result; } void lookUpMemberImpl( - Session* session, + ASTBuilder* astBuilder, SemanticsVisitor* semantics, Name* name, Type* type, @@ -805,7 +803,7 @@ void lookUpMemberImpl( LookupMask mask); LookupResult lookUpMember( - Session* session, + ASTBuilder* astBuilder, SemanticsVisitor* semantics, Name* name, Type* type, @@ -816,7 +814,7 @@ LookupResult lookUpMember( request.mask = mask; LookupResult result; - _lookUpMembersInType(session, name, type, request, result, nullptr); + _lookUpMembersInType(astBuilder, name, type, request, result, nullptr); return result; } diff --git a/source/slang/slang-lookup.h b/source/slang/slang-lookup.h index 0f44121c6..253355d7b 100644 --- a/source/slang/slang-lookup.h +++ b/source/slang/slang-lookup.h @@ -18,7 +18,7 @@ void buildMemberDictionary(ContainerDecl* decl); // Look up a name in the given scope, proceeding up through // parent scopes as needed. LookupResult lookUp( - Session* session, + ASTBuilder* astBuilder, SemanticsVisitor* semantics, Name* name, RefPtr<Scope> scope, @@ -26,7 +26,7 @@ LookupResult lookUp( // Perform member lookup in the context of a type LookupResult lookUpMember( - Session* session, + ASTBuilder* astBuilder, SemanticsVisitor* semantics, Name* name, Type* type, @@ -34,7 +34,7 @@ LookupResult lookUpMember( /// Perform "direct" lookup in a container declaration LookupResult lookUpDirectAndTransparentMembers( - Session* session, + ASTBuilder* astBuilder, SemanticsVisitor* semantics, Name* name, DeclRef<ContainerDecl> containerDeclRef, @@ -43,7 +43,7 @@ LookupResult lookUpDirectAndTransparentMembers( // TODO: this belongs somewhere else QualType getTypeForDeclRef( - Session* session, + ASTBuilder* astBuilder, SemanticsVisitor* sema, DiagnosticSink* sink, DeclRef<Decl> declRef, @@ -51,7 +51,7 @@ QualType getTypeForDeclRef( SourceLoc loc); QualType getTypeForDeclRef( - Session* session, + ASTBuilder* astBuilder, DeclRef<Decl> declRef, SourceLoc loc); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5b4694d36..f612e4ade 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -360,6 +360,8 @@ struct SharedIRGenContext struct IRGenContext { + ASTBuilder* astBuilder; + // Shared state for the IR generation process SharedIRGenContext* shared; @@ -378,8 +380,9 @@ struct IRGenContext // might be insufficient. LoweredValInfo thisVal; - explicit IRGenContext(SharedIRGenContext* inShared) + explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder) : shared(inShared) + , astBuilder(inAstBuilder) , env(&inShared->globalEnv) , irBuilder(nullptr) {} @@ -969,7 +972,7 @@ static void addLinkageDecoration( IRInst* inst, Decl* decl) { - String mangledName = getMangledName(decl); + String mangledName = getMangledName(context->astBuilder, decl); if (context->shared->m_obfuscateCode) { @@ -1025,7 +1028,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitGenericParamIntVal(GenericParamIntVal* val) { return emitDeclRef(context, val->declRef, - lowerType(context, GetType(val->declRef))); + lowerType(context, GetType(context->astBuilder, val->declRef))); } LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val) @@ -1176,14 +1179,14 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // for emitting the signature of a `CallableDecl`, and we should // try to re-use that if at all possible. // - auto irParamType = lowerType(context, GetType(paramDeclRef)); + auto irParamType = lowerType(context, GetType(context->astBuilder, paramDeclRef)); auto irParam = subBuilder->emitParam(irParamType); irParams.add(irParam); irParamTypes.add(irParamType); } - auto irResultType = lowerType(context, GetResultType(callableDeclRef)); + auto irResultType = lowerType(context, GetResultType(context->astBuilder, callableDeclRef)); auto irFuncType = subBuilder->getFuncType( irParamTypes, @@ -1481,7 +1484,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower IRType* visitExtractExistentialType(ExtractExistentialType* type) { auto declRef = type->declRef; - auto existentialType = lowerType(context, GetType(declRef)); + auto existentialType = lowerType(context, GetType(context->astBuilder, declRef)); IRInst* existentialVal = getSimpleVal(context, emitDeclRef(context, declRef, existentialType)); return getBuilder()->emitExtractExistentialType(existentialVal); } @@ -1489,7 +1492,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitExtractExistentialSubtypeWitness(ExtractExistentialSubtypeWitness* witness) { auto declRef = witness->declRef; - auto existentialType = lowerType(context, GetType(declRef)); + auto existentialType = lowerType(context, GetType(context->astBuilder, declRef)); IRInst* existentialVal = getSimpleVal(context, emitDeclRef(context, declRef, existentialType)); return LoweredValInfo::simple(getBuilder()->emitExtractExistentialWitnessTable(existentialVal)); } @@ -1520,7 +1523,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // getBuilder()->addExportDecoration( irType, - getMangledTypeName(type).getUnownedSlice()); + getMangledTypeName(context->astBuilder, type).getUnownedSlice()); } return LoweredValInfo::simple(irType); } @@ -1922,7 +1925,7 @@ DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, Decl* d { DeclRef<Decl> declRef; declRef.decl = decl; - declRef.substitutions = createDefaultSubstitutions(context->getSession(), decl); + declRef.substitutions = createDefaultSubstitutions(context->astBuilder, decl); return declRef; } // @@ -1950,11 +1953,11 @@ RefPtr<Type> getThisParamTypeForContainer( { if( auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>() ) { - return DeclRefType::Create(context->getSession(), aggTypeDeclRef); + return DeclRefType::create(context->astBuilder, aggTypeDeclRef); } else if( auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>() ) { - return GetTargetType(extensionDeclRef); + return GetTargetType(context->astBuilder, extensionDeclRef); } return nullptr; @@ -1984,6 +1987,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> IRGenContext* context; IRBuilder* getBuilder() { return context->irBuilder; } + ASTBuilder* getASTBuilder() { return context->astBuilder; } // Lower an expression that should have the same l-value-ness // as the visitor itself. @@ -2514,7 +2518,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef)) { auto paramDecl = paramDeclRef.getDecl(); - IRType* paramType = lowerType(context, GetType(paramDeclRef)); + IRType* paramType = lowerType(context, GetType(getASTBuilder(), paramDeclRef)); auto paramDirection = getParameterDirection(paramDecl); UInt argIndex = argCounter++; @@ -2529,7 +2533,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // but there are still parameters remaining. This must mean // that these parameters have default argument expressions // associated with them. - argExpr = getInitExpr(paramDeclRef); + argExpr = getInitExpr(getASTBuilder(), paramDeclRef); // Assert that such an expression must have been present. SLANG_ASSERT(argExpr); @@ -2879,7 +2883,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo visitExtractExistentialValueExpr(ExtractExistentialValueExpr* expr) { - auto existentialType = lowerType(context, GetType(expr->declRef)); + auto existentialType = lowerType(context, GetType(getASTBuilder(), expr->declRef)); auto existentialVal = getSimpleVal(context, emitDeclRef(context, expr->declRef, existentialType)); auto openedType = lowerType(context, expr->type); @@ -4461,9 +4465,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } else { - subType = DeclRefType::Create( - context->getSession(), - makeDeclRef(parentDecl)); + subType = DeclRefType::create(context->astBuilder, makeDeclRef(parentDecl)); } // What is the super-type that we have declared we inherit from? @@ -4473,7 +4475,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // on the type that is conforming, and the type that it conforms to. // // TODO: This approach doesn't really make sense for generic `extension` conformances. - auto mangledName = getMangledNameForConformanceWitness(subType, superType); + auto mangledName = getMangledNameForConformanceWitness(context->astBuilder, subType, superType); // A witness table may need to be generic, if the outer // declaration (either a type declaration or an `extension`) @@ -6657,9 +6659,7 @@ static void lowerProgramEntryPointToIR( // First, lower the entry point like an ordinary function - - auto session = context->getSession(); - auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); + auto entryPointFuncType = lowerType(context, getFuncType(context->astBuilder, entryPointFuncDeclRef)); auto builder = context->irBuilder; builder->setInsertInto(builder->getModule()->getModuleInst()); @@ -6669,7 +6669,7 @@ static void lowerProgramEntryPointToIR( if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>()) { - builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); + builder->addExportDecoration(loweredEntryPointFunc, getMangledName(context->astBuilder, entryPointFuncDeclRef).getUnownedSlice()); } // We may have shader parameters of interface/existential type, @@ -6725,6 +6725,7 @@ static void ensureAllDeclsRec( } IRModule* generateIRForTranslationUnit( + ASTBuilder* astBuilder, TranslationUnitRequest* translationUnit) { auto compileRequest = translationUnit->compileRequest; @@ -6736,7 +6737,7 @@ IRModule* generateIRForTranslationUnit( translationUnit->getModuleDecl()); SharedIRGenContext* sharedContext = &sharedContextStorage; - IRGenContext contextStorage(sharedContext); + IRGenContext contextStorage(sharedContext, astBuilder); IRGenContext* context = &contextStorage; SharedIRBuilder sharedBuilderStorage; @@ -6926,7 +6927,7 @@ struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor ); SharedIRGenContext* sharedContext = &sharedContextStorage; - IRGenContext contextStorage(sharedContext); + IRGenContext contextStorage(sharedContext, linkage->getASTBuilder()); context = &contextStorage; SharedIRBuilder sharedBuilderStorage; @@ -7034,8 +7035,8 @@ RefPtr<IRModule> TargetProgram::getOrCreateIRModuleForLayout(DiagnosticSink* sin /// Specialized IR generation context for when generating IR for layouts. struct IRLayoutGenContext : IRGenContext { - IRLayoutGenContext(SharedIRGenContext* shared) - : IRGenContext(shared) + IRLayoutGenContext(SharedIRGenContext* shared, ASTBuilder* astBuilder) + : IRGenContext(shared, astBuilder) {} /// Cache for custom key instructions used for entry-point parameter layout information. @@ -7325,7 +7326,9 @@ RefPtr<IRModule> TargetProgram::createIRModuleForLayout(DiagnosticSink* sink) linkage->m_obfuscateCode); auto sharedContext = &sharedContextStorage; - IRLayoutGenContext contextStorage(sharedContext); + ASTBuilder* astBuilder = linkage->getASTBuilder(); + + IRLayoutGenContext contextStorage(sharedContext, astBuilder); auto context = &contextStorage; SharedIRBuilder sharedBuilderStorage; @@ -7375,12 +7378,12 @@ RefPtr<IRModule> TargetProgram::createIRModuleForLayout(DiagnosticSink* sink) if(!funcDeclRef) continue; - auto irFuncType = lowerType(context, getFuncType(session, funcDeclRef)); + auto irFuncType = lowerType(context, getFuncType(astBuilder, funcDeclRef)); auto irFunc = getSimpleVal(context, emitDeclRef(context, funcDeclRef, irFuncType)); if( !irFunc->findDecoration<IRLinkageDecoration>() ) { - builder->addImportDecoration(irFunc, getMangledName(funcDeclRef).getUnownedSlice()); + builder->addImportDecoration(irFunc, getMangledName(astBuilder, funcDeclRef).getUnownedSlice()); } auto irEntryPointLayout = lowerEntryPointLayout(context, entryPointLayout); diff --git a/source/slang/slang-lower-to-ir.h b/source/slang/slang-lower-to-ir.h index 33dfd9d27..dbf2e550a 100644 --- a/source/slang/slang-lower-to-ir.h +++ b/source/slang/slang-lower-to-ir.h @@ -27,6 +27,7 @@ namespace Slang /// that are imported before code generation can be performed. /// IRModule* generateIRForTranslationUnit( + ASTBuilder* astBuilder, TranslationUnitRequest* translationUnit); /// Generate an IR module to represent the specializations applied by `componentType`. diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 0ecc12b45..5e141228d 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -7,6 +7,11 @@ namespace Slang { struct ManglingContext { + ManglingContext(ASTBuilder* inAstBuilder): + astBuilder(inAstBuilder) + { + } + ASTBuilder* astBuilder; StringBuilder sb; }; @@ -130,7 +135,7 @@ namespace Slang } else if( auto namedType = dynamicCast<NamedExpressionType>(type) ) { - emitType(context, GetType(namedType->declRef)); + emitType(context, GetType(context->astBuilder, namedType->declRef)); } else if( auto declRefType = dynamicCast<DeclRefType>(type) ) { @@ -246,7 +251,7 @@ namespace Slang if(auto inheritanceDeclRef = declRef.as<InheritanceDecl>()) { emit(context, "I"); - emitType(context, GetSup(inheritanceDeclRef)); + emitType(context, GetSup(context->astBuilder, inheritanceDeclRef)); return; } @@ -259,7 +264,7 @@ namespace Slang // that is in the same module as the type it extends should // be treated as equivalent to the type itself. emit(context, "X"); - emitType(context, GetTargetType(extensionDeclRef)); + emitType(context, GetTargetType(context->astBuilder, extensionDeclRef)); return; } @@ -334,7 +339,7 @@ namespace Slang else if(auto genericValueParamDecl = mm.as<GenericValueParamDecl>()) { emitRaw(context, "v"); - emitType(context, GetType(genericValueParamDecl)); + emitType(context, GetType(context->astBuilder, genericValueParamDecl)); } else if(mm.as<GenericTypeConstraintDecl>()) { @@ -366,14 +371,14 @@ namespace Slang for(auto paramDeclRef : parameters) { - emitType(context, GetType(paramDeclRef)); + emitType(context, GetType(context->astBuilder, paramDeclRef)); } // Don't print result type for an initializer/constructor, // since it is implicit in the qualified name. if (!callableDeclRef.is<ConstructorDecl>()) { - emitType(context, GetResultType(callableDeclRef)); + emitType(context, GetResultType(context->astBuilder, callableDeclRef)); } } } @@ -419,29 +424,30 @@ namespace Slang emitQualifiedName(context, declRef); } - String getMangledName(DeclRef<Decl> const& declRef) + String getMangledName(ASTBuilder* astBuilder, DeclRef<Decl> const& declRef) { - ManglingContext context; + ManglingContext context(astBuilder); mangleName(&context, declRef); return context.sb.ProduceString(); } - String getMangledName(DeclRefBase const & declRef) + String getMangledName(ASTBuilder* astBuilder, DeclRefBase const & declRef) { - return getMangledName( + return getMangledName(astBuilder, DeclRef<Decl>(declRef.decl, declRef.substitutions)); } - String getMangledName(Decl* decl) + String getMangledName(ASTBuilder* astBuilder, Decl* decl) { - return getMangledName(makeDeclRef(decl)); + return getMangledName(astBuilder, makeDeclRef(decl)); } String getMangledNameForConformanceWitness( + ASTBuilder* astBuilder, DeclRef<Decl> sub, DeclRef<Decl> sup) { - ManglingContext context; + ManglingContext context(astBuilder); emitRaw(&context, "_SW"); emitQualifiedName(&context, sub); emitQualifiedName(&context, sup); @@ -449,6 +455,7 @@ namespace Slang } String getMangledNameForConformanceWitness( + ASTBuilder* astBuilder, DeclRef<Decl> sub, Type* sup) { @@ -457,7 +464,7 @@ namespace Slang // // {Conforms(sub,sup)} => _SW{sub}{sup} // - ManglingContext context; + ManglingContext context(astBuilder); emitRaw(&context, "_SW"); emitQualifiedName(&context, sub); emitType(&context, sup); @@ -465,6 +472,7 @@ namespace Slang } String getMangledNameForConformanceWitness( + ASTBuilder* astBuilder, Type* sub, Type* sup) { @@ -473,16 +481,16 @@ namespace Slang // // {Conforms(sub,sup)} => _SW{sub}{sup} // - ManglingContext context; + ManglingContext context(astBuilder); emitRaw(&context, "_SW"); emitType(&context, sub); emitType(&context, sup); return context.sb.ProduceString(); } - String getMangledTypeName(Type* type) + String getMangledTypeName(ASTBuilder* astBuilder, Type* type) { - ManglingContext context; + ManglingContext context(astBuilder); emitType(&context, type); return context.sb.ProduceString(); } diff --git a/source/slang/slang-mangle.h b/source/slang/slang-mangle.h index 186f8bae4..e579ebfda 100644 --- a/source/slang/slang-mangle.h +++ b/source/slang/slang-mangle.h @@ -10,22 +10,27 @@ namespace Slang { struct IRSpecialize; - String getMangledName(Decl* decl); - String getMangledName(DeclRef<Decl> const & declRef); - String getMangledName(DeclRefBase const & declRef); + String getMangledName(ASTBuilder* astBuilder, Decl* decl); + String getMangledName(ASTBuilder* astBuilder, DeclRef<Decl> const & declRef); + String getMangledName(ASTBuilder* astBuilder, DeclRefBase const & declRef); String getHashedName(const UnownedStringSlice& mangledName); String getMangledNameForConformanceWitness( + ASTBuilder* astBuilder, Type* sub, Type* sup); String getMangledNameForConformanceWitness( + ASTBuilder* astBuilder, DeclRef<Decl> sub, DeclRef<Decl> sup); String getMangledNameForConformanceWitness( + ASTBuilder* astBuilder, DeclRef<Decl> sub, Type* sup); - String getMangledTypeName(Type* type); + String getMangledTypeName( + ASTBuilder* astBuilder, + Type* type); } #endif diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 39cf16229..c4f7dc9d5 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -407,6 +407,8 @@ struct ParameterBindingContext TargetRequest* getTargetRequest() { return shared->getTargetRequest(); } LayoutRulesFamilyImpl* getRulesFamily() { return layoutContext.getRulesFamily(); } + ASTBuilder* getASTBuilder() { return shared->getLinkage()->getASTBuilder(); } + Linkage* getLinkage() { return shared->getLinkage(); } }; @@ -700,10 +702,12 @@ static void collectGlobalScopeParameter( ShaderParamInfo const& shaderParamInfo, SubstitutionSet globalGenericSubst) { + auto astBuilder = context->getASTBuilder(); + auto varDeclRef = shaderParamInfo.paramDeclRef; // We apply any substitutions for global generic parameters here. - auto type = GetType(varDeclRef)->substitute(globalGenericSubst).as<Type>(); + auto type = GetType(astBuilder, varDeclRef)->substitute(astBuilder, globalGenericSubst).as<Type>(); // We use a single operation to both check whether the // variable represents a shader parameter, and to compute @@ -1922,7 +1926,7 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( auto fieldTypeLayout = processEntryPointVaryingParameterDecl( context, field.getDecl(), - GetType(field), + GetType(context->getASTBuilder(), field), state, fieldVarLayout); @@ -2039,7 +2043,7 @@ static RefPtr<TypeLayout> computeEntryPointParameterTypeLayout( RefPtr<VarLayout> paramVarLayout, EntryPointParameterState& state) { - auto paramType = GetType(paramDeclRef); + auto paramType = GetType(context->getASTBuilder(), paramDeclRef); SLANG_ASSERT(paramType); if( paramDeclRef.getDecl()->hasModifier<HLSLUniformModifier>() ) @@ -2431,6 +2435,8 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) { + auto astBuilder = context->getASTBuilder(); + // We will take responsibility for creating and filling in // the `EntryPointLayout` object here. // @@ -2471,7 +2477,7 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( if(specializationInfo) entryPointFuncDeclRef = specializationInfo->specializedFuncDeclRef; - auto entryPointType = DeclRefType::Create(context->getLinkage()->getSessionImpl(), entryPointFuncDeclRef); + auto entryPointType = DeclRefType::create(astBuilder, entryPointFuncDeclRef); entryPointLayout->entryPoint = entryPointFuncDeclRef; @@ -2575,10 +2581,10 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( // TODO: Ideally we should make the layout process more robust to empty/void // types and apply this logic unconditionally. // - auto resultType = GetResultType(entryPointFuncDeclRef); + auto resultType = GetResultType(astBuilder, entryPointFuncDeclRef); SLANG_ASSERT(resultType); - if( !resultType->equals(resultType->getSession()->getVoidType()) ) + if( !resultType->equals(astBuilder->getVoidType()) ) { state.loc = entryPointFuncDeclRef.getLoc(); state.directionMask = kEntryPointParameterDirection_Output; diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index f8622964f..fec7147b5 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -84,6 +84,7 @@ namespace Slang public: NamePool* namePool; SourceLanguage sourceLanguage; + ASTBuilder* astBuilder; NamePool* getNamePool() { return namePool; } SourceLanguage getSourceLanguage() { return sourceLanguage; } @@ -131,19 +132,18 @@ namespace Slang currentScope = currentScope->parent; } Parser( - Session* session, + ASTBuilder* inAstBuilder, TokenSpan const& _tokens, DiagnosticSink * sink, RefPtr<Scope> const& outerScope) : tokenReader(_tokens) + , astBuilder(inAstBuilder) , sink(sink) , outerScope(outerScope) - , m_session(session) {} Parser(const Parser & other) = default; - Session* m_session = nullptr; - Session* getSession() { return m_session; } + //Session* getSession() { return m_session; } Token ReadToken(); Token ReadToken(TokenType type); @@ -588,7 +588,7 @@ namespace Slang RefPtr<RefObject> ParseTypeDef(Parser* parser, void* /*userData*/) { - RefPtr<TypeDefDecl> typeDefDecl = new TypeDefDecl(); + RefPtr<TypeDefDecl> typeDefDecl = parser->astBuilder->create<TypeDefDecl>(); // TODO(tfoley): parse an actual declarator auto type = parser->ParseTypeExp(); @@ -725,7 +725,7 @@ namespace Slang Token nameToken = parseAttributeName(parser); - RefPtr<UncheckedAttribute> modifier = new UncheckedAttribute(); + RefPtr<UncheckedAttribute> modifier = parser->astBuilder->create<UncheckedAttribute>(); modifier->name = nameToken.getName(); modifier->loc = nameToken.getLoc(); modifier->scope = parser->currentScope; @@ -786,7 +786,7 @@ namespace Slang // Let's look up the name and see what we find. auto lookupResult = lookUp( - parser->getSession(), + parser->astBuilder, nullptr, // no semantics visitor available yet name, parser->currentScope); @@ -931,7 +931,7 @@ namespace Slang { parser->haveSeenAnyImportDecls = true; - auto decl = new ImportDecl(); + auto decl = parser->astBuilder->create<ImportDecl>(); decl->scope = parser->currentScope; if (peekTokenType(parser) == TokenType::StringLiteral) @@ -1099,7 +1099,7 @@ namespace Slang if (AdvanceIf(parser, "let")) { // default case is a type parameter - auto paramDecl = new GenericValueParamDecl(); + auto paramDecl = parser->astBuilder->create<GenericValueParamDecl>(); paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); if (AdvanceIf(parser, TokenType::Colon)) { @@ -1114,24 +1114,24 @@ namespace Slang else { // default case is a type parameter - RefPtr<GenericTypeParamDecl> paramDecl = new GenericTypeParamDecl(); + RefPtr<GenericTypeParamDecl> paramDecl = parser->astBuilder->create<GenericTypeParamDecl>(); parser->FillPosition(paramDecl); paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); if (AdvanceIf(parser, TokenType::Colon)) { // The user is apply a constraint to this type parameter... - auto paramConstraint = new GenericTypeConstraintDecl(); + auto paramConstraint = parser->astBuilder->create<GenericTypeConstraintDecl>(); parser->FillPosition(paramConstraint); - auto paramType = DeclRefType::Create( - parser->getSession(), + auto paramType = DeclRefType::create( + parser->astBuilder, DeclRef<Decl>(paramDecl, nullptr)); - auto paramTypeExpr = new SharedTypeExpr(); + auto paramTypeExpr = parser->astBuilder->create<SharedTypeExpr>(); paramTypeExpr->loc = paramDecl->loc; paramTypeExpr->base.type = paramType; - paramTypeExpr->type = QualType(getTypeType(paramType)); + paramTypeExpr->type = QualType(parser->astBuilder->getTypeType(paramType)); paramConstraint->sub = TypeExp(paramTypeExpr); paramConstraint->sup = parser->ParseTypeExp(); @@ -1184,7 +1184,7 @@ namespace Slang // TODO: may want more advanced disambiguation than this... if (parser->LookAheadToken(TokenType::OpLess)) { - RefPtr<GenericDecl> genericDecl = new GenericDecl(); + RefPtr<GenericDecl> genericDecl = parser->astBuilder->create<GenericDecl>(); parser->FillPosition(genericDecl); parser->PushScope(genericDecl); ParseGenericDeclImpl(parser, genericDecl, parseInner); @@ -1199,7 +1199,7 @@ namespace Slang static RefPtr<RefObject> ParseGenericDecl(Parser* parser, void*) { - RefPtr<GenericDecl> decl = new GenericDecl(); + RefPtr<GenericDecl> decl = parser->astBuilder->create<GenericDecl>(); parser->FillPosition(decl.Ptr()); parser->PushScope(decl.Ptr()); ParseGenericDeclImpl(parser, decl.Ptr(), [=](GenericDecl* genDecl) {return ParseSingleDecl(parser, genDecl); }); @@ -1288,7 +1288,7 @@ namespace Slang Parser* parser, DeclaratorInfo const& declaratorInfo) { - RefPtr<FuncDecl> decl = new FuncDecl(); + RefPtr<FuncDecl> decl = parser->astBuilder->create<FuncDecl>(); parser->FillPosition(decl.Ptr()); decl->loc = declaratorInfo.nameAndLoc.loc; decl->nameAndLoc = declaratorInfo.nameAndLoc; @@ -1325,19 +1325,20 @@ namespace Slang } static RefPtr<VarDeclBase> CreateVarDeclForContext( + ASTBuilder* astBuilder, ContainerDecl* containerDecl ) { if (as<CallableDecl>(containerDecl)) { // Function parameters always use their dedicated syntax class. // - return new ParamDecl(); + return astBuilder->create<ParamDecl>(); } else { // Globals, locals, and member variables all use the same syntax class. // - return new VarDecl(); + return astBuilder->create<VarDecl>(); } } @@ -1555,6 +1556,7 @@ namespace Slang } static void UnwrapDeclarator( + ASTBuilder* astBuilder, RefPtr<Declarator> declarator, DeclaratorInfo* ioInfo) { @@ -1586,7 +1588,7 @@ namespace Slang // TODO(tfoley): we don't support pointers for now auto arrayDeclarator = (ArrayDeclarator*) declarator.Ptr(); - auto arrayTypeExpr = new IndexExpr(); + auto arrayTypeExpr = astBuilder->create<IndexExpr>(); arrayTypeExpr->loc = arrayDeclarator->openBracketLoc; arrayTypeExpr->baseExpression = ioInfo->typeSpec; arrayTypeExpr->indexExpression = arrayDeclarator->elementCountExpr; @@ -1604,10 +1606,11 @@ namespace Slang } static void UnwrapDeclarator( + ASTBuilder* astBuilder, InitDeclarator const& initDeclarator, DeclaratorInfo* ioInfo) { - UnwrapDeclarator(initDeclarator.declarator, ioInfo); + UnwrapDeclarator(astBuilder, initDeclarator.declarator, ioInfo); ioInfo->semantics = initDeclarator.semantics; ioInfo->initializer = initDeclarator.initializer; } @@ -1618,6 +1621,7 @@ namespace Slang SourceLoc startPosition; RefPtr<Decl> decl; RefPtr<DeclGroup> group; + ASTBuilder* astBuilder = nullptr; // Add a new declaration to the potential group void addDecl( @@ -1627,7 +1631,7 @@ namespace Slang if( decl ) { - group = new DeclGroup(); + group = astBuilder->create<DeclGroup>(); group->loc = startPosition; group->decls.add(decl); decl = nullptr; @@ -1665,7 +1669,7 @@ namespace Slang // // TODO: do this better, e.g. by filling in the `declRef` field directly - auto expr = new VarExpr(); + auto expr = parser->astBuilder->create<VarExpr>(); expr->scope = parser->currentScope.Ptr(); expr->loc = decl->getNameLoc(); expr->name = decl->getName(); @@ -1687,7 +1691,7 @@ namespace Slang Parser* parser, RefPtr<Expr> base) { - RefPtr<GenericAppExpr> genericApp = new GenericAppExpr(); + RefPtr<GenericAppExpr> genericApp = parser->astBuilder->create<GenericAppExpr>(); parser->FillPosition(genericApp.Ptr()); // set up scope for lookup genericApp->functionExpr = base; @@ -1716,7 +1720,7 @@ namespace Slang static bool isGenericName(Parser* parser, Name* name) { auto lookupResult = lookUp( - parser->getSession(), + parser->astBuilder, nullptr, // no semantics visitor available yet name, parser->currentScope); @@ -1771,7 +1775,7 @@ namespace Slang { // When called the :: or . have been consumed, so don't need to consume here. - RefPtr<MemberExpr> memberExpr = new MemberExpr(); + RefPtr<MemberExpr> memberExpr = parser->astBuilder->create<MemberExpr>(); parser->FillPosition(memberExpr.Ptr()); memberExpr->baseExpression = base; @@ -1787,7 +1791,7 @@ namespace Slang auto typeExpr = inTypeExpr; while (parser->LookAheadToken(TokenType::LBracket)) { - RefPtr<IndexExpr> arrType = new IndexExpr(); + RefPtr<IndexExpr> arrType = parser->astBuilder->create<IndexExpr>(); arrType->loc = typeExpr->loc; arrType->baseExpression = typeExpr; parser->ReadToken(TokenType::LBracket); @@ -1803,7 +1807,7 @@ namespace Slang static RefPtr<Expr> parseTaggedUnionType(Parser* parser) { - RefPtr<TaggedUnionTypeExpr> taggedUnionType = new TaggedUnionTypeExpr(); + RefPtr<TaggedUnionTypeExpr> taggedUnionType = parser->astBuilder->create<TaggedUnionTypeExpr>(); parser->ReadToken(TokenType::LParent); while(!AdvanceIfMatch(parser, TokenType::RParent)) @@ -1828,7 +1832,7 @@ namespace Slang /// Parse a `This` type expression static RefPtr<Expr> parseThisTypeExpr(Parser* parser) { - RefPtr<ThisTypeExpr> expr = new ThisTypeExpr(); + RefPtr<ThisTypeExpr> expr = parser->astBuilder->create<ThisTypeExpr>(); expr->scope = parser->currentScope; return expr; } @@ -1900,7 +1904,7 @@ namespace Slang Token typeName = parser->ReadToken(TokenType::Identifier); - auto basicType = new VarExpr(); + auto basicType = parser->astBuilder->create<VarExpr>(); basicType->scope = parser->currentScope.Ptr(); basicType->loc = typeName.loc; basicType->name = typeName.getNameOrNull(); @@ -1945,6 +1949,7 @@ namespace Slang // declaration DeclGroupBuilder declGroupBuilder; declGroupBuilder.startPosition = startPosition; + declGroupBuilder.astBuilder = parser->astBuilder; // The type specifier may include a declaration. E.g., // it might declare a `struct` type. @@ -2030,7 +2035,7 @@ namespace Slang && !initDeclarator.semantics) { // Looks like a function, so parse it like one. - UnwrapDeclarator(initDeclarator, &declaratorInfo); + UnwrapDeclarator(parser->astBuilder, initDeclarator, &declaratorInfo); return parseTraditionalFuncDecl(parser, declaratorInfo); } @@ -2039,8 +2044,8 @@ namespace Slang if( AdvanceIf(parser, TokenType::Semicolon) ) { // easy case: we only had a single declaration! - UnwrapDeclarator(initDeclarator, &declaratorInfo); - RefPtr<VarDeclBase> firstDecl = CreateVarDeclForContext(containerDecl); + UnwrapDeclarator(parser->astBuilder, initDeclarator, &declaratorInfo); + RefPtr<VarDeclBase> firstDecl = CreateVarDeclForContext(parser->astBuilder, containerDecl); CompleteVarDecl(parser, firstDecl, declaratorInfo); declGroupBuilder.addDecl(firstDecl); @@ -2054,16 +2059,16 @@ namespace Slang // about it once, so we need to share structure rather than just // clone syntax. - auto sharedTypeSpec = new SharedTypeExpr(); + auto sharedTypeSpec = parser->astBuilder->create<SharedTypeExpr>(); sharedTypeSpec->loc = typeSpec.expr->loc; sharedTypeSpec->base = TypeExp(typeSpec.expr); for(;;) { declaratorInfo.typeSpec = sharedTypeSpec; - UnwrapDeclarator(initDeclarator, &declaratorInfo); + UnwrapDeclarator(parser->astBuilder, initDeclarator, &declaratorInfo); - RefPtr<VarDeclBase> varDecl = CreateVarDeclForContext(containerDecl); + RefPtr<VarDeclBase> varDecl = CreateVarDeclForContext(parser->astBuilder, containerDecl); CompleteVarDecl(parser, varDecl, declaratorInfo); declGroupBuilder.addDecl(varDecl); @@ -2182,21 +2187,21 @@ namespace Slang { if (parser->LookAheadToken("register")) { - RefPtr<HLSLRegisterSemantic> semantic = new HLSLRegisterSemantic(); + RefPtr<HLSLRegisterSemantic> semantic = parser->astBuilder->create<HLSLRegisterSemantic>(); parser->FillPosition(semantic); parseHLSLRegisterSemantic(parser, semantic.Ptr()); return semantic; } else if (parser->LookAheadToken("packoffset")) { - RefPtr<HLSLPackOffsetSemantic> semantic = new HLSLPackOffsetSemantic(); + RefPtr<HLSLPackOffsetSemantic> semantic = parser->astBuilder->create<HLSLPackOffsetSemantic>(); parser->FillPosition(semantic); parseHLSLPackOffsetSemantic(parser, semantic.Ptr()); return semantic; } else if (parser->LookAheadToken(TokenType::Identifier)) { - RefPtr<HLSLSimpleSemantic> semantic = new HLSLSimpleSemantic(); + RefPtr<HLSLSimpleSemantic> semantic = parser->astBuilder->create<HLSLSimpleSemantic>(); parser->FillPosition(semantic); semantic->name = parser->ReadToken(TokenType::Identifier); return semantic; @@ -2286,8 +2291,8 @@ namespace Slang // We are going to represent each buffer as a pair of declarations. // The first is a type declaration that holds all the members, while // the second is a variable declaration that uses the buffer type. - RefPtr<StructDecl> bufferDataTypeDecl = new StructDecl(); - RefPtr<VarDecl> bufferVarDecl = new VarDecl(); + RefPtr<StructDecl> bufferDataTypeDecl = parser->astBuilder->create<StructDecl>(); + RefPtr<VarDecl> bufferVarDecl = parser->astBuilder->create<VarDecl>(); // Both declarations will have a location that points to the name parser->FillPosition(bufferDataTypeDecl.Ptr()); @@ -2296,7 +2301,7 @@ namespace Slang auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); // Attach the reflection name to the block so we can use it - auto reflectionNameModifier = new ParameterGroupReflectionName(); + auto reflectionNameModifier = parser->astBuilder->create<ParameterGroupReflectionName>(); reflectionNameModifier->nameAndLoc = NameLoc(reflectionNameToken); addModifier(bufferVarDecl, reflectionNameModifier); @@ -2304,8 +2309,8 @@ namespace Slang bufferVarDecl->nameAndLoc.name = generateName(parser, "parameterGroup_" + String(reflectionNameToken.getContent())); bufferDataTypeDecl->nameAndLoc.name = generateName(parser, "ParameterGroup_" + String(reflectionNameToken.getContent())); - addModifier(bufferDataTypeDecl, new ImplicitParameterGroupElementTypeModifier()); - addModifier(bufferVarDecl, new ImplicitParameterGroupVariableModifier()); + addModifier(bufferDataTypeDecl, parser->astBuilder->create<ImplicitParameterGroupElementTypeModifier>()); + addModifier(bufferVarDecl, parser->astBuilder->create<ImplicitParameterGroupVariableModifier>()); // TODO(tfoley): We end up constructing unchecked syntax here that // is expected to type check into the right form, but it might be @@ -2313,13 +2318,13 @@ namespace Slang // these constructs directly into the AST and *then* desugar them. // Construct a type expression to reference the buffer data type - auto bufferDataTypeExpr = new VarExpr(); + auto bufferDataTypeExpr = parser->astBuilder->create<VarExpr>(); bufferDataTypeExpr->loc = bufferDataTypeDecl->loc; bufferDataTypeExpr->name = bufferDataTypeDecl->nameAndLoc.name; bufferDataTypeExpr->scope = parser->currentScope.Ptr(); // Construct a type expression to reference the type constructor - auto bufferWrapperTypeExpr = new VarExpr(); + auto bufferWrapperTypeExpr = parser->astBuilder->create<VarExpr>(); bufferWrapperTypeExpr->loc = bufferWrapperTypeNamePos; bufferWrapperTypeExpr->name = getName(parser, bufferWrapperTypeName); @@ -2329,7 +2334,7 @@ namespace Slang // Construct a type expression that represents the type for the variable, // which is the wrapper type applied to the data type - auto bufferVarTypeExpr = new GenericAppExpr(); + auto bufferVarTypeExpr = parser->astBuilder->create<GenericAppExpr>(); bufferVarTypeExpr->loc = bufferVarDecl->loc; bufferVarTypeExpr->functionExpr = bufferWrapperTypeExpr; bufferVarTypeExpr->arguments.add(bufferDataTypeExpr); @@ -2346,7 +2351,7 @@ namespace Slang // All HLSL buffer declarations are "transparent" in that their // members are implicitly made visible in the parent scope. // We achieve this by applying the transparent modifier to the variable. - auto transparentModifier = new TransparentModifier(); + auto transparentModifier = parser->astBuilder->create<TransparentModifier>(); transparentModifier->next = bufferVarDecl->modifiers.first; bufferVarDecl->modifiers.first = transparentModifier; @@ -2385,7 +2390,7 @@ namespace Slang { auto base = parser->ParseTypeExp(); - auto inheritanceDecl = new InheritanceDecl(); + auto inheritanceDecl = parser->astBuilder->create<InheritanceDecl>(); inheritanceDecl->loc = base.exp->loc; inheritanceDecl->nameAndLoc.name = getName(parser, "$inheritance"); inheritanceDecl->base = base; @@ -2398,7 +2403,7 @@ namespace Slang static RefPtr<RefObject> ParseExtensionDecl(Parser* parser, void* /*userData*/) { - RefPtr<ExtensionDecl> decl = new ExtensionDecl(); + RefPtr<ExtensionDecl> decl = parser->astBuilder->create<ExtensionDecl>(); parser->FillPosition(decl.Ptr()); decl->targetType = parser->ParseTypeExp(); parseOptionalInheritanceClause(parser, decl); @@ -2408,24 +2413,22 @@ namespace Slang } - void parseOptionalGenericConstraints(Parser * parser, ContainerDecl* decl) + void parseOptionalGenericConstraints(Parser* parser, ContainerDecl* decl) { if (AdvanceIf(parser, TokenType::Colon)) { do { - RefPtr<GenericTypeConstraintDecl> paramConstraint = new GenericTypeConstraintDecl(); + RefPtr<GenericTypeConstraintDecl> paramConstraint = parser->astBuilder->create<GenericTypeConstraintDecl>(); parser->FillPosition(paramConstraint); // substitution needs to be filled during check - RefPtr<DeclRefType> paramType = DeclRefType::Create( - parser->getSession(), - DeclRef<Decl>(decl, nullptr)); + RefPtr<DeclRefType> paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl, nullptr)); - RefPtr<SharedTypeExpr> paramTypeExpr = new SharedTypeExpr(); + RefPtr<SharedTypeExpr> paramTypeExpr = parser->astBuilder->create<SharedTypeExpr>(); paramTypeExpr->loc = decl->loc; paramTypeExpr->base.type = paramType; - paramTypeExpr->type = QualType(getTypeType(paramType)); + paramTypeExpr->type = QualType(parser->astBuilder->getTypeType(paramType)); paramConstraint->sub = TypeExp(paramTypeExpr); paramConstraint->sup = parser->ParseTypeExp(); @@ -2435,9 +2438,9 @@ namespace Slang } } - RefPtr<RefObject> parseAssocType(Parser * parser, void *) + RefPtr<RefObject> parseAssocType(Parser* parser, void *) { - RefPtr<AssocTypeDecl> assocTypeDecl = new AssocTypeDecl(); + RefPtr<AssocTypeDecl> assocTypeDecl = parser->astBuilder->create<AssocTypeDecl>(); auto nameToken = parser->ReadToken(TokenType::Identifier); assocTypeDecl->nameAndLoc = NameLoc(nameToken); @@ -2449,7 +2452,7 @@ namespace Slang RefPtr<RefObject> parseGlobalGenericTypeParamDecl(Parser * parser, void *) { - RefPtr<GlobalGenericParamDecl> genParamDecl = new GlobalGenericParamDecl(); + RefPtr<GlobalGenericParamDecl> genParamDecl = parser->astBuilder->create<GlobalGenericParamDecl>(); auto nameToken = parser->ReadToken(TokenType::Identifier); genParamDecl->nameAndLoc = NameLoc(nameToken); genParamDecl->loc = nameToken.loc; @@ -2460,7 +2463,7 @@ namespace Slang RefPtr<RefObject> parseGlobalGenericValueParamDecl(Parser * parser, void *) { - RefPtr<GlobalGenericValueParamDecl> genericParamDecl = new GlobalGenericValueParamDecl(); + RefPtr<GlobalGenericValueParamDecl> genericParamDecl = parser->astBuilder->create<GlobalGenericValueParamDecl>(); auto nameToken = parser->ReadToken(TokenType::Identifier); genericParamDecl->nameAndLoc = NameLoc(nameToken); genericParamDecl->loc = nameToken.loc; @@ -2481,7 +2484,7 @@ namespace Slang static RefPtr<RefObject> parseInterfaceDecl(Parser* parser, void* /*userData*/) { - RefPtr<InterfaceDecl> decl = new InterfaceDecl(); + RefPtr<InterfaceDecl> decl = parser->astBuilder->create<InterfaceDecl>(); parser->FillPosition(decl.Ptr()); decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); @@ -2588,7 +2591,7 @@ namespace Slang // if( !namespaceDecl ) { - namespaceDecl = new NamespaceDecl(); + namespaceDecl = parser->astBuilder->create<NamespaceDecl>(); namespaceDecl->nameAndLoc = nameAndLoc; // In the case where we are creating the first @@ -2613,7 +2616,7 @@ namespace Slang static RefPtr<RefObject> parseConstructorDecl(Parser* parser, void* /*userData*/) { - RefPtr<ConstructorDecl> decl = new ConstructorDecl(); + RefPtr<ConstructorDecl> decl = parser->astBuilder->create<ConstructorDecl>(); parser->FillPosition(decl.Ptr()); parser->PushScope(decl); @@ -2642,15 +2645,15 @@ namespace Slang RefPtr<AccessorDecl> decl; if( AdvanceIf(parser, "get") ) { - decl = new GetterDecl(); + decl = parser->astBuilder->create<GetterDecl>(); } else if( AdvanceIf(parser, "set") ) { - decl = new SetterDecl(); + decl = parser->astBuilder->create<SetterDecl>(); } else if( AdvanceIf(parser, "ref") ) { - decl = new RefAccessorDecl(); + decl = parser->astBuilder->create<RefAccessorDecl>(); } else { @@ -2674,7 +2677,7 @@ namespace Slang static RefPtr<RefObject> ParseSubscriptDecl(Parser* parser, void* /*userData*/) { - RefPtr<SubscriptDecl> decl = new SubscriptDecl(); + RefPtr<SubscriptDecl> decl = parser->astBuilder->create<SubscriptDecl>(); parser->FillPosition(decl.Ptr()); parser->PushScope(decl); @@ -2742,7 +2745,7 @@ namespace Slang static RefPtr<RefObject> parseLetDecl( Parser* parser, void* /*userData*/) { - RefPtr<LetDecl> decl = new LetDecl(); + RefPtr<LetDecl> decl = parser->astBuilder->create<LetDecl>(); parseModernVarDeclCommon(parser, decl); return decl; } @@ -2750,7 +2753,7 @@ namespace Slang static RefPtr<RefObject> parseVarDecl( Parser* parser, void* /*userData*/) { - RefPtr<VarDecl> decl = new VarDecl(); + RefPtr<VarDecl> decl = parser->astBuilder->create<VarDecl>(); parseModernVarDeclCommon(parser, decl); return decl; } @@ -2758,7 +2761,7 @@ namespace Slang static RefPtr<ParamDecl> parseModernParamDecl( Parser* parser) { - RefPtr<ParamDecl> decl = new ParamDecl(); + RefPtr<ParamDecl> decl = parser->astBuilder->create<ParamDecl>(); // TODO: "modern" parameters should not accept keyword-based // modifiers and should only accept `[attribute]` syntax for @@ -2790,7 +2793,7 @@ namespace Slang static RefPtr<RefObject> parseFuncDecl( Parser* parser, void* /*userData*/) { - RefPtr<FuncDecl> decl = new FuncDecl(); + RefPtr<FuncDecl> decl = parser->astBuilder->create<FuncDecl>(); parser->FillPosition(decl.Ptr()); decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); @@ -2812,7 +2815,7 @@ namespace Slang static RefPtr<RefObject> parseTypeAliasDecl( Parser* parser, void* /*userData*/) { - RefPtr<TypeAliasDecl> decl = new TypeAliasDecl(); + RefPtr<TypeAliasDecl> decl = parser->astBuilder->create<TypeAliasDecl>(); parser->FillPosition(decl.Ptr()); decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); @@ -2860,7 +2863,7 @@ namespace Slang // User is specifying the class that should be construted auto classNameAndLoc = expectIdentifier(parser); - syntaxClass = parser->getSession()->findSyntaxClass(classNameAndLoc.name); + syntaxClass = parser->astBuilder->findSyntaxClass(classNameAndLoc.name); } // If the user specified a syntax class, then we will default @@ -2911,7 +2914,7 @@ namespace Slang // TODO: skip creating the declaration if anything failed, just to not screw things // up for downstream code? - RefPtr<SyntaxDecl> syntaxDecl = new SyntaxDecl(); + RefPtr<SyntaxDecl> syntaxDecl = parser->astBuilder->create<SyntaxDecl>(); syntaxDecl->nameAndLoc = nameAndLoc; syntaxDecl->loc = nameAndLoc.loc; syntaxDecl->syntaxClass = syntaxClass; @@ -2929,7 +2932,7 @@ namespace Slang { auto nameAndLoc = expectIdentifier(parser); - RefPtr<ParamDecl> paramDecl = new ParamDecl(); + RefPtr<ParamDecl> paramDecl = parser->astBuilder->create<ParamDecl>(); paramDecl->nameAndLoc = nameAndLoc; if(AdvanceIf(parser, TokenType::Colon)) @@ -2971,7 +2974,7 @@ namespace Slang // First we parse the attribute name. auto nameAndLoc = expectIdentifier(parser); - RefPtr<AttributeDecl> attrDecl = new AttributeDecl(); + RefPtr<AttributeDecl> attrDecl = parser->astBuilder->create<AttributeDecl>(); if(AdvanceIf(parser, TokenType::LParent)) { while(!AdvanceIfMatch(parser, TokenType::RParent)) @@ -2999,7 +3002,7 @@ namespace Slang // User is specifying the class that should be construted auto classNameAndLoc = expectIdentifier(parser); - syntaxClass = parser->getSession()->findSyntaxClass(classNameAndLoc.name); + syntaxClass = parser->astBuilder->findSyntaxClass(classNameAndLoc.name); } else { @@ -3090,7 +3093,7 @@ namespace Slang { advanceToken(parser); - decl = new EmptyDecl(); + decl = parser->astBuilder->create<EmptyDecl>(); decl->loc = loc; } break; @@ -3113,7 +3116,7 @@ namespace Slang // so we want to give later passes a way to detect which modifiers // were shared, vs. which ones are specific to a single declaration. - auto sharedModifiers = new SharedModifiers(); + auto sharedModifiers = parser->astBuilder->create<SharedModifiers>(); sharedModifiers->next = modifiers.first; modifiers.first = sharedModifiers; @@ -3217,7 +3220,7 @@ namespace Slang RefPtr<Decl> Parser::ParseStruct() { - RefPtr<StructDecl> rs = new StructDecl(); + RefPtr<StructDecl> rs = astBuilder->create<StructDecl>(); FillPosition(rs.Ptr()); ReadToken("struct"); @@ -3236,7 +3239,7 @@ namespace Slang RefPtr<ClassDecl> Parser::ParseClass() { - RefPtr<ClassDecl> rs = new ClassDecl(); + RefPtr<ClassDecl> rs = astBuilder->create<ClassDecl>(); FillPosition(rs.Ptr()); ReadToken("class"); rs->nameAndLoc = expectIdentifier(this); @@ -3249,7 +3252,7 @@ namespace Slang static RefPtr<EnumCaseDecl> parseEnumCaseDecl(Parser* parser) { - RefPtr<EnumCaseDecl> decl = new EnumCaseDecl(); + RefPtr<EnumCaseDecl> decl = parser->astBuilder->create<EnumCaseDecl>(); decl->nameAndLoc = expectIdentifier(parser); if(AdvanceIf(parser, TokenType::OpAssign)) @@ -3262,7 +3265,7 @@ namespace Slang static RefPtr<Decl> parseEnumDecl(Parser* parser) { - RefPtr<EnumDecl> decl = new EnumDecl(); + RefPtr<EnumDecl> decl = parser->astBuilder->create<EnumDecl>(); parser->FillPosition(decl); parser->ReadToken("enum"); @@ -3299,7 +3302,7 @@ namespace Slang static RefPtr<Stmt> ParseSwitchStmt(Parser* parser) { - RefPtr<SwitchStmt> stmt = new SwitchStmt(); + RefPtr<SwitchStmt> stmt = parser->astBuilder->create<SwitchStmt>(); parser->FillPosition(stmt.Ptr()); parser->ReadToken("switch"); parser->ReadToken(TokenType::LParent); @@ -3311,7 +3314,7 @@ namespace Slang static RefPtr<Stmt> ParseCaseStmt(Parser* parser) { - RefPtr<CaseStmt> stmt = new CaseStmt(); + RefPtr<CaseStmt> stmt = parser->astBuilder->create<CaseStmt>(); parser->FillPosition(stmt.Ptr()); parser->ReadToken("case"); stmt->expr = parser->ParseExpression(); @@ -3321,7 +3324,7 @@ namespace Slang static RefPtr<Stmt> ParseDefaultStmt(Parser* parser) { - RefPtr<DefaultStmt> stmt = new DefaultStmt(); + RefPtr<DefaultStmt> stmt = parser->astBuilder->create<DefaultStmt>(); parser->FillPosition(stmt.Ptr()); parser->ReadToken("default"); parser->ReadToken(TokenType::Colon); @@ -3331,7 +3334,7 @@ namespace Slang static bool isTypeName(Parser* parser, Name* name) { auto lookupResult = lookUp( - parser->getSession(), + parser->astBuilder, nullptr, // no semantics visitor available yet name, parser->currentScope); @@ -3365,8 +3368,8 @@ namespace Slang RefPtr<Stmt> parseCompileTimeForStmt( Parser* parser) { - RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); - RefPtr<CompileTimeForStmt> stmt = new CompileTimeForStmt(); + RefPtr<ScopeDecl> scopeDecl = parser->astBuilder->create<ScopeDecl>(); + RefPtr<CompileTimeForStmt> stmt = parser->astBuilder->create<CompileTimeForStmt>(); stmt->scopeDecl = scopeDecl; @@ -3374,7 +3377,7 @@ namespace Slang parser->ReadToken(TokenType::LParent); NameLoc varNameAndLoc = expectIdentifier(parser); - RefPtr<VarDecl> varDecl = new VarDecl(); + RefPtr<VarDecl> varDecl = parser->astBuilder->create<VarDecl>(); varDecl->nameAndLoc = varNameAndLoc; varDecl->loc = varNameAndLoc.loc; @@ -3446,7 +3449,7 @@ namespace Slang statement = ParseReturnStatement(); else if (LookAheadToken("discard")) { - statement = new DiscardStmt(); + statement = astBuilder->create<DiscardStmt>(); FillPosition(statement.Ptr()); ReadToken("discard"); ReadToken(TokenType::Semicolon); @@ -3543,7 +3546,7 @@ namespace Slang } else if (LookAheadToken(TokenType::Semicolon)) { - statement = new EmptyStmt(); + statement = astBuilder->create<EmptyStmt>(); FillPosition(statement.Ptr()); ReadToken(TokenType::Semicolon); } @@ -3568,8 +3571,8 @@ namespace Slang RefPtr<Stmt> Parser::parseBlockStatement() { - RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); - RefPtr<BlockStmt> blockStatement = new BlockStmt(); + RefPtr<ScopeDecl> scopeDecl = astBuilder->create<ScopeDecl>(); + RefPtr<BlockStmt> blockStatement = astBuilder->create<BlockStmt>(); blockStatement->scopeDecl = scopeDecl; pushScopeAndSetParent(scopeDecl.Ptr()); ReadToken(TokenType::LBrace); @@ -3595,7 +3598,7 @@ namespace Slang } else { - RefPtr<SeqStmt> newBody = new SeqStmt(); + RefPtr<SeqStmt> newBody = astBuilder->create<SeqStmt>(); newBody->loc = blockStatement->loc; newBody->stmts.add(body); newBody->stmts.add(stmt); @@ -3609,7 +3612,7 @@ namespace Slang if(!body) { - body = new EmptyStmt(); + body = astBuilder->create<EmptyStmt>(); body->loc = blockStatement->loc; } @@ -3620,7 +3623,7 @@ namespace Slang RefPtr<DeclStmt> Parser::parseVarDeclrStatement( Modifiers modifiers) { - RefPtr<DeclStmt>varDeclrStatement = new DeclStmt(); + RefPtr<DeclStmt>varDeclrStatement = astBuilder->create<DeclStmt>(); FillPosition(varDeclrStatement.Ptr()); auto decl = ParseDeclWithModifiers(this, currentScope->containerDecl, modifiers); @@ -3630,7 +3633,7 @@ namespace Slang RefPtr<IfStmt> Parser::parseIfStatement() { - RefPtr<IfStmt> ifStatement = new IfStmt(); + RefPtr<IfStmt> ifStatement = astBuilder->create<IfStmt>(); FillPosition(ifStatement.Ptr()); ReadToken("if"); ReadToken(TokenType::LParent); @@ -3647,7 +3650,7 @@ namespace Slang RefPtr<ForStmt> Parser::ParseForStatement() { - RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); + RefPtr<ScopeDecl> scopeDecl = astBuilder->create<ScopeDecl>(); // HLSL implements the bad approach to scoping a `for` loop // variable, and we want to respect that, but *only* when @@ -3663,11 +3666,11 @@ namespace Slang RefPtr<ForStmt> stmt; if (brokenScoping) { - stmt = new UnscopedForStmt(); + stmt = astBuilder->create<UnscopedForStmt>(); } else { - stmt = new ForStmt(); + stmt = astBuilder->create<ForStmt>(); } stmt->scopeDecl = scopeDecl; @@ -3708,7 +3711,7 @@ namespace Slang RefPtr<WhileStmt> Parser::ParseWhileStatement() { - RefPtr<WhileStmt> whileStatement = new WhileStmt(); + RefPtr<WhileStmt> whileStatement = astBuilder->create<WhileStmt>(); FillPosition(whileStatement.Ptr()); ReadToken("while"); ReadToken(TokenType::LParent); @@ -3720,7 +3723,7 @@ namespace Slang RefPtr<DoWhileStmt> Parser::ParseDoWhileStatement() { - RefPtr<DoWhileStmt> doWhileStatement = new DoWhileStmt(); + RefPtr<DoWhileStmt> doWhileStatement = astBuilder->create<DoWhileStmt>(); FillPosition(doWhileStatement.Ptr()); ReadToken("do"); doWhileStatement->statement = ParseStatement(); @@ -3734,7 +3737,7 @@ namespace Slang RefPtr<BreakStmt> Parser::ParseBreakStatement() { - RefPtr<BreakStmt> breakStatement = new BreakStmt(); + RefPtr<BreakStmt> breakStatement = astBuilder->create<BreakStmt>(); FillPosition(breakStatement.Ptr()); ReadToken("break"); ReadToken(TokenType::Semicolon); @@ -3743,7 +3746,7 @@ namespace Slang RefPtr<ContinueStmt> Parser::ParseContinueStatement() { - RefPtr<ContinueStmt> continueStatement = new ContinueStmt(); + RefPtr<ContinueStmt> continueStatement = astBuilder->create<ContinueStmt>(); FillPosition(continueStatement.Ptr()); ReadToken("continue"); ReadToken(TokenType::Semicolon); @@ -3752,7 +3755,7 @@ namespace Slang RefPtr<ReturnStmt> Parser::ParseReturnStatement() { - RefPtr<ReturnStmt> returnStatement = new ReturnStmt(); + RefPtr<ReturnStmt> returnStatement = astBuilder->create<ReturnStmt>(); FillPosition(returnStatement.Ptr()); ReadToken("return"); if (!LookAheadToken(TokenType::Semicolon)) @@ -3763,7 +3766,7 @@ namespace Slang RefPtr<ExpressionStmt> Parser::ParseExpressionStatement() { - RefPtr<ExpressionStmt> statement = new ExpressionStmt(); + RefPtr<ExpressionStmt> statement = astBuilder->create<ExpressionStmt>(); FillPosition(statement.Ptr()); statement->expression = ParseExpression(); @@ -3774,14 +3777,14 @@ namespace Slang RefPtr<ParamDecl> Parser::ParseParameter() { - RefPtr<ParamDecl> parameter = new ParamDecl(); + RefPtr<ParamDecl> parameter = astBuilder->create<ParamDecl>(); parameter->modifiers = ParseModifiers(this); DeclaratorInfo declaratorInfo; declaratorInfo.typeSpec = ParseType(); InitDeclarator initDeclarator = parseInitDeclarator(this, kDeclaratorParseOption_AllowEmpty); - UnwrapDeclarator(initDeclarator, &declaratorInfo); + UnwrapDeclarator(astBuilder, initDeclarator, &declaratorInfo); // Assume it is a variable-like declarator CompleteVarDecl(this, parameter, declaratorInfo); @@ -3901,7 +3904,7 @@ namespace Slang break; } - auto opExpr = new VarExpr(); + auto opExpr = parser->astBuilder->create<VarExpr>(); opExpr->name = getName(parser, opToken.getContent()); opExpr->scope = parser->currentScope; opExpr->loc = opToken.loc; @@ -3911,12 +3914,12 @@ namespace Slang } static RefPtr<Expr> createInfixExpr( - Parser* /*parser*/, + Parser* parser, RefPtr<Expr> left, RefPtr<Expr> op, RefPtr<Expr> right) { - RefPtr<InfixExpr> expr = new InfixExpr(); + RefPtr<InfixExpr> expr = parser->astBuilder->create<InfixExpr>(); expr->loc = op->loc; expr->functionExpr = op; expr->arguments.add(left); @@ -3943,7 +3946,7 @@ namespace Slang // one non-binary case we need to deal with. if(opTokenType == TokenType::QuestionMark) { - RefPtr<SelectExpr> select = new SelectExpr(); + RefPtr<SelectExpr> select = parser->astBuilder->create<SelectExpr>(); select->loc = op->loc; select->functionExpr = op; @@ -3971,7 +3974,7 @@ namespace Slang if (opTokenType == TokenType::OpAssign) { - RefPtr<AssignExpr> assignExpr = new AssignExpr(); + RefPtr<AssignExpr> assignExpr = parser->astBuilder->create<AssignExpr>(); assignExpr->loc = op->loc; assignExpr->left = expr; assignExpr->right = right; @@ -4069,14 +4072,14 @@ namespace Slang // Parse OOP `this` expression syntax static RefPtr<RefObject> parseThisExpr(Parser* parser, void* /*userData*/) { - RefPtr<ThisExpr> expr = new ThisExpr(); + RefPtr<ThisExpr> expr = parser->astBuilder->create<ThisExpr>(); expr->scope = parser->currentScope; return expr; } - static RefPtr<Expr> parseBoolLitExpr(Parser* /*parser*/, bool value) + static RefPtr<Expr> parseBoolLitExpr(Parser* parser, bool value) { - RefPtr<BoolLiteralExpr> expr = new BoolLiteralExpr(); + RefPtr<BoolLiteralExpr> expr = parser->astBuilder->create<BoolLiteralExpr>(); expr->value = value; return expr; } @@ -4280,7 +4283,7 @@ namespace Slang if (peekTypeName(parser) && parser->LookAheadToken(TokenType::RParent, 1)) { - RefPtr<TypeCastExpr> tcexpr = new ExplicitCastExpr(); + RefPtr<TypeCastExpr> tcexpr = parser->astBuilder->create<ExplicitCastExpr>(); parser->FillPosition(tcexpr.Ptr()); tcexpr->functionExpr = parser->ParseType(); parser->ReadToken(TokenType::RParent); @@ -4295,7 +4298,7 @@ namespace Slang RefPtr<Expr> base = parser->ParseExpression(); parser->ReadToken(TokenType::RParent); - RefPtr<ParenExpr> parenExpr = new ParenExpr(); + RefPtr<ParenExpr> parenExpr = parser->astBuilder->create<ParenExpr>(); parenExpr->loc = openParen.loc; parenExpr->base = base; return parenExpr; @@ -4305,7 +4308,7 @@ namespace Slang // An initializer list `{ expr, ... }` case TokenType::LBrace: { - RefPtr<InitializerListExpr> initExpr = new InitializerListExpr(); + RefPtr<InitializerListExpr> initExpr = parser->astBuilder->create<InitializerListExpr>(); parser->FillPosition(initExpr.Ptr()); // Initializer list @@ -4335,7 +4338,7 @@ namespace Slang case TokenType::IntegerLiteral: { - RefPtr<IntegerLiteralExpr> constExpr = new IntegerLiteralExpr(); + RefPtr<IntegerLiteralExpr> constExpr = parser->astBuilder->create<IntegerLiteralExpr>(); parser->FillPosition(constExpr.Ptr()); auto token = parser->tokenReader.advanceToken(); @@ -4411,9 +4414,9 @@ namespace Slang } value = _fixIntegerLiteral(suffixBaseType, value, &token, parser->sink); - - auto session = parser->getSession(); - Type* suffixType = (suffixBaseType == BaseType::Void) ? session->getErrorType() : session->getBuiltinType(suffixBaseType); + + ASTBuilder* astBuilder = parser->astBuilder; + Type* suffixType = (suffixBaseType == BaseType::Void) ? astBuilder->getErrorType() : astBuilder->getBuiltinType(suffixBaseType); constExpr->value = value; constExpr->type = QualType(suffixType); @@ -4424,7 +4427,7 @@ namespace Slang case TokenType::FloatingPointLiteral: { - RefPtr<FloatingPointLiteralExpr> constExpr = new FloatingPointLiteralExpr(); + RefPtr<FloatingPointLiteralExpr> constExpr = parser->astBuilder->create<FloatingPointLiteralExpr>(); parser->FillPosition(constExpr.Ptr()); auto token = parser->tokenReader.advanceToken(); @@ -4526,9 +4529,9 @@ namespace Slang } } - Session* session = parser->getSession(); + ASTBuilder* astBuilder = parser->astBuilder; - Type* suffixType = (suffixBaseType == BaseType::Void) ? session->getErrorType() : session->getBuiltinType(suffixBaseType); + Type* suffixType = (suffixBaseType == BaseType::Void) ? astBuilder->getErrorType() : astBuilder->getBuiltinType(suffixBaseType); constExpr->value = fixedValue; constExpr->type = QualType(suffixType); @@ -4538,7 +4541,7 @@ namespace Slang case TokenType::StringLiteral: { - RefPtr<StringLiteralExpr> constExpr = new StringLiteralExpr(); + RefPtr<StringLiteralExpr> constExpr = parser->astBuilder->create<StringLiteralExpr>(); auto token = parser->tokenReader.advanceToken(); constExpr->token = token; parser->FillPosition(constExpr.Ptr()); @@ -4580,7 +4583,7 @@ namespace Slang } // Default behavior is just to create a name expression - RefPtr<VarExpr> varExpr = new VarExpr(); + RefPtr<VarExpr> varExpr = parser->astBuilder->create<VarExpr>(); varExpr->scope = parser->currentScope.Ptr(); parser->FillPosition(varExpr.Ptr()); @@ -4611,7 +4614,7 @@ namespace Slang case TokenType::OpInc: case TokenType::OpDec: { - RefPtr<OperatorExpr> postfixExpr = new PostfixExpr(); + RefPtr<OperatorExpr> postfixExpr = parser->astBuilder->create<PostfixExpr>(); parser->FillPosition(postfixExpr.Ptr()); postfixExpr->functionExpr = parseOperator(parser); postfixExpr->arguments.add(expr); @@ -4623,7 +4626,7 @@ namespace Slang // Subscript operation `a[i]` case TokenType::LBracket: { - RefPtr<IndexExpr> indexExpr = new IndexExpr(); + RefPtr<IndexExpr> indexExpr = parser->astBuilder->create<IndexExpr>(); indexExpr->baseExpression = expr; parser->FillPosition(indexExpr.Ptr()); parser->ReadToken(TokenType::LBracket); @@ -4641,7 +4644,7 @@ namespace Slang // Call oepration `f(x)` case TokenType::LParent: { - RefPtr<InvokeExpr> invokeExpr = new InvokeExpr(); + RefPtr<InvokeExpr> invokeExpr = parser->astBuilder->create<InvokeExpr>(); invokeExpr->functionExpr = expr; parser->FillPosition(invokeExpr.Ptr()); parser->ReadToken(TokenType::LParent); @@ -4666,7 +4669,7 @@ namespace Slang // Scope access `x::m` case TokenType::Scope: { - RefPtr<StaticMemberExpr> staticMemberExpr = new StaticMemberExpr(); + RefPtr<StaticMemberExpr> staticMemberExpr = parser->astBuilder->create<StaticMemberExpr>(); // TODO(tfoley): why would a member expression need this? staticMemberExpr->scope = parser->currentScope.Ptr(); @@ -4686,7 +4689,7 @@ namespace Slang // Member access `x.m` case TokenType::Dot: { - RefPtr<MemberExpr> memberExpr = new MemberExpr(); + RefPtr<MemberExpr> memberExpr = parser->astBuilder->create<MemberExpr>(); // TODO(tfoley): why would a member expression need this? memberExpr->scope = parser->currentScope.Ptr(); @@ -4747,7 +4750,7 @@ namespace Slang case TokenType::OpInc: case TokenType::OpDec: { - RefPtr<PrefixExpr> prefixExpr = new PrefixExpr(); + RefPtr<PrefixExpr> prefixExpr = parser->astBuilder->create<PrefixExpr>(); parser->FillPosition(prefixExpr.Ptr()); prefixExpr->functionExpr = parseOperator(parser); @@ -4760,7 +4763,7 @@ namespace Slang case TokenType::OpAdd: case TokenType::OpSub: { - RefPtr<PrefixExpr> prefixExpr = new PrefixExpr(); + RefPtr<PrefixExpr> prefixExpr = parser->astBuilder->create<PrefixExpr>(); parser->FillPosition(prefixExpr.Ptr()); prefixExpr->functionExpr = parseOperator(parser); @@ -4768,7 +4771,7 @@ namespace Slang if (auto intLit = as<IntegerLiteralExpr>(arg)) { - RefPtr<IntegerLiteralExpr> newLiteral = new IntegerLiteralExpr(*intLit); + RefPtr<IntegerLiteralExpr> newLiteral = parser->astBuilder->create<IntegerLiteralExpr>(*intLit); IRIntegerValue value = _foldIntegerPrefixOp(tokenType, newLiteral->value); @@ -4783,7 +4786,7 @@ namespace Slang } else if (auto floatLit = as<FloatingPointLiteralExpr>(arg)) { - RefPtr<FloatingPointLiteralExpr> newLiteral = new FloatingPointLiteralExpr(*floatLit); + RefPtr<FloatingPointLiteralExpr> newLiteral = parser->astBuilder->create<FloatingPointLiteralExpr>(*floatLit); newLiteral->value = _foldFloatPrefixOp(tokenType, floatLit->value); return newLiteral; } @@ -4802,14 +4805,14 @@ namespace Slang } RefPtr<Expr> parseTermFromSourceFile( - Session* session, + ASTBuilder* astBuilder, TokenSpan const& tokens, DiagnosticSink* sink, RefPtr<Scope> const& outerScope, NamePool* namePool, SourceLanguage sourceLanguage) { - Parser parser(session, tokens, sink, outerScope); + Parser parser(astBuilder, tokens, sink, outerScope); parser.currentScope = outerScope; parser.namePool = namePool; parser.sourceLanguage = sourceLanguage; @@ -4818,12 +4821,13 @@ namespace Slang // Parse a source file into an existing translation unit void parseSourceFile( + ASTBuilder* astBuilder, TranslationUnitRequest* translationUnit, TokenSpan const& tokens, DiagnosticSink* sink, RefPtr<Scope> const& outerScope) { - Parser parser(translationUnit->getSession(), tokens, sink, outerScope); + Parser parser(astBuilder, tokens, sink, outerScope); parser.namePool = translationUnit->getNamePool(); parser.sourceLanguage = translationUnit->sourceLanguage; @@ -4840,7 +4844,9 @@ namespace Slang { Name* name = session->getNamePool()->getName(nameText); - RefPtr<SyntaxDecl> syntaxDecl = new SyntaxDecl(); + ASTBuilder* globalASTBuilder = session->getGlobalASTBuilder(); + + RefPtr<SyntaxDecl> syntaxDecl = globalASTBuilder->create<SyntaxDecl>(); syntaxDecl->nameAndLoc = NameLoc(name); syntaxDecl->syntaxClass = syntaxClass; syntaxDecl->parseCallback = callback; @@ -4872,7 +4878,7 @@ namespace Slang static RefPtr<RefObject> parseIntrinsicOpModifier(Parser* parser, void* /*userData*/) { - RefPtr<IntrinsicOpModifier> modifier = new IntrinsicOpModifier(); + RefPtr<IntrinsicOpModifier> modifier = parser->astBuilder->create<IntrinsicOpModifier>(); // We allow a few difference forms here: // @@ -4920,7 +4926,7 @@ namespace Slang static RefPtr<RefObject> parseTargetIntrinsicModifier(Parser* parser, void* /*userData*/) { - auto modifier = new TargetIntrinsicModifier(); + auto modifier = parser->astBuilder->create<TargetIntrinsicModifier>(); if (AdvanceIf(parser, TokenType::LParent)) { @@ -4946,7 +4952,7 @@ namespace Slang static RefPtr<RefObject> parseSpecializedForTargetModifier(Parser* parser, void* /*userData*/) { - auto modifier = new SpecializedForTargetModifier(); + auto modifier = parser->astBuilder->create<SpecializedForTargetModifier>(); if (AdvanceIf(parser, TokenType::LParent)) { modifier->targetToken = parser->ReadToken(TokenType::Identifier); @@ -4957,7 +4963,7 @@ namespace Slang static RefPtr<RefObject> parseGLSLExtensionModifier(Parser* parser, void* /*userData*/) { - auto modifier = new RequiredGLSLExtensionModifier(); + auto modifier = parser->astBuilder->create<RequiredGLSLExtensionModifier>(); parser->ReadToken(TokenType::LParent); modifier->extensionNameToken = parser->ReadToken(TokenType::Identifier); @@ -4968,7 +4974,7 @@ namespace Slang static RefPtr<RefObject> parseGLSLVersionModifier(Parser* parser, void* /*userData*/) { - auto modifier = new RequiredGLSLVersionModifier(); + auto modifier = parser->astBuilder->create<RequiredGLSLVersionModifier>(); parser->ReadToken(TokenType::LParent); modifier->versionNumberToken = parser->ReadToken(TokenType::IntegerLiteral); @@ -5013,7 +5019,7 @@ namespace Slang SemanticVersion version; if (SLANG_SUCCEEDED(parseSemanticVersion(parser, token, version))) { - auto modifier = new RequiredSPIRVVersionModifier(); + auto modifier = parser->astBuilder->create<RequiredSPIRVVersionModifier>(); modifier->version = version; return modifier; } @@ -5027,7 +5033,7 @@ namespace Slang SemanticVersion version; if (SLANG_SUCCEEDED(parseSemanticVersion(parser, token, version))) { - auto modifier = new RequiredCUDASMVersionModifier(); + auto modifier = parser->astBuilder->create<RequiredCUDASMVersionModifier>(); modifier->version = version; return modifier; } @@ -5041,7 +5047,7 @@ namespace Slang RefPtr<UncheckedAttribute> numThreadsAttrib; - listBuilder.add(new GLSLLayoutModifierGroupBegin()); + listBuilder.add(parser->astBuilder->create<GLSLLayoutModifierGroupBegin>()); parser->ReadToken(TokenType::LParent); while (!AdvanceIfMatch(parser, TokenType::RParent)) @@ -5062,7 +5068,7 @@ namespace Slang { if (!numThreadsAttrib) { - numThreadsAttrib = new UncheckedAttribute; + numThreadsAttrib = parser->astBuilder->create<UncheckedAttribute>(); numThreadsAttrib->args.setCount(3); // Just mark the loc and name from the first in the list @@ -5089,7 +5095,7 @@ namespace Slang GLSLBindingAttribute* attr = listBuilder.find<GLSLBindingAttribute>(); if (!attr) { - attr = new GLSLBindingAttribute(); + attr = parser->astBuilder->create<GLSLBindingAttribute>(); listBuilder.add(attr); } @@ -5117,13 +5123,13 @@ namespace Slang { RefPtr<Modifier> modifier; -#define CASE(key, type) if (nameText == #key) { modifier = new type; } else +#define CASE(key, type) if (nameText == #key) { modifier = parser->astBuilder->create<type>(); } else CASE(push_constant, PushConstantAttribute) CASE(shaderRecordNV, ShaderRecordAttribute) CASE(constant_id, GLSLConstantIDLayoutModifier) CASE(location, GLSLLocationLayoutModifier) { - modifier = new GLSLUnparsedLayoutModifier(); + modifier = parser->astBuilder->create<GLSLUnparsedLayoutModifier>(); } SLANG_ASSERT(modifier); #undef CASE @@ -5153,14 +5159,14 @@ namespace Slang listBuilder.add(numThreadsAttrib); } - listBuilder.add(new GLSLLayoutModifierGroupEnd()); + listBuilder.add(parser->astBuilder->create<GLSLLayoutModifierGroupEnd>()); return listBuilder.getFirst(); } static RefPtr<RefObject> parseBuiltinTypeModifier(Parser* parser, void* /*userData*/) { - RefPtr<BuiltinTypeModifier> modifier = new BuiltinTypeModifier(); + RefPtr<BuiltinTypeModifier> modifier = parser->astBuilder->create<BuiltinTypeModifier>(); parser->ReadToken(TokenType::LParent); modifier->tag = BaseType(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent())); parser->ReadToken(TokenType::RParent); @@ -5170,7 +5176,7 @@ namespace Slang static RefPtr<RefObject> parseMagicTypeModifier(Parser* parser, void* /*userData*/) { - RefPtr<MagicTypeModifier> modifier = new MagicTypeModifier(); + RefPtr<MagicTypeModifier> modifier = parser->astBuilder->create<MagicTypeModifier>(); parser->ReadToken(TokenType::LParent); modifier->name = parser->ReadToken(TokenType::Identifier).getContent(); if (AdvanceIf(parser, TokenType::Comma)) @@ -5184,7 +5190,7 @@ namespace Slang static RefPtr<RefObject> parseIntrinsicTypeModifier(Parser* parser, void* /*userData*/) { - RefPtr<IntrinsicTypeModifier> modifier = new IntrinsicTypeModifier(); + RefPtr<IntrinsicTypeModifier> modifier = parser->astBuilder->create<IntrinsicTypeModifier>(); parser->ReadToken(TokenType::LParent); modifier->irOp = uint32_t(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent())); while( AdvanceIf(parser, TokenType::Comma) ) @@ -5198,7 +5204,7 @@ namespace Slang } static RefPtr<RefObject> parseImplicitConversionModifier(Parser* parser, void* /*userData*/) { - RefPtr<ImplicitConversionModifier> modifier = new ImplicitConversionModifier(); + RefPtr<ImplicitConversionModifier> modifier = parser->astBuilder->create<ImplicitConversionModifier>(); ConversionCost cost = kConversionCost_Default; if( AdvanceIf(parser, TokenType::LParent) ) @@ -5216,19 +5222,21 @@ namespace Slang auto syntaxClassNameAndLoc = expectIdentifier(parser); expect(parser, TokenType::RParent); - auto syntaxClass = parser->getSession()->findSyntaxClass(syntaxClassNameAndLoc.name); + auto syntaxClass = parser->astBuilder->findSyntaxClass(syntaxClassNameAndLoc.name); - RefPtr<AttributeTargetModifier> modifier = new AttributeTargetModifier(); + RefPtr<AttributeTargetModifier> modifier = parser->astBuilder->create<AttributeTargetModifier>(); modifier->syntaxClass = syntaxClass; return modifier; } RefPtr<ModuleDecl> populateBaseLanguageModule( - Session* session, + ASTBuilder* astBuilder, RefPtr<Scope> scope) { - RefPtr<ModuleDecl> moduleDecl = new ModuleDecl(); + Session* session = astBuilder->getGlobalSession(); + + RefPtr<ModuleDecl> moduleDecl = astBuilder->create<ModuleDecl>(); scope->containerDecl = moduleDecl; // Add syntax for declaration keywords diff --git a/source/slang/slang-parser.h b/source/slang/slang-parser.h index 1c21b9474..f3587793e 100644 --- a/source/slang/slang-parser.h +++ b/source/slang/slang-parser.h @@ -9,13 +9,14 @@ namespace Slang { // Parse a source file into an existing translation unit void parseSourceFile( + ASTBuilder* astBuilder, TranslationUnitRequest* translationUnit, TokenSpan const& tokens, DiagnosticSink* sink, RefPtr<Scope> const& outerScope); RefPtr<Expr> parseTermFromSourceFile( - Session* session, + ASTBuilder* astBuilder, TokenSpan const& tokens, DiagnosticSink* sink, RefPtr<Scope> const& outerScope, @@ -23,7 +24,7 @@ namespace Slang SourceLanguage sourceLanguage); RefPtr<ModuleDecl> populateBaseLanguageModule( - Session* session, + ASTBuilder* astBuilder, RefPtr<Scope> scope); } diff --git a/source/slang/slang-reflection.cpp b/source/slang/slang-reflection.cpp index 1316efa5d..23d006b3f 100644 --- a/source/slang/slang-reflection.cpp +++ b/source/slang/slang-reflection.cpp @@ -489,7 +489,10 @@ SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName if (!type) return 0; if (auto declRefType = as<DeclRefType>(type)) { - return findUserAttributeByName(declRefType->getSession(), declRefType->declRef.getDecl(), name); + ASTBuilder* astBuilder = declRefType->getASTBuilder(); + auto globalSession = astBuilder->getGlobalSession(); + + return findUserAttributeByName(globalSession, declRefType->declRef.getDecl(), name); } return 0; } diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 70b8a4239..54d27bd57 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -244,10 +244,10 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return false; } - RefPtr<Val> Type::substituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr<Val> Type::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto canSubst = getCanonicalType()->substituteImpl(subst, &diff); + auto canSubst = getCanonicalType()->substituteImpl(astBuilder, subst, &diff); // If nothing changed, then don't drop any sugar that is applied if (!diff) @@ -279,168 +279,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return et->canonicalType; } - void Session::initializeTypes() - { - errorType = new ErrorType(); - errorType->setSession(this); - - initializerListType = new InitializerListType(); - initializerListType->setSession(this); - - overloadedType = new OverloadGroupType(); - overloadedType->setSession(this); - } - - Type* Session::getBoolType() - { - return getBuiltinType(BaseType::Bool); - } - - Type* Session::getHalfType() - { - return getBuiltinType(BaseType::Half); - } - - Type* Session::getFloatType() - { - return getBuiltinType(BaseType::Float); - } - - Type* Session::getDoubleType() - { - return getBuiltinType(BaseType::Double); - } - - Type* Session::getIntType() - { - return getBuiltinType(BaseType::Int); - } - - Type* Session::getInt64Type() - { - return getBuiltinType(BaseType::Int64); - } - - Type* Session::getUIntType() - { - return getBuiltinType(BaseType::UInt); - } - - Type* Session::getUInt64Type() - { - return getBuiltinType(BaseType::UInt64); - } - - Type* Session::getVoidType() - { - return getBuiltinType(BaseType::Void); - } - - Type* Session::getBuiltinType(BaseType flavor) - { - return builtinTypes[int(flavor)]; - } - - Type* Session::getInitializerListType() - { - return initializerListType; - } - - Type* Session::getOverloadedType() - { - return overloadedType; - } - - Type* Session::getErrorType() - { - return errorType; - } - - Type* Session::getStringType() - { - if (stringType == nullptr) - { - auto stringTypeDecl = findMagicDecl(this, "StringType"); - stringType = DeclRefType::Create(this, makeDeclRef<Decl>(stringTypeDecl)); - } - return stringType; - } - - Type* Session::getEnumTypeType() - { - if (enumTypeType == nullptr) - { - auto enumTypeTypeDecl = findMagicDecl(this, "EnumTypeType"); - enumTypeType = DeclRefType::Create(this, makeDeclRef<Decl>(enumTypeTypeDecl)); - } - return enumTypeType; - } - - RefPtr<PtrType> Session::getPtrType( - RefPtr<Type> valueType) - { - return getPtrType(valueType, "PtrType").dynamicCast<PtrType>(); - } - - // Construct the type `Out<valueType>` - RefPtr<OutType> Session::getOutType(RefPtr<Type> valueType) - { - return getPtrType(valueType, "OutType").dynamicCast<OutType>(); - } - - RefPtr<InOutType> Session::getInOutType(RefPtr<Type> valueType) - { - return getPtrType(valueType, "InOutType").dynamicCast<InOutType>(); - } - - RefPtr<RefType> Session::getRefType(RefPtr<Type> valueType) - { - return getPtrType(valueType, "RefType").dynamicCast<RefType>(); - } - - RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, char const* ptrTypeName) - { - auto genericDecl = findMagicDecl(this, ptrTypeName).dynamicCast<GenericDecl>(); - return getPtrType(valueType, genericDecl); - } - - RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, GenericDecl* genericDecl) - { - auto typeDecl = genericDecl->inner; - - auto substitutions = new GenericSubstitution(); - substitutions->genericDecl = genericDecl; - substitutions->args.add(valueType); - - auto declRef = DeclRef<Decl>(typeDecl.Ptr(), substitutions); - auto rsType = DeclRefType::Create( - this, - declRef); - return as<PtrTypeBase>( rsType); - } - - RefPtr<ArrayExpressionType> Session::getArrayType( - Type* elementType, - IntVal* elementCount) - { - RefPtr<ArrayExpressionType> arrayType = new ArrayExpressionType(); - arrayType->setSession(this); - arrayType->baseType = elementType; - arrayType->arrayLength = elementCount; - return arrayType; - } - - SyntaxClass<RefObject> Session::findSyntaxClass(Name* name) - { - SyntaxClass<RefObject> syntaxClass; - if (mapNameToSyntaxClass.TryGetValue(name, syntaxClass)) - return syntaxClass; - - return SyntaxClass<RefObject>(); - } - - - bool ArrayExpressionType::equalsImpl(Type* type) { auto arrType = as<ArrayExpressionType>(type); @@ -449,16 +287,17 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return (areValsEqual(arrayLength, arrType->arrayLength) && baseType->equals(arrType->baseType.Ptr())); } - RefPtr<Val> ArrayExpressionType::substituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr<Val> ArrayExpressionType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto elementType = baseType->substituteImpl(subst, &diff).as<Type>(); - auto arrlen = arrayLength->substituteImpl(subst, &diff).as<IntVal>(); + auto elementType = baseType->substituteImpl(astBuilder, subst, &diff).as<Type>(); + auto arrlen = arrayLength->substituteImpl(astBuilder, subst, &diff).as<IntVal>(); SLANG_ASSERT(arrlen); if (diff) { *ioDiff = 1; auto rsType = getArrayType( + astBuilder, elementType, arrlen); return rsType; @@ -469,11 +308,12 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt RefPtr<Type> ArrayExpressionType::createCanonicalType() { auto canonicalElementType = baseType->getCanonicalType(); - auto canonicalArrayType = getArrayType( + auto canonicalArrayType = getASTBuilder()->getArrayType( canonicalElementType, arrayLength); return canonicalArrayType; } + HashCode ArrayExpressionType::getHashCode() { if (arrayLength) @@ -538,7 +378,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } - RequirementWitness RequirementWitness::specialize(SubstitutionSet const& subst) + RequirementWitness RequirementWitness::specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst) { switch(getFlavor()) { @@ -551,7 +391,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt { int diff = 0; return RequirementWitness( - getDeclRef().SubstituteImpl(subst, &diff)); + getDeclRef().substituteImpl(astBuilder, subst, &diff)); } case RequirementWitness::Flavor::val: @@ -560,12 +400,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt SLANG_ASSERT(val); return RequirementWitness( - val->substitute(subst)); + val->substitute(astBuilder, subst)); } } } RequirementWitness tryLookUpRequirementWitness( + ASTBuilder* astBuilder, SubtypeWitness* subtypeWitness, Decl* requirementKey) { @@ -625,7 +466,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // So, in order to get the *right* end result, we need to apply // the substitutions from the inheritance decl-ref to the witness. // - requirementWitness = requirementWitness.specialize(inheritanceDeclRef.substitutions); + requirementWitness = requirementWitness.specialize(astBuilder, inheritanceDeclRef.substitutions); return requirementWitness; } @@ -637,7 +478,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return RequirementWitness(); } - RefPtr<Val> DeclRefType::substituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr<Val> DeclRefType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { if (!subst) return this; @@ -698,7 +539,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } } int diff = 0; - DeclRef<Decl> substDeclRef = declRef.SubstituteImpl(subst, &diff); + DeclRef<Decl> substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); if (!diff) return this; @@ -726,7 +567,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // We need to look up the declaration that satisfies // the requirement named by the associated type. Decl* requirementKey = substAssocTypeDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(thisSubst->witness, requirementKey); + RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisSubst->witness, requirementKey); switch(requirementWitness.getFlavor()) { default: @@ -746,7 +587,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } // Re-construct the type in case we are using a specialized sub-class - return DeclRefType::Create(getSession(), substDeclRef); + return DeclRefType::create(astBuilder, substDeclRef); } static RefPtr<Type> ExtractGenericArgType(RefPtr<Val> val) @@ -764,7 +605,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } DeclRef<Decl> createDefaultSubstitutionsIfNeeded( - Session* session, + ASTBuilder* astBuilder, DeclRef<Decl> declRef) { // It is possible that `declRef` refers to a generic type, @@ -822,8 +663,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt if(!foundSubst) { - RefPtr<Substitutions> newSubst = createDefaultSubsitutionsForGeneric( - session, + RefPtr<Substitutions> newSubst = createDefaultSubstitutionsForGeneric( + astBuilder, genericParentDecl, nullptr); @@ -837,21 +678,20 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return declRef; int diff = 0; - return declRef.SubstituteImpl(substsToApply, &diff); + return declRef.substituteImpl(astBuilder, substsToApply, &diff); } // TODO: need to figure out how to unify this with the logic // in the generic case... - RefPtr<DeclRefType> DeclRefType::Create( - Session* session, + RefPtr<DeclRefType> DeclRefType::create( + ASTBuilder* astBuilder, DeclRef<Decl> declRef) { - declRef = createDefaultSubstitutionsIfNeeded(session, declRef); + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, declRef); if (auto builtinMod = declRef.getDecl()->findModifier<BuiltinTypeModifier>()) { - auto type = new BasicExpressionType(builtinMod->tag); - type->setSession(session); + auto type = astBuilder->create<BasicExpressionType>(builtinMod->tag); type->declRef = declRef; return type; } @@ -869,8 +709,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt if (magicMod->name == "SamplerState") { - auto type = new SamplerStateType(); - type->setSession(session); + auto type = astBuilder->create<SamplerStateType>(); type->declRef = declRef; type->flavor = SamplerStateFlavor(magicMod->tag); return type; @@ -878,8 +717,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt else if (magicMod->name == "Vector") { SLANG_ASSERT(subst && subst->args.getCount() == 2); - auto vecType = new VectorExpressionType(); - vecType->setSession(session); + auto vecType = astBuilder->create<VectorExpressionType>(); vecType->declRef = declRef; vecType->elementType = ExtractGenericArgType(subst->args[0]); vecType->elementCount = ExtractGenericArgInteger(subst->args[1]); @@ -888,38 +726,34 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt else if (magicMod->name == "Matrix") { SLANG_ASSERT(subst && subst->args.getCount() == 3); - auto matType = new MatrixExpressionType(); - matType->setSession(session); + auto matType = astBuilder->create<MatrixExpressionType>(); matType->declRef = declRef; return matType; } else if (magicMod->name == "Texture") { SLANG_ASSERT(subst && subst->args.getCount() >= 1); - auto textureType = new TextureType( + auto textureType = astBuilder->create<TextureType>( TextureFlavor(magicMod->tag), ExtractGenericArgType(subst->args[0])); - textureType->setSession(session); textureType->declRef = declRef; return textureType; } else if (magicMod->name == "TextureSampler") { SLANG_ASSERT(subst && subst->args.getCount() >= 1); - auto textureType = new TextureSamplerType( + auto textureType = astBuilder->create<TextureSamplerType>( TextureFlavor(magicMod->tag), ExtractGenericArgType(subst->args[0])); - textureType->setSession(session); textureType->declRef = declRef; return textureType; } else if (magicMod->name == "GLSLImageType") { SLANG_ASSERT(subst && subst->args.getCount() >= 1); - auto textureType = new GLSLImageType( + auto textureType = astBuilder->create<GLSLImageType>( TextureFlavor(magicMod->tag), ExtractGenericArgType(subst->args[0])); - textureType->setSession(session); textureType->declRef = declRef; return textureType; } @@ -930,8 +764,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt #define CASE(n,T) \ else if(magicMod->name == #n) { \ - auto type = new T(); \ - type->setSession(session); \ + auto type = astBuilder->create<T>(); \ type->declRef = declRef; \ return type; \ } @@ -944,8 +777,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt #define CASE(n,T) \ else if(magicMod->name == #n) { \ SLANG_ASSERT(subst && subst->args.getCount() == 1); \ - auto type = new T(); \ - type->setSession(session); \ + auto type = astBuilder->create<T>(); \ type->elementType = ExtractGenericArgType(subst->args[0]); \ type->declRef = declRef; \ return type; \ @@ -973,8 +805,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // "magic" builtin types which have no generic parameters #define CASE(n,T) \ else if(magicMod->name == #n) { \ - auto type = new T(); \ - type->setSession(session); \ + auto type = astBuilder->create<T>(); \ type->declRef = declRef; \ return type; \ } @@ -990,8 +821,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt else { - auto classInfo = session->findSyntaxClass( - session->getNamePool()->getName(magicMod->name)); + auto classInfo = astBuilder->findSyntaxClass(magicMod->name.getUnownedSlice()); if (!classInfo.classInfo) { SLANG_UNEXPECTED("unhandled type"); @@ -1008,16 +838,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt { SLANG_UNEXPECTED("expected a declaration reference type"); } - declRefType->session = session; declRefType->declRef = declRef; return declRefType; } } else { - auto type = new DeclRefType(declRef); - type->setSession(session); - return type; + return astBuilder->create<DeclRefType>(declRef); } } @@ -1084,7 +911,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return this; } - RefPtr<Val> ErrorType::substituteImpl(SubstitutionSet /*subst*/, int* /*ioDiff*/) + RefPtr<Val> ErrorType::substituteImpl(ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) { return this; } @@ -1111,7 +938,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt RefPtr<Type> NamedExpressionType::createCanonicalType() { if (!innerType) - innerType = GetType(declRef); + innerType = GetType(m_astBuilder, declRef); return innerType->getCanonicalType(); } @@ -1174,18 +1001,18 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return false; } - RefPtr<Val> FuncType::substituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr<Val> FuncType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; // result type - RefPtr<Type> substResultType = resultType->substituteImpl(subst, &diff).as<Type>(); + RefPtr<Type> substResultType = resultType->substituteImpl(astBuilder, subst, &diff).as<Type>(); // parameter types List<RefPtr<Type>> substParamTypes; for( auto pp : paramTypes ) { - substParamTypes.add(pp->substituteImpl(subst, &diff).as<Type>()); + substParamTypes.add(pp->substituteImpl(astBuilder, subst, &diff).as<Type>()); } // early exit for no change... @@ -1193,8 +1020,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return this; (*ioDiff)++; - RefPtr<FuncType> substType = new FuncType(); - substType->session = session; + RefPtr<FuncType> substType = astBuilder->create<FuncType>(); substType->resultType = substResultType; substType->paramTypes = substParamTypes; return substType; @@ -1212,8 +1038,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt canParamTypes.add(pp->getCanonicalType()); } - RefPtr<FuncType> canType = new FuncType(); - canType->session = session; + RefPtr<FuncType> canType = getASTBuilder()->create<FuncType>(); canType->resultType = resultType; canType->paramTypes = canParamTypes; @@ -1254,8 +1079,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt RefPtr<Type> TypeType::createCanonicalType() { - auto canType = getTypeType(type->getCanonicalType()); - return canType; + return getASTBuilder()->getTypeType(type->getCanonicalType()); } HashCode TypeType::getHashCode() @@ -1381,30 +1205,12 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt { if( !rowType ) { - rowType = getSession()->getVectorType(getElementType(), getColumnCount()); + rowType = m_astBuilder->getVectorType(getElementType(), getColumnCount()); } return rowType; } - RefPtr<VectorExpressionType> Session::getVectorType( - RefPtr<Type> elementType, - RefPtr<IntVal> elementCount) - { - auto vectorGenericDecl = findMagicDecl( - this, "Vector").as<GenericDecl>(); - auto vectorTypeDecl = vectorGenericDecl->inner; - - auto substitutions = new GenericSubstitution(); - substitutions->genericDecl = vectorGenericDecl.Ptr(); - substitutions->args.add(elementType); - substitutions->args.add(elementCount); - - auto declRef = DeclRef<Decl>(vectorTypeDecl.Ptr(), substitutions); - - return DeclRefType::Create( - this, - declRef).as<VectorExpressionType>(); - } + // PtrTypeBase @@ -1435,7 +1241,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return declRef.getHashCode() ^ HashCode(0xFFFF); } - RefPtr<Val> GenericParamIntVal::substituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr<Val> GenericParamIntVal::substituteImpl(ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff) { // search for a substitution that might apply to us for(auto s = subst.substitutions; s; s = s->outer) @@ -1498,8 +1304,9 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return HashCode(typeid(this).hash_code()); } - RefPtr<Val> ErrorIntVal::substituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr<Val> ErrorIntVal::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { + SLANG_UNUSED(astBuilder); SLANG_UNUSED(subst); SLANG_UNUSED(ioDiff); return this; @@ -1507,7 +1314,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // Substitutions - RefPtr<Substitutions> GenericSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) + RefPtr<Substitutions> GenericSubstitution::applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) { int diff = 0; @@ -1516,13 +1323,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt List<RefPtr<Val>> substArgs; for (auto a : args) { - substArgs.add(a->substituteImpl(substSet, &diff)); + substArgs.add(a->substituteImpl(astBuilder, substSet, &diff)); } if (!diff) return this; (*ioDiff)++; - auto substSubst = new GenericSubstitution(); + auto substSubst = astBuilder->create<GenericSubstitution>(); substSubst->genericDecl = genericDecl; substSubst->args = substArgs; substSubst->outer = substOuter; @@ -1560,19 +1367,19 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return true; } - RefPtr<Substitutions> ThisTypeSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) + RefPtr<Substitutions> ThisTypeSubstitution::applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) { int diff = 0; if(substOuter != outer) diff++; // NOTE: Must use .as because we must have a smart pointer here to keep in scope. - auto substWitness = witness->substituteImpl(substSet, &diff).as<SubtypeWitness>(); + auto substWitness = witness->substituteImpl(astBuilder, substSet, &diff).as<SubtypeWitness>(); if (!diff) return this; (*ioDiff)++; - auto substSubst = new ThisTypeSubstitution(); + auto substSubst = astBuilder->create<ThisTypeSubstitution>(); substSubst->interfaceDecl = interfaceDecl; substSubst->witness = substWitness; substSubst->outer = substOuter; @@ -1609,7 +1416,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return witness->getHashCode(); } - RefPtr<Substitutions> GlobalGenericParamSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) + RefPtr<Substitutions> GlobalGenericParamSubstitution::applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) { // if we find a GlobalGenericParamSubstitution in subst that references the same type_param decl // return a copy of that GlobalGenericParamSubstitution @@ -1617,14 +1424,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt if(substOuter != outer) diff++; - auto substActualType = actualType->substituteImpl(substSet, &diff).as<Type>(); + auto substActualType = actualType->substituteImpl(astBuilder, substSet, &diff).as<Type>(); List<ConstraintArg> substConstraintArgs; for(auto constraintArg : constraintArgs) { ConstraintArg substConstraintArg; substConstraintArg.decl = constraintArg.decl; - substConstraintArg.val = constraintArg.val->substituteImpl(substSet, &diff); + substConstraintArg.val = constraintArg.val->substituteImpl(astBuilder, substSet, &diff); substConstraintArgs.add(substConstraintArg); } @@ -1634,7 +1441,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt (*ioDiff)++; - RefPtr<GlobalGenericParamSubstitution> substSubst = new GlobalGenericParamSubstitution(); + RefPtr<GlobalGenericParamSubstitution> substSubst = astBuilder->create<GlobalGenericParamSubstitution>(); substSubst->paramDecl = paramDecl; substSubst->actualType = substActualType; substSubst->constraintArgs = substConstraintArgs; @@ -1670,7 +1477,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // DeclRefBase - RefPtr<Type> DeclRefBase::Substitute(RefPtr<Type> type) const + RefPtr<Type> DeclRefBase::substitute(ASTBuilder* astBuilder, RefPtr<Type> type) const { // Note that type can be nullptr, and so this function can return nullptr (although only correctly when no substitutions) @@ -1682,19 +1489,19 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // Otherwise we need to recurse on the type structure // and apply substitutions where it makes sense - return type->substitute(substitutions).as<Type>(); + return type->substitute(astBuilder, substitutions).as<Type>(); } - DeclRefBase DeclRefBase::Substitute(DeclRefBase declRef) const + DeclRefBase DeclRefBase::substitute(ASTBuilder* astBuilder, DeclRefBase declRef) const { if(!substitutions) return declRef; int diff = 0; - return declRef.SubstituteImpl(substitutions, &diff); + return declRef.substituteImpl(astBuilder, substitutions, &diff); } - RefPtr<Expr> DeclRefBase::Substitute(RefPtr<Expr> expr) const + RefPtr<Expr> DeclRefBase::substitute(ASTBuilder* /* astBuilder*/, RefPtr<Expr> expr) const { // No substitutions? Easy. if (!substitutions) @@ -1740,16 +1547,18 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } RefPtr<Substitutions> specializeSubstitutionsShallow( + ASTBuilder* astBuilder, RefPtr<Substitutions> substToSpecialize, RefPtr<Substitutions> substsToApply, RefPtr<Substitutions> restSubst, int* ioDiff) { SLANG_ASSERT(substToSpecialize); - return substToSpecialize->applySubstitutionsShallow(substsToApply, restSubst, ioDiff); + return substToSpecialize->applySubstitutionsShallow(astBuilder, substsToApply, restSubst, ioDiff); } RefPtr<Substitutions> specializeGlobalGenericSubstitutions( + ASTBuilder* astBuilder, Decl* declToSpecialize, RefPtr<Substitutions> substsToSpecialize, RefPtr<Substitutions> substsToApply, @@ -1768,6 +1577,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt int diff = 0; auto restSubst = specializeGlobalGenericSubstitutions( + astBuilder, declToSpecialize, specSubst->outer, substsToApply, @@ -1775,6 +1585,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt ioParametersFound); auto firstSubst = specializeSubstitutionsShallow( + astBuilder, specGlobalGenericSubst, substsToApply, restSubst, @@ -1819,7 +1630,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt if(ioParametersFound.Contains(appGlobalGenericSubst->paramDecl)) continue; - RefPtr<GlobalGenericParamSubstitution> newSubst = new GlobalGenericParamSubstitution(); + RefPtr<GlobalGenericParamSubstitution> newSubst = astBuilder->create<GlobalGenericParamSubstitution>(); newSubst->paramDecl = appGlobalGenericSubst->paramDecl; newSubst->actualType = appGlobalGenericSubst->actualType; newSubst->constraintArgs = appGlobalGenericSubst->constraintArgs; @@ -1832,6 +1643,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } RefPtr<Substitutions> specializeGlobalGenericSubstitutions( + ASTBuilder* astBuilder, Decl* declToSpecialize, RefPtr<Substitutions> substsToSpecialize, RefPtr<Substitutions> substsToApply, @@ -1840,13 +1652,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // Keep track of any parameters already present in the // existing substitution. HashSet<GlobalGenericParamDecl*> parametersFound; - return specializeGlobalGenericSubstitutions(declToSpecialize, substsToSpecialize, substsToApply, ioDiff, parametersFound); + return specializeGlobalGenericSubstitutions(astBuilder, declToSpecialize, substsToSpecialize, substsToApply, ioDiff, parametersFound); } // Construct new substitutions to apply to a declaration, // based on a provided substitution set to be applied RefPtr<Substitutions> specializeSubstitutions( + ASTBuilder* astBuilder, Decl* declToSpecialize, RefPtr<Substitutions> substsToSpecialize, RefPtr<Substitutions> substsToApply, @@ -1880,12 +1693,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // keep one matching it in place. int diff = 0; auto restSubst = specializeSubstitutions( + astBuilder, ancestorGenericDecl->parentDecl, specGenericSubst->outer, substsToApply, &diff); auto firstSubst = specializeSubstitutionsShallow( + astBuilder, specGenericSubst, substsToApply, restSubst, @@ -1920,12 +1735,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt int diff = 0; auto restSubst = specializeSubstitutions( + astBuilder, ancestorGenericDecl->parentDecl, substsToSpecialize, substsToApply, &diff); - RefPtr<GenericSubstitution> firstSubst = new GenericSubstitution(); + RefPtr<GenericSubstitution> firstSubst = astBuilder->create<GenericSubstitution>(); firstSubst->genericDecl = ancestorGenericDecl; firstSubst->args = appGenericSubst->args; firstSubst->outer = restSubst; @@ -1950,12 +1766,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // keep one matching it in place. int diff = 0; auto restSubst = specializeSubstitutions( + astBuilder, ancestorInterfaceDecl->parentDecl, specThisTypeSubst->outer, substsToApply, &diff); auto firstSubst = specializeSubstitutionsShallow( + astBuilder, specThisTypeSubst, substsToApply, restSubst, @@ -1980,12 +1798,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt int diff = 0; auto restSubst = specializeSubstitutions( + astBuilder, ancestorInterfaceDecl->parentDecl, substsToSpecialize, substsToApply, &diff); - RefPtr<ThisTypeSubstitution> firstSubst = new ThisTypeSubstitution(); + RefPtr<ThisTypeSubstitution> firstSubst = astBuilder->create<ThisTypeSubstitution>(); firstSubst->interfaceDecl = ancestorInterfaceDecl; firstSubst->witness = appThisTypeSubst->witness; firstSubst->outer = restSubst; @@ -2016,13 +1835,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // of them were applicable. // return specializeGlobalGenericSubstitutions( + astBuilder, declToSpecialize, substsToSpecialize, substsToApply, ioDiff); } - DeclRefBase DeclRefBase::SubstituteImpl(SubstitutionSet substSet, int* ioDiff) + DeclRefBase DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet substSet, int* ioDiff) { // Nothing to do when we have no declaration. if(!decl) @@ -2033,6 +1853,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt int diff = 0; auto substSubst = specializeSubstitutions( + astBuilder, decl, substitutions.substitutions, substSet.substitutions, @@ -2138,14 +1959,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // Val - RefPtr<Val> Val::substitute(SubstitutionSet subst) + RefPtr<Val> Val::substitute(ASTBuilder* astBuilder, SubstitutionSet subst) { if (!subst) return this; int diff = 0; - return substituteImpl(subst, &diff); + return substituteImpl(astBuilder, subst, &diff); } - RefPtr<Val> Val::substituteImpl(SubstitutionSet /*subst*/, int* /*ioDiff*/) + RefPtr<Val> Val::substituteImpl(ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) { // Default behavior is to not substitute at all return this; @@ -2182,64 +2003,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return (HashCode) value; } - // - - void registerBuiltinDecl( - Session* session, - RefPtr<Decl> decl, - RefPtr<BuiltinTypeModifier> modifier) - { - auto type = DeclRefType::Create( - session, - DeclRef<Decl>(decl.Ptr(), nullptr)); - session->builtinTypes[(int)modifier->tag] = type; - } - - void registerMagicDecl( - Session* session, - RefPtr<Decl> decl, - RefPtr<MagicTypeModifier> modifier) - { - // In some cases the modifier will have been applied to the - // "inner" declaration of a `GenericDecl`, but what we - // actually want to register is the generic itself. - // - auto declToRegister = decl; - if(auto genericDecl = as<GenericDecl>(decl->parentDecl)) - declToRegister = genericDecl; - - session->magicDecls[modifier->name] = declToRegister.Ptr(); - } - - RefPtr<Decl> findMagicDecl( - Session* session, - String const& name) - { - return session->magicDecls[name].GetValue(); - } - - // - - SyntaxNodeBase* createInstanceOfSyntaxClassByName( - String const& name) - { - if(0) {} - #define CASE(NAME) \ - else if(name == #NAME) return new NAME() - - CASE(GLSLBufferModifier); - CASE(GLSLWriteOnlyModifier); - CASE(GLSLReadOnlyModifier); - CASE(GLSLPatchModifier); - CASE(SimpleModifier); - - #undef CASE - else - { - SLANG_UNEXPECTED("unhandled syntax class name"); - UNREACHABLE_RETURN(nullptr); - } - } + // @@ -2258,72 +2022,59 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // Constructors for types RefPtr<ArrayExpressionType> getArrayType( + ASTBuilder* astBuilder, Type* elementType, IntVal* elementCount) { - auto session = elementType->getSession(); - auto arrayType = new ArrayExpressionType(); - arrayType->setSession(session); + auto arrayType = astBuilder->create<ArrayExpressionType>(); arrayType->baseType = elementType; arrayType->arrayLength = elementCount; return arrayType; } RefPtr<ArrayExpressionType> getArrayType( + ASTBuilder* astBuilder, Type* elementType) { - auto session = elementType->getSession(); - auto arrayType = new ArrayExpressionType(); - arrayType->setSession(session); + auto arrayType = astBuilder->create<ArrayExpressionType>(); arrayType->baseType = elementType; return arrayType; } RefPtr<NamedExpressionType> getNamedType( - Session* session, + ASTBuilder* astBuilder, DeclRef<TypeDefDecl> const& declRef) { - DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(session, declRef).as<TypeDefDecl>(); + DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, declRef).as<TypeDefDecl>(); - auto namedType = new NamedExpressionType(specializedDeclRef); - namedType->setSession(session); - return namedType; - } - - RefPtr<TypeType> getTypeType( - Type* type) - { - auto session = type->getSession(); - auto typeType = new TypeType(type); - typeType->setSession(session); - return typeType; + return astBuilder->create<NamedExpressionType>(specializedDeclRef); } + RefPtr<FuncType> getFuncType( - Session* session, + ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declRef) { - RefPtr<FuncType> funcType = new FuncType(); - funcType->setSession(session); + RefPtr<FuncType> funcType = astBuilder->create<FuncType>(); - funcType->resultType = GetResultType(declRef); + funcType->resultType = GetResultType(astBuilder, declRef); for (auto paramDeclRef : GetParameters(declRef)) { auto paramDecl = paramDeclRef.getDecl(); - auto paramType = GetType(paramDeclRef); + auto paramType = GetType(astBuilder, paramDeclRef); if( paramDecl->findModifier<RefModifier>() ) { - paramType = session->getRefType(paramType); + paramType = astBuilder->getRefType(paramType); } else if( paramDecl->findModifier<OutModifier>() ) { if(paramDecl->findModifier<InOutModifier>() || paramDecl->findModifier<InModifier>()) { - paramType = session->getInOutType(paramType); + paramType = astBuilder->getInOutType(paramType); } else { - paramType = session->getOutType(paramType); + paramType = astBuilder->getOutType(paramType); } } funcType->paramTypes.add(paramType); @@ -2333,30 +2084,25 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } RefPtr<GenericDeclRefType> getGenericDeclRefType( - Session* session, + ASTBuilder* astBuilder, DeclRef<GenericDecl> const& declRef) { - auto genericDeclRefType = new GenericDeclRefType(declRef); - genericDeclRefType->setSession(session); - return genericDeclRefType; + return astBuilder->create<GenericDeclRefType>(declRef); } RefPtr<NamespaceType> getNamespaceType( - Session* session, + ASTBuilder* astBuilder, DeclRef<NamespaceDeclBase> const& declRef) { - auto type = new NamespaceType; - type->setSession(session); + auto type = astBuilder->create<NamespaceType>(); type->declRef = declRef; return type; } RefPtr<SamplerStateType> getSamplerStateType( - Session* session) + ASTBuilder* astBuilder) { - auto samplerStateType = new SamplerStateType(); - samplerStateType->setSession(session); - return samplerStateType; + return astBuilder->create<SamplerStateType>(); } // TODO: should really have a `type.cpp` and a `witness.cpp` @@ -2369,11 +2115,11 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return sub->equals(otherWitness->sub); } - RefPtr<Val> TypeEqualityWitness::substituteImpl(SubstitutionSet subst, int * ioDiff) + RefPtr<Val> TypeEqualityWitness::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { - RefPtr<TypeEqualityWitness> rs = new TypeEqualityWitness(); - rs->sub = sub->substituteImpl(subst, ioDiff).as<Type>(); - rs->sup = sup->substituteImpl(subst, ioDiff).as<Type>(); + RefPtr<TypeEqualityWitness> rs = astBuilder->create<TypeEqualityWitness>(); + rs->sub = sub->substituteImpl(astBuilder, subst, ioDiff).as<Type>(); + rs->sup = sup->substituteImpl(astBuilder, subst, ioDiff).as<Type>(); return rs; } @@ -2417,7 +2163,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - RefPtr<Val> DeclaredSubtypeWitness::substituteImpl(SubstitutionSet subst, int * ioDiff) + RefPtr<Val> DeclaredSubtypeWitness::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { if (auto genConstraintDeclRef = declRef.as<GenericTypeConstraintDecl>()) { @@ -2477,9 +2223,9 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // Perform substitution on the constituent elements. int diff = 0; - auto substSub = sub->substituteImpl(subst, &diff).as<Type>(); - auto substSup = sup->substituteImpl(subst, &diff).as<Type>(); - auto substDeclRef = declRef.SubstituteImpl(subst, &diff); + auto substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); + auto substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); + auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); if (!diff) return this; @@ -2509,7 +2255,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // We need to look up the declaration that satisfies // the requirement named by the associated type. Decl* requirementKey = substTypeConstraintDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(thisTypeSubst->witness, requirementKey); + RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeSubst->witness, requirementKey); switch(requirementWitness.getFlavor()) { default: @@ -2529,7 +2275,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt - RefPtr<DeclaredSubtypeWitness> rs = new DeclaredSubtypeWitness(); + RefPtr<DeclaredSubtypeWitness> rs = astBuilder->create<DeclaredSubtypeWitness>(); rs->sub = substSub; rs->sup = substSup; rs->declRef = substDeclRef; @@ -2568,14 +2314,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt && midToSup.equals(otherWitness->midToSup); } - RefPtr<Val> TransitiveSubtypeWitness::substituteImpl(SubstitutionSet subst, int * ioDiff) + RefPtr<Val> TransitiveSubtypeWitness::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { int diff = 0; - RefPtr<Type> substSub = sub->substituteImpl(subst, &diff).as<Type>(); - RefPtr<Type> substSup = sup->substituteImpl(subst, &diff).as<Type>(); - RefPtr<SubtypeWitness> substSubToMid = subToMid->substituteImpl(subst, &diff).as<SubtypeWitness>(); - DeclRef<Decl> substMidToSup = midToSup.SubstituteImpl(subst, &diff); + RefPtr<Type> substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); + RefPtr<Type> substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); + RefPtr<SubtypeWitness> substSubToMid = subToMid->substituteImpl(astBuilder, subst, &diff).as<SubtypeWitness>(); + DeclRef<Decl> substMidToSup = midToSup.substituteImpl(astBuilder, subst, &diff); // If nothing changed, then we can bail out early. if (!diff) @@ -2601,7 +2347,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // In the simple case, we just construct a new transitive subtype // witness, and we move on with life. - RefPtr<TransitiveSubtypeWitness> result = new TransitiveSubtypeWitness(); + RefPtr<TransitiveSubtypeWitness> result = astBuilder->create<TransitiveSubtypeWitness>(); result->sub = substSub; result->sup = substSup; result->subToMid = substSubToMid; @@ -2695,16 +2441,16 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return this; } - RefPtr<Val> ExtractExistentialType::substituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr<Val> ExtractExistentialType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substDeclRef = declRef.SubstituteImpl(subst, &diff); + auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); if(!diff) return this; (*ioDiff)++; - RefPtr<ExtractExistentialType> substValue = new ExtractExistentialType(); + RefPtr<ExtractExistentialType> substValue = astBuilder->create<ExtractExistentialType>(); substValue->declRef = declRef; return substValue; } @@ -2734,20 +2480,20 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return declRef.getHashCode(); } - RefPtr<Val> ExtractExistentialSubtypeWitness::substituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr<Val> ExtractExistentialSubtypeWitness::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substDeclRef = declRef.SubstituteImpl(subst, &diff); - auto substSub = sub->substituteImpl(subst, &diff).as<Type>(); - auto substSup = sup->substituteImpl(subst, &diff).as<Type>(); + auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + auto substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); + auto substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); if(!diff) return this; (*ioDiff)++; - RefPtr<ExtractExistentialSubtypeWitness> substValue = new ExtractExistentialSubtypeWitness(); + RefPtr<ExtractExistentialSubtypeWitness> substValue = astBuilder->create<ExtractExistentialSubtypeWitness>(); substValue->declRef = declRef; substValue->sub = substSub; substValue->sup = substSup; @@ -2804,9 +2550,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt RefPtr<Type> TaggedUnionType::createCanonicalType() { - RefPtr<TaggedUnionType> canType = new TaggedUnionType(); - canType->setSession(getSession()); - + RefPtr<TaggedUnionType> canType = m_astBuilder->create<TaggedUnionType>(); + for( auto caseType : caseTypes ) { auto canCaseType = caseType->getCanonicalType(); @@ -2816,22 +2561,21 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return canType; } - RefPtr<Val> TaggedUnionType::substituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr<Val> TaggedUnionType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; List<RefPtr<Type>> substCaseTypes; for( auto caseType : caseTypes ) { - substCaseTypes.add(caseType->substituteImpl(subst, &diff).as<Type>()); + substCaseTypes.add(caseType->substituteImpl(astBuilder, subst, &diff).as<Type>()); } if(!diff) return this; (*ioDiff)++; - RefPtr<TaggedUnionType> substType = new TaggedUnionType(); - substType->setSession(getSession()); + RefPtr<TaggedUnionType> substType = astBuilder->create<TaggedUnionType>(); substType->caseTypes.swapWith(substCaseTypes); return substType; } @@ -2885,17 +2629,17 @@ HashCode TaggedUnionSubtypeWitness::getHashCode() return hash; } -RefPtr<Val> TaggedUnionSubtypeWitness::substituteImpl(SubstitutionSet subst, int* ioDiff) +RefPtr<Val> TaggedUnionSubtypeWitness::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substSub = sub->substituteImpl(subst, &diff).as<Type>(); - auto substSup = sup->substituteImpl(subst, &diff).as<Type>(); + auto substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); + auto substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); List<RefPtr<Val>> substCaseWitnesses; for( auto caseWitness : caseWitnesses ) { - substCaseWitnesses.add(caseWitness->substituteImpl(subst, &diff)); + substCaseWitnesses.add(caseWitness->substituteImpl(astBuilder, subst, &diff)); } if(!diff) @@ -2903,7 +2647,7 @@ RefPtr<Val> TaggedUnionSubtypeWitness::substituteImpl(SubstitutionSet subst, int (*ioDiff)++; - RefPtr<TaggedUnionSubtypeWitness> substWitness = new TaggedUnionSubtypeWitness(); + RefPtr<TaggedUnionSubtypeWitness> substWitness = astBuilder->create<TaggedUnionSubtypeWitness>(); substWitness->sub = substSub; substWitness->sup = substSup; substWitness->caseWitnesses.swapWith(substCaseWitnesses); @@ -3027,9 +2771,8 @@ RefPtr<Val> getCanonicalValue(Val* val) RefPtr<Type> ExistentialSpecializedType::createCanonicalType() { - RefPtr<ExistentialSpecializedType> canType = new ExistentialSpecializedType(); - canType->setSession(getSession()); - + RefPtr<ExistentialSpecializedType> canType = m_astBuilder->create<ExistentialSpecializedType>(); + canType->baseType = baseType->getCanonicalType(); for( auto arg : args ) { @@ -3041,24 +2784,24 @@ RefPtr<Type> ExistentialSpecializedType::createCanonicalType() return canType; } -RefPtr<Val> substituteImpl(Val* val, SubstitutionSet subst, int* ioDiff) +RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, Val* val, SubstitutionSet subst, int* ioDiff) { if(!val) return nullptr; - return val->substituteImpl(subst, ioDiff); + return val->substituteImpl(astBuilder, subst, ioDiff); } -RefPtr<Val> ExistentialSpecializedType::substituteImpl(SubstitutionSet subst, int* ioDiff) +RefPtr<Val> ExistentialSpecializedType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substBaseType = baseType->substituteImpl(subst, &diff).as<Type>(); + auto substBaseType = baseType->substituteImpl(astBuilder, subst, &diff).as<Type>(); ExpandedSpecializationArgs substArgs; for( auto arg : args ) { ExpandedSpecializationArg substArg; - substArg.val = Slang::substituteImpl(arg.val, subst, &diff); - substArg.witness = Slang::substituteImpl(arg.witness, subst, &diff); + substArg.val = Slang::substituteImpl(astBuilder, arg.val, subst, &diff); + substArg.witness = Slang::substituteImpl(astBuilder, arg.witness, subst, &diff); substArgs.add(substArg); } @@ -3067,8 +2810,7 @@ RefPtr<Val> ExistentialSpecializedType::substituteImpl(SubstitutionSet subst, in (*ioDiff)++; - RefPtr<ExistentialSpecializedType> substType = new ExistentialSpecializedType(); - substType->setSession(getSession()); + RefPtr<ExistentialSpecializedType> substType = astBuilder->create<ExistentialSpecializedType>(); substType->baseType = substBaseType; substType->args = substArgs; return substType; @@ -3107,19 +2849,18 @@ HashCode ThisType::getHashCode() RefPtr<Type> ThisType::createCanonicalType() { - RefPtr<ThisType> canType = new ThisType(); - canType->setSession(getSession()); - + RefPtr<ThisType> canType = m_astBuilder->create<ThisType>(); + // TODO: need to canonicalize the decl-ref canType->interfaceDeclRef = interfaceDeclRef; return canType; } -RefPtr<Val> ThisType::substituteImpl(SubstitutionSet subst, int* ioDiff) +RefPtr<Val> ThisType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substInterfaceDeclRef = interfaceDeclRef.SubstituteImpl(subst, &diff); + auto substInterfaceDeclRef = interfaceDeclRef.substituteImpl(astBuilder, subst, &diff); auto thisTypeSubst = findThisTypeSubstitution(subst.substitutions, substInterfaceDeclRef.getDecl()); if( thisTypeSubst ) @@ -3132,8 +2873,7 @@ RefPtr<Val> ThisType::substituteImpl(SubstitutionSet subst, int* ioDiff) (*ioDiff)++; - RefPtr<ThisType> substType = new ThisType(); - substType->setSession(getSession()); + RefPtr<ThisType> substType = m_astBuilder->create<ThisType>(); substType->interfaceDeclRef = substInterfaceDeclRef; return substType; } diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index d1e626f25..266f27599 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -1,45 +1,21 @@ #ifndef SLANG_SYNTAX_H #define SLANG_SYNTAX_H -#include "slang-ast-support-types.h" - -#include "slang-ast-all.h" +#include "slang-ast-builder.h" namespace Slang { - inline RefPtr<Type> GetSub(DeclRef<GenericTypeConstraintDecl> const& declRef) + inline RefPtr<Type> GetSub(ASTBuilder* astBuilder, DeclRef<GenericTypeConstraintDecl> const& declRef) { - return declRef.Substitute(declRef.getDecl()->sub.Ptr()); + return declRef.substitute(astBuilder, declRef.getDecl()->sub.Ptr()); } - inline RefPtr<Type> GetSup(DeclRef<TypeConstraintDecl> const& declRef) + inline RefPtr<Type> GetSup(ASTBuilder* astBuilder, DeclRef<TypeConstraintDecl> const& declRef) { - return declRef.Substitute(declRef.getDecl()->getSup().type); + return declRef.substitute(astBuilder, declRef.getDecl()->getSup().type); } - // Note(tfoley): These logically belong to `Type`, - // but order-of-declaration stuff makes that tricky - // - // TODO(tfoley): These should really belong to the compilation context! - // - void registerBuiltinDecl( - Session* session, - RefPtr<Decl> decl, - RefPtr<BuiltinTypeModifier> modifier); - void registerMagicDecl( - Session* session, - RefPtr<Decl> decl, - RefPtr<MagicTypeModifier> modifier); - - // Look up a magic declaration by its name - RefPtr<Decl> findMagicDecl( - Session* session, - String const& name); - - // Create an instance of a syntax class by name - SyntaxNodeBase* createInstanceOfSyntaxClassByName( - String const& name); // `Val` @@ -119,29 +95,29 @@ namespace Slang /// Name* getReflectionName(VarDeclBase* varDecl); - inline RefPtr<Type> GetType(DeclRef<VarDeclBase> const& declRef) + inline RefPtr<Type> GetType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& declRef) { - return declRef.Substitute(declRef.getDecl()->type.Ptr()); + return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } - inline RefPtr<Expr> getInitExpr(DeclRef<VarDeclBase> const& declRef) + inline RefPtr<Expr> getInitExpr(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& declRef) { - return declRef.Substitute(declRef.getDecl()->initExpr); + return declRef.substitute(astBuilder, declRef.getDecl()->initExpr); } - inline RefPtr<Type> getType(DeclRef<EnumCaseDecl> const& declRef) + inline RefPtr<Type> getType(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> const& declRef) { - return declRef.Substitute(declRef.getDecl()->type.Ptr()); + return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } - inline RefPtr<Expr> getTagExpr(DeclRef<EnumCaseDecl> const& declRef) + inline RefPtr<Expr> getTagExpr(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> const& declRef) { - return declRef.Substitute(declRef.getDecl()->tagExpr); + return declRef.substitute(astBuilder, declRef.getDecl()->tagExpr); } - inline RefPtr<Type> GetTargetType(DeclRef<ExtensionDecl> const& declRef) + inline RefPtr<Type> GetTargetType(ASTBuilder* astBuilder, DeclRef<ExtensionDecl> const& declRef) { - return declRef.Substitute(declRef.getDecl()->targetType.Ptr()); + return declRef.substitute(astBuilder, declRef.getDecl()->targetType.Ptr()); } inline FilteredMemberRefList<VarDecl> GetFields(DeclRef<StructDecl> const& declRef, MemberFilterStyle filterStyle) @@ -151,19 +127,19 @@ namespace Slang - inline RefPtr<Type> getBaseType(DeclRef<InheritanceDecl> const& declRef) + inline RefPtr<Type> getBaseType(ASTBuilder* astBuilder, DeclRef<InheritanceDecl> const& declRef) { - return declRef.Substitute(declRef.getDecl()->base.type); + return declRef.substitute(astBuilder, declRef.getDecl()->base.type); } - inline RefPtr<Type> GetType(DeclRef<TypeDefDecl> const& declRef) + inline RefPtr<Type> GetType(ASTBuilder* astBuilder, DeclRef<TypeDefDecl> const& declRef) { - return declRef.Substitute(declRef.getDecl()->type.Ptr()); + return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } - inline RefPtr<Type> GetResultType(DeclRef<CallableDecl> const& declRef) + inline RefPtr<Type> GetResultType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declRef) { - return declRef.Substitute(declRef.getDecl()->returnType.type.Ptr()); + return declRef.substitute(astBuilder, declRef.getDecl()->returnType.type.Ptr()); } inline FilteredMemberRefList<ParamDecl> GetParameters(DeclRef<CallableDecl> const& declRef) @@ -182,33 +158,32 @@ namespace Slang // RefPtr<ArrayExpressionType> getArrayType( + ASTBuilder* astBuilder, Type* elementType, IntVal* elementCount); RefPtr<ArrayExpressionType> getArrayType( + ASTBuilder* astBuilder, Type* elementType); RefPtr<NamedExpressionType> getNamedType( - Session* session, + ASTBuilder* astBuilder, DeclRef<TypeDefDecl> const& declRef); - RefPtr<TypeType> getTypeType( - Type* type); - RefPtr<FuncType> getFuncType( - Session* session, + ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declRef); RefPtr<GenericDeclRefType> getGenericDeclRefType( - Session* session, + ASTBuilder* astBuilder, DeclRef<GenericDecl> const& declRef); RefPtr<NamespaceType> getNamespaceType( - Session* session, + ASTBuilder* astBuilder, DeclRef<NamespaceDeclBase> const& declRef); RefPtr<SamplerStateType> getSamplerStateType( - Session* session); + ASTBuilder* astBuilder); // Definitions that can't come earlier despite @@ -237,20 +212,20 @@ namespace Slang // TODO: where should this live? SubstitutionSet createDefaultSubstitutions( - Session* session, + ASTBuilder* astBuilder, Decl* decl, SubstitutionSet parentSubst); SubstitutionSet createDefaultSubstitutions( - Session* session, + ASTBuilder* astBuilder, Decl* decl); DeclRef<Decl> createDefaultSubstitutionsIfNeeded( - Session* session, + ASTBuilder* astBuilder, DeclRef<Decl> declRef); - RefPtr<GenericSubstitution> createDefaultSubsitutionsForGeneric( - Session* session, + RefPtr<GenericSubstitution> createDefaultSubstitutionsForGeneric( + ASTBuilder* astBuilder, GenericDecl* genericDecl, RefPtr<Substitutions> outerSubst); diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 29d6c66d4..3f005f18e 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1298,9 +1298,12 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq, ProgramLayout* programLayout) { + auto astBuilder = targetReq->getLinkage()->getASTBuilder(); + LayoutRulesFamilyImpl* rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq); TypeLayoutContext context; + context.astBuilder = astBuilder; context.targetReq = targetReq; context.programLayout = programLayout; context.rules = nullptr; @@ -3460,7 +3463,7 @@ static TypeLayoutResult _createTypeLayout( TypeLayoutResult fieldResult = _createTypeLayout( fieldLayoutContext, - GetType(field).Ptr(), + GetType(context.astBuilder, field).Ptr(), field.getDecl()); auto fieldTypeLayout = fieldResult.layout; diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index fdf980fa9..fbe95608e 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -912,6 +912,8 @@ struct LayoutRulesFamilyImpl struct TypeLayoutContext { + ASTBuilder* astBuilder; + // The layout rules to use (e.g., we compute // layout differently in a `cbuffer` vs. the // parameter list of a fragment shader). diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 2c698eeaf..486034876 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -118,26 +118,18 @@ void Session::init() // Set all the shared library function pointers to nullptr ::memset(m_sharedLibraryFunctions, 0, sizeof(m_sharedLibraryFunctions)); - - // Initialize the lookup table of syntax classes: - - // We can just iterate over the class pointers. - // NOTE! That this adds the names of the abstract classes too(!) - { - for (Index i = 0; i < Index(ASTNodeType::CountOf); ++i) - { - const ReflectClassInfo* info = ReflectClassInfo::getInfo(ASTNodeType(i)); - if (info) - { - mapNameToSyntaxClass.Add(getNamePool()->getName(info->m_name), SyntaxClass<Slang::RefObject>(info)); - } - } - } + // Set up shared AST builder + m_sharedASTBuilder = new SharedASTBuilder; + m_sharedASTBuilder->init(this); + // Use to create a ASTBuilder + RefPtr<ASTBuilder> builtinAstBuilder(new ASTBuilder(m_sharedASTBuilder)); + // Make sure our source manager is initialized builtinSourceManager.initialize(nullptr, nullptr); - m_builtinLinkage = new Linkage(this); + + m_builtinLinkage = new Linkage(this, builtinAstBuilder); // Because the `Session` retains the builtin `Linkage`, // we need to make sure that the parent pointer inside @@ -149,9 +141,7 @@ void Session::init() // m_builtinLinkage->_stopRetainingParentSession(); - // Initialize representations of some very basic types: - initializeTypes(); - + // Create scopes for various language builtins. // // TODO: load these on-demand to avoid parsing @@ -160,7 +150,7 @@ void Session::init() baseLanguageScope = new Scope(); auto baseModuleDecl = populateBaseLanguageModule( - this, + m_builtinLinkage->getASTBuilder(), baseLanguageScope); loadedModuleCode.add(baseModuleDecl); @@ -198,7 +188,8 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Session::createSession( slang::SessionDesc const& desc, slang::ISession** outSession) { - RefPtr<Linkage> linkage = new Linkage(this); + RefPtr<ASTBuilder> astBuilder(new ASTBuilder(m_sharedASTBuilder)); + RefPtr<Linkage> linkage = new Linkage(this, astBuilder); Int targetCount = desc.targetCount; for(Int ii = 0; ii < targetCount; ++ii) @@ -507,10 +498,11 @@ Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target) // -Linkage::Linkage(Session* session) +Linkage::Linkage(Session* session, ASTBuilder* astBuilder) : m_session(session) , m_retainedSession(session) , m_sourceManager(&m_defaultSourceManager) + , m_astBuilder(astBuilder) { getNamePool()->setRootNamePool(session->getRootNamePool()); @@ -868,7 +860,7 @@ RefPtr<Expr> Linkage::parseTermString(String typeStr, RefPtr<Scope> scope) nullptr); return parseTermFromSourceFile( - getSessionImpl(), + getASTBuilder(), tokens, &sink, scope, getNamePool(), SourceLanguage::Slang); } @@ -960,7 +952,7 @@ void FrontEndCompileRequest::parseTranslationUnit( combinedPreprocessorDefinitions.Add(def.Key, def.Value); auto module = translationUnit->getModule(); - RefPtr<ModuleDecl> translationUnitSyntax = new ModuleDecl(); + RefPtr<ModuleDecl> translationUnitSyntax = linkage->getASTBuilder()->create<ModuleDecl>(); translationUnitSyntax->nameAndLoc.name = translationUnit->moduleName; translationUnitSyntax->module = module; module->setModuleDecl(translationUnitSyntax); @@ -980,7 +972,7 @@ void FrontEndCompileRequest::parseTranslationUnit( // if( m_isStandardLibraryCode ) { - translationUnitSyntax->modifiers.first = new FromStdLibModifier(); + translationUnitSyntax->modifiers.first = linkage->getASTBuilder()->create<FromStdLibModifier>(); } for (auto sourceFile : translationUnit->getSourceFiles()) @@ -994,6 +986,7 @@ void FrontEndCompileRequest::parseTranslationUnit( module); parseSourceFile( + linkage->getASTBuilder(), translationUnit, tokens, getSink(), @@ -1062,7 +1055,7 @@ void FrontEndCompileRequest::generateIR() // * it can generate diagnostics /// Generate IR for translation unit - RefPtr<IRModule> irModule(generateIRForTranslationUnit(translationUnit)); + RefPtr<IRModule> irModule(generateIRForTranslationUnit(getLinkage()->getASTBuilder(), translationUnit)); if (verifyDebugSerialization) { @@ -1212,7 +1205,8 @@ EndToEndCompileRequest::EndToEndCompileRequest( : m_session(session) , m_sink(nullptr) { - m_linkage = new Linkage(session); + RefPtr<ASTBuilder> astBuilder(new ASTBuilder(session->m_sharedASTBuilder)); + m_linkage = new Linkage(session, astBuilder); init(); } @@ -1534,7 +1528,7 @@ void Linkage::loadParsedModule( // If we didn't run into any errors, then try to generate // IR code for the imported module. SLANG_ASSERT(errorCountAfter == 0); - loadedModule->setIRModule(generateIRForTranslationUnit(translationUnit)); + loadedModule->setIRModule(generateIRForTranslationUnit(getASTBuilder(), translationUnit)); } loadedModulesList.add(loadedModule); } @@ -2337,7 +2331,7 @@ SpecializedComponentType::SpecializedComponentType( if(specializationInfo) funcDeclRef = specializationInfo->specializedFuncDeclRef; - (*mangledEntryPointNames).add(getMangledName(funcDeclRef)); + (*mangledEntryPointNames).add(getMangledName(m_astBuilder, funcDeclRef)); } void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE @@ -2346,12 +2340,18 @@ SpecializedComponentType::SpecializedComponentType( { visitChildren(composite, specializationInfo); } void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE { visitChildren(specialized); } + + EntryPointMangledNameCollector(ASTBuilder* astBuilder): + m_astBuilder(astBuilder) + { + } + ASTBuilder* m_astBuilder; }; // With the visitor defined, we apply it to ourself to compute // and collect the mangled entry point names. // - EntryPointMangledNameCollector collector; + EntryPointMangledNameCollector collector(getLinkage()->getASTBuilder()); collector.mangledEntryPointNames = &m_entryPointMangledNames; collector.visitSpecialized(this); } @@ -2605,19 +2605,8 @@ void Session::addBuiltinSource( Session::~Session() { - // free all built-in types first - errorType = nullptr; - initializerListType = nullptr; - overloadedType = nullptr; - irBasicBlockType = nullptr; - constExprRate = nullptr; - destroyTypeCheckingCache(); - for (Index i = 0; i < SLANG_COUNT_OF(builtinTypes); ++i) - { - builtinTypes[i].setNull(); - } // destroy modules next loadedModuleCode = decltype(loadedModuleCode)(); } diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index 8c815e1e1..c3e936f8a 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -191,6 +191,7 @@ <ClInclude Include="hlsl.meta.slang.h" /> <ClInclude Include="slang-ast-all.h" /> <ClInclude Include="slang-ast-base.h" /> + <ClInclude Include="slang-ast-builder.h" /> <ClInclude Include="slang-ast-decl.h" /> <ClInclude Include="slang-ast-dump.h" /> <ClInclude Include="slang-ast-expr.h" /> @@ -273,6 +274,7 @@ <ClInclude Include="slang-visitor.h" /> </ItemGroup> <ItemGroup> + <ClCompile Include="slang-ast-builder.cpp" /> <ClCompile Include="slang-ast-dump.cpp" /> <ClCompile Include="slang-ast-reflect.cpp" /> <ClCompile Include="slang-check-conformance.cpp" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index 31e03d73d..bec8547e1 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -24,6 +24,9 @@ <ClInclude Include="slang-ast-base.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="slang-ast-builder.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="slang-ast-decl.h"> <Filter>Header Files</Filter> </ClInclude> @@ -266,6 +269,9 @@ </ClInclude> </ItemGroup> <ItemGroup> + <ClCompile Include="slang-ast-builder.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="slang-ast-dump.cpp"> <Filter>Source Files</Filter> </ClCompile> |
