summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-04 15:47:39 -0700
committerGitHub <noreply@github.com>2023-08-04 15:47:39 -0700
commita2d90fb275962da84611160f8ddd74d934a68dbd (patch)
tree066084537b9f4fe1f367de100ed6638a88a028c1 /source
parent17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff)
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val. * Update project files * Fix warning. * Fix. * Fix. * Remove `Val::_equalsImplOverride`. * Rmove `Val::_getHashCodeOverride`. * Remove `semanticVisitor` param from `resolve`. * Cleanups. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/compiler-core/slang-diagnostic-sink.h2
-rw-r--r--source/compiler-core/slang-name.cpp7
-rw-r--r--source/compiler-core/slang-name.h1
-rw-r--r--source/compiler-core/slang-slice-allocator.h2
-rw-r--r--source/core/slang-array-view.h2
-rw-r--r--source/core/slang-hash.h2
-rw-r--r--source/core/slang-list.h5
-rw-r--r--source/core/slang-short-list.h8
-rw-r--r--source/core/slang-token-reader.cpp2
-rw-r--r--source/slang/core.meta.slang18
-rw-r--r--source/slang/hlsl.meta.slang6
-rw-r--r--source/slang/slang-ast-base.cpp33
-rw-r--r--source/slang/slang-ast-base.h548
-rw-r--r--source/slang/slang-ast-builder.cpp281
-rw-r--r--source/slang/slang-ast-builder.h362
-rw-r--r--source/slang/slang-ast-decl-ref.cpp461
-rw-r--r--source/slang/slang-ast-decl.cpp18
-rw-r--r--source/slang/slang-ast-decl.h27
-rw-r--r--source/slang/slang-ast-dump.cpp55
-rw-r--r--source/slang/slang-ast-expr.h13
-rw-r--r--source/slang/slang-ast-iterator.h5
-rw-r--r--source/slang/slang-ast-modifier.cpp7
-rw-r--r--source/slang/slang-ast-modifier.h17
-rw-r--r--source/slang/slang-ast-natural-layout.cpp46
-rw-r--r--source/slang/slang-ast-print.cpp25
-rw-r--r--source/slang/slang-ast-reflect.cpp2
-rw-r--r--source/slang/slang-ast-substitutions.cpp163
-rw-r--r--source/slang/slang-ast-support-types.cpp1
-rw-r--r--source/slang/slang-ast-support-types.h78
-rw-r--r--source/slang/slang-ast-synthesis.cpp3
-rw-r--r--source/slang/slang-ast-type.cpp827
-rw-r--r--source/slang/slang-ast-type.h349
-rw-r--r--source/slang/slang-ast-val.cpp1477
-rw-r--r--source/slang/slang-ast-val.h485
-rw-r--r--source/slang/slang-check-conformance.cpp76
-rw-r--r--source/slang/slang-check-constraint.cpp165
-rw-r--r--source/slang/slang-check-conversion.cpp43
-rw-r--r--source/slang/slang-check-decl.cpp564
-rw-r--r--source/slang/slang-check-expr.cpp185
-rw-r--r--source/slang/slang-check-impl.h83
-rw-r--r--source/slang/slang-check-inheritance.cpp42
-rw-r--r--source/slang/slang-check-modifier.cpp38
-rw-r--r--source/slang/slang-check-overload.cpp128
-rw-r--r--source/slang/slang-check-resolve-val.cpp48
-rw-r--r--source/slang/slang-check-shader.cpp42
-rw-r--r--source/slang/slang-check-stmt.cpp16
-rw-r--r--source/slang/slang-check-type.cpp64
-rw-r--r--source/slang/slang-check.cpp2
-rw-r--r--source/slang/slang-compiler.cpp6
-rwxr-xr-xsource/slang/slang-compiler.h6
-rw-r--r--source/slang/slang-doc-markdown-writer.cpp2
-rw-r--r--source/slang/slang-emit-c-like.cpp1
-rw-r--r--source/slang/slang-emit.cpp5
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h63
-rw-r--r--source/slang/slang-ir-link.cpp1
-rw-r--r--source/slang/slang-ir-union.cpp773
-rw-r--r--source/slang/slang-ir-union.h18
-rw-r--r--source/slang/slang-ir.cpp42
-rw-r--r--source/slang/slang-ir.h5
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp5
-rw-r--r--source/slang/slang-language-server-completion.cpp2
-rw-r--r--source/slang/slang-language-server.cpp33
-rw-r--r--source/slang/slang-lookup.cpp146
-rw-r--r--source/slang/slang-lookup.h3
-rw-r--r--source/slang/slang-lower-to-ir.cpp612
-rw-r--r--source/slang/slang-mangle.cpp82
-rw-r--r--source/slang/slang-parameter-binding.cpp24
-rw-r--r--source/slang/slang-parser.cpp69
-rw-r--r--source/slang/slang-reflection-api.cpp53
-rw-r--r--source/slang/slang-serialize-ast-type-info.h73
-rw-r--r--source/slang/slang-serialize-container.cpp64
-rw-r--r--source/slang/slang-serialize-type-info.h99
-rw-r--r--source/slang/slang-serialize-types.cpp3
-rw-r--r--source/slang/slang-serialize.cpp5
-rw-r--r--source/slang/slang-serialize.h15
-rw-r--r--source/slang/slang-syntax.cpp922
-rw-r--r--source/slang/slang-syntax.h84
-rw-r--r--source/slang/slang-type-layout.cpp143
-rw-r--r--source/slang/slang-type-layout.h29
-rw-r--r--source/slang/slang.cpp135
-rw-r--r--source/slang/slang.natvis335
82 files changed, 4374 insertions, 6321 deletions
diff --git a/source/compiler-core/slang-diagnostic-sink.h b/source/compiler-core/slang-diagnostic-sink.h
index e4d131e37..fc5e31b47 100644
--- a/source/compiler-core/slang-diagnostic-sink.h
+++ b/source/compiler-core/slang-diagnostic-sink.h
@@ -310,7 +310,7 @@ private:
class DiagnosticsLookup : public RefObject
{
public:
- static const Index kArenaInitialSize = 2048;
+ static const Index kArenaInitialSize = 65536;
/// Will take into account the slice name could be using different conventions
const DiagnosticInfo* findDiagnosticByName(const UnownedStringSlice& slice) const;
diff --git a/source/compiler-core/slang-name.cpp b/source/compiler-core/slang-name.cpp
index cc2033339..c815b8aa8 100644
--- a/source/compiler-core/slang-name.cpp
+++ b/source/compiler-core/slang-name.cpp
@@ -19,7 +19,7 @@ const char* getCstr(Name* name)
return name ? name->text.getBuffer() : nullptr;
}
-Name* NamePool::getName(String const& text)
+Name* NamePool::getName(UnownedStringSlice text)
{
RefPtr<Name> name;
if (rootPool->names.tryGetValue(text, name))
@@ -31,6 +31,11 @@ Name* NamePool::getName(String const& text)
return name;
}
+Name* NamePool::getName(String const& text)
+{
+ return getName(text.getUnownedSlice());
+}
+
Name* NamePool::tryGetName(String const& text)
{
RefPtr<Name> name;
diff --git a/source/compiler-core/slang-name.h b/source/compiler-core/slang-name.h
index cf702686b..f8c1201af 100644
--- a/source/compiler-core/slang-name.h
+++ b/source/compiler-core/slang-name.h
@@ -68,6 +68,7 @@ struct RootNamePool
struct NamePool
{
// Find or create the `Name` that represents the given `text`.
+ Name* getName(UnownedStringSlice text);
Name* getName(String const& text);
// Try find the `Name` that represents the given `text`.
// If the name does not exist, return nullptr
diff --git a/source/compiler-core/slang-slice-allocator.h b/source/compiler-core/slang-slice-allocator.h
index e4ba9e907..a6f0cd5c1 100644
--- a/source/compiler-core/slang-slice-allocator.h
+++ b/source/compiler-core/slang-slice-allocator.h
@@ -97,7 +97,7 @@ struct SliceAllocator
void deallocateAll() { m_arena.deallocateAll(); }
SliceAllocator():
- m_arena(1024)
+ m_arena(2097152)
{
}
protected:
diff --git a/source/core/slang-array-view.h b/source/core/slang-array-view.h
index 99609ef69..50270e0a0 100644
--- a/source/core/slang-array-view.h
+++ b/source/core/slang-array-view.h
@@ -197,6 +197,8 @@ namespace Slang
return ThisType(m_buffer + index, m_count - index);
}
+ T& getLast() { return m_buffer[m_count - 1]; }
+
ArrayView() : Super() {}
ArrayView(T* buffer, Index size) :Super(buffer, size) {}
};
diff --git a/source/core/slang-hash.h b/source/core/slang-hash.h
index bc4b30ccc..5f6b1b060 100644
--- a/source/core/slang-hash.h
+++ b/source/core/slang-hash.h
@@ -138,7 +138,7 @@ namespace Slang
template<typename TKey>
static HashCode getHashCode(TKey const& key)
{
- return (HashCode)((PtrInt)key) / 16; // sizeof(typename std::remove_pointer<TKey>::type);
+ return (HashCode)((PtrInt)key) >> 2; // sizeof(typename std::remove_pointer<TKey>::type);
}
};
template<>
diff --git a/source/core/slang-list.h b/source/core/slang-list.h
index ff756035c..250b6dc49 100644
--- a/source/core/slang-list.h
+++ b/source/core/slang-list.h
@@ -52,6 +52,11 @@ namespace Slang
{
this->operator=(static_cast<List<T>&&>(list));
}
+ List(ArrayView<T> view)
+ : List()
+ {
+ addRange(view);
+ }
static List<T> makeRepeated(const T& val, Index count)
{
List<T> rs;
diff --git a/source/core/slang-short-list.h b/source/core/slang-short-list.h
index 5bad9faf8..adbb935e6 100644
--- a/source/core/slang-short-list.h
+++ b/source/core/slang-short-list.h
@@ -117,17 +117,17 @@ namespace Slang
}
};
- Iterator begin()
+ Iterator begin() const
{
Iterator rs;
- rs.container = this;
+ rs.container = const_cast<ThisType*>(this);
rs.index = 0;
return rs;
}
- Iterator end()
+ Iterator end() const
{
Iterator rs;
- rs.container = this;
+ rs.container = const_cast<ThisType*>(this);
rs.index = m_count;
return rs;
}
diff --git a/source/core/slang-token-reader.cpp b/source/core/slang-token-reader.cpp
index f6f29def3..e8ebfb9ec 100644
--- a/source/core/slang-token-reader.cpp
+++ b/source/core/slang-token-reader.cpp
@@ -416,7 +416,7 @@ namespace Misc {
tokenFlags |= TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace;
pos++;
}
- else if (curChar == ' ' || curChar == '\t' || curChar == -62 || curChar == -96) // -62/-96:non-break space
+ else if (curChar == ' ' || curChar == '\t' || curChar == '\xC2' || curChar == '\xA0') // -62/-96:non-break space
{
tokenFlags |= TokenFlag::AfterWhitespace;
pos++;
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index ae70a83f4..efd0a743c 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -316,7 +316,7 @@ interface __FlagsEnumType : __EnumType
};
__generic<T, let N:int>
-__magic_type(ArrayType)
+__magic_type(ArrayExpressionType)
struct Array
{
}
@@ -876,7 +876,7 @@ extension int16_t
/// An `N` component vector with elements of type `T`.
__generic<T = float, let N : int = 4>
-__magic_type(Vector)
+__magic_type(VectorExpressionType)
struct vector
{
/// The element type of the vector
@@ -896,7 +896,7 @@ struct vector
/// A matrix with `R` rows and `C` columns, with elements of type `T`.
__generic<T = float, let R : int = 4, let C : int = 4>
-__magic_type(Matrix)
+__magic_type(MatrixExpressionType)
struct matrix
{
__intrinsic_op($(kIROp_MakeMatrixFromScalar))
@@ -957,12 +957,12 @@ for (int tt = 0; tt < kTypeCount; ++tt)
__generic<T>
__intrinsic_type($(kIROp_ConstantBufferType))
-__magic_type(ConstantBuffer)
+__magic_type(ConstantBufferType)
struct ConstantBuffer {}
__generic<T>
__intrinsic_type($(kIROp_TextureBufferType))
-__magic_type(TextureBuffer)
+__magic_type(TextureBufferType)
struct TextureBuffer {}
__generic<T>
@@ -1238,14 +1238,14 @@ extension matrix<T, R, C> : IDifferentiable
//@ public:
/// Sampling state for filtered texture fetches.
-__magic_type(SamplerState, $(int(SamplerStateFlavor::SamplerState)))
+__magic_type(SamplerStateType, $(int(SamplerStateFlavor::SamplerState)))
__intrinsic_type($(kIROp_SamplerStateType))
struct SamplerState
{
}
/// Sampling state for filtered texture fetches that include a comparison operation before filtering.
-__magic_type(SamplerState, $(int(SamplerStateFlavor::SamplerComparisonState)))
+__magic_type(SamplerStateType, $(int(SamplerStateFlavor::SamplerComparisonState)))
__intrinsic_type($(kIROp_SamplerComparisonStateType))
struct SamplerComparisonState
{
@@ -1347,12 +1347,12 @@ struct TextureTypeInfo
if(prefixInfo.combined)
{
- sb << "__magic_type(TextureSampler," << int(flavor) << ")\n";
+ sb << "__magic_type(TextureSamplerType," << int(flavor) << ")\n";
sb << "__intrinsic_type(" << (kIROp_TextureSamplerType + (int(flavor) << kIROpMeta_OtherShift)) << ")\n";
}
else
{
- sb << "__magic_type(Texture," << int(flavor) << ")\n";
+ sb << "__magic_type(TextureType," << int(flavor) << ")\n";
sb << "__intrinsic_type(" << (kIROp_TextureType + (int(flavor) << kIROpMeta_OtherShift)) << ")\n";
}
sb << "struct ";
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 9716a3e9e..1ab046b19 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -5048,7 +5048,7 @@ for (int aa = 0; aa < kBaseBufferAccessLevelCount; ++aa)
bool isReadOnly = (access == SLANG_RESOURCE_ACCESS_READ);
auto flavor = TextureFlavor::create(TextureFlavor::Shape::ShapeBuffer, access).flavor;
sb << "__generic<T>\n";
- sb << "__magic_type(Texture," << int(flavor) << ")\n";
+ sb << "__magic_type(TextureType," << int(flavor) << ")\n";
sb << "__intrinsic_type(" << (kIROp_TextureType + (int(flavor) << kIROpMeta_OtherShift)) << ")\n";
sb << "struct ";
sb << kBaseBufferAccessLevels[aa].name;
@@ -5566,7 +5566,7 @@ static const int feedbackTexture2DFlavor = int(TextureFlavor::create(TextureFlav
static const int feedbackTexture2DArrayFlavor = int(TextureFlavor::create(TextureFlavor::Shape::Shape2D, SLANG_RESOURCE_ACCESS_WRITE, SLANG_TEXTURE_FEEDBACK_FLAG | SLANG_TEXTURE_ARRAY_FLAG).flavor);
}}}}
-__magic_type(Texture, $(feedbackTexture2DFlavor))
+__magic_type(TextureType, $(feedbackTexture2DFlavor))
__intrinsic_type($(kIROp_TextureType + (feedbackTexture2DFlavor << kIROpMeta_OtherShift)))
struct FeedbackTexture2D<T : __BuiltinSamplerFeedbackType>
{
@@ -5619,7 +5619,7 @@ struct FeedbackTexture2D<T : __BuiltinSamplerFeedbackType>
-__magic_type(Texture, $(feedbackTexture2DArrayFlavor))
+__magic_type(TextureType, $(feedbackTexture2DArrayFlavor))
__intrinsic_type($(kIROp_TextureType + (feedbackTexture2DArrayFlavor << kIROpMeta_OtherShift)))
struct FeedbackTexture2DArray<T : __BuiltinSamplerFeedbackType>
{
diff --git a/source/slang/slang-ast-base.cpp b/source/slang/slang-ast-base.cpp
new file mode 100644
index 000000000..0ad2bb101
--- /dev/null
+++ b/source/slang/slang-ast-base.cpp
@@ -0,0 +1,33 @@
+#include "slang-ast-base.h"
+#include "slang-ast-builder.h"
+
+namespace Slang
+{
+void NodeBase::_initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder)
+{
+#ifdef _DEBUG
+ SLANG_UNUSED(inAstNodeType);
+ static int32_t uidCounter = 0;
+ static int32_t breakValue = 0;
+ uidCounter++;
+ _debugUID = uidCounter;
+ if (inAstBuilder->getId() == -1)
+ _debugUID = -_debugUID;
+ if (breakValue != 0 && _debugUID == breakValue)
+ SLANG_BREAKPOINT(0)
+#else
+ SLANG_UNUSED(inAstNodeType);
+ SLANG_UNUSED(inAstBuilder);
+#endif
+}
+DeclRefBase* Decl::getDefaultDeclRef()
+{
+ auto astBuilder = getCurrentASTBuilder();
+ if (astBuilder->getEpoch() != m_defaultDeclRefEpoch || !m_defaultDeclRef)
+ {
+ m_defaultDeclRef = astBuilder->getDirectDeclRef(this);
+ m_defaultDeclRefEpoch = astBuilder->getEpoch();
+ }
+ return m_defaultDeclRef;
+}
+}
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index b90014316..d8f4c8c6c 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -16,25 +16,26 @@
namespace Slang
{
+class ASTBuilder;
+struct SemanticsVisitor;
+
class NodeBase
{
SLANG_ABSTRACT_AST_CLASS(NodeBase)
// MUST be called before used. Called automatically via the ASTBuilder.
// Note that the astBuilder is not stored in the NodeBase derived types by default.
- SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* /*astBuilder*/)
+ SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder)
{
+ SLANG_UNUSED(inAstBuilder);
astNodeType = inAstNodeType;
#ifdef _DEBUG
- static uint32_t uidCounter = 0;
- static uint32_t breakValue = 0;
- uidCounter++;
- _debugUID = uidCounter;
- if (breakValue != 0 && _debugUID == breakValue)
- SLANG_BREAKPOINT(0)
+ _initDebug(inAstNodeType, inAstBuilder);
#endif
}
+ void _initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder);
+
/// Get the class info
SLANG_FORCE_INLINE const ReflectClassInfo& getClassInfo() const { return *ASTClassInfo::getInfo(astNodeType); }
@@ -48,7 +49,7 @@ class NodeBase
// Handy when debugging, shouldn't be checked in though!
// virtual ~NodeBase() {}
#ifdef _DEBUG
- SLANG_UNREFLECTED uint32_t _debugUID = 0;
+ SLANG_UNREFLECTED int32_t _debugUID = 0;
#endif
};
@@ -82,7 +83,14 @@ SLANG_FORCE_INLINE const T* as(const NodeBase* node)
// to avoid confusion and bugs. Instead, use the `as<>()` method on `DeclRefBase` to
// get a `DeclRef<T>` for a specific node type.
template<typename T>
-T* as(const DeclRefBase* declRefBase) = delete;
+T* as(DeclRefBase* declRefBase, typename EnableIf<!IsBaseOf<DeclRefBase, T>::Value, void*>::type arg = nullptr) = delete;
+
+template<typename T>
+T* as(DeclRefBase* declRefBase, typename EnableIf<IsBaseOf<DeclRefBase, T>::Value, void*>::type arg = nullptr)
+{
+ SLANG_UNUSED(arg);
+ return dynamicCast<T>(declRefBase);
+}
template<typename T, typename U>
DeclRef<T> as(DeclRef<U> declRef) { return DeclRef<T>(declRef); }
@@ -116,6 +124,177 @@ class SyntaxNodeBase : public NodeBase
SourceLoc loc;
};
+enum class ValNodeOperandKind
+{
+ ConstantValue,
+ ValNode,
+ ASTNode,
+};
+
+struct ValNodeOperand
+{
+ ValNodeOperandKind kind = ValNodeOperandKind::ConstantValue;
+ union
+ {
+ NodeBase* nodeOperand;
+ int64_t intOperand;
+ } values;
+
+ ValNodeOperand()
+ {
+ values.nodeOperand = nullptr;
+ }
+
+ explicit ValNodeOperand(NodeBase* node)
+ {
+ if (as<Val>(node))
+ {
+ values.nodeOperand = (NodeBase*)node;
+ kind = ValNodeOperandKind::ValNode;
+ }
+ else
+ {
+ values.nodeOperand = node;
+ kind = ValNodeOperandKind::ASTNode;
+ }
+ }
+
+ template<typename T>
+ explicit ValNodeOperand(DeclRef<T> declRef) { values.nodeOperand = declRef.declRefBase; kind = ValNodeOperandKind::ValNode; }
+
+ template<typename T>
+ explicit ValNodeOperand(T* node)
+ {
+ if constexpr (std::is_base_of<Val, T>::value)
+ {
+ values.nodeOperand = (NodeBase*)node;
+ kind = ValNodeOperandKind::ValNode;
+ }
+ else if constexpr (std::is_base_of<NodeBase, T>::value)
+ {
+ values.nodeOperand = node;
+ kind = ValNodeOperandKind::ASTNode;
+ }
+ else
+ {
+ static_assert(std::is_base_of<Val, T>::value || std::is_base_of<NodeBase, T>::value, "pointer used as Val operand must be an AST node.");
+ }
+ }
+
+ template<typename EnumType>
+ explicit ValNodeOperand(EnumType intVal)
+ {
+ static_assert(std::is_trivial<EnumType>::value, "Type to construct NodeOperand must be trivial.");
+ static_assert(sizeof(EnumType) <= sizeof(values), "size of operand must be less than pointer size.");
+ values.intOperand = 0;
+ memcpy(&values, &intVal, sizeof(intVal));
+ kind = ValNodeOperandKind::ConstantValue;
+ }
+};
+
+struct ValNodeDesc
+{
+ ASTNodeType type;
+ ShortList<ValNodeOperand, 4> operands;
+
+ bool operator==(ValNodeDesc const& that) const;
+ HashCode getHashCode() const { return hashCode; }
+ void init();
+private:
+ HashCode hashCode = 0;
+};
+
+template<int N>
+static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>&)
+{}
+
+template<int N, typename... Ts>
+static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, ExpandedSpecializationArgs e, Ts... ts)
+{
+ for (auto arg : e)
+ {
+ list.add(ValNodeOperand(arg.val));
+ list.add(ValNodeOperand(arg.witness));
+ }
+ addOrAppendToNodeList(list, ts...);
+}
+
+template<int N, typename T, typename... Ts>
+static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, T t, Ts... ts)
+{
+ list.add(ValNodeOperand(t));
+ addOrAppendToNodeList(list, ts...);
+}
+
+template<int N, typename T, typename... Ts>
+static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, const List<T>& l, Ts... ts)
+{
+ for (auto t : l)
+ list.add(ValNodeOperand(t));
+ addOrAppendToNodeList(list, ts...);
+}
+
+template<int N, typename T, typename... Ts>
+static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, ConstArrayView<T> l, Ts... ts)
+{
+ for (auto t : l)
+ list.add(ValNodeOperand(t));
+ addOrAppendToNodeList(list, ts...);
+}
+
+template<int N, typename T, typename... Ts>
+static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, ArrayView<T> l, Ts... ts)
+{
+ for (auto t : l)
+ list.add(ValNodeOperand(t));
+ addOrAppendToNodeList(list, ts...);
+}
+
+inline void addOrAppendToNodeList(List<ValNodeOperand>&)
+{}
+
+template<typename... Ts>
+static void addOrAppendToNodeList(List<ValNodeOperand>& list, ExpandedSpecializationArgs e, Ts... ts)
+{
+ for (auto arg : e)
+ {
+ list.add(ValNodeOperand(arg.val));
+ list.add(ValNodeOperand(arg.witness));
+ }
+ addOrAppendToNodeList(list, ts...);
+}
+
+template<typename T, typename... Ts>
+static void addOrAppendToNodeList(List<ValNodeOperand>& list, T t, Ts... ts)
+{
+ list.add(ValNodeOperand(t));
+ addOrAppendToNodeList(list, ts...);
+}
+
+template<typename T, typename... Ts>
+static void addOrAppendToNodeList(List<ValNodeOperand>& list, const List<T>& l, Ts... ts)
+{
+ for (auto t : l)
+ list.add(ValNodeOperand(t));
+ addOrAppendToNodeList(list, ts...);
+}
+
+template<typename T, typename... Ts>
+static void addOrAppendToNodeList(List<ValNodeOperand>& list, ConstArrayView<T> l, Ts... ts)
+{
+ for (auto t : l)
+ list.add(ValNodeOperand(t));
+ addOrAppendToNodeList(list, ts...);
+}
+
+template<typename T, typename... Ts>
+static void addOrAppendToNodeList(List<ValNodeOperand>& list, ArrayView<T> l, Ts... ts)
+{
+ for (auto t : l)
+ list.add(ValNodeOperand(t));
+ addOrAppendToNodeList(list, ts...);
+}
+
// Base class for compile-time values (most often a type).
// These are *not* syntax nodes, because they do not have
// a unique location, and any two `Val`s representing
@@ -124,6 +303,54 @@ class SyntaxNodeBase : public NodeBase
class Val : public NodeBase
{
SLANG_ABSTRACT_AST_CLASS(Val)
+
+ template<typename T>
+ struct OperandView
+ {
+ const Val* val;
+ Index offset;
+ Index count;
+ OperandView()
+ {
+ val = nullptr;
+ offset = 0;
+ count = 0;
+ }
+ OperandView(const Val* val, Index offset, Index count) : val(val), offset(offset), count(count) {}
+ Index getCount() { return count; }
+ T* operator[](Index index) const
+ {
+ return as<T>(val->getOperand(index + offset));
+ }
+ struct Iterator
+ {
+ const Val* val;
+ Index i;
+ bool operator==(Iterator other) const
+ {
+ return val == other.val && i == other.i;
+ }
+ bool operator!=(Iterator other) const
+ {
+ return val != other.val || i != other.i;
+ }
+ T*& operator*() const
+ {
+ return *(T**)&val->m_operands[i].values.nodeOperand;
+ }
+ T** operator->() const
+ {
+ return (T**)&val->m_operands[i].values.nodeOperand;
+ }
+ Iterator& operator++()
+ {
+ i++;
+ return *this;
+ }
+ };
+ Iterator begin() const { return Iterator { val, offset }; }
+ Iterator end() const { return Iterator{ val, offset + count }; }
+ };
typedef IValVisitor Visitor;
@@ -140,25 +367,84 @@ class Val : public NodeBase
// decide whether they need to do anything).
Val* substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
- bool equalsVal(Val* val);
+ bool equals(Val* val) const
+ {
+ return this == val || const_cast<Val*>(this)->resolve() == val->resolve();
+ }
// Appends as text to the end of the builder
void toText(StringBuilder& out);
String toString();
HashCode getHashCode();
- bool operator == (const Val & v)
+ bool operator == (const Val & v) const
{
- return equalsVal(const_cast<Val*>(&v));
+ return equals(const_cast<Val*>(&v));
}
// Overrides should be public so base classes can access
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
+
+ Val* _resolveImplOverride();
+
+ Val* resolveImpl();
+ Val* resolve();
+ ValNodeDesc getDesc();
+
+ Val* getOperand(Index index) const
+ {
+ SLANG_ASSERT(m_operands[index].kind == ValNodeOperandKind::ValNode);
+ return (Val*)m_operands[index].values.nodeOperand;
+ }
+
+ Decl* getDeclOperand(Index index) const
+ {
+ SLANG_ASSERT(m_operands[index].kind == ValNodeOperandKind::ASTNode);
+ return (Decl*)(m_operands[index].values.nodeOperand);
+ }
+
+ int64_t getIntConstOperand(Index index) const
+ {
+ SLANG_ASSERT(m_operands[index].kind == ValNodeOperandKind::ConstantValue);
+ return m_operands[index].values.intOperand;
+ }
+
+ Index getOperandCount() const { return m_operands.getCount(); }
+
+ template<typename ... TArgs>
+ void setOperands(TArgs... args)
+ {
+ m_operands.clear();
+ addOrAppendToNodeList(m_operands, args...);
+ }
+ template<typename ... TArgs>
+ void addOperands(TArgs... args)
+ {
+ addOrAppendToNodeList(m_operands, args...);
+ }
+ template<typename T>
+ void addOperands(OperandView<T> operands)
+ {
+ for (auto v : operands)
+ m_operands.add(ValNodeOperand(v));
+ }
+ List<ValNodeOperand> m_operands;
+protected:
+ Val* defaultResolveImpl();
+private:
+ mutable Val* m_resolvedVal = nullptr;
+ SLANG_UNREFLECTED mutable Index m_resolvedValEpoch = 0;
};
+template<int N, typename T, typename... Ts>
+static void addOrAppendToNodeList(ShortList<ValNodeOperand, N>& list, Val::OperandView<T> l, Ts... ts)
+{
+ for (auto t : l)
+ list.add(ValNodeOperand(t));
+ addOrAppendToNodeList(list, ts...);
+}
+
struct ValSet
{
struct ValItem
@@ -176,9 +462,9 @@ struct ValSet
if (val == other.val)
return true;
if (val)
- return val->equalsVal(other.val);
+ return val->equals(other.val);
else if (other.val)
- return other.val->equalsVal(val);
+ return other.val->equals(val);
return false;
}
};
@@ -232,31 +518,32 @@ class Type: public Val
/// Type derived types store the AST builder they were constructed on. The builder calls this function
/// after constructing.
- SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) { Val::init(inAstNodeType, inAstBuilder); m_astBuilder = inAstBuilder; }
-
- /// Get the ASTBuilder that was used to construct this Type
- SLANG_FORCE_INLINE ASTBuilder* getASTBuilder() const { return m_astBuilder; }
-
- bool equals(Type* type);
-
- Type* getCanonicalType();
+ SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder)
+ {
+ Val::init(inAstNodeType, inAstBuilder);
+ m_astBuilderForReflection = inAstBuilder;
+ }
// Overrides should be public so base classes can access
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
- bool _equalsValOverride(Val* val);
- bool _equalsImplOverride(Type* type);
Type* _createCanonicalTypeOverride();
+ Val* _resolveImplOverride();
- void _setASTBuilder(ASTBuilder* astBuilder) { m_astBuilder = astBuilder; }
+ Type* getCanonicalType()
+ {
+ return as<Type>(resolve());
+ }
+ ASTBuilder* getASTBuilderForReflection() const { return m_astBuilderForReflection; }
protected:
- bool equalsImpl(Type* type);
Type* createCanonicalType();
- Type* canonicalType = nullptr;
-
- SLANG_UNREFLECTED
- ASTBuilder* m_astBuilder = nullptr;
+ // We store the ASTBuilder to support reflection API only.
+ // It should not be used for anything else, especially not for constructing new AST nodes during
+ // semantic checking, since Val deduplication requires the entire semantic checking process to
+ // stick with one ASTBuilder.
+ // Call getCurrentASTBuilder() to obtain the right ASTBuilder for semantic checking.
+ SLANG_UNREFLECTED ASTBuilder* m_astBuilderForReflection;
};
template <typename T>
@@ -264,161 +551,68 @@ SLANG_FORCE_INLINE T* as(Type* obj) { return obj ? dynamicCast<T>(obj->getCanoni
template <typename T>
SLANG_FORCE_INLINE const T* as(const Type* obj) { return obj ? dynamicCast<T>(const_cast<Type*>(obj)->getCanonicalType()) : nullptr; }
-// A substitution represents a binding of certain
-// type-level variables to concrete argument values
-class Substitutions: public NodeBase
-{
- SLANG_ABSTRACT_AST_CLASS(Substitutions)
-
-
- // Apply a set of substitutions to the bindings in this substitution
- Substitutions* applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff);
-
- // Check if these are equivalent substitutions to another set
- bool equals(Substitutions* subst);
- HashCode getHashCode() const;
-
- // Overrides should be public so base classes can access
- Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff);
- bool _equalsOverride(Substitutions* subst);
- HashCode _getHashCodeOverride() const;
-
- Substitutions* getOuter() const { return outer; }
-protected:
- // The next outer that this one refines.
- Substitutions* outer = nullptr;
-};
-
-class GenericSubstitution : public Substitutions
-{
- SLANG_AST_CLASS(GenericSubstitution)
-
-private:
- // The generic declaration that defines the
- // parameters we are binding to arguments
- GenericDecl* genericDecl = nullptr;
-
- // The actual values of the arguments
- List<Val* > args;
-public:
- GenericDecl* getGenericDecl() const { return genericDecl; }
- List<Val*>& getArgs() { return args; }
- const List<Val*>& getArgs() const { return args; }
-
- // Overrides should be public so base classes can access
- Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff);
- bool _equalsOverride(Substitutions* subst);
- HashCode _getHashCodeOverride() const;
-
- GenericSubstitution(Substitutions* outerSubst, GenericDecl* decl, ArrayView<Val*> argVals)
- {
- outer = outerSubst;
- genericDecl = decl;
- args.addRange(argVals);
- }
-};
-
-class ThisTypeSubstitution : public Substitutions
-{
- SLANG_AST_CLASS(ThisTypeSubstitution)
-
- // The declaration of the interface that we are specializing
- InterfaceDecl* interfaceDecl = nullptr;
-
- // A witness that shows that the concrete type used to
- // specialize the interface conforms to the interface.
- SubtypeWitness* witness = nullptr;
-
- // Overrides should be public so base classes can access
- // The actual type that provides the lookup scope for an associated type
- Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff);
- bool _equalsOverride(Substitutions* subst);
- HashCode _getHashCodeOverride() const;
-
- ThisTypeSubstitution(Substitutions* outerSubst, InterfaceDecl* inInterfaceDecl, SubtypeWitness* inWitness)
- : interfaceDecl(inInterfaceDecl), witness(inWitness)
- {
- outer = outerSubst;
- }
-};
-
class Decl;
// A reference to a declaration, which may include
// substitutions for generic parameters.
class DeclRefBase : public Val
{
- SLANG_AST_CLASS(DeclRefBase)
+ SLANG_ABSTRACT_AST_CLASS(DeclRefBase)
- Decl* getDecl() const { return decl; }
+ Decl* getDecl() const { return getDeclOperand(0); }
- Substitutions* getSubst() const { return substitutions; }
+ // Apply substitutions to this declaration reference
+ DeclRefBase* substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
- DeclRefBase(Decl* decl)
- :decl(decl)
+ DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
+ SLANG_UNUSED(astBuilder);
+ SLANG_UNUSED(subst);
+ SLANG_UNUSED(ioDiff);
+ SLANG_UNREACHABLE("DeclRefBase::_substituteImplOverride not overrided.");
}
- DeclRefBase(Decl* decl, Substitutions* subst)
- :decl(decl), substitutions(subst)
+ void _toTextOverride(StringBuilder& out)
{
+ SLANG_UNUSED(out);
+ SLANG_UNREACHABLE("DeclRefBase::_toTextOverride not overrided.");
}
- // Apply substitutions to a type or declaration
- Type* substitute(ASTBuilder* astBuilder, Type* type) const;
-
- DeclRefBase* substitute(ASTBuilder* astBuilder, DeclRefBase* declRef) const;
-
- // Apply substitutions to an expression
- SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const;
-
- // Apply substitutions to this declaration reference
- DeclRefBase* substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const;
-
- Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+ Val* _resolveImplOverride()
{
- return substituteImpl(astBuilder, subst, ioDiff);
+ SLANG_UNREACHABLE("DeclRefBase::_resolveImplOverride not overrided.");
}
- bool _equalsValOverride(Val* val);
-
- bool _equalsImplOverride(DeclRefBase* declRef) { return equals(declRef); }
- void _toTextOverride(StringBuilder& out) { toText(out); }
+ DeclRefBase* _getBaseOverride()
+ {
+ SLANG_UNREACHABLE("DeclRefBase::_getBaseOverride not overrided.");
+ }
// Returns true if 'as' will return a valid cast
template <typename T>
- bool is() const { return Slang::as<T>(decl) != nullptr; }
-
- // Check if this is an equivalent declaration reference to another
- bool equals(DeclRefBase* declRef) const;
+ bool is() const { return Slang::as<T>(getDecl()) != nullptr; }
// Convenience accessors for common properties of declarations
Name* getName() const;
SourceLoc getNameLoc() const;
SourceLoc getLoc() const;
- DeclRefBase* getParent(ASTBuilder* astBuilder) const;
-
- HashCode getHashCode() const;
-
- // Debugging:
- String toString() const;
- void toText(StringBuilder& out) const;
-
-private:
-
- // The underlying declaration
- Decl* decl = nullptr;
- // Optionally, a chain of substitutions to perform
- Substitutions* substitutions = nullptr;
-
+ DeclRefBase* getParent();
+ String toString() const
+ {
+ StringBuilder sb;
+ const_cast<DeclRefBase*>(this)->toText(sb);
+ return sb.produceString();
+ }
+ DeclRefBase* getBase();
+ void toText(StringBuilder& out);
};
-SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const DeclRefBase* declRef) { declRef->toText(io); return io; }
+SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const DeclRefBase* declRef) { if (declRef) const_cast<DeclRefBase*>(declRef)->toText(io); return io; }
-SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const Decl* decl)
+SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, Decl* decl)
{
if (decl)
- _printNestedDecl(nullptr, decl, io);
+ makeDeclRef(decl).declRefBase->toText(io);
return io;
}
@@ -488,12 +682,7 @@ public:
ContainerDecl* parentDecl = nullptr;
- // A direct DeclRef to this Decl.
- // For every Decl, we create a DeclRef node representing a direct reference to it
- // upon the creation of the Decl (implemented in ASTBuilder), and store that
- // DeclRef here so we can get a direct DeclRef from a Decl* (by calling makeDeclRef)
- // without requiring a ASTBuilder to be available.
- DeclRefBase* defaultDeclRef = nullptr;
+ DeclRefBase* getDefaultDeclRef();
NameLoc nameAndLoc;
@@ -514,6 +703,10 @@ public:
SLANG_RELEASE_ASSERT(state >= checkState.getState());
checkState.setState(state);
}
+
+private:
+ SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr;
+ SLANG_UNREFLECTED Index m_defaultDeclRefEpoch = -1;
};
class Expr : public SyntaxNode
@@ -551,8 +744,7 @@ DeclRef<T>::DeclRef(Decl* decl)
DeclRefBase* declRef = nullptr;
if (decl)
{
- SLANG_ASSERT(decl->defaultDeclRef);
- declRef = decl->defaultDeclRef;
+ declRef = decl->getDefaultDeclRef();
}
init(declRef);
}
@@ -564,12 +756,6 @@ T* DeclRef<T>::getDecl() const
}
template<typename T>
-Substitutions* DeclRef<T>::getSubst() const
-{
- return declRefBase ? declRefBase->getSubst() : nullptr;
-}
-
-template<typename T>
Name* DeclRef<T>::getName() const
{
if (declRefBase)
@@ -592,9 +778,9 @@ SourceLoc DeclRef<T>::getLoc() const
}
template<typename T>
-DeclRef<ContainerDecl> DeclRef<T>::getParent(ASTBuilder* astBuilder) const
+DeclRef<ContainerDecl> DeclRef<T>::getParent() const
{
- if (declRefBase) return DeclRef<ContainerDecl>(declRefBase->getParent(astBuilder));
+ if (declRefBase) return DeclRef<ContainerDecl>(declRefBase->getParent());
return DeclRef<ContainerDecl>((DeclRefBase*)nullptr);
}
@@ -608,15 +794,17 @@ HashCode DeclRef<T>::getHashCode() const
template<typename T>
Type* DeclRef<T>::substitute(ASTBuilder* astBuilder, Type* type) const
{
+ SLANG_UNUSED(astBuilder);
if (!declRefBase) return type;
- return declRefBase->substitute(astBuilder, type);
+ return SubstitutionSet(*this).applyToType(astBuilder, type);
}
template<typename T>
SubstExpr<Expr> DeclRef<T>::substitute(ASTBuilder* astBuilder, Expr* expr) const
{
+ SLANG_UNUSED(astBuilder);
if (!declRefBase) return expr;
- return declRefBase->substitute(astBuilder, expr);
+ return applySubstitutionToExpr(SubstitutionSet(*this), expr);
}
// Apply substitutions to a type or declaration
@@ -624,23 +812,21 @@ template<typename T>
template<typename U>
DeclRef<U> DeclRef<T>::substitute(ASTBuilder* astBuilder, DeclRef<U> declRef) const
{
+ SLANG_UNUSED(astBuilder);
if (!declRefBase) return declRef;
- return DeclRef<U>(declRefBase->substitute(astBuilder, declRef.declRefBase));
+ return DeclRef<U>(SubstitutionSet(*this).applyToDeclRef(astBuilder, declRef.declRefBase));
}
// Apply substitutions to this declaration reference
template<typename T>
DeclRef<T> DeclRef<T>::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const
{
+ SLANG_UNUSED(astBuilder);
if (!declRefBase) return *this;
return DeclRef<T>(declRefBase->substituteImpl(astBuilder, subst, ioDiff));
}
-template<typename T>
-template<typename U>
-bool DeclRef<T>::equals(DeclRef<U> other) const
-{
- return declRefBase == other.declRefBase || (declRefBase && declRefBase->equals(other.declRefBase));
-}
+Val::OperandView<Val> tryGetGenericArguments(SubstitutionSet substSet, Decl* genericDecl);
+
} // namespace Slang
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 64a7abd8c..96fb6ac79 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -29,12 +29,6 @@ void SharedASTBuilder::init(Session* session)
// Clear the built in types
memset(m_builtinTypes, 0, sizeof(m_builtinTypes));
- // Create common shared types
- m_errorType = m_astBuilder->create<ErrorType>();
- m_bottomType = m_astBuilder->create<BottomType>();
- m_initializerListType = m_astBuilder->create<InitializerListType>();
- m_overloadedType = m_astBuilder->create<OverloadGroupType>();
-
// We can just iterate over the class pointers.
// NOTE! That this adds the names of the abstract classes too(!)
for (Index i = 0; i < Index(ASTNodeType::CountOf); ++i)
@@ -151,6 +145,31 @@ Type* SharedASTBuilder::getDiffInterfaceType()
return m_diffInterfaceType;
}
+Type* SharedASTBuilder::getErrorType()
+{
+ if (!m_errorType)
+ m_errorType = m_astBuilder->getOrCreate<ErrorType>();
+ return m_errorType;
+}
+Type* SharedASTBuilder::getBottomType()
+{
+ if (!m_bottomType)
+ m_bottomType = m_astBuilder->getOrCreate<BottomType>();
+ return m_bottomType;
+}
+Type* SharedASTBuilder::getInitializerListType()
+{
+ if (!m_initializerListType)
+ m_initializerListType = m_astBuilder->getOrCreate<InitializerListType>();
+ return m_initializerListType;
+}
+Type* SharedASTBuilder::getOverloadedType()
+{
+ if (!m_overloadedType)
+ m_overloadedType = m_astBuilder->getOrCreate<OverloadGroupType>();
+ return m_overloadedType;
+}
+
SharedASTBuilder::~SharedASTBuilder()
{
// Release built in types..
@@ -208,19 +227,28 @@ Decl* SharedASTBuilder::tryFindMagicDecl(const String& name)
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+Index& _getGlobalASTEpochId()
+{
+ static thread_local Index epochId = 1;
+ return epochId;
+}
+
ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name):
m_sharedASTBuilder(sharedASTBuilder),
m_name(name),
m_id(sharedASTBuilder->m_id++),
- m_arena(2048)
+ m_arena(2097152)
{
SLANG_ASSERT(sharedASTBuilder);
+ // Copy Val deduplication map over so we don't create duplicate Vals that are already
+ // existent in the stdlib.
+ m_cachedNodes = sharedASTBuilder->getInnerASTBuilder()->m_cachedNodes;
}
ASTBuilder::ASTBuilder():
m_sharedASTBuilder(nullptr),
m_id(-1),
- m_arena(2048)
+ m_arena(2097152)
{
m_name = "SharedASTBuilder::m_astBuilder";
}
@@ -233,6 +261,25 @@ ASTBuilder::~ASTBuilder()
SLANG_ASSERT(info->m_destructorFunc);
info->m_destructorFunc(node);
}
+ incrementEpoch();
+}
+
+Index ASTBuilder::getEpoch()
+{
+ return _getGlobalASTEpochId();
+}
+
+void ASTBuilder::incrementEpoch()
+{
+ _getGlobalASTEpochId()++;
+}
+
+void ASTBuilder::_verifyValDescConsistency(Val* val, const ValNodeDesc& expectedDesc)
+{
+ if (!val)
+ return;
+ ValNodeDesc descOut = val->getDesc();
+ SLANG_ASSERT(descOut == expectedDesc);
}
NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType)
@@ -256,6 +303,13 @@ Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTy
return rsType;
}
+Type* ASTBuilder::getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const char* magicTypeName)
+{
+ auto declRef = getBuiltinDeclRef(magicTypeName, genericArgs);
+ auto rsType = DeclRefType::create(this, declRef);
+ return rsType;
+}
+
PtrType* ASTBuilder::getPtrType(Type* valueType)
{
return dynamicCast<PtrType>(getPtrType(valueType, "PtrType"));
@@ -292,64 +346,57 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element
{
if (!elementCount)
elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength);
-
- auto result = getOrCreate<ArrayExpressionType>(elementType, elementCount);
- if (!result->declRef.getDecl())
+ if (elementCount->getType() != getIntType())
{
- auto arrayGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ArrayType"));
- auto arrayTypeDecl = arrayGenericDecl->inner;
- auto substitutions = getOrCreateGenericSubstitution(nullptr, arrayGenericDecl, elementType, elementCount);
- result->declRef = getSpecializedDeclRef<Decl>(arrayTypeDecl, substitutions);
+ // Canonicalize constant elementCount to int.
+ if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount))
+ {
+ elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue());
+ }
}
- return result;
+ Val* args[] = {elementType, elementCount};
+ return as<ArrayExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType"));
}
ConstantBufferType* ASTBuilder::getConstantBufferType(Type* elementType)
{
- auto result = getOrCreate<ConstantBufferType>(elementType);
- if (!result->declRef.getDecl())
- {
- auto genericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ConstantBuffer"));
- auto typeDecl = genericDecl->inner;
- auto substitutions = getOrCreateGenericSubstitution(nullptr, genericDecl, elementType);
- result->declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions);
- }
- return result;
+ return as<ConstantBufferType>(getSpecializedBuiltinType(elementType, "ConstantBufferType"));
+}
+
+ParameterBlockType* ASTBuilder::getParameterBlockType(Type* elementType)
+{
+ return as<ParameterBlockType>(getSpecializedBuiltinType(elementType, "ParameterBlockType"));
+}
+
+HLSLStructuredBufferType* ASTBuilder::getStructuredBufferType(Type* elementType)
+{
+ return as<HLSLStructuredBufferType>(getSpecializedBuiltinType(elementType, "HLSLStructuredBufferType"));
+}
+
+SamplerStateType* ASTBuilder::getSamplerStateType()
+{
+ return as<SamplerStateType>(getSpecializedBuiltinType(nullptr, "HLSLStructuredBufferType"));
}
VectorExpressionType* ASTBuilder::getVectorType(
Type* elementType,
IntVal* elementCount)
{
- auto result = getOrCreate<VectorExpressionType>(elementType, elementCount);
- if (!result->declRef.getDecl())
+ // Canonicalize constant elementCount to int.
+ if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount))
{
- auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector"));
- auto vectorTypeDecl = vectorGenericDecl->inner;
- auto substitutions = getOrCreateGenericSubstitution(nullptr, vectorGenericDecl, elementType, elementCount);
- result->declRef = getSpecializedDeclRef<Decl>(vectorTypeDecl, substitutions);
+ elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue());
}
- return result;
+ Val* args[] = { elementType, elementCount };
+ return as<VectorExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType"));
}
DifferentialPairType* ASTBuilder::getDifferentialPairType(
Type* valueType,
Witness* primalIsDifferentialWitness)
{
- auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl("DifferentialPairType"));
-
- auto typeDecl = genericDecl->inner;
-
- auto substitutions = getOrCreateGenericSubstitution(
- nullptr,
- genericDecl,
- valueType,
- primalIsDifferentialWitness);
-
- auto declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions);
- auto rsType = DeclRefType::create(this, declRef);
-
- return as<DifferentialPairType>(rsType);
+ Val* args[] = { valueType, primalIsDifferentialWitness };
+ return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType"));
}
DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl()
@@ -377,20 +424,9 @@ MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier(
: as<HLSLIndicesModifier>(modifier) ? "IndicesType"
: as<HLSLPrimitivesModifier>(modifier) ? "PrimitivesType"
: (SLANG_UNEXPECTED("Unhandled mesh output modifier"), nullptr);
- auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl(declName));
-
- auto typeDecl = genericDecl->inner;
-
- auto substitutions = getOrCreateGenericSubstitution(
- nullptr,
- genericDecl,
- elementType,
- maxElementCount);
- auto declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions);
- auto rsType = DeclRefType::create(this, declRef);
-
- return as<MeshOutputType>(rsType);
+ Val* args[] = { elementType, maxElementCount };
+ return as<MeshOutputType>(getSpecializedBuiltinType(makeArrayView(args), declName));
}
Type* ASTBuilder::getDifferentiableInterfaceType()
@@ -403,13 +439,8 @@ DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va
auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
if (auto genericDecl = as<GenericDecl>(decl))
{
- decl = genericDecl->inner;
- Substitutions* subst = nullptr;
- if (genericArg)
- {
- subst = getOrCreateGenericSubstitution(nullptr, genericDecl, genericArg);
- }
- return getSpecializedDeclRef(decl, subst);
+ auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), makeConstArrayViewSingle(genericArg));
+ return declRef;
}
else
{
@@ -418,6 +449,21 @@ DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va
return makeDeclRef(decl);
}
+DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ArrayView<Val*> genericArgs)
+{
+ auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
+ if (auto genericDecl = as<GenericDecl>(decl))
+ {
+ auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), genericArgs);
+ return declRef;
+ }
+ else
+ {
+ SLANG_ASSERT(!decl && !genericArgs.getCount());
+ }
+ return makeDeclRef(decl);
+}
+
Type* ASTBuilder::getAndType(Type* left, Type* right)
{
auto type = getOrCreate<AndType>(left, right);
@@ -426,9 +472,7 @@ Type* ASTBuilder::getAndType(Type* left, Type* right)
Type* ASTBuilder::getModifiedType(Type* base, Count modifierCount, Val* const* modifiers)
{
- auto type = create<ModifiedType>();
- type->base = base;
- type->modifiers.addRange(modifiers, modifierCount);
+ auto type = getOrCreate<ModifiedType>(base, makeArrayView((Val**)modifiers, modifierCount));
return type;
}
@@ -447,15 +491,16 @@ Val* ASTBuilder::getNoDiffModifierVal()
return getOrCreate<NoDiffModifierVal>();
}
-Type* ASTBuilder::getFuncType(List<Type*> parameters, Type* result)
+FuncType* ASTBuilder::getFuncType(ArrayView<Type*> parameters, Type* result, Type* errorType)
{
- auto errorType = getOrCreate<BottomType>();
+ if (!errorType)
+ errorType = getOrCreate<BottomType>();
return getOrCreate<FuncType>(parameters, result, errorType);
}
-Type* ASTBuilder::getTupleType(List<Type*>& types)
+TupleType* ASTBuilder::getTupleType(List<Type*>& types)
{
- return getOrCreate<TupleType>(types);
+ return getOrCreate<TupleType>(types.getArrayView());
}
TypeType* ASTBuilder::getTypeType(Type* type)
@@ -466,11 +511,11 @@ TypeType* ASTBuilder::getTypeType(Type* type)
TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness(
Type* type)
{
- return getOrCreate<TypeEqualityWitness>(type);
+ return getOrCreate<TypeEqualityWitness>(type, type);
}
-SubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness(
+DeclaredSubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness(
Type* subType,
Type* superType,
DeclRef<Decl> const& declRef)
@@ -517,8 +562,8 @@ top:
// Let's call the intermediate type here `x`, we know that the `b <: c`
// witness is based on witnesses that `b <: x` and `x <: c`:
//
- auto bIsSubtypeOfXWitness = bIsTransitiveSubtypeOfCWitness->subToMid;
- auto xIsSubtypeOfCWitness = bIsTransitiveSubtypeOfCWitness->midToSup;
+ auto bIsSubtypeOfXWitness = bIsTransitiveSubtypeOfCWitness->getSubToMid();
+ auto xIsSubtypeOfCWitness = bIsTransitiveSubtypeOfCWitness->getMidToSup();
// We can recursively call this operation to produce a witness that
// `a <: x`, based on the witnesses we already have for `a <: b` and `b <: x`:
@@ -535,8 +580,8 @@ top:
goto top;
}
- auto aType = aIsSubtypeOfBWitness->sub;
- auto cType = bIsSubtypeOfCWitness->sup;
+ auto aType = aIsSubtypeOfBWitness->getSub();
+ auto cType = bIsSubtypeOfCWitness->getSup();
// If the right-hand side is a conjunction witness for `B <: C`
// of the form `(B <: X)&(B <: Y)`, then we have it that `C = X&Y`
@@ -565,8 +610,8 @@ top:
// the witness `W` that `B <: X&Y&...` as well as the index
// `i` of `C` within the conjunction.
//
- auto bIsSubtypeOfConjunction = bIsSubtypeViaExtraction->conjunctionWitness;
- auto indexOfCInConjunction = bIsSubtypeViaExtraction->indexInConjunction;
+ auto bIsSubtypeOfConjunction = bIsSubtypeViaExtraction->getConjunctionWitness();
+ auto indexOfCInConjunction = bIsSubtypeViaExtraction->getIndexInConjunction();
// We lift the extraction to the outside of the composition, by
// forming a witness for `A <: C` that is of the form
@@ -591,24 +636,14 @@ top:
// formal set of rules for the allowed structure of our witnesses to
// guarantee that our simplifications are sufficient.
- TransitiveSubtypeWitness* transitiveWitness = getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(
+ TransitiveSubtypeWitness* transitiveWitness = getOrCreate<TransitiveSubtypeWitness>(
aType,
cType,
aIsSubtypeOfBWitness,
bIsSubtypeOfCWitness);
- transitiveWitness->sub = aType;
- transitiveWitness->sup = cType;
- transitiveWitness->subToMid = aIsSubtypeOfBWitness;
- transitiveWitness->midToSup = bIsSubtypeOfCWitness;
-
return transitiveWitness;
}
-ThisTypeSubtypeWitness* ASTBuilder::getThisTypeSubtypeWitness(Type* subType, Type* superType)
-{
- return getOrCreate<ThisTypeSubtypeWitness>(subType, superType);
-}
-
SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness(
Type* subType,
Type* superType,
@@ -633,16 +668,11 @@ SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness(
//
// * What if the original witness is transitive?
- auto witness = getOrCreateWithDefaultCtor<ExtractFromConjunctionSubtypeWitness>(
+ auto witness = getOrCreate<ExtractFromConjunctionSubtypeWitness>(
subType,
superType,
conjunctionWitness,
indexOfSuperTypeInConjunction);
-
- witness->sub = subType;
- witness->sup = superType;
- witness->conjunctionWitness = conjunctionWitness;
- witness->indexInConjunction = indexOfSuperTypeInConjunction;
return witness;
}
@@ -662,11 +692,11 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
auto rExtract = as<ExtractFromConjunctionSubtypeWitness>(subIsRWitness);
if(lExtract && rExtract)
{
- if (lExtract->indexInConjunction == 0
- && rExtract->indexInConjunction == 1)
+ if (lExtract->getIndexInConjunction() == 0
+ && rExtract->getIndexInConjunction() == 1)
{
- auto lInner = lExtract->conjunctionWitness;
- auto rInner = rExtract->conjunctionWitness;
+ auto lInner = lExtract->getConjunctionWitness();
+ auto rInner = rExtract->getConjunctionWitness();
if (lInner == rInner)
{
return lInner;
@@ -685,57 +715,30 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
// witness) deeper, so that we have more chances to expose a
// conjunction witness at higher levels.
- auto witness = getOrCreateWithDefaultCtor<ConjunctionSubtypeWitness>(
+ auto witness = getOrCreate<ConjunctionSubtypeWitness>(
sub,
lAndR,
subIsLWitness,
subIsRWitness);
- witness->componentWitnesses[0] = subIsLWitness;
- witness->componentWitnesses[1] = subIsRWitness;
- witness->sub = sub;
- witness->sup = lAndR;
return witness;
}
-bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const
+DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl)
{
- if (hashCode != that.hashCode) return false;
- if(type != that.type) return false;
- if(operands.getCount() != that.operands.getCount()) return false;
- for(Index i = 0; i < operands.getCount(); ++i)
- {
- // Note: we are comparing the operands directly for identity
- // (pointer equality) rather than doing the `Val`-level
- // equality check.
- //
- // The rationale here is that nodes that will be created
- // via a `NodeDesc` *should* all be going through the
- // deduplication path anyway, as should their operands.
- //
- if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false;
- }
- return true;
+ return builder->getMemberDeclRef(parent, decl);
}
-void ASTBuilder::NodeDesc::init()
+
+thread_local ASTBuilder* gCurrentASTBuilder = nullptr;
+
+ASTBuilder* getCurrentASTBuilder()
{
- Hasher hasher;
- hasher.hashValue(Int(type));
- for(Index i = 0; i < operands.getCount(); ++i)
- {
- // Note: we are hashing the raw pointer value rather
- // than the content of the value node. This is done
- // to match the semantics implemented for `==` on
- // `NodeDesc`.
- //
- hasher.hashValue(operands[i].values.nodeOperand);
- }
- hashCode = hasher.getResult();
+ return gCurrentASTBuilder;
}
-DeclRef<Decl> _getSpecializedDeclRef(ASTBuilder* builder, Decl* decl, Substitutions* subst)
+void setCurrentASTBuilder(ASTBuilder* astBuilder)
{
- return builder->getSpecializedDeclRef(decl, subst);
+ gCurrentASTBuilder = astBuilder;
}
} // namespace Slang
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index cf0975cdd..0d63e1060 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -39,6 +39,11 @@ public:
/// Get the `IDifferentiable` type
Type* getDiffInterfaceType();
+ Type* getErrorType();
+ Type* getBottomType();
+ Type* getInitializerListType();
+ Type* getOverloadedType();
+
const ReflectClassInfo* findClassInfo(Name* name);
SyntaxClass<NodeBase> findSyntaxClass(Name* name);
@@ -65,6 +70,8 @@ public:
~SharedASTBuilder();
+ ASTBuilder* getInnerASTBuilder() { return m_astBuilder; }
+
protected:
// State shared between ASTBuilders
@@ -108,79 +115,59 @@ protected:
class ASTBuilder : public RefObject
{
friend class SharedASTBuilder;
-public:
- // Node cache:
- struct NodeOperand
- {
- union
- {
- NodeBase* nodeOperand;
- int64_t intOperand;
- } values;
-
- NodeOperand()
- {
- values.nodeOperand = nullptr;
- }
-
- NodeOperand(NodeBase* node) { values.nodeOperand = node; }
-
- template<typename T>
- NodeOperand(DeclRef<T> declRef) { values.nodeOperand = declRef.declRefBase; }
-
- template<typename EnumType>
- NodeOperand(EnumType intVal)
- {
- static_assert(std::is_trivial<EnumType>::value, "Type to construct NodeOperand must be trivial.");
- static_assert(sizeof(EnumType) <= sizeof(values), "size of operand must be less than pointer size.");
- values.intOperand = 0;
- memcpy(&values, &intVal, sizeof(intVal));
- }
- };
- struct NodeDesc
- {
- ASTNodeType type;
- ShortList<NodeOperand, 4> operands;
-
- bool operator==(NodeDesc const& that) const;
- HashCode getHashCode() const { return hashCode; }
- void init();
- private:
- HashCode hashCode = 0;
- };
+public:
template<typename NodeCreateFunc>
- NodeBase* _getOrCreateImpl(NodeDesc const& desc, NodeCreateFunc createFunc)
+ NodeBase* _getOrCreateImpl(ValNodeDesc const& desc, NodeCreateFunc createFunc)
{
if (auto found = m_cachedNodes.tryGetValue(desc))
return *found;
auto node = createFunc();
m_cachedNodes.add(desc, node);
+#ifdef _DEBUG
+ _verifyValDescConsistency(dynamicCast<Val>(node), desc);
+#endif
return node;
}
/// A cache for AST nodes that are entirely defined by their node type, with
/// no need for additional state.
- Dictionary<NodeDesc, NodeBase*> m_cachedNodes;
+ Dictionary<ValNodeDesc, NodeBase*> m_cachedNodes;
+
+ Dictionary<GenericDecl*, List<Val*>> m_cachedGenericDefaultArgs;
+
+ /// Create AST types
+ template <typename T>
+ T* createImpl()
+ {
+ auto alloced = m_arena.allocate(sizeof(T));
+ memset(alloced, 0, sizeof(T));
+ auto result = _initAndAdd(new (alloced) T);
+ return result;
+ }
- template<int N>
- static void addOrAppendToNodeList(ShortList<NodeOperand, N>&)
- {}
+ template<typename T, typename... TArgs>
+ T* createImpl(TArgs&&... args)
+ {
+ auto alloced = m_arena.allocate(sizeof(T));
+ memset(alloced, 0, sizeof(T));
+ auto result = _initAndAdd(new (alloced) T(std::forward<TArgs>(args)...));
+ return result;
+ }
- template<int N, typename T, typename... Ts>
- static void addOrAppendToNodeList(ShortList<NodeOperand, N>& list, T t, Ts... ts)
+ template <typename T>
+ T* create()
{
- list.add(t);
- addOrAppendToNodeList(list, ts...);
+ static_assert(!IsBaseOf<Val, T>::Value, "ASTBuilder::create cannot be used to create a Val, use getOrCreate instead.");
+ return createImpl<T>();
}
- template<int N, typename T, typename... Ts>
- static void addOrAppendToNodeList(ShortList<NodeOperand, N>& list, const List<T>& l, Ts... ts )
+ template<typename T, typename... TArgs>
+ T* create(TArgs&&... args)
{
- for(auto t : l)
- list.add(t);
- addOrAppendToNodeList(list, ts...);
+ static_assert(!IsBaseOf<Val, T>::Value, "ASTBuilder::create cannot be used to create a Val, use getOrCreate instead.");
+ return createImpl<T>(args...);
}
public:
@@ -195,37 +182,27 @@ public:
};
};
- MemoryArena& getArena() { return m_arena; }
+ Index getEpoch();
- /// Create AST types
- template <typename T>
- T* create()
- {
- auto alloced = m_arena.allocate(sizeof(T));
- memset(alloced, 0, sizeof(T));
- return _initAndAdd(new (alloced) T);
- }
+ void incrementEpoch();
- template<typename T, typename... TArgs>
- T* create(TArgs&&... args)
- {
- auto alloced = m_arena.allocate(sizeof(T));
- memset(alloced, 0, sizeof(T));
- return _initAndAdd(new (alloced) T(std::forward<TArgs>(args)...));
- }
+ MemoryArena& getArena() { return m_arena; }
+
+ void _verifyValDescConsistency(Val* val, const ValNodeDesc& expectedDesc);
template<typename T, typename ... TArgs>
SLANG_FORCE_INLINE T* getOrCreate(TArgs ... args)
{
SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value);
- NodeDesc desc;
+ ValNodeDesc desc;
desc.type = T::kType;
addOrAppendToNodeList(desc.operands, args...);
desc.init();
- return (T*)_getOrCreateImpl(desc, [&]()
+ auto result = (T*)_getOrCreateImpl(desc, [&]()
{
- return create<T>(args...);
+ return createImpl<T>(args...);
});
+ return result;
}
template<typename T>
@@ -233,63 +210,101 @@ public:
{
SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value);
- NodeDesc desc;
+ ValNodeDesc desc;
desc.type = T::kType;
desc.init();
- return (T*)_getOrCreateImpl(desc, [this]() { return create<T>(); });
+ auto result = (T*)_getOrCreateImpl(desc, [this]() { return createImpl<T>(); });
+#ifdef _DEBUG
+ _verifyValDescConsistency(dynamicCast<Val>(result), desc);
+#endif
+ return result;
}
- template<typename T, typename ... TArgs>
- SLANG_FORCE_INLINE T* getOrCreateWithDefaultCtor(TArgs ... args)
+ InterfaceDecl* createInterfaceDecl(SourceLoc loc)
{
- SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value);
- NodeDesc desc;
- desc.type = T::kType;
- addOrAppendToNodeList(desc.operands, args...);
- desc.init();
- return (T*)_getOrCreateImpl(desc, [&]()
- {
- return create<T>();
- });
+ auto interfaceDecl = create<InterfaceDecl>();
+ // Always include a `This` member and a `This:IThisInterface` member.
+ auto thisDecl = create<ThisTypeDecl>();
+ thisDecl->nameAndLoc.name = m_sharedASTBuilder->getNamePool()->getName(UnownedStringSlice("This", 4));
+ thisDecl->nameAndLoc.loc = loc;
+ interfaceDecl->addMember(thisDecl);
+ auto thisConstraint = create<ThisTypeConstraintDecl>();
+ thisConstraint->loc = loc;
+ thisConstraint->base.type = DeclRefType::create(this, getDirectDeclRef(interfaceDecl));
+ thisDecl->addMember(thisConstraint);
+ return interfaceDecl;
}
template<typename T>
- SLANG_FORCE_INLINE T* getOrCreateWithDefaultCtor(ConstArrayView<NodeOperand> operands)
+ DeclRef<T> getDirectDeclRef(T* decl)
{
- SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value);
- NodeDesc desc;
- desc.type = T::kType;
- desc.operands.addRange(operands);
- desc.init();
- return (T*)_getOrCreateImpl(desc, [&]()
- {
- return create<T>();
- });
+ if (!decl)
+ return DeclRef<T>();
+
+ auto result = DeclRef<T>(getOrCreate<DirectDeclRef>(decl));
+ return result;
}
- // This is the bottlneck through which all DeclRefs are created.
template<typename T>
- DeclRef<T> getSpecializedDeclRef(T* decl, Substitutions* subst)
+ DeclRef<T> getMemberDeclRef(DeclRef<Decl> parent, T* memberDecl)
{
- // We never create an actual DeclRefBase node to point to a null decl.
- if (!decl)
- return DeclRef<T>();
-
- // If we don't have substitutions, use the default decl ref if it is created.
- if (!subst)
+ if (!parent)
+ return getDirectDeclRef(memberDecl);
+ // A Generic value/type ParamDecl is always referred to directly.
+ if (as<GenericTypeParamDecl>(memberDecl) || as<GenericValueParamDecl>(memberDecl))
+ return getDirectDeclRef(memberDecl);
+ if (as<ThisTypeDecl>(memberDecl) && !as<InterfaceDecl>(memberDecl->parentDecl))
+ return as<T>(parent);
+
+ if (auto parentMemberDeclRef = as<MemberDeclRef>(parent.declRefBase))
{
- auto defaultDeclRef = static_cast<Decl*>(decl)->defaultDeclRef;
- if (defaultDeclRef)
- return defaultDeclRef;
+ return DeclRef<T>(getMemberDeclRef(parentMemberDeclRef->getParent(), memberDecl));
+ }
+ else if (auto lookupDeclRef = as<LookupDeclRef>(parent.declRefBase))
+ {
+ // Handle some specicial case rules due to the way some of our builtin decls are
+ // represented.
+ // - Member(Lookup(w, This), x) ==> Lookup(w, X)
+ // Lookup of x from This is a lookup from w directly.
+ // - Member(Lookup(w, someExtension), x) ==> Lookup(w, X)
+ // Lookup of a decl defined in an extension is to lookup directly.
+ // - Member(Lookup(w, AssociatedType), TypeConstraintDecl) ==> Lookup(w, TypeConstraintDecl)
+ // Type constraint of an associated type is defined directly in w.
+
+ auto parentDeclKind = lookupDeclRef->getDecl()->astNodeType;
+ switch (parentDeclKind)
+ {
+ case ASTNodeType::ThisTypeDecl:
+ case ASTNodeType::ExtensionDecl:
+ case ASTNodeType::AssocTypeDecl:
+ return getLookupDeclRef(lookupDeclRef->getLookupSource(), lookupDeclRef->getWitness(), memberDecl);
+ default:
+ break;
+ }
+ }
+ else if (auto directDeclRef = as<DirectDeclRef>(parent.declRefBase))
+ {
+ return DeclRef<T>(getOrCreate<DirectDeclRef>(memberDecl));
}
- return getOrCreate<DeclRefBase>(decl, subst);
- }
+#if _DEBUG
+ // Verify that member is indeed a member of parent.
+ auto parentDecl = parent.getDecl();
+ while (as<ThisTypeDecl>(parentDecl))
+ parentDecl = parentDecl->parentDecl;
+ bool foundParent = false;
+ for (Decl* dd = memberDecl; dd; dd = dd->parentDecl)
+ {
+ if (dd == parentDecl)
+ {
+ foundParent = true;
+ break;
+ }
+ }
+ SLANG_ASSERT(foundParent);
+#endif
- template<typename T>
- DeclRef<T> getSpecializedDeclRef(T* decl, SubstitutionSet subst)
- {
- return getSpecializedDeclRef(decl, subst.substitutions);
+ return DeclRef<T>(getOrCreate<MemberDeclRef>(memberDecl, parent.declRefBase));
}
ConstantIntVal* getIntVal(Type* type, IntegerLiteralValue value)
@@ -297,61 +312,38 @@ public:
return getOrCreate<ConstantIntVal>(type, value);
}
- GenericSubstitution* getOrCreateGenericSubstitution(Substitutions* outer, GenericDecl* decl, ArrayView<Val*> args)
+ DeclRef<Decl> getGenericAppDeclRef(DeclRef<GenericDecl> genericDeclRef, ConstArrayView<Val*> args, Decl* innerDecl = nullptr)
{
- NodeDesc desc;
- desc.type = GenericSubstitution::kType;
- desc.operands.add(decl);
- for (auto arg : args)
- desc.operands.add(arg);
- if (outer)
- {
- desc.operands.add(outer);
- }
- desc.init();
- auto result = (GenericSubstitution*)_getOrCreateImpl(desc, [this]() {return create<GenericSubstitution>(); });
- if (result->args.getCount() != args.getCount())
- {
- SLANG_RELEASE_ASSERT(result->args.getCount() == 0);
- result->args.addRange(args);
- result->genericDecl = decl;
- result->outer = outer;
- }
- return result;
- }
+ if (!innerDecl)
+ innerDecl = genericDeclRef.getDecl()->inner;
- GenericSubstitution* getOrCreateGenericSubstitution(Substitutions* outer, GenericDecl* decl, const List<Val*>& args)
- {
- return getOrCreateGenericSubstitution(outer, decl, args.getArrayView());
+ return getOrCreate<GenericAppDeclRef>(innerDecl, genericDeclRef, args);
}
- template<typename... Args>
- GenericSubstitution* getOrCreateGenericSubstitution(Substitutions* outer, GenericDecl* decl, Args... args)
+ DeclRef<Decl> getGenericAppDeclRef(DeclRef<GenericDecl> genericDeclRef, Val::OperandView<Val> args, Decl* innerDecl = nullptr)
{
- List<Val*> vals;
- addToList(vals, args...);
- return getOrCreateGenericSubstitution(outer, decl, vals.getArrayView());
- }
+ if (!innerDecl)
+ innerDecl = genericDeclRef.getDecl()->inner;
+ return getOrCreate<GenericAppDeclRef>(innerDecl, genericDeclRef, args);
+ }
- ThisTypeSubstitution* getOrCreateThisTypeSubstitution(InterfaceDecl* interfaceDecl, SubtypeWitness* subtypeWitness, Substitutions* outer)
+ LookupDeclRef* getLookupDeclRef(Type* base, SubtypeWitness* subtypeWitness, Decl* declToLookup)
{
- NodeDesc desc;
- desc.type = ThisTypeSubstitution::kType;
- desc.operands.add(interfaceDecl);
- desc.operands.add(subtypeWitness);
- if (outer)
- {
- desc.operands.add(outer);
- }
+ ValNodeDesc desc;
+ desc.type = LookupDeclRef::kType;
+ desc.operands.add(ValNodeOperand(subtypeWitness));
+ desc.operands.add(ValNodeOperand(declToLookup));
desc.init();
- auto result = (ThisTypeSubstitution*)_getOrCreateImpl(desc, [this]() {return create<ThisTypeSubstitution>(); });
- result->interfaceDecl = interfaceDecl;
- result->witness = subtypeWitness;
- result->outer = outer;
+ auto result = getOrCreate<LookupDeclRef>(declToLookup, base, subtypeWitness);
return result;
}
+ LookupDeclRef* getLookupDeclRef(SubtypeWitness* subtypeWitness, Decl* declToLookup)
+ {
+ return getLookupDeclRef(subtypeWitness->getSub(), subtypeWitness, declToLookup);
+ }
+
NodeBase* createByNodeType(ASTNodeType nodeType);
/// Get the built in types
@@ -371,11 +363,12 @@ public:
SLANG_FORCE_INLINE Type* getBuiltinType(BaseType flavor) { return m_sharedASTBuilder->m_builtinTypes[Index(flavor)]; }
Type* getSpecializedBuiltinType(Type* typeParam, const char* magicTypeName);
+ Type* getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const char* magicTypeName);
- Type* getInitializerListType() { return m_sharedASTBuilder->m_initializerListType; }
- Type* getOverloadedType() { return m_sharedASTBuilder->m_overloadedType; }
- Type* getErrorType() { return m_sharedASTBuilder->m_errorType; }
- Type* getBottomType() { return m_sharedASTBuilder->m_bottomType; }
+ Type* getInitializerListType() { return m_sharedASTBuilder->getInitializerListType(); }
+ Type* getOverloadedType() { return m_sharedASTBuilder->getOverloadedType(); }
+ Type* getErrorType() { return m_sharedASTBuilder->getErrorType(); }
+ Type* getBottomType() { return m_sharedASTBuilder->getBottomType(); }
Type* getStringType() { return m_sharedASTBuilder->getStringType(); }
Type* getNullPtrType() { return m_sharedASTBuilder->getNullPtrType(); }
Type* getNoneType() { return m_sharedASTBuilder->getNoneType(); }
@@ -407,13 +400,18 @@ public:
ConstantBufferType* getConstantBufferType(Type* elementType);
+ ParameterBlockType* getParameterBlockType(Type* elementType);
+
+ HLSLStructuredBufferType* getStructuredBufferType(Type* elementType);
+
+ SamplerStateType* getSamplerStateType();
+
DifferentialPairType* getDifferentialPairType(
Type* valueType,
Witness* primalIsDifferentialWitness);
DeclRef<InterfaceDecl> getDifferentiableInterfaceDecl();
Type* getDifferentiableInterfaceType();
- Decl* getDifferentiableAssociatedTypeRequirement();
bool isDifferentiableInterfaceAvailable();
@@ -423,6 +421,7 @@ public:
IntVal* maxElementCount);
DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg);
+ DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, ArrayView<Val*> genericArgs);
Type* getAndType(Type* left, Type* right);
@@ -435,9 +434,9 @@ public:
Val* getSNormModifierVal();
Val* getNoDiffModifierVal();
- Type* getTupleType(List<Type*>& types);
+ TupleType* getTupleType(List<Type*>& types);
- Type* getFuncType(List<Type*> parameters, Type* result);
+ FuncType* getFuncType(ArrayView<Type*> parameters, Type* result, Type* errorType = nullptr);
TypeType* getTypeType(Type* type);
@@ -445,7 +444,7 @@ public:
TypeEqualityWitness* getTypeEqualityWitness(
Type* type);
- SubtypeWitness* getDeclaredSubtypeWitness(
+ DeclaredSubtypeWitness* getDeclaredSubtypeWitness(
Type* subType,
Type* superType,
DeclRef<Decl> const& declRef);
@@ -455,9 +454,6 @@ public:
SubtypeWitness* aIsSubtypeOfBWitness,
SubtypeWitness* bIsSubtypeOfCWitness);
- /// Produce a witness that `ThisType(IFoo) <: IFoo`.
- ThisTypeSubtypeWitness* getThisTypeSubtypeWitness(Type* subType, Type* superType);
-
/// Produce a witness that `T <: L` or `T <: R` given `T <: L&R`
SubtypeWitness* getExtractFromConjunctionSubtypeWitness(
Type* subType,
@@ -487,14 +483,14 @@ public:
/// Get the global session
Session* getGlobalSession() { return m_sharedASTBuilder->m_session; }
+ Index getId() { return m_id; }
+
/// Ctor
ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name);
/// Dtor
~ASTBuilder();
- Dictionary<Decl*, GenericSubstitution*> m_genericDefaultSubst;
-
protected:
// Special default Ctor that can only be used by SharedASTBuilder
ASTBuilder();
@@ -512,11 +508,12 @@ protected:
// Keep such that dtor can be run on ASTBuilder being dtored
m_dtorNodes.add(node);
}
- if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Decl::kType)))
+ if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Val::kType)))
{
- auto decl = (Decl*)(node);
- decl->defaultDeclRef = getSpecializedDeclRef(decl, nullptr);
+ auto val = (Val*)(node);
+ val->m_resolvedValEpoch = getEpoch();
}
+
return node;
}
@@ -529,9 +526,30 @@ protected:
SharedASTBuilder* m_sharedASTBuilder;
MemoryArena m_arena;
+};
+
+// Retrieves the ASTBuilder for the current compilation session.
+ASTBuilder* getCurrentASTBuilder();
+// Sets the ASTBuilder for the current compilation session.
+void setCurrentASTBuilder(ASTBuilder* astBuilder);
+
+struct SetASTBuilderContextRAII
+{
+ ASTBuilder* previousASTBuilder = nullptr;
+ SetASTBuilderContextRAII(ASTBuilder* astBuilder)
+ {
+ previousASTBuilder = getCurrentASTBuilder();
+ setCurrentASTBuilder(astBuilder);
+ }
+ ~SetASTBuilderContextRAII()
+ {
+ setCurrentASTBuilder(previousASTBuilder);
+ }
};
+#define SLANG_AST_BUILDER_RAII(astBuilder) SetASTBuilderContextRAII _setASTBuilderContextRAII(astBuilder)
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp
new file mode 100644
index 000000000..4384a6df9
--- /dev/null
+++ b/source/slang/slang-ast-decl-ref.cpp
@@ -0,0 +1,461 @@
+#include "slang-ast-builder.h"
+#include "slang-ast-reflect.h"
+#include "slang-generated-ast.h"
+#include "slang-generated-ast-macro.h"
+#include "slang-check-impl.h"
+
+namespace Slang
+{
+
+DeclRefBase* DirectDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ SLANG_UNUSED(astBuilder);
+ SLANG_UNUSED(subst);
+ SLANG_UNUSED(ioDiff);
+ return this;
+}
+
+void DirectDeclRef::_toTextOverride(StringBuilder& out)
+{
+ if (getDecl()->getName() && getDecl()->getName()->text.getLength() != 0)
+ {
+ out << getDecl()->getName()->text;
+ }
+}
+
+Val* DirectDeclRef::_resolveImplOverride()
+{
+ return this;
+}
+
+DeclRefBase* DirectDeclRef::_getBaseOverride()
+{
+ return nullptr;
+}
+
+DeclRefBase* _getDeclRefFromVal(Val* val)
+{
+ if (auto declRefType = as<DeclRefType>(val))
+ return declRefType->getDeclRef();
+ else if (auto genParamIntVal = as<GenericParamIntVal>(val))
+ return genParamIntVal->getDeclRef();
+ else if (auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(val))
+ return declaredSubtypeWitness->getDeclRef();
+ else if (auto declRef = as<DeclRefBase>(val))
+ return declRef;
+ return nullptr;
+}
+
+DeclRefBase* _resolveAsDeclRef(DeclRefBase* declRefToResolve)
+{
+ if (auto rs = _getDeclRefFromVal(declRefToResolve->resolve()))
+ return rs;
+ return declRefToResolve;
+}
+
+DeclRefBase* MemberDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+ auto substParent = getParentOperand()->substituteImpl(astBuilder, subst, &diff);
+ if (diff)
+ {
+ (*ioDiff)++;
+ return astBuilder->getMemberDeclRef(substParent, getDecl());
+ }
+ return this;
+}
+
+void MemberDeclRef::_toTextOverride(StringBuilder& out)
+{
+ getParentOperand()->toText(out);
+ if (out.getLength() && !out.endsWith("."))
+ out << ".";
+ if (getDecl()->getName() && getDecl()->getName()->text.getLength() != 0)
+ {
+ out << getDecl()->getName()->text;
+ }
+}
+
+Val* MemberDeclRef::_resolveImplOverride()
+{
+ auto resolvedParent = _resolveAsDeclRef(getParentOperand());
+ if (resolvedParent != getParentOperand())
+ {
+ return getCurrentASTBuilder()->getMemberDeclRef(resolvedParent, getDecl());
+ }
+ return this;
+}
+
+DeclRefBase* MemberDeclRef::_getBaseOverride()
+{
+ return getParentOperand();
+}
+
+Decl* LookupDeclRef::getSupDecl()
+{
+ if (auto supType = as<DeclRefType>(getWitness()->getSup()))
+ {
+ return supType->getDeclRef().getDecl();
+ }
+ // If we reach here, something is wrong.
+ SLANG_UNEXPECTED("Invalid lookup declref");
+}
+
+DeclRefBase* LookupDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+
+ auto substWitness = as<SubtypeWitness>(getWitness()->substituteImpl(astBuilder, subst, &diff));
+ if (diff == 0)
+ return this;
+ (*ioDiff)++;
+
+ auto substSource = as<Type>(getLookupSource()->substituteImpl(astBuilder, subst, &diff));
+ SLANG_ASSERT(substSource);
+
+ if (auto resolved = _getDeclRefFromVal(tryResolve(substWitness, substSource)))
+ return resolved;
+
+ return astBuilder->getLookupDeclRef(substSource, substWitness, getDecl());
+}
+
+void LookupDeclRef::_toTextOverride(StringBuilder& out)
+{
+ getLookupSource()->toText(out);
+ if (out.getLength() && !out.endsWith("."))
+ out << ".";
+ if (getDecl()->getName() && getDecl()->getName()->text.getLength() != 0)
+ {
+ out << getDecl()->getName()->text;
+ }
+}
+
+Val* LookupDeclRef::_resolveImplOverride()
+{
+ auto astBuilder = getCurrentASTBuilder();
+ Val* resolved = this;
+
+ auto newLookupSource = as<Type>(getLookupSource()->resolve());
+ SLANG_ASSERT(newLookupSource);
+
+ auto newWitness = as<SubtypeWitness>(getWitness()->resolve());
+ SLANG_ASSERT(newWitness);
+
+ if (auto resolvedVal = tryResolve(newWitness, newLookupSource))
+ return resolvedVal;
+ if (newLookupSource != getLookupSource() || newWitness != getWitness())
+ resolved = astBuilder->getLookupDeclRef(newLookupSource, newWitness, getDecl());
+ return resolved;
+}
+
+DeclRefBase* LookupDeclRef::_getBaseOverride()
+{
+ return nullptr;
+}
+
+Val* LookupDeclRef::tryResolve(SubtypeWitness* newWitness, Type* newLookupSource)
+{
+ auto astBuilder = getCurrentASTBuilder();
+ Decl* requirementKey = getDecl();
+ RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, newWitness, requirementKey);
+ switch (requirementWitness.getFlavor())
+ {
+ default:
+ // No usable value was found, so there is nothing we can do.
+ break;
+
+ case RequirementWitness::Flavor::val:
+ {
+ auto satisfyingVal = requirementWitness.getVal();
+ return satisfyingVal;
+ }
+ break;
+ }
+
+ // Hard code implementation of T.Differential.Differential == T.Differential rule.
+ auto builtinReq = requirementKey->findModifier<BuiltinRequirementModifier>();
+ bool isConstraint = false;
+ if (!builtinReq)
+ {
+ if (auto parentAssocType = as<AssocTypeDecl>(requirementKey->parentDecl))
+ {
+ builtinReq = parentAssocType->findModifier<BuiltinRequirementModifier>();
+ isConstraint = true;
+ }
+ if (!builtinReq)
+ return nullptr;
+ }
+ if (builtinReq->kind != BuiltinRequirementKind::DifferentialType)
+ return nullptr;
+ // Is the concrete type a Differential associated type?
+ auto innerDeclRefType = as<DeclRefType>(newLookupSource);
+ if (!innerDeclRefType)
+ return nullptr;
+ auto innerBuiltinReq = innerDeclRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>();
+ if (!innerBuiltinReq)
+ return nullptr;
+ if (innerBuiltinReq->kind != BuiltinRequirementKind::DifferentialType)
+ return nullptr;
+ if (isConstraint)
+ return newWitness;
+ if (innerDeclRefType->getDeclRef() != this)
+ {
+ auto result = innerDeclRefType->getDeclRef().declRefBase->resolve();
+ if (result)
+ return result;
+ }
+ return innerDeclRefType;
+}
+
+DeclRefBase* GenericAppDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+ auto substGenericDeclRef = getGenericDeclRef()->substituteImpl(astBuilder, subst, &diff);
+ List<Val*> substArgs;
+ for (auto arg : getArgs())
+ {
+ substArgs.add(arg->substituteImpl(astBuilder, subst, &diff));
+ }
+ if (diff == 0)
+ return this;
+ (*ioDiff)++;
+ return astBuilder->getGenericAppDeclRef(substGenericDeclRef, substArgs.getArrayView(), getDecl());
+}
+
+GenericDecl* GenericAppDeclRef::getGenericDecl() { return as<GenericDecl>(getGenericDeclRef()->getDecl()); }
+
+
+void GenericAppDeclRef::_toTextOverride(StringBuilder& out)
+{
+ auto genericDecl = as<GenericDecl>(getGenericDeclRef()->getDecl());
+ Index paramCount = 0;
+ for (auto member : genericDecl->members)
+ if (as<GenericTypeParamDecl>(member) || as<GenericValueParamDecl>(member))
+ paramCount++;
+ getGenericDeclRef()->toText(out);
+ out << "<";
+ auto args = getArgs();
+ Index argCount = args.getCount();
+ for (Index aa = 0; aa < Math::Min(paramCount, argCount); ++aa)
+ {
+ if (aa != 0) out << ", ";
+ args[aa]->toText(out);
+ }
+ out << ">";
+}
+
+Val* GenericAppDeclRef::_resolveImplOverride()
+{
+ auto astBuilder = getCurrentASTBuilder();
+ Val* resolvedVal = this;
+ auto resolvedGenericDeclRef = _resolveAsDeclRef(getGenericDeclRef());
+ bool diff = false;
+ if (resolvedGenericDeclRef != getGenericDeclRef())
+ diff = true;
+ List<Val*> resolvedArgs;
+ for (auto arg : getArgs())
+ {
+ auto resolvedArg = arg->resolve();
+ resolvedArgs.add(resolvedArg);
+ if (resolvedArg != arg)
+ diff = true;
+ }
+ if (diff)
+ resolvedVal = astBuilder->getGenericAppDeclRef(resolvedGenericDeclRef, resolvedArgs.getArrayView(), getDecl());
+ return resolvedVal;
+}
+
+DeclRefBase* GenericAppDeclRef::_getBaseOverride()
+{
+ return getGenericDeclRef();
+}
+
+// Convenience accessors for common properties of declarations
+
+DeclRefBase* DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, substituteImpl, (astBuilder, subst, ioDiff));
+}
+
+DeclRefBase* DeclRefBase::getBase() { SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, getBase, ()); }
+void DeclRefBase::toText(StringBuilder& out) { SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, toText, (out)); }
+
+Name* DeclRefBase::getName() const
+{
+ return getDecl()->nameAndLoc.name;
+}
+
+SourceLoc DeclRefBase::getNameLoc() const
+{
+ return getDecl()->nameAndLoc.loc;
+}
+
+SourceLoc DeclRefBase::getLoc() const
+{
+ return getDecl()->loc;
+}
+
+DeclRefBase* DeclRefBase::getParent()
+{
+ auto astBuilder = getCurrentASTBuilder();
+ if (!getDecl()->parentDecl)
+ return nullptr;
+ auto parentDecl = getDecl()->parentDecl;
+ for (auto base = getBase(); base; base = base->getBase())
+ {
+ if (base->getDecl() == parentDecl)
+ return base;
+ bool parentIsChildOfBase = false;
+ for (auto dd = parentDecl->parentDecl; dd; dd = dd->parentDecl)
+ {
+ if (dd == base->getDecl())
+ {
+ parentIsChildOfBase = true;
+ break;
+ }
+ }
+ if (parentIsChildOfBase)
+ return astBuilder->getMemberDeclRef(base, parentDecl);
+ }
+ return astBuilder->getDirectDeclRef(parentDecl);
+}
+
+SubstitutionSet::operator bool() const
+{
+ return declRef != nullptr && !as<DirectDeclRef>(declRef);
+}
+
+Val::OperandView<Val> tryGetGenericArguments(SubstitutionSet substSet, Decl* genericDecl)
+{
+ if (!substSet.declRef)
+ return Val::OperandView<Val>();
+
+ DeclRefBase* currentDeclRef = substSet.declRef;
+ // search for a substitution that might apply to us
+ for (auto s = currentDeclRef; s; s = s->getBase())
+ {
+ auto genericAppDeclRef = as<GenericAppDeclRef>(s);
+ if (!genericAppDeclRef)
+ continue;
+
+ // the generic decl associated with the substitution list must be
+ // the generic decl that declared this parameter
+ auto parentGeneric = genericAppDeclRef->getGenericDecl();
+ if (parentGeneric != genericDecl)
+ continue;
+
+ return genericAppDeclRef->getArgs();
+ }
+ return Val::OperandView<Val>();
+}
+
+Type* SubstitutionSet::applyToType(ASTBuilder* astBuilder, Type* type) const
+{
+ if (!type)
+ return nullptr;
+ int diff = 0;
+ auto newType = as<Type>(type->substituteImpl(astBuilder, *this, &diff));
+ if (diff && newType)
+ return newType;
+ return type;
+}
+
+SubstExpr<Expr> applySubstitutionToExpr(SubstitutionSet substSet, Expr* expr)
+{
+ return SubstExpr<Expr>(expr, substSet);
+}
+
+
+DeclRefBase* SubstitutionSet::applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* otherDeclRef) const
+{
+ int diff = 0;
+ return otherDeclRef->substituteImpl(astBuilder, *this, &diff);
+}
+
+LookupDeclRef* SubstitutionSet::findLookupDeclRef() const
+{
+ for (auto s = declRef; s; s = s->getBase())
+ {
+ if (auto lookupDeclRef = as<LookupDeclRef>(s))
+ return lookupDeclRef;
+ }
+ return nullptr;
+}
+
+DeclRefBase* SubstitutionSet::getInnerMostNodeWithSubstInfo() const
+{
+ for (auto s = declRef; s; s = s->getBase())
+ {
+ if (as<LookupDeclRef>(s) || as<GenericAppDeclRef>(s))
+ return s;
+ }
+ return nullptr;
+}
+
+
+GenericAppDeclRef* SubstitutionSet::findGenericAppDeclRef(GenericDecl* genericDecl) const
+{
+ for (auto s = declRef; s; s = s->getBase())
+ {
+ if (auto genApp = as<GenericAppDeclRef>(s))
+ {
+ if (genApp->getGenericDecl() == genericDecl)
+ return genApp;
+ }
+ }
+ return nullptr;
+}
+
+GenericAppDeclRef* SubstitutionSet::findGenericAppDeclRef() const
+{
+ for (auto s = declRef; s; s = s->getBase())
+ {
+ if (auto genApp = as<GenericAppDeclRef>(s))
+ {
+ return genApp;
+ }
+ }
+ return nullptr;
+}
+
+DeclRef<Decl> createDefaultSubstitutionsIfNeeded(
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
+ DeclRef<Decl> declRef)
+{
+ if (declRef.as<GenericTypeParamDecl>())
+ return declRef;
+ if (declRef.as<GenericValueParamDecl>())
+ return declRef;
+ if (declRef.as<GenericTypeConstraintDecl>())
+ return declRef;
+ ShortList<GenericDecl*> genericParentDecls;
+ auto lastSubstNode = SubstitutionSet(declRef).getInnerMostNodeWithSubstInfo();
+ auto lastGenApp = as<GenericAppDeclRef>(lastSubstNode);
+ for (auto dd = declRef.getDecl()->parentDecl; dd; dd = dd->parentDecl)
+ {
+ if (lastGenApp && dd == lastGenApp->getGenericDecl())
+ break;
+ if (auto gen = as<GenericDecl>(dd))
+ genericParentDecls.add(gen);
+ }
+ DeclRef<Decl> parentDeclRef = lastSubstNode;
+ for (auto i = genericParentDecls.getCount() - 1; i >= 0; i--)
+ {
+ auto current = genericParentDecls[i];
+ auto args = getDefaultSubstitutionArgs(astBuilder, semantics, current);
+ if (parentDeclRef)
+ {
+ parentDeclRef = astBuilder->getMemberDeclRef(parentDeclRef, current);
+ }
+ else
+ {
+ parentDeclRef = astBuilder->getDirectDeclRef(current);
+ }
+ parentDeclRef = astBuilder->getGenericAppDeclRef(parentDeclRef.as<GenericDecl>(), args.getArrayView());
+ }
+ if (parentDeclRef.getDecl() == declRef.getDecl())
+ return parentDeclRef;
+ return astBuilder->getMemberDeclRef(parentDeclRef, declRef.getDecl());
+}
+}
diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp
index 2f1c7c47e..9dbd006a0 100644
--- a/source/slang/slang-ast-decl.cpp
+++ b/source/slang/slang-ast-decl.cpp
@@ -4,6 +4,7 @@
#include <assert.h>
#include "slang-generated-ast-macro.h"
+#include "slang-ast-decl.h"
namespace Slang {
@@ -118,4 +119,21 @@ bool isLocalVar(const Decl* decl)
return false;
}
+ThisTypeDecl* InterfaceDecl::getThisTypeDecl()
+{
+ for (auto member : members)
+ {
+ if (auto thisTypeDeclCandidate = as<ThisTypeDecl>(member))
+ {
+ return thisTypeDeclCandidate;
+ }
+ }
+ SLANG_UNREACHABLE("InterfaceDecl does not have a ThisType decl.");
+}
+
+InterfaceDecl* ThisTypeConstraintDecl::getInterfaceDecl()
+{
+ return as<InterfaceDecl>(parentDecl->parentDecl);
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 61e623366..8266d77c7 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -56,6 +56,15 @@ class ContainerDecl: public Decl
return transparentMembers;
}
+ void addMember(Decl* member)
+ {
+ if (member)
+ {
+ member->parentDecl = this;
+ members.add(member);
+ }
+ }
+
SLANG_UNREFLECTED // We don't want to reflect the following fields
private:
@@ -178,12 +187,19 @@ class EnumCaseDecl : public Decl
Expr* tagExpr = nullptr;
};
+// A member of InterfaceDecl representing the abstract ThisType.
+class ThisTypeDecl : public AggTypeDecl
+{
+ SLANG_AST_CLASS(ThisTypeDecl)
+};
+
// An interface which other types can conform to
class InterfaceDecl : public AggTypeDecl
{
SLANG_AST_CLASS(InterfaceDecl)
-};
+ ThisTypeDecl* getThisTypeDecl();
+};
class TypeConstraintDecl : public Decl
{
@@ -195,6 +211,15 @@ class TypeConstraintDecl : public Decl
const TypeExp& _getSupOverride() const;
};
+class ThisTypeConstraintDecl : public TypeConstraintDecl
+{
+ SLANG_AST_CLASS(ThisTypeConstraintDecl)
+
+ TypeExp base;
+ const TypeExp& _getSupOverride() const { return base; }
+ InterfaceDecl* getInterfaceDecl();
+};
+
// A kind of pseudo-member that represents an explicit
// or implicit inheritance relationship.
//
diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp
index 0ab440a18..65718833b 100644
--- a/source/slang/slang-ast-dump.cpp
+++ b/source/slang/slang-ast-dump.cpp
@@ -65,18 +65,6 @@ struct ASTDumpContext
}
}
- void dump(Substitutions* subs)
- {
- if (subs == nullptr)
- {
- _dumpPtr(nullptr);
- }
- else
- {
- dumpObject(subs->getClassInfo(), subs);
- }
- }
-
void dump(const Name* name)
{
if (name == nullptr)
@@ -608,6 +596,40 @@ struct ASTDumpContext
m_writer->emit("\n");
}
+ template<int N>
+ void dump(const ShortList<ValNodeOperand, N>& operands)
+ {
+ m_writer->emit("(");
+ bool isFirst = true;
+ for (auto operand : operands)
+ {
+ if (!isFirst)
+ {
+ m_writer->emit(", ");
+ }
+ isFirst = false;
+ dumpField("operand", operand);
+ }
+
+ m_writer->emit(")");
+ }
+
+ void dump(ValNodeOperand operand)
+ {
+ switch (operand.kind)
+ {
+ case ValNodeOperandKind::ConstantValue:
+ dump(operand.values.intOperand);
+ break;
+ case ValNodeOperandKind::ValNode:
+ dump(operand.values.nodeOperand);
+ break;
+ case ValNodeOperandKind::ASTNode:
+ dump(operand.values.nodeOperand);
+ break;
+ }
+ }
+
void dump(ASTNodeType nodeType)
{
// Get the class
@@ -616,6 +638,15 @@ struct ASTDumpContext
m_writer->emit(info->m_name);
}
+ void dump(KeyValuePair<DeclRefBase*, SubtypeWitness*> pair)
+ {
+ m_writer->emit("(");
+ dump(pair.key);
+ m_writer->emit(", ");
+ dump(pair.value);
+ m_writer->emit(")");
+ }
+
void dumpObjectFull(NodeBase* node);
ASTDumpContext(SourceWriter* writer, ASTDumpUtil::Flags flags, ASTDumpUtil::Style dumpStyle):
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 07bf2f033..28ce2e4d1 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -560,17 +560,6 @@ class TreatAsDifferentiableExpr : public Expr
Flavor flavor;
};
- /// A type expression of the form `__TaggedUnion(A, ...)`.
- ///
- /// An expression of this form will resolve to a `TaggedUnionType`
- /// when checked.
- ///
-class TaggedUnionTypeExpr: public Expr
-{
- SLANG_AST_CLASS(TaggedUnionTypeExpr)
- List<TypeExp> caseTypes;
-};
-
/// A type expression of the form `This`
///
/// Refers to the type of `this` in the current context.
@@ -639,7 +628,7 @@ public:
DeclRef<GenericDecl> baseGenericDeclRef;
/// A substitution that includes the generic arguments known so far
- GenericSubstitution* substWithKnownGenericArgs = nullptr;
+ List<Val*> knownGenericArgs;
};
} // namespace Slang
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index ea3db6937..fb3d50b4f 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -99,11 +99,6 @@ struct ASTIterator
dispatchIfNotNull(expr->base.exp);
}
- void visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr)
- {
- iterator->maybeDispatchCallback(expr);
- }
-
void visitInvokeExpr(InvokeExpr* expr)
{
iterator->maybeDispatchCallback(expr);
diff --git a/source/slang/slang-ast-modifier.cpp b/source/slang/slang-ast-modifier.cpp
index 3daa9b056..84046a601 100644
--- a/source/slang/slang-ast-modifier.cpp
+++ b/source/slang/slang-ast-modifier.cpp
@@ -4,5 +4,10 @@
namespace Slang
{
-
+const OrderedDictionary<DeclRefBase*, SubtypeWitness*>& DifferentiableAttribute::getMapTypeToIDifferentiableWitness()
+{
+ for (Index i = m_mapToIDifferentiableWitness.getCount(); i < m_typeToIDifferentiableWitnessMappings.getCount(); i++)
+ m_mapToIDifferentiableWitness.add(m_typeToIDifferentiableWitnessMappings[i].key, m_typeToIDifferentiableWitnessMappings[i].value);
+ return m_mapToIDifferentiableWitness;
+}
} // namespace Slang
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index fd317a2c2..8e7cc9193 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -600,7 +600,7 @@ class Attribute : public AttributeBase
{
SLANG_AST_CLASS(Attribute)
- AttributeArgumentValueDict intArgVals;
+ List<Val*> intArgVals;
};
class UserDefinedAttribute : public Attribute
@@ -1054,10 +1054,23 @@ class DifferentiableAttribute : public Attribute
{
SLANG_AST_CLASS(DifferentiableAttribute)
+ List<KeyValuePair<DeclRefBase*, SubtypeWitness*>> m_typeToIDifferentiableWitnessMappings;
+
+ void addType(DeclRefBase* declRef, SubtypeWitness* witness)
+ {
+ getMapTypeToIDifferentiableWitness();
+ if (m_mapToIDifferentiableWitness.addIfNotExists(declRef, witness))
+ {
+ m_typeToIDifferentiableWitnessMappings.add(KeyValuePair<DeclRefBase*, SubtypeWitness*>(declRef, witness));
+ }
+ }
+
/// Mapping from types to subtype witnesses for conformance to IDifferentiable.
- OrderedDictionary<DeclRefBase*, SubtypeWitness*> m_mapTypeToIDifferentiableWitness;
+ const OrderedDictionary<DeclRefBase*, SubtypeWitness*>& getMapTypeToIDifferentiableWitness();
SLANG_UNREFLECTED ValSet m_typeRegistrationWorkingSet;
+private:
+ OrderedDictionary<DeclRefBase*, SubtypeWitness*> m_mapToIDifferentiableWitness;
};
class DllImportAttribute : public Attribute
diff --git a/source/slang/slang-ast-natural-layout.cpp b/source/slang/slang-ast-natural-layout.cpp
index 1789c5cea..4a4ef37fb 100644
--- a/source/slang/slang-ast-natural-layout.cpp
+++ b/source/slang/slang-ast-natural-layout.cpp
@@ -70,9 +70,9 @@ Count ASTNaturalLayoutContext::_getCount(IntVal* intVal)
{
if (auto constIntVal = as<ConstantIntVal>(intVal))
{
- if (constIntVal->value >= 0)
+ if (constIntVal->getValue() >= 0)
{
- return Count(constIntVal->value);
+ return Count(constIntVal->getValue());
}
}
@@ -115,9 +115,9 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type)
{
if (VectorExpressionType* vecType = as<VectorExpressionType>(type))
{
- const Count elementCount = _getCount(vecType->elementCount);
+ const Count elementCount = _getCount(vecType->getElementCount());
return (elementCount > 0) ?
- calcSize(vecType->elementType) * elementCount :
+ calcSize(vecType->getElementType()) * elementCount :
NaturalSize::makeInvalid();
}
else if (auto matType = as<MatrixExpressionType>(type))
@@ -130,7 +130,7 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type)
}
else if (auto basicType = as<BasicExpressionType>(type))
{
- return NaturalSize::makeFromBaseType(basicType->baseType);
+ return NaturalSize::makeFromBaseType(basicType->getBaseType());
}
else if (as<PtrTypeBase>(type) || as<NullPtrType>(type))
{
@@ -146,7 +146,7 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type)
}
else if (auto namedType = as<NamedExpressionType>(type))
{
- return calcSize(namedType->innerType);
+ return calcSize(namedType->getCanonicalType());
}
else if (const auto tupleType = as<TupleType>(type))
{
@@ -154,9 +154,9 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type)
NaturalSize size = NaturalSize::makeEmpty();
// Accumulate over all the member types
- for (auto cur : tupleType->memberTypes)
+ for (auto cur = 0; cur < tupleType->getMemberCount(); cur++)
{
- const auto curSize = calcSize(cur);
+ const auto curSize = calcSize(tupleType->getMember(cur));
if (!curSize)
{
return NaturalSize::makeInvalid();
@@ -166,36 +166,14 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type)
return size;
}
- else if (const auto taggedUnion = as<TaggedUnionType>(type))
- {
- NaturalSize size = NaturalSize::makeInvalid();
-
- for( auto caseType : taggedUnion->caseTypes )
- {
- const NaturalSize caseSize = calcSize(caseType);
- if (!caseSize)
- {
- return NaturalSize::makeInvalid();
- }
- size = NaturalSize::calcUnion(size, caseSize);
- }
-
- // After we've computed the size required to hold all the
- // case types, we will allocate space for the tag field.
-
- // Currently we assume uint32_t on all targets
- size.append(NaturalSize::makeFromBaseType(BaseType::UInt));
-
- return size;
- }
else if( auto declRefType = as<DeclRefType>(type) )
{
- if (const auto enumDeclRef = declRefType->declRef.as<EnumDecl>())
+ if (const auto enumDeclRef = declRefType->getDeclRef().as<EnumDecl>())
{
Type* tagType = getTagType(m_astBuilder, enumDeclRef);
return calcSize(tagType);
}
- else if(const auto structDeclRef = declRefType->declRef.as<StructDecl>())
+ else if(const auto structDeclRef = declRefType->getDeclRef().as<StructDecl>())
{
// Poison the cache whilst we construct
m_typeToSize.add(type, NaturalSize::makeInvalid());
@@ -208,7 +186,7 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type)
// Look for a struct type that it inherits from
if (auto inheritedDeclRef = as<DeclRefType>(inherited->base.type))
{
- if (auto parentDecl = inheritedDeclRef->declRef.as<StructDecl>())
+ if (auto parentDecl = inheritedDeclRef->getDeclRef().as<StructDecl>())
{
// We can only inherit from one thing
size = calcSize(inherited->base.type);
@@ -237,7 +215,7 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type)
return size;
}
- else if (const auto typeDef = declRefType->declRef.as<TypeDefDecl>())
+ else if (const auto typeDef = declRefType->getDeclRef().as<TypeDefDecl>())
{
return calcSize(typeDef.getDecl()->type);
}
diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp
index 84c521108..b80afeee1 100644
--- a/source/slang/slang-ast-print.cpp
+++ b/source/slang/slang-ast-print.cpp
@@ -36,12 +36,12 @@ void ASTPrinter::addType(Type* type)
{
if (auto vectorType = as<VectorExpressionType>(type))
{
- if (as<BasicExpressionType>(vectorType->elementType))
+ if (as<BasicExpressionType>(vectorType->getElementType()))
{
- vectorType->elementType->toText(m_builder);
- if (as<ConstantIntVal>(vectorType->elementCount))
+ vectorType->getElementType()->toText(m_builder);
+ if (as<ConstantIntVal>(vectorType->getElementCount()))
{
- m_builder << vectorType->elementCount;
+ m_builder << vectorType->getElementCount();
return;
}
}
@@ -107,14 +107,14 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth)
auto& sb = m_builder;
// Find the parent declaration
- auto parentDeclRef = declRef.getParent(m_astBuilder);
+ auto parentDeclRef = declRef.getParent();
// If the immediate parent is a generic, then we probably
// want the declaration above that...
auto parentGenericDeclRef = parentDeclRef.as<GenericDecl>();
if (parentGenericDeclRef)
{
- parentDeclRef = parentGenericDeclRef.getParent(m_astBuilder);
+ parentDeclRef = parentGenericDeclRef.getParent();
}
// Depending on what the parent is, we may want to format things specially
@@ -172,12 +172,9 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth)
!declRef.as<GenericValueParamDecl>() &&
!declRef.as<GenericTypeParamDecl>())
{
- auto genSubst = as<GenericSubstitution>(declRef.getSubst());
- if (genSubst)
+ auto substArgs = tryGetGenericArguments(SubstitutionSet(declRef), parentGenericDeclRef.getDecl());
+ if (substArgs.getCount())
{
- SLANG_RELEASE_ASSERT(genSubst);
- SLANG_RELEASE_ASSERT(genSubst->getGenericDecl() == parentGenericDeclRef.getDecl());
-
// If the name we printed previously was an operator
// that ends with `<`, then immediately printing the
// generic arguments inside `<...>` may cause it to
@@ -193,7 +190,7 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth)
sb << "<";
bool first = true;
- for (auto arg : genSubst->getArgs())
+ for (auto arg : substArgs)
{
// When printing the representation of a specialized
// generic declaration we don't want to include the
@@ -331,7 +328,7 @@ void ASTPrinter::addDeclParams(const DeclRef<Decl>& declRef, List<Range<Index>>*
{
addGenericParams(genericDeclRef);
- addDeclParams(m_astBuilder->getSpecializedDeclRef<Decl>(getInner(genericDeclRef), genericDeclRef.getSubst()), outParamRange);
+ addDeclParams(m_astBuilder->getMemberDeclRef(genericDeclRef, genericDeclRef.getDecl()->inner), outParamRange);
}
else
{
@@ -443,7 +440,7 @@ void ASTPrinter::addDeclResultType(const DeclRef<Decl>& inDeclRef)
DeclRef<Decl> declRef = inDeclRef;
if (auto genericDeclRef = declRef.as<GenericDecl>())
{
- declRef = m_astBuilder->getSpecializedDeclRef<Decl>(getInner(genericDeclRef), genericDeclRef.getSubst());
+ declRef = m_astBuilder->getMemberDeclRef<Decl>(genericDeclRef, genericDeclRef.getDecl()->inner);
}
if (declRef.as<ConstructorDecl>())
diff --git a/source/slang/slang-ast-reflect.cpp b/source/slang/slang-ast-reflect.cpp
index b16568d2e..66e57a744 100644
--- a/source/slang/slang-ast-reflect.cpp
+++ b/source/slang/slang-ast-reflect.cpp
@@ -39,7 +39,7 @@ struct ASTConstructAccess
static void* create(void* context)
{
ASTBuilder* astBuilder = (ASTBuilder*)context;
- return astBuilder->create<T>();
+ return astBuilder->createImpl<T>();
}
static void destroy(void* ptr)
{
diff --git a/source/slang/slang-ast-substitutions.cpp b/source/slang/slang-ast-substitutions.cpp
deleted file mode 100644
index 7b052522e..000000000
--- a/source/slang/slang-ast-substitutions.cpp
+++ /dev/null
@@ -1,163 +0,0 @@
-// slang-ast-substitutions.cpp
-#include "slang-ast-builder.h"
-#include <assert.h>
-
-#include "slang-generated-ast-macro.h"
-
-namespace Slang {
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Substitutions !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-Substitutions* Substitutions::applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff)
-{
- SLANG_AST_NODE_VIRTUAL_CALL(Substitutions, applySubstitutionsShallow, (astBuilder, substSet, substOuter, ioDiff))
-}
-
-bool Substitutions::equals(Substitutions* subst)
-{
- SLANG_AST_NODE_VIRTUAL_CALL(Substitutions, equals, (subst))
-}
-
-HashCode Substitutions::getHashCode() const
-{
- SLANG_AST_NODE_CONST_VIRTUAL_CALL(Substitutions, getHashCode, ())
-}
-
-Substitutions* Substitutions::_applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff)
-{
- SLANG_UNUSED(astBuilder);
- SLANG_UNUSED(substSet);
- SLANG_UNUSED(substOuter);
- SLANG_UNUSED(ioDiff);
- SLANG_UNEXPECTED("Substitutions::_applySubstitutionsShallowOverride not overridden");
- //return Substitutions*();
-}
-
-bool Substitutions::_equalsOverride(Substitutions* subst)
-{
- SLANG_UNUSED(subst);
- SLANG_UNEXPECTED("Substitutions::_equalsOverride not overridden");
- //return false;
-}
-
-HashCode Substitutions::_getHashCodeOverride() const
-{
- SLANG_UNEXPECTED("Substitutions::_getHashCodeOverride not overridden");
- //return HashCode(0);
-}
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericSubstitution !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-Substitutions* GenericSubstitution::_applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff)
-{
- int diff = 0;
-
- if (substOuter != outer) diff++;
-
- List<Val*> substArgs;
- for (auto a : args)
- {
- substArgs.add(a->substituteImpl(astBuilder, substSet, &diff));
- }
-
- if (!diff) return this;
-
- (*ioDiff)++;
-
- auto substSubst = astBuilder->getOrCreateGenericSubstitution(substOuter, genericDecl, substArgs);
- return substSubst;
-}
-
-bool GenericSubstitution::_equalsOverride(Substitutions* subst)
-{
- // both must be NULL, or non-NULL
- if (subst == nullptr)
- return false;
- if (this == subst)
- return true;
-
- auto genericSubst = as<GenericSubstitution>(subst);
- if (!genericSubst)
- return false;
- if (genericDecl != genericSubst->genericDecl)
- return false;
-
- Index argCount = args.getCount();
- SLANG_RELEASE_ASSERT(args.getCount() == genericSubst->args.getCount());
- for (Index aa = 0; aa < argCount; ++aa)
- {
- if (!args[aa]->equalsVal(genericSubst->args[aa]))
- return false;
- }
-
- if (!outer)
- return !genericSubst->outer;
-
- if (!outer->equals(genericSubst->outer))
- return false;
-
- return true;
-}
-
-HashCode GenericSubstitution::_getHashCodeOverride() const
-{
- HashCode rs = 0;
- for (auto && v : args)
- {
- rs ^= v->getHashCode();
- rs *= 16777619;
- }
- return rs;
-}
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisTypeSubstitution !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-Substitutions* ThisTypeSubstitution::_applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff)
-{
- int diff = 0;
-
- if (substOuter != outer) diff++;
-
- // NOTE: Must use .as because we must have a smart pointer here to keep in scope.
- auto substWitness = as<SubtypeWitness>(witness->substituteImpl(astBuilder, substSet, &diff));
-
- if (!diff) return this;
-
- (*ioDiff)++;
- ThisTypeSubstitution* substSubst;
-
- substSubst = astBuilder->getOrCreateThisTypeSubstitution(interfaceDecl, substWitness, substOuter);
- return substSubst;
-}
-
-bool ThisTypeSubstitution::_equalsOverride(Substitutions* subst)
-{
- if (!subst)
- return false;
- if (subst == this)
- return true;
-
- if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst))
- {
- // For our purposes, two this-type substitutions are
- // equivalent if they have the same type as `This`,
- // even if the specific witness values they use
- // might differ.
- //
- if (this->interfaceDecl != thisTypeSubst->interfaceDecl)
- return false;
-
- if (!this->witness->sub->equals(thisTypeSubst->witness->sub))
- return false;
-
- return true;
- }
- return false;
-}
-
-HashCode ThisTypeSubstitution::_getHashCodeOverride() const
-{
- return witness->sub->getHashCode();
-}
-
-} // namespace Slang
diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp
index a3df25ce9..6a957e427 100644
--- a/source/slang/slang-ast-support-types.cpp
+++ b/source/slang/slang-ast-support-types.cpp
@@ -68,4 +68,5 @@ UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr)
return UnownedStringSlice("bwd_diff");
return UnownedStringSlice();
}
+
}
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 4765d11ec..9140be967 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -28,7 +28,6 @@ namespace Slang
class Module;
class Name;
class Session;
- class Substitutions;
class SyntaxVisitor;
class FuncDecl;
class Layout;
@@ -50,6 +49,8 @@ namespace Slang
class Val;
class NodeBase;
+ class LookupDeclRef;
+ class GenericAppDeclRef;
template <typename T>
@@ -625,19 +626,27 @@ namespace Slang
struct SubstitutionSet
{
- Substitutions* substitutions = nullptr;
- operator Substitutions*() const
- {
- return substitutions;
- }
+ DeclRefBase* declRef = nullptr;
+ SubstitutionSet() = default;
+ SubstitutionSet(DeclRefBase* declRefBase)
+ :declRef(declRefBase)
+ {}
+ explicit operator bool() const;
+
+ template<typename F>
+ void forEachGenericSubstitution(F func) const;
+
+ template<typename F>
+ void forEachSubstitutionArg(F func) const;
+
+ Type* applyToType(ASTBuilder* astBuilder, Type* type) const;
+ DeclRefBase* applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* declRef) const;
+
+ LookupDeclRef* findLookupDeclRef() const;
+ GenericAppDeclRef* findGenericAppDeclRef(GenericDecl* genericDecl) const;
+ GenericAppDeclRef* findGenericAppDeclRef() const;
+ DeclRefBase* getInnerMostNodeWithSubstInfo() const;
- SubstitutionSet() {}
- SubstitutionSet(Substitutions* subst)
- : substitutions(subst)
- {
- }
- bool equals(const SubstitutionSet& substSet) const;
- HashCode getHashCode() const;
};
/// An expression together with (optional) substutions to apply to it
@@ -741,6 +750,8 @@ namespace Slang
}
};
+ SubstExpr<Expr> applySubstitutionToExpr(SubstitutionSet substSet, Expr* expr);
+
class ASTBuilder;
template<typename T>
@@ -752,7 +763,6 @@ namespace Slang
// try to find the concrete decl that satisfies the associatedtype requirement from the
// concrete type supplied by ThisTypeSubstittution.
Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef<Decl> declRef);
- void _printNestedDecl(const Substitutions* substitutions, const Decl* decl, StringBuilder& out);
template<typename T = Decl>
struct DeclRef
@@ -780,13 +790,12 @@ namespace Slang
{}
T* getDecl() const;
- Substitutions* getSubst() const;
Name* getName() const;
SourceLoc getNameLoc() const;
SourceLoc getLoc() const;
- DeclRef<ContainerDecl> getParent(ASTBuilder* astBuilder) const;
+ DeclRef<ContainerDecl> getParent() const;
HashCode getHashCode() const;
Type* substitute(ASTBuilder* astBuilder, Type* type) const;
@@ -823,7 +832,10 @@ namespace Slang
}
template<typename U>
- bool equals(DeclRef<U> other) const;
+ bool equals(DeclRef<U> other) const
+ {
+ return declRefBase == other.declRefBase;
+ }
template<typename U>
bool operator == (DeclRef<U> other) const
@@ -979,17 +991,17 @@ namespace Slang
struct FilteredMemberRefList
{
List<Decl*> const& m_decls;
- SubstitutionSet m_substitutions;
+ DeclRef<Decl> m_parent;
MemberFilterStyle m_filterStyle;
ASTBuilder* m_astBuilder;
FilteredMemberRefList(
ASTBuilder* astBuilder,
List<Decl*> const& decls,
- SubstitutionSet substitutions,
+ DeclRef<Decl> parent,
MemberFilterStyle filterStyle = MemberFilterStyle::All)
: m_decls(decls)
- , m_substitutions(substitutions)
+ , m_parent(parent)
, m_filterStyle(filterStyle)
, m_astBuilder(astBuilder)
{}
@@ -1007,7 +1019,7 @@ namespace Slang
{
Decl*const* decl = getFilterCursorByIndex<T>(m_filterStyle, m_decls.begin(), m_decls.end(), index);
SLANG_ASSERT(decl);
- return _getSpecializedDeclRef(m_astBuilder, (T*)*decl, m_substitutions).template as<T>();
+ return _getMemberDeclRef(m_astBuilder, m_parent, (T*)*decl).template as<T>();
}
List<DeclRef<T>> toArray() const
@@ -1042,7 +1054,7 @@ namespace Slang
void operator++() { m_ptr = adjustFilterCursor<T>(m_filterStyle, m_ptr + 1, m_end); }
- DeclRef<T> operator*() { return _getSpecializedDeclRef(m_list->m_astBuilder, (T*)*m_ptr, m_list->m_substitutions).template as<T>(); }
+ DeclRef<T> operator*() { return _getMemberDeclRef(m_list->m_astBuilder, m_list->m_parent, (T*)*m_ptr).template as<T>(); }
};
Iterator begin() const { return Iterator(this, adjustFilterCursor<T>(m_filterStyle, m_decls.begin(), m_decls.end()), m_decls.end(), m_filterStyle); }
@@ -1431,7 +1443,18 @@ namespace Slang
{
SLANG_OBJ_CLASS(WitnessTable)
- RequirementDictionary requirementDictionary;
+ const RequirementDictionary& getRequirementDictionary()
+ {
+ if (m_requirementDictionary.getCount() != m_requirements.getCount())
+ {
+ for (Index i = m_requirementDictionary.getCount(); i < m_requirements.getCount(); i++)
+ {
+ auto& r = m_requirements[i];
+ m_requirementDictionary.add(r.key, r.value);
+ }
+ }
+ return m_requirementDictionary;
+ }
void add(Decl* decl, RequirementWitness const& witness);
@@ -1440,9 +1463,13 @@ namespace Slang
// The type witnessesd by the witness table (a concrete type).
Type* witnessedType;
- };
- typedef Dictionary<unsigned int, NodeBase*> AttributeArgumentValueDict;
+ // Satisfying values of each requirement.
+ List<KeyValuePair<Decl*, RequirementWitness>> m_requirements;
+
+ // Cached dictionary for looking up satisfying values.
+ SLANG_UNREFLECTED RequirementDictionary m_requirementDictionary;
+ };
struct SpecializationParam
{
@@ -1551,6 +1578,7 @@ namespace Slang
/// Get the operator name from the higher order invoke expr.
UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr);
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp
index 65955e815..cb7d338c8 100644
--- a/source/slang/slang-ast-synthesis.cpp
+++ b/source/slang/slang-ast-synthesis.cpp
@@ -134,8 +134,7 @@ Expr* ASTSynthesizer::emitMemberExpr(Type* type, Name* name)
{
auto rs = m_builder->create<StaticMemberExpr>();
auto typeExpr = m_builder->create<SharedTypeExpr>();
- auto typetype = m_builder->create<TypeType>();
- typetype->type = type;
+ auto typetype = m_builder->getOrCreate<TypeType>(type);
typeExpr->type = typetype;
rs->baseExpression = typeExpr;
rs->name = name;
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index ee5d1d40e..13133a7f8 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -1,49 +1,19 @@
// slang-ast-type.cpp
#include "slang-ast-builder.h"
+#include "slang-ast-modifier.h"
#include <assert.h>
#include <typeinfo>
#include "slang-syntax.h"
#include "slang-generated-ast-macro.h"
-
namespace Slang {
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-Type* Type::createCanonicalType()
-{
- SLANG_AST_NODE_VIRTUAL_CALL(Type, createCanonicalType, ())
-}
-
-bool Type::equals(Type* type)
-{
- return getCanonicalType()->equalsImpl(type->getCanonicalType());
-}
-
-bool Type::equalsImpl(Type* type)
-{
- SLANG_AST_NODE_VIRTUAL_CALL(Type, equalsImpl, (type))
-}
-
-bool Type::_equalsImplOverride(Type* type)
-{
- SLANG_UNUSED(type)
- SLANG_UNEXPECTED("Type::_equalsImplOverride not overridden");
- //return false;
-}
-
Type* Type::_createCanonicalTypeOverride()
{
- SLANG_UNEXPECTED("Type::_createCanonicalTypeOverride not overridden");
- //return Type*();
-}
-
-bool Type::_equalsValOverride(Val* val)
-{
- if (auto type = dynamicCast<Type>(val))
- return const_cast<Type*>(this)->equals(type);
- return false;
+ return as<Type>(defaultResolveImpl());
}
Val* Type::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
@@ -61,20 +31,6 @@ Val* Type::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst
return canSubst;
}
-Type* Type::getCanonicalType()
-{
- Type* et = const_cast<Type*>(this);
- if (!et->canonicalType)
- {
- // TODO(tfoley): worry about thread safety here?
- auto canType = et->createCanonicalType();
- et->canonicalType = canType;
- if (!et->canonicalType)
- return getASTBuilder()->getErrorType();
- }
- return et->canonicalType;
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! OverloadGroupType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void OverloadGroupType::_toTextOverride(StringBuilder& out)
@@ -82,21 +38,11 @@ void OverloadGroupType::_toTextOverride(StringBuilder& out)
out << toSlice("overload group");
}
-bool OverloadGroupType::_equalsImplOverride(Type * /*type*/)
-{
- return false;
-}
-
Type* OverloadGroupType::_createCanonicalTypeOverride()
{
return this;
}
-HashCode OverloadGroupType::_getHashCodeOverride()
-{
- return (HashCode)(size_t(this));
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! InitializerListType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void InitializerListType::_toTextOverride(StringBuilder& out)
@@ -104,21 +50,11 @@ void InitializerListType::_toTextOverride(StringBuilder& out)
out << toSlice("initializer list");
}
-bool InitializerListType::_equalsImplOverride(Type * /*type*/)
-{
- return false;
-}
-
Type* InitializerListType::_createCanonicalTypeOverride()
{
return this;
}
-HashCode InitializerListType::_getHashCodeOverride()
-{
- return (HashCode)(size_t(this));
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void ErrorType::_toTextOverride(StringBuilder& out)
@@ -126,11 +62,6 @@ void ErrorType::_toTextOverride(StringBuilder& out)
out << toSlice("error");
}
-bool ErrorType::_equalsImplOverride(Type* type)
-{
- return as<ErrorType>(type);
-}
-
Type* ErrorType::_createCanonicalTypeOverride()
{
return this;
@@ -141,56 +72,21 @@ Val* ErrorType::_substituteImplOverride(ASTBuilder* /* astBuilder */, Substituti
return this;
}
-HashCode ErrorType::_getHashCodeOverride()
-{
- return HashCode(size_t(this));
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BottomType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void BottomType::_toTextOverride(StringBuilder& out) { out << toSlice("never"); }
-bool BottomType::_equalsImplOverride(Type* type)
-{
- return as<BottomType>(type);
-}
-
-Type* BottomType::_createCanonicalTypeOverride() { return this; }
-
Val* BottomType::_substituteImplOverride(
ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/)
{
return this;
}
-HashCode BottomType::_getHashCodeOverride() { return HashCode(size_t(this)); }
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void DeclRefType::_toTextOverride(StringBuilder& out)
{
- out << declRef;
-}
-
-HashCode DeclRefType::_getHashCodeOverride()
-{
- return (declRef.getHashCode() * 16777619) ^ (HashCode)(typeid(this).hash_code());
-}
-
-bool DeclRefType::_equalsImplOverride(Type * type)
-{
- if (auto declRefType = as<DeclRefType>(type))
- {
- return declRef.equals(declRefType->declRef);
- }
- return false;
-}
-
-Type* DeclRefType::_createCanonicalTypeOverride()
-{
- // A declaration reference is already canonical
- declRef.substitute(this->getASTBuilder(), this);
- return this;
+ out << getDeclRef();
}
Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff);
@@ -199,26 +95,47 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe
{
if (!subst) return this;
- // the case we especially care about is when this type references a declaration
- // of a generic parameter, since that is what we might be substituting...
- if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(declRef.getDecl()))
+ int diff = 0;
+ DeclRef<Decl> substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
+
+ // If this declref type is a direct reference to ThisType or a Generic parameter,
+ // and `subst` provides an argument for it, then we should just return that argument.
+ //
+ if (as<DirectDeclRef>(substDeclRef.declRefBase))
{
- if (auto result = maybeSubstituteGenericParam(this, genericTypeParamDecl, subst, ioDiff))
+ if (auto thisDecl = as<ThisTypeDecl>(substDeclRef.getDecl()))
+ {
+ auto lookupDeclRef = subst.findLookupDeclRef();
+ if (lookupDeclRef && lookupDeclRef->getSupDecl() == substDeclRef.getDecl()->parentDecl)
+ {
+ (*ioDiff)++;
+ return lookupDeclRef->getLookupSource();
+ }
+ }
+ else if (as<GenericTypeParamDecl>(substDeclRef.getDecl()) || as<GenericValueParamDecl>(substDeclRef.getDecl()))
{
- if (auto substDeclRefType = as<DeclRefType>(result))
+ auto resultVal = maybeSubstituteGenericParam(nullptr, substDeclRef.getDecl(), subst, ioDiff);
+ if (resultVal)
{
- // After generic substitution, we may be able to further simplify
- // by looking up the actual type of an associated type.
- if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(
- astBuilder, substDeclRefType->declRef))
- return satisfyingVal;
+ (*ioDiff)++;
+ return resultVal;
}
- return result;
}
}
- int diff = 0;
- DeclRef<Decl> substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff);
+ // If this type is a reference to an associated type declaration,
+ // and the substitutions provide a "this type" substitution for
+ // the outer interface, then try to replace the type with the
+ // actual value of the associated type for the given implementation.
+ //
+ if (auto satisfyingVal = substDeclRef.declRefBase->resolve())
+ {
+ if (satisfyingVal != getDeclRef())
+ {
+ *ioDiff += 1;
+ return DeclRefType::create(astBuilder, substDeclRef);
+ }
+ }
if (!diff)
return this;
@@ -226,14 +143,6 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe
// Make sure to record the difference!
*ioDiff += diff;
- // If this type is a reference to an associated type declaration,
- // and the substitutions provide a "this type" substitution for
- // the outer interface, then try to replace the type with the
- // actual value of the associated type for the given implementation.
- //
- if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, substDeclRef))
- return satisfyingVal;
-
// Re-construct the type in case we are using a specialized sub-class
return DeclRefType::create(astBuilder, substDeclRef);
}
@@ -254,40 +163,52 @@ BasicExpressionType* ArithmeticExpressionType::_getScalarTypeOverride()
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BasicExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool BasicExpressionType::_equalsImplOverride(Type * type)
+BasicExpressionType* BasicExpressionType::_getScalarTypeOverride()
{
- auto basicType = as<BasicExpressionType>(type);
- return basicType && basicType->baseType == this->baseType;
+ return this;
}
-Type* BasicExpressionType::_createCanonicalTypeOverride()
+static Val* _getGenericTypeArg(DeclRefBase* declRef, Index i)
{
- // A basic type is already canonical, in our setup
- return this;
+ auto args = findInnerMostGenericArgs(SubstitutionSet(declRef));
+ if (args.getCount() <= i)
+ return nullptr;
+
+ return args[i];
}
-BasicExpressionType* BasicExpressionType::_getScalarTypeOverride()
+static Val* _getGenericTypeArg(DeclRefType* declRefType, Index i)
{
- return this;
+ return _getGenericTypeArg(declRefType->getDeclRefBase(), i);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TensorViewType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Type* TensorViewType::getElementType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! VectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+Type* VectorExpressionType::getElementType()
+{
+ return as<Type>(_getGenericTypeArg(this, 0));
+}
+
+IntVal* VectorExpressionType::getElementCount()
+{
+ return as<IntVal>(_getGenericTypeArg(this, 1));
+}
+
void VectorExpressionType::_toTextOverride(StringBuilder& out)
{
- out << toSlice("vector<") << elementType << toSlice(",") << elementCount << toSlice(">");
+ out << toSlice("vector<") << getElementType() << toSlice(",") << getElementCount() << toSlice(">");
}
BasicExpressionType* VectorExpressionType::_getScalarTypeOverride()
{
- return as<BasicExpressionType>(elementType);
+ return as<BasicExpressionType>(getElementType());
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MatrixExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -304,24 +225,24 @@ BasicExpressionType* MatrixExpressionType::_getScalarTypeOverride()
Type* MatrixExpressionType::getElementType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
}
IntVal* MatrixExpressionType::getRowCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]);
+ return as<IntVal>(_getGenericTypeArg(this, 1));
}
IntVal* MatrixExpressionType::getColumnCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[2]);
+ return as<IntVal>(_getGenericTypeArg(this, 2));
}
Type* MatrixExpressionType::getRowType()
{
if (!rowType)
{
- rowType = m_astBuilder->getVectorType(getElementType(), getColumnCount());
+ rowType = getCurrentASTBuilder()->getVectorType(getElementType(), getColumnCount());
}
return rowType;
}
@@ -330,12 +251,12 @@ Type* MatrixExpressionType::getRowType()
Type* ArrayExpressionType::getElementType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
}
IntVal* ArrayExpressionType::getElementCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]);
+ return as<IntVal>(_getGenericTypeArg(this, 1));
}
void ArrayExpressionType::_toTextOverride(StringBuilder& out)
@@ -353,7 +274,7 @@ bool ArrayExpressionType::isUnsized()
{
if (auto constSize = as<ConstantIntVal>(getElementCount()))
{
- if (constSize->value == kUnsizedArrayMagicLength)
+ if (constSize->getValue() == kUnsizedArrayMagicLength)
return true;
}
return false;
@@ -363,27 +284,12 @@ bool ArrayExpressionType::isUnsized()
void TypeType::_toTextOverride(StringBuilder& out)
{
- out << toSlice("typeof(") << type << toSlice(")");
-}
-
-bool TypeType::_equalsImplOverride(Type * t)
-{
- if (auto typeType = as<TypeType>(t))
- {
- return t->equals(typeType->type);
- }
- return false;
+ out << toSlice("typeof(") << getType() << toSlice(")");
}
Type* TypeType::_createCanonicalTypeOverride()
{
- return getASTBuilder()->getTypeType(type->getCanonicalType());
-}
-
-HashCode TypeType::_getHashCodeOverride()
-{
- SLANG_UNEXPECTED("TypeType::_getHashCodeOverride should be unreachable");
- //return HashCode(0);
+ return getCurrentASTBuilder()->getTypeType(getType()->getCanonicalType());
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericDeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -394,20 +300,6 @@ void GenericDeclRefType::_toTextOverride(StringBuilder& out)
out << toSlice("<DeclRef<GenericDecl>>");
}
-bool GenericDeclRefType::_equalsImplOverride(Type * type)
-{
- if (auto genericDeclRefType = as<GenericDeclRefType>(type))
- {
- return declRef.equals(genericDeclRefType->declRef);
- }
- return false;
-}
-
-HashCode GenericDeclRefType::_getHashCodeOverride()
-{
- return declRef.getHashCode();
-}
-
Type* GenericDeclRefType::_createCanonicalTypeOverride()
{
return this;
@@ -417,21 +309,7 @@ Type* GenericDeclRefType::_createCanonicalTypeOverride()
void NamespaceType::_toTextOverride(StringBuilder& out)
{
- out << toSlice("namespace ") << declRef;
-}
-
-bool NamespaceType::_equalsImplOverride(Type * type)
-{
- if (auto namespaceType = as<NamespaceType>(type))
- {
- return declRef.equals(namespaceType->declRef);
- }
- return false;
-}
-
-HashCode NamespaceType::_getHashCodeOverride()
-{
- return declRef.getHashCode();
+ out << toSlice("namespace ") << getDeclRef();
}
Type* NamespaceType::_createCanonicalTypeOverride()
@@ -441,7 +319,7 @@ Type* NamespaceType::_createCanonicalTypeOverride()
Type* DifferentialPairType::getPrimalType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
}
@@ -449,51 +327,35 @@ Type* DifferentialPairType::getPrimalType()
Type* PtrTypeBase::getValueType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
}
Type* OptionalType::getValueType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(_getGenericTypeArg(this, 0));
+}
+
+Type* NativeRefType::getValueType()
+{
+ return as<Type>(_getGenericTypeArg(this, 0));
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void NamedExpressionType::_toTextOverride(StringBuilder& out)
{
- if (declRef.getDecl())
+ if (getDeclRef().getDecl())
{
- _printNestedDecl(declRef.getSubst(), declRef.getDecl(), out);
+ getDeclRef().declRefBase->toText(out);
}
}
-bool NamedExpressionType::_equalsImplOverride(Type * /*type*/)
-{
- SLANG_UNEXPECTED("NamedExpressionType::_equalsImplOverride should be unreachable");
- //return false;
-}
-
Type* NamedExpressionType::_createCanonicalTypeOverride()
{
- if (!innerType)
- innerType = getType(m_astBuilder, declRef);
- if (innerType)
- return innerType->getCanonicalType();
- return nullptr;
-}
-
-HashCode NamedExpressionType::_getHashCodeOverride()
-{
- // Type equality is based on comparing canonical types,
- // so the hash code for a type needs to come from the
- // canonical version of the type. This really means
- // that `Type::getHashCode()` should dispatch out to
- // something like `Type::getHashCodeImpl()` on the
- // canonical version of a type, but it is less invasive
- // for now (and hopefully equivalent) to just have any
- // named types automaticlaly route hash-code requests
- // to their canonical type.
- return getCanonicalType()->getHashCode();
+ auto canType = getType(getCurrentASTBuilder(), getDeclRef());
+ if (canType)
+ return canType->getCanonicalType();
+ return getCurrentASTBuilder()->getErrorType();
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -533,58 +395,27 @@ void FuncType::_toTextOverride(StringBuilder& out)
}
out << ") -> " << getResultType();
- if (!getErrorType()->equals(getASTBuilder()->getBottomType()))
+ if (!getErrorType()->equals(getCurrentASTBuilder()->getBottomType()))
{
out << " throws " << getErrorType();
}
}
-bool FuncType::_equalsImplOverride(Type * type)
-{
- if (auto funcType = as<FuncType>(type))
- {
- auto paramCount = getParamCount();
- auto otherParamCount = funcType->getParamCount();
- if (paramCount != otherParamCount)
- return false;
-
- for (Index pp = 0; pp < paramCount; ++pp)
- {
- auto paramType = getParamType(pp);
- auto otherParamType = funcType->getParamType(pp);
- if (!paramType->equals(otherParamType))
- return false;
- }
-
- if (!resultType->equals(funcType->resultType))
- return false;
-
- if (!errorType->equals(funcType->errorType))
- return false;
-
- // TODO: if we ever introduce other kinds
- // of qualification on function types, we'd
- // want to consider it here.
- return true;
- }
- return false;
-}
-
Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
// result type
- Type* substResultType = as<Type>(resultType->substituteImpl(astBuilder, subst, &diff));
+ Type* substResultType = as<Type>(getResultType()->substituteImpl(astBuilder, subst, &diff));
// error type
- Type* substErrorType = as<Type>(errorType->substituteImpl(astBuilder, subst, &diff));
+ Type* substErrorType = as<Type>(getErrorType()->substituteImpl(astBuilder, subst, &diff));
// parameter types
List<Type*> substParamTypes;
- for (auto pp : paramTypes)
+ for (Index pp = 0; pp < getParamCount(); pp++ )
{
- substParamTypes.add(as<Type>(pp->substituteImpl(astBuilder, subst, &diff)));
+ substParamTypes.add(as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff)));
}
// early exit for no change...
@@ -592,138 +423,75 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s
return this;
(*ioDiff)++;
- FuncType* substType = astBuilder->create<FuncType>();
- substType->resultType = substResultType;
- substType->paramTypes = substParamTypes;
- substType->errorType = substErrorType;
+ FuncType* substType = astBuilder->getFuncType(substParamTypes.getArrayView(), substResultType, substErrorType);
return substType;
}
Type* FuncType::_createCanonicalTypeOverride()
{
// result type
- Type* canResultType = resultType->getCanonicalType();
- Type* canErrorType = errorType->getCanonicalType();
+ Type* canResultType = getResultType()->getCanonicalType();
+ Type* canErrorType = getErrorType()->getCanonicalType();
// parameter types
List<Type*> canParamTypes;
- for (auto pp : paramTypes)
+ for (Index pp = 0; pp < getParamCount(); pp++)
{
- canParamTypes.add(pp->getCanonicalType());
+ canParamTypes.add(getParamType(pp)->getCanonicalType());
}
- FuncType* canType = getASTBuilder()->create<FuncType>();
- canType->resultType = canResultType;
- canType->paramTypes = canParamTypes;
- canType->errorType = canErrorType;
+ FuncType* canType = getCurrentASTBuilder()->getFuncType(canParamTypes.getArrayView(), canResultType, canErrorType);
return canType;
}
-HashCode FuncType::_getHashCodeOverride()
-{
- HashCode hashCode = getResultType()->getHashCode();
- Index paramCount = getParamCount();
- hashCode = combineHash(hashCode, Slang::getHashCode(paramCount));
- for (Index pp = 0; pp < paramCount; ++pp)
- {
- hashCode = combineHash(
- hashCode,
- getParamType(pp)->getHashCode());
- }
- combineHash(hashCode, getErrorType()->getHashCode());
- return hashCode;
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TupleType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void TupleType::_toTextOverride(StringBuilder& out)
{
out << toSlice("(");
- for (Index pp = 0; pp < memberTypes.getCount(); ++pp)
+ for (Index pp = 0; pp < getOperandCount(); ++pp)
{
if (pp != 0)
out << toSlice(", ");
- out << memberTypes[pp];
+ out << getOperand(pp);
}
out << toSlice(")");
}
-bool TupleType::_equalsImplOverride(Type * type)
-{
- if (const auto other = as<TupleType>(type))
- {
- auto paramCount = memberTypes.getCount();
- auto otherParamCount = other->memberTypes.getCount();
- if (paramCount != otherParamCount)
- return false;
-
- for (Index i = 0; i < memberTypes.getCount(); ++i)
- {
- if(!memberTypes[i]->equals(other->memberTypes[i]))
- return false;
- }
-
- return true;
- }
- return false;
-}
-
Val* TupleType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
// just recurse into the members
List<Type*> substMemberTypes;
- for (auto m : memberTypes)
- substMemberTypes.add(as<Type>(m->substituteImpl(astBuilder, subst, &diff)));
+ for (Index m = 0; m < getMemberCount(); m++)
+ substMemberTypes.add(as<Type>(getMember(m)->substituteImpl(astBuilder, subst, &diff)));
// early exit for no change...
if (!diff)
return this;
(*ioDiff)++;
- return astBuilder->create<TupleType>(std::move(substMemberTypes));
+ return astBuilder->getTupleType(substMemberTypes);
}
Type* TupleType::_createCanonicalTypeOverride()
{
// member types
List<Type*> canMemberTypes;
- for (auto m : memberTypes)
+ for (Index m = 0; m < getMemberCount(); m++)
{
- canMemberTypes.add(m->getCanonicalType());
+ canMemberTypes.add(getMember(m)->getCanonicalType());
}
- return getASTBuilder()->create<TupleType>(std::move(canMemberTypes));
-}
-
-HashCode TupleType::_getHashCodeOverride()
-{
- HashCode hashCode = Slang::getHashCode(kType);
- for(auto m : memberTypes)
- hashCode = combineHash(hashCode, m->getHashCode());
- return hashCode;
+ return getCurrentASTBuilder()->getTupleType(canMemberTypes);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void ExtractExistentialType::_toTextOverride(StringBuilder& out)
{
- out << declRef << toSlice(".This");
-}
-
-bool ExtractExistentialType::_equalsImplOverride(Type* type)
-{
- if (auto extractExistential = as<ExtractExistentialType>(type))
- {
- return declRef.equals(extractExistential->declRef);
- }
- return false;
-}
-
-HashCode ExtractExistentialType::_getHashCodeOverride()
-{
- return combineHash(declRef.getHashCode(), originalInterfaceType->getHashCode(), originalInterfaceDeclRef.getHashCode());
+ out << getDeclRef() << toSlice(".This");
}
Type* ExtractExistentialType::_createCanonicalTypeOverride()
@@ -734,18 +502,16 @@ Type* ExtractExistentialType::_createCanonicalTypeOverride()
Val* ExtractExistentialType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff);
- auto substOriginalInterfaceType = originalInterfaceType->substituteImpl(astBuilder, subst, &diff);
- auto substOriginalInterfaceDeclRef = originalInterfaceDeclRef.substituteImpl(astBuilder, subst, &diff);
+ auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
+ auto substOriginalInterfaceType = getOriginalInterfaceType()->substituteImpl(astBuilder, subst, &diff);
+ auto substOriginalInterfaceDeclRef = getOriginalInterfaceDeclRef().substituteImpl(astBuilder, subst, &diff);
if (!diff)
return this;
(*ioDiff)++;
- ExtractExistentialType* substValue = astBuilder->create<ExtractExistentialType>();
- substValue->declRef = substDeclRef;
- substValue->originalInterfaceType = as<Type>(substOriginalInterfaceType);
- substValue->originalInterfaceDeclRef = substOriginalInterfaceDeclRef;
+ ExtractExistentialType* substValue = astBuilder->getOrCreate<ExtractExistentialType>(
+ substDeclRef, as<Type>(substOriginalInterfaceType), substOriginalInterfaceDeclRef);
return substValue;
}
@@ -754,165 +520,47 @@ SubtypeWitness* ExtractExistentialType::getSubtypeWitness()
if (auto cachedValue = this->cachedSubtypeWitness)
return cachedValue;
- ExtractExistentialSubtypeWitness* openedWitness = m_astBuilder->create<ExtractExistentialSubtypeWitness>();
- openedWitness->sub = this;
- openedWitness->sup = originalInterfaceType;
- openedWitness->declRef = this->declRef;
-
+ ExtractExistentialSubtypeWitness* openedWitness = getCurrentASTBuilder()->getOrCreate<ExtractExistentialSubtypeWitness>(this, getOriginalInterfaceType(), getDeclRef());
this->cachedSubtypeWitness = openedWitness;
return openedWitness;
}
-DeclRef<InterfaceDecl> ExtractExistentialType::getSpecializedInterfaceDeclRef()
+DeclRef<ThisTypeDecl> ExtractExistentialType::getThisTypeDeclRef()
{
- if (auto cachedValue = this->cachedSpecializedInterfaceDeclRef)
+ if (auto cachedValue = this->cachedThisTypeDeclRef)
return cachedValue;
- auto interfaceDecl = originalInterfaceDeclRef.getDecl();
+ auto interfaceDecl = getOriginalInterfaceDeclRef().getDecl();
SubtypeWitness* openedWitness = getSubtypeWitness();
- ThisTypeSubstitution* openedThisType = m_astBuilder->getOrCreateThisTypeSubstitution(
- interfaceDecl, openedWitness, originalInterfaceDeclRef.getSubst());
-
- DeclRef<InterfaceDecl> specialiedInterfaceDeclRef = m_astBuilder->getSpecializedDeclRef<InterfaceDecl>(interfaceDecl, openedThisType);
-
- this->cachedSpecializedInterfaceDeclRef = specialiedInterfaceDeclRef;
- return specialiedInterfaceDeclRef;
-}
-
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-void TaggedUnionType::_toTextOverride(StringBuilder& out)
-{
- out << toSlice("__TaggedUnion(");
- bool first = true;
- for (auto caseType : caseTypes)
- {
- if (!first)
+ ThisTypeDecl* thisTypeDecl = nullptr;
+ for (auto member : interfaceDecl->members)
+ if (as<ThisTypeDecl>(member))
{
- out << toSlice(", ");
+ thisTypeDecl = as<ThisTypeDecl>(member);
+ break;
}
- first = false;
-
- out << caseType;
- }
- out << toSlice(")");
-}
-
-bool TaggedUnionType::_equalsImplOverride(Type* type)
-{
- auto taggedUnion = as<TaggedUnionType>(type);
- if (!taggedUnion)
- return false;
-
- auto caseCount = caseTypes.getCount();
- if (caseCount != taggedUnion->caseTypes.getCount())
- return false;
-
- for (Index ii = 0; ii < caseCount; ++ii)
- {
- if (!caseTypes[ii]->equals(taggedUnion->caseTypes[ii]))
- return false;
- }
- return true;
-}
-
-HashCode TaggedUnionType::_getHashCodeOverride()
-{
- HashCode hashCode = 0;
- for (auto caseType : caseTypes)
- {
- hashCode = combineHash(hashCode, caseType->getHashCode());
- }
- return hashCode;
-}
-
-Type* TaggedUnionType::_createCanonicalTypeOverride()
-{
- TaggedUnionType* canType = m_astBuilder->create<TaggedUnionType>();
-
- for (auto caseType : caseTypes)
- {
- auto canCaseType = caseType->getCanonicalType();
- canType->caseTypes.add(canCaseType);
- }
-
- return canType;
-}
+ SLANG_ASSERT(thisTypeDecl);
-Val* TaggedUnionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
+ DeclRef<ThisTypeDecl> specialiedInterfaceDeclRef = getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl);
- List<Type*> substCaseTypes;
- for (auto caseType : caseTypes)
- {
- substCaseTypes.add(as<Type>(caseType->substituteImpl(astBuilder, subst, &diff)));
- }
- if (!diff)
- return this;
-
- (*ioDiff)++;
-
- TaggedUnionType* substType = astBuilder->create<TaggedUnionType>();
- substType->caseTypes.swapWith(substCaseTypes);
- return substType;
+ this->cachedThisTypeDeclRef = specialiedInterfaceDeclRef;
+ return specialiedInterfaceDeclRef;
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExistentialSpecializedType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void ExistentialSpecializedType::_toTextOverride(StringBuilder& out)
{
- out << toSlice("__ExistentialSpecializedType(") << baseType;
- for (auto arg : args)
+ out << toSlice("__ExistentialSpecializedType(") << getBaseType();
+ for (Index i = 0; i < getArgCount(); i++)
{
- out << toSlice(", ") << arg.val;
+ out << toSlice(", ") << getArg(i).val;
}
out << toSlice(")");
}
-bool ExistentialSpecializedType::_equalsImplOverride(Type * type)
-{
- auto other = as<ExistentialSpecializedType>(type);
- if (!other)
- return false;
-
- if (!baseType->equals(other->baseType))
- return false;
-
- auto argCount = args.getCount();
- if (argCount != other->args.getCount())
- return false;
-
- for (Index ii = 0; ii < argCount; ++ii)
- {
- auto arg = args[ii];
- auto otherArg = other->args[ii];
-
- if (!arg.val->equalsVal(otherArg.val))
- return false;
-
- if (!areValsEqual(arg.witness, otherArg.witness))
- return false;
- }
- return true;
-}
-
-HashCode ExistentialSpecializedType::_getHashCodeOverride()
-{
- Hasher hasher;
- hasher.hashObject(baseType);
- for (auto arg : args)
- {
- hasher.hashObject(arg.val);
- if (auto witness = arg.witness)
- hasher.hashObject(witness);
- }
- return hasher.getResult();
-}
-
static Val* _getCanonicalValue(Val* val)
{
if (!val)
@@ -928,16 +576,21 @@ static Val* _getCanonicalValue(Val* val)
Type* ExistentialSpecializedType::_createCanonicalTypeOverride()
{
- ExistentialSpecializedType* canType = m_astBuilder->create<ExistentialSpecializedType>();
+ ExpandedSpecializationArgs newArgs;
- canType->baseType = baseType->getCanonicalType();
- for (auto arg : args)
+ for (Index ii = 0; ii < getArgCount(); ++ii)
{
+ auto arg = getArg(ii);
ExpandedSpecializationArg canArg;
canArg.val = _getCanonicalValue(arg.val);
canArg.witness = _getCanonicalValue(arg.witness);
- canType->args.add(canArg);
+ newArgs.add(canArg);
}
+
+ ExistentialSpecializedType* canType = getCurrentASTBuilder()->getOrCreate<ExistentialSpecializedType>(
+ getBaseType()->getCanonicalType(),
+ newArgs);
+
return canType;
}
@@ -951,11 +604,12 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder,
{
int diff = 0;
- auto substBaseType = as<Type>(baseType->substituteImpl(astBuilder, subst, &diff));
+ auto substBaseType = as<Type>(getBaseType()->substituteImpl(astBuilder, subst, &diff));
ExpandedSpecializationArgs substArgs;
- for (auto arg : args)
+ for (Index ii = 0; ii < getArgCount(); ++ii)
{
+ auto arg = getArg(ii);
ExpandedSpecializationArg substArg;
substArg.val = _substituteImpl(astBuilder, arg.val, subst, &diff);
substArg.witness = _substituteImpl(astBuilder, arg.witness, subst, &diff);
@@ -967,96 +621,22 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder,
(*ioDiff)++;
- ExistentialSpecializedType* substType = astBuilder->create<ExistentialSpecializedType>();
- substType->baseType = substBaseType;
- substType->args = substArgs;
+ ExistentialSpecializedType* substType = astBuilder->getOrCreate<ExistentialSpecializedType>(substBaseType, substArgs);
return substType;
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-void ThisType::_toTextOverride(StringBuilder& out)
-{
- out << interfaceDeclRef << toSlice(".This");
-}
-
-bool ThisType::_equalsImplOverride(Type * type)
-{
- auto other = as<ThisType>(type);
- if (!other)
- return false;
-
- if (!interfaceDeclRef.equals(other->interfaceDeclRef))
- return false;
-
- return true;
-}
-
-HashCode ThisType::_getHashCodeOverride()
-{
- return combineHash(
- HashCode(typeid(*this).hash_code()),
- interfaceDeclRef.getHashCode());
-}
-
-Type* ThisType::_createCanonicalTypeOverride()
+InterfaceDecl* ThisType::getInterfaceDecl()
{
- ThisType* canType = m_astBuilder->create<ThisType>();
-
- // TODO: need to canonicalize the decl-ref
- canType->interfaceDeclRef = interfaceDeclRef;
- return canType;
-}
-
-Val* ThisType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
-
- auto substInterfaceDeclRef = interfaceDeclRef.substituteImpl(astBuilder, subst, &diff);
-
- auto thisTypeSubst = findThisTypeSubstitution(subst.substitutions, substInterfaceDeclRef.getDecl());
- if (thisTypeSubst)
- {
- return thisTypeSubst->witness->sub;
- }
-
- if (!diff)
- return this;
-
- (*ioDiff)++;
-
- ThisType* substType = m_astBuilder->create<ThisType>();
- substType->interfaceDeclRef = substInterfaceDeclRef;
- return substType;
+ return dynamicCast<InterfaceDecl>(getDeclRefBase()->getDecl()->parentDecl);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AndType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void AndType::_toTextOverride(StringBuilder& out)
{
- out << left << toSlice(" & ") << right;
-}
-
-bool AndType::_equalsImplOverride(Type * type)
-{
- auto other = as<AndType>(type);
- if (!other)
- return false;
-
- if(!left->equals(other->left))
- return false;
- if(!right->equals(other->right))
- return false;
-
- return true;
-}
-
-HashCode AndType::_getHashCodeOverride()
-{
- Hasher hasher;
- hasher.hashObject(left);
- hasher.hashObject(right);
- return hasher.getResult();
+ out << getLeft() << toSlice(" & ") << getRight();
}
Type* AndType::_createCanonicalTypeOverride()
@@ -1094,9 +674,9 @@ Type* AndType::_createCanonicalTypeOverride()
// right now, in the name of getting something up and running.
//
- auto canLeft = left->getCanonicalType();
- auto canRight = right->getCanonicalType();
- auto canType = m_astBuilder->getAndType(canLeft, canRight);
+ auto canLeft = getLeft()->getCanonicalType();
+ auto canRight = getRight()->getCanonicalType();
+ auto canType = getCurrentASTBuilder()->getAndType(canLeft, canRight);
return canType;
}
@@ -1104,15 +684,15 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su
{
int diff = 0;
- auto substLeft = as<Type>(left ->substituteImpl(astBuilder, subst, &diff));
- auto substRight = as<Type>(right->substituteImpl(astBuilder, subst, &diff));
+ auto substLeft = as<Type>(getLeft()->substituteImpl(astBuilder, subst, &diff));
+ auto substRight = as<Type>(getRight()->substituteImpl(astBuilder, subst, &diff));
if(!diff)
return this;
(*ioDiff)++;
- auto substType = m_astBuilder->getAndType(substLeft, substRight);
+ auto substType = getCurrentASTBuilder()->getAndType(substLeft, substRight);
return substType;
}
@@ -1120,83 +700,35 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su
void ModifiedType::_toTextOverride(StringBuilder& out)
{
- for( auto modifier : modifiers )
+ for( Index i = 0; i < getModifierCount(); i++ )
{
- modifier->toText(out);
+ getModifier(i)->toText(out);
out.appendChar(' ');
}
- base->toText(out);
-}
-
-bool ModifiedType::_equalsImplOverride(Type* type)
-{
- auto other = as<ModifiedType>(type);
- if(!other)
- return false;
-
- if(!base->equals(other->base))
- return false;
-
- // TODO: Eventually we need to put the `modifiers` into
- // a canonical ordering as part of creation of a `ModifiedType`,
- // so that two instances that apply the same modifiers to
- // the same type will have those modifiers in a matching order.
- //
- // The simplest way to achieve that ordering *for now* would
- // be to sort the array by the integer AST node type tag.
- // That approach would of course not scale to modifiers that
- // have any operands of their own.
- //
- // Note that we would *also* need the logic that creates a
- // `ModifiedType` to detect when the base type is itself a
- // `ModifiedType` and produce a single `ModifiedType` with
- // a combined list of modifiers and a non-`ModifiedType` as
- // its base type.
- //
- auto modifierCount = modifiers.getCount();
- if(modifierCount != other->modifiers.getCount())
- return false;
-
- for( Index i = 0; i < modifierCount; ++i )
- {
- auto thisModifier = this->modifiers[i];
- auto otherModifier = other->modifiers[i];
- if(!thisModifier->equalsVal(otherModifier))
- return false;
- }
- return true;
-}
-
-HashCode ModifiedType::_getHashCodeOverride()
-{
- Hasher hasher;
- hasher.hashObject(base);
- for( auto modifier : modifiers )
- {
- hasher.hashObject(modifier);
- }
- return hasher.getResult();
+ getBase()->toText(out);
}
Type* ModifiedType::_createCanonicalTypeOverride()
{
- ModifiedType* canonical = m_astBuilder->create<ModifiedType>();
- canonical->base = base->getCanonicalType();
- for( auto modifier : modifiers )
+ List<Val*> modifiers;
+ for (Index i = 0; i < getModifierCount(); ++i)
{
- canonical->modifiers.add(modifier);
+ auto modifier = this->getModifier(i);
+ modifiers.add(modifier);
}
+ ModifiedType* canonical = getCurrentASTBuilder()->getOrCreate<ModifiedType>(getBase()->getCanonicalType(), modifiers.getArrayView());
return canonical;
}
Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- Type* substBase = as<Type>(base->substituteImpl(astBuilder, subst, &diff));
+ Type* substBase = as<Type>(getBase()->substituteImpl(astBuilder, subst, &diff));
List<Val*> substModifiers;
- for( auto modifier : modifiers )
+ for (Index i = 0; i < getModifierCount(); ++i)
{
+ auto modifier = this->getModifier(i);
auto substModifier = modifier->substituteImpl(astBuilder, subst, &diff);
substModifiers.add(substModifier);
}
@@ -1206,12 +738,49 @@ Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionS
*ioDiff = 1;
- ModifiedType* substType = m_astBuilder->create<ModifiedType>();
- substType->base = substBase;
- substType->modifiers = _Move(substModifiers);
+ ModifiedType* substType = getCurrentASTBuilder()->getOrCreate<ModifiedType>(substBase, substModifiers.getArrayView());
return substType;
}
+BaseType BasicExpressionType::getBaseType() const
+{
+ auto builtinType = getDeclRef().getDecl()->findModifier<BuiltinTypeModifier>();
+ return builtinType->tag;
+}
+
+FeedbackType::Kind FeedbackType::getKind() const
+{
+ auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>();
+ return FeedbackType::Kind(magicMod->tag);
+}
+
+TextureFlavor ResourceType::getFlavor() const
+{
+ auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>();
+ return TextureFlavor(magicMod->tag);
+}
+
+SamplerStateFlavor SamplerStateType::getFlavor() const
+{
+ auto magicMod = getDeclRef().getDecl()->findModifier<MagicTypeModifier>();
+ return SamplerStateFlavor(magicMod->tag);
+}
+
+Type* BuiltinGenericType::getElementType() const
+{
+ return as<Type>(_getGenericTypeArg(getDeclRefBase(), 0));
+}
+
+Type* ResourceType::getElementType()
+{
+ return as<Type>(_getGenericTypeArg(this, 0));
+}
+
+Val* TextureTypeBase::getSampleCount()
+{
+ return as<Type>(_getGenericTypeArg(this, 1));
+}
+
Type* removeParamDirType(Type* type)
{
for (auto paramDirType = as<ParamDirectionType>(type); paramDirType;)
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index b32d62404..8948f742c 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -16,8 +16,6 @@ class OverloadGroupType : public Type
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Type* _createCanonicalTypeOverride();
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
};
// The type of an initializer-list expression (before it has
@@ -29,8 +27,6 @@ class InitializerListType : public Type
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Type* _createCanonicalTypeOverride();
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
};
// The type of an expression that was erroneous
@@ -41,8 +37,6 @@ class ErrorType : public Type
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Type* _createCanonicalTypeOverride();
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
@@ -53,31 +47,28 @@ class BottomType : public Type
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
- Type* _createCanonicalTypeOverride();
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
// A type that takes the form of a reference to some declaration
-class DeclRefType : public Type
+class DeclRefType : public Type
{
SLANG_AST_CLASS(DeclRefType)
- DeclRef<Decl> declRef;
-
static DeclRefType* create(ASTBuilder* astBuilder, DeclRef<Decl> declRef);
+ DeclRef<Decl> getDeclRef() const { return DeclRef<Decl>(as<DeclRefBase>(getOperand(0))); }
+ DeclRefBase* getDeclRefBase() const { return as<DeclRefBase>(getOperand(0)); }
+
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Type* _createCanonicalTypeOverride();
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
DeclRefType(DeclRefBase* declRefBase)
- : declRef(declRefBase)
- {}
+ {
+ setOperands(declRefBase);
+ }
};
// Base class for types that can be used in arithmetic expressions
@@ -95,18 +86,15 @@ class BasicExpressionType : public ArithmeticExpressionType
{
SLANG_AST_CLASS(BasicExpressionType)
- BaseType baseType;
+ BaseType getBaseType() const;
// Overrides should be public so base classes can access
- Type* _createCanonicalTypeOverride();
- bool _equalsImplOverride(Type* type);
BasicExpressionType* _getScalarTypeOverride();
-protected:
- BasicExpressionType(
- Slang::BaseType baseType)
- : baseType(baseType)
- {}
+ BasicExpressionType(DeclRefBase* inDeclRef)
+ {
+ setOperands(inDeclRef);
+ }
};
// Base type for things that are built in to the compiler,
@@ -127,7 +115,7 @@ class FeedbackType : public BuiltinType
MipRegionUsed, /// SAMPLER_FEEDBACK_MIP_REGION_USED
};
- Kind kind;
+ Kind getKind() const;
};
// Resources that contain "elements" that can be fetched
@@ -135,37 +123,24 @@ class ResourceType : public BuiltinType
{
SLANG_ABSTRACT_AST_CLASS(ResourceType)
- // The type that results from fetching an element from this resource
- Type* elementType = nullptr;
-
- // Shape and access level information for this resource type
- TextureFlavor flavor;
+ TextureFlavor getFlavor() const;
TextureFlavor::Shape getBaseShape()
{
- return flavor.getBaseShape();
+ return getFlavor().getBaseShape();
}
- bool isMultisample() { return flavor.isMultisample(); }
- bool isArray() { return flavor.isArray(); }
- SlangResourceShape getShape() const { return flavor.getShape(); }
- SlangResourceAccess getAccess() { return flavor.getAccess(); }
+ bool isMultisample() { return getFlavor().isMultisample(); }
+ bool isArray() { return getFlavor().isArray(); }
+ SlangResourceShape getShape() const { return getFlavor().getShape(); }
+ SlangResourceAccess getAccess() { return getFlavor().getAccess(); }
+ Type* getElementType();
};
class TextureTypeBase : public ResourceType
{
SLANG_ABSTRACT_AST_CLASS(TextureTypeBase)
- // The sampleCount parameter of a RWTexture*MS resource.
- Val* sampleCount = nullptr;
-protected:
- TextureTypeBase(TextureFlavor inFlavor, Type* inElementType, Val* inSampleCount = nullptr)
- {
- elementType = inElementType;
- flavor = inFlavor;
- sampleCount = inSampleCount;
- }
-
- Val* getSampleCount() const { return sampleCount; }
+ Val* getSampleCount();
};
@@ -173,11 +148,6 @@ protected:
class TextureType : public TextureTypeBase
{
SLANG_AST_CLASS(TextureType)
-
-protected:
- TextureType(TextureFlavor flavor, Type* elementType, Val* inSampleCount = nullptr)
- : TextureTypeBase(flavor, elementType, inSampleCount)
- {}
};
@@ -186,37 +156,20 @@ protected:
class TextureSamplerType : public TextureTypeBase
{
SLANG_AST_CLASS(TextureSamplerType)
-
-protected:
- TextureSamplerType(TextureFlavor flavor, Type* elementType)
- : TextureTypeBase(flavor, elementType)
- {}
};
// This is a base type for `image*` types, as they exist in GLSL
class GLSLImageType : public TextureTypeBase
{
SLANG_AST_CLASS(GLSLImageType)
-
-protected:
- GLSLImageType(
- TextureFlavor flavor,
- Type* elementType)
- : TextureTypeBase(flavor, elementType)
- {}
};
class SamplerStateType : public BuiltinType
{
SLANG_AST_CLASS(SamplerStateType)
- // What flavor of sampler state is this
- SamplerStateFlavor flavor;
-
- SamplerStateType(SamplerStateFlavor inFlavor)
- {
- flavor = inFlavor;
- }
+ // Returns flavor of sampler state of this type.
+ SamplerStateFlavor getFlavor() const;
};
// Other cases of generic types known to the compiler
@@ -224,9 +177,7 @@ class BuiltinGenericType : public BuiltinType
{
SLANG_AST_CLASS(BuiltinGenericType)
- Type* elementType = nullptr;
-
- Type* getElementType() { return elementType; }
+ Type* getElementType() const;
};
// Types that behave like pointers, in that they can be
@@ -297,7 +248,6 @@ class HLSLConsumeStructuredBufferType : public HLSLStructuredBufferTypeBase
SLANG_AST_CLASS(HLSLConsumeStructuredBufferType)
};
-
class HLSLPatchType : public BuiltinType
{
SLANG_AST_CLASS(HLSLPatchType)
@@ -396,7 +346,6 @@ class VaryingParameterGroupType : public ParameterGroupType
class ConstantBufferType : public UniformParameterGroupType
{
SLANG_AST_CLASS(ConstantBufferType)
- ConstantBufferType(Type* elementType) { SLANG_UNUSED(elementType); }
};
@@ -435,11 +384,7 @@ class ParameterBlockType : public UniformParameterGroupType
class ArrayExpressionType : public DeclRefType
{
SLANG_AST_CLASS(ArrayExpressionType)
- ArrayExpressionType(Type* inElementType, IntVal* inElementCount)
- {
- SLANG_UNUSED(inElementType);
- SLANG_UNUSED(inElementCount);
- }
+
bool isUnsized();
void _toTextOverride(StringBuilder& out);
Type* getElementType();
@@ -453,21 +398,16 @@ class TypeType : public Type
{
SLANG_AST_CLASS(TypeType)
- // The type that this is the type of...
- Type* type = nullptr;
-
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Type* _createCanonicalTypeOverride();
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
-protected:
- TypeType(Type* type)
- : type(type)
- {}
+ Type* getType() { return as<Type>(getOperand(0)); }
-
+ TypeType(Type* type)
+ {
+ setOperands(type);
+ }
};
// A differential pair type, e.g., `__DifferentialPair<T>`
@@ -487,20 +427,12 @@ class VectorExpressionType : public ArithmeticExpressionType
{
SLANG_AST_CLASS(VectorExpressionType)
- // The type of vector elements.
- // As an invariant, this should be a basic type or an alias.
- Type* elementType = nullptr;
-
- // The number of elements
- IntVal* elementCount = nullptr;
-
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
BasicExpressionType* _getScalarTypeOverride();
- VectorExpressionType(Type* inElementType, IntVal* inElementCount)
- : elementType(inElementType), elementCount(inElementCount)
- {}
+ Type* getElementType();
+ IntVal* getElementCount();
};
// A matrix type, e.g., `matrix<T,R,C>`
@@ -519,9 +451,7 @@ class MatrixExpressionType : public ArithmeticExpressionType
BasicExpressionType* _getScalarTypeOverride();
private:
- Type* rowType = nullptr;
-
- MatrixExpressionType(Type*, IntVal*, IntVal*) {}
+ SLANG_UNREFLECTED Type* rowType = nullptr;
};
class TensorViewType : public BuiltinType
@@ -529,8 +459,6 @@ class TensorViewType : public BuiltinType
SLANG_AST_CLASS(TensorViewType)
Type* getElementType();
-private:
- TensorViewType(Type*) {}
};
// Base class for built in string types
@@ -561,6 +489,7 @@ class DynamicType : public BuiltinType
class EnumTypeType : public BuiltinType
{
SLANG_AST_CLASS(EnumTypeType)
+
// TODO: provide accessors for the declaration, the "tag" type, etc.
};
@@ -640,22 +569,16 @@ class NamedExpressionType : public Type
{
SLANG_AST_CLASS(NamedExpressionType)
- DeclRef<TypeDefDecl> declRef;
- Type* innerType = nullptr;
+ DeclRef<TypeDefDecl> getDeclRef() { return as<DeclRefBase>(getOperand(0)); }
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Type* _createCanonicalTypeOverride();
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
-
-protected:
- NamedExpressionType(
- DeclRef<TypeDefDecl> declRef)
- : declRef(declRef)
- {}
-
+ NamedExpressionType(DeclRef<TypeDefDecl> inDeclRef)
+ {
+ setOperands(inDeclRef);
+ }
};
// A function type is defined by its parameter types
@@ -666,27 +589,24 @@ class FuncType : public Type
// Construct a unary function
FuncType(Type* paramType, Type* resultType, Type* errorType)
- : paramTypes{{paramType}}, resultType{resultType}, errorType{errorType}
- {}
-
- FuncType(List<Type*> parameters, Type* result, Type* error)
- : paramTypes(std::move(parameters)), resultType(result), errorType(error)
- {}
+ {
+ setOperands(paramType, resultType, errorType);
+ }
- // TODO: We may want to preserve parameter names
- // in the list here, just so that we can print
- // out friendly names when printing a function
- // type, even if they don't affect the actual
- // semantic type underneath.
+ FuncType(ArrayView<Type*> parameters, Type* result, Type* error)
+ {
+ for (auto paramType : parameters)
+ m_operands.add(ValNodeOperand(paramType));
+ m_operands.add(ValNodeOperand(result));
+ m_operands.add(ValNodeOperand(error));
+ }
- List<Type*> paramTypes;
- Type* resultType = nullptr;
- Type* errorType = nullptr;
+ OperandView<Type> getParamTypes() { return OperandView<Type>(this, 0, getOperandCount() - 2); }
- Index getParamCount() { return paramTypes.getCount(); }
- Type* getParamType(Index index) { return paramTypes[index]; }
- Type* getResultType() { return resultType; }
- Type* getErrorType() { return errorType; }
+ Index getParamCount() { return m_operands.getCount() - 2; }
+ Type* getParamType(Index index) { return as<Type>(getOperand(index)); }
+ Type* getResultType() { return as<Type>(getOperand(m_operands.getCount() - 2)); }
+ Type* getErrorType() { return as<Type>(getOperand(m_operands.getCount() - 1)); }
ParameterDirection getParamDirection(Index index);
@@ -694,8 +614,6 @@ class FuncType : public Type
void _toTextOverride(StringBuilder& out);
Type* _createCanonicalTypeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
};
// A tuple is a product of its member types
@@ -704,21 +622,19 @@ class TupleType : public Type
SLANG_AST_CLASS(TupleType)
// Construct a unary tupletion
- TupleType(List<Type*> memberTypes)
- : memberTypes(std::move(memberTypes))
- {}
-
- auto getMemberCount() { return memberTypes.getCount(); } const
- auto& getMember(Index i) { return memberTypes[i]; }
+ TupleType(ArrayView<Type*> memberTypes)
+ {
+ for (auto t : memberTypes)
+ m_operands.add(ValNodeOperand(t));
+ }
- List<Type*> memberTypes;
+ auto getMemberCount() const { return getOperandCount(); }
+ Type* getMember(Index i) const { return as<Type>(getOperand(i)); }
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Type* _createCanonicalTypeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
};
// The "type" of an expression that names a generic declaration.
@@ -726,21 +642,16 @@ class GenericDeclRefType : public Type
{
SLANG_AST_CLASS(GenericDeclRefType)
- DeclRef<GenericDecl> declRef;
-
- DeclRef<GenericDecl> const& getDeclRef() const { return declRef; }
+ DeclRef<GenericDecl> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); }
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
Type* _createCanonicalTypeOverride();
-protected:
- GenericDeclRefType(
- DeclRef<GenericDecl> declRef)
- : declRef(declRef)
- {}
+ GenericDeclRefType(DeclRef<GenericDecl> declRef)
+ {
+ setOperands(declRef);
+ }
};
// The "type" of a reference to a module or namespace
@@ -748,14 +659,15 @@ class NamespaceType : public Type
{
SLANG_AST_CLASS(NamespaceType)
- DeclRef<NamespaceDeclBase> declRef;
+ DeclRef<NamespaceDeclBase> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); }
- DeclRef<NamespaceDeclBase> const& getDeclRef() const { return declRef; }
+ NamespaceType(DeclRef<NamespaceDeclBase> inDeclRef)
+ {
+ setOperands(inDeclRef);
+ }
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
Type* _createCanonicalTypeOverride();
};
@@ -765,25 +677,33 @@ class ExtractExistentialType : public Type
{
SLANG_AST_CLASS(ExtractExistentialType)
- DeclRef<VarDeclBase> declRef;
+ DeclRef<VarDeclBase> getDeclRef() const { return as<DeclRefBase>(getOperand(0)); }
// A reference to the original interface this type is known
// to be a subtype of.
//
- Type* originalInterfaceType;
- DeclRef<InterfaceDecl> originalInterfaceDeclRef;
+ Type* getOriginalInterfaceType() { return as<Type>(getOperand(1)); }
+ DeclRef<InterfaceDecl> getOriginalInterfaceDeclRef() { return as<DeclRefBase>(getOperand(2)); }
+
+ ExtractExistentialType(
+ DeclRef<VarDeclBase> inDeclRef,
+ Type* inOriginalInterfaceType,
+ DeclRef<InterfaceDecl> inOriginalInterfaceDeclRef)
+ {
+ setOperands(inDeclRef, inOriginalInterfaceType, inOriginalInterfaceDeclRef);
+ }
// Following fields will not be reflected (and thus won't be serialized, etc.)
SLANG_UNREFLECTED
- // A cached decl-ref to the original interface above, with
- // a this-type substitution that refers to the type extracted here.
+ // A cached decl-ref to the original interface's ThisType Decl, with
+ // a witness that refers to the type extracted here.
//
// This field is optional and can be filled in on-demand. It does *not*
// represent part of the logical value of this `Type`, and should not
// be serialized, included in hashes, etc.
//
- DeclRef<InterfaceDecl> cachedSpecializedInterfaceDeclRef;
+ DeclRef<ThisTypeDecl> cachedThisTypeDeclRef;
// A cached pointer to a witness that shows how this type is a subtype
// of `originalInterfaceType`.
@@ -792,8 +712,6 @@ SLANG_UNREFLECTED
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
Type* _createCanonicalTypeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
@@ -803,62 +721,54 @@ SLANG_UNREFLECTED
///
SubtypeWitness* getSubtypeWitness();
- /// Get a interface decl-ref for the original interface specialized to this type
- /// (using a type-type substitution).
+ /// Get a decl-ref to the interface's ThisType decl, which represents a substitutable type
+ /// from which lookup can be performed.
///
/// This operation may create the decl-ref on demand and cache it.
///
- DeclRef<InterfaceDecl> getSpecializedInterfaceDeclRef();
-};
-
- /// A tagged union of zero or more other types.
-class TaggedUnionType : public Type
-{
- SLANG_AST_CLASS(TaggedUnionType)
-
- /// The distinct "cases" the tagged union can store.
- ///
- /// For each type in this array, the array index is the
- /// tag value for that case.
- ///
- List<Type*> caseTypes;
-
- // Overrides should be public so base classes can access
- void _toTextOverride(StringBuilder& out);
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
- Type* _createCanonicalTypeOverride();
- Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ DeclRef<ThisTypeDecl> getThisTypeDeclRef();
};
class ExistentialSpecializedType : public Type
{
SLANG_AST_CLASS(ExistentialSpecializedType)
- Type* baseType = nullptr;
- ExpandedSpecializationArgs args;
+ Type* getBaseType() { return as<Type>(getOperand(0)); }
+ ExpandedSpecializationArg getArg(Index i)
+ {
+ ExpandedSpecializationArg arg;
+ arg.val = getOperand(i * 2 + 1);
+ arg.witness = getOperand(i * 2 + 2);
+ return arg;
+ }
+ Index getArgCount() { return (getOperandCount() - 1) / 2; }
+
+ ExistentialSpecializedType(
+ Type* inBaseType,
+ ExpandedSpecializationArgs const& inArgs)
+ {
+ m_operands.add(ValNodeOperand(inBaseType));
+ for (auto arg : inArgs)
+ {
+ m_operands.add(ValNodeOperand(arg.val));
+ m_operands.add(ValNodeOperand(arg.witness));
+ }
+ }
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
Type* _createCanonicalTypeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
/// The type of `this` within a polymorphic declaration
-class ThisType : public Type
+class ThisType : public DeclRefType
{
SLANG_AST_CLASS(ThisType)
- DeclRef<InterfaceDecl> interfaceDeclRef;
+ ThisType(DeclRefBase* declRef) : DeclRefType(declRef) {}
- // Overrides should be public so base classes can access
- void _toTextOverride(StringBuilder& out);
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
- Type* _createCanonicalTypeOverride();
- Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ InterfaceDecl* getInterfaceDecl();
};
/// The type of `A & B` where `A` and `B` are types
@@ -868,17 +778,16 @@ class AndType : public Type
{
SLANG_AST_CLASS(AndType)
- Type* left;
- Type* right;
-
+ Type* getLeft() { return as<Type>(getOperand(0)); }
+ Type* getRight() { return as<Type>(getOperand(1)); }
+
AndType(Type* leftType, Type* rightType)
- : left(leftType), right(rightType)
- {}
+ {
+ setOperands(leftType, rightType);
+ }
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
Type* _createCanonicalTypeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
@@ -887,22 +796,32 @@ class ModifiedType : public Type
{
SLANG_AST_CLASS(ModifiedType)
- Type* base;
- List<Val*> modifiers;
+ Type* getBase()
+ {
+ return as<Type>(getOperand(0));
+ }
+
+ Index getModifierCount() { return getOperandCount() - 1; }
+ Val* getModifier(Index index) { return getOperand(index + 1); }
+
+ ModifiedType(Type* inBase, ArrayView<Val*> inModifiers)
+ {
+ m_operands.add(ValNodeOperand(inBase));
+ for (auto modifier : inModifiers)
+ m_operands.add(ValNodeOperand(modifier));
+ }
template<typename T>
T* findModifier()
{
- for (auto v : modifiers)
- if (auto rs = as<T>(v))
+ for (Index i = 1; i < getOperandCount(); i++)
+ if (auto rs = as<T>(getOperand(i)))
return rs;
return nullptr;
}
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
- bool _equalsImplOverride(Type* type);
- HashCode _getHashCodeOverride();
Type* _createCanonicalTypeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index b45300af8..056577eb0 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -6,9 +6,47 @@
#include "slang-generated-ast-macro.h"
#include "slang-diagnostics.h"
#include "slang-syntax.h"
+#include "slang-ast-val.h"
namespace Slang {
+
+bool ValNodeDesc::operator==(ValNodeDesc const& that) const
+{
+ if (hashCode != that.hashCode) return false;
+ if (type != that.type) return false;
+ if (operands.getCount() != that.operands.getCount()) return false;
+ for (Index i = 0; i < operands.getCount(); ++i)
+ {
+ // Note: we are comparing the operands directly for identity
+ // (pointer equality) rather than doing the `Val`-level
+ // equality check.
+ //
+ // The rationale here is that nodes that will be created
+ // via a `NodeDesc` *should* all be going through the
+ // deduplication path anyway, as should their operands.
+ //
+ if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false;
+ }
+ return true;
+}
+
+void ValNodeDesc::init()
+{
+ Hasher hasher;
+ hasher.hashValue(Int(type));
+ for (Index i = 0; i < operands.getCount(); ++i)
+ {
+ // Note: we are hashing the raw pointer value rather
+ // than the content of the value node. This is done
+ // to match the semantics implemented for `==` on
+ // `NodeDesc`.
+ //
+ hasher.hashValue(operands[i].values.nodeOperand);
+ }
+ hashCode = hasher.getResult();
+}
+
Val* Val::substitute(ASTBuilder* astBuilder, SubstitutionSet subst)
{
if (!subst) return this;
@@ -21,14 +59,103 @@ Val* Val::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioD
SLANG_AST_NODE_VIRTUAL_CALL(Val, substituteImpl, (astBuilder, subst, ioDiff))
}
-bool Val::equalsVal(Val* val)
+void Val::toText(StringBuilder& out)
{
- SLANG_AST_NODE_VIRTUAL_CALL(Val, equalsVal, (val))
+ SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out))
}
-void Val::toText(StringBuilder& out)
+Val* Val::_resolveImplOverride()
{
- SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out))
+ SLANG_UNEXPECTED("Val::_resolveImplOverride not overridden");
+}
+
+Val* Val::resolveImpl()
+{
+ SLANG_AST_NODE_VIRTUAL_CALL(Val, resolveImpl, ());
+}
+
+Val* Val::resolve()
+{
+ auto astBuilder = getCurrentASTBuilder();
+
+ // If we are not in a proper checking context, just return the previously resolved val.
+ if (!astBuilder)
+ return m_resolvedVal? m_resolvedVal : this;
+ if (m_resolvedVal && m_resolvedValEpoch == getCurrentASTBuilder()->getEpoch())
+ {
+ SLANG_ASSERT(as<Val>(m_resolvedVal));
+ return m_resolvedVal;
+ }
+
+ // Update epoch now to avoid infinite recursion.
+ m_resolvedValEpoch = getCurrentASTBuilder()->getEpoch();
+ m_resolvedVal = this;
+ m_resolvedVal = resolveImpl();
+
+ // Check if we are resolved to an existing Val in the AST cache.
+ ValNodeDesc newDesc;
+ newDesc.type = m_resolvedVal->astNodeType;
+ for (auto operand : m_resolvedVal->m_operands)
+ {
+ if (operand.kind == ValNodeOperandKind::ValNode)
+ {
+ auto valOperand = as<Val>(operand.values.nodeOperand);
+ if (valOperand)
+ {
+ operand.values.nodeOperand = valOperand->resolve();
+ }
+ }
+ newDesc.operands.add(operand);
+ }
+ newDesc.init();
+
+ NodeBase* existingNode = nullptr;
+ if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode))
+ m_resolvedVal = as<Val>(existingNode);
+
+#ifdef _DEBUG
+ if (m_resolvedVal->_debugUID > 0 && this->_debugUID < 0)
+ {
+ //SLANG_ASSERT_FAILURE("should not be modifying stdlib vals outside of stdlib checking.");
+ }
+#endif
+ return m_resolvedVal;
+}
+
+ValNodeDesc Val::getDesc()
+{
+ ValNodeDesc desc;
+ desc.type = astNodeType;
+ for (auto operand : m_operands)
+ desc.operands.add(operand);
+ desc.init();
+ return desc;
+}
+
+Val* Val::defaultResolveImpl()
+{
+ // Default resolve implementation is to recursively resolve all operands, and lookup in deduplication cache.
+ ValNodeDesc newDesc;
+ newDesc.type = astNodeType;
+ for (auto operand : m_operands)
+ {
+ if (operand.kind == ValNodeOperandKind::ValNode)
+ {
+ auto valOperand = as<Val>(operand.values.nodeOperand);
+ if (valOperand)
+ {
+ operand.values.nodeOperand = valOperand->resolve();
+ }
+ }
+ newDesc.operands.add(operand);
+ }
+ newDesc.init();
+ auto astBuilder = getCurrentASTBuilder();
+
+ NodeBase* existingNode = nullptr;
+ if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode))
+ return as<Val>(existingNode);
+ return this;
}
String Val::toString()
@@ -40,7 +167,7 @@ String Val::toString()
HashCode Val::getHashCode()
{
- SLANG_AST_NODE_VIRTUAL_CALL(Val, getHashCode, ())
+ return Slang::getHashCode(resolve());
}
Val* Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
@@ -52,124 +179,84 @@ Val* Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst,
return this;
}
-bool Val::_equalsValOverride(Val* val)
-{
- SLANG_UNUSED(val);
- SLANG_UNEXPECTED("Val::_equalsValOverride not overridden");
- //return false;
-}
-
void Val::_toTextOverride(StringBuilder& out)
{
SLANG_UNUSED(out);
SLANG_UNEXPECTED("Val::_toStringOverride not overridden");
}
-HashCode Val::_getHashCodeOverride()
-{
- SLANG_UNEXPECTED("Val::_getHashCodeOverride not overridden");
- //return HashCode(0);
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ConstantIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool ConstantIntVal::_equalsValOverride(Val* val)
-{
- if (auto intVal = as<ConstantIntVal>(val))
- return value == intVal->value;
- return false;
-}
-
void ConstantIntVal::_toTextOverride(StringBuilder& out)
{
- out << value;
-}
-
-HashCode ConstantIntVal::_getHashCodeOverride()
-{
- return (HashCode)value;
+ out << getValue();
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericParamIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool GenericParamIntVal::_equalsValOverride(Val* val)
-{
- if (auto genericParamVal = as<GenericParamIntVal>(val))
- {
- return declRef.equals(genericParamVal->declRef);
- }
- return false;
-}
-
void GenericParamIntVal::_toTextOverride(StringBuilder& out)
{
- Name* name = declRef.getName();
+ Name* name = getDeclRef().getName();
if (name)
{
out << name->text;
}
}
-HashCode GenericParamIntVal::_getHashCodeOverride()
-{
- return declRef.getHashCode() ^ HashCode(0xFFFF);
-}
-
Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff)
{
// search for a substitution that might apply to us
- for (auto s = subst.substitutions; s; s = s->getOuter())
+ auto outerGeneric = as<GenericDecl>(paramDecl->parentDecl);
+ if (!outerGeneric)
+ return paramVal;
+
+ GenericAppDeclRef* genAppArgs = subst.findGenericAppDeclRef(outerGeneric);
+ if (!genAppArgs)
{
- auto genSubst = as<GenericSubstitution>(s);
- if (!genSubst)
- continue;
-
- // the generic decl associated with the substitution list must be
- // the generic decl that declared this parameter
- auto genericDecl = genSubst->getGenericDecl();
- if (genericDecl != paramDecl->parentDecl)
- continue;
-
- // In some cases, we construct a `DeclRef` to a `GenericDecl`
- // (or a declaration under one) that only includes argument
- // values for a prefix of the parameters of the generic.
- //
- // If we aren't careful, we could end up indexing into the
- // argument list past the available range.
- //
- Count argCount = genSubst->getArgs().getCount();
+ return paramVal;
+ }
- Count argIndex = 0;
- for (auto m : genericDecl->members)
+ auto args = genAppArgs->getArgs();
+
+ // In some cases, we construct a `DeclRef` to a `GenericDecl`
+ // (or a declaration under one) that only includes argument
+ // values for a prefix of the parameters of the generic.
+ //
+ // If we aren't careful, we could end up indexing into the
+ // argument list past the available range.
+ //
+ Count argCount = args.getCount();
+
+ Count argIndex = 0;
+ for (auto m : outerGeneric->members)
+ {
+ // If we have run out of arguments, then we can stop
+ // iterating over the parameters, because `this`
+ // parameter will not be replaced with anything by
+ // the substituion.
+ //
+ if (argIndex >= argCount)
{
- // If we have run out of arguments, then we can stop
- // iterating over the parameters, because `this`
- // parameter will not be replaced with anything by
- // the substituion.
- //
- if (argIndex >= argCount)
- {
- return paramVal;
- }
+ return paramVal;
+ }
- if (m == paramDecl)
- {
- // We've found it, so return the corresponding specialization argument
- (*ioDiff)++;
- return genSubst->getArgs()[argIndex];
- }
- else if (const auto typeParam = as<GenericTypeParamDecl>(m))
- {
- argIndex++;
- }
- else if (const auto valParam = as<GenericValueParamDecl>(m))
- {
- argIndex++;
- }
- else
- {
- }
+ if (m == paramDecl)
+ {
+ // We've found it, so return the corresponding specialization argument
+ (*ioDiff)++;
+ return args[argIndex];
+ }
+ else if (const auto typeParam = as<GenericTypeParamDecl>(m))
+ {
+ argIndex++;
+ }
+ else if (const auto valParam = as<GenericValueParamDecl>(m))
+ {
+ argIndex++;
+ }
+ else
+ {
}
}
@@ -180,7 +267,7 @@ Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet
Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff)
{
- if (auto result = maybeSubstituteGenericParam(this, declRef.getDecl(), subst, ioDiff))
+ if (auto result = maybeSubstituteGenericParam(this, getDeclRef().getDecl(), subst, ioDiff))
return result;
return this;
@@ -188,21 +275,11 @@ Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, S
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool ErrorIntVal::_equalsValOverride(Val* val)
-{
- return as<ErrorIntVal>(val);
-}
-
void ErrorIntVal::_toTextOverride(StringBuilder& out)
{
out << toSlice("<error>");
}
-HashCode ErrorIntVal::_getHashCodeOverride()
-{
- return HashCode(typeid(this).hash_code());
-}
-
Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
SLANG_UNUSED(astBuilder);
@@ -211,97 +288,110 @@ Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe
return this;
}
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-// TODO: should really have a `type.cpp` and a `witness.cpp`
-
-bool TypeEqualityWitness::_equalsValOverride(Val* val)
-{
- auto otherWitness = as<TypeEqualityWitness>(val);
- if (!otherWitness)
- return false;
- return sub->equals(otherWitness->sub);
-}
-
Val* TypeEqualityWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
{
- TypeEqualityWitness* rs = astBuilder->create<TypeEqualityWitness>();
- rs->sub = as<Type>(sub->substituteImpl(astBuilder, subst, ioDiff));
- rs->sup = as<Type>(sup->substituteImpl(astBuilder, subst, ioDiff));
+ auto type = as<Type>(getSub()->substituteImpl(astBuilder, subst, ioDiff));
+ TypeEqualityWitness* rs = astBuilder->getOrCreate<TypeEqualityWitness>(type, type);
return rs;
}
void TypeEqualityWitness::_toTextOverride(StringBuilder& out)
{
- out << toSlice("TypeEqualityWitness(") << sub << toSlice(")");
-}
-
-HashCode TypeEqualityWitness::_getHashCodeOverride()
-{
- return sub->getHashCode();
+ out << toSlice("TypeEqualityWitness(") << getSub() << toSlice(")");
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclaredSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool DeclaredSubtypeWitness::_equalsValOverride(Val* val)
+Val* DeclaredSubtypeWitness::_resolveImplOverride()
{
- auto otherWitness = as<DeclaredSubtypeWitness>(val);
- if (!otherWitness)
- return false;
+ auto resolvedDeclRef = getDeclRef().declRefBase->resolve();
+ if (auto resolvedVal = as<SubtypeWitness>(resolvedDeclRef))
+ return resolvedVal;
- return sub->equals(otherWitness->sub)
- && sup->equals(otherWitness->sup)
- && declRef.equals(otherWitness->declRef);
+ auto newSub = as<Type>(getSub()->resolve());
+ auto newSup = as<Type>(getSup()->resolve());
+
+ // If we are trying to lookup for a witness that A<:B from a witness(A<:B), we
+ // can just return the witness itself.
+ if (auto lookupDeclRef = as<LookupDeclRef>(resolvedDeclRef))
+ {
+ auto witnessToLookupFrom = lookupDeclRef->getWitness();
+ if (witnessToLookupFrom->getSub()->equals(newSub) &&
+ witnessToLookupFrom->getSup()->equals(newSup))
+ return witnessToLookupFrom;
+ }
+ auto newDeclRef = as<DeclRefBase>(resolvedDeclRef);
+ if (!newDeclRef)
+ newDeclRef = getDeclRef().declRefBase;
+ if (newSub != getSub() || newSup != getSup() || newDeclRef != getDeclRef())
+ {
+ return getCurrentASTBuilder()->getDeclaredSubtypeWitness(newSub, newSup, newDeclRef);
+ }
+ return this;
}
Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
{
- if (auto genConstraintDeclRef = declRef.as<GenericTypeConstraintDecl>())
+ if (auto genConstraintDeclRef = getDeclRef().as<GenericTypeConstraintDecl>())
{
- auto genConstraintDecl = genConstraintDeclRef.getDecl();
+ auto genericDecl = as<GenericDecl>(getDeclRef().getDecl()->parentDecl);
+ if (!genericDecl)
+ goto breakLabel;
// search for a substitution that might apply to us
- for (auto s = subst.substitutions; s; s = s->getOuter())
+ auto args = tryGetGenericArguments(subst, genericDecl);
+ if (args.getCount() == 0)
+ goto breakLabel;
+
+ bool found = false;
+ Index index = 0;
+ for (auto m : genericDecl->members)
{
- if (auto genericSubst = as<GenericSubstitution>(s))
+ if (auto constraintParam = as<GenericTypeConstraintDecl>(m))
{
- // the generic decl associated with the substitution list must be
- // the generic decl that declared this parameter
- auto genericDecl = genericSubst->getGenericDecl();
- if (genericDecl != genConstraintDecl->parentDecl)
- continue;
-
- bool found = false;
- Index index = 0;
- for (auto m : genericDecl->members)
+ if (constraintParam == getDeclRef().getDecl())
{
- if (auto constraintParam = as<GenericTypeConstraintDecl>(m))
- {
- if (constraintParam == declRef.getDecl())
- {
- found = true;
- break;
- }
- index++;
- }
- }
- if (found)
- {
- (*ioDiff)++;
- auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() +
- genericDecl->getMembersOfType<GenericValueParamDecl>().getCount();
- SLANG_ASSERT(index + ordinaryParamCount < genericSubst->getArgs().getCount());
- return genericSubst->getArgs()[index + ordinaryParamCount];
+ found = true;
+ break;
}
+ index++;
+ }
+ }
+ if (found)
+ {
+ auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() +
+ genericDecl->getMembersOfType<GenericValueParamDecl>().getCount();
+ if (index + ordinaryParamCount < args.getCount())
+ {
+ (*ioDiff)++;
+ return args[index + ordinaryParamCount];
+ }
+ else
+ {
+ // When the `subst` represents a partial substitution, we may not have a corresponding argument.
+ // In this case we just return the original witness.
+ //
+ goto breakLabel;
}
}
}
+ else if (auto thisTypeConstraintDeclRef = getDeclRef().as<ThisTypeConstraintDecl>())
+ {
+ auto lookupSubst = subst.findLookupDeclRef();
+ if (lookupSubst && lookupSubst->getSupDecl() == thisTypeConstraintDeclRef.getDecl()->getInterfaceDecl())
+ {
+ (*ioDiff)++;
+ return lookupSubst->getWitness();
+ }
+ }
+
+breakLabel:;
// Perform substitution on the constituent elements.
int diff = 0;
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
- auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff);
+ auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff));
+ auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff));
+
if (!diff)
return this;
@@ -317,7 +407,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
// so we'll need to change this location in the code if we ever clean
// up the hierarchy.
//
- if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(substDeclRef.getDecl()))
+ if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(getDeclRef().getDecl()))
{
if (auto substAssocTypeDecl = as<AssocTypeDecl>(substTypeConstraintDecl->parentDecl))
{
@@ -326,12 +416,12 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
// At this point we have a constraint decl for an associated type,
// and we nee to see if we are dealing with a concrete substitution
// for the interface around that associated type.
- if (auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.getSubst(), interfaceDecl))
+ if (auto thisTypeWitness = findThisTypeWitness(subst, interfaceDecl))
{
// We need to look up the declaration that satisfies
// the requirement named by the associated type.
Decl* requirementKey = substTypeConstraintDecl;
- RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeSubst->witness, requirementKey);
+ RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeWitness, requirementKey);
switch (requirementWitness.getFlavor())
{
default:
@@ -348,6 +438,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
}
}
+ auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
auto rs = astBuilder->getDeclaredSubtypeWitness(
substSub, substSup, substDeclRef);
return rs;
@@ -355,34 +446,17 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
void DeclaredSubtypeWitness::_toTextOverride(StringBuilder& out)
{
- out << toSlice("DeclaredSubtypeWitness(") << sub << toSlice(", ") << sup << toSlice(", ") << declRef << toSlice(")");
-}
-
-HashCode DeclaredSubtypeWitness::_getHashCodeOverride()
-{
- return declRef.getHashCode();
+ out << toSlice("DeclaredSubtypeWitness(") << getSub() << toSlice(", ") << getSup() << toSlice(", ") << getDeclRef() << toSlice(")");
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TransitiveSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool TransitiveSubtypeWitness::_equalsValOverride(Val* val)
-{
- auto otherWitness = as<TransitiveSubtypeWitness>(val);
- if (!otherWitness)
- return false;
-
- return sub->equals(otherWitness->sub)
- && sup->equals(otherWitness->sup)
- && subToMid->equalsVal(otherWitness->subToMid)
- && midToSup->equalsVal(otherWitness->midToSup);
-}
-
Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
{
int diff = 0;
- SubtypeWitness* substSubToMid = as<SubtypeWitness>(subToMid->substituteImpl(astBuilder, subst, &diff));
- SubtypeWitness* substMidToSup = as<SubtypeWitness>(midToSup->substituteImpl(astBuilder, subst, &diff));
+ SubtypeWitness* substSubToMid = as<SubtypeWitness>(getSubToMid()->substituteImpl(astBuilder, subst, &diff));
+ SubtypeWitness* substMidToSup = as<SubtypeWitness>(getMidToSup()->substituteImpl(astBuilder, subst, &diff));
// If nothing changed, then we can bail out early.
if (!diff)
@@ -407,16 +481,7 @@ void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out)
// witnesses, and rely on them to print
// the starting and ending types.
- out << toSlice("TransitiveSubtypeWitness(") << subToMid << toSlice(", ") << midToSup << toSlice(")");
-}
-
-HashCode TransitiveSubtypeWitness::_getHashCodeOverride()
-{
- auto hash = sub->getHashCode();
- hash = combineHash(hash, sup->getHashCode());
- hash = combineHash(hash, subToMid->getHashCode());
- hash = combineHash(hash, midToSup->getHashCode());
- return hash;
+ out << toSlice("TransitiveSubtypeWitness(") << getSubToMid() << toSlice(", ") << getMidToSup() << toSlice(")");
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -425,9 +490,9 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a
{
int diff = 0;
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
- auto substWitness = as<SubtypeWitness>(conjunctionWitness->substituteImpl(astBuilder, subst, &diff));
+ auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff));
+ auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff));
+ auto substWitness = as<SubtypeWitness>(getConjunctionWitness()->substituteImpl(astBuilder, subst, &diff));
// If nothing changed, then we can bail out early.
if (!diff)
@@ -447,138 +512,34 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a
// simplification logic as needed.
//
return astBuilder->getExtractFromConjunctionSubtypeWitness(
- substSub, substSup, substWitness, indexInConjunction);
+ substSub, substSup, substWitness, getIndexInConjunction());
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val)
-{
- if (auto extractWitness = as<ExtractExistentialSubtypeWitness>(val))
- {
- return declRef.equals(extractWitness->declRef);
- }
- return false;
-}
-
void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out)
{
- out << toSlice("extractExistentialValue(") << declRef << toSlice(")");
-}
-
-HashCode ExtractExistentialSubtypeWitness::_getHashCodeOverride()
-{
- return declRef.getHashCode();
+ out << toSlice("extractExistentialValue(") << getDeclRef() << toSlice(")");
}
Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff);
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
+ auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
+ auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff));
+ auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff));
if (!diff)
return this;
(*ioDiff)++;
- ExtractExistentialSubtypeWitness* substValue = astBuilder->create<ExtractExistentialSubtypeWitness>();
- substValue->declRef = substDeclRef;
- substValue->sub = substSub;
- substValue->sup = substSup;
+ ExtractExistentialSubtypeWitness* substValue = astBuilder->getOrCreate<ExtractExistentialSubtypeWitness>(
+ substSub, substSup, substDeclRef);
return substValue;
}
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-bool TaggedUnionSubtypeWitness::_equalsValOverride(Val* val)
-{
- auto taggedUnionWitness = as<TaggedUnionSubtypeWitness>(val);
- if (!taggedUnionWitness)
- return false;
-
- auto caseCount = caseWitnesses.getCount();
- if (caseCount != taggedUnionWitness->caseWitnesses.getCount())
- return false;
-
- for (Index ii = 0; ii < caseCount; ++ii)
- {
- if (!caseWitnesses[ii]->equalsVal(taggedUnionWitness->caseWitnesses[ii]))
- return false;
- }
-
- return true;
-}
-
-void TaggedUnionSubtypeWitness::_toTextOverride(StringBuilder& out)
-{
- out << toSlice("TaggedUnionSubtypeWitness(");
- bool first = true;
- for (auto caseWitness : caseWitnesses)
- {
- if (!first)
- {
- out << toSlice(", ");
- }
- first = false;
-
- out << caseWitness;
- }
- out << toSlice(")");
-}
-
-HashCode TaggedUnionSubtypeWitness::_getHashCodeOverride()
-{
- HashCode hash = 0;
- for (auto caseWitness : caseWitnesses)
- {
- hash = combineHash(hash, caseWitness->getHashCode());
- }
- return hash;
-}
-
-Val* TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
-
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
-
- List<SubtypeWitness*> substCaseWitnesses;
- for (auto caseWitness : caseWitnesses)
- {
- substCaseWitnesses.add(
- as<SubtypeWitness>(caseWitness->substituteImpl(astBuilder, subst, &diff)));
- }
-
- if (!diff)
- return this;
-
- (*ioDiff)++;
-
- TaggedUnionSubtypeWitness* substWitness = astBuilder->create<TaggedUnionSubtypeWitness>();
- substWitness->sub = substSub;
- substWitness->sup = substSup;
- substWitness->caseWitnesses.swapWith(substCaseWitnesses);
- return substWitness;
-}
-
-bool ConjunctionSubtypeWitness::_equalsValOverride(Val* val)
-{
- auto other = as<ConjunctionSubtypeWitness>(val);
- if (!other)
- return false;
-
- for (Index i = 0; i < kComponentCount; ++i)
- {
- if (!other->componentWitnesses[i]) return false;
- if (!other->componentWitnesses[i]->equalsVal(componentWitnesses[i])) return false;
- }
- return true;
-}
-
void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
{
out << "ConjunctionSubtypeWitness(";
@@ -586,34 +547,23 @@ void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
{
if (i != 0) out << ",";
- auto w = componentWitnesses[i];
+ auto w = getComponentWitness(i);
if (w) out << w;
}
out << ")";
}
-HashCode ConjunctionSubtypeWitness::_getHashCodeOverride()
-{
- HashCode result = 0;
- for (Index i = 0; i < kComponentCount; ++i)
- {
- auto w = componentWitnesses[i];
- if (w) result = combineHash(result, w->getHashCode());
- }
- return result;
-}
-
Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
Val* substComponentWitnesses[kComponentCount];
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
+ auto substSub = as<Type>(getSub()->substituteImpl(astBuilder, subst, &diff));
+ auto substSup = as<Type>(getSup()->substituteImpl(astBuilder, subst, &diff));
for (Index i = 0; i < kComponentCount; ++i)
{
- auto w = componentWitnesses[i];
+ auto w = getComponentWitness(i);
substComponentWitnesses[i] = w ? w->substituteImpl(astBuilder, subst, &diff) : nullptr;
}
@@ -630,65 +580,25 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder,
auto result = astBuilder->getConjunctionSubtypeWitness(
substSub,
substSup,
- componentWitnesses[0],
- componentWitnesses[1]);
+ as<SubtypeWitness>(substComponentWitnesses[0]),
+ as<SubtypeWitness>(substComponentWitnesses[1]));
return result;
}
-bool ExtractFromConjunctionSubtypeWitness::_equalsValOverride(Val* val)
-{
- if (auto other = as<ExtractFromConjunctionSubtypeWitness>(val))
- {
- if(!sub->equals(other->sub)) return false;
- if(!sup->equals(other->sup)) return false;
- if(indexInConjunction != other->indexInConjunction) return false;
-
- return true;
- }
- return false;
-}
-
void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
{
out << "ExtractFromConjunctionSubtypeWitness(";
- if (conjunctionWitness)
- out << conjunctionWitness;
- if (sub)
- out << sub;
+ if (getConjunctionWitness())
+ out << getConjunctionWitness();
+ if (getSub())
+ out << getSub();
out << ",";
- if (sup)
- out << sup;
- out << "," << indexInConjunction;
+ if (getSup())
+ out << getSup();
+ out << "," << getIndexInConjunction();
out << ")";
}
-HashCode ExtractFromConjunctionSubtypeWitness::_getHashCodeOverride()
-{
- return combineHash(
- conjunctionWitness ? conjunctionWitness->getHashCode() : 0,
- sub ? sub->getHashCode() : 0,
- sup ? sup->getHashCode() : 0,
- indexInConjunction);
-}
-
-// ModifierVal
-
-bool ModifierVal::_equalsValOverride(Val* val)
-{
- // TODO: This is assuming we can fully deduplicate the values that represent
- // modifiers, which may not actually be the case if there are multiple modules
- // being combined that use different `ASTBuilder`s.
- //
- return this == val;
-}
-
-HashCode ModifierVal::_getHashCodeOverride()
-{
- Hasher hasher;
- hasher.hashValue((void*) this);
- return hasher.getResult();
-}
-
// UNormModifierVal
void UNormModifierVal::_toTextOverride(StringBuilder& out)
@@ -735,48 +645,14 @@ Val* NoDiffModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitu
// PolynomialIntVal
-bool PolynomialIntVal::_equalsValOverride(Val* val)
-{
- if (auto genericParamVal = as<GenericParamIntVal>(val))
- {
- return constantTerm == 0 && terms.getCount() == 1 &&
- terms[0]->paramFactors.getCount() == 1 && terms[0]->constFactor == 1 &&
- terms[0]->paramFactors[0]->param->equalsVal(genericParamVal) &&
- terms[0]->paramFactors[0]->power == 1;
- }
- else if (auto otherPolynomial = as<PolynomialIntVal>(val))
- {
- if (constantTerm != otherPolynomial->constantTerm)
- return false;
- if (terms.getCount() != otherPolynomial->terms.getCount())
- return false;
- for (Index i = 0; i < terms.getCount(); i++)
- {
- auto& thisTerm = *(terms[i]);
- auto& thatTerm = *(otherPolynomial->terms[i]);
- if (thisTerm.constFactor != thatTerm.constFactor)
- return false;
- if (thisTerm.paramFactors.getCount() != thatTerm.paramFactors.getCount())
- return false;
- for (Index j = 0; j < thisTerm.paramFactors.getCount(); j++)
- {
- if (thisTerm.paramFactors[j]->power != thatTerm.paramFactors[j]->power)
- return false;
- if (!thisTerm.paramFactors[j]->param->equalsVal(thatTerm.paramFactors[j]->param))
- return false;
- }
- }
- return true;
- }
- return false;
-}
-
void PolynomialIntVal::_toTextOverride(StringBuilder& out)
{
+ auto constantTerm = getConstantTerm();
+ auto terms = getTerms();
for (Index i = 0; i < terms.getCount(); i++)
{
auto& term = *(terms[i]);
- if (term.constFactor > 0)
+ if (term.getConstFactor() > 0)
{
if (i > 0)
out << "+";
@@ -784,14 +660,14 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out)
else
out << "-";
bool isFirstFactor = true;
- if (abs(term.constFactor) != 1 || term.paramFactors.getCount() == 0)
+ if (abs(term.getConstFactor()) != 1 || term.getParamFactors().getCount() == 0)
{
- out << abs(term.constFactor);
+ out << abs(term.getConstFactor());
isFirstFactor = false;
}
- for (Index j = 0; j < term.paramFactors.getCount(); j++)
+ for (Index j = 0; j < term.getParamFactors().getCount(); j++)
{
- auto factor = term.paramFactors[j];
+ auto factor = term.getParamFactors()[j];
if (isFirstFactor)
{
isFirstFactor = false;
@@ -800,10 +676,10 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out)
{
out << "*";
}
- factor->param->toText(out);
- if (factor->power != 1)
+ factor->getParam()->toText(out);
+ if (factor->getPower() != 1)
{
- out << "^^" << factor->power;
+ out << "^^" << factor->getPower();
}
}
}
@@ -821,227 +697,304 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out)
}
}
-HashCode PolynomialIntVal::_getHashCodeOverride()
+struct PolynomialIntValBuilder
{
- HashCode result = (HashCode)constantTerm;
- for (auto& term : terms)
+ ASTBuilder* astBuilder;
+
+ IntegerLiteralValue constantTerm = 0;
+ List<PolynomialIntValTerm*> terms;
+
+ PolynomialIntValBuilder(ASTBuilder* inAstBuilder)
+ : astBuilder(inAstBuilder)
+ {}
+
+ // compute val += opreand*multiplier;
+ bool addToPolynomialTerm(IntVal* operand, IntegerLiteralValue multiplier)
{
- if (!term) continue;
- result = combineHash(result, (HashCode)term->constFactor);
- for (auto& factor : term->paramFactors)
+ if (auto c = as<ConstantIntVal>(operand))
{
- result = combineHash(result, factor->param->getHashCode());
- result = combineHash(result, (HashCode)factor->power);
+ constantTerm += c->getValue() * multiplier;
+ return true;
}
+ else if (auto poly = as<PolynomialIntVal>(operand))
+ {
+ constantTerm += poly->getConstantTerm() * multiplier;
+ for (auto term : poly->getTerms())
+ {
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ multiplier * term->getConstFactor(), term->getParamFactors());
+ terms.add(newTerm);
+ }
+ return true;
+ }
+ else if (auto genVal = as<IntVal>(operand))
+ {
+ auto factor = astBuilder->getOrCreate<PolynomialIntValFactor>(genVal, 1);
+ auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(multiplier, makeArrayViewSingle(factor));
+ terms.add(term);
+ return true;
+ }
+ return false;
}
- return result;
-}
+
+ IntVal* canonicalize(Type* type)
+ {
+ List<PolynomialIntValTerm*> newTerms;
+ IntegerLiteralValue newConstantTerm = constantTerm;
+ auto addTerm = [&](PolynomialIntValTerm* newTerm)
+ {
+ for (auto& term : newTerms)
+ {
+ if (term->canCombineWith(*newTerm))
+ {
+ term = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ term->getConstFactor() + newTerm->getConstFactor(),
+ term->getParamFactors());
+ return;
+ }
+ }
+ newTerms.add(newTerm);
+ };
+ for (auto term : terms)
+ {
+ if (term->getConstFactor() == 0)
+ continue;
+ List<PolynomialIntValFactor*> newFactors;
+ List<bool> factorIsDifferent;
+ for (Index i = 0; i < term->getParamFactors().getCount(); i++)
+ {
+ auto factor = term->getParamFactors()[i];
+ bool factorFound = false;
+ for (Index j = 0; j < newFactors.getCount(); j++)
+ {
+ auto& newFactor = newFactors[j];
+ if (factor->getParam()->equals(newFactor->getParam()))
+ {
+ if (!factorIsDifferent[j])
+ {
+ factorIsDifferent[j] = true;
+ auto clonedFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(newFactor->getParam(), newFactor->getPower());
+ newFactor = clonedFactor;
+ }
+ newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(newFactor->getParam(), newFactor->getPower() + factor->getPower());
+ factorFound = true;
+ break;
+ }
+ }
+ if (!factorFound)
+ {
+ newFactors.add(factor);
+ factorIsDifferent.add(false);
+ }
+ }
+ List<PolynomialIntValFactor*> newFactors2;
+ // Remove zero-powered factors.
+ for (auto factor : newFactors)
+ {
+ if (factor->getPower() != 0)
+ newFactors2.add(factor);
+ }
+ if (newFactors2.getCount() == 0)
+ {
+ newConstantTerm += term->getConstFactor();
+ continue;
+ }
+ newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; });
+ bool isDifferent = false;
+ if (newFactors2.getCount() != term->getParamFactors().getCount())
+ isDifferent = true;
+ if (!isDifferent)
+ {
+ for (Index i = 0; i < term->getParamFactors().getCount(); i++)
+ if (term->getParamFactors()[i] != newFactors2[i])
+ {
+ isDifferent = true;
+ break;
+ }
+ }
+ if (!isDifferent)
+ {
+ addTerm(term);
+ }
+ else
+ {
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(term->getConstFactor(), newFactors2.getArrayView());
+ addTerm(newTerm);
+ }
+ }
+ List<PolynomialIntValTerm*> newTerms2;
+ for (auto term : newTerms)
+ {
+ if (term->getConstFactor() == 0)
+ continue;
+ newTerms2.add(term);
+ }
+ newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; });
+ terms = _Move(newTerms2);
+ constantTerm = newConstantTerm;
+ if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->getConstFactor() == 1 && terms[0]->getParamFactors().getCount() == 1 &&
+ terms[0]->getParamFactors()[0]->getPower() == 1)
+ {
+ return terms[0]->getParamFactors()[0]->getParam();
+ }
+ if (terms.getCount() == 0)
+ return astBuilder->getIntVal(type, constantTerm);
+ return nullptr;
+ }
+
+ IntVal* getIntVal(Type* type)
+ {
+ if (auto canVal = canonicalize(type))
+ return canVal;
+ return astBuilder->getOrCreate<PolynomialIntVal>(type, constantTerm, terms.getArrayView());
+ }
+};
Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- IntegerLiteralValue evaluatedConstantTerm = constantTerm;
- List<PolynomialIntValTerm*> evaluatedTerms;
- for (auto& term : terms)
+ PolynomialIntValBuilder builder(astBuilder);
+ for (auto& term : getTerms())
{
IntegerLiteralValue evaluatedTermConstFactor;
List<PolynomialIntValFactor*> evaluatedTermParamFactors;
- evaluatedTermConstFactor = term->constFactor;
- for (auto& factor : term->paramFactors)
+ evaluatedTermConstFactor = term->getConstFactor();
+ for (auto& factor : term->getParamFactors())
{
- auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff);
+ auto substResult = factor->getParam()->substituteImpl(astBuilder, subst, &diff);
if (auto constantVal = as<ConstantIntVal>(substResult))
{
- evaluatedTermConstFactor *= constantVal->value;
+ evaluatedTermConstFactor *= constantVal->getValue();
}
else if (auto intResult = as<IntVal>(substResult))
{
- auto newFactor = astBuilder->create<PolynomialIntValFactor>();
- newFactor->param = intResult;
- newFactor->power = factor->power;
+ auto newFactor = astBuilder->getOrCreate<PolynomialIntValFactor>(intResult, factor->getPower());
evaluatedTermParamFactors.add(newFactor);
}
}
if (evaluatedTermParamFactors.getCount() == 0)
{
- evaluatedConstantTerm += evaluatedTermConstFactor;
+ builder.constantTerm += evaluatedTermConstFactor;
}
else
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->paramFactors = _Move(evaluatedTermParamFactors);
- newTerm->constFactor = evaluatedTermConstFactor;
- evaluatedTerms.add(newTerm);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ evaluatedTermConstFactor, evaluatedTermParamFactors.getArrayView());
+ builder.terms.add(newTerm);
}
}
*ioDiff += diff;
- if (evaluatedTerms.getCount() == 0)
- return astBuilder->getIntVal(type, evaluatedConstantTerm);
+ if (builder.terms.getCount() == 0)
+ return astBuilder->getIntVal(getType(), builder.constantTerm);
if (diff != 0)
{
- auto newPolynomial = astBuilder->create<PolynomialIntVal>(type);
- newPolynomial->constantTerm = evaluatedConstantTerm;
- newPolynomial->terms = _Move(evaluatedTerms);
- return newPolynomial->canonicalize(astBuilder);
+ return builder.getIntVal(getType());
}
return this;
}
-
-// compute val += opreand*multiplier;
-bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier)
-{
- if (auto c = as<ConstantIntVal>(operand))
- {
- val->constantTerm += c->value * multiplier;
- return true;
- }
- else if (auto poly = as<PolynomialIntVal>(operand))
- {
- val->constantTerm += poly->constantTerm * multiplier;
- for (auto term : poly->terms)
- {
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = multiplier * term->constFactor;
- newTerm->paramFactors = term->paramFactors;
- val->terms.add(newTerm);
- }
- return true;
- }
- else if (auto genVal = as<IntVal>(operand))
- {
- auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = multiplier;
- auto factor = astBuilder->create<PolynomialIntValFactor>();
- factor->power = 1;
- factor->param = genVal;
- term->paramFactors.add(factor);
- val->terms.add(term);
- return true;
- }
- return false;
-}
-
-PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base)
+IntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base)
{
- auto result = astBuilder->create<PolynomialIntVal>(base->type);
- if (!addToPolynomialTerm(astBuilder, result, base, -1))
- return nullptr;
- result->canonicalize(astBuilder);
- return result;
+ PolynomialIntValBuilder builder(astBuilder);
+ builder.addToPolynomialTerm(base, -1);
+ return builder.getIntVal(base->getType());
}
-PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
+IntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
{
- auto result = astBuilder->create<PolynomialIntVal>(op0->type);
- if (!addToPolynomialTerm(astBuilder, result, op0, 1))
- return nullptr;
- if (!addToPolynomialTerm(astBuilder, result, op1, -1))
- return nullptr;
- result->canonicalize(astBuilder);
- return result;
+ PolynomialIntValBuilder builder(astBuilder);
+ builder.addToPolynomialTerm(op0, 1);
+ builder.addToPolynomialTerm(op1, -1);
+ return builder.getIntVal(op0->getType());
}
-PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
+IntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
{
- auto result = astBuilder->create<PolynomialIntVal>(op0->type);
- if (!addToPolynomialTerm(astBuilder, result, op0, 1))
- return nullptr;
- if (!addToPolynomialTerm(astBuilder, result, op1, 1))
- return nullptr;
- result->canonicalize(astBuilder);
- return result;
+ PolynomialIntValBuilder builder(astBuilder);
+ builder.addToPolynomialTerm(op0, 1);
+ builder.addToPolynomialTerm(op1, 1);
+ return builder.getIntVal(op0->getType());
}
-PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
+IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
{
if (auto poly0 = as<PolynomialIntVal>(op0))
{
if (auto poly1 = as<PolynomialIntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>(poly0->type);
+ PolynomialIntValBuilder builder(astBuilder);
// add poly0.constant * poly1.constant
- result->constantTerm = poly0->constantTerm * poly1->constantTerm;
+ builder.constantTerm = poly0->getConstantTerm() * poly1->getConstantTerm();
// add poly0.constant * poly1.terms
- if (poly0->constantTerm != 0)
+ if (poly0->getConstantTerm() != 0)
{
- for (auto term : poly1->terms)
+ for (auto term : poly1->getTerms())
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = poly0->constantTerm * term->constFactor;
- newTerm->paramFactors.addRange(term->paramFactors);
- result->terms.add(newTerm);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ poly0->getConstantTerm() * term->getConstFactor(), term->getParamFactors());
+ builder.terms.add(newTerm);
}
}
// add poly1.constant * poly0.terms
- if (poly1->constantTerm != 0)
+ if (poly1->getConstantTerm() != 0)
{
- for (auto term : poly0->terms)
+ for (auto term : poly0->getTerms())
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = poly1->constantTerm * term->constFactor;
- newTerm->paramFactors.addRange(term->paramFactors);
- result->terms.add(newTerm);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ poly1->getConstantTerm() * term->getConstFactor(),
+ term->getParamFactors());
+ builder.terms.add(newTerm);
}
}
// add poly1.terms * poly0.terms
- for (auto term0 : poly0->terms)
+ for (auto term0 : poly0->getTerms())
{
- for (auto term1 : poly1->terms)
+ for (auto term1 : poly1->getTerms())
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term0->constFactor * term1->constFactor;
- newTerm->paramFactors.addRange(term0->paramFactors);
- newTerm->paramFactors.addRange(term1->paramFactors);
- result->terms.add(newTerm);
+ List<PolynomialIntValFactor*> newFactors;
+ for (auto f : term0->getParamFactors()) newFactors.add(f);
+ for (auto f : term1->getParamFactors()) newFactors.add(f);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ term0->getConstFactor() * term1->getConstFactor(), newFactors.getArrayView());
+ builder.terms.add(newTerm);
}
}
- result->canonicalize(astBuilder);
- return result;
+ return builder.getIntVal(op0->getType());
}
else if (auto cVal1 = as<ConstantIntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>(poly0->type);
- result->constantTerm = poly0->constantTerm * cVal1->value;
- auto factor1 = astBuilder->create<PolynomialIntValFactor>();
- for (auto term : poly0->terms)
+ PolynomialIntValBuilder builder(astBuilder);
+ builder.constantTerm = poly0->getConstantTerm() * cVal1->getValue();
+ for (auto term : poly0->getTerms())
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term->constFactor * cVal1->value;
- newTerm->paramFactors.addRange(term->paramFactors);
- newTerm->paramFactors.add(factor1);
- result->terms.add(newTerm);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(term->getConstFactor() * cVal1->getValue(), term->getParamFactors());
+ builder.terms.add(newTerm);
}
- result->canonicalize(astBuilder);
- return result;
+ return builder.getIntVal(poly0->getType());
}
else if (auto val1 = as<IntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>(poly0->type);
- result->constantTerm = 0;
- auto factor1 = astBuilder->create<PolynomialIntValFactor>();
- factor1->power = 1;
- factor1->param = val1;
- if (poly0->constantTerm != 0)
+ PolynomialIntValBuilder builder(astBuilder);
+ auto factor1 = astBuilder->getOrCreate<PolynomialIntValFactor>(val1, 1);
+ if (poly0->getConstantTerm() != 0)
{
- auto term0 = astBuilder->create<PolynomialIntValTerm>();
- term0->constFactor = poly0->constantTerm;
- term0->paramFactors.add(factor1);
- result->terms.add(term0);
+ auto term0 = astBuilder->getOrCreate<PolynomialIntValTerm>(poly0->getConstantTerm(), makeArrayViewSingle(factor1));
+ builder.terms.add(term0);
}
- for (auto term : poly0->terms)
+ for (auto term : poly0->getTerms())
{
- auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term->constFactor;
- newTerm->paramFactors.addRange(term->paramFactors);
- newTerm->paramFactors.add(factor1);
- result->terms.add(newTerm);
+ List<PolynomialIntValFactor*> newFactors;
+ for (auto f: term->getParamFactors())
+ newFactors.add(f);
+ newFactors.add(factor1);
+ auto newTerm = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ term->getConstFactor(), newFactors.getArrayView());
+ builder.terms.add(newTerm);
}
- result->canonicalize(astBuilder);
- return result;
+ return builder.getIntVal(poly0->getType());
}
else
return nullptr;
@@ -1058,184 +1011,48 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
}
else if (auto cVal1 = as<ConstantIntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>(val0->type);
- auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = cVal1->value;
- auto factor0 = astBuilder->create<PolynomialIntValFactor>();
- factor0->power = 1;
- factor0->param = val0;
- term->paramFactors.add(factor0);
- result->terms.add(term);
- result->canonicalize(astBuilder);
- return result;
+ PolynomialIntValBuilder builder(astBuilder);
+ auto factor0 = astBuilder->getOrCreate<PolynomialIntValFactor>(val0, 1);
+ auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(
+ cVal1->getValue(), makeArrayView(&factor0, 1));
+ builder.terms.add(term);
+ return builder.getIntVal(val0->getType());
}
else if (auto val1 = as<IntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>(val0->type);
- auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = 1;
- auto factor0 = astBuilder->create<PolynomialIntValFactor>();
- factor0->power = 1;
- factor0->param = val0;
- term->paramFactors.add(factor0);
- auto factor1 = astBuilder->create<PolynomialIntValFactor>();
- factor1->power = 1;
- factor1->param = val1;
- term->paramFactors.add(factor1);
- result->terms.add(term);
- result->canonicalize(astBuilder);
- return result;
+ PolynomialIntValBuilder builder(astBuilder);
+ auto factor0 = astBuilder->getOrCreate<PolynomialIntValFactor>(val0, 1);
+ auto factor1 = astBuilder->getOrCreate<PolynomialIntValFactor>(val1, 1);
+ PolynomialIntValFactor* newFactors[] = { factor0, factor1 };
+ auto term = astBuilder->getOrCreate<PolynomialIntValTerm>(1, makeArrayView(newFactors));
+ builder.terms.add(term);
+ return builder.getIntVal(val0->getType());
}
}
return nullptr;
}
-IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder)
-{
- List<PolynomialIntValTerm*> newTerms;
- IntegerLiteralValue newConstantTerm = constantTerm;
- auto addTerm = [&](PolynomialIntValTerm* newTerm)
- {
- for (auto term : newTerms)
- {
- if (term->canCombineWith(*newTerm))
- {
- term->constFactor += newTerm->constFactor;
- return;
- }
- }
- newTerms.add(newTerm);
- };
- for (auto term : terms)
- {
- if (term->constFactor == 0)
- continue;
- List<PolynomialIntValFactor*> newFactors;
- List<bool> factorIsDifferent;
- for (Index i = 0; i < term->paramFactors.getCount(); i++)
- {
- auto factor = term->paramFactors[i];
- bool factorFound = false;
- for (Index j = 0; j < newFactors.getCount(); j++)
- {
- auto& newFactor = newFactors[j];
- if (factor->param->equalsVal(newFactor->param))
- {
- if (!factorIsDifferent[j])
- {
- factorIsDifferent[j] = true;
- auto clonedFactor = builder->create<PolynomialIntValFactor>();
- clonedFactor->param = newFactor->param;
- clonedFactor->power = newFactor->power;
- newFactor = clonedFactor;
- }
- newFactor->power += factor->power;
- factorFound = true;
- break;
- }
- }
- if (!factorFound)
- {
- newFactors.add(factor);
- factorIsDifferent.add(false);
- }
- }
- List<PolynomialIntValFactor*> newFactors2;
- for (auto factor : newFactors)
- {
- if (factor->power != 0)
- newFactors2.add(factor);
- }
- if (newFactors2.getCount() == 0)
- {
- newConstantTerm += term->constFactor;
- continue;
- }
- newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; });
- bool isDifferent = false;
- if (newFactors2.getCount() != term->paramFactors.getCount())
- isDifferent = true;
- if (!isDifferent)
- {
- for (Index i = 0; i < term->paramFactors.getCount(); i++)
- if (term->paramFactors[i] != newFactors2[i])
- {
- isDifferent = true;
- break;
- }
- }
- if (!isDifferent)
- {
- addTerm(term);
- }
- else
- {
- auto newTerm = builder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term->constFactor;
- newTerm->paramFactors = _Move(newFactors2);
- addTerm(newTerm);
- }
- }
- List<PolynomialIntValTerm*> newTerms2;
- for (auto term : newTerms)
- {
- if (term->constFactor == 0)
- continue;
- newTerms2.add(term);
- }
- newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; });
- terms = _Move(newTerms2);
- constantTerm = newConstantTerm;
- if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->constFactor == 1 && terms[0]->paramFactors.getCount() == 1 &&
- terms[0]->paramFactors[0]->power == 1)
- {
- return terms[0]->paramFactors[0]->param;
- }
- if (terms.getCount() == 0)
- return builder->getIntVal(type, constantTerm);
- return this;
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeCastIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool TypeCastIntVal::_equalsValOverride(Val* val)
-{
- if (auto typeCastIntVal = as<TypeCastIntVal>(val))
- {
- if (!type->equals(typeCastIntVal->type))
- return false;
- if (!base->equalsVal(typeCastIntVal->base))
- return false;
- return true;
- }
- return false;
-}
void TypeCastIntVal::_toTextOverride(StringBuilder& out)
{
- type->toText(out);
+ getType()->toText(out);
out << "(";
- base->toText(out);
+ getBase()->toText(out);
out << ")";
}
-HashCode TypeCastIntVal::_getHashCodeOverride()
-{
- HashCode result = type->getHashCode();
- result = combineHash(result, base->getHashCode());
- return result;
-}
-
Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);
if (auto c = as<ConstantIntVal>(base))
{
- IntegerLiteralValue resultValue = c->value;
+ IntegerLiteralValue resultValue = c->getValue();
auto baseType = as<BasicExpressionType>(resultType);
if (baseType)
{
- switch (baseType->baseType)
+ switch (baseType->getBaseType())
{
case BaseType::Int:
resultValue = (int)resultValue;
@@ -1275,11 +1092,11 @@ Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val*
Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto substBase = base->substituteImpl(astBuilder, subst, &diff);
- if (substBase != base)
+ auto substBase = getBase()->substituteImpl(astBuilder, subst, &diff);
+ if (substBase != getBase())
diff++;
- auto substType = as<Type>(type->substituteImpl(astBuilder, subst, &diff));
- if (substType != type)
+ auto substType = as<Type>(getType()->substituteImpl(astBuilder, subst, &diff));
+ if (substType != getType())
diff++;
*ioDiff += diff;
if (diff)
@@ -1289,7 +1106,7 @@ Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio
return newVal;
else
{
- auto result = astBuilder->create<TypeCastIntVal>(substType, substBase);
+ auto result = astBuilder->getOrCreate<TypeCastIntVal>(substType, substBase);
return result;
}
}
@@ -1297,29 +1114,20 @@ Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio
return this;
}
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-bool FuncCallIntVal::_equalsValOverride(Val* val)
+Val* TypeCastIntVal::_resolveImplOverride()
{
- if (auto funcCallIntVal = as<FuncCallIntVal>(val))
- {
- if (!funcDeclRef.equals(funcCallIntVal->funcDeclRef))
- return false;
- if (args.getCount() != funcCallIntVal->args.getCount())
- return false;
- for (Index i = 0; i < args.getCount(); i++)
- {
- if (!args[i]->equalsVal(funcCallIntVal->args[i]))
- return false;
- }
- return true;
- }
- return false;
+ if (auto resolved = tryFoldImpl(getCurrentASTBuilder(), getType(), getBase(), nullptr))
+ return resolved;
+ return this;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
void FuncCallIntVal::_toTextOverride(StringBuilder& out)
{
+ auto args = getArgs();
+ auto funcDeclRef = getFuncDeclRef();
+
auto argToText = [&](int index)
{
if (as<PolynomialIntVal>(args[index]) || as<FuncCallIntVal>(args[index]))
@@ -1369,14 +1177,37 @@ void FuncCallIntVal::_toTextOverride(StringBuilder& out)
}
}
-HashCode FuncCallIntVal::_getHashCodeOverride()
+Val* FuncCallIntVal::_resolveImplOverride()
{
- HashCode result = funcDeclRef.getHashCode();
+ auto astBuilder = getCurrentASTBuilder();
+ auto args = getArgs();
+ auto funcDeclRef = getFuncDeclRef();
+ auto funcType = getFuncType();
+
+ Val* resolvedVal = this;
+
+ auto newFuncDeclRef = as<DeclRefBase>(funcDeclRef.declRefBase->resolve());
+ if (!newFuncDeclRef)
+ return this;
+ bool diff = false;
+ List<IntVal*> newArgs;
for (auto arg : args)
{
- result = combineHash(result, arg->getHashCode());
+ auto newArg = as<IntVal>(arg->resolve());
+ if (!newArg)
+ return this;
+ newArgs.add(newArg);
+ if (newArg != arg)
+ diff = true;
}
- return result;
+
+ if (auto resolved = tryFoldImpl(astBuilder, getType(), newFuncDeclRef, newArgs, nullptr))
+ resolvedVal = resolved;
+ else if (diff)
+ {
+ resolvedVal = astBuilder->getOrCreate<FuncCallIntVal>(getType(), newFuncDeclRef, funcType, newArgs.getArrayView());
+ }
+ return resolvedVal;
}
Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink)
@@ -1413,25 +1244,25 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
#define BINARY_OPERATOR_CASE(op) \
if (opNameSlice == toSlice(#op)) \
{ \
- resultValue = constArgs[0]->value op constArgs[1]->value; \
+ resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \
} else
#define DIV_OPERATOR_CASE(op) \
if (opNameSlice == toSlice(#op)) \
{ \
- if (constArgs[1]->value == 0) \
+ if (constArgs[1]->getValue() == 0) \
{ \
if (sink) \
sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \
return nullptr; \
} \
- resultValue = constArgs[0]->value op constArgs[1]->value; \
+ resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \
} else
#define LOGICAL_OPERATOR_CASE(op) \
if (opNameSlice == toSlice(#op)) \
{ \
- resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \
+ resultValue = (((constArgs[0]->getValue()!=0) op (constArgs[1]->getValue()!=0)) ? 1 : 0); \
} else
@@ -1463,9 +1294,9 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
LOGICAL_OPERATOR_CASE(&&)
LOGICAL_OPERATOR_CASE(||)
// Special cases need their "operator" names quoted.
- SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->value != 0) ? 1 : 0);)
- SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->value;)
- SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value;)
+ SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->getValue() != 0) ? 1 : 0);)
+ SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->getValue();)
+ SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->getValue() != 0 ? constArgs[1]->getValue() : constArgs[2]->getValue();)
TERMINATING_CASE(SLANG_UNREACHABLE("constant folding of FuncCallIntVal");)
return astBuilder->getIntVal(resultType, resultValue);
@@ -1483,9 +1314,9 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff);
+ auto newFuncDeclRef = getFuncDeclRef().substituteImpl(astBuilder, subst, &diff);
List<IntVal*> newArgs;
- for (auto& arg : args)
+ for (auto& arg : getArgs())
{
auto substArg = arg->substituteImpl(astBuilder, subst, &diff);
if (substArg != arg)
@@ -1496,15 +1327,12 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio
if (diff)
{
// TODO: report diagnostics back.
- auto newVal = tryFoldImpl(astBuilder, type, newFuncDeclRef, newArgs, nullptr);
+ auto newVal = tryFoldImpl(astBuilder, getType(), newFuncDeclRef, newArgs, nullptr);
if (newVal)
return newVal;
else
{
- auto result = astBuilder->create<FuncCallIntVal>(type);
- result->args = _Move(newArgs);
- result->funcDeclRef = newFuncDeclRef;
- result->funcType = funcType;
+ auto result = astBuilder->getOrCreate<FuncCallIntVal>(getType(), newFuncDeclRef, getFuncType(), newArgs.getArrayView());
return result;
}
}
@@ -1514,40 +1342,47 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! WitnessLookupIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool WitnessLookupIntVal::_equalsValOverride(Val* val)
-{
- if (auto lookupIntVal = as<WitnessLookupIntVal>(val))
- {
- if (!witness->equalsVal(lookupIntVal->witness))
- return false;
- if (key != lookupIntVal->key)
- return false;
- return true;
- }
- return false;
-}
-
void WitnessLookupIntVal::_toTextOverride(StringBuilder& out)
{
- witness->sub->toText(out);
+ getWitness()->getSub()->toText(out);
out << ".";
- out << (key->getName() ? key->getName()->text : "??");
+ out << (getKey()->getName() ? getKey()->getName()->text : "??");
}
-HashCode WitnessLookupIntVal::_getHashCodeOverride()
+Val* WitnessLookupIntVal::_resolveImplOverride()
{
- HashCode result = witness->getHashCode();
- result = combineHash(result, Slang::getHashCode(key));
- return result;
+ auto astBuilder = getCurrentASTBuilder();
+
+ auto newWitness = as<SubtypeWitness>(getWitness()->resolve());
+ if (!newWitness)
+ return this;
+
+ auto witnessVal = tryLookUpRequirementWitness(astBuilder, newWitness, getKey());
+ if (witnessVal.getFlavor() == RequirementWitness::Flavor::val)
+ {
+ return witnessVal.getVal();
+ }
+
+ auto newType = as<Type>(getType()->resolve());
+ if (!newType)
+ return this;
+
+ if (newWitness != getWitness() || newType != getType())
+ {
+ return astBuilder->getOrCreate<WitnessLookupIntVal>(newType, newWitness, getKey());
+ }
+
+ return this;
}
+
Val* WitnessLookupIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto newWitness = witness->substituteImpl(astBuilder, subst, &diff);
+ auto newWitness = getWitness()->substituteImpl(astBuilder, subst, &diff);
*ioDiff += diff;
if (diff)
{
- auto witnessEntry = tryFoldOrNull(astBuilder, as<SubtypeWitness>(newWitness), key);
+ auto witnessEntry = tryFoldOrNull(astBuilder, as<SubtypeWitness>(newWitness), getKey());
if (witnessEntry)
return witnessEntry;
}
@@ -1573,51 +1408,93 @@ Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witnes
{
if (auto result = tryFoldOrNull(astBuilder, witness, key))
return result;
- auto witnessResult = astBuilder->create<WitnessLookupIntVal>();
- witnessResult->witness = witness;
- witnessResult->key = key;
- witnessResult->type = type;
+ auto witnessResult = astBuilder->getOrCreate<WitnessLookupIntVal>(type, witness, key);
return witnessResult;
}
-
-bool DifferentiateVal::_equalsValOverride(Val* val)
-{
- if (auto other = as<DifferentiateVal>(val))
- {
- return other->astNodeType == astNodeType && other->func == func;
- }
- return false;
-}
-
void DifferentiateVal::_toTextOverride(StringBuilder& out)
{
out << "DifferentiateVal(";
- out << func;
+ out << getFunc();
out << ")";
}
-HashCode DifferentiateVal::_getHashCodeOverride()
-{
- HashCode result = (HashCode)astNodeType;
- result = combineHash(result, func.getHashCode());
- return result;
-}
-
Val* DifferentiateVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
- auto newFunc = func.substituteImpl(astBuilder, subst, &diff);
+ auto newFunc = getFunc().substituteImpl(astBuilder, subst, &diff);
*ioDiff += diff;
if (diff)
{
auto result = as<DifferentiateVal>(astBuilder->createByNodeType(astNodeType));
- result->func = newFunc;
+ result->getFunc() = newFunc;
return result;
}
// Nothing found: don't substitute.
return this;
}
+Val* DifferentiateVal::_resolveImplOverride()
+{
+ return this;
+}
+
+Val* PolynomialIntValFactor::_resolveImplOverride()
+{
+ auto astBuilder = getCurrentASTBuilder();
+
+ auto newParam = as<IntVal>(getParam()->resolve());
+ if (newParam && newParam != getParam())
+ return astBuilder->getOrCreate<PolynomialIntValFactor>(newParam, getPower());
+
+ return this;
+}
+
+Val* PolynomialIntValTerm::_resolveImplOverride()
+{
+ auto astBuilder = getCurrentASTBuilder();
+
+ bool diff = false;
+ List<PolynomialIntValFactor*> newFactors;
+ for (auto factor : getParamFactors())
+ {
+ auto newFactor = as<PolynomialIntValFactor>(factor->resolve());
+ if (!newFactor)
+ return this;
+
+ if (newFactor != factor)
+ diff = true;
+ newFactors.add(newFactor);
+ }
+
+ if (diff)
+ return astBuilder->getOrCreate<PolynomialIntValTerm>(getConstFactor(), newFactors.getArrayView());
+
+ return this;
+}
+
+Val* PolynomialIntVal::_resolveImplOverride()
+{
+ auto astBuilder = getCurrentASTBuilder();
+
+ bool diff = false;
+ PolynomialIntValBuilder builder(astBuilder);
+ builder.constantTerm = getConstantTerm();
+ for (auto term : getTerms())
+ {
+ auto newTerm = as<PolynomialIntValTerm>(term->resolve());
+ if (!newTerm)
+ return this;
+
+ if (newTerm != term)
+ diff = true;
+ builder.terms.add(newTerm);
+ }
+
+ if (diff)
+ return builder.getIntVal(getType());
+
+ return this;
+}
} // namespace Slang
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index cb4e94ebb..c45c42e02 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -8,17 +8,139 @@ namespace Slang {
// Syntax class definitions for compile-time values.
+class DirectDeclRef : public DeclRefBase
+{
+public:
+ SLANG_AST_CLASS(DirectDeclRef)
+
+ DirectDeclRef(Decl* decl)
+ {
+ setOperands(decl);
+ }
+
+ DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ void _toTextOverride(StringBuilder& out);
+ Val* _resolveImplOverride();
+ DeclRefBase* _getBaseOverride();
+};
+
+// Represent an static member of a base decl.
+// Note that we automatically fold the DeclRef if the path is known to be static.
+// For example, MemberDeclRef(DirectDeclRef(A), B) ==> DirectDeclRef(B),
+// and MemberDeclRef(MemberDeclRef(A, B), C) ==> MemberDeclRef(A, C).
+//
+class MemberDeclRef : public DeclRefBase
+{
+public:
+ SLANG_AST_CLASS(MemberDeclRef);
+
+ DeclRefBase* getParentOperand() { return as<DeclRefBase>(getOperand(1)); }
+
+ MemberDeclRef(Decl* decl, DeclRefBase* parent)
+ {
+ setOperands(decl, parent);
+ }
+
+ DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+
+ void _toTextOverride(StringBuilder& out);
+
+ Val* _resolveImplOverride();
+
+ DeclRefBase* _getBaseOverride();
+};
+
+
+// Represent a lookup of SuperType::`m_decl` from `lookupSourceType` type that we know conforms to SuperType.
+class LookupDeclRef : public DeclRefBase
+{
+public:
+ SLANG_AST_CLASS(LookupDeclRef);
+
+ // m_decl represents the decl in SuperType that we want to lookup.
+
+ // The source type that we are looking up from.
+ Type* getLookupSource()
+ {
+ return as<Type>(getOperand(1));
+ }
+
+ // Witness that `lookupSourceType`:SuperType.
+ SubtypeWitness* getWitness()
+ {
+ return as<SubtypeWitness>(getOperand(2));
+ }
+
+ LookupDeclRef(Decl* declToLookup, Type* lookupSource, SubtypeWitness* witness)
+ {
+ setOperands(declToLookup, lookupSource, witness);
+ }
+
+ Decl* getSupDecl();
+
+ DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+
+ void _toTextOverride(StringBuilder& out);
+
+ Val* _resolveImplOverride();
+
+ DeclRefBase* _getBaseOverride();
+
+private:
+ Val* tryResolve(SubtypeWitness* newWitness, Type* newLookupSource);
+};
+
+// Represents a specialization of a generic decl.
+class GenericAppDeclRef : public DeclRefBase
+{
+public:
+ SLANG_AST_CLASS(GenericAppDeclRef);
+
+ DeclRefBase* getGenericDeclRef() { return as<DeclRefBase>(getOperand(1)); }
+ Index getArgCount() { return getOperandCount() - 2; }
+ Val* getArg(Index index) { return getOperand(index + 2); }
+
+ GenericAppDeclRef(Decl* innerDecl, DeclRefBase* genericDeclRef, OperandView<Val> args)
+ {
+ m_operands.add(ValNodeOperand(innerDecl));
+ m_operands.add(ValNodeOperand(genericDeclRef));
+ for (auto arg : args)
+ {
+ m_operands.add(ValNodeOperand(arg));
+ }
+ }
+
+ GenericAppDeclRef(Decl* innerDecl, DeclRefBase* genericDeclRef, ConstArrayView<Val*> args)
+ {
+ m_operands.add(ValNodeOperand(innerDecl));
+ m_operands.add(ValNodeOperand(genericDeclRef));
+ for (auto arg : args)
+ {
+ m_operands.add(ValNodeOperand(arg));
+ }
+ }
+
+ GenericDecl* getGenericDecl();
+
+ OperandView<Val> getArgs() { return OperandView<Val>(this, 2, getArgCount()); }
+
+ DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+
+ void _toTextOverride(StringBuilder& out);
+
+ Val* _resolveImplOverride();
+
+ DeclRefBase* _getBaseOverride();
+};
+
// A compile-time integer (may not have a specific concrete value)
class IntVal : public Val
{
SLANG_ABSTRACT_AST_CLASS(IntVal)
- Type* type;
-
- IntVal(Type* inType)
- : type(inType)
- {}
+ Type* getType() { return as<Type>(getOperand(0)); }
+ Val* _resolveImplOverride() { return this; }
};
// Trivial case of a value that is just a constant integer
@@ -26,18 +148,15 @@ class ConstantIntVal : public IntVal
{
SLANG_AST_CLASS(ConstantIntVal)
- IntegerLiteralValue value;
+ IntegerLiteralValue getValue() { return getIntConstOperand(1); }
// Overrides should be public so base classes can access
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
-protected:
ConstantIntVal(Type* inType, IntegerLiteralValue inValue)
- : IntVal(inType), value(inValue)
- {}
-
+ {
+ setOperands(inType, inValue);
+ }
};
// The logical "value" of a reference to a generic value parameter
@@ -45,30 +164,31 @@ class GenericParamIntVal : public IntVal
{
SLANG_AST_CLASS(GenericParamIntVal)
- DeclRef<VarDeclBase> declRef;
+ DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(1)); }
// Overrides should be public so base classes can access
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef)
- : IntVal(inType), declRef(inDeclRef)
- {}
+ {
+ setOperands(inType, inDeclRef);
+ }
};
class TypeCastIntVal : public IntVal
{
SLANG_AST_CLASS(TypeCastIntVal)
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride();
- Val* base;
- TypeCastIntVal(Type* inType, Val* inBase) : IntVal(inType), base(inBase) {}
+ Val* getBase() { return getOperand(1); }
+ TypeCastIntVal(Type* inType, Val* inBase)
+ {
+ setOperands(inType, inBase);
+ }
static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink);
};
@@ -78,16 +198,21 @@ class FuncCallIntVal : public IntVal
{
SLANG_AST_CLASS(FuncCallIntVal)
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride();
- DeclRef<Decl> funcDeclRef;
- Type* funcType;
- List<IntVal*> args;
+ DeclRef<Decl> getFuncDeclRef() { return as<DeclRefBase>(getOperand(1)); }
+ Type* getFuncType() { return as<Type>(getOperand(2)); }
+ OperandView<IntVal> getArgs() { return OperandView<IntVal>(this, 3, getOperandCount() - 3); }
+ Index getArgCount() { return getOperandCount() - 3; }
- FuncCallIntVal(Type* inType) : IntVal(inType) {}
+ FuncCallIntVal(Type* inType, DeclRef<Decl> inFuncDeclRef, Type* inFuncType, ArrayView<IntVal*> inArgs)
+ {
+ setOperands(inType, inFuncDeclRef, inFuncType);
+ for (auto arg : inArgs)
+ m_operands.add(ValNodeOperand(arg));
+ }
static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink);
};
@@ -96,15 +221,17 @@ class WitnessLookupIntVal : public IntVal
{
SLANG_AST_CLASS(WitnessLookupIntVal)
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride();
- SubtypeWitness* witness;
- Decl* key;
+ SubtypeWitness* getWitness() { return as<SubtypeWitness>(getOperand(1)); }
+ Decl* getKey() { return as<Decl>(getDeclOperand(2)); }
- WitnessLookupIntVal(Type* inType) : IntVal(inType) {}
+ WitnessLookupIntVal(Type* inType, SubtypeWitness* witness, Decl* key)
+ {
+ setOperands(inType, witness, key);
+ }
static Val* tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key);
@@ -113,23 +240,31 @@ class WitnessLookupIntVal : public IntVal
// polynomial expression "2*a*b^3 + 1" will be represented as:
// { constantTerm:1, terms: [ { constFactor:2, paramFactors:[{"a", 1}, {"b", 3}] } ] }
-class PolynomialIntValFactor : public NodeBase
+class PolynomialIntValFactor : public Val
{
SLANG_AST_CLASS(PolynomialIntValFactor)
public:
- IntVal* param;
- IntegerLiteralValue power;
+ IntVal* getParam() const { return as<IntVal>(getOperand(0)); }
+ IntegerLiteralValue getPower() const { return getIntConstOperand(1); }
+
+ PolynomialIntValFactor(IntVal* inParam, IntegerLiteralValue inPower)
+ {
+ setOperands(inParam, inPower);
+ }
+
+ Val* _resolveImplOverride();
+
// for sorting only.
bool operator<(const PolynomialIntValFactor& other) const
{
- if (auto thisGenParam = as<GenericParamIntVal>(param))
+ if (auto thisGenParam = as<GenericParamIntVal>(getParam()))
{
- if (auto thatGenParam = as<GenericParamIntVal>(other.param))
+ if (auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
{
- if (thisGenParam->equalsVal(thatGenParam))
- return power < other.power;
+ if (thisGenParam->equals(thatGenParam))
+ return getPower() < other.getPower();
else
- return thisGenParam->declRef.getDecl() < thatGenParam->declRef.getDecl();
+ return thisGenParam->getDeclRef().getDecl() < thatGenParam->getDeclRef().getDecl();
}
else
{
@@ -138,64 +273,84 @@ public:
}
else
{
- if (const auto thatGenParam = as<GenericParamIntVal>(other.param))
+ if (const auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
{
return false;
}
- return param == other.param ? power < other.power : param < other.param;
+ return getParam() == other.getParam() ? getPower() < other.getPower() : getParam() < other.getParam();
}
}
// for sorting only.
bool operator==(const PolynomialIntValFactor& other) const
{
- if (auto thisGenParam = as<GenericParamIntVal>(param))
+ if (auto thisGenParam = as<GenericParamIntVal>(getParam()))
{
- if (auto thatGenParam = as<GenericParamIntVal>(other.param))
+ if (auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
{
- if (thisGenParam->equalsVal(thatGenParam) && power == other.power)
+ if (thisGenParam->equals(thatGenParam) && getPower() == other.getPower())
return true;
}
return false;
}
- return power == other.power && param == other.param;
+ return getPower() == other.getPower() && getParam() == other.getParam();
}
bool equals(const PolynomialIntValFactor& other) const
{
- return power == other.power && param->equalsVal(other.param);
+ return getPower() == other.getPower() && getParam()->equals(other.getParam());
}
};
-class PolynomialIntValTerm : public NodeBase
+class PolynomialIntValTerm : public Val
{
SLANG_AST_CLASS(PolynomialIntValTerm)
public:
- IntegerLiteralValue constFactor;
- List<PolynomialIntValFactor*> paramFactors;
+ IntegerLiteralValue getConstFactor() const { return getIntConstOperand(0); }
+ OperandView<PolynomialIntValFactor> getParamFactors() const { return OperandView<PolynomialIntValFactor>(this, 1, getOperandCount() - 1); }
+
+ Val* _resolveImplOverride();
+
+ PolynomialIntValTerm(IntegerLiteralValue inConstFactor, ArrayView<PolynomialIntValFactor*> inParamFactors)
+ {
+ setOperands(inConstFactor);
+ addOperands(inParamFactors);
+ }
+
+ PolynomialIntValTerm(IntegerLiteralValue inConstFactor, OperandView<PolynomialIntValFactor> inParamFactors)
+ {
+ setOperands(inConstFactor);
+ addOperands(inParamFactors);
+ }
+
bool canCombineWith(const PolynomialIntValTerm& other) const
{
- if (paramFactors.getCount() != other.paramFactors.getCount())
+ auto paramFactors = getParamFactors();
+ if (paramFactors.getCount() != other.getParamFactors().getCount())
return false;
- for (Index i = 0; i < paramFactors.getCount(); i++)
+ for (Index i = 0; i < getParamFactors().getCount(); i++)
{
- if (!paramFactors[i]->equals(*other.paramFactors[i]))
+ if (!paramFactors[i]->equals(*other.getParamFactors()[i]))
return false;
}
return true;
}
bool operator<(const PolynomialIntValTerm& other) const
{
- if (constFactor < other.constFactor)
+ auto constFactor = getConstFactor();
+ auto paramFactors = getParamFactors();
+
+ if (constFactor < other.getConstFactor())
return true;
- else if (constFactor == other.constFactor)
+ else if (constFactor == other.getConstFactor())
{
+ auto otherParamFactors = other.getParamFactors();
for (Index i = 0; i < paramFactors.getCount(); i++)
{
- if (i >= other.paramFactors.getCount())
+ if (i >= otherParamFactors.getCount())
return false;
- if (*(paramFactors[i]) < *(other.paramFactors[i]))
+ if (*(paramFactors[i]) < *(otherParamFactors[i]))
return true;
- if (*(paramFactors[i]) == *(other.paramFactors[i]))
+ if (*(paramFactors[i]) == *(otherParamFactors[i]))
{
}
else
@@ -213,27 +368,25 @@ class PolynomialIntVal : public IntVal
SLANG_AST_CLASS(PolynomialIntVal)
public:
- List<PolynomialIntValTerm*> terms;
- IntegerLiteralValue constantTerm = 0;
+ IntegerLiteralValue getConstantTerm() { return getIntConstOperand(1); };
+ OperandView<PolynomialIntValTerm> getTerms() { return OperandView<PolynomialIntValTerm>(this, 2, getOperandCount() - 2); };
- bool isConstant() { return terms.getCount() == 0; }
- // Canonicalize the polynomial. If the polynomial can be simplified to a constant or a genericparam,
- // the method returns the value simplified to.
- // Otherwise, in-place modifications are performed and returns this.
- IntVal* canonicalize(ASTBuilder* builder);
+ bool isConstant() { return getOperandCount() == 1; }
// Overrides should be public so base classes can access
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride();
- static PolynomialIntVal* neg(ASTBuilder* astBuilder, IntVal* base);
- static PolynomialIntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
- static PolynomialIntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
- static PolynomialIntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
- PolynomialIntVal(Type* inType) : IntVal(inType) {}
-
+ static IntVal* neg(ASTBuilder* astBuilder, IntVal* base);
+ static IntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
+ static IntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
+ static IntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
+ PolynomialIntVal(Type* inType, IntegerLiteralValue inConstantTerm, ArrayView<PolynomialIntValTerm*> inTerms)
+ {
+ setOperands(inType, inConstantTerm);
+ addOperands(inTerms);
+ }
};
/// An unknown integer value indicating an erroneous sub-expression
@@ -241,17 +394,16 @@ class ErrorIntVal : public IntVal
{
SLANG_AST_CLASS(ErrorIntVal)
- ErrorIntVal(Type* inType) : IntVal(inType) {}
+ ErrorIntVal(Type* inType) { setOperands(inType); }
// TODO: We should probably eventually just have an `ErrorVal` here
// and have all `Val`s that represent ordinary values hold their
// `Type` so that we can have an `ErrorVal` of any type.
// Overrides should be public so base classes can access
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride() { return this; }
};
// A witness to the fact that some proposition is true, encoded
@@ -301,25 +453,23 @@ class SubtypeWitness : public Witness
{
SLANG_ABSTRACT_AST_CLASS(SubtypeWitness)
- Type* sub = nullptr;
- Type* sup = nullptr;
+ Val* _resolveImplOverride();
+
+ Type* getSub() { return as<Type>(getOperand(0)); }
+ Type* getSup() { return as<Type>(getOperand(1)); }
};
class TypeEqualityWitness : public SubtypeWitness
{
SLANG_AST_CLASS(TypeEqualityWitness)
- TypeEqualityWitness(
- Type* type)
+ TypeEqualityWitness(Type* subType, Type* supType)
{
- this->sub = type;
- this->sup = type;
+ setOperands(subType, supType);
}
// Overrides should be public so base classes can access
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
@@ -329,19 +479,19 @@ class DeclaredSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(DeclaredSubtypeWitness)
- DeclRef<Decl> declRef;
+ DeclRef<Decl> getDeclRef()
+ {
+ return as<DeclRefBase>(getOperand(2));
+ }
// Overrides should be public so base classes can access
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride();
DeclaredSubtypeWitness(Type* inSub, Type* inSup, DeclRef<Decl> inDeclRef)
- : declRef(inDeclRef)
{
- sub = inSub;
- sup = inSup;
+ setOperands(inSub, inSup, inDeclRef);
}
};
@@ -351,20 +501,25 @@ class TransitiveSubtypeWitness : public SubtypeWitness
SLANG_AST_CLASS(TransitiveSubtypeWitness)
// Witness that `sub : mid`
- SubtypeWitness* subToMid = nullptr;
+ SubtypeWitness* getSubToMid()
+ {
+ return as<SubtypeWitness>(getOperand(2));
+ }
// Witness that `mid : sup`
- SubtypeWitness* midToSup = nullptr;
+ SubtypeWitness* getMidToSup()
+ {
+ return as<SubtypeWitness>(getOperand(3));
+ }
// Overrides should be public so base classes can access
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
- TransitiveSubtypeWitness(SubtypeWitness* inSubToMid, SubtypeWitness* inMidToSup)
- : subToMid(inSubToMid), midToSup(inMidToSup)
- {}
+ TransitiveSubtypeWitness(Type* subType, Type* supType, SubtypeWitness* inSubToMid, SubtypeWitness* inMidToSup)
+ {
+ setOperands(subType, supType, inSubToMid, inMidToSup);
+ }
};
// A witness that `sub : sup` because `sub` was wrapped into
@@ -374,52 +529,27 @@ class ExtractExistentialSubtypeWitness : public SubtypeWitness
SLANG_AST_CLASS(ExtractExistentialSubtypeWitness)
// The declaration of the existential value that has been opened
- DeclRef<VarDeclBase> declRef;
-
- // Overrides should be public so base classes can access
- bool _equalsValOverride(Val* val);
- void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
- Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
-};
-
-// A witness that `sub : sup`, because `sub` is a tagged union
-// of the form `A | B | C | ...` and each of `A : sup`,
-// `B : sup`, `C : sup`, etc.
-//
-class TaggedUnionSubtypeWitness : public SubtypeWitness
-{
- SLANG_AST_CLASS(TaggedUnionSubtypeWitness)
+ DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(2)); }
- // Witnesses that each of the "case" types in the union
- // is a subtype of `sup`.
- //
- List<SubtypeWitness*> caseWitnesses;
+ ExtractExistentialSubtypeWitness(Type* inSub, Type* inSup, DeclRef<Decl> inDeclRef)
+ {
+ setOperands(inSub, inSup, inDeclRef);
+ }
// Overrides should be public so base classes can access
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
- /// A witness of the fact that `ThisType(someInterface) : someInterface`
-class ThisTypeSubtypeWitness : public SubtypeWitness
-{
- SLANG_AST_CLASS(ThisTypeSubtypeWitness)
-
- ThisTypeSubtypeWitness(Type* subType, Type* supType)
- {
- sub = subType;
- sup = supType;
- }
-};
-
/// A witness of the fact that a user provided "__Dynamic" type argument is a
/// subtype to the existential type parameter.
class DynamicSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(DynamicSubtypeWitness)
+ DynamicSubtypeWitness(Type* inSub, Type* inSup)
+ {
+ setOperands(inSub, inSup);
+ }
};
/// A witness that `T : L & R` because `T : L` and `T : R`
@@ -431,23 +561,24 @@ class ConjunctionSubtypeWitness : public SubtypeWitness
// an operation that takes two witness tables `leftWitness`
// and `rightWitness`, and forms a pair/tuple of
// `(leftWitness, rightWitness)`.
+ static const int kComponentCount = 2;
- static const Count kComponentCount = 2;
- SubtypeWitness* componentWitnesses[kComponentCount];
+ ConjunctionSubtypeWitness(Type* inSub, Type* inSup, SubtypeWitness* left, SubtypeWitness* right)
+ {
+ setOperands(inSub, inSup, left, right);
+ }
- SubtypeWitness* getLeftWitness() const { return componentWitnesses[0]; }
- SubtypeWitness* getRightWitness() const { return componentWitnesses[1]; }
+ SubtypeWitness* getLeftWitness() const { return as<SubtypeWitness>(getOperand(2)); }
+ SubtypeWitness* getRightWitness() const { return as<SubtypeWitness>(getOperand(3)); }
- Count getComponentCount() const { return kComponentCount; }
+ Count getComponentCount() const { return 2; }
SubtypeWitness* getComponentWitness(Index index) const
{
SLANG_ASSERT(index >= 0 && index < kComponentCount);
- return componentWitnesses[index];
+ return as<SubtypeWitness>(getOperand(2 + index));
}
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
@@ -461,19 +592,22 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness
// `(leftWtiness, rightWitness)` and extracts one of the
// elements of it.
- /// Witness that `T < L & R`
- SubtypeWitness* conjunctionWitness;
+ /// Witness that `T < L & R`
+ SubtypeWitness* getConjunctionWitness() { return as<SubtypeWitness>(getOperand(2)); };
+
+ ExtractFromConjunctionSubtypeWitness(Type* inSub, Type* inSup, SubtypeWitness* witness, int index)
+ {
+ setOperands(inSub, inSup, witness, index);
+ }
/// The zero-based index of the super-type we care about in the conjunction
///
/// If `conjunctionWitness` is `T < L & R` then this index should be zero if
/// we want to represent `T < L` and one if we want `T < R`.
///
- int indexInConjunction;
+ int getIndexInConjunction() { return (int)getIntConstOperand(3); };
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
@@ -482,8 +616,7 @@ class ModifierVal : public Val
{
SLANG_AST_CLASS(ModifierVal)
- bool _equalsValOverride(Val* val);
- HashCode _getHashCodeOverride();
+ Val* _resolveImplOverride() { return this; }
};
class TypeModifierVal : public ModifierVal
@@ -525,37 +658,91 @@ class DifferentiateVal : public Val
{
SLANG_AST_CLASS(DifferentiateVal)
- DeclRef<Decl> func;
+ DifferentiateVal(DeclRef<Decl> inFunc)
+ {
+ setOperands(inFunc);
+ }
+
+ DeclRef<Decl> getFunc() { return as<DeclRefBase>(getOperand(0)); }
- bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride();
};
class ForwardDifferentiateVal : public DifferentiateVal
{
SLANG_AST_CLASS(ForwardDifferentiateVal)
+ ForwardDifferentiateVal(DeclRef<Decl> inFunc)
+ : DifferentiateVal(inFunc)
+ {}
};
class BackwardDifferentiateVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiateVal)
+
+ BackwardDifferentiateVal(DeclRef<Decl> inFunc)
+ : DifferentiateVal(inFunc)
+ {}
};
class BackwardDifferentiateIntermediateTypeVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiateIntermediateTypeVal)
+
+ BackwardDifferentiateIntermediateTypeVal(DeclRef<Decl> inFunc)
+ : DifferentiateVal(inFunc)
+ {}
};
class BackwardDifferentiatePrimalVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiatePrimalVal)
+
+ BackwardDifferentiatePrimalVal(DeclRef<Decl> inFunc)
+ : DifferentiateVal(inFunc)
+ {}
};
class BackwardDifferentiatePropagateVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiatePropagateVal)
+
+ BackwardDifferentiatePropagateVal(DeclRef<Decl> inFunc)
+ : DifferentiateVal(inFunc)
+ {}
};
+
+template<typename F>
+void SubstitutionSet::forEachGenericSubstitution(F func) const
+{
+ if (!declRef)
+ return;
+ for (auto subst = declRef; subst; subst = subst->getBase())
+ {
+ if (auto genSubst = as<GenericAppDeclRef>(subst))
+ func(genSubst->getGenericDecl(), genSubst->getArgs());
+ }
+}
+
+template<typename F>
+void SubstitutionSet::forEachSubstitutionArg(F func) const
+{
+ if (!declRef)
+ return;
+ for (auto subst = declRef; subst; subst = subst->getBase())
+ {
+ if (auto genSubst = as<GenericAppDeclRef>(subst))
+ {
+ for (auto arg : genSubst->getArgs())
+ func(arg);
+ }
+ else if (auto thisSubst = as<LookupDeclRef>(subst))
+ {
+ func(thisSubst->getWitness()->getSub());
+ }
+ }
+}
} // namespace Slang
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index 1d19e01bf..4376b1135 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -54,7 +54,7 @@ namespace Slang
Type* superType)
{
SubtypeWitness* result = nullptr;
- if (getShared()->tryGetSubtypeWitness(subType, superType, result))
+ if (getShared()->tryGetSubtypeWitnessFromCache(subType, superType, result))
return result;
result = checkAndConstructSubtypeWitness(subType, superType);
getShared()->cacheSubtypeWitness(subType, superType, result);
@@ -107,11 +107,11 @@ namespace Slang
// First, make sure both sub type and super type decl are ready for lookup.
if (auto subDeclRefType = as<DeclRefType>(subType))
{
- ensureDecl(subDeclRefType->declRef.getDecl(), DeclCheckState::ReadyForLookup);
+ ensureDecl(subDeclRefType->getDeclRef().getDecl(), DeclCheckState::ReadyForLookup);
}
if (auto superDeclRefType = as<DeclRefType>(subType))
{
- ensureDecl(superDeclRefType->declRef.getDecl(), DeclCheckState::ReadyForLookup);
+ ensureDecl(superDeclRefType->getDeclRef().getDecl(), DeclCheckState::ReadyForLookup);
}
// In the common case, we can use the pre-computed inheritance information for `subType`
@@ -173,13 +173,13 @@ namespace Slang
DeclRef<Decl> superTypeDeclRef;
if (auto superDeclRefType = as<DeclRefType>(superType))
{
- superTypeDeclRef = superDeclRefType->declRef;
+ superTypeDeclRef = superDeclRefType->getDeclRef();
}
- if (auto dynamicType = as<DynamicType>(subType))
+ if (as<DynamicType>(subType))
{
// A __Dynamic type always conforms to the interface via its witness table.
- auto witness = m_astBuilder->create<DynamicSubtypeWitness>();
+ auto witness = m_astBuilder->getOrCreate<DynamicSubtypeWitness>(subType, superType);
return witness;
}
else if (auto conjunctionSuperType = as<AndType>(superType))
@@ -189,10 +189,10 @@ namespace Slang
// We therefore simply recursively test both `T <: L`
// and `T <: R`.
//
- auto leftWitness = isSubtype(subType, conjunctionSuperType->left);
+ auto leftWitness = isSubtype(subType, conjunctionSuperType->getLeft());
if (!leftWitness) return nullptr;
//
- auto rightWitness = isSubtype(subType, conjunctionSuperType->right);
+ auto rightWitness = isSubtype(subType, conjunctionSuperType->getRight());
if (!rightWitness) return nullptr;
// If both of the sub-relationships hold, we can construct
@@ -214,7 +214,7 @@ namespace Slang
// TODO(tfoley): We could add support for `ExtractExistentialType` to
// the inheritance linearization logic, and eliminate this case.
//
- auto interfaceDeclRef = extractExistentialType->originalInterfaceDeclRef;
+ auto interfaceDeclRef = extractExistentialType->getOriginalInterfaceDeclRef();
if (interfaceDeclRef.equals(superTypeDeclRef))
{
auto witness = extractExistentialType->getSubtypeWitness();
@@ -222,62 +222,6 @@ namespace Slang
}
return nullptr;
}
- //
- // TODO(tfoley): We should probably just remove `TaggedUnionType`,
- // since there is no useful code that relies on it any more.
- //
- else if(auto taggedUnionType = as<TaggedUnionType>(subType))
- {
- // A tagged union type conforms to an interface if all of
- // the constituent types in the tagged union conform.
- //
- // We will iterate over the "case" types in the tagged
- // union, and check if they conform to the interface.
- // Along the way we will collect the conformance witness
- // values for the case types.
- //
- List<SubtypeWitness*> caseWitnesses;
- for(auto caseType : taggedUnionType->caseTypes)
- {
- auto caseWitness = isSubtype(caseType, superType);
-
- if(!caseWitness)
- {
- return nullptr;
- }
-
- caseWitnesses.add(caseWitness);
- }
-
- // We also need to validate the requirements on
- // the interface to make sure that they are suitable for
- // use with a tagged-union type.
- //
- // For example, if the interface includes a `static` method
- // (which can therefore be called without a particular instance),
- // then we wouldn't know what implementation of that method
- // to use because there is no tag value to dispatch on.
- //
- // We will start out being conservative about what we accept
- // here, just to keep things simple.
- //
- if( auto superInterfaceDeclRef = superTypeDeclRef.as<InterfaceDecl>() )
- {
- if(!isInterfaceSafeForTaggedUnion(superInterfaceDeclRef))
- return nullptr;
- }
-
- // If we reach this point then we have a concrete
- // witness for each of the case types, and that is
- // enough to build a witness for the tagged union.
- //
- TaggedUnionSubtypeWitness* taggedUnionWitness = m_astBuilder->create<TaggedUnionSubtypeWitness>();
- taggedUnionWitness->sub = taggedUnionType;
- taggedUnionWitness->sup = superType;
- taggedUnionWitness->caseWitnesses.swapWith(caseWitnesses);
-
- return taggedUnionWitness;
- }
// default is failure
return nullptr;
@@ -287,7 +231,7 @@ namespace Slang
{
if (auto declRefType = as<DeclRefType>(type))
{
- if (auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>())
+ if (auto interfaceDeclRef = declRefType->getDeclRef().as<InterfaceDecl>())
return true;
}
return false;
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 22a92bf0a..b9d33a1c1 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -65,14 +65,14 @@ namespace Slang
// That is, the join of a vector and a scalar type is
// a vector type with a joined element type.
auto joinElementType = TryJoinTypes(
- vectorType->elementType,
+ vectorType->getElementType(),
scalarType);
if(!joinElementType)
return nullptr;
return createVectorType(
joinElementType,
- vectorType->elementCount);
+ vectorType->getElementCount());
}
Type* SemanticsVisitor::_tryJoinTypeWithInterface(
@@ -110,11 +110,11 @@ namespace Slang
for(Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); baseTypeFlavorIndex++)
{
// Don't consider `type`, since we already know it doesn't work.
- if(baseTypeFlavorIndex == Int(basicType->baseType))
+ if(baseTypeFlavorIndex == Int(basicType->getBaseType()))
continue;
// Look up the type in our session.
- auto candidateType = type->getASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex));
+ auto candidateType = getCurrentASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex));
if(!candidateType)
continue;
@@ -186,8 +186,8 @@ namespace Slang
{
if (auto rightBasic = as<BasicExpressionType>(right))
{
- auto leftFlavor = leftBasic->baseType;
- auto rightFlavor = rightBasic->baseType;
+ auto leftFlavor = leftBasic->getBaseType();
+ auto rightFlavor = rightBasic->getBaseType();
// TODO(tfoley): Need a special-case rule here that if
// either operand is of type `half`, then we promote
@@ -217,19 +217,19 @@ namespace Slang
if(auto rightVector = as<VectorExpressionType>(right))
{
// Check if the vector sizes match
- if(!leftVector->elementCount->equalsVal(rightVector->elementCount))
+ if(!leftVector->getElementCount()->equals(rightVector->getElementCount()))
return nullptr;
// Try to join the element types
auto joinElementType = TryJoinTypes(
- leftVector->elementType,
- rightVector->elementType);
+ leftVector->getElementType(),
+ rightVector->getElementType());
if(!joinElementType)
return nullptr;
return createVectorType(
joinElementType,
- leftVector->elementCount);
+ leftVector->getElementCount());
}
// We can also join a vector and a scalar
@@ -242,7 +242,7 @@ namespace Slang
// HACK: trying to work trait types in here...
if(auto leftDeclRefType = as<DeclRefType>(left))
{
- if( auto leftInterfaceRef = leftDeclRefType->declRef.as<InterfaceDecl>() )
+ if( auto leftInterfaceRef = leftDeclRefType->getDeclRef().as<InterfaceDecl>() )
{
//
return _tryJoinTypeWithInterface(right, left);
@@ -250,7 +250,7 @@ namespace Slang
}
if(auto rightDeclRefType = as<DeclRefType>(right))
{
- if( auto rightInterfaceRef = rightDeclRefType->declRef.as<InterfaceDecl>() )
+ if( auto rightInterfaceRef = rightDeclRefType->getDeclRef().as<InterfaceDecl>() )
{
//
return _tryJoinTypeWithInterface(left, right);
@@ -263,10 +263,10 @@ namespace Slang
return nullptr;
}
- SubstitutionSet SemanticsVisitor::trySolveConstraintSystem(
+ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
ConstraintSystem* system,
DeclRef<GenericDecl> genericDeclRef,
- GenericSubstitution* substWithKnownGenericArgs)
+ ArrayView<Val*> knownGenericArgs)
{
// For now the "solver" is going to be ridiculously simplistic.
@@ -288,9 +288,8 @@ namespace Slang
for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, genericDeclRef) )
{
if(!TryUnifyTypes(*system, getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef)))
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
- SubstitutionSet resultSubst = genericDeclRef.getSubst();
// Once have built up the full list of constraints we are trying to satisfy,
// we will attempt to solve for each parameter in a way that satisfies all
@@ -310,10 +309,10 @@ namespace Slang
// or not they are compatible with the constraints).
//
Count knownGenericArgCount = 0;
- if (substWithKnownGenericArgs)
+ if (knownGenericArgs.getCount())
{
- knownGenericArgCount = substWithKnownGenericArgs->getArgs().getCount();
- for (auto arg : substWithKnownGenericArgs->getArgs())
+ knownGenericArgCount = knownGenericArgs.getCount();
+ for (auto arg : knownGenericArgs)
{
args.add(arg);
}
@@ -364,7 +363,7 @@ namespace Slang
if (!joinType)
{
// failure!
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
type = joinType;
}
@@ -375,7 +374,7 @@ namespace Slang
if (!type)
{
// failure!
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
args.add(type);
}
@@ -417,10 +416,10 @@ namespace Slang
}
else
{
- if(!val->equalsVal(cVal))
+ if(!val->equals(cVal))
{
// failure!
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
}
@@ -430,7 +429,7 @@ namespace Slang
if (!val)
{
// failure!
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
args.add(val);
}
@@ -456,14 +455,10 @@ namespace Slang
// search for a conformance `Robin : ISidekick`, which involved
// apply the substitutions we already know...
- GenericSubstitution* solvedSubst = m_astBuilder->getOrCreateGenericSubstitution(
- genericDeclRef.getSubst(), genericDeclRef.getDecl(), args.getArrayView());
-
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
- DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getSpecializedDeclRef(
- constraintDecl,
- solvedSubst);
+ DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(
+ genericDeclRef, args.getArrayView(), constraintDecl).as<GenericTypeConstraintDecl>();
// Extract the (substituted) sub- and super-type from the constraint.
auto sub = getSub(m_astBuilder, constraintDeclRef);
@@ -476,7 +471,7 @@ namespace Slang
// not provide an explicit type parameter to specialize a generic
// and the type parameter cannot be inferred from any arguments.
// In this case, we should fail the constraint check.
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
// Search for a witness that shows the constraint is satisfied.
@@ -492,7 +487,7 @@ namespace Slang
//
// TODO: Ideally we should print an error message in
// this case, to let the user know why things failed.
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
// TODO: We may need to mark some constrains in our constraint
@@ -505,13 +500,11 @@ namespace Slang
{
if (!c.satisfied)
{
- return SubstitutionSet();
+ return DeclRef<Decl>();
}
}
- resultSubst = m_astBuilder->getOrCreateGenericSubstitution(
- genericDeclRef.getSubst(), genericDeclRef.getDecl(), args);
- return resultSubst;
+ return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView());
}
bool SemanticsVisitor::TryUnifyVals(
@@ -533,7 +526,7 @@ namespace Slang
{
if (auto sndIntVal = as<ConstantIntVal>(snd))
{
- return fstIntVal->value == sndIntVal->value;
+ return fstIntVal->getValue() == sndIntVal->getValue();
}
}
@@ -541,23 +534,23 @@ namespace Slang
if (auto fstInt = as<IntVal>(fst))
{
if (auto tc = as<TypeCastIntVal>(fstInt))
- fstInt = as<IntVal>(tc->base);
+ fstInt = as<IntVal>(tc->getBase());
if (auto sndInt = as<IntVal>(snd))
{
if (auto tc = as<TypeCastIntVal>(sndInt))
- sndInt = as<IntVal>(tc->base);
+ sndInt = as<IntVal>(tc->getBase());
auto fstParam = as<GenericParamIntVal>(fstInt);
auto sndParam = as<GenericParamIntVal>(sndInt);
bool okay = false;
if (fstParam)
{
- if(TryUnifyIntParam(constraints, fstParam->declRef, sndInt))
+ if(TryUnifyIntParam(constraints, fstParam->getDeclRef(), sndInt))
okay = true;
}
if (sndParam)
{
- if(TryUnifyIntParam(constraints, sndParam->declRef, fstInt))
+ if(TryUnifyIntParam(constraints, sndParam->getDeclRef(), fstInt))
okay = true;
}
return okay;
@@ -568,8 +561,8 @@ namespace Slang
{
if (auto sndWit = as<DeclaredSubtypeWitness>(snd))
{
- auto constraintDecl1 = fstWit->declRef.as<TypeConstraintDecl>();
- auto constraintDecl2 = sndWit->declRef.as<TypeConstraintDecl>();
+ auto constraintDecl1 = fstWit->getDeclRef().as<TypeConstraintDecl>();
+ auto constraintDecl2 = sndWit->getDeclRef().as<TypeConstraintDecl>();
SLANG_ASSERT(constraintDecl1);
SLANG_ASSERT(constraintDecl2);
return TryUnifyTypes(constraints,
@@ -586,8 +579,8 @@ namespace Slang
if (auto sndWit = as<SubtypeWitness>(snd))
{
return TryUnifyTypes(constraints,
- fstWit->sup,
- sndWit->sup);
+ fstWit->getSup(),
+ sndWit->getSup());
}
}
@@ -597,35 +590,28 @@ namespace Slang
//return false;
}
- bool SemanticsVisitor::tryUnifySubstitutions(
- ConstraintSystem& constraints,
- Substitutions* fst,
- Substitutions* snd)
+ bool SemanticsVisitor::tryUnifyDeclRef(
+ ConstraintSystem& constraints,
+ DeclRefBase* fst,
+ DeclRefBase* snd)
{
- // They must both be NULL or non-NULL
- if (!fst || !snd)
- return !fst && !snd;
-
- if(auto fstGeneric = as<GenericSubstitution>(fst))
- {
- if(auto sndGeneric = as<GenericSubstitution>(snd))
- {
- return tryUnifyGenericSubstitutions(
- constraints,
- fstGeneric,
- sndGeneric);
- }
- }
-
- // TODO: need to handle other cases here
-
- return false;
+ if (fst == snd)
+ return true;
+ if (fst == nullptr || snd == nullptr)
+ return false;
+ auto fstGen = SubstitutionSet(fst).findGenericAppDeclRef();
+ auto sndGen = SubstitutionSet(snd).findGenericAppDeclRef();
+ if (fstGen == sndGen)
+ return true;
+ if (fstGen == nullptr || sndGen == nullptr)
+ return false;
+ return tryUnifyGenericAppDeclRef(constraints, fstGen, sndGen);
}
- bool SemanticsVisitor::tryUnifyGenericSubstitutions(
+ bool SemanticsVisitor::tryUnifyGenericAppDeclRef(
ConstraintSystem& constraints,
- GenericSubstitution* fst,
- GenericSubstitution* snd)
+ GenericAppDeclRef* fst,
+ GenericAppDeclRef* snd)
{
SLANG_ASSERT(fst);
SLANG_ASSERT(snd);
@@ -649,7 +635,10 @@ namespace Slang
}
// Their "base" specializations must unify
- if (!tryUnifySubstitutions(constraints, fstGen->getOuter(), sndGen->getOuter()))
+ auto fstBase = fst->getBase();
+ auto sndBase = snd->getBase();
+
+ if (!tryUnifyDeclRef(constraints, fstBase, sndBase))
{
okay = false;
}
@@ -718,14 +707,14 @@ namespace Slang
{
if (auto fstDeclRefType = as<DeclRefType>(fst))
{
- auto fstDeclRef = fstDeclRefType->declRef;
+ auto fstDeclRef = fstDeclRefType->getDeclRef();
if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
return TryUnifyTypeParam(constraints, typeParamDecl, snd);
if (auto sndDeclRefType = as<DeclRefType>(snd))
{
- auto sndDeclRef = sndDeclRefType->declRef;
+ auto sndDeclRef = sndDeclRefType->getDeclRef();
if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl()))
return TryUnifyTypeParam(constraints, typeParamDecl, fst);
@@ -735,10 +724,10 @@ namespace Slang
// next we need to unify the substitutions applied
// to each declaration reference.
- if (!tryUnifySubstitutions(
+ if (!tryUnifyDeclRef(
constraints,
- fstDeclRef.getSubst(),
- sndDeclRef.getSubst()))
+ fstDeclRef,
+ sndDeclRef))
{
return false;
}
@@ -749,15 +738,15 @@ namespace Slang
{
if (auto sndFunType = as<FuncType>(snd))
{
- const Index numParams = fstFunType->paramTypes.getCount();
- if(numParams != sndFunType->paramTypes.getCount())
+ const Index numParams = fstFunType->getParamCount();
+ if(numParams != sndFunType->getParamCount())
return false;
for(Index i = 0; i < numParams; ++i)
{
- if(!TryUnifyTypes(constraints, fstFunType->paramTypes[i], sndFunType->paramTypes[i]))
+ if(!TryUnifyTypes(constraints, fstFunType->getParamType(i), sndFunType->getParamType(i)))
return false;
}
- return TryUnifyTypes(constraints, fstFunType->resultType, sndFunType->resultType);
+ return TryUnifyTypes(constraints, fstFunType->getResultType(), sndFunType->getResultType());
}
}
@@ -779,13 +768,13 @@ namespace Slang
//
if (auto fstAndType = as<AndType>(fst))
{
- return TryUnifyTypes(constraints, fstAndType->left, snd)
- && TryUnifyTypes(constraints, fstAndType->right, snd);
+ return TryUnifyTypes(constraints, fstAndType->getLeft(), snd)
+ && TryUnifyTypes(constraints, fstAndType->getRight(), snd);
}
else if (auto sndAndType = as<AndType>(snd))
{
- return TryUnifyTypes(constraints, fst, sndAndType->left)
- || TryUnifyTypes(constraints, fst, sndAndType->right);
+ return TryUnifyTypes(constraints, fst, sndAndType->getLeft())
+ || TryUnifyTypes(constraints, fst, sndAndType->getRight());
}
else
return false;
@@ -828,7 +817,7 @@ namespace Slang
if (auto fstDeclRefType = as<DeclRefType>(fst))
{
- auto fstDeclRef = fstDeclRefType->declRef;
+ auto fstDeclRef = fstDeclRefType->getDeclRef();
if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
{
@@ -839,7 +828,7 @@ namespace Slang
if (auto sndDeclRefType = as<DeclRefType>(snd))
{
- auto sndDeclRef = sndDeclRefType->declRef;
+ auto sndDeclRef = sndDeclRefType->getDeclRef();
if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl()))
{
@@ -863,7 +852,7 @@ namespace Slang
{
return TryUnifyTypes(
constraints,
- fstVectorType->elementType,
+ fstVectorType->getElementType(),
sndScalarType);
}
}
@@ -875,7 +864,7 @@ namespace Slang
return TryUnifyTypes(
constraints,
fstScalarType,
- sndVectorType->elementType);
+ sndVectorType->getElementType());
}
}
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index abe9f4817..d89808c3d 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -61,7 +61,7 @@ namespace Slang
if(auto declRefType = as<DeclRefType>(type))
{
- if(as<StructDecl>(declRefType->declRef))
+ if(as<StructDecl>(declRefType->getDeclRef()))
return false;
}
@@ -174,7 +174,7 @@ namespace Slang
if(!baseDeclRefType)
return nullptr;
- auto baseDeclRef = baseDeclRefType->declRef;
+ auto baseDeclRef = baseDeclRefType->getDeclRef();
auto baseStructDeclRef = baseDeclRef.as<StructDecl>();
if(!baseStructDeclRef)
return nullptr;
@@ -193,7 +193,7 @@ namespace Slang
if (!baseDeclRefType)
return DeclRef<StructDecl>();
- auto baseDeclRef = baseDeclRefType->declRef;
+ auto baseDeclRef = baseDeclRefType->getDeclRef();
auto baseStructDeclRef = baseDeclRef.as<StructDecl>();
if (!baseStructDeclRef)
return DeclRef<StructDecl>();
@@ -244,13 +244,13 @@ namespace Slang
}
else if (auto toVecType = as<VectorExpressionType>(toType))
{
- auto toElementCount = toVecType->elementCount;
- auto toElementType = toVecType->elementType;
+ auto toElementCount = toVecType->getElementCount();
+ auto toElementType = toVecType->getElementType();
UInt elementCount = 0;
if (auto constElementCount = as<ConstantIntVal>(toElementCount))
{
- elementCount = (UInt) constElementCount->value;
+ elementCount = (UInt) constElementCount->getValue();
}
else
{
@@ -299,7 +299,7 @@ namespace Slang
UInt elementCount = 0;
if (auto constElementCount = as<ConstantIntVal>(toElementCount))
{
- elementCount = (UInt) constElementCount->value;
+ elementCount = (UInt) constElementCount->getValue();
}
else
{
@@ -388,7 +388,7 @@ namespace Slang
if (auto constRowCount = as<ConstantIntVal>(toMatrixType->getRowCount()))
{
- rowCount = (UInt) constRowCount->value;
+ rowCount = (UInt) constRowCount->getValue();
}
else
{
@@ -423,7 +423,7 @@ namespace Slang
}
else if(auto toDeclRefType = as<DeclRefType>(toType))
{
- auto toTypeDeclRef = toDeclRefType->declRef;
+ auto toTypeDeclRef = toDeclRefType->getDeclRef();
if(auto toStructDeclRef = toTypeDeclRef.as<StructDecl>())
{
// Trying to initialize a `struct` type given an initializer list.
@@ -570,7 +570,7 @@ namespace Slang
if( left == right )
return true;
- if( left->equalsVal(right) )
+ if( left->equals(right) )
return true;
return false;
@@ -581,9 +581,9 @@ namespace Slang
{
if(!type) return false;
- for( auto m : type->modifiers )
+ for (Index m = 0; m < type->getModifierCount(); m++)
{
- if(_doModifiersMatch(m, modifier))
+ if(_doModifiersMatch(type->getModifier(m), modifier))
return true;
}
@@ -632,7 +632,7 @@ namespace Slang
{
auto basicType = as<BasicExpressionType>(t);
if (!basicType) return false;
- switch (basicType->baseType)
+ switch (basicType->getBaseType())
{
case BaseType::Int8:
case BaseType::Int16:
@@ -650,7 +650,7 @@ namespace Slang
auto basicType = as<BasicExpressionType>(t);
if (!basicType) return 0;
- switch (basicType->baseType)
+ switch (basicType->getBaseType())
{
case BaseType::Int8:
case BaseType::UInt8:
@@ -770,10 +770,10 @@ namespace Slang
// on it, but the underlying types are otherwise the same.
//
auto toModified = as<ModifiedType>(toType);
- auto toBase = toModified ? toModified->base : toType;
+ auto toBase = toModified ? toModified->getBase() : toType;
//
auto fromModified = as<ModifiedType>(fromType);
- auto fromBase = fromModified ? fromModified->base : fromType;
+ auto fromBase = fromModified ? fromModified->getBase() : fromType;
if((toModified || fromModified) && toBase->equals(fromBase))
@@ -787,8 +787,9 @@ namespace Slang
//
if( toModified )
{
- for( auto modifier : toModified->modifiers )
+ for (Index m = 0; m < toModified->getModifierCount(); m++)
{
+ auto modifier = toModified->getModifier(m);
if(_hasMatchingModifier(fromModified, modifier))
continue;
@@ -804,8 +805,10 @@ namespace Slang
}
if( fromModified )
{
- for( auto modifier : fromModified->modifiers )
+ for (Index m = 0; m < fromModified->getModifierCount(); m++)
{
+ auto modifier = fromModified->getModifier(m);
+
if(_hasMatchingModifier(toModified, modifier))
continue;
@@ -923,7 +926,7 @@ namespace Slang
//
// TODO(tfoley): Under what circumstances would this check ever be needed?
//
- if (auto toParameterGroupType = as<ParameterGroupType>(toType))
+ if (as<ParameterGroupType>(toType))
{
return _failedCoercion(toType, outToExpr, fromExpr);
}
@@ -1141,7 +1144,7 @@ namespace Slang
{
if (auto val = as<ConstantIntVal>(intVal))
{
- if (isIntValueInRangeOfType(val->value, toType))
+ if (isIntValueInRangeOfType(val->getValue(), toType))
{
// OK.
shouldEmitGeneralWarning = false;
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index b1dd2d533..b6a5d94ef 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -13,6 +13,8 @@
#include "slang-lookup.h"
#include "slang-syntax.h"
#include "slang-ast-synthesis.h"
+#include "slang-ast-reflect.h"
+
#include <limits>
namespace Slang
@@ -201,12 +203,6 @@ namespace Slang
void visitDecl(Decl*) {}
void visitDeclGroup(DeclGroup*) {}
- Val* resolveVal(Val* val);
- Type* resolveType(Type* type)
- {
- return (Type*)resolveVal(type);
- }
-
void visitTypeExp(TypeExp& exp)
{
exp.type = resolveType(exp.type);
@@ -581,7 +577,7 @@ namespace Slang
return getTypeForDeclRef(astBuilder, nullptr, nullptr, declRef, &typeResult, loc);
}
- DeclRef<ExtensionDecl> ApplyExtensionToType(
+ DeclRef<ExtensionDecl> applyExtensionToType(
SemanticsVisitor* semantics,
ExtensionDecl* extDecl,
Type* type)
@@ -589,118 +585,7 @@ namespace Slang
if(!semantics)
return DeclRef<ExtensionDecl>();
- return semantics->ApplyExtensionToType(extDecl, type);
- }
-
- GenericSubstitution* createDefaultSubstitutionsForGeneric(
- ASTBuilder* astBuilder,
- SemanticsVisitor* semantics,
- GenericDecl* genericDecl,
- Substitutions* outerSubst)
- {
- GenericSubstitution* cachedResult = nullptr;
- if (astBuilder->m_genericDefaultSubst.tryGetValue(genericDecl, cachedResult))
- {
- if (cachedResult->getOuter() == outerSubst)
- return cachedResult;
- }
-
- List<Val*> args;
-
- for( auto mm : genericDecl->members )
- {
- if( auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm) )
- {
- args.add(DeclRefType::create(astBuilder, astBuilder->getSpecializedDeclRef<Decl>(genericTypeParamDecl, outerSubst)));
- }
- else if( auto genericValueParamDecl = as<GenericValueParamDecl>(mm) )
- {
- if (semantics)
- ensureDecl(semantics, genericValueParamDecl, DeclCheckState::ReadyForLookup);
-
- args.add(astBuilder->getOrCreate<GenericParamIntVal>(
- genericValueParamDecl->getType(),
- astBuilder->getSpecializedDeclRef(genericValueParamDecl, outerSubst)));
- }
- }
-
- bool shouldCache = true;
-
- // create default substitution arguments for constraints
- for (auto mm : genericDecl->members)
- {
- if (auto genericTypeConstraintDecl = as<GenericTypeConstraintDecl>(mm))
- {
- if (semantics)
- {
- ensureDecl(semantics, genericTypeConstraintDecl, DeclCheckState::ReadyForReference);
- }
- auto constraintDeclRef = astBuilder->getSpecializedDeclRef<GenericTypeConstraintDecl>(genericTypeConstraintDecl, outerSubst);
- auto witness =
- astBuilder->getDeclaredSubtypeWitness(
- getSub(astBuilder, constraintDeclRef),
- getSup(astBuilder, constraintDeclRef),
- astBuilder->getSpecializedDeclRef(genericTypeConstraintDecl, outerSubst));
- // TODO: this is an ugly hack to prevent crashing.
- // In early stages of compilation witness->sub and witness->sup may not be checked yet.
- // When semanticVisitor is present we have used that to ensure the type is checked.
- // However due to how the code is written we cannot guarantee semanticVisitor is always available
- // here, and if we can't get the checked sup/sub type this subst is incomplete and should not be
- // cached.
- if (!witness->sub)
- shouldCache = false;
- args.add(witness);
- }
- }
-
- GenericSubstitution* genericSubst = astBuilder->getOrCreateGenericSubstitution(outerSubst, genericDecl, args);
- if (shouldCache)
- astBuilder->m_genericDefaultSubst[genericDecl] = genericSubst;
- return genericSubst;
- }
-
- // Sometimes we need to refer to a declaration the way that it would be specialized
- // inside the context where it is declared (e.g., with generic parameters filled in
- // using their archetypes).
- //
- SubstitutionSet createDefaultSubstitutions(
- ASTBuilder* astBuilder,
- SemanticsVisitor* semantics,
- Decl* decl,
- SubstitutionSet outerSubstSet)
- {
- auto dd = decl->parentDecl;
- if( auto genericDecl = as<GenericDecl>(dd) )
- {
- // We don't want to specialize references to anything
- // other than the "inner" declaration itself.
- if(decl != genericDecl->inner)
- return outerSubstSet;
-
- GenericSubstitution* genericSubst = createDefaultSubstitutionsForGeneric(
- astBuilder,
- semantics,
- genericDecl,
- outerSubstSet.substitutions);
-
- return SubstitutionSet(genericSubst);
- }
-
- return outerSubstSet;
- }
-
- SubstitutionSet createDefaultSubstitutions(
- ASTBuilder* astBuilder,
- SemanticsVisitor* semantics,
- Decl* decl)
- {
- SubstitutionSet subst;
- if( auto parentDecl = decl->parentDecl )
- {
- subst = createDefaultSubstitutions(astBuilder, semantics, parentDecl);
- }
- subst = createDefaultSubstitutions(astBuilder, semantics, decl, subst);
- return subst;
+ return semantics->applyExtensionToType(extDecl, type);
}
bool SemanticsVisitor::isDeclUsableAsStaticMember(
@@ -1066,7 +951,7 @@ namespace Slang
auto baseExprType = memberExpr->baseExpression->type.type;
if (auto typeType = as<TypeType>(baseExprType))
{
- if (diffThisType->equals(typeType->type))
+ if (diffThisType->equals(typeType->getType()))
{
return;
}
@@ -1149,7 +1034,6 @@ namespace Slang
{
// A variable with an explicit type is simpler, for the
// most part.
-
TypeExp typeExp = CheckUsableType(varDecl->type);
varDecl->type = typeExp;
if (varDecl->type.equals(m_astBuilder->getVoidType()))
@@ -1256,7 +1140,7 @@ namespace Slang
{
if (auto basicType = as<BasicExpressionType>(varDecl->getType()))
{
- switch (basicType->baseType)
+ switch (basicType->getBaseType())
{
case BaseType::Bool:
case BaseType::Int8:
@@ -1429,11 +1313,11 @@ namespace Slang
{
if (auto declRefType = as<DeclRefType>(sharedTypeExpr->base))
{
- auto subst = createDefaultSubstitutions(m_astBuilder, this, declRefType->declRef.getDecl());
- auto newType = DeclRefType::create(m_astBuilder, m_astBuilder->getSpecializedDeclRef(declRefType->declRef.getDecl(), subst));
+ auto newDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, declRefType->getDeclRef());
+ auto newType = DeclRefType::create(m_astBuilder, newDeclRef);
sharedTypeExpr->base.type = newType;
if (auto typetype = as<TypeType>(typeExp.exp->type))
- typetype->type = newType;
+ typeExp.exp->type = m_astBuilder->getTypeType(newType);
}
}
}
@@ -1477,20 +1361,20 @@ namespace Slang
}
// If `This` is nested inside a generic, we need to form a complete declref type to the
- // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution
- // from requirementDeclRef to get the generic substitution for outer generic parameters, and
+ // newly synthesized aggTypeDecl here. This can be done by obtaining the this type witness
+ // from requirementDeclRef to get the generic arguments for the outer generic, and
// apply it to the newly synthesized decl.
SubstitutionSet substSet;
- if (auto thisTypeSusbt = findThisTypeSubstitution(
- requirementDeclRef.getSubst(),
- as<InterfaceDecl>(requirementDeclRef.getParent(m_astBuilder)).getDecl()))
+ if (auto thisWitness = findThisTypeWitness(
+ SubstitutionSet(requirementDeclRef),
+ as<InterfaceDecl>(requirementDeclRef.getParent()).getDecl()))
{
- if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
+ if (auto declRefType = as<DeclRefType>(thisWitness->getSub()))
{
- substSet = declRefType->declRef.getSubst();
+ substSet = SubstitutionSet(declRefType->getDeclRef());
}
}
- auto satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getSpecializedDeclRef(aggTypeDecl, substSet));
+ auto satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getMemberDeclRef(substSet.declRef, aggTypeDecl));
// Helper function to add a `diffType` field into the synthesized type for the original
// `member`.
@@ -1513,8 +1397,7 @@ namespace Slang
fieldLookupExpr->type.type = diffMemberType;
auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
baseTypeExpr->base.type = differentialType;
- auto baseTypeType = m_astBuilder->create<TypeType>();
- baseTypeType->type = differentialType;
+ auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(differentialType);
baseTypeExpr->type.type = baseTypeType;
fieldLookupExpr->baseExpression = baseTypeExpr;
fieldLookupExpr->declRef = makeDeclRef(diffField);
@@ -1529,8 +1412,7 @@ namespace Slang
fieldLookupExpr->type.type = diffMemberType;
auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
baseTypeExpr->base.type = differentialType;
- auto baseTypeType = m_astBuilder->create<TypeType>();
- baseTypeType->type = differentialType;
+ auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(differentialType);
baseTypeExpr->type.type = baseTypeType;
fieldLookupExpr->baseExpression = baseTypeExpr;
fieldLookupExpr->declRef = makeDeclRef(diffField);
@@ -1545,7 +1427,7 @@ namespace Slang
{
if (auto declRefType = as<DeclRefType>(inheritanceDecl->base.type))
{
- if (declRefType->declRef == m_astBuilder->getDifferentiableInterfaceDecl())
+ if (declRefType->getDeclRef() == m_astBuilder->getDifferentiableInterfaceDecl())
{
hasDifferentialConformance = true;
break;
@@ -1590,7 +1472,7 @@ namespace Slang
if (auto baseDeclRefType = as<DeclRefType>(inheritance->base.type))
{
// Skip interface super types.
- if (baseDeclRefType->declRef.as<InterfaceDecl>())
+ if (baseDeclRefType->getDeclRef().as<InterfaceDecl>())
continue;
if (auto superDiffType = tryGetDifferentialType(m_astBuilder, baseDeclRefType))
{
@@ -1618,6 +1500,9 @@ namespace Slang
if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable))
{
witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType));
+
+ // Incrase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types.
+ m_astBuilder->incrementEpoch();
return true;
}
@@ -1747,11 +1632,11 @@ namespace Slang
auto baseType = as<DeclRefType>(inheritanceDecl->witnessTable->baseType);
if (!baseType)
return;
- if (baseType->declRef.getDecl() != m_astBuilder->getDifferentiableInterfaceDecl().getDecl())
+ if (baseType->getDeclRef().getDecl() != m_astBuilder->getDifferentiableInterfaceDecl().getDecl())
return;
RequirementWitness witnessValue;
auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType);
- if (!inheritanceDecl->witnessTable->requirementDictionary.tryGetValue(requirementDecl, witnessValue))
+ if (!inheritanceDecl->witnessTable->getRequirementDictionary().tryGetValue(requirementDecl, witnessValue))
return;
// A type used as differential type must have itself as its own differential type.
@@ -1763,7 +1648,7 @@ namespace Slang
auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType);
if (!differentialType->equals(diffDiffType))
{
- SourceLoc sourceLoc = differentialType->declRef.getDecl()->loc;
+ SourceLoc sourceLoc = differentialType->getDeclRef().getDecl()->loc;
getSink()->diagnose(sourceLoc, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType);
getSink()->diagnose(inheritanceDecl, Diagnostics::noteSeeUseOfDifferentialType, differentialType, inheritanceDecl->getSup());
}
@@ -2287,7 +2172,7 @@ namespace Slang
auto satisfyingVal = m_astBuilder->getOrCreate<GenericParamIntVal>(
requiredValueParamDeclRef.getDecl()->getType(),
satisfyingValueParamDeclRef);
- satisfyingVal->declRef = satisfyingValueParamDeclRef;
+ satisfyingVal->getDeclRef() = satisfyingValueParamDeclRef;
requiredSubstArgs.add(satisfyingVal);
}
@@ -2311,21 +2196,16 @@ namespace Slang
}
}
- GenericSubstitution* requiredSubst = m_astBuilder->getOrCreateGenericSubstitution(
- requiredGenericDeclRef.getSubst(),
- requiredGenericDeclRef.getDecl(),
- requiredSubstArgs);
-
// Now that we have computed a set of specialization arguments that will
// specialize the generic requirement at the type parameters of the satisfying
// generic, we can construct a reference to that declaration and re-run some
// of the earlier checking logic with more type information usable.
//
- auto specializedRequiredGenericDeclRef = m_astBuilder->getSpecializedDeclRef<GenericDecl>(requiredGenericDeclRef.getDecl(), requiredSubst);
- auto specializedRequiredMemberDeclRefs = getMembers(m_astBuilder, specializedRequiredGenericDeclRef);
+ auto specializedRequiredGenericInnerDeclRef = m_astBuilder->getGenericAppDeclRef(
+ requiredGenericDeclRef, requiredSubstArgs.getArrayView());
for (Index i = 0; i < memberCount; i++)
{
- auto requiredMemberDeclRef = specializedRequiredMemberDeclRefs[i];
+ auto requiredMemberDeclRef = requiredMemberDeclRefs[i];
auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i];
if(auto requiredTypeParamDeclRef = requiredMemberDeclRef.as<GenericTypeParamDecl>())
@@ -2365,13 +2245,16 @@ namespace Slang
// In current code the sub type will always be one of the generic type parameters,
// and the super-type will always be an interface, but there should be no
// need to make use of those additional details here.
-
- auto requiredSubType = getSub(m_astBuilder, requiredConstraintDeclRef);
+ auto specializedRequiredConstraintDeclRef = m_astBuilder->getGenericAppDeclRef(
+ requiredGenericDeclRef,
+ requiredSubstArgs.getArrayView(),
+ requiredConstraintDeclRef.getDecl()).as<GenericTypeConstraintDecl>();
+ auto requiredSubType = getSub(m_astBuilder, specializedRequiredConstraintDeclRef);
auto satisfyingSubType = getSub(m_astBuilder, satisfyingConstraintDeclRef);
if (!satisfyingSubType->equals(requiredSubType))
return false;
- auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef);
+ auto requiredSuperType = getSup(m_astBuilder, specializedRequiredConstraintDeclRef);
auto satisfyingSuperType = getSup(m_astBuilder, satisfyingConstraintDeclRef);
if (!satisfyingSuperType->equals(requiredSuperType))
return false;
@@ -2400,8 +2283,8 @@ namespace Slang
// declaration (whatever it is) for an exact match.
//
return doesMemberSatisfyRequirement(
- m_astBuilder->getSpecializedDeclRef<Decl>(satisfyingGenericDeclRef.getDecl()->inner, satisfyingGenericDeclRef.getSubst()),
- m_astBuilder->getSpecializedDeclRef<Decl>(requiredGenericDeclRef.getDecl()->inner, requiredSubst),
+ m_astBuilder->getMemberDeclRef(satisfyingGenericDeclRef, getInner(satisfyingGenericDeclRef)),
+ specializedRequiredGenericInnerDeclRef,
witnessTable);
}
@@ -2444,7 +2327,7 @@ namespace Slang
{
// If we are seeing a placeholder that awaits synthesis, return false now to trigger
// auto synthesis.
- if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>())
+ if (declRefType->getDeclRef().getDecl()->hasModifier<ToBeSynthesizedModifier>())
return false;
}
// We need to confirm that the chosen type `satisfyingType`,
@@ -2466,7 +2349,7 @@ namespace Slang
// type can indeed satisfy the interface requirement.
witnessTable->add(
requiredAssociatedTypeDeclRef.getDecl(),
- RequirementWitness(satisfyingType));
+ RequirementWitness(satisfyingType->getCanonicalType()));
}
return conformance;
@@ -2563,7 +2446,7 @@ namespace Slang
// check if the specified type satisfies the constraints defined by the associated type
if (auto requiredTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>())
{
- ensureDecl(typedefDeclRef, DeclCheckState::CanUseAsType);
+ ensureDecl(typedefDeclRef, DeclCheckState::ReadyForLookup);
auto satisfyingType = getNamedType(m_astBuilder, typedefDeclRef);
return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable);
@@ -2648,9 +2531,6 @@ namespace Slang
{
if (auto constraintDecl = as<GenericTypeConstraintDecl>(member))
{
- getASTBuilder()->getSpecializedDeclRef(
- constraintDecl, requiredMemberDeclRef.getSubst());
-
auto synConstraintDecl = m_astBuilder->create<GenericTypeConstraintDecl>();
synConstraintDecl->nameAndLoc = constraintDecl->getNameAndLoc();
synConstraintDecl->parentDecl = synGenericDecl;
@@ -2658,7 +2538,7 @@ namespace Slang
// For constraints of type T : Interface, where T is a simple type parameter,
// find the declaration of T
//
- if (auto typeParamDecl = as<DeclRefType>(constraintDecl->sub.type)->declRef.as<GenericTypeParamDecl>().getDecl())
+ if (auto typeParamDecl = as<DeclRefType>(constraintDecl->sub.type)->getDeclRef().as<GenericTypeParamDecl>().getDecl())
{
auto synTypeParamDecl = mapOrigToSynTypeParams[typeParamDecl];
@@ -2680,37 +2560,19 @@ namespace Slang
}
}
- // Get outer substitutions. (This inner-most substition
- // must be a ThisTypeSubstition)
- //
- Substitutions* outer = nullptr;
- if (auto thisTypeSubst = findThisTypeSubstitution(
- requiredMemberDeclRef.getSubst(),
- as<InterfaceDecl>(requiredMemberDeclRef.getParent(m_astBuilder)).getDecl()))
- {
- outer = thisTypeSubst;
- }
-
// Override generic pointer to point to the original generic container.
// This will create a substitution of the synthesized parameters for the
// original parameters.
- //
- GenericSubstitution* requiredFuncSubsts = createDefaultSubstitutionsForGeneric(m_astBuilder, this, requiredMemberDeclRef.getDecl(), outer);
- DeclRef<Decl> requiredFuncDeclRef = m_astBuilder->getSpecializedDeclRef(requiredMemberDeclRef.getDecl()->inner, requiredFuncSubsts);
-
- GenericSubstitution* substSynParamsForOrigGeneric = m_astBuilder->getOrCreateGenericSubstitution(
- outer,
- requiredMemberDeclRef.getDecl(),
- createDefaultSubstitutionsForGeneric(m_astBuilder, this, synGenericDecl, nullptr)->getArgs());
-
- // Substitute parameters of the synthesized generic for the parameters of the original generic.
- requiredFuncDeclRef = substituteDeclRef(substSynParamsForOrigGeneric, m_astBuilder, requiredFuncDeclRef);
+ //
+ auto defaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, synGenericDecl);
+ DeclRef<FuncDecl> requiredFuncDeclRef = m_astBuilder->getGenericAppDeclRef(
+ requiredMemberDeclRef, defaultArgs.getArrayView()).as<FuncDecl>();
- SLANG_ASSERT(requiredFuncDeclRef.as<FuncDecl>());
+ SLANG_ASSERT(requiredFuncDeclRef);
synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitness(
context,
- requiredFuncDeclRef.as<FuncDecl>(),
+ requiredFuncDeclRef,
synArgs,
synThis);
synGenericDecl->inner->parentDecl = synGenericDecl;
@@ -2860,14 +2722,12 @@ namespace Slang
{
if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
{
- ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
- val->func = satisfyingMemberDeclRef;
+ ForwardDifferentiateVal* val = m_astBuilder->getOrCreate<ForwardDifferentiateVal>(satisfyingMemberDeclRef);
witnessTable->add(fwdReq, RequirementWitness(val));
}
else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
{
- DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
- val->func = satisfyingMemberDeclRef;
+ DifferentiateVal* val = m_astBuilder->getOrCreate<BackwardDifferentiateVal>(satisfyingMemberDeclRef);
witnessTable->add(bwdReq, RequirementWitness(val));
}
}
@@ -3127,7 +2987,7 @@ namespace Slang
// or uses, an associated type or `This`.
//
// Ideally we should be looking up the type using a `DeclRef` that
- // refers to the interface requirement using a `ThisTypeSubstitution`
+ // refers to the interface requirement using a `LookupDeclRef`
// that refers to the satisfying type declaration, and requirement
// checking for non-associated-type requirements should be done *after*
// requirement checking for associated-type requirements.
@@ -3577,7 +3437,7 @@ namespace Slang
// First we need to make sure the associated `Differential` type requirement is satisfied.
bool hasDifferentialAssocType = false;
- for (auto existingEntry : witnessTable->requirementDictionary)
+ for (auto& existingEntry : witnessTable->getRequirementDictionary())
{
if (auto builtinReqAttr = existingEntry.key->findModifier<BuiltinRequirementModifier>())
{
@@ -3726,20 +3586,33 @@ namespace Slang
// If `This` is nested inside a generic, we need to form a complete declref type to the
// newly synthesized method here in order to fill into the witness table.
- // This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the
+ // This can be done by obtaining the ThisType witness from requirementDeclRef to get the
// generic substitution for outer generic parameters, and apply it here.
SubstitutionSet substSet;
- if (auto thisTypeSubst = findThisTypeSubstitution(
- requirementDeclRef.getSubst(),
+ if (auto thisTypeWitness = findThisTypeWitness(
+ SubstitutionSet(requirementDeclRef),
as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
{
- if (auto declRefType = as<DeclRefType>(thisTypeSubst->witness->sub))
+ if (auto declRefType = as<DeclRefType>(thisTypeWitness->getSub()))
{
- substSet = declRefType->declRef.getSubst();
+ substSet = SubstitutionSet(declRefType->getDeclRef());
}
}
+ if (auto outerGeneric = GetOuterGeneric(context->parentDecl))
+ {
+ // If the context->parentDecl is not the same as ThisType represented by genApp, then it must be an extension
+ // to ThisType. In this case, we need to form a new GenericAppDeclRef to specailizethe outer parent extension
+ // decl. Note that the extension might be a partial extension with some generic arguments missing, and
+ // we can't support that case right now. For now we can just assume the extension will have the same set
+ // of generic parameters as the target type.
+ auto defaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, outerGeneric);
+ auto specializedParent = m_astBuilder->getGenericAppDeclRef(makeDeclRef(outerGeneric), defaultArgs.getArrayView());
+ auto specializedFunc = m_astBuilder->getMemberDeclRef(specializedParent, synFunc);
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(specializedFunc));
+ return true;
+ }
- witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(m_astBuilder->getSpecializedDeclRef<Decl>(synFunc, substSet)));
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(m_astBuilder->getDirectDeclRef(synFunc)));
return true;
}
@@ -3767,11 +3640,16 @@ namespace Slang
// witness in the table for the requirement, so
// that we can bail out early.
//
- if(witnessTable->requirementDictionary.containsKey(requiredMemberDeclRef.getDecl()))
+ if(witnessTable->getRequirementDictionary().containsKey(requiredMemberDeclRef.getDecl()))
{
return true;
}
+ // The ThisType requirement is always satisfied.
+ if (as<ThisTypeDecl>(requiredMemberDeclRef.getDecl()))
+ {
+ return true;
+ }
// An important exception to the above is that an
// inheritance declaration in the interface is not going
@@ -3987,17 +3865,13 @@ namespace Slang
ensureDecl(superInterfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements);
// When comparing things like signatures, we need to do so in the context
- // of a this-type substitution that aligns the signatures in the interface
+ // of a LookupDeclRef that aligns the signatures in the interface
// with those in the concrete type. For example, we need to treat any uses
// of `This` in the interface as equivalent to the concrete type for the
// purpose of signature matching (and similarly for associated types).
//
- ThisTypeSubstitution* thisTypeSubst = m_astBuilder->getOrCreateThisTypeSubstitution(
- superInterfaceDeclRef.getDecl(),
- subTypeConformsToSuperInterfaceWitness,
- superInterfaceDeclRef.getSubst());
-
- auto specializedSuperInterfaceDeclRef = m_astBuilder->getSpecializedDeclRef<InterfaceDecl>(superInterfaceDeclRef.getDecl(), thisTypeSubst);
+ auto thisTypeDeclRef = m_astBuilder->getLookupDeclRef(
+ subTypeConformsToSuperInterfaceWitness, superInterfaceDeclRef.getDecl()->getThisTypeDecl());
bool result = true;
@@ -4029,35 +3903,36 @@ namespace Slang
// constraints and solve for those type variables as part of the
// conformance-checking process.
//
- for(auto requiredMemberDeclRef : getMembers(m_astBuilder, specializedSuperInterfaceDeclRef))
+ for(auto requiredMemberDecl : getMembers(m_astBuilder, superInterfaceDeclRef))
{
- if(!isAssociatedTypeDecl(requiredMemberDeclRef.getDecl()))
+ if(!isAssociatedTypeDecl(requiredMemberDecl.getDecl()))
continue;
-
+ auto requiredMemberDeclRef = m_astBuilder->getLookupDeclRef(subTypeConformsToSuperInterfaceWitness, requiredMemberDecl.getDecl());
auto requirementSatisfied = findWitnessForInterfaceRequirement(
context,
subType,
superInterfaceType,
inheritanceDecl,
- specializedSuperInterfaceDeclRef,
+ thisTypeDeclRef,
requiredMemberDeclRef,
witnessTable,
subTypeConformsToSuperInterfaceWitness);
result = result && requirementSatisfied;
}
- for(auto requiredMemberDeclRef : getMembers(m_astBuilder, specializedSuperInterfaceDeclRef))
+ for(auto requiredMemberDecl : getMembers(m_astBuilder, superInterfaceDeclRef))
{
- if(isAssociatedTypeDecl(requiredMemberDeclRef.getDecl()))
+ if(isAssociatedTypeDecl(requiredMemberDecl.getDecl()))
continue;
- if (requiredMemberDeclRef.as<DerivativeRequirementDecl>())
+ if (requiredMemberDecl.as<DerivativeRequirementDecl>())
continue;
+ auto requiredMemberDeclRef = m_astBuilder->getLookupDeclRef(subTypeConformsToSuperInterfaceWitness, requiredMemberDecl.getDecl());
auto requirementSatisfied = findWitnessForInterfaceRequirement(
context,
subType,
superInterfaceType,
inheritanceDecl,
- specializedSuperInterfaceDeclRef,
+ thisTypeDeclRef,
requiredMemberDeclRef,
witnessTable,
subTypeConformsToSuperInterfaceWitness);
@@ -4089,25 +3964,27 @@ namespace Slang
// the time we are compiling and handle those, and punt on the larger issue
// for a bit longer.
//
- for(auto candidateExt : getCandidateExtensions(specializedSuperInterfaceDeclRef, this))
+ for(auto candidateExt : getCandidateExtensions(superInterfaceDeclRef, this))
{
// We need to apply the extension to the interface type that our
// concrete type is inheriting from.
//
- Type* targetType = DeclRefType::create(m_astBuilder, specializedSuperInterfaceDeclRef);
- auto extDeclRef = ApplyExtensionToType(candidateExt, targetType);
- if(!extDeclRef)
+ Type* targetType = DeclRefType::create(m_astBuilder, thisTypeDeclRef);
+ auto parentDeclRef = applyExtensionToType(candidateExt, targetType);
+ if(!parentDeclRef)
continue;
// Only inheritance clauses from the extension matter right now.
- for(auto requiredInheritanceDeclRef : getMembersOfType<InheritanceDecl>(m_astBuilder, extDeclRef))
+ for(auto requiredInheritanceDecl : getMembersOfType<InheritanceDecl>(m_astBuilder, candidateExt))
{
+ auto requiredInheritanceDeclRef = m_astBuilder->getLookupDeclRef(
+ subTypeConformsToSuperInterfaceWitness, requiredInheritanceDecl.getDecl());
auto requirementSatisfied = findWitnessForInterfaceRequirement(
context,
subType,
superInterfaceType,
inheritanceDecl,
- specializedSuperInterfaceDeclRef,
+ thisTypeDeclRef,
requiredInheritanceDeclRef,
witnessTable,
subTypeConformsToSuperInterfaceWitness);
@@ -4131,7 +4008,7 @@ namespace Slang
{
if (auto supereclRefType = as<DeclRefType>(superType))
{
- auto superTypeDeclRef = supereclRefType->declRef;
+ auto superTypeDeclRef = supereclRefType->getDeclRef();
if (auto superInterfaceDeclRef = superTypeDeclRef.as<InterfaceDecl>())
{
// The type is stating that it conforms to an interface.
@@ -4172,11 +4049,11 @@ namespace Slang
if( auto declRefType = as<DeclRefType>(subType) )
{
- auto declRef = declRefType->declRef;
+ auto declRef = declRefType->getDeclRef();
if (auto superDeclRefType = as<DeclRefType>(superType))
{
- auto superTypeDecl = superDeclRefType->declRef.getDecl();
+ auto superTypeDecl = superDeclRefType->getDeclRef().getDecl();
if (superTypeDecl->findModifier<ComInterfaceAttribute>())
{
// A struct cannot implement a COM Interface.
@@ -4228,10 +4105,7 @@ namespace Slang
// Look at the type being inherited from, and validate
// appropriately.
- DeclaredSubtypeWitness* subIsSuperWitness = m_astBuilder->create<DeclaredSubtypeWitness>();
- subIsSuperWitness->declRef = makeDeclRef(inheritanceDecl);
- subIsSuperWitness->sub = subType;
- subIsSuperWitness->sup = superType;
+ DeclaredSubtypeWitness* subIsSuperWitness = m_astBuilder->getDeclaredSubtypeWitness(subType, superType, makeDeclRef(inheritanceDecl));
ConformanceCheckingContext context;
context.conformingType = subType;
@@ -4333,7 +4207,7 @@ namespace Slang
{
return;
}
- auto baseDecl = baseDeclRefType->declRef.getDecl();
+ auto baseDecl = baseDeclRefType->getDeclRef().getDecl();
// Using the parent/child hierarchy baked into `Decl`s we
// can find the modules that contain both the `decl` doing
@@ -4415,7 +4289,7 @@ namespace Slang
continue;
}
- auto baseDeclRef = baseDeclRefType->declRef;
+ auto baseDeclRef = baseDeclRefType->getDeclRef();
auto baseInterfaceDeclRef = baseDeclRef.as<InterfaceDecl>();
if( !baseInterfaceDeclRef )
{
@@ -4476,7 +4350,7 @@ namespace Slang
continue;
}
- auto baseDeclRef = baseDeclRefType->declRef;
+ auto baseDeclRef = baseDeclRefType->getDeclRef();
if( auto baseInterfaceDeclRef = baseDeclRef.as<InterfaceDecl>() )
{
}
@@ -4545,7 +4419,7 @@ namespace Slang
continue;
}
- auto baseDeclRef = baseDeclRefType->declRef;
+ auto baseDeclRef = baseDeclRefType->getDeclRef();
if (auto baseInterfaceDeclRef = baseDeclRef.as<InterfaceDecl>())
{
}
@@ -4594,8 +4468,8 @@ namespace Slang
auto basicType = as<BasicExpressionType>(type);
if(!basicType)
return false;
-
- return isIntegerBaseType(basicType->baseType) || basicType->baseType == BaseType::Bool;
+ auto baseType = basicType->getBaseType();
+ return isIntegerBaseType(baseType) || baseType == BaseType::Bool;
}
bool SemanticsVisitor::isIntValueInRangeOfType(IntegerLiteralValue value, Type* type)
@@ -4604,7 +4478,7 @@ namespace Slang
if (!basicType)
return false;
- switch (basicType->baseType)
+ switch (basicType->getBaseType())
{
case BaseType::UInt8:
return (value >= 0 && value <= std::numeric_limits<uint8_t>::max()) || (value == -1);
@@ -4686,7 +4560,7 @@ namespace Slang
continue;
}
- auto baseDeclRef = baseDeclRefType->declRef;
+ auto baseDeclRef = baseDeclRefType->getDeclRef();
if( auto baseInterfaceDeclRef = baseDeclRef.as<InterfaceDecl>() )
{
_validateCrossModuleInheritance(decl, inheritanceDecl);
@@ -4790,7 +4664,7 @@ namespace Slang
Decl* tagAssociatedTypeDecl = nullptr;
if(auto enumTypeTypeDeclRefType = dynamicCast<DeclRefType>(enumTypeType))
{
- if(auto enumTypeTypeInterfaceDecl = as<InterfaceDecl>(enumTypeTypeDeclRefType->declRef.getDecl()))
+ if(auto enumTypeTypeInterfaceDecl = as<InterfaceDecl>(enumTypeTypeDeclRefType->getDeclRef().getDecl()))
{
for(auto memberDecl : enumTypeTypeInterfaceDecl->members)
{
@@ -4861,7 +4735,7 @@ namespace Slang
{
if(auto constIntVal = as<ConstantIntVal>(explicitTagVal))
{
- defaultTag = constIntVal->value;
+ defaultTag = constIntVal->getValue();
}
else
{
@@ -5015,7 +4889,7 @@ namespace Slang
bool SemanticsVisitor::doGenericSignaturesMatch(
GenericDecl* left,
GenericDecl* right,
- GenericSubstitution** outSubstRightToLeft)
+ DeclRef<Decl>* outSpecializedRightInner)
{
// Our first goal here is to determine if `left` and
// `right` have equivalent lists of explicit
@@ -5133,9 +5007,9 @@ namespace Slang
// `foo2<T>` so that its constraint, after specialization,
// looks like `T : IFoo`.
//
- auto& substRightToLeft = *outSubstRightToLeft;
- List<Val*> leftArgs = getDefaultSubstitutionArgs(left);
- substRightToLeft = getASTBuilder()->getOrCreateGenericSubstitution(nullptr, right, leftArgs);
+ auto& substInnerRightToLeft = *outSpecializedRightInner;
+ List<Val*> leftArgs = getDefaultSubstitutionArgs(m_astBuilder, this, left);
+ substInnerRightToLeft = m_astBuilder->getGenericAppDeclRef(makeDeclRef(right), leftArgs.getArrayView());
// We should now be able to enumerate the constraints
// on `right` in a way that uses the same type parameters
@@ -5207,7 +5081,9 @@ namespace Slang
// arguments into account.
//
GenericTypeConstraintDecl* leftConstraint = leftConstraints[cc];
- DeclRef<GenericTypeConstraintDecl> rightConstraint = m_astBuilder->getSpecializedDeclRef(rightConstraints[cc], substRightToLeft);
+ auto unspecializedRightConstarintDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(rightConstraints[cc]));
+ DeclRef<GenericTypeConstraintDecl> rightConstraint = substInnerRightToLeft.substitute(
+ m_astBuilder, unspecializedRightConstarintDeclRef).as<GenericTypeConstraintDecl>();
// For now, every constraint has the form `sub : sup`
// to indicate that `sub` must be a subtype of `sup`.
@@ -5277,44 +5153,59 @@ namespace Slang
return true;
}
- List<Val*> SemanticsVisitor::getDefaultSubstitutionArgs(GenericDecl* genericDecl)
+ List<Val*> getDefaultSubstitutionArgs(ASTBuilder* astBuilder, SemanticsVisitor* semantics, GenericDecl* genericDecl)
{
List<Val*> args;
- for (auto dd : genericDecl->members)
- {
- if (dd == genericDecl->inner)
- continue;
+ if (astBuilder->m_cachedGenericDefaultArgs.tryGetValue(genericDecl, args))
+ return args;
- if (auto typeParam = as<GenericTypeParamDecl>(dd))
+ for (auto mm : genericDecl->members)
+ {
+ if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm))
{
- auto type = DeclRefType::create(m_astBuilder, makeDeclRef(typeParam));
- args.add(type);
+ args.add(DeclRefType::create(astBuilder, astBuilder->getDirectDeclRef(genericTypeParamDecl)));
}
- else if (auto valueParam = as<GenericValueParamDecl>(dd))
+ else if (auto genericValueParamDecl = as<GenericValueParamDecl>(mm))
{
- auto val = m_astBuilder->getOrCreate<GenericParamIntVal>(
- valueParam->getType(),
- DeclRef<VarDeclBase>(valueParam));
- args.add(val);
+ if (semantics)
+ semantics->ensureDecl(genericValueParamDecl, DeclCheckState::ReadyForLookup);
+
+ args.add(astBuilder->getOrCreate<GenericParamIntVal>(
+ genericValueParamDecl->getType(),
+ astBuilder->getDirectDeclRef(genericValueParamDecl)));
}
}
- // Add defaults for constraint parameters.
- for (auto dd : genericDecl->members)
+ bool shouldCache = true;
+
+ // create default substitution arguments for constraints
+ for (auto mm : genericDecl->members)
{
- if (auto constraintDecl = as<GenericTypeConstraintDecl>(dd))
+ if (auto genericTypeConstraintDecl = as<GenericTypeConstraintDecl>(mm))
{
- // Convert the constraint to an appropriate witness.
- auto witness = tryGetSubtypeWitness(constraintDecl->sub, constraintDecl->sup);
-
- // Must be non-null since we know there's a constraint. If null, something is
- // very wrong.
- //
- SLANG_ASSERT(witness);
-
+ if (semantics)
+ semantics->ensureDecl(genericTypeConstraintDecl, DeclCheckState::ReadyForReference);
+ auto constraintDeclRef = astBuilder->getDirectDeclRef<GenericTypeConstraintDecl>(genericTypeConstraintDecl);
+ auto witness =
+ astBuilder->getDeclaredSubtypeWitness(
+ getSub(astBuilder, constraintDeclRef),
+ getSup(astBuilder, constraintDeclRef),
+ constraintDeclRef);
+ // TODO: this is an ugly hack to prevent crashing.
+ // In early stages of compilation witness->sub and witness->sup may not be checked yet.
+ // When semanticVisitor is present we have used that to ensure the type is checked.
+ // However due to how the code is written we cannot guarantee semanticVisitor is always available
+ // here, and if we can't get the checked sup/sub type this subst is incomplete and should not be
+ // cached.
+ if (!witness->getSub())
+ shouldCache = false;
args.add(witness);
}
}
+
+ if (shouldCache)
+ astBuilder->m_cachedGenericDefaultArgs[genericDecl] = args;
+
return args;
}
@@ -5442,11 +5333,11 @@ namespace Slang
// Then we will compare the parameter types of `foo2`
// against the specialization `foo1<U>`.
//
- GenericSubstitution* subst = nullptr;
- if(!doGenericSignaturesMatch(newGenericDecl, oldGenericDecl, &subst))
+ DeclRef<Decl> specializedOldDeclInner;
+ if(!doGenericSignaturesMatch(newGenericDecl, oldGenericDecl, &specializedOldDeclInner))
return SLANG_OK;
- oldDeclRef = getASTBuilder()->getSpecializedDeclRef(oldDecl, subst);
+ oldDeclRef = specializedOldDeclInner.as<FuncDecl>();
}
// If the parameter signatures don't match, then don't worry
@@ -5869,7 +5760,7 @@ namespace Slang
auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
reqDecl->originalRequirementDecl = decl;
cloneModifiers(reqDecl, decl);
- auto declRef = m_astBuilder->getSpecializedDeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
+ auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)).as<CallableDecl>();
auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef));
setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
interfaceDecl->members.add(reqDecl);
@@ -5884,7 +5775,7 @@ namespace Slang
if (decl->hasModifier<BackwardDifferentiableAttribute>())
{
// Requirement for backward derivative.
- auto declRef = m_astBuilder->getSpecializedDeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
+ auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)).as<CallableDecl>();
auto originalFuncType = getFuncType(m_astBuilder, declRef);
auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(originalFuncType));
{
@@ -5953,7 +5844,7 @@ namespace Slang
IntegerLiteralValue SemanticsVisitor::GetMinBound(IntVal* val)
{
if (auto constantVal = as<ConstantIntVal>(val))
- return constantVal->value;
+ return constantVal->getValue();
// TODO(tfoley): Need to track intervals so that this isn't just a lie...
return 1;
@@ -6024,7 +5915,7 @@ namespace Slang
if (auto targetDeclRefType = as<DeclRefType>(decl->targetType))
{
// Attach our extension to that type as a candidate...
- if (auto aggTypeDeclRef = targetDeclRefType->declRef.as<AggTypeDecl>())
+ if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>())
{
auto aggTypeDecl = aggTypeDeclRef.getDecl();
@@ -6075,7 +5966,7 @@ namespace Slang
continue;
}
- auto baseDeclRef = baseDeclRefType->declRef;
+ auto baseDeclRef = baseDeclRefType->getDeclRef();
auto baseInterfaceDeclRef = baseDeclRef.as<InterfaceDecl>();
if( !baseInterfaceDeclRef )
{
@@ -6106,9 +5997,9 @@ namespace Slang
// conform to the interface and fill in its
// requirements.
//
- ThisType* thisType = m_astBuilder->create<ThisType>();
- thisType->interfaceDeclRef = interfaceDeclRef;
- return thisType;
+ return DeclRefType::create(
+ m_astBuilder,
+ m_astBuilder->getDirectDeclRef(interfaceDeclRef.getDecl()->getThisTypeDecl()));
}
else if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>())
{
@@ -6159,7 +6050,7 @@ namespace Slang
{
if( auto declRefType = as<DeclRefType>(type) )
{
- return calcThisType(declRefType->declRef);
+ return calcThisType(declRefType->getDeclRef());
}
else
{
@@ -6404,7 +6295,7 @@ namespace Slang
return parentGeneric;
}
- DeclRef<ExtensionDecl> SemanticsVisitor::ApplyExtensionToType(
+ DeclRef<ExtensionDecl> SemanticsVisitor::applyExtensionToType(
ExtensionDecl* extDecl,
Type* type)
{
@@ -6438,15 +6329,15 @@ namespace Slang
if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type))
return DeclRef<ExtensionDecl>();
- auto constraintSubst = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl));
- if (!constraintSubst)
+ auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView<Val*>());
+ if (!solvedDeclRef)
{
return DeclRef<ExtensionDecl>();
}
// Construct a reference to the extension with our constraint variables
// set as they were found by solving the constraint system.
- extDeclRef = m_astBuilder->getSpecializedDeclRef<Decl>(extDecl, constraintSubst).as<ExtensionDecl>();
+ extDeclRef = solvedDeclRef.as<ExtensionDecl>();
}
// Now extract the target type from our (possibly specialized) extension decl-ref.
@@ -6458,67 +6349,21 @@ namespace Slang
// substitution to the extension decl-ref.
if(auto targetDeclRefType = as<DeclRefType>(targetType))
{
- if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>())
+ if(auto targetInterfaceDeclRef = targetDeclRefType->getDeclRef().as<InterfaceDecl>())
{
// Okay, the target type is an interface.
//
- // Is the type we want to apply to also an interface?
- if(auto appDeclRefType = as<DeclRefType>(type))
+ // Is the type we want to apply to a ThisType?
+ if(auto appDeclRefType = as<ThisType>(type))
{
- if(auto appInterfaceDeclRef = appDeclRefType->declRef.as<InterfaceDecl>())
+ if(auto thisTypeLookupDeclRef = SubstitutionSet(appDeclRefType->getDeclRef()).findLookupDeclRef())
{
- if(appInterfaceDeclRef.getDecl() == targetInterfaceDeclRef.getDecl())
+ if(thisTypeLookupDeclRef->getDecl() == targetInterfaceDeclRef.getDecl())
{
// Looks like we have a match in the types,
- // now let's see if we have a this-type substitution.
- if(auto appThisTypeSubst = as<ThisTypeSubstitution>(appInterfaceDeclRef.getSubst()))
- {
- if(appThisTypeSubst->interfaceDecl == appInterfaceDeclRef.getDecl())
- {
- // The type we want to apply to has a this-type substitution,
- // and (by construction) the target type currently does not.
- //
- SLANG_ASSERT(!as<ThisTypeSubstitution>(targetInterfaceDeclRef.getSubst()));
-
- // We will create a new substitution to apply to the target type.
- ThisTypeSubstitution* newTargetSubst = m_astBuilder->getOrCreateThisTypeSubstitution(
- appThisTypeSubst->interfaceDecl,
- appThisTypeSubst->witness,
- targetInterfaceDeclRef.getSubst());
-
- targetType = DeclRefType::create(m_astBuilder,
- m_astBuilder->getSpecializedDeclRef<InterfaceDecl>(targetInterfaceDeclRef.getDecl(), newTargetSubst));
-
- // Note: we are constructing a this-type substitution that
- // we will apply to the extension declaration as well.
- // This is not strictly allowed by our current representation
- // choices, but we need it in order to make sure that
- // references to the target type of the extension
- // declaration have a chance to resolve the way we want them to.
-
- ThisTypeSubstitution* newExtSubst = m_astBuilder->getOrCreateThisTypeSubstitution(
- appThisTypeSubst->interfaceDecl,
- appThisTypeSubst->witness,
- extDeclRef.getSubst());
-
- extDeclRef = m_astBuilder->getSpecializedDeclRef<ExtensionDecl>(
- extDeclRef.getDecl(),
- newExtSubst);
-
- // TODO: Ideally we should also apply the chosen specialization to
- // the decl-ref for the extension, so that subsequent lookup through
- // the members of this extension will retain that substitution and
- // be able to apply it.
- //
- // E.g., if an extension method returns a value of an associated
- // type, then we'd want that to become specialized to a concrete
- // type when using the extension method on a value of concrete type.
- //
- // The challenge here that makes me reluctant to just staple on
- // such a substitution is that it wouldn't follow our implicit
- // rules about where `ThisTypeSubstitution`s can appear.
- }
- }
+ // now let's see if `type`'s declref starts with a Lookup.
+ targetType = type;
+ extDeclRef = m_astBuilder->getLookupDeclRef(thisTypeLookupDeclRef->getWitness(), extDeclRef.getDecl());
}
}
}
@@ -6641,7 +6486,6 @@ namespace Slang
{
if( auto namespaceDeclRef = declRefExpr->declRef.as<NamespaceDeclBase>() )
{
- SLANG_ASSERT(!namespaceDeclRef.getSubst());
namespaceDecl = namespaceDeclRef.getDecl();
}
}
@@ -7007,7 +6851,7 @@ namespace Slang
// the extension to the type and see if we succeed in
// making a match.
//
- auto extDeclRef = ApplyExtensionToType(semantics, extDecl, aggType);
+ auto extDeclRef = applyExtensionToType(semantics, extDecl, aggType);
if(!extDeclRef)
continue;
@@ -7065,8 +6909,8 @@ namespace Slang
{
if (auto andType = as<AndType>(type))
{
- _getCanonicalConstraintTypes(outTypeList, andType->left);
- _getCanonicalConstraintTypes(outTypeList, andType->right);
+ _getCanonicalConstraintTypes(outTypeList, andType->getLeft());
+ _getCanonicalConstraintTypes(outTypeList, andType->getRight());
}
else
{
@@ -7087,7 +6931,7 @@ namespace Slang
assert(
genericTypeConstraintDecl.getDecl()->sub.type->astNodeType ==
ASTNodeType::DeclRefType);
- auto typeParamDecl = as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type)->declRef.getDecl();
+ auto typeParamDecl = as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type)->getDeclRef().getDecl();
List<Type*>* constraintTypes = genericConstraints.tryGetValue(typeParamDecl);
assert(constraintTypes);
constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type);
@@ -7107,42 +6951,6 @@ namespace Slang
return result;
}
- Val* SemanticsDeclTypeResolutionVisitor::resolveVal(Val* val)
- {
- if (auto declRefType = as<DeclRefType>(val))
- {
- if (auto concreteType = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(m_astBuilder, declRefType->declRef))
- return as<Type>(concreteType);
- for (auto subst = declRefType->declRef.getSubst(); subst; subst=subst->getOuter())
- {
- if (auto genericSubst = as<GenericSubstitution>(subst))
- {
- ShortList<Val*> newArgs;
- for (auto& arg : genericSubst->getArgs())
- {
- arg = resolveVal(arg);
- SLANG_RELEASE_ASSERT(arg);
- }
- }
- }
- }
- else if (auto subtypeWitness = as<SubtypeWitness>(val))
- {
- auto sub = as<Type>(resolveVal(subtypeWitness->sub));
- auto sup = as<Type>(resolveVal(subtypeWitness->sup));
- if (sub && sup)
- {
- if (sub != subtypeWitness->sub || sup != subtypeWitness->sup)
- {
- auto newVal = tryGetSubtypeWitness(as<Type>(sub), as<Type>(sup));
- if (newVal)
- val = newVal;
- }
- }
- }
- return val;
- }
-
struct ArgsWithDirectionInfo
{
List<Expr*> args;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 3c90c3ed8..e343e3113 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -22,7 +22,7 @@ namespace Slang
DeclRefType* SemanticsVisitor::getExprDeclRefType(Expr * expr)
{
if (auto typetype = as<TypeType>(expr->type))
- return dynamicCast<DeclRefType>(typetype->type);
+ return dynamicCast<DeclRefType>(typetype->getType());
else
return as<DeclRefType>(expr->type);
}
@@ -154,10 +154,8 @@ namespace Slang
//
return maybeMoveTemp(expr, [&](DeclRef<VarDeclBase> varDeclRef)
{
- ExtractExistentialType* openedType = m_astBuilder->create<ExtractExistentialType>();
- openedType->declRef = varDeclRef;
- openedType->originalInterfaceType = expr->type.type;
- openedType->originalInterfaceDeclRef = interfaceDeclRef;
+ ExtractExistentialType* openedType = m_astBuilder->getOrCreate<ExtractExistentialType>(
+ varDeclRef, expr->type.type, interfaceDeclRef);
ExtractExistentialValueExpr* openedValue = m_astBuilder->create<ExtractExistentialValueExpr>();
openedValue->declRef = varDeclRef;
@@ -202,29 +200,9 @@ namespace Slang
if(auto declRefType = as<DeclRefType>(exprType))
{
- if(auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>())
+ if(auto interfaceDeclRef = declRefType->getDeclRef().as<InterfaceDecl>())
{
- // Is there an this-type substitution being applied, so that
- // we are referencing the interface type through a concrete
- // type (e.g., a type parameter constrained to this interface)?
- //
- // Because of the way that substitutions need to mirror the nesting
- // hierarchy of declarations, any this-type substitution pertaining
- // to the chosen interface decl must be the first substitution on
- // the list (which is a linked list from the "inside" out).
- //
- auto thisTypeSubst = as<ThisTypeSubstitution>(interfaceDeclRef.getSubst());
- if(thisTypeSubst && thisTypeSubst->interfaceDecl == interfaceDeclRef.getDecl())
- {
- // This isn't really an existential type, because somebody
- // has already filled in a this-type substitution.
- }
- else
- {
- // Okay, here is the case that matters.
- //
- return openExistential(expr, interfaceDeclRef);
- }
+ return openExistential(expr, interfaceDeclRef);
}
}
@@ -317,7 +295,7 @@ namespace Slang
// actually names a type, because in that case we are doing
// a static member reference.
//
- if (auto typeType = as<TypeType>(baseExpr->type))
+ if (auto typeType = as<TypeType>(baseExpr->type->getCanonicalType()))
{
// Before forming the reference, we will check if the
// member being referenced can even be used as a static
@@ -340,7 +318,7 @@ namespace Slang
getSink()->diagnose(
loc,
Diagnostics::staticRefToNonStaticMember,
- typeType->type,
+ typeType->getType(),
declRef.getName());
}
@@ -493,9 +471,9 @@ namespace Slang
case LookupResultItem::Breadcrumb::Kind::SuperType:
{
auto witness = as<SubtypeWitness>(breadcrumb->val);
- if (auto subDeclRefType = as<DeclRefType>(witness->sub))
+ if (auto subDeclRefType = as<DeclRefType>(witness->getSub()))
{
- if (!as<InterfaceDecl>(subDeclRefType->declRef.getDecl()))
+ if (!as<InterfaceDecl>(subDeclRefType->getDeclRef().getDecl()))
{
// Store the inner most concrete super type.
subType = subDeclRefType;
@@ -515,10 +493,13 @@ namespace Slang
return nullptr;
// Don't synthesize for generic parameters.
- auto parent = as<AggTypeDecl>(subType->declRef.getDecl());
+ auto parent = as<AggTypeDecl>(subType->getDeclRef().getDecl());
if (!parent)
return nullptr;
+ // Don't synthesize for ThisType.
+ if (as<ThisTypeDecl>(subType->getDeclRef().getDecl()))
+ return nullptr;
// If we reach here, we are expecting a synthesized decl defined in `subType`.
// Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl
@@ -607,7 +588,7 @@ namespace Slang
//
auto witness = as<SubtypeWitness>(breadcrumb->val);
SLANG_ASSERT(witness);
- auto expr = createCastToSuperTypeExpr(witness->sup, bb, witness);
+ auto expr = createCastToSuperTypeExpr(witness->getSup(), bb, witness);
// Note that we allow a cast of an l-value to
// be used as an l-value here because it enables
@@ -926,7 +907,7 @@ namespace Slang
if (auto declRefType = as<DeclRefType>(type))
{
- if (auto builtinRequirement = declRefType->declRef.getDecl()->findModifier<BuiltinRequirementModifier>())
+ if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>())
{
if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType)
{
@@ -935,6 +916,7 @@ namespace Slang
return type;
}
}
+ type = resolveType(type);
if (const auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())))
{
auto diffTypeLookupResult = lookUpMember(
@@ -964,10 +946,10 @@ namespace Slang
auto diffTypeExpr = ConstructLookupResultExpr(
diffTypeLookupResult.item,
baseTypeExpr,
- declRefType->declRef.getLoc(),
+ declRefType->getDeclRef().getLoc(),
baseTypeExpr);
- return ExtractTypeFromTypeRepr(diffTypeExpr);
+ return resolveType(ExtractTypeFromTypeRepr(diffTypeExpr));
}
}
}
@@ -991,7 +973,7 @@ namespace Slang
SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr);
if (witness)
{
- m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.addIfNotExists(type->declRef, witness);
+ m_parentDifferentiableAttr->addType(type->getDeclRef(), witness);
}
}
@@ -1048,7 +1030,7 @@ namespace Slang
{
addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
}
- if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>())
+ if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>())
{
foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
{
@@ -1061,23 +1043,13 @@ namespace Slang
maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType);
});
}
- for (auto subst = declRefType->declRef.getSubst(); subst; subst = subst->getOuter())
- {
- if (auto genSubst = as<GenericSubstitution>(subst))
+ SubstitutionSet(declRefType->getDeclRef()).forEachSubstitutionArg([&](Val* arg)
{
- for (auto arg : genSubst->getArgs())
+ if (auto typeArg = as<Type>(arg))
{
- if (auto typeArg = as<Type>(arg))
- {
- maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg);
- }
+ maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg);
}
- }
- else if (auto thisSubst = as<ThisTypeSubstitution>(subst))
- {
- maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, thisSubst->witness->sub);
- }
- }
+ });
return;
}
}
@@ -1302,7 +1274,7 @@ namespace Slang
if (auto constArgVal = as<ConstantIntVal>(argVal))
{
- constArgVals[a] = constArgVal->value;
+ constArgVals[a] = constArgVal->getValue();
}
else
{
@@ -1366,12 +1338,13 @@ namespace Slang
|| opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") ||
opName == getName("?:") || opName == getName("<<") || opName == getName(">>"))
{
- auto result = m_astBuilder->create<FuncCallIntVal>(invokeExpr.getExpr()->type.type);
- result->args.addRange(argVals, argCount);
- result->funcDeclRef = funcDeclRef;
- result->funcType = as<Type>(funcDeclRefExpr.getExpr()->type->substitute(
- m_astBuilder, funcDeclRefExpr.getSubsts()));
- SLANG_RELEASE_ASSERT(result->funcType);
+ auto result = m_astBuilder->getOrCreate<FuncCallIntVal>(
+ invokeExpr.getExpr()->type.type,
+ funcDeclRef,
+ as<Type>(funcDeclRefExpr.getExpr()->type->substitute(
+ m_astBuilder, funcDeclRefExpr.getSubsts())),
+ makeArrayView(argVals, argCount));
+ SLANG_RELEASE_ASSERT(result->getFuncType());
return result;
}
return nullptr;
@@ -1507,18 +1480,14 @@ namespace Slang
if (isInterfaceRequirement(decl))
{
- for (auto subst = declRef.getSubst(); subst; subst = subst->getOuter())
- {
- if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst))
- {
- auto val = WitnessLookupIntVal::tryFold(
- m_astBuilder,
- thisTypeSubst->witness,
- decl,
- declRef.substitute(m_astBuilder, decl->type.type));
- return as<IntVal>(val);
- }
- }
+ auto witness = findThisTypeWitness(SubstitutionSet(declRef), as<InterfaceDecl>(decl->parentDecl));
+
+ auto val = WitnessLookupIntVal::tryFold(
+ m_astBuilder,
+ witness,
+ decl,
+ declRef.substitute(m_astBuilder, decl->type.type));
+ return as<IntVal>(val);
}
if (!getInitExpr(m_astBuilder, declRef))
@@ -1785,7 +1754,7 @@ namespace Slang
getSink()->diagnose(subscriptExpr, Diagnostics::multiDimensionalArrayNotSupported);
}
- auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type));
+ auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->getType()));
auto arrayType = getArrayType(
m_astBuilder,
elementType,
@@ -1804,7 +1773,7 @@ namespace Slang
{
return CheckSimpleSubscriptExpr(
subscriptExpr,
- vecType->elementType);
+ vecType->getElementType());
}
else if (auto matType = as<MatrixExpressionType>(baseType))
{
@@ -1975,8 +1944,8 @@ namespace Slang
if (basicTypeA && basicTypeB)
{
- const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->baseType);
- const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->baseType);
+ const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->getBaseType());
+ const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->getBaseType());
// TODO(JS): Initially this tries to limit where LValueImplict casts happen.
// We could in principal allow different sizes, as long as we converted to a temprorary
@@ -2021,7 +1990,7 @@ namespace Slang
// if this is still an invoke expression, test arguments passed to inout/out parameter are LValues
if(auto funcType = as<FuncType>(invoke->functionExpr->type))
{
- if (!funcType->errorType->equals(m_astBuilder->getBottomType()))
+ if (!funcType->getErrorType()->equals(m_astBuilder->getBottomType()))
{
// If the callee throws, make sure we are inside a try clause.
if (m_enclosingTryClauseType == TryClauseType::None)
@@ -2230,7 +2199,7 @@ namespace Slang
return result;
}
- Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr *expr)
+ Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr)
{
// check the base expression first
expr->functionExpr = CheckTerm(expr->functionExpr);
@@ -2312,6 +2281,7 @@ namespace Slang
auto lookupResult = lookUp(
m_astBuilder,
this, expr->name, expr->scope);
+
if (expr->name == getSession()->getCompletionRequestTokenName())
{
auto scopeKind = CompletionSuggestions::ScopeKind::Expr;
@@ -2357,7 +2327,7 @@ namespace Slang
if (auto modifiedType = as<ModifiedType>(primalType))
{
if (modifiedType->findModifier<NoDiffModifierVal>())
- return modifiedType->base;
+ return modifiedType->getBase();
}
// Get a reference to the builtin 'IDifferentiable' interface
@@ -2379,23 +2349,23 @@ namespace Slang
// Resolve JVP type here.
// Note that this type checking needs to be in sync with
// the auto-generation logic in slang-ir-jvp-diff.cpp
-
- FuncType* jvpType = m_astBuilder->create<FuncType>();
+ List<Type*> paramTypes;
// The JVP return type is float if primal return type is float
// void otherwise.
//
- jvpType->resultType = getDifferentialPairType(originalType->getResultType());
+ auto resultType = getDifferentialPairType(originalType->getResultType());
// No support for differentiating function that throw errors, for now.
- SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType()));
- jvpType->errorType = originalType->errorType;
+ SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType()));
+ auto errorType = originalType->getErrorType();
for (Index i = 0; i < originalType->getParamCount(); i++)
{
if(auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i)))
- jvpType->paramTypes.add(jvpParamType);
+ paramTypes.add(jvpParamType);
}
+ FuncType* jvpType = m_astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType);
return jvpType;
}
@@ -2405,16 +2375,15 @@ namespace Slang
// Resolve backward diff type here.
// Note that this type checking needs to be in sync with
// the auto-generation logic in slang-ir-jvp-diff.cpp
-
- FuncType* type = m_astBuilder->create<FuncType>();
+ List<Type*> paramTypes;
// The backward diff return type is void
//
- type->resultType = m_astBuilder->getVoidType();
+ auto resultType = m_astBuilder->getVoidType();
// No support for differentiating function that throw errors, for now.
- SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType()));
- type->errorType = originalType->errorType;
+ SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType()));
+ auto errorType = originalType->getErrorType();
for (Index i = 0; i < originalType->getParamCount(); i++)
{
@@ -2424,7 +2393,7 @@ namespace Slang
tryGetDifferentialType(m_astBuilder, outType->getValueType());
if (diffElementType)
{
- type->paramTypes.add(diffElementType);
+ paramTypes.add(diffElementType);
}
else
{
@@ -2447,16 +2416,16 @@ namespace Slang
derivType = inoutType->getValueType();
}
}
- type->paramTypes.add(derivType);
+ paramTypes.add(derivType);
}
}
// Last parameter is the initial derivative of the original return type
- auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->resultType);
+ auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->getResultType());
if (dOutType)
- type->paramTypes.add(dOutType);
+ paramTypes.add(dOutType);
- return type;
+ return m_astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType);
}
struct HigherOrderInvokeExprCheckingActions
@@ -2473,9 +2442,8 @@ namespace Slang
if (auto baseFuncGenericDeclRef = declRefExpr->declRef.as<GenericDecl>())
{
// Get inner function
- DeclRef<Decl> unspecializedInnerRef = astBuilder->getSpecializedDeclRef<Decl>(
- getInner(baseFuncGenericDeclRef),
- baseFuncGenericDeclRef.getSubst());
+ DeclRef<Decl> unspecializedInnerRef = createDefaultSubstitutionsIfNeeded(astBuilder, semantics,
+ astBuilder->getMemberDeclRef(baseFuncGenericDeclRef, getInner(baseFuncGenericDeclRef)));
auto callableDeclRef = unspecializedInnerRef.as<CallableDecl>();
if (!callableDeclRef)
return nullptr;
@@ -2677,10 +2645,10 @@ namespace Slang
return false;
if (!isIntegerBaseType(getVectorBaseType(vectorType)))
return false;
- auto constElementCount = as<ConstantIntVal>(vectorType->elementCount);
+ auto constElementCount = as<ConstantIntVal>(vectorType->getElementCount());
if (!constElementCount)
return false;
- return constElementCount->value == 3;
+ return constElementCount->getValue() == 3;
};
expr->threadGroupSize = dispatchExpr(expr->threadGroupSize, *this);
if (!isInt3Type(expr->threadGroupSize->type.type))
@@ -2836,7 +2804,7 @@ namespace Slang
//
if( auto declRefType = as<DeclRefType>(typeExp.type) )
{
- if(const auto structDeclRef = as<StructDecl>(declRefType->declRef))
+ if(const auto structDeclRef = as<StructDecl>(declRefType->getDeclRef()))
{
if( expr->arguments.getCount() == 1 )
{
@@ -3051,7 +3019,7 @@ namespace Slang
auto baseType = expr->type;
if (auto pointerLikeType = as<PointerLikeType>(baseType))
{
- auto elementType = QualType(pointerLikeType->elementType);
+ auto elementType = QualType(pointerLikeType->getElementType());
elementType.isLeftValue = baseType.isLeftValue;
auto derefExpr = m_astBuilder->create<DerefExpr>();
@@ -3230,7 +3198,7 @@ namespace Slang
if (auto constantColCount = as<ConstantIntVal>(baseColCount))
{
return CheckMatrixSwizzleExpr(memberRefExpr, baseElementType,
- constantRowCount->value, constantColCount->value);
+ constantRowCount->getValue(), constantColCount->getValue());
}
}
getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on matrix of unknown size");
@@ -3350,7 +3318,7 @@ namespace Slang
{
if (auto constantElementCount = as<ConstantIntVal>(baseElementCount))
{
- return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value);
+ return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->getValue());
}
else
{
@@ -3381,6 +3349,7 @@ namespace Slang
m_astBuilder,
this,
expr->name,
+ namespaceDeclRef.getDecl(),
namespaceDeclRef);
if (!lookupResult.isValid())
{
@@ -3406,7 +3375,7 @@ namespace Slang
//
// TODO: this duplicates a *lot* of logic with the case below.
// We need to fix that.
- auto type = typeType->type;
+ auto type = typeType->getType();
if (as<ErrorType>(type))
{
@@ -3577,7 +3546,7 @@ namespace Slang
for (auto lookupResult : overloadedExpr->lookupResult2)
{
bool shouldRemove = false;
- if (lookupResult.declRef.getParent(m_astBuilder).as<InterfaceDecl>())
+ if (lookupResult.declRef.getParent().as<InterfaceDecl>())
{
shouldRemove = true;
}
@@ -3627,8 +3596,8 @@ namespace Slang
{
return CheckSwizzleExpr(
expr,
- baseVecType->elementType,
- baseVecType->elementCount);
+ baseVecType->getElementType(),
+ baseVecType->getElementCount());
}
else if(auto baseScalarType = as<BasicExpressionType>(baseType))
{
@@ -3893,7 +3862,7 @@ namespace Slang
types.reserve(expr->parameters.getCount());
for(const auto& t : expr->parameters)
types.add(t.type);
- auto funcType = m_astBuilder->getFuncType(std::move(types), expr->result.type);
+ auto funcType = m_astBuilder->getFuncType(types.getArrayView(), expr->result.type);
expr->type = m_astBuilder->getTypeType(funcType);
return expr;
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 575d4aff7..46cc329a9 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -70,7 +70,7 @@ namespace Slang
{
if (auto basicType = as<BasicExpressionType>(typeIn))
{
- auto rs = makeBasicTypeKey(basicType->baseType);
+ auto rs = makeBasicTypeKey(basicType->getBaseType());
if (auto constInt = as<IntegerLiteralExpr>(exprIn))
{
if (constInt->value < 0)
@@ -83,11 +83,11 @@ namespace Slang
}
else if (auto vectorType = as<VectorExpressionType>(typeIn))
{
- if (auto elemCount = as<ConstantIntVal>(vectorType->elementCount))
+ if (auto elemCount = as<ConstantIntVal>(vectorType->getElementCount()))
{
- if( auto elemBasicType = as<BasicExpressionType>(vectorType->elementType) )
+ if( auto elemBasicType = as<BasicExpressionType>(vectorType->getElementType()) )
{
- return makeBasicTypeKey(elemBasicType->baseType, elemCount->value);
+ return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount->getValue());
}
}
}
@@ -99,7 +99,7 @@ namespace Slang
{
if( auto elemBasicType = as<BasicExpressionType>(matrixType->getElementType()) )
{
- return makeBasicTypeKey(elemBasicType->baseType, elemCount1->value, elemCount2->value);
+ return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount1->getValue(), elemCount2->getValue());
}
}
}
@@ -246,7 +246,7 @@ namespace Slang
// When required, a candidate can store a pre-checked list of
// arguments so that we don't have to repeat work across checking
// phases. Currently this is only needed for generics.
- Substitutions* subst = nullptr;
+ SubstitutionSet subst;
};
struct TypeCheckingCache
@@ -614,7 +614,7 @@ namespace Slang
InheritanceInfo getInheritanceInfo(DeclRef<ExtensionDecl> const& extension);
/// Try get subtype witness from cache, returns true if cache contains a result for the query.
- bool tryGetSubtypeWitness(Type* sub, Type* sup, SubtypeWitness*& outWitness)
+ bool tryGetSubtypeWitnessFromCache(Type* sub, Type* sup, SubtypeWitness*& outWitness)
{
auto pair = TypePair{ sub, sup };
return m_mapTypePairToSubtypeWitness.tryGetValue(pair, outWitness);
@@ -997,6 +997,21 @@ namespace Slang
void diagnoseDeprecatedDeclRefUsage(DeclRef<Decl> declRef, SourceLoc loc, Expr* originalExpr);
+ DeclRef<Decl> getDefaultDeclRef(Decl* decl)
+ {
+ return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl));
+ }
+
+ DeclRef<Decl> getSpecializedDeclRef(DeclRef<Decl> declToSpecialize, DeclRef<Decl> declRefWithSpecializationArgs)
+ {
+ return declRefWithSpecializationArgs.substitute(m_astBuilder, declToSpecialize);
+ }
+
+ DeclRef<Decl> getSpecializedDeclRef(Decl* declToSpecialize, DeclRef<Decl> declRefWithSpecializationArgs)
+ {
+ return declRefWithSpecializationArgs.substitute(m_astBuilder, getDefaultDeclRef(declToSpecialize));
+ }
+
DeclRefExpr* ConstructDeclRefExpr(
DeclRef<Decl> declRef,
Expr* baseExpr,
@@ -1026,6 +1041,18 @@ namespace Slang
SourceLoc loc,
Expr* originalExpr);
+
+ Val* resolveVal(Val* val)
+ {
+ if (!val) return nullptr;
+ return val->resolve();
+ }
+ Type* resolveType(Type* type)
+ {
+ return (Type*)resolveVal(type);
+ }
+ DeclRef<Decl> resolveDeclRef(DeclRef<Decl> declRef);
+
/// Attempt to "resolve" an overloaded `LookupResult` to only include the "best" results
LookupResult resolveOverloadedLookup(LookupResult const& lookupResult);
@@ -1651,12 +1678,12 @@ namespace Slang
List<GenericTypeConstraintDecl*>& outConstraints);
/// Determine if `left` and `right` have matching generic signatures.
- /// If they do, then outputs a substitution to `ioSubstRightToLeft` that
- /// can be used to specialize `right` to the parameters of `left`.
+ /// If they do, then outputs a specialized declRef to `ioSubstRightToLeft` that
+ /// represents a reference to `right` with the parameters of `left`.
bool doGenericSignaturesMatch(
GenericDecl* left,
GenericDecl* right,
- GenericSubstitution** outSubstRightToLeft);
+ DeclRef<Decl>* outSpecializedRightInner);
// Check if two functions have the same signature for the purposes
// of overload resolution.
@@ -1664,9 +1691,6 @@ namespace Slang
DeclRef<FuncDecl> fst,
DeclRef<FuncDecl> snd);
- List<Val*> getDefaultSubstitutionArgs(
- GenericDecl* genericDecl);
-
Result checkRedeclaration(Decl* newDecl, Decl* oldDecl);
Result checkFuncRedeclaration(FuncDecl* newDecl, FuncDecl* oldDecl);
void checkForRedeclaration(Decl* decl);
@@ -1901,12 +1925,13 @@ namespace Slang
// The `varSubst` argument provides the list of constraint
// variables that were created for the system.
//
- // Returns a new substitution representing the values that
+ // Returns a new declref to the inner decl of `genericDeclRef`,
+ // representing the specialized generic with the values
// we solved for along the way.
- SubstitutionSet trySolveConstraintSystem(
+ DeclRef<Decl> trySolveConstraintSystem(
ConstraintSystem* system,
DeclRef<GenericDecl> genericDeclRef,
- GenericSubstitution* substWithKnownGenericArgs = nullptr);
+ ArrayView<Val*> knownGenericArgs);
// State related to overload resolution for a call
@@ -2033,7 +2058,7 @@ namespace Slang
Expr* createGenericDeclRef(
Expr* baseExpr,
Expr* originalExpr,
- GenericSubstitution* subst);
+ SubstitutionSet substSet);
// Take an overload candidate that previously got through
// `TryCheckOverloadCandidate` above, and try to finish
@@ -2112,15 +2137,15 @@ namespace Slang
Val* fst,
Val* snd);
- bool tryUnifySubstitutions(
+ bool tryUnifyDeclRef(
ConstraintSystem& constraints,
- Substitutions* fst,
- Substitutions* snd);
+ DeclRefBase* fst,
+ DeclRefBase* snd);
- bool tryUnifyGenericSubstitutions(
- ConstraintSystem& constraints,
- GenericSubstitution* fst,
- GenericSubstitution* snd);
+ bool tryUnifyGenericAppDeclRef(
+ ConstraintSystem& constraints,
+ GenericAppDeclRef* fst,
+ GenericAppDeclRef* snd);
bool TryUnifyTypeParam(
ConstraintSystem& constraints,
@@ -2153,7 +2178,7 @@ namespace Slang
Type* snd);
// Is the candidate extension declaration actually applicable to the given type
- DeclRef<ExtensionDecl> ApplyExtensionToType(
+ DeclRef<ExtensionDecl> applyExtensionToType(
ExtensionDecl* extDecl,
Type* type);
@@ -2166,7 +2191,7 @@ namespace Slang
DeclRef<Decl> inferGenericArguments(
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
- GenericSubstitution* substWithKnownGenericArgs,
+ ArrayView<Val*> knownGenericArgs,
List<Type*> *innerParameterTypes = nullptr);
void AddTypeOverloadCandidates(
@@ -2209,7 +2234,7 @@ namespace Slang
void addOverloadCandidatesForCallToGeneric(
LookupResultItem genericItem,
OverloadResolveContext& context,
- GenericSubstitution* substWithKnownGenericArgs = nullptr);
+ ArrayView<Val*> knownGenericArgs);
/// Check a generic application where the operands have already been checked.
Expr* checkGenericAppWithCheckedArgs(GenericAppExpr* genericAppExpr);
@@ -2283,7 +2308,7 @@ namespace Slang
visitor->ensureDecl(decl, state);
}
- DeclRef<ExtensionDecl> ApplyExtensionToType(
+ DeclRef<ExtensionDecl> applyExtensionToType(
SemanticsVisitor* semantics,
ExtensionDecl* extDecl,
Type* type);
@@ -2318,8 +2343,6 @@ namespace Slang
Expr* visitSharedTypeExpr(SharedTypeExpr* expr);
- Expr* visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr);
-
Expr* visitInvokeExpr(InvokeExpr *expr);
Expr* visitSelectExpr(SelectExpr* expr);
diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp
index 5a6adbae5..5fff47cf1 100644
--- a/source/slang/slang-check-inheritance.cpp
+++ b/source/slang/slang-check-inheritance.cpp
@@ -220,7 +220,7 @@ namespace Slang
DeclRef<Decl> baseDeclRef;
if (auto baseDeclRefType = as<DeclRefType>(baseType))
{
- baseDeclRef = baseDeclRefType->declRef;
+ baseDeclRef = baseDeclRefType->getDeclRef();
}
addDirectBaseFacet(
@@ -239,9 +239,9 @@ namespace Slang
// In the case where we have an aggregate type or `extension`
// declaration, we can use the explicit list of direct bases.
//
- for (auto inheritanceDeclRef : getMembersOfType<InheritanceDecl>(_getASTBuilder(), aggTypeDeclBaseRef))
+ for (auto typeConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(_getASTBuilder(), aggTypeDeclBaseRef))
{
- visitor.ensureDecl(inheritanceDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl);
+ visitor.ensureDecl(typeConstraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl);
// Note: In certain cases something takes the *syntactic* form of an inheritance
// clause, but it is not actually something that should be treated as implying
@@ -251,38 +251,20 @@ namespace Slang
// We skip such pseudo-inheritance relationships for the purposes of determining
// the linearized list of bases.
//
- if (inheritanceDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>())
+ if (typeConstraintDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>())
continue;
// The base type and subtype witness can easily be determined
// using the `InheritanceDecl`.
//
- auto baseType = getSup(astBuilder, inheritanceDeclRef);
+ auto baseType = getSup(astBuilder, typeConstraintDeclRef);
auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness(
selfType,
baseType,
- inheritanceDeclRef);
+ typeConstraintDeclRef);
addDirectBaseType(baseType, satisfyingWitness);
}
-
- // In the case of an `associatedtype`, the constraints on the associated
- // type are encoded as `GenericTypeConstraintDecl`s instead of `InheritanceDecl`s.
- //
- // TOD(tfoley): Can we try to unify the representations of these to avoid having
- // to iterate twice?
- //
- for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, aggTypeDeclBaseRef))
- {
- visitor.ensureDecl(constraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl);
-
- auto baseType = getSup(astBuilder, constraintDeclRef);
- auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness(
- selfType,
- baseType,
- constraintDeclRef);
- addDirectBaseType(baseType, satisfyingWitness);
- }
}
else if (auto genericTypeParamDeclRef = declRef.as<GenericTypeParamDecl>())
{
@@ -296,7 +278,7 @@ namespace Slang
// representation would need to take into account canonicalization of
// constraints.
- auto genericDeclRef = genericTypeParamDeclRef.getParent(astBuilder).as<GenericDecl>();
+ auto genericDeclRef = genericTypeParamDeclRef.getParent().as<GenericDecl>();
SLANG_ASSERT(genericDeclRef);
ensureDecl(&visitor, genericDeclRef.getDecl(), DeclCheckState::CanSpecializeGeneric);
@@ -317,7 +299,7 @@ namespace Slang
auto subDeclRefType = as<DeclRefType>(subType);
if (!subDeclRefType)
continue;
- if (subDeclRefType->declRef != genericTypeParamDeclRef)
+ if (subDeclRefType->getDeclRef() != genericTypeParamDeclRef)
continue;
// Because the constraint is a declared inheritance relationship,
@@ -376,7 +358,7 @@ namespace Slang
// the extension to the type and see if we succeed in
// making a match.
//
- auto extDeclRef = ApplyExtensionToType(&visitor, extDecl, selfType);
+ auto extDeclRef = applyExtensionToType(&visitor, extDecl, selfType);
if (!extDeclRef)
continue;
@@ -858,15 +840,15 @@ namespace Slang
// bottleneck through the logic that gets shared between
// type and `extension` declarations.
//
- return _getInheritanceInfo(declRefType->declRef, declRefType);
+ return _getInheritanceInfo(declRefType->getDeclRef(), declRefType);
}
else if (auto conjunctionType = as<AndType>(type))
{
// In this case, we have a type of the form `L & R`,
// such that it is a subtype of both `L` and `R`.
//
- auto leftType = conjunctionType->left;
- auto rightType = conjunctionType->right;
+ auto leftType = conjunctionType->getLeft();
+ auto rightType = conjunctionType->getRight();
// The linearized inheritance list for the conjunction
// must include all the facets from the lists for `L`
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index ce80d0002..ab34c83dd 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -310,12 +310,12 @@ namespace Slang
{
return false;
}
- if (intValue->value < 1)
+ if (intValue->getValue() < 1)
{
- getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, intValue->value);
+ getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, intValue->getValue());
return false;
}
- value = int32_t(intValue->value);
+ value = int32_t(intValue->getValue());
}
values[i] = value;
}
@@ -341,13 +341,13 @@ namespace Slang
}
const IRIntegerValue kMaxAnyValueSize = 0x7FFF;
- if (value->value > kMaxAnyValueSize)
+ if (value->getValue() > kMaxAnyValueSize)
{
getSink()->diagnose(anyValueSizeAttr->loc, Diagnostics::anyValueSizeExceedsLimit, kMaxAnyValueSize);
return false;
}
- anyValueSizeAttr->size = int32_t(value->value);
+ anyValueSizeAttr->size = int32_t(value->getValue());
}
else if (auto bindingAttr = as<GLSLBindingAttribute>(attr))
{
@@ -369,8 +369,8 @@ namespace Slang
return false;
}
- bindingAttr->binding = int32_t(binding->value);
- bindingAttr->set = int32_t(set->value);
+ bindingAttr->binding = int32_t(binding->getValue());
+ bindingAttr->set = int32_t(set->getValue());
}
else if (auto simpleLayoutAttr = as<GLSLSimpleIntegerLayoutAttribute>(attr))
{
@@ -388,7 +388,7 @@ namespace Slang
return false;
}
- simpleLayoutAttr->value = int32_t(value->value);
+ simpleLayoutAttr->value = int32_t(value->getValue());
}
else if (auto maxVertexCountAttr = as<MaxVertexCountAttribute>(attr))
{
@@ -397,7 +397,7 @@ namespace Slang
if (!val) return false;
- maxVertexCountAttr->value = (int32_t)val->value;
+ maxVertexCountAttr->value = (int32_t)val->getValue();
}
else if (auto instanceAttr = as<InstanceAttribute>(attr))
{
@@ -406,7 +406,7 @@ namespace Slang
if (!val) return false;
- instanceAttr->value = (int32_t)val->value;
+ instanceAttr->value = (int32_t)val->getValue();
}
else if (auto entryPointAttr = as<EntryPointAttribute>(attr))
{
@@ -486,7 +486,7 @@ namespace Slang
//IntVal* outIntVal;
if (auto cInt = checkConstantEnumVal(attr->args[0]))
{
- targetClassId = (uint32_t)(cInt->value);
+ targetClassId = (uint32_t)(cInt->getValue());
}
else
{
@@ -515,7 +515,7 @@ namespace Slang
}
auto cint = checkConstantIntVal(attr->args[0]);
if (cint)
- forceUnrollAttr->maxIterations = (int32_t)cint->value;
+ forceUnrollAttr->maxIterations = (int32_t)cint->getValue();
}
else if (auto maxItersAttrs = as<MaxItersAttribute>(attr))
{
@@ -528,7 +528,7 @@ namespace Slang
auto cint = checkConstantIntVal(attr->args[0]);
if (cint)
{
- maxItersAttrs->value = (int32_t) cint->value;
+ maxItersAttrs->value = (int32_t) cint->getValue();
}
}
}
@@ -547,10 +547,12 @@ namespace Slang
bool typeChecked = false;
if (auto basicType = as<BasicExpressionType>(paramDecl->getType()))
{
- if (basicType->baseType == BaseType::Int)
+ if (basicType->getBaseType() == BaseType::Int)
{
if (auto cint = checkConstantIntVal(arg))
{
+ for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++)
+ attr->intArgVals.add(nullptr);
attr->intArgVals[(uint32_t)paramIndex] = cint;
}
typeChecked = true;
@@ -578,7 +580,7 @@ namespace Slang
SLANG_ASSERT(attr->args.getCount() == 1);
auto cint = checkConstantIntVal(attr->args[0]);
if (cint)
- diffAttr->maxOrder = (int32_t)cint->value;
+ diffAttr->maxOrder = (int32_t)cint->getValue();
}
else if (auto formatAttr = as<FormatAttribute>(attr))
{
@@ -652,7 +654,7 @@ namespace Slang
if (!val) return false;
- rayPayloadAttr->location = (int32_t)val->value;
+ rayPayloadAttr->location = (int32_t)val->getValue();
}
else if (auto callablePayloadAttr = as<VulkanCallablePayloadAttribute>(attr))
{
@@ -661,7 +663,7 @@ namespace Slang
if (!val) return false;
- callablePayloadAttr->location = (int32_t)val->value;
+ callablePayloadAttr->location = (int32_t)val->getValue();
}
else if (auto hitObjectAttributesAttr = as<VulkanHitObjectAttributesAttribute>(attr))
{
@@ -670,7 +672,7 @@ namespace Slang
if (!val) return false;
- hitObjectAttributesAttr->location = (int32_t)val->value;
+ hitObjectAttributesAttr->location = (int32_t)val->getValue();
}
else if (as<UserDefinedDerivativeAttribute>(attr) || as<PrimalSubstituteAttribute>(attr))
{
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 2fa13f3fa..8709ae763 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -70,7 +70,7 @@ namespace Slang
{
if (auto resultType = as<DeclRefType>(candidate.resultType))
{
- if (resultType->declRef.as<ClassDecl>())
+ if (resultType->getDeclRef().as<ClassDecl>())
{
isClassType = true;
}
@@ -373,7 +373,7 @@ namespace Slang
//
if( !val )
{
- val = m_astBuilder->create<ErrorIntVal>();
+ val = m_astBuilder->getOrCreate<ErrorIntVal>(m_astBuilder->getIntType());
}
checkedArgs.add(val);
}
@@ -383,8 +383,8 @@ namespace Slang
}
}
- auto genSubst = m_astBuilder->getOrCreateGenericSubstitution(nullptr, genericDeclRef.getDecl(), checkedArgs.getArrayView());
- candidate.subst = genSubst;
+ auto genSubst = m_astBuilder->getGenericAppDeclRef(genericDeclRef, checkedArgs.getArrayView());
+ candidate.subst = SubstitutionSet(genSubst);
// Once we are done processing the parameters of the generic,
// we will have build up a usable `checkedArgs` array and
@@ -550,19 +550,17 @@ namespace Slang
// We should have the existing arguments to the generic
// handy, so that we can construct a substitution list.
- auto subst = as<GenericSubstitution>(candidate.subst);
- SLANG_ASSERT(subst);
+ auto substArgs = tryGetGenericArguments(candidate.subst, genericDeclRef.getDecl());
+ SLANG_ASSERT(substArgs.getCount());
- subst = getASTBuilder()->getOrCreateGenericSubstitution(
- genericDeclRef.getSubst(), genericDeclRef.getDecl(), subst->getArgs());
-
- List<Val*> newArgs = subst->getArgs();
+ List<Val*> newArgs;
+ for (auto arg : substArgs)
+ newArgs.add(arg);
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
- DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getSpecializedDeclRef(
- constraintDecl, subst);
-
+ DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(genericDeclRef, substArgs, constraintDecl).as<GenericTypeConstraintDecl>();
+
auto sub = getSub(m_astBuilder, constraintDeclRef);
auto sup = getSup(m_astBuilder, constraintDeclRef);
@@ -575,14 +573,14 @@ namespace Slang
{
if(context.mode != OverloadResolveContext::Mode::JustTrying)
{
- subTypeWitness = tryGetSubtypeWitness(sub, sup);
+ subTypeWitness = isSubtype(sub, sup);
getSink()->diagnose(context.loc, Diagnostics::typeArgumentDoesNotConformToInterface, sub, sup);
}
return false;
}
}
- candidate.subst = m_astBuilder->getOrCreateGenericSubstitution(nullptr, genericDeclRef.getDecl(), newArgs);
+ candidate.subst = SubstitutionSet(m_astBuilder->getGenericAppDeclRef(genericDeclRef, newArgs.getArrayView()));
// Done checking all the constraints, hooray.
return true;
@@ -617,7 +615,7 @@ namespace Slang
Expr* SemanticsVisitor::createGenericDeclRef(
Expr* baseExpr,
Expr* originalExpr,
- GenericSubstitution* subst)
+ SubstitutionSet substArgs)
{
auto baseDeclRefExpr = as<DeclRefExpr>(baseExpr);
if (!baseDeclRefExpr)
@@ -631,10 +629,9 @@ namespace Slang
SLANG_DIAGNOSE_UNEXPECTED(getSink(), baseExpr, "expected a reference to a generic declaration");
return CreateErrorExpr(originalExpr);
}
-
- subst = m_astBuilder->getOrCreateGenericSubstitution(baseGenericRef.getSubst(), baseGenericRef.getDecl(), subst->getArgs());
-
- DeclRef<Decl> innerDeclRef = m_astBuilder->getSpecializedDeclRef<Decl>(getInner(baseGenericRef), subst);
+ auto genSubst = substArgs.findGenericAppDeclRef(baseGenericRef.getDecl());
+ SLANG_ASSERT(genSubst);
+ DeclRef<Decl> innerDeclRef = m_astBuilder->getGenericAppDeclRef(baseGenericRef, genSubst->getArgs());
Expr* base = nullptr;
if (auto mbrExpr = as<MemberExpr>(baseExpr))
@@ -768,14 +765,16 @@ namespace Slang
expr->loc = context.loc;
expr->originalExpr = baseExpr;
expr->baseGenericDeclRef = as<DeclRefExpr>(baseExpr)->declRef.as<GenericDecl>();
- expr->substWithKnownGenericArgs = (GenericSubstitution*)candidate.subst;
+ auto args = tryGetGenericArguments(candidate.subst, expr->baseGenericDeclRef.getDecl());
+ for (auto arg : args)
+ expr->knownGenericArgs.add(arg);
return expr;
}
return createGenericDeclRef(
baseExpr,
context.originalExpr,
- as<GenericSubstitution>(candidate.subst));
+ candidate.subst);
break;
default:
@@ -801,12 +800,14 @@ namespace Slang
/// Does the given `declRef` represent an interface requirement?
bool isInterfaceRequirement(ASTBuilder* builder, DeclRef<Decl> const& declRef)
{
+ SLANG_UNUSED(builder);
+
if(!declRef)
return false;
- auto parent = declRef.getParent(builder);
+ auto parent = declRef.getParent();
if(parent.as<GenericDecl>())
- parent = parent.getParent(builder);
+ parent = parent.getParent();
if(parent.as<InterfaceDecl>())
return true;
@@ -826,7 +827,7 @@ namespace Slang
// "inner" declaration of a generic. That means that
// the parent of the decl ref must be a generic.
//
- auto parentGeneric = declRef.getParent(m_astBuilder).as<GenericDecl>();
+ auto parentGeneric = declRef.getParent().as<GenericDecl>();
if(!parentGeneric)
return 0;
//
@@ -863,7 +864,18 @@ namespace Slang
if(leftIsInterfaceRequirement != rightIsInterfaceRequirement)
return int(leftIsInterfaceRequirement) - int(rightIsInterfaceRequirement);
- // TODO: We should always have rules such that in a tie a declaration
+ // If both are interface requirements, prefer to more derived interface.
+ if (leftIsInterfaceRequirement && rightIsInterfaceRequirement)
+ {
+ auto leftType = DeclRefType::create(m_astBuilder, left.declRef.getParent());
+ auto rightType = DeclRefType::create(m_astBuilder, right.declRef.getParent());
+ if (isSubtype(leftType, rightType))
+ return -1;
+ if (isSubtype(rightType, leftType))
+ return 1;
+ }
+
+ // TODO: We should generalize above rules such that in a tie a declaration
// A::m is better than B::m when all other factors are equal and
// A inherits from B.
@@ -1227,7 +1239,7 @@ namespace Slang
DeclRef<Decl> SemanticsVisitor::inferGenericArguments(
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
- GenericSubstitution* substWithKnownGenericArgs,
+ ArrayView<Val*> knownGenericArgs,
List<Type*> *innerParameterTypes)
{
// We have been asked to infer zero or more arguments to
@@ -1265,28 +1277,10 @@ namespace Slang
// the "inner" declaration of the generic (e.g., the `FuncitonDecl`
// under the `GenericDecl`).
//
- // In the case where no explicit arguments are available, we will
- // use any substitutions that were in place for referring to the
- // generic itself.
- //
- Substitutions* substForInnerDecl = genericDeclRef.getSubst();
- //
- // In the case where we have explicit/known arguments,
- // we will use those as our baseline substitutions.
- //
- if (substWithKnownGenericArgs)
- {
- substForInnerDecl = substWithKnownGenericArgs;
- }
-
- auto innerDecl = getInner(genericDeclRef);
- DeclRef<Decl> partiallySpecializedInnerRef = m_astBuilder->getSpecializedDeclRef<Decl>(
- innerDecl,
- substForInnerDecl);
-
// Check what type of declaration we are dealing with, and then try
// to match it up with the arguments accordingly...
- if (auto funcDeclRef = partiallySpecializedInnerRef.as<CallableDecl>())
+
+ if (auto funcDeclRef = as<CallableDecl>(genericDeclRef.getDecl()->inner))
{
List<Type*> paramTypes;
if (!innerParameterTypes)
@@ -1360,28 +1354,8 @@ namespace Slang
// TODO(tfoley): We probably need to pass along the explicit arguments here,
// so that the solver knows to accept those arguments as-is.
//
- auto constraintSubst = trySolveConstraintSystem(
- &constraints, genericDeclRef, substWithKnownGenericArgs);
- if (!constraintSubst)
- {
- // In this case, the solver failed to find a solution to the constraint
- // system, and we will signal that failure up to the client that called
- // this operation.
- //
- // TODO: We really ought to be passing up some kind of representation
- // of the failure, so that constraint-related issues can be reported to
- // the user. This could either be a return path here (returning some
- // diagnostics), or this code could have a "just trying" vs. "actually
- // do things" distinction like some other steps.
- //
- return DeclRef<Decl>();
- }
-
- // If we found a solution (that is, a set of argument values that satisfy
- // all the constraints), we can construct a reference to the inner
- // declaration that applies the generic to those arguments.
- //
- return m_astBuilder->getSpecializedDeclRef<Decl>(innerDecl, constraintSubst);
+ return trySolveConstraintSystem(
+ &constraints, genericDeclRef, knownGenericArgs);
}
void SemanticsVisitor::AddTypeOverloadCandidates(
@@ -1424,13 +1398,13 @@ namespace Slang
void SemanticsVisitor::addOverloadCandidatesForCallToGeneric(
LookupResultItem genericItem,
OverloadResolveContext& context,
- GenericSubstitution* substWithKnownGenericArgs)
+ ArrayView<Val*> knownGenericArgs)
{
auto genericDeclRef = genericItem.declRef.as<GenericDecl>();
SLANG_ASSERT(genericDeclRef);
// Try to infer generic arguments, based on the context
- DeclRef<Decl> innerRef = inferGenericArguments(genericDeclRef, context, substWithKnownGenericArgs);
+ DeclRef<Decl> innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs);
if (innerRef)
{
@@ -1475,7 +1449,7 @@ namespace Slang
LookupResultItem innerItem;
innerItem.breadcrumbs = item.breadcrumbs;
innerItem.declRef = genericDeclRef;
- addOverloadCandidatesForCallToGeneric(innerItem, context);
+ addOverloadCandidatesForCallToGeneric(innerItem, context, ArrayView<Val*>());
}
else if( auto typeDefDeclRef = item.declRef.as<TypeDefDecl>() )
{
@@ -1578,7 +1552,7 @@ namespace Slang
addOverloadCandidatesForCallToGeneric(
LookupResultItem(partiallyAppliedGenericExpr->baseGenericDeclRef),
context,
- partiallyAppliedGenericExpr->substWithKnownGenericArgs);
+ partiallyAppliedGenericExpr->knownGenericArgs.getArrayView());
}
else if (auto typeType = as<TypeType>(funcExprType))
{
@@ -1588,7 +1562,7 @@ namespace Slang
//
// TODO(tfoley): are there any meaningful types left
// that aren't declaration references?
- auto type = typeType->type;
+ auto type = typeType->getType();
AddTypeOverloadCandidates(type, context);
return;
}
@@ -1633,12 +1607,16 @@ namespace Slang
paramTypes.add(removeParamDirType(diffFuncType->getParamType(ii)));
// Try to infer generic arguments, based on the updated context.
+ OverloadResolveContext subContext = context;
DeclRef<Decl> innerRef = inferGenericArguments(
baseFuncGenericDeclRef,
context,
- nullptr,
+ ArrayView<Val*>(),
&paramTypes);
+ if (!innerRef)
+ return;
+
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Expr;
if (innerRef)
diff --git a/source/slang/slang-check-resolve-val.cpp b/source/slang/slang-check-resolve-val.cpp
new file mode 100644
index 000000000..91722f82c
--- /dev/null
+++ b/source/slang/slang-check-resolve-val.cpp
@@ -0,0 +1,48 @@
+// slang-check-resolve-val.cpp
+
+// Logic for resolving/simplifying Types and DeclRefs.
+
+#include "slang-check-impl.h"
+
+#include "slang-lookup.h"
+#include "slang-syntax.h"
+#include "slang-ast-synthesis.h"
+#include "slang-ast-reflect.h"
+
+namespace Slang
+{
+
+Type* Type::createCanonicalType()
+{
+ SLANG_AST_NODE_VIRTUAL_CALL(Type, createCanonicalType, ());
+}
+
+Val* Type::_resolveImplOverride()
+{
+ Val* resolvedVal = createCanonicalType();
+ return resolvedVal;
+}
+
+DeclRefBase* _resolveAsDeclRef(DeclRefBase* declRefToResolve);
+
+Type* DeclRefType::_createCanonicalTypeOverride()
+{
+ auto astBuilder = getCurrentASTBuilder();
+
+ // A declaration reference is already canonical
+ auto resolvedDeclRef = getDeclRef();
+ resolvedDeclRef = _resolveAsDeclRef(getDeclRef().declRefBase);
+ if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, resolvedDeclRef))
+ return as<Type>(satisfyingVal);
+ if (resolvedDeclRef != getDeclRef())
+ return DeclRefType::create(astBuilder, resolvedDeclRef);
+ return this;
+}
+
+
+Val* SubtypeWitness::_resolveImplOverride()
+{
+ return as<SubtypeWitness>(defaultResolveImpl());
+}
+
+}
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 657438222..d9bb11548 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -17,7 +17,7 @@ namespace Slang
auto basicType = as<BasicExpressionType>(type);
if (basicType)
{
- return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt);
+ return (basicType->getBaseType() == BaseType::Int || basicType->getBaseType() == BaseType::UInt);
}
}
// Can be an int/uint vector from size 1 to 3
@@ -27,20 +27,21 @@ namespace Slang
{
return false;
}
- auto elemCount = as<ConstantIntVal>(vectorType->elementCount);
- if (elemCount->value < 1 || elemCount->value > 3)
+ auto elemCount = as<ConstantIntVal>(vectorType->getElementCount());
+ if (elemCount->getValue() < 1 || elemCount->getValue() > 3)
{
return false;
}
// Must be a basic type
- auto basicType = as<BasicExpressionType>(vectorType->elementType);
+ auto basicType = as<BasicExpressionType>(vectorType->getElementType());
if (!basicType)
{
return false;
}
// Must be integral
- return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt);
+ auto baseType = basicType->getBaseType();
+ return (baseType == BaseType::Int || baseType == BaseType::UInt);
}
}
@@ -83,7 +84,7 @@ namespace Slang
if( auto declRefType = as<DeclRefType>(type) )
{
- auto typeDeclRef = declRefType->declRef;
+ auto typeDeclRef = declRefType->getDeclRef();
if( auto interfaceDeclRef = typeDeclRef.as<InterfaceDecl>() )
{
// Each leaf parameter of interface type adds a specialization
@@ -792,6 +793,8 @@ namespace Slang
void FrontEndCompileRequest::checkEntryPoints()
{
auto linkage = getLinkage();
+ SLANG_AST_BUILDER_RAII(linkage->getASTBuilder());
+
auto sink = getSink();
// The validation of entry points here will be modal, and controlled
@@ -1025,7 +1028,7 @@ namespace Slang
//
if( auto argDeclRefType = as<DeclRefType>(argType) )
{
- auto argDeclRef = argDeclRefType->declRef;
+ auto argDeclRef = argDeclRefType->getDeclRef();
if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>())
{
if(argGenericParamDeclRef.getDecl() == genericTypeParamDecl)
@@ -1193,7 +1196,7 @@ namespace Slang
// the semantic checking machinery to expand out
// the rest of the arguments via inference...
- auto genericDeclRef = m_funcDeclRef.getParent(getLinkage()->getASTBuilder()).as<GenericDecl>();
+ auto genericDeclRef = m_funcDeclRef.getParent().as<GenericDecl>();
SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters
List<Val*> genericArgs;
@@ -1203,19 +1206,13 @@ namespace Slang
auto specializationArg = args[ii];
genericArgs.add(specializationArg.val);
}
- GenericSubstitution* genericSubst =
- getLinkage()->getASTBuilder()->getOrCreateGenericSubstitution(
- genericDeclRef.getSubst(),
- genericDeclRef.getDecl(),
- genericArgs.getArrayView());
+ auto genericInnerDeclRef = getLinkage()->getASTBuilder()->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView());
ASTBuilder* astBuilder = getLinkage()->getASTBuilder();
for (auto constraintDecl : getMembersOfType<GenericTypeConstraintDecl>(
getLinkage()->getASTBuilder(), DeclRef<ContainerDecl>(genericDeclRef)))
{
- DeclRef<GenericTypeConstraintDecl> constraintDeclRef = astBuilder->getSpecializedDeclRef(
- constraintDecl.getDecl(), genericSubst);
-
+ DeclRef<GenericTypeConstraintDecl> constraintDeclRef = astBuilder->getDirectDeclRef(constraintDecl.getDecl());
auto sub = getSub(astBuilder, constraintDeclRef);
auto sup = getSup(astBuilder, constraintDeclRef);
@@ -1233,12 +1230,8 @@ namespace Slang
}
}
- genericSubst =
- getLinkage()->getASTBuilder()->getOrCreateGenericSubstitution(
- genericDeclRef.getSubst(),
- genericDeclRef.getDecl(),
- genericArgs);
- specializedFuncDeclRef = astBuilder->getSpecializedDeclRef(specializedFuncDeclRef.getDecl(), genericSubst);
+ specializedFuncDeclRef = getLinkage()->getASTBuilder()->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView()).as<FuncDecl>();
+ SLANG_ASSERT(specializedFuncDeclRef);
}
info->specializedFuncDeclRef = specializedFuncDeclRef;
@@ -1418,9 +1411,8 @@ namespace Slang
specializationArgs.add(arg);
}
- ExistentialSpecializedType* specializedType = m_astBuilder->create<ExistentialSpecializedType>();
- specializedType->baseType = unspecializedType;
- specializedType->args = specializationArgs;
+ ExistentialSpecializedType* specializedType = m_astBuilder->getOrCreate<ExistentialSpecializedType>(
+ unspecializedType, specializationArgs);
m_specializedTypes.add(specializedType);
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index ba7e977e3..4b4257f75 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -526,7 +526,7 @@ namespace Slang
}
if (!stepSize)
return;
- if (stepSize->value > 0)
+ if (stepSize->getValue() > 0)
{
if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater ||
sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less)
@@ -535,7 +535,7 @@ namespace Slang
return;
}
}
- else if (stepSize->value < 0)
+ else if (stepSize->getValue() < 0)
{
if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less ||
sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater)
@@ -553,25 +553,25 @@ namespace Slang
if (!initialLitVal || !finalVal)
return;
- auto absStepSize = abs(stepSize->value);
+ auto absStepSize = abs(stepSize->getValue());
int adjustment = 0;
if (compareOp == kIROp_Geq || compareOp == kIROp_Leq)
adjustment = 1;
- auto iterations = (Math::Max(finalVal->value, initialLitVal->value) -
- Math::Min(finalVal->value, initialLitVal->value) + absStepSize - 1 + adjustment) /
+ auto iterations = (Math::Max(finalVal->getValue(), initialLitVal->getValue()) -
+ Math::Min(finalVal->getValue(), initialLitVal->getValue()) + absStepSize - 1 + adjustment) /
absStepSize;
switch (compareOp)
{
case kIROp_Geq:
case kIROp_Greater:
// Expect final value to be less than initial value.
- if (finalVal->value > initialLitVal->value)
+ if (finalVal->getValue() > initialLitVal->getValue())
iterations = 0;
break;
case kIROp_Leq:
case kIROp_Less:
- if (finalVal->value < initialLitVal->value)
+ if (finalVal->getValue() < initialLitVal->getValue())
iterations = 0;
break;
}
@@ -590,7 +590,7 @@ namespace Slang
litExpr->type.type = m_astBuilder->getIntType();
litExpr->token.setName(getNamePool()->getName(String(iterations)));
maxItersAttr->args.add(litExpr);
- maxItersAttr->intArgVals.add(0, m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations));
+ maxItersAttr->intArgVals.add(m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations));
maxItersAttr->value = (int32_t)iterations;
maxItersAttr->inductionVar = initialVar;
addModifier(stmt, maxItersAttr);
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index cee54388f..d5d3e5a5d 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -17,6 +17,7 @@ namespace Slang
sink);
SemanticsVisitor visitor(&sharedSemanticsContext);
+ SLANG_AST_BUILDER_RAII(linkage->getASTBuilder());
auto typeOut = visitor.CheckProperType(typeExp);
return typeOut.type;
@@ -49,7 +50,7 @@ namespace Slang
if (!typeRepr) return nullptr;
if (auto typeType = as<TypeType>(typeRepr->type))
{
- return typeType->type;
+ return typeType->getType();
}
return m_astBuilder->getErrorType();
}
@@ -86,11 +87,17 @@ namespace Slang
Type* SemanticsVisitor::getRemovedModifierType(ModifiedType* modifiedType, ModifierVal* modifier)
{
- if (modifiedType->modifiers.getCount() == 1)
- return modifiedType->base;
- auto newModifiers = modifiedType->modifiers;
- newModifiers.remove(modifier);
- return m_astBuilder->getModifiedType(modifiedType->base, newModifiers);
+ if (modifiedType->getModifierCount() == 1)
+ return modifiedType->getBase();
+ List<Val*> newModifiers;
+ for (Index i = 0; i < modifiedType->getModifierCount(); i++)
+ {
+ auto m = modifiedType->getModifier(i);
+ if (m == modifier)
+ continue;
+ newModifiers.add(m);
+ }
+ return m_astBuilder->getModifiedType(modifiedType->getBase(), newModifiers);
}
Expr* SemanticsVisitor::ExpectATypeRepr(Expr* expr)
@@ -118,7 +125,7 @@ namespace Slang
auto typeRepr = ExpectATypeRepr(expr);
if (auto typeType = as<TypeType>(typeRepr->type))
{
- return typeType->type;
+ return typeType->getType();
}
return m_astBuilder->getErrorType();
}
@@ -142,7 +149,7 @@ namespace Slang
// constant expression in context, then we will instead construct
// a dummy "error" value to represent the result.
//
- val = m_astBuilder->create<ErrorIntVal>();
+ val = m_astBuilder->getOrCreate<ErrorIntVal>(m_astBuilder->getIntType());
return val;
}
@@ -160,7 +167,7 @@ namespace Slang
}
if (auto typeType = as<TypeType>(exp->type))
{
- return typeType->type;
+ return typeType->getType();
}
else if (const auto errorType = as<ErrorType>(exp->type))
{
@@ -187,10 +194,7 @@ namespace Slang
evaledArgs.add(ExtractGenericArgVal(argExpr));
}
- GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution(
- genericDeclRef.getSubst(), genericDeclRef.getDecl(), evaledArgs);
-
- DeclRef<Decl> innerDeclRef = m_astBuilder->getSpecializedDeclRef(getInner(genericDeclRef), subst);
+ DeclRef<Decl> innerDeclRef = m_astBuilder->getGenericAppDeclRef(genericDeclRef, evaledArgs.getArrayView());
return DeclRefType::create(m_astBuilder, innerDeclRef);
}
@@ -198,9 +202,9 @@ namespace Slang
{
if (auto declRefValueType = as<DeclRefType>(type))
{
- if (as<ClassDecl>(declRefValueType->declRef.getDecl()))
+ if (as<ClassDecl>(declRefValueType->getDeclRef().getDecl()))
return true;
- if (as<InterfaceDecl>(declRefValueType->declRef.getDecl()))
+ if (as<InterfaceDecl>(declRefValueType->getDeclRef().getDecl()))
return true;
}
return false;
@@ -221,7 +225,7 @@ namespace Slang
if(auto typeType = as<TypeType>(expr->type))
{
- type = typeType->type;
+ type = typeType->getType();
}
}
@@ -358,7 +362,7 @@ namespace Slang
if (auto basicType = as<BasicExpressionType>(type))
{
// TODO: `void` shouldn't be a basic type, to make this easier to avoid
- if (basicType->baseType == BaseType::Void)
+ if (basicType->getBaseType() == BaseType::Void)
{
// TODO(tfoley): pick the right diagnostic message
getSink()->diagnose(result.exp, Diagnostics::invalidTypeVoid);
@@ -384,7 +388,7 @@ namespace Slang
{
if(auto rightConst = as<ConstantIntVal>(right))
{
- return leftConst->value == rightConst->value;
+ return leftConst->getValue() == rightConst->getValue();
}
}
@@ -392,16 +396,16 @@ namespace Slang
{
if(auto rightVar = as<GenericParamIntVal>(right))
{
- return leftVar->declRef.equals(rightVar->declRef);
+ return leftVar->getDeclRef().equals(rightVar->getDeclRef());
}
else if (const auto rightPoly = as<PolynomialIntVal>(right))
{
- return right->equalsVal(leftVar);
+ return right->equals(leftVar);
}
}
if (auto leftVar = as<PolynomialIntVal>(left))
{
- return leftVar->equalsVal(right);
+ return leftVar->equals(right);
}
return false;
}
@@ -423,22 +427,4 @@ namespace Slang
return expr;
}
- Expr* SemanticsExprVisitor::visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr)
- {
- // We have an expression of the form `__TaggedUnion(A, B, ...)`
- // which will evaluate to a tagged-union type over `A`, `B`, etc.
- //
- TaggedUnionType* type = m_astBuilder->create<TaggedUnionType>();
- expr->type = QualType(m_astBuilder->getTypeType(type));
-
- for( auto& caseTypeExpr : expr->caseTypes )
- {
- caseTypeExpr = CheckProperType(caseTypeExpr);
- type->caseTypes.add(caseTypeExpr.type);
- }
-
- return expr;
- }
-
-
}
diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp
index 276e086df..780c109da 100644
--- a/source/slang/slang-check.cpp
+++ b/source/slang/slang-check.cpp
@@ -164,6 +164,8 @@ namespace Slang
TranslationUnitRequest* translationUnit,
LoadedModuleDictionary& loadedModules)
{
+ SLANG_AST_BUILDER_RAII(translationUnit->compileRequest->getLinkage()->getASTBuilder());
+
SharedSemanticsContext sharedSemanticsContext(
translationUnit->compileRequest->getLinkage(),
translationUnit->getModule(),
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 224e30fa1..4e1ab8e98 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -265,7 +265,7 @@ namespace Slang
{
if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness))
{
- auto declModule = getModule(declaredWitness->declRef.getDecl());
+ auto declModule = getModule(declaredWitness->getDeclRef().getDecl());
m_moduleDependencyList.addDependency(declModule);
m_fileDependencyList.addDependency(declModule);
if (m_requirementSet.add(declModule))
@@ -276,8 +276,8 @@ namespace Slang
}
else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness))
{
- addDepedencyFromWitness(transitiveWitness->midToSup);
- addDepedencyFromWitness(transitiveWitness->subToMid);
+ addDepedencyFromWitness(transitiveWitness->getMidToSup());
+ addDepedencyFromWitness(transitiveWitness->getSubToMid());
}
else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness))
{
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 7d73599ba..79f6b6ed4 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -650,9 +650,6 @@ namespace Slang
List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies; }
List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencies; }
- /// Get a list of tagged-union types referenced by the specialization parameters.
- List<TaggedUnionType*> const& getTaggedUnionTypes() { return m_taggedUnionTypes; }
-
RefPtr<IRModule> getIRModule() { return m_irModule; }
void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE;
@@ -679,9 +676,6 @@ namespace Slang
List<String> m_entryPointMangledNames;
List<String> m_entryPointNameOverrides;
- // Any tagged union types that were referenced by the specialization arguments.
- List<TaggedUnionType*> m_taggedUnionTypes;
-
List<Module*> m_moduleDependencies;
List<SourceFile*> m_fileDependencies;
List<RefPtr<ComponentType>> m_requirements;
diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp
index 9e9efeb64..c13dc9668 100644
--- a/source/slang/slang-doc-markdown-writer.cpp
+++ b/source/slang/slang-doc-markdown-writer.cpp
@@ -1119,7 +1119,7 @@ void DocMarkdownWriter::writeDescription(const ASTMarkup::Entry& entry)
void DocMarkdownWriter::writeDecl(const ASTMarkup::Entry& entry, Decl* decl)
{
// Skip these they will be output as part of their respective 'containers'
- if (as<ParamDecl>(decl) || as<EnumCaseDecl>(decl) || as<AssocTypeDecl>(decl) || as<InheritanceDecl>(decl))
+ if (as<ParamDecl>(decl) || as<EnumCaseDecl>(decl) || as<AssocTypeDecl>(decl) || as<InheritanceDecl>(decl) || as<ThisTypeDecl>(decl))
{
return;
}
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index adfc98dfd..5bd57e81c 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -14,7 +14,6 @@
#include "slang-ir-specialize.h"
#include "slang-ir-specialize-resources.h"
#include "slang-ir-ssa.h"
-#include "slang-ir-union.h"
#include "slang-ir-util.h"
#include "slang-ir-validate.h"
#include "slang-legalize-types.h"
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 8f4d68a75..343c18916 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -54,7 +54,6 @@
#include "slang-ir-strip-cached-dict.h"
#include "slang-ir-strip-witness-tables.h"
#include "slang-ir-synthesize-active-mask.h"
-#include "slang-ir-union.h"
#include "slang-ir-validate.h"
#include "slang-ir-wrap-structured-buffers.h"
#include "slang-ir-liveness.h"
@@ -347,10 +346,6 @@ Result linkAndOptimizeIR(
// Lower `Result<T,E>` types into ordinary struct types.
lowerResultType(irModule, sink);
- // Desguar any union types, since these will be illegal on
- // various targets.
- //
- desugarUnionTypes(irModule);
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "UNIONS DESUGARED");
#endif
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 4716ed427..69f3c4e0d 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -54,8 +54,6 @@ INST(Nop, nop, 0, 0)
INST(VectorType, Vec, 2, HOISTABLE)
INST(MatrixType, Mat, 3, HOISTABLE)
- INST(TaggedUnionType, TaggedUnion, 0, HOISTABLE)
-
INST(ConjunctionType, Conjunction, 0, HOISTABLE)
INST(AttributedType, Attributed, 0, HOISTABLE)
INST(ResultType, Result, 2, HOISTABLE)
@@ -985,7 +983,6 @@ INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
INST(ArrayTypeLayout, arrayTypeLayout, 1, HOISTABLE)
INST(StreamOutputTypeLayout, streamOutputTypeLayout, 1, HOISTABLE)
INST(MatrixTypeLayout, matrixTypeLayout, 1, HOISTABLE)
- INST(TaggedUnionTypeLayout, taggedUnionTypeLayout, 0, HOISTABLE)
INST(ExistentialTypeLayout, existentialTypeLayout, 0, HOISTABLE)
INST(StructTypeLayout, structTypeLayout, 0, HOISTABLE)
// TODO(JS): Ideally we'd have the layout to the pointed to value type (ie 1 instead of 0 here). But to avoid infinite recursion we don't.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 2312cc4f2..95f72b3cd 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1717,59 +1717,6 @@ struct IRCaseTypeLayoutAttr : IRAttr
}
};
- /// Specialized layout information for tagged union types
-struct IRTaggedUnionTypeLayout : IRTypeLayout
-{
- typedef IRTypeLayout Super;
-
- IR_LEAF_ISA(TaggedUnionTypeLayout)
-
- /// Get the (byte) offset of the tagged union's tag (aka "discriminator") field
- LayoutSize getTagOffset()
- {
- return LayoutSize::fromRaw(LayoutSize::RawValue(getIntVal(cast<IRIntLit>(getOperand(0)))));
- }
-
- /// Get all the attributes representing layouts for the difference cases
- IROperandList<IRCaseTypeLayoutAttr> getCaseTypeLayoutAttrs()
- {
- return findAttrs<IRCaseTypeLayoutAttr>();
- }
-
- /// Get the number of cases for which layout information is stored
- UInt getCaseCount()
- {
- return getCaseTypeLayoutAttrs().getCount();
- }
-
- /// Get the layout information for the case at the given `index`
- IRTypeLayout* getCaseTypeLayout(UInt index)
- {
- return getCaseTypeLayoutAttrs()[index]->getTypeLayout();
- }
-
- /// Specialized builder for tagged union type layouts
- struct Builder : Super::Builder
- {
- Builder(IRBuilder* irBuilder, LayoutSize tagOffset);
-
- void addCaseTypeLayout(IRTypeLayout* typeLayout);
-
- IRTaggedUnionTypeLayout* build()
- {
- return cast<IRTaggedUnionTypeLayout>(Super::Builder::build());
- }
-
- protected:
- IROp getOp() SLANG_OVERRIDE { return kIROp_TaggedUnionTypeLayout; }
- void addOperandsImpl(List<IRInst*>& ioOperands) SLANG_OVERRIDE;
- void addAttrsImpl(List<IRInst*>& ioOperands) SLANG_OVERRIDE;
-
- IRInst* m_tagOffset = nullptr;
- List<IRAttr*> m_caseTypeLayoutAttrs;
- };
-};
-
/// Type layout for an existential/interface type.
struct IRExistentialTypeLayout : IRTypeLayout
{
@@ -3013,16 +2960,6 @@ public:
IRRate* rate,
IRType* dataType);
- IRType* getTaggedUnionType(
- UInt caseCount,
- IRType* const* caseTypes);
-
- IRType* getTaggedUnionType(
- List<IRType*> const& caseTypes)
- {
- return getTaggedUnionType(caseTypes.getCount(), caseTypes.getBuffer());
- }
-
IRType* getBindExistentialsType(
IRInst* baseType,
UInt slotArgCount,
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 081fd7486..364613074 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -240,7 +240,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
case kIROp_GlobalGenericParam:
case kIROp_WitnessTable:
case kIROp_InterfaceType:
- case kIROp_TaggedUnionType:
return cloneGlobalValue(this, originalValue);
case kIROp_BoolLit:
diff --git a/source/slang/slang-ir-union.cpp b/source/slang/slang-ir-union.cpp
deleted file mode 100644
index 1eb4955e7..000000000
--- a/source/slang/slang-ir-union.cpp
+++ /dev/null
@@ -1,773 +0,0 @@
-// slang-ir-union.cpp
-#include "slang-ir-union.h"
-
-#include "slang-ir.h"
-#include "slang-ir-insts.h"
-
-namespace Slang {
-
-// This file will implement a pass to replace any union types (currently
-// just tagged unions) with plain `struct` types that attempt to provide
-// equivalent semantics. This will necessarily be a bit fragile, and there
-// will be fundamental limits to what the translation can support without
-// improved features in the target shading languages/ILs.
-
-struct DesugarUnionTypesContext
-{
- // We'll start with some basic state that we need to get the job done.
- //
- // This includes the IR module we are to process, as well as IR building
- // state that we will initialize once and then use throughout the pass.
- //
- IRModule* module;
- IRBuilder builderStorage;
- IRBuilder* getBuilder() { return &builderStorage; }
-
- // Because we will be replacing instructions that refer to unions with
- // different logic, we'll want to remove the original instructions.
- // However, we need to be careful about modifying the IR tree while also
- // iterating it, and to keep things simple for ourselves we'll go ahead
- // and build up a list of instruction to remove along the way, and then
- // remove them all at the end.
- //
- List<IRInst*> instsToRemove;
-
- // The overall flow of the pass is pretty simple, so we will walk through it now.
- //
- void processModule()
- {
- // We start by initializing our IR building state.
- //
- builderStorage = IRBuilder(module);
-
- // Next, we will search for any instruction that create or use
- // union types, and process them accordingingly (usually by
- // constructing a new instruction to replace them).
- //
- processInstRec(module->getModuleInst());
-
- // Along the way we will build up a list of the tagged union
- // types that we encountered, but we will refrain from replacing
- // them until we are done (so that we always know that the instructions
- // we process above refer to the original type, and not its
- // replacement.
- //
- for( auto info : taggedUnionInfos )
- {
- auto taggedUnionType = info->taggedUnionType;
- auto replacementInst = info->replacementInst;
-
- // TODO: We should consider transferring decorations from the source
- // type to the destination, but doing so carelessly could create
- // problems, since an IR struct type shouldn't have, e.g., a
- // `TaggedUnionTypeLayout` attached to it.
-
- taggedUnionType->replaceUsesWith(replacementInst);
- taggedUnionType->removeAndDeallocate();
- }
-
- // As described previously, we build up the `instsToRemove` list as
- // we iterate so that we can remove them all here and not risk
- // modifying the IR tree while also walking it.
- //
- // TODO: This might be overkill and we could conceivably just be
- // a bit careful in `processInstRec`.
- //
- for(auto inst : instsToRemove)
- {
- inst->removeAndDeallocate();
- }
- }
-
- // In order to replace a (tagged) union type, we will need to know
- // something about it, and we will use the `TaggedUnionInfo` type
- // to collect all the relevant information.
- //
- struct TaggedUnionInfo : public RefObject
- {
- // We obviously need to know the tagged union itself, and
- // we will also use this structure to track the instruction
- // (an IR struct type) that will replace it.
- //
- IRTaggedUnionType* taggedUnionType;
- IRInst* replacementInst;
-
- // In order to compute a suitable layout for the replacement
- // `struct` type we need to know how the tagged union itself
- // would be laid out in memory, so we require that all tagged
- // unions in the generated IR have an associated (target-specific)
- // layout.
- //
- IRTaggedUnionTypeLayout* taggedUnionTypeLayout;
-
- // The basic approach we will use 16-byte chunks (represented as an array
- // of `uint4`s) to reprent the "bulk" of a type, and then use a single field
- // that could be up to 12 bytes to represent the "rest" of the type.
- //
- // Note that there are deeply ingrained assumptions here that all types
- // are at least four bytes in size (so that unions cannot easily
- // accomodate `half` value), and that any types *larger* than four bytes
- // will need to be loaded/stored via multiple 4-byte loads/stores.
- //
- // With the basic idea out of the way, we need an IR level field
- // in our struct to hold the bulk data, which comprises a "key" for
- // looking up the field, and the type of the field itself. We also
- // keep track of how many bytes we put in our bulk storage.
- //
- // The bulk field might be:
- //
- // - null, if none of the case types was 16 bytes or more
- // - a single `uint4` for between 16 and 31 (inclusive) bytes
- // - an array of `uint4`s for 32 or more bytes
- //
- UInt64 bulkSize = 0;
- IRInst* bulkFieldKey = nullptr;
- IRType* bulkFieldType = nullptr;
-
- // The same basic idea then applies to the rest of the data.
- //
- // The "rest" field will be either be absent (if the size of the
- // type was evently divisible by 16), a scalar `uint`, or else
- // a 2- or 3-component vector of `uint`.
- //
- UInt64 restSize = 0;
- IRInst* restFieldKey = nullptr;
- IRType* restFieldType = nullptr;
-
- // Finally, since we are currently working with tagged unions,
- // we need a field to hold the tag, which will always be allocated
- // after the fields that hold the bulk/rest of the payload.
- //
- // This field is always a single `uint`.
- //
- // TODO: if/when we support untagged unions, they could be handled
- // by having this field be null.
- //
- IRInst* tagFieldKey;
- };
-
- // We will build up a list of all the tagged union types we encounter,
- // so that we can replace them with the synthesized types when we are done.
- //
- List<RefPtr<TaggedUnionInfo>> taggedUnionInfos;
-
- // It is possible that we will see the same tagged union type referenced
- // many times in the IR, but we only want to synthesize the information
- // above (including the various IR structures) once, so we also maintain
- // a map from the original IR type to the corresponding information.
- //
- Dictionary<IRInst*, TaggedUnionInfo*> mapIRTypeToTaggedUnionInfo;
-
- // We will process all instructions in the module in a single recursive walk.
- //
- void processInstRec(IRInst* inst)
- {
- processInst(inst);
-
- for( auto child : inst->getChildren() )
- {
- processInstRec(child);
- }
- }
- //
- // At each instruction, we will check if it is one of the union-related instructions
- // we need to replace, and process it accordingly.
- //
- void processInst(IRInst* inst)
- {
- switch( inst->getOp() )
- {
- default:
- // Any instruction not listed below either doesn't involve union types,
- // or handles them in a hands-off fashion that we don't need to care about.
- //
- // E.g., a `load` of a union type from a constant buffer will turn into
- // a load of the replacement `struct` type once we are done, and nothing
- // needs to be done to the `load` instruction.
- //
- break;
-
- case kIROp_TaggedUnionType:
- {
- // We clearly need to process the tagged union type itself, but the actual
- // work is handled by other functions. All we need to do here is ensure
- // that the information for this type gets generated, and then we can
- // rely on the main `processModule` function to do the actual replacement later.
- //
- auto type = cast<IRTaggedUnionType>(inst);
- getTaggedUnionInfo(type);
- }
- break;
-
- case kIROp_ExtractTaggedUnionTag:
- {
- // The case of extracting the tag from a tagged union is relatively
- // simple, because the replacement type will have a dedicated field or it.
- //
- // We start by finding the tagged union value the instruction is operating
- // on, and then looking up the information for its type (which had
- // better be a tagged union type).
- //
- auto taggedUnionVal = inst->getOperand(0);
- auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType());
-
- // Because the replacement type will have an explicit field for the tag,
- // we can simply emit a single field-extract instruction to read its value
- // out.
- //
- auto builder = getBuilder();
- builder->setInsertBefore(inst);
- auto replacement = builder->emitFieldExtract(
- inst->getFullType(),
- taggedUnionVal,
- taggedUnionInfo->tagFieldKey);
-
- // Now we can replace anything that used the original instruction with
- // the new field-extract operation, and add this instruction to the
- // list for later removal.
- //
- inst->replaceUsesWith(replacement);
- instsToRemove.add(inst);
- }
- break;
-
- case kIROp_ExtractTaggedUnionPayload:
- {
- // The most interesting case is when we are trying to extract a particular
- // payload (one of the case types) from a union. We may need to extract
- // one or more fields from the data stored in the union's replacement
- // type (the bulk/rest fields), and we may also have to convert them
- // to the type expected via bit-casts.
-
- // We can start things off easily enough by extracting the tagged union
- // value being operated on, as well as the information for its type.
- //
- auto taggedUnionVal = inst->getOperand(0);
- auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType());
-
- // Next we need to figure out which case is being extracted from the union.
- // The operand for the case tag should be a literal by construction.
- //
- auto caseTagVal = inst->getOperand(1);
- auto caseTagConst = as<IRIntLit>(caseTagVal);
- SLANG_ASSERT(caseTagConst);
-
- // The case type we are extracting will be the result type of the instruciton.
- //
- auto caseType = inst->getDataType();
- //
- // The tag value itself will be the index of the case type in the union
- // type (and its layout).
- //
- auto caseTagIndex = UInt(caseTagConst->getValue());
-
- // We can use the case tag value to look up the layout for the particular
- // case type we are extracting (this will allow us to resolve byte offsets
- // for fields, etc.).
- //
- auto taggedUnionTypeLayout = taggedUnionInfo->taggedUnionTypeLayout;
- SLANG_ASSERT(caseTagIndex < UInt(taggedUnionTypeLayout->getCaseCount()));
- auto caseTypeLayout = taggedUnionTypeLayout->getCaseTypeLayout(caseTagIndex);
-
- // At this point we know the type we are trying to extract, as well
- // as its layout. We will defer the actual implementation of extraction
- // to a (recursive) subroutine that can extract a (sub-)field from the
- // union at a given byte offset. Since we are extracting a full case
- // right now, the byte offset will be zero.
- //
- auto payloadVal = extractPayload(
- taggedUnionInfo,
- taggedUnionVal,
- caseType,
- caseTypeLayout,
- 0);
-
- // TODO: There is a significant flaw in the above approach when
- // the case type might be (or contain) an array. If we have a setup
- // like the following:
- //
- // union SomeUnion { float someCase[100]; ... }
- // ...
- // float result = someUnion.someCase[someIndex];
- //
- // The current logic would desugar this into something like:
- //
- // struct SomeUnion { uint4 bulk[100]; ... }
- // ...
- // float[] tmp = { asfloat(someUnion.bulk[0].x), asfloat(someUnion.bulk[1].x), ... }
- // float result = tmp[someIndex];
- //
- // The result is that we copy an entire 100-element array into local memory
- // just to fetch a single element, when it would be much nicer to just do:
- //
- // float result = asfloat(someUnion.bulk[someIndex].x);
- //
- // Achieving the latter code requires that rather than blindly translate
- // the `extractTaggedUnionPayload` instruction into a semantically equiavlent
- // value (which might lead to a big copy in the end), we should transitively
- // chase down any "access chains" off of `inst` and see what leaf values are
- // actually needed, and generated more tailored extraction logic for just
- // the elements/fields that actually get referenced.
- //
- // The more refined approach can be built on top of many of the same primitives,
- // so for now we will resign ourselves to the simpler but potentially less
- // efficient approach.
-
- // Now that we've extracted the value for the payload from the fields of
- // the replacement struct, we can use that extracted value to replace
- // this instruction, and schedule the original instruction for removal.
- //
- inst->replaceUsesWith(payloadVal);
- instsToRemove.add(inst);
- }
- break;
- }
- }
-
- // The `extractPayload` operation is the most important bit of translation we
- // need to do to make unions work. We have as input the following:
- //
- IRInst* extractPayload(
-
- // - Information about a tagged union type and its layout.
- TaggedUnionInfo* taggedUnionInfo,
-
- // - A single value of that tagged unon type.
- IRInst* taggedUnionVal,
-
- // - Type type of some "payload" field we want to extract from the union.
- IRType* payloadType,
-
- // - The memory layout of that payload type.
- IRTypeLayout* payloadTypeLayout,
-
- // - The byte offset at which we want to fetch the payload.
- UInt64 payloadOffset)
- {
- // We are going to be building some IR code no matter what.
- //
- auto builder = getBuilder();
-
- // The basic approach here will be to look at the type we
- // are trying to extract from the union, and whenever possible
- // recursively walk its structure so that we can express things
- // in terms of extraction of smaller/simpler types.
- //
- if( auto irStructType = as<IRStructType>(payloadType) )
- {
- // A structure type is a nice recursive case: we simply
- // want to extract each of its field recursively, and
- // then construct a fresh value of the `struct` type.
-
- // In all of the cases of this function we expect/require
- // there to be complete type layout information for the
- // types involved.
- //
- auto structTypeLayout = as<IRStructTypeLayout>(payloadTypeLayout);
- SLANG_ASSERT(structTypeLayout);
-
- // We are going to emit code to extract each of the fields
- // and collect them to use as operands to a `makeStruct`.
- //
- List<IRInst*> fieldVals;
-
- // We need to walk over the fields in the order the IR expects them
- UInt fieldCounter = 0;
- for( auto irField : irStructType->getFields() )
- {
- IRType* fieldType = irField->getFieldType();
-
- // TODO: We need to confirm/enforce that the fields of the
- // IR struct and the fields of the layout still align.
- //
- UInt fieldIndex = fieldCounter++;
- auto fieldLayout = structTypeLayout->getFieldLayout(fieldIndex);
- auto fieldTypeLayout = fieldLayout->getTypeLayout();
-
- // The offset of the field can be computed from the base
- // offset passed in, plus the reflection data for the field.
- //
- UInt64 fieldOffset = payloadOffset;
- if(auto resInfo = fieldLayout->findOffsetAttr(LayoutResourceKind::Uniform))
- fieldOffset += resInfo->getOffset();
-
- // We make a recursive call to extract each field, expecting
- // that this will bottom out eventually.
- //
- IRInst* fieldVal = extractPayload(
- taggedUnionInfo,
- taggedUnionVal,
- fieldType,
- fieldTypeLayout,
- fieldOffset);
- fieldVals.add(fieldVal);
- }
-
- // The final value is then just a new struct constructed from
- // the extracted field values.
- //
- auto payloadVal = builder->emitMakeStruct(irStructType, fieldVals);
- return payloadVal;
- }
- else if( auto vecType = as<IRVectorType>(payloadType) )
- {
- auto elementType = vecType->getElementType();
-
- // We expect that by the time we are desugaring union types
- // all vector types have literal constant values for their
- // element count.
- //
- auto elementCountVal = vecType->getElementCount();
- auto elementCountConst = as<IRIntLit>(elementCountVal);
- SLANG_ASSERT(elementCountConst);
- UInt elementCount = UInt(elementCountConst->getValue());
-
- // HACK: There is currently no `VectorTypeLayout` and thus
- // no way to query the layout of the elements of a vector
- // type. Until that gets added we will kludge things here.
- //
- IRTypeLayout* elementTypeLayout = nullptr;
- size_t elementSize = 0;
- if(auto resInfo = payloadTypeLayout->findSizeAttr(LayoutResourceKind::Uniform))
- elementSize = resInfo->getSize().getFiniteValue() / elementCount;
-
- // Similar to the `struct` case above, we will extract a
- // value for each element of the vector, and then use
- // `makeVector` to construct the result value.
- //
- List<IRInst*> elementVals;
- for(UInt ii = 0; ii < elementCount; ++ii)
- {
- auto elementVal = extractPayload(
- taggedUnionInfo,
- taggedUnionVal,
- elementType,
- elementTypeLayout,
- payloadOffset + ii*elementSize);
- elementVals.add(elementVal);
- }
- return builder->emitMakeVector(vecType, elementVals);
- }
- else if( const auto matType = as<IRMatrixType>(payloadType) )
- {
- SLANG_UNIMPLEMENTED_X("matrix in union type");
- }
- else if( const auto arrayType = as<IRArrayType>(payloadType) )
- {
- SLANG_UNIMPLEMENTED_X("array in union type");
- }
- else
- {
- // If none of the above cases match, then we assume that
- // we have an individual scalar field that we need to fetch.
- //
- UInt64 payloadSize = 0;
- if( auto resInfo = payloadTypeLayout->findSizeAttr(LayoutResourceKind::Uniform) )
- {
- // TODO: somebody before this point should generate an error if
- // we have a `union` type that contains a potentially unbounded
- // amount of data.
- //
- payloadSize = resInfo->getSize().getFiniteValue();
- }
-
- if( payloadSize != 4 )
- {
- // TODO: We should handle the case of 64-bit fields by fetching
- // two `uint` values to form a `uint2`, and then using an
- // appropriate bit-cast to get from `uint2` to, e.g., `double`.
- //
- // The case of 16-bit and smaller fields is more troublesome, but
- // in the worst case we can load a `uint` and then use bitwise
- // ops to extract what we need before bitcasting.
- //
- // The right long-term solution is for downstream languages to have
- // better support for raw memory addressing.
-
- SLANG_UNIMPLEMENTED_X("leaf union field with size other than 4 bytes");
- }
-
- // We know that we want to fetch a value of size `payloadSize`, and
- // we have a known base value and an initial offset into it.
- //
- IRInst* baseVal = taggedUnionVal;
- UInt64 offset = payloadOffset;
-
- // We are going to refine our `baseVal` and `offset` as we go, by
- // trying to narrow down the data we will access in the `struct`
- // type that will provide storage for the union.
- //
- // The first thing we want to check is if the value sits in the
- // "bulk" part of the storage, or the "rest."
- //
- UInt64 bulkSize = taggedUnionInfo->bulkSize;
- if( offset < bulkSize )
- {
- // If the value starts in the bulk area, then the whole
- // thing had better fit in the bulk area. The 16-byte
- // granularity rules for constant buffers should ensure
- // this property for us on current targets.
- //
- SLANG_ASSERT(offset + payloadSize <= bulkSize);
-
- // Since we know we'll be accessing the bulk storage,
- // we will extract it here. The extracted field will
- // be our new base value, but the `offset` doesn't need
- // to be updated since the bulk field sits at offset 0.
- //
- baseVal = builder->emitFieldExtract(
- taggedUnionInfo->bulkFieldType,
- baseVal,
- taggedUnionInfo->bulkFieldKey);
-
- // The bulk storage could be an array, if there are 32
- // or more bytes of bulk storage.
- //
- if( auto baseArrayType = as<IRArrayType>(baseVal->getDataType()) )
- {
- // If an array was allocated for bulk storage then
- // our leaf value resides entirely within a single
- // element (due to constant buffer layout rules),
- // and so we will fetch the appropriate element here.
- //
- // We will change our `baseVal` to the extracted element,
- // and then also adjust our `offset` to be relative
- // to that element.
- //
- size_t bulkElementSize = 16;
- auto index = offset / bulkElementSize;
- baseVal = builder->emitElementExtract(
- baseArrayType->getElementType(),
- baseVal,
- builder->getIntValue(builder->getIntType(), index));
- offset -= index*bulkElementSize;
- }
- }
- else
- {
- // If the offset of the field we want is past the end of
- // the bulk field then it must sit inside of the rest field,
- // and we'll extract it here. This establishes a new
- // base value, and we adjust the `offset` to be relative
- // to the rest field (which starts at an offset equal to `bulkSize`).
- //
- baseVal = builder->emitFieldExtract(
- taggedUnionInfo->restFieldType,
- baseVal,
- taggedUnionInfo->restFieldKey);
- offset -= bulkSize;
- }
-
- // We've now extracted a field that could be either a scalar or
- // a vector, and we have an offset into it. In the case where
- // the base value is a vector, we will extract out the appropriate
- // element.
- //
- if( auto baseVecType = as<IRVectorType>(baseVal->getDataType()) )
- {
- size_t vecElementSize = 4;
- auto index = offset / vecElementSize;
- baseVal = builder->emitElementExtract(
- baseVecType->getElementType(),
- baseVal,
- builder->getIntValue(builder->getIntType(), index));
- offset -= index*vecElementSize;
- }
-
- // At this point, our `baseVal` should be a single `uint`, and
- // it should provide the storage for the exact thing we wanted
- // to access (under the assumption that we always fetch 4 bytes
- // on 4-byte alignment).
- //
- IRInst* payloadVal = baseVal;
- SLANG_ASSERT(offset == 0);
-
- // TODO: we could imagine adding logic here to handle types less
- // than 4 bytes in size by shifting and masking the value we
- // just loaded.
-
- // The payload field we were trying to extract might have a type
- // other than `uint`, and to handle that case we need to employ
- // a bit-cast to get to the desired type.
- //
- if( payloadVal->getDataType() != payloadType )
- {
- payloadVal = builder->emitBitCast(
- payloadType,
- payloadVal);
- }
- return payloadVal;
- }
- }
-
- // All of the logic so far as assumed we can just call `getTaggedUnionInfo`
- // and have easy access to all the required information and the
- // synthesized replacement type.
- //
- TaggedUnionInfo* getTaggedUnionInfo(IRType* type)
- {
- // The big picture is fairly simple: we will lazily build and
- // memoize the information about tagged unions.
- //
- {
- TaggedUnionInfo* info = nullptr;
- if(mapIRTypeToTaggedUnionInfo.tryGetValue(type, info))
- return info;
- }
-
- // When we don't find information in our memo-cache, we
- // will construct it and add it to both the memo-cache
- // *and* a global list of all tagged unions encountered,
- // so that we can replacement them later.
- //
- auto info = createTaggedUnionInfo(type);
- mapIRTypeToTaggedUnionInfo.add(type, info.Ptr());
- taggedUnionInfos.add(info);
-
- return info;
- }
-
- // The actual logic for creating a `TaggedUnionInfo` is relatively
- // straightforward once we've decided what information we need.
- //
- RefPtr<TaggedUnionInfo> createTaggedUnionInfo(IRType* type)
- {
- // We expect that any type used as an operation to one of the
- // `extractTaggedUnion*` operations must be an IR tagged union.
- //
- // Note: If/when we ever expose `union`s to user and allow
- // then to create *generic* tagged union types it might appear
- // that this needs to be changed to account for a `specialize`
- // instruction in place of a concrete tagged union, but in
- // practice this pass needs to be performed late enough that
- // any such generic should be fully specialized.
- //
- auto taggedUnionType = as<IRTaggedUnionType>(type);
- SLANG_ASSERT(taggedUnionType);
-
- RefPtr<TaggedUnionInfo> info = new TaggedUnionInfo();
- info->taggedUnionType = taggedUnionType;
-
- // We are going to create an instruction to replace `type`,
- // and thus will be placing it into the same parent.
- //
- auto builder = getBuilder();
- builder->setInsertBefore(type);
-
- // A tagged union type will be replaced with an ordinary
- // `struct` type with fields to store all the relevant
- // data from any of the cases, plus a tag field.
- //
- auto structType = builder->createStructType();
- info->replacementInst = structType;
-
- // We require/expect the earlier code generation steps to have
- // associated a layout with every tagged union that appears in
- // the code.
- //
- auto layoutDecoration = type->findDecoration<IRLayoutDecoration>();
- SLANG_ASSERT(layoutDecoration);
- auto layout = layoutDecoration->getLayout();
- SLANG_ASSERT(layout);
- auto taggedUnionTypeLayout = as<IRTaggedUnionTypeLayout>(layout);
- SLANG_ASSERT(taggedUnionTypeLayout);
-
- info->taggedUnionTypeLayout = taggedUnionTypeLayout;
-
- // The size of the "payload" for the different cases (everything but
- // the tag) is taken to be the offset of the tag itself.
- //
- // TODO: this might be inaccurate if the payload size isn't a multiple
- // of the tag's alignment. We should deal with that when/if we support
- // types smaller than 4 bytes in unions.
- //
- auto payloadSize = taggedUnionTypeLayout->getTagOffset().getFiniteValue();
-
- // We are going to be construction IR code that makes use of the `int`
- // and `uint` types in several cases, so we go ahead and get a pointer
- // to those types here.
- //
- auto intType = getBuilder()->getIntType();
- auto uintType = getBuilder()->getBasicType(BaseType::UInt);
-
- // For now we will use a simple stragegy for how we encode a union,
- // which depends only on the total number of bytes needed, and not
- // on the makeup of the values being stored.
- //
- // We will start by allocating one or more `uint4` values (in an
- // array for the "or more" case) to hold the bulk of any large
- // payload value.
- //
- size_t bulkVectorSize = 16; // Note: assuming `sizeof(uint4) == 16` on all targets
- auto bulkVectorCount = payloadSize / bulkVectorSize;
- auto bulkFieldSize = bulkVectorCount * bulkVectorSize;
- if( bulkVectorCount )
- {
- IRType* bulkFieldType = builder->getVectorType(
- uintType,
- builder->getIntValue(intType, 4));
-
- if( bulkVectorCount > 1 )
- {
- bulkFieldType = builder->getArrayType(
- bulkFieldType,
- builder->getIntValue(intType, bulkVectorCount));
- }
-
- auto bulkFieldKey = builder->createStructKey();
- builder->createStructField(structType, bulkFieldKey, bulkFieldType);
-
- info->bulkFieldKey = bulkFieldKey;
- info->bulkFieldType = bulkFieldType;
- }
- info->bulkSize = bulkFieldSize;
-
- // The rest of the data (anything that doesn't fit in the bulk field),
- // will get allocated into a single scalar or vector of `uint`.
- //
- auto restSize = payloadSize - bulkFieldSize;
- if( restSize )
- {
- size_t restElementSize = 4; // assuming `sizeof(uint) == 4` on all targets
- auto restElementCount = restSize / restElementSize;
- auto restFieldSize = restElementSize * restElementCount;
- SLANG_ASSERT(restFieldSize == restSize); // Note: all our current targets have minimum 4-byte storage granularity
-
- IRType* restFieldType = uintType;
- if( restElementCount > 1 )
- {
- restFieldType = builder->getVectorType(
- restFieldType,
- builder->getIntValue(intType, restElementCount));
- }
-
- auto restFieldKey = builder->createStructKey();
- builder->createStructField(structType, restFieldKey, restFieldType);
-
- info->restFieldKey = restFieldKey;
- info->restFieldType = restFieldType;
- info->restSize = restFieldSize;
- }
-
- // Finally, we add a field to represent the tag.
- //
- auto tagFieldType = uintType;
- auto tagFieldKey = builder->createStructKey();
- builder->createStructField(structType, tagFieldKey, tagFieldType);
-
- info->tagFieldKey = tagFieldKey;
-
- return info;
- }
-};
-
-void desugarUnionTypes(
- IRModule* module)
-{
- DesugarUnionTypesContext context;
- context.module = module;
-
- context.processModule();
-}
-
-} // namespace Slang
diff --git a/source/slang/slang-ir-union.h b/source/slang/slang-ir-union.h
deleted file mode 100644
index 81757dced..000000000
--- a/source/slang/slang-ir-union.h
+++ /dev/null
@@ -1,18 +0,0 @@
-// slang-ir-union.h
-#pragma once
-
-namespace Slang {
-
-struct IRModule;
-
- /// Desugar any unions types, and code using them, in `module`
- ///
- /// Union types will be replaced with ordinary `struct` types that store
- /// the data of the underlying type as a "bag of bits" and references
- /// to cases of the union will be replaced with logic to extract the
- /// relevant bits.
- ///
-void desugarUnionTypes(
- IRModule* module);
-
-} // namespace Slang
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index a27bf8658..38d1eb520 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1004,32 +1004,6 @@ namespace Slang
}
//
- // IRTaggedUnionTypeLayout
- //
-
- IRTaggedUnionTypeLayout::Builder::Builder(IRBuilder* irBuilder, LayoutSize tagOffset)
- : Super::Builder(irBuilder)
- {
- m_tagOffset = irBuilder->getIntValue(irBuilder->getIntType(), tagOffset.raw);
- }
-
- void IRTaggedUnionTypeLayout::Builder::addCaseTypeLayout(IRTypeLayout* typeLayout)
- {
- m_caseTypeLayoutAttrs.add(getIRBuilder()->getCaseTypeLayoutAttr(typeLayout));
- }
-
- void IRTaggedUnionTypeLayout::Builder::addOperandsImpl(List<IRInst*>& ioOperands)
- {
- ioOperands.add(m_tagOffset);
- }
-
- void IRTaggedUnionTypeLayout::Builder::addAttrsImpl(List<IRInst*>& ioOperands)
- {
- for(auto attr : m_caseTypeLayoutAttrs)
- ioOperands.add(attr);
- }
-
- //
// IRVarLayout
//
@@ -2981,17 +2955,6 @@ namespace Slang
operands);
}
- IRType* IRBuilder::getTaggedUnionType(
- UInt caseCount,
- IRType* const* caseTypes)
- {
- return (IRType*)createIntrinsicInst(
- getTypeKind(),
- kIROp_TaggedUnionType,
- caseCount,
- (IRInst* const*) caseTypes);
- }
-
IRType* IRBuilder::getBindExistentialsType(
IRInst* baseType,
UInt slotArgCount,
@@ -3335,7 +3298,6 @@ namespace Slang
IRInst* const* args)
{
auto innerReturnVal = findInnerMostGenericReturnVal(as<IRGeneric>(genericVal));
-
if (as<IRWitnessTable>(innerReturnVal))
{
return createIntrinsicInst(
@@ -3371,7 +3333,7 @@ namespace Slang
// the emit logic, but this is a reasonably early place
// to catch it.
//
- SLANG_ASSERT(witnessTableVal->getOp() != kIROp_StructKey);
+ SLANG_ASSERT(witnessTableVal && witnessTableVal->getOp() != kIROp_StructKey);
IRInst* args[] = {witnessTableVal, interfaceMethodVal};
@@ -5536,6 +5498,8 @@ namespace Slang
return emitIntrinsicInst(
getNativePtrType((IRType*)valueType->getOperand(0)), kIROp_GetNativePtr, 1, &value);
break;
+ case kIROp_ExtractExistentialType:
+ return emitGetNativePtr(value->getOperand(0));
default:
SLANG_UNEXPECTED("invalid operand type for `getNativePtr`.");
UNREACHABLE_RETURN(nullptr);
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index b0d9bb109..97f98fce2 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1749,11 +1749,6 @@ struct IRInterfaceType : IRType
IR_LEAF_ISA(InterfaceType)
};
-struct IRTaggedUnionType : IRType
-{
- IR_LEAF_ISA(TaggedUnionType)
-};
-
struct IRConjunctionType : IRType
{
IR_LEAF_ISA(ConjunctionType)
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index e9c5f8fe5..d29cc4485 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -142,11 +142,6 @@ public:
bool visitSharedTypeExpr(SharedTypeExpr* expr) { return dispatchIfNotNull(expr->base.exp); }
- bool visitTaggedUnionTypeExpr(TaggedUnionTypeExpr*)
- {
- return false;
- }
-
bool visitInvokeExpr(InvokeExpr* expr)
{
PushNode pushNodeRAII(context, expr);
diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp
index 6bccca8d3..4ec0ac64f 100644
--- a/source/slang/slang-language-server-completion.cpp
+++ b/source/slang/slang-language-server-completion.cpp
@@ -596,7 +596,7 @@ List<LanguageServerProtocol::CompletionItem> CompletionContext::createSwizzleCan
{
const char* memberNames[4] = {"x", "y", "z", "w"};
Type* elementType = nullptr;
- elementType = vectorType->elementType;
+ elementType = vectorType->getElementType();
String typeStr;
if (elementType)
typeStr = elementType->toString();
diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp
index e79716975..bc12ad34f 100644
--- a/source/slang/slang-language-server.cpp
+++ b/source/slang/slang-language-server.cpp
@@ -213,7 +213,7 @@ static bool isBoolType(Type* t)
auto basicType = as<BasicExpressionType>(t);
if (!basicType)
return false;
- return basicType->baseType == BaseType::Bool;
+ return basicType->getBaseType() == BaseType::Bool;
}
String getDeclKindString(DeclRef<Decl> declRef)
@@ -303,11 +303,11 @@ String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version)
sb << " = ";
if (isBoolType(varDecl->getType()))
{
- sb << (constantInt->value ? "true" : "false");
+ sb << (constantInt->getValue() ? "true" : "false");
}
else
{
- sb << constantInt->value;
+ sb << constantInt->getValue();
}
}
else
@@ -492,6 +492,8 @@ SlangResult LanguageServer::hover(
doc->zeroBasedUTF16LocToOneBasedUTF8Loc(args.position.line, args.position.character, line, col);
auto version = m_workspace->getCurrentVersion();
+ SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder());
+
Module* parsedModule = version->getOrLoadModule(canonicalPath);
if (!parsedModule)
{
@@ -741,6 +743,8 @@ SlangResult LanguageServer::gotoDefinition(
doc->zeroBasedUTF16LocToOneBasedUTF8Loc(args.position.line, args.position.character, line, col);
auto version = m_workspace->getCurrentVersion();
+ SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder());
+
Module* parsedModule = version->getOrLoadModule(canonicalPath);
if (!parsedModule)
{
@@ -1029,6 +1033,8 @@ SlangResult LanguageServer::semanticTokens(
}
auto version = m_workspace->getCurrentVersion();
+ SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder());
+
Module* parsedModule = version->getOrLoadModule(canonicalPath);
if (!parsedModule)
{
@@ -1073,6 +1079,7 @@ String LanguageServer::getExprDeclSignature(Expr* expr, String* outDocumentation
return String();
auto version = m_workspace->getCurrentVersion();
+ SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder());
SignatureInformation sigInfo;
@@ -1096,7 +1103,7 @@ String LanguageServer::getExprDeclSignature(Expr* expr, String* outDocumentation
bool isFirst = true;
printer.getStringBuilder() << "(";
int paramIndex = 0;
- for (auto param : funcType->paramTypes)
+ for (auto param : funcType->getParamTypes())
{
if (!isFirst)
printer.getStringBuilder() << ", ";
@@ -1134,6 +1141,8 @@ String LanguageServer::getExprDeclSignature(Expr* expr, String* outDocumentation
String LanguageServer::getDeclRefSignature(DeclRef<Decl> declRef, String* outDocumentation, List<Slang::Range<Index>>* outParamRanges)
{
auto version = m_workspace->getCurrentVersion();
+ SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder());
+
ASTPrinter printer(
version->linkage->getASTBuilder(),
ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords |
@@ -1169,6 +1178,8 @@ SlangResult LanguageServer::signatureHelp(
doc->zeroBasedUTF16LocToOneBasedUTF8Loc(args.position.line, args.position.character, line, col);
auto version = m_workspace->getCurrentVersion();
+ SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder());
+
Module* parsedModule = version->getOrLoadModule(canonicalPath);
if (!parsedModule)
{
@@ -1289,7 +1300,7 @@ SlangResult LanguageServer::signatureHelp(
printer.getStringBuilder() << "func (";
bool isFirst = true;
- for (auto param : funcType->paramTypes)
+ for (auto param : funcType->getParamTypes())
{
if (!isFirst)
printer.getStringBuilder() << ", ";
@@ -1315,12 +1326,12 @@ SlangResult LanguageServer::signatureHelp(
if (auto declRefExpr = as<DeclRefExpr>(funcExpr))
{
- if (auto aggDecl = as<AggTypeDecl>(declRefExpr->declRef.getDecl()))
+ if (auto aggDeclRef = as<AggTypeDecl>(declRefExpr->declRef))
{
// Look for initializers
- for (auto member : aggDecl->getMembersOfType<ConstructorDecl>())
+ for (auto member : getMembersOfType<ConstructorDecl>(version->linkage->getASTBuilder(), aggDeclRef))
{
- addDeclRef(version->linkage->getASTBuilder()->getSpecializedDeclRef<Decl>(member, declRefExpr->declRef.getSubst()));
+ addDeclRef(member);
}
}
else
@@ -1379,6 +1390,8 @@ SlangResult LanguageServer::documentSymbol(
return SLANG_OK;
}
auto version = m_workspace->getCurrentVersion();
+ SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder());
+
Module* parsedModule = version->getOrLoadModule(canonicalPath);
if (!parsedModule)
{
@@ -1400,6 +1413,8 @@ SlangResult LanguageServer::inlayHint(const LanguageServerProtocol::InlayHintPar
return SLANG_OK;
}
auto version = m_workspace->getCurrentVersion();
+ SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder());
+
Module* parsedModule = version->getOrLoadModule(canonicalPath);
if (!parsedModule)
{
@@ -1518,6 +1533,8 @@ void LanguageServer::publishDiagnostics()
m_lastDiagnosticUpdateTime = std::chrono::system_clock::now();
auto version = m_workspace->getCurrentVersion();
+ SLANG_AST_BUILDER_RAII(version->linkage->getASTBuilder());
+
// Send updates to clear diagnostics for files that no longer have any messages.
List<String> filesToRemove;
for (auto& file : m_lastPublishedDiagnostics)
diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp
index 2eca91673..89d3380e4 100644
--- a/source/slang/slang-lookup.cpp
+++ b/source/slang/slang-lookup.cpp
@@ -16,7 +16,7 @@ void ensureDecl(SemanticsVisitor* visitor, Decl* decl, DeclCheckState state);
//
-DeclRef<ExtensionDecl> ApplyExtensionToType(
+DeclRef<ExtensionDecl> applyExtensionToType(
SemanticsVisitor* semantics,
ExtensionDecl* extDecl,
Type* type);
@@ -161,14 +161,12 @@ static bool _isUncheckedLocalVar(const Decl* decl)
static void _lookUpDirectAndTransparentMembers(
ASTBuilder* astBuilder,
Name* name,
- DeclRef<ContainerDecl> containerDeclRef,
+ ContainerDecl* containerDecl, // The container decl to find member with `name`.
+ DeclRef<Decl> parentDeclRef, // The parent of the resulting declref.
LookupRequest const& request,
LookupResult& result,
BreadcrumbInfo* inBreadcrumbs)
{
- ContainerDecl* containerDecl = containerDeclRef.getDecl();
-
-
if (request.isCompletionRequest())
{
// If we are looking up for completion suggestions,
@@ -182,7 +180,7 @@ static void _lookUpDirectAndTransparentMembers(
AddToLookupResult(
result,
CreateLookupResultItem(
- astBuilder->getSpecializedDeclRef<Decl>(member, containerDeclRef.getSubst()), inBreadcrumbs));
+ astBuilder->getMemberDeclRef<Decl>(parentDeclRef, member), inBreadcrumbs));
}
}
else
@@ -207,7 +205,7 @@ static void _lookUpDirectAndTransparentMembers(
continue;
// The declaration passed the test, so add it!
- AddToLookupResult(result, CreateLookupResultItem(astBuilder->getSpecializedDeclRef<Decl>(m, containerDeclRef.getSubst()), inBreadcrumbs));
+ AddToLookupResult(result, CreateLookupResultItem(astBuilder->getMemberDeclRef<Decl>(parentDeclRef, m), inBreadcrumbs));
}
}
@@ -215,9 +213,9 @@ static void _lookUpDirectAndTransparentMembers(
// if we already has a hit in the current container?
for(auto transparentInfo : containerDecl->getTransparentMembers())
{
- // The reference to the transparent member should use whatever
- // substitutions we used in referring to its outer container
- DeclRef<Decl> transparentMemberDeclRef = astBuilder->getSpecializedDeclRef(transparentInfo.decl, containerDeclRef.getSubst());
+ // The reference to the transparent member should use the same
+ // path as we used in referring to its parent.
+ DeclRef<Decl> transparentMemberDeclRef = astBuilder->getMemberDeclRef(parentDeclRef, transparentInfo.decl);
// We need to leave a breadcrumb so that we know that the result
// of lookup involves a member lookup step here
@@ -262,7 +260,8 @@ LookupResult lookUpDirectAndTransparentMembers(
ASTBuilder* astBuilder,
SemanticsVisitor* semantics,
Name* name,
- DeclRef<ContainerDecl> containerDeclRef,
+ ContainerDecl* containerDecl,
+ DeclRef<Decl> parentDeclRef,
LookupMask mask)
{
LookupRequest request = initLookupRequest(semantics, name, mask, LookupOptions::None, nullptr);
@@ -270,36 +269,14 @@ LookupResult lookUpDirectAndTransparentMembers(
_lookUpDirectAndTransparentMembers(
astBuilder,
name,
- containerDeclRef,
+ containerDecl,
+ parentDeclRef,
request,
result,
nullptr);
return result;
}
-static SubtypeWitness* _makeSubtypeWitness(
- ASTBuilder* astBuilder,
- Type* subType,
- SubtypeWitness* subToMidWitness,
- Type* superType,
- SubtypeWitness* midtoSuperWitness)
-{
- SLANG_UNUSED(subType);
- SLANG_UNUSED(superType);
-
- if(subToMidWitness)
- {
- auto transitiveWitness = astBuilder->getTransitiveSubtypeWitness(
- subToMidWitness,
- midtoSuperWitness);
- return transitiveWitness;
- }
- else
- {
- return midtoSuperWitness;
- }
-}
-
// Specialize `declRefToSpecialize` with ThisType info if `superType` is an interface type.
DeclRef<Decl> _maybeSpecializeSuperTypeDeclRef(
ASTBuilder* astBuilder,
@@ -309,14 +286,10 @@ DeclRef<Decl> _maybeSpecializeSuperTypeDeclRef(
{
if (auto superDeclRefType = as<DeclRefType>(superType))
{
- if (auto superInterfaceDeclRef = superDeclRefType->declRef.as<InterfaceDecl>())
+ if (auto superInterfaceDeclRef = superDeclRefType->getDeclRef().as<InterfaceDecl>())
{
- ThisTypeSubstitution* thisTypeSubst = astBuilder->getOrCreateThisTypeSubstitution(
- superInterfaceDeclRef.getDecl(),
- subIsSuperWitness,
- declRefToSpecialize.getSubst());
-
- auto specializedDeclRef = astBuilder->getSpecializedDeclRef<Decl>(declRefToSpecialize.getDecl(), thisTypeSubst);
+ ThisTypeDecl* thisTypeDecl = superInterfaceDeclRef.getDecl()->getThisTypeDecl();
+ auto specializedDeclRef = astBuilder->getLookupDeclRef(subIsSuperWitness, thisTypeDecl);
return specializedDeclRef;
}
@@ -332,7 +305,7 @@ static Type* _maybeSpecializeSuperType(
{
if (auto superDeclRefType = as<DeclRefType>(superType))
{
- auto specializedDeclRef = _maybeSpecializeSuperTypeDeclRef(astBuilder, superDeclRefType->declRef, superType, subIsSuperWitness);
+ auto specializedDeclRef = _maybeSpecializeSuperTypeDeclRef(astBuilder, superDeclRefType->getDeclRef(), superType, subIsSuperWitness);
return DeclRefType::create(astBuilder, specializedDeclRef);
}
@@ -391,14 +364,21 @@ static void _lookUpMembersInSuperType(
}
static void _lookUpMembersInSuperTypeDeclImpl(
- ASTBuilder* astBuilder,
- Name* name,
+ ASTBuilder* astBuilder,
+ Name* name,
DeclRef<Decl> declRef,
- LookupRequest const& request,
- LookupResult& ioResult,
- BreadcrumbInfo* inBreadcrumbs)
+ LookupRequest const& request,
+ LookupResult& ioResult,
+ BreadcrumbInfo* inBreadcrumbs)
{
auto semantics = request.semantics;
+ if (!as<InterfaceDecl>(declRef.getDecl()) && name->text == "This")
+ {
+ // If we are looking for `This` in anything other than an InterfaceDecl,
+ // we just need to return the declRef itself.
+ AddToLookupResult(ioResult, CreateLookupResultItem(declRef, inBreadcrumbs));
+ return;
+ }
// If the semantics context hasn't been established yet (e.g. when looking up during parsing),
// we simply do a direct lookup without considering subtypes or extensions.
@@ -408,7 +388,7 @@ static void _lookUpMembersInSuperTypeDeclImpl(
// In this case we can only lookup in an aggregate type.
if (auto aggTypeDeclBaseRef = declRef.as<AggTypeDeclBase>())
{
- _lookUpDirectAndTransparentMembers(astBuilder, name, aggTypeDeclBaseRef, request, ioResult, inBreadcrumbs);
+ _lookUpDirectAndTransparentMembers(astBuilder, name, aggTypeDeclBaseRef.getDecl(), aggTypeDeclBaseRef, request, ioResult, inBreadcrumbs);
}
return;
}
@@ -464,7 +444,7 @@ static void _lookUpMembersInSuperTypeDeclImpl(
// relying on the modifier.
if (auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(facet->subtypeWitness))
{
- auto inheritanceDeclRef = declaredSubtypeWitness->declRef;
+ auto inheritanceDeclRef = declaredSubtypeWitness->getDeclRef();
if (inheritanceDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>())
continue;
}
@@ -473,6 +453,7 @@ static void _lookUpMembersInSuperTypeDeclImpl(
BreadcrumbInfo* newBreadcrumbs = inBreadcrumbs;
BreadcrumbInfo subtypeInfo;
+ auto parentDeclRef = containerDeclRef;
if (facet->directness != Facet::Directness::Self)
{
// Depending on the type of the facet, we may want to specialize the
@@ -487,9 +468,15 @@ static void _lookUpMembersInSuperTypeDeclImpl(
// we should also specialize the interface declRef with the concrete
// type info.
//
- containerDeclRef = _maybeSpecializeSuperTypeDeclRef(
+ parentDeclRef = _maybeSpecializeSuperTypeDeclRef(
astBuilder, containerDeclRef, facet->getType(), facet->subtypeWitness)
.as<ContainerDecl>();
+ if (as<ThisTypeDecl>(parentDeclRef.getDecl()) && name->text == "This")
+ {
+ // If we are going looking for `This` in a `ThisType`, we just need to return the declRef itself.
+ AddToLookupResult(ioResult, CreateLookupResultItem(parentDeclRef, inBreadcrumbs));
+ continue;
+ }
// If we are looking up in a base type, we also need to make sure
// to create a breadcrumb to track the sub to super indirection.
@@ -502,7 +489,7 @@ static void _lookUpMembersInSuperTypeDeclImpl(
newBreadcrumbs = &subtypeInfo;
}
}
- _lookUpDirectAndTransparentMembers(astBuilder, name, containerDeclRef, request, ioResult, newBreadcrumbs);
+ _lookUpDirectAndTransparentMembers(astBuilder, name, containerDeclRef.getDecl(), parentDeclRef, request, ioResult, newBreadcrumbs);
}
}
@@ -540,7 +527,7 @@ static void _lookUpMembersInSuperTypeImpl(
if(auto declRefType = as<DeclRefType>(superType))
{
- auto declRef = declRefType->declRef;
+ auto declRef = declRefType->getDeclRef();
_lookUpMembersInSuperTypeDeclImpl(astBuilder, name, declRef, request, ioResult, inBreadcrumbs);
}
@@ -551,36 +538,16 @@ static void _lookUpMembersInSuperTypeImpl(
// lookup will have a comparable substitution applied (allowing things like associated
// types, etc. used in the signature of a method to resolve correctly).
//
- auto interfaceDeclRef = extractExistentialType->getSpecializedInterfaceDeclRef();
- _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, interfaceDeclRef, request, ioResult, inBreadcrumbs);
- }
- else if( auto thisType = as<ThisType>(superType) )
- {
- // We need to create a witness that represents the next link in the
- // chain. The `leafIsSuperWitness` represents the knowledge that `leafType : superType`
- // (and we know that `superType == thisType`, but we now need to extend that
- // with the knowledge that `thisType : thisType->interfaceTypeDeclRef`.
- //
- auto interfaceType = DeclRefType::create(astBuilder, thisType->interfaceDeclRef);
-
- auto superIsInterfaceWitness = astBuilder->getThisTypeSubtypeWitness(superType, interfaceType);
-
- auto leafIsInterfaceWitness = _makeSubtypeWitness(
- astBuilder,
- leafType,
- leafIsSuperWitness,
- interfaceType,
- superIsInterfaceWitness);
-
- _lookUpMembersInSuperType(astBuilder, name, leafType, interfaceType, leafIsInterfaceWitness, request, ioResult, inBreadcrumbs);
+ auto thisTypeDeclRef = extractExistentialType->getThisTypeDeclRef();
+ _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, thisTypeDeclRef, request, ioResult, inBreadcrumbs);
}
else if( auto andType = as<AndType>(superType) )
{
// We have a type of the form `leftType & rightType` and we need to perform
// lookup in both `leftType` and `rightType`.
//
- auto leftType = andType->left;
- auto rightType = andType->right;
+ auto leftType = andType->getLeft();
+ auto rightType = andType->getRight();
// Operationally, we are in a situation where we have a witness
// that the `leafType` we are doing lookup on is an subtype
@@ -731,7 +698,7 @@ static void _lookUpInScopes(
// just a decl.
//
DeclRef<ContainerDecl> containerDeclRef =
- astBuilder->getSpecializedDeclRef<Decl>(containerDecl, createDefaultSubstitutions(astBuilder, request.semantics, containerDecl)).as<ContainerDecl>();
+ createDefaultSubstitutionsIfNeeded(astBuilder, request.semantics, makeDeclRef(containerDecl)).as<ContainerDecl>();
// If the container we are looking into represents a type
// or an `extension` of a type, then we need to treat
@@ -755,7 +722,7 @@ static void _lookUpInScopes(
breadcrumb.thisParameterMode = thisParameterMode;
breadcrumb.declRef = aggTypeDeclBaseRef;
breadcrumb.prev = nullptr;
-
+ BreadcrumbInfo* breadcrumbPtr = &breadcrumb;
Type* type = nullptr;
if (auto extDeclRef = aggTypeDeclBaseRef.as<ExtensionDecl>())
{
@@ -773,10 +740,25 @@ static void _lookUpInScopes(
else
{
assert(aggTypeDeclBaseRef.as<AggTypeDecl>());
- type = DeclRefType::create(astBuilder, aggTypeDeclBaseRef);
+ if (auto interfaceBase = as<InterfaceDecl>(aggTypeDeclBaseRef.getDecl()))
+ {
+ // When looking up inside an interface type, we are actually looking up through ThisType.
+ if (name != interfaceBase->getThisTypeDecl()->getName())
+ {
+ type = DeclRefType::create(astBuilder, astBuilder->getMemberDeclRef(aggTypeDeclBaseRef, interfaceBase->getThisTypeDecl()));
+ // Don't need any breadcrumb for looking up through ThisType, since we have already
+ // created the base type reference in the new `type`'s declref.
+ breadcrumbPtr = nullptr;
+ }
+ }
+
+ if (!type)
+ {
+ type = DeclRefType::create(astBuilder, aggTypeDeclBaseRef);
+ }
}
- _lookUpMembersInType(astBuilder, name, type, request, result, &breadcrumb);
+ _lookUpMembersInType(astBuilder, name, type, request, result, breadcrumbPtr);
}
else
{
@@ -784,7 +766,7 @@ static void _lookUpInScopes(
// type or `extension` declaration, so we can look up members
// in that scope much more simply.
//
- _lookUpDirectAndTransparentMembers(astBuilder, name, containerDeclRef, request, result, nullptr);
+ _lookUpDirectAndTransparentMembers(astBuilder, name, containerDeclRef.getDecl(), containerDeclRef, request, result, nullptr);
}
// Before we proceed up to the next outer scope to perform lookup
diff --git a/source/slang/slang-lookup.h b/source/slang/slang-lookup.h
index 69374c024..8af760f70 100644
--- a/source/slang/slang-lookup.h
+++ b/source/slang/slang-lookup.h
@@ -35,7 +35,8 @@ LookupResult lookUpDirectAndTransparentMembers(
ASTBuilder* astBuilder,
SemanticsVisitor* semantics,
Name* name,
- DeclRef<ContainerDecl> containerDeclRef,
+ ContainerDecl* containerDecl,
+ DeclRef<Decl> parentDeclRef, // The parent of the resulting declref.
LookupMask mask = LookupMask::Default);
// TODO: this belongs somewhere else
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 0773226d1..e0e97d6e7 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -804,7 +804,7 @@ LoweredValInfo emitCallToDeclRef(
if( auto ctorDeclRef = funcDeclRef.as<ConstructorDecl>() )
{
- if(!ctorDeclRef.getDecl()->body && isFromStdLib(ctorDeclRef.getDecl()) && !as<InterfaceDecl>(ctorDeclRef.getParent(context->astBuilder).getDecl()))
+ if(!ctorDeclRef.getDecl()->body && isFromStdLib(ctorDeclRef.getDecl()) && !as<InterfaceDecl>(ctorDeclRef.getParent().getDecl()))
{
SLANG_UNREACHABLE("stdlib error: __init() has no definition.");
}
@@ -1399,7 +1399,7 @@ void getGenericTypeConformances(IRGenContext* context, ShortList<IRType*>& supTy
{
if (auto declRefType = as<DeclRefType>(typeConstraint->sub.type))
{
- if (declRefType->declRef.getDecl() == genericParamDecl)
+ if (declRefType->getDeclRef().getDecl() == genericParamDecl)
{
supTypes.add(lowerType(context, typeConstraint->getSup().type));
}
@@ -1408,6 +1408,36 @@ void getGenericTypeConformances(IRGenContext* context, ShortList<IRType*>& supTy
}
}
+
+// Check if declRef represents a witness that `ISomeInterface.This : ISomeInterface`.
+static bool _isThisTypeSubtypeWitness(DeclRefBase* declRef)
+{
+ auto lookupDeclRef = as<LookupDeclRef>(declRef);
+ if (!lookupDeclRef)
+ return false;
+ if (!as<ThisType>(lookupDeclRef->getLookupSource()))
+ return false;
+ auto declaredWitness = as<DeclaredSubtypeWitness>(lookupDeclRef->getWitness());
+ if (!declaredWitness)
+ return false;
+ if (!as<ThisTypeConstraintDecl>(declaredWitness->getDeclRef()))
+ return false;
+ return true;
+}
+
+// Returns whether `declRef` represents a trivial lookup of an interface requirement
+// through `ThisTypeDecl` made from within the same interface Decl.
+static bool _isTrivialLookupFromInterfaceThis(IRGenContext* context, DeclRefBase* declRef)
+{
+ if (!_isThisTypeSubtypeWitness(declRef))
+ return false;
+ // This is a lookup from an interface's This type.
+ // If the lookup is made from an interface type itself rather than an extension of it,
+ // then it is a trivial lookup and we should lower it as a struct key.
+ return context->thisTypeWitness == nullptr;
+}
+
+
//
struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredValInfo>
@@ -1424,24 +1454,24 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitGenericParamIntVal(GenericParamIntVal* val)
{
- return emitDeclRef(context, val->declRef,
- lowerType(context, getType(context->astBuilder, val->declRef)));
+ return emitDeclRef(context, val->getDeclRef(),
+ lowerType(context, getType(context->astBuilder, val->getDeclRef())));
}
LoweredValInfo visitFuncCallIntVal(FuncCallIntVal* val)
{
TryClauseEnvironment tryEnv;
List<IRInst*> args;
- for (auto arg : val->args)
+ for (auto arg : val->getArgs())
{
auto loweredArg = lowerVal(context, arg);
args.add(loweredArg.val);
}
- auto funcType = lowerType(context, val->funcType);
+ auto funcType = lowerType(context, val->getFuncType());
return emitCallToDeclRef(
context,
as<IRFuncType>(funcType)->getResultType(),
- val->funcDeclRef,
+ val->getFuncDeclRef(),
funcType,
args,
tryEnv);
@@ -1449,17 +1479,17 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitTypeCastIntVal(TypeCastIntVal* val)
{
- auto baseVal = lowerVal(context, val->base);
+ auto baseVal = lowerVal(context, val->getBase());
SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
- auto type = lowerType(context, val->type);
+ auto type = lowerType(context, val->getType());
return LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val));
}
LoweredValInfo visitWitnessLookupIntVal(WitnessLookupIntVal* val)
{
- auto witnessVal = lowerVal(context, val->witness);
- auto key = getInterfaceRequirementKey(context, val->key);
- auto type = lowerType(context, val->type);
+ auto witnessVal = lowerVal(context, val->getWitness());
+ auto key = getInterfaceRequirementKey(context, val->getKey());
+ auto type = lowerType(context, val->getType());
return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst(
type, witnessVal.val, key));
}
@@ -1467,16 +1497,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitPolynomialIntVal(PolynomialIntVal* val)
{
auto irBuilder = getBuilder();
- auto type = lowerType(context, val->type);
- auto constTerm = irBuilder->getIntValue(type, val->constantTerm);
+ auto type = lowerType(context, val->getType());
+ auto constTerm = irBuilder->getIntValue(type, val->getConstantTerm());
auto resultVal = constTerm;
- for (auto term : val->terms)
+ for (auto term : val->getTerms())
{
- auto termVal = irBuilder->getIntValue(type, term->constFactor);
- for (auto factor : term->paramFactors)
+ auto termVal = irBuilder->getIntValue(type, term->getConstFactor());
+ for (auto factor : term->getParamFactors())
{
- auto factorVal = lowerVal(context, factor->param).val;
- for (IntegerLiteralValue i = 0; i < factor->power; i++)
+ auto factorVal = lowerVal(context, factor->getParam()).val;
+ for (IntegerLiteralValue i = 0; i < factor->getPower(); i++)
{
termVal = irBuilder->emitMul(factorVal->getDataType(), termVal, factorVal);
}
@@ -1488,9 +1518,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val)
{
- return emitDeclRef(context, val->declRef,
+ if (as<ThisTypeConstraintDecl>(val->getDeclRef()))
+ return LoweredValInfo::simple(context->thisTypeWitness);
+
+ return emitDeclRef(context, val->getDeclRef(),
context->irBuilder->getWitnessTableType(
- lowerType(context, val->sup)));
+ lowerType(context, val->getSup())));
}
LoweredValInfo visitTransitiveSubtypeWitness(
@@ -1498,7 +1531,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
{
// The base (subToMid) will turn into a value with
// witness-table type.
- IRInst* baseWitnessTable = lowerSimpleVal(context, val->subToMid);
+ IRInst* baseWitnessTable = lowerSimpleVal(context, val->getSubToMid());
IRInst* midToSup = nullptr;
// The next step should map to an interface requirement
@@ -1530,17 +1563,17 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(midToSup);
}
- if (auto declaredMidToSup = as<DeclaredSubtypeWitness>(val->midToSup))
+ if (auto declaredMidToSup = as<DeclaredSubtypeWitness>(val->getMidToSup()))
{
- midToSup = getInterfaceRequirementKey(context, declaredMidToSup->declRef.getDecl());
+ midToSup = getInterfaceRequirementKey(context, declaredMidToSup->getDeclRef().getDecl());
}
else
{
- midToSup = lowerSimpleVal(context, val->midToSup);
+ midToSup = lowerSimpleVal(context, val->getMidToSup());
}
return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst(
- getBuilder()->getWitnessTableType(lowerType(context, val->sup)),
+ getBuilder()->getWitnessTableType(lowerType(context, val->getSup())),
baseWitnessTable,
midToSup));
}
@@ -1550,7 +1583,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// TODO: properly fill in type info here.
// We should consider fold all cases of witness table entries to `Val`, and make the `DeclRef` case a `DeclRefVal`.
// So that we can hold the type in `DeclRefVal`.
- auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ auto funcVal = emitDeclRef(context, val->getFunc(), context->irBuilder->getTypeKind());
SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
auto diff = getBuilder()->emitForwardDifferentiateInst(getBuilder()->getTypeKind(), funcVal.val);
@@ -1559,7 +1592,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitBackwardDifferentiateVal(BackwardDifferentiateVal* val)
{
- auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ auto funcVal = emitDeclRef(context, val->getFunc(), context->irBuilder->getTypeKind());
SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
auto diff = getBuilder()->emitBackwardDifferentiateInst(getBuilder()->getTypeKind(), funcVal.val);
@@ -1568,7 +1601,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitBackwardDifferentiatePropagateVal(BackwardDifferentiatePropagateVal* val)
{
- auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ auto funcVal = emitDeclRef(context, val->getFunc(), context->irBuilder->getTypeKind());
SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
auto diff = getBuilder()->emitBackwardDifferentiatePropagateInst(getBuilder()->getTypeKind(), funcVal.val);
@@ -1577,7 +1610,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitBackwardDifferentiatePrimalVal(BackwardDifferentiatePrimalVal* val)
{
- auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ auto funcVal = emitDeclRef(context, val->getFunc(), context->irBuilder->getTypeKind());
SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
auto diff = getBuilder()->emitBackwardDifferentiatePrimalInst(getBuilder()->getTypeKind(), funcVal.val);
@@ -1586,280 +1619,18 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitBackwardDifferentiateIntermediateTypeVal(BackwardDifferentiateIntermediateTypeVal* val)
{
- auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ auto funcVal = emitDeclRef(context, val->getFunc(), context->irBuilder->getTypeKind());
SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
auto diff = getBuilder()->getBackwardDiffIntermediateContextType(funcVal.val);
return LoweredValInfo::simple(diff);
}
- LoweredValInfo visitTaggedUnionSubtypeWitness(
- TaggedUnionSubtypeWitness* val)
- {
- // The sub-type in this case is a tagged union `A | B | ...`,
- // and the witness holds an array of witnesses showing that each
- // "case" (`A`, `B`, etc.) is a subtype of the super-type.
-
- // We will start by getting the IR-level representation of the
- // sub type (the tagged union type).
- //
- auto irTaggedUnionType = lowerType(context, val->sub);
-
- // We can turn each of those per-case witnesses into a witness
- // table value:
- //
- auto caseCount = val->caseWitnesses.getCount();
- List<IRInst*> caseWitnessTables;
- for( auto caseWitness : val->caseWitnesses )
- {
- auto caseWitnessTable = lowerSimpleVal(context, caseWitness);
- caseWitnessTables.add(caseWitnessTable);
- }
-
- // Now we need to synthesize a witness table for the tagged union
- // value, showing how it can implement all of the requirements
- // of the super type by delegating to the appropriate implementation
- // on a per-case basis.
- //
- // We will assume here that the super-type is an interface, and it
- // will be left to the front-end to ensure this property.
- //
- auto supDeclRefType = as<DeclRefType>(val->sup);
- if(!supDeclRefType)
- {
- SLANG_UNEXPECTED("super-type not a decl-ref type when generating tagged union witness table");
- UNREACHABLE_RETURN(LoweredValInfo());
- }
- auto supInterfaceDeclRef = supDeclRefType->declRef.as<InterfaceDecl>();
- if( !supInterfaceDeclRef )
- {
- SLANG_UNEXPECTED("super-type not an interface type when generating tagged union witness table");
- UNREACHABLE_RETURN(LoweredValInfo());
- }
-
- auto subType = lowerType(context, val->sub);
- auto irWitnessTableBaseType = lowerType(context, supDeclRefType);
- auto irWitnessTable = getBuilder()->createWitnessTable(irWitnessTableBaseType, subType);
-
- // Now we will iterate over the requirements (members) of the
- // interface and try to synthesize an appropriate value for each.
- //
- for( auto reqDeclRef : getMembers(context->astBuilder, supInterfaceDeclRef) )
- {
- // TODO: if there are any members we shouldn't process as a requirement,
- // then we should detect and skip them here.
- //
-
- // Every interface requirement will have a unique key that is used
- // when looking up the requirement in a concrete witness table.
- //
- auto irReqKey = getInterfaceRequirementKey(context, reqDeclRef.getDecl());
-
- if (!irReqKey)
- continue;
-
- // We expect that each of the witness tables in `caseWitnessTables`
- // will have an entry to match these keys. However, we may not
- // have a concrete `IRWitnessTable` for each of the case types, either
- // because they are a specialization of a generic (so that the witness
- // table reference is a `specialize` instruction at this point), or
- // they are a type external to this module (so that we have a declaration
- // rather than a definition of the witness table).
-
- // Our task is to create an IR value that can satisfy the interface
- // requirement for the tagged union type, by appropriately delegating
- // to the implementations of the same requirement in the case types.
- //
- IRInst* irSatisfyingVal = nullptr;
-
-
-
- if(auto callableDeclRef = reqDeclRef.as<CallableDecl>())
- {
- // We have something callable, so we need to synthesize
- // a function to satisfy it.
- //
- auto irFunc = getBuilder()->createFunc();
- irSatisfyingVal = irFunc;
-
- IRBuilder subBuilderStorage = *getBuilder();
- auto subBuilder = &subBuilderStorage;
- subBuilder->setInsertInto(irFunc);
-
- // We will start by setting up the function parameters,
- // which live in the entry block of the IR function.
- //
- auto entryBlock = subBuilder->emitBlock();
- subBuilder->setInsertInto(entryBlock);
-
- // Create a `this` parameter of the tagged-union type.
- //
- // TODO: need to handle the `[mutating]` case here...
- //
- auto irThisType = irTaggedUnionType;
- auto irThisParam = subBuilder->emitParam(irThisType);
-
- List<IRType*> irParamTypes;
- irParamTypes.add(irThisType);
-
- // Create the remaining parameters of the callable,
- // using a decl-ref specialized to the tagged union
- // type (so that things like associated types are
- // mapped to the correct witness value).
- //
- List<IRParam*> irParams;
- for( auto paramDeclRef : getMembersOfType<ParamDecl>(context->astBuilder, callableDeclRef) )
- {
- // TODO: need to handle `out` and `in out` here. Over all
- // there is a lot of duplication here with the existing logic
- // for emitting the signature of a `CallableDecl`, and we should
- // try to re-use that if at all possible.
- //
- auto irParamType = lowerType(context, getType(context->astBuilder, paramDeclRef));
- auto irParam = subBuilder->emitParam(irParamType);
-
- irParams.add(irParam);
- irParamTypes.add(irParamType);
- }
-
- auto irResultType = lowerType(context, getResultType(context->astBuilder, callableDeclRef));
-
- auto irFuncType = subBuilder->getFuncType(
- irParamTypes,
- irResultType);
- irFunc->setFullType(irFuncType);
-
- // The first thing our function needs to do is extract the tag
- // from the incoming `this` parameter.
- //
- auto irTagVal = subBuilder->emitExtractTaggedUnionTag(irThisParam);
-
- // Next we want to emit a `switch` on the tag value, but before we
- // do that we need to generate the code for each of the cases so that
- // our `switch` has somewhere to branch to.
- //
- List<IRInst*> switchCaseOperands;
-
- IRBlock* defaultLabel = nullptr;
-
- for( Index ii = 0; ii < caseCount; ++ii )
- {
- auto caseTag = subBuilder->getIntValue(irTagVal->getDataType(), ii);
-
- subBuilder->setInsertInto(irFunc);
- auto caseLabel = subBuilder->emitBlock();
-
- if(!defaultLabel)
- defaultLabel = caseLabel;
-
- switchCaseOperands.add(caseTag);
- switchCaseOperands.add(caseLabel);
-
- subBuilder->setInsertInto(caseLabel);
-
- // We need to look up the satisfying value for this interface
- // requirement on the witness table of the particular case value.
- //
- // We already have the witness table, and the requirement key is
- // just `irReqKey`.
- //
- auto caseWitnessTable = caseWitnessTables[ii];
-
- // The subtle bit here is determining the type we expect the
- // satisfying value to have, since that depends on the actual
- // type that is satisfying the requirement.
- //
- IRType* caseResultType = irResultType;
- IRType* caseFuncType = nullptr;
- auto caseFunc = subBuilder->emitLookupInterfaceMethodInst(
- caseFuncType,
- caseWitnessTable,
- irReqKey);
-
- // We are going to emit a `call` to the satisfying value
- // for the case type, so we will collect the arguments for that call.
- //
- List<IRInst*> caseArgs;
-
- // The `this` argument to the call will need to represent the
- // appropriate field of our tagged union.
- //
- IRType* caseThisType = (IRType*) irTaggedUnionType->getOperand(ii);
- auto caseThisArg = subBuilder->emitExtractTaggedUnionPayload(
- caseThisType,
- irThisParam, caseTag);
- caseArgs.add(caseThisArg);
-
- // The remaining arguments to the call will just be forwarded from
- // the parameters of the wrapper function.
- //
- // TODO: This would need to change if/when we started allowing `This` type
- // or associated-type parameters to be used at call sites where a tagged
- // union is used.
- //
- for( auto param : irParams )
- {
- caseArgs.add(param);
- }
-
- auto caseCall = subBuilder->emitCallInst(caseResultType, caseFunc, caseArgs);
-
- if( as<IRVoidType>(irResultType->getDataType()) )
- {
- subBuilder->emitReturn();
- }
- else
- {
- subBuilder->emitReturn(caseCall);
- }
- }
-
- // We will create a block to represent the supposedly-unreachable
- // code that will run if no `case` matches.
- //
- subBuilder->setInsertInto(irFunc);
- auto invalidLabel = subBuilder->emitBlock();
- subBuilder->setInsertInto(invalidLabel);
- subBuilder->emitUnreachable();
-
- if(!defaultLabel) defaultLabel = invalidLabel;
-
- // Now we have enough information to go back and emit the `switch` instruction
- // into the entry block.
- subBuilder->setInsertInto(entryBlock);
- subBuilder->emitSwitch(
- irTagVal, // value to `switch` on
- invalidLabel, // `break` label (block after the `switch` statement ends)
- defaultLabel, // `default` label (where to go if no `case` matches)
- switchCaseOperands.getCount(),
- switchCaseOperands.getBuffer());
- }
- else
- {
- // TODO: We need to handle other cases of interface requirements.
- SLANG_UNEXPECTED("unexpceted interface requirement when generating tagged union witness table");
- UNREACHABLE_RETURN(LoweredValInfo());
- }
-
- // Once we've generating a value to satisfying the requirement, we install
- // it into the witness table for our tagged-union type.
- //
- getBuilder()->createWitnessTableEntry(irWitnessTable, irReqKey, irSatisfyingVal);
- }
- return LoweredValInfo::simple(irWitnessTable);
- }
-
LoweredValInfo visitDynamicSubtypeWitness(DynamicSubtypeWitness * /*val*/)
{
return LoweredValInfo::simple(nullptr);
}
- LoweredValInfo visitThisTypeSubtypeWitness(ThisTypeSubtypeWitness* val)
- {
- SLANG_UNUSED(val);
- return LoweredValInfo::simple(context->thisTypeWitness);
- }
-
LoweredValInfo visitConjunctionSubtypeWitness(ConjunctionSubtypeWitness* val)
{
// A witness `W = X & Y & ...` will lower as a tuple of the sub-witnesses
@@ -1892,14 +1663,14 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// witness that `T : L & R`, so lower that first and expect it to be
// a value of tuple type.
//
- auto conjunctionWitness = lowerSimpleVal(context, val->conjunctionWitness);
+ auto conjunctionWitness = lowerSimpleVal(context, val->getConjunctionWitness());
auto conjunctionTupleType = as<IRTupleType>(conjunctionWitness->getDataType());
SLANG_ASSERT(conjunctionTupleType);
// The `ExtractFromConjunctionSubtypeWitness` also stores the index of
// the witness/supertype we want in the conjunction `L & R`.
//
- auto indexInConjunction = val->indexInConjunction;
+ auto indexInConjunction = val->getIndexInConjunction();
// We want to extract the appropriate element from the tuple based on
// the index, but to know the type of the result we need to look up
@@ -1923,8 +1694,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitConstantIntVal(ConstantIntVal* val)
{
- auto type = lowerType(context, val->type);
- return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value));
+ auto type = lowerType(context, val->getType());
+ return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->getValue()));
}
IRFuncType* visitFuncType(FuncType* type)
@@ -1964,7 +1735,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
IRType* visitDeclRefType(DeclRefType* type)
{
- auto declRef = type->declRef;
+ auto declRef = type->getDeclRef();
auto decl = declRef.getDecl();
// Check for types with teh `__intrinsic_type` modifier.
@@ -1988,13 +1759,13 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
IRType* visitBasicExpressionType(BasicExpressionType* type)
{
return getBuilder()->getBasicType(
- type->baseType);
+ type->getBaseType());
}
IRType* visitVectorExpressionType(VectorExpressionType* type)
{
- auto elementType = lowerType(context, type->elementType);
- auto elementCount = lowerSimpleVal(context, type->elementCount);
+ auto elementType = lowerType(context, type->getElementType());
+ auto elementCount = lowerSimpleVal(context, type->getElementCount());
return getBuilder()->getVectorType(
elementType,
@@ -2030,19 +1801,6 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
}
}
- // Lower substitution args and collect them into a list of IR operands.
- void _collectSubstitutionArgs(List<IRInst*>& operands, Substitutions* subst)
- {
- if (!subst) return;
- _collectSubstitutionArgs(operands, subst->getOuter());
- if (auto genSubst = as<GenericSubstitution>(subst))
- {
- for (auto arg : genSubst->getArgs())
- {
- operands.add(lowerVal(context, arg).val);
- }
- }
- }
// Lower a type where the type declaration being referenced is assumed
// to be an intrinsic type, which can thus be lowered to a simple IR
// type with the appropriate opcode.
@@ -2050,13 +1808,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
{
SLANG_ASSERT(getBuilder()->getInsertLoc().getMode() != IRInsertLoc::Mode::None);
- auto intrinsicTypeModifier = type->declRef.getDecl()->findModifier<IntrinsicTypeModifier>();
+ auto intrinsicTypeModifier = type->getDeclRef().getDecl()->findModifier<IntrinsicTypeModifier>();
SLANG_ASSERT(intrinsicTypeModifier);
IROp op = IROp(intrinsicTypeModifier->irOp);
List<IRInst*> operands;
// If there are any substitutions attached to the declRef,
// add them as operands of the IR type.
- _collectSubstitutionArgs(operands, type->declRef.getSubst());
+ SubstitutionSet(type->getDeclRef()).forEachSubstitutionArg([&](Val* arg)
+ {
+ operands.add(lowerVal(context, arg).val);
+ });
return getBuilder()->getType(
op,
static_cast<UInt>(operands.getCount()),
@@ -2095,7 +1856,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
IRType* visitExtractExistentialType(ExtractExistentialType* type)
{
- auto declRef = type->declRef;
+ auto declRef = type->getDeclRef();
auto existentialType = lowerType(context, getType(context->astBuilder, declRef));
IRInst* existentialVal = getSimpleVal(context, emitDeclRef(context, declRef, existentialType));
return getBuilder()->emitExtractExistentialType(existentialVal);
@@ -2103,50 +1864,20 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitExtractExistentialSubtypeWitness(ExtractExistentialSubtypeWitness* witness)
{
- auto declRef = witness->declRef;
+ auto declRef = witness->getDeclRef();
auto existentialType = lowerType(context, getType(context->astBuilder, declRef));
IRInst* existentialVal = getSimpleVal(context, emitDeclRef(context, declRef, existentialType));
return LoweredValInfo::simple(getBuilder()->emitExtractExistentialWitnessTable(existentialVal));
}
- LoweredValInfo visitTaggedUnionType(TaggedUnionType* type)
- {
- // A tagged union type will lower into an IR `union` over the cases,
- // along with an IR `struct` with a field for the union and a tag.
- // (Note: we are placing the tag after the payload to avoid padding
- // in the case where the payload is more aligned than the tag)
- //
- // TODO: should we be lowering directly like this, or have
- // an IR-level representation of tagged unions?
- //
-
- List<IRType*> irCaseTypes;
- for(auto caseType : type->caseTypes)
- {
- auto irCaseType = lowerType(context, caseType);
- irCaseTypes.add(irCaseType);
- }
-
- auto irType = getBuilder()->getTaggedUnionType(irCaseTypes);
- if(!irType->findDecoration<IRLinkageDecoration>())
- {
- // We need a way for later passes to attach layout information
- // to this type, so we will give it a mangled name here.
- //
- getBuilder()->addExportDecoration(
- irType,
- getMangledTypeName(context->astBuilder, type).getUnownedSlice());
- }
- return LoweredValInfo::simple(irType);
- }
-
LoweredValInfo visitExistentialSpecializedType(ExistentialSpecializedType* type)
{
- auto irBaseType = lowerType(context, type->baseType);
+ auto irBaseType = lowerType(context, type->getBaseType());
List<IRInst*> slotArgs;
- for(auto arg : type->args)
+ for (Index i = 0; i < type->getArgCount(); i++)
{
+ auto arg = type->getArg(i);
auto irArgVal = lowerSimpleVal(context, arg.val);
slotArgs.add(irArgVal);
@@ -2173,13 +1904,13 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
//
if (context->thisType != nullptr)
return LoweredValInfo::simple(context->thisType);
- return emitDeclRef(context, type->interfaceDeclRef, getBuilder()->getTypeKind());
+ return emitDeclRef(context, makeDeclRef(type->getInterfaceDecl()), getBuilder()->getTypeKind());
}
LoweredValInfo visitAndType(AndType* type)
{
- auto left = lowerType(context, type->left);
- auto right = lowerType(context, type->right);
+ auto left = lowerType(context, type->getLeft());
+ auto right = lowerType(context, type->getRight());
auto irType = getBuilder()->getConjunctionType(left, right);
return LoweredValInfo::simple(irType);
@@ -2187,11 +1918,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitModifiedType(ModifiedType* astType)
{
- IRType* irBase = lowerType(context, astType->base);
+ IRType* irBase = lowerType(context, astType->getBase());
List<IRAttr*> irAttrs;
- for(auto astModifier : astType->modifiers)
+ for(Index i = 0; i < astType->getModifierCount(); i++)
{
+ auto astModifier = astType->getModifier(i);
IRAttr* irAttr = (IRAttr*) lowerSimpleVal(context, astModifier);
if(irAttr)
irAttrs.add(irAttr);
@@ -2237,7 +1969,8 @@ LoweredValInfo lowerVal(
{
ValLoweringVisitor visitor;
visitor.context = context;
- return visitor.dispatch(val);
+ auto resolvedVal = val->resolve();
+ return visitor.dispatch(resolvedVal);
}
IRType* lowerType(
@@ -2786,8 +2519,8 @@ ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection de
DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, SemanticsVisitor* semantics, Decl* decl)
{
- DeclRef<Decl> declRef = context->astBuilder->getSpecializedDeclRef(
- decl, createDefaultSubstitutions(context->astBuilder, semantics, decl));
+ DeclRef<Decl> declRef = createDefaultSubstitutionsIfNeeded(context->astBuilder, semantics,
+ makeDeclRef(decl));
return declRef;
}
//
@@ -2808,7 +2541,7 @@ static Type* _findReplacementThisParamType(
auto targetType = getTargetType(context->astBuilder, extensionDeclRef);
if(auto targetDeclRefType = as<DeclRefType>(targetType))
{
- if(auto replacementType = _findReplacementThisParamType(context, targetDeclRefType->declRef))
+ if(auto replacementType = _findReplacementThisParamType(context, targetDeclRefType->getDeclRef()))
return replacementType;
}
return targetType;
@@ -2816,8 +2549,7 @@ static Type* _findReplacementThisParamType(
if (auto interfaceDeclRef = parentDeclRef.as<InterfaceDecl>())
{
- auto thisType = context->astBuilder->create<ThisType>();
- thisType->interfaceDeclRef = interfaceDeclRef;
+ auto thisType = DeclRefType::create(context->astBuilder, interfaceDeclRef.getDecl()->getThisTypeDecl());
return thisType;
}
@@ -2853,13 +2585,13 @@ Type* getThisParamTypeForCallable(
IRGenContext* context,
DeclRef<Decl> callableDeclRef)
{
- auto parentDeclRef = callableDeclRef.getParent(context->astBuilder);
+ auto parentDeclRef = callableDeclRef.getParent();
if(auto subscriptDeclRef = parentDeclRef.as<SubscriptDecl>())
- parentDeclRef = subscriptDeclRef.getParent(context->astBuilder);
+ parentDeclRef = subscriptDeclRef.getParent();
if(auto genericDeclRef = parentDeclRef.as<GenericDecl>())
- parentDeclRef = genericDeclRef.getParent(context->astBuilder);
+ parentDeclRef = genericDeclRef.getParent();
return getThisParamTypeForContainer(context, parentDeclRef);
}
@@ -2997,7 +2729,7 @@ void collectParameterLists(
// The parameters introduced by any "parent" declarations
// will need to come first, so we'll deal with that
// logic here.
- if( auto parentDeclRef = declRef.getParent(context->astBuilder) )
+ if( auto parentDeclRef = declRef.getParent() )
{
// Compute the mode to use when collecting parameters from
// the outer declaration. The most important question here
@@ -3592,7 +3324,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
{
auto innerType = type;
while (auto modifiedType = as<ModifiedType>(innerType))
- innerType = modifiedType->base;
+ innerType = modifiedType->getBase();
return innerType;
}
@@ -3607,9 +3339,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
else if (auto vectorType = as<VectorExpressionType>(type))
{
- UInt elementCount = (UInt) getIntVal(vectorType->elementCount);
+ UInt elementCount = (UInt) getIntVal(vectorType->getElementCount());
- auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType));
+ auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->getElementType()));
List<IRInst*> args;
for(UInt ee = 0; ee < elementCount; ++ee)
@@ -3644,7 +3376,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
else if (auto declRefType = as<DeclRefType>(type))
{
- DeclRef<Decl> declRef = declRefType->declRef;
+ DeclRef<Decl> declRef = declRefType->getDeclRef();
if (auto enumType = declRef.as<EnumDecl>())
{
return LoweredValInfo::simple(getBuilder()->getIntValue(irType, 0));
@@ -3735,7 +3467,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
else if (auto vectorType = as<VectorExpressionType>(type))
{
- UInt elementCount = (UInt) getIntVal(vectorType->elementCount);
+ UInt elementCount = (UInt) getIntVal(vectorType->getElementCount());
for (UInt ee = 0; ee < argCount; ++ee)
{
@@ -3745,7 +3477,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
if(elementCount > argCount)
{
- auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType));
+ auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->getElementType()));
for(UInt ee = argCount; ee < elementCount; ++ee)
{
args.add(irDefaultValue);
@@ -3781,7 +3513,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
else if (auto declRefType = as<DeclRefType>(type))
{
- DeclRef<Decl> declRef = declRefType->declRef;
+ DeclRef<Decl> declRef = declRefType->getDeclRef();
if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>())
{
UInt argCounter = 0;
@@ -3896,19 +3628,19 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
UNREACHABLE_RETURN(LoweredValInfo());
}
- void _lowerSubstitutionArg(IRGenContext* subContext, GenericSubstitution* subst, Decl* paramDecl, Index argIndex)
+ void _lowerSubstitutionArg(IRGenContext* subContext, GenericAppDeclRef* subst, Decl* paramDecl, Index argIndex)
{
SLANG_ASSERT(argIndex < subst->getArgs().getCount());
auto argVal = lowerVal(subContext, subst->getArgs()[argIndex]);
subContext->setValue(paramDecl, argVal);
}
- void _lowerSubstitutionEnv(IRGenContext* subContext, Substitutions* subst)
+ void _lowerSubstitutionEnv(IRGenContext* subContext, DeclRefBase* subst)
{
if(!subst) return;
- _lowerSubstitutionEnv(subContext, subst->getOuter());
+ _lowerSubstitutionEnv(subContext, subst->getBase());
- if (auto genSubst = as<GenericSubstitution>(subst))
+ if (auto genSubst = as<GenericAppDeclRef>(subst))
{
auto genDecl = genSubst->getGenericDecl();
@@ -3985,7 +3717,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
IRGenContext* subContext = &subContextStorage;
subContext->env = subEnv;
- _lowerSubstitutionEnv(subContext, argExpr.getSubsts());
+ _lowerSubstitutionEnv(subContext, argExpr.getSubsts() ? argExpr.getSubsts().declRef : nullptr);
addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups);
@@ -4148,7 +3880,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
{
if(auto declRefType = as<DeclRefType>(e->type))
{
- if(declRefType->declRef.as<InterfaceDecl>())
+ if(declRefType->getDeclRef().as<InterfaceDecl>())
{
e = castExpr->valueArg;
continue;
@@ -4387,7 +4119,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
{
if( auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(subTypeWitness) )
{
- return extractField(superType, value, declaredSubtypeWitness->declRef);
+ return extractField(superType, value, declaredSubtypeWitness->getDeclRef());
}
else
{
@@ -4414,7 +4146,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
//
if( auto declRefType = as<DeclRefType>(expr->type) )
{
- auto declRef = declRefType->declRef;
+ auto declRef = declRefType->getDeclRef();
if( auto interfaceDeclRef = declRef.as<InterfaceDecl>() )
{
// We have an expression that is "up-casting" some concrete value
@@ -4573,12 +4305,6 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
UNREACHABLE_RETURN(LoweredValInfo());
}
- LoweredValInfo visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* /*expr*/)
- {
- SLANG_UNIMPLEMENTED_X("tagged union type expression during code generation");
- UNREACHABLE_RETURN(LoweredValInfo());
- }
-
LoweredValInfo visitThisTypeExpr(ThisTypeExpr* /*expr*/)
{
SLANG_UNIMPLEMENTED_X("this-type expression during code generation");
@@ -6676,7 +6402,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if (auto declRefType = as<DeclRefType>(type))
{
- if (declRefType->declRef.getDecl()->findModifier<PublicModifier>())
+ if (declRefType->getDeclRef().getDecl()->findModifier<PublicModifier>())
return true;
}
return false;
@@ -6690,7 +6416,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
auto subBuilder = subContext->irBuilder;
- for(auto entry : astWitnessTable->requirementDictionary)
+ for(auto entry : astWitnessTable->getRequirementDictionary())
{
auto requiredMemberDecl = entry.key;
auto satisfyingWitness = entry.value;
@@ -6787,7 +6513,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto targetType = parentExtensionDecl->targetType;
if(auto targetDeclRefType = as<DeclRefType>(targetType))
{
- if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>())
+ if(auto targetInterfaceDeclRef = targetDeclRefType->getDeclRef().as<InterfaceDecl>())
{
return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl));
}
@@ -6815,7 +6541,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if(auto superDeclRefType = as<DeclRefType>(superType))
{
- if( superDeclRefType->declRef.as<StructDecl>() || superDeclRefType->declRef.as<ClassDecl>() )
+ if( superDeclRefType->getDeclRef().as<StructDecl>() || superDeclRefType->getDeclRef().as<ClassDecl>() )
{
// TODO: the witness that a type inherits from a `struct`
// type should probably be a key that will be used for
@@ -7675,6 +7401,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo::simple(finishOuterGenerics(subBuilder, loweredTagType, outerGeneric));
}
+ LoweredValInfo visitThisTypeDecl(ThisTypeDecl* decl)
+ {
+ auto interfaceType = ensureDecl(context, decl->parentDecl).val;
+ return LoweredValInfo::simple(context->irBuilder->getThisType(as<IRInterfaceType>(interfaceType)));
+ }
+
+ LoweredValInfo visitThisTypeConstraintDecl(ThisTypeConstraintDecl* decl)
+ {
+ SLANG_UNUSED(decl);
+ return LoweredValInfo();
+ }
+
LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl)
{
// Don't generate an IR `struct` for intrinsic types
@@ -7753,8 +7491,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto superType = inheritanceDecl->base;
if(auto superDeclRefType = as<DeclRefType>(superType))
{
- if (superDeclRefType->declRef.as<StructDecl>() ||
- superDeclRefType->declRef.as<ClassDecl>())
+ if (superDeclRefType->getDeclRef().as<StructDecl>() ||
+ superDeclRefType->getDeclRef().as<ClassDecl>())
{
auto superKey = (IRStructKey*) getSimpleVal(context, ensureDecl(context, inheritanceDecl));
auto irSuperType = lowerType(context, superType.type);
@@ -7890,10 +7628,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
void lowerDifferentiableAttribute(IRGenContext* subContext, IRInst* inst, DifferentiableAttribute* attr)
{
auto irDict = getBuilder()->addDifferentiableTypeDictionaryDecoration(inst);
- for (auto& entry : attr->m_mapTypeToIDifferentiableWitness)
+ for (auto& entry : attr->getMapTypeToIDifferentiableWitness())
{
// Lower type and witness.
- IRType* irType = lowerType(subContext, entry.value->sub);
+ IRType* irType = lowerType(subContext, entry.value->getSub());
IRInst* irWitness = lowerVal(subContext, entry.value).val;
SLANG_ASSERT(irType);
@@ -8028,7 +7766,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
if (auto declRefType = as<DeclRefType>(constraintDecl->sub.type))
{
- auto typeParamDeclVal = subContext->findLoweredDecl(declRefType->declRef.getDecl());
+ auto typeParamDeclVal = subContext->findLoweredDecl(declRefType->getDeclRef().getDecl());
SLANG_ASSERT(typeParamDeclVal && typeParamDeclVal->val);
subBuilder->addTypeConstraintDecoration(typeParamDeclVal->val, supType);
}
@@ -8162,7 +7900,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
if(auto targetDeclRefType = as<DeclRefType>(extensionAncestor->targetType))
{
- if(auto interfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>())
+ if(auto interfaceDeclRef = targetDeclRefType->getDeclRef().as<InterfaceDecl>())
{
return emitOuterInterfaceGeneric(subContext, extensionAncestor, targetDeclRefType, leafDecl);
}
@@ -8608,9 +8346,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
addNameHint(subContext, irFunc, decl);
addLinkageDecoration(subContext, irFunc, decl);
- if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>())
+ if (decl->body)
{
- lowerDifferentiableAttribute(subContext, irFunc, differentialAttr);
+ if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>())
+ {
+ lowerDifferentiableAttribute(subContext, irFunc, differentialAttr);
+ }
}
// Always force inline diff setter accessor to prevent downstream compiler from complaining
@@ -9375,7 +9116,7 @@ static void _addFlattenedTupleArgs(
LoweredValInfo emitDeclRef(
IRGenContext* context,
Decl* decl,
- Substitutions* subst,
+ DeclRefBase* subst,
IRType* type)
{
const auto initialSubst = subst;
@@ -9383,27 +9124,28 @@ LoweredValInfo emitDeclRef(
// We need to proceed by considering the specializations that
// have been put in place.
+ subst = SubstitutionSet(subst).getInnerMostNodeWithSubstInfo();
// If the declaration would not get wrapped in a `IRGeneric`,
// even if it is nested inside of an AST `GenericDecl`, then
// we should also ignore any generic substitutions.
if(!canDeclLowerToAGeneric(decl))
{
- while(auto genericSubst = as<GenericSubstitution>(subst))
- subst = genericSubst->getOuter();
+ while(auto genericSubst = SubstitutionSet(subst).findGenericAppDeclRef())
+ subst = genericSubst->getBase();
}
// In the simplest case, there is no specialization going
// on, and the decl-ref turns into a reference to the
// lowered IR value for the declaration.
- if(!subst)
+ if(!SubstitutionSet(subst) || _isTrivialLookupFromInterfaceThis(context, subst))
{
LoweredValInfo loweredDecl = ensureDecl(context, decl);
return loweredDecl;
}
// Otherwise, we look at the kind of substitution, and let it guide us.
- if(auto genericSubst = as<GenericSubstitution>(subst))
+ if(auto genericSubst = as<GenericAppDeclRef>(subst))
{
// A generic substitution means we will need to output
// a `specialize` instruction to specialize the generic.
@@ -9419,7 +9161,7 @@ LoweredValInfo emitDeclRef(
LoweredValInfo genericVal = emitDeclRef(
context,
decl,
- genericSubst->getOuter(),
+ genericSubst->getBase(),
context->irBuilder->getGenericKind());
// There's no reason to specialize something that maps to a NULL pointer.
@@ -9464,21 +9206,21 @@ LoweredValInfo emitDeclRef(
return LoweredValInfo::simple(irSpecializedVal);
}
- else if(auto thisTypeSubst = as<ThisTypeSubstitution>(subst))
+ else if(auto thisTypeSubst = as<LookupDeclRef>(subst))
{
- if(decl == thisTypeSubst->interfaceDecl)
+ if( as<ThisTypeDecl>(decl))
{
- // This is a reference to the interface type itself,
- // through the this-type substitution, so it is really
- // a reference to the this-type.
- return lowerType(context, thisTypeSubst->witness->sub);
+ // This is a reference to the ThisType from the interface,
+ // therefore we should just lower it as the sub type.
+ return lowerType(context, thisTypeSubst->getWitness()->getSub());
}
if(isInterfaceRequirement(decl))
{
- // Somebody is trying to look up an interface requirement
- // "through" some concrete type. We need to lower this decl-ref
- // as a lookup of the corresponding member in a witness table.
+ // If we reach here, somebody is trying to look up an interface
+ // requirement "through" some concrete type. We need to lower this
+ // decl-ref as a lookup of the corresponding member in a witness
+ // table.
//
// The witness table itself is referenced by the this-type
// substitution, so we can just lower that.
@@ -9491,7 +9233,7 @@ LoweredValInfo emitDeclRef(
// `ISomething<T>`. That is because we really care about the
// witness table for the concrete type that conforms to `ISomething<Foo>`.
//
- auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness);
+ auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->getWitness());
//
// The key to use for looking up the interface member is
// derived from the declaration.
@@ -9517,14 +9259,14 @@ LoweredValInfo emitDeclRef(
// are lowered as generics, where the generic parameter represents
// the `ThisType`.
//
- auto genericVal = emitDeclRef(context, decl, thisTypeSubst->getOuter(), context->irBuilder->getGenericKind());
+ auto genericVal = emitDeclRef(context, decl, thisTypeSubst->getBase(), context->irBuilder->getGenericKind());
auto irGenericVal = getSimpleVal(context, genericVal);
// In order to reference the member for a particular type, we
// specialize the generic for that type.
//
- IRInst* irSubType = lowerType(context, thisTypeSubst->witness->sub);
- IRInst* irSubTypeWitness = lowerSimpleVal(context, thisTypeSubst->witness);
+ IRInst* irSubType = lowerType(context, thisTypeSubst->getWitness()->getSub());
+ IRInst* irSubTypeWitness = lowerSimpleVal(context, thisTypeSubst->getWitness());
IRInst* irSpecializeArgs[] = { irSubType, irSubTypeWitness };
auto irSpecializedVal = context->irBuilder->emitSpecializeInst(
@@ -9550,7 +9292,7 @@ LoweredValInfo emitDeclRef(
return emitDeclRef(
context,
declRef.getDecl(),
- declRef.getSubst(),
+ declRef.declRefBase,
type);
}
@@ -9723,6 +9465,7 @@ RefPtr<IRModule> generateIRForTranslationUnit(
TranslationUnitRequest* translationUnit)
{
SLANG_PROFILE;
+ SLANG_AST_BUILDER_RAII(astBuilder);
auto session = translationUnit->getSession();
auto compileRequest = translationUnit->compileRequest;
@@ -10082,6 +9825,8 @@ RefPtr<IRModule> generateIRForSpecializedComponentType(
SpecializedComponentType* componentType,
DiagnosticSink* sink)
{
+ SLANG_AST_BUILDER_RAII(componentType->getLinkage()->getASTBuilder());
+
SpecializedComponentTypeIRGenContext context;
return context.process(componentType, sink);
}
@@ -10135,6 +9880,8 @@ RefPtr<IRModule> generateIRForTypeConformance(
Int conformanceIdOverride,
DiagnosticSink* sink)
{
+ SLANG_AST_BUILDER_RAII(typeConformance->getLinkage()->getASTBuilder());
+
TypeConformanceIRGenContext context;
return context.process(typeConformance, conformanceIdOverride, sink);
}
@@ -10296,20 +10043,6 @@ IRTypeLayout* lowerTypeLayout(
IRPointerTypeLayout::Builder builder(context->irBuilder);
return _lowerTypeLayoutCommon(context, &builder, ptrTypeLayout);
}
- else if( auto taggedUnionTypeLayout = as<TaggedUnionTypeLayout>(typeLayout) )
- {
- IRTaggedUnionTypeLayout::Builder builder(context->irBuilder, taggedUnionTypeLayout->tagOffset);
-
- for( auto caseTypeLayout : taggedUnionTypeLayout->caseTypeLayouts )
- {
- builder.addCaseTypeLayout(
- lowerTypeLayout(
- context,
- caseTypeLayout));
- }
-
- return _lowerTypeLayoutCommon(context, &builder, taggedUnionTypeLayout);
- }
else if( auto streamOutputTypeLayout = as<StreamOutputTypeLayout>(typeLayout) )
{
auto irElementTypeLayout = lowerTypeLayout(context, streamOutputTypeLayout->elementTypeLayout);
@@ -10453,6 +10186,9 @@ RefPtr<IRModule> TargetProgram::createIRModuleForLayout(DiagnosticSink* sink)
auto program = getProgram();
auto linkage = program->getLinkage();
+
+ SLANG_AST_BUILDER_RAII(linkage->getASTBuilder());
+
auto session = linkage->getSessionImpl();
SharedIRGenContext sharedContextStorage(
@@ -10550,16 +10286,6 @@ RefPtr<IRModule> TargetProgram::createIRModuleForLayout(DiagnosticSink* sink)
builder->addLayoutDecoration(irFunc, irEntryPointLayout);
}
- for( auto taggedUnionTypeLayout : programLayout->taggedUnionTypeLayouts )
- {
- auto taggedUnionType = taggedUnionTypeLayout->getType();
- auto irType = lowerType(context, taggedUnionType);
-
- auto irTypeLayout = lowerTypeLayout(context, taggedUnionTypeLayout);
-
- builder->addLayoutDecoration(irType, irTypeLayout);
- }
-
// Lets strip and run DCE here
if (linkage->m_obfuscateCode)
{
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index 4d94d5283..b27a45484 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -141,7 +141,7 @@ namespace Slang
{
if( auto constVal = as<ConstantIntVal>(val) )
{
- auto cVal = constVal->value;
+ auto cVal = constVal->getValue();
if(cVal >= 0 && cVal <= 9 )
{
emit(context, (UInt)cVal);
@@ -190,13 +190,13 @@ namespace Slang
if( auto basicType = dynamicCast<BasicExpressionType>(type) )
{
- emitBaseType(context, basicType->baseType);
+ emitBaseType(context, basicType->getBaseType());
}
else if( auto vecType = dynamicCast<VectorExpressionType>(type) )
{
emitRaw(context, "v");
- emitSimpleIntVal(context, vecType->elementCount);
- emitType(context, vecType->elementType);
+ emitSimpleIntVal(context, vecType->getElementCount());
+ emitType(context, vecType->getElementType());
}
else if( auto matType = dynamicCast<MatrixExpressionType>(type) )
{
@@ -208,11 +208,11 @@ namespace Slang
}
else if( auto namedType = dynamicCast<NamedExpressionType>(type) )
{
- emitType(context, getType(context->astBuilder, namedType->declRef));
+ emitType(context, getType(context->astBuilder, namedType->getDeclRef()));
}
else if( auto declRefType = dynamicCast<DeclRefType>(type) )
{
- emitQualifiedName(context, declRefType->declRef);
+ emitQualifiedName(context, declRefType->getDeclRef());
}
else if (auto arrType = dynamicCast<ArrayExpressionType>(type))
{
@@ -220,19 +220,10 @@ namespace Slang
emitSimpleIntVal(context, arrType->getElementCount());
emitType(context, arrType->getElementType());
}
- else if( auto taggedUnionType = dynamicCast<TaggedUnionType>(type) )
- {
- emitRaw(context, "u");
- for( auto caseType : taggedUnionType->caseTypes )
- {
- emitType(context, caseType);
- }
- emitRaw(context, "U");
- }
else if( auto thisType = dynamicCast<ThisType>(type) )
{
emitRaw(context, "t");
- emitQualifiedName(context, thisType->interfaceDeclRef);
+ emitQualifiedName(context, thisType->getInterfaceDecl());
}
else if (const auto errorType = dynamicCast<ErrorType>(type))
{
@@ -300,50 +291,50 @@ namespace Slang
// "depth" (how many outer generics) and "index" (which
// parameter are they at the specified depth).
emitRaw(context, "K");
- emitName(context, genericParamIntVal->declRef.getName());
+ emitName(context, genericParamIntVal->getDeclRef().getName());
}
else if( auto constantIntVal = dynamicCast<ConstantIntVal>(val) )
{
// TODO: need to figure out what prefix/suffix is needed
// to allow demangling later.
emitRaw(context, "k");
- emit(context, (UInt) constantIntVal->value);
+ emit(context, (UInt) constantIntVal->getValue());
}
else if (auto funcCallIntVal = dynamicCast<FuncCallIntVal>(val))
{
emitRaw(context, "KC");
- emit(context, funcCallIntVal->args.getCount());
- emitName(context, funcCallIntVal->funcDeclRef.getName());
- for (Index i = 0; i < funcCallIntVal->args.getCount(); i++)
- emitVal(context, funcCallIntVal->args[i]);
+ emit(context, funcCallIntVal->getArgs().getCount());
+ emitName(context, funcCallIntVal->getFuncDeclRef().getName());
+ for (Index i = 0; i < funcCallIntVal->getArgs().getCount(); i++)
+ emitVal(context, funcCallIntVal->getArgs()[i]);
}
else if (auto lookupIntVal = dynamicCast<WitnessLookupIntVal>(val))
{
emitRaw(context, "KL");
- emitVal(context, lookupIntVal->witness);
- emitName(context, lookupIntVal->key->getName());
+ emitVal(context, lookupIntVal->getWitness());
+ emitName(context, lookupIntVal->getKey()->getName());
}
else if (const auto polynomialIntVal = dynamicCast<PolynomialIntVal>(val))
{
emitRaw(context, "KX");
- emit(context, (UInt)polynomialIntVal->constantTerm);
- emit(context, (UInt)polynomialIntVal->terms.getCount());
- for (auto term : polynomialIntVal->terms)
+ emit(context, (UInt)polynomialIntVal->getConstantTerm());
+ emit(context, (UInt)polynomialIntVal->getTerms().getCount());
+ for (auto term : polynomialIntVal->getTerms())
{
- emit(context, (UInt)term->constFactor);
- emit(context, (UInt)term->paramFactors.getCount());
- for (auto factor : term->paramFactors)
+ emit(context, (UInt)term->getConstFactor());
+ emit(context, (UInt)term->getParamFactors().getCount());
+ for (auto factor : term->getParamFactors())
{
- emitVal(context, factor->param);
- emit(context, (UInt)factor->power);
+ emitVal(context, factor->getParam());
+ emit(context, (UInt)factor->getPower());
}
}
}
else if (const auto typecastIntVal = dynamicCast<TypeCastIntVal>(val))
{
emitRaw(context, "KK");
- emitVal(context, typecastIntVal->type);
- emitVal(context, typecastIntVal->base);
+ emitVal(context, typecastIntVal->getType());
+ emitVal(context, typecastIntVal->getBase());
}
else
{
@@ -355,7 +346,7 @@ namespace Slang
ManglingContext* context,
DeclRef<Decl> declRef)
{
- auto parentDeclRef = declRef.getParent(context->astBuilder);
+ auto parentDeclRef = declRef.getParent();
auto parentGenericDeclRef = parentDeclRef.as<GenericDecl>();
if( parentDeclRef )
{
@@ -423,14 +414,14 @@ namespace Slang
// There are two cases here: either we have specializations
// in place for the parent generic declaration, or we don't.
- auto subst = findInnerMostGenericSubstitution(declRef.getSubst());
- if( subst && subst->getGenericDecl() == parentGenericDeclRef.getDecl())
+ auto substArgs = tryGetGenericArguments(SubstitutionSet(declRef), parentGenericDeclRef.getDecl());
+ if (substArgs.getCount())
{
// This is the case where we *do* have substitutions.
emitRaw(context, "G");
- UInt genericArgCount = subst->getArgs().getCount();
+ UInt genericArgCount = substArgs.getCount();
emit(context, genericArgCount);
- for (auto aa : subst->getArgs())
+ for (auto aa : substArgs)
{
emitVal(context, aa);
}
@@ -441,7 +432,7 @@ namespace Slang
// information about the parameters of the generic here.
emitRaw(context, "g");
UInt genericParameterCount = 0;
- for( auto mm : getMembers(context->astBuilder, parentGenericDeclRef) )
+ for( auto mm : getMembers(context->astBuilder, parentGenericDeclRef.as<ContainerDecl>()) )
{
if(mm.is<GenericTypeParamDecl>())
{
@@ -569,7 +560,7 @@ namespace Slang
// mangling the generic and the inner entity
emitRaw(context, "G");
- SLANG_ASSERT(genericDecl.getSubst() == nullptr);
+ SLANG_ASSERT(SubstitutionSet(genericDecl).findGenericAppDeclRef() == nullptr);
auto innerDecl = getInner(genericDecl);
@@ -591,6 +582,7 @@ namespace Slang
static String getMangledName(ASTBuilder* astBuilder, DeclRef<Decl> const& declRef)
{
+ SLANG_AST_BUILDER_RAII(astBuilder);
ManglingContext context(astBuilder);
mangleName(&context, declRef);
return context.sb.produceString();
@@ -598,11 +590,15 @@ namespace Slang
String getMangledName(ASTBuilder* astBuilder, DeclRefBase* declRef)
{
+ SLANG_AST_BUILDER_RAII(astBuilder);
+
return getMangledName(astBuilder, DeclRef<Decl>(declRef));
}
String getMangledName(ASTBuilder* astBuilder, Decl* decl)
{
+ SLANG_AST_BUILDER_RAII(astBuilder);
+
return getMangledName(astBuilder, makeDeclRef(decl));
}
@@ -611,6 +607,7 @@ namespace Slang
DeclRef<Decl> sub,
DeclRef<Decl> sup)
{
+ SLANG_AST_BUILDER_RAII(astBuilder);
ManglingContext context(astBuilder);
emitRaw(&context, "_SW");
emitQualifiedName(&context, sub);
@@ -623,6 +620,7 @@ namespace Slang
DeclRef<Decl> sub,
Type* sup)
{
+ SLANG_AST_BUILDER_RAII(astBuilder);
// The mangled form for a witness that `sub`
// conforms to `sup` will be named:
//
@@ -640,6 +638,7 @@ namespace Slang
Type* sub,
Type* sup)
{
+ SLANG_AST_BUILDER_RAII(astBuilder);
// The mangled form for a witness that `sub`
// conforms to `sup` will be named:
//
@@ -654,6 +653,7 @@ namespace Slang
String getMangledTypeName(ASTBuilder* astBuilder, Type* type)
{
+ SLANG_AST_BUILDER_RAII(astBuilder);
ManglingContext context(astBuilder);
emitType(&context, type);
return context.sb.produceString();
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index 47f370854..c0389d1cd 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -2064,7 +2064,7 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(
// otherwise they will include all of the above cases...
else if( auto declRefType = as<DeclRefType>(type) )
{
- auto declRef = declRefType->declRef;
+ auto declRef = declRefType->getDeclRef();
if (auto structDeclRef = declRef.as<StructDecl>())
{
@@ -2777,7 +2777,7 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters(
// Any generic specialization applied to the entry-point function
// must also be applied to its parameters.
- paramDeclRef = context->getASTBuilder()->getSpecializedDeclRef(paramDeclRef.getDecl(), entryPointFuncDeclRef.getSubst());
+ paramDeclRef = context->getASTBuilder()->getMemberDeclRef(entryPointFuncDeclRef, paramDeclRef.getDecl());
// When computing layout for an entry-point parameter,
// we want to make sure that the layout context has access
@@ -3033,24 +3033,6 @@ struct CollectParametersVisitor : ComponentTypeVisitor
// along.
//
visitChildren(specialized);
-
- // While we are at it, we will also make note of any
- // tagged-union types that were used as part of the
- // specialization arguments, since we need to make
- // sure that their layout information is computed
- // and made available for IR code generation.
- //
- // Note: this isn't really the best place for this logic to sit,
- // but it is the simplest place where we can collect all the tagged
- // union types that get referenced by a program.
- //
- for( auto taggedUnionType : specialized->getTaggedUnionTypes() )
- {
- SLANG_ASSERT(taggedUnionType);
- auto substType = taggedUnionType;
- auto typeLayout = createTypeLayout(m_context->layoutContext, substType);
- m_context->shared->programLayout->taggedUnionTypeLayouts.add(typeLayout);
- }
}
@@ -3755,6 +3737,8 @@ RefPtr<ProgramLayout> generateParameterBindings(
TargetProgram* targetProgram,
DiagnosticSink* sink)
{
+ SLANG_AST_BUILDER_RAII(targetProgram->getProgram()->getLinkage()->getASTBuilder());
+
auto program = targetProgram->getProgram();
auto targetReq = targetProgram->getTargetReq();
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index cab01d585..4448a96e1 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -1323,10 +1323,9 @@ namespace Slang
// parent link is set up correctly.
static void AddMember(ContainerDecl* container, Decl* member)
{
- if (container && member)
+ if (container)
{
- member->parentDecl = container;
- container->members.add(member);
+ container->addMember(member);
}
}
@@ -1334,7 +1333,7 @@ namespace Slang
{
if (scope)
{
- AddMember(scope->containerDecl, member);
+ scope->containerDecl->addMember(member);
}
}
@@ -2149,30 +2148,6 @@ namespace Slang
}
return typeExpr;
}
-
- static Expr* parseTaggedUnionType(Parser* parser)
- {
- TaggedUnionTypeExpr* taggedUnionType = parser->astBuilder->create<TaggedUnionTypeExpr>();
-
- parser->ReadToken(TokenType::LParent);
- while(!AdvanceIfMatch(parser, MatchedTokenType::Parentheses))
- {
- auto caseType = parser->ParseTypeExp();
- taggedUnionType->caseTypes.add(caseType);
-
- if(AdvanceIf(parser, TokenType::RParent))
- break;
-
- parser->ReadToken(TokenType::Comma);
- }
-
- return taggedUnionType;
- }
-
- static NodeBase* parseTaggedUnionType(Parser* parser, void* /*unused*/)
- {
- return parseTaggedUnionType(parser);
- }
/// Parse an expression of the form __fwd_diff(fn) where fn is an
/// identifier pointing to a function.
static Expr* parseForwardDifferentiate(Parser* parser)
@@ -2234,19 +2209,6 @@ namespace Slang
return parseDispatchKernel(parser);
}
- /// Parse a `This` type expression
- static Expr* parseThisTypeExpr(Parser* parser)
- {
- ThisTypeExpr* expr = parser->astBuilder->create<ThisTypeExpr>();
- expr->scope = parser->currentScope;
- return expr;
- }
-
- static NodeBase* parseThisTypeExpr(Parser* parser, void* /*userData*/)
- {
- return parseThisTypeExpr(parser);
- }
-
// (a,b,c) style tuples, curently unused
#if 0
static Expr* parseTupleTypeExpr(Parser* parser)
@@ -2459,22 +2421,6 @@ namespace Slang
typeSpec.expr = createDeclRefType(parser, decl);
return typeSpec;
}
- // TODO: This case would not be needed if we had the
- // code below dispatch into `parseAtomicExpr`, which
- // already includes logic for keyword lookup.
- //
- // Leaving this case here for now to avoid breaking anything.
- //
- else if(AdvanceIf(parser, "__TaggedUnion"))
- {
- typeSpec.expr = parseTaggedUnionType(parser);
- return typeSpec;
- }
- else if(AdvanceIf(parser, "This"))
- {
- typeSpec.expr = parseThisTypeExpr(parser);
- return typeSpec;
- }
// Uncomment should we decide to enable (a,b,c) tuple types
// else if(parser->LookAheadToken(TokenType::LParent))
// {
@@ -3170,7 +3116,7 @@ namespace Slang
static NodeBase* parseInterfaceDecl(Parser* parser, void* /*userData*/)
{
- InterfaceDecl* decl = parser->astBuilder->create<InterfaceDecl>();
+ InterfaceDecl* decl = parser->astBuilder->createInterfaceDecl(parser->tokenReader.peekLoc());
parser->FillPosition(decl);
AdvanceIf(parser, TokenType::CompletionRequest);
@@ -4082,6 +4028,8 @@ namespace Slang
void Parser::parseSourceFile(ModuleDecl* program)
{
+ SLANG_AST_BUILDER_RAII(astBuilder);
+
if (outerScope)
{
currentScope = outerScope;
@@ -4328,6 +4276,7 @@ namespace Slang
parser->astBuilder,
nullptr, // no semantics visitor available yet
staticMemberExpr->name,
+ aggTypeDecl,
declRef);
if (!lookupResult.isValid() || lookupResult.isOverloaded())
@@ -6252,7 +6201,7 @@ namespace Slang
// Need to get the basic type, so we can fit to underlying type
if (auto basicExprType = as<BasicExpressionType>(intLit->type.type))
{
- value = _fixIntegerLiteral(basicExprType->baseType, value, nullptr, nullptr);
+ value = _fixIntegerLiteral(basicExprType->getBaseType(), value, nullptr, nullptr);
}
newLiteral->value = value;
@@ -6910,14 +6859,12 @@ namespace Slang
// !!!!!!!!!!!!!!!!!!!!!!! Expr !!!!!!!!!!!!!!!!!!!!!!!!!!!
_makeParseExpr("this", parseThisExpr),
- _makeParseExpr("This", parseThisTypeExpr),
_makeParseExpr("true", parseTrueExpr),
_makeParseExpr("false", parseFalseExpr),
_makeParseExpr("nullptr", parseNullPtrExpr),
_makeParseExpr("none", parseNoneExpr),
_makeParseExpr("try", parseTryExpr),
_makeParseExpr("no_diff", parseTreatAsDifferentiableExpr),
- _makeParseExpr("__TaggedUnion", parseTaggedUnionType),
_makeParseExpr("__fwd_diff", parseForwardDifferentiate),
_makeParseExpr("__bwd_diff", parseBackwardDifferentiate),
_makeParseExpr("fwd_diff", parseForwardDifferentiate),
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 7a79e9fcd..9f83d325d 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -269,11 +269,14 @@ SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueInt(SlangReflect
if (!userAttr) return SLANG_E_INVALID_ARG;
if (index >= (unsigned int)userAttr->args.getCount()) return SLANG_E_INVALID_ARG;
- NodeBase* val = nullptr;
- if (userAttr->intArgVals.tryGetValue(index, val))
+ if (userAttr->intArgVals.getCount() > (Index)index)
{
- *rs = (int)as<ConstantIntVal>(val)->value;
- return 0;
+ auto intVal = as<ConstantIntVal>(userAttr->intArgVals[index]);
+ if (intVal)
+ {
+ *rs = (int)intVal->getValue();
+ return 0;
+ }
}
return SLANG_E_INVALID_ARG;
}
@@ -387,7 +390,7 @@ SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType)
}
else if( auto declRefType = as<DeclRefType>(type) )
{
- const auto& declRef = declRefType->declRef;
+ const auto& declRef = declRefType->getDeclRef();
if(declRef.is<StructDecl>() )
{
return SLANG_TYPE_KIND_STRUCT;
@@ -429,7 +432,7 @@ SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* inTyp
if(auto declRefType = as<DeclRefType>(type))
{
- auto declRef = declRefType->declRef;
+ auto declRef = declRefType->getDeclRef();
if( auto structDeclRef = declRef.as<StructDecl>())
{
return (unsigned int)getFields(
@@ -452,7 +455,7 @@ SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflect
if(auto declRefType = as<DeclRefType>(type))
{
- auto declRef = declRefType->declRef;
+ auto declRef = declRefType->getDeclRef();
if( auto structDeclRef = declRef.as<StructDecl>())
{
auto fields = getFields(
@@ -476,7 +479,7 @@ SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType)
}
else if( auto vectorType = as<VectorExpressionType>(type))
{
- return (size_t) getIntVal(vectorType->elementCount);
+ return (size_t) getIntVal(vectorType->getElementCount());
}
return 0;
@@ -493,15 +496,15 @@ SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionTy
}
else if( auto parameterGroupType = as<ParameterGroupType>(type))
{
- return convert(parameterGroupType->elementType);
+ return convert(parameterGroupType->getElementType());
}
else if (auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type))
{
- return convert(structuredBufferType->elementType);
+ return convert(structuredBufferType->getElementType());
}
else if( auto vectorType = as<VectorExpressionType>(type))
{
- return convert(vectorType->elementType);
+ return convert(vectorType->getElementType());
}
else if( auto matrixType = as<MatrixExpressionType>(type))
{
@@ -543,7 +546,7 @@ SLANG_API unsigned int spReflectionType_GetColumnCount(SlangReflectionType* inTy
}
else if(auto vectorType = as<VectorExpressionType>(type))
{
- return (unsigned int) getIntVal(vectorType->elementCount);
+ return (unsigned int) getIntVal(vectorType->getElementCount());
}
else if( const auto basicType = as<BasicExpressionType>(type) )
{
@@ -564,12 +567,12 @@ SLANG_API SlangScalarType spReflectionType_GetScalarType(SlangReflectionType* in
}
else if(auto vectorType = as<VectorExpressionType>(type))
{
- type = vectorType->elementType;
+ type = vectorType->getElementType();
}
if(auto basicType = as<BasicExpressionType>(type))
{
- switch (basicType->baseType)
+ switch (basicType->getBaseType())
{
#define CASE(BASE, TAG) \
case BaseType::BASE: return SLANG_SCALAR_TYPE_##TAG
@@ -606,7 +609,7 @@ SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionTyp
if (!type) return 0;
if (auto declRefType = as<DeclRefType>(type))
{
- return getUserAttributeCount(declRefType->declRef.getDecl());
+ return getUserAttributeCount(declRefType->getDeclRef().getDecl());
}
return 0;
}
@@ -616,7 +619,7 @@ SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangR
if (!type) return 0;
if (auto declRefType = as<DeclRefType>(type))
{
- return getUserAttributeByIndex(declRefType->declRef.getDecl(), index);
+ return getUserAttributeByIndex(declRefType->getDeclRef().getDecl(), index);
}
return 0;
}
@@ -626,10 +629,10 @@ SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName
if (!type) return 0;
if (auto declRefType = as<DeclRefType>(type))
{
- ASTBuilder* astBuilder = declRefType->getASTBuilder();
+ ASTBuilder* astBuilder = declRefType->getASTBuilderForReflection();
auto globalSession = astBuilder->getGlobalSession();
- return findUserAttributeByName(globalSession, declRefType->declRef.getDecl(), name);
+ return findUserAttributeByName(globalSession, declRefType->getDeclRef().getDecl(), name);
}
return 0;
}
@@ -714,7 +717,7 @@ SLANG_API char const* spReflectionType_GetName(SlangReflectionType* inType)
if( auto declRefType = as<DeclRefType>(type) )
{
- auto declRef = declRefType->declRef;
+ auto declRef = declRefType->getDeclRef();
// Don't return a name for auto-generated anonymous types
// that represent `cbuffer` members, etc.
@@ -778,13 +781,13 @@ SLANG_API SlangReflectionType* spReflectionType_GetResourceResultType(SlangRefle
if (auto textureType = as<TextureTypeBase>(type))
{
- return convert(textureType->elementType);
+ return convert(textureType->getElementType());
}
// TODO: need a better way to handle this stuff...
#define CASE(TYPE, SHAPE, ACCESS) \
else if(as<TYPE>(type)) do { \
- return convert(as<TYPE>(type)->elementType); \
+ return convert(as<TYPE>(type)->getElementType()); \
} while(0)
// TODO: structured buffer needs to expose type layout!
@@ -1132,7 +1135,7 @@ SLANG_API SlangInt spReflectionType_getSpecializedTypeArgCount(SlangReflectionTy
auto specializedType = as<ExistentialSpecializedType>(type);
if(!specializedType) return 0;
- return specializedType->args.getCount();
+ return specializedType->getArgCount();
}
SLANG_API SlangReflectionType* spReflectionType_getSpecializedTypeArgType(SlangReflectionType* inType, SlangInt index)
@@ -1144,9 +1147,9 @@ SLANG_API SlangReflectionType* spReflectionType_getSpecializedTypeArgType(SlangR
if(!specializedType) return nullptr;
if(index < 0) return nullptr;
- if(index >= specializedType->args.getCount()) return nullptr;
+ if(index >= specializedType->getArgCount()) return nullptr;
- auto argType = as<Type>(specializedType->args[index].val);
+ auto argType = as<Type>(specializedType->getArg(index).val);
return convert(argType);
}
@@ -1405,7 +1408,7 @@ namespace Slang
{
if(auto declRefType = as<DeclRefType>(type))
{
- if(declRefType->declRef.as<InterfaceDecl>())
+ if(declRefType->getDeclRef().as<InterfaceDecl>())
{
return declRefType;
}
diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h
index c28c8a6d6..351b6f519 100644
--- a/source/slang/slang-serialize-ast-type-info.h
+++ b/source/slang/slang-serialize-ast-type-info.h
@@ -39,20 +39,82 @@ struct SerialTypeInfo<SyntaxClass<T>>
}
};
+// MatrixCoord can just go as is
+template <>
+struct SerialTypeInfo<MatrixCoord> : SerialIdentityTypeInfo<MatrixCoord> {};
+
+inline void serializePointerValue(SerialWriter* writer, Val* ptrValue, SerialIndex* outSerial)
+{
+ if (ptrValue)
+ ptrValue = ptrValue->resolve();
+ *(SerialIndex*)outSerial = writer->addPointer(ptrValue);
+}
+
+inline void deserializePointerValue(SerialReader* reader, const SerialIndex* inSerial, void* outPtr, Val* unusedForResolution)
+{
+ SLANG_UNUSED(unusedForResolution);
+
+ auto val = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<Val>();
+ *(Val**)outPtr = val;
+ if (val)
+ {
+ SLANG_ASSERT(as<Val>(val));
+ PostSerializationFixUp fixup;
+ fixup.kind = PostSerializationFixUpKind::ValPtr;
+ fixup.addressToModify = outPtr;
+ reader->getFixUps().add(fixup);
+ }
+}
template <typename T>
struct SerialTypeInfo<DeclRef<T>> : public SerialTypeInfo<DeclRefBase*> {};
-// MatrixCoord can just go as is
+// ValNodeOperand
template <>
-struct SerialTypeInfo<MatrixCoord> : SerialIdentityTypeInfo<MatrixCoord> {};
+struct SerialTypeInfo<ValNodeOperand>
+{
+ typedef ValNodeOperand NativeType;
+ struct SerialType
+ {
+ int8_t kind;
+ int64_t val;
+ };
+ enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) };
+
+ static void toSerial(SerialWriter* writer, const void* native, void* serial)
+ {
+ auto& src = *(const NativeType*)native;
+ auto& dst = *(SerialType*)serial;
+ dst.kind = int8_t(src.kind);
+ if (src.kind == ValNodeOperandKind::ConstantValue)
+ dst.val = src.values.intOperand;
+ else if (src.kind == ValNodeOperandKind::ValNode)
+ serializePointerValue(writer, (Val*)src.values.nodeOperand, (SerialIndex*)&dst.val);
+ else
+ serializePointerValue(writer, src.values.nodeOperand, (SerialIndex*)&dst.val);
+ }
+ static void toNative(SerialReader* reader, const void* serial, void* native)
+ {
+ auto& dst = *(NativeType*)native;
+ auto& src = *(const SerialType*)serial;
+
+ // Initialize
+ dst = NativeType();
+ dst.kind = ValNodeOperandKind(src.kind);
+ if (dst.kind == ValNodeOperandKind::ConstantValue)
+ dst.values.intOperand = int64_t(src.val);
+ else if (dst.kind == ValNodeOperandKind::ValNode)
+ deserializePointerValue(reader, (SerialIndex*)&src.val, (Val**)&dst.values.nodeOperand, (Val*)nullptr);
+ else
+ deserializePointerValue(reader, (SerialIndex*)&src.val, &dst.values.nodeOperand, (NodeBase*)nullptr);
+ }
+};
// LookupResultItem
SLANG_VALUE_TYPE_INFO(LookupResultItem)
// QualType
SLANG_VALUE_TYPE_INFO(QualType)
-
// LookupResult
template <>
struct SerialTypeInfo<LookupResult>
@@ -151,10 +213,6 @@ struct SerialTypeInfo<Modifiers>
}
};
-// ASTNodeType
-template <>
-struct SerialTypeInfo<ASTNodeType> : public SerialConvertTypeInfo<ASTNodeType, uint16_t> {};
-
// LookupResultItem_Breadcrumb::ThisParameterMode
template <>
struct SerialTypeInfo<LookupResultItem_Breadcrumb::ThisParameterMode> : public SerialConvertTypeInfo<LookupResultItem_Breadcrumb::ThisParameterMode, uint8_t> {};
@@ -170,6 +228,7 @@ struct SerialTypeInfo<RequirementWitness::Flavor> : public SerialConvertTypeInfo
// RequirementWitness
SLANG_VALUE_TYPE_INFO(RequirementWitness)
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp
index c75237896..293535b02 100644
--- a/source/slang/slang-serialize-container.cpp
+++ b/source/slang/slang-serialize-container.cpp
@@ -475,10 +475,6 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
// Set the sourceLocReader before doing de-serialize, such can lookup the remapped sourceLocs
reader.getExtraObjects().set(sourceLocReader);
- // Go through all of the AST nodes
- // 1) Set the ASTBuilder on Type nodes
-
-
// TODO(JS):
// If modules can have more complicated relationships (like a two modules can refer to symbols
// from each other), then we can make this work by
@@ -492,12 +488,14 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
// For now if we assume a module can only access symbols from another module, and not the reverse.
// So we just need to deserialize and we are done
SLANG_RETURN_ON_FAIL(reader.deserializeObjects());
-
+
// Get the root node. It's at index 1 (0 is the null value).
astRootNode = reader.getPointer(SerialIndex(1)).dynamicCast<NodeBase>();
- // 2) Add the extensions to the module mapTypeToCandidateExtensions cache
- // 3) We need to fix the callback pointers for parsing
+ // Go through all AST nodes:
+ // 1) Add the extensions to the module mapTypeToCandidateExtensions cache
+ // 2) We need to fix the callback pointers for parsing
+ // 3) Register all `Val`s to the ASTBuilder's deduplication map.
{
ModuleDecl* moduleDecl = as<ModuleDecl>(astRootNode);
@@ -505,6 +503,8 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
// Maps from keyword name name to index in (syntaxParseInfos)
// Will be filled in lazily if needed (for SyntaxDecl setup)
Dictionary<Name*, Index> syntaxKeywordDict;
+
+ OrderedDictionary<Val*, List<Val**>> valUses;
// Get the parse infos
const auto syntaxParseInfos = getSyntaxParseInfos();
@@ -512,21 +512,18 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
for (auto& obj : reader.getObjects())
{
+
if (obj.m_kind == SerialTypeKind::NodeBase)
{
NodeBase* nodeBase = (NodeBase*)obj.m_ptr;
SLANG_ASSERT(nodeBase);
- if (Type* type = dynamicCast<Type>(nodeBase))
- {
- type->_setASTBuilder(astBuilder);
- }
- else if (ExtensionDecl* extensionDecl = dynamicCast<ExtensionDecl>(nodeBase))
+ if (ExtensionDecl* extensionDecl = dynamicCast<ExtensionDecl>(nodeBase))
{
if (auto targetDeclRefType = as<DeclRefType>(extensionDecl->targetType))
{
// Attach our extension to that type as a candidate...
- if (auto aggTypeDeclRef = targetDeclRefType->declRef.as<AggTypeDecl>())
+ if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>())
{
auto aggTypeDecl = aggTypeDeclRef.getDecl();
@@ -567,6 +564,47 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
syntaxDecl->parseUserData = const_cast<ReflectClassInfo*>(syntaxDecl->syntaxClass.classInfo);
}
}
+ else if (Val* val = dynamicCast<Val>(nodeBase))
+ {
+ valUses[val] = List<Val**>();
+ }
+ }
+ }
+ // Go through fixup locations and deduplicate Vals.
+ // This is needed because we currently the same Val can be serialized multiple times
+ // in different modules. If we have a type defined in Module A and used in Module B,
+ // then both serialized Module A and Module B will contain a Type Val object that refers to A.
+ // When we load B, we should resolve those type references to the existing Type val instead.
+ // This step can be avoided if we can run deduplication while deserializing, which
+ // requires a different way of handling Val objects.
+ for (auto fixup : reader.getFixUps())
+ {
+ if (fixup.kind == PostSerializationFixUpKind::ValPtr)
+ {
+ auto list = valUses.tryGetValue(*(Val**)fixup.addressToModify);
+ if (list)
+ list->add((Val**)fixup.addressToModify);
+ }
+ }
+ SLANG_AST_BUILDER_RAII(astBuilder);
+ for (auto& valUseList : valUses)
+ {
+ auto val = valUseList.key;
+ auto desc = val->getDesc();
+ astBuilder->m_cachedNodes.tryGetValueOrAdd(desc, val);
+ }
+ for (auto& valUseList : valUses)
+ {
+ auto val = valUseList.key;
+ auto newVal = val->resolve();
+ if (val != newVal)
+ {
+ astBuilder->m_cachedNodes[val->getDesc()] = newVal;
+ for (auto use : valUseList.value)
+ {
+ if (*use != newVal)
+ *use = newVal;
+ }
}
}
}
diff --git a/source/slang/slang-serialize-type-info.h b/source/slang/slang-serialize-type-info.h
index 971d45197..c4b20c5b9 100644
--- a/source/slang/slang-serialize-type-info.h
+++ b/source/slang/slang-serialize-type-info.h
@@ -3,6 +3,7 @@
#define SLANG_SERIALIZE_TYPE_INFO_H
#include "slang-serialize.h"
+
namespace Slang {
/* For the serialization system to work we need to defined how native types are represented in the serialized format.
@@ -87,7 +88,6 @@ struct SerialTypeInfo<float> : public SerialBasicTypeInfo<float> {};
template <>
struct SerialTypeInfo<double> : public SerialBasicTypeInfo<double> {};
-
// Fixed arrays
template <typename T, size_t N>
@@ -154,9 +154,26 @@ struct SerialTypeInfo<T, typename std::enable_if<std::is_enum<T>::value>::type>
: public SerialIdentityTypeInfo<T>
{};
+class Val;
+
// Pointer
-// Could handle different pointer base types with some more template magic here, but instead went with Pointer type to keep
-// things simpler.
+
+template<typename T, typename sfinae = typename std::enable_if<!IsBaseOf<Val, T>::Value>::type>
+void serializePointerValue(SerialWriter* writer, T* ptrValue, SerialIndex* outSerial)
+{
+ static_assert(!IsBaseOf<Val, T>::Value);
+ *(SerialIndex*)outSerial = writer->addPointer(ptrValue);
+}
+
+template<typename T, typename sfinae = typename std::enable_if<!IsBaseOf<Val, T>::Value>::type>
+void deserializePointerValue(SerialReader* reader, SerialIndex* inSerial, void* outPtr, T* unusedForResolution)
+{
+ static_assert(!IsBaseOf<Val, T>::Value);
+
+ SLANG_UNUSED(unusedForResolution);
+ *(T**)outPtr = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<T>();
+}
+
template <typename T>
struct SerialTypeInfo<T*>
{
@@ -166,11 +183,13 @@ struct SerialTypeInfo<T*>
static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial)
{
- *(SerialType*)outSerial = writer->addPointer(*(T**)inNative);
+ auto ptrToWrite = *(T**)inNative;
+ serializePointerValue(writer, ptrToWrite, (SerialIndex*)outSerial);
}
+
static void toNative(SerialReader* reader, const void* inSerial, void* outNative)
{
- *(T**)outNative = reader->getPointer(*(const SerialType*)inSerial).dynamicCast<T>();
+ deserializePointerValue(reader, (SerialIndex*)inSerial, outNative, (T*)nullptr);
}
};
@@ -257,74 +276,8 @@ struct SerialTypeInfo<String>
};
// Dictionary
-template <typename KEY, typename VALUE>
-struct SerialTypeInfo<Dictionary<KEY, VALUE>>
-{
- typedef Dictionary<KEY, VALUE> NativeType;
- struct SerialType
- {
- SerialIndex keys; ///< Index an array
- SerialIndex values; ///< Index an array
- };
-
- typedef typename SerialTypeInfo<KEY>::SerialType KeySerialType;
- typedef typename SerialTypeInfo<VALUE>::SerialType ValueSerialType;
-
- enum { SerialAlignment = SLANG_ALIGN_OF(SerialIndex) };
-
- static void toSerial(SerialWriter* writer, const void* native, void* serial)
- {
- auto& src = *(const NativeType*)native;
- auto& dst = *(SerialType*)serial;
-
- List<KeySerialType> keys;
- List<ValueSerialType> values;
-
- Index count = Index(src.getCount());
- keys.setCount(count);
- values.setCount(count);
-
- if (writer->getFlags() & SerialWriter::Flag::ZeroInitialize)
- {
- ::memset(keys.getBuffer(), 0, count * sizeof(KeySerialType));
- ::memset(values.getBuffer(), 0, count * sizeof(ValueSerialType));
- }
-
- Index i = 0;
- for (const auto& pair : src)
- {
- SerialTypeInfo<KEY>::toSerial(writer, &pair.key, &keys[i]);
- SerialTypeInfo<VALUE>::toSerial(writer, &pair.value, &values[i]);
- i++;
- }
-
- // When we add the array it is already converted to a serializable type, so add as SerialArray
- dst.keys = writer->addSerialArray<KEY>(keys.getBuffer(), count);
- dst.values = writer->addSerialArray<VALUE>(values.getBuffer(), count);
- }
- static void toNative(SerialReader* reader, const void* serial, void* native)
- {
- auto& src = *(const SerialType*)serial;
- auto& dst = *(NativeType*)native;
-
- // Clear it
- dst = NativeType();
-
- List<KEY> keys;
- List<VALUE> values;
-
- reader->getArray(src.keys, keys);
- reader->getArray(src.values, values);
-
- SLANG_ASSERT(keys.getCount() == values.getCount());
-
- const Index count = keys.getCount();
- for (Index i = 0; i < count; ++i)
- {
- dst.add(keys[i], values[i]);
- }
- }
-};
+// Note: We leave out SerialTypeInfo specialization for Dictionary, because
+// it does not have determinstic ordering.
// OrderedDictionary
template <typename KEY, typename VALUE>
diff --git a/source/slang/slang-serialize-types.cpp b/source/slang/slang-serialize-types.cpp
index 6c4512b1d..a091a2850 100644
--- a/source/slang/slang-serialize-types.cpp
+++ b/source/slang/slang-serialize-types.cpp
@@ -48,7 +48,8 @@ struct ByteReader
const int numPrefixBytes = encodeUnicodePointToUTF8(len, prefixBytes);
const Index baseIndex = stringTable.getCount();
- stringTable.setCount(baseIndex + numPrefixBytes + len);
+ auto newCount = baseIndex + numPrefixBytes + len;
+ stringTable.growToCount(newCount);
char* dst = stringTable.begin() + baseIndex;
diff --git a/source/slang/slang-serialize.cpp b/source/slang/slang-serialize.cpp
index 2e8d6c6ba..1f5b6942d 100644
--- a/source/slang/slang-serialize.cpp
+++ b/source/slang/slang-serialize.cpp
@@ -2,6 +2,7 @@
#include "slang-serialize.h"
#include "slang-ast-base.h"
+#include "slang-ast-builder.h"
namespace Slang {
@@ -204,14 +205,14 @@ bool SerialClasses::isOk() const
SerialClasses::SerialClasses():
- m_arena(2048)
+ m_arena(2097152)
{
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialWriter !!!!!!!!!!!!!!!!!!!!!!!!!!!!
SerialWriter::SerialWriter(SerialClasses* classes, SerialFilter* filter, Flags flags)
- : m_arena(2048)
+ : m_arena(2097152)
, m_classes(classes)
, m_filter(filter)
, m_flags(flags)
diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h
index cc617034d..ce7bfa87b 100644
--- a/source/slang/slang-serialize.h
+++ b/source/slang/slang-serialize.h
@@ -211,6 +211,17 @@ protected:
void* m_objects[Index(SerialExtraType::CountOf)];
};
+enum class PostSerializationFixUpKind
+{
+ ValPtr,
+};
+
+struct PostSerializationFixUp
+{
+ PostSerializationFixUpKind kind;
+ void* addressToModify;
+};
+
/* This class is the interface used by toNative implementations to recreate a type. */
class SerialReader : public RefObject
{
@@ -240,6 +251,8 @@ public:
/// Get the entries list
const List<const Entry*>& getEntries() const { return m_entries; }
+ List<PostSerializationFixUp>& getFixUps() { return m_fixUps; }
+
/// Access the objects list
/// NOTE that if a SerialObject holding a RefObject and needs to be kept in scope, add the RefObject* via addScope
List<SerialPointer>& getObjects() { return m_objects; }
@@ -277,6 +290,8 @@ protected:
SerialObjectFactory* m_objectFactory;
SerialClasses* m_classes; ///< Information used to deserialize
+
+ List<PostSerializationFixUp> m_fixUps;
};
// ---------------------------------------------------------------------------
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 227e468d6..ae44e0c70 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -284,14 +284,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
{
if(auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(subtypeWitness))
{
- if(auto inheritanceDeclRef = declaredSubtypeWitness->declRef.as<InheritanceDecl>())
+ if(auto inheritanceDeclRef = declaredSubtypeWitness->getDeclRef().as<InheritanceDecl>())
{
// A conformance that was declared as part of an inheritance clause
// will have built up a dictionary of the satisfying declarations
// for each of its requirements.
RequirementWitness requirementWitness;
auto witnessTable = inheritanceDeclRef.getDecl()->witnessTable;
- if(witnessTable && witnessTable->requirementDictionary.tryGetValue(requirementKey, requirementWitness))
+ if(witnessTable && witnessTable->getRequirementDictionary().tryGetValue(requirementKey, requirementWitness))
{
// The `inheritanceDeclRef` has substitutions applied to it that
// *aren't* present in the `requirementWitness`, because it was
@@ -338,7 +338,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
// So, in order to get the *right* end result, we need to apply
// the substitutions from the inheritance decl-ref to the witness.
//
- requirementWitness = requirementWitness.specialize(astBuilder, inheritanceDeclRef.getSubst());
+ requirementWitness = requirementWitness.specialize(astBuilder, SubstitutionSet(inheritanceDeclRef));
return requirementWitness;
}
@@ -346,17 +346,17 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
}
else if (auto transitiveTypeWitness = as<TransitiveSubtypeWitness>(subtypeWitness))
{
- if (auto declaredSubtypeWitnessMidToSup = as<DeclaredSubtypeWitness>(transitiveTypeWitness->midToSup))
+ if (auto declaredSubtypeWitnessMidToSup = as<DeclaredSubtypeWitness>(transitiveTypeWitness->getMidToSup()))
{
- auto midKey = declaredSubtypeWitnessMidToSup->declRef;
- auto midWitness = tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(transitiveTypeWitness->subToMid), midKey.getDecl());
+ auto midKey = declaredSubtypeWitnessMidToSup->getDeclRef();
+ auto midWitness = tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(transitiveTypeWitness->getSubToMid()), midKey.getDecl());
if (midWitness.getFlavor() == RequirementWitness::Flavor::witnessTable)
{
auto table = midWitness.getWitnessTable();
RequirementWitness result;
- if (table->requirementDictionary.tryGetValue(requirementKey, result))
+ if (table->getRequirementDictionary().tryGetValue(requirementKey, result))
{
- result = result.specialize(astBuilder, midKey.getSubst());
+ result = result.specialize(astBuilder, SubstitutionSet(midKey));
}
return result;
}
@@ -364,15 +364,32 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
}
else if (auto extractFromConjunctionTypeWitness = as<ExtractFromConjunctionSubtypeWitness>(subtypeWitness))
{
- if (auto conjunctionTypeWitness = as<ConjunctionSubtypeWitness>(extractFromConjunctionTypeWitness->conjunctionWitness))
+ if (auto conjunctionTypeWitness = as<ConjunctionSubtypeWitness>(extractFromConjunctionTypeWitness->getConjunctionWitness()))
{
auto componentWitness = as<SubtypeWitness>(
conjunctionTypeWitness->getComponentWitness(
- extractFromConjunctionTypeWitness->indexInConjunction));
+ extractFromConjunctionTypeWitness->getIndexInConjunction()));
return tryLookUpRequirementWitness(astBuilder, componentWitness, requirementKey);
}
}
+
+ // If we are looking for `ThisType`, just return subtype.
+ if (as<ThisTypeDecl>(requirementKey))
+ {
+ RequirementWitness result;
+ result.m_flavor = RequirementWitness::Flavor::val;
+ result.m_val = subtypeWitness->getSub();
+ return result;
+ }
+ // If we are looking for `ThisTypeConstraint`, just return the witness itself.
+ if (as<ThisTypeConstraintDecl>(requirementKey))
+ {
+ RequirementWitness result;
+ result.m_flavor = RequirementWitness::Flavor::val;
+ result.m_val = subtypeWitness;
+ return result;
+ }
// TODO: should handle the transitive case here too
return RequirementWitness();
@@ -384,125 +401,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
void WitnessTable::add(Decl* decl, RequirementWitness const& witness)
{
- SLANG_ASSERT(!requirementDictionary.containsKey(decl));
-
- requirementDictionary.add(decl, witness);
- }
-
- //
-
- static Type* ExtractGenericArgType(Val* val)
- {
- auto type = as<Type>(val);
- SLANG_RELEASE_ASSERT(type);
- return type;
- }
-
- static IntVal* ExtractGenericArgInteger(Val* val)
- {
- auto intVal = as<IntVal>(val);
- SLANG_RELEASE_ASSERT(intVal);
- return intVal;
- }
-
- DeclRef<Decl> createDefaultSubstitutionsIfNeeded(
- ASTBuilder* astBuilder,
- SemanticsVisitor* semantics,
- DeclRef<Decl> declRef)
- {
- // It is possible that `declRef` refers to a generic type,
- // but does not specify arguments for its generic parameters.
- // (E.g., this happens when referring to a generic type from
- // within its own member functions). To handle this case,
- // we will construct a default specialization at the use
- // site if needed.
- //
- // This same logic should also apply to declarations nested
- // more than one level inside of a generic (e.g., a `typdef`
- // inside of a generic `struct`).
- //
- // Similarly, it needs to work for multiple levels of
- // nested generics.
- //
-
- // First, we collect all the generic parents.
- ShortList<GenericDecl*> genericParents;
- Decl* dd = declRef.getDecl();
- for (;;)
- {
- Decl* childDecl = dd;
- Decl* parentDecl = dd->parentDecl;
- if (!parentDecl)
- break;
-
- dd = parentDecl;
-
- if (auto genericParentDecl = as<GenericDecl>(parentDecl))
- {
- // Don't specialize any parameters of a generic.
- if (childDecl != genericParentDecl->inner)
- break;
- genericParents.add(genericParentDecl);
- }
- }
-
-
- Substitutions* outerSubst = nullptr;
- for (Index i = genericParents.getCount()-1; i>=0; i--)
- {
- Decl* childDecl = genericParents[i]->inner;
- Decl* parentDecl = genericParents[i];
-
- if(auto genericParentDecl = as<GenericDecl>(parentDecl))
- {
- // Don't specialize any parameters of a generic.
- if(childDecl != genericParentDecl->inner)
- break;
-
- // We have a generic ancestor, but do we have an substitutions for it?
- GenericSubstitution* foundSubst = nullptr;
- for(auto s = declRef.getSubst(); s; s = s->getOuter())
- {
- auto genSubst = as<GenericSubstitution>(s);
- if(!genSubst)
- continue;
-
- if(genSubst->getGenericDecl() != genericParentDecl)
- continue;
-
- // Okay, we found a matching substitution,
- // so we just grab the args from the matching subst instead.
- foundSubst = genSubst;
- if (foundSubst->getOuter() != outerSubst)
- {
- foundSubst = astBuilder->getOrCreateGenericSubstitution(
- outerSubst, foundSubst->getGenericDecl(), foundSubst->getArgs());
- }
-
- break;
- }
-
- if(!foundSubst)
- {
- Substitutions* newSubst = createDefaultSubstitutionsForGeneric(
- astBuilder,
- semantics,
- genericParentDecl,
- outerSubst);
- outerSubst = newSubst;
- }
- else
- {
- outerSubst = foundSubst;
- }
- }
- }
-
- if(!outerSubst)
- return declRef;
-
- int diff = 0;
- return declRef.substituteImpl(astBuilder, outerSubst, &diff);
+ m_requirements.add(KeyValuePair<Decl*, RequirementWitness>(decl, witness));
+ m_requirementDictionary.add(decl, witness);
}
// TODO: need to figure out how to unify this with the logic
@@ -511,245 +411,73 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
ASTBuilder* astBuilder,
DeclRef<Decl> declRef)
{
- declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);
-
if (auto builtinMod = declRef.getDecl()->findModifier<BuiltinTypeModifier>())
{
- auto type = astBuilder->getOrCreate<BasicExpressionType>(builtinMod->tag);
- type->declRef = declRef;
+ // Always create builtin types in global AST builder.
+ if (astBuilder->getSharedASTBuilder()->getInnerASTBuilder() != astBuilder)
+ return DeclRefType::create(astBuilder->getSharedASTBuilder()->getInnerASTBuilder(), declRef);
+
+ declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);
+ auto type = astBuilder->getOrCreate<BasicExpressionType>(declRef.declRefBase);
return type;
}
else if (auto magicMod = declRef.getDecl()->findModifier<MagicTypeModifier>())
{
- GenericSubstitution* subst = nullptr;
- for(auto s = declRef.getSubst(); s; s = s->getOuter())
- {
- if(auto genericSubst = as<GenericSubstitution>(s))
- {
- subst = genericSubst;
- break;
- }
- }
+ // Always create builtin types in global AST builder.
+ if (astBuilder->getSharedASTBuilder()->getInnerASTBuilder() != astBuilder)
+ return DeclRefType::create(astBuilder->getSharedASTBuilder()->getInnerASTBuilder(), declRef);
- if (magicMod->magicName == "SamplerState")
- {
- auto type = astBuilder->getOrCreate<SamplerStateType>(SamplerStateFlavor(magicMod->tag));
- type->declRef = declRef;
- return type;
- }
- else if (magicMod->magicName == "Vector")
- {
- SLANG_ASSERT(subst && subst->getArgs().getCount() == 2);
- auto vecType = astBuilder->getOrCreate<VectorExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1]));
- vecType->declRef = declRef;
- vecType->elementType = ExtractGenericArgType(subst->getArgs()[0]);
- vecType->elementCount = ExtractGenericArgInteger(subst->getArgs()[1]);
- return vecType;
- }
- else if (magicMod->magicName == "ArrayType")
- {
- SLANG_ASSERT(subst && subst->getArgs().getCount() == 2);
- auto vecType = astBuilder->getOrCreate<ArrayExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1]));
- vecType->declRef = declRef;
- return vecType;
- }
- else if (magicMod->magicName == "Matrix")
+ declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);
+ auto classInfo = astBuilder->findSyntaxClass(magicMod->magicName.getUnownedSlice());
+ if (!classInfo.classInfo)
{
- SLANG_ASSERT(subst && subst->getArgs().getCount() == 3);
- auto matType = astBuilder->getOrCreate<MatrixExpressionType>(
- ExtractGenericArgType(subst->getArgs()[0]),
- ExtractGenericArgInteger(subst->getArgs()[1]),
- ExtractGenericArgInteger(subst->getArgs()[2]));
- matType->declRef = declRef;
- return matType;
+ SLANG_UNEXPECTED("unhandled type");
}
- else if (magicMod->magicName == "TensorViewType")
- {
- SLANG_ASSERT(subst && subst->getArgs().getCount() == 1);
- auto vecType = astBuilder->getOrCreate<TensorViewType>(ExtractGenericArgType(subst->getArgs()[0]));
- vecType->declRef = declRef;
- return vecType;
- }
- else if (magicMod->magicName == "Texture")
- {
- SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1);
- auto textureTag = TextureFlavor(magicMod->tag);
- Val* sampleCount = nullptr;
- if (textureTag.isMultisample())
+ ValNodeDesc nodeDesc = {};
+ nodeDesc.type = (ASTNodeType)classInfo.classInfo->m_classId;
+ nodeDesc.operands.add(ValNodeOperand(declRef));
+ nodeDesc.init();
+ NodeBase* type = astBuilder->_getOrCreateImpl(nodeDesc, [&]()
{
- if (subst->getArgs().getCount() >= 2)
- sampleCount = ExtractGenericArgInteger(subst->getArgs().getLast());
- }
- auto textureType = astBuilder->getOrCreate<TextureType>(
- textureTag,
- ExtractGenericArgType(subst->getArgs()[0]),
- sampleCount);
- textureType->declRef = declRef;
- return textureType;
- }
- else if (magicMod->magicName == "TextureSampler")
- {
- SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1);
- auto textureType = astBuilder->getOrCreate<TextureSamplerType>(
- TextureFlavor(magicMod->tag),
- ExtractGenericArgType(subst->getArgs()[0]));
- textureType->declRef = declRef;
- return textureType;
- }
- else if (magicMod->magicName == "GLSLImageType")
- {
- SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1);
- auto textureType = astBuilder->getOrCreate<GLSLImageType>(
- TextureFlavor(magicMod->tag),
- ExtractGenericArgType(subst->getArgs()[0]));
- textureType->declRef = declRef;
- return textureType;
- }
- else if (magicMod->magicName == "FeedbackType")
+ auto resultNode = as<DeclRefType>(classInfo.createInstance(astBuilder));
+ resultNode->setOperands(declRef);
+ return resultNode;
+ });
+ if (!type)
{
- SLANG_ASSERT(subst == nullptr);
- auto type = astBuilder->getOrCreateWithDefaultCtor<FeedbackType>(magicMod->tag);
- type->declRef = declRef;
- type->kind = FeedbackType::Kind(magicMod->tag);
- return type;
+ SLANG_UNEXPECTED("constructor failure");
}
- // TODO: eventually everything should follow this pattern,
- // and we can drive the dispatch with a table instead
- // of this ridiculously slow `if` cascade.
-
- #define CASE(n, T) \
- else if (magicMod->magicName == #n) \
- { \
- auto type = astBuilder->getOrCreateWithDefaultCtor<T>( \
- declRef.getDecl(), declRef.getSubst()); \
- type->declRef = declRef; \
- return type; \
- }
-
- CASE(HLSLInputPatchType, HLSLInputPatchType)
- CASE(HLSLOutputPatchType, HLSLOutputPatchType)
-
- #undef CASE
-
- #define CASE(n, T) \
- else if (magicMod->magicName == #n) \
- { \
- SLANG_ASSERT(subst && subst->getArgs().getCount() == 1); \
- auto type = \
- astBuilder->getOrCreateWithDefaultCtor<T>(ExtractGenericArgType(subst->getArgs()[0])); \
- type->elementType = ExtractGenericArgType(subst->getArgs()[0]); \
- type->declRef = declRef; \
- return type; \
- }
-
- CASE(ConstantBuffer, ConstantBufferType)
- CASE(TextureBuffer, TextureBufferType)
- CASE(ParameterBlockType, ParameterBlockType)
- CASE(GLSLInputParameterGroupType, GLSLInputParameterGroupType)
- CASE(GLSLOutputParameterGroupType, GLSLOutputParameterGroupType)
- CASE(GLSLShaderStorageBufferType, GLSLShaderStorageBufferType)
-
- CASE(HLSLStructuredBufferType, HLSLStructuredBufferType)
- CASE(HLSLRWStructuredBufferType, HLSLRWStructuredBufferType)
- CASE(HLSLRasterizerOrderedStructuredBufferType, HLSLRasterizerOrderedStructuredBufferType)
- CASE(HLSLAppendStructuredBufferType, HLSLAppendStructuredBufferType)
- CASE(HLSLConsumeStructuredBufferType, HLSLConsumeStructuredBufferType)
-
- CASE(HLSLPointStreamType, HLSLPointStreamType)
- CASE(HLSLLineStreamType, HLSLLineStreamType)
- CASE(HLSLTriangleStreamType, HLSLTriangleStreamType)
-
- #undef CASE
-
- // "magic" builtin types which have no generic parameters
- #define CASE(n,T) \
- else if(magicMod->magicName == #n) { \
- auto type = astBuilder->getOrCreate<T>(); \
- type->declRef = declRef; \
- return type; \
- }
-
- CASE(HLSLByteAddressBufferType, HLSLByteAddressBufferType)
- CASE(HLSLRWByteAddressBufferType, HLSLRWByteAddressBufferType)
- CASE(HLSLRasterizerOrderedByteAddressBufferType, HLSLRasterizerOrderedByteAddressBufferType)
- CASE(UntypedBufferResourceType, UntypedBufferResourceType)
-
- CASE(GLSLInputAttachmentType, GLSLInputAttachmentType)
-
- #undef CASE
-
- else
+ auto declRefType = dynamicCast<DeclRefType>(type);
+ if (!declRefType)
{
- auto classInfo = astBuilder->findSyntaxClass(magicMod->magicName.getUnownedSlice());
- if (!classInfo.classInfo)
- {
- SLANG_UNEXPECTED("unhandled type");
- }
-
- NodeBase* type = classInfo.createInstance(astBuilder);
- if (!type)
- {
- SLANG_UNEXPECTED("constructor failure");
- }
-
- auto declRefType = dynamicCast<DeclRefType>(type);
- if (!declRefType)
- {
- SLANG_UNEXPECTED("expected a declaration reference type");
- }
- declRefType->declRef = declRef;
- return declRefType;
+ SLANG_UNEXPECTED("expected a declaration reference type");
}
+ return declRefType;
+ }
+ else if (as<ThisTypeDecl>(declRef.getDecl()) && as<DirectDeclRef>(declRef.declRefBase))
+ {
+ declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);
+
+ return astBuilder->getOrCreate<ThisType>(declRef.declRefBase);
}
else
{
+ declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);
+
return astBuilder->getOrCreate<DeclRefType>(declRef.declRefBase);
}
}
//
- GenericSubstitution* findInnerMostGenericSubstitution(Substitutions* subst)
+ Val::OperandView<Val> findInnerMostGenericArgs(SubstitutionSet subst)
{
- for(Substitutions* s = subst; s; s = s->getOuter())
- {
- if(auto genericSubst = as<GenericSubstitution>(s))
- return genericSubst;
- }
- return nullptr;
- }
-
-
- // DeclRefBase
-
- Type* DeclRefBase::substitute(ASTBuilder* astBuilder, Type* type) const
- {
- // Note that type can be nullptr, and so this function can return nullptr (although only correctly when no substitutions)
-
- // No substitutions? Easy.
- if (!substitutions)
- return type;
-
- SLANG_ASSERT(type);
-
- // Otherwise we need to recurse on the type structure
- // and apply substitutions where it makes sense
- return Slang::as<Type>(type->substitute(astBuilder, substitutions));
- }
-
- DeclRefBase* DeclRefBase::substitute(ASTBuilder* astBuilder, DeclRefBase* declRef) const
- {
- if(!substitutions)
- return declRef;
-
- int diff = 0;
- return declRef->substituteImpl(astBuilder, substitutions, &diff);
- }
-
- SubstExpr<Expr> DeclRefBase::substitute(ASTBuilder* /* astBuilder*/, Expr* expr) const
- {
- return SubstExpr<Expr>(expr, substitutions);
+ if (!subst.declRef)
+ return Val::OperandView<Val>();
+ if (auto genApp = subst.findGenericAppDeclRef())
+ return genApp->getArgs();
+ return Val::OperandView<Val>();
}
SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr)
@@ -764,7 +492,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
int diff = 0;
auto declRefBase = declRef.substituteImpl(astBuilder, substs, &diff);
- return astBuilder->getSpecializedDeclRef<Decl>(declRefBase.getDecl(), declRefBase.getSubst());
+ return declRefBase;
}
Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type)
@@ -790,332 +518,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
return nullptr;
}
- Substitutions* specializeSubstitutionsShallow(
- ASTBuilder* astBuilder,
- Substitutions* substToSpecialize,
- Substitutions* substsToApply,
- Substitutions* restSubst,
- int* ioDiff)
- {
- SLANG_ASSERT(substToSpecialize);
- return substToSpecialize->applySubstitutionsShallow(astBuilder, substsToApply, restSubst, ioDiff);
- }
-
- // Construct new substitutions to apply to a declaration,
- // based on a provided substitution set to be applied
- Substitutions* specializeSubstitutions(
- ASTBuilder* astBuilder,
- Decl* declToSpecialize,
- Substitutions* substsToSpecialize,
- Substitutions* substsToApply,
- int* ioDiff)
- {
- // No declaration? Then nothing to specialize.
- if(!declToSpecialize)
- return nullptr;
-
- // No (remaining) substitutions to apply? Then we are done.
- if(!substsToApply)
- return substsToSpecialize;
-
- // Walk the hierarchy of the declaration to determine what specializations might apply.
- // We assume that the `substsToSpecialize` must be aligned with the ancestor
- // hierarchy of `declToSpecialize` such that if, e.g., the `declToSpecialize` is
- // nested directly in a generic, then `substToSpecialize` will either start with
- // the corresponding `GenericSubstitution` or there will be *no* generic substitutions
- // corresponding to that decl.
- for(Decl* ancestorDecl = declToSpecialize; ancestorDecl; ancestorDecl = ancestorDecl->parentDecl)
- {
- if(auto ancestorGenericDecl = as<GenericDecl>(ancestorDecl))
- {
- // The declaration is nested inside a generic.
- // Does it already have a specialization for that generic?
- if(auto specGenericSubst = as<GenericSubstitution>(substsToSpecialize))
- {
- if(specGenericSubst->getGenericDecl() == ancestorGenericDecl)
- {
- // Yes. We have an existing specialization, so we will
- // keep one matching it in place.
- int diff = 0;
- auto restSubst = specializeSubstitutions(
- astBuilder,
- ancestorGenericDecl->parentDecl,
- specGenericSubst->getOuter(),
- substsToApply,
- &diff);
-
- auto firstSubst = specializeSubstitutionsShallow(
- astBuilder,
- specGenericSubst,
- substsToApply,
- restSubst,
- &diff);
-
- *ioDiff += diff;
- return firstSubst;
- }
- }
-
- // If the declaration is not already specialized
- // for the given generic, then see if we are trying
- // to *apply* such specializations to it.
- //
- // TODO: The way we handle things right now with
- // "default" specializations, this case shouldn't
- // actually come up.
- //
- for(auto s = substsToApply; s; s = s->getOuter())
- {
- auto appGenericSubst = as<GenericSubstitution>(s);
- if(!appGenericSubst)
- continue;
-
- if(appGenericSubst->getGenericDecl() != ancestorGenericDecl)
- continue;
-
- // The substitutions we are applying are trying
- // to specialize this generic, but we don't already
- // have a generic substitution in place.
- // We will need to create one.
-
- int diff = 0;
- auto restSubst = specializeSubstitutions(
- astBuilder,
- ancestorGenericDecl->parentDecl,
- substsToSpecialize,
- substsToApply,
- &diff);
-
- GenericSubstitution* firstSubst = astBuilder->getOrCreateGenericSubstitution(
- restSubst, ancestorGenericDecl, appGenericSubst->getArgs());
-
- (*ioDiff)++;
- return firstSubst;
- }
- }
- else if(auto ancestorInterfaceDecl = as<InterfaceDecl>(ancestorDecl))
- {
- // The task is basically the same as for the generic case:
- // We want to see if there is any existing substitution that
- // applies to this declaration, and use that if possible.
-
- // The declaration is nested inside a generic.
- // Does it already have a specialization for that generic?
- if(auto specThisTypeSubst = as<ThisTypeSubstitution>(substsToSpecialize))
- {
- if(specThisTypeSubst->interfaceDecl == ancestorInterfaceDecl)
- {
- // Yes. We have an existing specialization, so we will
- // keep one matching it in place.
- int diff = 0;
- auto restSubst = specializeSubstitutions(
- astBuilder,
- ancestorInterfaceDecl->parentDecl,
- specThisTypeSubst->getOuter(),
- substsToApply,
- &diff);
-
- auto firstSubst = specializeSubstitutionsShallow(
- astBuilder,
- specThisTypeSubst,
- substsToApply,
- restSubst,
- &diff);
-
- *ioDiff += diff;
- return firstSubst;
- }
- }
-
- // Otherwise, check if we are trying to apply
- // a this-type substitution to the given interface
- //
- // Note: We want to skip the ThisTypeSubstitution that specializes
- // declToSpecialize itself (when declToSpecialize is an interface
- // decl and the subst specializes it), and only pull the
- // ThisTypeSubstitution when the decl is referencing a child of
- // the interface decl being specialized. This is because
- // by default an interface declref type is a "free" existential
- // type that shouldn't be specialized by someone else, unless
- // there is an "implicit" ThisType reference preceeding a child
- // reference.
- if (declToSpecialize != ancestorInterfaceDecl)
- {
- for (auto s = substsToApply; s; s = s->getOuter())
- {
- auto appThisTypeSubst = as<ThisTypeSubstitution>(s);
- if (!appThisTypeSubst)
- continue;
-
- if (appThisTypeSubst->interfaceDecl != ancestorInterfaceDecl)
- continue;
-
- int diff = 0;
- auto restSubst = specializeSubstitutions(
- astBuilder,
- ancestorInterfaceDecl->parentDecl,
- substsToSpecialize,
- substsToApply,
- &diff);
-
- ThisTypeSubstitution* firstSubst = astBuilder->getOrCreateThisTypeSubstitution(
- ancestorInterfaceDecl, appThisTypeSubst->witness, restSubst);
-
- (*ioDiff)++;
- return firstSubst;
- }
- }
- }
- }
-
- // If we reach here then we've walked the full hierarchy up from
- // `declToSpecialize` and either didn't run into an generic/interface
- // declarations, or we didn't find any attempt to specialize them
- // in either substitution.
- //
- // As an invariant, there should *not* be any generic or this-type
- // substitutions in `substToSpecialize`, because otherwise they
- // would be specializations that don't actually apply to the given
- // declaration.
- //
- // Note: this does *not* mean that `substsToApply` doesn't have
- // any generic or this-type substitutions; it just means that none
- // of them were applicable.
- //
- return nullptr;
- }
-
- DeclRefBase* DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet substSet, int* ioDiff) const
- {
- // Nothing to do when we have no declaration.
- if(!decl)
- return const_cast<DeclRefBase*>(this);
-
- // Apply the given substitutions to any specializations
- // that have already been applied to this declaration.
- int diff = 0;
-
- auto substSubst = specializeSubstitutions(
- astBuilder,
- decl,
- substitutions,
- substSet.substitutions,
- &diff);
-
- if (!diff)
- return const_cast<DeclRefBase*>(this);
-
- *ioDiff += diff;
-
- DeclRefBase* substDeclRef = astBuilder->getSpecializedDeclRef(decl, substSubst);
-
- // TODO: The old code here used to try to translate a decl-ref
- // to an associated type in a decl-ref for the concrete type
- // in a particular implementation.
- //
- // I have only kept that logic in `DeclRefType::SubstituteImpl`,
- // but it may turn out it is needed here too.
-
- return substDeclRef;
- }
-
- bool DeclRefBase::_equalsValOverride(Val* val)
- {
- if (auto otherDeclRef = as<DeclRefBase>(val))
- return equals(otherDeclRef);
- return false;
- }
-
- // Check if this is an equivalent declaration reference to another
- bool DeclRefBase::equals(DeclRefBase* declRef) const
- {
- if (!declRef)
- return false;
- if (decl != declRef->decl)
- return false;
- if (!SubstitutionSet(substitutions).equals(declRef->substitutions))
- return false;
-
- return true;
- }
-
- // Convenience accessors for common properties of declarations
- Name* DeclRefBase::getName() const
- {
- return decl->nameAndLoc.name;
- }
- SourceLoc DeclRefBase::getNameLoc() const
- {
- return decl->nameAndLoc.loc;
- }
- SourceLoc DeclRefBase::getLoc() const
- {
- return decl->loc;
- }
-
- DeclRefBase* DeclRefBase::getParent(ASTBuilder* astBuilder) const
- {
- // Want access to the free function (the 'as' method by default gets priority)
- // Can access as method with this->as because it removes any ambiguity.
- using Slang::as;
-
- auto parentDecl = decl->parentDecl;
- if (!parentDecl)
- return nullptr;
-
- // Default is to apply the same set of substitutions/specializations
- // to the parent declaration as were applied to the child.
- Substitutions* substToApply = substitutions;
-
- if(auto interfaceDecl = as<InterfaceDecl>(decl))
- {
- // The declaration being referenced is an `interface` declaration,
- // and there might be a this-type substitution in place.
- // A reference to the parent of the interface declaration
- // should not include that substitution.
- if(auto thisTypeSubst = as<ThisTypeSubstitution>(substToApply))
- {
- if(thisTypeSubst->interfaceDecl == interfaceDecl)
- {
- // Strip away that specializations that apply to the interface.
- substToApply = thisTypeSubst->getOuter();
- }
- }
- }
-
- if (auto parentGenericDecl = as<GenericDecl>(parentDecl))
- {
- // The parent of this declaration is a generic, which means
- // that the decl-ref to the current declaration might include
- // substitutions that specialize the generic parameters.
- // A decl-ref to the parent generic should *not* include
- // those substitutions.
- //
- if(auto genericSubst = as<GenericSubstitution>(substToApply))
- {
- if(genericSubst->getGenericDecl() == parentGenericDecl)
- {
- // Strip away the specializations that were applied to the parent.
- substToApply = genericSubst->getOuter();
- }
- }
- }
-
- return astBuilder->getSpecializedDeclRef(parentDecl, substToApply);
- }
-
- HashCode DeclRefBase::getHashCode() const
- {
- return combineHash(PointerHash<1>::getHashCode(decl), SubstitutionSet(substitutions).getHashCode());
- }
-
// IntVal
IntegerLiteralValue getIntVal(IntVal* val)
{
if (auto constantVal = as<ConstantIntVal>(val))
{
- return constantVal->value;
+ return constantVal->getValue();
}
SLANG_UNEXPECTED("needed a known integer value");
//return 0;
@@ -1125,14 +534,22 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
// HLSLPatchType
+ Val* getGenericArg(DeclRef<Decl> declRef, Index index)
+ {
+ auto subst = SubstitutionSet(declRef).findGenericAppDeclRef();
+ if (index < subst->getArgs().getCount())
+ return subst->getArgs()[index];
+ return nullptr;
+ }
+
Type* HLSLPatchType::getElementType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(getGenericArg(getDeclRef(), 0));
}
IntVal* HLSLPatchType::getElementCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]);
+ return as<IntVal>(getGenericArg(getDeclRef(), 1));
}
// MeshOutputType
@@ -1143,12 +560,12 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
Type* MeshOutputType::getElementType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]);
+ return as<Type>(getGenericArg(getDeclRef(), 0));
}
IntVal* MeshOutputType::getMaxElementCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]);
+ return as<IntVal>(getGenericArg(getDeclRef(), 1));
}
// Constructors for types
@@ -1174,17 +591,16 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
{
DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef).as<TypeDefDecl>();
- return astBuilder->create<NamedExpressionType>(specializedDeclRef);
+ return astBuilder->getOrCreate<NamedExpressionType>(specializedDeclRef);
}
FuncType* getFuncType(
ASTBuilder* astBuilder,
DeclRef<CallableDecl> const& declRef)
{
- FuncType* funcType = astBuilder->create<FuncType>();
-
- funcType->resultType = getResultType(astBuilder, declRef);
- funcType->errorType = getErrorCodeType(astBuilder, declRef);
+ List<Type*> paramTypes;
+ auto resultType = getResultType(astBuilder, declRef);
+ auto errorType = getErrorCodeType(astBuilder, declRef);
for (auto paramDeclRef : getParameters(astBuilder, declRef))
{
auto paramDecl = paramDeclRef.getDecl();
@@ -1204,9 +620,10 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
paramType = astBuilder->getOutType(paramType);
}
}
- funcType->paramTypes.add(paramType);
+ paramTypes.add(paramType);
}
+ FuncType* funcType = astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType);
return funcType;
}
@@ -1214,40 +631,34 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
ASTBuilder* astBuilder,
DeclRef<GenericDecl> const& declRef)
{
- return astBuilder->create<GenericDeclRefType>(declRef);
+ return astBuilder->getOrCreate<GenericDeclRefType>(declRef);
}
NamespaceType* getNamespaceType(
ASTBuilder* astBuilder,
DeclRef<NamespaceDeclBase> const& declRef)
{
- auto type = astBuilder->create<NamespaceType>();
- type->declRef = declRef;
+ auto type = astBuilder->getOrCreate<NamespaceType>(declRef);
return type;
}
SamplerStateType* getSamplerStateType(
ASTBuilder* astBuilder)
{
- return astBuilder->create<SamplerStateType>();
+ return astBuilder->getSamplerStateType();
}
- ThisTypeSubstitution* findThisTypeSubstitution(
- const Substitutions* substs,
+ SubtypeWitness* findThisTypeWitness(
+ SubstitutionSet substs,
InterfaceDecl* interfaceDecl)
{
- for(const Substitutions* s = substs; s; s = s->getOuter())
+ auto lookupDeclRef = substs.findLookupDeclRef();
+ if (!lookupDeclRef)
+ return nullptr;
+ if (lookupDeclRef->getSupDecl() == interfaceDecl)
{
- auto thisTypeSubst = as<ThisTypeSubstitution>(s);
- if(!thisTypeSubst)
- continue;
-
- if(thisTypeSubst->interfaceDecl != interfaceDecl)
- continue;
-
- return const_cast<ThisTypeSubstitution*>(thisTypeSubst);
+ return lookupDeclRef->getWitness();
}
-
return nullptr;
}
@@ -1259,20 +670,16 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
auto substAssocTypeDecl = substDeclRef.getDecl();
- for (auto s = substDeclRef.getSubst(); s; s = s->getOuter())
+ if (auto lookupDeclRef = SubstitutionSet(substDeclRef).findLookupDeclRef())
{
- auto thisSubst = as<ThisTypeSubstitution>(s);
- if (!thisSubst)
- continue;
-
if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl))
{
- if (thisSubst->interfaceDecl == interfaceDecl)
+ if (lookupDeclRef->getSupDecl() == interfaceDecl)
{
// We need to look up the declaration that satisfies
// the requirement named by the associated type.
Decl* requirementKey = substAssocTypeDecl;
- RequirementWitness requirementWitness = tryLookUpRequirementWitness(builder, thisSubst->witness, requirementKey);
+ RequirementWitness requirementWitness = tryLookUpRequirementWitness(builder, lookupDeclRef->getWitness(), requirementKey);
switch (requirementWitness.getFlavor())
{
default:
@@ -1296,17 +703,17 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
if (builtinReq->kind != BuiltinRequirementKind::DifferentialType)
return nullptr;
// Is the concrete type a Differential associated type?
- auto innerDeclRefType = as<DeclRefType>(thisSubst->witness->sub);
+ auto innerDeclRefType = as<DeclRefType>(lookupDeclRef->getWitness()->getSub());
if (!innerDeclRefType)
return nullptr;
- auto innerBuiltinReq = innerDeclRefType->declRef.getDecl()->findModifier<BuiltinRequirementModifier>();
+ auto innerBuiltinReq = innerDeclRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>();
if (!innerBuiltinReq)
return nullptr;
if (innerBuiltinReq->kind != BuiltinRequirementKind::DifferentialType)
return nullptr;
- if (!innerDeclRefType->declRef.equals(declRef))
+ if (!innerDeclRefType->getDeclRef().equals(declRef))
{
- auto result = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(builder, innerDeclRefType->declRef);
+ auto result = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(builder, innerDeclRefType->getDeclRef());
if (result)
return result;
}
@@ -1320,119 +727,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
return nullptr;
}
- String DeclRefBase::toString() const
- {
- StringBuilder builder;
- toText(builder);
- return std::move(builder);
- }
-
- // Prints a partially qualified type name with generic substitutions.
- void _printNestedDecl(const Substitutions* substitutions, const Decl* decl, StringBuilder& out)
- {
- // If there is a parent scope for the declaration, print it first.
- // Exclude top-level namespaces like `tu0` or `core`.
- if (decl->parentDecl && !Slang::as<ModuleDecl>(decl->parentDecl))
- {
- auto parentGeneric = Slang::as<GenericDecl>(decl->parentDecl);
-
- // Exclude function or operator names.
- // Avoids excessively verbose messages like `func<T>(func::T)`
- if (!(parentGeneric && Slang::as<CallableDecl>(parentGeneric->inner)))
- {
- _printNestedDecl(substitutions, decl->parentDecl, out);
-
- // If the parent is a generic for this type, skip *this* type.
- // Avoids duplicate types like `MyType<T>::MyType`
- if (parentGeneric && parentGeneric->inner == decl)
- return;
-
- out << ".";
- }
- }
- // If we have a ThisTypeSubstitution to an interface decl, print the substituted sub
- // type instead.
- for (;;)
- {
- if (auto interfaceDecl = const_cast<InterfaceDecl*>(as<InterfaceDecl>(decl)))
- {
- if (auto thisSubst = findThisTypeSubstitution(substitutions, interfaceDecl))
- {
- if (auto subTypeWitness = as<SubtypeWitness>(thisSubst->witness))
- {
- out << subTypeWitness->sub;
- break;
- }
- }
- }
- // Otherwise, just print this type's name.
- auto name = decl->getName();
- if (name)
- {
- out << name->text;
- }
- break;
- }
-
- // Look for generic substitutions on this type.
- for (const Substitutions* subst = substitutions; subst; subst = subst->getOuter())
- {
- auto genericSubstitution = Slang::as<GenericSubstitution>(subst);
- if (!genericSubstitution)
- continue;
-
- // If the substitution is for this type, print it.
- if (genericSubstitution->getGenericDecl() == decl)
- {
- out << "<";
- bool isFirst = true;
- for (const auto& it : genericSubstitution->getArgs())
- {
- // Don't print out witnesses.
- if (as<Witness>(it))
- continue;
- if (!isFirst)
- out << ", ";
- isFirst = false;
- it->toText(out);
- }
- out << ">";
-
- break;
- }
- }
- }
-
- void DeclRefBase::toText(StringBuilder& out) const
- {
- if (decl)
- {
- _printNestedDecl(substitutions, decl, out);
- }
- }
-
- bool SubstitutionSet::equals(const SubstitutionSet& substSet) const
- {
- if (substitutions == substSet.substitutions)
- {
- return true;
- }
- if (substitutions == nullptr || substSet.substitutions == nullptr)
- {
- return false;
- }
- return substitutions->equals(substSet.substitutions);
- }
-
- HashCode SubstitutionSet::getHashCode() const
- {
- HashCode rs = 0;
- if (substitutions)
- rs = combineHash(rs, substitutions->getHashCode());
- return rs;
- }
-
-
ModuleDecl* getModuleDecl(Decl* decl)
{
for( auto dd = decl; dd; dd = dd->parentDecl )
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index a63a2471c..4addb1d53 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -22,22 +22,22 @@ namespace Slang
inline bool areValsEqual(Val* left, Val* right)
{
if(!left || !right) return left == right;
- return left->equalsVal(right);
+ return left->equals(right);
}
//
inline BaseType getVectorBaseType(VectorExpressionType* vecType)
{
- auto basicExprType = as<BasicExpressionType>(vecType->elementType);
- return basicExprType->baseType;
+ auto basicExprType = as<BasicExpressionType>(vecType->getElementType());
+ return basicExprType->getBaseType();
}
inline int getVectorSize(VectorExpressionType* vecType)
{
- auto constantVal = as<ConstantIntVal>(vecType->elementCount);
+ auto constantVal = as<ConstantIntVal>(vecType->getElementCount());
if (constantVal)
- return (int) constantVal->value;
+ return (int) constantVal->getValue();
// TODO: what to do in this case?
return 0;
}
@@ -52,15 +52,21 @@ namespace Slang
DeclRef<AggTypeDecl> const& declRef,
SemanticsVisitor* semantics);
+ // Returns the members of `genericInnerDecl`'s enclosing generic decl.
+ inline FilteredMemberRefList<Decl> getGenericMembers(ASTBuilder* astBuilder, DeclRef<Decl> genericInnerDecl, MemberFilterStyle filterStyle = MemberFilterStyle::All)
+ {
+ return FilteredMemberRefList<Decl>(astBuilder, genericInnerDecl.getParent().getDecl()->members, genericInnerDecl, filterStyle);
+ }
+
inline FilteredMemberRefList<Decl> getMembers(ASTBuilder* astBuilder, DeclRef<ContainerDecl> declRef, MemberFilterStyle filterStyle = MemberFilterStyle::All)
{
- return FilteredMemberRefList<Decl>(astBuilder, declRef.getDecl()->members, declRef.getSubst(), filterStyle);
+ return FilteredMemberRefList<Decl>(astBuilder, declRef.getDecl()->members, declRef, filterStyle);
}
template<typename T>
inline FilteredMemberRefList<T> getMembersOfType(ASTBuilder* astBuilder, DeclRef<ContainerDecl> declRef, MemberFilterStyle filterStyle = MemberFilterStyle::All)
{
- return FilteredMemberRefList<T>(astBuilder, declRef.getDecl()->members, declRef.getSubst(), filterStyle);
+ return FilteredMemberRefList<T>(astBuilder, declRef.getDecl()->members, declRef, filterStyle);
}
void _foreachDirectOrExtensionMemberOfType(
@@ -70,7 +76,7 @@ namespace Slang
void (*callback)(DeclRefBase*, void*),
void const* userData);
- DeclRef<Decl> _getSpecializedDeclRef(ASTBuilder* builder, Decl* decl, Substitutions* subst);
+ DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl);
template<typename T, typename F>
inline void foreachDirectOrExtensionMemberOfType(
@@ -153,6 +159,26 @@ namespace Slang
/// If the given `structTypeDeclRef` inherits from another struct type, return that base struct decl
DeclRef<StructDecl> findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef<StructDecl> structTypeDeclRef);
+ SubtypeWitness* findThisTypeWitness(
+ SubstitutionSet substs,
+ InterfaceDecl* interfaceDecl);
+
+ RequirementWitness tryLookUpRequirementWitness(
+ ASTBuilder* astBuilder,
+ SubtypeWitness* subtypeWitness,
+ Decl* requirementKey);
+
+ DeclRef<Decl> createDefaultSubstitutionsIfNeeded(
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
+ DeclRef<Decl> declRef);
+
+ List<Val*> getDefaultSubstitutionArgs(ASTBuilder* astBuilder, SemanticsVisitor* semantics, GenericDecl* genericDecl);
+
+ Val::OperandView<Val> findInnerMostGenericArgs(SubstitutionSet subst);
+
+ ParameterDirection getParameterDirection(VarDeclBase* varDecl);
+
inline Type* getTagType(ASTBuilder* astBuilder, DeclRef<EnumDecl> declRef)
{
return declRef.substitute(astBuilder, declRef.getDecl()->tagType);
@@ -192,8 +218,6 @@ namespace Slang
inline Decl* getInner(DeclRef<GenericDecl> declRef)
{
- // TODO: Should really return a `DeclRef<Decl>` for the inner
- // declaration, and not just a raw pointer
return declRef.getDecl()->inner;
}
@@ -288,46 +312,6 @@ namespace Slang
//
- ThisTypeSubstitution* findThisTypeSubstitution(
- const Substitutions* substs,
- InterfaceDecl* interfaceDecl);
-
- RequirementWitness tryLookUpRequirementWitness(
- ASTBuilder* astBuilder,
- SubtypeWitness* subtypeWitness,
- Decl* requirementKey);
-
- // TODO: where should this live?
- SubstitutionSet createDefaultSubstitutions(
- ASTBuilder* astBuilder,
- SemanticsVisitor* semantics,
- Decl* decl,
- SubstitutionSet parentSubst);
-
- SubstitutionSet createDefaultSubstitutions(
- ASTBuilder* astBuilder,
- SemanticsVisitor* semantics,
- Decl* decl);
-
- DeclRef<Decl> createDefaultSubstitutionsIfNeeded(
- ASTBuilder* astBuilder,
- SemanticsVisitor* semantics,
- DeclRef<Decl> declRef);
-
- GenericSubstitution* createDefaultSubstitutionsForGeneric(
- ASTBuilder* astBuilder,
- SemanticsVisitor* semantics,
- GenericDecl* genericDecl,
- Substitutions* outerSubst);
-
- GenericSubstitution* findInnerMostGenericSubstitution(Substitutions* subst);
-
- ThisTypeSubstitution* findThisTypeSubstitution(
- const Substitutions* substs,
- InterfaceDecl* interfaceDecl);
-
- ParameterDirection getParameterDirection(VarDeclBase* varDecl);
-
enum class UserDefinedAttributeTargets
{
None = 0,
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 5cf8d2350..f5e14366d 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -1571,9 +1571,9 @@ static LayoutSize GetElementCount(IntVal* val)
if (auto constantVal = as<ConstantIntVal>(val))
{
- if (constantVal->value == kUnsizedArrayMagicLength)
+ if (constantVal->getValue() == kUnsizedArrayMagicLength)
return LayoutSize::infinite();
- return LayoutSize(LayoutSize::RawValue(constantVal->value));
+ return LayoutSize(LayoutSize::RawValue(constantVal->getValue()));
}
else if(const auto varRefVal = as<GenericParamIntVal>(val))
{
@@ -2766,7 +2766,7 @@ RefPtr<TypeLayout> createParameterGroupTypeLayout(
parameterGroupRules,
context.targetReq);
- auto elementType = parameterGroupType->elementType;
+ auto elementType = parameterGroupType->getElementType();
return _createParameterGroupTypeLayout(
context,
@@ -3642,24 +3642,6 @@ static void _addLayout(TypeLayoutContext const& context,
static TypeLayoutResult _updateLayout(TypeLayoutContext const& context,
Type* type,
- TypeLayout* layout,
- const SimpleLayoutInfo& info)
-{
- auto layoutResultPtr = context.layoutMap.tryGetValue(type);
- SLANG_ASSERT(layoutResultPtr);
- if (layoutResultPtr)
- {
- // Check the layout is the same!
- SLANG_ASSERT(layoutResultPtr->layout.get() == layout);
- // Update the info
- layoutResultPtr->info = info;
- }
-
- return TypeLayoutResult(layout, info);
-}
-
-static TypeLayoutResult _updateLayout(TypeLayoutContext const& context,
- Type* type,
const TypeLayoutResult& result)
{
auto layoutResultPtr = context.layoutMap.tryGetValue(type);
@@ -3791,7 +3773,7 @@ static TypeLayoutResult _createTypeLayout(
context, \
ShaderParameterKind::KIND, \
type_##TYPE, \
- type_##TYPE->elementType); \
+ type_##TYPE->getElementType()); \
return TypeLayoutResult(typeLayout, info); \
} while(0)
@@ -3826,14 +3808,14 @@ static TypeLayoutResult _createTypeLayout(
else if(auto basicType = as<BasicExpressionType>(type))
{
return createSimpleTypeLayout(
- rules->GetScalarLayout(basicType->baseType),
+ rules->GetScalarLayout(basicType->getBaseType()),
type,
rules);
}
else if(auto vecType = as<VectorExpressionType>(type))
{
- auto elementType = vecType->elementType;
- size_t elementCount = (size_t) getIntVal(vecType->elementCount);
+ auto elementType = vecType->getElementType();
+ size_t elementCount = (size_t) getIntVal(vecType->getElementCount());
auto element = _createTypeLayout(
context,
@@ -3842,7 +3824,7 @@ static TypeLayoutResult _createTypeLayout(
BaseType elementBaseType = BaseType::Void;
if (auto elementBasicType = as<BasicExpressionType>(elementType))
{
- elementBaseType = elementBasicType->baseType;
+ elementBaseType = elementBasicType->getBaseType();
}
auto info = rules->GetVectorLayout(elementBaseType, element.info, elementCount);
@@ -3874,7 +3856,7 @@ static TypeLayoutResult _createTypeLayout(
BaseType elementBaseType = BaseType::Void;
if (auto elementBasicType = as<BasicExpressionType>(elementType))
{
- elementBaseType = elementBasicType->baseType;
+ elementBaseType = elementBasicType->getBaseType();
}
// The `GetMatrixLayout` implementation in the layout rules
@@ -3972,7 +3954,7 @@ static TypeLayoutResult _createTypeLayout(
}
else if (auto declRefType = as<DeclRefType>(type))
{
- auto declRef = declRefType->declRef;
+ auto declRef = declRefType->getDeclRef();
if (auto structDeclRef = declRef.as<StructDecl>())
{
@@ -4346,99 +4328,20 @@ static TypeLayoutResult _createTypeLayout(
errorType,
rules);
}
- else if( auto taggedUnionType = as<TaggedUnionType>(type) )
+ else if( auto existentialSpecializedType = as<ExistentialSpecializedType>(type) )
{
- // A tagged union type needs to be laid out as the maximum
- // size of any constituent type.
- //
- // In practice, only a tagged union of uniform data will
- // work, but for now we will compute the maximum usage
- // for each resource kind for generality.
- //
- // For the uniform data we will start with a size
- // of zero and an alignment of one for our base case
- // (this is what a tagged union of no cases would consume).
- //
- UniformLayoutInfo info(0, 1);
-
- RefPtr<TaggedUnionTypeLayout> taggedUnionLayout = new TaggedUnionTypeLayout();
-
- _addLayout(context, type, taggedUnionLayout);
-
- taggedUnionLayout->type = type;
- taggedUnionLayout->rules = rules;
-
- // Now we iterate over the case types and see if they
- // change our computed maximum size/alignement.
- //
- for( auto caseType : taggedUnionType->caseTypes )
+ ExpandedSpecializationArgs args;
+ for (Index i = 0; i < existentialSpecializedType->getArgCount(); ++i)
{
- // Note: A tagged union type is not expected to have any existential/interface type
- // slots; the case types that are provided must be fully specialized before the union is
- // formed. Thus we don't need to mess around with existential type slots here the
- // way we do for the `struct` case.
-
- auto caseTypeResult = _createTypeLayout(context, caseType);
- RefPtr<TypeLayout> caseTypeLayout = caseTypeResult.layout;
- UniformLayoutInfo caseTypeInfo = caseTypeResult.info.getUniformLayout();
-
- info.size = maximum(info.size, caseTypeInfo.size);
- info.alignment = std::max(info.alignment, caseTypeInfo.alignment);
-
- // We need to remember the layout of the case type
- // on the final `TaggedUnionTypeLayout`.
- //
- taggedUnionLayout->caseTypeLayouts.add(caseTypeLayout);
-
- // We also need to consider contributions for other
- // resource kinds beyond uniform data.
- //
- for( auto caseResInfo : caseTypeLayout->resourceInfos )
- {
- auto unionResInfo = taggedUnionLayout->findOrAddResourceInfo(caseResInfo.kind);
- unionResInfo->count = maximum(unionResInfo->count, caseResInfo.count);
- }
- }
-
- // After we've computed the size required to hold all the
- // case types, we will allocate space for the tag field.
- //
- // TODO: This assumes the tag will always be allocated out
- // of uniform storage, which means we can't support a tagged
- // union as part of a varying input/output signature. That is
- // probably a valid limitation, but it should get enforced
- // somewhere along the way.
- //
- {
- // The tag is always a `uint` for now.
- //
- auto tagInfo = context.rules->GetScalarLayout(BaseType::UInt);
- info.size = _roundToAlignment(info.size, tagInfo.alignment);
-
- taggedUnionLayout->tagOffset = info.size;
-
- info.size += tagInfo.size;
- info.alignment = std::max(info.alignment, tagInfo.alignment);
+ args.add(existentialSpecializedType->getArg(i));
}
-
- // As a final step, if we are computing a full `TypeLayout`
- // we will make sure that its information on uniform layout
- // matches what we've computed in the `UniformLayoutInfo` we return.
- //
- taggedUnionLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->count = info.size;
- taggedUnionLayout->uniformAlignment = info.alignment;
-
- return _updateLayout(context, type, taggedUnionLayout, info);
- }
- else if( auto existentialSpecializedType = as<ExistentialSpecializedType>(type) )
- {
TypeLayoutContext subContext = context.withSpecializationArgs(
- existentialSpecializedType->args.getBuffer(),
- existentialSpecializedType->args.getCount());
+ args.getBuffer(),
+ args.getCount());
auto baseTypeLayoutResult = _createTypeLayout(
subContext,
- existentialSpecializedType->baseType);
+ existentialSpecializedType->getBaseType());
UniformLayoutInfo info = rules->BeginStructLayout();
rules->AddStructField(&info, baseTypeLayoutResult.info.getUniformLayout());
@@ -4534,7 +4437,7 @@ RefPtr<TypeLayout> getSimpleVaryingParameterTypeLayout(
if(auto basicType = as<BasicExpressionType>(type))
{
- auto baseType = basicType->baseType;
+ auto baseType = basicType->getBaseType();
RefPtr<TypeLayout> typeLayout = new TypeLayout();
typeLayout->type = type;
@@ -4550,13 +4453,13 @@ RefPtr<TypeLayout> getSimpleVaryingParameterTypeLayout(
}
else if(auto vecType = as<VectorExpressionType>(type))
{
- auto elementType = vecType->elementType;
- size_t elementCount = (size_t) getIntVal(vecType->elementCount);
+ auto elementType = vecType->getElementType();
+ size_t elementCount = (size_t) getIntVal(vecType->getElementCount());
BaseType elementBaseType = BaseType::Void;
if( auto elementBasicType = as<BasicExpressionType>(elementType) )
{
- elementBaseType = elementBasicType->baseType;
+ elementBaseType = elementBasicType->getBaseType();
}
// Note that we do *not* add any resource usage to the type
@@ -4592,7 +4495,7 @@ RefPtr<TypeLayout> getSimpleVaryingParameterTypeLayout(
BaseType elementBaseType = BaseType::Void;
if( auto elementBasicType = as<BasicExpressionType>(elementType) )
{
- elementBaseType = elementBasicType->baseType;
+ elementBaseType = elementBasicType->getBaseType();
}
// Just as for `_createTypeLayout`, we need to handle row- and
@@ -4711,7 +4614,7 @@ GlobalGenericParamDecl* GenericParamTypeLayout::getGlobalGenericParamDecl()
{
auto declRefType = as<DeclRefType>(type);
SLANG_ASSERT(declRefType);
- auto rsDeclRef = declRefType->declRef.as<GlobalGenericParamDecl>();
+ auto rsDeclRef = declRefType->getDeclRef().as<GlobalGenericParamDecl>();
return rsDeclRef.getDecl();
}
diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h
index 7b822eac4..c800d0931 100644
--- a/source/slang/slang-type-layout.h
+++ b/source/slang/slang-type-layout.h
@@ -725,28 +725,6 @@ public:
Index paramIndex = 0;
};
- /// Layout information for a tagged union type.
-class TaggedUnionTypeLayout : public TypeLayout
-{
-public:
- /// The layouts of each of the case types.
- ///
- /// The order of entries in this array matches
- /// the order of case types on the original
- /// `TaggedUnionType`, and the index of a case
- /// type is also the tag value for that case.
- ///
- List<RefPtr<TypeLayout>> caseTypeLayouts;
-
- /// The byte offset for the tag field.
- ///
- /// The tag field will always be allocated as
- /// a `uint`, so we don't store a separate layout
- /// for it.
- ///
- LayoutSize tagOffset;
-};
-
/// Layout information for an interface/existential type
///
/// This class is used to represent the layout of an interface type
@@ -912,13 +890,6 @@ public:
///
Dictionary<GlobalGenericParamDecl*, Val*> globalGenericArgs;
- /// Layouts for all tagged union types required by this program
- ///
- /// These are any tagged union types used by the specialization
- /// arguments that have been used to specialize the program.
- ///
- List<RefPtr<TypeLayout>> taggedUnionTypeLayouts;
-
/// Holds all of the string literals that have been hashed
StringSlicePool hashedStringLiteralPool;
};
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index e08bb2a62..266533874 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -161,11 +161,9 @@ void Session::init()
m_sharedASTBuilder = new SharedASTBuilder;
m_sharedASTBuilder->init(this);
- // Use to create a ASTBuilder
- RefPtr<ASTBuilder> builtinAstBuilder(new ASTBuilder(m_sharedASTBuilder, "m_builtInLinkage::m_astBuilder"));
-
// And the global ASTBuilder
- globalAstBuilder = new ASTBuilder(m_sharedASTBuilder, "globalAstBuilder");
+ auto builtinAstBuilder = m_sharedASTBuilder->getInnerASTBuilder();
+ globalAstBuilder = builtinAstBuilder;
// Make sure our source manager is initialized
builtinSourceManager.initialize(nullptr, nullptr);
@@ -367,6 +365,8 @@ SlangResult Session::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes)
return SLANG_FAIL;
}
+ SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder());
+
// Make a file system to read it from
ComPtr<ISlangFileSystemExt> fileSystem;
SLANG_RETURN_ON_FAIL(loadArchiveFileSystem(stdLib, stdLibSizeInBytes, fileSystem));
@@ -397,6 +397,8 @@ SlangResult Session::saveStdLib(SlangArchiveType archiveType, ISlangBlob** outBl
return SLANG_FAIL;
}
+ SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder());
+
for (auto& pair : m_builtinLinkage->mapNameToLoadedModules)
{
const Name* moduleName = pair.key;
@@ -463,6 +465,7 @@ SlangResult Session::_readBuiltinModule(ISlangFileSystem* fileSystem, Scope* sco
options.namePool = linkageNamePool;
options.session = this;
options.sharedASTBuilder = linkage->getASTBuilder()->getSharedASTBuilder();
+ options.astBuilder = linkage->getASTBuilder();
options.sourceManager = sourceManger;
options.linkage = linkage;
@@ -920,6 +923,9 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinka
, m_sourceManager(&m_defaultSourceManager)
, m_astBuilder(astBuilder)
{
+ if (builtinLinkage)
+ m_astBuilder->m_cachedNodes = builtinLinkage->getASTBuilder()->m_cachedNodes;
+
getNamePool()->setRootNamePool(session->getRootNamePool());
m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr);
@@ -990,6 +996,8 @@ SLANG_NO_THROW slang::IGlobalSession* SLANG_MCALL Linkage::getGlobalSession()
void Linkage::addTarget(
slang::TargetDesc const& desc)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
auto targetIndex = addTarget(CodeGenTarget(desc.format));
auto target = targets[targetIndex];
@@ -1018,6 +1026,8 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModule(
const char* moduleName,
slang::IBlob** outDiagnostics)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer);
if (isInLanguageServer())
@@ -1048,6 +1058,8 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSource(
slang::IBlob* source,
slang::IBlob** outDiagnostics)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer);
if (isInLanguageServer())
{
@@ -1096,6 +1108,8 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompositeComponentType(
slang::IComponentType** outCompositeComponentType,
ISlangBlob** outDiagnostics)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
// Attempting to create a "composite" of just one component type should
// just return the component type itself, to avoid redundant work.
//
@@ -1131,6 +1145,8 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType(
SlangInt specializationArgCount,
ISlangBlob** outDiagnostics)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
auto unspecializedType = asInternal(inUnspecializedType);
List<Type*> typeArgs;
@@ -1157,6 +1173,8 @@ SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL Linkage::getTypeLayout(
slang::LayoutRules rules,
ISlangBlob** outDiagnostics)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
auto type = asInternal(inType);
if(targetIndex < 0 || targetIndex >= targets.getCount())
@@ -1187,6 +1205,8 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType(
slang::ContainerType containerType,
ISlangBlob** outDiagnostics)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
auto type = asInternal(inType);
Type* containerTypeReflection = nullptr;
@@ -1197,29 +1217,20 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType(
{
case slang::ContainerType::ConstantBuffer:
{
- ConstantBufferType* cbType = getASTBuilder()->create<ConstantBufferType>();
- cbType->elementType = type;
- cbType->declRef = getASTBuilder()->getBuiltinDeclRef(
- "ConstantBuffer", static_cast<Val*>(type));
+ ConstantBufferType* cbType = getASTBuilder()->getConstantBufferType(type);
containerTypeReflection = cbType;
}
break;
case slang::ContainerType::ParameterBlock:
{
- ParameterBlockType* pbType = getASTBuilder()->create<ParameterBlockType>();
- pbType->elementType = type;
- pbType->declRef = getASTBuilder()->getBuiltinDeclRef(
- "ParameterBlock", static_cast<Val*>(type));
+ ParameterBlockType* pbType = getASTBuilder()->getParameterBlockType(type);
containerTypeReflection = pbType;
}
break;
case slang::ContainerType::StructuredBuffer:
{
HLSLStructuredBufferType* sbType =
- getASTBuilder()->create<HLSLStructuredBufferType>();
- sbType->elementType = type;
- sbType->declRef = getASTBuilder()->getBuiltinDeclRef(
- "HLSLStructuredBufferType", static_cast<Val*>(type));
+ getASTBuilder()->getStructuredBufferType(type);
containerTypeReflection = sbType;
}
break;
@@ -1244,16 +1255,20 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType(
SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getDynamicType()
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
return asExternal(getASTBuilder()->getSharedASTBuilder()->getDynamicType());
}
SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeRTTIMangledName(
slang::TypeReflection* type, ISlangBlob** outNameBlob)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
auto internalType = asInternal(type);
if (auto declRefType = as<DeclRefType>(internalType))
{
- auto name = getMangledName(internalType->getASTBuilder(), declRefType->declRef);
+ auto name = getMangledName(m_astBuilder, declRefType->getDeclRef());
Slang::ComPtr<ISlangBlob> blob = Slang::StringUtil::createStringBlob(name);
*outNameBlob = blob.detach();
return SLANG_OK;
@@ -1264,9 +1279,11 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeRTTIMangledName(
SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessMangledName(
slang::TypeReflection* type, slang::TypeReflection* interfaceType, ISlangBlob** outNameBlob)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
auto subType = asInternal(type);
auto supType = asInternal(interfaceType);
- auto name = getMangledNameForConformanceWitness(subType->getASTBuilder(), subType, supType);
+ auto name = getMangledNameForConformanceWitness(m_astBuilder, subType, supType);
Slang::ComPtr<ISlangBlob> blob = Slang::StringUtil::createStringBlob(name);
*outNameBlob = blob.detach();
return SLANG_OK;
@@ -1277,14 +1294,16 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessSequent
slang::TypeReflection* interfaceType,
uint32_t* outId)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
auto subType = asInternal(type);
auto supType = asInternal(interfaceType);
if (!subType || !supType)
return SLANG_FAIL;
- auto name = getMangledNameForConformanceWitness(subType->getASTBuilder(), subType, supType);
- auto interfaceName = getMangledTypeName(supType->getASTBuilder(), supType);
+ auto name = getMangledNameForConformanceWitness(m_astBuilder, subType, supType);
+ auto interfaceName = getMangledTypeName(m_astBuilder, supType);
uint32_t resultIndex = 0;
if (mapMangledNameToRTTIObjectIndex.tryGetValue(name, resultIndex))
{
@@ -1313,6 +1332,8 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentTy
SlangInt conformanceIdOverride,
ISlangBlob** outDiagnostics)
{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
RefPtr<TypeConformance> result;
DiagnosticSink sink;
try
@@ -1550,6 +1571,8 @@ CapabilitySet TargetRequest::getTargetCaps()
TypeLayout* TargetRequest::getTypeLayout(Type* type)
{
+ SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder());
+
// TODO: We are not passing in a `ProgramLayout` here, although one
// is nominally required to establish the global ordering of
// generic type parameters, which might be referenced from field types.
@@ -1866,6 +1889,9 @@ Type* ComponentType::getTypeFromString(
Scope* scope = _createScopeForLegacyLookup(astBuilder);
auto linkage = getLinkage();
+
+ SLANG_AST_BUILDER_RAII(linkage->getASTBuilder());
+
Expr* typeExpr = linkage->parseTermString(
typeStr, scope);
type = checkProperType(linkage, TypeExp(typeExpr), sink);
@@ -2172,6 +2198,8 @@ void FrontEndCompileRequest::parseTranslationUnit(
{
auto linkage = getLinkage();
+ SLANG_AST_BUILDER_RAII(linkage->getASTBuilder());
+
// TODO(JS): NOTE! Here we are using the searchDirectories on the linkage. This is because
// currently the API only allows the setting search paths on linkage.
//
@@ -2376,6 +2404,8 @@ void FrontEndCompileRequest::checkAllTranslationUnits()
void FrontEndCompileRequest::generateIR()
{
+ SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder());
+
// Our task in this function is to generate IR code
// for all of the declarations in the translation
// units that were loaded.
@@ -2469,6 +2499,8 @@ static SourceLanguage inferSourceLanguage(FrontEndCompileRequest* request)
SlangResult FrontEndCompileRequest::executeActionsInner()
{
+ SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder());
+
// We currently allow GlSL files on the command line so that we can
// drive our "pass-through" mode, but we really want to issue an error
// message if the user is seriously asking us to compile them.
@@ -3272,7 +3304,7 @@ Module::Module(Linkage* linkage, ASTBuilder* astBuilder)
}
else
{
- m_astBuilder = new ASTBuilder(linkage->getASTBuilder()->getSharedASTBuilder(), "Module");
+ m_astBuilder = linkage->getASTBuilder();
}
addModuleDependency(this);
@@ -4091,36 +4123,28 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor
maybeAddModule(module);
}
- void collectReferencedModules(Substitutions* substitution)
+ void collectReferencedModules(SubstitutionSet substitutions)
{
- if(auto genericSubst = as<GenericSubstitution>(substitution))
+ substitutions.forEachGenericSubstitution([this](GenericDecl*, Val::OperandView<Val> args)
{
- for(auto arg : genericSubst->getArgs())
+ for (auto arg : args)
{
collectReferencedModules(arg);
}
- }
- }
-
- void collectReferencedModules(SubstitutionSet const& substitutions)
- {
- for(auto subst = substitutions.substitutions; subst; subst = subst->getOuter())
- {
- collectReferencedModules(subst);
- }
+ });
}
- void collectReferencedModules(DeclRefBase const& declRef)
+ void collectReferencedModules(DeclRefBase* declRef)
{
- collectReferencedModules(declRef.getDecl());
- collectReferencedModules(declRef.getSubst());
+ collectReferencedModules(declRef->getDecl());
+ collectReferencedModules(SubstitutionSet(declRef));
}
void collectReferencedModules(Type* type)
{
if(auto declRefType = as<DeclRefType>(type))
{
- collectReferencedModules(declRefType->declRef);
+ collectReferencedModules(declRefType->getDeclRef());
}
// TODO: Handle non-decl-ref composite type cases
@@ -4135,7 +4159,7 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor
}
else if (auto declRefVal = as<GenericParamIntVal>(val))
{
- collectReferencedModules(declRefVal->declRef);
+ collectReferencedModules(declRefVal->getDeclRef());
}
// TODO: other cases of values that could reference
@@ -4350,41 +4374,6 @@ SpecializedComponentType::SpecializedComponentType(
m_moduleDependencies.add(module);
}
- // The following is a bit of a hack.
- //
- // TODO: We should not need this hack any longer, since the
- // new approach to `switch`-based dynamic dispatch has made
- // the existing tagged-union support obsolete.
- //
- // Back-end code generation relies on us having computed layouts for all tagged
- // unions that end up being used in the code, which means we need a way to find
- // all such types that get used in a program (and the stuff it imports).
- //
- // For now we are assuming a tagged union type only comes into existence
- // as a (top-level) argument for a generic type parameter, so that we
- // can check for them here and cache them on the entry point.
- //
- // A longer-term strategy might need to consider any (tagged or untagged)
- // union types that get used inside of a module, and also take
- // those lists into account.
- //
- // An even longer-term strategy would be to allow type layout to
- // be performed on IR types, so taht we don't need to have front-end
- // code worrying about this stuff.
- //
- for(auto arg : specializationArgs)
- {
- auto argType = as<Type>(arg.val);
- if(!argType)
- continue;
-
- auto taggedUnionType = as<TaggedUnionType>(argType);
- if(!taggedUnionType)
- continue;
-
- m_taggedUnionTypes.add(taggedUnionType);
- }
-
// Because we are specializing shader code, the mangled entry
// point names for this component type may be different than
// for the base component type (e.g., the mangled name for `f<int>`
diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis
index 74b625183..912a8f2a7 100644
--- a/source/slang/slang.natvis
+++ b/source/slang/slang.natvis
@@ -11,48 +11,66 @@
</Type>
<Type Name="Slang::DeclRef&lt;*&gt;">
<DisplayString Condition="declRefBase == 0">DeclRef nullptr</DisplayString>
+
<DisplayString Condition="declRefBase != 0">{*declRefBase}</DisplayString>
<Expand>
- <ExpandedItem>declRefBase ? ($T1*)(declRefBase->decl) : ($T1*)0</ExpandedItem>
- <Synthetic Name="[Substitutions]">
- <Expand>
- <LinkedListItems>
- <HeadPointer>declRefBase->substitutions</HeadPointer>
- <NextPointer>outer</NextPointer>
- <ValueNode>this</ValueNode>
- </LinkedListItems>
- </Expand>
- </Synthetic>
+ <ExpandedItem>declRefBase</ExpandedItem>
</Expand>
</Type>
<Type Name="Slang::DeclRefBase">
- <DisplayString Condition="decl != 0 &amp;&amp; substitutions != 0">{*decl}{*substitutions}</DisplayString>
- <DisplayString Condition="decl != 0">{*decl}</DisplayString>
- <DisplayString Condition="decl == 0">DeclRefBase nullptr</DisplayString>
- <Expand>
- <ExpandedItem>decl</ExpandedItem>
- <Synthetic Name="[Substitutions]">
- <Expand>
- <LinkedListItems>
- <HeadPointer>substitutions.substitutions</HeadPointer>
- <NextPointer>outer</NextPointer>
- <ValueNode>this</ValueNode>
- </LinkedListItems>
- </Expand>
- </Synthetic>
- </Expand>
- </Type>
- <Type Name="Slang::GenericSubstitution">
- <DisplayString>GenSubst {(*genericDecl).nameAndLoc}</DisplayString>
+ <DisplayString Optional="true" Condition="m_operands.m_buffer[0].values.nodeOperand != 0">{astNodeType,en}#{_debugUID}({(Decl*)m_operands.m_buffer[0].values.nodeOperand}) </DisplayString>
+ <DisplayString Condition="m_operands.m_buffer[0].values.nodeOperand != 0">{astNodeType,en}({(Decl*)m_operands.m_buffer[0].values.nodeOperand})</DisplayString>
+ <DisplayString Condition="m_operands.m_buffer[0].values.nodeOperand == 0">DeclRefBase nullptr</DisplayString>
<Expand>
- <Item Name="genericDecl">genericDecl</Item>
- <ExpandedItem>args</ExpandedItem>
+ <Synthetic Name="[Decl]">
+ <DisplayString>{*(Decl*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString>
+ <Expand>
+ <ExpandedItem>*(Decl*)m_operands.m_buffer[0].values.nodeOperand</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ <Synthetic Condition="astNodeType == Slang::ASTNodeType::MemberDeclRef" Name="[Parent]">
+ <DisplayString>{*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)}</DisplayString>
+ <Expand>
+ <ExpandedItem>*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ <Synthetic Condition="astNodeType == Slang::ASTNodeType::LookupDeclRef" Name="[Base]">
+ <DisplayString>{*(Val*)(this->m_operands.m_buffer[1].values.nodeOperand)}</DisplayString>
+ <Expand>
+ <ExpandedItem>*(Val*)(this->m_operands.m_buffer[1].values.nodeOperand)</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ <Synthetic Condition="astNodeType == Slang::ASTNodeType::LookupDeclRef" Name="[Witness]">
+ <DisplayString>{*(SubtypeWitness*)(this->m_operands.m_buffer[2].values.nodeOperand)}</DisplayString>
+ <Expand>
+ <ExpandedItem>*(SubtypeWitness*)(this->m_operands.m_buffer[2].values.nodeOperand)</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ <Synthetic Condition="astNodeType == Slang::ASTNodeType::GenericAppDeclRef" Name="[BaseGeneric]">
+ <DisplayString>{*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)}</DisplayString>
+ <Expand>
+ <ExpandedItem>*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ <CustomListItems Condition="astNodeType == Slang::ASTNodeType::GenericAppDeclRef">
+ <Variable Name="index" InitialValue="2"/>
+ <Loop Condition="index&lt;m_operands.m_count">
+ <Item Name="Arg[{index-2}]">*(Val*)(this->m_operands.m_buffer[index].values.nodeOperand)</Item>
+ <Exec>index=index+1</Exec>
+ </Loop>
+ </CustomListItems>
</Expand>
</Type>
<Type Name="Slang::DeclRefType">
- <DisplayString>DeclRefType {declRef}</DisplayString>
+ <DisplayString Optional="true">{astNodeType,en}#{_debugUID} {*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand} </DisplayString>
+
+ <DisplayString>{astNodeType,en} {*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString>
<Expand>
- <ExpandedItem>declRef</ExpandedItem>
+ <Synthetic Name="DeclRefType">
+ <DisplayString Optional="true">{astNodeType,en}#{_debugUID} {m_operands.m_buffer[0].values.nodeOperand->astNodeType, en}#{m_operands.m_buffer[0].values.nodeOperand->_debugUID}</DisplayString>
+ <DisplayString>{astNodeType,en} {m_operands.m_buffer[0].values.nodeOperand->astNodeType, en}</DisplayString>
+ </Synthetic>
+ <ExpandedItem>*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand</ExpandedItem>
</Expand>
</Type>
<Type Name="Slang::FuncDecl">
@@ -223,13 +241,12 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OpenRefExpr">(Slang::OpenRefExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ForwardDifferentiateExpr">(Slang::ForwardDifferentiateExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BackwardDifferentiateExpr">(Slang::BackwardDifferentiateExpr*)&amp;astNodeType</ExpandedItem>
- <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionTypeExpr">(Slang::TaggedUnionTypeExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisTypeExpr">(Slang::ThisTypeExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndTypeExpr">(Slang::AndTypeExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedTypeExpr">(Slang::ModifiedTypeExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PointerTypeExpr">(Slang::PointerTypeExpr*)&amp;astNodeType</ExpandedItem>
<Item Name="[type]">type</Item>
- <Item Name="[Expr]">(Slang::Expr*)this,nd</Item>
+ <Item Name="[Expr]">(Slang::Expr*)this,!</Item>
</Expand>
</Type>
<Type Name="Slang::Stmt" Inheritable="false">
@@ -261,18 +278,19 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ContinueStmt">(Slang::ContinueStmt*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ReturnStmt">(Slang::ReturnStmt*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExpressionStmt">(Slang::ExpressionStmt*)&amp;astNodeType</ExpandedItem>
- <Item Name="[Stmt]">(Slang::Stmt*)this,nd</Item>
+ <Item Name="[Stmt]">(Slang::Stmt*)this,!</Item>
</Expand>
</Type>
<Type Name="Slang::Name">
<DisplayString>{text}</DisplayString>
</Type>
<Type Name="Slang::Decl" Inheritable="false">
- <DisplayString Condition="nameAndLoc.name!=0">{nameAndLoc.name->text}</DisplayString>
+ <DisplayString Condition="nameAndLoc.name!=0">{astNodeType,en} {nameAndLoc.name->text}</DisplayString>
<DisplayString Condition="nameAndLoc.name==0">{astNodeType,en}</DisplayString>
<Expand>
<Item Name="[Name]" Condition="nameAndLoc.name!=0">nameAndLoc.name->text</Item>
<Item Name="[Parent]">parentDecl</Item>
+ <Item Name="[CheckState]">Slang::DeclCheckState(checkState.m_raw &amp; ~Slang::DeclCheckStateExt::kBeingCheckedBit)</Item>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ContainerDecl">(Slang::ContainerDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtensionDecl">(Slang::ExtensionDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::StructDecl">(Slang::StructDecl*)&amp;astNodeType</ExpandedItem>
@@ -314,7 +332,7 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::SyntaxDecl">(Slang::SyntaxDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclGroup">(Slang::DeclGroup*)&amp;astNodeType</ExpandedItem>
- <Item Name="Decl">(Slang::DeclBase*)this,nd</Item>
+ <Item Name="Decl">(Slang::DeclBase*)this,!</Item>
</Expand>
</Type>
@@ -361,20 +379,57 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::EmptyDecl">(Slang::EmptyDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::SyntaxDecl">(Slang::SyntaxDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclGroup">(Slang::DeclGroup*)&amp;astNodeType</ExpandedItem>
- <Item Name="Decl">(Slang::Decl*)this,nd</Item>
+ <Item Name="Decl">(Slang::Decl*)this,!</Item>
</Expand>
</Type>
+ <Type Name="Slang::TypeType" Inheritable="false">
+ <DisplayString Optional="true">{astNodeType,en} #{_debugUID} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString>
+ <DisplayString>{astNodeType,en} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString>
+ <Expand>
+ <ExpandedItem>*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)</ExpandedItem>
+ </Expand>
+ </Type>
+ <Type Name="Slang::FuncType" Inheritable="false">
+ <DisplayString Optional="true">{astNodeType,en} #{_debugUID}</DisplayString>
+ <DisplayString Optional="true">{astNodeType,en}</DisplayString>
+ <Expand>
+ <Synthetic Name="[ParamCount]">
+ <DisplayString>{m_operands.m_count-2}</DisplayString>
+ </Synthetic>
+ <ArrayItems>
+ <Size>m_operands.m_count-2</Size>
+ <ValuePointer>m_operands.m_buffer</ValuePointer>
+ </ArrayItems>
+ <Synthetic Name="[ResultType]">
+ <DisplayString>{m_operands.m_buffer[m_operands.m_count-2]}</DisplayString>
+ <Expand>
+ <ExpandedItem>m_operands.m_buffer[m_operands.m_count-2]</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ <Synthetic Name="[ErrorType]">
+ <DisplayString>{m_operands.m_buffer[m_operands.m_count-1]}</DisplayString>
+ <Expand>
+ <ExpandedItem>m_operands.m_buffer[m_operands.m_count-1]</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ </Expand>
+ </Type>
<Type Name="Slang::Type" Inheritable="false">
- <DisplayString Condition="astNodeType == Slang::ASTNodeType::DeclRefType">{((Slang::DeclRefType*)&amp;astNodeType)->declRef}</DisplayString>
+ <DisplayString Optional="true" Condition="astNodeType == Slang::ASTNodeType::DeclRefType">DeclRefType#{_debugUID} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::DeclRefType">DeclRefType {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString>
+ <DisplayString Optional="true">{astNodeType,en} #{_debugUID}</DisplayString>
<DisplayString>{astNodeType,en}</DisplayString>
- <Expand>
+
+ <Expand>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OverloadGroupType">(Slang::OverloadGroupType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::InitializerListType">(Slang::InitializerListType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ErrorType">(Slang::ErrorType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BottomType">(Slang::BottomType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclRefType">(Slang::DeclRefType*)&amp;astNodeType</ExpandedItem>
- <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DifferentialPairType">(Slang::DeclRefType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DifferentiableType">(Slang::DeclRefType*)&amp;astNodeType</ExpandedItem>
+
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DifferentialPairType">(Slang::DeclRefType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ArithmeticExpressionType">(Slang::ArithmeticExpressionType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BasicExpressionType">(Slang::BasicExpressionType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::VectorExpressionType">(Slang::VectorExpressionType*)&amp;astNodeType</ExpandedItem>
@@ -437,45 +492,51 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ArrayExpressionType">(Slang::ArrayExpressionType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TypeType">(Slang::TypeType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NamedExpressionType">(Slang::NamedExpressionType*)&amp;astNodeType</ExpandedItem>
- <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::FuncType">(Slang::FuncType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::FuncType">(Slang::FuncType*)this</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericDeclRefType">(Slang::GenericDeclRefType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NamespaceType">(Slang::NamespaceType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialType">(Slang::ExtractExistentialType*)&amp;astNodeType</ExpandedItem>
- <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionType">(Slang::TaggedUnionType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExistentialSpecializedType">(Slang::ExistentialSpecializedType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisType">(Slang::ThisType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndType">(Slang::AndType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedType">(Slang::ModifiedType*)&amp;astNodeType</ExpandedItem>
- <Item Name="[Type]">(Slang::Type*)this,nd</Item>
- </Expand>
- </Type>
- <Type Name="Slang::Substitutions" Inheritable="false">
- <DisplayString Condition="astNodeType == Slang::ASTNodeType::GenericSubstitution">{*(Slang::GenericSubstitution*)&amp;astNodeType}</DisplayString>
- <DisplayString Condition="astNodeType == Slang::ASTNodeType::ThisTypeSubstitution">{*(Slang::ThisTypeSubstitution*)&amp;astNodeType}</DisplayString>
- <DisplayString>{astNodeType,en}</DisplayString>
- <DisplayString>{astNodeType,en}</DisplayString>
- <Expand>
- <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericSubstitution">(Slang::GenericSubstitution*)&amp;astNodeType</ExpandedItem>
- <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisTypeSubstitution">(Slang::ThisTypeSubstitution*)&amp;astNodeType</ExpandedItem>
+ <Item Name="[Raw View]">(Slang::Type*)this,!</Item>
</Expand>
</Type>
- <Type Name="Slang::GenericSubstitution" Inheritable="false">
- <DisplayString Condition="outer != 0">&lt;{args}&gt;{*outer}</DisplayString>
- <DisplayString>&lt;{args}&gt;</DisplayString>
- </Type>
- <Type Name="Slang::ThisTypeSubstitution" Inheritable="false">
- <DisplayString Condition="outer != 0">{*outer}[This=={witness->sub,na}]</DisplayString>
- <DisplayString>[{witness->sup,na}.This: {witness->sub,na}]</DisplayString>
- </Type>
<Type Name="Slang::SubstitutionSet">
- <DisplayString>{astNodeType,en}</DisplayString>
+ <DisplayString>SubstitutionSet{declRef,en}</DisplayString>
<Expand>
- <LinkedListItems>
- <HeadPointer>substitutions</HeadPointer>
- <NextPointer>outer</NextPointer>
- <ValueNode>(Slang::Substitutions*)this</ValueNode>
- </LinkedListItems>
+ <ExpandedItem>declRef</ExpandedItem>
+ <CustomListItems MaxItemsPerView="24">
+ <Variable Name="subst" InitialValue="declRef"/>
+ <Variable Name="substType" InitialValue="(Slang::ASTNodeType)0"/>
+ <Variable Name="shouldBreak" InitialValue="0"/>
+ <Loop Condition="subst != 0">
+ <Exec>substType = subst->astNodeType </Exec>
+ <Exec>shouldBreak = 1 </Exec>
+
+ <If Condition="substType == Slang::ASTNodeType::DirectDeclRef">
+ <Break/>
+ </If>
+ <If Condition="substType == Slang::ASTNodeType::MemberDeclRef">
+ <Exec>subst = (DeclRefBase*)(((Slang::MemberDeclRef*)subst)->m_operands.m_buffer[1].values.nodeOperand)</Exec>
+ <Exec>shouldBreak = 0 </Exec>
+ </If>
+ <If Condition="substType == Slang::ASTNodeType::LookupDeclRef">
+ <Item>(LookupDeclRef*)subst</Item>
+ <Break/>
+ </If>
+ <If Condition="substType == Slang::ASTNodeType::GenericAppDeclRef">
+ <Item>(GenericAppDeclRef*)subst</Item>
+ <Exec>subst = (DeclRefBase*)(((Slang::GenericAppDeclRef*)subst)->m_operands.m_buffer[1].values.nodeOperand)</Exec>
+ <Exec>shouldBreak = 0 </Exec>
+ </If>
+ <If Condition="shouldBreak">
+ <Break/>
+ </If>
+ </Loop>
+ </CustomListItems>
</Expand>
</Type>
<Type Name="Slang::AggTypeDecl">
@@ -484,9 +545,103 @@
<Item Name="[Members]">members</Item>
</Expand>
</Type>
+ <Type Name="Slang::ValNodeOperand">
+ <DisplayString Condition="kind==Slang::ValNodeOperandKind::ConstantValue">Const({values.intOperand})</DisplayString>
+ <DisplayString Condition="kind==Slang::ValNodeOperandKind::ValNode">{*(Val*)values.nodeOperand}</DisplayString>
+ <DisplayString>{values.nodeOperand}</DisplayString>
+ <Expand>
+ <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ValNode">*(Val*)values.nodeOperand</ExpandedItem>
+ <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ASTNode">*values.nodeOperand</ExpandedItem>
+ </Expand>
+ </Type>
<Type Name="Slang::Val" Inheritable="false">
- <DisplayString>{astNodeType,en}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::DirectDeclRef">{*(Slang::DirectDeclRef*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::LookupDeclRef">{*(Slang::LookupDeclRef*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::MemberDeclRef">{*(Slang::MemberDeclRef*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::GenericAppDeclRef">{*(Slang::GenericAppDeclRef*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ConstantIntVal">{*(Slang::ConstantIntVal*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::PolynomialIntVal">{*(Slang::PolynomialIntVal*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::GenericParamIntVal">{*(Slang::GenericParamIntVal*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::DeclaredSubtypeWitness">{*(Slang::DeclaredSubtypeWitness*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::TransitiveSubtypeWitness">{*(Slang::TransitiveSubtypeWitness*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::OverloadGroupType">{*(Slang::OverloadGroupType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::InitializerListType">{*(Slang::InitializerListType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ErrorType">{*(Slang::ErrorType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::BottomType">{*(Slang::BottomType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::DeclRefType">{*(Slang::DeclRefType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::DifferentialPairType">{*(Slang::DeclRefType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ArithmeticExpressionType">{*(Slang::ArithmeticExpressionType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::BasicExpressionType">{*(Slang::BasicExpressionType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::VectorExpressionType">{*(Slang::VectorExpressionType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::MatrixExpressionType">{*(Slang::MatrixExpressionType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::BuiltinType">{*(Slang::BuiltinType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::FeedbackType">{*(Slang::FeedbackType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ResourceType">{*(Slang::ResourceType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::TextureTypeBase">{*(Slang::TextureTypeBase*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::TextureType">{*(Slang::TextureType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::TextureSamplerType">{*(Slang::TextureSamplerType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::GLSLImageType">{*(Slang::GLSLImageType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::SamplerStateType">{*(Slang::SamplerStateType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::BuiltinGenericType">{*(Slang::BuiltinGenericType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::PointerLikeType">{*(Slang::PointerLikeType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ParameterGroupType">{*(Slang::ParameterGroupType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::UniformParameterGroupType">{*(Slang::UniformParameterGroupType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ConstantBufferType">{*(Slang::ConstantBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::TextureBufferType">{*(Slang::TextureBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::GLSLShaderStorageBufferType">{*(Slang::GLSLShaderStorageBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ParameterBlockType">{*(Slang::ParameterBlockType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::VaryingParameterGroupType">{*(Slang::VaryingParameterGroupType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::GLSLInputParameterGroupType">{*(Slang::GLSLInputParameterGroupType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::GLSLOutputParameterGroupType">{*(Slang::GLSLOutputParameterGroupType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLStructuredBufferTypeBase">{*(Slang::HLSLStructuredBufferTypeBase*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLStructuredBufferType">{*(Slang::HLSLStructuredBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLRWStructuredBufferType">{*(Slang::HLSLRWStructuredBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLRasterizerOrderedStructuredBufferType">{*(Slang::HLSLRasterizerOrderedStructuredBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLAppendStructuredBufferType">{*(Slang::HLSLAppendStructuredBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLConsumeStructuredBufferType">{*(Slang::HLSLConsumeStructuredBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLStreamOutputType">{*(Slang::HLSLStreamOutputType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLPointStreamType">{*(Slang::HLSLPointStreamType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLLineStreamType">{*(Slang::HLSLLineStreamType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLTriangleStreamType">{*(Slang::HLSLTriangleStreamType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::UntypedBufferResourceType">{*(Slang::UntypedBufferResourceType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLByteAddressBufferType">{*(Slang::HLSLByteAddressBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLRWByteAddressBufferType">{*(Slang::HLSLRWByteAddressBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLRasterizerOrderedByteAddressBufferType">{*(Slang::HLSLRasterizerOrderedByteAddressBufferType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::RaytracingAccelerationStructureType">{*(Slang::RaytracingAccelerationStructureType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLPatchType">{*(Slang::HLSLPatchType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLInputPatchType">{*(Slang::HLSLInputPatchType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::HLSLOutputPatchType">{*(Slang::HLSLOutputPatchType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::GLSLInputAttachmentType">{*(Slang::GLSLInputAttachmentType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::StringTypeBase">{*(Slang::StringTypeBase*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::StringType">{*(Slang::StringType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::NativeStringType">{*(Slang::NativeStringType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::DynamicType">{*(Slang::DynamicType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::EnumTypeType">{*(Slang::EnumTypeType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::PtrTypeBase">{*(Slang::PtrTypeBase*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::PtrType">{*(Slang::PtrType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ParamDirectionType">{(Slang::ParamDirectionType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::OutTypeBase">{*(Slang::OutTypeBase*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::OutType">{*(Slang::OutType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::InOutType">{*(Slang::InOutType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::RefType">{*(Slang::RefType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::NullPtrType">{*(Slang::NullPtrType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ArrayExpressionType">{*(Slang::ArrayExpressionType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::TypeType">{*(Slang::TypeType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::NamedExpressionType">{*(Slang::NamedExpressionType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::FuncType">{*(Slang::FuncType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::GenericDeclRefType">{*(Slang::GenericDeclRefType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::NamespaceType">{*(Slang::NamespaceType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialType">{*(Slang::ExtractExistentialType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ExistentialSpecializedType">{*(Slang::ExistentialSpecializedType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ThisType">{*(Slang::ThisType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::AndType">{*(Slang::AndType*)this}</DisplayString>
+ <DisplayString Condition="astNodeType == Slang::ASTNodeType::ModifiedType">{*(Slang::ModifiedType*)this}</DisplayString>
+
<Expand>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DirectDeclRef">(Slang::DirectDeclRef*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::LookupDeclRef">(Slang::LookupDeclRef*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::MemberDeclRef">(Slang::MemberDeclRef*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericAppDeclRef">(Slang::GenericAppDeclRef*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ConstantIntVal">(Slang::ConstantIntVal*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PolynomialIntVal">(Slang::PolynomialIntVal*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericParamIntVal">(Slang::GenericParamIntVal*)&amp;astNodeType</ExpandedItem>
@@ -560,11 +715,15 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericDeclRefType">(Slang::GenericDeclRefType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NamespaceType">(Slang::NamespaceType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialType">(Slang::ExtractExistentialType*)&amp;astNodeType</ExpandedItem>
- <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionType">(Slang::TaggedUnionType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExistentialSpecializedType">(Slang::ExistentialSpecializedType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisType">(Slang::ThisType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndType">(Slang::AndType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedType">(Slang::ModifiedType*)&amp;astNodeType</ExpandedItem>
+ <Synthetic Name="[RawOperands]">
+ <Expand>
+ <ExpandedItem>m_operands</ExpandedItem>
+ </Expand>
+ </Synthetic>
</Expand>
</Type>
<Type Name="Slang::Facet">
@@ -594,7 +753,41 @@
</Type>
<Type Name="Slang::SubtypeWitness">
<DisplayString Condition="astNodeType == Slang::ASTNodeType::TypeEqualityWitness">{*(Slang::TypeEqualityWitness*)this}</DisplayString>
- <DisplayString>{sub,na} &lt;: {sup,na}</DisplayString>
+ <DisplayString Optional="true">{astNodeType,en}#{_debugUID}({*(Type*)m_operands.m_buffer[0].values.nodeOperand,na} &lt;: {*(Type*)m_operands.m_buffer[1].values.nodeOperand,na})</DisplayString>
+ <DisplayString>{astNodeType,en}({*(Type*)m_operands.m_buffer[0].values.nodeOperand,na} &lt;: {*(Type*)m_operands.m_buffer[1].values.nodeOperand,na})</DisplayString>
+
+ <Expand>
+ <Synthetic Name="[Sub]">
+ <DisplayString>{*(Type*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString>
+ <Expand>
+ <ExpandedItem>(Type*)m_operands.m_buffer[0].values.nodeOperand</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ <Synthetic Name="[Sup]">
+ <DisplayString>{*(Type*)m_operands.m_buffer[1].values.nodeOperand}</DisplayString>
+ <Expand>
+ <ExpandedItem>(Type*)m_operands.m_buffer[1].values.nodeOperand</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ <Synthetic Name="[DeclRef]" Condition="astNodeType == Slang::ASTNodeType::DeclaredSubtypeWitness">
+ <DisplayString>{*(Val*)m_operands.m_buffer[2].values.nodeOperand}</DisplayString>
+ <Expand>
+ <ExpandedItem>(DeclRefBase*)m_operands.m_buffer[2].values.nodeOperand</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ <Synthetic Name="[SubToMid]" Condition="astNodeType == Slang::ASTNodeType::TransitiveSubtypeWitness">
+ <DisplayString>{*(SubtypeWitness*)m_operands.m_buffer[2].values.nodeOperand}</DisplayString>
+ <Expand>
+ <ExpandedItem>(SubtypeWitness*)m_operands.m_buffer[2].values.nodeOperand</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ <Synthetic Name="[MidToSup]" Condition="astNodeType == Slang::ASTNodeType::TransitiveSubtypeWitness">
+ <DisplayString>{*(SubtypeWitness*)m_operands.m_buffer[3].values.nodeOperand}</DisplayString>
+ <Expand>
+ <ExpandedItem>(SubtypeWitness*)m_operands.m_buffer[3].values.nodeOperand</ExpandedItem>
+ </Expand>
+ </Synthetic>
+ </Expand>
</Type>
<Type Name="Slang::TypeEqualityWitness">
<DisplayString>{sub,na} == {sup,na}</DisplayString>