diff options
| author | Theresa Foley <10618364+tangent-vector@users.noreply.github.com> | 2025-04-22 13:26:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-22 13:26:57 -0700 |
| commit | 1cf3f18a9ca1905a5bc51790ca723815dd5b1400 (patch) | |
| tree | 097a6db7b7e4196f3e68996e8ae68ed8f054fb1f /source | |
| parent | ed5940a629ae05e9571bfe355d22f0728347dcb4 (diff) | |
A new approach to AST serialization (#6854)
* A new approach to AST serialization
This change completely overhauls the way that AST nodes are being serialized, and the offline source-code generation steps that enable that serialization.
In practice, this ends up being a complete overhaul of the way that *modules* are being serialized (not just the AST part), although things like the serialization format for the Slang IR and for source locations are not affected.
The rest of this commit message is broken down in to sections, in an attempt to help guide anybody looking at the code in how to make sense of all the changes.
The Old C++ Extractor
---------------------
AST serialization used to be driven by information scraped using the `slang-cpp-extractor` tool, which did an ad hoc parse of the C++ declarations of the AST node types and then generated a set of "X macros" that could be for macro-based code generation within the rest of the compiler.
While the existing approach was functional, it wasn't easy to understand or maintain, and it has been getting in the way of forward progress on other features we'd like to work on in the language and compiler.
This change removes the `slang-cpp-extractor` tool entirely.
Marking Up the AST Declarations
-------------------------------
The most notable change that contributors to the compiler may notice is the large number of invocations of a macro `FIDDLE()` on the declarations of the AST node types.
The basic idea is that only declarations (namespaces, types, fields) that are preceded by `FIDDLE()` are visible to the code generator tool.
So if somebody is working with the AST and wondering why a new node type isn't working, or why a field they added isn't being serialized correctly, it is probably because they need to add `FIDDLE()` in front of it.
Generating the Boilerplate Code
-------------------------------
The file `slang-ast-boilerplate.cpp` provides a good example of how the information extracted from the marked-up AST declarations gets used.
In that file, the `FIDDLE TEMPLATE` construct is used to generate type information for each of the AST node types.
Similar logic is used in `slang-ast-forward-declarations.h` to generate the declaration of the `ASTNodeType` enumeration, and forward-declare all the AST node classes.
For many parts of the code, simply including that file replaces the need for the old `slang-generated-*.h` files.
Replacing Visitors and Related Logic
------------------------------------
The old visitor types for the AST used the macros that were generated by `slang-cpp-extractor`, so something new was needed to replace them.
The same goes for the `SLANG_AST_NODE_VIRTUAL_CALL` macros.
The core of the solution implemented here is in `slang-ast-dispatch.h`.
Given a "dispatchable" AST node type (say, `Expr`), a call like:
```
ASTNodeDispatcher<Expr,R>(expr, [&](auto e) { return doSomething(e); })
```
is an expression of type `R`, which does the equivalent of something like:
```
switch(expr->getTag())
{
case ASTNodeType::VarExpr: return doSomething(static_cast<VarExpr*>(expr));
// ...
}
```
The `SLANG_AST_NODE_VIRTUAL_CALL` macro is now implemented in terms of `ASTNodeDispatcher`.
The implementation of the visitor types is more involved.
The code in this change retains some of the macro names from the original version, just to try and make the parallels more clear.
The visitor types are all implemented on top of the `ASTNodeDispatcher` approach, and use `FIDDLE TEMPLATE` to generate all the boilerplate `visit*()` method declarations.
Refactoring of `Linkage` Module Loading
---------------------------------------
Needing to revisit all the places where modules get deserialized made it clear that there is a lot of complexity and apparent duplication in the core routines on the `Linkage` that get used for loading modules.
This change tries to clean up some of that logic, but it is worth noting that there are two legacy features that get in the way of making things as clean as they should be:
* The `LoadedModuleDictionary` type that gets passed around a lot exists entirely to handle the corner case where somebody uses the Slang API to perform a compilation with multiple `TranslationUnitRequest`s in the same `FrontEndCompileRequest`, and one of the translation units `import`s the module defined by another of the translation units.
* There are a lot of special-case behaviors and routines entirely there to support the `ModuleLibrary` feature, although that feature should be considered deprecated (or at least subject to getting entirely re-designed down the line).
The basic idea of the cleanup is that all of the (non-deprecated) ways load a module from a serialized binary, or compile one from source should now bottleneck through `loadModuleImpl`, which then bifurcates into `loadSourceModuleImpl` for the compilation case and `loadBinaryModuleImpl` for the deserialization case.
High-Level Serialization Approach
---------------------------------
The old serialization logic used the [RIFF](https://en.wikipedia.org/wiki/Resource_Interchange_File_Format) format to encode the high-level structure of things, and this change retains that usage (and actually doubles down on the RIFF usage).
The old serialization system relied on the idea that for any given type `Foo` that wants to support serialization, there should be something like a `SerialFooData` type in C++, that can represent the state of a `Foo`, and then the actual serialization applied to that `SerialFooData`. This means that in most cases there are four pieces of code written:
* During serialization:
* Copying the data of a `Foo` in memory over to a `SerialFooData` in memory
* Writing the state of a `SerialFooData` into the serialized data stream
* During deserialization:
* Reading the state of a `SerialFooData` from a serialized data stream
* Copying the data of the `SerialFooData` in memory over to a `Foo`
The new logic gets rid of the intermediate `SerialFooData`.
In the serialization direction, we take a `Foo` and write it to the `RIFFContainer` directly, or using some other utilities layered on top of it.
In the deserialization direction, we have additional flexibility. Given a `RIFFContainer::Chunk*` that represents a serialized `Foo`, we often navigate through the in-memory representation of the RIFF data to get to the parts of the serialized value that we actually want/need, without needing to deserialize the entire `Foo`.
To support this kind of operation, this change introduces a few helper types like `ContainerChunkRef` an `ModuleChunkRef`, that are little more than typed wrappers around a `RIFFContainer::Chunk*`.
The Module "Container" Part
---------------------------
A serialized `Module` is encoded as a RIFF chunk, using logic in `slang-serialize-container.cpp` - both before and after this change.
This change reorganizes a lot of the code in that file, to account for the way that eliminating the intermediate `SerialContainerData` type streamlines the overall task of writing out the parts of the module.
In the deserialization logic... there isn't really much to do in `slang-serialize-container.cpp`. Most of the logic in `slang.cpp` and `slang-module-library.cpp` that pertains to deserializing modules uses the `ModuleChunkRef`-based approach, and simply extracts the pieces of the serialized module that it needs.
The Actual Serialization of the AST
-----------------------------------
The actual AST serialization logic is in `slang-serialize-ast.cpp`.
The basic approach in both the writing and reading directions is:
* Use the `FIDDLE TEMPLATE` system to generate a set of functions, one for each AST node type, that recursively invoke the read/write logic on each field of that node (after recursively invoking the case for its direct superclass)
* Use the `ASTNodeDispatcher` system to dispatch out to those functions whene reading or writing anything derived from `NodeBase`
* For now, handle all types *not* derived from `NodeBase` by hand.
There's a lot of room for improvement around that last item: it should be just as easy to generate the serialization and deserialization logic for other types that don't inherit from `NodeBase`, but the current change tries to err on the side of making the logic as explicit and simplistic as possible, rather than trying to get too clever too soon.
The actual serialization *format* used for the AST is almost comically simplistic: the code uses hierarchical RIFF chunks to emulate a JSON-like structure. This is a very wasteful representation (e.g., a `bool` or a null pointer each take up *8 bytes*), but the goal for now is to start with the simplest thing that could possibly work, and only add more cleverness once we are sure it won't get in the way of important future improvements (like lazy/on-demand deserialization or IR and AST, to improve compiler startup times).
The files `slang-serialize.{h,cpp}` have been co-opted to define a new pair of types `Encoder` and `Decoder` that are used for a more-or-less stream-oriented way or reading or writing RIFF chunks for the JSON-like structure.
Almost everything related to the actual AST serialization could do with a cleanup pass, and some time spent on picking good/better names for everything.
Smaller Stuff
-------------
* Cleaned up a lot of code that was using bare `ASTNodeType` or the extractor's `ReflectClassInfo` type to consistently use `SyntaxClass`.
* Fixed an apparent bug in how the destination-driven code genarator was handling `TryExpr`s
* Fixed an apparent bug in how the GLSL legalization pass was handling translation of certain `SV_*` semantics.
* format code
* fixup: template errors caught by non-VS compilers
* format code
* fixup: more template errors
* fixup: more stuff VS didn't catch
* fixup: it's amazing VS doesn't catch these...
* fixup: yet more template stuff VS ignores
* fixup: more VS template nonsense
* fixup: unreachable return macro usage
* fixup: more unreacable returns
* fixup: unused parameter
* fixup: strict aliasing
* fixup: allow missing entry point list chunk
* fixup: wasm build script
* fixup: AST changes since this PR was created
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
67 files changed, 6678 insertions, 8127 deletions
diff --git a/source/core/slang-riff.cpp b/source/core/slang-riff.cpp index 0eb0381e1..c1e3e81c3 100644 --- a/source/core/slang-riff.cpp +++ b/source/core/slang-riff.cpp @@ -758,6 +758,24 @@ void RiffContainer::_addChunk(Chunk* chunk) } } +void RiffContainer::setCurrentChunk(Chunk* chunk) +{ + SLANG_ASSERT(chunk); + + switch (chunk->m_kind) + { + case Chunk::Kind::Data: + m_listChunk = nullptr; + m_dataChunk = static_cast<RiffContainer::DataChunk*>(chunk); + break; + + case Chunk::Kind::List: + m_dataChunk = nullptr; + m_listChunk = static_cast<RiffContainer::ListChunk*>(chunk); + break; + } +} + void RiffContainer::startChunk(Chunk::Kind kind, FourCC fourCC) { SLANG_ASSERT(m_listChunk || m_rootList == nullptr); @@ -857,7 +875,10 @@ void RiffContainer::setPayload(Data* data, const void* payload, size_t size) data->m_ownership = Ownership::Arena; data->m_size = size; - data->m_payload = m_arena.allocateAligned(size, kPayloadMinAlignment); + if (size) + { + data->m_payload = m_arena.allocateAligned(size, kPayloadMinAlignment); + } if (payload) { diff --git a/source/core/slang-riff.h b/source/core/slang-riff.h index 1e2c883b9..c858158e6 100644 --- a/source/core/slang-riff.h +++ b/source/core/slang-riff.h @@ -24,17 +24,11 @@ typedef uint32_t FourCC; #define SLANG_FOUR_CC(c0, c1, c2, c3) \ ((FourCC(c0) << 0) | (FourCC(c1) << 8) | (FourCC(c2) << 16) | (FourCC(c3) << 24)) -#define SLANG_FOUR_CC_GET_FIRST_CHAR(x) char((x) & 0xff) -#define SLANG_FOUR_CC_REPLACE_FIRST_CHAR(x, c) (((x) & 0xffffff00) | FourCC(c)) - #else #define SLANG_FOUR_CC(c0, c1, c2, c3) \ ((FourCC(c0) << 24) | (FourCC(c1) << 16) | (FourCC(c2) << 8) | (FourCC(c3) << 0)) -#define SLANG_FOUR_CC_GET_FIRST_CHAR(x) char((x) >> 24) -#define SLANG_FOUR_CC_REPLACE_FIRST_CHAR(x, c) (((x) & 0x00ffffff) | (FourCC(c) << 24)) - #endif enum @@ -451,6 +445,8 @@ public: /// Ctor RiffContainer(); + void setCurrentChunk(Chunk* chunk); + protected: void _addChunk(Chunk* chunk); ListChunk* _newListChunk(FourCC subType); diff --git a/source/slang-core-module/CMakeLists.txt b/source/slang-core-module/CMakeLists.txt index ba70d77b9..600190161 100644 --- a/source/slang-core-module/CMakeLists.txt +++ b/source/slang-core-module/CMakeLists.txt @@ -61,7 +61,6 @@ set(core_module_source_common_args core slang-capability-defs slang-fiddle-output - slang-reflect-headers SPIRV-Headers INCLUDE_DIRECTORIES_PRIVATE ../slang diff --git a/source/slang-wasm/CMakeLists.txt b/source/slang-wasm/CMakeLists.txt index c6c5601e9..152ed5094 100644 --- a/source/slang-wasm/CMakeLists.txt +++ b/source/slang-wasm/CMakeLists.txt @@ -17,7 +17,7 @@ if(EMSCRIPTEN) compiler-core slang-capability-defs slang-capability-lookup - slang-reflect-headers + slang-fiddle-output slang-lookup-tables INCLUDE_DIRECTORIES_PUBLIC ${slang_SOURCE_DIR}/include . ) diff --git a/source/slang/CMakeLists.txt b/source/slang/CMakeLists.txt index 2adc96939..daea0e002 100644 --- a/source/slang/CMakeLists.txt +++ b/source/slang/CMakeLists.txt @@ -97,60 +97,6 @@ slang_add_target( ) # -# generated headers for reflection -# - -set(SLANG_REFLECT_INPUT - slang-ast-support-types.h - slang-ast-base.h - slang-ast-decl.h - slang-ast-expr.h - slang-ast-modifier.h - slang-ast-stmt.h - slang-ast-type.h - slang-ast-val.h -) -# Make them absolute -list(TRANSFORM SLANG_REFLECT_INPUT PREPEND "${CMAKE_CURRENT_LIST_DIR}/") - -set(SLANG_REFLECT_GENERATED_HEADERS - slang-generated-obj.h - slang-generated-obj-macro.h - slang-generated-ast.h - slang-generated-ast-macro.h - slang-generated-value.h - slang-generated-value-macro.h -) -set(SLANG_REFLECT_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/ast-reflect") -list( - TRANSFORM SLANG_REFLECT_GENERATED_HEADERS - PREPEND "${SLANG_REFLECT_OUTPUT_DIR}/" -) - -add_custom_command( - OUTPUT ${SLANG_REFLECT_GENERATED_HEADERS} - COMMAND ${CMAKE_COMMAND} -E make_directory ${SLANG_REFLECT_OUTPUT_DIR} - COMMAND - slang-cpp-extractor ${SLANG_REFLECT_INPUT} -strip-prefix slang- -o - ${SLANG_REFLECT_OUTPUT_DIR}/slang-generated -output-fields -mark-suffix - _CLASS - DEPENDS ${SLANG_REFLECT_INPUT} slang-cpp-extractor - VERBATIM -) - -add_library( - slang-reflect-headers - INTERFACE - EXCLUDE_FROM_ALL - ${SLANG_REFLECT_GENERATED_HEADERS} -) -set_target_properties(slang-reflect-headers PROPERTIES FOLDER generated) -target_include_directories( - slang-reflect-headers - INTERFACE ${SLANG_REFLECT_OUTPUT_DIR} -) - -# # generated lookup tables # @@ -279,7 +225,6 @@ set(slang_link_args slang-capability-defs slang-capability-lookup slang-fiddle-output - slang-reflect-headers slang-lookup-tables SPIRV-Headers ) diff --git a/source/slang/slang-ast-base.cpp b/source/slang/slang-ast-base.cpp index ac18da404..72e42a860 100644 --- a/source/slang/slang-ast-base.cpp +++ b/source/slang/slang-ast-base.cpp @@ -1,3 +1,4 @@ +// slang-ast-base.cpp #include "slang-ast-base.h" #include "slang-ast-builder.h" diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 8f85334d6..72da9cf56 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -2,25 +2,26 @@ #pragma once -#include "slang-ast-reflect.h" +#include "slang-ast-base.h.fiddle" +#include "slang-ast-forward-declarations.h" #include "slang-ast-support-types.h" #include "slang-capability.h" -#include "slang-generated-ast.h" -#include "slang-serialize-reflection.h" // This file defines the primary base classes for the hierarchy of // AST nodes and related objects. For example, this is where the // basic `Decl`, `Stmt`, `Expr`, `type`, etc. definitions come from. +FIDDLE() namespace Slang { class ASTBuilder; struct SemanticsVisitor; +FIDDLE(abstract) class NodeBase { - SLANG_ABSTRACT_AST_CLASS(NodeBase) + FIDDLE(...) // MUST be called before used. Called automatically via the ASTBuilder. // Note that the astBuilder is not stored in the NodeBase derived types by default. @@ -35,18 +36,12 @@ class NodeBase void _initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder); - /// Get the class info - SLANG_FORCE_INLINE const ReflectClassInfo& getClassInfo() const - { - return *ASTClassInfo::getInfo(astNodeType); - } - - SyntaxClass<NodeBase> getClass() { return SyntaxClass<NodeBase>(&getClassInfo()); } + SyntaxClass<NodeBase> getClass() const { return SyntaxClass<NodeBase>(astNodeType); } /// The type of the node. ASTNodeType(-1) is an invalid node type, and shouldn't appear on any /// correctly constructed (through ASTBuilder) NodeBase derived class. /// The actual type is set when constructed on the ASTBuilder. - ASTNodeType astNodeType = ASTNodeType(-1); + FIDDLE() ASTNodeType astNodeType = ASTNodeType(-1); #ifdef _DEBUG SLANG_UNREFLECTED int32_t _debugUID = 0; @@ -58,37 +53,25 @@ class NodeBase template<typename T> SLANG_FORCE_INLINE T* dynamicCast(NodeBase* node) { - return (node && - ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) - ? static_cast<T*>(node) - : nullptr; + return (node && node->getClass().isSubClassOf<T>()) ? static_cast<T*>(node) : nullptr; } template<typename T> SLANG_FORCE_INLINE const T* dynamicCast(const NodeBase* node) { - return (node && - ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) - ? static_cast<const T*>(node) - : nullptr; + return (node && node->getClass().isSubClassOf<T>()) ? static_cast<const T*>(node) : nullptr; } template<typename T> SLANG_FORCE_INLINE T* as(NodeBase* node) { - return (node && - ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) - ? static_cast<T*>(node) - : nullptr; + return (node && node->getClass().isSubClassOf<T>()) ? static_cast<T*>(node) : nullptr; } template<typename T> SLANG_FORCE_INLINE const T* as(const NodeBase* node) { - return (node && - ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) - ? static_cast<const T*>(node) - : nullptr; + return (node && node->getClass().isSubClassOf<T>()) ? static_cast<const T*>(node) : nullptr; } // Because DeclRefBase is now a `Val`, we prevent casting it directly into other nodes @@ -114,9 +97,10 @@ DeclRef<T> as(DeclRef<U> declRef) return DeclRef<T>(declRef); } -struct Scope : public NodeBase +FIDDLE() +class Scope : public NodeBase { - SLANG_AST_CLASS(Scope) + FIDDLE(...) // The container to use for lookup // @@ -135,12 +119,13 @@ struct Scope : public NodeBase // Base class for all nodes representing actual syntax // (thus having a location in the source code) +FIDDLE(abstract) class SyntaxNodeBase : public NodeBase { - SLANG_ABSTRACT_AST_CLASS(SyntaxNodeBase) + FIDDLE(...) // The primary source location associated with this AST node - SourceLoc loc; + FIDDLE() SourceLoc loc; }; enum class ValNodeOperandKind @@ -231,7 +216,7 @@ private: HashCode hashCode = 0; public: - ASTNodeType type; + SyntaxClass<NodeBase> type; ShortList<ValNodeOperand, 8> operands; inline bool operator==(ValNodeDesc const& that) const @@ -363,9 +348,10 @@ static void addOrAppendToNodeList(List<ValNodeOperand>& list, ArrayView<T> l, Ts // a unique location, and any two `Val`s representing // the same value should be conceptually equal. +FIDDLE(abstract) class Val : public NodeBase { - SLANG_ABSTRACT_AST_CLASS(Val) + FIDDLE(...) template<typename T> struct OperandView @@ -406,10 +392,6 @@ class Val : public NodeBase ConstIterator end() const { return ConstIterator{val, offset + count}; } }; - typedef IValVisitor Visitor; - - void accept(IValVisitor* visitor, void* extra); - // construct a new value by applying a set of parameter // substitutions to this one Val* substitute(ASTBuilder* astBuilder, SubstitutionSet subst); @@ -479,7 +461,7 @@ class Val : public NodeBase for (auto v : operands) m_operands.add(ValNodeOperand(v)); } - List<ValNodeOperand> m_operands; + FIDDLE() List<ValNodeOperand> m_operands; // Private use by the core module deserialization only. Since we know the Vals serialized into // the core module is already unique, we can just use `this` pointer as the `m_resolvedVal` so @@ -567,13 +549,10 @@ SLANG_FORCE_INLINE const T* as(const Type* obj); // "canonical" type. The representation caches a pointer to // a canonical type on every type, so we can easily // operate on the raw representation when needed. +FIDDLE(abstract) class Type : public Val { - SLANG_ABSTRACT_AST_CLASS(Type) - - typedef ITypeVisitor Visitor; - - void accept(ITypeVisitor* visitor, void* extra); + FIDDLE(...) /// Type derived types store the AST builder they were constructed on. The builder calls this /// function after constructing. @@ -618,9 +597,10 @@ class Decl; // A reference to a declaration, which may include // substitutions for generic parameters. +FIDDLE(abstract) class DeclRefBase : public Val { - SLANG_ABSTRACT_AST_CLASS(DeclRefBase) + FIDDLE(...) Decl* getDecl() const { return getDeclOperand(0); } @@ -687,9 +667,10 @@ SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, Decl* decl) return io; } +FIDDLE(abstract) class SyntaxNode : public SyntaxNodeBase { - SLANG_ABSTRACT_AST_CLASS(SyntaxNode); + FIDDLE(...) }; // @@ -697,29 +678,28 @@ class SyntaxNode : public SyntaxNodeBase // (that is, we don't use a bitfield, even for simple/common flags). // This ensures that we can track source locations for all modifiers. // +FIDDLE(abstract) class Modifier : public SyntaxNode { - SLANG_ABSTRACT_AST_CLASS(Modifier) - typedef IModifierVisitor Visitor; - - void accept(IModifierVisitor* visitor, void* extra); + FIDDLE(...) // Next modifier in linked list of modifiers on same piece of syntax Modifier* next = nullptr; // The keyword that was used to introduce t that was used to name this modifier. - Name* keywordName = nullptr; + FIDDLE() Name* keywordName = nullptr; Name* getKeywordName() { return keywordName; } NameLoc getKeywordNameAndLoc() { return NameLoc(keywordName, loc); } }; // A syntax node which can have modifiers applied +FIDDLE(abstract) class ModifiableSyntaxNode : public SyntaxNode { - SLANG_ABSTRACT_AST_CLASS(ModifiableSyntaxNode) + FIDDLE(...) - Modifiers modifiers; + FIDDLE() Modifiers modifiers; template<typename T> FilteredModifierList<T> getModifiersOfType() @@ -748,28 +728,25 @@ struct ProvenenceNodeWithLoc }; // An intermediate type to represent either a single declaration, or a group of declarations +FIDDLE(abstract) class DeclBase : public ModifiableSyntaxNode { - SLANG_ABSTRACT_AST_CLASS(DeclBase) - - typedef IDeclVisitor Visitor; - - void accept(IDeclVisitor* visitor, void* extra); + FIDDLE(...) }; +FIDDLE(abstract) class Decl : public DeclBase { + FIDDLE(...) public: - SLANG_ABSTRACT_AST_CLASS(Decl) - - ContainerDecl* parentDecl = nullptr; + FIDDLE() ContainerDecl* parentDecl = nullptr; DeclRefBase* getDefaultDeclRef(); - NameLoc nameAndLoc; - CapabilitySet inferredCapabilityRequirements; + FIDDLE() NameLoc nameAndLoc; + FIDDLE() CapabilitySet inferredCapabilityRequirements; - RefPtr<MarkupEntry> markup; + FIDDLE() RefPtr<MarkupEntry> markup; Name* getName() const { return nameAndLoc.name; } SourceLoc getNameLoc() const { return nameAndLoc.loc; } @@ -797,26 +774,20 @@ private: SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr; }; +FIDDLE(abstract) class Expr : public SyntaxNode { - SLANG_ABSTRACT_AST_CLASS(Expr) - - typedef IExprVisitor Visitor; + FIDDLE(...) - QualType type; + FIDDLE() QualType type; bool checked = false; - - void accept(IExprVisitor* visitor, void* extra); }; +FIDDLE(abstract) class Stmt : public ModifiableSyntaxNode { - SLANG_ABSTRACT_AST_CLASS(Stmt) - - typedef IStmtVisitor Visitor; - - void accept(IStmtVisitor* visitor, void* extra); + FIDDLE(...) }; template<typename T> diff --git a/source/slang/slang-ast-boilerplate.cpp b/source/slang/slang-ast-boilerplate.cpp new file mode 100644 index 000000000..0313d4411 --- /dev/null +++ b/source/slang/slang-ast-boilerplate.cpp @@ -0,0 +1,54 @@ +// slang-ast-boilerplate.cpp + +#include "slang-ast-all.h" +#include "slang-ast-builder.h" +#include "slang-ast-forward-declarations.h" + +namespace Slang +{ +template<typename T> +struct Helper +{ + static void* create(ASTBuilder* builder) { return builder->createImpl<T>(); } + + static void destruct(void* obj) { ((T*)obj)->~T(); } +}; + +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do +const SyntaxClassInfo $T::kSyntaxClassInfo = { + "$T", + ASTNodeType::$T, + $(#T.subclasses), +% if T.isAbstract then + nullptr, // create + nullptr, // destruct +% else + &Helper<$T>::create, + &Helper<$T>::destruct, +% end +}; +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-ast-boilerplate.cpp.fiddle" +#endif // FIDDLE END + +static SyntaxClassInfo const* kAllSyntaxClasses[] = { +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do + &$T::kSyntaxClassInfo, +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 1 +#include "slang-ast-boilerplate.cpp.fiddle" +#endif // FIDDLE END +}; + +SyntaxClassBase::SyntaxClassBase(ASTNodeType tag) +{ + assert(int(tag) >= 0 && int(tag) < SLANG_COUNT_OF(kAllSyntaxClasses)); + _info = kAllSyntaxClasses[int(tag)]; +} + +} // namespace Slang diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index b3afa5310..5abef94b3 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -33,46 +33,34 @@ void SharedASTBuilder::init(Session* session) // NOTE! That this adds the names of the abstract classes too(!) for (Index i = 0; i < Index(ASTNodeType::CountOf); ++i) { - const ReflectClassInfo* info = ASTClassInfo::getInfo(ASTNodeType(i)); - if (info) - { - m_sliceToTypeMap.add(UnownedStringSlice(info->m_name), info); - Name* name = m_namePool->getName(String(info->m_name)); - m_nameToTypeMap.add(name, info); - } + auto syntaxClass = SyntaxClass(ASTNodeType(i)); + if (!syntaxClass) + continue; + auto nameText = syntaxClass.getName(); + m_sliceToTypeMap.add(nameText, syntaxClass); + Name* nameObj = m_namePool->getName(nameText); + m_nameToTypeMap.add(nameObj, syntaxClass); } } -const ReflectClassInfo* SharedASTBuilder::findClassInfo(const UnownedStringSlice& slice) -{ - const ReflectClassInfo* typeInfo; - return m_sliceToTypeMap.tryGetValue(slice, typeInfo) ? typeInfo : nullptr; -} - -SyntaxClass<NodeBase> SharedASTBuilder::findSyntaxClass(const UnownedStringSlice& slice) +SyntaxClass<> SharedASTBuilder::findSyntaxClass(const UnownedStringSlice& slice) { - const ReflectClassInfo* typeInfo; + SyntaxClass typeInfo; if (m_sliceToTypeMap.tryGetValue(slice, typeInfo)) { - return SyntaxClass<NodeBase>(typeInfo); + return typeInfo; } - return SyntaxClass<NodeBase>(); -} - -const ReflectClassInfo* SharedASTBuilder::findClassInfo(Name* name) -{ - const ReflectClassInfo* typeInfo; - return m_nameToTypeMap.tryGetValue(name, typeInfo) ? typeInfo : nullptr; + return getSyntaxClass<NodeBase>(); } SyntaxClass<NodeBase> SharedASTBuilder::findSyntaxClass(Name* name) { - const ReflectClassInfo* typeInfo; + SyntaxClass<NodeBase> typeInfo; if (m_nameToTypeMap.tryGetValue(name, typeInfo)) { - return SyntaxClass<NodeBase>(typeInfo); + return typeInfo; } - return SyntaxClass<NodeBase>(); + return getSyntaxClass<NodeBase>(); } Type* SharedASTBuilder::getStringType() @@ -256,9 +244,8 @@ ASTBuilder::~ASTBuilder() { for (NodeBase* node : m_dtorNodes) { - const ReflectClassInfo* info = ASTClassInfo::getInfo(node->astNodeType); - SLANG_ASSERT(info->m_destructorFunc); - info->m_destructorFunc(node); + auto nodeClass = node->getClass(); + nodeClass.destructInstance(node); } incrementEpoch(); } @@ -275,16 +262,8 @@ void ASTBuilder::incrementEpoch() NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType) { - const ReflectClassInfo* info = ASTClassInfo::getInfo(nodeType); - - auto createFunc = info->m_createFunc; - SLANG_ASSERT(createFunc); - if (!createFunc) - { - return nullptr; - } - - return (NodeBase*)createFunc(this); + auto syntaxClass = SyntaxClass<NodeBase>(nodeType); + return syntaxClass.createInstance(this); } Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTypeName) diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index daf49f3f7..a25fcea28 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -46,10 +46,8 @@ public: Type* getInitializerListType(); Type* getOverloadedType(); - const ReflectClassInfo* findClassInfo(Name* name); SyntaxClass<NodeBase> findSyntaxClass(Name* name); - const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice); SyntaxClass<NodeBase> findSyntaxClass(const UnownedStringSlice& slice); // Look up a magic declaration by its name @@ -113,8 +111,8 @@ protected: Dictionary<String, Decl*> m_magicDecls; Dictionary<BuiltinRequirementKind, Decl*> m_builtinRequirementDecls; - Dictionary<UnownedStringSlice, const ReflectClassInfo*> m_sliceToTypeMap; - Dictionary<Name*, const ReflectClassInfo*> m_nameToTypeMap; + Dictionary<UnownedStringSlice, SyntaxClass<NodeBase>> m_sliceToTypeMap; + Dictionary<Name*, SyntaxClass<NodeBase>> m_nameToTypeMap; NamePool* m_namePool = nullptr; @@ -160,7 +158,7 @@ struct ValKey { if (hashCode != desc.getHashCode()) return false; - if (val->astNodeType != desc.type) + if (val->getClass() != desc.type) return false; if (val->m_operands.getCount() != desc.operands.getCount()) return false; @@ -199,7 +197,7 @@ public: if (auto found = m_cachedNodes.tryGetValue(desc)) return *found; - auto node = as<Val>(createByNodeType(desc.type)); + auto node = as<Val>(desc.type.createInstance(this)); SLANG_ASSERT(node); for (auto& operand : desc.operands) node->m_operands.add(operand); @@ -268,12 +266,14 @@ public: MemoryArena& getArena() { return m_arena; } + NamePool* getNamePool() { return getSharedASTBuilder()->getNamePool(); } + template<typename T, typename... TArgs> SLANG_FORCE_INLINE T* getOrCreate(TArgs... args) { SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); ValNodeDesc desc; - desc.type = T::kType; + desc.type = getSyntaxClass<T>(); addOrAppendToNodeList(desc.operands, args...); desc.init(); auto result = (T*)_getOrCreateImpl(_Move(desc)); @@ -286,7 +286,7 @@ public: SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); ValNodeDesc desc; - desc.type = T::kType; + desc.type = getSyntaxClass<T>(); desc.init(); auto result = (T*)_getOrCreateImpl(_Move(desc)); return result; @@ -642,19 +642,11 @@ public: DeclRef<Decl> declRef); /// Helpers to get type info from the SharedASTBuilder - const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice) - { - return m_sharedASTBuilder->findClassInfo(slice); - } SyntaxClass<NodeBase> findSyntaxClass(const UnownedStringSlice& slice) { return m_sharedASTBuilder->findSyntaxClass(slice); } - const ReflectClassInfo* findClassInfo(Name* name) - { - return m_sharedASTBuilder->findClassInfo(name); - } SyntaxClass<NodeBase> findSyntaxClass(Name* name) { return m_sharedASTBuilder->findSyntaxClass(name); @@ -695,12 +687,12 @@ protected: // Keep such that dtor can be run on ASTBuilder being dtored m_dtorNodes.add(node); } - if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Val::kType))) + if (node->getClass().isSubClassOf(getSyntaxClass<Val>())) { auto val = (Val*)(node); val->m_resolvedValEpoch = getEpoch(); } - else if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Decl::kType))) + else if (node->getClass().isSubClassOf(getSyntaxClass<Decl>())) { ((Decl*)node)->m_defaultDeclRef = getOrCreate<DirectDeclRef>((Decl*)node); } diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp index 9f140a524..a44e5b817 100644 --- a/source/slang/slang-ast-decl-ref.cpp +++ b/source/slang/slang-ast-decl-ref.cpp @@ -1,8 +1,9 @@ +// slang-ast-decl-ref.cpp + #include "slang-ast-builder.h" -#include "slang-ast-reflect.h" +#include "slang-ast-dispatch.h" +#include "slang-ast-forward-declarations.h" #include "slang-check-impl.h" -#include "slang-generated-ast-macro.h" -#include "slang-generated-ast.h" namespace Slang { diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp index c0d0e9242..530f983d9 100644 --- a/source/slang/slang-ast-decl.cpp +++ b/source/slang/slang-ast-decl.cpp @@ -2,7 +2,7 @@ #include "slang-ast-decl.h" #include "slang-ast-builder.h" -#include "slang-generated-ast-macro.h" +#include "slang-ast-dispatch.h" #include "slang-syntax.h" #include <assert.h> @@ -12,7 +12,7 @@ namespace Slang const TypeExp& TypeConstraintDecl::getSup() const { - SLANG_AST_NODE_CONST_VIRTUAL_CALL(TypeConstraintDecl, getSup, ()) + SLANG_AST_NODE_VIRTUAL_CALL(TypeConstraintDecl, getSup, ()) } const TypeExp& TypeConstraintDecl::_getSupOverride() const diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index ff55340ac..261d2458a 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -3,31 +3,35 @@ #pragma once #include "slang-ast-base.h" +#include "slang-ast-decl.h.fiddle" +FIDDLE() namespace Slang { // Syntax class definitions for declarations. // A group of declarations that should be treated as a unit +FIDDLE() class DeclGroup : public DeclBase { - SLANG_AST_CLASS(DeclGroup) - - List<Decl*> decls; + FIDDLE(...) + FIDDLE() List<Decl*> decls; }; +FIDDLE() class UnresolvedDecl : public Decl { - SLANG_AST_CLASS(UnresolvedDecl) + FIDDLE(...) }; // A "container" decl is a parent to other declarations +FIDDLE(abstract) class ContainerDecl : public Decl { - SLANG_ABSTRACT_AST_CLASS(ContainerDecl) + FIDDLE(...) - List<Decl*> members; + FIDDLE() List<Decl*> members; SourceLoc closingSourceLoc; // The associated scope owned by this decl. @@ -86,32 +90,35 @@ class ContainerDecl : public Decl }; // Base class for all variable declarations +FIDDLE(abstract) class VarDeclBase : public Decl { - SLANG_ABSTRACT_AST_CLASS(VarDeclBase) + FIDDLE(...) // type of the variable - TypeExp type; + FIDDLE() TypeExp type; Type* getType() { return type.type; } // Initializer expression (optional) - Expr* initExpr = nullptr; + FIDDLE() Expr* initExpr = nullptr; // Folded IntVal if the initializer is a constant integer. - IntVal* val = nullptr; + FIDDLE() IntVal* val = nullptr; }; // Ordinary potentially-mutable variables (locals, globals, and member variables) +FIDDLE() class VarDecl : public VarDeclBase { - SLANG_AST_CLASS(VarDecl) + FIDDLE(...) }; // A variable declaration that is always immutable (whether local, global, or member variable) +FIDDLE() class LetDecl : public VarDecl { - SLANG_AST_CLASS(LetDecl) + FIDDLE(...) }; // An `AggTypeDeclBase` captures the shared functionality @@ -122,17 +129,18 @@ class LetDecl : public VarDecl // - Both can have declared bases // - Both expose a `this` variable in their body // +FIDDLE(abstract) class AggTypeDeclBase : public ContainerDecl { - SLANG_ABSTRACT_AST_CLASS(AggTypeDeclBase); + FIDDLE(...) }; // An extension to apply to an existing type +FIDDLE() class ExtensionDecl : public AggTypeDeclBase { - SLANG_AST_CLASS(ExtensionDecl) - - TypeExp targetType; + FIDDLE(...) + FIDDLE() TypeExp targetType; }; enum class TypeTag @@ -145,11 +153,11 @@ enum class TypeTag }; // Declaration of a type that represents some sort of aggregate +FIDDLE(abstract) class AggTypeDecl : public AggTypeDeclBase { - SLANG_ABSTRACT_AST_CLASS(AggTypeDecl) - - TypeTag typeTags = TypeTag::None; + FIDDLE(...) + FIDDLE() TypeTag typeTags = TypeTag::None; // Used if this type declaration is a wrapper, i.e. struct FooWrapper:IFoo = Foo; TypeExp wrappedType; @@ -162,23 +170,25 @@ class AggTypeDecl : public AggTypeDeclBase FilteredMemberList<VarDecl> getFields() { return getMembersOfType<VarDecl>(); } }; +FIDDLE() class StructDecl : public AggTypeDecl { - SLANG_AST_CLASS(StructDecl); - + FIDDLE(...) SLANG_UNREFLECTED // We will use these auxiliary to help in synthesizing the member initialize constructor. Slang::HashSet<VarDeclBase*> m_membersVisibleInCtor; }; +FIDDLE() class ClassDecl : public AggTypeDecl { - SLANG_AST_CLASS(ClassDecl) + FIDDLE(...) }; +FIDDLE() class GLSLInterfaceBlockDecl : public AggTypeDecl { - SLANG_AST_CLASS(GLSLInterfaceBlockDecl); + FIDDLE(...) }; // TODO: Is it appropriate to treat an `enum` as an aggregate type? @@ -186,11 +196,11 @@ class GLSLInterfaceBlockDecl : public AggTypeDecl // types are all `AggTypeDecl`, so this is the right choice for now // if we want `enum` types to be able to implement interfaces, etc. // +FIDDLE() class EnumDecl : public AggTypeDecl { - SLANG_AST_CLASS(EnumDecl) - - Type* tagType = nullptr; + FIDDLE(...) + FIDDLE() Type* tagType = nullptr; }; // A single case in an enum. @@ -203,39 +213,40 @@ class EnumDecl : public AggTypeDecl // case, with `0` as an explicit expression for its // _tag value_. // +FIDDLE() class EnumCaseDecl : public Decl { - SLANG_AST_CLASS(EnumCaseDecl) - + FIDDLE(...) // type of the parent `enum` - TypeExp type; + FIDDLE() TypeExp type; Type* getType() { return type.type; } // Tag value - Expr* tagExpr = nullptr; + FIDDLE() Expr* tagExpr = nullptr; - IntVal* tagVal = nullptr; + FIDDLE() IntVal* tagVal = nullptr; }; // A member of InterfaceDecl representing the abstract ThisType. +FIDDLE() class ThisTypeDecl : public AggTypeDecl { - SLANG_AST_CLASS(ThisTypeDecl) + FIDDLE(...) }; // An interface which other types can conform to +FIDDLE() class InterfaceDecl : public AggTypeDecl { - SLANG_AST_CLASS(InterfaceDecl) - + FIDDLE(...) ThisTypeDecl* getThisTypeDecl(); }; +FIDDLE(abstract) class TypeConstraintDecl : public Decl { - SLANG_ABSTRACT_AST_CLASS(TypeConstraintDecl) - + FIDDLE(...) const TypeExp& getSup() const; // Overrides should be public so base classes can access // Implement _getSupOverride on derived classes to change behavior of getSup, as if getSup is @@ -243,11 +254,11 @@ class TypeConstraintDecl : public Decl const TypeExp& _getSupOverride() const; }; +FIDDLE() class ThisTypeConstraintDecl : public TypeConstraintDecl { - SLANG_AST_CLASS(ThisTypeConstraintDecl) - - TypeExp base; + FIDDLE(...) + FIDDLE() TypeExp base; const TypeExp& _getSupOverride() const { return base; } InterfaceDecl* getInterfaceDecl(); }; @@ -255,18 +266,18 @@ class ThisTypeConstraintDecl : public TypeConstraintDecl // A kind of pseudo-member that represents an explicit // or implicit inheritance relationship. // +FIDDLE() class InheritanceDecl : public TypeConstraintDecl { - SLANG_AST_CLASS(InheritanceDecl) - + FIDDLE(...) // The type expression as written - TypeExp base; + FIDDLE() TypeExp base; // After checking, this dictionary will map members // required by the base type to their concrete // implementations in the type that contains // this inheritance declaration. - RefPtr<WitnessTable> witnessTable; + FIDDLE() RefPtr<WitnessTable> witnessTable; // Overrides should be public so base classes can access const TypeExp& _getSupOverride() const { return base; } @@ -279,74 +290,82 @@ class InheritanceDecl : public TypeConstraintDecl // // TODO: probably all types will be aggregate decls eventually, // so that we can easily store conformances/constraints on type variables +FIDDLE(abstract) class SimpleTypeDecl : public Decl { - SLANG_ABSTRACT_AST_CLASS(SimpleTypeDecl) + FIDDLE(...) }; // A `typedef` declaration +FIDDLE() class TypeDefDecl : public SimpleTypeDecl { - SLANG_AST_CLASS(TypeDefDecl) - - TypeExp type; + FIDDLE(...) + FIDDLE() TypeExp type; }; +FIDDLE() class TypeAliasDecl : public TypeDefDecl { - SLANG_AST_CLASS(TypeAliasDecl) + FIDDLE(...) }; // An 'assoctype' declaration, it is a container of inheritance clauses +FIDDLE() class AssocTypeDecl : public AggTypeDecl { - SLANG_AST_CLASS(AssocTypeDecl) + FIDDLE(...) }; // A 'type_param' declaration, which defines a generic // entry-point parameter. Is a container of GenericTypeConstraintDecl +FIDDLE() class GlobalGenericParamDecl : public AggTypeDecl { - SLANG_AST_CLASS(GlobalGenericParamDecl) + FIDDLE(...) }; // A `__generic_value_param` declaration, which defines an existential // value parameter (not a type parameter. +FIDDLE() class GlobalGenericValueParamDecl : public VarDeclBase { - SLANG_AST_CLASS(GlobalGenericValueParamDecl) + FIDDLE(...) }; // A scope for local declarations (e.g., as part of a statement) +FIDDLE() class ScopeDecl : public ContainerDecl { - SLANG_AST_CLASS(ScopeDecl) + FIDDLE(...) }; // A function/initializer/subscript parameter (potentially mutable) +FIDDLE() class ParamDecl : public VarDeclBase { - SLANG_AST_CLASS(ParamDecl) + FIDDLE(...) }; // A parameter of a function declared in "modern" types (immutable unless explicitly `out` or // `inout`) +FIDDLE() class ModernParamDecl : public ParamDecl { - SLANG_AST_CLASS(ModernParamDecl) + FIDDLE(...) }; // Base class for things that have parameter lists and can thus be applied to arguments ("called") +FIDDLE(abstract) class CallableDecl : public ContainerDecl { - SLANG_ABSTRACT_AST_CLASS(CallableDecl) - + FIDDLE(...) FilteredMemberList<ParamDecl> getParameters() { return getMembersOfType<ParamDecl>(); } - TypeExp returnType; + FIDDLE() TypeExp returnType; // If this callable throws an error code, `errorType` is the type of the error code. - TypeExp errorType; + FIDDLE() TypeExp errorType; // Fields related to redeclaration, so that we // can support multiple specialized variations @@ -366,18 +385,18 @@ class CallableDecl : public ContainerDecl // Base class for callable things that may also have a body that is evaluated to produce their // result +FIDDLE(abstract) class FunctionDeclBase : public CallableDecl { - SLANG_ABSTRACT_AST_CLASS(FunctionDeclBase) - - Stmt* body = nullptr; + FIDDLE(...) + FIDDLE() Stmt* body = nullptr; }; // A constructor/initializer to create instances of a type +FIDDLE() class ConstructorDecl : public FunctionDeclBase { - SLANG_AST_CLASS(ConstructorDecl) - + FIDDLE(...) enum class ConstructorFlavor : int { UserDefined = 0x00, @@ -389,51 +408,61 @@ class ConstructorDecl : public FunctionDeclBase SynthesizedMemberInit = 0x02 }; - int m_flavor = (int)ConstructorFlavor::UserDefined; + FIDDLE() int m_flavor = (int)ConstructorFlavor::UserDefined; void addFlavor(ConstructorFlavor flavor) { m_flavor |= (int)flavor; } bool containsFlavor(ConstructorFlavor flavor) { return m_flavor & (int)flavor; } }; // A subscript operation used to index instances of a type +FIDDLE() class SubscriptDecl : public CallableDecl { - SLANG_AST_CLASS(SubscriptDecl) + FIDDLE(...) }; /// A property declaration that abstracts over storage with a getter/setter/etc. +FIDDLE() class PropertyDecl : public ContainerDecl { - SLANG_AST_CLASS(PropertyDecl) - - TypeExp type; + FIDDLE(...) + FIDDLE() TypeExp type; }; // An "accessor" for a subscript or property +FIDDLE(abstract) class AccessorDecl : public FunctionDeclBase { - SLANG_AST_CLASS(AccessorDecl) + FIDDLE(...) }; +FIDDLE() class GetterDecl : public AccessorDecl { - SLANG_AST_CLASS(GetterDecl) + FIDDLE(...) }; + +FIDDLE() class SetterDecl : public AccessorDecl { - SLANG_AST_CLASS(SetterDecl) + FIDDLE(...) }; + +FIDDLE() class RefAccessorDecl : public AccessorDecl { - SLANG_AST_CLASS(RefAccessorDecl) + FIDDLE(...) }; + +FIDDLE() class FuncDecl : public FunctionDeclBase { - SLANG_AST_CLASS(FuncDecl) + FIDDLE(...) }; +FIDDLE(abstract) class NamespaceDeclBase : public ContainerDecl { - SLANG_AST_CLASS(NamespaceDeclBase) + FIDDLE(...) }; // A `namespace` declaration inside some module, that provides @@ -444,16 +473,19 @@ class NamespaceDeclBase : public ContainerDecl // `NamespaceDecl` during parsing, so this declaration does // not directly represent what is present in the input syntax. // +FIDDLE() class NamespaceDecl : public NamespaceDeclBase { - SLANG_AST_CLASS(NamespaceDecl) + FIDDLE(...) }; // A "module" of code (essentially, a single translation unit) // that provides a scope for some number of declarations. +FIDDLE() class ModuleDecl : public NamespaceDeclBase { - SLANG_AST_CLASS(ModuleDecl) + FIDDLE(...) + // The API-level module that this declaration belong to. // // This field allows lookup of the `Module` based on a @@ -467,7 +499,7 @@ class ModuleDecl : public NamespaceDeclBase /// This mapping is filled in during semantic checking, as the decl declarations get checked or /// generated. /// - OrderedDictionary<Decl*, RefPtr<DeclAssociationList>> mapDeclToAssociatedDecls; + FIDDLE() OrderedDictionary<Decl*, RefPtr<DeclAssociationList>> mapDeclToAssociatedDecls; /// Whether the module is defined in legacy language. /// The legacy Slang language does not have visibility modifiers and everything is treated as @@ -477,9 +509,9 @@ class ModuleDecl : public NamespaceDeclBase /// visibility modifiers, or if the module uses new language constructs, e.g. `module`, /// `__include`, /// `__implementing` etc. - bool isInLegacyLanguage = true; + FIDDLE() bool isInLegacyLanguage = true; - DeclVisibility defaultVisibility = DeclVisibility::Internal; + FIDDLE() DeclVisibility defaultVisibility = DeclVisibility::Internal; SLANG_UNREFLECTED @@ -487,19 +519,21 @@ class ModuleDecl : public NamespaceDeclBase /// /// This mapping is filled in during semantic checking, as `ExtensionDecl`s get checked. /// - Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> mapTypeToCandidateExtensions; + FIDDLE() Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> mapTypeToCandidateExtensions; }; // Represents a transparent scope of declarations that are defined in a single source file. +FIDDLE() class FileDecl : public ContainerDecl { - SLANG_AST_CLASS(FileDecl); + FIDDLE(...) }; /// A declaration that brings members of another declaration or namespace into scope +FIDDLE() class UsingDecl : public Decl { - SLANG_AST_CLASS(UsingDecl) + FIDDLE(...) /// An expression that identifies the entity (e.g., a namespace) to be brought into `scope` Expr* arg = nullptr; @@ -509,9 +543,10 @@ class UsingDecl : public Decl Scope* scope = nullptr; }; +FIDDLE() class FileReferenceDeclBase : public Decl { - SLANG_AST_CLASS(FileReferenceDeclBase) + FIDDLE(...) // The name of the module we are trying to import NameLoc moduleNameAndLoc; @@ -524,107 +559,115 @@ class FileReferenceDeclBase : public Decl Scope* scope = nullptr; }; +FIDDLE() class ImportDecl : public FileReferenceDeclBase { - SLANG_AST_CLASS(ImportDecl) + FIDDLE(...) // The module that actually got imported - ModuleDecl* importedModuleDecl = nullptr; + FIDDLE() ModuleDecl* importedModuleDecl = nullptr; }; +FIDDLE(abstract) class IncludeDeclBase : public FileReferenceDeclBase { - SLANG_AST_CLASS(IncludeDeclBase) - + FIDDLE(...) FileDecl* fileDecl = nullptr; }; +FIDDLE() class IncludeDecl : public IncludeDeclBase { - SLANG_AST_CLASS(IncludeDecl) + FIDDLE(...) }; +FIDDLE() class ImplementingDecl : public IncludeDeclBase { - SLANG_AST_CLASS(ImplementingDecl) + FIDDLE(...) }; +FIDDLE() class ModuleDeclarationDecl : public Decl { - SLANG_AST_CLASS(ModuleDeclarationDecl) + FIDDLE(...) }; +FIDDLE() class RequireCapabilityDecl : public Decl { - SLANG_AST_CLASS(RequireCapabilityDecl) + FIDDLE(...) }; // A generic declaration, parameterized on types/values +FIDDLE() class GenericDecl : public ContainerDecl { - SLANG_AST_CLASS(GenericDecl) + FIDDLE(...) // The decl that is genericized... - Decl* inner = nullptr; + FIDDLE() Decl* inner = nullptr; }; +FIDDLE(abstract) class GenericTypeParamDeclBase : public SimpleTypeDecl { - SLANG_AST_CLASS(GenericTypeParamDeclBase) - + FIDDLE(...) // The index of the generic parameter. int parameterIndex = -1; }; +FIDDLE() class GenericTypeParamDecl : public GenericTypeParamDeclBase { - SLANG_AST_CLASS(GenericTypeParamDecl) + FIDDLE(...) // The bound for the type parameter represents a trait that any // type used as this parameter must conform to // TypeExp bound; // The "initializer" for the parameter represents a default value - TypeExp initType; + FIDDLE() TypeExp initType; }; +FIDDLE() class GenericTypePackParamDecl : public GenericTypeParamDeclBase { - SLANG_AST_CLASS(GenericTypePackParamDecl) + FIDDLE(...) }; // A constraint placed as part of a generic declaration +FIDDLE() class GenericTypeConstraintDecl : public TypeConstraintDecl { - SLANG_AST_CLASS(GenericTypeConstraintDecl) - + FIDDLE(...) // A type constraint like `T : U` is constraining `T` to be "below" `U` // on a lattice of types. This may not be a subtyping relationship // per se, but it makes sense to use that terminology here, so we // think of these fields as the sub-type and super-type, respectively. - TypeExp sub; - TypeExp sup; + FIDDLE() TypeExp sub; + FIDDLE() TypeExp sup; // If this decl is defined in a where clause, store the source location of the where token. SourceLoc whereTokenLoc = SourceLoc(); - bool isEqualityConstraint = false; + FIDDLE() bool isEqualityConstraint = false; // Overrides should be public so base classes can access const TypeExp& _getSupOverride() const { return sup; } }; +FIDDLE() class TypeCoercionConstraintDecl : public Decl { - SLANG_AST_CLASS(TypeCoercionConstraintDecl) - + FIDDLE(...) SourceLoc whereTokenLoc = SourceLoc(); - TypeExp fromType; - TypeExp toType; + FIDDLE() TypeExp fromType; + FIDDLE() TypeExp toType; }; +FIDDLE() class GenericValueParamDecl : public VarDeclBase { - SLANG_AST_CLASS(GenericValueParamDecl) - + FIDDLE(...) // The index of the generic parameter. int parameterIndex = 0; }; @@ -638,20 +681,21 @@ class GenericValueParamDecl : public VarDeclBase // // layout(local_size_x = 16) in; // +FIDDLE() class EmptyDecl : public Decl { - SLANG_AST_CLASS(EmptyDecl) + FIDDLE(...) }; // A declaration used by the implementation to put syntax keywords // into the current scope. // +FIDDLE() class SyntaxDecl : public Decl { - SLANG_AST_CLASS(SyntaxDecl) - + FIDDLE(...) // What type of syntax node will be produced when parsing with this keyword? - SyntaxClass<NodeBase> syntaxClass; + FIDDLE() SyntaxClass<NodeBase> syntaxClass; SLANG_UNREFLECTED @@ -662,44 +706,48 @@ class SyntaxDecl : public Decl // A declaration of an attribute to be used with `[name(...)]` syntax. // +FIDDLE() class AttributeDecl : public ContainerDecl { - SLANG_AST_CLASS(AttributeDecl) + FIDDLE(...) // What type of syntax node will be produced to represent this attribute. - SyntaxClass<NodeBase> syntaxClass; + FIDDLE() SyntaxClass<NodeBase> syntaxClass; }; // A synthesized decl used as a placeholder for a differentiable function requirement. This decl // will be a child of interface decl. This allows us to form an interface requirement key for the // derivative of an interface function. The synthesized `DerivativeRequirementDecl` will be a child // of the original function requirement decl after an interface type is checked. +FIDDLE() class DerivativeRequirementDecl : public FunctionDeclBase { - SLANG_AST_CLASS(DerivativeRequirementDecl) - + FIDDLE(...) // The original requirement decl. - Decl* originalRequirementDecl = nullptr; + FIDDLE() Decl* originalRequirementDecl = nullptr; // Type to use for 'ThisType' - Type* diffThisType; + FIDDLE() Type* diffThisType; }; // A reference to a synthesized decl representing a differentiable function requirement, this decl // will be a child in the orignal function. +FIDDLE() class DerivativeRequirementReferenceDecl : public FunctionDeclBase { - SLANG_AST_CLASS(DerivativeRequirementReferenceDecl) - DerivativeRequirementDecl* referencedDecl; + FIDDLE(...) + FIDDLE() DerivativeRequirementDecl* referencedDecl; }; +FIDDLE() class ForwardDerivativeRequirementDecl : public DerivativeRequirementDecl { - SLANG_AST_CLASS(ForwardDerivativeRequirementDecl) + FIDDLE(...) }; +FIDDLE() class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl { - SLANG_AST_CLASS(BackwardDerivativeRequirementDecl) + FIDDLE(...) }; bool isInterfaceRequirement(Decl* decl); diff --git a/source/slang/slang-ast-dispatch.h b/source/slang/slang-ast-dispatch.h new file mode 100644 index 000000000..58c67a974 --- /dev/null +++ b/source/slang/slang-ast-dispatch.h @@ -0,0 +1,56 @@ +// slang-ast-dispatch.h +#pragma once + +#include "slang-ast-forward-declarations.h" +#include "slang-syntax.h" + +namespace Slang +{ + +template<typename Base, typename Result> +struct ASTNodeDispatcher +{ +}; + +#if 0 // FIDDLE TEMPLATE: +%function generateDispatcher(BASE) +template<typename R> +struct ASTNodeDispatcher<$BASE, R> +{ + template<typename F> + static R dispatch($BASE const* obj, F const& f) + { + switch (obj->getClass().getTag()) + { + default: + SLANG_UNEXPECTED("unhandled subclass in ASTNodeDispatcher::dispatch"); + +% for _,T in ipairs(BASE.subclasses) do +% if not T.isAbstract then + case ASTNodeType::$T: + return f(static_cast<$T*>(const_cast<$BASE*>(obj))); +% end +% end + } + } +}; +%end +%generateDispatcher(Slang.TypeConstraintDecl) +%generateDispatcher(Slang.ArithmeticExpressionType) +%generateDispatcher(Slang.DeclRefBase) +%generateDispatcher(Slang.Val) +%generateDispatcher(Slang.Type) +%generateDispatcher(Slang.SubtypeWitness) +%generateDispatcher(Slang.IntVal) +%generateDispatcher(Slang.Modifier) +%generateDispatcher(Slang.DeclBase) +%generateDispatcher(Slang.Decl) +%generateDispatcher(Slang.Expr) +%generateDispatcher(Slang.Stmt) +%generateDispatcher(Slang.NodeBase) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-ast-dispatch.h.fiddle" +#endif // FIDDLE END + +} // namespace Slang diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index bd366be19..24b10344d 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -2,8 +2,8 @@ #include "slang-ast-dump.h" #include "../core/slang-string.h" +#include "slang-ast-dispatch.h" #include "slang-compiler.h" -#include "slang-generated-ast-macro.h" #include <assert.h> #include <limits> @@ -11,12 +11,10 @@ namespace Slang { - struct ASTDumpContext { struct ObjectInfo { - const ReflectClassInfo* m_typeInfo; NodeBase* m_object; bool m_isDumped; }; @@ -48,10 +46,10 @@ struct ASTDumpContext ASTDumpContext* m_context; }; - void dumpObject(const ReflectClassInfo& type, NodeBase* obj); + void dumpObject(NodeBase* obj); - void dumpObjectFull(const ReflectClassInfo& type, NodeBase* obj, Index objIndex); - void dumpObjectReference(const ReflectClassInfo& type, NodeBase* obj, Index objIndex); + void dumpObjectFull(NodeBase* obj, Index objIndex); + void dumpObjectReference(NodeBase* obj, Index objIndex); void dump(NodeBase* node) { @@ -61,7 +59,7 @@ struct ASTDumpContext } else { - dumpObject(node->getClassInfo(), node); + dumpObject(node); } } @@ -283,7 +281,7 @@ struct ASTDumpContext m_writer->emit(" }"); } - Index getObjectIndex(const ReflectClassInfo& typeInfo, NodeBase* obj) + Index getObjectIndex(NodeBase* obj) { Index* indexPtr = m_objectMap.tryGetValueOrAdd(obj, m_objects.getCount()); if (indexPtr) @@ -294,7 +292,6 @@ struct ASTDumpContext ObjectInfo info; info.m_isDumped = false; info.m_object = obj; - info.m_typeInfo = &typeInfo; m_objects.add(info); return m_objects.getCount() - 1; @@ -366,7 +363,7 @@ struct ASTDumpContext template<typename T> void dump(const SyntaxClass<T>& cls) { - m_writer->emit(cls.classInfo->m_name); + m_writer->emit(cls.getName()); } template<typename KEY, typename VALUE> @@ -568,7 +565,7 @@ struct ASTDumpContext ObjectInfo& info = m_objects[i]; if (!info.m_isDumped) { - dumpObjectFull(*info.m_typeInfo, info.m_object, i); + dumpObjectFull(info.m_object, i); } } } @@ -580,13 +577,12 @@ struct ASTDumpContext // Lets special case handling of module decls -> we only want to output as references // otherwise we end up dumping everything in every module. - const ReflectClassInfo& typeInfo = moduleDecl->getClassInfo(); - Index index = getObjectIndex(typeInfo, moduleDecl); + Index index = getObjectIndex(moduleDecl); // We don't want to fully dump, referenced modules as doing so dumps everything m_objects[index].m_isDumped = true; - dumpObjectReference(typeInfo, moduleDecl, index); + dumpObjectReference(moduleDecl, index); } else { @@ -640,9 +636,9 @@ struct ASTDumpContext void dump(ASTNodeType nodeType) { // Get the class - auto info = ASTClassInfo::getInfo(nodeType); + auto syntaxClass = SyntaxClass<NodeBase>(nodeType); // Write the name - m_writer->emit(info->m_name); + m_writer->emit(syntaxClass.getName()); } void dump(SourceLanguage language) { m_writer->emit((int)language); } @@ -775,35 +771,36 @@ struct ASTDumpContext struct ASTDumpAccess { +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do + static void dump_($T * node, ASTDumpContext & context) + { +% if T.directSuperClass then + dump_(static_cast<$(T.directSuperClass)*>(node), context); +% end +% for _,f in ipairs(T.directFields) do + context.dumpField("$f", node->$f); +% end + } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-ast-dump.cpp.fiddle" +#endif // FIDDLE END -#define SLANG_AST_DUMP_FIELD(FIELD_NAME, TYPE, param) \ - context.dumpField(#FIELD_NAME, static_cast<param*>(base)->FIELD_NAME); - -#define SLANG_AST_DUMP_FIELDS_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - case ASTNodeType::NAME: \ - { \ - SLANG_FIELDS_ASTNode_##NAME(SLANG_AST_DUMP_FIELD, NAME) break; \ - } - - static void dump(ASTNodeType type, NodeBase* base, ASTDumpContext& context) + static void dump(NodeBase* base, ASTDumpContext& context) { - switch (type) - { - SLANG_ALL_ASTNode_NodeBase(SLANG_AST_DUMP_FIELDS_IMPL, _) default : break; - } + ASTNodeDispatcher<NodeBase, void>::dispatch(base, [&](auto b) { dump_(b, context); }); } }; -void ASTDumpContext::dumpObjectReference( - const ReflectClassInfo& type, - NodeBase* obj, - Index objIndex) +void ASTDumpContext::dumpObjectReference(NodeBase* obj, Index objIndex) { SLANG_UNUSED(obj); - ScopeWrite(this).getBuf() << type.m_name << ":" << objIndex; + ScopeWrite(this).getBuf() << obj->getClass().getName() << ":" << objIndex; } -void ASTDumpContext::dumpObjectFull(const ReflectClassInfo& type, NodeBase* obj, Index objIndex) +void ASTDumpContext::dumpObjectFull(NodeBase* obj, Index objIndex) { ObjectInfo& info = m_objects[objIndex]; SLANG_ASSERT(info.m_isDumped == false); @@ -811,42 +808,27 @@ void ASTDumpContext::dumpObjectFull(const ReflectClassInfo& type, NodeBase* obj, // We need to dump the fields. - ScopeWrite(this).getBuf() << type.m_name << ":" << objIndex << " {\n"; + ScopeWrite(this).getBuf() << obj->getClass().getName() << ":" << objIndex << " {\n"; m_writer->indent(); - List<const ReflectClassInfo*> allTypes; - { - const ReflectClassInfo* curType = &type; - do - { - allTypes.add(curType); - curType = curType->m_superClass; - } while (curType); - } - - // Okay we go backwards so we output in the 'normal' order - for (Index i = allTypes.getCount() - 1; i >= 0; --i) - { - const ReflectClassInfo* curType = allTypes[i]; - ASTDumpAccess::dump(ASTNodeType(curType->m_classId), obj, *this); - } + ASTDumpAccess::dump(obj, *this); m_writer->dedent(); m_writer->emit("}\n"); } -void ASTDumpContext::dumpObject(const ReflectClassInfo& typeInfo, NodeBase* obj) +void ASTDumpContext::dumpObject(NodeBase* obj) { - Index index = getObjectIndex(typeInfo, obj); + Index index = getObjectIndex(obj); ObjectInfo& info = m_objects[index]; if (info.m_isDumped || m_dumpStyle == ASTDumpUtil::Style::Flat) { - dumpObjectReference(typeInfo, obj, index); + dumpObjectReference(obj, index); } else { - dumpObjectFull(typeInfo, obj, index); + dumpObjectFull(obj, index); } } @@ -858,9 +840,8 @@ void ASTDumpContext::dumpObjectFull(NodeBase* node) } else { - const ReflectClassInfo& typeInfo = node->getClassInfo(); - Index index = getObjectIndex(typeInfo, node); - dumpObjectFull(typeInfo, node, index); + Index index = getObjectIndex(node); + dumpObjectFull(node, index); } } diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index c9bc86b79..cd5f9b6e8 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -1,9 +1,10 @@ // slang-ast-expr.h - #pragma once #include "slang-ast-base.h" +#include "slang-ast-expr.h.fiddle" +FIDDLE() namespace Slang { @@ -12,19 +13,20 @@ using SpvWord = uint32_t; // Syntax class definitions for expressions. // // A placeholder for where an Expr is expected but is missing from source. +FIDDLE() class IncompleteExpr : public Expr { - SLANG_AST_CLASS(IncompleteExpr) + FIDDLE(...) }; + // Base class for expressions that will reference declarations +FIDDLE(abstract) class DeclRefExpr : public Expr { - SLANG_ABSTRACT_AST_CLASS(DeclRefExpr) - - + FIDDLE(...) // The declaration of the symbol being referenced - DeclRef<Decl> declRef; + FIDDLE() DeclRef<Decl> declRef; // The name of the symbol being referenced Name* name = nullptr; @@ -36,22 +38,24 @@ class DeclRefExpr : public Expr Scope* scope = nullptr; }; +FIDDLE() class VarExpr : public DeclRefExpr { - SLANG_AST_CLASS(VarExpr) + FIDDLE(...) }; +FIDDLE() class DefaultConstructExpr : public Expr { - SLANG_AST_CLASS(DefaultConstructExpr) + FIDDLE(...) }; // An expression that references an overloaded set of declarations // having the same name. +FIDDLE() class OverloadedExpr : public Expr { - SLANG_AST_CLASS(OverloadedExpr) - + FIDDLE(...) // The name that was looked up and found to be overloaded Name* name = nullptr; @@ -67,10 +71,10 @@ class OverloadedExpr : public Expr // An expression that references an overloaded set of declarations // having the same name. +FIDDLE() class OverloadedExpr2 : public Expr { - SLANG_AST_CLASS(OverloadedExpr2) - + FIDDLE(...) // Optional: the base expression is this overloaded result // arose from a member-reference expression. Expr* base = nullptr; @@ -79,108 +83,117 @@ class OverloadedExpr2 : public Expr List<Expr*> candidiateExprs; }; +FIDDLE(abstract) class LiteralExpr : public Expr { - SLANG_ABSTRACT_AST_CLASS(LiteralExpr) + FIDDLE(...) // The token that was used to express the literal. This can be // used to get the raw text of the literal, including any suffix. Token token; - BaseType suffixType = BaseType::Void; + FIDDLE() BaseType suffixType = BaseType::Void; }; +FIDDLE() class IntegerLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(IntegerLiteralExpr) - - IntegerLiteralValue value; + FIDDLE(...) + FIDDLE() IntegerLiteralValue value; }; +FIDDLE() class FloatingPointLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(FloatingPointLiteralExpr) - FloatingPointLiteralValue value; + FIDDLE(...) + FIDDLE() FloatingPointLiteralValue value; }; +FIDDLE() class BoolLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(BoolLiteralExpr) - bool value; + FIDDLE(...) + FIDDLE() bool value; }; +FIDDLE() class NullPtrLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(NullPtrLiteralExpr) + FIDDLE(...) }; +FIDDLE() class NoneLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(NoneLiteralExpr) + FIDDLE(...) }; +FIDDLE() class StringLiteralExpr : public LiteralExpr { - SLANG_AST_CLASS(StringLiteralExpr) - + FIDDLE(...) // TODO: consider storing the "segments" of the string // literal, in the case where multiple literals were // lined up at the lexer level, e.g.: // // "first" "second" "third" // - String value; + FIDDLE() String value; }; // An initializer list, e.g. `{ 1, 2, 3 }` +FIDDLE() class InitializerListExpr : public Expr { - SLANG_AST_CLASS(InitializerListExpr) - List<Expr*> args; + FIDDLE(...) + FIDDLE() List<Expr*> args; bool useCStyleInitialization = true; }; +FIDDLE() class GetArrayLengthExpr : public Expr { - SLANG_AST_CLASS(GetArrayLengthExpr) - Expr* arrayExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* arrayExpr = nullptr; }; +FIDDLE() class ExpandExpr : public Expr { - SLANG_AST_CLASS(ExpandExpr) - Expr* baseExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* baseExpr = nullptr; }; +FIDDLE() class EachExpr : public Expr { - SLANG_AST_CLASS(EachExpr) - Expr* baseExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* baseExpr = nullptr; }; // A base class for expressions with arguments +FIDDLE(abstract) class ExprWithArgsBase : public Expr { - SLANG_ABSTRACT_AST_CLASS(ExprWithArgsBase) - - List<Expr*> arguments; + FIDDLE(...) + FIDDLE() List<Expr*> arguments; }; // An aggregate type constructor +FIDDLE() class AggTypeCtorExpr : public ExprWithArgsBase { - SLANG_AST_CLASS(AggTypeCtorExpr) - - TypeExp base; + FIDDLE(...) + FIDDLE() TypeExp base; }; // A base expression being applied to arguments: covers // both ordinary `()` function calls and `<>` generic application +FIDDLE(abstract) class AppExprBase : public ExprWithArgsBase { - SLANG_ABSTRACT_AST_CLASS(AppExprBase) - - Expr* functionExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* functionExpr = nullptr; // The original function expr before overload resolution. Expr* originalFunctionExpr = nullptr; @@ -190,14 +203,16 @@ class AppExprBase : public ExprWithArgsBase List<SourceLoc> argumentDelimeterLocs; }; +FIDDLE() class InvokeExpr : public AppExprBase { - SLANG_AST_CLASS(InvokeExpr) + FIDDLE(...) }; +FIDDLE() class ExplicitCtorInvokeExpr : public InvokeExpr { - SLANG_AST_CLASS(ExplicitCtorInvokeExpr) + FIDDLE(...) }; enum class TryClauseType @@ -211,69 +226,80 @@ enum class TryClauseType char const* getTryClauseTypeName(TryClauseType value); +FIDDLE() class TryExpr : public Expr { - SLANG_AST_CLASS(TryExpr) - - Expr* base; + FIDDLE(...) + FIDDLE() Expr* base; - TryClauseType tryClauseType = TryClauseType::Standard; + FIDDLE() TryClauseType tryClauseType = TryClauseType::Standard; // The scope of this expr. Scope* scope = nullptr; }; +FIDDLE() class NewExpr : public InvokeExpr { - SLANG_AST_CLASS(NewExpr) + FIDDLE(...) }; +FIDDLE() class OperatorExpr : public InvokeExpr { - SLANG_AST_CLASS(OperatorExpr) + FIDDLE(...) }; +FIDDLE() class InfixExpr : public OperatorExpr { - SLANG_AST_CLASS(InfixExpr) + FIDDLE(...) }; + +FIDDLE() class PrefixExpr : public OperatorExpr { - SLANG_AST_CLASS(PrefixExpr) + FIDDLE(...) }; + +FIDDLE() class PostfixExpr : public OperatorExpr { - SLANG_AST_CLASS(PostfixExpr) + FIDDLE(...) }; +FIDDLE() class IndexExpr : public Expr { - SLANG_AST_CLASS(IndexExpr) - Expr* baseExpression; - List<Expr*> indexExprs; + FIDDLE(...) + FIDDLE() Expr* baseExpression; + FIDDLE() List<Expr*> indexExprs; // The source location of `(`, `)`, and `,` that marks the start/end of the application op and // each argument expr. This info is used by language server. List<SourceLoc> argumentDelimeterLocs; }; +FIDDLE() class MemberExpr : public DeclRefExpr { - SLANG_AST_CLASS(MemberExpr) - Expr* baseExpression = nullptr; + FIDDLE(...) + FIDDLE() Expr* baseExpression = nullptr; SourceLoc memberOperatorLoc; }; // Member expression that is dereferenced, e.g. `a->b`. +FIDDLE() class DerefMemberExpr : public MemberExpr { - SLANG_AST_CLASS(DerefMemberExpr) + FIDDLE(...) }; // Member looked up on a type, rather than a value +FIDDLE() class StaticMemberExpr : public DeclRefExpr { - SLANG_AST_CLASS(StaticMemberExpr) + FIDDLE(...) Expr* baseExpression = nullptr; SourceLoc memberOperatorLoc; }; @@ -287,69 +313,77 @@ struct MatrixCoord int col; }; +FIDDLE() class MatrixSwizzleExpr : public Expr { - SLANG_AST_CLASS(MatrixSwizzleExpr) - Expr* base = nullptr; - int elementCount; - MatrixCoord elementCoords[4]; + FIDDLE(...) + FIDDLE() Expr* base = nullptr; + FIDDLE() int elementCount; + FIDDLE() MatrixCoord elementCoords[4]; SourceLoc memberOpLoc; }; +FIDDLE() class SwizzleExpr : public Expr { - SLANG_AST_CLASS(SwizzleExpr) - Expr* base = nullptr; - ShortList<uint32_t, 4> elementIndices; + FIDDLE(...) + FIDDLE() Expr* base = nullptr; + FIDDLE() ShortList<uint32_t, 4> elementIndices; SourceLoc memberOpLoc; }; // An operation to convert an l-value to a reference type. +FIDDLE() class MakeRefExpr : public Expr { - SLANG_AST_CLASS(MakeRefExpr) - Expr* base = nullptr; + FIDDLE(...) + FIDDLE() Expr* base = nullptr; }; // A dereference of a pointer or pointer-like type +FIDDLE() class DerefExpr : public Expr { - SLANG_AST_CLASS(DerefExpr) - Expr* base = nullptr; + FIDDLE(...) + FIDDLE() Expr* base = nullptr; }; // Any operation that performs type-casting +FIDDLE() class TypeCastExpr : public InvokeExpr { - SLANG_AST_CLASS(TypeCastExpr) + FIDDLE(...) // TypeExp TargetType; // Expr* Expression = nullptr; }; // An explicit type-cast that appear in the user's code with `(type) expr` syntax +FIDDLE() class ExplicitCastExpr : public TypeCastExpr { - SLANG_AST_CLASS(ExplicitCastExpr) + FIDDLE(...) }; // An implicit type-cast inserted during semantic checking +FIDDLE() class ImplicitCastExpr : public TypeCastExpr { - SLANG_AST_CLASS(ImplicitCastExpr) + FIDDLE(...) }; // A builtin cast expr generated during semantic checking, where there is // no associated conversion function decl. +FIDDLE() class BuiltinCastExpr : public Expr { - SLANG_AST_CLASS(BuiltinCastExpr); - Expr* base = nullptr; + FIDDLE(...) + FIDDLE() Expr* base = nullptr; }; +FIDDLE() class LValueImplicitCastExpr : public TypeCastExpr { - SLANG_AST_CLASS(LValueImplicitCastExpr) - + FIDDLE(...) explicit LValueImplicitCastExpr(const TypeCastExpr& rhs) : Super(rhs) { @@ -359,10 +393,10 @@ class LValueImplicitCastExpr : public TypeCastExpr // To work around situations like int += uint // where we want to allow an LValue to work with an implicit cast. // The argument being cast *must* be an LValue. +FIDDLE() class OutImplicitCastExpr : public LValueImplicitCastExpr { - SLANG_AST_CLASS(OutImplicitCastExpr) - + FIDDLE(...) /// Allow explict construction from any TypeCastExpr explicit OutImplicitCastExpr(const TypeCastExpr& rhs) : Super(rhs) @@ -370,10 +404,10 @@ class OutImplicitCastExpr : public LValueImplicitCastExpr } }; +FIDDLE() class InOutImplicitCastExpr : public LValueImplicitCastExpr { - SLANG_AST_CLASS(InOutImplicitCastExpr) - + FIDDLE(...) /// Allow explict construction from any TypeCastExpr explicit InOutImplicitCastExpr(const TypeCastExpr& rhs) : Super(rhs) @@ -385,249 +419,266 @@ class InOutImplicitCastExpr : public LValueImplicitCastExpr /// /// The type being cast to is stored as this expression's `type`. /// +FIDDLE() class CastToSuperTypeExpr : public Expr { - SLANG_AST_CLASS(CastToSuperTypeExpr) - + FIDDLE(...) /// The value being cast to a super type /// /// The type being cast from is `valueArg->type`. /// - Expr* valueArg = nullptr; + FIDDLE() Expr* valueArg = nullptr; /// A witness showing that `valueArg`'s type is a sub-type of this expression's `type` - Val* witnessArg = nullptr; + FIDDLE() Val* witnessArg = nullptr; }; /// A `value is Type` expression that evaluates to `true` if type of `value` is a sub-type of /// `Type`. +FIDDLE() class IsTypeExpr : public Expr { - SLANG_AST_CLASS(IsTypeExpr) - - Expr* value = nullptr; - TypeExp typeExpr; + FIDDLE(...) + FIDDLE() Expr* value = nullptr; + FIDDLE() TypeExp typeExpr; // A witness showing that `typeExpr.type` is a subtype of `typeof(value)`. - Val* witnessArg = nullptr; + FIDDLE() Val* witnessArg = nullptr; // non-null if evaluates to a constant. - BoolLiteralExpr* constantVal = nullptr; + FIDDLE() BoolLiteralExpr* constantVal = nullptr; }; /// A `value as Type` expression that casts `value` to `Type` within type hierarchy. /// The result is undefined if `value` is not `Type`. +FIDDLE() class AsTypeExpr : public Expr { - SLANG_AST_CLASS(AsTypeExpr) - - Expr* value = nullptr; - Expr* typeExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* value = nullptr; + FIDDLE() Expr* typeExpr = nullptr; // A witness showing that `typeExpr` is a subtype of `typeof(value)`. - Val* witnessArg = nullptr; + FIDDLE() Val* witnessArg = nullptr; }; +FIDDLE(abstract) class SizeOfLikeExpr : public Expr { - SLANG_AST_CLASS(SizeOfLikeExpr); - + FIDDLE(...) // Set during the parse, could be an expression, a variable or a type - Expr* value = nullptr; + FIDDLE() Expr* value = nullptr; // The type the size/alignment needs to operate on. Set during traversal of SemanticsExprVisitor - Type* sizedType = nullptr; + FIDDLE() Type* sizedType = nullptr; }; +FIDDLE() class SizeOfExpr : public SizeOfLikeExpr { - SLANG_AST_CLASS(SizeOfExpr); + FIDDLE(...) }; +FIDDLE() class AlignOfExpr : public SizeOfLikeExpr { - SLANG_AST_CLASS(AlignOfExpr); + FIDDLE(...) }; +FIDDLE() class CountOfExpr : public SizeOfLikeExpr { - SLANG_AST_CLASS(CountOfExpr); + FIDDLE(...) }; +FIDDLE() class MakeOptionalExpr : public Expr { - SLANG_AST_CLASS(MakeOptionalExpr) - + FIDDLE(...) // If `value` is null, this constructs an `Optional<T>` that doesn't have a value. - Expr* value = nullptr; - Expr* typeExpr = nullptr; + FIDDLE() Expr* value = nullptr; + FIDDLE() Expr* typeExpr = nullptr; }; /// A cast of a value to the same type, with different modifiers. /// /// The type being cast to is stored as this expression's `type`. /// +FIDDLE() class ModifierCastExpr : public Expr { - SLANG_AST_CLASS(ModifierCastExpr) - + FIDDLE(...) /// The value being cast. /// /// The type being cast from is `valueArg->type`. /// - Expr* valueArg = nullptr; + FIDDLE() Expr* valueArg = nullptr; }; +FIDDLE() class SelectExpr : public OperatorExpr { - SLANG_AST_CLASS(SelectExpr) + FIDDLE(...) }; +FIDDLE() class LogicOperatorShortCircuitExpr : public OperatorExpr { - SLANG_AST_CLASS(LogicOperatorShortCircuitExpr) + FIDDLE(...) public: enum Flavor { And, // && Or, // || }; - Flavor flavor; + FIDDLE() Flavor flavor; }; +FIDDLE() class GenericAppExpr : public AppExprBase { - SLANG_AST_CLASS(GenericAppExpr) + FIDDLE(...) }; // An expression representing re-use of the syntax for a type in more // than once conceptually-distinct declaration +FIDDLE() class SharedTypeExpr : public Expr { - SLANG_AST_CLASS(SharedTypeExpr) + FIDDLE(...) // The underlying type expression that we want to share TypeExp base; }; +FIDDLE() class AssignExpr : public Expr { - SLANG_AST_CLASS(AssignExpr) - Expr* left = nullptr; - Expr* right = nullptr; + FIDDLE(...) + FIDDLE() Expr* left = nullptr; + FIDDLE() Expr* right = nullptr; }; // Just an expression inside parentheses `(exp)` // // We keep this around explicitly to be sure we don't lose any structure // when we do rewriter stuff. +FIDDLE() class ParenExpr : public Expr { - SLANG_AST_CLASS(ParenExpr) + FIDDLE(...) Expr* base = nullptr; }; // An object-oriented `this` expression, used to // refer to the current instance of an enclosing type. +FIDDLE() class ThisExpr : public Expr { - SLANG_AST_CLASS(ThisExpr) - + FIDDLE(...) SLANG_UNREFLECTED Scope* scope = nullptr; }; // Represent a reference to the virtual __return_val object holding the return value of // functions whose result type is non-copyable. +FIDDLE() class ReturnValExpr : public Expr { - SLANG_AST_CLASS(ReturnValExpr) - + FIDDLE(...) SLANG_UNREFLECTED Scope* scope = nullptr; }; // An expression that binds a temporary variable in a local expression context +FIDDLE() class LetExpr : public Expr { - SLANG_AST_CLASS(LetExpr) - VarDecl* decl = nullptr; - Expr* body = nullptr; + FIDDLE(...) + FIDDLE() VarDecl* decl = nullptr; + FIDDLE() Expr* body = nullptr; }; +FIDDLE() class ExtractExistentialValueExpr : public Expr { - SLANG_AST_CLASS(ExtractExistentialValueExpr) - DeclRef<VarDeclBase> declRef; + FIDDLE(...) + FIDDLE() DeclRef<VarDeclBase> declRef; Expr* originalExpr; }; +FIDDLE() class OpenRefExpr : public Expr { - SLANG_AST_CLASS(OpenRefExpr) - - Expr* innerExpr = nullptr; + FIDDLE(...) + FIDDLE() Expr* innerExpr = nullptr; }; +FIDDLE() class DetachExpr : public Expr { - SLANG_AST_CLASS(DetachExpr) - - Expr* inner = nullptr; + FIDDLE(...) + FIDDLE() Expr* inner = nullptr; }; /// Base class for higher-order function application /// Eg: foo(fn) where fn is a function expression. /// +FIDDLE(abstract) class HigherOrderInvokeExpr : public Expr { - SLANG_ABSTRACT_AST_CLASS(HigherOrderInvokeExpr) - Expr* baseFunction; - List<Name*> newParameterNames; + FIDDLE(...) + FIDDLE() Expr* baseFunction; + FIDDLE() List<Name*> newParameterNames; }; +FIDDLE() class PrimalSubstituteExpr : public HigherOrderInvokeExpr { - SLANG_AST_CLASS(PrimalSubstituteExpr) + FIDDLE(...) }; +FIDDLE(abstract) class DifferentiateExpr : public HigherOrderInvokeExpr { - SLANG_ABSTRACT_AST_CLASS(DifferentiateExpr) + FIDDLE(...) }; /// An expression of the form `__fwd_diff(fn)` to access the /// forward-mode derivative version of the function `fn` /// +FIDDLE() class ForwardDifferentiateExpr : public DifferentiateExpr { - SLANG_AST_CLASS(ForwardDifferentiateExpr) + FIDDLE(...) }; /// An expression of the form `__bwd_diff(fn)` to access the /// forward-mode derivative version of the function `fn` /// +FIDDLE() class BackwardDifferentiateExpr : public DifferentiateExpr { - SLANG_AST_CLASS(BackwardDifferentiateExpr) + FIDDLE(...) }; /// An expression of the form `__dispatch_kernel(fn, threadGroupSize, dispatchSize)` to /// dispatch a compute kernel from host. /// +FIDDLE() class DispatchKernelExpr : public HigherOrderInvokeExpr { - SLANG_AST_CLASS(DispatchKernelExpr) - Expr* threadGroupSize; - Expr* dispatchSize; + FIDDLE(...) + FIDDLE() Expr* threadGroupSize; + FIDDLE() Expr* dispatchSize; }; /// An express to mark its inner expression as an intended non-differential call. +FIDDLE() class TreatAsDifferentiableExpr : public Expr { - SLANG_AST_CLASS(TreatAsDifferentiableExpr) - - Expr* innerExpr; + FIDDLE(...) + FIDDLE() Expr* innerExpr; Scope* scope; enum Flavor @@ -645,70 +696,70 @@ class TreatAsDifferentiableExpr : public Expr Differentiable }; - Flavor flavor; + FIDDLE() Flavor flavor; }; /// A type expression of the form `This` /// /// Refers to the type of `this` in the current context. /// +FIDDLE() class ThisTypeExpr : public Expr { - SLANG_AST_CLASS(ThisTypeExpr) - + FIDDLE(...) SLANG_UNREFLECTED Scope* scope = nullptr; }; /// A type expression of the form `Left & Right`. +FIDDLE() class AndTypeExpr : public Expr { - SLANG_AST_CLASS(AndTypeExpr); - - TypeExp left; - TypeExp right; + FIDDLE(...) + FIDDLE() TypeExp left; + FIDDLE() TypeExp right; }; /// A type exprssion that applies one or more modifiers to another type +FIDDLE() class ModifiedTypeExpr : public Expr { - SLANG_AST_CLASS(ModifiedTypeExpr); - - Modifiers modifiers; - TypeExp base; + FIDDLE(...) + FIDDLE() Modifiers modifiers; + FIDDLE() TypeExp base; }; /// A type expression that rrepresents a pointer type, e.g. T* +FIDDLE() class PointerTypeExpr : public Expr { - SLANG_AST_CLASS(PointerTypeExpr); - - TypeExp base; + FIDDLE(...) + FIDDLE() TypeExp base; }; /// A type expression that represents a function type, e.g. (bool, int) -> float +FIDDLE() class FuncTypeExpr : public Expr { - SLANG_AST_CLASS(FuncTypeExpr); - - List<TypeExp> parameters; - TypeExp result; + FIDDLE(...) + FIDDLE() List<TypeExp> parameters; + FIDDLE() TypeExp result; }; +FIDDLE() class TupleTypeExpr : public Expr { - SLANG_AST_CLASS(TupleTypeExpr); - - List<TypeExp> members; + FIDDLE(...) + FIDDLE() List<TypeExp> members; }; /// An expression that applies a generic to arguments for some, /// but not all, of its explicit parameters. /// +FIDDLE() class PartiallyAppliedGenericExpr : public Expr { - SLANG_AST_CLASS(PartiallyAppliedGenericExpr); - + FIDDLE(...) public: Expr* originalExpr = nullptr; @@ -723,16 +774,17 @@ public: /// An expression that holds a set of argument exprs that got matched to a pack parameter /// during overload resolution. /// +FIDDLE() class PackExpr : public Expr { - SLANG_AST_CLASS(PackExpr) - - List<Expr*> args; + FIDDLE(...) + FIDDLE() List<Expr*> args; }; -class SPIRVAsmOperand +FIDDLE() +struct SPIRVAsmOperand { - SLANG_VALUE_CLASS(SPIRVAsmOperand); + FIDDLE(...) public: enum Flavor @@ -792,21 +844,21 @@ public: TypeExp type = TypeExp(); }; -class SPIRVAsmInst +FIDDLE() +struct SPIRVAsmInst { - SLANG_VALUE_CLASS(SPIRVAsmInst); - + FIDDLE(...) public: SPIRVAsmOperand opcode; List<SPIRVAsmOperand> operands; }; +FIDDLE() class SPIRVAsmExpr : public Expr { - SLANG_AST_CLASS(SPIRVAsmExpr); - + FIDDLE(...) public: - List<SPIRVAsmInst> insts; + FIDDLE() List<SPIRVAsmInst> insts; }; } // namespace Slang diff --git a/source/slang/slang-ast-forward-declarations.h b/source/slang/slang-ast-forward-declarations.h new file mode 100644 index 000000000..717bca1d9 --- /dev/null +++ b/source/slang/slang-ast-forward-declarations.h @@ -0,0 +1,29 @@ +// slang-ast-forward-declarations.h +#pragma once + +namespace Slang +{ + +enum class ASTNodeType +{ +#if 0 // FIDDLE TEMPLATE: +%for _, T in ipairs(Slang.NodeBase.subclasses) do + $T, +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-ast-forward-declarations.h.fiddle" +#endif // FIDDLE END + CountOf +}; + +#if 0 // FIDDLE TEMPLATE: +%for _, T in ipairs(Slang.NodeBase.subclasses) do + class $T; +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 1 +#include "slang-ast-forward-declarations.h.fiddle" +#endif // FIDDLE END + +} // namespace Slang diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index c7da945f2..2112d452e 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -38,9 +38,9 @@ struct ASTIterator { if (!expr) return; - expr->accept(this, nullptr); + this->dispatch(expr); } - bool visitExpr(Expr*) { return false; } + void visitExpr(Expr*) {} void visitBoolLiteralExpr(BoolLiteralExpr* expr) { iterator->maybeDispatchCallback(expr); } void visitNullPtrLiteralExpr(NullPtrLiteralExpr* expr) { @@ -313,6 +313,8 @@ struct ASTIterator dispatchIfNotNull(o.expr); } } + + void visitDetachExpr(DetachExpr* expr) { iterator->maybeDispatchCallback(expr); } }; struct ASTIteratorStmtVisitor : public StmtVisitor<ASTIteratorStmtVisitor> @@ -327,7 +329,7 @@ struct ASTIterator { if (!stmt) return; - stmt->accept(this, nullptr); + this->dispatch(stmt); } void visitDeclStmt(DeclStmt* stmt) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 563084361..e566eca9e 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1,9 +1,10 @@ // slang-ast-modifier.h - #pragma once #include "slang-ast-base.h" +#include "slang-ast-modifier.h.fiddle" +FIDDLE() namespace Slang { @@ -11,246 +12,299 @@ namespace Slang // Simple modifiers have no state beyond their identity +FIDDLE() class InModifier : public Modifier { - SLANG_AST_CLASS(InModifier) + FIDDLE(...) }; + +FIDDLE() class OutModifier : public Modifier { - SLANG_AST_CLASS(OutModifier) + FIDDLE(...) }; + +FIDDLE() class ConstModifier : public Modifier { - SLANG_AST_CLASS(ConstModifier) + FIDDLE(...) }; + +FIDDLE() class BuiltinModifier : public Modifier { - SLANG_AST_CLASS(BuiltinModifier) + FIDDLE(...) }; + +FIDDLE() class InlineModifier : public Modifier { - SLANG_AST_CLASS(InlineModifier) + FIDDLE(...) }; + +FIDDLE(abstract) class VisibilityModifier : public Modifier { - SLANG_AST_CLASS(VisibilityModifier) + FIDDLE(...) }; + +FIDDLE() class PublicModifier : public VisibilityModifier { - SLANG_AST_CLASS(PublicModifier) + FIDDLE(...) }; + +FIDDLE() class PrivateModifier : public VisibilityModifier { - SLANG_AST_CLASS(PrivateModifier) + FIDDLE(...) }; + +FIDDLE() class InternalModifier : public VisibilityModifier { - SLANG_AST_CLASS(InternalModifier) + FIDDLE(...) }; + +FIDDLE() class RequireModifier : public Modifier { - SLANG_AST_CLASS(RequireModifier) + FIDDLE(...) }; + +FIDDLE() class ParamModifier : public Modifier { - SLANG_AST_CLASS(ParamModifier) + FIDDLE(...) }; + +FIDDLE() class ExternModifier : public Modifier { - SLANG_AST_CLASS(ExternModifier) + FIDDLE(...) }; + +FIDDLE() class HLSLExportModifier : public Modifier { - SLANG_AST_CLASS(HLSLExportModifier) + FIDDLE(...) }; + +FIDDLE() class TransparentModifier : public Modifier { - SLANG_AST_CLASS(TransparentModifier) + FIDDLE(...) }; + +FIDDLE() class FromCoreModuleModifier : public Modifier { - SLANG_AST_CLASS(FromCoreModuleModifier) + FIDDLE(...) }; + +FIDDLE() class PrefixModifier : public Modifier { - SLANG_AST_CLASS(PrefixModifier) + FIDDLE(...) }; + +FIDDLE() class PostfixModifier : public Modifier { - SLANG_AST_CLASS(PostfixModifier) + FIDDLE(...) }; + +FIDDLE() class ExportedModifier : public Modifier { - SLANG_AST_CLASS(ExportedModifier) + FIDDLE(...) }; + +FIDDLE() class ConstExprModifier : public Modifier { - SLANG_AST_CLASS(ConstExprModifier) + FIDDLE(...) }; + +FIDDLE() class ExternCppModifier : public Modifier { - SLANG_AST_CLASS(ExternCppModifier) + FIDDLE(...) }; + +FIDDLE() class GLSLPrecisionModifier : public Modifier { - SLANG_AST_CLASS(GLSLPrecisionModifier) + FIDDLE(...) }; + +FIDDLE() class GLSLModuleModifier : public Modifier { - SLANG_AST_CLASS(GLSLModuleModifier) + FIDDLE(...) }; + // Marks that the definition of a decl is not yet synthesized. +FIDDLE() class ToBeSynthesizedModifier : public Modifier { - SLANG_AST_CLASS(ToBeSynthesizedModifier) + FIDDLE(...) }; // Marks that the definition of a decl is synthesized. +FIDDLE() class SynthesizedModifier : public Modifier { - SLANG_AST_CLASS(SynthesizedModifier) + FIDDLE(...) }; // Marks a synthesized variable as local temporary variable. +FIDDLE() class LocalTempVarModifier : public Modifier { - SLANG_AST_CLASS(LocalTempVarModifier) + FIDDLE(...) }; // An `extern` variable in an extension is used to introduce additional attributes on an existing // field. +FIDDLE() class ExtensionExternVarModifier : public Modifier { - SLANG_AST_CLASS(ExtensionExternVarModifier) - DeclRef<Decl> originalDecl; + FIDDLE(...) + FIDDLE() DeclRef<Decl> originalDecl; }; // An 'ActualGlobal' is a global that is output as a normal global in CPU code. // Globals in HLSL/Slang are constant state passed into kernel execution +FIDDLE() class ActualGlobalModifier : public Modifier { - SLANG_AST_CLASS(ActualGlobalModifier) + FIDDLE(...) }; /// A modifier that indicates an `InheritanceDecl` should be ignored during name lookup (and related /// checks). +FIDDLE() class IgnoreForLookupModifier : public Modifier { - SLANG_AST_CLASS(IgnoreForLookupModifier) + FIDDLE(...) }; // A modifier that marks something as an operation that // has a one-to-one translation to the IR, and thus // has no direct definition in the high-level language. // +FIDDLE() class IntrinsicOpModifier : public Modifier { - SLANG_AST_CLASS(IntrinsicOpModifier) - + FIDDLE(...) // Token that names the intrinsic op. Token opToken; // The IR opcode for the intrinsic operation. // - uint32_t op = 0; + FIDDLE() uint32_t op = 0; }; // A modifier that marks something as an intrinsic function, // for some subset of targets. +FIDDLE() class TargetIntrinsicModifier : public Modifier { - SLANG_AST_CLASS(TargetIntrinsicModifier) - + FIDDLE(...) // Token that names the target that the operation // is an intrisic for. - Token targetToken; + FIDDLE() Token targetToken; // A custom definition for the operation, one of either an ident or a // string (the concatenation of several string literals) Token definitionIdent; - String definitionString; + FIDDLE() String definitionString; bool isString; // A predicate to be used on an identifier to guard this intrinsic Token predicateToken; NameLoc scrutinee; - DeclRef<Decl> scrutineeDeclRef; + FIDDLE() DeclRef<Decl> scrutineeDeclRef; }; // A modifier that marks a declaration as representing a // specialization that should be preferred on a particular // target. +FIDDLE() class SpecializedForTargetModifier : public Modifier { - SLANG_AST_CLASS(SpecializedForTargetModifier) - + FIDDLE(...) // Token that names the target that the operation // has been specialized for. - Token targetToken; + FIDDLE() Token targetToken; }; // A modifier to tag something as an intrinsic that requires // a certain GLSL extension to be enabled when used +FIDDLE() class RequiredGLSLExtensionModifier : public Modifier { - SLANG_AST_CLASS(RequiredGLSLExtensionModifier) - - Token extensionNameToken; + FIDDLE(...) + FIDDLE() Token extensionNameToken; }; // A modifier to tag something as an intrinsic that requires // a certain GLSL version to be enabled when used +FIDDLE() class RequiredGLSLVersionModifier : public Modifier { - SLANG_AST_CLASS(RequiredGLSLVersionModifier) - - Token versionNumberToken; + FIDDLE(...) + FIDDLE() Token versionNumberToken; }; // A modifier to tag something as an intrinsic that requires // a certain SPIRV version to be enabled when used. Specified as "major.minor" +FIDDLE() class RequiredSPIRVVersionModifier : public Modifier { - SLANG_AST_CLASS(RequiredSPIRVVersionModifier) - - SemanticVersion version; + FIDDLE(...) + FIDDLE() SemanticVersion version; }; // A modifier to tag something as an intrinsic that requires // a certain WGSL extension to be enabled when used +FIDDLE() class RequiredWGSLExtensionModifier : public Modifier { - SLANG_AST_CLASS(RequiredWGSLExtensionModifier) - - Token extensionNameToken; + FIDDLE(...) + FIDDLE() Token extensionNameToken; }; // A modifier to tag something as an intrinsic that requires // a certain CUDA SM version to be enabled when used. Specified as "major.minor" +FIDDLE() class RequiredCUDASMVersionModifier : public Modifier { - SLANG_AST_CLASS(RequiredCUDASMVersionModifier) - - SemanticVersion version; + FIDDLE(...) + FIDDLE() SemanticVersion version; }; +FIDDLE() class InOutModifier : public OutModifier { - SLANG_AST_CLASS(InOutModifier) + FIDDLE(...) }; // `__ref` modifier for by-reference parameter passing +FIDDLE() class RefModifier : public Modifier { - SLANG_AST_CLASS(RefModifier) + FIDDLE(...) }; // `__ref` modifier for by-reference parameter passing +FIDDLE() class ConstRefModifier : public Modifier { - SLANG_AST_CLASS(ConstRefModifier) + FIDDLE(...) }; // This is a special sentinel modifier that gets added @@ -267,132 +321,147 @@ class ConstRefModifier : public Modifier // / // b: RegisterModifier("x0") / // +FIDDLE() class SharedModifiers : public Modifier { - SLANG_AST_CLASS(SharedModifiers) + FIDDLE(...) }; // AST nodes to represent the begin/end of a `layout` modifier group +FIDDLE(abstract) class GLSLLayoutModifierGroupMarker : public Modifier { - SLANG_ABSTRACT_AST_CLASS(GLSLLayoutModifierGroupMarker) + FIDDLE(...) }; +FIDDLE() class GLSLLayoutModifierGroupBegin : public GLSLLayoutModifierGroupMarker { - SLANG_AST_CLASS(GLSLLayoutModifierGroupBegin) + FIDDLE(...) }; +FIDDLE() class GLSLLayoutModifierGroupEnd : public GLSLLayoutModifierGroupMarker { - SLANG_AST_CLASS(GLSLLayoutModifierGroupEnd) + FIDDLE(...) }; +FIDDLE() class GLSLUnparsedLayoutModifier : public Modifier { - SLANG_AST_CLASS(GLSLUnparsedLayoutModifier) + FIDDLE(...) }; +FIDDLE() class GLSLBufferDataLayoutModifier : public Modifier { - SLANG_AST_CLASS(GLSLBufferDataLayoutModifier) + FIDDLE(...) }; +FIDDLE() class GLSLStd140Modifier : public GLSLBufferDataLayoutModifier { - SLANG_AST_CLASS(GLSLStd140Modifier) + FIDDLE(...) }; +FIDDLE() class GLSLStd430Modifier : public GLSLBufferDataLayoutModifier { - SLANG_AST_CLASS(GLSLStd430Modifier) + FIDDLE(...) }; +FIDDLE() class GLSLScalarModifier : public GLSLBufferDataLayoutModifier { - SLANG_AST_CLASS(GLSLScalarModifier) + FIDDLE(...) }; // A catch-all for single-keyword modifiers +FIDDLE() class SimpleModifier : public Modifier { - SLANG_AST_CLASS(SimpleModifier) + FIDDLE(...) }; // Indicates that this is a variable declaration that corresponds to // a parameter block declaration in the source program. +FIDDLE() class ImplicitParameterGroupVariableModifier : public Modifier { - SLANG_AST_CLASS(ImplicitParameterGroupVariableModifier) + FIDDLE(...) }; // Indicates that this is a type that corresponds to the element // type of a parameter block declaration in the source program. +FIDDLE() class ImplicitParameterGroupElementTypeModifier : public Modifier { - SLANG_AST_CLASS(ImplicitParameterGroupElementTypeModifier) + FIDDLE(...) }; // An HLSL semantic +FIDDLE(abstract) class HLSLSemantic : public Modifier { - SLANG_ABSTRACT_AST_CLASS(HLSLSemantic) - - Token name; + FIDDLE(...) + FIDDLE() Token name; }; // An HLSL semantic that affects layout +FIDDLE() class HLSLLayoutSemantic : public HLSLSemantic { - SLANG_AST_CLASS(HLSLLayoutSemantic) - - Token registerName; - Token componentMask; + FIDDLE(...) + FIDDLE() Token registerName; + FIDDLE() Token componentMask; }; // An HLSL `register` semantic +FIDDLE() class HLSLRegisterSemantic : public HLSLLayoutSemantic { - SLANG_AST_CLASS(HLSLRegisterSemantic) - - Token spaceName; + FIDDLE(...) + FIDDLE() Token spaceName; }; // TODO(tfoley): `packoffset` +FIDDLE() class HLSLPackOffsetSemantic : public HLSLLayoutSemantic { - SLANG_AST_CLASS(HLSLPackOffsetSemantic) - - int uniformOffset = 0; + FIDDLE(...) + FIDDLE() int uniformOffset = 0; }; // An HLSL semantic that just associated a declaration with a semantic name +FIDDLE() class HLSLSimpleSemantic : public HLSLSemantic { - SLANG_AST_CLASS(HLSLSimpleSemantic) + FIDDLE(...) }; // A semantic applied to a field of a ray-payload type, to control access +FIDDLE() class RayPayloadAccessSemantic : public HLSLSemantic { - SLANG_AST_CLASS(RayPayloadAccessSemantic) - - List<Token> stageNameTokens; + FIDDLE(...) + FIDDLE() List<Token> stageNameTokens; }; +FIDDLE() class RayPayloadReadSemantic : public RayPayloadAccessSemantic { - SLANG_AST_CLASS(RayPayloadReadSemantic) + FIDDLE(...) }; +FIDDLE() class RayPayloadWriteSemantic : public RayPayloadAccessSemantic { - SLANG_AST_CLASS(RayPayloadWriteSemantic) + FIDDLE(...) }; @@ -400,73 +469,72 @@ class RayPayloadWriteSemantic : public RayPayloadAccessSemantic // Directives that came in via the preprocessor, but // that we need to keep around for later steps +FIDDLE() class GLSLPreprocessorDirective : public Modifier { - SLANG_AST_CLASS(GLSLPreprocessorDirective) + FIDDLE(...) }; // A GLSL `#version` directive +FIDDLE() class GLSLVersionDirective : public GLSLPreprocessorDirective { - SLANG_AST_CLASS(GLSLVersionDirective) - - + FIDDLE(...) // Token giving the version number to use - Token versionNumberToken; + FIDDLE() Token versionNumberToken; // Optional token giving the sub-profile to be used - Token glslProfileToken; + FIDDLE() Token glslProfileToken; }; // A GLSL `#extension` directive +FIDDLE() class GLSLExtensionDirective : public GLSLPreprocessorDirective { - SLANG_AST_CLASS(GLSLExtensionDirective) - - + FIDDLE(...) // Token giving the version number to use - Token extensionNameToken; + FIDDLE() Token extensionNameToken; // Optional token giving the sub-profile to be used - Token dispositionToken; + FIDDLE() Token dispositionToken; }; +FIDDLE() class ParameterGroupReflectionName : public Modifier { - SLANG_AST_CLASS(ParameterGroupReflectionName) - - NameLoc nameAndLoc; + FIDDLE(...) + FIDDLE() NameLoc nameAndLoc; }; // A modifier that indicates a built-in base type (e.g., `float`) +FIDDLE() class BuiltinTypeModifier : public Modifier { - SLANG_AST_CLASS(BuiltinTypeModifier) - - BaseType tag; + FIDDLE(...) + FIDDLE() BaseType tag; }; // A modifier that indicates a built-in type that isn't a base type (e.g., `vector`) // // TODO(tfoley): This deserves a better name than "magic" +FIDDLE() class MagicTypeModifier : public Modifier { - SLANG_AST_CLASS(MagicTypeModifier) - - ASTNodeType magicNodeType = ASTNodeType(-1); + FIDDLE(...) + FIDDLE() SyntaxClass<NodeBase> magicNodeType; /// Modifier has a name so call this magicModifier to disambiguate - String magicName; - uint32_t tag = uint32_t(0); + FIDDLE() String magicName; + FIDDLE() uint32_t tag = uint32_t(0); }; // A modifier that indicates a built-in associated type requirement (e.g., `Differential`) +FIDDLE() class BuiltinRequirementModifier : public Modifier { - SLANG_AST_CLASS(BuiltinRequirementModifier); - - BuiltinRequirementKind kind; + FIDDLE(...) + FIDDLE() BuiltinRequirementKind kind; }; @@ -475,47 +543,52 @@ class BuiltinRequirementModifier : public Modifier // // TODO: This should really subsume `BuiltinTypeModifier` and // `MagicTypeModifier` so that we don't have to apply all of them. +FIDDLE() class IntrinsicTypeModifier : public Modifier { - SLANG_AST_CLASS(IntrinsicTypeModifier) - + FIDDLE(...) // The IR opcode to use when constructing a type - uint32_t irOp; + FIDDLE() uint32_t irOp; Token opToken; // Additional literal opreands to provide when creating instances. // (e.g., for a texture type this passes in shape/mutability info) - List<uint32_t> irOperands; + FIDDLE() List<uint32_t> irOperands; }; // Modifiers that affect the storage layout for matrices +FIDDLE(abstract) class MatrixLayoutModifier : public Modifier { - SLANG_AST_CLASS(MatrixLayoutModifier) + FIDDLE(...) }; // Modifiers that specify row- and column-major layout, respectively +FIDDLE(abstract) class RowMajorLayoutModifier : public MatrixLayoutModifier { - SLANG_AST_CLASS(RowMajorLayoutModifier) + FIDDLE(...) }; +FIDDLE(abstract) class ColumnMajorLayoutModifier : public MatrixLayoutModifier { - SLANG_AST_CLASS(ColumnMajorLayoutModifier) + FIDDLE(...) }; // The HLSL flavor of those modifiers +FIDDLE() class HLSLRowMajorLayoutModifier : public RowMajorLayoutModifier { - SLANG_AST_CLASS(HLSLRowMajorLayoutModifier) + FIDDLE(...) }; +FIDDLE() class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier { - SLANG_AST_CLASS(HLSLColumnMajorLayoutModifier) + FIDDLE(...) }; @@ -526,676 +599,736 @@ class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier // we actually interpret that as requesting column-major. This makes // sense because we interpret matrix conventions backwards from how // GLSL specifies them. +FIDDLE() class GLSLRowMajorLayoutModifier : public ColumnMajorLayoutModifier { - SLANG_AST_CLASS(GLSLRowMajorLayoutModifier) + FIDDLE(...) }; +FIDDLE() class GLSLColumnMajorLayoutModifier : public RowMajorLayoutModifier { - SLANG_AST_CLASS(GLSLColumnMajorLayoutModifier) + FIDDLE(...) }; // More HLSL Keyword +FIDDLE(abstract) class InterpolationModeModifier : public Modifier { - SLANG_ABSTRACT_AST_CLASS(InterpolationModeModifier) + FIDDLE(...) }; // HLSL `nointerpolation` modifier +FIDDLE() class HLSLNoInterpolationModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(HLSLNoInterpolationModifier) + FIDDLE(...) }; // HLSL `noperspective` modifier +FIDDLE() class HLSLNoPerspectiveModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(HLSLNoPerspectiveModifier) + FIDDLE(...) }; // HLSL `linear` modifier +FIDDLE() class HLSLLinearModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(HLSLLinearModifier) + FIDDLE(...) }; // HLSL `sample` modifier +FIDDLE() class HLSLSampleModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(HLSLSampleModifier) + FIDDLE(...) }; // HLSL `centroid` modifier +FIDDLE() class HLSLCentroidModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(HLSLCentroidModifier) + FIDDLE(...) }; /// Slang-defined `pervertex` modifier +FIDDLE() class PerVertexModifier : public InterpolationModeModifier { - SLANG_AST_CLASS(PerVertexModifier) + FIDDLE(...) }; // HLSL `precise` modifier +FIDDLE() class PreciseModifier : public Modifier { - SLANG_AST_CLASS(PreciseModifier) + FIDDLE(...) }; // HLSL `shared` modifier (which is used by the effect system, // and shouldn't be confused with `groupshared`) +FIDDLE() class HLSLEffectSharedModifier : public Modifier { - SLANG_AST_CLASS(HLSLEffectSharedModifier) + FIDDLE(...) }; // HLSL `groupshared` modifier +FIDDLE() class HLSLGroupSharedModifier : public Modifier { - SLANG_AST_CLASS(HLSLGroupSharedModifier) + FIDDLE(...) }; // HLSL `static` modifier (probably doesn't need to be // treated as HLSL-specific) +FIDDLE() class HLSLStaticModifier : public Modifier { - SLANG_AST_CLASS(HLSLStaticModifier) + FIDDLE(...) }; // HLSL `uniform` modifier (distinct meaning from GLSL // use of the keyword) +FIDDLE() class HLSLUniformModifier : public Modifier { - SLANG_AST_CLASS(HLSLUniformModifier) + FIDDLE(...) }; // HLSL `volatile` modifier (ignored) +FIDDLE() class HLSLVolatileModifier : public Modifier { - SLANG_AST_CLASS(HLSLVolatileModifier) + FIDDLE(...) }; +FIDDLE() class AttributeTargetModifier : public Modifier { - SLANG_AST_CLASS(AttributeTargetModifier) - + FIDDLE(...) // A class to which the declared attribute type is applicable - SyntaxClass<NodeBase> syntaxClass; + FIDDLE() SyntaxClass<NodeBase> syntaxClass; }; // Base class for checked and unchecked `[name(arg0, ...)]` style attribute. +FIDDLE(abstract) class AttributeBase : public Modifier { - SLANG_AST_CLASS(AttributeBase) - - AttributeDecl* attributeDecl = nullptr; + FIDDLE(...) + FIDDLE() AttributeDecl* attributeDecl = nullptr; // The original identifier token representing the last part of the qualified name. Token originalIdentifierToken; - List<Expr*> args; + FIDDLE() List<Expr*> args; }; // A `[name(...)]` attribute that hasn't undergone any semantic analysis. // After analysis, this will be transformed into a more specific case. +FIDDLE() class UncheckedAttribute : public AttributeBase { - SLANG_AST_CLASS(UncheckedAttribute) - + FIDDLE(...) SLANG_UNREFLECTED Scope* scope = nullptr; }; // A GLSL layout qualifier whose value has not yet been resolved or validated. +FIDDLE() class UncheckedGLSLLayoutAttribute : public AttributeBase { - SLANG_AST_CLASS(UncheckedGLSLLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; // GLSL `binding` layout qualifier, does not include `set`. +FIDDLE() class UncheckedGLSLBindingLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLBindingLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; // GLSL `set` layout qualifier, does not include `binding`. +FIDDLE() class UncheckedGLSLSetLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLSetLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; // GLSL `offset` layout qualifier. +FIDDLE() class UncheckedGLSLOffsetLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLOffsetLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLInputAttachmentIndexLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLInputAttachmentIndexLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLLocationLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLLocationLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLIndexLayoutAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLIndexLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLConstantIdAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLConstantIdAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLRayPayloadAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLRayPayloadAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLRayPayloadInAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLRayPayloadInAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; - +FIDDLE() class UncheckedGLSLHitObjectAttributesAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLHitObjectAttributesAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLCallablePayloadAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLCallablePayloadAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class UncheckedGLSLCallablePayloadInAttribute : public UncheckedGLSLLayoutAttribute { - SLANG_AST_CLASS(UncheckedGLSLCallablePayloadInAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; // A `[name(arg0, ...)]` style attribute that has been validated. +FIDDLE() class Attribute : public AttributeBase { - SLANG_AST_CLASS(Attribute) - - List<Val*> intArgVals; + FIDDLE(...) + FIDDLE() List<Val*> intArgVals; }; +FIDDLE() class UserDefinedAttribute : public Attribute { - SLANG_AST_CLASS(UserDefinedAttribute) + FIDDLE(...) }; +FIDDLE() class AttributeUsageAttribute : public Attribute { - SLANG_AST_CLASS(AttributeUsageAttribute) - - SyntaxClass<NodeBase> targetSyntaxClass; + FIDDLE(...) + FIDDLE() SyntaxClass<NodeBase> targetSyntaxClass; }; +FIDDLE() class NonDynamicUniformAttribute : public Attribute { - SLANG_AST_CLASS(NonDynamicUniformAttribute) + FIDDLE(...) }; +FIDDLE() class RequireCapabilityAttribute : public Attribute { - SLANG_AST_CLASS(RequireCapabilityAttribute) - CapabilitySet capabilitySet; + FIDDLE(...) + FIDDLE() CapabilitySet capabilitySet; }; // An `[unroll]` or `[unroll(count)]` attribute +FIDDLE() class UnrollAttribute : public Attribute { - SLANG_AST_CLASS(UnrollAttribute) + FIDDLE(...) }; // An `[unroll]` or `[unroll(count)]` attribute +FIDDLE() class ForceUnrollAttribute : public Attribute { - SLANG_AST_CLASS(ForceUnrollAttribute) - - int32_t maxIterations = 0; + FIDDLE(...) + FIDDLE() int32_t maxIterations = 0; }; // An `[maxiters(count)]` +FIDDLE() class MaxItersAttribute : public Attribute { - SLANG_AST_CLASS(MaxItersAttribute) - - IntVal* value = 0; + FIDDLE(...) + FIDDLE() IntVal* value = 0; }; // An inferred max iteration count on a loop. +FIDDLE() class InferredMaxItersAttribute : public Attribute { - SLANG_AST_CLASS(InferredMaxItersAttribute) - DeclRef<Decl> inductionVar; - int32_t value = 0; + FIDDLE(...) + FIDDLE() DeclRef<Decl> inductionVar; + FIDDLE() int32_t value = 0; }; +FIDDLE() class LoopAttribute : public Attribute { - SLANG_AST_CLASS(LoopAttribute) + FIDDLE(...) }; // `[loop]` + +FIDDLE() class FastOptAttribute : public Attribute { - SLANG_AST_CLASS(FastOptAttribute) + FIDDLE(...) }; // `[fastopt]` + +FIDDLE() class AllowUAVConditionAttribute : public Attribute { - SLANG_AST_CLASS(AllowUAVConditionAttribute) + FIDDLE(...) }; // `[allow_uav_condition]` + +FIDDLE() class BranchAttribute : public Attribute { - SLANG_AST_CLASS(BranchAttribute) + FIDDLE(...) }; // `[branch]` + +FIDDLE() class FlattenAttribute : public Attribute { - SLANG_AST_CLASS(FlattenAttribute) + FIDDLE(...) }; // `[flatten]` + +FIDDLE() class ForceCaseAttribute : public Attribute { - SLANG_AST_CLASS(ForceCaseAttribute) + FIDDLE(...) }; // `[forcecase]` + +FIDDLE() class CallAttribute : public Attribute { - SLANG_AST_CLASS(CallAttribute) + FIDDLE(...) }; // `[call]` +FIDDLE() class UnscopedEnumAttribute : public Attribute { - SLANG_AST_CLASS(UnscopedEnumAttribute) + FIDDLE(...) }; // Marks a enum to have `flags` semantics, where each enum case is a bitfield. +FIDDLE() class FlagsAttribute : public Attribute { - SLANG_AST_CLASS(FlagsAttribute); + FIDDLE(...) }; // [[vk_push_constant]] [[push_constant]] +FIDDLE() class PushConstantAttribute : public Attribute { - SLANG_AST_CLASS(PushConstantAttribute) + FIDDLE(...) }; // [[vk_specialization_constant]] [[specialization_constant]] +FIDDLE() class SpecializationConstantAttribute : public Attribute { - SLANG_AST_CLASS(SpecializationConstantAttribute) + FIDDLE(...) }; // [[vk_constant_id]] +FIDDLE() class VkConstantIdAttribute : public Attribute { - SLANG_AST_CLASS(VkConstantIdAttribute) - int location; + FIDDLE(...) + FIDDLE() int location; }; // [[vk_shader_record]] [[shader_record]] +FIDDLE() class ShaderRecordAttribute : public Attribute { - SLANG_AST_CLASS(ShaderRecordAttribute) + FIDDLE(...) }; // [[vk_binding]] +FIDDLE() class GLSLBindingAttribute : public Attribute { - SLANG_AST_CLASS(GLSLBindingAttribute) - - int32_t binding = 0; - int32_t set = 0; + FIDDLE(...) + FIDDLE() int32_t binding = 0; + FIDDLE() int32_t set = 0; }; +FIDDLE() class VkAliasedPointerAttribute : public Attribute { - SLANG_AST_CLASS(VkAliasedPointerAttribute) + FIDDLE(...) }; +FIDDLE() class VkRestrictPointerAttribute : public Attribute { - SLANG_AST_CLASS(VkRestrictPointerAttribute) + FIDDLE(...) }; +FIDDLE() class GLSLOffsetLayoutAttribute : public Attribute { - SLANG_AST_CLASS(GLSLOffsetLayoutAttribute) - - int64_t offset; + FIDDLE(...) + FIDDLE() int64_t offset; }; // Implicitly added offset qualifier when no offset is specified. +FIDDLE() class GLSLImplicitOffsetLayoutAttribute : public AttributeBase { - SLANG_AST_CLASS(GLSLImplicitOffsetLayoutAttribute) - + FIDDLE(...) SLANG_UNREFLECTED }; +FIDDLE() class GLSLSimpleIntegerLayoutAttribute : public Attribute { - SLANG_AST_CLASS(GLSLSimpleIntegerLayoutAttribute) - - int32_t value = 0; + FIDDLE(...) + FIDDLE() int32_t value = 0; }; /// [[vk_input_attachment_index]] +FIDDLE() class GLSLInputAttachmentIndexLayoutAttribute : public Attribute { - SLANG_AST_CLASS(GLSLInputAttachmentIndexLayoutAttribute) - - IntegerLiteralValue location; + FIDDLE(...) + FIDDLE() IntegerLiteralValue location; }; // [[vk_location]] +FIDDLE() class GLSLLocationAttribute : public GLSLSimpleIntegerLayoutAttribute { - SLANG_AST_CLASS(GLSLLocationAttribute) + FIDDLE(...) }; // [[vk_index]] +FIDDLE() class GLSLIndexAttribute : public GLSLSimpleIntegerLayoutAttribute { - SLANG_AST_CLASS(GLSLIndexAttribute) + FIDDLE(...) }; // [[vk_offset]] +FIDDLE() class VkStructOffsetAttribute : public GLSLSimpleIntegerLayoutAttribute { - SLANG_AST_CLASS(VkStructOffsetAttribute) + FIDDLE(...) }; // [[vk_spirv_instruction]] +FIDDLE() class SPIRVInstructionOpAttribute : public Attribute { - SLANG_AST_CLASS(SPIRVInstructionOpAttribute) + FIDDLE(...) }; // [[spv_target_env_1_3]] +FIDDLE() class SPIRVTargetEnv13Attribute : public Attribute { - SLANG_AST_CLASS(SPIRVTargetEnv13Attribute); + FIDDLE(...) }; // [[disable_array_flattening]] +FIDDLE() class DisableArrayFlatteningAttribute : public Attribute { - SLANG_AST_CLASS(DisableArrayFlatteningAttribute); + FIDDLE(...) }; // A GLSL layout(local_size_x = 64, ... attribute) +FIDDLE() class GLSLLayoutLocalSizeAttribute : public Attribute { - SLANG_AST_CLASS(GLSLLayoutLocalSizeAttribute) - + FIDDLE(...) // The number of threads to use along each axis // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* extents[3]; + FIDDLE() IntVal* extents[3]; - bool axisIsSpecConstId[3]; + FIDDLE() bool axisIsSpecConstId[3]; // References to specialization constants, for defining the number of // threads with them. If set, the corresponding axis is set to nullptr // above. - DeclRef<VarDeclBase> specConstExtents[3]; + FIDDLE() DeclRef<VarDeclBase> specConstExtents[3]; }; +FIDDLE() class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute { - SLANG_AST_CLASS(GLSLLayoutDerivativeGroupQuadAttribute) + FIDDLE(...) }; +FIDDLE() class GLSLLayoutDerivativeGroupLinearAttribute : public Attribute { - SLANG_AST_CLASS(GLSLLayoutDerivativeGroupLinearAttribute) + FIDDLE(...) }; // TODO: for attributes that take arguments, the syntax node // classes should provide accessors for the values of those arguments. +FIDDLE() class MaxTessFactorAttribute : public Attribute { - SLANG_AST_CLASS(MaxTessFactorAttribute) + FIDDLE(...) }; +FIDDLE() class OutputControlPointsAttribute : public Attribute { - SLANG_AST_CLASS(OutputControlPointsAttribute) + FIDDLE(...) }; +FIDDLE() class OutputTopologyAttribute : public Attribute { - SLANG_AST_CLASS(OutputTopologyAttribute) + FIDDLE(...) }; +FIDDLE() class PartitioningAttribute : public Attribute { - SLANG_AST_CLASS(PartitioningAttribute) + FIDDLE(...) }; +FIDDLE() class PatchConstantFuncAttribute : public Attribute { - SLANG_AST_CLASS(PatchConstantFuncAttribute) - - FuncDecl* patchConstantFuncDecl = nullptr; + FIDDLE(...) + FIDDLE() FuncDecl* patchConstantFuncDecl = nullptr; }; + +FIDDLE() class DomainAttribute : public Attribute { - SLANG_AST_CLASS(DomainAttribute) + FIDDLE(...) }; +FIDDLE() class EarlyDepthStencilAttribute : public Attribute { - SLANG_AST_CLASS(EarlyDepthStencilAttribute) + FIDDLE(...) }; // `[earlydepthstencil]` // An HLSL `[numthreads(x,y,z)]` attribute +FIDDLE() class NumThreadsAttribute : public Attribute { - SLANG_AST_CLASS(NumThreadsAttribute) - + FIDDLE(...) // The number of threads to use along each axis // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* extents[3]; + FIDDLE() IntVal* extents[3]; // References to specialization constants, for defining the number of // threads with them. If set, the corresponding axis is set to nullptr // above. - DeclRef<VarDeclBase> specConstExtents[3]; + FIDDLE() DeclRef<VarDeclBase> specConstExtents[3]; }; +FIDDLE() class WaveSizeAttribute : public Attribute { - SLANG_AST_CLASS(WaveSizeAttribute) - + FIDDLE(...) // "numLanes" must be a compile time constant integer // value of an allowed wave size, which is one of the // followings: 4, 8, 16, 32, 64 or 128. // - IntVal* numLanes; + FIDDLE() IntVal* numLanes; }; +FIDDLE() class MaxVertexCountAttribute : public Attribute { - SLANG_AST_CLASS(MaxVertexCountAttribute) - + FIDDLE(...) // The number of max vertex count for geometry shader // // TODO: This should be an accessor that uses the // ordinary `args` list, rather than side data. - int32_t value; + FIDDLE() int32_t value; }; +FIDDLE() class InstanceAttribute : public Attribute { - SLANG_AST_CLASS(InstanceAttribute) - + FIDDLE(...) // The number of instances to run for geometry shader // // TODO: This should be an accessor that uses the // ordinary `args` list, rather than side data. - int32_t value; + FIDDLE() int32_t value; }; // A `[shader("stageName")]`/`[shader("capability")]` attribute which // marks an entry point for compiling. This attribute also specifies // the 'capabilities' implicitly supported by an entry point +FIDDLE() class EntryPointAttribute : public Attribute { - SLANG_AST_CLASS(EntryPointAttribute) - + FIDDLE(...) // The resolved capailities for our entry point. - CapabilitySet capabilitySet; + FIDDLE() CapabilitySet capabilitySet; }; // A `[__vulkanRayPayload(location)]` attribute, which is used in the // core module implementation to indicate that a variable // actually represents the input/output interface for a Vulkan // ray tracing shader to pass per-ray payload information. +FIDDLE() class VulkanRayPayloadAttribute : public Attribute { - SLANG_AST_CLASS(VulkanRayPayloadAttribute) - - int location; + FIDDLE(...) + FIDDLE() int location; }; +FIDDLE() class VulkanRayPayloadInAttribute : public Attribute { - SLANG_AST_CLASS(VulkanRayPayloadInAttribute) - - int location; + FIDDLE(...) + FIDDLE() int location; }; // A `[__vulkanCallablePayload(location)]` attribute, which is used in the // core module implementation to indicate that a variable // actually represents the input/output interface for a Vulkan // ray tracing shader to pass payload information to/from a callee. +FIDDLE() class VulkanCallablePayloadAttribute : public Attribute { - SLANG_AST_CLASS(VulkanCallablePayloadAttribute) - - int location; + FIDDLE(...) + FIDDLE() int location; }; +FIDDLE() class VulkanCallablePayloadInAttribute : public Attribute { - SLANG_AST_CLASS(VulkanCallablePayloadInAttribute) - - int location; + FIDDLE(...) + FIDDLE() int location; }; // A `[__vulkanHitAttributes]` attribute, which is used in the // core module implementation to indicate that a variable // actually represents the output interface for a Vulkan // intersection shader to pass hit attribute information. +FIDDLE() class VulkanHitAttributesAttribute : public Attribute { - SLANG_AST_CLASS(VulkanHitAttributesAttribute) + FIDDLE(...) }; // A `[__vulkanHitObjectAttributes(location)]` attribute, which is used in the // core module implementation to indicate that a variable // actually represents the attributes on a HitObject as part of // Shader ExecutionReordering +FIDDLE() class VulkanHitObjectAttributesAttribute : public Attribute { - SLANG_AST_CLASS(VulkanHitObjectAttributesAttribute) - - int location; + FIDDLE(...) + FIDDLE() int location; }; // A `[mutating]` attribute, which indicates that a member // function is allowed to modify things through its `this` // argument. // +FIDDLE() class MutatingAttribute : public Attribute { - SLANG_AST_CLASS(MutatingAttribute) + FIDDLE(...) }; // A `[nonmutating]` attribute, which indicates that a // `set` accessor does not need to modify anything through // its `this` parameter. // +FIDDLE() class NonmutatingAttribute : public Attribute { - SLANG_AST_CLASS(NonmutatingAttribute) + FIDDLE(...) }; // A `[constref]` attribute, which indicates that the `this` parameter of // a member function should be passed by const reference. // +FIDDLE() class ConstRefAttribute : public Attribute { - SLANG_AST_CLASS(ConstRefAttribute) + FIDDLE(...) }; // A `[ref]` attribute, which indicates that the `this` parameter of // a member function should be passed by reference. // +FIDDLE() class RefAttribute : public Attribute { - SLANG_AST_CLASS(RefAttribute) + FIDDLE(...) }; // A `[__readNone]` attribute, which indicates that a function @@ -1203,174 +1336,194 @@ class RefAttribute : public Attribute // reading or writing through any pointer arguments, or any other // state that could be observed by a caller. // +FIDDLE() class ReadNoneAttribute : public Attribute { - SLANG_AST_CLASS(ReadNoneAttribute) + FIDDLE(...) }; // A `[__GLSLRequireShaderInputParameter]` attribute to annotate // functions that require a shader input as parameter // +FIDDLE() class GLSLRequireShaderInputParameterAttribute : public Attribute { - SLANG_AST_CLASS(GLSLRequireShaderInputParameterAttribute) - - uint32_t parameterNumber; + FIDDLE(...) + FIDDLE() uint32_t parameterNumber; }; // HLSL modifiers for geometry shader input topology +FIDDLE() class HLSLGeometryShaderInputPrimitiveTypeModifier : public Modifier { - SLANG_AST_CLASS(HLSLGeometryShaderInputPrimitiveTypeModifier) + FIDDLE(...) }; +FIDDLE() class HLSLPointModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { - SLANG_AST_CLASS(HLSLPointModifier) + FIDDLE(...) }; +FIDDLE() class HLSLLineModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { - SLANG_AST_CLASS(HLSLLineModifier) + FIDDLE(...) }; +FIDDLE() class HLSLTriangleModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { - SLANG_AST_CLASS(HLSLTriangleModifier) + FIDDLE(...) }; +FIDDLE() class HLSLLineAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { - SLANG_AST_CLASS(HLSLLineAdjModifier) + FIDDLE(...) }; +FIDDLE() class HLSLTriangleAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { - SLANG_AST_CLASS(HLSLTriangleAdjModifier) + FIDDLE(...) }; // Mesh shader paramters +FIDDLE() class HLSLMeshShaderOutputModifier : public Modifier { - SLANG_AST_CLASS(HLSLMeshShaderOutputModifier) + FIDDLE(...) }; +FIDDLE() class HLSLVerticesModifier : public HLSLMeshShaderOutputModifier { - SLANG_AST_CLASS(HLSLVerticesModifier) + FIDDLE(...) }; +FIDDLE() class HLSLIndicesModifier : public HLSLMeshShaderOutputModifier { - SLANG_AST_CLASS(HLSLIndicesModifier) + FIDDLE(...) }; +FIDDLE() class HLSLPrimitivesModifier : public HLSLMeshShaderOutputModifier { - SLANG_AST_CLASS(HLSLPrimitivesModifier) + FIDDLE(...) }; +FIDDLE() class HLSLPayloadModifier : public Modifier { - SLANG_AST_CLASS(HLSLPayloadModifier) + FIDDLE(...) }; // A modifier to indicate that a constructor/initializer can be used // to perform implicit type conversion, and to specify the cost of // the conversion, if applied. +FIDDLE() class ImplicitConversionModifier : public Modifier { - SLANG_AST_CLASS(ImplicitConversionModifier) - + FIDDLE(...) // The conversion cost, used to rank conversions - ConversionCost cost = kConversionCost_None; + FIDDLE() ConversionCost cost = kConversionCost_None; // A builtin identifier for identifying conversions that need special treatment. - BuiltinConversionKind builtinConversionKind = kBuiltinConversion_Unknown; + FIDDLE() BuiltinConversionKind builtinConversionKind = kBuiltinConversion_Unknown; }; +FIDDLE() class FormatAttribute : public Attribute { - SLANG_AST_CLASS(FormatAttribute) - - ImageFormat format; + FIDDLE(...) + FIDDLE() ImageFormat format; }; +FIDDLE() class AllowAttribute : public Attribute { - SLANG_AST_CLASS(AllowAttribute) - - DiagnosticInfo const* diagnostic = nullptr; + FIDDLE(...) + FIDDLE() DiagnosticInfo const* diagnostic = nullptr; }; // A `[__extern]` attribute, which indicates that a function/type is defined externally // +FIDDLE() class ExternAttribute : public Attribute { - SLANG_AST_CLASS(ExternAttribute) + FIDDLE(...) }; // An `[__unsafeForceInlineExternal]` attribute indicates that the callee should be inlined // into call sites after initial IR generation (that is, as early as possible). // +FIDDLE() class UnsafeForceInlineEarlyAttribute : public Attribute { - SLANG_AST_CLASS(UnsafeForceInlineEarlyAttribute) + FIDDLE(...) }; // A `[ForceInline]` attribute indicates that the callee should be inlined // by the Slang compiler. // +FIDDLE() class ForceInlineAttribute : public Attribute { - SLANG_AST_CLASS(ForceInlineAttribute) + FIDDLE(...) }; /// An attribute that marks a type declaration as either allowing or /// disallowing the type to be inherited from in other modules. +FIDDLE(abstract) class InheritanceControlAttribute : public Attribute { - SLANG_AST_CLASS(InheritanceControlAttribute) + FIDDLE(...) }; /// An attribute that marks a type declaration as allowing the type to be inherited from in other /// modules. +FIDDLE() class OpenAttribute : public InheritanceControlAttribute { - SLANG_AST_CLASS(OpenAttribute) + FIDDLE(...) }; /// An attribute that marks a type declaration as disallowing the type to be inherited from in other /// modules. +FIDDLE() class SealedAttribute : public InheritanceControlAttribute { - SLANG_AST_CLASS(SealedAttribute) + FIDDLE(...) }; /// An attribute that marks a decl as a compiler built-in object. +FIDDLE() class BuiltinAttribute : public Attribute { - SLANG_AST_CLASS(BuiltinAttribute) + FIDDLE(...) }; /// An attribute that marks a decl as a compiler built-in object for the autodiff system. +FIDDLE() class AutoDiffBuiltinAttribute : public Attribute { - SLANG_AST_CLASS(AutoDiffBuiltinAttribute) + FIDDLE(...) }; /// An attribute that defines the size of `AnyValue` type to represent a polymoprhic value that /// conforms to the decorated interface type. +FIDDLE() class AnyValueSizeAttribute : public Attribute { - SLANG_AST_CLASS(AnyValueSizeAttribute) - - int32_t size; + FIDDLE(...) + FIDDLE() int32_t size; }; /// This is a stop-gap solution to break overload ambiguity in the core module. @@ -1379,24 +1532,27 @@ class AnyValueSizeAttribute : public Attribute /// In the future, we should enhance our type system to take into account the "specialized"-ness /// of an overload, such that `T overload1<T:IDerived>()` is more specialized than `T /// overload2<T:IBase>()` and preferred during overload resolution. +FIDDLE() class OverloadRankAttribute : public Attribute { - SLANG_AST_CLASS(OverloadRankAttribute) - int32_t rank; + FIDDLE(...) + FIDDLE() int32_t rank; }; /// An attribute that marks an interface for specialization use only. Any operation that triggers /// dynamic dispatch through the interface is a compile-time error. +FIDDLE() class SpecializeAttribute : public Attribute { - SLANG_AST_CLASS(SpecializeAttribute) + FIDDLE(...) }; /// An attribute that marks a type, function or variable as differentiable. +FIDDLE() class DifferentiableAttribute : public Attribute { - SLANG_AST_CLASS(DifferentiableAttribute) - + FIDDLE(...) + // TODO(tfoley): Why is there this duplication here? List<KeyValuePair<Type*, SubtypeWitness*>> m_typeToIDifferentiableWitnessMappings; void addType(Type* declRef, SubtypeWitness* witness) @@ -1418,55 +1574,62 @@ private: OrderedDictionary<Type*, SubtypeWitness*> m_mapToIDifferentiableWitness; }; +FIDDLE() class DllImportAttribute : public Attribute { - SLANG_AST_CLASS(DllImportAttribute) + FIDDLE(...) + FIDDLE() String modulePath; - String modulePath; - - String functionName; + FIDDLE() String functionName; }; +FIDDLE() class DllExportAttribute : public Attribute { - SLANG_AST_CLASS(DllExportAttribute) + FIDDLE(...) }; +FIDDLE() class TorchEntryPointAttribute : public Attribute { - SLANG_AST_CLASS(TorchEntryPointAttribute) + FIDDLE(...) }; +FIDDLE() class CudaDeviceExportAttribute : public Attribute { - SLANG_AST_CLASS(CudaDeviceExportAttribute) + FIDDLE(...) }; +FIDDLE() class CudaKernelAttribute : public Attribute { - SLANG_AST_CLASS(CudaKernelAttribute) + FIDDLE(...) }; +FIDDLE() class CudaHostAttribute : public Attribute { - SLANG_AST_CLASS(CudaHostAttribute) + FIDDLE(...) }; +FIDDLE() class AutoPyBindCudaAttribute : public Attribute { - SLANG_AST_CLASS(AutoPyBindCudaAttribute) + FIDDLE(...) }; +FIDDLE() class PyExportAttribute : public Attribute { - SLANG_AST_CLASS(PyExportAttribute) - - String name; + FIDDLE(...) + FIDDLE() String name; }; +FIDDLE() class PreferRecomputeAttribute : public Attribute { - SLANG_AST_CLASS(PreferRecomputeAttribute) + FIDDLE(...) enum SideEffectBehavior { @@ -1474,87 +1637,94 @@ class PreferRecomputeAttribute : public Attribute Allow = 1 }; - SideEffectBehavior sideEffectBehavior; + FIDDLE() SideEffectBehavior sideEffectBehavior; }; +FIDDLE() class PreferCheckpointAttribute : public Attribute { - SLANG_AST_CLASS(PreferCheckpointAttribute) + FIDDLE(...) }; +FIDDLE() class DerivativeMemberAttribute : public Attribute { - SLANG_AST_CLASS(DerivativeMemberAttribute) - - DeclRefExpr* memberDeclRef; + FIDDLE(...) + FIDDLE() DeclRefExpr* memberDeclRef; }; /// An attribute that marks an interface type as a COM interface declaration. +FIDDLE() class ComInterfaceAttribute : public Attribute { - SLANG_AST_CLASS(ComInterfaceAttribute) - - String guid; + FIDDLE(...) + FIDDLE() String guid; }; /// A `[__requiresNVAPI]` attribute indicates that the declaration being modifed /// requires NVAPI operations for its implementation on D3D. +FIDDLE() class RequiresNVAPIAttribute : public Attribute { - SLANG_AST_CLASS(RequiresNVAPIAttribute) + FIDDLE(...) }; /// A `[RequirePrelude(target, "string")]` attribute indicates that the declaration being modifed /// requires a textual prelude to be injected in the resulting target code. +FIDDLE() class RequirePreludeAttribute : public Attribute { - SLANG_AST_CLASS(RequirePreludeAttribute) - - CapabilitySet capabilitySet; - String prelude; + FIDDLE(...) + FIDDLE() CapabilitySet capabilitySet; + FIDDLE() String prelude; }; /// A `[__AlwaysFoldIntoUseSite]` attribute indicates that the calls into the modified /// function should always be folded into use sites during source emit. +FIDDLE() class AlwaysFoldIntoUseSiteAttribute : public Attribute { - SLANG_AST_CLASS(AlwaysFoldIntoUseSiteAttribute) + FIDDLE(...) }; // A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface // should be treated as differentiable in IR validation step. // +FIDDLE() class TreatAsDifferentiableAttribute : public DifferentiableAttribute { - SLANG_AST_CLASS(TreatAsDifferentiableAttribute) + FIDDLE(...) }; /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. +FIDDLE() class ForwardDifferentiableAttribute : public DifferentiableAttribute { - SLANG_AST_CLASS(ForwardDifferentiableAttribute) + FIDDLE(...) }; +FIDDLE() class UserDefinedDerivativeAttribute : public DifferentiableAttribute { - SLANG_AST_CLASS(UserDefinedDerivativeAttribute) - - Expr* funcExpr; + FIDDLE(...) + FIDDLE() Expr* funcExpr; }; /// The `[ForwardDerivative(function)]` attribute specifies a custom function that should /// be used as the derivative for the decorated function. +FIDDLE() class ForwardDerivativeAttribute : public UserDefinedDerivativeAttribute { - SLANG_AST_CLASS(ForwardDerivativeAttribute) + FIDDLE(...) }; +FIDDLE() class DerivativeOfAttribute : public DifferentiableAttribute { - SLANG_AST_CLASS(DerivativeOfAttribute) - - Expr* funcExpr; + FIDDLE(...) + FIDDLE() Expr* funcExpr; + FIDDLE() Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; @@ -1562,75 +1732,83 @@ class DerivativeOfAttribute : public DifferentiableAttribute /// derivative implementation for `primalFunction`. /// ForwardDerivativeOfAttribute inherits from DifferentiableAttribute because a derivative /// function itself is considered differentiable. +FIDDLE() class ForwardDerivativeOfAttribute : public DerivativeOfAttribute { - SLANG_AST_CLASS(ForwardDerivativeOfAttribute) + FIDDLE(...) }; /// The `[BackwardDifferentiable]` attribute indicates that a function can be /// backward-differentiated. +FIDDLE() class BackwardDifferentiableAttribute : public DifferentiableAttribute { - SLANG_AST_CLASS(BackwardDifferentiableAttribute) - int maxOrder = 0; + FIDDLE(...) + FIDDLE() int maxOrder = 0; }; /// The `[BackwardDerivative(function)]` attribute specifies a custom function that should /// be used as the backward-derivative for the decorated function. +FIDDLE() class BackwardDerivativeAttribute : public UserDefinedDerivativeAttribute { - SLANG_AST_CLASS(BackwardDerivativeAttribute) + FIDDLE(...) }; /// The `[BackwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom /// backward-derivative implementation for `primalFunction`. +FIDDLE() class BackwardDerivativeOfAttribute : public DerivativeOfAttribute { - SLANG_AST_CLASS(BackwardDerivativeOfAttribute) + FIDDLE(...) }; /// The `[PrimalSubstitute(function)]` attribute specifies a custom function that should /// be used as the primal function substitute when differentiating code that calls the primal /// function. +FIDDLE() class PrimalSubstituteAttribute : public Attribute { - SLANG_AST_CLASS(PrimalSubstituteAttribute) - Expr* funcExpr; + FIDDLE(...) + FIDDLE() Expr* funcExpr; }; /// The `[PrimalSubstituteOf(primalFunction)]` attribute marks the decorated function as /// the substitute primal function in a forward or backward derivative function. +FIDDLE() class PrimalSubstituteOfAttribute : public Attribute { - SLANG_AST_CLASS(PrimalSubstituteOfAttribute) - - Expr* funcExpr; + FIDDLE(...) + FIDDLE() Expr* funcExpr; + FIDDLE() Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; /// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be /// included for differentiation. +FIDDLE() class NoDiffThisAttribute : public Attribute { - SLANG_AST_CLASS(NoDiffThisAttribute) + FIDDLE(...) }; /// Indicates that the modified declaration is one of the "magic" declarations /// that NVAPI uses to communicate extended operations. When NVAPI is being included /// via the prelude for downstream compilation, declarations with this modifier /// will not be emitted, instead allowing the versions from the prelude to be used. +FIDDLE() class NVAPIMagicModifier : public Modifier { - SLANG_AST_CLASS(NVAPIMagicModifier) + FIDDLE(...) }; /// A modifier that attaches to a `ModuleDecl` to indicate the register/space binding /// that NVAPI wants to use, as indicated by, e.g., the `NV_SHADER_EXTN_SLOT` and /// `NV_SHADER_EXTN_REGISTER_SPACE` preprocessor definitions. +FIDDLE() class NVAPISlotModifier : public Modifier { - SLANG_AST_CLASS(NVAPISlotModifier) - + FIDDLE(...) /// The name of the register that is to be used (e.g., `"u3"`) /// /// This value will come from the `NV_SHADER_EXTN_SLOT` macro, if set. @@ -1639,7 +1817,7 @@ class NVAPISlotModifier : public Modifier /// an `NVAPISlotModifier` to a module; if no register name is defined, /// then the modifier should not be added. /// - String registerName; + FIDDLE() String registerName; /// The name of the register space to be used (e.g., `space1`) /// @@ -1648,7 +1826,7 @@ class NVAPISlotModifier : public Modifier /// /// It is valid for a user to specify a register name but not a space name, /// and in that case `spaceName` will be set to `"space0"`. - String spaceName; + FIDDLE() String spaceName; }; /// A `[noinline]` attribute represents a request by the application that, @@ -1657,41 +1835,48 @@ class NVAPISlotModifier : public Modifier /// Note that due to various limitations of different targets, it is entirely /// possible for such functions to be inlined or specialized to call sites. /// +FIDDLE() class NoInlineAttribute : public Attribute { - SLANG_AST_CLASS(NoInlineAttribute) + FIDDLE(...) }; /// A `[noRefInline]` attribute represents a request to not force inline a /// function specifically due to a refType parameter. +FIDDLE() class NoRefInlineAttribute : public Attribute { - SLANG_AST_CLASS(NoRefInlineAttribute) + FIDDLE(...) }; +FIDDLE() class DerivativeGroupQuadAttribute : public Attribute { - SLANG_AST_CLASS(DerivativeGroupQuadAttribute) + FIDDLE(...) }; +FIDDLE() class DerivativeGroupLinearAttribute : public Attribute { - SLANG_AST_CLASS(DerivativeGroupLinearAttribute) + FIDDLE(...) }; +FIDDLE() class MaximallyReconvergesAttribute : public Attribute { - SLANG_AST_CLASS(MaximallyReconvergesAttribute) + FIDDLE(...) }; +FIDDLE() class QuadDerivativesAttribute : public Attribute { - SLANG_AST_CLASS(QuadDerivativesAttribute) + FIDDLE(...) }; +FIDDLE() class RequireFullQuadsAttribute : public Attribute { - SLANG_AST_CLASS(RequireFullQuadsAttribute) + FIDDLE(...) }; /// A `[payload]` attribute indicates that a `struct` type will be used as @@ -1699,9 +1884,10 @@ class RequireFullQuadsAttribute : public Attribute /// for shaders in the ray tracing pipeline that might be invoked for /// such a ray. /// +FIDDLE() class PayloadAttribute : public Attribute { - SLANG_AST_CLASS(PayloadAttribute) + FIDDLE(...) }; /// A `[raypayload]` attribute indicates that a `struct` type will be used as @@ -1709,9 +1895,10 @@ class PayloadAttribute : public Attribute /// for shaders in the ray tracing pipeline that might be invoked for /// such a ray. /// +FIDDLE() class RayPayloadAttribute : public Attribute { - SLANG_AST_CLASS(RayPayloadAttribute) + FIDDLE(...) }; /// A `[deprecated("message")]` attribute indicates the target is @@ -1719,32 +1906,34 @@ class RayPayloadAttribute : public Attribute /// A compiler warning including the message will be raised if the /// deprecated value is used. /// +FIDDLE() class DeprecatedAttribute : public Attribute { - SLANG_AST_CLASS(DeprecatedAttribute) - - String message; + FIDDLE(...) + FIDDLE() String message; }; +FIDDLE() class NonCopyableTypeAttribute : public Attribute { - SLANG_AST_CLASS(NonCopyableTypeAttribute) + FIDDLE(...) }; +FIDDLE() class NoSideEffectAttribute : public Attribute { - SLANG_AST_CLASS(NoSideEffectAttribute) + FIDDLE(...) }; /// A `[KnownBuiltin("name")]` attribute allows the compiler to /// identify this declaration during compilation, despite obfuscation or /// linkage removing optimizations /// +FIDDLE() class KnownBuiltinAttribute : public Attribute { - SLANG_AST_CLASS(KnownBuiltinAttribute) - - String name; + FIDDLE(...) + FIDDLE() String name; }; /// A modifier that applies to types rather than declarations. @@ -1762,103 +1951,117 @@ class KnownBuiltinAttribute : public Attribute /// and instead want to belong to the type (or rather the type *specifier* /// from a parsing standpoint). /// +FIDDLE() class TypeModifier : public Modifier { - SLANG_AST_CLASS(TypeModifier) + FIDDLE(...) }; /// A kind of syntax element which appears as a modifier in the syntax, but /// we represent as a function over type expressions +FIDDLE() class WrappingTypeModifier : public TypeModifier { - SLANG_AST_CLASS(WrappingTypeModifier) + FIDDLE(...) }; /// A modifier that applies to a type and implies information about the /// underlying format of a resource that uses that type as its element type. /// +FIDDLE() class ResourceElementFormatModifier : public TypeModifier { - SLANG_AST_CLASS(ResourceElementFormatModifier) + FIDDLE(...) }; /// HLSL `unorm` modifier +FIDDLE() class UNormModifier : public ResourceElementFormatModifier { - SLANG_AST_CLASS(UNormModifier) + FIDDLE(...) }; /// HLSL `snorm` modifier +FIDDLE() class SNormModifier : public ResourceElementFormatModifier { - SLANG_AST_CLASS(SNormModifier) + FIDDLE(...) }; +FIDDLE() class NoDiffModifier : public TypeModifier { - SLANG_AST_CLASS(NoDiffModifier) + FIDDLE(...) }; +FIDDLE() class GloballyCoherentModifier : public SimpleModifier { - SLANG_AST_CLASS(GloballyCoherentModifier) + FIDDLE(...) }; // Some GLSL-specific modifiers +FIDDLE() class GLSLBufferModifier : public WrappingTypeModifier { - SLANG_AST_CLASS(GLSLBufferModifier) + FIDDLE(...) }; +FIDDLE() class GLSLWriteOnlyModifier : public SimpleModifier { - SLANG_AST_CLASS(GLSLWriteOnlyModifier) + FIDDLE(...) }; +FIDDLE() class GLSLReadOnlyModifier : public SimpleModifier { - SLANG_AST_CLASS(GLSLReadOnlyModifier) + FIDDLE(...) }; +FIDDLE() class GLSLVolatileModifier : public SimpleModifier { - SLANG_AST_CLASS(GLSLVolatileModifier) + FIDDLE(...) }; +FIDDLE() class GLSLRestrictModifier : public SimpleModifier { - SLANG_AST_CLASS(GLSLRestrictModifier) + FIDDLE(...) }; +FIDDLE() class GLSLPatchModifier : public SimpleModifier { - SLANG_AST_CLASS(GLSLPatchModifier) + FIDDLE(...) }; // +FIDDLE() class BitFieldModifier : public Modifier { - SLANG_AST_CLASS(BitFieldModifier) - - IntegerLiteralValue width; + FIDDLE(...) + FIDDLE() IntegerLiteralValue width; // Fields filled during semantic analysis - IntegerLiteralValue offset = 0; - DeclRef<VarDecl> backingDeclRef; + FIDDLE() IntegerLiteralValue offset = 0; + FIDDLE() DeclRef<VarDecl> backingDeclRef; }; +FIDDLE() class DynamicUniformModifier : public Modifier { - SLANG_AST_CLASS(DynamicUniformModifier) + FIDDLE(...) }; +FIDDLE() class MemoryQualifierSetModifier : public Modifier { - SLANG_AST_CLASS(MemoryQualifierSetModifier); - - List<Modifier*> memoryModifiers; + FIDDLE(...) + FIDDLE() List<Modifier*> memoryModifiers; - uint32_t memoryQualifiers = 0; + FIDDLE() uint32_t memoryQualifiers = 0; public: struct Flags diff --git a/source/slang/slang-ast-reflect.cpp b/source/slang/slang-ast-reflect.cpp deleted file mode 100644 index 3f4ba9534..000000000 --- a/source/slang/slang-ast-reflect.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include "slang-ast-reflect.h" - -#include "../core/slang-smart-pointer.h" -#include "slang-ast-all.h" -#include "slang-generated-ast-macro.h" -#include "slang-visitor.h" -#include "slang.h" - -#include <assert.h> -#include <typeinfo> - -namespace Slang -{ - -#define SLANG_REFLECT_GET_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - infos.infos[int(ASTNodeType::NAME)] = &NAME::kReflectClassInfo; - -static ASTClassInfo::Infos _calcInfos() -{ - ASTClassInfo::Infos infos; - memset(&infos, 0, sizeof(infos)); - SLANG_ALL_ASTNode_NodeBase(SLANG_REFLECT_GET_REFLECT_CLASS_INFO, _) return infos; -} - -/* static */ const ASTClassInfo::Infos ASTClassInfo::kInfos = _calcInfos(); - -// Now try and implement all of the classes -// Macro generated is of the format - -struct ASTConstructAccess -{ - template<typename T> - struct Impl - { - static void* create(void* context) - { - ASTBuilder* astBuilder = (ASTBuilder*)context; - return astBuilder->createImpl<T>(); - } - static void destroy(void* ptr) - { - // Needed because if type has non dtor, Visual Studio claims ptr not used - SLANG_UNUSED(ptr); - reinterpret_cast<T*>(ptr)->~T(); - } - }; -}; - -#define SLANG_GET_SUPER_BASE(SUPER) nullptr -#define SLANG_GET_SUPER_INNER(SUPER) &SUPER::kReflectClassInfo -#define SLANG_GET_SUPER_LEAF(SUPER) &SUPER::kReflectClassInfo - -#define SLANG_GET_CREATE_FUNC_ABSTRACT_AST(NAME) nullptr -#define SLANG_GET_CREATE_FUNC_AST(NAME) &ASTConstructAccess::Impl<NAME>::create - -#define SLANG_GET_DESTROY_FUNC_ABSTRACT_AST(NAME) nullptr -#define SLANG_GET_DESTROY_FUNC_AST(NAME) &ASTConstructAccess::Impl<NAME>::destroy - -#define SLANG_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - /* static */ const ReflectClassInfo NAME::kReflectClassInfo = { \ - uint32_t(ASTNodeType::NAME), \ - uint32_t(ASTNodeType::LAST), \ - SLANG_GET_SUPER_##TYPE(SUPER), \ - #NAME, \ - SLANG_GET_CREATE_FUNC_##MARKER(NAME), \ - SLANG_GET_DESTROY_FUNC_##MARKER(NAME), \ - uint32_t(sizeof(NAME)), \ - uint8_t(SLANG_ALIGN_OF(NAME))}; - -SLANG_ALL_ASTNode_NodeBase(SLANG_REFLECT_CLASS_INFO, _) - -// We dispatch to non 'abstract' types -#define SLANG_CASE_AST(NAME) \ - case ASTNodeType::NAME: \ - return visitor->dispatch_##NAME(static_cast<NAME*>(this), extra); -#define SLANG_CASE_ABSTRACT_AST(NAME) - -#define SLANG_CASE_DISPATCH(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - SLANG_CASE_##MARKER(NAME) - - void Val::accept(IValVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_Val(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); - } -} - -void Type::accept(ITypeVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_Type(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); - } -} - -void Modifier::accept(IModifierVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_Modifier(SLANG_CASE_DISPATCH, _) default - : SLANG_ASSERT(!"Unknown type"); - } -} - -void DeclBase::accept(IDeclVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CASE_DISPATCH, _) default - : SLANG_ASSERT(!"Unknown type"); - } -} - -void Expr::accept(IExprVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_Expr(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); - } -} - -void Stmt::accept(IStmtVisitor* visitor, void* extra) -{ - const ReflectClassInfo& classInfo = getClassInfo(); - const ASTNodeType astType = ASTNodeType(classInfo.m_classId); - - switch (astType) - { - SLANG_CHILDREN_ASTNode_Stmt(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); - } -} - -} // namespace Slang diff --git a/source/slang/slang-ast-reflect.h b/source/slang/slang-ast-reflect.h deleted file mode 100644 index 56e42c8bd..000000000 --- a/source/slang/slang-ast-reflect.h +++ /dev/null @@ -1,59 +0,0 @@ -// slang-ast-reflect.h - -#ifndef SLANG_AST_REFLECT_H -#define SLANG_AST_REFLECT_H - -#include "slang-generated-ast.h" -#include "slang-serialize-reflection.h" - -// Implementation for SLANG_ABSTRACT_CLASS(x) using reflection from C++ extractor in -// slang-ast-generated.h -#define SLANG_AST_CLASS_REFLECT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ -protected: \ - NAME() = default; \ - \ -public: \ - typedef NAME This; \ - static constexpr ASTNodeType kType = ASTNodeType::NAME; \ - static const ReflectClassInfo kReflectClassInfo; \ - SLANG_FORCE_INLINE static bool isDerivedFrom(ASTNodeType type) \ - { \ - return int(type) >= int(kType) && int(type) <= int(ASTNodeType::LAST); \ - } \ - SLANG_CLASS_REFLECT_SUPER_##TYPE(SUPER) friend class ASTBuilder; \ - friend struct ASTConstructAccess; \ - friend struct ASTFieldAccess; \ - friend struct ASTDumpAccess; - -// Macro definitions - use the SLANG_ASTNode_ definitions to invoke the IMPL to produce the code -// injected into AST classes -#define SLANG_ABSTRACT_AST_CLASS(NAME) SLANG_ASTNode_##NAME(SLANG_AST_CLASS_REFLECT_IMPL, _) -#define SLANG_AST_CLASS(NAME) SLANG_ASTNode_##NAME(SLANG_AST_CLASS_REFLECT_IMPL, _) - -// Macros for simulating virtual methods without virtual methods - -#define SLANG_AST_NODE_INVOKE(method, methodParams) _##method##Override methodParams - -#define SLANG_AST_NODE_CASE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - case ASTNodeType::NAME: \ - return static_cast<NAME*>(this)->SLANG_AST_NODE_INVOKE param; - -#define SLANG_AST_NODE_VIRTUAL_CALL(base, methodName, methodParams) \ - switch (astNodeType) \ - { \ - SLANG_ALL_ASTNode_##base(SLANG_AST_NODE_CASE, (methodName, methodParams)) default \ - : return SLANG_AST_NODE_INVOKE(methodName, methodParams); \ - } - -// Same but for a method that's const -#define SLANG_AST_NODE_CONST_CASE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - case ASTNodeType::NAME: \ - return static_cast<const NAME*>(this)->SLANG_AST_NODE_INVOKE param; -#define SLANG_AST_NODE_CONST_VIRTUAL_CALL(base, methodName, methodParams) \ - switch (astNodeType) \ - { \ - SLANG_ALL_ASTNode_##base(SLANG_AST_NODE_CONST_CASE, (methodName, methodParams)) default \ - : return SLANG_AST_NODE_INVOKE(methodName, methodParams); \ - } - -#endif // SLANG_AST_REFLECT_H diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index 4107664bf..a1b7c274e 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -1,55 +1,56 @@ // slang-ast-stmt.h - #pragma once #include "slang-ast-base.h" +#include "slang-ast-stmt.h.fiddle" +FIDDLE() namespace Slang { // Syntax class definitions for statements. +FIDDLE(abstract) class ScopeStmt : public Stmt { - SLANG_ABSTRACT_AST_CLASS(ScopeStmt) - + FIDDLE(...) ScopeDecl* scopeDecl = nullptr; }; // A sequence of statements, treated as a single statement +FIDDLE() class SeqStmt : public Stmt { - SLANG_AST_CLASS(SeqStmt) - - List<Stmt*> stmts; + FIDDLE(...) + FIDDLE() List<Stmt*> stmts; }; // A statement with a label. +FIDDLE() class LabelStmt : public Stmt { - SLANG_AST_CLASS(LabelStmt) - - Token label; - Stmt* innerStmt; + FIDDLE(...) + FIDDLE() Token label; + FIDDLE() Stmt* innerStmt; }; // The simplest kind of scope statement: just a `{...}` block +FIDDLE() class BlockStmt : public ScopeStmt { - SLANG_AST_CLASS(BlockStmt) - + FIDDLE(...) /// TODO(JS): Having ranges of sourcelocs might be a good addition to AST nodes in general. SourceLoc closingSourceLoc; ///< The source location of the closing brace - Stmt* body = nullptr; + FIDDLE() Stmt* body = nullptr; }; // A statement that we aren't going to parse or check, because // we want to let a downstream compiler handle any issues +FIDDLE() class UnparsedStmt : public Stmt { - SLANG_AST_CLASS(UnparsedStmt) - + FIDDLE(...) // The tokens that were contained between `{` and `}` List<Token> tokens; Scope* currentScope = nullptr; @@ -58,41 +59,45 @@ class UnparsedStmt : public Stmt bool isInVariadicGenerics = false; }; +FIDDLE() class EmptyStmt : public Stmt { - SLANG_AST_CLASS(EmptyStmt) + FIDDLE(...) }; +FIDDLE() class DiscardStmt : public Stmt { - SLANG_AST_CLASS(DiscardStmt) + FIDDLE(...) }; +FIDDLE() class DeclStmt : public Stmt { - SLANG_AST_CLASS(DeclStmt) - - DeclBase* decl = nullptr; + FIDDLE(...) + FIDDLE() DeclBase* decl = nullptr; }; +FIDDLE() class IfStmt : public Stmt { - SLANG_AST_CLASS(IfStmt) - - Expr* predicate = nullptr; - Stmt* positiveStatement = nullptr; - Stmt* negativeStatement = nullptr; + FIDDLE(...) + FIDDLE() Expr* predicate = nullptr; + FIDDLE() Stmt* positiveStatement = nullptr; + FIDDLE() Stmt* negativeStatement = nullptr; }; +FIDDLE() class UniqueStmtIDNode : public Decl { - SLANG_AST_CLASS(UniqueStmtIDNode) + FIDDLE(...) }; // A statement that can be escaped with a `break` +FIDDLE(abstract) class BreakableStmt : public ScopeStmt { - SLANG_ABSTRACT_AST_CLASS(BreakableStmt) + FIDDLE(...) /// A unique ID for this statement. /// @@ -106,20 +111,21 @@ class BreakableStmt : public ScopeStmt static constexpr UniqueID kInvalidUniqueID = nullptr; }; +FIDDLE() class SwitchStmt : public BreakableStmt { - SLANG_AST_CLASS(SwitchStmt) - - Expr* condition = nullptr; - Stmt* body = nullptr; + FIDDLE(...) + FIDDLE() Expr* condition = nullptr; + FIDDLE() Stmt* body = nullptr; }; // A statement that is expected to appear lexically nested inside // some other construct, and thus needs to keep track of the // outer statement that it is associated with... +FIDDLE(abstract) class ChildStmt : public Stmt { - SLANG_ABSTRACT_AST_CLASS(ChildStmt) + FIDDLE(...) /// The unique ID of the enclosing statement this /// child statement refers to. @@ -127,33 +133,35 @@ class ChildStmt : public Stmt BreakableStmt::UniqueID targetOuterStmtID = BreakableStmt::kInvalidUniqueID; }; +FIDDLE() class TargetCaseStmt : public ChildStmt { - SLANG_AST_CLASS(TargetCaseStmt) - int32_t capability; - Token capabilityToken; - Stmt* body = nullptr; + FIDDLE(...) + FIDDLE() int32_t capability; + FIDDLE() Token capabilityToken; + FIDDLE() Stmt* body = nullptr; }; +FIDDLE() class TargetSwitchStmt : public BreakableStmt { - SLANG_AST_CLASS(TargetSwitchStmt) - - List<TargetCaseStmt*> targetCases; + FIDDLE(...) + FIDDLE() List<TargetCaseStmt*> targetCases; }; +FIDDLE() class StageSwitchStmt : public TargetSwitchStmt { - SLANG_AST_CLASS(StageSwitchStmt) + FIDDLE(...) }; +FIDDLE() class IntrinsicAsmStmt : public Stmt { - SLANG_AST_CLASS(IntrinsicAsmStmt) + FIDDLE(...) + FIDDLE() String asmText; - String asmText; - - List<Expr*> args; + FIDDLE() List<Expr*> args; }; // a `case` or `default` statement inside a `switch` @@ -161,129 +169,136 @@ class IntrinsicAsmStmt : public Stmt // Note(tfoley): A correct AST for a C-like language would treat // these as a labelled statement, and so they would contain a // sub-statement. I'm leaving that out for now for simplicity. +FIDDLE(abstract) class CaseStmtBase : public ChildStmt { - SLANG_ABSTRACT_AST_CLASS(CaseStmtBase) + FIDDLE(...) }; // a `case` statement inside a `switch` +FIDDLE() class CaseStmt : public CaseStmtBase { - SLANG_AST_CLASS(CaseStmt) - - Expr* expr = nullptr; + FIDDLE(...) + FIDDLE() Expr* expr = nullptr; - Val* exprVal = nullptr; + FIDDLE() Val* exprVal = nullptr; }; // a `default` statement inside a `switch` +FIDDLE() class DefaultStmt : public CaseStmtBase { - SLANG_AST_CLASS(DefaultStmt) + FIDDLE(...) }; // a `default` statement inside a `switch` +FIDDLE() class GpuForeachStmt : public ScopeStmt { - SLANG_AST_CLASS(GpuForeachStmt) - - Expr* device = nullptr; - Expr* gridDims = nullptr; - VarDecl* dispatchThreadID = nullptr; - Expr* kernelCall = nullptr; + FIDDLE(...) + FIDDLE() Expr* device = nullptr; + FIDDLE() Expr* gridDims = nullptr; + FIDDLE() VarDecl* dispatchThreadID = nullptr; + FIDDLE() Expr* kernelCall = nullptr; }; // A statement that represents a loop, and can thus be escaped with a `continue` +FIDDLE(abstract) class LoopStmt : public BreakableStmt { - SLANG_ABSTRACT_AST_CLASS(LoopStmt) + FIDDLE(...) }; // A `for` statement +FIDDLE() class ForStmt : public LoopStmt { - SLANG_AST_CLASS(ForStmt) - - Stmt* initialStatement = nullptr; - Expr* sideEffectExpression = nullptr; - Expr* predicateExpression = nullptr; - Stmt* statement = nullptr; + FIDDLE(...) + FIDDLE() Stmt* initialStatement = nullptr; + FIDDLE() Expr* sideEffectExpression = nullptr; + FIDDLE() Expr* predicateExpression = nullptr; + FIDDLE() Stmt* statement = nullptr; }; // A `for` statement in a language that doesn't restrict the scope // of the loop variable to the body. +FIDDLE() class UnscopedForStmt : public ForStmt { - SLANG_AST_CLASS(UnscopedForStmt); + FIDDLE(...) }; +FIDDLE() class WhileStmt : public LoopStmt { - SLANG_AST_CLASS(WhileStmt) - - Expr* predicate = nullptr; - Stmt* statement = nullptr; + FIDDLE(...) + FIDDLE() Expr* predicate = nullptr; + FIDDLE() Stmt* statement = nullptr; }; +FIDDLE() class DoWhileStmt : public LoopStmt { - SLANG_AST_CLASS(DoWhileStmt) - - Stmt* statement = nullptr; - Expr* predicate = nullptr; + FIDDLE(...) + FIDDLE() Stmt* statement = nullptr; + FIDDLE() Expr* predicate = nullptr; }; // A compile-time, range-based `for` loop, which will not appear in the output code +FIDDLE() class CompileTimeForStmt : public ScopeStmt { - SLANG_AST_CLASS(CompileTimeForStmt) - - VarDecl* varDecl = nullptr; - Expr* rangeBeginExpr = nullptr; - Expr* rangeEndExpr = nullptr; - Stmt* body = nullptr; - IntVal* rangeBeginVal = nullptr; - IntVal* rangeEndVal = nullptr; + FIDDLE(...) + FIDDLE() VarDecl* varDecl = nullptr; + FIDDLE() Expr* rangeBeginExpr = nullptr; + FIDDLE() Expr* rangeEndExpr = nullptr; + FIDDLE() Stmt* body = nullptr; + FIDDLE() IntVal* rangeBeginVal = nullptr; + FIDDLE() IntVal* rangeEndVal = nullptr; }; // The case of child statements that do control flow relative // to their parent statement. +FIDDLE(abstract) class JumpStmt : public ChildStmt { - SLANG_ABSTRACT_AST_CLASS(JumpStmt) + FIDDLE(...) }; +FIDDLE() class BreakStmt : public JumpStmt { - SLANG_AST_CLASS(BreakStmt) - + FIDDLE(...) Token targetLabel; }; +FIDDLE() class ContinueStmt : public JumpStmt { - SLANG_AST_CLASS(ContinueStmt) + FIDDLE(...) }; +FIDDLE() class ReturnStmt : public Stmt { - SLANG_AST_CLASS(ReturnStmt) - - Expr* expression = nullptr; + FIDDLE(...) + FIDDLE() Expr* expression = nullptr; }; +FIDDLE() class DeferStmt : public Stmt { - SLANG_AST_CLASS(DeferStmt) + FIDDLE(...) - Stmt* statement = nullptr; + FIDDLE() Stmt* statement = nullptr; }; +FIDDLE() class ExpressionStmt : public Stmt { - SLANG_AST_CLASS(ExpressionStmt) - - Expr* expression = nullptr; + FIDDLE(...) + FIDDLE() Expr* expression = nullptr; }; } // namespace Slang diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index d59b6b286..3ac352f0a 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -1,3 +1,4 @@ +// slang-ast-support-types.cpp #include "slang-ast-support-types.h" #include "slang-ast-base.h" diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index b3baee98f..87715d9e0 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -6,1716 +6,1731 @@ #include "../compiler-core/slang-name.h" #include "../core/slang-basic.h" #include "../core/slang-semantic-version.h" -#include "slang-ast-reflect.h" -#include "slang-generated-ast.h" +#include "slang-ast-forward-declarations.h" +#include "slang-ast-support-types.h.fiddle" #include "slang-profile.h" -#include "slang-ref-object-reflect.h" -#include "slang-serialize-reflection.h" #include "slang-type-system-shared.h" #include "slang.h" #include <assert.h> #include <type_traits> -namespace Slang -{ -class Module; -class Name; -class Session; -class SyntaxVisitor; -class FuncDecl; -class Layout; - -struct IExprVisitor; -struct IDeclVisitor; -struct IModifierVisitor; -struct IStmtVisitor; -struct ITypeVisitor; -struct IValVisitor; - -class Parser; -class SyntaxNode; - -class Decl; -struct QualType; -class Type; -struct TypeExp; -class Val; - -class NodeBase; -class LookupDeclRef; -class GenericAppDeclRef; -struct CapabilitySet; - -template<typename T> -T* as(NodeBase* node); - -template<typename T> -const T* as(const NodeBase* node); - -void printDiagnosticArg(StringBuilder& sb, Decl* decl); -void printDiagnosticArg(StringBuilder& sb, Type* type); -void printDiagnosticArg(StringBuilder& sb, TypeExp const& type); -void printDiagnosticArg(StringBuilder& sb, QualType const& type); -void printDiagnosticArg(StringBuilder& sb, Val* val); -void printDiagnosticArg(StringBuilder& sb, DeclRefBase* declRefBase); -void printDiagnosticArg(StringBuilder& sb, ASTNodeType nodeType); -void printDiagnosticArg(StringBuilder& sb, const CapabilitySet& set); -void printDiagnosticArg(StringBuilder& sb, List<CapabilityAtom>& set); - -struct QualifiedDeclPath -{ - DeclRefBase* declRef; - QualifiedDeclPath() = default; - QualifiedDeclPath(DeclRefBase* declRef) - : declRef(declRef) - { - } -}; -// Prints the fully qualified decl name. -void printDiagnosticArg(StringBuilder& sb, QualifiedDeclPath path); +#define SLANG_UNREFLECTED /* empty */ +FIDDLE(hidden class RefObject;) -class SyntaxNode; -SourceLoc getDiagnosticPos(SyntaxNode const* syntax); -SourceLoc getDiagnosticPos(TypeExp const& typeExp); -SourceLoc getDiagnosticPos(DeclRefBase* declRef); -SourceLoc getDiagnosticPos(Decl* decl); - -typedef NodeBase* (*SyntaxParseCallback)(Parser* parser, void* userData); - -typedef unsigned int ConversionCost; -enum : ConversionCost +FIDDLE() namespace Slang { - // No conversion at all - kConversionCost_None = 0, - - kConversionCost_GenericParamUpcast = 1, - kConversionCost_UnconstraintGenericParam = 20, - kConversionCost_SizedArrayToUnsizedArray = 30, - - // Convert between matrices of different layout - kConversionCost_MatrixLayout = 5, - - // Conversion from a buffer to the type it carries needs to add a minimal - // extra cost, just so we can distinguish an overload on `ConstantBuffer<Foo>` - // from one on `Foo` - kConversionCost_GetRef = 5, - kConversionCost_ImplicitDereference = 10, - kConversionCost_InRangeIntLitConversion = 23, - kConversionCost_InRangeIntLitSignedToUnsignedConversion = 32, - kConversionCost_InRangeIntLitUnsignedToSignedConversion = 81, - - kConversionCost_MutablePtrToConstPtr = 20, - - // Conversions based on explicit sub-typing relationships are the cheapest - // - // TODO(tfoley): We will eventually need a discipline for ranking - // when two up-casts are comparable. - kConversionCost_CastToInterface = 50, - - // Conversion that is lossless and keeps the "kind" of the value the same - kConversionCost_BoolToInt = - 120, // Converting bool to int has lower cost than other integer types to prevent ambiguity. - kConversionCost_RankPromotion = 150, - kConversionCost_NoneToOptional = 150, - kConversionCost_ValToOptional = 150, - kConversionCost_NullPtrToPtr = 150, - kConversionCost_PtrToVoidPtr = 150, - - // Conversions that are lossless, but change "kind" - kConversionCost_UnsignedToSignedPromotion = 200, - - // Same-size size unsigned->signed conversions are potentially lossy, but they are commonly - // allowed silently. - kConversionCost_SameSizeUnsignedToSignedConversion = 300, - - // Conversion from signed->unsigned integer of same or greater size - kConversionCost_SignedToUnsignedConversion = 250, +#define SLANG_AST_NODE_VIRTUAL_CALL(CLASS, METHOD, ARGS) \ + return ASTNodeDispatcher<CLASS, decltype(this->METHOD ARGS)>::dispatch( \ + this, \ + [&](auto _this) -> decltype(this->METHOD ARGS) \ + { return _this->_##METHOD##Override ARGS; }); + + class Module; + class Name; + class Session; + class SyntaxVisitor; + class FuncDecl; + class Layout; + + class Parser; + class SyntaxNode; + + class Decl; + struct QualType; + class Type; + struct TypeExp; + class Val; + + class DeclRefBase; + class NodeBase; + class LookupDeclRef; + class GenericAppDeclRef; + struct CapabilitySet; - // Cost of converting an integer to a floating-point type - kConversionCost_IntegerToFloatConversion = 400, - - // Cost of converting a pointer to bool - kConversionCost_PtrToBool = 400, - - // Cost of converting an integer to int16_t - kConversionCost_IntegerTruncate = 450, + template<typename T> + T* as(NodeBase * node); - // Cost of converting an integer to a half type - kConversionCost_IntegerToHalfConversion = 500, + template<typename T> + const T* as(const NodeBase* node); + + void printDiagnosticArg(StringBuilder & sb, Decl * decl); + void printDiagnosticArg(StringBuilder & sb, Type * type); + void printDiagnosticArg(StringBuilder & sb, TypeExp const& type); + void printDiagnosticArg(StringBuilder & sb, QualType const& type); + void printDiagnosticArg(StringBuilder & sb, Val * val); + void printDiagnosticArg(StringBuilder & sb, DeclRefBase * declRefBase); + void printDiagnosticArg(StringBuilder & sb, ASTNodeType nodeType); + void printDiagnosticArg(StringBuilder & sb, const CapabilitySet& set); + void printDiagnosticArg(StringBuilder & sb, List<CapabilityAtom> & set); + + struct QualifiedDeclPath + { + DeclRefBase* declRef; + QualifiedDeclPath() = default; + QualifiedDeclPath(DeclRefBase* declRef) + : declRef(declRef) + { + } + }; + // Prints the fully qualified decl name. + void printDiagnosticArg(StringBuilder & sb, QualifiedDeclPath path); - // Cost of using a concrete argument pack - kConversionCost_ParameterPack = 500, - // Default case (usable for user-defined conversions) - kConversionCost_Default = 500, + class SyntaxNode; + SourceLoc getDiagnosticPos(SyntaxNode const* syntax); + SourceLoc getDiagnosticPos(TypeExp const& typeExp); + SourceLoc getDiagnosticPos(DeclRefBase * declRef); + SourceLoc getDiagnosticPos(Decl * decl); + + typedef NodeBase* (*SyntaxParseCallback)(Parser* parser, void* userData); + + typedef unsigned int ConversionCost; + enum : ConversionCost + { + // No conversion at all + kConversionCost_None = 0, + + kConversionCost_GenericParamUpcast = 1, + kConversionCost_UnconstraintGenericParam = 20, + kConversionCost_SizedArrayToUnsizedArray = 30, + + // Convert between matrices of different layout + kConversionCost_MatrixLayout = 5, + + // Conversion from a buffer to the type it carries needs to add a minimal + // extra cost, just so we can distinguish an overload on `ConstantBuffer<Foo>` + // from one on `Foo` + kConversionCost_GetRef = 5, + kConversionCost_ImplicitDereference = 10, + kConversionCost_InRangeIntLitConversion = 23, + kConversionCost_InRangeIntLitSignedToUnsignedConversion = 32, + kConversionCost_InRangeIntLitUnsignedToSignedConversion = 81, + + kConversionCost_MutablePtrToConstPtr = 20, + + // Conversions based on explicit sub-typing relationships are the cheapest + // + // TODO(tfoley): We will eventually need a discipline for ranking + // when two up-casts are comparable. + kConversionCost_CastToInterface = 50, + + // Conversion that is lossless and keeps the "kind" of the value the same + kConversionCost_BoolToInt = 120, // Converting bool to int has lower cost than other integer + // types to prevent ambiguity. + kConversionCost_RankPromotion = 150, + kConversionCost_NoneToOptional = 150, + kConversionCost_ValToOptional = 150, + kConversionCost_NullPtrToPtr = 150, + kConversionCost_PtrToVoidPtr = 150, + + // Conversions that are lossless, but change "kind" + kConversionCost_UnsignedToSignedPromotion = 200, + + // Same-size size unsigned->signed conversions are potentially lossy, but they are commonly + // allowed silently. + kConversionCost_SameSizeUnsignedToSignedConversion = 300, + + // Conversion from signed->unsigned integer of same or greater size + kConversionCost_SignedToUnsignedConversion = 250, + + // Cost of converting an integer to a floating-point type + kConversionCost_IntegerToFloatConversion = 400, + + // Cost of converting a pointer to bool + kConversionCost_PtrToBool = 400, + + // Cost of converting an integer to int16_t + kConversionCost_IntegerTruncate = 450, + + // Cost of converting an integer to a half type + kConversionCost_IntegerToHalfConversion = 500, + + // Cost of using a concrete argument pack + kConversionCost_ParameterPack = 500, + + // Default case (usable for user-defined conversions) + kConversionCost_Default = 500, + + // Catch-all for conversions that should be discouraged + // (i.e., that really shouldn't be made implicitly) + // + // TODO: make these conversions not be allowed implicitly in "Slang mode" + kConversionCost_GeneralConversion = 900, + + // This is the cost of an explicit conversion, which should + // not actually be performed. + kConversionCost_Explicit = 90000, + + // Additional conversion cost to add when promoting from a scalar to + // a vector (this will be added to the cost, if any, of converting + // the element type of the vector) + kConversionCost_OneVectorToScalar = 1, + kConversionCost_ScalarToVector = 2, + kConversionCost_ScalarToMatrix = 10, + kConversionCost_ScalarIntegerToFloatMatrix = + kConversionCost_IntegerToFloatConversion + kConversionCost_ScalarToMatrix, + + // Additional conversion cost to add when promoting from a scalar to + // a CoopVector (this will be added to the cost, if any, of converting + // the element type of the CoopVector) + kConversionCost_ScalarToCoopVector = 1, + + // Additional cost when casting an LValue. + kConversionCost_LValueCast = 800, + + // The cost of this conversion is defined by the type coercion constraint. + kConversionCost_TypeCoercionConstraint = 1000, + kConversionCost_TypeCoercionConstraintPlusScalarToVector = + kConversionCost_TypeCoercionConstraint + kConversionCost_ScalarToVector, + + // Conversion is impossible + kConversionCost_Impossible = 0xFFFFFFFF, + }; - // Catch-all for conversions that should be discouraged - // (i.e., that really shouldn't be made implicitly) - // - // TODO: make these conversions not be allowed implicitly in "Slang mode" - kConversionCost_GeneralConversion = 900, - - // This is the cost of an explicit conversion, which should - // not actually be performed. - kConversionCost_Explicit = 90000, - - // Additional conversion cost to add when promoting from a scalar to - // a vector (this will be added to the cost, if any, of converting - // the element type of the vector) - kConversionCost_OneVectorToScalar = 1, - kConversionCost_ScalarToVector = 2, - kConversionCost_ScalarToMatrix = 10, - kConversionCost_ScalarIntegerToFloatMatrix = - kConversionCost_IntegerToFloatConversion + kConversionCost_ScalarToMatrix, - - // Additional conversion cost to add when promoting from a scalar to - // a CoopVector (this will be added to the cost, if any, of converting - // the element type of the CoopVector) - kConversionCost_ScalarToCoopVector = 1, - - // Additional cost when casting an LValue. - kConversionCost_LValueCast = 800, - - // The cost of this conversion is defined by the type coercion constraint. - kConversionCost_TypeCoercionConstraint = 1000, - kConversionCost_TypeCoercionConstraintPlusScalarToVector = - kConversionCost_TypeCoercionConstraint + kConversionCost_ScalarToVector, - - // Conversion is impossible - kConversionCost_Impossible = 0xFFFFFFFF, -}; - -typedef unsigned int BuiltinConversionKind; -enum : BuiltinConversionKind -{ - kBuiltinConversion_Unknown = 0, - kBuiltinConversion_FloatToDouble = 1, -}; + typedef unsigned int BuiltinConversionKind; + enum : BuiltinConversionKind + { + kBuiltinConversion_Unknown = 0, + kBuiltinConversion_FloatToDouble = 1, + }; -enum class ImageFormat -{ + enum class ImageFormat + { #define SLANG_FORMAT(NAME, OTHER) NAME, #include "slang-image-format-defs.h" #undef SLANG_FORMAT -}; - -struct ImageFormatInfo -{ - SlangScalarType scalarType; ///< If image format is not made up of channels of set sizes this - ///< will be SLANG_SCALAR_TYPE_NONE - uint8_t channelCount; ///< The number of channels - uint8_t sizeInBytes; ///< Size in bytes - UnownedStringSlice name; ///< The name associated with this type. NOTE! Currently these names - ///< *are* the GLSL format names. -}; - -const ImageFormatInfo& getImageFormatInfo(ImageFormat format); - -bool findImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); -bool findVkImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); - -char const* getGLSLNameForImageFormat(ImageFormat format); - -// TODO(tfoley): We should ditch this enumeration -// and just use the IR opcodes that represent these -// types directly. The one major complication there -// is that the order of the enum values currently -// matters, since it determines promotion rank. -// We either need to keep that restriction, or -// look up promotion rank by some other means. -// - -class Decl; -class Val; - -// Helper type for pairing up a name and the location where it appeared -struct NameLoc -{ - Name* name; - SourceLoc loc; + }; - NameLoc() - : name(nullptr) + struct ImageFormatInfo { - } + SlangScalarType scalarType; ///< If image format is not made up of channels of set sizes + ///< this will be SLANG_SCALAR_TYPE_NONE + uint8_t channelCount; ///< The number of channels + uint8_t sizeInBytes; ///< Size in bytes + UnownedStringSlice name; ///< The name associated with this type. NOTE! Currently these + ///< names *are* the GLSL format names. + }; - explicit NameLoc(Name* inName) - : name(inName) - { - } + const ImageFormatInfo& getImageFormatInfo(ImageFormat format); + bool findImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); + bool findVkImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); - NameLoc(Name* inName, SourceLoc inLoc) - : name(inName), loc(inLoc) - { - } + char const* getGLSLNameForImageFormat(ImageFormat format); - NameLoc(Token const& token) - : name(token.getNameOrNull()), loc(token.getLoc()) - { - } -}; + // TODO(tfoley): We should ditch this enumeration + // and just use the IR opcodes that represent these + // types directly. The one major complication there + // is that the order of the enum values currently + // matters, since it determines promotion rank. + // We either need to keep that restriction, or + // look up promotion rank by some other means. + // -struct StringSliceLoc -{ - UnownedStringSlice name; - SourceLoc loc; + class Decl; + class Val; - StringSliceLoc() - : name(nullptr) - { - } - explicit StringSliceLoc(const UnownedStringSlice& inName) - : name(inName) - { - } - StringSliceLoc(const UnownedStringSlice& inName, SourceLoc inLoc) - : name(inName), loc(inLoc) + // Helper type for pairing up a name and the location where it appeared + struct NameLoc { - } - StringSliceLoc(Token const& token) - : loc(token.getLoc()) - { - Name* tokenName = token.getNameOrNull(); - if (tokenName) + Name* name; + SourceLoc loc; + + NameLoc() + : name(nullptr) { - name = tokenName->text.getUnownedSlice(); } - } -}; - -// Helper class for iterating over a list of heap-allocated modifiers -struct ModifierList -{ - struct Iterator - { - Modifier* current = nullptr; - - Modifier* operator*() { return current; } - void operator++(); + explicit NameLoc(Name* inName) + : name(inName) + { + } - bool operator!=(Iterator other) { return current != other.current; }; - Iterator() - : current(nullptr) + NameLoc(Name* inName, SourceLoc inLoc) + : name(inName), loc(inLoc) { } - Iterator(Modifier* modifier) - : current(modifier) + NameLoc(Token const& token) + : name(token.getNameOrNull()), loc(token.getLoc()) { } }; - ModifierList() - : modifiers(nullptr) + struct StringSliceLoc { - } + UnownedStringSlice name; + SourceLoc loc; - ModifierList(Modifier* modifiers) - : modifiers(modifiers) - { - } + StringSliceLoc() + : name(nullptr) + { + } + explicit StringSliceLoc(const UnownedStringSlice& inName) + : name(inName) + { + } + StringSliceLoc(const UnownedStringSlice& inName, SourceLoc inLoc) + : name(inName), loc(inLoc) + { + } + StringSliceLoc(Token const& token) + : loc(token.getLoc()) + { + Name* tokenName = token.getNameOrNull(); + if (tokenName) + { + name = tokenName->text.getUnownedSlice(); + } + } + }; - Iterator begin() { return Iterator(modifiers); } - Iterator end() { return Iterator(nullptr); } + // Helper class for iterating over a list of heap-allocated modifiers + struct ModifierList + { + struct Iterator + { + Modifier* current = nullptr; - Modifier* modifiers = nullptr; -}; + Modifier* operator*() { return current; } -// Helper class for iterating over heap-allocated modifiers -// of a specific type. -template<typename T> -struct FilteredModifierList -{ - struct Iterator - { - Modifier* current = nullptr; + void operator++(); - T* operator*() { return (T*)current; } + bool operator!=(Iterator other) { return current != other.current; }; - void operator++(); + Iterator() + : current(nullptr) + { + } - bool operator!=(Iterator other) { return current != other.current; }; + Iterator(Modifier* modifier) + : current(modifier) + { + } + }; - Iterator() - : current(nullptr) + ModifierList() + : modifiers(nullptr) { } - Iterator(Modifier* modifier) - : current(modifier) + ModifierList(Modifier* modifiers) + : modifiers(modifiers) { } + + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } + + Modifier* modifiers = nullptr; }; - FilteredModifierList() - : modifiers(nullptr) + // Helper class for iterating over heap-allocated modifiers + // of a specific type. + template<typename T> + struct FilteredModifierList { - } + struct Iterator + { + Modifier* current = nullptr; - FilteredModifierList(Modifier* modifiers) - : modifiers(adjust(modifiers)) - { - } + T* operator*() { return (T*)current; } - Iterator begin() { return Iterator(modifiers); } - Iterator end() { return Iterator(nullptr); } + void operator++(); - static Modifier* adjust(Modifier* modifier); + bool operator!=(Iterator other) { return current != other.current; }; - Modifier* modifiers = nullptr; -}; + Iterator() + : current(nullptr) + { + } -// A set of modifiers attached to a syntax node -struct Modifiers -{ - // The first modifier in the linked list of heap-allocated modifiers - Modifier* first = nullptr; + Iterator(Modifier* modifier) + : current(modifier) + { + } + }; - template<typename T> - FilteredModifierList<T> getModifiersOfType() - { - return FilteredModifierList<T>(first); - } + FilteredModifierList() + : modifiers(nullptr) + { + } - // Find the first modifier of a given type, or return `nullptr` if none is found. - template<typename T> - T* findModifier() - { - return *getModifiersOfType<T>().begin(); - } + FilteredModifierList(Modifier* modifiers) + : modifiers(adjust(modifiers)) + { + } - template<typename T> - bool hasModifier() - { - return findModifier<T>() != nullptr; - } + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } - /// True if has no modifiers - bool isEmpty() const { return first == nullptr; } + static Modifier* adjust(Modifier* modifier); - FilteredModifierList<Modifier>::Iterator begin() - { - return FilteredModifierList<Modifier>::Iterator(first); - } - FilteredModifierList<Modifier>::Iterator end() + Modifier* modifiers = nullptr; + }; + + // A set of modifiers attached to a syntax node + struct Modifiers { - return FilteredModifierList<Modifier>::Iterator(nullptr); - } -}; + // The first modifier in the linked list of heap-allocated modifiers + Modifier* first = nullptr; -class NamedExpressionType; -class GenericDecl; -class ContainerDecl; + template<typename T> + FilteredModifierList<T> getModifiersOfType() + { + return FilteredModifierList<T>(first); + } -// Try to extract a simple integer value from an `IntVal`. -// This fill assert-fail if the object doesn't represent a literal value. -IntegerLiteralValue getIntVal(IntVal* val); + // Find the first modifier of a given type, or return `nullptr` if none is found. + template<typename T> + T* findModifier() + { + return *getModifiersOfType<T>().begin(); + } -/// Represents how much checking has been applied to a declaration. -enum class DeclCheckState : uint8_t -{ - /// The declaration has been parsed, but - /// is otherwise completely unchecked. - /// - Unchecked, + template<typename T> + bool hasModifier() + { + return findModifier<T>() != nullptr; + } - /// The declaration is parsed and inserted into the initial scope, - /// ready for future lookups from within the parser for disambiguation purposes. - ReadyForParserLookup, + /// True if has no modifiers + bool isEmpty() const { return first == nullptr; } - /// Basic checks on the modifiers of the declaration have been applied. - /// - /// For example, when a declaration has attributes, the transformation - /// of an attribute from the parsed-but-unchecked form into a checked - /// form (in which it has the appropriate C++ subclass) happens here. - /// - ModifiersChecked, + FilteredModifierList<Modifier>::Iterator begin() + { + return FilteredModifierList<Modifier>::Iterator(first); + } + FilteredModifierList<Modifier>::Iterator end() + { + return FilteredModifierList<Modifier>::Iterator(nullptr); + } + }; - /// Wiring up scopes of namespaces with their siblings defined in different - /// files/modules, and other namespaces imported via `using`. - ScopesWired, + class NamedExpressionType; + class GenericDecl; + class ContainerDecl; - /// The type/signature of the declaration has been checked. - /// - /// For a value declaration like a variable or function, this means that - /// the type of the declaration can be queried. - /// - /// For a type declaration like a `struct` or `typedef` this means - /// that a `Type` referring to that declaration can be formed. - /// - SignatureChecked, + // Try to extract a simple integer value from an `IntVal`. + // This fill assert-fail if the object doesn't represent a literal value. + IntegerLiteralValue getIntVal(IntVal * val); - /// The declaration's basic signature has been checked to the point that - /// it is ready to be referenced in other places. - /// - /// For a function, this means that it has been organized into a - /// "redeclration group" if there are multiple functions with the - /// same name in a scope. - /// - ReadyForReference, + /// Represents how much checking has been applied to a declaration. + enum class DeclCheckState : uint8_t + { + /// The declaration has been parsed, but + /// is otherwise completely unchecked. + /// + Unchecked, - /// The declaration is ready for lookup operations to be performed. - /// - /// For type declarations (e.g., aggregate types, generic type parameters) - /// this means that any base type or constraint clauses have been - /// sufficiently checked so that we can enumerate the inheritance - /// hierarchy of the type and discover all its members. - /// - ReadyForLookup, + /// The declaration is parsed and inserted into the initial scope, + /// ready for future lookups from within the parser for disambiguation purposes. + ReadyForParserLookup, - /// Any conformance declared on the declaration have been validated. - /// - /// In particular, this step means that a "witness table" has been - /// created to show how a type satisfies the requirements of any - /// interfaces it conforms to. - /// - ReadyForConformances, + /// Basic checks on the modifiers of the declaration have been applied. + /// + /// For example, when a declaration has attributes, the transformation + /// of an attribute from the parsed-but-unchecked form into a checked + /// form (in which it has the appropriate C++ subclass) happens here. + /// + ModifiersChecked, - /// Any DeclRefTypes with substitutions have been fully resolved - /// to concrete type. E.g. `T.X` with `T=A` should resolve to `A.X`. - /// We need a separate pass to resolve these types because `A.X` - /// maybe synthesized and made available only after conformance checking. - TypesFullyResolved, + /// Wiring up scopes of namespaces with their siblings defined in different + /// files/modules, and other namespaces imported via `using`. + ScopesWired, - /// All attributes are fully checked. This is the final step before - /// checking the function body. - AttributesChecked, + /// The type/signature of the declaration has been checked. + /// + /// For a value declaration like a variable or function, this means that + /// the type of the declaration can be queried. + /// + /// For a type declaration like a `struct` or `typedef` this means + /// that a `Type` referring to that declaration can be formed. + /// + SignatureChecked, - /// The body/definition is checked. - /// - /// This step includes any validation of the declaration that is - /// immaterial to clients code using the declaration, but that is - /// nonetheless relevant to checking correctness. - /// - /// The canonical example here is checking the body of functions. - /// Client code cannot depend on *how* a function is implemented, - /// but we still need to (eventually) check the bodies of all - /// functions, so it belongs in the last phase of checking. - /// - DefinitionChecked, - DefaultConstructorReadyForUse = DefinitionChecked, + /// The declaration's basic signature has been checked to the point that + /// it is ready to be referenced in other places. + /// + /// For a function, this means that it has been organized into a + /// "redeclration group" if there are multiple functions with the + /// same name in a scope. + /// + ReadyForReference, - /// The capabilities required by the decl is infered and validated. - /// - CapabilityChecked, + /// The declaration is ready for lookup operations to be performed. + /// + /// For type declarations (e.g., aggregate types, generic type parameters) + /// this means that any base type or constraint clauses have been + /// sufficiently checked so that we can enumerate the inheritance + /// hierarchy of the type and discover all its members. + /// + ReadyForLookup, - // For convenience at sites that call `ensureDecl()`, we define - // some aliases for the above states that are expressed in terms - // of what client code needs to be able to do with a declaration. - // - // These aliases can be changed over time if we decide to add - // more phases to semantic checking. - - CanEnumerateBases = ReadyForLookup, - CanUseBaseOfInheritanceDecl = ReadyForLookup, - CanUseTypeOfValueDecl = ReadyForReference, - CanUseExtensionTargetType = ReadyForLookup, - CanUseAsType = ReadyForReference, - CanUseFuncSignature = ReadyForReference, - CanSpecializeGeneric = ReadyForReference, - CanReadInterfaceRequirements = ReadyForLookup, -}; - -/// A `DeclCheckState` plus a bit to track whether a declaration is currently being checked. -struct DeclCheckStateExt -{ - SLANG_VALUE_CLASS(DeclCheckStateExt) + /// Any conformance declared on the declaration have been validated. + /// + /// In particular, this step means that a "witness table" has been + /// created to show how a type satisfies the requirements of any + /// interfaces it conforms to. + /// + ReadyForConformances, - typedef uint8_t RawType; - DeclCheckStateExt() {} - DeclCheckStateExt(DeclCheckState state) - : m_raw(uint8_t(state)) - { - } + /// Any DeclRefTypes with substitutions have been fully resolved + /// to concrete type. E.g. `T.X` with `T=A` should resolve to `A.X`. + /// We need a separate pass to resolve these types because `A.X` + /// maybe synthesized and made available only after conformance checking. + TypesFullyResolved, - enum : RawType - { - /// A flag to indicate that a declaration is being checked. + /// All attributes are fully checked. This is the final step before + /// checking the function body. + AttributesChecked, + + /// The body/definition is checked. /// - /// The value of this flag is chosen so that it can be - /// represented in the bits of a `DeclCheckState` without - /// colliding with the bits that represent actual states. + /// This step includes any validation of the declaration that is + /// immaterial to clients code using the declaration, but that is + /// nonetheless relevant to checking correctness. /// - kBeingCheckedBit = 0x80, - }; - - DeclCheckState getState() const { return DeclCheckState(m_raw & ~kBeingCheckedBit); } - void setState(DeclCheckState state) { m_raw = (m_raw & kBeingCheckedBit) | RawType(state); } + /// The canonical example here is checking the body of functions. + /// Client code cannot depend on *how* a function is implemented, + /// but we still need to (eventually) check the bodies of all + /// functions, so it belongs in the last phase of checking. + /// + DefinitionChecked, + DefaultConstructorReadyForUse = DefinitionChecked, - bool isBeingChecked() const { return (m_raw & kBeingCheckedBit) != 0; } + /// The capabilities required by the decl is infered and validated. + /// + CapabilityChecked, + + // For convenience at sites that call `ensureDecl()`, we define + // some aliases for the above states that are expressed in terms + // of what client code needs to be able to do with a declaration. + // + // These aliases can be changed over time if we decide to add + // more phases to semantic checking. + + CanEnumerateBases = ReadyForLookup, + CanUseBaseOfInheritanceDecl = ReadyForLookup, + CanUseTypeOfValueDecl = ReadyForReference, + CanUseExtensionTargetType = ReadyForLookup, + CanUseAsType = ReadyForReference, + CanUseFuncSignature = ReadyForReference, + CanSpecializeGeneric = ReadyForReference, + CanReadInterfaceRequirements = ReadyForLookup, + }; - void setIsBeingChecked(bool isBeingChecked) + /// A `DeclCheckState` plus a bit to track whether a declaration is currently being checked. + struct DeclCheckStateExt { - m_raw = (m_raw & ~kBeingCheckedBit) | (isBeingChecked ? kBeingCheckedBit : 0); - } + typedef uint8_t RawType; + DeclCheckStateExt() {} + DeclCheckStateExt(DeclCheckState state) + : m_raw(uint8_t(state)) + { + } - bool operator>=(DeclCheckState state) const { return getState() >= state; } + enum : RawType + { + /// A flag to indicate that a declaration is being checked. + /// + /// The value of this flag is chosen so that it can be + /// represented in the bits of a `DeclCheckState` without + /// colliding with the bits that represent actual states. + /// + kBeingCheckedBit = 0x80, + }; - RawType getRaw() const { return m_raw; } - void setRaw(RawType raw) { m_raw = raw; } + DeclCheckState getState() const { return DeclCheckState(m_raw & ~kBeingCheckedBit); } + void setState(DeclCheckState state) { m_raw = (m_raw & kBeingCheckedBit) | RawType(state); } - // TODO(JS): - // Unfortunately for automatic serialization to see this member, it has to be public. - // private: - RawType m_raw = 0; -}; + bool isBeingChecked() const { return (m_raw & kBeingCheckedBit) != 0; } -void addModifier(ModifiableSyntaxNode* syntax, Modifier* modifier); + void setIsBeingChecked(bool isBeingChecked) + { + m_raw = (m_raw & ~kBeingCheckedBit) | (isBeingChecked ? kBeingCheckedBit : 0); + } -void removeModifier(ModifiableSyntaxNode* syntax, Modifier* modifier); + bool operator>=(DeclCheckState state) const { return getState() >= state; } -struct QualType -{ - SLANG_VALUE_CLASS(QualType) + RawType getRaw() const { return m_raw; } + void setRaw(RawType raw) { m_raw = raw; } - Type* type = nullptr; - bool isLeftValue = false; - bool hasReadOnlyOnTarget = false; - bool isWriteOnly = false; + // TODO(JS): + // Unfortunately for automatic serialization to see this member, it has to be public. + // private: + RawType m_raw = 0; + }; - QualType() = default; + void addModifier(ModifiableSyntaxNode * syntax, Modifier * modifier); - QualType(Type* type); + void removeModifier(ModifiableSyntaxNode * syntax, Modifier * modifier); - QualType(Type* type, bool isLVal) - : QualType(type) + FIDDLE() + struct QualType { - isLeftValue = isLVal; - } + FIDDLE(...) + Type* type = nullptr; + bool isLeftValue = false; + bool hasReadOnlyOnTarget = false; + bool isWriteOnly = false; + QualType() = default; - Type* Ptr() { return type; } + QualType(Type* type); - operator Type*() { return type; } - Type* operator->() { return type; } -}; + QualType(Type* type, bool isLVal) + : QualType(type) + { + isLeftValue = isLVal; + } -class ASTBuilder; -struct ASTClassInfo -{ - struct Infos - { - const ReflectClassInfo* infos[int(ASTNodeType::CountOf)]; + Type* Ptr() { return type; } + + operator Type*() { return type; } + Type* operator->() { return type; } }; - SLANG_FORCE_INLINE static const ReflectClassInfo* getInfo(ASTNodeType type) - { - return kInfos.infos[int(type)]; - } - static const Infos kInfos; -}; -// A reference to a class of syntax node, that can be -// used to create instances on the fly -struct SyntaxClassBase -{ - SyntaxClassBase() {} + class ASTBuilder; - SyntaxClassBase(ReflectClassInfo const* inClassInfo) - : classInfo(inClassInfo) - { - } + struct SyntaxClassBase; + typedef SyntaxClassBase ReflectClassInfo; + typedef SyntaxClassBase ASTClassInfo; - void* createInstanceImpl(ASTBuilder* astBuilder) const + struct SyntaxClassInfo { - auto ci = classInfo; - if (!ci) - return nullptr; - - auto cf = ci->m_createFunc; - if (!cf) - return nullptr; + public: + char const* name; + ASTNodeType firstTag; + Count tagCount; + void* (*createFunc)(ASTBuilder*); + void (*destructFunc)(void*); - return cf(astBuilder); - } + template<typename T> + static SyntaxClassInfo* get() + { + return const_cast<SyntaxClassInfo*>(&T::kSyntaxClassInfo); + } + }; - SLANG_FORCE_INLINE bool isSubClassOfImpl(SyntaxClassBase const& super) const + // A reference to a class of syntax node, that can be + // used to create instances on the fly + struct SyntaxClassBase { - return classInfo ? classInfo->isSubClassOf(*super.classInfo) : false; - } + SyntaxClassBase() {} - ReflectClassInfo const* classInfo = nullptr; -}; + explicit SyntaxClassBase(ASTNodeType tag); -template<typename T> -struct SyntaxClass : SyntaxClassBase -{ - SyntaxClass() {} + SyntaxClassBase(SyntaxClassInfo const* info) + : _info(info) + { + } - template<typename U> - SyntaxClass( - SyntaxClass<U> const& other, - typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) - : SyntaxClassBase(other.classInfo) - { - } - T* createInstance(ASTBuilder* astBuilder) const { return (T*)createInstanceImpl(astBuilder); } + ASTNodeType getTag() const { return getInfo()->firstTag; } + UnownedTerminatedStringSlice getName() const; - SyntaxClass(const ReflectClassInfo* inClassInfo) - : SyntaxClassBase(inClassInfo) - { - } + void* createInstanceImpl(ASTBuilder* astBuilder) const; + void destructInstanceImpl(void* instance) const; - static SyntaxClass<T> getClass() { return SyntaxClass<T>(&T::kReflectClassInfo); } + bool isSubClassOf(SyntaxClassBase const& super) const; - template<typename U> - bool isSubClassOf(SyntaxClass<U> super) - { - return isSubClassOfImpl(super); - } + typedef SyntaxClassInfo Info; - template<typename U> - bool isSubClassOf() - { - return isSubClassOf(SyntaxClass<U>::getClass()); - } + Info* getInfo() const { return const_cast<Info*>(_info); } + operator Info*() const { return const_cast<Info*>(_info); } - template<typename U> - bool operator==(const SyntaxClass<U> other) const - { - return classInfo == other.classInfo; - } - template<typename U> - bool operator!=(const SyntaxClass<U> other) const - { - return classInfo != other.classInfo; - } -}; + bool operator==(SyntaxClassBase const& other) const { return _info == other._info; } -template<typename T> -SyntaxClass<T> getClass() -{ - return SyntaxClass<T>::getClass(); -} + bool operator!=(SyntaxClassBase const& other) const { return _info != other._info; } -struct SubstitutionSet -{ - DeclRefBase* declRef = nullptr; - - // The element index if the substitution is happening inside a pack expansion. - // For example, if we are substituting the pattern type of `expand each T`, where - // `T` is a type pack, then packExpansionIndex will have a value starting from 0 - // to the count of the type pack during expansion of the `expand` type when we - // substitute `each T` with the element of `T` at index `packExpansionIndex`. - int packExpansionIndex = -1; - - SubstitutionSet() = default; - SubstitutionSet(DeclRefBase* declRefBase) - : declRef(declRefBase) - { - } - explicit operator bool() const; - - template<typename F> - void forEachGenericSubstitution(F func) const; - - template<typename F> - void forEachSubstitutionArg(F func) const; - - Type* applyToType(ASTBuilder* astBuilder, Type* type) const; - DeclRefBase* applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* declRef) const; - - LookupDeclRef* findLookupDeclRef() const; - GenericAppDeclRef* findGenericAppDeclRef(GenericDecl* genericDecl) const; - GenericAppDeclRef* findGenericAppDeclRef() const; - DeclRefBase* getInnerMostNodeWithSubstInfo() const; -}; - -/// An expression together with (optional) substutions to apply to it -/// -/// Under the hood this is a pair of an `Expr*` and a `SubstitutionSet`. -/// Conceptually it represents the result of applying the substitutions, -/// recursively, to the given expression. -/// -/// `SubstExprBase` exists primarily to provide a non-templated base type -/// for `SubstExpr<T>`. Code should prefer to use `SubstExpr<Expr>` instead -/// of `SubstExprBase` as often as possible. -/// -struct SubstExprBase -{ -public: - /// Initialize as a null expression - SubstExprBase() {} + private: + Info const* _info = nullptr; + }; - /// Initialize as the given `expr` with no subsitutions applied - SubstExprBase(Expr* expr) - : m_expr(expr) - { - } + template<typename T> + struct SyntaxClass; - /// Initialize as the given `expr` with the given `substs` applied - SubstExprBase(Expr* expr, SubstitutionSet const& substs) - : m_expr(expr), m_substs(substs) - { - } + template<typename T> + SyntaxClass<T> getSyntaxClass(); - /// Get the underlying expression without any substitutions - Expr* getExpr() const { return m_expr; } + template<typename T = NodeBase> + struct SyntaxClass : SyntaxClassBase + { + SyntaxClass() {} - /// Get the subsitutions being applied, if any - SubstitutionSet const& getSubsts() const { return m_substs; } + template<typename U> + SyntaxClass( + SyntaxClass<U> const& other, + typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) + : SyntaxClassBase(other) + { + } -private: - Expr* m_expr = nullptr; - SubstitutionSet m_substs; + explicit SyntaxClass(SyntaxClassBase const& other) + : SyntaxClassBase(other) + { + } - typedef void (SubstExprBase::*SafeBool)(); - void SafeBoolTrue() {} + explicit SyntaxClass(ASTNodeType tag) + : SyntaxClassBase(tag) + { + } -public: - /// Test whether this is a non-null expression - operator SafeBool() { return m_expr ? &SubstExprBase::SafeBoolTrue : nullptr; } + explicit SyntaxClass(SyntaxClassInfo const* info) + : SyntaxClassBase(info) + { + } - /// Test whether this is a null expression - bool operator!() const { return m_expr == nullptr; } -}; + T* createInstance(ASTBuilder* astBuilder) const + { + return (T*)createInstanceImpl(astBuilder); + } + void destructInstance(T* instance) { destructInstanceImpl(instance); } -/// An expression together with (optional) substutions to apply to it -/// -/// Under the hood this is a pair of an `T*` (there `T: Expr`) and a `SubstitutionSet`. -/// Conceptually it represents the result of applying the substitutions, -/// recursively, to the given expression. -/// -template<typename T> -struct SubstExpr : SubstExprBase -{ -private: - typedef SubstExprBase Super; + bool isSubClassOf(SyntaxClassBase const& other) + { + return SyntaxClassBase::isSubClassOf(other); + } -public: - /// Initialize as a null expression - SubstExpr() {} + template<typename U> + bool isSubClassOf() + { + return SyntaxClassBase::isSubClassOf(getSyntaxClass<U>()); + } + }; - /// Initialize as the given `expr` with no subsitutions applied - SubstExpr(T* expr) - : Super(expr) + template<typename T> + SyntaxClass<T> getSyntaxClass() { + return SyntaxClass<T>(SyntaxClassInfo::get<T>()); } - /// Initialize as the given `expr` with the given `substs` applied - SubstExpr(T* expr, SubstitutionSet const& substs) - : Super(expr, substs) + struct SubstitutionSet { - } + DeclRefBase* declRef = nullptr; - /// Initialize as a copy of the given `other` expression - template<typename U> - SubstExpr( - SubstExpr<U> const& other, - typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) - : Super(other.getExpr(), other.getSubsts()) - { - } + // The element index if the substitution is happening inside a pack expansion. + // For example, if we are substituting the pattern type of `expand each T`, where + // `T` is a type pack, then packExpansionIndex will have a value starting from 0 + // to the count of the type pack during expansion of the `expand` type when we + // substitute `each T` with the element of `T` at index `packExpansionIndex`. + int packExpansionIndex = -1; + + SubstitutionSet() = default; + SubstitutionSet(DeclRefBase* declRefBase) + : declRef(declRefBase) + { + } + explicit operator bool() const; - /// Get the underlying expression without any substitutions - T* getExpr() const { return (T*)Super::getExpr(); } + template<typename F> + void forEachGenericSubstitution(F func) const; + + template<typename F> + void forEachSubstitutionArg(F func) const; + + Type* applyToType(ASTBuilder* astBuilder, Type* type) const; + DeclRefBase* applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* declRef) const; + + LookupDeclRef* findLookupDeclRef() const; + GenericAppDeclRef* findGenericAppDeclRef(GenericDecl* genericDecl) const; + GenericAppDeclRef* findGenericAppDeclRef() const; + DeclRefBase* getInnerMostNodeWithSubstInfo() const; + }; - /// Dynamic cast to an expression of type `U` + /// An expression together with (optional) substutions to apply to it /// - /// Returns a null expression if the cast fails, or if this expression was null. - template<typename U> - SubstExpr<U> as() + /// Under the hood this is a pair of an `Expr*` and a `SubstitutionSet`. + /// Conceptually it represents the result of applying the substitutions, + /// recursively, to the given expression. + /// + /// `SubstExprBase` exists primarily to provide a non-templated base type + /// for `SubstExpr<T>`. Code should prefer to use `SubstExpr<Expr>` instead + /// of `SubstExprBase` as often as possible. + /// + struct SubstExprBase { - return SubstExpr<U>(Slang::as<U>(getExpr()), getSubsts()); - } -}; + public: + /// Initialize as a null expression + SubstExprBase() {} -SubstExpr<Expr> applySubstitutionToExpr(SubstitutionSet substSet, Expr* expr); + /// Initialize as the given `expr` with no subsitutions applied + SubstExprBase(Expr* expr) + : m_expr(expr) + { + } -class ASTBuilder; + /// Initialize as the given `expr` with the given `substs` applied + SubstExprBase(Expr* expr, SubstitutionSet const& substs) + : m_expr(expr), m_substs(substs) + { + } -template<typename T> -struct DeclRef; -Module* getModule(Decl* decl); + /// Get the underlying expression without any substitutions + Expr* getExpr() const { return m_expr; } + /// Get the subsitutions being applied, if any + SubstitutionSet const& getSubsts() const { return m_substs; } -// If this is a declref to an associatedtype with a ThisTypeSubsitution, -// try to find the concrete decl that satisfies the associatedtype requirement from the -// concrete type supplied by ThisTypeSubstittution. -Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef<Decl> declRef); + private: + Expr* m_expr = nullptr; + SubstitutionSet m_substs; -template<typename T = Decl> -struct DeclRef -{ - friend class ASTBuilder; + typedef void (SubstExprBase::*SafeBool)(); + void SafeBoolTrue() {} -public: - typedef T DeclType; - DeclRefBase* declRefBase; - DeclRef() - : declRefBase(nullptr) + public: + /// Test whether this is a non-null expression + operator SafeBool() { return m_expr ? &SubstExprBase::SafeBoolTrue : nullptr; } + + /// Test whether this is a null expression + bool operator!() const { return m_expr == nullptr; } + }; + + /// An expression together with (optional) substutions to apply to it + /// + /// Under the hood this is a pair of an `T*` (there `T: Expr`) and a `SubstitutionSet`. + /// Conceptually it represents the result of applying the substitutions, + /// recursively, to the given expression. + /// + template<typename T> + struct SubstExpr : SubstExprBase { - } + private: + typedef SubstExprBase Super; - void init(DeclRefBase* base); + public: + /// Initialize as a null expression + SubstExpr() {} - DeclRef(Decl* decl); + /// Initialize as the given `expr` with no subsitutions applied + SubstExpr(T* expr) + : Super(expr) + { + } - DeclRef(DeclRefBase* base) { init(base); } + /// Initialize as the given `expr` with the given `substs` applied + SubstExpr(T* expr, SubstitutionSet const& substs) + : Super(expr, substs) + { + } - template<typename U, typename = typename EnableIf<IsConvertible<T*, U*>::Value, void>::type> - DeclRef(DeclRef<U> const& other) - : declRefBase(other.declRefBase) - { - } + /// Initialize as a copy of the given `other` expression + template<typename U> + SubstExpr( + SubstExpr<U> const& other, + typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) + : Super(other.getExpr(), other.getSubsts()) + { + } - T* getDecl() const; + /// Get the underlying expression without any substitutions + T* getExpr() const { return (T*)Super::getExpr(); } - Name* getName() const; + /// Dynamic cast to an expression of type `U` + /// + /// Returns a null expression if the cast fails, or if this expression was null. + template<typename U> + SubstExpr<U> as() + { + return SubstExpr<U>(Slang::as<U>(getExpr()), getSubsts()); + } + }; - SourceLoc getNameLoc() const; - SourceLoc getLoc() const; - DeclRef<ContainerDecl> getParent() const; - HashCode getHashCode() const; - Type* substitute(ASTBuilder* astBuilder, Type* type) const; + SubstExpr<Expr> applySubstitutionToExpr(SubstitutionSet substSet, Expr * expr); - SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const; + class ASTBuilder; - // Apply substitutions to a type or declaration - template<typename U> - DeclRef<U> substitute(ASTBuilder* astBuilder, DeclRef<U> declRef) const; + template<typename T> + struct DeclRef; + Module* getModule(Decl * decl); - // Apply substitutions to this declaration reference - DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; - template<typename U> - DeclRef<U> as() const - { - DeclRef<U> result = DeclRef<U>(declRefBase); - return result; - } + // If this is a declref to an associatedtype with a ThisTypeSubsitution, + // try to find the concrete decl that satisfies the associatedtype requirement from the + // concrete type supplied by ThisTypeSubstittution. + Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst( + ASTBuilder * builder, + DeclRef<Decl> declRef); - template<typename U> - bool is() const + template<typename T = Decl> + struct DeclRef { - return Slang::as<U>(static_cast<NodeBase*>(getDecl())) != nullptr; - } + friend class ASTBuilder; - operator DeclRefBase*() const { return declRefBase; } + public: + typedef T DeclType; + DeclRefBase* declRefBase; + DeclRef() + : declRefBase(nullptr) + { + } - operator DeclRef<Decl>() const { return DeclRef<Decl>(declRefBase); } + void init(DeclRefBase* base); - template<typename U> - bool equals(DeclRef<U> other) const - { - return declRefBase == other.declRefBase; - } + DeclRef(Decl* decl); - template<typename U> - bool operator==(DeclRef<U> other) const - { - return equals(other); - } + DeclRef(DeclRefBase* base) { init(base); } - template<typename U> - bool operator!=(DeclRef<U> other) const - { - return !equals(other); - } + template<typename U, typename = typename EnableIf<IsConvertible<T*, U*>::Value, void>::type> + DeclRef(DeclRef<U> const& other) + : declRefBase(other.declRefBase) + { + } - explicit operator bool() const { return declRefBase; } -}; + T* getDecl() const; -template<typename T> -inline DeclRef<T> makeDeclRef(T* decl) -{ - return DeclRef<T>(decl); -} + Name* getName() const; -SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr); -DeclRef<Decl> substituteDeclRef( - SubstitutionSet const& substs, - ASTBuilder* astBuilder, - DeclRef<Decl> const& declRef); -Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type); + SourceLoc getNameLoc() const; + SourceLoc getLoc() const; + DeclRef<ContainerDecl> getParent() const; + HashCode getHashCode() const; + Type* substitute(ASTBuilder* astBuilder, Type* type) const; -enum class MemberFilterStyle -{ - All, ///< All members - Instance, ///< Only instance members - Static, ///< Only static (ie non instance) members -}; - -Decl* const* adjustFilterCursorImpl( - const ReflectClassInfo& clsInfo, - MemberFilterStyle filterStyle, - Decl* const* ptr, - Decl* const* end); -Decl* const* getFilterCursorByIndexImpl( - const ReflectClassInfo& clsInfo, - MemberFilterStyle filterStyle, - Decl* const* ptr, - Decl* const* end, - Index index); -Index getFilterCountImpl( - const ReflectClassInfo& clsInfo, - MemberFilterStyle filterStyle, - Decl* const* ptr, - Decl* const* end); - - -template<typename T> -Decl* const* adjustFilterCursor(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) -{ - return adjustFilterCursorImpl(T::kReflectClassInfo, filterStyle, ptr, end); -} - -/// Finds the element at index. If there is no element at the index (for example has too few -/// elements), returns nullptr. -template<typename T> -Decl* const* getFilterCursorByIndex( - MemberFilterStyle filterStyle, - Decl* const* ptr, - Decl* const* end, - Index index) -{ - return getFilterCursorByIndexImpl(T::kReflectClassInfo, filterStyle, ptr, end, index); -} + SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const; -template<typename T> -Index getFilterCount(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) -{ - return getFilterCountImpl(T::kReflectClassInfo, filterStyle, ptr, end); -} + // Apply substitutions to a type or declaration + template<typename U> + DeclRef<U> substitute(ASTBuilder* astBuilder, DeclRef<U> declRef) const; -template<typename T> -bool isFilterNonEmpty(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) -{ - return adjustFilterCursorImpl(T::kReflectClassInfo, filterStyle, ptr, end) != end; -} + // Apply substitutions to this declaration reference + DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; -template<typename T> -struct FilteredMemberList -{ - typedef Decl* Element; + template<typename U> + DeclRef<U> as() const + { + DeclRef<U> result = DeclRef<U>(declRefBase); + return result; + } - FilteredMemberList() - : m_begin(nullptr), m_end(nullptr) - { - } + template<typename U> + bool is() const + { + return Slang::as<U>(static_cast<NodeBase*>(getDecl())) != nullptr; + } - explicit FilteredMemberList( - List<Element> const& list, - MemberFilterStyle filterStyle = MemberFilterStyle::All) - : m_begin(adjustFilterCursor<T>(filterStyle, list.begin(), list.end())) - , m_end(list.end()) - , m_filterStyle(filterStyle) - { - } + operator DeclRefBase*() const { return declRefBase; } - struct Iterator - { - const Element* m_cursor; - const Element* m_end; - MemberFilterStyle m_filterStyle; + operator DeclRef<Decl>() const { return DeclRef<Decl>(declRefBase); } - bool operator!=(Iterator const& other) const { return m_cursor != other.m_cursor; } + template<typename U> + bool equals(DeclRef<U> other) const + { + return declRefBase == other.declRefBase; + } + + template<typename U> + bool operator==(DeclRef<U> other) const + { + return equals(other); + } - void operator++() { m_cursor = adjustFilterCursor<T>(m_filterStyle, m_cursor + 1, m_end); } + template<typename U> + bool operator!=(DeclRef<U> other) const + { + return !equals(other); + } - T* operator*() { return static_cast<T*>(*m_cursor); } + explicit operator bool() const { return declRefBase; } }; - Iterator begin() + template<typename T> + inline DeclRef<T> makeDeclRef(T * decl) { - Iterator iter = {m_begin, m_end, m_filterStyle}; - return iter; + return DeclRef<T>(decl); } - Iterator end() + SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr); + DeclRef<Decl> substituteDeclRef( + SubstitutionSet const& substs, + ASTBuilder* astBuilder, + DeclRef<Decl> const& declRef); + Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type); + + enum class MemberFilterStyle { - Iterator iter = {m_end, m_end, m_filterStyle}; - return iter; - } + All, ///< All members + Instance, ///< Only instance members + Static, ///< Only static (ie non instance) members + }; - // TODO(tfoley): It is ugly to have these. - // We should probably fix the call sites instead. - T* getFirst() { return *begin(); } - Index getCount() { return getFilterCount<T>(m_filterStyle, m_begin, m_end); } + Decl* const* adjustFilterCursorImpl( + const ReflectClassInfo& clsInfo, + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end); + Decl* const* getFilterCursorByIndexImpl( + const ReflectClassInfo& clsInfo, + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end, + Index index); + Index getFilterCountImpl( + const ReflectClassInfo& clsInfo, + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end); - T* operator[](Index index) const - { - Decl* const* ptr = getFilterCursorByIndex<T>(m_filterStyle, m_begin, m_end, index); - SLANG_ASSERT(ptr); - return static_cast<T*>(*ptr); - } - /// Returns true if empty (equivalent to getCount() == 0) - bool isEmpty() const + template<typename T> + Decl* const* adjustFilterCursor( + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end) { - /// Note we don't have to scan, because m_begin has already been adjusted, when the - /// FilteredMemberList is constructed - return m_begin == m_end; + return adjustFilterCursorImpl(getSyntaxClass<T>(), filterStyle, ptr, end); } - /// Returns true if non empty (equivalent to getCount() != 0 but faster) - bool isNonEmpty() const { return !isEmpty(); } - List<T*> toList() + /// Finds the element at index. If there is no element at the index (for example has too few + /// elements), returns nullptr. + template<typename T> + Decl* const* getFilterCursorByIndex( + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end, + Index index) { - List<T*> result; - for (auto element : (*this)) - { - result.add(element); - } - return result; + return getFilterCursorByIndexImpl(getSyntaxClass<T>(), filterStyle, ptr, end, index); } - const Element* - m_begin; ///< Is either equal to m_end, or points to first *valid* filtered member - const Element* m_end; - MemberFilterStyle m_filterStyle; -}; - -struct TransparentMemberInfo -{ - // The declaration of the transparent member - Decl* decl = nullptr; -}; - -template<typename T> -struct FilteredMemberRefList -{ - List<Decl*> const& m_decls; - DeclRef<Decl> m_parent; - MemberFilterStyle m_filterStyle; - ASTBuilder* m_astBuilder; - - FilteredMemberRefList( - ASTBuilder* astBuilder, - List<Decl*> const& decls, - DeclRef<Decl> parent, - MemberFilterStyle filterStyle = MemberFilterStyle::All) - : m_decls(decls), m_parent(parent), m_filterStyle(filterStyle), m_astBuilder(astBuilder) + template<typename T> + Index getFilterCount(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) { + return getFilterCountImpl(getSyntaxClass<T>(), filterStyle, ptr, end); } - Index getCount() const + template<typename T> + bool isFilterNonEmpty(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) { - return getFilterCount<T>(m_filterStyle, m_decls.begin(), m_decls.end()); + return adjustFilterCursorImpl(getSyntaxClass<T>(), filterStyle, ptr, end) != end; } - /// True if empty (equivalent to getCount == 0, but faster) - bool isEmpty() const { return !isNonEmpty(); } - /// True if non empty (equivalent to getCount() != 0 but faster) - bool isNonEmpty() const + template<typename T> + struct FilteredMemberList { - return isFilterNonEmpty<T>(m_filterStyle, m_decls.begin(), m_decls.end()); - } + typedef Decl* Element; - DeclRef<T> getFirstOrNull() { return isEmpty() ? DeclRef<T>() : (*this)[0]; } + FilteredMemberList() + : m_begin(nullptr), m_end(nullptr) + { + } - DeclRef<T> operator[](Index index) const - { - Decl* const* decl = - getFilterCursorByIndex<T>(m_filterStyle, m_decls.begin(), m_decls.end(), index); - SLANG_ASSERT(decl); - return _getMemberDeclRef(m_astBuilder, m_parent, (T*)*decl).template as<T>(); - } + explicit FilteredMemberList( + List<Element> const& list, + MemberFilterStyle filterStyle = MemberFilterStyle::All) + : m_begin(adjustFilterCursor<T>(filterStyle, list.begin(), list.end())) + , m_end(list.end()) + , m_filterStyle(filterStyle) + { + } - List<DeclRef<T>> toArray() const - { - List<DeclRef<T>> result; - for (auto d : *this) - result.add(d); - return result; - } + struct Iterator + { + const Element* m_cursor; + const Element* m_end; + MemberFilterStyle m_filterStyle; - struct Iterator - { - FilteredMemberRefList const* m_list; - Decl* const* m_ptr; - Decl* const* m_end; - MemberFilterStyle m_filterStyle; + bool operator!=(Iterator const& other) const { return m_cursor != other.m_cursor; } + + void operator++() + { + m_cursor = adjustFilterCursor<T>(m_filterStyle, m_cursor + 1, m_end); + } - Iterator() - : m_list(nullptr), m_ptr(nullptr), m_filterStyle(MemberFilterStyle::All) + T* operator*() { return static_cast<T*>(*m_cursor); } + }; + + Iterator begin() { + Iterator iter = {m_begin, m_end, m_filterStyle}; + return iter; } - Iterator( - FilteredMemberRefList const* list, - Decl* const* ptr, - Decl* const* end, - MemberFilterStyle filterStyle) - : m_list(list), m_ptr(ptr), m_end(end), m_filterStyle(filterStyle) + + Iterator end() { + Iterator iter = {m_end, m_end, m_filterStyle}; + return iter; } - bool operator!=(const Iterator& other) const { return m_ptr != other.m_ptr; } + // TODO(tfoley): It is ugly to have these. + // We should probably fix the call sites instead. + T* getFirst() { return *begin(); } + Index getCount() { return getFilterCount<T>(m_filterStyle, m_begin, m_end); } + + T* operator[](Index index) const + { + Decl* const* ptr = getFilterCursorByIndex<T>(m_filterStyle, m_begin, m_end, index); + SLANG_ASSERT(ptr); + return static_cast<T*>(*ptr); + } - void operator++() { m_ptr = adjustFilterCursor<T>(m_filterStyle, m_ptr + 1, m_end); } + /// Returns true if empty (equivalent to getCount() == 0) + bool isEmpty() const + { + /// Note we don't have to scan, because m_begin has already been adjusted, when the + /// FilteredMemberList is constructed + return m_begin == m_end; + } + /// Returns true if non empty (equivalent to getCount() != 0 but faster) + bool isNonEmpty() const { return !isEmpty(); } - DeclRef<T> operator*() + List<T*> toList() { - return _getMemberDeclRef(m_list->m_astBuilder, m_list->m_parent, (T*)*m_ptr) - .template as<T>(); + List<T*> result; + for (auto element : (*this)) + { + result.add(element); + } + return result; } + + const Element* + m_begin; ///< Is either equal to m_end, or points to first *valid* filtered member + const Element* m_end; + MemberFilterStyle m_filterStyle; }; - Iterator begin() const + struct TransparentMemberInfo { - return Iterator( - this, - adjustFilterCursor<T>(m_filterStyle, m_decls.begin(), m_decls.end()), - m_decls.end(), - m_filterStyle); - } - Iterator end() const { return Iterator(this, m_decls.end(), m_decls.end(), m_filterStyle); } -}; + // The declaration of the transparent member + Decl* decl = nullptr; + }; -// -// type Expressions -// + template<typename T> + struct FilteredMemberRefList + { + List<Decl*> const& m_decls; + DeclRef<Decl> m_parent; + MemberFilterStyle m_filterStyle; + ASTBuilder* m_astBuilder; + + FilteredMemberRefList( + ASTBuilder* astBuilder, + List<Decl*> const& decls, + DeclRef<Decl> parent, + MemberFilterStyle filterStyle = MemberFilterStyle::All) + : m_decls(decls), m_parent(parent), m_filterStyle(filterStyle), m_astBuilder(astBuilder) + { + } -// A "type expression" is a term that we expect to resolve to a type during checking. -// We store both the original syntax and the resolved type here. -struct TypeExp -{ - SLANG_VALUE_CLASS(TypeExp) - typedef TypeExp ThisType; + Index getCount() const + { + return getFilterCount<T>(m_filterStyle, m_decls.begin(), m_decls.end()); + } - TypeExp() {} - TypeExp(TypeExp const& other) - : exp(other.exp), type(other.type) - { - } - explicit TypeExp(Expr* exp) - : exp(exp) - { - } - explicit TypeExp(Type* type) - : type(type) - { - } - TypeExp(Expr* exp, Type* type) - : exp(exp), type(type) - { - } + /// True if empty (equivalent to getCount == 0, but faster) + bool isEmpty() const { return !isNonEmpty(); } + /// True if non empty (equivalent to getCount() != 0 but faster) + bool isNonEmpty() const + { + return isFilterNonEmpty<T>(m_filterStyle, m_decls.begin(), m_decls.end()); + } - Expr* exp = nullptr; - Type* type = nullptr; + DeclRef<T> getFirstOrNull() { return isEmpty() ? DeclRef<T>() : (*this)[0]; } - bool equals(Type* other); + DeclRef<T> operator[](Index index) const + { + Decl* const* decl = + getFilterCursorByIndex<T>(m_filterStyle, m_decls.begin(), m_decls.end(), index); + SLANG_ASSERT(decl); + return _getMemberDeclRef(m_astBuilder, m_parent, (T*)*decl).template as<T>(); + } - Type* Ptr() { return type; } - operator Type*() { return type; } - Type* operator->() { return Ptr(); } + List<DeclRef<T>> toArray() const + { + List<DeclRef<T>> result; + for (auto d : *this) + result.add(d); + return result; + } - ThisType& operator=(const ThisType& rhs) = default; + struct Iterator + { + FilteredMemberRefList const* m_list; + Decl* const* m_ptr; + Decl* const* m_end; + MemberFilterStyle m_filterStyle; + + Iterator() + : m_list(nullptr), m_ptr(nullptr), m_filterStyle(MemberFilterStyle::All) + { + } + Iterator( + FilteredMemberRefList const* list, + Decl* const* ptr, + Decl* const* end, + MemberFilterStyle filterStyle) + : m_list(list), m_ptr(ptr), m_end(end), m_filterStyle(filterStyle) + { + } + + bool operator!=(const Iterator& other) const { return m_ptr != other.m_ptr; } + + void operator++() { m_ptr = adjustFilterCursor<T>(m_filterStyle, m_ptr + 1, m_end); } + + DeclRef<T> operator*() + { + return _getMemberDeclRef(m_list->m_astBuilder, m_list->m_parent, (T*)*m_ptr) + .template as<T>(); + } + }; + + Iterator begin() const + { + return Iterator( + this, + adjustFilterCursor<T>(m_filterStyle, m_decls.begin(), m_decls.end()), + m_decls.end(), + m_filterStyle); + } + Iterator end() const { return Iterator(this, m_decls.end(), m_decls.end(), m_filterStyle); } + }; - // TypeExp accept(SyntaxVisitor* visitor); + // + // type Expressions + // - /// A global immutable TypeExp, that has no type or exp set. - static const TypeExp empty; -}; + // A "type expression" is a term that we expect to resolve to a type during checking. + // We store both the original syntax and the resolved type here. + FIDDLE() + struct TypeExp + { + FIDDLE(...) + typedef TypeExp ThisType; -// Masks to be applied when lookup up declarations -enum class LookupMask : uint8_t -{ - type = 0x1, - Function = 0x2, - Value = 0x4, - Attribute = 0x8, - SyntaxDecl = 0x10, - Default = type | Function | Value | SyntaxDecl, -}; - -/// Flags for options to be used when looking up declarations -enum class LookupOptions : uint8_t -{ - None = 0, - IgnoreBaseInterfaces = 1 << 0, - Completion = 1 << 1, ///< Lookup all applicable decls for code completion suggestions - NoDeref = 1 << 2, - ConsiderAllLocalNamesInScope = 1 << 3, - ///^ Normally we rely on the checking state of local names to determine - /// if they have been declared. If the scopes are currently - /// "under-construction" and not being checked, then it's safe to - /// consider all names we've inserted so far. This is used when - /// checking to see if a keyword is shadowed. - IgnoreInheritance = - 1 << 4, ///< Lookup only non inheritance children of a struct (including `extension`) - IgnoreTransparentMembers = 1 << 5, -}; -inline LookupOptions operator&(LookupOptions a, LookupOptions b) -{ - return (LookupOptions)((std::underlying_type_t<LookupOptions>)a & - (std::underlying_type_t<LookupOptions>)b); -} + TypeExp() {} + TypeExp(TypeExp const& other) + : exp(other.exp), type(other.type) + { + } + explicit TypeExp(Expr* exp) + : exp(exp) + { + } + explicit TypeExp(Type* type) + : type(type) + { + } + TypeExp(Expr* exp, Type* type) + : exp(exp), type(type) + { + } -class SerialRefObject; + Expr* exp = nullptr; + Type* type = nullptr; -// Make sure C++ extractor can see the base class. -SLANG_PRE_DECLARE(OBJ, class SerialRefObject) + bool equals(Type* other); -SLANG_TYPE_SET(OBJ, RefObject) -SLANG_TYPE_SET(VALUE, Value) -SLANG_TYPE_SET(AST, ASTNode) + Type* Ptr() { return type; } + operator Type*() { return type; } + Type* operator->() { return Ptr(); } -class LookupResultItem_Breadcrumb : public SerialRefObject -{ -public: - SLANG_OBJ_CLASS(LookupResultItem_Breadcrumb) + ThisType& operator=(const ThisType& rhs) = default; - enum class Kind : uint8_t - { - // The lookup process looked "through" an in-scope - // declaration to the fields inside of it, so that - // even if lookup started with a simple name `f`, - // it needs to result in a member expression `obj.f`. - Member, - - // The lookup process took a pointer(-like) value, and then - // proceeded to derefence it and look at the thing(s) - // it points to instead, so that the final expression - // needs to have `(*obj)` - Deref, - - // The lookup process saw a value `obj` of type `T` and - // took into account an in-scope constraint that says - // `T` is a subtype of some other type `U`, so that - // lookup was able to find a member through type `U` - // instead. - SuperType, - - // The lookup process considered a member of an - // enclosing type as being in scope, so that any - // reference to that member needs to use a `this` - // expression as appropriate. - This, + /// A global immutable TypeExp, that has no type or exp set. + static const TypeExp empty; }; - // The kind of lookup step that was performed - Kind kind; - - // For the `Kind::This` case, what does the implicit - // `this` or `This` parameter refer to? - // - enum class ThisParameterMode : uint8_t + // Masks to be applied when lookup up declarations + enum class LookupMask : uint8_t { - ImmutableValue, // An immutable `this` value - MutableValue, // A mutable `this` value - Type, // A `This` type - - Default = ImmutableValue, + type = 0x1, + Function = 0x2, + Value = 0x4, + Attribute = 0x8, + SyntaxDecl = 0x10, + Default = type | Function | Value | SyntaxDecl, }; - ThisParameterMode thisParameterMode = ThisParameterMode::Default; - - // As needed, a reference to the declaration that faciliated - // the lookup step. - // - // For a `Member` lookup step, this is the declaration whose - // members were implicitly pulled into scope. - // - // For a `Constraint` lookup step, this is the `ConstraintDecl` - // that serves to witness the subtype relationship. - // - DeclRef<Decl> declRef; - - Val* val = nullptr; - - // The next implicit step that the lookup process took to - // arrive at a final value. - RefPtr<LookupResultItem_Breadcrumb> next; - LookupResultItem_Breadcrumb( - Kind kind, - DeclRef<Decl> declRef, - Val* val, - RefPtr<LookupResultItem_Breadcrumb> next, - ThisParameterMode thisParameterMode = ThisParameterMode::Default) - : kind(kind), thisParameterMode(thisParameterMode), declRef(declRef), val(val), next(next) + /// Flags for options to be used when looking up declarations + enum class LookupOptions : uint8_t + { + None = 0, + IgnoreBaseInterfaces = 1 << 0, + Completion = 1 << 1, ///< Lookup all applicable decls for code completion suggestions + NoDeref = 1 << 2, + ConsiderAllLocalNamesInScope = 1 << 3, + ///^ Normally we rely on the checking state of local names to determine + /// if they have been declared. If the scopes are currently + /// "under-construction" and not being checked, then it's safe to + /// consider all names we've inserted so far. This is used when + /// checking to see if a keyword is shadowed. + IgnoreInheritance = + 1 << 4, ///< Lookup only non inheritance children of a struct (including `extension`) + IgnoreTransparentMembers = 1 << 5, + }; + inline LookupOptions operator&(LookupOptions a, LookupOptions b) { + return (LookupOptions)((std::underlying_type_t<LookupOptions>)a & + (std::underlying_type_t<LookupOptions>)b); } -protected: - // Needed for serialization - LookupResultItem_Breadcrumb() = default; -}; - -// Represents one item found during lookup -struct LookupResultItem -{ - SLANG_VALUE_CLASS(LookupResultItem) + class LookupResultItem_Breadcrumb : public RefObject + { + public: + enum class Kind : uint8_t + { + // The lookup process looked "through" an in-scope + // declaration to the fields inside of it, so that + // even if lookup started with a simple name `f`, + // it needs to result in a member expression `obj.f`. + Member, + + // The lookup process took a pointer(-like) value, and then + // proceeded to derefence it and look at the thing(s) + // it points to instead, so that the final expression + // needs to have `(*obj)` + Deref, + + // The lookup process saw a value `obj` of type `T` and + // took into account an in-scope constraint that says + // `T` is a subtype of some other type `U`, so that + // lookup was able to find a member through type `U` + // instead. + SuperType, + + // The lookup process considered a member of an + // enclosing type as being in scope, so that any + // reference to that member needs to use a `this` + // expression as appropriate. + This, + }; + + // The kind of lookup step that was performed + Kind kind; + + // For the `Kind::This` case, what does the implicit + // `this` or `This` parameter refer to? + // + enum class ThisParameterMode : uint8_t + { + ImmutableValue, // An immutable `this` value + MutableValue, // A mutable `this` value + Type, // A `This` type + + Default = ImmutableValue, + }; + ThisParameterMode thisParameterMode = ThisParameterMode::Default; + + // As needed, a reference to the declaration that faciliated + // the lookup step. + // + // For a `Member` lookup step, this is the declaration whose + // members were implicitly pulled into scope. + // + // For a `Constraint` lookup step, this is the `ConstraintDecl` + // that serves to witness the subtype relationship. + // + DeclRef<Decl> declRef; + + Val* val = nullptr; + + // The next implicit step that the lookup process took to + // arrive at a final value. + RefPtr<LookupResultItem_Breadcrumb> next; + + LookupResultItem_Breadcrumb( + Kind kind, + DeclRef<Decl> declRef, + Val* val, + RefPtr<LookupResultItem_Breadcrumb> next, + ThisParameterMode thisParameterMode = ThisParameterMode::Default) + : kind(kind) + , thisParameterMode(thisParameterMode) + , declRef(declRef) + , val(val) + , next(next) + { + } - typedef LookupResultItem_Breadcrumb Breadcrumb; + protected: + // Needed for serialization + LookupResultItem_Breadcrumb() = default; + }; - // Sometimes lookup finds an item, but there were additional - // "hops" taken to reach it. We need to remember these steps - // so that if/when we consturct a full expression we generate - // appropriate AST nodes for all the steps. - // - // We build up a list of these "breadcrumbs" while doing - // lookup, and store them alongside each item found. - // - // As an example, suppose we have an HLSL `cbuffer` declaration: - // - // cbuffer C { float4 f; } - // - // This is syntax sugar for a global-scope variable of - // type `ConstantBuffer<T>` where `T` is a `struct` containing - // all the members: - // - // struct Anon0 { float4 f; }; - // __transparent ConstantBuffer<Anon0> anon1; - // - // The `__transparent` modifier there captures the fact that - // when somebody writes `f` in their code, they expect it to - // "see through" the `cbuffer` declaration (or the global variable, - // in this case) and find the member inside. - // - // But when the user writes `f` we can't just create a simple - // `VarExpr` that refers directly to that field, because that - // doesn't actually reflect the required steps in a way that - // code generation can use. - // - // Instead we need to construct an expression like `(*anon1).f`, - // where there is are two additional steps in the process: - // - // 1. We needed to dereference the pointer-like type `ConstantBuffer<Anon0>` - // to get at a value of type `Anon0` - // 2. We needed to access a sub-field of the aggregate type `Anon0` - // - // We *could* just create these full-formed expressions during - // lookup, but this might mean creating a large number of - // AST nodes in cases where the user calls an overloaded function. - // At the very least we'd rather not heap-allocate in the common - // case where no "extra" steps need to be performed to get to - // the declarations. - // - // This is where "breadcrumbs" come in. A breadcrumb represents - // an extra "step" that must be performed to turn a declaration - // found by lookup into a valid expression to splice into the - // AST. Most of the time lookup result items don't have any - // breadcrumbs, so that no extra heap allocation takes place. - // When an item does have breadcrumbs, and it is chosen as - // the unique result (perhaps by overload resolution), then - // we can walk the list of breadcrumbs to create a full - // expression. - - - // A properly-specialized reference to the declaration that was found. - DeclRef<Decl> declRef; - - // Any breadcrumbs needed in order to turn that declaration - // reference into a well-formed expression. - // - // This is unused in the simple case where a declaration - // is being referenced directly (rather than through - // transparent members). - RefPtr<LookupResultItem_Breadcrumb> breadcrumbs; - - LookupResultItem() = default; - explicit LookupResultItem(DeclRef<Decl> declRef) - : declRef(declRef) - { - } - LookupResultItem(DeclRef<Decl> declRef, RefPtr<Breadcrumb> breadcrumbs) - : declRef(declRef), breadcrumbs(breadcrumbs) - { - } -}; + // Represents one item found during lookup + struct LookupResultItem + { + typedef LookupResultItem_Breadcrumb Breadcrumb; + + // Sometimes lookup finds an item, but there were additional + // "hops" taken to reach it. We need to remember these steps + // so that if/when we consturct a full expression we generate + // appropriate AST nodes for all the steps. + // + // We build up a list of these "breadcrumbs" while doing + // lookup, and store them alongside each item found. + // + // As an example, suppose we have an HLSL `cbuffer` declaration: + // + // cbuffer C { float4 f; } + // + // This is syntax sugar for a global-scope variable of + // type `ConstantBuffer<T>` where `T` is a `struct` containing + // all the members: + // + // struct Anon0 { float4 f; }; + // __transparent ConstantBuffer<Anon0> anon1; + // + // The `__transparent` modifier there captures the fact that + // when somebody writes `f` in their code, they expect it to + // "see through" the `cbuffer` declaration (or the global variable, + // in this case) and find the member inside. + // + // But when the user writes `f` we can't just create a simple + // `VarExpr` that refers directly to that field, because that + // doesn't actually reflect the required steps in a way that + // code generation can use. + // + // Instead we need to construct an expression like `(*anon1).f`, + // where there is are two additional steps in the process: + // + // 1. We needed to dereference the pointer-like type `ConstantBuffer<Anon0>` + // to get at a value of type `Anon0` + // 2. We needed to access a sub-field of the aggregate type `Anon0` + // + // We *could* just create these full-formed expressions during + // lookup, but this might mean creating a large number of + // AST nodes in cases where the user calls an overloaded function. + // At the very least we'd rather not heap-allocate in the common + // case where no "extra" steps need to be performed to get to + // the declarations. + // + // This is where "breadcrumbs" come in. A breadcrumb represents + // an extra "step" that must be performed to turn a declaration + // found by lookup into a valid expression to splice into the + // AST. Most of the time lookup result items don't have any + // breadcrumbs, so that no extra heap allocation takes place. + // When an item does have breadcrumbs, and it is chosen as + // the unique result (perhaps by overload resolution), then + // we can walk the list of breadcrumbs to create a full + // expression. + + + // A properly-specialized reference to the declaration that was found. + DeclRef<Decl> declRef; + + // Any breadcrumbs needed in order to turn that declaration + // reference into a well-formed expression. + // + // This is unused in the simple case where a declaration + // is being referenced directly (rather than through + // transparent members). + RefPtr<LookupResultItem_Breadcrumb> breadcrumbs; + + LookupResultItem() = default; + explicit LookupResultItem(DeclRef<Decl> declRef) + : declRef(declRef) + { + } + LookupResultItem(DeclRef<Decl> declRef, RefPtr<Breadcrumb> breadcrumbs) + : declRef(declRef), breadcrumbs(breadcrumbs) + { + } + }; -// Result of looking up a name in some lexical/semantic environment. -// Can be used to enumerate all the declarations matching that name, -// in the case where the result is overloaded. -struct LookupResult -{ - // The one item that was found, in the simple case - LookupResultItem item; + // Result of looking up a name in some lexical/semantic environment. + // Can be used to enumerate all the declarations matching that name, + // in the case where the result is overloaded. + struct LookupResult + { + // The one item that was found, in the simple case + LookupResultItem item; - // All of the items that were found, in the complex case. - // Note: if there was no overloading, then this list isn't - // used at all, to avoid allocation. - // - // Additionally, if `items` is used, then `item` *must* hold an item that - // is also in the items list (typically the first entry), as an invariant. - // Otherwise isValid/begin will not function correctly. - List<LookupResultItem> items; + // All of the items that were found, in the complex case. + // Note: if there was no overloading, then this list isn't + // used at all, to avoid allocation. + // + // Additionally, if `items` is used, then `item` *must* hold an item that + // is also in the items list (typically the first entry), as an invariant. + // Otherwise isValid/begin will not function correctly. + List<LookupResultItem> items; - // Was at least one result found? - bool isValid() const { return item.declRef.getDecl() != nullptr; } + // Was at least one result found? + bool isValid() const { return item.declRef.getDecl() != nullptr; } - bool isOverloaded() const { return items.getCount() > 1; } + bool isOverloaded() const { return items.getCount() > 1; } - Name* getName() const - { - return items.getCount() > 1 ? items[0].declRef.getName() : item.declRef.getName(); - } - LookupResultItem* begin() const - { - if (isValid()) + Name* getName() const { - if (isOverloaded()) - return const_cast<LookupResultItem*>(items.begin()); + return items.getCount() > 1 ? items[0].declRef.getName() : item.declRef.getName(); + } + LookupResultItem* begin() const + { + if (isValid()) + { + if (isOverloaded()) + return const_cast<LookupResultItem*>(items.begin()); + else + return const_cast<LookupResultItem*>(&item); + } else - return const_cast<LookupResultItem*>(&item); + return nullptr; } - else - return nullptr; - } - LookupResultItem* end() const - { - if (isValid()) + LookupResultItem* end() const { - if (isOverloaded()) - return const_cast<LookupResultItem*>(items.end()); + if (isValid()) + { + if (isOverloaded()) + return const_cast<LookupResultItem*>(items.end()); + else + return const_cast<LookupResultItem*>(&item + 1); + } else - return const_cast<LookupResultItem*>(&item + 1); + return nullptr; } - else - return nullptr; - } -}; - -// A helper to avoid having to include slang-check-impl.h in slang-syntax.h -struct SemanticsVisitor; -ASTBuilder* semanticsVisitorGetASTBuilder(SemanticsVisitor*); - -struct LookupRequest -{ - SemanticsVisitor* semantics = nullptr; - Scope* scope = nullptr; - Scope* endScope = nullptr; + }; - // A decl to exclude from the lookup, used to exclude the current decl being checked, such as in - // typedef Foo Foo; to avoid finding itself. - Decl* declToExclude = nullptr; - LookupMask mask = LookupMask::Default; - LookupOptions options = LookupOptions::None; + // A helper to avoid having to include slang-check-impl.h in slang-syntax.h + struct SemanticsVisitor; + ASTBuilder* semanticsVisitorGetASTBuilder(SemanticsVisitor*); - bool isCompletionRequest() const - { - return (options & LookupOptions::Completion) != LookupOptions::None; - } - bool shouldConsiderAllLocalNames() const + struct LookupRequest { - return (options & LookupOptions::ConsiderAllLocalNamesInScope) != LookupOptions::None; - } -}; + SemanticsVisitor* semantics = nullptr; + Scope* scope = nullptr; + Scope* endScope = nullptr; -struct WitnessTable; + // A decl to exclude from the lookup, used to exclude the current decl being checked, such + // as in typedef Foo Foo; to avoid finding itself. + Decl* declToExclude = nullptr; + LookupMask mask = LookupMask::Default; + LookupOptions options = LookupOptions::None; -// A value that witnesses the satisfaction of an interface -// requirement by a particular declaration or value. -struct RequirementWitness -{ - SLANG_VALUE_CLASS(RequirementWitness) + bool isCompletionRequest() const + { + return (options & LookupOptions::Completion) != LookupOptions::None; + } + bool shouldConsiderAllLocalNames() const + { + return (options & LookupOptions::ConsiderAllLocalNamesInScope) != LookupOptions::None; + } + }; - RequirementWitness() - : m_flavor(Flavor::none) - { - } + class WitnessTable; - RequirementWitness(DeclRefBase* declRef) - : m_flavor(Flavor::declRef), m_declRef(declRef) + // A value that witnesses the satisfaction of an interface + // requirement by a particular declaration or value. + struct RequirementWitness { - } + RequirementWitness() + : m_flavor(Flavor::none) + { + } - RequirementWitness(Val* val); + RequirementWitness(DeclRefBase* declRef) + : m_flavor(Flavor::declRef), m_declRef(declRef) + { + } - RequirementWitness(RefPtr<WitnessTable> witnessTable); + RequirementWitness(Val* val); - enum class Flavor - { - none, - declRef, - val, - witnessTable, - }; + RequirementWitness(RefPtr<WitnessTable> witnessTable); - Flavor getFlavor() const { return m_flavor; } + enum class Flavor + { + none, + declRef, + val, + witnessTable, + }; - DeclRef<Decl> getDeclRef() - { - SLANG_ASSERT(getFlavor() == Flavor::declRef); - return m_declRef; - } + Flavor getFlavor() const { return m_flavor; } - Val* getVal() - { - SLANG_ASSERT(getFlavor() == Flavor::val); - return m_val; - } + DeclRef<Decl> getDeclRef() + { + SLANG_ASSERT(getFlavor() == Flavor::declRef); + return m_declRef; + } - RefPtr<WitnessTable> getWitnessTable(); + Val* getVal() + { + SLANG_ASSERT(getFlavor() == Flavor::val); + return m_val; + } - RequirementWitness specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); + RefPtr<WitnessTable> getWitnessTable(); - Flavor m_flavor; - DeclRef<Decl> m_declRef; - RefPtr<RefObject> m_obj; - Val* m_val = nullptr; -}; + RequirementWitness specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); -typedef OrderedDictionary<Decl*, RequirementWitness> RequirementDictionary; + Flavor m_flavor; + DeclRef<Decl> m_declRef; + RefPtr<RefObject> m_obj; + Val* m_val = nullptr; + }; -struct WitnessTable : SerialRefObject -{ - SLANG_OBJ_CLASS(WitnessTable) + typedef OrderedDictionary<Decl*, RequirementWitness> RequirementDictionary; - const RequirementDictionary& getRequirementDictionary() { return m_requirementDictionary; } + FIDDLE() + class WitnessTable : public RefObject + { + FIDDLE(...) + const RequirementDictionary& getRequirementDictionary() { return m_requirementDictionary; } - void add(Decl* decl, RequirementWitness const& witness); + void add(Decl* decl, RequirementWitness const& witness); - // The type that the witness table witnesses conformance to (e.g. an Interface) - Type* baseType; + // The type that the witness table witnesses conformance to (e.g. an Interface) + Type* baseType; - // The type witnessesd by the witness table (a concrete type). - Type* witnessedType; + // The type witnessesd by the witness table (a concrete type). + Type* witnessedType; - // Whether or not this witness table is an extern declaration. - bool isExtern = false; + // Whether or not this witness table is an extern declaration. + bool isExtern = false; - // Cached dictionary for looking up satisfying values. - RequirementDictionary m_requirementDictionary; + // Cached dictionary for looking up satisfying values. + RequirementDictionary m_requirementDictionary; - RefPtr<WitnessTable> specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); -}; + RefPtr<WitnessTable> specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); + }; -struct SpecializationParam -{ - enum class Flavor + struct SpecializationParam { - GenericType, - GenericValue, - ExistentialType, - ExistentialValue, + enum class Flavor + { + GenericType, + GenericValue, + ExistentialType, + ExistentialValue, + }; + Flavor flavor; + SourceLoc loc; + NodeBase* object = nullptr; }; - Flavor flavor; - SourceLoc loc; - NodeBase* object = nullptr; -}; -typedef List<SpecializationParam> SpecializationParams; + typedef List<SpecializationParam> SpecializationParams; -struct SpecializationArg -{ - SLANG_VALUE_CLASS(SpecializationArg) - Val* val = nullptr; -}; -typedef List<SpecializationArg> SpecializationArgs; + struct SpecializationArg + { + Val* val = nullptr; + }; + typedef List<SpecializationArg> SpecializationArgs; -struct ExpandedSpecializationArg : SpecializationArg -{ - SLANG_VALUE_CLASS(ExpandedSpecializationArg) - Val* witness = nullptr; -}; -typedef List<ExpandedSpecializationArg> ExpandedSpecializationArgs; - -/// A reference-counted object to hold a list of candidate extensions -/// that might be applicable to a type based on its declaration. -/// -struct CandidateExtensionList : RefObject -{ - List<ExtensionDecl*> candidateExtensions; -}; + struct ExpandedSpecializationArg : SpecializationArg + { + Val* witness = nullptr; + }; + typedef List<ExpandedSpecializationArg> ExpandedSpecializationArgs; + /// A reference-counted object to hold a list of candidate extensions + /// that might be applicable to a type based on its declaration. + /// + FIDDLE() + class CandidateExtensionList : public RefObject + { + FIDDLE(...) + List<ExtensionDecl*> candidateExtensions; + }; -enum class DeclAssociationKind -{ - ForwardDerivativeFunc, - BackwardDerivativeFunc, - PrimalSubstituteFunc -}; -struct DeclAssociation : SerialRefObject -{ - SLANG_OBJ_CLASS(DeclAssociation) - DeclAssociationKind kind; - Decl* decl; -}; - -/// A reference-counted object to hold a list of associated decls for a decl. -/// -struct DeclAssociationList : SerialRefObject -{ - SLANG_OBJ_CLASS(DeclAssociationList) + enum class DeclAssociationKind + { + ForwardDerivativeFunc, + BackwardDerivativeFunc, + PrimalSubstituteFunc + }; - List<RefPtr<DeclAssociation>> associations; -}; + FIDDLE() + class DeclAssociation : public RefObject + { + FIDDLE(...) + DeclAssociationKind kind; + Decl* decl; + }; -/// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` -enum ParameterDirection -{ - kParameterDirection_In, ///< Copy in - kParameterDirection_Out, ///< Copy out - kParameterDirection_InOut, ///< Copy in, copy out - kParameterDirection_Ref, ///< By-reference - kParameterDirection_ConstRef, ///< By-const-reference -}; + /// A reference-counted object to hold a list of associated decls for a decl. + /// + FIDDLE() + class DeclAssociationList : public RefObject + { + FIDDLE(...) + List<RefPtr<DeclAssociation>> associations; + }; -void printDiagnosticArg(StringBuilder& sb, ParameterDirection direction); + /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` + enum ParameterDirection + { + kParameterDirection_In, ///< Copy in + kParameterDirection_Out, ///< Copy out + kParameterDirection_InOut, ///< Copy in, copy out + kParameterDirection_Ref, ///< By-reference + kParameterDirection_ConstRef, ///< By-const-reference + }; -/// The kind of a builtin interface requirement that can be automatically synthesized. -enum class BuiltinRequirementKind -{ - DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method - - DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement - DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement - DZeroFunc, ///< The `IDifferentiable.dzero` function requirement - DAddFunc, ///< The `IDifferentiable.dadd` function requirement - DMulFunc, ///< The `IDifferentiable.dmul` function requirement - - InitLogicalFromInt, ///< The `ILogical.__init` mtehod. - Equals, ///< The `ILogical.equals` mtehod. - LessThan, ///< The `ILogical.lessThan` mtehod. - LessThanOrEquals, ///< The `ILogical.lessThanOrEquals` mtehod. - Shl, ///< The `ILogical.shl` mtehod. - Shr, ///< The `ILogical.shr` mtehod. - BitAnd, ///< The `ILogical.bitAnd` mtehod. - BitOr, ///< The `ILogical.bitOr` mtehod. - BitXor, ///< The `ILogical.bitXor` mtehod. - BitNot, ///< The `ILogical.bitNot` mtehod. - And, ///< The `ILogical.and` mtehod. - Or, ///< The `ILogical.or` mtehod. - Not, ///< The `ILogical.not` mtehod. -}; - -enum class FunctionDifferentiableLevel -{ - None, - Forward, - Backward -}; + void printDiagnosticArg(StringBuilder & sb, ParameterDirection direction); + + /// The kind of a builtin interface requirement that can be automatically synthesized. + enum class BuiltinRequirementKind + { + DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method + + DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement + DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement + DZeroFunc, ///< The `IDifferentiable.dzero` function requirement + DAddFunc, ///< The `IDifferentiable.dadd` function requirement + DMulFunc, ///< The `IDifferentiable.dmul` function requirement + + InitLogicalFromInt, ///< The `ILogical.__init` mtehod. + Equals, ///< The `ILogical.equals` mtehod. + LessThan, ///< The `ILogical.lessThan` mtehod. + LessThanOrEquals, ///< The `ILogical.lessThanOrEquals` mtehod. + Shl, ///< The `ILogical.shl` mtehod. + Shr, ///< The `ILogical.shr` mtehod. + BitAnd, ///< The `ILogical.bitAnd` mtehod. + BitOr, ///< The `ILogical.bitOr` mtehod. + BitXor, ///< The `ILogical.bitXor` mtehod. + BitNot, ///< The `ILogical.bitNot` mtehod. + And, ///< The `ILogical.and` mtehod. + Or, ///< The `ILogical.or` mtehod. + Not, ///< The `ILogical.not` mtehod. + }; -/// Represents a markup (documentation) associated with a decl. -struct MarkupEntry : public SerialRefObject -{ - SLANG_OBJ_CLASS(MarkupEntry) + enum class FunctionDifferentiableLevel + { + None, + Forward, + Backward + }; - NodeBase* m_node; ///< The node this documentation is associated with - String m_markup; ///< The raw contents of of markup associated with the decoration - MarkupVisibility m_visibility = MarkupVisibility::Public; ///< How visible this decl is -}; + /// Represents a markup (documentation) associated with a decl. + FIDDLE() + class MarkupEntry : public RefObject + { + FIDDLE(...) + NodeBase* m_node; ///< The node this documentation is associated with + String m_markup; ///< The raw contents of of markup associated with the decoration + MarkupVisibility m_visibility = MarkupVisibility::Public; ///< How visible this decl is + }; -/// Get the inner most expr from an higher order expr chain, e.g. `__fwd_diff(__fwd_diff(f))`'s -/// inner most expr is `f`. -Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLevel& outDiffLevel); -inline Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr) -{ - FunctionDifferentiableLevel level; - return getInnerMostExprFromHigherOrderExpr(expr, level); -} + /// Get the inner most expr from an higher order expr chain, e.g. `__fwd_diff(__fwd_diff(f))`'s + /// inner most expr is `f`. + Expr* getInnerMostExprFromHigherOrderExpr( + Expr * expr, + FunctionDifferentiableLevel & outDiffLevel); + inline Expr* getInnerMostExprFromHigherOrderExpr(Expr * expr) + { + FunctionDifferentiableLevel level; + return getInnerMostExprFromHigherOrderExpr(expr, level); + } -/// Get the operator name from the higher order invoke expr. -UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr); + /// Get the operator name from the higher order invoke expr. + UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr * expr); -enum class DeclVisibility -{ - Private, - Internal, - Public, - Default = Internal, -}; + enum class DeclVisibility + { + Private, + Internal, + Public, + Default = Internal, + }; } // namespace Slang diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 29a52a93a..ff4cc0d10 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -1,7 +1,9 @@ // slang-ast-type.cpp +#include "slang-ast-type.h" + #include "slang-ast-builder.h" +#include "slang-ast-dispatch.h" #include "slang-ast-modifier.h" -#include "slang-generated-ast-macro.h" #include "slang-syntax.h" #include <assert.h> diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 7393092f9..dd4d2acd6 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -1,19 +1,20 @@ // slang-ast-type.h - #pragma once #include "slang-ast-base.h" +#include "slang-ast-type.h.fiddle" +FIDDLE() namespace Slang { // Syntax class definitions for types. // The type of a reference to an overloaded name +FIDDLE() class OverloadGroupType : public Type { - SLANG_AST_CLASS(OverloadGroupType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); @@ -21,20 +22,20 @@ class OverloadGroupType : public Type // The type of an initializer-list expression (before it has // been coerced to some other type) +FIDDLE() class InitializerListType : public Type { - SLANG_AST_CLASS(InitializerListType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); }; // The type of an expression that was erroneous +FIDDLE() class ErrorType : public Type { - SLANG_AST_CLASS(ErrorType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); @@ -42,20 +43,20 @@ class ErrorType : public Type }; // The bottom/empty type that has no values. +FIDDLE() class BottomType : public Type { - SLANG_AST_CLASS(BottomType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; // A type that takes the form of a reference to some declaration +FIDDLE() class DeclRefType : public Type { - SLANG_AST_CLASS(DeclRefType) - + FIDDLE(...) static Type* create(ASTBuilder* astBuilder, DeclRef<Decl> declRef); DeclRef<Decl> getDeclRef() const { return DeclRef<Decl>(as<DeclRefBase>(getOperand(0))); } @@ -83,20 +84,20 @@ bool isTypePack(Type* type); bool isAbstractTypePack(Type* type); // Base class for types that can be used in arithmetic expressions +FIDDLE(abstract) class ArithmeticExpressionType : public DeclRefType { - SLANG_ABSTRACT_AST_CLASS(ArithmeticExpressionType) - + FIDDLE(...) BasicExpressionType* getScalarType(); // Overrides should be public so base classes can access BasicExpressionType* _getScalarTypeOverride(); }; +FIDDLE() class BasicExpressionType : public ArithmeticExpressionType { - SLANG_AST_CLASS(BasicExpressionType) - + FIDDLE(...) BaseType getBaseType() const; // Overrides should be public so base classes can access @@ -108,45 +109,52 @@ class BasicExpressionType : public ArithmeticExpressionType // Base type for things that are built in to the compiler, // and will usually have special behavior or a custom // mapping to the IR level. +FIDDLE(abstract) class BuiltinType : public DeclRefType { - SLANG_ABSTRACT_AST_CLASS(BuiltinType) + FIDDLE(...) }; +FIDDLE(abstract) class DataLayoutType : public BuiltinType { - SLANG_ABSTRACT_AST_CLASS(DataLayoutType) + FIDDLE(...) }; +FIDDLE() class IBufferDataLayoutType : public BuiltinType { - SLANG_AST_CLASS(IBufferDataLayoutType) + FIDDLE(...) }; +FIDDLE() class DefaultDataLayoutType : public DataLayoutType { - SLANG_AST_CLASS(DefaultDataLayoutType) + FIDDLE(...) }; +FIDDLE() class Std430DataLayoutType : public DataLayoutType { - SLANG_AST_CLASS(Std430DataLayoutType) + FIDDLE(...) }; +FIDDLE() class Std140DataLayoutType : public DataLayoutType { - SLANG_AST_CLASS(Std140DataLayoutType) + FIDDLE(...) }; +FIDDLE() class ScalarDataLayoutType : public DataLayoutType { - SLANG_AST_CLASS(ScalarDataLayoutType) + FIDDLE(...) }; +FIDDLE() class FeedbackType : public BuiltinType { - SLANG_AST_CLASS(FeedbackType) - + FIDDLE(...) enum class Kind : uint8_t { MinMip, /// SAMPLER_FEEDBACK_MIN_MIP @@ -156,37 +164,43 @@ class FeedbackType : public BuiltinType Kind getKind() const; }; +FIDDLE(abstract) class TextureShapeType : public BuiltinType { - SLANG_ABSTRACT_AST_CLASS(TextureShapeType) + FIDDLE(...) }; +FIDDLE() class TextureShape1DType : public TextureShapeType { - SLANG_AST_CLASS(TextureShape1DType) + FIDDLE(...) }; +FIDDLE() class TextureShape2DType : public TextureShapeType { - SLANG_AST_CLASS(TextureShape2DType) + FIDDLE(...) }; +FIDDLE() class TextureShape3DType : public TextureShapeType { - SLANG_AST_CLASS(TextureShape3DType) + FIDDLE(...) }; +FIDDLE() class TextureShapeCubeType : public TextureShapeType { - SLANG_AST_CLASS(TextureShapeCubeType) + FIDDLE(...) }; +FIDDLE() class TextureShapeBufferType : public TextureShapeType { - SLANG_AST_CLASS(TextureShapeBufferType) + FIDDLE(...) }; // Resources that contain "elements" that can be fetched +FIDDLE(abstract) class ResourceType : public BuiltinType { - SLANG_ABSTRACT_AST_CLASS(ResourceType) - + FIDDLE(...) bool isMultisample(); bool isArray(); bool isShadow(); @@ -199,286 +213,322 @@ class ResourceType : public BuiltinType void _toTextOverride(StringBuilder& out); }; +FIDDLE(abstract) class TextureTypeBase : public ResourceType { - SLANG_ABSTRACT_AST_CLASS(TextureTypeBase) - + FIDDLE(...) Val* getSampleCount(); Val* getFormat(); }; +FIDDLE() class TextureType : public TextureTypeBase { - SLANG_AST_CLASS(TextureType) + FIDDLE(...) }; // This is a base type for `image*` types, as they exist in GLSL +FIDDLE() class GLSLImageType : public TextureTypeBase { - SLANG_AST_CLASS(GLSLImageType) + FIDDLE(...) }; +FIDDLE() class SubpassInputType : public BuiltinType { - SLANG_AST_CLASS(SubpassInputType) - + FIDDLE(...) bool isMultisample(); Type* getElementType(); }; +FIDDLE() class SamplerStateType : public BuiltinType { - SLANG_AST_CLASS(SamplerStateType) - + FIDDLE(...) // Returns flavor of sampler state of this type. SamplerStateFlavor getFlavor() const; }; // Other cases of generic types known to the compiler +FIDDLE() class BuiltinGenericType : public BuiltinType { - SLANG_AST_CLASS(BuiltinGenericType) - + FIDDLE(...) Type* getElementType() const; }; // Types that behave like pointers, in that they can be // dereferenced (implicitly) to access members defined // in the element type. +FIDDLE(abstract) class PointerLikeType : public BuiltinGenericType { - SLANG_AST_CLASS(PointerLikeType) + FIDDLE(...) }; +FIDDLE() class DynamicResourceType : public BuiltinType { - SLANG_AST_CLASS(DynamicResourceType) + FIDDLE(...) }; // HLSL buffer-type resources +FIDDLE(abstract) class HLSLStructuredBufferTypeBase : public BuiltinGenericType { - SLANG_AST_CLASS(HLSLStructuredBufferTypeBase) + FIDDLE(...) }; +FIDDLE() class HLSLStructuredBufferType : public HLSLStructuredBufferTypeBase { - SLANG_AST_CLASS(HLSLStructuredBufferType) + FIDDLE(...) }; +FIDDLE() class HLSLRWStructuredBufferType : public HLSLStructuredBufferTypeBase { - SLANG_AST_CLASS(HLSLRWStructuredBufferType) + FIDDLE(...) }; +FIDDLE() class HLSLRasterizerOrderedStructuredBufferType : public HLSLStructuredBufferTypeBase { - SLANG_AST_CLASS(HLSLRasterizerOrderedStructuredBufferType) + FIDDLE(...) }; +FIDDLE() class UntypedBufferResourceType : public BuiltinType { - SLANG_AST_CLASS(UntypedBufferResourceType) + FIDDLE(...) }; +FIDDLE() class HLSLByteAddressBufferType : public UntypedBufferResourceType { - SLANG_AST_CLASS(HLSLByteAddressBufferType) + FIDDLE(...) }; +FIDDLE() class HLSLRWByteAddressBufferType : public UntypedBufferResourceType { - SLANG_AST_CLASS(HLSLRWByteAddressBufferType) + FIDDLE(...) }; +FIDDLE() class HLSLRasterizerOrderedByteAddressBufferType : public UntypedBufferResourceType { - SLANG_AST_CLASS(HLSLRasterizerOrderedByteAddressBufferType) + FIDDLE(...) }; +FIDDLE() class RaytracingAccelerationStructureType : public UntypedBufferResourceType { - SLANG_AST_CLASS(RaytracingAccelerationStructureType) + FIDDLE(...) }; +FIDDLE() class HLSLAppendStructuredBufferType : public HLSLStructuredBufferTypeBase { - SLANG_AST_CLASS(HLSLAppendStructuredBufferType) + FIDDLE(...) }; +FIDDLE() class HLSLConsumeStructuredBufferType : public HLSLStructuredBufferTypeBase { - SLANG_AST_CLASS(HLSLConsumeStructuredBufferType) + FIDDLE(...) }; +FIDDLE() class GLSLAtomicUintType : public BuiltinType { - SLANG_AST_CLASS(GLSLAtomicUintType) + FIDDLE(...) }; +FIDDLE() class HLSLPatchType : public BuiltinType { - SLANG_AST_CLASS(HLSLPatchType) - + FIDDLE(...) Type* getElementType(); IntVal* getElementCount(); }; +FIDDLE() class HLSLInputPatchType : public HLSLPatchType { - SLANG_AST_CLASS(HLSLInputPatchType) + FIDDLE(...) }; +FIDDLE() class HLSLOutputPatchType : public HLSLPatchType { - SLANG_AST_CLASS(HLSLOutputPatchType) + FIDDLE(...) }; // HLSL geometry shader output stream types +FIDDLE() class HLSLStreamOutputType : public BuiltinGenericType { - SLANG_AST_CLASS(HLSLStreamOutputType) + FIDDLE(...) }; +FIDDLE() class HLSLPointStreamType : public HLSLStreamOutputType { - SLANG_AST_CLASS(HLSLPointStreamType) + FIDDLE(...) }; +FIDDLE() class HLSLLineStreamType : public HLSLStreamOutputType { - SLANG_AST_CLASS(HLSLLineStreamType) + FIDDLE(...) }; +FIDDLE() class HLSLTriangleStreamType : public HLSLStreamOutputType { - SLANG_AST_CLASS(HLSLTriangleStreamType) + FIDDLE(...) }; // mesh shader output types +FIDDLE() class MeshOutputType : public BuiltinGenericType { - SLANG_AST_CLASS(MeshOutputType) - + FIDDLE(...) Type* getElementType(); IntVal* getMaxElementCount(); }; +FIDDLE() class VerticesType : public MeshOutputType { - SLANG_AST_CLASS(VerticesType) + FIDDLE(...) }; +FIDDLE() class IndicesType : public MeshOutputType { - SLANG_AST_CLASS(IndicesType) + FIDDLE(...) }; +FIDDLE() class PrimitivesType : public MeshOutputType { - SLANG_AST_CLASS(PrimitivesType) + FIDDLE(...) }; // +FIDDLE() class GLSLInputAttachmentType : public BuiltinType { - SLANG_AST_CLASS(GLSLInputAttachmentType) + FIDDLE(...) }; +FIDDLE() class DescriptorHandleType : public PointerLikeType { - SLANG_AST_CLASS(DescriptorHandleType) + FIDDLE(...) }; // Base class for types used when desugaring parameter block // declarations, includeing HLSL `cbuffer` or GLSL `uniform` blocks. +FIDDLE(abstract) class ParameterGroupType : public PointerLikeType { - SLANG_AST_CLASS(ParameterGroupType) + FIDDLE(...) }; +FIDDLE() class UniformParameterGroupType : public ParameterGroupType { - SLANG_AST_CLASS(UniformParameterGroupType) + FIDDLE(...) Type* getLayoutType(); }; +FIDDLE() class VaryingParameterGroupType : public ParameterGroupType { - SLANG_AST_CLASS(VaryingParameterGroupType) + FIDDLE(...) }; // type for HLSL `cbuffer` declarations, and `ConstantBuffer<T>` // ALso used for GLSL `uniform` blocks. +FIDDLE() class ConstantBufferType : public UniformParameterGroupType { - SLANG_AST_CLASS(ConstantBufferType) + FIDDLE(...) }; // type for HLSL `tbuffer` declarations, and `TextureBuffer<T>` +FIDDLE() class TextureBufferType : public UniformParameterGroupType { - SLANG_AST_CLASS(TextureBufferType) + FIDDLE(...) }; // type for GLSL `in` and `out` blocks +FIDDLE() class GLSLInputParameterGroupType : public VaryingParameterGroupType { - SLANG_AST_CLASS(GLSLInputParameterGroupType) + FIDDLE(...) }; +FIDDLE() class GLSLOutputParameterGroupType : public VaryingParameterGroupType { - SLANG_AST_CLASS(GLSLOutputParameterGroupType) + FIDDLE(...) }; // type for GLSL `buffer` blocks +FIDDLE() class GLSLShaderStorageBufferType : public PointerLikeType { - SLANG_AST_CLASS(GLSLShaderStorageBufferType) + FIDDLE(...) }; // type for Slang `ParameterBlock<T>` type +FIDDLE() class ParameterBlockType : public UniformParameterGroupType { - SLANG_AST_CLASS(ParameterBlockType) + FIDDLE(...) }; +FIDDLE() class ArrayExpressionType : public DeclRefType { - SLANG_AST_CLASS(ArrayExpressionType) - + FIDDLE(...) bool isUnsized(); void _toTextOverride(StringBuilder& out); Type* getElementType(); IntVal* getElementCount(); }; +FIDDLE() class AtomicType : public DeclRefType { - SLANG_AST_CLASS(AtomicType) - + FIDDLE(...) Type* getElementType(); }; +FIDDLE() class CoopVectorExpressionType : public ArithmeticExpressionType { - SLANG_AST_CLASS(CoopVectorExpressionType) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); BasicExpressionType* _getScalarTypeOverride(); @@ -489,10 +539,10 @@ class CoopVectorExpressionType : public ArithmeticExpressionType // The "type" of an expression that resolves to a type. // For example, in the expression `float(2)` the sub-expression, // `float` would have the type `TypeType(float)`. +FIDDLE() class TypeType : public Type { - SLANG_AST_CLASS(TypeType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); @@ -503,38 +553,43 @@ class TypeType : public Type }; // A differential pair type, e.g., `__DifferentialPair<T>` +FIDDLE() class DifferentialPairType : public ArithmeticExpressionType { - SLANG_AST_CLASS(DifferentialPairType) + FIDDLE(...) Type* getPrimalType(); }; +FIDDLE() class DifferentialPtrPairType : public ArithmeticExpressionType { - SLANG_AST_CLASS(DifferentialPtrPairType) + FIDDLE(...) Type* getPrimalRefType(); }; +FIDDLE() class DifferentiableType : public BuiltinType { - SLANG_AST_CLASS(DifferentiableType) + FIDDLE(...) }; +FIDDLE() class DifferentiablePtrType : public BuiltinType { - SLANG_AST_CLASS(DifferentiablePtrType) + FIDDLE(...) }; +FIDDLE() class DefaultInitializableType : public BuiltinType { - SLANG_AST_CLASS(DefaultInitializableType); + FIDDLE(...) }; // A vector type, e.g., `vector<T,N>` +FIDDLE() class VectorExpressionType : public ArithmeticExpressionType { - SLANG_AST_CLASS(VectorExpressionType) - + FIDDLE(...) // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); BasicExpressionType* _getScalarTypeOverride(); @@ -544,10 +599,10 @@ class VectorExpressionType : public ArithmeticExpressionType }; // A matrix type, e.g., `matrix<T,R,C,L>` +FIDDLE() class MatrixExpressionType : public ArithmeticExpressionType { - SLANG_AST_CLASS(MatrixExpressionType) - + FIDDLE(...) Type* getElementType(); IntVal* getRowCount(); IntVal* getColumnCount(); @@ -563,137 +618,152 @@ private: SLANG_UNREFLECTED Type* rowType = nullptr; }; +FIDDLE() class TensorViewType : public BuiltinType { - SLANG_AST_CLASS(TensorViewType) - + FIDDLE(...) Type* getElementType(); }; // Base class for built in string types +FIDDLE(abstract) class StringTypeBase : public BuiltinType { - SLANG_AST_CLASS(StringTypeBase) + FIDDLE(...) }; // The regular built-in `String` type +FIDDLE() class StringType : public StringTypeBase { - SLANG_AST_CLASS(StringType) + FIDDLE(...) }; // The string type native to the target +FIDDLE() class NativeStringType : public StringTypeBase { - SLANG_AST_CLASS(NativeStringType) + FIDDLE(...) }; // The built-in `__Dynamic` type +FIDDLE() class DynamicType : public BuiltinType { - SLANG_AST_CLASS(DynamicType) + FIDDLE(...) }; // Type built-in `__EnumType` type +FIDDLE() class EnumTypeType : public BuiltinType { - SLANG_AST_CLASS(EnumTypeType) - + FIDDLE(...) // TODO: provide accessors for the declaration, the "tag" type, etc. }; // Base class for types that map down to // simple pointers as part of code generation. +FIDDLE() class PtrTypeBase : public BuiltinType { - SLANG_AST_CLASS(PtrTypeBase) - + FIDDLE(...) // Get the type of the pointed-to value. Type* getValueType(); Val* getAddressSpace(); }; +FIDDLE() class NoneType : public BuiltinType { - SLANG_AST_CLASS(NoneType) + FIDDLE(...) }; +FIDDLE() class NullPtrType : public BuiltinType { - SLANG_AST_CLASS(NullPtrType) + FIDDLE(...) }; // A true (user-visible) pointer type, e.g., `T*` +FIDDLE() class PtrType : public PtrTypeBase { - SLANG_AST_CLASS(PtrType) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); }; /// A pointer-like type used to represent a parameter "direction" +FIDDLE() class ParamDirectionType : public PtrTypeBase { - SLANG_AST_CLASS(ParamDirectionType) + FIDDLE(...) }; // A type that represents the behind-the-scenes // logical pointer that is passed for an `out` // or `in out` parameter +FIDDLE(abstract) class OutTypeBase : public ParamDirectionType { - SLANG_AST_CLASS(OutTypeBase) + FIDDLE(...) }; // The type for an `out` parameter, e.g., `out T` +FIDDLE() class OutType : public OutTypeBase { - SLANG_AST_CLASS(OutType) + FIDDLE(...) }; // The type for an `in out` parameter, e.g., `in out T` +FIDDLE() class InOutType : public OutTypeBase { - SLANG_AST_CLASS(InOutType) + FIDDLE(...) }; +FIDDLE(abstract) class RefTypeBase : public ParamDirectionType { - SLANG_AST_CLASS(RefTypeBase) + FIDDLE(...) }; // The type for an `ref` parameter, e.g., `ref T` +FIDDLE() class RefType : public RefTypeBase { - SLANG_AST_CLASS(RefType) + FIDDLE(...) void _toTextOverride(StringBuilder& out); }; // The type for an `constref` parameter, e.g., `constref T` +FIDDLE() class ConstRefType : public RefTypeBase { - SLANG_AST_CLASS(ConstRefType) + FIDDLE(...) }; +FIDDLE() class OptionalType : public BuiltinType { - SLANG_AST_CLASS(OptionalType) + FIDDLE(...) Type* getValueType(); }; // A raw-pointer reference to an managed value. +FIDDLE() class NativeRefType : public BuiltinType { - SLANG_AST_CLASS(NativeRefType) + FIDDLE(...) Type* getValueType(); }; // A type alias of some kind (e.g., via `typedef`) +FIDDLE() class NamedExpressionType : public Type { - SLANG_AST_CLASS(NamedExpressionType) - + FIDDLE(...) DeclRef<TypeDefDecl> getDeclRef() { return as<DeclRefBase>(getOperand(0)); } // Overrides should be public so base classes can access @@ -705,10 +775,10 @@ class NamedExpressionType : public Type // A function type is defined by its parameter types // and its result type. +FIDDLE() class FuncType : public Type { - SLANG_AST_CLASS(FuncType) - + FIDDLE(...) // Construct a unary function FuncType(Type* paramType, Type* resultType, Type* errorType) { @@ -739,18 +809,19 @@ class FuncType : public Type }; // A tuple is a product of its member types +FIDDLE() class TupleType : public DeclRefType { - SLANG_AST_CLASS(TupleType) - + FIDDLE(...) Index getMemberCount() const; Type* getMember(Index i) const; Type* getTypePack() const; }; +FIDDLE() class EachType : public Type { - SLANG_AST_CLASS(EachType) + FIDDLE(...) Type* getElementType() const { return as<Type>(getOperand(0)); } DeclRefType* getElementDeclRefType() const { return as<DeclRefType>(getOperand(0)); } @@ -760,9 +831,10 @@ class EachType : public Type Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class ExpandType : public Type { - SLANG_AST_CLASS(ExpandType) + FIDDLE(...) Type* getPatternType() const { return as<Type>(getOperand(0)); } Index getCapturedTypePackCount() { return getOperandCount() - 1; } Type* getCapturedTypePack(Index i) { return as<Type>(getOperand(i + 1)); } @@ -778,9 +850,10 @@ class ExpandType : public Type }; // A concrete pack of types. +FIDDLE() class ConcreteTypePack : public Type { - SLANG_AST_CLASS(ConcreteTypePack) + FIDDLE(...) ConcreteTypePack(ArrayView<Type*> types) { for (auto t : types) @@ -794,10 +867,10 @@ class ConcreteTypePack : public Type }; // The "type" of an expression that names a generic declaration. +FIDDLE() class GenericDeclRefType : public Type { - SLANG_AST_CLASS(GenericDeclRefType) - + FIDDLE(...) DeclRef<GenericDecl> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); } // Overrides should be public so base classes can access @@ -808,10 +881,10 @@ class GenericDeclRefType : public Type }; // The "type" of a reference to a module or namespace +FIDDLE() class NamespaceType : public Type { - SLANG_AST_CLASS(NamespaceType) - + FIDDLE(...) DeclRef<NamespaceDeclBase> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); } NamespaceType(DeclRef<NamespaceDeclBase> inDeclRef) { setOperands(inDeclRef); } @@ -823,10 +896,10 @@ class NamespaceType : public Type // The concrete type for a value wrapped in an existential, accessible // when the existential is "opened" in some context. +FIDDLE() class ExtractExistentialType : public Type { - SLANG_AST_CLASS(ExtractExistentialType) - + FIDDLE(...) DeclRef<VarDeclBase> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); } // A reference to the original interface this type is known @@ -879,10 +952,10 @@ class ExtractExistentialType : public Type DeclRef<ThisTypeDecl> getThisTypeDeclRef(); }; +FIDDLE() class ExistentialSpecializedType : public Type { - SLANG_AST_CLASS(ExistentialSpecializedType) - + FIDDLE(...) Type* getBaseType() { return as<Type>(getOperand(0)); } ExpandedSpecializationArg getArg(Index i) { @@ -910,10 +983,10 @@ class ExistentialSpecializedType : public Type }; /// The type of `this` within a polymorphic declaration +FIDDLE() class ThisType : public DeclRefType { - SLANG_AST_CLASS(ThisType) - + FIDDLE(...) ThisType(DeclRefBase* declRef) : DeclRefType(declRef) { @@ -925,10 +998,10 @@ class ThisType : public DeclRefType /// The type of `A & B` where `A` and `B` are types /// /// A value `v` is of type `A & B` if it is both of type `A` and of type `B`. +FIDDLE() class AndType : public Type { - SLANG_AST_CLASS(AndType) - + FIDDLE(...) Type* getLeft() { return as<Type>(getOperand(0)); } Type* getRight() { return as<Type>(getOperand(1)); } @@ -940,10 +1013,10 @@ class AndType : public Type Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class ModifiedType : public Type { - SLANG_AST_CLASS(ModifiedType) - + FIDDLE(...) Type* getBase() { return as<Type>(getOperand(0)); } Index getModifierCount() { return getOperandCount() - 1; } diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 7613dbe80..efb87b831 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -2,9 +2,9 @@ #include "slang-ast-val.h" #include "slang-ast-builder.h" +#include "slang-ast-dispatch.h" #include "slang-check-impl.h" #include "slang-diagnostics.h" -#include "slang-generated-ast-macro.h" #include "slang-mangle.h" #include "slang-syntax.h" @@ -17,7 +17,7 @@ namespace Slang void ValNodeDesc::init() { Hasher hasher; - hasher.hashValue(Int(type)); + hasher.hashValue(type.getTag()); for (Index i = 0; i < operands.getCount(); ++i) { // Note: we are hashing the raw pointer value rather @@ -90,7 +90,7 @@ Val* Val::defaultResolveImpl() // Default resolve implementation is to recursively resolve all operands, and lookup in // deduplication cache. ValNodeDesc newDesc; - newDesc.type = astNodeType; + newDesc.type = SyntaxClass<NodeBase>(astNodeType); bool diff = false; for (auto operand : m_operands) { diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 3a14be17b..cdfb0b51f 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -1,20 +1,21 @@ // slang-ast-val.h - #pragma once #include "slang-ast-base.h" #include "slang-ast-decl.h" +#include "slang-ast-val.h.fiddle" +FIDDLE() namespace Slang { // Syntax class definitions for compile-time values. +FIDDLE() class DirectDeclRef : public DeclRefBase { + FIDDLE(...) public: - SLANG_AST_CLASS(DirectDeclRef) - DirectDeclRef(Decl* decl) { setOperands(decl); } DeclRefBase* _substituteImplOverride( @@ -31,11 +32,11 @@ public: // For example, MemberDeclRef(DirectDeclRef(A), B) ==> DirectDeclRef(B), // and MemberDeclRef(MemberDeclRef(A, B), C) ==> MemberDeclRef(A, C). // +FIDDLE() class MemberDeclRef : public DeclRefBase { + FIDDLE(...) public: - SLANG_AST_CLASS(MemberDeclRef); - DeclRefBase* getParentOperand() { return as<DeclRefBase>(getOperand(1)); } MemberDeclRef(Decl* decl, DeclRefBase* parent) { setOperands(decl, parent); } @@ -55,11 +56,11 @@ public: // Represent a lookup of SuperType::`m_decl` from `lookupSourceType` type that we know conforms to // SuperType. +FIDDLE() class LookupDeclRef : public DeclRefBase { + FIDDLE(...) public: - SLANG_AST_CLASS(LookupDeclRef); - // m_decl represents the decl in SuperType that we want to lookup. // The source type that we are looking up from. @@ -91,11 +92,11 @@ private: }; // Represents a specialization of a generic decl. +FIDDLE() class GenericAppDeclRef : public DeclRefBase { + FIDDLE(...) public: - SLANG_AST_CLASS(GenericAppDeclRef); - DeclRefBase* getGenericDeclRef() { return as<DeclRefBase>(getOperand(1)); } Index getArgCount() { return getOperandCount() - 2; } Val* getArg(Index index) { return getOperand(index + 2); } @@ -137,10 +138,10 @@ public: }; // A compile-time integer (may not have a specific concrete value) +FIDDLE(abstract) class IntVal : public Val { - SLANG_ABSTRACT_AST_CLASS(IntVal) - + FIDDLE(...) Type* getType() { return as<Type>(getOperand(0)); } Val* _resolveImplOverride() { return this; } @@ -152,10 +153,10 @@ class IntVal : public Val }; // Trivial case of a value that is just a constant integer +FIDDLE() class ConstantIntVal : public IntVal { - SLANG_AST_CLASS(ConstantIntVal) - + FIDDLE(...) IntegerLiteralValue getValue() { return getIntConstOperand(1); } // Overrides should be public so base classes can access @@ -166,10 +167,10 @@ class ConstantIntVal : public IntVal }; // The logical "value" of a reference to a generic value parameter +FIDDLE() class GenericParamIntVal : public IntVal { - SLANG_AST_CLASS(GenericParamIntVal) - + FIDDLE(...) DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(1)); } // Overrides should be public so base classes can access @@ -185,10 +186,10 @@ class GenericParamIntVal : public IntVal Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map); }; +FIDDLE() class TypeCastIntVal : public IntVal { - SLANG_AST_CLASS(TypeCastIntVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); Val* _resolveImplOverride(); @@ -213,10 +214,10 @@ class TypeCastIntVal : public IntVal }; // An compile time int val as result of some general computation. +FIDDLE() class FuncCallIntVal : public IntVal { - SLANG_AST_CLASS(FuncCallIntVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); Val* _resolveImplOverride(); @@ -257,10 +258,10 @@ class FuncCallIntVal : public IntVal Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map); }; +FIDDLE() class CountOfIntVal : public IntVal { - SLANG_AST_CLASS(CountOfIntVal) - + FIDDLE(...) CountOfIntVal(Type* inType, Type* typeArg) { setOperands(inType, typeArg); } Val* getTypeArg() { return getOperand(1); } @@ -275,10 +276,10 @@ class CountOfIntVal : public IntVal static Val* tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType); }; +FIDDLE() class WitnessLookupIntVal : public IntVal { - SLANG_AST_CLASS(WitnessLookupIntVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); Val* _resolveImplOverride(); @@ -300,9 +301,10 @@ class WitnessLookupIntVal : public IntVal // polynomial expression "2*a*b^3 + 1" will be represented as: // { constantTerm:1, terms: [ { constFactor:2, paramFactors:[{"a", 1}, {"b", 3}] } ] } +FIDDLE() class PolynomialIntValFactor : public Val { - SLANG_AST_CLASS(PolynomialIntValFactor) + FIDDLE(...) public: IntVal* getParam() const { return as<IntVal>(getOperand(0)); } IntegerLiteralValue getPower() const { return getIntConstOperand(1); } @@ -361,9 +363,11 @@ public: return getPower() == other.getPower() && getParam()->equals(other.getParam()); } }; + +FIDDLE() class PolynomialIntValTerm : public Val { - SLANG_AST_CLASS(PolynomialIntValTerm) + FIDDLE(...) public: IntegerLiteralValue getConstFactor() const { return getIntConstOperand(0); } OperandView<PolynomialIntValFactor> getParamFactors() const @@ -440,9 +444,10 @@ public: } }; +FIDDLE() class PolynomialIntVal : public IntVal { - SLANG_AST_CLASS(PolynomialIntVal) + FIDDLE(...) public: IntegerLiteralValue getConstantTerm() { return getIntConstOperand(1); }; OperandView<PolynomialIntValTerm> getTerms() @@ -482,10 +487,10 @@ public: }; /// An unknown integer value indicating an erroneous sub-expression +FIDDLE() class ErrorIntVal : public IntVal { - SLANG_AST_CLASS(ErrorIntVal) - + FIDDLE(...) ErrorIntVal(Type* inType) { setOperands(inType); } // TODO: We should probably eventually just have an `ErrorVal` here @@ -532,9 +537,10 @@ class ErrorIntVal : public IntVal // the concrete declarations that provide the implementation // of `ILight` for `X`. // +FIDDLE(abstract) class Witness : public Val { - SLANG_ABSTRACT_AST_CLASS(Witness) + FIDDLE(...) }; // A witness that one type is a subtype of another @@ -542,10 +548,10 @@ class Witness : public Val // relationships and type-conforms-to-interface relationships) // // TODO: we may need to tease those apart. +FIDDLE(abstract) class SubtypeWitness : public Witness { - SLANG_ABSTRACT_AST_CLASS(SubtypeWitness) - + FIDDLE(...) Val* _resolveImplOverride(); Type* getSub() { return as<Type>(getOperand(0)); } @@ -555,10 +561,10 @@ class SubtypeWitness : public Witness ConversionCost getOverloadResolutionCost(); }; +FIDDLE() class TypePackSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(TypePackSubtypeWitness) - + FIDDLE(...) Type* getSub() { return as<Type>(getOperand(0)); } Type* getSup() { return as<Type>(getOperand(1)); } @@ -578,10 +584,10 @@ class TypePackSubtypeWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class EachSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(EachSubtypeWitness) - + FIDDLE(...) EachSubtypeWitness(Type* sub, Type* sup, SubtypeWitness* patternWitness) { setOperands(sub, sup, patternWitness); @@ -594,10 +600,10 @@ class EachSubtypeWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class ExpandSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(ExpandSubtypeWitness) - + FIDDLE(...) ExpandSubtypeWitness(Type* sub, Type* sup, SubtypeWitness* patternWitness) { setOperands(sub, sup, patternWitness); @@ -610,10 +616,10 @@ class ExpandSubtypeWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class TypeEqualityWitness : public SubtypeWitness { - SLANG_AST_CLASS(TypeEqualityWitness) - + FIDDLE(...) TypeEqualityWitness(Type* subType, Type* supType) { setOperands(subType, supType); } // Overrides should be public so base classes can access @@ -621,10 +627,10 @@ class TypeEqualityWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class TypeCoercionWitness : public Witness { - SLANG_AST_CLASS(TypeCoercionWitness) - + FIDDLE(...) Type* getFromType() { return as<Type>(getOperand(0)); } Type* getToType() { return as<Type>(getOperand(1)); } @@ -637,10 +643,10 @@ class TypeCoercionWitness : public Witness // A witness that one type is a subtype of another // because some in-scope declaration says so +FIDDLE() class DeclaredSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(DeclaredSubtypeWitness) - + FIDDLE(...) DeclRef<Decl> getDeclRef() { return as<DeclRefBase>(getOperand(2)); } bool isEquality() @@ -664,10 +670,10 @@ class DeclaredSubtypeWitness : public SubtypeWitness }; // A witness that `sub : sup` because `sub : mid` and `mid : sup` +FIDDLE() class TransitiveSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(TransitiveSubtypeWitness) - + FIDDLE(...) // Witness that `sub : mid` SubtypeWitness* getSubToMid() { return as<SubtypeWitness>(getOperand(2)); } @@ -692,10 +698,10 @@ class TransitiveSubtypeWitness : public SubtypeWitness // A witness that `sub : sup` because `sub` was wrapped into // an existential of type `sup`. +FIDDLE() class ExtractExistentialSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(ExtractExistentialSubtypeWitness) - + FIDDLE(...) // The declaration of the existential value that has been opened DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(2)); } @@ -711,17 +717,18 @@ class ExtractExistentialSubtypeWitness : public SubtypeWitness /// A witness of the fact that a user provided "__Dynamic" type argument is a /// subtype to the existential type parameter. +FIDDLE() class DynamicSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(DynamicSubtypeWitness) + FIDDLE(...) DynamicSubtypeWitness(Type* inSub, Type* inSup) { setOperands(inSub, inSup); } }; /// A witness that `T : L & R` because `T : L` and `T : R` +FIDDLE() class ConjunctionSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(ConjunctionSubtypeWitness) - + FIDDLE(...) // At the operational level, this class of witness is // an operation that takes two witness tables `leftWitness` // and `rightWitness`, and forms a pair/tuple of @@ -750,10 +757,10 @@ class ConjunctionSubtypeWitness : public SubtypeWitness }; /// A witness that `T <: L` or `T <: R` because `T <: L&R` +FIDDLE() class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness { - SLANG_AST_CLASS(ExtractFromConjunctionSubtypeWitness) - + FIDDLE(...) // At the operational level, this class of witness is // an operation that takes a pair/tuple of witness tables // `(leftWtiness, rightWitness)` and extracts one of the @@ -785,52 +792,54 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness }; /// A value that represents a modifier attached to some other value +FIDDLE() class ModifierVal : public Val { - SLANG_AST_CLASS(ModifierVal) - + FIDDLE(...) Val* _resolveImplOverride() { return this; } }; +FIDDLE() class TypeModifierVal : public ModifierVal { - SLANG_AST_CLASS(TypeModifierVal) + FIDDLE(...) }; +FIDDLE() class ResourceFormatModifierVal : public TypeModifierVal { - SLANG_AST_CLASS(ResourceFormatModifierVal) + FIDDLE(...) }; +FIDDLE() class UNormModifierVal : public ResourceFormatModifierVal { - SLANG_AST_CLASS(UNormModifierVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class SNormModifierVal : public ResourceFormatModifierVal { - SLANG_AST_CLASS(SNormModifierVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +FIDDLE() class NoDiffModifierVal : public TypeModifierVal { - SLANG_AST_CLASS(NoDiffModifierVal) - + FIDDLE(...) void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; /// Represents the result of differentiating a function. +FIDDLE() class DifferentiateVal : public Val { - SLANG_AST_CLASS(DifferentiateVal) - + FIDDLE(...) DifferentiateVal(DeclRef<Decl> inFunc) { setOperands(inFunc); } DeclRef<Decl> getFunc() { return as<DeclRefBase>(getOperand(0)); } @@ -840,49 +849,50 @@ class DifferentiateVal : public Val Val* _resolveImplOverride(); }; +FIDDLE() class ForwardDifferentiateVal : public DifferentiateVal { - SLANG_AST_CLASS(ForwardDifferentiateVal) + FIDDLE(...) ForwardDifferentiateVal(DeclRef<Decl> inFunc) : DifferentiateVal(inFunc) { } }; +FIDDLE() class BackwardDifferentiateVal : public DifferentiateVal { - SLANG_AST_CLASS(BackwardDifferentiateVal) - + FIDDLE(...) BackwardDifferentiateVal(DeclRef<Decl> inFunc) : DifferentiateVal(inFunc) { } }; +FIDDLE() class BackwardDifferentiateIntermediateTypeVal : public DifferentiateVal { - SLANG_AST_CLASS(BackwardDifferentiateIntermediateTypeVal) - + FIDDLE(...) BackwardDifferentiateIntermediateTypeVal(DeclRef<Decl> inFunc) : DifferentiateVal(inFunc) { } }; +FIDDLE() class BackwardDifferentiatePrimalVal : public DifferentiateVal { - SLANG_AST_CLASS(BackwardDifferentiatePrimalVal) - + FIDDLE(...) BackwardDifferentiatePrimalVal(DeclRef<Decl> inFunc) : DifferentiateVal(inFunc) { } }; +FIDDLE() class BackwardDifferentiatePropagateVal : public DifferentiateVal { - SLANG_AST_CLASS(BackwardDifferentiatePropagateVal) - + FIDDLE(...) BackwardDifferentiatePropagateVal(DeclRef<Decl> inFunc) : DifferentiateVal(inFunc) { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 899b04b8b..e511fbc39 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -12,8 +12,8 @@ // logic also orchestrates the overall flow and how // and when things get checked. +#include "slang-ast-forward-declarations.h" #include "slang-ast-iterator.h" -#include "slang-ast-reflect.h" #include "slang-ast-synthesis.h" #include "slang-lookup.h" #include "slang-parser.h" @@ -3590,7 +3590,7 @@ bool SemanticsVisitor::doesAccessorMatchRequirement( // auto satisfyingMemberClass = satisfyingMemberDeclRef.getDecl()->getClass(); auto requiredMemberClass = requiredMemberDeclRef.getDecl()->getClass(); - if (!satisfyingMemberClass.isSubClassOfImpl(requiredMemberClass)) + if (!satisfyingMemberClass.isSubClassOf(requiredMemberClass)) return false; // We do not check the parameters or return types of accessors @@ -11261,7 +11261,7 @@ void _foreachDirectOrExtensionMemberOfType( // for (auto memberDeclRef : getMembers(semantics->getASTBuilder(), containerDeclRef)) { - if (memberDeclRef.getDecl()->getClass().isSubClassOfImpl(syntaxClass)) + if (memberDeclRef.getDecl()->getClass().isSubClassOf(syntaxClass)) { callback(memberDeclRef, (void*)userData); } @@ -11294,7 +11294,7 @@ void _foreachDirectOrExtensionMemberOfType( for (auto memberDeclRef : getMembers(semantics->getASTBuilder(), extDeclRef)) { - if (memberDeclRef.getDecl()->getClass().isSubClassOfImpl(syntaxClass)) + if (memberDeclRef.getDecl()->getClass().isSubClassOf(syntaxClass)) { callback(memberDeclRef, (void*)userData); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2f91a6a77..7b774f300 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1195,7 +1195,7 @@ Type* SemanticsVisitor::tryGetDifferentialType(ASTBuilder* builder, Type* type) auto baseDiffType = tryGetDifferentialType(builder, ptrType->getValueType()); if (!baseDiffType) return nullptr; - return builder->getPtrType(baseDiffType, ptrType->getClassInfo().m_name); + return builder->getPtrType(baseDiffType, ptrType->getClass().getName()); } else if (auto arrayType = as<ArrayExpressionType>(type)) { diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index b818c9e06..44fdf45cf 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1,4 +1,5 @@ // slang-check-overload.cpp + #include "slang-ast-base.h" #include "slang-ast-print.h" #include "slang-check-impl.h" diff --git a/source/slang/slang-check-resolve-val.cpp b/source/slang/slang-check-resolve-val.cpp index 92a9a9d6d..e16a470de 100644 --- a/source/slang/slang-check-resolve-val.cpp +++ b/source/slang/slang-check-resolve-val.cpp @@ -2,7 +2,8 @@ // Logic for resolving/simplifying Types and DeclRefs. -#include "slang-ast-reflect.h" +#include "slang-ast-dispatch.h" +#include "slang-ast-forward-declarations.h" #include "slang-ast-synthesis.h" #include "slang-check-impl.h" #include "slang-lookup.h" diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 2c8f3d0c0..db753713b 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -266,8 +266,10 @@ bool SemanticsVisitor::CoerceToProperTypeImpl( // diagnostic. // Get the AST node type info, so we can output a 'got' name - auto info = ASTClassInfo::getInfo(originalExpr->astNodeType); - diagSink->diagnose(originalExpr, Diagnostics::expectedAType, info->m_name); + diagSink->diagnose( + originalExpr, + Diagnostics::expectedAType, + originalExpr->getClass().getName()); } } @@ -296,7 +298,12 @@ bool SemanticsVisitor::CoerceToProperTypeImpl( { if (auto typeParam = as<GenericTypeParamDecl>(member)) { - if (!typeParam->initType.exp) + if (auto defaultArg = typeParam->initType.type) + { + if (outProperType) + args.add(defaultArg); + } + else { if (diagSink) { @@ -305,10 +312,6 @@ bool SemanticsVisitor::CoerceToProperTypeImpl( } return false; } - - // TODO: this is one place where syntax should get cloned! - if (outProperType) - args.add(ExtractGenericArgVal(typeParam->initType.exp)); } else if (auto valParam = as<GenericValueParamDecl>(member)) { diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 9815f6ff1..8e9b8f430 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2173,17 +2173,7 @@ SlangResult EndToEndCompileRequest::writeContainerToStream(Stream* stream) options.sourceManager = linkage->getSourceManager(); } - { - RiffContainer container; - { - SerialContainerData data; - SLANG_RETURN_ON_FAIL( - SerialContainerUtil::addEndToEndRequestToData(this, options, data)); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(data, options, &container)); - } - // We now write the RiffContainer to the stream - SLANG_RETURN_ON_FAIL(RiffUtil::write(container.getRoot(), true, stream)); - } + SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(this, options, stream)); return SLANG_OK; } diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 26a4bb43b..bfae6e400 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -34,6 +34,7 @@ namespace Slang struct PathInfo; struct IncludeHandler; struct SharedSemanticsContext; +struct ModuleChunkRef; class ProgramLayout; class PtrType; @@ -2086,8 +2087,39 @@ struct ContainerTypeKey } }; -/// A dictionary of currently loaded modules. Used by `findOrImportModule` to -/// lookup additional loaded modules. +/// A dictionary of modules to be considered when resolving `import`s, +/// beyond those that would normally be found through a `Linkage`. +/// +/// Checking of an `import` declaration will bottleneck through +/// `Linkage::findOrImportModule`, which would usually just check for +/// any module that had been previously loaded into the same `Linkage` +/// (e.g., by a call to `Linkage::loadModule()`). +/// +/// In the case where compilation is being done through an +/// explicit `FrontEndCompileRequest` or `EndToEndCompileRequest`, +/// the modules being compiled by that request do not get added to +/// the surrounding `Linkage`. +/// +/// There is a corner case when an explicit compile request has +/// multiple `TranslationUnitRequest`s, because the user (reasonably) +/// expects that if they compile `A.slang` and `B.slang` as two +/// distinct translation units in the same compile request, then +/// an `import B` inside of `A.slang` should resolve to reference +/// the code of `B.slang`. But because neither `A` nor `B` gets +/// added to the `Linkage`, and the `Linkage` is what usually +/// determines what is or isn't loaded, that intuition will +/// be wrong, without a bit of help. +/// +/// The `LoadedModuleDictionary` is thus filled in by a +/// `FrontEndCompileRequest` to collect the modules it is compiling, +/// so that they can cross-reference one another (albeit with +/// a current implementation restriction that modules in the +/// request can only `import` those earlier in the request...). +/// +/// The dictionary then gets passed around between nearly all of +/// the operations that deal with loading modules, to make sure +/// that they can detect a previously loaded module. +/// typedef Dictionary<Name*, Module*> LoadedModuleDictionary; enum ModuleBlobType @@ -2096,8 +2128,6 @@ enum ModuleBlobType IR }; -struct SerialContainerDataModule; - /// A context for loading and re-using code modules. class Linkage : public RefObject, public slang::ISession { @@ -2287,7 +2317,15 @@ public: /// Add a new target and return its index. UInt addTarget(CodeGenTarget target); - RefPtr<Module> loadModule( + /// "Bottleneck" routine for loading a module. + /// + /// All attempts to load a module, whether through + /// Slang API calls, `import` operations, or other + /// means, should bottleneck through `loadModuleImpl`, + /// or one of the specialized cases `loadSourceModuleImpl` + /// and `loadBinaryModuleImpl`. + /// + RefPtr<Module> loadModuleImpl( Name* name, const PathInfo& filePathInfo, ISlangBlob* fileContentsBlob, @@ -2296,17 +2334,49 @@ public: const LoadedModuleDictionary* additionalLoadedModules, ModuleBlobType blobType); - RefPtr<Module> loadModuleFromIRBlobImpl( + RefPtr<Module> loadSourceModuleImpl( Name* name, const PathInfo& filePathInfo, ISlangBlob* fileContentsBlob, SourceLoc const& loc, DiagnosticSink* sink, const LoadedModuleDictionary* additionalLoadedModules); - RefPtr<Module> loadDeserializedModule( + + RefPtr<Module> loadBinaryModuleImpl( Name* name, const PathInfo& filePathInfo, - SerialContainerDataModule& m, + ISlangBlob* fileContentsBlob, + SourceLoc const& loc, + DiagnosticSink* sink); + + /// Either finds a previously-loaded module matching what + /// was serialized into `moduleChunk`, or else attempts + /// to load the serialized module. + /// + /// If a previously-loaded module is found that matches the + /// name or path information in `moduleChunk`, then that + /// previously-loaded module is returned. + /// + /// Othwerise, attempts to load a module from `moduleChunk` + /// and, if successful, returns the freshly loaded module. + /// + /// Otherwise, return null. + /// + RefPtr<Module> findOrLoadSerializedModuleForModuleLibrary( + ModuleChunkRef moduleChunk, + DiagnosticSink* sink); + + RefPtr<Module> loadSerializedModule( + Name* moduleName, + const PathInfo& moduleFilePathInfo, + ModuleChunkRef moduleChunk, + SourceLoc const& requestingLoc, + DiagnosticSink* sink); + + SlangResult loadSerializedModuleContents( + Module* module, + const PathInfo& moduleFilePathInfo, + ModuleChunkRef moduleChunk, DiagnosticSink* sink); SourceFile* loadSourceFile(String pathFrom, String path); @@ -2317,10 +2387,8 @@ public: Name* name, PathInfo const& pathInfo); - /// Load a module of the given name. - Module* loadModule(String const& name); - bool isBinaryModuleUpToDate(String fromPath, RiffContainer* container); + bool isBinaryModuleUpToDate(String fromPath, ModuleChunkRef moduleChunk); RefPtr<Module> findOrImportModule( Name* name, @@ -2328,12 +2396,6 @@ public: DiagnosticSink* sink, const LoadedModuleDictionary* loadedModules = nullptr); - void prepareDeserializedModule( - SerialContainerDataModule& moduleEntry, - const PathInfo& pathInfo, - Module* module, - DiagnosticSink* sink); - SourceFile* findFile(Name* name, SourceLoc loc, IncludeSystem& outIncludeSystem); struct IncludeResult { diff --git a/source/slang/slang-doc-ast.cpp b/source/slang/slang-doc-ast.cpp index 0d4b69895..7e83d5d59 100644 --- a/source/slang/slang-doc-ast.cpp +++ b/source/slang/slang-doc-ast.cpp @@ -2,7 +2,7 @@ #include "slang-doc-ast.h" #include "../core/slang-string-util.h" -#include "slang/slang-ast-support-types.h" +#include "slang-ast-support-types.h" // #include "slang-ast-builder.h" // #include "slang-ast-print.h" diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index ada34f220..729802f4e 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -2041,10 +2041,12 @@ ScalarizedVal adaptType(IRBuilder* builder, IRInst* val, IRType* toType, IRType* // Get array sizes once auto fromSize = getIntVal(fromArray->getElementCount()); auto toSize = getIntVal(toArray->getElementCount()); - SLANG_ASSERT(fromSize <= toSize); - // Extract elements one at a time up to the source array size - for (Index i = 0; i < fromSize; i++) + // Extract elements one at a time up to the minimum + // size, between the source and destination. + // + auto limit = fromSize < toSize ? fromSize : toSize; + for (Index i = 0; i < limit; i++) { auto element = builder->emitElementExtract( fromArray->getElementType(), diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d764cade5..2b94a1fa7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -5843,13 +5843,19 @@ struct DestinationDrivenRValueExprLoweringVisitor } /// Emit code for a `try` invoke. - LoweredValInfo visitTryExpr(TryExpr* expr) + void visitTryExpr(TryExpr* expr) { auto invokeExpr = as<InvokeExpr>(expr->base); assert(invokeExpr); TryClauseEnvironment tryEnv; tryEnv.clauseType = expr->tryClauseType; - return sharedLoweringContext.visitInvokeExprImpl(invokeExpr, destination, tryEnv); + auto rValue = sharedLoweringContext.visitInvokeExprImpl(invokeExpr, destination, tryEnv); + if (rValue.flavor != LoweredValInfo::Flavor::None) + { + // If we weren't able to fuse the destination write during lowering rvalue, + // we should insert the assign operation now. + assign(context, destination, rValue); + } } }; diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 63dfffb94..f5878cb1d 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -381,7 +381,7 @@ void emitVal(ManglingContext* context, Val* val) } else if (auto modifier = as<ModifierVal>(val)) { - emitNameImpl(context, UnownedStringSlice(modifier->getClassInfo().m_name)); + emitNameImpl(context, UnownedStringSlice(modifier->getClass().getName())); } else { diff --git a/source/slang/slang-module-library.cpp b/source/slang/slang-module-library.cpp index 060c9007c..c03d5c2dd 100644 --- a/source/slang/slang-module-library.cpp +++ b/source/slang/slang-module-library.cpp @@ -45,6 +45,8 @@ SlangResult loadModuleLibrary( EndToEndCompileRequest* req, ComPtr<IModuleLibrary>& outLibrary) { + SLANG_UNUSED(path); + auto library = new ModuleLibrary; ComPtr<IModuleLibrary> scopeLibrary(library); @@ -55,54 +57,28 @@ SlangResult loadModuleLibrary( SLANG_RETURN_ON_FAIL(RiffUtil::read(&memoryStream, riffContainer)); auto linkage = req->getLinkage(); + auto sink = req->getSink(); + auto namePool = req->getNamePool(); + + auto container = ContainerChunkRef::find(&riffContainer); + + for (auto moduleChunk : container.getModules()) + { + auto loadedModule = linkage->findOrLoadSerializedModuleForModuleLibrary(moduleChunk, sink); + if (!loadedModule) + return SLANG_FAIL; + + library->m_modules.add(loadedModule); + } + + for (auto entryPointChunk : container.getEntryPoints()) { - SerialContainerData containerData; - - SerialContainerUtil::ReadOptions options; - options.namePool = req->getNamePool(); - options.session = req->getSession(); - options.sharedASTBuilder = linkage->getASTBuilder()->getSharedASTBuilder(); - options.sourceManager = linkage->getSourceManager(); - options.linkage = req->getLinkage(); - options.sink = req->getSink(); - options.astBuilder = linkage->getASTBuilder(); - options.modulePath = path; - SLANG_RETURN_ON_FAIL( - SerialContainerUtil::read(&riffContainer, options, nullptr, containerData)); - DiagnosticSink sink; - - // Modules in the container should be serialized in its depedency order, - // so that we always load the dependencies before the consuming module. - for (auto& module : containerData.modules) - { - // If the irModule is set, add it - if (module.irModule) - { - if (module.dependentFiles.getCount() == 0) - return SLANG_FAIL; - if (!module.astRootNode) - return SLANG_FAIL; - auto loadedModule = linkage->loadDeserializedModule( - as<ModuleDecl>(module.astRootNode)->getName(), - PathInfo::makePath(module.dependentFiles.getFirst()), - module, - &sink); - if (!loadedModule) - return SLANG_FAIL; - library->m_modules.add(loadedModule); - } - } - - for (const auto& entryPoint : containerData.entryPoints) - { - FrontEndCompileRequest::ExtraEntryPointInfo dst; - dst.mangledName = entryPoint.mangledName; - dst.name = entryPoint.name; - dst.profile = entryPoint.profile; - - // Add entry point - library->m_entryPoints.add(dst); - } + FrontEndCompileRequest::ExtraEntryPointInfo entryPointInfo; + entryPointInfo.mangledName = entryPointChunk.getMangledName(); + entryPointInfo.name = namePool->getName(entryPointChunk.getName()); + entryPointInfo.profile = entryPointChunk.getProfile(); + + library->m_entryPoints.add(entryPointInfo); } outLibrary.swap(scopeLibrary); diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index b09e36a1b..3c0bcf8db 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -2424,7 +2424,7 @@ SlangResult OptionsParser::_parse(int argc, char const* const* argv) { CommandLineArg name; SLANG_RETURN_ON_FAIL(m_reader.expectArg(name)); - // TODO: doagnose deprecated option + // TODO: warn that this option is deprecated break; } case OptionKind::EmbedDownstreamIR: diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 00f15cbb3..a8573c909 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1771,46 +1771,46 @@ public: void visitDeclRefExpr(DeclRefExpr* expr) { expr->scope = scope; } void visitGenericAppExpr(GenericAppExpr* expr) { - expr->functionExpr->accept(this, nullptr); + dispatch(expr->functionExpr); for (auto arg : expr->arguments) - arg->accept(this, nullptr); + dispatch(arg); } void visitIndexExpr(IndexExpr* expr) { - expr->baseExpression->accept(this, nullptr); + dispatch(expr->baseExpression); for (auto arg : expr->indexExprs) - arg->accept(this, nullptr); + dispatch(arg); } void visitMemberExpr(MemberExpr* expr) { - expr->baseExpression->accept(this, nullptr); + dispatch(expr->baseExpression); expr->scope = scope; } void visitStaticMemberExpr(StaticMemberExpr* expr) { - expr->baseExpression->accept(this, nullptr); + dispatch(expr->baseExpression); expr->scope = scope; } void visitAppExprBase(AppExprBase* expr) { - expr->functionExpr->accept(this, nullptr); + dispatch(expr->functionExpr); for (auto arg : expr->arguments) - arg->accept(this, nullptr); + dispatch(arg); } void visitIsTypeExpr(IsTypeExpr* expr) { if (expr->typeExpr.exp) - expr->typeExpr.exp->accept(this, nullptr); + dispatch(expr->typeExpr.exp); } void visitAsTypeExpr(AsTypeExpr* expr) { if (expr->typeExpr) - expr->typeExpr->accept(this, nullptr); + dispatch(expr->typeExpr); } void visitSizeOfLikeExpr(SizeOfLikeExpr* expr) { if (expr->value) - expr->value->accept(this, nullptr); + dispatch(expr->value); } void visitExpr(Expr* /*expr*/) {} }; @@ -1910,7 +1910,7 @@ static Decl* parseTraditionalFuncDecl(Parser* parser, DeclaratorInfo const& decl // ReplaceScopeVisitor replaceScopeVisitor; replaceScopeVisitor.scope = parser->currentScope; - declaratorInfo.typeSpec->accept(&replaceScopeVisitor, nullptr); + replaceScopeVisitor.dispatch(declaratorInfo.typeSpec); decl->returnType = TypeExp(declaratorInfo.typeSpec); @@ -4377,7 +4377,7 @@ static NodeBase* parseTypeAliasDecl(Parser* parser, void* /*userData*/) // the class of AST node to construct. NodeBase* parseSimpleSyntax(Parser* parser, void* userData) { - SyntaxClassBase syntaxClass((ReflectClassInfo*)userData); + SyntaxClassBase syntaxClass((SyntaxClassInfo*)userData); return (NodeBase*)syntaxClass.createInstanceImpl(parser->astBuilder); } @@ -4411,7 +4411,7 @@ static NodeBase* parseSyntaxDecl(Parser* parser, void* /*userData*/) // to the `parseSimpleSyntax` callback that will just construct // an instance of that type to represent the keyword in the AST. SyntaxParseCallback parseCallback = &parseSimpleSyntax; - void* parseUserData = (void*)syntaxClass.classInfo; + void* parseUserData = (void*)syntaxClass.getInfo(); // Next we look for an initializer that will make this keyword // an alias for some existing keyword. @@ -4435,7 +4435,7 @@ static NodeBase* parseSyntaxDecl(Parser* parser, void* /*userData*/) // If we don't already have a syntax class specified, then // we will crib the one from the existing syntax, to ensure // that we are creating a drop-in alias. - if (!syntaxClass.classInfo) + if (!syntaxClass) syntaxClass = existingSyntax->syntaxClass; } } @@ -4445,7 +4445,7 @@ static NodeBase* parseSyntaxDecl(Parser* parser, void* /*userData*/) // // TODO: down the line this should be expanded so that the user can reference // an existing *function* to use to parse the chosen syntax. - if (!syntaxClass.classInfo) + if (!syntaxClass) { // TODO: diagnose: either a type or an existing keyword needs to be specified } @@ -4757,7 +4757,7 @@ static NodeBase* parseAttributeSyntaxDecl(Parser* parser, void* /*userData*/) auto classNameAndLoc = expectIdentifier(parser); syntaxClass = parser->astBuilder->findSyntaxClass(classNameAndLoc.name); - assert(syntaxClass.classInfo); + assert(syntaxClass); } else { @@ -8428,20 +8428,20 @@ static void addBuiltinSyntax( SyntaxParseCallback callback, void* userData = nullptr) { - addBuiltinSyntaxImpl(session, scope, name, callback, userData, getClass<T>()); + addBuiltinSyntaxImpl(session, scope, name, callback, userData, getSyntaxClass<T>()); } template<typename T> static void addSimpleModifierSyntax(Session* session, Scope* scope, char const* name) { - auto syntaxClass = getClass<T>(); + auto syntaxClass = getSyntaxClass<T>(); addBuiltinSyntaxImpl( session, scope, name, &parseSimpleSyntax, (void*)syntaxClass.classInfo, - getClass<T>()); + getSyntaxClass<T>()); } static IROp parseIROp(Parser* parser, Token& outToken) @@ -8931,10 +8931,10 @@ static NodeBase* parseMagicTypeModifier(Parser* parser, void* /*userData*/) modifier->tag = uint32_t(stringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent())); } - auto classInfo = parser->astBuilder->findClassInfo(getName(parser, modifier->magicName)); - if (classInfo) + auto syntaxClass = parser->astBuilder->findSyntaxClass(getName(parser, modifier->magicName)); + if (syntaxClass) { - modifier->magicNodeType = ASTNodeType(classInfo->m_classId); + modifier->magicNodeType = syntaxClass; } // TODO: print diagnostic if the magic type name doesn't correspond to an actual ASTNodeType. parser->ReadToken(TokenType::RParent); @@ -9006,7 +9006,7 @@ static NodeBase* parseAttributeTargetModifier(Parser* parser, void* /*userData*/ static SyntaxParseInfo _makeParseExpr(const char* keywordName, SyntaxParseCallback callback) { SyntaxParseInfo entry; - entry.classInfo = &Expr::kReflectClassInfo; + entry.classInfo = getSyntaxClass<Expr>(); entry.keywordName = keywordName; entry.callback = callback; return entry; @@ -9016,18 +9016,18 @@ static SyntaxParseInfo _makeParseDecl(const char* keywordName, SyntaxParseCallba SyntaxParseInfo entry; entry.keywordName = keywordName; entry.callback = callback; - entry.classInfo = &Decl::kReflectClassInfo; + entry.classInfo = getSyntaxClass<Decl>(); return entry; } static SyntaxParseInfo _makeParseModifier( const char* keywordName, - const ReflectClassInfo& classInfo) + SyntaxClass<NodeBase> const& syntaxClass) { // If we just have class info - use simple parser SyntaxParseInfo entry; entry.keywordName = keywordName; entry.callback = &parseSimpleSyntax; - entry.classInfo = &classInfo; + entry.classInfo = syntaxClass; return entry; } static SyntaxParseInfo _makeParseModifier(const char* keywordName, SyntaxParseCallback callback) @@ -9035,7 +9035,7 @@ static SyntaxParseInfo _makeParseModifier(const char* keywordName, SyntaxParseCa SyntaxParseInfo entry; entry.keywordName = keywordName; entry.callback = callback; - entry.classInfo = &Modifier::kReflectClassInfo; + entry.classInfo = getSyntaxClass<Modifier>(); return entry; } @@ -9082,68 +9082,68 @@ static const SyntaxParseInfo g_parseSyntaxEntries[] = { // and which can be represented just by creating // a new AST node of the corresponding type. - _makeParseModifier("in", InModifier::kReflectClassInfo), - _makeParseModifier("out", OutModifier::kReflectClassInfo), - _makeParseModifier("inout", InOutModifier::kReflectClassInfo), - _makeParseModifier("__ref", RefModifier::kReflectClassInfo), - _makeParseModifier("__constref", ConstRefModifier::kReflectClassInfo), - _makeParseModifier("const", ConstModifier::kReflectClassInfo), - _makeParseModifier("__builtin", BuiltinModifier::kReflectClassInfo), - _makeParseModifier("highp", GLSLPrecisionModifier::kReflectClassInfo), - _makeParseModifier("lowp", GLSLPrecisionModifier::kReflectClassInfo), - _makeParseModifier("mediump", GLSLPrecisionModifier::kReflectClassInfo), - - _makeParseModifier("__global", ActualGlobalModifier::kReflectClassInfo), - - _makeParseModifier("inline", InlineModifier::kReflectClassInfo), - _makeParseModifier("public", PublicModifier::kReflectClassInfo), - _makeParseModifier("private", PrivateModifier::kReflectClassInfo), - _makeParseModifier("internal", InternalModifier::kReflectClassInfo), - - _makeParseModifier("require", RequireModifier::kReflectClassInfo), - _makeParseModifier("param", ParamModifier::kReflectClassInfo), - _makeParseModifier("extern", ExternModifier::kReflectClassInfo), - - _makeParseModifier("row_major", HLSLRowMajorLayoutModifier::kReflectClassInfo), - _makeParseModifier("column_major", HLSLColumnMajorLayoutModifier::kReflectClassInfo), - - _makeParseModifier("nointerpolation", HLSLNoInterpolationModifier::kReflectClassInfo), - _makeParseModifier("noperspective", HLSLNoPerspectiveModifier::kReflectClassInfo), - _makeParseModifier("linear", HLSLLinearModifier::kReflectClassInfo), - _makeParseModifier("sample", HLSLSampleModifier::kReflectClassInfo), - _makeParseModifier("centroid", HLSLCentroidModifier::kReflectClassInfo), - _makeParseModifier("precise", PreciseModifier::kReflectClassInfo), + _makeParseModifier("in", getSyntaxClass<InModifier>()), + _makeParseModifier("out", getSyntaxClass<OutModifier>()), + _makeParseModifier("inout", getSyntaxClass<InOutModifier>()), + _makeParseModifier("__ref", getSyntaxClass<RefModifier>()), + _makeParseModifier("__constref", getSyntaxClass<ConstRefModifier>()), + _makeParseModifier("const", getSyntaxClass<ConstModifier>()), + _makeParseModifier("__builtin", getSyntaxClass<BuiltinModifier>()), + _makeParseModifier("highp", getSyntaxClass<GLSLPrecisionModifier>()), + _makeParseModifier("lowp", getSyntaxClass<GLSLPrecisionModifier>()), + _makeParseModifier("mediump", getSyntaxClass<GLSLPrecisionModifier>()), + + _makeParseModifier("__global", getSyntaxClass<ActualGlobalModifier>()), + + _makeParseModifier("inline", getSyntaxClass<InlineModifier>()), + _makeParseModifier("public", getSyntaxClass<PublicModifier>()), + _makeParseModifier("private", getSyntaxClass<PrivateModifier>()), + _makeParseModifier("internal", getSyntaxClass<InternalModifier>()), + + _makeParseModifier("require", getSyntaxClass<RequireModifier>()), + _makeParseModifier("param", getSyntaxClass<ParamModifier>()), + _makeParseModifier("extern", getSyntaxClass<ExternModifier>()), + + _makeParseModifier("row_major", getSyntaxClass<HLSLRowMajorLayoutModifier>()), + _makeParseModifier("column_major", getSyntaxClass<HLSLColumnMajorLayoutModifier>()), + + _makeParseModifier("nointerpolation", getSyntaxClass<HLSLNoInterpolationModifier>()), + _makeParseModifier("noperspective", getSyntaxClass<HLSLNoPerspectiveModifier>()), + _makeParseModifier("linear", getSyntaxClass<HLSLLinearModifier>()), + _makeParseModifier("sample", getSyntaxClass<HLSLSampleModifier>()), + _makeParseModifier("centroid", getSyntaxClass<HLSLCentroidModifier>()), + _makeParseModifier("precise", getSyntaxClass<PreciseModifier>()), _makeParseModifier("shared", parseSharedModifier), - _makeParseModifier("groupshared", HLSLGroupSharedModifier::kReflectClassInfo), - _makeParseModifier("static", HLSLStaticModifier::kReflectClassInfo), - _makeParseModifier("uniform", HLSLUniformModifier::kReflectClassInfo), + _makeParseModifier("groupshared", getSyntaxClass<HLSLGroupSharedModifier>()), + _makeParseModifier("static", getSyntaxClass<HLSLStaticModifier>()), + _makeParseModifier("uniform", getSyntaxClass<HLSLUniformModifier>()), _makeParseModifier("volatile", parseVolatileModifier), _makeParseModifier("coherent", parseCoherentModifier), _makeParseModifier("restrict", parseRestrictModifier), _makeParseModifier("readonly", parseReadonlyModifier), _makeParseModifier("writeonly", parseWriteonlyModifier), - _makeParseModifier("export", HLSLExportModifier::kReflectClassInfo), - _makeParseModifier("dynamic_uniform", DynamicUniformModifier::kReflectClassInfo), + _makeParseModifier("export", getSyntaxClass<HLSLExportModifier>()), + _makeParseModifier("dynamic_uniform", getSyntaxClass<DynamicUniformModifier>()), // Modifiers for geometry shader input - _makeParseModifier("point", HLSLPointModifier::kReflectClassInfo), - _makeParseModifier("line", HLSLLineModifier::kReflectClassInfo), - _makeParseModifier("triangle", HLSLTriangleModifier::kReflectClassInfo), - _makeParseModifier("lineadj", HLSLLineAdjModifier::kReflectClassInfo), - _makeParseModifier("triangleadj", HLSLTriangleAdjModifier::kReflectClassInfo), + _makeParseModifier("point", getSyntaxClass<HLSLPointModifier>()), + _makeParseModifier("line", getSyntaxClass<HLSLLineModifier>()), + _makeParseModifier("triangle", getSyntaxClass<HLSLTriangleModifier>()), + _makeParseModifier("lineadj", getSyntaxClass<HLSLLineAdjModifier>()), + _makeParseModifier("triangleadj", getSyntaxClass<HLSLTriangleAdjModifier>()), // Modifiers for mesh shader parameters - _makeParseModifier("vertices", HLSLVerticesModifier::kReflectClassInfo), - _makeParseModifier("indices", HLSLIndicesModifier::kReflectClassInfo), - _makeParseModifier("primitives", HLSLPrimitivesModifier::kReflectClassInfo), - _makeParseModifier("payload", HLSLPayloadModifier::kReflectClassInfo), + _makeParseModifier("vertices", getSyntaxClass<HLSLVerticesModifier>()), + _makeParseModifier("indices", getSyntaxClass<HLSLIndicesModifier>()), + _makeParseModifier("primitives", getSyntaxClass<HLSLPrimitivesModifier>()), + _makeParseModifier("payload", getSyntaxClass<HLSLPayloadModifier>()), // Modifiers for unary operator declarations - _makeParseModifier("__prefix", PrefixModifier::kReflectClassInfo), - _makeParseModifier("__postfix", PostfixModifier::kReflectClassInfo), + _makeParseModifier("__prefix", getSyntaxClass<PrefixModifier>()), + _makeParseModifier("__postfix", getSyntaxClass<PostfixModifier>()), // Modifier to apply to `import` that should be re-exported - _makeParseModifier("__exported", ExportedModifier::kReflectClassInfo), + _makeParseModifier("__exported", getSyntaxClass<ExportedModifier>()), // Add syntax for more complex modifiers, which allow // or expect more tokens after the initial keyword. @@ -9208,7 +9208,7 @@ ModuleDecl* populateBaseLanguageModule(ASTBuilder* astBuilder, Scope* scope) scope, info.keywordName, info.callback, - const_cast<ReflectClassInfo*>(info.classInfo), + info.classInfo.getInfo(), info.classInfo); } diff --git a/source/slang/slang-parser.h b/source/slang/slang-parser.h index 9f9f4972a..c4e68a7fa 100644 --- a/source/slang/slang-parser.h +++ b/source/slang/slang-parser.h @@ -45,9 +45,9 @@ ModuleDecl* populateBaseLanguageModule(ASTBuilder* astBuilder, Scope* scope); /// for the `parseUserData` to be set the the associated classInfo struct SyntaxParseInfo { - const char* keywordName; ///< The keyword associated with this parse - SyntaxParseCallback callback; ///< The callback to apply to the parse - const ReflectClassInfo* classInfo; ///< + const char* keywordName; ///< The keyword associated with this parse + SyntaxParseCallback callback; ///< The callback to apply to the parse + SyntaxClass<NodeBase> classInfo; ///< }; /// Get all of the predefined SyntaxParseInfos diff --git a/source/slang/slang-ref-object-reflect.cpp b/source/slang/slang-ref-object-reflect.cpp deleted file mode 100644 index 303601b7b..000000000 --- a/source/slang/slang-ref-object-reflect.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#include "slang-ref-object-reflect.h" - -#include "slang-ast-support-types.h" -#include "slang-generated-obj-macro.h" -#include "slang-generated-obj.h" -#include "slang.h" - -// #include "slang-serialize.h" - -#include "slang-serialize-ast-type-info.h" - -namespace Slang -{ - -static const SerialClass* _addClass( - SerialClasses* serialClasses, - RefObjectType type, - RefObjectType super, - const List<SerialField>& fields) -{ - const SerialClass* superClass = - serialClasses->getSerialClass(SerialTypeKind::RefObject, SerialSubType(super)); - return serialClasses->add( - SerialTypeKind::RefObject, - SerialSubType(type), - fields.getBuffer(), - fields.getCount(), - superClass); -} - -#define SLANG_REF_OBJECT_ADD_SERIAL_FIELD(FIELD_NAME, TYPE, param) \ - fields.add(SerialField::make(#FIELD_NAME, &obj->FIELD_NAME)); - -// Note that the obj point is not nullptr, because some compilers notice this is 'indexing from -// null' and warn/error. So we offset from 1. -#define SLANG_REF_OBJECT_ADD_SERIAL_CLASS(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - { \ - NAME* obj = SerialField::getPtr<NAME>(); \ - SLANG_UNUSED(obj); \ - fields.clear(); \ - SLANG_FIELDS_RefObject_##NAME(SLANG_REF_OBJECT_ADD_SERIAL_FIELD, param) \ - _addClass(serialClasses, RefObjectType::NAME, RefObjectType::SUPER, fields); \ - } - -struct RefObjectAccess -{ - template<typename T> - static void* create(void* context) - { - SLANG_UNUSED(context) - return new T; - } - - static void calcClasses(SerialClasses* serialClasses) - { - // Add SerialRefObject first, and specially handle so that we add a null super class - serialClasses->add( - SerialTypeKind::RefObject, - SerialSubType(RefObjectType::SerialRefObject), - nullptr, - 0, - nullptr); - - // Add the rest in order such that Super class is always added before its children - List<SerialField> fields; - SLANG_CHILDREN_RefObject_SerialRefObject(SLANG_REF_OBJECT_ADD_SERIAL_CLASS, _) - } -}; - -#define SLANG_GET_SUPER_BASE(SUPER) nullptr -#define SLANG_GET_SUPER_INNER(SUPER) &SUPER::kReflectClassInfo -#define SLANG_GET_SUPER_LEAF(SUPER) &SUPER::kReflectClassInfo - -#define SLANG_GET_CREATE_FUNC_NONE(NAME) nullptr -#define SLANG_GET_CREATE_FUNC_OBJ_ABSTRACT(NAME) nullptr -#define SLANG_GET_CREATE_FUNC_OBJ(NAME) &RefObjectAccess::create<NAME> - -#define SLANG_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - /* static */ const ReflectClassInfo NAME::kReflectClassInfo = { \ - uint32_t(RefObjectType::NAME), \ - uint32_t(RefObjectType::LAST), \ - SLANG_GET_SUPER_##TYPE(SUPER), \ - #NAME, \ - SLANG_GET_CREATE_FUNC_##MARKER(NAME), \ - nullptr, \ - uint32_t(sizeof(NAME)), \ - uint8_t(SLANG_ALIGN_OF(NAME))}; - -SLANG_ALL_RefObject_SerialRefObject(SLANG_REFLECT_CLASS_INFO, _) - - /* static */ const SerialRefObjects SerialRefObjects::g_singleton; - -// Macro to set all of the entries in m_infos for SerialRefObjects -#define SLANG_GET_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - m_infos[Index(RefObjectType::NAME)] = &NAME::kReflectClassInfo; - -SerialRefObjects::SerialRefObjects(){ - SLANG_ALL_RefObject_SerialRefObject(SLANG_GET_REFLECT_CLASS_INFO, _)} - -/* static */ SlangResult SerialRefObjects::addSerialClasses(SerialClasses* serialClasses) -{ - RefObjectAccess::calcClasses(serialClasses); - return SLANG_OK; -} - -} // namespace Slang diff --git a/source/slang/slang-ref-object-reflect.h b/source/slang/slang-ref-object-reflect.h deleted file mode 100644 index 1a6bf4520..000000000 --- a/source/slang/slang-ref-object-reflect.h +++ /dev/null @@ -1,73 +0,0 @@ -// slang-ref-object-reflect.h - -#ifndef SLANG_REF_OBJECT_REFLECT_H -#define SLANG_REF_OBJECT_REFLECT_H - -#include "../core/slang-smart-pointer.h" -#include "slang-generated-obj.h" -#include "slang-serialize-reflection.h" - -class SerialClasses; - -struct RefObjectAccess; - -#define SLANG_OBJ_CLASS_REFLECT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ -public: \ - typedef NAME This; \ - static const ReflectClassInfo kReflectClassInfo; \ - virtual const ReflectClassInfo* getClassInfo() const SLANG_OVERRIDE \ - { \ - return &kReflectClassInfo; \ - } \ - \ - friend struct RefObjectAccess; \ - \ - SLANG_CLASS_REFLECT_SUPER_##TYPE(SUPER) - -// Placed in any SerialRefObject derived class -#define SLANG_ABSTRACT_OBJ_CLASS(NAME) SLANG_RefObject_##NAME(SLANG_OBJ_CLASS_REFLECT_IMPL, _) -#define SLANG_OBJ_CLASS(NAME) SLANG_RefObject_##NAME(SLANG_OBJ_CLASS_REFLECT_IMPL, _) - -namespace Slang -{ - -class SerialClasses; - -// Is friended such that internally we have access to construct or get members -struct RefObjectAccess; - -// Base class for Serialized RefObject derived classes. The main feature is that gives away to get -// ReflectClassInfo via getClassInfo() method -class SerialRefObject : public RefObject -{ -public: - typedef RefObject Super; - typedef SerialRefObject This; - - static const ReflectClassInfo kReflectClassInfo; - - virtual const ReflectClassInfo* getClassInfo() const { return &kReflectClassInfo; } -}; - -// For turning RefObjectType back to ReflectClassInfo -struct SerialRefObjects -{ - /// Add serialization classes - static SlangResult addSerialClasses(SerialClasses* serialClasses); - - static const ReflectClassInfo* getClassInfo(RefObjectType type) - { - return g_singleton.m_infos[Index(type)]; - } - - - static const SerialRefObjects g_singleton; - -protected: - SerialRefObjects(); - const ReflectClassInfo* m_infos[Index(RefObjectType::CountOf)]; -}; - -} // namespace Slang - -#endif // SLANG_REF_OBJECT_REFLECT_H diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h deleted file mode 100644 index ef2fd7ad5..000000000 --- a/source/slang/slang-serialize-ast-type-info.h +++ /dev/null @@ -1,477 +0,0 @@ -// slang-serialize-ast-type-info.h -#ifndef SLANG_SERIALIZE_AST_TYPE_INFO_H -#define SLANG_SERIALIZE_AST_TYPE_INFO_H - -#include "slang-ast-all.h" -#include "slang-ast-support-types.h" -#include "slang-serialize-misc-type-info.h" -#include "slang-serialize-type-info.h" -#include "slang-serialize-value-type-info.h" - -namespace Slang -{ - -/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AST types !!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -// SyntaxClass<T> -template<typename T> -struct SerialTypeInfo<SyntaxClass<T>> -{ - typedef SyntaxClass<T> NativeType; - typedef uint16_t SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - SLANG_UNUSED(writer); - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - dst = SerialType(src.classInfo->m_classId); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - SLANG_UNUSED(reader); - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - dst.classInfo = ASTClassInfo::getInfo(ASTNodeType(src)); - } -}; - -// MatrixCoord can just go as is -template<> -struct SerialTypeInfo<MatrixCoord> : SerialIdentityTypeInfo<MatrixCoord> -{ -}; - -inline void serializeValPointerValue(SerialWriter* writer, Val* ptrValue, SerialIndex* outSerial) -{ - if (ptrValue) - ptrValue = ptrValue->resolve(); - *(SerialIndex*)outSerial = writer->addPointer(ptrValue); -} - -inline void deserializeValPointerValue( - SerialReader* reader, - const SerialIndex* inSerial, - void* outPtr) -{ - auto val = reader->getValPointer(*(const SerialIndex*)inSerial); - *(void**)outPtr = val.m_ptr; -} - -template<typename T> -struct PtrSerialTypeInfo<T, std::enable_if_t<std::is_base_of_v<Val, T>>> -{ - typedef T* NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - auto ptrValue = *(T**)inNative; - serializeValPointerValue(writer, ptrValue, (SerialIndex*)outSerial); - } - - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - deserializeValPointerValue(reader, (SerialIndex*)inSerial, outNative); - } -}; - -template<typename T> -struct SerialTypeInfo<DeclRef<T>> : public SerialTypeInfo<DeclRefBase*> -{ -}; - -// UIntSet - -template<> -struct SerialTypeInfo<CapabilityAtomSet> -{ - typedef CapabilityAtomSet NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialIndex) - }; - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(NativeType*)native; - auto& dst = *(SerialType*)serial; - - dst = writer->addArray(src.getBuffer().getBuffer(), src.getBuffer().getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - List<CapabilityAtomSet::Element> UIntSetBuffer; - reader->getArray(src, UIntSetBuffer); - - dst = CapabilityAtomSet(); - for (Index i = 0; i < UIntSetBuffer.getCount(); i++) - dst.addRawElement(UIntSetBuffer[i], i); - } -}; - -// ~UIntSet - -template<> -struct SerialTypeInfo<CapabilityStageSet> -{ - struct SerialType - { - SerialIndex stage; - SerialIndex atomSet; - }; - - typedef CapabilityStageSet NativeType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialIndex) - }; - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - List<SerialTypeInfo<CapabilityStageSet>::SerialType> SatomSetsList; - SatomSetsList.setCount(src.atomSet.has_value()); - - if (src.atomSet) - { - auto& i = src.atomSet.value(); - SerialTypeInfo<CapabilityAtomSet>::toSerial(writer, &i, &SatomSetsList[0]); - } - - SerialTypeInfo<CapabilityAtom>::toSerial(writer, &src.stage, &dst.stage); - dst.atomSet = writer->addSerialArray<CapabilityStageSet>( - SatomSetsList.getBuffer(), - SatomSetsList.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - CapabilityAtom stage; - List<CapabilityAtomSet> items; - SerialTypeInfo<CapabilityAtom>::toNative(reader, &src.stage, &stage); - reader->getArray(src.atomSet, items); - - dst.stage = stage; - - for (auto i : items) - { - dst.addNewSet(std::move(i)); - } - } -}; - -template<> -struct SerialTypeInfo<CapabilityTargetSet> -{ - struct SerialType - { - SerialIndex target; - SerialIndex shaderStageSets; - }; - - typedef CapabilityTargetSet NativeType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialIndex) - }; - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - List<SerialTypeInfo<CapabilityStageSet>::SerialType> SStageSetList; - SStageSetList.setCount(src.shaderStageSets.getCount()); - Index iter = 0; - for (auto& i : src.shaderStageSets) - { - SerialTypeInfo<CapabilityStageSet>::toSerial(writer, &i.second, &SStageSetList[iter]); - iter++; - } - - SerialTypeInfo<CapabilityAtom>::toSerial(writer, &src.target, &dst.target); - dst.shaderStageSets = writer->addSerialArray<CapabilityStageSet>( - SStageSetList.getBuffer(), - SStageSetList.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - CapabilityAtom target; - List<CapabilityStageSet> items; - SerialTypeInfo<CapabilityAtom>::toNative(reader, &src.target, &target); - reader->getArray(src.shaderStageSets, items); - - dst.target = target; - - auto& shaderStageSets = dst.shaderStageSets; - shaderStageSets.clear(); - shaderStageSets.reserve(items.getCount()); - for (auto& i : items) - { - dst.shaderStageSets[i.stage] = i; - } - } -}; - -template<> -struct SerialTypeInfo<CapabilitySet> -{ - struct SerialType - { - SerialIndex m_targetSets; - }; - - typedef CapabilitySet NativeType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialIndex) - }; - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - List<SerialTypeInfo<CapabilityTargetSet>::SerialType> STargetSetList; - auto capabilityTargetSets = src.getCapabilityTargetSets(); - STargetSetList.setCount(capabilityTargetSets.getCount()); - Index iter = 0; - for (auto& i : capabilityTargetSets) - { - SerialTypeInfo<CapabilityTargetSet>::toSerial(writer, &i.second, &STargetSetList[iter]); - iter++; - } - - dst.m_targetSets = writer->addSerialArray<CapabilityTargetSet>( - STargetSetList.getBuffer(), - STargetSetList.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - List<CapabilityTargetSet> items; - reader->getArray(src.m_targetSets, items); - - auto& targetSets = dst.getCapabilityTargetSets(); - targetSets.clear(); - targetSets.reserve(items.getCount()); - for (auto& i : items) - { - targetSets[i.target] = i; - } - } -}; - -// ValNodeOperand -template<> -struct SerialTypeInfo<ValNodeOperand> -{ - typedef ValNodeOperand NativeType; - struct SerialType - { - int8_t kind; - int64_t val; - }; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - dst.kind = int8_t(src.kind); - if (src.kind == ValNodeOperandKind::ConstantValue) - dst.val = src.values.intOperand; - else if (src.kind == ValNodeOperandKind::ValNode) - serializeValPointerValue(writer, (Val*)src.values.nodeOperand, (SerialIndex*)&dst.val); - else - SerialTypeInfo<NodeBase*>::toSerial( - writer, - &src.values.nodeOperand, - (SerialIndex*)&dst.val); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - // Initialize - dst = NativeType(); - dst.kind = ValNodeOperandKind(src.kind); - if (dst.kind == ValNodeOperandKind::ConstantValue) - dst.values.intOperand = int64_t(src.val); - else if (dst.kind == ValNodeOperandKind::ValNode) - deserializeValPointerValue( - reader, - (SerialIndex*)&src.val, - (Val**)&dst.values.nodeOperand); - else - SerialTypeInfo<NodeBase*>::toNative( - reader, - (SerialIndex*)&src.val, - (NodeBase**)&dst.values.nodeOperand); - } -}; - -// LookupResultItem -SLANG_VALUE_TYPE_INFO(LookupResultItem) -// QualType -SLANG_VALUE_TYPE_INFO(QualType) - -// LookupResult -template<> -struct SerialTypeInfo<LookupResult> -{ - typedef LookupResult NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - if (src.isOverloaded()) - { - // Save off as an array - dst = writer->addArray(src.items.getBuffer(), src.items.getCount()); - } - else if (src.item.declRef.getDecl()) - { - dst = writer->addArray(&src.item, 1); - } - else - { - dst = SerialIndex(0); - } - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - // Initialize - dst = NativeType(); - - List<LookupResultItem> items; - reader->getArray(src, items); - - if (items.getCount() == 1) - { - dst.item = items[0]; - } - else - { - dst.items.swapWith(items); - // We have to set item such that it is valid/member of items, if items is non empty - dst.item = dst.items[0]; - } - } -}; - -// SpecializationArg -SLANG_VALUE_TYPE_INFO(SpecializationArg) -// ExpandedSpecializationArg -SLANG_VALUE_TYPE_INFO(ExpandedSpecializationArg) -// TypeExp -SLANG_VALUE_TYPE_INFO(TypeExp) -// DeclCheckStateExt -SLANG_VALUE_TYPE_INFO(DeclCheckStateExt) - -// Modifiers -template<> -struct SerialTypeInfo<Modifiers> -{ - typedef Modifiers NativeType; - typedef SerialIndex SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - // We need to make into an array - List<SerialIndex> modifierIndices; - for (Modifier* modifier : *(NativeType*)native) - { - modifierIndices.add(writer->addPointer(modifier)); - } - *(SerialType*)serial = - writer->addArray(modifierIndices.getBuffer(), modifierIndices.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - List<Modifier*> modifiers; - reader->getArray(*(const SerialType*)serial, modifiers); - - Modifier* prev = nullptr; - for (Modifier* modifier : modifiers) - { - if (prev) - { - prev->next = modifier; - } - } - - NativeType& dst = *(NativeType*)native; - dst.first = modifiers.getCount() > 0 ? modifiers[0] : nullptr; - } -}; - -// LookupResultItem_Breadcrumb::ThisParameterMode -template<> -struct SerialTypeInfo<LookupResultItem_Breadcrumb::ThisParameterMode> - : public SerialConvertTypeInfo<LookupResultItem_Breadcrumb::ThisParameterMode, uint8_t> -{ -}; - -// LookupResultItem_Breadcrumb::Kind -template<> -struct SerialTypeInfo<LookupResultItem_Breadcrumb::Kind> - : public SerialConvertTypeInfo<LookupResultItem_Breadcrumb::Kind, uint8_t> -{ -}; - -// RequirementWitness::Flavor -template<> -struct SerialTypeInfo<RequirementWitness::Flavor> - : public SerialConvertTypeInfo<RequirementWitness::Flavor, uint8_t> -{ -}; - -// RequirementWitness -SLANG_VALUE_TYPE_INFO(RequirementWitness) - -// SPIRVAsm -SLANG_VALUE_TYPE_INFO(SPIRVAsmOperand) -SLANG_VALUE_TYPE_INFO(SPIRVAsmInst) - -} // namespace Slang - -#endif diff --git a/source/slang/slang-serialize-ast.cpp b/source/slang/slang-serialize-ast.cpp index a7837edea..aad3bcc57 100644 --- a/source/slang/slang-serialize-ast.cpp +++ b/source/slang/slang-serialize-ast.cpp @@ -1,208 +1,1542 @@ // slang-serialize-ast.cpp #include "slang-serialize-ast.h" -#include "slang-ast-dump.h" -#include "slang-ast-support-types.h" -#include "slang-generated-ast-macro.h" -#include "slang-generated-ast.h" -#include "slang-serialize-ast-type-info.h" -#include "slang-serialize-factory.h" +#include "slang-ast-dispatch.h" +#include "slang-compiler.h" +#include "slang-diagnostics.h" +#include "slang-mangle.h" namespace Slang { +// TODO(tfoley): have the parser export this, or a utility function +// for initializing a `SyntaxDecl` in the common case. +// +NodeBase* parseSimpleSyntax(Parser* parser, void* userData); -// !!!!!!!!!!!!!!!!!!!!!! Generate fields for a type !!!!!!!!!!!!!!!!!!!!!!!!!!! -static const SerialClass* _addClass( - SerialClasses* serialClasses, - ASTNodeType type, - ASTNodeType super, - const List<SerialField>& fields) +struct ASTEncodingContext { - const SerialClass* superClass = - serialClasses->getSerialClass(SerialTypeKind::NodeBase, SerialSubType(super)); - return serialClasses->add( - SerialTypeKind::NodeBase, - SerialSubType(type), - fields.getBuffer(), - fields.getCount(), - superClass); -} +private: + Encoder* encoder; + struct UnhandledCase + { + }; + + typedef Int DeclID; + Dictionary<Decl*, DeclID> mapDeclToID; + List<Decl*> decls; + + struct ImportedDeclInfo + { + Int moduleIndex = -1; + Decl* decl; + }; + List<ImportedDeclInfo> importedDecls; -#define SLANG_AST_ADD_SERIAL_FIELD(FIELD_NAME, TYPE, param) \ - fields.add(SerialField::make(#FIELD_NAME, &obj->FIELD_NAME)); + typedef Int ValID; + Dictionary<Val*, ValID> mapValToID; + List<Val*> vals; -// Note that the obj point is not nullptr, because some compilers notice this is 'indexing from -// null' and warn/error. So we offset from 1. -#define SLANG_AST_ADD_SERIAL_CLASS(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - { \ - NAME* obj = SerialField::getPtr<NAME>(); \ - SLANG_UNUSED(obj); \ - fields.clear(); \ - SLANG_FIELDS_ASTNode_##NAME(SLANG_AST_ADD_SERIAL_FIELD, param) \ - _addClass(serialClasses, ASTNodeType::NAME, ASTNodeType::SUPER, fields); \ + ModuleDecl* _module = nullptr; + + SerialSourceLocWriter* _sourceLocWriter = nullptr; + +public: + ASTEncodingContext(Encoder* encoder, ModuleDecl* module, SerialSourceLocWriter* sourceLocWriter) + : encoder(encoder), _module(module), _sourceLocWriter(sourceLocWriter) + { } -struct ASTFieldAccess -{ - static void calcClasses(SerialClasses* serialClasses) + template<typename T> + void encodeASTNodeContent(T* node) { - // Add NodeBase first, and specially handle so that we add a null super class - serialClasses->add( - SerialTypeKind::NodeBase, - SerialSubType(ASTNodeType::NodeBase), - nullptr, - 0, - nullptr); + Encoder::WithObject withObject(encoder); - // Add the rest in order such that Super class is always added before its children - List<SerialField> fields; - SLANG_CHILDREN_ASTNode_NodeBase(SLANG_AST_ADD_SERIAL_CLASS, _) + ASTNodeDispatcher<T, void>::dispatch(node, [&](auto n) { _encodeDataOf(n); }); + } + + void flush() + { + auto containerChunk = encoder->getRIFFChunk(); + + RiffContainer::Chunk* declChunk = nullptr; + RiffContainer::Chunk* importedDeclChunk = nullptr; + RiffContainer::Chunk* valChunk = nullptr; + { + Encoder::WithArray withList(encoder); + declChunk = encoder->getRIFFChunk(); + } + { + Encoder::WithArray withList(encoder); + importedDeclChunk = encoder->getRIFFChunk(); + } + { + Encoder::WithArray withList(encoder); + valChunk = encoder->getRIFFChunk(); + } + Int declIndex = 0; + Int importedDeclIndex = 0; + Int valIndex = 0; + + bool done = false; + do + { + done = true; + while (declIndex < decls.getCount()) + { + done = false; + encoder->setRIFFChunk(declChunk); + encodeASTNodeContent(decls[declIndex++]); + } + while (importedDeclIndex < importedDecls.getCount()) + { + done = false; + encoder->setRIFFChunk(importedDeclChunk); + encodeImportedDecl(importedDecls[importedDeclIndex++]); + } + while (valIndex < vals.getCount()) + { + done = false; + encoder->setRIFFChunk(valChunk); + encodeASTNodeContent(vals[valIndex++]); + } + } while (!done); + + RiffContainer::calcAndSetSize(containerChunk); + encoder->setRIFFChunk(containerChunk); + } + + ModuleDecl* findModuleForDecl(Decl* decl) + { + for (auto d = decl; d; d = d->parentDecl) + { + if (auto m = as<ModuleDecl>(d)) + return m; + } + return nullptr; } -}; -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerialUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!! + ModuleDecl* findModuleDeclWasImportedFrom(Decl* decl) + { + auto declModule = findModuleForDecl(decl); + if (declModule == nullptr) + return nullptr; + if (declModule == _module) + return nullptr; + return declModule; + } + + DeclID getDeclID(Decl* decl) + { + SLANG_ASSERT(decl != nullptr); + + if (auto found = mapDeclToID.tryGetValue(decl)) + return *found; + + // We need to detect whether the declaration is an + // imported one, or one from this module itself. + // + // Imported declarations need to be handled very + // differently, since they'll involve resolving + // references to those other modules, and the + // declarations within them. + // + if (auto importedFromModule = findModuleDeclWasImportedFrom(decl)) + { + DeclID importedFromModuleDeclID = 0; + if (decl != importedFromModule) + { + importedFromModuleDeclID = getDeclID(importedFromModule); + } + + DeclID id = ~importedDecls.getCount(); + mapDeclToID.add(decl, id); + + ImportedDeclInfo info; + info.moduleIndex = ~importedFromModuleDeclID; + info.decl = decl; + importedDecls.add(info); + + return id; + } + else + { + DeclID id = decls.getCount(); + decls.add(decl); + mapDeclToID.add(decl, id); + + return id; + } + } + + void encodePtr(Decl* decl) + { + DeclID id = getDeclID(decl); + encoder->encode(id); + } + + ValID getValID(Val* val) + { + SLANG_ASSERT(val != nullptr); + + if (auto found = mapValToID.tryGetValue(val)) + return *found; + + // In order to ensure that values can be fully constructed + // from the get-go (so that they will get cached correctly), + // we conspire to ensure that every value is preceded by + // all of its operands. + // + for (auto operand : val->m_operands) + { + switch (operand.kind) + { + default: + break; + + case ValNodeOperandKind::ValNode: + if (auto operandNode = operand.values.nodeOperand) + { + SLANG_ASSERT(as<Val>(operandNode)); + getValID(static_cast<Val*>(operandNode)); + } + break; + + case ValNodeOperandKind::ASTNode: + if (auto operandNode = operand.values.nodeOperand) + { + SLANG_ASSERT(as<Decl>(operandNode)); + getDeclID(static_cast<Decl*>(operandNode)); + } + break; + } + } + auto resolved = val->resolve(); + if (resolved != val) + { + getValID(resolved); + } + + ValID id = vals.getCount(); + vals.add(val); + mapValToID.add(val, id); + return id; + } + + void encodePtr(Val* val) + { + ValID id = getValID(val); + encoder->encode(id); + } + + void encodeImportedDecl(ImportedDeclInfo const& info) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(info.moduleIndex); + auto decl = info.decl; + if (auto importedModuleDecl = as<ModuleDecl>(decl)) + { + SLANG_ASSERT(info.moduleIndex == -1); + encode(importedModuleDecl->getName()); + } + else + { + auto mangledName = getMangledName(getCurrentASTBuilder(), decl); + encode(mangledName); + } + } + + void encodePtr(Modifier* modifier) { encodeASTNodeContent(modifier); } + void encodePtr(Expr* expr) { encodeASTNodeContent(expr); } + void encodePtr(Stmt* stmt) { encodeASTNodeContent(stmt); } + + void encodePtr(Name* name) { encode(name->text); } + + void encodePtr(MarkupEntry* entry) + { + // TODO: is this case needed? + SLANG_UNUSED(entry); + } + + void encodePtr(DeclAssociationList* list) + { + // We serialize this as if it were a simple list + // of key-value pairs because... well... that's + // what it amounts to in practice. + // + Encoder::WithArray withArray(encoder); + for (auto association : list->associations) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(association->kind); + encode(association->decl); + } + } + + void encodePtr(CandidateExtensionList* list) { encode(list->candidateExtensions); } + + void encodePtr(WitnessTable* witnessTable) + { + Encoder::WithObject withObject(encoder); + encode(witnessTable->baseType); + encode(witnessTable->witnessedType); + encode(witnessTable->isExtern); + + // TODO(tfoley): In theory we should be able to streamline + // this so that we only encode the requirements that we + // absolutely need to (which basically amounts to `associatedtype` + // requirements where the satisfying type is part of the public + // API of the type). + // + encode(witnessTable->m_requirementDictionary); + } + + void encodeValue(RequirementWitness const& witness) + { + Encoder::WithKeyValuePair withPair(encoder); + encodeEnum(witness.m_flavor); + switch (witness.m_flavor) + { + case RequirementWitness::Flavor::none: + break; + + case RequirementWitness::Flavor::declRef: + encode(witness.m_declRef); + break; + + case RequirementWitness::Flavor::val: + encode(witness.m_val); + break; + + case RequirementWitness::Flavor::witnessTable: + encode((WitnessTable*)witness.m_obj.Ptr()); + break; + } + } + + void encodePtr(DiagnosticInfo* info) { encode(Int(info->id)); } + + void encodePtr(DeclBase* declBase) + { + if (auto decl = as<Decl>(declBase)) + { + encodePtr(decl); + } + else + { + encodeASTNodeContent(declBase); + } + } + + void encodeValue(UnhandledCase); + + void encodeValue(String const& value) { encoder->encode(value); } + + void encodeValue(Token const& value) + { + encode(value.type); + encode(TokenFlags(value.flags & ~TokenFlag::Name)); + encode(value.loc); + if (value.hasContent()) + encoder->encodeString(value.getContent()); + else + encode(nullptr); + } + + void encodeValue(NameLoc const& value) { encode(value.name); } + + void encodeValue(SemanticVersion value) { encoder->encode(value.toInteger()); } + + void encodeValue(CapabilitySet const& value) + { + // While the `CapabilityTargetSets` type is a dictionary, + // in practice each entry already embeds its own key + // (the target atom), so we can encode this as just + // an array of the `CapabilityTargetSet` values. + // + Encoder::WithArray withArray(encoder); + for (auto pair : value.getCapabilityTargetSets()) + { + encode(pair.second); + } + } + + void encodeValue(CapabilityTargetSet const& value) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(value.target); + + // Similar to the case for the `CapabilityTargetSets` above, + // each `CapabilityStageSet` already includes the stage atom, + // so we can simply encode the values from the dictionary. + // + Encoder::WithArray withArray(encoder); + for (auto pair : value.shaderStageSets) + { + encode(pair.second); + } + } + + void encodeValue(CapabilityStageSet const& value) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(value.stage); + encode(value.atomSet); + } + + void encodeValue(CapabilityAtomSet const& value) + { + Encoder::WithArray withArray(encoder); + for (auto rawAtom : value) + { + encode(CapabilityAtom(rawAtom)); + } + } -/* static */ void ASTSerialUtil::addSerialClasses(SerialClasses* serialClasses) + template<typename T> + void encodeValue(std::optional<T> const& value) + { + if (value) + encodeValue(*value); + else + encoder->encode(nullptr); + } + + void encodeValue(SyntaxClass<NodeBase> const& value) { encode(value.getTag()); } + + template<typename T> + void encodeValue(DeclRef<T> const& value) + { + encode((DeclRefBase*)value); + } + + void encodeValue(ValNodeOperand value) + { + Encoder::WithKeyValuePair withPair(encoder); + + encodeEnum(value.kind); + switch (value.kind) + { + case ValNodeOperandKind::ConstantValue: + encode(value.values.intOperand); + break; + + case ValNodeOperandKind::ValNode: + encode(static_cast<Val*>(value.values.nodeOperand)); + break; + + case ValNodeOperandKind::ASTNode: + { + if (auto decl = as<Decl>(value.values.nodeOperand)) + { + encode(decl); + } + else + { + SLANG_UNEXPECTED("AST node operand of `Val` was expected to be a `Decl`"); + } + } + break; + } + } + + void encodeValue(TypeExp value) { encode(value.type); } + + void encodeValue(QualType value) + { + Encoder::WithObject withObject(encoder); + encode(value.type); + encode(value.isLeftValue); + encode(value.hasReadOnlyOnTarget); + encode(value.isWriteOnly); + } + + void encodeValue(MatrixCoord value) + { + Encoder::WithObject withObject(encoder); + encode(value.row); + encode(value.col); + } + + void encodeValue(SPIRVAsmOperand::Flavor const& value) { encodeEnum(value); } + + void encodeValue(SPIRVAsmOperand const& value) + { + Encoder::WithObject withObject(encoder); + encode(value.flavor); + encode(value.token); + encode(value.expr); + encode(value.bitwiseOrWith); + encode(value.knownValue); + encode(value.wrapInId); + encode(value.type); + } + + void encodeValue(SPIRVAsmInst const& value) + { + Encoder::WithObject withObject(encoder); + encode(value.opcode); + encode(value.operands); + } + + + template<typename T, typename = std::enable_if_t<std::is_same_v<T, bool>>> + void encodeValue(T value) + { + encoder->encodeBool(value); + } + + void encodeValue(Int32 value) { encoder->encode(value); } + void encodeValue(UInt32 value) { encoder->encode(value); } + void encodeValue(Int64 value) { encoder->encode(value); } + void encodeValue(UInt64 value) { encoder->encode(value); } + void encodeValue(float value) { encoder->encode(value); } + void encodeValue(double value) { encoder->encode(value); } + + void encodeValue(uint8_t value) { encoder->encode(UInt32(value)); } + + void encodeValue(nullptr_t) { encoder->encode(nullptr); } + + template<typename T> + void encodeEnum(T value) + { + encoder->encode(Int32(value)); + } + + void encodeValue(DeclVisibility value) { encodeEnum(value); } + void encodeValue(BaseType value) { encodeEnum(value); } + void encodeValue(BuiltinRequirementKind value) { encodeEnum(value); } + void encodeValue(ASTNodeType value) { encodeEnum(value); } + void encodeValue(ImageFormat value) { encodeEnum(value); } + void encodeValue(TypeTag value) { encodeEnum(value); } + void encodeValue(TryClauseType value) { encodeEnum(value); } + void encodeValue(CapabilityAtom value) { encodeEnum(value); } + void encodeValue(DeclAssociationKind value) { encodeEnum(value); } + void encodeValue(TokenType value) { encodeEnum(value); } + + void encodeValue(SourceLoc value) + { + if (!_sourceLocWriter) + { + encoder->encode(nullptr); + } + else + { + auto intermediate = _sourceLocWriter->addSourceLoc(value); + encoder->encode(intermediate); + } + } + + template<typename T> + void encodeValue(T const* ptr) + { + if (!ptr) + { + encoder->encode(nullptr); + } + else + { + encodePtr(const_cast<T*>(ptr)); + } + } + + template<typename T> + void encodeValue(RefPtr<T> const& ptr) + { + if (!ptr) + { + encoder->encode(nullptr); + } + else + { + encodePtr(ptr.Ptr()); + } + } + + void encodeValue(Modifiers const& modifiers) + { + Encoder::WithArray withArray(encoder); + for (auto m : const_cast<Modifiers&>(modifiers)) + { + encode(m); + } + } + + template<typename T, int N> + void encodeValue(ShortList<T, N> const& array) + { + Encoder::WithArray withArray(encoder); + for (auto element : array) + { + encode(element); + } + } + + + template<typename T> + void encode(List<T> const& array) + { + Encoder::WithArray withArray(encoder); + for (auto element : array) + { + encode(element); + } + } + + template<typename T, size_t N> + void encode(T const (&array)[N]) + { + Encoder::WithArray withArray(encoder); + for (auto element : array) + { + encode(element); + } + } + + template<typename K, typename V> + void encode(OrderedDictionary<K, V> const& dictionary) + { + Encoder::WithArray withArray(encoder); + for (auto p : dictionary) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(p.key); + encode(p.value); + } + } + + template<typename K, typename V> + void encode(Dictionary<K, V> const& dictionary) + { + Encoder::WithArray withArray(encoder); + for (auto p : dictionary) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(p.first); + encode(p.second); + } + } + + template<typename T> + void encode(T const& value) + { + encodeValue(value); + } + + // for each class of node, we generate + // code to recursively serialize each + // of its fields. + +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do + void _encodeDataOf($T* obj) + { +%if T.directSuperClass then + _encodeDataOf(static_cast<$(T.directSuperClass)*>(obj)); +%end +%for _,f in ipairs(T.directFields) do + encode(obj->$f); +%end + } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-serialize-ast.cpp.fiddle" +#endif // FIDDLE END +}; + +void writeSerializedModuleAST( + Encoder* encoder, + ModuleDecl* moduleDecl, + SerialSourceLocWriter* sourceLocWriter) { - ASTFieldAccess::calcClasses(serialClasses); + Encoder::WithObject withObject(encoder); + + // TODO: we should have a more careful pass here, + // where we only encode the public declarations + // + + ASTEncodingContext context(encoder, moduleDecl, sourceLocWriter); + context.getDeclID(moduleDecl); + context.flush(); } -/* static */ SlangResult ASTSerialUtil::testSerialize( - NodeBase* node, - RootNamePool* rootNamePool, - SharedASTBuilder* sharedASTBuilder, - SourceManager* sourceManager) +struct ASTDecodingContext { - RefPtr<SerialClasses> classes; +public: + ASTDecodingContext( + Linkage* linkage, + ASTBuilder* astBuilder, + DiagnosticSink* sink, + RiffContainer::Chunk* rootChunk, + SerialSourceLocReader* sourceLocReader, + SourceLoc requestingSourceLoc) + : _linkage(linkage) + , _astBuilder(astBuilder) + , _sink(sink) + , _rootChunk(static_cast<RiffContainer::ListChunk*>(rootChunk)) + , _sourceLocReader(sourceLocReader) + , _requestingSourceLoc(requestingSourceLoc) + { + } + + Linkage* _linkage = nullptr; + DiagnosticSink* _sink = nullptr; + SerialSourceLocReader* _sourceLocReader = nullptr; + SourceLoc _requestingSourceLoc; + + SlangResult decodeAll() + { + auto cursor = _rootChunk->getFirstContainedChunk(); + + // There are a few different top-level chunks that + // hold different arrays that we need in order + // to decode the entire module hierarchy. + // + // Basically, these lists correspond to the kinds + // of nodes in the AST hierarchy for which back-references + // are allowed (all other nodes should, barring + // weird corner cases, form a single tree-structured + // ownership hierarchy, rooted at the `ModuleDecl`. + // + + // First there is the list that actually encodes + // for the declarations in the module, including + // the `ModuleDecl` itself, which should be the + // first entry in the list. + // + auto declChunk = cursor; + cursor = cursor->m_next; + + // Next there is a list of all the declarations + // referenced inside of the module that need to + // be imported in from outside. + // + auto importedDeclChunk = cursor; + cursor = cursor->m_next; + + // Then there are all the `Val`-derived nodes that + // are needed by the module, which will need to be + // deduplicated so that they are unique within the + // current compilation context. + // + auto valChunk = cursor; + cursor = cursor->m_next; + + // The process of decoding the module is then spread + // over a number of steps. + // + // The first step is to process all of the imported + // declarations, so that other nodes can refer to + // them. + // + SLANG_RETURN_ON_FAIL(decodeImportedDecls(importedDeclChunk)); + + // Next we process the declarations that are within + // the module itself, first creating an "empty shell" + // of each declaration that has the right size in + // memory (and the right `ASTNodeType` tag), so that + // we can wire up references to it (including circular + // references)... so long as nothing here tries to + // look *inside* the empty shell along the way. + // + SLANG_RETURN_ON_FAIL(createEmptyShells(declChunk)); + + // Once all the `Decl`s that might be needed have + // been allocated, we can process all the `Val`s + // that might reference those`Decl`s (and one another). + // + // The nature of the `Val` representation ensures + // that there cannot be cirularities in the references + // between `Val`s, and the encoding process will have + // sorted the entries so that a `Val` only ever appears + // *after* its operands. + // + SLANG_RETURN_ON_FAIL(decodeVals(valChunk)); + + // Once all the back-reference-able objects have been + // instantiated in memory, we can go back through the + // `Decl`s in the module and fill in those empty shells. + // + SLANG_RETURN_ON_FAIL(fillEmptyShells(declChunk)); + + // As a final pass, we perform any special cleanup actions + // that might be required to make the output valid for consumers. + // + // For example, this is where we set the `DeclCheckState` of everything + // we are loading to reflect the fact that everything we deserialize + // is (supposed to be) fully cheked. + // + SLANG_RETURN_ON_FAIL(cleanUpNodes()); + + + return SLANG_OK; + } + + typedef Int DeclID; + Decl* getDeclByID(DeclID id) + { + if (id >= 0) + { + return _decls[id]; + } + else + { + return _importedDecls[~id]; + } + } + +private: + struct UnhandledCase + { + }; + + ASTBuilder* _astBuilder = nullptr; + RiffContainer::ListChunk* _rootChunk = nullptr; + + List<Decl*> _decls; + List<Decl*> _importedDecls; + List<Val*> _vals; + + typedef Int ValID; + Val* getValByID(ValID id) { return _vals[id]; } + + SlangResult decodeImportedDecls(RiffContainer::Chunk* importedDeclChunk) + { + Decoder decoder(importedDeclChunk); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Decoder::WithKeyValuePair withPair(decoder); + + Int moduleIndex; + decode(moduleIndex, decoder); + + if (moduleIndex == -1) + { + Name* moduleName = nullptr; + decode(moduleName, decoder); + + Decl* importedModule = getImportedModule(moduleName); + _importedDecls.add(importedModule); + } + else + { + auto importedFromModuleDecl = as<ModuleDecl>(_importedDecls[moduleIndex]); + auto importedFromModule = importedFromModuleDecl->module; + + String mangledName; + decode(mangledName, decoder); + + auto importedNode = + importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); + auto importedDecl = as<Decl>(importedNode); + _importedDecls.add(importedDecl); + } + } + return SLANG_OK; + } + + ModuleDecl* getImportedModule(Name* moduleName) + { + Module* module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); + if (!module) + { + SLANG_ABORT_COMPILATION("failed to load an imported module during deserialization"); + } + + return module->getModuleDecl(); + } + + SlangResult decodeVals(RiffContainer::Chunk* valChunk) + { + Decoder decoder(valChunk); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Val* val = decodeValNode(decoder); + _vals.add(val); + } + return SLANG_OK; + } - SerialClassesUtil::create(classes); + SlangResult createEmptyShells(RiffContainer::Chunk* declChunk) + { + Decoder decoder(declChunk); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + ASTNodeType nodeType; + + // Each of the declarations is expected to take + // the form of an object with a first field + // that holds the node type. + // + { + Decoder::WithObject withObject(decoder); + decode(nodeType, decoder); + } + + auto emptyShell = createEmptyShell(nodeType); + auto declEmptyShell = as<Decl>(emptyShell); + _decls.add(declEmptyShell); + } - List<uint8_t> contents; + return SLANG_OK; + } + Val* decodeValNode(Decoder& decoder) { - OwnedMemoryStream stream(FileAccess::ReadWrite); + Decoder::WithObject withObject(decoder); - ModuleDecl* moduleDecl = as<ModuleDecl>(node); - // Only serialize out things *in* this module - ModuleSerialFilter filterStorage(moduleDecl); + ASTNodeType nodeType; + decode(nodeType, decoder); - SerialFilter* filter = moduleDecl ? &filterStorage : nullptr; + ValNodeDesc desc; + desc.type = SyntaxClass<NodeBase>(nodeType); - SerialWriter writer(classes, filter); + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + ValNodeOperand operand; + decode(operand, decoder); + desc.operands.add(operand); + } - // Lets serialize it all - writer.addPointer(node); - // Let's stick it all in a stream - writer.write(&stream); + desc.init(); - stream.swapContents(contents); + auto val = _astBuilder->_getOrCreateImpl(_Move(desc)); - NamePool namePool; - namePool.setRootNamePool(rootNamePool); + // Values created during deserialization are + // not expected to ever resolve further, because + // they should be coming from fully checked code. + // + // val->resolve(); + // val->_setUnique(); - ASTBuilder builder(sharedASTBuilder, "Serialize Check"); + return val; + } - SetASTBuilderContextRAII astBuilderRAII(&builder); + NodeBase* createEmptyShell(ASTNodeType nodeType) + { + return SyntaxClass<NodeBase>(nodeType).createInstance(_astBuilder); + } - DefaultSerialObjectFactory objectFactory(&builder); + SlangResult fillEmptyShells(RiffContainer::Chunk* declChunk) + { + Index declIndex = 0; - // We could now check that the loaded data matches + Decoder decoder(declChunk); + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + auto declEmptyShell = _decls[declIndex++]; + decodeASTNodeContent(declEmptyShell, decoder); + } + return SLANG_OK; + } + + SlangResult cleanUpNodes() + { + for (auto decl : _decls) { - const List<SerialInfo::Entry*>& writtenEntries = writer.getEntries(); - List<const SerialInfo::Entry*> readEntries; + decl->checkState = DeclCheckState::CapabilityChecked; + } - SlangResult res = SerialReader::loadEntries( - contents.getBuffer(), - contents.getCount(), - classes, - readEntries); - SLANG_UNUSED(res); + return SLANG_OK; + } - SLANG_ASSERT(writtenEntries.getCount() == readEntries.getCount()); - // They should be identical up to the - for (Index i = 1; i < readEntries.getCount(); ++i) + void assignGenericParameterIndices(GenericDecl* genericDecl) + { + int parameterCounter = 0; + for (auto m : genericDecl->members) + { + if (auto typeParam = as<GenericTypeParamDeclBase>(m)) + { + typeParam->parameterIndex = parameterCounter++; + } + else if (auto valParam = as<GenericValueParamDecl>(m)) { - auto writtenEntry = writtenEntries[i]; - auto readEntry = readEntries[i]; + valParam->parameterIndex = parameterCounter++; + } + } + } + - const size_t writtenSize = writtenEntry->calcSize(classes); - const size_t readSize = readEntry->calcSize(classes); - SLANG_UNUSED(writtenSize); - SLANG_UNUSED(readSize); + void cleanUpASTNode(NodeBase* node) + { + if (auto expr = as<Expr>(node)) + { + expr->checked = true; + } + else if (auto genericDecl = as<GenericDecl>(node)) + { + assignGenericParameterIndices(genericDecl); + } + else if (auto syntaxDecl = as<SyntaxDecl>(node)) + { + syntaxDecl->parseCallback = &parseSimpleSyntax; + syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); + } + else if (auto namespaceLikeDecl = as<NamespaceDeclBase>(node)) + { + auto declScope = _astBuilder->create<Scope>(); + declScope->containerDecl = namespaceLikeDecl; + namespaceLikeDecl->ownedScope = declScope; + } + } + + void decodeASTNodeContent(NodeBase* node, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); - SLANG_ASSERT(readSize == writtenSize); - // Check the payload is the same - SLANG_ASSERT(memcmp(readEntry, writtenEntry, readSize) == 0); + ASTNodeDispatcher<NodeBase, void>::dispatch( + node, + [&](auto n) { _decodeDataOf(n, decoder); }); + + cleanUpASTNode(node); + } + + DeclID decodeDeclID(Decoder& decoder) + { + DeclID result = decoder.decode<DeclID>(); + return result; + } + + ValID decodeValID(Decoder& decoder) + { + ValID result = decoder.decode<ValID>(); + return result; + } + + template<typename T> + void decodeASTNode(T*& node, Decoder& decoder) + { + ASTNodeType nodeType; + auto saved = decoder.getCursor(); + { + Decoder::WithObject withObject(decoder); + decode(nodeType, decoder); + } + decoder.setCursor(saved); + + auto shell = createEmptyShell(nodeType); + decodeASTNodeContent(shell, decoder); + + node = as<T>(shell); + } + + void decodePtr(Name*& name, Decoder& decoder, Name*) + { + String text; + decode(text, decoder); + + name = _astBuilder->getNamePool()->getName(text); + } + + void decodePtr(DeclAssociationList*& outList, Decoder& decoder, DeclAssociationList*) + { + // Mirroring the encoding logic, we decode this + // as a list of key-value pairs. + // + auto list = RefPtr(new DeclAssociationList()); + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + auto association = RefPtr(new DeclAssociation()); + + Decoder::WithKeyValuePair withPair(decoder); + decode(association->kind, decoder); + decode(association->decl, decoder); + + list->associations.add(association); + } + + outList = list.detach(); + } + + void decodePtr(DiagnosticInfo const*& info, Decoder& decoder, DiagnosticInfo const*) + { + Int id; + decode(id, decoder); + info = getDiagnosticsLookup()->getDiagnosticById(id); + } + + void decodePtr(MarkupEntry*& markupEntry, Decoder&, MarkupEntry*) + { + // TODO: is this case needed? + markupEntry = nullptr; + } + + void decodePtr(CandidateExtensionList*& list, Decoder& decoder, CandidateExtensionList*) + { + auto result = RefPtr(new CandidateExtensionList()); + decode(result->candidateExtensions, decoder); + list = result.detach(); + } + + void decodePtr(WitnessTable*& witnessTable, Decoder& decoder, WitnessTable*) + { + Decoder::WithObject withObject(decoder); + auto wt = RefPtr(new WitnessTable()); + decode(wt->baseType, decoder); + decode(wt->witnessedType, decoder); + decode(wt->isExtern, decoder); + decode(wt->m_requirementDictionary, decoder); + witnessTable = wt.detach(); + } + + void decodeValue(RequirementWitness& witness, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); + decodeEnum(witness.m_flavor, decoder); + switch (witness.m_flavor) + { + case RequirementWitness::Flavor::none: + break; + + case RequirementWitness::Flavor::declRef: + decode(witness.m_declRef, decoder); + break; + + case RequirementWitness::Flavor::val: + decode(witness.m_val, decoder); + break; + + case RequirementWitness::Flavor::witnessTable: + { + RefPtr<WitnessTable> object; + decode(object, decoder); + witness.m_obj = object; } + break; } + } - SerialReader reader(classes, nullptr); + template<typename T> + void decodePtr(T*& node, Decoder& decoder, Val*) + { + ValID id = decodeValID(decoder); + node = static_cast<T*>(getValByID(id)); + } + + template<typename T> + void decodePtr(T*& node, Decoder& decoder, Decl*) + { + DeclID id = decodeDeclID(decoder); + node = static_cast<T*>(getDeclByID(id)); + } + + template<typename T> + void decodePtr(T*& node, Decoder& decoder, DeclBase*) + { + if (decoder.getTag() == SerialBinary::kInt64FourCC) + { + DeclID id = decodeDeclID(decoder); + node = static_cast<T*>(getDeclByID(id)); + } + else { + decodeASTNode(node, decoder); + } + } + + template<typename T> + void decodePtr(T*& node, Decoder& decoder, NodeBase*) + { + decodeASTNode(node, decoder); + } + + + void decodeValue(UnhandledCase, Decoder& decoder); - SlangResult res = reader.load(contents.getBuffer(), contents.getCount(), &namePool); - SLANG_UNUSED(res); + void decodeValue(String& value, Decoder& decoder) { value = decoder.decodeString(); } + + void decodeValue(Token& value, Decoder& decoder) + { + decode(value.type, decoder); + decode(value.flags, decoder); + decode(value.loc, decoder); + if (decoder.decodeNull()) + { } + else + { + Name* name = nullptr; + decode(name, decoder); + value.setName(name); + } + } + + void decodeValue(NameLoc& value, Decoder& decoder) { decode(value.name, decoder); } - // Lets see what we have - const ASTDumpUtil::Flags dumpFlags = - ASTDumpUtil::Flag::HideSourceLoc | ASTDumpUtil::Flag::HideScope; + void decodeValue(SemanticVersion& value, Decoder& decoder) + { + SemanticVersion::IntegerType rawValue = decoder.decode<SemanticVersion::IntegerType>(); + value.setFromInteger(rawValue); + } - String readDump; + void decodeValue(CapabilitySet& value, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) { - SourceWriter sourceWriter(sourceManager, LineDirectiveMode::None, nullptr); - ASTDumpUtil::dump( - reader.getPointer(SerialIndex(1)).dynamicCast<NodeBase>(), - ASTDumpUtil::Style::Hierachical, - dumpFlags, - &sourceWriter); - readDump = sourceWriter.getContentAndClear(); + CapabilityTargetSet targetSet; + decode(targetSet, decoder); + value.getCapabilityTargetSets()[targetSet.target] = targetSet; } - String origDump; + } + + void decodeValue(CapabilityTargetSet& value, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); + decode(value.target, decoder); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) { - SourceWriter sourceWriter(sourceManager, LineDirectiveMode::None, nullptr); - ASTDumpUtil::dump(node, ASTDumpUtil::Style::Hierachical, dumpFlags, &sourceWriter); - origDump = sourceWriter.getContentAndClear(); + CapabilityStageSet stageSet; + decode(stageSet, decoder); + value.shaderStageSets[stageSet.stage] = stageSet; } + } - // Write out - File::writeAllText("ast-read.ast-dump", readDump); - File::writeAllText("ast-orig.ast-dump", origDump); + void decodeValue(CapabilityStageSet& value, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); + decode(value.stage, decoder); + decode(value.atomSet, decoder); + } - if (readDump != origDump) + void decodeValue(CapabilityAtomSet& value, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) { - return SLANG_FAIL; + CapabilityAtom atom; + decode(atom, decoder); + value.add(UInt(atom)); } } - return SLANG_OK; -} + template<typename T> + void decodeValue(std::optional<T>& outValue, Decoder& decoder) + { + if (decoder.decodeNull()) + { + outValue.reset(); + } + else + { + T value; + decode(value, decoder); + outValue = value; + } + } -/* static */ List<uint8_t> ASTSerialUtil::serializeAST(ModuleDecl* moduleDecl) -{ - // TODO: we should store `classes` in GlobalSession to avoid recomputing them every time. - RefPtr<SerialClasses> classes; - SerialClassesUtil::create(classes); + void decodeValue(SyntaxClass<NodeBase>& syntaxClass, Decoder& decoder) + { + ASTNodeType nodeType; + decode(nodeType, decoder); + syntaxClass = SyntaxClass<NodeBase>(nodeType); + } - List<uint8_t> contents; - OwnedMemoryStream stream(FileAccess::ReadWrite); + template<typename T> + void decodeValue(DeclRef<T>& declRef, Decoder& decoder) + { + decode(declRef.declRefBase, decoder); + } - // Only serialize out things *in* this module - ModuleSerialFilter filterStorage(moduleDecl); + void decodeValue(ValNodeOperand& value, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); - SerialFilter* filter = moduleDecl ? &filterStorage : nullptr; + decodeEnum(value.kind, decoder); + switch (value.kind) + { + case ValNodeOperandKind::ConstantValue: + decode(value.values.intOperand, decoder); + break; - SerialWriter writer(classes, filter); + case ValNodeOperandKind::ValNode: + { + Val* val = nullptr; + decode(val, decoder); + value.values.nodeOperand = val; + } + break; - // Lets serialize it all - writer.addPointer(moduleDecl); - // Let's stick it all in a stream - writer.write(&stream); + case ValNodeOperandKind::ASTNode: + { + Decl* decl = nullptr; + decode(decl, decoder); + value.values.nodeOperand = decl; + } + break; + } + } + + void decodeValue(TypeExp& value, Decoder& decoder) { decode(value.type, decoder); } + + void decodeValue(QualType& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.type, decoder); + decode(value.isLeftValue, decoder); + decode(value.hasReadOnlyOnTarget, decoder); + decode(value.isWriteOnly, decoder); + } + + void decodeValue(MatrixCoord& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.row, decoder); + decode(value.col, decoder); + } + + void decodeValue(SPIRVAsmOperand::Flavor& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + + void decodeValue(SPIRVAsmOperand& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.flavor, decoder); + decode(value.token, decoder); + decode(value.expr, decoder); + decode(value.bitwiseOrWith, decoder); + decode(value.knownValue, decoder); + decode(value.wrapInId, decoder); + decode(value.type, decoder); + } + + void decodeValue(SPIRVAsmInst& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.opcode, decoder); + decode(value.operands, decoder); + } + + + template<typename T> + void decodeEnum(T& value, Decoder& decoder) + { + value = T(decoder.decode<Int32>()); + } + + template<typename T> + void decodeSimpleValue(T& value, Decoder& decoder) + { + value = decoder.decode<T>(); + } + + void decodeValue(bool& value, Decoder& decoder) { value = decoder.decodeBool(); } + void decodeValue(Int32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(Int64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(UInt32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(UInt64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(float& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(double& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + + void decodeValue(uint8_t& value, Decoder& decoder) + { + value = uint8_t(decoder.decode<UInt32>()); + } + + void decodeValue(DeclVisibility& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(BaseType& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(BuiltinRequirementKind& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(ASTNodeType& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(ImageFormat& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(TypeTag& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(TryClauseType& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(CapabilityAtom& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(PreferRecomputeAttribute::SideEffectBehavior& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(LogicOperatorShortCircuitExpr::Flavor& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(TreatAsDifferentiableExpr::Flavor& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(DeclAssociationKind& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(TokenType& value, Decoder& decoder) { decodeEnum(value, decoder); } + + + void decodeValue(SourceLoc& value, Decoder& decoder) + { + if (!decoder.decodeNull()) + { + SerialSourceLocData::SourceLoc intermediate; + decoder.decode(intermediate); + + if (_sourceLocReader) + { + auto sourceLoc = _sourceLocReader->getSourceLoc(intermediate); + value = sourceLoc; + } + } + } + + template<typename T> + void decodeValue(T*& ptr, Decoder& decoder) + { + if (decoder.decodeNull()) + ptr = nullptr; + else + decodePtr(ptr, decoder, (T*)nullptr); + } + + template<typename T> + void decodeValue(RefPtr<T>& ptr, Decoder& decoder) + { + if (decoder.decodeNull()) + ptr = nullptr; + else + { + // Hi Future Tess, + // + // The next step here is decoding logic for `WitnessTable`s. + // + + decodePtr(*ptr.writeRef(), decoder, (T*)nullptr); + } + } + + void decodeValue(Modifiers& modifiers, Decoder& decoder) + { + Modifier** link = &modifiers.first; + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Modifier* modifier = nullptr; + decode(modifier, decoder); + + *link = modifier; + link = &modifier->next; + } + } + + template<typename T, int N> + void decodeValue(ShortList<T, N>& array, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + T element; + decode(element, decoder); + array.add(element); + } + } - stream.swapContents(contents); - return contents; -} + template<typename T> + void decode(List<T>& array, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + T element; + decode(element, decoder); + array.add(element); + } + } + + template<typename T, size_t N> + void decode(T (&array)[N], Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + for (auto& element : array) + { + decode(element, decoder); + } + } + + template<typename K, typename V> + void decode(OrderedDictionary<K, V>& dictionary, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Decoder::WithKeyValuePair withPair(decoder); + + K key; + V value; + decode(key, decoder); + decode(value, decoder); + + dictionary.add(key, value); + } + } + + template<typename K, typename V> + void decode(Dictionary<K, V>& dictionary, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Decoder::WithKeyValuePair withPair(decoder); + + K key; + V value; + decode(key, decoder); + decode(value, decoder); + + dictionary.add(key, value); + } + } + + template<typename T> + void decode(T& outValue, Decoder& decoder) + { + decodeValue(outValue, decoder); + } + +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do + void _decodeDataOf($T* obj, Decoder& decoder) + { +% if T.directSuperClass then + _decodeDataOf(static_cast<$(T.directSuperClass)*>(obj), decoder); +% end +% for _,f in ipairs(T.directFields) do + decode(obj->$f, decoder); +% end + } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 1 +#include "slang-serialize-ast.cpp.fiddle" +#endif // FIDDLE END +}; + +ModuleDecl* readSerializedModuleAST( + Linkage* linkage, + ASTBuilder* astBuilder, + DiagnosticSink* sink, + RiffContainer::Chunk* chunk, + SerialSourceLocReader* sourceLocReader, + SourceLoc requestingSourceLoc) +{ + ASTDecodingContext + context(linkage, astBuilder, sink, chunk, sourceLocReader, requestingSourceLoc); + context.decodeAll(); + auto node = context.getDeclByID(0); + auto moduleDecl = as<ModuleDecl>(node); + return moduleDecl; +} } // namespace Slang diff --git a/source/slang/slang-serialize-ast.h b/source/slang/slang-serialize-ast.h index b8af9484c..6adeae8dd 100644 --- a/source/slang/slang-serialize-ast.h +++ b/source/slang/slang-serialize-ast.h @@ -6,53 +6,23 @@ #include "slang-ast-all.h" #include "slang-ast-builder.h" #include "slang-ast-support-types.h" +#include "slang-serialize-source-loc.h" #include "slang-serialize.h" namespace Slang { - -/* Holds RIFF FourCC codes for AST types */ -struct ASTSerialBinary -{ - static const FourCC kRiffFourCC = RiffFourCC::kRiff; - - /// AST module LIST container - static const FourCC kSlangASTModuleFourCC = SLANG_FOUR_CC('S', 'A', 'm', 'l'); - /// AST module data - static const FourCC kSlangASTModuleDataFourCC = SLANG_FOUR_CC('S', 'A', 'm', 'd'); -}; - -class ModuleSerialFilter : public SerialFilter -{ -public: - // SerialFilter impl - virtual SerialIndex writePointer(SerialWriter* writer, const NodeBase* ptr) SLANG_OVERRIDE; - virtual SerialIndex writePointer(SerialWriter* writer, const RefObject* ptr) SLANG_OVERRIDE; - - ModuleSerialFilter(ModuleDecl* moduleDecl) - : m_moduleDecl(moduleDecl) - { - } - -protected: - ModuleDecl* m_moduleDecl; -}; - -struct ASTSerialUtil -{ - /// Add the AST related classes - static void addSerialClasses(SerialClasses* classes); - - /// Tries to serialize out, read back in and test the results are the same. - /// Will write dumped out node to files - static SlangResult testSerialize( - NodeBase* node, - RootNamePool* rootNamePool, - SharedASTBuilder* sharedASTBuilder, - SourceManager* sourceManager); - - static List<uint8_t> serializeAST(ModuleDecl* moduleDecl); -}; +void writeSerializedModuleAST( + Encoder* encoder, + ModuleDecl* moduleDecl, + SerialSourceLocWriter* sourceLocWriter); + +ModuleDecl* readSerializedModuleAST( + Linkage* linkage, + ASTBuilder* astBuilder, + DiagnosticSink* sink, + RiffContainer::Chunk* chunk, + SerialSourceLocReader* sourceLocReader, + SourceLoc requestingSourceLoc); } // namespace Slang diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp index f82357459..c2253ed45 100644 --- a/source/slang/slang-serialize-container.cpp +++ b/source/slang/slang-serialize-container.cpp @@ -10,89 +10,239 @@ #include "slang-mangled-lexer.h" #include "slang-parser.h" #include "slang-serialize-ast.h" -#include "slang-serialize-factory.h" #include "slang-serialize-ir.h" #include "slang-serialize-source-loc.h" namespace Slang { - -/* static */ SlangResult SerialContainerUtil::write( - Module* module, - const WriteOptions& options, - Stream* stream) +struct ModuleEncodingContext { - RiffContainer container; +public: + ModuleEncodingContext(SerialContainerUtil::WriteOptions const& options, Stream* stream) + : options(options), encoder(stream), containerStringPool(StringSlicePool::Style::Default) { - SerialContainerData data; - SLANG_RETURN_ON_FAIL(SerialContainerUtil::addModuleToData(module, options, data)); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(data, options, &container)); + if (options.optionFlags & SerialOptionFlag::SourceLocation) + { + sourceLocWriter = new SerialSourceLocWriter(options.sourceManager); + } } - // We now write the RiffContainer to the stream - SLANG_RETURN_ON_FAIL(RiffUtil::write(container.getRoot(), true, stream)); - return SLANG_OK; -} -/* static */ SlangResult SerialContainerUtil::write( - FrontEndCompileRequest* frontEndReq, - const WriteOptions& options, - Stream* stream) -{ - RiffContainer container; + ~ModuleEncodingContext() + { + encoder.setRIFFChunk(encoder.getRIFF()->getRoot()); + encodeFinalPieces(); + } + + SlangResult encodeModuleList(FrontEndCompileRequest* frontEndReq) + { + // Encoding a front-end compile request into a RIFF + // is simply a matter of encoding the module for each + // of the translation units that got compiled. + // + Encoder::WithKeyValuePair withArray(&encoder, SerialBinary::kModuleListFourCc); + for (TranslationUnitRequest* translationUnit : frontEndReq->translationUnits) + { + SLANG_RETURN_ON_FAIL(encode(translationUnit->module)); + } + return SLANG_OK; + } + + SlangResult encode(FrontEndCompileRequest* frontEndReq) + { + Encoder::WithObject withObject(&encoder, SerialBinary::kContainerFourCc); + SLANG_RETURN_ON_FAIL(encodeModuleList(frontEndReq)); + return SLANG_OK; + } + + SlangResult encode(EndToEndCompileRequest* request) + { + Encoder::WithObject withObject(&encoder, SerialBinary::kContainerFourCc); + + // Encoding an end-to-end compile request starts with the same + // work as for a front-end request: we encode each of + // the modules for the translation units. + // + SLANG_RETURN_ON_FAIL(encodeModuleList(request->getFrontEndReq())); + // + // If code generation is disabled, then we can skip all further + // steps, and the encoding process is no different + // than for a front-end request. + // + if (request->getOptionSet().getBoolOption(CompilerOptionName::SkipCodeGen)) + { + return SLANG_OK; + } + + // If code generation is enabled, then we need to encode + // information on each of the code generation targets, as well + // as the entry points. + // + // We start with the targets, each of which will have a Slang IR + // representation of the layout information for the program + // on that target. + // + auto linkage = request->getLinkage(); + auto sink = request->getSink(); + auto program = request->getSpecializedGlobalAndEntryPointsComponentType(); + { + Encoder::WithArray withArray(&encoder); // kContainerFourCc + + for (auto target : linkage->targets) + { + auto targetProgram = program->getTargetProgram(target); + encode(targetProgram, sink); + } + } + + // The compiled `program` may also have zero or more entry points, + // and we need to encode information about each of them. + // + { + Encoder::WithArray withArray(&encoder, SerialBinary::kEntryPointListFourCc); + + auto entryPointCount = program->getEntryPointCount(); + for (Index ii = 0; ii < entryPointCount; ++ii) + { + auto entryPoint = program->getEntryPoint(ii); + auto entryPointMangledName = program->getEntryPointMangledName(ii); + encode(entryPoint, entryPointMangledName); + } + } + + return SLANG_OK; + } + + SlangResult encode(TargetProgram* targetProgram, DiagnosticSink* sink) { - SerialContainerData data; + // TODO: + // Serialization of target component IR is causing the embedded precompiled binary + // feature to fail. The resulting data modules contain both TU IR and TC IR, with only + // one module header. Yong suggested to ignore the TC IR for now, though also that + // OV was using the feature, so disabling this might cause problems. + + IRModule* irModule = targetProgram->getOrCreateIRModuleForLayout(sink); + + // Okay, we need to serialize this target program and its IR too... + IRSerialData serialData; + IRSerialWriter writer; + SLANG_RETURN_ON_FAIL( - SerialContainerUtil::addFrontEndRequestToData(frontEndReq, options, data)); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(data, options, &container)); + writer.write(irModule, sourceLocWriter, options.optionFlags, &serialData)); + SLANG_RETURN_ON_FAIL(IRSerialWriter::writeContainer(serialData, encoder.getRIFF())); + + return SLANG_OK; } - // We now write the RiffContainer to the stream - SLANG_RETURN_ON_FAIL(RiffUtil::write(container.getRoot(), true, stream)); - return SLANG_OK; -} -/* static */ SlangResult SerialContainerUtil::write( - EndToEndCompileRequest* request, - const WriteOptions& options, - Stream* stream) -{ - RiffContainer container; + void encode(Name* name) { encoder.encode(name->text); } + + void encode(String const& value) { encoder.encode(value); } + + void encode(uint32_t value) { encoder.encode(UInt(value)); } + + void encodeData(void const* data, size_t size) { encoder.encodeData(data, size); } + + SlangResult encode(EntryPoint* entryPoint, String const& entryPointMangledName) { - SerialContainerData data; - SLANG_RETURN_ON_FAIL(SerialContainerUtil::addEndToEndRequestToData(request, options, data)); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(data, options, &container)); + Encoder::WithObject withObject(&encoder, SerialBinary::kEntryPointFourCc); + + { + Encoder::WithObject withProperty(&encoder, SerialBinary::kNameFourCC); + encode(entryPoint->getName()); + } + { + Encoder::WithObject withProperty(&encoder, SerialBinary::kProfileFourCC); + encode(entryPoint->getProfile().raw); + } + { + Encoder::WithObject withProperty(&encoder, SerialBinary::kMangledNameFourCC); + encode(entryPointMangledName); + } + + return SLANG_OK; } - // We now write the RiffContainer to the stream - SLANG_RETURN_ON_FAIL(RiffUtil::write(container.getRoot(), true, stream)); - return SLANG_OK; -} -/* static */ SlangResult SerialContainerUtil::addModuleToData( - Module* module, - const WriteOptions& options, - SerialContainerData& outData) -{ - if (options.optionFlags & (SerialOptionFlag::ASTModule | SerialOptionFlag::IRModule)) + + SlangResult encode(Module* module) { - SerialContainerData::Module dstModule; + if (!(options.optionFlags & (SerialOptionFlag::IRModule | SerialOptionFlag::ASTModule))) + return SLANG_OK; - // NOTE: The astBuilder is not set here, as not needed to be scoped for serialization (it is - // assumed the TranslationUnitRequest stays in scope) + Encoder::WithObject withModule(&encoder, SerialBinary::kModuleFourCC); - if (options.optionFlags & SerialOptionFlag::ASTModule) + // The first piece that we write for a module is its header. + // The header is intended to provide information that can be + // used to determine if a precompiled module is up-to-date. + // + // Update(tfoley): Okay, let's skip the whole header idea and just + // serialize these things as properties of the module itself... { - // Root AST node - auto moduleDecl = module->getModuleDecl(); - SLANG_ASSERT(moduleDecl); + // So many things need the module name, that it makes + // sense to serialize it separately from all the rest. + // + { + Encoder::WithObject withProperty(&encoder, SerialBinary::kNameFourCC); + encoder.encodeString(module->getNameObj()->text); + } + + // The header includes a digest of all the compile options and + // the files that the compiled result depended on. + // + auto digest = module->computeDigest(); + encoder.encodeData(PropertyKeys<Module>::Digest, digest.data, sizeof(digest.data)); - dstModule.astRootNode = moduleDecl; + // The header includes an array of the paths of all of the + // files that the compiled result depended on. + // + encodeModuleDependencyPaths(module); } - if (options.optionFlags & SerialOptionFlag::IRModule) + + // If serialization of Slang IR modules is enabled, and there + // is IR available for this module, then we we encode it. + // + if ((options.optionFlags & SerialOptionFlag::IRModule)) { - // IR module - dstModule.irModule = module->getIRModule(); - SLANG_ASSERT(dstModule.irModule); + if (auto irModule = module->getIRModule()) + { + Encoder::WithKeyValuePair withKey(&encoder, PropertyKeys<Module>::IRModule); + + IRSerialData serialData; + IRSerialWriter writer; + SLANG_RETURN_ON_FAIL( + writer.write(irModule, sourceLocWriter, options.optionFlags, &serialData)); + SLANG_RETURN_ON_FAIL(IRSerialWriter::writeContainer(serialData, encoder.getRIFF())); + } } + // If serialization of AST information is enabled, and we have AST + // information available, then we serialize it here. + // + if (options.optionFlags & SerialOptionFlag::ASTModule) + { + if (auto moduleDecl = module->getModuleDecl()) + { + Encoder::WithKeyValuePair withKey(&encoder, PropertyKeys<Module>::ASTModule); + + writeSerializedModuleAST(&encoder, moduleDecl, sourceLocWriter); + } + } + + return SLANG_OK; + } + + SlangResult encodeModuleDependencyPaths(Module* module) + { + Encoder::WithObject withProperty(&encoder, PropertyKeys<Module>::FileDependencies); + + // TODO(tfoley): This is some of the most complicated logic + // in the encoding system, because it tries to translate + // the file dependency paths into something that isn't + // specific to the machine on which a module was built. + // + // The comments that follow are from the original implementation + // of this logic, because I cannot state with confidence + // that I know what's happening in all of this. + + // Here we assume that the first file in the file dependencies is the module's file path. // We store the module's file path as a relative path with respect to the first search // directory that contains the module, and store the paths of dependent files as relative @@ -155,6 +305,7 @@ namespace Slang } Path::getCanonical(linkageRoot, linkageRoot); + Encoder::WithArray withArray(&encoder); for (auto file : fileDependencies) { if (file->getPathInfo().hasFoundPath()) @@ -170,728 +321,314 @@ namespace Slang { auto relativeModulePath = Path::getRelativePath(linkageRoot, canonicalModulePath); - dstModule.dependentFiles.add(relativeModulePath); + + encoder.encodeString(relativeModulePath); } else { // For all other dependnet files, store them as relative paths with respect // to the module's path. canonicalFilePath = Path::getRelativePath(moduleDir, canonicalFilePath); - dstModule.dependentFiles.add(canonicalFilePath); + encoder.encodeString(canonicalFilePath); } } else { // If the module is coming from string instead of an actual file, store it as // is. - dstModule.dependentFiles.add(canonicalModulePath); + encoder.encodeString(canonicalModulePath); } } else { - dstModule.dependentFiles.add(file->getPathInfo().getMostUniqueIdentity()); + encoder.encodeString(file->getPathInfo().getMostUniqueIdentity()); } } - dstModule.digest = module->computeDigest(); - outData.modules.add(dstModule); - } - return SLANG_OK; -} - -/* static */ SlangResult SerialContainerUtil::addFrontEndRequestToData( - FrontEndCompileRequest* frontEndReq, - const WriteOptions& options, - SerialContainerData& outData) -{ - // Go through translation units, adding modules - for (TranslationUnitRequest* translationUnit : frontEndReq->translationUnits) - { - SLANG_RETURN_ON_FAIL(addModuleToData(translationUnit->module, options, outData)); - } - - return SLANG_OK; -} - -/* static */ SlangResult SerialContainerUtil::addEndToEndRequestToData( - EndToEndCompileRequest* request, - const WriteOptions& options, - SerialContainerData& out) -{ - auto linkage = request->getLinkage(); - auto sink = request->getSink(); - - // Output the parsed modules. - addFrontEndRequestToData(request->getFrontEndReq(), options, out); - - // If we are skipping code generation, then we are done. - if (request->getOptionSet().getBoolOption(CompilerOptionName::SkipCodeGen)) - { return SLANG_OK; } - // - auto program = request->getSpecializedGlobalAndEntryPointsComponentType(); - // Add all the target modules + SlangResult encodeFinalPieces() { - for (auto target : linkage->targets) + // We can now output the debug information. This is for all IR and AST + if (sourceLocWriter) { - auto targetProgram = program->getTargetProgram(target); - auto irModule = targetProgram->getOrCreateIRModuleForLayout(sink); - - SerialContainerData::TargetComponent targetComponent; + // Write out the debug info + SerialSourceLocData debugData; + sourceLocWriter->write(&debugData); - targetComponent.irModule = irModule; - - auto& dstTarget = targetComponent.target; - - dstTarget.floatingPointMode = target->getOptionSet().getFloatingPointMode(); - dstTarget.profile = target->getOptionSet().getProfile(); - dstTarget.flags = target->getOptionSet().getTargetFlags(); - dstTarget.codeGenTarget = target->getTarget(); - - out.targetComponents.add(targetComponent); + debugData.writeContainer(encoder.getRIFF()); } - } - // Entry points - { - auto entryPointCount = program->getEntryPointCount(); - for (Index ii = 0; ii < entryPointCount; ++ii) + // Write the container string table + if (containerStringPool.getAdded().getCount() > 0) { - auto entryPoint = program->getEntryPoint(ii); - auto entryPointMangledName = program->getEntryPointMangledName(ii); - - SerialContainerData::EntryPoint dstEntryPoint; + Encoder::WithKeyValuePair withKey(&encoder, SerialBinary::kStringTableFourCc); - dstEntryPoint.name = entryPoint->getName(); - dstEntryPoint.mangledName = entryPointMangledName; - dstEntryPoint.profile = entryPoint->getProfile(); + List<char> encodedTable; + SerialStringTableUtil::encodeStringTable(containerStringPool, encodedTable); - out.entryPoints.add(dstEntryPoint); + encoder.encodeData(encodedTable.getBuffer(), encodedTable.getCount()); } + + return SLANG_OK; } - return SLANG_OK; -} -/* static */ SlangResult SerialContainerUtil::write( - const SerialContainerData& data, - const WriteOptions& options, - RiffContainer* container) -{ +private: + SerialContainerUtil::WriteOptions const& options; RefPtr<SerialSourceLocWriter> sourceLocWriter; // The string pool used across the whole of the container - StringSlicePool containerStringPool(StringSlicePool::Style::Default); + StringSlicePool containerStringPool; - RiffContainer::ScopeChunk scopeModule( - container, - RiffContainer::Chunk::Kind::List, - SerialBinary::kContainerFourCc); + Encoder encoder; +}; - if (data.modules.getCount() && - (options.optionFlags & (SerialOptionFlag::IRModule | SerialOptionFlag::ASTModule))) - { - // Module list - RiffContainer::ScopeChunk moduleListScope( - container, - RiffContainer::Chunk::Kind::List, - SerialBinary::kModuleListFourCc); - - if (options.optionFlags & SerialOptionFlag::SourceLocation) - { - sourceLocWriter = new SerialSourceLocWriter(options.sourceManager); - } - - RefPtr<SerialClasses> serialClasses; - - for (const auto& module : data.modules) - { - // Okay, we need to serialize this module to our container file. - // We currently don't serialize it's name..., but support for that could be added. +// +// To serialize a module (or compile request) to a stream, we first +// construct a RIFF container from it, and then serialize that +// container out to a byte stream. +// - // First, we write a header that can be used to verify if the precompiled module is - // up-to-date. The header has: 1) a digest of all compile options and dependent source - // files. 2) a list of source file paths. - // - { - RiffContainer::ScopeChunk scopeHeader( - container, - RiffContainer::Chunk::Kind::Data, - SerialBinary::kModuleHeaderFourCc); - OwnedMemoryStream headerMemStream(FileAccess::Write); - StringBuilder filePathsSB; - for (auto fileDependency : module.dependentFiles) - filePathsSB << fileDependency << "\n"; - headerMemStream.write(module.digest.data, sizeof(module.digest.data)); - uint32_t fileListLength = (uint32_t)filePathsSB.getLength(); - headerMemStream.write(&fileListLength, sizeof(uint32_t)); - headerMemStream.write(filePathsSB.getBuffer(), fileListLength); - container->write( - headerMemStream.getContents().getBuffer(), - headerMemStream.getContents().getCount()); - } - - // Write the IR information - if ((options.optionFlags & SerialOptionFlag::IRModule) && module.irModule) - { - IRSerialData serialData; - IRSerialWriter writer; - SLANG_RETURN_ON_FAIL(writer.write( - module.irModule, - sourceLocWriter, - options.optionFlags, - &serialData)); - SLANG_RETURN_ON_FAIL(IRSerialWriter::writeContainer(serialData, container)); - } - - // Write the AST information - - if (options.optionFlags & SerialOptionFlag::ASTModule) - { - if (ModuleDecl* moduleDecl = as<ModuleDecl>(module.astRootNode)) - { - // Put in AST module - RiffContainer::ScopeChunk scopeASTModule( - container, - RiffContainer::Chunk::Kind::List, - ASTSerialBinary::kSlangASTModuleFourCC); - - if (!serialClasses) - { - SLANG_RETURN_ON_FAIL(SerialClassesUtil::create(serialClasses)); - } - - ModuleSerialFilter filter(moduleDecl); - auto astWriterFlag = SerialWriter::Flag::ZeroInitialize; - if ((options.optionFlags & SerialOptionFlag::ASTFunctionBody) == 0) - astWriterFlag = (SerialWriter::Flag::Enum)( - astWriterFlag | SerialWriter::Flag::SkipFunctionBody); - - SerialWriter writer(serialClasses, &filter, astWriterFlag); - - writer.getExtraObjects().set(sourceLocWriter); - - // Add the module and everything that isn't filtered out in the filter. - writer.addPointer(moduleDecl); +/* static */ SlangResult SerialContainerUtil::write( + Module* module, + const WriteOptions& options, + Stream* stream) +{ + ModuleEncodingContext context(options, stream); + SLANG_RETURN_ON_FAIL(context.encode(module)); + return SLANG_OK; +} +/* static */ SlangResult SerialContainerUtil::write( + FrontEndCompileRequest* request, + const WriteOptions& options, + Stream* stream) +{ + ModuleEncodingContext context(options, stream); + SLANG_RETURN_ON_FAIL(context.encode(request)); + return SLANG_OK; +} - // We can now serialize it into the riff container. - SLANG_RETURN_ON_FAIL(writer.writeIntoContainer( - ASTSerialBinary::kSlangASTModuleDataFourCC, - container)); - } - } - } +/* static */ SlangResult SerialContainerUtil::write( + EndToEndCompileRequest* request, + const WriteOptions& options, + Stream* stream) +{ + ModuleEncodingContext context(options, stream); + SLANG_RETURN_ON_FAIL(context.encode(request)); + return SLANG_OK; +} - // TODO: - // Serialization of target component IR is causing the embedded precompiled binary - // feature to fail. The resulting data modules contain both TU IR and TC IR, with only - // one module header. Yong suggested to ignore the TC IR for now, though also that - // OV was using the feature, so disabling this might cause problems. -#if 0 - if (data.targetComponents.getCount() && (options.optionFlags & SerialOptionFlag::IRModule)) - { - // TODO: in the case where we have specialization, we might need - // to serialize IR related to `program`... +String StringChunkRef::getValue() +{ + return Decoder(ptr()).decodeString(); +} - for (const auto& targetComponent : data.targetComponents) - { - IRModule* irModule = targetComponent.irModule; +ChunkRefList<StringChunkRef> ModuleChunkRef::getFileDependencies() +{ + Decoder decoder(ptr()); + Decoder::WithProperty withProperty(decoder, PropertyKeys<Module>::FileDependencies); + return ChunkRefList<StringChunkRef>(as<RiffContainer::ListChunk>(decoder.getCursor())); +} - // Okay, we need to serialize this target program and its IR too... - IRSerialData serialData; - IRSerialWriter writer; +ModuleChunkRef ModuleChunkRef::find(RiffContainer* container) +{ + auto found = container->getRoot()->findListRec(SerialBinary::kModuleFourCC); + return ModuleChunkRef(found); +} - SLANG_RETURN_ON_FAIL(writer.write(irModule, sourceLocWriter, options.optionFlags, &serialData)); - SLANG_RETURN_ON_FAIL(IRSerialWriter::writeContainer(serialData, options.compressionType, container)); - } - } -#endif +SHA1::Digest ModuleChunkRef::getDigest() +{ + auto foundChunk = + static_cast<RiffContainer::DataChunk*>(ptr()->findContained(PropertyKeys<Module>::Digest)); + if (!foundChunk) + { + SLANG_UNEXPECTED("module chunk had no digest"); } - - if (data.entryPoints.getCount()) + if (foundChunk->calcPayloadSize() != sizeof(SHA1::Digest)) { - for (const auto& entryPoint : data.entryPoints) - { - RiffContainer::ScopeChunk entryPointScope( - container, - RiffContainer::Chunk::Kind::Data, - SerialBinary::kEntryPointFourCc); - - SerialContainerBinary::EntryPoint dst; - - dst.name = uint32_t(containerStringPool.add(entryPoint.name->text)); - dst.profile = entryPoint.profile.raw; - dst.mangledName = uint32_t(containerStringPool.add(entryPoint.mangledName)); - - container->write(&dst, sizeof(dst)); - } + SLANG_UNEXPECTED("module digest chunk had wrong size"); } - // We can now output the debug information. This is for all IR and AST - if (sourceLocWriter) - { - // Write out the debug info - SerialSourceLocData debugData; - sourceLocWriter->write(&debugData); + SHA1::Digest digest; + foundChunk->getPayload(&digest); + return digest; +} - debugData.writeContainer(container); - } +String ModuleChunkRef::getName() +{ + // TODO(tfoley): This kind of logic needs a way + // to be greatly simplified, so that we don't + // have to express such complicated logic for + // simply extracting a single string property... + // + Decoder decoder(ptr()); + Decoder::WithProperty withProperty(decoder, SerialBinary::kNameFourCC); + return decoder.decodeString(); +} - // Write the container string table - if (containerStringPool.getAdded().getCount() > 0) - { - RiffContainer::ScopeChunk stringTableScope( - container, - RiffContainer::Chunk::Kind::Data, - SerialBinary::kStringTableFourCc); - List<char> encodedTable; - SerialStringTableUtil::encodeStringTable(containerStringPool, encodedTable); +IRModuleChunkRef ModuleChunkRef::findIR() +{ + auto foundProperty = ptr()->findContainedList(PropertyKeys<Module>::IRModule); + if (!foundProperty) + return IRModuleChunkRef(nullptr); + return IRModuleChunkRef( + static_cast<RiffContainer::ListChunk*>(foundProperty->getFirstContainedChunk())); +} - container->write(encodedTable.getBuffer(), encodedTable.getCount()); - } +ASTModuleChunkRef ModuleChunkRef::findAST() +{ + auto foundProperty = ptr()->findContainedList(PropertyKeys<Module>::ASTModule); + if (!foundProperty) + return ASTModuleChunkRef(nullptr); + return ASTModuleChunkRef( + static_cast<RiffContainer::ListChunk*>(foundProperty->getFirstContainedChunk())); +} - return SLANG_OK; +ContainerChunkRef ContainerChunkRef::find(RiffContainer* container) +{ + auto found = container->getRoot()->findListRec(SerialBinary::kContainerFourCc); + return ContainerChunkRef(found); } +ChunkRefList<ModuleChunkRef> ContainerChunkRef::getModules() +{ + auto found = ptr()->findContainedList(SerialBinary::kModuleListFourCc); + return ChunkRefList<ModuleChunkRef>(found); +} -static List<ExtensionDecl*>& _getCandidateExtensionList( - AggTypeDecl* typeDecl, - Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>>& mapTypeToCandidateExtensions) +ChunkRefList<EntryPointChunkRef> ContainerChunkRef::getEntryPoints() { - RefPtr<CandidateExtensionList> entry; - if (!mapTypeToCandidateExtensions.tryGetValue(typeDecl, entry)) - { - entry = new CandidateExtensionList(); - mapTypeToCandidateExtensions.add(typeDecl, entry); - } - return entry->candidateExtensions; + auto found = ptr()->findContainedList(SerialBinary::kEntryPointListFourCc); + return ChunkRefList<EntryPointChunkRef>(found); } -/* static */ Result SerialContainerUtil::read( - RiffContainer* container, - const ReadOptions& options, - const LoadedModuleDictionary* additionalLoadedModules, - SerialContainerData& out) +String EntryPointChunkRef::getMangledName() const { - out.clear(); + // TODO(tfoley): This kind of logic needs a way + // to be greatly simplified, so that we don't + // have to express such complicated logic for + // simply extracting a single string property... + // + Decoder decoder(ptr()); + Decoder::WithProperty withProperty(decoder, SerialBinary::kMangledNameFourCC); + return decoder.decodeString(); +} - RiffContainer::ListChunk* containerChunk = - container->getRoot()->findListRec(SerialBinary::kContainerFourCc); - if (!containerChunk) - { - // Must be a container - return SLANG_FAIL; - } +String EntryPointChunkRef::getName() const +{ + // TODO(tfoley): This kind of logic needs a way + // to be greatly simplified, so that we don't + // have to express such complicated logic for + // simply extracting a single string property... + // + Decoder decoder(ptr()); + Decoder::WithProperty withProperty(decoder, SerialBinary::kNameFourCC); + return decoder.decodeString(); +} - StringSlicePool containerStringPool(StringSlicePool::Style::Default); +Profile EntryPointChunkRef::getProfile() const +{ + // TODO(tfoley): This kind of logic needs a way + // to be greatly simplified, so that we don't + // have to express such complicated logic for + // simply extracting a single string property... + // + Decoder decoder(ptr()); + Decoder::WithProperty withProperty(decoder, SerialBinary::kProfileFourCC); - if (RiffContainer::Data* stringTableData = - containerChunk->findContainedData(SerialBinary::kStringTableFourCc)) - { - SerialStringTableUtil::decodeStringTable( - (const char*)stringTableData->getPayload(), - stringTableData->getSize(), - containerStringPool); - } + Profile::RawVal rawVal; + decoder.decode(rawVal); - RefPtr<SerialSourceLocReader> sourceLocReader; - RefPtr<SerialClasses> serialClasses; + return Profile(rawVal); +} - // Debug information - if (auto debugChunk = containerChunk->findContainedList(SerialSourceLocData::kDebugFourCc)) - { - // Read into data - SerialSourceLocData sourceLocData; - SLANG_RETURN_ON_FAIL(sourceLocData.readContainer(debugChunk)); - // Turn into DebugReader - sourceLocReader = new SerialSourceLocReader; - SLANG_RETURN_ON_FAIL(sourceLocReader->read(&sourceLocData, options.sourceManager)); - } +RiffContainer::ListChunk* findDebugChunk(RiffContainer::Chunk* startingChunk) +{ + if (!startingChunk) + return nullptr; - // Create a source loc representing the binary module. - SourceLoc binaryModuleLoc = SourceLoc(); + RiffContainer::ListChunk* container = as<RiffContainer::ListChunk>(startingChunk); + if (!container) + container = startingChunk->m_parent; - if (options.modulePath.getLength()) + for (; container; container = container->m_parent) { - auto srcManager = options.linkage->getSourceManager(); - auto modulePathInfo = PathInfo::makePath(options.modulePath); - auto srcFile = srcManager->findSourceFileByPathRecursively(modulePathInfo.foundPath); - if (!srcFile) + if (auto debugChunk = container->findContainedList(SerialSourceLocData::kDebugFourCc)) { - srcFile = srcManager->createSourceFileWithString(modulePathInfo, String()); - srcManager->addSourceFile(options.modulePath, srcFile); + return debugChunk; } - auto srcView = srcManager->createSourceView(srcFile, &modulePathInfo, SourceLoc()); - binaryModuleLoc = srcView->getRange().begin; } - // Add modules - if (RiffContainer::ListChunk* moduleList = - containerChunk->findContainedList(SerialBinary::kModuleListFourCc)) - { - RiffContainer::Chunk* chunk = moduleList->getFirstContainedChunk(); - while (chunk) - { - auto startChunk = chunk; - - RefPtr<ASTBuilder> astBuilder = options.astBuilder; - NodeBase* astRootNode = nullptr; - RefPtr<IRModule> irModule; - SerialContainerData::Module module; - if (auto headerChunk = - as<RiffContainer::DataChunk>(chunk, SerialBinary::kModuleHeaderFourCc)) - { - MemoryStreamBase memStream( - FileAccess::Read, - headerChunk->getSingleData()->getPayload(), - headerChunk->getSingleData()->getSize()); - size_t readSize = 0; - memStream.read(module.digest.data, sizeof(SHA1::Digest), readSize); - if (readSize != sizeof(SHA1::Digest)) - return SLANG_FAIL; - uint32_t fileListLength = 0; - memStream.read(&fileListLength, sizeof(uint32_t), readSize); - if (readSize != sizeof(uint32_t)) - return SLANG_FAIL; - List<uint8_t> fileListContent; - fileListContent.setCount(fileListLength); - memStream.read(fileListContent.getBuffer(), fileListContent.getCount(), readSize); - if (readSize != (size_t)fileListContent.getCount()) - return SLANG_FAIL; - UnownedStringSlice fileListString( - (const char*)fileListContent.getBuffer(), - fileListContent.getCount()); - List<UnownedStringSlice> fileList; - StringUtil::split(fileListString, '\n', fileList); - for (auto file : fileList) - { - if (file.getLength()) - { - module.dependentFiles.add(file); - } - } - // Onto next chunk - chunk = chunk->m_next; - } - - if (auto irChunk = as<RiffContainer::ListChunk>(chunk, IRSerialBinary::kIRModuleFourCc)) - { - if (!options.readHeaderOnly) - { - IRSerialData serialData; - SLANG_RETURN_ON_FAIL(IRSerialReader::readContainer(irChunk, &serialData)); - - // Read IR back from serialData - IRSerialReader reader; - SLANG_RETURN_ON_FAIL( - reader.read(serialData, options.session, sourceLocReader, irModule)); - } - - // Onto next chunk - chunk = chunk->m_next; - } - - if (auto astChunk = - as<RiffContainer::ListChunk>(chunk, ASTSerialBinary::kSlangASTModuleFourCC)) - { - if (!options.readHeaderOnly) - { - RiffContainer::Data* astData = - astChunk->findContainedData(ASTSerialBinary::kSlangASTModuleDataFourCC); - - if (astData) - { - if (!serialClasses) - { - SLANG_RETURN_ON_FAIL(SerialClassesUtil::create(serialClasses)); - } - - // TODO(JS): We probably want to store off better information about each of - // the translation unit including some kind of 'name'. For now we just - // generate a name. - - StringBuilder buf; - buf << "tu" << out.modules.getCount(); - if (!astBuilder) - { - astBuilder = - new ASTBuilder(options.sharedASTBuilder, buf.produceString()); - } - - /// We need to make the current ASTBuilder available for access via - /// thread_local global. - SetASTBuilderContextRAII astBuilderRAII(astBuilder); - - DefaultSerialObjectFactory objectFactory(astBuilder); - - SerialReader reader(serialClasses, &objectFactory); - - // Sets up the entry table - one entry for each 'object'. - // No native objects are constructed. No objects are deserialized. - SLANG_RETURN_ON_FAIL(reader.loadEntries( - (const uint8_t*)astData->getPayload(), - astData->getSize())); - - // Construct a native object for each table entry (where appropriate). - // Note that this *doesn't* set all object pointers - some are special cased - // and created on demand (strings) and imported symbols will have their - // object pointers unset (they are resolved in next step) - SLANG_RETURN_ON_FAIL(reader.constructObjects(options.namePool)); - - // Resolve external references if the linkage is specified - if (options.linkage) - { - const auto& entries = reader.getEntries(); - auto& objects = reader.getObjects(); - const Index entriesCount = entries.getCount(); - - String currentModuleName; - Module* currentModule = nullptr; - - // Index from 1 (0 is null) - for (Index i = 1; i < entriesCount; ++i) - { - const SerialInfo::Entry* entry = entries[i]; - if (entry->typeKind == SerialTypeKind::ImportSymbol) - { - // Import symbols are always serialized with a mangled name in - // the form of <module_name>!<symbol_mangled_name>. As - // symbol_mangled_name may not contain the name of its parent - // module in the case of an `extern` or `export` symbol. - // - UnownedStringSlice mangledName = - reader.getStringSlice(SerialIndex(i)); - List<UnownedStringSlice> slicesOut; - StringUtil::split(mangledName, '!', slicesOut); - if (slicesOut.getCount() != 2) - return SLANG_FAIL; - auto moduleName = slicesOut[0]; - mangledName = slicesOut[1]; - - // If we already have looked up this module and it has the same - // name just use what we have - Module* readModule = nullptr; - if (currentModule && - moduleName == currentModuleName.getUnownedSlice()) - { - readModule = currentModule; - } - else - { - // The modules are loaded on the linkage. - Linkage* linkage = options.linkage; - - NamePool* namePool = linkage->getNamePool(); - Name* moduleNameName = namePool->getName(moduleName); - readModule = linkage->findOrImportModule( - moduleNameName, - binaryModuleLoc, - options.sink, - additionalLoadedModules); - if (!readModule) - { - return SLANG_FAIL; - } - - // Set the current module and name - currentModule = readModule; - currentModuleName = moduleName; - } - - // Look up the symbol - NodeBase* nodeBase = - readModule->findExportFromMangledName(mangledName); - - if (!nodeBase) - { - if (options.sink) - { - options.sink->diagnose( - SourceLoc::fromRaw(0), - Diagnostics::unableToFindSymbolInModule, - mangledName, - moduleName); - } - - // If didn't find the export then we create an - // UnresolvedDecl node to represent the error. - auto unresolved = astBuilder->create<UnresolvedDecl>(); - unresolved->nameAndLoc.name = - options.linkage->getNamePool()->getName(mangledName); - nodeBase = unresolved; - } - - // set the result - objects[i] = nodeBase; - } - } - } - - // Set the sourceLocReader before doing de-serialize, such can lookup the - // remapped sourceLocs - reader.getExtraObjects().set(sourceLocReader); - - // TODO(JS): - // If modules can have more complicated relationships (like a two modules - // can refer to symbols from each other), then we can make this work by 1) - // deserialize *without* the external symbols being set up 2) calculate the - // symbols 3) deserialize the other module (in the same way) 4) run - // deserializeObjects *again* on each module This is less efficient than it - // might be (because deserialize phase is done twice) so if this is - // necessary may want a mechanism that *just* does reference lookups. - // - // For now if we assume a module can only access symbols from another - // module, and not the reverse. So we just need to deserialize and we are - // done - SLANG_RETURN_ON_FAIL(reader.deserializeObjects()); - - // Get the root node. It's at index 1 (0 is the null value). - astRootNode = reader.getPointer(SerialIndex(1)).dynamicCast<NodeBase>(); - - // Go through all AST nodes: - // 1) Add the extensions to the module mapTypeToCandidateExtensions cache - // 2) We need to fix the callback pointers for parsing - // 3) Register all `Val`s to the ASTBuilder's deduplication map. - - { - ModuleDecl* moduleDecl = as<ModuleDecl>(astRootNode); - - // Maps from keyword name name to index in (syntaxParseInfos) - // Will be filled in lazily if needed (for SyntaxDecl setup) - Dictionary<Name*, Index> syntaxKeywordDict; - - OrderedDictionary<Val*, List<Val**>> valUses; - - // Get the parse infos - const auto syntaxParseInfos = getSyntaxParseInfos(); - SLANG_ASSERT(syntaxParseInfos.getCount()); - - for (auto& obj : reader.getObjects()) - { - - if (obj.m_kind == SerialTypeKind::NodeBase) - { - NodeBase* nodeBase = (NodeBase*)obj.m_ptr; - SLANG_ASSERT(nodeBase); - - if (ExtensionDecl* extensionDecl = - dynamicCast<ExtensionDecl>(nodeBase)) - { - if (auto targetDeclRefType = - as<DeclRefType>(extensionDecl->targetType)) - { - ShortList<AggTypeDecl*> baseDecls; - getExtensionTargetDeclList( - astBuilder, - targetDeclRefType, - extensionDecl, - baseDecls); - for (auto baseDecl : baseDecls) - { - _getCandidateExtensionList( - baseDecl, - moduleDecl->mapTypeToCandidateExtensions) - .add(extensionDecl); - } - } - } - else if ( - SyntaxDecl* syntaxDecl = dynamicCast<SyntaxDecl>(nodeBase)) - { - // Set up the dictionary lazily - if (syntaxKeywordDict.getCount() == 0) - { - NamePool* namePool = options.session->getNamePool(); - for (Index i = 0; i < syntaxParseInfos.getCount(); ++i) - { - const auto& entry = syntaxParseInfos[i]; - syntaxKeywordDict.add( - namePool->getName(entry.keywordName), - i); - } - // Must have something in it at this point - SLANG_ASSERT(syntaxKeywordDict.getCount()); - } - - // Look up the index - Index* entryIndexPtr = - syntaxKeywordDict.tryGetValue(syntaxDecl->getName()); - if (entryIndexPtr) - { - // Set up SyntaxDecl based on the ParseSyntaxIndo - auto& info = syntaxParseInfos[*entryIndexPtr]; - syntaxDecl->parseCallback = *info.callback; - syntaxDecl->parseUserData = - const_cast<ReflectClassInfo*>(info.classInfo); - } - else - { - // If we don't find a setup entry, we use - // `parseSimpleSyntax`, and set the parseUserData to the - // ReflectClassInfo (as parseSimpleSyntax needs this) - syntaxDecl->parseCallback = &parseSimpleSyntax; - SLANG_ASSERT(syntaxDecl->syntaxClass.classInfo); - syntaxDecl->parseUserData = - const_cast<ReflectClassInfo*>( - syntaxDecl->syntaxClass.classInfo); - } - } - else if (Val* val = dynamicCast<Val>(nodeBase)) - { - val->_setUnique(); - } - } - } - } - } - } - - // Onto next chunk - chunk = chunk->m_next; - } - - if (astBuilder || irModule) - { - module.astBuilder = astBuilder; - module.astRootNode = astRootNode; - module.irModule = irModule; - - out.modules.add(module); - } - - // If no progress, step to next chunk - chunk = (chunk == startChunk) ? chunk->m_next : chunk; - } - } + return nullptr; +} - // Add all the entry points - { - List<RiffContainer::DataChunk*> entryPointChunks; - containerChunk->findContained(SerialBinary::kEntryPointFourCc, entryPointChunks); +SlangResult readSourceLocationsFromDebugChunk( + RiffContainer::ListChunk* debugChunk, + SourceManager* sourceManager, + RefPtr<SerialSourceLocReader>& outReader) +{ + if (!debugChunk) + return SLANG_FAIL; - for (auto entryPointChunk : entryPointChunks) - { - auto reader = entryPointChunk->asReadHelper(); + // Source location serialization uses the old approach where + // there is an intermediate in-memory data structure that the + // raw data from the RIFF gets deserialized into, before that + // intermediate representation gets transformed into something + // more directly usable. + // + // Thus we start with a first step where we simply read the data + // from the RIFF into the intermediate structure. + // + SerialSourceLocData intermediateData; + SLANG_RETURN_ON_FAIL(intermediateData.readContainer(debugChunk)); - SerialContainerBinary::EntryPoint srcEntryPoint; - SLANG_RETURN_ON_FAIL(reader.read(srcEntryPoint)); + // After reading the data into the intermediate representation, + // we turn it into a `SerialSourceLocReader`, which vends source + // location information to other deserialization tasks (both IR + // and AST deserialization). + // + auto reader = RefPtr(new SerialSourceLocReader()); + SLANG_RETURN_ON_FAIL(reader->read(&intermediateData, sourceManager)); - SerialContainerData::EntryPoint dstEntryPoint; + outReader = reader; + return SLANG_OK; +} - dstEntryPoint.name = options.namePool->getName( - containerStringPool.getSlice(StringSlicePool::Handle(srcEntryPoint.name))); - dstEntryPoint.profile.raw = srcEntryPoint.profile; - dstEntryPoint.mangledName = - containerStringPool.getSlice(StringSlicePool::Handle(srcEntryPoint.mangledName)); +SlangResult decodeModuleIR( + RefPtr<IRModule>& outIRModule, + RiffContainer::Chunk* chunk, + Session* session, + SerialSourceLocReader* sourceLocReader) +{ + // IR serialization still uses the older approach, where + // data gets deserialized from the RIFF into an intermediate + // data structure (`IRSerialData`), and then the actual + // in-memory structures are created based on the intermediate. + // + // Thus we start by running the `IRSerialReader::readContainer` + // logic to get the `IRSerialData` representation. + // + // TODO(tfoley): This should all get streamlined so that we + // are deserializing IR nodes directly from the format written + // into the RIFF. + // + auto listChunk = as<RiffContainer::ListChunk>(chunk); + if (!listChunk) + return SLANG_FAIL; + IRSerialData serialData; + SLANG_RETURN_ON_FAIL(IRSerialReader::readContainer(listChunk, &serialData)); - out.entryPoints.add(dstEntryPoint); - } - } + // Next we read the actual IR representation out from the + // `serialData`. This is the step that may pull source-location + // information from the provided `sourceLocReader`. + // + IRSerialReader reader; + SLANG_RETURN_ON_FAIL(reader.read(serialData, session, sourceLocReader, outIRModule)); return SLANG_OK; } diff --git a/source/slang/slang-serialize-container.h b/source/slang/slang-serialize-container.h index 8ddc5072a..4c1053a6d 100644 --- a/source/slang/slang-serialize-container.h +++ b/source/slang/slang-serialize-container.h @@ -12,72 +12,6 @@ namespace Slang class EndToEndCompileRequest; -/* The binary representation actually held in riff/file format*/ -struct SerialContainerBinary -{ - struct Target - { - uint32_t target; - uint32_t flags; - uint32_t profile; - uint32_t floatingPointMode; - }; - - struct EntryPoint - { - uint32_t name; - uint32_t profile; - uint32_t mangledName; - }; -}; - -struct SerialContainerDataModule -{ - RefPtr<IRModule> irModule; ///< The IR for the module - RefPtr<ASTBuilder> astBuilder; ///< The astBuilder that owns the astRootNode - NodeBase* astRootNode = nullptr; ///< The module decl - List<String> dependentFiles; - SHA1::Digest digest; -}; - -/* Struct that holds all the data that can be held in a 'container' */ -struct SerialContainerData -{ - struct Target - { - CodeGenTarget codeGenTarget = CodeGenTarget::Unknown; - SlangTargetFlags flags = kDefaultTargetFlags; - Profile profile; - FloatingPointMode floatingPointMode = FloatingPointMode::Default; - }; - - struct TargetComponent - { - // IR module for a specific compilation target - Target target; - RefPtr<IRModule> irModule; - }; - - typedef SerialContainerDataModule Module; - - struct EntryPoint - { - Name* name = nullptr; - Profile profile; - String mangledName; - }; - - void clear() - { - entryPoints.clear(); - modules.clear(); - targetComponents.clear(); - } - - List<Module> modules; - List<TargetComponent> targetComponents; - List<EntryPoint> entryPoints; -}; struct SerialContainerUtil { @@ -104,37 +38,6 @@ struct SerialContainerUtil String modulePath; }; - /// Add module to outData - static SlangResult addModuleToData( - Module* module, - const WriteOptions& options, - SerialContainerData& outData); - - /// Get the serializable contents of the request as data - static SlangResult addEndToEndRequestToData( - EndToEndCompileRequest* request, - const WriteOptions& options, - SerialContainerData& outData); - - /// Convert front end request into something serializable - static SlangResult addFrontEndRequestToData( - FrontEndCompileRequest* request, - const WriteOptions& options, - SerialContainerData& outData); - - /// Write the data into the container - static SlangResult write( - const SerialContainerData& data, - const WriteOptions& options, - RiffContainer* container); - - /// Read the container into outData - static SlangResult read( - RiffContainer* container, - const ReadOptions& options, - const LoadedModuleDictionary* additionalLoadedModules, - SerialContainerData& outData); - /// Verify IR serialization static SlangResult verifyIRSerialize( IRModule* module, @@ -153,6 +56,192 @@ struct SerialContainerUtil static SlangResult write(Module* module, const WriteOptions& options, Stream* stream); }; + +struct ChunkRef +{ +public: + ChunkRef(RiffContainer::Chunk* chunk) + : _chunk(chunk) + { + } + + RiffContainer::Chunk* ptr() const { return _chunk; } + +protected: + RiffContainer::Chunk* _chunk = nullptr; +}; + +struct DataChunkRef : ChunkRef +{ +public: + DataChunkRef(RiffContainer::DataChunk* chunk) + : ChunkRef(chunk) + { + } + + RiffContainer::DataChunk* ptr() const { return static_cast<RiffContainer::DataChunk*>(_chunk); } + + operator RiffContainer::DataChunk*() const { return ptr(); } +}; + + +template<typename T> +struct ChunkRefList +{ +public: + struct Iterator + { + public: + Iterator(RiffContainer::Chunk* chunk) + : _chunk(chunk) + { + } + + bool operator!=(Iterator const& other) const { return _chunk != other._chunk; } + + void operator++() { _chunk = _chunk->m_next; } + + T operator*() + { + ChunkRef ref(_chunk); + return *(T*)&ref; + } + + private: + RiffContainer::Chunk* _chunk = nullptr; + }; + + Iterator begin() const { return _list ? _list->getFirstContainedChunk() : nullptr; } + Iterator end() const { return Iterator(nullptr); } + + Count getCount() + { + Count count = 0; + for (auto i : *this) + count++; + return count; + } + + T getFirst() { return *begin(); } + + ChunkRefList() {} + + ChunkRefList(RiffContainer::ListChunk* list) + : _list(list) + { + } + + operator RiffContainer::ListChunk*() const { return _list; } + +private: + RiffContainer::ListChunk* _list = nullptr; +}; + +struct ListChunkRef : ChunkRef +{ +public: + ListChunkRef(RiffContainer::Chunk* chunk) + : ChunkRef(chunk) + { + } + + RiffContainer::ListChunk* ptr() const { return static_cast<RiffContainer::ListChunk*>(_chunk); } + + operator RiffContainer::ListChunk*() const { return ptr(); } +}; + + +struct StringChunkRef : DataChunkRef +{ +public: + String getValue(); +}; + +struct IRModuleChunkRef : ListChunkRef +{ +public: + explicit IRModuleChunkRef(RiffContainer::ListChunk* chunk) + : ListChunkRef(chunk) + { + } +}; + +struct ASTModuleChunkRef : ListChunkRef +{ +public: + explicit ASTModuleChunkRef(RiffContainer::ListChunk* chunk) + : ListChunkRef(chunk) + { + } +}; + +struct ModuleChunkRef : ListChunkRef +{ +public: + static ModuleChunkRef find(RiffContainer* container); + + String getName(); + + IRModuleChunkRef findIR(); + ASTModuleChunkRef findAST(); + + SHA1::Digest getDigest(); + + ChunkRefList<StringChunkRef> getFileDependencies(); + +protected: + ModuleChunkRef(RiffContainer::Chunk* chunk) + : ListChunkRef(chunk) + { + } +}; + +struct EntryPointChunkRef : ListChunkRef +{ +public: + String getMangledName() const; + String getName() const; + Profile getProfile() const; + +protected: + EntryPointChunkRef(RiffContainer::Chunk* chunk) + : ListChunkRef(chunk) + { + } +}; + +struct ContainerChunkRef : ListChunkRef +{ +public: + static ContainerChunkRef find(RiffContainer* container); + + ChunkRefList<ModuleChunkRef> getModules(); + + ChunkRefList<EntryPointChunkRef> getEntryPoints(); + +protected: + ContainerChunkRef(RiffContainer::Chunk* chunk) + : ListChunkRef(chunk) + { + } +}; + +/// Attempt to find a debug-info chunk relative to +/// the given `startingChunk`. +/// +RiffContainer::ListChunk* findDebugChunk(RiffContainer::Chunk* startingChunk); + +SlangResult readSourceLocationsFromDebugChunk( + RiffContainer::ListChunk* debugChunk, + SourceManager* sourceManager, + RefPtr<SerialSourceLocReader>& outReader); + +SlangResult decodeModuleIR( + RefPtr<IRModule>& outIRModule, + RiffContainer::Chunk* chunk, + Session* session, + SerialSourceLocReader* sourceLocReader); + } // namespace Slang #endif diff --git a/source/slang/slang-serialize-factory.cpp b/source/slang/slang-serialize-factory.cpp deleted file mode 100644 index 5ad1e4911..000000000 --- a/source/slang/slang-serialize-factory.cpp +++ /dev/null @@ -1,123 +0,0 @@ -// slang-serialize-factory.cpp -#include "slang-serialize-factory.h" - -#include "../core/slang-math.h" -#include "slang-ast-builder.h" -#include "slang-ast-reflect.h" -#include "slang-ref-object-reflect.h" -#include "slang-serialize-ast.h" - -// Needed for ModuleSerialFilter -// Needed for 'findModuleForDecl' -#include "slang-legalize-types.h" -#include "slang-mangle.h" - -namespace Slang -{ - -/* !!!!!!!!!!!!!!!!!!!!!! DefaultSerialObjectFactory !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -void* DefaultSerialObjectFactory::create(SerialTypeKind typeKind, SerialSubType subType) -{ - switch (typeKind) - { - case SerialTypeKind::NodeBase: - { - return m_astBuilder->createByNodeType(ASTNodeType(subType)); - } - case SerialTypeKind::RefObject: - { - const ReflectClassInfo* info = SerialRefObjects::getClassInfo(RefObjectType(subType)); - - if (info && info->m_createFunc) - { - RefObject* obj = reinterpret_cast<RefObject*>(info->m_createFunc(nullptr)); - return _add(obj); - } - return nullptr; - } - default: - break; - } - - return nullptr; -} - -void* DefaultSerialObjectFactory::getOrCreateVal(ValNodeDesc&& desc) -{ - return m_astBuilder->_getOrCreateImpl(_Move(desc)); -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ModuleSerialFilter !!!!!!!!!!!!!!!!!!!!!!!! - -SerialIndex ModuleSerialFilter::writePointer(SerialWriter* writer, const RefObject* inPtr) -{ - // We don't serialize Module - if (as<Module>(inPtr)) - { - writer->setPointerIndex(inPtr, SerialIndex(0)); - return SerialIndex(0); - } - - // For now for everything else just write it - return writer->writeObject(inPtr); -} - -SerialIndex ModuleSerialFilter::writePointer(SerialWriter* writer, const NodeBase* inPtr) -{ - NodeBase* ptr = const_cast<NodeBase*>(inPtr); - SLANG_ASSERT(ptr); - - - if (Decl* decl = as<Decl>(ptr)) - { - ModuleDecl* moduleDecl = findModuleForDecl(decl); - if (moduleDecl && moduleDecl != m_moduleDecl) - { - ASTBuilder* astBuilder = m_moduleDecl->module->getASTBuilder(); - - // It's a reference to a declaration in another module, so first get the symbol name. - // Note that we will always name an import symbol in the form of - // <module_name>!<symbol_mangled_name> for serialization. - // This is because <symbol_mangled_name> does not necessarily include the name of its - // parent module when it is qualified as `extern` or `export`. - // - String mangledName = - getText(moduleDecl->getName()) + "!" + getMangledName(astBuilder, decl); - - // Add as an import symbol - return writer->addImportSymbol(mangledName); - } - else - { - // Okay... we can just write it out then - return writer->writeObject(ptr); - } - } - // For now for everything else just write it - return writer->writeObject(ptr); -} - -/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialClassesUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -/* static */ SlangResult SerialClassesUtil::addSerialClasses(SerialClasses* serialClasses) -{ - ASTSerialUtil::addSerialClasses(serialClasses); - SerialRefObjects::addSerialClasses(serialClasses); - - // Check if it seems ok - SLANG_ASSERT(serialClasses->isOk()); - - return SLANG_OK; -} - -/* static */ SlangResult SerialClassesUtil::create(RefPtr<SerialClasses>& out) -{ - RefPtr<SerialClasses> classes(new SerialClasses); - SLANG_RETURN_ON_FAIL(addSerialClasses(classes)); - - out = classes; - return SLANG_OK; -} - -} // namespace Slang diff --git a/source/slang/slang-serialize-factory.h b/source/slang/slang-serialize-factory.h deleted file mode 100644 index ef13fff83..000000000 --- a/source/slang/slang-serialize-factory.h +++ /dev/null @@ -1,49 +0,0 @@ -// slang-serialize-factory.h -#ifndef SLANG_SERIALIZE_FACTORY_H -#define SLANG_SERIALIZE_FACTORY_H - -#include "slang-serialize.h" - -namespace Slang -{ - -// !!!!!!!!!!!!!!!!!!!!! DefaultSerialObjectFactory !!!!!!!!!!!!!!!!!!!!!!!!!!! - -class ASTBuilder; - -class DefaultSerialObjectFactory : public SerialObjectFactory -{ -public: - virtual void* create(SerialTypeKind typeKind, SerialSubType subType) SLANG_OVERRIDE; - virtual void* getOrCreateVal(ValNodeDesc&& desc) SLANG_OVERRIDE; - - DefaultSerialObjectFactory(ASTBuilder* astBuilder) - : m_astBuilder(astBuilder) - { - } - -protected: - RefObject* _add(RefObject* obj) - { - m_scope.add(obj); - return obj; - } - - // We keep RefObjects in scope - List<RefPtr<RefObject>> m_scope; - ASTBuilder* m_astBuilder; -}; - - -struct SerialClassesUtil -{ - /// Add all types to serialClasses - static SlangResult addSerialClasses(SerialClasses* serialClasses); - /// Create SerialClasses with all the types added - static SlangResult create(RefPtr<SerialClasses>& out); -}; - - -} // namespace Slang - -#endif diff --git a/source/slang/slang-serialize-misc-type-info.h b/source/slang/slang-serialize-misc-type-info.h deleted file mode 100644 index 121b205d5..000000000 --- a/source/slang/slang-serialize-misc-type-info.h +++ /dev/null @@ -1,224 +0,0 @@ -// slang-serialize-misc-type-info.h -#ifndef SLANG_SERIALIZE_MISC_TYPE_INFO_H -#define SLANG_SERIALIZE_MISC_TYPE_INFO_H - -#include "../compiler-core/slang-source-loc.h" -#include "slang-compiler.h" -#include "slang-serialize-type-info.h" - -namespace Slang -{ - -/* Conversion for serialization for some more misc Slang types - */ - - -// Because is sized, we don't need to convert -template<> -struct SerialTypeInfo<FeedbackType::Kind> : public SerialIdentityTypeInfo<FeedbackType::Kind> -{ -}; - -// SamplerStateFlavor - -template<> -struct SerialTypeInfo<SamplerStateFlavor> - : public SerialConvertTypeInfo<SamplerStateFlavor, uint8_t> -{ -}; - -// ImageFormat -template<> -struct SerialTypeInfo<ImageFormat> : public SerialConvertTypeInfo<ImageFormat, uint8_t> -{ -}; - -// Stage -template<> -struct SerialTypeInfo<Stage> : public SerialConvertTypeInfo<Stage, uint8_t> -{ -}; - -// TokenType -template<> -struct SerialTypeInfo<TokenType> : public SerialConvertTypeInfo<TokenType, uint8_t> -{ -}; - -// BaseType -template<> -struct SerialTypeInfo<BaseType> : public SerialConvertTypeInfo<BaseType, uint8_t> -{ -}; - -// SemanticVersion -template<> -struct SerialTypeInfo<SemanticVersion> : public SerialIdentityTypeInfo<SemanticVersion> -{ -}; - -// SourceLoc - -// Make the type exposed, so we can look for it if we want to remap. -template<> -struct SerialTypeInfo<SourceLoc> -{ - typedef SourceLoc NativeType; - typedef SerialSourceLoc SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialSourceLoc) - }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - SerialSourceLocWriter* sourceLocWriter = - writer->getExtraObjects().get<SerialSourceLocWriter>(); - *(SerialType*)outSerial = sourceLocWriter - ? sourceLocWriter->addSourceLoc(*(const NativeType*)inNative) - : SerialType(0); - } - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - SerialSourceLocReader* sourceLocReader = - reader->getExtraObjects().get<SerialSourceLocReader>(); - *(NativeType*)outNative = sourceLocReader - ? sourceLocReader->getSourceLoc(*(const SerialType*)inSerial) - : NativeType::fromRaw(0); - } -}; - -// Token -template<> -struct SerialTypeInfo<Token> -{ - typedef Token NativeType; - struct SerialType - { - SerialTypeInfo<BaseType>::SerialType type; - SerialTypeInfo<SourceLoc>::SerialType loc; - SerialIndex name; - }; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - SerialTypeInfo<TokenType>::toSerial(writer, &src.type, &dst.type); - SerialTypeInfo<SourceLoc>::toSerial(writer, &src.loc, &dst.loc); - - if (src.flags & TokenFlag::Name) - { - dst.name = writer->addName(src.getName()); - } - else - { - dst.name = writer->addString(src.getContent()); - } - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - dst.flags = 0; - dst.charsNameUnion.chars = nullptr; - - SerialTypeInfo<TokenType>::toNative(reader, &src.type, &dst.type); - SerialTypeInfo<SourceLoc>::toNative(reader, &src.loc, &dst.loc); - - // At the other end all token content will appear as Names. - if (src.name != SerialIndex(0)) - { - dst.charsNameUnion.name = reader->getName(src.name); - dst.flags |= TokenFlag::Name; - } - } -}; - -// NameLoc -template<> -struct SerialTypeInfo<NameLoc> -{ - typedef NameLoc NativeType; - struct SerialType - { - SerialTypeInfo<SourceLoc>::SerialType loc; - SerialIndex name; - }; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - dst.name = writer->addName(src.name); - SerialTypeInfo<SourceLoc>::toSerial(writer, &src.loc, &dst.loc); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - dst.name = reader->getName(src.name); - SerialTypeInfo<SourceLoc>::toNative(reader, &src.loc, &dst.loc); - } -}; - -// DiagnosticInfo -template<> -struct SerialTypeInfo<const DiagnosticInfo*> -{ - typedef const DiagnosticInfo* NativeType; - typedef SerialIndex SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - dst = src ? writer->addString(UnownedStringSlice(src->name)) : SerialIndex(0); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - if (src == SerialIndex(0)) - { - dst = nullptr; - } - else - { - dst = findDiagnosticByName(reader->getStringSlice(src)); - } - } -}; - -// DeclAssociation -template<> -struct SerialTypeInfo<DeclAssociation> : SerialIdentityTypeInfo<DeclAssociation> -{ -}; -template<> -struct SerialTypeInfo<DeclAssociationKind> - : public SerialConvertTypeInfo<DeclAssociationKind, uint8_t> -{ -}; - -} // namespace Slang - -#endif diff --git a/source/slang/slang-serialize-reflection.cpp b/source/slang/slang-serialize-reflection.cpp deleted file mode 100644 index 60ab31e17..000000000 --- a/source/slang/slang-serialize-reflection.cpp +++ /dev/null @@ -1,123 +0,0 @@ -// slang-serialize-reflection.cpp -#include "slang-serialize-reflection.h" - -#include "slang-serialize.h" - -namespace Slang -{ - -bool ReflectClassInfo::isSubClassOfSlow(const ThisType& super) const -{ - ReflectClassInfo const* info = this; - while (info) - { - if (info == &super) - return true; - info = info->m_superClass; - } - return false; -} - -#if 0 - -// #if'd out because produces a warning->error if not used. -static bool _checkSubClassRange(ReflectClassInfo*const* typeInfos, Index typeInfosCount) -{ - for (Index i = 0; i < typeInfosCount; ++i) - { - for (Index j = 0; j < typeInfosCount; ++j) - { - auto a = typeInfos[i]; - auto b = typeInfos[j]; - if (a->isSubClassOf(*b) != a->isSubClassOfSlow(*b)) - { - return false; - } - } - } - - return true; -} - -#endif - -static uint32_t _calcRangeRec( - ReflectClassInfo* classInfo, - const Dictionary<const ReflectClassInfo*, List<ReflectClassInfo*>>& childMap, - uint32_t index) -{ - classInfo->m_classId = index++; - // Do the calc range for all the children - auto list = childMap.tryGetValue(classInfo); - - if (list) - { - for (auto child : *list) - { - index = _calcRangeRec(child, childMap, index); - } - } - - classInfo->m_lastClassId = index; - return index; -} - -static ReflectClassInfo* _calcRoot(ReflectClassInfo* classInfo) -{ - while (classInfo->m_superClass) - { - classInfo = const_cast<ReflectClassInfo*>(classInfo->m_superClass); - } - return classInfo; -} - - -/* static */ void ReflectClassInfo::calcClassIdHierachy( - uint32_t baseIndex, - ReflectClassInfo* const* typeInfos, - Index typeInfosCount) -{ - SLANG_ASSERT(typeInfosCount > 0); - - // TODO(JS): - // Note that the calculating of the ranges could be done more efficiently by adding to an array - // of struct { super, class }, sorting, by super classs and using a dictionary to map from class - // it's first in list of super class use. This works for now though. - - // The root cannot be shared with another hierarchy - as doing so will mean that the range will - // be incorrect (it would need to span both trees) - ReflectClassInfo* root = _calcRoot(typeInfos[0]); - - // We want to produce a map from a node that holds all of it's children - Dictionary<const ThisType*, List<ThisType*>> childMap; - - const List<ThisType*> emptyList; - { - for (Index i = 0; i < typeInfosCount; ++i) - { - auto typeInfo = typeInfos[i]; - if (typeInfo->m_superClass) - { - // Add to that item - List<ThisType*>* list = - childMap.tryGetValueOrAdd(typeInfo->m_superClass, emptyList); - if (!list) - { - list = childMap.tryGetValue(typeInfo->m_superClass); - } - SLANG_ASSERT(list); - list->add(typeInfo); - } - - // The root should be the same for all types - SLANG_ASSERT(_calcRoot(typeInfo) == root); - } - } - - // We want to recursively work out a range - _calcRangeRec(root, childMap, baseIndex); - - // SLANG_ASSERT(_checkSubClassRange(typeInfos, typeInfoCount)); -} - -} // namespace Slang diff --git a/source/slang/slang-serialize-reflection.h b/source/slang/slang-serialize-reflection.h deleted file mode 100644 index 63ea1e7b6..000000000 --- a/source/slang/slang-serialize-reflection.h +++ /dev/null @@ -1,86 +0,0 @@ -// slang-serialize-reflection.h -#ifndef SLANG_SERIALIZE_REFLECTION_H -#define SLANG_SERIALIZE_REFLECTION_H - -#include "../compiler-core/slang-name.h" - -namespace Slang -{ - -struct ReflectClassInfo -{ - typedef ReflectClassInfo ThisType; - - typedef void* (*CreateFunc)(void* context); - typedef void (*DestructorFunc)(void* ptr); - - /// A constant time implementation of isSubClassOf - SLANG_FORCE_INLINE bool isSubClassOf(const ThisType& super) const - { - // We include super.m_classId, because it's a subclass of itself. - return m_classId >= super.m_classId && m_classId <= super.m_lastClassId; - } - - SLANG_FORCE_INLINE static bool isValidTypeId(uint32_t typeId) { return int32_t(typeId) >= 0; } - - // True if typeId derives from this type - SLANG_FORCE_INLINE bool isDerivedFrom(uint32_t typeId) const - { - SLANG_ASSERT(isValidTypeId(typeId) && isValidTypeId(m_classId)); - return typeId >= m_classId && typeId <= m_lastClassId; - } - - SLANG_FORCE_INLINE static bool isSubClassOf(uint32_t type, const ThisType& super) - { - SLANG_ASSERT(isValidTypeId(type) && isValidTypeId(super.m_classId)); - // We include super.m_classId, because it's a subclass of itself. - return type >= super.m_classId && type <= super.m_lastClassId; - } - - /// Will produce the same result as isSubClassOf (if enumerated), but more slowly by traversing - /// the m_superClass Works without initRange being called. - bool isSubClassOfSlow(const ThisType& super) const; - - /// Calculate infos m_classId for all the infos specified such that they are honor the - /// inheritance relationship such that a m_classId of a child is > m_classId && <= m_lastClassId - static void calcClassIdHierachy( - uint32_t baseIndex, - ReflectClassInfo* const* infos, - Index infosCount); - - uint32_t m_classId; ///< Not necessarily set. - uint32_t m_lastClassId; - - const ReflectClassInfo* - m_superClass; ///< The super class of this class, or nullptr if has no super class. - const char* m_name; ///< Textual class name, for debugging - CreateFunc m_createFunc; ///< Callback to use when creating instances (using an ASTBuilder for - ///< backing memory) - DestructorFunc m_destructorFunc; ///< The destructor for this type. Being just destructor, does - ///< not free backing memory for type. - - uint32_t m_sizeInBytes; ///< Total size of the type - uint8_t m_alignment; ///< The required alignment of the type -}; - -// Does nothing - just a mark to the C++ extractor -#define SLANG_REFLECTED -#define SLANG_UNREFLECTED - -#define SLANG_PRE_DECLARE(SUFFIX, DEF) - -#define SLANG_TYPE_SET(SUFFIX, ...) - -// Use these macros to help define Super, and making the base definition NOT have a Super -// definition. For example something like... - -#define SLANG_CLASS_REFLECT_SUPER_BASE(SUPER) -#define SLANG_CLASS_REFLECT_SUPER_INNER(SUPER) typedef SUPER Super; -#define SLANG_CLASS_REFLECT_SUPER_LEAF(SUPER) typedef SUPER Super; - -// Mark a value class -#define SLANG_VALUE_CLASS(x) - -} // namespace Slang - -#endif diff --git a/source/slang/slang-serialize-source-loc.h b/source/slang/slang-serialize-source-loc.h index 10d084fb6..24e1813a4 100644 --- a/source/slang/slang-serialize-source-loc.h +++ b/source/slang/slang-serialize-source-loc.h @@ -147,8 +147,6 @@ public: class SerialSourceLocReader : public RefObject { public: - static const SerialExtraType kExtraType = SerialExtraType::SourceLocReader; - Index findViewIndex(SerialSourceLocData::SourceLoc loc); SourceLoc getSourceLoc(SerialSourceLocData::SourceLoc loc); @@ -186,8 +184,6 @@ protected: class SerialSourceLocWriter : public RefObject { public: - static const SerialExtraType kExtraType = SerialExtraType::SourceLocWriter; - class Source : public RefObject { public: diff --git a/source/slang/slang-serialize-type-info.h b/source/slang/slang-serialize-type-info.h deleted file mode 100644 index 20662c319..000000000 --- a/source/slang/slang-serialize-type-info.h +++ /dev/null @@ -1,491 +0,0 @@ -// slang-serialize-type-info.h -#ifndef SLANG_SERIALIZE_TYPE_INFO_H -#define SLANG_SERIALIZE_TYPE_INFO_H - -#include "slang-serialize.h" - -namespace Slang -{ - -/* For the serialization system to work we need to defined how native types are represented in the -serialized format. This information is defined by specializing SerialTypeInfo with the native type -to be converted This header provides conversion for common Slang types. -*/ - - -// We need to have a way to map between the two. -// If no mapping is needed, (just a copy), then we don't bother with the functions -template<typename T> -struct SerialBasicTypeInfo -{ - typedef T NativeType; - typedef T SerialType; - - // We want the alignment to be the same as the size of the type for basic types - // NOTE! Might be different from SLANG_ALIGN_OF(SerialType) - enum - { - SerialAlignment = sizeof(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - SLANG_UNUSED(writer); - *(T*)serial = *(const T*)native; - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - SLANG_UNUSED(reader); - *(T*)native = *(const T*)serial; - } - - static const SerialType* getType() - { - static const SerialType type = - {sizeof(SerialType), uint8_t(SerialAlignment), &toSerial, &toNative}; - return &type; - } -}; - -template<typename NATIVE_T, typename SERIAL_T> -struct SerialConvertTypeInfo -{ - typedef NATIVE_T NativeType; - typedef SERIAL_T SerialType; - - enum - { - SerialAlignment = SerialBasicTypeInfo<SERIAL_T>::SerialAlignment - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - SLANG_UNUSED(writer); - *(SERIAL_T*)serial = SERIAL_T(*(const NATIVE_T*)native); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - SLANG_UNUSED(reader); - *(NATIVE_T*)native = NATIVE_T(*(const SERIAL_T*)serial); - } -}; - -template<typename T> -struct SerialIdentityTypeInfo -{ - typedef T NativeType; - typedef T SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - SLANG_UNUSED(writer); - *(T*)serial = *(const T*)native; - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - SLANG_UNUSED(reader); - *(T*)native = *(const T*)serial; - } -}; - -// Don't need to convert the index type - -template<> -struct SerialTypeInfo<SerialIndex> : public SerialIdentityTypeInfo<SerialIndex> -{ -}; - -// Implement for Basic Types - -template<> -struct SerialTypeInfo<uint8_t> : public SerialBasicTypeInfo<uint8_t> -{ -}; -template<> -struct SerialTypeInfo<uint16_t> : public SerialBasicTypeInfo<uint16_t> -{ -}; -template<> -struct SerialTypeInfo<uint32_t> : public SerialBasicTypeInfo<uint32_t> -{ -}; -template<> -struct SerialTypeInfo<uint64_t> : public SerialBasicTypeInfo<uint64_t> -{ -}; - -template<> -struct SerialTypeInfo<int8_t> : public SerialBasicTypeInfo<int8_t> -{ -}; -template<> -struct SerialTypeInfo<int16_t> : public SerialBasicTypeInfo<int16_t> -{ -}; -template<> -struct SerialTypeInfo<int32_t> : public SerialBasicTypeInfo<int32_t> -{ -}; -template<> -struct SerialTypeInfo<int64_t> : public SerialBasicTypeInfo<int64_t> -{ -}; - -template<> -struct SerialTypeInfo<float> : public SerialBasicTypeInfo<float> -{ -}; -template<> -struct SerialTypeInfo<double> : public SerialBasicTypeInfo<double> -{ -}; - -// Fixed arrays - -template<typename T, size_t N> -struct SerialTypeInfo<T[N]> -{ - typedef SerialTypeInfo<T> ElementASTSerialType; - typedef typename ElementASTSerialType::SerialType SerialElementType; - - typedef T NativeType[N]; - typedef SerialElementType SerialType[N]; - - enum - { - SerialAlignment = SerialTypeInfo<T>::SerialAlignment - }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - SerialElementType* serial = (SerialElementType*)outSerial; - - if (writer->getFlags() & SerialWriter::Flag::ZeroInitialize) - { - ::memset(outSerial, 0, sizeof(SerialElementType) * N); - } - - const T* native = (const T*)inNative; - for (Index i = 0; i < Index(N); ++i) - { - ElementASTSerialType::toSerial(writer, native + i, serial + i); - } - } - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - const SerialElementType* serial = (const SerialElementType*)inSerial; - T* native = (T*)outNative; - for (Index i = 0; i < Index(N); ++i) - { - ElementASTSerialType::toNative(reader, serial + i, native + i); - } - } -}; - -// Special case bool - as we can't rely on size alignment -template<> -struct SerialTypeInfo<bool> -{ - typedef bool NativeType; - typedef uint8_t SerialType; - - enum - { - SerialAlignment = sizeof(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - SLANG_UNUSED(writer); - *(SerialType*)outSerial = *(const NativeType*)inNative ? 1 : 0; - } - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - SLANG_UNUSED(reader); - *(NativeType*)outNative = (*(const SerialType*)inSerial) != 0; - } -}; - -// Specialization for all enum types -template<typename T> -struct SerialTypeInfo<T, typename std::enable_if<std::is_enum<T>::value>::type> - : public SerialIdentityTypeInfo<T> -{ -}; - -class Val; - -// Pointer - -template<typename T, typename /*sfinaeType*/ = void> -struct PtrSerialTypeInfo -{ - typedef T* NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) - { - auto ptrToWrite = *(T**)inNative; - static_assert(!IsBaseOf<Val, T>::Value); - *(SerialIndex*)outSerial = writer->addPointer(ptrToWrite); - } - - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - *(T**)outNative = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<T>(); - } -}; - -template<typename T> -struct SerialTypeInfo<T*> : public PtrSerialTypeInfo<T> -{ -}; - -// RefPtr (pretty much the same as T* - except for native rep) -template<typename T> -struct SerialTypeInfo<RefPtr<T>> -{ - typedef RefPtr<T> NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - *(SerialType*)serial = writer->addPointer(src); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - *(NativeType*)native = reader->getPointer(*(const SerialType*)serial).dynamicCast<T>(); - } -}; - -// Special case Name -template<> -struct SerialTypeInfo<Name*> : public SerialTypeInfo<RefObject*> -{ - // Special case - typedef Name* NativeType; - static void toNative(SerialReader* reader, const void* inSerial, void* outNative) - { - *(Name**)outNative = reader->getName(*(const SerialType*)inSerial); - } -}; - -template<> -struct SerialTypeInfo<const Name*> : public SerialTypeInfo<Name*> -{ -}; - -// List -template<typename T, typename ALLOCATOR> -struct SerialTypeInfo<List<T, ALLOCATOR>> -{ - typedef List<T, ALLOCATOR> NativeType; - typedef SerialIndex SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - dst = writer->addArray(src.getBuffer(), src.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - reader->getArray(src, dst); - } -}; - -// ShortList -template<typename T, int n, typename ALLOCATOR> -struct SerialTypeInfo<ShortList<T, n, ALLOCATOR>> -{ - typedef ShortList<T, n, ALLOCATOR> NativeType; - typedef SerialIndex SerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - dst = writer->addArray(src.getArrayView().getBuffer(), src.getCount()); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& dst = *(NativeType*)native; - auto& src = *(const SerialType*)serial; - - reader->getArray(src, dst); - } -}; - -// String -template<> -struct SerialTypeInfo<String> -{ - typedef String NativeType; - typedef SerialIndex SerialType; - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - *(SerialType*)serial = writer->addString(src); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - dst = reader->getString(src); - } -}; - -// Dictionary -// Note: We leave out SerialTypeInfo specialization for Dictionary, because -// it does not have determinstic ordering. - -// OrderedDictionary -template<typename KEY, typename VALUE> -struct SerialTypeInfo<OrderedDictionary<KEY, VALUE>> -{ - typedef OrderedDictionary<KEY, VALUE> NativeType; - struct SerialType - { - SerialIndex keys; ///< Index an array - SerialIndex values; ///< Index an array - }; - - typedef typename SerialTypeInfo<KEY>::SerialType KeySerialType; - typedef typename SerialTypeInfo<VALUE>::SerialType ValueSerialType; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialIndex) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - List<KeySerialType> keys; - List<ValueSerialType> values; - - Index count = Index(src.getCount()); - keys.setCount(count); - values.setCount(count); - - if (writer->getFlags() & SerialWriter::Flag::ZeroInitialize) - { - ::memset(keys.getBuffer(), 0, count * sizeof(KeySerialType)); - ::memset(values.getBuffer(), 0, count * sizeof(ValueSerialType)); - } - - Index i = 0; - for (const auto& pair : src) - { - SerialTypeInfo<KEY>::toSerial(writer, &pair.key, &keys[i]); - SerialTypeInfo<VALUE>::toSerial(writer, &pair.value, &values[i]); - i++; - } - - // When we add the array it is already converted to a serializable type, so add as - // SerialArray - dst.keys = writer->addSerialArray<KEY>(keys.getBuffer(), count); - dst.values = writer->addSerialArray<VALUE>(values.getBuffer(), count); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - // Clear it - dst = NativeType(); - - List<KEY> keys; - List<VALUE> values; - - reader->getArray(src.keys, keys); - reader->getArray(src.values, values); - - SLANG_ASSERT(keys.getCount() == values.getCount()); - - const Index count = keys.getCount(); - for (Index i = 0; i < count; ++i) - { - dst.add(keys[i], values[i]); - } - } -}; - -// KeyValuePair -template<typename KEY, typename VALUE> -struct SerialTypeInfo<KeyValuePair<KEY, VALUE>> -{ - typedef KeyValuePair<KEY, VALUE> NativeType; - - typedef typename SerialTypeInfo<KEY>::SerialType KeySerialType; - typedef typename SerialTypeInfo<VALUE>::SerialType ValueSerialType; - - struct SerialType - { - KeySerialType key; - ValueSerialType value; - }; - - enum - { - SerialAlignment = SLANG_ALIGN_OF(SerialType) - }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - SerialTypeInfo<KEY>::toSerial(writer, &src.key, &dst.key); - SerialTypeInfo<VALUE>::toSerial(writer, &src.value, &dst.value); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - SerialTypeInfo<KEY>::toNative(reader, &src.key, &dst.key); - SerialTypeInfo<VALUE>::toNative(reader, &src.value, &dst.value); - } -}; - - -} // namespace Slang - -#endif diff --git a/source/slang/slang-serialize-types.h b/source/slang/slang-serialize-types.h index 217c14b44..cd2b4c99c 100644 --- a/source/slang/slang-serialize-types.h +++ b/source/slang/slang-serialize-types.h @@ -11,14 +11,7 @@ namespace Slang { - -// An enumeration of types that can be set -enum class SerialExtraType -{ - SourceLocReader, - SourceLocWriter, - CountOf, -}; +class Module; // Options for IR/AST/Debug serialization @@ -35,7 +28,6 @@ struct SerialOptionFlag ASTModule = 0x04, ///< If set will output AST modules - typically required, but potentially ///< not desired (for example with obsfucation) IRModule = 0x08, ///< If set will output IR modules - typically required - ASTFunctionBody = 0x10, ///< If set will serialize AST function bodies. }; }; typedef SerialOptionFlag::Type SerialOptionFlags; @@ -123,6 +115,20 @@ struct SerialListUtil } }; +template<typename T> +struct PropertyKeys +{ +}; + +template<> +struct PropertyKeys<Module> +{ + static const FourCC Digest = SLANG_FOUR_CC('S', 'H', 'A', '1'); + static const FourCC ASTModule = SLANG_FOUR_CC('a', 's', 't', ' '); + static const FourCC IRModule = SLANG_FOUR_CC('i', 'r', ' ', ' '); + static const FourCC FileDependencies = SLANG_FOUR_CC('f', 'd', 'e', 'p'); +}; + // For types/FourCC that work for serializing in general (not just IR). struct SerialBinary { @@ -140,8 +146,44 @@ struct SerialBinary /// An entry point static const FourCC kEntryPointFourCc = SLANG_FOUR_CC('E', 'P', 'n', 't'); - // Module header - static const FourCC kModuleHeaderFourCc = SLANG_FOUR_CC('S', 'm', 'h', 'd'); + static const FourCC kEntryPointListFourCc = SLANG_FOUR_CC('e', 'p', 't', 's'); + + // Module + static const FourCC kModuleFourCC = SLANG_FOUR_CC('s', 'm', 'o', 'd'); + + // The following are "generic" codes, suitable for + // use when serializing content using JSON-like structure. + // + static const FourCC kObjectFourCC = SLANG_FOUR_CC('o', 'b', 'j', ' '); + static const FourCC kPairFourCC = SLANG_FOUR_CC('p', 'a', 'i', 'r'); + static const FourCC kArrayFourCC = SLANG_FOUR_CC('a', 'r', 'r', 'y'); + static const FourCC kDictionaryFourCC = SLANG_FOUR_CC('d', 'i', 'c', 't'); + static const FourCC kNullFourCC = SLANG_FOUR_CC('n', 'u', 'l', 'l'); + static const FourCC kStringFourCC = SLANG_FOUR_CC('s', 't', 'r', ' '); + static const FourCC kTrueFourCC = SLANG_FOUR_CC('t', 'r', 'u', 'e'); + static const FourCC kFalseFourCC = SLANG_FOUR_CC('f', 'a', 'l', 's'); + static const FourCC kInt32FourCC = SLANG_FOUR_CC('i', '3', '2', ' '); + static const FourCC kUInt32FourCC = SLANG_FOUR_CC('u', '3', '2', ' '); + static const FourCC kFloat32FourCC = SLANG_FOUR_CC('f', '3', '2', ' '); + static const FourCC kInt64FourCC = SLANG_FOUR_CC('i', '6', '4', ' '); + static const FourCC kUInt64FourCC = SLANG_FOUR_CC('u', '6', '4', ' '); + static const FourCC kFloat64FourCC = SLANG_FOUR_CC('f', '6', '4', ' '); + + // The following codes are suitable for use when serializing + // content that represents a logical file system. + // + static const FourCC kDirectoryFourCC = SLANG_FOUR_CC('d', 'i', 'r', ' '); + static const FourCC kFileFourCC = SLANG_FOUR_CC('f', 'i', 'l', 'e'); + static const FourCC kNameFourCC = SLANG_FOUR_CC('n', 'a', 'm', 'e'); + static const FourCC kPathFourCC = SLANG_FOUR_CC('p', 'a', 't', 'h'); + static const FourCC kDataFourCC = SLANG_FOUR_CC('d', 'a', 't', 'a'); + + // TODO(tfoley): Figure out where to put all of these so that + // they can be more usefully addressed. + // + static const FourCC kMangledNameFourCC = SLANG_FOUR_CC('m', 'g', 'n', 'm'); + static const FourCC kProfileFourCC = SLANG_FOUR_CC('p', 'r', 'o', 'f'); + struct ArrayHeader { diff --git a/source/slang/slang-serialize-value-type-info.h b/source/slang/slang-serialize-value-type-info.h deleted file mode 100644 index 3ebbdc858..000000000 --- a/source/slang/slang-serialize-value-type-info.h +++ /dev/null @@ -1,83 +0,0 @@ -// slang-serialize-value-type-info.h - -#ifndef SLANG_SERIALIZE_VALUE_TYPE_INFO_H -#define SLANG_SERIALIZE_VALUE_TYPE_INFO_H - -#include "slang-ast-support-types.h" -#include "slang-generated-value-macro.h" -#include "slang-generated-value.h" -#include "slang-serialize-misc-type-info.h" -#include "slang-serialize-type-info.h" -#include "slang-serialize.h" - -// Create the functions to automatically convert between value types - -namespace Slang -{ - -// TODO(JS): We may want to strip const or other modifiers -// Just strips the brackets. -#define SLANG_VALUE_GET_TYPE(TYPE) TYPE - -#define SLANG_VALUE_FIELD_TO_SERIAL(FIELD_NAME, TYPE, param) \ - SerialTypeInfo<decltype(src->FIELD_NAME)>::toSerial(writer, &src->FIELD_NAME, &dst->FIELD_NAME); -#define SLANG_VALUE_FIELD_TO_NATIVE(FIELD_NAME, TYPE, param) \ - SerialTypeInfo<decltype(dst->FIELD_NAME)>::toNative(reader, &src->FIELD_NAME, &dst->FIELD_NAME); - -#define SLANG_IF_HAS_SUPER_BASE(x) -#define SLANG_IF_HAS_SUPER_INNER(x) x -#define SLANG_IF_HAS_SUPER_LEAF(x) x - -#define SLANG_VALUE_TO_SERIAL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - static void toSerial(SerialWriter* writer, const void* native, void* serial) \ - { \ - SLANG_IF_HAS_SUPER_##TYPE( \ - SerialTypeInfo<SUPER>::toSerial(writer, native, serial);) auto dst = \ - (SerialType*)serial; \ - auto src = (const NativeType*)native; \ - SLANG_FIELDS_Value_##NAME(SLANG_VALUE_FIELD_TO_SERIAL, param) \ - } - -#define SLANG_VALUE_TO_NATIVE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - static void toNative(SerialReader* reader, const void* serial, void* native) \ - { \ - SLANG_IF_HAS_SUPER_##TYPE( \ - SerialTypeInfo<SUPER>::toNative(reader, serial, native);) auto src = \ - (const SerialType*)serial; \ - auto dst = (NativeType*)native; \ - SLANG_FIELDS_Value_##NAME(SLANG_VALUE_FIELD_TO_NATIVE, param) \ - } - -// #define SLANG_VALUE_SERIAL_FIELD(FIELD_NAME, TYPE, param) SerialTypeInfo<SLANG_VALUE_GET_TYPE -// TYPE>::SerialType FIELD_NAME; -#define SLANG_VALUE_SERIAL_FIELD(FIELD_NAME, TYPE, param) \ - SerialTypeInfo<decltype(((param*)nullptr)->FIELD_NAME)>::SerialType FIELD_NAME; - -#define SLANG_VALUE_SERIAL_STRUCT(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - struct SerialType SLANG_IF_HAS_SUPER_##TYPE( : SerialTypeInfo<SUPER>::SerialType) \ - { \ - SLANG_FIELDS_Value_##NAME(SLANG_VALUE_SERIAL_FIELD, NAME) \ - }; - -#define SLANG_VALUE_TYPE_INFO_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - template<> \ - struct SerialTypeInfo<NAME> \ - { \ - typedef NAME NativeType; \ - SLANG_VALUE_SERIAL_STRUCT(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - \ - enum \ - { \ - SerialAlignment = SLANG_ALIGN_OF(SerialType) \ - }; \ - \ - SLANG_VALUE_TO_NATIVE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - SLANG_VALUE_TO_SERIAL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - }; - -#define SLANG_VALUE_TYPE_INFO(NAME) SLANG_Value_##NAME(SLANG_VALUE_TYPE_INFO_IMPL, _) - - -} // namespace Slang - -#endif // SLANG_SERIALIZE_VALUE_TYPE_INFO_H diff --git a/source/slang/slang-serialize.cpp b/source/slang/slang-serialize.cpp index 2a1a92302..a1c555a9b 100644 --- a/source/slang/slang-serialize.cpp +++ b/source/slang/slang-serialize.cpp @@ -8,1182 +8,4 @@ namespace Slang { -const SerialClass* SerialClasses::add(const SerialClass* cls) -{ - List<const SerialClass*>& classes = m_classesByTypeKind[Index(cls->typeKind)]; - - if (cls->subType >= classes.getCount()) - { - classes.setCount(cls->subType + 1); - } - else - { - if (classes[cls->subType]) - { - SLANG_ASSERT(!"Type is already set"); - return nullptr; - } - } - - SerialClass* copy = _createSerialClass(cls); - classes[cls->subType] = copy; - - return copy; -} - -const SerialClass* SerialClasses::add( - SerialTypeKind kind, - SerialSubType subType, - const SerialField* fields, - Index fieldsCount, - const SerialClass* superCls) -{ - SerialClass cls; - cls.typeKind = kind; - cls.subType = subType; - - cls.fields = fields; - cls.fieldsCount = fieldsCount; - - // If the superCls is set it must be owned - SLANG_ASSERT(superCls == nullptr || isOwned(superCls)); - - cls.super = superCls; - - // Set to invalid values for now - cls.alignment = 0; - cls.size = 0; - cls.flags = 0; - - return add(&cls); -} - -const SerialClass* SerialClasses::addUnserialized(SerialTypeKind kind, SerialSubType subType) -{ - List<const SerialClass*>& classes = m_classesByTypeKind[Index(kind)]; - - if (subType >= classes.getCount()) - { - classes.setCount(subType + 1); - } - else - { - if (classes[subType]) - { - SLANG_ASSERT(!"Type is already set"); - return nullptr; - } - } - - SerialClass* dst = m_arena.allocate<SerialClass>(); - - dst->typeKind = kind; - dst->subType = subType; - - dst->size = 0; - dst->alignment = 0; - - dst->fields = nullptr; - dst->fieldsCount = 0; - dst->flags = SerialClassFlag::DontSerialize; - dst->super = nullptr; - - classes[subType] = dst; - return dst; -} - -bool SerialClasses::isOwned(const SerialClass* cls) const -{ - const List<const SerialClass*>& classes = m_classesByTypeKind[Index(cls->typeKind)]; - return cls->subType < classes.getCount() && classes[cls->subType] == cls; -} - -SerialClass* SerialClasses::_createSerialClass(const SerialClass* cls) -{ - uint32_t maxAlignment = 1; - uint32_t offset = 0; - - if (cls->super) - { - SLANG_ASSERT(isOwned(cls->super)); - - maxAlignment = cls->super->alignment; - offset = cls->super->size; - } - - // Can't be 0 - SLANG_ASSERT(maxAlignment != 0); - // Must be a power of 2 - SLANG_ASSERT((maxAlignment & (maxAlignment - 1)) == 0); - - // Check it is correctly aligned - SLANG_ASSERT((offset & (maxAlignment - 1)) == 0); - - SerialField* dstFields = m_arena.allocateArray<SerialField>(cls->fieldsCount); - - // Okay, go through fields setting their offset - const SerialField* srcFields = cls->fields; - for (Index j = 0; j < cls->fieldsCount; j++) - { - const SerialField& srcField = srcFields[j]; - SerialField& dstField = dstFields[j]; - - // Copy the field - dstField = srcField; - - uint32_t alignment = srcField.type->serialAlignment; - // Make sure the offset is aligned for the field requirement - offset = (offset + alignment - 1) & ~(alignment - 1); - - // Save the field offset - dstField.serialOffset = uint32_t(offset); - - // Move past the field - offset += uint32_t(srcField.type->serialSizeInBytes); - - // Calc the maximum alignment - maxAlignment = (alignment > maxAlignment) ? alignment : maxAlignment; - } - - // Align with maximum alignment - offset = (offset + maxAlignment - 1) & ~(maxAlignment - 1); - - SerialClass* dst = m_arena.allocate<SerialClass>(); - *dst = *cls; - - dst->alignment = uint8_t(maxAlignment); - dst->size = uint32_t(offset); - - dst->fields = dstFields; - - return dst; -} - -bool SerialClasses::isOk() const -{ - StringSlicePool pool(StringSlicePool::Style::Default); - - for (const auto& classes : m_classesByTypeKind) - { - for (const SerialClass* cls : classes) - { - // It is possible potentially to have gaps - if (cls == nullptr) - { - continue; - } - - if (cls->super && cls->super->typeKind != cls->typeKind) - { - // If has a super type, must be the same typeKind - return false; - } - - // Make sure the fields are uniquely named - - pool.clear(); - - { - const SerialClass* curCls = cls; - - do - { - for (Index i = 0; i < curCls->fieldsCount; ++i) - { - const SerialField& field = curCls->fields[i]; - - StringSlicePool::Handle handle; - if (pool.findOrAdd(UnownedStringSlice(field.name), handle)) - { - return false; - } - } - - // Add the fields of the parent - curCls = curCls->super; - } while (curCls); - } - } - } - - return true; -} - - -SerialClasses::SerialClasses() - : m_arena(2097152) -{ -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialWriter !!!!!!!!!!!!!!!!!!!!!!!!!!!! - -SerialWriter::SerialWriter(SerialClasses* classes, SerialFilter* filter, Flags flags) - : m_arena(2097152), m_classes(classes), m_filter(filter), m_flags(flags) -{ - // 0 is always the null pointer - m_entries.add(nullptr); - m_ptrMap.add(nullptr, 0); -} - -struct SkipFunctionBodyRAII -{ - FunctionDeclBase* funcDecl = nullptr; - Stmt* oldBody = nullptr; - SkipFunctionBodyRAII(SerialWriter::Flags flags, const SerialClass* serialCls, const void* ptr) - { - if ((flags & SerialWriter::Flag::SkipFunctionBody) == 0) - return; - - if (serialCls->typeKind != SerialTypeKind::NodeBase) - return; - auto cls = serialCls; - while (cls) - { - auto astNodeType = (ASTNodeType)cls->subType; - if (astNodeType == ASTNodeType::FunctionDeclBase) - { - funcDecl = (FunctionDeclBase*)ptr; - break; - } - cls = cls->super; - } - if (funcDecl) - { - oldBody = funcDecl->body; - // We always need to include body of unsafeForceInlineEarly functions - // since they will need to be available at IR lowering time of the - // user module for pre-linking inling. - if (!isUnsafeForceInlineFunc(funcDecl)) - { - funcDecl->body = nullptr; - } - } - } - ~SkipFunctionBodyRAII() - { - if (funcDecl) - { - funcDecl->body = oldBody; - } - } -}; - -SerialIndex SerialWriter::writeObject(const SerialClass* serialCls, const void* ptr) -{ - if (serialCls->flags & SerialClassFlag::DontSerialize) - { - return SerialIndex(0); - } - - if (serialCls->typeKind == SerialTypeKind::NodeBase && - ReflectClassInfo::isSubClassOf(serialCls->subType, Val::kReflectClassInfo)) - { - return writeValObject((Val*)ptr); - } - - // If we are skipping function bodies, set the body field to nullptr, and - // restore it after serialization. - SkipFunctionBodyRAII clearFunctionBodyRAII(m_flags, serialCls, ptr); - - // This pointer cannot be in the map - SLANG_ASSERT(m_ptrMap.tryGetValue(ptr) == nullptr); - - typedef SerialInfo::ObjectEntry ObjectEntry; - - ObjectEntry* nodeEntry = (ObjectEntry*)m_arena.allocateAligned( - sizeof(ObjectEntry) + serialCls->size, - SerialInfo::MAX_ALIGNMENT); - - nodeEntry->typeKind = serialCls->typeKind; - nodeEntry->subType = serialCls->subType; - nodeEntry->_pad0 = 0; - - nodeEntry->info = SerialInfo::makeEntryInfo(serialCls->alignment); - - // We add before adding fields, so if the fields point to this, the entry will be set - auto index = _add(ptr, nodeEntry); - - // Point to start of payload - uint8_t* serialPayload = (uint8_t*)(nodeEntry + 1); - - if (m_flags & Flag::ZeroInitialize) - { - ::memset(serialPayload, 0, serialCls->size); - } - - while (serialCls) - { - for (Index i = 0; i < serialCls->fieldsCount; ++i) - { - auto field = serialCls->fields[i]; - - // Work out the offsets - auto srcField = ((const uint8_t*)ptr) + field.nativeOffset; - auto dstField = serialPayload + field.serialOffset; - - field.type->toSerialFunc(this, srcField, dstField); - } - - // Get the super class - serialCls = serialCls->super; - } - - return index; -} - -SerialIndex SerialWriter::writeObject(const NodeBase* node) -{ - const SerialClass* serialClass = - m_classes->getSerialClass(SerialTypeKind::NodeBase, SerialSubType(node->astNodeType)); - return writeObject(serialClass, (const void*)node); -} - -SerialIndex SerialWriter::writeValObject(const Val* node) -{ - typedef SerialInfo::ValEntry ValEntry; - - size_t size = node->getOperandCount() * sizeof(SerialInfo::SerialValOperand); - ValEntry* nodeEntry = - (ValEntry*)m_arena.allocateAligned(sizeof(ValEntry) + size, SerialInfo::MAX_ALIGNMENT); - - nodeEntry->typeKind = SerialTypeKind::NodeBase; - nodeEntry->subType = (SerialSubType)node->astNodeType; - nodeEntry->operandCount = (uint32_t)node->getOperandCount(); - nodeEntry->info = SerialInfo::makeEntryInfo(SerialInfo::MAX_ALIGNMENT); - - // We add before adding fields, so if the fields point to this, the entry will be set - auto index = _add(node, nodeEntry); - - ShortList<SerialIndex, 4> serializedOperands; - - for (Index i = 0; i < node->getOperandCount(); i++) - { - auto operand = node->m_operands[i]; - switch (operand.kind) - { - case ValNodeOperandKind::ConstantValue: - serializedOperands.add((SerialIndex)0); - break; - case ValNodeOperandKind::ValNode: - case ValNodeOperandKind::ASTNode: - serializedOperands.add(addPointer(operand.values.nodeOperand)); - break; - } - } - - SLANG_ASSERT(serializedOperands.getCount() == node->getOperandCount()); - - auto serialOperands = (SerialInfo::SerialValOperand*)(nodeEntry + 1); - for (Index i = 0; i < node->getOperandCount(); i++) - { - auto serialOperand = serialOperands + i; - auto operand = node->m_operands[i]; - serialOperand->type = (int)operand.kind; - switch (operand.kind) - { - case ValNodeOperandKind::ConstantValue: - serialOperand->payload = operand.values.intOperand; - break; - case ValNodeOperandKind::ValNode: - serialOperand->payload = (uint64_t)serializedOperands[i]; - break; - case ValNodeOperandKind::ASTNode: - serialOperand->payload = (uint64_t)serializedOperands[i]; - break; - } - } - return index; -} - -SerialIndex SerialWriter::writeObject(const RefObject* obj) -{ - const SerialRefObject* serialObj = as<const SerialRefObject>(obj); - if (!serialObj) - { - SLANG_ASSERT(!"Unhandled type"); - return SerialIndex(0); - } - - const ReflectClassInfo* classInfo = serialObj->getClassInfo(); - SLANG_ASSERT(classInfo); - - const SerialClass* serialClass = - m_classes->getSerialClass(SerialTypeKind::RefObject, SerialSubType(classInfo->m_classId)); - return writeObject(serialClass, (const void*)obj); -} - -void SerialWriter::setPointerIndex(const NodeBase* ptr, SerialIndex index) -{ - m_ptrMap.add(ptr, Index(index)); -} - -void SerialWriter::setPointerIndex(const RefObject* ptr, SerialIndex index) -{ - m_ptrMap.add(ptr, Index(index)); -} - -SerialIndex SerialWriter::addPointer(const NodeBase* node) -{ - // Null is always 0 - if (node == nullptr) - { - return SerialIndex(0); - } - // Look up in the map - Index* indexPtr = m_ptrMap.tryGetValue(node); - if (indexPtr) - { - return SerialIndex(*indexPtr); - } - - if (m_filter) - { - return m_filter->writePointer(this, node); - } - else - { - return writeObject(node); - } -} - -SerialIndex SerialWriter::addPointer(const RefObject* obj) -{ - // Null is always 0 - if (obj == nullptr) - { - return SerialIndex(0); - } - // Look up in the map - Index* indexPtr = m_ptrMap.tryGetValue(obj); - if (indexPtr) - { - return SerialIndex(*indexPtr); - } - - // TODO(JS): - // Arguably the lookup for these types should be done the same way as arbitrary RefObject types - // and have a enum for them, such we can use a switch instead of all this casting - - if (auto stringRep = dynamicCast<StringRepresentation>(obj)) - { - SerialIndex index = addString(StringRepresentation::asSlice(stringRep)); - m_ptrMap.add(obj, Index(index)); - return index; - } - else if (auto name = dynamicCast<const Name>(obj)) - { - return addName(name); - } - - if (m_filter) - { - return m_filter->writePointer(this, obj); - } - else - { - return writeObject(obj); - } -} - -SerialIndex SerialWriter::_addStringSlice( - SerialTypeKind typeKind, - SliceMap& sliceMap, - const UnownedStringSlice& slice) -{ - typedef ByteEncodeUtil Util; - typedef SerialInfo::StringEntry StringEntry; - - if (slice.getLength() == 0) - { - return SerialIndex(0); - } - - Index* indexPtr = sliceMap.tryGetValue(slice); - if (indexPtr) - { - return SerialIndex(*indexPtr); - } - - // Okay we need to add the string - - uint8_t encodeBuf[Util::kMaxLiteEncodeUInt32]; - const int encodeCount = Util::encodeLiteUInt32(uint32_t(slice.getLength()), encodeBuf); - - StringEntry* entry = (StringEntry*)m_arena.allocateUnaligned( - SLANG_OFFSET_OF(StringEntry, sizeAndChars) + encodeCount + slice.getLength()); - entry->info = SerialInfo::EntryInfo::Alignment1; - entry->typeKind = typeKind; - - uint8_t* dst = (uint8_t*)(entry->sizeAndChars); - for (int i = 0; i < encodeCount; ++i) - { - dst[i] = encodeBuf[i]; - } - - memcpy(dst + encodeCount, slice.begin(), slice.getLength()); - - // Make a key that will stay in scope -> it's actually just stored in the arena. - // NOTE! without terminating 0 - UnownedStringSlice keySlice(((const char*)dst) + encodeCount, slice.getLength()); - - Index newIndex = m_entries.getCount(); - sliceMap.add(keySlice, newIndex); - - m_entries.add(entry); - return SerialIndex(newIndex); -} - -SerialIndex SerialWriter::addString(const String& in) -{ - return addPointer(in.getStringRepresentation()); -} - -SerialIndex SerialWriter::addName(const Name* name) -{ - if (name == nullptr) - { - return SerialIndex(0); - } - - // Look it up - Index* indexPtr = m_ptrMap.tryGetValue(name); - if (indexPtr) - { - return SerialIndex(*indexPtr); - } - - SerialIndex index = addString(name->text); - m_ptrMap.add(name, Index(index)); - return index; -} - -SerialIndex SerialWriter::addSerialArray( - size_t elementSize, - size_t alignment, - const void* elements, - Index elementCount) -{ - typedef SerialInfo::ArrayEntry Entry; - - if (elementCount == 0) - { - return SerialIndex(0); - } - - SLANG_ASSERT(alignment >= 1 && alignment <= SerialInfo::MAX_ALIGNMENT); - - // We must at a minimum have the alignment for the array prefix info - alignment = (alignment < SLANG_ALIGN_OF(Entry)) ? SLANG_ALIGN_OF(Entry) : alignment; - - size_t payloadSize = elementCount * elementSize; - - Entry* entry = (Entry*)m_arena.allocateAligned(sizeof(Entry) + payloadSize, alignment); - - entry->typeKind = SerialTypeKind::Array; - entry->info = SerialInfo::makeEntryInfo(int(alignment)); - entry->elementSize = uint16_t(elementSize); - entry->elementCount = uint32_t(elementCount); - - memcpy(entry + 1, elements, payloadSize); - - m_entries.add(entry); - return SerialIndex(m_entries.getCount() - 1); -} - -static const uint8_t s_fixBuffer[SerialInfo::MAX_ALIGNMENT]{ - 0, -}; - -SlangResult SerialWriter::write(Stream* stream) -{ - const Int entriesCount = m_entries.getCount(); - - // Add a sentinal so we don't need special handling for - SerialInfo::Entry sentinal; - sentinal.typeKind = SerialTypeKind::String; - sentinal.info = SerialInfo::EntryInfo::Alignment1; - - m_entries.add(&sentinal); - m_entries.removeLast(); - - SerialInfo::Entry** entries = m_entries.getBuffer(); - // Note strictly required in our impl of List. But by writing this and - // knowing that removeLast cannot release memory, means the sentinal must be at the last - // position. - entries[entriesCount] = &sentinal; - - { - size_t offset = 0; - - SerialInfo::Entry* entry = entries[1]; - // We start on 1, because 0 is nullptr and not used for anything - for (Index i = 1; i < entriesCount; ++i) - { - SerialInfo::Entry* next = entries[i + 1]; - - // Before writing we need to store the next alignment - - const size_t nextAlignment = SerialInfo::getAlignment(next->info); - const size_t alignment = SerialInfo::getAlignment(entry->info); - SLANG_UNUSED(alignment); - - entry->info = SerialInfo::combineWithNext(entry->info, next->info); - - // Check we are aligned correctly - SLANG_ASSERT((offset & (alignment - 1)) == 0); - - // When we write, we need to make sure it take into account the next alignment - const size_t entrySize = entry->calcSize(m_classes); - - // Work out the fix for next alignment - size_t nextOffset = offset + entrySize; - nextOffset = (nextOffset + nextAlignment - 1) & ~(nextAlignment - 1); - - size_t alignmentFixSize = nextOffset - (offset + entrySize); - - // The fix must be less than max alignment. We require it to be less because we aligned - // each Entry to MAX_ALIGNMENT, and so < MAX_ALIGNMENT is the most extra bytes we can - // write - SLANG_ASSERT(alignmentFixSize < SerialInfo::MAX_ALIGNMENT); - - SLANG_RETURN_ON_FAIL(stream->write(entry, entrySize)); - // If we needed to fix so that subsequent alignment is right, write out extra bytes here - if (alignmentFixSize) - { - SLANG_RETURN_ON_FAIL(stream->write(s_fixBuffer, alignmentFixSize)); - } - - // Onto next - offset = nextOffset; - entry = next; - } - } - - return SLANG_OK; -} - -SlangResult SerialWriter::writeIntoContainer(FourCC fourCc, RiffContainer* container) -{ - typedef RiffContainer::Chunk Chunk; - typedef RiffContainer::ScopeChunk ScopeChunk; - - { - ScopeChunk scopeData(container, Chunk::Kind::Data, fourCc); - - { - // Sentinel so we don't need special handling for end of list - SerialInfo::Entry sentinal; - sentinal.typeKind = SerialTypeKind::String; - sentinal.info = SerialInfo::EntryInfo::Alignment1; - - size_t offset = 0; - const Int entriesCount = m_entries.getCount(); - - { - m_entries.add(&sentinal); - m_entries.removeLast(); - // Note strictly required in our impl of List. But by writing this and - // knowing that removeLast cannot release memory, means the sentinal must be at the - // last position. - m_entries.getBuffer()[entriesCount] = &sentinal; - } - - SerialInfo::Entry* const* entries = m_entries.getBuffer(); - - SerialInfo::Entry* entry = entries[1]; - // We start on 1, because 0 is nullptr and not used for anything - for (Index i = 1; i < entriesCount; ++i) - { - SerialInfo::Entry* next = entries[i + 1]; - - // Before writing we need to store the next alignment - - const size_t nextAlignment = SerialInfo::getAlignment(next->info); - const size_t alignment = SerialInfo::getAlignment(entry->info); - SLANG_UNUSED(alignment); - - entry->info = SerialInfo::combineWithNext(entry->info, next->info); - - // Check we are aligned correctly - SLANG_ASSERT((offset & (alignment - 1)) == 0); - - // When we write, we need to make sure it take into account the next alignment - const size_t entrySize = entry->calcSize(m_classes); - - // Work out the fix for next alignment - size_t nextOffset = offset + entrySize; - nextOffset = (nextOffset + nextAlignment - 1) & ~(nextAlignment - 1); - - size_t alignmentFixSize = nextOffset - (offset + entrySize); - - // The fix must be less than max alignment. We require it to be less because we - // aligned each Entry to MAX_ALIGNMENT, and so < MAX_ALIGNMENT is the most extra - // bytes we can write - SLANG_ASSERT(alignmentFixSize < SerialInfo::MAX_ALIGNMENT); - - container->write(entry, entrySize); - if (alignmentFixSize) - { - container->write(s_fixBuffer, alignmentFixSize); - } - - // Onto next - offset = nextOffset; - entry = next; - } - } - } - - return SLANG_OK; -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialInfo::Entry !!!!!!!!!!!!!!!!!!!!!!!! - -size_t SerialInfo::Entry::calcSize(SerialClasses* serialClasses) const -{ - switch (typeKind) - { - case SerialTypeKind::ImportSymbol: - case SerialTypeKind::String: - { - auto entry = static_cast<const StringEntry*>(this); - const uint8_t* cur = (const uint8_t*)entry->sizeAndChars; - uint32_t charsSize; - int sizeSize = ByteEncodeUtil::decodeLiteUInt32(cur, &charsSize); - return SLANG_OFFSET_OF(StringEntry, sizeAndChars) + sizeSize + charsSize; - } - case SerialTypeKind::Array: - { - auto entry = static_cast<const ArrayEntry*>(this); - return sizeof(ArrayEntry) + entry->elementSize * entry->elementCount; - } - case SerialTypeKind::RefObject: - case SerialTypeKind::NodeBase: - { - auto entry = static_cast<const ObjectEntry*>(this); - - auto serialClass = serialClasses->getSerialClass(typeKind, entry->subType); - - if (ReflectClassInfo::isSubClassOf(entry->subType, Val::kReflectClassInfo)) - return sizeof(ValEntry) + - static_cast<const ValEntry*>(this)->operandCount * sizeof(SerialValOperand); - - // Align by the alignment of the entry - size_t alignment = getAlignment(entry->info); - size_t size = sizeof(ObjectEntry) + serialClass->size; - - size = size + (alignment - 1) & ~(alignment - 1); - return size; - } - - default: - break; - } - - SLANG_ASSERT(!"Unknown type"); - return 0; -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialReader !!!!!!!!!!!!!!!!!!!!!!!!!!!! - -SerialReader::~SerialReader() -{ - for (const RefObject* obj : m_scope) - { - const_cast<RefObject*>(obj)->releaseReference(); - } -} - -const void* SerialReader::getArray(SerialIndex index, Index& outCount) -{ - if (index == SerialIndex(0)) - { - outCount = 0; - return nullptr; - } - - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - const Entry* entry = m_entries[Index(index)]; - - switch (entry->typeKind) - { - case SerialTypeKind::Array: - { - auto arrayEntry = static_cast<const SerialInfo::ArrayEntry*>(entry); - outCount = Index(arrayEntry->elementCount); - return (arrayEntry + 1); - } - default: - break; - } - - SLANG_ASSERT(!"Not an array"); - outCount = 0; - return nullptr; -} - -SerialPointer SerialReader::getPointer(SerialIndex index) -{ - if (index == SerialIndex(0)) - { - return SerialPointer(); - } - - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - const Entry* entry = m_entries[Index(index)]; - - const SerialPointer& ptr = m_objects[Index(index)]; - - switch (entry->typeKind) - { - case SerialTypeKind::String: - { - // Hmm. Tricky -> we don't know if will be cast as Name or String. Lets assume string. - String string = getString(index); - return SerialPointer(string.getStringRepresentation()); - } - case SerialTypeKind::ImportSymbol: - { - if (ptr.m_kind == SerialTypeKind::Unknown) - { - // TODO(JS): - // Could have an error here, because import symbol was not set - // For now just return nullptr - return SerialPointer(); - } - break; - } - default: - break; - } - - return ptr; -} - -SerialPointer SerialReader::getValPointer(SerialIndex index) -{ - if (index == SerialIndex(0)) - { - return SerialPointer(); - } - - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - - SerialPointer& ptr = m_objects[Index(index)]; - - if (ptr.m_ptr) - return ptr; - - const SerialInfo::ValEntry* entry = (SerialInfo::ValEntry*)m_entries[Index(index)]; - ValNodeDesc desc; - desc.type = (ASTNodeType)entry->subType; - auto readPtr = (SerialInfo::SerialValOperand*)(entry + 1); - for (uint32_t i = 0; i < entry->operandCount; i++) - { - auto serialOperand = readPtr[i]; - ValNodeOperand operand; - operand.kind = (ValNodeOperandKind)(serialOperand.type); - switch (operand.kind) - { - case ValNodeOperandKind::ConstantValue: - operand.values.intOperand = serialOperand.payload; - break; - case ValNodeOperandKind::ASTNode: - operand.values.nodeOperand = - (NodeBase*)getPointer((SerialIndex)serialOperand.payload).m_ptr; - break; - case ValNodeOperandKind::ValNode: - operand.values.nodeOperand = - (Val*)getValPointer((SerialIndex)serialOperand.payload).m_ptr; - break; - } - desc.operands.add(operand); - } - desc.init(); - ptr.m_kind = SerialTypeKind::NodeBase; - ptr.m_ptr = this->m_objectFactory->getOrCreateVal(_Move(desc)); - return ptr; -} - -String SerialReader::getString(SerialIndex index) -{ - if (index == SerialIndex(0)) - { - return String(); - } - - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - const Entry* entry = m_entries[Index(index)]; - - // It has to be a string type - if (entry->typeKind != SerialTypeKind::String) - { - SLANG_ASSERT(!"Not a string"); - return String(); - } - - RefObject* obj = m_objects[Index(index)].dynamicCast<RefObject>(); - - if (obj) - { - StringRepresentation* stringRep = dynamicCast<StringRepresentation>(obj); - if (stringRep) - { - return String(stringRep); - } - // Must be a name then - Name* name = dynamicCast<Name>(obj); - SLANG_ASSERT(name); - return name->text; - } - - // Okay we need to construct as a string - UnownedStringSlice slice = getStringSlice(index); - - StringRepresentation* stringRep = nullptr; - - const Index length = slice.getLength(); - if (length) - { - stringRep = StringRepresentation::createWithCapacityAndLength(length, length); - memcpy(stringRep->getData(), slice.begin(), length * sizeof(char)); - addScope(stringRep); - } - - m_objects[Index(index)] = stringRep; - return String(stringRep); -} - -Name* SerialReader::getName(SerialIndex index) -{ - if (index == SerialIndex(0)) - { - return nullptr; - } - - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - const Entry* entry = m_entries[Index(index)]; - - // It has to be a string type - if (entry->typeKind != SerialTypeKind::String) - { - SLANG_ASSERT(!"Not a string"); - return nullptr; - } - - RefObject* obj = m_objects[Index(index)].dynamicCast<RefObject>(); - - if (obj) - { - Name* name = dynamicCast<Name>(obj); - if (name) - { - return name; - } - // Can only be a string then - StringRepresentation* stringRep = dynamicCast<StringRepresentation>(obj); - SLANG_ASSERT(stringRep); - - // I don't need to scope, as scoped in NamePool - name = m_namePool->getName(String(stringRep)); - - // Store as name, as can always access the inner string if needed - m_objects[Index(index)] = name; - return name; - } - - UnownedStringSlice slice = getStringSlice(index); - String string(slice); - Name* name = m_namePool->getName(string); - // Don't need to add to scope, because scoped on the pool - m_objects[Index(index)] = name; - return name; -} - -UnownedStringSlice SerialReader::getStringSlice(SerialIndex index) -{ - SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); - const Entry* entry = m_entries[Index(index)]; - - // It has to be a string type - if (entry->typeKind == SerialTypeKind::String || - entry->typeKind == SerialTypeKind::ImportSymbol) - { - auto stringEntry = static_cast<const SerialInfo::StringEntry*>(entry); - - const uint8_t* src = (const uint8_t*)stringEntry->sizeAndChars; - - // Decode the string - uint32_t size; - int sizeSize = ByteEncodeUtil::decodeLiteUInt32(src, &size); - return UnownedStringSlice((const char*)src + sizeSize, size); - } - - // Can't be accessed as a slice - SLANG_ASSERT(!"Not accessible as a slice"); - return UnownedStringSlice(); -} - -/* static */ SlangResult SerialReader::loadEntries( - const uint8_t* data, - size_t dataCount, - SerialClasses* serialClasses, - List<const Entry*>& outEntries) -{ - // Check the input data is at least aligned to the max alignment (otherwise everything cannot be - // aligned correctly) - SLANG_ASSERT((size_t(data) & (SerialInfo::MAX_ALIGNMENT - 1)) == 0); - - outEntries.setCount(1); - outEntries[0] = nullptr; - - const uint8_t* const end = data + dataCount; - - const uint8_t* cur = data; - while (cur < end) - { - const Entry* entry = (const Entry*)cur; - outEntries.add(entry); - - const size_t entrySize = entry->calcSize(serialClasses); - cur += entrySize; - - // Need to get the next alignment - const size_t nextAlignment = SerialInfo::getNextAlignment(entry->info); - - // Need to fix cur with the alignment - cur = (const uint8_t*)((size_t(cur) + nextAlignment - 1) & ~(nextAlignment - 1)); - } - - return SLANG_OK; -} - -SlangResult SerialReader::constructObjects(NamePool* namePool) -{ - m_namePool = namePool; - - m_objects.clearAndDeallocate(); - m_objects.setCount(m_entries.getCount()); - memset(m_objects.getBuffer(), 0, m_objects.getCount() * sizeof(void*)); - - // Go through entries, constructing objects. - for (Index i = 1; i < m_entries.getCount(); ++i) - { - const Entry* entry = m_entries[i]; - - switch (entry->typeKind) - { - case SerialTypeKind::ImportSymbol: - { - // We don't construct any object for an imported symbol. - // It will be the responsibility of external code to interpet the symbols and *set* - // the appopriate objects prior to a call to `deserializeObjects` - break; - } - case SerialTypeKind::String: - { - // Don't need to construct an object. This is probably a StringRepresentation, or a - // Name Will evaluate lazily. - break; - } - case SerialTypeKind::RefObject: - case SerialTypeKind::NodeBase: - { - auto objectEntry = static_cast<const SerialInfo::ObjectEntry*>(entry); - - // Don't create object for Vals. - if (objectEntry->typeKind == SerialTypeKind::NodeBase && - ReflectClassInfo::isSubClassOf(objectEntry->subType, Val::kReflectClassInfo)) - break; - - void* obj = m_objectFactory->create(objectEntry->typeKind, objectEntry->subType); - if (!obj) - { - return SLANG_FAIL; - } - m_objects[i].set(entry->typeKind, obj); - break; - } - case SerialTypeKind::Array: - { - // Don't need to construct an object, as will be accessed and interpreted by the - // object that holds it - break; - } - } - } - - return SLANG_OK; -} - -SlangResult SerialReader::deserializeObjects() -{ - // Deserialize - for (Index i = 1; i < m_entries.getCount(); ++i) - { - const Entry* entry = m_entries[i]; - // First see if there is anything to construct - SerialPointer& dstPtr = m_objects[i]; - if (!dstPtr) - { - continue; - } - switch (entry->typeKind) - { - case SerialTypeKind::NodeBase: - case SerialTypeKind::RefObject: - { - auto objectEntry = static_cast<const SerialInfo::ObjectEntry*>(entry); - auto serialClass = - m_classes->getSerialClass(objectEntry->typeKind, objectEntry->subType); - if (!serialClass) - { - return SLANG_FAIL; - } - if (ReflectClassInfo::isSubClassOf(objectEntry->subType, Val::kReflectClassInfo)) - continue; - - const uint8_t* src = (const uint8_t*)(objectEntry + 1); - uint8_t* dst = (uint8_t*)dstPtr.m_ptr; - - // It must be constructed - SLANG_ASSERT(dst); - - while (serialClass) - { - for (Index j = 0; j < serialClass->fieldsCount; ++j) - { - auto field = serialClass->fields[j]; - auto fieldType = field.type; - fieldType->toNativeFunc( - this, - src + field.serialOffset, - dst + field.nativeOffset); - } - - // Get the super class - serialClass = serialClass->super; - } - - break; - } - default: - break; - } - } - - return SLANG_OK; -} - - -SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* namePool) -{ - // Load and place entries into entries table - SLANG_RETURN_ON_FAIL(loadEntries(data, dataCount)); - // Construct all of the objects - SLANG_RETURN_ON_FAIL(constructObjects(namePool)); - SLANG_RETURN_ON_FAIL(deserializeObjects()); - return SLANG_OK; -} - } // namespace Slang diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h index 4c7189d62..20916f735 100644 --- a/source/slang/slang-serialize.h +++ b/source/slang/slang-serialize.h @@ -27,702 +27,465 @@ class NodeBase; class Val; struct ValNodeDesc; -// Pre-declare -class SerialClasses; -class SerialWriter; -class SerialReader; - -struct SerialClass; -struct SerialField; - -// Type used to implement mechanisms to convert to and from serial types. -template<typename T, typename /*enumTypeSFINAE*/ = void> -struct SerialTypeInfo; - -enum class SerialTypeKind : uint8_t +struct Encoder { - Unknown, - - String, ///< String - Array, ///< Array - ImportSymbol, ///< Holds the name of the import symbol. Represented in exactly the same way as a - ///< string - - NodeBase, ///< NodeBase derived - RefObject, ///< RefObject derived types - - CountOf, -}; -typedef uint16_t SerialSubType; - -struct SerialInfo -{ - enum - { - // Data held in serialized format, the maximally allowed alignment - MAX_ALIGNMENT = 8, - }; - - // We only allow up to MAX_ALIGNMENT bytes of alignment. We store alignments as shifts, so 2 - // bits needed for 1 - 8 - enum class EntryInfo : uint8_t - { - Alignment1 = 0, - }; - - static EntryInfo makeEntryInfo(int alignment, int nextAlignment) - { - // Make sure they are power of 2 - SLANG_ASSERT((alignment & (alignment - 1)) == 0); - SLANG_ASSERT((nextAlignment & (nextAlignment - 1)) == 0); - - const int alignmentShift = ByteEncodeUtil::calcMsb8(alignment); - const int nextAlignmentShift = ByteEncodeUtil::calcMsb8(nextAlignment); - return EntryInfo((nextAlignmentShift << 2) | alignmentShift); - } - static EntryInfo makeEntryInfo(int alignment) - { - // Make sure they are power of 2 - SLANG_ASSERT((alignment & (alignment - 1)) == 0); - return EntryInfo(ByteEncodeUtil::calcMsb8(alignment)); - } - /// Apply with the next alignment - static EntryInfo combineWithNext(EntryInfo cur, EntryInfo next) +public: + Encoder(Stream* stream) + : _stream(stream) { - return EntryInfo((int(cur) & ~0xc0) | ((int(next) & 3) << 2)); } - static int getAlignment(EntryInfo info) { return 1 << (int(info) & 3); } - static int getNextAlignment(EntryInfo info) { return 1 << ((int(info) >> 2) & 3); } - - /* Alignment is a little tricky. We have a 'Entry' header before the payload. The payload - alignment may change. If we only align on the Entry header, then it's size *must* be some modulo - of the maximum alignment allowed. + ~Encoder() { RiffUtil::write(&_riff, _stream); } - We could hold Entry separate from payload. We could make the header not require the alignment of - the payload - but then we'd need payload alignment separate from entry alignment. - */ - struct Entry + void beginArray(FourCC typeCode) { - SerialTypeKind typeKind; - EntryInfo info; + _riff.startChunk(RiffContainer::Chunk::Kind::List, typeCode); + } - size_t calcSize(SerialClasses* serialClasses) const; - }; + void beginArray() { beginArray(SerialBinary::kArrayFourCC); } - struct StringEntry : Entry + void endArray() { - char sizeAndChars[1]; - }; + _riff.endChunk(); + // TODO: maybe end key... + } - struct ObjectEntry : Entry + void beginObject(FourCC typeCode) { - SerialSubType - subType; ///< Can be ASTType or other subtypes (as used for RefObjects for example) - uint32_t _pad0; ///< Necessary, because a node *can* have MAX_ALIGNEMENT - }; + _riff.startChunk(RiffContainer::Chunk::Kind::List, typeCode); + } - struct ValEntry : Entry - { - SerialSubType subType; - uint32_t operandCount; - }; + void beginObject() { beginObject(SerialBinary::kObjectFourCC); } - struct ArrayEntry : Entry - { - uint16_t elementSize; - uint32_t elementCount; - }; + void endObject() { _riff.endChunk(); } - struct SerialValOperand + void beginKeyValuePair() { - int type; - uint64_t payload; - }; -}; + _riff.startChunk(RiffContainer::Chunk::Kind::List, SerialBinary::kPairFourCC); + } -typedef uint32_t SerialIndexRaw; -enum class SerialIndex : SerialIndexRaw; + void endKeyValuePair() { _riff.endChunk(); } -/* A type to convert pointers into types such that they can be passed around to readers/writers -without having to know the specific type. If there was a base class that all the serialized types -derived from, that was dynamically castable this would not be necessary */ -struct SerialPointer -{ - // Helpers so we can choose what kind of pointer we have based on the (unused) type of the - // pointer passed in - SLANG_FORCE_INLINE RefObject* _get(const RefObject*) - { - return m_kind == SerialTypeKind::RefObject ? reinterpret_cast<RefObject*>(m_ptr) : nullptr; - } - SLANG_FORCE_INLINE NodeBase* _get(const NodeBase*) + void beginKeyValuePair(FourCC keyCode) { - return m_kind == SerialTypeKind::NodeBase ? reinterpret_cast<NodeBase*>(m_ptr) : nullptr; + _riff.startChunk(RiffContainer::Chunk::Kind::List, keyCode); } - template<typename T> - T* dynamicCast() + void encodeData(FourCC typeCode, void const* data, size_t size) { - return Slang::dynamicCast<T>(_get((T*)nullptr)); + _riff.startChunk(RiffContainer::Chunk::Kind::Data, typeCode); + _riff.write(data, size); + _riff.endChunk(); } - SerialPointer() - : m_kind(SerialTypeKind::Unknown), m_ptr(nullptr) + void encodeData(void const* data, size_t size) { + encodeData(SerialBinary::kDataFourCC, data, size); } - SerialPointer(RefObject* in) - : m_kind(SerialTypeKind::RefObject), m_ptr((void*)in) - { - } - SerialPointer(NodeBase* in) - : m_kind(SerialTypeKind::NodeBase), m_ptr((void*)in) + void encode(nullptr_t) { encodeData(SerialBinary::kNullFourCC, nullptr, 0); } + + void encodeBool(bool value) { + encodeData(value ? SerialBinary::kTrueFourCC : SerialBinary::kFalseFourCC, nullptr, 0); } - /// True if the ptr is set - SLANG_FORCE_INLINE operator bool() const { return m_ptr != nullptr; } + void encode(Int32 value) { encodeData(SerialBinary::kInt32FourCC, &value, sizeof(value)); } - /// Directly set pointer/kind - void set(SerialTypeKind kind, void* ptr) - { - m_kind = kind; - m_ptr = ptr; - } + void encode(UInt32 value) { encodeData(SerialBinary::kUInt32FourCC, &value, sizeof(value)); } - static SerialTypeKind getKind(const RefObject*) { return SerialTypeKind::RefObject; } - static SerialTypeKind getKind(const NodeBase*) { return SerialTypeKind::NodeBase; } + void encode(Int64 value) { encodeData(SerialBinary::kInt64FourCC, &value, sizeof(value)); } - SerialTypeKind m_kind; - void* m_ptr; -}; + void encode(UInt64 value) { encodeData(SerialBinary::kUInt64FourCC, &value, sizeof(value)); } -class SerialFilter -{ -public: - virtual SerialIndex writePointer(SerialWriter* writer, const NodeBase* ptr) = 0; - virtual SerialIndex writePointer(SerialWriter* writer, const RefObject* ptr) = 0; -}; + void encode(float value) { encodeData(SerialBinary::kFloat32FourCC, &value, sizeof(value)); } -class SerialObjectFactory -{ -public: - virtual void* create(SerialTypeKind typeKind, SerialSubType subType) = 0; - virtual void* getOrCreateVal(ValNodeDesc&& desc) = 0; -}; + void encode(double value) { encodeData(SerialBinary::kFloat64FourCC, &value, sizeof(value)); } -class SerialExtraObjects -{ -public: - template<typename T> - void set(T* obj) - { - m_objects[Index(T::kExtraType)] = obj; - } - template<typename T> - void set(const RefPtr<T>& obj) + void encodeString(String const& value) { - m_objects[Index(T::kExtraType)] = obj.Ptr(); + Int size = value.getLength(); + encodeData(SerialBinary::kStringFourCC, value.getBuffer(), size); } - /// Get the extra type - template<typename T> - T* get() - { - return reinterpret_cast<T*>(m_objects[Index(T::kExtraType)]); - } - SerialExtraObjects() - { - for (auto& obj : m_objects) - obj = nullptr; - } + void encode(String const& value) { encodeString(value); } -protected: - void* m_objects[Index(SerialExtraType::CountOf)]; -}; + struct WithArray + { + public: + WithArray(Encoder* encoder) + : _encoder(encoder) + { + encoder->beginArray(); + } -enum class PostSerializationFixUpKind -{ - ValPtr, -}; + WithArray(Encoder* encoder, FourCC typeCode) + : _encoder(encoder) + { + encoder->beginArray(typeCode); + } -/* This class is the interface used by toNative implementations to recreate a type. */ -class SerialReader : public RefObject -{ -public: - typedef SerialInfo::Entry Entry; + ~WithArray() { _encoder->endArray(); } - template<typename T> - void getArray(SerialIndex index, List<T>& out); + private: + Encoder* _encoder; + }; - template<typename T, int n> - void getArray(SerialIndex index, ShortList<T, n>& out); + struct WithObject + { + public: + WithObject(Encoder* encoder) + : _encoder(encoder) + { + encoder->beginObject(); + } - const void* getArray(SerialIndex index, Index& outCount); + WithObject(Encoder* encoder, FourCC typeCode) + : _encoder(encoder) + { + encoder->beginObject(typeCode); + } - SerialPointer getPointer(SerialIndex index); - SerialPointer getValPointer(SerialIndex index); + ~WithObject() { _encoder->endObject(); } - String getString(SerialIndex index); - Name* getName(SerialIndex index); - UnownedStringSlice getStringSlice(SerialIndex index); + private: + Encoder* _encoder; + }; - SlangResult loadEntries(const uint8_t* data, size_t dataCount) - { - return loadEntries(data, dataCount, m_classes, m_entries); - } - /// For each entry construct an object. Does *NOT* deserialize them - SlangResult constructObjects(NamePool* namePool); - /// Entries must be loaded (with loadEntries), and objects constructed (with constructObjects) - /// before deserializing - SlangResult deserializeObjects(); - - /// NOTE! data must stay ins scope when reading takes place - SlangResult load(const uint8_t* data, size_t dataCount, NamePool* namePool); - - /// Get the entries list - const List<const Entry*>& getEntries() const { return m_entries; } - - /// Access the objects list - /// NOTE that if a SerialObject holding a RefObject and needs to be kept in scope, add the - /// RefObject* via addScope - List<SerialPointer>& getObjects() { return m_objects; } - const List<SerialPointer>& getObjects() const { return m_objects; } - - /// Add an object to be kept in scope - void addScopeWithoutAddRef(const RefObject* obj) { m_scope.add(obj); } - /// Add obj with a reference - void addScope(const RefObject* obj) + struct WithKeyValuePair { - const_cast<RefObject*>(obj)->addReference(); - m_scope.add(obj); - } + public: + WithKeyValuePair(Encoder* encoder) + : _encoder(encoder) + { + encoder->beginKeyValuePair(); + } - /// Used for attaching extra objects necessary for serializing - SerialExtraObjects& getExtraObjects() { return m_extraObjects; } + WithKeyValuePair(Encoder* encoder, FourCC typeCode) + : _encoder(encoder) + { + encoder->beginKeyValuePair(typeCode); + } - /// Ctor - SerialReader(SerialClasses* classes, SerialObjectFactory* objectFactory) - : m_classes(classes), m_objectFactory(objectFactory) - { - } - ~SerialReader(); + ~WithKeyValuePair() { _encoder->endKeyValuePair(); } - /// Load the entries table (without deserializing anything) - /// NOTE! data must stay ins scope for outEntries to be valid - static SlangResult loadEntries( - const uint8_t* data, - size_t dataCount, - SerialClasses* serialClasses, - List<const Entry*>& outEntries); + private: + Encoder* _encoder; + }; -protected: - List<const Entry*> m_entries; ///< The entries +private: + Stream* _stream = nullptr; - List<SerialPointer> m_objects; ///< The constructed objects - NamePool* m_namePool; ///< Pool names are added to + // Implementation details below... + RiffContainer _riff; - List<const RefObject*> m_scope; ///< Keeping objects in scope +public: + RiffContainer* getRIFF() { return &_riff; } - SerialExtraObjects m_extraObjects; + RiffContainer::Chunk* getRIFFChunk() { return _riff.getCurrentChunk(); } - SerialObjectFactory* m_objectFactory; - SerialClasses* m_classes; ///< Information used to deserialize + void setRIFFChunk(RiffContainer::Chunk* chunk) { _riff.setCurrentChunk(chunk); } }; -// --------------------------------------------------------------------------- -template<typename T> -void SerialReader::getArray(SerialIndex index, List<T>& out) +struct Decoder { - typedef SerialTypeInfo<T> ElementTypeInfo; - typedef typename ElementTypeInfo::SerialType ElementSerialType; - - Index count; - auto serialElements = (const ElementSerialType*)getArray(index, count); - - if (count == 0) +public: + Decoder(RiffContainer::Chunk* chunk) + : _chunk(chunk) { - out.clear(); - return; } - if (std::is_same<T, ElementSerialType>::value) + bool decodeBool() { - // If they are the same we can just write out - out.clear(); - out.insertRange(0, (const T*)serialElements, count); - } - else - { - // Else we need to convert - out.setCount(count); - for (Index i = 0; i < count; ++i) + switch (getTag()) { - ElementTypeInfo::toNative(this, (const void*)&serialElements[i], (void*)&out[i]); + case SerialBinary::kTrueFourCC: + _chunk = _chunk->m_next; + return true; + case SerialBinary::kFalseFourCC: + _chunk = _chunk->m_next; + return false; + + default: + SLANG_UNEXPECTED("invalid format in RIFF"); + UNREACHABLE_RETURN(false); } } -} - -template<typename T, int n> -void SerialReader::getArray(SerialIndex index, ShortList<T, n>& out) -{ - typedef SerialTypeInfo<T> ElementTypeInfo; - typedef typename ElementTypeInfo::SerialType ElementSerialType; - - Index count; - auto serialElements = (const ElementSerialType*)getArray(index, count); - if (count == 0) + String decodeString() { - out.clear(); - return; - } + if (getTag() != SerialBinary::kStringFourCC) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + UNREACHABLE_RETURN(""); + } - if (std::is_same<T, ElementSerialType>::value) - { - // If they are the same we can just write out - out.clear(); - out.addRange((const T*)serialElements, count); - } - else - { - // Else we need to convert - out.setCount(count); - for (Index i = 0; i < count; ++i) + auto dataChunk = as<RiffContainer::DataChunk>(_chunk); + if (!dataChunk) { - ElementTypeInfo::toNative(this, (const void*)&serialElements[i], (void*)&out[i]); + SLANG_UNEXPECTED("invalid format in RIFF"); + UNREACHABLE_RETURN(""); } + + auto size = dataChunk->calcPayloadSize(); + + String value; + value.appendRepeatedChar(' ', size); + dataChunk->getPayload((char*)value.getBuffer()); + + _chunk = _chunk->m_next; + return value; } -} -/* This is a class used tby toSerial implementations to turn native type into the serial type */ -class SerialWriter : public RefObject -{ -public: - typedef uint32_t Flags; - struct Flag + void decodeData(FourCC typeTag, void* outData, size_t dataSize) { - enum Enum : Flags + if (getTag() == typeTag) { - /// If set will zero initialize backing memory. This is slower but - /// is desirable to make two serializations of the same thing produce the - /// identical serialized result. - ZeroInitialize = 0x1, - - /// If set will not serialize function body. - SkipFunctionBody = 0x2, - }; - }; - - SerialIndex addPointer(const NodeBase* ptr); - SerialIndex addPointer(const RefObject* ptr); - - /// Write the object at ptr of type serialCls - SerialIndex writeObject(const SerialClass* serialCls, const void* ptr); + auto dataChunk = as<RiffContainer::DataChunk>(_chunk); + if (dataChunk) + { + if (dataChunk->calcPayloadSize() >= dataSize) + { + dataChunk->getPayload(outData); + _chunk = _chunk->m_next; + return; + } + } + } - /// Write the object at the pointer - SerialIndex writeObject(const NodeBase* ptr); - SerialIndex writeObject(const RefObject* ptr); - SerialIndex writeValObject(const Val* ptr); + SLANG_UNEXPECTED("invalid format in RIFF"); + } - /// Add an array - may need to convert to serialized format template<typename T> - SerialIndex addArray(const T* in, Index count); - - template<typename NATIVE_TYPE> - /// Add an array where all the elements are already in serialized format (ie there is no need to - /// do a conversion) - SerialIndex addSerialArray(const void* elements, Index elementCount) + T _decodeSimpleValue(FourCC typeTag) { - typedef SerialTypeInfo<NATIVE_TYPE> TypeInfo; - return addSerialArray( - sizeof(typename TypeInfo::SerialType), - SerialTypeInfo<NATIVE_TYPE>::SerialAlignment, - elements, - elementCount); + T value; + decodeData(typeTag, &value, sizeof(value)); + return value; } - /// Add an array where all the elements are already in serialized format (ie there is no need to - /// do a conversion) - SerialIndex addSerialArray( - size_t elementSize, - size_t alignment, - const void* elements, - Index elementCount); + Int64 decodeInt64() { return _decodeSimpleValue<Int64>(SerialBinary::kInt64FourCC); } - /// Add the string - SerialIndex addString(const UnownedStringSlice& slice) - { - return _addStringSlice(SerialTypeKind::String, m_sliceMap, slice); - } - SerialIndex addString(const String& in); - SerialIndex addName(const Name* name); + UInt64 decodeUInt64() { return _decodeSimpleValue<UInt64>(SerialBinary::kUInt64FourCC); } - /// Adding import symbols - SerialIndex addImportSymbol(const UnownedStringSlice& slice) - { - return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, slice); - } - SerialIndex addImportSymbol(const String& string) - { - return _addStringSlice( - SerialTypeKind::ImportSymbol, - m_importSymbolMap, - string.getUnownedSlice()); - } + Int32 decodeInt32() { return _decodeSimpleValue<Int32>(SerialBinary::kInt32FourCC); } - /// Set a the ptr associated with an index. - /// NOTE! That there cannot be a pre-existing setting. - void setPointerIndex(const NodeBase* ptr, SerialIndex index); - void setPointerIndex(const RefObject* ptr, SerialIndex index); + UInt32 decodeUInt32() { return _decodeSimpleValue<UInt32>(SerialBinary::kUInt32FourCC); } - /// Get the entries table holding how each index maps to an entry - const List<SerialInfo::Entry*>& getEntries() const { return m_entries; } + float decodeFloat32() { return _decodeSimpleValue<float>(SerialBinary::kFloat32FourCC); } - /// Write to a stream - SlangResult write(Stream* stream); + double decodeFloat64() { return _decodeSimpleValue<double>(SerialBinary::kFloat64FourCC); } - /// Write a data chunk with fourCC - SlangResult writeIntoContainer(FourCC fourCC, RiffContainer* container); - /// Used for attaching extra objects necessary for serializing - SerialExtraObjects& getExtraObjects() { return m_extraObjects; } + FourCC getTag() { return _chunk ? _chunk->m_fourCC : 0; } - /// Get the flag - Flags getFlags() const { return m_flags; } + Int32 _decodeImpl(Int32*) { return decodeInt32(); } + UInt32 _decodeImpl(UInt32*) { return decodeUInt32(); } - /// Ctor - SerialWriter(SerialClasses* classes, SerialFilter* filter, Flags flags = Flag::ZeroInitialize); + Int64 _decodeImpl(Int64*) { return decodeInt64(); } + UInt64 _decodeImpl(UInt64*) { return decodeUInt64(); } -protected: - typedef Dictionary<UnownedStringSlice, Index> SliceMap; + float _decodeImpl(float*) { return decodeFloat32(); } + double _decodeImpl(double*) { return decodeFloat64(); } - SerialIndex _addStringSlice( - SerialTypeKind typeKind, - SliceMap& sliceMap, - const UnownedStringSlice& slice); + template<typename T> + T decode() + { + return _decodeImpl((T*)nullptr); + } - SerialIndex _add(const void* nativePtr, SerialInfo::Entry* entry) + template<typename T> + void decode(T& outValue) { - m_entries.add(entry); - // Okay I need to allocate space for this - SerialIndex index = SerialIndex(m_entries.getCount() - 1); - // Add to the map - m_ptrMap.add(nativePtr, Index(index)); - return index; + outValue = _decodeImpl((T*)nullptr); } - Dictionary<const void*, Index> m_ptrMap; // Maps a pointer to an entry index + void beginArray(FourCC typeCode = SerialBinary::kArrayFourCC) + { + auto listChunk = as<RiffContainer::ListChunk>(_chunk); + if (!listChunk) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } - // NOTE! Assumes the content stays in scope! - SliceMap m_sliceMap; - SliceMap m_importSymbolMap; + if (listChunk->m_fourCC != typeCode) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } - SerialExtraObjects m_extraObjects; ///< Extra objects + _chunk = listChunk->getFirstContainedChunk(); + } - List<SerialInfo::Entry*> m_entries; ///< The entries - MemoryArena m_arena; ///< Holds the payloads - SerialClasses* m_classes; - SerialFilter* m_filter; ///< Filter to control what is serialized + void beginObject(FourCC typeCode = SerialBinary::kObjectFourCC) + { + auto listChunk = as<RiffContainer::ListChunk>(_chunk); + if (!listChunk) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } - Flags m_flags; ///< Flags to control behavior -}; + if (listChunk->m_fourCC != typeCode) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } -// --------------------------------------------------------------------------- -template<typename T> -SerialIndex SerialWriter::addArray(const T* in, Index count) -{ - typedef SerialTypeInfo<T> ElementTypeInfo; - typedef typename ElementTypeInfo::SerialType ElementSerialType; + _chunk = listChunk->getFirstContainedChunk(); + } - if (std::is_same<T, ElementSerialType>::value) + void beginKeyValuePair(FourCC typeCode = SerialBinary::kPairFourCC) { - // If they are the same we can just write out - return addSerialArray(sizeof(T), SLANG_ALIGN_OF(ElementSerialType), in, count); + auto listChunk = as<RiffContainer::ListChunk>(_chunk); + if (!listChunk) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } + + if (listChunk->m_fourCC != typeCode) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } + + _chunk = listChunk->getFirstContainedChunk(); } - else - { - // Else we need to convert - List<ElementSerialType> work; - work.setCount(count); - if (getFlags() & Flag::ZeroInitialize) + void beginProperty(FourCC propertyCode) + { + auto listChunk = as<RiffContainer::ListChunk>(_chunk); + if (!listChunk) { - ::memset(work.getBuffer(), 0, sizeof(ElementSerialType) * count); + SLANG_UNEXPECTED("invalid format in RIFF"); } - for (Index i = 0; i < count; ++i) + auto found = listChunk->findContainedList(propertyCode); + if (!found) { - ElementTypeInfo::toSerial(this, &in[i], &work[i]); + SLANG_UNEXPECTED("invalid format in RIFF"); } - return addSerialArray( - sizeof(ElementSerialType), - SLANG_ALIGN_OF(ElementSerialType), - work.getBuffer(), - count); - } -} -/* A SerialFieldType describes the size of field, it's alignment, and contains the -functions that convert between serial and native data */ -struct SerialFieldType -{ - typedef void (*ToSerialFunc)(SerialWriter* writer, const void* src, void* dst); - typedef void (*ToNativeFunc)(SerialReader* reader, const void* src, void* dst); + _chunk = found->getFirstContainedChunk(); + } - size_t serialSizeInBytes; - uint8_t serialAlignment; - ToSerialFunc toSerialFunc; - ToNativeFunc toNativeFunc; -}; + bool hasElements() { return _chunk != nullptr; } -/* Describes a field in a SerialClass. */ -struct SerialField -{ - /// Returns a suitable ptr for use in make. - /// NOTE! Sets to 1 so it's constant and not 0 (and so nullptr) - template<typename T> - static T* getPtr() + bool isNull() { - return (T*)1; + if (_chunk == nullptr) + return true; + if (getTag() == SerialBinary::kNullFourCC) + return true; + return false; } - template<typename T> - static SerialField make(const char* name, T* in); - - const char* name; ///< The name of the field - const SerialFieldType* type; ///< The type of the field - uint32_t nativeOffset; ///< Offset to field from base of type - uint32_t serialOffset; ///< Offset in serial type -}; + bool decodeNull() + { + if (!isNull()) + return false; -typedef uint8_t SerialClassFlags; + if (_chunk != nullptr) + { + _chunk = _chunk->m_next; + } + return true; + } -struct SerialClassFlag -{ - enum Enum : SerialClassFlags + struct WithArray { - DontSerialize = - 0x01, ///< If set the type is not serialized, so can turn into SerialIndex(0) - }; -}; + public: + WithArray(Decoder& decoder) + : _decoder(decoder) + { + _saved = decoder._chunk; + decoder.beginArray(); + } -/* SerialClass defines the type (typeKind/subType) and the fields in just this class definition (ie -not it's super class). Also contains a pointer to the super type if there is one */ -struct SerialClass -{ - SerialTypeKind typeKind; ///< The type kind - SerialSubType subType; ///< Subtype - meaning depends on typeKind + WithArray(Decoder& decoder, FourCC typeCode) + : _decoder(decoder) + { + _saved = decoder._chunk; + decoder.beginArray(typeCode); + } - uint8_t alignment; ///< Alignment of this type - SerialClassFlags flags; ///< Flags + ~WithArray() { _decoder._chunk = _saved->m_next; } - uint32_t size; ///< Size of the field in bytes + private: + RiffContainer::Chunk* _saved; + Decoder& _decoder; + }; - Index fieldsCount; - const SerialField* fields; + struct WithObject + { + public: + WithObject(Decoder& decoder) + : _decoder(decoder) + { + _saved = decoder._chunk; + decoder.beginObject(); + } - const SerialClass* super; ///< The super class -}; + WithObject(Decoder& decoder, FourCC typeCode) + : _decoder(decoder) + { + _saved = decoder._chunk; + decoder.beginObject(typeCode); + } -// An instance could be shared across Sessions, but for simplicity of life time -// here we don't deal with that -class SerialClasses : public RefObject -{ -public: - /// Will add it's own copy into m_classesByType - /// In process will calculate alignment, offset etc for fields - /// NOTE! the super set, *must* be an already added to this SerialClasses - const SerialClass* add(const SerialClass* cls); + ~WithObject() { _decoder._chunk = _saved->m_next; } - const SerialClass* add( - SerialTypeKind kind, - SerialSubType subType, - const SerialField* fields, - Index fieldsCount, - const SerialClass* superCls); + private: + RiffContainer::Chunk* _saved; + Decoder& _decoder; + }; - /// Add a type which will not serialize - const SerialClass* addUnserialized(SerialTypeKind kind, SerialSubType subType); + struct WithKeyValuePair + { + public: + WithKeyValuePair(Decoder& decoder) + : _decoder(decoder) + { + _saved = decoder._chunk; + decoder.beginKeyValuePair(); + } - /// Returns true if this cls is *owned* by this SerialClasses - bool isOwned(const SerialClass* cls) const; + WithKeyValuePair(Decoder& decoder, FourCC typeCode) + : _decoder(decoder) + { + _saved = decoder._chunk; + _decoder.beginKeyValuePair(typeCode); + } - /// Returns true if the SerialClasses structure appears ok - bool isOk() const; + ~WithKeyValuePair() { _decoder._chunk = _saved->m_next; } - /// Get a serial class based on its type/subType - const SerialClass* getSerialClass(SerialTypeKind typeKind, SerialSubType subType) const - { - const auto& classes = m_classesByTypeKind[Index(typeKind)]; - return (subType < classes.getCount()) ? classes[subType] : nullptr; - } + private: + RiffContainer::Chunk* _saved; + Decoder& _decoder; + }; - /// Ctor - SerialClasses(); + struct WithProperty + { + public: + WithProperty(Decoder& decoder, FourCC typeCode) + : _decoder(decoder) + { + _saved = decoder._chunk; + _decoder.beginProperty(typeCode); + } -protected: - SerialClass* _createSerialClass(const SerialClass* cls); + ~WithProperty() { _decoder._chunk = _saved->m_next; } - MemoryArena m_arena; + private: + RiffContainer::Chunk* _saved; + Decoder& _decoder; + }; - List<const SerialClass*> m_classesByTypeKind[Index(SerialTypeKind::CountOf)]; -}; -// !!!!!!!!!!!!!!!!!!!!! SerialGetFieldType<T> !!!!!!!!!!!!!!!!!!!!!!!!!!! -// Getting the type info, let's use a static variable to hold the state to keep simple + RiffContainer::Chunk* getCursor() { return _chunk; } + void setCursor(RiffContainer::Chunk* chunk) { _chunk = chunk; } -template<typename T> -struct SerialGetFieldType -{ - static const SerialFieldType* getFieldType() - { - typedef SerialTypeInfo<T> Info; - static const SerialFieldType type = { - sizeof(typename Info::SerialType), - uint8_t(Info::SerialAlignment), - &Info::toSerial, - &Info::toNative}; - return &type; - } +private: + RiffContainer::Chunk* _chunk = nullptr; }; -// !!!!!!!!!!!!!!!!!!!!! SerialGetFieldType<T> !!!!!!!!!!!!!!!!!!!!!!!!!!! - -template<typename T> -/* static */ SerialField SerialField::make(const char* name, T* in) -{ - uint8_t* ptr = reinterpret_cast<uint8_t*>(in); - - SerialField field; - field.name = name; - field.type = SerialGetFieldType<T>::getFieldType(); - // This only works because we in is an offset from 1 - field.nativeOffset = uint32_t(size_t(ptr) - 1); - field.serialOffset = 0; - return field; -} - -// !!!!!!!!!!!!!!!!!!!!! Convenience functions !!!!!!!!!!!!!!!!!!!!!!!!!!! - -template<typename NATIVE_TYPE, typename SERIAL_TYPE> -SLANG_FORCE_INLINE void toSerialValue( - SerialWriter* writer, - const NATIVE_TYPE& src, - SERIAL_TYPE& dst) -{ - SerialTypeInfo<NATIVE_TYPE>::toSerial(writer, &src, &dst); -} - -template<typename SERIAL_TYPE, typename NATIVE_TYPE> -SLANG_FORCE_INLINE void toNativeValue( - SerialReader* reader, - const SERIAL_TYPE& src, - NATIVE_TYPE& dst) -{ - SerialTypeInfo<NATIVE_TYPE>::toNative(reader, &src, &dst); -} } // namespace Slang diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index efb5814e6..e45311abc 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -9,6 +9,41 @@ namespace Slang { +bool SyntaxClassBase::isSubClassOf(SyntaxClassBase const& other) const +{ + auto selfInfo = getInfo(); + auto otherInfo = other.getInfo(); + if (!selfInfo || !otherInfo) + return false; + return unsigned((int)selfInfo->firstTag - (int)otherInfo->firstTag) < + unsigned(otherInfo->tagCount); +} + +UnownedTerminatedStringSlice SyntaxClassBase::getName() const +{ + return _info ? UnownedTerminatedStringSlice(_info->name) : UnownedTerminatedStringSlice(); +} + +void* SyntaxClassBase::createInstanceImpl(ASTBuilder* astBuilder) const +{ + if (!_info) + return nullptr; + if (!_info->createFunc) + return nullptr; + + return _info->createFunc(astBuilder); +} + +void SyntaxClassBase::destructInstanceImpl(void* instance) const +{ + if (!_info) + return; + if (!_info->destructFunc) + return; + + return _info->destructFunc(instance); +} + /* static */ const TypeExp TypeExp::empty; @@ -227,13 +262,13 @@ void printDiagnosticArg(StringBuilder& sb, ASTNodeType nodeType) sb << "discard"; break; default: - if (ASTClassInfo::getInfo(nodeType)->isDerivedFrom((uint32_t)ASTNodeType::Expr)) + if (SyntaxClass<NodeBase>(nodeType).isSubClassOf<Expr>()) sb << "expression"; - else if (ASTClassInfo::getInfo(nodeType)->isDerivedFrom((uint32_t)ASTNodeType::Stmt)) + else if (SyntaxClass<NodeBase>(nodeType).isSubClassOf<Stmt>()) sb << "statement"; - else if (ASTClassInfo::getInfo(nodeType)->isDerivedFrom((uint32_t)ASTNodeType::Decl)) + else if (SyntaxClass<NodeBase>(nodeType).isSubClassOf<Decl>()) sb << "decl"; - else if (ASTClassInfo::getInfo(nodeType)->isDerivedFrom((uint32_t)ASTNodeType::Val)) + else if (SyntaxClass<NodeBase>(nodeType).isSubClassOf<Val>()) sb << "val"; else sb << "node"; @@ -326,7 +361,7 @@ Decl* const* adjustFilterCursorImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo)) + if (decl->getClass().isSubClassOf(clsInfo)) { return ptr; } @@ -338,7 +373,7 @@ Decl* const* adjustFilterCursorImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo) && + if (decl->getClass().isSubClassOf(clsInfo) && !decl->hasModifier<HLSLStaticModifier>()) { return ptr; @@ -351,7 +386,7 @@ Decl* const* adjustFilterCursorImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo) && + if (decl->getClass().isSubClassOf(clsInfo) && decl->hasModifier<HLSLStaticModifier>()) { return ptr; @@ -378,7 +413,7 @@ Decl* const* getFilterCursorByIndexImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo)) + if (decl->getClass().isSubClassOf(clsInfo)) { if (index <= 0) { @@ -394,7 +429,7 @@ Decl* const* getFilterCursorByIndexImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo) && + if (decl->getClass().isSubClassOf(clsInfo) && !decl->hasModifier<HLSLStaticModifier>()) { if (index <= 0) @@ -411,7 +446,7 @@ Decl* const* getFilterCursorByIndexImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - if (decl->getClassInfo().isSubClassOf(clsInfo) && + if (decl->getClass().isSubClassOf(clsInfo) && decl->hasModifier<HLSLStaticModifier>()) { if (index <= 0) @@ -428,7 +463,7 @@ Decl* const* getFilterCursorByIndexImpl( } Index getFilterCountImpl( - const ReflectClassInfo& clsInfo, + const SyntaxClassBase& clsInfo, MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) @@ -442,7 +477,7 @@ Index getFilterCountImpl( for (; ptr != end; ptr++) { Decl* decl = *ptr; - count += Index(decl->getClassInfo().isSubClassOf(clsInfo)); + count += Index(decl->getClass().isSubClassOf(clsInfo)); } break; } @@ -452,7 +487,7 @@ Index getFilterCountImpl( { Decl* decl = *ptr; count += Index( - decl->getClassInfo().isSubClassOf(clsInfo) && + decl->getClass().isSubClassOf(clsInfo) && !decl->hasModifier<HLSLStaticModifier>()); } break; @@ -463,7 +498,7 @@ Index getFilterCountImpl( { Decl* decl = *ptr; count += Index( - decl->getClassInfo().isSubClassOf(clsInfo) && + decl->getClass().isSubClassOf(clsInfo) && decl->hasModifier<HLSLStaticModifier>()); } break; @@ -701,7 +736,7 @@ Type* DeclRefType::create(ASTBuilder* astBuilder, DeclRef<Decl> declRef) } else if (auto magicMod = declRef.getDecl()->findModifier<MagicTypeModifier>()) { - if (magicMod->magicNodeType == ASTNodeType(-1)) + if (!magicMod->magicNodeType) { SLANG_UNEXPECTED("unhandled type"); } diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index accc490f2..8d78872a6 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -117,7 +117,7 @@ inline void foreachDirectOrExtensionMemberOfType( _foreachDirectOrExtensionMemberOfType( semantics, declRef, - getClass<T>(), + getSyntaxClass<T>(), &Helper::callback, &helper); } diff --git a/source/slang/slang-value-reflect.cpp b/source/slang/slang-value-reflect.cpp deleted file mode 100644 index aa2b6fbe2..000000000 --- a/source/slang/slang-value-reflect.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "slang-value-reflect.h" - -#include "slang-generated-value-macro.h" -#include "slang-generated-value.h" -#include "slang.h" - -namespace Slang -{ - - -} // namespace Slang diff --git a/source/slang/slang-value-reflect.h b/source/slang/slang-value-reflect.h deleted file mode 100644 index 1110ad225..000000000 --- a/source/slang/slang-value-reflect.h +++ /dev/null @@ -1,12 +0,0 @@ -// slang-value-reflect.h - -#ifndef SLANG_VALUE_REFLECT_H -#define SLANG_VALUE_REFLECT_H - -#include "slang-generated-value-macro.h" -#include "slang-generated-value.h" - -// Create the functions to automatically convert between value types - - -#endif // SLANG_VALUE_REFLECT_H diff --git a/source/slang/slang-visitor.h b/source/slang/slang-visitor.h index 580029289..180956bde 100644 --- a/source/slang/slang-visitor.h +++ b/source/slang/slang-visitor.h @@ -5,235 +5,233 @@ // This file defines the basic "Visitor" pattern for doing dispatch // over the various categories of syntax node. -#include "slang-generated-ast-macro.h" +#include "slang-ast-dispatch.h" +#include "slang-ast-forward-declarations.h" #include "slang-syntax.h" namespace Slang { -// Macros to generate from ast-generated-macro file the vistors - -// Only runs 'param' macro if the marker is NONE (ie not ABSTRACT here) -#define SLANG_CLASS_ONLY_ABSTRACT_AST(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) -#define SLANG_CLASS_ONLY_AST(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - param(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) - -#define SLANG_CLASS_ONLY(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - SLANG_CLASS_ONLY_##MARKER(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) - -// Dispatch decl -#define SLANG_VISITOR_DISPATCH_DECL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; - // Dispatch -#define SLANG_VISITOR_DISPATCH_RESULT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) override \ - { \ - *(Result*)extra = ((Derived*)this)->visit##NAME(obj); \ - } - -#define SLANG_VISITOR_DISPATCH_VOID_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - virtual void dispatch_##NAME(NAME* obj, void*) override \ - { \ - ((Derived*)this)->visit##NAME(obj); \ +#if 0 // FIDDLE TEMPLATE: +%function SLANG_VISITOR_DISPATCH_RESULT_IMPL(baseType) +% for _,T in ipairs(baseType.subclasses) do +% if not T.isAbstract then + Result _dispatchImpl($T* obj) + { + return ((Derived*)this)->visit$T(obj); } +% end +% end +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END // Visitor with and without result -#define SLANG_VISITOR_RESULT_VISIT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - Result visit##NAME(NAME* obj) \ - { \ - return ((Derived*)this)->visit##SUPER(obj); \ - } - -#define SLANG_VISITOR_VOID_VISIT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - void visit##NAME(NAME* obj) \ - { \ - ((Derived*)this)->visit##SUPER(obj); \ +#if 0 // FIDDLE TEMPLATE: +%function SLANG_VISITOR_VISIT_RESULT_IMPL(baseType) +% for _,T in ipairs(baseType.subclasses) do + Result visit$T($T* obj) + { + return ((Derived*)this)->visit$(T.directSuperClass)(obj); } +% end +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 1 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END // Args -#define SLANG_VISITOR_DISPATCH_ARG_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - virtual void dispatch_##NAME(NAME* obj, void* arg) override \ - { \ - ((Derived*)this)->visit##NAME(obj, *(Arg*)arg); \ - } +#if 0 // FIDDLE TEMPLATE: +%function SLANG_VISITOR_DISPATCH_ARG_IMPL(baseType) +% for _, T in ipairs(baseType.subclasses) do +% if not T.isAbstract then +virtual void _dispatchImpl($T* obj, Arg const& arg) +{ + ((Derived*)this)->visit$T(obj, arg); +} +% end +% end +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 2 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END + +#if 0 // FIDDLE TEMPLATE: +%function SLANG_VISITOR_VISIT_ARG_IMPL(baseType) +% for _, T in ipairs(baseType.subclasses) do +void visit$T($T* obj, Arg const& arg) +{ + ((Derived*)this)->visit$(T.directSuperClass)(obj, arg); +} +% end +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 3 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END -#define SLANG_VISITOR_VOID_VISIT_ARG_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - void visit##NAME(NAME* obj, Arg const& arg) \ - { \ - ((Derived*)this)->visit##SUPER(obj, arg); \ - } // // type Visitors // -struct ITypeVisitor -{ - SLANG_CHILDREN_ASTNode_Type(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - // Suppress VS2017 Unreachable code warning #ifdef _MSC_VER #pragma warning(push) #pragma warning(disable : 4702) #endif -template<typename Derived, typename Result = void, typename Base = ITypeVisitor> -struct TypeVisitor : Base +template<typename Derived, typename Result = void> +struct TypeVisitor { Result dispatch(Type* type) { - Result result; - type->accept(this, &result); - return result; + return ASTNodeDispatcher<Type, Result>::dispatch( + type, + [&](auto obj) { return _dispatchImpl(obj); }); } Result dispatchType(Type* type) { - Result result; - type->accept(this, &result); - return result; + return ASTNodeDispatcher<Type, Result>::dispatch( + type, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_Type(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_Type(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived, typename Base> -struct TypeVisitor<Derived, void, Base> : Base -{ - void dispatch(Type* type) { type->accept(this, 0); } - - void dispatchType(Type* type) { type->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_Type(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_Type(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.Type) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.Type) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 4 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; -template<typename Derived, typename Arg, typename Base = ITypeVisitor> -struct TypeVisitorWithArg : Base +template<typename Derived, typename Arg> +struct TypeVisitorWithArg { - void dispatch(Type* type, Arg const& arg) { type->accept(this, (void*)&arg); } + void dispatch(Type* type, Arg const& arg) + { + ASTNodeDispatcher<Type, void>::dispatch(type, [&](auto obj) { _dispatchImpl(obj, arg); }); + } - SLANG_CHILDREN_ASTNode_Type(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_ARG_IMPL) - SLANG_CHILDREN_ASTNode_Type(SLANG_VISITOR_VOID_VISIT_ARG_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_ARG_IMPL(Slang.Type) + % SLANG_VISITOR_VISIT_ARG_IMPL(Slang.Type) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 5 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; // // Expression Visitors // -struct IExprVisitor -{ - SLANG_CHILDREN_ASTNode_Expr(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - template<typename Derived, typename Result = void> -struct ExprVisitor : IExprVisitor +struct ExprVisitor { Result dispatch(Expr* expr) { - Result result; - expr->accept(this, &result); - return result; + return ASTNodeDispatcher<Expr, Result>::dispatch( + expr, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_Expr(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_Expr(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived> -struct ExprVisitor<Derived, void> : IExprVisitor -{ - void dispatch(Expr* expr) { expr->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_Expr(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_Expr(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.Expr) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.Expr) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 6 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; template<typename Derived, typename Arg> -struct ExprVisitorWithArg : IExprVisitor +struct ExprVisitorWithArg { - void dispatch(Expr* obj, Arg const& arg) { obj->accept(this, (void*)&arg); } + void dispatch(Expr* expr, Arg const& arg) + { + ASTNodeDispatcher<Expr, void>::dispatch(expr, [&](auto obj) { _dispatchImpl(obj, arg); }); + } - SLANG_CHILDREN_ASTNode_Expr(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_ARG_IMPL) - SLANG_CHILDREN_ASTNode_Expr(SLANG_VISITOR_VOID_VISIT_ARG_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_ARG_IMPL(Slang.Expr) + % SLANG_VISITOR_VISIT_ARG_IMPL(Slang.Expr) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 7 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; // // Statement Visitors // -struct IStmtVisitor -{ - SLANG_CHILDREN_ASTNode_Stmt(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - template<typename Derived, typename Result = void> -struct StmtVisitor : IStmtVisitor +struct StmtVisitor { Result dispatch(Stmt* stmt) { - Result result; - stmt->accept(this, &result); - return result; + return ASTNodeDispatcher<Stmt, Result>::dispatch( + stmt, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_Stmt(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_Stmt(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived> -struct StmtVisitor<Derived, void> : IStmtVisitor -{ - void dispatch(Stmt* stmt) { stmt->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_Stmt(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_Stmt(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.Stmt) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.Stmt) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 8 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; // // Declaration Visitors // -struct IDeclVisitor -{ - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - template<typename Derived, typename Result = void> -struct DeclVisitor : IDeclVisitor +struct DeclVisitor { Result dispatch(DeclBase* decl) { - Result result; - decl->accept(this, &result); - return result; + return ASTNodeDispatcher<DeclBase, Result>::dispatch( + decl, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived> -struct DeclVisitor<Derived, void> : IDeclVisitor -{ - void dispatch(DeclBase* decl) { decl->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.DeclBase) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.DeclBase) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 9 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; template<typename Derived, typename Arg> -struct DeclVisitorWithArg : IDeclVisitor +struct DeclVisitorWithArg { - void dispatch(DeclBase* obj, Arg const& arg) { obj->accept(this, (void*)&arg); } + void dispatch(DeclBase* decl, Arg const& arg) + { + ASTNodeDispatcher<Expr, void>::dispatch(decl, [&](auto obj) { _dispatchImpl(obj, arg); }); + } - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_ARG_IMPL) - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_VISITOR_VOID_VISIT_ARG_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_ARG_IMPL(Slang.DeclBase) + % SLANG_VISITOR_VISIT_ARG_IMPL(Slang.DeclBase) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 10 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; @@ -241,64 +239,46 @@ struct DeclVisitorWithArg : IDeclVisitor // Modifier Visitors // -struct IModifierVisitor -{ - SLANG_CHILDREN_ASTNode_Modifier(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - template<typename Derived, typename Result = void> -struct ModifierVisitor : IModifierVisitor +struct ModifierVisitor { Result dispatch(Modifier* modifier) { - Result result; - modifier->accept(this, &result); - return result; + return ASTNodeDispatcher<Modifier, Result>::dispatch( + modifier, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_Modifier(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_Modifier(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived> -struct ModifierVisitor<Derived, void> : IModifierVisitor -{ - void dispatch(Modifier* modifier) { modifier->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_Modifier(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_Modifier(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.Modifier) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.Modifier) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 11 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; // // Val Visitors // -struct IValVisitor : ITypeVisitor -{ - SLANG_CHILDREN_ASTNode_Val(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_DECL) -}; - template<typename Derived, typename Result = void, typename TypeResult = void> -struct ValVisitor : TypeVisitor<Derived, TypeResult, IValVisitor> +struct ValVisitor : TypeVisitor<Derived, TypeResult> { Result dispatch(Val* val) { - Result result; - val->accept(this, &result); - return result; + return ASTNodeDispatcher<Val, Result>::dispatch( + val, + [&](auto obj) { return _dispatchImpl(obj); }); } - SLANG_CHILDREN_ASTNode_Val(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_RESULT_IMPL) - SLANG_CHILDREN_ASTNode_Val(SLANG_VISITOR_RESULT_VISIT_IMPL, _) -}; - -template<typename Derived> -struct ValVisitor<Derived, void, void> : TypeVisitor<Derived, void, IValVisitor> -{ - void dispatch(Val* val) { val->accept(this, 0); } - - SLANG_CHILDREN_ASTNode_Val(SLANG_CLASS_ONLY, SLANG_VISITOR_DISPATCH_VOID_IMPL) - SLANG_CHILDREN_ASTNode_Val(SLANG_VISITOR_VOID_VISIT_IMPL, _) +#if 0 // FIDDLE TEMPLATE: + % SLANG_VISITOR_DISPATCH_RESULT_IMPL(Slang.Val) + % SLANG_VISITOR_VISIT_RESULT_IMPL(Slang.Val) +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 12 +#include "slang-visitor.h.fiddle" +#endif // FIDDLE END }; // Re-activate VS2017 warning settings diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 351ab6f06..99457647d 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -570,64 +570,110 @@ SlangResult Session::saveCoreModule(SlangArchiveType archiveType, ISlangBlob** o } SlangResult Session::saveBuiltinModule( - slang::BuiltinModuleName builtinModuleName, + slang::BuiltinModuleName moduleTag, SlangArchiveType archiveType, ISlangBlob** outBlob) { + // If no builtin modules have been loaded, then there is + // nothing to save, and we fail immediately. + // if (m_builtinLinkage->mapNameToLoadedModules.getCount() == 0) { - // There is no standard lib loaded return SLANG_FAIL; } - BuiltinModuleInfo builtinModuleInfo = getBuiltinModuleInfo(builtinModuleName); - - // Make a file system to read it from - ComPtr<ISlangMutableFileSystem> fileSystem; - SLANG_RETURN_ON_FAIL(createArchiveFileSystem(archiveType, fileSystem)); - - // Must have archiveFileSystem interface - auto archiveFileSystem = as<IArchiveFileSystem>(fileSystem); - if (!archiveFileSystem) - { - return SLANG_FAIL; - } + // The module will need to be looked up by its name, and + // will also be serialized out to a path with a matching name. + // + BuiltinModuleInfo moduleInfo = getBuiltinModuleInfo(moduleTag); + const char* moduleName = moduleInfo.name; + // If we cannot find a loaded module in the linkage with + // the appropriate name, then for some reason it hasn't + // been loaded, and we fail. + // RefPtr<Module> module; m_builtinLinkage->mapNameToLoadedModules.tryGetValue( - getNameObj(UnownedStringSlice(builtinModuleInfo.name)), + getNameObj(UnownedStringSlice(moduleName)), module); if (!module) { return SLANG_FAIL; } + // AST serialization needs access to an AST builder, so + // we establish a current builder for the duration of + // the serialization process. + // SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder()); - // Set up options - SerialContainerUtil::WriteOptions options; + // The serialized module will be represented as a logical + // file in an archive, so we create a logical file system + // to represent that archive. + // + ComPtr<ISlangMutableFileSystem> fileSystem; + SLANG_RETURN_ON_FAIL(createArchiveFileSystem(archiveType, fileSystem)); + // + // The created file system must support the `IArchiveFileSystem` + // interface (since we created it with `createArchiveFileSystem`). + // + auto archiveFileSystem = as<IArchiveFileSystem>(fileSystem); + if (!archiveFileSystem) + { + return SLANG_FAIL; + } - // Save with SourceLocation information - options.optionFlags |= SerialOptionFlag::SourceLocation; + // The output file name that we'll write to in that file system + // is just the builtin module name with a `.slang-module` suffix. + // + StringBuilder moduleFileName; + moduleFileName << moduleName << ".slang-module"; - // TODO(JS): Should this be the Session::getBuiltinSourceManager()? + // The module serialization step has some options that we need + // to configure appropriately. + // + SerialContainerUtil::WriteOptions options; + // + // We want builtin modules to be saved with their source location + // information. + // + options.optionFlags |= SerialOptionFlag::SourceLocation; + // + // And in order to work with source locations, the serialization + // process will also need access to the source manager that + // can translate locations into their humane format. + // options.sourceManager = m_builtinLinkage->getSourceManager(); - StringBuilder builder; - builder << builtinModuleInfo.name << ".slang-module"; - + // At this point we can finally delegate down to the next level, + // which handles the serialization of a Slang module into a + // byte stream. + // OwnedMemoryStream stream(FileAccess::Write); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(module, options, &stream)); - auto contents = stream.getContents(); - // Write into the file system - SLANG_RETURN_ON_FAIL( - fileSystem->saveFile(builder.getBuffer(), contents.getBuffer(), contents.getCount())); + // Once the stream that represents the module has been written, we can + // write it to a file in the logical file system. + // + // TODO(tfoley): why can't the file system let us open the file for output? + // + SLANG_RETURN_ON_FAIL(fileSystem->saveFile( + moduleFileName.getBuffer(), + contents.getBuffer(), + contents.getCount())); + + // And finally, we can ask the archive file system to serialize itself + // out as a blob of bytes, which yields the final serialized representation + // of the module. + // + SLANG_RETURN_ON_FAIL(archiveFileSystem->storeArchive( + // The `true` here indicates that the blob that gets created should own + // its content, independent from the file system object itself; otherwise + // the file system might return a blob that shares storage with itself. + true, + outBlob)); - // Now need to convert into a blob - SLANG_RETURN_ON_FAIL(archiveFileSystem->storeArchive(true, outBlob)); return SLANG_OK; } @@ -654,74 +700,98 @@ SlangResult Session::_readBuiltinModule( SLANG_RETURN_ON_FAIL(RiffUtil::read(&stream, riffContainer)); } - // Load up the module + Linkage* linkage = getBuiltinLinkage(); + SourceManager* sourceManager = getBuiltinSourceManager(); + NamePool* sessionNamePool = &namePool; - SerialContainerData containerData; + auto moduleChunk = ModuleChunkRef::find(&riffContainer); + if (!moduleChunk) + return SLANG_FAIL; - Linkage* linkage = getBuiltinLinkage(); + SHA1::Digest moduleDigest = moduleChunk.getDigest(); - SourceManager* sourceManger = getBuiltinSourceManager(); + auto irChunk = moduleChunk.findIR(); + if (!irChunk) + return SLANG_FAIL; - NamePool* sessionNamePool = &namePool; - NamePool* linkageNamePool = linkage->getNamePool(); + auto astChunk = moduleChunk.findAST(); + if (!astChunk) + return SLANG_FAIL; - SerialContainerUtil::ReadOptions options; - options.namePool = linkageNamePool; - options.session = this; - options.sharedASTBuilder = linkage->getASTBuilder()->getSharedASTBuilder(); - options.astBuilder = linkage->getASTBuilder(); - options.sourceManager = sourceManger; - options.linkage = linkage; + // Source location information is stored as a distinct + // chunk from the IR and AST, so we need to search for + // that chunk and then set up the information for use + // in the IR and AST deserialization (if we find anything). + // + RefPtr<SerialSourceLocReader> sourceLocReader; + if (auto debugChunk = findDebugChunk(moduleChunk.ptr())) + { + SLANG_RETURN_ON_FAIL( + readSourceLocationsFromDebugChunk(debugChunk, sourceManager, sourceLocReader)); + } - // Hmm - don't have a suitable sink yet, so attempt to just not have one - options.sink = nullptr; + // At this point we create the `Module` object that will + // represent the builtin module we are reading, although + // it is still possible that deserialization will fail + // at one of the following steps. + // + auto astBuilder = linkage->getASTBuilder(); + RefPtr<Module> module(new Module(linkage, astBuilder)); + module->setName(moduleName); + module->setDigest(moduleDigest); - SLANG_RETURN_ON_FAIL( - SerialContainerUtil::read(&riffContainer, options, nullptr, containerData)); - for (auto& srcModule : containerData.modules) + // Next, we set about deserializing the AST representation + // of the module. + // + auto moduleDecl = readSerializedModuleAST( + linkage, + astBuilder, + nullptr, // no sink + astChunk, + sourceLocReader, + SourceLoc()); + if (!moduleDecl) { - RefPtr<Module> module(new Module(linkage, srcModule.astBuilder)); - module->setName(moduleName); - module->setDigest(srcModule.digest); - - ModuleDecl* moduleDecl = as<ModuleDecl>(srcModule.astRootNode); - // Set the module back reference on the decl - moduleDecl->module = module; + return SLANG_FAIL; + } + moduleDecl->module = module; + module->setModuleDecl(moduleDecl); - if (moduleDecl) - { - if (isFromCoreModule(moduleDecl)) - { - registerBuiltinDecls(this, moduleDecl); - } + if (isFromCoreModule(moduleDecl)) + { + registerBuiltinDecls(this, moduleDecl); + } - module->setModuleDecl(moduleDecl); - } + // After the AST module has been read in, we next look + // to deserialize the IR module. + // + RefPtr<IRModule> irModule; + SLANG_RETURN_ON_FAIL(decodeModuleIR(irModule, irChunk, this, sourceLocReader)); - srcModule.irModule->setName(module->getNameObj()); - module->setIRModule(srcModule.irModule); + irModule->setName(module->getNameObj()); + module->setIRModule(irModule); - // Put in the loaded module map - linkage->mapNameToLoadedModules.add(sessionNamePool->getName(moduleName), module); + // Put in the loaded module map + linkage->mapNameToLoadedModules.add(sessionNamePool->getName(moduleName), module); - // Add the resulting code to the appropriate scope - if (!scope->containerDecl) - { - // We are the first chunk of code to be loaded for this scope - scope->containerDecl = moduleDecl; - } - else - { - // We need to create a new scope to link into the whole thing - auto subScope = linkage->getASTBuilder()->create<Scope>(); - subScope->containerDecl = moduleDecl; - subScope->nextSibling = scope->nextSibling; - scope->nextSibling = subScope; - } - outModule = module.get(); + // Add the resulting code to the appropriate scope + if (!scope->containerDecl) + { + // We are the first chunk of code to be loaded for this scope + scope->containerDecl = moduleDecl; } + else + { + // We need to create a new scope to link into the whole thing + auto subScope = linkage->getASTBuilder()->create<Scope>(); + subScope->containerDecl = moduleDecl; + subScope->nextSibling = scope->nextSibling; + scope->nextSibling = subScope; + } + + outModule = module.get(); return SLANG_OK; } @@ -1526,9 +1596,10 @@ slang::IModule* Linkage::loadModuleFromBlob( pathInfo = PathInfo::makeNormal(pathStr, cannonicalPath); } } - auto module = loadModule(name, pathInfo, source, SourceLoc(), &sink, nullptr, blobType); + RefPtr<Module> module = + loadModuleImpl(name, pathInfo, source, SourceLoc(), &sink, nullptr, blobType); sink.getBlobIfNeeded(outDiagnostics); - return asExternal(module); + return asExternal(module.detach()); } catch (const AbortCompilationException& e) { @@ -4057,101 +4128,157 @@ void Linkage::loadParsedModule( loadedModulesList.add(loadedModule); } -RefPtr<Module> Linkage::loadDeserializedModule( - Name* name, - const PathInfo& filePathInfo, - SerialContainerData::Module& moduleEntry, +RefPtr<Module> Linkage::findOrLoadSerializedModuleForModuleLibrary( + ModuleChunkRef moduleChunk, DiagnosticSink* sink) { - SLANG_AST_BUILDER_RAII(m_astBuilder); RefPtr<Module> resultModule; - if (mapNameToLoadedModules.tryGetValue(name, resultModule)) - return resultModule; - if (mapPathToLoadedModule.tryGetValue(filePathInfo.getMostUniqueIdentity(), resultModule)) + + // We will attempt things in a few different steps, trying to + // decode as little of the serialized module as necessary at + // each step, so that we don't waste time on the heavyweight + // stuff when we didn't need to. + // + // The first step is to simply decode the module name, and + // see if we have a already loaded a matching module. + + auto moduleName = getNamePool()->getName(moduleChunk.getName()); + if (mapNameToLoadedModules.tryGetValue(moduleName, resultModule)) return resultModule; - resultModule = new Module(this, m_astBuilder); - prepareDeserializedModule(moduleEntry, filePathInfo, resultModule, sink); + // It is possible that the module has been loaded, but somehow + // under a different name, so next we decode the list of file + // paths that the module depends on, and then rely on the assumption + // that the first of those paths represents the file for the module + // itself to detect if we've already loaded a module from that + // path. + // + // Note: While this is a distasteful assumption to make, it is + // one that gets made in several parts of the compiler codebase + // already. It isn't something that can be fixed in just one + // place at this point. + + auto fileDependenciesChunk = moduleChunk.getFileDependencies(); + auto firstFileDependencyChunk = fileDependenciesChunk.getFirst(); + if (!firstFileDependencyChunk) + return nullptr; + + auto modulePathInfo = PathInfo::makePath(firstFileDependencyChunk.getValue()); + if (mapPathToLoadedModule.tryGetValue(modulePathInfo.getMostUniqueIdentity(), resultModule)) + return resultModule; - loadedModulesList.add(resultModule); - mapPathToLoadedModule.add(filePathInfo.getMostUniqueIdentity(), resultModule); - mapNameToLoadedModules.add(name, resultModule); - return resultModule; + // If we failed to find a previously-loaded module, then we + // will go ahead and load the module from the serialized form. + // + PathInfo filePathInfo; + return loadSerializedModule(moduleName, modulePathInfo, moduleChunk, SourceLoc(), sink); } -RefPtr<Module> Linkage::loadModuleFromIRBlobImpl( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* fileContentsBlob, - SourceLoc const& loc, - DiagnosticSink* sink, - const LoadedModuleDictionary* additionalLoadedModules) +RefPtr<Module> Linkage::loadSerializedModule( + Name* moduleName, + const PathInfo& moduleFilePathInfo, + ModuleChunkRef moduleChunk, + SourceLoc const& requestingLoc, + DiagnosticSink* sink) { - SLANG_AST_BUILDER_RAII(m_astBuilder); + auto astBuilder = getASTBuilder(); + SLANG_AST_BUILDER_RAII(astBuilder); - RefPtr<Module> resultModule = new Module(this, getASTBuilder()); - resultModule->setName(name); - ModuleBeingImportedRAII moduleBeingImported(this, resultModule, name, loc); + auto module = RefPtr(new Module(this, astBuilder)); + module->setName(moduleName); - String mostUniqueIdentity = filePathInfo.getMostUniqueIdentity(); - SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); + // Just as if we were processing an `import` declaration in + // source code, we will track the fact that this serialized + // modlue is (effectively) being imported, so that we can + // diagnose anything troublesome, like an attempt at a + // recursive import. + // + ModuleBeingImportedRAII moduleBeingImported(this, module, moduleName, requestingLoc); - RiffContainer container; - MemoryStreamBase readStream( - FileAccess::Read, - fileContentsBlob->getBufferPointer(), - fileContentsBlob->getBufferSize()); - SLANG_RETURN_NULL_ON_FAIL(RiffUtil::read(&readStream, container)); + // We will register the module in our data structures to + // track loaded modules, and then remove it in the case + // where there is some kind of failure. + // + String mostUniqueIdentity = moduleFilePathInfo.getMostUniqueIdentity(); + SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); - if (m_optionSet.getBoolOption(CompilerOptionName::UseUpToDateBinaryModule)) + mapPathToLoadedModule.add(mostUniqueIdentity, module); + mapNameToLoadedModules.add(moduleName, module); + try { - if (!isBinaryModuleUpToDate(filePathInfo.foundPath, &container)) + if (SLANG_FAILED( + loadSerializedModuleContents(module, moduleFilePathInfo, moduleChunk, sink))) + { + mapPathToLoadedModule.remove(mostUniqueIdentity); + mapNameToLoadedModules.remove(moduleName); return nullptr; - } + } - mapPathToLoadedModule.add(mostUniqueIdentity, resultModule); - mapNameToLoadedModules.add(name, resultModule); - - SerialContainerUtil::ReadOptions readOptions; - readOptions.linkage = this; - readOptions.astBuilder = getASTBuilder(); - readOptions.session = getSessionImpl(); - readOptions.sharedASTBuilder = getASTBuilder()->getSharedASTBuilder(); - readOptions.sink = sink; - readOptions.sourceManager = getSourceManager(); - readOptions.namePool = getNamePool(); - readOptions.modulePath = filePathInfo.foundPath; - SerialContainerData containerData; - if (SLANG_FAILED(SerialContainerUtil::read( - &container, - readOptions, - additionalLoadedModules, - containerData)) || - containerData.modules.getCount() != 1) + loadedModulesList.add(module); + return module; + } + catch (...) { mapPathToLoadedModule.remove(mostUniqueIdentity); - mapNameToLoadedModules.remove(name); - return nullptr; + mapNameToLoadedModules.remove(moduleName); + throw; } - auto moduleEntry = containerData.modules.getFirst(); +} - prepareDeserializedModule(moduleEntry, filePathInfo, resultModule, sink); +RefPtr<Module> Linkage::loadBinaryModuleImpl( + Name* moduleName, + const PathInfo& moduleFilePathInfo, + ISlangBlob* moduleFileContents, + SourceLoc const& requestingLoc, + DiagnosticSink* sink) +{ + auto astBuilder = getASTBuilder(); + SLANG_AST_BUILDER_RAII(astBuilder); - loadedModulesList.add(resultModule); - resultModule->setPathInfo(filePathInfo); - resultModule->getIRModule()->setName(resultModule->getNameObj()); + // We start by reading the content of the file into + // an in-memory RIFF container. + // + // TODO(tfoley): this is an unnecessary copy step, since + // we can simply use the contents of the blob directly + // and navigate it in-memory. + // + RiffContainer riffContainer; + { + MemoryStreamBase readStream( + FileAccess::Read, + moduleFileContents->getBufferPointer(), + moduleFileContents->getBufferSize()); + SLANG_RETURN_NULL_ON_FAIL(RiffUtil::read(&readStream, riffContainer)); + } - return resultModule; -} + auto moduleChunkRef = ModuleChunkRef::find(&riffContainer); + if (!moduleChunkRef) + { + return nullptr; + } -Module* Linkage::loadModule(String const& name) -{ - // TODO: We either need to have a diagnostics sink - // get passed into this operation, or associate - // one with the linkage. + // Next, we attempt to check if the binary module is up to + // date with the compilation options in use as well as + // the contents of all the files its compilation depended + // on (as determined by its hash). // - DiagnosticSink* sink = nullptr; - return findOrImportModule(getNamePool()->getName(name), SourceLoc(), sink); + String mostUniqueIdentity = moduleFilePathInfo.getMostUniqueIdentity(); + SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); + if (m_optionSet.getBoolOption(CompilerOptionName::UseUpToDateBinaryModule)) + { + if (!isBinaryModuleUpToDate(moduleFilePathInfo.foundPath, moduleChunkRef)) + { + return nullptr; + } + } + + // If everything seems reasonable, then we will go ahead and load + // the module more completely from that serialized representation. + // + RefPtr<Module> module = + loadSerializedModule(moduleName, moduleFilePathInfo, moduleChunkRef, requestingLoc, sink); + + return module; } void Linkage::_diagnoseErrorInImportedModule(DiagnosticSink* sink) @@ -4166,24 +4293,43 @@ void Linkage::_diagnoseErrorInImportedModule(DiagnosticSink* sink) } } -RefPtr<Module> Linkage::loadModule( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* sourceBlob, - SourceLoc const& srcLoc, +RefPtr<Module> Linkage::loadModuleImpl( + Name* moduleName, + const PathInfo& modulePathInfo, + ISlangBlob* moduleBlob, + SourceLoc const& requestingLoc, DiagnosticSink* sink, const LoadedModuleDictionary* additionalLoadedModules, ModuleBlobType blobType) { - if (blobType == ModuleBlobType::IR) - return loadModuleFromIRBlobImpl( - name, - filePathInfo, - sourceBlob, - srcLoc, + switch (blobType) + { + case ModuleBlobType::IR: + return loadBinaryModuleImpl(moduleName, modulePathInfo, moduleBlob, requestingLoc, sink); + + case ModuleBlobType::Source: + return loadSourceModuleImpl( + moduleName, + modulePathInfo, + moduleBlob, + requestingLoc, sink, additionalLoadedModules); + default: + SLANG_UNEXPECTED("unknown module blob type"); + UNREACHABLE_RETURN(nullptr); + } +} + +RefPtr<Module> Linkage::loadSourceModuleImpl( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* sourceBlob, + SourceLoc const& srcLoc, + DiagnosticSink* sink, + const LoadedModuleDictionary* additionalLoadedModules) +{ RefPtr<FrontEndCompileRequest> frontEndReq = new FrontEndCompileRequest(this, nullptr, sink); frontEndReq->additionalLoadedModules = additionalLoadedModules; @@ -4275,8 +4421,10 @@ RefPtr<Module> Linkage::loadModule( return nullptr; } - if (module) - module->setPathInfo(filePathInfo); + if (!module) + return nullptr; + + module->setPathInfo(filePathInfo); return module; } @@ -4319,126 +4467,263 @@ String getFileNameFromModuleName(Name* name, bool translateUnderScore) } RefPtr<Module> Linkage::findOrImportModule( - Name* name, - SourceLoc const& loc, + Name* moduleName, + SourceLoc const& requestingLoc, DiagnosticSink* sink, const LoadedModuleDictionary* loadedModules) { // Have we already loaded a module matching this name? // - RefPtr<LoadedModule> loadedModule; - if (mapNameToLoadedModules.tryGetValue(name, loadedModule)) + RefPtr<LoadedModule> previouslyLoadedModule; + if (mapNameToLoadedModules.tryGetValue(moduleName, previouslyLoadedModule)) { // If the map shows a null module having been loaded, // then that means there was a prior load attempt, // but it failed, so we won't bother trying again. // - if (!loadedModule) + if (!previouslyLoadedModule) return nullptr; // If state shows us that the module is already being // imported deeper on the call stack, then we've // hit a recursive case, and that is an error. // - if (isBeingImported(loadedModule)) + if (isBeingImported(previouslyLoadedModule)) { // We seem to be in the middle of loading this module - sink->diagnose(loc, Diagnostics::recursiveModuleImport, name); + sink->diagnose(requestingLoc, Diagnostics::recursiveModuleImport, moduleName); return nullptr; } - return loadedModule; + return previouslyLoadedModule; } // If the user is providing an additional list of loaded modules, we find // if the module being imported is in that list. This allows a translation // unit to use previously checked translation units in the same // FrontEndCompileRequest. - Module* previouslyLoadedModule = nullptr; - if (loadedModules && loadedModules->tryGetValue(name, previouslyLoadedModule)) { - return previouslyLoadedModule; + Module* previouslyLoadedLocalModule = nullptr; + if (loadedModules && loadedModules->tryGetValue(moduleName, previouslyLoadedLocalModule)) + { + return previouslyLoadedLocalModule; + } } - if (name == getSessionImpl()->glslModuleName) + // If the name being requested matches the name of a built-in module, + // then we will special-case the process by loading that builtin + // module directly. + // + // TODO: right now this logic is only considering the built-in `glsl` + // module, but it should probably be generalized so that we can more + // easily support having multiple built-in modules rather than just + // putting everything into `core`. + // + if (moduleName == getSessionImpl()->glslModuleName) { // This is a builtin glsl module, just load it from embedded definition. auto glslModule = getSessionImpl()->getBuiltinModule(slang::BuiltinModuleName::GLSL); if (!glslModule) { - sink->diagnose(loc, Diagnostics::glslModuleNotAvailable, name); + // Note: the way this logic is currently written, if the built-in + // `glsl` module fails to load, then we will *not* fall back to + // searching for a user-defined module in a file like `glsl.slang`. + // + // It is unclear if this should be the default behavior or not. + // Should built-in modules be prioritized over user modules? + // Should built-in modules shadow user modules, even when the + // built-in module fails to load, for some reason? + // + sink->diagnose(requestingLoc, Diagnostics::glslModuleNotAvailable, moduleName); } return glslModule; } - // Next, try to find the file of the given name, - // using our ordinary include-handling logic. + // We are going to use a loop to search for a suitable file to + // load the module from, to account for a few key choices: + // + // * We can both load modules from a source `.slang` file, + // or from a binary `.slang-module` file. + // + // * For a variety of reasons, the `import` logic has historically + // translated underscores in a module name into dashes (so that + // `import my_module` will look for `my-module.slang`), and we + // try to support both that convention as well as a convention + // that preserves underscores. + // + // To try to keep this logic as orthogonal as possible, we first + // construct lists of the options we want to iterate over, and + // then do the actual loop later. - IncludeSystem includeSystem(&getSearchDirectories(), getFileSystemExt(), getSourceManager()); + ShortList<ModuleBlobType, 2> typesToTry; + if (isInLanguageServer()) + { + // When in language server, we always prefer to use source module if it is available. + typesToTry.add(ModuleBlobType::Source); + typesToTry.add(ModuleBlobType::IR); + } + else + { + // Look for a precompiled module first, if not exist, load from source. + typesToTry.add(ModuleBlobType::IR); + typesToTry.add(ModuleBlobType::Source); + } - // Get the original path info - PathInfo pathIncludedFromInfo = getSourceManager()->getPathInfo(loc, SourceLocType::Actual); - PathInfo filePathInfo; + // We will always search for a file name that directly matches the + // module name as written first, and then search for one with + // underscores replaced by dashes. The latter is the original + // behavior that `import` provided, but it seems safest to prefer + // the exact name spelled in the user's code when there might + // actually be ambiguity. + // + auto defaultSourceFileName = getFileNameFromModuleName(moduleName, false); + auto alternativeSourceFileName = getFileNameFromModuleName(moduleName, true); + String sourceFileNamesToTry[] = {defaultSourceFileName, alternativeSourceFileName}; + // We are going to look for the candidate file using the same + // logic that would be used for a preprocessor `#include`, + // so we set up the necessary state. + // + IncludeSystem includeSystem(&getSearchDirectories(), getFileSystemExt(), getSourceManager()); - // Look for a precompiled module first, if not exist, load from source. - bool shouldCheckBinaryModuleSettings[2] = {true, false}; + // Just like with a `#include`, the search will take into + // account the path to the file where the request to import + // this module came from (e.g. the source file with the + // `import` declaration), if such a path is available. + // + PathInfo requestingPathInfo = + getSourceManager()->getPathInfo(requestingLoc, SourceLocType::Actual); - for (auto checkBinaryModule : shouldCheckBinaryModuleSettings) + for (auto type : typesToTry) { - // When in language server, we always prefer to use source module if it is available. - if (isInLanguageServer()) - checkBinaryModule = !checkBinaryModule; - - // Try without translating `_` to `-` first, if that fails, try translating. - for (int translateUnderScore = 0; translateUnderScore <= 1; translateUnderScore++) + for (auto sourceFileName : sourceFileNamesToTry) { - auto moduleSourceFileName = getFileNameFromModuleName(name, translateUnderScore == 1); + // The `sourceFileName` will have the `.slang` extension, + // so if we are looking for a binary module, we need + // to change the extension we will look for. + // String fileName; - if (checkBinaryModule == 1) - fileName = Path::replaceExt(moduleSourceFileName, "slang-module"); - else - fileName = moduleSourceFileName; + switch (type) + { + case ModuleBlobType::Source: + fileName = sourceFileName; + break; - ComPtr<ISlangBlob> fileContents; + case ModuleBlobType::IR: + fileName = Path::replaceExt(sourceFileName, "slang-module"); + break; + } - // We have to load via the found path - as that is how file was originally loaded + // We now search for a file matching the desired name, + // using the same logic as for a `#include`. + // + // TODO: We might want to consider how to handle the case + // of an `import` with a relative path a little specially, + // since it could in theory be possible for two `.slang` + // files with the same base name to exist in different + // directories in a project, and we'd want file-relative + // `import`s to work for each, without having either one + // be able to "claim" the bare identifier of the base + // name for itself. + // + PathInfo filePathInfo; if (SLANG_FAILED( - includeSystem.findFile(fileName, pathIncludedFromInfo.foundPath, filePathInfo))) + includeSystem.findFile(fileName, requestingPathInfo.foundPath, filePathInfo))) { + // If we failed to find the file at this step, we + // will continue the search for our other options. + // continue; } - // Maybe this was loaded previously at a different relative name? + // We will *again* search for a previously loaded module. + // + // It is possible that the same file will have been loaded + // as a module under two different module names. The easiest + // way for this to happen is if there are `import` declarations + // using both the underscore and dash conventions (e.g., both + // `import "my-module.slang"` and `import my_module`). + // + // This case may also arise if one file `import`s a module using + // just an identifier for its name, but another `import`s it + // using a path (e.g., `import "subdir/file.slang"`). + // + // No matter how the situation arises, we only want to have one + // copy of the "same" module loaded at a given time, so we + // will re-use the existing module if we find one here. + // if (mapPathToLoadedModule.tryGetValue( filePathInfo.getMostUniqueIdentity(), - loadedModule)) - return loadedModule; + previouslyLoadedModule)) + { + // TODO: If we find a previously-loaded module at this step, + // then we should probably register that module under the + // given `moduleName` in the map of loaded modules, so + // that subsequent `import`s using the same form will find it. + // + return previouslyLoadedModule; + } - // Try to load it - if (!fileContents && SLANG_FAILED(includeSystem.loadFile(filePathInfo, fileContents))) + // Now we try to load the content of the file. + // + // If for some reason we could find a file at the + // given path, but for some reason couldn't *open* + // and *read* it, then we continue the search + // using whatever other candidate file names are left. + // + ComPtr<ISlangBlob> fileContents; + if (SLANG_FAILED(includeSystem.loadFile(filePathInfo, fileContents))) { continue; } - // We've found a file that we can load for the given module, so - // go ahead and perform the module-load action - auto resultModule = loadModule( - name, + // If we found a real file and were able to load its contents, + // then we'll go ahead and try to load a module from it, + // whether by compiling it or decoding the binary. + // + auto module = loadModuleImpl( + moduleName, filePathInfo, fileContents, - loc, + requestingLoc, sink, loadedModules, - (checkBinaryModule == 1 ? ModuleBlobType::IR : ModuleBlobType::Source)); - if (resultModule) - return resultModule; + type); + + // If the attempt to load the module from the given path + // was successful, we go ahead and use it, without trying + // out any other options. + // + if (module) + return module; } } - // Error: we cannot find the file. - sink->diagnose(loc, Diagnostics::cannotOpenFile, getFileNameFromModuleName(name, false)); - mapNameToLoadedModules[name] = nullptr; + // If we tried out all of our candidate file names + // and failed with each of them, then we diagnose + // an error based on the original *source* file + // name. + // + // TODO: this should really be an error message + // that clearly states something like "no file + // suitable for module `whatever` was found + // and loaded. + // + // Ideally that error message would include whatever + // of the candidate file names from the loop above + // got furthest along in the process (or just a + // list of the file names that were tried, if + // nothing was even found via the include system). + // + sink->diagnose(requestingLoc, Diagnostics::cannotOpenFile, defaultSourceFileName); + + // If the attempt to import the module failed, then + // we will stick a null pointer into the map of loaded + // modules, so that subsequent attempts to load a module + // with this name will return null without having to + // go through all the above steps yet again. + // + mapNameToLoadedModules[moduleName] = nullptr; return nullptr; } @@ -4454,27 +4739,19 @@ SourceFile* Linkage::loadSourceFile(String pathFrom, String path) } // Check if a serialized module is up-to-date with current compiler options and source files. -bool Linkage::isBinaryModuleUpToDate(String fromPath, RiffContainer* container) +bool Linkage::isBinaryModuleUpToDate(String fromPath, RiffContainer* riffContainer) { - DiagnosticSink sink; - SerialContainerUtil::ReadOptions readOptions; - readOptions.linkage = this; - readOptions.astBuilder = getASTBuilder(); - readOptions.session = getSessionImpl(); - readOptions.sharedASTBuilder = getASTBuilder()->getSharedASTBuilder(); - readOptions.sink = &sink; - readOptions.sourceManager = getSourceManager(); - readOptions.namePool = getNamePool(); - readOptions.readHeaderOnly = true; - - SerialContainerData containerData; - if (SLANG_FAILED(SerialContainerUtil::read(container, readOptions, nullptr, containerData))) + auto moduleChunk = ModuleChunkRef::find(riffContainer); + if (!moduleChunk) return false; - if (containerData.modules.getCount() != 1) - return false; + return isBinaryModuleUpToDate(fromPath, moduleChunk); +} + +bool Linkage::isBinaryModuleUpToDate(String fromPath, ModuleChunkRef moduleChunk) +{ + SHA1::Digest existingDigest = moduleChunk.getDigest(); - auto& moduleHeader = containerData.modules[0]; DigestBuilder<SHA1> digestBuilder; auto version = String(getBuildTagString()); digestBuilder.append(version); @@ -4482,9 +4759,12 @@ bool Linkage::isBinaryModuleUpToDate(String fromPath, RiffContainer* container) // Find the canonical path of the directory containing the module source file. String moduleSrcPath = ""; - if (moduleHeader.dependentFiles.getCount()) + + auto dependencyChunks = moduleChunk.getFileDependencies(); + if (auto firstDependencyChunk = dependencyChunks.getFirst()) { - moduleSrcPath = moduleHeader.dependentFiles.getFirst(); + moduleSrcPath = firstDependencyChunk.getValue(); + IncludeSystem includeSystem( &getSearchDirectories(), getFileSystemExt(), @@ -4497,21 +4777,22 @@ bool Linkage::isBinaryModuleUpToDate(String fromPath, RiffContainer* container) } } - for (auto file : moduleHeader.dependentFiles) + for (auto dependencyChunk : dependencyChunks) { + auto file = dependencyChunk.getValue(); auto sourceFile = loadSourceFile(fromPath, file); if (!sourceFile) { // If we cannot find the source file from `fromPath`, // try again from the module's source file path. - if (moduleHeader.dependentFiles.getCount() != 0) + if (dependencyChunks.getFirst()) sourceFile = loadSourceFile(moduleSrcPath, file); } if (!sourceFile) return false; digestBuilder.append(sourceFile->getDigest()); } - return digestBuilder.finalize() == moduleHeader.digest; + return digestBuilder.finalize() == existingDigest; } SLANG_NO_THROW bool SLANG_MCALL @@ -6243,20 +6524,100 @@ void Linkage::setFileSystem(ISlangFileSystem* inFileSystem) getSourceManager()->setFileSystemExt(m_fileSystemExt); } -void Linkage::prepareDeserializedModule( - SerialContainerData::Module& moduleEntry, - const PathInfo& filePathInfo, +SlangResult Linkage::loadSerializedModuleContents( Module* module, + const PathInfo& moduleFilePathInfo, + ModuleChunkRef moduleChunk, DiagnosticSink* sink) { - module->setIRModule(moduleEntry.irModule); - module->setModuleDecl(as<ModuleDecl>(moduleEntry.astRootNode)); + // At this point we've dealt with basically all of + // the formalities, and we just need to get down + // to the real work of decoding the information + // in the `moduleChunk`. + + auto sourceManager = getSourceManager(); + RefPtr<SerialSourceLocReader> sourceLocReader; + if (auto debugChunk = findDebugChunk(moduleChunk.ptr())) + { + SLANG_RETURN_ON_FAIL( + readSourceLocationsFromDebugChunk(debugChunk, sourceManager, sourceLocReader)); + } + + auto astChunk = moduleChunk.findAST(); + if (!astChunk) + return SLANG_FAIL; + + auto irChunk = moduleChunk.findIR(); + if (!irChunk) + return SLANG_FAIL; + + auto astBuilder = getASTBuilder(); + auto session = getSessionImpl(); + + // For the purposes of any modules referenced + // by the module we're about to decode, we will + // construct a source location that represents + // the module itself (if possible). + // + // TODO(tfoley): This logic seems like overkill, given + // that many (most? all?) control-flow paths that can + // reach this routine will have already found a `SourceFile` + // to represent the module, as part of even getting the + // `moduleFilePathInfo` to pass in + // + // The approach here is more or less exactly copied + // from what the old `SerialContainerUtil::read` function + // used to do, with the hopes that it will as many tests + // passing as possible. + // + // Down the line somebody should scrutinize all of this + // kind of logic in the compiler codebase, because there + // is something that feels unclean about how paths are being handled. + // + SourceLoc serializedModuleLoc; + { + auto sourceFile = + sourceManager->findSourceFileByPathRecursively(moduleFilePathInfo.foundPath); + if (!sourceFile) + { + sourceFile = sourceManager->createSourceFileWithString(moduleFilePathInfo, String()); + sourceManager->addSourceFile(moduleFilePathInfo.getMostUniqueIdentity(), sourceFile); + } + auto sourceView = + sourceManager->createSourceView(sourceFile, &moduleFilePathInfo, SourceLoc()); + serializedModuleLoc = sourceView->getRange().begin; + } + + auto moduleDecl = readSerializedModuleAST( + this, + astBuilder, + sink, + astChunk, + sourceLocReader, + serializedModuleLoc); + if (!moduleDecl) + return SLANG_FAIL; + module->setModuleDecl(moduleDecl); + + RefPtr<IRModule> irModule; + SLANG_RETURN_ON_FAIL(decodeModuleIR(irModule, irChunk, session, sourceLocReader)); + module->setIRModule(irModule); + + // The handling of file dependencies is complicated, because of + // the way that the encoding logic tried to make all of the + // paths be relative to the primary source file for the module. + // + // We end up needing to undo some amount of that work here. + // + module->clearFileDependency(); - String moduleSourcePath = filePathInfo.foundPath; + String moduleSourcePath = moduleFilePathInfo.foundPath; bool isFirst = true; - for (auto file : moduleEntry.dependentFiles) + for (auto depenencyFileChunk : moduleChunk.getFileDependencies()) { - auto sourceFile = loadSourceFile(filePathInfo.foundPath, file); + auto encodedDependencyFilePath = depenencyFileChunk.getValue(); + + auto sourceFile = loadSourceFile(moduleFilePathInfo.foundPath, encodedDependencyFilePath); if (isFirst) { // The first file is the source for the main module file. @@ -6270,20 +6631,19 @@ void Linkage::prepareDeserializedModule( // it relative to the module source path. if (!sourceFile) { - sourceFile = loadSourceFile(moduleSourcePath, file); + sourceFile = loadSourceFile(moduleSourcePath, encodedDependencyFilePath); } if (sourceFile) { module->addFileDependency(sourceFile); } } - module->setPathInfo(filePathInfo); - module->setDigest(moduleEntry.digest); + module->setPathInfo(moduleFilePathInfo); + module->setDigest(moduleChunk.getDigest()); module->_collectShaderParams(); module->_discoverEntryPoints(sink, targets); // Hook up fileDecl's scope to module's scope. - auto moduleDecl = module->getModuleDecl(); for (auto globalDecl : moduleDecl->members) { if (auto fileDecl = as<FileDecl>(globalDecl)) @@ -6291,6 +6651,8 @@ void Linkage::prepareDeserializedModule( addSiblingScopeForContainerDecl(m_astBuilder, moduleDecl->ownedScope, fileDecl); } } + + return SLANG_OK; } void Linkage::setRequireCacheFileSystem(bool requireCacheFileSystem) |
