diff options
Diffstat (limited to 'source/slang/slang.cpp')
| -rw-r--r-- | source/slang/slang.cpp | 700 |
1 files changed, 598 insertions, 102 deletions
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index b18e4d4d9..b030e5cf9 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -436,44 +436,31 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModule( return asExternal(module); } -SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createProgram( - slang::ProgramDesc const& desc, - slang::IProgram** outProgram) +SLANG_NO_THROW slang::IComponentType* SLANG_MCALL Linkage::createCompositeComponentType( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + ISlangBlob** outDiagnostics) { - RefPtr<Program> program = new Program(this); - - auto itemCount = desc.itemCount; - for(SlangInt ii = 0; ii < itemCount; ++ii) - { - auto& item = desc.items[ii]; - switch(item.kind) - { - case slang::ProgramDesc::Item::Kind::Program: - { - Program* existingProgram = asInternal(item.program); - for(auto referencedModule : existingProgram->getModuleDependencies()) - { - program->addReferencedLeafModule(referencedModule); - } - - // TODO: Need to decide whether to include the entry points as well... - } - break; + // Attempting to create a "composite" of just one component type should + // just return the component type itself, to avoid redundant work. + // + if( componentTypeCount == 1) + return componentTypes[0]; - case slang::ProgramDesc::Item::Kind::Module: - { - Module* module = asInternal(item.module); - program->addReferencedModule(module); - } - break; + DiagnosticSink sink(getSourceManager()); - default: - return SLANG_E_INVALID_ARG; - } + List<RefPtr<ComponentType>> childComponents; + for( Int cc = 0; cc < componentTypeCount; ++cc ) + { + childComponents.add(asInternal(componentTypes[cc])); } - *outProgram = asExternal(program.detach()); - return SLANG_OK; + RefPtr<ComponentType> composite = CompositeComponentType::create( + this, + childComponents); + + sink.getBlobIfNeeded(outDiagnostics); + return asExternal(composite.detach()); } SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType( @@ -781,7 +768,9 @@ RefPtr<Type> checkProperType( TypeExp typeExp, DiagnosticSink* sink); -Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) +Type* ComponentType::getTypeFromString( + String const& typeStr, + DiagnosticSink* sink) { // If we've looked up this type name before, // then we can re-use it. @@ -803,7 +792,7 @@ Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) for(auto module : getModuleDependencies()) scopesToTry.add(module->getModuleDecl()->scope); - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); for(auto& s : scopesToTry) { RefPtr<Expr> typeExpr = linkage->parseTypeString( @@ -899,10 +888,16 @@ void FrontEndCompileRequest::parseTranslationUnit( } } -RefPtr<Program> createUnspecializedProgram( +RefPtr<ComponentType> createUnspecializedGlobalComponentType( + FrontEndCompileRequest* compileRequest); + +RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType( FrontEndCompileRequest* compileRequest); -RefPtr<Program> createSpecializedProgram( +RefPtr<ComponentType> createSpecializedGlobalComponentType( + EndToEndCompileRequest* endToEndReq); + +RefPtr<ComponentType> createSpecializedGlobalAndEntryPointsComponentType( EndToEndCompileRequest* endToEndReq); void FrontEndCompileRequest::checkAllTranslationUnits() @@ -1032,7 +1027,11 @@ SlangResult FrontEndCompileRequest::executeActionsInner() // Look up all the entry points that are expected, // and use them to populate the `program` member. // - m_program = createUnspecializedProgram(this); + m_globalComponentType = createUnspecializedGlobalComponentType(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + m_globalAndEntryPointsComponentType = createUnspecializedGlobalAndEntryPointsComponentType(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1052,7 +1051,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner() // for(auto targetReq : getLinkage()->targets) { - auto targetProgram = m_program->getTargetProgram(targetReq); + auto targetProgram = m_globalAndEntryPointsComponentType->getTargetProgram(targetReq); targetProgram->getOrCreateLayout(getSink()); } if (getSink()->GetErrorCount() != 0) @@ -1064,7 +1063,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner() BackEndCompileRequest::BackEndCompileRequest( Linkage* linkage, DiagnosticSink* sink, - Program* program) + ComponentType* program) : CompileRequestBase(linkage, sink) , m_program(program) {} @@ -1146,7 +1145,8 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // that was computed in the front-end for all subsequent // reflection queries, etc. // - m_specializedProgram = getUnspecializedProgram(); + m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); + m_specializedGlobalAndEntryPointsComponentType = getUnspecializedGlobalAndEntryPointsComponentType(); return SLANG_OK; } @@ -1156,7 +1156,11 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // if (passThrough == PassThroughMode::None) { - m_specializedProgram = createSpecializedProgram(this); + m_specializedGlobalComponentType = createSpecializedGlobalComponentType(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + m_specializedGlobalAndEntryPointsComponentType = createSpecializedGlobalAndEntryPointsComponentType(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1166,7 +1170,7 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // for (auto targetReq : getLinkage()->targets) { - auto targetProgram = m_specializedProgram->getTargetProgram(targetReq); + auto targetProgram = m_specializedGlobalAndEntryPointsComponentType->getTargetProgram(targetReq); targetProgram->getOrCreateLayout(getSink()); } if (getSink()->GetErrorCount() != 0) @@ -1178,20 +1182,27 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // to make sure that the logic in `generateOutput` // sees something worth processing. // - auto specializedProgram = new Program(getLinkage()); - m_specializedProgram = specializedProgram; + List<RefPtr<ComponentType>> dummyEntryPoints; for(auto entryPointReq : getFrontEndReq()->getEntryPointReqs()) { - RefPtr<EntryPoint> entryPoint = EntryPoint::createDummyForPassThrough( + RefPtr<EntryPoint> dummyEntryPoint = EntryPoint::createDummyForPassThrough( + getLinkage(), entryPointReq->getName(), entryPointReq->getProfile()); - specializedProgram->addEntryPoint(entryPoint, getSink()); + dummyEntryPoints.add(dummyEntryPoint); } + + RefPtr<ComponentType> composedProgram = CompositeComponentType::create( + getLinkage(), + dummyEntryPoints); + + m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); + m_specializedGlobalAndEntryPointsComponentType = composedProgram; } // Generate output code, in whatever format was requested - getBackEndReq()->setProgram(getSpecializedProgram()); + getBackEndReq()->setProgram(getSpecializedGlobalAndEntryPointsComponentType()); generateOutput(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1324,7 +1335,7 @@ int EndToEndCompileRequest::addEntryPoint( EntryPointInfo entryPointInfo; for (auto typeName : genericTypeNames) - entryPointInfo.genericArgStrings.add(typeName); + entryPointInfo.specializationArgStrings.add(typeName); Index result = entryPoints.getCount(); entryPoints.add(_Move(entryPointInfo)); @@ -1617,8 +1628,10 @@ void FilePathDependencyList::addDependency(Module* module) // Module::Module(Linkage* linkage) - : m_linkage(linkage) -{} + : ComponentType(linkage) +{ + addModuleDependency(this); +} ISlangUnknown* Module::getInterface(const Guid& guid) { @@ -1638,35 +1651,40 @@ void Module::addFilePathDependency(String const& path) m_filePathDependencyList.addDependency(path); } -// Program +void Module::setModuleDecl(ModuleDecl* moduleDecl) +{ + m_moduleDecl = moduleDecl; +} + +// ComponentType -static const Guid IID_IProgram = SLANG_UUID_IProgram; +static const Guid IID_IComponentType = SLANG_UUID_IComponentType; -Program::Program(Linkage* linkage) +ComponentType::ComponentType(Linkage* linkage) : m_linkage(linkage) {} -ISlangUnknown* Program::getInterface(Guid const& guid) +ISlangUnknown* ComponentType::getInterface(Guid const& guid) { if(guid == IID_ISlangUnknown - || guid == IID_IProgram) + || guid == IID_IComponentType) { - return static_cast<slang::IProgram*>(this); + return static_cast<slang::IComponentType*>(this); } return nullptr; } -SLANG_NO_THROW slang::ISession* SLANG_MCALL Program::getSession() +SLANG_NO_THROW slang::ISession* SLANG_MCALL ComponentType::getSession() { return m_linkage; } -SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout( +SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL ComponentType::getLayout( Int targetIndex, slang::IBlob** outDiagnostics) { - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); if(targetIndex < 0 || targetIndex >= linkage->targets.getCount()) return nullptr; auto target = linkage->targets[targetIndex]; @@ -1678,13 +1696,13 @@ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout( return asExternal(programLayout); } -SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode( +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCode( SlangInt entryPointIndex, Int targetIndex, slang::IBlob** outCode, slang::IBlob** outDiagnostics) { - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); if(targetIndex < 0 || targetIndex >= linkage->targets.getCount()) return SLANG_E_INVALID_ARG; auto target = linkage->targets[targetIndex]; @@ -1702,57 +1720,535 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode( return SLANG_OK; } +RefPtr<ComponentType> ComponentType::specialize( + SpecializationArg const* inSpecializationArgs, + SlangInt specializationArgCount, + DiagnosticSink* sink) +{ + List<SpecializationArg> specializationArgs; + specializationArgs.addRange( + inSpecializationArgs, + specializationArgCount); + + // We next need to validate that the specialization arguments + // make sense, and also expand them to include any derived data + // (e.g., interface conformance witnesses) that doesn't get + // passed explicitly through the API interface. + // + RefPtr<SpecializationInfo> specializationInfo = _validateSpecializationArgs( + specializationArgs.getBuffer(), + specializationArgCount, + sink); + + return new SpecializedComponentType( + this, + specializationInfo, + specializationArgs, + sink); +} -void Program::addReferencedModule(Module* module) +SLANG_NO_THROW slang::IComponentType* SLANG_MCALL ComponentType::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics) { - m_moduleDependencyList.addDependency(module); - m_filePathDependencyList.addDependency(module); + DiagnosticSink sink(getLinkage()->getSourceManager()); + + // First let's check if the number of arguments given matches + // the number of parameters that are present on this component type. + // + auto specializationParamCount = getSpecializationParamCount(); + if( specializationArgCount != specializationParamCount ) + { + // TODO: diagnose + sink.getBlobIfNeeded(outDiagnostics); + return nullptr; + } + + List<SpecializationArg> expandedArgs; + for( Int aa = 0; aa < specializationArgCount; ++aa ) + { + auto apiArg = specializationArgs[aa]; + + SpecializationArg expandedArg; + switch(apiArg.kind) + { + case slang::SpecializationArg::Kind::Type: + expandedArg.val = asInternal(apiArg.type); + break; + + default: + sink.getBlobIfNeeded(outDiagnostics); + return nullptr; + } + expandedArgs.add(expandedArg); + } + + auto specializedComponentType = specialize( + expandedArgs.getBuffer(), + expandedArgs.getCount(), + &sink); + + sink.getBlobIfNeeded(outDiagnostics); + + return specializedComponentType; } -void Program::addReferencedLeafModule(Module* module) + /// Visitor used by `ComponentType::enumerateModules` +struct EnumerateModulesVisitor : ComponentTypeVisitor { - m_moduleDependencyList.addLeafDependency(module); - m_filePathDependencyList.addDependency(module); + EnumerateModulesVisitor(ComponentType::EnumerateModulesCallback callback, void* userData) + : m_callback(callback) + , m_userData(userData) + {} + + ComponentType::EnumerateModulesCallback m_callback; + void* m_userData; + + void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + { + m_callback(module, m_userData); + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(legacy, specializationInfo); + } +}; + + +void ComponentType::enumerateModules(EnumerateModulesCallback callback, void* userData) +{ + EnumerateModulesVisitor visitor(callback, userData); + acceptVisitor(&visitor, nullptr); +} + + /// Visitor used by `ComponentType::enumerateIRModules` +struct EnumerateIRModulesVisitor : ComponentTypeVisitor +{ + EnumerateIRModulesVisitor(ComponentType::EnumerateIRModulesCallback callback, void* userData) + : m_callback(callback) + , m_userData(userData) + {} + + ComponentType::EnumerateIRModulesCallback m_callback; + void* m_userData; + + void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + { + m_callback(module->getIRModule(), m_userData); + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + + m_callback(specialized->getIRModule(), m_userData); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(legacy, specializationInfo); + } +}; + +void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData) +{ + EnumerateIRModulesVisitor visitor(callback, userData); + acceptVisitor(&visitor, nullptr); +} + +// +// CompositeComponentType +// + +RefPtr<ComponentType> CompositeComponentType::create( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents) +{ + // TODO: We should ideally be caching the results of + // composition on the `linkage`, so that if we get + // asked for the same composite again later we re-use + // it rather than re-create it. + // + // Similarly, we might want to do some amount of + // work to "canonicalize" the input for composition. + // E.g., if the user does: + // + // X = compose(A,B); + // Y = compose(C,D); + // Z = compose(X,Y); + // + // W = compose(A, B, C, D); + // + // Then there is no observable difference between + // Z and W, so we might prefer to have them be identical. + + // If there is only a single child, then we should + // just return that child rather than create a dummy composite. + // + if( childComponents.getCount() == 1 ) + { + return childComponents[0]; + } + + return new CompositeComponentType(linkage, childComponents); +} + + +CompositeComponentType::CompositeComponentType( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents) + : ComponentType(linkage) + , m_childComponents(childComponents) +{ + HashSet<ComponentType*> requirementsSet; + for(auto child : childComponents ) + { + child->enumerateModules([&](Module* module) + { + requirementsSet.Add(module); + }); + } + + for(auto child : childComponents ) + { + auto childEntryPointCount = child->getEntryPointCount(); + for(Index cc = 0; cc < childEntryPointCount; ++cc) + { + m_entryPoints.add(child->getEntryPoint(cc)); + } + + auto childShaderParamCount = child->getShaderParamCount(); + for(Index pp = 0; pp < childShaderParamCount; ++pp) + { + m_shaderParams.add(child->getShaderParam(pp)); + } + + auto childSpecializationParamCount = child->getSpecializationParamCount(); + for(Index pp = 0; pp < childSpecializationParamCount; ++pp) + { + m_specializationParams.add(child->getSpecializationParam(pp)); + } + + for(auto module : child->getModuleDependencies()) + { + m_moduleDependencyList.addDependency(module); + } + for(auto filePath : child->getFilePathDependencies()) + { + m_filePathDependencyList.addDependency(filePath); + } + + auto childRequirementCount = child->getRequirementCount(); + for(Index rr = 0; rr < childRequirementCount; ++rr) + { + auto childRequirement = child->getRequirement(rr); + if(!requirementsSet.Contains(childRequirement)) + { + requirementsSet.Add(childRequirement); + m_requirements.add(childRequirement); + } + } + } +} + +Index CompositeComponentType::getEntryPointCount() +{ + return m_entryPoints.getCount(); +} + +RefPtr<EntryPoint> CompositeComponentType::getEntryPoint(Index index) +{ + return m_entryPoints[index]; +} + +Index CompositeComponentType::getShaderParamCount() +{ + return m_shaderParams.getCount(); } -void Program::addEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) +GlobalShaderParamInfo CompositeComponentType::getShaderParam(Index index) { - List<RefPtr<EntryPoint>> entryPoints; - entryPoints.add(entryPoint); + return m_shaderParams[index]; +} + +Index CompositeComponentType::getSpecializationParamCount() +{ + return m_specializationParams.getCount(); +} + +SpecializationParam const& CompositeComponentType::getSpecializationParam(Index index) +{ + return m_specializationParams[index]; +} + +Index CompositeComponentType::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> CompositeComponentType::getRequirement(Index index) +{ + return m_requirements[index]; +} + +List<Module*> const& CompositeComponentType::getModuleDependencies() +{ + return m_moduleDependencyList.getModuleList(); +} - RefPtr<EntryPointGroup> entryPointGroup = EntryPointGroup::create(getLinkageImpl(), entryPoints, sink); +List<String> const& CompositeComponentType::getFilePathDependencies() +{ + return m_filePathDependencyList.getFilePathList(); +} - addEntryPointGroup(entryPointGroup); +void CompositeComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) +{ + visitor->visitComposite(this, as<CompositeSpecializationInfo>(specializationInfo)); } -void Program::addEntryPointGroup(EntryPointGroup* entryPointGroup) + +RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) { - m_entryPointGroups.add(entryPointGroup); + SLANG_UNUSED(argCount); + + RefPtr<CompositeSpecializationInfo> specializationInfo = new CompositeSpecializationInfo(); - for(auto entryPoint : entryPointGroup->getEntryPoints()) + Index offset = 0; + for(auto child : m_childComponents) { - m_entryPoints.add(entryPoint); - for(auto module : entryPoint->getModuleDependencies()) + auto childParamCount = child->getSpecializationParamCount(); + SLANG_ASSERT(offset + childParamCount <= argCount); + + auto childInfo = child->_validateSpecializationArgs( + args + offset, + childParamCount, + sink); + + specializationInfo->childInfos.add(childInfo); + + offset += childParamCount; + } + return specializationInfo; +} + +// +// SpecializedComponentType +// + +SpecializedComponentType::SpecializedComponentType( + ComponentType* base, + ComponentType::SpecializationInfo* specializationInfo, + List<SpecializationArg> const& specializationArgs, + DiagnosticSink* sink) + : ComponentType(base->getLinkage()) + , m_base(base) + , m_specializationInfo(specializationInfo) + , m_specializationArgs(specializationArgs) +{ + m_irModule = generateIRForSpecializedComponentType(this, sink); + + // The following is a bit of a hack. + // + // Back-end code generation relies on us having computed layouts for all tagged + // unions that end up being used in the code, which means we need a way to find + // all such types that get used in a program (and the stuff it imports). + // + // For now we are assuming a tagged union type only comes into existence + // as a (top-level) argument for a generic type parameter, so that we + // can check for them here and cache them on the entry point. + // + // A longer-term strategy might need to consider any (tagged or untagged) + // union types that get used inside of a module, and also take + // those lists into account. + // + // An even longer-term strategy would be to allow type layout to + // be performed on IR types, so taht we don't need to have front-end + // code worrying about this stuff. + // + for(auto arg : specializationArgs) + { + auto argType = as<Type>(arg.val); + if(!argType) + continue; + + auto taggedUnionType = as<TaggedUnionType>(argType); + if(!taggedUnionType) + continue; + + m_taggedUnionTypes.add(taggedUnionType); + } +} + +void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) +{ + SLANG_ASSERT(specializationInfo == nullptr); + SLANG_UNUSED(specializationInfo); + visitor->visitSpecialized(this); +} + +Index SpecializedComponentType::getRequirementCount() +{ + // TODO: A specialized component type may have *more* requirements + // than the original, because it also needs to include the module(s) + // that define the types used for specialization arguments. + + return m_base->getRequirementCount(); +} + +RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index) +{ + return m_base->getRequirement(index); +} + +// +// LegacyProgram +// + +LegacyProgram::LegacyProgram( + Linkage* linkage, + List<RefPtr<TranslationUnitRequest>> const& translationUnits, + DiagnosticSink* sink) + : ComponentType(linkage) + , m_translationUnits(translationUnits) +{ + HashSet<ComponentType*> requirementsSet; + + for(auto translationUnit : translationUnits ) + { + ComponentType* child = translationUnit->getModule(); + + auto childEntryPointCount = child->getEntryPointCount(); + for(Index cc = 0; cc < childEntryPointCount; ++cc) + { + m_entryPoints.add(child->getEntryPoint(cc)); + } + + for(auto module : child->getModuleDependencies()) + { + m_moduleDependencies.addDependency(module); + } + for(auto filePath : child->getFilePathDependencies()) { - addReferencedModule(module); + m_fileDependencies.addDependency(filePath); + } + + auto childRequirementCount = child->getRequirementCount(); + for(Index rr = 0; rr < childRequirementCount; ++rr) + { + auto childRequirement = child->getRequirement(rr); + if(!requirementsSet.Contains(childRequirement)) + { + requirementsSet.Add(childRequirement); + m_requirements.add(childRequirement); + } } } + + _collectShaderParams(sink); } -RefPtr<IRModule> Program::getOrCreateIRModule(DiagnosticSink* sink) +void LegacyProgram::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) { - if(!m_irModule) + visitor->visitLegacy(this, as<CompositeComponentType::CompositeSpecializationInfo>(specializationInfo)); +} + +RefPtr<ComponentType::SpecializationInfo> LegacyProgram::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) +{ + SLANG_UNUSED(argCount); + + RefPtr<CompositeComponentType::CompositeSpecializationInfo> info = new CompositeComponentType::CompositeSpecializationInfo(); + + Index offset = 0; + for(auto translationUnit : m_translationUnits) { - m_irModule = generateIRForProgram( - m_linkage->getSessionImpl(), - this, + ComponentType* child = translationUnit->getModule(); + auto childParamCount = child->getSpecializationParamCount(); + SLANG_ASSERT(offset + childParamCount <= argCount); + + auto childInfo = child->_validateSpecializationArgs( + args + offset, + childParamCount, sink); + + info->childInfos.add(childInfo); + + offset += childParamCount; } - return m_irModule; + return info; } +Index LegacyProgram::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> LegacyProgram::getRequirement(Index index) +{ + return m_requirements[index]; +} + +void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) +{ + auto childCount = composite->getChildComponentCount(); + for(Index ii = 0; ii < childCount; ++ii) + { + auto child = composite->getChildComponent(ii); + auto childSpecializationInfo = specializationInfo + ? specializationInfo->childInfos[ii] + : nullptr; + + child->acceptVisitor(this, childSpecializationInfo); + } +} + +void ComponentTypeVisitor::visitChildren(SpecializedComponentType* specialized) +{ + specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo()); +} + +void ComponentTypeVisitor::visitChildren(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) +{ + auto childCount = legacy->getTranslationUnitCount(); + for(Index ii = 0; ii < childCount; ++ii) + { + auto translationUnit = legacy->getTranslationUnit(ii); + ComponentType* child = translationUnit->getModule(); + auto childSpecializationInfo = specializationInfo + ? specializationInfo->childInfos[ii] + : nullptr; + + child->acceptVisitor(this, childSpecializationInfo); + } +} -TargetProgram* Program::getTargetProgram(TargetRequest* target) +TargetProgram* ComponentType::getTargetProgram(TargetRequest* target) { RefPtr<TargetProgram> targetProgram; if(!m_targetPrograms.TryGetValue(target, targetProgram)) @@ -1768,12 +2264,12 @@ TargetProgram* Program::getTargetProgram(TargetRequest* target) // TargetProgram::TargetProgram( - Program* program, + ComponentType* componentType, TargetRequest* targetReq) - : m_program(program) + : m_program(componentType) , m_targetReq(targetReq) { - m_entryPointResults.setCount(program->getEntryPoints().getCount()); + m_entryPointResults.setCount(componentType->getEntryPointCount()); } // @@ -2458,10 +2954,10 @@ SLANG_API SlangResult spSetGlobalGenericArgs( if (!request) return SLANG_FAIL; auto req = convert(request); - auto& genericArgStrings = req->globalGenericArgStrings; - genericArgStrings.clear(); + auto& argStrings = req->globalSpecializationArgStrings; + argStrings.clear(); for (int i = 0; i < genericArgCount; i++) - genericArgStrings.add(genericArgs[i]); + argStrings.add(genericArgs[i]); return SLANG_OK; } @@ -2477,7 +2973,7 @@ SLANG_API SlangResult spSetTypeNameForGlobalExistentialTypeParam( if(!typeName) return SLANG_FAIL; auto req = convert(request); - auto& typeArgStrings = req->globalExistentialSlotArgStrings; + auto& typeArgStrings = req->globalSpecializationArgStrings; if(Index(slotIndex) >= typeArgStrings.getCount()) typeArgStrings.setCount(slotIndex+1); typeArgStrings[slotIndex] = String(typeName); @@ -2501,7 +2997,7 @@ SLANG_API SlangResult spSetTypeNameForEntryPointExistentialTypeParam( return SLANG_FAIL; auto& entryPointInfo = req->entryPoints[entryPointIndex]; - auto& typeArgStrings = entryPointInfo.existentialArgStrings; + auto& typeArgStrings = entryPointInfo.specializationArgStrings; if(Index(slotIndex) >= typeArgStrings.getCount()) typeArgStrings.setCount(slotIndex+1); typeArgStrings[slotIndex] = String(typeName); @@ -2569,7 +3065,7 @@ spGetDependencyFileCount( if(!request) return 0; auto req = convert(request); auto frontEndReq = req->getFrontEndReq(); - auto program = frontEndReq->getProgram(); + auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); return (int) program->getFilePathDependencies().getCount(); } @@ -2583,7 +3079,7 @@ spGetDependencyFilePath( if(!request) return 0; auto req = convert(request); auto frontEndReq = req->getFrontEndReq(); - auto program = frontEndReq->getProgram(); + auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); return program->getFilePathDependencies()[index].begin(); } @@ -2613,7 +3109,7 @@ SLANG_API void const* spGetEntryPointCode( using namespace Slang; auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); // TODO: We should really accept a target index in this API Index targetIndex = 0; @@ -2668,7 +3164,7 @@ SLANG_API SlangResult spGetEntryPointCodeBlob( auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); Index targetCount = linkage->targets.getCount(); if((targetIndex < 0) || (targetIndex >= targetCount)) @@ -2715,13 +3211,13 @@ SLANG_API void const* spGetCompileRequestCode( SLANG_API SlangResult spCompileRequest_getProgram( SlangCompileRequest* request, - slang::IProgram** outProgram) + slang::IComponentType** outProgram) { if( !request ) return SLANG_ERROR_INVALID_PARAMETER; auto req = convert(request); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); - *outProgram = Slang::ComPtr<slang::IProgram>(program).detach(); + *outProgram = Slang::ComPtr<slang::IComponentType>(program).detach(); return SLANG_OK; } @@ -2731,7 +3227,7 @@ SLANG_API SlangReflection* spGetReflection( if( !request ) return 0; auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); // Note(tfoley): The API signature doesn't let the client // specify which target they want to access reflection |
