diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-decl.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-conformance.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 212 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 31 | ||||
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-check-type.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-check.cpp | 1 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-lookup.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 43 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 4 |
13 files changed, 308 insertions, 45 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index b3dbbef58..833acd681 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -96,9 +96,6 @@ class ExtensionDecl : public AggTypeDeclBase SLANG_CLASS(ExtensionDecl) TypeExp targetType; - - // next extension attached to the same nominal type - ExtensionDecl* nextCandidateExtension = nullptr; }; // Declaration of a type that represents some sort of aggregate @@ -106,9 +103,6 @@ class AggTypeDecl : public AggTypeDeclBase { SLANG_ABSTRACT_CLASS(AggTypeDecl) - // extensions that might apply to this declaration - ExtensionDecl* candidateExtensions = nullptr; - FilteredMemberList<VarDecl> getFields() { return getMembersOfType<VarDecl>(); @@ -371,6 +365,12 @@ class ModuleDecl : public NamespaceDeclBase // its chain of parents. // Module* module = nullptr; + + /// Map a type to the list of extensions of that type (if any) declared in this module + /// + /// This mapping is filled in during semantic checking, as `ExtensionDecl`s get checked. + /// + Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> mapTypeToCandidateExtensions; }; class ImportDecl : public Decl diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 8143d4e17..076c6cbbb 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1337,6 +1337,14 @@ namespace Slang }; typedef List<ExpandedSpecializationArg> ExpandedSpecializationArgs; + /// A reference-counted object to hold a list of candidate extensions + /// that might be applicable to a type based on its declaration. + /// + struct CandidateExtensionList : RefObject + { + List<ExtensionDecl*> candidateExtensions; + }; + } // namespace Slang #endif diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 2a0e661a3..4dcfb3065 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -180,7 +180,8 @@ namespace Slang { ensureDecl(aggTypeDeclRef, DeclCheckState::CanEnumerateBases); - for( auto inheritanceDeclRef : getMembersOfTypeWithExt<InheritanceDecl>(aggTypeDeclRef)) + bool found = false; + foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> const& inheritanceDeclRef) { ensureDecl(inheritanceDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); @@ -209,9 +210,12 @@ namespace Slang if(_isDeclaredSubtype(originalSubType, inheritedType, superTypeDeclRef, outWitness, &breadcrumb)) { - return true; + found = true; } - } + }); + if(found) + return true; + // if an inheritance decl is not found, try to find a GenericTypeConstraintDecl for (auto genConstraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(aggTypeDeclRef)) { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index e4cf40122..2879d7fa2 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2028,8 +2028,9 @@ namespace Slang // // For now we will just walk through the extensions that are known at // the time we are compiling and handle those, and punt on the larger issue - // for abit longer. - for(auto candidateExt = interfaceDeclRef.getDecl()->candidateExtensions; candidateExt; candidateExt = candidateExt->nextCandidateExtension) + // for a bit longer. + // + for(auto candidateExt : getCandidateExtensions(interfaceDeclRef, this)) { // We need to apply the extension to the interface type that our // concrete type is inheriting from. @@ -3519,8 +3520,9 @@ namespace Slang if (auto aggTypeDeclRef = targetDeclRefType->declRef.as<AggTypeDecl>()) { auto aggTypeDecl = aggTypeDeclRef.getDecl(); - decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; - aggTypeDecl->candidateExtensions = decl; + + getShared()->registerCandidateExtension(aggTypeDecl, decl); + return; } } @@ -4089,6 +4091,208 @@ namespace Slang } } + /// Get a reference to the candidate extension list for `typeDecl` in the given dictionary + /// + /// Note: this function creates an empty list of candidates for the given type if + /// a matching entry doesn't exist already. + /// + static List<ExtensionDecl*>& _getCandidateExtensionList( + AggTypeDecl* typeDecl, + Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>>& mapTypeToCandidateExtensions) + { + RefPtr<CandidateExtensionList> entry; + if( !mapTypeToCandidateExtensions.TryGetValue(typeDecl, entry) ) + { + entry = new CandidateExtensionList(); + mapTypeToCandidateExtensions.Add(typeDecl, entry); + } + return entry->candidateExtensions; + } + + List<ExtensionDecl*> const& SharedSemanticsContext::getCandidateExtensionsForTypeDecl(AggTypeDecl* decl) + { + // We are caching the lists of candidate extensions on the shared + // context, so we will only build the lists if they either have + // not been built before, or if some code caused the lists to + // be invalidated. + // + // TODO: Similar to the rebuilding of lookup tables in `ContainerDecl`s, + // we probably want to optimize this logic to gracefully handle new + // extensions encountered during checking instead of tearing the whole + // thing down. For now this potentially-quadratic behavior is acceptable + // because there just aren't that many extension declarations being used. + // + if( !m_candidateExtensionListsBuilt ) + { + m_candidateExtensionListsBuilt = true; + + // We need to make sure that all extensions that were declared + // as part of our standard-library modules are always visible, + // even if they are not explicit `import`ed into user code. + // + for( auto module : getSession()->stdlibModules ) + { + _addCandidateExtensionsFromModule(module->getModuleDecl()); + } + + // There are two primary modes in which the `SharedSemanticsContext` + // gets used. + // + // In the first mode, we are checking an entire `ModuelDecl`, and we + // need to always check things from the "point of view" of that module + // (so that the extensions that should be visible are based on what + // that module can access via `import`s). + // + // In the second mode, we are checking code related to API interactions + // by the user (e.g., parsing a type from a string, specializing an + // entry point to type arguments, etc.). In these cases there is no + // clear module that should determine the point of view for looking + // up extensions, and we instead need/want to consider any extensions + // from all modules loaded into the linkage. + // + // We differentiate these cases based on whether a "primary" module + // was set at the time the `SharedSemanticsContext` was constructed. + // + if( m_module ) + { + // We have a "primary" module that is being checked, and we should + // look up extensions based on what would be visible to that + // module. + // + // We need to consider the extensions declared in the module itself, + // along with everything the module imported. + // + // Note: there is an implicit assumption here that the `importedModules` + // member on the `SharedSemanticsContext` is accurate in this case. + // + _addCandidateExtensionsFromModule(m_module->getModuleDecl()); + for( auto moduleDecl : this->importedModules ) + { + _addCandidateExtensionsFromModule(moduleDecl); + } + } + else + { + // We are in one of the many ad hoc checking modes where we really + // want to resolve things based on the totality of what is + // available/defined within the current linkage. + // + for( auto module : m_linkage->loadedModulesList ) + { + _addCandidateExtensionsFromModule(module->getModuleDecl()); + } + } + } + + // Once we are sure that the dictionary-of-arrays of extensions + // has been populated, we return to the user the entry they + // asked for. + // + return _getCandidateExtensionList(decl, m_mapTypeDeclToCandidateExtensions); + } + + void SharedSemanticsContext::registerCandidateExtension(AggTypeDecl* typeDecl, ExtensionDecl* extDecl) + { + // The primary cache of extension declarations is on the `ModuleDecl`. + // We will add the `extDecl` to the cache for the module it belongs to. + // + // We can be sure that the resulting cache won't have lifetime issues, + // because all the extensions it contains are owned by the module itself, + // and the types used as keys had to be reachable/referenceable from the + // code inside the module for the given `extDecl` to extend them. + // + auto moduleDecl = getModuleDecl(extDecl); + _getCandidateExtensionList(typeDecl, moduleDecl->mapTypeToCandidateExtensions).add(extDecl); + + // Because we've loaded a new extension, we need to invalidate whatever + // information the `SharedSemanticsContext` had cached about loaded + // extensions, and force it to rebuild its cache to include the + // new extension we just added. + // + // TODO: We should probably just go ahead and add `extDecl` directly + // into the appropriate entry here, and do a similar step on each + // `import`. + // + m_candidateExtensionListsBuilt = false; + m_mapTypeDeclToCandidateExtensions.Clear(); + } + + void SharedSemanticsContext::_addCandidateExtensionsFromModule(ModuleDecl* moduleDecl) + { + for( auto& entry : moduleDecl->mapTypeToCandidateExtensions ) + { + auto& list = _getCandidateExtensionList(entry.Key, m_mapTypeDeclToCandidateExtensions); + list.addRange(entry.Value->candidateExtensions); + } + } + + List<ExtensionDecl*> const& getCandidateExtensions( + DeclRef<AggTypeDecl> const& declRef, + SemanticsVisitor* semantics) + { + auto decl = declRef.getDecl(); + auto shared = semantics->getShared(); + return shared->getCandidateExtensionsForTypeDecl(decl); + } + + void _foreachDirectOrExtensionMemberOfType( + SemanticsVisitor* semantics, + DeclRef<ContainerDecl> const& containerDeclRef, + SyntaxClassBase const& syntaxClass, + void (*callback)(DeclRefBase, void*), + void const* userData) + { + // We are being asked to invoke the given callback on + // each direct member of `containerDeclRef`, along with + // any members added via `extension` declarations, that + // have the correct AST node class (`syntaxClass`). + // + // We start with the direct members. + // + for( auto memberDeclRef : getMembers(containerDeclRef) ) + { + if( memberDeclRef.decl->getClass().isSubClassOfImpl(syntaxClass) ) + { + callback(memberDeclRef, (void*)userData); + } + } + + // Next, in the case wher ethe type can be subject to extensions, + // we loop over the applicable extensions and their member.s + // + if(auto aggTypeDeclRef = containerDeclRef.as<AggTypeDecl>()) + { + auto aggType = DeclRefType::create(semantics->getASTBuilder(), aggTypeDeclRef); + for(auto extDecl : getCandidateExtensions(aggTypeDeclRef, semantics)) + { + // Note that `extDecl` may have been declared for a type + // base on the declaration that `aggTypeDeclRef` refers + // to, but that does not guarantee that it applies to + // the type itself. E.g., we might have an extension of + // `vector<float, N>` for any `N`, but the current type is + // `vector<int, 2>` so that the extension doesn't match. + // + // In order to make sure that we don't enumerate members + // that don't make sense in context, we must apply + // the extension to the type and see if we succeed in + // making a match. + // + auto extDeclRef = ApplyExtensionToType(semantics, extDecl, aggType); + if(!extDeclRef) + continue; + + for( auto memberDeclRef : getMembers(extDeclRef) ) + { + if( memberDeclRef.decl->getClass().isSubClassOfImpl(syntaxClass) ) + { + callback(memberDeclRef, (void*)userData); + } + } + } + } + } + + static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, SharedSemanticsContext* shared) { switch(state) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 931e331a6..97e77ec3e 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -199,8 +199,12 @@ namespace Slang /// Shared state for a semantics-checking session. struct SharedSemanticsContext { - Linkage* m_linkage = nullptr; - DiagnosticSink* m_sink = nullptr; + Linkage* m_linkage = nullptr; + + /// The (optional) "primary" module that is the parent to everything that will be checked. + Module* m_module = nullptr; + + DiagnosticSink* m_sink = nullptr; DiagnosticSink* getSink() { @@ -217,8 +221,10 @@ namespace Slang public: SharedSemanticsContext( Linkage* linkage, + Module* module, DiagnosticSink* sink) : m_linkage(linkage) + , m_module(module) , m_sink(sink) {} @@ -231,6 +237,27 @@ namespace Slang { return m_linkage; } + + Module* getModule() + { + return m_module; + } + + /// Get the list of extension declarations that appear to apply to `decl` in this context + List<ExtensionDecl*> const& getCandidateExtensionsForTypeDecl(AggTypeDecl* decl); + + /// Register a candidate extension `extDecl` for `typeDecl` encountered during checking. + void registerCandidateExtension(AggTypeDecl* typeDecl, ExtensionDecl* extDecl); + + private: + /// Mapping from type declarations to the known extensiosn that apply to them + Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> m_mapTypeDeclToCandidateExtensions; + + /// Is the `m_mapTypeDeclToCandidateExtensions` dictionary valid and up to date? + bool m_candidateExtensionListsBuilt = false; + + /// Add candidate extensions declared in `moduleDecl` to `m_mapTypeDeclToCandidateExtensions` + void _addCandidateExtensionsFromModule(ModuleDecl* moduleDecl); }; struct SemanticsVisitor diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 17ca9fc78..fe35f8a19 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -954,7 +954,7 @@ namespace Slang { SLANG_ASSERT(argCount == getSpecializationParamCount()); - SharedSemanticsContext semanticsContext(getLinkage(), sink); + SharedSemanticsContext semanticsContext(getLinkage(), this, sink); SemanticsVisitor visitor(&semanticsContext); RefPtr<Module::ModuleSpecializationInfo> specializationInfo = new Module::ModuleSpecializationInfo(); @@ -1129,7 +1129,7 @@ namespace Slang { auto linkage = componentType->getLinkage(); - SharedSemanticsContext semanticsContext(linkage, sink); + SharedSemanticsContext semanticsContext(linkage, nullptr, sink); SemanticsVisitor semanticsVisitor(&semanticsContext); auto argCount = argExprs.getCount(); @@ -1152,7 +1152,7 @@ namespace Slang auto args = inArgs; auto argCount = inArgCount; - SharedSemanticsContext sharedSemanticsContext(getLinkage(), sink); + SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, sink); SemanticsVisitor visitor(&sharedSemanticsContext); // The first N arguments will be for the explicit generic parameters @@ -1331,6 +1331,7 @@ namespace Slang SharedSemanticsContext sharedSemanticsContext( linkage, + nullptr, sink); SemanticsVisitor semantics(&sharedSemanticsContext); @@ -1367,7 +1368,7 @@ namespace Slang // TODO: We should cache and re-use specialized types // when the exact same arguments are provided again later. - SharedSemanticsContext sharedSemanticsContext(this, sink); + SharedSemanticsContext sharedSemanticsContext(this, nullptr, sink); SemanticsVisitor visitor(&sharedSemanticsContext); SpecializationParams specializationParams; @@ -1425,7 +1426,7 @@ namespace Slang // We have an appropriate number of arguments for the global specialization parameters, // and now we need to check that the arguments conform to the declared constraints. // - SharedSemanticsContext visitor(linkage, sink); + SharedSemanticsContext visitor(linkage, nullptr, sink); List<SpecializationArg> specializationArgs; _extractSpecializationArgs(unspecializedProgram, specializationArgExprs, specializationArgs, sink); diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 98b376a69..6c9b0c8ca 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -13,6 +13,7 @@ namespace Slang { SharedSemanticsContext sharedSemanticsContext( linkage, + nullptr, sink); SemanticsVisitor visitor(&sharedSemanticsContext); diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index b86d90cec..e97cb0b61 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -201,6 +201,7 @@ namespace Slang { SharedSemanticsContext sharedSemanticsContext( translationUnit->compileRequest->getLinkage(), + translationUnit->getModule(), translationUnit->compileRequest->getSink()); SemanticsDeclVisitorBase visitor(&sharedSemanticsContext); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index cf1983511..146cf9ed0 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2078,7 +2078,7 @@ namespace Slang RefPtr<Scope> slangLanguageScope; ModuleDecl* baseModuleDecl = nullptr; - List<RefPtr<Module>> loadedModuleCode; + List<RefPtr<Module>> stdlibModules; SourceManager builtinSourceManager; diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index 25e1eedce..01451470e 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -492,7 +492,7 @@ static void _lookUpMembersInSuperTypeDeclImpl( // directly with that type. // ensureDecl(request.semantics, aggTypeDeclRef.getDecl(), DeclCheckState::ReadyForLookup); - for(auto extDecl = getCandidateExtensions(aggTypeDeclRef); extDecl; extDecl = extDecl->nextCandidateExtension) + for(auto extDecl : getCandidateExtensions(aggTypeDeclRef, semantics)) { // Note: In this case `extDecl` is an extension that was declared to apply // (conditionally) to `aggTypeDeclRef`, which is the decl-ref part of diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 06603547f..d207b45bf 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -1242,18 +1242,26 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return rs; } - - -Module* getModule(Decl* decl) + +ModuleDecl* getModuleDecl(Decl* decl) { for( auto dd = decl; dd; dd = dd->parentDecl ) { if(auto moduleDecl = as<ModuleDecl>(dd)) - return moduleDecl->module; + return moduleDecl; } return nullptr; } +Module* getModule(Decl* decl) +{ + auto moduleDecl = getModuleDecl(decl); + if(!moduleDecl) + return nullptr; + + return moduleDecl->module; +} + bool findImageFormatByName(char const* name, ImageFormat* outFormat) { static const struct diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 81c191ccf..a0c54f914 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -46,10 +46,11 @@ namespace Slang // Declarations // - inline ExtensionDecl* getCandidateExtensions(DeclRef<AggTypeDecl> const& declRef) - { - return declRef.getDecl()->candidateExtensions; - } + struct SemanticsVisitor; + + List<ExtensionDecl*> const& getCandidateExtensions( + DeclRef<AggTypeDecl> const& declRef, + SemanticsVisitor* semantics); inline FilteredMemberRefList<Decl> getMembers(DeclRef<ContainerDecl> const& declRef, MemberFilterStyle filterStyle = MemberFilterStyle::All) { @@ -62,22 +63,27 @@ namespace Slang return FilteredMemberRefList<T>(declRef.getDecl()->members, declRef.substitutions, filterStyle); } - template<typename T> - inline List<DeclRef<T>> getMembersOfTypeWithExt(DeclRef<ContainerDecl> const& declRef, MemberFilterStyle filterStyle = MemberFilterStyle::All) + void _foreachDirectOrExtensionMemberOfType( + SemanticsVisitor* semantics, + DeclRef<ContainerDecl> const& declRef, + SyntaxClassBase const& syntaxClass, + void (*callback)(DeclRefBase, void*), + void const* userData); + + template<typename T, typename F> + inline void foreachDirectOrExtensionMemberOfType( + SemanticsVisitor* semantics, + DeclRef<ContainerDecl> const& declRef, + F const& func) { - List<DeclRef<T>> rs; - for (auto d : getMembersOfType<T>(declRef, filterStyle)) - rs.add(d); - if (auto aggDeclRef = declRef.as<AggTypeDecl>()) + struct Helper { - for (auto ext = getCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension) + static void callback(DeclRefBase declRef, void* userData) { - auto extMembers = getMembersOfType<T>(DeclRef<ContainerDecl>(ext, declRef.substitutions), filterStyle); - for (auto mbr : extMembers) - rs.add(mbr); + (*(F*)userData)(DeclRef<T>((T*) declRef.decl, declRef.substitutions)); } - } - return rs; + }; + _foreachDirectOrExtensionMemberOfType(semantics, declRef, getClass<T>(), &Helper::callback, &func); } /// The the user-level name for a variable that might be a shader parameter. @@ -256,7 +262,10 @@ namespace Slang All = 7 }; - /// Get the module that a declaration is associated with, if any. + /// Get the module dclaration that a declaration is associated with, if any. + ModuleDecl* getModuleDecl(Decl* decl); + + /// Get the module that a declaration is associated with, if any. Module* getModule(Decl* decl); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index b79cecb59..8523a445d 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -2683,13 +2683,13 @@ void Session::addBuiltinSource( // We need to retain this AST so that we can use it in other code // (Note that the `Scope` type does not retain the AST it points to) - loadedModuleCode.add(module); + stdlibModules.add(module); } Session::~Session() { // destroy modules next - loadedModuleCode = decltype(loadedModuleCode)(); + stdlibModules = decltype(stdlibModules)(); } } |
