diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 85 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 295 |
4 files changed, 346 insertions, 52 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 800a0519f..b6ed0224a 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -784,30 +784,10 @@ namespace Slang return fillRequirements(globalComponentType); } - /// Create a component type that represents the global scope for a compile request, - /// along with any entry point functions. - /// - /// The resulting component type will include the global-scope information - /// first, so its layout will be compatible with the result of - /// `createUnspecializedGlobalComponentType`. - /// - /// The new component type will also add on any entry-point functions - /// that were requested and will thus include space for their `uniform` parameters. - /// If multiple entry points were requested then they will be given non-overlapping - /// parameter bindings, consistent with them being used together in - /// a single pipeline state, hit group, etc. - /// - /// The result of this function is unspecialized and doesn't take into - /// account any specialization arguments the user might have supplied. - /// - RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType( - FrontEndCompileRequest* compileRequest, - List<RefPtr<ComponentType>>& outUnspecializedEntryPoints) + void FrontEndCompileRequest::checkEntryPoints() { - auto linkage = compileRequest->getLinkage(); - auto sink = compileRequest->getSink(); - - auto globalComponentType = compileRequest->getGlobalComponentType(); + auto linkage = getLinkage(); + auto sink = getSink(); // The validation of entry points here will be modal, and controlled // by whether the user specified any entry points directly via @@ -818,17 +798,14 @@ namespace Slang // First, check if the user requested any entry points explicitly via // the API or command line. // - bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0; - - List<RefPtr<ComponentType>> allComponentTypes; - allComponentTypes.add(globalComponentType); + bool anyExplicitEntryPoints = getEntryPointReqCount() != 0; if( anyExplicitEntryPoints ) { // If there were any explicit requests for entry points to be // checked, then we will *only* check those. // - for(auto entryPointReq : compileRequest->getEntryPointReqs()) + for(auto entryPointReq : getEntryPointReqs()) { auto entryPoint = findAndValidateEntryPoint( entryPointReq); @@ -841,9 +818,6 @@ namespace Slang // compilation API doesn't allow for grouping). // entryPointReq->getTranslationUnit()->module->_addEntryPoint(entryPoint); - - outUnspecializedEntryPoints.add(entryPoint); - allComponentTypes.add(entryPoint); } } @@ -868,10 +842,10 @@ namespace Slang // For now we'll start with an extremely basic approach that // should work for typical HLSL code. // - Index translationUnitCount = compileRequest->translationUnits.getCount(); + Index translationUnitCount = translationUnits.getCount(); for(Index tt = 0; tt < translationUnitCount; ++tt) { - auto translationUnit = compileRequest->translationUnits[tt]; + auto translationUnit = translationUnits[tt]; for( auto globalDecl : translationUnit->getModuleDecl()->members ) { auto maybeFuncDecl = globalDecl; @@ -920,12 +894,51 @@ namespace Slang // independent of the others. // translationUnit->module->_addEntryPoint(entryPoint); - - outUnspecializedEntryPoints.add(entryPoint); - allComponentTypes.add(entryPoint); } } } + } + + + /// Create a component type that represents the global scope for a compile request, + /// along with any entry point functions. + /// + /// The resulting component type will include the global-scope information + /// first, so its layout will be compatible with the result of + /// `createUnspecializedGlobalComponentType`. + /// + /// The new component type will also add on any entry-point functions + /// that were requested and will thus include space for their `uniform` parameters. + /// If multiple entry points were requested then they will be given non-overlapping + /// parameter bindings, consistent with them being used together in + /// a single pipeline state, hit group, etc. + /// + /// The result of this function is unspecialized and doesn't take into + /// account any specialization arguments the user might have supplied. + /// + RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType( + FrontEndCompileRequest* compileRequest, + List<RefPtr<ComponentType>>& outUnspecializedEntryPoints) + { + auto linkage = compileRequest->getLinkage(); + + auto globalComponentType = compileRequest->getGlobalComponentType(); + + List<RefPtr<ComponentType>> allComponentTypes; + allComponentTypes.add(globalComponentType); + + Index translationUnitCount = compileRequest->translationUnits.getCount(); + for(Index tt = 0; tt < translationUnitCount; ++tt) + { + auto translationUnit = compileRequest->translationUnits[tt]; + auto module = translationUnit->getModule(); + + for(auto entryPoint : module->getEntryPoints() ) + { + outUnspecializedEntryPoints.add(entryPoint); + allComponentTypes.add(entryPoint); + } + } // Also consider entry points that were introduced via adding // a library reference... diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 2388f5247..9e3550066 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -602,10 +602,8 @@ namespace Slang Index getRequirementCount() SLANG_OVERRIDE; RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; - /// TODO: These should include requirements/dependencies for the types - /// referenced in the specialization arguments... - List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_base->getModuleDependencies(); } - List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_base->getFilePathDependencies(); } + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies; } + List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencies; } /// Get a list of tagged-union types referenced by the specialization parameters. List<TaggedUnionType*> const& getTaggedUnionTypes() { return m_taggedUnionTypes; } @@ -638,6 +636,9 @@ namespace Slang // Any tagged union types that were referenced by the specialization arguments. List<TaggedUnionType*> m_taggedUnionTypes; + List<Module*> m_moduleDependencies; + List<String> m_filePathDependencies; + List<RefPtr<ComponentType>> m_requirements; }; /// Describes an entry point for the purposes of layout and code generation. @@ -1362,6 +1363,7 @@ namespace Slang DiagnosticSink* sink); void loadParsedModule( + RefPtr<FrontEndCompileRequest> compileRequest, RefPtr<TranslationUnitRequest> translationUnit, Name* name, PathInfo const& pathInfo); @@ -1560,6 +1562,8 @@ namespace Slang // of the translation units in the program void checkAllTranslationUnits(); + void checkEntryPoints(); + void generateIR(); SlangResult executeActionsInner(); @@ -1576,6 +1580,8 @@ namespace Slang /// @return The zero-based index of the translation unit in this compile request. int addTranslationUnit(SourceLanguage language, Name* moduleName); + int addTranslationUnit(TranslationUnitRequest* translationUnit); + void addTranslationUnitSourceFile( int translationUnitIndex, SourceFile* sourceFile); diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 91229b80d..91427ce9c 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -1409,6 +1409,10 @@ static LegalVal legalizeInst( break; } + if(as<IRAttr>(inst)) + return LegalVal::simple(inst); + + // We will iterate over all the operands, extract the legalized // value of each, and collect them in an array for subsequent use. // diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 254a37c24..4a25c2392 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1552,6 +1552,7 @@ void FrontEndCompileRequest::checkAllTranslationUnits() { checkTranslationUnit(translationUnit.Ptr()); } + checkEntryPoints(); } void FrontEndCompileRequest::generateIR() @@ -1676,7 +1677,6 @@ SlangResult FrontEndCompileRequest::executeActionsInner() if (getSink()->getErrorCount() != 0) return SLANG_FAIL; - // Look up all the entry points that are expected, // and use them to populate the `program` member. // @@ -1905,13 +1905,11 @@ SlangResult EndToEndCompileRequest::executeActions() int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* moduleName) { - Index result = translationUnits.getCount(); - if (!moduleName) { // We want to ensure that symbols defined in different translation // units get unique mangled names, so that we can, e.g., tell apart - // a `main()` function in `vertex.slang` and a `main()` in `fragment.slang`, + // a `main()` function in `vertex.hlsl` and a `main()` in `fragment.hlsl`, // even when they are being compiled together. // String generatedName = "tu"; @@ -1925,11 +1923,17 @@ int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* mo translationUnit->moduleName = moduleName; - translationUnits.add(translationUnit); + return addTranslationUnit(translationUnit); +} +int FrontEndCompileRequest::addTranslationUnit(TranslationUnitRequest* translationUnit) +{ + Index result = translationUnits.getCount(); + translationUnits.add(translationUnit); return (int) result; } + void FrontEndCompileRequest::addTranslationUnitSourceFile( int translationUnitIndex, SourceFile* sourceFile) @@ -2041,6 +2045,7 @@ UInt Linkage::addTarget( } void Linkage::loadParsedModule( + RefPtr<FrontEndCompileRequest> compileRequest, RefPtr<TranslationUnitRequest> translationUnit, Name* name, const PathInfo& pathInfo) @@ -2061,7 +2066,7 @@ void Linkage::loadParsedModule( auto sink = translationUnit->compileRequest->getSink(); int errorCountBefore = sink->getErrorCount(); - checkTranslationUnit(translationUnit.Ptr()); + compileRequest->checkAllTranslationUnits(); int errorCountAfter = sink->getErrorCount(); if (errorCountAfter != errorCountBefore) @@ -2106,6 +2111,8 @@ RefPtr<Module> Linkage::loadModule( translationUnit->moduleName = name; translationUnit->sourceLanguage = SourceLanguage::Slang; + frontEndReq->addTranslationUnit(translationUnit); + auto module = translationUnit->getModule(); ModuleBeingImportedRAII moduleBeingImported( @@ -2132,6 +2139,7 @@ RefPtr<Module> Linkage::loadModule( } loadParsedModule( + frontEndReq, translationUnit, name, filePathInfo); @@ -2902,6 +2910,134 @@ RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpeci // SpecializedComponentType // +/// Utility type for collecting modules references by types/declarations +struct SpecializationArgModuleCollector : ComponentTypeVisitor +{ + HashSet<Module*> m_modulesSet; + List<Module*> m_modulesList; + + void addModule(Module* module) + { + m_modulesList.add(module); + m_modulesSet.Add(module); + } + + void maybeAddModule(Module* module) + { + if(!module) + return; + if(m_modulesSet.Contains(module)) + return; + + addModule(module); + } + + void collectReferencedModules(Decl* decl) + { + auto module = getModule(decl); + maybeAddModule(module); + } + + void collectReferencedModules(Substitutions* substitution) + { + if(auto genericSubst = as<GenericSubstitution>(substitution)) + { + for(auto arg : genericSubst->args) + { + collectReferencedModules(arg); + } + } + } + + void collectReferencedModules(SubstitutionSet const& substitutions) + { + for(auto subst = substitutions.substitutions; subst; subst = subst->outer) + { + collectReferencedModules(subst); + } + } + + void collectReferencedModules(DeclRefBase const& declRef) + { + collectReferencedModules(declRef.decl); + collectReferencedModules(declRef.substitutions); + } + + void collectReferencedModules(Type* type) + { + if(auto declRefType = as<DeclRefType>(type)) + { + collectReferencedModules(declRefType->declRef); + } + + // TODO: Handle non-decl-ref composite type cases + // (e.g., function types). + } + + void collectReferencedModules(Val* val) + { + if(auto type = as<Type>(val)) + { + collectReferencedModules(type); + } + else if (auto declRefVal = as<GenericParamIntVal>(val)) + { + collectReferencedModules(declRefVal->declRef); + } + + // TODO: other cases of values that could reference + // a declaration. + } + + void collectReferencedModules(List<ExpandedSpecializationArg> const& args) + { + for(auto arg : args) + { + collectReferencedModules(arg.val); + collectReferencedModules(arg.witness); + } + } + + // + // ComponentTypeVisitor methods + // + + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(entryPoint); + + if(!specializationInfo) + return; + + collectReferencedModules(specializationInfo->specializedFuncDeclRef); + collectReferencedModules(specializationInfo->existentialSpecializationArgs); + } + + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(module); + + if(!specializationInfo) + return; + + for(auto arg : specializationInfo->genericArgs) + { + collectReferencedModules(arg.argVal); + } + collectReferencedModules(specializationInfo->existentialArgs); + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } +}; + SpecializedComponentType::SpecializedComponentType( ComponentType* base, ComponentType::SpecializationInfo* specializationInfo, @@ -2914,8 +3050,147 @@ SpecializedComponentType::SpecializedComponentType( { m_irModule = generateIRForSpecializedComponentType(this, sink); + // We need to account for the fact that a specialized + // entity like `myShader<SomeType>` needs to not only + // depend on the module(s) that `myShader` depends on, + // but also on any modules that `SomeType` depends on. + // + // We will set up a "collector" type that will be + // used to build a list of these additional modules. + // + SpecializationArgModuleCollector moduleCollector; + + // We don't want to go adding additional requirements for + // modules that the base component type already includes, + // so we will add those to the set of modules in + // the collector before we starting trying to add others. + // + base->enumerateModules([&](Module* module) + { + moduleCollector.m_modulesSet.Add(module); + }); + + // In order to collect the additional modules, we need + // to inspect the specialization arguments and see what + // they depend on. + // + // Naively, it seems like we'd just want to iterate + // over `specializationArgs`, which gives the specialization + // arguments as the user supplied them. However, such + // an approach would have a subtle problem. + // + // If we have a generic entry point like: + // + // // In module A + // myShader<T : IThing> + // + // + // And the type `SomeType` that is being used as an argument doesn't + // directly conform to `IThing`: + // + // // In module B + // struct SomeType { ... } + // + // and the conformance of `SomeType` to `IThing` is + // coming from yet another module: + // + // // In module C + // import B; + // extension SomeType : IThing { ... } + // + // In this case, the specialized component for `myShader<SomeType>` + // needs to depend on all of: + // + // * Module A, because it defines `myShader` + // * Module B, because it defines `SomeType` + // * Module C, because it defines the conformance `SomeType : IThing` + // + // We thus need to iterate over a form of the specialization + // arguments that includes the "expanded" arguments like + // interface conformance witnesses that got added during + // semantic checking. + // + // The expanded arguments are being stored in the `specializationInfo` + // today (for use by downstream code generation), and the easiest + // way to walk that information and get to the leaf nodes where + // the expanded arguments are stored is to apply a visitor to + // the specialized component type we are in the middle of constructing. + // + moduleCollector.visitSpecialized(this); + + // Now that we've collected our additional information, we can + // start to build up the final lists for the specialized component type. + // + // The starting point for our lists comes from the base component type. + // + m_moduleDependencies = base->getModuleDependencies(); + m_filePathDependencies = base->getFilePathDependencies(); + + Index baseRequirementCount = base->getRequirementCount(); + for( Index r = 0; r < baseRequirementCount; r++ ) + { + m_requirements.add(base->getRequirement(r)); + } + + // The specialized component type will need to have additional + // dependencies and requirements based on the modules that + // were collected when looking at the specialization arguments. + + // We want to avoid adding the same file path dependency more than once. + // + HashSet<String> filePathDependencySet; + for(auto path : m_filePathDependencies) + filePathDependencySet.Add(path); + + for(auto module : moduleCollector.m_modulesList) + { + // The specialized component type will have an open (unsatisfied) + // requirement for each of the modules that its specialization + // arguments need. + // + // Note: what this means in practice is that the component type + // records that the given module(s) will need to be linked in + // before final code can be generated, but it importantly + // does not dictate the final placement of the parameters from + // those modules in the layout. + // + m_requirements.add(module); + + // The speciialized component type will also have a dependency + // on all the file paths that any of the modules involved in + // it depend on (including those that are required but not + // yet linked in). + // + // The file path information is what a client would need to + // use to decide if kernel code is out of date compared to + // source files, so we want to include anything that could + // affect the validity of generated code. + // + for(auto path : module->getFilePathDependencies()) + { + if(filePathDependencySet.Contains(path)) + continue; + filePathDependencySet.Add(path); + m_filePathDependencies.add(path); + } + + // Finalyl we also add the module for the specialization arguments + // to the list of modules that would be used for legacy lookup + // operations where we need an implicit/default scope to use + // and want it to be expansive. + // + // TODO: This stuff really isn't worth keeping around long + // term, and we should ditch the entire "legacy lookup" idea. + // + m_moduleDependencies.add(module); + } + // The following is a bit of a hack. // + // TODO: We should not need this hack any longer, since the + // new approach to `switch`-based dynamic dispatch has made + // the existing tagged-union support obsolete. + // // Back-end code generation relies on us having computed layouts for all tagged // unions that end up being used in the code, which means we need a way to find // all such types that get used in a program (and the stuff it imports). @@ -3001,16 +3276,12 @@ void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, Spec Index SpecializedComponentType::getRequirementCount() { - // TODO: A specialized component type may have *more* requirements - // than the original, because it also needs to include the module(s) - // that define the types used for specialization arguments. - - return m_base->getRequirementCount(); + return m_requirements.getCount(); } RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index) { - return m_base->getRequirement(index); + return m_requirements[index]; } String SpecializedComponentType::getEntryPointMangledName(Index index) |
