diff options
| author | Yong He <yonghe@outlook.com> | 2022-09-13 13:11:48 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-09-13 13:11:48 -0700 |
| commit | f216b77752b9e4aea52882b2110ceb1cc64a2171 (patch) | |
| tree | fbb33485b7260bc0f89b406e1be6fb8196f94196 /source | |
| parent | 9f3e83cf0d664c87a618edf08d834829178030e6 (diff) | |
Deduplicate AST type nodes and cache lookup operations. (#2397)
* wip: dedup AST type nodes and cache lookup.
* Fix.
* Remove profiling.
* Fixes.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
28 files changed, 573 insertions, 325 deletions
diff --git a/source/core/slang-array.h b/source/core/slang-array.h index 2647e163f..8d22e1848 100644 --- a/source/core/slang-array.h +++ b/source/core/slang-array.h @@ -125,6 +125,23 @@ namespace Slang insertArray(rs, args...); return rs; } + + + template<typename TList> + void addToList(TList&) + { + } + template<typename TList, typename T> + void addToList(TList& list, T node) + { + list.add(node); + } + template<typename TList, typename T, typename ... TArgs> + void addToList(TList& list, T node, TArgs ... args) + { + list.add(node); + addToList(list, args...); + } } #endif diff --git a/source/core/slang-short-list.h b/source/core/slang-short-list.h index 7d51a8abf..5bad9faf8 100644 --- a/source/core/slang-short-list.h +++ b/source/core/slang-short-list.h @@ -288,6 +288,8 @@ namespace Slang void addRange(ArrayView<T> list) { addRange(list.m_buffer, list.m_count); } + void addRange(ConstArrayView<T> list) { addRange(list.m_buffer, list.m_count); } + template<int _otherShortListSize, typename TOtherAllocator> void addRange(const ShortList<T, _otherShortListSize, TOtherAllocator>& list) { diff --git a/source/slang/slang-api.cpp b/source/slang/slang-api.cpp index a3b1e5409..45c583060 100644 --- a/source/slang/slang-api.cpp +++ b/source/slang/slang-api.cpp @@ -98,11 +98,19 @@ SLANG_API SlangResult slang_createGlobalSession( { Slang::String cacheFilename; uint64_t dllTimestamp = 0; +#define SLANG_PROFILE_STDLIB_COMPILE 0 +#if SLANG_PROFILE_STDLIB_COMPILE + auto startTime = std::chrono::high_resolution_clock::now(); +#else if (tryLoadStdLibFromCache(globalSession, cacheFilename, dllTimestamp) != SLANG_OK) +#endif { // Compile std lib from embeded source. SLANG_RETURN_ON_FAIL(globalSession->compileStdLib(0)); - +#if SLANG_PROFILE_STDLIB_COMPILE + auto timeElapsed = std::chrono::high_resolution_clock::now() - startTime; + printf("stdlib compilation time: %.1fms\n", timeElapsed.count() / 1000000.0); +#endif // Store the compiled stdlib to cache file. trySaveStdLibToCache(globalSession, cacheFilename, dllTimestamp); } diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 3126aab71..dd08fece2 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -22,7 +22,18 @@ class NodeBase // MUST be called before used. Called automatically via the ASTBuilder. // Note that the astBuilder is not stored in the NodeBase derived types by default. - SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* /* astBuilder*/ ) { astNodeType = inAstNodeType; } + SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* /* astBuilder*/) + { + astNodeType = inAstNodeType; +#ifdef _DEBUG + static uint32_t uidCounter = 0; + static uint32_t breakValue = 0; + uidCounter++; + _debugUID = uidCounter; + if (breakValue != 0 && _debugUID == breakValue) + SLANG_BREAKPOINT(0) +#endif + } /// Get the class info SLANG_FORCE_INLINE const ReflectClassInfo& getClassInfo() const { return *ASTClassInfo::getInfo(astNodeType); } @@ -36,6 +47,9 @@ class NodeBase // Handy when debugging, shouldn't be checked in though! // virtual ~NodeBase() {} +#ifdef _DEBUG + SLANG_UNREFLECTED uint32_t _debugUID = 0; +#endif }; // Casting of NodeBase @@ -228,13 +242,28 @@ class GenericSubstitution : public Substitutions // parameters we are binding to arguments GenericDecl* genericDecl = nullptr; +private: // The actual values of the arguments List<Val* > args; - +public: + const List<Val*>& getArgs() const { return args; } // Overrides should be public so base classes can access Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff); bool _equalsOverride(Substitutions* subst); HashCode _getHashCodeOverride() const; + + GenericSubstitution(GenericDecl* decl) + { + genericDecl = decl; + } + + template<typename... TArgs> + GenericSubstitution(GenericDecl* decl, TArgs... inArgs) + { + genericDecl = decl; + addToList(args, inArgs...); + } + }; class ThisTypeSubstitution : public Substitutions @@ -253,6 +282,10 @@ class ThisTypeSubstitution : public Substitutions Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff); bool _equalsOverride(Substitutions* subst); HashCode _getHashCodeOverride() const; + + ThisTypeSubstitution(InterfaceDecl* inInterfaceDecl, SubtypeWitness* inWitness) + : interfaceDecl(inInterfaceDecl), witness(inWitness) + {} }; class SyntaxNode : public SyntaxNodeBase diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index fa8051171..f8c208ac1 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -224,7 +224,7 @@ NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType) Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTypeName) { - auto declRef = getBuiltinDeclRef(magicTypeName, makeConstArrayViewSingle<Val*>(typeParam)); + auto declRef = getBuiltinDeclRef(magicTypeName, typeParam); auto rsType = DeclRefType::create(this, declRef); return rsType; } @@ -263,9 +263,12 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName) ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount) { - ArrayExpressionType* arrayType = create<ArrayExpressionType>(); - arrayType->baseType = elementType; - arrayType->arrayLength = elementCount; + ArrayExpressionType* arrayType = getOrCreateWithDefaultCtor<ArrayExpressionType>(elementType, elementCount); + if (!arrayType->baseType) + { + arrayType->baseType = elementType; + arrayType->arrayLength = elementCount; + } return arrayType; } @@ -273,18 +276,15 @@ VectorExpressionType* ASTBuilder::getVectorType( Type* elementType, IntVal* elementCount) { - auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector")); - - auto vectorTypeDecl = vectorGenericDecl->inner; - - auto substitutions = create<GenericSubstitution>(); - substitutions->genericDecl = vectorGenericDecl; - substitutions->args.add(elementType); - substitutions->args.add(elementCount); - - auto declRef = DeclRef<Decl>(vectorTypeDecl, substitutions); - - return as<VectorExpressionType>(DeclRefType::create(this, declRef)); + auto result = getOrCreate<VectorExpressionType>(elementType, elementCount); + if (!result->declRef.decl) + { + auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector")); + auto vectorTypeDecl = vectorGenericDecl->inner; + auto substitutions = getOrCreate<GenericSubstitution>(vectorGenericDecl, elementType, elementCount); + result->declRef = DeclRef<Decl>(vectorTypeDecl, substitutions); + } + return result; } DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* conformanceWitness) @@ -293,10 +293,7 @@ DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witne auto typeDecl = genericDecl->inner; - auto substitutions = create<GenericSubstitution>(); - substitutions->genericDecl = genericDecl; - substitutions->args.add(valueType); - substitutions->args.add(conformanceWitness); + auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, valueType, conformanceWitness); auto declRef = DeclRef<Decl>(typeDecl, substitutions); auto rsType = DeclRefType::create(this, declRef); @@ -311,34 +308,29 @@ DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterface() return declRef; } -DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ConstArrayView<Val*> genericArgs) +DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg) { DeclRef<Decl> declRef; declRef.decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); if (auto genericDecl = as<GenericDecl>(declRef.decl)) { - if (genericArgs.getCount()) + if (genericArg) { - auto substitutions = create<GenericSubstitution>(); - substitutions->genericDecl = genericDecl; - for (auto arg : genericArgs) - substitutions->args.add(arg); + auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, genericArg); declRef.substitutions = substitutions; } declRef.decl = genericDecl->inner; } else { - SLANG_ASSERT(genericArgs.getCount() == 0); + SLANG_ASSERT(!genericArg); } return declRef; } Type* ASTBuilder::getAndType(Type* left, Type* right) { - auto type = create<AndType>(); - type->left = left; - type->right = right; + auto type = getOrCreate<AndType>(left, right); return type; } @@ -350,46 +342,26 @@ Type* ASTBuilder::getModifiedType(Type* base, Count modifierCount, Val* const* m return type; } -NodeBase* ASTBuilder::_getOrCreateImpl(NodeDesc const& desc, NodeCreateFunc createFunc, void* createFuncUserData) -{ - if(auto found = m_cachedNodes.TryGetValue(desc)) - return *found; - - auto node = createFunc(this, desc, createFuncUserData); - - auto operandCount = desc.operandCount; - NodeBase** operandsCopy = m_arena.allocateAndZeroArray<NodeBase*>(desc.operandCount); - for(Index i = 0; i < operandCount; ++i) - operandsCopy[i] = desc.operands[i]; - - NodeDesc descCopy = desc; - descCopy.operands = operandsCopy; - m_cachedNodes.Add(descCopy, node); - - return node; -} - - Val* ASTBuilder::getUNormModifierVal() { - return _getOrCreate<UNormModifierVal>(); + return getOrCreate<UNormModifierVal>(); } Val* ASTBuilder::getSNormModifierVal() { - return _getOrCreate<SNormModifierVal>(); + return getOrCreate<SNormModifierVal>(); } TypeType* ASTBuilder::getTypeType(Type* type) { - return create<TypeType>(type); + return getOrCreate<TypeType>(type); } bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const { if(type != that.type) return false; - if(operandCount != that.operandCount) return false; - for(Index i = 0; i < operandCount; ++i) + if(operands.getCount() != that.operands.getCount()) return false; + for(Index i = 0; i < operands.getCount(); ++i) { // Note: we are comparing the operands directly for identity // (pointer equality) rather than doing the `Val`-level @@ -399,7 +371,7 @@ bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const // via a `NodeDesc` *should* all be going through the // deduplication path anyway, as should their operands. // - if(operands[i] != that.operands[i]) return false; + if(operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; } return true; } @@ -407,15 +379,14 @@ HashCode ASTBuilder::NodeDesc::getHashCode() const { Hasher hasher; hasher.hashValue(Int(type)); - hasher.hashValue(operandCount); - for(Index i = 0; i < operandCount; ++i) + for(Index i = 0; i < operands.getCount(); ++i) { // Note: we are hashing the raw pointer value rather // than the content of the value node. This is done // to match the semantics implemented for `==` on // `NodeDesc`. // - hasher.hashValue((void*) operands[i]); + hasher.hashValue(operands[i].values.nodeOperand); } return hasher.getResult(); } diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index d363cb8f6..788257fa0 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -98,6 +98,49 @@ class ASTBuilder : public RefObject { friend class SharedASTBuilder; public: + // Node cache: + struct NodeOperand + { + union + { + NodeBase* nodeOperand; + int64_t intOperand; + } values; + NodeOperand() { values.nodeOperand = nullptr; } + NodeOperand(NodeBase* node) { values.nodeOperand = node; } + template<typename EnumType> + NodeOperand(EnumType intVal) + { + static_assert(sizeof(EnumType) <= sizeof(values), "size of operand must be less than pointer size."); + values.intOperand = 0; + memcpy(&values, &intVal, sizeof(intVal)); + } + }; + struct NodeDesc + { + ASTNodeType type; + ShortList<NodeOperand, 4> operands; + + bool operator==(NodeDesc const& that) const; + HashCode getHashCode() const; + }; + + template<typename NodeCreateFunc> + NodeBase* _getOrCreateImpl(NodeDesc const& desc, NodeCreateFunc createFunc) + { + if (auto found = m_cachedNodes.TryGetValue(desc)) + return *found; + + auto node = createFunc(); + m_cachedNodes.Add(desc, node); + return node; + } + + /// A cache for AST nodes that are entirely defined by their node type, with + /// no need for additional state. + Dictionary<NodeDesc, NodeBase*> m_cachedNodes; + +public: // For compile time check to see if thing being constructed is an AST type template <typename T> @@ -113,10 +156,109 @@ public: /// Create AST types template <typename T> T* create() { return _initAndAdd(new (m_arena.allocate(sizeof(T))) T); } - template<typename T, typename P0> - T* create(const P0& p0) { return _initAndAdd(new (m_arena.allocate(sizeof(T))) T(p0)); } - template<typename T, typename P0, typename P1> - T* create(const P0& p0, const P1& p1) { return _initAndAdd(new (m_arena.allocate(sizeof(T))) T(p0, p1));} + template<typename T, typename... TArgs> + T* create(TArgs... args) { return _initAndAdd(new (m_arena.allocate(sizeof(T))) T(args...)); } + + template<typename T, typename ... TArgs> + SLANG_FORCE_INLINE T* getOrCreate(TArgs ... args) + { + SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); + NodeDesc desc; + desc.type = T::kType; + addToList(desc.operands, args...); + return (T*)_getOrCreateImpl(desc, [&]() + { + return create<T>(args...); + }); + } + + template<typename T> + SLANG_FORCE_INLINE T* getOrCreate() + { + SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); + + NodeDesc desc; + desc.type = T::kType; + return (T*)_getOrCreateImpl(desc, [this]() { return create<T>(); }); + } + + template<typename T, typename ... TArgs> + SLANG_FORCE_INLINE T* getOrCreateWithDefaultCtor(TArgs ... args) + { + SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); + NodeDesc desc; + desc.type = T::kType; + addToList(desc.operands, args...); + return (T*)_getOrCreateImpl(desc, [&]() + { + return create<T>(); + }); + } + + template<typename T> + SLANG_FORCE_INLINE T* getOrCreateWithDefaultCtor(ConstArrayView<NodeOperand> operands) + { + SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); + NodeDesc desc; + desc.type = T::kType; + desc.operands.addRange(operands); + return (T*)_getOrCreateImpl(desc, [&]() + { + return create<T>(); + }); + } + + DeclRefType* getOrCreateDeclRefType(Decl* decl, Substitutions* outer) + { + NodeDesc desc; + desc.type = DeclRefType::kType; + desc.operands.add(decl); + if (outer) + { + desc.operands.add(outer); + } + auto result = (DeclRefType*)_getOrCreateImpl(desc, [&]() {return create<DeclRefType>(decl, outer); }); + return result; + } + + GenericSubstitution* getOrCreateGenericSubstitution(GenericDecl* decl, const List<Val*>& args, Substitutions* outer) + { + NodeDesc desc; + desc.type = GenericSubstitution::kType; + desc.operands.add(decl); + for (auto arg : args) + desc.operands.add(arg); + if (outer) + { + desc.operands.add(outer); + } + auto result = (GenericSubstitution*)_getOrCreateImpl(desc, [this]() {return create<GenericSubstitution>(); }); + if (result->args.getCount() != args.getCount()) + { + SLANG_RELEASE_ASSERT(result->args.getCount() == 0); + result->args.addRange(args); + result->genericDecl = decl; + result->outer = outer; + } + return result; + } + + ThisTypeSubstitution* getOrCreateThisTypeSubstitution(InterfaceDecl* interfaceDecl, SubtypeWitness* subtypeWitness, Substitutions* outer) + { + NodeDesc desc; + desc.type = ThisTypeSubstitution::kType; + desc.operands.add(interfaceDecl); + desc.operands.add(subtypeWitness); + if (outer) + { + desc.operands.add(outer); + } + auto result = (ThisTypeSubstitution*)_getOrCreateImpl(desc, [this]() {return create<ThisTypeSubstitution>(); }); + result->interfaceDecl = interfaceDecl; + result->witness = subtypeWitness; + result->outer = outer; + return result; + } NodeBase* createByNodeType(ASTNodeType nodeType); @@ -173,7 +315,7 @@ public: DeclRef<InterfaceDecl> getDifferentiableInterface(); - DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, ConstArrayView<Val*> genericArgs); + DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg); Type* getAndType(Type* left, Type* right); @@ -208,10 +350,13 @@ public: /// Dtor ~ASTBuilder(); + Dictionary<Decl*, GenericSubstitution*> m_genericDefaultSubst; + protected: // Special default Ctor that can only be used by SharedASTBuilder ASTBuilder(); + template <typename T> SLANG_FORCE_INLINE T* _initAndAdd(T* node) { @@ -237,62 +382,6 @@ protected: MemoryArena m_arena; - struct NodeDesc - { - ASTNodeType type; - Count operandCount = 0; - NodeBase* const* operands = nullptr; - - bool operator==(NodeDesc const& that) const; - HashCode getHashCode() const; - }; - - /// A cache for AST nodes that are entirely defined by their node type, with - /// no need for additional state. - Dictionary<NodeDesc, NodeBase*> m_cachedNodes; - - - typedef NodeBase* (*NodeCreateFunc)(ASTBuilder* astBuilder, NodeDesc const& desc, void* userData); - - NodeBase* _getOrCreateImpl(NodeDesc const& desc, NodeCreateFunc createFunc, void* createFuncUserData); - - template<typename T> - SLANG_FORCE_INLINE T* _getOrCreate(Count operandCount, NodeBase* const* operands) - { - SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); - - struct Helper - { - static NodeBase* create(ASTBuilder* astBuilder, NodeDesc const& desc, void* /*userData*/) - { - return astBuilder->create<T>(desc.operandCount, desc.operands); - } - }; - - NodeDesc desc; - desc.type = T::kType; - desc.operandCount = operandCount; - desc.operands = operands; - return (T*) _getOrCreateImpl(desc, &Helper::create, nullptr); - } - - template<typename T> - SLANG_FORCE_INLINE T* _getOrCreate() - { - SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); - - struct Helper - { - static NodeBase* create(ASTBuilder* astBuilder, NodeDesc const& /*desc*/, void* /*userData*/) - { - return astBuilder->create<T>(); - } - }; - - NodeDesc desc; - desc.type = T::kType; - return (T*) _getOrCreateImpl(desc, &Helper::create, nullptr); - } }; } // namespace Slang diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index e81191086..830c6bf34 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -190,7 +190,7 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth) sb << "<"; bool first = true; - for (auto arg : genSubst->args) + for (auto arg : genSubst->getArgs()) { // When printing the representation of a specialized // generic declaration we don't want to include the diff --git a/source/slang/slang-ast-substitutions.cpp b/source/slang/slang-ast-substitutions.cpp index 6656f1fa6..7c6cf1bc6 100644 --- a/source/slang/slang-ast-substitutions.cpp +++ b/source/slang/slang-ast-substitutions.cpp @@ -63,10 +63,8 @@ Substitutions* GenericSubstitution::_applySubstitutionsShallowOverride(ASTBuilde if (!diff) return this; (*ioDiff)++; - auto substSubst = astBuilder->create<GenericSubstitution>(); - substSubst->genericDecl = genericDecl; - substSubst->args = substArgs; - substSubst->outer = substOuter; + + auto substSubst = astBuilder->getOrCreateGenericSubstitution(genericDecl, substArgs, substOuter); return substSubst; } @@ -126,10 +124,9 @@ Substitutions* ThisTypeSubstitution::_applySubstitutionsShallowOverride(ASTBuild if (!diff) return this; (*ioDiff)++; - auto substSubst = astBuilder->create<ThisTypeSubstitution>(); - substSubst->interfaceDecl = interfaceDecl; - substSubst->witness = substWitness; - substSubst->outer = substOuter; + ThisTypeSubstitution* substSubst; + + substSubst = astBuilder->getOrCreateThisTypeSubstitution(interfaceDecl, substWitness, substOuter); return substSubst; } diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index a84f04a32..077d6de0a 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -224,7 +224,7 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe { // We've found it, so return the corresponding specialization argument (*ioDiff)++; - return genericSubst->args[index]; + return genericSubst->getArgs()[index]; } else if (auto typeParam = as<GenericTypeParamDecl>(m)) { @@ -351,17 +351,17 @@ BasicExpressionType* MatrixExpressionType::_getScalarTypeOverride() Type* MatrixExpressionType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); } IntVal* MatrixExpressionType::getRowCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]); + return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]); } IntVal* MatrixExpressionType::getColumnCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[2]); + return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[2]); } Type* MatrixExpressionType::getRowType() @@ -518,12 +518,12 @@ Type* NamespaceType::_createCanonicalTypeOverride() Type* PtrTypeBase::getValueType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); } Type* OptionalType::getValueType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 458c24a23..5bb91e5da 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -78,9 +78,13 @@ class DeclRefType : public Type Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); protected: - DeclRefType( DeclRef<Decl> declRef) + DeclRefType(DeclRef<Decl> declRef) : declRef(declRef) {} + + DeclRefType(Decl* decl, Substitutions* substitutions) + : declRef(decl, substitutions) + {} }; // Base class for types that can be used in arithmetic expressions @@ -210,6 +214,11 @@ class SamplerStateType : public BuiltinType // What flavor of sampler state is this SamplerStateFlavor flavor; + + SamplerStateType(SamplerStateFlavor inFlavor) + { + flavor = inFlavor; + } }; // Other cases of generic types known to the compiler @@ -467,6 +476,10 @@ class VectorExpressionType : public ArithmeticExpressionType // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); BasicExpressionType* _getScalarTypeOverride(); + + VectorExpressionType(Type* inElementType, IntVal* inElementCount) + : elementType(inElementType), elementCount(inElementCount) + {} }; // A matrix type, e.g., `matrix<T,R,C>` @@ -486,6 +499,8 @@ class MatrixExpressionType : public ArithmeticExpressionType private: Type* rowType = nullptr; + + MatrixExpressionType(Type*, IntVal*, IntVal*) {} }; // Base class for built in string types @@ -794,6 +809,10 @@ class AndType : public Type Type* left; Type* right; + AndType(Type* leftType, Type* rightType) + : left(leftType), right(rightType) + {} + // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); bool _equalsImplOverride(Type* type); diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 64d32f4b4..a70a79535 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -137,7 +137,7 @@ Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, S { // We've found it, so return the corresponding specialization argument (*ioDiff)++; - return genSubst->args[index]; + return genSubst->getArgs()[index]; } else if (auto typeParam = as<GenericTypeParamDecl>(m)) { @@ -265,8 +265,8 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub (*ioDiff)++; auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() + genericDecl->getMembersOfType<GenericValueParamDecl>().getCount(); - SLANG_ASSERT(index + ordinaryParamCount < genericSubst->args.getCount()); - return genericSubst->args[index + ordinaryParamCount]; + SLANG_ASSERT(index + ordinaryParamCount < genericSubst->getArgs().getCount()); + return genericSubst->getArgs()[index + ordinaryParamCount]; } } } @@ -323,7 +323,8 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub } } - DeclaredSubtypeWitness* rs = astBuilder->create<DeclaredSubtypeWitness>(); + DeclaredSubtypeWitness* rs = astBuilder->getOrCreate<DeclaredSubtypeWitness>( + substSub, substSup, substDeclRef.getDecl(), substDeclRef.substitutions.substitutions); rs->sub = substSub; rs->sup = substSup; rs->declRef = substDeclRef; @@ -722,7 +723,7 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut *ioDiff += diff; if (evaluatedTerms.getCount() == 0) - return astBuilder->create<ConstantIntVal>(type, evaluatedConstantTerm); + return astBuilder->getOrCreate<ConstantIntVal>(type, evaluatedConstantTerm); if (diff != 0) { auto newPolynomial = astBuilder->create<PolynomialIntVal>(type); @@ -1035,7 +1036,7 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder) return terms[0]->paramFactors[0]->param; } if (terms.getCount() == 0) - return builder->create<ConstantIntVal>(type, constantTerm); + return builder->getOrCreate<ConstantIntVal>(type, constantTerm); return this; } @@ -1207,7 +1208,7 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR { SLANG_UNREACHABLE("constant folding of FuncCallIntVal"); } - return astBuilder->create<ConstantIntVal>(resultType, resultValue); + return astBuilder->getOrCreate<ConstantIntVal>(resultType, resultValue); } return nullptr; } diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 69797b3a5..0f7bdec8e 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -53,6 +53,10 @@ class GenericParamIntVal : public IntVal HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + GenericParamIntVal(Type* inType, VarDeclBase* inDecl, Substitutions* inSubst) + : IntVal(inType), declRef(inDecl, inSubst) + {} + protected: GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef) : IntVal(inType), declRef(inDeclRef) @@ -315,6 +319,13 @@ class DeclaredSubtypeWitness : public SubtypeWitness void _toTextOverride(StringBuilder& out); HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + DeclaredSubtypeWitness(Type* inSub, Type* inSup, Decl* decl, Substitutions* subst) + : declRef(decl, subst) + { + sub = inSub; + sup = inSup; + } }; // A witness that `sub : sup` because `sub : mid` and `mid : sup` @@ -333,6 +344,10 @@ class TransitiveSubtypeWitness : public SubtypeWitness void _toTextOverride(StringBuilder& out); HashCode _getHashCodeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + TransitiveSubtypeWitness(SubtypeWitness* inSubToMid, SubtypeWitness* inMidToSup) + : subToMid(inSubToMid), midToSup(inMidToSup) + {} }; // A witness taht `sub : sup` because `sub` was wrapped into diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 5889c6140..e0c1f3702 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -10,10 +10,11 @@ namespace Slang DeclaredSubtypeWitness* SemanticsVisitor::createSimpleSubtypeWitness( TypeWitnessBreadcrumb* breadcrumb) { - DeclaredSubtypeWitness* witness = m_astBuilder->create<DeclaredSubtypeWitness>(); - witness->sub = breadcrumb->sub; - witness->sup = breadcrumb->sup; - witness->declRef = breadcrumb->declRef; + DeclaredSubtypeWitness* witness = m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( + breadcrumb->sub, + breadcrumb->sup, + breadcrumb->declRef.decl, + breadcrumb->declRef.substitutions.substitutions); return witness; } @@ -82,12 +83,11 @@ namespace Slang // where `[...]` represents the "hole" we leave // open to fill in next. // - DeclaredSubtypeWitness* declaredWitness = m_astBuilder->create<DeclaredSubtypeWitness>(); - declaredWitness->sub = bb->sub; - declaredWitness->sup = bb->sup; - declaredWitness->declRef = bb->declRef; + DeclaredSubtypeWitness* declaredWitness = + m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( + bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions); - TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->create<TransitiveSubtypeWitness>(); + TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, bb->sup, declaredWitness); transitiveWitness->sub = subType; transitiveWitness->sup = bb->sup; transitiveWitness->midToSup = declaredWitness; diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index a03043f31..129d3ed0c 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -391,11 +391,8 @@ namespace Slang // search for a conformance `Robin : ISidekick`, which involved // apply the substitutions we already know... - GenericSubstitution* solvedSubst = m_astBuilder->create<GenericSubstitution>(); - solvedSubst->genericDecl = genericDeclRef.getDecl(); - solvedSubst->outer = genericDeclRef.substitutions.substitutions; - solvedSubst->args = args; - resultSubst.substitutions = solvedSubst; + GenericSubstitution* solvedSubst = m_astBuilder->getOrCreateGenericSubstitution( + genericDeclRef.getDecl(), args, genericDeclRef.substitutions.substitutions); for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { @@ -412,7 +409,7 @@ namespace Slang if(subTypeWitness) { // We found a witness, so it will become an (implicit) argument. - solvedSubst->args.add(subTypeWitness); + args.add(subTypeWitness); } else { @@ -437,6 +434,8 @@ namespace Slang } } + resultSubst = m_astBuilder->getOrCreateGenericSubstitution( + genericDeclRef.getDecl(), args, genericDeclRef.substitutions.substitutions); return resultSubst; } @@ -546,12 +545,12 @@ namespace Slang return false; // Their arguments must unify - SLANG_RELEASE_ASSERT(fstGen->args.getCount() == sndGen->args.getCount()); - Index argCount = fstGen->args.getCount(); + SLANG_RELEASE_ASSERT(fstGen->getArgs().getCount() == sndGen->getArgs().getCount()); + Index argCount = fstGen->getArgs().getCount(); bool okay = true; for (Index aa = 0; aa < argCount; ++aa) { - if (!TryUnifyVals(constraints, fstGen->args[aa], sndGen->args[aa])) + if (!TryUnifyVals(constraints, fstGen->getArgs()[aa], sndGen->getArgs()[aa])) { okay = false; } diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 4efacd703..2f5447ffb 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -349,7 +349,7 @@ namespace Slang // We have a new type for the conversion, based on what // we learned. toType = m_astBuilder->getArrayType(toElementType, - m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)); + m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)); } } else if(auto toMatrixType = as<MatrixExpressionType>(toType)) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 9f19023ee..125f0cb08 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -519,42 +519,73 @@ namespace Slang return semantics->ApplyExtensionToType(extDecl, type); } + void ensureDecl(SemanticsVisitor* visitor, Decl* decl, DeclCheckState state) + { + visitor->ensureDecl(decl, state); + } + GenericSubstitution* createDefaultSubstitutionsForGeneric( - ASTBuilder* astBuilder, + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, GenericDecl* genericDecl, Substitutions* outerSubst) { - GenericSubstitution* genericSubst = astBuilder->create<GenericSubstitution>(); - genericSubst->genericDecl = genericDecl; - genericSubst->outer = outerSubst; + GenericSubstitution* cachedResult = nullptr; + if (astBuilder->m_genericDefaultSubst.TryGetValue(genericDecl, cachedResult)) + { + if (cachedResult->outer == outerSubst) + return cachedResult; + } + + List<Val*> args; for( auto mm : genericDecl->members ) { if( auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm) ) { - genericSubst->args.add(DeclRefType::create(astBuilder, DeclRef<Decl>(genericTypeParamDecl, outerSubst))); + args.add(DeclRefType::create(astBuilder, DeclRef<Decl>(genericTypeParamDecl, outerSubst))); } else if( auto genericValueParamDecl = as<GenericValueParamDecl>(mm) ) { - genericSubst->args.add(astBuilder->create<GenericParamIntVal>( + args.add(astBuilder->getOrCreate<GenericParamIntVal>( genericValueParamDecl->getType(), - DeclRef<GenericValueParamDecl>(genericValueParamDecl, outerSubst))); + genericValueParamDecl, outerSubst)); } } + bool shouldCache = true; + // create default substitution arguments for constraints for (auto mm : genericDecl->members) { if (auto genericTypeConstraintDecl = as<GenericTypeConstraintDecl>(mm)) { - DeclaredSubtypeWitness* witness = astBuilder->create<DeclaredSubtypeWitness>(); - witness->declRef = DeclRef<Decl>(genericTypeConstraintDecl, outerSubst); - witness->sub = genericTypeConstraintDecl->sub.type; - witness->sup = genericTypeConstraintDecl->sup.type; - genericSubst->args.add(witness); + if (semantics) + { + ensureDecl(semantics, genericTypeConstraintDecl, DeclCheckState::ReadyForReference); + } + auto constraintDeclRef = DeclRef<GenericTypeConstraintDecl>(genericTypeConstraintDecl, outerSubst); + DeclaredSubtypeWitness* witness = + astBuilder->getOrCreate<DeclaredSubtypeWitness>( + getSub(astBuilder, constraintDeclRef), + getSup(astBuilder, constraintDeclRef), + genericTypeConstraintDecl, + outerSubst); + // TODO: this is an ugly hack to prevent crashing. + // In early stages of compilation witness->sub and witness->sup may not be checked yet. + // When semanticVisitor is present we have used that to ensure the type is checked. + // However due to how the code is written we cannot guarantee semanticVisitor is always available + // here, and if we can't get the checked sup/sub type this subst is incomplete and should not be + // cached. + if (!witness->sub) + shouldCache = false; + args.add(witness); } } + GenericSubstitution* genericSubst = astBuilder->getOrCreateGenericSubstitution(genericDecl, args, outerSubst); + if (shouldCache) + astBuilder->m_genericDefaultSubst[genericDecl] = genericSubst; return genericSubst; } @@ -563,7 +594,8 @@ namespace Slang // using their archetypes). // SubstitutionSet createDefaultSubstitutions( - ASTBuilder* astBuilder, + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, Decl* decl, SubstitutionSet outerSubstSet) { @@ -577,6 +609,7 @@ namespace Slang GenericSubstitution* genericSubst = createDefaultSubstitutionsForGeneric( astBuilder, + semantics, genericDecl, outerSubstSet.substitutions); @@ -587,23 +620,19 @@ namespace Slang } SubstitutionSet createDefaultSubstitutions( - ASTBuilder* astBuilder, + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, Decl* decl) { SubstitutionSet subst; if( auto parentDecl = decl->parentDecl ) { - subst = createDefaultSubstitutions(astBuilder, parentDecl); + subst = createDefaultSubstitutions(astBuilder, semantics, parentDecl); } - subst = createDefaultSubstitutions(astBuilder, decl, subst); + subst = createDefaultSubstitutions(astBuilder, semantics, decl, subst); return subst; } - void ensureDecl(SemanticsVisitor* visitor, Decl* decl, DeclCheckState state) - { - visitor->ensureDecl(decl, state); - } - bool SemanticsVisitor::isDeclUsableAsStaticMember( Decl* decl) { @@ -1195,7 +1224,7 @@ namespace Slang { if (auto declRefType = as<DeclRefType>(sharedTypeExpr->base)) { - declRefType->declRef.substitutions = createDefaultSubstitutions(m_astBuilder, declRefType->declRef.getDecl()); + declRefType->declRef.substitutions = createDefaultSubstitutions(m_astBuilder, this, declRefType->declRef.getDecl()); if (auto typetype = as<TypeType>(typeExp.exp->type)) typetype->type = declRefType; @@ -1754,9 +1783,7 @@ namespace Slang // compare `Derived::doThing` against `IBase::doThing<U>` where the `U` there is // the parameter of `Dervived::doThing`. // - GenericSubstitution* requiredSubst = m_astBuilder->create<GenericSubstitution>(); - requiredSubst->genericDecl = requiredGenericDeclRef.getDecl(); - requiredSubst->outer = requiredGenericDeclRef.substitutions; + List<Val*> requiredSubstArgs; for (Index i = 0; i < memberCount; i++) { @@ -1769,17 +1796,20 @@ namespace Slang SLANG_ASSERT(satisfyingTypeParamDeclRef); auto satisfyingType = DeclRefType::create(m_astBuilder, satisfyingTypeParamDeclRef); - requiredSubst->args.add(satisfyingType); + requiredSubstArgs.add(satisfyingType); } else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as<GenericValueParamDecl>()) { auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>(); SLANG_ASSERT(satisfyingValueParamDeclRef); - auto satisfyingVal = m_astBuilder->create<GenericParamIntVal>(); + auto satisfyingVal = m_astBuilder->getOrCreate<GenericParamIntVal>( + requiredValueParamDeclRef.getDecl()->getType(), + satisfyingValueParamDeclRef.getDecl(), + satisfyingValueParamDeclRef.substitutions.substitutions); satisfyingVal->declRef = satisfyingValueParamDeclRef; - requiredSubst->args.add(satisfyingVal); + requiredSubstArgs.add(satisfyingVal); } } for (Index i = 0; i < memberCount; i++) @@ -1792,15 +1822,20 @@ namespace Slang auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as<GenericTypeConstraintDecl>(); SLANG_ASSERT(satisfyingConstraintDeclRef); - auto satisfyingWitness = m_astBuilder->create<DeclaredSubtypeWitness>(); + auto satisfyingWitness = m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(); satisfyingWitness->sub = getSub(m_astBuilder, satisfyingConstraintDeclRef); satisfyingWitness->sup = getSup(m_astBuilder, satisfyingConstraintDeclRef); satisfyingWitness->declRef = satisfyingConstraintDeclRef; - requiredSubst->args.add(satisfyingWitness); + requiredSubstArgs.add(satisfyingWitness); } } + GenericSubstitution* requiredSubst = m_astBuilder->getOrCreateGenericSubstitution( + requiredGenericDeclRef.getDecl(), + requiredSubstArgs, + requiredGenericDeclRef.substitutions); + // Now that we have computed a set of specialization arguments that will // specialize the generic requirement at the type parameters of the satisfying // generic, we can construct a reference to that declaration and re-run some @@ -2764,13 +2799,15 @@ namespace Slang auto reqType = getBaseType(m_astBuilder, requiredInheritanceDeclRef); - DeclaredSubtypeWitness* interfaceIsReqWitness = m_astBuilder->create<DeclaredSubtypeWitness>(); - interfaceIsReqWitness->sub = superInterfaceType; - interfaceIsReqWitness->sup = reqType; - interfaceIsReqWitness->declRef = requiredInheritanceDeclRef; + DeclaredSubtypeWitness* interfaceIsReqWitness = + m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( + superInterfaceType, + reqType, + requiredInheritanceDeclRef.getDecl(), + requiredInheritanceDeclRef.substitutions.substitutions); // ... - TransitiveSubtypeWitness* subIsReqWitness = m_astBuilder->create<TransitiveSubtypeWitness>(); + TransitiveSubtypeWitness* subIsReqWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, reqType, interfaceIsReqWitness); subIsReqWitness->sub = subType; subIsReqWitness->sup = reqType; subIsReqWitness->subToMid = subTypeConformsToSuperInterfaceWitness; @@ -3232,7 +3269,7 @@ namespace Slang void SemanticsVisitor::checkExtensionConformance(ExtensionDecl* decl) { - auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, makeDeclRef(decl)).as<ExtensionDecl>(); + auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)).as<ExtensionDecl>(); auto targetType = getTargetType(m_astBuilder, declRef); for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>()) @@ -3265,7 +3302,7 @@ namespace Slang auto astBuilder = getASTBuilder(); - auto declRef = createDefaultSubstitutionsIfNeeded(astBuilder, makeDeclRef(decl)).as<AggTypeDeclBase>(); + auto declRef = createDefaultSubstitutionsIfNeeded(astBuilder, this, makeDeclRef(decl)).as<AggTypeDeclBase>(); auto type = DeclRefType::create(astBuilder, declRef); // TODO: Need to figure out what this should do for @@ -4222,8 +4259,7 @@ namespace Slang GenericSubstitution* SemanticsVisitor::createDummySubstitutions( GenericDecl* genericDecl) { - GenericSubstitution* subst = m_astBuilder->create<GenericSubstitution>(); - subst->genericDecl = genericDecl; + List<Val*> args; for (auto dd : genericDecl->members) { if (dd == genericDecl->inner) @@ -4232,17 +4268,19 @@ namespace Slang if (auto typeParam = as<GenericTypeParamDecl>(dd)) { auto type = DeclRefType::create(m_astBuilder, makeDeclRef(typeParam)); - subst->args.add(type); + args.add(type); } else if (auto valueParam = as<GenericValueParamDecl>(dd)) { - auto val = m_astBuilder->create<GenericParamIntVal>( + auto val = m_astBuilder->getOrCreate<GenericParamIntVal>( valueParam->getType(), - makeDeclRef(valueParam)); - subst->args.add(val); + valueParam, + nullptr); + args.add(val); } // TODO: need to handle constraints here? } + GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution(genericDecl, args, nullptr); return subst; } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 11b001c2c..4d55669e2 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -865,8 +865,7 @@ namespace Slang IntVal* SemanticsVisitor::getIntVal(IntegerLiteralExpr* expr) { - // TODO(tfoley): don't keep allocating here! - return m_astBuilder->create<ConstantIntVal>(expr->type.type, expr->value); + return m_astBuilder->getOrCreate<ConstantIntVal>(expr->type.type, expr->value); } IntVal* SemanticsVisitor::tryConstantFoldExpr( @@ -1091,7 +1090,7 @@ namespace Slang } } - IntVal* result = m_astBuilder->create<ConstantIntVal>(invokeExpr.getExpr()->type.type, resultValue); + IntVal* result = m_astBuilder->getOrCreate<ConstantIntVal>(invokeExpr.getExpr()->type.type, resultValue); return result; } @@ -1166,7 +1165,6 @@ namespace Slang expr = getBaseExpr(parenExpr); } - // TODO(tfoley): more serious constant folding here if (auto intLitExpr = expr.as<IntegerLiteralExpr>()) { return getIntVal(intLitExpr); @@ -1176,7 +1174,7 @@ namespace Slang { // If it's a boolean, we allow promotion to int. const IntegerLiteralValue value = IntegerLiteralValue(boolLitExpr.getExpr()->value); - return m_astBuilder->create<ConstantIntVal>(m_astBuilder->getBoolType(), value); + return m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getBoolType(), value); } // it is possible that we are referring to a generic value param @@ -1186,9 +1184,10 @@ namespace Slang if (auto genericValParamRef = declRef.as<GenericValueParamDecl>()) { - Val* valResult = m_astBuilder->create<GenericParamIntVal>( + Val* valResult = m_astBuilder->getOrCreate<GenericParamIntVal>( declRef.substitute(m_astBuilder, genericValParamRef.getDecl()->getType()), - genericValParamRef); + genericValParamRef.getDecl(), + genericValParamRef.substitutions.substitutions); valResult = valResult->substitute(m_astBuilder, expr.getSubsts()); return as<IntVal>(valResult); } @@ -2158,7 +2157,7 @@ namespace Slang // here if the input type had a sugared name... swizExpr->type = QualType(createVectorType( baseElementType, - m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount))); + m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount))); } // A swizzle can be used as an l-value as long as there @@ -2279,7 +2278,7 @@ namespace Slang // here if the input type had a sugared name... swizExpr->type = QualType(createVectorType( baseElementType, - m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount))); + m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount))); } // A swizzle can be used as an l-value as long as there diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 091990cf3..ef500d4fa 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -202,10 +202,32 @@ namespace Slang Substitutions* subst = nullptr; }; + struct LookupRequestKey + { + NodeBase* base; + Name* name; + LookupOptions options; + LookupMask mask; + bool operator==(const LookupRequestKey& other) const + { + return base == other.base && name == other.name && options == other.options && mask == other.mask; + } + HashCode getHashCode() const + { + Hasher hasher; + hasher.hashValue(base); + hasher.hashValue(name); + hasher.hashValue(options); + hasher.hashValue(mask); + return hasher.getResult(); + } + }; + struct TypeCheckingCache { Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache; Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache; + Dictionary<LookupRequestKey, LookupResult> lookupCache; }; /// Shared state for a semantics-checking session. @@ -1467,7 +1489,7 @@ namespace Slang // bool TryCheckOverloadCandidateConstraints( OverloadResolveContext& context, - OverloadCandidate const& candidate); + OverloadCandidate& candidate); // Try to check an overload candidate, but bail out // if any step fails diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 9109130e2..55263453d 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -210,7 +210,7 @@ namespace Slang // auto genSubst = m_astBuilder->create<GenericSubstitution>(); candidate.subst = genSubst; - auto& checkedArgs = genSubst->args; + auto& checkedArgs = (List<Val*>&)genSubst->getArgs(); // Rather than bail out as soon as we hit a problem, // we are going to process *all* of the parameters of the @@ -474,7 +474,7 @@ namespace Slang bool SemanticsVisitor::TryCheckOverloadCandidateConstraints( OverloadResolveContext& context, - OverloadCandidate const& candidate) + OverloadCandidate& candidate) { // We only need this step for generics, so always succeed on // everything else. @@ -493,6 +493,8 @@ namespace Slang subst->genericDecl = genericDeclRef.getDecl(); subst->outer = genericDeclRef.substitutions.substitutions; + List<Val*> newArgs = subst->getArgs(); + for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { auto subset = genericDeclRef.substitutions; @@ -506,7 +508,7 @@ namespace Slang auto subTypeWitness = tryGetSubtypeWitness(sub, sup); if(subTypeWitness) { - subst->args.add(subTypeWitness); + newArgs.add(subTypeWitness); } else { @@ -518,6 +520,8 @@ namespace Slang } } + candidate.subst = m_astBuilder->getOrCreateGenericSubstitution(genericDeclRef.getDecl(), newArgs, genericDeclRef.substitutions.substitutions); + // Done checking all the constraints, hooray. return true; } diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 2e577b6d5..d7200d47c 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -1120,7 +1120,7 @@ namespace Slang if(!intVal) { sink->diagnose(param.loc, Diagnostics::expectedValueOfTypeForSpecializationArg, paramDecl->getType(), paramDecl); - intVal = getLinkage()->getASTBuilder()->create<ConstantIntVal>(m_astBuilder->getIntType(), 0); + intVal = getLinkage()->getASTBuilder()->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), 0); } ModuleSpecializationInfo::GenericArgInfo expandedArg; @@ -1192,15 +1192,18 @@ namespace Slang auto genericDeclRef = m_funcDeclRef.getParent().as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters - GenericSubstitution* genericSubst = getLinkage()->getASTBuilder()->create<GenericSubstitution>(); - genericSubst->outer = genericDeclRef.substitutions.substitutions; - genericSubst->genericDecl = genericDeclRef.getDecl(); + List<Val*> genericArgs; for(Index ii = 0; ii < genericSpecializationParamCount; ++ii) { auto specializationArg = args[ii]; - genericSubst->args.add(specializationArg.val); + genericArgs.add(specializationArg.val); } + GenericSubstitution* genericSubst = + getLinkage()->getASTBuilder()->getOrCreateGenericSubstitution( + genericDeclRef.getDecl(), + genericArgs, + genericDeclRef.substitutions.substitutions); for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { @@ -1218,7 +1221,7 @@ namespace Slang auto subTypeWitness = visitor.tryGetSubtypeWitness(sub, sup); if(subTypeWitness) { - genericSubst->args.add(subTypeWitness); + genericArgs.add(subTypeWitness); } else { @@ -1228,6 +1231,11 @@ namespace Slang } } + genericSubst = + getLinkage()->getASTBuilder()->getOrCreateGenericSubstitution( + genericDeclRef.getDecl(), + genericArgs, + genericDeclRef.substitutions.substitutions); specializedFuncDeclRef.substitutions.substitutions = genericSubst; } diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 2bf6cb830..a25a8683d 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -163,7 +163,7 @@ namespace Slang } else { - ConstantIntVal* rangeBeginConst = m_astBuilder->create<ConstantIntVal>(); + ConstantIntVal* rangeBeginConst = m_astBuilder->getOrCreate<ConstantIntVal>(); rangeBeginConst->value = 0; rangeBeginVal = rangeBeginConst; } diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index c59e1b308..d402dde03 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -171,15 +171,16 @@ namespace Slang DeclRef<GenericDecl> genericDeclRef, List<Expr*> const& args) { - GenericSubstitution* subst = m_astBuilder->create<GenericSubstitution>(); - subst->genericDecl = genericDeclRef.getDecl(); - subst->outer = genericDeclRef.substitutions.substitutions; + List<Val*> evaledArgs; for (auto argExpr : args) { - subst->args.add(ExtractGenericArgVal(argExpr)); + evaledArgs.add(ExtractGenericArgVal(argExpr)); } + GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution( + genericDeclRef.getDecl(), evaledArgs, genericDeclRef.substitutions.substitutions); + DeclRef<Decl> innerDeclRef; innerDeclRef.decl = getInner(genericDeclRef); innerDeclRef.substitutions = SubstitutionSet(subst); @@ -403,18 +404,7 @@ namespace Slang Type* elementType, IntVal* elementCount) { - auto vectorGenericDecl = as<GenericDecl>(m_astBuilder->getSharedASTBuilder()->findMagicDecl("Vector")); - - auto vectorTypeDecl = vectorGenericDecl->inner; - - auto substitutions = m_astBuilder->create<GenericSubstitution>(); - substitutions->genericDecl = vectorGenericDecl; - substitutions->args.add(elementType); - substitutions->args.add(elementCount); - - auto declRef = DeclRef<Decl>(vectorTypeDecl, substitutions); - - return as<VectorExpressionType>(DeclRefType::create(m_astBuilder, declRef)); + return m_astBuilder->getVectorType(elementType, elementCount); } Expr* SemanticsExprVisitor::visitSharedTypeExpr(SharedTypeExpr* expr) diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index 48f7ea099..a8818fd5c 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -843,7 +843,7 @@ static void _lookUpInScopes( // have a null `containerDecl` and needs to be // skipped over. // - if(!containerDecl) + if (!containerDecl) continue; // TODO: If we need default substitutions to be applied to @@ -852,8 +852,8 @@ static void _lookUpInScopes( // just a decl. // DeclRef<ContainerDecl> containerDeclRef = - DeclRef<Decl>(containerDecl, createDefaultSubstitutions(astBuilder, containerDecl)).as<ContainerDecl>(); - + DeclRef<Decl>(containerDecl, createDefaultSubstitutions(astBuilder, request.semantics, containerDecl)).as<ContainerDecl>(); + // If the container we are looking into represents a type // or an `extension` of a type, then we need to treat // this step as lookup into the `this` variable (or the @@ -878,9 +878,9 @@ static void _lookUpInScopes( breadcrumb.prev = nullptr; Type* type = nullptr; - if(auto extDeclRef = aggTypeDeclBaseRef.as<ExtensionDecl>()) + if (auto extDeclRef = aggTypeDeclBaseRef.as<ExtensionDecl>()) { - if( request.semantics ) + if (request.semantics) { ensureDecl(request.semantics, extDeclRef.getDecl(), DeclCheckState::CanUseExtensionTargetType); } @@ -918,14 +918,14 @@ static void _lookUpInScopes( // of some nested type, then there shouldn't be an implicit `this` // expression for the outer type, but instead an implicit `This`. // - if( containerDeclRef.is<ConstructorDecl>() ) + if (containerDeclRef.is<ConstructorDecl>()) { // In the context of an `__init` declaration, the members of // the surrounding type are accessible through a mutable `this`. // thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::MutableValue; } - else if( containerDeclRef.is<SetterDecl>() ) + else if (containerDeclRef.is<SetterDecl>()) { // In the context of a `set` accessor, the members of the // surrounding type are accessible through a mutable `this`. @@ -937,19 +937,19 @@ static void _lookUpInScopes( // thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::MutableValue; } - else if( auto funcDeclRef = containerDeclRef.as<FunctionDeclBase>() ) + else if (auto funcDeclRef = containerDeclRef.as<FunctionDeclBase>()) { // The implicit `this`/`This` for a function-like declaration // depends on modifiers attached to the declaration. // - if( funcDeclRef.getDecl()->hasModifier<HLSLStaticModifier>() ) + if (funcDeclRef.getDecl()->hasModifier<HLSLStaticModifier>()) { // A `static` method only has access to an implicit `This`, // and does not have a `this` expression available. // thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Type; } - else if( funcDeclRef.getDecl()->hasModifier<MutatingAttribute>() ) + else if (funcDeclRef.getDecl()->hasModifier<MutatingAttribute>()) { // In a non-`static` method marked `[mutating]` there is // an implicit `this` parameter that is mutable. @@ -964,7 +964,7 @@ static void _lookUpInScopes( thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::ImmutableValue; } } - else if( containerDeclRef.as<AggTypeDeclBase>() ) + else if (containerDeclRef.as<AggTypeDeclBase>()) { // When lookup moves from a nested typed declaration to an // outer scope, there is no ability to use an implicit `this` @@ -972,7 +972,6 @@ static void _lookUpInScopes( // thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Type; } - // TODO: What other cases need to be enumerated here? } if (result.isValid()) @@ -1002,9 +1001,27 @@ LookupResult lookUp( Scope* scope, LookupMask mask) { - LookupRequest request = initLookupRequest(semantics, name, mask, LookupOptions::None, scope); LookupResult result; + LookupRequestKey key; + TypeCheckingCache* typeCheckingCache = nullptr; + if (semantics) + { + typeCheckingCache = semantics->getLinkage()->getTypeCheckingCache(); + key.base = scope; + key.name = name; + key.options = LookupOptions::None; + key.mask = mask; + if (typeCheckingCache->lookupCache.TryGetValue(key, result)) + { + return result; + } + } + LookupRequest request = initLookupRequest(semantics, name, mask, LookupOptions::None, scope); _lookUpInScopes(astBuilder, name, request, result); + if (typeCheckingCache) + { + typeCheckingCache->lookupCache[key] = result; + } return result; } @@ -1016,9 +1033,20 @@ LookupResult lookUpMember( LookupMask mask, LookupOptions options) { - LookupRequest request = initLookupRequest(semantics, name, mask, options, nullptr); + TypeCheckingCache* typeCheckingCache = semantics->getLinkage()->getTypeCheckingCache(); + LookupRequestKey key; + key.base = type; + key.name = name; + key.options = options; + key.mask = mask; LookupResult result; + if (typeCheckingCache->lookupCache.TryGetValue(key, result)) + { + return result; + } + LookupRequest request = initLookupRequest(semantics, name, mask, options, nullptr); _lookUpMembersInType(astBuilder, name, type, request, result, nullptr); + typeCheckingCache->lookupCache[key] = result; return result; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 99af2abd6..221abe2b8 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1813,7 +1813,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower _collectSubstitutionArgs(operands, subst->outer); if (auto genSubst = as<GenericSubstitution>(subst)) { - for (auto arg : genSubst->args) + for (auto arg : genSubst->getArgs()) { operands.add(lowerVal(context, arg).val); } @@ -2561,19 +2561,19 @@ ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection de return kParameterDirection_In; } -DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, Decl* decl) +DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, SemanticsVisitor* semantics, Decl* decl) { DeclRef<Decl> declRef; declRef.decl = decl; - declRef.substitutions = createDefaultSubstitutions(context->astBuilder, decl); + declRef.substitutions = createDefaultSubstitutions(context->astBuilder, semantics, decl); return declRef; } // // The client should actually call the templated wrapper, to preserve type information. template<typename D> -DeclRef<D> createDefaultSpecializedDeclRef(IRGenContext* context, D* decl) +DeclRef<D> createDefaultSpecializedDeclRef(IRGenContext* context, SemanticsVisitor* semantics, D* decl) { - DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(context, decl); + DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(context, semantics, decl); return declRef.as<D>(); } @@ -3502,8 +3502,8 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> void _lowerSubstitutionArg(IRGenContext* subContext, GenericSubstitution* subst, Decl* paramDecl, Index argIndex) { - SLANG_ASSERT(argIndex < subst->args.getCount()); - auto argVal = lowerVal(subContext, subst->args[argIndex]); + SLANG_ASSERT(argIndex < subst->getArgs().getCount()); + auto argVal = lowerVal(subContext, subst->getArgs()[argIndex]); setValue(subContext, paramDecl, argVal); } @@ -7669,7 +7669,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( funcTypeContext, - createDefaultSpecializedDeclRef(funcTypeContext, decl), + createDefaultSpecializedDeclRef(funcTypeContext, nullptr, decl), info); auto irFuncType = info.type; @@ -7710,7 +7710,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( subContext, - createDefaultSpecializedDeclRef(context, decl), + createDefaultSpecializedDeclRef(context, nullptr, decl), info); auto irFuncType = info.type; @@ -8368,7 +8368,7 @@ LoweredValInfo emitDeclRef( // We have the IR value for the generic we'd like to specialize, // and now we need to get the value for the arguments. List<IRInst*> irArgs; - for (auto argVal : genericSubst->args) + for (auto argVal : genericSubst->getArgs()) { auto irArgVal = lowerSimpleVal(context, argVal); SLANG_ASSERT(irArgVal); diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index c230776db..c15402fbb 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -374,9 +374,9 @@ namespace Slang { // This is the case where we *do* have substitutions. emitRaw(context, "G"); - UInt genericArgCount = subst->args.getCount(); + UInt genericArgCount = subst->getArgs().getCount(); emit(context, genericArgCount); - for( auto aa : subst->args ) + for (auto aa : subst->getArgs()) { emitVal(context, aa); } diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 957d4e661..9a7200e9d 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -355,6 +355,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt DeclRef<Decl> createDefaultSubstitutionsIfNeeded( ASTBuilder* astBuilder, + SemanticsVisitor* semantics, DeclRef<Decl> declRef) { // It is possible that `declRef` refers to a generic type, @@ -413,7 +414,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt if(!foundSubst) { Substitutions* newSubst = createDefaultSubstitutionsForGeneric( - astBuilder, + astBuilder, + semantics, genericParentDecl, nullptr); @@ -436,7 +438,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt ASTBuilder* astBuilder, DeclRef<Decl> declRef) { - declRef = createDefaultSubstitutionsIfNeeded(astBuilder, declRef); + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); if (auto builtinMod = declRef.getDecl()->findModifier<BuiltinTypeModifier>()) { @@ -458,58 +460,60 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt if (magicMod->magicName == "SamplerState") { - auto type = astBuilder->create<SamplerStateType>(); + auto type = astBuilder->getOrCreate<SamplerStateType>(SamplerStateFlavor(magicMod->tag)); type->declRef = declRef; - type->flavor = SamplerStateFlavor(magicMod->tag); return type; } else if (magicMod->magicName == "Vector") { - SLANG_ASSERT(subst && subst->args.getCount() == 2); - auto vecType = astBuilder->create<VectorExpressionType>(); + SLANG_ASSERT(subst && subst->getArgs().getCount() == 2); + auto vecType = astBuilder->getOrCreate<VectorExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1])); vecType->declRef = declRef; - vecType->elementType = ExtractGenericArgType(subst->args[0]); - vecType->elementCount = ExtractGenericArgInteger(subst->args[1]); + vecType->elementType = ExtractGenericArgType(subst->getArgs()[0]); + vecType->elementCount = ExtractGenericArgInteger(subst->getArgs()[1]); return vecType; } else if (magicMod->magicName == "Matrix") { - SLANG_ASSERT(subst && subst->args.getCount() == 3); - auto matType = astBuilder->create<MatrixExpressionType>(); + SLANG_ASSERT(subst && subst->getArgs().getCount() == 3); + auto matType = astBuilder->getOrCreate<MatrixExpressionType>( + ExtractGenericArgType(subst->getArgs()[0]), + ExtractGenericArgInteger(subst->getArgs()[1]), + ExtractGenericArgInteger(subst->getArgs()[2])); matType->declRef = declRef; return matType; } else if (magicMod->magicName == "Texture") { - SLANG_ASSERT(subst && subst->args.getCount() >= 1); - auto textureType = astBuilder->create<TextureType>( + SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); + auto textureType = astBuilder->getOrCreate<TextureType>( TextureFlavor(magicMod->tag), - ExtractGenericArgType(subst->args[0])); + ExtractGenericArgType(subst->getArgs()[0])); textureType->declRef = declRef; return textureType; } else if (magicMod->magicName == "TextureSampler") { - SLANG_ASSERT(subst && subst->args.getCount() >= 1); + SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); auto textureType = astBuilder->create<TextureSamplerType>( TextureFlavor(magicMod->tag), - ExtractGenericArgType(subst->args[0])); + ExtractGenericArgType(subst->getArgs()[0])); textureType->declRef = declRef; return textureType; } else if (magicMod->magicName == "GLSLImageType") { - SLANG_ASSERT(subst && subst->args.getCount() >= 1); - auto textureType = astBuilder->create<GLSLImageType>( + SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); + auto textureType = astBuilder->getOrCreate<GLSLImageType>( TextureFlavor(magicMod->tag), - ExtractGenericArgType(subst->args[0])); + ExtractGenericArgType(subst->getArgs()[0])); textureType->declRef = declRef; return textureType; } else if (magicMod->magicName == "FeedbackType") { SLANG_ASSERT(subst == nullptr); - auto type = astBuilder->create<FeedbackType>(); + auto type = astBuilder->getOrCreateWithDefaultCtor<FeedbackType>(magicMod->tag); type->declRef = declRef; type->kind = FeedbackType::Kind(magicMod->tag); return type; @@ -519,11 +523,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // and we can drive the dispatch with a table instead // of this ridiculously slow `if` cascade. - #define CASE(n,T) \ - else if(magicMod->magicName == #n) { \ - auto type = astBuilder->create<T>(); \ - type->declRef = declRef; \ - return type; \ + #define CASE(n, T) \ + else if (magicMod->magicName == #n) \ + { \ + auto type = astBuilder->getOrCreateWithDefaultCtor<T>( \ + declRef.decl, declRef.substitutions.substitutions); \ + type->declRef = declRef; \ + return type; \ } CASE(HLSLInputPatchType, HLSLInputPatchType) @@ -531,14 +537,16 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt #undef CASE - #define CASE(n,T) \ - else if(magicMod->magicName == #n) { \ - SLANG_ASSERT(subst && subst->args.getCount() == 1); \ - auto type = astBuilder->create<T>(); \ - type->elementType = ExtractGenericArgType(subst->args[0]); \ - type->declRef = declRef; \ - return type; \ - } + #define CASE(n, T) \ + else if (magicMod->magicName == #n) \ + { \ + SLANG_ASSERT(subst && subst->getArgs().getCount() == 1); \ + auto type = \ + astBuilder->getOrCreateWithDefaultCtor<T>(ExtractGenericArgType(subst->getArgs()[0])); \ + type->elementType = ExtractGenericArgType(subst->getArgs()[0]); \ + type->declRef = declRef; \ + return type; \ + } CASE(ConstantBuffer, ConstantBufferType) CASE(TextureBuffer, TextureBufferType) @@ -561,8 +569,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // "magic" builtin types which have no generic parameters #define CASE(n,T) \ - else if(magicMod->magicName == #n) { \ - auto type = astBuilder->create<T>(); \ + else if(magicMod->magicName == #n) { \ + auto type = astBuilder->getOrCreate<T>(); \ type->declRef = declRef; \ return type; \ } @@ -601,7 +609,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } else { - return astBuilder->create<DeclRefType>(declRef); + return astBuilder->getOrCreateDeclRefType(declRef.decl, declRef.substitutions.substitutions); } } @@ -786,10 +794,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt substsToApply, &diff); - GenericSubstitution* firstSubst = astBuilder->create<GenericSubstitution>(); - firstSubst->genericDecl = ancestorGenericDecl; - firstSubst->args = appGenericSubst->args; - firstSubst->outer = restSubst; + GenericSubstitution* firstSubst = astBuilder->getOrCreateGenericSubstitution( + ancestorGenericDecl, appGenericSubst->getArgs(), restSubst); (*ioDiff)++; return firstSubst; @@ -849,10 +855,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt substsToApply, &diff); - ThisTypeSubstitution* firstSubst = astBuilder->create<ThisTypeSubstitution>(); - firstSubst->interfaceDecl = ancestorInterfaceDecl; - firstSubst->witness = appThisTypeSubst->witness; - firstSubst->outer = restSubst; + ThisTypeSubstitution* firstSubst = astBuilder->getOrCreateThisTypeSubstitution( + ancestorInterfaceDecl, appThisTypeSubst->witness, restSubst); (*ioDiff)++; return firstSubst; @@ -1013,12 +1017,12 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt Type* HLSLPatchType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); } IntVal* HLSLPatchType::getElementCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]); + return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]); } // Constructors for types @@ -1047,7 +1051,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt ASTBuilder* astBuilder, DeclRef<TypeDefDecl> const& declRef) { - DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, declRef).as<TypeDefDecl>(); + DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef).as<TypeDefDecl>(); return astBuilder->create<NamedExpressionType>(specializedDeclRef); } @@ -1179,7 +1183,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt { out << "<"; bool isFirst = true; - for (const auto& it : genericSubstitution->args) + for (const auto& it : genericSubstitution->getArgs()) { if (!isFirst) out << ", "; diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 332661e12..4e1900636 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -286,20 +286,24 @@ namespace Slang // TODO: where should this live? SubstitutionSet createDefaultSubstitutions( - ASTBuilder* astBuilder, + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, Decl* decl, SubstitutionSet parentSubst); SubstitutionSet createDefaultSubstitutions( - ASTBuilder* astBuilder, + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, Decl* decl); DeclRef<Decl> createDefaultSubstitutionsIfNeeded( - ASTBuilder* astBuilder, + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, DeclRef<Decl> declRef); GenericSubstitution* createDefaultSubstitutionsForGeneric( - ASTBuilder* astBuilder, + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, GenericDecl* genericDecl, Substitutions* outerSubst); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index a0355c266..b68dbc14a 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1178,7 +1178,7 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType( ConstantBufferType* cbType = getASTBuilder()->create<ConstantBufferType>(); cbType->elementType = type; cbType->declRef = getASTBuilder()->getBuiltinDeclRef( - "ConstantBuffer", makeConstArrayViewSingle<Val*>(static_cast<Val*>(type))); + "ConstantBuffer", static_cast<Val*>(type)); containerTypeReflection = cbType; } break; @@ -1187,7 +1187,7 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType( ParameterBlockType* pbType = getASTBuilder()->create<ParameterBlockType>(); pbType->elementType = type; pbType->declRef = getASTBuilder()->getBuiltinDeclRef( - "ParameterBlock", makeConstArrayViewSingle<Val*>(static_cast<Val*>(type))); + "ParameterBlock", static_cast<Val*>(type)); containerTypeReflection = pbType; } break; @@ -1197,7 +1197,7 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType( getASTBuilder()->create<HLSLStructuredBufferType>(); sbType->elementType = type; sbType->declRef = getASTBuilder()->getBuiltinDeclRef( - "HLSLStructuredBufferType", makeConstArrayViewSingle<Val*>(static_cast<Val*>(type))); + "HLSLStructuredBufferType", static_cast<Val*>(type)); containerTypeReflection = sbType; } break; @@ -3839,7 +3839,7 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor { if(auto genericSubst = as<GenericSubstitution>(substitution)) { - for(auto arg : genericSubst->args) + for(auto arg : genericSubst->getArgs()) { collectReferencedModules(arg); } |
