summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-09-13 13:11:48 -0700
committerGitHub <noreply@github.com>2022-09-13 13:11:48 -0700
commitf216b77752b9e4aea52882b2110ceb1cc64a2171 (patch)
treefbb33485b7260bc0f89b406e1be6fb8196f94196 /source
parent9f3e83cf0d664c87a618edf08d834829178030e6 (diff)
Deduplicate AST type nodes and cache lookup operations. (#2397)
* wip: dedup AST type nodes and cache lookup. * Fix. * Remove profiling. * Fixes. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/core/slang-array.h17
-rw-r--r--source/core/slang-short-list.h2
-rw-r--r--source/slang/slang-api.cpp10
-rw-r--r--source/slang/slang-ast-base.h37
-rw-r--r--source/slang/slang-ast-builder.cpp89
-rw-r--r--source/slang/slang-ast-builder.h211
-rw-r--r--source/slang/slang-ast-print.cpp2
-rw-r--r--source/slang/slang-ast-substitutions.cpp13
-rw-r--r--source/slang/slang-ast-type.cpp12
-rw-r--r--source/slang/slang-ast-type.h21
-rw-r--r--source/slang/slang-ast-val.cpp15
-rw-r--r--source/slang/slang-ast-val.h15
-rw-r--r--source/slang/slang-check-conformance.cpp18
-rw-r--r--source/slang/slang-check-constraint.cpp17
-rw-r--r--source/slang/slang-check-conversion.cpp2
-rw-r--r--source/slang/slang-check-decl.cpp124
-rw-r--r--source/slang/slang-check-expr.cpp17
-rw-r--r--source/slang/slang-check-impl.h24
-rw-r--r--source/slang/slang-check-overload.cpp10
-rw-r--r--source/slang/slang-check-shader.cpp20
-rw-r--r--source/slang/slang-check-stmt.cpp2
-rw-r--r--source/slang/slang-check-type.cpp22
-rw-r--r--source/slang/slang-lookup.cpp56
-rw-r--r--source/slang/slang-lower-to-ir.cpp20
-rw-r--r--source/slang/slang-mangle.cpp4
-rw-r--r--source/slang/slang-syntax.cpp98
-rw-r--r--source/slang/slang-syntax.h12
-rw-r--r--source/slang/slang.cpp8
28 files changed, 573 insertions, 325 deletions
diff --git a/source/core/slang-array.h b/source/core/slang-array.h
index 2647e163f..8d22e1848 100644
--- a/source/core/slang-array.h
+++ b/source/core/slang-array.h
@@ -125,6 +125,23 @@ namespace Slang
insertArray(rs, args...);
return rs;
}
+
+
+ template<typename TList>
+ void addToList(TList&)
+ {
+ }
+ template<typename TList, typename T>
+ void addToList(TList& list, T node)
+ {
+ list.add(node);
+ }
+ template<typename TList, typename T, typename ... TArgs>
+ void addToList(TList& list, T node, TArgs ... args)
+ {
+ list.add(node);
+ addToList(list, args...);
+ }
}
#endif
diff --git a/source/core/slang-short-list.h b/source/core/slang-short-list.h
index 7d51a8abf..5bad9faf8 100644
--- a/source/core/slang-short-list.h
+++ b/source/core/slang-short-list.h
@@ -288,6 +288,8 @@ namespace Slang
void addRange(ArrayView<T> list) { addRange(list.m_buffer, list.m_count); }
+ void addRange(ConstArrayView<T> list) { addRange(list.m_buffer, list.m_count); }
+
template<int _otherShortListSize, typename TOtherAllocator>
void addRange(const ShortList<T, _otherShortListSize, TOtherAllocator>& list)
{
diff --git a/source/slang/slang-api.cpp b/source/slang/slang-api.cpp
index a3b1e5409..45c583060 100644
--- a/source/slang/slang-api.cpp
+++ b/source/slang/slang-api.cpp
@@ -98,11 +98,19 @@ SLANG_API SlangResult slang_createGlobalSession(
{
Slang::String cacheFilename;
uint64_t dllTimestamp = 0;
+#define SLANG_PROFILE_STDLIB_COMPILE 0
+#if SLANG_PROFILE_STDLIB_COMPILE
+ auto startTime = std::chrono::high_resolution_clock::now();
+#else
if (tryLoadStdLibFromCache(globalSession, cacheFilename, dllTimestamp) != SLANG_OK)
+#endif
{
// Compile std lib from embeded source.
SLANG_RETURN_ON_FAIL(globalSession->compileStdLib(0));
-
+#if SLANG_PROFILE_STDLIB_COMPILE
+ auto timeElapsed = std::chrono::high_resolution_clock::now() - startTime;
+ printf("stdlib compilation time: %.1fms\n", timeElapsed.count() / 1000000.0);
+#endif
// Store the compiled stdlib to cache file.
trySaveStdLibToCache(globalSession, cacheFilename, dllTimestamp);
}
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index 3126aab71..dd08fece2 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -22,7 +22,18 @@ class NodeBase
// MUST be called before used. Called automatically via the ASTBuilder.
// Note that the astBuilder is not stored in the NodeBase derived types by default.
- SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* /* astBuilder*/ ) { astNodeType = inAstNodeType; }
+ SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* /* astBuilder*/)
+ {
+ astNodeType = inAstNodeType;
+#ifdef _DEBUG
+ static uint32_t uidCounter = 0;
+ static uint32_t breakValue = 0;
+ uidCounter++;
+ _debugUID = uidCounter;
+ if (breakValue != 0 && _debugUID == breakValue)
+ SLANG_BREAKPOINT(0)
+#endif
+ }
/// Get the class info
SLANG_FORCE_INLINE const ReflectClassInfo& getClassInfo() const { return *ASTClassInfo::getInfo(astNodeType); }
@@ -36,6 +47,9 @@ class NodeBase
// Handy when debugging, shouldn't be checked in though!
// virtual ~NodeBase() {}
+#ifdef _DEBUG
+ SLANG_UNREFLECTED uint32_t _debugUID = 0;
+#endif
};
// Casting of NodeBase
@@ -228,13 +242,28 @@ class GenericSubstitution : public Substitutions
// parameters we are binding to arguments
GenericDecl* genericDecl = nullptr;
+private:
// The actual values of the arguments
List<Val* > args;
-
+public:
+ const List<Val*>& getArgs() const { return args; }
// Overrides should be public so base classes can access
Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff);
bool _equalsOverride(Substitutions* subst);
HashCode _getHashCodeOverride() const;
+
+ GenericSubstitution(GenericDecl* decl)
+ {
+ genericDecl = decl;
+ }
+
+ template<typename... TArgs>
+ GenericSubstitution(GenericDecl* decl, TArgs... inArgs)
+ {
+ genericDecl = decl;
+ addToList(args, inArgs...);
+ }
+
};
class ThisTypeSubstitution : public Substitutions
@@ -253,6 +282,10 @@ class ThisTypeSubstitution : public Substitutions
Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff);
bool _equalsOverride(Substitutions* subst);
HashCode _getHashCodeOverride() const;
+
+ ThisTypeSubstitution(InterfaceDecl* inInterfaceDecl, SubtypeWitness* inWitness)
+ : interfaceDecl(inInterfaceDecl), witness(inWitness)
+ {}
};
class SyntaxNode : public SyntaxNodeBase
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index fa8051171..f8c208ac1 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -224,7 +224,7 @@ NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType)
Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTypeName)
{
- auto declRef = getBuiltinDeclRef(magicTypeName, makeConstArrayViewSingle<Val*>(typeParam));
+ auto declRef = getBuiltinDeclRef(magicTypeName, typeParam);
auto rsType = DeclRefType::create(this, declRef);
return rsType;
}
@@ -263,9 +263,12 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName)
ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount)
{
- ArrayExpressionType* arrayType = create<ArrayExpressionType>();
- arrayType->baseType = elementType;
- arrayType->arrayLength = elementCount;
+ ArrayExpressionType* arrayType = getOrCreateWithDefaultCtor<ArrayExpressionType>(elementType, elementCount);
+ if (!arrayType->baseType)
+ {
+ arrayType->baseType = elementType;
+ arrayType->arrayLength = elementCount;
+ }
return arrayType;
}
@@ -273,18 +276,15 @@ VectorExpressionType* ASTBuilder::getVectorType(
Type* elementType,
IntVal* elementCount)
{
- auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector"));
-
- auto vectorTypeDecl = vectorGenericDecl->inner;
-
- auto substitutions = create<GenericSubstitution>();
- substitutions->genericDecl = vectorGenericDecl;
- substitutions->args.add(elementType);
- substitutions->args.add(elementCount);
-
- auto declRef = DeclRef<Decl>(vectorTypeDecl, substitutions);
-
- return as<VectorExpressionType>(DeclRefType::create(this, declRef));
+ auto result = getOrCreate<VectorExpressionType>(elementType, elementCount);
+ if (!result->declRef.decl)
+ {
+ auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector"));
+ auto vectorTypeDecl = vectorGenericDecl->inner;
+ auto substitutions = getOrCreate<GenericSubstitution>(vectorGenericDecl, elementType, elementCount);
+ result->declRef = DeclRef<Decl>(vectorTypeDecl, substitutions);
+ }
+ return result;
}
DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* conformanceWitness)
@@ -293,10 +293,7 @@ DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witne
auto typeDecl = genericDecl->inner;
- auto substitutions = create<GenericSubstitution>();
- substitutions->genericDecl = genericDecl;
- substitutions->args.add(valueType);
- substitutions->args.add(conformanceWitness);
+ auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, valueType, conformanceWitness);
auto declRef = DeclRef<Decl>(typeDecl, substitutions);
auto rsType = DeclRefType::create(this, declRef);
@@ -311,34 +308,29 @@ DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterface()
return declRef;
}
-DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ConstArrayView<Val*> genericArgs)
+DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg)
{
DeclRef<Decl> declRef;
declRef.decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
if (auto genericDecl = as<GenericDecl>(declRef.decl))
{
- if (genericArgs.getCount())
+ if (genericArg)
{
- auto substitutions = create<GenericSubstitution>();
- substitutions->genericDecl = genericDecl;
- for (auto arg : genericArgs)
- substitutions->args.add(arg);
+ auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, genericArg);
declRef.substitutions = substitutions;
}
declRef.decl = genericDecl->inner;
}
else
{
- SLANG_ASSERT(genericArgs.getCount() == 0);
+ SLANG_ASSERT(!genericArg);
}
return declRef;
}
Type* ASTBuilder::getAndType(Type* left, Type* right)
{
- auto type = create<AndType>();
- type->left = left;
- type->right = right;
+ auto type = getOrCreate<AndType>(left, right);
return type;
}
@@ -350,46 +342,26 @@ Type* ASTBuilder::getModifiedType(Type* base, Count modifierCount, Val* const* m
return type;
}
-NodeBase* ASTBuilder::_getOrCreateImpl(NodeDesc const& desc, NodeCreateFunc createFunc, void* createFuncUserData)
-{
- if(auto found = m_cachedNodes.TryGetValue(desc))
- return *found;
-
- auto node = createFunc(this, desc, createFuncUserData);
-
- auto operandCount = desc.operandCount;
- NodeBase** operandsCopy = m_arena.allocateAndZeroArray<NodeBase*>(desc.operandCount);
- for(Index i = 0; i < operandCount; ++i)
- operandsCopy[i] = desc.operands[i];
-
- NodeDesc descCopy = desc;
- descCopy.operands = operandsCopy;
- m_cachedNodes.Add(descCopy, node);
-
- return node;
-}
-
-
Val* ASTBuilder::getUNormModifierVal()
{
- return _getOrCreate<UNormModifierVal>();
+ return getOrCreate<UNormModifierVal>();
}
Val* ASTBuilder::getSNormModifierVal()
{
- return _getOrCreate<SNormModifierVal>();
+ return getOrCreate<SNormModifierVal>();
}
TypeType* ASTBuilder::getTypeType(Type* type)
{
- return create<TypeType>(type);
+ return getOrCreate<TypeType>(type);
}
bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const
{
if(type != that.type) return false;
- if(operandCount != that.operandCount) return false;
- for(Index i = 0; i < operandCount; ++i)
+ if(operands.getCount() != that.operands.getCount()) return false;
+ for(Index i = 0; i < operands.getCount(); ++i)
{
// Note: we are comparing the operands directly for identity
// (pointer equality) rather than doing the `Val`-level
@@ -399,7 +371,7 @@ bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const
// via a `NodeDesc` *should* all be going through the
// deduplication path anyway, as should their operands.
//
- if(operands[i] != that.operands[i]) return false;
+ if(operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false;
}
return true;
}
@@ -407,15 +379,14 @@ HashCode ASTBuilder::NodeDesc::getHashCode() const
{
Hasher hasher;
hasher.hashValue(Int(type));
- hasher.hashValue(operandCount);
- for(Index i = 0; i < operandCount; ++i)
+ for(Index i = 0; i < operands.getCount(); ++i)
{
// Note: we are hashing the raw pointer value rather
// than the content of the value node. This is done
// to match the semantics implemented for `==` on
// `NodeDesc`.
//
- hasher.hashValue((void*) operands[i]);
+ hasher.hashValue(operands[i].values.nodeOperand);
}
return hasher.getResult();
}
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index d363cb8f6..788257fa0 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -98,6 +98,49 @@ class ASTBuilder : public RefObject
{
friend class SharedASTBuilder;
public:
+ // Node cache:
+ struct NodeOperand
+ {
+ union
+ {
+ NodeBase* nodeOperand;
+ int64_t intOperand;
+ } values;
+ NodeOperand() { values.nodeOperand = nullptr; }
+ NodeOperand(NodeBase* node) { values.nodeOperand = node; }
+ template<typename EnumType>
+ NodeOperand(EnumType intVal)
+ {
+ static_assert(sizeof(EnumType) <= sizeof(values), "size of operand must be less than pointer size.");
+ values.intOperand = 0;
+ memcpy(&values, &intVal, sizeof(intVal));
+ }
+ };
+ struct NodeDesc
+ {
+ ASTNodeType type;
+ ShortList<NodeOperand, 4> operands;
+
+ bool operator==(NodeDesc const& that) const;
+ HashCode getHashCode() const;
+ };
+
+ template<typename NodeCreateFunc>
+ NodeBase* _getOrCreateImpl(NodeDesc const& desc, NodeCreateFunc createFunc)
+ {
+ if (auto found = m_cachedNodes.TryGetValue(desc))
+ return *found;
+
+ auto node = createFunc();
+ m_cachedNodes.Add(desc, node);
+ return node;
+ }
+
+ /// A cache for AST nodes that are entirely defined by their node type, with
+ /// no need for additional state.
+ Dictionary<NodeDesc, NodeBase*> m_cachedNodes;
+
+public:
// For compile time check to see if thing being constructed is an AST type
template <typename T>
@@ -113,10 +156,109 @@ public:
/// Create AST types
template <typename T>
T* create() { return _initAndAdd(new (m_arena.allocate(sizeof(T))) T); }
- template<typename T, typename P0>
- T* create(const P0& p0) { return _initAndAdd(new (m_arena.allocate(sizeof(T))) T(p0)); }
- template<typename T, typename P0, typename P1>
- T* create(const P0& p0, const P1& p1) { return _initAndAdd(new (m_arena.allocate(sizeof(T))) T(p0, p1));}
+ template<typename T, typename... TArgs>
+ T* create(TArgs... args) { return _initAndAdd(new (m_arena.allocate(sizeof(T))) T(args...)); }
+
+ template<typename T, typename ... TArgs>
+ SLANG_FORCE_INLINE T* getOrCreate(TArgs ... args)
+ {
+ SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value);
+ NodeDesc desc;
+ desc.type = T::kType;
+ addToList(desc.operands, args...);
+ return (T*)_getOrCreateImpl(desc, [&]()
+ {
+ return create<T>(args...);
+ });
+ }
+
+ template<typename T>
+ SLANG_FORCE_INLINE T* getOrCreate()
+ {
+ SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value);
+
+ NodeDesc desc;
+ desc.type = T::kType;
+ return (T*)_getOrCreateImpl(desc, [this]() { return create<T>(); });
+ }
+
+ template<typename T, typename ... TArgs>
+ SLANG_FORCE_INLINE T* getOrCreateWithDefaultCtor(TArgs ... args)
+ {
+ SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value);
+ NodeDesc desc;
+ desc.type = T::kType;
+ addToList(desc.operands, args...);
+ return (T*)_getOrCreateImpl(desc, [&]()
+ {
+ return create<T>();
+ });
+ }
+
+ template<typename T>
+ SLANG_FORCE_INLINE T* getOrCreateWithDefaultCtor(ConstArrayView<NodeOperand> operands)
+ {
+ SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value);
+ NodeDesc desc;
+ desc.type = T::kType;
+ desc.operands.addRange(operands);
+ return (T*)_getOrCreateImpl(desc, [&]()
+ {
+ return create<T>();
+ });
+ }
+
+ DeclRefType* getOrCreateDeclRefType(Decl* decl, Substitutions* outer)
+ {
+ NodeDesc desc;
+ desc.type = DeclRefType::kType;
+ desc.operands.add(decl);
+ if (outer)
+ {
+ desc.operands.add(outer);
+ }
+ auto result = (DeclRefType*)_getOrCreateImpl(desc, [&]() {return create<DeclRefType>(decl, outer); });
+ return result;
+ }
+
+ GenericSubstitution* getOrCreateGenericSubstitution(GenericDecl* decl, const List<Val*>& args, Substitutions* outer)
+ {
+ NodeDesc desc;
+ desc.type = GenericSubstitution::kType;
+ desc.operands.add(decl);
+ for (auto arg : args)
+ desc.operands.add(arg);
+ if (outer)
+ {
+ desc.operands.add(outer);
+ }
+ auto result = (GenericSubstitution*)_getOrCreateImpl(desc, [this]() {return create<GenericSubstitution>(); });
+ if (result->args.getCount() != args.getCount())
+ {
+ SLANG_RELEASE_ASSERT(result->args.getCount() == 0);
+ result->args.addRange(args);
+ result->genericDecl = decl;
+ result->outer = outer;
+ }
+ return result;
+ }
+
+ ThisTypeSubstitution* getOrCreateThisTypeSubstitution(InterfaceDecl* interfaceDecl, SubtypeWitness* subtypeWitness, Substitutions* outer)
+ {
+ NodeDesc desc;
+ desc.type = ThisTypeSubstitution::kType;
+ desc.operands.add(interfaceDecl);
+ desc.operands.add(subtypeWitness);
+ if (outer)
+ {
+ desc.operands.add(outer);
+ }
+ auto result = (ThisTypeSubstitution*)_getOrCreateImpl(desc, [this]() {return create<ThisTypeSubstitution>(); });
+ result->interfaceDecl = interfaceDecl;
+ result->witness = subtypeWitness;
+ result->outer = outer;
+ return result;
+ }
NodeBase* createByNodeType(ASTNodeType nodeType);
@@ -173,7 +315,7 @@ public:
DeclRef<InterfaceDecl> getDifferentiableInterface();
- DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, ConstArrayView<Val*> genericArgs);
+ DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg);
Type* getAndType(Type* left, Type* right);
@@ -208,10 +350,13 @@ public:
/// Dtor
~ASTBuilder();
+ Dictionary<Decl*, GenericSubstitution*> m_genericDefaultSubst;
+
protected:
// Special default Ctor that can only be used by SharedASTBuilder
ASTBuilder();
+
template <typename T>
SLANG_FORCE_INLINE T* _initAndAdd(T* node)
{
@@ -237,62 +382,6 @@ protected:
MemoryArena m_arena;
- struct NodeDesc
- {
- ASTNodeType type;
- Count operandCount = 0;
- NodeBase* const* operands = nullptr;
-
- bool operator==(NodeDesc const& that) const;
- HashCode getHashCode() const;
- };
-
- /// A cache for AST nodes that are entirely defined by their node type, with
- /// no need for additional state.
- Dictionary<NodeDesc, NodeBase*> m_cachedNodes;
-
-
- typedef NodeBase* (*NodeCreateFunc)(ASTBuilder* astBuilder, NodeDesc const& desc, void* userData);
-
- NodeBase* _getOrCreateImpl(NodeDesc const& desc, NodeCreateFunc createFunc, void* createFuncUserData);
-
- template<typename T>
- SLANG_FORCE_INLINE T* _getOrCreate(Count operandCount, NodeBase* const* operands)
- {
- SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value);
-
- struct Helper
- {
- static NodeBase* create(ASTBuilder* astBuilder, NodeDesc const& desc, void* /*userData*/)
- {
- return astBuilder->create<T>(desc.operandCount, desc.operands);
- }
- };
-
- NodeDesc desc;
- desc.type = T::kType;
- desc.operandCount = operandCount;
- desc.operands = operands;
- return (T*) _getOrCreateImpl(desc, &Helper::create, nullptr);
- }
-
- template<typename T>
- SLANG_FORCE_INLINE T* _getOrCreate()
- {
- SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value);
-
- struct Helper
- {
- static NodeBase* create(ASTBuilder* astBuilder, NodeDesc const& /*desc*/, void* /*userData*/)
- {
- return astBuilder->create<T>();
- }
- };
-
- NodeDesc desc;
- desc.type = T::kType;
- return (T*) _getOrCreateImpl(desc, &Helper::create, nullptr);
- }
};
} // namespace Slang
diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp
index e81191086..830c6bf34 100644
--- a/source/slang/slang-ast-print.cpp
+++ b/source/slang/slang-ast-print.cpp
@@ -190,7 +190,7 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth)
sb << "<";
bool first = true;
- for (auto arg : genSubst->args)
+ for (auto arg : genSubst->getArgs())
{
// When printing the representation of a specialized
// generic declaration we don't want to include the
diff --git a/source/slang/slang-ast-substitutions.cpp b/source/slang/slang-ast-substitutions.cpp
index 6656f1fa6..7c6cf1bc6 100644
--- a/source/slang/slang-ast-substitutions.cpp
+++ b/source/slang/slang-ast-substitutions.cpp
@@ -63,10 +63,8 @@ Substitutions* GenericSubstitution::_applySubstitutionsShallowOverride(ASTBuilde
if (!diff) return this;
(*ioDiff)++;
- auto substSubst = astBuilder->create<GenericSubstitution>();
- substSubst->genericDecl = genericDecl;
- substSubst->args = substArgs;
- substSubst->outer = substOuter;
+
+ auto substSubst = astBuilder->getOrCreateGenericSubstitution(genericDecl, substArgs, substOuter);
return substSubst;
}
@@ -126,10 +124,9 @@ Substitutions* ThisTypeSubstitution::_applySubstitutionsShallowOverride(ASTBuild
if (!diff) return this;
(*ioDiff)++;
- auto substSubst = astBuilder->create<ThisTypeSubstitution>();
- substSubst->interfaceDecl = interfaceDecl;
- substSubst->witness = substWitness;
- substSubst->outer = substOuter;
+ ThisTypeSubstitution* substSubst;
+
+ substSubst = astBuilder->getOrCreateThisTypeSubstitution(interfaceDecl, substWitness, substOuter);
return substSubst;
}
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index a84f04a32..077d6de0a 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -224,7 +224,7 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe
{
// We've found it, so return the corresponding specialization argument
(*ioDiff)++;
- return genericSubst->args[index];
+ return genericSubst->getArgs()[index];
}
else if (auto typeParam = as<GenericTypeParamDecl>(m))
{
@@ -351,17 +351,17 @@ BasicExpressionType* MatrixExpressionType::_getScalarTypeOverride()
Type* MatrixExpressionType::getElementType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]);
+ return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]);
}
IntVal* MatrixExpressionType::getRowCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]);
+ return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]);
}
IntVal* MatrixExpressionType::getColumnCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[2]);
+ return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[2]);
}
Type* MatrixExpressionType::getRowType()
@@ -518,12 +518,12 @@ Type* NamespaceType::_createCanonicalTypeOverride()
Type* PtrTypeBase::getValueType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]);
+ return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]);
}
Type* OptionalType::getValueType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]);
+ return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 458c24a23..5bb91e5da 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -78,9 +78,13 @@ class DeclRefType : public Type
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
protected:
- DeclRefType( DeclRef<Decl> declRef)
+ DeclRefType(DeclRef<Decl> declRef)
: declRef(declRef)
{}
+
+ DeclRefType(Decl* decl, Substitutions* substitutions)
+ : declRef(decl, substitutions)
+ {}
};
// Base class for types that can be used in arithmetic expressions
@@ -210,6 +214,11 @@ class SamplerStateType : public BuiltinType
// What flavor of sampler state is this
SamplerStateFlavor flavor;
+
+ SamplerStateType(SamplerStateFlavor inFlavor)
+ {
+ flavor = inFlavor;
+ }
};
// Other cases of generic types known to the compiler
@@ -467,6 +476,10 @@ class VectorExpressionType : public ArithmeticExpressionType
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
BasicExpressionType* _getScalarTypeOverride();
+
+ VectorExpressionType(Type* inElementType, IntVal* inElementCount)
+ : elementType(inElementType), elementCount(inElementCount)
+ {}
};
// A matrix type, e.g., `matrix<T,R,C>`
@@ -486,6 +499,8 @@ class MatrixExpressionType : public ArithmeticExpressionType
private:
Type* rowType = nullptr;
+
+ MatrixExpressionType(Type*, IntVal*, IntVal*) {}
};
// Base class for built in string types
@@ -794,6 +809,10 @@ class AndType : public Type
Type* left;
Type* right;
+ AndType(Type* leftType, Type* rightType)
+ : left(leftType), right(rightType)
+ {}
+
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
bool _equalsImplOverride(Type* type);
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 64d32f4b4..a70a79535 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -137,7 +137,7 @@ Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, S
{
// We've found it, so return the corresponding specialization argument
(*ioDiff)++;
- return genSubst->args[index];
+ return genSubst->getArgs()[index];
}
else if (auto typeParam = as<GenericTypeParamDecl>(m))
{
@@ -265,8 +265,8 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
(*ioDiff)++;
auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() +
genericDecl->getMembersOfType<GenericValueParamDecl>().getCount();
- SLANG_ASSERT(index + ordinaryParamCount < genericSubst->args.getCount());
- return genericSubst->args[index + ordinaryParamCount];
+ SLANG_ASSERT(index + ordinaryParamCount < genericSubst->getArgs().getCount());
+ return genericSubst->getArgs()[index + ordinaryParamCount];
}
}
}
@@ -323,7 +323,8 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
}
}
- DeclaredSubtypeWitness* rs = astBuilder->create<DeclaredSubtypeWitness>();
+ DeclaredSubtypeWitness* rs = astBuilder->getOrCreate<DeclaredSubtypeWitness>(
+ substSub, substSup, substDeclRef.getDecl(), substDeclRef.substitutions.substitutions);
rs->sub = substSub;
rs->sup = substSup;
rs->declRef = substDeclRef;
@@ -722,7 +723,7 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut
*ioDiff += diff;
if (evaluatedTerms.getCount() == 0)
- return astBuilder->create<ConstantIntVal>(type, evaluatedConstantTerm);
+ return astBuilder->getOrCreate<ConstantIntVal>(type, evaluatedConstantTerm);
if (diff != 0)
{
auto newPolynomial = astBuilder->create<PolynomialIntVal>(type);
@@ -1035,7 +1036,7 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder)
return terms[0]->paramFactors[0]->param;
}
if (terms.getCount() == 0)
- return builder->create<ConstantIntVal>(type, constantTerm);
+ return builder->getOrCreate<ConstantIntVal>(type, constantTerm);
return this;
}
@@ -1207,7 +1208,7 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
{
SLANG_UNREACHABLE("constant folding of FuncCallIntVal");
}
- return astBuilder->create<ConstantIntVal>(resultType, resultValue);
+ return astBuilder->getOrCreate<ConstantIntVal>(resultType, resultValue);
}
return nullptr;
}
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 69797b3a5..0f7bdec8e 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -53,6 +53,10 @@ class GenericParamIntVal : public IntVal
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ GenericParamIntVal(Type* inType, VarDeclBase* inDecl, Substitutions* inSubst)
+ : IntVal(inType), declRef(inDecl, inSubst)
+ {}
+
protected:
GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef)
: IntVal(inType), declRef(inDeclRef)
@@ -315,6 +319,13 @@ class DeclaredSubtypeWitness : public SubtypeWitness
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+
+ DeclaredSubtypeWitness(Type* inSub, Type* inSup, Decl* decl, Substitutions* subst)
+ : declRef(decl, subst)
+ {
+ sub = inSub;
+ sup = inSup;
+ }
};
// A witness that `sub : sup` because `sub : mid` and `mid : sup`
@@ -333,6 +344,10 @@ class TransitiveSubtypeWitness : public SubtypeWitness
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+
+ TransitiveSubtypeWitness(SubtypeWitness* inSubToMid, SubtypeWitness* inMidToSup)
+ : subToMid(inSubToMid), midToSup(inMidToSup)
+ {}
};
// A witness taht `sub : sup` because `sub` was wrapped into
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index 5889c6140..e0c1f3702 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -10,10 +10,11 @@ namespace Slang
DeclaredSubtypeWitness* SemanticsVisitor::createSimpleSubtypeWitness(
TypeWitnessBreadcrumb* breadcrumb)
{
- DeclaredSubtypeWitness* witness = m_astBuilder->create<DeclaredSubtypeWitness>();
- witness->sub = breadcrumb->sub;
- witness->sup = breadcrumb->sup;
- witness->declRef = breadcrumb->declRef;
+ DeclaredSubtypeWitness* witness = m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(
+ breadcrumb->sub,
+ breadcrumb->sup,
+ breadcrumb->declRef.decl,
+ breadcrumb->declRef.substitutions.substitutions);
return witness;
}
@@ -82,12 +83,11 @@ namespace Slang
// where `[...]` represents the "hole" we leave
// open to fill in next.
//
- DeclaredSubtypeWitness* declaredWitness = m_astBuilder->create<DeclaredSubtypeWitness>();
- declaredWitness->sub = bb->sub;
- declaredWitness->sup = bb->sup;
- declaredWitness->declRef = bb->declRef;
+ DeclaredSubtypeWitness* declaredWitness =
+ m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(
+ bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions);
- TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->create<TransitiveSubtypeWitness>();
+ TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, bb->sup, declaredWitness);
transitiveWitness->sub = subType;
transitiveWitness->sup = bb->sup;
transitiveWitness->midToSup = declaredWitness;
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index a03043f31..129d3ed0c 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -391,11 +391,8 @@ namespace Slang
// search for a conformance `Robin : ISidekick`, which involved
// apply the substitutions we already know...
- GenericSubstitution* solvedSubst = m_astBuilder->create<GenericSubstitution>();
- solvedSubst->genericDecl = genericDeclRef.getDecl();
- solvedSubst->outer = genericDeclRef.substitutions.substitutions;
- solvedSubst->args = args;
- resultSubst.substitutions = solvedSubst;
+ GenericSubstitution* solvedSubst = m_astBuilder->getOrCreateGenericSubstitution(
+ genericDeclRef.getDecl(), args, genericDeclRef.substitutions.substitutions);
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
@@ -412,7 +409,7 @@ namespace Slang
if(subTypeWitness)
{
// We found a witness, so it will become an (implicit) argument.
- solvedSubst->args.add(subTypeWitness);
+ args.add(subTypeWitness);
}
else
{
@@ -437,6 +434,8 @@ namespace Slang
}
}
+ resultSubst = m_astBuilder->getOrCreateGenericSubstitution(
+ genericDeclRef.getDecl(), args, genericDeclRef.substitutions.substitutions);
return resultSubst;
}
@@ -546,12 +545,12 @@ namespace Slang
return false;
// Their arguments must unify
- SLANG_RELEASE_ASSERT(fstGen->args.getCount() == sndGen->args.getCount());
- Index argCount = fstGen->args.getCount();
+ SLANG_RELEASE_ASSERT(fstGen->getArgs().getCount() == sndGen->getArgs().getCount());
+ Index argCount = fstGen->getArgs().getCount();
bool okay = true;
for (Index aa = 0; aa < argCount; ++aa)
{
- if (!TryUnifyVals(constraints, fstGen->args[aa], sndGen->args[aa]))
+ if (!TryUnifyVals(constraints, fstGen->getArgs()[aa], sndGen->getArgs()[aa]))
{
okay = false;
}
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 4efacd703..2f5447ffb 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -349,7 +349,7 @@ namespace Slang
// We have a new type for the conversion, based on what
// we learned.
toType = m_astBuilder->getArrayType(toElementType,
- m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount));
+ m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount));
}
}
else if(auto toMatrixType = as<MatrixExpressionType>(toType))
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 9f19023ee..125f0cb08 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -519,42 +519,73 @@ namespace Slang
return semantics->ApplyExtensionToType(extDecl, type);
}
+ void ensureDecl(SemanticsVisitor* visitor, Decl* decl, DeclCheckState state)
+ {
+ visitor->ensureDecl(decl, state);
+ }
+
GenericSubstitution* createDefaultSubstitutionsForGeneric(
- ASTBuilder* astBuilder,
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
GenericDecl* genericDecl,
Substitutions* outerSubst)
{
- GenericSubstitution* genericSubst = astBuilder->create<GenericSubstitution>();
- genericSubst->genericDecl = genericDecl;
- genericSubst->outer = outerSubst;
+ GenericSubstitution* cachedResult = nullptr;
+ if (astBuilder->m_genericDefaultSubst.TryGetValue(genericDecl, cachedResult))
+ {
+ if (cachedResult->outer == outerSubst)
+ return cachedResult;
+ }
+
+ List<Val*> args;
for( auto mm : genericDecl->members )
{
if( auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm) )
{
- genericSubst->args.add(DeclRefType::create(astBuilder, DeclRef<Decl>(genericTypeParamDecl, outerSubst)));
+ args.add(DeclRefType::create(astBuilder, DeclRef<Decl>(genericTypeParamDecl, outerSubst)));
}
else if( auto genericValueParamDecl = as<GenericValueParamDecl>(mm) )
{
- genericSubst->args.add(astBuilder->create<GenericParamIntVal>(
+ args.add(astBuilder->getOrCreate<GenericParamIntVal>(
genericValueParamDecl->getType(),
- DeclRef<GenericValueParamDecl>(genericValueParamDecl, outerSubst)));
+ genericValueParamDecl, outerSubst));
}
}
+ bool shouldCache = true;
+
// create default substitution arguments for constraints
for (auto mm : genericDecl->members)
{
if (auto genericTypeConstraintDecl = as<GenericTypeConstraintDecl>(mm))
{
- DeclaredSubtypeWitness* witness = astBuilder->create<DeclaredSubtypeWitness>();
- witness->declRef = DeclRef<Decl>(genericTypeConstraintDecl, outerSubst);
- witness->sub = genericTypeConstraintDecl->sub.type;
- witness->sup = genericTypeConstraintDecl->sup.type;
- genericSubst->args.add(witness);
+ if (semantics)
+ {
+ ensureDecl(semantics, genericTypeConstraintDecl, DeclCheckState::ReadyForReference);
+ }
+ auto constraintDeclRef = DeclRef<GenericTypeConstraintDecl>(genericTypeConstraintDecl, outerSubst);
+ DeclaredSubtypeWitness* witness =
+ astBuilder->getOrCreate<DeclaredSubtypeWitness>(
+ getSub(astBuilder, constraintDeclRef),
+ getSup(astBuilder, constraintDeclRef),
+ genericTypeConstraintDecl,
+ outerSubst);
+ // TODO: this is an ugly hack to prevent crashing.
+ // In early stages of compilation witness->sub and witness->sup may not be checked yet.
+ // When semanticVisitor is present we have used that to ensure the type is checked.
+ // However due to how the code is written we cannot guarantee semanticVisitor is always available
+ // here, and if we can't get the checked sup/sub type this subst is incomplete and should not be
+ // cached.
+ if (!witness->sub)
+ shouldCache = false;
+ args.add(witness);
}
}
+ GenericSubstitution* genericSubst = astBuilder->getOrCreateGenericSubstitution(genericDecl, args, outerSubst);
+ if (shouldCache)
+ astBuilder->m_genericDefaultSubst[genericDecl] = genericSubst;
return genericSubst;
}
@@ -563,7 +594,8 @@ namespace Slang
// using their archetypes).
//
SubstitutionSet createDefaultSubstitutions(
- ASTBuilder* astBuilder,
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
Decl* decl,
SubstitutionSet outerSubstSet)
{
@@ -577,6 +609,7 @@ namespace Slang
GenericSubstitution* genericSubst = createDefaultSubstitutionsForGeneric(
astBuilder,
+ semantics,
genericDecl,
outerSubstSet.substitutions);
@@ -587,23 +620,19 @@ namespace Slang
}
SubstitutionSet createDefaultSubstitutions(
- ASTBuilder* astBuilder,
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
Decl* decl)
{
SubstitutionSet subst;
if( auto parentDecl = decl->parentDecl )
{
- subst = createDefaultSubstitutions(astBuilder, parentDecl);
+ subst = createDefaultSubstitutions(astBuilder, semantics, parentDecl);
}
- subst = createDefaultSubstitutions(astBuilder, decl, subst);
+ subst = createDefaultSubstitutions(astBuilder, semantics, decl, subst);
return subst;
}
- void ensureDecl(SemanticsVisitor* visitor, Decl* decl, DeclCheckState state)
- {
- visitor->ensureDecl(decl, state);
- }
-
bool SemanticsVisitor::isDeclUsableAsStaticMember(
Decl* decl)
{
@@ -1195,7 +1224,7 @@ namespace Slang
{
if (auto declRefType = as<DeclRefType>(sharedTypeExpr->base))
{
- declRefType->declRef.substitutions = createDefaultSubstitutions(m_astBuilder, declRefType->declRef.getDecl());
+ declRefType->declRef.substitutions = createDefaultSubstitutions(m_astBuilder, this, declRefType->declRef.getDecl());
if (auto typetype = as<TypeType>(typeExp.exp->type))
typetype->type = declRefType;
@@ -1754,9 +1783,7 @@ namespace Slang
// compare `Derived::doThing` against `IBase::doThing<U>` where the `U` there is
// the parameter of `Dervived::doThing`.
//
- GenericSubstitution* requiredSubst = m_astBuilder->create<GenericSubstitution>();
- requiredSubst->genericDecl = requiredGenericDeclRef.getDecl();
- requiredSubst->outer = requiredGenericDeclRef.substitutions;
+ List<Val*> requiredSubstArgs;
for (Index i = 0; i < memberCount; i++)
{
@@ -1769,17 +1796,20 @@ namespace Slang
SLANG_ASSERT(satisfyingTypeParamDeclRef);
auto satisfyingType = DeclRefType::create(m_astBuilder, satisfyingTypeParamDeclRef);
- requiredSubst->args.add(satisfyingType);
+ requiredSubstArgs.add(satisfyingType);
}
else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as<GenericValueParamDecl>())
{
auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>();
SLANG_ASSERT(satisfyingValueParamDeclRef);
- auto satisfyingVal = m_astBuilder->create<GenericParamIntVal>();
+ auto satisfyingVal = m_astBuilder->getOrCreate<GenericParamIntVal>(
+ requiredValueParamDeclRef.getDecl()->getType(),
+ satisfyingValueParamDeclRef.getDecl(),
+ satisfyingValueParamDeclRef.substitutions.substitutions);
satisfyingVal->declRef = satisfyingValueParamDeclRef;
- requiredSubst->args.add(satisfyingVal);
+ requiredSubstArgs.add(satisfyingVal);
}
}
for (Index i = 0; i < memberCount; i++)
@@ -1792,15 +1822,20 @@ namespace Slang
auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as<GenericTypeConstraintDecl>();
SLANG_ASSERT(satisfyingConstraintDeclRef);
- auto satisfyingWitness = m_astBuilder->create<DeclaredSubtypeWitness>();
+ auto satisfyingWitness = m_astBuilder->getOrCreate<DeclaredSubtypeWitness>();
satisfyingWitness->sub = getSub(m_astBuilder, satisfyingConstraintDeclRef);
satisfyingWitness->sup = getSup(m_astBuilder, satisfyingConstraintDeclRef);
satisfyingWitness->declRef = satisfyingConstraintDeclRef;
- requiredSubst->args.add(satisfyingWitness);
+ requiredSubstArgs.add(satisfyingWitness);
}
}
+ GenericSubstitution* requiredSubst = m_astBuilder->getOrCreateGenericSubstitution(
+ requiredGenericDeclRef.getDecl(),
+ requiredSubstArgs,
+ requiredGenericDeclRef.substitutions);
+
// Now that we have computed a set of specialization arguments that will
// specialize the generic requirement at the type parameters of the satisfying
// generic, we can construct a reference to that declaration and re-run some
@@ -2764,13 +2799,15 @@ namespace Slang
auto reqType = getBaseType(m_astBuilder, requiredInheritanceDeclRef);
- DeclaredSubtypeWitness* interfaceIsReqWitness = m_astBuilder->create<DeclaredSubtypeWitness>();
- interfaceIsReqWitness->sub = superInterfaceType;
- interfaceIsReqWitness->sup = reqType;
- interfaceIsReqWitness->declRef = requiredInheritanceDeclRef;
+ DeclaredSubtypeWitness* interfaceIsReqWitness =
+ m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(
+ superInterfaceType,
+ reqType,
+ requiredInheritanceDeclRef.getDecl(),
+ requiredInheritanceDeclRef.substitutions.substitutions);
// ...
- TransitiveSubtypeWitness* subIsReqWitness = m_astBuilder->create<TransitiveSubtypeWitness>();
+ TransitiveSubtypeWitness* subIsReqWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, reqType, interfaceIsReqWitness);
subIsReqWitness->sub = subType;
subIsReqWitness->sup = reqType;
subIsReqWitness->subToMid = subTypeConformsToSuperInterfaceWitness;
@@ -3232,7 +3269,7 @@ namespace Slang
void SemanticsVisitor::checkExtensionConformance(ExtensionDecl* decl)
{
- auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, makeDeclRef(decl)).as<ExtensionDecl>();
+ auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)).as<ExtensionDecl>();
auto targetType = getTargetType(m_astBuilder, declRef);
for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>())
@@ -3265,7 +3302,7 @@ namespace Slang
auto astBuilder = getASTBuilder();
- auto declRef = createDefaultSubstitutionsIfNeeded(astBuilder, makeDeclRef(decl)).as<AggTypeDeclBase>();
+ auto declRef = createDefaultSubstitutionsIfNeeded(astBuilder, this, makeDeclRef(decl)).as<AggTypeDeclBase>();
auto type = DeclRefType::create(astBuilder, declRef);
// TODO: Need to figure out what this should do for
@@ -4222,8 +4259,7 @@ namespace Slang
GenericSubstitution* SemanticsVisitor::createDummySubstitutions(
GenericDecl* genericDecl)
{
- GenericSubstitution* subst = m_astBuilder->create<GenericSubstitution>();
- subst->genericDecl = genericDecl;
+ List<Val*> args;
for (auto dd : genericDecl->members)
{
if (dd == genericDecl->inner)
@@ -4232,17 +4268,19 @@ namespace Slang
if (auto typeParam = as<GenericTypeParamDecl>(dd))
{
auto type = DeclRefType::create(m_astBuilder, makeDeclRef(typeParam));
- subst->args.add(type);
+ args.add(type);
}
else if (auto valueParam = as<GenericValueParamDecl>(dd))
{
- auto val = m_astBuilder->create<GenericParamIntVal>(
+ auto val = m_astBuilder->getOrCreate<GenericParamIntVal>(
valueParam->getType(),
- makeDeclRef(valueParam));
- subst->args.add(val);
+ valueParam,
+ nullptr);
+ args.add(val);
}
// TODO: need to handle constraints here?
}
+ GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution(genericDecl, args, nullptr);
return subst;
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 11b001c2c..4d55669e2 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -865,8 +865,7 @@ namespace Slang
IntVal* SemanticsVisitor::getIntVal(IntegerLiteralExpr* expr)
{
- // TODO(tfoley): don't keep allocating here!
- return m_astBuilder->create<ConstantIntVal>(expr->type.type, expr->value);
+ return m_astBuilder->getOrCreate<ConstantIntVal>(expr->type.type, expr->value);
}
IntVal* SemanticsVisitor::tryConstantFoldExpr(
@@ -1091,7 +1090,7 @@ namespace Slang
}
}
- IntVal* result = m_astBuilder->create<ConstantIntVal>(invokeExpr.getExpr()->type.type, resultValue);
+ IntVal* result = m_astBuilder->getOrCreate<ConstantIntVal>(invokeExpr.getExpr()->type.type, resultValue);
return result;
}
@@ -1166,7 +1165,6 @@ namespace Slang
expr = getBaseExpr(parenExpr);
}
- // TODO(tfoley): more serious constant folding here
if (auto intLitExpr = expr.as<IntegerLiteralExpr>())
{
return getIntVal(intLitExpr);
@@ -1176,7 +1174,7 @@ namespace Slang
{
// If it's a boolean, we allow promotion to int.
const IntegerLiteralValue value = IntegerLiteralValue(boolLitExpr.getExpr()->value);
- return m_astBuilder->create<ConstantIntVal>(m_astBuilder->getBoolType(), value);
+ return m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getBoolType(), value);
}
// it is possible that we are referring to a generic value param
@@ -1186,9 +1184,10 @@ namespace Slang
if (auto genericValParamRef = declRef.as<GenericValueParamDecl>())
{
- Val* valResult = m_astBuilder->create<GenericParamIntVal>(
+ Val* valResult = m_astBuilder->getOrCreate<GenericParamIntVal>(
declRef.substitute(m_astBuilder, genericValParamRef.getDecl()->getType()),
- genericValParamRef);
+ genericValParamRef.getDecl(),
+ genericValParamRef.substitutions.substitutions);
valResult = valResult->substitute(m_astBuilder, expr.getSubsts());
return as<IntVal>(valResult);
}
@@ -2158,7 +2157,7 @@ namespace Slang
// here if the input type had a sugared name...
swizExpr->type = QualType(createVectorType(
baseElementType,
- m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)));
+ m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)));
}
// A swizzle can be used as an l-value as long as there
@@ -2279,7 +2278,7 @@ namespace Slang
// here if the input type had a sugared name...
swizExpr->type = QualType(createVectorType(
baseElementType,
- m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)));
+ m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)));
}
// A swizzle can be used as an l-value as long as there
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 091990cf3..ef500d4fa 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -202,10 +202,32 @@ namespace Slang
Substitutions* subst = nullptr;
};
+ struct LookupRequestKey
+ {
+ NodeBase* base;
+ Name* name;
+ LookupOptions options;
+ LookupMask mask;
+ bool operator==(const LookupRequestKey& other) const
+ {
+ return base == other.base && name == other.name && options == other.options && mask == other.mask;
+ }
+ HashCode getHashCode() const
+ {
+ Hasher hasher;
+ hasher.hashValue(base);
+ hasher.hashValue(name);
+ hasher.hashValue(options);
+ hasher.hashValue(mask);
+ return hasher.getResult();
+ }
+ };
+
struct TypeCheckingCache
{
Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache;
Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache;
+ Dictionary<LookupRequestKey, LookupResult> lookupCache;
};
/// Shared state for a semantics-checking session.
@@ -1467,7 +1489,7 @@ namespace Slang
//
bool TryCheckOverloadCandidateConstraints(
OverloadResolveContext& context,
- OverloadCandidate const& candidate);
+ OverloadCandidate& candidate);
// Try to check an overload candidate, but bail out
// if any step fails
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 9109130e2..55263453d 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -210,7 +210,7 @@ namespace Slang
//
auto genSubst = m_astBuilder->create<GenericSubstitution>();
candidate.subst = genSubst;
- auto& checkedArgs = genSubst->args;
+ auto& checkedArgs = (List<Val*>&)genSubst->getArgs();
// Rather than bail out as soon as we hit a problem,
// we are going to process *all* of the parameters of the
@@ -474,7 +474,7 @@ namespace Slang
bool SemanticsVisitor::TryCheckOverloadCandidateConstraints(
OverloadResolveContext& context,
- OverloadCandidate const& candidate)
+ OverloadCandidate& candidate)
{
// We only need this step for generics, so always succeed on
// everything else.
@@ -493,6 +493,8 @@ namespace Slang
subst->genericDecl = genericDeclRef.getDecl();
subst->outer = genericDeclRef.substitutions.substitutions;
+ List<Val*> newArgs = subst->getArgs();
+
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
auto subset = genericDeclRef.substitutions;
@@ -506,7 +508,7 @@ namespace Slang
auto subTypeWitness = tryGetSubtypeWitness(sub, sup);
if(subTypeWitness)
{
- subst->args.add(subTypeWitness);
+ newArgs.add(subTypeWitness);
}
else
{
@@ -518,6 +520,8 @@ namespace Slang
}
}
+ candidate.subst = m_astBuilder->getOrCreateGenericSubstitution(genericDeclRef.getDecl(), newArgs, genericDeclRef.substitutions.substitutions);
+
// Done checking all the constraints, hooray.
return true;
}
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 2e577b6d5..d7200d47c 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -1120,7 +1120,7 @@ namespace Slang
if(!intVal)
{
sink->diagnose(param.loc, Diagnostics::expectedValueOfTypeForSpecializationArg, paramDecl->getType(), paramDecl);
- intVal = getLinkage()->getASTBuilder()->create<ConstantIntVal>(m_astBuilder->getIntType(), 0);
+ intVal = getLinkage()->getASTBuilder()->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), 0);
}
ModuleSpecializationInfo::GenericArgInfo expandedArg;
@@ -1192,15 +1192,18 @@ namespace Slang
auto genericDeclRef = m_funcDeclRef.getParent().as<GenericDecl>();
SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters
- GenericSubstitution* genericSubst = getLinkage()->getASTBuilder()->create<GenericSubstitution>();
- genericSubst->outer = genericDeclRef.substitutions.substitutions;
- genericSubst->genericDecl = genericDeclRef.getDecl();
+ List<Val*> genericArgs;
for(Index ii = 0; ii < genericSpecializationParamCount; ++ii)
{
auto specializationArg = args[ii];
- genericSubst->args.add(specializationArg.val);
+ genericArgs.add(specializationArg.val);
}
+ GenericSubstitution* genericSubst =
+ getLinkage()->getASTBuilder()->getOrCreateGenericSubstitution(
+ genericDeclRef.getDecl(),
+ genericArgs,
+ genericDeclRef.substitutions.substitutions);
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
@@ -1218,7 +1221,7 @@ namespace Slang
auto subTypeWitness = visitor.tryGetSubtypeWitness(sub, sup);
if(subTypeWitness)
{
- genericSubst->args.add(subTypeWitness);
+ genericArgs.add(subTypeWitness);
}
else
{
@@ -1228,6 +1231,11 @@ namespace Slang
}
}
+ genericSubst =
+ getLinkage()->getASTBuilder()->getOrCreateGenericSubstitution(
+ genericDeclRef.getDecl(),
+ genericArgs,
+ genericDeclRef.substitutions.substitutions);
specializedFuncDeclRef.substitutions.substitutions = genericSubst;
}
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index 2bf6cb830..a25a8683d 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -163,7 +163,7 @@ namespace Slang
}
else
{
- ConstantIntVal* rangeBeginConst = m_astBuilder->create<ConstantIntVal>();
+ ConstantIntVal* rangeBeginConst = m_astBuilder->getOrCreate<ConstantIntVal>();
rangeBeginConst->value = 0;
rangeBeginVal = rangeBeginConst;
}
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index c59e1b308..d402dde03 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -171,15 +171,16 @@ namespace Slang
DeclRef<GenericDecl> genericDeclRef,
List<Expr*> const& args)
{
- GenericSubstitution* subst = m_astBuilder->create<GenericSubstitution>();
- subst->genericDecl = genericDeclRef.getDecl();
- subst->outer = genericDeclRef.substitutions.substitutions;
+ List<Val*> evaledArgs;
for (auto argExpr : args)
{
- subst->args.add(ExtractGenericArgVal(argExpr));
+ evaledArgs.add(ExtractGenericArgVal(argExpr));
}
+ GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution(
+ genericDeclRef.getDecl(), evaledArgs, genericDeclRef.substitutions.substitutions);
+
DeclRef<Decl> innerDeclRef;
innerDeclRef.decl = getInner(genericDeclRef);
innerDeclRef.substitutions = SubstitutionSet(subst);
@@ -403,18 +404,7 @@ namespace Slang
Type* elementType,
IntVal* elementCount)
{
- auto vectorGenericDecl = as<GenericDecl>(m_astBuilder->getSharedASTBuilder()->findMagicDecl("Vector"));
-
- auto vectorTypeDecl = vectorGenericDecl->inner;
-
- auto substitutions = m_astBuilder->create<GenericSubstitution>();
- substitutions->genericDecl = vectorGenericDecl;
- substitutions->args.add(elementType);
- substitutions->args.add(elementCount);
-
- auto declRef = DeclRef<Decl>(vectorTypeDecl, substitutions);
-
- return as<VectorExpressionType>(DeclRefType::create(m_astBuilder, declRef));
+ return m_astBuilder->getVectorType(elementType, elementCount);
}
Expr* SemanticsExprVisitor::visitSharedTypeExpr(SharedTypeExpr* expr)
diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp
index 48f7ea099..a8818fd5c 100644
--- a/source/slang/slang-lookup.cpp
+++ b/source/slang/slang-lookup.cpp
@@ -843,7 +843,7 @@ static void _lookUpInScopes(
// have a null `containerDecl` and needs to be
// skipped over.
//
- if(!containerDecl)
+ if (!containerDecl)
continue;
// TODO: If we need default substitutions to be applied to
@@ -852,8 +852,8 @@ static void _lookUpInScopes(
// just a decl.
//
DeclRef<ContainerDecl> containerDeclRef =
- DeclRef<Decl>(containerDecl, createDefaultSubstitutions(astBuilder, containerDecl)).as<ContainerDecl>();
-
+ DeclRef<Decl>(containerDecl, createDefaultSubstitutions(astBuilder, request.semantics, containerDecl)).as<ContainerDecl>();
+
// If the container we are looking into represents a type
// or an `extension` of a type, then we need to treat
// this step as lookup into the `this` variable (or the
@@ -878,9 +878,9 @@ static void _lookUpInScopes(
breadcrumb.prev = nullptr;
Type* type = nullptr;
- if(auto extDeclRef = aggTypeDeclBaseRef.as<ExtensionDecl>())
+ if (auto extDeclRef = aggTypeDeclBaseRef.as<ExtensionDecl>())
{
- if( request.semantics )
+ if (request.semantics)
{
ensureDecl(request.semantics, extDeclRef.getDecl(), DeclCheckState::CanUseExtensionTargetType);
}
@@ -918,14 +918,14 @@ static void _lookUpInScopes(
// of some nested type, then there shouldn't be an implicit `this`
// expression for the outer type, but instead an implicit `This`.
//
- if( containerDeclRef.is<ConstructorDecl>() )
+ if (containerDeclRef.is<ConstructorDecl>())
{
// In the context of an `__init` declaration, the members of
// the surrounding type are accessible through a mutable `this`.
//
thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::MutableValue;
}
- else if( containerDeclRef.is<SetterDecl>() )
+ else if (containerDeclRef.is<SetterDecl>())
{
// In the context of a `set` accessor, the members of the
// surrounding type are accessible through a mutable `this`.
@@ -937,19 +937,19 @@ static void _lookUpInScopes(
//
thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::MutableValue;
}
- else if( auto funcDeclRef = containerDeclRef.as<FunctionDeclBase>() )
+ else if (auto funcDeclRef = containerDeclRef.as<FunctionDeclBase>())
{
// The implicit `this`/`This` for a function-like declaration
// depends on modifiers attached to the declaration.
//
- if( funcDeclRef.getDecl()->hasModifier<HLSLStaticModifier>() )
+ if (funcDeclRef.getDecl()->hasModifier<HLSLStaticModifier>())
{
// A `static` method only has access to an implicit `This`,
// and does not have a `this` expression available.
//
thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Type;
}
- else if( funcDeclRef.getDecl()->hasModifier<MutatingAttribute>() )
+ else if (funcDeclRef.getDecl()->hasModifier<MutatingAttribute>())
{
// In a non-`static` method marked `[mutating]` there is
// an implicit `this` parameter that is mutable.
@@ -964,7 +964,7 @@ static void _lookUpInScopes(
thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::ImmutableValue;
}
}
- else if( containerDeclRef.as<AggTypeDeclBase>() )
+ else if (containerDeclRef.as<AggTypeDeclBase>())
{
// When lookup moves from a nested typed declaration to an
// outer scope, there is no ability to use an implicit `this`
@@ -972,7 +972,6 @@ static void _lookUpInScopes(
//
thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Type;
}
- // TODO: What other cases need to be enumerated here?
}
if (result.isValid())
@@ -1002,9 +1001,27 @@ LookupResult lookUp(
Scope* scope,
LookupMask mask)
{
- LookupRequest request = initLookupRequest(semantics, name, mask, LookupOptions::None, scope);
LookupResult result;
+ LookupRequestKey key;
+ TypeCheckingCache* typeCheckingCache = nullptr;
+ if (semantics)
+ {
+ typeCheckingCache = semantics->getLinkage()->getTypeCheckingCache();
+ key.base = scope;
+ key.name = name;
+ key.options = LookupOptions::None;
+ key.mask = mask;
+ if (typeCheckingCache->lookupCache.TryGetValue(key, result))
+ {
+ return result;
+ }
+ }
+ LookupRequest request = initLookupRequest(semantics, name, mask, LookupOptions::None, scope);
_lookUpInScopes(astBuilder, name, request, result);
+ if (typeCheckingCache)
+ {
+ typeCheckingCache->lookupCache[key] = result;
+ }
return result;
}
@@ -1016,9 +1033,20 @@ LookupResult lookUpMember(
LookupMask mask,
LookupOptions options)
{
- LookupRequest request = initLookupRequest(semantics, name, mask, options, nullptr);
+ TypeCheckingCache* typeCheckingCache = semantics->getLinkage()->getTypeCheckingCache();
+ LookupRequestKey key;
+ key.base = type;
+ key.name = name;
+ key.options = options;
+ key.mask = mask;
LookupResult result;
+ if (typeCheckingCache->lookupCache.TryGetValue(key, result))
+ {
+ return result;
+ }
+ LookupRequest request = initLookupRequest(semantics, name, mask, options, nullptr);
_lookUpMembersInType(astBuilder, name, type, request, result, nullptr);
+ typeCheckingCache->lookupCache[key] = result;
return result;
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 99af2abd6..221abe2b8 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1813,7 +1813,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
_collectSubstitutionArgs(operands, subst->outer);
if (auto genSubst = as<GenericSubstitution>(subst))
{
- for (auto arg : genSubst->args)
+ for (auto arg : genSubst->getArgs())
{
operands.add(lowerVal(context, arg).val);
}
@@ -2561,19 +2561,19 @@ ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection de
return kParameterDirection_In;
}
-DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, Decl* decl)
+DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, SemanticsVisitor* semantics, Decl* decl)
{
DeclRef<Decl> declRef;
declRef.decl = decl;
- declRef.substitutions = createDefaultSubstitutions(context->astBuilder, decl);
+ declRef.substitutions = createDefaultSubstitutions(context->astBuilder, semantics, decl);
return declRef;
}
//
// The client should actually call the templated wrapper, to preserve type information.
template<typename D>
-DeclRef<D> createDefaultSpecializedDeclRef(IRGenContext* context, D* decl)
+DeclRef<D> createDefaultSpecializedDeclRef(IRGenContext* context, SemanticsVisitor* semantics, D* decl)
{
- DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(context, decl);
+ DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(context, semantics, decl);
return declRef.as<D>();
}
@@ -3502,8 +3502,8 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
void _lowerSubstitutionArg(IRGenContext* subContext, GenericSubstitution* subst, Decl* paramDecl, Index argIndex)
{
- SLANG_ASSERT(argIndex < subst->args.getCount());
- auto argVal = lowerVal(subContext, subst->args[argIndex]);
+ SLANG_ASSERT(argIndex < subst->getArgs().getCount());
+ auto argVal = lowerVal(subContext, subst->getArgs()[argIndex]);
setValue(subContext, paramDecl, argVal);
}
@@ -7669,7 +7669,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
FuncDeclBaseTypeInfo info;
_lowerFuncDeclBaseTypeInfo(
funcTypeContext,
- createDefaultSpecializedDeclRef(funcTypeContext, decl),
+ createDefaultSpecializedDeclRef(funcTypeContext, nullptr, decl),
info);
auto irFuncType = info.type;
@@ -7710,7 +7710,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
FuncDeclBaseTypeInfo info;
_lowerFuncDeclBaseTypeInfo(
subContext,
- createDefaultSpecializedDeclRef(context, decl),
+ createDefaultSpecializedDeclRef(context, nullptr, decl),
info);
auto irFuncType = info.type;
@@ -8368,7 +8368,7 @@ LoweredValInfo emitDeclRef(
// We have the IR value for the generic we'd like to specialize,
// and now we need to get the value for the arguments.
List<IRInst*> irArgs;
- for (auto argVal : genericSubst->args)
+ for (auto argVal : genericSubst->getArgs())
{
auto irArgVal = lowerSimpleVal(context, argVal);
SLANG_ASSERT(irArgVal);
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index c230776db..c15402fbb 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -374,9 +374,9 @@ namespace Slang
{
// This is the case where we *do* have substitutions.
emitRaw(context, "G");
- UInt genericArgCount = subst->args.getCount();
+ UInt genericArgCount = subst->getArgs().getCount();
emit(context, genericArgCount);
- for( auto aa : subst->args )
+ for (auto aa : subst->getArgs())
{
emitVal(context, aa);
}
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 957d4e661..9a7200e9d 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -355,6 +355,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
DeclRef<Decl> createDefaultSubstitutionsIfNeeded(
ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
DeclRef<Decl> declRef)
{
// It is possible that `declRef` refers to a generic type,
@@ -413,7 +414,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
if(!foundSubst)
{
Substitutions* newSubst = createDefaultSubstitutionsForGeneric(
- astBuilder,
+ astBuilder,
+ semantics,
genericParentDecl,
nullptr);
@@ -436,7 +438,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
ASTBuilder* astBuilder,
DeclRef<Decl> declRef)
{
- declRef = createDefaultSubstitutionsIfNeeded(astBuilder, declRef);
+ declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);
if (auto builtinMod = declRef.getDecl()->findModifier<BuiltinTypeModifier>())
{
@@ -458,58 +460,60 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
if (magicMod->magicName == "SamplerState")
{
- auto type = astBuilder->create<SamplerStateType>();
+ auto type = astBuilder->getOrCreate<SamplerStateType>(SamplerStateFlavor(magicMod->tag));
type->declRef = declRef;
- type->flavor = SamplerStateFlavor(magicMod->tag);
return type;
}
else if (magicMod->magicName == "Vector")
{
- SLANG_ASSERT(subst && subst->args.getCount() == 2);
- auto vecType = astBuilder->create<VectorExpressionType>();
+ SLANG_ASSERT(subst && subst->getArgs().getCount() == 2);
+ auto vecType = astBuilder->getOrCreate<VectorExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1]));
vecType->declRef = declRef;
- vecType->elementType = ExtractGenericArgType(subst->args[0]);
- vecType->elementCount = ExtractGenericArgInteger(subst->args[1]);
+ vecType->elementType = ExtractGenericArgType(subst->getArgs()[0]);
+ vecType->elementCount = ExtractGenericArgInteger(subst->getArgs()[1]);
return vecType;
}
else if (magicMod->magicName == "Matrix")
{
- SLANG_ASSERT(subst && subst->args.getCount() == 3);
- auto matType = astBuilder->create<MatrixExpressionType>();
+ SLANG_ASSERT(subst && subst->getArgs().getCount() == 3);
+ auto matType = astBuilder->getOrCreate<MatrixExpressionType>(
+ ExtractGenericArgType(subst->getArgs()[0]),
+ ExtractGenericArgInteger(subst->getArgs()[1]),
+ ExtractGenericArgInteger(subst->getArgs()[2]));
matType->declRef = declRef;
return matType;
}
else if (magicMod->magicName == "Texture")
{
- SLANG_ASSERT(subst && subst->args.getCount() >= 1);
- auto textureType = astBuilder->create<TextureType>(
+ SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1);
+ auto textureType = astBuilder->getOrCreate<TextureType>(
TextureFlavor(magicMod->tag),
- ExtractGenericArgType(subst->args[0]));
+ ExtractGenericArgType(subst->getArgs()[0]));
textureType->declRef = declRef;
return textureType;
}
else if (magicMod->magicName == "TextureSampler")
{
- SLANG_ASSERT(subst && subst->args.getCount() >= 1);
+ SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1);
auto textureType = astBuilder->create<TextureSamplerType>(
TextureFlavor(magicMod->tag),
- ExtractGenericArgType(subst->args[0]));
+ ExtractGenericArgType(subst->getArgs()[0]));
textureType->declRef = declRef;
return textureType;
}
else if (magicMod->magicName == "GLSLImageType")
{
- SLANG_ASSERT(subst && subst->args.getCount() >= 1);
- auto textureType = astBuilder->create<GLSLImageType>(
+ SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1);
+ auto textureType = astBuilder->getOrCreate<GLSLImageType>(
TextureFlavor(magicMod->tag),
- ExtractGenericArgType(subst->args[0]));
+ ExtractGenericArgType(subst->getArgs()[0]));
textureType->declRef = declRef;
return textureType;
}
else if (magicMod->magicName == "FeedbackType")
{
SLANG_ASSERT(subst == nullptr);
- auto type = astBuilder->create<FeedbackType>();
+ auto type = astBuilder->getOrCreateWithDefaultCtor<FeedbackType>(magicMod->tag);
type->declRef = declRef;
type->kind = FeedbackType::Kind(magicMod->tag);
return type;
@@ -519,11 +523,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
// and we can drive the dispatch with a table instead
// of this ridiculously slow `if` cascade.
- #define CASE(n,T) \
- else if(magicMod->magicName == #n) { \
- auto type = astBuilder->create<T>(); \
- type->declRef = declRef; \
- return type; \
+ #define CASE(n, T) \
+ else if (magicMod->magicName == #n) \
+ { \
+ auto type = astBuilder->getOrCreateWithDefaultCtor<T>( \
+ declRef.decl, declRef.substitutions.substitutions); \
+ type->declRef = declRef; \
+ return type; \
}
CASE(HLSLInputPatchType, HLSLInputPatchType)
@@ -531,14 +537,16 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
#undef CASE
- #define CASE(n,T) \
- else if(magicMod->magicName == #n) { \
- SLANG_ASSERT(subst && subst->args.getCount() == 1); \
- auto type = astBuilder->create<T>(); \
- type->elementType = ExtractGenericArgType(subst->args[0]); \
- type->declRef = declRef; \
- return type; \
- }
+ #define CASE(n, T) \
+ else if (magicMod->magicName == #n) \
+ { \
+ SLANG_ASSERT(subst && subst->getArgs().getCount() == 1); \
+ auto type = \
+ astBuilder->getOrCreateWithDefaultCtor<T>(ExtractGenericArgType(subst->getArgs()[0])); \
+ type->elementType = ExtractGenericArgType(subst->getArgs()[0]); \
+ type->declRef = declRef; \
+ return type; \
+ }
CASE(ConstantBuffer, ConstantBufferType)
CASE(TextureBuffer, TextureBufferType)
@@ -561,8 +569,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
// "magic" builtin types which have no generic parameters
#define CASE(n,T) \
- else if(magicMod->magicName == #n) { \
- auto type = astBuilder->create<T>(); \
+ else if(magicMod->magicName == #n) { \
+ auto type = astBuilder->getOrCreate<T>(); \
type->declRef = declRef; \
return type; \
}
@@ -601,7 +609,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
}
else
{
- return astBuilder->create<DeclRefType>(declRef);
+ return astBuilder->getOrCreateDeclRefType(declRef.decl, declRef.substitutions.substitutions);
}
}
@@ -786,10 +794,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
substsToApply,
&diff);
- GenericSubstitution* firstSubst = astBuilder->create<GenericSubstitution>();
- firstSubst->genericDecl = ancestorGenericDecl;
- firstSubst->args = appGenericSubst->args;
- firstSubst->outer = restSubst;
+ GenericSubstitution* firstSubst = astBuilder->getOrCreateGenericSubstitution(
+ ancestorGenericDecl, appGenericSubst->getArgs(), restSubst);
(*ioDiff)++;
return firstSubst;
@@ -849,10 +855,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
substsToApply,
&diff);
- ThisTypeSubstitution* firstSubst = astBuilder->create<ThisTypeSubstitution>();
- firstSubst->interfaceDecl = ancestorInterfaceDecl;
- firstSubst->witness = appThisTypeSubst->witness;
- firstSubst->outer = restSubst;
+ ThisTypeSubstitution* firstSubst = astBuilder->getOrCreateThisTypeSubstitution(
+ ancestorInterfaceDecl, appThisTypeSubst->witness, restSubst);
(*ioDiff)++;
return firstSubst;
@@ -1013,12 +1017,12 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
Type* HLSLPatchType::getElementType()
{
- return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]);
+ return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]);
}
IntVal* HLSLPatchType::getElementCount()
{
- return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]);
+ return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]);
}
// Constructors for types
@@ -1047,7 +1051,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
ASTBuilder* astBuilder,
DeclRef<TypeDefDecl> const& declRef)
{
- DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, declRef).as<TypeDefDecl>();
+ DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef).as<TypeDefDecl>();
return astBuilder->create<NamedExpressionType>(specializedDeclRef);
}
@@ -1179,7 +1183,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
{
out << "<";
bool isFirst = true;
- for (const auto& it : genericSubstitution->args)
+ for (const auto& it : genericSubstitution->getArgs())
{
if (!isFirst)
out << ", ";
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index 332661e12..4e1900636 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -286,20 +286,24 @@ namespace Slang
// TODO: where should this live?
SubstitutionSet createDefaultSubstitutions(
- ASTBuilder* astBuilder,
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
Decl* decl,
SubstitutionSet parentSubst);
SubstitutionSet createDefaultSubstitutions(
- ASTBuilder* astBuilder,
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
Decl* decl);
DeclRef<Decl> createDefaultSubstitutionsIfNeeded(
- ASTBuilder* astBuilder,
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
DeclRef<Decl> declRef);
GenericSubstitution* createDefaultSubstitutionsForGeneric(
- ASTBuilder* astBuilder,
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
GenericDecl* genericDecl,
Substitutions* outerSubst);
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index a0355c266..b68dbc14a 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1178,7 +1178,7 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType(
ConstantBufferType* cbType = getASTBuilder()->create<ConstantBufferType>();
cbType->elementType = type;
cbType->declRef = getASTBuilder()->getBuiltinDeclRef(
- "ConstantBuffer", makeConstArrayViewSingle<Val*>(static_cast<Val*>(type)));
+ "ConstantBuffer", static_cast<Val*>(type));
containerTypeReflection = cbType;
}
break;
@@ -1187,7 +1187,7 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType(
ParameterBlockType* pbType = getASTBuilder()->create<ParameterBlockType>();
pbType->elementType = type;
pbType->declRef = getASTBuilder()->getBuiltinDeclRef(
- "ParameterBlock", makeConstArrayViewSingle<Val*>(static_cast<Val*>(type)));
+ "ParameterBlock", static_cast<Val*>(type));
containerTypeReflection = pbType;
}
break;
@@ -1197,7 +1197,7 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType(
getASTBuilder()->create<HLSLStructuredBufferType>();
sbType->elementType = type;
sbType->declRef = getASTBuilder()->getBuiltinDeclRef(
- "HLSLStructuredBufferType", makeConstArrayViewSingle<Val*>(static_cast<Val*>(type)));
+ "HLSLStructuredBufferType", static_cast<Val*>(type));
containerTypeReflection = sbType;
}
break;
@@ -3839,7 +3839,7 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor
{
if(auto genericSubst = as<GenericSubstitution>(substitution))
{
- for(auto arg : genericSubst->args)
+ for(auto arg : genericSubst->getArgs())
{
collectReferencedModules(arg);
}