diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2019-08-08 11:22:32 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-08-08 11:22:32 -0700 |
| commit | 2552217b76c0bd83e18fceba1d35a367bf569eca (patch) | |
| tree | 0651175e4601af75bc18687c853068f013e6c1b9 /source/slang/slang.cpp | |
| parent | 81ce78d08a7e3fbe74f2fd41c5a258ea4b078245 (diff) | |
Revise new COM-lite API (#1007)
* Revise new COM-lite API
This change revises the "COM-lite" API that was recently introduced to try to streamline it and introduce some missing central/base concepts.
The central new abstraction in the API is the notion of a "component type," which is a unit of shader code composition. A component type can have:
* IR code for some number of functions/types/etc.
* Zero or more global shader parameters
* Zero or more "entry point" functions at which execution can start
* Zero or more "specialization" parameters (types or values that must be filled in before kernel code can be generated)
* Zero or more "requirements" (dependencies on other component types that must be satisfied before kernel code can be generated)
Both individual compiled modules, and validated entry points are then examples of component types, and we additionally define a few services that apply to all component types:
* We can take N component types and compose them to create a new component type that combines their code, shader parameters, entry points, and specialization parameters. A composed component type may also include requirements from the sub-component types, but it is also possible that by composing thing we satisfy requirements (if `A` requires `B`, and we compose `A` and `B`, then the requirement is now satisfied, and doesn't appear on the composite).
* We can take a component type with N specialization parameters, and specialize it by giving N compatible specialization arguments. The result of specialization is a new component type with zero specialization parameters. Under the right circumstances the specialzed component type will be layout compatible with the unspecialized one.
* One more example that isn't exposed in the public API today is that we can take a component with requirements and "complete" it by automatically composing it with component types that satisfy those requirements. This can be seen as a kind of linking step that pulls together the transitive closure of dependencies.
* We can query the layout for the shader parameters and entry points of a component type, for a specific target.
* We can query compiled kernel code for an entry point in a component type (for a specific target). This only works for component types with zero specialization parameters and zero requirements.
The idea is that by giving users a fairly general algebra of operations on component types, they can compose final programs in ways that meet their requirements. For example, it becomes possible to incrementally "grow" a component type to represent the global root signature for ray tracing shaders as new entry points are added, in such a way that it always stays layout-compatible with kernels that have already been compiled.
Much of the implementation work here is in implementing the unifying component type abstraction, and in particular re-writing code that used to assume a program consisted of a flat list of modules and entry points to work with a hierarchical representation that reflects the underlying algebra (e.g., with types to represent composite and specialized component types).
There's also a hidden "legacy" case of a component type to deal with some legacy compiler behaviors that can't be directly modeled on top of the simple algebra with modules and entry points.
This API is by no means feature-complete or fully developed. It is expected that we will flesh it out more when bringing up application code (e.g., Falcor) on top of the revamped API.
One notable thing that went away in this change is explicit support for "entry point groups" and notions of local root signatures (especially the Falcor-specific handling of the `shared` keyword, which a previous change turned into an explicitly supported feature). With the new "building blocks" approach, it should be possible for a DXR application to deal with local root signatures as a matter of policy (on top of the API we provide). If/when we need to provide some kind of emulation of local root signatures for Vulkan (and/or if Vulkan is extended with an explicit notion of local root signatures), we might need to revisit this choice.
* Fix debug build
There was invalid code inside an `assert()`, so the release build didn't catch it.
* fixup: warnings
* fixup: more warnings-as-errors
* fixup: review notes
* fixup: use component type visitors in place of dynamic casting
Diffstat (limited to 'source/slang/slang.cpp')
| -rw-r--r-- | source/slang/slang.cpp | 700 |
1 files changed, 598 insertions, 102 deletions
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index b18e4d4d9..b030e5cf9 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -436,44 +436,31 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModule( return asExternal(module); } -SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createProgram( - slang::ProgramDesc const& desc, - slang::IProgram** outProgram) +SLANG_NO_THROW slang::IComponentType* SLANG_MCALL Linkage::createCompositeComponentType( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + ISlangBlob** outDiagnostics) { - RefPtr<Program> program = new Program(this); - - auto itemCount = desc.itemCount; - for(SlangInt ii = 0; ii < itemCount; ++ii) - { - auto& item = desc.items[ii]; - switch(item.kind) - { - case slang::ProgramDesc::Item::Kind::Program: - { - Program* existingProgram = asInternal(item.program); - for(auto referencedModule : existingProgram->getModuleDependencies()) - { - program->addReferencedLeafModule(referencedModule); - } - - // TODO: Need to decide whether to include the entry points as well... - } - break; + // Attempting to create a "composite" of just one component type should + // just return the component type itself, to avoid redundant work. + // + if( componentTypeCount == 1) + return componentTypes[0]; - case slang::ProgramDesc::Item::Kind::Module: - { - Module* module = asInternal(item.module); - program->addReferencedModule(module); - } - break; + DiagnosticSink sink(getSourceManager()); - default: - return SLANG_E_INVALID_ARG; - } + List<RefPtr<ComponentType>> childComponents; + for( Int cc = 0; cc < componentTypeCount; ++cc ) + { + childComponents.add(asInternal(componentTypes[cc])); } - *outProgram = asExternal(program.detach()); - return SLANG_OK; + RefPtr<ComponentType> composite = CompositeComponentType::create( + this, + childComponents); + + sink.getBlobIfNeeded(outDiagnostics); + return asExternal(composite.detach()); } SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType( @@ -781,7 +768,9 @@ RefPtr<Type> checkProperType( TypeExp typeExp, DiagnosticSink* sink); -Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) +Type* ComponentType::getTypeFromString( + String const& typeStr, + DiagnosticSink* sink) { // If we've looked up this type name before, // then we can re-use it. @@ -803,7 +792,7 @@ Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) for(auto module : getModuleDependencies()) scopesToTry.add(module->getModuleDecl()->scope); - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); for(auto& s : scopesToTry) { RefPtr<Expr> typeExpr = linkage->parseTypeString( @@ -899,10 +888,16 @@ void FrontEndCompileRequest::parseTranslationUnit( } } -RefPtr<Program> createUnspecializedProgram( +RefPtr<ComponentType> createUnspecializedGlobalComponentType( + FrontEndCompileRequest* compileRequest); + +RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType( FrontEndCompileRequest* compileRequest); -RefPtr<Program> createSpecializedProgram( +RefPtr<ComponentType> createSpecializedGlobalComponentType( + EndToEndCompileRequest* endToEndReq); + +RefPtr<ComponentType> createSpecializedGlobalAndEntryPointsComponentType( EndToEndCompileRequest* endToEndReq); void FrontEndCompileRequest::checkAllTranslationUnits() @@ -1032,7 +1027,11 @@ SlangResult FrontEndCompileRequest::executeActionsInner() // Look up all the entry points that are expected, // and use them to populate the `program` member. // - m_program = createUnspecializedProgram(this); + m_globalComponentType = createUnspecializedGlobalComponentType(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + m_globalAndEntryPointsComponentType = createUnspecializedGlobalAndEntryPointsComponentType(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1052,7 +1051,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner() // for(auto targetReq : getLinkage()->targets) { - auto targetProgram = m_program->getTargetProgram(targetReq); + auto targetProgram = m_globalAndEntryPointsComponentType->getTargetProgram(targetReq); targetProgram->getOrCreateLayout(getSink()); } if (getSink()->GetErrorCount() != 0) @@ -1064,7 +1063,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner() BackEndCompileRequest::BackEndCompileRequest( Linkage* linkage, DiagnosticSink* sink, - Program* program) + ComponentType* program) : CompileRequestBase(linkage, sink) , m_program(program) {} @@ -1146,7 +1145,8 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // that was computed in the front-end for all subsequent // reflection queries, etc. // - m_specializedProgram = getUnspecializedProgram(); + m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); + m_specializedGlobalAndEntryPointsComponentType = getUnspecializedGlobalAndEntryPointsComponentType(); return SLANG_OK; } @@ -1156,7 +1156,11 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // if (passThrough == PassThroughMode::None) { - m_specializedProgram = createSpecializedProgram(this); + m_specializedGlobalComponentType = createSpecializedGlobalComponentType(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + m_specializedGlobalAndEntryPointsComponentType = createSpecializedGlobalAndEntryPointsComponentType(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1166,7 +1170,7 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // for (auto targetReq : getLinkage()->targets) { - auto targetProgram = m_specializedProgram->getTargetProgram(targetReq); + auto targetProgram = m_specializedGlobalAndEntryPointsComponentType->getTargetProgram(targetReq); targetProgram->getOrCreateLayout(getSink()); } if (getSink()->GetErrorCount() != 0) @@ -1178,20 +1182,27 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // to make sure that the logic in `generateOutput` // sees something worth processing. // - auto specializedProgram = new Program(getLinkage()); - m_specializedProgram = specializedProgram; + List<RefPtr<ComponentType>> dummyEntryPoints; for(auto entryPointReq : getFrontEndReq()->getEntryPointReqs()) { - RefPtr<EntryPoint> entryPoint = EntryPoint::createDummyForPassThrough( + RefPtr<EntryPoint> dummyEntryPoint = EntryPoint::createDummyForPassThrough( + getLinkage(), entryPointReq->getName(), entryPointReq->getProfile()); - specializedProgram->addEntryPoint(entryPoint, getSink()); + dummyEntryPoints.add(dummyEntryPoint); } + + RefPtr<ComponentType> composedProgram = CompositeComponentType::create( + getLinkage(), + dummyEntryPoints); + + m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); + m_specializedGlobalAndEntryPointsComponentType = composedProgram; } // Generate output code, in whatever format was requested - getBackEndReq()->setProgram(getSpecializedProgram()); + getBackEndReq()->setProgram(getSpecializedGlobalAndEntryPointsComponentType()); generateOutput(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1324,7 +1335,7 @@ int EndToEndCompileRequest::addEntryPoint( EntryPointInfo entryPointInfo; for (auto typeName : genericTypeNames) - entryPointInfo.genericArgStrings.add(typeName); + entryPointInfo.specializationArgStrings.add(typeName); Index result = entryPoints.getCount(); entryPoints.add(_Move(entryPointInfo)); @@ -1617,8 +1628,10 @@ void FilePathDependencyList::addDependency(Module* module) // Module::Module(Linkage* linkage) - : m_linkage(linkage) -{} + : ComponentType(linkage) +{ + addModuleDependency(this); +} ISlangUnknown* Module::getInterface(const Guid& guid) { @@ -1638,35 +1651,40 @@ void Module::addFilePathDependency(String const& path) m_filePathDependencyList.addDependency(path); } -// Program +void Module::setModuleDecl(ModuleDecl* moduleDecl) +{ + m_moduleDecl = moduleDecl; +} + +// ComponentType -static const Guid IID_IProgram = SLANG_UUID_IProgram; +static const Guid IID_IComponentType = SLANG_UUID_IComponentType; -Program::Program(Linkage* linkage) +ComponentType::ComponentType(Linkage* linkage) : m_linkage(linkage) {} -ISlangUnknown* Program::getInterface(Guid const& guid) +ISlangUnknown* ComponentType::getInterface(Guid const& guid) { if(guid == IID_ISlangUnknown - || guid == IID_IProgram) + || guid == IID_IComponentType) { - return static_cast<slang::IProgram*>(this); + return static_cast<slang::IComponentType*>(this); } return nullptr; } -SLANG_NO_THROW slang::ISession* SLANG_MCALL Program::getSession() +SLANG_NO_THROW slang::ISession* SLANG_MCALL ComponentType::getSession() { return m_linkage; } -SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout( +SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL ComponentType::getLayout( Int targetIndex, slang::IBlob** outDiagnostics) { - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); if(targetIndex < 0 || targetIndex >= linkage->targets.getCount()) return nullptr; auto target = linkage->targets[targetIndex]; @@ -1678,13 +1696,13 @@ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout( return asExternal(programLayout); } -SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode( +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCode( SlangInt entryPointIndex, Int targetIndex, slang::IBlob** outCode, slang::IBlob** outDiagnostics) { - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); if(targetIndex < 0 || targetIndex >= linkage->targets.getCount()) return SLANG_E_INVALID_ARG; auto target = linkage->targets[targetIndex]; @@ -1702,57 +1720,535 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode( return SLANG_OK; } +RefPtr<ComponentType> ComponentType::specialize( + SpecializationArg const* inSpecializationArgs, + SlangInt specializationArgCount, + DiagnosticSink* sink) +{ + List<SpecializationArg> specializationArgs; + specializationArgs.addRange( + inSpecializationArgs, + specializationArgCount); + + // We next need to validate that the specialization arguments + // make sense, and also expand them to include any derived data + // (e.g., interface conformance witnesses) that doesn't get + // passed explicitly through the API interface. + // + RefPtr<SpecializationInfo> specializationInfo = _validateSpecializationArgs( + specializationArgs.getBuffer(), + specializationArgCount, + sink); + + return new SpecializedComponentType( + this, + specializationInfo, + specializationArgs, + sink); +} -void Program::addReferencedModule(Module* module) +SLANG_NO_THROW slang::IComponentType* SLANG_MCALL ComponentType::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics) { - m_moduleDependencyList.addDependency(module); - m_filePathDependencyList.addDependency(module); + DiagnosticSink sink(getLinkage()->getSourceManager()); + + // First let's check if the number of arguments given matches + // the number of parameters that are present on this component type. + // + auto specializationParamCount = getSpecializationParamCount(); + if( specializationArgCount != specializationParamCount ) + { + // TODO: diagnose + sink.getBlobIfNeeded(outDiagnostics); + return nullptr; + } + + List<SpecializationArg> expandedArgs; + for( Int aa = 0; aa < specializationArgCount; ++aa ) + { + auto apiArg = specializationArgs[aa]; + + SpecializationArg expandedArg; + switch(apiArg.kind) + { + case slang::SpecializationArg::Kind::Type: + expandedArg.val = asInternal(apiArg.type); + break; + + default: + sink.getBlobIfNeeded(outDiagnostics); + return nullptr; + } + expandedArgs.add(expandedArg); + } + + auto specializedComponentType = specialize( + expandedArgs.getBuffer(), + expandedArgs.getCount(), + &sink); + + sink.getBlobIfNeeded(outDiagnostics); + + return specializedComponentType; } -void Program::addReferencedLeafModule(Module* module) + /// Visitor used by `ComponentType::enumerateModules` +struct EnumerateModulesVisitor : ComponentTypeVisitor { - m_moduleDependencyList.addLeafDependency(module); - m_filePathDependencyList.addDependency(module); + EnumerateModulesVisitor(ComponentType::EnumerateModulesCallback callback, void* userData) + : m_callback(callback) + , m_userData(userData) + {} + + ComponentType::EnumerateModulesCallback m_callback; + void* m_userData; + + void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + { + m_callback(module, m_userData); + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(legacy, specializationInfo); + } +}; + + +void ComponentType::enumerateModules(EnumerateModulesCallback callback, void* userData) +{ + EnumerateModulesVisitor visitor(callback, userData); + acceptVisitor(&visitor, nullptr); +} + + /// Visitor used by `ComponentType::enumerateIRModules` +struct EnumerateIRModulesVisitor : ComponentTypeVisitor +{ + EnumerateIRModulesVisitor(ComponentType::EnumerateIRModulesCallback callback, void* userData) + : m_callback(callback) + , m_userData(userData) + {} + + ComponentType::EnumerateIRModulesCallback m_callback; + void* m_userData; + + void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + { + m_callback(module->getIRModule(), m_userData); + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + + m_callback(specialized->getIRModule(), m_userData); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(legacy, specializationInfo); + } +}; + +void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData) +{ + EnumerateIRModulesVisitor visitor(callback, userData); + acceptVisitor(&visitor, nullptr); +} + +// +// CompositeComponentType +// + +RefPtr<ComponentType> CompositeComponentType::create( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents) +{ + // TODO: We should ideally be caching the results of + // composition on the `linkage`, so that if we get + // asked for the same composite again later we re-use + // it rather than re-create it. + // + // Similarly, we might want to do some amount of + // work to "canonicalize" the input for composition. + // E.g., if the user does: + // + // X = compose(A,B); + // Y = compose(C,D); + // Z = compose(X,Y); + // + // W = compose(A, B, C, D); + // + // Then there is no observable difference between + // Z and W, so we might prefer to have them be identical. + + // If there is only a single child, then we should + // just return that child rather than create a dummy composite. + // + if( childComponents.getCount() == 1 ) + { + return childComponents[0]; + } + + return new CompositeComponentType(linkage, childComponents); +} + + +CompositeComponentType::CompositeComponentType( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents) + : ComponentType(linkage) + , m_childComponents(childComponents) +{ + HashSet<ComponentType*> requirementsSet; + for(auto child : childComponents ) + { + child->enumerateModules([&](Module* module) + { + requirementsSet.Add(module); + }); + } + + for(auto child : childComponents ) + { + auto childEntryPointCount = child->getEntryPointCount(); + for(Index cc = 0; cc < childEntryPointCount; ++cc) + { + m_entryPoints.add(child->getEntryPoint(cc)); + } + + auto childShaderParamCount = child->getShaderParamCount(); + for(Index pp = 0; pp < childShaderParamCount; ++pp) + { + m_shaderParams.add(child->getShaderParam(pp)); + } + + auto childSpecializationParamCount = child->getSpecializationParamCount(); + for(Index pp = 0; pp < childSpecializationParamCount; ++pp) + { + m_specializationParams.add(child->getSpecializationParam(pp)); + } + + for(auto module : child->getModuleDependencies()) + { + m_moduleDependencyList.addDependency(module); + } + for(auto filePath : child->getFilePathDependencies()) + { + m_filePathDependencyList.addDependency(filePath); + } + + auto childRequirementCount = child->getRequirementCount(); + for(Index rr = 0; rr < childRequirementCount; ++rr) + { + auto childRequirement = child->getRequirement(rr); + if(!requirementsSet.Contains(childRequirement)) + { + requirementsSet.Add(childRequirement); + m_requirements.add(childRequirement); + } + } + } +} + +Index CompositeComponentType::getEntryPointCount() +{ + return m_entryPoints.getCount(); +} + +RefPtr<EntryPoint> CompositeComponentType::getEntryPoint(Index index) +{ + return m_entryPoints[index]; +} + +Index CompositeComponentType::getShaderParamCount() +{ + return m_shaderParams.getCount(); } -void Program::addEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) +GlobalShaderParamInfo CompositeComponentType::getShaderParam(Index index) { - List<RefPtr<EntryPoint>> entryPoints; - entryPoints.add(entryPoint); + return m_shaderParams[index]; +} + +Index CompositeComponentType::getSpecializationParamCount() +{ + return m_specializationParams.getCount(); +} + +SpecializationParam const& CompositeComponentType::getSpecializationParam(Index index) +{ + return m_specializationParams[index]; +} + +Index CompositeComponentType::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> CompositeComponentType::getRequirement(Index index) +{ + return m_requirements[index]; +} + +List<Module*> const& CompositeComponentType::getModuleDependencies() +{ + return m_moduleDependencyList.getModuleList(); +} - RefPtr<EntryPointGroup> entryPointGroup = EntryPointGroup::create(getLinkageImpl(), entryPoints, sink); +List<String> const& CompositeComponentType::getFilePathDependencies() +{ + return m_filePathDependencyList.getFilePathList(); +} - addEntryPointGroup(entryPointGroup); +void CompositeComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) +{ + visitor->visitComposite(this, as<CompositeSpecializationInfo>(specializationInfo)); } -void Program::addEntryPointGroup(EntryPointGroup* entryPointGroup) + +RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) { - m_entryPointGroups.add(entryPointGroup); + SLANG_UNUSED(argCount); + + RefPtr<CompositeSpecializationInfo> specializationInfo = new CompositeSpecializationInfo(); - for(auto entryPoint : entryPointGroup->getEntryPoints()) + Index offset = 0; + for(auto child : m_childComponents) { - m_entryPoints.add(entryPoint); - for(auto module : entryPoint->getModuleDependencies()) + auto childParamCount = child->getSpecializationParamCount(); + SLANG_ASSERT(offset + childParamCount <= argCount); + + auto childInfo = child->_validateSpecializationArgs( + args + offset, + childParamCount, + sink); + + specializationInfo->childInfos.add(childInfo); + + offset += childParamCount; + } + return specializationInfo; +} + +// +// SpecializedComponentType +// + +SpecializedComponentType::SpecializedComponentType( + ComponentType* base, + ComponentType::SpecializationInfo* specializationInfo, + List<SpecializationArg> const& specializationArgs, + DiagnosticSink* sink) + : ComponentType(base->getLinkage()) + , m_base(base) + , m_specializationInfo(specializationInfo) + , m_specializationArgs(specializationArgs) +{ + m_irModule = generateIRForSpecializedComponentType(this, sink); + + // The following is a bit of a hack. + // + // Back-end code generation relies on us having computed layouts for all tagged + // unions that end up being used in the code, which means we need a way to find + // all such types that get used in a program (and the stuff it imports). + // + // For now we are assuming a tagged union type only comes into existence + // as a (top-level) argument for a generic type parameter, so that we + // can check for them here and cache them on the entry point. + // + // A longer-term strategy might need to consider any (tagged or untagged) + // union types that get used inside of a module, and also take + // those lists into account. + // + // An even longer-term strategy would be to allow type layout to + // be performed on IR types, so taht we don't need to have front-end + // code worrying about this stuff. + // + for(auto arg : specializationArgs) + { + auto argType = as<Type>(arg.val); + if(!argType) + continue; + + auto taggedUnionType = as<TaggedUnionType>(argType); + if(!taggedUnionType) + continue; + + m_taggedUnionTypes.add(taggedUnionType); + } +} + +void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) +{ + SLANG_ASSERT(specializationInfo == nullptr); + SLANG_UNUSED(specializationInfo); + visitor->visitSpecialized(this); +} + +Index SpecializedComponentType::getRequirementCount() +{ + // TODO: A specialized component type may have *more* requirements + // than the original, because it also needs to include the module(s) + // that define the types used for specialization arguments. + + return m_base->getRequirementCount(); +} + +RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index) +{ + return m_base->getRequirement(index); +} + +// +// LegacyProgram +// + +LegacyProgram::LegacyProgram( + Linkage* linkage, + List<RefPtr<TranslationUnitRequest>> const& translationUnits, + DiagnosticSink* sink) + : ComponentType(linkage) + , m_translationUnits(translationUnits) +{ + HashSet<ComponentType*> requirementsSet; + + for(auto translationUnit : translationUnits ) + { + ComponentType* child = translationUnit->getModule(); + + auto childEntryPointCount = child->getEntryPointCount(); + for(Index cc = 0; cc < childEntryPointCount; ++cc) + { + m_entryPoints.add(child->getEntryPoint(cc)); + } + + for(auto module : child->getModuleDependencies()) + { + m_moduleDependencies.addDependency(module); + } + for(auto filePath : child->getFilePathDependencies()) { - addReferencedModule(module); + m_fileDependencies.addDependency(filePath); + } + + auto childRequirementCount = child->getRequirementCount(); + for(Index rr = 0; rr < childRequirementCount; ++rr) + { + auto childRequirement = child->getRequirement(rr); + if(!requirementsSet.Contains(childRequirement)) + { + requirementsSet.Add(childRequirement); + m_requirements.add(childRequirement); + } } } + + _collectShaderParams(sink); } -RefPtr<IRModule> Program::getOrCreateIRModule(DiagnosticSink* sink) +void LegacyProgram::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) { - if(!m_irModule) + visitor->visitLegacy(this, as<CompositeComponentType::CompositeSpecializationInfo>(specializationInfo)); +} + +RefPtr<ComponentType::SpecializationInfo> LegacyProgram::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) +{ + SLANG_UNUSED(argCount); + + RefPtr<CompositeComponentType::CompositeSpecializationInfo> info = new CompositeComponentType::CompositeSpecializationInfo(); + + Index offset = 0; + for(auto translationUnit : m_translationUnits) { - m_irModule = generateIRForProgram( - m_linkage->getSessionImpl(), - this, + ComponentType* child = translationUnit->getModule(); + auto childParamCount = child->getSpecializationParamCount(); + SLANG_ASSERT(offset + childParamCount <= argCount); + + auto childInfo = child->_validateSpecializationArgs( + args + offset, + childParamCount, sink); + + info->childInfos.add(childInfo); + + offset += childParamCount; } - return m_irModule; + return info; } +Index LegacyProgram::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> LegacyProgram::getRequirement(Index index) +{ + return m_requirements[index]; +} + +void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) +{ + auto childCount = composite->getChildComponentCount(); + for(Index ii = 0; ii < childCount; ++ii) + { + auto child = composite->getChildComponent(ii); + auto childSpecializationInfo = specializationInfo + ? specializationInfo->childInfos[ii] + : nullptr; + + child->acceptVisitor(this, childSpecializationInfo); + } +} + +void ComponentTypeVisitor::visitChildren(SpecializedComponentType* specialized) +{ + specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo()); +} + +void ComponentTypeVisitor::visitChildren(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) +{ + auto childCount = legacy->getTranslationUnitCount(); + for(Index ii = 0; ii < childCount; ++ii) + { + auto translationUnit = legacy->getTranslationUnit(ii); + ComponentType* child = translationUnit->getModule(); + auto childSpecializationInfo = specializationInfo + ? specializationInfo->childInfos[ii] + : nullptr; + + child->acceptVisitor(this, childSpecializationInfo); + } +} -TargetProgram* Program::getTargetProgram(TargetRequest* target) +TargetProgram* ComponentType::getTargetProgram(TargetRequest* target) { RefPtr<TargetProgram> targetProgram; if(!m_targetPrograms.TryGetValue(target, targetProgram)) @@ -1768,12 +2264,12 @@ TargetProgram* Program::getTargetProgram(TargetRequest* target) // TargetProgram::TargetProgram( - Program* program, + ComponentType* componentType, TargetRequest* targetReq) - : m_program(program) + : m_program(componentType) , m_targetReq(targetReq) { - m_entryPointResults.setCount(program->getEntryPoints().getCount()); + m_entryPointResults.setCount(componentType->getEntryPointCount()); } // @@ -2458,10 +2954,10 @@ SLANG_API SlangResult spSetGlobalGenericArgs( if (!request) return SLANG_FAIL; auto req = convert(request); - auto& genericArgStrings = req->globalGenericArgStrings; - genericArgStrings.clear(); + auto& argStrings = req->globalSpecializationArgStrings; + argStrings.clear(); for (int i = 0; i < genericArgCount; i++) - genericArgStrings.add(genericArgs[i]); + argStrings.add(genericArgs[i]); return SLANG_OK; } @@ -2477,7 +2973,7 @@ SLANG_API SlangResult spSetTypeNameForGlobalExistentialTypeParam( if(!typeName) return SLANG_FAIL; auto req = convert(request); - auto& typeArgStrings = req->globalExistentialSlotArgStrings; + auto& typeArgStrings = req->globalSpecializationArgStrings; if(Index(slotIndex) >= typeArgStrings.getCount()) typeArgStrings.setCount(slotIndex+1); typeArgStrings[slotIndex] = String(typeName); @@ -2501,7 +2997,7 @@ SLANG_API SlangResult spSetTypeNameForEntryPointExistentialTypeParam( return SLANG_FAIL; auto& entryPointInfo = req->entryPoints[entryPointIndex]; - auto& typeArgStrings = entryPointInfo.existentialArgStrings; + auto& typeArgStrings = entryPointInfo.specializationArgStrings; if(Index(slotIndex) >= typeArgStrings.getCount()) typeArgStrings.setCount(slotIndex+1); typeArgStrings[slotIndex] = String(typeName); @@ -2569,7 +3065,7 @@ spGetDependencyFileCount( if(!request) return 0; auto req = convert(request); auto frontEndReq = req->getFrontEndReq(); - auto program = frontEndReq->getProgram(); + auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); return (int) program->getFilePathDependencies().getCount(); } @@ -2583,7 +3079,7 @@ spGetDependencyFilePath( if(!request) return 0; auto req = convert(request); auto frontEndReq = req->getFrontEndReq(); - auto program = frontEndReq->getProgram(); + auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); return program->getFilePathDependencies()[index].begin(); } @@ -2613,7 +3109,7 @@ SLANG_API void const* spGetEntryPointCode( using namespace Slang; auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); // TODO: We should really accept a target index in this API Index targetIndex = 0; @@ -2668,7 +3164,7 @@ SLANG_API SlangResult spGetEntryPointCodeBlob( auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); Index targetCount = linkage->targets.getCount(); if((targetIndex < 0) || (targetIndex >= targetCount)) @@ -2715,13 +3211,13 @@ SLANG_API void const* spGetCompileRequestCode( SLANG_API SlangResult spCompileRequest_getProgram( SlangCompileRequest* request, - slang::IProgram** outProgram) + slang::IComponentType** outProgram) { if( !request ) return SLANG_ERROR_INVALID_PARAMETER; auto req = convert(request); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); - *outProgram = Slang::ComPtr<slang::IProgram>(program).detach(); + *outProgram = Slang::ComPtr<slang::IComponentType>(program).detach(); return SLANG_OK; } @@ -2731,7 +3227,7 @@ SLANG_API SlangReflection* spGetReflection( if( !request ) return 0; auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); // Note(tfoley): The API signature doesn't let the client // specify which target they want to access reflection |
