diff options
Diffstat (limited to 'source/slang/slang.cpp')
| -rw-r--r-- | source/slang/slang.cpp | 295 |
1 files changed, 283 insertions, 12 deletions
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 254a37c24..4a25c2392 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1552,6 +1552,7 @@ void FrontEndCompileRequest::checkAllTranslationUnits() { checkTranslationUnit(translationUnit.Ptr()); } + checkEntryPoints(); } void FrontEndCompileRequest::generateIR() @@ -1676,7 +1677,6 @@ SlangResult FrontEndCompileRequest::executeActionsInner() if (getSink()->getErrorCount() != 0) return SLANG_FAIL; - // Look up all the entry points that are expected, // and use them to populate the `program` member. // @@ -1905,13 +1905,11 @@ SlangResult EndToEndCompileRequest::executeActions() int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* moduleName) { - Index result = translationUnits.getCount(); - if (!moduleName) { // We want to ensure that symbols defined in different translation // units get unique mangled names, so that we can, e.g., tell apart - // a `main()` function in `vertex.slang` and a `main()` in `fragment.slang`, + // a `main()` function in `vertex.hlsl` and a `main()` in `fragment.hlsl`, // even when they are being compiled together. // String generatedName = "tu"; @@ -1925,11 +1923,17 @@ int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* mo translationUnit->moduleName = moduleName; - translationUnits.add(translationUnit); + return addTranslationUnit(translationUnit); +} +int FrontEndCompileRequest::addTranslationUnit(TranslationUnitRequest* translationUnit) +{ + Index result = translationUnits.getCount(); + translationUnits.add(translationUnit); return (int) result; } + void FrontEndCompileRequest::addTranslationUnitSourceFile( int translationUnitIndex, SourceFile* sourceFile) @@ -2041,6 +2045,7 @@ UInt Linkage::addTarget( } void Linkage::loadParsedModule( + RefPtr<FrontEndCompileRequest> compileRequest, RefPtr<TranslationUnitRequest> translationUnit, Name* name, const PathInfo& pathInfo) @@ -2061,7 +2066,7 @@ void Linkage::loadParsedModule( auto sink = translationUnit->compileRequest->getSink(); int errorCountBefore = sink->getErrorCount(); - checkTranslationUnit(translationUnit.Ptr()); + compileRequest->checkAllTranslationUnits(); int errorCountAfter = sink->getErrorCount(); if (errorCountAfter != errorCountBefore) @@ -2106,6 +2111,8 @@ RefPtr<Module> Linkage::loadModule( translationUnit->moduleName = name; translationUnit->sourceLanguage = SourceLanguage::Slang; + frontEndReq->addTranslationUnit(translationUnit); + auto module = translationUnit->getModule(); ModuleBeingImportedRAII moduleBeingImported( @@ -2132,6 +2139,7 @@ RefPtr<Module> Linkage::loadModule( } loadParsedModule( + frontEndReq, translationUnit, name, filePathInfo); @@ -2902,6 +2910,134 @@ RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpeci // SpecializedComponentType // +/// Utility type for collecting modules references by types/declarations +struct SpecializationArgModuleCollector : ComponentTypeVisitor +{ + HashSet<Module*> m_modulesSet; + List<Module*> m_modulesList; + + void addModule(Module* module) + { + m_modulesList.add(module); + m_modulesSet.Add(module); + } + + void maybeAddModule(Module* module) + { + if(!module) + return; + if(m_modulesSet.Contains(module)) + return; + + addModule(module); + } + + void collectReferencedModules(Decl* decl) + { + auto module = getModule(decl); + maybeAddModule(module); + } + + void collectReferencedModules(Substitutions* substitution) + { + if(auto genericSubst = as<GenericSubstitution>(substitution)) + { + for(auto arg : genericSubst->args) + { + collectReferencedModules(arg); + } + } + } + + void collectReferencedModules(SubstitutionSet const& substitutions) + { + for(auto subst = substitutions.substitutions; subst; subst = subst->outer) + { + collectReferencedModules(subst); + } + } + + void collectReferencedModules(DeclRefBase const& declRef) + { + collectReferencedModules(declRef.decl); + collectReferencedModules(declRef.substitutions); + } + + void collectReferencedModules(Type* type) + { + if(auto declRefType = as<DeclRefType>(type)) + { + collectReferencedModules(declRefType->declRef); + } + + // TODO: Handle non-decl-ref composite type cases + // (e.g., function types). + } + + void collectReferencedModules(Val* val) + { + if(auto type = as<Type>(val)) + { + collectReferencedModules(type); + } + else if (auto declRefVal = as<GenericParamIntVal>(val)) + { + collectReferencedModules(declRefVal->declRef); + } + + // TODO: other cases of values that could reference + // a declaration. + } + + void collectReferencedModules(List<ExpandedSpecializationArg> const& args) + { + for(auto arg : args) + { + collectReferencedModules(arg.val); + collectReferencedModules(arg.witness); + } + } + + // + // ComponentTypeVisitor methods + // + + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(entryPoint); + + if(!specializationInfo) + return; + + collectReferencedModules(specializationInfo->specializedFuncDeclRef); + collectReferencedModules(specializationInfo->existentialSpecializationArgs); + } + + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(module); + + if(!specializationInfo) + return; + + for(auto arg : specializationInfo->genericArgs) + { + collectReferencedModules(arg.argVal); + } + collectReferencedModules(specializationInfo->existentialArgs); + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } +}; + SpecializedComponentType::SpecializedComponentType( ComponentType* base, ComponentType::SpecializationInfo* specializationInfo, @@ -2914,8 +3050,147 @@ SpecializedComponentType::SpecializedComponentType( { m_irModule = generateIRForSpecializedComponentType(this, sink); + // We need to account for the fact that a specialized + // entity like `myShader<SomeType>` needs to not only + // depend on the module(s) that `myShader` depends on, + // but also on any modules that `SomeType` depends on. + // + // We will set up a "collector" type that will be + // used to build a list of these additional modules. + // + SpecializationArgModuleCollector moduleCollector; + + // We don't want to go adding additional requirements for + // modules that the base component type already includes, + // so we will add those to the set of modules in + // the collector before we starting trying to add others. + // + base->enumerateModules([&](Module* module) + { + moduleCollector.m_modulesSet.Add(module); + }); + + // In order to collect the additional modules, we need + // to inspect the specialization arguments and see what + // they depend on. + // + // Naively, it seems like we'd just want to iterate + // over `specializationArgs`, which gives the specialization + // arguments as the user supplied them. However, such + // an approach would have a subtle problem. + // + // If we have a generic entry point like: + // + // // In module A + // myShader<T : IThing> + // + // + // And the type `SomeType` that is being used as an argument doesn't + // directly conform to `IThing`: + // + // // In module B + // struct SomeType { ... } + // + // and the conformance of `SomeType` to `IThing` is + // coming from yet another module: + // + // // In module C + // import B; + // extension SomeType : IThing { ... } + // + // In this case, the specialized component for `myShader<SomeType>` + // needs to depend on all of: + // + // * Module A, because it defines `myShader` + // * Module B, because it defines `SomeType` + // * Module C, because it defines the conformance `SomeType : IThing` + // + // We thus need to iterate over a form of the specialization + // arguments that includes the "expanded" arguments like + // interface conformance witnesses that got added during + // semantic checking. + // + // The expanded arguments are being stored in the `specializationInfo` + // today (for use by downstream code generation), and the easiest + // way to walk that information and get to the leaf nodes where + // the expanded arguments are stored is to apply a visitor to + // the specialized component type we are in the middle of constructing. + // + moduleCollector.visitSpecialized(this); + + // Now that we've collected our additional information, we can + // start to build up the final lists for the specialized component type. + // + // The starting point for our lists comes from the base component type. + // + m_moduleDependencies = base->getModuleDependencies(); + m_filePathDependencies = base->getFilePathDependencies(); + + Index baseRequirementCount = base->getRequirementCount(); + for( Index r = 0; r < baseRequirementCount; r++ ) + { + m_requirements.add(base->getRequirement(r)); + } + + // The specialized component type will need to have additional + // dependencies and requirements based on the modules that + // were collected when looking at the specialization arguments. + + // We want to avoid adding the same file path dependency more than once. + // + HashSet<String> filePathDependencySet; + for(auto path : m_filePathDependencies) + filePathDependencySet.Add(path); + + for(auto module : moduleCollector.m_modulesList) + { + // The specialized component type will have an open (unsatisfied) + // requirement for each of the modules that its specialization + // arguments need. + // + // Note: what this means in practice is that the component type + // records that the given module(s) will need to be linked in + // before final code can be generated, but it importantly + // does not dictate the final placement of the parameters from + // those modules in the layout. + // + m_requirements.add(module); + + // The speciialized component type will also have a dependency + // on all the file paths that any of the modules involved in + // it depend on (including those that are required but not + // yet linked in). + // + // The file path information is what a client would need to + // use to decide if kernel code is out of date compared to + // source files, so we want to include anything that could + // affect the validity of generated code. + // + for(auto path : module->getFilePathDependencies()) + { + if(filePathDependencySet.Contains(path)) + continue; + filePathDependencySet.Add(path); + m_filePathDependencies.add(path); + } + + // Finalyl we also add the module for the specialization arguments + // to the list of modules that would be used for legacy lookup + // operations where we need an implicit/default scope to use + // and want it to be expansive. + // + // TODO: This stuff really isn't worth keeping around long + // term, and we should ditch the entire "legacy lookup" idea. + // + m_moduleDependencies.add(module); + } + // The following is a bit of a hack. // + // TODO: We should not need this hack any longer, since the + // new approach to `switch`-based dynamic dispatch has made + // the existing tagged-union support obsolete. + // // Back-end code generation relies on us having computed layouts for all tagged // unions that end up being used in the code, which means we need a way to find // all such types that get used in a program (and the stuff it imports). @@ -3001,16 +3276,12 @@ void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, Spec Index SpecializedComponentType::getRequirementCount() { - // TODO: A specialized component type may have *more* requirements - // than the original, because it also needs to include the module(s) - // that define the types used for specialization arguments. - - return m_base->getRequirementCount(); + return m_requirements.getCount(); } RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index) { - return m_base->getRequirement(index); + return m_requirements[index]; } String SpecializedComponentType::getEntryPointMangledName(Index index) |
