diff options
Diffstat (limited to 'source')
82 files changed, 4374 insertions, 6321 deletions
diff --git a/source/compiler-core/slang-diagnostic-sink.h b/source/compiler-core/slang-diagnostic-sink.h index e4d131e37..fc5e31b47 100644 --- a/source/compiler-core/slang-diagnostic-sink.h +++ b/source/compiler-core/slang-diagnostic-sink.h @@ -310,7 +310,7 @@ private: class DiagnosticsLookup : public RefObject { public: - static const Index kArenaInitialSize = 2048; + static const Index kArenaInitialSize = 65536; /// Will take into account the slice name could be using different conventions const DiagnosticInfo* findDiagnosticByName(const UnownedStringSlice& slice) const; diff --git a/source/compiler-core/slang-name.cpp b/source/compiler-core/slang-name.cpp index cc2033339..c815b8aa8 100644 --- a/source/compiler-core/slang-name.cpp +++ b/source/compiler-core/slang-name.cpp @@ -19,7 +19,7 @@ const char* getCstr(Name* name) return name ? name->text.getBuffer() : nullptr; } -Name* NamePool::getName(String const& text) +Name* NamePool::getName(UnownedStringSlice text) { RefPtr<Name> name; if (rootPool->names.tryGetValue(text, name)) @@ -31,6 +31,11 @@ Name* NamePool::getName(String const& text) return name; } +Name* NamePool::getName(String const& text) +{ + return getName(text.getUnownedSlice()); +} + Name* NamePool::tryGetName(String const& text) { RefPtr<Name> name; diff --git a/source/compiler-core/slang-name.h b/source/compiler-core/slang-name.h index cf702686b..f8c1201af 100644 --- a/source/compiler-core/slang-name.h +++ b/source/compiler-core/slang-name.h @@ -68,6 +68,7 @@ struct RootNamePool struct NamePool { // Find or create the `Name` that represents the given `text`. + Name* getName(UnownedStringSlice text); Name* getName(String const& text); // Try find the `Name` that represents the given `text`. // If the name does not exist, return nullptr diff --git a/source/compiler-core/slang-slice-allocator.h b/source/compiler-core/slang-slice-allocator.h index e4ba9e907..a6f0cd5c1 100644 --- a/source/compiler-core/slang-slice-allocator.h +++ b/source/compiler-core/slang-slice-allocator.h @@ -97,7 +97,7 @@ struct SliceAllocator void deallocateAll() { m_arena.deallocateAll(); } SliceAllocator(): - m_arena(1024) + m_arena(2097152) { } protected: diff --git a/source/core/slang-array-view.h b/source/core/slang-array-view.h index 99609ef69..50270e0a0 100644 --- a/source/core/slang-array-view.h +++ b/source/core/slang-array-view.h @@ -197,6 +197,8 @@ namespace Slang return ThisType(m_buffer + index, m_count - index); } + T& getLast() { return m_buffer[m_count - 1]; } + ArrayView() : Super() {} ArrayView(T* buffer, Index size) :Super(buffer, size) {} }; diff --git a/source/core/slang-hash.h b/source/core/slang-hash.h index bc4b30ccc..5f6b1b060 100644 --- a/source/core/slang-hash.h +++ b/source/core/slang-hash.h @@ -138,7 +138,7 @@ namespace Slang template<typename TKey> static HashCode getHashCode(TKey const& key) { - return (HashCode)((PtrInt)key) / 16; // sizeof(typename std::remove_pointer<TKey>::type); + return (HashCode)((PtrInt)key) >> 2; // sizeof(typename std::remove_pointer<TKey>::type); } }; template<> diff --git a/source/core/slang-list.h b/source/core/slang-list.h index ff756035c..250b6dc49 100644 --- a/source/core/slang-list.h +++ b/source/core/slang-list.h @@ -52,6 +52,11 @@ namespace Slang { this->operator=(static_cast<List<T>&&>(list)); } + List(ArrayView<T> view) + : List() + { + addRange(view); + } static List<T> makeRepeated(const T& val, Index count) { List<T> rs; diff --git a/source/core/slang-short-list.h b/source/core/slang-short-list.h index 5bad9faf8..adbb935e6 100644 --- a/source/core/slang-short-list.h +++ b/source/core/slang-short-list.h @@ -117,17 +117,17 @@ namespace Slang } }; - Iterator begin() + Iterator begin() const { Iterator rs; - rs.container = this; + rs.container = const_cast<ThisType*>(this); rs.index = 0; return rs; } - Iterator end() + Iterator end() const { Iterator rs; - rs.container = this; + rs.container = const_cast<ThisType*>(this); rs.index = m_count; return rs; } diff --git a/source/core/slang-token-reader.cpp b/source/core/slang-token-reader.cpp index f6f29def3..e8ebfb9ec 100644 --- a/source/core/slang-token-reader.cpp +++ b/source/core/slang-token-reader.cpp @@ -416,7 +416,7 @@ namespace Misc { tokenFlags |= TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; pos++; } - else if (curChar == ' ' || curChar == '\t' || curChar == -62 || curChar == -96) // -62/-96:non-break space + else if (curChar == ' ' || curChar == '\t' || curChar == '\xC2' || curChar == '\xA0') // -62/-96:non-break space { tokenFlags |= TokenFlag::AfterWhitespace; pos++; diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index ae70a83f4..efd0a743c 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -316,7 +316,7 @@ interface __FlagsEnumType : __EnumType }; __generic<T, let N:int> -__magic_type(ArrayType) +__magic_type(ArrayExpressionType) struct Array { } @@ -876,7 +876,7 @@ extension int16_t /// An `N` component vector with elements of type `T`. __generic<T = float, let N : int = 4> -__magic_type(Vector) +__magic_type(VectorExpressionType) struct vector { /// The element type of the vector @@ -896,7 +896,7 @@ struct vector /// A matrix with `R` rows and `C` columns, with elements of type `T`. __generic<T = float, let R : int = 4, let C : int = 4> -__magic_type(Matrix) +__magic_type(MatrixExpressionType) struct matrix { __intrinsic_op($(kIROp_MakeMatrixFromScalar)) @@ -957,12 +957,12 @@ for (int tt = 0; tt < kTypeCount; ++tt) __generic<T> __intrinsic_type($(kIROp_ConstantBufferType)) -__magic_type(ConstantBuffer) +__magic_type(ConstantBufferType) struct ConstantBuffer {} __generic<T> __intrinsic_type($(kIROp_TextureBufferType)) -__magic_type(TextureBuffer) +__magic_type(TextureBufferType) struct TextureBuffer {} __generic<T> @@ -1238,14 +1238,14 @@ extension matrix<T, R, C> : IDifferentiable //@ public: /// Sampling state for filtered texture fetches. -__magic_type(SamplerState, $(int(SamplerStateFlavor::SamplerState))) +__magic_type(SamplerStateType, $(int(SamplerStateFlavor::SamplerState))) __intrinsic_type($(kIROp_SamplerStateType)) struct SamplerState { } /// Sampling state for filtered texture fetches that include a comparison operation before filtering. -__magic_type(SamplerState, $(int(SamplerStateFlavor::SamplerComparisonState))) +__magic_type(SamplerStateType, $(int(SamplerStateFlavor::SamplerComparisonState))) __intrinsic_type($(kIROp_SamplerComparisonStateType)) struct SamplerComparisonState { @@ -1347,12 +1347,12 @@ struct TextureTypeInfo if(prefixInfo.combined) { - sb << "__magic_type(TextureSampler," << int(flavor) << ")\n"; + sb << "__magic_type(TextureSamplerType," << int(flavor) << ")\n"; sb << "__intrinsic_type(" << (kIROp_TextureSamplerType + (int(flavor) << kIROpMeta_OtherShift)) << ")\n"; } else { - sb << "__magic_type(Texture," << int(flavor) << ")\n"; + sb << "__magic_type(TextureType," << int(flavor) << ")\n"; sb << "__intrinsic_type(" << (kIROp_TextureType + (int(flavor) << kIROpMeta_OtherShift)) << ")\n"; } sb << "struct "; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 9716a3e9e..1ab046b19 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -5048,7 +5048,7 @@ for (int aa = 0; aa < kBaseBufferAccessLevelCount; ++aa) bool isReadOnly = (access == SLANG_RESOURCE_ACCESS_READ); auto flavor = TextureFlavor::create(TextureFlavor::Shape::ShapeBuffer, access).flavor; sb << "__generic<T>\n"; - sb << "__magic_type(Texture," << int(flavor) << ")\n"; + sb << "__magic_type(TextureType," << int(flavor) << ")\n"; sb << "__intrinsic_type(" << (kIROp_TextureType + (int(flavor) << kIROpMeta_OtherShift)) << ")\n"; sb << "struct "; sb << kBaseBufferAccessLevels[aa].name; @@ -5566,7 +5566,7 @@ static const int feedbackTexture2DFlavor = int(TextureFlavor::create(TextureFlav static const int feedbackTexture2DArrayFlavor = int(TextureFlavor::create(TextureFlavor::Shape::Shape2D, SLANG_RESOURCE_ACCESS_WRITE, SLANG_TEXTURE_FEEDBACK_FLAG | SLANG_TEXTURE_ARRAY_FLAG).flavor); }}}} -__magic_type(Texture, $(feedbackTexture2DFlavor)) +__magic_type(TextureType, $(feedbackTexture2DFlavor)) __intrinsic_type($(kIROp_TextureType + (feedbackTexture2DFlavor << kIROpMeta_OtherShift))) struct FeedbackTexture2D<T : __BuiltinSamplerFeedbackType> { @@ -5619,7 +5619,7 @@ struct FeedbackTexture2D<T : __BuiltinSamplerFeedbackType> -__magic_type(Texture, $(feedbackTexture2DArrayFlavor)) +__magic_type(TextureType, $(feedbackTexture2DArrayFlavor)) __intrinsic_type($(kIROp_TextureType + (feedbackTexture2DArrayFlavor << kIROpMeta_OtherShift))) struct FeedbackTexture2DArray<T : __BuiltinSamplerFeedbackType> { diff --git a/source/slang/slang-ast-base.cpp b/source/slang/slang-ast-base.cpp new file mode 100644 index 000000000..0ad2bb101 --- /dev/null +++ b/source/slang/slang-ast-base.cpp @@ -0,0 +1,33 @@ +#include "slang-ast-base.h" +#include "slang-ast-builder.h" + +namespace Slang +{ +void NodeBase::_initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) +{ +#ifdef _DEBUG + SLANG_UNUSED(inAstNodeType); + static int32_t uidCounter = 0; + static int32_t breakValue = 0; + uidCounter++; + _debugUID = uidCounter; + if (inAstBuilder->getId() == -1) + _debugUID = -_debugUID; + if (breakValue != 0 && _debugUID == breakValue) + SLANG_BREAKPOINT(0) +#else + SLANG_UNUSED(inAstNodeType); + SLANG_UNUSED(inAstBuilder); +#endif +} +DeclRefBase* Decl::getDefaultDeclRef() +{ + auto astBuilder = getCurrentASTBuilder(); + if (astBuilder->getEpoch() != m_defaultDeclRefEpoch || !m_defaultDeclRef) + { + m_defaultDeclRef = astBuilder->getDirectDeclRef(this); + m_defaultDeclRefEpoch = astBuilder->getEpoch(); + } + return m_defaultDeclRef; +} +} diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index b90014316..d8f4c8c6c 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -16,25 +16,26 @@ namespace Slang { +class ASTBuilder; +struct SemanticsVisitor; + class NodeBase { SLANG_ABSTRACT_AST_CLASS(NodeBase) // MUST be called before used. Called automatically via the ASTBuilder. // Note that the astBuilder is not stored in the NodeBase derived types by default. - SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* /*astBuilder*/) + SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) { + SLANG_UNUSED(inAstBuilder); astNodeType = inAstNodeType; #ifdef _DEBUG - static uint32_t uidCounter = 0; - static uint32_t breakValue = 0; - uidCounter++; - _debugUID = uidCounter; - if (breakValue != 0 && _debugUID == breakValue) - SLANG_BREAKPOINT(0) + _initDebug(inAstNodeType, inAstBuilder); #endif } + void _initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder); + /// Get the class info SLANG_FORCE_INLINE const ReflectClassInfo& getClassInfo() const { return *ASTClassInfo::getInfo(astNodeType); } @@ -48,7 +49,7 @@ class NodeBase // Handy when debugging, shouldn't be checked in though! // virtual ~NodeBase() {} #ifdef _DEBUG - SLANG_UNREFLECTED uint32_t _debugUID = 0; + SLANG_UNREFLECTED int32_t _debugUID = 0; #endif }; @@ -82,7 +83,14 @@ SLANG_FORCE_INLINE const T* as(const NodeBase* node) // to avoid confusion and bugs. Instead, use the `as<>()` method on `DeclRefBase` to // get a `DeclRef<T>` for a specific node type. template<typename T> -T* as(const DeclRefBase* declRefBase) = delete; +T* as(DeclRefBase* declRefBase, typename EnableIf<!IsBaseOf<DeclRefBase, T>::Value, void*>::type arg = nullptr) = delete; + +template<typename T> +T* as(DeclRefBase* declRefBase, typename EnableIf<IsBaseOf<DeclRefBase, T>::Value, void*>::type arg = nullptr) +{ + SLANG_UNUSED(arg); + return dynamicCast<T>(declRefBase); +} template<typename T, typename U> DeclRef<T> as(DeclRef<U> declRef) { return DeclRef<T>(declRef); } @@ -116,6 +124,177 @@ class SyntaxNodeBase : public NodeBase SourceLoc loc; }; +enum class ValNodeOperandKind +{ + ConstantValue, + ValNode, + ASTNode, +}; + +struct ValNodeOperand +{ + ValNodeOperandKind kind = ValNodeOperandKind::ConstantValue; + union + { + NodeBase* nodeOperand; + int64_t intOperand; + } values; + + ValNodeOperand() + { + values.nodeOperand = nullptr; + } + + explicit ValNodeOperand(NodeBase* node) + { + if (as<Val>(node)) + { + values.nodeOperand = (NodeBase*)node; + kind = ValNodeOperandKind::ValNode; + } + else + { + values.nodeOperand = node; + kind = ValNodeOperandKind::ASTNode; + } + } + + template<typename T> + explicit ValNodeOperand(DeclRef<T> declRef) { values.nodeOperand = declRef.declRefBase; kind = ValNodeOperandKind::ValNode; } + + template<typename T> + explicit ValNodeOperand(T* node) + { + if constexpr (std::is_base_of<Val, T>::value) + { + values.nodeOperand = (NodeBase*)node; + kind = ValNodeOperandKind::ValNode; + } + else if constexpr (std::is_base_of<NodeBase, T>::value) + { + values.nodeOperand = node; + kind = ValNodeOperandKind::ASTNode; + } + else + { + static_assert(std::is_base_of<Val, T>::value || std::is_base_of<NodeBase, T>::value, "pointer used as Val operand must be an AST node."); + } + } + + template<typename EnumType> + explicit ValNodeOperand(EnumType intVal) + { + static_assert(std::is_trivial<EnumType>::value, "Type to construct NodeOperand must be trivial."); + static_assert(sizeof(EnumType) <= sizeof(values), "size of operand must be less than pointer size."); + values.intOperand = 0; + memcpy(&values, &intVal, sizeof(intVal)); + kind = ValNodeOperandKind::ConstantValue; + } +}; + +struct ValNodeDesc +{ + ASTNodeType type; + ShortList<ValNodeOperand, 4> operands; + + bool operator==(ValNodeDesc const& that) const; + HashCode getHashCode() const { return hashCode; } + void init(); +private: + HashCode hashCode = 0; +}; + +template<int N> +static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>&) +{} + +template<int N, typename... Ts> +static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, ExpandedSpecializationArgs e, Ts... ts) +{ + for (auto arg : e) + { + list.add(ValNodeOperand(arg.val)); + list.add(ValNodeOperand(arg.witness)); + } + addOrAppendToNodeList(list, ts...); +} + +template<int N, typename T, typename... Ts> +static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, T t, Ts... ts) +{ + list.add(ValNodeOperand(t)); + addOrAppendToNodeList(list, ts...); +} + +template<int N, typename T, typename... Ts> +static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, const List<T>& l, Ts... ts) +{ + for (auto t : l) + list.add(ValNodeOperand(t)); + addOrAppendToNodeList(list, ts...); +} + +template<int N, typename T, typename... Ts> +static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, ConstArrayView<T> l, Ts... ts) +{ + for (auto t : l) + list.add(ValNodeOperand(t)); + addOrAppendToNodeList(list, ts...); +} + +template<int N, typename T, typename... Ts> +static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, ArrayView<T> l, Ts... ts) +{ + for (auto t : l) + list.add(ValNodeOperand(t)); + addOrAppendToNodeList(list, ts...); +} + +inline void addOrAppendToNodeList(List<ValNodeOperand>&) +{} + +template<typename... Ts> +static void addOrAppendToNodeList(List<ValNodeOperand>& list, ExpandedSpecializationArgs e, Ts... ts) +{ + for (auto arg : e) + { + list.add(ValNodeOperand(arg.val)); + list.add(ValNodeOperand(arg.witness)); + } + addOrAppendToNodeList(list, ts...); +} + +template<typename T, typename... Ts> +static void addOrAppendToNodeList(List<ValNodeOperand>& list, T t, Ts... ts) +{ + list.add(ValNodeOperand(t)); + addOrAppendToNodeList(list, ts...); +} + +template<typename T, typename... Ts> +static void addOrAppendToNodeList(List<ValNodeOperand>& list, const List<T>& l, Ts... ts) +{ + for (auto t : l) + list.add(ValNodeOperand(t)); + addOrAppendToNodeList(list, ts...); +} + +template<typename T, typename... Ts> +static void addOrAppendToNodeList(List<ValNodeOperand>& list, ConstArrayView<T> l, Ts... ts) +{ + for (auto t : l) + list.add(ValNodeOperand(t)); + addOrAppendToNodeList(list, ts...); +} + +template<typename T, typename... Ts> +static void addOrAppendToNodeList(List<ValNodeOperand>& list, ArrayView<T> l, Ts... ts) +{ + for (auto t : l) + list.add(ValNodeOperand(t)); + addOrAppendToNodeList(list, ts...); +} + // Base class for compile-time values (most often a type). // These are *not* syntax nodes, because they do not have // a unique location, and any two `Val`s representing @@ -124,6 +303,54 @@ class SyntaxNodeBase : public NodeBase class Val : public NodeBase { SLANG_ABSTRACT_AST_CLASS(Val) + + template<typename T> + struct OperandView + { + const Val* val; + Index offset; + Index count; + OperandView() + { + val = nullptr; + offset = 0; + count = 0; + } + OperandView(const Val* val, Index offset, Index count) : val(val), offset(offset), count(count) {} + Index getCount() { return count; } + T* operator[](Index index) const + { + return as<T>(val->getOperand(index + offset)); + } + struct Iterator + { + const Val* val; + Index i; + bool operator==(Iterator other) const + { + return val == other.val && i == other.i; + } + bool operator!=(Iterator other) const + { + return val != other.val || i != other.i; + } + T*& operator*() const + { + return *(T**)&val->m_operands[i].values.nodeOperand; + } + T** operator->() const + { + return (T**)&val->m_operands[i].values.nodeOperand; + } + Iterator& operator++() + { + i++; + return *this; + } + }; + Iterator begin() const { return Iterator { val, offset }; } + Iterator end() const { return Iterator{ val, offset + count }; } + }; typedef IValVisitor Visitor; @@ -140,25 +367,84 @@ class Val : public NodeBase // decide whether they need to do anything). Val* substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - bool equalsVal(Val* val); + bool equals(Val* val) const + { + return this == val || const_cast<Val*>(this)->resolve() == val->resolve(); + } // Appends as text to the end of the builder void toText(StringBuilder& out); String toString(); HashCode getHashCode(); - bool operator == (const Val & v) + bool operator == (const Val & v) const { - return equalsVal(const_cast<Val*>(&v)); + return equals(const_cast<Val*>(&v)); } // Overrides should be public so base classes can access Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); + + Val* _resolveImplOverride(); + + Val* resolveImpl(); + Val* resolve(); + ValNodeDesc getDesc(); + + Val* getOperand(Index index) const + { + SLANG_ASSERT(m_operands[index].kind == ValNodeOperandKind::ValNode); + return (Val*)m_operands[index].values.nodeOperand; + } + + Decl* getDeclOperand(Index index) const + { + SLANG_ASSERT(m_operands[index].kind == ValNodeOperandKind::ASTNode); + return (Decl*)(m_operands[index].values.nodeOperand); + } + + int64_t getIntConstOperand(Index index) const + { + SLANG_ASSERT(m_operands[index].kind == ValNodeOperandKind::ConstantValue); + return m_operands[index].values.intOperand; + } + + Index getOperandCount() const { return m_operands.getCount(); } + + template<typename ... TArgs> + void setOperands(TArgs... args) + { + m_operands.clear(); + addOrAppendToNodeList(m_operands, args...); + } + template<typename ... TArgs> + void addOperands(TArgs... args) + { + addOrAppendToNodeList(m_operands, args...); + } + template<typename T> + void addOperands(OperandView<T> operands) + { + for (auto v : operands) + m_operands.add(ValNodeOperand(v)); + } + List<ValNodeOperand> m_operands; +protected: + Val* defaultResolveImpl(); +private: + mutable Val* m_resolvedVal = nullptr; + SLANG_UNREFLECTED mutable Index m_resolvedValEpoch = 0; }; +template<int N, typename T, typename... Ts> +static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, Val::OperandView<T> l, Ts... ts) +{ + for (auto t : l) + list.add(ValNodeOperand(t)); + addOrAppendToNodeList(list, ts...); +} + struct ValSet { struct ValItem @@ -176,9 +462,9 @@ struct ValSet if (val == other.val) return true; if (val) - return val->equalsVal(other.val); + return val->equals(other.val); else if (other.val) - return other.val->equalsVal(val); + return other.val->equals(val); return false; } }; @@ -232,31 +518,32 @@ class Type: public Val /// Type derived types store the AST builder they were constructed on. The builder calls this function /// after constructing. - SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) { Val::init(inAstNodeType, inAstBuilder); m_astBuilder = inAstBuilder; } - - /// Get the ASTBuilder that was used to construct this Type - SLANG_FORCE_INLINE ASTBuilder* getASTBuilder() const { return m_astBuilder; } - - bool equals(Type* type); - - Type* getCanonicalType(); + SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) + { + Val::init(inAstNodeType, inAstBuilder); + m_astBuilderForReflection = inAstBuilder; + } // Overrides should be public so base classes can access Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - bool _equalsValOverride(Val* val); - bool _equalsImplOverride(Type* type); Type* _createCanonicalTypeOverride(); + Val* _resolveImplOverride(); - void _setASTBuilder(ASTBuilder* astBuilder) { m_astBuilder = astBuilder; } + Type* getCanonicalType() + { + return as<Type>(resolve()); + } + ASTBuilder* getASTBuilderForReflection() const { return m_astBuilderForReflection; } protected: - bool equalsImpl(Type* type); Type* createCanonicalType(); - Type* canonicalType = nullptr; - - SLANG_UNREFLECTED - ASTBuilder* m_astBuilder = nullptr; + // We store the ASTBuilder to support reflection API only. + // It should not be used for anything else, especially not for constructing new AST nodes during + // semantic checking, since Val deduplication requires the entire semantic checking process to + // stick with one ASTBuilder. + // Call getCurrentASTBuilder() to obtain the right ASTBuilder for semantic checking. + SLANG_UNREFLECTED ASTBuilder* m_astBuilderForReflection; }; template <typename T> @@ -264,161 +551,68 @@ SLANG_FORCE_INLINE T* as(Type* obj) { return obj ? dynamicCast<T>(obj->getCanoni template <typename T> SLANG_FORCE_INLINE const T* as(const Type* obj) { return obj ? dynamicCast<T>(const_cast<Type*>(obj)->getCanonicalType()) : nullptr; } -// A substitution represents a binding of certain -// type-level variables to concrete argument values -class Substitutions: public NodeBase -{ - SLANG_ABSTRACT_AST_CLASS(Substitutions) - - - // Apply a set of substitutions to the bindings in this substitution - Substitutions* applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff); - - // Check if these are equivalent substitutions to another set - bool equals(Substitutions* subst); - HashCode getHashCode() const; - - // Overrides should be public so base classes can access - Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff); - bool _equalsOverride(Substitutions* subst); - HashCode _getHashCodeOverride() const; - - Substitutions* getOuter() const { return outer; } -protected: - // The next outer that this one refines. - Substitutions* outer = nullptr; -}; - -class GenericSubstitution : public Substitutions -{ - SLANG_AST_CLASS(GenericSubstitution) - -private: - // The generic declaration that defines the - // parameters we are binding to arguments - GenericDecl* genericDecl = nullptr; - - // The actual values of the arguments - List<Val* > args; -public: - GenericDecl* getGenericDecl() const { return genericDecl; } - List<Val*>& getArgs() { return args; } - const List<Val*>& getArgs() const { return args; } - - // Overrides should be public so base classes can access - Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff); - bool _equalsOverride(Substitutions* subst); - HashCode _getHashCodeOverride() const; - - GenericSubstitution(Substitutions* outerSubst, GenericDecl* decl, ArrayView<Val*> argVals) - { - outer = outerSubst; - genericDecl = decl; - args.addRange(argVals); - } -}; - -class ThisTypeSubstitution : public Substitutions -{ - SLANG_AST_CLASS(ThisTypeSubstitution) - - // The declaration of the interface that we are specializing - InterfaceDecl* interfaceDecl = nullptr; - - // A witness that shows that the concrete type used to - // specialize the interface conforms to the interface. - SubtypeWitness* witness = nullptr; - - // Overrides should be public so base classes can access - // The actual type that provides the lookup scope for an associated type - Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff); - bool _equalsOverride(Substitutions* subst); - HashCode _getHashCodeOverride() const; - - ThisTypeSubstitution(Substitutions* outerSubst, InterfaceDecl* inInterfaceDecl, SubtypeWitness* inWitness) - : interfaceDecl(inInterfaceDecl), witness(inWitness) - { - outer = outerSubst; - } -}; - class Decl; // A reference to a declaration, which may include // substitutions for generic parameters. class DeclRefBase : public Val { - SLANG_AST_CLASS(DeclRefBase) + SLANG_ABSTRACT_AST_CLASS(DeclRefBase) - Decl* getDecl() const { return decl; } + Decl* getDecl() const { return getDeclOperand(0); } - Substitutions* getSubst() const { return substitutions; } + // Apply substitutions to this declaration reference + DeclRefBase* substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - DeclRefBase(Decl* decl) - :decl(decl) + DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { + SLANG_UNUSED(astBuilder); + SLANG_UNUSED(subst); + SLANG_UNUSED(ioDiff); + SLANG_UNREACHABLE("DeclRefBase::_substituteImplOverride not overrided."); } - DeclRefBase(Decl* decl, Substitutions* subst) - :decl(decl), substitutions(subst) + void _toTextOverride(StringBuilder& out) { + SLANG_UNUSED(out); + SLANG_UNREACHABLE("DeclRefBase::_toTextOverride not overrided."); } - // Apply substitutions to a type or declaration - Type* substitute(ASTBuilder* astBuilder, Type* type) const; - - DeclRefBase* substitute(ASTBuilder* astBuilder, DeclRefBase* declRef) const; - - // Apply substitutions to an expression - SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const; - - // Apply substitutions to this declaration reference - DeclRefBase* substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; - - Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) + Val* _resolveImplOverride() { - return substituteImpl(astBuilder, subst, ioDiff); + SLANG_UNREACHABLE("DeclRefBase::_resolveImplOverride not overrided."); } - bool _equalsValOverride(Val* val); - - bool _equalsImplOverride(DeclRefBase* declRef) { return equals(declRef); } - void _toTextOverride(StringBuilder& out) { toText(out); } + DeclRefBase* _getBaseOverride() + { + SLANG_UNREACHABLE("DeclRefBase::_getBaseOverride not overrided."); + } // Returns true if 'as' will return a valid cast template <typename T> - bool is() const { return Slang::as<T>(decl) != nullptr; } - - // Check if this is an equivalent declaration reference to another - bool equals(DeclRefBase* declRef) const; + bool is() const { return Slang::as<T>(getDecl()) != nullptr; } // Convenience accessors for common properties of declarations Name* getName() const; SourceLoc getNameLoc() const; SourceLoc getLoc() const; - DeclRefBase* getParent(ASTBuilder* astBuilder) const; - - HashCode getHashCode() const; - - // Debugging: - String toString() const; - void toText(StringBuilder& out) const; - -private: - - // The underlying declaration - Decl* decl = nullptr; - // Optionally, a chain of substitutions to perform - Substitutions* substitutions = nullptr; - + DeclRefBase* getParent(); + String toString() const + { + StringBuilder sb; + const_cast<DeclRefBase*>(this)->toText(sb); + return sb.produceString(); + } + DeclRefBase* getBase(); + void toText(StringBuilder& out); }; -SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const DeclRefBase* declRef) { declRef->toText(io); return io; } +SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const DeclRefBase* declRef) { if (declRef) const_cast<DeclRefBase*>(declRef)->toText(io); return io; } -SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const Decl* decl) +SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, Decl* decl) { if (decl) - _printNestedDecl(nullptr, decl, io); + makeDeclRef(decl).declRefBase->toText(io); return io; } @@ -488,12 +682,7 @@ public: ContainerDecl* parentDecl = nullptr; - // A direct DeclRef to this Decl. - // For every Decl, we create a DeclRef node representing a direct reference to it - // upon the creation of the Decl (implemented in ASTBuilder), and store that - // DeclRef here so we can get a direct DeclRef from a Decl* (by calling makeDeclRef) - // without requiring a ASTBuilder to be available. - DeclRefBase* defaultDeclRef = nullptr; + DeclRefBase* getDefaultDeclRef(); NameLoc nameAndLoc; @@ -514,6 +703,10 @@ public: SLANG_RELEASE_ASSERT(state >= checkState.getState()); checkState.setState(state); } + +private: + SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr; + SLANG_UNREFLECTED Index m_defaultDeclRefEpoch = -1; }; class Expr : public SyntaxNode @@ -551,8 +744,7 @@ DeclRef<T>::DeclRef(Decl* decl) DeclRefBase* declRef = nullptr; if (decl) { - SLANG_ASSERT(decl->defaultDeclRef); - declRef = decl->defaultDeclRef; + declRef = decl->getDefaultDeclRef(); } init(declRef); } @@ -564,12 +756,6 @@ T* DeclRef<T>::getDecl() const } template<typename T> -Substitutions* DeclRef<T>::getSubst() const -{ - return declRefBase ? declRefBase->getSubst() : nullptr; -} - -template<typename T> Name* DeclRef<T>::getName() const { if (declRefBase) @@ -592,9 +778,9 @@ SourceLoc DeclRef<T>::getLoc() const } template<typename T> -DeclRef<ContainerDecl> DeclRef<T>::getParent(ASTBuilder* astBuilder) const +DeclRef<ContainerDecl> DeclRef<T>::getParent() const { - if (declRefBase) return DeclRef<ContainerDecl>(declRefBase->getParent(astBuilder)); + if (declRefBase) return DeclRef<ContainerDecl>(declRefBase->getParent()); return DeclRef<ContainerDecl>((DeclRefBase*)nullptr); } @@ -608,15 +794,17 @@ HashCode DeclRef<T>::getHashCode() const template<typename T> Type* DeclRef<T>::substitute(ASTBuilder* astBuilder, Type* type) const { + SLANG_UNUSED(astBuilder); if (!declRefBase) return type; - return declRefBase->substitute(astBuilder, type); + return SubstitutionSet(*this).applyToType(astBuilder, type); } template<typename T> SubstExpr<Expr> DeclRef<T>::substitute(ASTBuilder* astBuilder, Expr* expr) const { + SLANG_UNUSED(astBuilder); if (!declRefBase) return expr; - return declRefBase->substitute(astBuilder, expr); + return applySubstitutionToExpr(SubstitutionSet(*this), expr); } // Apply substitutions to a type or declaration @@ -624,23 +812,21 @@ template<typename T> template<typename U> DeclRef<U> DeclRef<T>::substitute(ASTBuilder* astBuilder, DeclRef<U> declRef) const { + SLANG_UNUSED(astBuilder); if (!declRefBase) return declRef; - return DeclRef<U>(declRefBase->substitute(astBuilder, declRef.declRefBase)); + return DeclRef<U>(SubstitutionSet(*this).applyToDeclRef(astBuilder, declRef.declRefBase)); } // Apply substitutions to this declaration reference template<typename T> DeclRef<T> DeclRef<T>::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const { + SLANG_UNUSED(astBuilder); if (!declRefBase) return *this; return DeclRef<T>(declRefBase->substituteImpl(astBuilder, subst, ioDiff)); } -template<typename T> -template<typename U> -bool DeclRef<T>::equals(DeclRef<U> other) const -{ - return declRefBase == other.declRefBase || (declRefBase && declRefBase->equals(other.declRefBase)); -} +Val::OperandView<Val> tryGetGenericArguments(SubstitutionSet substSet, Decl* genericDecl); + } // namespace Slang diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 64a7abd8c..96fb6ac79 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -29,12 +29,6 @@ void SharedASTBuilder::init(Session* session) // Clear the built in types memset(m_builtinTypes, 0, sizeof(m_builtinTypes)); - // Create common shared types - m_errorType = m_astBuilder->create<ErrorType>(); - m_bottomType = m_astBuilder->create<BottomType>(); - m_initializerListType = m_astBuilder->create<InitializerListType>(); - m_overloadedType = m_astBuilder->create<OverloadGroupType>(); - // We can just iterate over the class pointers. // NOTE! That this adds the names of the abstract classes too(!) for (Index i = 0; i < Index(ASTNodeType::CountOf); ++i) @@ -151,6 +145,31 @@ Type* SharedASTBuilder::getDiffInterfaceType() return m_diffInterfaceType; } +Type* SharedASTBuilder::getErrorType() +{ + if (!m_errorType) + m_errorType = m_astBuilder->getOrCreate<ErrorType>(); + return m_errorType; +} +Type* SharedASTBuilder::getBottomType() +{ + if (!m_bottomType) + m_bottomType = m_astBuilder->getOrCreate<BottomType>(); + return m_bottomType; +} +Type* SharedASTBuilder::getInitializerListType() +{ + if (!m_initializerListType) + m_initializerListType = m_astBuilder->getOrCreate<InitializerListType>(); + return m_initializerListType; +} +Type* SharedASTBuilder::getOverloadedType() +{ + if (!m_overloadedType) + m_overloadedType = m_astBuilder->getOrCreate<OverloadGroupType>(); + return m_overloadedType; +} + SharedASTBuilder::~SharedASTBuilder() { // Release built in types.. @@ -208,19 +227,28 @@ Decl* SharedASTBuilder::tryFindMagicDecl(const String& name) // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +Index& _getGlobalASTEpochId() +{ + static thread_local Index epochId = 1; + return epochId; +} + ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name): m_sharedASTBuilder(sharedASTBuilder), m_name(name), m_id(sharedASTBuilder->m_id++), - m_arena(2048) + m_arena(2097152) { SLANG_ASSERT(sharedASTBuilder); + // Copy Val deduplication map over so we don't create duplicate Vals that are already + // existent in the stdlib. + m_cachedNodes = sharedASTBuilder->getInnerASTBuilder()->m_cachedNodes; } ASTBuilder::ASTBuilder(): m_sharedASTBuilder(nullptr), m_id(-1), - m_arena(2048) + m_arena(2097152) { m_name = "SharedASTBuilder::m_astBuilder"; } @@ -233,6 +261,25 @@ ASTBuilder::~ASTBuilder() SLANG_ASSERT(info->m_destructorFunc); info->m_destructorFunc(node); } + incrementEpoch(); +} + +Index ASTBuilder::getEpoch() +{ + return _getGlobalASTEpochId(); +} + +void ASTBuilder::incrementEpoch() +{ + _getGlobalASTEpochId()++; +} + +void ASTBuilder::_verifyValDescConsistency(Val* val, const ValNodeDesc& expectedDesc) +{ + if (!val) + return; + ValNodeDesc descOut = val->getDesc(); + SLANG_ASSERT(descOut == expectedDesc); } NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType) @@ -256,6 +303,13 @@ Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTy return rsType; } +Type* ASTBuilder::getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const char* magicTypeName) +{ + auto declRef = getBuiltinDeclRef(magicTypeName, genericArgs); + auto rsType = DeclRefType::create(this, declRef); + return rsType; +} + PtrType* ASTBuilder::getPtrType(Type* valueType) { return dynamicCast<PtrType>(getPtrType(valueType, "PtrType")); @@ -292,64 +346,57 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element { if (!elementCount) elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength); - - auto result = getOrCreate<ArrayExpressionType>(elementType, elementCount); - if (!result->declRef.getDecl()) + if (elementCount->getType() != getIntType()) { - auto arrayGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ArrayType")); - auto arrayTypeDecl = arrayGenericDecl->inner; - auto substitutions = getOrCreateGenericSubstitution(nullptr, arrayGenericDecl, elementType, elementCount); - result->declRef = getSpecializedDeclRef<Decl>(arrayTypeDecl, substitutions); + // Canonicalize constant elementCount to int. + if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount)) + { + elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue()); + } } - return result; + Val* args[] = {elementType, elementCount}; + return as<ArrayExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType")); } ConstantBufferType* ASTBuilder::getConstantBufferType(Type* elementType) { - auto result = getOrCreate<ConstantBufferType>(elementType); - if (!result->declRef.getDecl()) - { - auto genericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ConstantBuffer")); - auto typeDecl = genericDecl->inner; - auto substitutions = getOrCreateGenericSubstitution(nullptr, genericDecl, elementType); - result->declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions); - } - return result; + return as<ConstantBufferType>(getSpecializedBuiltinType(elementType, "ConstantBufferType")); +} + +ParameterBlockType* ASTBuilder::getParameterBlockType(Type* elementType) +{ + return as<ParameterBlockType>(getSpecializedBuiltinType(elementType, "ParameterBlockType")); +} + +HLSLStructuredBufferType* ASTBuilder::getStructuredBufferType(Type* elementType) +{ + return as<HLSLStructuredBufferType>(getSpecializedBuiltinType(elementType, "HLSLStructuredBufferType")); +} + +SamplerStateType* ASTBuilder::getSamplerStateType() +{ + return as<SamplerStateType>(getSpecializedBuiltinType(nullptr, "HLSLStructuredBufferType")); } VectorExpressionType* ASTBuilder::getVectorType( Type* elementType, IntVal* elementCount) { - auto result = getOrCreate<VectorExpressionType>(elementType, elementCount); - if (!result->declRef.getDecl()) + // Canonicalize constant elementCount to int. + if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount)) { - auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector")); - auto vectorTypeDecl = vectorGenericDecl->inner; - auto substitutions = getOrCreateGenericSubstitution(nullptr, vectorGenericDecl, elementType, elementCount); - result->declRef = getSpecializedDeclRef<Decl>(vectorTypeDecl, substitutions); + elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue()); } - return result; + Val* args[] = { elementType, elementCount }; + return as<VectorExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType")); } DifferentialPairType* ASTBuilder::getDifferentialPairType( Type* valueType, Witness* primalIsDifferentialWitness) { - auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl("DifferentialPairType")); - - auto typeDecl = genericDecl->inner; - - auto substitutions = getOrCreateGenericSubstitution( - nullptr, - genericDecl, - valueType, - primalIsDifferentialWitness); - - auto declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions); - auto rsType = DeclRefType::create(this, declRef); - - return as<DifferentialPairType>(rsType); + Val* args[] = { valueType, primalIsDifferentialWitness }; + return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType")); } DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl() @@ -377,20 +424,9 @@ MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier( : as<HLSLIndicesModifier>(modifier) ? "IndicesType" : as<HLSLPrimitivesModifier>(modifier) ? "PrimitivesType" : (SLANG_UNEXPECTED("Unhandled mesh output modifier"), nullptr); - auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl(declName)); - - auto typeDecl = genericDecl->inner; - - auto substitutions = getOrCreateGenericSubstitution( - nullptr, - genericDecl, - elementType, - maxElementCount); - auto declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions); - auto rsType = DeclRefType::create(this, declRef); - - return as<MeshOutputType>(rsType); + Val* args[] = { elementType, maxElementCount }; + return as<MeshOutputType>(getSpecializedBuiltinType(makeArrayView(args), declName)); } Type* ASTBuilder::getDifferentiableInterfaceType() @@ -403,13 +439,8 @@ DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); if (auto genericDecl = as<GenericDecl>(decl)) { - decl = genericDecl->inner; - Substitutions* subst = nullptr; - if (genericArg) - { - subst = getOrCreateGenericSubstitution(nullptr, genericDecl, genericArg); - } - return getSpecializedDeclRef(decl, subst); + auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), makeConstArrayViewSingle(genericArg)); + return declRef; } else { @@ -418,6 +449,21 @@ DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va return makeDeclRef(decl); } +DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ArrayView<Val*> genericArgs) +{ + auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); + if (auto genericDecl = as<GenericDecl>(decl)) + { + auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), genericArgs); + return declRef; + } + else + { + SLANG_ASSERT(!decl && !genericArgs.getCount()); + } + return makeDeclRef(decl); +} + Type* ASTBuilder::getAndType(Type* left, Type* right) { auto type = getOrCreate<AndType>(left, right); @@ -426,9 +472,7 @@ Type* ASTBuilder::getAndType(Type* left, Type* right) Type* ASTBuilder::getModifiedType(Type* base, Count modifierCount, Val* const* modifiers) { - auto type = create<ModifiedType>(); - type->base = base; - type->modifiers.addRange(modifiers, modifierCount); + auto type = getOrCreate<ModifiedType>(base, makeArrayView((Val**)modifiers, modifierCount)); return type; } @@ -447,15 +491,16 @@ Val* ASTBuilder::getNoDiffModifierVal() return getOrCreate<NoDiffModifierVal>(); } -Type* ASTBuilder::getFuncType(List<Type*> parameters, Type* result) +FuncType* ASTBuilder::getFuncType(ArrayView<Type*> parameters, Type* result, Type* errorType) { - auto errorType = getOrCreate<BottomType>(); + if (!errorType) + errorType = getOrCreate<BottomType>(); return getOrCreate<FuncType>(parameters, result, errorType); } -Type* ASTBuilder::getTupleType(List<Type*>& types) +TupleType* ASTBuilder::getTupleType(List<Type*>& types) { - return getOrCreate<TupleType>(types); + return getOrCreate<TupleType>(types.getArrayView()); } TypeType* ASTBuilder::getTypeType(Type* type) @@ -466,11 +511,11 @@ TypeType* ASTBuilder::getTypeType(Type* type) TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness( Type* type) { - return getOrCreate<TypeEqualityWitness>(type); + return getOrCreate<TypeEqualityWitness>(type, type); } -SubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness( +DeclaredSubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness( Type* subType, Type* superType, DeclRef<Decl> const& declRef) @@ -517,8 +562,8 @@ top: // Let's call the intermediate type here `x`, we know that the `b <: c` // witness is based on witnesses that `b <: x` and `x <: c`: // - auto bIsSubtypeOfXWitness = bIsTransitiveSubtypeOfCWitness->subToMid; - auto xIsSubtypeOfCWitness = bIsTransitiveSubtypeOfCWitness->midToSup; + auto bIsSubtypeOfXWitness = bIsTransitiveSubtypeOfCWitness->getSubToMid(); + auto xIsSubtypeOfCWitness = bIsTransitiveSubtypeOfCWitness->getMidToSup(); // We can recursively call this operation to produce a witness that // `a <: x`, based on the witnesses we already have for `a <: b` and `b <: x`: @@ -535,8 +580,8 @@ top: goto top; } - auto aType = aIsSubtypeOfBWitness->sub; - auto cType = bIsSubtypeOfCWitness->sup; + auto aType = aIsSubtypeOfBWitness->getSub(); + auto cType = bIsSubtypeOfCWitness->getSup(); // If the right-hand side is a conjunction witness for `B <: C` // of the form `(B <: X)&(B <: Y)`, then we have it that `C = X&Y` @@ -565,8 +610,8 @@ top: // the witness `W` that `B <: X&Y&...` as well as the index // `i` of `C` within the conjunction. // - auto bIsSubtypeOfConjunction = bIsSubtypeViaExtraction->conjunctionWitness; - auto indexOfCInConjunction = bIsSubtypeViaExtraction->indexInConjunction; + auto bIsSubtypeOfConjunction = bIsSubtypeViaExtraction->getConjunctionWitness(); + auto indexOfCInConjunction = bIsSubtypeViaExtraction->getIndexInConjunction(); // We lift the extraction to the outside of the composition, by // forming a witness for `A <: C` that is of the form @@ -591,24 +636,14 @@ top: // formal set of rules for the allowed structure of our witnesses to // guarantee that our simplifications are sufficient. - TransitiveSubtypeWitness* transitiveWitness = getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>( + TransitiveSubtypeWitness* transitiveWitness = getOrCreate<TransitiveSubtypeWitness>( aType, cType, aIsSubtypeOfBWitness, bIsSubtypeOfCWitness); - transitiveWitness->sub = aType; - transitiveWitness->sup = cType; - transitiveWitness->subToMid = aIsSubtypeOfBWitness; - transitiveWitness->midToSup = bIsSubtypeOfCWitness; - return transitiveWitness; } -ThisTypeSubtypeWitness* ASTBuilder::getThisTypeSubtypeWitness(Type* subType, Type* superType) -{ - return getOrCreate<ThisTypeSubtypeWitness>(subType, superType); -} - SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness( Type* subType, Type* superType, @@ -633,16 +668,11 @@ SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness( // // * What if the original witness is transitive? - auto witness = getOrCreateWithDefaultCtor<ExtractFromConjunctionSubtypeWitness>( + auto witness = getOrCreate<ExtractFromConjunctionSubtypeWitness>( subType, superType, conjunctionWitness, indexOfSuperTypeInConjunction); - - witness->sub = subType; - witness->sup = superType; - witness->conjunctionWitness = conjunctionWitness; - witness->indexInConjunction = indexOfSuperTypeInConjunction; return witness; } @@ -662,11 +692,11 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness( auto rExtract = as<ExtractFromConjunctionSubtypeWitness>(subIsRWitness); if(lExtract && rExtract) { - if (lExtract->indexInConjunction == 0 - && rExtract->indexInConjunction == 1) + if (lExtract->getIndexInConjunction() == 0 + && rExtract->getIndexInConjunction() == 1) { - auto lInner = lExtract->conjunctionWitness; - auto rInner = rExtract->conjunctionWitness; + auto lInner = lExtract->getConjunctionWitness(); + auto rInner = rExtract->getConjunctionWitness(); if (lInner == rInner) { return lInner; @@ -685,57 +715,30 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness( // witness) deeper, so that we have more chances to expose a // conjunction witness at higher levels. - auto witness = getOrCreateWithDefaultCtor<ConjunctionSubtypeWitness>( + auto witness = getOrCreate<ConjunctionSubtypeWitness>( sub, lAndR, subIsLWitness, subIsRWitness); - witness->componentWitnesses[0] = subIsLWitness; - witness->componentWitnesses[1] = subIsRWitness; - witness->sub = sub; - witness->sup = lAndR; return witness; } -bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const +DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl) { - if (hashCode != that.hashCode) return false; - if(type != that.type) return false; - if(operands.getCount() != that.operands.getCount()) return false; - for(Index i = 0; i < operands.getCount(); ++i) - { - // Note: we are comparing the operands directly for identity - // (pointer equality) rather than doing the `Val`-level - // equality check. - // - // The rationale here is that nodes that will be created - // via a `NodeDesc` *should* all be going through the - // deduplication path anyway, as should their operands. - // - if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; - } - return true; + return builder->getMemberDeclRef(parent, decl); } -void ASTBuilder::NodeDesc::init() + +thread_local ASTBuilder* gCurrentASTBuilder = nullptr; + +ASTBuilder* getCurrentASTBuilder() { - Hasher hasher; - hasher.hashValue(Int(type)); - for(Index i = 0; i < operands.getCount(); ++i) - { - // Note: we are hashing the raw pointer value rather - // than the content of the value node. This is done - // to match the semantics implemented for `==` on - // `NodeDesc`. - // - hasher.hashValue(operands[i].values.nodeOperand); - } - hashCode = hasher.getResult(); + return gCurrentASTBuilder; } -DeclRef<Decl> _getSpecializedDeclRef(ASTBuilder* builder, Decl* decl, Substitutions* subst) +void setCurrentASTBuilder(ASTBuilder* astBuilder) { - return builder->getSpecializedDeclRef(decl, subst); + gCurrentASTBuilder = astBuilder; } } // namespace Slang diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index cf0975cdd..0d63e1060 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -39,6 +39,11 @@ public: /// Get the `IDifferentiable` type Type* getDiffInterfaceType(); + Type* getErrorType(); + Type* getBottomType(); + Type* getInitializerListType(); + Type* getOverloadedType(); + const ReflectClassInfo* findClassInfo(Name* name); SyntaxClass<NodeBase> findSyntaxClass(Name* name); @@ -65,6 +70,8 @@ public: ~SharedASTBuilder(); + ASTBuilder* getInnerASTBuilder() { return m_astBuilder; } + protected: // State shared between ASTBuilders @@ -108,79 +115,59 @@ protected: class ASTBuilder : public RefObject { friend class SharedASTBuilder; -public: - // Node cache: - struct NodeOperand - { - union - { - NodeBase* nodeOperand; - int64_t intOperand; - } values; - - NodeOperand() - { - values.nodeOperand = nullptr; - } - - NodeOperand(NodeBase* node) { values.nodeOperand = node; } - - template<typename T> - NodeOperand(DeclRef<T> declRef) { values.nodeOperand = declRef.declRefBase; } - - template<typename EnumType> - NodeOperand(EnumType intVal) - { - static_assert(std::is_trivial<EnumType>::value, "Type to construct NodeOperand must be trivial."); - static_assert(sizeof(EnumType) <= sizeof(values), "size of operand must be less than pointer size."); - values.intOperand = 0; - memcpy(&values, &intVal, sizeof(intVal)); - } - }; - struct NodeDesc - { - ASTNodeType type; - ShortList<NodeOperand, 4> operands; - - bool operator==(NodeDesc const& that) const; - HashCode getHashCode() const { return hashCode; } - void init(); - private: - HashCode hashCode = 0; - }; +public: template<typename NodeCreateFunc> - NodeBase* _getOrCreateImpl(NodeDesc const& desc, NodeCreateFunc createFunc) + NodeBase* _getOrCreateImpl(ValNodeDesc const& desc, NodeCreateFunc createFunc) { if (auto found = m_cachedNodes.tryGetValue(desc)) return *found; auto node = createFunc(); m_cachedNodes.add(desc, node); +#ifdef _DEBUG + _verifyValDescConsistency(dynamicCast<Val>(node), desc); +#endif return node; } /// A cache for AST nodes that are entirely defined by their node type, with /// no need for additional state. - Dictionary<NodeDesc, NodeBase*> m_cachedNodes; + Dictionary<ValNodeDesc, NodeBase*> m_cachedNodes; + + Dictionary<GenericDecl*, List<Val*>> m_cachedGenericDefaultArgs; + + /// Create AST types + template <typename T> + T* createImpl() + { + auto alloced = m_arena.allocate(sizeof(T)); + memset(alloced, 0, sizeof(T)); + auto result = _initAndAdd(new (alloced) T); + return result; + } - template<int N> - static void addOrAppendToNodeList(ShortList<NodeOperand, N>&) - {} + template<typename T, typename... TArgs> + T* createImpl(TArgs&&... args) + { + auto alloced = m_arena.allocate(sizeof(T)); + memset(alloced, 0, sizeof(T)); + auto result = _initAndAdd(new (alloced) T(std::forward<TArgs>(args)...)); + return result; + } - template<int N, typename T, typename... Ts> - static void addOrAppendToNodeList(ShortList<NodeOperand, N>& list, T t, Ts... ts) + template <typename T> + T* create() { - list.add(t); - addOrAppendToNodeList(list, ts...); + static_assert(!IsBaseOf<Val, T>::Value, "ASTBuilder::create cannot be used to create a Val, use getOrCreate instead."); + return createImpl<T>(); } - template<int N, typename T, typename... Ts> - static void addOrAppendToNodeList(ShortList<NodeOperand, N>& list, const List<T>& l, Ts... ts ) + template<typename T, typename... TArgs> + T* create(TArgs&&... args) { - for(auto t : l) - list.add(t); - addOrAppendToNodeList(list, ts...); + static_assert(!IsBaseOf<Val, T>::Value, "ASTBuilder::create cannot be used to create a Val, use getOrCreate instead."); + return createImpl<T>(args...); } public: @@ -195,37 +182,27 @@ public: }; }; - MemoryArena& getArena() { return m_arena; } + Index getEpoch(); - /// Create AST types - template <typename T> - T* create() - { - auto alloced = m_arena.allocate(sizeof(T)); - memset(alloced, 0, sizeof(T)); - return _initAndAdd(new (alloced) T); - } + void incrementEpoch(); - template<typename T, typename... TArgs> - T* create(TArgs&&... args) - { - auto alloced = m_arena.allocate(sizeof(T)); - memset(alloced, 0, sizeof(T)); - return _initAndAdd(new (alloced) T(std::forward<TArgs>(args)...)); - } + MemoryArena& getArena() { return m_arena; } + + void _verifyValDescConsistency(Val* val, const ValNodeDesc& expectedDesc); template<typename T, typename ... TArgs> SLANG_FORCE_INLINE T* getOrCreate(TArgs ... args) { SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); - NodeDesc desc; + ValNodeDesc desc; desc.type = T::kType; addOrAppendToNodeList(desc.operands, args...); desc.init(); - return (T*)_getOrCreateImpl(desc, [&]() + auto result = (T*)_getOrCreateImpl(desc, [&]() { - return create<T>(args...); + return createImpl<T>(args...); }); + return result; } template<typename T> @@ -233,63 +210,101 @@ public: { SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); - NodeDesc desc; + ValNodeDesc desc; desc.type = T::kType; desc.init(); - return (T*)_getOrCreateImpl(desc, [this]() { return create<T>(); }); + auto result = (T*)_getOrCreateImpl(desc, [this]() { return createImpl<T>(); }); +#ifdef _DEBUG + _verifyValDescConsistency(dynamicCast<Val>(result), desc); +#endif + return result; } - template<typename T, typename ... TArgs> - SLANG_FORCE_INLINE T* getOrCreateWithDefaultCtor(TArgs ... args) + InterfaceDecl* createInterfaceDecl(SourceLoc loc) { - SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); - NodeDesc desc; - desc.type = T::kType; - addOrAppendToNodeList(desc.operands, args...); - desc.init(); - return (T*)_getOrCreateImpl(desc, [&]() - { - return create<T>(); - }); + auto interfaceDecl = create<InterfaceDecl>(); + // Always include a `This` member and a `This:IThisInterface` member. + auto thisDecl = create<ThisTypeDecl>(); + thisDecl->nameAndLoc.name = m_sharedASTBuilder->getNamePool()->getName(UnownedStringSlice("This", 4)); + thisDecl->nameAndLoc.loc = loc; + interfaceDecl->addMember(thisDecl); + auto thisConstraint = create<ThisTypeConstraintDecl>(); + thisConstraint->loc = loc; + thisConstraint->base.type = DeclRefType::create(this, getDirectDeclRef(interfaceDecl)); + thisDecl->addMember(thisConstraint); + return interfaceDecl; } template<typename T> - SLANG_FORCE_INLINE T* getOrCreateWithDefaultCtor(ConstArrayView<NodeOperand> operands) + DeclRef<T> getDirectDeclRef(T* decl) { - SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); - NodeDesc desc; - desc.type = T::kType; - desc.operands.addRange(operands); - desc.init(); - return (T*)_getOrCreateImpl(desc, [&]() - { - return create<T>(); - }); + if (!decl) + return DeclRef<T>(); + + auto result = DeclRef<T>(getOrCreate<DirectDeclRef>(decl)); + return result; } - // This is the bottlneck through which all DeclRefs are created. template<typename T> - DeclRef<T> getSpecializedDeclRef(T* decl, Substitutions* subst) + DeclRef<T> getMemberDeclRef(DeclRef<Decl> parent, T* memberDecl) { - // We never create an actual DeclRefBase node to point to a null decl. - if (!decl) - return DeclRef<T>(); - - // If we don't have substitutions, use the default decl ref if it is created. - if (!subst) + if (!parent) + return getDirectDeclRef(memberDecl); + // A Generic value/type ParamDecl is always referred to directly. + if (as<GenericTypeParamDecl>(memberDecl) || as<GenericValueParamDecl>(memberDecl)) + return getDirectDeclRef(memberDecl); + if (as<ThisTypeDecl>(memberDecl) && !as<InterfaceDecl>(memberDecl->parentDecl)) + return as<T>(parent); + + if (auto parentMemberDeclRef = as<MemberDeclRef>(parent.declRefBase)) { - auto defaultDeclRef = static_cast<Decl*>(decl)->defaultDeclRef; - if (defaultDeclRef) - return defaultDeclRef; + return DeclRef<T>(getMemberDeclRef(parentMemberDeclRef->getParent(), memberDecl)); + } + else if (auto lookupDeclRef = as<LookupDeclRef>(parent.declRefBase)) + { + // Handle some specicial case rules due to the way some of our builtin decls are + // represented. + // - Member(Lookup(w, This), x) ==> Lookup(w, X) + // Lookup of x from This is a lookup from w directly. + // - Member(Lookup(w, someExtension), x) ==> Lookup(w, X) + // Lookup of a decl defined in an extension is to lookup directly. + // - Member(Lookup(w, AssociatedType), TypeConstraintDecl) ==> Lookup(w, TypeConstraintDecl) + // Type constraint of an associated type is defined directly in w. + + auto parentDeclKind = lookupDeclRef->getDecl()->astNodeType; + switch (parentDeclKind) + { + case ASTNodeType::ThisTypeDecl: + case ASTNodeType::ExtensionDecl: + case ASTNodeType::AssocTypeDecl: + return getLookupDeclRef(lookupDeclRef->getLookupSource(), lookupDeclRef->getWitness(), memberDecl); + default: + break; + } + } + else if (auto directDeclRef = as<DirectDeclRef>(parent.declRefBase)) + { + return DeclRef<T>(getOrCreate<DirectDeclRef>(memberDecl)); } - return getOrCreate<DeclRefBase>(decl, subst); - } +#if _DEBUG + // Verify that member is indeed a member of parent. + auto parentDecl = parent.getDecl(); + while (as<ThisTypeDecl>(parentDecl)) + parentDecl = parentDecl->parentDecl; + bool foundParent = false; + for (Decl* dd = memberDecl; dd; dd = dd->parentDecl) + { + if (dd == parentDecl) + { + foundParent = true; + break; + } + } + SLANG_ASSERT(foundParent); +#endif - template<typename T> - DeclRef<T> getSpecializedDeclRef(T* decl, SubstitutionSet subst) - { - return getSpecializedDeclRef(decl, subst.substitutions); + return DeclRef<T>(getOrCreate<MemberDeclRef>(memberDecl, parent.declRefBase)); } ConstantIntVal* getIntVal(Type* type, IntegerLiteralValue value) @@ -297,61 +312,38 @@ public: return getOrCreate<ConstantIntVal>(type, value); } - GenericSubstitution* getOrCreateGenericSubstitution(Substitutions* outer, GenericDecl* decl, ArrayView<Val*> args) + DeclRef<Decl> getGenericAppDeclRef(DeclRef<GenericDecl> genericDeclRef, ConstArrayView<Val*> args, Decl* innerDecl = nullptr) { - NodeDesc desc; - desc.type = GenericSubstitution::kType; - desc.operands.add(decl); - for (auto arg : args) - desc.operands.add(arg); - if (outer) - { - desc.operands.add(outer); - } - desc.init(); - auto result = (GenericSubstitution*)_getOrCreateImpl(desc, [this]() {return create<GenericSubstitution>(); }); - if (result->args.getCount() != args.getCount()) - { - SLANG_RELEASE_ASSERT(result->args.getCount() == 0); - result->args.addRange(args); - result->genericDecl = decl; - result->outer = outer; - } - return result; - } + if (!innerDecl) + innerDecl = genericDeclRef.getDecl()->inner; - GenericSubstitution* getOrCreateGenericSubstitution(Substitutions* outer, GenericDecl* decl, const List<Val*>& args) - { - return getOrCreateGenericSubstitution(outer, decl, args.getArrayView()); + return getOrCreate<GenericAppDeclRef>(innerDecl, genericDeclRef, args); } - template<typename... Args> - GenericSubstitution* getOrCreateGenericSubstitution(Substitutions* outer, GenericDecl* decl, Args... args) + DeclRef<Decl> getGenericAppDeclRef(DeclRef<GenericDecl> genericDeclRef, Val::OperandView<Val> args, Decl* innerDecl = nullptr) { - List<Val*> vals; - addToList(vals, args...); - return getOrCreateGenericSubstitution(outer, decl, vals.getArrayView()); - } + if (!innerDecl) + innerDecl = genericDeclRef.getDecl()->inner; + return getOrCreate<GenericAppDeclRef>(innerDecl, genericDeclRef, args); + } - ThisTypeSubstitution* getOrCreateThisTypeSubstitution(InterfaceDecl* interfaceDecl, SubtypeWitness* subtypeWitness, Substitutions* outer) + LookupDeclRef* getLookupDeclRef(Type* base, SubtypeWitness* subtypeWitness, Decl* declToLookup) { - NodeDesc desc; - desc.type = ThisTypeSubstitution::kType; - desc.operands.add(interfaceDecl); - desc.operands.add(subtypeWitness); - if (outer) - { - desc.operands.add(outer); - } + ValNodeDesc desc; + desc.type = LookupDeclRef::kType; + desc.operands.add(ValNodeOperand(subtypeWitness)); + desc.operands.add(ValNodeOperand(declToLookup)); desc.init(); - auto result = (ThisTypeSubstitution*)_getOrCreateImpl(desc, [this]() {return create<ThisTypeSubstitution>(); }); - result->interfaceDecl = interfaceDecl; - result->witness = subtypeWitness; - result->outer = outer; + auto result = getOrCreate<LookupDeclRef>(declToLookup, base, subtypeWitness); return result; } + LookupDeclRef* getLookupDeclRef(SubtypeWitness* subtypeWitness, Decl* declToLookup) + { + return getLookupDeclRef(subtypeWitness->getSub(), subtypeWitness, declToLookup); + } + NodeBase* createByNodeType(ASTNodeType nodeType); /// Get the built in types @@ -371,11 +363,12 @@ public: SLANG_FORCE_INLINE Type* getBuiltinType(BaseType flavor) { return m_sharedASTBuilder->m_builtinTypes[Index(flavor)]; } Type* getSpecializedBuiltinType(Type* typeParam, const char* magicTypeName); + Type* getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const char* magicTypeName); - Type* getInitializerListType() { return m_sharedASTBuilder->m_initializerListType; } - Type* getOverloadedType() { return m_sharedASTBuilder->m_overloadedType; } - Type* getErrorType() { return m_sharedASTBuilder->m_errorType; } - Type* getBottomType() { return m_sharedASTBuilder->m_bottomType; } + Type* getInitializerListType() { return m_sharedASTBuilder->getInitializerListType(); } + Type* getOverloadedType() { return m_sharedASTBuilder->getOverloadedType(); } + Type* getErrorType() { return m_sharedASTBuilder->getErrorType(); } + Type* getBottomType() { return m_sharedASTBuilder->getBottomType(); } Type* getStringType() { return m_sharedASTBuilder->getStringType(); } Type* getNullPtrType() { return m_sharedASTBuilder->getNullPtrType(); } Type* getNoneType() { return m_sharedASTBuilder->getNoneType(); } @@ -407,13 +400,18 @@ public: ConstantBufferType* getConstantBufferType(Type* elementType); + ParameterBlockType* getParameterBlockType(Type* elementType); + + HLSLStructuredBufferType* getStructuredBufferType(Type* elementType); + + SamplerStateType* getSamplerStateType(); + DifferentialPairType* getDifferentialPairType( Type* valueType, Witness* primalIsDifferentialWitness); DeclRef<InterfaceDecl> getDifferentiableInterfaceDecl(); Type* getDifferentiableInterfaceType(); - Decl* getDifferentiableAssociatedTypeRequirement(); bool isDifferentiableInterfaceAvailable(); @@ -423,6 +421,7 @@ public: IntVal* maxElementCount); DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg); + DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, ArrayView<Val*> genericArgs); Type* getAndType(Type* left, Type* right); @@ -435,9 +434,9 @@ public: Val* getSNormModifierVal(); Val* getNoDiffModifierVal(); - Type* getTupleType(List<Type*>& types); + TupleType* getTupleType(List<Type*>& types); - Type* getFuncType(List<Type*> parameters, Type* result); + FuncType* getFuncType(ArrayView<Type*> parameters, Type* result, Type* errorType = nullptr); TypeType* getTypeType(Type* type); @@ -445,7 +444,7 @@ public: TypeEqualityWitness* getTypeEqualityWitness( Type* type); - SubtypeWitness* getDeclaredSubtypeWitness( + DeclaredSubtypeWitness* getDeclaredSubtypeWitness( Type* subType, Type* superType, DeclRef<Decl> const& declRef); @@ -455,9 +454,6 @@ public: SubtypeWitness* aIsSubtypeOfBWitness, SubtypeWitness* bIsSubtypeOfCWitness); - /// Produce a witness that `ThisType(IFoo) <: IFoo`. - ThisTypeSubtypeWitness* getThisTypeSubtypeWitness(Type* subType, Type* superType); - /// Produce a witness that `T <: L` or `T <: R` given `T <: L&R` SubtypeWitness* getExtractFromConjunctionSubtypeWitness( Type* subType, @@ -487,14 +483,14 @@ public: /// Get the global session Session* getGlobalSession() { return m_sharedASTBuilder->m_session; } + Index getId() { return m_id; } + /// Ctor ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name); /// Dtor ~ASTBuilder(); - Dictionary<Decl*, GenericSubstitution*> m_genericDefaultSubst; - protected: // Special default Ctor that can only be used by SharedASTBuilder ASTBuilder(); @@ -512,11 +508,12 @@ protected: // Keep such that dtor can be run on ASTBuilder being dtored m_dtorNodes.add(node); } - if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Decl::kType))) + if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Val::kType))) { - auto decl = (Decl*)(node); - decl->defaultDeclRef = getSpecializedDeclRef(decl, nullptr); + auto val = (Val*)(node); + val->m_resolvedValEpoch = getEpoch(); } + return node; } @@ -529,9 +526,30 @@ protected: SharedASTBuilder* m_sharedASTBuilder; MemoryArena m_arena; +}; + +// Retrieves the ASTBuilder for the current compilation session. +ASTBuilder* getCurrentASTBuilder(); +// Sets the ASTBuilder for the current compilation session. +void setCurrentASTBuilder(ASTBuilder* astBuilder); + +struct SetASTBuilderContextRAII +{ + ASTBuilder* previousASTBuilder = nullptr; + SetASTBuilderContextRAII(ASTBuilder* astBuilder) + { + previousASTBuilder = getCurrentASTBuilder(); + setCurrentASTBuilder(astBuilder); + } + ~SetASTBuilderContextRAII() + { + setCurrentASTBuilder(previousASTBuilder); + } }; +#define SLANG_AST_BUILDER_RAII(astBuilder) SetASTBuilderContextRAII _setASTBuilderContextRAII(astBuilder) + } // namespace Slang #endif diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp new file mode 100644 index 000000000..4384a6df9 --- /dev/null +++ b/source/slang/slang-ast-decl-ref.cpp @@ -0,0 +1,461 @@ +#include "slang-ast-builder.h" +#include "slang-ast-reflect.h" +#include "slang-generated-ast.h" +#include "slang-generated-ast-macro.h" +#include "slang-check-impl.h" + +namespace Slang +{ + +DeclRefBase* DirectDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + SLANG_UNUSED(astBuilder); + SLANG_UNUSED(subst); + SLANG_UNUSED(ioDiff); + return this; +} + +void DirectDeclRef::_toTextOverride(StringBuilder& out) +{ + if (getDecl()->getName() && getDecl()->getName()->text.getLength() != 0) + { + out << getDecl()->getName()->text; + } +} + +Val* DirectDeclRef::_resolveImplOverride() +{ + return this; +} + +DeclRefBase* DirectDeclRef::_getBaseOverride() +{ + return nullptr; +} + +DeclRefBase* _getDeclRefFromVal(Val* val) +{ + if (auto declRefType = as<DeclRefType>(val)) + return declRefType->getDeclRef(); + else if (auto genParamIntVal = as<GenericParamIntVal>(val)) + return genParamIntVal->getDeclRef(); + else if (auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(val)) + return declaredSubtypeWitness->getDeclRef(); + else if (auto declRef = as<DeclRefBase>(val)) + return declRef; + return nullptr; +} + +DeclRefBase* _resolveAsDeclRef(DeclRefBase* declRefToResolve) +{ + if (auto rs = _getDeclRefFromVal(declRefToResolve->resolve())) + return rs; + return declRefToResolve; +} + +DeclRefBase* MemberDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto substParent = getParentOperand()->substituteImpl(astBuilder, subst, &diff); + if (diff) + { + (*ioDiff)++; + return astBuilder->getMemberDeclRef(substParent, getDecl()); + } + return this; +} + +void MemberDeclRef::_toTextOverride(StringBuilder& out) +{ + getParentOperand()->toText(out); + if (out.getLength() && !out.endsWith(".")) + out << "."; + if (getDecl()->getName() && getDecl()->getName()->text.getLength() != 0) + { + out << getDecl()->getName()->text; + } +} + +Val* MemberDeclRef::_resolveImplOverride() +{ + auto resolvedParent = _resolveAsDeclRef(getParentOperand()); + if (resolvedParent != getParentOperand()) + { + return getCurrentASTBuilder()->getMemberDeclRef(resolvedParent, getDecl()); + } + return this; +} + +DeclRefBase* MemberDeclRef::_getBaseOverride() +{ + return getParentOperand(); +} + +Decl* LookupDeclRef::getSupDecl() +{ + if (auto supType = as<DeclRefType>(getWitness()->getSup())) + { + return supType->getDeclRef().getDecl(); + } + // If we reach here, something is wrong. + SLANG_UNEXPECTED("Invalid lookup declref"); +} + +DeclRefBase* LookupDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substWitness = as<SubtypeWitness>(getWitness()->substituteImpl(astBuilder, subst, &diff)); + if (diff == 0) + return this; + (*ioDiff)++; + + auto substSource = as<Type>(getLookupSource()->substituteImpl(astBuilder, subst, &diff)); + SLANG_ASSERT(substSource); + + if (auto resolved = _getDeclRefFromVal(tryResolve(substWitness, substSource))) + return resolved; + + return astBuilder->getLookupDeclRef(substSource, substWitness, getDecl()); +} + +void LookupDeclRef::_toTextOverride(StringBuilder& out) +{ + getLookupSource()->toText(out); + if (out.getLength() && !out.endsWith(".")) + out << "."; + if (getDecl()->getName() && getDecl()->getName()->text.getLength() != 0) + { + out << getDecl()->getName()->text; + } +} + +Val* LookupDeclRef::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + Val* resolved = this; + + auto newLookupSource = as<Type>(getLookupSource()->resolve()); + SLANG_ASSERT(newLookupSource); + + auto newWitness = as<SubtypeWitness>(getWitness()->resolve()); + SLANG_ASSERT(newWitness); + + if (auto resolvedVal = tryResolve(newWitness, newLookupSource)) + return resolvedVal; + if (newLookupSource != getLookupSource() || newWitness != getWitness()) + resolved = astBuilder->getLookupDeclRef(newLookupSource, newWitness, getDecl()); + return resolved; +} + +DeclRefBase* LookupDeclRef::_getBaseOverride() +{ + return nullptr; +} + +Val* LookupDeclRef::tryResolve(SubtypeWitness* newWitness, Type* newLookupSource) +{ + auto astBuilder = getCurrentASTBuilder(); + Decl* requirementKey = getDecl(); + RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, newWitness, requirementKey); + switch (requirementWitness.getFlavor()) + { + default: + // No usable value was found, so there is nothing we can do. + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = requirementWitness.getVal(); + return satisfyingVal; + } + break; + } + + // Hard code implementation of T.Differential.Differential == T.Differential rule. + auto builtinReq = requirementKey->findModifier<BuiltinRequirementModifier>(); + bool isConstraint = false; + if (!builtinReq) + { + if (auto parentAssocType = as<AssocTypeDecl>(requirementKey->parentDecl)) + { + builtinReq = parentAssocType->findModifier<BuiltinRequirementModifier>(); + isConstraint = true; + } + if (!builtinReq) + return nullptr; + } + if (builtinReq->kind != BuiltinRequirementKind::DifferentialType) + return nullptr; + // Is the concrete type a Differential associated type? + auto innerDeclRefType = as<DeclRefType>(newLookupSource); + if (!innerDeclRefType) + return nullptr; + auto innerBuiltinReq = innerDeclRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>(); + if (!innerBuiltinReq) + return nullptr; + if (innerBuiltinReq->kind != BuiltinRequirementKind::DifferentialType) + return nullptr; + if (isConstraint) + return newWitness; + if (innerDeclRefType->getDeclRef() != this) + { + auto result = innerDeclRefType->getDeclRef().declRefBase->resolve(); + if (result) + return result; + } + return innerDeclRefType; +} + +DeclRefBase* GenericAppDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto substGenericDeclRef = getGenericDeclRef()->substituteImpl(astBuilder, subst, &diff); + List<Val*> substArgs; + for (auto arg : getArgs()) + { + substArgs.add(arg->substituteImpl(astBuilder, subst, &diff)); + } + if (diff == 0) + return this; + (*ioDiff)++; + return astBuilder->getGenericAppDeclRef(substGenericDeclRef, substArgs.getArrayView(), getDecl()); +} + +GenericDecl* GenericAppDeclRef::getGenericDecl() { return as<GenericDecl>(getGenericDeclRef()->getDecl()); } + + +void GenericAppDeclRef::_toTextOverride(StringBuilder& out) +{ + auto genericDecl = as<GenericDecl>(getGenericDeclRef()->getDecl()); + Index paramCount = 0; + for (auto member : genericDecl->members) + if (as<GenericTypeParamDecl>(member) || as<GenericValueParamDecl>(member)) + paramCount++; + getGenericDeclRef()->toText(out); + out << "<"; + auto args = getArgs(); + Index argCount = args.getCount(); + for (Index aa = 0; aa < Math::Min(paramCount, argCount); ++aa) + { + if (aa != 0) out << ", "; + args[aa]->toText(out); + } + out << ">"; +} + +Val* GenericAppDeclRef::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + Val* resolvedVal = this; + auto resolvedGenericDeclRef = _resolveAsDeclRef(getGenericDeclRef()); + bool diff = false; + if (resolvedGenericDeclRef != getGenericDeclRef()) + diff = true; + List<Val*> resolvedArgs; + for (auto arg : getArgs()) + { + auto resolvedArg = arg->resolve(); + resolvedArgs.add(resolvedArg); + if (resolvedArg != arg) + diff = true; + } + if (diff) + resolvedVal = astBuilder->getGenericAppDeclRef(resolvedGenericDeclRef, resolvedArgs.getArrayView(), getDecl()); + return resolvedVal; +} + +DeclRefBase* GenericAppDeclRef::_getBaseOverride() +{ + return getGenericDeclRef(); +} + +// Convenience accessors for common properties of declarations + +DeclRefBase* DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, substituteImpl, (astBuilder, subst, ioDiff)); +} + +DeclRefBase* DeclRefBase::getBase() { SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, getBase, ()); } +void DeclRefBase::toText(StringBuilder& out) { SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, toText, (out)); } + +Name* DeclRefBase::getName() const +{ + return getDecl()->nameAndLoc.name; +} + +SourceLoc DeclRefBase::getNameLoc() const +{ + return getDecl()->nameAndLoc.loc; +} + +SourceLoc DeclRefBase::getLoc() const +{ + return getDecl()->loc; +} + +DeclRefBase* DeclRefBase::getParent() +{ + auto astBuilder = getCurrentASTBuilder(); + if (!getDecl()->parentDecl) + return nullptr; + auto parentDecl = getDecl()->parentDecl; + for (auto base = getBase(); base; base = base->getBase()) + { + if (base->getDecl() == parentDecl) + return base; + bool parentIsChildOfBase = false; + for (auto dd = parentDecl->parentDecl; dd; dd = dd->parentDecl) + { + if (dd == base->getDecl()) + { + parentIsChildOfBase = true; + break; + } + } + if (parentIsChildOfBase) + return astBuilder->getMemberDeclRef(base, parentDecl); + } + return astBuilder->getDirectDeclRef(parentDecl); +} + +SubstitutionSet::operator bool() const +{ + return declRef != nullptr && !as<DirectDeclRef>(declRef); +} + +Val::OperandView<Val> tryGetGenericArguments(SubstitutionSet substSet, Decl* genericDecl) +{ + if (!substSet.declRef) + return Val::OperandView<Val>(); + + DeclRefBase* currentDeclRef = substSet.declRef; + // search for a substitution that might apply to us + for (auto s = currentDeclRef; s; s = s->getBase()) + { + auto genericAppDeclRef = as<GenericAppDeclRef>(s); + if (!genericAppDeclRef) + continue; + + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto parentGeneric = genericAppDeclRef->getGenericDecl(); + if (parentGeneric != genericDecl) + continue; + + return genericAppDeclRef->getArgs(); + } + return Val::OperandView<Val>(); +} + +Type* SubstitutionSet::applyToType(ASTBuilder* astBuilder, Type* type) const +{ + if (!type) + return nullptr; + int diff = 0; + auto newType = as<Type>(type->substituteImpl(astBuilder, *this, &diff)); + if (diff && newType) + return newType; + return type; +} + +SubstExpr<Expr> applySubstitutionToExpr(SubstitutionSet substSet, Expr* expr) +{ + return SubstExpr<Expr>(expr, substSet); +} + + +DeclRefBase* SubstitutionSet::applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* otherDeclRef) const +{ + int diff = 0; + return otherDeclRef->substituteImpl(astBuilder, *this, &diff); +} + +LookupDeclRef* SubstitutionSet::findLookupDeclRef() const +{ + for (auto s = declRef; s; s = s->getBase()) + { + if (auto lookupDeclRef = as<LookupDeclRef>(s)) + return lookupDeclRef; + } + return nullptr; +} + +DeclRefBase* SubstitutionSet::getInnerMostNodeWithSubstInfo() const +{ + for (auto s = declRef; s; s = s->getBase()) + { + if (as<LookupDeclRef>(s) || as<GenericAppDeclRef>(s)) + return s; + } + return nullptr; +} + + +GenericAppDeclRef* SubstitutionSet::findGenericAppDeclRef(GenericDecl* genericDecl) const +{ + for (auto s = declRef; s; s = s->getBase()) + { + if (auto genApp = as<GenericAppDeclRef>(s)) + { + if (genApp->getGenericDecl() == genericDecl) + return genApp; + } + } + return nullptr; +} + +GenericAppDeclRef* SubstitutionSet::findGenericAppDeclRef() const +{ + for (auto s = declRef; s; s = s->getBase()) + { + if (auto genApp = as<GenericAppDeclRef>(s)) + { + return genApp; + } + } + return nullptr; +} + +DeclRef<Decl> createDefaultSubstitutionsIfNeeded( + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, + DeclRef<Decl> declRef) +{ + if (declRef.as<GenericTypeParamDecl>()) + return declRef; + if (declRef.as<GenericValueParamDecl>()) + return declRef; + if (declRef.as<GenericTypeConstraintDecl>()) + return declRef; + ShortList<GenericDecl*> genericParentDecls; + auto lastSubstNode = SubstitutionSet(declRef).getInnerMostNodeWithSubstInfo(); + auto lastGenApp = as<GenericAppDeclRef>(lastSubstNode); + for (auto dd = declRef.getDecl()->parentDecl; dd; dd = dd->parentDecl) + { + if (lastGenApp && dd == lastGenApp->getGenericDecl()) + break; + if (auto gen = as<GenericDecl>(dd)) + genericParentDecls.add(gen); + } + DeclRef<Decl> parentDeclRef = lastSubstNode; + for (auto i = genericParentDecls.getCount() - 1; i >= 0; i--) + { + auto current = genericParentDecls[i]; + auto args = getDefaultSubstitutionArgs(astBuilder, semantics, current); + if (parentDeclRef) + { + parentDeclRef = astBuilder->getMemberDeclRef(parentDeclRef, current); + } + else + { + parentDeclRef = astBuilder->getDirectDeclRef(current); + } + parentDeclRef = astBuilder->getGenericAppDeclRef(parentDeclRef.as<GenericDecl>(), args.getArrayView()); + } + if (parentDeclRef.getDecl() == declRef.getDecl()) + return parentDeclRef; + return astBuilder->getMemberDeclRef(parentDeclRef, declRef.getDecl()); +} +} diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp index 2f1c7c47e..9dbd006a0 100644 --- a/source/slang/slang-ast-decl.cpp +++ b/source/slang/slang-ast-decl.cpp @@ -4,6 +4,7 @@ #include <assert.h> #include "slang-generated-ast-macro.h" +#include "slang-ast-decl.h" namespace Slang { @@ -118,4 +119,21 @@ bool isLocalVar(const Decl* decl) return false; } +ThisTypeDecl* InterfaceDecl::getThisTypeDecl() +{ + for (auto member : members) + { + if (auto thisTypeDeclCandidate = as<ThisTypeDecl>(member)) + { + return thisTypeDeclCandidate; + } + } + SLANG_UNREACHABLE("InterfaceDecl does not have a ThisType decl."); +} + +InterfaceDecl* ThisTypeConstraintDecl::getInterfaceDecl() +{ + return as<InterfaceDecl>(parentDecl->parentDecl); +} + } // namespace Slang diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 61e623366..8266d77c7 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -56,6 +56,15 @@ class ContainerDecl: public Decl return transparentMembers; } + void addMember(Decl* member) + { + if (member) + { + member->parentDecl = this; + members.add(member); + } + } + SLANG_UNREFLECTED // We don't want to reflect the following fields private: @@ -178,12 +187,19 @@ class EnumCaseDecl : public Decl Expr* tagExpr = nullptr; }; +// A member of InterfaceDecl representing the abstract ThisType. +class ThisTypeDecl : public AggTypeDecl +{ + SLANG_AST_CLASS(ThisTypeDecl) +}; + // An interface which other types can conform to class InterfaceDecl : public AggTypeDecl { SLANG_AST_CLASS(InterfaceDecl) -}; + ThisTypeDecl* getThisTypeDecl(); +}; class TypeConstraintDecl : public Decl { @@ -195,6 +211,15 @@ class TypeConstraintDecl : public Decl const TypeExp& _getSupOverride() const; }; +class ThisTypeConstraintDecl : public TypeConstraintDecl +{ + SLANG_AST_CLASS(ThisTypeConstraintDecl) + + TypeExp base; + const TypeExp& _getSupOverride() const { return base; } + InterfaceDecl* getInterfaceDecl(); +}; + // A kind of pseudo-member that represents an explicit // or implicit inheritance relationship. // diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index 0ab440a18..65718833b 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -65,18 +65,6 @@ struct ASTDumpContext } } - void dump(Substitutions* subs) - { - if (subs == nullptr) - { - _dumpPtr(nullptr); - } - else - { - dumpObject(subs->getClassInfo(), subs); - } - } - void dump(const Name* name) { if (name == nullptr) @@ -608,6 +596,40 @@ struct ASTDumpContext m_writer->emit("\n"); } + template<int N> + void dump(const ShortList<ValNodeOperand, N>& operands) + { + m_writer->emit("("); + bool isFirst = true; + for (auto operand : operands) + { + if (!isFirst) + { + m_writer->emit(", "); + } + isFirst = false; + dumpField("operand", operand); + } + + m_writer->emit(")"); + } + + void dump(ValNodeOperand operand) + { + switch (operand.kind) + { + case ValNodeOperandKind::ConstantValue: + dump(operand.values.intOperand); + break; + case ValNodeOperandKind::ValNode: + dump(operand.values.nodeOperand); + break; + case ValNodeOperandKind::ASTNode: + dump(operand.values.nodeOperand); + break; + } + } + void dump(ASTNodeType nodeType) { // Get the class @@ -616,6 +638,15 @@ struct ASTDumpContext m_writer->emit(info->m_name); } + void dump(KeyValuePair<DeclRefBase*, SubtypeWitness*> pair) + { + m_writer->emit("("); + dump(pair.key); + m_writer->emit(", "); + dump(pair.value); + m_writer->emit(")"); + } + void dumpObjectFull(NodeBase* node); ASTDumpContext(SourceWriter* writer, ASTDumpUtil::Flags flags, ASTDumpUtil::Style dumpStyle): diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 07bf2f033..28ce2e4d1 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -560,17 +560,6 @@ class TreatAsDifferentiableExpr : public Expr Flavor flavor; }; - /// A type expression of the form `__TaggedUnion(A, ...)`. - /// - /// An expression of this form will resolve to a `TaggedUnionType` - /// when checked. - /// -class TaggedUnionTypeExpr: public Expr -{ - SLANG_AST_CLASS(TaggedUnionTypeExpr) - List<TypeExp> caseTypes; -}; - /// A type expression of the form `This` /// /// Refers to the type of `this` in the current context. @@ -639,7 +628,7 @@ public: DeclRef<GenericDecl> baseGenericDeclRef; /// A substitution that includes the generic arguments known so far - GenericSubstitution* substWithKnownGenericArgs = nullptr; + List<Val*> knownGenericArgs; }; } // namespace Slang diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index ea3db6937..fb3d50b4f 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -99,11 +99,6 @@ struct ASTIterator dispatchIfNotNull(expr->base.exp); } - void visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr) - { - iterator->maybeDispatchCallback(expr); - } - void visitInvokeExpr(InvokeExpr* expr) { iterator->maybeDispatchCallback(expr); diff --git a/source/slang/slang-ast-modifier.cpp b/source/slang/slang-ast-modifier.cpp index 3daa9b056..84046a601 100644 --- a/source/slang/slang-ast-modifier.cpp +++ b/source/slang/slang-ast-modifier.cpp @@ -4,5 +4,10 @@ namespace Slang { - +const OrderedDictionary<DeclRefBase*, SubtypeWitness*>& DifferentiableAttribute::getMapTypeToIDifferentiableWitness() +{ + for (Index i = m_mapToIDifferentiableWitness.getCount(); i < m_typeToIDifferentiableWitnessMappings.getCount(); i++) + m_mapToIDifferentiableWitness.add(m_typeToIDifferentiableWitnessMappings[i].key, m_typeToIDifferentiableWitnessMappings[i].value); + return m_mapToIDifferentiableWitness; +} } // namespace Slang diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index fd317a2c2..8e7cc9193 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -600,7 +600,7 @@ class Attribute : public AttributeBase { SLANG_AST_CLASS(Attribute) - AttributeArgumentValueDict intArgVals; + List<Val*> intArgVals; }; class UserDefinedAttribute : public Attribute @@ -1054,10 +1054,23 @@ class DifferentiableAttribute : public Attribute { SLANG_AST_CLASS(DifferentiableAttribute) + List<KeyValuePair<DeclRefBase*, SubtypeWitness*>> m_typeToIDifferentiableWitnessMappings; + + void addType(DeclRefBase* declRef, SubtypeWitness* witness) + { + getMapTypeToIDifferentiableWitness(); + if (m_mapToIDifferentiableWitness.addIfNotExists(declRef, witness)) + { + m_typeToIDifferentiableWitnessMappings.add(KeyValuePair<DeclRefBase*, SubtypeWitness*>(declRef, witness)); + } + } + /// Mapping from types to subtype witnesses for conformance to IDifferentiable. - OrderedDictionary<DeclRefBase*, SubtypeWitness*> m_mapTypeToIDifferentiableWitness; + const OrderedDictionary<DeclRefBase*, SubtypeWitness*>& getMapTypeToIDifferentiableWitness(); SLANG_UNREFLECTED ValSet m_typeRegistrationWorkingSet; +private: + OrderedDictionary<DeclRefBase*, SubtypeWitness*> m_mapToIDifferentiableWitness; }; class DllImportAttribute : public Attribute diff --git a/source/slang/slang-ast-natural-layout.cpp b/source/slang/slang-ast-natural-layout.cpp index 1789c5cea..4a4ef37fb 100644 --- a/source/slang/slang-ast-natural-layout.cpp +++ b/source/slang/slang-ast-natural-layout.cpp @@ -70,9 +70,9 @@ Count ASTNaturalLayoutContext::_getCount(IntVal* intVal) { if (auto constIntVal = as<ConstantIntVal>(intVal)) { - if (constIntVal->value >= 0) + if (constIntVal->getValue() >= 0) { - return Count(constIntVal->value); + return Count(constIntVal->getValue()); } } @@ -115,9 +115,9 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) { if (VectorExpressionType* vecType = as<VectorExpressionType>(type)) { - const Count elementCount = _getCount(vecType->elementCount); + const Count elementCount = _getCount(vecType->getElementCount()); return (elementCount > 0) ? - calcSize(vecType->elementType) * elementCount : + calcSize(vecType->getElementType()) * elementCount : NaturalSize::makeInvalid(); } else if (auto matType = as<MatrixExpressionType>(type)) @@ -130,7 +130,7 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) } else if (auto basicType = as<BasicExpressionType>(type)) { - return NaturalSize::makeFromBaseType(basicType->baseType); + return NaturalSize::makeFromBaseType(basicType->getBaseType()); } else if (as<PtrTypeBase>(type) || as<NullPtrType>(type)) { @@ -146,7 +146,7 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) } else if (auto namedType = as<NamedExpressionType>(type)) { - return calcSize(namedType->innerType); + return calcSize(namedType->getCanonicalType()); } else if (const auto tupleType = as<TupleType>(type)) { @@ -154,9 +154,9 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) NaturalSize size = NaturalSize::makeEmpty(); // Accumulate over all the member types - for (auto cur : tupleType->memberTypes) + for (auto cur = 0; cur < tupleType->getMemberCount(); cur++) { - const auto curSize = calcSize(cur); + const auto curSize = calcSize(tupleType->getMember(cur)); if (!curSize) { return NaturalSize::makeInvalid(); @@ -166,36 +166,14 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) return size; } - else if (const auto taggedUnion = as<TaggedUnionType>(type)) - { - NaturalSize size = NaturalSize::makeInvalid(); - - for( auto caseType : taggedUnion->caseTypes ) - { - const NaturalSize caseSize = calcSize(caseType); - if (!caseSize) - { - return NaturalSize::makeInvalid(); - } - size = NaturalSize::calcUnion(size, caseSize); - } - - // After we've computed the size required to hold all the - // case types, we will allocate space for the tag field. - - // Currently we assume uint32_t on all targets - size.append(NaturalSize::makeFromBaseType(BaseType::UInt)); - - return size; - } else if( auto declRefType = as<DeclRefType>(type) ) { - if (const auto enumDeclRef = declRefType->declRef.as<EnumDecl>()) + if (const auto enumDeclRef = declRefType->getDeclRef().as<EnumDecl>()) { Type* tagType = getTagType(m_astBuilder, enumDeclRef); return calcSize(tagType); } - else if(const auto structDeclRef = declRefType->declRef.as<StructDecl>()) + else if(const auto structDeclRef = declRefType->getDeclRef().as<StructDecl>()) { // Poison the cache whilst we construct m_typeToSize.add(type, NaturalSize::makeInvalid()); @@ -208,7 +186,7 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) // Look for a struct type that it inherits from if (auto inheritedDeclRef = as<DeclRefType>(inherited->base.type)) { - if (auto parentDecl = inheritedDeclRef->declRef.as<StructDecl>()) + if (auto parentDecl = inheritedDeclRef->getDeclRef().as<StructDecl>()) { // We can only inherit from one thing size = calcSize(inherited->base.type); @@ -237,7 +215,7 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) return size; } - else if (const auto typeDef = declRefType->declRef.as<TypeDefDecl>()) + else if (const auto typeDef = declRefType->getDeclRef().as<TypeDefDecl>()) { return calcSize(typeDef.getDecl()->type); } diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index 84c521108..b80afeee1 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -36,12 +36,12 @@ void ASTPrinter::addType(Type* type) { if (auto vectorType = as<VectorExpressionType>(type)) { - if (as<BasicExpressionType>(vectorType->elementType)) + if (as<BasicExpressionType>(vectorType->getElementType())) { - vectorType->elementType->toText(m_builder); - if (as<ConstantIntVal>(vectorType->elementCount)) + vectorType->getElementType()->toText(m_builder); + if (as<ConstantIntVal>(vectorType->getElementCount())) { - m_builder << vectorType->elementCount; + m_builder << vectorType->getElementCount(); return; } } @@ -107,14 +107,14 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth) auto& sb = m_builder; // Find the parent declaration - auto parentDeclRef = declRef.getParent(m_astBuilder); + auto parentDeclRef = declRef.getParent(); // If the immediate parent is a generic, then we probably // want the declaration above that... auto parentGenericDeclRef = parentDeclRef.as<GenericDecl>(); if (parentGenericDeclRef) { - parentDeclRef = parentGenericDeclRef.getParent(m_astBuilder); + parentDeclRef = parentGenericDeclRef.getParent(); } // Depending on what the parent is, we may want to format things specially @@ -172,12 +172,9 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth) !declRef.as<GenericValueParamDecl>() && !declRef.as<GenericTypeParamDecl>()) { - auto genSubst = as<GenericSubstitution>(declRef.getSubst()); - if (genSubst) + auto substArgs = tryGetGenericArguments(SubstitutionSet(declRef), parentGenericDeclRef.getDecl()); + if (substArgs.getCount()) { - SLANG_RELEASE_ASSERT(genSubst); - SLANG_RELEASE_ASSERT(genSubst->getGenericDecl() == parentGenericDeclRef.getDecl()); - // If the name we printed previously was an operator // that ends with `<`, then immediately printing the // generic arguments inside `<...>` may cause it to @@ -193,7 +190,7 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth) sb << "<"; bool first = true; - for (auto arg : genSubst->getArgs()) + for (auto arg : substArgs) { // When printing the representation of a specialized // generic declaration we don't want to include the @@ -331,7 +328,7 @@ void ASTPrinter::addDeclParams(const DeclRef<Decl>& declRef, List<Range<Index>>* { addGenericParams(genericDeclRef); - addDeclParams(m_astBuilder->getSpecializedDeclRef<Decl>(getInner(genericDeclRef), genericDeclRef.getSubst()), outParamRange); + addDeclParams(m_astBuilder->getMemberDeclRef(genericDeclRef, genericDeclRef.getDecl()->inner), outParamRange); } else { @@ -443,7 +440,7 @@ void ASTPrinter::addDeclResultType(const DeclRef<Decl>& inDeclRef) DeclRef<Decl> declRef = inDeclRef; if (auto genericDeclRef = declRef.as<GenericDecl>()) { - declRef = m_astBuilder->getSpecializedDeclRef<Decl>(getInner(genericDeclRef), genericDeclRef.getSubst()); + declRef = m_astBuilder->getMemberDeclRef<Decl>(genericDeclRef, genericDeclRef.getDecl()->inner); } if (declRef.as<ConstructorDecl>()) diff --git a/source/slang/slang-ast-reflect.cpp b/source/slang/slang-ast-reflect.cpp index b16568d2e..66e57a744 100644 --- a/source/slang/slang-ast-reflect.cpp +++ b/source/slang/slang-ast-reflect.cpp @@ -39,7 +39,7 @@ struct ASTConstructAccess static void* create(void* context) { ASTBuilder* astBuilder = (ASTBuilder*)context; - return astBuilder->create<T>(); + return astBuilder->createImpl<T>(); } static void destroy(void* ptr) { diff --git a/source/slang/slang-ast-substitutions.cpp b/source/slang/slang-ast-substitutions.cpp deleted file mode 100644 index 7b052522e..000000000 --- a/source/slang/slang-ast-substitutions.cpp +++ /dev/null @@ -1,163 +0,0 @@ -// slang-ast-substitutions.cpp -#include "slang-ast-builder.h" -#include <assert.h> - -#include "slang-generated-ast-macro.h" - -namespace Slang { - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Substitutions !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -Substitutions* Substitutions::applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff) -{ - SLANG_AST_NODE_VIRTUAL_CALL(Substitutions, applySubstitutionsShallow, (astBuilder, substSet, substOuter, ioDiff)) -} - -bool Substitutions::equals(Substitutions* subst) -{ - SLANG_AST_NODE_VIRTUAL_CALL(Substitutions, equals, (subst)) -} - -HashCode Substitutions::getHashCode() const -{ - SLANG_AST_NODE_CONST_VIRTUAL_CALL(Substitutions, getHashCode, ()) -} - -Substitutions* Substitutions::_applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff) -{ - SLANG_UNUSED(astBuilder); - SLANG_UNUSED(substSet); - SLANG_UNUSED(substOuter); - SLANG_UNUSED(ioDiff); - SLANG_UNEXPECTED("Substitutions::_applySubstitutionsShallowOverride not overridden"); - //return Substitutions*(); -} - -bool Substitutions::_equalsOverride(Substitutions* subst) -{ - SLANG_UNUSED(subst); - SLANG_UNEXPECTED("Substitutions::_equalsOverride not overridden"); - //return false; -} - -HashCode Substitutions::_getHashCodeOverride() const -{ - SLANG_UNEXPECTED("Substitutions::_getHashCodeOverride not overridden"); - //return HashCode(0); -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericSubstitution !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -Substitutions* GenericSubstitution::_applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff) -{ - int diff = 0; - - if (substOuter != outer) diff++; - - List<Val*> substArgs; - for (auto a : args) - { - substArgs.add(a->substituteImpl(astBuilder, substSet, &diff)); - } - - if (!diff) return this; - - (*ioDiff)++; - - auto substSubst = astBuilder->getOrCreateGenericSubstitution(substOuter, genericDecl, substArgs); - return substSubst; -} - -bool GenericSubstitution::_equalsOverride(Substitutions* subst) -{ - // both must be NULL, or non-NULL - if (subst == nullptr) - return false; - if (this == subst) - return true; - - auto genericSubst = as<GenericSubstitution>(subst); - if (!genericSubst) - return false; - if (genericDecl != genericSubst->genericDecl) - return false; - - Index argCount = args.getCount(); - SLANG_RELEASE_ASSERT(args.getCount() == genericSubst->args.getCount()); - for (Index aa = 0; aa < argCount; ++aa) - { - if (!args[aa]->equalsVal(genericSubst->args[aa])) - return false; - } - - if (!outer) - return !genericSubst->outer; - - if (!outer->equals(genericSubst->outer)) - return false; - - return true; -} - -HashCode GenericSubstitution::_getHashCodeOverride() const -{ - HashCode rs = 0; - for (auto && v : args) - { - rs ^= v->getHashCode(); - rs *= 16777619; - } - return rs; -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisTypeSubstitution !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -Substitutions* ThisTypeSubstitution::_applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff) -{ - int diff = 0; - - if (substOuter != outer) diff++; - - // NOTE: Must use .as because we must have a smart pointer here to keep in scope. - auto substWitness = as<SubtypeWitness>(witness->substituteImpl(astBuilder, substSet, &diff)); - - if (!diff) return this; - - (*ioDiff)++; - ThisTypeSubstitution* substSubst; - - substSubst = astBuilder->getOrCreateThisTypeSubstitution(interfaceDecl, substWitness, substOuter); - return substSubst; -} - -bool ThisTypeSubstitution::_equalsOverride(Substitutions* subst) -{ - if (!subst) - return false; - if (subst == this) - return true; - - if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst)) - { - // For our purposes, two this-type substitutions are - // equivalent if they have the same type as `This`, - // even if the specific witness values they use - // might differ. - // - if (this->interfaceDecl != thisTypeSubst->interfaceDecl) - return false; - - if (!this->witness->sub->equals(thisTypeSubst->witness->sub)) - return false; - - return true; - } - return false; -} - -HashCode ThisTypeSubstitution::_getHashCodeOverride() const -{ - return witness->sub->getHashCode(); -} - -} // namespace Slang diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index a3df25ce9..6a957e427 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -68,4 +68,5 @@ UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr) return UnownedStringSlice("bwd_diff"); return UnownedStringSlice(); } + } diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 4765d11ec..9140be967 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -28,7 +28,6 @@ namespace Slang class Module; class Name; class Session; - class Substitutions; class SyntaxVisitor; class FuncDecl; class Layout; @@ -50,6 +49,8 @@ namespace Slang class Val; class NodeBase; + class LookupDeclRef; + class GenericAppDeclRef; template <typename T> @@ -625,19 +626,27 @@ namespace Slang struct SubstitutionSet { - Substitutions* substitutions = nullptr; - operator Substitutions*() const - { - return substitutions; - } + DeclRefBase* declRef = nullptr; + SubstitutionSet() = default; + SubstitutionSet(DeclRefBase* declRefBase) + :declRef(declRefBase) + {} + explicit operator bool() const; + + template<typename F> + void forEachGenericSubstitution(F func) const; + + template<typename F> + void forEachSubstitutionArg(F func) const; + + Type* applyToType(ASTBuilder* astBuilder, Type* type) const; + DeclRefBase* applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* declRef) const; + + LookupDeclRef* findLookupDeclRef() const; + GenericAppDeclRef* findGenericAppDeclRef(GenericDecl* genericDecl) const; + GenericAppDeclRef* findGenericAppDeclRef() const; + DeclRefBase* getInnerMostNodeWithSubstInfo() const; - SubstitutionSet() {} - SubstitutionSet(Substitutions* subst) - : substitutions(subst) - { - } - bool equals(const SubstitutionSet& substSet) const; - HashCode getHashCode() const; }; /// An expression together with (optional) substutions to apply to it @@ -741,6 +750,8 @@ namespace Slang } }; + SubstExpr<Expr> applySubstitutionToExpr(SubstitutionSet substSet, Expr* expr); + class ASTBuilder; template<typename T> @@ -752,7 +763,6 @@ namespace Slang // try to find the concrete decl that satisfies the associatedtype requirement from the // concrete type supplied by ThisTypeSubstittution. Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef<Decl> declRef); - void _printNestedDecl(const Substitutions* substitutions, const Decl* decl, StringBuilder& out); template<typename T = Decl> struct DeclRef @@ -780,13 +790,12 @@ namespace Slang {} T* getDecl() const; - Substitutions* getSubst() const; Name* getName() const; SourceLoc getNameLoc() const; SourceLoc getLoc() const; - DeclRef<ContainerDecl> getParent(ASTBuilder* astBuilder) const; + DeclRef<ContainerDecl> getParent() const; HashCode getHashCode() const; Type* substitute(ASTBuilder* astBuilder, Type* type) const; @@ -823,7 +832,10 @@ namespace Slang } template<typename U> - bool equals(DeclRef<U> other) const; + bool equals(DeclRef<U> other) const + { + return declRefBase == other.declRefBase; + } template<typename U> bool operator == (DeclRef<U> other) const @@ -979,17 +991,17 @@ namespace Slang struct FilteredMemberRefList { List<Decl*> const& m_decls; - SubstitutionSet m_substitutions; + DeclRef<Decl> m_parent; MemberFilterStyle m_filterStyle; ASTBuilder* m_astBuilder; FilteredMemberRefList( ASTBuilder* astBuilder, List<Decl*> const& decls, - SubstitutionSet substitutions, + DeclRef<Decl> parent, MemberFilterStyle filterStyle = MemberFilterStyle::All) : m_decls(decls) - , m_substitutions(substitutions) + , m_parent(parent) , m_filterStyle(filterStyle) , m_astBuilder(astBuilder) {} @@ -1007,7 +1019,7 @@ namespace Slang { Decl*const* decl = getFilterCursorByIndex<T>(m_filterStyle, m_decls.begin(), m_decls.end(), index); SLANG_ASSERT(decl); - return _getSpecializedDeclRef(m_astBuilder, (T*)*decl, m_substitutions).template as<T>(); + return _getMemberDeclRef(m_astBuilder, m_parent, (T*)*decl).template as<T>(); } List<DeclRef<T>> toArray() const @@ -1042,7 +1054,7 @@ namespace Slang void operator++() { m_ptr = adjustFilterCursor<T>(m_filterStyle, m_ptr + 1, m_end); } - DeclRef<T> operator*() { return _getSpecializedDeclRef(m_list->m_astBuilder, (T*)*m_ptr, m_list->m_substitutions).template as<T>(); } + DeclRef<T> operator*() { return _getMemberDeclRef(m_list->m_astBuilder, m_list->m_parent, (T*)*m_ptr).template as<T>(); } }; Iterator begin() const { return Iterator(this, adjustFilterCursor<T>(m_filterStyle, m_decls.begin(), m_decls.end()), m_decls.end(), m_filterStyle); } @@ -1431,7 +1443,18 @@ namespace Slang { SLANG_OBJ_CLASS(WitnessTable) - RequirementDictionary requirementDictionary; + const RequirementDictionary& getRequirementDictionary() + { + if (m_requirementDictionary.getCount() != m_requirements.getCount()) + { + for (Index i = m_requirementDictionary.getCount(); i < m_requirements.getCount(); i++) + { + auto& r = m_requirements[i]; + m_requirementDictionary.add(r.key, r.value); + } + } + return m_requirementDictionary; + } void add(Decl* decl, RequirementWitness const& witness); @@ -1440,9 +1463,13 @@ namespace Slang // The type witnessesd by the witness table (a concrete type). Type* witnessedType; - }; - typedef Dictionary<unsigned int, NodeBase*> AttributeArgumentValueDict; + // Satisfying values of each requirement. + List<KeyValuePair<Decl*, RequirementWitness>> m_requirements; + + // Cached dictionary for looking up satisfying values. + SLANG_UNREFLECTED RequirementDictionary m_requirementDictionary; + }; struct SpecializationParam { @@ -1551,6 +1578,7 @@ namespace Slang /// Get the operator name from the higher order invoke expr. UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr); + } // namespace Slang #endif diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp index 65955e815..cb7d338c8 100644 --- a/source/slang/slang-ast-synthesis.cpp +++ b/source/slang/slang-ast-synthesis.cpp @@ -134,8 +134,7 @@ Expr* ASTSynthesizer::emitMemberExpr(Type* type, Name* name) { auto rs = m_builder->create<StaticMemberExpr>(); auto typeExpr = m_builder->create<SharedTypeExpr>(); - auto typetype = m_builder->create<TypeType>(); - typetype->type = type; + auto typetype = m_builder->getOrCreate<TypeType>(type); typeExpr->type = typetype; rs->baseExpression = typeExpr; rs->name = name; diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index ee5d1d40e..13133a7f8 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -1,49 +1,19 @@ // slang-ast-type.cpp #include "slang-ast-builder.h" +#include "slang-ast-modifier.h" #include <assert.h> #include <typeinfo> #include "slang-syntax.h" #include "slang-generated-ast-macro.h" - namespace Slang { // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Type* Type::createCanonicalType() -{ - SLANG_AST_NODE_VIRTUAL_CALL(Type, createCanonicalType, ()) -} - -bool Type::equals(Type* type) -{ - return getCanonicalType()->equalsImpl(type->getCanonicalType()); -} - -bool Type::equalsImpl(Type* type) -{ - SLANG_AST_NODE_VIRTUAL_CALL(Type, equalsImpl, (type)) -} - -bool Type::_equalsImplOverride(Type* type) -{ - SLANG_UNUSED(type) - SLANG_UNEXPECTED("Type::_equalsImplOverride not overridden"); - //return false; -} - Type* Type::_createCanonicalTypeOverride() { - SLANG_UNEXPECTED("Type::_createCanonicalTypeOverride not overridden"); - //return Type*(); -} - -bool Type::_equalsValOverride(Val* val) -{ - if (auto type = dynamicCast<Type>(val)) - return const_cast<Type*>(this)->equals(type); - return false; + return as<Type>(defaultResolveImpl()); } Val* Type::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) @@ -61,20 +31,6 @@ Val* Type::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst return canSubst; } -Type* Type::getCanonicalType() -{ - Type* et = const_cast<Type*>(this); - if (!et->canonicalType) - { - // TODO(tfoley): worry about thread safety here? - auto canType = et->createCanonicalType(); - et->canonicalType = canType; - if (!et->canonicalType) - return getASTBuilder()->getErrorType(); - } - return et->canonicalType; -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! OverloadGroupType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void OverloadGroupType::_toTextOverride(StringBuilder& out) @@ -82,21 +38,11 @@ void OverloadGroupType::_toTextOverride(StringBuilder& out) out << toSlice("overload group"); } -bool OverloadGroupType::_equalsImplOverride(Type * /*type*/) -{ - return false; -} - Type* OverloadGroupType::_createCanonicalTypeOverride() { return this; } -HashCode OverloadGroupType::_getHashCodeOverride() -{ - return (HashCode)(size_t(this)); -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! InitializerListType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void InitializerListType::_toTextOverride(StringBuilder& out) @@ -104,21 +50,11 @@ void InitializerListType::_toTextOverride(StringBuilder& out) out << toSlice("initializer list"); } -bool InitializerListType::_equalsImplOverride(Type * /*type*/) -{ - return false; -} - Type* InitializerListType::_createCanonicalTypeOverride() { return this; } -HashCode InitializerListType::_getHashCodeOverride() -{ - return (HashCode)(size_t(this)); -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ErrorType::_toTextOverride(StringBuilder& out) @@ -126,11 +62,6 @@ void ErrorType::_toTextOverride(StringBuilder& out) out << toSlice("error"); } -bool ErrorType::_equalsImplOverride(Type* type) -{ - return as<ErrorType>(type); -} - Type* ErrorType::_createCanonicalTypeOverride() { return this; @@ -141,56 +72,21 @@ Val* ErrorType::_substituteImplOverride(ASTBuilder* /* astBuilder */, Substituti return this; } -HashCode ErrorType::_getHashCodeOverride() -{ - return HashCode(size_t(this)); -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BottomType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void BottomType::_toTextOverride(StringBuilder& out) { out << toSlice("never"); } -bool BottomType::_equalsImplOverride(Type* type) -{ - return as<BottomType>(type); -} - -Type* BottomType::_createCanonicalTypeOverride() { return this; } - Val* BottomType::_substituteImplOverride( ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) { return this; } -HashCode BottomType::_getHashCodeOverride() { return HashCode(size_t(this)); } - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void DeclRefType::_toTextOverride(StringBuilder& out) { - out << declRef; -} - -HashCode DeclRefType::_getHashCodeOverride() -{ - return (declRef.getHashCode() * 16777619) ^ (HashCode)(typeid(this).hash_code()); -} - -bool DeclRefType::_equalsImplOverride(Type * type) -{ - if (auto declRefType = as<DeclRefType>(type)) - { - return declRef.equals(declRefType->declRef); - } - return false; -} - -Type* DeclRefType::_createCanonicalTypeOverride() -{ - // A declaration reference is already canonical - declRef.substitute(this->getASTBuilder(), this); - return this; + out << getDeclRef(); } Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff); @@ -199,26 +95,47 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe { if (!subst) return this; - // the case we especially care about is when this type references a declaration - // of a generic parameter, since that is what we might be substituting... - if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(declRef.getDecl())) + int diff = 0; + DeclRef<Decl> substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); + + // If this declref type is a direct reference to ThisType or a Generic parameter, + // and `subst` provides an argument for it, then we should just return that argument. + // + if (as<DirectDeclRef>(substDeclRef.declRefBase)) { - if (auto result = maybeSubstituteGenericParam(this, genericTypeParamDecl, subst, ioDiff)) + if (auto thisDecl = as<ThisTypeDecl>(substDeclRef.getDecl())) + { + auto lookupDeclRef = subst.findLookupDeclRef(); + if (lookupDeclRef && lookupDeclRef->getSupDecl() == substDeclRef.getDecl()->parentDecl) + { + (*ioDiff)++; + return lookupDeclRef->getLookupSource(); + } + } + else if (as<GenericTypeParamDecl>(substDeclRef.getDecl()) || as<GenericValueParamDecl>(substDeclRef.getDecl())) { - if (auto substDeclRefType = as<DeclRefType>(result)) + auto resultVal = maybeSubstituteGenericParam(nullptr, substDeclRef.getDecl(), subst, ioDiff); + if (resultVal) { - // After generic substitution, we may be able to further simplify - // by looking up the actual type of an associated type. - if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst( - astBuilder, substDeclRefType->declRef)) - return satisfyingVal; + (*ioDiff)++; + return resultVal; } - return result; } } - int diff = 0; - DeclRef<Decl> substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + // If this type is a reference to an associated type declaration, + // and the substitutions provide a "this type" substitution for + // the outer interface, then try to replace the type with the + // actual value of the associated type for the given implementation. + // + if (auto satisfyingVal = substDeclRef.declRefBase->resolve()) + { + if (satisfyingVal != getDeclRef()) + { + *ioDiff += 1; + return DeclRefType::create(astBuilder, substDeclRef); + } + } if (!diff) return this; @@ -226,14 +143,6 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe // Make sure to record the difference! *ioDiff += diff; - // If this type is a reference to an associated type declaration, - // and the substitutions provide a "this type" substitution for - // the outer interface, then try to replace the type with the - // actual value of the associated type for the given implementation. - // - if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, substDeclRef)) - return satisfyingVal; - // Re-construct the type in case we are using a specialized sub-class return DeclRefType::create(astBuilder, substDeclRef); } @@ -254,40 +163,52 @@ BasicExpressionType* ArithmeticExpressionType::_getScalarTypeOverride() // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BasicExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool BasicExpressionType::_equalsImplOverride(Type * type) +BasicExpressionType* BasicExpressionType::_getScalarTypeOverride() { - auto basicType = as<BasicExpressionType>(type); - return basicType && basicType->baseType == this->baseType; + return this; } -Type* BasicExpressionType::_createCanonicalTypeOverride() +static Val* _getGenericTypeArg(DeclRefBase* declRef, Index i) { - // A basic type is already canonical, in our setup - return this; + auto args = findInnerMostGenericArgs(SubstitutionSet(declRef)); + if (args.getCount() <= i) + return nullptr; + + return args[i]; } -BasicExpressionType* BasicExpressionType::_getScalarTypeOverride() +static Val* _getGenericTypeArg(DeclRefType* declRefType, Index i) { - return this; + return _getGenericTypeArg(declRefType->getDeclRefBase(), i); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TensorViewType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* TensorViewType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! VectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +Type* VectorExpressionType::getElementType() +{ + return as<Type>(_getGenericTypeArg(this, 0)); +} + +IntVal* VectorExpressionType::getElementCount() +{ + return as<IntVal>(_getGenericTypeArg(this, 1)); +} + void VectorExpressionType::_toTextOverride(StringBuilder& out) { - out << toSlice("vector<") << elementType << toSlice(",") << elementCount << toSlice(">"); + out << toSlice("vector<") << getElementType() << toSlice(",") << getElementCount() << toSlice(">"); } BasicExpressionType* VectorExpressionType::_getScalarTypeOverride() { - return as<BasicExpressionType>(elementType); + return as<BasicExpressionType>(getElementType()); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MatrixExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -304,24 +225,24 @@ BasicExpressionType* MatrixExpressionType::_getScalarTypeOverride() Type* MatrixExpressionType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); } IntVal* MatrixExpressionType::getRowCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); + return as<IntVal>(_getGenericTypeArg(this, 1)); } IntVal* MatrixExpressionType::getColumnCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[2]); + return as<IntVal>(_getGenericTypeArg(this, 2)); } Type* MatrixExpressionType::getRowType() { if (!rowType) { - rowType = m_astBuilder->getVectorType(getElementType(), getColumnCount()); + rowType = getCurrentASTBuilder()->getVectorType(getElementType(), getColumnCount()); } return rowType; } @@ -330,12 +251,12 @@ Type* MatrixExpressionType::getRowType() Type* ArrayExpressionType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); } IntVal* ArrayExpressionType::getElementCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); + return as<IntVal>(_getGenericTypeArg(this, 1)); } void ArrayExpressionType::_toTextOverride(StringBuilder& out) @@ -353,7 +274,7 @@ bool ArrayExpressionType::isUnsized() { if (auto constSize = as<ConstantIntVal>(getElementCount())) { - if (constSize->value == kUnsizedArrayMagicLength) + if (constSize->getValue() == kUnsizedArrayMagicLength) return true; } return false; @@ -363,27 +284,12 @@ bool ArrayExpressionType::isUnsized() void TypeType::_toTextOverride(StringBuilder& out) { - out << toSlice("typeof(") << type << toSlice(")"); -} - -bool TypeType::_equalsImplOverride(Type * t) -{ - if (auto typeType = as<TypeType>(t)) - { - return t->equals(typeType->type); - } - return false; + out << toSlice("typeof(") << getType() << toSlice(")"); } Type* TypeType::_createCanonicalTypeOverride() { - return getASTBuilder()->getTypeType(type->getCanonicalType()); -} - -HashCode TypeType::_getHashCodeOverride() -{ - SLANG_UNEXPECTED("TypeType::_getHashCodeOverride should be unreachable"); - //return HashCode(0); + return getCurrentASTBuilder()->getTypeType(getType()->getCanonicalType()); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericDeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -394,20 +300,6 @@ void GenericDeclRefType::_toTextOverride(StringBuilder& out) out << toSlice("<DeclRef<GenericDecl>>"); } -bool GenericDeclRefType::_equalsImplOverride(Type * type) -{ - if (auto genericDeclRefType = as<GenericDeclRefType>(type)) - { - return declRef.equals(genericDeclRefType->declRef); - } - return false; -} - -HashCode GenericDeclRefType::_getHashCodeOverride() -{ - return declRef.getHashCode(); -} - Type* GenericDeclRefType::_createCanonicalTypeOverride() { return this; @@ -417,21 +309,7 @@ Type* GenericDeclRefType::_createCanonicalTypeOverride() void NamespaceType::_toTextOverride(StringBuilder& out) { - out << toSlice("namespace ") << declRef; -} - -bool NamespaceType::_equalsImplOverride(Type * type) -{ - if (auto namespaceType = as<NamespaceType>(type)) - { - return declRef.equals(namespaceType->declRef); - } - return false; -} - -HashCode NamespaceType::_getHashCodeOverride() -{ - return declRef.getHashCode(); + out << toSlice("namespace ") << getDeclRef(); } Type* NamespaceType::_createCanonicalTypeOverride() @@ -441,7 +319,7 @@ Type* NamespaceType::_createCanonicalTypeOverride() Type* DifferentialPairType::getPrimalType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); } @@ -449,51 +327,35 @@ Type* DifferentialPairType::getPrimalType() Type* PtrTypeBase::getValueType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); } Type* OptionalType::getValueType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(_getGenericTypeArg(this, 0)); +} + +Type* NativeRefType::getValueType() +{ + return as<Type>(_getGenericTypeArg(this, 0)); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void NamedExpressionType::_toTextOverride(StringBuilder& out) { - if (declRef.getDecl()) + if (getDeclRef().getDecl()) { - _printNestedDecl(declRef.getSubst(), declRef.getDecl(), out); + getDeclRef().declRefBase->toText(out); } } -bool NamedExpressionType::_equalsImplOverride(Type * /*type*/) -{ - SLANG_UNEXPECTED("NamedExpressionType::_equalsImplOverride should be unreachable"); - //return false; -} - Type* NamedExpressionType::_createCanonicalTypeOverride() { - if (!innerType) - innerType = getType(m_astBuilder, declRef); - if (innerType) - return innerType->getCanonicalType(); - return nullptr; -} - -HashCode NamedExpressionType::_getHashCodeOverride() -{ - // Type equality is based on comparing canonical types, - // so the hash code for a type needs to come from the - // canonical version of the type. This really means - // that `Type::getHashCode()` should dispatch out to - // something like `Type::getHashCodeImpl()` on the - // canonical version of a type, but it is less invasive - // for now (and hopefully equivalent) to just have any - // named types automaticlaly route hash-code requests - // to their canonical type. - return getCanonicalType()->getHashCode(); + auto canType = getType(getCurrentASTBuilder(), getDeclRef()); + if (canType) + return canType->getCanonicalType(); + return getCurrentASTBuilder()->getErrorType(); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -533,58 +395,27 @@ void FuncType::_toTextOverride(StringBuilder& out) } out << ") -> " << getResultType(); - if (!getErrorType()->equals(getASTBuilder()->getBottomType())) + if (!getErrorType()->equals(getCurrentASTBuilder()->getBottomType())) { out << " throws " << getErrorType(); } } -bool FuncType::_equalsImplOverride(Type * type) -{ - if (auto funcType = as<FuncType>(type)) - { - auto paramCount = getParamCount(); - auto otherParamCount = funcType->getParamCount(); - if (paramCount != otherParamCount) - return false; - - for (Index pp = 0; pp < paramCount; ++pp) - { - auto paramType = getParamType(pp); - auto otherParamType = funcType->getParamType(pp); - if (!paramType->equals(otherParamType)) - return false; - } - - if (!resultType->equals(funcType->resultType)) - return false; - - if (!errorType->equals(funcType->errorType)) - return false; - - // TODO: if we ever introduce other kinds - // of qualification on function types, we'd - // want to consider it here. - return true; - } - return false; -} - Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; // result type - Type* substResultType = as<Type>(resultType->substituteImpl(astBuilder, subst, &diff)); + Type* substResultType = as<Type>(getResultType()->substituteImpl(astBuilder, subst, &diff)); // error type - Type* substErrorType = as<Type>(errorType->substituteImpl(astBuilder, subst, &diff)); + Type* substErrorType = as<Type>(getErrorType()->substituteImpl(astBuilder, subst, &diff)); // parameter types List<Type*> substParamTypes; - for (auto pp : paramTypes) + for (Index pp = 0; pp < getParamCount(); pp++ ) { - substParamTypes.add(as<Type>(pp->substituteImpl(astBuilder, subst, &diff))); + substParamTypes.add(as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff))); } // early exit for no change... @@ -592,138 +423,75 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s return this; (*ioDiff)++; - FuncType* substType = astBuilder->create<FuncType>(); - substType->resultType = substResultType; - substType->paramTypes = substParamTypes; - substType->errorType = substErrorType; + FuncType* substType = astBuilder->getFuncType(substParamTypes.getArrayView(), substResultType, substErrorType); return substType; } Type* FuncType::_createCanonicalTypeOverride() { // result type - Type* canResultType = resultType->getCanonicalType(); - Type* canErrorType = errorType->getCanonicalType(); + Type* canResultType = getResultType()->getCanonicalType(); + Type* canErrorType = getErrorType()->getCanonicalType(); // parameter types List<Type*> canParamTypes; - for (auto pp : paramTypes) + for (Index pp = 0; pp < getParamCount(); pp++) { - canParamTypes.add(pp->getCanonicalType()); + canParamTypes.add(getParamType(pp)->getCanonicalType()); } - FuncType* canType = getASTBuilder()->create<FuncType>(); - canType->resultType = canResultType; - canType->paramTypes = canParamTypes; - canType->errorType = canErrorType; + FuncType* canType = getCurrentASTBuilder()->getFuncType(canParamTypes.getArrayView(), canResultType, canErrorType); return canType; } -HashCode FuncType::_getHashCodeOverride() -{ - HashCode hashCode = getResultType()->getHashCode(); - Index paramCount = getParamCount(); - hashCode = combineHash(hashCode, Slang::getHashCode(paramCount)); - for (Index pp = 0; pp < paramCount; ++pp) - { - hashCode = combineHash( - hashCode, - getParamType(pp)->getHashCode()); - } - combineHash(hashCode, getErrorType()->getHashCode()); - return hashCode; -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TupleType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void TupleType::_toTextOverride(StringBuilder& out) { out << toSlice("("); - for (Index pp = 0; pp < memberTypes.getCount(); ++pp) + for (Index pp = 0; pp < getOperandCount(); ++pp) { if (pp != 0) out << toSlice(", "); - out << memberTypes[pp]; + out << getOperand(pp); } out << toSlice(")"); } -bool TupleType::_equalsImplOverride(Type * type) -{ - if (const auto other = as<TupleType>(type)) - { - auto paramCount = memberTypes.getCount(); - auto otherParamCount = other->memberTypes.getCount(); - if (paramCount != otherParamCount) - return false; - - for (Index i = 0; i < memberTypes.getCount(); ++i) - { - if(!memberTypes[i]->equals(other->memberTypes[i])) - return false; - } - - return true; - } - return false; -} - Val* TupleType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; // just recurse into the members List<Type*> substMemberTypes; - for (auto m : memberTypes) - substMemberTypes.add(as<Type>(m->substituteImpl(astBuilder, subst, &diff))); + for (Index m = 0; m < getMemberCount(); m++) + substMemberTypes.add(as<Type>(getMember(m)->substituteImpl(astBuilder, subst, &diff))); // early exit for no change... if (!diff) return this; (*ioDiff)++; - return astBuilder->create<TupleType>(std::move(substMemberTypes)); + return astBuilder->getTupleType(substMemberTypes); } Type* TupleType::_createCanonicalTypeOverride() { // member types List<Type*> canMemberTypes; - for (auto m : memberTypes) + for (Index m = 0; m < getMemberCount(); m++) { - canMemberTypes.add(m->getCanonicalType()); + canMemberTypes.add(getMember(m)->getCanonicalType()); } - return getASTBuilder()->create<TupleType>(std::move(canMemberTypes)); -} - -HashCode TupleType::_getHashCodeOverride() -{ - HashCode hashCode = Slang::getHashCode(kType); - for(auto m : memberTypes) - hashCode = combineHash(hashCode, m->getHashCode()); - return hashCode; + return getCurrentASTBuilder()->getTupleType(canMemberTypes); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExtractExistentialType::_toTextOverride(StringBuilder& out) { - out << declRef << toSlice(".This"); -} - -bool ExtractExistentialType::_equalsImplOverride(Type* type) -{ - if (auto extractExistential = as<ExtractExistentialType>(type)) - { - return declRef.equals(extractExistential->declRef); - } - return false; -} - -HashCode ExtractExistentialType::_getHashCodeOverride() -{ - return combineHash(declRef.getHashCode(), originalInterfaceType->getHashCode(), originalInterfaceDeclRef.getHashCode()); + out << getDeclRef() << toSlice(".This"); } Type* ExtractExistentialType::_createCanonicalTypeOverride() @@ -734,18 +502,16 @@ Type* ExtractExistentialType::_createCanonicalTypeOverride() Val* ExtractExistentialType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); - auto substOriginalInterfaceType = originalInterfaceType->substituteImpl(astBuilder, subst, &diff); - auto substOriginalInterfaceDeclRef = originalInterfaceDeclRef.substituteImpl(astBuilder, subst, &diff); + auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); + auto substOriginalInterfaceType = getOriginalInterfaceType()->substituteImpl(astBuilder, subst, &diff); + auto substOriginalInterfaceDeclRef = getOriginalInterfaceDeclRef().substituteImpl(astBuilder, subst, &diff); if (!diff) return this; (*ioDiff)++; - ExtractExistentialType* substValue = astBuilder->create<ExtractExistentialType>(); - substValue->declRef = substDeclRef; - substValue->originalInterfaceType = as<Type>(substOriginalInterfaceType); - substValue->originalInterfaceDeclRef = substOriginalInterfaceDeclRef; + ExtractExistentialType* substValue = astBuilder->getOrCreate<ExtractExistentialType>( + substDeclRef, as<Type>(substOriginalInterfaceType), substOriginalInterfaceDeclRef); return substValue; } @@ -754,165 +520,47 @@ SubtypeWitness* ExtractExistentialType::getSubtypeWitness() if (auto cachedValue = this->cachedSubtypeWitness) return cachedValue; - ExtractExistentialSubtypeWitness* openedWitness = m_astBuilder->create<ExtractExistentialSubtypeWitness>(); - openedWitness->sub = this; - openedWitness->sup = originalInterfaceType; - openedWitness->declRef = this->declRef; - + ExtractExistentialSubtypeWitness* openedWitness = getCurrentASTBuilder()->getOrCreate<ExtractExistentialSubtypeWitness>(this, getOriginalInterfaceType(), getDeclRef()); this->cachedSubtypeWitness = openedWitness; return openedWitness; } -DeclRef<InterfaceDecl> ExtractExistentialType::getSpecializedInterfaceDeclRef() +DeclRef<ThisTypeDecl> ExtractExistentialType::getThisTypeDeclRef() { - if (auto cachedValue = this->cachedSpecializedInterfaceDeclRef) + if (auto cachedValue = this->cachedThisTypeDeclRef) return cachedValue; - auto interfaceDecl = originalInterfaceDeclRef.getDecl(); + auto interfaceDecl = getOriginalInterfaceDeclRef().getDecl(); SubtypeWitness* openedWitness = getSubtypeWitness(); - ThisTypeSubstitution* openedThisType = m_astBuilder->getOrCreateThisTypeSubstitution( - interfaceDecl, openedWitness, originalInterfaceDeclRef.getSubst()); - - DeclRef<InterfaceDecl> specialiedInterfaceDeclRef = m_astBuilder->getSpecializedDeclRef<InterfaceDecl>(interfaceDecl, openedThisType); - - this->cachedSpecializedInterfaceDeclRef = specialiedInterfaceDeclRef; - return specialiedInterfaceDeclRef; -} - - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -void TaggedUnionType::_toTextOverride(StringBuilder& out) -{ - out << toSlice("__TaggedUnion("); - bool first = true; - for (auto caseType : caseTypes) - { - if (!first) + ThisTypeDecl* thisTypeDecl = nullptr; + for (auto member : interfaceDecl->members) + if (as<ThisTypeDecl>(member)) { - out << toSlice(", "); + thisTypeDecl = as<ThisTypeDecl>(member); + break; } - first = false; - - out << caseType; - } - out << toSlice(")"); -} - -bool TaggedUnionType::_equalsImplOverride(Type* type) -{ - auto taggedUnion = as<TaggedUnionType>(type); - if (!taggedUnion) - return false; - - auto caseCount = caseTypes.getCount(); - if (caseCount != taggedUnion->caseTypes.getCount()) - return false; - - for (Index ii = 0; ii < caseCount; ++ii) - { - if (!caseTypes[ii]->equals(taggedUnion->caseTypes[ii])) - return false; - } - return true; -} - -HashCode TaggedUnionType::_getHashCodeOverride() -{ - HashCode hashCode = 0; - for (auto caseType : caseTypes) - { - hashCode = combineHash(hashCode, caseType->getHashCode()); - } - return hashCode; -} - -Type* TaggedUnionType::_createCanonicalTypeOverride() -{ - TaggedUnionType* canType = m_astBuilder->create<TaggedUnionType>(); - - for (auto caseType : caseTypes) - { - auto canCaseType = caseType->getCanonicalType(); - canType->caseTypes.add(canCaseType); - } - - return canType; -} + SLANG_ASSERT(thisTypeDecl); -Val* TaggedUnionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; + DeclRef<ThisTypeDecl> specialiedInterfaceDeclRef = getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl); - List<Type*> substCaseTypes; - for (auto caseType : caseTypes) - { - substCaseTypes.add(as<Type>(caseType->substituteImpl(astBuilder, subst, &diff))); - } - if (!diff) - return this; - - (*ioDiff)++; - - TaggedUnionType* substType = astBuilder->create<TaggedUnionType>(); - substType->caseTypes.swapWith(substCaseTypes); - return substType; + this->cachedThisTypeDeclRef = specialiedInterfaceDeclRef; + return specialiedInterfaceDeclRef; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExistentialSpecializedType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExistentialSpecializedType::_toTextOverride(StringBuilder& out) { - out << toSlice("__ExistentialSpecializedType(") << baseType; - for (auto arg : args) + out << toSlice("__ExistentialSpecializedType(") << getBaseType(); + for (Index i = 0; i < getArgCount(); i++) { - out << toSlice(", ") << arg.val; + out << toSlice(", ") << getArg(i).val; } out << toSlice(")"); } -bool ExistentialSpecializedType::_equalsImplOverride(Type * type) -{ - auto other = as<ExistentialSpecializedType>(type); - if (!other) - return false; - - if (!baseType->equals(other->baseType)) - return false; - - auto argCount = args.getCount(); - if (argCount != other->args.getCount()) - return false; - - for (Index ii = 0; ii < argCount; ++ii) - { - auto arg = args[ii]; - auto otherArg = other->args[ii]; - - if (!arg.val->equalsVal(otherArg.val)) - return false; - - if (!areValsEqual(arg.witness, otherArg.witness)) - return false; - } - return true; -} - -HashCode ExistentialSpecializedType::_getHashCodeOverride() -{ - Hasher hasher; - hasher.hashObject(baseType); - for (auto arg : args) - { - hasher.hashObject(arg.val); - if (auto witness = arg.witness) - hasher.hashObject(witness); - } - return hasher.getResult(); -} - static Val* _getCanonicalValue(Val* val) { if (!val) @@ -928,16 +576,21 @@ static Val* _getCanonicalValue(Val* val) Type* ExistentialSpecializedType::_createCanonicalTypeOverride() { - ExistentialSpecializedType* canType = m_astBuilder->create<ExistentialSpecializedType>(); + ExpandedSpecializationArgs newArgs; - canType->baseType = baseType->getCanonicalType(); - for (auto arg : args) + for (Index ii = 0; ii < getArgCount(); ++ii) { + auto arg = getArg(ii); ExpandedSpecializationArg canArg; canArg.val = _getCanonicalValue(arg.val); canArg.witness = _getCanonicalValue(arg.witness); - canType->args.add(canArg); + newArgs.add(canArg); } + + ExistentialSpecializedType* canType = getCurrentASTBuilder()->getOrCreate<ExistentialSpecializedType>( + getBaseType()->getCanonicalType(), + newArgs); + return canType; } @@ -951,11 +604,12 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder, { int diff = 0; - auto substBaseType = as<Type>(baseType->substituteImpl(astBuilder, subst, &diff)); + auto substBaseType = as<Type>(getBaseType()->substituteImpl(astBuilder, subst, &diff)); ExpandedSpecializationArgs substArgs; - for (auto arg : args) + for (Index ii = 0; ii < getArgCount(); ++ii) { + auto arg = getArg(ii); ExpandedSpecializationArg substArg; substArg.val = _substituteImpl(astBuilder, arg.val, subst, &diff); substArg.witness = _substituteImpl(astBuilder, arg.witness, subst, &diff); @@ -967,96 +621,22 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder, (*ioDiff)++; - ExistentialSpecializedType* substType = astBuilder->create<ExistentialSpecializedType>(); - substType->baseType = substBaseType; - substType->args = substArgs; + ExistentialSpecializedType* substType = astBuilder->getOrCreate<ExistentialSpecializedType>(substBaseType, substArgs); return substType; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -void ThisType::_toTextOverride(StringBuilder& out) -{ - out << interfaceDeclRef << toSlice(".This"); -} - -bool ThisType::_equalsImplOverride(Type * type) -{ - auto other = as<ThisType>(type); - if (!other) - return false; - - if (!interfaceDeclRef.equals(other->interfaceDeclRef)) - return false; - - return true; -} - -HashCode ThisType::_getHashCodeOverride() -{ - return combineHash( - HashCode(typeid(*this).hash_code()), - interfaceDeclRef.getHashCode()); -} - -Type* ThisType::_createCanonicalTypeOverride() +InterfaceDecl* ThisType::getInterfaceDecl() { - ThisType* canType = m_astBuilder->create<ThisType>(); - - // TODO: need to canonicalize the decl-ref - canType->interfaceDeclRef = interfaceDeclRef; - return canType; -} - -Val* ThisType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substInterfaceDeclRef = interfaceDeclRef.substituteImpl(astBuilder, subst, &diff); - - auto thisTypeSubst = findThisTypeSubstitution(subst.substitutions, substInterfaceDeclRef.getDecl()); - if (thisTypeSubst) - { - return thisTypeSubst->witness->sub; - } - - if (!diff) - return this; - - (*ioDiff)++; - - ThisType* substType = m_astBuilder->create<ThisType>(); - substType->interfaceDeclRef = substInterfaceDeclRef; - return substType; + return dynamicCast<InterfaceDecl>(getDeclRefBase()->getDecl()->parentDecl); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AndType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void AndType::_toTextOverride(StringBuilder& out) { - out << left << toSlice(" & ") << right; -} - -bool AndType::_equalsImplOverride(Type * type) -{ - auto other = as<AndType>(type); - if (!other) - return false; - - if(!left->equals(other->left)) - return false; - if(!right->equals(other->right)) - return false; - - return true; -} - -HashCode AndType::_getHashCodeOverride() -{ - Hasher hasher; - hasher.hashObject(left); - hasher.hashObject(right); - return hasher.getResult(); + out << getLeft() << toSlice(" & ") << getRight(); } Type* AndType::_createCanonicalTypeOverride() @@ -1094,9 +674,9 @@ Type* AndType::_createCanonicalTypeOverride() // right now, in the name of getting something up and running. // - auto canLeft = left->getCanonicalType(); - auto canRight = right->getCanonicalType(); - auto canType = m_astBuilder->getAndType(canLeft, canRight); + auto canLeft = getLeft()->getCanonicalType(); + auto canRight = getRight()->getCanonicalType(); + auto canType = getCurrentASTBuilder()->getAndType(canLeft, canRight); return canType; } @@ -1104,15 +684,15 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su { int diff = 0; - auto substLeft = as<Type>(left ->substituteImpl(astBuilder, subst, &diff)); - auto substRight = as<Type>(right->substituteImpl(astBuilder, subst, &diff)); + auto substLeft = as<Type>(getLeft()->substituteImpl(astBuilder, subst, &diff)); + auto substRight = as<Type>(getRight()->substituteImpl(astBuilder, subst, &diff)); if(!diff) return this; (*ioDiff)++; - auto substType = m_astBuilder->getAndType(substLeft, substRight); + auto substType = getCurrentASTBuilder()->getAndType(substLeft, substRight); return substType; } @@ -1120,83 +700,35 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su void ModifiedType::_toTextOverride(StringBuilder& out) { - for( auto modifier : modifiers ) + for( Index i = 0; i < getModifierCount(); i++ ) { - modifier->toText(out); + getModifier(i)->toText(out); out.appendChar(' '); } - base->toText(out); -} - -bool ModifiedType::_equalsImplOverride(Type* type) -{ - auto other = as<ModifiedType>(type); - if(!other) - return false; - - if(!base->equals(other->base)) - return false; - - // TODO: Eventually we need to put the `modifiers` into - // a canonical ordering as part of creation of a `ModifiedType`, - // so that two instances that apply the same modifiers to - // the same type will have those modifiers in a matching order. - // - // The simplest way to achieve that ordering *for now* would - // be to sort the array by the integer AST node type tag. - // That approach would of course not scale to modifiers that - // have any operands of their own. - // - // Note that we would *also* need the logic that creates a - // `ModifiedType` to detect when the base type is itself a - // `ModifiedType` and produce a single `ModifiedType` with - // a combined list of modifiers and a non-`ModifiedType` as - // its base type. - // - auto modifierCount = modifiers.getCount(); - if(modifierCount != other->modifiers.getCount()) - return false; - - for( Index i = 0; i < modifierCount; ++i ) - { - auto thisModifier = this->modifiers[i]; - auto otherModifier = other->modifiers[i]; - if(!thisModifier->equalsVal(otherModifier)) - return false; - } - return true; -} - -HashCode ModifiedType::_getHashCodeOverride() -{ - Hasher hasher; - hasher.hashObject(base); - for( auto modifier : modifiers ) - { - hasher.hashObject(modifier); - } - return hasher.getResult(); + getBase()->toText(out); } Type* ModifiedType::_createCanonicalTypeOverride() { - ModifiedType* canonical = m_astBuilder->create<ModifiedType>(); - canonical->base = base->getCanonicalType(); - for( auto modifier : modifiers ) + List<Val*> modifiers; + for (Index i = 0; i < getModifierCount(); ++i) { - canonical->modifiers.add(modifier); + auto modifier = this->getModifier(i); + modifiers.add(modifier); } + ModifiedType* canonical = getCurrentASTBuilder()->getOrCreate<ModifiedType>(getBase()->getCanonicalType(), modifiers.getArrayView()); return canonical; } Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - Type* substBase = as<Type>(base->substituteImpl(astBuilder, subst, &diff)); + Type* substBase = as<Type>(getBase()->substituteImpl(astBuilder, subst, &diff)); List<Val*> substModifiers; - for( auto modifier : modifiers ) + for (Index i = 0; i < getModifierCount(); ++i) { + auto modifier = this->getModifier(i); auto substModifier = modifier->substituteImpl(astBuilder, subst, &diff); substModifiers.add(substModifier); } @@ -1206,12 +738,49 @@ Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionS *ioDiff = 1; - ModifiedType* substType = m_astBuilder->create<ModifiedType>(); - substType->base = substBase; - substType->modifiers = _Move(substModifiers); + ModifiedType* substType = getCurrentASTBuilder()->getOrCreate<ModifiedType>(substBase, substModifiers.getArrayView()); return substType; } +BaseType BasicExpressionType::getBaseType() const +{ + auto builtinType = getDeclRef().getDecl()->findModifier<BuiltinTypeModifier>(); + return builtinType->tag; +} + +FeedbackType::Kind FeedbackType::getKind() const +{ + auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>(); + return FeedbackType::Kind(magicMod->tag); +} + +TextureFlavor ResourceType::getFlavor() const +{ + auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>(); + return TextureFlavor(magicMod->tag); +} + +SamplerStateFlavor SamplerStateType::getFlavor() const +{ + auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>(); + return SamplerStateFlavor(magicMod->tag); +} + +Type* BuiltinGenericType::getElementType() const +{ + return as<Type>(_getGenericTypeArg(getDeclRefBase(), 0)); +} + +Type* ResourceType::getElementType() +{ + return as<Type>(_getGenericTypeArg(this, 0)); +} + +Val* TextureTypeBase::getSampleCount() +{ + return as<Type>(_getGenericTypeArg(this, 1)); +} + Type* removeParamDirType(Type* type) { for (auto paramDirType = as<ParamDirectionType>(type); paramDirType;) diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index b32d62404..8948f742c 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -16,8 +16,6 @@ class OverloadGroupType : public Type // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); }; // The type of an initializer-list expression (before it has @@ -29,8 +27,6 @@ class InitializerListType : public Type // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); }; // The type of an expression that was erroneous @@ -41,8 +37,6 @@ class ErrorType : public Type // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; @@ -53,31 +47,28 @@ class BottomType : public Type // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); - Type* _createCanonicalTypeOverride(); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; // A type that takes the form of a reference to some declaration -class DeclRefType : public Type +class DeclRefType : public Type { SLANG_AST_CLASS(DeclRefType) - DeclRef<Decl> declRef; - static DeclRefType* create(ASTBuilder* astBuilder, DeclRef<Decl> declRef); + DeclRef<Decl> getDeclRef() const { return DeclRef<Decl>(as<DeclRefBase>(getOperand(0))); } + DeclRefBase* getDeclRefBase() const { return as<DeclRefBase>(getOperand(0)); } + // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); DeclRefType(DeclRefBase* declRefBase) - : declRef(declRefBase) - {} + { + setOperands(declRefBase); + } }; // Base class for types that can be used in arithmetic expressions @@ -95,18 +86,15 @@ class BasicExpressionType : public ArithmeticExpressionType { SLANG_AST_CLASS(BasicExpressionType) - BaseType baseType; + BaseType getBaseType() const; // Overrides should be public so base classes can access - Type* _createCanonicalTypeOverride(); - bool _equalsImplOverride(Type* type); BasicExpressionType* _getScalarTypeOverride(); -protected: - BasicExpressionType( - Slang::BaseType baseType) - : baseType(baseType) - {} + BasicExpressionType(DeclRefBase* inDeclRef) + { + setOperands(inDeclRef); + } }; // Base type for things that are built in to the compiler, @@ -127,7 +115,7 @@ class FeedbackType : public BuiltinType MipRegionUsed, /// SAMPLER_FEEDBACK_MIP_REGION_USED }; - Kind kind; + Kind getKind() const; }; // Resources that contain "elements" that can be fetched @@ -135,37 +123,24 @@ class ResourceType : public BuiltinType { SLANG_ABSTRACT_AST_CLASS(ResourceType) - // The type that results from fetching an element from this resource - Type* elementType = nullptr; - - // Shape and access level information for this resource type - TextureFlavor flavor; + TextureFlavor getFlavor() const; TextureFlavor::Shape getBaseShape() { - return flavor.getBaseShape(); + return getFlavor().getBaseShape(); } - bool isMultisample() { return flavor.isMultisample(); } - bool isArray() { return flavor.isArray(); } - SlangResourceShape getShape() const { return flavor.getShape(); } - SlangResourceAccess getAccess() { return flavor.getAccess(); } + bool isMultisample() { return getFlavor().isMultisample(); } + bool isArray() { return getFlavor().isArray(); } + SlangResourceShape getShape() const { return getFlavor().getShape(); } + SlangResourceAccess getAccess() { return getFlavor().getAccess(); } + Type* getElementType(); }; class TextureTypeBase : public ResourceType { SLANG_ABSTRACT_AST_CLASS(TextureTypeBase) - // The sampleCount parameter of a RWTexture*MS resource. - Val* sampleCount = nullptr; -protected: - TextureTypeBase(TextureFlavor inFlavor, Type* inElementType, Val* inSampleCount = nullptr) - { - elementType = inElementType; - flavor = inFlavor; - sampleCount = inSampleCount; - } - - Val* getSampleCount() const { return sampleCount; } + Val* getSampleCount(); }; @@ -173,11 +148,6 @@ protected: class TextureType : public TextureTypeBase { SLANG_AST_CLASS(TextureType) - -protected: - TextureType(TextureFlavor flavor, Type* elementType, Val* inSampleCount = nullptr) - : TextureTypeBase(flavor, elementType, inSampleCount) - {} }; @@ -186,37 +156,20 @@ protected: class TextureSamplerType : public TextureTypeBase { SLANG_AST_CLASS(TextureSamplerType) - -protected: - TextureSamplerType(TextureFlavor flavor, Type* elementType) - : TextureTypeBase(flavor, elementType) - {} }; // This is a base type for `image*` types, as they exist in GLSL class GLSLImageType : public TextureTypeBase { SLANG_AST_CLASS(GLSLImageType) - -protected: - GLSLImageType( - TextureFlavor flavor, - Type* elementType) - : TextureTypeBase(flavor, elementType) - {} }; class SamplerStateType : public BuiltinType { SLANG_AST_CLASS(SamplerStateType) - // What flavor of sampler state is this - SamplerStateFlavor flavor; - - SamplerStateType(SamplerStateFlavor inFlavor) - { - flavor = inFlavor; - } + // Returns flavor of sampler state of this type. + SamplerStateFlavor getFlavor() const; }; // Other cases of generic types known to the compiler @@ -224,9 +177,7 @@ class BuiltinGenericType : public BuiltinType { SLANG_AST_CLASS(BuiltinGenericType) - Type* elementType = nullptr; - - Type* getElementType() { return elementType; } + Type* getElementType() const; }; // Types that behave like pointers, in that they can be @@ -297,7 +248,6 @@ class HLSLConsumeStructuredBufferType : public HLSLStructuredBufferTypeBase SLANG_AST_CLASS(HLSLConsumeStructuredBufferType) }; - class HLSLPatchType : public BuiltinType { SLANG_AST_CLASS(HLSLPatchType) @@ -396,7 +346,6 @@ class VaryingParameterGroupType : public ParameterGroupType class ConstantBufferType : public UniformParameterGroupType { SLANG_AST_CLASS(ConstantBufferType) - ConstantBufferType(Type* elementType) { SLANG_UNUSED(elementType); } }; @@ -435,11 +384,7 @@ class ParameterBlockType : public UniformParameterGroupType class ArrayExpressionType : public DeclRefType { SLANG_AST_CLASS(ArrayExpressionType) - ArrayExpressionType(Type* inElementType, IntVal* inElementCount) - { - SLANG_UNUSED(inElementType); - SLANG_UNUSED(inElementCount); - } + bool isUnsized(); void _toTextOverride(StringBuilder& out); Type* getElementType(); @@ -453,21 +398,16 @@ class TypeType : public Type { SLANG_AST_CLASS(TypeType) - // The type that this is the type of... - Type* type = nullptr; - // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); -protected: - TypeType(Type* type) - : type(type) - {} + Type* getType() { return as<Type>(getOperand(0)); } - + TypeType(Type* type) + { + setOperands(type); + } }; // A differential pair type, e.g., `__DifferentialPair<T>` @@ -487,20 +427,12 @@ class VectorExpressionType : public ArithmeticExpressionType { SLANG_AST_CLASS(VectorExpressionType) - // The type of vector elements. - // As an invariant, this should be a basic type or an alias. - Type* elementType = nullptr; - - // The number of elements - IntVal* elementCount = nullptr; - // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); BasicExpressionType* _getScalarTypeOverride(); - VectorExpressionType(Type* inElementType, IntVal* inElementCount) - : elementType(inElementType), elementCount(inElementCount) - {} + Type* getElementType(); + IntVal* getElementCount(); }; // A matrix type, e.g., `matrix<T,R,C>` @@ -519,9 +451,7 @@ class MatrixExpressionType : public ArithmeticExpressionType BasicExpressionType* _getScalarTypeOverride(); private: - Type* rowType = nullptr; - - MatrixExpressionType(Type*, IntVal*, IntVal*) {} + SLANG_UNREFLECTED Type* rowType = nullptr; }; class TensorViewType : public BuiltinType @@ -529,8 +459,6 @@ class TensorViewType : public BuiltinType SLANG_AST_CLASS(TensorViewType) Type* getElementType(); -private: - TensorViewType(Type*) {} }; // Base class for built in string types @@ -561,6 +489,7 @@ class DynamicType : public BuiltinType class EnumTypeType : public BuiltinType { SLANG_AST_CLASS(EnumTypeType) + // TODO: provide accessors for the declaration, the "tag" type, etc. }; @@ -640,22 +569,16 @@ class NamedExpressionType : public Type { SLANG_AST_CLASS(NamedExpressionType) - DeclRef<TypeDefDecl> declRef; - Type* innerType = nullptr; + DeclRef<TypeDefDecl> getDeclRef() { return as<DeclRefBase>(getOperand(0)); } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); - -protected: - NamedExpressionType( - DeclRef<TypeDefDecl> declRef) - : declRef(declRef) - {} - + NamedExpressionType(DeclRef<TypeDefDecl> inDeclRef) + { + setOperands(inDeclRef); + } }; // A function type is defined by its parameter types @@ -666,27 +589,24 @@ class FuncType : public Type // Construct a unary function FuncType(Type* paramType, Type* resultType, Type* errorType) - : paramTypes{{paramType}}, resultType{resultType}, errorType{errorType} - {} - - FuncType(List<Type*> parameters, Type* result, Type* error) - : paramTypes(std::move(parameters)), resultType(result), errorType(error) - {} + { + setOperands(paramType, resultType, errorType); + } - // TODO: We may want to preserve parameter names - // in the list here, just so that we can print - // out friendly names when printing a function - // type, even if they don't affect the actual - // semantic type underneath. + FuncType(ArrayView<Type*> parameters, Type* result, Type* error) + { + for (auto paramType : parameters) + m_operands.add(ValNodeOperand(paramType)); + m_operands.add(ValNodeOperand(result)); + m_operands.add(ValNodeOperand(error)); + } - List<Type*> paramTypes; - Type* resultType = nullptr; - Type* errorType = nullptr; + OperandView<Type> getParamTypes() { return OperandView<Type>(this, 0, getOperandCount() - 2); } - Index getParamCount() { return paramTypes.getCount(); } - Type* getParamType(Index index) { return paramTypes[index]; } - Type* getResultType() { return resultType; } - Type* getErrorType() { return errorType; } + Index getParamCount() { return m_operands.getCount() - 2; } + Type* getParamType(Index index) { return as<Type>(getOperand(index)); } + Type* getResultType() { return as<Type>(getOperand(m_operands.getCount() - 2)); } + Type* getErrorType() { return as<Type>(getOperand(m_operands.getCount() - 1)); } ParameterDirection getParamDirection(Index index); @@ -694,8 +614,6 @@ class FuncType : public Type void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); }; // A tuple is a product of its member types @@ -704,21 +622,19 @@ class TupleType : public Type SLANG_AST_CLASS(TupleType) // Construct a unary tupletion - TupleType(List<Type*> memberTypes) - : memberTypes(std::move(memberTypes)) - {} - - auto getMemberCount() { return memberTypes.getCount(); } const - auto& getMember(Index i) { return memberTypes[i]; } + TupleType(ArrayView<Type*> memberTypes) + { + for (auto t : memberTypes) + m_operands.add(ValNodeOperand(t)); + } - List<Type*> memberTypes; + auto getMemberCount() const { return getOperandCount(); } + Type* getMember(Index i) const { return as<Type>(getOperand(i)); } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); }; // The "type" of an expression that names a generic declaration. @@ -726,21 +642,16 @@ class GenericDeclRefType : public Type { SLANG_AST_CLASS(GenericDeclRefType) - DeclRef<GenericDecl> declRef; - - DeclRef<GenericDecl> const& getDeclRef() const { return declRef; } + DeclRef<GenericDecl> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); Type* _createCanonicalTypeOverride(); -protected: - GenericDeclRefType( - DeclRef<GenericDecl> declRef) - : declRef(declRef) - {} + GenericDeclRefType(DeclRef<GenericDecl> declRef) + { + setOperands(declRef); + } }; // The "type" of a reference to a module or namespace @@ -748,14 +659,15 @@ class NamespaceType : public Type { SLANG_AST_CLASS(NamespaceType) - DeclRef<NamespaceDeclBase> declRef; + DeclRef<NamespaceDeclBase> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); } - DeclRef<NamespaceDeclBase> const& getDeclRef() const { return declRef; } + NamespaceType(DeclRef<NamespaceDeclBase> inDeclRef) + { + setOperands(inDeclRef); + } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); Type* _createCanonicalTypeOverride(); }; @@ -765,25 +677,33 @@ class ExtractExistentialType : public Type { SLANG_AST_CLASS(ExtractExistentialType) - DeclRef<VarDeclBase> declRef; + DeclRef<VarDeclBase> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); } // A reference to the original interface this type is known // to be a subtype of. // - Type* originalInterfaceType; - DeclRef<InterfaceDecl> originalInterfaceDeclRef; + Type* getOriginalInterfaceType() { return as<Type>(getOperand(1)); } + DeclRef<InterfaceDecl> getOriginalInterfaceDeclRef() { return as<DeclRefBase>(getOperand(2)); } + + ExtractExistentialType( + DeclRef<VarDeclBase> inDeclRef, + Type* inOriginalInterfaceType, + DeclRef<InterfaceDecl> inOriginalInterfaceDeclRef) + { + setOperands(inDeclRef, inOriginalInterfaceType, inOriginalInterfaceDeclRef); + } // Following fields will not be reflected (and thus won't be serialized, etc.) SLANG_UNREFLECTED - // A cached decl-ref to the original interface above, with - // a this-type substitution that refers to the type extracted here. + // A cached decl-ref to the original interface's ThisType Decl, with + // a witness that refers to the type extracted here. // // This field is optional and can be filled in on-demand. It does *not* // represent part of the logical value of this `Type`, and should not // be serialized, included in hashes, etc. // - DeclRef<InterfaceDecl> cachedSpecializedInterfaceDeclRef; + DeclRef<ThisTypeDecl> cachedThisTypeDeclRef; // A cached pointer to a witness that shows how this type is a subtype // of `originalInterfaceType`. @@ -792,8 +712,6 @@ SLANG_UNREFLECTED // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); Type* _createCanonicalTypeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); @@ -803,62 +721,54 @@ SLANG_UNREFLECTED /// SubtypeWitness* getSubtypeWitness(); - /// Get a interface decl-ref for the original interface specialized to this type - /// (using a type-type substitution). + /// Get a decl-ref to the interface's ThisType decl, which represents a substitutable type + /// from which lookup can be performed. /// /// This operation may create the decl-ref on demand and cache it. /// - DeclRef<InterfaceDecl> getSpecializedInterfaceDeclRef(); -}; - - /// A tagged union of zero or more other types. -class TaggedUnionType : public Type -{ - SLANG_AST_CLASS(TaggedUnionType) - - /// The distinct "cases" the tagged union can store. - /// - /// For each type in this array, the array index is the - /// tag value for that case. - /// - List<Type*> caseTypes; - - // Overrides should be public so base classes can access - void _toTextOverride(StringBuilder& out); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); - Type* _createCanonicalTypeOverride(); - Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + DeclRef<ThisTypeDecl> getThisTypeDeclRef(); }; class ExistentialSpecializedType : public Type { SLANG_AST_CLASS(ExistentialSpecializedType) - Type* baseType = nullptr; - ExpandedSpecializationArgs args; + Type* getBaseType() { return as<Type>(getOperand(0)); } + ExpandedSpecializationArg getArg(Index i) + { + ExpandedSpecializationArg arg; + arg.val = getOperand(i * 2 + 1); + arg.witness = getOperand(i * 2 + 2); + return arg; + } + Index getArgCount() { return (getOperandCount() - 1) / 2; } + + ExistentialSpecializedType( + Type* inBaseType, + ExpandedSpecializationArgs const& inArgs) + { + m_operands.add(ValNodeOperand(inBaseType)); + for (auto arg : inArgs) + { + m_operands.add(ValNodeOperand(arg.val)); + m_operands.add(ValNodeOperand(arg.witness)); + } + } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); Type* _createCanonicalTypeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; /// The type of `this` within a polymorphic declaration -class ThisType : public Type +class ThisType : public DeclRefType { SLANG_AST_CLASS(ThisType) - DeclRef<InterfaceDecl> interfaceDeclRef; + ThisType(DeclRefBase* declRef) : DeclRefType(declRef) {} - // Overrides should be public so base classes can access - void _toTextOverride(StringBuilder& out); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); - Type* _createCanonicalTypeOverride(); - Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + InterfaceDecl* getInterfaceDecl(); }; /// The type of `A & B` where `A` and `B` are types @@ -868,17 +778,16 @@ class AndType : public Type { SLANG_AST_CLASS(AndType) - Type* left; - Type* right; - + Type* getLeft() { return as<Type>(getOperand(0)); } + Type* getRight() { return as<Type>(getOperand(1)); } + AndType(Type* leftType, Type* rightType) - : left(leftType), right(rightType) - {} + { + setOperands(leftType, rightType); + } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); Type* _createCanonicalTypeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; @@ -887,22 +796,32 @@ class ModifiedType : public Type { SLANG_AST_CLASS(ModifiedType) - Type* base; - List<Val*> modifiers; + Type* getBase() + { + return as<Type>(getOperand(0)); + } + + Index getModifierCount() { return getOperandCount() - 1; } + Val* getModifier(Index index) { return getOperand(index + 1); } + + ModifiedType(Type* inBase, ArrayView<Val*> inModifiers) + { + m_operands.add(ValNodeOperand(inBase)); + for (auto modifier : inModifiers) + m_operands.add(ValNodeOperand(modifier)); + } template<typename T> T* findModifier() { - for (auto v : modifiers) - if (auto rs = as<T>(v)) + for (Index i = 1; i < getOperandCount(); i++) + if (auto rs = as<T>(getOperand(i))) return rs; return nullptr; } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); - bool _equalsImplOverride(Type* type); - HashCode _getHashCodeOverride(); Type* _createCanonicalTypeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index b45300af8..056577eb0 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -6,9 +6,47 @@ #include "slang-generated-ast-macro.h" #include "slang-diagnostics.h" #include "slang-syntax.h" +#include "slang-ast-val.h" namespace Slang { + +bool ValNodeDesc::operator==(ValNodeDesc const& that) const +{ + if (hashCode != that.hashCode) return false; + if (type != that.type) return false; + if (operands.getCount() != that.operands.getCount()) return false; + for (Index i = 0; i < operands.getCount(); ++i) + { + // Note: we are comparing the operands directly for identity + // (pointer equality) rather than doing the `Val`-level + // equality check. + // + // The rationale here is that nodes that will be created + // via a `NodeDesc` *should* all be going through the + // deduplication path anyway, as should their operands. + // + if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; + } + return true; +} + +void ValNodeDesc::init() +{ + Hasher hasher; + hasher.hashValue(Int(type)); + for (Index i = 0; i < operands.getCount(); ++i) + { + // Note: we are hashing the raw pointer value rather + // than the content of the value node. This is done + // to match the semantics implemented for `==` on + // `NodeDesc`. + // + hasher.hashValue(operands[i].values.nodeOperand); + } + hashCode = hasher.getResult(); +} + Val* Val::substitute(ASTBuilder* astBuilder, SubstitutionSet subst) { if (!subst) return this; @@ -21,14 +59,103 @@ Val* Val::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioD SLANG_AST_NODE_VIRTUAL_CALL(Val, substituteImpl, (astBuilder, subst, ioDiff)) } -bool Val::equalsVal(Val* val) +void Val::toText(StringBuilder& out) { - SLANG_AST_NODE_VIRTUAL_CALL(Val, equalsVal, (val)) + SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out)) } -void Val::toText(StringBuilder& out) +Val* Val::_resolveImplOverride() { - SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out)) + SLANG_UNEXPECTED("Val::_resolveImplOverride not overridden"); +} + +Val* Val::resolveImpl() +{ + SLANG_AST_NODE_VIRTUAL_CALL(Val, resolveImpl, ()); +} + +Val* Val::resolve() +{ + auto astBuilder = getCurrentASTBuilder(); + + // If we are not in a proper checking context, just return the previously resolved val. + if (!astBuilder) + return m_resolvedVal? m_resolvedVal : this; + if (m_resolvedVal && m_resolvedValEpoch == getCurrentASTBuilder()->getEpoch()) + { + SLANG_ASSERT(as<Val>(m_resolvedVal)); + return m_resolvedVal; + } + + // Update epoch now to avoid infinite recursion. + m_resolvedValEpoch = getCurrentASTBuilder()->getEpoch(); + m_resolvedVal = this; + m_resolvedVal = resolveImpl(); + + // Check if we are resolved to an existing Val in the AST cache. + ValNodeDesc newDesc; + newDesc.type = m_resolvedVal->astNodeType; + for (auto operand : m_resolvedVal->m_operands) + { + if (operand.kind == ValNodeOperandKind::ValNode) + { + auto valOperand = as<Val>(operand.values.nodeOperand); + if (valOperand) + { + operand.values.nodeOperand = valOperand->resolve(); + } + } + newDesc.operands.add(operand); + } + newDesc.init(); + + NodeBase* existingNode = nullptr; + if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode)) + m_resolvedVal = as<Val>(existingNode); + +#ifdef _DEBUG + if (m_resolvedVal->_debugUID > 0 && this->_debugUID < 0) + { + //SLANG_ASSERT_FAILURE("should not be modifying stdlib vals outside of stdlib checking."); + } +#endif + return m_resolvedVal; +} + +ValNodeDesc Val::getDesc() +{ + ValNodeDesc desc; + desc.type = astNodeType; + for (auto operand : m_operands) + desc.operands.add(operand); + desc.init(); + return desc; +} + +Val* Val::defaultResolveImpl() +{ + // Default resolve implementation is to recursively resolve all operands, and lookup in deduplication cache. + ValNodeDesc newDesc; + newDesc.type = astNodeType; + for (auto operand : m_operands) + { + if (operand.kind == ValNodeOperandKind::ValNode) + { + auto valOperand = as<Val>(operand.values.nodeOperand); + if (valOperand) + { + operand.values.nodeOperand = valOperand->resolve(); + } + } + newDesc.operands.add(operand); + } + newDesc.init(); + auto astBuilder = getCurrentASTBuilder(); + + NodeBase* existingNode = nullptr; + if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode)) + return as<Val>(existingNode); + return this; } String Val::toString() @@ -40,7 +167,7 @@ String Val::toString() HashCode Val::getHashCode() { - SLANG_AST_NODE_VIRTUAL_CALL(Val, getHashCode, ()) + return Slang::getHashCode(resolve()); } Val* Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) @@ -52,124 +179,84 @@ Val* Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, return this; } -bool Val::_equalsValOverride(Val* val) -{ - SLANG_UNUSED(val); - SLANG_UNEXPECTED("Val::_equalsValOverride not overridden"); - //return false; -} - void Val::_toTextOverride(StringBuilder& out) { SLANG_UNUSED(out); SLANG_UNEXPECTED("Val::_toStringOverride not overridden"); } -HashCode Val::_getHashCodeOverride() -{ - SLANG_UNEXPECTED("Val::_getHashCodeOverride not overridden"); - //return HashCode(0); -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ConstantIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ConstantIntVal::_equalsValOverride(Val* val) -{ - if (auto intVal = as<ConstantIntVal>(val)) - return value == intVal->value; - return false; -} - void ConstantIntVal::_toTextOverride(StringBuilder& out) { - out << value; -} - -HashCode ConstantIntVal::_getHashCodeOverride() -{ - return (HashCode)value; + out << getValue(); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericParamIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool GenericParamIntVal::_equalsValOverride(Val* val) -{ - if (auto genericParamVal = as<GenericParamIntVal>(val)) - { - return declRef.equals(genericParamVal->declRef); - } - return false; -} - void GenericParamIntVal::_toTextOverride(StringBuilder& out) { - Name* name = declRef.getName(); + Name* name = getDeclRef().getName(); if (name) { out << name->text; } } -HashCode GenericParamIntVal::_getHashCodeOverride() -{ - return declRef.getHashCode() ^ HashCode(0xFFFF); -} - Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff) { // search for a substitution that might apply to us - for (auto s = subst.substitutions; s; s = s->getOuter()) + auto outerGeneric = as<GenericDecl>(paramDecl->parentDecl); + if (!outerGeneric) + return paramVal; + + GenericAppDeclRef* genAppArgs = subst.findGenericAppDeclRef(outerGeneric); + if (!genAppArgs) { - auto genSubst = as<GenericSubstitution>(s); - if (!genSubst) - continue; - - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genSubst->getGenericDecl(); - if (genericDecl != paramDecl->parentDecl) - continue; - - // In some cases, we construct a `DeclRef` to a `GenericDecl` - // (or a declaration under one) that only includes argument - // values for a prefix of the parameters of the generic. - // - // If we aren't careful, we could end up indexing into the - // argument list past the available range. - // - Count argCount = genSubst->getArgs().getCount(); + return paramVal; + } - Count argIndex = 0; - for (auto m : genericDecl->members) + auto args = genAppArgs->getArgs(); + + // In some cases, we construct a `DeclRef` to a `GenericDecl` + // (or a declaration under one) that only includes argument + // values for a prefix of the parameters of the generic. + // + // If we aren't careful, we could end up indexing into the + // argument list past the available range. + // + Count argCount = args.getCount(); + + Count argIndex = 0; + for (auto m : outerGeneric->members) + { + // If we have run out of arguments, then we can stop + // iterating over the parameters, because `this` + // parameter will not be replaced with anything by + // the substituion. + // + if (argIndex >= argCount) { - // If we have run out of arguments, then we can stop - // iterating over the parameters, because `this` - // parameter will not be replaced with anything by - // the substituion. - // - if (argIndex >= argCount) - { - return paramVal; - } + return paramVal; + } - if (m == paramDecl) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return genSubst->getArgs()[argIndex]; - } - else if (const auto typeParam = as<GenericTypeParamDecl>(m)) - { - argIndex++; - } - else if (const auto valParam = as<GenericValueParamDecl>(m)) - { - argIndex++; - } - else - { - } + if (m == paramDecl) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return args[argIndex]; + } + else if (const auto typeParam = as<GenericTypeParamDecl>(m)) + { + argIndex++; + } + else if (const auto valParam = as<GenericValueParamDecl>(m)) + { + argIndex++; + } + else + { } } @@ -180,7 +267,7 @@ Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff) { - if (auto result = maybeSubstituteGenericParam(this, declRef.getDecl(), subst, ioDiff)) + if (auto result = maybeSubstituteGenericParam(this, getDeclRef().getDecl(), subst, ioDiff)) return result; return this; @@ -188,21 +275,11 @@ Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, S // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ErrorIntVal::_equalsValOverride(Val* val) -{ - return as<ErrorIntVal>(val); -} - void ErrorIntVal::_toTextOverride(StringBuilder& out) { out << toSlice("<error>"); } -HashCode ErrorIntVal::_getHashCodeOverride() -{ - return HashCode(typeid(this).hash_code()); -} - Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { SLANG_UNUSED(astBuilder); @@ -211,97 +288,110 @@ Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe return this; } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -// TODO: should really have a `type.cpp` and a `witness.cpp` - -bool TypeEqualityWitness::_equalsValOverride(Val* val) -{ - auto otherWitness = as<TypeEqualityWitness>(val); - if (!otherWitness) - return false; - return sub->equals(otherWitness->sub); -} - Val* TypeEqualityWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { - TypeEqualityWitness* rs = astBuilder->create<TypeEqualityWitness>(); - rs->sub = as<Type>(sub->substituteImpl(astBuilder, subst, ioDiff)); - rs->sup = as<Type>(sup->substituteImpl(astBuilder, subst, ioDiff)); + auto type = as<Type>(getSub()->substituteImpl(astBuilder, subst, ioDiff)); + TypeEqualityWitness* rs = astBuilder->getOrCreate<TypeEqualityWitness>(type, type); return rs; } void TypeEqualityWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("TypeEqualityWitness(") << sub << toSlice(")"); -} - -HashCode TypeEqualityWitness::_getHashCodeOverride() -{ - return sub->getHashCode(); + out << toSlice("TypeEqualityWitness(") << getSub() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclaredSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool DeclaredSubtypeWitness::_equalsValOverride(Val* val) +Val* DeclaredSubtypeWitness::_resolveImplOverride() { - auto otherWitness = as<DeclaredSubtypeWitness>(val); - if (!otherWitness) - return false; + auto resolvedDeclRef = getDeclRef().declRefBase->resolve(); + if (auto resolvedVal = as<SubtypeWitness>(resolvedDeclRef)) + return resolvedVal; - return sub->equals(otherWitness->sub) - && sup->equals(otherWitness->sup) - && declRef.equals(otherWitness->declRef); + auto newSub = as<Type>(getSub()->resolve()); + auto newSup = as<Type>(getSup()->resolve()); + + // If we are trying to lookup for a witness that A<:B from a witness(A<:B), we + // can just return the witness itself. + if (auto lookupDeclRef = as<LookupDeclRef>(resolvedDeclRef)) + { + auto witnessToLookupFrom = lookupDeclRef->getWitness(); + if (witnessToLookupFrom->getSub()->equals(newSub) && + witnessToLookupFrom->getSup()->equals(newSup)) + return witnessToLookupFrom; + } + auto newDeclRef = as<DeclRefBase>(resolvedDeclRef); + if (!newDeclRef) + newDeclRef = getDeclRef().declRefBase; + if (newSub != getSub() || newSup != getSup() || newDeclRef != getDeclRef()) + { + return getCurrentASTBuilder()->getDeclaredSubtypeWitness(newSub, newSup, newDeclRef); + } + return this; } Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { - if (auto genConstraintDeclRef = declRef.as<GenericTypeConstraintDecl>()) + if (auto genConstraintDeclRef = getDeclRef().as<GenericTypeConstraintDecl>()) { - auto genConstraintDecl = genConstraintDeclRef.getDecl(); + auto genericDecl = as<GenericDecl>(getDeclRef().getDecl()->parentDecl); + if (!genericDecl) + goto breakLabel; // search for a substitution that might apply to us - for (auto s = subst.substitutions; s; s = s->getOuter()) + auto args = tryGetGenericArguments(subst, genericDecl); + if (args.getCount() == 0) + goto breakLabel; + + bool found = false; + Index index = 0; + for (auto m : genericDecl->members) { - if (auto genericSubst = as<GenericSubstitution>(s)) + if (auto constraintParam = as<GenericTypeConstraintDecl>(m)) { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genericSubst->getGenericDecl(); - if (genericDecl != genConstraintDecl->parentDecl) - continue; - - bool found = false; - Index index = 0; - for (auto m : genericDecl->members) + if (constraintParam == getDeclRef().getDecl()) { - if (auto constraintParam = as<GenericTypeConstraintDecl>(m)) - { - if (constraintParam == declRef.getDecl()) - { - found = true; - break; - } - index++; - } - } - if (found) - { - (*ioDiff)++; - auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() + - genericDecl->getMembersOfType<GenericValueParamDecl>().getCount(); - SLANG_ASSERT(index + ordinaryParamCount < genericSubst->getArgs().getCount()); - return genericSubst->getArgs()[index + ordinaryParamCount]; + found = true; + break; } + index++; + } + } + if (found) + { + auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() + + genericDecl->getMembersOfType<GenericValueParamDecl>().getCount(); + if (index + ordinaryParamCount < args.getCount()) + { + (*ioDiff)++; + return args[index + ordinaryParamCount]; + } + else + { + // When the `subst` represents a partial substitution, we may not have a corresponding argument. + // In this case we just return the original witness. + // + goto breakLabel; } } } + else if (auto thisTypeConstraintDeclRef = getDeclRef().as<ThisTypeConstraintDecl>()) + { + auto lookupSubst = subst.findLookupDeclRef(); + if (lookupSubst && lookupSubst->getSupDecl() == thisTypeConstraintDeclRef.getDecl()->getInterfaceDecl()) + { + (*ioDiff)++; + return lookupSubst->getWitness(); + } + } + +breakLabel:; // Perform substitution on the constituent elements. int diff = 0; - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff)); + if (!diff) return this; @@ -317,7 +407,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub // so we'll need to change this location in the code if we ever clean // up the hierarchy. // - if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(substDeclRef.getDecl())) + if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(getDeclRef().getDecl())) { if (auto substAssocTypeDecl = as<AssocTypeDecl>(substTypeConstraintDecl->parentDecl)) { @@ -326,12 +416,12 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub // At this point we have a constraint decl for an associated type, // and we nee to see if we are dealing with a concrete substitution // for the interface around that associated type. - if (auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.getSubst(), interfaceDecl)) + if (auto thisTypeWitness = findThisTypeWitness(subst, interfaceDecl)) { // We need to look up the declaration that satisfies // the requirement named by the associated type. Decl* requirementKey = substTypeConstraintDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeSubst->witness, requirementKey); + RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeWitness, requirementKey); switch (requirementWitness.getFlavor()) { default: @@ -348,6 +438,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub } } + auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); auto rs = astBuilder->getDeclaredSubtypeWitness( substSub, substSup, substDeclRef); return rs; @@ -355,34 +446,17 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub void DeclaredSubtypeWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("DeclaredSubtypeWitness(") << sub << toSlice(", ") << sup << toSlice(", ") << declRef << toSlice(")"); -} - -HashCode DeclaredSubtypeWitness::_getHashCodeOverride() -{ - return declRef.getHashCode(); + out << toSlice("DeclaredSubtypeWitness(") << getSub() << toSlice(", ") << getSup() << toSlice(", ") << getDeclRef() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TransitiveSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool TransitiveSubtypeWitness::_equalsValOverride(Val* val) -{ - auto otherWitness = as<TransitiveSubtypeWitness>(val); - if (!otherWitness) - return false; - - return sub->equals(otherWitness->sub) - && sup->equals(otherWitness->sup) - && subToMid->equalsVal(otherWitness->subToMid) - && midToSup->equalsVal(otherWitness->midToSup); -} - Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { int diff = 0; - SubtypeWitness* substSubToMid = as<SubtypeWitness>(subToMid->substituteImpl(astBuilder, subst, &diff)); - SubtypeWitness* substMidToSup = as<SubtypeWitness>(midToSup->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substSubToMid = as<SubtypeWitness>(getSubToMid()->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substMidToSup = as<SubtypeWitness>(getMidToSup()->substituteImpl(astBuilder, subst, &diff)); // If nothing changed, then we can bail out early. if (!diff) @@ -407,16 +481,7 @@ void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out) // witnesses, and rely on them to print // the starting and ending types. - out << toSlice("TransitiveSubtypeWitness(") << subToMid << toSlice(", ") << midToSup << toSlice(")"); -} - -HashCode TransitiveSubtypeWitness::_getHashCodeOverride() -{ - auto hash = sub->getHashCode(); - hash = combineHash(hash, sup->getHashCode()); - hash = combineHash(hash, subToMid->getHashCode()); - hash = combineHash(hash, midToSup->getHashCode()); - return hash; + out << toSlice("TransitiveSubtypeWitness(") << getSubToMid() << toSlice(", ") << getMidToSup() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -425,9 +490,9 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a { int diff = 0; - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); - auto substWitness = as<SubtypeWitness>(conjunctionWitness->substituteImpl(astBuilder, subst, &diff)); + auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff)); + auto substWitness = as<SubtypeWitness>(getConjunctionWitness()->substituteImpl(astBuilder, subst, &diff)); // If nothing changed, then we can bail out early. if (!diff) @@ -447,138 +512,34 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a // simplification logic as needed. // return astBuilder->getExtractFromConjunctionSubtypeWitness( - substSub, substSup, substWitness, indexInConjunction); + substSub, substSup, substWitness, getIndexInConjunction()); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val) -{ - if (auto extractWitness = as<ExtractExistentialSubtypeWitness>(val)) - { - return declRef.equals(extractWitness->declRef); - } - return false; -} - void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("extractExistentialValue(") << declRef << toSlice(")"); -} - -HashCode ExtractExistentialSubtypeWitness::_getHashCodeOverride() -{ - return declRef.getHashCode(); + out << toSlice("extractExistentialValue(") << getDeclRef() << toSlice(")"); } Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); + auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); + auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff)); if (!diff) return this; (*ioDiff)++; - ExtractExistentialSubtypeWitness* substValue = astBuilder->create<ExtractExistentialSubtypeWitness>(); - substValue->declRef = substDeclRef; - substValue->sub = substSub; - substValue->sup = substSup; + ExtractExistentialSubtypeWitness* substValue = astBuilder->getOrCreate<ExtractExistentialSubtypeWitness>( + substSub, substSup, substDeclRef); return substValue; } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -bool TaggedUnionSubtypeWitness::_equalsValOverride(Val* val) -{ - auto taggedUnionWitness = as<TaggedUnionSubtypeWitness>(val); - if (!taggedUnionWitness) - return false; - - auto caseCount = caseWitnesses.getCount(); - if (caseCount != taggedUnionWitness->caseWitnesses.getCount()) - return false; - - for (Index ii = 0; ii < caseCount; ++ii) - { - if (!caseWitnesses[ii]->equalsVal(taggedUnionWitness->caseWitnesses[ii])) - return false; - } - - return true; -} - -void TaggedUnionSubtypeWitness::_toTextOverride(StringBuilder& out) -{ - out << toSlice("TaggedUnionSubtypeWitness("); - bool first = true; - for (auto caseWitness : caseWitnesses) - { - if (!first) - { - out << toSlice(", "); - } - first = false; - - out << caseWitness; - } - out << toSlice(")"); -} - -HashCode TaggedUnionSubtypeWitness::_getHashCodeOverride() -{ - HashCode hash = 0; - for (auto caseWitness : caseWitnesses) - { - hash = combineHash(hash, caseWitness->getHashCode()); - } - return hash; -} - -Val* TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); - - List<SubtypeWitness*> substCaseWitnesses; - for (auto caseWitness : caseWitnesses) - { - substCaseWitnesses.add( - as<SubtypeWitness>(caseWitness->substituteImpl(astBuilder, subst, &diff))); - } - - if (!diff) - return this; - - (*ioDiff)++; - - TaggedUnionSubtypeWitness* substWitness = astBuilder->create<TaggedUnionSubtypeWitness>(); - substWitness->sub = substSub; - substWitness->sup = substSup; - substWitness->caseWitnesses.swapWith(substCaseWitnesses); - return substWitness; -} - -bool ConjunctionSubtypeWitness::_equalsValOverride(Val* val) -{ - auto other = as<ConjunctionSubtypeWitness>(val); - if (!other) - return false; - - for (Index i = 0; i < kComponentCount; ++i) - { - if (!other->componentWitnesses[i]) return false; - if (!other->componentWitnesses[i]->equalsVal(componentWitnesses[i])) return false; - } - return true; -} - void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) { out << "ConjunctionSubtypeWitness("; @@ -586,34 +547,23 @@ void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) { if (i != 0) out << ","; - auto w = componentWitnesses[i]; + auto w = getComponentWitness(i); if (w) out << w; } out << ")"; } -HashCode ConjunctionSubtypeWitness::_getHashCodeOverride() -{ - HashCode result = 0; - for (Index i = 0; i < kComponentCount; ++i) - { - auto w = componentWitnesses[i]; - if (w) result = combineHash(result, w->getHashCode()); - } - return result; -} - Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; Val* substComponentWitnesses[kComponentCount]; - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); + auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff)); for (Index i = 0; i < kComponentCount; ++i) { - auto w = componentWitnesses[i]; + auto w = getComponentWitness(i); substComponentWitnesses[i] = w ? w->substituteImpl(astBuilder, subst, &diff) : nullptr; } @@ -630,65 +580,25 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, auto result = astBuilder->getConjunctionSubtypeWitness( substSub, substSup, - componentWitnesses[0], - componentWitnesses[1]); + as<SubtypeWitness>(substComponentWitnesses[0]), + as<SubtypeWitness>(substComponentWitnesses[1])); return result; } -bool ExtractFromConjunctionSubtypeWitness::_equalsValOverride(Val* val) -{ - if (auto other = as<ExtractFromConjunctionSubtypeWitness>(val)) - { - if(!sub->equals(other->sub)) return false; - if(!sup->equals(other->sup)) return false; - if(indexInConjunction != other->indexInConjunction) return false; - - return true; - } - return false; -} - void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) { out << "ExtractFromConjunctionSubtypeWitness("; - if (conjunctionWitness) - out << conjunctionWitness; - if (sub) - out << sub; + if (getConjunctionWitness()) + out << getConjunctionWitness(); + if (getSub()) + out << getSub(); out << ","; - if (sup) - out << sup; - out << "," << indexInConjunction; + if (getSup()) + out << getSup(); + out << "," << getIndexInConjunction(); out << ")"; } -HashCode ExtractFromConjunctionSubtypeWitness::_getHashCodeOverride() -{ - return combineHash( - conjunctionWitness ? conjunctionWitness->getHashCode() : 0, - sub ? sub->getHashCode() : 0, - sup ? sup->getHashCode() : 0, - indexInConjunction); -} - -// ModifierVal - -bool ModifierVal::_equalsValOverride(Val* val) -{ - // TODO: This is assuming we can fully deduplicate the values that represent - // modifiers, which may not actually be the case if there are multiple modules - // being combined that use different `ASTBuilder`s. - // - return this == val; -} - -HashCode ModifierVal::_getHashCodeOverride() -{ - Hasher hasher; - hasher.hashValue((void*) this); - return hasher.getResult(); -} - // UNormModifierVal void UNormModifierVal::_toTextOverride(StringBuilder& out) @@ -735,48 +645,14 @@ Val* NoDiffModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitu // PolynomialIntVal -bool PolynomialIntVal::_equalsValOverride(Val* val) -{ - if (auto genericParamVal = as<GenericParamIntVal>(val)) - { - return constantTerm == 0 && terms.getCount() == 1 && - terms[0]->paramFactors.getCount() == 1 && terms[0]->constFactor == 1 && - terms[0]->paramFactors[0]->param->equalsVal(genericParamVal) && - terms[0]->paramFactors[0]->power == 1; - } - else if (auto otherPolynomial = as<PolynomialIntVal>(val)) - { - if (constantTerm != otherPolynomial->constantTerm) - return false; - if (terms.getCount() != otherPolynomial->terms.getCount()) - return false; - for (Index i = 0; i < terms.getCount(); i++) - { - auto& thisTerm = *(terms[i]); - auto& thatTerm = *(otherPolynomial->terms[i]); - if (thisTerm.constFactor != thatTerm.constFactor) - return false; - if (thisTerm.paramFactors.getCount() != thatTerm.paramFactors.getCount()) - return false; - for (Index j = 0; j < thisTerm.paramFactors.getCount(); j++) - { - if (thisTerm.paramFactors[j]->power != thatTerm.paramFactors[j]->power) - return false; - if (!thisTerm.paramFactors[j]->param->equalsVal(thatTerm.paramFactors[j]->param)) - return false; - } - } - return true; - } - return false; -} - void PolynomialIntVal::_toTextOverride(StringBuilder& out) { + auto constantTerm = getConstantTerm(); + auto terms = getTerms(); for (Index i = 0; i < terms.getCount(); i++) { auto& term = *(terms[i]); - if (term.constFactor > 0) + if (term.getConstFactor() > 0) { if (i > 0) out << "+"; @@ -784,14 +660,14 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out) else out << "-"; bool isFirstFactor = true; - if (abs(term.constFactor) != 1 || term.paramFactors.getCount() == 0) + if (abs(term.getConstFactor()) != 1 || term.getParamFactors().getCount() == 0) { - out << abs(term.constFactor); + out << abs(term.getConstFactor()); isFirstFactor = false; } - for (Index j = 0; j < term.paramFactors.getCount(); j++) + for (Index j = 0; j < term.getParamFactors().getCount(); j++) { - auto factor = term.paramFactors[j]; + auto factor = term.getParamFactors()[j]; if (isFirstFactor) { isFirstFactor = false; @@ -800,10 +676,10 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out) { out << "*"; } - factor->param->toText(out); - if (factor->power != 1) + factor->getParam()->toText(out); + if (factor->getPower() != 1) { - out << "^^" << factor->power; + out << "^^" << factor->getPower(); } } } @@ -821,227 +697,304 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out) } } -HashCode PolynomialIntVal::_getHashCodeOverride() +struct PolynomialIntValBuilder { - HashCode result = (HashCode)constantTerm; - for (auto& term : terms) + ASTBuilder* astBuilder; + + IntegerLiteralValue constantTerm = 0; + List<PolynomialIntValTerm*> terms; + + PolynomialIntValBuilder(ASTBuilder* inAstBuilder) + : astBuilder(inAstBuilder) + {} + + // compute val += opreand*multiplier; + bool addToPolynomialTerm(IntVal* operand, IntegerLiteralValue multiplier) { - if (!term) continue; - result = combineHash(result, (HashCode)term->constFactor); - for (auto& factor : term->paramFactors) + if (auto c = as<ConstantIntVal>(operand)) { - result = combineHash(result, factor->param->getHashCode()); - result = combineHash(result, (HashCode)factor->power); + constantTerm += c->getValue() * multiplier; + return true; } + else if (auto poly = as<PolynomialIntVal>(operand)) + { + constantTerm += poly->getConstantTerm() * multiplier; + for (auto term : poly->getTerms()) + { + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + multiplier * term->getConstFactor(), term->getParamFactors()); + terms.add(newTerm); + } + return true; + } + else if (auto genVal = as<IntVal>(operand)) + { + auto factor = astBuilder->getOrCreate<PolynomialIntValFactor>(genVal, 1); + auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(multiplier, makeArrayViewSingle(factor)); + terms.add(term); + return true; + } + return false; } - return result; -} + + IntVal* canonicalize(Type* type) + { + List<PolynomialIntValTerm*> newTerms; + IntegerLiteralValue newConstantTerm = constantTerm; + auto addTerm = [&](PolynomialIntValTerm* newTerm) + { + for (auto& term : newTerms) + { + if (term->canCombineWith(*newTerm)) + { + term = astBuilder->getOrCreate<PolynomialIntValTerm>( + term->getConstFactor() + newTerm->getConstFactor(), + term->getParamFactors()); + return; + } + } + newTerms.add(newTerm); + }; + for (auto term : terms) + { + if (term->getConstFactor() == 0) + continue; + List<PolynomialIntValFactor*> newFactors; + List<bool> factorIsDifferent; + for (Index i = 0; i < term->getParamFactors().getCount(); i++) + { + auto factor = term->getParamFactors()[i]; + bool factorFound = false; + for (Index j = 0; j < newFactors.getCount(); j++) + { + auto& newFactor = newFactors[j]; + if (factor->getParam()->equals(newFactor->getParam())) + { + if (!factorIsDifferent[j]) + { + factorIsDifferent[j] = true; + auto clonedFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(newFactor->getParam(), newFactor->getPower()); + newFactor = clonedFactor; + } + newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(newFactor->getParam(), newFactor->getPower() + factor->getPower()); + factorFound = true; + break; + } + } + if (!factorFound) + { + newFactors.add(factor); + factorIsDifferent.add(false); + } + } + List<PolynomialIntValFactor*> newFactors2; + // Remove zero-powered factors. + for (auto factor : newFactors) + { + if (factor->getPower() != 0) + newFactors2.add(factor); + } + if (newFactors2.getCount() == 0) + { + newConstantTerm += term->getConstFactor(); + continue; + } + newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); + bool isDifferent = false; + if (newFactors2.getCount() != term->getParamFactors().getCount()) + isDifferent = true; + if (!isDifferent) + { + for (Index i = 0; i < term->getParamFactors().getCount(); i++) + if (term->getParamFactors()[i] != newFactors2[i]) + { + isDifferent = true; + break; + } + } + if (!isDifferent) + { + addTerm(term); + } + else + { + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(term->getConstFactor(), newFactors2.getArrayView()); + addTerm(newTerm); + } + } + List<PolynomialIntValTerm*> newTerms2; + for (auto term : newTerms) + { + if (term->getConstFactor() == 0) + continue; + newTerms2.add(term); + } + newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); + terms = _Move(newTerms2); + constantTerm = newConstantTerm; + if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->getConstFactor() == 1 && terms[0]->getParamFactors().getCount() == 1 && + terms[0]->getParamFactors()[0]->getPower() == 1) + { + return terms[0]->getParamFactors()[0]->getParam(); + } + if (terms.getCount() == 0) + return astBuilder->getIntVal(type, constantTerm); + return nullptr; + } + + IntVal* getIntVal(Type* type) + { + if (auto canVal = canonicalize(type)) + return canVal; + return astBuilder->getOrCreate<PolynomialIntVal>(type, constantTerm, terms.getArrayView()); + } +}; Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - IntegerLiteralValue evaluatedConstantTerm = constantTerm; - List<PolynomialIntValTerm*> evaluatedTerms; - for (auto& term : terms) + PolynomialIntValBuilder builder(astBuilder); + for (auto& term : getTerms()) { IntegerLiteralValue evaluatedTermConstFactor; List<PolynomialIntValFactor*> evaluatedTermParamFactors; - evaluatedTermConstFactor = term->constFactor; - for (auto& factor : term->paramFactors) + evaluatedTermConstFactor = term->getConstFactor(); + for (auto& factor : term->getParamFactors()) { - auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff); + auto substResult = factor->getParam()->substituteImpl(astBuilder, subst, &diff); if (auto constantVal = as<ConstantIntVal>(substResult)) { - evaluatedTermConstFactor *= constantVal->value; + evaluatedTermConstFactor *= constantVal->getValue(); } else if (auto intResult = as<IntVal>(substResult)) { - auto newFactor = astBuilder->create<PolynomialIntValFactor>(); - newFactor->param = intResult; - newFactor->power = factor->power; + auto newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(intResult, factor->getPower()); evaluatedTermParamFactors.add(newFactor); } } if (evaluatedTermParamFactors.getCount() == 0) { - evaluatedConstantTerm += evaluatedTermConstFactor; + builder.constantTerm += evaluatedTermConstFactor; } else { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->paramFactors = _Move(evaluatedTermParamFactors); - newTerm->constFactor = evaluatedTermConstFactor; - evaluatedTerms.add(newTerm); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + evaluatedTermConstFactor, evaluatedTermParamFactors.getArrayView()); + builder.terms.add(newTerm); } } *ioDiff += diff; - if (evaluatedTerms.getCount() == 0) - return astBuilder->getIntVal(type, evaluatedConstantTerm); + if (builder.terms.getCount() == 0) + return astBuilder->getIntVal(getType(), builder.constantTerm); if (diff != 0) { - auto newPolynomial = astBuilder->create<PolynomialIntVal>(type); - newPolynomial->constantTerm = evaluatedConstantTerm; - newPolynomial->terms = _Move(evaluatedTerms); - return newPolynomial->canonicalize(astBuilder); + return builder.getIntVal(getType()); } return this; } - -// compute val += opreand*multiplier; -bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier) -{ - if (auto c = as<ConstantIntVal>(operand)) - { - val->constantTerm += c->value * multiplier; - return true; - } - else if (auto poly = as<PolynomialIntVal>(operand)) - { - val->constantTerm += poly->constantTerm * multiplier; - for (auto term : poly->terms) - { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = multiplier * term->constFactor; - newTerm->paramFactors = term->paramFactors; - val->terms.add(newTerm); - } - return true; - } - else if (auto genVal = as<IntVal>(operand)) - { - auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = multiplier; - auto factor = astBuilder->create<PolynomialIntValFactor>(); - factor->power = 1; - factor->param = genVal; - term->paramFactors.add(factor); - val->terms.add(term); - return true; - } - return false; -} - -PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) +IntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) { - auto result = astBuilder->create<PolynomialIntVal>(base->type); - if (!addToPolynomialTerm(astBuilder, result, base, -1)) - return nullptr; - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + builder.addToPolynomialTerm(base, -1); + return builder.getIntVal(base->getType()); } -PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +IntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { - auto result = astBuilder->create<PolynomialIntVal>(op0->type); - if (!addToPolynomialTerm(astBuilder, result, op0, 1)) - return nullptr; - if (!addToPolynomialTerm(astBuilder, result, op1, -1)) - return nullptr; - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + builder.addToPolynomialTerm(op0, 1); + builder.addToPolynomialTerm(op1, -1); + return builder.getIntVal(op0->getType()); } -PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +IntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { - auto result = astBuilder->create<PolynomialIntVal>(op0->type); - if (!addToPolynomialTerm(astBuilder, result, op0, 1)) - return nullptr; - if (!addToPolynomialTerm(astBuilder, result, op1, 1)) - return nullptr; - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + builder.addToPolynomialTerm(op0, 1); + builder.addToPolynomialTerm(op1, 1); + return builder.getIntVal(op0->getType()); } -PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { if (auto poly0 = as<PolynomialIntVal>(op0)) { if (auto poly1 = as<PolynomialIntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(poly0->type); + PolynomialIntValBuilder builder(astBuilder); // add poly0.constant * poly1.constant - result->constantTerm = poly0->constantTerm * poly1->constantTerm; + builder.constantTerm = poly0->getConstantTerm() * poly1->getConstantTerm(); // add poly0.constant * poly1.terms - if (poly0->constantTerm != 0) + if (poly0->getConstantTerm() != 0) { - for (auto term : poly1->terms) + for (auto term : poly1->getTerms()) { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = poly0->constantTerm * term->constFactor; - newTerm->paramFactors.addRange(term->paramFactors); - result->terms.add(newTerm); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + poly0->getConstantTerm() * term->getConstFactor(), term->getParamFactors()); + builder.terms.add(newTerm); } } // add poly1.constant * poly0.terms - if (poly1->constantTerm != 0) + if (poly1->getConstantTerm() != 0) { - for (auto term : poly0->terms) + for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = poly1->constantTerm * term->constFactor; - newTerm->paramFactors.addRange(term->paramFactors); - result->terms.add(newTerm); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + poly1->getConstantTerm() * term->getConstFactor(), + term->getParamFactors()); + builder.terms.add(newTerm); } } // add poly1.terms * poly0.terms - for (auto term0 : poly0->terms) + for (auto term0 : poly0->getTerms()) { - for (auto term1 : poly1->terms) + for (auto term1 : poly1->getTerms()) { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term0->constFactor * term1->constFactor; - newTerm->paramFactors.addRange(term0->paramFactors); - newTerm->paramFactors.addRange(term1->paramFactors); - result->terms.add(newTerm); + List<PolynomialIntValFactor*> newFactors; + for (auto f : term0->getParamFactors()) newFactors.add(f); + for (auto f : term1->getParamFactors()) newFactors.add(f); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + term0->getConstFactor() * term1->getConstFactor(), newFactors.getArrayView()); + builder.terms.add(newTerm); } } - result->canonicalize(astBuilder); - return result; + return builder.getIntVal(op0->getType()); } else if (auto cVal1 = as<ConstantIntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(poly0->type); - result->constantTerm = poly0->constantTerm * cVal1->value; - auto factor1 = astBuilder->create<PolynomialIntValFactor>(); - for (auto term : poly0->terms) + PolynomialIntValBuilder builder(astBuilder); + builder.constantTerm = poly0->getConstantTerm() * cVal1->getValue(); + for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term->constFactor * cVal1->value; - newTerm->paramFactors.addRange(term->paramFactors); - newTerm->paramFactors.add(factor1); - result->terms.add(newTerm); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(term->getConstFactor() * cVal1->getValue(), term->getParamFactors()); + builder.terms.add(newTerm); } - result->canonicalize(astBuilder); - return result; + return builder.getIntVal(poly0->getType()); } else if (auto val1 = as<IntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(poly0->type); - result->constantTerm = 0; - auto factor1 = astBuilder->create<PolynomialIntValFactor>(); - factor1->power = 1; - factor1->param = val1; - if (poly0->constantTerm != 0) + PolynomialIntValBuilder builder(astBuilder); + auto factor1 = astBuilder->getOrCreate<PolynomialIntValFactor>(val1, 1); + if (poly0->getConstantTerm() != 0) { - auto term0 = astBuilder->create<PolynomialIntValTerm>(); - term0->constFactor = poly0->constantTerm; - term0->paramFactors.add(factor1); - result->terms.add(term0); + auto term0 = astBuilder->getOrCreate<PolynomialIntValTerm>(poly0->getConstantTerm(), makeArrayViewSingle(factor1)); + builder.terms.add(term0); } - for (auto term : poly0->terms) + for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term->constFactor; - newTerm->paramFactors.addRange(term->paramFactors); - newTerm->paramFactors.add(factor1); - result->terms.add(newTerm); + List<PolynomialIntValFactor*> newFactors; + for (auto f: term->getParamFactors()) + newFactors.add(f); + newFactors.add(factor1); + auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>( + term->getConstFactor(), newFactors.getArrayView()); + builder.terms.add(newTerm); } - result->canonicalize(astBuilder); - return result; + return builder.getIntVal(poly0->getType()); } else return nullptr; @@ -1058,184 +1011,48 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int } else if (auto cVal1 = as<ConstantIntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(val0->type); - auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = cVal1->value; - auto factor0 = astBuilder->create<PolynomialIntValFactor>(); - factor0->power = 1; - factor0->param = val0; - term->paramFactors.add(factor0); - result->terms.add(term); - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + auto factor0 = astBuilder->getOrCreate<PolynomialIntValFactor>(val0, 1); + auto term = astBuilder->getOrCreate<PolynomialIntValTerm>( + cVal1->getValue(), makeArrayView(&factor0, 1)); + builder.terms.add(term); + return builder.getIntVal(val0->getType()); } else if (auto val1 = as<IntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(val0->type); - auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = 1; - auto factor0 = astBuilder->create<PolynomialIntValFactor>(); - factor0->power = 1; - factor0->param = val0; - term->paramFactors.add(factor0); - auto factor1 = astBuilder->create<PolynomialIntValFactor>(); - factor1->power = 1; - factor1->param = val1; - term->paramFactors.add(factor1); - result->terms.add(term); - result->canonicalize(astBuilder); - return result; + PolynomialIntValBuilder builder(astBuilder); + auto factor0 = astBuilder->getOrCreate<PolynomialIntValFactor>(val0, 1); + auto factor1 = astBuilder->getOrCreate<PolynomialIntValFactor>(val1, 1); + PolynomialIntValFactor* newFactors[] = { factor0, factor1 }; + auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(1, makeArrayView(newFactors)); + builder.terms.add(term); + return builder.getIntVal(val0->getType()); } } return nullptr; } -IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder) -{ - List<PolynomialIntValTerm*> newTerms; - IntegerLiteralValue newConstantTerm = constantTerm; - auto addTerm = [&](PolynomialIntValTerm* newTerm) - { - for (auto term : newTerms) - { - if (term->canCombineWith(*newTerm)) - { - term->constFactor += newTerm->constFactor; - return; - } - } - newTerms.add(newTerm); - }; - for (auto term : terms) - { - if (term->constFactor == 0) - continue; - List<PolynomialIntValFactor*> newFactors; - List<bool> factorIsDifferent; - for (Index i = 0; i < term->paramFactors.getCount(); i++) - { - auto factor = term->paramFactors[i]; - bool factorFound = false; - for (Index j = 0; j < newFactors.getCount(); j++) - { - auto& newFactor = newFactors[j]; - if (factor->param->equalsVal(newFactor->param)) - { - if (!factorIsDifferent[j]) - { - factorIsDifferent[j] = true; - auto clonedFactor = builder->create<PolynomialIntValFactor>(); - clonedFactor->param = newFactor->param; - clonedFactor->power = newFactor->power; - newFactor = clonedFactor; - } - newFactor->power += factor->power; - factorFound = true; - break; - } - } - if (!factorFound) - { - newFactors.add(factor); - factorIsDifferent.add(false); - } - } - List<PolynomialIntValFactor*> newFactors2; - for (auto factor : newFactors) - { - if (factor->power != 0) - newFactors2.add(factor); - } - if (newFactors2.getCount() == 0) - { - newConstantTerm += term->constFactor; - continue; - } - newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); - bool isDifferent = false; - if (newFactors2.getCount() != term->paramFactors.getCount()) - isDifferent = true; - if (!isDifferent) - { - for (Index i = 0; i < term->paramFactors.getCount(); i++) - if (term->paramFactors[i] != newFactors2[i]) - { - isDifferent = true; - break; - } - } - if (!isDifferent) - { - addTerm(term); - } - else - { - auto newTerm = builder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term->constFactor; - newTerm->paramFactors = _Move(newFactors2); - addTerm(newTerm); - } - } - List<PolynomialIntValTerm*> newTerms2; - for (auto term : newTerms) - { - if (term->constFactor == 0) - continue; - newTerms2.add(term); - } - newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); - terms = _Move(newTerms2); - constantTerm = newConstantTerm; - if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->constFactor == 1 && terms[0]->paramFactors.getCount() == 1 && - terms[0]->paramFactors[0]->power == 1) - { - return terms[0]->paramFactors[0]->param; - } - if (terms.getCount() == 0) - return builder->getIntVal(type, constantTerm); - return this; -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeCastIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool TypeCastIntVal::_equalsValOverride(Val* val) -{ - if (auto typeCastIntVal = as<TypeCastIntVal>(val)) - { - if (!type->equals(typeCastIntVal->type)) - return false; - if (!base->equalsVal(typeCastIntVal->base)) - return false; - return true; - } - return false; -} void TypeCastIntVal::_toTextOverride(StringBuilder& out) { - type->toText(out); + getType()->toText(out); out << "("; - base->toText(out); + getBase()->toText(out); out << ")"; } -HashCode TypeCastIntVal::_getHashCodeOverride() -{ - HashCode result = type->getHashCode(); - result = combineHash(result, base->getHashCode()); - return result; -} - Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink) { SLANG_UNUSED(sink); if (auto c = as<ConstantIntVal>(base)) { - IntegerLiteralValue resultValue = c->value; + IntegerLiteralValue resultValue = c->getValue(); auto baseType = as<BasicExpressionType>(resultType); if (baseType) { - switch (baseType->baseType) + switch (baseType->getBaseType()) { case BaseType::Int: resultValue = (int)resultValue; @@ -1275,11 +1092,11 @@ Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto substBase = base->substituteImpl(astBuilder, subst, &diff); - if (substBase != base) + auto substBase = getBase()->substituteImpl(astBuilder, subst, &diff); + if (substBase != getBase()) diff++; - auto substType = as<Type>(type->substituteImpl(astBuilder, subst, &diff)); - if (substType != type) + auto substType = as<Type>(getType()->substituteImpl(astBuilder, subst, &diff)); + if (substType != getType()) diff++; *ioDiff += diff; if (diff) @@ -1289,7 +1106,7 @@ Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio return newVal; else { - auto result = astBuilder->create<TypeCastIntVal>(substType, substBase); + auto result = astBuilder->getOrCreate<TypeCastIntVal>(substType, substBase); return result; } } @@ -1297,29 +1114,20 @@ Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio return this; } - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -bool FuncCallIntVal::_equalsValOverride(Val* val) +Val* TypeCastIntVal::_resolveImplOverride() { - if (auto funcCallIntVal = as<FuncCallIntVal>(val)) - { - if (!funcDeclRef.equals(funcCallIntVal->funcDeclRef)) - return false; - if (args.getCount() != funcCallIntVal->args.getCount()) - return false; - for (Index i = 0; i < args.getCount(); i++) - { - if (!args[i]->equalsVal(funcCallIntVal->args[i])) - return false; - } - return true; - } - return false; + if (auto resolved = tryFoldImpl(getCurrentASTBuilder(), getType(), getBase(), nullptr)) + return resolved; + return this; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + void FuncCallIntVal::_toTextOverride(StringBuilder& out) { + auto args = getArgs(); + auto funcDeclRef = getFuncDeclRef(); + auto argToText = [&](int index) { if (as<PolynomialIntVal>(args[index]) || as<FuncCallIntVal>(args[index])) @@ -1369,14 +1177,37 @@ void FuncCallIntVal::_toTextOverride(StringBuilder& out) } } -HashCode FuncCallIntVal::_getHashCodeOverride() +Val* FuncCallIntVal::_resolveImplOverride() { - HashCode result = funcDeclRef.getHashCode(); + auto astBuilder = getCurrentASTBuilder(); + auto args = getArgs(); + auto funcDeclRef = getFuncDeclRef(); + auto funcType = getFuncType(); + + Val* resolvedVal = this; + + auto newFuncDeclRef = as<DeclRefBase>(funcDeclRef.declRefBase->resolve()); + if (!newFuncDeclRef) + return this; + bool diff = false; + List<IntVal*> newArgs; for (auto arg : args) { - result = combineHash(result, arg->getHashCode()); + auto newArg = as<IntVal>(arg->resolve()); + if (!newArg) + return this; + newArgs.add(newArg); + if (newArg != arg) + diff = true; } - return result; + + if (auto resolved = tryFoldImpl(astBuilder, getType(), newFuncDeclRef, newArgs, nullptr)) + resolvedVal = resolved; + else if (diff) + { + resolvedVal = astBuilder->getOrCreate<FuncCallIntVal>(getType(), newFuncDeclRef, funcType, newArgs.getArrayView()); + } + return resolvedVal; } Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink) @@ -1413,25 +1244,25 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR #define BINARY_OPERATOR_CASE(op) \ if (opNameSlice == toSlice(#op)) \ { \ - resultValue = constArgs[0]->value op constArgs[1]->value; \ + resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ } else #define DIV_OPERATOR_CASE(op) \ if (opNameSlice == toSlice(#op)) \ { \ - if (constArgs[1]->value == 0) \ + if (constArgs[1]->getValue() == 0) \ { \ if (sink) \ sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \ return nullptr; \ } \ - resultValue = constArgs[0]->value op constArgs[1]->value; \ + resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ } else #define LOGICAL_OPERATOR_CASE(op) \ if (opNameSlice == toSlice(#op)) \ { \ - resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \ + resultValue = (((constArgs[0]->getValue()!=0) op (constArgs[1]->getValue()!=0)) ? 1 : 0); \ } else @@ -1463,9 +1294,9 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR LOGICAL_OPERATOR_CASE(&&) LOGICAL_OPERATOR_CASE(||) // Special cases need their "operator" names quoted. - SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->value != 0) ? 1 : 0);) - SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->value;) - SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value;) + SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->getValue() != 0) ? 1 : 0);) + SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->getValue();) + SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->getValue() != 0 ? constArgs[1]->getValue() : constArgs[2]->getValue();) TERMINATING_CASE(SLANG_UNREACHABLE("constant folding of FuncCallIntVal");) return astBuilder->getIntVal(resultType, resultValue); @@ -1483,9 +1314,9 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff); + auto newFuncDeclRef = getFuncDeclRef().substituteImpl(astBuilder, subst, &diff); List<IntVal*> newArgs; - for (auto& arg : args) + for (auto& arg : getArgs()) { auto substArg = arg->substituteImpl(astBuilder, subst, &diff); if (substArg != arg) @@ -1496,15 +1327,12 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio if (diff) { // TODO: report diagnostics back. - auto newVal = tryFoldImpl(astBuilder, type, newFuncDeclRef, newArgs, nullptr); + auto newVal = tryFoldImpl(astBuilder, getType(), newFuncDeclRef, newArgs, nullptr); if (newVal) return newVal; else { - auto result = astBuilder->create<FuncCallIntVal>(type); - result->args = _Move(newArgs); - result->funcDeclRef = newFuncDeclRef; - result->funcType = funcType; + auto result = astBuilder->getOrCreate<FuncCallIntVal>(getType(), newFuncDeclRef, getFuncType(), newArgs.getArrayView()); return result; } } @@ -1514,40 +1342,47 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! WitnessLookupIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool WitnessLookupIntVal::_equalsValOverride(Val* val) -{ - if (auto lookupIntVal = as<WitnessLookupIntVal>(val)) - { - if (!witness->equalsVal(lookupIntVal->witness)) - return false; - if (key != lookupIntVal->key) - return false; - return true; - } - return false; -} - void WitnessLookupIntVal::_toTextOverride(StringBuilder& out) { - witness->sub->toText(out); + getWitness()->getSub()->toText(out); out << "."; - out << (key->getName() ? key->getName()->text : "??"); + out << (getKey()->getName() ? getKey()->getName()->text : "??"); } -HashCode WitnessLookupIntVal::_getHashCodeOverride() +Val* WitnessLookupIntVal::_resolveImplOverride() { - HashCode result = witness->getHashCode(); - result = combineHash(result, Slang::getHashCode(key)); - return result; + auto astBuilder = getCurrentASTBuilder(); + + auto newWitness = as<SubtypeWitness>(getWitness()->resolve()); + if (!newWitness) + return this; + + auto witnessVal = tryLookUpRequirementWitness(astBuilder, newWitness, getKey()); + if (witnessVal.getFlavor() == RequirementWitness::Flavor::val) + { + return witnessVal.getVal(); + } + + auto newType = as<Type>(getType()->resolve()); + if (!newType) + return this; + + if (newWitness != getWitness() || newType != getType()) + { + return astBuilder->getOrCreate<WitnessLookupIntVal>(newType, newWitness, getKey()); + } + + return this; } + Val* WitnessLookupIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto newWitness = witness->substituteImpl(astBuilder, subst, &diff); + auto newWitness = getWitness()->substituteImpl(astBuilder, subst, &diff); *ioDiff += diff; if (diff) { - auto witnessEntry = tryFoldOrNull(astBuilder, as<SubtypeWitness>(newWitness), key); + auto witnessEntry = tryFoldOrNull(astBuilder, as<SubtypeWitness>(newWitness), getKey()); if (witnessEntry) return witnessEntry; } @@ -1573,51 +1408,93 @@ Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witnes { if (auto result = tryFoldOrNull(astBuilder, witness, key)) return result; - auto witnessResult = astBuilder->create<WitnessLookupIntVal>(); - witnessResult->witness = witness; - witnessResult->key = key; - witnessResult->type = type; + auto witnessResult = astBuilder->getOrCreate<WitnessLookupIntVal>(type, witness, key); return witnessResult; } - -bool DifferentiateVal::_equalsValOverride(Val* val) -{ - if (auto other = as<DifferentiateVal>(val)) - { - return other->astNodeType == astNodeType && other->func == func; - } - return false; -} - void DifferentiateVal::_toTextOverride(StringBuilder& out) { out << "DifferentiateVal("; - out << func; + out << getFunc(); out << ")"; } -HashCode DifferentiateVal::_getHashCodeOverride() -{ - HashCode result = (HashCode)astNodeType; - result = combineHash(result, func.getHashCode()); - return result; -} - Val* DifferentiateVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto newFunc = func.substituteImpl(astBuilder, subst, &diff); + auto newFunc = getFunc().substituteImpl(astBuilder, subst, &diff); *ioDiff += diff; if (diff) { auto result = as<DifferentiateVal>(astBuilder->createByNodeType(astNodeType)); - result->func = newFunc; + result->getFunc() = newFunc; return result; } // Nothing found: don't substitute. return this; } +Val* DifferentiateVal::_resolveImplOverride() +{ + return this; +} + +Val* PolynomialIntValFactor::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + + auto newParam = as<IntVal>(getParam()->resolve()); + if (newParam && newParam != getParam()) + return astBuilder->getOrCreate<PolynomialIntValFactor>(newParam, getPower()); + + return this; +} + +Val* PolynomialIntValTerm::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + + bool diff = false; + List<PolynomialIntValFactor*> newFactors; + for (auto factor : getParamFactors()) + { + auto newFactor = as<PolynomialIntValFactor>(factor->resolve()); + if (!newFactor) + return this; + + if (newFactor != factor) + diff = true; + newFactors.add(newFactor); + } + + if (diff) + return astBuilder->getOrCreate<PolynomialIntValTerm>(getConstFactor(), newFactors.getArrayView()); + + return this; +} + +Val* PolynomialIntVal::_resolveImplOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + + bool diff = false; + PolynomialIntValBuilder builder(astBuilder); + builder.constantTerm = getConstantTerm(); + for (auto term : getTerms()) + { + auto newTerm = as<PolynomialIntValTerm>(term->resolve()); + if (!newTerm) + return this; + + if (newTerm != term) + diff = true; + builder.terms.add(newTerm); + } + + if (diff) + return builder.getIntVal(getType()); + + return this; +} } // namespace Slang diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index cb4e94ebb..c45c42e02 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -8,17 +8,139 @@ namespace Slang { // Syntax class definitions for compile-time values. +class DirectDeclRef : public DeclRefBase +{ +public: + SLANG_AST_CLASS(DirectDeclRef) + + DirectDeclRef(Decl* decl) + { + setOperands(decl); + } + + DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + void _toTextOverride(StringBuilder& out); + Val* _resolveImplOverride(); + DeclRefBase* _getBaseOverride(); +}; + +// Represent an static member of a base decl. +// Note that we automatically fold the DeclRef if the path is known to be static. +// For example, MemberDeclRef(DirectDeclRef(A), B) ==> DirectDeclRef(B), +// and MemberDeclRef(MemberDeclRef(A, B), C) ==> MemberDeclRef(A, C). +// +class MemberDeclRef : public DeclRefBase +{ +public: + SLANG_AST_CLASS(MemberDeclRef); + + DeclRefBase* getParentOperand() { return as<DeclRefBase>(getOperand(1)); } + + MemberDeclRef(Decl* decl, DeclRefBase* parent) + { + setOperands(decl, parent); + } + + DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + void _toTextOverride(StringBuilder& out); + + Val* _resolveImplOverride(); + + DeclRefBase* _getBaseOverride(); +}; + + +// Represent a lookup of SuperType::`m_decl` from `lookupSourceType` type that we know conforms to SuperType. +class LookupDeclRef : public DeclRefBase +{ +public: + SLANG_AST_CLASS(LookupDeclRef); + + // m_decl represents the decl in SuperType that we want to lookup. + + // The source type that we are looking up from. + Type* getLookupSource() + { + return as<Type>(getOperand(1)); + } + + // Witness that `lookupSourceType`:SuperType. + SubtypeWitness* getWitness() + { + return as<SubtypeWitness>(getOperand(2)); + } + + LookupDeclRef(Decl* declToLookup, Type* lookupSource, SubtypeWitness* witness) + { + setOperands(declToLookup, lookupSource, witness); + } + + Decl* getSupDecl(); + + DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + void _toTextOverride(StringBuilder& out); + + Val* _resolveImplOverride(); + + DeclRefBase* _getBaseOverride(); + +private: + Val* tryResolve(SubtypeWitness* newWitness, Type* newLookupSource); +}; + +// Represents a specialization of a generic decl. +class GenericAppDeclRef : public DeclRefBase +{ +public: + SLANG_AST_CLASS(GenericAppDeclRef); + + DeclRefBase* getGenericDeclRef() { return as<DeclRefBase>(getOperand(1)); } + Index getArgCount() { return getOperandCount() - 2; } + Val* getArg(Index index) { return getOperand(index + 2); } + + GenericAppDeclRef(Decl* innerDecl, DeclRefBase* genericDeclRef, OperandView<Val> args) + { + m_operands.add(ValNodeOperand(innerDecl)); + m_operands.add(ValNodeOperand(genericDeclRef)); + for (auto arg : args) + { + m_operands.add(ValNodeOperand(arg)); + } + } + + GenericAppDeclRef(Decl* innerDecl, DeclRefBase* genericDeclRef, ConstArrayView<Val*> args) + { + m_operands.add(ValNodeOperand(innerDecl)); + m_operands.add(ValNodeOperand(genericDeclRef)); + for (auto arg : args) + { + m_operands.add(ValNodeOperand(arg)); + } + } + + GenericDecl* getGenericDecl(); + + OperandView<Val> getArgs() { return OperandView<Val>(this, 2, getArgCount()); } + + DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + void _toTextOverride(StringBuilder& out); + + Val* _resolveImplOverride(); + + DeclRefBase* _getBaseOverride(); +}; + // A compile-time integer (may not have a specific concrete value) class IntVal : public Val { SLANG_ABSTRACT_AST_CLASS(IntVal) - Type* type; - - IntVal(Type* inType) - : type(inType) - {} + Type* getType() { return as<Type>(getOperand(0)); } + Val* _resolveImplOverride() { return this; } }; // Trivial case of a value that is just a constant integer @@ -26,18 +148,15 @@ class ConstantIntVal : public IntVal { SLANG_AST_CLASS(ConstantIntVal) - IntegerLiteralValue value; + IntegerLiteralValue getValue() { return getIntConstOperand(1); } // Overrides should be public so base classes can access - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); -protected: ConstantIntVal(Type* inType, IntegerLiteralValue inValue) - : IntVal(inType), value(inValue) - {} - + { + setOperands(inType, inValue); + } }; // The logical "value" of a reference to a generic value parameter @@ -45,30 +164,31 @@ class GenericParamIntVal : public IntVal { SLANG_AST_CLASS(GenericParamIntVal) - DeclRef<VarDeclBase> declRef; + DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(1)); } // Overrides should be public so base classes can access - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef) - : IntVal(inType), declRef(inDeclRef) - {} + { + setOperands(inType, inDeclRef); + } }; class TypeCastIntVal : public IntVal { SLANG_AST_CLASS(TypeCastIntVal) - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); - Val* base; - TypeCastIntVal(Type* inType, Val* inBase) : IntVal(inType), base(inBase) {} + Val* getBase() { return getOperand(1); } + TypeCastIntVal(Type* inType, Val* inBase) + { + setOperands(inType, inBase); + } static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink); }; @@ -78,16 +198,21 @@ class FuncCallIntVal : public IntVal { SLANG_AST_CLASS(FuncCallIntVal) - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); - DeclRef<Decl> funcDeclRef; - Type* funcType; - List<IntVal*> args; + DeclRef<Decl> getFuncDeclRef() { return as<DeclRefBase>(getOperand(1)); } + Type* getFuncType() { return as<Type>(getOperand(2)); } + OperandView<IntVal> getArgs() { return OperandView<IntVal>(this, 3, getOperandCount() - 3); } + Index getArgCount() { return getOperandCount() - 3; } - FuncCallIntVal(Type* inType) : IntVal(inType) {} + FuncCallIntVal(Type* inType, DeclRef<Decl> inFuncDeclRef, Type* inFuncType, ArrayView<IntVal*> inArgs) + { + setOperands(inType, inFuncDeclRef, inFuncType); + for (auto arg : inArgs) + m_operands.add(ValNodeOperand(arg)); + } static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink); }; @@ -96,15 +221,17 @@ class WitnessLookupIntVal : public IntVal { SLANG_AST_CLASS(WitnessLookupIntVal) - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); - SubtypeWitness* witness; - Decl* key; + SubtypeWitness* getWitness() { return as<SubtypeWitness>(getOperand(1)); } + Decl* getKey() { return as<Decl>(getDeclOperand(2)); } - WitnessLookupIntVal(Type* inType) : IntVal(inType) {} + WitnessLookupIntVal(Type* inType, SubtypeWitness* witness, Decl* key) + { + setOperands(inType, witness, key); + } static Val* tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key); @@ -113,23 +240,31 @@ class WitnessLookupIntVal : public IntVal // polynomial expression "2*a*b^3 + 1" will be represented as: // { constantTerm:1, terms: [ { constFactor:2, paramFactors:[{"a", 1}, {"b", 3}] } ] } -class PolynomialIntValFactor : public NodeBase +class PolynomialIntValFactor : public Val { SLANG_AST_CLASS(PolynomialIntValFactor) public: - IntVal* param; - IntegerLiteralValue power; + IntVal* getParam() const { return as<IntVal>(getOperand(0)); } + IntegerLiteralValue getPower() const { return getIntConstOperand(1); } + + PolynomialIntValFactor(IntVal* inParam, IntegerLiteralValue inPower) + { + setOperands(inParam, inPower); + } + + Val* _resolveImplOverride(); + // for sorting only. bool operator<(const PolynomialIntValFactor& other) const { - if (auto thisGenParam = as<GenericParamIntVal>(param)) + if (auto thisGenParam = as<GenericParamIntVal>(getParam())) { - if (auto thatGenParam = as<GenericParamIntVal>(other.param)) + if (auto thatGenParam = as<GenericParamIntVal>(other.getParam())) { - if (thisGenParam->equalsVal(thatGenParam)) - return power < other.power; + if (thisGenParam->equals(thatGenParam)) + return getPower() < other.getPower(); else - return thisGenParam->declRef.getDecl() < thatGenParam->declRef.getDecl(); + return thisGenParam->getDeclRef().getDecl() < thatGenParam->getDeclRef().getDecl(); } else { @@ -138,64 +273,84 @@ public: } else { - if (const auto thatGenParam = as<GenericParamIntVal>(other.param)) + if (const auto thatGenParam = as<GenericParamIntVal>(other.getParam())) { return false; } - return param == other.param ? power < other.power : param < other.param; + return getParam() == other.getParam() ? getPower() < other.getPower() : getParam() < other.getParam(); } } // for sorting only. bool operator==(const PolynomialIntValFactor& other) const { - if (auto thisGenParam = as<GenericParamIntVal>(param)) + if (auto thisGenParam = as<GenericParamIntVal>(getParam())) { - if (auto thatGenParam = as<GenericParamIntVal>(other.param)) + if (auto thatGenParam = as<GenericParamIntVal>(other.getParam())) { - if (thisGenParam->equalsVal(thatGenParam) && power == other.power) + if (thisGenParam->equals(thatGenParam) && getPower() == other.getPower()) return true; } return false; } - return power == other.power && param == other.param; + return getPower() == other.getPower() && getParam() == other.getParam(); } bool equals(const PolynomialIntValFactor& other) const { - return power == other.power && param->equalsVal(other.param); + return getPower() == other.getPower() && getParam()->equals(other.getParam()); } }; -class PolynomialIntValTerm : public NodeBase +class PolynomialIntValTerm : public Val { SLANG_AST_CLASS(PolynomialIntValTerm) public: - IntegerLiteralValue constFactor; - List<PolynomialIntValFactor*> paramFactors; + IntegerLiteralValue getConstFactor() const { return getIntConstOperand(0); } + OperandView<PolynomialIntValFactor> getParamFactors() const { return OperandView<PolynomialIntValFactor>(this, 1, getOperandCount() - 1); } + + Val* _resolveImplOverride(); + + PolynomialIntValTerm(IntegerLiteralValue inConstFactor, ArrayView<PolynomialIntValFactor*> inParamFactors) + { + setOperands(inConstFactor); + addOperands(inParamFactors); + } + + PolynomialIntValTerm(IntegerLiteralValue inConstFactor, OperandView<PolynomialIntValFactor> inParamFactors) + { + setOperands(inConstFactor); + addOperands(inParamFactors); + } + bool canCombineWith(const PolynomialIntValTerm& other) const { - if (paramFactors.getCount() != other.paramFactors.getCount()) + auto paramFactors = getParamFactors(); + if (paramFactors.getCount() != other.getParamFactors().getCount()) return false; - for (Index i = 0; i < paramFactors.getCount(); i++) + for (Index i = 0; i < getParamFactors().getCount(); i++) { - if (!paramFactors[i]->equals(*other.paramFactors[i])) + if (!paramFactors[i]->equals(*other.getParamFactors()[i])) return false; } return true; } bool operator<(const PolynomialIntValTerm& other) const { - if (constFactor < other.constFactor) + auto constFactor = getConstFactor(); + auto paramFactors = getParamFactors(); + + if (constFactor < other.getConstFactor()) return true; - else if (constFactor == other.constFactor) + else if (constFactor == other.getConstFactor()) { + auto otherParamFactors = other.getParamFactors(); for (Index i = 0; i < paramFactors.getCount(); i++) { - if (i >= other.paramFactors.getCount()) + if (i >= otherParamFactors.getCount()) return false; - if (*(paramFactors[i]) < *(other.paramFactors[i])) + if (*(paramFactors[i]) < *(otherParamFactors[i])) return true; - if (*(paramFactors[i]) == *(other.paramFactors[i])) + if (*(paramFactors[i]) == *(otherParamFactors[i])) { } else @@ -213,27 +368,25 @@ class PolynomialIntVal : public IntVal SLANG_AST_CLASS(PolynomialIntVal) public: - List<PolynomialIntValTerm*> terms; - IntegerLiteralValue constantTerm = 0; + IntegerLiteralValue getConstantTerm() { return getIntConstOperand(1); }; + OperandView<PolynomialIntValTerm> getTerms() { return OperandView<PolynomialIntValTerm>(this, 2, getOperandCount() - 2); }; - bool isConstant() { return terms.getCount() == 0; } - // Canonicalize the polynomial. If the polynomial can be simplified to a constant or a genericparam, - // the method returns the value simplified to. - // Otherwise, in-place modifications are performed and returns this. - IntVal* canonicalize(ASTBuilder* builder); + bool isConstant() { return getOperandCount() == 1; } // Overrides should be public so base classes can access - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); - static PolynomialIntVal* neg(ASTBuilder* astBuilder, IntVal* base); - static PolynomialIntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); - static PolynomialIntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); - static PolynomialIntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); - PolynomialIntVal(Type* inType) : IntVal(inType) {} - + static IntVal* neg(ASTBuilder* astBuilder, IntVal* base); + static IntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); + static IntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); + static IntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); + PolynomialIntVal(Type* inType, IntegerLiteralValue inConstantTerm, ArrayView<PolynomialIntValTerm*> inTerms) + { + setOperands(inType, inConstantTerm); + addOperands(inTerms); + } }; /// An unknown integer value indicating an erroneous sub-expression @@ -241,17 +394,16 @@ class ErrorIntVal : public IntVal { SLANG_AST_CLASS(ErrorIntVal) - ErrorIntVal(Type* inType) : IntVal(inType) {} + ErrorIntVal(Type* inType) { setOperands(inType); } // TODO: We should probably eventually just have an `ErrorVal` here // and have all `Val`s that represent ordinary values hold their // `Type` so that we can have an `ErrorVal` of any type. // Overrides should be public so base classes can access - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride() { return this; } }; // A witness to the fact that some proposition is true, encoded @@ -301,25 +453,23 @@ class SubtypeWitness : public Witness { SLANG_ABSTRACT_AST_CLASS(SubtypeWitness) - Type* sub = nullptr; - Type* sup = nullptr; + Val* _resolveImplOverride(); + + Type* getSub() { return as<Type>(getOperand(0)); } + Type* getSup() { return as<Type>(getOperand(1)); } }; class TypeEqualityWitness : public SubtypeWitness { SLANG_AST_CLASS(TypeEqualityWitness) - TypeEqualityWitness( - Type* type) + TypeEqualityWitness(Type* subType, Type* supType) { - this->sub = type; - this->sup = type; + setOperands(subType, supType); } // Overrides should be public so base classes can access - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; @@ -329,19 +479,19 @@ class DeclaredSubtypeWitness : public SubtypeWitness { SLANG_AST_CLASS(DeclaredSubtypeWitness) - DeclRef<Decl> declRef; + DeclRef<Decl> getDeclRef() + { + return as<DeclRefBase>(getOperand(2)); + } // Overrides should be public so base classes can access - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); DeclaredSubtypeWitness(Type* inSub, Type* inSup, DeclRef<Decl> inDeclRef) - : declRef(inDeclRef) { - sub = inSub; - sup = inSup; + setOperands(inSub, inSup, inDeclRef); } }; @@ -351,20 +501,25 @@ class TransitiveSubtypeWitness : public SubtypeWitness SLANG_AST_CLASS(TransitiveSubtypeWitness) // Witness that `sub : mid` - SubtypeWitness* subToMid = nullptr; + SubtypeWitness* getSubToMid() + { + return as<SubtypeWitness>(getOperand(2)); + } // Witness that `mid : sup` - SubtypeWitness* midToSup = nullptr; + SubtypeWitness* getMidToSup() + { + return as<SubtypeWitness>(getOperand(3)); + } // Overrides should be public so base classes can access - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - TransitiveSubtypeWitness(SubtypeWitness* inSubToMid, SubtypeWitness* inMidToSup) - : subToMid(inSubToMid), midToSup(inMidToSup) - {} + TransitiveSubtypeWitness(Type* subType, Type* supType, SubtypeWitness* inSubToMid, SubtypeWitness* inMidToSup) + { + setOperands(subType, supType, inSubToMid, inMidToSup); + } }; // A witness that `sub : sup` because `sub` was wrapped into @@ -374,52 +529,27 @@ class ExtractExistentialSubtypeWitness : public SubtypeWitness SLANG_AST_CLASS(ExtractExistentialSubtypeWitness) // The declaration of the existential value that has been opened - DeclRef<VarDeclBase> declRef; - - // Overrides should be public so base classes can access - bool _equalsValOverride(Val* val); - void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); - Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); -}; - -// A witness that `sub : sup`, because `sub` is a tagged union -// of the form `A | B | C | ...` and each of `A : sup`, -// `B : sup`, `C : sup`, etc. -// -class TaggedUnionSubtypeWitness : public SubtypeWitness -{ - SLANG_AST_CLASS(TaggedUnionSubtypeWitness) + DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(2)); } - // Witnesses that each of the "case" types in the union - // is a subtype of `sup`. - // - List<SubtypeWitness*> caseWitnesses; + ExtractExistentialSubtypeWitness(Type* inSub, Type* inSup, DeclRef<Decl> inDeclRef) + { + setOperands(inSub, inSup, inDeclRef); + } // Overrides should be public so base classes can access - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; - /// A witness of the fact that `ThisType(someInterface) : someInterface` -class ThisTypeSubtypeWitness : public SubtypeWitness -{ - SLANG_AST_CLASS(ThisTypeSubtypeWitness) - - ThisTypeSubtypeWitness(Type* subType, Type* supType) - { - sub = subType; - sup = supType; - } -}; - /// A witness of the fact that a user provided "__Dynamic" type argument is a /// subtype to the existential type parameter. class DynamicSubtypeWitness : public SubtypeWitness { SLANG_AST_CLASS(DynamicSubtypeWitness) + DynamicSubtypeWitness(Type* inSub, Type* inSup) + { + setOperands(inSub, inSup); + } }; /// A witness that `T : L & R` because `T : L` and `T : R` @@ -431,23 +561,24 @@ class ConjunctionSubtypeWitness : public SubtypeWitness // an operation that takes two witness tables `leftWitness` // and `rightWitness`, and forms a pair/tuple of // `(leftWitness, rightWitness)`. + static const int kComponentCount = 2; - static const Count kComponentCount = 2; - SubtypeWitness* componentWitnesses[kComponentCount]; + ConjunctionSubtypeWitness(Type* inSub, Type* inSup, SubtypeWitness* left, SubtypeWitness* right) + { + setOperands(inSub, inSup, left, right); + } - SubtypeWitness* getLeftWitness() const { return componentWitnesses[0]; } - SubtypeWitness* getRightWitness() const { return componentWitnesses[1]; } + SubtypeWitness* getLeftWitness() const { return as<SubtypeWitness>(getOperand(2)); } + SubtypeWitness* getRightWitness() const { return as<SubtypeWitness>(getOperand(3)); } - Count getComponentCount() const { return kComponentCount; } + Count getComponentCount() const { return 2; } SubtypeWitness* getComponentWitness(Index index) const { SLANG_ASSERT(index >= 0 && index < kComponentCount); - return componentWitnesses[index]; + return as<SubtypeWitness>(getOperand(2 + index)); } - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; @@ -461,19 +592,22 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness // `(leftWtiness, rightWitness)` and extracts one of the // elements of it. - /// Witness that `T < L & R` - SubtypeWitness* conjunctionWitness; + /// Witness that `T < L & R` + SubtypeWitness* getConjunctionWitness() { return as<SubtypeWitness>(getOperand(2)); }; + + ExtractFromConjunctionSubtypeWitness(Type* inSub, Type* inSup, SubtypeWitness* witness, int index) + { + setOperands(inSub, inSup, witness, index); + } /// The zero-based index of the super-type we care about in the conjunction /// /// If `conjunctionWitness` is `T < L & R` then this index should be zero if /// we want to represent `T < L` and one if we want `T < R`. /// - int indexInConjunction; + int getIndexInConjunction() { return (int)getIntConstOperand(3); }; - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; @@ -482,8 +616,7 @@ class ModifierVal : public Val { SLANG_AST_CLASS(ModifierVal) - bool _equalsValOverride(Val* val); - HashCode _getHashCodeOverride(); + Val* _resolveImplOverride() { return this; } }; class TypeModifierVal : public ModifierVal @@ -525,37 +658,91 @@ class DifferentiateVal : public Val { SLANG_AST_CLASS(DifferentiateVal) - DeclRef<Decl> func; + DifferentiateVal(DeclRef<Decl> inFunc) + { + setOperands(inFunc); + } + + DeclRef<Decl> getFunc() { return as<DeclRefBase>(getOperand(0)); } - bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); }; class ForwardDifferentiateVal : public DifferentiateVal { SLANG_AST_CLASS(ForwardDifferentiateVal) + ForwardDifferentiateVal(DeclRef<Decl> inFunc) + : DifferentiateVal(inFunc) + {} }; class BackwardDifferentiateVal : public DifferentiateVal { SLANG_AST_CLASS(BackwardDifferentiateVal) + + BackwardDifferentiateVal(DeclRef<Decl> inFunc) + : DifferentiateVal(inFunc) + {} }; class BackwardDifferentiateIntermediateTypeVal : public DifferentiateVal { SLANG_AST_CLASS(BackwardDifferentiateIntermediateTypeVal) + + BackwardDifferentiateIntermediateTypeVal(DeclRef<Decl> inFunc) + : DifferentiateVal(inFunc) + {} }; class BackwardDifferentiatePrimalVal : public DifferentiateVal { SLANG_AST_CLASS(BackwardDifferentiatePrimalVal) + + BackwardDifferentiatePrimalVal(DeclRef<Decl> inFunc) + : DifferentiateVal(inFunc) + {} }; class BackwardDifferentiatePropagateVal : public DifferentiateVal { SLANG_AST_CLASS(BackwardDifferentiatePropagateVal) + + BackwardDifferentiatePropagateVal(DeclRef<Decl> inFunc) + : DifferentiateVal(inFunc) + {} }; + +template<typename F> +void SubstitutionSet::forEachGenericSubstitution(F func) const +{ + if (!declRef) + return; + for (auto subst = declRef; subst; subst = subst->getBase()) + { + if (auto genSubst = as<GenericAppDeclRef>(subst)) + func(genSubst->getGenericDecl(), genSubst->getArgs()); + } +} + +template<typename F> +void SubstitutionSet::forEachSubstitutionArg(F func) const +{ + if (!declRef) + return; + for (auto subst = declRef; subst; subst = subst->getBase()) + { + if (auto genSubst = as<GenericAppDeclRef>(subst)) + { + for (auto arg : genSubst->getArgs()) + func(arg); + } + else if (auto thisSubst = as<LookupDeclRef>(subst)) + { + func(thisSubst->getWitness()->getSub()); + } + } +} } // namespace Slang diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 1d19e01bf..4376b1135 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -54,7 +54,7 @@ namespace Slang Type* superType) { SubtypeWitness* result = nullptr; - if (getShared()->tryGetSubtypeWitness(subType, superType, result)) + if (getShared()->tryGetSubtypeWitnessFromCache(subType, superType, result)) return result; result = checkAndConstructSubtypeWitness(subType, superType); getShared()->cacheSubtypeWitness(subType, superType, result); @@ -107,11 +107,11 @@ namespace Slang // First, make sure both sub type and super type decl are ready for lookup. if (auto subDeclRefType = as<DeclRefType>(subType)) { - ensureDecl(subDeclRefType->declRef.getDecl(), DeclCheckState::ReadyForLookup); + ensureDecl(subDeclRefType->getDeclRef().getDecl(), DeclCheckState::ReadyForLookup); } if (auto superDeclRefType = as<DeclRefType>(subType)) { - ensureDecl(superDeclRefType->declRef.getDecl(), DeclCheckState::ReadyForLookup); + ensureDecl(superDeclRefType->getDeclRef().getDecl(), DeclCheckState::ReadyForLookup); } // In the common case, we can use the pre-computed inheritance information for `subType` @@ -173,13 +173,13 @@ namespace Slang DeclRef<Decl> superTypeDeclRef; if (auto superDeclRefType = as<DeclRefType>(superType)) { - superTypeDeclRef = superDeclRefType->declRef; + superTypeDeclRef = superDeclRefType->getDeclRef(); } - if (auto dynamicType = as<DynamicType>(subType)) + if (as<DynamicType>(subType)) { // A __Dynamic type always conforms to the interface via its witness table. - auto witness = m_astBuilder->create<DynamicSubtypeWitness>(); + auto witness = m_astBuilder->getOrCreate<DynamicSubtypeWitness>(subType, superType); return witness; } else if (auto conjunctionSuperType = as<AndType>(superType)) @@ -189,10 +189,10 @@ namespace Slang // We therefore simply recursively test both `T <: L` // and `T <: R`. // - auto leftWitness = isSubtype(subType, conjunctionSuperType->left); + auto leftWitness = isSubtype(subType, conjunctionSuperType->getLeft()); if (!leftWitness) return nullptr; // - auto rightWitness = isSubtype(subType, conjunctionSuperType->right); + auto rightWitness = isSubtype(subType, conjunctionSuperType->getRight()); if (!rightWitness) return nullptr; // If both of the sub-relationships hold, we can construct @@ -214,7 +214,7 @@ namespace Slang // TODO(tfoley): We could add support for `ExtractExistentialType` to // the inheritance linearization logic, and eliminate this case. // - auto interfaceDeclRef = extractExistentialType->originalInterfaceDeclRef; + auto interfaceDeclRef = extractExistentialType->getOriginalInterfaceDeclRef(); if (interfaceDeclRef.equals(superTypeDeclRef)) { auto witness = extractExistentialType->getSubtypeWitness(); @@ -222,62 +222,6 @@ namespace Slang } return nullptr; } - // - // TODO(tfoley): We should probably just remove `TaggedUnionType`, - // since there is no useful code that relies on it any more. - // - else if(auto taggedUnionType = as<TaggedUnionType>(subType)) - { - // A tagged union type conforms to an interface if all of - // the constituent types in the tagged union conform. - // - // We will iterate over the "case" types in the tagged - // union, and check if they conform to the interface. - // Along the way we will collect the conformance witness - // values for the case types. - // - List<SubtypeWitness*> caseWitnesses; - for(auto caseType : taggedUnionType->caseTypes) - { - auto caseWitness = isSubtype(caseType, superType); - - if(!caseWitness) - { - return nullptr; - } - - caseWitnesses.add(caseWitness); - } - - // We also need to validate the requirements on - // the interface to make sure that they are suitable for - // use with a tagged-union type. - // - // For example, if the interface includes a `static` method - // (which can therefore be called without a particular instance), - // then we wouldn't know what implementation of that method - // to use because there is no tag value to dispatch on. - // - // We will start out being conservative about what we accept - // here, just to keep things simple. - // - if( auto superInterfaceDeclRef = superTypeDeclRef.as<InterfaceDecl>() ) - { - if(!isInterfaceSafeForTaggedUnion(superInterfaceDeclRef)) - return nullptr; - } - - // If we reach this point then we have a concrete - // witness for each of the case types, and that is - // enough to build a witness for the tagged union. - // - TaggedUnionSubtypeWitness* taggedUnionWitness = m_astBuilder->create<TaggedUnionSubtypeWitness>(); - taggedUnionWitness->sub = taggedUnionType; - taggedUnionWitness->sup = superType; - taggedUnionWitness->caseWitnesses.swapWith(caseWitnesses); - - return taggedUnionWitness; - } // default is failure return nullptr; @@ -287,7 +231,7 @@ namespace Slang { if (auto declRefType = as<DeclRefType>(type)) { - if (auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>()) + if (auto interfaceDeclRef = declRefType->getDeclRef().as<InterfaceDecl>()) return true; } return false; diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 22a92bf0a..b9d33a1c1 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -65,14 +65,14 @@ namespace Slang // That is, the join of a vector and a scalar type is // a vector type with a joined element type. auto joinElementType = TryJoinTypes( - vectorType->elementType, + vectorType->getElementType(), scalarType); if(!joinElementType) return nullptr; return createVectorType( joinElementType, - vectorType->elementCount); + vectorType->getElementCount()); } Type* SemanticsVisitor::_tryJoinTypeWithInterface( @@ -110,11 +110,11 @@ namespace Slang for(Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); baseTypeFlavorIndex++) { // Don't consider `type`, since we already know it doesn't work. - if(baseTypeFlavorIndex == Int(basicType->baseType)) + if(baseTypeFlavorIndex == Int(basicType->getBaseType())) continue; // Look up the type in our session. - auto candidateType = type->getASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex)); + auto candidateType = getCurrentASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex)); if(!candidateType) continue; @@ -186,8 +186,8 @@ namespace Slang { if (auto rightBasic = as<BasicExpressionType>(right)) { - auto leftFlavor = leftBasic->baseType; - auto rightFlavor = rightBasic->baseType; + auto leftFlavor = leftBasic->getBaseType(); + auto rightFlavor = rightBasic->getBaseType(); // TODO(tfoley): Need a special-case rule here that if // either operand is of type `half`, then we promote @@ -217,19 +217,19 @@ namespace Slang if(auto rightVector = as<VectorExpressionType>(right)) { // Check if the vector sizes match - if(!leftVector->elementCount->equalsVal(rightVector->elementCount)) + if(!leftVector->getElementCount()->equals(rightVector->getElementCount())) return nullptr; // Try to join the element types auto joinElementType = TryJoinTypes( - leftVector->elementType, - rightVector->elementType); + leftVector->getElementType(), + rightVector->getElementType()); if(!joinElementType) return nullptr; return createVectorType( joinElementType, - leftVector->elementCount); + leftVector->getElementCount()); } // We can also join a vector and a scalar @@ -242,7 +242,7 @@ namespace Slang // HACK: trying to work trait types in here... if(auto leftDeclRefType = as<DeclRefType>(left)) { - if( auto leftInterfaceRef = leftDeclRefType->declRef.as<InterfaceDecl>() ) + if( auto leftInterfaceRef = leftDeclRefType->getDeclRef().as<InterfaceDecl>() ) { // return _tryJoinTypeWithInterface(right, left); @@ -250,7 +250,7 @@ namespace Slang } if(auto rightDeclRefType = as<DeclRefType>(right)) { - if( auto rightInterfaceRef = rightDeclRefType->declRef.as<InterfaceDecl>() ) + if( auto rightInterfaceRef = rightDeclRefType->getDeclRef().as<InterfaceDecl>() ) { // return _tryJoinTypeWithInterface(left, right); @@ -263,10 +263,10 @@ namespace Slang return nullptr; } - SubstitutionSet SemanticsVisitor::trySolveConstraintSystem( + DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem( ConstraintSystem* system, DeclRef<GenericDecl> genericDeclRef, - GenericSubstitution* substWithKnownGenericArgs) + ArrayView<Val*> knownGenericArgs) { // For now the "solver" is going to be ridiculously simplistic. @@ -288,9 +288,8 @@ namespace Slang for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, genericDeclRef) ) { if(!TryUnifyTypes(*system, getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef))) - return SubstitutionSet(); + return DeclRef<Decl>(); } - SubstitutionSet resultSubst = genericDeclRef.getSubst(); // Once have built up the full list of constraints we are trying to satisfy, // we will attempt to solve for each parameter in a way that satisfies all @@ -310,10 +309,10 @@ namespace Slang // or not they are compatible with the constraints). // Count knownGenericArgCount = 0; - if (substWithKnownGenericArgs) + if (knownGenericArgs.getCount()) { - knownGenericArgCount = substWithKnownGenericArgs->getArgs().getCount(); - for (auto arg : substWithKnownGenericArgs->getArgs()) + knownGenericArgCount = knownGenericArgs.getCount(); + for (auto arg : knownGenericArgs) { args.add(arg); } @@ -364,7 +363,7 @@ namespace Slang if (!joinType) { // failure! - return SubstitutionSet(); + return DeclRef<Decl>(); } type = joinType; } @@ -375,7 +374,7 @@ namespace Slang if (!type) { // failure! - return SubstitutionSet(); + return DeclRef<Decl>(); } args.add(type); } @@ -417,10 +416,10 @@ namespace Slang } else { - if(!val->equalsVal(cVal)) + if(!val->equals(cVal)) { // failure! - return SubstitutionSet(); + return DeclRef<Decl>(); } } @@ -430,7 +429,7 @@ namespace Slang if (!val) { // failure! - return SubstitutionSet(); + return DeclRef<Decl>(); } args.add(val); } @@ -456,14 +455,10 @@ namespace Slang // search for a conformance `Robin : ISidekick`, which involved // apply the substitutions we already know... - GenericSubstitution* solvedSubst = m_astBuilder->getOrCreateGenericSubstitution( - genericDeclRef.getSubst(), genericDeclRef.getDecl(), args.getArrayView()); - for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { - DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getSpecializedDeclRef( - constraintDecl, - solvedSubst); + DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef( + genericDeclRef, args.getArrayView(), constraintDecl).as<GenericTypeConstraintDecl>(); // Extract the (substituted) sub- and super-type from the constraint. auto sub = getSub(m_astBuilder, constraintDeclRef); @@ -476,7 +471,7 @@ namespace Slang // not provide an explicit type parameter to specialize a generic // and the type parameter cannot be inferred from any arguments. // In this case, we should fail the constraint check. - return SubstitutionSet(); + return DeclRef<Decl>(); } // Search for a witness that shows the constraint is satisfied. @@ -492,7 +487,7 @@ namespace Slang // // TODO: Ideally we should print an error message in // this case, to let the user know why things failed. - return SubstitutionSet(); + return DeclRef<Decl>(); } // TODO: We may need to mark some constrains in our constraint @@ -505,13 +500,11 @@ namespace Slang { if (!c.satisfied) { - return SubstitutionSet(); + return DeclRef<Decl>(); } } - resultSubst = m_astBuilder->getOrCreateGenericSubstitution( - genericDeclRef.getSubst(), genericDeclRef.getDecl(), args); - return resultSubst; + return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView()); } bool SemanticsVisitor::TryUnifyVals( @@ -533,7 +526,7 @@ namespace Slang { if (auto sndIntVal = as<ConstantIntVal>(snd)) { - return fstIntVal->value == sndIntVal->value; + return fstIntVal->getValue() == sndIntVal->getValue(); } } @@ -541,23 +534,23 @@ namespace Slang if (auto fstInt = as<IntVal>(fst)) { if (auto tc = as<TypeCastIntVal>(fstInt)) - fstInt = as<IntVal>(tc->base); + fstInt = as<IntVal>(tc->getBase()); if (auto sndInt = as<IntVal>(snd)) { if (auto tc = as<TypeCastIntVal>(sndInt)) - sndInt = as<IntVal>(tc->base); + sndInt = as<IntVal>(tc->getBase()); auto fstParam = as<GenericParamIntVal>(fstInt); auto sndParam = as<GenericParamIntVal>(sndInt); bool okay = false; if (fstParam) { - if(TryUnifyIntParam(constraints, fstParam->declRef, sndInt)) + if(TryUnifyIntParam(constraints, fstParam->getDeclRef(), sndInt)) okay = true; } if (sndParam) { - if(TryUnifyIntParam(constraints, sndParam->declRef, fstInt)) + if(TryUnifyIntParam(constraints, sndParam->getDeclRef(), fstInt)) okay = true; } return okay; @@ -568,8 +561,8 @@ namespace Slang { if (auto sndWit = as<DeclaredSubtypeWitness>(snd)) { - auto constraintDecl1 = fstWit->declRef.as<TypeConstraintDecl>(); - auto constraintDecl2 = sndWit->declRef.as<TypeConstraintDecl>(); + auto constraintDecl1 = fstWit->getDeclRef().as<TypeConstraintDecl>(); + auto constraintDecl2 = sndWit->getDeclRef().as<TypeConstraintDecl>(); SLANG_ASSERT(constraintDecl1); SLANG_ASSERT(constraintDecl2); return TryUnifyTypes(constraints, @@ -586,8 +579,8 @@ namespace Slang if (auto sndWit = as<SubtypeWitness>(snd)) { return TryUnifyTypes(constraints, - fstWit->sup, - sndWit->sup); + fstWit->getSup(), + sndWit->getSup()); } } @@ -597,35 +590,28 @@ namespace Slang //return false; } - bool SemanticsVisitor::tryUnifySubstitutions( - ConstraintSystem& constraints, - Substitutions* fst, - Substitutions* snd) + bool SemanticsVisitor::tryUnifyDeclRef( + ConstraintSystem& constraints, + DeclRefBase* fst, + DeclRefBase* snd) { - // They must both be NULL or non-NULL - if (!fst || !snd) - return !fst && !snd; - - if(auto fstGeneric = as<GenericSubstitution>(fst)) - { - if(auto sndGeneric = as<GenericSubstitution>(snd)) - { - return tryUnifyGenericSubstitutions( - constraints, - fstGeneric, - sndGeneric); - } - } - - // TODO: need to handle other cases here - - return false; + if (fst == snd) + return true; + if (fst == nullptr || snd == nullptr) + return false; + auto fstGen = SubstitutionSet(fst).findGenericAppDeclRef(); + auto sndGen = SubstitutionSet(snd).findGenericAppDeclRef(); + if (fstGen == sndGen) + return true; + if (fstGen == nullptr || sndGen == nullptr) + return false; + return tryUnifyGenericAppDeclRef(constraints, fstGen, sndGen); } - bool SemanticsVisitor::tryUnifyGenericSubstitutions( + bool SemanticsVisitor::tryUnifyGenericAppDeclRef( ConstraintSystem& constraints, - GenericSubstitution* fst, - GenericSubstitution* snd) + GenericAppDeclRef* fst, + GenericAppDeclRef* snd) { SLANG_ASSERT(fst); SLANG_ASSERT(snd); @@ -649,7 +635,10 @@ namespace Slang } // Their "base" specializations must unify - if (!tryUnifySubstitutions(constraints, fstGen->getOuter(), sndGen->getOuter())) + auto fstBase = fst->getBase(); + auto sndBase = snd->getBase(); + + if (!tryUnifyDeclRef(constraints, fstBase, sndBase)) { okay = false; } @@ -718,14 +707,14 @@ namespace Slang { if (auto fstDeclRefType = as<DeclRefType>(fst)) { - auto fstDeclRef = fstDeclRefType->declRef; + auto fstDeclRef = fstDeclRefType->getDeclRef(); if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl())) return TryUnifyTypeParam(constraints, typeParamDecl, snd); if (auto sndDeclRefType = as<DeclRefType>(snd)) { - auto sndDeclRef = sndDeclRefType->declRef; + auto sndDeclRef = sndDeclRefType->getDeclRef(); if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl())) return TryUnifyTypeParam(constraints, typeParamDecl, fst); @@ -735,10 +724,10 @@ namespace Slang // next we need to unify the substitutions applied // to each declaration reference. - if (!tryUnifySubstitutions( + if (!tryUnifyDeclRef( constraints, - fstDeclRef.getSubst(), - sndDeclRef.getSubst())) + fstDeclRef, + sndDeclRef)) { return false; } @@ -749,15 +738,15 @@ namespace Slang { if (auto sndFunType = as<FuncType>(snd)) { - const Index numParams = fstFunType->paramTypes.getCount(); - if(numParams != sndFunType->paramTypes.getCount()) + const Index numParams = fstFunType->getParamCount(); + if(numParams != sndFunType->getParamCount()) return false; for(Index i = 0; i < numParams; ++i) { - if(!TryUnifyTypes(constraints, fstFunType->paramTypes[i], sndFunType->paramTypes[i])) + if(!TryUnifyTypes(constraints, fstFunType->getParamType(i), sndFunType->getParamType(i))) return false; } - return TryUnifyTypes(constraints, fstFunType->resultType, sndFunType->resultType); + return TryUnifyTypes(constraints, fstFunType->getResultType(), sndFunType->getResultType()); } } @@ -779,13 +768,13 @@ namespace Slang // if (auto fstAndType = as<AndType>(fst)) { - return TryUnifyTypes(constraints, fstAndType->left, snd) - && TryUnifyTypes(constraints, fstAndType->right, snd); + return TryUnifyTypes(constraints, fstAndType->getLeft(), snd) + && TryUnifyTypes(constraints, fstAndType->getRight(), snd); } else if (auto sndAndType = as<AndType>(snd)) { - return TryUnifyTypes(constraints, fst, sndAndType->left) - || TryUnifyTypes(constraints, fst, sndAndType->right); + return TryUnifyTypes(constraints, fst, sndAndType->getLeft()) + || TryUnifyTypes(constraints, fst, sndAndType->getRight()); } else return false; @@ -828,7 +817,7 @@ namespace Slang if (auto fstDeclRefType = as<DeclRefType>(fst)) { - auto fstDeclRef = fstDeclRefType->declRef; + auto fstDeclRef = fstDeclRefType->getDeclRef(); if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl())) { @@ -839,7 +828,7 @@ namespace Slang if (auto sndDeclRefType = as<DeclRefType>(snd)) { - auto sndDeclRef = sndDeclRefType->declRef; + auto sndDeclRef = sndDeclRefType->getDeclRef(); if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl())) { @@ -863,7 +852,7 @@ namespace Slang { return TryUnifyTypes( constraints, - fstVectorType->elementType, + fstVectorType->getElementType(), sndScalarType); } } @@ -875,7 +864,7 @@ namespace Slang return TryUnifyTypes( constraints, fstScalarType, - sndVectorType->elementType); + sndVectorType->getElementType()); } } diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index abe9f4817..d89808c3d 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -61,7 +61,7 @@ namespace Slang if(auto declRefType = as<DeclRefType>(type)) { - if(as<StructDecl>(declRefType->declRef)) + if(as<StructDecl>(declRefType->getDeclRef())) return false; } @@ -174,7 +174,7 @@ namespace Slang if(!baseDeclRefType) return nullptr; - auto baseDeclRef = baseDeclRefType->declRef; + auto baseDeclRef = baseDeclRefType->getDeclRef(); auto baseStructDeclRef = baseDeclRef.as<StructDecl>(); if(!baseStructDeclRef) return nullptr; @@ -193,7 +193,7 @@ namespace Slang if (!baseDeclRefType) return DeclRef<StructDecl>(); - auto baseDeclRef = baseDeclRefType->declRef; + auto baseDeclRef = baseDeclRefType->getDeclRef(); auto baseStructDeclRef = baseDeclRef.as<StructDecl>(); if (!baseStructDeclRef) return DeclRef<StructDecl>(); @@ -244,13 +244,13 @@ namespace Slang } else if (auto toVecType = as<VectorExpressionType>(toType)) { - auto toElementCount = toVecType->elementCount; - auto toElementType = toVecType->elementType; + auto toElementCount = toVecType->getElementCount(); + auto toElementType = toVecType->getElementType(); UInt elementCount = 0; if (auto constElementCount = as<ConstantIntVal>(toElementCount)) { - elementCount = (UInt) constElementCount->value; + elementCount = (UInt) constElementCount->getValue(); } else { @@ -299,7 +299,7 @@ namespace Slang UInt elementCount = 0; if (auto constElementCount = as<ConstantIntVal>(toElementCount)) { - elementCount = (UInt) constElementCount->value; + elementCount = (UInt) constElementCount->getValue(); } else { @@ -388,7 +388,7 @@ namespace Slang if (auto constRowCount = as<ConstantIntVal>(toMatrixType->getRowCount())) { - rowCount = (UInt) constRowCount->value; + rowCount = (UInt) constRowCount->getValue(); } else { @@ -423,7 +423,7 @@ namespace Slang } else if(auto toDeclRefType = as<DeclRefType>(toType)) { - auto toTypeDeclRef = toDeclRefType->declRef; + auto toTypeDeclRef = toDeclRefType->getDeclRef(); if(auto toStructDeclRef = toTypeDeclRef.as<StructDecl>()) { // Trying to initialize a `struct` type given an initializer list. @@ -570,7 +570,7 @@ namespace Slang if( left == right ) return true; - if( left->equalsVal(right) ) + if( left->equals(right) ) return true; return false; @@ -581,9 +581,9 @@ namespace Slang { if(!type) return false; - for( auto m : type->modifiers ) + for (Index m = 0; m < type->getModifierCount(); m++) { - if(_doModifiersMatch(m, modifier)) + if(_doModifiersMatch(type->getModifier(m), modifier)) return true; } @@ -632,7 +632,7 @@ namespace Slang { auto basicType = as<BasicExpressionType>(t); if (!basicType) return false; - switch (basicType->baseType) + switch (basicType->getBaseType()) { case BaseType::Int8: case BaseType::Int16: @@ -650,7 +650,7 @@ namespace Slang auto basicType = as<BasicExpressionType>(t); if (!basicType) return 0; - switch (basicType->baseType) + switch (basicType->getBaseType()) { case BaseType::Int8: case BaseType::UInt8: @@ -770,10 +770,10 @@ namespace Slang // on it, but the underlying types are otherwise the same. // auto toModified = as<ModifiedType>(toType); - auto toBase = toModified ? toModified->base : toType; + auto toBase = toModified ? toModified->getBase() : toType; // auto fromModified = as<ModifiedType>(fromType); - auto fromBase = fromModified ? fromModified->base : fromType; + auto fromBase = fromModified ? fromModified->getBase() : fromType; if((toModified || fromModified) && toBase->equals(fromBase)) @@ -787,8 +787,9 @@ namespace Slang // if( toModified ) { - for( auto modifier : toModified->modifiers ) + for (Index m = 0; m < toModified->getModifierCount(); m++) { + auto modifier = toModified->getModifier(m); if(_hasMatchingModifier(fromModified, modifier)) continue; @@ -804,8 +805,10 @@ namespace Slang } if( fromModified ) { - for( auto modifier : fromModified->modifiers ) + for (Index m = 0; m < fromModified->getModifierCount(); m++) { + auto modifier = fromModified->getModifier(m); + if(_hasMatchingModifier(toModified, modifier)) continue; @@ -923,7 +926,7 @@ namespace Slang // // TODO(tfoley): Under what circumstances would this check ever be needed? // - if (auto toParameterGroupType = as<ParameterGroupType>(toType)) + if (as<ParameterGroupType>(toType)) { return _failedCoercion(toType, outToExpr, fromExpr); } @@ -1141,7 +1144,7 @@ namespace Slang { if (auto val = as<ConstantIntVal>(intVal)) { - if (isIntValueInRangeOfType(val->value, toType)) + if (isIntValueInRangeOfType(val->getValue(), toType)) { // OK. shouldEmitGeneralWarning = false; diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index b1dd2d533..b6a5d94ef 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -13,6 +13,8 @@ #include "slang-lookup.h" #include "slang-syntax.h" #include "slang-ast-synthesis.h" +#include "slang-ast-reflect.h" + #include <limits> namespace Slang @@ -201,12 +203,6 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} - Val* resolveVal(Val* val); - Type* resolveType(Type* type) - { - return (Type*)resolveVal(type); - } - void visitTypeExp(TypeExp& exp) { exp.type = resolveType(exp.type); @@ -581,7 +577,7 @@ namespace Slang return getTypeForDeclRef(astBuilder, nullptr, nullptr, declRef, &typeResult, loc); } - DeclRef<ExtensionDecl> ApplyExtensionToType( + DeclRef<ExtensionDecl> applyExtensionToType( SemanticsVisitor* semantics, ExtensionDecl* extDecl, Type* type) @@ -589,118 +585,7 @@ namespace Slang if(!semantics) return DeclRef<ExtensionDecl>(); - return semantics->ApplyExtensionToType(extDecl, type); - } - - GenericSubstitution* createDefaultSubstitutionsForGeneric( - ASTBuilder* astBuilder, - SemanticsVisitor* semantics, - GenericDecl* genericDecl, - Substitutions* outerSubst) - { - GenericSubstitution* cachedResult = nullptr; - if (astBuilder->m_genericDefaultSubst.tryGetValue(genericDecl, cachedResult)) - { - if (cachedResult->getOuter() == outerSubst) - return cachedResult; - } - - List<Val*> args; - - for( auto mm : genericDecl->members ) - { - if( auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm) ) - { - args.add(DeclRefType::create(astBuilder, astBuilder->getSpecializedDeclRef<Decl>(genericTypeParamDecl, outerSubst))); - } - else if( auto genericValueParamDecl = as<GenericValueParamDecl>(mm) ) - { - if (semantics) - ensureDecl(semantics, genericValueParamDecl, DeclCheckState::ReadyForLookup); - - args.add(astBuilder->getOrCreate<GenericParamIntVal>( - genericValueParamDecl->getType(), - astBuilder->getSpecializedDeclRef(genericValueParamDecl, outerSubst))); - } - } - - bool shouldCache = true; - - // create default substitution arguments for constraints - for (auto mm : genericDecl->members) - { - if (auto genericTypeConstraintDecl = as<GenericTypeConstraintDecl>(mm)) - { - if (semantics) - { - ensureDecl(semantics, genericTypeConstraintDecl, DeclCheckState::ReadyForReference); - } - auto constraintDeclRef = astBuilder->getSpecializedDeclRef<GenericTypeConstraintDecl>(genericTypeConstraintDecl, outerSubst); - auto witness = - astBuilder->getDeclaredSubtypeWitness( - getSub(astBuilder, constraintDeclRef), - getSup(astBuilder, constraintDeclRef), - astBuilder->getSpecializedDeclRef(genericTypeConstraintDecl, outerSubst)); - // TODO: this is an ugly hack to prevent crashing. - // In early stages of compilation witness->sub and witness->sup may not be checked yet. - // When semanticVisitor is present we have used that to ensure the type is checked. - // However due to how the code is written we cannot guarantee semanticVisitor is always available - // here, and if we can't get the checked sup/sub type this subst is incomplete and should not be - // cached. - if (!witness->sub) - shouldCache = false; - args.add(witness); - } - } - - GenericSubstitution* genericSubst = astBuilder->getOrCreateGenericSubstitution(outerSubst, genericDecl, args); - if (shouldCache) - astBuilder->m_genericDefaultSubst[genericDecl] = genericSubst; - return genericSubst; - } - - // Sometimes we need to refer to a declaration the way that it would be specialized - // inside the context where it is declared (e.g., with generic parameters filled in - // using their archetypes). - // - SubstitutionSet createDefaultSubstitutions( - ASTBuilder* astBuilder, - SemanticsVisitor* semantics, - Decl* decl, - SubstitutionSet outerSubstSet) - { - auto dd = decl->parentDecl; - if( auto genericDecl = as<GenericDecl>(dd) ) - { - // We don't want to specialize references to anything - // other than the "inner" declaration itself. - if(decl != genericDecl->inner) - return outerSubstSet; - - GenericSubstitution* genericSubst = createDefaultSubstitutionsForGeneric( - astBuilder, - semantics, - genericDecl, - outerSubstSet.substitutions); - - return SubstitutionSet(genericSubst); - } - - return outerSubstSet; - } - - SubstitutionSet createDefaultSubstitutions( - ASTBuilder* astBuilder, - SemanticsVisitor* semantics, - Decl* decl) - { - SubstitutionSet subst; - if( auto parentDecl = decl->parentDecl ) - { - subst = createDefaultSubstitutions(astBuilder, semantics, parentDecl); - } - subst = createDefaultSubstitutions(astBuilder, semantics, decl, subst); - return subst; + return semantics->applyExtensionToType(extDecl, type); } bool SemanticsVisitor::isDeclUsableAsStaticMember( @@ -1066,7 +951,7 @@ namespace Slang auto baseExprType = memberExpr->baseExpression->type.type; if (auto typeType = as<TypeType>(baseExprType)) { - if (diffThisType->equals(typeType->type)) + if (diffThisType->equals(typeType->getType())) { return; } @@ -1149,7 +1034,6 @@ namespace Slang { // A variable with an explicit type is simpler, for the // most part. - TypeExp typeExp = CheckUsableType(varDecl->type); varDecl->type = typeExp; if (varDecl->type.equals(m_astBuilder->getVoidType())) @@ -1256,7 +1140,7 @@ namespace Slang { if (auto basicType = as<BasicExpressionType>(varDecl->getType())) { - switch (basicType->baseType) + switch (basicType->getBaseType()) { case BaseType::Bool: case BaseType::Int8: @@ -1429,11 +1313,11 @@ namespace Slang { if (auto declRefType = as<DeclRefType>(sharedTypeExpr->base)) { - auto subst = createDefaultSubstitutions(m_astBuilder, this, declRefType->declRef.getDecl()); - auto newType = DeclRefType::create(m_astBuilder, m_astBuilder->getSpecializedDeclRef(declRefType->declRef.getDecl(), subst)); + auto newDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, declRefType->getDeclRef()); + auto newType = DeclRefType::create(m_astBuilder, newDeclRef); sharedTypeExpr->base.type = newType; if (auto typetype = as<TypeType>(typeExp.exp->type)) - typetype->type = newType; + typeExp.exp->type = m_astBuilder->getTypeType(newType); } } } @@ -1477,20 +1361,20 @@ namespace Slang } // If `This` is nested inside a generic, we need to form a complete declref type to the - // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution - // from requirementDeclRef to get the generic substitution for outer generic parameters, and + // newly synthesized aggTypeDecl here. This can be done by obtaining the this type witness + // from requirementDeclRef to get the generic arguments for the outer generic, and // apply it to the newly synthesized decl. SubstitutionSet substSet; - if (auto thisTypeSusbt = findThisTypeSubstitution( - requirementDeclRef.getSubst(), - as<InterfaceDecl>(requirementDeclRef.getParent(m_astBuilder)).getDecl())) + if (auto thisWitness = findThisTypeWitness( + SubstitutionSet(requirementDeclRef), + as<InterfaceDecl>(requirementDeclRef.getParent()).getDecl())) { - if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) + if (auto declRefType = as<DeclRefType>(thisWitness->getSub())) { - substSet = declRefType->declRef.getSubst(); + substSet = SubstitutionSet(declRefType->getDeclRef()); } } - auto satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getSpecializedDeclRef(aggTypeDecl, substSet)); + auto satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getMemberDeclRef(substSet.declRef, aggTypeDecl)); // Helper function to add a `diffType` field into the synthesized type for the original // `member`. @@ -1513,8 +1397,7 @@ namespace Slang fieldLookupExpr->type.type = diffMemberType; auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); baseTypeExpr->base.type = differentialType; - auto baseTypeType = m_astBuilder->create<TypeType>(); - baseTypeType->type = differentialType; + auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(differentialType); baseTypeExpr->type.type = baseTypeType; fieldLookupExpr->baseExpression = baseTypeExpr; fieldLookupExpr->declRef = makeDeclRef(diffField); @@ -1529,8 +1412,7 @@ namespace Slang fieldLookupExpr->type.type = diffMemberType; auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); baseTypeExpr->base.type = differentialType; - auto baseTypeType = m_astBuilder->create<TypeType>(); - baseTypeType->type = differentialType; + auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(differentialType); baseTypeExpr->type.type = baseTypeType; fieldLookupExpr->baseExpression = baseTypeExpr; fieldLookupExpr->declRef = makeDeclRef(diffField); @@ -1545,7 +1427,7 @@ namespace Slang { if (auto declRefType = as<DeclRefType>(inheritanceDecl->base.type)) { - if (declRefType->declRef == m_astBuilder->getDifferentiableInterfaceDecl()) + if (declRefType->getDeclRef() == m_astBuilder->getDifferentiableInterfaceDecl()) { hasDifferentialConformance = true; break; @@ -1590,7 +1472,7 @@ namespace Slang if (auto baseDeclRefType = as<DeclRefType>(inheritance->base.type)) { // Skip interface super types. - if (baseDeclRefType->declRef.as<InterfaceDecl>()) + if (baseDeclRefType->getDeclRef().as<InterfaceDecl>()) continue; if (auto superDiffType = tryGetDifferentialType(m_astBuilder, baseDeclRefType)) { @@ -1618,6 +1500,9 @@ namespace Slang if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) { witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); + + // Incrase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types. + m_astBuilder->incrementEpoch(); return true; } @@ -1747,11 +1632,11 @@ namespace Slang auto baseType = as<DeclRefType>(inheritanceDecl->witnessTable->baseType); if (!baseType) return; - if (baseType->declRef.getDecl() != m_astBuilder->getDifferentiableInterfaceDecl().getDecl()) + if (baseType->getDeclRef().getDecl() != m_astBuilder->getDifferentiableInterfaceDecl().getDecl()) return; RequirementWitness witnessValue; auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType); - if (!inheritanceDecl->witnessTable->requirementDictionary.tryGetValue(requirementDecl, witnessValue)) + if (!inheritanceDecl->witnessTable->getRequirementDictionary().tryGetValue(requirementDecl, witnessValue)) return; // A type used as differential type must have itself as its own differential type. @@ -1763,7 +1648,7 @@ namespace Slang auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType); if (!differentialType->equals(diffDiffType)) { - SourceLoc sourceLoc = differentialType->declRef.getDecl()->loc; + SourceLoc sourceLoc = differentialType->getDeclRef().getDecl()->loc; getSink()->diagnose(sourceLoc, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType); getSink()->diagnose(inheritanceDecl, Diagnostics::noteSeeUseOfDifferentialType, differentialType, inheritanceDecl->getSup()); } @@ -2287,7 +2172,7 @@ namespace Slang auto satisfyingVal = m_astBuilder->getOrCreate<GenericParamIntVal>( requiredValueParamDeclRef.getDecl()->getType(), satisfyingValueParamDeclRef); - satisfyingVal->declRef = satisfyingValueParamDeclRef; + satisfyingVal->getDeclRef() = satisfyingValueParamDeclRef; requiredSubstArgs.add(satisfyingVal); } @@ -2311,21 +2196,16 @@ namespace Slang } } - GenericSubstitution* requiredSubst = m_astBuilder->getOrCreateGenericSubstitution( - requiredGenericDeclRef.getSubst(), - requiredGenericDeclRef.getDecl(), - requiredSubstArgs); - // Now that we have computed a set of specialization arguments that will // specialize the generic requirement at the type parameters of the satisfying // generic, we can construct a reference to that declaration and re-run some // of the earlier checking logic with more type information usable. // - auto specializedRequiredGenericDeclRef = m_astBuilder->getSpecializedDeclRef<GenericDecl>(requiredGenericDeclRef.getDecl(), requiredSubst); - auto specializedRequiredMemberDeclRefs = getMembers(m_astBuilder, specializedRequiredGenericDeclRef); + auto specializedRequiredGenericInnerDeclRef = m_astBuilder->getGenericAppDeclRef( + requiredGenericDeclRef, requiredSubstArgs.getArrayView()); for (Index i = 0; i < memberCount; i++) { - auto requiredMemberDeclRef = specializedRequiredMemberDeclRefs[i]; + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; if(auto requiredTypeParamDeclRef = requiredMemberDeclRef.as<GenericTypeParamDecl>()) @@ -2365,13 +2245,16 @@ namespace Slang // In current code the sub type will always be one of the generic type parameters, // and the super-type will always be an interface, but there should be no // need to make use of those additional details here. - - auto requiredSubType = getSub(m_astBuilder, requiredConstraintDeclRef); + auto specializedRequiredConstraintDeclRef = m_astBuilder->getGenericAppDeclRef( + requiredGenericDeclRef, + requiredSubstArgs.getArrayView(), + requiredConstraintDeclRef.getDecl()).as<GenericTypeConstraintDecl>(); + auto requiredSubType = getSub(m_astBuilder, specializedRequiredConstraintDeclRef); auto satisfyingSubType = getSub(m_astBuilder, satisfyingConstraintDeclRef); if (!satisfyingSubType->equals(requiredSubType)) return false; - auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef); + auto requiredSuperType = getSup(m_astBuilder, specializedRequiredConstraintDeclRef); auto satisfyingSuperType = getSup(m_astBuilder, satisfyingConstraintDeclRef); if (!satisfyingSuperType->equals(requiredSuperType)) return false; @@ -2400,8 +2283,8 @@ namespace Slang // declaration (whatever it is) for an exact match. // return doesMemberSatisfyRequirement( - m_astBuilder->getSpecializedDeclRef<Decl>(satisfyingGenericDeclRef.getDecl()->inner, satisfyingGenericDeclRef.getSubst()), - m_astBuilder->getSpecializedDeclRef<Decl>(requiredGenericDeclRef.getDecl()->inner, requiredSubst), + m_astBuilder->getMemberDeclRef(satisfyingGenericDeclRef, getInner(satisfyingGenericDeclRef)), + specializedRequiredGenericInnerDeclRef, witnessTable); } @@ -2444,7 +2327,7 @@ namespace Slang { // If we are seeing a placeholder that awaits synthesis, return false now to trigger // auto synthesis. - if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>()) + if (declRefType->getDeclRef().getDecl()->hasModifier<ToBeSynthesizedModifier>()) return false; } // We need to confirm that the chosen type `satisfyingType`, @@ -2466,7 +2349,7 @@ namespace Slang // type can indeed satisfy the interface requirement. witnessTable->add( requiredAssociatedTypeDeclRef.getDecl(), - RequirementWitness(satisfyingType)); + RequirementWitness(satisfyingType->getCanonicalType())); } return conformance; @@ -2563,7 +2446,7 @@ namespace Slang // check if the specified type satisfies the constraints defined by the associated type if (auto requiredTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>()) { - ensureDecl(typedefDeclRef, DeclCheckState::CanUseAsType); + ensureDecl(typedefDeclRef, DeclCheckState::ReadyForLookup); auto satisfyingType = getNamedType(m_astBuilder, typedefDeclRef); return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); @@ -2648,9 +2531,6 @@ namespace Slang { if (auto constraintDecl = as<GenericTypeConstraintDecl>(member)) { - getASTBuilder()->getSpecializedDeclRef( - constraintDecl, requiredMemberDeclRef.getSubst()); - auto synConstraintDecl = m_astBuilder->create<GenericTypeConstraintDecl>(); synConstraintDecl->nameAndLoc = constraintDecl->getNameAndLoc(); synConstraintDecl->parentDecl = synGenericDecl; @@ -2658,7 +2538,7 @@ namespace Slang // For constraints of type T : Interface, where T is a simple type parameter, // find the declaration of T // - if (auto typeParamDecl = as<DeclRefType>(constraintDecl->sub.type)->declRef.as<GenericTypeParamDecl>().getDecl()) + if (auto typeParamDecl = as<DeclRefType>(constraintDecl->sub.type)->getDeclRef().as<GenericTypeParamDecl>().getDecl()) { auto synTypeParamDecl = mapOrigToSynTypeParams[typeParamDecl]; @@ -2680,37 +2560,19 @@ namespace Slang } } - // Get outer substitutions. (This inner-most substition - // must be a ThisTypeSubstition) - // - Substitutions* outer = nullptr; - if (auto thisTypeSubst = findThisTypeSubstitution( - requiredMemberDeclRef.getSubst(), - as<InterfaceDecl>(requiredMemberDeclRef.getParent(m_astBuilder)).getDecl())) - { - outer = thisTypeSubst; - } - // Override generic pointer to point to the original generic container. // This will create a substitution of the synthesized parameters for the // original parameters. - // - GenericSubstitution* requiredFuncSubsts = createDefaultSubstitutionsForGeneric(m_astBuilder, this, requiredMemberDeclRef.getDecl(), outer); - DeclRef<Decl> requiredFuncDeclRef = m_astBuilder->getSpecializedDeclRef(requiredMemberDeclRef.getDecl()->inner, requiredFuncSubsts); - - GenericSubstitution* substSynParamsForOrigGeneric = m_astBuilder->getOrCreateGenericSubstitution( - outer, - requiredMemberDeclRef.getDecl(), - createDefaultSubstitutionsForGeneric(m_astBuilder, this, synGenericDecl, nullptr)->getArgs()); - - // Substitute parameters of the synthesized generic for the parameters of the original generic. - requiredFuncDeclRef = substituteDeclRef(substSynParamsForOrigGeneric, m_astBuilder, requiredFuncDeclRef); + // + auto defaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, synGenericDecl); + DeclRef<FuncDecl> requiredFuncDeclRef = m_astBuilder->getGenericAppDeclRef( + requiredMemberDeclRef, defaultArgs.getArrayView()).as<FuncDecl>(); - SLANG_ASSERT(requiredFuncDeclRef.as<FuncDecl>()); + SLANG_ASSERT(requiredFuncDeclRef); synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitness( context, - requiredFuncDeclRef.as<FuncDecl>(), + requiredFuncDeclRef, synArgs, synThis); synGenericDecl->inner->parentDecl = synGenericDecl; @@ -2860,14 +2722,12 @@ namespace Slang { if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) { - ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); - val->func = satisfyingMemberDeclRef; + ForwardDifferentiateVal* val = m_astBuilder->getOrCreate<ForwardDifferentiateVal>(satisfyingMemberDeclRef); witnessTable->add(fwdReq, RequirementWitness(val)); } else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) { - DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); - val->func = satisfyingMemberDeclRef; + DifferentiateVal* val = m_astBuilder->getOrCreate<BackwardDifferentiateVal>(satisfyingMemberDeclRef); witnessTable->add(bwdReq, RequirementWitness(val)); } } @@ -3127,7 +2987,7 @@ namespace Slang // or uses, an associated type or `This`. // // Ideally we should be looking up the type using a `DeclRef` that - // refers to the interface requirement using a `ThisTypeSubstitution` + // refers to the interface requirement using a `LookupDeclRef` // that refers to the satisfying type declaration, and requirement // checking for non-associated-type requirements should be done *after* // requirement checking for associated-type requirements. @@ -3577,7 +3437,7 @@ namespace Slang // First we need to make sure the associated `Differential` type requirement is satisfied. bool hasDifferentialAssocType = false; - for (auto existingEntry : witnessTable->requirementDictionary) + for (auto& existingEntry : witnessTable->getRequirementDictionary()) { if (auto builtinReqAttr = existingEntry.key->findModifier<BuiltinRequirementModifier>()) { @@ -3726,20 +3586,33 @@ namespace Slang // If `This` is nested inside a generic, we need to form a complete declref type to the // newly synthesized method here in order to fill into the witness table. - // This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the + // This can be done by obtaining the ThisType witness from requirementDeclRef to get the // generic substitution for outer generic parameters, and apply it here. SubstitutionSet substSet; - if (auto thisTypeSubst = findThisTypeSubstitution( - requirementDeclRef.getSubst(), + if (auto thisTypeWitness = findThisTypeWitness( + SubstitutionSet(requirementDeclRef), as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) { - if (auto declRefType = as<DeclRefType>(thisTypeSubst->witness->sub)) + if (auto declRefType = as<DeclRefType>(thisTypeWitness->getSub())) { - substSet = declRefType->declRef.getSubst(); + substSet = SubstitutionSet(declRefType->getDeclRef()); } } + if (auto outerGeneric = GetOuterGeneric(context->parentDecl)) + { + // If the context->parentDecl is not the same as ThisType represented by genApp, then it must be an extension + // to ThisType. In this case, we need to form a new GenericAppDeclRef to specailizethe outer parent extension + // decl. Note that the extension might be a partial extension with some generic arguments missing, and + // we can't support that case right now. For now we can just assume the extension will have the same set + // of generic parameters as the target type. + auto defaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, outerGeneric); + auto specializedParent = m_astBuilder->getGenericAppDeclRef(makeDeclRef(outerGeneric), defaultArgs.getArrayView()); + auto specializedFunc = m_astBuilder->getMemberDeclRef(specializedParent, synFunc); + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(specializedFunc)); + return true; + } - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(m_astBuilder->getSpecializedDeclRef<Decl>(synFunc, substSet))); + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(m_astBuilder->getDirectDeclRef(synFunc))); return true; } @@ -3767,11 +3640,16 @@ namespace Slang // witness in the table for the requirement, so // that we can bail out early. // - if(witnessTable->requirementDictionary.containsKey(requiredMemberDeclRef.getDecl())) + if(witnessTable->getRequirementDictionary().containsKey(requiredMemberDeclRef.getDecl())) { return true; } + // The ThisType requirement is always satisfied. + if (as<ThisTypeDecl>(requiredMemberDeclRef.getDecl())) + { + return true; + } // An important exception to the above is that an // inheritance declaration in the interface is not going @@ -3987,17 +3865,13 @@ namespace Slang ensureDecl(superInterfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); // When comparing things like signatures, we need to do so in the context - // of a this-type substitution that aligns the signatures in the interface + // of a LookupDeclRef that aligns the signatures in the interface // with those in the concrete type. For example, we need to treat any uses // of `This` in the interface as equivalent to the concrete type for the // purpose of signature matching (and similarly for associated types). // - ThisTypeSubstitution* thisTypeSubst = m_astBuilder->getOrCreateThisTypeSubstitution( - superInterfaceDeclRef.getDecl(), - subTypeConformsToSuperInterfaceWitness, - superInterfaceDeclRef.getSubst()); - - auto specializedSuperInterfaceDeclRef = m_astBuilder->getSpecializedDeclRef<InterfaceDecl>(superInterfaceDeclRef.getDecl(), thisTypeSubst); + auto thisTypeDeclRef = m_astBuilder->getLookupDeclRef( + subTypeConformsToSuperInterfaceWitness, superInterfaceDeclRef.getDecl()->getThisTypeDecl()); bool result = true; @@ -4029,35 +3903,36 @@ namespace Slang // constraints and solve for those type variables as part of the // conformance-checking process. // - for(auto requiredMemberDeclRef : getMembers(m_astBuilder, specializedSuperInterfaceDeclRef)) + for(auto requiredMemberDecl : getMembers(m_astBuilder, superInterfaceDeclRef)) { - if(!isAssociatedTypeDecl(requiredMemberDeclRef.getDecl())) + if(!isAssociatedTypeDecl(requiredMemberDecl.getDecl())) continue; - + auto requiredMemberDeclRef = m_astBuilder->getLookupDeclRef(subTypeConformsToSuperInterfaceWitness, requiredMemberDecl.getDecl()); auto requirementSatisfied = findWitnessForInterfaceRequirement( context, subType, superInterfaceType, inheritanceDecl, - specializedSuperInterfaceDeclRef, + thisTypeDeclRef, requiredMemberDeclRef, witnessTable, subTypeConformsToSuperInterfaceWitness); result = result && requirementSatisfied; } - for(auto requiredMemberDeclRef : getMembers(m_astBuilder, specializedSuperInterfaceDeclRef)) + for(auto requiredMemberDecl : getMembers(m_astBuilder, superInterfaceDeclRef)) { - if(isAssociatedTypeDecl(requiredMemberDeclRef.getDecl())) + if(isAssociatedTypeDecl(requiredMemberDecl.getDecl())) continue; - if (requiredMemberDeclRef.as<DerivativeRequirementDecl>()) + if (requiredMemberDecl.as<DerivativeRequirementDecl>()) continue; + auto requiredMemberDeclRef = m_astBuilder->getLookupDeclRef(subTypeConformsToSuperInterfaceWitness, requiredMemberDecl.getDecl()); auto requirementSatisfied = findWitnessForInterfaceRequirement( context, subType, superInterfaceType, inheritanceDecl, - specializedSuperInterfaceDeclRef, + thisTypeDeclRef, requiredMemberDeclRef, witnessTable, subTypeConformsToSuperInterfaceWitness); @@ -4089,25 +3964,27 @@ namespace Slang // the time we are compiling and handle those, and punt on the larger issue // for a bit longer. // - for(auto candidateExt : getCandidateExtensions(specializedSuperInterfaceDeclRef, this)) + for(auto candidateExt : getCandidateExtensions(superInterfaceDeclRef, this)) { // We need to apply the extension to the interface type that our // concrete type is inheriting from. // - Type* targetType = DeclRefType::create(m_astBuilder, specializedSuperInterfaceDeclRef); - auto extDeclRef = ApplyExtensionToType(candidateExt, targetType); - if(!extDeclRef) + Type* targetType = DeclRefType::create(m_astBuilder, thisTypeDeclRef); + auto parentDeclRef = applyExtensionToType(candidateExt, targetType); + if(!parentDeclRef) continue; // Only inheritance clauses from the extension matter right now. - for(auto requiredInheritanceDeclRef : getMembersOfType<InheritanceDecl>(m_astBuilder, extDeclRef)) + for(auto requiredInheritanceDecl : getMembersOfType<InheritanceDecl>(m_astBuilder, candidateExt)) { + auto requiredInheritanceDeclRef = m_astBuilder->getLookupDeclRef( + subTypeConformsToSuperInterfaceWitness, requiredInheritanceDecl.getDecl()); auto requirementSatisfied = findWitnessForInterfaceRequirement( context, subType, superInterfaceType, inheritanceDecl, - specializedSuperInterfaceDeclRef, + thisTypeDeclRef, requiredInheritanceDeclRef, witnessTable, subTypeConformsToSuperInterfaceWitness); @@ -4131,7 +4008,7 @@ namespace Slang { if (auto supereclRefType = as<DeclRefType>(superType)) { - auto superTypeDeclRef = supereclRefType->declRef; + auto superTypeDeclRef = supereclRefType->getDeclRef(); if (auto superInterfaceDeclRef = superTypeDeclRef.as<InterfaceDecl>()) { // The type is stating that it conforms to an interface. @@ -4172,11 +4049,11 @@ namespace Slang if( auto declRefType = as<DeclRefType>(subType) ) { - auto declRef = declRefType->declRef; + auto declRef = declRefType->getDeclRef(); if (auto superDeclRefType = as<DeclRefType>(superType)) { - auto superTypeDecl = superDeclRefType->declRef.getDecl(); + auto superTypeDecl = superDeclRefType->getDeclRef().getDecl(); if (superTypeDecl->findModifier<ComInterfaceAttribute>()) { // A struct cannot implement a COM Interface. @@ -4228,10 +4105,7 @@ namespace Slang // Look at the type being inherited from, and validate // appropriately. - DeclaredSubtypeWitness* subIsSuperWitness = m_astBuilder->create<DeclaredSubtypeWitness>(); - subIsSuperWitness->declRef = makeDeclRef(inheritanceDecl); - subIsSuperWitness->sub = subType; - subIsSuperWitness->sup = superType; + DeclaredSubtypeWitness* subIsSuperWitness = m_astBuilder->getDeclaredSubtypeWitness(subType, superType, makeDeclRef(inheritanceDecl)); ConformanceCheckingContext context; context.conformingType = subType; @@ -4333,7 +4207,7 @@ namespace Slang { return; } - auto baseDecl = baseDeclRefType->declRef.getDecl(); + auto baseDecl = baseDeclRefType->getDeclRef().getDecl(); // Using the parent/child hierarchy baked into `Decl`s we // can find the modules that contain both the `decl` doing @@ -4415,7 +4289,7 @@ namespace Slang continue; } - auto baseDeclRef = baseDeclRefType->declRef; + auto baseDeclRef = baseDeclRefType->getDeclRef(); auto baseInterfaceDeclRef = baseDeclRef.as<InterfaceDecl>(); if( !baseInterfaceDeclRef ) { @@ -4476,7 +4350,7 @@ namespace Slang continue; } - auto baseDeclRef = baseDeclRefType->declRef; + auto baseDeclRef = baseDeclRefType->getDeclRef(); if( auto baseInterfaceDeclRef = baseDeclRef.as<InterfaceDecl>() ) { } @@ -4545,7 +4419,7 @@ namespace Slang continue; } - auto baseDeclRef = baseDeclRefType->declRef; + auto baseDeclRef = baseDeclRefType->getDeclRef(); if (auto baseInterfaceDeclRef = baseDeclRef.as<InterfaceDecl>()) { } @@ -4594,8 +4468,8 @@ namespace Slang auto basicType = as<BasicExpressionType>(type); if(!basicType) return false; - - return isIntegerBaseType(basicType->baseType) || basicType->baseType == BaseType::Bool; + auto baseType = basicType->getBaseType(); + return isIntegerBaseType(baseType) || baseType == BaseType::Bool; } bool SemanticsVisitor::isIntValueInRangeOfType(IntegerLiteralValue value, Type* type) @@ -4604,7 +4478,7 @@ namespace Slang if (!basicType) return false; - switch (basicType->baseType) + switch (basicType->getBaseType()) { case BaseType::UInt8: return (value >= 0 && value <= std::numeric_limits<uint8_t>::max()) || (value == -1); @@ -4686,7 +4560,7 @@ namespace Slang continue; } - auto baseDeclRef = baseDeclRefType->declRef; + auto baseDeclRef = baseDeclRefType->getDeclRef(); if( auto baseInterfaceDeclRef = baseDeclRef.as<InterfaceDecl>() ) { _validateCrossModuleInheritance(decl, inheritanceDecl); @@ -4790,7 +4664,7 @@ namespace Slang Decl* tagAssociatedTypeDecl = nullptr; if(auto enumTypeTypeDeclRefType = dynamicCast<DeclRefType>(enumTypeType)) { - if(auto enumTypeTypeInterfaceDecl = as<InterfaceDecl>(enumTypeTypeDeclRefType->declRef.getDecl())) + if(auto enumTypeTypeInterfaceDecl = as<InterfaceDecl>(enumTypeTypeDeclRefType->getDeclRef().getDecl())) { for(auto memberDecl : enumTypeTypeInterfaceDecl->members) { @@ -4861,7 +4735,7 @@ namespace Slang { if(auto constIntVal = as<ConstantIntVal>(explicitTagVal)) { - defaultTag = constIntVal->value; + defaultTag = constIntVal->getValue(); } else { @@ -5015,7 +4889,7 @@ namespace Slang bool SemanticsVisitor::doGenericSignaturesMatch( GenericDecl* left, GenericDecl* right, - GenericSubstitution** outSubstRightToLeft) + DeclRef<Decl>* outSpecializedRightInner) { // Our first goal here is to determine if `left` and // `right` have equivalent lists of explicit @@ -5133,9 +5007,9 @@ namespace Slang // `foo2<T>` so that its constraint, after specialization, // looks like `T : IFoo`. // - auto& substRightToLeft = *outSubstRightToLeft; - List<Val*> leftArgs = getDefaultSubstitutionArgs(left); - substRightToLeft = getASTBuilder()->getOrCreateGenericSubstitution(nullptr, right, leftArgs); + auto& substInnerRightToLeft = *outSpecializedRightInner; + List<Val*> leftArgs = getDefaultSubstitutionArgs(m_astBuilder, this, left); + substInnerRightToLeft = m_astBuilder->getGenericAppDeclRef(makeDeclRef(right), leftArgs.getArrayView()); // We should now be able to enumerate the constraints // on `right` in a way that uses the same type parameters @@ -5207,7 +5081,9 @@ namespace Slang // arguments into account. // GenericTypeConstraintDecl* leftConstraint = leftConstraints[cc]; - DeclRef<GenericTypeConstraintDecl> rightConstraint = m_astBuilder->getSpecializedDeclRef(rightConstraints[cc], substRightToLeft); + auto unspecializedRightConstarintDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(rightConstraints[cc])); + DeclRef<GenericTypeConstraintDecl> rightConstraint = substInnerRightToLeft.substitute( + m_astBuilder, unspecializedRightConstarintDeclRef).as<GenericTypeConstraintDecl>(); // For now, every constraint has the form `sub : sup` // to indicate that `sub` must be a subtype of `sup`. @@ -5277,44 +5153,59 @@ namespace Slang return true; } - List<Val*> SemanticsVisitor::getDefaultSubstitutionArgs(GenericDecl* genericDecl) + List<Val*> getDefaultSubstitutionArgs(ASTBuilder* astBuilder, SemanticsVisitor* semantics, GenericDecl* genericDecl) { List<Val*> args; - for (auto dd : genericDecl->members) - { - if (dd == genericDecl->inner) - continue; + if (astBuilder->m_cachedGenericDefaultArgs.tryGetValue(genericDecl, args)) + return args; - if (auto typeParam = as<GenericTypeParamDecl>(dd)) + for (auto mm : genericDecl->members) + { + if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm)) { - auto type = DeclRefType::create(m_astBuilder, makeDeclRef(typeParam)); - args.add(type); + args.add(DeclRefType::create(astBuilder, astBuilder->getDirectDeclRef(genericTypeParamDecl))); } - else if (auto valueParam = as<GenericValueParamDecl>(dd)) + else if (auto genericValueParamDecl = as<GenericValueParamDecl>(mm)) { - auto val = m_astBuilder->getOrCreate<GenericParamIntVal>( - valueParam->getType(), - DeclRef<VarDeclBase>(valueParam)); - args.add(val); + if (semantics) + semantics->ensureDecl(genericValueParamDecl, DeclCheckState::ReadyForLookup); + + args.add(astBuilder->getOrCreate<GenericParamIntVal>( + genericValueParamDecl->getType(), + astBuilder->getDirectDeclRef(genericValueParamDecl))); } } - // Add defaults for constraint parameters. - for (auto dd : genericDecl->members) + bool shouldCache = true; + + // create default substitution arguments for constraints + for (auto mm : genericDecl->members) { - if (auto constraintDecl = as<GenericTypeConstraintDecl>(dd)) + if (auto genericTypeConstraintDecl = as<GenericTypeConstraintDecl>(mm)) { - // Convert the constraint to an appropriate witness. - auto witness = tryGetSubtypeWitness(constraintDecl->sub, constraintDecl->sup); - - // Must be non-null since we know there's a constraint. If null, something is - // very wrong. - // - SLANG_ASSERT(witness); - + if (semantics) + semantics->ensureDecl(genericTypeConstraintDecl, DeclCheckState::ReadyForReference); + auto constraintDeclRef = astBuilder->getDirectDeclRef<GenericTypeConstraintDecl>(genericTypeConstraintDecl); + auto witness = + astBuilder->getDeclaredSubtypeWitness( + getSub(astBuilder, constraintDeclRef), + getSup(astBuilder, constraintDeclRef), + constraintDeclRef); + // TODO: this is an ugly hack to prevent crashing. + // In early stages of compilation witness->sub and witness->sup may not be checked yet. + // When semanticVisitor is present we have used that to ensure the type is checked. + // However due to how the code is written we cannot guarantee semanticVisitor is always available + // here, and if we can't get the checked sup/sub type this subst is incomplete and should not be + // cached. + if (!witness->getSub()) + shouldCache = false; args.add(witness); } } + + if (shouldCache) + astBuilder->m_cachedGenericDefaultArgs[genericDecl] = args; + return args; } @@ -5442,11 +5333,11 @@ namespace Slang // Then we will compare the parameter types of `foo2` // against the specialization `foo1<U>`. // - GenericSubstitution* subst = nullptr; - if(!doGenericSignaturesMatch(newGenericDecl, oldGenericDecl, &subst)) + DeclRef<Decl> specializedOldDeclInner; + if(!doGenericSignaturesMatch(newGenericDecl, oldGenericDecl, &specializedOldDeclInner)) return SLANG_OK; - oldDeclRef = getASTBuilder()->getSpecializedDeclRef(oldDecl, subst); + oldDeclRef = specializedOldDeclInner.as<FuncDecl>(); } // If the parameter signatures don't match, then don't worry @@ -5869,7 +5760,7 @@ namespace Slang auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>(); reqDecl->originalRequirementDecl = decl; cloneModifiers(reqDecl, decl); - auto declRef = m_astBuilder->getSpecializedDeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); + auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)).as<CallableDecl>(); auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); interfaceDecl->members.add(reqDecl); @@ -5884,7 +5775,7 @@ namespace Slang if (decl->hasModifier<BackwardDifferentiableAttribute>()) { // Requirement for backward derivative. - auto declRef = m_astBuilder->getSpecializedDeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); + auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)).as<CallableDecl>(); auto originalFuncType = getFuncType(m_astBuilder, declRef); auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(originalFuncType)); { @@ -5953,7 +5844,7 @@ namespace Slang IntegerLiteralValue SemanticsVisitor::GetMinBound(IntVal* val) { if (auto constantVal = as<ConstantIntVal>(val)) - return constantVal->value; + return constantVal->getValue(); // TODO(tfoley): Need to track intervals so that this isn't just a lie... return 1; @@ -6024,7 +5915,7 @@ namespace Slang if (auto targetDeclRefType = as<DeclRefType>(decl->targetType)) { // Attach our extension to that type as a candidate... - if (auto aggTypeDeclRef = targetDeclRefType->declRef.as<AggTypeDecl>()) + if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>()) { auto aggTypeDecl = aggTypeDeclRef.getDecl(); @@ -6075,7 +5966,7 @@ namespace Slang continue; } - auto baseDeclRef = baseDeclRefType->declRef; + auto baseDeclRef = baseDeclRefType->getDeclRef(); auto baseInterfaceDeclRef = baseDeclRef.as<InterfaceDecl>(); if( !baseInterfaceDeclRef ) { @@ -6106,9 +5997,9 @@ namespace Slang // conform to the interface and fill in its // requirements. // - ThisType* thisType = m_astBuilder->create<ThisType>(); - thisType->interfaceDeclRef = interfaceDeclRef; - return thisType; + return DeclRefType::create( + m_astBuilder, + m_astBuilder->getDirectDeclRef(interfaceDeclRef.getDecl()->getThisTypeDecl())); } else if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>()) { @@ -6159,7 +6050,7 @@ namespace Slang { if( auto declRefType = as<DeclRefType>(type) ) { - return calcThisType(declRefType->declRef); + return calcThisType(declRefType->getDeclRef()); } else { @@ -6404,7 +6295,7 @@ namespace Slang return parentGeneric; } - DeclRef<ExtensionDecl> SemanticsVisitor::ApplyExtensionToType( + DeclRef<ExtensionDecl> SemanticsVisitor::applyExtensionToType( ExtensionDecl* extDecl, Type* type) { @@ -6438,15 +6329,15 @@ namespace Slang if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type)) return DeclRef<ExtensionDecl>(); - auto constraintSubst = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl)); - if (!constraintSubst) + auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView<Val*>()); + if (!solvedDeclRef) { return DeclRef<ExtensionDecl>(); } // Construct a reference to the extension with our constraint variables // set as they were found by solving the constraint system. - extDeclRef = m_astBuilder->getSpecializedDeclRef<Decl>(extDecl, constraintSubst).as<ExtensionDecl>(); + extDeclRef = solvedDeclRef.as<ExtensionDecl>(); } // Now extract the target type from our (possibly specialized) extension decl-ref. @@ -6458,67 +6349,21 @@ namespace Slang // substitution to the extension decl-ref. if(auto targetDeclRefType = as<DeclRefType>(targetType)) { - if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>()) + if(auto targetInterfaceDeclRef = targetDeclRefType->getDeclRef().as<InterfaceDecl>()) { // Okay, the target type is an interface. // - // Is the type we want to apply to also an interface? - if(auto appDeclRefType = as<DeclRefType>(type)) + // Is the type we want to apply to a ThisType? + if(auto appDeclRefType = as<ThisType>(type)) { - if(auto appInterfaceDeclRef = appDeclRefType->declRef.as<InterfaceDecl>()) + if(auto thisTypeLookupDeclRef = SubstitutionSet(appDeclRefType->getDeclRef()).findLookupDeclRef()) { - if(appInterfaceDeclRef.getDecl() == targetInterfaceDeclRef.getDecl()) + if(thisTypeLookupDeclRef->getDecl() == targetInterfaceDeclRef.getDecl()) { // Looks like we have a match in the types, - // now let's see if we have a this-type substitution. - if(auto appThisTypeSubst = as<ThisTypeSubstitution>(appInterfaceDeclRef.getSubst())) - { - if(appThisTypeSubst->interfaceDecl == appInterfaceDeclRef.getDecl()) - { - // The type we want to apply to has a this-type substitution, - // and (by construction) the target type currently does not. - // - SLANG_ASSERT(!as<ThisTypeSubstitution>(targetInterfaceDeclRef.getSubst())); - - // We will create a new substitution to apply to the target type. - ThisTypeSubstitution* newTargetSubst = m_astBuilder->getOrCreateThisTypeSubstitution( - appThisTypeSubst->interfaceDecl, - appThisTypeSubst->witness, - targetInterfaceDeclRef.getSubst()); - - targetType = DeclRefType::create(m_astBuilder, - m_astBuilder->getSpecializedDeclRef<InterfaceDecl>(targetInterfaceDeclRef.getDecl(), newTargetSubst)); - - // Note: we are constructing a this-type substitution that - // we will apply to the extension declaration as well. - // This is not strictly allowed by our current representation - // choices, but we need it in order to make sure that - // references to the target type of the extension - // declaration have a chance to resolve the way we want them to. - - ThisTypeSubstitution* newExtSubst = m_astBuilder->getOrCreateThisTypeSubstitution( - appThisTypeSubst->interfaceDecl, - appThisTypeSubst->witness, - extDeclRef.getSubst()); - - extDeclRef = m_astBuilder->getSpecializedDeclRef<ExtensionDecl>( - extDeclRef.getDecl(), - newExtSubst); - - // TODO: Ideally we should also apply the chosen specialization to - // the decl-ref for the extension, so that subsequent lookup through - // the members of this extension will retain that substitution and - // be able to apply it. - // - // E.g., if an extension method returns a value of an associated - // type, then we'd want that to become specialized to a concrete - // type when using the extension method on a value of concrete type. - // - // The challenge here that makes me reluctant to just staple on - // such a substitution is that it wouldn't follow our implicit - // rules about where `ThisTypeSubstitution`s can appear. - } - } + // now let's see if `type`'s declref starts with a Lookup. + targetType = type; + extDeclRef = m_astBuilder->getLookupDeclRef(thisTypeLookupDeclRef->getWitness(), extDeclRef.getDecl()); } } } @@ -6641,7 +6486,6 @@ namespace Slang { if( auto namespaceDeclRef = declRefExpr->declRef.as<NamespaceDeclBase>() ) { - SLANG_ASSERT(!namespaceDeclRef.getSubst()); namespaceDecl = namespaceDeclRef.getDecl(); } } @@ -7007,7 +6851,7 @@ namespace Slang // the extension to the type and see if we succeed in // making a match. // - auto extDeclRef = ApplyExtensionToType(semantics, extDecl, aggType); + auto extDeclRef = applyExtensionToType(semantics, extDecl, aggType); if(!extDeclRef) continue; @@ -7065,8 +6909,8 @@ namespace Slang { if (auto andType = as<AndType>(type)) { - _getCanonicalConstraintTypes(outTypeList, andType->left); - _getCanonicalConstraintTypes(outTypeList, andType->right); + _getCanonicalConstraintTypes(outTypeList, andType->getLeft()); + _getCanonicalConstraintTypes(outTypeList, andType->getRight()); } else { @@ -7087,7 +6931,7 @@ namespace Slang assert( genericTypeConstraintDecl.getDecl()->sub.type->astNodeType == ASTNodeType::DeclRefType); - auto typeParamDecl = as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type)->declRef.getDecl(); + auto typeParamDecl = as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type)->getDeclRef().getDecl(); List<Type*>* constraintTypes = genericConstraints.tryGetValue(typeParamDecl); assert(constraintTypes); constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type); @@ -7107,42 +6951,6 @@ namespace Slang return result; } - Val* SemanticsDeclTypeResolutionVisitor::resolveVal(Val* val) - { - if (auto declRefType = as<DeclRefType>(val)) - { - if (auto concreteType = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(m_astBuilder, declRefType->declRef)) - return as<Type>(concreteType); - for (auto subst = declRefType->declRef.getSubst(); subst; subst=subst->getOuter()) - { - if (auto genericSubst = as<GenericSubstitution>(subst)) - { - ShortList<Val*> newArgs; - for (auto& arg : genericSubst->getArgs()) - { - arg = resolveVal(arg); - SLANG_RELEASE_ASSERT(arg); - } - } - } - } - else if (auto subtypeWitness = as<SubtypeWitness>(val)) - { - auto sub = as<Type>(resolveVal(subtypeWitness->sub)); - auto sup = as<Type>(resolveVal(subtypeWitness->sup)); - if (sub && sup) - { - if (sub != subtypeWitness->sub || sup != subtypeWitness->sup) - { - auto newVal = tryGetSubtypeWitness(as<Type>(sub), as<Type>(sup)); - if (newVal) - val = newVal; - } - } - } - return val; - } - struct ArgsWithDirectionInfo { List<Expr*> args; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 3c90c3ed8..e343e3113 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -22,7 +22,7 @@ namespace Slang DeclRefType* SemanticsVisitor::getExprDeclRefType(Expr * expr) { if (auto typetype = as<TypeType>(expr->type)) - return dynamicCast<DeclRefType>(typetype->type); + return dynamicCast<DeclRefType>(typetype->getType()); else return as<DeclRefType>(expr->type); } @@ -154,10 +154,8 @@ namespace Slang // return maybeMoveTemp(expr, [&](DeclRef<VarDeclBase> varDeclRef) { - ExtractExistentialType* openedType = m_astBuilder->create<ExtractExistentialType>(); - openedType->declRef = varDeclRef; - openedType->originalInterfaceType = expr->type.type; - openedType->originalInterfaceDeclRef = interfaceDeclRef; + ExtractExistentialType* openedType = m_astBuilder->getOrCreate<ExtractExistentialType>( + varDeclRef, expr->type.type, interfaceDeclRef); ExtractExistentialValueExpr* openedValue = m_astBuilder->create<ExtractExistentialValueExpr>(); openedValue->declRef = varDeclRef; @@ -202,29 +200,9 @@ namespace Slang if(auto declRefType = as<DeclRefType>(exprType)) { - if(auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>()) + if(auto interfaceDeclRef = declRefType->getDeclRef().as<InterfaceDecl>()) { - // Is there an this-type substitution being applied, so that - // we are referencing the interface type through a concrete - // type (e.g., a type parameter constrained to this interface)? - // - // Because of the way that substitutions need to mirror the nesting - // hierarchy of declarations, any this-type substitution pertaining - // to the chosen interface decl must be the first substitution on - // the list (which is a linked list from the "inside" out). - // - auto thisTypeSubst = as<ThisTypeSubstitution>(interfaceDeclRef.getSubst()); - if(thisTypeSubst && thisTypeSubst->interfaceDecl == interfaceDeclRef.getDecl()) - { - // This isn't really an existential type, because somebody - // has already filled in a this-type substitution. - } - else - { - // Okay, here is the case that matters. - // - return openExistential(expr, interfaceDeclRef); - } + return openExistential(expr, interfaceDeclRef); } } @@ -317,7 +295,7 @@ namespace Slang // actually names a type, because in that case we are doing // a static member reference. // - if (auto typeType = as<TypeType>(baseExpr->type)) + if (auto typeType = as<TypeType>(baseExpr->type->getCanonicalType())) { // Before forming the reference, we will check if the // member being referenced can even be used as a static @@ -340,7 +318,7 @@ namespace Slang getSink()->diagnose( loc, Diagnostics::staticRefToNonStaticMember, - typeType->type, + typeType->getType(), declRef.getName()); } @@ -493,9 +471,9 @@ namespace Slang case LookupResultItem::Breadcrumb::Kind::SuperType: { auto witness = as<SubtypeWitness>(breadcrumb->val); - if (auto subDeclRefType = as<DeclRefType>(witness->sub)) + if (auto subDeclRefType = as<DeclRefType>(witness->getSub())) { - if (!as<InterfaceDecl>(subDeclRefType->declRef.getDecl())) + if (!as<InterfaceDecl>(subDeclRefType->getDeclRef().getDecl())) { // Store the inner most concrete super type. subType = subDeclRefType; @@ -515,10 +493,13 @@ namespace Slang return nullptr; // Don't synthesize for generic parameters. - auto parent = as<AggTypeDecl>(subType->declRef.getDecl()); + auto parent = as<AggTypeDecl>(subType->getDeclRef().getDecl()); if (!parent) return nullptr; + // Don't synthesize for ThisType. + if (as<ThisTypeDecl>(subType->getDeclRef().getDecl())) + return nullptr; // If we reach here, we are expecting a synthesized decl defined in `subType`. // Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl @@ -607,7 +588,7 @@ namespace Slang // auto witness = as<SubtypeWitness>(breadcrumb->val); SLANG_ASSERT(witness); - auto expr = createCastToSuperTypeExpr(witness->sup, bb, witness); + auto expr = createCastToSuperTypeExpr(witness->getSup(), bb, witness); // Note that we allow a cast of an l-value to // be used as an l-value here because it enables @@ -926,7 +907,7 @@ namespace Slang if (auto declRefType = as<DeclRefType>(type)) { - if (auto builtinRequirement = declRefType->declRef.getDecl()->findModifier<BuiltinRequirementModifier>()) + if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>()) { if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType) { @@ -935,6 +916,7 @@ namespace Slang return type; } } + type = resolveType(type); if (const auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType()))) { auto diffTypeLookupResult = lookUpMember( @@ -964,10 +946,10 @@ namespace Slang auto diffTypeExpr = ConstructLookupResultExpr( diffTypeLookupResult.item, baseTypeExpr, - declRefType->declRef.getLoc(), + declRefType->getDeclRef().getLoc(), baseTypeExpr); - return ExtractTypeFromTypeRepr(diffTypeExpr); + return resolveType(ExtractTypeFromTypeRepr(diffTypeExpr)); } } } @@ -991,7 +973,7 @@ namespace Slang SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr); if (witness) { - m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.addIfNotExists(type->declRef, witness); + m_parentDifferentiableAttr->addType(type->getDeclRef(), witness); } } @@ -1048,7 +1030,7 @@ namespace Slang { addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); } - if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>()) + if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>()) { foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) { @@ -1061,23 +1043,13 @@ namespace Slang maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType); }); } - for (auto subst = declRefType->declRef.getSubst(); subst; subst = subst->getOuter()) - { - if (auto genSubst = as<GenericSubstitution>(subst)) + SubstitutionSet(declRefType->getDeclRef()).forEachSubstitutionArg([&](Val* arg) { - for (auto arg : genSubst->getArgs()) + if (auto typeArg = as<Type>(arg)) { - if (auto typeArg = as<Type>(arg)) - { - maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg); - } + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg); } - } - else if (auto thisSubst = as<ThisTypeSubstitution>(subst)) - { - maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, thisSubst->witness->sub); - } - } + }); return; } } @@ -1302,7 +1274,7 @@ namespace Slang if (auto constArgVal = as<ConstantIntVal>(argVal)) { - constArgVals[a] = constArgVal->value; + constArgVals[a] = constArgVal->getValue(); } else { @@ -1366,12 +1338,13 @@ namespace Slang || opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") || opName == getName("?:") || opName == getName("<<") || opName == getName(">>")) { - auto result = m_astBuilder->create<FuncCallIntVal>(invokeExpr.getExpr()->type.type); - result->args.addRange(argVals, argCount); - result->funcDeclRef = funcDeclRef; - result->funcType = as<Type>(funcDeclRefExpr.getExpr()->type->substitute( - m_astBuilder, funcDeclRefExpr.getSubsts())); - SLANG_RELEASE_ASSERT(result->funcType); + auto result = m_astBuilder->getOrCreate<FuncCallIntVal>( + invokeExpr.getExpr()->type.type, + funcDeclRef, + as<Type>(funcDeclRefExpr.getExpr()->type->substitute( + m_astBuilder, funcDeclRefExpr.getSubsts())), + makeArrayView(argVals, argCount)); + SLANG_RELEASE_ASSERT(result->getFuncType()); return result; } return nullptr; @@ -1507,18 +1480,14 @@ namespace Slang if (isInterfaceRequirement(decl)) { - for (auto subst = declRef.getSubst(); subst; subst = subst->getOuter()) - { - if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst)) - { - auto val = WitnessLookupIntVal::tryFold( - m_astBuilder, - thisTypeSubst->witness, - decl, - declRef.substitute(m_astBuilder, decl->type.type)); - return as<IntVal>(val); - } - } + auto witness = findThisTypeWitness(SubstitutionSet(declRef), as<InterfaceDecl>(decl->parentDecl)); + + auto val = WitnessLookupIntVal::tryFold( + m_astBuilder, + witness, + decl, + declRef.substitute(m_astBuilder, decl->type.type)); + return as<IntVal>(val); } if (!getInitExpr(m_astBuilder, declRef)) @@ -1785,7 +1754,7 @@ namespace Slang getSink()->diagnose(subscriptExpr, Diagnostics::multiDimensionalArrayNotSupported); } - auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); + auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->getType())); auto arrayType = getArrayType( m_astBuilder, elementType, @@ -1804,7 +1773,7 @@ namespace Slang { return CheckSimpleSubscriptExpr( subscriptExpr, - vecType->elementType); + vecType->getElementType()); } else if (auto matType = as<MatrixExpressionType>(baseType)) { @@ -1975,8 +1944,8 @@ namespace Slang if (basicTypeA && basicTypeB) { - const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->baseType); - const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->baseType); + const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->getBaseType()); + const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->getBaseType()); // TODO(JS): Initially this tries to limit where LValueImplict casts happen. // We could in principal allow different sizes, as long as we converted to a temprorary @@ -2021,7 +1990,7 @@ namespace Slang // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues if(auto funcType = as<FuncType>(invoke->functionExpr->type)) { - if (!funcType->errorType->equals(m_astBuilder->getBottomType())) + if (!funcType->getErrorType()->equals(m_astBuilder->getBottomType())) { // If the callee throws, make sure we are inside a try clause. if (m_enclosingTryClauseType == TryClauseType::None) @@ -2230,7 +2199,7 @@ namespace Slang return result; } - Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr *expr) + Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr) { // check the base expression first expr->functionExpr = CheckTerm(expr->functionExpr); @@ -2312,6 +2281,7 @@ namespace Slang auto lookupResult = lookUp( m_astBuilder, this, expr->name, expr->scope); + if (expr->name == getSession()->getCompletionRequestTokenName()) { auto scopeKind = CompletionSuggestions::ScopeKind::Expr; @@ -2357,7 +2327,7 @@ namespace Slang if (auto modifiedType = as<ModifiedType>(primalType)) { if (modifiedType->findModifier<NoDiffModifierVal>()) - return modifiedType->base; + return modifiedType->getBase(); } // Get a reference to the builtin 'IDifferentiable' interface @@ -2379,23 +2349,23 @@ namespace Slang // Resolve JVP type here. // Note that this type checking needs to be in sync with // the auto-generation logic in slang-ir-jvp-diff.cpp - - FuncType* jvpType = m_astBuilder->create<FuncType>(); + List<Type*> paramTypes; // The JVP return type is float if primal return type is float // void otherwise. // - jvpType->resultType = getDifferentialPairType(originalType->getResultType()); + auto resultType = getDifferentialPairType(originalType->getResultType()); // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); - jvpType->errorType = originalType->errorType; + SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType())); + auto errorType = originalType->getErrorType(); for (Index i = 0; i < originalType->getParamCount(); i++) { if(auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i))) - jvpType->paramTypes.add(jvpParamType); + paramTypes.add(jvpParamType); } + FuncType* jvpType = m_astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType); return jvpType; } @@ -2405,16 +2375,15 @@ namespace Slang // Resolve backward diff type here. // Note that this type checking needs to be in sync with // the auto-generation logic in slang-ir-jvp-diff.cpp - - FuncType* type = m_astBuilder->create<FuncType>(); + List<Type*> paramTypes; // The backward diff return type is void // - type->resultType = m_astBuilder->getVoidType(); + auto resultType = m_astBuilder->getVoidType(); // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); - type->errorType = originalType->errorType; + SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType())); + auto errorType = originalType->getErrorType(); for (Index i = 0; i < originalType->getParamCount(); i++) { @@ -2424,7 +2393,7 @@ namespace Slang tryGetDifferentialType(m_astBuilder, outType->getValueType()); if (diffElementType) { - type->paramTypes.add(diffElementType); + paramTypes.add(diffElementType); } else { @@ -2447,16 +2416,16 @@ namespace Slang derivType = inoutType->getValueType(); } } - type->paramTypes.add(derivType); + paramTypes.add(derivType); } } // Last parameter is the initial derivative of the original return type - auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->resultType); + auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->getResultType()); if (dOutType) - type->paramTypes.add(dOutType); + paramTypes.add(dOutType); - return type; + return m_astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType); } struct HigherOrderInvokeExprCheckingActions @@ -2473,9 +2442,8 @@ namespace Slang if (auto baseFuncGenericDeclRef = declRefExpr->declRef.as<GenericDecl>()) { // Get inner function - DeclRef<Decl> unspecializedInnerRef = astBuilder->getSpecializedDeclRef<Decl>( - getInner(baseFuncGenericDeclRef), - baseFuncGenericDeclRef.getSubst()); + DeclRef<Decl> unspecializedInnerRef = createDefaultSubstitutionsIfNeeded(astBuilder, semantics, + astBuilder->getMemberDeclRef(baseFuncGenericDeclRef, getInner(baseFuncGenericDeclRef))); auto callableDeclRef = unspecializedInnerRef.as<CallableDecl>(); if (!callableDeclRef) return nullptr; @@ -2677,10 +2645,10 @@ namespace Slang return false; if (!isIntegerBaseType(getVectorBaseType(vectorType))) return false; - auto constElementCount = as<ConstantIntVal>(vectorType->elementCount); + auto constElementCount = as<ConstantIntVal>(vectorType->getElementCount()); if (!constElementCount) return false; - return constElementCount->value == 3; + return constElementCount->getValue() == 3; }; expr->threadGroupSize = dispatchExpr(expr->threadGroupSize, *this); if (!isInt3Type(expr->threadGroupSize->type.type)) @@ -2836,7 +2804,7 @@ namespace Slang // if( auto declRefType = as<DeclRefType>(typeExp.type) ) { - if(const auto structDeclRef = as<StructDecl>(declRefType->declRef)) + if(const auto structDeclRef = as<StructDecl>(declRefType->getDeclRef())) { if( expr->arguments.getCount() == 1 ) { @@ -3051,7 +3019,7 @@ namespace Slang auto baseType = expr->type; if (auto pointerLikeType = as<PointerLikeType>(baseType)) { - auto elementType = QualType(pointerLikeType->elementType); + auto elementType = QualType(pointerLikeType->getElementType()); elementType.isLeftValue = baseType.isLeftValue; auto derefExpr = m_astBuilder->create<DerefExpr>(); @@ -3230,7 +3198,7 @@ namespace Slang if (auto constantColCount = as<ConstantIntVal>(baseColCount)) { return CheckMatrixSwizzleExpr(memberRefExpr, baseElementType, - constantRowCount->value, constantColCount->value); + constantRowCount->getValue(), constantColCount->getValue()); } } getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on matrix of unknown size"); @@ -3350,7 +3318,7 @@ namespace Slang { if (auto constantElementCount = as<ConstantIntVal>(baseElementCount)) { - return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value); + return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->getValue()); } else { @@ -3381,6 +3349,7 @@ namespace Slang m_astBuilder, this, expr->name, + namespaceDeclRef.getDecl(), namespaceDeclRef); if (!lookupResult.isValid()) { @@ -3406,7 +3375,7 @@ namespace Slang // // TODO: this duplicates a *lot* of logic with the case below. // We need to fix that. - auto type = typeType->type; + auto type = typeType->getType(); if (as<ErrorType>(type)) { @@ -3577,7 +3546,7 @@ namespace Slang for (auto lookupResult : overloadedExpr->lookupResult2) { bool shouldRemove = false; - if (lookupResult.declRef.getParent(m_astBuilder).as<InterfaceDecl>()) + if (lookupResult.declRef.getParent().as<InterfaceDecl>()) { shouldRemove = true; } @@ -3627,8 +3596,8 @@ namespace Slang { return CheckSwizzleExpr( expr, - baseVecType->elementType, - baseVecType->elementCount); + baseVecType->getElementType(), + baseVecType->getElementCount()); } else if(auto baseScalarType = as<BasicExpressionType>(baseType)) { @@ -3893,7 +3862,7 @@ namespace Slang types.reserve(expr->parameters.getCount()); for(const auto& t : expr->parameters) types.add(t.type); - auto funcType = m_astBuilder->getFuncType(std::move(types), expr->result.type); + auto funcType = m_astBuilder->getFuncType(types.getArrayView(), expr->result.type); expr->type = m_astBuilder->getTypeType(funcType); return expr; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 575d4aff7..46cc329a9 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -70,7 +70,7 @@ namespace Slang { if (auto basicType = as<BasicExpressionType>(typeIn)) { - auto rs = makeBasicTypeKey(basicType->baseType); + auto rs = makeBasicTypeKey(basicType->getBaseType()); if (auto constInt = as<IntegerLiteralExpr>(exprIn)) { if (constInt->value < 0) @@ -83,11 +83,11 @@ namespace Slang } else if (auto vectorType = as<VectorExpressionType>(typeIn)) { - if (auto elemCount = as<ConstantIntVal>(vectorType->elementCount)) + if (auto elemCount = as<ConstantIntVal>(vectorType->getElementCount())) { - if( auto elemBasicType = as<BasicExpressionType>(vectorType->elementType) ) + if( auto elemBasicType = as<BasicExpressionType>(vectorType->getElementType()) ) { - return makeBasicTypeKey(elemBasicType->baseType, elemCount->value); + return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount->getValue()); } } } @@ -99,7 +99,7 @@ namespace Slang { if( auto elemBasicType = as<BasicExpressionType>(matrixType->getElementType()) ) { - return makeBasicTypeKey(elemBasicType->baseType, elemCount1->value, elemCount2->value); + return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount1->getValue(), elemCount2->getValue()); } } } @@ -246,7 +246,7 @@ namespace Slang // When required, a candidate can store a pre-checked list of // arguments so that we don't have to repeat work across checking // phases. Currently this is only needed for generics. - Substitutions* subst = nullptr; + SubstitutionSet subst; }; struct TypeCheckingCache @@ -614,7 +614,7 @@ namespace Slang InheritanceInfo getInheritanceInfo(DeclRef<ExtensionDecl> const& extension); /// Try get subtype witness from cache, returns true if cache contains a result for the query. - bool tryGetSubtypeWitness(Type* sub, Type* sup, SubtypeWitness*& outWitness) + bool tryGetSubtypeWitnessFromCache(Type* sub, Type* sup, SubtypeWitness*& outWitness) { auto pair = TypePair{ sub, sup }; return m_mapTypePairToSubtypeWitness.tryGetValue(pair, outWitness); @@ -997,6 +997,21 @@ namespace Slang void diagnoseDeprecatedDeclRefUsage(DeclRef<Decl> declRef, SourceLoc loc, Expr* originalExpr); + DeclRef<Decl> getDefaultDeclRef(Decl* decl) + { + return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)); + } + + DeclRef<Decl> getSpecializedDeclRef(DeclRef<Decl> declToSpecialize, DeclRef<Decl> declRefWithSpecializationArgs) + { + return declRefWithSpecializationArgs.substitute(m_astBuilder, declToSpecialize); + } + + DeclRef<Decl> getSpecializedDeclRef(Decl* declToSpecialize, DeclRef<Decl> declRefWithSpecializationArgs) + { + return declRefWithSpecializationArgs.substitute(m_astBuilder, getDefaultDeclRef(declToSpecialize)); + } + DeclRefExpr* ConstructDeclRefExpr( DeclRef<Decl> declRef, Expr* baseExpr, @@ -1026,6 +1041,18 @@ namespace Slang SourceLoc loc, Expr* originalExpr); + + Val* resolveVal(Val* val) + { + if (!val) return nullptr; + return val->resolve(); + } + Type* resolveType(Type* type) + { + return (Type*)resolveVal(type); + } + DeclRef<Decl> resolveDeclRef(DeclRef<Decl> declRef); + /// Attempt to "resolve" an overloaded `LookupResult` to only include the "best" results LookupResult resolveOverloadedLookup(LookupResult const& lookupResult); @@ -1651,12 +1678,12 @@ namespace Slang List<GenericTypeConstraintDecl*>& outConstraints); /// Determine if `left` and `right` have matching generic signatures. - /// If they do, then outputs a substitution to `ioSubstRightToLeft` that - /// can be used to specialize `right` to the parameters of `left`. + /// If they do, then outputs a specialized declRef to `ioSubstRightToLeft` that + /// represents a reference to `right` with the parameters of `left`. bool doGenericSignaturesMatch( GenericDecl* left, GenericDecl* right, - GenericSubstitution** outSubstRightToLeft); + DeclRef<Decl>* outSpecializedRightInner); // Check if two functions have the same signature for the purposes // of overload resolution. @@ -1664,9 +1691,6 @@ namespace Slang DeclRef<FuncDecl> fst, DeclRef<FuncDecl> snd); - List<Val*> getDefaultSubstitutionArgs( - GenericDecl* genericDecl); - Result checkRedeclaration(Decl* newDecl, Decl* oldDecl); Result checkFuncRedeclaration(FuncDecl* newDecl, FuncDecl* oldDecl); void checkForRedeclaration(Decl* decl); @@ -1901,12 +1925,13 @@ namespace Slang // The `varSubst` argument provides the list of constraint // variables that were created for the system. // - // Returns a new substitution representing the values that + // Returns a new declref to the inner decl of `genericDeclRef`, + // representing the specialized generic with the values // we solved for along the way. - SubstitutionSet trySolveConstraintSystem( + DeclRef<Decl> trySolveConstraintSystem( ConstraintSystem* system, DeclRef<GenericDecl> genericDeclRef, - GenericSubstitution* substWithKnownGenericArgs = nullptr); + ArrayView<Val*> knownGenericArgs); // State related to overload resolution for a call @@ -2033,7 +2058,7 @@ namespace Slang Expr* createGenericDeclRef( Expr* baseExpr, Expr* originalExpr, - GenericSubstitution* subst); + SubstitutionSet substSet); // Take an overload candidate that previously got through // `TryCheckOverloadCandidate` above, and try to finish @@ -2112,15 +2137,15 @@ namespace Slang Val* fst, Val* snd); - bool tryUnifySubstitutions( + bool tryUnifyDeclRef( ConstraintSystem& constraints, - Substitutions* fst, - Substitutions* snd); + DeclRefBase* fst, + DeclRefBase* snd); - bool tryUnifyGenericSubstitutions( - ConstraintSystem& constraints, - GenericSubstitution* fst, - GenericSubstitution* snd); + bool tryUnifyGenericAppDeclRef( + ConstraintSystem& constraints, + GenericAppDeclRef* fst, + GenericAppDeclRef* snd); bool TryUnifyTypeParam( ConstraintSystem& constraints, @@ -2153,7 +2178,7 @@ namespace Slang Type* snd); // Is the candidate extension declaration actually applicable to the given type - DeclRef<ExtensionDecl> ApplyExtensionToType( + DeclRef<ExtensionDecl> applyExtensionToType( ExtensionDecl* extDecl, Type* type); @@ -2166,7 +2191,7 @@ namespace Slang DeclRef<Decl> inferGenericArguments( DeclRef<GenericDecl> genericDeclRef, OverloadResolveContext& context, - GenericSubstitution* substWithKnownGenericArgs, + ArrayView<Val*> knownGenericArgs, List<Type*> *innerParameterTypes = nullptr); void AddTypeOverloadCandidates( @@ -2209,7 +2234,7 @@ namespace Slang void addOverloadCandidatesForCallToGeneric( LookupResultItem genericItem, OverloadResolveContext& context, - GenericSubstitution* substWithKnownGenericArgs = nullptr); + ArrayView<Val*> knownGenericArgs); /// Check a generic application where the operands have already been checked. Expr* checkGenericAppWithCheckedArgs(GenericAppExpr* genericAppExpr); @@ -2283,7 +2308,7 @@ namespace Slang visitor->ensureDecl(decl, state); } - DeclRef<ExtensionDecl> ApplyExtensionToType( + DeclRef<ExtensionDecl> applyExtensionToType( SemanticsVisitor* semantics, ExtensionDecl* extDecl, Type* type); @@ -2318,8 +2343,6 @@ namespace Slang Expr* visitSharedTypeExpr(SharedTypeExpr* expr); - Expr* visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr); - Expr* visitInvokeExpr(InvokeExpr *expr); Expr* visitSelectExpr(SelectExpr* expr); diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp index 5a6adbae5..5fff47cf1 100644 --- a/source/slang/slang-check-inheritance.cpp +++ b/source/slang/slang-check-inheritance.cpp @@ -220,7 +220,7 @@ namespace Slang DeclRef<Decl> baseDeclRef; if (auto baseDeclRefType = as<DeclRefType>(baseType)) { - baseDeclRef = baseDeclRefType->declRef; + baseDeclRef = baseDeclRefType->getDeclRef(); } addDirectBaseFacet( @@ -239,9 +239,9 @@ namespace Slang // In the case where we have an aggregate type or `extension` // declaration, we can use the explicit list of direct bases. // - for (auto inheritanceDeclRef : getMembersOfType<InheritanceDecl>(_getASTBuilder(), aggTypeDeclBaseRef)) + for (auto typeConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(_getASTBuilder(), aggTypeDeclBaseRef)) { - visitor.ensureDecl(inheritanceDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); + visitor.ensureDecl(typeConstraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); // Note: In certain cases something takes the *syntactic* form of an inheritance // clause, but it is not actually something that should be treated as implying @@ -251,38 +251,20 @@ namespace Slang // We skip such pseudo-inheritance relationships for the purposes of determining // the linearized list of bases. // - if (inheritanceDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>()) + if (typeConstraintDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>()) continue; // The base type and subtype witness can easily be determined // using the `InheritanceDecl`. // - auto baseType = getSup(astBuilder, inheritanceDeclRef); + auto baseType = getSup(astBuilder, typeConstraintDeclRef); auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness( selfType, baseType, - inheritanceDeclRef); + typeConstraintDeclRef); addDirectBaseType(baseType, satisfyingWitness); } - - // In the case of an `associatedtype`, the constraints on the associated - // type are encoded as `GenericTypeConstraintDecl`s instead of `InheritanceDecl`s. - // - // TOD(tfoley): Can we try to unify the representations of these to avoid having - // to iterate twice? - // - for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, aggTypeDeclBaseRef)) - { - visitor.ensureDecl(constraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); - - auto baseType = getSup(astBuilder, constraintDeclRef); - auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness( - selfType, - baseType, - constraintDeclRef); - addDirectBaseType(baseType, satisfyingWitness); - } } else if (auto genericTypeParamDeclRef = declRef.as<GenericTypeParamDecl>()) { @@ -296,7 +278,7 @@ namespace Slang // representation would need to take into account canonicalization of // constraints. - auto genericDeclRef = genericTypeParamDeclRef.getParent(astBuilder).as<GenericDecl>(); + auto genericDeclRef = genericTypeParamDeclRef.getParent().as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); ensureDecl(&visitor, genericDeclRef.getDecl(), DeclCheckState::CanSpecializeGeneric); @@ -317,7 +299,7 @@ namespace Slang auto subDeclRefType = as<DeclRefType>(subType); if (!subDeclRefType) continue; - if (subDeclRefType->declRef != genericTypeParamDeclRef) + if (subDeclRefType->getDeclRef() != genericTypeParamDeclRef) continue; // Because the constraint is a declared inheritance relationship, @@ -376,7 +358,7 @@ namespace Slang // the extension to the type and see if we succeed in // making a match. // - auto extDeclRef = ApplyExtensionToType(&visitor, extDecl, selfType); + auto extDeclRef = applyExtensionToType(&visitor, extDecl, selfType); if (!extDeclRef) continue; @@ -858,15 +840,15 @@ namespace Slang // bottleneck through the logic that gets shared between // type and `extension` declarations. // - return _getInheritanceInfo(declRefType->declRef, declRefType); + return _getInheritanceInfo(declRefType->getDeclRef(), declRefType); } else if (auto conjunctionType = as<AndType>(type)) { // In this case, we have a type of the form `L & R`, // such that it is a subtype of both `L` and `R`. // - auto leftType = conjunctionType->left; - auto rightType = conjunctionType->right; + auto leftType = conjunctionType->getLeft(); + auto rightType = conjunctionType->getRight(); // The linearized inheritance list for the conjunction // must include all the facets from the lists for `L` diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index ce80d0002..ab34c83dd 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -310,12 +310,12 @@ namespace Slang { return false; } - if (intValue->value < 1) + if (intValue->getValue() < 1) { - getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, intValue->value); + getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, intValue->getValue()); return false; } - value = int32_t(intValue->value); + value = int32_t(intValue->getValue()); } values[i] = value; } @@ -341,13 +341,13 @@ namespace Slang } const IRIntegerValue kMaxAnyValueSize = 0x7FFF; - if (value->value > kMaxAnyValueSize) + if (value->getValue() > kMaxAnyValueSize) { getSink()->diagnose(anyValueSizeAttr->loc, Diagnostics::anyValueSizeExceedsLimit, kMaxAnyValueSize); return false; } - anyValueSizeAttr->size = int32_t(value->value); + anyValueSizeAttr->size = int32_t(value->getValue()); } else if (auto bindingAttr = as<GLSLBindingAttribute>(attr)) { @@ -369,8 +369,8 @@ namespace Slang return false; } - bindingAttr->binding = int32_t(binding->value); - bindingAttr->set = int32_t(set->value); + bindingAttr->binding = int32_t(binding->getValue()); + bindingAttr->set = int32_t(set->getValue()); } else if (auto simpleLayoutAttr = as<GLSLSimpleIntegerLayoutAttribute>(attr)) { @@ -388,7 +388,7 @@ namespace Slang return false; } - simpleLayoutAttr->value = int32_t(value->value); + simpleLayoutAttr->value = int32_t(value->getValue()); } else if (auto maxVertexCountAttr = as<MaxVertexCountAttribute>(attr)) { @@ -397,7 +397,7 @@ namespace Slang if (!val) return false; - maxVertexCountAttr->value = (int32_t)val->value; + maxVertexCountAttr->value = (int32_t)val->getValue(); } else if (auto instanceAttr = as<InstanceAttribute>(attr)) { @@ -406,7 +406,7 @@ namespace Slang if (!val) return false; - instanceAttr->value = (int32_t)val->value; + instanceAttr->value = (int32_t)val->getValue(); } else if (auto entryPointAttr = as<EntryPointAttribute>(attr)) { @@ -486,7 +486,7 @@ namespace Slang //IntVal* outIntVal; if (auto cInt = checkConstantEnumVal(attr->args[0])) { - targetClassId = (uint32_t)(cInt->value); + targetClassId = (uint32_t)(cInt->getValue()); } else { @@ -515,7 +515,7 @@ namespace Slang } auto cint = checkConstantIntVal(attr->args[0]); if (cint) - forceUnrollAttr->maxIterations = (int32_t)cint->value; + forceUnrollAttr->maxIterations = (int32_t)cint->getValue(); } else if (auto maxItersAttrs = as<MaxItersAttribute>(attr)) { @@ -528,7 +528,7 @@ namespace Slang auto cint = checkConstantIntVal(attr->args[0]); if (cint) { - maxItersAttrs->value = (int32_t) cint->value; + maxItersAttrs->value = (int32_t) cint->getValue(); } } } @@ -547,10 +547,12 @@ namespace Slang bool typeChecked = false; if (auto basicType = as<BasicExpressionType>(paramDecl->getType())) { - if (basicType->baseType == BaseType::Int) + if (basicType->getBaseType() == BaseType::Int) { if (auto cint = checkConstantIntVal(arg)) { + for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++) + attr->intArgVals.add(nullptr); attr->intArgVals[(uint32_t)paramIndex] = cint; } typeChecked = true; @@ -578,7 +580,7 @@ namespace Slang SLANG_ASSERT(attr->args.getCount() == 1); auto cint = checkConstantIntVal(attr->args[0]); if (cint) - diffAttr->maxOrder = (int32_t)cint->value; + diffAttr->maxOrder = (int32_t)cint->getValue(); } else if (auto formatAttr = as<FormatAttribute>(attr)) { @@ -652,7 +654,7 @@ namespace Slang if (!val) return false; - rayPayloadAttr->location = (int32_t)val->value; + rayPayloadAttr->location = (int32_t)val->getValue(); } else if (auto callablePayloadAttr = as<VulkanCallablePayloadAttribute>(attr)) { @@ -661,7 +663,7 @@ namespace Slang if (!val) return false; - callablePayloadAttr->location = (int32_t)val->value; + callablePayloadAttr->location = (int32_t)val->getValue(); } else if (auto hitObjectAttributesAttr = as<VulkanHitObjectAttributesAttribute>(attr)) { @@ -670,7 +672,7 @@ namespace Slang if (!val) return false; - hitObjectAttributesAttr->location = (int32_t)val->value; + hitObjectAttributesAttr->location = (int32_t)val->getValue(); } else if (as<UserDefinedDerivativeAttribute>(attr) || as<PrimalSubstituteAttribute>(attr)) { diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 2fa13f3fa..8709ae763 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -70,7 +70,7 @@ namespace Slang { if (auto resultType = as<DeclRefType>(candidate.resultType)) { - if (resultType->declRef.as<ClassDecl>()) + if (resultType->getDeclRef().as<ClassDecl>()) { isClassType = true; } @@ -373,7 +373,7 @@ namespace Slang // if( !val ) { - val = m_astBuilder->create<ErrorIntVal>(); + val = m_astBuilder->getOrCreate<ErrorIntVal>(m_astBuilder->getIntType()); } checkedArgs.add(val); } @@ -383,8 +383,8 @@ namespace Slang } } - auto genSubst = m_astBuilder->getOrCreateGenericSubstitution(nullptr, genericDeclRef.getDecl(), checkedArgs.getArrayView()); - candidate.subst = genSubst; + auto genSubst = m_astBuilder->getGenericAppDeclRef(genericDeclRef, checkedArgs.getArrayView()); + candidate.subst = SubstitutionSet(genSubst); // Once we are done processing the parameters of the generic, // we will have build up a usable `checkedArgs` array and @@ -550,19 +550,17 @@ namespace Slang // We should have the existing arguments to the generic // handy, so that we can construct a substitution list. - auto subst = as<GenericSubstitution>(candidate.subst); - SLANG_ASSERT(subst); + auto substArgs = tryGetGenericArguments(candidate.subst, genericDeclRef.getDecl()); + SLANG_ASSERT(substArgs.getCount()); - subst = getASTBuilder()->getOrCreateGenericSubstitution( - genericDeclRef.getSubst(), genericDeclRef.getDecl(), subst->getArgs()); - - List<Val*> newArgs = subst->getArgs(); + List<Val*> newArgs; + for (auto arg : substArgs) + newArgs.add(arg); for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { - DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getSpecializedDeclRef( - constraintDecl, subst); - + DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(genericDeclRef, substArgs, constraintDecl).as<GenericTypeConstraintDecl>(); + auto sub = getSub(m_astBuilder, constraintDeclRef); auto sup = getSup(m_astBuilder, constraintDeclRef); @@ -575,14 +573,14 @@ namespace Slang { if(context.mode != OverloadResolveContext::Mode::JustTrying) { - subTypeWitness = tryGetSubtypeWitness(sub, sup); + subTypeWitness = isSubtype(sub, sup); getSink()->diagnose(context.loc, Diagnostics::typeArgumentDoesNotConformToInterface, sub, sup); } return false; } } - candidate.subst = m_astBuilder->getOrCreateGenericSubstitution(nullptr, genericDeclRef.getDecl(), newArgs); + candidate.subst = SubstitutionSet(m_astBuilder->getGenericAppDeclRef(genericDeclRef, newArgs.getArrayView())); // Done checking all the constraints, hooray. return true; @@ -617,7 +615,7 @@ namespace Slang Expr* SemanticsVisitor::createGenericDeclRef( Expr* baseExpr, Expr* originalExpr, - GenericSubstitution* subst) + SubstitutionSet substArgs) { auto baseDeclRefExpr = as<DeclRefExpr>(baseExpr); if (!baseDeclRefExpr) @@ -631,10 +629,9 @@ namespace Slang SLANG_DIAGNOSE_UNEXPECTED(getSink(), baseExpr, "expected a reference to a generic declaration"); return CreateErrorExpr(originalExpr); } - - subst = m_astBuilder->getOrCreateGenericSubstitution(baseGenericRef.getSubst(), baseGenericRef.getDecl(), subst->getArgs()); - - DeclRef<Decl> innerDeclRef = m_astBuilder->getSpecializedDeclRef<Decl>(getInner(baseGenericRef), subst); + auto genSubst = substArgs.findGenericAppDeclRef(baseGenericRef.getDecl()); + SLANG_ASSERT(genSubst); + DeclRef<Decl> innerDeclRef = m_astBuilder->getGenericAppDeclRef(baseGenericRef, genSubst->getArgs()); Expr* base = nullptr; if (auto mbrExpr = as<MemberExpr>(baseExpr)) @@ -768,14 +765,16 @@ namespace Slang expr->loc = context.loc; expr->originalExpr = baseExpr; expr->baseGenericDeclRef = as<DeclRefExpr>(baseExpr)->declRef.as<GenericDecl>(); - expr->substWithKnownGenericArgs = (GenericSubstitution*)candidate.subst; + auto args = tryGetGenericArguments(candidate.subst, expr->baseGenericDeclRef.getDecl()); + for (auto arg : args) + expr->knownGenericArgs.add(arg); return expr; } return createGenericDeclRef( baseExpr, context.originalExpr, - as<GenericSubstitution>(candidate.subst)); + candidate.subst); break; default: @@ -801,12 +800,14 @@ namespace Slang /// Does the given `declRef` represent an interface requirement? bool isInterfaceRequirement(ASTBuilder* builder, DeclRef<Decl> const& declRef) { + SLANG_UNUSED(builder); + if(!declRef) return false; - auto parent = declRef.getParent(builder); + auto parent = declRef.getParent(); if(parent.as<GenericDecl>()) - parent = parent.getParent(builder); + parent = parent.getParent(); if(parent.as<InterfaceDecl>()) return true; @@ -826,7 +827,7 @@ namespace Slang // "inner" declaration of a generic. That means that // the parent of the decl ref must be a generic. // - auto parentGeneric = declRef.getParent(m_astBuilder).as<GenericDecl>(); + auto parentGeneric = declRef.getParent().as<GenericDecl>(); if(!parentGeneric) return 0; // @@ -863,7 +864,18 @@ namespace Slang if(leftIsInterfaceRequirement != rightIsInterfaceRequirement) return int(leftIsInterfaceRequirement) - int(rightIsInterfaceRequirement); - // TODO: We should always have rules such that in a tie a declaration + // If both are interface requirements, prefer to more derived interface. + if (leftIsInterfaceRequirement && rightIsInterfaceRequirement) + { + auto leftType = DeclRefType::create(m_astBuilder, left.declRef.getParent()); + auto rightType = DeclRefType::create(m_astBuilder, right.declRef.getParent()); + if (isSubtype(leftType, rightType)) + return -1; + if (isSubtype(rightType, leftType)) + return 1; + } + + // TODO: We should generalize above rules such that in a tie a declaration // A::m is better than B::m when all other factors are equal and // A inherits from B. @@ -1227,7 +1239,7 @@ namespace Slang DeclRef<Decl> SemanticsVisitor::inferGenericArguments( DeclRef<GenericDecl> genericDeclRef, OverloadResolveContext& context, - GenericSubstitution* substWithKnownGenericArgs, + ArrayView<Val*> knownGenericArgs, List<Type*> *innerParameterTypes) { // We have been asked to infer zero or more arguments to @@ -1265,28 +1277,10 @@ namespace Slang // the "inner" declaration of the generic (e.g., the `FuncitonDecl` // under the `GenericDecl`). // - // In the case where no explicit arguments are available, we will - // use any substitutions that were in place for referring to the - // generic itself. - // - Substitutions* substForInnerDecl = genericDeclRef.getSubst(); - // - // In the case where we have explicit/known arguments, - // we will use those as our baseline substitutions. - // - if (substWithKnownGenericArgs) - { - substForInnerDecl = substWithKnownGenericArgs; - } - - auto innerDecl = getInner(genericDeclRef); - DeclRef<Decl> partiallySpecializedInnerRef = m_astBuilder->getSpecializedDeclRef<Decl>( - innerDecl, - substForInnerDecl); - // Check what type of declaration we are dealing with, and then try // to match it up with the arguments accordingly... - if (auto funcDeclRef = partiallySpecializedInnerRef.as<CallableDecl>()) + + if (auto funcDeclRef = as<CallableDecl>(genericDeclRef.getDecl()->inner)) { List<Type*> paramTypes; if (!innerParameterTypes) @@ -1360,28 +1354,8 @@ namespace Slang // TODO(tfoley): We probably need to pass along the explicit arguments here, // so that the solver knows to accept those arguments as-is. // - auto constraintSubst = trySolveConstraintSystem( - &constraints, genericDeclRef, substWithKnownGenericArgs); - if (!constraintSubst) - { - // In this case, the solver failed to find a solution to the constraint - // system, and we will signal that failure up to the client that called - // this operation. - // - // TODO: We really ought to be passing up some kind of representation - // of the failure, so that constraint-related issues can be reported to - // the user. This could either be a return path here (returning some - // diagnostics), or this code could have a "just trying" vs. "actually - // do things" distinction like some other steps. - // - return DeclRef<Decl>(); - } - - // If we found a solution (that is, a set of argument values that satisfy - // all the constraints), we can construct a reference to the inner - // declaration that applies the generic to those arguments. - // - return m_astBuilder->getSpecializedDeclRef<Decl>(innerDecl, constraintSubst); + return trySolveConstraintSystem( + &constraints, genericDeclRef, knownGenericArgs); } void SemanticsVisitor::AddTypeOverloadCandidates( @@ -1424,13 +1398,13 @@ namespace Slang void SemanticsVisitor::addOverloadCandidatesForCallToGeneric( LookupResultItem genericItem, OverloadResolveContext& context, - GenericSubstitution* substWithKnownGenericArgs) + ArrayView<Val*> knownGenericArgs) { auto genericDeclRef = genericItem.declRef.as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); // Try to infer generic arguments, based on the context - DeclRef<Decl> innerRef = inferGenericArguments(genericDeclRef, context, substWithKnownGenericArgs); + DeclRef<Decl> innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs); if (innerRef) { @@ -1475,7 +1449,7 @@ namespace Slang LookupResultItem innerItem; innerItem.breadcrumbs = item.breadcrumbs; innerItem.declRef = genericDeclRef; - addOverloadCandidatesForCallToGeneric(innerItem, context); + addOverloadCandidatesForCallToGeneric(innerItem, context, ArrayView<Val*>()); } else if( auto typeDefDeclRef = item.declRef.as<TypeDefDecl>() ) { @@ -1578,7 +1552,7 @@ namespace Slang addOverloadCandidatesForCallToGeneric( LookupResultItem(partiallyAppliedGenericExpr->baseGenericDeclRef), context, - partiallyAppliedGenericExpr->substWithKnownGenericArgs); + partiallyAppliedGenericExpr->knownGenericArgs.getArrayView()); } else if (auto typeType = as<TypeType>(funcExprType)) { @@ -1588,7 +1562,7 @@ namespace Slang // // TODO(tfoley): are there any meaningful types left // that aren't declaration references? - auto type = typeType->type; + auto type = typeType->getType(); AddTypeOverloadCandidates(type, context); return; } @@ -1633,12 +1607,16 @@ namespace Slang paramTypes.add(removeParamDirType(diffFuncType->getParamType(ii))); // Try to infer generic arguments, based on the updated context. + OverloadResolveContext subContext = context; DeclRef<Decl> innerRef = inferGenericArguments( baseFuncGenericDeclRef, context, - nullptr, + ArrayView<Val*>(), ¶mTypes); + if (!innerRef) + return; + OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; if (innerRef) diff --git a/source/slang/slang-check-resolve-val.cpp b/source/slang/slang-check-resolve-val.cpp new file mode 100644 index 000000000..91722f82c --- /dev/null +++ b/source/slang/slang-check-resolve-val.cpp @@ -0,0 +1,48 @@ +// slang-check-resolve-val.cpp + +// Logic for resolving/simplifying Types and DeclRefs. + +#include "slang-check-impl.h" + +#include "slang-lookup.h" +#include "slang-syntax.h" +#include "slang-ast-synthesis.h" +#include "slang-ast-reflect.h" + +namespace Slang +{ + +Type* Type::createCanonicalType() +{ + SLANG_AST_NODE_VIRTUAL_CALL(Type, createCanonicalType, ()); +} + +Val* Type::_resolveImplOverride() +{ + Val* resolvedVal = createCanonicalType(); + return resolvedVal; +} + +DeclRefBase* _resolveAsDeclRef(DeclRefBase* declRefToResolve); + +Type* DeclRefType::_createCanonicalTypeOverride() +{ + auto astBuilder = getCurrentASTBuilder(); + + // A declaration reference is already canonical + auto resolvedDeclRef = getDeclRef(); + resolvedDeclRef = _resolveAsDeclRef(getDeclRef().declRefBase); + if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, resolvedDeclRef)) + return as<Type>(satisfyingVal); + if (resolvedDeclRef != getDeclRef()) + return DeclRefType::create(astBuilder, resolvedDeclRef); + return this; +} + + +Val* SubtypeWitness::_resolveImplOverride() +{ + return as<SubtypeWitness>(defaultResolveImpl()); +} + +} diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 657438222..d9bb11548 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -17,7 +17,7 @@ namespace Slang auto basicType = as<BasicExpressionType>(type); if (basicType) { - return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt); + return (basicType->getBaseType() == BaseType::Int || basicType->getBaseType() == BaseType::UInt); } } // Can be an int/uint vector from size 1 to 3 @@ -27,20 +27,21 @@ namespace Slang { return false; } - auto elemCount = as<ConstantIntVal>(vectorType->elementCount); - if (elemCount->value < 1 || elemCount->value > 3) + auto elemCount = as<ConstantIntVal>(vectorType->getElementCount()); + if (elemCount->getValue() < 1 || elemCount->getValue() > 3) { return false; } // Must be a basic type - auto basicType = as<BasicExpressionType>(vectorType->elementType); + auto basicType = as<BasicExpressionType>(vectorType->getElementType()); if (!basicType) { return false; } // Must be integral - return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt); + auto baseType = basicType->getBaseType(); + return (baseType == BaseType::Int || baseType == BaseType::UInt); } } @@ -83,7 +84,7 @@ namespace Slang if( auto declRefType = as<DeclRefType>(type) ) { - auto typeDeclRef = declRefType->declRef; + auto typeDeclRef = declRefType->getDeclRef(); if( auto interfaceDeclRef = typeDeclRef.as<InterfaceDecl>() ) { // Each leaf parameter of interface type adds a specialization @@ -792,6 +793,8 @@ namespace Slang void FrontEndCompileRequest::checkEntryPoints() { auto linkage = getLinkage(); + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + auto sink = getSink(); // The validation of entry points here will be modal, and controlled @@ -1025,7 +1028,7 @@ namespace Slang // if( auto argDeclRefType = as<DeclRefType>(argType) ) { - auto argDeclRef = argDeclRefType->declRef; + auto argDeclRef = argDeclRefType->getDeclRef(); if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) { if(argGenericParamDeclRef.getDecl() == genericTypeParamDecl) @@ -1193,7 +1196,7 @@ namespace Slang // the semantic checking machinery to expand out // the rest of the arguments via inference... - auto genericDeclRef = m_funcDeclRef.getParent(getLinkage()->getASTBuilder()).as<GenericDecl>(); + auto genericDeclRef = m_funcDeclRef.getParent().as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters List<Val*> genericArgs; @@ -1203,19 +1206,13 @@ namespace Slang auto specializationArg = args[ii]; genericArgs.add(specializationArg.val); } - GenericSubstitution* genericSubst = - getLinkage()->getASTBuilder()->getOrCreateGenericSubstitution( - genericDeclRef.getSubst(), - genericDeclRef.getDecl(), - genericArgs.getArrayView()); + auto genericInnerDeclRef = getLinkage()->getASTBuilder()->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView()); ASTBuilder* astBuilder = getLinkage()->getASTBuilder(); for (auto constraintDecl : getMembersOfType<GenericTypeConstraintDecl>( getLinkage()->getASTBuilder(), DeclRef<ContainerDecl>(genericDeclRef))) { - DeclRef<GenericTypeConstraintDecl> constraintDeclRef = astBuilder->getSpecializedDeclRef( - constraintDecl.getDecl(), genericSubst); - + DeclRef<GenericTypeConstraintDecl> constraintDeclRef = astBuilder->getDirectDeclRef(constraintDecl.getDecl()); auto sub = getSub(astBuilder, constraintDeclRef); auto sup = getSup(astBuilder, constraintDeclRef); @@ -1233,12 +1230,8 @@ namespace Slang } } - genericSubst = - getLinkage()->getASTBuilder()->getOrCreateGenericSubstitution( - genericDeclRef.getSubst(), - genericDeclRef.getDecl(), - genericArgs); - specializedFuncDeclRef = astBuilder->getSpecializedDeclRef(specializedFuncDeclRef.getDecl(), genericSubst); + specializedFuncDeclRef = getLinkage()->getASTBuilder()->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView()).as<FuncDecl>(); + SLANG_ASSERT(specializedFuncDeclRef); } info->specializedFuncDeclRef = specializedFuncDeclRef; @@ -1418,9 +1411,8 @@ namespace Slang specializationArgs.add(arg); } - ExistentialSpecializedType* specializedType = m_astBuilder->create<ExistentialSpecializedType>(); - specializedType->baseType = unspecializedType; - specializedType->args = specializationArgs; + ExistentialSpecializedType* specializedType = m_astBuilder->getOrCreate<ExistentialSpecializedType>( + unspecializedType, specializationArgs); m_specializedTypes.add(specializedType); diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index ba7e977e3..4b4257f75 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -526,7 +526,7 @@ namespace Slang } if (!stepSize) return; - if (stepSize->value > 0) + if (stepSize->getValue() > 0) { if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater || sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less) @@ -535,7 +535,7 @@ namespace Slang return; } } - else if (stepSize->value < 0) + else if (stepSize->getValue() < 0) { if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less || sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater) @@ -553,25 +553,25 @@ namespace Slang if (!initialLitVal || !finalVal) return; - auto absStepSize = abs(stepSize->value); + auto absStepSize = abs(stepSize->getValue()); int adjustment = 0; if (compareOp == kIROp_Geq || compareOp == kIROp_Leq) adjustment = 1; - auto iterations = (Math::Max(finalVal->value, initialLitVal->value) - - Math::Min(finalVal->value, initialLitVal->value) + absStepSize - 1 + adjustment) / + auto iterations = (Math::Max(finalVal->getValue(), initialLitVal->getValue()) - + Math::Min(finalVal->getValue(), initialLitVal->getValue()) + absStepSize - 1 + adjustment) / absStepSize; switch (compareOp) { case kIROp_Geq: case kIROp_Greater: // Expect final value to be less than initial value. - if (finalVal->value > initialLitVal->value) + if (finalVal->getValue() > initialLitVal->getValue()) iterations = 0; break; case kIROp_Leq: case kIROp_Less: - if (finalVal->value < initialLitVal->value) + if (finalVal->getValue() < initialLitVal->getValue()) iterations = 0; break; } @@ -590,7 +590,7 @@ namespace Slang litExpr->type.type = m_astBuilder->getIntType(); litExpr->token.setName(getNamePool()->getName(String(iterations))); maxItersAttr->args.add(litExpr); - maxItersAttr->intArgVals.add(0, m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations)); + maxItersAttr->intArgVals.add(m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations)); maxItersAttr->value = (int32_t)iterations; maxItersAttr->inductionVar = initialVar; addModifier(stmt, maxItersAttr); diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index cee54388f..d5d3e5a5d 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -17,6 +17,7 @@ namespace Slang sink); SemanticsVisitor visitor(&sharedSemanticsContext); + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); auto typeOut = visitor.CheckProperType(typeExp); return typeOut.type; @@ -49,7 +50,7 @@ namespace Slang if (!typeRepr) return nullptr; if (auto typeType = as<TypeType>(typeRepr->type)) { - return typeType->type; + return typeType->getType(); } return m_astBuilder->getErrorType(); } @@ -86,11 +87,17 @@ namespace Slang Type* SemanticsVisitor::getRemovedModifierType(ModifiedType* modifiedType, ModifierVal* modifier) { - if (modifiedType->modifiers.getCount() == 1) - return modifiedType->base; - auto newModifiers = modifiedType->modifiers; - newModifiers.remove(modifier); - return m_astBuilder->getModifiedType(modifiedType->base, newModifiers); + if (modifiedType->getModifierCount() == 1) + return modifiedType->getBase(); + List<Val*> newModifiers; + for (Index i = 0; i < modifiedType->getModifierCount(); i++) + { + auto m = modifiedType->getModifier(i); + if (m == modifier) + continue; + newModifiers.add(m); + } + return m_astBuilder->getModifiedType(modifiedType->getBase(), newModifiers); } Expr* SemanticsVisitor::ExpectATypeRepr(Expr* expr) @@ -118,7 +125,7 @@ namespace Slang auto typeRepr = ExpectATypeRepr(expr); if (auto typeType = as<TypeType>(typeRepr->type)) { - return typeType->type; + return typeType->getType(); } return m_astBuilder->getErrorType(); } @@ -142,7 +149,7 @@ namespace Slang // constant expression in context, then we will instead construct // a dummy "error" value to represent the result. // - val = m_astBuilder->create<ErrorIntVal>(); + val = m_astBuilder->getOrCreate<ErrorIntVal>(m_astBuilder->getIntType()); return val; } @@ -160,7 +167,7 @@ namespace Slang } if (auto typeType = as<TypeType>(exp->type)) { - return typeType->type; + return typeType->getType(); } else if (const auto errorType = as<ErrorType>(exp->type)) { @@ -187,10 +194,7 @@ namespace Slang evaledArgs.add(ExtractGenericArgVal(argExpr)); } - GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution( - genericDeclRef.getSubst(), genericDeclRef.getDecl(), evaledArgs); - - DeclRef<Decl> innerDeclRef = m_astBuilder->getSpecializedDeclRef(getInner(genericDeclRef), subst); + DeclRef<Decl> innerDeclRef = m_astBuilder->getGenericAppDeclRef(genericDeclRef, evaledArgs.getArrayView()); return DeclRefType::create(m_astBuilder, innerDeclRef); } @@ -198,9 +202,9 @@ namespace Slang { if (auto declRefValueType = as<DeclRefType>(type)) { - if (as<ClassDecl>(declRefValueType->declRef.getDecl())) + if (as<ClassDecl>(declRefValueType->getDeclRef().getDecl())) return true; - if (as<InterfaceDecl>(declRefValueType->declRef.getDecl())) + if (as<InterfaceDecl>(declRefValueType->getDeclRef().getDecl())) return true; } return false; @@ -221,7 +225,7 @@ namespace Slang if(auto typeType = as<TypeType>(expr->type)) { - type = typeType->type; + type = typeType->getType(); } } @@ -358,7 +362,7 @@ namespace Slang if (auto basicType = as<BasicExpressionType>(type)) { // TODO: `void` shouldn't be a basic type, to make this easier to avoid - if (basicType->baseType == BaseType::Void) + if (basicType->getBaseType() == BaseType::Void) { // TODO(tfoley): pick the right diagnostic message getSink()->diagnose(result.exp, Diagnostics::invalidTypeVoid); @@ -384,7 +388,7 @@ namespace Slang { if(auto rightConst = as<ConstantIntVal>(right)) { - return leftConst->value == rightConst->value; + return leftConst->getValue() == rightConst->getValue(); } } @@ -392,16 +396,16 @@ namespace Slang { if(auto rightVar = as<GenericParamIntVal>(right)) { - return leftVar->declRef.equals(rightVar->declRef); + return leftVar->getDeclRef().equals(rightVar->getDeclRef()); } else if (const auto rightPoly = as<PolynomialIntVal>(right)) { - return right->equalsVal(leftVar); + return right->equals(leftVar); } } if (auto leftVar = as<PolynomialIntVal>(left)) { - return leftVar->equalsVal(right); + return leftVar->equals(right); } return false; } @@ -423,22 +427,4 @@ namespace Slang return expr; } - Expr* SemanticsExprVisitor::visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr) - { - // We have an expression of the form `__TaggedUnion(A, B, ...)` - // which will evaluate to a tagged-union type over `A`, `B`, etc. - // - TaggedUnionType* type = m_astBuilder->create<TaggedUnionType>(); - expr->type = QualType(m_astBuilder->getTypeType(type)); - - for( auto& caseTypeExpr : expr->caseTypes ) - { - caseTypeExpr = CheckProperType(caseTypeExpr); - type->caseTypes.add(caseTypeExpr.type); - } - - return expr; - } - - } diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index 276e086df..780c109da 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -164,6 +164,8 @@ namespace Slang TranslationUnitRequest* translationUnit, LoadedModuleDictionary& loadedModules) { + SLANG_AST_BUILDER_RAII(translationUnit->compileRequest->getLinkage()->getASTBuilder()); + SharedSemanticsContext sharedSemanticsContext( translationUnit->compileRequest->getLinkage(), translationUnit->getModule(), diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 224e30fa1..4e1ab8e98 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -265,7 +265,7 @@ namespace Slang { if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness)) { - auto declModule = getModule(declaredWitness->declRef.getDecl()); + auto declModule = getModule(declaredWitness->getDeclRef().getDecl()); m_moduleDependencyList.addDependency(declModule); m_fileDependencyList.addDependency(declModule); if (m_requirementSet.add(declModule)) @@ -276,8 +276,8 @@ namespace Slang } else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness)) { - addDepedencyFromWitness(transitiveWitness->midToSup); - addDepedencyFromWitness(transitiveWitness->subToMid); + addDepedencyFromWitness(transitiveWitness->getMidToSup()); + addDepedencyFromWitness(transitiveWitness->getSubToMid()); } else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness)) { diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 7d73599ba..79f6b6ed4 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -650,9 +650,6 @@ namespace Slang List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies; } List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencies; } - /// Get a list of tagged-union types referenced by the specialization parameters. - List<TaggedUnionType*> const& getTaggedUnionTypes() { return m_taggedUnionTypes; } - RefPtr<IRModule> getIRModule() { return m_irModule; } void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; @@ -679,9 +676,6 @@ namespace Slang List<String> m_entryPointMangledNames; List<String> m_entryPointNameOverrides; - // Any tagged union types that were referenced by the specialization arguments. - List<TaggedUnionType*> m_taggedUnionTypes; - List<Module*> m_moduleDependencies; List<SourceFile*> m_fileDependencies; List<RefPtr<ComponentType>> m_requirements; diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp index 9e9efeb64..c13dc9668 100644 --- a/source/slang/slang-doc-markdown-writer.cpp +++ b/source/slang/slang-doc-markdown-writer.cpp @@ -1119,7 +1119,7 @@ void DocMarkdownWriter::writeDescription(const ASTMarkup::Entry& entry) void DocMarkdownWriter::writeDecl(const ASTMarkup::Entry& entry, Decl* decl) { // Skip these they will be output as part of their respective 'containers' - if (as<ParamDecl>(decl) || as<EnumCaseDecl>(decl) || as<AssocTypeDecl>(decl) || as<InheritanceDecl>(decl)) + if (as<ParamDecl>(decl) || as<EnumCaseDecl>(decl) || as<AssocTypeDecl>(decl) || as<InheritanceDecl>(decl) || as<ThisTypeDecl>(decl)) { return; } diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index adfc98dfd..5bd57e81c 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -14,7 +14,6 @@ #include "slang-ir-specialize.h" #include "slang-ir-specialize-resources.h" #include "slang-ir-ssa.h" -#include "slang-ir-union.h" #include "slang-ir-util.h" #include "slang-ir-validate.h" #include "slang-legalize-types.h" diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 8f4d68a75..343c18916 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -54,7 +54,6 @@ #include "slang-ir-strip-cached-dict.h" #include "slang-ir-strip-witness-tables.h" #include "slang-ir-synthesize-active-mask.h" -#include "slang-ir-union.h" #include "slang-ir-validate.h" #include "slang-ir-wrap-structured-buffers.h" #include "slang-ir-liveness.h" @@ -347,10 +346,6 @@ Result linkAndOptimizeIR( // Lower `Result<T,E>` types into ordinary struct types. lowerResultType(irModule, sink); - // Desguar any union types, since these will be illegal on - // various targets. - // - desugarUnionTypes(irModule); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "UNIONS DESUGARED"); #endif diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 4716ed427..69f3c4e0d 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -54,8 +54,6 @@ INST(Nop, nop, 0, 0) INST(VectorType, Vec, 2, HOISTABLE) INST(MatrixType, Mat, 3, HOISTABLE) - INST(TaggedUnionType, TaggedUnion, 0, HOISTABLE) - INST(ConjunctionType, Conjunction, 0, HOISTABLE) INST(AttributedType, Attributed, 0, HOISTABLE) INST(ResultType, Result, 2, HOISTABLE) @@ -985,7 +983,6 @@ INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) INST(ArrayTypeLayout, arrayTypeLayout, 1, HOISTABLE) INST(StreamOutputTypeLayout, streamOutputTypeLayout, 1, HOISTABLE) INST(MatrixTypeLayout, matrixTypeLayout, 1, HOISTABLE) - INST(TaggedUnionTypeLayout, taggedUnionTypeLayout, 0, HOISTABLE) INST(ExistentialTypeLayout, existentialTypeLayout, 0, HOISTABLE) INST(StructTypeLayout, structTypeLayout, 0, HOISTABLE) // TODO(JS): Ideally we'd have the layout to the pointed to value type (ie 1 instead of 0 here). But to avoid infinite recursion we don't. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 2312cc4f2..95f72b3cd 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1717,59 +1717,6 @@ struct IRCaseTypeLayoutAttr : IRAttr } }; - /// Specialized layout information for tagged union types -struct IRTaggedUnionTypeLayout : IRTypeLayout -{ - typedef IRTypeLayout Super; - - IR_LEAF_ISA(TaggedUnionTypeLayout) - - /// Get the (byte) offset of the tagged union's tag (aka "discriminator") field - LayoutSize getTagOffset() - { - return LayoutSize::fromRaw(LayoutSize::RawValue(getIntVal(cast<IRIntLit>(getOperand(0))))); - } - - /// Get all the attributes representing layouts for the difference cases - IROperandList<IRCaseTypeLayoutAttr> getCaseTypeLayoutAttrs() - { - return findAttrs<IRCaseTypeLayoutAttr>(); - } - - /// Get the number of cases for which layout information is stored - UInt getCaseCount() - { - return getCaseTypeLayoutAttrs().getCount(); - } - - /// Get the layout information for the case at the given `index` - IRTypeLayout* getCaseTypeLayout(UInt index) - { - return getCaseTypeLayoutAttrs()[index]->getTypeLayout(); - } - - /// Specialized builder for tagged union type layouts - struct Builder : Super::Builder - { - Builder(IRBuilder* irBuilder, LayoutSize tagOffset); - - void addCaseTypeLayout(IRTypeLayout* typeLayout); - - IRTaggedUnionTypeLayout* build() - { - return cast<IRTaggedUnionTypeLayout>(Super::Builder::build()); - } - - protected: - IROp getOp() SLANG_OVERRIDE { return kIROp_TaggedUnionTypeLayout; } - void addOperandsImpl(List<IRInst*>& ioOperands) SLANG_OVERRIDE; - void addAttrsImpl(List<IRInst*>& ioOperands) SLANG_OVERRIDE; - - IRInst* m_tagOffset = nullptr; - List<IRAttr*> m_caseTypeLayoutAttrs; - }; -}; - /// Type layout for an existential/interface type. struct IRExistentialTypeLayout : IRTypeLayout { @@ -3013,16 +2960,6 @@ public: IRRate* rate, IRType* dataType); - IRType* getTaggedUnionType( - UInt caseCount, - IRType* const* caseTypes); - - IRType* getTaggedUnionType( - List<IRType*> const& caseTypes) - { - return getTaggedUnionType(caseTypes.getCount(), caseTypes.getBuffer()); - } - IRType* getBindExistentialsType( IRInst* baseType, UInt slotArgCount, diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 081fd7486..364613074 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -240,7 +240,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) case kIROp_GlobalGenericParam: case kIROp_WitnessTable: case kIROp_InterfaceType: - case kIROp_TaggedUnionType: return cloneGlobalValue(this, originalValue); case kIROp_BoolLit: diff --git a/source/slang/slang-ir-union.cpp b/source/slang/slang-ir-union.cpp deleted file mode 100644 index 1eb4955e7..000000000 --- a/source/slang/slang-ir-union.cpp +++ /dev/null @@ -1,773 +0,0 @@ -// slang-ir-union.cpp -#include "slang-ir-union.h" - -#include "slang-ir.h" -#include "slang-ir-insts.h" - -namespace Slang { - -// This file will implement a pass to replace any union types (currently -// just tagged unions) with plain `struct` types that attempt to provide -// equivalent semantics. This will necessarily be a bit fragile, and there -// will be fundamental limits to what the translation can support without -// improved features in the target shading languages/ILs. - -struct DesugarUnionTypesContext -{ - // We'll start with some basic state that we need to get the job done. - // - // This includes the IR module we are to process, as well as IR building - // state that we will initialize once and then use throughout the pass. - // - IRModule* module; - IRBuilder builderStorage; - IRBuilder* getBuilder() { return &builderStorage; } - - // Because we will be replacing instructions that refer to unions with - // different logic, we'll want to remove the original instructions. - // However, we need to be careful about modifying the IR tree while also - // iterating it, and to keep things simple for ourselves we'll go ahead - // and build up a list of instruction to remove along the way, and then - // remove them all at the end. - // - List<IRInst*> instsToRemove; - - // The overall flow of the pass is pretty simple, so we will walk through it now. - // - void processModule() - { - // We start by initializing our IR building state. - // - builderStorage = IRBuilder(module); - - // Next, we will search for any instruction that create or use - // union types, and process them accordingingly (usually by - // constructing a new instruction to replace them). - // - processInstRec(module->getModuleInst()); - - // Along the way we will build up a list of the tagged union - // types that we encountered, but we will refrain from replacing - // them until we are done (so that we always know that the instructions - // we process above refer to the original type, and not its - // replacement. - // - for( auto info : taggedUnionInfos ) - { - auto taggedUnionType = info->taggedUnionType; - auto replacementInst = info->replacementInst; - - // TODO: We should consider transferring decorations from the source - // type to the destination, but doing so carelessly could create - // problems, since an IR struct type shouldn't have, e.g., a - // `TaggedUnionTypeLayout` attached to it. - - taggedUnionType->replaceUsesWith(replacementInst); - taggedUnionType->removeAndDeallocate(); - } - - // As described previously, we build up the `instsToRemove` list as - // we iterate so that we can remove them all here and not risk - // modifying the IR tree while also walking it. - // - // TODO: This might be overkill and we could conceivably just be - // a bit careful in `processInstRec`. - // - for(auto inst : instsToRemove) - { - inst->removeAndDeallocate(); - } - } - - // In order to replace a (tagged) union type, we will need to know - // something about it, and we will use the `TaggedUnionInfo` type - // to collect all the relevant information. - // - struct TaggedUnionInfo : public RefObject - { - // We obviously need to know the tagged union itself, and - // we will also use this structure to track the instruction - // (an IR struct type) that will replace it. - // - IRTaggedUnionType* taggedUnionType; - IRInst* replacementInst; - - // In order to compute a suitable layout for the replacement - // `struct` type we need to know how the tagged union itself - // would be laid out in memory, so we require that all tagged - // unions in the generated IR have an associated (target-specific) - // layout. - // - IRTaggedUnionTypeLayout* taggedUnionTypeLayout; - - // The basic approach we will use 16-byte chunks (represented as an array - // of `uint4`s) to reprent the "bulk" of a type, and then use a single field - // that could be up to 12 bytes to represent the "rest" of the type. - // - // Note that there are deeply ingrained assumptions here that all types - // are at least four bytes in size (so that unions cannot easily - // accomodate `half` value), and that any types *larger* than four bytes - // will need to be loaded/stored via multiple 4-byte loads/stores. - // - // With the basic idea out of the way, we need an IR level field - // in our struct to hold the bulk data, which comprises a "key" for - // looking up the field, and the type of the field itself. We also - // keep track of how many bytes we put in our bulk storage. - // - // The bulk field might be: - // - // - null, if none of the case types was 16 bytes or more - // - a single `uint4` for between 16 and 31 (inclusive) bytes - // - an array of `uint4`s for 32 or more bytes - // - UInt64 bulkSize = 0; - IRInst* bulkFieldKey = nullptr; - IRType* bulkFieldType = nullptr; - - // The same basic idea then applies to the rest of the data. - // - // The "rest" field will be either be absent (if the size of the - // type was evently divisible by 16), a scalar `uint`, or else - // a 2- or 3-component vector of `uint`. - // - UInt64 restSize = 0; - IRInst* restFieldKey = nullptr; - IRType* restFieldType = nullptr; - - // Finally, since we are currently working with tagged unions, - // we need a field to hold the tag, which will always be allocated - // after the fields that hold the bulk/rest of the payload. - // - // This field is always a single `uint`. - // - // TODO: if/when we support untagged unions, they could be handled - // by having this field be null. - // - IRInst* tagFieldKey; - }; - - // We will build up a list of all the tagged union types we encounter, - // so that we can replace them with the synthesized types when we are done. - // - List<RefPtr<TaggedUnionInfo>> taggedUnionInfos; - - // It is possible that we will see the same tagged union type referenced - // many times in the IR, but we only want to synthesize the information - // above (including the various IR structures) once, so we also maintain - // a map from the original IR type to the corresponding information. - // - Dictionary<IRInst*, TaggedUnionInfo*> mapIRTypeToTaggedUnionInfo; - - // We will process all instructions in the module in a single recursive walk. - // - void processInstRec(IRInst* inst) - { - processInst(inst); - - for( auto child : inst->getChildren() ) - { - processInstRec(child); - } - } - // - // At each instruction, we will check if it is one of the union-related instructions - // we need to replace, and process it accordingly. - // - void processInst(IRInst* inst) - { - switch( inst->getOp() ) - { - default: - // Any instruction not listed below either doesn't involve union types, - // or handles them in a hands-off fashion that we don't need to care about. - // - // E.g., a `load` of a union type from a constant buffer will turn into - // a load of the replacement `struct` type once we are done, and nothing - // needs to be done to the `load` instruction. - // - break; - - case kIROp_TaggedUnionType: - { - // We clearly need to process the tagged union type itself, but the actual - // work is handled by other functions. All we need to do here is ensure - // that the information for this type gets generated, and then we can - // rely on the main `processModule` function to do the actual replacement later. - // - auto type = cast<IRTaggedUnionType>(inst); - getTaggedUnionInfo(type); - } - break; - - case kIROp_ExtractTaggedUnionTag: - { - // The case of extracting the tag from a tagged union is relatively - // simple, because the replacement type will have a dedicated field or it. - // - // We start by finding the tagged union value the instruction is operating - // on, and then looking up the information for its type (which had - // better be a tagged union type). - // - auto taggedUnionVal = inst->getOperand(0); - auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType()); - - // Because the replacement type will have an explicit field for the tag, - // we can simply emit a single field-extract instruction to read its value - // out. - // - auto builder = getBuilder(); - builder->setInsertBefore(inst); - auto replacement = builder->emitFieldExtract( - inst->getFullType(), - taggedUnionVal, - taggedUnionInfo->tagFieldKey); - - // Now we can replace anything that used the original instruction with - // the new field-extract operation, and add this instruction to the - // list for later removal. - // - inst->replaceUsesWith(replacement); - instsToRemove.add(inst); - } - break; - - case kIROp_ExtractTaggedUnionPayload: - { - // The most interesting case is when we are trying to extract a particular - // payload (one of the case types) from a union. We may need to extract - // one or more fields from the data stored in the union's replacement - // type (the bulk/rest fields), and we may also have to convert them - // to the type expected via bit-casts. - - // We can start things off easily enough by extracting the tagged union - // value being operated on, as well as the information for its type. - // - auto taggedUnionVal = inst->getOperand(0); - auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType()); - - // Next we need to figure out which case is being extracted from the union. - // The operand for the case tag should be a literal by construction. - // - auto caseTagVal = inst->getOperand(1); - auto caseTagConst = as<IRIntLit>(caseTagVal); - SLANG_ASSERT(caseTagConst); - - // The case type we are extracting will be the result type of the instruciton. - // - auto caseType = inst->getDataType(); - // - // The tag value itself will be the index of the case type in the union - // type (and its layout). - // - auto caseTagIndex = UInt(caseTagConst->getValue()); - - // We can use the case tag value to look up the layout for the particular - // case type we are extracting (this will allow us to resolve byte offsets - // for fields, etc.). - // - auto taggedUnionTypeLayout = taggedUnionInfo->taggedUnionTypeLayout; - SLANG_ASSERT(caseTagIndex < UInt(taggedUnionTypeLayout->getCaseCount())); - auto caseTypeLayout = taggedUnionTypeLayout->getCaseTypeLayout(caseTagIndex); - - // At this point we know the type we are trying to extract, as well - // as its layout. We will defer the actual implementation of extraction - // to a (recursive) subroutine that can extract a (sub-)field from the - // union at a given byte offset. Since we are extracting a full case - // right now, the byte offset will be zero. - // - auto payloadVal = extractPayload( - taggedUnionInfo, - taggedUnionVal, - caseType, - caseTypeLayout, - 0); - - // TODO: There is a significant flaw in the above approach when - // the case type might be (or contain) an array. If we have a setup - // like the following: - // - // union SomeUnion { float someCase[100]; ... } - // ... - // float result = someUnion.someCase[someIndex]; - // - // The current logic would desugar this into something like: - // - // struct SomeUnion { uint4 bulk[100]; ... } - // ... - // float[] tmp = { asfloat(someUnion.bulk[0].x), asfloat(someUnion.bulk[1].x), ... } - // float result = tmp[someIndex]; - // - // The result is that we copy an entire 100-element array into local memory - // just to fetch a single element, when it would be much nicer to just do: - // - // float result = asfloat(someUnion.bulk[someIndex].x); - // - // Achieving the latter code requires that rather than blindly translate - // the `extractTaggedUnionPayload` instruction into a semantically equiavlent - // value (which might lead to a big copy in the end), we should transitively - // chase down any "access chains" off of `inst` and see what leaf values are - // actually needed, and generated more tailored extraction logic for just - // the elements/fields that actually get referenced. - // - // The more refined approach can be built on top of many of the same primitives, - // so for now we will resign ourselves to the simpler but potentially less - // efficient approach. - - // Now that we've extracted the value for the payload from the fields of - // the replacement struct, we can use that extracted value to replace - // this instruction, and schedule the original instruction for removal. - // - inst->replaceUsesWith(payloadVal); - instsToRemove.add(inst); - } - break; - } - } - - // The `extractPayload` operation is the most important bit of translation we - // need to do to make unions work. We have as input the following: - // - IRInst* extractPayload( - - // - Information about a tagged union type and its layout. - TaggedUnionInfo* taggedUnionInfo, - - // - A single value of that tagged unon type. - IRInst* taggedUnionVal, - - // - Type type of some "payload" field we want to extract from the union. - IRType* payloadType, - - // - The memory layout of that payload type. - IRTypeLayout* payloadTypeLayout, - - // - The byte offset at which we want to fetch the payload. - UInt64 payloadOffset) - { - // We are going to be building some IR code no matter what. - // - auto builder = getBuilder(); - - // The basic approach here will be to look at the type we - // are trying to extract from the union, and whenever possible - // recursively walk its structure so that we can express things - // in terms of extraction of smaller/simpler types. - // - if( auto irStructType = as<IRStructType>(payloadType) ) - { - // A structure type is a nice recursive case: we simply - // want to extract each of its field recursively, and - // then construct a fresh value of the `struct` type. - - // In all of the cases of this function we expect/require - // there to be complete type layout information for the - // types involved. - // - auto structTypeLayout = as<IRStructTypeLayout>(payloadTypeLayout); - SLANG_ASSERT(structTypeLayout); - - // We are going to emit code to extract each of the fields - // and collect them to use as operands to a `makeStruct`. - // - List<IRInst*> fieldVals; - - // We need to walk over the fields in the order the IR expects them - UInt fieldCounter = 0; - for( auto irField : irStructType->getFields() ) - { - IRType* fieldType = irField->getFieldType(); - - // TODO: We need to confirm/enforce that the fields of the - // IR struct and the fields of the layout still align. - // - UInt fieldIndex = fieldCounter++; - auto fieldLayout = structTypeLayout->getFieldLayout(fieldIndex); - auto fieldTypeLayout = fieldLayout->getTypeLayout(); - - // The offset of the field can be computed from the base - // offset passed in, plus the reflection data for the field. - // - UInt64 fieldOffset = payloadOffset; - if(auto resInfo = fieldLayout->findOffsetAttr(LayoutResourceKind::Uniform)) - fieldOffset += resInfo->getOffset(); - - // We make a recursive call to extract each field, expecting - // that this will bottom out eventually. - // - IRInst* fieldVal = extractPayload( - taggedUnionInfo, - taggedUnionVal, - fieldType, - fieldTypeLayout, - fieldOffset); - fieldVals.add(fieldVal); - } - - // The final value is then just a new struct constructed from - // the extracted field values. - // - auto payloadVal = builder->emitMakeStruct(irStructType, fieldVals); - return payloadVal; - } - else if( auto vecType = as<IRVectorType>(payloadType) ) - { - auto elementType = vecType->getElementType(); - - // We expect that by the time we are desugaring union types - // all vector types have literal constant values for their - // element count. - // - auto elementCountVal = vecType->getElementCount(); - auto elementCountConst = as<IRIntLit>(elementCountVal); - SLANG_ASSERT(elementCountConst); - UInt elementCount = UInt(elementCountConst->getValue()); - - // HACK: There is currently no `VectorTypeLayout` and thus - // no way to query the layout of the elements of a vector - // type. Until that gets added we will kludge things here. - // - IRTypeLayout* elementTypeLayout = nullptr; - size_t elementSize = 0; - if(auto resInfo = payloadTypeLayout->findSizeAttr(LayoutResourceKind::Uniform)) - elementSize = resInfo->getSize().getFiniteValue() / elementCount; - - // Similar to the `struct` case above, we will extract a - // value for each element of the vector, and then use - // `makeVector` to construct the result value. - // - List<IRInst*> elementVals; - for(UInt ii = 0; ii < elementCount; ++ii) - { - auto elementVal = extractPayload( - taggedUnionInfo, - taggedUnionVal, - elementType, - elementTypeLayout, - payloadOffset + ii*elementSize); - elementVals.add(elementVal); - } - return builder->emitMakeVector(vecType, elementVals); - } - else if( const auto matType = as<IRMatrixType>(payloadType) ) - { - SLANG_UNIMPLEMENTED_X("matrix in union type"); - } - else if( const auto arrayType = as<IRArrayType>(payloadType) ) - { - SLANG_UNIMPLEMENTED_X("array in union type"); - } - else - { - // If none of the above cases match, then we assume that - // we have an individual scalar field that we need to fetch. - // - UInt64 payloadSize = 0; - if( auto resInfo = payloadTypeLayout->findSizeAttr(LayoutResourceKind::Uniform) ) - { - // TODO: somebody before this point should generate an error if - // we have a `union` type that contains a potentially unbounded - // amount of data. - // - payloadSize = resInfo->getSize().getFiniteValue(); - } - - if( payloadSize != 4 ) - { - // TODO: We should handle the case of 64-bit fields by fetching - // two `uint` values to form a `uint2`, and then using an - // appropriate bit-cast to get from `uint2` to, e.g., `double`. - // - // The case of 16-bit and smaller fields is more troublesome, but - // in the worst case we can load a `uint` and then use bitwise - // ops to extract what we need before bitcasting. - // - // The right long-term solution is for downstream languages to have - // better support for raw memory addressing. - - SLANG_UNIMPLEMENTED_X("leaf union field with size other than 4 bytes"); - } - - // We know that we want to fetch a value of size `payloadSize`, and - // we have a known base value and an initial offset into it. - // - IRInst* baseVal = taggedUnionVal; - UInt64 offset = payloadOffset; - - // We are going to refine our `baseVal` and `offset` as we go, by - // trying to narrow down the data we will access in the `struct` - // type that will provide storage for the union. - // - // The first thing we want to check is if the value sits in the - // "bulk" part of the storage, or the "rest." - // - UInt64 bulkSize = taggedUnionInfo->bulkSize; - if( offset < bulkSize ) - { - // If the value starts in the bulk area, then the whole - // thing had better fit in the bulk area. The 16-byte - // granularity rules for constant buffers should ensure - // this property for us on current targets. - // - SLANG_ASSERT(offset + payloadSize <= bulkSize); - - // Since we know we'll be accessing the bulk storage, - // we will extract it here. The extracted field will - // be our new base value, but the `offset` doesn't need - // to be updated since the bulk field sits at offset 0. - // - baseVal = builder->emitFieldExtract( - taggedUnionInfo->bulkFieldType, - baseVal, - taggedUnionInfo->bulkFieldKey); - - // The bulk storage could be an array, if there are 32 - // or more bytes of bulk storage. - // - if( auto baseArrayType = as<IRArrayType>(baseVal->getDataType()) ) - { - // If an array was allocated for bulk storage then - // our leaf value resides entirely within a single - // element (due to constant buffer layout rules), - // and so we will fetch the appropriate element here. - // - // We will change our `baseVal` to the extracted element, - // and then also adjust our `offset` to be relative - // to that element. - // - size_t bulkElementSize = 16; - auto index = offset / bulkElementSize; - baseVal = builder->emitElementExtract( - baseArrayType->getElementType(), - baseVal, - builder->getIntValue(builder->getIntType(), index)); - offset -= index*bulkElementSize; - } - } - else - { - // If the offset of the field we want is past the end of - // the bulk field then it must sit inside of the rest field, - // and we'll extract it here. This establishes a new - // base value, and we adjust the `offset` to be relative - // to the rest field (which starts at an offset equal to `bulkSize`). - // - baseVal = builder->emitFieldExtract( - taggedUnionInfo->restFieldType, - baseVal, - taggedUnionInfo->restFieldKey); - offset -= bulkSize; - } - - // We've now extracted a field that could be either a scalar or - // a vector, and we have an offset into it. In the case where - // the base value is a vector, we will extract out the appropriate - // element. - // - if( auto baseVecType = as<IRVectorType>(baseVal->getDataType()) ) - { - size_t vecElementSize = 4; - auto index = offset / vecElementSize; - baseVal = builder->emitElementExtract( - baseVecType->getElementType(), - baseVal, - builder->getIntValue(builder->getIntType(), index)); - offset -= index*vecElementSize; - } - - // At this point, our `baseVal` should be a single `uint`, and - // it should provide the storage for the exact thing we wanted - // to access (under the assumption that we always fetch 4 bytes - // on 4-byte alignment). - // - IRInst* payloadVal = baseVal; - SLANG_ASSERT(offset == 0); - - // TODO: we could imagine adding logic here to handle types less - // than 4 bytes in size by shifting and masking the value we - // just loaded. - - // The payload field we were trying to extract might have a type - // other than `uint`, and to handle that case we need to employ - // a bit-cast to get to the desired type. - // - if( payloadVal->getDataType() != payloadType ) - { - payloadVal = builder->emitBitCast( - payloadType, - payloadVal); - } - return payloadVal; - } - } - - // All of the logic so far as assumed we can just call `getTaggedUnionInfo` - // and have easy access to all the required information and the - // synthesized replacement type. - // - TaggedUnionInfo* getTaggedUnionInfo(IRType* type) - { - // The big picture is fairly simple: we will lazily build and - // memoize the information about tagged unions. - // - { - TaggedUnionInfo* info = nullptr; - if(mapIRTypeToTaggedUnionInfo.tryGetValue(type, info)) - return info; - } - - // When we don't find information in our memo-cache, we - // will construct it and add it to both the memo-cache - // *and* a global list of all tagged unions encountered, - // so that we can replacement them later. - // - auto info = createTaggedUnionInfo(type); - mapIRTypeToTaggedUnionInfo.add(type, info.Ptr()); - taggedUnionInfos.add(info); - - return info; - } - - // The actual logic for creating a `TaggedUnionInfo` is relatively - // straightforward once we've decided what information we need. - // - RefPtr<TaggedUnionInfo> createTaggedUnionInfo(IRType* type) - { - // We expect that any type used as an operation to one of the - // `extractTaggedUnion*` operations must be an IR tagged union. - // - // Note: If/when we ever expose `union`s to user and allow - // then to create *generic* tagged union types it might appear - // that this needs to be changed to account for a `specialize` - // instruction in place of a concrete tagged union, but in - // practice this pass needs to be performed late enough that - // any such generic should be fully specialized. - // - auto taggedUnionType = as<IRTaggedUnionType>(type); - SLANG_ASSERT(taggedUnionType); - - RefPtr<TaggedUnionInfo> info = new TaggedUnionInfo(); - info->taggedUnionType = taggedUnionType; - - // We are going to create an instruction to replace `type`, - // and thus will be placing it into the same parent. - // - auto builder = getBuilder(); - builder->setInsertBefore(type); - - // A tagged union type will be replaced with an ordinary - // `struct` type with fields to store all the relevant - // data from any of the cases, plus a tag field. - // - auto structType = builder->createStructType(); - info->replacementInst = structType; - - // We require/expect the earlier code generation steps to have - // associated a layout with every tagged union that appears in - // the code. - // - auto layoutDecoration = type->findDecoration<IRLayoutDecoration>(); - SLANG_ASSERT(layoutDecoration); - auto layout = layoutDecoration->getLayout(); - SLANG_ASSERT(layout); - auto taggedUnionTypeLayout = as<IRTaggedUnionTypeLayout>(layout); - SLANG_ASSERT(taggedUnionTypeLayout); - - info->taggedUnionTypeLayout = taggedUnionTypeLayout; - - // The size of the "payload" for the different cases (everything but - // the tag) is taken to be the offset of the tag itself. - // - // TODO: this might be inaccurate if the payload size isn't a multiple - // of the tag's alignment. We should deal with that when/if we support - // types smaller than 4 bytes in unions. - // - auto payloadSize = taggedUnionTypeLayout->getTagOffset().getFiniteValue(); - - // We are going to be construction IR code that makes use of the `int` - // and `uint` types in several cases, so we go ahead and get a pointer - // to those types here. - // - auto intType = getBuilder()->getIntType(); - auto uintType = getBuilder()->getBasicType(BaseType::UInt); - - // For now we will use a simple stragegy for how we encode a union, - // which depends only on the total number of bytes needed, and not - // on the makeup of the values being stored. - // - // We will start by allocating one or more `uint4` values (in an - // array for the "or more" case) to hold the bulk of any large - // payload value. - // - size_t bulkVectorSize = 16; // Note: assuming `sizeof(uint4) == 16` on all targets - auto bulkVectorCount = payloadSize / bulkVectorSize; - auto bulkFieldSize = bulkVectorCount * bulkVectorSize; - if( bulkVectorCount ) - { - IRType* bulkFieldType = builder->getVectorType( - uintType, - builder->getIntValue(intType, 4)); - - if( bulkVectorCount > 1 ) - { - bulkFieldType = builder->getArrayType( - bulkFieldType, - builder->getIntValue(intType, bulkVectorCount)); - } - - auto bulkFieldKey = builder->createStructKey(); - builder->createStructField(structType, bulkFieldKey, bulkFieldType); - - info->bulkFieldKey = bulkFieldKey; - info->bulkFieldType = bulkFieldType; - } - info->bulkSize = bulkFieldSize; - - // The rest of the data (anything that doesn't fit in the bulk field), - // will get allocated into a single scalar or vector of `uint`. - // - auto restSize = payloadSize - bulkFieldSize; - if( restSize ) - { - size_t restElementSize = 4; // assuming `sizeof(uint) == 4` on all targets - auto restElementCount = restSize / restElementSize; - auto restFieldSize = restElementSize * restElementCount; - SLANG_ASSERT(restFieldSize == restSize); // Note: all our current targets have minimum 4-byte storage granularity - - IRType* restFieldType = uintType; - if( restElementCount > 1 ) - { - restFieldType = builder->getVectorType( - restFieldType, - builder->getIntValue(intType, restElementCount)); - } - - auto restFieldKey = builder->createStructKey(); - builder->createStructField(structType, restFieldKey, restFieldType); - - info->restFieldKey = restFieldKey; - info->restFieldType = restFieldType; - info->restSize = restFieldSize; - } - - // Finally, we add a field to represent the tag. - // - auto tagFieldType = uintType; - auto tagFieldKey = builder->createStructKey(); - builder->createStructField(structType, tagFieldKey, tagFieldType); - - info->tagFieldKey = tagFieldKey; - - return info; - } -}; - -void desugarUnionTypes( - IRModule* module) -{ - DesugarUnionTypesContext context; - context.module = module; - - context.processModule(); -} - -} // namespace Slang diff --git a/source/slang/slang-ir-union.h b/source/slang/slang-ir-union.h deleted file mode 100644 index 81757dced..000000000 --- a/source/slang/slang-ir-union.h +++ /dev/null @@ -1,18 +0,0 @@ -// slang-ir-union.h -#pragma once - -namespace Slang { - -struct IRModule; - - /// Desugar any unions types, and code using them, in `module` - /// - /// Union types will be replaced with ordinary `struct` types that store - /// the data of the underlying type as a "bag of bits" and references - /// to cases of the union will be replaced with logic to extract the - /// relevant bits. - /// -void desugarUnionTypes( - IRModule* module); - -} // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index a27bf8658..38d1eb520 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -1004,32 +1004,6 @@ namespace Slang } // - // IRTaggedUnionTypeLayout - // - - IRTaggedUnionTypeLayout::Builder::Builder(IRBuilder* irBuilder, LayoutSize tagOffset) - : Super::Builder(irBuilder) - { - m_tagOffset = irBuilder->getIntValue(irBuilder->getIntType(), tagOffset.raw); - } - - void IRTaggedUnionTypeLayout::Builder::addCaseTypeLayout(IRTypeLayout* typeLayout) - { - m_caseTypeLayoutAttrs.add(getIRBuilder()->getCaseTypeLayoutAttr(typeLayout)); - } - - void IRTaggedUnionTypeLayout::Builder::addOperandsImpl(List<IRInst*>& ioOperands) - { - ioOperands.add(m_tagOffset); - } - - void IRTaggedUnionTypeLayout::Builder::addAttrsImpl(List<IRInst*>& ioOperands) - { - for(auto attr : m_caseTypeLayoutAttrs) - ioOperands.add(attr); - } - - // // IRVarLayout // @@ -2981,17 +2955,6 @@ namespace Slang operands); } - IRType* IRBuilder::getTaggedUnionType( - UInt caseCount, - IRType* const* caseTypes) - { - return (IRType*)createIntrinsicInst( - getTypeKind(), - kIROp_TaggedUnionType, - caseCount, - (IRInst* const*) caseTypes); - } - IRType* IRBuilder::getBindExistentialsType( IRInst* baseType, UInt slotArgCount, @@ -3335,7 +3298,6 @@ namespace Slang IRInst* const* args) { auto innerReturnVal = findInnerMostGenericReturnVal(as<IRGeneric>(genericVal)); - if (as<IRWitnessTable>(innerReturnVal)) { return createIntrinsicInst( @@ -3371,7 +3333,7 @@ namespace Slang // the emit logic, but this is a reasonably early place // to catch it. // - SLANG_ASSERT(witnessTableVal->getOp() != kIROp_StructKey); + SLANG_ASSERT(witnessTableVal && witnessTableVal->getOp() != kIROp_StructKey); IRInst* args[] = {witnessTableVal, interfaceMethodVal}; @@ -5536,6 +5498,8 @@ namespace Slang return emitIntrinsicInst( getNativePtrType((IRType*)valueType->getOperand(0)), kIROp_GetNativePtr, 1, &value); break; + case kIROp_ExtractExistentialType: + return emitGetNativePtr(value->getOperand(0)); default: SLANG_UNEXPECTED("invalid operand type for `getNativePtr`."); UNREACHABLE_RETURN(nullptr); diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index b0d9bb109..97f98fce2 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1749,11 +1749,6 @@ struct IRInterfaceType : IRType IR_LEAF_ISA(InterfaceType) }; -struct IRTaggedUnionType : IRType -{ - IR_LEAF_ISA(TaggedUnionType) -}; - struct IRConjunctionType : IRType { IR_LEAF_ISA(ConjunctionType) diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index e9c5f8fe5..d29cc4485 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -142,11 +142,6 @@ public: bool visitSharedTypeExpr(SharedTypeExpr* expr) { return dispatchIfNotNull(expr->base.exp); } - bool visitTaggedUnionTypeExpr(TaggedUnionTypeExpr*) - { - return false; - } - bool visitInvokeExpr(InvokeExpr* expr) { PushNode pushNodeRAII(context, expr); diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp index 6bccca8d3..4ec0ac64f 100644 --- a/source/slang/slang-language-server-completion.cpp +++ b/source/slang/slang-language-server-completion.cpp @@ -596,7 +596,7 @@ List<LanguageServerProtocol::CompletionItem> CompletionContext::createSwizzleCan { const char* memberNames[4] = {"x", "y", "z", "w"}; Type* elementType = nullptr; - elementType = vectorType->elementType; + elementType = vectorType->getElementType(); String typeStr; if (elementType) typeStr = elementType->toString(); diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index e79716975..bc12ad34f 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -213,7 +213,7 @@ static bool isBoolType(Type* t) auto basicType = as<BasicExpressionType>(t); if (!basicType) return false; - return basicType->baseType == BaseType::Bool; + return basicType->getBaseType() == BaseType::Bool; } String getDeclKindString(DeclRef<Decl> declRef) @@ -303,11 +303,11 @@ String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version) sb << " = "; if (isBoolType(varDecl->getType())) { - sb << (constantInt->value ? "true" : "false"); + sb << (constantInt->getValue() ? "true" : "false"); } else { - sb << constantInt->value; + sb << constantInt->getValue(); } } else @@ -492,6 +492,8 @@ SlangResult LanguageServer::hover( doc->zeroBasedUTF16LocToOneBasedUTF8Loc(args.position.line, args.position.character, line, col); auto version = m_workspace->getCurrentVersion(); + SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder()); + Module* parsedModule = version->getOrLoadModule(canonicalPath); if (!parsedModule) { @@ -741,6 +743,8 @@ SlangResult LanguageServer::gotoDefinition( doc->zeroBasedUTF16LocToOneBasedUTF8Loc(args.position.line, args.position.character, line, col); auto version = m_workspace->getCurrentVersion(); + SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder()); + Module* parsedModule = version->getOrLoadModule(canonicalPath); if (!parsedModule) { @@ -1029,6 +1033,8 @@ SlangResult LanguageServer::semanticTokens( } auto version = m_workspace->getCurrentVersion(); + SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder()); + Module* parsedModule = version->getOrLoadModule(canonicalPath); if (!parsedModule) { @@ -1073,6 +1079,7 @@ String LanguageServer::getExprDeclSignature(Expr* expr, String* outDocumentation return String(); auto version = m_workspace->getCurrentVersion(); + SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder()); SignatureInformation sigInfo; @@ -1096,7 +1103,7 @@ String LanguageServer::getExprDeclSignature(Expr* expr, String* outDocumentation bool isFirst = true; printer.getStringBuilder() << "("; int paramIndex = 0; - for (auto param : funcType->paramTypes) + for (auto param : funcType->getParamTypes()) { if (!isFirst) printer.getStringBuilder() << ", "; @@ -1134,6 +1141,8 @@ String LanguageServer::getExprDeclSignature(Expr* expr, String* outDocumentation String LanguageServer::getDeclRefSignature(DeclRef<Decl> declRef, String* outDocumentation, List<Slang::Range<Index>>* outParamRanges) { auto version = m_workspace->getCurrentVersion(); + SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder()); + ASTPrinter printer( version->linkage->getASTBuilder(), ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | @@ -1169,6 +1178,8 @@ SlangResult LanguageServer::signatureHelp( doc->zeroBasedUTF16LocToOneBasedUTF8Loc(args.position.line, args.position.character, line, col); auto version = m_workspace->getCurrentVersion(); + SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder()); + Module* parsedModule = version->getOrLoadModule(canonicalPath); if (!parsedModule) { @@ -1289,7 +1300,7 @@ SlangResult LanguageServer::signatureHelp( printer.getStringBuilder() << "func ("; bool isFirst = true; - for (auto param : funcType->paramTypes) + for (auto param : funcType->getParamTypes()) { if (!isFirst) printer.getStringBuilder() << ", "; @@ -1315,12 +1326,12 @@ SlangResult LanguageServer::signatureHelp( if (auto declRefExpr = as<DeclRefExpr>(funcExpr)) { - if (auto aggDecl = as<AggTypeDecl>(declRefExpr->declRef.getDecl())) + if (auto aggDeclRef = as<AggTypeDecl>(declRefExpr->declRef)) { // Look for initializers - for (auto member : aggDecl->getMembersOfType<ConstructorDecl>()) + for (auto member : getMembersOfType<ConstructorDecl>(version->linkage->getASTBuilder(), aggDeclRef)) { - addDeclRef(version->linkage->getASTBuilder()->getSpecializedDeclRef<Decl>(member, declRefExpr->declRef.getSubst())); + addDeclRef(member); } } else @@ -1379,6 +1390,8 @@ SlangResult LanguageServer::documentSymbol( return SLANG_OK; } auto version = m_workspace->getCurrentVersion(); + SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder()); + Module* parsedModule = version->getOrLoadModule(canonicalPath); if (!parsedModule) { @@ -1400,6 +1413,8 @@ SlangResult LanguageServer::inlayHint(const LanguageServerProtocol::InlayHintPar return SLANG_OK; } auto version = m_workspace->getCurrentVersion(); + SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder()); + Module* parsedModule = version->getOrLoadModule(canonicalPath); if (!parsedModule) { @@ -1518,6 +1533,8 @@ void LanguageServer::publishDiagnostics() m_lastDiagnosticUpdateTime = std::chrono::system_clock::now(); auto version = m_workspace->getCurrentVersion(); + SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder()); + // Send updates to clear diagnostics for files that no longer have any messages. List<String> filesToRemove; for (auto& file : m_lastPublishedDiagnostics) diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index 2eca91673..89d3380e4 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -16,7 +16,7 @@ void ensureDecl(SemanticsVisitor* visitor, Decl* decl, DeclCheckState state); // -DeclRef<ExtensionDecl> ApplyExtensionToType( +DeclRef<ExtensionDecl> applyExtensionToType( SemanticsVisitor* semantics, ExtensionDecl* extDecl, Type* type); @@ -161,14 +161,12 @@ static bool _isUncheckedLocalVar(const Decl* decl) static void _lookUpDirectAndTransparentMembers( ASTBuilder* astBuilder, Name* name, - DeclRef<ContainerDecl> containerDeclRef, + ContainerDecl* containerDecl, // The container decl to find member with `name`. + DeclRef<Decl> parentDeclRef, // The parent of the resulting declref. LookupRequest const& request, LookupResult& result, BreadcrumbInfo* inBreadcrumbs) { - ContainerDecl* containerDecl = containerDeclRef.getDecl(); - - if (request.isCompletionRequest()) { // If we are looking up for completion suggestions, @@ -182,7 +180,7 @@ static void _lookUpDirectAndTransparentMembers( AddToLookupResult( result, CreateLookupResultItem( - astBuilder->getSpecializedDeclRef<Decl>(member, containerDeclRef.getSubst()), inBreadcrumbs)); + astBuilder->getMemberDeclRef<Decl>(parentDeclRef, member), inBreadcrumbs)); } } else @@ -207,7 +205,7 @@ static void _lookUpDirectAndTransparentMembers( continue; // The declaration passed the test, so add it! - AddToLookupResult(result, CreateLookupResultItem(astBuilder->getSpecializedDeclRef<Decl>(m, containerDeclRef.getSubst()), inBreadcrumbs)); + AddToLookupResult(result, CreateLookupResultItem(astBuilder->getMemberDeclRef<Decl>(parentDeclRef, m), inBreadcrumbs)); } } @@ -215,9 +213,9 @@ static void _lookUpDirectAndTransparentMembers( // if we already has a hit in the current container? for(auto transparentInfo : containerDecl->getTransparentMembers()) { - // The reference to the transparent member should use whatever - // substitutions we used in referring to its outer container - DeclRef<Decl> transparentMemberDeclRef = astBuilder->getSpecializedDeclRef(transparentInfo.decl, containerDeclRef.getSubst()); + // The reference to the transparent member should use the same + // path as we used in referring to its parent. + DeclRef<Decl> transparentMemberDeclRef = astBuilder->getMemberDeclRef(parentDeclRef, transparentInfo.decl); // We need to leave a breadcrumb so that we know that the result // of lookup involves a member lookup step here @@ -262,7 +260,8 @@ LookupResult lookUpDirectAndTransparentMembers( ASTBuilder* astBuilder, SemanticsVisitor* semantics, Name* name, - DeclRef<ContainerDecl> containerDeclRef, + ContainerDecl* containerDecl, + DeclRef<Decl> parentDeclRef, LookupMask mask) { LookupRequest request = initLookupRequest(semantics, name, mask, LookupOptions::None, nullptr); @@ -270,36 +269,14 @@ LookupResult lookUpDirectAndTransparentMembers( _lookUpDirectAndTransparentMembers( astBuilder, name, - containerDeclRef, + containerDecl, + parentDeclRef, request, result, nullptr); return result; } -static SubtypeWitness* _makeSubtypeWitness( - ASTBuilder* astBuilder, - Type* subType, - SubtypeWitness* subToMidWitness, - Type* superType, - SubtypeWitness* midtoSuperWitness) -{ - SLANG_UNUSED(subType); - SLANG_UNUSED(superType); - - if(subToMidWitness) - { - auto transitiveWitness = astBuilder->getTransitiveSubtypeWitness( - subToMidWitness, - midtoSuperWitness); - return transitiveWitness; - } - else - { - return midtoSuperWitness; - } -} - // Specialize `declRefToSpecialize` with ThisType info if `superType` is an interface type. DeclRef<Decl> _maybeSpecializeSuperTypeDeclRef( ASTBuilder* astBuilder, @@ -309,14 +286,10 @@ DeclRef<Decl> _maybeSpecializeSuperTypeDeclRef( { if (auto superDeclRefType = as<DeclRefType>(superType)) { - if (auto superInterfaceDeclRef = superDeclRefType->declRef.as<InterfaceDecl>()) + if (auto superInterfaceDeclRef = superDeclRefType->getDeclRef().as<InterfaceDecl>()) { - ThisTypeSubstitution* thisTypeSubst = astBuilder->getOrCreateThisTypeSubstitution( - superInterfaceDeclRef.getDecl(), - subIsSuperWitness, - declRefToSpecialize.getSubst()); - - auto specializedDeclRef = astBuilder->getSpecializedDeclRef<Decl>(declRefToSpecialize.getDecl(), thisTypeSubst); + ThisTypeDecl* thisTypeDecl = superInterfaceDeclRef.getDecl()->getThisTypeDecl(); + auto specializedDeclRef = astBuilder->getLookupDeclRef(subIsSuperWitness, thisTypeDecl); return specializedDeclRef; } @@ -332,7 +305,7 @@ static Type* _maybeSpecializeSuperType( { if (auto superDeclRefType = as<DeclRefType>(superType)) { - auto specializedDeclRef = _maybeSpecializeSuperTypeDeclRef(astBuilder, superDeclRefType->declRef, superType, subIsSuperWitness); + auto specializedDeclRef = _maybeSpecializeSuperTypeDeclRef(astBuilder, superDeclRefType->getDeclRef(), superType, subIsSuperWitness); return DeclRefType::create(astBuilder, specializedDeclRef); } @@ -391,14 +364,21 @@ static void _lookUpMembersInSuperType( } static void _lookUpMembersInSuperTypeDeclImpl( - ASTBuilder* astBuilder, - Name* name, + ASTBuilder* astBuilder, + Name* name, DeclRef<Decl> declRef, - LookupRequest const& request, - LookupResult& ioResult, - BreadcrumbInfo* inBreadcrumbs) + LookupRequest const& request, + LookupResult& ioResult, + BreadcrumbInfo* inBreadcrumbs) { auto semantics = request.semantics; + if (!as<InterfaceDecl>(declRef.getDecl()) && name->text == "This") + { + // If we are looking for `This` in anything other than an InterfaceDecl, + // we just need to return the declRef itself. + AddToLookupResult(ioResult, CreateLookupResultItem(declRef, inBreadcrumbs)); + return; + } // If the semantics context hasn't been established yet (e.g. when looking up during parsing), // we simply do a direct lookup without considering subtypes or extensions. @@ -408,7 +388,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( // In this case we can only lookup in an aggregate type. if (auto aggTypeDeclBaseRef = declRef.as<AggTypeDeclBase>()) { - _lookUpDirectAndTransparentMembers(astBuilder, name, aggTypeDeclBaseRef, request, ioResult, inBreadcrumbs); + _lookUpDirectAndTransparentMembers(astBuilder, name, aggTypeDeclBaseRef.getDecl(), aggTypeDeclBaseRef, request, ioResult, inBreadcrumbs); } return; } @@ -464,7 +444,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( // relying on the modifier. if (auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(facet->subtypeWitness)) { - auto inheritanceDeclRef = declaredSubtypeWitness->declRef; + auto inheritanceDeclRef = declaredSubtypeWitness->getDeclRef(); if (inheritanceDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>()) continue; } @@ -473,6 +453,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( BreadcrumbInfo* newBreadcrumbs = inBreadcrumbs; BreadcrumbInfo subtypeInfo; + auto parentDeclRef = containerDeclRef; if (facet->directness != Facet::Directness::Self) { // Depending on the type of the facet, we may want to specialize the @@ -487,9 +468,15 @@ static void _lookUpMembersInSuperTypeDeclImpl( // we should also specialize the interface declRef with the concrete // type info. // - containerDeclRef = _maybeSpecializeSuperTypeDeclRef( + parentDeclRef = _maybeSpecializeSuperTypeDeclRef( astBuilder, containerDeclRef, facet->getType(), facet->subtypeWitness) .as<ContainerDecl>(); + if (as<ThisTypeDecl>(parentDeclRef.getDecl()) && name->text == "This") + { + // If we are going looking for `This` in a `ThisType`, we just need to return the declRef itself. + AddToLookupResult(ioResult, CreateLookupResultItem(parentDeclRef, inBreadcrumbs)); + continue; + } // If we are looking up in a base type, we also need to make sure // to create a breadcrumb to track the sub to super indirection. @@ -502,7 +489,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( newBreadcrumbs = &subtypeInfo; } } - _lookUpDirectAndTransparentMembers(astBuilder, name, containerDeclRef, request, ioResult, newBreadcrumbs); + _lookUpDirectAndTransparentMembers(astBuilder, name, containerDeclRef.getDecl(), parentDeclRef, request, ioResult, newBreadcrumbs); } } @@ -540,7 +527,7 @@ static void _lookUpMembersInSuperTypeImpl( if(auto declRefType = as<DeclRefType>(superType)) { - auto declRef = declRefType->declRef; + auto declRef = declRefType->getDeclRef(); _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, declRef, request, ioResult, inBreadcrumbs); } @@ -551,36 +538,16 @@ static void _lookUpMembersInSuperTypeImpl( // lookup will have a comparable substitution applied (allowing things like associated // types, etc. used in the signature of a method to resolve correctly). // - auto interfaceDeclRef = extractExistentialType->getSpecializedInterfaceDeclRef(); - _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, interfaceDeclRef, request, ioResult, inBreadcrumbs); - } - else if( auto thisType = as<ThisType>(superType) ) - { - // We need to create a witness that represents the next link in the - // chain. The `leafIsSuperWitness` represents the knowledge that `leafType : superType` - // (and we know that `superType == thisType`, but we now need to extend that - // with the knowledge that `thisType : thisType->interfaceTypeDeclRef`. - // - auto interfaceType = DeclRefType::create(astBuilder, thisType->interfaceDeclRef); - - auto superIsInterfaceWitness = astBuilder->getThisTypeSubtypeWitness(superType, interfaceType); - - auto leafIsInterfaceWitness = _makeSubtypeWitness( - astBuilder, - leafType, - leafIsSuperWitness, - interfaceType, - superIsInterfaceWitness); - - _lookUpMembersInSuperType(astBuilder, name, leafType, interfaceType, leafIsInterfaceWitness, request, ioResult, inBreadcrumbs); + auto thisTypeDeclRef = extractExistentialType->getThisTypeDeclRef(); + _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, thisTypeDeclRef, request, ioResult, inBreadcrumbs); } else if( auto andType = as<AndType>(superType) ) { // We have a type of the form `leftType & rightType` and we need to perform // lookup in both `leftType` and `rightType`. // - auto leftType = andType->left; - auto rightType = andType->right; + auto leftType = andType->getLeft(); + auto rightType = andType->getRight(); // Operationally, we are in a situation where we have a witness // that the `leafType` we are doing lookup on is an subtype @@ -731,7 +698,7 @@ static void _lookUpInScopes( // just a decl. // DeclRef<ContainerDecl> containerDeclRef = - astBuilder->getSpecializedDeclRef<Decl>(containerDecl, createDefaultSubstitutions(astBuilder, request.semantics, containerDecl)).as<ContainerDecl>(); + createDefaultSubstitutionsIfNeeded(astBuilder, request.semantics, makeDeclRef(containerDecl)).as<ContainerDecl>(); // If the container we are looking into represents a type // or an `extension` of a type, then we need to treat @@ -755,7 +722,7 @@ static void _lookUpInScopes( breadcrumb.thisParameterMode = thisParameterMode; breadcrumb.declRef = aggTypeDeclBaseRef; breadcrumb.prev = nullptr; - + BreadcrumbInfo* breadcrumbPtr = &breadcrumb; Type* type = nullptr; if (auto extDeclRef = aggTypeDeclBaseRef.as<ExtensionDecl>()) { @@ -773,10 +740,25 @@ static void _lookUpInScopes( else { assert(aggTypeDeclBaseRef.as<AggTypeDecl>()); - type = DeclRefType::create(astBuilder, aggTypeDeclBaseRef); + if (auto interfaceBase = as<InterfaceDecl>(aggTypeDeclBaseRef.getDecl())) + { + // When looking up inside an interface type, we are actually looking up through ThisType. + if (name != interfaceBase->getThisTypeDecl()->getName()) + { + type = DeclRefType::create(astBuilder, astBuilder->getMemberDeclRef(aggTypeDeclBaseRef, interfaceBase->getThisTypeDecl())); + // Don't need any breadcrumb for looking up through ThisType, since we have already + // created the base type reference in the new `type`'s declref. + breadcrumbPtr = nullptr; + } + } + + if (!type) + { + type = DeclRefType::create(astBuilder, aggTypeDeclBaseRef); + } } - _lookUpMembersInType(astBuilder, name, type, request, result, &breadcrumb); + _lookUpMembersInType(astBuilder, name, type, request, result, breadcrumbPtr); } else { @@ -784,7 +766,7 @@ static void _lookUpInScopes( // type or `extension` declaration, so we can look up members // in that scope much more simply. // - _lookUpDirectAndTransparentMembers(astBuilder, name, containerDeclRef, request, result, nullptr); + _lookUpDirectAndTransparentMembers(astBuilder, name, containerDeclRef.getDecl(), containerDeclRef, request, result, nullptr); } // Before we proceed up to the next outer scope to perform lookup diff --git a/source/slang/slang-lookup.h b/source/slang/slang-lookup.h index 69374c024..8af760f70 100644 --- a/source/slang/slang-lookup.h +++ b/source/slang/slang-lookup.h @@ -35,7 +35,8 @@ LookupResult lookUpDirectAndTransparentMembers( ASTBuilder* astBuilder, SemanticsVisitor* semantics, Name* name, - DeclRef<ContainerDecl> containerDeclRef, + ContainerDecl* containerDecl, + DeclRef<Decl> parentDeclRef, // The parent of the resulting declref. LookupMask mask = LookupMask::Default); // TODO: this belongs somewhere else diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 0773226d1..e0e97d6e7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -804,7 +804,7 @@ LoweredValInfo emitCallToDeclRef( if( auto ctorDeclRef = funcDeclRef.as<ConstructorDecl>() ) { - if(!ctorDeclRef.getDecl()->body && isFromStdLib(ctorDeclRef.getDecl()) && !as<InterfaceDecl>(ctorDeclRef.getParent(context->astBuilder).getDecl())) + if(!ctorDeclRef.getDecl()->body && isFromStdLib(ctorDeclRef.getDecl()) && !as<InterfaceDecl>(ctorDeclRef.getParent().getDecl())) { SLANG_UNREACHABLE("stdlib error: __init() has no definition."); } @@ -1399,7 +1399,7 @@ void getGenericTypeConformances(IRGenContext* context, ShortList<IRType*>& supTy { if (auto declRefType = as<DeclRefType>(typeConstraint->sub.type)) { - if (declRefType->declRef.getDecl() == genericParamDecl) + if (declRefType->getDeclRef().getDecl() == genericParamDecl) { supTypes.add(lowerType(context, typeConstraint->getSup().type)); } @@ -1408,6 +1408,36 @@ void getGenericTypeConformances(IRGenContext* context, ShortList<IRType*>& supTy } } + +// Check if declRef represents a witness that `ISomeInterface.This : ISomeInterface`. +static bool _isThisTypeSubtypeWitness(DeclRefBase* declRef) +{ + auto lookupDeclRef = as<LookupDeclRef>(declRef); + if (!lookupDeclRef) + return false; + if (!as<ThisType>(lookupDeclRef->getLookupSource())) + return false; + auto declaredWitness = as<DeclaredSubtypeWitness>(lookupDeclRef->getWitness()); + if (!declaredWitness) + return false; + if (!as<ThisTypeConstraintDecl>(declaredWitness->getDeclRef())) + return false; + return true; +} + +// Returns whether `declRef` represents a trivial lookup of an interface requirement +// through `ThisTypeDecl` made from within the same interface Decl. +static bool _isTrivialLookupFromInterfaceThis(IRGenContext* context, DeclRefBase* declRef) +{ + if (!_isThisTypeSubtypeWitness(declRef)) + return false; + // This is a lookup from an interface's This type. + // If the lookup is made from an interface type itself rather than an extension of it, + // then it is a trivial lookup and we should lower it as a struct key. + return context->thisTypeWitness == nullptr; +} + + // struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredValInfo> @@ -1424,24 +1454,24 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitGenericParamIntVal(GenericParamIntVal* val) { - return emitDeclRef(context, val->declRef, - lowerType(context, getType(context->astBuilder, val->declRef))); + return emitDeclRef(context, val->getDeclRef(), + lowerType(context, getType(context->astBuilder, val->getDeclRef()))); } LoweredValInfo visitFuncCallIntVal(FuncCallIntVal* val) { TryClauseEnvironment tryEnv; List<IRInst*> args; - for (auto arg : val->args) + for (auto arg : val->getArgs()) { auto loweredArg = lowerVal(context, arg); args.add(loweredArg.val); } - auto funcType = lowerType(context, val->funcType); + auto funcType = lowerType(context, val->getFuncType()); return emitCallToDeclRef( context, as<IRFuncType>(funcType)->getResultType(), - val->funcDeclRef, + val->getFuncDeclRef(), funcType, args, tryEnv); @@ -1449,17 +1479,17 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitTypeCastIntVal(TypeCastIntVal* val) { - auto baseVal = lowerVal(context, val->base); + auto baseVal = lowerVal(context, val->getBase()); SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); - auto type = lowerType(context, val->type); + auto type = lowerType(context, val->getType()); return LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val)); } LoweredValInfo visitWitnessLookupIntVal(WitnessLookupIntVal* val) { - auto witnessVal = lowerVal(context, val->witness); - auto key = getInterfaceRequirementKey(context, val->key); - auto type = lowerType(context, val->type); + auto witnessVal = lowerVal(context, val->getWitness()); + auto key = getInterfaceRequirementKey(context, val->getKey()); + auto type = lowerType(context, val->getType()); return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( type, witnessVal.val, key)); } @@ -1467,16 +1497,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitPolynomialIntVal(PolynomialIntVal* val) { auto irBuilder = getBuilder(); - auto type = lowerType(context, val->type); - auto constTerm = irBuilder->getIntValue(type, val->constantTerm); + auto type = lowerType(context, val->getType()); + auto constTerm = irBuilder->getIntValue(type, val->getConstantTerm()); auto resultVal = constTerm; - for (auto term : val->terms) + for (auto term : val->getTerms()) { - auto termVal = irBuilder->getIntValue(type, term->constFactor); - for (auto factor : term->paramFactors) + auto termVal = irBuilder->getIntValue(type, term->getConstFactor()); + for (auto factor : term->getParamFactors()) { - auto factorVal = lowerVal(context, factor->param).val; - for (IntegerLiteralValue i = 0; i < factor->power; i++) + auto factorVal = lowerVal(context, factor->getParam()).val; + for (IntegerLiteralValue i = 0; i < factor->getPower(); i++) { termVal = irBuilder->emitMul(factorVal->getDataType(), termVal, factorVal); } @@ -1488,9 +1518,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val) { - return emitDeclRef(context, val->declRef, + if (as<ThisTypeConstraintDecl>(val->getDeclRef())) + return LoweredValInfo::simple(context->thisTypeWitness); + + return emitDeclRef(context, val->getDeclRef(), context->irBuilder->getWitnessTableType( - lowerType(context, val->sup))); + lowerType(context, val->getSup()))); } LoweredValInfo visitTransitiveSubtypeWitness( @@ -1498,7 +1531,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower { // The base (subToMid) will turn into a value with // witness-table type. - IRInst* baseWitnessTable = lowerSimpleVal(context, val->subToMid); + IRInst* baseWitnessTable = lowerSimpleVal(context, val->getSubToMid()); IRInst* midToSup = nullptr; // The next step should map to an interface requirement @@ -1530,17 +1563,17 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(midToSup); } - if (auto declaredMidToSup = as<DeclaredSubtypeWitness>(val->midToSup)) + if (auto declaredMidToSup = as<DeclaredSubtypeWitness>(val->getMidToSup())) { - midToSup = getInterfaceRequirementKey(context, declaredMidToSup->declRef.getDecl()); + midToSup = getInterfaceRequirementKey(context, declaredMidToSup->getDeclRef().getDecl()); } else { - midToSup = lowerSimpleVal(context, val->midToSup); + midToSup = lowerSimpleVal(context, val->getMidToSup()); } return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( - getBuilder()->getWitnessTableType(lowerType(context, val->sup)), + getBuilder()->getWitnessTableType(lowerType(context, val->getSup())), baseWitnessTable, midToSup)); } @@ -1550,7 +1583,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // TODO: properly fill in type info here. // We should consider fold all cases of witness table entries to `Val`, and make the `DeclRef` case a `DeclRefVal`. // So that we can hold the type in `DeclRefVal`. - auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + auto funcVal = emitDeclRef(context, val->getFunc(), context->irBuilder->getTypeKind()); SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); auto diff = getBuilder()->emitForwardDifferentiateInst(getBuilder()->getTypeKind(), funcVal.val); @@ -1559,7 +1592,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitBackwardDifferentiateVal(BackwardDifferentiateVal* val) { - auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + auto funcVal = emitDeclRef(context, val->getFunc(), context->irBuilder->getTypeKind()); SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); auto diff = getBuilder()->emitBackwardDifferentiateInst(getBuilder()->getTypeKind(), funcVal.val); @@ -1568,7 +1601,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitBackwardDifferentiatePropagateVal(BackwardDifferentiatePropagateVal* val) { - auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + auto funcVal = emitDeclRef(context, val->getFunc(), context->irBuilder->getTypeKind()); SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); auto diff = getBuilder()->emitBackwardDifferentiatePropagateInst(getBuilder()->getTypeKind(), funcVal.val); @@ -1577,7 +1610,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitBackwardDifferentiatePrimalVal(BackwardDifferentiatePrimalVal* val) { - auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + auto funcVal = emitDeclRef(context, val->getFunc(), context->irBuilder->getTypeKind()); SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); auto diff = getBuilder()->emitBackwardDifferentiatePrimalInst(getBuilder()->getTypeKind(), funcVal.val); @@ -1586,280 +1619,18 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitBackwardDifferentiateIntermediateTypeVal(BackwardDifferentiateIntermediateTypeVal* val) { - auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + auto funcVal = emitDeclRef(context, val->getFunc(), context->irBuilder->getTypeKind()); SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); auto diff = getBuilder()->getBackwardDiffIntermediateContextType(funcVal.val); return LoweredValInfo::simple(diff); } - LoweredValInfo visitTaggedUnionSubtypeWitness( - TaggedUnionSubtypeWitness* val) - { - // The sub-type in this case is a tagged union `A | B | ...`, - // and the witness holds an array of witnesses showing that each - // "case" (`A`, `B`, etc.) is a subtype of the super-type. - - // We will start by getting the IR-level representation of the - // sub type (the tagged union type). - // - auto irTaggedUnionType = lowerType(context, val->sub); - - // We can turn each of those per-case witnesses into a witness - // table value: - // - auto caseCount = val->caseWitnesses.getCount(); - List<IRInst*> caseWitnessTables; - for( auto caseWitness : val->caseWitnesses ) - { - auto caseWitnessTable = lowerSimpleVal(context, caseWitness); - caseWitnessTables.add(caseWitnessTable); - } - - // Now we need to synthesize a witness table for the tagged union - // value, showing how it can implement all of the requirements - // of the super type by delegating to the appropriate implementation - // on a per-case basis. - // - // We will assume here that the super-type is an interface, and it - // will be left to the front-end to ensure this property. - // - auto supDeclRefType = as<DeclRefType>(val->sup); - if(!supDeclRefType) - { - SLANG_UNEXPECTED("super-type not a decl-ref type when generating tagged union witness table"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - auto supInterfaceDeclRef = supDeclRefType->declRef.as<InterfaceDecl>(); - if( !supInterfaceDeclRef ) - { - SLANG_UNEXPECTED("super-type not an interface type when generating tagged union witness table"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - auto subType = lowerType(context, val->sub); - auto irWitnessTableBaseType = lowerType(context, supDeclRefType); - auto irWitnessTable = getBuilder()->createWitnessTable(irWitnessTableBaseType, subType); - - // Now we will iterate over the requirements (members) of the - // interface and try to synthesize an appropriate value for each. - // - for( auto reqDeclRef : getMembers(context->astBuilder, supInterfaceDeclRef) ) - { - // TODO: if there are any members we shouldn't process as a requirement, - // then we should detect and skip them here. - // - - // Every interface requirement will have a unique key that is used - // when looking up the requirement in a concrete witness table. - // - auto irReqKey = getInterfaceRequirementKey(context, reqDeclRef.getDecl()); - - if (!irReqKey) - continue; - - // We expect that each of the witness tables in `caseWitnessTables` - // will have an entry to match these keys. However, we may not - // have a concrete `IRWitnessTable` for each of the case types, either - // because they are a specialization of a generic (so that the witness - // table reference is a `specialize` instruction at this point), or - // they are a type external to this module (so that we have a declaration - // rather than a definition of the witness table). - - // Our task is to create an IR value that can satisfy the interface - // requirement for the tagged union type, by appropriately delegating - // to the implementations of the same requirement in the case types. - // - IRInst* irSatisfyingVal = nullptr; - - - - if(auto callableDeclRef = reqDeclRef.as<CallableDecl>()) - { - // We have something callable, so we need to synthesize - // a function to satisfy it. - // - auto irFunc = getBuilder()->createFunc(); - irSatisfyingVal = irFunc; - - IRBuilder subBuilderStorage = *getBuilder(); - auto subBuilder = &subBuilderStorage; - subBuilder->setInsertInto(irFunc); - - // We will start by setting up the function parameters, - // which live in the entry block of the IR function. - // - auto entryBlock = subBuilder->emitBlock(); - subBuilder->setInsertInto(entryBlock); - - // Create a `this` parameter of the tagged-union type. - // - // TODO: need to handle the `[mutating]` case here... - // - auto irThisType = irTaggedUnionType; - auto irThisParam = subBuilder->emitParam(irThisType); - - List<IRType*> irParamTypes; - irParamTypes.add(irThisType); - - // Create the remaining parameters of the callable, - // using a decl-ref specialized to the tagged union - // type (so that things like associated types are - // mapped to the correct witness value). - // - List<IRParam*> irParams; - for( auto paramDeclRef : getMembersOfType<ParamDecl>(context->astBuilder, callableDeclRef) ) - { - // TODO: need to handle `out` and `in out` here. Over all - // there is a lot of duplication here with the existing logic - // for emitting the signature of a `CallableDecl`, and we should - // try to re-use that if at all possible. - // - auto irParamType = lowerType(context, getType(context->astBuilder, paramDeclRef)); - auto irParam = subBuilder->emitParam(irParamType); - - irParams.add(irParam); - irParamTypes.add(irParamType); - } - - auto irResultType = lowerType(context, getResultType(context->astBuilder, callableDeclRef)); - - auto irFuncType = subBuilder->getFuncType( - irParamTypes, - irResultType); - irFunc->setFullType(irFuncType); - - // The first thing our function needs to do is extract the tag - // from the incoming `this` parameter. - // - auto irTagVal = subBuilder->emitExtractTaggedUnionTag(irThisParam); - - // Next we want to emit a `switch` on the tag value, but before we - // do that we need to generate the code for each of the cases so that - // our `switch` has somewhere to branch to. - // - List<IRInst*> switchCaseOperands; - - IRBlock* defaultLabel = nullptr; - - for( Index ii = 0; ii < caseCount; ++ii ) - { - auto caseTag = subBuilder->getIntValue(irTagVal->getDataType(), ii); - - subBuilder->setInsertInto(irFunc); - auto caseLabel = subBuilder->emitBlock(); - - if(!defaultLabel) - defaultLabel = caseLabel; - - switchCaseOperands.add(caseTag); - switchCaseOperands.add(caseLabel); - - subBuilder->setInsertInto(caseLabel); - - // We need to look up the satisfying value for this interface - // requirement on the witness table of the particular case value. - // - // We already have the witness table, and the requirement key is - // just `irReqKey`. - // - auto caseWitnessTable = caseWitnessTables[ii]; - - // The subtle bit here is determining the type we expect the - // satisfying value to have, since that depends on the actual - // type that is satisfying the requirement. - // - IRType* caseResultType = irResultType; - IRType* caseFuncType = nullptr; - auto caseFunc = subBuilder->emitLookupInterfaceMethodInst( - caseFuncType, - caseWitnessTable, - irReqKey); - - // We are going to emit a `call` to the satisfying value - // for the case type, so we will collect the arguments for that call. - // - List<IRInst*> caseArgs; - - // The `this` argument to the call will need to represent the - // appropriate field of our tagged union. - // - IRType* caseThisType = (IRType*) irTaggedUnionType->getOperand(ii); - auto caseThisArg = subBuilder->emitExtractTaggedUnionPayload( - caseThisType, - irThisParam, caseTag); - caseArgs.add(caseThisArg); - - // The remaining arguments to the call will just be forwarded from - // the parameters of the wrapper function. - // - // TODO: This would need to change if/when we started allowing `This` type - // or associated-type parameters to be used at call sites where a tagged - // union is used. - // - for( auto param : irParams ) - { - caseArgs.add(param); - } - - auto caseCall = subBuilder->emitCallInst(caseResultType, caseFunc, caseArgs); - - if( as<IRVoidType>(irResultType->getDataType()) ) - { - subBuilder->emitReturn(); - } - else - { - subBuilder->emitReturn(caseCall); - } - } - - // We will create a block to represent the supposedly-unreachable - // code that will run if no `case` matches. - // - subBuilder->setInsertInto(irFunc); - auto invalidLabel = subBuilder->emitBlock(); - subBuilder->setInsertInto(invalidLabel); - subBuilder->emitUnreachable(); - - if(!defaultLabel) defaultLabel = invalidLabel; - - // Now we have enough information to go back and emit the `switch` instruction - // into the entry block. - subBuilder->setInsertInto(entryBlock); - subBuilder->emitSwitch( - irTagVal, // value to `switch` on - invalidLabel, // `break` label (block after the `switch` statement ends) - defaultLabel, // `default` label (where to go if no `case` matches) - switchCaseOperands.getCount(), - switchCaseOperands.getBuffer()); - } - else - { - // TODO: We need to handle other cases of interface requirements. - SLANG_UNEXPECTED("unexpceted interface requirement when generating tagged union witness table"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - // Once we've generating a value to satisfying the requirement, we install - // it into the witness table for our tagged-union type. - // - getBuilder()->createWitnessTableEntry(irWitnessTable, irReqKey, irSatisfyingVal); - } - return LoweredValInfo::simple(irWitnessTable); - } - LoweredValInfo visitDynamicSubtypeWitness(DynamicSubtypeWitness * /*val*/) { return LoweredValInfo::simple(nullptr); } - LoweredValInfo visitThisTypeSubtypeWitness(ThisTypeSubtypeWitness* val) - { - SLANG_UNUSED(val); - return LoweredValInfo::simple(context->thisTypeWitness); - } - LoweredValInfo visitConjunctionSubtypeWitness(ConjunctionSubtypeWitness* val) { // A witness `W = X & Y & ...` will lower as a tuple of the sub-witnesses @@ -1892,14 +1663,14 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // witness that `T : L & R`, so lower that first and expect it to be // a value of tuple type. // - auto conjunctionWitness = lowerSimpleVal(context, val->conjunctionWitness); + auto conjunctionWitness = lowerSimpleVal(context, val->getConjunctionWitness()); auto conjunctionTupleType = as<IRTupleType>(conjunctionWitness->getDataType()); SLANG_ASSERT(conjunctionTupleType); // The `ExtractFromConjunctionSubtypeWitness` also stores the index of // the witness/supertype we want in the conjunction `L & R`. // - auto indexInConjunction = val->indexInConjunction; + auto indexInConjunction = val->getIndexInConjunction(); // We want to extract the appropriate element from the tuple based on // the index, but to know the type of the result we need to look up @@ -1923,8 +1694,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitConstantIntVal(ConstantIntVal* val) { - auto type = lowerType(context, val->type); - return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value)); + auto type = lowerType(context, val->getType()); + return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->getValue())); } IRFuncType* visitFuncType(FuncType* type) @@ -1964,7 +1735,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower IRType* visitDeclRefType(DeclRefType* type) { - auto declRef = type->declRef; + auto declRef = type->getDeclRef(); auto decl = declRef.getDecl(); // Check for types with teh `__intrinsic_type` modifier. @@ -1988,13 +1759,13 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower IRType* visitBasicExpressionType(BasicExpressionType* type) { return getBuilder()->getBasicType( - type->baseType); + type->getBaseType()); } IRType* visitVectorExpressionType(VectorExpressionType* type) { - auto elementType = lowerType(context, type->elementType); - auto elementCount = lowerSimpleVal(context, type->elementCount); + auto elementType = lowerType(context, type->getElementType()); + auto elementCount = lowerSimpleVal(context, type->getElementCount()); return getBuilder()->getVectorType( elementType, @@ -2030,19 +1801,6 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower } } - // Lower substitution args and collect them into a list of IR operands. - void _collectSubstitutionArgs(List<IRInst*>& operands, Substitutions* subst) - { - if (!subst) return; - _collectSubstitutionArgs(operands, subst->getOuter()); - if (auto genSubst = as<GenericSubstitution>(subst)) - { - for (auto arg : genSubst->getArgs()) - { - operands.add(lowerVal(context, arg).val); - } - } - } // Lower a type where the type declaration being referenced is assumed // to be an intrinsic type, which can thus be lowered to a simple IR // type with the appropriate opcode. @@ -2050,13 +1808,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower { SLANG_ASSERT(getBuilder()->getInsertLoc().getMode() != IRInsertLoc::Mode::None); - auto intrinsicTypeModifier = type->declRef.getDecl()->findModifier<IntrinsicTypeModifier>(); + auto intrinsicTypeModifier = type->getDeclRef().getDecl()->findModifier<IntrinsicTypeModifier>(); SLANG_ASSERT(intrinsicTypeModifier); IROp op = IROp(intrinsicTypeModifier->irOp); List<IRInst*> operands; // If there are any substitutions attached to the declRef, // add them as operands of the IR type. - _collectSubstitutionArgs(operands, type->declRef.getSubst()); + SubstitutionSet(type->getDeclRef()).forEachSubstitutionArg([&](Val* arg) + { + operands.add(lowerVal(context, arg).val); + }); return getBuilder()->getType( op, static_cast<UInt>(operands.getCount()), @@ -2095,7 +1856,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower IRType* visitExtractExistentialType(ExtractExistentialType* type) { - auto declRef = type->declRef; + auto declRef = type->getDeclRef(); auto existentialType = lowerType(context, getType(context->astBuilder, declRef)); IRInst* existentialVal = getSimpleVal(context, emitDeclRef(context, declRef, existentialType)); return getBuilder()->emitExtractExistentialType(existentialVal); @@ -2103,50 +1864,20 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitExtractExistentialSubtypeWitness(ExtractExistentialSubtypeWitness* witness) { - auto declRef = witness->declRef; + auto declRef = witness->getDeclRef(); auto existentialType = lowerType(context, getType(context->astBuilder, declRef)); IRInst* existentialVal = getSimpleVal(context, emitDeclRef(context, declRef, existentialType)); return LoweredValInfo::simple(getBuilder()->emitExtractExistentialWitnessTable(existentialVal)); } - LoweredValInfo visitTaggedUnionType(TaggedUnionType* type) - { - // A tagged union type will lower into an IR `union` over the cases, - // along with an IR `struct` with a field for the union and a tag. - // (Note: we are placing the tag after the payload to avoid padding - // in the case where the payload is more aligned than the tag) - // - // TODO: should we be lowering directly like this, or have - // an IR-level representation of tagged unions? - // - - List<IRType*> irCaseTypes; - for(auto caseType : type->caseTypes) - { - auto irCaseType = lowerType(context, caseType); - irCaseTypes.add(irCaseType); - } - - auto irType = getBuilder()->getTaggedUnionType(irCaseTypes); - if(!irType->findDecoration<IRLinkageDecoration>()) - { - // We need a way for later passes to attach layout information - // to this type, so we will give it a mangled name here. - // - getBuilder()->addExportDecoration( - irType, - getMangledTypeName(context->astBuilder, type).getUnownedSlice()); - } - return LoweredValInfo::simple(irType); - } - LoweredValInfo visitExistentialSpecializedType(ExistentialSpecializedType* type) { - auto irBaseType = lowerType(context, type->baseType); + auto irBaseType = lowerType(context, type->getBaseType()); List<IRInst*> slotArgs; - for(auto arg : type->args) + for (Index i = 0; i < type->getArgCount(); i++) { + auto arg = type->getArg(i); auto irArgVal = lowerSimpleVal(context, arg.val); slotArgs.add(irArgVal); @@ -2173,13 +1904,13 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // if (context->thisType != nullptr) return LoweredValInfo::simple(context->thisType); - return emitDeclRef(context, type->interfaceDeclRef, getBuilder()->getTypeKind()); + return emitDeclRef(context, makeDeclRef(type->getInterfaceDecl()), getBuilder()->getTypeKind()); } LoweredValInfo visitAndType(AndType* type) { - auto left = lowerType(context, type->left); - auto right = lowerType(context, type->right); + auto left = lowerType(context, type->getLeft()); + auto right = lowerType(context, type->getRight()); auto irType = getBuilder()->getConjunctionType(left, right); return LoweredValInfo::simple(irType); @@ -2187,11 +1918,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitModifiedType(ModifiedType* astType) { - IRType* irBase = lowerType(context, astType->base); + IRType* irBase = lowerType(context, astType->getBase()); List<IRAttr*> irAttrs; - for(auto astModifier : astType->modifiers) + for(Index i = 0; i < astType->getModifierCount(); i++) { + auto astModifier = astType->getModifier(i); IRAttr* irAttr = (IRAttr*) lowerSimpleVal(context, astModifier); if(irAttr) irAttrs.add(irAttr); @@ -2237,7 +1969,8 @@ LoweredValInfo lowerVal( { ValLoweringVisitor visitor; visitor.context = context; - return visitor.dispatch(val); + auto resolvedVal = val->resolve(); + return visitor.dispatch(resolvedVal); } IRType* lowerType( @@ -2786,8 +2519,8 @@ ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection de DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, SemanticsVisitor* semantics, Decl* decl) { - DeclRef<Decl> declRef = context->astBuilder->getSpecializedDeclRef( - decl, createDefaultSubstitutions(context->astBuilder, semantics, decl)); + DeclRef<Decl> declRef = createDefaultSubstitutionsIfNeeded(context->astBuilder, semantics, + makeDeclRef(decl)); return declRef; } // @@ -2808,7 +2541,7 @@ static Type* _findReplacementThisParamType( auto targetType = getTargetType(context->astBuilder, extensionDeclRef); if(auto targetDeclRefType = as<DeclRefType>(targetType)) { - if(auto replacementType = _findReplacementThisParamType(context, targetDeclRefType->declRef)) + if(auto replacementType = _findReplacementThisParamType(context, targetDeclRefType->getDeclRef())) return replacementType; } return targetType; @@ -2816,8 +2549,7 @@ static Type* _findReplacementThisParamType( if (auto interfaceDeclRef = parentDeclRef.as<InterfaceDecl>()) { - auto thisType = context->astBuilder->create<ThisType>(); - thisType->interfaceDeclRef = interfaceDeclRef; + auto thisType = DeclRefType::create(context->astBuilder, interfaceDeclRef.getDecl()->getThisTypeDecl()); return thisType; } @@ -2853,13 +2585,13 @@ Type* getThisParamTypeForCallable( IRGenContext* context, DeclRef<Decl> callableDeclRef) { - auto parentDeclRef = callableDeclRef.getParent(context->astBuilder); + auto parentDeclRef = callableDeclRef.getParent(); if(auto subscriptDeclRef = parentDeclRef.as<SubscriptDecl>()) - parentDeclRef = subscriptDeclRef.getParent(context->astBuilder); + parentDeclRef = subscriptDeclRef.getParent(); if(auto genericDeclRef = parentDeclRef.as<GenericDecl>()) - parentDeclRef = genericDeclRef.getParent(context->astBuilder); + parentDeclRef = genericDeclRef.getParent(); return getThisParamTypeForContainer(context, parentDeclRef); } @@ -2997,7 +2729,7 @@ void collectParameterLists( // The parameters introduced by any "parent" declarations // will need to come first, so we'll deal with that // logic here. - if( auto parentDeclRef = declRef.getParent(context->astBuilder) ) + if( auto parentDeclRef = declRef.getParent() ) { // Compute the mode to use when collecting parameters from // the outer declaration. The most important question here @@ -3592,7 +3324,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> { auto innerType = type; while (auto modifiedType = as<ModifiedType>(innerType)) - innerType = modifiedType->base; + innerType = modifiedType->getBase(); return innerType; } @@ -3607,9 +3339,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } else if (auto vectorType = as<VectorExpressionType>(type)) { - UInt elementCount = (UInt) getIntVal(vectorType->elementCount); + UInt elementCount = (UInt) getIntVal(vectorType->getElementCount()); - auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType)); + auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->getElementType())); List<IRInst*> args; for(UInt ee = 0; ee < elementCount; ++ee) @@ -3644,7 +3376,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } else if (auto declRefType = as<DeclRefType>(type)) { - DeclRef<Decl> declRef = declRefType->declRef; + DeclRef<Decl> declRef = declRefType->getDeclRef(); if (auto enumType = declRef.as<EnumDecl>()) { return LoweredValInfo::simple(getBuilder()->getIntValue(irType, 0)); @@ -3735,7 +3467,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } else if (auto vectorType = as<VectorExpressionType>(type)) { - UInt elementCount = (UInt) getIntVal(vectorType->elementCount); + UInt elementCount = (UInt) getIntVal(vectorType->getElementCount()); for (UInt ee = 0; ee < argCount; ++ee) { @@ -3745,7 +3477,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } if(elementCount > argCount) { - auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType)); + auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->getElementType())); for(UInt ee = argCount; ee < elementCount; ++ee) { args.add(irDefaultValue); @@ -3781,7 +3513,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } else if (auto declRefType = as<DeclRefType>(type)) { - DeclRef<Decl> declRef = declRefType->declRef; + DeclRef<Decl> declRef = declRefType->getDeclRef(); if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>()) { UInt argCounter = 0; @@ -3896,19 +3628,19 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } - void _lowerSubstitutionArg(IRGenContext* subContext, GenericSubstitution* subst, Decl* paramDecl, Index argIndex) + void _lowerSubstitutionArg(IRGenContext* subContext, GenericAppDeclRef* subst, Decl* paramDecl, Index argIndex) { SLANG_ASSERT(argIndex < subst->getArgs().getCount()); auto argVal = lowerVal(subContext, subst->getArgs()[argIndex]); subContext->setValue(paramDecl, argVal); } - void _lowerSubstitutionEnv(IRGenContext* subContext, Substitutions* subst) + void _lowerSubstitutionEnv(IRGenContext* subContext, DeclRefBase* subst) { if(!subst) return; - _lowerSubstitutionEnv(subContext, subst->getOuter()); + _lowerSubstitutionEnv(subContext, subst->getBase()); - if (auto genSubst = as<GenericSubstitution>(subst)) + if (auto genSubst = as<GenericAppDeclRef>(subst)) { auto genDecl = genSubst->getGenericDecl(); @@ -3985,7 +3717,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> IRGenContext* subContext = &subContextStorage; subContext->env = subEnv; - _lowerSubstitutionEnv(subContext, argExpr.getSubsts()); + _lowerSubstitutionEnv(subContext, argExpr.getSubsts() ? argExpr.getSubsts().declRef : nullptr); addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups); @@ -4148,7 +3880,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> { if(auto declRefType = as<DeclRefType>(e->type)) { - if(declRefType->declRef.as<InterfaceDecl>()) + if(declRefType->getDeclRef().as<InterfaceDecl>()) { e = castExpr->valueArg; continue; @@ -4387,7 +4119,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> { if( auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(subTypeWitness) ) { - return extractField(superType, value, declaredSubtypeWitness->declRef); + return extractField(superType, value, declaredSubtypeWitness->getDeclRef()); } else { @@ -4414,7 +4146,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // if( auto declRefType = as<DeclRefType>(expr->type) ) { - auto declRef = declRefType->declRef; + auto declRef = declRefType->getDeclRef(); if( auto interfaceDeclRef = declRef.as<InterfaceDecl>() ) { // We have an expression that is "up-casting" some concrete value @@ -4573,12 +4305,6 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } - LoweredValInfo visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* /*expr*/) - { - SLANG_UNIMPLEMENTED_X("tagged union type expression during code generation"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - LoweredValInfo visitThisTypeExpr(ThisTypeExpr* /*expr*/) { SLANG_UNIMPLEMENTED_X("this-type expression during code generation"); @@ -6676,7 +6402,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (auto declRefType = as<DeclRefType>(type)) { - if (declRefType->declRef.getDecl()->findModifier<PublicModifier>()) + if (declRefType->getDeclRef().getDecl()->findModifier<PublicModifier>()) return true; } return false; @@ -6690,7 +6416,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { auto subBuilder = subContext->irBuilder; - for(auto entry : astWitnessTable->requirementDictionary) + for(auto entry : astWitnessTable->getRequirementDictionary()) { auto requiredMemberDecl = entry.key; auto satisfyingWitness = entry.value; @@ -6787,7 +6513,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto targetType = parentExtensionDecl->targetType; if(auto targetDeclRefType = as<DeclRefType>(targetType)) { - if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>()) + if(auto targetInterfaceDeclRef = targetDeclRefType->getDeclRef().as<InterfaceDecl>()) { return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl)); } @@ -6815,7 +6541,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if(auto superDeclRefType = as<DeclRefType>(superType)) { - if( superDeclRefType->declRef.as<StructDecl>() || superDeclRefType->declRef.as<ClassDecl>() ) + if( superDeclRefType->getDeclRef().as<StructDecl>() || superDeclRefType->getDeclRef().as<ClassDecl>() ) { // TODO: the witness that a type inherits from a `struct` // type should probably be a key that will be used for @@ -7675,6 +7401,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo::simple(finishOuterGenerics(subBuilder, loweredTagType, outerGeneric)); } + LoweredValInfo visitThisTypeDecl(ThisTypeDecl* decl) + { + auto interfaceType = ensureDecl(context, decl->parentDecl).val; + return LoweredValInfo::simple(context->irBuilder->getThisType(as<IRInterfaceType>(interfaceType))); + } + + LoweredValInfo visitThisTypeConstraintDecl(ThisTypeConstraintDecl* decl) + { + SLANG_UNUSED(decl); + return LoweredValInfo(); + } + LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl) { // Don't generate an IR `struct` for intrinsic types @@ -7753,8 +7491,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto superType = inheritanceDecl->base; if(auto superDeclRefType = as<DeclRefType>(superType)) { - if (superDeclRefType->declRef.as<StructDecl>() || - superDeclRefType->declRef.as<ClassDecl>()) + if (superDeclRefType->getDeclRef().as<StructDecl>() || + superDeclRefType->getDeclRef().as<ClassDecl>()) { auto superKey = (IRStructKey*) getSimpleVal(context, ensureDecl(context, inheritanceDecl)); auto irSuperType = lowerType(context, superType.type); @@ -7890,10 +7628,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> void lowerDifferentiableAttribute(IRGenContext* subContext, IRInst* inst, DifferentiableAttribute* attr) { auto irDict = getBuilder()->addDifferentiableTypeDictionaryDecoration(inst); - for (auto& entry : attr->m_mapTypeToIDifferentiableWitness) + for (auto& entry : attr->getMapTypeToIDifferentiableWitness()) { // Lower type and witness. - IRType* irType = lowerType(subContext, entry.value->sub); + IRType* irType = lowerType(subContext, entry.value->getSub()); IRInst* irWitness = lowerVal(subContext, entry.value).val; SLANG_ASSERT(irType); @@ -8028,7 +7766,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // if (auto declRefType = as<DeclRefType>(constraintDecl->sub.type)) { - auto typeParamDeclVal = subContext->findLoweredDecl(declRefType->declRef.getDecl()); + auto typeParamDeclVal = subContext->findLoweredDecl(declRefType->getDeclRef().getDecl()); SLANG_ASSERT(typeParamDeclVal && typeParamDeclVal->val); subBuilder->addTypeConstraintDecoration(typeParamDeclVal->val, supType); } @@ -8162,7 +7900,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { if(auto targetDeclRefType = as<DeclRefType>(extensionAncestor->targetType)) { - if(auto interfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>()) + if(auto interfaceDeclRef = targetDeclRefType->getDeclRef().as<InterfaceDecl>()) { return emitOuterInterfaceGeneric(subContext, extensionAncestor, targetDeclRefType, leafDecl); } @@ -8608,9 +8346,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addNameHint(subContext, irFunc, decl); addLinkageDecoration(subContext, irFunc, decl); - if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>()) + if (decl->body) { - lowerDifferentiableAttribute(subContext, irFunc, differentialAttr); + if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>()) + { + lowerDifferentiableAttribute(subContext, irFunc, differentialAttr); + } } // Always force inline diff setter accessor to prevent downstream compiler from complaining @@ -9375,7 +9116,7 @@ static void _addFlattenedTupleArgs( LoweredValInfo emitDeclRef( IRGenContext* context, Decl* decl, - Substitutions* subst, + DeclRefBase* subst, IRType* type) { const auto initialSubst = subst; @@ -9383,27 +9124,28 @@ LoweredValInfo emitDeclRef( // We need to proceed by considering the specializations that // have been put in place. + subst = SubstitutionSet(subst).getInnerMostNodeWithSubstInfo(); // If the declaration would not get wrapped in a `IRGeneric`, // even if it is nested inside of an AST `GenericDecl`, then // we should also ignore any generic substitutions. if(!canDeclLowerToAGeneric(decl)) { - while(auto genericSubst = as<GenericSubstitution>(subst)) - subst = genericSubst->getOuter(); + while(auto genericSubst = SubstitutionSet(subst).findGenericAppDeclRef()) + subst = genericSubst->getBase(); } // In the simplest case, there is no specialization going // on, and the decl-ref turns into a reference to the // lowered IR value for the declaration. - if(!subst) + if(!SubstitutionSet(subst) || _isTrivialLookupFromInterfaceThis(context, subst)) { LoweredValInfo loweredDecl = ensureDecl(context, decl); return loweredDecl; } // Otherwise, we look at the kind of substitution, and let it guide us. - if(auto genericSubst = as<GenericSubstitution>(subst)) + if(auto genericSubst = as<GenericAppDeclRef>(subst)) { // A generic substitution means we will need to output // a `specialize` instruction to specialize the generic. @@ -9419,7 +9161,7 @@ LoweredValInfo emitDeclRef( LoweredValInfo genericVal = emitDeclRef( context, decl, - genericSubst->getOuter(), + genericSubst->getBase(), context->irBuilder->getGenericKind()); // There's no reason to specialize something that maps to a NULL pointer. @@ -9464,21 +9206,21 @@ LoweredValInfo emitDeclRef( return LoweredValInfo::simple(irSpecializedVal); } - else if(auto thisTypeSubst = as<ThisTypeSubstitution>(subst)) + else if(auto thisTypeSubst = as<LookupDeclRef>(subst)) { - if(decl == thisTypeSubst->interfaceDecl) + if( as<ThisTypeDecl>(decl)) { - // This is a reference to the interface type itself, - // through the this-type substitution, so it is really - // a reference to the this-type. - return lowerType(context, thisTypeSubst->witness->sub); + // This is a reference to the ThisType from the interface, + // therefore we should just lower it as the sub type. + return lowerType(context, thisTypeSubst->getWitness()->getSub()); } if(isInterfaceRequirement(decl)) { - // Somebody is trying to look up an interface requirement - // "through" some concrete type. We need to lower this decl-ref - // as a lookup of the corresponding member in a witness table. + // If we reach here, somebody is trying to look up an interface + // requirement "through" some concrete type. We need to lower this + // decl-ref as a lookup of the corresponding member in a witness + // table. // // The witness table itself is referenced by the this-type // substitution, so we can just lower that. @@ -9491,7 +9233,7 @@ LoweredValInfo emitDeclRef( // `ISomething<T>`. That is because we really care about the // witness table for the concrete type that conforms to `ISomething<Foo>`. // - auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness); + auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->getWitness()); // // The key to use for looking up the interface member is // derived from the declaration. @@ -9517,14 +9259,14 @@ LoweredValInfo emitDeclRef( // are lowered as generics, where the generic parameter represents // the `ThisType`. // - auto genericVal = emitDeclRef(context, decl, thisTypeSubst->getOuter(), context->irBuilder->getGenericKind()); + auto genericVal = emitDeclRef(context, decl, thisTypeSubst->getBase(), context->irBuilder->getGenericKind()); auto irGenericVal = getSimpleVal(context, genericVal); // In order to reference the member for a particular type, we // specialize the generic for that type. // - IRInst* irSubType = lowerType(context, thisTypeSubst->witness->sub); - IRInst* irSubTypeWitness = lowerSimpleVal(context, thisTypeSubst->witness); + IRInst* irSubType = lowerType(context, thisTypeSubst->getWitness()->getSub()); + IRInst* irSubTypeWitness = lowerSimpleVal(context, thisTypeSubst->getWitness()); IRInst* irSpecializeArgs[] = { irSubType, irSubTypeWitness }; auto irSpecializedVal = context->irBuilder->emitSpecializeInst( @@ -9550,7 +9292,7 @@ LoweredValInfo emitDeclRef( return emitDeclRef( context, declRef.getDecl(), - declRef.getSubst(), + declRef.declRefBase, type); } @@ -9723,6 +9465,7 @@ RefPtr<IRModule> generateIRForTranslationUnit( TranslationUnitRequest* translationUnit) { SLANG_PROFILE; + SLANG_AST_BUILDER_RAII(astBuilder); auto session = translationUnit->getSession(); auto compileRequest = translationUnit->compileRequest; @@ -10082,6 +9825,8 @@ RefPtr<IRModule> generateIRForSpecializedComponentType( SpecializedComponentType* componentType, DiagnosticSink* sink) { + SLANG_AST_BUILDER_RAII(componentType->getLinkage()->getASTBuilder()); + SpecializedComponentTypeIRGenContext context; return context.process(componentType, sink); } @@ -10135,6 +9880,8 @@ RefPtr<IRModule> generateIRForTypeConformance( Int conformanceIdOverride, DiagnosticSink* sink) { + SLANG_AST_BUILDER_RAII(typeConformance->getLinkage()->getASTBuilder()); + TypeConformanceIRGenContext context; return context.process(typeConformance, conformanceIdOverride, sink); } @@ -10296,20 +10043,6 @@ IRTypeLayout* lowerTypeLayout( IRPointerTypeLayout::Builder builder(context->irBuilder); return _lowerTypeLayoutCommon(context, &builder, ptrTypeLayout); } - else if( auto taggedUnionTypeLayout = as<TaggedUnionTypeLayout>(typeLayout) ) - { - IRTaggedUnionTypeLayout::Builder builder(context->irBuilder, taggedUnionTypeLayout->tagOffset); - - for( auto caseTypeLayout : taggedUnionTypeLayout->caseTypeLayouts ) - { - builder.addCaseTypeLayout( - lowerTypeLayout( - context, - caseTypeLayout)); - } - - return _lowerTypeLayoutCommon(context, &builder, taggedUnionTypeLayout); - } else if( auto streamOutputTypeLayout = as<StreamOutputTypeLayout>(typeLayout) ) { auto irElementTypeLayout = lowerTypeLayout(context, streamOutputTypeLayout->elementTypeLayout); @@ -10453,6 +10186,9 @@ RefPtr<IRModule> TargetProgram::createIRModuleForLayout(DiagnosticSink* sink) auto program = getProgram(); auto linkage = program->getLinkage(); + + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + auto session = linkage->getSessionImpl(); SharedIRGenContext sharedContextStorage( @@ -10550,16 +10286,6 @@ RefPtr<IRModule> TargetProgram::createIRModuleForLayout(DiagnosticSink* sink) builder->addLayoutDecoration(irFunc, irEntryPointLayout); } - for( auto taggedUnionTypeLayout : programLayout->taggedUnionTypeLayouts ) - { - auto taggedUnionType = taggedUnionTypeLayout->getType(); - auto irType = lowerType(context, taggedUnionType); - - auto irTypeLayout = lowerTypeLayout(context, taggedUnionTypeLayout); - - builder->addLayoutDecoration(irType, irTypeLayout); - } - // Lets strip and run DCE here if (linkage->m_obfuscateCode) { diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 4d94d5283..b27a45484 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -141,7 +141,7 @@ namespace Slang { if( auto constVal = as<ConstantIntVal>(val) ) { - auto cVal = constVal->value; + auto cVal = constVal->getValue(); if(cVal >= 0 && cVal <= 9 ) { emit(context, (UInt)cVal); @@ -190,13 +190,13 @@ namespace Slang if( auto basicType = dynamicCast<BasicExpressionType>(type) ) { - emitBaseType(context, basicType->baseType); + emitBaseType(context, basicType->getBaseType()); } else if( auto vecType = dynamicCast<VectorExpressionType>(type) ) { emitRaw(context, "v"); - emitSimpleIntVal(context, vecType->elementCount); - emitType(context, vecType->elementType); + emitSimpleIntVal(context, vecType->getElementCount()); + emitType(context, vecType->getElementType()); } else if( auto matType = dynamicCast<MatrixExpressionType>(type) ) { @@ -208,11 +208,11 @@ namespace Slang } else if( auto namedType = dynamicCast<NamedExpressionType>(type) ) { - emitType(context, getType(context->astBuilder, namedType->declRef)); + emitType(context, getType(context->astBuilder, namedType->getDeclRef())); } else if( auto declRefType = dynamicCast<DeclRefType>(type) ) { - emitQualifiedName(context, declRefType->declRef); + emitQualifiedName(context, declRefType->getDeclRef()); } else if (auto arrType = dynamicCast<ArrayExpressionType>(type)) { @@ -220,19 +220,10 @@ namespace Slang emitSimpleIntVal(context, arrType->getElementCount()); emitType(context, arrType->getElementType()); } - else if( auto taggedUnionType = dynamicCast<TaggedUnionType>(type) ) - { - emitRaw(context, "u"); - for( auto caseType : taggedUnionType->caseTypes ) - { - emitType(context, caseType); - } - emitRaw(context, "U"); - } else if( auto thisType = dynamicCast<ThisType>(type) ) { emitRaw(context, "t"); - emitQualifiedName(context, thisType->interfaceDeclRef); + emitQualifiedName(context, thisType->getInterfaceDecl()); } else if (const auto errorType = dynamicCast<ErrorType>(type)) { @@ -300,50 +291,50 @@ namespace Slang // "depth" (how many outer generics) and "index" (which // parameter are they at the specified depth). emitRaw(context, "K"); - emitName(context, genericParamIntVal->declRef.getName()); + emitName(context, genericParamIntVal->getDeclRef().getName()); } else if( auto constantIntVal = dynamicCast<ConstantIntVal>(val) ) { // TODO: need to figure out what prefix/suffix is needed // to allow demangling later. emitRaw(context, "k"); - emit(context, (UInt) constantIntVal->value); + emit(context, (UInt) constantIntVal->getValue()); } else if (auto funcCallIntVal = dynamicCast<FuncCallIntVal>(val)) { emitRaw(context, "KC"); - emit(context, funcCallIntVal->args.getCount()); - emitName(context, funcCallIntVal->funcDeclRef.getName()); - for (Index i = 0; i < funcCallIntVal->args.getCount(); i++) - emitVal(context, funcCallIntVal->args[i]); + emit(context, funcCallIntVal->getArgs().getCount()); + emitName(context, funcCallIntVal->getFuncDeclRef().getName()); + for (Index i = 0; i < funcCallIntVal->getArgs().getCount(); i++) + emitVal(context, funcCallIntVal->getArgs()[i]); } else if (auto lookupIntVal = dynamicCast<WitnessLookupIntVal>(val)) { emitRaw(context, "KL"); - emitVal(context, lookupIntVal->witness); - emitName(context, lookupIntVal->key->getName()); + emitVal(context, lookupIntVal->getWitness()); + emitName(context, lookupIntVal->getKey()->getName()); } else if (const auto polynomialIntVal = dynamicCast<PolynomialIntVal>(val)) { emitRaw(context, "KX"); - emit(context, (UInt)polynomialIntVal->constantTerm); - emit(context, (UInt)polynomialIntVal->terms.getCount()); - for (auto term : polynomialIntVal->terms) + emit(context, (UInt)polynomialIntVal->getConstantTerm()); + emit(context, (UInt)polynomialIntVal->getTerms().getCount()); + for (auto term : polynomialIntVal->getTerms()) { - emit(context, (UInt)term->constFactor); - emit(context, (UInt)term->paramFactors.getCount()); - for (auto factor : term->paramFactors) + emit(context, (UInt)term->getConstFactor()); + emit(context, (UInt)term->getParamFactors().getCount()); + for (auto factor : term->getParamFactors()) { - emitVal(context, factor->param); - emit(context, (UInt)factor->power); + emitVal(context, factor->getParam()); + emit(context, (UInt)factor->getPower()); } } } else if (const auto typecastIntVal = dynamicCast<TypeCastIntVal>(val)) { emitRaw(context, "KK"); - emitVal(context, typecastIntVal->type); - emitVal(context, typecastIntVal->base); + emitVal(context, typecastIntVal->getType()); + emitVal(context, typecastIntVal->getBase()); } else { @@ -355,7 +346,7 @@ namespace Slang ManglingContext* context, DeclRef<Decl> declRef) { - auto parentDeclRef = declRef.getParent(context->astBuilder); + auto parentDeclRef = declRef.getParent(); auto parentGenericDeclRef = parentDeclRef.as<GenericDecl>(); if( parentDeclRef ) { @@ -423,14 +414,14 @@ namespace Slang // There are two cases here: either we have specializations // in place for the parent generic declaration, or we don't. - auto subst = findInnerMostGenericSubstitution(declRef.getSubst()); - if( subst && subst->getGenericDecl() == parentGenericDeclRef.getDecl()) + auto substArgs = tryGetGenericArguments(SubstitutionSet(declRef), parentGenericDeclRef.getDecl()); + if (substArgs.getCount()) { // This is the case where we *do* have substitutions. emitRaw(context, "G"); - UInt genericArgCount = subst->getArgs().getCount(); + UInt genericArgCount = substArgs.getCount(); emit(context, genericArgCount); - for (auto aa : subst->getArgs()) + for (auto aa : substArgs) { emitVal(context, aa); } @@ -441,7 +432,7 @@ namespace Slang // information about the parameters of the generic here. emitRaw(context, "g"); UInt genericParameterCount = 0; - for( auto mm : getMembers(context->astBuilder, parentGenericDeclRef) ) + for( auto mm : getMembers(context->astBuilder, parentGenericDeclRef.as<ContainerDecl>()) ) { if(mm.is<GenericTypeParamDecl>()) { @@ -569,7 +560,7 @@ namespace Slang // mangling the generic and the inner entity emitRaw(context, "G"); - SLANG_ASSERT(genericDecl.getSubst() == nullptr); + SLANG_ASSERT(SubstitutionSet(genericDecl).findGenericAppDeclRef() == nullptr); auto innerDecl = getInner(genericDecl); @@ -591,6 +582,7 @@ namespace Slang static String getMangledName(ASTBuilder* astBuilder, DeclRef<Decl> const& declRef) { + SLANG_AST_BUILDER_RAII(astBuilder); ManglingContext context(astBuilder); mangleName(&context, declRef); return context.sb.produceString(); @@ -598,11 +590,15 @@ namespace Slang String getMangledName(ASTBuilder* astBuilder, DeclRefBase* declRef) { + SLANG_AST_BUILDER_RAII(astBuilder); + return getMangledName(astBuilder, DeclRef<Decl>(declRef)); } String getMangledName(ASTBuilder* astBuilder, Decl* decl) { + SLANG_AST_BUILDER_RAII(astBuilder); + return getMangledName(astBuilder, makeDeclRef(decl)); } @@ -611,6 +607,7 @@ namespace Slang DeclRef<Decl> sub, DeclRef<Decl> sup) { + SLANG_AST_BUILDER_RAII(astBuilder); ManglingContext context(astBuilder); emitRaw(&context, "_SW"); emitQualifiedName(&context, sub); @@ -623,6 +620,7 @@ namespace Slang DeclRef<Decl> sub, Type* sup) { + SLANG_AST_BUILDER_RAII(astBuilder); // The mangled form for a witness that `sub` // conforms to `sup` will be named: // @@ -640,6 +638,7 @@ namespace Slang Type* sub, Type* sup) { + SLANG_AST_BUILDER_RAII(astBuilder); // The mangled form for a witness that `sub` // conforms to `sup` will be named: // @@ -654,6 +653,7 @@ namespace Slang String getMangledTypeName(ASTBuilder* astBuilder, Type* type) { + SLANG_AST_BUILDER_RAII(astBuilder); ManglingContext context(astBuilder); emitType(&context, type); return context.sb.produceString(); diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 47f370854..c0389d1cd 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -2064,7 +2064,7 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( // otherwise they will include all of the above cases... else if( auto declRefType = as<DeclRefType>(type) ) { - auto declRef = declRefType->declRef; + auto declRef = declRefType->getDeclRef(); if (auto structDeclRef = declRef.as<StructDecl>()) { @@ -2777,7 +2777,7 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( // Any generic specialization applied to the entry-point function // must also be applied to its parameters. - paramDeclRef = context->getASTBuilder()->getSpecializedDeclRef(paramDeclRef.getDecl(), entryPointFuncDeclRef.getSubst()); + paramDeclRef = context->getASTBuilder()->getMemberDeclRef(entryPointFuncDeclRef, paramDeclRef.getDecl()); // When computing layout for an entry-point parameter, // we want to make sure that the layout context has access @@ -3033,24 +3033,6 @@ struct CollectParametersVisitor : ComponentTypeVisitor // along. // visitChildren(specialized); - - // While we are at it, we will also make note of any - // tagged-union types that were used as part of the - // specialization arguments, since we need to make - // sure that their layout information is computed - // and made available for IR code generation. - // - // Note: this isn't really the best place for this logic to sit, - // but it is the simplest place where we can collect all the tagged - // union types that get referenced by a program. - // - for( auto taggedUnionType : specialized->getTaggedUnionTypes() ) - { - SLANG_ASSERT(taggedUnionType); - auto substType = taggedUnionType; - auto typeLayout = createTypeLayout(m_context->layoutContext, substType); - m_context->shared->programLayout->taggedUnionTypeLayouts.add(typeLayout); - } } @@ -3755,6 +3737,8 @@ RefPtr<ProgramLayout> generateParameterBindings( TargetProgram* targetProgram, DiagnosticSink* sink) { + SLANG_AST_BUILDER_RAII(targetProgram->getProgram()->getLinkage()->getASTBuilder()); + auto program = targetProgram->getProgram(); auto targetReq = targetProgram->getTargetReq(); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index cab01d585..4448a96e1 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1323,10 +1323,9 @@ namespace Slang // parent link is set up correctly. static void AddMember(ContainerDecl* container, Decl* member) { - if (container && member) + if (container) { - member->parentDecl = container; - container->members.add(member); + container->addMember(member); } } @@ -1334,7 +1333,7 @@ namespace Slang { if (scope) { - AddMember(scope->containerDecl, member); + scope->containerDecl->addMember(member); } } @@ -2149,30 +2148,6 @@ namespace Slang } return typeExpr; } - - static Expr* parseTaggedUnionType(Parser* parser) - { - TaggedUnionTypeExpr* taggedUnionType = parser->astBuilder->create<TaggedUnionTypeExpr>(); - - parser->ReadToken(TokenType::LParent); - while(!AdvanceIfMatch(parser, MatchedTokenType::Parentheses)) - { - auto caseType = parser->ParseTypeExp(); - taggedUnionType->caseTypes.add(caseType); - - if(AdvanceIf(parser, TokenType::RParent)) - break; - - parser->ReadToken(TokenType::Comma); - } - - return taggedUnionType; - } - - static NodeBase* parseTaggedUnionType(Parser* parser, void* /*unused*/) - { - return parseTaggedUnionType(parser); - } /// Parse an expression of the form __fwd_diff(fn) where fn is an /// identifier pointing to a function. static Expr* parseForwardDifferentiate(Parser* parser) @@ -2234,19 +2209,6 @@ namespace Slang return parseDispatchKernel(parser); } - /// Parse a `This` type expression - static Expr* parseThisTypeExpr(Parser* parser) - { - ThisTypeExpr* expr = parser->astBuilder->create<ThisTypeExpr>(); - expr->scope = parser->currentScope; - return expr; - } - - static NodeBase* parseThisTypeExpr(Parser* parser, void* /*userData*/) - { - return parseThisTypeExpr(parser); - } - // (a,b,c) style tuples, curently unused #if 0 static Expr* parseTupleTypeExpr(Parser* parser) @@ -2459,22 +2421,6 @@ namespace Slang typeSpec.expr = createDeclRefType(parser, decl); return typeSpec; } - // TODO: This case would not be needed if we had the - // code below dispatch into `parseAtomicExpr`, which - // already includes logic for keyword lookup. - // - // Leaving this case here for now to avoid breaking anything. - // - else if(AdvanceIf(parser, "__TaggedUnion")) - { - typeSpec.expr = parseTaggedUnionType(parser); - return typeSpec; - } - else if(AdvanceIf(parser, "This")) - { - typeSpec.expr = parseThisTypeExpr(parser); - return typeSpec; - } // Uncomment should we decide to enable (a,b,c) tuple types // else if(parser->LookAheadToken(TokenType::LParent)) // { @@ -3170,7 +3116,7 @@ namespace Slang static NodeBase* parseInterfaceDecl(Parser* parser, void* /*userData*/) { - InterfaceDecl* decl = parser->astBuilder->create<InterfaceDecl>(); + InterfaceDecl* decl = parser->astBuilder->createInterfaceDecl(parser->tokenReader.peekLoc()); parser->FillPosition(decl); AdvanceIf(parser, TokenType::CompletionRequest); @@ -4082,6 +4028,8 @@ namespace Slang void Parser::parseSourceFile(ModuleDecl* program) { + SLANG_AST_BUILDER_RAII(astBuilder); + if (outerScope) { currentScope = outerScope; @@ -4328,6 +4276,7 @@ namespace Slang parser->astBuilder, nullptr, // no semantics visitor available yet staticMemberExpr->name, + aggTypeDecl, declRef); if (!lookupResult.isValid() || lookupResult.isOverloaded()) @@ -6252,7 +6201,7 @@ namespace Slang // Need to get the basic type, so we can fit to underlying type if (auto basicExprType = as<BasicExpressionType>(intLit->type.type)) { - value = _fixIntegerLiteral(basicExprType->baseType, value, nullptr, nullptr); + value = _fixIntegerLiteral(basicExprType->getBaseType(), value, nullptr, nullptr); } newLiteral->value = value; @@ -6910,14 +6859,12 @@ namespace Slang // !!!!!!!!!!!!!!!!!!!!!!! Expr !!!!!!!!!!!!!!!!!!!!!!!!!!! _makeParseExpr("this", parseThisExpr), - _makeParseExpr("This", parseThisTypeExpr), _makeParseExpr("true", parseTrueExpr), _makeParseExpr("false", parseFalseExpr), _makeParseExpr("nullptr", parseNullPtrExpr), _makeParseExpr("none", parseNoneExpr), _makeParseExpr("try", parseTryExpr), _makeParseExpr("no_diff", parseTreatAsDifferentiableExpr), - _makeParseExpr("__TaggedUnion", parseTaggedUnionType), _makeParseExpr("__fwd_diff", parseForwardDifferentiate), _makeParseExpr("__bwd_diff", parseBackwardDifferentiate), _makeParseExpr("fwd_diff", parseForwardDifferentiate), diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 7a79e9fcd..9f83d325d 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -269,11 +269,14 @@ SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueInt(SlangReflect if (!userAttr) return SLANG_E_INVALID_ARG; if (index >= (unsigned int)userAttr->args.getCount()) return SLANG_E_INVALID_ARG; - NodeBase* val = nullptr; - if (userAttr->intArgVals.tryGetValue(index, val)) + if (userAttr->intArgVals.getCount() > (Index)index) { - *rs = (int)as<ConstantIntVal>(val)->value; - return 0; + auto intVal = as<ConstantIntVal>(userAttr->intArgVals[index]); + if (intVal) + { + *rs = (int)intVal->getValue(); + return 0; + } } return SLANG_E_INVALID_ARG; } @@ -387,7 +390,7 @@ SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) } else if( auto declRefType = as<DeclRefType>(type) ) { - const auto& declRef = declRefType->declRef; + const auto& declRef = declRefType->getDeclRef(); if(declRef.is<StructDecl>() ) { return SLANG_TYPE_KIND_STRUCT; @@ -429,7 +432,7 @@ SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* inTyp if(auto declRefType = as<DeclRefType>(type)) { - auto declRef = declRefType->declRef; + auto declRef = declRefType->getDeclRef(); if( auto structDeclRef = declRef.as<StructDecl>()) { return (unsigned int)getFields( @@ -452,7 +455,7 @@ SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflect if(auto declRefType = as<DeclRefType>(type)) { - auto declRef = declRefType->declRef; + auto declRef = declRefType->getDeclRef(); if( auto structDeclRef = declRef.as<StructDecl>()) { auto fields = getFields( @@ -476,7 +479,7 @@ SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType) } else if( auto vectorType = as<VectorExpressionType>(type)) { - return (size_t) getIntVal(vectorType->elementCount); + return (size_t) getIntVal(vectorType->getElementCount()); } return 0; @@ -493,15 +496,15 @@ SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionTy } else if( auto parameterGroupType = as<ParameterGroupType>(type)) { - return convert(parameterGroupType->elementType); + return convert(parameterGroupType->getElementType()); } else if (auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type)) { - return convert(structuredBufferType->elementType); + return convert(structuredBufferType->getElementType()); } else if( auto vectorType = as<VectorExpressionType>(type)) { - return convert(vectorType->elementType); + return convert(vectorType->getElementType()); } else if( auto matrixType = as<MatrixExpressionType>(type)) { @@ -543,7 +546,7 @@ SLANG_API unsigned int spReflectionType_GetColumnCount(SlangReflectionType* inTy } else if(auto vectorType = as<VectorExpressionType>(type)) { - return (unsigned int) getIntVal(vectorType->elementCount); + return (unsigned int) getIntVal(vectorType->getElementCount()); } else if( const auto basicType = as<BasicExpressionType>(type) ) { @@ -564,12 +567,12 @@ SLANG_API SlangScalarType spReflectionType_GetScalarType(SlangReflectionType* in } else if(auto vectorType = as<VectorExpressionType>(type)) { - type = vectorType->elementType; + type = vectorType->getElementType(); } if(auto basicType = as<BasicExpressionType>(type)) { - switch (basicType->baseType) + switch (basicType->getBaseType()) { #define CASE(BASE, TAG) \ case BaseType::BASE: return SLANG_SCALAR_TYPE_##TAG @@ -606,7 +609,7 @@ SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionTyp if (!type) return 0; if (auto declRefType = as<DeclRefType>(type)) { - return getUserAttributeCount(declRefType->declRef.getDecl()); + return getUserAttributeCount(declRefType->getDeclRef().getDecl()); } return 0; } @@ -616,7 +619,7 @@ SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangR if (!type) return 0; if (auto declRefType = as<DeclRefType>(type)) { - return getUserAttributeByIndex(declRefType->declRef.getDecl(), index); + return getUserAttributeByIndex(declRefType->getDeclRef().getDecl(), index); } return 0; } @@ -626,10 +629,10 @@ SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName if (!type) return 0; if (auto declRefType = as<DeclRefType>(type)) { - ASTBuilder* astBuilder = declRefType->getASTBuilder(); + ASTBuilder* astBuilder = declRefType->getASTBuilderForReflection(); auto globalSession = astBuilder->getGlobalSession(); - return findUserAttributeByName(globalSession, declRefType->declRef.getDecl(), name); + return findUserAttributeByName(globalSession, declRefType->getDeclRef().getDecl(), name); } return 0; } @@ -714,7 +717,7 @@ SLANG_API char const* spReflectionType_GetName(SlangReflectionType* inType) if( auto declRefType = as<DeclRefType>(type) ) { - auto declRef = declRefType->declRef; + auto declRef = declRefType->getDeclRef(); // Don't return a name for auto-generated anonymous types // that represent `cbuffer` members, etc. @@ -778,13 +781,13 @@ SLANG_API SlangReflectionType* spReflectionType_GetResourceResultType(SlangRefle if (auto textureType = as<TextureTypeBase>(type)) { - return convert(textureType->elementType); + return convert(textureType->getElementType()); } // TODO: need a better way to handle this stuff... #define CASE(TYPE, SHAPE, ACCESS) \ else if(as<TYPE>(type)) do { \ - return convert(as<TYPE>(type)->elementType); \ + return convert(as<TYPE>(type)->getElementType()); \ } while(0) // TODO: structured buffer needs to expose type layout! @@ -1132,7 +1135,7 @@ SLANG_API SlangInt spReflectionType_getSpecializedTypeArgCount(SlangReflectionTy auto specializedType = as<ExistentialSpecializedType>(type); if(!specializedType) return 0; - return specializedType->args.getCount(); + return specializedType->getArgCount(); } SLANG_API SlangReflectionType* spReflectionType_getSpecializedTypeArgType(SlangReflectionType* inType, SlangInt index) @@ -1144,9 +1147,9 @@ SLANG_API SlangReflectionType* spReflectionType_getSpecializedTypeArgType(SlangR if(!specializedType) return nullptr; if(index < 0) return nullptr; - if(index >= specializedType->args.getCount()) return nullptr; + if(index >= specializedType->getArgCount()) return nullptr; - auto argType = as<Type>(specializedType->args[index].val); + auto argType = as<Type>(specializedType->getArg(index).val); return convert(argType); } @@ -1405,7 +1408,7 @@ namespace Slang { if(auto declRefType = as<DeclRefType>(type)) { - if(declRefType->declRef.as<InterfaceDecl>()) + if(declRefType->getDeclRef().as<InterfaceDecl>()) { return declRefType; } diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h index c28c8a6d6..351b6f519 100644 --- a/source/slang/slang-serialize-ast-type-info.h +++ b/source/slang/slang-serialize-ast-type-info.h @@ -39,20 +39,82 @@ struct SerialTypeInfo<SyntaxClass<T>> } }; +// MatrixCoord can just go as is +template <> +struct SerialTypeInfo<MatrixCoord> : SerialIdentityTypeInfo<MatrixCoord> {}; + +inline void serializePointerValue(SerialWriter* writer, Val* ptrValue, SerialIndex* outSerial) +{ + if (ptrValue) + ptrValue = ptrValue->resolve(); + *(SerialIndex*)outSerial = writer->addPointer(ptrValue); +} + +inline void deserializePointerValue(SerialReader* reader, const SerialIndex* inSerial, void* outPtr, Val* unusedForResolution) +{ + SLANG_UNUSED(unusedForResolution); + + auto val = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<Val>(); + *(Val**)outPtr = val; + if (val) + { + SLANG_ASSERT(as<Val>(val)); + PostSerializationFixUp fixup; + fixup.kind = PostSerializationFixUpKind::ValPtr; + fixup.addressToModify = outPtr; + reader->getFixUps().add(fixup); + } +} template <typename T> struct SerialTypeInfo<DeclRef<T>> : public SerialTypeInfo<DeclRefBase*> {}; -// MatrixCoord can just go as is +// ValNodeOperand template <> -struct SerialTypeInfo<MatrixCoord> : SerialIdentityTypeInfo<MatrixCoord> {}; +struct SerialTypeInfo<ValNodeOperand> +{ + typedef ValNodeOperand NativeType; + struct SerialType + { + int8_t kind; + int64_t val; + }; + enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) }; + + static void toSerial(SerialWriter* writer, const void* native, void* serial) + { + auto& src = *(const NativeType*)native; + auto& dst = *(SerialType*)serial; + dst.kind = int8_t(src.kind); + if (src.kind == ValNodeOperandKind::ConstantValue) + dst.val = src.values.intOperand; + else if (src.kind == ValNodeOperandKind::ValNode) + serializePointerValue(writer, (Val*)src.values.nodeOperand, (SerialIndex*)&dst.val); + else + serializePointerValue(writer, src.values.nodeOperand, (SerialIndex*)&dst.val); + } + static void toNative(SerialReader* reader, const void* serial, void* native) + { + auto& dst = *(NativeType*)native; + auto& src = *(const SerialType*)serial; + + // Initialize + dst = NativeType(); + dst.kind = ValNodeOperandKind(src.kind); + if (dst.kind == ValNodeOperandKind::ConstantValue) + dst.values.intOperand = int64_t(src.val); + else if (dst.kind == ValNodeOperandKind::ValNode) + deserializePointerValue(reader, (SerialIndex*)&src.val, (Val**)&dst.values.nodeOperand, (Val*)nullptr); + else + deserializePointerValue(reader, (SerialIndex*)&src.val, &dst.values.nodeOperand, (NodeBase*)nullptr); + } +}; // LookupResultItem SLANG_VALUE_TYPE_INFO(LookupResultItem) // QualType SLANG_VALUE_TYPE_INFO(QualType) - // LookupResult template <> struct SerialTypeInfo<LookupResult> @@ -151,10 +213,6 @@ struct SerialTypeInfo<Modifiers> } }; -// ASTNodeType -template <> -struct SerialTypeInfo<ASTNodeType> : public SerialConvertTypeInfo<ASTNodeType, uint16_t> {}; - // LookupResultItem_Breadcrumb::ThisParameterMode template <> struct SerialTypeInfo<LookupResultItem_Breadcrumb::ThisParameterMode> : public SerialConvertTypeInfo<LookupResultItem_Breadcrumb::ThisParameterMode, uint8_t> {}; @@ -170,6 +228,7 @@ struct SerialTypeInfo<RequirementWitness::Flavor> : public SerialConvertTypeInfo // RequirementWitness SLANG_VALUE_TYPE_INFO(RequirementWitness) + } // namespace Slang #endif diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp index c75237896..293535b02 100644 --- a/source/slang/slang-serialize-container.cpp +++ b/source/slang/slang-serialize-container.cpp @@ -475,10 +475,6 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( // Set the sourceLocReader before doing de-serialize, such can lookup the remapped sourceLocs reader.getExtraObjects().set(sourceLocReader); - // Go through all of the AST nodes - // 1) Set the ASTBuilder on Type nodes - - // TODO(JS): // If modules can have more complicated relationships (like a two modules can refer to symbols // from each other), then we can make this work by @@ -492,12 +488,14 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( // For now if we assume a module can only access symbols from another module, and not the reverse. // So we just need to deserialize and we are done SLANG_RETURN_ON_FAIL(reader.deserializeObjects()); - + // Get the root node. It's at index 1 (0 is the null value). astRootNode = reader.getPointer(SerialIndex(1)).dynamicCast<NodeBase>(); - // 2) Add the extensions to the module mapTypeToCandidateExtensions cache - // 3) We need to fix the callback pointers for parsing + // Go through all AST nodes: + // 1) Add the extensions to the module mapTypeToCandidateExtensions cache + // 2) We need to fix the callback pointers for parsing + // 3) Register all `Val`s to the ASTBuilder's deduplication map. { ModuleDecl* moduleDecl = as<ModuleDecl>(astRootNode); @@ -505,6 +503,8 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( // Maps from keyword name name to index in (syntaxParseInfos) // Will be filled in lazily if needed (for SyntaxDecl setup) Dictionary<Name*, Index> syntaxKeywordDict; + + OrderedDictionary<Val*, List<Val**>> valUses; // Get the parse infos const auto syntaxParseInfos = getSyntaxParseInfos(); @@ -512,21 +512,18 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( for (auto& obj : reader.getObjects()) { + if (obj.m_kind == SerialTypeKind::NodeBase) { NodeBase* nodeBase = (NodeBase*)obj.m_ptr; SLANG_ASSERT(nodeBase); - if (Type* type = dynamicCast<Type>(nodeBase)) - { - type->_setASTBuilder(astBuilder); - } - else if (ExtensionDecl* extensionDecl = dynamicCast<ExtensionDecl>(nodeBase)) + if (ExtensionDecl* extensionDecl = dynamicCast<ExtensionDecl>(nodeBase)) { if (auto targetDeclRefType = as<DeclRefType>(extensionDecl->targetType)) { // Attach our extension to that type as a candidate... - if (auto aggTypeDeclRef = targetDeclRefType->declRef.as<AggTypeDecl>()) + if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>()) { auto aggTypeDecl = aggTypeDeclRef.getDecl(); @@ -567,6 +564,47 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( syntaxDecl->parseUserData = const_cast<ReflectClassInfo*>(syntaxDecl->syntaxClass.classInfo); } } + else if (Val* val = dynamicCast<Val>(nodeBase)) + { + valUses[val] = List<Val**>(); + } + } + } + // Go through fixup locations and deduplicate Vals. + // This is needed because we currently the same Val can be serialized multiple times + // in different modules. If we have a type defined in Module A and used in Module B, + // then both serialized Module A and Module B will contain a Type Val object that refers to A. + // When we load B, we should resolve those type references to the existing Type val instead. + // This step can be avoided if we can run deduplication while deserializing, which + // requires a different way of handling Val objects. + for (auto fixup : reader.getFixUps()) + { + if (fixup.kind == PostSerializationFixUpKind::ValPtr) + { + auto list = valUses.tryGetValue(*(Val**)fixup.addressToModify); + if (list) + list->add((Val**)fixup.addressToModify); + } + } + SLANG_AST_BUILDER_RAII(astBuilder); + for (auto& valUseList : valUses) + { + auto val = valUseList.key; + auto desc = val->getDesc(); + astBuilder->m_cachedNodes.tryGetValueOrAdd(desc, val); + } + for (auto& valUseList : valUses) + { + auto val = valUseList.key; + auto newVal = val->resolve(); + if (val != newVal) + { + astBuilder->m_cachedNodes[val->getDesc()] = newVal; + for (auto use : valUseList.value) + { + if (*use != newVal) + *use = newVal; + } } } } diff --git a/source/slang/slang-serialize-type-info.h b/source/slang/slang-serialize-type-info.h index 971d45197..c4b20c5b9 100644 --- a/source/slang/slang-serialize-type-info.h +++ b/source/slang/slang-serialize-type-info.h @@ -3,6 +3,7 @@ #define SLANG_SERIALIZE_TYPE_INFO_H #include "slang-serialize.h" + namespace Slang { /* For the serialization system to work we need to defined how native types are represented in the serialized format. @@ -87,7 +88,6 @@ struct SerialTypeInfo<float> : public SerialBasicTypeInfo<float> {}; template <> struct SerialTypeInfo<double> : public SerialBasicTypeInfo<double> {}; - // Fixed arrays template <typename T, size_t N> @@ -154,9 +154,26 @@ struct SerialTypeInfo<T, typename std::enable_if<std::is_enum<T>::value>::type> : public SerialIdentityTypeInfo<T> {}; +class Val; + // Pointer -// Could handle different pointer base types with some more template magic here, but instead went with Pointer type to keep -// things simpler. + +template<typename T, typename sfinae = typename std::enable_if<!IsBaseOf<Val, T>::Value>::type> +void serializePointerValue(SerialWriter* writer, T* ptrValue, SerialIndex* outSerial) +{ + static_assert(!IsBaseOf<Val, T>::Value); + *(SerialIndex*)outSerial = writer->addPointer(ptrValue); +} + +template<typename T, typename sfinae = typename std::enable_if<!IsBaseOf<Val, T>::Value>::type> +void deserializePointerValue(SerialReader* reader, SerialIndex* inSerial, void* outPtr, T* unusedForResolution) +{ + static_assert(!IsBaseOf<Val, T>::Value); + + SLANG_UNUSED(unusedForResolution); + *(T**)outPtr = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<T>(); +} + template <typename T> struct SerialTypeInfo<T*> { @@ -166,11 +183,13 @@ struct SerialTypeInfo<T*> static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) { - *(SerialType*)outSerial = writer->addPointer(*(T**)inNative); + auto ptrToWrite = *(T**)inNative; + serializePointerValue(writer, ptrToWrite, (SerialIndex*)outSerial); } + static void toNative(SerialReader* reader, const void* inSerial, void* outNative) { - *(T**)outNative = reader->getPointer(*(const SerialType*)inSerial).dynamicCast<T>(); + deserializePointerValue(reader, (SerialIndex*)inSerial, outNative, (T*)nullptr); } }; @@ -257,74 +276,8 @@ struct SerialTypeInfo<String> }; // Dictionary -template <typename KEY, typename VALUE> -struct SerialTypeInfo<Dictionary<KEY, VALUE>> -{ - typedef Dictionary<KEY, VALUE> NativeType; - struct SerialType - { - SerialIndex keys; ///< Index an array - SerialIndex values; ///< Index an array - }; - - typedef typename SerialTypeInfo<KEY>::SerialType KeySerialType; - typedef typename SerialTypeInfo<VALUE>::SerialType ValueSerialType; - - enum { SerialAlignment = SLANG_ALIGN_OF(SerialIndex) }; - - static void toSerial(SerialWriter* writer, const void* native, void* serial) - { - auto& src = *(const NativeType*)native; - auto& dst = *(SerialType*)serial; - - List<KeySerialType> keys; - List<ValueSerialType> values; - - Index count = Index(src.getCount()); - keys.setCount(count); - values.setCount(count); - - if (writer->getFlags() & SerialWriter::Flag::ZeroInitialize) - { - ::memset(keys.getBuffer(), 0, count * sizeof(KeySerialType)); - ::memset(values.getBuffer(), 0, count * sizeof(ValueSerialType)); - } - - Index i = 0; - for (const auto& pair : src) - { - SerialTypeInfo<KEY>::toSerial(writer, &pair.key, &keys[i]); - SerialTypeInfo<VALUE>::toSerial(writer, &pair.value, &values[i]); - i++; - } - - // When we add the array it is already converted to a serializable type, so add as SerialArray - dst.keys = writer->addSerialArray<KEY>(keys.getBuffer(), count); - dst.values = writer->addSerialArray<VALUE>(values.getBuffer(), count); - } - static void toNative(SerialReader* reader, const void* serial, void* native) - { - auto& src = *(const SerialType*)serial; - auto& dst = *(NativeType*)native; - - // Clear it - dst = NativeType(); - - List<KEY> keys; - List<VALUE> values; - - reader->getArray(src.keys, keys); - reader->getArray(src.values, values); - - SLANG_ASSERT(keys.getCount() == values.getCount()); - - const Index count = keys.getCount(); - for (Index i = 0; i < count; ++i) - { - dst.add(keys[i], values[i]); - } - } -}; +// Note: We leave out SerialTypeInfo specialization for Dictionary, because +// it does not have determinstic ordering. // OrderedDictionary template <typename KEY, typename VALUE> diff --git a/source/slang/slang-serialize-types.cpp b/source/slang/slang-serialize-types.cpp index 6c4512b1d..a091a2850 100644 --- a/source/slang/slang-serialize-types.cpp +++ b/source/slang/slang-serialize-types.cpp @@ -48,7 +48,8 @@ struct ByteReader const int numPrefixBytes = encodeUnicodePointToUTF8(len, prefixBytes); const Index baseIndex = stringTable.getCount(); - stringTable.setCount(baseIndex + numPrefixBytes + len); + auto newCount = baseIndex + numPrefixBytes + len; + stringTable.growToCount(newCount); char* dst = stringTable.begin() + baseIndex; diff --git a/source/slang/slang-serialize.cpp b/source/slang/slang-serialize.cpp index 2e8d6c6ba..1f5b6942d 100644 --- a/source/slang/slang-serialize.cpp +++ b/source/slang/slang-serialize.cpp @@ -2,6 +2,7 @@ #include "slang-serialize.h" #include "slang-ast-base.h" +#include "slang-ast-builder.h" namespace Slang { @@ -204,14 +205,14 @@ bool SerialClasses::isOk() const SerialClasses::SerialClasses(): - m_arena(2048) + m_arena(2097152) { } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialWriter !!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialWriter::SerialWriter(SerialClasses* classes, SerialFilter* filter, Flags flags) - : m_arena(2048) + : m_arena(2097152) , m_classes(classes) , m_filter(filter) , m_flags(flags) diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h index cc617034d..ce7bfa87b 100644 --- a/source/slang/slang-serialize.h +++ b/source/slang/slang-serialize.h @@ -211,6 +211,17 @@ protected: void* m_objects[Index(SerialExtraType::CountOf)]; }; +enum class PostSerializationFixUpKind +{ + ValPtr, +}; + +struct PostSerializationFixUp +{ + PostSerializationFixUpKind kind; + void* addressToModify; +}; + /* This class is the interface used by toNative implementations to recreate a type. */ class SerialReader : public RefObject { @@ -240,6 +251,8 @@ public: /// Get the entries list const List<const Entry*>& getEntries() const { return m_entries; } + List<PostSerializationFixUp>& getFixUps() { return m_fixUps; } + /// Access the objects list /// NOTE that if a SerialObject holding a RefObject and needs to be kept in scope, add the RefObject* via addScope List<SerialPointer>& getObjects() { return m_objects; } @@ -277,6 +290,8 @@ protected: SerialObjectFactory* m_objectFactory; SerialClasses* m_classes; ///< Information used to deserialize + + List<PostSerializationFixUp> m_fixUps; }; // --------------------------------------------------------------------------- diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 227e468d6..ae44e0c70 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -284,14 +284,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt { if(auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(subtypeWitness)) { - if(auto inheritanceDeclRef = declaredSubtypeWitness->declRef.as<InheritanceDecl>()) + if(auto inheritanceDeclRef = declaredSubtypeWitness->getDeclRef().as<InheritanceDecl>()) { // A conformance that was declared as part of an inheritance clause // will have built up a dictionary of the satisfying declarations // for each of its requirements. RequirementWitness requirementWitness; auto witnessTable = inheritanceDeclRef.getDecl()->witnessTable; - if(witnessTable && witnessTable->requirementDictionary.tryGetValue(requirementKey, requirementWitness)) + if(witnessTable && witnessTable->getRequirementDictionary().tryGetValue(requirementKey, requirementWitness)) { // The `inheritanceDeclRef` has substitutions applied to it that // *aren't* present in the `requirementWitness`, because it was @@ -338,7 +338,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // So, in order to get the *right* end result, we need to apply // the substitutions from the inheritance decl-ref to the witness. // - requirementWitness = requirementWitness.specialize(astBuilder, inheritanceDeclRef.getSubst()); + requirementWitness = requirementWitness.specialize(astBuilder, SubstitutionSet(inheritanceDeclRef)); return requirementWitness; } @@ -346,17 +346,17 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } else if (auto transitiveTypeWitness = as<TransitiveSubtypeWitness>(subtypeWitness)) { - if (auto declaredSubtypeWitnessMidToSup = as<DeclaredSubtypeWitness>(transitiveTypeWitness->midToSup)) + if (auto declaredSubtypeWitnessMidToSup = as<DeclaredSubtypeWitness>(transitiveTypeWitness->getMidToSup())) { - auto midKey = declaredSubtypeWitnessMidToSup->declRef; - auto midWitness = tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(transitiveTypeWitness->subToMid), midKey.getDecl()); + auto midKey = declaredSubtypeWitnessMidToSup->getDeclRef(); + auto midWitness = tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(transitiveTypeWitness->getSubToMid()), midKey.getDecl()); if (midWitness.getFlavor() == RequirementWitness::Flavor::witnessTable) { auto table = midWitness.getWitnessTable(); RequirementWitness result; - if (table->requirementDictionary.tryGetValue(requirementKey, result)) + if (table->getRequirementDictionary().tryGetValue(requirementKey, result)) { - result = result.specialize(astBuilder, midKey.getSubst()); + result = result.specialize(astBuilder, SubstitutionSet(midKey)); } return result; } @@ -364,15 +364,32 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } else if (auto extractFromConjunctionTypeWitness = as<ExtractFromConjunctionSubtypeWitness>(subtypeWitness)) { - if (auto conjunctionTypeWitness = as<ConjunctionSubtypeWitness>(extractFromConjunctionTypeWitness->conjunctionWitness)) + if (auto conjunctionTypeWitness = as<ConjunctionSubtypeWitness>(extractFromConjunctionTypeWitness->getConjunctionWitness())) { auto componentWitness = as<SubtypeWitness>( conjunctionTypeWitness->getComponentWitness( - extractFromConjunctionTypeWitness->indexInConjunction)); + extractFromConjunctionTypeWitness->getIndexInConjunction())); return tryLookUpRequirementWitness(astBuilder, componentWitness, requirementKey); } } + + // If we are looking for `ThisType`, just return subtype. + if (as<ThisTypeDecl>(requirementKey)) + { + RequirementWitness result; + result.m_flavor = RequirementWitness::Flavor::val; + result.m_val = subtypeWitness->getSub(); + return result; + } + // If we are looking for `ThisTypeConstraint`, just return the witness itself. + if (as<ThisTypeConstraintDecl>(requirementKey)) + { + RequirementWitness result; + result.m_flavor = RequirementWitness::Flavor::val; + result.m_val = subtypeWitness; + return result; + } // TODO: should handle the transitive case here too return RequirementWitness(); @@ -384,125 +401,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt void WitnessTable::add(Decl* decl, RequirementWitness const& witness) { - SLANG_ASSERT(!requirementDictionary.containsKey(decl)); - - requirementDictionary.add(decl, witness); - } - - // - - static Type* ExtractGenericArgType(Val* val) - { - auto type = as<Type>(val); - SLANG_RELEASE_ASSERT(type); - return type; - } - - static IntVal* ExtractGenericArgInteger(Val* val) - { - auto intVal = as<IntVal>(val); - SLANG_RELEASE_ASSERT(intVal); - return intVal; - } - - DeclRef<Decl> createDefaultSubstitutionsIfNeeded( - ASTBuilder* astBuilder, - SemanticsVisitor* semantics, - DeclRef<Decl> declRef) - { - // It is possible that `declRef` refers to a generic type, - // but does not specify arguments for its generic parameters. - // (E.g., this happens when referring to a generic type from - // within its own member functions). To handle this case, - // we will construct a default specialization at the use - // site if needed. - // - // This same logic should also apply to declarations nested - // more than one level inside of a generic (e.g., a `typdef` - // inside of a generic `struct`). - // - // Similarly, it needs to work for multiple levels of - // nested generics. - // - - // First, we collect all the generic parents. - ShortList<GenericDecl*> genericParents; - Decl* dd = declRef.getDecl(); - for (;;) - { - Decl* childDecl = dd; - Decl* parentDecl = dd->parentDecl; - if (!parentDecl) - break; - - dd = parentDecl; - - if (auto genericParentDecl = as<GenericDecl>(parentDecl)) - { - // Don't specialize any parameters of a generic. - if (childDecl != genericParentDecl->inner) - break; - genericParents.add(genericParentDecl); - } - } - - - Substitutions* outerSubst = nullptr; - for (Index i = genericParents.getCount()-1; i>=0; i--) - { - Decl* childDecl = genericParents[i]->inner; - Decl* parentDecl = genericParents[i]; - - if(auto genericParentDecl = as<GenericDecl>(parentDecl)) - { - // Don't specialize any parameters of a generic. - if(childDecl != genericParentDecl->inner) - break; - - // We have a generic ancestor, but do we have an substitutions for it? - GenericSubstitution* foundSubst = nullptr; - for(auto s = declRef.getSubst(); s; s = s->getOuter()) - { - auto genSubst = as<GenericSubstitution>(s); - if(!genSubst) - continue; - - if(genSubst->getGenericDecl() != genericParentDecl) - continue; - - // Okay, we found a matching substitution, - // so we just grab the args from the matching subst instead. - foundSubst = genSubst; - if (foundSubst->getOuter() != outerSubst) - { - foundSubst = astBuilder->getOrCreateGenericSubstitution( - outerSubst, foundSubst->getGenericDecl(), foundSubst->getArgs()); - } - - break; - } - - if(!foundSubst) - { - Substitutions* newSubst = createDefaultSubstitutionsForGeneric( - astBuilder, - semantics, - genericParentDecl, - outerSubst); - outerSubst = newSubst; - } - else - { - outerSubst = foundSubst; - } - } - } - - if(!outerSubst) - return declRef; - - int diff = 0; - return declRef.substituteImpl(astBuilder, outerSubst, &diff); + m_requirements.add(KeyValuePair<Decl*, RequirementWitness>(decl, witness)); + m_requirementDictionary.add(decl, witness); } // TODO: need to figure out how to unify this with the logic @@ -511,245 +411,73 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt ASTBuilder* astBuilder, DeclRef<Decl> declRef) { - declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); - if (auto builtinMod = declRef.getDecl()->findModifier<BuiltinTypeModifier>()) { - auto type = astBuilder->getOrCreate<BasicExpressionType>(builtinMod->tag); - type->declRef = declRef; + // Always create builtin types in global AST builder. + if (astBuilder->getSharedASTBuilder()->getInnerASTBuilder() != astBuilder) + return DeclRefType::create(astBuilder->getSharedASTBuilder()->getInnerASTBuilder(), declRef); + + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); + auto type = astBuilder->getOrCreate<BasicExpressionType>(declRef.declRefBase); return type; } else if (auto magicMod = declRef.getDecl()->findModifier<MagicTypeModifier>()) { - GenericSubstitution* subst = nullptr; - for(auto s = declRef.getSubst(); s; s = s->getOuter()) - { - if(auto genericSubst = as<GenericSubstitution>(s)) - { - subst = genericSubst; - break; - } - } + // Always create builtin types in global AST builder. + if (astBuilder->getSharedASTBuilder()->getInnerASTBuilder() != astBuilder) + return DeclRefType::create(astBuilder->getSharedASTBuilder()->getInnerASTBuilder(), declRef); - if (magicMod->magicName == "SamplerState") - { - auto type = astBuilder->getOrCreate<SamplerStateType>(SamplerStateFlavor(magicMod->tag)); - type->declRef = declRef; - return type; - } - else if (magicMod->magicName == "Vector") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() == 2); - auto vecType = astBuilder->getOrCreate<VectorExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1])); - vecType->declRef = declRef; - vecType->elementType = ExtractGenericArgType(subst->getArgs()[0]); - vecType->elementCount = ExtractGenericArgInteger(subst->getArgs()[1]); - return vecType; - } - else if (magicMod->magicName == "ArrayType") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() == 2); - auto vecType = astBuilder->getOrCreate<ArrayExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1])); - vecType->declRef = declRef; - return vecType; - } - else if (magicMod->magicName == "Matrix") + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); + auto classInfo = astBuilder->findSyntaxClass(magicMod->magicName.getUnownedSlice()); + if (!classInfo.classInfo) { - SLANG_ASSERT(subst && subst->getArgs().getCount() == 3); - auto matType = astBuilder->getOrCreate<MatrixExpressionType>( - ExtractGenericArgType(subst->getArgs()[0]), - ExtractGenericArgInteger(subst->getArgs()[1]), - ExtractGenericArgInteger(subst->getArgs()[2])); - matType->declRef = declRef; - return matType; + SLANG_UNEXPECTED("unhandled type"); } - else if (magicMod->magicName == "TensorViewType") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() == 1); - auto vecType = astBuilder->getOrCreate<TensorViewType>(ExtractGenericArgType(subst->getArgs()[0])); - vecType->declRef = declRef; - return vecType; - } - else if (magicMod->magicName == "Texture") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); - auto textureTag = TextureFlavor(magicMod->tag); - Val* sampleCount = nullptr; - if (textureTag.isMultisample()) + ValNodeDesc nodeDesc = {}; + nodeDesc.type = (ASTNodeType)classInfo.classInfo->m_classId; + nodeDesc.operands.add(ValNodeOperand(declRef)); + nodeDesc.init(); + NodeBase* type = astBuilder->_getOrCreateImpl(nodeDesc, [&]() { - if (subst->getArgs().getCount() >= 2) - sampleCount = ExtractGenericArgInteger(subst->getArgs().getLast()); - } - auto textureType = astBuilder->getOrCreate<TextureType>( - textureTag, - ExtractGenericArgType(subst->getArgs()[0]), - sampleCount); - textureType->declRef = declRef; - return textureType; - } - else if (magicMod->magicName == "TextureSampler") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); - auto textureType = astBuilder->getOrCreate<TextureSamplerType>( - TextureFlavor(magicMod->tag), - ExtractGenericArgType(subst->getArgs()[0])); - textureType->declRef = declRef; - return textureType; - } - else if (magicMod->magicName == "GLSLImageType") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); - auto textureType = astBuilder->getOrCreate<GLSLImageType>( - TextureFlavor(magicMod->tag), - ExtractGenericArgType(subst->getArgs()[0])); - textureType->declRef = declRef; - return textureType; - } - else if (magicMod->magicName == "FeedbackType") + auto resultNode = as<DeclRefType>(classInfo.createInstance(astBuilder)); + resultNode->setOperands(declRef); + return resultNode; + }); + if (!type) { - SLANG_ASSERT(subst == nullptr); - auto type = astBuilder->getOrCreateWithDefaultCtor<FeedbackType>(magicMod->tag); - type->declRef = declRef; - type->kind = FeedbackType::Kind(magicMod->tag); - return type; + SLANG_UNEXPECTED("constructor failure"); } - // TODO: eventually everything should follow this pattern, - // and we can drive the dispatch with a table instead - // of this ridiculously slow `if` cascade. - - #define CASE(n, T) \ - else if (magicMod->magicName == #n) \ - { \ - auto type = astBuilder->getOrCreateWithDefaultCtor<T>( \ - declRef.getDecl(), declRef.getSubst()); \ - type->declRef = declRef; \ - return type; \ - } - - CASE(HLSLInputPatchType, HLSLInputPatchType) - CASE(HLSLOutputPatchType, HLSLOutputPatchType) - - #undef CASE - - #define CASE(n, T) \ - else if (magicMod->magicName == #n) \ - { \ - SLANG_ASSERT(subst && subst->getArgs().getCount() == 1); \ - auto type = \ - astBuilder->getOrCreateWithDefaultCtor<T>(ExtractGenericArgType(subst->getArgs()[0])); \ - type->elementType = ExtractGenericArgType(subst->getArgs()[0]); \ - type->declRef = declRef; \ - return type; \ - } - - CASE(ConstantBuffer, ConstantBufferType) - CASE(TextureBuffer, TextureBufferType) - CASE(ParameterBlockType, ParameterBlockType) - CASE(GLSLInputParameterGroupType, GLSLInputParameterGroupType) - CASE(GLSLOutputParameterGroupType, GLSLOutputParameterGroupType) - CASE(GLSLShaderStorageBufferType, GLSLShaderStorageBufferType) - - CASE(HLSLStructuredBufferType, HLSLStructuredBufferType) - CASE(HLSLRWStructuredBufferType, HLSLRWStructuredBufferType) - CASE(HLSLRasterizerOrderedStructuredBufferType, HLSLRasterizerOrderedStructuredBufferType) - CASE(HLSLAppendStructuredBufferType, HLSLAppendStructuredBufferType) - CASE(HLSLConsumeStructuredBufferType, HLSLConsumeStructuredBufferType) - - CASE(HLSLPointStreamType, HLSLPointStreamType) - CASE(HLSLLineStreamType, HLSLLineStreamType) - CASE(HLSLTriangleStreamType, HLSLTriangleStreamType) - - #undef CASE - - // "magic" builtin types which have no generic parameters - #define CASE(n,T) \ - else if(magicMod->magicName == #n) { \ - auto type = astBuilder->getOrCreate<T>(); \ - type->declRef = declRef; \ - return type; \ - } - - CASE(HLSLByteAddressBufferType, HLSLByteAddressBufferType) - CASE(HLSLRWByteAddressBufferType, HLSLRWByteAddressBufferType) - CASE(HLSLRasterizerOrderedByteAddressBufferType, HLSLRasterizerOrderedByteAddressBufferType) - CASE(UntypedBufferResourceType, UntypedBufferResourceType) - - CASE(GLSLInputAttachmentType, GLSLInputAttachmentType) - - #undef CASE - - else + auto declRefType = dynamicCast<DeclRefType>(type); + if (!declRefType) { - auto classInfo = astBuilder->findSyntaxClass(magicMod->magicName.getUnownedSlice()); - if (!classInfo.classInfo) - { - SLANG_UNEXPECTED("unhandled type"); - } - - NodeBase* type = classInfo.createInstance(astBuilder); - if (!type) - { - SLANG_UNEXPECTED("constructor failure"); - } - - auto declRefType = dynamicCast<DeclRefType>(type); - if (!declRefType) - { - SLANG_UNEXPECTED("expected a declaration reference type"); - } - declRefType->declRef = declRef; - return declRefType; + SLANG_UNEXPECTED("expected a declaration reference type"); } + return declRefType; + } + else if (as<ThisTypeDecl>(declRef.getDecl()) && as<DirectDeclRef>(declRef.declRefBase)) + { + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); + + return astBuilder->getOrCreate<ThisType>(declRef.declRefBase); } else { + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); + return astBuilder->getOrCreate<DeclRefType>(declRef.declRefBase); } } // - GenericSubstitution* findInnerMostGenericSubstitution(Substitutions* subst) + Val::OperandView<Val> findInnerMostGenericArgs(SubstitutionSet subst) { - for(Substitutions* s = subst; s; s = s->getOuter()) - { - if(auto genericSubst = as<GenericSubstitution>(s)) - return genericSubst; - } - return nullptr; - } - - - // DeclRefBase - - Type* DeclRefBase::substitute(ASTBuilder* astBuilder, Type* type) const - { - // Note that type can be nullptr, and so this function can return nullptr (although only correctly when no substitutions) - - // No substitutions? Easy. - if (!substitutions) - return type; - - SLANG_ASSERT(type); - - // Otherwise we need to recurse on the type structure - // and apply substitutions where it makes sense - return Slang::as<Type>(type->substitute(astBuilder, substitutions)); - } - - DeclRefBase* DeclRefBase::substitute(ASTBuilder* astBuilder, DeclRefBase* declRef) const - { - if(!substitutions) - return declRef; - - int diff = 0; - return declRef->substituteImpl(astBuilder, substitutions, &diff); - } - - SubstExpr<Expr> DeclRefBase::substitute(ASTBuilder* /* astBuilder*/, Expr* expr) const - { - return SubstExpr<Expr>(expr, substitutions); + if (!subst.declRef) + return Val::OperandView<Val>(); + if (auto genApp = subst.findGenericAppDeclRef()) + return genApp->getArgs(); + return Val::OperandView<Val>(); } SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr) @@ -764,7 +492,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt int diff = 0; auto declRefBase = declRef.substituteImpl(astBuilder, substs, &diff); - return astBuilder->getSpecializedDeclRef<Decl>(declRefBase.getDecl(), declRefBase.getSubst()); + return declRefBase; } Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type) @@ -790,332 +518,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - Substitutions* specializeSubstitutionsShallow( - ASTBuilder* astBuilder, - Substitutions* substToSpecialize, - Substitutions* substsToApply, - Substitutions* restSubst, - int* ioDiff) - { - SLANG_ASSERT(substToSpecialize); - return substToSpecialize->applySubstitutionsShallow(astBuilder, substsToApply, restSubst, ioDiff); - } - - // Construct new substitutions to apply to a declaration, - // based on a provided substitution set to be applied - Substitutions* specializeSubstitutions( - ASTBuilder* astBuilder, - Decl* declToSpecialize, - Substitutions* substsToSpecialize, - Substitutions* substsToApply, - int* ioDiff) - { - // No declaration? Then nothing to specialize. - if(!declToSpecialize) - return nullptr; - - // No (remaining) substitutions to apply? Then we are done. - if(!substsToApply) - return substsToSpecialize; - - // Walk the hierarchy of the declaration to determine what specializations might apply. - // We assume that the `substsToSpecialize` must be aligned with the ancestor - // hierarchy of `declToSpecialize` such that if, e.g., the `declToSpecialize` is - // nested directly in a generic, then `substToSpecialize` will either start with - // the corresponding `GenericSubstitution` or there will be *no* generic substitutions - // corresponding to that decl. - for(Decl* ancestorDecl = declToSpecialize; ancestorDecl; ancestorDecl = ancestorDecl->parentDecl) - { - if(auto ancestorGenericDecl = as<GenericDecl>(ancestorDecl)) - { - // The declaration is nested inside a generic. - // Does it already have a specialization for that generic? - if(auto specGenericSubst = as<GenericSubstitution>(substsToSpecialize)) - { - if(specGenericSubst->getGenericDecl() == ancestorGenericDecl) - { - // Yes. We have an existing specialization, so we will - // keep one matching it in place. - int diff = 0; - auto restSubst = specializeSubstitutions( - astBuilder, - ancestorGenericDecl->parentDecl, - specGenericSubst->getOuter(), - substsToApply, - &diff); - - auto firstSubst = specializeSubstitutionsShallow( - astBuilder, - specGenericSubst, - substsToApply, - restSubst, - &diff); - - *ioDiff += diff; - return firstSubst; - } - } - - // If the declaration is not already specialized - // for the given generic, then see if we are trying - // to *apply* such specializations to it. - // - // TODO: The way we handle things right now with - // "default" specializations, this case shouldn't - // actually come up. - // - for(auto s = substsToApply; s; s = s->getOuter()) - { - auto appGenericSubst = as<GenericSubstitution>(s); - if(!appGenericSubst) - continue; - - if(appGenericSubst->getGenericDecl() != ancestorGenericDecl) - continue; - - // The substitutions we are applying are trying - // to specialize this generic, but we don't already - // have a generic substitution in place. - // We will need to create one. - - int diff = 0; - auto restSubst = specializeSubstitutions( - astBuilder, - ancestorGenericDecl->parentDecl, - substsToSpecialize, - substsToApply, - &diff); - - GenericSubstitution* firstSubst = astBuilder->getOrCreateGenericSubstitution( - restSubst, ancestorGenericDecl, appGenericSubst->getArgs()); - - (*ioDiff)++; - return firstSubst; - } - } - else if(auto ancestorInterfaceDecl = as<InterfaceDecl>(ancestorDecl)) - { - // The task is basically the same as for the generic case: - // We want to see if there is any existing substitution that - // applies to this declaration, and use that if possible. - - // The declaration is nested inside a generic. - // Does it already have a specialization for that generic? - if(auto specThisTypeSubst = as<ThisTypeSubstitution>(substsToSpecialize)) - { - if(specThisTypeSubst->interfaceDecl == ancestorInterfaceDecl) - { - // Yes. We have an existing specialization, so we will - // keep one matching it in place. - int diff = 0; - auto restSubst = specializeSubstitutions( - astBuilder, - ancestorInterfaceDecl->parentDecl, - specThisTypeSubst->getOuter(), - substsToApply, - &diff); - - auto firstSubst = specializeSubstitutionsShallow( - astBuilder, - specThisTypeSubst, - substsToApply, - restSubst, - &diff); - - *ioDiff += diff; - return firstSubst; - } - } - - // Otherwise, check if we are trying to apply - // a this-type substitution to the given interface - // - // Note: We want to skip the ThisTypeSubstitution that specializes - // declToSpecialize itself (when declToSpecialize is an interface - // decl and the subst specializes it), and only pull the - // ThisTypeSubstitution when the decl is referencing a child of - // the interface decl being specialized. This is because - // by default an interface declref type is a "free" existential - // type that shouldn't be specialized by someone else, unless - // there is an "implicit" ThisType reference preceeding a child - // reference. - if (declToSpecialize != ancestorInterfaceDecl) - { - for (auto s = substsToApply; s; s = s->getOuter()) - { - auto appThisTypeSubst = as<ThisTypeSubstitution>(s); - if (!appThisTypeSubst) - continue; - - if (appThisTypeSubst->interfaceDecl != ancestorInterfaceDecl) - continue; - - int diff = 0; - auto restSubst = specializeSubstitutions( - astBuilder, - ancestorInterfaceDecl->parentDecl, - substsToSpecialize, - substsToApply, - &diff); - - ThisTypeSubstitution* firstSubst = astBuilder->getOrCreateThisTypeSubstitution( - ancestorInterfaceDecl, appThisTypeSubst->witness, restSubst); - - (*ioDiff)++; - return firstSubst; - } - } - } - } - - // If we reach here then we've walked the full hierarchy up from - // `declToSpecialize` and either didn't run into an generic/interface - // declarations, or we didn't find any attempt to specialize them - // in either substitution. - // - // As an invariant, there should *not* be any generic or this-type - // substitutions in `substToSpecialize`, because otherwise they - // would be specializations that don't actually apply to the given - // declaration. - // - // Note: this does *not* mean that `substsToApply` doesn't have - // any generic or this-type substitutions; it just means that none - // of them were applicable. - // - return nullptr; - } - - DeclRefBase* DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet substSet, int* ioDiff) const - { - // Nothing to do when we have no declaration. - if(!decl) - return const_cast<DeclRefBase*>(this); - - // Apply the given substitutions to any specializations - // that have already been applied to this declaration. - int diff = 0; - - auto substSubst = specializeSubstitutions( - astBuilder, - decl, - substitutions, - substSet.substitutions, - &diff); - - if (!diff) - return const_cast<DeclRefBase*>(this); - - *ioDiff += diff; - - DeclRefBase* substDeclRef = astBuilder->getSpecializedDeclRef(decl, substSubst); - - // TODO: The old code here used to try to translate a decl-ref - // to an associated type in a decl-ref for the concrete type - // in a particular implementation. - // - // I have only kept that logic in `DeclRefType::SubstituteImpl`, - // but it may turn out it is needed here too. - - return substDeclRef; - } - - bool DeclRefBase::_equalsValOverride(Val* val) - { - if (auto otherDeclRef = as<DeclRefBase>(val)) - return equals(otherDeclRef); - return false; - } - - // Check if this is an equivalent declaration reference to another - bool DeclRefBase::equals(DeclRefBase* declRef) const - { - if (!declRef) - return false; - if (decl != declRef->decl) - return false; - if (!SubstitutionSet(substitutions).equals(declRef->substitutions)) - return false; - - return true; - } - - // Convenience accessors for common properties of declarations - Name* DeclRefBase::getName() const - { - return decl->nameAndLoc.name; - } - SourceLoc DeclRefBase::getNameLoc() const - { - return decl->nameAndLoc.loc; - } - SourceLoc DeclRefBase::getLoc() const - { - return decl->loc; - } - - DeclRefBase* DeclRefBase::getParent(ASTBuilder* astBuilder) const - { - // Want access to the free function (the 'as' method by default gets priority) - // Can access as method with this->as because it removes any ambiguity. - using Slang::as; - - auto parentDecl = decl->parentDecl; - if (!parentDecl) - return nullptr; - - // Default is to apply the same set of substitutions/specializations - // to the parent declaration as were applied to the child. - Substitutions* substToApply = substitutions; - - if(auto interfaceDecl = as<InterfaceDecl>(decl)) - { - // The declaration being referenced is an `interface` declaration, - // and there might be a this-type substitution in place. - // A reference to the parent of the interface declaration - // should not include that substitution. - if(auto thisTypeSubst = as<ThisTypeSubstitution>(substToApply)) - { - if(thisTypeSubst->interfaceDecl == interfaceDecl) - { - // Strip away that specializations that apply to the interface. - substToApply = thisTypeSubst->getOuter(); - } - } - } - - if (auto parentGenericDecl = as<GenericDecl>(parentDecl)) - { - // The parent of this declaration is a generic, which means - // that the decl-ref to the current declaration might include - // substitutions that specialize the generic parameters. - // A decl-ref to the parent generic should *not* include - // those substitutions. - // - if(auto genericSubst = as<GenericSubstitution>(substToApply)) - { - if(genericSubst->getGenericDecl() == parentGenericDecl) - { - // Strip away the specializations that were applied to the parent. - substToApply = genericSubst->getOuter(); - } - } - } - - return astBuilder->getSpecializedDeclRef(parentDecl, substToApply); - } - - HashCode DeclRefBase::getHashCode() const - { - return combineHash(PointerHash<1>::getHashCode(decl), SubstitutionSet(substitutions).getHashCode()); - } - // IntVal IntegerLiteralValue getIntVal(IntVal* val) { if (auto constantVal = as<ConstantIntVal>(val)) { - return constantVal->value; + return constantVal->getValue(); } SLANG_UNEXPECTED("needed a known integer value"); //return 0; @@ -1125,14 +534,22 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // HLSLPatchType + Val* getGenericArg(DeclRef<Decl> declRef, Index index) + { + auto subst = SubstitutionSet(declRef).findGenericAppDeclRef(); + if (index < subst->getArgs().getCount()) + return subst->getArgs()[index]; + return nullptr; + } + Type* HLSLPatchType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(getGenericArg(getDeclRef(), 0)); } IntVal* HLSLPatchType::getElementCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); + return as<IntVal>(getGenericArg(getDeclRef(), 1)); } // MeshOutputType @@ -1143,12 +560,12 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt Type* MeshOutputType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(getGenericArg(getDeclRef(), 0)); } IntVal* MeshOutputType::getMaxElementCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); + return as<IntVal>(getGenericArg(getDeclRef(), 1)); } // Constructors for types @@ -1174,17 +591,16 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt { DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef).as<TypeDefDecl>(); - return astBuilder->create<NamedExpressionType>(specializedDeclRef); + return astBuilder->getOrCreate<NamedExpressionType>(specializedDeclRef); } FuncType* getFuncType( ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declRef) { - FuncType* funcType = astBuilder->create<FuncType>(); - - funcType->resultType = getResultType(astBuilder, declRef); - funcType->errorType = getErrorCodeType(astBuilder, declRef); + List<Type*> paramTypes; + auto resultType = getResultType(astBuilder, declRef); + auto errorType = getErrorCodeType(astBuilder, declRef); for (auto paramDeclRef : getParameters(astBuilder, declRef)) { auto paramDecl = paramDeclRef.getDecl(); @@ -1204,9 +620,10 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt paramType = astBuilder->getOutType(paramType); } } - funcType->paramTypes.add(paramType); + paramTypes.add(paramType); } + FuncType* funcType = astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType); return funcType; } @@ -1214,40 +631,34 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt ASTBuilder* astBuilder, DeclRef<GenericDecl> const& declRef) { - return astBuilder->create<GenericDeclRefType>(declRef); + return astBuilder->getOrCreate<GenericDeclRefType>(declRef); } NamespaceType* getNamespaceType( ASTBuilder* astBuilder, DeclRef<NamespaceDeclBase> const& declRef) { - auto type = astBuilder->create<NamespaceType>(); - type->declRef = declRef; + auto type = astBuilder->getOrCreate<NamespaceType>(declRef); return type; } SamplerStateType* getSamplerStateType( ASTBuilder* astBuilder) { - return astBuilder->create<SamplerStateType>(); + return astBuilder->getSamplerStateType(); } - ThisTypeSubstitution* findThisTypeSubstitution( - const Substitutions* substs, + SubtypeWitness* findThisTypeWitness( + SubstitutionSet substs, InterfaceDecl* interfaceDecl) { - for(const Substitutions* s = substs; s; s = s->getOuter()) + auto lookupDeclRef = substs.findLookupDeclRef(); + if (!lookupDeclRef) + return nullptr; + if (lookupDeclRef->getSupDecl() == interfaceDecl) { - auto thisTypeSubst = as<ThisTypeSubstitution>(s); - if(!thisTypeSubst) - continue; - - if(thisTypeSubst->interfaceDecl != interfaceDecl) - continue; - - return const_cast<ThisTypeSubstitution*>(thisTypeSubst); + return lookupDeclRef->getWitness(); } - return nullptr; } @@ -1259,20 +670,16 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt auto substAssocTypeDecl = substDeclRef.getDecl(); - for (auto s = substDeclRef.getSubst(); s; s = s->getOuter()) + if (auto lookupDeclRef = SubstitutionSet(substDeclRef).findLookupDeclRef()) { - auto thisSubst = as<ThisTypeSubstitution>(s); - if (!thisSubst) - continue; - if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl)) { - if (thisSubst->interfaceDecl == interfaceDecl) + if (lookupDeclRef->getSupDecl() == interfaceDecl) { // We need to look up the declaration that satisfies // the requirement named by the associated type. Decl* requirementKey = substAssocTypeDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(builder, thisSubst->witness, requirementKey); + RequirementWitness requirementWitness = tryLookUpRequirementWitness(builder, lookupDeclRef->getWitness(), requirementKey); switch (requirementWitness.getFlavor()) { default: @@ -1296,17 +703,17 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt if (builtinReq->kind != BuiltinRequirementKind::DifferentialType) return nullptr; // Is the concrete type a Differential associated type? - auto innerDeclRefType = as<DeclRefType>(thisSubst->witness->sub); + auto innerDeclRefType = as<DeclRefType>(lookupDeclRef->getWitness()->getSub()); if (!innerDeclRefType) return nullptr; - auto innerBuiltinReq = innerDeclRefType->declRef.getDecl()->findModifier<BuiltinRequirementModifier>(); + auto innerBuiltinReq = innerDeclRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>(); if (!innerBuiltinReq) return nullptr; if (innerBuiltinReq->kind != BuiltinRequirementKind::DifferentialType) return nullptr; - if (!innerDeclRefType->declRef.equals(declRef)) + if (!innerDeclRefType->getDeclRef().equals(declRef)) { - auto result = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(builder, innerDeclRefType->declRef); + auto result = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(builder, innerDeclRefType->getDeclRef()); if (result) return result; } @@ -1320,119 +727,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - String DeclRefBase::toString() const - { - StringBuilder builder; - toText(builder); - return std::move(builder); - } - - // Prints a partially qualified type name with generic substitutions. - void _printNestedDecl(const Substitutions* substitutions, const Decl* decl, StringBuilder& out) - { - // If there is a parent scope for the declaration, print it first. - // Exclude top-level namespaces like `tu0` or `core`. - if (decl->parentDecl && !Slang::as<ModuleDecl>(decl->parentDecl)) - { - auto parentGeneric = Slang::as<GenericDecl>(decl->parentDecl); - - // Exclude function or operator names. - // Avoids excessively verbose messages like `func<T>(func::T)` - if (!(parentGeneric && Slang::as<CallableDecl>(parentGeneric->inner))) - { - _printNestedDecl(substitutions, decl->parentDecl, out); - - // If the parent is a generic for this type, skip *this* type. - // Avoids duplicate types like `MyType<T>::MyType` - if (parentGeneric && parentGeneric->inner == decl) - return; - - out << "."; - } - } - // If we have a ThisTypeSubstitution to an interface decl, print the substituted sub - // type instead. - for (;;) - { - if (auto interfaceDecl = const_cast<InterfaceDecl*>(as<InterfaceDecl>(decl))) - { - if (auto thisSubst = findThisTypeSubstitution(substitutions, interfaceDecl)) - { - if (auto subTypeWitness = as<SubtypeWitness>(thisSubst->witness)) - { - out << subTypeWitness->sub; - break; - } - } - } - // Otherwise, just print this type's name. - auto name = decl->getName(); - if (name) - { - out << name->text; - } - break; - } - - // Look for generic substitutions on this type. - for (const Substitutions* subst = substitutions; subst; subst = subst->getOuter()) - { - auto genericSubstitution = Slang::as<GenericSubstitution>(subst); - if (!genericSubstitution) - continue; - - // If the substitution is for this type, print it. - if (genericSubstitution->getGenericDecl() == decl) - { - out << "<"; - bool isFirst = true; - for (const auto& it : genericSubstitution->getArgs()) - { - // Don't print out witnesses. - if (as<Witness>(it)) - continue; - if (!isFirst) - out << ", "; - isFirst = false; - it->toText(out); - } - out << ">"; - - break; - } - } - } - - void DeclRefBase::toText(StringBuilder& out) const - { - if (decl) - { - _printNestedDecl(substitutions, decl, out); - } - } - - bool SubstitutionSet::equals(const SubstitutionSet& substSet) const - { - if (substitutions == substSet.substitutions) - { - return true; - } - if (substitutions == nullptr || substSet.substitutions == nullptr) - { - return false; - } - return substitutions->equals(substSet.substitutions); - } - - HashCode SubstitutionSet::getHashCode() const - { - HashCode rs = 0; - if (substitutions) - rs = combineHash(rs, substitutions->getHashCode()); - return rs; - } - - ModuleDecl* getModuleDecl(Decl* decl) { for( auto dd = decl; dd; dd = dd->parentDecl ) diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index a63a2471c..4addb1d53 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -22,22 +22,22 @@ namespace Slang inline bool areValsEqual(Val* left, Val* right) { if(!left || !right) return left == right; - return left->equalsVal(right); + return left->equals(right); } // inline BaseType getVectorBaseType(VectorExpressionType* vecType) { - auto basicExprType = as<BasicExpressionType>(vecType->elementType); - return basicExprType->baseType; + auto basicExprType = as<BasicExpressionType>(vecType->getElementType()); + return basicExprType->getBaseType(); } inline int getVectorSize(VectorExpressionType* vecType) { - auto constantVal = as<ConstantIntVal>(vecType->elementCount); + auto constantVal = as<ConstantIntVal>(vecType->getElementCount()); if (constantVal) - return (int) constantVal->value; + return (int) constantVal->getValue(); // TODO: what to do in this case? return 0; } @@ -52,15 +52,21 @@ namespace Slang DeclRef<AggTypeDecl> const& declRef, SemanticsVisitor* semantics); + // Returns the members of `genericInnerDecl`'s enclosing generic decl. + inline FilteredMemberRefList<Decl> getGenericMembers(ASTBuilder* astBuilder, DeclRef<Decl> genericInnerDecl, MemberFilterStyle filterStyle = MemberFilterStyle::All) + { + return FilteredMemberRefList<Decl>(astBuilder, genericInnerDecl.getParent().getDecl()->members, genericInnerDecl, filterStyle); + } + inline FilteredMemberRefList<Decl> getMembers(ASTBuilder* astBuilder, DeclRef<ContainerDecl> declRef, MemberFilterStyle filterStyle = MemberFilterStyle::All) { - return FilteredMemberRefList<Decl>(astBuilder, declRef.getDecl()->members, declRef.getSubst(), filterStyle); + return FilteredMemberRefList<Decl>(astBuilder, declRef.getDecl()->members, declRef, filterStyle); } template<typename T> inline FilteredMemberRefList<T> getMembersOfType(ASTBuilder* astBuilder, DeclRef<ContainerDecl> declRef, MemberFilterStyle filterStyle = MemberFilterStyle::All) { - return FilteredMemberRefList<T>(astBuilder, declRef.getDecl()->members, declRef.getSubst(), filterStyle); + return FilteredMemberRefList<T>(astBuilder, declRef.getDecl()->members, declRef, filterStyle); } void _foreachDirectOrExtensionMemberOfType( @@ -70,7 +76,7 @@ namespace Slang void (*callback)(DeclRefBase*, void*), void const* userData); - DeclRef<Decl> _getSpecializedDeclRef(ASTBuilder* builder, Decl* decl, Substitutions* subst); + DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl); template<typename T, typename F> inline void foreachDirectOrExtensionMemberOfType( @@ -153,6 +159,26 @@ namespace Slang /// If the given `structTypeDeclRef` inherits from another struct type, return that base struct decl DeclRef<StructDecl> findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef<StructDecl> structTypeDeclRef); + SubtypeWitness* findThisTypeWitness( + SubstitutionSet substs, + InterfaceDecl* interfaceDecl); + + RequirementWitness tryLookUpRequirementWitness( + ASTBuilder* astBuilder, + SubtypeWitness* subtypeWitness, + Decl* requirementKey); + + DeclRef<Decl> createDefaultSubstitutionsIfNeeded( + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, + DeclRef<Decl> declRef); + + List<Val*> getDefaultSubstitutionArgs(ASTBuilder* astBuilder, SemanticsVisitor* semantics, GenericDecl* genericDecl); + + Val::OperandView<Val> findInnerMostGenericArgs(SubstitutionSet subst); + + ParameterDirection getParameterDirection(VarDeclBase* varDecl); + inline Type* getTagType(ASTBuilder* astBuilder, DeclRef<EnumDecl> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->tagType); @@ -192,8 +218,6 @@ namespace Slang inline Decl* getInner(DeclRef<GenericDecl> declRef) { - // TODO: Should really return a `DeclRef<Decl>` for the inner - // declaration, and not just a raw pointer return declRef.getDecl()->inner; } @@ -288,46 +312,6 @@ namespace Slang // - ThisTypeSubstitution* findThisTypeSubstitution( - const Substitutions* substs, - InterfaceDecl* interfaceDecl); - - RequirementWitness tryLookUpRequirementWitness( - ASTBuilder* astBuilder, - SubtypeWitness* subtypeWitness, - Decl* requirementKey); - - // TODO: where should this live? - SubstitutionSet createDefaultSubstitutions( - ASTBuilder* astBuilder, - SemanticsVisitor* semantics, - Decl* decl, - SubstitutionSet parentSubst); - - SubstitutionSet createDefaultSubstitutions( - ASTBuilder* astBuilder, - SemanticsVisitor* semantics, - Decl* decl); - - DeclRef<Decl> createDefaultSubstitutionsIfNeeded( - ASTBuilder* astBuilder, - SemanticsVisitor* semantics, - DeclRef<Decl> declRef); - - GenericSubstitution* createDefaultSubstitutionsForGeneric( - ASTBuilder* astBuilder, - SemanticsVisitor* semantics, - GenericDecl* genericDecl, - Substitutions* outerSubst); - - GenericSubstitution* findInnerMostGenericSubstitution(Substitutions* subst); - - ThisTypeSubstitution* findThisTypeSubstitution( - const Substitutions* substs, - InterfaceDecl* interfaceDecl); - - ParameterDirection getParameterDirection(VarDeclBase* varDecl); - enum class UserDefinedAttributeTargets { None = 0, diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 5cf8d2350..f5e14366d 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1571,9 +1571,9 @@ static LayoutSize GetElementCount(IntVal* val) if (auto constantVal = as<ConstantIntVal>(val)) { - if (constantVal->value == kUnsizedArrayMagicLength) + if (constantVal->getValue() == kUnsizedArrayMagicLength) return LayoutSize::infinite(); - return LayoutSize(LayoutSize::RawValue(constantVal->value)); + return LayoutSize(LayoutSize::RawValue(constantVal->getValue())); } else if(const auto varRefVal = as<GenericParamIntVal>(val)) { @@ -2766,7 +2766,7 @@ RefPtr<TypeLayout> createParameterGroupTypeLayout( parameterGroupRules, context.targetReq); - auto elementType = parameterGroupType->elementType; + auto elementType = parameterGroupType->getElementType(); return _createParameterGroupTypeLayout( context, @@ -3642,24 +3642,6 @@ static void _addLayout(TypeLayoutContext const& context, static TypeLayoutResult _updateLayout(TypeLayoutContext const& context, Type* type, - TypeLayout* layout, - const SimpleLayoutInfo& info) -{ - auto layoutResultPtr = context.layoutMap.tryGetValue(type); - SLANG_ASSERT(layoutResultPtr); - if (layoutResultPtr) - { - // Check the layout is the same! - SLANG_ASSERT(layoutResultPtr->layout.get() == layout); - // Update the info - layoutResultPtr->info = info; - } - - return TypeLayoutResult(layout, info); -} - -static TypeLayoutResult _updateLayout(TypeLayoutContext const& context, - Type* type, const TypeLayoutResult& result) { auto layoutResultPtr = context.layoutMap.tryGetValue(type); @@ -3791,7 +3773,7 @@ static TypeLayoutResult _createTypeLayout( context, \ ShaderParameterKind::KIND, \ type_##TYPE, \ - type_##TYPE->elementType); \ + type_##TYPE->getElementType()); \ return TypeLayoutResult(typeLayout, info); \ } while(0) @@ -3826,14 +3808,14 @@ static TypeLayoutResult _createTypeLayout( else if(auto basicType = as<BasicExpressionType>(type)) { return createSimpleTypeLayout( - rules->GetScalarLayout(basicType->baseType), + rules->GetScalarLayout(basicType->getBaseType()), type, rules); } else if(auto vecType = as<VectorExpressionType>(type)) { - auto elementType = vecType->elementType; - size_t elementCount = (size_t) getIntVal(vecType->elementCount); + auto elementType = vecType->getElementType(); + size_t elementCount = (size_t) getIntVal(vecType->getElementCount()); auto element = _createTypeLayout( context, @@ -3842,7 +3824,7 @@ static TypeLayoutResult _createTypeLayout( BaseType elementBaseType = BaseType::Void; if (auto elementBasicType = as<BasicExpressionType>(elementType)) { - elementBaseType = elementBasicType->baseType; + elementBaseType = elementBasicType->getBaseType(); } auto info = rules->GetVectorLayout(elementBaseType, element.info, elementCount); @@ -3874,7 +3856,7 @@ static TypeLayoutResult _createTypeLayout( BaseType elementBaseType = BaseType::Void; if (auto elementBasicType = as<BasicExpressionType>(elementType)) { - elementBaseType = elementBasicType->baseType; + elementBaseType = elementBasicType->getBaseType(); } // The `GetMatrixLayout` implementation in the layout rules @@ -3972,7 +3954,7 @@ static TypeLayoutResult _createTypeLayout( } else if (auto declRefType = as<DeclRefType>(type)) { - auto declRef = declRefType->declRef; + auto declRef = declRefType->getDeclRef(); if (auto structDeclRef = declRef.as<StructDecl>()) { @@ -4346,99 +4328,20 @@ static TypeLayoutResult _createTypeLayout( errorType, rules); } - else if( auto taggedUnionType = as<TaggedUnionType>(type) ) + else if( auto existentialSpecializedType = as<ExistentialSpecializedType>(type) ) { - // A tagged union type needs to be laid out as the maximum - // size of any constituent type. - // - // In practice, only a tagged union of uniform data will - // work, but for now we will compute the maximum usage - // for each resource kind for generality. - // - // For the uniform data we will start with a size - // of zero and an alignment of one for our base case - // (this is what a tagged union of no cases would consume). - // - UniformLayoutInfo info(0, 1); - - RefPtr<TaggedUnionTypeLayout> taggedUnionLayout = new TaggedUnionTypeLayout(); - - _addLayout(context, type, taggedUnionLayout); - - taggedUnionLayout->type = type; - taggedUnionLayout->rules = rules; - - // Now we iterate over the case types and see if they - // change our computed maximum size/alignement. - // - for( auto caseType : taggedUnionType->caseTypes ) + ExpandedSpecializationArgs args; + for (Index i = 0; i < existentialSpecializedType->getArgCount(); ++i) { - // Note: A tagged union type is not expected to have any existential/interface type - // slots; the case types that are provided must be fully specialized before the union is - // formed. Thus we don't need to mess around with existential type slots here the - // way we do for the `struct` case. - - auto caseTypeResult = _createTypeLayout(context, caseType); - RefPtr<TypeLayout> caseTypeLayout = caseTypeResult.layout; - UniformLayoutInfo caseTypeInfo = caseTypeResult.info.getUniformLayout(); - - info.size = maximum(info.size, caseTypeInfo.size); - info.alignment = std::max(info.alignment, caseTypeInfo.alignment); - - // We need to remember the layout of the case type - // on the final `TaggedUnionTypeLayout`. - // - taggedUnionLayout->caseTypeLayouts.add(caseTypeLayout); - - // We also need to consider contributions for other - // resource kinds beyond uniform data. - // - for( auto caseResInfo : caseTypeLayout->resourceInfos ) - { - auto unionResInfo = taggedUnionLayout->findOrAddResourceInfo(caseResInfo.kind); - unionResInfo->count = maximum(unionResInfo->count, caseResInfo.count); - } - } - - // After we've computed the size required to hold all the - // case types, we will allocate space for the tag field. - // - // TODO: This assumes the tag will always be allocated out - // of uniform storage, which means we can't support a tagged - // union as part of a varying input/output signature. That is - // probably a valid limitation, but it should get enforced - // somewhere along the way. - // - { - // The tag is always a `uint` for now. - // - auto tagInfo = context.rules->GetScalarLayout(BaseType::UInt); - info.size = _roundToAlignment(info.size, tagInfo.alignment); - - taggedUnionLayout->tagOffset = info.size; - - info.size += tagInfo.size; - info.alignment = std::max(info.alignment, tagInfo.alignment); + args.add(existentialSpecializedType->getArg(i)); } - - // As a final step, if we are computing a full `TypeLayout` - // we will make sure that its information on uniform layout - // matches what we've computed in the `UniformLayoutInfo` we return. - // - taggedUnionLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->count = info.size; - taggedUnionLayout->uniformAlignment = info.alignment; - - return _updateLayout(context, type, taggedUnionLayout, info); - } - else if( auto existentialSpecializedType = as<ExistentialSpecializedType>(type) ) - { TypeLayoutContext subContext = context.withSpecializationArgs( - existentialSpecializedType->args.getBuffer(), - existentialSpecializedType->args.getCount()); + args.getBuffer(), + args.getCount()); auto baseTypeLayoutResult = _createTypeLayout( subContext, - existentialSpecializedType->baseType); + existentialSpecializedType->getBaseType()); UniformLayoutInfo info = rules->BeginStructLayout(); rules->AddStructField(&info, baseTypeLayoutResult.info.getUniformLayout()); @@ -4534,7 +4437,7 @@ RefPtr<TypeLayout> getSimpleVaryingParameterTypeLayout( if(auto basicType = as<BasicExpressionType>(type)) { - auto baseType = basicType->baseType; + auto baseType = basicType->getBaseType(); RefPtr<TypeLayout> typeLayout = new TypeLayout(); typeLayout->type = type; @@ -4550,13 +4453,13 @@ RefPtr<TypeLayout> getSimpleVaryingParameterTypeLayout( } else if(auto vecType = as<VectorExpressionType>(type)) { - auto elementType = vecType->elementType; - size_t elementCount = (size_t) getIntVal(vecType->elementCount); + auto elementType = vecType->getElementType(); + size_t elementCount = (size_t) getIntVal(vecType->getElementCount()); BaseType elementBaseType = BaseType::Void; if( auto elementBasicType = as<BasicExpressionType>(elementType) ) { - elementBaseType = elementBasicType->baseType; + elementBaseType = elementBasicType->getBaseType(); } // Note that we do *not* add any resource usage to the type @@ -4592,7 +4495,7 @@ RefPtr<TypeLayout> getSimpleVaryingParameterTypeLayout( BaseType elementBaseType = BaseType::Void; if( auto elementBasicType = as<BasicExpressionType>(elementType) ) { - elementBaseType = elementBasicType->baseType; + elementBaseType = elementBasicType->getBaseType(); } // Just as for `_createTypeLayout`, we need to handle row- and @@ -4711,7 +4614,7 @@ GlobalGenericParamDecl* GenericParamTypeLayout::getGlobalGenericParamDecl() { auto declRefType = as<DeclRefType>(type); SLANG_ASSERT(declRefType); - auto rsDeclRef = declRefType->declRef.as<GlobalGenericParamDecl>(); + auto rsDeclRef = declRefType->getDeclRef().as<GlobalGenericParamDecl>(); return rsDeclRef.getDecl(); } diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index 7b822eac4..c800d0931 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -725,28 +725,6 @@ public: Index paramIndex = 0; }; - /// Layout information for a tagged union type. -class TaggedUnionTypeLayout : public TypeLayout -{ -public: - /// The layouts of each of the case types. - /// - /// The order of entries in this array matches - /// the order of case types on the original - /// `TaggedUnionType`, and the index of a case - /// type is also the tag value for that case. - /// - List<RefPtr<TypeLayout>> caseTypeLayouts; - - /// The byte offset for the tag field. - /// - /// The tag field will always be allocated as - /// a `uint`, so we don't store a separate layout - /// for it. - /// - LayoutSize tagOffset; -}; - /// Layout information for an interface/existential type /// /// This class is used to represent the layout of an interface type @@ -912,13 +890,6 @@ public: /// Dictionary<GlobalGenericParamDecl*, Val*> globalGenericArgs; - /// Layouts for all tagged union types required by this program - /// - /// These are any tagged union types used by the specialization - /// arguments that have been used to specialize the program. - /// - List<RefPtr<TypeLayout>> taggedUnionTypeLayouts; - /// Holds all of the string literals that have been hashed StringSlicePool hashedStringLiteralPool; }; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index e08bb2a62..266533874 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -161,11 +161,9 @@ void Session::init() m_sharedASTBuilder = new SharedASTBuilder; m_sharedASTBuilder->init(this); - // Use to create a ASTBuilder - RefPtr<ASTBuilder> builtinAstBuilder(new ASTBuilder(m_sharedASTBuilder, "m_builtInLinkage::m_astBuilder")); - // And the global ASTBuilder - globalAstBuilder = new ASTBuilder(m_sharedASTBuilder, "globalAstBuilder"); + auto builtinAstBuilder = m_sharedASTBuilder->getInnerASTBuilder(); + globalAstBuilder = builtinAstBuilder; // Make sure our source manager is initialized builtinSourceManager.initialize(nullptr, nullptr); @@ -367,6 +365,8 @@ SlangResult Session::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes) return SLANG_FAIL; } + SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder()); + // Make a file system to read it from ComPtr<ISlangFileSystemExt> fileSystem; SLANG_RETURN_ON_FAIL(loadArchiveFileSystem(stdLib, stdLibSizeInBytes, fileSystem)); @@ -397,6 +397,8 @@ SlangResult Session::saveStdLib(SlangArchiveType archiveType, ISlangBlob** outBl return SLANG_FAIL; } + SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder()); + for (auto& pair : m_builtinLinkage->mapNameToLoadedModules) { const Name* moduleName = pair.key; @@ -463,6 +465,7 @@ SlangResult Session::_readBuiltinModule(ISlangFileSystem* fileSystem, Scope* sco options.namePool = linkageNamePool; options.session = this; options.sharedASTBuilder = linkage->getASTBuilder()->getSharedASTBuilder(); + options.astBuilder = linkage->getASTBuilder(); options.sourceManager = sourceManger; options.linkage = linkage; @@ -920,6 +923,9 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinka , m_sourceManager(&m_defaultSourceManager) , m_astBuilder(astBuilder) { + if (builtinLinkage) + m_astBuilder->m_cachedNodes = builtinLinkage->getASTBuilder()->m_cachedNodes; + getNamePool()->setRootNamePool(session->getRootNamePool()); m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr); @@ -990,6 +996,8 @@ SLANG_NO_THROW slang::IGlobalSession* SLANG_MCALL Linkage::getGlobalSession() void Linkage::addTarget( slang::TargetDesc const& desc) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + auto targetIndex = addTarget(CodeGenTarget(desc.format)); auto target = targets[targetIndex]; @@ -1018,6 +1026,8 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModule( const char* moduleName, slang::IBlob** outDiagnostics) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer); if (isInLanguageServer()) @@ -1048,6 +1058,8 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSource( slang::IBlob* source, slang::IBlob** outDiagnostics) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer); if (isInLanguageServer()) { @@ -1096,6 +1108,8 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompositeComponentType( slang::IComponentType** outCompositeComponentType, ISlangBlob** outDiagnostics) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + // Attempting to create a "composite" of just one component type should // just return the component type itself, to avoid redundant work. // @@ -1131,6 +1145,8 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType( SlangInt specializationArgCount, ISlangBlob** outDiagnostics) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + auto unspecializedType = asInternal(inUnspecializedType); List<Type*> typeArgs; @@ -1157,6 +1173,8 @@ SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL Linkage::getTypeLayout( slang::LayoutRules rules, ISlangBlob** outDiagnostics) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + auto type = asInternal(inType); if(targetIndex < 0 || targetIndex >= targets.getCount()) @@ -1187,6 +1205,8 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType( slang::ContainerType containerType, ISlangBlob** outDiagnostics) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + auto type = asInternal(inType); Type* containerTypeReflection = nullptr; @@ -1197,29 +1217,20 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType( { case slang::ContainerType::ConstantBuffer: { - ConstantBufferType* cbType = getASTBuilder()->create<ConstantBufferType>(); - cbType->elementType = type; - cbType->declRef = getASTBuilder()->getBuiltinDeclRef( - "ConstantBuffer", static_cast<Val*>(type)); + ConstantBufferType* cbType = getASTBuilder()->getConstantBufferType(type); containerTypeReflection = cbType; } break; case slang::ContainerType::ParameterBlock: { - ParameterBlockType* pbType = getASTBuilder()->create<ParameterBlockType>(); - pbType->elementType = type; - pbType->declRef = getASTBuilder()->getBuiltinDeclRef( - "ParameterBlock", static_cast<Val*>(type)); + ParameterBlockType* pbType = getASTBuilder()->getParameterBlockType(type); containerTypeReflection = pbType; } break; case slang::ContainerType::StructuredBuffer: { HLSLStructuredBufferType* sbType = - getASTBuilder()->create<HLSLStructuredBufferType>(); - sbType->elementType = type; - sbType->declRef = getASTBuilder()->getBuiltinDeclRef( - "HLSLStructuredBufferType", static_cast<Val*>(type)); + getASTBuilder()->getStructuredBufferType(type); containerTypeReflection = sbType; } break; @@ -1244,16 +1255,20 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType( SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getDynamicType() { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + return asExternal(getASTBuilder()->getSharedASTBuilder()->getDynamicType()); } SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeRTTIMangledName( slang::TypeReflection* type, ISlangBlob** outNameBlob) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + auto internalType = asInternal(type); if (auto declRefType = as<DeclRefType>(internalType)) { - auto name = getMangledName(internalType->getASTBuilder(), declRefType->declRef); + auto name = getMangledName(m_astBuilder, declRefType->getDeclRef()); Slang::ComPtr<ISlangBlob> blob = Slang::StringUtil::createStringBlob(name); *outNameBlob = blob.detach(); return SLANG_OK; @@ -1264,9 +1279,11 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeRTTIMangledName( SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessMangledName( slang::TypeReflection* type, slang::TypeReflection* interfaceType, ISlangBlob** outNameBlob) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + auto subType = asInternal(type); auto supType = asInternal(interfaceType); - auto name = getMangledNameForConformanceWitness(subType->getASTBuilder(), subType, supType); + auto name = getMangledNameForConformanceWitness(m_astBuilder, subType, supType); Slang::ComPtr<ISlangBlob> blob = Slang::StringUtil::createStringBlob(name); *outNameBlob = blob.detach(); return SLANG_OK; @@ -1277,14 +1294,16 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessSequent slang::TypeReflection* interfaceType, uint32_t* outId) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + auto subType = asInternal(type); auto supType = asInternal(interfaceType); if (!subType || !supType) return SLANG_FAIL; - auto name = getMangledNameForConformanceWitness(subType->getASTBuilder(), subType, supType); - auto interfaceName = getMangledTypeName(supType->getASTBuilder(), supType); + auto name = getMangledNameForConformanceWitness(m_astBuilder, subType, supType); + auto interfaceName = getMangledTypeName(m_astBuilder, supType); uint32_t resultIndex = 0; if (mapMangledNameToRTTIObjectIndex.tryGetValue(name, resultIndex)) { @@ -1313,6 +1332,8 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentTy SlangInt conformanceIdOverride, ISlangBlob** outDiagnostics) { + SLANG_AST_BUILDER_RAII(getASTBuilder()); + RefPtr<TypeConformance> result; DiagnosticSink sink; try @@ -1550,6 +1571,8 @@ CapabilitySet TargetRequest::getTargetCaps() TypeLayout* TargetRequest::getTypeLayout(Type* type) { + SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder()); + // TODO: We are not passing in a `ProgramLayout` here, although one // is nominally required to establish the global ordering of // generic type parameters, which might be referenced from field types. @@ -1866,6 +1889,9 @@ Type* ComponentType::getTypeFromString( Scope* scope = _createScopeForLegacyLookup(astBuilder); auto linkage = getLinkage(); + + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + Expr* typeExpr = linkage->parseTermString( typeStr, scope); type = checkProperType(linkage, TypeExp(typeExpr), sink); @@ -2172,6 +2198,8 @@ void FrontEndCompileRequest::parseTranslationUnit( { auto linkage = getLinkage(); + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + // TODO(JS): NOTE! Here we are using the searchDirectories on the linkage. This is because // currently the API only allows the setting search paths on linkage. // @@ -2376,6 +2404,8 @@ void FrontEndCompileRequest::checkAllTranslationUnits() void FrontEndCompileRequest::generateIR() { + SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder()); + // Our task in this function is to generate IR code // for all of the declarations in the translation // units that were loaded. @@ -2469,6 +2499,8 @@ static SourceLanguage inferSourceLanguage(FrontEndCompileRequest* request) SlangResult FrontEndCompileRequest::executeActionsInner() { + SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder()); + // We currently allow GlSL files on the command line so that we can // drive our "pass-through" mode, but we really want to issue an error // message if the user is seriously asking us to compile them. @@ -3272,7 +3304,7 @@ Module::Module(Linkage* linkage, ASTBuilder* astBuilder) } else { - m_astBuilder = new ASTBuilder(linkage->getASTBuilder()->getSharedASTBuilder(), "Module"); + m_astBuilder = linkage->getASTBuilder(); } addModuleDependency(this); @@ -4091,36 +4123,28 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor maybeAddModule(module); } - void collectReferencedModules(Substitutions* substitution) + void collectReferencedModules(SubstitutionSet substitutions) { - if(auto genericSubst = as<GenericSubstitution>(substitution)) + substitutions.forEachGenericSubstitution([this](GenericDecl*, Val::OperandView<Val> args) { - for(auto arg : genericSubst->getArgs()) + for (auto arg : args) { collectReferencedModules(arg); } - } - } - - void collectReferencedModules(SubstitutionSet const& substitutions) - { - for(auto subst = substitutions.substitutions; subst; subst = subst->getOuter()) - { - collectReferencedModules(subst); - } + }); } - void collectReferencedModules(DeclRefBase const& declRef) + void collectReferencedModules(DeclRefBase* declRef) { - collectReferencedModules(declRef.getDecl()); - collectReferencedModules(declRef.getSubst()); + collectReferencedModules(declRef->getDecl()); + collectReferencedModules(SubstitutionSet(declRef)); } void collectReferencedModules(Type* type) { if(auto declRefType = as<DeclRefType>(type)) { - collectReferencedModules(declRefType->declRef); + collectReferencedModules(declRefType->getDeclRef()); } // TODO: Handle non-decl-ref composite type cases @@ -4135,7 +4159,7 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor } else if (auto declRefVal = as<GenericParamIntVal>(val)) { - collectReferencedModules(declRefVal->declRef); + collectReferencedModules(declRefVal->getDeclRef()); } // TODO: other cases of values that could reference @@ -4350,41 +4374,6 @@ SpecializedComponentType::SpecializedComponentType( m_moduleDependencies.add(module); } - // The following is a bit of a hack. - // - // TODO: We should not need this hack any longer, since the - // new approach to `switch`-based dynamic dispatch has made - // the existing tagged-union support obsolete. - // - // Back-end code generation relies on us having computed layouts for all tagged - // unions that end up being used in the code, which means we need a way to find - // all such types that get used in a program (and the stuff it imports). - // - // For now we are assuming a tagged union type only comes into existence - // as a (top-level) argument for a generic type parameter, so that we - // can check for them here and cache them on the entry point. - // - // A longer-term strategy might need to consider any (tagged or untagged) - // union types that get used inside of a module, and also take - // those lists into account. - // - // An even longer-term strategy would be to allow type layout to - // be performed on IR types, so taht we don't need to have front-end - // code worrying about this stuff. - // - for(auto arg : specializationArgs) - { - auto argType = as<Type>(arg.val); - if(!argType) - continue; - - auto taggedUnionType = as<TaggedUnionType>(argType); - if(!taggedUnionType) - continue; - - m_taggedUnionTypes.add(taggedUnionType); - } - // Because we are specializing shader code, the mangled entry // point names for this component type may be different than // for the base component type (e.g., the mangled name for `f<int>` diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index 74b625183..912a8f2a7 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -11,48 +11,66 @@ </Type> <Type Name="Slang::DeclRef<*>"> <DisplayString Condition="declRefBase == 0">DeclRef nullptr</DisplayString> + <DisplayString Condition="declRefBase != 0">{*declRefBase}</DisplayString> <Expand> - <ExpandedItem>declRefBase ? ($T1*)(declRefBase->decl) : ($T1*)0</ExpandedItem> - <Synthetic Name="[Substitutions]"> - <Expand> - <LinkedListItems> - <HeadPointer>declRefBase->substitutions</HeadPointer> - <NextPointer>outer</NextPointer> - <ValueNode>this</ValueNode> - </LinkedListItems> - </Expand> - </Synthetic> + <ExpandedItem>declRefBase</ExpandedItem> </Expand> </Type> <Type Name="Slang::DeclRefBase"> - <DisplayString Condition="decl != 0 && substitutions != 0">{*decl}{*substitutions}</DisplayString> - <DisplayString Condition="decl != 0">{*decl}</DisplayString> - <DisplayString Condition="decl == 0">DeclRefBase nullptr</DisplayString> - <Expand> - <ExpandedItem>decl</ExpandedItem> - <Synthetic Name="[Substitutions]"> - <Expand> - <LinkedListItems> - <HeadPointer>substitutions.substitutions</HeadPointer> - <NextPointer>outer</NextPointer> - <ValueNode>this</ValueNode> - </LinkedListItems> - </Expand> - </Synthetic> - </Expand> - </Type> - <Type Name="Slang::GenericSubstitution"> - <DisplayString>GenSubst {(*genericDecl).nameAndLoc}</DisplayString> + <DisplayString Optional="true" Condition="m_operands.m_buffer[0].values.nodeOperand != 0">{astNodeType,en}#{_debugUID}({(Decl*)m_operands.m_buffer[0].values.nodeOperand}) </DisplayString> + <DisplayString Condition="m_operands.m_buffer[0].values.nodeOperand != 0">{astNodeType,en}({(Decl*)m_operands.m_buffer[0].values.nodeOperand})</DisplayString> + <DisplayString Condition="m_operands.m_buffer[0].values.nodeOperand == 0">DeclRefBase nullptr</DisplayString> <Expand> - <Item Name="genericDecl">genericDecl</Item> - <ExpandedItem>args</ExpandedItem> + <Synthetic Name="[Decl]"> + <DisplayString>{*(Decl*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString> + <Expand> + <ExpandedItem>*(Decl*)m_operands.m_buffer[0].values.nodeOperand</ExpandedItem> + </Expand> + </Synthetic> + <Synthetic Condition="astNodeType == Slang::ASTNodeType::MemberDeclRef" Name="[Parent]"> + <DisplayString>{*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)}</DisplayString> + <Expand> + <ExpandedItem>*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)</ExpandedItem> + </Expand> + </Synthetic> + <Synthetic Condition="astNodeType == Slang::ASTNodeType::LookupDeclRef" Name="[Base]"> + <DisplayString>{*(Val*)(this->m_operands.m_buffer[1].values.nodeOperand)}</DisplayString> + <Expand> + <ExpandedItem>*(Val*)(this->m_operands.m_buffer[1].values.nodeOperand)</ExpandedItem> + </Expand> + </Synthetic> + <Synthetic Condition="astNodeType == Slang::ASTNodeType::LookupDeclRef" Name="[Witness]"> + <DisplayString>{*(SubtypeWitness*)(this->m_operands.m_buffer[2].values.nodeOperand)}</DisplayString> + <Expand> + <ExpandedItem>*(SubtypeWitness*)(this->m_operands.m_buffer[2].values.nodeOperand)</ExpandedItem> + </Expand> + </Synthetic> + <Synthetic Condition="astNodeType == Slang::ASTNodeType::GenericAppDeclRef" Name="[BaseGeneric]"> + <DisplayString>{*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)}</DisplayString> + <Expand> + <ExpandedItem>*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)</ExpandedItem> + </Expand> + </Synthetic> + <CustomListItems Condition="astNodeType == Slang::ASTNodeType::GenericAppDeclRef"> + <Variable Name="index" InitialValue="2"/> + <Loop Condition="index<m_operands.m_count"> + <Item Name="Arg[{index-2}]">*(Val*)(this->m_operands.m_buffer[index].values.nodeOperand)</Item> + <Exec>index=index+1</Exec> + </Loop> + </CustomListItems> </Expand> </Type> <Type Name="Slang::DeclRefType"> - <DisplayString>DeclRefType {declRef}</DisplayString> + <DisplayString Optional="true">{astNodeType,en}#{_debugUID} {*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand} </DisplayString> + + <DisplayString>{astNodeType,en} {*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString> <Expand> - <ExpandedItem>declRef</ExpandedItem> + <Synthetic Name="DeclRefType"> + <DisplayString Optional="true">{astNodeType,en}#{_debugUID} {m_operands.m_buffer[0].values.nodeOperand->astNodeType, en}#{m_operands.m_buffer[0].values.nodeOperand->_debugUID}</DisplayString> + <DisplayString>{astNodeType,en} {m_operands.m_buffer[0].values.nodeOperand->astNodeType, en}</DisplayString> + </Synthetic> + <ExpandedItem>*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand</ExpandedItem> </Expand> </Type> <Type Name="Slang::FuncDecl"> @@ -223,13 +241,12 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OpenRefExpr">(Slang::OpenRefExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ForwardDifferentiateExpr">(Slang::ForwardDifferentiateExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BackwardDifferentiateExpr">(Slang::BackwardDifferentiateExpr*)&astNodeType</ExpandedItem> - <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionTypeExpr">(Slang::TaggedUnionTypeExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisTypeExpr">(Slang::ThisTypeExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndTypeExpr">(Slang::AndTypeExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedTypeExpr">(Slang::ModifiedTypeExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PointerTypeExpr">(Slang::PointerTypeExpr*)&astNodeType</ExpandedItem> <Item Name="[type]">type</Item> - <Item Name="[Expr]">(Slang::Expr*)this,nd</Item> + <Item Name="[Expr]">(Slang::Expr*)this,!</Item> </Expand> </Type> <Type Name="Slang::Stmt" Inheritable="false"> @@ -261,18 +278,19 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ContinueStmt">(Slang::ContinueStmt*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ReturnStmt">(Slang::ReturnStmt*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExpressionStmt">(Slang::ExpressionStmt*)&astNodeType</ExpandedItem> - <Item Name="[Stmt]">(Slang::Stmt*)this,nd</Item> + <Item Name="[Stmt]">(Slang::Stmt*)this,!</Item> </Expand> </Type> <Type Name="Slang::Name"> <DisplayString>{text}</DisplayString> </Type> <Type Name="Slang::Decl" Inheritable="false"> - <DisplayString Condition="nameAndLoc.name!=0">{nameAndLoc.name->text}</DisplayString> + <DisplayString Condition="nameAndLoc.name!=0">{astNodeType,en} {nameAndLoc.name->text}</DisplayString> <DisplayString Condition="nameAndLoc.name==0">{astNodeType,en}</DisplayString> <Expand> <Item Name="[Name]" Condition="nameAndLoc.name!=0">nameAndLoc.name->text</Item> <Item Name="[Parent]">parentDecl</Item> + <Item Name="[CheckState]">Slang::DeclCheckState(checkState.m_raw & ~Slang::DeclCheckStateExt::kBeingCheckedBit)</Item> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ContainerDecl">(Slang::ContainerDecl*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtensionDecl">(Slang::ExtensionDecl*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::StructDecl">(Slang::StructDecl*)&astNodeType</ExpandedItem> @@ -314,7 +332,7 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::SyntaxDecl">(Slang::SyntaxDecl*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclGroup">(Slang::DeclGroup*)&astNodeType</ExpandedItem> - <Item Name="Decl">(Slang::DeclBase*)this,nd</Item> + <Item Name="Decl">(Slang::DeclBase*)this,!</Item> </Expand> </Type> @@ -361,20 +379,57 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::EmptyDecl">(Slang::EmptyDecl*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::SyntaxDecl">(Slang::SyntaxDecl*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclGroup">(Slang::DeclGroup*)&astNodeType</ExpandedItem> - <Item Name="Decl">(Slang::Decl*)this,nd</Item> + <Item Name="Decl">(Slang::Decl*)this,!</Item> </Expand> </Type> + <Type Name="Slang::TypeType" Inheritable="false"> + <DisplayString Optional="true">{astNodeType,en} #{_debugUID} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString> + <DisplayString>{astNodeType,en} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString> + <Expand> + <ExpandedItem>*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)</ExpandedItem> + </Expand> + </Type> + <Type Name="Slang::FuncType" Inheritable="false"> + <DisplayString Optional="true">{astNodeType,en} #{_debugUID}</DisplayString> + <DisplayString Optional="true">{astNodeType,en}</DisplayString> + <Expand> + <Synthetic Name="[ParamCount]"> + <DisplayString>{m_operands.m_count-2}</DisplayString> + </Synthetic> + <ArrayItems> + <Size>m_operands.m_count-2</Size> + <ValuePointer>m_operands.m_buffer</ValuePointer> + </ArrayItems> + <Synthetic Name="[ResultType]"> + <DisplayString>{m_operands.m_buffer[m_operands.m_count-2]}</DisplayString> + <Expand> + <ExpandedItem>m_operands.m_buffer[m_operands.m_count-2]</ExpandedItem> + </Expand> + </Synthetic> + <Synthetic Name="[ErrorType]"> + <DisplayString>{m_operands.m_buffer[m_operands.m_count-1]}</DisplayString> + <Expand> + <ExpandedItem>m_operands.m_buffer[m_operands.m_count-1]</ExpandedItem> + </Expand> + </Synthetic> + </Expand> + </Type> <Type Name="Slang::Type" Inheritable="false"> - <DisplayString Condition="astNodeType == Slang::ASTNodeType::DeclRefType">{((Slang::DeclRefType*)&astNodeType)->declRef}</DisplayString> + <DisplayString Optional="true" Condition="astNodeType == Slang::ASTNodeType::DeclRefType">DeclRefType#{_debugUID} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::DeclRefType">DeclRefType {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString> + <DisplayString Optional="true">{astNodeType,en} #{_debugUID}</DisplayString> <DisplayString>{astNodeType,en}</DisplayString> - <Expand> + + <Expand> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OverloadGroupType">(Slang::OverloadGroupType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::InitializerListType">(Slang::InitializerListType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ErrorType">(Slang::ErrorType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BottomType">(Slang::BottomType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclRefType">(Slang::DeclRefType*)&astNodeType</ExpandedItem> - <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DifferentialPairType">(Slang::DeclRefType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DifferentiableType">(Slang::DeclRefType*)&astNodeType</ExpandedItem> + + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DifferentialPairType">(Slang::DeclRefType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ArithmeticExpressionType">(Slang::ArithmeticExpressionType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BasicExpressionType">(Slang::BasicExpressionType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::VectorExpressionType">(Slang::VectorExpressionType*)&astNodeType</ExpandedItem> @@ -437,45 +492,51 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ArrayExpressionType">(Slang::ArrayExpressionType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TypeType">(Slang::TypeType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NamedExpressionType">(Slang::NamedExpressionType*)&astNodeType</ExpandedItem> - <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::FuncType">(Slang::FuncType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::FuncType">(Slang::FuncType*)this</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericDeclRefType">(Slang::GenericDeclRefType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NamespaceType">(Slang::NamespaceType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialType">(Slang::ExtractExistentialType*)&astNodeType</ExpandedItem> - <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionType">(Slang::TaggedUnionType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExistentialSpecializedType">(Slang::ExistentialSpecializedType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisType">(Slang::ThisType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndType">(Slang::AndType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedType">(Slang::ModifiedType*)&astNodeType</ExpandedItem> - <Item Name="[Type]">(Slang::Type*)this,nd</Item> - </Expand> - </Type> - <Type Name="Slang::Substitutions" Inheritable="false"> - <DisplayString Condition="astNodeType == Slang::ASTNodeType::GenericSubstitution">{*(Slang::GenericSubstitution*)&astNodeType}</DisplayString> - <DisplayString Condition="astNodeType == Slang::ASTNodeType::ThisTypeSubstitution">{*(Slang::ThisTypeSubstitution*)&astNodeType}</DisplayString> - <DisplayString>{astNodeType,en}</DisplayString> - <DisplayString>{astNodeType,en}</DisplayString> - <Expand> - <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericSubstitution">(Slang::GenericSubstitution*)&astNodeType</ExpandedItem> - <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisTypeSubstitution">(Slang::ThisTypeSubstitution*)&astNodeType</ExpandedItem> + <Item Name="[Raw View]">(Slang::Type*)this,!</Item> </Expand> </Type> - <Type Name="Slang::GenericSubstitution" Inheritable="false"> - <DisplayString Condition="outer != 0"><{args}>{*outer}</DisplayString> - <DisplayString><{args}></DisplayString> - </Type> - <Type Name="Slang::ThisTypeSubstitution" Inheritable="false"> - <DisplayString Condition="outer != 0">{*outer}[This=={witness->sub,na}]</DisplayString> - <DisplayString>[{witness->sup,na}.This: {witness->sub,na}]</DisplayString> - </Type> <Type Name="Slang::SubstitutionSet"> - <DisplayString>{astNodeType,en}</DisplayString> + <DisplayString>SubstitutionSet{declRef,en}</DisplayString> <Expand> - <LinkedListItems> - <HeadPointer>substitutions</HeadPointer> - <NextPointer>outer</NextPointer> - <ValueNode>(Slang::Substitutions*)this</ValueNode> - </LinkedListItems> + <ExpandedItem>declRef</ExpandedItem> + <CustomListItems MaxItemsPerView="24"> + <Variable Name="subst" InitialValue="declRef"/> + <Variable Name="substType" InitialValue="(Slang::ASTNodeType)0"/> + <Variable Name="shouldBreak" InitialValue="0"/> + <Loop Condition="subst != 0"> + <Exec>substType = subst->astNodeType </Exec> + <Exec>shouldBreak = 1 </Exec> + + <If Condition="substType == Slang::ASTNodeType::DirectDeclRef"> + <Break/> + </If> + <If Condition="substType == Slang::ASTNodeType::MemberDeclRef"> + <Exec>subst = (DeclRefBase*)(((Slang::MemberDeclRef*)subst)->m_operands.m_buffer[1].values.nodeOperand)</Exec> + <Exec>shouldBreak = 0 </Exec> + </If> + <If Condition="substType == Slang::ASTNodeType::LookupDeclRef"> + <Item>(LookupDeclRef*)subst</Item> + <Break/> + </If> + <If Condition="substType == Slang::ASTNodeType::GenericAppDeclRef"> + <Item>(GenericAppDeclRef*)subst</Item> + <Exec>subst = (DeclRefBase*)(((Slang::GenericAppDeclRef*)subst)->m_operands.m_buffer[1].values.nodeOperand)</Exec> + <Exec>shouldBreak = 0 </Exec> + </If> + <If Condition="shouldBreak"> + <Break/> + </If> + </Loop> + </CustomListItems> </Expand> </Type> <Type Name="Slang::AggTypeDecl"> @@ -484,9 +545,103 @@ <Item Name="[Members]">members</Item> </Expand> </Type> + <Type Name="Slang::ValNodeOperand"> + <DisplayString Condition="kind==Slang::ValNodeOperandKind::ConstantValue">Const({values.intOperand})</DisplayString> + <DisplayString Condition="kind==Slang::ValNodeOperandKind::ValNode">{*(Val*)values.nodeOperand}</DisplayString> + <DisplayString>{values.nodeOperand}</DisplayString> + <Expand> + <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ValNode">*(Val*)values.nodeOperand</ExpandedItem> + <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ASTNode">*values.nodeOperand</ExpandedItem> + </Expand> + </Type> <Type Name="Slang::Val" Inheritable="false"> - <DisplayString>{astNodeType,en}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::DirectDeclRef">{*(Slang::DirectDeclRef*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::LookupDeclRef">{*(Slang::LookupDeclRef*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::MemberDeclRef">{*(Slang::MemberDeclRef*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::GenericAppDeclRef">{*(Slang::GenericAppDeclRef*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ConstantIntVal">{*(Slang::ConstantIntVal*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::PolynomialIntVal">{*(Slang::PolynomialIntVal*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::GenericParamIntVal">{*(Slang::GenericParamIntVal*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::DeclaredSubtypeWitness">{*(Slang::DeclaredSubtypeWitness*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::TransitiveSubtypeWitness">{*(Slang::TransitiveSubtypeWitness*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::OverloadGroupType">{*(Slang::OverloadGroupType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::InitializerListType">{*(Slang::InitializerListType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ErrorType">{*(Slang::ErrorType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::BottomType">{*(Slang::BottomType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::DeclRefType">{*(Slang::DeclRefType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::DifferentialPairType">{*(Slang::DeclRefType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ArithmeticExpressionType">{*(Slang::ArithmeticExpressionType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::BasicExpressionType">{*(Slang::BasicExpressionType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::VectorExpressionType">{*(Slang::VectorExpressionType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::MatrixExpressionType">{*(Slang::MatrixExpressionType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::BuiltinType">{*(Slang::BuiltinType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::FeedbackType">{*(Slang::FeedbackType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ResourceType">{*(Slang::ResourceType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::TextureTypeBase">{*(Slang::TextureTypeBase*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::TextureType">{*(Slang::TextureType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::TextureSamplerType">{*(Slang::TextureSamplerType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::GLSLImageType">{*(Slang::GLSLImageType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::SamplerStateType">{*(Slang::SamplerStateType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::BuiltinGenericType">{*(Slang::BuiltinGenericType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::PointerLikeType">{*(Slang::PointerLikeType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ParameterGroupType">{*(Slang::ParameterGroupType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::UniformParameterGroupType">{*(Slang::UniformParameterGroupType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ConstantBufferType">{*(Slang::ConstantBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::TextureBufferType">{*(Slang::TextureBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::GLSLShaderStorageBufferType">{*(Slang::GLSLShaderStorageBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ParameterBlockType">{*(Slang::ParameterBlockType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::VaryingParameterGroupType">{*(Slang::VaryingParameterGroupType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::GLSLInputParameterGroupType">{*(Slang::GLSLInputParameterGroupType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::GLSLOutputParameterGroupType">{*(Slang::GLSLOutputParameterGroupType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLStructuredBufferTypeBase">{*(Slang::HLSLStructuredBufferTypeBase*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLStructuredBufferType">{*(Slang::HLSLStructuredBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLRWStructuredBufferType">{*(Slang::HLSLRWStructuredBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLRasterizerOrderedStructuredBufferType">{*(Slang::HLSLRasterizerOrderedStructuredBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLAppendStructuredBufferType">{*(Slang::HLSLAppendStructuredBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLConsumeStructuredBufferType">{*(Slang::HLSLConsumeStructuredBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLStreamOutputType">{*(Slang::HLSLStreamOutputType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLPointStreamType">{*(Slang::HLSLPointStreamType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLLineStreamType">{*(Slang::HLSLLineStreamType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLTriangleStreamType">{*(Slang::HLSLTriangleStreamType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::UntypedBufferResourceType">{*(Slang::UntypedBufferResourceType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLByteAddressBufferType">{*(Slang::HLSLByteAddressBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLRWByteAddressBufferType">{*(Slang::HLSLRWByteAddressBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLRasterizerOrderedByteAddressBufferType">{*(Slang::HLSLRasterizerOrderedByteAddressBufferType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::RaytracingAccelerationStructureType">{*(Slang::RaytracingAccelerationStructureType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLPatchType">{*(Slang::HLSLPatchType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLInputPatchType">{*(Slang::HLSLInputPatchType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLOutputPatchType">{*(Slang::HLSLOutputPatchType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::GLSLInputAttachmentType">{*(Slang::GLSLInputAttachmentType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::StringTypeBase">{*(Slang::StringTypeBase*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::StringType">{*(Slang::StringType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::NativeStringType">{*(Slang::NativeStringType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::DynamicType">{*(Slang::DynamicType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::EnumTypeType">{*(Slang::EnumTypeType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::PtrTypeBase">{*(Slang::PtrTypeBase*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::PtrType">{*(Slang::PtrType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ParamDirectionType">{(Slang::ParamDirectionType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::OutTypeBase">{*(Slang::OutTypeBase*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::OutType">{*(Slang::OutType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::InOutType">{*(Slang::InOutType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::RefType">{*(Slang::RefType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::NullPtrType">{*(Slang::NullPtrType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ArrayExpressionType">{*(Slang::ArrayExpressionType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::TypeType">{*(Slang::TypeType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::NamedExpressionType">{*(Slang::NamedExpressionType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::FuncType">{*(Slang::FuncType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::GenericDeclRefType">{*(Slang::GenericDeclRefType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::NamespaceType">{*(Slang::NamespaceType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialType">{*(Slang::ExtractExistentialType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ExistentialSpecializedType">{*(Slang::ExistentialSpecializedType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ThisType">{*(Slang::ThisType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::AndType">{*(Slang::AndType*)this}</DisplayString> + <DisplayString Condition="astNodeType == Slang::ASTNodeType::ModifiedType">{*(Slang::ModifiedType*)this}</DisplayString> + <Expand> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DirectDeclRef">(Slang::DirectDeclRef*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::LookupDeclRef">(Slang::LookupDeclRef*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::MemberDeclRef">(Slang::MemberDeclRef*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericAppDeclRef">(Slang::GenericAppDeclRef*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ConstantIntVal">(Slang::ConstantIntVal*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PolynomialIntVal">(Slang::PolynomialIntVal*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericParamIntVal">(Slang::GenericParamIntVal*)&astNodeType</ExpandedItem> @@ -560,11 +715,15 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericDeclRefType">(Slang::GenericDeclRefType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NamespaceType">(Slang::NamespaceType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialType">(Slang::ExtractExistentialType*)&astNodeType</ExpandedItem> - <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionType">(Slang::TaggedUnionType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExistentialSpecializedType">(Slang::ExistentialSpecializedType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisType">(Slang::ThisType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndType">(Slang::AndType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedType">(Slang::ModifiedType*)&astNodeType</ExpandedItem> + <Synthetic Name="[RawOperands]"> + <Expand> + <ExpandedItem>m_operands</ExpandedItem> + </Expand> + </Synthetic> </Expand> </Type> <Type Name="Slang::Facet"> @@ -594,7 +753,41 @@ </Type> <Type Name="Slang::SubtypeWitness"> <DisplayString Condition="astNodeType == Slang::ASTNodeType::TypeEqualityWitness">{*(Slang::TypeEqualityWitness*)this}</DisplayString> - <DisplayString>{sub,na} <: {sup,na}</DisplayString> + <DisplayString Optional="true">{astNodeType,en}#{_debugUID}({*(Type*)m_operands.m_buffer[0].values.nodeOperand,na} <: {*(Type*)m_operands.m_buffer[1].values.nodeOperand,na})</DisplayString> + <DisplayString>{astNodeType,en}({*(Type*)m_operands.m_buffer[0].values.nodeOperand,na} <: {*(Type*)m_operands.m_buffer[1].values.nodeOperand,na})</DisplayString> + + <Expand> + <Synthetic Name="[Sub]"> + <DisplayString>{*(Type*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString> + <Expand> + <ExpandedItem>(Type*)m_operands.m_buffer[0].values.nodeOperand</ExpandedItem> + </Expand> + </Synthetic> + <Synthetic Name="[Sup]"> + <DisplayString>{*(Type*)m_operands.m_buffer[1].values.nodeOperand}</DisplayString> + <Expand> + <ExpandedItem>(Type*)m_operands.m_buffer[1].values.nodeOperand</ExpandedItem> + </Expand> + </Synthetic> + <Synthetic Name="[DeclRef]" Condition="astNodeType == Slang::ASTNodeType::DeclaredSubtypeWitness"> + <DisplayString>{*(Val*)m_operands.m_buffer[2].values.nodeOperand}</DisplayString> + <Expand> + <ExpandedItem>(DeclRefBase*)m_operands.m_buffer[2].values.nodeOperand</ExpandedItem> + </Expand> + </Synthetic> + <Synthetic Name="[SubToMid]" Condition="astNodeType == Slang::ASTNodeType::TransitiveSubtypeWitness"> + <DisplayString>{*(SubtypeWitness*)m_operands.m_buffer[2].values.nodeOperand}</DisplayString> + <Expand> + <ExpandedItem>(SubtypeWitness*)m_operands.m_buffer[2].values.nodeOperand</ExpandedItem> + </Expand> + </Synthetic> + <Synthetic Name="[MidToSup]" Condition="astNodeType == Slang::ASTNodeType::TransitiveSubtypeWitness"> + <DisplayString>{*(SubtypeWitness*)m_operands.m_buffer[3].values.nodeOperand}</DisplayString> + <Expand> + <ExpandedItem>(SubtypeWitness*)m_operands.m_buffer[3].values.nodeOperand</ExpandedItem> + </Expand> + </Synthetic> + </Expand> </Type> <Type Name="Slang::TypeEqualityWitness"> <DisplayString>{sub,na} == {sup,na}</DisplayString> |
