diff options
Diffstat (limited to 'source')
67 files changed, 6678 insertions, 8127 deletions
diff --git a/source/core/slang-riff.cpp b/source/core/slang-riff.cpp index 0eb0381e1..c1e3e81c3 100644 --- a/source/core/slang-riff.cpp +++ b/source/core/slang-riff.cpp @@ -758,6 +758,24 @@ void RiffContainer::_addChunk(Chunk* chunk) } } +void RiffContainer::setCurrentChunk(Chunk* chunk) +{ + SLANG_ASSERT(chunk); + + switch (chunk->m_kind) + { + case Chunk::Kind::Data: + m_listChunk = nullptr; + m_dataChunk = static_cast<RiffContainer::DataChunk*>(chunk); + break; + + case Chunk::Kind::List: + m_dataChunk = nullptr; + m_listChunk = static_cast<RiffContainer::ListChunk*>(chunk); + break; + } +} + void RiffContainer::startChunk(Chunk::Kind kind, FourCC fourCC) { SLANG_ASSERT(m_listChunk || m_rootList == nullptr); @@ -857,7 +875,10 @@ void RiffContainer::setPayload(Data* data, const void* payload, size_t size) data->m_ownership = Ownership::Arena; data->m_size = size; - data->m_payload = m_arena.allocateAligned(size, kPayloadMinAlignment); + if (size) + { + data->m_payload = m_arena.allocateAligned(size, kPayloadMinAlignment); + } if (payload) { diff --git a/source/core/slang-riff.h b/source/core/slang-riff.h index 1e2c883b9..c858158e6 100644 --- a/source/core/slang-riff.h +++ b/source/core/slang-riff.h @@ -24,17 +24,11 @@ typedef uint32_t FourCC; #define SLANG_FOUR_CC(c0, c1, c2, c3) \ ((FourCC(c0) << 0) | (FourCC(c1) << 8) | (FourCC(c2) << 16) | (FourCC(c3) << 24)) -#define SLANG_FOUR_CC_GET_FIRST_CHAR(x) char((x) & 0xff) -#define SLANG_FOUR_CC_REPLACE_FIRST_CHAR(x, c) (((x) & 0xffffff00) | FourCC(c)) - #else #define SLANG_FOUR_CC(c0, c1, c2, c3) \ ((FourCC(c0) << 24) | (FourCC(c1) << 16) | (FourCC(c2) << 8) | (FourCC(c3) << 0)) -#define SLANG_FOUR_CC_GET_FIRST_CHAR(x) char((x) >> 24) -#define SLANG_FOUR_CC_REPLACE_FIRST_CHAR(x, c) (((x) & 0x00ffffff) | (FourCC(c) << 24)) - #endif enum @@ -451,6 +445,8 @@ public: /// Ctor RiffContainer(); + void setCurrentChunk(Chunk* chunk); + protected: void _addChunk(Chunk* chunk); ListChunk* _newListChunk(FourCC subType); diff --git a/source/slang-core-module/CMakeLists.txt b/source/slang-core-module/CMakeLists.txt index ba70d77b9..600190161 100644 --- a/source/slang-core-module/CMakeLists.txt +++ b/source/slang-core-module/CMakeLists.txt @@ -61,7 +61,6 @@ set(core_module_source_common_args core slang-capability-defs slang-fiddle-output - slang-reflect-headers SPIRV-Headers INCLUDE_DIRECTORIES_PRIVATE ../slang diff --git a/source/slang-wasm/CMakeLists.txt b/source/slang-wasm/CMakeLists.txt index c6c5601e9..152ed5094 100644 --- a/source/slang-wasm/CMakeLists.txt +++ b/source/slang-wasm/CMakeLists.txt @@ -17,7 +17,7 @@ if(EMSCRIPTEN) compiler-core slang-capability-defs slang-capability-lookup - slang-reflect-headers + slang-fiddle-output slang-lookup-tables INCLUDE_DIRECTORIES_PUBLIC ${slang_SOURCE_DIR}/include . ) diff --git a/source/slang/CMakeLists.txt b/source/slang/CMakeLists.txt index 2adc96939..daea0e002 100644 --- a/source/slang/CMakeLists.txt +++ b/source/slang/CMakeLists.txt @@ -97,60 +97,6 @@ slang_add_target( ) # -# generated headers for reflection -# - -set(SLANG_REFLECT_INPUT - slang-ast-support-types.h - slang-ast-base.h - slang-ast-decl.h - slang-ast-expr.h - slang-ast-modifier.h - slang-ast-stmt.h - slang-ast-type.h - slang-ast-val.h -) -# Make them absolute -list(TRANSFORM SLANG_REFLECT_INPUT PREPEND "${CMAKE_CURRENT_LIST_DIR}/") - -set(SLANG_REFLECT_GENERATED_HEADERS - slang-generated-obj.h - slang-generated-obj-macro.h - slang-generated-ast.h - slang-generated-ast-macro.h - slang-generated-value.h - slang-generated-value-macro.h -) -set(SLANG_REFLECT_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/ast-reflect") -list( - TRANSFORM SLANG_REFLECT_GENERATED_HEADERS - PREPEND "${SLANG_REFLECT_OUTPUT_DIR}/" -) - -add_custom_command( - OUTPUT ${SLANG_REFLECT_GENERATED_HEADERS} - COMMAND ${CMAKE_COMMAND} -E make_directory ${SLANG_REFLECT_OUTPUT_DIR} - COMMAND - slang-cpp-extractor ${SLANG_REFLECT_INPUT} -strip-prefix slang- -o - ${SLANG_REFLECT_OUTPUT_DIR}/slang-generated -output-fields -mark-suffix - _CLASS - DEPENDS ${SLANG_REFLECT_INPUT} slang-cpp-extractor - VERBATIM -) - -add_library( - slang-reflect-headers - INTERFACE - EXCLUDE_FROM_ALL - ${SLANG_REFLECT_GENERATED_HEADERS} -) -set_target_properties(slang-reflect-headers PROPERTIES FOLDER generated) -target_include_directories( - slang-reflect-headers - INTERFACE ${SLANG_REFLECT_OUTPUT_DIR} -) - -# # generated lookup tables # @@ -279,7 +225,6 @@ set(slang_link_args slang-capability-defs slang-capability-lookup slang-fiddle-output - slang-reflect-headers slang-lookup-tables SPIRV-Headers ) diff --git a/source/slang/slang-ast-base.cpp b/source/slang/slang-ast-base.cpp index ac18da404..72e42a860 100644 --- a/source/slang/slang-ast-base.cpp +++ b/source/slang/slang-ast-base.cpp @@ -1,3 +1,4 @@ +// slang-ast-base.cpp #include "slang-ast-base.h" #include "slang-ast-builder.h" diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 8f85334d6..72da9cf56 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -2,25 +2,26 @@ #pragma once -#include "slang-ast-reflect.h" +#include "slang-ast-base.h.fiddle" +#include "slang-ast-forward-declarations.h" #include "slang-ast-support-types.h" #include "slang-capability.h" -#include "slang-generated-ast.h" -#include "slang-serialize-reflection.h" // This file defines the primary base classes for the hierarchy of // AST nodes and related objects. For example, this is where the // basic `Decl`, `Stmt`, `Expr`, `type`, etc. definitions come from. +FIDDLE() namespace Slang { class ASTBuilder; struct SemanticsVisitor; +FIDDLE(abstract) class NodeBase { - SLANG_ABSTRACT_AST_CLASS(NodeBase) + FIDDLE(...) // MUST be called before used. Called automatically via the ASTBuilder. // Note that the astBuilder is not stored in the NodeBase derived types by default. @@ -35,18 +36,12 @@ class NodeBase void _initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder); - /// Get the class info - SLANG_FORCE_INLINE const ReflectClassInfo& getClassInfo() const - { - return *ASTClassInfo::getInfo(astNodeType); - } - - SyntaxClass<NodeBase> getClass() { return SyntaxClass<NodeBase>(&getClassInfo()); } + SyntaxClass<NodeBase> getClass() const { return SyntaxClass<NodeBase>(astNodeType); } /// The type of the node. ASTNodeType(-1) is an invalid node type, and shouldn't appear on any /// correctly constructed (through ASTBuilder) NodeBase derived class. /// The actual type is set when constructed on the ASTBuilder. - ASTNodeType astNodeType = ASTNodeType(-1); + FIDDLE() ASTNodeType astNodeType = ASTNodeType(-1); #ifdef _DEBUG SLANG_UNREFLECTED int32_t _debugUID = 0; @@ -58,37 +53,25 @@ class NodeBase template<typename T> SLANG_FORCE_INLINE T* dynamicCast(NodeBase* node) { - return (node && - ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) - ? static_cast<T*>(node) - : nullptr; + return (node && node->getClass().isSubClassOf<T>()) ? static_cast<T*>(node) : nullptr; } template<typename T> SLANG_FORCE_INLINE const T* dynamicCast(const NodeBase* node) { - return (node && - ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) - ? static_cast<const T*>(node) - : nullptr; + return (node && node->getClass().isSubClassOf<T>()) ? static_cast<const T*>(node) : nullptr; } template<typename T> SLANG_FORCE_INLINE T* as(NodeBase* node) { - return (node && - ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) - ? static_cast<T*>(node) - : nullptr; + return (node && node->getClass().isSubClassOf<T>()) ? static_cast<T*>(node) : nullptr; } template<typename T> SLANG_FORCE_INLINE const T* as(const NodeBase* node) { - return (node && - ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) - ? static_cast<const T*>(node) - : nullptr; + return (node && node->getClass().isSubClassOf<T>()) ? static_cast<const T*>(node) : nullptr; } // Because DeclRefBase is now a `Val`, we prevent casting it directly into other nodes @@ -114,9 +97,10 @@ DeclRef<T> as(DeclRef<U> declRef) return DeclRef<T>(declRef); } -struct Scope : public NodeBase +FIDDLE() +class Scope : public NodeBase { - SLANG_AST_CLASS(Scope) + FIDDLE(...) // The container to use for lookup // @@ -135,12 +119,13 @@ struct Scope : public NodeBase // Base class for all nodes representing actual syntax // (thus having a location in the source code) +FIDDLE(abstract) class SyntaxNodeBase : public NodeBase { - SLANG_ABSTRACT_AST_CLASS(SyntaxNodeBase) + FIDDLE(...) // The primary source location associated with this AST node - SourceLoc loc; + FIDDLE() SourceLoc loc; }; enum class ValNodeOperandKind @@ -231,7 +216,7 @@ private: HashCode hashCode = 0; public: - ASTNodeType type; + SyntaxClass<NodeBase> type; ShortList<ValNodeOperand, 8> operands; inline bool operator==(ValNodeDesc const& that) const @@ -363,9 +348,10 @@ static void addOrAppendToNodeList(List<ValNodeOperand>& list, ArrayView<T> l, Ts // a unique location, and any two `Val`s representing // the same value should be conceptually equal. +FIDDLE(abstract) class Val : public NodeBase { - SLANG_ABSTRACT_AST_CLASS(Val) + FIDDLE(...) template<typename T> struct OperandView @@ -406,10 +392,6 @@ class Val : public NodeBase ConstIterator end() const { return ConstIterator{val, offset + count}; } }; - typedef IValVisitor Visitor; - - void accept(IValVisitor* visitor, void* extra); - // construct a new value by applying a set of parameter // substitutions to this one Val* substitute(ASTBuilder* astBuilder, SubstitutionSet subst); @@ -479,7 +461,7 @@ class Val : public NodeBase for (auto v : operands) m_operands.add(ValNodeOperand(v)); } - List<ValNodeOperand> m_operands; + FIDDLE() List<ValNodeOperand> m_operands; // Private use by the core module deserialization only. Since we know the Vals serialized into // the core module is already unique, we can just use `this` pointer as the `m_resolvedVal` so @@ -567,13 +549,10 @@ SLANG_FORCE_INLINE const T* as(const Type* obj); // "canonical" type. The representation caches a pointer to // a canonical type on every type, so we can easily // operate on the raw representation when needed. +FIDDLE(abstract) class Type : public Val { - SLANG_ABSTRACT_AST_CLASS(Type) - - typedef ITypeVisitor Visitor; - - void accept(ITypeVisitor* visitor, void* extra); + FIDDLE(...) /// Type derived types store the AST builder they were constructed on. The builder calls this /// function after constructing. @@ -618,9 +597,10 @@ class Decl; // A reference to a declaration, which may include // substitutions for generic parameters. +FIDDLE(abstract) class DeclRefBase : public Val { - SLANG_ABSTRACT_AST_CLASS(DeclRefBase) + FIDDLE(...) Decl* getDecl() const { return getDeclOperand(0); } @@ -687,9 +667,10 @@ SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, Decl* decl) return io; } +FIDDLE(abstract) class SyntaxNode : public SyntaxNodeBase { - SLANG_ABSTRACT_AST_CLASS(SyntaxNode); + FIDDLE(...) }; // @@ -697,29 +678,28 @@ class SyntaxNode : public SyntaxNodeBase // (that is, we don't use a bitfield, even for simple/common flags). // This ensures that we can track source locations for all modifiers. // +FIDDLE(abstract) class Modifier : public SyntaxNode { - SLANG_ABSTRACT_AST_CLASS(Modifier) - typedef IModifierVisitor Visitor; - - void accept(IModifierVisitor* visitor, void* extra); + FIDDLE(...) // Next modifier in linked list of modifiers on same piece of syntax Modifier* next = nullptr; // The keyword that was used to introduce t that was used to name this modifier. - Name* keywordName = nullptr; + FIDDLE() Name* keywordName = nullptr; Name* getKeywordName() { return keywordName; } NameLoc getKeywordNameAndLoc() { return NameLoc(keywordName, loc); } }; // A syntax node which can have modifiers applied +FIDDLE(abstract) class ModifiableSyntaxNode : public SyntaxNode { - SLANG_ABSTRACT_AST_CLASS(ModifiableSyntaxNode) + FIDDLE(...) - Modifiers modifiers; + FIDDLE() Modifiers modifiers; template<typename T> FilteredModifierList<T> getModifiersOfType() @@ -748,28 +728,25 @@ struct ProvenenceNodeWithLoc }; // An intermediate type to represent either a single declaration, or a group of declarations +FIDDLE(abstract) class DeclBase : public ModifiableSyntaxNode { - SLANG_ABSTRACT_AST_CLASS(DeclBase) - - typedef IDeclVisitor Visitor; - - void accept(IDeclVisitor* visitor, void* extra); + FIDDLE(...) }; +FIDDLE(abstract) class Decl : public DeclBase { + FIDDLE(...) public: - SLANG_ABSTRACT_AST_CLASS(Decl) - - ContainerDecl* parentDecl = nullptr; + FIDDLE() ContainerDecl* parentDecl = nullptr; DeclRefBase* getDefaultDeclRef(); - NameLoc nameAndLoc; - CapabilitySet inferredCapabilityRequirements; + FIDDLE() NameLoc nameAndLoc; + FIDDLE() CapabilitySet inferredCapabilityRequirements; - RefPtr<MarkupEntry> markup; + FIDDLE() RefPtr<MarkupEntry> markup; Name* getName() const { return nameAndLoc.name; } SourceLoc getNameLoc() const { return nameAndLoc.loc; } @@ -797,26 +774,20 @@ private: SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr; }; +FIDDLE(abstract) class Expr : public SyntaxNode { - SLANG_ABSTRACT_AST_CLASS(Expr) - - typedef IExprVisitor Visitor; + FIDDLE(...) - QualType type; + FIDDLE() QualType type; bool checked = false; - - void accept(IExprVisitor* visitor, void* extra); }; +FIDDLE(abstract) class Stmt : public ModifiableSyntaxNode { - SLANG_ABSTRACT_AST_CLASS(Stmt) - - typedef IStmtVisitor Visitor; - - void accept(IStmtVisitor* visitor, void* extra); + FIDDLE(...) }; template<typename T> diff --git a/source/slang/slang-ast-boilerplate.cpp b/source/slang/slang-ast-boilerplate.cpp new file mode 100644 index 000000000..0313d4411 --- /dev/null +++ b/source/slang/slang-ast-boilerplate.cpp @@ -0,0 +1,54 @@ +// slang-ast-boilerplate.cpp + +#include "slang-ast-all.h" +#include "slang-ast-builder.h" +#include "slang-ast-forward-declarations.h" + +namespace Slang +{ +template<typename T> +struct Helper +{ + static void* create(ASTBuilder* builder) { return builder->createImpl<T>(); } + + static void destruct(void* obj) { ((T*)obj)->~T(); } +}; + +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do +const SyntaxClassInfo $T::kSyntaxClassInfo = { + "$T", + ASTNodeType::$T, + $(#T.subclasses), +% if T.isAbstract then + nullptr, // create + nullptr, // destruct +% else + &Helper<$T>::create, + &Helper<$T>::destruct, +% end +}; +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-ast-boilerplate.cpp.fiddle" +#endif // FIDDLE END + +static SyntaxClassInfo const* kAllSyntaxClasses[] = { +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do + &$T::kSyntaxClassInfo, +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 1 +#include "slang-ast-boilerplate.cpp.fiddle" +#endif // FIDDLE END +}; + +SyntaxClassBase::SyntaxClassBase(ASTNodeType tag) +{ + assert(int(tag) >= 0 && int(tag) < SLANG_COUNT_OF(kAllSyntaxClasses)); + _info = kAllSyntaxClasses[int(tag)]; +} + +} // namespace Slang diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index b3afa5310..5abef94b3 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -33,46 +33,34 @@ void SharedASTBuilder::init(Session* session) // NOTE! That this adds the names of the abstract classes too(!) for (Index i = 0; i < Index(ASTNodeType::CountOf); ++i) { - const ReflectClassInfo* info = ASTClassInfo::getInfo(ASTNodeType(i)); - if (info) - { - m_sliceToTypeMap.add(UnownedStringSlice(info->m_name), info); - Name* name = m_namePool->getName(String(info->m_name)); - m_nameToTypeMap.add(name, info); - } + auto syntaxClass = SyntaxClass(ASTNodeType(i)); + if (!syntaxClass) + continue; + auto nameText = syntaxClass.getName(); + m_sliceToTypeMap.add(nameText, syntaxClass); + Name* nameObj = m_namePool->getName(nameText); + m_nameToTypeMap.add(nameObj, syntaxClass); } } -const ReflectClassInfo* SharedASTBuilder::findClassInfo(const UnownedStringSlice& slice) -{ - const ReflectClassInfo* typeInfo; - return m_sliceToTypeMap.tryGetValue(slice, typeInfo) ? typeInfo : nullptr; -} - -SyntaxClass<NodeBase> SharedASTBuilder::findSyntaxClass(const UnownedStringSlice& slice) +SyntaxClass<> SharedASTBuilder::findSyntaxClass(const UnownedStringSlice& slice) { - const ReflectClassInfo* typeInfo; + SyntaxClass typeInfo; if (m_sliceToTypeMap.tryGetValue(slice, typeInfo)) { - return SyntaxClass<NodeBase>(typeInfo); + return typeInfo; } - return SyntaxClass<NodeBase>(); -} - -const ReflectClassInfo* SharedASTBuilder::findClassInfo(Name* name) -{ - const ReflectClassInfo* typeInfo; - return m_nameToTypeMap.tryGetValue(name, typeInfo) ? typeInfo : nullptr; + return getSyntaxClass<NodeBase>(); } SyntaxClass<NodeBase> SharedASTBuilder::findSyntaxClass(Name* name) { - const ReflectClassInfo* typeInfo; + SyntaxClass<NodeBase> typeInfo; if (m_nameToTypeMap.tryGetValue(name, typeInfo)) { - return SyntaxClass<NodeBase>(typeInfo); + return typeInfo; } - return SyntaxClass<NodeBase>(); + return getSyntaxClass<NodeBase>(); } Type* SharedASTBuilder::getStringType() @@ -256,9 +244,8 @@ ASTBuilder::~ASTBuilder() { for (NodeBase* node : m_dtorNodes) { - const ReflectClassInfo* info = ASTClassInfo::getInfo(node->astNodeType); - SLANG_ASSERT(info->m_destructorFunc); - info->m_destructorFunc(node); + auto nodeClass = node->getClass(); + nodeClass.destructInstance(node); } incrementEpoch(); } @@ -275,16 +262,8 @@ void ASTBuilder::incrementEpoch() NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType) { - const ReflectClassInfo* info = ASTClassInfo::getInfo(nodeType); - - auto createFunc = info->m_createFunc; - SLANG_ASSERT(createFunc); - if (!createFunc) - { - return nullptr; - } - - return (NodeBase*)createFunc(this); + auto syntaxClass = SyntaxClass<NodeBase>(nodeType); + return syntaxClass.createInstance(this); } Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTypeName) diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index daf49f3f7..a25fcea28 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -46,10 +46,8 @@ public: Type* getInitializerListType(); Type* getOverloadedType(); - const ReflectClassInfo* findClassInfo(Name* name); SyntaxClass<NodeBase> findSyntaxClass(Name* name); - const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice); SyntaxClass<NodeBase> findSyntaxClass(const UnownedStringSlice& slice); // Look up a magic declaration by its name @@ -113,8 +111,8 @@ protected: Dictionary<String, Decl*> m_magicDecls; Dictionary<BuiltinRequirementKind, Decl*> m_builtinRequirementDecls; - Dictionary<UnownedStringSlice, const ReflectClassInfo*> m_sliceToTypeMap; - Dictionary<Name*, const ReflectClassInfo*> m_nameToTypeMap; + Dictionary<UnownedStringSlice, SyntaxClass<NodeBase>> m_sliceToTypeMap; + Dictionary<Name*, SyntaxClass<NodeBase>> m_nameToTypeMap; NamePool* m_namePool = nullptr; @@ -160,7 +158,7 @@ struct ValKey { if (hashCode != desc.getHashCode()) return false; - if (val->astNodeType != desc.type) + if (val->getClass() != desc.type) return false; if (val->m_operands.getCount() != desc.operands.getCount()) return false; @@ -199,7 +197,7 @@ public: if (auto found = m_cachedNodes.tryGetValue(desc)) return *found; - auto node = as<Val>(createByNodeType(desc.type)); + auto node = as<Val>(desc.type.createInstance(this)); SLANG_ASSERT(node); for (auto& operand : desc.operands) node->m_operands.add(operand); @@ -268,12 +266,14 @@ public: MemoryArena& getArena() { return m_arena; } + NamePool* getNamePool() { return getSharedASTBuilder()->getNamePool(); } + template<typename T, typename... TArgs> SLANG_FORCE_INLINE T* getOrCreate(TArgs... args) { SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); ValNodeDesc desc; - desc.type = T::kType; + desc.type = getSyntaxClass<T>(); addOrAppendToNodeList(desc.operands, args...); desc.init(); auto result = (T*)_getOrCreateImpl(_Move(desc)); @@ -286,7 +286,7 @@ public: SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); ValNodeDesc desc; - desc.type = T::kType; + desc.type = getSyntaxClass<T>(); desc.init(); auto result = (T*)_getOrCreateImpl(_Move(desc)); return result; @@ -642,19 +642,11 @@ public: DeclRef<Decl> declRef); /// Helpers to get type info from the SharedASTBuilder - const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice) - { - return m_sharedASTBuilder->findClassInfo(slice); - } SyntaxClass<NodeBase> findSyntaxClass(const UnownedStringSlice& slice) { return m_sharedASTBuilder->findSyntaxClass(slice); } - const ReflectClassInfo* findClassInfo(Name* name) - { - return m_sharedASTBuilder->findClassInfo(name); - } SyntaxClass<NodeBase> findSyntaxClass(Name* name) { return m_sharedASTBuilder->findSyntaxClass(name); @@ -695,12 +687,12 @@ protected: // Keep such that dtor can be run on ASTBuilder being dtored m_dtorNodes.add(node); } - if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Val::kType))) + if (node->getClass().isSubClassOf(getSyntaxClass<Val>())) { auto val = (Val*)(node); val->m_resolvedValEpoch = getEpoch(); } - else if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Decl::kType))) + else if (node->getClass().isSubClassOf(getSyntaxClass<Decl>())) { ((Decl*)node)->m_defaultDeclRef = getOrCreate<DirectDeclRef>((Decl*)node); } diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp index 9f140a524..a44e5b817 100644 --- a/source/slang/slang-ast-decl-ref.cpp +++ b/source/slang/slang-ast-decl-ref.cpp @@ -1,8 +1,9 @@ +// slang-ast-decl-ref.cpp + #include "slang-ast-builder.h" -#include "slang-ast-reflect.h" +#include "slang-ast-dispatch.h" +#include "slang-ast-forward-declarations.h" #include "slang-check-impl.h" -#include "slang-generated-ast-macro.h" -#include "slang-generated-ast.h" namespace Slang { diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp index c0d0e9242..530f983d9 100644 --- a/source/slang/slang-ast-decl.cpp +++ b/source/slang/slang-ast-decl.cpp @@ -2,7 +2,7 @@ #include "slang-ast-decl.h" #include "slang-ast-builder.h" -#include "slang-generated-ast-macro.h" +#include "slang-ast-dispatch.h" #include "slang-syntax.h" #include <assert.h> @@ -12,7 +12,7 @@ namespace Slang const TypeExp& TypeConstraintDecl::getSup() const { - SLANG_AST_NODE_CONST_VIRTUAL_CALL(TypeConstraintDecl, getSup, ()) + SLANG_AST_NODE_VIRTUAL_CALL(TypeConstraintDecl, getSup, ()) } const TypeExp& TypeConstraintDecl::_getSupOverride() const diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index ff55340ac..261d2458a 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -3,31 +3,35 @@ #pragma once #include "slang-ast-base.h" +#include "slang-ast-decl.h.fiddle" +FIDDLE() namespace Slang { // Syntax class definitions for declarations. // A group of declarations that should be treated as a unit +FIDDLE() class DeclGroup : public DeclBase { - SLANG_AST_CLASS(DeclGroup) - - List<Decl*> decls; + FIDDLE(...) + FIDDLE() List<Decl*> decls; }; +FIDDLE() class UnresolvedDecl : public Decl { - SLANG_AST_CLASS(UnresolvedDecl) + FIDDLE(...) }; // A "container" decl is a parent to other declarations +FIDDLE(abstract) class ContainerDecl : public Decl { - SLANG_ABSTRACT_AST_CLASS(ContainerDecl) + FIDDLE(...) - List<Decl*> members; + FIDDLE() List<Decl*> members; SourceLoc closingSourceLoc; // The associated scope owned by this decl. @@ -86,32 +90,35 @@ class ContainerDecl : public Decl }; // Base class for all variable declarations +FIDDLE(abstract) class VarDeclBase : public Decl { - SLANG_ABSTRACT_AST_CLASS(VarDeclBase) + FIDDLE(...) // type of the variable - TypeExp type; + FIDDLE() TypeExp type; Type* getType() { return type.type; } // Initializer expression (optional) - Expr* initExpr = nullptr; + FIDDLE() Expr* initExpr = nullptr; // Folded IntVal if the initializer is a constant integer. - IntVal* val = nullptr; + FIDDLE() IntVal* val = nullptr; }; // Ordinary potentially-mutable variables (locals, globals, and member variables) +FIDDLE() class VarDecl : public VarDeclBase { - SLANG_AST_CLASS(VarDecl) + FIDDLE(...) }; // A variable declaration that is always immutable (whether local, global, or member variable) +FIDDLE() class LetDecl : public VarDecl { - SLANG_AST_CLASS(LetDecl) + FIDDLE(...) }; // An `AggTypeDeclBase` captures the shared functionality @@ -122,17 +129,18 @@ class LetDecl : public VarDecl // - Both can have declared bases // - Both expose a `this` variable in their body // +FIDDLE(abstract) class AggTypeDeclBase : public ContainerDecl { - SLANG_ABSTRACT_AST_CLASS(AggTypeDeclBase); + FIDDLE(...) }; // An extension to apply to an existing type +FIDDLE() class ExtensionDecl : public AggTypeDeclBase { - SLANG_AST_CLASS(ExtensionDecl) - - TypeExp targetType; + FIDDLE(...) + FIDDLE() TypeExp targetType; }; enum class TypeTag @@ -145,11 +153,11 @@ enum class TypeTag }; // Declaration of a type that represents some sort of aggregate +FIDDLE(abstract) class AggTypeDecl : public AggTypeDeclBase { - SLANG_ABSTRACT_AST_CLASS(AggTypeDecl) - - TypeTag typeTags = TypeTag::None; + FIDDLE(...) + FIDDLE() TypeTag typeTags = TypeTag::None; // Used if this type declaration is a wrapper, i.e. struct FooWrapper:IFoo = Foo; TypeExp wrappedType; @@ -162,23 +170,25 @@ class AggTypeDecl : public AggTypeDeclBase FilteredMemberList<VarDecl> getFields() { return getMembersOfType<VarDecl>(); } }; +FIDDLE() class StructDecl : public AggTypeDecl { - SLANG_AST_CLASS(StructDecl); - + FIDDLE(...) SLANG_UNREFLECTED // We will use these auxiliary to help in synthesizing the member initialize constructor. Slang::HashSet<VarDeclBase*> m_membersVisibleInCtor; }; +FIDDLE() class ClassDecl : public AggTypeDecl { - SLANG_AST_CLASS(ClassDecl) + FIDDLE(...) }; +FIDDLE() class GLSLInterfaceBlockDecl : public AggTypeDecl { - SLANG_AST_CLASS(GLSLInterfaceBlockDecl); + FIDDLE(...) }; // TODO: Is it appropriate to treat an `enum` as an aggregate type? @@ -186,11 +196,11 @@ class GLSLInterfaceBlockDecl : public AggTypeDecl // types are all `AggTypeDecl`, so this is the right choice for now // if we want `enum` types to be able to implement interfaces, etc. // +FIDDLE() class EnumDecl : public AggTypeDecl { - SLANG_AST_CLASS(EnumDecl) - - Type* tagType = nullptr; + FIDDLE(...) + FIDDLE() Type* tagType = nullptr; }; // A single case in an enum. @@ -203,39 +213,40 @@ class EnumDecl : public AggTypeDecl // case, with `0` as an explicit expression for its // _tag value_. // +FIDDLE() class EnumCaseDecl : public Decl { - SLANG_AST_CLASS(EnumCaseDecl) - + FIDDLE(...) // type of the parent `enum` - TypeExp type; + FIDDLE() TypeExp type; Type* getType() { return type.type; } // Tag value - Expr* tagExpr = nullptr; + FIDDLE() Expr* tagExpr = nullptr; - IntVal* tagVal = nullptr; + FIDDLE() IntVal* tagVal = nullptr; }; // A member of InterfaceDecl representing the abstract ThisType. +FIDDLE() class ThisTypeDecl : public AggTypeDecl { - SLANG_AST_CLASS(ThisTypeDecl) + FIDDLE(...) }; // An interface which other types can conform to +FIDDLE() class InterfaceDecl : public AggTypeDecl { - SLANG_AST_CLASS(InterfaceDecl) - + FIDDLE(...) ThisTypeDecl* getThisTypeDecl(); }; +FIDDLE(abstract) class TypeConstraintDecl : public Decl { - SLANG_ABSTRACT_AST_CLASS(TypeConstraintDecl) - + FIDDLE(...) const TypeExp& getSup() const; // Overrides should be public so base classes can access // Implement _getSupOverride on derived classes to change behavior of getSup, as if getSup is @@ -243,11 +254,11 @@ class TypeConstraintDecl : public Decl const TypeExp& _getSupOverride() const; }; +FIDDLE() class ThisTypeConstraintDecl : public TypeConstraintDecl { - SLANG_AST_CLASS(ThisTypeConstraintDecl) - - TypeExp base; + FIDDLE(...) + FIDDLE() TypeExp base; const TypeExp& _getSupOverride() const { return base; } InterfaceDecl* getInterfaceDecl(); }; @@ -255,18 +266,18 @@ class ThisTypeConstraintDecl : public TypeConstraintDecl // A kind of pseudo-member that represents an explicit // or implicit inheritance relationship. // +FIDDLE() class InheritanceDecl : public TypeConstraintDecl { - SLANG_AST_CLASS(InheritanceDecl) - + FIDDLE(...) // The type expression as written - TypeExp base; + FIDDLE() TypeExp base; // After checking, this dictionary will map members // required by the base type to their concrete // implementations in the type that contains // this inheritance declaration. - RefPtr<WitnessTable> witnessTable; + FIDDLE() RefPtr<WitnessTable> witnessTable; // Overrides should be public so base classes can access const TypeExp& _getSupOverride() const { return base; } @@ -279,74 +290,82 @@ class InheritanceDecl : public TypeConstraintDecl // // TODO: probably all types will be aggregate decls eventually, // so that we can easily store conformances/constraints on type variables +FIDDLE(abstract) class SimpleTypeDecl : public Decl { - SLANG_ABSTRACT_AST_CLASS(SimpleTypeDecl) + FIDDLE(...) }; // A `typedef` declaration +FIDDLE() class TypeDefDecl : public SimpleTypeDecl { - SLANG_AST_CLASS(TypeDefDecl) - - TypeExp type; + FIDDLE(...) + FIDDLE() TypeExp type; }; +FIDDLE() class TypeAliasDecl : public TypeDefDecl { - SLANG_AST_CLASS(TypeAliasDecl) + FIDDLE(...) }; // An 'assoctype' declaration, it is a container of inheritance clauses +FIDDLE() class AssocTypeDecl : public AggTypeDecl { - SLANG_AST_CLASS(AssocTypeDecl) + FIDDLE(...) }; // A 'type_param' declaration, which defines a generic // entry-point parameter. Is a container of GenericTypeConstraintDecl +FIDDLE() class GlobalGenericParamDecl : public AggTypeDecl { - SLANG_AST_CLASS(GlobalGenericParamDecl) + FIDDLE(...) }; // A `__generic_value_param` declaration, which defines an existential // value parameter (not a type parameter. +FIDDLE() class GlobalGenericValueParamDecl : public VarDeclBase { - SLANG_AST_CLASS(GlobalGenericValueParamDecl) + FIDDLE(...) }; // A scope for local declarations (e.g., as part of a statement) +FIDDLE() class ScopeDecl : public ContainerDecl { - SLANG_AST_CLASS(ScopeDecl) + FIDDLE(...) }; // A function/initializer/subscript parameter (potentially mutable) +FIDDLE() class ParamDecl : public VarDeclBase { - SLANG_AST_CLASS(ParamDecl) + FIDDLE(...) }; // A parameter of a function declared in "modern" types (immutable unless explicitly `out` or // `inout`) +FIDDLE() class ModernParamDecl : public ParamDecl { - SLANG_AST_CLASS(ModernParamDecl) + FIDDLE(...) }; // Base class for things that have parameter lists and can thus be applied to arguments ("called") +FIDDLE(abstract) class CallableDecl : public ContainerDecl { - SLANG_ABSTRACT_AST_CLASS(CallableDecl) - + FIDDLE(...) FilteredMemberList<ParamDecl> getParameters() { return getMembersOfType<ParamDecl>(); } - TypeExp returnType; + FIDDLE() TypeExp returnType; // If this callable throws an error code, `errorType` is the type of the error code. - TypeExp errorType; + FIDDLE() TypeExp errorType; // Fields related to redeclaration, so that we // can support multiple specialized variations @@ -366,18 +385,18 @@ class CallableDecl : public ContainerDecl // Base class for callable things that may also have a body that is evaluated to produce their // result +FIDDLE(abstract) class FunctionDeclBase : public CallableDecl { - SLANG_ABSTRACT_AST_CLASS(FunctionDeclBase) - - Stmt* body = nullptr; + FIDDLE(...) + FIDDLE() Stmt* body = nullptr; }; // A constructor/initializer to create instances of a type +FIDDLE() class ConstructorDecl : public FunctionDeclBase { - SLANG_AST_CLASS(ConstructorDecl) - + FIDDLE(...) enum class ConstructorFlavor : int { UserDefined = 0x00, @@ -389,51 +408,61 @@ class ConstructorDecl : public FunctionDeclBase SynthesizedMemberInit = 0x02 }; - int m_flavor = (int)ConstructorFlavor::UserDefined; + FIDDLE() int m_flavor = (int)ConstructorFlavor::UserDefined; void addFlavor(ConstructorFlavor flavor) { m_flavor |= (int)flavor; } bool containsFlavor(ConstructorFlavor flavor) { return m_flavor & (int)flavor; } }; // A subscript operation used to index instances of a type +FIDDLE() class SubscriptDecl : public CallableDecl { - SLANG_AST_CLASS(SubscriptDecl) + FIDDLE(...) }; /// A property declaration that abstracts over storage with a getter/setter/etc. +FIDDLE() class PropertyDecl : public ContainerDecl { - SLANG_AST_CLASS(PropertyDecl) - - TypeExp type; + FIDDLE(...) + FIDDLE() TypeExp type; }; // An "accessor" for a subscript or property +FIDDLE(abstract) class AccessorDecl : public FunctionDeclBase { - SLANG_AST_CLASS(AccessorDecl) + FIDDLE(...) }; +FIDDLE() class GetterDecl : public AccessorDecl { - SLANG_AST_CLASS(GetterDecl) + FIDDLE(...) }; + +FIDDLE() class SetterDecl : public AccessorDecl { - SLANG_AST_CLASS(SetterDecl) + FIDDLE(...) }; + +FIDDLE() class RefAccessorDecl : public AccessorDecl { - SLANG_AST_CLASS(RefAccessorDecl) + FIDDLE(...) }; + +FIDDLE() class FuncDecl : public FunctionDeclBase { - SLANG_AST_CLASS(FuncDecl) + FIDDLE(...) }; +FIDDLE(abstract) class NamespaceDeclBase : public ContainerDecl { - SLANG_AST_CLASS(NamespaceDeclBase) + FIDDLE(...) }; // A `namespace` declaration inside some module, that provides @@ -444,16 +473,19 @@ class NamespaceDeclBase : public ContainerDecl // `NamespaceDecl` during parsing, so this declaration does // not directly represent what is present in the input syntax. // +FIDDLE() class NamespaceDecl : public NamespaceDeclBase { - SLANG_AST_CLASS(NamespaceDecl) + FIDDLE(...) }; // A "module" of code (essentially, a single translation unit) // that provides a scope for some number of declarations. +FIDDLE() class ModuleDecl : public NamespaceDeclBase { - SLANG_AST_CLASS(ModuleDecl) + FIDDLE(...) + // The API-level module that this declaration belong to. // // This field allows lookup of the `Module` based on a @@ -467,7 +499,7 @@ class ModuleDecl : public NamespaceDeclBase /// This mapping is filled in during semantic checking, as the decl declarations get checked or /// generated. /// - OrderedDictionary<Decl*, RefPtr<DeclAssociationList>> mapDeclToAssociatedDecls; + FIDDLE() OrderedDictionary<Decl*, RefPtr<DeclAssociationList>> mapDeclToAssociatedDecls; /// Whether the module is defined in legacy language. /// The legacy Slang language does not have visibility modifiers and everything is treated as @@ -477,9 +509,9 @@ class ModuleDecl : public NamespaceDeclBase /// visibility modifiers, or if the module uses new language constructs, e.g. `module`, /// `__include`, /// `__implementing` etc. - bool isInLegacyLanguage = true; + FIDDLE() bool isInLegacyLanguage = true; - DeclVisibility defaultVisibility = DeclVisibility::Internal; + FIDDLE() DeclVisibility defaultVisibility = DeclVisibility::Internal; SLANG_UNREFLECTED @@ -487,19 +519,21 @@ class ModuleDecl : public NamespaceDeclBase /// /// This mapping is filled in during semantic checking, as `ExtensionDecl`s get checked. /// - Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> mapTypeToCandidateExtensions; + FIDDLE() Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> mapTypeToCandidateExtensions; }; // Represents a transparent scope of declarations that are defined in a single source file. +FIDDLE() class FileDecl : public ContainerDecl { - SLANG_AST_CLASS(FileDecl); + FIDDLE(...) }; /// A declaration that brings members of another declaration or namespace into scope +FIDDLE() class UsingDecl : public Decl { - SLANG_AST_CLASS(UsingDecl) + FIDDLE(...) /// An expression that identifies the entity (e.g., a namespace) to be brought into `scope` Expr* arg = nullptr; @@ -509,9 +543,10 @@ class UsingDecl : public Decl Scope* scope = nullptr; }; +FIDDLE() class FileReferenceDeclBase : public Decl { - SLANG_AST_CLASS(FileReferenceDeclBase) + FIDDLE(...) // The name of the module we are trying to import NameLoc moduleNameAndLoc; @@ -524,107 +559,115 @@ class FileReferenceDeclBase : public Decl Scope* scope = nullptr; }; +FIDDLE() class ImportDecl : public FileReferenceDeclBase { - SLANG_AST_CLASS(ImportDecl) + FIDDLE(...) // The module that actually got imported - ModuleDecl* importedModuleDecl = nullptr; + FIDDLE() ModuleDecl* importedModuleDecl = nullptr; }; +FIDDLE(abstract) class IncludeDeclBase : public FileReferenceDeclBase { - SLANG_AST_CLASS(IncludeDeclBase) - + FIDDLE(...) FileDecl* fileDecl = nullptr; }; +FIDDLE() class IncludeDecl : public IncludeDeclBase { - SLANG_AST_CLASS(IncludeDecl) + FIDDLE(...) }; +FIDDLE() class ImplementingDecl : public IncludeDeclBase { - SLANG_AST_CLASS(ImplementingDecl) + FIDDLE(...) }; +FIDDLE() class ModuleDeclarationDecl : public Decl { - SLANG_AST_CLASS(ModuleDeclarationDecl) + FIDDLE(...) }; +FIDDLE() class RequireCapabilityDecl : public Decl { - SLANG_AST_CLASS(RequireCapabilityDecl) + FIDDLE(...) }; // A generic declaration, parameterized on types/values +FIDDLE() class GenericDecl : public ContainerDecl { - SLANG_AST_CLASS(GenericDecl) + FIDDLE(...) // The decl that is genericized... - Decl* inner = nullptr; + FIDDLE() Decl* inner = nullptr; }; +FIDDLE(abstract) class GenericTypeParamDeclBase : public SimpleTypeDecl { - SLANG_AST_CLASS(GenericTypeParamDeclBase) - + FIDDLE(...) // The index of the generic parameter. int parameterIndex = -1; }; +FIDDLE() class GenericTypeParamDecl : public GenericTypeParamDeclBase { - SLANG_AST_CLASS(GenericTypeParamDecl) + FIDDLE(...) // The bound for the type parameter represents a trait that any // type used as this parameter must conform to // TypeExp bound; // The "initializer" for the parameter represents a default value - TypeExp initType; + FIDDLE() TypeExp initType; }; +FIDDLE() class GenericTypePackParamDecl : public GenericTypeParamDeclBase { - SLANG_AST_CLASS(GenericTypePackParamDecl) + FIDDLE(...) }; // A constraint placed as part of a generic declaration +FIDDLE() class GenericTypeConstraintDecl : public TypeConstraintDecl { - SLANG_AST_CLASS(GenericTypeConstraintDecl) - + FIDDLE(...) // A type constraint like `T : U` is constraining `T` to be "below" `U` // on a lattice of types. This may not be a subtyping relationship // per se, but it makes sense to use that terminology here, so we // think of these fields as the sub-type and super-type, respectively. - TypeExp sub; - TypeExp sup; + FIDDLE() TypeExp sub; + FIDDLE() TypeExp sup; // If this decl is defined in a where clause, store the source location of the where token. SourceLoc whereTokenLoc = SourceLoc(); - bool isEqualityConstraint = false; + FIDDLE() bool isEqualityConstraint = false; // Overrides should be public so base classes can access const TypeExp& _getSupOverride() const { return sup; } }; +FIDDLE() class TypeCoercionConstraintDecl : public Decl { - SLANG_AST_CLASS(TypeCoercionConstraintDecl) - + FIDDLE(...) SourceLoc whereTokenLoc = SourceLoc(); - TypeExp fromType; - TypeExp toType; + FIDDLE() TypeExp fromType; + FIDDLE() TypeExp toType; }; +FIDDLE() class GenericValueParamDecl : public VarDeclBase { - SLANG_AST_CLASS(GenericValueParamDecl) - + FIDDLE(...) // The index of the generic parameter. int parameterIndex = 0; }; @@ -638,20 +681,21 @@ class GenericValueParamDecl : public VarDeclBase // // layout(local_size_x = 16) in; // +FIDDLE() class EmptyDecl : public Decl { - SLANG_AST_CLASS(EmptyDecl) + FIDDLE(...) }; // A declaration used by the implementation to put syntax keywords // into the current scope. // +FIDDLE() class SyntaxDecl : public Decl { - SLANG_AST_CLASS(SyntaxDecl) - + FIDDLE(...) // What type of syntax node will be produced when parsing with this keyword? - SyntaxClass<NodeBase> syntaxClass; + FIDDLE() SyntaxClass<NodeBase> syntaxClass; SLANG_UNREFLECTED @@ -662,44 +706,48 @@ class SyntaxDecl : public Decl // A declaration of an attribute to be used with `[name(...)]` syntax. // +FIDDLE() class AttributeDecl : public ContainerDecl { - SLANG_AST_CLASS(AttributeDecl) + FIDDLE(...) // What type of syntax node will be produced to represent this attribute. - SyntaxClass<NodeBase> syntaxClass; + FIDDLE() SyntaxClass<NodeBase> syntaxClass; }; // A synthesized decl used as a placeholder for a differentiable function requirement. This decl // will be a child of interface decl. This allows us to form an interface requirement key for the // derivative of an interface function. The synthesized `DerivativeRequirementDecl` will be a child // of the original function requirement decl after an interface type is checked. +FIDDLE() class DerivativeRequirementDecl : public FunctionDeclBase { - SLANG_AST_CLASS(DerivativeRequirementDecl) - + FIDDLE(...) // The original requirement decl. - Decl* originalRequirementDecl = nullptr; + FIDDLE() Decl* originalRequirementDecl = nullptr; // Type to use for 'ThisType' - Type* diffThisType; + FIDDLE() Type* diffThisType; }; // A reference to a synthesized decl representing a differentiable function requirement, this decl // will be a child in the orignal function. +FIDDLE() class DerivativeRequirementReferenceDecl : public FunctionDeclBase { - SLANG_AST_CLASS(DerivativeRequirementReferenceDecl) - DerivativeRequirementDecl* referencedDecl; + FIDDLE(...) + FIDDLE() DerivativeRequirementDecl* referencedDecl; }; +FIDDLE() class ForwardDerivativeRequirementDecl : public DerivativeRequirementDecl { - SLANG_AST_CLASS(ForwardDerivativeRequirementDecl) + FIDDLE(...) }; +FIDDLE() class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl { - SLANG_AST_CLASS(BackwardDerivativeRequirementDecl) + FIDDLE(...) }; bool isInterfaceRequirement(Decl* decl); diff --git a/source/slang/slang-ast-dispatch.h b/source/slang/slang-ast-dispatch.h new file mode 100644 index 000000000..58c67a974 --- /dev/null +++ b/source/slang/slang-ast-dispatch.h @@ -0,0 +1,56 @@ +// slang-ast-dispatch.h +#pragma once + +#include "slang-ast-forward-declarations.h" +#include "slang-syntax.h" + +namespace Slang +{ + +template<typename Base, typename Result> +struct ASTNodeDispatcher +{ +}; + +#if 0 // FIDDLE TEMPLATE: +%function generateDispatcher(BASE) +template<typename R> +struct ASTNodeDispatcher<$BASE, R> +{ + template<typename F> + static R dispatch($BASE const* obj, F const& f) + { + switch (obj->getClass().getTag()) + { + default: + SLANG_UNEXPECTED("unhandled subclass in ASTNodeDispatcher::dispatch"); + +% for _,T in ipairs(BASE.subclasses) do +% if not T.isAbstract then + case ASTNodeType::$T: + return f(static_cast<$T*>(const_cast<$BASE*>(obj))); +% end +% end + } + } +}; +%end +%generateDispatcher(Slang.TypeConstraintDecl) +%generateDispatcher(Slang.ArithmeticExpressionType) +%generateDispatcher(Slang.DeclRefBase) +%generateDispatcher(Slang.Val) +%generateDispatcher(Slang.Type) +%generateDispatcher(Slang.SubtypeWitness) +%generateDispatcher(Slang.IntVal) +%generateDispatcher(Slang.Modifier) +%generateDispatcher(Slang.DeclBase) +%generateDispatcher(Slang.Decl) +%generateDispatcher(Slang.Expr) +%generateDispatcher(Slang.Stmt) +%generateDispatcher(Slang.NodeBase) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-ast-dispatch.h.fiddle" +#endif // FIDDLE END + +} // namespace Slang diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index bd366be19..24b10344d 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -2,8 +2,8 @@ #include "slang-ast-dump.h" #include "../core/slang-string.h" +#include "slang-ast-dispatch.h" #include "slang-compiler.h" -#include "slang-generated-ast-macro.h" #include <assert.h> #include <limits> @@ -11,12 +11,10 @@ namespace Slang { - struct ASTDumpContext { struct ObjectInfo { - const ReflectClassInfo* m_typeInfo; NodeBase* m_object; bool m_isDumped; }; @@ -48,10 +46,10 @@ struct ASTDumpContext ASTDumpContext* m_context; }; - void dumpObject(const ReflectClassInfo& type, NodeBase* obj); + void dumpObject(NodeBase* obj); - void dumpObjectFull(const ReflectClassInfo& type, NodeBase* obj, Index objIndex); - void dumpObjectReference(const ReflectClassInfo& type, NodeBase* obj, Index objIndex); + void dumpObjectFull(NodeBase* obj, Index objIndex); + void dumpObjectReference(NodeBase* obj, Index objIndex); void dump(NodeBase* node) { @@ -61,7 +59,7 @@ struct ASTDumpContext } else { - dumpObject(node->getClassInfo(), node); + dumpObject(node); } } @@ -283,7 +281,7 @@ struct ASTDumpContext m_writer->emit(" }"); } - Index getObjectIndex(const ReflectClassInfo& typeInfo, NodeBase* obj) + Index getObjectIndex(NodeBase* obj) { Index* indexPtr = m_objectMap.tryGetValueOrAdd(obj, m_objects.getCount()); if (indexPtr) @@ -294,7 +292,6 @@ struct ASTDumpContext ObjectInfo info; info.m_isDumped = false; info.m_object = obj; - info.m_typeInfo = &typeInfo; m_objects.add(info); return m_objects.getCount() - 1; @@ -366,7 +363,7 @@ struct ASTDumpContext template<typename T> void dump(const SyntaxClass<T>& cls) { - m_writer->emit(cls.classInfo->m_name); + m_writer->emit(cls.getName()); } template<typename KEY, typename VALUE> @@ -568,7 +565,7 @@ struct ASTDumpContext ObjectInfo& info = m_objects[i]; if (!info.m_isDumped) { - dumpObjectFull(*info.m_typeInfo, info.m_object, i); + dumpObjectFull(info.m_object, i); } } } @@ -580,13 +577,12 @@ struct ASTDumpContext // Lets special case handling of module decls -> we only want to output as references // otherwise we end up dumping everything in every module. - const ReflectClassInfo& typeInfo = moduleDecl->getClassInfo(); - Index index = getObjectIndex(typeInfo, moduleDecl); + Index index = getObjectIndex(moduleDecl); // We don't want to fully dump, referenced modules as doing so dumps everything m_objects[index].m_isDumped = true; - dumpObjectReference(typeInfo, moduleDecl, index); + dumpObjectReference(moduleDecl, index); } else { @@ -640,9 +636,9 @@ struct ASTDumpContext void dump(ASTNodeType nodeType) { // Get the class - auto info = ASTClassInfo::getInfo(nodeType); + auto syntaxClass = SyntaxClass<NodeBase>(nodeType); // Write the name - m_writer->emit(info->m_name); + m_writer->emit(syntaxClass.getName()); } void dump(SourceLanguage language) { m_writer->emit((int)language); } @@ -775,35 +771,36 @@ struct ASTDumpContext struct ASTDumpAccess { +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do + static void dump_($T * node, ASTDumpContext & context) + { +% if T.directSuperClass then + dump_(static_cast<$(T.directSuperClass)*>(node), context); +% end +% for _,f in ipairs(T.directFields) do + context.dumpField("$f", node->$f); +% end + } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-ast-dump.cpp.fiddle" +#endif // FIDDLE END -#define SLANG_AST_DUMP_FIELD(FIELD_NAME, TYPE, param) \ - context.dumpField(#FIELD_NAME, static_cast<param*>(base)->FIELD_NAME); - -#define SLANG_AST_DUMP_FIELDS_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - case ASTNodeType::NAME: \ - { \ - SLANG_FIELDS_ASTNode_##NAME(SLANG_AST_DUMP_FIELD, NAME) break; \ - } - - static void dump(ASTNodeType type, NodeBase* base, ASTDumpContext& context) + static void dump(NodeBase* base, ASTDumpContext& context) { - switch (type) - { - SLANG_ALL_ASTNode_NodeBase(SLANG_AST_DUMP_FIELDS_IMPL, _) default : break; - } + ASTNodeDispatcher<NodeBase, void>::dispatch(base, [&](auto b) { dump_(b, context); }); } }; -void ASTDumpContext::dumpObjectReference( - const ReflectClassInfo& type, - NodeBase* obj, - Index objIndex) +void ASTDumpContext::dumpObjectReference(NodeBase* obj, Index objIndex) { SLANG_UNUSED(obj); - ScopeWrite(this).getBuf() << type.m_name << ":" << objIndex; + ScopeWrite(this).getBuf() << obj->getClass().getName() << ":" << objIndex; } -void ASTDumpContext::dumpObjectFull(const ReflectClassInfo& type, NodeBase* obj, Index objIndex) +void ASTDumpContext::dumpObjectFull(NodeBase* obj, Index objIndex) { ObjectInfo& info = m_objects[objIndex]; SLANG_ASSERT(info.m_isDumped == false); @@ -811,42 +808,27 @@ void ASTDumpContext::dumpObjectFull(const ReflectClassInfo& type, NodeBase* obj, // We need to dump the fields. - ScopeWrite(this).getBuf() << type.m_name << ":" << objIndex << " {\n"; + ScopeWrite(this).getBuf() << obj->getClass().getName() << ":" << objIndex << " {\n"; m_writer->indent(); - List<const ReflectClassInfo*> allTypes; - { - const ReflectClassInfo* curType = &type; - do - { - allTypes.add(curType); - curType = curType->m_superClass; - } while (curType); - } - - // Okay we go backwards so we output in the 'normal' order - for (Index i = allTypes.getCount() - 1; i >= 0; --i) - { - const ReflectClassInfo* curType = allTypes[i]; - ASTDumpAccess::dump(ASTNodeType(curType->m_classId), obj, *this); - } + ASTDumpAccess::dump(obj, *this); m_writer->dedent(); m_writer->emit("}\n"); } -void ASTDumpContext::dumpObject(const ReflectClassInfo& typeInfo, NodeBase* obj) +void ASTDumpContext::dumpObject(NodeBase* obj) { - Index index = getObjectIndex(typeInfo, obj); + Index index = getObjectIndex(obj); ObjectInfo& info = m_objects[index]; if (info.m_isDumped || m_dumpStyle == ASTDumpUtil::Style::Flat) { - dumpObjectReference(typeInfo, obj, index); + dumpObjectReference(obj, index); } else { - dumpObjectFull(typeInfo, obj, index); + dumpObjectFull(obj, index); } } @@ -858,9 +840,8 @@ void ASTDumpContext::dumpObjectFull(NodeBase* node) } else { - const ReflectClassInfo& typeInfo = node->getClassInfo(); - Index index = getObjectIndex(typeInfo, node); - dumpObjectFull(typeInfo, node, index); + Index index = getObjectIndex(node); + dumpObjectFull(node, index); } } diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index c9bc86b79..cd5f9b6e8 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -1,9 +1,10 @@ // slang-ast-expr.h - #pragma once #include "slang-ast-base.h" +#include "slang-ast-expr.h.fiddle" +FIDDLE() namespace Slang { @@ -12,19 +13,20 @@ using SpvWord = uint32_t; // Syntax class definitions for expressions. // // A placeholder for where an Expr is expected but is missing from source. +FIDDLE() class IncompleteExpr : public Expr { - SLANG_AST_CLASS(IncompleteExpr) + FIDDLE(...) }; + // Base class for expressions that will reference declarations +FIDDLE(abstract) class DeclRefExpr : public Expr { - SLANG_ABSTRACT_AST_CLASS(DeclRefExpr) - - + FIDDLE(...) // The declaration of the symbol being referenced - DeclRef<Decl> declRef; + FIDDLE() DeclRef<Decl> declRef; // The name of the symbol being referenced Name* name = nullptr; @@ -36,22 +38,24 @@ class DeclRefExpr : public Expr Scope* scope = nullptr; }; +FIDDLE() class VarExpr : public DeclRefExpr { - SLANG_AST_CLASS(VarExpr) + FIDDLE(...) }; +FIDDLE() class DefaultConstructExpr : public Expr { - SLANG_AST_CLASS(DefaultConstructExpr) + FIDDLE(...) }; // An expression that references an overloaded set of declarations // having the same name. +FIDDLE() class OverloadedExpr : public Expr { - SLANG_AST_CLASS(OverloadedExpr) - + FIDDLE(...) // The name that was looked up and found to be overloaded Name* name = nullptr; @@ -67,10 +71,10 @@ class OverloadedExpr : public Expr // An expression that references an overloaded set of declarations // having the same name. +FIDDLE() class OverloadedExpr2 : public Expr { - SLANG_AST_CLASS(OverloadedExpr2) - + FIDDLE(...) // Optional: the base expression is this overloaded result // arose from a member-reference expression. Expr* base = nullptr; @@ -79,108 +83,117 @@ class OverloadedExpr2 : public Expr List<Expr*> candidiateExprs; }; +FIDDLE(abstract) class LiteralExpr : public Expr { - SLANG_ABSTRACT_AST_CLASS(LiteralExpr) + FIDDLE(...) // The token that was used to express the literal. This can be // used to get the raw text of the literal, including any suffix. Token token; - BaseType suffixType = BaseType::Void; + FIDDLE() BaseType suffixType = BaseType::Void; }; +FIDDLE() class IntegerLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(IntegerLiteralExpr) - - IntegerLiteralValue value; + FIDDLE(...) + FIDDLE() IntegerLiteralValue value; }; +FIDDLE() class FloatingPointLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(FloatingPointLiteralExpr) - FloatingPointLiteralValue value; + FIDDLE(...) + FIDDLE() FloatingPointLiteralValue value; }; +FIDDLE() class BoolLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(BoolLiteralExpr) - bool value; + FIDDLE(...) + FIDDLE() bool value; }; +FIDDLE() class NullPtrLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(NullPtrLiteralExpr) + FIDDLE(...) }; +FIDDLE() class NoneLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(NoneLiteralExpr) + FIDDLE(...) }; +FIDDLE() class StringLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(StringLiteralExpr) - + FIDDLE(...) // TODO: consider storing the "segments" of the string // literal, in the case where multiple literals were // lined up at the lexer level, e.g.: // // "first" "second" "third" // - String value; + FIDDLE() String value; }; // An initializer list, e.g. `{ 1, 2, 3 }` +FIDDLE() class InitializerListExpr : public Expr { - SLANG_AST_CLASS(InitializerListExpr) - List<Expr*> args; + FIDDLE(...) + FIDDLE() List<Expr*> args; bool useCStyleInitialization = true; }; +FIDDLE() class GetArrayLengthExpr : public Expr { - SLANG_AST_CLASS(GetArrayLengthExpr) - Expr* arrayExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* arrayExpr = nullptr; }; +FIDDLE() class ExpandExpr : public Expr { - SLANG_AST_CLASS(ExpandExpr) - Expr* baseExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* baseExpr = nullptr; }; +FIDDLE() class EachExpr : public Expr { - SLANG_AST_CLASS(EachExpr) - Expr* baseExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* baseExpr = nullptr; }; // A base class for expressions with arguments +FIDDLE(abstract) class ExprWithArgsBase : public Expr { - SLANG_ABSTRACT_AST_CLASS(ExprWithArgsBase) - - List<Expr*> arguments; + FIDDLE(...) + FIDDLE() List<Expr*> arguments; }; // An aggregate type constructor +FIDDLE() class AggTypeCtorExpr : public ExprWithArgsBase { - SLANG_AST_CLASS(AggTypeCtorExpr) - - TypeExp base; + FIDDLE(...) + FIDDLE() TypeExp base; }; // A base expression being applied to arguments: covers // both ordinary `()` function calls and `<>` generic application +FIDDLE(abstract) class AppExprBase : public ExprWithArgsBase { - SLANG_ABSTRACT_AST_CLASS(AppExprBase) - - Expr* functionExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* functionExpr = nullptr; // The original function expr before overload resolution. Expr* originalFunctionExpr = nullptr; @@ -190,14 +203,16 @@ class AppExprBase : public ExprWithArgsBase List<SourceLoc> argumentDelimeterLocs; }; +FIDDLE() class InvokeExpr : public AppExprBase { - SLANG_AST_CLASS(InvokeExpr) + FIDDLE(...) }; +FIDDLE() class ExplicitCtorInvokeExpr : public InvokeExpr { - SLANG_AST_CLASS(ExplicitCtorInvokeExpr) + FIDDLE(...) }; enum class TryClauseType @@ -211,69 +226,80 @@ enum class TryClauseType char const* getTryClauseTypeName(TryClauseType value); +FIDDLE() class TryExpr : public Expr { - SLANG_AST_CLASS(TryExpr) - - Expr* base; + FIDDLE(...) + FIDDLE() Expr* base; - TryClauseType tryClauseType = TryClauseType::Standard; + FIDDLE() TryClauseType tryClauseType = TryClauseType::Standard; // The scope of this expr. Scope* scope = nullptr; }; +FIDDLE() class NewExpr : public InvokeExpr { - SLANG_AST_CLASS(NewExpr) + FIDDLE(...) }; +FIDDLE() class OperatorExpr : public InvokeExpr { - SLANG_AST_CLASS(OperatorExpr) + FIDDLE(...) }; +FIDDLE() class InfixExpr : public OperatorExpr { - SLANG_AST_CLASS(InfixExpr) + FIDDLE(...) }; + +FIDDLE() class PrefixExpr : public OperatorExpr { - SLANG_AST_CLASS(PrefixExpr) + FIDDLE(...) }; + +FIDDLE() class PostfixExpr : public OperatorExpr { - SLANG_AST_CLASS(PostfixExpr) + FIDDLE(...) }; +FIDDLE() class IndexExpr : public Expr { - SLANG_AST_CLASS(IndexExpr) - Expr* baseExpression; - List<Expr*> indexExprs; + FIDDLE(...) + FIDDLE() Expr* baseExpression; + FIDDLE() List<Expr*> indexExprs; // The source location of `(`, `)`, and `,` that marks the start/end of the application op and // each argument expr. This info is used by language server. List<SourceLoc> argumentDelimeterLocs; }; +FIDDLE() class MemberExpr : public DeclRefExpr { - SLANG_AST_CLASS(MemberExpr) - Expr* baseExpression = nullptr; + FIDDLE(...) + FIDDLE() Expr* baseExpression = nullptr; SourceLoc memberOperatorLoc; }; // Member expression that is dereferenced, e.g. `a->b`. +FIDDLE() class DerefMemberExpr : public MemberExpr { - SLANG_AST_CLASS(DerefMemberExpr) + FIDDLE(...) }; // Member looked up on a type, rather than a value +FIDDLE() class StaticMemberExpr : public DeclRefExpr { - SLANG_AST_CLASS(StaticMemberExpr) + FIDDLE(...) Expr* baseExpression = nullptr; SourceLoc memberOperatorLoc; }; @@ -287,69 +313,77 @@ struct MatrixCoord int col; }; +FIDDLE() class MatrixSwizzleExpr : public Expr { - SLANG_AST_CLASS(MatrixSwizzleExpr) - Expr* base = nullptr; - int elementCount; - MatrixCoord elementCoords[4]; + FIDDLE(...) + FIDDLE() Expr* base = nullptr; + FIDDLE() int elementCount; + FIDDLE() MatrixCoord elementCoords[4]; SourceLoc memberOpLoc; }; +FIDDLE() class SwizzleExpr : public Expr { - SLANG_AST_CLASS(SwizzleExpr) - Expr* base = nullptr; - ShortList<uint32_t, 4> elementIndices; + FIDDLE(...) + FIDDLE() Expr* base = nullptr; + FIDDLE() ShortList<uint32_t, 4> elementIndices; SourceLoc memberOpLoc; }; // An operation to convert an l-value to a reference type. +FIDDLE() class MakeRefExpr : public Expr { - SLANG_AST_CLASS(MakeRefExpr) - Expr* base = nullptr; + FIDDLE(...) + FIDDLE() Expr* base = nullptr; }; // A dereference of a pointer or pointer-like type +FIDDLE() class DerefExpr : public Expr { - SLANG_AST_CLASS(DerefExpr) - Expr* base = nullptr; + FIDDLE(...) + FIDDLE() Expr* base = nullptr; }; // Any operation that performs type-casting +FIDDLE() class TypeCastExpr : public InvokeExpr { - SLANG_AST_CLASS(TypeCastExpr) + FIDDLE(...) // TypeExp TargetType; // Expr* Expression = nullptr; }; // An explicit type-cast that appear in the user's code with `(type) expr` syntax +FIDDLE() class ExplicitCastExpr : public TypeCastExpr { - SLANG_AST_CLASS(ExplicitCastExpr) + FIDDLE(...) }; // An implicit type-cast inserted during semantic checking +FIDDLE() class ImplicitCastExpr : public TypeCastExpr { - SLANG_AST_CLASS(ImplicitCastExpr) + FIDDLE(...) }; // A builtin cast expr generated during semantic checking, where there is // no associated conversion function decl. +FIDDLE() class BuiltinCastExpr : public Expr { - SLANG_AST_CLASS(BuiltinCastExpr); - Expr* base = nullptr; + FIDDLE(...) + FIDDLE() Expr* base = nullptr; }; +FIDDLE() class LValueImplicitCastExpr : public TypeCastExpr { - SLANG_AST_CLASS(LValueImplicitCastExpr) - + FIDDLE(...) explicit LValueImplicitCastExpr(const TypeCastExpr& rhs) : Super(rhs) { @@ -359,10 +393,10 @@ class LValueImplicitCastExpr : public TypeCastExpr // To work around situations like int += uint // where we want to allow an LValue to work with an implicit cast. // The argument being cast *must* be an LValue. +FIDDLE() class OutImplicitCastExpr : public LValueImplicitCastExpr { - SLANG_AST_CLASS(OutImplicitCastExpr) - + FIDDLE(...) /// Allow explict construction from any TypeCastExpr explicit OutImplicitCastExpr(const TypeCastExpr& rhs) : Super(rhs) @@ -370,10 +404,10 @@ class OutImplicitCastExpr : public LValueImplicitCastExpr } }; +FIDDLE() class InOutImplicitCastExpr : public LValueImplicitCastExpr { - SLANG_AST_CLASS(InOutImplicitCastExpr) - + FIDDLE(...) /// Allow explict construction from any TypeCastExpr explicit InOutImplicitCastExpr(const TypeCastExpr& rhs) : Super(rhs) @@ -385,249 +419,266 @@ class InOutImplicitCastExpr : public LValueImplicitCastExpr /// /// The type being cast to is stored as this expression's `type`. /// +FIDDLE() class CastToSuperTypeExpr : public Expr { - SLANG_AST_CLASS(CastToSuperTypeExpr) - + FIDDLE(...) /// The value being cast to a super type /// /// The type being cast from is `valueArg->type`. /// - Expr* valueArg = nullptr; + FIDDLE() Expr* valueArg = nullptr; /// A witness showing that `valueArg`'s type is a sub-type of this expression's `type` - Val* witnessArg = nullptr; + FIDDLE() Val* witnessArg = nullptr; }; /// A `value is Type` expression that evaluates to `true` if type of `value` is a sub-type of /// `Type`. +FIDDLE() class IsTypeExpr : public Expr { - SLANG_AST_CLASS(IsTypeExpr) - - Expr* value = nullptr; - TypeExp typeExpr; + FIDDLE(...) + FIDDLE() Expr* value = nullptr; + FIDDLE() TypeExp typeExpr; // A witness showing that `typeExpr.type` is a subtype of `typeof(value)`. - Val* witnessArg = nullptr; + FIDDLE() Val* witnessArg = nullptr; // non-null if evaluates to a constant. - BoolLiteralExpr* constantVal = nullptr; + FIDDLE() BoolLiteralExpr* constantVal = nullptr; }; /// A `value as Type` expression that casts `value` to `Type` within type hierarchy. /// The result is undefined if `value` is not `Type`. +FIDDLE() class AsTypeExpr : public Expr { - SLANG_AST_CLASS(AsTypeExpr) - - Expr* value = nullptr; - Expr* typeExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* value = nullptr; + FIDDLE() Expr* typeExpr = nullptr; // A witness showing that `typeExpr` is a subtype of `typeof(value)`. - Val* witnessArg = nullptr; + FIDDLE() Val* witnessArg = nullptr; }; +FIDDLE(abstract) class SizeOfLikeExpr : public Expr { - SLANG_AST_CLASS(SizeOfLikeExpr); - + FIDDLE(...) // Set during the parse, could be an expression, a variable or a type - Expr* value = nullptr; + FIDDLE() Expr* value = nullptr; // The type the size/alignment needs to operate on. Set during traversal of SemanticsExprVisitor - Type* sizedType = nullptr; + FIDDLE() Type* sizedType = nullptr; }; +FIDDLE() class SizeOfExpr : public SizeOfLikeExpr { - SLANG_AST_CLASS(SizeOfExpr); + FIDDLE(...) }; +FIDDLE() class AlignOfExpr : public SizeOfLikeExpr { - SLANG_AST_CLASS(AlignOfExpr); + FIDDLE(...) }; +FIDDLE() class CountOfExpr : public SizeOfLikeExpr { - SLANG_AST_CLASS(CountOfExpr); + FIDDLE(...) }; +FIDDLE() class MakeOptionalExpr : public Expr { - SLANG_AST_CLASS(MakeOptionalExpr) - + FIDDLE(...) // If `value` is null, this constructs an `Optional<T>` that doesn't have a value. - Expr* value = nullptr; - Expr* typeExpr = nullptr; + FIDDLE() Expr* value = nullptr; + FIDDLE() Expr* typeExpr = nullptr; }; /// A cast of a value to the same type, with different modifiers. /// /// The type being cast to is stored as this expression's `type`. /// +FIDDLE() class ModifierCastExpr : public Expr { - SLANG_AST_CLASS(ModifierCastExpr) - + FIDDLE(...) /// The value being cast. /// /// The type being cast from is `valueArg->type`. /// - Expr* valueArg = nullptr; + FIDDLE() Expr* valueArg = nullptr; }; +FIDDLE() class SelectExpr : public OperatorExpr { - SLANG_AST_CLASS(SelectExpr) + FIDDLE(...) }; +FIDDLE() class LogicOperatorShortCircuitExpr : public OperatorExpr { - SLANG_AST_CLASS(LogicOperatorShortCircuitExpr) + FIDDLE(...) public: enum Flavor { And, // && Or, // || }; - Flavor flavor; + FIDDLE() Flavor flavor; }; +FIDDLE() class GenericAppExpr : public AppExprBase { - SLANG_AST_CLASS(GenericAppExpr) + FIDDLE(...) }; // An expression representing re-use of the syntax for a type in more // than once conceptually-distinct declaration +FIDDLE() class SharedTypeExpr : public Expr { - SLANG_AST_CLASS(SharedTypeExpr) + FIDDLE(...) // The underlying type expression that we want to share TypeExp base; }; +FIDDLE() class AssignExpr : public Expr { - SLANG_AST_CLASS(AssignExpr) - Expr* left = nullptr; - Expr* right = nullptr; + FIDDLE(...) + FIDDLE() Expr* left = nullptr; + FIDDLE() Expr* right = nullptr; }; // Just an expression inside parentheses `(exp)` // // We keep this around explicitly to be sure we don't lose any structure // when we do rewriter stuff. +FIDDLE() class ParenExpr : public Expr { - SLANG_AST_CLASS(ParenExpr) + FIDDLE(...) Expr* base = nullptr; }; // An object-oriented `this` expression, used to // refer to the current instance of an enclosing type. +FIDDLE() class ThisExpr : public Expr { - SLANG_AST_CLASS(ThisExpr) - + FIDDLE(...) SLANG_UNREFLECTED Scope* scope = nullptr; }; // Represent a reference to the virtual __return_val object holding the return value of // functions whose result type is non-copyable. +FIDDLE() class ReturnValExpr : public Expr { - SLANG_AST_CLASS(ReturnValExpr) - + FIDDLE(...) SLANG_UNREFLECTED Scope* scope = nullptr; }; // An expression that binds a temporary variable in a local expression context +FIDDLE() class LetExpr : public Expr { - SLANG_AST_CLASS(LetExpr) - VarDecl* decl = nullptr; - Expr* body = nullptr; + FIDDLE(...) + FIDDLE() VarDecl* decl = nullptr; + FIDDLE() Expr* body = nullptr; }; +FIDDLE() class ExtractExistentialValueExpr : public Expr { - SLANG_AST_CLASS(ExtractExistentialValueExpr) - DeclRef<VarDeclBase> declRef; + FIDDLE(...) + FIDDLE() DeclRef<VarDeclBase> declRef; Expr* originalExpr; }; +FIDDLE() class OpenRefExpr : public Expr { - SLANG_AST_CLASS(OpenRefExpr) - - Expr* innerExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* innerExpr = nullptr; }; +FIDDLE() class DetachExpr : public Expr { - SLANG_AST_CLASS(DetachExpr) - - Expr* inner = nullptr; + FIDDLE(...) + FIDDLE() Expr* inner = nullptr; }; /// Base class for higher-order function application /// Eg: foo(fn) where fn is a function expression. /// +FIDDLE(abstract) class HigherOrderInvokeExpr : public Expr { - SLANG_ABSTRACT_AST_CLASS(HigherOrderInvokeExpr) - Expr* baseFunction; - List<Name*> newParameterNames; + FIDDLE(...) + FIDDLE() Expr* baseFunction; + FIDDLE() List<Name*> newParameterNames; }; +FIDDLE() class PrimalSubstituteExpr : public HigherOrderInvokeExpr { - SLANG_AST_CLASS(PrimalSubstituteExpr) + FIDDLE(...) }; +FIDDLE(abstract) class DifferentiateExpr : public HigherOrderInvokeExpr { - SLANG_ABSTRACT_AST_CLASS(DifferentiateExpr) + FIDDLE(...) }; /// An expression of the form `__fwd_diff(fn)` to access the /// forward-mode derivative version of the function `fn` /// +FIDDLE() class ForwardDifferentiateExpr : public DifferentiateExpr { - SLANG_AST_CLASS(ForwardDifferentiateExpr) + FIDDLE(...) }; /// An expression of the form `__bwd_diff(fn)` to access the /// forward-mode derivative version of the function `fn` /// +FIDDLE() class BackwardDifferentiateExpr : public DifferentiateExpr { - SLANG_AST_CLASS(BackwardDifferentiateExpr) + FIDDLE(...) }; /// An expression of the form `__dispatch_kernel(fn, threadGroupSize, dispatchSize)` to /// dispatch a compute kernel from host. /// +FIDDLE() class DispatchKernelExpr : public HigherOrderInvokeExpr { - SLANG_AST_CLASS(DispatchKernelExpr) - Expr* threadGroupSize; - Expr* dispatchSize; + FIDDLE(...) + FIDDLE() Expr* threadGroupSize; + FIDDLE() Expr* dispatchSize; }; /// An express to mark its inner expression as an intended non-differential call. +FIDDLE() class TreatAsDifferentiableExpr : public Expr { - SLANG_AST_CLASS(TreatAsDifferentiableExpr) - - Expr* innerExpr; + FIDDLE(...) + FIDDLE() Expr* innerExpr; Scope* scope; enum Flavor @@ -645,70 +696,70 @@ class TreatAsDifferentiableExpr : public Expr Differentiable }; - Flavor flavor; + FIDDLE() Flavor flavor; }; /// A type expression of the form `This` /// /// Refers to the type of `this` in the current context. /// +FIDDLE() class ThisTypeExpr : public Expr { - SLANG_AST_CLASS(ThisTypeExpr) - + FIDDLE(...) SLANG_UNREFLECTED Scope* scope = nullptr; }; /// A type expression of the form `Left & Right`. +FIDDLE() class AndTypeExpr : public Expr { - SLANG_AST_CLASS(AndTypeExpr); - - TypeExp left; - TypeExp right; + FIDDLE(...) + FIDDLE() TypeExp left; + FIDDLE() TypeExp right; }; /// A type exprssion that applies one or more modifiers to another type +FIDDLE() class ModifiedTypeExpr : public Expr { - SLANG_AST_CLASS(ModifiedTypeExpr); - - Modifiers modifiers; - TypeExp base; + FIDDLE(...) + FIDDLE() Modifiers modifiers; + FIDDLE() TypeExp base; }; /// A type expression that rrepresents a pointer type, e.g. T* +FIDDLE() class PointerTypeExpr : public Expr { - SLANG_AST_CLASS(PointerTypeExpr); - - TypeExp base; + FIDDLE(...) + FIDDLE() TypeExp base; }; /// A type expression that represents a function type, e.g. (bool, int) -> float +FIDDLE() class FuncTypeExpr : public Expr { - SLANG_AST_CLASS(FuncTypeExpr); - - List<TypeExp> parameters; - TypeExp result; + FIDDLE(...) + FIDDLE() List<TypeExp> parameters; + FIDDLE() TypeExp result; }; +FIDDLE() class TupleTypeExpr : public Expr { - SLANG_AST_CLASS(TupleTypeExpr); - - List<TypeExp> members; + FIDDLE(...) + FIDDLE() List<TypeExp> members; }; /// An expression that applies a generic to arguments for some, /// but not all, of its explicit parameters. /// +FIDDLE() class PartiallyAppliedGenericExpr : public Expr { - SLANG_AST_CLASS(PartiallyAppliedGenericExpr); - + FIDDLE(...) public: Expr* originalExpr = nullptr; @@ -723,16 +774,17 @@ public: /// An expression that holds a set of argument exprs that got matched to a pack parameter /// during overload resolution. /// +FIDDLE() class PackExpr : public Expr { - SLANG_AST_CLASS(PackExpr) - - List<Expr*> args; + FIDDLE(...) + FIDDLE() List<Expr*> args; }; -class SPIRVAsmOperand +FIDDLE() +struct SPIRVAsmOperand { - SLANG_VALUE_CLASS(SPIRVAsmOperand); + FIDDLE(...) public: enum Flavor @@ -792,21 +844,21 @@ public: TypeExp type = TypeExp(); }; -class SPIRVAsmInst +FIDDLE() +struct SPIRVAsmInst { - SLANG_VALUE_CLASS(SPIRVAsmInst); - + FIDDLE(...) public: SPIRVAsmOperand opcode; List<SPIRVAsmOperand> operands; }; +FIDDLE() class SPIRVAsmExpr : public Expr { - SLANG_AST_CLASS(SPIRVAsmExpr); - + FIDDLE(...) public: - List<SPIRVAsmInst> insts; + FIDDLE() List<SPIRVAsmInst> insts; }; } // namespace Slang diff --git a/source/slang/slang-ast-forward-declarations.h b/source/slang/slang-ast-forward-declarations.h new file mode 100644 index 000000000..717bca1d9 --- /dev/null +++ b/source/slang/slang-ast-forward-declarations.h @@ -0,0 +1,29 @@ +// slang-ast-forward-declarations.h +#pragma once + +namespace Slang +{ + +enum class ASTNodeType +{ +#if 0 // FIDDLE TEMPLATE: +%for _, T in ipairs(Slang.NodeBase.subclasses) do + $T, +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-ast-forward-declarations.h.fiddle" +#endif // FIDDLE END + CountOf +}; + +#if 0 // FIDDLE TEMPLATE: +%for _, T in ipairs(Slang.NodeBase.subclasses) do + class $T; +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 1 +#include "slang-ast-forward-declarations.h.fiddle" +#endif // FIDDLE END + +} // namespace Slang diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index c7da945f2..2112d452e 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -38,9 +38,9 @@ struct ASTIterator { if (!expr) return; - expr->accept(this, nullptr); + this->dispatch(expr); } - bool visitExpr(Expr*) { return false; } + void visitExpr(Expr*) {} void visitBoolLiteralExpr(BoolLiteralExpr* expr) { iterator->maybeDispatchCallback(expr); } void visitNullPtrLiteralExpr(NullPtrLiteralExpr* expr) { @@ -313,6 +313,8 @@ struct ASTIterator dispatchIfNotNull(o.expr); } } + + void visitDetachExpr(DetachExpr* expr) { iterator->maybeDispatchCallback(expr); } }; struct ASTIteratorStmtVisitor : public StmtVisitor<ASTIteratorStmtVisitor> @@ -327,7 +329,7 @@ struct ASTIterator { if (!stmt) return; - stmt->accept(this, nullptr); + this->dispatch(stmt); } void visitDeclStmt(DeclStmt* stmt) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 563084361..e566eca9e 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1,9 +1,10 @@ // slang-ast-modifier.h - #pragma once #include "slang-ast-base.h" +#include "slang-ast-modifier.h.fiddle" +FIDDLE() namespace Slang { @@ -11,246 +12,299 @@ namespace Slang // Simple modifiers have no state beyond their identity +FIDDLE() class InModifier : public Modifier { - SLANG_AST_CLASS(InModifier) + FIDDLE(...) }; + +FIDDLE() class OutModifier : public Modifier { - SLANG_AST_CLASS(OutModifier) + FIDDLE(...) }; + +FIDDLE() class ConstModifier : public Modifier { - SLANG_AST_CLASS(ConstModifier) + FIDDLE(...) }; + +FIDDLE() class BuiltinModifier : public Modifier { - SLANG_AST_CLASS(BuiltinModifier) + FIDDLE(...) }; + +FIDDLE() class InlineModifier : public Modifier { - SLANG_AST_CLASS(InlineModifier) + FIDDLE(...) }; + +FIDDLE(abstract) class VisibilityModifier : public Modifier { - SLANG_AST_CLASS(VisibilityModifier) + FIDDLE(...) }; + +FIDDLE() class PublicModifier : public VisibilityModifier { - SLANG_AST_CLASS(PublicModifier) + FIDDLE(...) }; + +FIDDLE() class PrivateModifier : public VisibilityModifier { - SLANG_AST_CLASS(PrivateModifier) + FIDDLE(...) }; + +FIDDLE() class InternalModifier : public VisibilityModifier { - SLANG_AST_CLASS(InternalModifier) + FIDDLE(...) }; + +FIDDLE() class RequireModifier : public Modifier { - SLANG_AST_CLASS(RequireModifier) + FIDDLE(...) }; + +FIDDLE() class ParamModifier : public Modifier { - SLANG_AST_CLASS(ParamModifier) + FIDDLE(...) }; + +FIDDLE() class ExternModifier : public Modifier { - SLANG_AST_CLASS(ExternModifier) + FIDDLE(...) }; + +FIDDLE() class HLSLExportModifier : public Modifier { - SLANG_AST_CLASS(HLSLExportModifier) + FIDDLE(...) }; + +FIDDLE() class TransparentModifier : public Modifier { - SLANG_AST_CLASS(TransparentModifier) + FIDDLE(...) }; + +FIDDLE() class FromCoreModuleModifier : public Modifier { - SLANG_AST_CLASS(FromCoreModuleModifier) + FIDDLE(...) }; + +FIDDLE() class PrefixModifier : public Modifier { - SLANG_AST_CLASS(PrefixModifier) + FIDDLE(...) }; + +FIDDLE() class PostfixModifier : public Modifier { - SLANG_AST_CLASS(PostfixModifier) + FIDDLE(...) }; + +FIDDLE() class ExportedModifier : public Modifier { - SLANG_AST_CLASS(ExportedModifier) + FIDDLE(...) }; + +FIDDLE() class ConstExprModifier : public Modifier { - SLANG_AST_CLASS(ConstExprModifier) + FIDDLE(...) }; + +FIDDLE() class ExternCppModifier : public Modifier { - SLANG_AST_CLASS(ExternCppModifier) + FIDDLE(...) }; + +FIDDLE() class GLSLPrecisionModifier : public Modifier { - SLANG_AST_CLASS(GLSLPrecisionModifier) + FIDDLE(...) }; + +FIDDLE() class GLSLModuleModifier : public Modifier { - SLANG_AST_CLASS(GLSLModuleModifier) + FIDDLE(...) }; + // Marks that the definition of a decl is not yet synthesized. +FIDDLE() class ToBeSynthesizedModifier : public Modifier { - SLANG_AST_CLASS(ToBeSynthesizedModifier) + FIDDLE(...) }; // Marks that the definition of a decl is synthesized. +FIDDLE() class SynthesizedModifier : public Modifier { - SLANG_AST_CLASS(SynthesizedModifier) + FIDDLE(...) }; // Marks a synthesized variable as local temporary variable. +FIDDLE() class LocalTempVarModifier : public Modifier { - SLANG_AST_CLASS(LocalTempVarModifier) + FIDDLE(...) }; // An `extern` variable in an extension is used to introduce additional attributes on an existing // field. +FIDDLE() class ExtensionExternVarModifier : public Modifier { - SLANG_AST_CLASS(ExtensionExternVarModifier) - DeclRef<Decl> originalDecl; + FIDDLE(...) + FIDDLE() DeclRef<Decl> originalDecl; }; // An 'ActualGlobal' is a global that is output as a normal global in CPU code. // Globals in HLSL/Slang are constant state passed into kernel execution +FIDDLE() class ActualGlobalModifier : public Modifier { - SLANG_AST_CLASS(ActualGlobalModifier) + FIDDLE(...) }; /// A modifier that indicates an `InheritanceDecl` should be ignored during name lookup (and related /// checks). +FIDDLE() class IgnoreForLookupModifier : public Modifier { - SLANG_AST_CLASS(IgnoreForLookupModifier) + FIDDLE(...) }; // A modifier that marks something as an operation that // has a one-to-one translation to the IR, and thus // has no direct definition in the high-level language. // +FIDDLE() class IntrinsicOpModifier : public Modifier { - SLANG_AST_CLASS(IntrinsicOpModifier) - + FIDDLE(...) // Token that names the intrinsic op. Token opToken; // The IR opcode for the intrinsic operation. // - uint32_t op = 0; + FIDDLE() uint32_t op = 0; }; // A modifier that marks something as an intrinsic function, // for some subset of targets. +FIDDLE() class TargetIntrinsicModifier : public Modifier { - SLANG_AST_CLASS(TargetIntrinsicModifier) - + FIDDLE(...) // Token that names the target that the operation // is an intrisic for. - Token targetToken; + FIDDLE() Token targetToken; // A custom definition for the operation, one of either an ident or a // string (the concatenation of several string literals) Token definitionIdent; - String definitionString; + FIDDLE() String definitionString; bool isString; // A predicate to be used on an identifier to guard this intrinsic Token predicateToken; NameLoc scrutinee; - DeclRef<Decl> scrutineeDeclRef; + FIDDLE() DeclRef<Decl> scrutineeDeclRef; }; // A modifier that marks a declaration as representing a // specialization that should be preferred on a particular // target. +FIDDLE() class SpecializedForTargetModifier : public Modifier { - SLANG_AST_CLASS(SpecializedForTargetModifier) - + FIDDLE(...) // Token that names the target that the operation // has been specialized for. - Token targetToken; + FIDDLE() Token targetToken; }; // A modifier to tag something as an intrinsic that requires // a certain GLSL extension to be enabled when used +FIDDLE() class RequiredGLSLExtensionModifier : public Modifier { - SLANG_AST_CLASS(RequiredGLSLExtensionModifier) - - Token extensionNameToken; + FIDDLE(...) + FIDDLE() Token extensionNameToken; }; // A modifier to tag something as an intrinsic that requires // a certain GLSL version to be enabled when used +FIDDLE() class RequiredGLSLVersionModifier : public Modifier { - SLANG_AST_CLASS(RequiredGLSLVersionModifier) - - Token versionNumberToken; + FIDDLE(...) + FIDDLE() Token versionNumberToken; }; // A modifier to tag something as an intrinsic that requires // a certain SPIRV version to be enabled when used. Specified as "major.minor" +FIDDLE() class RequiredSPIRVVersionModifier : public Modifier { - SLANG_AST_CLASS(RequiredSPIRVVersionModifier) - - SemanticVersion version; + FIDDLE(...) + FIDDLE() SemanticVersion version; }; // A modifier to tag something as an intrinsic that requires // a certain WGSL extension to be enabled when used +FIDDLE() class RequiredWGSLExtensionModifier : public Modifier { - SLANG_AST_CLASS(RequiredWGSLExtensionModifier) - - Token extensionNameToken; + FIDDLE(...) + FIDDLE() Token extensionNameToken; }; // A modifier to tag something as an intrinsic that requires // a certain CUDA SM version to be enabled when used. Specified as "major.minor" +FIDDLE() class RequiredCUDASMVersionModifier : public Modifier { - SLANG_AST_CLASS(RequiredCUDASMVersionModifier) - - SemanticVersion version; + FIDDLE(...) + FIDDLE() SemanticVersion version; }; +FIDDLE() class InOutModifier : public OutModifier { - SLANG_AST_CLASS(InOutModifier) + FIDDLE(...) }; // `__ref` modifier for by-reference parameter passing +FIDDLE() class RefModifier : public Modifier { - SLANG_AST_CLASS(RefModifier) + FIDDLE(...) }; // `__ref` modifier for by-reference parameter passing +FIDDLE() class ConstRefModifier : public Modifier { - SLANG_AST_CLASS(ConstRefModifier) + FIDDLE(...) }; // This is a special sentinel modifier that gets added @@ -267,132 +321,147 @@ class ConstRefModifier : public Modifier // / // b: RegisterModifier("x0") / // +FIDDLE() class SharedModifiers : public Modifier { - SLANG_AST_CLASS(SharedModifiers) + FIDDLE(...) }; // AST nodes to represent the begin/end of a `layout` modifier group +FIDDLE(abstract) class GLSLLayoutModifierGroupMarker : public Modifier { - SLANG_ABSTRACT_AST_CLASS(GLSLLayoutModifierGroupMarker) + FIDDLE(...) }; +FIDDLE() class GLSLLayoutModifierGroupBegin : public GLSLLayoutModifierGroupMarker { - SLANG_AST_CLASS(GLSLLayoutModifierGroupBegin) + FIDDLE(...) }; +FIDDLE() class GLSLLayoutModifierGroupEnd : public GLSLLayoutModifierGroupMarker { - SLANG_AST_CLASS(GLSLLayoutModifierGroupEnd) + FIDDLE(...) }; +FIDDLE() class GLSLUnparsedLayoutModifier : public Modifier { - SLANG_AST_CLASS(GLSLUnparsedLayoutModifier) + FIDDLE(...) }; +FIDDLE() class GLSLBufferDataLayoutModifier : public Modifier { - SLANG_AST_CLASS(GLSLBufferDataLayoutModifier) + FIDDLE(...) }; +FIDDLE() class GLSLStd140Modifier : public GLSLBufferDataLayoutModifier { - SLANG_AST_CLASS(GLSLStd140Modifier) + FIDDLE(...) }; +FIDDLE() class GLSLStd430Modifier : public GLSLBufferDataLayoutModifier { - SLANG_AST_CLASS(GLSLStd430Modifier) + FIDDLE(...) }; +FIDDLE() class GLSLScalarModifier : public GLSLBufferDataLayoutModifier { - SLANG_AST_CLASS(GLSLScalarModifier) + FIDDLE(...) }; // A catch-all for single-keyword modifiers +FIDDLE() class SimpleModifier : public Modifier { - SLANG_AST_CLASS(SimpleModifier) + FIDDLE(...) }; // Indicates that this is a variable declaration that corresponds to // a parameter block declaration in the source program. +FIDDLE() class ImplicitParameterGroupVariableModifier : public Modifier { - SLANG_AST_CLASS(ImplicitParameterGroupVariableModifier) + FIDDLE(...) }; // Indicates that this is a type that corresponds to the element // type of a parameter block declaration in the source program. +FIDDLE() class ImplicitParameterGroupElementTypeModifier : public Modifier { - SLANG_AST_CLASS(ImplicitParameterGroupElementTypeModifier) + FIDDLE(...) }; // An HLSL semantic +FIDDLE(abstract) class HLSLSemantic : public Modifier { - SLANG_ABSTRACT_AST_CLASS(HLSLSemantic) - - Token name; + FIDDLE(...) + FIDDLE() Token name; }; // An HLSL semantic that affects layout +FIDDLE() class HLSLLayoutSemantic : public HLSLSemantic { - SLANG_AST_CLASS(HLSLLayoutSemantic) - - Token registerName; - Token componentMask; + FIDDLE(...) + FIDDLE() Token registerName; + FIDDLE() Token componentMask; }; // An HLSL `register` semantic +FIDDLE() class HLSLRegisterSemantic : public HLSLLayoutSemantic { - SLANG_AST_CLASS(HLSLRegisterSemantic) - - Token spaceName; + FIDDLE(...) + FIDDLE() Token spaceName; }; // TODO(tfoley): `packoffset` +FIDDLE() class HLSLPackOffsetSemantic : public HLSLLayoutSemantic { - SLANG_AST_CLASS(HLSLPackOffsetSemantic) - - int uniformOffset = 0; + FIDDLE(...) + FIDDLE() int uniformOffset = 0; }; // An HLSL semantic that just associated a declaration with a semantic name +FIDDLE() class HLSLSimpleSemantic : public HLSLSemantic { - SLANG_AST_CLASS(HLSLSimpleSemantic) + FIDDLE(...) }; // A semantic applied to a field of a ray-payload type, to control access +FIDDLE() class RayPayloadAccessSemantic : public HLSLSemantic { - SLANG_AST_CLASS(RayPayloadAccessSemantic) - - List<Token> stageNameTokens; + FIDDLE(...) + FIDDLE() List<Token> stageNameTokens; }; +FIDDLE() class RayPayloadReadSemantic : public RayPayloadAccessSemantic { - SLANG_AST_CLASS(RayPayloadReadSemantic) + FIDDLE(...) }; +FIDDLE() class RayPayloadWriteSemantic : public RayPayloadAccessSemantic { - SLANG_AST_CLASS(RayPayloadWriteSemantic) + FIDDLE(...) }; @@ -400,73 +469,72 @@ class RayPayloadWriteSemantic : public RayPayloadAccessSemantic // Directives that came in via the preprocessor, but // that we need to keep around for later steps +FIDDLE() class GLSLPreprocessorDirective : public Modifier { - SLANG_AST_CLASS(GLSLPreprocessorDirective) + FIDDLE(...) }; // A GLSL `#version` directive +FIDDLE() class GLSLVersionDirective : public GLSLPreprocessorDirective { - SLANG_AST_CLASS(GLSLVersionDirective) - - + FIDDLE(...) // Token giving the version number to use - Token versionNumberToken; + FIDDLE() Token versionNumberToken; // Optional token giving the sub-profile to be used - Token glslProfileToken; + FIDDLE() Token glslProfileToken; }; // A GLSL `#extension` directive +FIDDLE() class GLSLExtensionDirective : public GLSLPreprocessorDirective { - SLANG_AST_CLASS(GLSLExtensionDirective) - - + FIDDLE(...) // Token giving the version number to use - Token extensionNameToken; + FIDDLE() Token extensionNameToken; // Optional token giving the sub-profile to be used - Token dispositionToken; + FIDDLE() Token dispositionToken; }; +FIDDLE() class ParameterGroupReflectionName : public Modifier { - SLANG_AST_CLASS(ParameterGroupReflectionName) - - NameLoc nameAndLoc; + FIDDLE(...) + FIDDLE() NameLoc nameAndLoc; }; // A modifier that indicates a built-in base type (e.g., `float`) +FIDDLE() class BuiltinTypeModifier : public Modifier { - SLANG_AST_CLASS(BuiltinTypeModifier) - - BaseType tag; + FIDDLE(...) + FIDDLE() BaseType tag; }; // A modifier that indicates a built-in type that isn't a base type (e.g., `vector`) // // TODO(tfoley): This deserves a better name than "magic" +FIDDLE() class MagicTypeModifier : public Modifier { - SLANG_AST_CLASS(MagicTypeModifier) - - ASTNodeType magicNodeType = ASTNodeType(-1); + FIDDLE(...) + FIDDLE() SyntaxClass<NodeBase> magicNodeType; /// Modifier has a name so call this magicModifier to disambiguate - String magicName; - uint32_t tag = uint32_t(0); + FIDDLE() String magicName; + FIDDLE() uint32_t tag = uint32_t(0); }; // A modifier that indicates a built-in associated type requirement (e.g., `Differential`) +FIDDLE() class BuiltinRequirementModifier : public Modifier { - SLANG_AST_CLASS(BuiltinRequirementModifier); - - BuiltinRequirementKind kind; + FIDDLE(...) + FIDDLE() BuiltinRequirementKind kind; }; @@ -475,47 +543,52 @@ class BuiltinRequirementModifier : public Modifier // // TODO: This should really subsume `BuiltinTypeModifier` and // `MagicTypeModifier` so that we don't have to apply all of them. +FIDDLE() class IntrinsicTypeModifier : public Modifier { - SLANG_AST_CLASS(IntrinsicTypeModifier) - + FIDDLE(...) // The IR opcode to use when constructing a type - uint32_t irOp; + FIDDLE() uint32_t irOp; Token opToken; // Additional literal opreands to provide when creating instances. // (e.g., for a texture type this passes in shape/mutability info) - List<uint32_t> irOperands; + FIDDLE() List<uint32_t> irOperands; }; // Modifiers that affect the storage layout for matrices +FIDDLE(abstract) class MatrixLayoutModifier : public Modifier { - SLANG_AST_CLASS(MatrixLayoutModifier) + FIDDLE(...) }; // Modifiers that specify row- and column-major layout, respectively +FIDDLE(abstract) class RowMajorLayoutModifier : public MatrixLayoutModifier { - SLANG_AST_CLASS(RowMajorLayoutModifier) + FIDDLE(...) }; +FIDDLE(abstract) class ColumnMajorLayoutModifier : public MatrixLayoutModifier { - SLANG_AST_CLASS(ColumnMajorLayoutModifier) + FIDDLE(...) }; // The HLSL flavor of those modifiers +FIDDLE() class HLSLRowMajorLayoutModifier : public RowMajorLayoutModifier { - SLANG_AST_CLASS(HLSLRowMajorLayoutModifier) + FIDDLE(...) }; +FIDDLE() class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier { - SLANG_AST_CLASS(HLSLColumnMajorLayoutModifier) + FIDDLE(...) }; @@ -526,676 +599,736 @@ class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier // we actually interpret that as requesting column-major. This makes // sense because we interpret matrix conventions backwards from how // GLSL specifies them. +FIDDLE() class GLSLRowMajorLayoutModifier : public ColumnMajorLayoutModifier { - SLANG_AST_CLASS(GLSLRowMajorLayoutModifier) + FIDDLE(...) }; +FIDDLE() class GLSLColumnMajorLayoutModifier : public RowMajorLayoutModifier { - SLANG_AST_CLASS(GLSLColumnMajorLayoutModifier) + FIDDLE(...) }; // More HLSL Keyword +FIDDLE(abstract) class InterpolationModeModifier : public Modifier { - SLANG_ABSTRACT_AST_CLASS(InterpolationModeModifier) + FIDDLE(...) }; // HLSL `nointerpolation` modifier +FIDDLE() class HLSLNoInterpolationModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(HLSLNoInterpolationModifier) + FIDDLE(...) }; // HLSL `noperspective` modifier +FIDDLE() class HLSLNoPerspectiveModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(HLSLNoPerspectiveModifier) + FIDDLE(...) }; // HLSL `linear` modifier +FIDDLE() class HLSLLinearModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(HLSLLinearModifier) + FIDDLE(...) }; // HLSL `sample` modifier +FIDDLE() class HLSLSampleModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(HLSLSampleModifier) + FIDDLE(...) }; // HLSL `centroid` modifier +FIDDLE() class HLSLCentroidModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(HLSLCentroidModifier) + FIDDLE(...) }; /// Slang-defined `pervertex` modifier +FIDDLE() class PerVertexModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(PerVertexModifier) + FIDDLE(...) }; // HLSL `precise` modifier +FIDDLE() class PreciseModifier : public Modifier { - SLANG_AST_CLASS(PreciseModifier) + FIDDLE(...) }; // HLSL `shared` modifier (which is used by the effect system, // and shouldn't be confused with `groupshared`) +FIDDLE() class HLSLEffectSharedModifier : public Modifier { - SLANG_AST_CLASS(HLSLEffectSharedModifier) + FIDDLE(...) }; // HLSL `groupshared` modifier +FIDDLE() class HLSLGroupSharedModifier : public Modifier { - SLANG_AST_CLASS(HLSLGroupSharedModifier) + FIDDLE(...) }; // HLSL `static` modifier (probably doesn't need to be // treated as HLSL-specific) +FIDDLE() class HLSLStaticModifier : public Modifier { - SLANG_AST_CLASS(HLSLStaticModifier) + FIDDLE(...) }; // HLSL `uniform` modifier (distinct meaning from GLSL // use of the keyword) +FIDDLE() class HLSLUniformModifier : public Modifier { - SLANG_AST_CLASS(HLSLUniformModifier) + FIDDLE(...) }; // HLSL `volatile` modifier (ignored) +FIDDLE() class HLSLVolatileModifier : public Modifier { - SLANG_AST_CLASS(HLSLVolatileModifier) + FIDDLE(...) }; +FIDDLE() class AttributeTargetModifier : public Modifier { - SLANG_AST_CLASS(AttributeTargetModifier) - + FIDDLE(...) // A class to which the declared attribute type is applicable - SyntaxClass<NodeBase> syntaxClass; + FIDDLE() SyntaxClass<NodeBase> syntaxClass; }; // Base class for checked and unchecked `[name(arg0, ...)]` style attribute. +FIDDLE(abstract) class AttributeBase : public Modifier { - SLANG_AST_CLASS(AttributeBase) - - AttributeDecl* attributeDecl = nullptr; + FIDDLE(...) + FIDDLE() AttributeDecl* attributeDecl = nullptr; // The original identifier token representing the last part of the qualified name. Token originalIdentifierToken; - List<Expr*> args; + FIDDLE() List<Expr*> args; }; // A `[name(...)]` attribute that hasn't undergone any semantic analysis. // After analysis, this will be transformed into a more specific case. +FIDDLE() class UncheckedAttribute : public AttributeBase { - SLANG_AST_CLASS(UncheckedAttribute) - + FIDDLE(...) SLANG_UNREFLECTED Scope* scope = nullptr; }; // A GLSL layout qualifier whose value has not yet been resolved or validated. +FIDDLE() class UncheckedGLSLLayoutAttribute : public AttributeBase { - SLANG_AST_CLASS(UncheckedGLSLLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; // GLSL `binding` layout qualifier, does not include `set`. +FIDDLE() class UncheckedGLSLBindingLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLBindingLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; // GLSL `set` layout qualifier, does not include `binding`. +FIDDLE() class UncheckedGLSLSetLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLSetLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; // GLSL `offset` layout qualifier. +FIDDLE() class UncheckedGLSLOffsetLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLOffsetLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLInputAttachmentIndexLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLInputAttachmentIndexLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLLocationLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLLocationLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLIndexLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLIndexLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLConstantIdAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLConstantIdAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLRayPayloadAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLRayPayloadAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLRayPayloadInAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLRayPayloadInAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; - +FIDDLE() class UncheckedGLSLHitObjectAttributesAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLHitObjectAttributesAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLCallablePayloadAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLCallablePayloadAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLCallablePayloadInAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLCallablePayloadInAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; // A `[name(arg0, ...)]` style attribute that has been validated. +FIDDLE() class Attribute : public AttributeBase { - SLANG_AST_CLASS(Attribute) - - List<Val*> intArgVals; + FIDDLE(...) + FIDDLE() List<Val*> intArgVals; }; +FIDDLE() class UserDefinedAttribute : public Attribute { - SLANG_AST_CLASS(UserDefinedAttribute) + FIDDLE(...) }; +FIDDLE() class AttributeUsageAttribute : public Attribute { - SLANG_AST_CLASS(AttributeUsageAttribute) - - SyntaxClass<NodeBase> targetSyntaxClass; + FIDDLE(...) + FIDDLE() SyntaxClass<NodeBase> targetSyntaxClass; }; +FIDDLE() class NonDynamicUniformAttribute : public Attribute { - SLANG_AST_CLASS(NonDynamicUniformAttribute) + FIDDLE(...) }; +FIDDLE() class RequireCapabilityAttribute : public Attribute { - SLANG_AST_CLASS(RequireCapabilityAttribute) - CapabilitySet capabilitySet; + FIDDLE(...) + FIDDLE() CapabilitySet capabilitySet; }; // An `[unroll]` or `[unroll(count)]` attribute +FIDDLE() class UnrollAttribute : public Attribute { - SLANG_AST_CLASS(UnrollAttribute) + FIDDLE(...) }; // An `[unroll]` or `[unroll(count)]` attribute +FIDDLE() class ForceUnrollAttribute : public Attribute { - SLANG_AST_CLASS(ForceUnrollAttribute) - - int32_t maxIterations = 0; + FIDDLE(...) + FIDDLE() int32_t maxIterations = 0; }; // An `[maxiters(count)]` +FIDDLE() class MaxItersAttribute : public Attribute { - SLANG_AST_CLASS(MaxItersAttribute) - - IntVal* value = 0; + FIDDLE(...) + FIDDLE() IntVal* value = 0; }; // An inferred max iteration count on a loop. +FIDDLE() class InferredMaxItersAttribute : public Attribute { - SLANG_AST_CLASS(InferredMaxItersAttribute) - DeclRef<Decl> inductionVar; - int32_t value = 0; + FIDDLE(...) + FIDDLE() DeclRef<Decl> inductionVar; + FIDDLE() int32_t value = 0; }; +FIDDLE() class LoopAttribute : public Attribute { - SLANG_AST_CLASS(LoopAttribute) + FIDDLE(...) }; // `[loop]` + +FIDDLE() class FastOptAttribute : public Attribute { - SLANG_AST_CLASS(FastOptAttribute) + FIDDLE(...) }; // `[fastopt]` + +FIDDLE() class AllowUAVConditionAttribute : public Attribute { - SLANG_AST_CLASS(AllowUAVConditionAttribute) + FIDDLE(...) }; // `[allow_uav_condition]` + +FIDDLE() class BranchAttribute : public Attribute { - SLANG_AST_CLASS(BranchAttribute) + FIDDLE(...) }; // `[branch]` + +FIDDLE() class FlattenAttribute : public Attribute { - SLANG_AST_CLASS(FlattenAttribute) + FIDDLE(...) }; // `[flatten]` + +FIDDLE() class ForceCaseAttribute : public Attribute { - SLANG_AST_CLASS(ForceCaseAttribute) + FIDDLE(...) }; // `[forcecase]` + +FIDDLE() class CallAttribute : public Attribute { - SLANG_AST_CLASS(CallAttribute) + FIDDLE(...) }; // `[call]` +FIDDLE() class UnscopedEnumAttribute : public Attribute { - SLANG_AST_CLASS(UnscopedEnumAttribute) + FIDDLE(...) }; // Marks a enum to have `flags` semantics, where each enum case is a bitfield. +FIDDLE() class FlagsAttribute : public Attribute { - SLANG_AST_CLASS(FlagsAttribute); + FIDDLE(...) }; // [[vk_push_constant]] [[push_constant]] +FIDDLE() class PushConstantAttribute : public Attribute { - SLANG_AST_CLASS(PushConstantAttribute) + FIDDLE(...) }; // [[vk_specialization_constant]] [[specialization_constant]] +FIDDLE() class SpecializationConstantAttribute : public Attribute { - SLANG_AST_CLASS(SpecializationConstantAttribute) + FIDDLE(...) }; // [[vk_constant_id]] +FIDDLE() class VkConstantIdAttribute : public Attribute { - SLANG_AST_CLASS(VkConstantIdAttribute) - int location; + FIDDLE(...) + FIDDLE() int location; }; // [[vk_shader_record]] [[shader_record]] +FIDDLE() class ShaderRecordAttribute : public Attribute { - SLANG_AST_CLASS(ShaderRecordAttribute) + FIDDLE(...) }; // [[vk_binding]] +FIDDLE() class GLSLBindingAttribute : public Attribute { - SLANG_AST_CLASS(GLSLBindingAttribute) - - int32_t binding = 0; - int32_t set = 0; + FIDDLE(...) + FIDDLE() int32_t binding = 0; + FIDDLE() int32_t set = 0; }; +FIDDLE() class VkAliasedPointerAttribute : public Attribute { - SLANG_AST_CLASS(VkAliasedPointerAttribute) + FIDDLE(...) }; +FIDDLE() class VkRestrictPointerAttribute : public Attribute { - SLANG_AST_CLASS(VkRestrictPointerAttribute) + FIDDLE(...) }; +FIDDLE() class GLSLOffsetLayoutAttribute : public Attribute { - SLANG_AST_CLASS(GLSLOffsetLayoutAttribute) - - int64_t offset; + FIDDLE(...) + FIDDLE() int64_t offset; }; // Implicitly added offset qualifier when no offset is specified. +FIDDLE() class GLSLImplicitOffsetLayoutAttribute : public AttributeBase { - SLANG_AST_CLASS(GLSLImplicitOffsetLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class GLSLSimpleIntegerLayoutAttribute : public Attribute { - SLANG_AST_CLASS(GLSLSimpleIntegerLayoutAttribute) - - int32_t value = 0; + FIDDLE(...) + FIDDLE() int32_t value = 0; }; /// [[vk_input_attachment_index]] +FIDDLE() class GLSLInputAttachmentIndexLayoutAttribute : public Attribute { - SLANG_AST_CLASS(GLSLInputAttachmentIndexLayoutAttribute) - - IntegerLiteralValue location; + FIDDLE(...) + FIDDLE() IntegerLiteralValue location; }; // [[vk_location]] +FIDDLE() class GLSLLocationAttribute : public GLSLSimpleIntegerLayoutAttribute { - SLANG_AST_CLASS(GLSLLocationAttribute) + FIDDLE(...) }; // [[vk_index]] +FIDDLE() class GLSLIndexAttribute : public GLSLSimpleIntegerLayoutAttribute { - SLANG_AST_CLASS(GLSLIndexAttribute) + FIDDLE(...) }; // [[vk_offset]] +FIDDLE() class VkStructOffsetAttribute : public GLSLSimpleIntegerLayoutAttribute { - SLANG_AST_CLASS(VkStructOffsetAttribute) + FIDDLE(...) }; // [[vk_spirv_instruction]] +FIDDLE() class SPIRVInstructionOpAttribute : public Attribute { - SLANG_AST_CLASS(SPIRVInstructionOpAttribute) + FIDDLE(...) }; // [[spv_target_env_1_3]] +FIDDLE() class SPIRVTargetEnv13Attribute : public Attribute { - SLANG_AST_CLASS(SPIRVTargetEnv13Attribute); + FIDDLE(...) }; // [[disable_array_flattening]] +FIDDLE() class DisableArrayFlatteningAttribute : public Attribute { - SLANG_AST_CLASS(DisableArrayFlatteningAttribute); + FIDDLE(...) }; // A GLSL layout(local_size_x = 64, ... attribute) +FIDDLE() class GLSLLayoutLocalSizeAttribute : public Attribute { - SLANG_AST_CLASS(GLSLLayoutLocalSizeAttribute) - + FIDDLE(...) // The number of threads to use along each axis // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* extents[3]; + FIDDLE() IntVal* extents[3]; - bool axisIsSpecConstId[3]; + FIDDLE() bool axisIsSpecConstId[3]; // References to specialization constants, for defining the number of // threads with them. If set, the corresponding axis is set to nullptr // above. - DeclRef<VarDeclBase> specConstExtents[3]; + FIDDLE() DeclRef<VarDeclBase> specConstExtents[3]; }; +FIDDLE() class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute { - SLANG_AST_CLASS(GLSLLayoutDerivativeGroupQuadAttribute) + FIDDLE(...) }; +FIDDLE() class GLSLLayoutDerivativeGroupLinearAttribute : public Attribute { - SLANG_AST_CLASS(GLSLLayoutDerivativeGroupLinearAttribute) + FIDDLE(...) }; // TODO: for attributes that take arguments, the syntax node // classes should provide accessors for the values of those arguments. +FIDDLE() class MaxTessFactorAttribute : public Attribute { - SLANG_AST_CLASS(MaxTessFactorAttribute) + FIDDLE(...) }; +FIDDLE() class OutputControlPointsAttribute : public Attribute { - SLANG_AST_CLASS(OutputControlPointsAttribute) + FIDDLE(...) }; +FIDDLE() class OutputTopologyAttribute : public Attribute { - SLANG_AST_CLASS(OutputTopologyAttribute) + FIDDLE(...) }; +FIDDLE() class PartitioningAttribute : public Attribute { - SLANG_AST_CLASS(PartitioningAttribute) + FIDDLE(...) }; +FIDDLE() class PatchConstantFuncAttribute : public Attribute { - SLANG_AST_CLASS(PatchConstantFuncAttribute) - - FuncDecl* patchConstantFuncDecl = nullptr; + FIDDLE(...) + FIDDLE() FuncDecl* patchConstantFuncDecl = nullptr; }; + +FIDDLE() class DomainAttribute : public Attribute { - SLANG_AST_CLASS(DomainAttribute) + FIDDLE(...) }; +FIDDLE() class EarlyDepthStencilAttribute : public Attribute { - SLANG_AST_CLASS(EarlyDepthStencilAttribute) + FIDDLE(...) }; // `[earlydepthstencil]` // An HLSL `[numthreads(x,y,z)]` attribute +FIDDLE() class NumThreadsAttribute : public Attribute { - SLANG_AST_CLASS(NumThreadsAttribute) - + FIDDLE(...) // The number of threads to use along each axis // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* extents[3]; + FIDDLE() IntVal* extents[3]; // References to specialization constants, for defining the number of // threads with them. If set, the corresponding axis is set to nullptr // above. - DeclRef<VarDeclBase> specConstExtents[3]; + FIDDLE() DeclRef<VarDeclBase> specConstExtents[3]; }; +FIDDLE() class WaveSizeAttribute : public Attribute { - SLANG_AST_CLASS(WaveSizeAttribute) - + FIDDLE(...) // "numLanes" must be a compile time constant integer // value of an allowed wave size, which is one of the // followings: 4, 8, 16, 32, 64 or 128. // - IntVal* numLanes; + FIDDLE() IntVal* numLanes; }; +FIDDLE() class MaxVertexCountAttribute : public Attribute { - SLANG_AST_CLASS(MaxVertexCountAttribute) - + FIDDLE(...) // The number of max vertex count for geometry shader // // TODO: This should be an accessor that uses the // ordinary `args` list, rather than side data. - int32_t value; + FIDDLE() int32_t value; }; +FIDDLE() class InstanceAttribute : public Attribute { - SLANG_AST_CLASS(InstanceAttribute) - + FIDDLE(...) // The number of instances to run for geometry shader // // TODO: This should be an accessor that uses the // ordinary `args` list, rather than side data. - int32_t value; + FIDDLE() int32_t value; }; // A `[shader("stageName")]`/`[shader("capability")]` attribute which // marks an entry point for compiling. This attribute also specifies // the 'capabilities' implicitly supported by an entry point +FIDDLE() class EntryPointAttribute : public Attribute { - SLANG_AST_CLASS(EntryPointAttribute) - + FIDDLE(...) // The resolved capailities for our entry point. - CapabilitySet capabilitySet; + FIDDLE() CapabilitySet capabilitySet; }; // A `[__vulkanRayPayload(location)]` attribute, which is used in the // core module implementation to indicate that a variable // actually represents the input/output interface for a Vulkan // ray tracing shader to pass per-ray payload information. +FIDDLE() class VulkanRayPayloadAttribute : public Attribute { - SLANG_AST_CLASS(VulkanRayPayloadAttribute) - - int location; + FIDDLE(...) + FIDDLE() int location; }; +FIDDLE() class VulkanRayPayloadInAttribute : public Attribute { - SLANG_AST_CLASS(VulkanRayPayloadInAttribute) - - int location; + FIDDLE(...) + FIDDLE() int location; }; // A `[__vulkanCallablePayload(location)]` attribute, which is used in the // core module implementation to indicate that a variable // actually represents the input/output interface for a Vulkan // ray tracing shader to pass payload information to/from a callee. +FIDDLE() class VulkanCallablePayloadAttribute : public Attribute { - SLANG_AST_CLASS(VulkanCallablePayloadAttribute) - - int location; + FIDDLE(...) + FIDDLE() int location; }; +FIDDLE() class VulkanCallablePayloadInAttribute : public Attribute { - SLANG_AST_CLASS(VulkanCallablePayloadInAttribute) - - int location; + FIDDLE(...) + FIDDLE() int location; }; // A `[__vulkanHitAttributes]` attribute, which is used in the // core module implementation to indicate that a variable // actually represents the output interface for a Vulkan // intersection shader to pass hit attribute information. +FIDDLE() class VulkanHitAttributesAttribute : public Attribute { - SLANG_AST_CLASS(VulkanHitAttributesAttribute) + FIDDLE(...) }; // A `[__vulkanHitObjectAttributes(location)]` attribute, which is used in the // core module implementation to indicate that a variable // actually represents the attributes on a HitObject as part of // Shader ExecutionReordering +FIDDLE() class VulkanHitObjectAttributesAttribute : public Attribute { - SLANG_AST_CLASS(VulkanHitObjectAttributesAttribute) - - int location; + FIDDLE(...) + FIDDLE() int location; }; // A `[mutating]` attribute, which indicates that a member // function is allowed to modify things through its `this` // argument. // +FIDDLE() class MutatingAttribute : public Attribute { - SLANG_AST_CLASS(MutatingAttribute) + FIDDLE(...) }; // A `[nonmutating]` attribute, which indicates that a // `set` accessor does not need to modify anything through // its `this` parameter. // +FIDDLE() class NonmutatingAttribute : public Attribute { - SLANG_AST_CLASS(NonmutatingAttribute) + FIDDLE(...) }; // A `[constref]` attribute, which indicates that the `this` parameter of // a member function should be passed by const reference. // +FIDDLE() class ConstRefAttribute : public Attribute { - SLANG_AST_CLASS(ConstRefAttribute) + FIDDLE(...) }; // A `[ref]` attribute, which indicates that the `this` parameter of // a member function should be passed by reference. // +FIDDLE() class RefAttribute : public Attribute { - SLANG_AST_CLASS(RefAttribute) + FIDDLE(...) }; // A `[__readNone]` attribute, which indicates that a function @@ -1203,174 +1336,194 @@ class RefAttribute : public Attribute // reading or writing through any pointer arguments, or any other // state that could be observed by a caller. // +FIDDLE() class ReadNoneAttribute : public Attribute { - SLANG_AST_CLASS(ReadNoneAttribute) + FIDDLE(...) }; // A `[__GLSLRequireShaderInputParameter]` attribute to annotate // functions that require a shader input as parameter // +FIDDLE() class GLSLRequireShaderInputParameterAttribute : public Attribute { - SLANG_AST_CLASS(GLSLRequireShaderInputParameterAttribute) - - uint32_t parameterNumber; + FIDDLE(...) + FIDDLE() uint32_t parameterNumber; }; // HLSL modifiers for geometry shader input topology +FIDDLE() class HLSLGeometryShaderInputPrimitiveTypeModifier : public Modifier { - SLANG_AST_CLASS(HLSLGeometryShaderInputPrimitiveTypeModifier) + FIDDLE(...) }; +FIDDLE() class HLSLPointModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { - SLANG_AST_CLASS(HLSLPointModifier) + FIDDLE(...) }; +FIDDLE() class HLSLLineModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { - SLANG_AST_CLASS(HLSLLineModifier) + FIDDLE(...) }; +FIDDLE() class HLSLTriangleModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { - SLANG_AST_CLASS(HLSLTriangleModifier) + FIDDLE(...) }; +FIDDLE() class HLSLLineAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { - SLANG_AST_CLASS(HLSLLineAdjModifier) + FIDDLE(...) }; +FIDDLE() class HLSLTriangleAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { - SLANG_AST_CLASS(HLSLTriangleAdjModifier) + FIDDLE(...) }; // Mesh shader paramters +FIDDLE() class HLSLMeshShaderOutputModifier : public Modifier { - SLANG_AST_CLASS(HLSLMeshShaderOutputModifier) + FIDDLE(...) }; +FIDDLE() class HLSLVerticesModifier : public HLSLMeshShaderOutputModifier { - SLANG_AST_CLASS(HLSLVerticesModifier) + FIDDLE(...) }; +FIDDLE() class HLSLIndicesModifier : public HLSLMeshShaderOutputModifier { - SLANG_AST_CLASS(HLSLIndicesModifier) + FIDDLE(...) }; +FIDDLE() class HLSLPrimitivesModifier : public HLSLMeshShaderOutputModifier { - SLANG_AST_CLASS(HLSLPrimitivesModifier) + FIDDLE(...) }; +FIDDLE() class HLSLPayloadModifier : public Modifier { - SLANG_AST_CLASS(HLSLPayloadModifier) + FIDDLE(...) }; // A modifier to indicate that a constructor/initializer can be used // to perform implicit type conversion, and to specify the cost of // the conversion, if applied. +FIDDLE() class ImplicitConversionModifier : public Modifier { - SLANG_AST_CLASS(ImplicitConversionModifier) - + FIDDLE(...) // The conversion cost, used to rank conversions - ConversionCost cost = kConversionCost_None; + FIDDLE() ConversionCost cost = kConversionCost_None; // A builtin identifier for identifying conversions that need special treatment. - BuiltinConversionKind builtinConversionKind = kBuiltinConversion_Unknown; + FIDDLE() BuiltinConversionKind builtinConversionKind = kBuiltinConversion_Unknown; }; +FIDDLE() class FormatAttribute : public Attribute { - SLANG_AST_CLASS(FormatAttribute) - - ImageFormat format; + FIDDLE(...) + FIDDLE() ImageFormat format; }; +FIDDLE() class AllowAttribute : public Attribute { - SLANG_AST_CLASS(AllowAttribute) - - DiagnosticInfo const* diagnostic = nullptr; + FIDDLE(...) + FIDDLE() DiagnosticInfo const* diagnostic = nullptr; }; // A `[__extern]` attribute, which indicates that a function/type is defined externally // +FIDDLE() class ExternAttribute : public Attribute { - SLANG_AST_CLASS(ExternAttribute) + FIDDLE(...) }; // An `[__unsafeForceInlineExternal]` attribute indicates that the callee should be inlined // into call sites after initial IR generation (that is, as early as possible). // +FIDDLE() class UnsafeForceInlineEarlyAttribute : public Attribute { - SLANG_AST_CLASS(UnsafeForceInlineEarlyAttribute) + FIDDLE(...) }; // A `[ForceInline]` attribute indicates that the callee should be inlined // by the Slang compiler. // +FIDDLE() class ForceInlineAttribute : public Attribute { - SLANG_AST_CLASS(ForceInlineAttribute) + FIDDLE(...) }; /// An attribute that marks a type declaration as either allowing or /// disallowing the type to be inherited from in other modules. +FIDDLE(abstract) class InheritanceControlAttribute : public Attribute { - SLANG_AST_CLASS(InheritanceControlAttribute) + FIDDLE(...) }; /// An attribute that marks a type declaration as allowing the type to be inherited from in other /// modules. +FIDDLE() class OpenAttribute : public InheritanceControlAttribute { - SLANG_AST_CLASS(OpenAttribute) + FIDDLE(...) }; /// An attribute that marks a type declaration as disallowing the type to be inherited from in other /// modules. +FIDDLE() class SealedAttribute : public InheritanceControlAttribute { - SLANG_AST_CLASS(SealedAttribute) + FIDDLE(...) }; /// An attribute that marks a decl as a compiler built-in object. +FIDDLE() class BuiltinAttribute : public Attribute { - SLANG_AST_CLASS(BuiltinAttribute) + FIDDLE(...) }; /// An attribute that marks a decl as a compiler built-in object for the autodiff system. +FIDDLE() class AutoDiffBuiltinAttribute : public Attribute { - SLANG_AST_CLASS(AutoDiffBuiltinAttribute) + FIDDLE(...) }; /// An attribute that defines the size of `AnyValue` type to represent a polymoprhic value that /// conforms to the decorated interface type. +FIDDLE() class AnyValueSizeAttribute : public Attribute { - SLANG_AST_CLASS(AnyValueSizeAttribute) - - int32_t size; + FIDDLE(...) + FIDDLE() int32_t size; }; /// This is a stop-gap solution to break overload ambiguity in the core module. @@ -1379,24 +1532,27 @@ class AnyValueSizeAttribute : public Attribute /// In the future, we should enhance our type system to take into account the "specialized"-ness /// of an overload, such that `T overload1<T:IDerived>()` is more specialized than `T /// overload2<T:IBase>()` and preferred during overload resolution. +FIDDLE() class OverloadRankAttribute : public Attribute { - SLANG_AST_CLASS(OverloadRankAttribute) - int32_t rank; + FIDDLE(...) + FIDDLE() int32_t rank; }; /// An attribute that marks an interface for specialization use only. Any operation that triggers /// dynamic dispatch through the interface is a compile-time error. +FIDDLE() class SpecializeAttribute : public Attribute { - SLANG_AST_CLASS(SpecializeAttribute) + FIDDLE(...) }; /// An attribute that marks a type, function or variable as differentiable. +FIDDLE() class DifferentiableAttribute : public Attribute { - SLANG_AST_CLASS(DifferentiableAttribute) - + FIDDLE(...) + // TODO(tfoley): Why is there this duplication here? List<KeyValuePair<Type*, SubtypeWitness*>> m_typeToIDifferentiableWitnessMappings; void addType(Type* declRef, SubtypeWitness* witness) @@ -1418,55 +1574,62 @@ private: OrderedDictionary<Type*, SubtypeWitness*> m_mapToIDifferentiableWitness; }; +FIDDLE() class DllImportAttribute : public Attribute { - SLANG_AST_CLASS(DllImportAttribute) + FIDDLE(...) + FIDDLE() String modulePath; - String modulePath; - - String functionName; + FIDDLE() String functionName; }; +FIDDLE() class DllExportAttribute : public Attribute { - SLANG_AST_CLASS(DllExportAttribute) + FIDDLE(...) }; +FIDDLE() class TorchEntryPointAttribute : public Attribute { - SLANG_AST_CLASS(TorchEntryPointAttribute) + FIDDLE(...) }; +FIDDLE() class CudaDeviceExportAttribute : public Attribute { - SLANG_AST_CLASS(CudaDeviceExportAttribute) + FIDDLE(...) }; +FIDDLE() class CudaKernelAttribute : public Attribute { - SLANG_AST_CLASS(CudaKernelAttribute) + FIDDLE(...) }; +FIDDLE() class CudaHostAttribute : public Attribute { - SLANG_AST_CLASS(CudaHostAttribute) + FIDDLE(...) }; +FIDDLE() class AutoPyBindCudaAttribute : public Attribute { - SLANG_AST_CLASS(AutoPyBindCudaAttribute) + FIDDLE(...) }; +FIDDLE() class PyExportAttribute : public Attribute { - SLANG_AST_CLASS(PyExportAttribute) - - String name; + FIDDLE(...) + FIDDLE() String name; }; +FIDDLE() class PreferRecomputeAttribute : public Attribute { - SLANG_AST_CLASS(PreferRecomputeAttribute) + FIDDLE(...) enum SideEffectBehavior { @@ -1474,87 +1637,94 @@ class PreferRecomputeAttribute : public Attribute Allow = 1 }; - SideEffectBehavior sideEffectBehavior; + FIDDLE() SideEffectBehavior sideEffectBehavior; }; +FIDDLE() class PreferCheckpointAttribute : public Attribute { - SLANG_AST_CLASS(PreferCheckpointAttribute) + FIDDLE(...) }; +FIDDLE() class DerivativeMemberAttribute : public Attribute { - SLANG_AST_CLASS(DerivativeMemberAttribute) - - DeclRefExpr* memberDeclRef; + FIDDLE(...) + FIDDLE() DeclRefExpr* memberDeclRef; }; /// An attribute that marks an interface type as a COM interface declaration. +FIDDLE() class ComInterfaceAttribute : public Attribute { - SLANG_AST_CLASS(ComInterfaceAttribute) - - String guid; + FIDDLE(...) + FIDDLE() String guid; }; /// A `[__requiresNVAPI]` attribute indicates that the declaration being modifed /// requires NVAPI operations for its implementation on D3D. +FIDDLE() class RequiresNVAPIAttribute : public Attribute { - SLANG_AST_CLASS(RequiresNVAPIAttribute) + FIDDLE(...) }; /// A `[RequirePrelude(target, "string")]` attribute indicates that the declaration being modifed /// requires a textual prelude to be injected in the resulting target code. +FIDDLE() class RequirePreludeAttribute : public Attribute { - SLANG_AST_CLASS(RequirePreludeAttribute) - - CapabilitySet capabilitySet; - String prelude; + FIDDLE(...) + FIDDLE() CapabilitySet capabilitySet; + FIDDLE() String prelude; }; /// A `[__AlwaysFoldIntoUseSite]` attribute indicates that the calls into the modified /// function should always be folded into use sites during source emit. +FIDDLE() class AlwaysFoldIntoUseSiteAttribute : public Attribute { - SLANG_AST_CLASS(AlwaysFoldIntoUseSiteAttribute) + FIDDLE(...) }; // A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface // should be treated as differentiable in IR validation step. // +FIDDLE() class TreatAsDifferentiableAttribute : public DifferentiableAttribute { - SLANG_AST_CLASS(TreatAsDifferentiableAttribute) + FIDDLE(...) }; /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. +FIDDLE() class ForwardDifferentiableAttribute : public DifferentiableAttribute { - SLANG_AST_CLASS(ForwardDifferentiableAttribute) + FIDDLE(...) }; +FIDDLE() class UserDefinedDerivativeAttribute : public DifferentiableAttribute { - SLANG_AST_CLASS(UserDefinedDerivativeAttribute) - - Expr* funcExpr; + FIDDLE(...) + FIDDLE() Expr* funcExpr; }; /// The `[ForwardDerivative(function)]` attribute specifies a custom function that should /// be used as the derivative for the decorated function. +FIDDLE() class ForwardDerivativeAttribute : public UserDefinedDerivativeAttribute { - SLANG_AST_CLASS(ForwardDerivativeAttribute) + FIDDLE(...) }; +FIDDLE() class DerivativeOfAttribute : public DifferentiableAttribute { - SLANG_AST_CLASS(DerivativeOfAttribute) - - Expr* funcExpr; + FIDDLE(...) + FIDDLE() Expr* funcExpr; + FIDDLE() Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; @@ -1562,75 +1732,83 @@ class DerivativeOfAttribute : public DifferentiableAttribute /// derivative implementation for `primalFunction`. /// ForwardDerivativeOfAttribute inherits from DifferentiableAttribute because a derivative /// function itself is considered differentiable. +FIDDLE() class ForwardDerivativeOfAttribute : public DerivativeOfAttribute { - SLANG_AST_CLASS(ForwardDerivativeOfAttribute) + FIDDLE(...) }; /// The `[BackwardDifferentiable]` attribute indicates that a function can be /// backward-differentiated. +FIDDLE() class BackwardDifferentiableAttribute : public DifferentiableAttribute { - SLANG_AST_CLASS(BackwardDifferentiableAttribute) - int maxOrder = 0; + FIDDLE(...) + FIDDLE() int maxOrder = 0; }; /// The `[BackwardDerivative(function)]` attribute specifies a custom function that should /// be used as the backward-derivative for the decorated function. +FIDDLE() class BackwardDerivativeAttribute : public UserDefinedDerivativeAttribute { - SLANG_AST_CLASS(BackwardDerivativeAttribute) + FIDDLE(...) }; /// The `[BackwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom /// backward-derivative implementation for `primalFunction`. +FIDDLE() class BackwardDerivativeOfAttribute : public DerivativeOfAttribute { - SLANG_AST_CLASS(BackwardDerivativeOfAttribute) + FIDDLE(...) }; /// The `[PrimalSubstitute(function)]` attribute specifies a custom function that should /// be used as the primal function substitute when differentiating code that calls the primal /// function. +FIDDLE() class PrimalSubstituteAttribute : public Attribute { - SLANG_AST_CLASS(PrimalSubstituteAttribute) - Expr* funcExpr; + FIDDLE(...) + FIDDLE() Expr* funcExpr; }; /// The `[PrimalSubstituteOf(primalFunction)]` attribute marks the decorated function as /// the substitute primal function in a forward or backward derivative function. +FIDDLE() class PrimalSubstituteOfAttribute : public Attribute { - SLANG_AST_CLASS(PrimalSubstituteOfAttribute) - - Expr* funcExpr; + FIDDLE(...) + FIDDLE() Expr* funcExpr; + FIDDLE() Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; /// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be /// included for differentiation. +FIDDLE() class NoDiffThisAttribute : public Attribute { - SLANG_AST_CLASS(NoDiffThisAttribute) + FIDDLE(...) }; /// Indicates that the modified declaration is one of the "magic" declarations /// that NVAPI uses to communicate extended operations. When NVAPI is being included /// via the prelude for downstream compilation, declarations with this modifier /// will not be emitted, instead allowing the versions from the prelude to be used. +FIDDLE() class NVAPIMagicModifier : public Modifier { - SLANG_AST_CLASS(NVAPIMagicModifier) + FIDDLE(...) }; /// A modifier that attaches to a `ModuleDecl` to indicate the register/space binding /// that NVAPI wants to use, as indicated by, e.g., the `NV_SHADER_EXTN_SLOT` and /// `NV_SHADER_EXTN_REGISTER_SPACE` preprocessor definitions. +FIDDLE() class NVAPISlotModifier : public Modifier { - SLANG_AST_CLASS(NVAPISlotModifier) - + FIDDLE(...) /// The name of the register that is to be used (e.g., `"u3"`) /// /// This value will come from the `NV_SHADER_EXTN_SLOT` macro, if set. @@ -1639,7 +1817,7 @@ class NVAPISlotModifier : public Modifier /// an `NVAPISlotModifier` to a module; if no register name is defined, /// then the modifier should not be added. /// - String registerName; + FIDDLE() String registerName; /// The name of the register space to be used (e.g., `space1`) /// @@ -1648,7 +1826,7 @@ class NVAPISlotModifier : public Modifier /// /// It is valid for a user to specify a register name but not a space name, /// and in that case `spaceName` will be set to `"space0"`. - String spaceName; + FIDDLE() String spaceName; }; /// A `[noinline]` attribute represents a request by the application that, @@ -1657,41 +1835,48 @@ class NVAPISlotModifier : public Modifier /// Note that due to various limitations of different targets, it is entirely /// possible for such functions to be inlined or specialized to call sites. /// +FIDDLE() class NoInlineAttribute : public Attribute { - SLANG_AST_CLASS(NoInlineAttribute) + FIDDLE(...) }; /// A `[noRefInline]` attribute represents a request to not force inline a /// function specifically due to a refType parameter. +FIDDLE() class NoRefInlineAttribute : public Attribute { - SLANG_AST_CLASS(NoRefInlineAttribute) + FIDDLE(...) }; +FIDDLE() class DerivativeGroupQuadAttribute : public Attribute { - SLANG_AST_CLASS(DerivativeGroupQuadAttribute) + FIDDLE(...) }; +FIDDLE() class DerivativeGroupLinearAttribute : public Attribute { - SLANG_AST_CLASS(DerivativeGroupLinearAttribute) + FIDDLE(...) }; +FIDDLE() class MaximallyReconvergesAttribute : public Attribute { - SLANG_AST_CLASS(MaximallyReconvergesAttribute) + FIDDLE(...) }; +FIDDLE() class QuadDerivativesAttribute : public Attribute { - SLANG_AST_CLASS(QuadDerivativesAttribute) + FIDDLE(...) }; +FIDDLE() class RequireFullQuadsAttribute : public Attribute { - SLANG_AST_CLASS(RequireFullQuadsAttribute) + FIDDLE(...) }; /// A `[payload]` attribute indicates that a `struct` type will be used as @@ -1699,9 +1884,10 @@ class RequireFullQuadsAttribute : public Attribute /// for shaders in the ray tracing pipeline that might be invoked for /// such a ray. /// +FIDDLE() class PayloadAttribute : public Attribute { - SLANG_AST_CLASS(PayloadAttribute) + FIDDLE(...) }; /// A `[raypayload]` attribute indicates that a `struct` type will be used as @@ -1709,9 +1895,10 @@ class PayloadAttribute : public Attribute /// for shaders in the ray tracing pipeline that might be invoked for /// such a ray. /// +FIDDLE() class RayPayloadAttribute : public Attribute { - SLANG_AST_CLASS(RayPayloadAttribute) + FIDDLE(...) }; /// A `[deprecated("message")]` attribute indicates the target is @@ -1719,32 +1906,34 @@ class RayPayloadAttribute : public Attribute /// A compiler warning including the message will be raised if the /// deprecated value is used. /// +FIDDLE() class DeprecatedAttribute : public Attribute { - SLANG_AST_CLASS(DeprecatedAttribute) - - String message; + FIDDLE(...) + FIDDLE() String message; }; +FIDDLE() class NonCopyableTypeAttribute : public Attribute { - SLANG_AST_CLASS(NonCopyableTypeAttribute) + FIDDLE(...) }; +FIDDLE() class NoSideEffectAttribute : public Attribute { - SLANG_AST_CLASS(NoSideEffectAttribute) + FIDDLE(...) }; /// A `[KnownBuiltin("name")]` attribute allows the compiler to /// identify this declaration during compilation, despite obfuscation or /// linkage removing optimizations /// +FIDDLE() class KnownBuiltinAttribute : public Attribute { - SLANG_AST_CLASS(KnownBuiltinAttribute) - - String name; + FIDDLE(...) + FIDDLE() String name; }; /// A modifier that applies to types rather than declarations. @@ -1762,103 +1951,117 @@ class KnownBuiltinAttribute : public Attribute /// and instead want to belong to the type (or rather the type *specifier* /// from a parsing standpoint). /// +FIDDLE() class TypeModifier : public Modifier { - SLANG_AST_CLASS(TypeModifier) + FIDDLE(...) }; /// A kind of syntax element which appears as a modifier in the syntax, but /// we represent as a function over type expressions +FIDDLE() class WrappingTypeModifier : public TypeModifier { - SLANG_AST_CLASS(WrappingTypeModifier) + FIDDLE(...) }; /// A modifier that applies to a type and implies information about the /// underlying format of a resource that uses that type as its element type. /// +FIDDLE() class ResourceElementFormatModifier : public TypeModifier { - SLANG_AST_CLASS(ResourceElementFormatModifier) + FIDDLE(...) }; /// HLSL `unorm` modifier +FIDDLE() class UNormModifier : public ResourceElementFormatModifier { - SLANG_AST_CLASS(UNormModifier) + FIDDLE(...) }; /// HLSL `snorm` modifier +FIDDLE() class SNormModifier : public ResourceElementFormatModifier { - SLANG_AST_CLASS(SNormModifier) + FIDDLE(...) }; +FIDDLE() class NoDiffModifier : public TypeModifier { - SLANG_AST_CLASS(NoDiffModifier) + FIDDLE(...) }; +FIDDLE() class GloballyCoherentModifier : public SimpleModifier { - SLANG_AST_CLASS(GloballyCoherentModifier) + FIDDLE(...) }; // Some GLSL-specific modifiers +FIDDLE() class GLSLBufferModifier : public WrappingTypeModifier { - SLANG_AST_CLASS(GLSLBufferModifier) + FIDDLE(...) }; +FIDDLE() class GLSLWriteOnlyModifier : public SimpleModifier { - SLANG_AST_CLASS(GLSLWriteOnlyModifier) + FIDDLE(...) }; +FIDDLE() class GLSLReadOnlyModifier : public SimpleModifier { - SLANG_AST_CLASS(GLSLReadOnlyModifier) + FIDDLE(...) }; +FIDDLE() class GLSLVolatileModifier : public SimpleModifier { - SLANG_AST_CLASS(GLSLVolatileModifier) + FIDDLE(...) }; +FIDDLE() class GLSLRestrictModifier : public SimpleModifier { - SLANG_AST_CLASS(GLSLRestrictModifier) + FIDDLE(...) }; +FIDDLE() class GLSLPatchModifier : public SimpleModifier { - SLANG_AST_CLASS(GLSLPatchModifier) + FIDDLE(...) }; // +FIDDLE() class BitFieldModifier : public Modifier { - SLANG_AST_CLASS(BitFieldModifier) - - IntegerLiteralValue width; + FIDDLE(...) + FIDDLE() IntegerLiteralValue width; // Fields filled during semantic analysis - IntegerLiteralValue offset = 0; - DeclRef<VarDecl> backingDeclRef; + FIDDLE() IntegerLiteralValue offset = 0; + FIDDLE() DeclRef<VarDecl> backingDeclRef; }; +FIDDLE() class DynamicUniformModifier : public Modifier { - SLANG_AST_CLASS(DynamicUniformModifier) + FIDDLE(...) }; +FIDDLE() class MemoryQualifierSetModifier : public Modifier { - SLANG_AST_CLASS(MemoryQualifierSetModifier); - - List<Modifier*> memoryModifiers; + FIDDLE(...) + FIDDLE() List<Modifier*> memoryModifiers; - uint32_t memoryQualifiers = 0; + FIDDLE() uint32_t memoryQualifiers = 0; public: struct Flags diff --git a/source/slang/slang-ast-reflect.cpp b/source/slang/slang-ast-reflect.cpp deleted file mode 100644 index 3f4ba9534..000000000 --- a/source/slang/slang-ast-reflect.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include "slang-ast-reflect.h" - -#include "../core/slang-smart-pointer.h" -#include "slang-ast-all.h" -#include "slang-generated-ast-macro.h" -#include "slang-visitor.h" -#include "slang.h" - -#include <assert.h> -#include <typeinfo> - -namespace Slang -{ - -#define SLANG_REFLECT_GET_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - infos.infos[int(ASTNodeType::NAME)] = &NAME::kReflectClassInfo; - -static ASTClassInfo::Infos _calcInfos() -{ - ASTClassInfo::Infos infos; - memset(&infos, 0, sizeof(infos)); - SLANG_ALL_ASTNode_NodeBase(SLANG_REFLECT_GET_REFLECT_CLASS_INFO, _) return infos; -} - -/* static */ const ASTClassInfo::Infos ASTClassInfo::kInfos = _calcInfos(); - -// Now try and implement all of the classes -// Macro generated is of the format - -struct ASTConstructAccess -{ - template<typename T> - struct Impl - { - static void* create(void* context) - { - ASTBuilder* astBuilder = (ASTBuilder*)context; - return astBuilder->createImpl<T>(); - } - static void destroy(void* ptr) - { - // Needed because if type has non dtor, Visual Studio claims ptr not used - SLANG_UNUSED(ptr); - reinterpret_cast<T*>(ptr)->~T(); - } - }; -}; - -#define SLANG_GET_SUPER_BASE(SUPER) nullptr -#define SLANG_GET_SUPER_INNER(SUPER) &SUPER::kReflectClassInfo -#define SLANG_GET_SUPER_LEAF(SUPER) &SUPER::kReflectClassInfo - -#define SLANG_GET_CREATE_FUNC_ABSTRACT_AST(NAME) nullptr -#define SLANG_GET_CREATE_FUNC_AST(NAME) &ASTConstructAccess::Impl<NAME>::create - -#define SLANG_GET_DESTROY_FUNC_ABSTRACT_AST(NAME) nullptr -#define SLANG_GET_DESTROY_FUNC_AST(NAME) &ASTConstructAccess::Impl<NAME>::destroy - -#define SLANG_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - /* static */ const ReflectClassInfo NAME::kReflectClassInfo = { \ - uint32_t(ASTNodeType::NAME), \ - uint32_t(ASTNodeType::LAST), \ - SLANG_GET_SUPER_##TYPE(SUPER), \ - #NAME, \ - SLANG_GET_CREATE_FUNC_##MARKER(NAME), \ - SLANG_GET_DESTROY_FUNC_##MARKER(NAME), \ - uint32_t(sizeof(NAME)), \ - uint8_t(SLANG_ALIGN_OF(NAME))}; - -SLANG_ALL_ASTNode_NodeBase(SLANG_REFLECT_CLASS_INFO, _) - -// We dispatch to non 'abstract' types -#define SLANG_CASE_AST(NAME) \ - case ASTNodeType::NAME: \ - return visitor->dispatch_##NAME(static_cast<NAME*>(this), extra); -#define SLANG_CASE_ABSTRACT_AST(NAME) - -#define SLANG_CASE_DISPATCH(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - SLANG_CASE_##MARKER(NAME) - - void Val::accept(IValVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_Val(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); - } -} - -void Type::accept(ITypeVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_Type(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); - } -} - -void Modifier::accept(IModifierVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_Modifier(SLANG_CASE_DISPATCH, _) default - : SLANG_ASSERT(!"Unknown type"); - } -} - -void DeclBase::accept(IDeclVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CASE_DISPATCH, _) default - : SLANG_ASSERT(!"Unknown type"); - } -} - -void Expr::accept(IExprVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_Expr(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); - } -} - -void Stmt::accept(IStmtVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_Stmt(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); - } -} - -} // namespace Slang diff --git a/source/slang/slang-ast-reflect.h b/source/slang/slang-ast-reflect.h deleted file mode 100644 index 56e42c8bd..000000000 --- a/source/slang/slang-ast-reflect.h +++ /dev/null @@ -1,59 +0,0 @@ -// slang-ast-reflect.h - -#ifndef SLANG_AST_REFLECT_H -#define SLANG_AST_REFLECT_H - -#include "slang-generated-ast.h" -#include "slang-serialize-reflection.h" - -// Implementation for SLANG_ABSTRACT_CLASS(x) using reflection from C++ extractor in -// slang-ast-generated.h -#define SLANG_AST_CLASS_REFLECT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ -protected: \ - NAME() = default; \ - \ -public: \ - typedef NAME This; \ - static constexpr ASTNodeType kType = ASTNodeType::NAME; \ - static const ReflectClassInfo kReflectClassInfo; \ - SLANG_FORCE_INLINE static bool isDerivedFrom(ASTNodeType type) \ - { \ - return int(type) >= int(kType) && int(type) <= int(ASTNodeType::LAST); \ - } \ - SLANG_CLASS_REFLECT_SUPER_##TYPE(SUPER) friend class ASTBuilder; \ - friend struct ASTConstructAccess; \ - friend struct ASTFieldAccess; \ - friend struct ASTDumpAccess; - -// Macro definitions - use the SLANG_ASTNode_ definitions to invoke the IMPL to produce the code -// injected into AST classes -#define SLANG_ABSTRACT_AST_CLASS(NAME) SLANG_ASTNode_##NAME(SLANG_AST_CLASS_REFLECT_IMPL, _) -#define SLANG_AST_CLASS(NAME) SLANG_ASTNode_##NAME(SLANG_AST_CLASS_REFLECT_IMPL, _) - -// Macros for simulating virtual methods without virtual methods - -#define SLANG_AST_NODE_INVOKE(method, methodParams) _##method##Override methodParams - -#define SLANG_AST_NODE_CASE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - case ASTNodeType::NAME: \ - return static_cast<NAME*>(this)->SLANG_AST_NODE_INVOKE param; - -#define SLANG_AST_NODE_VIRTUAL_CALL(base, methodName, methodParams) \ - switch (astNodeType) \ - { \ - SLANG_ALL_ASTNode_##base(SLANG_AST_NODE_CASE, (methodName, methodParams)) default \ - : return SLANG_AST_NODE_INVOKE(methodName, methodParams); \ - } - -// Same but for a method that's const -#define SLANG_AST_NODE_CONST_CASE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - case ASTNodeType::NAME: \ - return static_cast<const NAME*>(this)->SLANG_AST_NODE_INVOKE param; -#define SLANG_AST_NODE_CONST_VIRTUAL_CALL(base, methodName, methodParams) \ - switch (astNodeType) \ - { \ - SLANG_ALL_ASTNode_##base(SLANG_AST_NODE_CONST_CASE, (methodName, methodParams)) default \ - : return SLANG_AST_NODE_INVOKE(methodName, methodParams); \ - } - -#endif // SLANG_AST_REFLECT_H diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index 4107664bf..a1b7c274e 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -1,55 +1,56 @@ // slang-ast-stmt.h - #pragma once #include "slang-ast-base.h" +#include "slang-ast-stmt.h.fiddle" +FIDDLE() namespace Slang { // Syntax class definitions for statements. +FIDDLE(abstract) class ScopeStmt : public Stmt { - SLANG_ABSTRACT_AST_CLASS(ScopeStmt) - + FIDDLE(...) ScopeDecl* scopeDecl = nullptr; }; // A sequence of statements, treated as a single statement +FIDDLE() class SeqStmt : public Stmt { - SLANG_AST_CLASS(SeqStmt) - - List<Stmt*> stmts; + FIDDLE(...) + FIDDLE() List<Stmt*> stmts; }; // A statement with a label. +FIDDLE() class LabelStmt : public Stmt { - SLANG_AST_CLASS(LabelStmt) - - Token label; - Stmt* innerStmt; + FIDDLE(...) + FIDDLE() Token label; + FIDDLE() Stmt* innerStmt; }; // The simplest kind of scope statement: just a `{...}` block +FIDDLE() class BlockStmt : public ScopeStmt { - SLANG_AST_CLASS(BlockStmt) - + FIDDLE(...) /// TODO(JS): Having ranges of sourcelocs might be a good addition to AST nodes in general. SourceLoc closingSourceLoc; ///< The source location of the closing brace - Stmt* body = nullptr; + FIDDLE() Stmt* body = nullptr; }; // A statement that we aren't going to parse or check, because // we want to let a downstream compiler handle any issues +FIDDLE() class UnparsedStmt : public Stmt { - SLANG_AST_CLASS(UnparsedStmt) - + FIDDLE(...) // The tokens that were contained between `{` and `}` List<Token> tokens; Scope* currentScope = nullptr; @@ -58,41 +59,45 @@ class UnparsedStmt : public Stmt bool isInVariadicGenerics = false; }; +FIDDLE() class EmptyStmt : public Stmt { - SLANG_AST_CLASS(EmptyStmt) + FIDDLE(...) }; +FIDDLE() class DiscardStmt : public Stmt { - SLANG_AST_CLASS(DiscardStmt) + FIDDLE(...) }; +FIDDLE() class DeclStmt : public Stmt { - SLANG_AST_CLASS(DeclStmt) - - DeclBase* decl = nullptr; + FIDDLE(...) + FIDDLE() DeclBase* decl = nullptr; }; +FIDDLE() class IfStmt : public Stmt { - SLANG_AST_CLASS(IfStmt) - - Expr* predicate = nullptr; - Stmt* positiveStatement = nullptr; - Stmt* negativeStatement = nullptr; + FIDDLE(...) + FIDDLE() Expr* predicate = nullptr; + FIDDLE() Stmt* positiveStatement = nullptr; + FIDDLE() Stmt* negativeStatement = nullptr; }; +FIDDLE() class UniqueStmtIDNode : public Decl { - SLANG_AST_CLASS(UniqueStmtIDNode) + FIDDLE(...) }; // A statement that can be escaped with a `break` +FIDDLE(abstract) class BreakableStmt : public ScopeStmt { - SLANG_ABSTRACT_AST_CLASS(BreakableStmt) + FIDDLE(...) /// A unique ID for this statement. /// @@ -106,20 +111,21 @@ class BreakableStmt : public ScopeStmt static constexpr UniqueID kInvalidUniqueID = nullptr; }; +FIDDLE() class SwitchStmt : public BreakableStmt { - SLANG_AST_CLASS(SwitchStmt) - - Expr* condition = nullptr; - Stmt* body = nullptr; + FIDDLE(...) + FIDDLE() Expr* condition = nullptr; + FIDDLE() Stmt* body = nullptr; }; // A statement that is expected to appear lexically nested inside // some other construct, and thus needs to keep track of the // outer statement that it is associated with... +FIDDLE(abstract) class ChildStmt : public Stmt { - SLANG_ABSTRACT_AST_CLASS(ChildStmt) + FIDDLE(...) /// The unique ID of the enclosing statement this /// child statement refers to. @@ -127,33 +133,35 @@ class ChildStmt : public Stmt BreakableStmt::UniqueID targetOuterStmtID = BreakableStmt::kInvalidUniqueID; }; +FIDDLE() class TargetCaseStmt : public ChildStmt { - SLANG_AST_CLASS(TargetCaseStmt) - int32_t capability; - Token capabilityToken; - Stmt* body = nullptr; + FIDDLE(...) + FIDDLE() int32_t capability; + FIDDLE() Token capabilityToken; + FIDDLE() Stmt* body = nullptr; }; +FIDDLE() class TargetSwitchStmt : public BreakableStmt { - SLANG_AST_CLASS(TargetSwitchStmt) - - List<TargetCaseStmt*> targetCases; + FIDDLE(...) + FIDDLE() List<TargetCaseStmt*> targetCases; }; +FIDDLE() class StageSwitchStmt : public TargetSwitchStmt { - SLANG_AST_CLASS(StageSwitchStmt) + FIDDLE(...) }; +FIDDLE() class IntrinsicAsmStmt : public Stmt { - SLANG_AST_CLASS(IntrinsicAsmStmt) + FIDDLE(...) + FIDDLE() String asmText; - String asmText; - - List<Expr*> args; + FIDDLE() List<Expr*> args; }; // a `case` or `default` statement inside a `switch` @@ -161,129 +169,136 @@ class IntrinsicAsmStmt : public Stmt // Note(tfoley): A correct AST for a C-like language would treat // these as a labelled statement, and so they would contain a // sub-statement. I'm leaving that out for now for simplicity. +FIDDLE(abstract) class CaseStmtBase : public ChildStmt { - SLANG_ABSTRACT_AST_CLASS(CaseStmtBase) + FIDDLE(...) }; // a `case` statement inside a `switch` +FIDDLE() class CaseStmt : public CaseStmtBase { - SLANG_AST_CLASS(CaseStmt) - - Expr* expr = nullptr; + FIDDLE(...) + FIDDLE() Expr* expr = nullptr; - Val* exprVal = nullptr; + FIDDLE() Val* exprVal = nullptr; }; // a `default` statement inside a `switch` +FIDDLE() class DefaultStmt : public CaseStmtBase { - SLANG_AST_CLASS(DefaultStmt) + FIDDLE(...) }; // a `default` statement inside a `switch` +FIDDLE() class GpuForeachStmt : public ScopeStmt { - SLANG_AST_CLASS(GpuForeachStmt) - - Expr* device = nullptr; - Expr* gridDims = nullptr; - VarDecl* dispatchThreadID = nullptr; - Expr* kernelCall = nullptr; + FIDDLE(...) + FIDDLE() Expr* device = nullptr; + FIDDLE() Expr* gridDims = nullptr; + FIDDLE() VarDecl* dispatchThreadID = nullptr; + FIDDLE() Expr* kernelCall = nullptr; }; // A statement that represents a loop, and can thus be escaped with a `continue` +FIDDLE(abstract) class LoopStmt : public BreakableStmt { - SLANG_ABSTRACT_AST_CLASS(LoopStmt) + FIDDLE(...) }; // A `for` statement +FIDDLE() class ForStmt : public LoopStmt { - SLANG_AST_CLASS(ForStmt) - - Stmt* initialStatement = nullptr; - Expr* sideEffectExpression = nullptr; - Expr* predicateExpression = nullptr; - Stmt* statement = nullptr; + FIDDLE(...) + FIDDLE() Stmt* initialStatement = nullptr; + FIDDLE() Expr* sideEffectExpression = nullptr; + FIDDLE() Expr* predicateExpression = nullptr; + FIDDLE() Stmt* statement = nullptr; }; // A `for` statement in a language that doesn't restrict the scope // of the loop variable to the body. +FIDDLE() class UnscopedForStmt : public ForStmt { - SLANG_AST_CLASS(UnscopedForStmt); + FIDDLE(...) }; +FIDDLE() class WhileStmt : public LoopStmt { - SLANG_AST_CLASS(WhileStmt) - - Expr* predicate = nullptr; - Stmt* statement = nullptr; + FIDDLE(...) + FIDDLE() Expr* predicate = nullptr; + FIDDLE() Stmt* statement = nullptr; }; +FIDDLE() class DoWhileStmt : public LoopStmt { - SLANG_AST_CLASS(DoWhileStmt) - - Stmt* statement = nullptr; - Expr* predicate = nullptr; + FIDDLE(...) + FIDDLE() Stmt* statement = nullptr; + FIDDLE() Expr* predicate = nullptr; }; // A compile-time, range-based `for` loop, which will not appear in the output code +FIDDLE() class CompileTimeForStmt : public ScopeStmt { - SLANG_AST_CLASS(CompileTimeForStmt) - - VarDecl* varDecl = nullptr; - Expr* rangeBeginExpr = nullptr; - Expr* rangeEndExpr = nullptr; - Stmt* body = nullptr; - IntVal* rangeBeginVal = nullptr; - IntVal* rangeEndVal = nullptr; + FIDDLE(...) + FIDDLE() VarDecl* varDecl = nullptr; + FIDDLE() Expr* rangeBeginExpr = nullptr; + FIDDLE() Expr* rangeEndExpr = nullptr; + FIDDLE() Stmt* body = nullptr; + FIDDLE() IntVal* rangeBeginVal = nullptr; + FIDDLE() IntVal* rangeEndVal = nullptr; }; // The case of child statements that do control flow relative // to their parent statement. +FIDDLE(abstract) class JumpStmt : public ChildStmt { - SLANG_ABSTRACT_AST_CLASS(JumpStmt) + FIDDLE(...) }; +FIDDLE() class BreakStmt : public JumpStmt { - SLANG_AST_CLASS(BreakStmt) - + FIDDLE(...) Token targetLabel; }; +FIDDLE() class ContinueStmt : public JumpStmt { - SLANG_AST_CLASS(ContinueStmt) + FIDDLE(...) }; +FIDDLE() class ReturnStmt : public Stmt { - SLANG_AST_CLASS(ReturnStmt) - - Expr* expression = nullptr; + FIDDLE(...) + FIDDLE() Expr* expression = nullptr; }; +FIDDLE() class DeferStmt : public Stmt { - SLANG_AST_CLASS(DeferStmt) + FIDDLE(...) - Stmt* statement = nullptr; + FIDDLE() Stmt* statement = nullptr; }; +FIDDLE() class ExpressionStmt : public Stmt { - SLANG_AST_CLASS(ExpressionStmt) - - Expr* expression = nullptr; + FIDDLE(...) + FIDDLE() Expr* expression = nullptr; }; } // namespace Slang diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index d59b6b286..3ac352f0a 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -1,3 +1,4 @@ +// slang-ast-support-types.cpp #include "slang-ast-support-types.h" #include "slang-ast-base.h" diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index b3baee98f..87715d9e0 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -6,1716 +6,1731 @@ #include "../compiler-core/slang-name.h" #include "../core/slang-basic.h" #include "../core/slang-semantic-version.h" -#include "slang-ast-reflect.h" -#include "slang-generated-ast.h" +#include "slang-ast-forward-declarations.h" +#include "slang-ast-support-types.h.fiddle" #include "slang-profile.h" -#include "slang-ref-object-reflect.h" -#include "slang-serialize-reflection.h" #include "slang-type-system-shared.h" #include "slang.h" #include <assert.h> #include <type_traits> -namespace Slang -{ -class Module; -class Name; -class Session; -class SyntaxVisitor; -class FuncDecl; -class Layout; - -struct IExprVisitor; -struct IDeclVisitor; -struct IModifierVisitor; -struct IStmtVisitor; -struct ITypeVisitor; -struct IValVisitor; - -class Parser; -class SyntaxNode; - -class Decl; -struct QualType; -class Type; -struct TypeExp; -class Val; - -class NodeBase; -class LookupDeclRef; -class GenericAppDeclRef; -struct CapabilitySet; - -template<typename T> -T* as(NodeBase* node); - -template<typename T> -const T* as(const NodeBase* node); - -void printDiagnosticArg(StringBuilder& sb, Decl* decl); -void printDiagnosticArg(StringBuilder& sb, Type* type); -void printDiagnosticArg(StringBuilder& sb, TypeExp const& type); -void printDiagnosticArg(StringBuilder& sb, QualType const& type); -void printDiagnosticArg(StringBuilder& sb, Val* val); -void printDiagnosticArg(StringBuilder& sb, DeclRefBase* declRefBase); -void printDiagnosticArg(StringBuilder& sb, ASTNodeType nodeType); -void printDiagnosticArg(StringBuilder& sb, const CapabilitySet& set); -void printDiagnosticArg(StringBuilder& sb, List<CapabilityAtom>& set); - -struct QualifiedDeclPath -{ - DeclRefBase* declRef; - QualifiedDeclPath() = default; - QualifiedDeclPath(DeclRefBase* declRef) - : declRef(declRef) - { - } -}; -// Prints the fully qualified decl name. -void printDiagnosticArg(StringBuilder& sb, QualifiedDeclPath path); +#define SLANG_UNREFLECTED /* empty */ +FIDDLE(hidden class RefObject;) -class SyntaxNode; -SourceLoc getDiagnosticPos(SyntaxNode const* syntax); -SourceLoc getDiagnosticPos(TypeExp const& typeExp); -SourceLoc getDiagnosticPos(DeclRefBase* declRef); -SourceLoc getDiagnosticPos(Decl* decl); - -typedef NodeBase* (*SyntaxParseCallback)(Parser* parser, void* userData); - -typedef unsigned int ConversionCost; -enum : ConversionCost +FIDDLE() namespace Slang { - // No conversion at all - kConversionCost_None = 0, - - kConversionCost_GenericParamUpcast = 1, - kConversionCost_UnconstraintGenericParam = 20, - kConversionCost_SizedArrayToUnsizedArray = 30, - - // Convert between matrices of different layout - kConversionCost_MatrixLayout = 5, - - // Conversion from a buffer to the type it carries needs to add a minimal - // extra cost, just so we can distinguish an overload on `ConstantBuffer<Foo>` - // from one on `Foo` - kConversionCost_GetRef = 5, - kConversionCost_ImplicitDereference = 10, - kConversionCost_InRangeIntLitConversion = 23, - kConversionCost_InRangeIntLitSignedToUnsignedConversion = 32, - kConversionCost_InRangeIntLitUnsignedToSignedConversion = 81, - - kConversionCost_MutablePtrToConstPtr = 20, - - // Conversions based on explicit sub-typing relationships are the cheapest - // - // TODO(tfoley): We will eventually need a discipline for ranking - // when two up-casts are comparable. - kConversionCost_CastToInterface = 50, - - // Conversion that is lossless and keeps the "kind" of the value the same - kConversionCost_BoolToInt = - 120, // Converting bool to int has lower cost than other integer types to prevent ambiguity. - kConversionCost_RankPromotion = 150, - kConversionCost_NoneToOptional = 150, - kConversionCost_ValToOptional = 150, - kConversionCost_NullPtrToPtr = 150, - kConversionCost_PtrToVoidPtr = 150, - - // Conversions that are lossless, but change "kind" - kConversionCost_UnsignedToSignedPromotion = 200, - - // Same-size size unsigned->signed conversions are potentially lossy, but they are commonly - // allowed silently. - kConversionCost_SameSizeUnsignedToSignedConversion = 300, - - // Conversion from signed->unsigned integer of same or greater size - kConversionCost_SignedToUnsignedConversion = 250, +#define SLANG_AST_NODE_VIRTUAL_CALL(CLASS, METHOD, ARGS) \ + return ASTNodeDispatcher<CLASS, decltype(this->METHOD ARGS)>::dispatch( \ + this, \ + [&](auto _this) -> decltype(this->METHOD ARGS) \ + { return _this->_##METHOD##Override ARGS; }); + + class Module; + class Name; + class Session; + class SyntaxVisitor; + class FuncDecl; + class Layout; + + class Parser; + class SyntaxNode; + + class Decl; + struct QualType; + class Type; + struct TypeExp; + class Val; + + class DeclRefBase; + class NodeBase; + class LookupDeclRef; + class GenericAppDeclRef; + struct CapabilitySet; - // Cost of converting an integer to a floating-point type - kConversionCost_IntegerToFloatConversion = 400, - - // Cost of converting a pointer to bool - kConversionCost_PtrToBool = 400, - - // Cost of converting an integer to int16_t - kConversionCost_IntegerTruncate = 450, + template<typename T> + T* as(NodeBase * node); - // Cost of converting an integer to a half type - kConversionCost_IntegerToHalfConversion = 500, + template<typename T> + const T* as(const NodeBase* node); + + void printDiagnosticArg(StringBuilder & sb, Decl * decl); + void printDiagnosticArg(StringBuilder & sb, Type * type); + void printDiagnosticArg(StringBuilder & sb, TypeExp const& type); + void printDiagnosticArg(StringBuilder & sb, QualType const& type); + void printDiagnosticArg(StringBuilder & sb, Val * val); + void printDiagnosticArg(StringBuilder & sb, DeclRefBase * declRefBase); + void printDiagnosticArg(StringBuilder & sb, ASTNodeType nodeType); + void printDiagnosticArg(StringBuilder & sb, const CapabilitySet& set); + void printDiagnosticArg(StringBuilder & sb, List<CapabilityAtom> & set); + + struct QualifiedDeclPath + { + DeclRefBase* declRef; + QualifiedDeclPath() = default; + QualifiedDeclPath(DeclRefBase* declRef) + : declRef(declRef) + { + } + }; + // Prints the fully qualified decl name. + void printDiagnosticArg(StringBuilder & sb, QualifiedDeclPath path); - // Cost of using a concrete argument pack - kConversionCost_ParameterPack = 500, - // Default case (usable for user-defined conversions) - kConversionCost_Default = 500, + class SyntaxNode; + SourceLoc getDiagnosticPos(SyntaxNode const* syntax); + SourceLoc getDiagnosticPos(TypeExp const& typeExp); + SourceLoc getDiagnosticPos(DeclRefBase * declRef); + SourceLoc getDiagnosticPos(Decl * decl); + + typedef NodeBase* (*SyntaxParseCallback)(Parser* parser, void* userData); + + typedef unsigned int ConversionCost; + enum : ConversionCost + { + // No conversion at all + kConversionCost_None = 0, + + kConversionCost_GenericParamUpcast = 1, + kConversionCost_UnconstraintGenericParam = 20, + kConversionCost_SizedArrayToUnsizedArray = 30, + + // Convert between matrices of different layout + kConversionCost_MatrixLayout = 5, + + // Conversion from a buffer to the type it carries needs to add a minimal + // extra cost, just so we can distinguish an overload on `ConstantBuffer<Foo>` + // from one on `Foo` + kConversionCost_GetRef = 5, + kConversionCost_ImplicitDereference = 10, + kConversionCost_InRangeIntLitConversion = 23, + kConversionCost_InRangeIntLitSignedToUnsignedConversion = 32, + kConversionCost_InRangeIntLitUnsignedToSignedConversion = 81, + + kConversionCost_MutablePtrToConstPtr = 20, + + // Conversions based on explicit sub-typing relationships are the cheapest + // + // TODO(tfoley): We will eventually need a discipline for ranking + // when two up-casts are comparable. + kConversionCost_CastToInterface = 50, + + // Conversion that is lossless and keeps the "kind" of the value the same + kConversionCost_BoolToInt = 120, // Converting bool to int has lower cost than other integer + // types to prevent ambiguity. + kConversionCost_RankPromotion = 150, + kConversionCost_NoneToOptional = 150, + kConversionCost_ValToOptional = 150, + kConversionCost_NullPtrToPtr = 150, + kConversionCost_PtrToVoidPtr = 150, + + // Conversions that are lossless, but change "kind" + kConversionCost_UnsignedToSignedPromotion = 200, + + // Same-size size unsigned->signed conversions are potentially lossy, but they are commonly + // allowed silently. + kConversionCost_SameSizeUnsignedToSignedConversion = 300, + + // Conversion from signed->unsigned integer of same or greater size + kConversionCost_SignedToUnsignedConversion = 250, + + // Cost of converting an integer to a floating-point type + kConversionCost_IntegerToFloatConversion = 400, + + // Cost of converting a pointer to bool + kConversionCost_PtrToBool = 400, + + // Cost of converting an integer to int16_t + kConversionCost_IntegerTruncate = 450, + + // Cost of converting an integer to a half type + kConversionCost_IntegerToHalfConversion = 500, + + // Cost of using a concrete argument pack + kConversionCost_ParameterPack = 500, + + // Default case (usable for user-defined conversions) + kConversionCost_Default = 500, + + // Catch-all for conversions that should be discouraged + // (i.e., that really shouldn't be made implicitly) + // + // TODO: make these conversions not be allowed implicitly in "Slang mode" + kConversionCost_GeneralConversion = 900, + + // This is the cost of an explicit conversion, which should + // not actually be performed. + kConversionCost_Explicit = 90000, + + // Additional conversion cost to add when promoting from a scalar to + // a vector (this will be added to the cost, if any, of converting + // the element type of the vector) + kConversionCost_OneVectorToScalar = 1, + kConversionCost_ScalarToVector = 2, + kConversionCost_ScalarToMatrix = 10, + kConversionCost_ScalarIntegerToFloatMatrix = + kConversionCost_IntegerToFloatConversion + kConversionCost_ScalarToMatrix, + + // Additional conversion cost to add when promoting from a scalar to + // a CoopVector (this will be added to the cost, if any, of converting + // the element type of the CoopVector) + kConversionCost_ScalarToCoopVector = 1, + + // Additional cost when casting an LValue. + kConversionCost_LValueCast = 800, + + // The cost of this conversion is defined by the type coercion constraint. + kConversionCost_TypeCoercionConstraint = 1000, + kConversionCost_TypeCoercionConstraintPlusScalarToVector = + kConversionCost_TypeCoercionConstraint + kConversionCost_ScalarToVector, + + // Conversion is impossible + kConversionCost_Impossible = 0xFFFFFFFF, + }; - // Catch-all for conversions that should be discouraged - // (i.e., that really shouldn't be made implicitly) - // - // TODO: make these conversions not be allowed implicitly in "Slang mode" - kConversionCost_GeneralConversion = 900, - - // This is the cost of an explicit conversion, which should - // not actually be performed. - kConversionCost_Explicit = 90000, - - // Additional conversion cost to add when promoting from a scalar to - // a vector (this will be added to the cost, if any, of converting - // the element type of the vector) - kConversionCost_OneVectorToScalar = 1, - kConversionCost_ScalarToVector = 2, - kConversionCost_ScalarToMatrix = 10, - kConversionCost_ScalarIntegerToFloatMatrix = - kConversionCost_IntegerToFloatConversion + kConversionCost_ScalarToMatrix, - - // Additional conversion cost to add when promoting from a scalar to - // a CoopVector (this will be added to the cost, if any, of converting - // the element type of the CoopVector) - kConversionCost_ScalarToCoopVector = 1, - - // Additional cost when casting an LValue. - kConversionCost_LValueCast = 800, - - // The cost of this conversion is defined by the type coercion constraint. - kConversionCost_TypeCoercionConstraint = 1000, - kConversionCost_TypeCoercionConstraintPlusScalarToVector = - kConversionCost_TypeCoercionConstraint + kConversionCost_ScalarToVector, - - // Conversion is impossible - kConversionCost_Impossible = 0xFFFFFFFF, -}; - -typedef unsigned int BuiltinConversionKind; -enum : BuiltinConversionKind -{ - kBuiltinConversion_Unknown = 0, - kBuiltinConversion_FloatToDouble = 1, -}; + typedef unsigned int BuiltinConversionKind; + enum : BuiltinConversionKind + { + kBuiltinConversion_Unknown = 0, + kBuiltinConversion_FloatToDouble = 1, + }; -enum class ImageFormat -{ + enum class ImageFormat + { #define SLANG_FORMAT(NAME, OTHER) NAME, #include "slang-image-format-defs.h" #undef SLANG_FORMAT -}; - -struct ImageFormatInfo -{ - SlangScalarType scalarType; ///< If image format is not made up of channels of set sizes this - ///< will be SLANG_SCALAR_TYPE_NONE - uint8_t channelCount; ///< The number of channels - uint8_t sizeInBytes; ///< Size in bytes - UnownedStringSlice name; ///< The name associated with this type. NOTE! Currently these names - ///< *are* the GLSL format names. -}; - -const ImageFormatInfo& getImageFormatInfo(ImageFormat format); - -bool findImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); -bool findVkImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); - -char const* getGLSLNameForImageFormat(ImageFormat format); - -// TODO(tfoley): We should ditch this enumeration -// and just use the IR opcodes that represent these -// types directly. The one major complication there -// is that the order of the enum values currently -// matters, since it determines promotion rank. -// We either need to keep that restriction, or -// look up promotion rank by some other means. -// - -class Decl; -class Val; - -// Helper type for pairing up a name and the location where it appeared -struct NameLoc -{ - Name* name; - SourceLoc loc; + }; - NameLoc() - : name(nullptr) + struct ImageFormatInfo { - } + SlangScalarType scalarType; ///< If image format is not made up of channels of set sizes + ///< this will be SLANG_SCALAR_TYPE_NONE + uint8_t channelCount; ///< The number of channels + uint8_t sizeInBytes; ///< Size in bytes + UnownedStringSlice name; ///< The name associated with this type. NOTE! Currently these + ///< names *are* the GLSL format names. + }; - explicit NameLoc(Name* inName) - : name(inName) - { - } + const ImageFormatInfo& getImageFormatInfo(ImageFormat format); + bool findImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); + bool findVkImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); - NameLoc(Name* inName, SourceLoc inLoc) - : name(inName), loc(inLoc) - { - } + char const* getGLSLNameForImageFormat(ImageFormat format); - NameLoc(Token const& token) - : name(token.getNameOrNull()), loc(token.getLoc()) - { - } -}; + // TODO(tfoley): We should ditch this enumeration + // and just use the IR opcodes that represent these + // types directly. The one major complication there + // is that the order of the enum values currently + // matters, since it determines promotion rank. + // We either need to keep that restriction, or + // look up promotion rank by some other means. + // -struct StringSliceLoc -{ - UnownedStringSlice name; - SourceLoc loc; + class Decl; + class Val; - StringSliceLoc() - : name(nullptr) - { - } - explicit StringSliceLoc(const UnownedStringSlice& inName) - : name(inName) - { - } - StringSliceLoc(const UnownedStringSlice& inName, SourceLoc inLoc) - : name(inName), loc(inLoc) + // Helper type for pairing up a name and the location where it appeared + struct NameLoc { - } - StringSliceLoc(Token const& token) - : loc(token.getLoc()) - { - Name* tokenName = token.getNameOrNull(); - if (tokenName) + Name* name; + SourceLoc loc; + + NameLoc() + : name(nullptr) { - name = tokenName->text.getUnownedSlice(); } - } -}; - -// Helper class for iterating over a list of heap-allocated modifiers -struct ModifierList -{ - struct Iterator - { - Modifier* current = nullptr; - - Modifier* operator*() { return current; } - void operator++(); + explicit NameLoc(Name* inName) + : name(inName) + { + } - bool operator!=(Iterator other) { return current != other.current; }; - Iterator() - : current(nullptr) + NameLoc(Name* inName, SourceLoc inLoc) + : name(inName), loc(inLoc) { } - Iterator(Modifier* modifier) - : current(modifier) + NameLoc(Token const& token) + : name(token.getNameOrNull()), loc(token.getLoc()) { } }; - ModifierList() - : modifiers(nullptr) + struct StringSliceLoc { - } + UnownedStringSlice name; + SourceLoc loc; - ModifierList(Modifier* modifiers) - : modifiers(modifiers) - { - } + StringSliceLoc() + : name(nullptr) + { + } + explicit StringSliceLoc(const UnownedStringSlice& inName) + : name(inName) + { + } + StringSliceLoc(const UnownedStringSlice& inName, SourceLoc inLoc) + : name(inName), loc(inLoc) + { + } + StringSliceLoc(Token const& token) + : loc(token.getLoc()) + { + Name* tokenName = token.getNameOrNull(); + if (tokenName) + { + name = tokenName->text.getUnownedSlice(); + } + } + }; - Iterator begin() { return Iterator(modifiers); } - Iterator end() { return Iterator(nullptr); } + // Helper class for iterating over a list of heap-allocated modifiers + struct ModifierList + { + struct Iterator + { + Modifier* current = nullptr; - Modifier* modifiers = nullptr; -}; + Modifier* operator*() { return current; } -// Helper class for iterating over heap-allocated modifiers -// of a specific type. -template<typename T> -struct FilteredModifierList -{ - struct Iterator - { - Modifier* current = nullptr; + void operator++(); - T* operator*() { return (T*)current; } + bool operator!=(Iterator other) { return current != other.current; }; - void operator++(); + Iterator() + : current(nullptr) + { + } - bool operator!=(Iterator other) { return current != other.current; }; + Iterator(Modifier* modifier) + : current(modifier) + { + } + }; - Iterator() - : current(nullptr) + ModifierList() + : modifiers(nullptr) { } - Iterator(Modifier* modifier) - : current(modifier) + ModifierList(Modifier* modifiers) + : modifiers(modifiers) { } + + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } + + Modifier* modifiers = nullptr; }; - FilteredModifierList() - : modifiers(nullptr) + // Helper class for iterating over heap-allocated modifiers + // of a specific type. + template<typename T> + struct FilteredModifierList { - } + struct Iterator + { + Modifier* current = nullptr; - FilteredModifierList(Modifier* modifiers) - : modifiers(adjust(modifiers)) - { - } + T* operator*() { return (T*)current; } - Iterator begin() { return Iterator(modifiers); } - Iterator end() { return Iterator(nullptr); } + void operator++(); - static Modifier* adjust(Modifier* modifier); + bool operator!=(Iterator other) { return current != other.current; }; - Modifier* modifiers = nullptr; -}; + Iterator() + : current(nullptr) + { + } -// A set of modifiers attached to a syntax node -struct Modifiers -{ - // The first modifier in the linked list of heap-allocated modifiers - Modifier* first = nullptr; + Iterator(Modifier* modifier) + : current(modifier) + { + } + }; - template<typename T> - FilteredModifierList<T> getModifiersOfType() - { - return FilteredModifierList<T>(first); - } + FilteredModifierList() + : modifiers(nullptr) + { + } - // Find the first modifier of a given type, or return `nullptr` if none is found. - template<typename T> - T* findModifier() - { - return *getModifiersOfType<T>().begin(); - } + FilteredModifierList(Modifier* modifiers) + : modifiers(adjust(modifiers)) + { + } - template<typename T> - bool hasModifier() - { - return findModifier<T>() != nullptr; - } + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } - /// True if has no modifiers - bool isEmpty() const { return first == nullptr; } + static Modifier* adjust(Modifier* modifier); - FilteredModifierList<Modifier>::Iterator begin() - { - return FilteredModifierList<Modifier>::Iterator(first); - } - FilteredModifierList<Modifier>::Iterator end() + Modifier* modifiers = nullptr; + }; + + // A set of modifiers attached to a syntax node + struct Modifiers { - return FilteredModifierList<Modifier>::Iterator(nullptr); - } -}; + // The first modifier in the linked list of heap-allocated modifiers + Modifier* first = nullptr; -class NamedExpressionType; -class GenericDecl; -class ContainerDecl; + template<typename T> + FilteredModifierList<T> getModifiersOfType() + { + return FilteredModifierList<T>(first); + } -// Try to extract a simple integer value from an `IntVal`. -// This fill assert-fail if the object doesn't represent a literal value. -IntegerLiteralValue getIntVal(IntVal* val); + // Find the first modifier of a given type, or return `nullptr` if none is found. + template<typename T> + T* findModifier() + { + return *getModifiersOfType<T>().begin(); + } -/// Represents how much checking has been applied to a declaration. -enum class DeclCheckState : uint8_t -{ - /// The declaration has been parsed, but - /// is otherwise completely unchecked. - /// - Unchecked, + template<typename T> + bool hasModifier() + { + return findModifier<T>() != nullptr; + } - /// The declaration is parsed and inserted into the initial scope, - /// ready for future lookups from within the parser for disambiguation purposes. - ReadyForParserLookup, + /// True if has no modifiers + bool isEmpty() const { return first == nullptr; } - /// Basic checks on the modifiers of the declaration have been applied. - /// - /// For example, when a declaration has attributes, the transformation - /// of an attribute from the parsed-but-unchecked form into a checked - /// form (in which it has the appropriate C++ subclass) happens here. - /// - ModifiersChecked, + FilteredModifierList<Modifier>::Iterator begin() + { + return FilteredModifierList<Modifier>::Iterator(first); + } + FilteredModifierList<Modifier>::Iterator end() + { + return FilteredModifierList<Modifier>::Iterator(nullptr); + } + }; - /// Wiring up scopes of namespaces with their siblings defined in different - /// files/modules, and other namespaces imported via `using`. - ScopesWired, + class NamedExpressionType; + class GenericDecl; + class ContainerDecl; - /// The type/signature of the declaration has been checked. - /// - /// For a value declaration like a variable or function, this means that - /// the type of the declaration can be queried. - /// - /// For a type declaration like a `struct` or `typedef` this means - /// that a `Type` referring to that declaration can be formed. - /// - SignatureChecked, + // Try to extract a simple integer value from an `IntVal`. + // This fill assert-fail if the object doesn't represent a literal value. + IntegerLiteralValue getIntVal(IntVal * val); - /// The declaration's basic signature has been checked to the point that - /// it is ready to be referenced in other places. - /// - /// For a function, this means that it has been organized into a - /// "redeclration group" if there are multiple functions with the - /// same name in a scope. - /// - ReadyForReference, + /// Represents how much checking has been applied to a declaration. + enum class DeclCheckState : uint8_t + { + /// The declaration has been parsed, but + /// is otherwise completely unchecked. + /// + Unchecked, - /// The declaration is ready for lookup operations to be performed. - /// - /// For type declarations (e.g., aggregate types, generic type parameters) - /// this means that any base type or constraint clauses have been - /// sufficiently checked so that we can enumerate the inheritance - /// hierarchy of the type and discover all its members. - /// - ReadyForLookup, + /// The declaration is parsed and inserted into the initial scope, + /// ready for future lookups from within the parser for disambiguation purposes. + ReadyForParserLookup, - /// Any conformance declared on the declaration have been validated. - /// - /// In particular, this step means that a "witness table" has been - /// created to show how a type satisfies the requirements of any - /// interfaces it conforms to. - /// - ReadyForConformances, + /// Basic checks on the modifiers of the declaration have been applied. + /// + /// For example, when a declaration has attributes, the transformation + /// of an attribute from the parsed-but-unchecked form into a checked + /// form (in which it has the appropriate C++ subclass) happens here. + /// + ModifiersChecked, - /// Any DeclRefTypes with substitutions have been fully resolved - /// to concrete type. E.g. `T.X` with `T=A` should resolve to `A.X`. - /// We need a separate pass to resolve these types because `A.X` - /// maybe synthesized and made available only after conformance checking. - TypesFullyResolved, + /// Wiring up scopes of namespaces with their siblings defined in different + /// files/modules, and other namespaces imported via `using`. + ScopesWired, - /// All attributes are fully checked. This is the final step before - /// checking the function body. - AttributesChecked, + /// The type/signature of the declaration has been checked. + /// + /// For a value declaration like a variable or function, this means that + /// the type of the declaration can be queried. + /// + /// For a type declaration like a `struct` or `typedef` this means + /// that a `Type` referring to that declaration can be formed. + /// + SignatureChecked, - /// The body/definition is checked. - /// - /// This step includes any validation of the declaration that is - /// immaterial to clients code using the declaration, but that is - /// nonetheless relevant to checking correctness. - /// - /// The canonical example here is checking the body of functions. - /// Client code cannot depend on *how* a function is implemented, - /// but we still need to (eventually) check the bodies of all - /// functions, so it belongs in the last phase of checking. - /// - DefinitionChecked, - DefaultConstructorReadyForUse = DefinitionChecked, + /// The declaration's basic signature has been checked to the point that + /// it is ready to be referenced in other places. + /// + /// For a function, this means that it has been organized into a + /// "redeclration group" if there are multiple functions with the + /// same name in a scope. + /// + ReadyForReference, - /// The capabilities required by the decl is infered and validated. - /// - CapabilityChecked, + /// The declaration is ready for lookup operations to be performed. + /// + /// For type declarations (e.g., aggregate types, generic type parameters) + /// this means that any base type or constraint clauses have been + /// sufficiently checked so that we can enumerate the inheritance + /// hierarchy of the type and discover all its members. + /// + ReadyForLookup, - // For convenience at sites that call `ensureDecl()`, we define - // some aliases for the above states that are expressed in terms - // of what client code needs to be able to do with a declaration. - // - // These aliases can be changed over time if we decide to add - // more phases to semantic checking. - - CanEnumerateBases = ReadyForLookup, - CanUseBaseOfInheritanceDecl = ReadyForLookup, - CanUseTypeOfValueDecl = ReadyForReference, - CanUseExtensionTargetType = ReadyForLookup, - CanUseAsType = ReadyForReference, - CanUseFuncSignature = ReadyForReference, - CanSpecializeGeneric = ReadyForReference, - CanReadInterfaceRequirements = ReadyForLookup, -}; - -/// A `DeclCheckState` plus a bit to track whether a declaration is currently being checked. -struct DeclCheckStateExt -{ - SLANG_VALUE_CLASS(DeclCheckStateExt) + /// Any conformance declared on the declaration have been validated. + /// + /// In particular, this step means that a "witness table" has been + /// created to show how a type satisfies the requirements of any + /// interfaces it conforms to. + /// + ReadyForConformances, - typedef uint8_t RawType; - DeclCheckStateExt() {} - DeclCheckStateExt(DeclCheckState state) - : m_raw(uint8_t(state)) - { - } + /// Any DeclRefTypes with substitutions have been fully resolved + /// to concrete type. E.g. `T.X` with `T=A` should resolve to `A.X`. + /// We need a separate pass to resolve these types because `A.X` + /// maybe synthesized and made available only after conformance checking. + TypesFullyResolved, - enum : RawType - { - /// A flag to indicate that a declaration is being checked. + /// All attributes are fully checked. This is the final step before + /// checking the function body. + AttributesChecked, + + /// The body/definition is checked. /// - /// The value of this flag is chosen so that it can be - /// represented in the bits of a `DeclCheckState` without - /// colliding with the bits that represent actual states. + /// This step includes any validation of the declaration that is + /// immaterial to clients code using the declaration, but that is + /// nonetheless relevant to checking correctness. /// - kBeingCheckedBit = 0x80, - }; - - DeclCheckState getState() const { return DeclCheckState(m_raw & ~kBeingCheckedBit); } - void setState(DeclCheckState state) { m_raw = (m_raw & kBeingCheckedBit) | RawType(state); } + /// The canonical example here is checking the body of functions. + /// Client code cannot depend on *how* a function is implemented, + /// but we still need to (eventually) check the bodies of all + /// functions, so it belongs in the last phase of checking. + /// + DefinitionChecked, + DefaultConstructorReadyForUse = DefinitionChecked, - bool isBeingChecked() const { return (m_raw & kBeingCheckedBit) != 0; } + /// The capabilities required by the decl is infered and validated. + /// + CapabilityChecked, + + // For convenience at sites that call `ensureDecl()`, we define + // some aliases for the above states that are expressed in terms + // of what client code needs to be able to do with a declaration. + // + // These aliases can be changed over time if we decide to add + // more phases to semantic checking. + + CanEnumerateBases = ReadyForLookup, + CanUseBaseOfInheritanceDecl = ReadyForLookup, + CanUseTypeOfValueDecl = ReadyForReference, + CanUseExtensionTargetType = ReadyForLookup, + CanUseAsType = ReadyForReference, + CanUseFuncSignature = ReadyForReference, + CanSpecializeGeneric = ReadyForReference, + CanReadInterfaceRequirements = ReadyForLookup, + }; - void setIsBeingChecked(bool isBeingChecked) + /// A `DeclCheckState` plus a bit to track whether a declaration is currently being checked. + struct DeclCheckStateExt { - m_raw = (m_raw & ~kBeingCheckedBit) | (isBeingChecked ? kBeingCheckedBit : 0); - } + typedef uint8_t RawType; + DeclCheckStateExt() {} + DeclCheckStateExt(DeclCheckState state) + : m_raw(uint8_t(state)) + { + } - bool operator>=(DeclCheckState state) const { return getState() >= state; } + enum : RawType + { + /// A flag to indicate that a declaration is being checked. + /// + /// The value of this flag is chosen so that it can be + /// represented in the bits of a `DeclCheckState` without + /// colliding with the bits that represent actual states. + /// + kBeingCheckedBit = 0x80, + }; - RawType getRaw() const { return m_raw; } - void setRaw(RawType raw) { m_raw = raw; } + DeclCheckState getState() const { return DeclCheckState(m_raw & ~kBeingCheckedBit); } + void setState(DeclCheckState state) { m_raw = (m_raw & kBeingCheckedBit) | RawType(state); } - // TODO(JS): - // Unfortunately for automatic serialization to see this member, it has to be public. - // private: - RawType m_raw = 0; -}; + bool isBeingChecked() const { return (m_raw & kBeingCheckedBit) != 0; } -void addModifier(ModifiableSyntaxNode* syntax, Modifier* modifier); + void setIsBeingChecked(bool isBeingChecked) + { + m_raw = (m_raw & ~kBeingCheckedBit) | (isBeingChecked ? kBeingCheckedBit : 0); + } -void removeModifier(ModifiableSyntaxNode* syntax, Modifier* modifier); + bool operator>=(DeclCheckState state) const { return getState() >= state; } -struct QualType -{ - SLANG_VALUE_CLASS(QualType) + RawType getRaw() const { return m_raw; } + void setRaw(RawType raw) { m_raw = raw; } - Type* type = nullptr; - bool isLeftValue = false; - bool hasReadOnlyOnTarget = false; - bool isWriteOnly = false; + // TODO(JS): + // Unfortunately for automatic serialization to see this member, it has to be public. + // private: + RawType m_raw = 0; + }; - QualType() = default; + void addModifier(ModifiableSyntaxNode * syntax, Modifier * modifier); - QualType(Type* type); + void removeModifier(ModifiableSyntaxNode * syntax, Modifier * modifier); - QualType(Type* type, bool isLVal) - : QualType(type) + FIDDLE() + struct QualType { - isLeftValue = isLVal; - } + FIDDLE(...) + Type* type = nullptr; + bool isLeftValue = false; + bool hasReadOnlyOnTarget = false; + bool isWriteOnly = false; + QualType() = default; - Type* Ptr() { return type; } + QualType(Type* type); - operator Type*() { return type; } - Type* operator->() { return type; } -}; + QualType(Type* type, bool isLVal) + : QualType(type) + { + isLeftValue = isLVal; + } -class ASTBuilder; -struct ASTClassInfo -{ - struct Infos - { - const ReflectClassInfo* infos[int(ASTNodeType::CountOf)]; + Type* Ptr() { return type; } + + operator Type*() { return type; } + Type* operator->() { return type; } }; - SLANG_FORCE_INLINE static const ReflectClassInfo* getInfo(ASTNodeType type) - { - return kInfos.infos[int(type)]; - } - static const Infos kInfos; -}; -// A reference to a class of syntax node, that can be -// used to create instances on the fly -struct SyntaxClassBase -{ - SyntaxClassBase() {} + class ASTBuilder; - SyntaxClassBase(ReflectClassInfo const* inClassInfo) - : classInfo(inClassInfo) - { - } + struct SyntaxClassBase; + typedef SyntaxClassBase ReflectClassInfo; + typedef SyntaxClassBase ASTClassInfo; - void* createInstanceImpl(ASTBuilder* astBuilder) const + struct SyntaxClassInfo { - auto ci = classInfo; - if (!ci) - return nullptr; - - auto cf = ci->m_createFunc; - if (!cf) - return nullptr; + public: + char const* name; + ASTNodeType firstTag; + Count tagCount; + void* (*createFunc)(ASTBuilder*); + void (*destructFunc)(void*); - return cf(astBuilder); - } + template<typename T> + static SyntaxClassInfo* get() + { + return const_cast<SyntaxClassInfo*>(&T::kSyntaxClassInfo); + } + }; - SLANG_FORCE_INLINE bool isSubClassOfImpl(SyntaxClassBase const& super) const + // A reference to a class of syntax node, that can be + // used to create instances on the fly + struct SyntaxClassBase { - return classInfo ? classInfo->isSubClassOf(*super.classInfo) : false; - } + SyntaxClassBase() {} - ReflectClassInfo const* classInfo = nullptr; -}; + explicit SyntaxClassBase(ASTNodeType tag); -template<typename T> -struct SyntaxClass : SyntaxClassBase -{ - SyntaxClass() {} + SyntaxClassBase(SyntaxClassInfo const* info) + : _info(info) + { + } - template<typename U> - SyntaxClass( - SyntaxClass<U> const& other, - typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) - : SyntaxClassBase(other.classInfo) - { - } - T* createInstance(ASTBuilder* astBuilder) const { return (T*)createInstanceImpl(astBuilder); } + ASTNodeType getTag() const { return getInfo()->firstTag; } + UnownedTerminatedStringSlice getName() const; - SyntaxClass(const ReflectClassInfo* inClassInfo) - : SyntaxClassBase(inClassInfo) - { - } + void* createInstanceImpl(ASTBuilder* astBuilder) const; + void destructInstanceImpl(void* instance) const; - static SyntaxClass<T> getClass() { return SyntaxClass<T>(&T::kReflectClassInfo); } + bool isSubClassOf(SyntaxClassBase const& super) const; - template<typename U> - bool isSubClassOf(SyntaxClass<U> super) - { - return isSubClassOfImpl(super); - } + typedef SyntaxClassInfo Info; - template<typename U> - bool isSubClassOf() - { - return isSubClassOf(SyntaxClass<U>::getClass()); - } + Info* getInfo() const { return const_cast<Info*>(_info); } + operator Info*() const { return const_cast<Info*>(_info); } - template<typename U> - bool operator==(const SyntaxClass<U> other) const - { - return classInfo == other.classInfo; - } - template<typename U> - bool operator!=(const SyntaxClass<U> other) const - { - return classInfo != other.classInfo; - } -}; + bool operator==(SyntaxClassBase const& other) const { return _info == other._info; } -template<typename T> -SyntaxClass<T> getClass() -{ - return SyntaxClass<T>::getClass(); -} + bool operator!=(SyntaxClassBase const& other) const { return _info != other._info; } -struct SubstitutionSet -{ - DeclRefBase* declRef = nullptr; - - // The element index if the substitution is happening inside a pack expansion. - // For example, if we are substituting the pattern type of `expand each T`, where - // `T` is a type pack, then packExpansionIndex will have a value starting from 0 - // to the count of the type pack during expansion of the `expand` type when we - // substitute `each T` with the element of `T` at index `packExpansionIndex`. - int packExpansionIndex = -1; - - SubstitutionSet() = default; - SubstitutionSet(DeclRefBase* declRefBase) - : declRef(declRefBase) - { - } - explicit operator bool() const; - - template<typename F> - void forEachGenericSubstitution(F func) const; - - template<typename F> - void forEachSubstitutionArg(F func) const; - - Type* applyToType(ASTBuilder* astBuilder, Type* type) const; - DeclRefBase* applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* declRef) const; - - LookupDeclRef* findLookupDeclRef() const; - GenericAppDeclRef* findGenericAppDeclRef(GenericDecl* genericDecl) const; - GenericAppDeclRef* findGenericAppDeclRef() const; - DeclRefBase* getInnerMostNodeWithSubstInfo() const; -}; - -/// An expression together with (optional) substutions to apply to it -/// -/// Under the hood this is a pair of an `Expr*` and a `SubstitutionSet`. -/// Conceptually it represents the result of applying the substitutions, -/// recursively, to the given expression. -/// -/// `SubstExprBase` exists primarily to provide a non-templated base type -/// for `SubstExpr<T>`. Code should prefer to use `SubstExpr<Expr>` instead -/// of `SubstExprBase` as often as possible. -/// -struct SubstExprBase -{ -public: - /// Initialize as a null expression - SubstExprBase() {} + private: + Info const* _info = nullptr; + }; - /// Initialize as the given `expr` with no subsitutions applied - SubstExprBase(Expr* expr) - : m_expr(expr) - { - } + template<typename T> + struct SyntaxClass; - /// Initialize as the given `expr` with the given `substs` applied - SubstExprBase(Expr* expr, SubstitutionSet const& substs) - : m_expr(expr), m_substs(substs) - { - } + template<typename T> + SyntaxClass<T> getSyntaxClass(); - /// Get the underlying expression without any substitutions - Expr* getExpr() const { return m_expr; } + template<typename T = NodeBase> + struct SyntaxClass : SyntaxClassBase + { + SyntaxClass() {} - /// Get the subsitutions being applied, if any - SubstitutionSet const& getSubsts() const { return m_substs; } + template<typename U> + SyntaxClass( + SyntaxClass<U> const& other, + typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) + : SyntaxClassBase(other) + { + } -private: - Expr* m_expr = nullptr; - SubstitutionSet m_substs; + explicit SyntaxClass(SyntaxClassBase const& other) + : SyntaxClassBase(other) + { + } - typedef void (SubstExprBase::*SafeBool)(); - void SafeBoolTrue() {} + explicit SyntaxClass(ASTNodeType tag) + : SyntaxClassBase(tag) + { + } -public: - /// Test whether this is a non-null expression - operator SafeBool() { return m_expr ? &SubstExprBase::SafeBoolTrue : nullptr; } + explicit SyntaxClass(SyntaxClassInfo const* info) + : SyntaxClassBase(info) + { + } - /// Test whether this is a null expression - bool operator!() const { return m_expr == nullptr; } -}; + T* createInstance(ASTBuilder* astBuilder) const + { + return (T*)createInstanceImpl(astBuilder); + } + void destructInstance(T* instance) { destructInstanceImpl(instance); } -/// An expression together with (optional) substutions to apply to it -/// -/// Under the hood this is a pair of an `T*` (there `T: Expr`) and a `SubstitutionSet`. -/// Conceptually it represents the result of applying the substitutions, -/// recursively, to the given expression. -/// -template<typename T> -struct SubstExpr : SubstExprBase -{ -private: - typedef SubstExprBase Super; + bool isSubClassOf(SyntaxClassBase const& other) + { + return SyntaxClassBase::isSubClassOf(other); + } -public: - /// Initialize as a null expression - SubstExpr() {} + template<typename U> + bool isSubClassOf() + { + return SyntaxClassBase::isSubClassOf(getSyntaxClass<U>()); + } + }; - /// Initialize as the given `expr` with no subsitutions applied - SubstExpr(T* expr) - : Super(expr) + template<typename T> + SyntaxClass<T> getSyntaxClass() { + return SyntaxClass<T>(SyntaxClassInfo::get<T>()); } - /// Initialize as the given `expr` with the given `substs` applied - SubstExpr(T* expr, SubstitutionSet const& substs) - : Super(expr, substs) + struct SubstitutionSet { - } + DeclRefBase* declRef = nullptr; - /// Initialize as a copy of the given `other` expression - template<typename U> - SubstExpr( - SubstExpr<U> const& other, - typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) - : Super(other.getExpr(), other.getSubsts()) - { - } + // The element index if the substitution is happening inside a pack expansion. + // For example, if we are substituting the pattern type of `expand each T`, where + // `T` is a type pack, then packExpansionIndex will have a value starting from 0 + // to the count of the type pack during expansion of the `expand` type when we + // substitute `each T` with the element of `T` at index `packExpansionIndex`. + int packExpansionIndex = -1; + + SubstitutionSet() = default; + SubstitutionSet(DeclRefBase* declRefBase) + : declRef(declRefBase) + { + } + explicit operator bool() const; - /// Get the underlying expression without any substitutions - T* getExpr() const { return (T*)Super::getExpr(); } + template<typename F> + void forEachGenericSubstitution(F func) const; + + template<typename F> + void forEachSubstitutionArg(F func) const; + + Type* applyToType(ASTBuilder* astBuilder, Type* type) const; + DeclRefBase* applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* declRef) const; + + LookupDeclRef* findLookupDeclRef() const; + GenericAppDeclRef* findGenericAppDeclRef(GenericDecl* genericDecl) const; + GenericAppDeclRef* findGenericAppDeclRef() const; + DeclRefBase* getInnerMostNodeWithSubstInfo() const; + }; - /// Dynamic cast to an expression of type `U` + /// An expression together with (optional) substutions to apply to it /// - /// Returns a null expression if the cast fails, or if this expression was null. - template<typename U> - SubstExpr<U> as() + /// Under the hood this is a pair of an `Expr*` and a `SubstitutionSet`. + /// Conceptually it represents the result of applying the substitutions, + /// recursively, to the given expression. + /// + /// `SubstExprBase` exists primarily to provide a non-templated base type + /// for `SubstExpr<T>`. Code should prefer to use `SubstExpr<Expr>` instead + /// of `SubstExprBase` as often as possible. + /// + struct SubstExprBase { - return SubstExpr<U>(Slang::as<U>(getExpr()), getSubsts()); - } -}; + public: + /// Initialize as a null expression + SubstExprBase() {} -SubstExpr<Expr> applySubstitutionToExpr(SubstitutionSet substSet, Expr* expr); + /// Initialize as the given `expr` with no subsitutions applied + SubstExprBase(Expr* expr) + : m_expr(expr) + { + } -class ASTBuilder; + /// Initialize as the given `expr` with the given `substs` applied + SubstExprBase(Expr* expr, SubstitutionSet const& substs) + : m_expr(expr), m_substs(substs) + { + } -template<typename T> -struct DeclRef; -Module* getModule(Decl* decl); + /// Get the underlying expression without any substitutions + Expr* getExpr() const { return m_expr; } + /// Get the subsitutions being applied, if any + SubstitutionSet const& getSubsts() const { return m_substs; } -// If this is a declref to an associatedtype with a ThisTypeSubsitution, -// try to find the concrete decl that satisfies the associatedtype requirement from the -// concrete type supplied by ThisTypeSubstittution. -Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef<Decl> declRef); + private: + Expr* m_expr = nullptr; + SubstitutionSet m_substs; -template<typename T = Decl> -struct DeclRef -{ - friend class ASTBuilder; + typedef void (SubstExprBase::*SafeBool)(); + void SafeBoolTrue() {} -public: - typedef T DeclType; - DeclRefBase* declRefBase; - DeclRef() - : declRefBase(nullptr) + public: + /// Test whether this is a non-null expression + operator SafeBool() { return m_expr ? &SubstExprBase::SafeBoolTrue : nullptr; } + + /// Test whether this is a null expression + bool operator!() const { return m_expr == nullptr; } + }; + + /// An expression together with (optional) substutions to apply to it + /// + /// Under the hood this is a pair of an `T*` (there `T: Expr`) and a `SubstitutionSet`. + /// Conceptually it represents the result of applying the substitutions, + /// recursively, to the given expression. + /// + template<typename T> + struct SubstExpr : SubstExprBase { - } + private: + typedef SubstExprBase Super; - void init(DeclRefBase* base); + public: + /// Initialize as a null expression + SubstExpr() {} - DeclRef(Decl* decl); + /// Initialize as the given `expr` with no subsitutions applied + SubstExpr(T* expr) + : Super(expr) + { + } - DeclRef(DeclRefBase* base) { init(base); } + /// Initialize as the given `expr` with the given `substs` applied + SubstExpr(T* expr, SubstitutionSet const& substs) + : Super(expr, substs) + { + } - template<typename U, typename = typename EnableIf<IsConvertible<T*, U*>::Value, void>::type> - DeclRef(DeclRef<U> const& other) - : declRefBase(other.declRefBase) - { - } + /// Initialize as a copy of the given `other` expression + template<typename U> + SubstExpr( + SubstExpr<U> const& other, + typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) + : Super(other.getExpr(), other.getSubsts()) + { + } - T* getDecl() const; + /// Get the underlying expression without any substitutions + T* getExpr() const { return (T*)Super::getExpr(); } - Name* getName() const; + /// Dynamic cast to an expression of type `U` + /// + /// Returns a null expression if the cast fails, or if this expression was null. + template<typename U> + SubstExpr<U> as() + { + return SubstExpr<U>(Slang::as<U>(getExpr()), getSubsts()); + } + }; - SourceLoc getNameLoc() const; - SourceLoc getLoc() const; - DeclRef<ContainerDecl> getParent() const; - HashCode getHashCode() const; - Type* substitute(ASTBuilder* astBuilder, Type* type) const; + SubstExpr<Expr> applySubstitutionToExpr(SubstitutionSet substSet, Expr * expr); - SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const; + class ASTBuilder; - // Apply substitutions to a type or declaration - template<typename U> - DeclRef<U> substitute(ASTBuilder* astBuilder, DeclRef<U> declRef) const; + template<typename T> + struct DeclRef; + Module* getModule(Decl * decl); - // Apply substitutions to this declaration reference - DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; - template<typename U> - DeclRef<U> as() const - { - DeclRef<U> result = DeclRef<U>(declRefBase); - return result; - } + // If this is a declref to an associatedtype with a ThisTypeSubsitution, + // try to find the concrete decl that satisfies the associatedtype requirement from the + // concrete type supplied by ThisTypeSubstittution. + Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst( + ASTBuilder * builder, + DeclRef<Decl> declRef); - template<typename U> - bool is() const + template<typename T = Decl> + struct DeclRef { - return Slang::as<U>(static_cast<NodeBase*>(getDecl())) != nullptr; - } + friend class ASTBuilder; - operator DeclRefBase*() const { return declRefBase; } + public: + typedef T DeclType; + DeclRefBase* declRefBase; + DeclRef() + : declRefBase(nullptr) + { + } - operator DeclRef<Decl>() const { return DeclRef<Decl>(declRefBase); } + void init(DeclRefBase* base); - template<typename U> - bool equals(DeclRef<U> other) const - { - return declRefBase == other.declRefBase; - } + DeclRef(Decl* decl); - template<typename U> - bool operator==(DeclRef<U> other) const - { - return equals(other); - } + DeclRef(DeclRefBase* base) { init(base); } - template<typename U> - bool operator!=(DeclRef<U> other) const - { - return !equals(other); - } + template<typename U, typename = typename EnableIf<IsConvertible<T*, U*>::Value, void>::type> + DeclRef(DeclRef<U> const& other) + : declRefBase(other.declRefBase) + { + } - explicit operator bool() const { return declRefBase; } -}; + T* getDecl() const; -template<typename T> -inline DeclRef<T> makeDeclRef(T* decl) -{ - return DeclRef<T>(decl); -} + Name* getName() const; -SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr); -DeclRef<Decl> substituteDeclRef( - SubstitutionSet const& substs, - ASTBuilder* astBuilder, - DeclRef<Decl> const& declRef); -Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type); + SourceLoc getNameLoc() const; + SourceLoc getLoc() const; + DeclRef<ContainerDecl> getParent() const; + HashCode getHashCode() const; + Type* substitute(ASTBuilder* astBuilder, Type* type) const; -enum class MemberFilterStyle -{ - All, ///< All members - Instance, ///< Only instance members - Static, ///< Only static (ie non instance) members -}; - -Decl* const* adjustFilterCursorImpl( - const ReflectClassInfo& clsInfo, - MemberFilterStyle filterStyle, - Decl* const* ptr, - Decl* const* end); -Decl* const* getFilterCursorByIndexImpl( - const ReflectClassInfo& clsInfo, - MemberFilterStyle filterStyle, - Decl* const* ptr, - Decl* const* end, - Index index); -Index getFilterCountImpl( - const ReflectClassInfo& clsInfo, - MemberFilterStyle filterStyle, - Decl* const* ptr, - Decl* const* end); - - -template<typename T> -Decl* const* adjustFilterCursor(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) -{ - return adjustFilterCursorImpl(T::kReflectClassInfo, filterStyle, ptr, end); -} - -/// Finds the element at index. If there is no element at the index (for example has too few -/// elements), returns nullptr. -template<typename T> -Decl* const* getFilterCursorByIndex( - MemberFilterStyle filterStyle, - Decl* const* ptr, - Decl* const* end, - Index index) -{ - return getFilterCursorByIndexImpl(T::kReflectClassInfo, filterStyle, ptr, end, index); -} + SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const; -template<typename T> -Index getFilterCount(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) -{ - return getFilterCountImpl(T::kReflectClassInfo, filterStyle, ptr, end); -} + // Apply substitutions to a type or declaration + template<typename U> + DeclRef<U> substitute(ASTBuilder* astBuilder, DeclRef<U> declRef) const; -template<typename T> -bool isFilterNonEmpty(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) -{ - return adjustFilterCursorImpl(T::kReflectClassInfo, filterStyle, ptr, end) != end; -} + // Apply substitutions to this declaration reference + DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; -template<typename T> -struct FilteredMemberList -{ - typedef Decl* Element; + template<typename U> + DeclRef<U> as() const + { + DeclRef<U> result = DeclRef<U>(declRefBase); + return result; + } - FilteredMemberList() - : m_begin(nullptr), m_end(nullptr) - { - } + template<typename U> + bool is() const + { + return Slang::as<U>(static_cast<NodeBase*>(getDecl())) != nullptr; + } - explicit FilteredMemberList( - List<Element> const& list, - MemberFilterStyle filterStyle = MemberFilterStyle::All) - : m_begin(adjustFilterCursor<T>(filterStyle, list.begin(), list.end())) - , m_end(list.end()) - , m_filterStyle(filterStyle) - { - } + operator DeclRefBase*() const { return declRefBase; } - struct Iterator - { - const Element* m_cursor; - const Element* m_end; - MemberFilterStyle m_filterStyle; + operator DeclRef<Decl>() const { return DeclRef<Decl>(declRefBase); } - bool operator!=(Iterator const& other) const { return m_cursor != other.m_cursor; } + template<typename U> + bool equals(DeclRef<U> other) const + { + return declRefBase == other.declRefBase; + } + + template<typename U> + bool operator==(DeclRef<U> other) const + { + return equals(other); + } - void operator++() { m_cursor = adjustFilterCursor<T>(m_filterStyle, m_cursor + 1, m_end); } + template<typename U> + bool operator!=(DeclRef<U> other) const + { + return !equals(other); + } - T* operator*() { return static_cast<T*>(*m_cursor); } + explicit operator bool() const { return declRefBase; } }; - Iterator begin() + template<typename T> + inline DeclRef<T> makeDeclRef(T * decl) { - Iterator iter = {m_begin, m_end, m_filterStyle}; - return iter; + return DeclRef<T>(decl); } - Iterator end() + SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr); + DeclRef<Decl> substituteDeclRef( + SubstitutionSet const& substs, + ASTBuilder* astBuilder, + DeclRef<Decl> const& declRef); + Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type); + + enum class MemberFilterStyle { - Iterator iter = {m_end, m_end, m_filterStyle}; - return iter; - } + All, ///< All members + Instance, ///< Only instance members + Static, ///< Only static (ie non instance) members + }; - // TODO(tfoley): It is ugly to have these. - // We should probably fix the call sites instead. - T* getFirst() { return *begin(); } - Index getCount() { return getFilterCount<T>(m_filterStyle, m_begin, m_end); } + Decl* const* adjustFilterCursorImpl( + const ReflectClassInfo& clsInfo, + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end); + Decl* const* getFilterCursorByIndexImpl( + const ReflectClassInfo& clsInfo, + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end, + Index index); + Index getFilterCountImpl( + const ReflectClassInfo& clsInfo, + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end); - T* operator[](Index index) const - { - Decl* const* ptr = getFilterCursorByIndex<T>(m_filterStyle, m_begin, m_end, index); - SLANG_ASSERT(ptr); - return static_cast<T*>(*ptr); - } - /// Returns true if empty (equivalent to getCount() == 0) - bool isEmpty() const + template<typename T> + Decl* const* adjustFilterCursor( + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end) { - /// Note we don't have to scan, because m_begin has already been adjusted, when the - /// FilteredMemberList is constructed - return m_begin == m_end; + return adjustFilterCursorImpl(getSyntaxClass<T>(), filterStyle, ptr, end); } - /// Returns true if non empty (equivalent to getCount() != 0 but faster) - bool isNonEmpty() const { return !isEmpty(); } - List<T*> toList() + /// Finds the element at index. If there is no element at the index (for example has too few + /// elements), returns nullptr. + template<typename T> + Decl* const* getFilterCursorByIndex( + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end, + Index index) { - List<T*> result; - for (auto element : (*this)) - { - result.add(element); - } - return result; + return getFilterCursorByIndexImpl(getSyntaxClass<T>(), filterStyle, ptr, end, index); } - const Element* - m_begin; ///< Is either equal to m_end, or points to first *valid* filtered member - const Element* m_end; - MemberFilterStyle m_filterStyle; -}; - -struct TransparentMemberInfo -{ - // The declaration of the transparent member - Decl* decl = nullptr; -}; - -template<typename T> -struct FilteredMemberRefList -{ - List<Decl*> const& m_decls; - DeclRef<Decl> m_parent; - MemberFilterStyle m_filterStyle; - ASTBuilder* m_astBuilder; - - FilteredMemberRefList( - ASTBuilder* astBuilder, - List<Decl*> const& decls, - DeclRef<Decl> parent, - MemberFilterStyle filterStyle = MemberFilterStyle::All) - : m_decls(decls), m_parent(parent), m_filterStyle(filterStyle), m_astBuilder(astBuilder) + template<typename T> + Index getFilterCount(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) { + return getFilterCountImpl(getSyntaxClass<T>(), filterStyle, ptr, end); } - Index getCount() const + template<typename T> + bool isFilterNonEmpty(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) { - return getFilterCount<T>(m_filterStyle, m_decls.begin(), m_decls.end()); + return adjustFilterCursorImpl(getSyntaxClass<T>(), filterStyle, ptr, end) != end; } - /// True if empty (equivalent to getCount == 0, but faster) - bool isEmpty() const { return !isNonEmpty(); } - /// True if non empty (equivalent to getCount() != 0 but faster) - bool isNonEmpty() const + template<typename T> + struct FilteredMemberList { - return isFilterNonEmpty<T>(m_filterStyle, m_decls.begin(), m_decls.end()); - } + typedef Decl* Element; - DeclRef<T> getFirstOrNull() { return isEmpty() ? DeclRef<T>() : (*this)[0]; } + FilteredMemberList() + : m_begin(nullptr), m_end(nullptr) + { + } - DeclRef<T> operator[](Index index) const - { - Decl* const* decl = - getFilterCursorByIndex<T>(m_filterStyle, m_decls.begin(), m_decls.end(), index); - SLANG_ASSERT(decl); - return _getMemberDeclRef(m_astBuilder, m_parent, (T*)*decl).template as<T>(); - } + explicit FilteredMemberList( + List<Element> const& list, + MemberFilterStyle filterStyle = MemberFilterStyle::All) + : m_begin(adjustFilterCursor<T>(filterStyle, list.begin(), list.end())) + , m_end(list.end()) + , m_filterStyle(filterStyle) + { + } - List<DeclRef<T>> toArray() const - { - List<DeclRef<T>> result; - for (auto d : *this) - result.add(d); - return result; - } + struct Iterator + { + const Element* m_cursor; + const Element* m_end; + MemberFilterStyle m_filterStyle; - struct Iterator - { - FilteredMemberRefList const* m_list; - Decl* const* m_ptr; - Decl* const* m_end; - MemberFilterStyle m_filterStyle; + bool operator!=(Iterator const& other) const { return m_cursor != other.m_cursor; } + + void operator++() + { + m_cursor = adjustFilterCursor<T>(m_filterStyle, m_cursor + 1, m_end); + } - Iterator() - : m_list(nullptr), m_ptr(nullptr), m_filterStyle(MemberFilterStyle::All) + T* operator*() { return static_cast<T*>(*m_cursor); } + }; + + Iterator begin() { + Iterator iter = {m_begin, m_end, m_filterStyle}; + return iter; } - Iterator( - FilteredMemberRefList const* list, - Decl* const* ptr, - Decl* const* end, - MemberFilterStyle filterStyle) - : m_list(list), m_ptr(ptr), m_end(end), m_filterStyle(filterStyle) + + Iterator end() { + Iterator iter = {m_end, m_end, m_filterStyle}; + return iter; } - bool operator!=(const Iterator& other) const { return m_ptr != other.m_ptr; } + // TODO(tfoley): It is ugly to have these. + // We should probably fix the call sites instead. + T* getFirst() { return *begin(); } + Index getCount() { return getFilterCount<T>(m_filterStyle, m_begin, m_end); } + + T* operator[](Index index) const + { + Decl* const* ptr = getFilterCursorByIndex<T>(m_filterStyle, m_begin, m_end, index); + SLANG_ASSERT(ptr); + return static_cast<T*>(*ptr); + } - void operator++() { m_ptr = adjustFilterCursor<T>(m_filterStyle, m_ptr + 1, m_end); } + /// Returns true if empty (equivalent to getCount() == 0) + bool isEmpty() const + { + /// Note we don't have to scan, because m_begin has already been adjusted, when the + /// FilteredMemberList is constructed + return m_begin == m_end; + } + /// Returns true if non empty (equivalent to getCount() != 0 but faster) + bool isNonEmpty() const { return !isEmpty(); } - DeclRef<T> operator*() + List<T*> toList() { - return _getMemberDeclRef(m_list->m_astBuilder, m_list->m_parent, (T*)*m_ptr) - .template as<T>(); + List<T*> result; + for (auto element : (*this)) + { + result.add(element); + } + return result; } + + const Element* + m_begin; ///< Is either equal to m_end, or points to first *valid* filtered member + const Element* m_end; + MemberFilterStyle m_filterStyle; }; - Iterator begin() const + struct TransparentMemberInfo { - return Iterator( - this, - adjustFilterCursor<T>(m_filterStyle, m_decls.begin(), m_decls.end()), - m_decls.end(), - m_filterStyle); - } - Iterator end() const { return Iterator(this, m_decls.end(), m_decls.end(), m_filterStyle); } -}; + // The declaration of the transparent member + Decl* decl = nullptr; + }; -// -// type Expressions -// + template<typename T> + struct FilteredMemberRefList + { + List<Decl*> const& m_decls; + DeclRef<Decl> m_parent; + MemberFilterStyle m_filterStyle; + ASTBuilder* m_astBuilder; + + FilteredMemberRefList( + ASTBuilder* astBuilder, + List<Decl*> const& decls, + DeclRef<Decl> parent, + MemberFilterStyle filterStyle = MemberFilterStyle::All) + : m_decls(decls), m_parent(parent), m_filterStyle(filterStyle), m_astBuilder(astBuilder) + { + } -// A "type expression" is a term that we expect to resolve to a type during checking. -// We store both the original syntax and the resolved type here. -struct TypeExp -{ - SLANG_VALUE_CLASS(TypeExp) - typedef TypeExp ThisType; + Index getCount() const + { + return getFilterCount<T>(m_filterStyle, m_decls.begin(), m_decls.end()); + } - TypeExp() {} - TypeExp(TypeExp const& other) - : exp(other.exp), type(other.type) - { - } - explicit TypeExp(Expr* exp) - : exp(exp) - { - } - explicit TypeExp(Type* type) - : type(type) - { - } - TypeExp(Expr* exp, Type* type) - : exp(exp), type(type) - { - } + /// True if empty (equivalent to getCount == 0, but faster) + bool isEmpty() const { return !isNonEmpty(); } + /// True if non empty (equivalent to getCount() != 0 but faster) + bool isNonEmpty() const + { + return isFilterNonEmpty<T>(m_filterStyle, m_decls.begin(), m_decls.end()); + } - Expr* exp = nullptr; - Type* type = nullptr; + DeclRef<T> getFirstOrNull() { return isEmpty() ? DeclRef<T>() : (*this)[0]; } - bool equals(Type* other); + DeclRef<T> operator[](Index index) const + { + Decl* const* decl = + getFilterCursorByIndex<T>(m_filterStyle, m_decls.begin(), m_decls.end(), index); + SLANG_ASSERT(decl); + return _getMemberDeclRef(m_astBuilder, m_parent, (T*)*decl).template as<T>(); + } - Type* Ptr() { return type; } - operator Type*() { return type; } - Type* operator->() { return Ptr(); } + List<DeclRef<T>> toArray() const + { + List<DeclRef<T>> result; + for (auto d : *this) + result.add(d); + return result; + } - ThisType& operator=(const ThisType& rhs) = default; + struct Iterator + { + FilteredMemberRefList const* m_list; + Decl* const* m_ptr; + Decl* const* m_end; + MemberFilterStyle m_filterStyle; + + Iterator() + : m_list(nullptr), m_ptr(nullptr), m_filterStyle(MemberFilterStyle::All) + { + } + Iterator( + FilteredMemberRefList const* list, + Decl* const* ptr, + Decl* const* end, + MemberFilterStyle filterStyle) + : m_list(list), m_ptr(ptr), m_end(end), m_filterStyle(filterStyle) + { + } + + bool operator!=(const Iterator& other) const { return m_ptr != other.m_ptr; } + + void operator++() { m_ptr = adjustFilterCursor<T>(m_filterStyle, m_ptr + 1, m_end); } + + DeclRef<T> operator*() + { + return _getMemberDeclRef(m_list->m_astBuilder, m_list->m_parent, (T*)*m_ptr) + .template as<T>(); + } + }; + + Iterator begin() const + { + return Iterator( + this, + adjustFilterCursor<T>(m_filterStyle, m_decls.begin(), m_decls.end()), + m_decls.end(), + m_filterStyle); + } + Iterator end() const { return Iterator(this, m_decls.end(), m_decls.end(), m_filterStyle); } + }; - // TypeExp accept(SyntaxVisitor* visitor); + // + // type Expressions + // - /// A global immutable TypeExp, that has no type or exp set. - static const TypeExp empty; -}; + // A "type expression" is a term that we expect to resolve to a type during checking. + // We store both the original syntax and the resolved type here. + FIDDLE() + struct TypeExp + { + FIDDLE(...) + typedef TypeExp ThisType; -// Masks to be applied when lookup up declarations -enum class LookupMask : uint8_t -{ - type = 0x1, - Function = 0x2, - Value = 0x4, - Attribute = 0x8, - SyntaxDecl = 0x10, - Default = type | Function | Value | SyntaxDecl, -}; - -/// Flags for options to be used when looking up declarations -enum class LookupOptions : uint8_t -{ - None = 0, - IgnoreBaseInterfaces = 1 << 0, - Completion = 1 << 1, ///< Lookup all applicable decls for code completion suggestions - NoDeref = 1 << 2, - ConsiderAllLocalNamesInScope = 1 << 3, - ///^ Normally we rely on the checking state of local names to determine - /// if they have been declared. If the scopes are currently - /// "under-construction" and not being checked, then it's safe to - /// consider all names we've inserted so far. This is used when - /// checking to see if a keyword is shadowed. - IgnoreInheritance = - 1 << 4, ///< Lookup only non inheritance children of a struct (including `extension`) - IgnoreTransparentMembers = 1 << 5, -}; -inline LookupOptions operator&(LookupOptions a, LookupOptions b) -{ - return (LookupOptions)((std::underlying_type_t<LookupOptions>)a & - (std::underlying_type_t<LookupOptions>)b); -} + TypeExp() {} + TypeExp(TypeExp const& other) + : exp(other.exp), type(other.type) + { + } + explicit TypeExp(Expr* exp) + : exp(exp) + { + } + explicit TypeExp(Type* type) + : type(type) + { + } + TypeExp(Expr* exp, Type* type) + : exp(exp), type(type) + { + } -class SerialRefObject; + Expr* exp = nullptr; + Type* type = nullptr; -// Make sure C++ extractor can see the base class. -SLANG_PRE_DECLARE(OBJ, class SerialRefObject) + bool equals(Type* other); -SLANG_TYPE_SET(OBJ, RefObject) -SLANG_TYPE_SET(VALUE, Value) -SLANG_TYPE_SET(AST, ASTNode) + Type* Ptr() { return type; } + operator Type*() { return type; } + Type* operator->() { return Ptr(); } -class LookupResultItem_Breadcrumb : public SerialRefObject -{ -public: - SLANG_OBJ_CLASS(LookupResultItem_Breadcrumb) + ThisType& operator=(const ThisType& rhs) = default; - enum class Kind : uint8_t - { - // The lookup process looked "through" an in-scope - // declaration to the fields inside of it, so that - // even if lookup started with a simple name `f`, - // it needs to result in a member expression `obj.f`. - Member, - - // The lookup process took a pointer(-like) value, and then - // proceeded to derefence it and look at the thing(s) - // it points to instead, so that the final expression - // needs to have `(*obj)` - Deref, - - // The lookup process saw a value `obj` of type `T` and - // took into account an in-scope constraint that says - // `T` is a subtype of some other type `U`, so that - // lookup was able to find a member through type `U` - // instead. - SuperType, - - // The lookup process considered a member of an - // enclosing type as being in scope, so that any - // reference to that member needs to use a `this` - // expression as appropriate. - This, + /// A global immutable TypeExp, that has no type or exp set. + static const TypeExp empty; }; - // The kind of lookup step that was performed - Kind kind; - - // For the `Kind::This` case, what does the implicit - // `this` or `This` parameter refer to? - // - enum class ThisParameterMode : uint8_t + // Masks to be applied when lookup up declarations + enum class LookupMask : uint8_t { - ImmutableValue, // An immutable `this` value - MutableValue, // A mutable `this` value - Type, // A `This` type - - Default = ImmutableValue, + type = 0x1, + Function = 0x2, + Value = 0x4, + Attribute = 0x8, + SyntaxDecl = 0x10, + Default = type | Function | Value | SyntaxDecl, }; - ThisParameterMode thisParameterMode = ThisParameterMode::Default; - - // As needed, a reference to the declaration that faciliated - // the lookup step. - // - // For a `Member` lookup step, this is the declaration whose - // members were implicitly pulled into scope. - // - // For a `Constraint` lookup step, this is the `ConstraintDecl` - // that serves to witness the subtype relationship. - // - DeclRef<Decl> declRef; - - Val* val = nullptr; - - // The next implicit step that the lookup process took to - // arrive at a final value. - RefPtr<LookupResultItem_Breadcrumb> next; - LookupResultItem_Breadcrumb( - Kind kind, - DeclRef<Decl> declRef, - Val* val, - RefPtr<LookupResultItem_Breadcrumb> next, - ThisParameterMode thisParameterMode = ThisParameterMode::Default) - : kind(kind), thisParameterMode(thisParameterMode), declRef(declRef), val(val), next(next) + /// Flags for options to be used when looking up declarations + enum class LookupOptions : uint8_t + { + None = 0, + IgnoreBaseInterfaces = 1 << 0, + Completion = 1 << 1, ///< Lookup all applicable decls for code completion suggestions + NoDeref = 1 << 2, + ConsiderAllLocalNamesInScope = 1 << 3, + ///^ Normally we rely on the checking state of local names to determine + /// if they have been declared. If the scopes are currently + /// "under-construction" and not being checked, then it's safe to + /// consider all names we've inserted so far. This is used when + /// checking to see if a keyword is shadowed. + IgnoreInheritance = + 1 << 4, ///< Lookup only non inheritance children of a struct (including `extension`) + IgnoreTransparentMembers = 1 << 5, + }; + inline LookupOptions operator&(LookupOptions a, LookupOptions b) { + return (LookupOptions)((std::underlying_type_t<LookupOptions>)a & + (std::underlying_type_t<LookupOptions>)b); } -protected: - // Needed for serialization - LookupResultItem_Breadcrumb() = default; -}; - -// Represents one item found during lookup -struct LookupResultItem -{ - SLANG_VALUE_CLASS(LookupResultItem) + class LookupResultItem_Breadcrumb : public RefObject + { + public: + enum class Kind : uint8_t + { + // The lookup process looked "through" an in-scope + // declaration to the fields inside of it, so that + // even if lookup started with a simple name `f`, + // it needs to result in a member expression `obj.f`. + Member, + + // The lookup process took a pointer(-like) value, and then + // proceeded to derefence it and look at the thing(s) + // it points to instead, so that the final expression + // needs to have `(*obj)` + Deref, + + // The lookup process saw a value `obj` of type `T` and + // took into account an in-scope constraint that says + // `T` is a subtype of some other type `U`, so that + // lookup was able to find a member through type `U` + // instead. + SuperType, + + // The lookup process considered a member of an + // enclosing type as being in scope, so that any + // reference to that member needs to use a `this` + // expression as appropriate. + This, + }; + + // The kind of lookup step that was performed + Kind kind; + + // For the `Kind::This` case, what does the implicit + // `this` or `This` parameter refer to? + // + enum class ThisParameterMode : uint8_t + { + ImmutableValue, // An immutable `this` value + MutableValue, // A mutable `this` value + Type, // A `This` type + + Default = ImmutableValue, + }; + ThisParameterMode thisParameterMode = ThisParameterMode::Default; + + // As needed, a reference to the declaration that faciliated + // the lookup step. + // + // For a `Member` lookup step, this is the declaration whose + // members were implicitly pulled into scope. + // + // For a `Constraint` lookup step, this is the `ConstraintDecl` + // that serves to witness the subtype relationship. + // + DeclRef<Decl> declRef; + + Val* val = nullptr; + + // The next implicit step that the lookup process took to + // arrive at a final value. + RefPtr<LookupResultItem_Breadcrumb> next; + + LookupResultItem_Breadcrumb( + Kind kind, + DeclRef<Decl> declRef, + Val* val, + RefPtr<LookupResultItem_Breadcrumb> next, + ThisParameterMode thisParameterMode = ThisParameterMode::Default) + : kind(kind) + , thisParameterMode(thisParameterMode) + , declRef(declRef) + , val(val) + , next(next) + { + } - typedef LookupResultItem_Breadcrumb Breadcrumb; + protected: + // Needed for serialization + LookupResultItem_Breadcrumb() = default; + }; - // Sometimes lookup finds an item, but there were additional - // "hops" taken to reach it. We need to remember these steps - // so that if/when we consturct a full expression we generate - // appropriate AST nodes for all the steps. - // - // We build up a list of these "breadcrumbs" while doing - // lookup, and store them alongside each item found. - // - // As an example, suppose we have an HLSL `cbuffer` declaration: - // - // cbuffer C { float4 f; } - // - // This is syntax sugar for a global-scope variable of - // type `ConstantBuffer<T>` where `T` is a `struct` containing - // all the members: - // - // struct Anon0 { float4 f; }; - // __transparent ConstantBuffer<Anon0> anon1; - // - // The `__transparent` modifier there captures the fact that - // when somebody writes `f` in their code, they expect it to - // "see through" the `cbuffer` declaration (or the global variable, - // in this case) and find the member inside. - // - // But when the user writes `f` we can't just create a simple - // `VarExpr` that refers directly to that field, because that - // doesn't actually reflect the required steps in a way that - // code generation can use. - // - // Instead we need to construct an expression like `(*anon1).f`, - // where there is are two additional steps in the process: - // - // 1. We needed to dereference the pointer-like type `ConstantBuffer<Anon0>` - // to get at a value of type `Anon0` - // 2. We needed to access a sub-field of the aggregate type `Anon0` - // - // We *could* just create these full-formed expressions during - // lookup, but this might mean creating a large number of - // AST nodes in cases where the user calls an overloaded function. - // At the very least we'd rather not heap-allocate in the common - // case where no "extra" steps need to be performed to get to - // the declarations. - // - // This is where "breadcrumbs" come in. A breadcrumb represents - // an extra "step" that must be performed to turn a declaration - // found by lookup into a valid expression to splice into the - // AST. Most of the time lookup result items don't have any - // breadcrumbs, so that no extra heap allocation takes place. - // When an item does have breadcrumbs, and it is chosen as - // the unique result (perhaps by overload resolution), then - // we can walk the list of breadcrumbs to create a full - // expression. - - - // A properly-specialized reference to the declaration that was found. - DeclRef<Decl> declRef; - - // Any breadcrumbs needed in order to turn that declaration - // reference into a well-formed expression. - // - // This is unused in the simple case where a declaration - // is being referenced directly (rather than through - // transparent members). - RefPtr<LookupResultItem_Breadcrumb> breadcrumbs; - - LookupResultItem() = default; - explicit LookupResultItem(DeclRef<Decl> declRef) - : declRef(declRef) - { - } - LookupResultItem(DeclRef<Decl> declRef, RefPtr<Breadcrumb> breadcrumbs) - : declRef(declRef), breadcrumbs(breadcrumbs) - { - } -}; + // Represents one item found during lookup + struct LookupResultItem + { + typedef LookupResultItem_Breadcrumb Breadcrumb; + + // Sometimes lookup finds an item, but there were additional + // "hops" taken to reach it. We need to remember these steps + // so that if/when we consturct a full expression we generate + // appropriate AST nodes for all the steps. + // + // We build up a list of these "breadcrumbs" while doing + // lookup, and store them alongside each item found. + // + // As an example, suppose we have an HLSL `cbuffer` declaration: + // + // cbuffer C { float4 f; } + // + // This is syntax sugar for a global-scope variable of + // type `ConstantBuffer<T>` where `T` is a `struct` containing + // all the members: + // + // struct Anon0 { float4 f; }; + // __transparent ConstantBuffer<Anon0> anon1; + // + // The `__transparent` modifier there captures the fact that + // when somebody writes `f` in their code, they expect it to + // "see through" the `cbuffer` declaration (or the global variable, + // in this case) and find the member inside. + // + // But when the user writes `f` we can't just create a simple + // `VarExpr` that refers directly to that field, because that + // doesn't actually reflect the required steps in a way that + // code generation can use. + // + // Instead we need to construct an expression like `(*anon1).f`, + // where there is are two additional steps in the process: + // + // 1. We needed to dereference the pointer-like type `ConstantBuffer<Anon0>` + // to get at a value of type `Anon0` + // 2. We needed to access a sub-field of the aggregate type `Anon0` + // + // We *could* just create these full-formed expressions during + // lookup, but this might mean creating a large number of + // AST nodes in cases where the user calls an overloaded function. + // At the very least we'd rather not heap-allocate in the common + // case where no "extra" steps need to be performed to get to + // the declarations. + // + // This is where "breadcrumbs" come in. A breadcrumb represents + // an extra "step" that must be performed to turn a declaration + // found by lookup into a valid expression to splice into the + // AST. Most of the time lookup result items don't have any + // breadcrumbs, so that no extra heap allocation takes place. + // When an item does have breadcrumbs, and it is chosen as + // the unique result (perhaps by overload resolution), then + // we can walk the list of breadcrumbs to create a full + // expression. + + + // A properly-specialized reference to the declaration that was found. + DeclRef<Decl> declRef; + + // Any breadcrumbs needed in order to turn that declaration + // reference into a well-formed expression. + // + // This is unused in the simple case where a declaration + // is being referenced directly (rather than through + // transparent members). + RefPtr<LookupResultItem_Breadcrumb> breadcrumbs; + + LookupResultItem() = default; + explicit LookupResultItem(DeclRef<Decl> declRef) + : declRef(declRef) + { + } + LookupResultItem(DeclRef<Decl> declRef, RefPtr<Breadcrumb> breadcrumbs) + : declRef(declRef), breadcrumbs(breadcrumbs) + { + } + }; -// Result of looking up a name in some lexical/semantic environment. -// Can be used to enumerate all the declarations matching that name, -// in the case where the result is overloaded. -struct LookupResult -{ - // The one item that was found, in the simple case - LookupResultItem item; + // Result of looking up a name in some lexical/semantic environment. + // Can be used to enumerate all the declarations matching that name, + // in the case where the result is overloaded. + struct LookupResult + { + // The one item that was found, in the simple case + LookupResultItem item; - // All of the items that were found, in the complex case. - // Note: if there was no overloading, then this list isn't - // used at all, to avoid allocation. - // - // Additionally, if `items` is used, then `item` *must* hold an item that - // is also in the items list (typically the first entry), as an invariant. - // Otherwise isValid/begin will not function correctly. - List<LookupResultItem> items; + // All of the items that were found, in the complex case. + // Note: if there was no overloading, then this list isn't + // used at all, to avoid allocation. + // + // Additionally, if `items` is used, then `item` *must* hold an item that + // is also in the items list (typically the first entry), as an invariant. + // Otherwise isValid/begin will not function correctly. + List<LookupResultItem> items; - // Was at least one result found? - bool isValid() const { return item.declRef.getDecl() != nullptr; } + // Was at least one result found? + bool isValid() const { return item.declRef.getDecl() != nullptr; } - bool isOverloaded() const { return items.getCount() > 1; } + bool isOverloaded() const { return items.getCount() > 1; } - Name* getName() const - { - return items.getCount() > 1 ? items[0].declRef.getName() : item.declRef.getName(); - } - LookupResultItem* begin() const - { - if (isValid()) + Name* getName() const { - if (isOverloaded()) - return const_cast<LookupResultItem*>(items.begin()); + return items.getCount() > 1 ? items[0].declRef.getName() : item.declRef.getName(); + } + LookupResultItem* begin() const + { + if (isValid()) + { + if (isOverloaded()) + return const_cast<LookupResultItem*>(items.begin()); + else + return const_cast<LookupResultItem*>(&item); + } else - return const_cast<LookupResultItem*>(&item); + return nullptr; } - else - return nullptr; - } - LookupResultItem* end() const - { - if (isValid()) + LookupResultItem* end() const { - if (isOverloaded()) - return const_cast<LookupResultItem*>(items.end()); + if (isValid()) + { + if (isOverloaded()) + return const_cast<LookupResultItem*>(items.end()); + else + return const_cast<LookupResultItem*>(&item + 1); + } else - return const_cast<LookupResultItem*>(&item + 1); + return nullptr; } - else - return nullptr; - } -}; - -// A helper to avoid having to include slang-check-impl.h in slang-syntax.h -struct SemanticsVisitor; -ASTBuilder* semanticsVisitorGetASTBuilder(SemanticsVisitor*); - -struct LookupRequest -{ - SemanticsVisitor* semantics = nullptr; - Scope* scope = nullptr; - Scope* endScope = nullptr; + }; - // A decl to exclude from the lookup, used to exclude the current decl being checked, such as in - // typedef Foo Foo; to avoid finding itself. - Decl* declToExclude = nullptr; - LookupMask mask = LookupMask::Default; - LookupOptions options = LookupOptions::None; + // A helper to avoid having to include slang-check-impl.h in slang-syntax.h + struct SemanticsVisitor; + ASTBuilder* semanticsVisitorGetASTBuilder(SemanticsVisitor*); - bool isCompletionRequest() const - { - return (options & LookupOptions::Completion) != LookupOptions::None; - } - bool shouldConsiderAllLocalNames() const + struct LookupRequest { - return (options & LookupOptions::ConsiderAllLocalNamesInScope) != LookupOptions::None; - } -}; + SemanticsVisitor* semantics = nullptr; + Scope* scope = nullptr; + Scope* endScope = nullptr; -struct WitnessTable; + // A decl to exclude from the lookup, used to exclude the current decl being checked, such + // as in typedef Foo Foo; to avoid finding itself. + Decl* declToExclude = nullptr; + LookupMask mask = LookupMask::Default; + LookupOptions options = LookupOptions::None; -// A value that witnesses the satisfaction of an interface -// requirement by a particular declaration or value. -struct RequirementWitness -{ - SLANG_VALUE_CLASS(RequirementWitness) + bool isCompletionRequest() const + { + return (options & LookupOptions::Completion) != LookupOptions::None; + } + bool shouldConsiderAllLocalNames() const + { + return (options & LookupOptions::ConsiderAllLocalNamesInScope) != LookupOptions::None; + } + }; - RequirementWitness() - : m_flavor(Flavor::none) - { - } + class WitnessTable; - RequirementWitness(DeclRefBase* declRef) - : m_flavor(Flavor::declRef), m_declRef(declRef) + // A value that witnesses the satisfaction of an interface + // requirement by a particular declaration or value. + struct RequirementWitness { - } + RequirementWitness() + : m_flavor(Flavor::none) + { + } - RequirementWitness(Val* val); + RequirementWitness(DeclRefBase* declRef) + : m_flavor(Flavor::declRef), m_declRef(declRef) + { + } - RequirementWitness(RefPtr<WitnessTable> witnessTable); + RequirementWitness(Val* val); - enum class Flavor - { - none, - declRef, - val, - witnessTable, - }; + RequirementWitness(RefPtr<WitnessTable> witnessTable); - Flavor getFlavor() const { return m_flavor; } + enum class Flavor + { + none, + declRef, + val, + witnessTable, + }; - DeclRef<Decl> getDeclRef() - { - SLANG_ASSERT(getFlavor() == Flavor::declRef); - return m_declRef; - } + Flavor getFlavor() const { return m_flavor; } - Val* getVal() - { - SLANG_ASSERT(getFlavor() == Flavor::val); - return m_val; - } + DeclRef<Decl> getDeclRef() + { + SLANG_ASSERT(getFlavor() == Flavor::declRef); + return m_declRef; + } - RefPtr<WitnessTable> getWitnessTable(); + Val* getVal() + { + SLANG_ASSERT(getFlavor() == Flavor::val); + return m_val; + } - RequirementWitness specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); + RefPtr<WitnessTable> getWitnessTable(); - Flavor m_flavor; - DeclRef<Decl> m_declRef; - RefPtr<RefObject> m_obj; - Val* m_val = nullptr; -}; + RequirementWitness specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); -typedef OrderedDictionary<Decl*, RequirementWitness> RequirementDictionary; + Flavor m_flavor; + DeclRef<Decl> m_declRef; + RefPtr<RefObject> m_obj; + Val* m_val = nullptr; + }; -struct WitnessTable : SerialRefObject -{ - SLANG_OBJ_CLASS(WitnessTable) + typedef OrderedDictionary<Decl*, RequirementWitness> RequirementDictionary; - const RequirementDictionary& getRequirementDictionary() { return m_requirementDictionary; } + FIDDLE() + class WitnessTable : public RefObject + { + FIDDLE(...) + const RequirementDictionary& getRequirementDictionary() { return m_requirementDictionary; } - void add(Decl* decl, RequirementWitness const& witness); + void add(Decl* decl, RequirementWitness const& witness); - // The type that the witness table witnesses conformance to (e.g. an Interface) - Type* baseType; + // The type that the witness table witnesses conformance to (e.g. an Interface) + Type* baseType; - // The type witnessesd by the witness table (a concrete type). - Type* witnessedType; + // The type witnessesd by the witness table (a concrete type). + Type* witnessedType; - // Whether or not this witness table is an extern declaration. - bool isExtern = false; + // Whether or not this witness table is an extern declaration. + bool isExtern = false; - // Cached dictionary for looking up satisfying values. - RequirementDictionary m_requirementDictionary; + // Cached dictionary for looking up satisfying values. + RequirementDictionary m_requirementDictionary; - RefPtr<WitnessTable> specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); -}; + RefPtr<WitnessTable> specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); + }; -struct SpecializationParam -{ - enum class Flavor + struct SpecializationParam { - GenericType, - GenericValue, - ExistentialType, - ExistentialValue, + enum class Flavor + { + GenericType, + GenericValue, + ExistentialType, + ExistentialValue, + }; + Flavor flavor; + SourceLoc loc; + NodeBase* object = nullptr; }; - Flavor flavor; - SourceLoc loc; - NodeBase* object = nullptr; -}; -typedef List<SpecializationParam> SpecializationParams; + typedef List<SpecializationParam> SpecializationParams; -struct SpecializationArg -{ - SLANG_VALUE_CLASS(SpecializationArg) - Val* val = nullptr; -}; -typedef List<SpecializationArg> SpecializationArgs; + struct SpecializationArg + { + Val* val = nullptr; + }; + typedef List<SpecializationArg> SpecializationArgs; -struct ExpandedSpecializationArg : SpecializationArg -{ - SLANG_VALUE_CLASS(ExpandedSpecializationArg) - Val* witness = nullptr; -}; -typedef List<ExpandedSpecializationArg> ExpandedSpecializationArgs; - -/// A reference-counted object to hold a list of candidate extensions -/// that might be applicable to a type based on its declaration. -/// -struct CandidateExtensionList : RefObject -{ - List<ExtensionDecl*> candidateExtensions; -}; + struct ExpandedSpecializationArg : SpecializationArg + { + Val* witness = nullptr; + }; + typedef List<ExpandedSpecializationArg> ExpandedSpecializationArgs; + /// A reference-counted object to hold a list of candidate extensions + /// that might be applicable to a type based on its declaration. + /// + FIDDLE() + class CandidateExtensionList : public RefObject + { + FIDDLE(...) + List<ExtensionDecl*> candidateExtensions; + }; -enum class DeclAssociationKind -{ - ForwardDerivativeFunc, - BackwardDerivativeFunc, - PrimalSubstituteFunc -}; -struct DeclAssociation : SerialRefObject -{ - SLANG_OBJ_CLASS(DeclAssociation) - DeclAssociationKind kind; - Decl* decl; -}; - -/// A reference-counted object to hold a list of associated decls for a decl. -/// -struct DeclAssociationList : SerialRefObject -{ - SLANG_OBJ_CLASS(DeclAssociationList) + enum class DeclAssociationKind + { + ForwardDerivativeFunc, + BackwardDerivativeFunc, + PrimalSubstituteFunc + }; - List<RefPtr<DeclAssociation>> associations; -}; + FIDDLE() + class DeclAssociation : public RefObject + { + FIDDLE(...) + DeclAssociationKind kind; + Decl* decl; + }; -/// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` -enum ParameterDirection -{ - kParameterDirection_In, ///< Copy in - kParameterDirection_Out, ///< Copy out - kParameterDirection_InOut, ///< Copy in, copy out - kParameterDirection_Ref, ///< By-reference - kParameterDirection_ConstRef, ///< By-const-reference -}; + /// A reference-counted object to hold a list of associated decls for a decl. + /// + FIDDLE() + class DeclAssociationList : public RefObject + { + FIDDLE(...) + List<RefPtr<DeclAssociation>> associations; + }; -void printDiagnosticArg(StringBuilder& sb, ParameterDirection direction); + /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` + enum ParameterDirection + { + kParameterDirection_In, ///< Copy in + kParameterDirection_Out, ///< Copy out + kParameterDirection_InOut, ///< Copy in, copy out + kParameterDirection_Ref, ///< By-reference + kParameterDirection_ConstRef, ///< By-const-reference + }; -/// The kind of a builtin interface requirement that can be automatically synthesized. -enum class BuiltinRequirementKind -{ - DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method - - DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement - DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement - DZeroFunc, ///< The `IDifferentiable.dzero` function requirement - DAddFunc, ///< The `IDifferentiable.dadd` function requirement - DMulFunc, ///< The `IDifferentiable.dmul` function requirement - - InitLogicalFromInt, ///< The `ILogical.__init` mtehod. - Equals, ///< The `ILogical.equals` mtehod. - LessThan, ///< The `ILogical.lessThan` mtehod. - LessThanOrEquals, ///< The `ILogical.lessThanOrEquals` mtehod. - Shl, ///< The `ILogical.shl` mtehod. - Shr, ///< The `ILogical.shr` mtehod. - BitAnd, ///< The `ILogical.bitAnd` mtehod. - BitOr, ///< The `ILogical.bitOr` mtehod. - BitXor, ///< The `ILogical.bitXor` mtehod. - BitNot, ///< The `ILogical.bitNot` mtehod. - And, ///< The `ILogical.and` mtehod. - Or, ///< The `ILogical.or` mtehod. - Not, ///< The `ILogical.not` mtehod. -}; - -enum class FunctionDifferentiableLevel -{ - None, - Forward, - Backward -}; + void printDiagnosticArg(StringBuilder & sb, ParameterDirection direction); + + /// The kind of a builtin interface requirement that can be automatically synthesized. + enum class BuiltinRequirementKind + { + DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method + + DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement + DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement + DZeroFunc, ///< The `IDifferentiable.dzero` function requirement + DAddFunc, ///< The `IDifferentiable.dadd` function requirement + DMulFunc, ///< The `IDifferentiable.dmul` function requirement + + InitLogicalFromInt, ///< The `ILogical.__init` mtehod. + Equals, ///< The `ILogical.equals` mtehod. + LessThan, ///< The `ILogical.lessThan` mtehod. + LessThanOrEquals, ///< The `ILogical.lessThanOrEquals` mtehod. + Shl, ///< The `ILogical.shl` mtehod. + Shr, ///< The `ILogical.shr` mtehod. + BitAnd, ///< The `ILogical.bitAnd` mtehod. + BitOr, ///< The `ILogical.bitOr` mtehod. + BitXor, ///< The `ILogical.bitXor` mtehod. + BitNot, ///< The `ILogical.bitNot` mtehod. + And, ///< The `ILogical.and` mtehod. + Or, ///< The `ILogical.or` mtehod. + Not, ///< The `ILogical.not` mtehod. + }; -/// Represents a markup (documentation) associated with a decl. -struct MarkupEntry : public SerialRefObject -{ - SLANG_OBJ_CLASS(MarkupEntry) + enum class FunctionDifferentiableLevel + { + None, + Forward, + Backward + }; - NodeBase* m_node; ///< The node this documentation is associated with - String m_markup; ///< The raw contents of of markup associated with the decoration - MarkupVisibility m_visibility = MarkupVisibility::Public; ///< How visible this decl is -}; + /// Represents a markup (documentation) associated with a decl. + FIDDLE() + class MarkupEntry : public RefObject + { + FIDDLE(...) + NodeBase* m_node; ///< The node this documentation is associated with + String m_markup; ///< The raw contents of of markup associated with the decoration + MarkupVisibility m_visibility = MarkupVisibility::Public; ///< How visible this decl is + }; -/// Get the inner most expr from an higher order expr chain, e.g. `__fwd_diff(__fwd_diff(f))`'s -/// inner most expr is `f`. -Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLevel& outDiffLevel); -inline Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr) -{ - FunctionDifferentiableLevel level; - return getInnerMostExprFromHigherOrderExpr(expr, level); -} + /// Get the inner most expr from an higher order expr chain, e.g. `__fwd_diff(__fwd_diff(f))`'s + /// inner most expr is `f`. + Expr* getInnerMostExprFromHigherOrderExpr( + Expr * expr, + FunctionDifferentiableLevel & outDiffLevel); + inline Expr* getInnerMostExprFromHigherOrderExpr(Expr * expr) + { + FunctionDifferentiableLevel level; + return getInnerMostExprFromHigherOrderExpr(expr, level); + } -/// Get the operator name from the higher order invoke expr. -UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr); + /// Get the operator name from the higher order invoke expr. + UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr * expr); -enum class DeclVisibility -{ - Private, - Internal, - Public, - Default = Internal, -}; + enum class DeclVisibility + { + Private, + Internal, + Public, + Default = Internal, + }; } // namespace Slang diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 29a52a93a..ff4cc0d10 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -1,7 +1,9 @@ // slang-ast-type.cpp +#include "slang-ast-type.h" + #include "slang-ast-builder.h" +#include "slang-ast-dispatch.h" #include "slang-ast-modifier.h" -#include "slang-generated-ast-macro.h" #include "slang-syntax.h" #include <assert.h> diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 7393092f9..dd4d2acd6 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -1,19 +1,20 @@ // slang-ast-type.h - #pragma once #include "slang-ast-base.h" +#include "slang-ast-type.h.fiddle" +FIDDLE() namespace Slang { // Syntax class definitions for types. // The type of a reference to an overloaded name +FIDDLE() class OverloadGroupType : public Type { - SLANG_AST_CLASS(OverloadGroupType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); @@ -21,20 +22,20 @@ class OverloadGroupType : public Type // The type of an initializer-list expression (before it has // been coerced to some other type) +FIDDLE() class InitializerListType : public Type { - SLANG_AST_CLASS(InitializerListType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); }; // The type of an expression that was erroneous +FIDDLE() class ErrorType : public Type { - SLANG_AST_CLASS(ErrorType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); @@ -42,20 +43,20 @@ class ErrorType : public Type }; // The bottom/empty type that has no values. +FIDDLE() class BottomType : public Type { - SLANG_AST_CLASS(BottomType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; // A type that takes the form of a reference to some declaration +FIDDLE() class DeclRefType : public Type { - SLANG_AST_CLASS(DeclRefType) - + FIDDLE(...) static Type* create(ASTBuilder* astBuilder, DeclRef<Decl> declRef); DeclRef<Decl> getDeclRef() const { return DeclRef<Decl>(as<DeclRefBase>(getOperand(0))); } @@ -83,20 +84,20 @@ bool isTypePack(Type* type); bool isAbstractTypePack(Type* type); // Base class for types that can be used in arithmetic expressions +FIDDLE(abstract) class ArithmeticExpressionType : public DeclRefType { - SLANG_ABSTRACT_AST_CLASS(ArithmeticExpressionType) - + FIDDLE(...) BasicExpressionType* getScalarType(); // Overrides should be public so base classes can access BasicExpressionType* _getScalarTypeOverride(); }; +FIDDLE() class BasicExpressionType : public ArithmeticExpressionType { - SLANG_AST_CLASS(BasicExpressionType) - + FIDDLE(...) BaseType getBaseType() const; // Overrides should be public so base classes can access @@ -108,45 +109,52 @@ class BasicExpressionType : public ArithmeticExpressionType // Base type for things that are built in to the compiler, // and will usually have special behavior or a custom // mapping to the IR level. +FIDDLE(abstract) class BuiltinType : public DeclRefType { - SLANG_ABSTRACT_AST_CLASS(BuiltinType) + FIDDLE(...) }; +FIDDLE(abstract) class DataLayoutType : public BuiltinType { - SLANG_ABSTRACT_AST_CLASS(DataLayoutType) + FIDDLE(...) }; +FIDDLE() class IBufferDataLayoutType : public BuiltinType { - SLANG_AST_CLASS(IBufferDataLayoutType) + FIDDLE(...) }; +FIDDLE() class DefaultDataLayoutType : public DataLayoutType { - SLANG_AST_CLASS(DefaultDataLayoutType) + FIDDLE(...) }; +FIDDLE() class Std430DataLayoutType : public DataLayoutType { - SLANG_AST_CLASS(Std430DataLayoutType) + FIDDLE(...) }; +FIDDLE() class Std140DataLayoutType : public DataLayoutType { - SLANG_AST_CLASS(Std140DataLayoutType) + FIDDLE(...) }; +FIDDLE() class ScalarDataLayoutType : public DataLayoutType { - SLANG_AST_CLASS(ScalarDataLayoutType) + FIDDLE(...) }; +FIDDLE() class FeedbackType : public BuiltinType { - SLANG_AST_CLASS(FeedbackType) - + FIDDLE(...) enum class Kind : uint8_t { MinMip, /// SAMPLER_FEEDBACK_MIN_MIP @@ -156,37 +164,43 @@ class FeedbackType : public BuiltinType Kind getKind() const; }; +FIDDLE(abstract) class TextureShapeType : public BuiltinType { - SLANG_ABSTRACT_AST_CLASS(TextureShapeType) + FIDDLE(...) }; +FIDDLE() class TextureShape1DType : public TextureShapeType { - SLANG_AST_CLASS(TextureShape1DType) + FIDDLE(...) }; +FIDDLE() class TextureShape2DType : public TextureShapeType { - SLANG_AST_CLASS(TextureShape2DType) + FIDDLE(...) }; +FIDDLE() class TextureShape3DType : public TextureShapeType { - SLANG_AST_CLASS(TextureShape3DType) + FIDDLE(...) }; +FIDDLE() class TextureShapeCubeType : public TextureShapeType { - SLANG_AST_CLASS(TextureShapeCubeType) + FIDDLE(...) }; +FIDDLE() class TextureShapeBufferType : public TextureShapeType { - SLANG_AST_CLASS(TextureShapeBufferType) + FIDDLE(...) }; // Resources that contain "elements" that can be fetched +FIDDLE(abstract) class ResourceType : public BuiltinType { - SLANG_ABSTRACT_AST_CLASS(ResourceType) - + FIDDLE(...) bool isMultisample(); bool isArray(); bool isShadow(); @@ -199,286 +213,322 @@ class ResourceType : public BuiltinType void _toTextOverride(StringBuilder& out); }; +FIDDLE(abstract) class TextureTypeBase : public ResourceType { - SLANG_ABSTRACT_AST_CLASS(TextureTypeBase) - + FIDDLE(...) Val* getSampleCount(); Val* getFormat(); }; +FIDDLE() class TextureType : public TextureTypeBase { - SLANG_AST_CLASS(TextureType) + FIDDLE(...) }; // This is a base type for `image*` types, as they exist in GLSL +FIDDLE() class GLSLImageType : public TextureTypeBase { - SLANG_AST_CLASS(GLSLImageType) + FIDDLE(...) }; +FIDDLE() class SubpassInputType : public BuiltinType { - SLANG_AST_CLASS(SubpassInputType) - + FIDDLE(...) bool isMultisample(); Type* getElementType(); }; +FIDDLE() class SamplerStateType : public BuiltinType { - SLANG_AST_CLASS(SamplerStateType) - + FIDDLE(...) // Returns flavor of sampler state of this type. SamplerStateFlavor getFlavor() const; }; // Other cases of generic types known to the compiler +FIDDLE() class BuiltinGenericType : public BuiltinType { - SLANG_AST_CLASS(BuiltinGenericType) - + FIDDLE(...) Type* getElementType() const; }; // Types that behave like pointers, in that they can be // dereferenced (implicitly) to access members defined // in the element type. +FIDDLE(abstract) class PointerLikeType : public BuiltinGenericType { - SLANG_AST_CLASS(PointerLikeType) + FIDDLE(...) }; +FIDDLE() class DynamicResourceType : public BuiltinType { - SLANG_AST_CLASS(DynamicResourceType) + FIDDLE(...) }; // HLSL buffer-type resources +FIDDLE(abstract) class HLSLStructuredBufferTypeBase : public BuiltinGenericType { - SLANG_AST_CLASS(HLSLStructuredBufferTypeBase) + FIDDLE(...) }; +FIDDLE() class HLSLStructuredBufferType : public HLSLStructuredBufferTypeBase { - SLANG_AST_CLASS(HLSLStructuredBufferType) + FIDDLE(...) }; +FIDDLE() class HLSLRWStructuredBufferType : public HLSLStructuredBufferTypeBase { - SLANG_AST_CLASS(HLSLRWStructuredBufferType) + FIDDLE(...) }; +FIDDLE() class HLSLRasterizerOrderedStructuredBufferType : public HLSLStructuredBufferTypeBase { - SLANG_AST_CLASS(HLSLRasterizerOrderedStructuredBufferType) + FIDDLE(...) }; +FIDDLE() class UntypedBufferResourceType : public BuiltinType { - SLANG_AST_CLASS(UntypedBufferResourceType) + FIDDLE(...) }; +FIDDLE() class HLSLByteAddressBufferType : public UntypedBufferResourceType { - SLANG_AST_CLASS(HLSLByteAddressBufferType) + FIDDLE(...) }; +FIDDLE() class HLSLRWByteAddressBufferType : public UntypedBufferResourceType { - SLANG_AST_CLASS(HLSLRWByteAddressBufferType) + FIDDLE(...) }; +FIDDLE() class HLSLRasterizerOrderedByteAddressBufferType : public UntypedBufferResourceType { - SLANG_AST_CLASS(HLSLRasterizerOrderedByteAddressBufferType) + FIDDLE(...) }; +FIDDLE() class RaytracingAccelerationStructureType : public UntypedBufferResourceType { - SLANG_AST_CLASS(RaytracingAccelerationStructureType) + FIDDLE(...) }; +FIDDLE() class HLSLAppendStructuredBufferType : public HLSLStructuredBufferTypeBase { - SLANG_AST_CLASS(HLSLAppendStructuredBufferType) + FIDDLE(...) }; +FIDDLE() class HLSLConsumeStructuredBufferType : public HLSLStructuredBufferTypeBase { - SLANG_AST_CLASS(HLSLConsumeStructuredBufferType) + FIDDLE(...) }; +FIDDLE() class GLSLAtomicUintType : public BuiltinType { - SLANG_AST_CLASS(GLSLAtomicUintType) + FIDDLE(...) }; +FIDDLE() class HLSLPatchType : public BuiltinType { - SLANG_AST_CLASS(HLSLPatchType) - + FIDDLE(...) Type* getElementType(); IntVal* getElementCount(); }; +FIDDLE() class HLSLInputPatchType : public HLSLPatchType { - SLANG_AST_CLASS(HLSLInputPatchType) + FIDDLE(...) }; +FIDDLE() class HLSLOutputPatchType : public HLSLPatchType { - SLANG_AST_CLASS(HLSLOutputPatchType) + FIDDLE(...) }; // HLSL geometry shader output stream types +FIDDLE() class HLSLStreamOutputType : public BuiltinGenericType { - SLANG_AST_CLASS(HLSLStreamOutputType) + FIDDLE(...) }; +FIDDLE() class HLSLPointStreamType : public HLSLStreamOutputType { - SLANG_AST_CLASS(HLSLPointStreamType) + FIDDLE(...) }; +FIDDLE() class HLSLLineStreamType : public HLSLStreamOutputType { - SLANG_AST_CLASS(HLSLLineStreamType) + FIDDLE(...) }; +FIDDLE() class HLSLTriangleStreamType : public HLSLStreamOutputType { - SLANG_AST_CLASS(HLSLTriangleStreamType) + FIDDLE(...) }; // mesh shader output types +FIDDLE() class MeshOutputType : public BuiltinGenericType { - SLANG_AST_CLASS(MeshOutputType) - + FIDDLE(...) Type* getElementType(); IntVal* getMaxElementCount(); }; +FIDDLE() class VerticesType : public MeshOutputType { - SLANG_AST_CLASS(VerticesType) + FIDDLE(...) }; +FIDDLE() class IndicesType : public MeshOutputType { - SLANG_AST_CLASS(IndicesType) + FIDDLE(...) }; +FIDDLE() class PrimitivesType : public MeshOutputType { - SLANG_AST_CLASS(PrimitivesType) + FIDDLE(...) }; // +FIDDLE() class GLSLInputAttachmentType : public BuiltinType { - SLANG_AST_CLASS(GLSLInputAttachmentType) + FIDDLE(...) }; +FIDDLE() class DescriptorHandleType : public PointerLikeType { - SLANG_AST_CLASS(DescriptorHandleType) + FIDDLE(...) }; // Base class for types used when desugaring parameter block // declarations, includeing HLSL `cbuffer` or GLSL `uniform` blocks. +FIDDLE(abstract) class ParameterGroupType : public PointerLikeType { - SLANG_AST_CLASS(ParameterGroupType) + FIDDLE(...) }; +FIDDLE() class UniformParameterGroupType : public ParameterGroupType { - SLANG_AST_CLASS(UniformParameterGroupType) + FIDDLE(...) Type* getLayoutType(); }; +FIDDLE() class VaryingParameterGroupType : public ParameterGroupType { - SLANG_AST_CLASS(VaryingParameterGroupType) + FIDDLE(...) }; // type for HLSL `cbuffer` declarations, and `ConstantBuffer<T>` // ALso used for GLSL `uniform` blocks. +FIDDLE() class ConstantBufferType : public UniformParameterGroupType { - SLANG_AST_CLASS(ConstantBufferType) + FIDDLE(...) }; // type for HLSL `tbuffer` declarations, and `TextureBuffer<T>` +FIDDLE() class TextureBufferType : public UniformParameterGroupType { - SLANG_AST_CLASS(TextureBufferType) + FIDDLE(...) }; // type for GLSL `in` and `out` blocks +FIDDLE() class GLSLInputParameterGroupType : public VaryingParameterGroupType { - SLANG_AST_CLASS(GLSLInputParameterGroupType) + FIDDLE(...) }; +FIDDLE() class GLSLOutputParameterGroupType : public VaryingParameterGroupType { - SLANG_AST_CLASS(GLSLOutputParameterGroupType) + FIDDLE(...) }; // type for GLSL `buffer` blocks +FIDDLE() class GLSLShaderStorageBufferType : public PointerLikeType { - SLANG_AST_CLASS(GLSLShaderStorageBufferType) + FIDDLE(...) }; // type for Slang `ParameterBlock<T>` type +FIDDLE() class ParameterBlockType : public UniformParameterGroupType { - SLANG_AST_CLASS(ParameterBlockType) + FIDDLE(...) }; +FIDDLE() class ArrayExpressionType : public DeclRefType { - SLANG_AST_CLASS(ArrayExpressionType) - + FIDDLE(...) bool isUnsized(); void _toTextOverride(StringBuilder& out); Type* getElementType(); IntVal* getElementCount(); }; +FIDDLE() class AtomicType : public DeclRefType { - SLANG_AST_CLASS(AtomicType) - + FIDDLE(...) Type* getElementType(); }; +FIDDLE() class CoopVectorExpressionType : public ArithmeticExpressionType { - SLANG_AST_CLASS(CoopVectorExpressionType) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); BasicExpressionType* _getScalarTypeOverride(); @@ -489,10 +539,10 @@ class CoopVectorExpressionType : public ArithmeticExpressionType // The "type" of an expression that resolves to a type. // For example, in the expression `float(2)` the sub-expression, // `float` would have the type `TypeType(float)`. +FIDDLE() class TypeType : public Type { - SLANG_AST_CLASS(TypeType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); @@ -503,38 +553,43 @@ class TypeType : public Type }; // A differential pair type, e.g., `__DifferentialPair<T>` +FIDDLE() class DifferentialPairType : public ArithmeticExpressionType { - SLANG_AST_CLASS(DifferentialPairType) + FIDDLE(...) Type* getPrimalType(); }; +FIDDLE() class DifferentialPtrPairType : public ArithmeticExpressionType { - SLANG_AST_CLASS(DifferentialPtrPairType) + FIDDLE(...) Type* getPrimalRefType(); }; +FIDDLE() class DifferentiableType : public BuiltinType { - SLANG_AST_CLASS(DifferentiableType) + FIDDLE(...) }; +FIDDLE() class DifferentiablePtrType : public BuiltinType { - SLANG_AST_CLASS(DifferentiablePtrType) + FIDDLE(...) }; +FIDDLE() class DefaultInitializableType : public BuiltinType { - SLANG_AST_CLASS(DefaultInitializableType); + FIDDLE(...) }; // A vector type, e.g., `vector<T,N>` +FIDDLE() class VectorExpressionType : public ArithmeticExpressionType { - SLANG_AST_CLASS(VectorExpressionType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); BasicExpressionType* _getScalarTypeOverride(); @@ -544,10 +599,10 @@ class VectorExpressionType : public ArithmeticExpressionType }; // A matrix type, e.g., `matrix<T,R,C,L>` +FIDDLE() class MatrixExpressionType : public ArithmeticExpressionType { - SLANG_AST_CLASS(MatrixExpressionType) - + FIDDLE(...) Type* getElementType(); IntVal* getRowCount(); IntVal* getColumnCount(); @@ -563,137 +618,152 @@ private: SLANG_UNREFLECTED Type* rowType = nullptr; }; +FIDDLE() class TensorViewType : public BuiltinType { - SLANG_AST_CLASS(TensorViewType) - + FIDDLE(...) Type* getElementType(); }; // Base class for built in string types +FIDDLE(abstract) class StringTypeBase : public BuiltinType { - SLANG_AST_CLASS(StringTypeBase) + FIDDLE(...) }; // The regular built-in `String` type +FIDDLE() class StringType : public StringTypeBase { - SLANG_AST_CLASS(StringType) + FIDDLE(...) }; // The string type native to the target +FIDDLE() class NativeStringType : public StringTypeBase { - SLANG_AST_CLASS(NativeStringType) + FIDDLE(...) }; // The built-in `__Dynamic` type +FIDDLE() class DynamicType : public BuiltinType { - SLANG_AST_CLASS(DynamicType) + FIDDLE(...) }; // Type built-in `__EnumType` type +FIDDLE() class EnumTypeType : public BuiltinType { - SLANG_AST_CLASS(EnumTypeType) - + FIDDLE(...) // TODO: provide accessors for the declaration, the "tag" type, etc. }; // Base class for types that map down to // simple pointers as part of code generation. +FIDDLE() class PtrTypeBase : public BuiltinType { - SLANG_AST_CLASS(PtrTypeBase) - + FIDDLE(...) // Get the type of the pointed-to value. Type* getValueType(); Val* getAddressSpace(); }; +FIDDLE() class NoneType : public BuiltinType { - SLANG_AST_CLASS(NoneType) + FIDDLE(...) }; +FIDDLE() class NullPtrType : public BuiltinType { - SLANG_AST_CLASS(NullPtrType) + FIDDLE(...) }; // A true (user-visible) pointer type, e.g., `T*` +FIDDLE() class PtrType : public PtrTypeBase { - SLANG_AST_CLASS(PtrType) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); }; /// A pointer-like type used to represent a parameter "direction" +FIDDLE() class ParamDirectionType : public PtrTypeBase { - SLANG_AST_CLASS(ParamDirectionType) + FIDDLE(...) }; // A type that represents the behind-the-scenes // logical pointer that is passed for an `out` // or `in out` parameter +FIDDLE(abstract) class OutTypeBase : public ParamDirectionType { - SLANG_AST_CLASS(OutTypeBase) + FIDDLE(...) }; // The type for an `out` parameter, e.g., `out T` +FIDDLE() class OutType : public OutTypeBase { - SLANG_AST_CLASS(OutType) + FIDDLE(...) }; // The type for an `in out` parameter, e.g., `in out T` +FIDDLE() class InOutType : public OutTypeBase { - SLANG_AST_CLASS(InOutType) + FIDDLE(...) }; +FIDDLE(abstract) class RefTypeBase : public ParamDirectionType { - SLANG_AST_CLASS(RefTypeBase) + FIDDLE(...) }; // The type for an `ref` parameter, e.g., `ref T` +FIDDLE() class RefType : public RefTypeBase { - SLANG_AST_CLASS(RefType) + FIDDLE(...) void _toTextOverride(StringBuilder& out); }; // The type for an `constref` parameter, e.g., `constref T` +FIDDLE() class ConstRefType : public RefTypeBase { - SLANG_AST_CLASS(ConstRefType) + FIDDLE(...) }; +FIDDLE() class OptionalType : public BuiltinType { - SLANG_AST_CLASS(OptionalType) + FIDDLE(...) Type* getValueType(); }; // A raw-pointer reference to an managed value. +FIDDLE() class NativeRefType : public BuiltinType { - SLANG_AST_CLASS(NativeRefType) + FIDDLE(...) Type* getValueType(); }; // A type alias of some kind (e.g., via `typedef`) +FIDDLE() class NamedExpressionType : public Type { - SLANG_AST_CLASS(NamedExpressionType) - + FIDDLE(...) DeclRef<TypeDefDecl> getDeclRef() { return as<DeclRefBase>(getOperand(0)); } // Overrides should be public so base classes can access @@ -705,10 +775,10 @@ class NamedExpressionType : public Type // A function type is defined by its parameter types // and its result type. +FIDDLE() class FuncType : public Type { - SLANG_AST_CLASS(FuncType) - + FIDDLE(...) // Construct a unary function FuncType(Type* paramType, Type* resultType, Type* errorType) { @@ -739,18 +809,19 @@ class FuncType : public Type }; // A tuple is a product of its member types +FIDDLE() class TupleType : public DeclRefType { - SLANG_AST_CLASS(TupleType) - + FIDDLE(...) Index getMemberCount() const; Type* getMember(Index i) const; Type* getTypePack() const; }; +FIDDLE() class EachType : public Type { - SLANG_AST_CLASS(EachType) + FIDDLE(...) Type* getElementType() const { return as<Type>(getOperand(0)); } DeclRefType* getElementDeclRefType() const { return as<DeclRefType>(getOperand(0)); } @@ -760,9 +831,10 @@ class EachType : public Type Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class ExpandType : public Type { - SLANG_AST_CLASS(ExpandType) + FIDDLE(...) Type* getPatternType() const { return as<Type>(getOperand(0)); } Index getCapturedTypePackCount() { return getOperandCount() - 1; } Type* getCapturedTypePack(Index i) { return as<Type>(getOperand(i + 1)); } @@ -778,9 +850,10 @@ class ExpandType : public Type }; // A concrete pack of types. +FIDDLE() class ConcreteTypePack : public Type { - SLANG_AST_CLASS(ConcreteTypePack) + FIDDLE(...) ConcreteTypePack(ArrayView<Type*> types) { for (auto t : types) @@ -794,10 +867,10 @@ class ConcreteTypePack : public Type }; // The "type" of an expression that names a generic declaration. +FIDDLE() class GenericDeclRefType : public Type { - SLANG_AST_CLASS(GenericDeclRefType) - + FIDDLE(...) DeclRef<GenericDecl> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); } // Overrides should be public so base classes can access @@ -808,10 +881,10 @@ class GenericDeclRefType : public Type }; // The "type" of a reference to a module or namespace +FIDDLE() class NamespaceType : public Type { - SLANG_AST_CLASS(NamespaceType) - + FIDDLE(...) DeclRef<NamespaceDeclBase> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); } NamespaceType(DeclRef<NamespaceDeclBase> inDeclRef) { setOperands(inDeclRef); } @@ -823,10 +896,10 @@ class NamespaceType : public Type // The concrete type for a value wrapped in an existential, accessible // when the existential is "opened" in some context. +FIDDLE() class ExtractExistentialType : public Type { - SLANG_AST_CLASS(ExtractExistentialType) - + FIDDLE(...) DeclRef<VarDeclBase> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); } // A reference to the original interface this type is known @@ -879,10 +952,10 @@ class ExtractExistentialType : public Type DeclRef<ThisTypeDecl> getThisTypeDeclRef(); }; +FIDDLE() class ExistentialSpecializedType : public Type { - SLANG_AST_CLASS(ExistentialSpecializedType) - + FIDDLE(...) Type* getBaseType() { return as<Type>(getOperand(0)); } ExpandedSpecializationArg getArg(Index i) { @@ -910,10 +983,10 @@ class ExistentialSpecializedType : public Type }; /// The type of `this` within a polymorphic declaration +FIDDLE() class ThisType : public DeclRefType { - SLANG_AST_CLASS(ThisType) - + FIDDLE(...) ThisType(DeclRefBase* declRef) : DeclRefType(declRef) { @@ -925,10 +998,10 @@ class ThisType : public DeclRefType /// The type of `A & B` where `A` and `B` are types /// /// A value `v` is of type `A & B` if it is both of type `A` and of type `B`. +FIDDLE() class AndType : public Type { - SLANG_AST_CLASS(AndType) - + FIDDLE(...) Type* getLeft() { return as<Type>(getOperand(0)); } Type* getRight() { return as<Type>(getOperand(1)); } @@ -940,10 +1013,10 @@ class AndType : public Type Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class ModifiedType : public Type { - SLANG_AST_CLASS(ModifiedType) - + FIDDLE(...) Type* getBase() { return as<Type>(getOperand(0)); } Index getModifierCount() { return getOperandCount() - 1; } diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 7613dbe80..efb87b831 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -2,9 +2,9 @@ #include "slang-ast-val.h" #include "slang-ast-builder.h" +#include "slang-ast-dispatch.h" #include "slang-check-impl.h" #include "slang-diagnostics.h" -#include "slang-generated-ast-macro.h" #include "slang-mangle.h" #include "slang-syntax.h" @@ -17,7 +17,7 @@ namespace Slang void ValNodeDesc::init() { Hasher hasher; - hasher.hashValue(Int(type)); + hasher.hashValue(type.getTag()); for (Index i = 0; i < operands.getCount(); ++i) { // Note: we are hashing the raw pointer value rather @@ -90,7 +90,7 @@ Val* Val::defaultResolveImpl() // Default resolve implementation is to recursively resolve all operands, and lookup in // deduplication cache. ValNodeDesc newDesc; - newDesc.type = astNodeType; + newDesc.type = SyntaxClass<NodeBase>(astNodeType); bool diff = false; for (auto operand : m_operands) { diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 3a14be17b..cdfb0b51f 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -1,20 +1,21 @@ // slang-ast-val.h - #pragma once #include "slang-ast-base.h" #include "slang-ast-decl.h" +#include "slang-ast-val.h.fiddle" +FIDDLE() namespace Slang { // Syntax class definitions for compile-time values. +FIDDLE() class DirectDeclRef : public DeclRefBase { + FIDDLE(...) public: - SLANG_AST_CLASS(DirectDeclRef) - DirectDeclRef(Decl* decl) { setOperands(decl); } DeclRefBase* _substituteImplOverride( @@ -31,11 +32,11 @@ public: // For example, MemberDeclRef(DirectDeclRef(A), B) ==> DirectDeclRef(B), // and MemberDeclRef(MemberDeclRef(A, B), C) ==> MemberDeclRef(A, C). // +FIDDLE() class MemberDeclRef : public DeclRefBase { + FIDDLE(...) public: - SLANG_AST_CLASS(MemberDeclRef); - DeclRefBase* getParentOperand() { return as<DeclRefBase>(getOperand(1)); } MemberDeclRef(Decl* decl, DeclRefBase* parent) { setOperands(decl, parent); } @@ -55,11 +56,11 @@ public: // Represent a lookup of SuperType::`m_decl` from `lookupSourceType` type that we know conforms to // SuperType. +FIDDLE() class LookupDeclRef : public DeclRefBase { + FIDDLE(...) public: - SLANG_AST_CLASS(LookupDeclRef); - // m_decl represents the decl in SuperType that we want to lookup. // The source type that we are looking up from. @@ -91,11 +92,11 @@ private: }; // Represents a specialization of a generic decl. +FIDDLE() class GenericAppDeclRef : public DeclRefBase { + FIDDLE(...) public: - SLANG_AST_CLASS(GenericAppDeclRef); - DeclRefBase* getGenericDeclRef() { return as<DeclRefBase>(getOperand(1)); } Index getArgCount() { return getOperandCount() - 2; } Val* getArg(Index index) { return getOperand(index + 2); } @@ -137,10 +138,10 @@ public: }; // A compile-time integer (may not have a specific concrete value) +FIDDLE(abstract) class IntVal : public Val { - SLANG_ABSTRACT_AST_CLASS(IntVal) - + FIDDLE(...) Type* getType() { return as<Type>(getOperand(0)); } Val* _resolveImplOverride() { return this; } @@ -152,10 +153,10 @@ class IntVal : public Val }; // Trivial case of a value that is just a constant integer +FIDDLE() class ConstantIntVal : public IntVal { - SLANG_AST_CLASS(ConstantIntVal) - + FIDDLE(...) IntegerLiteralValue getValue() { return getIntConstOperand(1); } // Overrides should be public so base classes can access @@ -166,10 +167,10 @@ class ConstantIntVal : public IntVal }; // The logical "value" of a reference to a generic value parameter +FIDDLE() class GenericParamIntVal : public IntVal { - SLANG_AST_CLASS(GenericParamIntVal) - + FIDDLE(...) DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(1)); } // Overrides should be public so base classes can access @@ -185,10 +186,10 @@ class GenericParamIntVal : public IntVal Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map); }; +FIDDLE() class TypeCastIntVal : public IntVal { - SLANG_AST_CLASS(TypeCastIntVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); Val* _resolveImplOverride(); @@ -213,10 +214,10 @@ class TypeCastIntVal : public IntVal }; // An compile time int val as result of some general computation. +FIDDLE() class FuncCallIntVal : public IntVal { - SLANG_AST_CLASS(FuncCallIntVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); Val* _resolveImplOverride(); @@ -257,10 +258,10 @@ class FuncCallIntVal : public IntVal Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map); }; +FIDDLE() class CountOfIntVal : public IntVal { - SLANG_AST_CLASS(CountOfIntVal) - + FIDDLE(...) CountOfIntVal(Type* inType, Type* typeArg) { setOperands(inType, typeArg); } Val* getTypeArg() { return getOperand(1); } @@ -275,10 +276,10 @@ class CountOfIntVal : public IntVal static Val* tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType); }; +FIDDLE() class WitnessLookupIntVal : public IntVal { - SLANG_AST_CLASS(WitnessLookupIntVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); Val* _resolveImplOverride(); @@ -300,9 +301,10 @@ class WitnessLookupIntVal : public IntVal // polynomial expression "2*a*b^3 + 1" will be represented as: // { constantTerm:1, terms: [ { constFactor:2, paramFactors:[{"a", 1}, {"b", 3}] } ] } +FIDDLE() class PolynomialIntValFactor : public Val { - SLANG_AST_CLASS(PolynomialIntValFactor) + FIDDLE(...) public: IntVal* getParam() const { return as<IntVal>(getOperand(0)); } IntegerLiteralValue getPower() const { return getIntConstOperand(1); } @@ -361,9 +363,11 @@ public: return getPower() == other.getPower() && getParam()->equals(other.getParam()); } }; + +FIDDLE() class PolynomialIntValTerm : public Val { - SLANG_AST_CLASS(PolynomialIntValTerm) + FIDDLE(...) public: IntegerLiteralValue getConstFactor() const { return getIntConstOperand(0); } OperandView<PolynomialIntValFactor> getParamFactors() const @@ -440,9 +444,10 @@ public: } }; +FIDDLE() class PolynomialIntVal : public IntVal { - SLANG_AST_CLASS(PolynomialIntVal) + FIDDLE(...) public: IntegerLiteralValue getConstantTerm() { return getIntConstOperand(1); }; OperandView<PolynomialIntValTerm> getTerms() @@ -482,10 +487,10 @@ public: }; /// An unknown integer value indicating an erroneous sub-expression +FIDDLE() class ErrorIntVal : public IntVal { - SLANG_AST_CLASS(ErrorIntVal) - + FIDDLE(...) ErrorIntVal(Type* inType) { setOperands(inType); } // TODO: We should probably eventually just have an `ErrorVal` here @@ -532,9 +537,10 @@ class ErrorIntVal : public IntVal // the concrete declarations that provide the implementation // of `ILight` for `X`. // +FIDDLE(abstract) class Witness : public Val { - SLANG_ABSTRACT_AST_CLASS(Witness) + FIDDLE(...) }; // A witness that one type is a subtype of another @@ -542,10 +548,10 @@ class Witness : public Val // relationships and type-conforms-to-interface relationships) // // TODO: we may need to tease those apart. +FIDDLE(abstract) class SubtypeWitness : public Witness { - SLANG_ABSTRACT_AST_CLASS(SubtypeWitness) - + FIDDLE(...) Val* _resolveImplOverride(); Type* getSub() { return as<Type>(getOperand(0)); } @@ -555,10 +561,10 @@ class SubtypeWitness : public Witness ConversionCost getOverloadResolutionCost(); }; +FIDDLE() class TypePackSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(TypePackSubtypeWitness) - + FIDDLE(...) Type* getSub() { return as<Type>(getOperand(0)); } Type* getSup() { return as<Type>(getOperand(1)); } @@ -578,10 +584,10 @@ class TypePackSubtypeWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class EachSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(EachSubtypeWitness) - + FIDDLE(...) EachSubtypeWitness(Type* sub, Type* sup, SubtypeWitness* patternWitness) { setOperands(sub, sup, patternWitness); @@ -594,10 +600,10 @@ class EachSubtypeWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class ExpandSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(ExpandSubtypeWitness) - + FIDDLE(...) ExpandSubtypeWitness(Type* sub, Type* sup, SubtypeWitness* patternWitness) { setOperands(sub, sup, patternWitness); @@ -610,10 +616,10 @@ class ExpandSubtypeWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class TypeEqualityWitness : public SubtypeWitness { - SLANG_AST_CLASS(TypeEqualityWitness) - + FIDDLE(...) TypeEqualityWitness(Type* subType, Type* supType) { setOperands(subType, supType); } // Overrides should be public so base classes can access @@ -621,10 +627,10 @@ class TypeEqualityWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class TypeCoercionWitness : public Witness { - SLANG_AST_CLASS(TypeCoercionWitness) - + FIDDLE(...) Type* getFromType() { return as<Type>(getOperand(0)); } Type* getToType() { return as<Type>(getOperand(1)); } @@ -637,10 +643,10 @@ class TypeCoercionWitness : public Witness // A witness that one type is a subtype of another // because some in-scope declaration says so +FIDDLE() class DeclaredSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(DeclaredSubtypeWitness) - + FIDDLE(...) DeclRef<Decl> getDeclRef() { return as<DeclRefBase>(getOperand(2)); } bool isEquality() @@ -664,10 +670,10 @@ class DeclaredSubtypeWitness : public SubtypeWitness }; // A witness that `sub : sup` because `sub : mid` and `mid : sup` +FIDDLE() class TransitiveSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(TransitiveSubtypeWitness) - + FIDDLE(...) // Witness that `sub : mid` SubtypeWitness* getSubToMid() { return as<SubtypeWitness>(getOperand(2)); } @@ -692,10 +698,10 @@ class TransitiveSubtypeWitness : public SubtypeWitness // A witness that `sub : sup` because `sub` was wrapped into // an existential of type `sup`. +FIDDLE() class ExtractExistentialSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(ExtractExistentialSubtypeWitness) - + FIDDLE(...) // The declaration of the existential value that has been opened DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(2)); } @@ -711,17 +717,18 @@ class ExtractExistentialSubtypeWitness : public SubtypeWitness /// A witness of the fact that a user provided "__Dynamic" type argument is a /// subtype to the existential type parameter. +FIDDLE() class DynamicSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(DynamicSubtypeWitness) + FIDDLE(...) DynamicSubtypeWitness(Type* inSub, Type* inSup) { setOperands(inSub, inSup); } }; /// A witness that `T : L & R` because `T : L` and `T : R` +FIDDLE() class ConjunctionSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(ConjunctionSubtypeWitness) - + FIDDLE(...) // At the operational level, this class of witness is // an operation that takes two witness tables `leftWitness` // and `rightWitness`, and forms a pair/tuple of @@ -750,10 +757,10 @@ class ConjunctionSubtypeWitness : public SubtypeWitness }; /// A witness that `T <: L` or `T <: R` because `T <: L&R` +FIDDLE() class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(ExtractFromConjunctionSubtypeWitness) - + FIDDLE(...) // At the operational level, this class of witness is // an operation that takes a pair/tuple of witness tables // `(leftWtiness, rightWitness)` and extracts one of the @@ -785,52 +792,54 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness }; /// A value that represents a modifier attached to some other value +FIDDLE() class ModifierVal : public Val { - SLANG_AST_CLASS(ModifierVal) - + FIDDLE(...) Val* _resolveImplOverride() { return this; } }; +FIDDLE() class TypeModifierVal : public ModifierVal { - SLANG_AST_CLASS(TypeModifierVal) + FIDDLE(...) }; +FIDDLE() class ResourceFormatModifierVal : public TypeModifierVal { - SLANG_AST_CLASS(ResourceFormatModifierVal) + FIDDLE(...) }; +FIDDLE() class UNormModifierVal : public ResourceFormatModifierVal { - SLANG_AST_CLASS(UNormModifierVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class SNormModifierVal : public ResourceFormatModifierVal { - SLANG_AST_CLASS(SNormModifierVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class NoDiffModifierVal : public TypeModifierVal { - SLANG_AST_CLASS(NoDiffModifierVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; /// Represents the result of differentiating a function. +FIDDLE() class DifferentiateVal : public Val { - SLANG_AST_CLASS(DifferentiateVal) - + FIDDLE(...) DifferentiateVal(DeclRef<Decl> inFunc) { setOperands(inFunc); } DeclRef<Decl> getFunc() { return as<DeclRefBase>(getOperand(0)); } @@ -840,49 +849,50 @@ class DifferentiateVal : public Val Val* _resolveImplOverride(); }; +FIDDLE() class ForwardDifferentiateVal : public DifferentiateVal { - SLANG_AST_CLASS(ForwardDifferentiateVal) + FIDDLE(...) ForwardDifferentiateVal(DeclRef<Decl> inFunc) : DifferentiateVal(inFunc) { } }; +FIDDLE() class BackwardDifferentiateVal : public DifferentiateVal { - SLANG_AST_CLASS(BackwardDifferentiateVal) - + FIDDLE(...) BackwardDifferentiateVal(DeclRef<Decl> inFunc) : DifferentiateVal(inFunc) { } }; +FIDDLE() class BackwardDifferentiateIntermediateTypeVal : public DifferentiateVal { - SLANG_AST_CLASS(BackwardDifferentiateIntermediateTypeVal) - + FIDDLE(...) BackwardDifferentiateIntermediateTypeVal(DeclRef<Decl> inFunc) : DifferentiateVal(inFunc) { } }; +FIDDLE() class BackwardDifferentiatePrimalVal : public DifferentiateVal { - SLANG_AST_CLASS(BackwardDifferentiatePrimalVal) - + FIDDLE(...) BackwardDifferentiatePrimalVal(DeclRef<Decl> inFunc) : DifferentiateVal(inFunc) { } }; +FIDDLE() class BackwardDifferentiatePropagateVal : public DifferentiateVal { - SLANG_AST_CLASS(BackwardDifferentiatePropagateVal) - + FIDDLE(...) BackwardDifferentiatePropagateVal(DeclRef<Decl> inFunc) : DifferentiateVal(inFunc) { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 899b04b8b..e511fbc39 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -12,8 +12,8 @@ // logic also orchestrates the overall flow and how // and when things get checked. +#include "slang-ast-forward-declarations.h" #include "slang-ast-iterator.h" -#include "slang-ast-reflect.h" #include "slang-ast-synthesis.h" #include "slang-lookup.h" #include "slang-parser.h" @@ -3590,7 +3590,7 @@ bool SemanticsVisitor::doesAccessorMatchRequirement( // auto satisfyingMemberClass = satisfyingMemberDeclRef.getDecl()->getClass(); auto requiredMemberClass = requiredMemberDeclRef.getDecl()->getClass(); - if (!satisfyingMemberClass.isSubClassOfImpl(requiredMemberClass)) + if (!satisfyingMemberClass.isSubClassOf(requiredMemberClass)) return false; // We do not check the parameters or return types of accessors @@ -11261,7 +11261,7 @@ void _foreachDirectOrExtensionMemberOfType( // for (auto memberDeclRef : getMembers(semantics->getASTBuilder(), containerDeclRef)) { - if (memberDeclRef.getDecl()->getClass().isSubClassOfImpl(syntaxClass)) + if (memberDeclRef.getDecl()->getClass().isSubClassOf(syntaxClass)) { callback(memberDeclRef, (void*)userData); } @@ -11294,7 +11294,7 @@ void _foreachDirectOrExtensionMemberOfType( for (auto memberDeclRef : getMembers(semantics->getASTBuilder(), extDeclRef)) { - if (memberDeclRef.getDecl()->getClass().isSubClassOfImpl(syntaxClass)) + if (memberDeclRef.getDecl()->getClass().isSubClassOf(syntaxClass)) { callback(memberDeclRef, (void*)userData); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2f91a6a77..7b774f300 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1195,7 +1195,7 @@ Type* SemanticsVisitor::tryGetDifferentialType(ASTBuilder* builder, Type* type) auto baseDiffType = tryGetDifferentialType(builder, ptrType->getValueType()); if (!baseDiffType) return nullptr; - return builder->getPtrType(baseDiffType, ptrType->getClassInfo().m_name); + return builder->getPtrType(baseDiffType, ptrType->getClass().getName()); } else if (auto arrayType = as<ArrayExpressionType>(type)) { diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index b818c9e06..44fdf45cf 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1,4 +1,5 @@ // slang-check-overload.cpp + #include "slang-ast-base.h" #include "slang-ast-print.h" #include "slang-check-impl.h" diff --git a/source/slang/slang-check-resolve-val.cpp b/source/slang/slang-check-resolve-val.cpp index 92a9a9d6d..e16a470de 100644 --- a/source/slang/slang-check-resolve-val.cpp +++ b/source/slang/slang-check-resolve-val.cpp @@ -2,7 +2,8 @@ // Logic for resolving/simplifying Types and DeclRefs. -#include "slang-ast-reflect.h" +#include "slang-ast-dispatch.h" +#include "slang-ast-forward-declarations.h" #include "slang-ast-synthesis.h" #include "slang-check-impl.h" #include "slang-lookup.h" diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 2c8f3d0c0..db753713b 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -266,8 +266,10 @@ bool SemanticsVisitor::CoerceToProperTypeImpl( // diagnostic. // Get the AST node type info, so we can output a 'got' name - auto info = ASTClassInfo::getInfo(originalExpr->astNodeType); - diagSink->diagnose(originalExpr, Diagnostics::expectedAType, info->m_name); + diagSink->diagnose( + originalExpr, + Diagnostics::expectedAType, + originalExpr->getClass().getName()); } } @@ -296,7 +298,12 @@ bool SemanticsVisitor::CoerceToProperTypeImpl( { if (auto typeParam = as<GenericTypeParamDecl>(member)) { - if (!typeParam->initType.exp) + if (auto defaultArg = typeParam->initType.type) + { + if (outProperType) + args.add(defaultArg); + } + else { if (diagSink) { @@ -305,10 +312,6 @@ bool SemanticsVisitor::CoerceToProperTypeImpl( } return false; } - - // TODO: this is one place where syntax should get cloned! - if (outProperType) - args.add(ExtractGenericArgVal(typeParam->initType.exp)); } else if (auto valParam = as<GenericValueParamDecl>(member)) { diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 9815f6ff1..8e9b8f430 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2173,17 +2173,7 @@ SlangResult EndToEndCompileRequest::writeContainerToStream(Stream* stream) options.sourceManager = linkage->getSourceManager(); } - { - RiffContainer container; - { - SerialContainerData data; - SLANG_RETURN_ON_FAIL( - SerialContainerUtil::addEndToEndRequestToData(this, options, data)); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(data, options, &container)); - } - // We now write the RiffContainer to the stream - SLANG_RETURN_ON_FAIL(RiffUtil::write(container.getRoot(), true, stream)); - } + SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(this, options, stream)); return SLANG_OK; } diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 26a4bb43b..bfae6e400 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -34,6 +34,7 @@ namespace Slang struct PathInfo; struct IncludeHandler; struct SharedSemanticsContext; +struct ModuleChunkRef; class ProgramLayout; class PtrType; @@ -2086,8 +2087,39 @@ struct ContainerTypeKey } }; -/// A dictionary of currently loaded modules. Used by `findOrImportModule` to -/// lookup additional loaded modules. +/// A dictionary of modules to be considered when resolving `import`s, +/// beyond those that would normally be found through a `Linkage`. +/// +/// Checking of an `import` declaration will bottleneck through +/// `Linkage::findOrImportModule`, which would usually just check for +/// any module that had been previously loaded into the same `Linkage` +/// (e.g., by a call to `Linkage::loadModule()`). +/// +/// In the case where compilation is being done through an +/// explicit `FrontEndCompileRequest` or `EndToEndCompileRequest`, +/// the modules being compiled by that request do not get added to +/// the surrounding `Linkage`. +/// +/// There is a corner case when an explicit compile request has +/// multiple `TranslationUnitRequest`s, because the user (reasonably) +/// expects that if they compile `A.slang` and `B.slang` as two +/// distinct translation units in the same compile request, then +/// an `import B` inside of `A.slang` should resolve to reference +/// the code of `B.slang`. But because neither `A` nor `B` gets +/// added to the `Linkage`, and the `Linkage` is what usually +/// determines what is or isn't loaded, that intuition will +/// be wrong, without a bit of help. +/// +/// The `LoadedModuleDictionary` is thus filled in by a +/// `FrontEndCompileRequest` to collect the modules it is compiling, +/// so that they can cross-reference one another (albeit with +/// a current implementation restriction that modules in the +/// request can only `import` those earlier in the request...). +/// +/// The dictionary then gets passed around between nearly all of +/// the operations that deal with loading modules, to make sure +/// that they can detect a previously loaded module. +/// typedef Dictionary<Name*, Module*> LoadedModuleDictionary; enum ModuleBlobType @@ -2096,8 +2128,6 @@ enum ModuleBlobType IR }; -struct SerialContainerDataModule; - /// A context for loading and re-using code modules. class Linkage : public RefObject, public slang::ISession { @@ -2287,7 +2317,15 @@ public: /// Add a new target and return its index. UInt addTarget(CodeGenTarget target); - RefPtr<Module> loadModule( + /// "Bottleneck" routine for loading a module. + /// + /// All attempts to load a module, whether through + /// Slang API calls, `import` operations, or other + /// means, should bottleneck through `loadModuleImpl`, + /// or one of the specialized cases `loadSourceModuleImpl` + /// and `loadBinaryModuleImpl`. + /// + RefPtr<Module> loadModuleImpl( Name* name, const PathInfo& filePathInfo, ISlangBlob* fileContentsBlob, @@ -2296,17 +2334,49 @@ public: const LoadedModuleDictionary* additionalLoadedModules, ModuleBlobType blobType); - RefPtr<Module> loadModuleFromIRBlobImpl( + RefPtr<Module> loadSourceModuleImpl( Name* name, const PathInfo& filePathInfo, ISlangBlob* fileContentsBlob, SourceLoc const& loc, DiagnosticSink* sink, const LoadedModuleDictionary* additionalLoadedModules); - RefPtr<Module> loadDeserializedModule( + + RefPtr<Module> loadBinaryModuleImpl( Name* name, const PathInfo& filePathInfo, - SerialContainerDataModule& m, + ISlangBlob* fileContentsBlob, + SourceLoc const& loc, + DiagnosticSink* sink); + + /// Either finds a previously-loaded module matching what + /// was serialized into `moduleChunk`, or else attempts + /// to load the serialized module. + /// + /// If a previously-loaded module is found that matches the + /// name or path information in `moduleChunk`, then that + /// previously-loaded module is returned. + /// + /// Othwerise, attempts to load a module from `moduleChunk` + /// and, if successful, returns the freshly loaded module. + /// + /// Otherwise, return null. + /// + RefPtr<Module> findOrLoadSerializedModuleForModuleLibrary( + ModuleChunkRef moduleChunk, + DiagnosticSink* sink); + + RefPtr<Module> loadSerializedModule( + Name* moduleName, + const PathInfo& moduleFilePathInfo, + ModuleChunkRef moduleChunk, + SourceLoc const& requestingLoc, + DiagnosticSink* sink); + + SlangResult loadSerializedModuleContents( + Module* module, + const PathInfo& moduleFilePathInfo, + ModuleChunkRef moduleChunk, DiagnosticSink* sink); SourceFile* loadSourceFile(String pathFrom, String path); @@ -2317,10 +2387,8 @@ public: Name* name, PathInfo const& pathInfo); - /// Load a module of the given name. - Module* loadModule(String const& name); - bool isBinaryModuleUpToDate(String fromPath, RiffContainer* container); + bool isBinaryModuleUpToDate(String fromPath, ModuleChunkRef moduleChunk); RefPtr<Module> findOrImportModule( Name* name, @@ -2328,12 +2396,6 @@ public: DiagnosticSink* sink, const LoadedModuleDictionary* loadedModules = nullptr); - void prepareDeserializedModule( - SerialContainerDataModule& moduleEntry, - const PathInfo& pathInfo, - Module* module, - DiagnosticSink* sink); - SourceFile* findFile(Name* name, SourceLoc loc, IncludeSystem& outIncludeSystem); struct IncludeResult { diff --git a/source/slang/slang-doc-ast.cpp b/source/slang/slang-doc-ast.cpp index 0d4b69895..7e83d5d59 100644 --- a/source/slang/slang-doc-ast.cpp +++ b/source/slang/slang-doc-ast.cpp @@ -2,7 +2,7 @@ #include "slang-doc-ast.h" #include "../core/slang-string-util.h" -#include "slang/slang-ast-support-types.h" +#include "slang-ast-support-types.h" // #include "slang-ast-builder.h" // #include "slang-ast-print.h" diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index ada34f220..729802f4e 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -2041,10 +2041,12 @@ ScalarizedVal adaptType(IRBuilder* builder, IRInst* val, IRType* toType, IRType* // Get array sizes once auto fromSize = getIntVal(fromArray->getElementCount()); auto toSize = getIntVal(toArray->getElementCount()); - SLANG_ASSERT(fromSize <= toSize); - // Extract elements one at a time up to the source array size - for (Index i = 0; i < fromSize; i++) + // Extract elements one at a time up to the minimum + // size, between the source and destination. + // + auto limit = fromSize < toSize ? fromSize : toSize; + for (Index i = 0; i < limit; i++) { auto element = builder->emitElementExtract( fromArray->getElementType(), diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d764cade5..2b94a1fa7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -5843,13 +5843,19 @@ struct DestinationDrivenRValueExprLoweringVisitor } /// Emit code for a `try` invoke. - LoweredValInfo visitTryExpr(TryExpr* expr) + void visitTryExpr(TryExpr* expr) { auto invokeExpr = as<InvokeExpr>(expr->base); assert(invokeExpr); TryClauseEnvironment tryEnv; tryEnv.clauseType = expr->tryClauseType; - return sharedLoweringContext.visitInvokeExprImpl(invokeExpr, destination, tryEnv); + auto rValue = sharedLoweringContext.visitInvokeExprImpl(invokeExpr, destination, tryEnv); + if (rValue.flavor != LoweredValInfo::Flavor::None) + { + // If we weren't able to fuse the destination write during lowering rvalue, + // we should insert the assign operation now. + assign(context, destination, rValue); + } } }; diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 63dfffb94..f5878cb1d 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -381,7 +381,7 @@ void emitVal(ManglingContext* context, Val* val) } else if (auto modifier = as<ModifierVal>(val)) { - emitNameImpl(context, UnownedStringSlice(modifier->getClassInfo().m_name)); + emitNameImpl(context, UnownedStringSlice(modifier->getClass().getName())); } else { diff --git a/source/slang/slang-module-library.cpp b/source/slang/slang-module-library.cpp index 060c9007c..c03d5c2dd 100644 --- a/source/slang/slang-module-library.cpp +++ b/source/slang/slang-module-library.cpp @@ -45,6 +45,8 @@ SlangResult loadModuleLibrary( EndToEndCompileRequest* req, ComPtr<IModuleLibrary>& outLibrary) { + SLANG_UNUSED(path); + auto library = new ModuleLibrary; ComPtr<IModuleLibrary> scopeLibrary(library); @@ -55,54 +57,28 @@ SlangResult loadModuleLibrary( SLANG_RETURN_ON_FAIL(RiffUtil::read(&memoryStream, riffContainer)); auto linkage = req->getLinkage(); + auto sink = req->getSink(); + auto namePool = req->getNamePool(); + + auto container = ContainerChunkRef::find(&riffContainer); + + for (auto moduleChunk : container.getModules()) + { + auto loadedModule = linkage->findOrLoadSerializedModuleForModuleLibrary(moduleChunk, sink); + if (!loadedModule) + return SLANG_FAIL; + + library->m_modules.add(loadedModule); + } + + for (auto entryPointChunk : container.getEntryPoints()) { - SerialContainerData containerData; - - SerialContainerUtil::ReadOptions options; - options.namePool = req->getNamePool(); - options.session = req->getSession(); - options.sharedASTBuilder = linkage->getASTBuilder()->getSharedASTBuilder(); - options.sourceManager = linkage->getSourceManager(); - options.linkage = req->getLinkage(); - options.sink = req->getSink(); - options.astBuilder = linkage->getASTBuilder(); - options.modulePath = path; - SLANG_RETURN_ON_FAIL( - SerialContainerUtil::read(&riffContainer, options, nullptr, containerData)); - DiagnosticSink sink; - - // Modules in the container should be serialized in its depedency order, - // so that we always load the dependencies before the consuming module. - for (auto& module : containerData.modules) - { - // If the irModule is set, add it - if (module.irModule) - { - if (module.dependentFiles.getCount() == 0) - return SLANG_FAIL; - if (!module.astRootNode) - return SLANG_FAIL; - auto loadedModule = linkage->loadDeserializedModule( - as<ModuleDecl>(module.astRootNode)->getName(), - PathInfo::makePath(module.dependentFiles.getFirst()), - module, - &sink); - if (!loadedModule) - return SLANG_FAIL; - library->m_modules.add(loadedModule); - } - } - - for (const auto& entryPoint : containerData.entryPoints) - { - FrontEndCompileRequest::ExtraEntryPointInfo dst; - dst.mangledName = entryPoint.mangledName; - dst.name = entryPoint.name; - dst.profile = entryPoint.profile; - - // Add entry point - library->m_entryPoints.add(dst); - } + FrontEndCompileRequest::ExtraEntryPointInfo entryPointInfo; + entryPointInfo.mangledName = entryPointChunk.getMangledName(); + entryPointInfo.name = namePool->getName(entryPointChunk.getName()); + entryPointInfo.profile = entryPointChunk.getProfile(); + + library->m_entryPoints.add(entryPointInfo); } outLibrary.swap(scopeLibrary); diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index b09e36a1b..3c0bcf8db 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -2424,7 +2424,7 @@ SlangResult OptionsParser::_parse(int argc, char const* const* argv) { CommandLineArg name; SLANG_RETURN_ON_FAIL(m_reader.expectArg(name)); - // TODO: doagnose deprecated option + // TODO: warn that this option is deprecated break; } case OptionKind::EmbedDownstreamIR: diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 00f15cbb3..a8573c909 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1771,46 +1771,46 @@ public: void visitDeclRefExpr(DeclRefExpr* expr) { expr->scope = scope; } void visitGenericAppExpr(GenericAppExpr* expr) { - expr->functionExpr->accept(this, nullptr); + dispatch(expr->functionExpr); for (auto arg : expr->arguments) - arg->accept(this, nullptr); + dispatch(arg); } void visitIndexExpr(IndexExpr* expr) { - expr->baseExpression->accept(this, nullptr); + dispatch(expr->baseExpression); for (auto arg : expr->indexExprs) - arg->accept(this, nullptr); + dispatch(arg); } void visitMemberExpr(MemberExpr* expr) { - expr->baseExpression->accept(this, nullptr); + dispatch(expr->baseExpression); expr->scope = scope; } void visitStaticMemberExpr(StaticMemberExpr* expr) { - expr->baseExpression->accept(this, nullptr); + dispatch(expr->baseExpression); expr->scope = scope; } void visitAppExprBase(AppExprBase* expr) { - expr->functionExpr->accept(this, nullptr); + dispatch(expr->functionExpr); for (auto arg : expr->arguments) - arg->accept(this, nullptr); + dispatch(arg); } void visitIsTypeExpr(IsTypeExpr* expr) { if (expr->typeExpr.exp) - expr->typeExpr.exp->accept(this, nullptr); + dispatch(expr->typeExpr.exp); } void visitAsTypeExpr(AsTypeExpr* expr) { if (expr->typeExpr) - expr->typeExpr->accept(this, nullptr); + dispatch(expr->typeExpr); } void visitSizeOfLikeExpr(SizeOfLikeExpr* expr) { if (expr->value) - expr->value->accept(this, nullptr); + dispatch(expr->value); } void visitExpr(Expr* /*expr*/) {} }; @@ -1910,7 +1910,7 @@ static Decl* parseTraditionalFuncDecl(Parser* parser, DeclaratorInfo const& decl // ReplaceScopeVisitor replaceScopeVisitor; replaceScopeVisitor.scope = parser->currentScope; - declaratorInfo.typeSpec->accept(&replaceScopeVisitor, nullptr); + replaceScopeVisitor.dispatch(declaratorInfo.typeSpec); decl->returnType = TypeExp(declaratorInfo.typeSpec); @@ -4377,7 +4377,7 @@ static NodeBase* parseTypeAliasDecl(Parser* parser, void* /*userData*/) // the class of AST node to construct. NodeBase* parseSimpleSyntax(Parser* parser, void* userData) { - SyntaxClassBase syntaxClass((ReflectClassInfo*)userData); + SyntaxClassBase syntaxClass((SyntaxClassInfo*)userData); return (NodeBase*)syntaxClass.createInstanceImpl(parser->astBuilder); } @@ -4411,7 +4411,7 @@ static NodeBase* parseSyntaxDecl(Parser* parser, void* /*userData*/) // to the `parseSimpleSyntax` callback that will just construct // an instance of that type to represent the keyword in the AST. SyntaxParseCallback parseCallback = &parseSimpleSyntax; - void* parseUserData = (void*)syntaxClass.classInfo; + void* parseUserData = (void*)syntaxClass.getInfo(); // Next we look for an initializer that will make this keyword // an alias for some existing keyword. @@ -4435,7 +4435,7 @@ static NodeBase* parseSyntaxDecl(Parser* parser, void* /*userData*/) // If we don't already have a syntax class specified, then // we will crib the one from the existing syntax, to ensure // that we are creating a drop-in alias. - if (!syntaxClass.classInfo) + if (!syntaxClass) syntaxClass = existingSyntax->syntaxClass; } } @@ -4445,7 +4445,7 @@ static NodeBase* parseSyntaxDecl(Parser* parser, void* /*userData*/) // // TODO: down the line this should be expanded so that the user can reference // an existing *function* to use to parse the chosen syntax. - if (!syntaxClass.classInfo) + if (!syntaxClass) { // TODO: diagnose: either a type or an existing keyword needs to be specified } @@ -4757,7 +4757,7 @@ static NodeBase* parseAttributeSyntaxDecl(Parser* parser, void* /*userData*/) auto classNameAndLoc = expectIdentifier(parser); syntaxClass = parser->astBuilder->findSyntaxClass(classNameAndLoc.name); - assert(syntaxClass.classInfo); + assert(syntaxClass); } else { @@ -8428,20 +8428,20 @@ static void addBuiltinSyntax( SyntaxParseCallback callback, void* userData = nullptr) { - addBuiltinSyntaxImpl(session, scope, name, callback, userData, getClass<T>()); + addBuiltinSyntaxImpl(session, scope, name, callback, userData, getSyntaxClass<T>()); } template<typename T> static void addSimpleModifierSyntax(Session* session, Scope* scope, char const* name) { - auto syntaxClass = getClass<T>(); + auto syntaxClass = getSyntaxClass<T>(); addBuiltinSyntaxImpl( session, scope, name, &parseSimpleSyntax, (void*)syntaxClass.classInfo, - getClass<T>()); + getSyntaxClass<T>()); } static IROp parseIROp(Parser* parser, Token& outToken) @@ -8931,10 +8931,10 @@ static NodeBase* parseMagicTypeModifier(Parser* parser, void* /*userData*/) modifier->tag = uint32_t(stringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent())); } - auto classInfo = parser->astBuilder->findClassInfo(getName(parser, modifier->magicName)); - if (classInfo) + auto syntaxClass = parser->astBuilder->findSyntaxClass(getName(parser, modifier->magicName)); + if (syntaxClass) { - modifier->magicNodeType = ASTNodeType(classInfo->m_classId); + modifier->magicNodeType = syntaxClass; } // TODO: print diagnostic if the magic type name doesn't correspond to an actual ASTNodeType. parser->ReadToken(TokenType::RParent); @@ -9006,7 +9006,7 @@ static NodeBase* parseAttributeTargetModifier(Parser* parser, void* /*userData*/ static SyntaxParseInfo _makeParseExpr(const char* keywordName, SyntaxParseCallback callback) { SyntaxParseInfo entry; - entry.classInfo = &Expr::kReflectClassInfo; + entry.classInfo = getSyntaxClass<Expr>(); entry.keywordName = keywordName; entry.callback = callback; return entry; @@ -9016,18 +9016,18 @@ static SyntaxParseInfo _makeParseDecl(const char* keywordName, SyntaxParseCallba SyntaxParseInfo entry; entry.keywordName = keywordName; entry.callback = callback; - entry.classInfo = &Decl::kReflectClassInfo; + entry.classInfo = getSyntaxClass<Decl>(); return entry; } static SyntaxParseInfo _makeParseModifier( const char* keywordName, - const ReflectClassInfo& classInfo) + SyntaxClass<NodeBase> const& syntaxClass) { // If we just have class info - use simple parser SyntaxParseInfo entry; entry.keywordName = keywordName; entry.callback = &parseSimpleSyntax; - entry.classInfo = &classInfo; + entry.classInfo = syntaxClass; return entry; } static SyntaxParseInfo _makeParseModifier(const char* keywordName, SyntaxParseCallback callback) @@ -9035,7 +9035,7 @@ static SyntaxParseInfo _makeParseModifier(const char* keywordName, SyntaxParseCa SyntaxParseInfo entry; entry.keywordName = keywordName; entry.callback = callback; - entry.classInfo = &Modifier::kReflectClassInfo; + entry.classInfo = getSyntaxClass<Modifier>(); return entry; } @@ -9082,68 +9082,68 @@ static const SyntaxParseInfo g_parseSyntaxEntries[] = { // and which can be represented just by creating // a new AST node of the corresponding type. - _makeParseModifier("in", InModifier::kReflectClassInfo), - _makeParseModifier("out", OutModifier::kReflectClassInfo), - _makeParseModifier("inout", InOutModifier::kReflectClassInfo), - _makeParseModifier("__ref", RefModifier::kReflectClassInfo), - _makeParseModifier("__constref", ConstRefModifier::kReflectClassInfo), - _makeParseModifier("const", ConstModifier::kReflectClassInfo), - _makeParseModifier("__builtin", BuiltinModifier::kReflectClassInfo), - _makeParseModifier("highp", GLSLPrecisionModifier::kReflectClassInfo), - _makeParseModifier("lowp", GLSLPrecisionModifier::kReflectClassInfo), - _makeParseModifier("mediump", GLSLPrecisionModifier::kReflectClassInfo), - - _makeParseModifier("__global", ActualGlobalModifier::kReflectClassInfo), - - _makeParseModifier("inline", InlineModifier::kReflectClassInfo), - _makeParseModifier("public", PublicModifier::kReflectClassInfo), - _makeParseModifier("private", PrivateModifier::kReflectClassInfo), - _makeParseModifier("internal", InternalModifier::kReflectClassInfo), - - _makeParseModifier("require", RequireModifier::kReflectClassInfo), - _makeParseModifier("param", ParamModifier::kReflectClassInfo), - _makeParseModifier("extern", ExternModifier::kReflectClassInfo), - - _makeParseModifier("row_major", HLSLRowMajorLayoutModifier::kReflectClassInfo), - _makeParseModifier("column_major", HLSLColumnMajorLayoutModifier::kReflectClassInfo), - - _makeParseModifier("nointerpolation", HLSLNoInterpolationModifier::kReflectClassInfo), - _makeParseModifier("noperspective", HLSLNoPerspectiveModifier::kReflectClassInfo), - _makeParseModifier("linear", HLSLLinearModifier::kReflectClassInfo), - _makeParseModifier("sample", HLSLSampleModifier::kReflectClassInfo), - _makeParseModifier("centroid", HLSLCentroidModifier::kReflectClassInfo), - _makeParseModifier("precise", PreciseModifier::kReflectClassInfo), + _makeParseModifier("in", getSyntaxClass<InModifier>()), + _makeParseModifier("out", getSyntaxClass<OutModifier>()), + _makeParseModifier("inout", getSyntaxClass<InOutModifier>()), + _makeParseModifier("__ref", getSyntaxClass<RefModifier>()), + _makeParseModifier("__constref", getSyntaxClass<ConstRefModifier>()), + _makeParseModifier("const", getSyntaxClass<ConstModifier>()), + _makeParseModifier("__builtin", getSyntaxClass<BuiltinModifier>()), + _makeParseModifier("highp", getSyntaxClass<GLSLPrecisionModifier>()), + _makeParseModifier("lowp", getSyntaxClass<GLSLPrecisionModifier>()), + _makeParseModifier("mediump", getSyntaxClass<GLSLPrecisionModifier>()), + + _makeParseModifier("__global", getSyntaxClass<ActualGlobalModifier>()), + + _makeParseModifier("inline", getSyntaxClass<InlineModifier>()), + _makeParseModifier("public", getSyntaxClass<PublicModifier>()), + _makeParseModifier("private", getSyntaxClass<PrivateModifier>()), + _makeParseModifier("internal", getSyntaxClass<InternalModifier>()), + + _makeParseModifier("require", getSyntaxClass<RequireModifier>()), + _makeParseModifier("param", getSyntaxClass<ParamModifier>()), + _makeParseModifier("extern", getSyntaxClass<ExternModifier>()), + + _makeParseModifier("row_major", getSyntaxClass<HLSLRowMajorLayoutModifier>()), + _makeParseModifier("column_major", getSyntaxClass<HLSLColumnMajorLayoutModifier>()), + + _makeParseModifier("nointerpolation", getSyntaxClass<HLSLNoInterpolationModifier>()), + _makeParseModifier("noperspective", getSyntaxClass<HLSLNoPerspectiveModifier>()), + _makeParseModifier("linear", getSyntaxClass<HLSLLinearModifier>()), + _makeParseModifier("sample", getSyntaxClass<HLSLSampleModifier>()), + _makeParseModifier("centroid", getSyntaxClass<HLSLCentroidModifier>()), + _makeParseModifier("precise", getSyntaxClass<PreciseModifier>()), _makeParseModifier("shared", parseSharedModifier), - _makeParseModifier("groupshared", HLSLGroupSharedModifier::kReflectClassInfo), - _makeParseModifier("static", HLSLStaticModifier::kReflectClassInfo), - _makeParseModifier("uniform", HLSLUniformModifier::kReflectClassInfo), + _makeParseModifier("groupshared", getSyntaxClass<HLSLGroupSharedModifier>()), + _makeParseModifier("static", getSyntaxClass<HLSLStaticModifier>()), + _makeParseModifier("uniform", getSyntaxClass<HLSLUniformModifier>()), _makeParseModifier("volatile", parseVolatileModifier), _makeParseModifier("coherent", parseCoherentModifier), _makeParseModifier("restrict", parseRestrictModifier), _makeParseModifier("readonly", parseReadonlyModifier), _makeParseModifier("writeonly", parseWriteonlyModifier), - _makeParseModifier("export", HLSLExportModifier::kReflectClassInfo), - _makeParseModifier("dynamic_uniform", DynamicUniformModifier::kReflectClassInfo), + _makeParseModifier("export", getSyntaxClass<HLSLExportModifier>()), + _makeParseModifier("dynamic_uniform", getSyntaxClass<DynamicUniformModifier>()), // Modifiers for geometry shader input - _makeParseModifier("point", HLSLPointModifier::kReflectClassInfo), - _makeParseModifier("line", HLSLLineModifier::kReflectClassInfo), - _makeParseModifier("triangle", HLSLTriangleModifier::kReflectClassInfo), - _makeParseModifier("lineadj", HLSLLineAdjModifier::kReflectClassInfo), - _makeParseModifier("triangleadj", HLSLTriangleAdjModifier::kReflectClassInfo), + _makeParseModifier("point", getSyntaxClass<HLSLPointModifier>()), + _makeParseModifier("line", getSyntaxClass<HLSLLineModifier>()), + _makeParseModifier("triangle", getSyntaxClass<HLSLTriangleModifier>()), + _makeParseModifier("lineadj", getSyntaxClass<HLSLLineAdjModifier>()), + _makeParseModifier("triangleadj", getSyntaxClass<HLSLTriangleAdjModifier>()), // Modifiers for mesh shader parameters - _makeParseModifier("vertices", HLSLVerticesModifier::kReflectClassInfo), - _makeParseModifier("indices", HLSLIndicesModifier::kReflectClassInfo), - _makeParseModifier("primitives", HLSLPrimitivesModifier::kReflectClassInfo), - _makeParseModifier("payload", HLSLPayloadModifier::kReflectClassInfo), + _makeParseModifier("vertices", getSyntaxClass<HLSLVerticesModifier>()), + _makeParseModifier("indices", getSyntaxClass<HLSLIndicesModifier>()), + _makeParseModifier("primitives", getSyntaxClass<HLSLPrimitivesModifier>()), + _makeParseModifier("payload", getSyntaxClass<HLSLPayloadModifier>()), // Modifiers for unary operator declarations - _makeParseModifier("__prefix", PrefixModifier::kReflectClassInfo), - _makeParseModifier("__postfix", PostfixModifier::kReflectClassInfo), + _makeParseModifier("__prefix", getSyntaxClass<PrefixModifier>()), + _makeParseModifier("__postfix", getSyntaxClass<PostfixModifier>()), // Modifier to apply to `import` that should be re-exported - _makeParseModifier("__exported", ExportedModifier::kReflectClassInfo), + _makeParseModifier("__exported", getSyntaxClass<ExportedModifier>()), // Add syntax for more complex modifiers, which allow // or expect more tokens after the initial keyword. @@ -9208,7 +9208,7 @@ ModuleDecl* populateBaseLanguageModule(ASTBuilder* astBuilder, Scope* scope) scope, info.keywordName, info.callback, - const_cast<ReflectClassInfo*>(info.classInfo), + info.classInfo.getInfo(), info.classInfo); } diff --git a/source/slang/slang-parser.h b/source/slang/slang-parser.h index 9f9f4972a..c4e68a7fa 100644 --- a/source/slang/slang-parser.h +++ b/source/slang/slang-parser.h @@ -45,9 +45,9 @@ ModuleDecl* populateBaseLanguageModule(ASTBuilder* astBuilder, Scope* scope); /// for the `parseUserData` to be set the the associated classInfo struct SyntaxParseInfo { - const char* keywordName; ///< The keyword associated with this parse - SyntaxParseCallback callback; ///< The callback to apply to the parse - const ReflectClassInfo* classInfo; ///< + const char* keywordName; ///< The keyword associated with this parse + SyntaxParseCallback callback; ///< The callback to apply to the parse + SyntaxClass<NodeBase> classInfo; ///< }; /// Get all of the predefined SyntaxParseInfos diff --git a/source/slang/slang-ref-object-reflect.cpp b/source/slang/slang-ref-object-reflect.cpp deleted file mode 100644 index 303601b7b..000000000 --- a/source/slang/slang-ref-object-reflect.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#include "slang-ref-object-reflect.h" - -#include "slang-ast-support-types.h" -#include "slang-generated-obj-macro.h" -#include "slang-generated-obj.h" -#include "slang.h" - -// #include "slang-serialize.h" - -#include "slang-serialize-ast-type-info.h" - -namespace Slang -{ - -static const SerialClass* _addClass( - SerialClasses* serialClasses, - RefObjectType type, - RefObjectType super, - const List<SerialField>& fields) -{ - const SerialClass* superClass = - serialClasses->getSerialClass(SerialTypeKind::RefObject, SerialSubType(super)); - return serialClasses->add( - SerialTypeKind::RefObject, - SerialSubType(type), - fields.getBuffer(), - fields.getCount(), - superClass); -} - -#define SLANG_REF_OBJECT_ADD_SERIAL_FIELD(FIELD_NAME, TYPE, param) \ - fields.add(SerialField::make(#FIELD_NAME, &obj->FIELD_NAME)); - -// Note that the obj point is not nullptr, because some compilers notice this is 'indexing from -// null' and warn/error. So we offset from 1. -#define SLANG_REF_OBJECT_ADD_SERIAL_CLASS(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - { \ - NAME* obj = SerialField::getPtr<NAME>(); \ - SLANG_UNUSED(obj); \ - fields.clear(); \ - SLANG_FIELDS_RefObject_##NAME(SLANG_REF_OBJECT_ADD_SERIAL_FIELD, param) \ - _addClass(serialClasses, RefObjectType::NAME, RefObjectType::SUPER, fields); \ - } - -struct RefObjectAccess -{ - template<typename T> - static void* create(void* context) - { - SLANG_UNUSED(context) - return new T; - } - - static void calcClasses(SerialClasses* serialClasses) - { - // Add SerialRefObject first, and specially handle so that we add a null super class - serialClasses->add( - SerialTypeKind::RefObject, - SerialSubType(RefObjectType::SerialRefObject), - nullptr, - 0, - nullptr); - - // Add the rest in order such that Super class is always added before its children - List<SerialField> fields; - SLANG_CHILDREN_RefObject_SerialRefObject(SLANG_REF_OBJECT_ADD_SERIAL_CLASS, _) - } -}; - -#define SLANG_GET_SUPER_BASE(SUPER) nullptr -#define SLANG_GET_SUPER_INNER(SUPER) &SUPER::kReflectClassInfo -#define SLANG_GET_SUPER_LEAF(SUPER) &SUPER::kReflectClassInfo - -#define SLANG_GET_CREATE_FUNC_NONE(NAME) nullptr -#define SLANG_GET_CREATE_FUNC_OBJ_ABSTRACT(NAME) nullptr -#define SLANG_GET_CREATE_FUNC_OBJ(NAME) &RefObjectAccess::create<NAME> - -#define SLANG_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - /* static */ const ReflectClassInfo NAME::kReflectClassInfo = { \ - uint32_t(RefObjectType::NAME), \ - uint32_t(RefObjectType::LAST), \ - SLANG_GET_SUPER_##TYPE(SUPER), \ - #NAME, \ - SLANG_GET_CREATE_FUNC_##MARKER(NAME), \ - nullptr, \ - uint32_t(sizeof(NAME)), \ - uint8_t(SLANG_ALIGN_OF(NAME))}; - -SLANG_ALL_RefObject_SerialRefObject(SLANG_REFLECT_CLASS_INFO, _) - - /* static */ const SerialRefObjects SerialRefObjects::g_singleton; - -// Macro to set all of the entries in m_infos for SerialRefObjects -#define SLANG_GET_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - m_infos[Index(RefObjectType::NAME)] = &NAME::kReflectClassInfo; - -SerialRefObjects::SerialRefObjects(){ - SLANG_ALL_RefObject_SerialRefObject(SLANG_GET_REFLECT_CLASS_INFO, _)} - -/* static */ SlangResult SerialRefObjects::addSerialClasses(SerialClasses* serialClasses) -{ - RefObjectAccess::calcClasses(serialClasses); - return SLANG_OK; -} - -} // namespace Slang diff --git a/source/slang/slang-ref-object-reflect.h b/source/slang/slang-ref-object-reflect.h deleted file mode 100644 index 1a6bf4520..000000000 --- a/source/slang/slang-ref-object-reflect.h +++ /dev/null @@ -1,73 +0,0 @@ -// slang-ref-object-reflect.h - -#ifndef SLANG_REF_OBJECT_REFLECT_H -#define SLANG_REF_OBJECT_REFLECT_H - -#include "../core/slang-smart-pointer.h" -#include "slang-generated-obj.h" -#include "slang-serialize-reflection.h" - -class SerialClasses; - -struct RefObjectAccess; - -#define SLANG_OBJ_CLASS_REFLECT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ -public: \ - typedef NAME This; \ - static const ReflectClassInfo kReflectClassInfo; \ - virtual const ReflectClassInfo* getClassInfo() const SLANG_OVERRIDE \ - { \ - return &kReflectClassInfo; \ - } \ - \ - friend struct RefObjectAccess; \ - \ - SLANG_CLASS_REFLECT_SUPER_##TYPE(SUPER) - -// Placed in any SerialRefObject derived class -#define SLANG_ABSTRACT_OBJ_CLASS(NAME) SLANG_RefObject_##NAME(SLANG_OBJ_CLASS_REFLECT_IMPL, _) -#define SLANG_OBJ_CLASS(NAME) SLANG_RefObject_##NAME(SLANG_OBJ_CLASS_REFLECT_IMPL, _) - -namespace Slang -{ - -class SerialClasses; - -// Is friended such that internally we have access to construct or get members -struct RefObjectAccess; - -// Base class for Serialized RefObject derived classes. The main feature is that gives away to get -// ReflectClassInfo via getClassInfo() method -class SerialRefObject : public RefObject -{ -public: - typedef RefObject Super; - typedef SerialRefObject This; - - static const ReflectClassInfo kReflectClassInfo; - - virtual const ReflectClassInfo* getClassInfo() const { return &kReflectClassInfo; } -}; - -// For turning RefObjectType back to ReflectClassInfo -struct SerialRefObjects -{ - /// Add serialization classes - static SlangResult addSerialClasses(SerialClasses* serialClasses); - - static const ReflectClassInfo* getClassInfo(RefObjectType type) - { - return g_singleton.m_infos[Index(type)]; - } - - - static const SerialRefObjects g_singleton; - -protected: - SerialRefObjects(); - const ReflectClassInfo* m_infos[Index(RefObjectType::CountOf)]; -}; - -} // namespace Slang - -#endif // SLANG_REF_OBJECT_REFLECT_H diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h deleted file mode 100644 index ef2fd7ad5..000000000 --- a/source/slang/slang-serialize-ast-type-info.h +++ /dev/null @@ -1,477 +0,0 @@ -// slang-serialize-ast-type-info.h -#ifndef SLANG_SERIALIZE_AST_TYPE_INFO_H -#define SLANG_SERIALIZE_AST_TYPE_INFO_H - -#include "slang-ast-all.h" -#include "slang-ast-support-types.h" -#include "slang-serialize-misc-type-info.h" -#include "slang-serialize-type-info.h" -#include "slang-serialize-value-type-info.h" - -namespace Slang -{ - -/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AST types !!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -// SyntaxClass<T> -template<typename T> -struct SerialTypeInfo<SyntaxClass<T>> -{ - typedef SyntaxClass<T> NativeType; - typedef uint16_t SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - SLANG_UNUSED(writer); - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - dst = SerialType(src.classInfo->m_classId); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - SLANG_UNUSED(reader); - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - dst.classInfo = ASTClassInfo::getInfo(ASTNodeType(src)); - } -}; - -// MatrixCoord can just go as is -template<> -struct SerialTypeInfo<MatrixCoord> : SerialIdentityTypeInfo<MatrixCoord> -{ -}; - -inline void serializeValPointerValue(SerialWriter* writer, Val* ptrValue, SerialIndex* outSerial) -{ - if (ptrValue) - ptrValue = ptrValue->resolve(); - *(SerialIndex*)outSerial = writer->addPointer(ptrValue); -} - -inline void deserializeValPointerValue( - SerialReader* reader, - const SerialIndex* inSerial, - void* outPtr) -{ - auto val = reader->getValPointer(*(const SerialIndex*)inSerial); - *(void**)outPtr = val.m_ptr; -} - -template<typename T> -struct PtrSerialTypeInfo<T, std::enable_if_t<std::is_base_of_v<Val, T>>> -{ - typedef T* NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - auto ptrValue = *(T**)inNative; - serializeValPointerValue(writer, ptrValue, (SerialIndex*)outSerial); - } - - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - deserializeValPointerValue(reader, (SerialIndex*)inSerial, outNative); - } -}; - -template<typename T> -struct SerialTypeInfo<DeclRef<T>> : public SerialTypeInfo<DeclRefBase*> -{ -}; - -// UIntSet - -template<> -struct SerialTypeInfo<CapabilityAtomSet> -{ - typedef CapabilityAtomSet NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialIndex) - }; - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(NativeType*)native; - auto& dst = *(SerialType*)serial; - - dst = writer->addArray(src.getBuffer().getBuffer(), src.getBuffer().getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - List<CapabilityAtomSet::Element> UIntSetBuffer; - reader->getArray(src, UIntSetBuffer); - - dst = CapabilityAtomSet(); - for (Index i = 0; i < UIntSetBuffer.getCount(); i++) - dst.addRawElement(UIntSetBuffer[i], i); - } -}; - -// ~UIntSet - -template<> -struct SerialTypeInfo<CapabilityStageSet> -{ - struct SerialType - { - SerialIndex stage; - SerialIndex atomSet; - }; - - typedef CapabilityStageSet NativeType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialIndex) - }; - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - List<SerialTypeInfo<CapabilityStageSet>::SerialType> SatomSetsList; - SatomSetsList.setCount(src.atomSet.has_value()); - - if (src.atomSet) - { - auto& i = src.atomSet.value(); - SerialTypeInfo<CapabilityAtomSet>::toSerial(writer, &i, &SatomSetsList[0]); - } - - SerialTypeInfo<CapabilityAtom>::toSerial(writer, &src.stage, &dst.stage); - dst.atomSet = writer->addSerialArray<CapabilityStageSet>( - SatomSetsList.getBuffer(), - SatomSetsList.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - CapabilityAtom stage; - List<CapabilityAtomSet> items; - SerialTypeInfo<CapabilityAtom>::toNative(reader, &src.stage, &stage); - reader->getArray(src.atomSet, items); - - dst.stage = stage; - - for (auto i : items) - { - dst.addNewSet(std::move(i)); - } - } -}; - -template<> -struct SerialTypeInfo<CapabilityTargetSet> -{ - struct SerialType - { - SerialIndex target; - SerialIndex shaderStageSets; - }; - - typedef CapabilityTargetSet NativeType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialIndex) - }; - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - List<SerialTypeInfo<CapabilityStageSet>::SerialType> SStageSetList; - SStageSetList.setCount(src.shaderStageSets.getCount()); - Index iter = 0; - for (auto& i : src.shaderStageSets) - { - SerialTypeInfo<CapabilityStageSet>::toSerial(writer, &i.second, &SStageSetList[iter]); - iter++; - } - - SerialTypeInfo<CapabilityAtom>::toSerial(writer, &src.target, &dst.target); - dst.shaderStageSets = writer->addSerialArray<CapabilityStageSet>( - SStageSetList.getBuffer(), - SStageSetList.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - CapabilityAtom target; - List<CapabilityStageSet> items; - SerialTypeInfo<CapabilityAtom>::toNative(reader, &src.target, &target); - reader->getArray(src.shaderStageSets, items); - - dst.target = target; - - auto& shaderStageSets = dst.shaderStageSets; - shaderStageSets.clear(); - shaderStageSets.reserve(items.getCount()); - for (auto& i : items) - { - dst.shaderStageSets[i.stage] = i; - } - } -}; - -template<> -struct SerialTypeInfo<CapabilitySet> -{ - struct SerialType - { - SerialIndex m_targetSets; - }; - - typedef CapabilitySet NativeType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialIndex) - }; - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - List<SerialTypeInfo<CapabilityTargetSet>::SerialType> STargetSetList; - auto capabilityTargetSets = src.getCapabilityTargetSets(); - STargetSetList.setCount(capabilityTargetSets.getCount()); - Index iter = 0; - for (auto& i : capabilityTargetSets) - { - SerialTypeInfo<CapabilityTargetSet>::toSerial(writer, &i.second, &STargetSetList[iter]); - iter++; - } - - dst.m_targetSets = writer->addSerialArray<CapabilityTargetSet>( - STargetSetList.getBuffer(), - STargetSetList.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - List<CapabilityTargetSet> items; - reader->getArray(src.m_targetSets, items); - - auto& targetSets = dst.getCapabilityTargetSets(); - targetSets.clear(); - targetSets.reserve(items.getCount()); - for (auto& i : items) - { - targetSets[i.target] = i; - } - } -}; - -// ValNodeOperand -template<> -struct SerialTypeInfo<ValNodeOperand> -{ - typedef ValNodeOperand NativeType; - struct SerialType - { - int8_t kind; - int64_t val; - }; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - dst.kind = int8_t(src.kind); - if (src.kind == ValNodeOperandKind::ConstantValue) - dst.val = src.values.intOperand; - else if (src.kind == ValNodeOperandKind::ValNode) - serializeValPointerValue(writer, (Val*)src.values.nodeOperand, (SerialIndex*)&dst.val); - else - SerialTypeInfo<NodeBase*>::toSerial( - writer, - &src.values.nodeOperand, - (SerialIndex*)&dst.val); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - // Initialize - dst = NativeType(); - dst.kind = ValNodeOperandKind(src.kind); - if (dst.kind == ValNodeOperandKind::ConstantValue) - dst.values.intOperand = int64_t(src.val); - else if (dst.kind == ValNodeOperandKind::ValNode) - deserializeValPointerValue( - reader, - (SerialIndex*)&src.val, - (Val**)&dst.values.nodeOperand); - else - SerialTypeInfo<NodeBase*>::toNative( - reader, - (SerialIndex*)&src.val, - (NodeBase**)&dst.values.nodeOperand); - } -}; - -// LookupResultItem -SLANG_VALUE_TYPE_INFO(LookupResultItem) -// QualType -SLANG_VALUE_TYPE_INFO(QualType) - -// LookupResult -template<> -struct SerialTypeInfo<LookupResult> -{ - typedef LookupResult NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - if (src.isOverloaded()) - { - // Save off as an array - dst = writer->addArray(src.items.getBuffer(), src.items.getCount()); - } - else if (src.item.declRef.getDecl()) - { - dst = writer->addArray(&src.item, 1); - } - else - { - dst = SerialIndex(0); - } - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - // Initialize - dst = NativeType(); - - List<LookupResultItem> items; - reader->getArray(src, items); - - if (items.getCount() == 1) - { - dst.item = items[0]; - } - else - { - dst.items.swapWith(items); - // We have to set item such that it is valid/member of items, if items is non empty - dst.item = dst.items[0]; - } - } -}; - -// SpecializationArg -SLANG_VALUE_TYPE_INFO(SpecializationArg) -// ExpandedSpecializationArg -SLANG_VALUE_TYPE_INFO(ExpandedSpecializationArg) -// TypeExp -SLANG_VALUE_TYPE_INFO(TypeExp) -// DeclCheckStateExt -SLANG_VALUE_TYPE_INFO(DeclCheckStateExt) - -// Modifiers -template<> -struct SerialTypeInfo<Modifiers> -{ - typedef Modifiers NativeType; - typedef SerialIndex SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - // We need to make into an array - List<SerialIndex> modifierIndices; - for (Modifier* modifier : *(NativeType*)native) - { - modifierIndices.add(writer->addPointer(modifier)); - } - *(SerialType*)serial = - writer->addArray(modifierIndices.getBuffer(), modifierIndices.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - List<Modifier*> modifiers; - reader->getArray(*(const SerialType*)serial, modifiers); - - Modifier* prev = nullptr; - for (Modifier* modifier : modifiers) - { - if (prev) - { - prev->next = modifier; - } - } - - NativeType& dst = *(NativeType*)native; - dst.first = modifiers.getCount() > 0 ? modifiers[0] : nullptr; - } -}; - -// LookupResultItem_Breadcrumb::ThisParameterMode -template<> -struct SerialTypeInfo<LookupResultItem_Breadcrumb::ThisParameterMode> - : public SerialConvertTypeInfo<LookupResultItem_Breadcrumb::ThisParameterMode, uint8_t> -{ -}; - -// LookupResultItem_Breadcrumb::Kind -template<> -struct SerialTypeInfo<LookupResultItem_Breadcrumb::Kind> - : public SerialConvertTypeInfo<LookupResultItem_Breadcrumb::Kind, uint8_t> -{ -}; - -// RequirementWitness::Flavor -template<> -struct SerialTypeInfo<RequirementWitness::Flavor> - : public SerialConvertTypeInfo<RequirementWitness::Flavor, uint8_t> -{ -}; - -// RequirementWitness -SLANG_VALUE_TYPE_INFO(RequirementWitness) - -// SPIRVAsm -SLANG_VALUE_TYPE_INFO(SPIRVAsmOperand) -SLANG_VALUE_TYPE_INFO(SPIRVAsmInst) - -} // namespace Slang - -#endif diff --git a/source/slang/slang-serialize-ast.cpp b/source/slang/slang-serialize-ast.cpp index a7837edea..aad3bcc57 100644 --- a/source/slang/slang-serialize-ast.cpp +++ b/source/slang/slang-serialize-ast.cpp @@ -1,208 +1,1542 @@ // slang-serialize-ast.cpp #include "slang-serialize-ast.h" -#include "slang-ast-dump.h" -#include "slang-ast-support-types.h" -#include "slang-generated-ast-macro.h" -#include "slang-generated-ast.h" -#include "slang-serialize-ast-type-info.h" -#include "slang-serialize-factory.h" +#include "slang-ast-dispatch.h" +#include "slang-compiler.h" +#include "slang-diagnostics.h" +#include "slang-mangle.h" namespace Slang { +// TODO(tfoley): have the parser export this, or a utility function +// for initializing a `SyntaxDecl` in the common case. +// +NodeBase* parseSimpleSyntax(Parser* parser, void* userData); -// !!!!!!!!!!!!!!!!!!!!!! Generate fields for a type !!!!!!!!!!!!!!!!!!!!!!!!!!! -static const SerialClass* _addClass( - SerialClasses* serialClasses, - ASTNodeType type, - ASTNodeType super, - const List<SerialField>& fields) +struct ASTEncodingContext { - const SerialClass* superClass = - serialClasses->getSerialClass(SerialTypeKind::NodeBase, SerialSubType(super)); - return serialClasses->add( - SerialTypeKind::NodeBase, - SerialSubType(type), - fields.getBuffer(), - fields.getCount(), - superClass); -} +private: + Encoder* encoder; + struct UnhandledCase + { + }; + + typedef Int DeclID; + Dictionary<Decl*, DeclID> mapDeclToID; + List<Decl*> decls; + + struct ImportedDeclInfo + { + Int moduleIndex = -1; + Decl* decl; + }; + List<ImportedDeclInfo> importedDecls; -#define SLANG_AST_ADD_SERIAL_FIELD(FIELD_NAME, TYPE, param) \ - fields.add(SerialField::make(#FIELD_NAME, &obj->FIELD_NAME)); + typedef Int ValID; + Dictionary<Val*, ValID> mapValToID; + List<Val*> vals; -// Note that the obj point is not nullptr, because some compilers notice this is 'indexing from -// null' and warn/error. So we offset from 1. -#define SLANG_AST_ADD_SERIAL_CLASS(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - { \ - NAME* obj = SerialField::getPtr<NAME>(); \ - SLANG_UNUSED(obj); \ - fields.clear(); \ - SLANG_FIELDS_ASTNode_##NAME(SLANG_AST_ADD_SERIAL_FIELD, param) \ - _addClass(serialClasses, ASTNodeType::NAME, ASTNodeType::SUPER, fields); \ + ModuleDecl* _module = nullptr; + + SerialSourceLocWriter* _sourceLocWriter = nullptr; + +public: + ASTEncodingContext(Encoder* encoder, ModuleDecl* module, SerialSourceLocWriter* sourceLocWriter) + : encoder(encoder), _module(module), _sourceLocWriter(sourceLocWriter) + { } -struct ASTFieldAccess -{ - static void calcClasses(SerialClasses* serialClasses) + template<typename T> + void encodeASTNodeContent(T* node) { - // Add NodeBase first, and specially handle so that we add a null super class - serialClasses->add( - SerialTypeKind::NodeBase, - SerialSubType(ASTNodeType::NodeBase), - nullptr, - 0, - nullptr); + Encoder::WithObject withObject(encoder); - // Add the rest in order such that Super class is always added before its children - List<SerialField> fields; - SLANG_CHILDREN_ASTNode_NodeBase(SLANG_AST_ADD_SERIAL_CLASS, _) + ASTNodeDispatcher<T, void>::dispatch(node, [&](auto n) { _encodeDataOf(n); }); + } + + void flush() + { + auto containerChunk = encoder->getRIFFChunk(); + + RiffContainer::Chunk* declChunk = nullptr; + RiffContainer::Chunk* importedDeclChunk = nullptr; + RiffContainer::Chunk* valChunk = nullptr; + { + Encoder::WithArray withList(encoder); + declChunk = encoder->getRIFFChunk(); + } + { + Encoder::WithArray withList(encoder); + importedDeclChunk = encoder->getRIFFChunk(); + } + { + Encoder::WithArray withList(encoder); + valChunk = encoder->getRIFFChunk(); + } + Int declIndex = 0; + Int importedDeclIndex = 0; + Int valIndex = 0; + + bool done = false; + do + { + done = true; + while (declIndex < decls.getCount()) + { + done = false; + encoder->setRIFFChunk(declChunk); + encodeASTNodeContent(decls[declIndex++]); + } + while (importedDeclIndex < importedDecls.getCount()) + { + done = false; + encoder->setRIFFChunk(importedDeclChunk); + encodeImportedDecl(importedDecls[importedDeclIndex++]); + } + while (valIndex < vals.getCount()) + { + done = false; + encoder->setRIFFChunk(valChunk); + encodeASTNodeContent(vals[valIndex++]); + } + } while (!done); + + RiffContainer::calcAndSetSize(containerChunk); + encoder->setRIFFChunk(containerChunk); + } + + ModuleDecl* findModuleForDecl(Decl* decl) + { + for (auto d = decl; d; d = d->parentDecl) + { + if (auto m = as<ModuleDecl>(d)) + return m; + } + return nullptr; } -}; -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerialUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!! + ModuleDecl* findModuleDeclWasImportedFrom(Decl* decl) + { + auto declModule = findModuleForDecl(decl); + if (declModule == nullptr) + return nullptr; + if (declModule == _module) + return nullptr; + return declModule; + } + + DeclID getDeclID(Decl* decl) + { + SLANG_ASSERT(decl != nullptr); + + if (auto found = mapDeclToID.tryGetValue(decl)) + return *found; + + // We need to detect whether the declaration is an + // imported one, or one from this module itself. + // + // Imported declarations need to be handled very + // differently, since they'll involve resolving + // references to those other modules, and the + // declarations within them. + // + if (auto importedFromModule = findModuleDeclWasImportedFrom(decl)) + { + DeclID importedFromModuleDeclID = 0; + if (decl != importedFromModule) + { + importedFromModuleDeclID = getDeclID(importedFromModule); + } + + DeclID id = ~importedDecls.getCount(); + mapDeclToID.add(decl, id); + + ImportedDeclInfo info; + info.moduleIndex = ~importedFromModuleDeclID; + info.decl = decl; + importedDecls.add(info); + + return id; + } + else + { + DeclID id = decls.getCount(); + decls.add(decl); + mapDeclToID.add(decl, id); + + return id; + } + } + + void encodePtr(Decl* decl) + { + DeclID id = getDeclID(decl); + encoder->encode(id); + } + + ValID getValID(Val* val) + { + SLANG_ASSERT(val != nullptr); + + if (auto found = mapValToID.tryGetValue(val)) + return *found; + + // In order to ensure that values can be fully constructed + // from the get-go (so that they will get cached correctly), + // we conspire to ensure that every value is preceded by + // all of its operands. + // + for (auto operand : val->m_operands) + { + switch (operand.kind) + { + default: + break; + + case ValNodeOperandKind::ValNode: + if (auto operandNode = operand.values.nodeOperand) + { + SLANG_ASSERT(as<Val>(operandNode)); + getValID(static_cast<Val*>(operandNode)); + } + break; + + case ValNodeOperandKind::ASTNode: + if (auto operandNode = operand.values.nodeOperand) + { + SLANG_ASSERT(as<Decl>(operandNode)); + getDeclID(static_cast<Decl*>(operandNode)); + } + break; + } + } + auto resolved = val->resolve(); + if (resolved != val) + { + getValID(resolved); + } + + ValID id = vals.getCount(); + vals.add(val); + mapValToID.add(val, id); + return id; + } + + void encodePtr(Val* val) + { + ValID id = getValID(val); + encoder->encode(id); + } + + void encodeImportedDecl(ImportedDeclInfo const& info) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(info.moduleIndex); + auto decl = info.decl; + if (auto importedModuleDecl = as<ModuleDecl>(decl)) + { + SLANG_ASSERT(info.moduleIndex == -1); + encode(importedModuleDecl->getName()); + } + else + { + auto mangledName = getMangledName(getCurrentASTBuilder(), decl); + encode(mangledName); + } + } + + void encodePtr(Modifier* modifier) { encodeASTNodeContent(modifier); } + void encodePtr(Expr* expr) { encodeASTNodeContent(expr); } + void encodePtr(Stmt* stmt) { encodeASTNodeContent(stmt); } + + void encodePtr(Name* name) { encode(name->text); } + + void encodePtr(MarkupEntry* entry) + { + // TODO: is this case needed? + SLANG_UNUSED(entry); + } + + void encodePtr(DeclAssociationList* list) + { + // We serialize this as if it were a simple list + // of key-value pairs because... well... that's + // what it amounts to in practice. + // + Encoder::WithArray withArray(encoder); + for (auto association : list->associations) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(association->kind); + encode(association->decl); + } + } + + void encodePtr(CandidateExtensionList* list) { encode(list->candidateExtensions); } + + void encodePtr(WitnessTable* witnessTable) + { + Encoder::WithObject withObject(encoder); + encode(witnessTable->baseType); + encode(witnessTable->witnessedType); + encode(witnessTable->isExtern); + + // TODO(tfoley): In theory we should be able to streamline + // this so that we only encode the requirements that we + // absolutely need to (which basically amounts to `associatedtype` + // requirements where the satisfying type is part of the public + // API of the type). + // + encode(witnessTable->m_requirementDictionary); + } + + void encodeValue(RequirementWitness const& witness) + { + Encoder::WithKeyValuePair withPair(encoder); + encodeEnum(witness.m_flavor); + switch (witness.m_flavor) + { + case RequirementWitness::Flavor::none: + break; + + case RequirementWitness::Flavor::declRef: + encode(witness.m_declRef); + break; + + case RequirementWitness::Flavor::val: + encode(witness.m_val); + break; + + case RequirementWitness::Flavor::witnessTable: + encode((WitnessTable*)witness.m_obj.Ptr()); + break; + } + } + + void encodePtr(DiagnosticInfo* info) { encode(Int(info->id)); } + + void encodePtr(DeclBase* declBase) + { + if (auto decl = as<Decl>(declBase)) + { + encodePtr(decl); + } + else + { + encodeASTNodeContent(declBase); + } + } + + void encodeValue(UnhandledCase); + + void encodeValue(String const& value) { encoder->encode(value); } + + void encodeValue(Token const& value) + { + encode(value.type); + encode(TokenFlags(value.flags & ~TokenFlag::Name)); + encode(value.loc); + if (value.hasContent()) + encoder->encodeString(value.getContent()); + else + encode(nullptr); + } + + void encodeValue(NameLoc const& value) { encode(value.name); } + + void encodeValue(SemanticVersion value) { encoder->encode(value.toInteger()); } + + void encodeValue(CapabilitySet const& value) + { + // While the `CapabilityTargetSets` type is a dictionary, + // in practice each entry already embeds its own key + // (the target atom), so we can encode this as just + // an array of the `CapabilityTargetSet` values. + // + Encoder::WithArray withArray(encoder); + for (auto pair : value.getCapabilityTargetSets()) + { + encode(pair.second); + } + } + + void encodeValue(CapabilityTargetSet const& value) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(value.target); + + // Similar to the case for the `CapabilityTargetSets` above, + // each `CapabilityStageSet` already includes the stage atom, + // so we can simply encode the values from the dictionary. + // + Encoder::WithArray withArray(encoder); + for (auto pair : value.shaderStageSets) + { + encode(pair.second); + } + } + + void encodeValue(CapabilityStageSet const& value) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(value.stage); + encode(value.atomSet); + } + + void encodeValue(CapabilityAtomSet const& value) + { + Encoder::WithArray withArray(encoder); + for (auto rawAtom : value) + { + encode(CapabilityAtom(rawAtom)); + } + } -/* static */ void ASTSerialUtil::addSerialClasses(SerialClasses* serialClasses) + template<typename T> + void encodeValue(std::optional<T> const& value) + { + if (value) + encodeValue(*value); + else + encoder->encode(nullptr); + } + + void encodeValue(SyntaxClass<NodeBase> const& value) { encode(value.getTag()); } + + template<typename T> + void encodeValue(DeclRef<T> const& value) + { + encode((DeclRefBase*)value); + } + + void encodeValue(ValNodeOperand value) + { + Encoder::WithKeyValuePair withPair(encoder); + + encodeEnum(value.kind); + switch (value.kind) + { + case ValNodeOperandKind::ConstantValue: + encode(value.values.intOperand); + break; + + case ValNodeOperandKind::ValNode: + encode(static_cast<Val*>(value.values.nodeOperand)); + break; + + case ValNodeOperandKind::ASTNode: + { + if (auto decl = as<Decl>(value.values.nodeOperand)) + { + encode(decl); + } + else + { + SLANG_UNEXPECTED("AST node operand of `Val` was expected to be a `Decl`"); + } + } + break; + } + } + + void encodeValue(TypeExp value) { encode(value.type); } + + void encodeValue(QualType value) + { + Encoder::WithObject withObject(encoder); + encode(value.type); + encode(value.isLeftValue); + encode(value.hasReadOnlyOnTarget); + encode(value.isWriteOnly); + } + + void encodeValue(MatrixCoord value) + { + Encoder::WithObject withObject(encoder); + encode(value.row); + encode(value.col); + } + + void encodeValue(SPIRVAsmOperand::Flavor const& value) { encodeEnum(value); } + + void encodeValue(SPIRVAsmOperand const& value) + { + Encoder::WithObject withObject(encoder); + encode(value.flavor); + encode(value.token); + encode(value.expr); + encode(value.bitwiseOrWith); + encode(value.knownValue); + encode(value.wrapInId); + encode(value.type); + } + + void encodeValue(SPIRVAsmInst const& value) + { + Encoder::WithObject withObject(encoder); + encode(value.opcode); + encode(value.operands); + } + + + template<typename T, typename = std::enable_if_t<std::is_same_v<T, bool>>> + void encodeValue(T value) + { + encoder->encodeBool(value); + } + + void encodeValue(Int32 value) { encoder->encode(value); } + void encodeValue(UInt32 value) { encoder->encode(value); } + void encodeValue(Int64 value) { encoder->encode(value); } + void encodeValue(UInt64 value) { encoder->encode(value); } + void encodeValue(float value) { encoder->encode(value); } + void encodeValue(double value) { encoder->encode(value); } + + void encodeValue(uint8_t value) { encoder->encode(UInt32(value)); } + + void encodeValue(nullptr_t) { encoder->encode(nullptr); } + + template<typename T> + void encodeEnum(T value) + { + encoder->encode(Int32(value)); + } + + void encodeValue(DeclVisibility value) { encodeEnum(value); } + void encodeValue(BaseType value) { encodeEnum(value); } + void encodeValue(BuiltinRequirementKind value) { encodeEnum(value); } + void encodeValue(ASTNodeType value) { encodeEnum(value); } + void encodeValue(ImageFormat value) { encodeEnum(value); } + void encodeValue(TypeTag value) { encodeEnum(value); } + void encodeValue(TryClauseType value) { encodeEnum(value); } + void encodeValue(CapabilityAtom value) { encodeEnum(value); } + void encodeValue(DeclAssociationKind value) { encodeEnum(value); } + void encodeValue(TokenType value) { encodeEnum(value); } + + void encodeValue(SourceLoc value) + { + if (!_sourceLocWriter) + { + encoder->encode(nullptr); + } + else + { + auto intermediate = _sourceLocWriter->addSourceLoc(value); + encoder->encode(intermediate); + } + } + + template<typename T> + void encodeValue(T const* ptr) + { + if (!ptr) + { + encoder->encode(nullptr); + } + else + { + encodePtr(const_cast<T*>(ptr)); + } + } + + template<typename T> + void encodeValue(RefPtr<T> const& ptr) + { + if (!ptr) + { + encoder->encode(nullptr); + } + else + { + encodePtr(ptr.Ptr()); + } + } + + void encodeValue(Modifiers const& modifiers) + { + Encoder::WithArray withArray(encoder); + for (auto m : const_cast<Modifiers&>(modifiers)) + { + encode(m); + } + } + + template<typename T, int N> + void encodeValue(ShortList<T, N> const& array) + { + Encoder::WithArray withArray(encoder); + for (auto element : array) + { + encode(element); + } + } + + + template<typename T> + void encode(List<T> const& array) + { + Encoder::WithArray withArray(encoder); + for (auto element : array) + { + encode(element); + } + } + + template<typename T, size_t N> + void encode(T const (&array)[N]) + { + Encoder::WithArray withArray(encoder); + for (auto element : array) + { + encode(element); + } + } + + template<typename K, typename V> + void encode(OrderedDictionary<K, V> const& dictionary) + { + Encoder::WithArray withArray(encoder); + for (auto p : dictionary) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(p.key); + encode(p.value); + } + } + + template<typename K, typename V> + void encode(Dictionary<K, V> const& dictionary) + { + Encoder::WithArray withArray(encoder); + for (auto p : dictionary) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(p.first); + encode(p.second); + } + } + + template<typename T> + void encode(T const& value) + { + encodeValue(value); + } + + // for each class of node, we generate + // code to recursively serialize each + // of its fields. + +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do + void _encodeDataOf($T* obj) + { +%if T.directSuperClass then + _encodeDataOf(static_cast<$(T.directSuperClass)*>(obj)); +%end +%for _,f in ipairs(T.directFields) do + encode(obj->$f); +%end + } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-serialize-ast.cpp.fiddle" +#endif // FIDDLE END +}; + +void writeSerializedModuleAST( + Encoder* encoder, + ModuleDecl* moduleDecl, + SerialSourceLocWriter* sourceLocWriter) { - ASTFieldAccess::calcClasses(serialClasses); + Encoder::WithObject withObject(encoder); + + // TODO: we should have a more careful pass here, + // where we only encode the public declarations + // + + ASTEncodingContext context(encoder, moduleDecl, sourceLocWriter); + context.getDeclID(moduleDecl); + context.flush(); } -/* static */ SlangResult ASTSerialUtil::testSerialize( - NodeBase* node, - RootNamePool* rootNamePool, - SharedASTBuilder* sharedASTBuilder, - SourceManager* sourceManager) +struct ASTDecodingContext { - RefPtr<SerialClasses> classes; +public: + ASTDecodingContext( + Linkage* linkage, + ASTBuilder* astBuilder, + DiagnosticSink* sink, + RiffContainer::Chunk* rootChunk, + SerialSourceLocReader* sourceLocReader, + SourceLoc requestingSourceLoc) + : _linkage(linkage) + , _astBuilder(astBuilder) + , _sink(sink) + , _rootChunk(static_cast<RiffContainer::ListChunk*>(rootChunk)) + , _sourceLocReader(sourceLocReader) + , _requestingSourceLoc(requestingSourceLoc) + { + } + + Linkage* _linkage = nullptr; + DiagnosticSink* _sink = nullptr; + SerialSourceLocReader* _sourceLocReader = nullptr; + SourceLoc _requestingSourceLoc; + + SlangResult decodeAll() + { + auto cursor = _rootChunk->getFirstContainedChunk(); + + // There are a few different top-level chunks that + // hold different arrays that we need in order + // to decode the entire module hierarchy. + // + // Basically, these lists correspond to the kinds + // of nodes in the AST hierarchy for which back-references + // are allowed (all other nodes should, barring + // weird corner cases, form a single tree-structured + // ownership hierarchy, rooted at the `ModuleDecl`. + // + + // First there is the list that actually encodes + // for the declarations in the module, including + // the `ModuleDecl` itself, which should be the + // first entry in the list. + // + auto declChunk = cursor; + cursor = cursor->m_next; + + // Next there is a list of all the declarations + // referenced inside of the module that need to + // be imported in from outside. + // + auto importedDeclChunk = cursor; + cursor = cursor->m_next; + + // Then there are all the `Val`-derived nodes that + // are needed by the module, which will need to be + // deduplicated so that they are unique within the + // current compilation context. + // + auto valChunk = cursor; + cursor = cursor->m_next; + + // The process of decoding the module is then spread + // over a number of steps. + // + // The first step is to process all of the imported + // declarations, so that other nodes can refer to + // them. + // + SLANG_RETURN_ON_FAIL(decodeImportedDecls(importedDeclChunk)); + + // Next we process the declarations that are within + // the module itself, first creating an "empty shell" + // of each declaration that has the right size in + // memory (and the right `ASTNodeType` tag), so that + // we can wire up references to it (including circular + // references)... so long as nothing here tries to + // look *inside* the empty shell along the way. + // + SLANG_RETURN_ON_FAIL(createEmptyShells(declChunk)); + + // Once all the `Decl`s that might be needed have + // been allocated, we can process all the `Val`s + // that might reference those`Decl`s (and one another). + // + // The nature of the `Val` representation ensures + // that there cannot be cirularities in the references + // between `Val`s, and the encoding process will have + // sorted the entries so that a `Val` only ever appears + // *after* its operands. + // + SLANG_RETURN_ON_FAIL(decodeVals(valChunk)); + + // Once all the back-reference-able objects have been + // instantiated in memory, we can go back through the + // `Decl`s in the module and fill in those empty shells. + // + SLANG_RETURN_ON_FAIL(fillEmptyShells(declChunk)); + + // As a final pass, we perform any special cleanup actions + // that might be required to make the output valid for consumers. + // + // For example, this is where we set the `DeclCheckState` of everything + // we are loading to reflect the fact that everything we deserialize + // is (supposed to be) fully cheked. + // + SLANG_RETURN_ON_FAIL(cleanUpNodes()); + + + return SLANG_OK; + } + + typedef Int DeclID; + Decl* getDeclByID(DeclID id) + { + if (id >= 0) + { + return _decls[id]; + } + else + { + return _importedDecls[~id]; + } + } + +private: + struct UnhandledCase + { + }; + + ASTBuilder* _astBuilder = nullptr; + RiffContainer::ListChunk* _rootChunk = nullptr; + + List<Decl*> _decls; + List<Decl*> _importedDecls; + List<Val*> _vals; + + typedef Int ValID; + Val* getValByID(ValID id) { return _vals[id]; } + + SlangResult decodeImportedDecls(RiffContainer::Chunk* importedDeclChunk) + { + Decoder decoder(importedDeclChunk); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Decoder::WithKeyValuePair withPair(decoder); + + Int moduleIndex; + decode(moduleIndex, decoder); + + if (moduleIndex == -1) + { + Name* moduleName = nullptr; + decode(moduleName, decoder); + + Decl* importedModule = getImportedModule(moduleName); + _importedDecls.add(importedModule); + } + else + { + auto importedFromModuleDecl = as<ModuleDecl>(_importedDecls[moduleIndex]); + auto importedFromModule = importedFromModuleDecl->module; + + String mangledName; + decode(mangledName, decoder); + + auto importedNode = + importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); + auto importedDecl = as<Decl>(importedNode); + _importedDecls.add(importedDecl); + } + } + return SLANG_OK; + } + + ModuleDecl* getImportedModule(Name* moduleName) + { + Module* module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); + if (!module) + { + SLANG_ABORT_COMPILATION("failed to load an imported module during deserialization"); + } + + return module->getModuleDecl(); + } + + SlangResult decodeVals(RiffContainer::Chunk* valChunk) + { + Decoder decoder(valChunk); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Val* val = decodeValNode(decoder); + _vals.add(val); + } + return SLANG_OK; + } - SerialClassesUtil::create(classes); + SlangResult createEmptyShells(RiffContainer::Chunk* declChunk) + { + Decoder decoder(declChunk); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + ASTNodeType nodeType; + + // Each of the declarations is expected to take + // the form of an object with a first field + // that holds the node type. + // + { + Decoder::WithObject withObject(decoder); + decode(nodeType, decoder); + } + + auto emptyShell = createEmptyShell(nodeType); + auto declEmptyShell = as<Decl>(emptyShell); + _decls.add(declEmptyShell); + } - List<uint8_t> contents; + return SLANG_OK; + } + Val* decodeValNode(Decoder& decoder) { - OwnedMemoryStream stream(FileAccess::ReadWrite); + Decoder::WithObject withObject(decoder); - ModuleDecl* moduleDecl = as<ModuleDecl>(node); - // Only serialize out things *in* this module - ModuleSerialFilter filterStorage(moduleDecl); + ASTNodeType nodeType; + decode(nodeType, decoder); - SerialFilter* filter = moduleDecl ? &filterStorage : nullptr; + ValNodeDesc desc; + desc.type = SyntaxClass<NodeBase>(nodeType); - SerialWriter writer(classes, filter); + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + ValNodeOperand operand; + decode(operand, decoder); + desc.operands.add(operand); + } - // Lets serialize it all - writer.addPointer(node); - // Let's stick it all in a stream - writer.write(&stream); + desc.init(); - stream.swapContents(contents); + auto val = _astBuilder->_getOrCreateImpl(_Move(desc)); - NamePool namePool; - namePool.setRootNamePool(rootNamePool); + // Values created during deserialization are + // not expected to ever resolve further, because + // they should be coming from fully checked code. + // + // val->resolve(); + // val->_setUnique(); - ASTBuilder builder(sharedASTBuilder, "Serialize Check"); + return val; + } - SetASTBuilderContextRAII astBuilderRAII(&builder); + NodeBase* createEmptyShell(ASTNodeType nodeType) + { + return SyntaxClass<NodeBase>(nodeType).createInstance(_astBuilder); + } - DefaultSerialObjectFactory objectFactory(&builder); + SlangResult fillEmptyShells(RiffContainer::Chunk* declChunk) + { + Index declIndex = 0; - // We could now check that the loaded data matches + Decoder decoder(declChunk); + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + auto declEmptyShell = _decls[declIndex++]; + decodeASTNodeContent(declEmptyShell, decoder); + } + return SLANG_OK; + } + + SlangResult cleanUpNodes() + { + for (auto decl : _decls) { - const List<SerialInfo::Entry*>& writtenEntries = writer.getEntries(); - List<const SerialInfo::Entry*> readEntries; + decl->checkState = DeclCheckState::CapabilityChecked; + } - SlangResult res = SerialReader::loadEntries( - contents.getBuffer(), - contents.getCount(), - classes, - readEntries); - SLANG_UNUSED(res); + return SLANG_OK; + } - SLANG_ASSERT(writtenEntries.getCount() == readEntries.getCount()); - // They should be identical up to the - for (Index i = 1; i < readEntries.getCount(); ++i) + void assignGenericParameterIndices(GenericDecl* genericDecl) + { + int parameterCounter = 0; + for (auto m : genericDecl->members) + { + if (auto typeParam = as<GenericTypeParamDeclBase>(m)) + { + typeParam->parameterIndex = parameterCounter++; + } + else if (auto valParam = as<GenericValueParamDecl>(m)) { - auto writtenEntry = writtenEntries[i]; - auto readEntry = readEntries[i]; + valParam->parameterIndex = parameterCounter++; + } + } + } + - const size_t writtenSize = writtenEntry->calcSize(classes); - const size_t readSize = readEntry->calcSize(classes); - SLANG_UNUSED(writtenSize); - SLANG_UNUSED(readSize); + void cleanUpASTNode(NodeBase* node) + { + if (auto expr = as<Expr>(node)) + { + expr->checked = true; + } + else if (auto genericDecl = as<GenericDecl>(node)) + { + assignGenericParameterIndices(genericDecl); + } + else if (auto syntaxDecl = as<SyntaxDecl>(node)) + { + syntaxDecl->parseCallback = &parseSimpleSyntax; + syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); + } + else if (auto namespaceLikeDecl = as<NamespaceDeclBase>(node)) + { + auto declScope = _astBuilder->create<Scope>(); + declScope->containerDecl = namespaceLikeDecl; + namespaceLikeDecl->ownedScope = declScope; + } + } + + void decodeASTNodeContent(NodeBase* node, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); - SLANG_ASSERT(readSize == writtenSize); - // Check the payload is the same - SLANG_ASSERT(memcmp(readEntry, writtenEntry, readSize) == 0); + ASTNodeDispatcher<NodeBase, void>::dispatch( + node, + [&](auto n) { _decodeDataOf(n, decoder); }); + + cleanUpASTNode(node); + } + + DeclID decodeDeclID(Decoder& decoder) + { + DeclID result = decoder.decode<DeclID>(); + return result; + } + + ValID decodeValID(Decoder& decoder) + { + ValID result = decoder.decode<ValID>(); + return result; + } + + template<typename T> + void decodeASTNode(T*& node, Decoder& decoder) + { + ASTNodeType nodeType; + auto saved = decoder.getCursor(); + { + Decoder::WithObject withObject(decoder); + decode(nodeType, decoder); + } + decoder.setCursor(saved); + + auto shell = createEmptyShell(nodeType); + decodeASTNodeContent(shell, decoder); + + node = as<T>(shell); + } + + void decodePtr(Name*& name, Decoder& decoder, Name*) + { + String text; + decode(text, decoder); + + name = _astBuilder->getNamePool()->getName(text); + } + + void decodePtr(DeclAssociationList*& outList, Decoder& decoder, DeclAssociationList*) + { + // Mirroring the encoding logic, we decode this + // as a list of key-value pairs. + // + auto list = RefPtr(new DeclAssociationList()); + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + auto association = RefPtr(new DeclAssociation()); + + Decoder::WithKeyValuePair withPair(decoder); + decode(association->kind, decoder); + decode(association->decl, decoder); + + list->associations.add(association); + } + + outList = list.detach(); + } + + void decodePtr(DiagnosticInfo const*& info, Decoder& decoder, DiagnosticInfo const*) + { + Int id; + decode(id, decoder); + info = getDiagnosticsLookup()->getDiagnosticById(id); + } + + void decodePtr(MarkupEntry*& markupEntry, Decoder&, MarkupEntry*) + { + // TODO: is this case needed? + markupEntry = nullptr; + } + + void decodePtr(CandidateExtensionList*& list, Decoder& decoder, CandidateExtensionList*) + { + auto result = RefPtr(new CandidateExtensionList()); + decode(result->candidateExtensions, decoder); + list = result.detach(); + } + + void decodePtr(WitnessTable*& witnessTable, Decoder& decoder, WitnessTable*) + { + Decoder::WithObject withObject(decoder); + auto wt = RefPtr(new WitnessTable()); + decode(wt->baseType, decoder); + decode(wt->witnessedType, decoder); + decode(wt->isExtern, decoder); + decode(wt->m_requirementDictionary, decoder); + witnessTable = wt.detach(); + } + + void decodeValue(RequirementWitness& witness, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); + decodeEnum(witness.m_flavor, decoder); + switch (witness.m_flavor) + { + case RequirementWitness::Flavor::none: + break; + + case RequirementWitness::Flavor::declRef: + decode(witness.m_declRef, decoder); + break; + + case RequirementWitness::Flavor::val: + decode(witness.m_val, decoder); + break; + + case RequirementWitness::Flavor::witnessTable: + { + RefPtr<WitnessTable> object; + decode(object, decoder); + witness.m_obj = object; } + break; } + } - SerialReader reader(classes, nullptr); + template<typename T> + void decodePtr(T*& node, Decoder& decoder, Val*) + { + ValID id = decodeValID(decoder); + node = static_cast<T*>(getValByID(id)); + } + + template<typename T> + void decodePtr(T*& node, Decoder& decoder, Decl*) + { + DeclID id = decodeDeclID(decoder); + node = static_cast<T*>(getDeclByID(id)); + } + + template<typename T> + void decodePtr(T*& node, Decoder& decoder, DeclBase*) + { + if (decoder.getTag() == SerialBinary::kInt64FourCC) + { + DeclID id = decodeDeclID(decoder); + node = static_cast<T*>(getDeclByID(id)); + } + else { + decodeASTNode(node, decoder); + } + } + + template<typename T> + void decodePtr(T*& node, Decoder& decoder, NodeBase*) + { + decodeASTNode(node, decoder); + } + + + void decodeValue(UnhandledCase, Decoder& decoder); - SlangResult res = reader.load(contents.getBuffer(), contents.getCount(), &namePool); - SLANG_UNUSED(res); + void decodeValue(String& value, Decoder& decoder) { value = decoder.decodeString(); } + + void decodeValue(Token& value, Decoder& decoder) + { + decode(value.type, decoder); + decode(value.flags, decoder); + decode(value.loc, decoder); + if (decoder.decodeNull()) + { } + else + { + Name* name = nullptr; + decode(name, decoder); + value.setName(name); + } + } + + void decodeValue(NameLoc& value, Decoder& decoder) { decode(value.name, decoder); } - // Lets see what we have - const ASTDumpUtil::Flags dumpFlags = - ASTDumpUtil::Flag::HideSourceLoc | ASTDumpUtil::Flag::HideScope; + void decodeValue(SemanticVersion& value, Decoder& decoder) + { + SemanticVersion::IntegerType rawValue = decoder.decode<SemanticVersion::IntegerType>(); + value.setFromInteger(rawValue); + } - String readDump; + void decodeValue(CapabilitySet& value, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) { - SourceWriter sourceWriter(sourceManager, LineDirectiveMode::None, nullptr); - ASTDumpUtil::dump( - reader.getPointer(SerialIndex(1)).dynamicCast<NodeBase>(), - ASTDumpUtil::Style::Hierachical, - dumpFlags, - &sourceWriter); - readDump = sourceWriter.getContentAndClear(); + CapabilityTargetSet targetSet; + decode(targetSet, decoder); + value.getCapabilityTargetSets()[targetSet.target] = targetSet; } - String origDump; + } + + void decodeValue(CapabilityTargetSet& value, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); + decode(value.target, decoder); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) { - SourceWriter sourceWriter(sourceManager, LineDirectiveMode::None, nullptr); - ASTDumpUtil::dump(node, ASTDumpUtil::Style::Hierachical, dumpFlags, &sourceWriter); - origDump = sourceWriter.getContentAndClear(); + CapabilityStageSet stageSet; + decode(stageSet, decoder); + value.shaderStageSets[stageSet.stage] = stageSet; } + } - // Write out - File::writeAllText("ast-read.ast-dump", readDump); - File::writeAllText("ast-orig.ast-dump", origDump); + void decodeValue(CapabilityStageSet& value, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); + decode(value.stage, decoder); + decode(value.atomSet, decoder); + } - if (readDump != origDump) + void decodeValue(CapabilityAtomSet& value, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) { - return SLANG_FAIL; + CapabilityAtom atom; + decode(atom, decoder); + value.add(UInt(atom)); } } - return SLANG_OK; -} + template<typename T> + void decodeValue(std::optional<T>& outValue, Decoder& decoder) + { + if (decoder.decodeNull()) + { + outValue.reset(); + } + else + { + T value; + decode(value, decoder); + outValue = value; + } + } -/* static */ List<uint8_t> ASTSerialUtil::serializeAST(ModuleDecl* moduleDecl) -{ - // TODO: we should store `classes` in GlobalSession to avoid recomputing them every time. - RefPtr<SerialClasses> classes; - SerialClassesUtil::create(classes); + void decodeValue(SyntaxClass<NodeBase>& syntaxClass, Decoder& decoder) + { + ASTNodeType nodeType; + decode(nodeType, decoder); + syntaxClass = SyntaxClass<NodeBase>(nodeType); + } - List<uint8_t> contents; - OwnedMemoryStream stream(FileAccess::ReadWrite); + template<typename T> + void decodeValue(DeclRef<T>& declRef, Decoder& decoder) + { + decode(declRef.declRefBase, decoder); + } - // Only serialize out things *in* this module - ModuleSerialFilter filterStorage(moduleDecl); + void decodeValue(ValNodeOperand& value, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); - SerialFilter* filter = moduleDecl ? &filterStorage : nullptr; + decodeEnum(value.kind, decoder); + switch (value.kind) + { + case ValNodeOperandKind::ConstantValue: + decode(value.values.intOperand, decoder); + break; - SerialWriter writer(classes, filter); + case ValNodeOperandKind::ValNode: + { + Val* val = nullptr; + decode(val, decoder); + value.values.nodeOperand = val; + } + break; - // Lets serialize it all - writer.addPointer(moduleDecl); - // Let's stick it all in a stream - writer.write(&stream); + case ValNodeOperandKind::ASTNode: + { + Decl* decl = nullptr; + decode(decl, decoder); + value.values.nodeOperand = decl; + } + break; + } + } + + void decodeValue(TypeExp& value, Decoder& decoder) { decode(value.type, decoder); } + + void decodeValue(QualType& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.type, decoder); + decode(value.isLeftValue, decoder); + decode(value.hasReadOnlyOnTarget, decoder); + decode(value.isWriteOnly, decoder); + } + + void decodeValue(MatrixCoord& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.row, decoder); + decode(value.col, decoder); + } + + void decodeValue(SPIRVAsmOperand::Flavor& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + + void decodeValue(SPIRVAsmOperand& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.flavor, decoder); + decode(value.token, decoder); + decode(value.expr, decoder); + decode(value.bitwiseOrWith, decoder); + decode(value.knownValue, decoder); + decode(value.wrapInId, decoder); + decode(value.type, decoder); + } + + void decodeValue(SPIRVAsmInst& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.opcode, decoder); + decode(value.operands, decoder); + } + + + template<typename T> + void decodeEnum(T& value, Decoder& decoder) + { + value = T(decoder.decode<Int32>()); + } + + template<typename T> + void decodeSimpleValue(T& value, Decoder& decoder) + { + value = decoder.decode<T>(); + } + + void decodeValue(bool& value, Decoder& decoder) { value = decoder.decodeBool(); } + void decodeValue(Int32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(Int64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(UInt32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(UInt64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(float& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(double& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + + void decodeValue(uint8_t& value, Decoder& decoder) + { + value = uint8_t(decoder.decode<UInt32>()); + } + + void decodeValue(DeclVisibility& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(BaseType& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(BuiltinRequirementKind& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(ASTNodeType& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(ImageFormat& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(TypeTag& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(TryClauseType& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(CapabilityAtom& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(PreferRecomputeAttribute::SideEffectBehavior& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(LogicOperatorShortCircuitExpr::Flavor& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(TreatAsDifferentiableExpr::Flavor& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(DeclAssociationKind& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(TokenType& value, Decoder& decoder) { decodeEnum(value, decoder); } + + + void decodeValue(SourceLoc& value, Decoder& decoder) + { + if (!decoder.decodeNull()) + { + SerialSourceLocData::SourceLoc intermediate; + decoder.decode(intermediate); + + if (_sourceLocReader) + { + auto sourceLoc = _sourceLocReader->getSourceLoc(intermediate); + value = sourceLoc; + } + } + } + + template<typename T> + void decodeValue(T*& ptr, Decoder& decoder) + { + if (decoder.decodeNull()) + ptr = nullptr; + else + decodePtr(ptr, decoder, (T*)nullptr); + } + + template<typename T> + void decodeValue(RefPtr<T>& ptr, Decoder& decoder) + { + if (decoder.decodeNull()) + ptr = nullptr; + else + { + // Hi Future Tess, + // + // The next step here is decoding logic for `WitnessTable`s. + // + + decodePtr(*ptr.writeRef(), decoder, (T*)nullptr); + } + } + + void decodeValue(Modifiers& modifiers, Decoder& decoder) + { + Modifier** link = &modifiers.first; + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Modifier* modifier = nullptr; + decode(modifier, decoder); + + *link = modifier; + link = &modifier->next; + } + } + + template<typename T, int N> + void decodeValue(ShortList<T, N>& array, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + T element; + decode(element, decoder); + array.add(element); + } + } - stream.swapContents(contents); - return contents; -} + template<typename T> + void decode(List<T>& array, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + T element; + decode(element, decoder); + array.add(element); + } + } + + template<typename T, size_t N> + void decode(T (&array)[N], Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + for (auto& element : array) + { + decode(element, decoder); + } + } + + template<typename K, typename V> + void decode(OrderedDictionary<K, V>& dictionary, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Decoder::WithKeyValuePair withPair(decoder); + + K key; + V value; + decode(key, decoder); + decode(value, decoder); + + dictionary.add(key, value); + } + } + + template<typename K, typename V> + void decode(Dictionary<K, V>& dictionary, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Decoder::WithKeyValuePair withPair(decoder); + + K key; + V value; + decode(key, decoder); + decode(value, decoder); + + dictionary.add(key, value); + } + } + + template<typename T> + void decode(T& outValue, Decoder& decoder) + { + decodeValue(outValue, decoder); + } + +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do + void _decodeDataOf($T* obj, Decoder& decoder) + { +% if T.directSuperClass then + _decodeDataOf(static_cast<$(T.directSuperClass)*>(obj), decoder); +% end +% for _,f in ipairs(T.directFields) do + decode(obj->$f, decoder); +% end + } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 1 +#include "slang-serialize-ast.cpp.fiddle" +#endif // FIDDLE END +}; + +ModuleDecl* readSerializedModuleAST( + Linkage* linkage, + ASTBuilder* astBuilder, + DiagnosticSink* sink, + RiffContainer::Chunk* chunk, + SerialSourceLocReader* sourceLocReader, + SourceLoc requestingSourceLoc) +{ + ASTDecodingContext + context(linkage, astBuilder, sink, chunk, sourceLocReader, requestingSourceLoc); + context.decodeAll(); + auto node = context.getDeclByID(0); + auto moduleDecl = as<ModuleDecl>(node); + return moduleDecl; +} } // namespace Slang diff --git a/source/slang/slang-serialize-ast.h b/source/slang/slang-serialize-ast.h index b8af9484c..6adeae8dd 100644 --- a/source/slang/slang-serialize-ast.h +++ b/source/slang/slang-serialize-ast.h @@ -6,53 +6,23 @@ #include "slang-ast-all.h" #include "slang-ast-builder.h" #include "slang-ast-support-types.h" +#include "slang-serialize-source-loc.h" #include "slang-serialize.h" namespace Slang { - -/* Holds RIFF FourCC codes for AST types */ -struct ASTSerialBinary -{ - static const FourCC kRiffFourCC = RiffFourCC::kRiff; - - /// AST module LIST container - static const FourCC kSlangASTModuleFourCC = SLANG_FOUR_CC('S', 'A', 'm', 'l'); - /// AST module data - static const FourCC kSlangASTModuleDataFourCC = SLANG_FOUR_CC('S', 'A', 'm', 'd'); -}; - -class ModuleSerialFilter : public SerialFilter -{ -public: - // SerialFilter impl - virtual SerialIndex writePointer(SerialWriter* writer, const NodeBase* ptr) SLANG_OVERRIDE; - virtual SerialIndex writePointer(SerialWriter* writer, const RefObject* ptr) SLANG_OVERRIDE; - - ModuleSerialFilter(ModuleDecl* moduleDecl) - : m_moduleDecl(moduleDecl) - { - } - -protected: - ModuleDecl* m_moduleDecl; -}; - -struct ASTSerialUtil -{ - /// Add the AST related classes - static void addSerialClasses(SerialClasses* classes); - - /// Tries to serialize out, read back in and test the results are the same. - /// Will write dumped out node to files - static SlangResult testSerialize( - NodeBase* node, - RootNamePool* rootNamePool, - SharedASTBuilder* sharedASTBuilder, - SourceManager* sourceManager); - - static List<uint8_t> serializeAST(ModuleDecl* moduleDecl); -}; +void writeSerializedModuleAST( + Encoder* encoder, + ModuleDecl* moduleDecl, + SerialSourceLocWriter* sourceLocWriter); + +ModuleDecl* readSerializedModuleAST( + Linkage* linkage, + ASTBuilder* astBuilder, + DiagnosticSink* sink, + RiffContainer::Chunk* chunk, + SerialSourceLocReader* sourceLocReader, + SourceLoc requestingSourceLoc); } // namespace Slang diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp index f82357459..c2253ed45 100644 --- a/source/slang/slang-serialize-container.cpp +++ b/source/slang/slang-serialize-container.cpp @@ -10,89 +10,239 @@ #include "slang-mangled-lexer.h" #include "slang-parser.h" #include "slang-serialize-ast.h" -#include "slang-serialize-factory.h" #include "slang-serialize-ir.h" #include "slang-serialize-source-loc.h" namespace Slang { - -/* static */ SlangResult SerialContainerUtil::write( - Module* module, - const WriteOptions& options, - Stream* stream) +struct ModuleEncodingContext { - RiffContainer container; +public: + ModuleEncodingContext(SerialContainerUtil::WriteOptions const& options, Stream* stream) + : options(options), encoder(stream), containerStringPool(StringSlicePool::Style::Default) { - SerialContainerData data; - SLANG_RETURN_ON_FAIL(SerialContainerUtil::addModuleToData(module, options, data)); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(data, options, &container)); + if (options.optionFlags & SerialOptionFlag::SourceLocation) + { + sourceLocWriter = new SerialSourceLocWriter(options.sourceManager); + } } - // We now write the RiffContainer to the stream - SLANG_RETURN_ON_FAIL(RiffUtil::write(container.getRoot(), true, stream)); - return SLANG_OK; -} -/* static */ SlangResult SerialContainerUtil::write( - FrontEndCompileRequest* frontEndReq, - const WriteOptions& options, - Stream* stream) -{ - RiffContainer container; + ~ModuleEncodingContext() + { + encoder.setRIFFChunk(encoder.getRIFF()->getRoot()); + encodeFinalPieces(); + } + + SlangResult encodeModuleList(FrontEndCompileRequest* frontEndReq) + { + // Encoding a front-end compile request into a RIFF + // is simply a matter of encoding the module for each + // of the translation units that got compiled. + // + Encoder::WithKeyValuePair withArray(&encoder, SerialBinary::kModuleListFourCc); + for (TranslationUnitRequest* translationUnit : frontEndReq->translationUnits) + { + SLANG_RETURN_ON_FAIL(encode(translationUnit->module)); + } + return SLANG_OK; + } + + SlangResult encode(FrontEndCompileRequest* frontEndReq) + { + Encoder::WithObject withObject(&encoder, SerialBinary::kContainerFourCc); + SLANG_RETURN_ON_FAIL(encodeModuleList(frontEndReq)); + return SLANG_OK; + } + + SlangResult encode(EndToEndCompileRequest* request) + { + Encoder::WithObject withObject(&encoder, SerialBinary::kContainerFourCc); + + // Encoding an end-to-end compile request starts with the same + // work as for a front-end request: we encode each of + // the modules for the translation units. + // + SLANG_RETURN_ON_FAIL(encodeModuleList(request->getFrontEndReq())); + // + // If code generation is disabled, then we can skip all further + // steps, and the encoding process is no different + // than for a front-end request. + // + if (request->getOptionSet().getBoolOption(CompilerOptionName::SkipCodeGen)) + { + return SLANG_OK; + } + + // If code generation is enabled, then we need to encode + // information on each of the code generation targets, as well + // as the entry points. + // + // We start with the targets, each of which will have a Slang IR + // representation of the layout information for the program + // on that target. + // + auto linkage = request->getLinkage(); + auto sink = request->getSink(); + auto program = request->getSpecializedGlobalAndEntryPointsComponentType(); + { + Encoder::WithArray withArray(&encoder); // kContainerFourCc + + for (auto target : linkage->targets) + { + auto targetProgram = program->getTargetProgram(target); + encode(targetProgram, sink); + } + } + + // The compiled `program` may also have zero or more entry points, + // and we need to encode information about each of them. + // + { + Encoder::WithArray withArray(&encoder, SerialBinary::kEntryPointListFourCc); + + auto entryPointCount = program->getEntryPointCount(); + for (Index ii = 0; ii < entryPointCount; ++ii) + { + auto entryPoint = program->getEntryPoint(ii); + auto entryPointMangledName = program->getEntryPointMangledName(ii); + encode(entryPoint, entryPointMangledName); + } + } + + return SLANG_OK; + } + + SlangResult encode(TargetProgram* targetProgram, DiagnosticSink* sink) { - SerialContainerData data; + // TODO: + // Serialization of target component IR is causing the embedded precompiled binary + // feature to fail. The resulting data modules contain both TU IR and TC IR, with only + // one module header. Yong suggested to ignore the TC IR for now, though also that + // OV was using the feature, so disabling this might cause problems. + + IRModule* irModule = targetProgram->getOrCreateIRModuleForLayout(sink); + + // Okay, we need to serialize this target program and its IR too... + IRSerialData serialData; + IRSerialWriter writer; + SLANG_RETURN_ON_FAIL( - SerialContainerUtil::addFrontEndRequestToData(frontEndReq, options, data)); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(data, options, &container)); + writer.write(irModule, sourceLocWriter, options.optionFlags, &serialData)); + SLANG_RETURN_ON_FAIL(IRSerialWriter::writeContainer(serialData, encoder.getRIFF())); + + return SLANG_OK; } - // We now write the RiffContainer to the stream - SLANG_RETURN_ON_FAIL(RiffUtil::write(container.getRoot(), true, stream)); - return SLANG_OK; -} -/* static */ SlangResult SerialContainerUtil::write( - EndToEndCompileRequest* request, - const WriteOptions& options, - Stream* stream) -{ - RiffContainer container; + void encode(Name* name) { encoder.encode(name->text); } + + void encode(String const& value) { encoder.encode(value); } + + void encode(uint32_t value) { encoder.encode(UInt(value)); } + + void encodeData(void const* data, size_t size) { encoder.encodeData(data, size); } + + SlangResult encode(EntryPoint* entryPoint, String const& entryPointMangledName) { - SerialContainerData data; - SLANG_RETURN_ON_FAIL(SerialContainerUtil::addEndToEndRequestToData(request, options, data)); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(data, options, &container)); + Encoder::WithObject withObject(&encoder, SerialBinary::kEntryPointFourCc); + + { + Encoder::WithObject withProperty(&encoder, SerialBinary::kNameFourCC); + encode(entryPoint->getName()); + } + { + Encoder::WithObject withProperty(&encoder, SerialBinary::kProfileFourCC); + encode(entryPoint->getProfile().raw); + } + { + Encoder::WithObject withProperty(&encoder, SerialBinary::kMangledNameFourCC); + encode(entryPointMangledName); + } + + return SLANG_OK; } - // We now write the RiffContainer to the stream - SLANG_RETURN_ON_FAIL(RiffUtil::write(container.getRoot(), true, stream)); - return SLANG_OK; -} -/* static */ SlangResult SerialContainerUtil::addModuleToData( - Module* module, - const WriteOptions& options, - SerialContainerData& outData) -{ - if (options.optionFlags & (SerialOptionFlag::ASTModule | SerialOptionFlag::IRModule)) + + SlangResult encode(Module* module) { - SerialContainerData::Module dstModule; + if (!(options.optionFlags & (SerialOptionFlag::IRModule | SerialOptionFlag::ASTModule))) + return SLANG_OK; - // NOTE: The astBuilder is not set here, as not needed to be scoped for serialization (it is - // assumed the TranslationUnitRequest stays in scope) + Encoder::WithObject withModule(&encoder, SerialBinary::kModuleFourCC); - if (options.optionFlags & SerialOptionFlag::ASTModule) + // The first piece that we write for a module is its header. + // The header is intended to provide information that can be + // used to determine if a precompiled module is up-to-date. + // + // Update(tfoley): Okay, let's skip the whole header idea and just + // serialize these things as properties of the module itself... { - // Root AST node - auto moduleDecl = module->getModuleDecl(); - SLANG_ASSERT(moduleDecl); + // So many things need the module name, that it makes + // sense to serialize it separately from all the rest. + // + { + Encoder::WithObject withProperty(&encoder, SerialBinary::kNameFourCC); + encoder.encodeString(module->getNameObj()->text); + } + + // The header includes a digest of all the compile options and + // the files that the compiled result depended on. + // + auto digest = module->computeDigest(); + encoder.encodeData(PropertyKeys<Module>::Digest, digest.data, sizeof(digest.data)); - dstModule.astRootNode = moduleDecl; + // The header includes an array of the paths of all of the + // files that the compiled result depended on. + // + encodeModuleDependencyPaths(module); } - if (options.optionFlags & SerialOptionFlag::IRModule) + + // If serialization of Slang IR modules is enabled, and there + // is IR available for this module, then we we encode it. + // + if ((options.optionFlags & SerialOptionFlag::IRModule)) { - // IR module - dstModule.irModule = module->getIRModule(); - SLANG_ASSERT(dstModule.irModule); + if (auto irModule = module->getIRModule()) + { + Encoder::WithKeyValuePair withKey(&encoder, PropertyKeys<Module>::IRModule); + + IRSerialData serialData; + IRSerialWriter writer; + SLANG_RETURN_ON_FAIL( + writer.write(irModule, sourceLocWriter, options.optionFlags, &serialData)); + SLANG_RETURN_ON_FAIL(IRSerialWriter::writeContainer(serialData, encoder.getRIFF())); + } } + // If serialization of AST information is enabled, and we have AST + // information available, then we serialize it here. + // + if (options.optionFlags & SerialOptionFlag::ASTModule) + { + if (auto moduleDecl = module->getModuleDecl()) + { + Encoder::WithKeyValuePair withKey(&encoder, PropertyKeys<Module>::ASTModule); + + writeSerializedModuleAST(&encoder, moduleDecl, sourceLocWriter); + } + } + + return SLANG_OK; + } + + SlangResult encodeModuleDependencyPaths(Module* module) + { + Encoder::WithObject withProperty(&encoder, PropertyKeys<Module>::FileDependencies); + + // TODO(tfoley): This is some of the most complicated logic + // in the encoding system, because it tries to translate + // the file dependency paths into something that isn't + // specific to the machine on which a module was built. + // + // The comments that follow are from the original implementation + // of this logic, because I cannot state with confidence + // that I know what's happening in all of this. + + // Here we assume that the first file in the file dependencies is the module's file path. // We store the module's file path as a relative path with respect to the first search // directory that contains the module, and store the paths of dependent files as relative @@ -155,6 +305,7 @@ namespace Slang } Path::getCanonical(linkageRoot, linkageRoot); + Encoder::WithArray withArray(&encoder); for (auto file : fileDependencies) { if (file->getPathInfo().hasFoundPath()) @@ -170,728 +321,314 @@ namespace Slang { auto relativeModulePath = Path::getRelativePath(linkageRoot, canonicalModulePath); - dstModule.dependentFiles.add(relativeModulePath); + + encoder.encodeString(relativeModulePath); } else { // For all other dependnet files, store them as relative paths with respect // to the module's path. canonicalFilePath = Path::getRelativePath(moduleDir, canonicalFilePath); - dstModule.dependentFiles.add(canonicalFilePath); + encoder.encodeString(canonicalFilePath); } } else { // If the module is coming from string instead of an actual file, store it as // is. - dstModule.dependentFiles.add(canonicalModulePath); + encoder.encodeString(canonicalModulePath); } } else { - dstModule.dependentFiles.add(file->getPathInfo().getMostUniqueIdentity()); + encoder.encodeString(file->getPathInfo().getMostUniqueIdentity()); } } - dstModule.digest = module->computeDigest(); - outData.modules.add(dstModule); - } - return SLANG_OK; -} - -/* static */ SlangResult SerialContainerUtil::addFrontEndRequestToData( - FrontEndCompileRequest* frontEndReq, - const WriteOptions& options, - SerialContainerData& outData) -{ - // Go through translation units, adding modules - for (TranslationUnitRequest* translationUnit : frontEndReq->translationUnits) - { - SLANG_RETURN_ON_FAIL(addModuleToData(translationUnit->module, options, outData)); - } - - return SLANG_OK; -} - -/* static */ SlangResult SerialContainerUtil::addEndToEndRequestToData( - EndToEndCompileRequest* request, - const WriteOptions& options, - SerialContainerData& out) -{ - auto linkage = request->getLinkage(); - auto sink = request->getSink(); - - // Output the parsed modules. - addFrontEndRequestToData(request->getFrontEndReq(), options, out); - - // If we are skipping code generation, then we are done. - if (request->getOptionSet().getBoolOption(CompilerOptionName::SkipCodeGen)) - { return SLANG_OK; } - // - auto program = request->getSpecializedGlobalAndEntryPointsComponentType(); - // Add all the target modules + SlangResult encodeFinalPieces() { - for (auto target : linkage->targets) + // We can now output the debug information. This is for all IR and AST + if (sourceLocWriter) { - auto targetProgram = program->getTargetProgram(target); - auto irModule = targetProgram->getOrCreateIRModuleForLayout(sink); - - SerialContainerData::TargetComponent targetComponent; + // Write out the debug info + SerialSourceLocData debugData; + sourceLocWriter->write(&debugData); - targetComponent.irModule = irModule; - - auto& dstTarget = targetComponent.target; - - dstTarget.floatingPointMode = target->getOptionSet().getFloatingPointMode(); - dstTarget.profile = target->getOptionSet().getProfile(); - dstTarget.flags = target->getOptionSet().getTargetFlags(); - dstTarget.codeGenTarget = target->getTarget(); - - out.targetComponents.add(targetComponent); + debugData.writeContainer(encoder.getRIFF()); } - } - // Entry points - { - auto entryPointCount = program->getEntryPointCount(); - for (Index ii = 0; ii < entryPointCount; ++ii) + // Write the container string table + if (containerStringPool.getAdded().getCount() > 0) { - auto entryPoint = program->getEntryPoint(ii); - auto entryPointMangledName = program->getEntryPointMangledName(ii); - - SerialContainerData::EntryPoint dstEntryPoint; + Encoder::WithKeyValuePair withKey(&encoder, SerialBinary::kStringTableFourCc); - dstEntryPoint.name = entryPoint->getName(); - dstEntryPoint.mangledName = entryPointMangledName; - dstEntryPoint.profile = entryPoint->getProfile(); + List<char> encodedTable; + SerialStringTableUtil::encodeStringTable(containerStringPool, encodedTable); - out.entryPoints.add(dstEntryPoint); + encoder.encodeData(encodedTable.getBuffer(), encodedTable.getCount()); } + + return SLANG_OK; } - return SLANG_OK; -} -/* static */ SlangResult SerialContainerUtil::write( - const SerialContainerData& data, - const WriteOptions& options, - RiffContainer* container) -{ +private: + SerialContainerUtil::WriteOptions const& options; RefPtr<SerialSourceLocWriter> sourceLocWriter; // The string pool used across the whole of the container - StringSlicePool containerStringPool(StringSlicePool::Style::Default); + StringSlicePool containerStringPool; - RiffContainer::ScopeChunk scopeModule( - container, - RiffContainer::Chunk::Kind::List, - SerialBinary::kContainerFourCc); + Encoder encoder; +}; - if (data.modules.getCount() && - (options.optionFlags & (SerialOptionFlag::IRModule | SerialOptionFlag::ASTModule))) - { - // Module list - RiffContainer::ScopeChunk moduleListScope( - container, - RiffContainer::Chunk::Kind::List, - SerialBinary::kModuleListFourCc); - - if (options.optionFlags & SerialOptionFlag::SourceLocation) - { - sourceLocWriter = new SerialSourceLocWriter(options.sourceManager); - } - - RefPtr<SerialClasses> serialClasses; - - for (const auto& module : data.modules) - { - // Okay, we need to serialize this module to our container file. - // We currently don't serialize it's name..., but support for that could be added. +// +// To serialize a module (or compile request) to a stream, we first +// construct a RIFF container from it, and then serialize that +// container out to a byte stream. +// - // First, we write a header that can be used to verify if the precompiled module is - // up-to-date. The header has: 1) a digest of all compile options and dependent source - // files. 2) a list of source file paths. - // - { - RiffContainer::ScopeChunk scopeHeader( - container, - RiffContainer::Chunk::Kind::Data, - SerialBinary::kModuleHeaderFourCc); - OwnedMemoryStream headerMemStream(FileAccess::Write); - StringBuilder filePathsSB; - for (auto fileDependency : module.dependentFiles) - filePathsSB << fileDependency << "\n"; - headerMemStream.write(module.digest.data, sizeof(module.digest.data)); - uint32_t fileListLength = (uint32_t)filePathsSB.getLength(); - headerMemStream.write(&fileListLength, sizeof(uint32_t)); - headerMemStream.write(filePathsSB.getBuffer(), fileListLength); - container->write( - headerMemStream.getContents().getBuffer(), - headerMemStream.getContents().getCount()); - } - - // Write the IR information - if ((options.optionFlags & SerialOptionFlag::IRModule) && module.irModule) - { - IRSerialData serialData; - IRSerialWriter writer; - SLANG_RETURN_ON_FAIL(writer.write( - module.irModule, - sourceLocWriter, - options.optionFlags, - &serialData)); - SLANG_RETURN_ON_FAIL(IRSerialWriter::writeContainer(serialData, container)); - } - - // Write the AST information - - if (options.optionFlags & SerialOptionFlag::ASTModule) - { - if (ModuleDecl* moduleDecl = as<ModuleDecl>(module.astRootNode)) - { - // Put in AST module - RiffContainer::ScopeChunk scopeASTModule( - container, - RiffContainer::Chunk::Kind::List, - ASTSerialBinary::kSlangASTModuleFourCC); - - if (!serialClasses) - { - SLANG_RETURN_ON_FAIL(SerialClassesUtil::create(serialClasses)); - } - - ModuleSerialFilter filter(moduleDecl); - auto astWriterFlag = SerialWriter::Flag::ZeroInitialize; - if ((options.optionFlags & SerialOptionFlag::ASTFunctionBody) == 0) - astWriterFlag = (SerialWriter::Flag::Enum)( - astWriterFlag | SerialWriter::Flag::SkipFunctionBody); - - SerialWriter writer(serialClasses, &filter, astWriterFlag); - - writer.getExtraObjects().set(sourceLocWriter); - - // Add the module and everything that isn't filtered out in the filter. - writer.addPointer(moduleDecl); +/* static */ SlangResult SerialContainerUtil::write( + Module* module, + const WriteOptions& options, + Stream* stream) +{ + ModuleEncodingContext context(options, stream); + SLANG_RETURN_ON_FAIL(context.encode(module)); + return SLANG_OK; +} +/* static */ SlangResult SerialContainerUtil::write( + FrontEndCompileRequest* request, + const WriteOptions& options, + Stream* stream) +{ + ModuleEncodingContext context(options, stream); + SLANG_RETURN_ON_FAIL(context.encode(request)); + return SLANG_OK; +} - // We can now serialize it into the riff container. - SLANG_RETURN_ON_FAIL(writer.writeIntoContainer( - ASTSerialBinary::kSlangASTModuleDataFourCC, - container)); - } - } - } +/* static */ SlangResult SerialContainerUtil::write( + EndToEndCompileRequest* request, + const WriteOptions& options, + Stream* stream) +{ + ModuleEncodingContext context(options, stream); + SLANG_RETURN_ON_FAIL(context.encode(request)); + return SLANG_OK; +} - // TODO: - // Serialization of target component IR is causing the embedded precompiled binary - // feature to fail. The resulting data modules contain both TU IR and TC IR, with only - // one module header. Yong suggested to ignore the TC IR for now, though also that - // OV was using the feature, so disabling this might cause problems. -#if 0 - if (data.targetComponents.getCount() && (options.optionFlags & SerialOptionFlag::IRModule)) - { - // TODO: in the case where we have specialization, we might need - // to serialize IR related to `program`... +String StringChunkRef::getValue() +{ + return Decoder(ptr()).decodeString(); +} - for (const auto& targetComponent : data.targetComponents) - { - IRModule* irModule = targetComponent.irModule; +ChunkRefList<StringChunkRef> ModuleChunkRef::getFileDependencies() +{ + Decoder decoder(ptr()); + Decoder::WithProperty withProperty(decoder, PropertyKeys<Module>::FileDependencies); + return ChunkRefList<StringChunkRef>(as<RiffContainer::ListChunk>(decoder.getCursor())); +} - // Okay, we need to serialize this target program and its IR too... - IRSerialData serialData; - IRSerialWriter writer; +ModuleChunkRef ModuleChunkRef::find(RiffContainer* container) +{ + auto found = container->getRoot()->findListRec(SerialBinary::kModuleFourCC); + return ModuleChunkRef(found); +} - SLANG_RETURN_ON_FAIL(writer.write(irModule, sourceLocWriter, options.optionFlags, &serialData)); - SLANG_RETURN_ON_FAIL(IRSerialWriter::writeContainer(serialData, options.compressionType, container)); - } - } -#endif +SHA1::Digest ModuleChunkRef::getDigest() +{ + auto foundChunk = + static_cast<RiffContainer::DataChunk*>(ptr()->findContained(PropertyKeys<Module>::Digest)); + if (!foundChunk) + { + SLANG_UNEXPECTED("module chunk had no digest"); } - - if (data.entryPoints.getCount()) + if (foundChunk->calcPayloadSize() != sizeof(SHA1::Digest)) { - for (const auto& entryPoint : data.entryPoints) - { - RiffContainer::ScopeChunk entryPointScope( - container, - RiffContainer::Chunk::Kind::Data, - SerialBinary::kEntryPointFourCc); - - SerialContainerBinary::EntryPoint dst; - - dst.name = uint32_t(containerStringPool.add(entryPoint.name->text)); - dst.profile = entryPoint.profile.raw; - dst.mangledName = uint32_t(containerStringPool.add(entryPoint.mangledName)); - - container->write(&dst, sizeof(dst)); - } + SLANG_UNEXPECTED("module digest chunk had wrong size"); } - // We can now output the debug information. This is for all IR and AST - if (sourceLocWriter) - { - // Write out the debug info - SerialSourceLocData debugData; - sourceLocWriter->write(&debugData); + SHA1::Digest digest; + foundChunk->getPayload(&digest); + return digest; +} - debugData.writeContainer(container); - } +String ModuleChunkRef::getName() +{ + // TODO(tfoley): This kind of logic needs a way + // to be greatly simplified, so that we don't + // have to express such complicated logic for + // simply extracting a single string property... + // + Decoder decoder(ptr()); + Decoder::WithProperty withProperty(decoder, SerialBinary::kNameFourCC); + return decoder.decodeString(); +} - // Write the container string table - if (containerStringPool.getAdded().getCount() > 0) - { - RiffContainer::ScopeChunk stringTableScope( - container, - RiffContainer::Chunk::Kind::Data, - SerialBinary::kStringTableFourCc); - List<char> encodedTable; - SerialStringTableUtil::encodeStringTable(containerStringPool, encodedTable); +IRModuleChunkRef ModuleChunkRef::findIR() +{ + auto foundProperty = ptr()->findContainedList(PropertyKeys<Module>::IRModule); + if (!foundProperty) + return IRModuleChunkRef(nullptr); + return IRModuleChunkRef( + static_cast<RiffContainer::ListChunk*>(foundProperty->getFirstContainedChunk())); +} - container->write(encodedTable.getBuffer(), encodedTable.getCount()); - } +ASTModuleChunkRef ModuleChunkRef::findAST() +{ + auto foundProperty = ptr()->findContainedList(PropertyKeys<Module>::ASTModule); + if (!foundProperty) + return ASTModuleChunkRef(nullptr); + return ASTModuleChunkRef( + static_cast<RiffContainer::ListChunk*>(foundProperty->getFirstContainedChunk())); +} - return SLANG_OK; +ContainerChunkRef ContainerChunkRef::find(RiffContainer* container) +{ + auto found = container->getRoot()->findListRec(SerialBinary::kContainerFourCc); + return ContainerChunkRef(found); } +ChunkRefList<ModuleChunkRef> ContainerChunkRef::getModules() +{ + auto found = ptr()->findContainedList(SerialBinary::kModuleListFourCc); + return ChunkRefList<ModuleChunkRef>(found); +} -static List<ExtensionDecl*>& _getCandidateExtensionList( - AggTypeDecl* typeDecl, - Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>>& mapTypeToCandidateExtensions) +ChunkRefList<EntryPointChunkRef> ContainerChunkRef::getEntryPoints() { - RefPtr<CandidateExtensionList> entry; - if (!mapTypeToCandidateExtensions.tryGetValue(typeDecl, entry)) - { - entry = new CandidateExtensionList(); - mapTypeToCandidateExtensions.add(typeDecl, entry); - } - return entry->candidateExtensions; + auto found = ptr()->findContainedList(SerialBinary::kEntryPointListFourCc); + return ChunkRefList<EntryPointChunkRef>(found); } -/* static */ Result SerialContainerUtil::read( - RiffContainer* container, - const ReadOptions& options, - const LoadedModuleDictionary* additionalLoadedModules, - SerialContainerData& out) +String EntryPointChunkRef::getMangledName() const { - out.clear(); + // TODO(tfoley): This kind of logic needs a way + // to be greatly simplified, so that we don't + // have to express such complicated logic for + // simply extracting a single string property... + // + Decoder decoder(ptr()); + Decoder::WithProperty withProperty(decoder, SerialBinary::kMangledNameFourCC); + return decoder.decodeString(); +} - RiffContainer::ListChunk* containerChunk = - container->getRoot()->findListRec(SerialBinary::kContainerFourCc); - if (!containerChunk) - { - // Must be a container - return SLANG_FAIL; - } +String EntryPointChunkRef::getName() const +{ + // TODO(tfoley): This kind of logic needs a way + // to be greatly simplified, so that we don't + // have to express such complicated logic for + // simply extracting a single string property... + // + Decoder decoder(ptr()); + Decoder::WithProperty withProperty(decoder, SerialBinary::kNameFourCC); + return decoder.decodeString(); +} - StringSlicePool containerStringPool(StringSlicePool::Style::Default); +Profile EntryPointChunkRef::getProfile() const +{ + // TODO(tfoley): This kind of logic needs a way + // to be greatly simplified, so that we don't + // have to express such complicated logic for + // simply extracting a single string property... + // + Decoder decoder(ptr()); + Decoder::WithProperty withProperty(decoder, SerialBinary::kProfileFourCC); - if (RiffContainer::Data* stringTableData = - containerChunk->findContainedData(SerialBinary::kStringTableFourCc)) - { - SerialStringTableUtil::decodeStringTable( - (const char*)stringTableData->getPayload(), - stringTableData->getSize(), - containerStringPool); - } + Profile::RawVal rawVal; + decoder.decode(rawVal); - RefPtr<SerialSourceLocReader> sourceLocReader; - RefPtr<SerialClasses> serialClasses; + return Profile(rawVal); +} - // Debug information - if (auto debugChunk = containerChunk->findContainedList(SerialSourceLocData::kDebugFourCc)) - { - // Read into data - SerialSourceLocData sourceLocData; - SLANG_RETURN_ON_FAIL(sourceLocData.readContainer(debugChunk)); - // Turn into DebugReader - sourceLocReader = new SerialSourceLocReader; - SLANG_RETURN_ON_FAIL(sourceLocReader->read(&sourceLocData, options.sourceManager)); - } +RiffContainer::ListChunk* findDebugChunk(RiffContainer::Chunk* startingChunk) +{ + if (!startingChunk) + return nullptr; - // Create a source loc representing the binary module. - SourceLoc binaryModuleLoc = SourceLoc(); + RiffContainer::ListChunk* container = as<RiffContainer::ListChunk>(startingChunk); + if (!container) + container = startingChunk->m_parent; - if (options.modulePath.getLength()) + for (; container; container = container->m_parent) { - auto srcManager = options.linkage->getSourceManager(); - auto modulePathInfo = PathInfo::makePath(options.modulePath); - auto srcFile = srcManager->findSourceFileByPathRecursively(modulePathInfo.foundPath); - if (!srcFile) + if (auto debugChunk = container->findContainedList(SerialSourceLocData::kDebugFourCc)) { - srcFile = srcManager->createSourceFileWithString(modulePathInfo, String()); - srcManager->addSourceFile(options.modulePath, srcFile); + return debugChunk; } - auto srcView = srcManager->createSourceView(srcFile, &modulePathInfo, SourceLoc()); - binaryModuleLoc = srcView->getRange().begin; } - // Add modules - if (RiffContainer::ListChunk* moduleList = - containerChunk->findContainedList(SerialBinary::kModuleListFourCc)) - { - RiffContainer::Chunk* chunk = moduleList->getFirstContainedChunk(); - while (chunk) - { - auto startChunk = chunk; - - RefPtr<ASTBuilder> astBuilder = options.astBuilder; - NodeBase* astRootNode = nullptr; - RefPtr<IRModule> irModule; - SerialContainerData::Module module; - if (auto headerChunk = - as<RiffContainer::DataChunk>(chunk, SerialBinary::kModuleHeaderFourCc)) - { - MemoryStreamBase memStream( - FileAccess::Read, - headerChunk->getSingleData()->getPayload(), - headerChunk->getSingleData()->getSize()); - size_t readSize = 0; - memStream.read(module.digest.data, sizeof(SHA1::Digest), readSize); - if (readSize != sizeof(SHA1::Digest)) - return SLANG_FAIL; - uint32_t fileListLength = 0; - memStream.read(&fileListLength, sizeof(uint32_t), readSize); - if (readSize != sizeof(uint32_t)) - return SLANG_FAIL; - List<uint8_t> fileListContent; - fileListContent.setCount(fileListLength); - memStream.read(fileListContent.getBuffer(), fileListContent.getCount(), readSize); - if (readSize != (size_t)fileListContent.getCount()) - return SLANG_FAIL; - UnownedStringSlice fileListString( - (const char*)fileListContent.getBuffer(), - fileListContent.getCount()); - List<UnownedStringSlice> fileList; - StringUtil::split(fileListString, '\n', fileList); - for (auto file : fileList) - { - if (file.getLength()) - { - module.dependentFiles.add(file); - } - } - // Onto next chunk - chunk = chunk->m_next; - } - - if (auto irChunk = as<RiffContainer::ListChunk>(chunk, IRSerialBinary::kIRModuleFourCc)) - { - if (!options.readHeaderOnly) - { - IRSerialData serialData; - SLANG_RETURN_ON_FAIL(IRSerialReader::readContainer(irChunk, &serialData)); - - // Read IR back from serialData - IRSerialReader reader; - SLANG_RETURN_ON_FAIL( - reader.read(serialData, options.session, sourceLocReader, irModule)); - } - - // Onto next chunk - chunk = chunk->m_next; - } - - if (auto astChunk = - as<RiffContainer::ListChunk>(chunk, ASTSerialBinary::kSlangASTModuleFourCC)) - { - if (!options.readHeaderOnly) - { - RiffContainer::Data* astData = - astChunk->findContainedData(ASTSerialBinary::kSlangASTModuleDataFourCC); - - if (astData) - { - if (!serialClasses) - { - SLANG_RETURN_ON_FAIL(SerialClassesUtil::create(serialClasses)); - } - - // TODO(JS): We probably want to store off better information about each of - // the translation unit including some kind of 'name'. For now we just - // generate a name. - - StringBuilder buf; - buf << "tu" << out.modules.getCount(); - if (!astBuilder) - { - astBuilder = - new ASTBuilder(options.sharedASTBuilder, buf.produceString()); - } - - /// We need to make the current ASTBuilder available for access via - /// thread_local global. - SetASTBuilderContextRAII astBuilderRAII(astBuilder); - - DefaultSerialObjectFactory objectFactory(astBuilder); - - SerialReader reader(serialClasses, &objectFactory); - - // Sets up the entry table - one entry for each 'object'. - // No native objects are constructed. No objects are deserialized. - SLANG_RETURN_ON_FAIL(reader.loadEntries( - (const uint8_t*)astData->getPayload(), - astData->getSize())); - - // Construct a native object for each table entry (where appropriate). - // Note that this *doesn't* set all object pointers - some are special cased - // and created on demand (strings) and imported symbols will have their - // object pointers unset (they are resolved in next step) - SLANG_RETURN_ON_FAIL(reader.constructObjects(options.namePool)); - - // Resolve external references if the linkage is specified - if (options.linkage) - { - const auto& entries = reader.getEntries(); - auto& objects = reader.getObjects(); - const Index entriesCount = entries.getCount(); - - String currentModuleName; - Module* currentModule = nullptr; - - // Index from 1 (0 is null) - for (Index i = 1; i < entriesCount; ++i) - { - const SerialInfo::Entry* entry = entries[i]; - if (entry->typeKind == SerialTypeKind::ImportSymbol) - { - // Import symbols are always serialized with a mangled name in - // the form of <module_name>!<symbol_mangled_name>. As - // symbol_mangled_name may not contain the name of its parent - // module in the case of an `extern` or `export` symbol. - // - UnownedStringSlice mangledName = - reader.getStringSlice(SerialIndex(i)); - List<UnownedStringSlice> slicesOut; - StringUtil::split(mangledName, '!', slicesOut); - if (slicesOut.getCount() != 2) - return SLANG_FAIL; - auto moduleName = slicesOut[0]; - mangledName = slicesOut[1]; - - // If we already have looked up this module and it has the same - // name just use what we have - Module* readModule = nullptr; - if (currentModule && - moduleName == currentModuleName.getUnownedSlice()) - { - readModule = currentModule; - } - else - { - // The modules are loaded on the linkage. - Linkage* linkage = options.linkage; - - NamePool* namePool = linkage->getNamePool(); - Name* moduleNameName = namePool->getName(moduleName); - readModule = linkage->findOrImportModule( - moduleNameName, - binaryModuleLoc, - options.sink, - additionalLoadedModules); - if (!readModule) - { - return SLANG_FAIL; - } - - // Set the current module and name - currentModule = readModule; - currentModuleName = moduleName; - } - - // Look up the symbol - NodeBase* nodeBase = - readModule->findExportFromMangledName(mangledName); - - if (!nodeBase) - { - if (options.sink) - { - options.sink->diagnose( - SourceLoc::fromRaw(0), - Diagnostics::unableToFindSymbolInModule, - mangledName, - moduleName); - } - - // If didn't find the export then we create an - // UnresolvedDecl node to represent the error. - auto unresolved = astBuilder->create<UnresolvedDecl>(); - unresolved->nameAndLoc.name = - options.linkage->getNamePool()->getName(mangledName); - nodeBase = unresolved; - } - - // set the result - objects[i] = nodeBase; - } - } - } - - // Set the sourceLocReader before doing de-serialize, such can lookup the - // remapped sourceLocs - reader.getExtraObjects().set(sourceLocReader); - - // TODO(JS): - // If modules can have more complicated relationships (like a two modules - // can refer to symbols from each other), then we can make this work by 1) - // deserialize *without* the external symbols being set up 2) calculate the - // symbols 3) deserialize the other module (in the same way) 4) run - // deserializeObjects *again* on each module This is less efficient than it - // might be (because deserialize phase is done twice) so if this is - // necessary may want a mechanism that *just* does reference lookups. - // - // For now if we assume a module can only access symbols from another - // module, and not the reverse. So we just need to deserialize and we are - // done - SLANG_RETURN_ON_FAIL(reader.deserializeObjects()); - - // Get the root node. It's at index 1 (0 is the null value). - astRootNode = reader.getPointer(SerialIndex(1)).dynamicCast<NodeBase>(); - - // Go through all AST nodes: - // 1) Add the extensions to the module mapTypeToCandidateExtensions cache - // 2) We need to fix the callback pointers for parsing - // 3) Register all `Val`s to the ASTBuilder's deduplication map. - - { - ModuleDecl* moduleDecl = as<ModuleDecl>(astRootNode); - - // Maps from keyword name name to index in (syntaxParseInfos) - // Will be filled in lazily if needed (for SyntaxDecl setup) - Dictionary<Name*, Index> syntaxKeywordDict; - - OrderedDictionary<Val*, List<Val**>> valUses; - - // Get the parse infos - const auto syntaxParseInfos = getSyntaxParseInfos(); - SLANG_ASSERT(syntaxParseInfos.getCount()); - - for (auto& obj : reader.getObjects()) - { - - if (obj.m_kind == SerialTypeKind::NodeBase) - { - NodeBase* nodeBase = (NodeBase*)obj.m_ptr; - SLANG_ASSERT(nodeBase); - - if (ExtensionDecl* extensionDecl = - dynamicCast<ExtensionDecl>(nodeBase)) - { - if (auto targetDeclRefType = - as<DeclRefType>(extensionDecl->targetType)) - { - ShortList<AggTypeDecl*> baseDecls; - getExtensionTargetDeclList( - astBuilder, - targetDeclRefType, - extensionDecl, - baseDecls); - for (auto baseDecl : baseDecls) - { - _getCandidateExtensionList( - baseDecl, - moduleDecl->mapTypeToCandidateExtensions) - .add(extensionDecl); - } - } - } - else if ( - SyntaxDecl* syntaxDecl = dynamicCast<SyntaxDecl>(nodeBase)) - { - // Set up the dictionary lazily - if (syntaxKeywordDict.getCount() == 0) - { - NamePool* namePool = options.session->getNamePool(); - for (Index i = 0; i < syntaxParseInfos.getCount(); ++i) - { - const auto& entry = syntaxParseInfos[i]; - syntaxKeywordDict.add( - namePool->getName(entry.keywordName), - i); - } - // Must have something in it at this point - SLANG_ASSERT(syntaxKeywordDict.getCount()); - } - - // Look up the index - Index* entryIndexPtr = - syntaxKeywordDict.tryGetValue(syntaxDecl->getName()); - if (entryIndexPtr) - { - // Set up SyntaxDecl based on the ParseSyntaxIndo - auto& info = syntaxParseInfos[*entryIndexPtr]; - syntaxDecl->parseCallback = *info.callback; - syntaxDecl->parseUserData = - const_cast<ReflectClassInfo*>(info.classInfo); - } - else - { - // If we don't find a setup entry, we use - // `parseSimpleSyntax`, and set the parseUserData to the - // ReflectClassInfo (as parseSimpleSyntax needs this) - syntaxDecl->parseCallback = &parseSimpleSyntax; - SLANG_ASSERT(syntaxDecl->syntaxClass.classInfo); - syntaxDecl->parseUserData = - const_cast<ReflectClassInfo*>( - syntaxDecl->syntaxClass.classInfo); - } - } - else if (Val* val = dynamicCast<Val>(nodeBase)) - { - val->_setUnique(); - } - } - } - } - } - } - - // Onto next chunk - chunk = chunk->m_next; - } - - if (astBuilder || irModule) - { - module.astBuilder = astBuilder; - module.astRootNode = astRootNode; - module.irModule = irModule; - - out.modules.add(module); - } - - // If no progress, step to next chunk - chunk = (chunk == startChunk) ? chunk->m_next : chunk; - } - } + return nullptr; +} - // Add all the entry points - { - List<RiffContainer::DataChunk*> entryPointChunks; - containerChunk->findContained(SerialBinary::kEntryPointFourCc, entryPointChunks); +SlangResult readSourceLocationsFromDebugChunk( + RiffContainer::ListChunk* debugChunk, + SourceManager* sourceManager, + RefPtr<SerialSourceLocReader>& outReader) +{ + if (!debugChunk) + return SLANG_FAIL; - for (auto entryPointChunk : entryPointChunks) - { - auto reader = entryPointChunk->asReadHelper(); + // Source location serialization uses the old approach where + // there is an intermediate in-memory data structure that the + // raw data from the RIFF gets deserialized into, before that + // intermediate representation gets transformed into something + // more directly usable. + // + // Thus we start with a first step where we simply read the data + // from the RIFF into the intermediate structure. + // + SerialSourceLocData intermediateData; + SLANG_RETURN_ON_FAIL(intermediateData.readContainer(debugChunk)); - SerialContainerBinary::EntryPoint srcEntryPoint; - SLANG_RETURN_ON_FAIL(reader.read(srcEntryPoint)); + // After reading the data into the intermediate representation, + // we turn it into a `SerialSourceLocReader`, which vends source + // location information to other deserialization tasks (both IR + // and AST deserialization). + // + auto reader = RefPtr(new SerialSourceLocReader()); + SLANG_RETURN_ON_FAIL(reader->read(&intermediateData, sourceManager)); - SerialContainerData::EntryPoint dstEntryPoint; + outReader = reader; + return SLANG_OK; +} - dstEntryPoint.name = options.namePool->getName( - containerStringPool.getSlice(StringSlicePool::Handle(srcEntryPoint.name))); - dstEntryPoint.profile.raw = srcEntryPoint.profile; - dstEntryPoint.mangledName = - containerStringPool.getSlice(StringSlicePool::Handle(srcEntryPoint.mangledName)); +SlangResult decodeModuleIR( + RefPtr<IRModule>& outIRModule, + RiffContainer::Chunk* chunk, + Session* session, + SerialSourceLocReader* sourceLocReader) +{ + // IR serialization still uses the older approach, where + // data gets deserialized from the RIFF into an intermediate + // data structure (`IRSerialData`), and then the actual + // in-memory structures are created based on the intermediate. + // + // Thus we start by running the `IRSerialReader::readContainer` + // logic to get the `IRSerialData` representation. + // + // TODO(tfoley): This should all get streamlined so that we + // are deserializing IR nodes directly from the format written + // into the RIFF. + // + auto listChunk = as<RiffContainer::ListChunk>(chunk); + if (!listChunk) + return SLANG_FAIL; + IRSerialData serialData; + SLANG_RETURN_ON_FAIL(IRSerialReader::readContainer(listChunk, &serialData)); - out.entryPoints.add(dstEntryPoint); - } - } + // Next we read the actual IR representation out from the + // `serialData`. This is the step that may pull source-location + // information from the provided `sourceLocReader`. + // + IRSerialReader reader; + SLANG_RETURN_ON_FAIL(reader.read(serialData, session, sourceLocReader, outIRModule)); return SLANG_OK; } diff --git a/source/slang/slang-serialize-container.h b/source/slang/slang-serialize-container.h index 8ddc5072a..4c1053a6d 100644 --- a/source/slang/slang-serialize-container.h +++ b/source/slang/slang-serialize-container.h @@ -12,72 +12,6 @@ namespace Slang class EndToEndCompileRequest; -/* The binary representation actually held in riff/file format*/ -struct SerialContainerBinary -{ - struct Target - { - uint32_t target; - uint32_t flags; - uint32_t profile; - uint32_t floatingPointMode; - }; - - struct EntryPoint - { - uint32_t name; - uint32_t profile; - uint32_t mangledName; - }; -}; - -struct SerialContainerDataModule -{ - RefPtr<IRModule> irModule; ///< The IR for the module - RefPtr<ASTBuilder> astBuilder; ///< The astBuilder that owns the astRootNode - NodeBase* astRootNode = nullptr; ///< The module decl - List<String> dependentFiles; - SHA1::Digest digest; -}; - -/* Struct that holds all the data that can be held in a 'container' */ -struct SerialContainerData -{ - struct Target - { - CodeGenTarget codeGenTarget = CodeGenTarget::Unknown; - SlangTargetFlags flags = kDefaultTargetFlags; - Profile profile; - FloatingPointMode floatingPointMode = FloatingPointMode::Default; - }; - - struct TargetComponent - { - // IR module for a specific compilation target - Target target; - RefPtr<IRModule> irModule; - }; - - typedef SerialContainerDataModule Module; - - struct EntryPoint - { - Name* name = nullptr; - Profile profile; - String mangledName; - }; - - void clear() - { - entryPoints.clear(); - modules.clear(); - targetComponents.clear(); - } - - List<Module> modules; - List<TargetComponent> targetComponents; - List<EntryPoint> entryPoints; -}; struct SerialContainerUtil { @@ -104,37 +38,6 @@ struct SerialContainerUtil String modulePath; }; - /// Add module to outData - static SlangResult addModuleToData( - Module* module, - const WriteOptions& options, - SerialContainerData& outData); - - /// Get the serializable contents of the request as data - static SlangResult addEndToEndRequestToData( - EndToEndCompileRequest* request, - const WriteOptions& options, - SerialContainerData& outData); - - /// Convert front end request into something serializable - static SlangResult addFrontEndRequestToData( - FrontEndCompileRequest* request, - const WriteOptions& options, - SerialContainerData& outData); - - /// Write the data into the container - static SlangResult write( - const SerialContainerData& data, - const WriteOptions& options, - RiffContainer* container); - - /// Read the container into outData - static SlangResult read( - RiffContainer* container, - const ReadOptions& options, - const LoadedModuleDictionary* additionalLoadedModules, - SerialContainerData& outData); - /// Verify IR serialization static SlangResult verifyIRSerialize( IRModule* module, @@ -153,6 +56,192 @@ struct SerialContainerUtil static SlangResult write(Module* module, const WriteOptions& options, Stream* stream); }; + +struct ChunkRef +{ +public: + ChunkRef(RiffContainer::Chunk* chunk) + : _chunk(chunk) + { + } + + RiffContainer::Chunk* ptr() const { return _chunk; } + +protected: + RiffContainer::Chunk* _chunk = nullptr; +}; + +struct DataChunkRef : ChunkRef +{ +public: + DataChunkRef(RiffContainer::DataChunk* chunk) + : ChunkRef(chunk) + { + } + + RiffContainer::DataChunk* ptr() const { return static_cast<RiffContainer::DataChunk*>(_chunk); } + + operator RiffContainer::DataChunk*() const { return ptr(); } +}; + + +template<typename T> +struct ChunkRefList +{ +public: + struct Iterator + { + public: + Iterator(RiffContainer::Chunk* chunk) + : _chunk(chunk) + { + } + + bool operator!=(Iterator const& other) const { return _chunk != other._chunk; } + + void operator++() { _chunk = _chunk->m_next; } + + T operator*() + { + ChunkRef ref(_chunk); + return *(T*)&ref; + } + + private: + RiffContainer::Chunk* _chunk = nullptr; + }; + + Iterator begin() const { return _list ? _list->getFirstContainedChunk() : nullptr; } + Iterator end() const { return Iterator(nullptr); } + + Count getCount() + { + Count count = 0; + for (auto i : *this) + count++; + return count; + } + + T getFirst() { return *begin(); } + + ChunkRefList() {} + + ChunkRefList(RiffContainer::ListChunk* list) + : _list(list) + { + } + + operator RiffContainer::ListChunk*() const { return _list; } + +private: + RiffContainer::ListChunk* _list = nullptr; +}; + +struct ListChunkRef : ChunkRef +{ +public: + ListChunkRef(RiffContainer::Chunk* chunk) + : ChunkRef(chunk) + { + } + + RiffContainer::ListChunk* ptr() const { return static_cast<RiffContainer::ListChunk*>(_chunk); } + + operator RiffContainer::ListChunk*() const { return ptr(); } +}; + + +struct StringChunkRef : DataChunkRef +{ +public: + String getValue(); +}; + +struct IRModuleChunkRef : ListChunkRef +{ +public: + explicit IRModuleChunkRef(RiffContainer::ListChunk* chunk) + : ListChunkRef(chunk) + { + } +}; + +struct ASTModuleChunkRef : ListChunkRef +{ +public: + explicit ASTModuleChunkRef(RiffContainer::ListChunk* chunk) + : ListChunkRef(chunk) + { + } +}; + +struct ModuleChunkRef : ListChunkRef +{ +public: + static ModuleChunkRef find(RiffContainer* container); + + String getName(); + + IRModuleChunkRef findIR(); + ASTModuleChunkRef findAST(); + + SHA1::Digest getDigest(); + + ChunkRefList<StringChunkRef> getFileDependencies(); + +protected: + ModuleChunkRef(RiffContainer::Chunk* chunk) + : ListChunkRef(chunk) + { + } +}; + +struct EntryPointChunkRef : ListChunkRef +{ +public: + String getMangledName() const; + String getName() const; + Profile getProfile() const; + +protected: + EntryPointChunkRef(RiffContainer::Chunk* chunk) + : ListChunkRef(chunk) + { + } +}; + +struct ContainerChunkRef : ListChunkRef +{ +public: + static ContainerChunkRef find(RiffContainer* container); + + ChunkRefList<ModuleChunkRef> getModules(); + + ChunkRefList<EntryPointChunkRef> getEntryPoints(); + +protected: + ContainerChunkRef(RiffContainer::Chunk* chunk) + : ListChunkRef(chunk) + { + } +}; + +/// Attempt to find a debug-info chunk relative to +/// the given `startingChunk`. +/// +RiffContainer::ListChunk* findDebugChunk(RiffContainer::Chunk* startingChunk); + +SlangResult readSourceLocationsFromDebugChunk( + RiffContainer::ListChunk* debugChunk, + SourceManager* sourceManager, + RefPtr<SerialSourceLocReader>& outReader); + +SlangResult decodeModuleIR( + RefPtr<IRModule>& outIRModule, + RiffContainer::Chunk* chunk, + Session* session, + SerialSourceLocReader* sourceLocReader); + } // namespace Slang #endif diff --git a/source/slang/slang-serialize-factory.cpp b/source/slang/slang-serialize-factory.cpp deleted file mode 100644 index 5ad1e4911..000000000 --- a/source/slang/slang-serialize-factory.cpp +++ /dev/null @@ -1,123 +0,0 @@ -// slang-serialize-factory.cpp -#include "slang-serialize-factory.h" - -#include "../core/slang-math.h" -#include "slang-ast-builder.h" -#include "slang-ast-reflect.h" -#include "slang-ref-object-reflect.h" -#include "slang-serialize-ast.h" - -// Needed for ModuleSerialFilter -// Needed for 'findModuleForDecl' -#include "slang-legalize-types.h" -#include "slang-mangle.h" - -namespace Slang -{ - -/* !!!!!!!!!!!!!!!!!!!!!! DefaultSerialObjectFactory !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -void* DefaultSerialObjectFactory::create(SerialTypeKind typeKind, SerialSubType subType) -{ - switch (typeKind) - { - case SerialTypeKind::NodeBase: - { - return m_astBuilder->createByNodeType(ASTNodeType(subType)); - } - case SerialTypeKind::RefObject: - { - const ReflectClassInfo* info = SerialRefObjects::getClassInfo(RefObjectType(subType)); - - if (info && info->m_createFunc) - { - RefObject* obj = reinterpret_cast<RefObject*>(info->m_createFunc(nullptr)); - return _add(obj); - } - return nullptr; - } - default: - break; - } - - return nullptr; -} - -void* DefaultSerialObjectFactory::getOrCreateVal(ValNodeDesc&& desc) -{ - return m_astBuilder->_getOrCreateImpl(_Move(desc)); -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ModuleSerialFilter !!!!!!!!!!!!!!!!!!!!!!!! - -SerialIndex ModuleSerialFilter::writePointer(SerialWriter* writer, const RefObject* inPtr) -{ - // We don't serialize Module - if (as<Module>(inPtr)) - { - writer->setPointerIndex(inPtr, SerialIndex(0)); - return SerialIndex(0); - } - - // For now for everything else just write it - return writer->writeObject(inPtr); -} - -SerialIndex ModuleSerialFilter::writePointer(SerialWriter* writer, const NodeBase* inPtr) -{ - NodeBase* ptr = const_cast<NodeBase*>(inPtr); - SLANG_ASSERT(ptr); - - - if (Decl* decl = as<Decl>(ptr)) - { - ModuleDecl* moduleDecl = findModuleForDecl(decl); - if (moduleDecl && moduleDecl != m_moduleDecl) - { - ASTBuilder* astBuilder = m_moduleDecl->module->getASTBuilder(); - - // It's a reference to a declaration in another module, so first get the symbol name. - // Note that we will always name an import symbol in the form of - // <module_name>!<symbol_mangled_name> for serialization. - // This is because <symbol_mangled_name> does not necessarily include the name of its - // parent module when it is qualified as `extern` or `export`. - // - String mangledName = - getText(moduleDecl->getName()) + "!" + getMangledName(astBuilder, decl); - - // Add as an import symbol - return writer->addImportSymbol(mangledName); - } - else - { - // Okay... we can just write it out then - return writer->writeObject(ptr); - } - } - // For now for everything else just write it - return writer->writeObject(ptr); -} - -/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialClassesUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -/* static */ SlangResult SerialClassesUtil::addSerialClasses(SerialClasses* serialClasses) -{ - ASTSerialUtil::addSerialClasses(serialClasses); - SerialRefObjects::addSerialClasses(serialClasses); - - // Check if it seems ok - SLANG_ASSERT(serialClasses->isOk()); - - return SLANG_OK; -} - -/* static */ SlangResult SerialClassesUtil::create(RefPtr<SerialClasses>& out) -{ - RefPtr<SerialClasses> classes(new SerialClasses); - SLANG_RETURN_ON_FAIL(addSerialClasses(classes)); - - out = classes; - return SLANG_OK; -} - -} // namespace Slang diff --git a/source/slang/slang-serialize-factory.h b/source/slang/slang-serialize-factory.h deleted file mode 100644 index ef13fff83..000000000 --- a/source/slang/slang-serialize-factory.h +++ /dev/null @@ -1,49 +0,0 @@ -// slang-serialize-factory.h -#ifndef SLANG_SERIALIZE_FACTORY_H -#define SLANG_SERIALIZE_FACTORY_H - -#include "slang-serialize.h" - -namespace Slang -{ - -// !!!!!!!!!!!!!!!!!!!!! DefaultSerialObjectFactory !!!!!!!!!!!!!!!!!!!!!!!!!!! - -class ASTBuilder; - -class DefaultSerialObjectFactory : public SerialObjectFactory -{ -public: - virtual void* create(SerialTypeKind typeKind, SerialSubType subType) SLANG_OVERRIDE; - virtual void* getOrCreateVal(ValNodeDesc&& desc) SLANG_OVERRIDE; - - DefaultSerialObjectFactory(ASTBuilder* astBuilder) - : m_astBuilder(astBuilder) - { - } - -protected: - RefObject* _add(RefObject* obj) - { - m_scope.add(obj); - return obj; - } - - // We keep RefObjects in scope - List<RefPtr<RefObject>> m_scope; - ASTBuilder* m_astBuilder; -}; - - -struct SerialClassesUtil -{ - /// Add all types to serialClasses - static SlangResult addSerialClasses(SerialClasses* serialClasses); - /// Create SerialClasses with all the types added - static SlangResult create(RefPtr<SerialClasses>& out); -}; - - -} // namespace Slang - -#endif diff --git a/source/slang/slang-serialize-misc-type-info.h b/source/slang/slang-serialize-misc-type-info.h deleted file mode 100644 index 121b205d5..000000000 --- a/source/slang/slang-serialize-misc-type-info.h +++ /dev/null @@ -1,224 +0,0 @@ -// slang-serialize-misc-type-info.h -#ifndef SLANG_SERIALIZE_MISC_TYPE_INFO_H -#define SLANG_SERIALIZE_MISC_TYPE_INFO_H - -#include "../compiler-core/slang-source-loc.h" -#include "slang-compiler.h" -#include "slang-serialize-type-info.h" - -namespace Slang -{ - -/* Conversion for serialization for some more misc Slang types - */ - - -// Because is sized, we don't need to convert -template<> -struct SerialTypeInfo<FeedbackType::Kind> : public SerialIdentityTypeInfo<FeedbackType::Kind> -{ -}; - -// SamplerStateFlavor - -template<> -struct SerialTypeInfo<SamplerStateFlavor> - : public SerialConvertTypeInfo<SamplerStateFlavor, uint8_t> -{ -}; - -// ImageFormat -template<> -struct SerialTypeInfo<ImageFormat> : public SerialConvertTypeInfo<ImageFormat, uint8_t> -{ -}; - -// Stage -template<> -struct SerialTypeInfo<Stage> : public SerialConvertTypeInfo<Stage, uint8_t> -{ -}; - -// TokenType -template<> -struct SerialTypeInfo<TokenType> : public SerialConvertTypeInfo<TokenType, uint8_t> -{ -}; - -// BaseType -template<> -struct SerialTypeInfo<BaseType> : public SerialConvertTypeInfo<BaseType, uint8_t> -{ -}; - -// SemanticVersion -template<> -struct SerialTypeInfo<SemanticVersion> : public SerialIdentityTypeInfo<SemanticVersion> -{ -}; - -// SourceLoc - -// Make the type exposed, so we can look for it if we want to remap. -template<> -struct SerialTypeInfo<SourceLoc> -{ - typedef SourceLoc NativeType; - typedef SerialSourceLoc SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialSourceLoc) - }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - SerialSourceLocWriter* sourceLocWriter = - writer->getExtraObjects().get<SerialSourceLocWriter>(); - *(SerialType*)outSerial = sourceLocWriter - ? sourceLocWriter->addSourceLoc(*(const NativeType*)inNative) - : SerialType(0); - } - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - SerialSourceLocReader* sourceLocReader = - reader->getExtraObjects().get<SerialSourceLocReader>(); - *(NativeType*)outNative = sourceLocReader - ? sourceLocReader->getSourceLoc(*(const SerialType*)inSerial) - : NativeType::fromRaw(0); - } -}; - -// Token -template<> -struct SerialTypeInfo<Token> -{ - typedef Token NativeType; - struct SerialType - { - SerialTypeInfo<BaseType>::SerialType type; - SerialTypeInfo<SourceLoc>::SerialType loc; - SerialIndex name; - }; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - SerialTypeInfo<TokenType>::toSerial(writer, &src.type, &dst.type); - SerialTypeInfo<SourceLoc>::toSerial(writer, &src.loc, &dst.loc); - - if (src.flags & TokenFlag::Name) - { - dst.name = writer->addName(src.getName()); - } - else - { - dst.name = writer->addString(src.getContent()); - } - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - dst.flags = 0; - dst.charsNameUnion.chars = nullptr; - - SerialTypeInfo<TokenType>::toNative(reader, &src.type, &dst.type); - SerialTypeInfo<SourceLoc>::toNative(reader, &src.loc, &dst.loc); - - // At the other end all token content will appear as Names. - if (src.name != SerialIndex(0)) - { - dst.charsNameUnion.name = reader->getName(src.name); - dst.flags |= TokenFlag::Name; - } - } -}; - -// NameLoc -template<> -struct SerialTypeInfo<NameLoc> -{ - typedef NameLoc NativeType; - struct SerialType - { - SerialTypeInfo<SourceLoc>::SerialType loc; - SerialIndex name; - }; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - dst.name = writer->addName(src.name); - SerialTypeInfo<SourceLoc>::toSerial(writer, &src.loc, &dst.loc); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - dst.name = reader->getName(src.name); - SerialTypeInfo<SourceLoc>::toNative(reader, &src.loc, &dst.loc); - } -}; - -// DiagnosticInfo -template<> -struct SerialTypeInfo<const DiagnosticInfo*> -{ - typedef const DiagnosticInfo* NativeType; - typedef SerialIndex SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - dst = src ? writer->addString(UnownedStringSlice(src->name)) : SerialIndex(0); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - if (src == SerialIndex(0)) - { - dst = nullptr; - } - else - { - dst = findDiagnosticByName(reader->getStringSlice(src)); - } - } -}; - -// DeclAssociation -template<> -struct SerialTypeInfo<DeclAssociation> : SerialIdentityTypeInfo<DeclAssociation> -{ -}; -template<> -struct SerialTypeInfo<DeclAssociationKind> - : public SerialConvertTypeInfo<DeclAssociationKind, uint8_t> -{ -}; - -} // namespace Slang - -#endif diff --git a/source/slang/slang-serialize-reflection.cpp b/source/slang/slang-serialize-reflection.cpp deleted file mode 100644 index 60ab31e17..000000000 --- a/source/slang/slang-serialize-reflection.cpp +++ /dev/null @@ -1,123 +0,0 @@ -// slang-serialize-reflection.cpp -#include "slang-serialize-reflection.h" - -#include "slang-serialize.h" - -namespace Slang -{ - -bool ReflectClassInfo::isSubClassOfSlow(const ThisType& super) const -{ - ReflectClassInfo const* info = this; - while (info) - { - if (info == &super) - return true; - info = info->m_superClass; - } - return false; -} - -#if 0 - -// #if'd out because produces a warning->error if not used. -static bool _checkSubClassRange(ReflectClassInfo*const* typeInfos, Index typeInfosCount) -{ - for (Index i = 0; i < typeInfosCount; ++i) - { - for (Index j = 0; j < typeInfosCount; ++j) - { - auto a = typeInfos[i]; - auto b = typeInfos[j]; - if (a->isSubClassOf(*b) != a->isSubClassOfSlow(*b)) - { - return false; - } - } - } - - return true; -} - -#endif - -static uint32_t _calcRangeRec( - ReflectClassInfo* classInfo, - const Dictionary<const ReflectClassInfo*, List<ReflectClassInfo*>>& childMap, - uint32_t index) -{ - classInfo->m_classId = index++; - // Do the calc range for all the children - auto list = childMap.tryGetValue(classInfo); - - if (list) - { - for (auto child : *list) - { - index = _calcRangeRec(child, childMap, index); - } - } - - classInfo->m_lastClassId = index; - return index; -} - -static ReflectClassInfo* _calcRoot(ReflectClassInfo* classInfo) -{ - while (classInfo->m_superClass) - { - classInfo = const_cast<ReflectClassInfo*>(classInfo->m_superClass); - } - return classInfo; -} - - -/* static */ void ReflectClassInfo::calcClassIdHierachy( - uint32_t baseIndex, - ReflectClassInfo* const* typeInfos, - Index typeInfosCount) -{ - SLANG_ASSERT(typeInfosCount > 0); - - // TODO(JS): - // Note that the calculating of the ranges could be done more efficiently by adding to an array - // of struct { super, class }, sorting, by super classs and using a dictionary to map from class - // it's first in list of super class use. This works for now though. - - // The root cannot be shared with another hierarchy - as doing so will mean that the range will - // be incorrect (it would need to span both trees) - ReflectClassInfo* root = _calcRoot(typeInfos[0]); - - // We want to produce a map from a node that holds all of it's children - Dictionary<const ThisType*, List<ThisType*>> childMap; - - const List<ThisType*> emptyList; - { - for (Index i = 0; i < typeInfosCount; ++i) - { - auto typeInfo = typeInfos[i]; - if (typeInfo->m_superClass) - { - // Add to that item - List<ThisType*>* list = - childMap.tryGetValueOrAdd(typeInfo->m_superClass, emptyList); - if (!list) - { - list = childMap.tryGetValue(typeInfo->m_superClass); - } - SLANG_ASSERT(list); - list->add(typeInfo); - } - - // The root should be the same for all types - SLANG_ASSERT(_calcRoot(typeInfo) == root); - } - } - - // We want to recursively work out a range - _calcRangeRec(root, childMap, baseIndex); - - // SLANG_ASSERT(_checkSubClassRange(typeInfos, typeInfoCount)); -} - -} // namespace Slang diff --git a/source/slang/slang-serialize-reflection.h b/source/slang/slang-serialize-reflection.h deleted file mode 100644 index 63ea1e7b6..000000000 --- a/source/slang/slang-serialize-reflection.h +++ /dev/null @@ -1,86 +0,0 @@ -// slang-serialize-reflection.h -#ifndef SLANG_SERIALIZE_REFLECTION_H -#define SLANG_SERIALIZE_REFLECTION_H - -#include "../compiler-core/slang-name.h" - -namespace Slang -{ - -struct ReflectClassInfo -{ - typedef ReflectClassInfo ThisType; - - typedef void* (*CreateFunc)(void* context); - typedef void (*DestructorFunc)(void* ptr); - - /// A constant time implementation of isSubClassOf - SLANG_FORCE_INLINE bool isSubClassOf(const ThisType& super) const - { - // We include super.m_classId, because it's a subclass of itself. - return m_classId >= super.m_classId && m_classId <= super.m_lastClassId; - } - - SLANG_FORCE_INLINE static bool isValidTypeId(uint32_t typeId) { return int32_t(typeId) >= 0; } - - // True if typeId derives from this type - SLANG_FORCE_INLINE bool isDerivedFrom(uint32_t typeId) const - { - SLANG_ASSERT(isValidTypeId(typeId) && isValidTypeId(m_classId)); - return typeId >= m_classId && typeId <= m_lastClassId; - } - - SLANG_FORCE_INLINE static bool isSubClassOf(uint32_t type, const ThisType& super) - { - SLANG_ASSERT(isValidTypeId(type) && isValidTypeId(super.m_classId)); - // We include super.m_classId, because it's a subclass of itself. - return type >= super.m_classId && type <= super.m_lastClassId; - } - - /// Will produce the same result as isSubClassOf (if enumerated), but more slowly by traversing - /// the m_superClass Works without initRange being called. - bool isSubClassOfSlow(const ThisType& super) const; - - /// Calculate infos m_classId for all the infos specified such that they are honor the - /// inheritance relationship such that a m_classId of a child is > m_classId && <= m_lastClassId - static void calcClassIdHierachy( - uint32_t baseIndex, - ReflectClassInfo* const* infos, - Index infosCount); - - uint32_t m_classId; ///< Not necessarily set. - uint32_t m_lastClassId; - - const ReflectClassInfo* - m_superClass; ///< The super class of this class, or nullptr if has no super class. - const char* m_name; ///< Textual class name, for debugging - CreateFunc m_createFunc; ///< Callback to use when creating instances (using an ASTBuilder for - ///< backing memory) - DestructorFunc m_destructorFunc; ///< The destructor for this type. Being just destructor, does - ///< not free backing memory for type. - - uint32_t m_sizeInBytes; ///< Total size of the type - uint8_t m_alignment; ///< The required alignment of the type -}; - -// Does nothing - just a mark to the C++ extractor -#define SLANG_REFLECTED -#define SLANG_UNREFLECTED - -#define SLANG_PRE_DECLARE(SUFFIX, DEF) - -#define SLANG_TYPE_SET(SUFFIX, ...) - -// Use these macros to help define Super, and making the base definition NOT have a Super -// definition. For example something like... - -#define SLANG_CLASS_REFLECT_SUPER_BASE(SUPER) -#define SLANG_CLASS_REFLECT_SUPER_INNER(SUPER) typedef SUPER Super; -#define SLANG_CLASS_REFLECT_SUPER_LEAF(SUPER) typedef SUPER Super; - -// Mark a value class -#define SLANG_VALUE_CLASS(x) - -} // namespace Slang - -#endif diff --git a/source/slang/slang-serialize-source-loc.h b/source/slang/slang-serialize-source-loc.h index 10d084fb6..24e1813a4 100644 --- a/source/slang/slang-serialize-source-loc.h +++ b/source/slang/slang-serialize-source-loc.h @@ -147,8 +147,6 @@ public: class SerialSourceLocReader : public RefObject { public: - static const SerialExtraType kExtraType = SerialExtraType::SourceLocReader; - Index findViewIndex(SerialSourceLocData::SourceLoc loc); SourceLoc getSourceLoc(SerialSourceLocData::SourceLoc loc); @@ -186,8 +184,6 @@ protected: class SerialSourceLocWriter : public RefObject { public: - static const SerialExtraType kExtraType = SerialExtraType::SourceLocWriter; - class Source : public RefObject { public: diff --git a/source/slang/slang-serialize-type-info.h b/source/slang/slang-serialize-type-info.h deleted file mode 100644 index 20662c319..000000000 --- a/source/slang/slang-serialize-type-info.h +++ /dev/null @@ -1,491 +0,0 @@ -// slang-serialize-type-info.h -#ifndef SLANG_SERIALIZE_TYPE_INFO_H -#define SLANG_SERIALIZE_TYPE_INFO_H - -#include "slang-serialize.h" - -namespace Slang -{ - -/* For the serialization system to work we need to defined how native types are represented in the -serialized format. This information is defined by specializing SerialTypeInfo with the native type -to be converted This header provides conversion for common Slang types. -*/ - - -// We need to have a way to map between the two. -// If no mapping is needed, (just a copy), then we don't bother with the functions -template<typename T> -struct SerialBasicTypeInfo -{ - typedef T NativeType; - typedef T SerialType; - - // We want the alignment to be the same as the size of the type for basic types - // NOTE! Might be different from SLANG_ALIGN_OF(SerialType) - enum - { - SerialAlignment = sizeof(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - SLANG_UNUSED(writer); - *(T*)serial = *(const T*)native; - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - SLANG_UNUSED(reader); - *(T*)native = *(const T*)serial; - } - - static const SerialType* getType() - { - static const SerialType type = - {sizeof(SerialType), uint8_t(SerialAlignment), &toSerial, &toNative}; - return &type; - } -}; - -template<typename NATIVE_T, typename SERIAL_T> -struct SerialConvertTypeInfo -{ - typedef NATIVE_T NativeType; - typedef SERIAL_T SerialType; - - enum - { - SerialAlignment = SerialBasicTypeInfo<SERIAL_T>::SerialAlignment - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - SLANG_UNUSED(writer); - *(SERIAL_T*)serial = SERIAL_T(*(const NATIVE_T*)native); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - SLANG_UNUSED(reader); - *(NATIVE_T*)native = NATIVE_T(*(const SERIAL_T*)serial); - } -}; - -template<typename T> -struct SerialIdentityTypeInfo -{ - typedef T NativeType; - typedef T SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - SLANG_UNUSED(writer); - *(T*)serial = *(const T*)native; - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - SLANG_UNUSED(reader); - *(T*)native = *(const T*)serial; - } -}; - -// Don't need to convert the index type - -template<> -struct SerialTypeInfo<SerialIndex> : public SerialIdentityTypeInfo<SerialIndex> -{ -}; - -// Implement for Basic Types - -template<> -struct SerialTypeInfo<uint8_t> : public SerialBasicTypeInfo<uint8_t> -{ -}; -template<> -struct SerialTypeInfo<uint16_t> : public SerialBasicTypeInfo<uint16_t> -{ -}; -template<> -struct SerialTypeInfo<uint32_t> : public SerialBasicTypeInfo<uint32_t> -{ -}; -template<> -struct SerialTypeInfo<uint64_t> : public SerialBasicTypeInfo<uint64_t> -{ -}; - -template<> -struct SerialTypeInfo<int8_t> : public SerialBasicTypeInfo<int8_t> -{ -}; -template<> -struct SerialTypeInfo<int16_t> : public SerialBasicTypeInfo<int16_t> -{ -}; -template<> -struct SerialTypeInfo<int32_t> : public SerialBasicTypeInfo<int32_t> -{ -}; -template<> -struct SerialTypeInfo<int64_t> : public SerialBasicTypeInfo<int64_t> -{ -}; - -template<> -struct SerialTypeInfo<float> : public SerialBasicTypeInfo<float> -{ -}; -template<> -struct SerialTypeInfo<double> : public SerialBasicTypeInfo<double> -{ -}; - -// Fixed arrays - -template<typename T, size_t N> -struct SerialTypeInfo<T[N]> -{ - typedef SerialTypeInfo<T> ElementASTSerialType; - typedef typename ElementASTSerialType::SerialType SerialElementType; - - typedef T NativeType[N]; - typedef SerialElementType SerialType[N]; - - enum - { - SerialAlignment = SerialTypeInfo<T>::SerialAlignment - }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - SerialElementType* serial = (SerialElementType*)outSerial; - - if (writer->getFlags() & SerialWriter::Flag::ZeroInitialize) - { - ::memset(outSerial, 0, sizeof(SerialElementType) * N); - } - - const T* native = (const T*)inNative; - for (Index i = 0; i < Index(N); ++i) - { - ElementASTSerialType::toSerial(writer, native + i, serial + i); - } - } - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - const SerialElementType* serial = (const SerialElementType*)inSerial; - T* native = (T*)outNative; - for (Index i = 0; i < Index(N); ++i) - { - ElementASTSerialType::toNative(reader, serial + i, native + i); - } - } -}; - -// Special case bool - as we can't rely on size alignment -template<> -struct SerialTypeInfo<bool> -{ - typedef bool NativeType; - typedef uint8_t SerialType; - - enum - { - SerialAlignment = sizeof(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - SLANG_UNUSED(writer); - *(SerialType*)outSerial = *(const NativeType*)inNative ? 1 : 0; - } - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - SLANG_UNUSED(reader); - *(NativeType*)outNative = (*(const SerialType*)inSerial) != 0; - } -}; - -// Specialization for all enum types -template<typename T> -struct SerialTypeInfo<T, typename std::enable_if<std::is_enum<T>::value>::type> - : public SerialIdentityTypeInfo<T> -{ -}; - -class Val; - -// Pointer - -template<typename T, typename /*sfinaeType*/ = void> -struct PtrSerialTypeInfo -{ - typedef T* NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - auto ptrToWrite = *(T**)inNative; - static_assert(!IsBaseOf<Val, T>::Value); - *(SerialIndex*)outSerial = writer->addPointer(ptrToWrite); - } - - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - *(T**)outNative = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<T>(); - } -}; - -template<typename T> -struct SerialTypeInfo<T*> : public PtrSerialTypeInfo<T> -{ -}; - -// RefPtr (pretty much the same as T* - except for native rep) -template<typename T> -struct SerialTypeInfo<RefPtr<T>> -{ - typedef RefPtr<T> NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - *(SerialType*)serial = writer->addPointer(src); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - *(NativeType*)native = reader->getPointer(*(const SerialType*)serial).dynamicCast<T>(); - } -}; - -// Special case Name -template<> -struct SerialTypeInfo<Name*> : public SerialTypeInfo<RefObject*> -{ - // Special case - typedef Name* NativeType; - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - *(Name**)outNative = reader->getName(*(const SerialType*)inSerial); - } -}; - -template<> -struct SerialTypeInfo<const Name*> : public SerialTypeInfo<Name*> -{ -}; - -// List -template<typename T, typename ALLOCATOR> -struct SerialTypeInfo<List<T, ALLOCATOR>> -{ - typedef List<T, ALLOCATOR> NativeType; - typedef SerialIndex SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - dst = writer->addArray(src.getBuffer(), src.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - reader->getArray(src, dst); - } -}; - -// ShortList -template<typename T, int n, typename ALLOCATOR> -struct SerialTypeInfo<ShortList<T, n, ALLOCATOR>> -{ - typedef ShortList<T, n, ALLOCATOR> NativeType; - typedef SerialIndex SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - dst = writer->addArray(src.getArrayView().getBuffer(), src.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - reader->getArray(src, dst); - } -}; - -// String -template<> -struct SerialTypeInfo<String> -{ - typedef String NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - *(SerialType*)serial = writer->addString(src); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - dst = reader->getString(src); - } -}; - -// Dictionary -// Note: We leave out SerialTypeInfo specialization for Dictionary, because -// it does not have determinstic ordering. - -// OrderedDictionary -template<typename KEY, typename VALUE> -struct SerialTypeInfo<OrderedDictionary<KEY, VALUE>> -{ - typedef OrderedDictionary<KEY, VALUE> NativeType; - struct SerialType - { - SerialIndex keys; ///< Index an array - SerialIndex values; ///< Index an array - }; - - typedef typename SerialTypeInfo<KEY>::SerialType KeySerialType; - typedef typename SerialTypeInfo<VALUE>::SerialType ValueSerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialIndex) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - List<KeySerialType> keys; - List<ValueSerialType> values; - - Index count = Index(src.getCount()); - keys.setCount(count); - values.setCount(count); - - if (writer->getFlags() & SerialWriter::Flag::ZeroInitialize) - { - ::memset(keys.getBuffer(), 0, count * sizeof(KeySerialType)); - ::memset(values.getBuffer(), 0, count * sizeof(ValueSerialType)); - } - - Index i = 0; - for (const auto& pair : src) - { - SerialTypeInfo<KEY>::toSerial(writer, &pair.key, &keys[i]); - SerialTypeInfo<VALUE>::toSerial(writer, &pair.value, &values[i]); - i++; - } - - // When we add the array it is already converted to a serializable type, so add as - // SerialArray - dst.keys = writer->addSerialArray<KEY>(keys.getBuffer(), count); - dst.values = writer->addSerialArray<VALUE>(values.getBuffer(), count); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - // Clear it - dst = NativeType(); - - List<KEY> keys; - List<VALUE> values; - - reader->getArray(src.keys, keys); - reader->getArray(src.values, values); - - SLANG_ASSERT(keys.getCount() == values.getCount()); - - const Index count = keys.getCount(); - for (Index i = 0; i < count; ++i) - { - dst.add(keys[i], values[i]); - } - } -}; - -// KeyValuePair -template<typename KEY, typename VALUE> -struct SerialTypeInfo<KeyValuePair<KEY, VALUE>> -{ - typedef KeyValuePair<KEY, VALUE> NativeType; - - typedef typename SerialTypeInfo<KEY>::SerialType KeySerialType; - typedef typename SerialTypeInfo<VALUE>::SerialType ValueSerialType; - - struct SerialType - { - KeySerialType key; - ValueSerialType value; - }; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - SerialTypeInfo<KEY>::toSerial(writer, &src.key, &dst.key); - SerialTypeInfo<VALUE>::toSerial(writer, &src.value, &dst.value); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - SerialTypeInfo<KEY>::toNative(reader, &src.key, &dst.key); - SerialTypeInfo<VALUE>::toNative(reader, &src.value, &dst.value); - } -}; - - -} // namespace Slang - -#endif diff --git a/source/slang/slang-serialize-types.h b/source/slang/slang-serialize-types.h index 217c14b44..cd2b4c99c 100644 --- a/source/slang/slang-serialize-types.h +++ b/source/slang/slang-serialize-types.h @@ -11,14 +11,7 @@ namespace Slang { - -// An enumeration of types that can be set -enum class SerialExtraType -{ - SourceLocReader, - SourceLocWriter, - CountOf, -}; +class Module; // Options for IR/AST/Debug serialization @@ -35,7 +28,6 @@ struct SerialOptionFlag ASTModule = 0x04, ///< If set will output AST modules - typically required, but potentially ///< not desired (for example with obsfucation) IRModule = 0x08, ///< If set will output IR modules - typically required - ASTFunctionBody = 0x10, ///< If set will serialize AST function bodies. }; }; typedef SerialOptionFlag::Type SerialOptionFlags; @@ -123,6 +115,20 @@ struct SerialListUtil } }; +template<typename T> +struct PropertyKeys +{ +}; + +template<> +struct PropertyKeys<Module> +{ + static const FourCC Digest = SLANG_FOUR_CC('S', 'H', 'A', '1'); + static const FourCC ASTModule = SLANG_FOUR_CC('a', 's', 't', ' '); + static const FourCC IRModule = SLANG_FOUR_CC('i', 'r', ' ', ' '); + static const FourCC FileDependencies = SLANG_FOUR_CC('f', 'd', 'e', 'p'); +}; + // For types/FourCC that work for serializing in general (not just IR). struct SerialBinary { @@ -140,8 +146,44 @@ struct SerialBinary /// An entry point static const FourCC kEntryPointFourCc = SLANG_FOUR_CC('E', 'P', 'n', 't'); - // Module header - static const FourCC kModuleHeaderFourCc = SLANG_FOUR_CC('S', 'm', 'h', 'd'); + static const FourCC kEntryPointListFourCc = SLANG_FOUR_CC('e', 'p', 't', 's'); + + // Module + static const FourCC kModuleFourCC = SLANG_FOUR_CC('s', 'm', 'o', 'd'); + + // The following are "generic" codes, suitable for + // use when serializing content using JSON-like structure. + // + static const FourCC kObjectFourCC = SLANG_FOUR_CC('o', 'b', 'j', ' '); + static const FourCC kPairFourCC = SLANG_FOUR_CC('p', 'a', 'i', 'r'); + static const FourCC kArrayFourCC = SLANG_FOUR_CC('a', 'r', 'r', 'y'); + static const FourCC kDictionaryFourCC = SLANG_FOUR_CC('d', 'i', 'c', 't'); + static const FourCC kNullFourCC = SLANG_FOUR_CC('n', 'u', 'l', 'l'); + static const FourCC kStringFourCC = SLANG_FOUR_CC('s', 't', 'r', ' '); + static const FourCC kTrueFourCC = SLANG_FOUR_CC('t', 'r', 'u', 'e'); + static const FourCC kFalseFourCC = SLANG_FOUR_CC('f', 'a', 'l', 's'); + static const FourCC kInt32FourCC = SLANG_FOUR_CC('i', '3', '2', ' '); + static const FourCC kUInt32FourCC = SLANG_FOUR_CC('u', '3', '2', ' '); + static const FourCC kFloat32FourCC = SLANG_FOUR_CC('f', '3', '2', ' '); + static const FourCC kInt64FourCC = SLANG_FOUR_CC('i', '6', '4', ' '); + static const FourCC kUInt64FourCC = SLANG_FOUR_CC('u', '6', '4', ' '); + static const FourCC kFloat64FourCC = SLANG_FOUR_CC('f', '6', '4', ' '); + + // The following codes are suitable for use when serializing + // content that represents a logical file system. + // + static const FourCC kDirectoryFourCC = SLANG_FOUR_CC('d', 'i', 'r', ' '); + static const FourCC kFileFourCC = SLANG_FOUR_CC('f', 'i', 'l', 'e'); + static const FourCC kNameFourCC = SLANG_FOUR_CC('n', 'a', 'm', 'e'); + static const FourCC kPathFourCC = SLANG_FOUR_CC('p', 'a', 't', 'h'); + static const FourCC kDataFourCC = SLANG_FOUR_CC('d', 'a', 't', 'a'); + + // TODO(tfoley): Figure out where to put all of these so that + // they can be more usefully addressed. + // + static const FourCC kMangledNameFourCC = SLANG_FOUR_CC('m', 'g', 'n', 'm'); + static const FourCC kProfileFourCC = SLANG_FOUR_CC('p', 'r', 'o', 'f'); + struct ArrayHeader { diff --git a/source/slang/slang-serialize-value-type-info.h b/source/slang/slang-serialize-value-type-info.h deleted file mode 100644 index 3ebbdc858..000000000 --- a/source/slang/slang-serialize-value-type-info.h +++ /dev/null @@ -1,83 +0,0 @@ -// slang-serialize-value-type-info.h - -#ifndef SLANG_SERIALIZE_VALUE_TYPE_INFO_H -#define SLANG_SERIALIZE_VALUE_TYPE_INFO_H - -#include "slang-ast-support-types.h" -#include "slang-generated-value-macro.h" -#include "slang-generated-value.h" -#include "slang-serialize-misc-type-info.h" -#include "slang-serialize-type-info.h" -#include "slang-serialize.h" - -// Create the functions to automatically convert between value types - -namespace Slang -{ - -// TODO(JS): We may want to strip const or other modifiers -// Just strips the brackets. -#define SLANG_VALUE_GET_TYPE(TYPE) TYPE - -#define SLANG_VALUE_FIELD_TO_SERIAL(FIELD_NAME, TYPE, param) \ - SerialTypeInfo<decltype(src->FIELD_NAME)>::toSerial(writer, &src->FIELD_NAME, &dst->FIELD_NAME); -#define SLANG_VALUE_FIELD_TO_NATIVE(FIELD_NAME, TYPE, param) \ - SerialTypeInfo<decltype(dst->FIELD_NAME)>::toNative(reader, &src->FIELD_NAME, &dst->FIELD_NAME); - -#define SLANG_IF_HAS_SUPER_BASE(x) -#define SLANG_IF_HAS_SUPER_INNER(x) x -#define SLANG_IF_HAS_SUPER_LEAF(x) x - -#define SLANG_VALUE_TO_SERIAL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - static void toSerial(SerialWriter* writer, const void* native, void* serial) \ - { \ - SLANG_IF_HAS_SUPER_##TYPE( \ - SerialTypeInfo<SUPER>::toSerial(writer, native, serial);) auto dst = \ - (SerialType*)serial; \ - auto src = (const NativeType*)native; \ - SLANG_FIELDS_Value_##NAME(SLANG_VALUE_FIELD_TO_SERIAL, param) \ - } - -#define SLANG_VALUE_TO_NATIVE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - static void toNative(SerialReader* reader, const void* serial, void* native) \ - { \ - SLANG_IF_HAS_SUPER_##TYPE( \ - SerialTypeInfo<SUPER>::toNative(reader, serial, native);) auto src = \ - (const SerialType*)serial; \ - auto dst = (NativeType*)native; \ - SLANG_FIELDS_Value_##NAME(SLANG_VALUE_FIELD_TO_NATIVE, param) \ - } - -// #define SLANG_VALUE_SERIAL_FIELD(FIELD_NAME, TYPE, param) SerialTypeInfo<SLANG_VALUE_GET_TYPE -// TYPE>::SerialType FIELD_NAME; -#define SLANG_VALUE_SERIAL_FIELD(FIELD_NAME, TYPE, param) \ - SerialTypeInfo<decltype(((param*)nullptr)->FIELD_NAME)>::SerialType FIELD_NAME; - -#define SLANG_VALUE_SERIAL_STRUCT(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - struct SerialType SLANG_IF_HAS_SUPER_##TYPE( : SerialTypeInfo<SUPER>::SerialType) \ - { \ - SLANG_FIELDS_Value_##NAME(SLANG_VALUE_SERIAL_FIELD, NAME) \ - }; - -#define SLANG_VALUE_TYPE_INFO_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - template<> \ - struct SerialTypeInfo<NAME> \ - { \ - typedef NAME NativeType; \ - SLANG_VALUE_SERIAL_STRUCT(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - \ - enum \ - { \ - SerialAlignment = SLANG_ALIGN_OF(SerialType) \ - }; \ - \ - SLANG_VALUE_TO_NATIVE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - SLANG_VALUE_TO_SERIAL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - }; - -#define SLANG_VALUE_TYPE_INFO(NAME) SLANG_Value_##NAME(SLANG_VALUE_TYPE_INFO_IMPL, _) - - -} // namespace Slang - -#endif // SLANG_SERIALIZE_VALUE_TYPE_INFO_H diff --git a/source/slang/slang-serialize.cpp b/source/slang/slang-serialize.cpp index 2a1a92302..a1c555a9b 100644 --- a/source/slang/slang-serialize.cpp +++ b/source/slang/slang-serialize.cpp @@ -8,1182 +8,4 @@ namespace Slang { -const SerialClass* SerialClasses::add(const SerialClass* cls) -{ - List<const SerialClass*>& classes = m_classesByTypeKind[Index(cls->typeKind)]; - - if (cls->subType >= classes.getCount()) - { - classes.setCount(cls->subType + 1); - } - else - { - if (classes[cls->subType]) - { - SLANG_ASSERT(!"Type is already set"); - return nullptr; - } - } - - SerialClass* copy = _createSerialClass(cls); - classes[cls->subType] = copy; - - return copy; -} - -const SerialClass* SerialClasses::add( - SerialTypeKind kind, - SerialSubType subType, - const SerialField* fields, - Index fieldsCount, - const SerialClass* superCls) -{ - SerialClass cls; - cls.typeKind = kind; - cls.subType = subType; - - cls.fields = fields; - cls.fieldsCount = fieldsCount; - - // If the superCls is set it must be owned - SLANG_ASSERT(superCls == nullptr || isOwned(superCls)); - - cls.super = superCls; - - // Set to invalid values for now - cls.alignment = 0; - cls.size = 0; - cls.flags = 0; - - return add(&cls); -} - -const SerialClass* SerialClasses::addUnserialized(SerialTypeKind kind, SerialSubType subType) -{ - List<const SerialClass*>& classes = m_classesByTypeKind[Index(kind)]; - - if (subType >= classes.getCount()) - { - classes.setCount(subType + 1); - } - else - { - if (classes[subType]) - { - SLANG_ASSERT(!"Type is already set"); - return nullptr; - } - } - - SerialClass* dst = m_arena.allocate<SerialClass>(); - - dst->typeKind = kind; - dst->subType = subType; - - dst->size = 0; - dst->alignment = 0; - - dst->fields = nullptr; - dst->fieldsCount = 0; - dst->flags = SerialClassFlag::DontSerialize; - dst->super = nullptr; - - classes[subType] = dst; - return dst; -} - -bool SerialClasses::isOwned(const SerialClass* cls) const -{ - const List<const SerialClass*>& classes = m_classesByTypeKind[Index(cls->typeKind)]; - return cls->subType < classes.getCount() && classes[cls->subType] == cls; -} - -SerialClass* SerialClasses::_createSerialClass(const SerialClass* cls) -{ - uint32_t maxAlignment = 1; - uint32_t offset = 0; - - if (cls->super) - { - SLANG_ASSERT(isOwned(cls->super)); - - maxAlignment = cls->super->alignment; - offset = cls->super->size; - } - - // Can't be 0 - SLANG_ASSERT(maxAlignment != 0); - // Must be a power of 2 - SLANG_ASSERT((maxAlignment & (maxAlignment - 1)) == 0); - - // Check it is correctly aligned - SLANG_ASSERT((offset & (maxAlignment - 1)) == 0); - - SerialField* dstFields = m_arena.allocateArray<SerialField>(cls->fieldsCount); - - // Okay, go through fields setting their offset - const SerialField* srcFields = cls->fields; - for (Index j = 0; j < cls->fieldsCount; j++) - { - const SerialField& srcField = srcFields[j]; - SerialField& dstField = dstFields[j]; - - // Copy the field - dstField = srcField; - - uint32_t alignment = srcField.type->serialAlignment; - // Make sure the offset is aligned for the field requirement - offset = (offset + alignment - 1) & ~(alignment - 1); - - // Save the field offset - dstField.serialOffset = uint32_t(offset); - - // Move past the field - offset += uint32_t(srcField.type->serialSizeInBytes); - - // Calc the maximum alignment - maxAlignment = (alignment > maxAlignment) ? alignment : maxAlignment; - } - - // Align with maximum alignment - offset = (offset + maxAlignment - 1) & ~(maxAlignment - 1); - - SerialClass* dst = m_arena.allocate<SerialClass>(); - *dst = *cls; - - dst->alignment = uint8_t(maxAlignment); - dst->size = uint32_t(offset); - - dst->fields = dstFields; - - return dst; -} - -bool SerialClasses::isOk() const -{ - StringSlicePool pool(StringSlicePool::Style::Default); - - for (const auto& classes : m_classesByTypeKind) - { - for (const SerialClass* cls : classes) - { - // It is possible potentially to have gaps - if (cls == nullptr) - { - continue; - } - - if (cls->super && cls->super->typeKind != cls->typeKind) - { - // If has a super type, must be the same typeKind - return false; - } - - // Make sure the fields are uniquely named - - pool.clear(); - - { - const SerialClass* curCls = cls; - - do - { - for (Index i = 0; i < curCls->fieldsCount; ++i) - { - const SerialField& field = curCls->fields[i]; - - StringSlicePool::Handle handle; - if (pool.findOrAdd(UnownedStringSlice(field.name), handle)) - { - return false; - } - } - - // Add the fields of the parent - curCls = curCls->super; - } while (curCls); - } - } - } - - return true; -} - - -SerialClasses::SerialClasses() - : m_arena(2097152) -{ -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialWriter !!!!!!!!!!!!!!!!!!!!!!!!!!!! - -SerialWriter::SerialWriter(SerialClasses* classes, SerialFilter* filter, Flags flags) - : m_arena(2097152), m_classes(classes), m_filter(filter), m_flags(flags) -{ - // 0 is always the null pointer - m_entries.add(nullptr); - m_ptrMap.add(nullptr, 0); -} - -struct SkipFunctionBodyRAII -{ - FunctionDeclBase* funcDecl = nullptr; - Stmt* oldBody = nullptr; - SkipFunctionBodyRAII(SerialWriter::Flags flags, const SerialClass* serialCls, const void* ptr) - { - if ((flags & SerialWriter::Flag::SkipFunctionBody) == 0) - return; - - if (serialCls->typeKind != SerialTypeKind::NodeBase) - return; - auto cls = serialCls; - while (cls) - { - auto astNodeType = (ASTNodeType)cls->subType; - if (astNodeType == ASTNodeType::FunctionDeclBase) - { - funcDecl = (FunctionDeclBase*)ptr; - break; - } - cls = cls->super; - } - if (funcDecl) - { - oldBody = funcDecl->body; - // We always need to include body of unsafeForceInlineEarly functions - // since they will need to be available at IR lowering time of the - // user module for pre-linking inling. - if (!isUnsafeForceInlineFunc(funcDecl)) - { - funcDecl->body = nullptr; - } - } - } - ~SkipFunctionBodyRAII() - { - if (funcDecl) - { - funcDecl->body = oldBody; - } - } -}; - -SerialIndex SerialWriter::writeObject(const SerialClass* serialCls, const void* ptr) -{ - if (serialCls->flags & SerialClassFlag::DontSerialize) - { - return SerialIndex(0); - } - - if (serialCls->typeKind == SerialTypeKind::NodeBase && - ReflectClassInfo::isSubClassOf(serialCls->subType, Val::kReflectClassInfo)) - { - return writeValObject((Val*)ptr); - } - - // If we are skipping function bodies, set the body field to nullptr, and - // restore it after serialization. - SkipFunctionBodyRAII clearFunctionBodyRAII(m_flags, serialCls, ptr); - - // This pointer cannot be in the map - SLANG_ASSERT(m_ptrMap.tryGetValue(ptr) == nullptr); - - typedef SerialInfo::ObjectEntry ObjectEntry; - - ObjectEntry* nodeEntry = (ObjectEntry*)m_arena.allocateAligned( - sizeof(ObjectEntry) + serialCls->size, - SerialInfo::MAX_ALIGNMENT); - - nodeEntry->typeKind = serialCls->typeKind; - nodeEntry->subType = serialCls->subType; - nodeEntry->_pad0 = 0; - - nodeEntry->info = SerialInfo::makeEntryInfo(serialCls->alignment); - - // We add before adding fields, so if the fields point to this, the entry will be set - auto index = _add(ptr, nodeEntry); - - // Point to start of payload - uint8_t* serialPayload = (uint8_t*)(nodeEntry + 1); - - if (m_flags & Flag::ZeroInitialize) - { - ::memset(serialPayload, 0, serialCls->size); - } - - while (serialCls) - { - for (Index i = 0; i < serialCls->fieldsCount; ++i) - { - auto field = serialCls->fields[i]; - - // Work out the offsets - auto srcField = ((const uint8_t*)ptr) + field.nativeOffset; - auto dstField = serialPayload + field.serialOffset; - - field.type->toSerialFunc(this, srcField, dstField); - } - - // Get the super class - serialCls = serialCls->super; - } - - return index; -} - -SerialIndex SerialWriter::writeObject(const NodeBase* node) -{ - const SerialClass* serialClass = - m_classes->getSerialClass(SerialTypeKind::NodeBase, SerialSubType(node->astNodeType)); - return writeObject(serialClass, (const void*)node); -} - -SerialIndex SerialWriter::writeValObject(const Val* node) -{ - typedef SerialInfo::ValEntry ValEntry; - - size_t size = node->getOperandCount() * sizeof(SerialInfo::SerialValOperand); - ValEntry* nodeEntry = - (ValEntry*)m_arena.allocateAligned(sizeof(ValEntry) + size, SerialInfo::MAX_ALIGNMENT); - - nodeEntry->typeKind = SerialTypeKind::NodeBase; - nodeEntry->subType = (SerialSubType)node->astNodeType; - nodeEntry->operandCount = (uint32_t)node->getOperandCount(); - nodeEntry->info = SerialInfo::makeEntryInfo(SerialInfo::MAX_ALIGNMENT); - - // We add before adding fields, so if the fields point to this, the entry will be set - auto index = _add(node, nodeEntry); - - ShortList<SerialIndex, 4> serializedOperands; - - for (Index i = 0; i < node->getOperandCount(); i++) - { - auto operand = node->m_operands[i]; - switch (operand.kind) - { - case ValNodeOperandKind::ConstantValue: - serializedOperands.add((SerialIndex)0); - break; - case ValNodeOperandKind::ValNode: - case ValNodeOperandKind::ASTNode: - serializedOperands.add(addPointer(operand.values.nodeOperand)); - break; - } - } - - SLANG_ASSERT(serializedOperands.getCount() == node->getOperandCount()); - - auto serialOperands = (SerialInfo::SerialValOperand*)(nodeEntry + 1); - for (Index i = 0; i < node->getOperandCount(); i++) - { - auto serialOperand = serialOperands + i; - auto operand = node->m_operands[i]; - serialOperand->type = (int)operand.kind; - switch (operand.kind) - { - case ValNodeOperandKind::ConstantValue: - serialOperand->payload = operand.values.intOperand; - break; - case ValNodeOperandKind::ValNode: - serialOperand->payload = (uint64_t)serializedOperands[i]; - break; - case ValNodeOperandKind::ASTNode: - serialOperand->payload = (uint64_t)serializedOperands[i]; - break; - } - } - return index; -} - -SerialIndex SerialWriter::writeObject(const RefObject* obj) -{ - const SerialRefObject* serialObj = as<const SerialRefObject>(obj); - if (!serialObj) - { - SLANG_ASSERT(!"Unhandled type"); - return SerialIndex(0); - } - - const ReflectClassInfo* classInfo = serialObj->getClassInfo(); - SLANG_ASSERT(classInfo); - - const SerialClass* serialClass = - m_classes->getSerialClass(SerialTypeKind::RefObject, SerialSubType(classInfo->m_classId)); - return writeObject(serialClass, (const void*)obj); -} - -void SerialWriter::setPointerIndex(const NodeBase* ptr, SerialIndex index) -{ - m_ptrMap.add(ptr, Index(index)); -} - -void SerialWriter::setPointerIndex(const RefObject* ptr, SerialIndex index) -{ - m_ptrMap.add(ptr, Index(index)); -} - -SerialIndex SerialWriter::addPointer(const NodeBase* node) -{ - // Null is always 0 - if (node == nullptr) - { - return SerialIndex(0); - } - // Look up in the map - Index* indexPtr = m_ptrMap.tryGetValue(node); - if (indexPtr) - { - return SerialIndex(*indexPtr); - } - - if (m_filter) - { - return m_filter->writePointer(this, node); - } - else - { - return writeObject(node); - } -} - -SerialIndex SerialWriter::addPointer(const RefObject* obj) -{ - // Null is always 0 - if (obj == nullptr) - { - return SerialIndex(0); - } - // Look up in the map - Index* indexPtr = m_ptrMap.tryGetValue(obj); - if (indexPtr) - { - return SerialIndex(*indexPtr); - } - - // TODO(JS): - // Arguably the lookup for these types should be done the same way as arbitrary RefObject types - // and have a enum for them, such we can use a switch instead of all this casting - - if (auto stringRep = dynamicCast<StringRepresentation>(obj)) - { - SerialIndex index = addString(StringRepresentation::asSlice(stringRep)); - m_ptrMap.add(obj, Index(index)); - return index; - } - else if (auto name = dynamicCast<const Name>(obj)) - { - return addName(name); - } - - if (m_filter) - { - return m_filter->writePointer(this, obj); - } - else - { - return writeObject(obj); - } -} - -SerialIndex SerialWriter::_addStringSlice( - SerialTypeKind typeKind, - SliceMap& sliceMap, - const UnownedStringSlice& slice) -{ - typedef ByteEncodeUtil Util; - typedef SerialInfo::StringEntry StringEntry; - - if (slice.getLength() == 0) - { - return SerialIndex(0); - } - - Index* indexPtr = sliceMap.tryGetValue(slice); - if (indexPtr) - { - return SerialIndex(*indexPtr); - } - - // Okay we need to add the string - - uint8_t encodeBuf[Util::kMaxLiteEncodeUInt32]; - const int encodeCount = Util::encodeLiteUInt32(uint32_t(slice.getLength()), encodeBuf); - - StringEntry* entry = (StringEntry*)m_arena.allocateUnaligned( - SLANG_OFFSET_OF(StringEntry, sizeAndChars) + encodeCount + slice.getLength()); - entry->info = SerialInfo::EntryInfo::Alignment1; - entry->typeKind = typeKind; - - uint8_t* dst = (uint8_t*)(entry->sizeAndChars); - for (int i = 0; i < encodeCount; ++i) - { - dst[i] = encodeBuf[i]; - } - - memcpy(dst + encodeCount, slice.begin(), slice.getLength()); - - // Make a key that will stay in scope -> it's actually just stored in the arena. - // NOTE! without terminating 0 - UnownedStringSlice keySlice(((const char*)dst) + encodeCount, slice.getLength()); - - Index newIndex = m_entries.getCount(); - sliceMap.add(keySlice, newIndex); - - m_entries.add(entry); - return SerialIndex(newIndex); -} - -SerialIndex SerialWriter::addString(const String& in) -{ - return addPointer(in.getStringRepresentation()); -} - -SerialIndex SerialWriter::addName(const Name* name) -{ - if (name == nullptr) - { - return SerialIndex(0); - } - - // Look it up - Index* indexPtr = m_ptrMap.tryGetValue(name); - if (indexPtr) - { - return SerialIndex(*indexPtr); - } - - SerialIndex index = addString(name->text); - m_ptrMap.add(name, Index(index)); - return index; -} - -SerialIndex SerialWriter::addSerialArray( - size_t elementSize, - size_t alignment, - const void* elements, - Index elementCount) -{ - typedef SerialInfo::ArrayEntry Entry; - - if (elementCount == 0) - { - return SerialIndex(0); - } - - SLANG_ASSERT(alignment >= 1 && alignment <= SerialInfo::MAX_ALIGNMENT); - - // We must at a minimum have the alignment for the array prefix info - alignment = (alignment < SLANG_ALIGN_OF(Entry)) ? SLANG_ALIGN_OF(Entry) : alignment; - - size_t payloadSize = elementCount * elementSize; - - Entry* entry = (Entry*)m_arena.allocateAligned(sizeof(Entry) + payloadSize, alignment); - - entry->typeKind = SerialTypeKind::Array; - entry->info = SerialInfo::makeEntryInfo(int(alignment)); - entry->elementSize = uint16_t(elementSize); - entry->elementCount = uint32_t(elementCount); - - memcpy(entry + 1, elements, payloadSize); - - m_entries.add(entry); - return SerialIndex(m_entries.getCount() - 1); -} - -static const uint8_t s_fixBuffer[SerialInfo::MAX_ALIGNMENT]{ - 0, -}; - -SlangResult SerialWriter::write(Stream* stream) -{ - const Int entriesCount = m_entries.getCount(); - - // Add a sentinal so we don't need special handling for - SerialInfo::Entry sentinal; - sentinal.typeKind = SerialTypeKind::String; - sentinal.info = SerialInfo::EntryInfo::Alignment1; - - m_entries.add(&sentinal); - m_entries.removeLast(); - - SerialInfo::Entry** entries = m_entries.getBuffer(); - // Note strictly required in our impl of List. But by writing this and - // knowing that removeLast cannot release memory, means the sentinal must be at the last - // position. - entries[entriesCount] = &sentinal; - - { - size_t offset = 0; - - SerialInfo::Entry* entry = entries[1]; - // We start on 1, because 0 is nullptr and not used for anything - for (Index i = 1; i < entriesCount; ++i) - { - SerialInfo::Entry* next = entries[i + 1]; - - // Before writing we need to store the next alignment - - const size_t nextAlignment = SerialInfo::getAlignment(next->info); - const size_t alignment = SerialInfo::getAlignment(entry->info); - SLANG_UNUSED(alignment); - - entry->info = SerialInfo::combineWithNext(entry->info, next->info); - - // Check we are aligned correctly - SLANG_ASSERT((offset & (alignment - 1)) == 0); - - // When we write, we need to make sure it take into account the next alignment - const size_t entrySize = entry->calcSize(m_classes); - - // Work out the fix for next alignment - size_t nextOffset = offset + entrySize; - nextOffset = (nextOffset + nextAlignment - 1) & ~(nextAlignment - 1); - - size_t alignmentFixSize = nextOffset - (offset + entrySize); - - // The fix must be less than max alignment. We require it to be less because we aligned - // each Entry to MAX_ALIGNMENT, and so < MAX_ALIGNMENT is the most extra bytes we can - // write - SLANG_ASSERT(alignmentFixSize < SerialInfo::MAX_ALIGNMENT); - - SLANG_RETURN_ON_FAIL(stream->write(entry, entrySize)); - // If we needed to fix so that subsequent alignment is right, write out extra bytes here - if (alignmentFixSize) - { - SLANG_RETURN_ON_FAIL(stream->write(s_fixBuffer, alignmentFixSize)); - } - - // Onto next - offset = nextOffset; - entry = next; - } - } - - return SLANG_OK; -} - -SlangResult SerialWriter::writeIntoContainer(FourCC fourCc, RiffContainer* container) -{ - typedef RiffContainer::Chunk Chunk; - typedef RiffContainer::ScopeChunk ScopeChunk; - - { - ScopeChunk scopeData(container, Chunk::Kind::Data, fourCc); - - { - // Sentinel so we don't need special handling for end of list - SerialInfo::Entry sentinal; - sentinal.typeKind = SerialTypeKind::String; - sentinal.info = SerialInfo::EntryInfo::Alignment1; - - size_t offset = 0; - const Int entriesCount = m_entries.getCount(); - - { - m_entries.add(&sentinal); - m_entries.removeLast(); - // Note strictly required in our impl of List. But by writing this and - // knowing that removeLast cannot release memory, means the sentinal must be at the - // last position. - m_entries.getBuffer()[entriesCount] = &sentinal; - } - - SerialInfo::Entry* const* entries = m_entries.getBuffer(); - - SerialInfo::Entry* entry = entries[1]; - // We start on 1, because 0 is nullptr and not used for anything - for (Index i = 1; i < entriesCount; ++i) - { - SerialInfo::Entry* next = entries[i + 1]; - - // Before writing we need to store the next alignment - - const size_t nextAlignment = SerialInfo::getAlignment(next->info); - const size_t alignment = SerialInfo::getAlignment(entry->info); - SLANG_UNUSED(alignment); - - entry->info = SerialInfo::combineWithNext(entry->info, next->info); - - // Check we are aligned correctly - SLANG_ASSERT((offset & (alignment - 1)) == 0); - - // When we write, we need to make sure it take into account the next alignment - const size_t entrySize = entry->calcSize(m_classes); - - // Work out the fix for next alignment - size_t nextOffset = offset + entrySize; - nextOffset = (nextOffset + nextAlignment - 1) & ~(nextAlignment - 1); - - size_t alignmentFixSize = nextOffset - (offset + entrySize); - - // The fix must be less than max alignment. We require it to be less because we - // aligned each Entry to MAX_ALIGNMENT, and so < MAX_ALIGNMENT is the most extra - // bytes we can write - SLANG_ASSERT(alignmentFixSize < SerialInfo::MAX_ALIGNMENT); - - container->write(entry, entrySize); - if (alignmentFixSize) - { - container->write(s_fixBuffer, alignmentFixSize); - } - - // Onto next - offset = nextOffset; - entry = next; - } - } - } - - return SLANG_OK; -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialInfo::Entry !!!!!!!!!!!!!!!!!!!!!!!! - -size_t SerialInfo::Entry::calcSize(SerialClasses* serialClasses) const -{ - switch (typeKind) - { - case SerialTypeKind::ImportSymbol: - case SerialTypeKind::String: - { - auto entry = static_cast<const StringEntry*>(this); - const uint8_t* cur = (const uint8_t*)entry->sizeAndChars; - uint32_t charsSize; - int sizeSize = ByteEncodeUtil::decodeLiteUInt32(cur, &charsSize); - return SLANG_OFFSET_OF(StringEntry, sizeAndChars) + sizeSize + charsSize; - } - case SerialTypeKind::Array: - { - auto entry = static_cast<const ArrayEntry*>(this); - return sizeof(ArrayEntry) + entry->elementSize * entry->elementCount; - } - case SerialTypeKind::RefObject: - case SerialTypeKind::NodeBase: - { - auto entry = static_cast<const ObjectEntry*>(this); - - auto serialClass = serialClasses->getSerialClass(typeKind, entry->subType); - - if (ReflectClassInfo::isSubClassOf(entry->subType, Val::kReflectClassInfo)) - return sizeof(ValEntry) + - static_cast<const ValEntry*>(this)->operandCount * sizeof(SerialValOperand); - - // Align by the alignment of the entry - size_t alignment = getAlignment(entry->info); - size_t size = sizeof(ObjectEntry) + serialClass->size; - - size = size + (alignment - 1) & ~(alignment - 1); - return size; - } - - default: - break; - } - - SLANG_ASSERT(!"Unknown type"); - return 0; -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialReader !!!!!!!!!!!!!!!!!!!!!!!!!!!! - -SerialReader::~SerialReader() -{ - for (const RefObject* obj : m_scope) - { - const_cast<RefObject*>(obj)->releaseReference(); - } -} - -const void* SerialReader::getArray(SerialIndex index, Index& outCount) -{ - if (index == SerialIndex(0)) - { - outCount = 0; - return nullptr; - } - - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - const Entry* entry = m_entries[Index(index)]; - - switch (entry->typeKind) - { - case SerialTypeKind::Array: - { - auto arrayEntry = static_cast<const SerialInfo::ArrayEntry*>(entry); - outCount = Index(arrayEntry->elementCount); - return (arrayEntry + 1); - } - default: - break; - } - - SLANG_ASSERT(!"Not an array"); - outCount = 0; - return nullptr; -} - -SerialPointer SerialReader::getPointer(SerialIndex index) -{ - if (index == SerialIndex(0)) - { - return SerialPointer(); - } - - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - const Entry* entry = m_entries[Index(index)]; - - const SerialPointer& ptr = m_objects[Index(index)]; - - switch (entry->typeKind) - { - case SerialTypeKind::String: - { - // Hmm. Tricky -> we don't know if will be cast as Name or String. Lets assume string. - String string = getString(index); - return SerialPointer(string.getStringRepresentation()); - } - case SerialTypeKind::ImportSymbol: - { - if (ptr.m_kind == SerialTypeKind::Unknown) - { - // TODO(JS): - // Could have an error here, because import symbol was not set - // For now just return nullptr - return SerialPointer(); - } - break; - } - default: - break; - } - - return ptr; -} - -SerialPointer SerialReader::getValPointer(SerialIndex index) -{ - if (index == SerialIndex(0)) - { - return SerialPointer(); - } - - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - - SerialPointer& ptr = m_objects[Index(index)]; - - if (ptr.m_ptr) - return ptr; - - const SerialInfo::ValEntry* entry = (SerialInfo::ValEntry*)m_entries[Index(index)]; - ValNodeDesc desc; - desc.type = (ASTNodeType)entry->subType; - auto readPtr = (SerialInfo::SerialValOperand*)(entry + 1); - for (uint32_t i = 0; i < entry->operandCount; i++) - { - auto serialOperand = readPtr[i]; - ValNodeOperand operand; - operand.kind = (ValNodeOperandKind)(serialOperand.type); - switch (operand.kind) - { - case ValNodeOperandKind::ConstantValue: - operand.values.intOperand = serialOperand.payload; - break; - case ValNodeOperandKind::ASTNode: - operand.values.nodeOperand = - (NodeBase*)getPointer((SerialIndex)serialOperand.payload).m_ptr; - break; - case ValNodeOperandKind::ValNode: - operand.values.nodeOperand = - (Val*)getValPointer((SerialIndex)serialOperand.payload).m_ptr; - break; - } - desc.operands.add(operand); - } - desc.init(); - ptr.m_kind = SerialTypeKind::NodeBase; - ptr.m_ptr = this->m_objectFactory->getOrCreateVal(_Move(desc)); - return ptr; -} - -String SerialReader::getString(SerialIndex index) -{ - if (index == SerialIndex(0)) - { - return String(); - } - - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - const Entry* entry = m_entries[Index(index)]; - - // It has to be a string type - if (entry->typeKind != SerialTypeKind::String) - { - SLANG_ASSERT(!"Not a string"); - return String(); - } - - RefObject* obj = m_objects[Index(index)].dynamicCast<RefObject>(); - - if (obj) - { - StringRepresentation* stringRep = dynamicCast<StringRepresentation>(obj); - if (stringRep) - { - return String(stringRep); - } - // Must be a name then - Name* name = dynamicCast<Name>(obj); - SLANG_ASSERT(name); - return name->text; - } - - // Okay we need to construct as a string - UnownedStringSlice slice = getStringSlice(index); - - StringRepresentation* stringRep = nullptr; - - const Index length = slice.getLength(); - if (length) - { - stringRep = StringRepresentation::createWithCapacityAndLength(length, length); - memcpy(stringRep->getData(), slice.begin(), length * sizeof(char)); - addScope(stringRep); - } - - m_objects[Index(index)] = stringRep; - return String(stringRep); -} - -Name* SerialReader::getName(SerialIndex index) -{ - if (index == SerialIndex(0)) - { - return nullptr; - } - - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - const Entry* entry = m_entries[Index(index)]; - - // It has to be a string type - if (entry->typeKind != SerialTypeKind::String) - { - SLANG_ASSERT(!"Not a string"); - return nullptr; - } - - RefObject* obj = m_objects[Index(index)].dynamicCast<RefObject>(); - - if (obj) - { - Name* name = dynamicCast<Name>(obj); - if (name) - { - return name; - } - // Can only be a string then - StringRepresentation* stringRep = dynamicCast<StringRepresentation>(obj); - SLANG_ASSERT(stringRep); - - // I don't need to scope, as scoped in NamePool - name = m_namePool->getName(String(stringRep)); - - // Store as name, as can always access the inner string if needed - m_objects[Index(index)] = name; - return name; - } - - UnownedStringSlice slice = getStringSlice(index); - String string(slice); - Name* name = m_namePool->getName(string); - // Don't need to add to scope, because scoped on the pool - m_objects[Index(index)] = name; - return name; -} - -UnownedStringSlice SerialReader::getStringSlice(SerialIndex index) -{ - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - const Entry* entry = m_entries[Index(index)]; - - // It has to be a string type - if (entry->typeKind == SerialTypeKind::String || - entry->typeKind == SerialTypeKind::ImportSymbol) - { - auto stringEntry = static_cast<const SerialInfo::StringEntry*>(entry); - - const uint8_t* src = (const uint8_t*)stringEntry->sizeAndChars; - - // Decode the string - uint32_t size; - int sizeSize = ByteEncodeUtil::decodeLiteUInt32(src, &size); - return UnownedStringSlice((const char*)src + sizeSize, size); - } - - // Can't be accessed as a slice - SLANG_ASSERT(!"Not accessible as a slice"); - return UnownedStringSlice(); -} - -/* static */ SlangResult SerialReader::loadEntries( - const uint8_t* data, - size_t dataCount, - SerialClasses* serialClasses, - List<const Entry*>& outEntries) -{ - // Check the input data is at least aligned to the max alignment (otherwise everything cannot be - // aligned correctly) - SLANG_ASSERT((size_t(data) & (SerialInfo::MAX_ALIGNMENT - 1)) == 0); - - outEntries.setCount(1); - outEntries[0] = nullptr; - - const uint8_t* const end = data + dataCount; - - const uint8_t* cur = data; - while (cur < end) - { - const Entry* entry = (const Entry*)cur; - outEntries.add(entry); - - const size_t entrySize = entry->calcSize(serialClasses); - cur += entrySize; - - // Need to get the next alignment - const size_t nextAlignment = SerialInfo::getNextAlignment(entry->info); - - // Need to fix cur with the alignment - cur = (const uint8_t*)((size_t(cur) + nextAlignment - 1) & ~(nextAlignment - 1)); - } - - return SLANG_OK; -} - -SlangResult SerialReader::constructObjects(NamePool* namePool) -{ - m_namePool = namePool; - - m_objects.clearAndDeallocate(); - m_objects.setCount(m_entries.getCount()); - memset(m_objects.getBuffer(), 0, m_objects.getCount() * sizeof(void*)); - - // Go through entries, constructing objects. - for (Index i = 1; i < m_entries.getCount(); ++i) - { - const Entry* entry = m_entries[i]; - - switch (entry->typeKind) - { - case SerialTypeKind::ImportSymbol: - { - // We don't construct any object for an imported symbol. - // It will be the responsibility of external code to interpet the symbols and *set* - // the appopriate objects prior to a call to `deserializeObjects` - break; - } - case SerialTypeKind::String: - { - // Don't need to construct an object. This is probably a StringRepresentation, or a - // Name Will evaluate lazily. - break; - } - case SerialTypeKind::RefObject: - case SerialTypeKind::NodeBase: - { - auto objectEntry = static_cast<const SerialInfo::ObjectEntry*>(entry); - - // Don't create object for Vals. - if (objectEntry->typeKind == SerialTypeKind::NodeBase && - ReflectClassInfo::isSubClassOf(objectEntry->subType, Val::kReflectClassInfo)) - break; - - void* obj = m_objectFactory->create(objectEntry->typeKind, objectEntry->subType); - if (!obj) - { - return SLANG_FAIL; - } - m_objects[i].set(entry->typeKind, obj); - break; - } - case SerialTypeKind::Array: - { - // Don't need to construct an object, as will be accessed and interpreted by the - // object that holds it - break; - } - } - } - - return SLANG_OK; -} - -SlangResult SerialReader::deserializeObjects() -{ - // Deserialize - for (Index i = 1; i < m_entries.getCount(); ++i) - { - const Entry* entry = m_entries[i]; - // First see if there is anything to construct - SerialPointer& dstPtr = m_objects[i]; - if (!dstPtr) - { - continue; - } - switch (entry->typeKind) - { - case SerialTypeKind::NodeBase: - case SerialTypeKind::RefObject: - { - auto objectEntry = static_cast<const SerialInfo::ObjectEntry*>(entry); - auto serialClass = - m_classes->getSerialClass(objectEntry->typeKind, objectEntry->subType); - if (!serialClass) - { - return SLANG_FAIL; - } - if (ReflectClassInfo::isSubClassOf(objectEntry->subType, Val::kReflectClassInfo)) - continue; - - const uint8_t* src = (const uint8_t*)(objectEntry + 1); - uint8_t* dst = (uint8_t*)dstPtr.m_ptr; - - // It must be constructed - SLANG_ASSERT(dst); - - while (serialClass) - { - for (Index j = 0; j < serialClass->fieldsCount; ++j) - { - auto field = serialClass->fields[j]; - auto fieldType = field.type; - fieldType->toNativeFunc( - this, - src + field.serialOffset, - dst + field.nativeOffset); - } - - // Get the super class - serialClass = serialClass->super; - } - - break; - } - default: - break; - } - } - - return SLANG_OK; -} - - -SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* namePool) -{ - // Load and place entries into entries table - SLANG_RETURN_ON_FAIL(loadEntries(data, dataCount)); - // Construct all of the objects - SLANG_RETURN_ON_FAIL(constructObjects(namePool)); - SLANG_RETURN_ON_FAIL(deserializeObjects()); - return SLANG_OK; -} - } // namespace Slang diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h index 4c7189d62..20916f735 100644 --- a/source/slang/slang-serialize.h +++ b/source/slang/slang-serialize.h @@ -27,702 +27,465 @@ class NodeBase; class Val; struct ValNodeDesc; -// Pre-declare -class SerialClasses; -class SerialWriter; -class SerialReader; - -struct SerialClass; -struct SerialField; - -// Type used to implement mechanisms to convert to and from serial types. -template<typename T, typename /*enumTypeSFINAE*/ = void> -struct SerialTypeInfo; - -enum class SerialTypeKind : uint8_t +struct Encoder { - Unknown, - - String, ///< String - Array, ///< Array - ImportSymbol, ///< Holds the name of the import symbol. Represented in exactly the same way as a - ///< string - - NodeBase, ///< NodeBase derived - RefObject, ///< RefObject derived types - - CountOf, -}; -typedef uint16_t SerialSubType; - -struct SerialInfo -{ - enum - { - // Data held in serialized format, the maximally allowed alignment - MAX_ALIGNMENT = 8, - }; - - // We only allow up to MAX_ALIGNMENT bytes of alignment. We store alignments as shifts, so 2 - // bits needed for 1 - 8 - enum class EntryInfo : uint8_t - { - Alignment1 = 0, - }; - - static EntryInfo makeEntryInfo(int alignment, int nextAlignment) - { - // Make sure they are power of 2 - SLANG_ASSERT((alignment & (alignment - 1)) == 0); - SLANG_ASSERT((nextAlignment & (nextAlignment - 1)) == 0); - - const int alignmentShift = ByteEncodeUtil::calcMsb8(alignment); - const int nextAlignmentShift = ByteEncodeUtil::calcMsb8(nextAlignment); - return EntryInfo((nextAlignmentShift << 2) | alignmentShift); - } - static EntryInfo makeEntryInfo(int alignment) - { - // Make sure they are power of 2 - SLANG_ASSERT((alignment & (alignment - 1)) == 0); - return EntryInfo(ByteEncodeUtil::calcMsb8(alignment)); - } - /// Apply with the next alignment - static EntryInfo combineWithNext(EntryInfo cur, EntryInfo next) +public: + Encoder(Stream* stream) + : _stream(stream) { - return EntryInfo((int(cur) & ~0xc0) | ((int(next) & 3) << 2)); } - static int getAlignment(EntryInfo info) { return 1 << (int(info) & 3); } - static int getNextAlignment(EntryInfo info) { return 1 << ((int(info) >> 2) & 3); } - - /* Alignment is a little tricky. We have a 'Entry' header before the payload. The payload - alignment may change. If we only align on the Entry header, then it's size *must* be some modulo - of the maximum alignment allowed. + ~Encoder() { RiffUtil::write(&_riff, _stream); } - We could hold Entry separate from payload. We could make the header not require the alignment of - the payload - but then we'd need payload alignment separate from entry alignment. - */ - struct Entry + void beginArray(FourCC typeCode) { - SerialTypeKind typeKind; - EntryInfo info; + _riff.startChunk(RiffContainer::Chunk::Kind::List, typeCode); + } - size_t calcSize(SerialClasses* serialClasses) const; - }; + void beginArray() { beginArray(SerialBinary::kArrayFourCC); } - struct StringEntry : Entry + void endArray() { - char sizeAndChars[1]; - }; + _riff.endChunk(); + // TODO: maybe end key... + } - struct ObjectEntry : Entry + void beginObject(FourCC typeCode) { - SerialSubType - subType; ///< Can be ASTType or other subtypes (as used for RefObjects for example) - uint32_t _pad0; ///< Necessary, because a node *can* have MAX_ALIGNEMENT - }; + _riff.startChunk(RiffContainer::Chunk::Kind::List, typeCode); + } - struct ValEntry : Entry - { - SerialSubType subType; - uint32_t operandCount; - }; + void beginObject() { beginObject(SerialBinary::kObjectFourCC); } - struct ArrayEntry : Entry - { - uint16_t elementSize; - uint32_t elementCount; - }; + void endObject() { _riff.endChunk(); } - struct SerialValOperand + void beginKeyValuePair() { - int type; - uint64_t payload; - }; -}; + _riff.startChunk(RiffContainer::Chunk::Kind::List, SerialBinary::kPairFourCC); + } -typedef uint32_t SerialIndexRaw; -enum class SerialIndex : SerialIndexRaw; + void endKeyValuePair() { _riff.endChunk(); } -/* A type to convert pointers into types such that they can be passed around to readers/writers -without having to know the specific type. If there was a base class that all the serialized types -derived from, that was dynamically castable this would not be necessary */ -struct SerialPointer -{ - // Helpers so we can choose what kind of pointer we have based on the (unused) type of the - // pointer passed in - SLANG_FORCE_INLINE RefObject* _get(const RefObject*) - { - return m_kind == SerialTypeKind::RefObject ? reinterpret_cast<RefObject*>(m_ptr) : nullptr; - } - SLANG_FORCE_INLINE NodeBase* _get(const NodeBase*) + void beginKeyValuePair(FourCC keyCode) { - return m_kind == SerialTypeKind::NodeBase ? reinterpret_cast<NodeBase*>(m_ptr) : nullptr; + _riff.startChunk(RiffContainer::Chunk::Kind::List, keyCode); } - template<typename T> - T* dynamicCast() + void encodeData(FourCC typeCode, void const* data, size_t size) { - return Slang::dynamicCast<T>(_get((T*)nullptr)); + _riff.startChunk(RiffContainer::Chunk::Kind::Data, typeCode); + _riff.write(data, size); + _riff.endChunk(); } - SerialPointer() - : m_kind(SerialTypeKind::Unknown), m_ptr(nullptr) + void encodeData(void const* data, size_t size) { + encodeData(SerialBinary::kDataFourCC, data, size); } - SerialPointer(RefObject* in) - : m_kind(SerialTypeKind::RefObject), m_ptr((void*)in) - { - } - SerialPointer(NodeBase* in) - : m_kind(SerialTypeKind::NodeBase), m_ptr((void*)in) + void encode(nullptr_t) { encodeData(SerialBinary::kNullFourCC, nullptr, 0); } + + void encodeBool(bool value) { + encodeData(value ? SerialBinary::kTrueFourCC : SerialBinary::kFalseFourCC, nullptr, 0); } - /// True if the ptr is set - SLANG_FORCE_INLINE operator bool() const { return m_ptr != nullptr; } + void encode(Int32 value) { encodeData(SerialBinary::kInt32FourCC, &value, sizeof(value)); } - /// Directly set pointer/kind - void set(SerialTypeKind kind, void* ptr) - { - m_kind = kind; - m_ptr = ptr; - } + void encode(UInt32 value) { encodeData(SerialBinary::kUInt32FourCC, &value, sizeof(value)); } - static SerialTypeKind getKind(const RefObject*) { return SerialTypeKind::RefObject; } - static SerialTypeKind getKind(const NodeBase*) { return SerialTypeKind::NodeBase; } + void encode(Int64 value) { encodeData(SerialBinary::kInt64FourCC, &value, sizeof(value)); } - SerialTypeKind m_kind; - void* m_ptr; -}; + void encode(UInt64 value) { encodeData(SerialBinary::kUInt64FourCC, &value, sizeof(value)); } -class SerialFilter -{ -public: - virtual SerialIndex writePointer(SerialWriter* writer, const NodeBase* ptr) = 0; - virtual SerialIndex writePointer(SerialWriter* writer, const RefObject* ptr) = 0; -}; + void encode(float value) { encodeData(SerialBinary::kFloat32FourCC, &value, sizeof(value)); } -class SerialObjectFactory -{ -public: - virtual void* create(SerialTypeKind typeKind, SerialSubType subType) = 0; - virtual void* getOrCreateVal(ValNodeDesc&& desc) = 0; -}; + void encode(double value) { encodeData(SerialBinary::kFloat64FourCC, &value, sizeof(value)); } -class SerialExtraObjects -{ -public: - template<typename T> - void set(T* obj) - { - m_objects[Index(T::kExtraType)] = obj; - } - template<typename T> - void set(const RefPtr<T>& obj) + void encodeString(String const& value) { - m_objects[Index(T::kExtraType)] = obj.Ptr(); + Int size = value.getLength(); + encodeData(SerialBinary::kStringFourCC, value.getBuffer(), size); } - /// Get the extra type - template<typename T> - T* get() - { - return reinterpret_cast<T*>(m_objects[Index(T::kExtraType)]); - } - SerialExtraObjects() - { - for (auto& obj : m_objects) - obj = nullptr; - } + void encode(String const& value) { encodeString(value); } -protected: - void* m_objects[Index(SerialExtraType::CountOf)]; -}; + struct WithArray + { + public: + WithArray(Encoder* encoder) + : _encoder(encoder) + { + encoder->beginArray(); + } -enum class PostSerializationFixUpKind -{ - ValPtr, -}; + WithArray(Encoder* encoder, FourCC typeCode) + : _encoder(encoder) + { + encoder->beginArray(typeCode); + } -/* This class is the interface used by toNative implementations to recreate a type. */ -class SerialReader : public RefObject -{ -public: - typedef SerialInfo::Entry Entry; + ~WithArray() { _encoder->endArray(); } - template<typename T> - void getArray(SerialIndex index, List<T>& out); + private: + Encoder* _encoder; + }; - template<typename T, int n> - void getArray(SerialIndex index, ShortList<T, n>& out); + struct WithObject + { + public: + WithObject(Encoder* encoder) + : _encoder(encoder) + { + encoder->beginObject(); + } - const void* getArray(SerialIndex index, Index& outCount); + WithObject(Encoder* encoder, FourCC typeCode) + : _encoder(encoder) + { + encoder->beginObject(typeCode); + } - SerialPointer getPointer(SerialIndex index); - SerialPointer getValPointer(SerialIndex index); + ~WithObject() { _encoder->endObject(); } - String getString(SerialIndex index); - Name* getName(SerialIndex index); - UnownedStringSlice getStringSlice(SerialIndex index); + private: + Encoder* _encoder; + }; - SlangResult loadEntries(const uint8_t* data, size_t dataCount) - { - return loadEntries(data, dataCount, m_classes, m_entries); - } - /// For each entry construct an object. Does *NOT* deserialize them - SlangResult constructObjects(NamePool* namePool); - /// Entries must be loaded (with loadEntries), and objects constructed (with constructObjects) - /// before deserializing - SlangResult deserializeObjects(); - - /// NOTE! data must stay ins scope when reading takes place - SlangResult load(const uint8_t* data, size_t dataCount, NamePool* namePool); - - /// Get the entries list - const List<const Entry*>& getEntries() const { return m_entries; } - - /// Access the objects list - /// NOTE that if a SerialObject holding a RefObject and needs to be kept in scope, add the - /// RefObject* via addScope - List<SerialPointer>& getObjects() { return m_objects; } - const List<SerialPointer>& getObjects() const { return m_objects; } - - /// Add an object to be kept in scope - void addScopeWithoutAddRef(const RefObject* obj) { m_scope.add(obj); } - /// Add obj with a reference - void addScope(const RefObject* obj) + struct WithKeyValuePair { - const_cast<RefObject*>(obj)->addReference(); - m_scope.add(obj); - } + public: + WithKeyValuePair(Encoder* encoder) + : _encoder(encoder) + { + encoder->beginKeyValuePair(); + } - /// Used for attaching extra objects necessary for serializing - SerialExtraObjects& getExtraObjects() { return m_extraObjects; } + WithKeyValuePair(Encoder* encoder, FourCC typeCode) + : _encoder(encoder) + { + encoder->beginKeyValuePair(typeCode); + } - /// Ctor - SerialReader(SerialClasses* classes, SerialObjectFactory* objectFactory) - : m_classes(classes), m_objectFactory(objectFactory) - { - } - ~SerialReader(); + ~WithKeyValuePair() { _encoder->endKeyValuePair(); } - /// Load the entries table (without deserializing anything) - /// NOTE! data must stay ins scope for outEntries to be valid - static SlangResult loadEntries( - const uint8_t* data, - size_t dataCount, - SerialClasses* serialClasses, - List<const Entry*>& outEntries); + private: + Encoder* _encoder; + }; -protected: - List<const Entry*> m_entries; ///< The entries +private: + Stream* _stream = nullptr; - List<SerialPointer> m_objects; ///< The constructed objects - NamePool* m_namePool; ///< Pool names are added to + // Implementation details below... + RiffContainer _riff; - List<const RefObject*> m_scope; ///< Keeping objects in scope +public: + RiffContainer* getRIFF() { return &_riff; } - SerialExtraObjects m_extraObjects; + RiffContainer::Chunk* getRIFFChunk() { return _riff.getCurrentChunk(); } - SerialObjectFactory* m_objectFactory; - SerialClasses* m_classes; ///< Information used to deserialize + void setRIFFChunk(RiffContainer::Chunk* chunk) { _riff.setCurrentChunk(chunk); } }; -// --------------------------------------------------------------------------- -template<typename T> -void SerialReader::getArray(SerialIndex index, List<T>& out) +struct Decoder { - typedef SerialTypeInfo<T> ElementTypeInfo; - typedef typename ElementTypeInfo::SerialType ElementSerialType; - - Index count; - auto serialElements = (const ElementSerialType*)getArray(index, count); - - if (count == 0) +public: + Decoder(RiffContainer::Chunk* chunk) + : _chunk(chunk) { - out.clear(); - return; } - if (std::is_same<T, ElementSerialType>::value) + bool decodeBool() { - // If they are the same we can just write out - out.clear(); - out.insertRange(0, (const T*)serialElements, count); - } - else - { - // Else we need to convert - out.setCount(count); - for (Index i = 0; i < count; ++i) + switch (getTag()) { - ElementTypeInfo::toNative(this, (const void*)&serialElements[i], (void*)&out[i]); + case SerialBinary::kTrueFourCC: + _chunk = _chunk->m_next; + return true; + case SerialBinary::kFalseFourCC: + _chunk = _chunk->m_next; + return false; + + default: + SLANG_UNEXPECTED("invalid format in RIFF"); + UNREACHABLE_RETURN(false); } } -} - -template<typename T, int n> -void SerialReader::getArray(SerialIndex index, ShortList<T, n>& out) -{ - typedef SerialTypeInfo<T> ElementTypeInfo; - typedef typename ElementTypeInfo::SerialType ElementSerialType; - - Index count; - auto serialElements = (const ElementSerialType*)getArray(index, count); - if (count == 0) + String decodeString() { - out.clear(); - return; - } + if (getTag() != SerialBinary::kStringFourCC) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + UNREACHABLE_RETURN(""); + } - if (std::is_same<T, ElementSerialType>::value) - { - // If they are the same we can just write out - out.clear(); - out.addRange((const T*)serialElements, count); - } - else - { - // Else we need to convert - out.setCount(count); - for (Index i = 0; i < count; ++i) + auto dataChunk = as<RiffContainer::DataChunk>(_chunk); + if (!dataChunk) { - ElementTypeInfo::toNative(this, (const void*)&serialElements[i], (void*)&out[i]); + SLANG_UNEXPECTED("invalid format in RIFF"); + UNREACHABLE_RETURN(""); } + + auto size = dataChunk->calcPayloadSize(); + + String value; + value.appendRepeatedChar(' ', size); + dataChunk->getPayload((char*)value.getBuffer()); + + _chunk = _chunk->m_next; + return value; } -} -/* This is a class used tby toSerial implementations to turn native type into the serial type */ -class SerialWriter : public RefObject -{ -public: - typedef uint32_t Flags; - struct Flag + void decodeData(FourCC typeTag, void* outData, size_t dataSize) { - enum Enum : Flags + if (getTag() == typeTag) { - /// If set will zero initialize backing memory. This is slower but - /// is desirable to make two serializations of the same thing produce the - /// identical serialized result. - ZeroInitialize = 0x1, - - /// If set will not serialize function body. - SkipFunctionBody = 0x2, - }; - }; - - SerialIndex addPointer(const NodeBase* ptr); - SerialIndex addPointer(const RefObject* ptr); - - /// Write the object at ptr of type serialCls - SerialIndex writeObject(const SerialClass* serialCls, const void* ptr); + auto dataChunk = as<RiffContainer::DataChunk>(_chunk); + if (dataChunk) + { + if (dataChunk->calcPayloadSize() >= dataSize) + { + dataChunk->getPayload(outData); + _chunk = _chunk->m_next; + return; + } + } + } - /// Write the object at the pointer - SerialIndex writeObject(const NodeBase* ptr); - SerialIndex writeObject(const RefObject* ptr); - SerialIndex writeValObject(const Val* ptr); + SLANG_UNEXPECTED("invalid format in RIFF"); + } - /// Add an array - may need to convert to serialized format template<typename T> - SerialIndex addArray(const T* in, Index count); - - template<typename NATIVE_TYPE> - /// Add an array where all the elements are already in serialized format (ie there is no need to - /// do a conversion) - SerialIndex addSerialArray(const void* elements, Index elementCount) + T _decodeSimpleValue(FourCC typeTag) { - typedef SerialTypeInfo<NATIVE_TYPE> TypeInfo; - return addSerialArray( - sizeof(typename TypeInfo::SerialType), - SerialTypeInfo<NATIVE_TYPE>::SerialAlignment, - elements, - elementCount); + T value; + decodeData(typeTag, &value, sizeof(value)); + return value; } - /// Add an array where all the elements are already in serialized format (ie there is no need to - /// do a conversion) - SerialIndex addSerialArray( - size_t elementSize, - size_t alignment, - const void* elements, - Index elementCount); + Int64 decodeInt64() { return _decodeSimpleValue<Int64>(SerialBinary::kInt64FourCC); } - /// Add the string - SerialIndex addString(const UnownedStringSlice& slice) - { - return _addStringSlice(SerialTypeKind::String, m_sliceMap, slice); - } - SerialIndex addString(const String& in); - SerialIndex addName(const Name* name); + UInt64 decodeUInt64() { return _decodeSimpleValue<UInt64>(SerialBinary::kUInt64FourCC); } - /// Adding import symbols - SerialIndex addImportSymbol(const UnownedStringSlice& slice) - { - return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, slice); - } - SerialIndex addImportSymbol(const String& string) - { - return _addStringSlice( - SerialTypeKind::ImportSymbol, - m_importSymbolMap, - string.getUnownedSlice()); - } + Int32 decodeInt32() { return _decodeSimpleValue<Int32>(SerialBinary::kInt32FourCC); } - /// Set a the ptr associated with an index. - /// NOTE! That there cannot be a pre-existing setting. - void setPointerIndex(const NodeBase* ptr, SerialIndex index); - void setPointerIndex(const RefObject* ptr, SerialIndex index); + UInt32 decodeUInt32() { return _decodeSimpleValue<UInt32>(SerialBinary::kUInt32FourCC); } - /// Get the entries table holding how each index maps to an entry - const List<SerialInfo::Entry*>& getEntries() const { return m_entries; } + float decodeFloat32() { return _decodeSimpleValue<float>(SerialBinary::kFloat32FourCC); } - /// Write to a stream - SlangResult write(Stream* stream); + double decodeFloat64() { return _decodeSimpleValue<double>(SerialBinary::kFloat64FourCC); } - /// Write a data chunk with fourCC - SlangResult writeIntoContainer(FourCC fourCC, RiffContainer* container); - /// Used for attaching extra objects necessary for serializing - SerialExtraObjects& getExtraObjects() { return m_extraObjects; } + FourCC getTag() { return _chunk ? _chunk->m_fourCC : 0; } - /// Get the flag - Flags getFlags() const { return m_flags; } + Int32 _decodeImpl(Int32*) { return decodeInt32(); } + UInt32 _decodeImpl(UInt32*) { return decodeUInt32(); } - /// Ctor - SerialWriter(SerialClasses* classes, SerialFilter* filter, Flags flags = Flag::ZeroInitialize); + Int64 _decodeImpl(Int64*) { return decodeInt64(); } + UInt64 _decodeImpl(UInt64*) { return decodeUInt64(); } -protected: - typedef Dictionary<UnownedStringSlice, Index> SliceMap; + float _decodeImpl(float*) { return decodeFloat32(); } + double _decodeImpl(double*) { return decodeFloat64(); } - SerialIndex _addStringSlice( - SerialTypeKind typeKind, - SliceMap& sliceMap, - const UnownedStringSlice& slice); + template<typename T> + T decode() + { + return _decodeImpl((T*)nullptr); + } - SerialIndex _add(const void* nativePtr, SerialInfo::Entry* entry) + template<typename T> + void decode(T& outValue) { - m_entries.add(entry); - // Okay I need to allocate space for this - SerialIndex index = SerialIndex(m_entries.getCount() - 1); - // Add to the map - m_ptrMap.add(nativePtr, Index(index)); - return index; + outValue = _decodeImpl((T*)nullptr); } - Dictionary<const void*, Index> m_ptrMap; // Maps a pointer to an entry index + void beginArray(FourCC typeCode = SerialBinary::kArrayFourCC) + { + auto listChunk = as<RiffContainer::ListChunk>(_chunk); + if (!listChunk) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } - // NOTE! Assumes the content stays in scope! - SliceMap m_sliceMap; - SliceMap m_importSymbolMap; + if (listChunk->m_fourCC != typeCode) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } - SerialExtraObjects m_extraObjects; ///< Extra objects + _chunk = listChunk->getFirstContainedChunk(); + } - List<SerialInfo::Entry*> m_entries; ///< The entries - MemoryArena m_arena; ///< Holds the payloads - SerialClasses* m_classes; - SerialFilter* m_filter; ///< Filter to control what is serialized + void beginObject(FourCC typeCode = SerialBinary::kObjectFourCC) + { + auto listChunk = as<RiffContainer::ListChunk>(_chunk); + if (!listChunk) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } - Flags m_flags; ///< Flags to control behavior -}; + if (listChunk->m_fourCC != typeCode) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } -// --------------------------------------------------------------------------- -template<typename T> -SerialIndex SerialWriter::addArray(const T* in, Index count) -{ - typedef SerialTypeInfo<T> ElementTypeInfo; - typedef typename ElementTypeInfo::SerialType ElementSerialType; + _chunk = listChunk->getFirstContainedChunk(); + } - if (std::is_same<T, ElementSerialType>::value) + void beginKeyValuePair(FourCC typeCode = SerialBinary::kPairFourCC) { - // If they are the same we can just write out - return addSerialArray(sizeof(T), SLANG_ALIGN_OF(ElementSerialType), in, count); + auto listChunk = as<RiffContainer::ListChunk>(_chunk); + if (!listChunk) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } + + if (listChunk->m_fourCC != typeCode) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } + + _chunk = listChunk->getFirstContainedChunk(); } - else - { - // Else we need to convert - List<ElementSerialType> work; - work.setCount(count); - if (getFlags() & Flag::ZeroInitialize) + void beginProperty(FourCC propertyCode) + { + auto listChunk = as<RiffContainer::ListChunk>(_chunk); + if (!listChunk) { - ::memset(work.getBuffer(), 0, sizeof(ElementSerialType) * count); + SLANG_UNEXPECTED("invalid format in RIFF"); } - for (Index i = 0; i < count; ++i) + auto found = listChunk->findContainedList(propertyCode); + if (!found) { - ElementTypeInfo::toSerial(this, &in[i], &work[i]); + SLANG_UNEXPECTED("invalid format in RIFF"); } - return addSerialArray( - sizeof(ElementSerialType), - SLANG_ALIGN_OF(ElementSerialType), - work.getBuffer(), - count); - } -} -/* A SerialFieldType describes the size of field, it's alignment, and contains the -functions that convert between serial and native data */ -struct SerialFieldType -{ - typedef void (*ToSerialFunc)(SerialWriter* writer, const void* src, void* dst); - typedef void (*ToNativeFunc)(SerialReader* reader, const void* src, void* dst); + _chunk = found->getFirstContainedChunk(); + } - size_t serialSizeInBytes; - uint8_t serialAlignment; - ToSerialFunc toSerialFunc; - ToNativeFunc toNativeFunc; -}; + bool hasElements() { return _chunk != nullptr; } -/* Describes a field in a SerialClass. */ -struct SerialField -{ - /// Returns a suitable ptr for use in make. - /// NOTE! Sets to 1 so it's constant and not 0 (and so nullptr) - template<typename T> - static T* getPtr() + bool isNull() { - return (T*)1; + if (_chunk == nullptr) + return true; + if (getTag() == SerialBinary::kNullFourCC) + return true; + return false; } - template<typename T> - static SerialField make(const char* name, T* in); - - const char* name; ///< The name of the field - const SerialFieldType* type; ///< The type of the field - uint32_t nativeOffset; ///< Offset to field from base of type - uint32_t serialOffset; ///< Offset in serial type -}; + bool decodeNull() + { + if (!isNull()) + return false; -typedef uint8_t SerialClassFlags; + if (_chunk != nullptr) + { + _chunk = _chunk->m_next; + } + return true; + } -struct SerialClassFlag -{ - enum Enum : SerialClassFlags + struct WithArray { - DontSerialize = - 0x01, ///< If set the type is not serialized, so can turn into SerialIndex(0) - }; -}; + public: + WithArray(Decoder& decoder) + : _decoder(decoder) + { + _saved = decoder._chunk; + decoder.beginArray(); + } -/* SerialClass defines the type (typeKind/subType) and the fields in just this class definition (ie -not it's super class). Also contains a pointer to the super type if there is one */ -struct SerialClass -{ - SerialTypeKind typeKind; ///< The type kind - SerialSubType subType; ///< Subtype - meaning depends on typeKind + WithArray(Decoder& decoder, FourCC typeCode) + : _decoder(decoder) + { + _saved = decoder._chunk; + decoder.beginArray(typeCode); + } - uint8_t alignment; ///< Alignment of this type - SerialClassFlags flags; ///< Flags + ~WithArray() { _decoder._chunk = _saved->m_next; } - uint32_t size; ///< Size of the field in bytes + private: + RiffContainer::Chunk* _saved; + Decoder& _decoder; + }; - Index fieldsCount; - const SerialField* fields; + struct WithObject + { + public: + WithObject(Decoder& decoder) + : _decoder(decoder) + { + _saved = decoder._chunk; + decoder.beginObject(); + } - const SerialClass* super; ///< The super class -}; + WithObject(Decoder& decoder, FourCC typeCode) + : _decoder(decoder) + { + _saved = decoder._chunk; + decoder.beginObject(typeCode); + } -// An instance could be shared across Sessions, but for simplicity of life time -// here we don't deal with that -class SerialClasses : public RefObject -{ -public: - /// Will add it's own copy into m_classesByType - /// In process will calculate alignment, offset etc for fields - /// NOTE! the super set, *must* be an already added to this SerialClasses - const SerialClass* add(const SerialClass* cls); + ~WithObject() { _decoder._chunk = _saved->m_next; } - const SerialClass* add( - SerialTypeKind kind, - SerialSubType subType, - const SerialField* fields, - Index fieldsCount, - const SerialClass* superCls); + private: + RiffContainer::Chunk* _saved; + Decoder& _decoder; + }; - /// Add a type which will not serialize - const SerialClass* addUnserialized(SerialTypeKind kind, SerialSubType subType); + struct WithKeyValuePair + { + public: + WithKeyValuePair(Decoder& decoder) + : _decoder(decoder) + { + _saved = decoder._chunk; + decoder.beginKeyValuePair(); + } - /// Returns true if this cls is *owned* by this SerialClasses - bool isOwned(const SerialClass* cls) const; + WithKeyValuePair(Decoder& decoder, FourCC typeCode) + : _decoder(decoder) + { + _saved = decoder._chunk; + _decoder.beginKeyValuePair(typeCode); + } - /// Returns true if the SerialClasses structure appears ok - bool isOk() const; + ~WithKeyValuePair() { _decoder._chunk = _saved->m_next; } - /// Get a serial class based on its type/subType - const SerialClass* getSerialClass(SerialTypeKind typeKind, SerialSubType subType) const - { - const auto& classes = m_classesByTypeKind[Index(typeKind)]; - return (subType < classes.getCount()) ? classes[subType] : nullptr; - } + private: + RiffContainer::Chunk* _saved; + Decoder& _decoder; + }; - /// Ctor - SerialClasses(); + struct WithProperty + { + public: + WithProperty(Decoder& decoder, FourCC typeCode) + : _decoder(decoder) + { + _saved = decoder._chunk; + _decoder.beginProperty(typeCode); + } -protected: - SerialClass* _createSerialClass(const SerialClass* cls); + ~WithProperty() { _decoder._chunk = _saved->m_next; } - MemoryArena m_arena; + private: + RiffContainer::Chunk* _saved; + Decoder& _decoder; + }; - List<const SerialClass*> m_classesByTypeKind[Index(SerialTypeKind::CountOf)]; -}; -// !!!!!!!!!!!!!!!!!!!!! SerialGetFieldType<T> !!!!!!!!!!!!!!!!!!!!!!!!!!! -// Getting the type info, let's use a static variable to hold the state to keep simple + RiffContainer::Chunk* getCursor() { return _chunk; } + void setCursor(RiffContainer::Chunk* chunk) { _chunk = chunk; } -template<typename T> -struct SerialGetFieldType -{ - static const SerialFieldType* getFieldType() - { - typedef SerialTypeInfo<T> Info; - static const SerialFieldType type = { - sizeof(typename Info::SerialType), - uint8_t(Info::SerialAlignment), - &Info::toSerial, - &Info::toNative}; - return &type; - } +private: + RiffContainer::Chunk* _chunk = nullptr; }; -// !!!!!!!!!!!!!!!!!!!!! SerialGetFieldType<T> !!!!!!!!!!!!!!!!!!!!!!!!!!! - -template<typename T> -/* static */ SerialField SerialField::make(const char* name, T* in) -{ - uint8_t* ptr = reinterpret_cast<uint8_t*>(in); - - SerialField field; - field.name = name; - field.type = SerialGetFieldType<T>::getFieldType(); - // This only works because we in is an offset from 1 - field.nativeOffset = uint32_t(size_t(ptr) - 1); - field.serialOffset = 0; - return field; -} - -// !!!!!!!!!!!!!!!!!!!!! Convenience functions !!!!!!!!!!!!!!!!!!!!!!!!!!! - -template<typename NATIVE_TYPE, typename SERIAL_TYPE> -SLANG_FORCE_INLINE void toSerialValue( - SerialWriter* writer, - const NATIVE_TYPE& src, - SERIAL_TYPE& dst) -{ - SerialTypeInfo<NATIVE_TYPE>::toSerial(writer, &src, &dst); -} - -template<typename SERIAL_TYPE, typename NATIVE_TYPE> -SLANG_FORCE_INLINE void toNativeValue( - SerialReader* reader, - const SERIAL_TYPE& src, - NATIVE_TYPE& dst) -{ - SerialTypeInfo<NATIVE_TYPE>::toNative(reader, &src, &dst); -} } // namespace Slang diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index efb5814e6..e45311abc 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -9,6 +9,41 @@ namespace Slang { +bool SyntaxClassBase::isSubClassOf(SyntaxClassBase const& other) const +{ + auto selfInfo = getInfo(); + auto otherInfo = other.getInfo(); + if (!selfInfo || !otherInfo) + return false; + return unsigned((int)selfInfo->firstTag - (int)otherInfo->firstTag) < + unsigned(otherInfo->tagCount); +} + +UnownedTerminatedStringSlice SyntaxClassBase::getName() const +{ + return _info ? UnownedTerminatedStringSlice(_info->name) : UnownedTerminatedStringSlice(); +} + +void* SyntaxClassBase::createInstanceImpl(ASTBuilder* astBuilder) const +{ + if (!_info) + return nullptr; + if (!_info->createFunc) + return nullptr; + + return _info->createFunc(astBuilder); +} + +void SyntaxClassBase::destructInstanceImpl(void* instance) const +{ + if (!_info) + return; + if (!_info->destructFunc) + return; + + return _info->destructFunc(instance); +} + /* static */ const TypeExp TypeExp::empty; @@ -227,13 +262,13 @@ void printDiagnosticArg(StringBuilder& sb, ASTNodeType nodeType) sb << "discard"; break; default: - if (ASTClassInfo::getInfo(nodeType)->isDerivedFrom((uint32_t)ASTNodeType::Expr)) + if (SyntaxClass<NodeBase>(nodeType).isSubClassOf<Expr>()) sb << "expression"; - else if (ASTClassInfo::getInfo(nodeType)->isDerivedFrom((uint32_t)ASTNodeType::Stmt)) + else if (SyntaxClass<NodeBase>(nodeType).isSubClassOf<Stmt>()) sb << "statement"; - else if (ASTClassInfo::getInfo(nodeType)->isDerivedFrom((uint32_t)ASTNodeType::Decl)) + else if (SyntaxClass<NodeBase>(nodeType).isSubClassOf<Decl>()) sb << "decl"; - else if (ASTClassInfo::getInfo(nodeType)->isDerivedFrom((uint32_t)ASTNodeType::Val)) + else if (SyntaxClass<NodeBase>(nodeType).isSubClassOf<Val>()) sb << "val"; else sb << "node"; @@ -326,7 +361,7 @@ Decl* const* adjustFilterCursorImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo)) + if (decl->getClass().isSubClassOf(clsInfo)) { return ptr; } @@ -338,7 +373,7 @@ Decl* const* adjustFilterCursorImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo) && + if (decl->getClass().isSubClassOf(clsInfo) && !decl->hasModifier<HLSLStaticModifier>()) { return ptr; @@ -351,7 +386,7 @@ Decl* const* adjustFilterCursorImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo) && + if (decl->getClass().isSubClassOf(clsInfo) && decl->hasModifier<HLSLStaticModifier>()) { return ptr; @@ -378,7 +413,7 @@ Decl* const* getFilterCursorByIndexImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo)) + if (decl->getClass().isSubClassOf(clsInfo)) { if (index <= 0) { @@ -394,7 +429,7 @@ Decl* const* getFilterCursorByIndexImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo) && + if (decl->getClass().isSubClassOf(clsInfo) && !decl->hasModifier<HLSLStaticModifier>()) { if (index <= 0) @@ -411,7 +446,7 @@ Decl* const* getFilterCursorByIndexImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo) && + if (decl->getClass().isSubClassOf(clsInfo) && decl->hasModifier<HLSLStaticModifier>()) { if (index <= 0) @@ -428,7 +463,7 @@ Decl* const* getFilterCursorByIndexImpl( } Index getFilterCountImpl( - const ReflectClassInfo& clsInfo, + const SyntaxClassBase& clsInfo, MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) @@ -442,7 +477,7 @@ Index getFilterCountImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - count += Index(decl->getClassInfo().isSubClassOf(clsInfo)); + count += Index(decl->getClass().isSubClassOf(clsInfo)); } break; } @@ -452,7 +487,7 @@ Index getFilterCountImpl( { Decl* decl = *ptr; count += Index( - decl->getClassInfo().isSubClassOf(clsInfo) && + decl->getClass().isSubClassOf(clsInfo) && !decl->hasModifier<HLSLStaticModifier>()); } break; @@ -463,7 +498,7 @@ Index getFilterCountImpl( { Decl* decl = *ptr; count += Index( - decl->getClassInfo().isSubClassOf(clsInfo) && + decl->getClass().isSubClassOf(clsInfo) && decl->hasModifier<HLSLStaticModifier>()); } break; @@ -701,7 +736,7 @@ Type* DeclRefType::create(ASTBuilder* astBuilder, DeclRef<Decl> declRef) } else if (auto magicMod = declRef.getDecl()->findModifier<MagicTypeModifier>()) { - if (magicMod->magicNodeType == ASTNodeType(-1)) + if (!magicMod->magicNodeType) { SLANG_UNEXPECTED("unhandled type"); } diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index accc490f2..8d78872a6 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -117,7 +117,7 @@ inline void foreachDirectOrExtensionMemberOfType( _foreachDirectOrExtensionMemberOfType( semantics, declRef, - getClass<T>(), + getSyntaxClass<T>(), &Helper::callback, &helper); } diff --git a/source/slang/slang-value-reflect.cpp b/source/slang/slang-value-reflect.cpp deleted file mode 100644 index aa2b6fbe2..000000000 --- a/source/slang/slang-value-reflect.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "slang-value-reflect.h" - -#include "slang-generated-value-macro.h" -#include "slang-generated-value.h" -#include "slang.h" - -namespace Slang -{ - - -} // namespace Slang diff --git a/source/slang/slang-value-reflect.h b/source/slang/slang-value-reflect.h deleted file mode 100644 index 1110ad225..000000000 --- a/source/slang/slang-value-reflect.h +++ /dev/null @@ -1,12 +0,0 @@ -// slang-value-reflect.h - -#ifndef SLANG_VALUE_REFLECT_H -#define SLANG_VALUE_REFLECT_H - -#include "slang-generated-value-macro.h" -#include "slang-generated-value.h" - -// Create the functions to automatically convert between value types - - -#endif // SLANG_VALUE_REFLECT_H diff --git a/source/slang/slang-visitor.h b/source/slang/slang-visitor.h index 580029289..180956bde 100644 --- a/source/slang/slang-visitor.h +++ b/source/slang/slang-visitor.h @@ -5,235 +5,233 @@ // This file defines the basic "Visitor" pattern for doing dispatch // over the various categories of syntax node. -#include "slang-generated-ast-macro.h" +#include "slang-ast-dispatch.h" +#include "slang-ast-forward-declarations.h" #include "slang-syntax.h" namespace Slang { -// Macros to generate from ast-generated-macro file the vistors - -// Only runs 'param' macro if the marker is NONE (ie not ABSTRACT here) -#define SLANG_CLASS_ONLY_ABSTRACT_AST(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) -#define SLANG_CLASS_ONLY_AST(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - param(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) - -#define SLANG_CLASS_ONLY(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - SLANG_CLASS_ONLY_##MARKER(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) - -// Dispatch decl -#define SLANG_VISITOR_DISPATCH_DECL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; - // Dispatch -#define SLANG_VISITOR_DISPATCH_RESULT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) override \ - { \ - *(Result*)extra = ((Derived*)this)->visit##NAME(obj); \ - } - -#define SLANG_VISITOR_DISPATCH_VOID_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - virtual void dispatch_##NAME(NAME* obj, void*) override \ - { \ - ((Derived*)this)->visit##NAME(obj); \ +#if 0 // FIDDLE TEMPLATE: +%function SLANG_VISITOR_DISPATCH_RESULT_IMPL(baseType) +% for _,T in ipairs(baseType.subclasses) do +% if not T.isAbstract then + Result _dispatchImpl($T* obj) + { + return ((Derived*)this)->visit$T(obj); } +% end +% end +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END // Visitor with and without result -#define SLANG_VISITOR_RESULT_VISIT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - Result visit##NAME(NAME* obj) \ - { \ - return ((Derived*)this)->visit##SUPER(obj); \ - } - -#define SLANG_VISITOR_VOID_VISIT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - void visit##NAME(NAME* obj) \ - { \ - ((Derived*)this)->visit##SUPER(obj); \ +#if 0 // FIDDLE TEMPLATE: +%function SLANG_VISITOR_VISIT_RESULT_IMPL(baseType) +% for _,T in ipairs(baseType.subclasses) do + Result visit$T($T* obj) + { + return ((Derived*)this)->visit$(T.directSuperClass)(obj); } +% end +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 1 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END // Args -#define SLANG_VISITOR_DISPATCH_ARG_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - virtual void dispatch_##NAME(NAME* obj, void* arg) override \ - { \ - ((Derived*)this)->visit##NAME(obj, *(Arg*)arg); \ - } +#if 0 // FIDDLE TEMPLATE: +%function SLANG_VISITOR_DISPATCH_ARG_IMPL(baseType) +% for _, T in ipairs(baseType.subclasses) do +% if not T.isAbstract then +virtual void _dispatchImpl($T* obj, Arg const& arg) +{ + ((Derived*)this)->visit$T(obj, arg); +} +% end +% end +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 2 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END + +#if 0 // FIDDLE TEMPLATE: +%function SLANG_VISITOR_VISIT_ARG_IMPL(baseType) +% for _, T in ipairs(baseType.subclasses) do +void visit$T($T* obj, Arg const& arg) +{ + ((Derived*)this)->visit$(T.directSuperClass)(obj, arg); +} +% end +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 3 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END -#define SLANG_VISITOR_VOID_VISIT_ARG_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - void visit##NAME(NAME* obj, Arg const& arg) \ - { \ - ((Derived*)this)->visit##SUPER(obj, arg); \ - } // // type Visitors // -struct ITypeVisitor -{ - SLANG_CHILDREN_ASTNode_Type(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - // Suppress VS2017 Unreachable code warning #ifdef _MSC_VER #pragma warning(push) #pragma warning(disable : 4702) #endif -template<typename Derived, typename Result = void, typename Base = ITypeVisitor> -struct TypeVisitor : Base +template<typename Derived, typename Result = void> +struct TypeVisitor { Result dispatch(Type* type) { - Result result; - type->accept(this, &result); - return result; + return ASTNodeDispatcher<Type, Result>::dispatch( + type, + [&](auto obj) { return _dispatchImpl(obj); }); } Result dispatchType(Type* type) { - Result result; - type->accept(this, &result); - return result; + return ASTNodeDispatcher<Type, Result>::dispatch( + type, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_Type(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_Type(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived, typename Base> -struct TypeVisitor<Derived, void, Base> : Base -{ - void dispatch(Type* type) { type->accept(this, 0); } - - void dispatchType(Type* type) { type->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_Type(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_Type(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.Type) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.Type) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 4 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; -template<typename Derived, typename Arg, typename Base = ITypeVisitor> -struct TypeVisitorWithArg : Base +template<typename Derived, typename Arg> +struct TypeVisitorWithArg { - void dispatch(Type* type, Arg const& arg) { type->accept(this, (void*)&arg); } + void dispatch(Type* type, Arg const& arg) + { + ASTNodeDispatcher<Type, void>::dispatch(type, [&](auto obj) { _dispatchImpl(obj, arg); }); + } - SLANG_CHILDREN_ASTNode_Type(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_ARG_IMPL) - SLANG_CHILDREN_ASTNode_Type(SLANG_VISITOR_VOID_VISIT_ARG_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_ARG_IMPL(Slang.Type) + % SLANG_VISITOR_VISIT_ARG_IMPL(Slang.Type) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 5 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; // // Expression Visitors // -struct IExprVisitor -{ - SLANG_CHILDREN_ASTNode_Expr(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - template<typename Derived, typename Result = void> -struct ExprVisitor : IExprVisitor +struct ExprVisitor { Result dispatch(Expr* expr) { - Result result; - expr->accept(this, &result); - return result; + return ASTNodeDispatcher<Expr, Result>::dispatch( + expr, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_Expr(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_Expr(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived> -struct ExprVisitor<Derived, void> : IExprVisitor -{ - void dispatch(Expr* expr) { expr->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_Expr(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_Expr(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.Expr) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.Expr) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 6 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; template<typename Derived, typename Arg> -struct ExprVisitorWithArg : IExprVisitor +struct ExprVisitorWithArg { - void dispatch(Expr* obj, Arg const& arg) { obj->accept(this, (void*)&arg); } + void dispatch(Expr* expr, Arg const& arg) + { + ASTNodeDispatcher<Expr, void>::dispatch(expr, [&](auto obj) { _dispatchImpl(obj, arg); }); + } - SLANG_CHILDREN_ASTNode_Expr(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_ARG_IMPL) - SLANG_CHILDREN_ASTNode_Expr(SLANG_VISITOR_VOID_VISIT_ARG_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_ARG_IMPL(Slang.Expr) + % SLANG_VISITOR_VISIT_ARG_IMPL(Slang.Expr) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 7 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; // // Statement Visitors // -struct IStmtVisitor -{ - SLANG_CHILDREN_ASTNode_Stmt(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - template<typename Derived, typename Result = void> -struct StmtVisitor : IStmtVisitor +struct StmtVisitor { Result dispatch(Stmt* stmt) { - Result result; - stmt->accept(this, &result); - return result; + return ASTNodeDispatcher<Stmt, Result>::dispatch( + stmt, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_Stmt(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_Stmt(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived> -struct StmtVisitor<Derived, void> : IStmtVisitor -{ - void dispatch(Stmt* stmt) { stmt->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_Stmt(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_Stmt(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.Stmt) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.Stmt) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 8 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; // // Declaration Visitors // -struct IDeclVisitor -{ - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - template<typename Derived, typename Result = void> -struct DeclVisitor : IDeclVisitor +struct DeclVisitor { Result dispatch(DeclBase* decl) { - Result result; - decl->accept(this, &result); - return result; + return ASTNodeDispatcher<DeclBase, Result>::dispatch( + decl, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived> -struct DeclVisitor<Derived, void> : IDeclVisitor -{ - void dispatch(DeclBase* decl) { decl->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.DeclBase) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.DeclBase) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 9 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; template<typename Derived, typename Arg> -struct DeclVisitorWithArg : IDeclVisitor +struct DeclVisitorWithArg { - void dispatch(DeclBase* obj, Arg const& arg) { obj->accept(this, (void*)&arg); } + void dispatch(DeclBase* decl, Arg const& arg) + { + ASTNodeDispatcher<Expr, void>::dispatch(decl, [&](auto obj) { _dispatchImpl(obj, arg); }); + } - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_ARG_IMPL) - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_VISITOR_VOID_VISIT_ARG_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_ARG_IMPL(Slang.DeclBase) + % SLANG_VISITOR_VISIT_ARG_IMPL(Slang.DeclBase) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 10 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; @@ -241,64 +239,46 @@ struct DeclVisitorWithArg : IDeclVisitor // Modifier Visitors // -struct IModifierVisitor -{ - SLANG_CHILDREN_ASTNode_Modifier(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - template<typename Derived, typename Result = void> -struct ModifierVisitor : IModifierVisitor +struct ModifierVisitor { Result dispatch(Modifier* modifier) { - Result result; - modifier->accept(this, &result); - return result; + return ASTNodeDispatcher<Modifier, Result>::dispatch( + modifier, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_Modifier(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_Modifier(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived> -struct ModifierVisitor<Derived, void> : IModifierVisitor -{ - void dispatch(Modifier* modifier) { modifier->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_Modifier(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_Modifier(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.Modifier) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.Modifier) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 11 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; // // Val Visitors // -struct IValVisitor : ITypeVisitor -{ - SLANG_CHILDREN_ASTNode_Val(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - template<typename Derived, typename Result = void, typename TypeResult = void> -struct ValVisitor : TypeVisitor<Derived, TypeResult, IValVisitor> +struct ValVisitor : TypeVisitor<Derived, TypeResult> { Result dispatch(Val* val) { - Result result; - val->accept(this, &result); - return result; + return ASTNodeDispatcher<Val, Result>::dispatch( + val, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_Val(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_Val(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived> -struct ValVisitor<Derived, void, void> : TypeVisitor<Derived, void, IValVisitor> -{ - void dispatch(Val* val) { val->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_Val(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_Val(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.Val) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.Val) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 12 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; // Re-activate VS2017 warning settings diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 351ab6f06..99457647d 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -570,64 +570,110 @@ SlangResult Session::saveCoreModule(SlangArchiveType archiveType, ISlangBlob** o } SlangResult Session::saveBuiltinModule( - slang::BuiltinModuleName builtinModuleName, + slang::BuiltinModuleName moduleTag, SlangArchiveType archiveType, ISlangBlob** outBlob) { + // If no builtin modules have been loaded, then there is + // nothing to save, and we fail immediately. + // if (m_builtinLinkage->mapNameToLoadedModules.getCount() == 0) { - // There is no standard lib loaded return SLANG_FAIL; } - BuiltinModuleInfo builtinModuleInfo = getBuiltinModuleInfo(builtinModuleName); - - // Make a file system to read it from - ComPtr<ISlangMutableFileSystem> fileSystem; - SLANG_RETURN_ON_FAIL(createArchiveFileSystem(archiveType, fileSystem)); - - // Must have archiveFileSystem interface - auto archiveFileSystem = as<IArchiveFileSystem>(fileSystem); - if (!archiveFileSystem) - { - return SLANG_FAIL; - } + // The module will need to be looked up by its name, and + // will also be serialized out to a path with a matching name. + // + BuiltinModuleInfo moduleInfo = getBuiltinModuleInfo(moduleTag); + const char* moduleName = moduleInfo.name; + // If we cannot find a loaded module in the linkage with + // the appropriate name, then for some reason it hasn't + // been loaded, and we fail. + // RefPtr<Module> module; m_builtinLinkage->mapNameToLoadedModules.tryGetValue( - getNameObj(UnownedStringSlice(builtinModuleInfo.name)), + getNameObj(UnownedStringSlice(moduleName)), module); if (!module) { return SLANG_FAIL; } + // AST serialization needs access to an AST builder, so + // we establish a current builder for the duration of + // the serialization process. + // SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder()); - // Set up options - SerialContainerUtil::WriteOptions options; + // The serialized module will be represented as a logical + // file in an archive, so we create a logical file system + // to represent that archive. + // + ComPtr<ISlangMutableFileSystem> fileSystem; + SLANG_RETURN_ON_FAIL(createArchiveFileSystem(archiveType, fileSystem)); + // + // The created file system must support the `IArchiveFileSystem` + // interface (since we created it with `createArchiveFileSystem`). + // + auto archiveFileSystem = as<IArchiveFileSystem>(fileSystem); + if (!archiveFileSystem) + { + return SLANG_FAIL; + } - // Save with SourceLocation information - options.optionFlags |= SerialOptionFlag::SourceLocation; + // The output file name that we'll write to in that file system + // is just the builtin module name with a `.slang-module` suffix. + // + StringBuilder moduleFileName; + moduleFileName << moduleName << ".slang-module"; - // TODO(JS): Should this be the Session::getBuiltinSourceManager()? + // The module serialization step has some options that we need + // to configure appropriately. + // + SerialContainerUtil::WriteOptions options; + // + // We want builtin modules to be saved with their source location + // information. + // + options.optionFlags |= SerialOptionFlag::SourceLocation; + // + // And in order to work with source locations, the serialization + // process will also need access to the source manager that + // can translate locations into their humane format. + // options.sourceManager = m_builtinLinkage->getSourceManager(); - StringBuilder builder; - builder << builtinModuleInfo.name << ".slang-module"; - + // At this point we can finally delegate down to the next level, + // which handles the serialization of a Slang module into a + // byte stream. + // OwnedMemoryStream stream(FileAccess::Write); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(module, options, &stream)); - auto contents = stream.getContents(); - // Write into the file system - SLANG_RETURN_ON_FAIL( - fileSystem->saveFile(builder.getBuffer(), contents.getBuffer(), contents.getCount())); + // Once the stream that represents the module has been written, we can + // write it to a file in the logical file system. + // + // TODO(tfoley): why can't the file system let us open the file for output? + // + SLANG_RETURN_ON_FAIL(fileSystem->saveFile( + moduleFileName.getBuffer(), + contents.getBuffer(), + contents.getCount())); + + // And finally, we can ask the archive file system to serialize itself + // out as a blob of bytes, which yields the final serialized representation + // of the module. + // + SLANG_RETURN_ON_FAIL(archiveFileSystem->storeArchive( + // The `true` here indicates that the blob that gets created should own + // its content, independent from the file system object itself; otherwise + // the file system might return a blob that shares storage with itself. + true, + outBlob)); - // Now need to convert into a blob - SLANG_RETURN_ON_FAIL(archiveFileSystem->storeArchive(true, outBlob)); return SLANG_OK; } @@ -654,74 +700,98 @@ SlangResult Session::_readBuiltinModule( SLANG_RETURN_ON_FAIL(RiffUtil::read(&stream, riffContainer)); } - // Load up the module + Linkage* linkage = getBuiltinLinkage(); + SourceManager* sourceManager = getBuiltinSourceManager(); + NamePool* sessionNamePool = &namePool; - SerialContainerData containerData; + auto moduleChunk = ModuleChunkRef::find(&riffContainer); + if (!moduleChunk) + return SLANG_FAIL; - Linkage* linkage = getBuiltinLinkage(); + SHA1::Digest moduleDigest = moduleChunk.getDigest(); - SourceManager* sourceManger = getBuiltinSourceManager(); + auto irChunk = moduleChunk.findIR(); + if (!irChunk) + return SLANG_FAIL; - NamePool* sessionNamePool = &namePool; - NamePool* linkageNamePool = linkage->getNamePool(); + auto astChunk = moduleChunk.findAST(); + if (!astChunk) + return SLANG_FAIL; - SerialContainerUtil::ReadOptions options; - options.namePool = linkageNamePool; - options.session = this; - options.sharedASTBuilder = linkage->getASTBuilder()->getSharedASTBuilder(); - options.astBuilder = linkage->getASTBuilder(); - options.sourceManager = sourceManger; - options.linkage = linkage; + // Source location information is stored as a distinct + // chunk from the IR and AST, so we need to search for + // that chunk and then set up the information for use + // in the IR and AST deserialization (if we find anything). + // + RefPtr<SerialSourceLocReader> sourceLocReader; + if (auto debugChunk = findDebugChunk(moduleChunk.ptr())) + { + SLANG_RETURN_ON_FAIL( + readSourceLocationsFromDebugChunk(debugChunk, sourceManager, sourceLocReader)); + } - // Hmm - don't have a suitable sink yet, so attempt to just not have one - options.sink = nullptr; + // At this point we create the `Module` object that will + // represent the builtin module we are reading, although + // it is still possible that deserialization will fail + // at one of the following steps. + // + auto astBuilder = linkage->getASTBuilder(); + RefPtr<Module> module(new Module(linkage, astBuilder)); + module->setName(moduleName); + module->setDigest(moduleDigest); - SLANG_RETURN_ON_FAIL( - SerialContainerUtil::read(&riffContainer, options, nullptr, containerData)); - for (auto& srcModule : containerData.modules) + // Next, we set about deserializing the AST representation + // of the module. + // + auto moduleDecl = readSerializedModuleAST( + linkage, + astBuilder, + nullptr, // no sink + astChunk, + sourceLocReader, + SourceLoc()); + if (!moduleDecl) { - RefPtr<Module> module(new Module(linkage, srcModule.astBuilder)); - module->setName(moduleName); - module->setDigest(srcModule.digest); - - ModuleDecl* moduleDecl = as<ModuleDecl>(srcModule.astRootNode); - // Set the module back reference on the decl - moduleDecl->module = module; + return SLANG_FAIL; + } + moduleDecl->module = module; + module->setModuleDecl(moduleDecl); - if (moduleDecl) - { - if (isFromCoreModule(moduleDecl)) - { - registerBuiltinDecls(this, moduleDecl); - } + if (isFromCoreModule(moduleDecl)) + { + registerBuiltinDecls(this, moduleDecl); + } - module->setModuleDecl(moduleDecl); - } + // After the AST module has been read in, we next look + // to deserialize the IR module. + // + RefPtr<IRModule> irModule; + SLANG_RETURN_ON_FAIL(decodeModuleIR(irModule, irChunk, this, sourceLocReader)); - srcModule.irModule->setName(module->getNameObj()); - module->setIRModule(srcModule.irModule); + irModule->setName(module->getNameObj()); + module->setIRModule(irModule); - // Put in the loaded module map - linkage->mapNameToLoadedModules.add(sessionNamePool->getName(moduleName), module); + // Put in the loaded module map + linkage->mapNameToLoadedModules.add(sessionNamePool->getName(moduleName), module); - // Add the resulting code to the appropriate scope - if (!scope->containerDecl) - { - // We are the first chunk of code to be loaded for this scope - scope->containerDecl = moduleDecl; - } - else - { - // We need to create a new scope to link into the whole thing - auto subScope = linkage->getASTBuilder()->create<Scope>(); - subScope->containerDecl = moduleDecl; - subScope->nextSibling = scope->nextSibling; - scope->nextSibling = subScope; - } - outModule = module.get(); + // Add the resulting code to the appropriate scope + if (!scope->containerDecl) + { + // We are the first chunk of code to be loaded for this scope + scope->containerDecl = moduleDecl; } + else + { + // We need to create a new scope to link into the whole thing + auto subScope = linkage->getASTBuilder()->create<Scope>(); + subScope->containerDecl = moduleDecl; + subScope->nextSibling = scope->nextSibling; + scope->nextSibling = subScope; + } + + outModule = module.get(); return SLANG_OK; } @@ -1526,9 +1596,10 @@ slang::IModule* Linkage::loadModuleFromBlob( pathInfo = PathInfo::makeNormal(pathStr, cannonicalPath); } } - auto module = loadModule(name, pathInfo, source, SourceLoc(), &sink, nullptr, blobType); + RefPtr<Module> module = + loadModuleImpl(name, pathInfo, source, SourceLoc(), &sink, nullptr, blobType); sink.getBlobIfNeeded(outDiagnostics); - return asExternal(module); + return asExternal(module.detach()); } catch (const AbortCompilationException& e) { @@ -4057,101 +4128,157 @@ void Linkage::loadParsedModule( loadedModulesList.add(loadedModule); } -RefPtr<Module> Linkage::loadDeserializedModule( - Name* name, - const PathInfo& filePathInfo, - SerialContainerData::Module& moduleEntry, +RefPtr<Module> Linkage::findOrLoadSerializedModuleForModuleLibrary( + ModuleChunkRef moduleChunk, DiagnosticSink* sink) { - SLANG_AST_BUILDER_RAII(m_astBuilder); RefPtr<Module> resultModule; - if (mapNameToLoadedModules.tryGetValue(name, resultModule)) - return resultModule; - if (mapPathToLoadedModule.tryGetValue(filePathInfo.getMostUniqueIdentity(), resultModule)) + + // We will attempt things in a few different steps, trying to + // decode as little of the serialized module as necessary at + // each step, so that we don't waste time on the heavyweight + // stuff when we didn't need to. + // + // The first step is to simply decode the module name, and + // see if we have a already loaded a matching module. + + auto moduleName = getNamePool()->getName(moduleChunk.getName()); + if (mapNameToLoadedModules.tryGetValue(moduleName, resultModule)) return resultModule; - resultModule = new Module(this, m_astBuilder); - prepareDeserializedModule(moduleEntry, filePathInfo, resultModule, sink); + // It is possible that the module has been loaded, but somehow + // under a different name, so next we decode the list of file + // paths that the module depends on, and then rely on the assumption + // that the first of those paths represents the file for the module + // itself to detect if we've already loaded a module from that + // path. + // + // Note: While this is a distasteful assumption to make, it is + // one that gets made in several parts of the compiler codebase + // already. It isn't something that can be fixed in just one + // place at this point. + + auto fileDependenciesChunk = moduleChunk.getFileDependencies(); + auto firstFileDependencyChunk = fileDependenciesChunk.getFirst(); + if (!firstFileDependencyChunk) + return nullptr; + + auto modulePathInfo = PathInfo::makePath(firstFileDependencyChunk.getValue()); + if (mapPathToLoadedModule.tryGetValue(modulePathInfo.getMostUniqueIdentity(), resultModule)) + return resultModule; - loadedModulesList.add(resultModule); - mapPathToLoadedModule.add(filePathInfo.getMostUniqueIdentity(), resultModule); - mapNameToLoadedModules.add(name, resultModule); - return resultModule; + // If we failed to find a previously-loaded module, then we + // will go ahead and load the module from the serialized form. + // + PathInfo filePathInfo; + return loadSerializedModule(moduleName, modulePathInfo, moduleChunk, SourceLoc(), sink); } -RefPtr<Module> Linkage::loadModuleFromIRBlobImpl( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* fileContentsBlob, - SourceLoc const& loc, - DiagnosticSink* sink, - const LoadedModuleDictionary* additionalLoadedModules) +RefPtr<Module> Linkage::loadSerializedModule( + Name* moduleName, + const PathInfo& moduleFilePathInfo, + ModuleChunkRef moduleChunk, + SourceLoc const& requestingLoc, + DiagnosticSink* sink) { - SLANG_AST_BUILDER_RAII(m_astBuilder); + auto astBuilder = getASTBuilder(); + SLANG_AST_BUILDER_RAII(astBuilder); - RefPtr<Module> resultModule = new Module(this, getASTBuilder()); - resultModule->setName(name); - ModuleBeingImportedRAII moduleBeingImported(this, resultModule, name, loc); + auto module = RefPtr(new Module(this, astBuilder)); + module->setName(moduleName); - String mostUniqueIdentity = filePathInfo.getMostUniqueIdentity(); - SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); + // Just as if we were processing an `import` declaration in + // source code, we will track the fact that this serialized + // modlue is (effectively) being imported, so that we can + // diagnose anything troublesome, like an attempt at a + // recursive import. + // + ModuleBeingImportedRAII moduleBeingImported(this, module, moduleName, requestingLoc); - RiffContainer container; - MemoryStreamBase readStream( - FileAccess::Read, - fileContentsBlob->getBufferPointer(), - fileContentsBlob->getBufferSize()); - SLANG_RETURN_NULL_ON_FAIL(RiffUtil::read(&readStream, container)); + // We will register the module in our data structures to + // track loaded modules, and then remove it in the case + // where there is some kind of failure. + // + String mostUniqueIdentity = moduleFilePathInfo.getMostUniqueIdentity(); + SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); - if (m_optionSet.getBoolOption(CompilerOptionName::UseUpToDateBinaryModule)) + mapPathToLoadedModule.add(mostUniqueIdentity, module); + mapNameToLoadedModules.add(moduleName, module); + try { - if (!isBinaryModuleUpToDate(filePathInfo.foundPath, &container)) + if (SLANG_FAILED( + loadSerializedModuleContents(module, moduleFilePathInfo, moduleChunk, sink))) + { + mapPathToLoadedModule.remove(mostUniqueIdentity); + mapNameToLoadedModules.remove(moduleName); return nullptr; - } + } - mapPathToLoadedModule.add(mostUniqueIdentity, resultModule); - mapNameToLoadedModules.add(name, resultModule); - - SerialContainerUtil::ReadOptions readOptions; - readOptions.linkage = this; - readOptions.astBuilder = getASTBuilder(); - readOptions.session = getSessionImpl(); - readOptions.sharedASTBuilder = getASTBuilder()->getSharedASTBuilder(); - readOptions.sink = sink; - readOptions.sourceManager = getSourceManager(); - readOptions.namePool = getNamePool(); - readOptions.modulePath = filePathInfo.foundPath; - SerialContainerData containerData; - if (SLANG_FAILED(SerialContainerUtil::read( - &container, - readOptions, - additionalLoadedModules, - containerData)) || - containerData.modules.getCount() != 1) + loadedModulesList.add(module); + return module; + } + catch (...) { mapPathToLoadedModule.remove(mostUniqueIdentity); - mapNameToLoadedModules.remove(name); - return nullptr; + mapNameToLoadedModules.remove(moduleName); + throw; } - auto moduleEntry = containerData.modules.getFirst(); +} - prepareDeserializedModule(moduleEntry, filePathInfo, resultModule, sink); +RefPtr<Module> Linkage::loadBinaryModuleImpl( + Name* moduleName, + const PathInfo& moduleFilePathInfo, + ISlangBlob* moduleFileContents, + SourceLoc const& requestingLoc, + DiagnosticSink* sink) +{ + auto astBuilder = getASTBuilder(); + SLANG_AST_BUILDER_RAII(astBuilder); - loadedModulesList.add(resultModule); - resultModule->setPathInfo(filePathInfo); - resultModule->getIRModule()->setName(resultModule->getNameObj()); + // We start by reading the content of the file into + // an in-memory RIFF container. + // + // TODO(tfoley): this is an unnecessary copy step, since + // we can simply use the contents of the blob directly + // and navigate it in-memory. + // + RiffContainer riffContainer; + { + MemoryStreamBase readStream( + FileAccess::Read, + moduleFileContents->getBufferPointer(), + moduleFileContents->getBufferSize()); + SLANG_RETURN_NULL_ON_FAIL(RiffUtil::read(&readStream, riffContainer)); + } - return resultModule; -} + auto moduleChunkRef = ModuleChunkRef::find(&riffContainer); + if (!moduleChunkRef) + { + return nullptr; + } -Module* Linkage::loadModule(String const& name) -{ - // TODO: We either need to have a diagnostics sink - // get passed into this operation, or associate - // one with the linkage. + // Next, we attempt to check if the binary module is up to + // date with the compilation options in use as well as + // the contents of all the files its compilation depended + // on (as determined by its hash). // - DiagnosticSink* sink = nullptr; - return findOrImportModule(getNamePool()->getName(name), SourceLoc(), sink); + String mostUniqueIdentity = moduleFilePathInfo.getMostUniqueIdentity(); + SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); + if (m_optionSet.getBoolOption(CompilerOptionName::UseUpToDateBinaryModule)) + { + if (!isBinaryModuleUpToDate(moduleFilePathInfo.foundPath, moduleChunkRef)) + { + return nullptr; + } + } + + // If everything seems reasonable, then we will go ahead and load + // the module more completely from that serialized representation. + // + RefPtr<Module> module = + loadSerializedModule(moduleName, moduleFilePathInfo, moduleChunkRef, requestingLoc, sink); + + return module; } void Linkage::_diagnoseErrorInImportedModule(DiagnosticSink* sink) @@ -4166,24 +4293,43 @@ void Linkage::_diagnoseErrorInImportedModule(DiagnosticSink* sink) } } -RefPtr<Module> Linkage::loadModule( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* sourceBlob, - SourceLoc const& srcLoc, +RefPtr<Module> Linkage::loadModuleImpl( + Name* moduleName, + const PathInfo& modulePathInfo, + ISlangBlob* moduleBlob, + SourceLoc const& requestingLoc, DiagnosticSink* sink, const LoadedModuleDictionary* additionalLoadedModules, ModuleBlobType blobType) { - if (blobType == ModuleBlobType::IR) - return loadModuleFromIRBlobImpl( - name, - filePathInfo, - sourceBlob, - srcLoc, + switch (blobType) + { + case ModuleBlobType::IR: + return loadBinaryModuleImpl(moduleName, modulePathInfo, moduleBlob, requestingLoc, sink); + + case ModuleBlobType::Source: + return loadSourceModuleImpl( + moduleName, + modulePathInfo, + moduleBlob, + requestingLoc, sink, additionalLoadedModules); + default: + SLANG_UNEXPECTED("unknown module blob type"); + UNREACHABLE_RETURN(nullptr); + } +} + +RefPtr<Module> Linkage::loadSourceModuleImpl( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* sourceBlob, + SourceLoc const& srcLoc, + DiagnosticSink* sink, + const LoadedModuleDictionary* additionalLoadedModules) +{ RefPtr<FrontEndCompileRequest> frontEndReq = new FrontEndCompileRequest(this, nullptr, sink); frontEndReq->additionalLoadedModules = additionalLoadedModules; @@ -4275,8 +4421,10 @@ RefPtr<Module> Linkage::loadModule( return nullptr; } - if (module) - module->setPathInfo(filePathInfo); + if (!module) + return nullptr; + + module->setPathInfo(filePathInfo); return module; } @@ -4319,126 +4467,263 @@ String getFileNameFromModuleName(Name* name, bool translateUnderScore) } RefPtr<Module> Linkage::findOrImportModule( - Name* name, - SourceLoc const& loc, + Name* moduleName, + SourceLoc const& requestingLoc, DiagnosticSink* sink, const LoadedModuleDictionary* loadedModules) { // Have we already loaded a module matching this name? // - RefPtr<LoadedModule> loadedModule; - if (mapNameToLoadedModules.tryGetValue(name, loadedModule)) + RefPtr<LoadedModule> previouslyLoadedModule; + if (mapNameToLoadedModules.tryGetValue(moduleName, previouslyLoadedModule)) { // If the map shows a null module having been loaded, // then that means there was a prior load attempt, // but it failed, so we won't bother trying again. // - if (!loadedModule) + if (!previouslyLoadedModule) return nullptr; // If state shows us that the module is already being // imported deeper on the call stack, then we've // hit a recursive case, and that is an error. // - if (isBeingImported(loadedModule)) + if (isBeingImported(previouslyLoadedModule)) { // We seem to be in the middle of loading this module - sink->diagnose(loc, Diagnostics::recursiveModuleImport, name); + sink->diagnose(requestingLoc, Diagnostics::recursiveModuleImport, moduleName); return nullptr; } - return loadedModule; + return previouslyLoadedModule; } // If the user is providing an additional list of loaded modules, we find // if the module being imported is in that list. This allows a translation // unit to use previously checked translation units in the same // FrontEndCompileRequest. - Module* previouslyLoadedModule = nullptr; - if (loadedModules && loadedModules->tryGetValue(name, previouslyLoadedModule)) { - return previouslyLoadedModule; + Module* previouslyLoadedLocalModule = nullptr; + if (loadedModules && loadedModules->tryGetValue(moduleName, previouslyLoadedLocalModule)) + { + return previouslyLoadedLocalModule; + } } - if (name == getSessionImpl()->glslModuleName) + // If the name being requested matches the name of a built-in module, + // then we will special-case the process by loading that builtin + // module directly. + // + // TODO: right now this logic is only considering the built-in `glsl` + // module, but it should probably be generalized so that we can more + // easily support having multiple built-in modules rather than just + // putting everything into `core`. + // + if (moduleName == getSessionImpl()->glslModuleName) { // This is a builtin glsl module, just load it from embedded definition. auto glslModule = getSessionImpl()->getBuiltinModule(slang::BuiltinModuleName::GLSL); if (!glslModule) { - sink->diagnose(loc, Diagnostics::glslModuleNotAvailable, name); + // Note: the way this logic is currently written, if the built-in + // `glsl` module fails to load, then we will *not* fall back to + // searching for a user-defined module in a file like `glsl.slang`. + // + // It is unclear if this should be the default behavior or not. + // Should built-in modules be prioritized over user modules? + // Should built-in modules shadow user modules, even when the + // built-in module fails to load, for some reason? + // + sink->diagnose(requestingLoc, Diagnostics::glslModuleNotAvailable, moduleName); } return glslModule; } - // Next, try to find the file of the given name, - // using our ordinary include-handling logic. + // We are going to use a loop to search for a suitable file to + // load the module from, to account for a few key choices: + // + // * We can both load modules from a source `.slang` file, + // or from a binary `.slang-module` file. + // + // * For a variety of reasons, the `import` logic has historically + // translated underscores in a module name into dashes (so that + // `import my_module` will look for `my-module.slang`), and we + // try to support both that convention as well as a convention + // that preserves underscores. + // + // To try to keep this logic as orthogonal as possible, we first + // construct lists of the options we want to iterate over, and + // then do the actual loop later. - IncludeSystem includeSystem(&getSearchDirectories(), getFileSystemExt(), getSourceManager()); + ShortList<ModuleBlobType, 2> typesToTry; + if (isInLanguageServer()) + { + // When in language server, we always prefer to use source module if it is available. + typesToTry.add(ModuleBlobType::Source); + typesToTry.add(ModuleBlobType::IR); + } + else + { + // Look for a precompiled module first, if not exist, load from source. + typesToTry.add(ModuleBlobType::IR); + typesToTry.add(ModuleBlobType::Source); + } - // Get the original path info - PathInfo pathIncludedFromInfo = getSourceManager()->getPathInfo(loc, SourceLocType::Actual); - PathInfo filePathInfo; + // We will always search for a file name that directly matches the + // module name as written first, and then search for one with + // underscores replaced by dashes. The latter is the original + // behavior that `import` provided, but it seems safest to prefer + // the exact name spelled in the user's code when there might + // actually be ambiguity. + // + auto defaultSourceFileName = getFileNameFromModuleName(moduleName, false); + auto alternativeSourceFileName = getFileNameFromModuleName(moduleName, true); + String sourceFileNamesToTry[] = {defaultSourceFileName, alternativeSourceFileName}; + // We are going to look for the candidate file using the same + // logic that would be used for a preprocessor `#include`, + // so we set up the necessary state. + // + IncludeSystem includeSystem(&getSearchDirectories(), getFileSystemExt(), getSourceManager()); - // Look for a precompiled module first, if not exist, load from source. - bool shouldCheckBinaryModuleSettings[2] = {true, false}; + // Just like with a `#include`, the search will take into + // account the path to the file where the request to import + // this module came from (e.g. the source file with the + // `import` declaration), if such a path is available. + // + PathInfo requestingPathInfo = + getSourceManager()->getPathInfo(requestingLoc, SourceLocType::Actual); - for (auto checkBinaryModule : shouldCheckBinaryModuleSettings) + for (auto type : typesToTry) { - // When in language server, we always prefer to use source module if it is available. - if (isInLanguageServer()) - checkBinaryModule = !checkBinaryModule; - - // Try without translating `_` to `-` first, if that fails, try translating. - for (int translateUnderScore = 0; translateUnderScore <= 1; translateUnderScore++) + for (auto sourceFileName : sourceFileNamesToTry) { - auto moduleSourceFileName = getFileNameFromModuleName(name, translateUnderScore == 1); + // The `sourceFileName` will have the `.slang` extension, + // so if we are looking for a binary module, we need + // to change the extension we will look for. + // String fileName; - if (checkBinaryModule == 1) - fileName = Path::replaceExt(moduleSourceFileName, "slang-module"); - else - fileName = moduleSourceFileName; + switch (type) + { + case ModuleBlobType::Source: + fileName = sourceFileName; + break; - ComPtr<ISlangBlob> fileContents; + case ModuleBlobType::IR: + fileName = Path::replaceExt(sourceFileName, "slang-module"); + break; + } - // We have to load via the found path - as that is how file was originally loaded + // We now search for a file matching the desired name, + // using the same logic as for a `#include`. + // + // TODO: We might want to consider how to handle the case + // of an `import` with a relative path a little specially, + // since it could in theory be possible for two `.slang` + // files with the same base name to exist in different + // directories in a project, and we'd want file-relative + // `import`s to work for each, without having either one + // be able to "claim" the bare identifier of the base + // name for itself. + // + PathInfo filePathInfo; if (SLANG_FAILED( - includeSystem.findFile(fileName, pathIncludedFromInfo.foundPath, filePathInfo))) + includeSystem.findFile(fileName, requestingPathInfo.foundPath, filePathInfo))) { + // If we failed to find the file at this step, we + // will continue the search for our other options. + // continue; } - // Maybe this was loaded previously at a different relative name? + // We will *again* search for a previously loaded module. + // + // It is possible that the same file will have been loaded + // as a module under two different module names. The easiest + // way for this to happen is if there are `import` declarations + // using both the underscore and dash conventions (e.g., both + // `import "my-module.slang"` and `import my_module`). + // + // This case may also arise if one file `import`s a module using + // just an identifier for its name, but another `import`s it + // using a path (e.g., `import "subdir/file.slang"`). + // + // No matter how the situation arises, we only want to have one + // copy of the "same" module loaded at a given time, so we + // will re-use the existing module if we find one here. + // if (mapPathToLoadedModule.tryGetValue( filePathInfo.getMostUniqueIdentity(), - loadedModule)) - return loadedModule; + previouslyLoadedModule)) + { + // TODO: If we find a previously-loaded module at this step, + // then we should probably register that module under the + // given `moduleName` in the map of loaded modules, so + // that subsequent `import`s using the same form will find it. + // + return previouslyLoadedModule; + } - // Try to load it - if (!fileContents && SLANG_FAILED(includeSystem.loadFile(filePathInfo, fileContents))) + // Now we try to load the content of the file. + // + // If for some reason we could find a file at the + // given path, but for some reason couldn't *open* + // and *read* it, then we continue the search + // using whatever other candidate file names are left. + // + ComPtr<ISlangBlob> fileContents; + if (SLANG_FAILED(includeSystem.loadFile(filePathInfo, fileContents))) { continue; } - // We've found a file that we can load for the given module, so - // go ahead and perform the module-load action - auto resultModule = loadModule( - name, + // If we found a real file and were able to load its contents, + // then we'll go ahead and try to load a module from it, + // whether by compiling it or decoding the binary. + // + auto module = loadModuleImpl( + moduleName, filePathInfo, fileContents, - loc, + requestingLoc, sink, loadedModules, - (checkBinaryModule == 1 ? ModuleBlobType::IR : ModuleBlobType::Source)); - if (resultModule) - return resultModule; + type); + + // If the attempt to load the module from the given path + // was successful, we go ahead and use it, without trying + // out any other options. + // + if (module) + return module; } } - // Error: we cannot find the file. - sink->diagnose(loc, Diagnostics::cannotOpenFile, getFileNameFromModuleName(name, false)); - mapNameToLoadedModules[name] = nullptr; + // If we tried out all of our candidate file names + // and failed with each of them, then we diagnose + // an error based on the original *source* file + // name. + // + // TODO: this should really be an error message + // that clearly states something like "no file + // suitable for module `whatever` was found + // and loaded. + // + // Ideally that error message would include whatever + // of the candidate file names from the loop above + // got furthest along in the process (or just a + // list of the file names that were tried, if + // nothing was even found via the include system). + // + sink->diagnose(requestingLoc, Diagnostics::cannotOpenFile, defaultSourceFileName); + + // If the attempt to import the module failed, then + // we will stick a null pointer into the map of loaded + // modules, so that subsequent attempts to load a module + // with this name will return null without having to + // go through all the above steps yet again. + // + mapNameToLoadedModules[moduleName] = nullptr; return nullptr; } @@ -4454,27 +4739,19 @@ SourceFile* Linkage::loadSourceFile(String pathFrom, String path) } // Check if a serialized module is up-to-date with current compiler options and source files. -bool Linkage::isBinaryModuleUpToDate(String fromPath, RiffContainer* container) +bool Linkage::isBinaryModuleUpToDate(String fromPath, RiffContainer* riffContainer) { - DiagnosticSink sink; - SerialContainerUtil::ReadOptions readOptions; - readOptions.linkage = this; - readOptions.astBuilder = getASTBuilder(); - readOptions.session = getSessionImpl(); - readOptions.sharedASTBuilder = getASTBuilder()->getSharedASTBuilder(); - readOptions.sink = &sink; - readOptions.sourceManager = getSourceManager(); - readOptions.namePool = getNamePool(); - readOptions.readHeaderOnly = true; - - SerialContainerData containerData; - if (SLANG_FAILED(SerialContainerUtil::read(container, readOptions, nullptr, containerData))) + auto moduleChunk = ModuleChunkRef::find(riffContainer); + if (!moduleChunk) return false; - if (containerData.modules.getCount() != 1) - return false; + return isBinaryModuleUpToDate(fromPath, moduleChunk); +} + +bool Linkage::isBinaryModuleUpToDate(String fromPath, ModuleChunkRef moduleChunk) +{ + SHA1::Digest existingDigest = moduleChunk.getDigest(); - auto& moduleHeader = containerData.modules[0]; DigestBuilder<SHA1> digestBuilder; auto version = String(getBuildTagString()); digestBuilder.append(version); @@ -4482,9 +4759,12 @@ bool Linkage::isBinaryModuleUpToDate(String fromPath, RiffContainer* container) // Find the canonical path of the directory containing the module source file. String moduleSrcPath = ""; - if (moduleHeader.dependentFiles.getCount()) + + auto dependencyChunks = moduleChunk.getFileDependencies(); + if (auto firstDependencyChunk = dependencyChunks.getFirst()) { - moduleSrcPath = moduleHeader.dependentFiles.getFirst(); + moduleSrcPath = firstDependencyChunk.getValue(); + IncludeSystem includeSystem( &getSearchDirectories(), getFileSystemExt(), @@ -4497,21 +4777,22 @@ bool Linkage::isBinaryModuleUpToDate(String fromPath, RiffContainer* container) } } - for (auto file : moduleHeader.dependentFiles) + for (auto dependencyChunk : dependencyChunks) { + auto file = dependencyChunk.getValue(); auto sourceFile = loadSourceFile(fromPath, file); if (!sourceFile) { // If we cannot find the source file from `fromPath`, // try again from the module's source file path. - if (moduleHeader.dependentFiles.getCount() != 0) + if (dependencyChunks.getFirst()) sourceFile = loadSourceFile(moduleSrcPath, file); } if (!sourceFile) return false; digestBuilder.append(sourceFile->getDigest()); } - return digestBuilder.finalize() == moduleHeader.digest; + return digestBuilder.finalize() == existingDigest; } SLANG_NO_THROW bool SLANG_MCALL @@ -6243,20 +6524,100 @@ void Linkage::setFileSystem(ISlangFileSystem* inFileSystem) getSourceManager()->setFileSystemExt(m_fileSystemExt); } -void Linkage::prepareDeserializedModule( - SerialContainerData::Module& moduleEntry, - const PathInfo& filePathInfo, +SlangResult Linkage::loadSerializedModuleContents( Module* module, + const PathInfo& moduleFilePathInfo, + ModuleChunkRef moduleChunk, DiagnosticSink* sink) { - module->setIRModule(moduleEntry.irModule); - module->setModuleDecl(as<ModuleDecl>(moduleEntry.astRootNode)); + // At this point we've dealt with basically all of + // the formalities, and we just need to get down + // to the real work of decoding the information + // in the `moduleChunk`. + + auto sourceManager = getSourceManager(); + RefPtr<SerialSourceLocReader> sourceLocReader; + if (auto debugChunk = findDebugChunk(moduleChunk.ptr())) + { + SLANG_RETURN_ON_FAIL( + readSourceLocationsFromDebugChunk(debugChunk, sourceManager, sourceLocReader)); + } + + auto astChunk = moduleChunk.findAST(); + if (!astChunk) + return SLANG_FAIL; + + auto irChunk = moduleChunk.findIR(); + if (!irChunk) + return SLANG_FAIL; + + auto astBuilder = getASTBuilder(); + auto session = getSessionImpl(); + + // For the purposes of any modules referenced + // by the module we're about to decode, we will + // construct a source location that represents + // the module itself (if possible). + // + // TODO(tfoley): This logic seems like overkill, given + // that many (most? all?) control-flow paths that can + // reach this routine will have already found a `SourceFile` + // to represent the module, as part of even getting the + // `moduleFilePathInfo` to pass in + // + // The approach here is more or less exactly copied + // from what the old `SerialContainerUtil::read` function + // used to do, with the hopes that it will as many tests + // passing as possible. + // + // Down the line somebody should scrutinize all of this + // kind of logic in the compiler codebase, because there + // is something that feels unclean about how paths are being handled. + // + SourceLoc serializedModuleLoc; + { + auto sourceFile = + sourceManager->findSourceFileByPathRecursively(moduleFilePathInfo.foundPath); + if (!sourceFile) + { + sourceFile = sourceManager->createSourceFileWithString(moduleFilePathInfo, String()); + sourceManager->addSourceFile(moduleFilePathInfo.getMostUniqueIdentity(), sourceFile); + } + auto sourceView = + sourceManager->createSourceView(sourceFile, &moduleFilePathInfo, SourceLoc()); + serializedModuleLoc = sourceView->getRange().begin; + } + + auto moduleDecl = readSerializedModuleAST( + this, + astBuilder, + sink, + astChunk, + sourceLocReader, + serializedModuleLoc); + if (!moduleDecl) + return SLANG_FAIL; + module->setModuleDecl(moduleDecl); + + RefPtr<IRModule> irModule; + SLANG_RETURN_ON_FAIL(decodeModuleIR(irModule, irChunk, session, sourceLocReader)); + module->setIRModule(irModule); + + // The handling of file dependencies is complicated, because of + // the way that the encoding logic tried to make all of the + // paths be relative to the primary source file for the module. + // + // We end up needing to undo some amount of that work here. + // + module->clearFileDependency(); - String moduleSourcePath = filePathInfo.foundPath; + String moduleSourcePath = moduleFilePathInfo.foundPath; bool isFirst = true; - for (auto file : moduleEntry.dependentFiles) + for (auto depenencyFileChunk : moduleChunk.getFileDependencies()) { - auto sourceFile = loadSourceFile(filePathInfo.foundPath, file); + auto encodedDependencyFilePath = depenencyFileChunk.getValue(); + + auto sourceFile = loadSourceFile(moduleFilePathInfo.foundPath, encodedDependencyFilePath); if (isFirst) { // The first file is the source for the main module file. @@ -6270,20 +6631,19 @@ void Linkage::prepareDeserializedModule( // it relative to the module source path. if (!sourceFile) { - sourceFile = loadSourceFile(moduleSourcePath, file); + sourceFile = loadSourceFile(moduleSourcePath, encodedDependencyFilePath); } if (sourceFile) { module->addFileDependency(sourceFile); } } - module->setPathInfo(filePathInfo); - module->setDigest(moduleEntry.digest); + module->setPathInfo(moduleFilePathInfo); + module->setDigest(moduleChunk.getDigest()); module->_collectShaderParams(); module->_discoverEntryPoints(sink, targets); // Hook up fileDecl's scope to module's scope. - auto moduleDecl = module->getModuleDecl(); for (auto globalDecl : moduleDecl->members) { if (auto fileDecl = as<FileDecl>(globalDecl)) @@ -6291,6 +6651,8 @@ void Linkage::prepareDeserializedModule( addSiblingScopeForContainerDecl(m_astBuilder, moduleDecl->ownedScope, fileDecl); } } + + return SLANG_OK; } void Linkage::setRequireCacheFileSystem(bool requireCacheFileSystem) |
