summaryrefslogtreecommitdiffstats
path: root/source/slang/slang.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang.cpp')
-rw-r--r--source/slang/slang.cpp700
1 files changed, 598 insertions, 102 deletions
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index b18e4d4d9..b030e5cf9 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -436,44 +436,31 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModule(
return asExternal(module);
}
-SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createProgram(
- slang::ProgramDesc const& desc,
- slang::IProgram** outProgram)
+SLANG_NO_THROW slang::IComponentType* SLANG_MCALL Linkage::createCompositeComponentType(
+ slang::IComponentType* const* componentTypes,
+ SlangInt componentTypeCount,
+ ISlangBlob** outDiagnostics)
{
- RefPtr<Program> program = new Program(this);
-
- auto itemCount = desc.itemCount;
- for(SlangInt ii = 0; ii < itemCount; ++ii)
- {
- auto& item = desc.items[ii];
- switch(item.kind)
- {
- case slang::ProgramDesc::Item::Kind::Program:
- {
- Program* existingProgram = asInternal(item.program);
- for(auto referencedModule : existingProgram->getModuleDependencies())
- {
- program->addReferencedLeafModule(referencedModule);
- }
-
- // TODO: Need to decide whether to include the entry points as well...
- }
- break;
+ // Attempting to create a "composite" of just one component type should
+ // just return the component type itself, to avoid redundant work.
+ //
+ if( componentTypeCount == 1)
+ return componentTypes[0];
- case slang::ProgramDesc::Item::Kind::Module:
- {
- Module* module = asInternal(item.module);
- program->addReferencedModule(module);
- }
- break;
+ DiagnosticSink sink(getSourceManager());
- default:
- return SLANG_E_INVALID_ARG;
- }
+ List<RefPtr<ComponentType>> childComponents;
+ for( Int cc = 0; cc < componentTypeCount; ++cc )
+ {
+ childComponents.add(asInternal(componentTypes[cc]));
}
- *outProgram = asExternal(program.detach());
- return SLANG_OK;
+ RefPtr<ComponentType> composite = CompositeComponentType::create(
+ this,
+ childComponents);
+
+ sink.getBlobIfNeeded(outDiagnostics);
+ return asExternal(composite.detach());
}
SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType(
@@ -781,7 +768,9 @@ RefPtr<Type> checkProperType(
TypeExp typeExp,
DiagnosticSink* sink);
-Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink)
+Type* ComponentType::getTypeFromString(
+ String const& typeStr,
+ DiagnosticSink* sink)
{
// If we've looked up this type name before,
// then we can re-use it.
@@ -803,7 +792,7 @@ Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink)
for(auto module : getModuleDependencies())
scopesToTry.add(module->getModuleDecl()->scope);
- auto linkage = getLinkageImpl();
+ auto linkage = getLinkage();
for(auto& s : scopesToTry)
{
RefPtr<Expr> typeExpr = linkage->parseTypeString(
@@ -899,10 +888,16 @@ void FrontEndCompileRequest::parseTranslationUnit(
}
}
-RefPtr<Program> createUnspecializedProgram(
+RefPtr<ComponentType> createUnspecializedGlobalComponentType(
+ FrontEndCompileRequest* compileRequest);
+
+RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType(
FrontEndCompileRequest* compileRequest);
-RefPtr<Program> createSpecializedProgram(
+RefPtr<ComponentType> createSpecializedGlobalComponentType(
+ EndToEndCompileRequest* endToEndReq);
+
+RefPtr<ComponentType> createSpecializedGlobalAndEntryPointsComponentType(
EndToEndCompileRequest* endToEndReq);
void FrontEndCompileRequest::checkAllTranslationUnits()
@@ -1032,7 +1027,11 @@ SlangResult FrontEndCompileRequest::executeActionsInner()
// Look up all the entry points that are expected,
// and use them to populate the `program` member.
//
- m_program = createUnspecializedProgram(this);
+ m_globalComponentType = createUnspecializedGlobalComponentType(this);
+ if (getSink()->GetErrorCount() != 0)
+ return SLANG_FAIL;
+
+ m_globalAndEntryPointsComponentType = createUnspecializedGlobalAndEntryPointsComponentType(this);
if (getSink()->GetErrorCount() != 0)
return SLANG_FAIL;
@@ -1052,7 +1051,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner()
//
for(auto targetReq : getLinkage()->targets)
{
- auto targetProgram = m_program->getTargetProgram(targetReq);
+ auto targetProgram = m_globalAndEntryPointsComponentType->getTargetProgram(targetReq);
targetProgram->getOrCreateLayout(getSink());
}
if (getSink()->GetErrorCount() != 0)
@@ -1064,7 +1063,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner()
BackEndCompileRequest::BackEndCompileRequest(
Linkage* linkage,
DiagnosticSink* sink,
- Program* program)
+ ComponentType* program)
: CompileRequestBase(linkage, sink)
, m_program(program)
{}
@@ -1146,7 +1145,8 @@ SlangResult EndToEndCompileRequest::executeActionsInner()
// that was computed in the front-end for all subsequent
// reflection queries, etc.
//
- m_specializedProgram = getUnspecializedProgram();
+ m_specializedGlobalComponentType = getUnspecializedGlobalComponentType();
+ m_specializedGlobalAndEntryPointsComponentType = getUnspecializedGlobalAndEntryPointsComponentType();
return SLANG_OK;
}
@@ -1156,7 +1156,11 @@ SlangResult EndToEndCompileRequest::executeActionsInner()
//
if (passThrough == PassThroughMode::None)
{
- m_specializedProgram = createSpecializedProgram(this);
+ m_specializedGlobalComponentType = createSpecializedGlobalComponentType(this);
+ if (getSink()->GetErrorCount() != 0)
+ return SLANG_FAIL;
+
+ m_specializedGlobalAndEntryPointsComponentType = createSpecializedGlobalAndEntryPointsComponentType(this);
if (getSink()->GetErrorCount() != 0)
return SLANG_FAIL;
@@ -1166,7 +1170,7 @@ SlangResult EndToEndCompileRequest::executeActionsInner()
//
for (auto targetReq : getLinkage()->targets)
{
- auto targetProgram = m_specializedProgram->getTargetProgram(targetReq);
+ auto targetProgram = m_specializedGlobalAndEntryPointsComponentType->getTargetProgram(targetReq);
targetProgram->getOrCreateLayout(getSink());
}
if (getSink()->GetErrorCount() != 0)
@@ -1178,20 +1182,27 @@ SlangResult EndToEndCompileRequest::executeActionsInner()
// to make sure that the logic in `generateOutput`
// sees something worth processing.
//
- auto specializedProgram = new Program(getLinkage());
- m_specializedProgram = specializedProgram;
+ List<RefPtr<ComponentType>> dummyEntryPoints;
for(auto entryPointReq : getFrontEndReq()->getEntryPointReqs())
{
- RefPtr<EntryPoint> entryPoint = EntryPoint::createDummyForPassThrough(
+ RefPtr<EntryPoint> dummyEntryPoint = EntryPoint::createDummyForPassThrough(
+ getLinkage(),
entryPointReq->getName(),
entryPointReq->getProfile());
- specializedProgram->addEntryPoint(entryPoint, getSink());
+ dummyEntryPoints.add(dummyEntryPoint);
}
+
+ RefPtr<ComponentType> composedProgram = CompositeComponentType::create(
+ getLinkage(),
+ dummyEntryPoints);
+
+ m_specializedGlobalComponentType = getUnspecializedGlobalComponentType();
+ m_specializedGlobalAndEntryPointsComponentType = composedProgram;
}
// Generate output code, in whatever format was requested
- getBackEndReq()->setProgram(getSpecializedProgram());
+ getBackEndReq()->setProgram(getSpecializedGlobalAndEntryPointsComponentType());
generateOutput(this);
if (getSink()->GetErrorCount() != 0)
return SLANG_FAIL;
@@ -1324,7 +1335,7 @@ int EndToEndCompileRequest::addEntryPoint(
EntryPointInfo entryPointInfo;
for (auto typeName : genericTypeNames)
- entryPointInfo.genericArgStrings.add(typeName);
+ entryPointInfo.specializationArgStrings.add(typeName);
Index result = entryPoints.getCount();
entryPoints.add(_Move(entryPointInfo));
@@ -1617,8 +1628,10 @@ void FilePathDependencyList::addDependency(Module* module)
//
Module::Module(Linkage* linkage)
- : m_linkage(linkage)
-{}
+ : ComponentType(linkage)
+{
+ addModuleDependency(this);
+}
ISlangUnknown* Module::getInterface(const Guid& guid)
{
@@ -1638,35 +1651,40 @@ void Module::addFilePathDependency(String const& path)
m_filePathDependencyList.addDependency(path);
}
-// Program
+void Module::setModuleDecl(ModuleDecl* moduleDecl)
+{
+ m_moduleDecl = moduleDecl;
+}
+
+// ComponentType
-static const Guid IID_IProgram = SLANG_UUID_IProgram;
+static const Guid IID_IComponentType = SLANG_UUID_IComponentType;
-Program::Program(Linkage* linkage)
+ComponentType::ComponentType(Linkage* linkage)
: m_linkage(linkage)
{}
-ISlangUnknown* Program::getInterface(Guid const& guid)
+ISlangUnknown* ComponentType::getInterface(Guid const& guid)
{
if(guid == IID_ISlangUnknown
- || guid == IID_IProgram)
+ || guid == IID_IComponentType)
{
- return static_cast<slang::IProgram*>(this);
+ return static_cast<slang::IComponentType*>(this);
}
return nullptr;
}
-SLANG_NO_THROW slang::ISession* SLANG_MCALL Program::getSession()
+SLANG_NO_THROW slang::ISession* SLANG_MCALL ComponentType::getSession()
{
return m_linkage;
}
-SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout(
+SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL ComponentType::getLayout(
Int targetIndex,
slang::IBlob** outDiagnostics)
{
- auto linkage = getLinkageImpl();
+ auto linkage = getLinkage();
if(targetIndex < 0 || targetIndex >= linkage->targets.getCount())
return nullptr;
auto target = linkage->targets[targetIndex];
@@ -1678,13 +1696,13 @@ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout(
return asExternal(programLayout);
}
-SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode(
+SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCode(
SlangInt entryPointIndex,
Int targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics)
{
- auto linkage = getLinkageImpl();
+ auto linkage = getLinkage();
if(targetIndex < 0 || targetIndex >= linkage->targets.getCount())
return SLANG_E_INVALID_ARG;
auto target = linkage->targets[targetIndex];
@@ -1702,57 +1720,535 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode(
return SLANG_OK;
}
+RefPtr<ComponentType> ComponentType::specialize(
+ SpecializationArg const* inSpecializationArgs,
+ SlangInt specializationArgCount,
+ DiagnosticSink* sink)
+{
+ List<SpecializationArg> specializationArgs;
+ specializationArgs.addRange(
+ inSpecializationArgs,
+ specializationArgCount);
+
+ // We next need to validate that the specialization arguments
+ // make sense, and also expand them to include any derived data
+ // (e.g., interface conformance witnesses) that doesn't get
+ // passed explicitly through the API interface.
+ //
+ RefPtr<SpecializationInfo> specializationInfo = _validateSpecializationArgs(
+ specializationArgs.getBuffer(),
+ specializationArgCount,
+ sink);
+
+ return new SpecializedComponentType(
+ this,
+ specializationInfo,
+ specializationArgs,
+ sink);
+}
-void Program::addReferencedModule(Module* module)
+SLANG_NO_THROW slang::IComponentType* SLANG_MCALL ComponentType::specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ ISlangBlob** outDiagnostics)
{
- m_moduleDependencyList.addDependency(module);
- m_filePathDependencyList.addDependency(module);
+ DiagnosticSink sink(getLinkage()->getSourceManager());
+
+ // First let's check if the number of arguments given matches
+ // the number of parameters that are present on this component type.
+ //
+ auto specializationParamCount = getSpecializationParamCount();
+ if( specializationArgCount != specializationParamCount )
+ {
+ // TODO: diagnose
+ sink.getBlobIfNeeded(outDiagnostics);
+ return nullptr;
+ }
+
+ List<SpecializationArg> expandedArgs;
+ for( Int aa = 0; aa < specializationArgCount; ++aa )
+ {
+ auto apiArg = specializationArgs[aa];
+
+ SpecializationArg expandedArg;
+ switch(apiArg.kind)
+ {
+ case slang::SpecializationArg::Kind::Type:
+ expandedArg.val = asInternal(apiArg.type);
+ break;
+
+ default:
+ sink.getBlobIfNeeded(outDiagnostics);
+ return nullptr;
+ }
+ expandedArgs.add(expandedArg);
+ }
+
+ auto specializedComponentType = specialize(
+ expandedArgs.getBuffer(),
+ expandedArgs.getCount(),
+ &sink);
+
+ sink.getBlobIfNeeded(outDiagnostics);
+
+ return specializedComponentType;
}
-void Program::addReferencedLeafModule(Module* module)
+ /// Visitor used by `ComponentType::enumerateModules`
+struct EnumerateModulesVisitor : ComponentTypeVisitor
{
- m_moduleDependencyList.addLeafDependency(module);
- m_filePathDependencyList.addDependency(module);
+ EnumerateModulesVisitor(ComponentType::EnumerateModulesCallback callback, void* userData)
+ : m_callback(callback)
+ , m_userData(userData)
+ {}
+
+ ComponentType::EnumerateModulesCallback m_callback;
+ void* m_userData;
+
+ void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {}
+
+ void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE
+ {
+ m_callback(module, m_userData);
+ }
+
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(composite, specializationInfo);
+ }
+
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
+ {
+ visitChildren(specialized);
+ }
+
+ void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(legacy, specializationInfo);
+ }
+};
+
+
+void ComponentType::enumerateModules(EnumerateModulesCallback callback, void* userData)
+{
+ EnumerateModulesVisitor visitor(callback, userData);
+ acceptVisitor(&visitor, nullptr);
+}
+
+ /// Visitor used by `ComponentType::enumerateIRModules`
+struct EnumerateIRModulesVisitor : ComponentTypeVisitor
+{
+ EnumerateIRModulesVisitor(ComponentType::EnumerateIRModulesCallback callback, void* userData)
+ : m_callback(callback)
+ , m_userData(userData)
+ {}
+
+ ComponentType::EnumerateIRModulesCallback m_callback;
+ void* m_userData;
+
+ void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {}
+
+ void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE
+ {
+ m_callback(module->getIRModule(), m_userData);
+ }
+
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(composite, specializationInfo);
+ }
+
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
+ {
+ visitChildren(specialized);
+
+ m_callback(specialized->getIRModule(), m_userData);
+ }
+
+ void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(legacy, specializationInfo);
+ }
+};
+
+void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData)
+{
+ EnumerateIRModulesVisitor visitor(callback, userData);
+ acceptVisitor(&visitor, nullptr);
+}
+
+//
+// CompositeComponentType
+//
+
+RefPtr<ComponentType> CompositeComponentType::create(
+ Linkage* linkage,
+ List<RefPtr<ComponentType>> const& childComponents)
+{
+ // TODO: We should ideally be caching the results of
+ // composition on the `linkage`, so that if we get
+ // asked for the same composite again later we re-use
+ // it rather than re-create it.
+ //
+ // Similarly, we might want to do some amount of
+ // work to "canonicalize" the input for composition.
+ // E.g., if the user does:
+ //
+ // X = compose(A,B);
+ // Y = compose(C,D);
+ // Z = compose(X,Y);
+ //
+ // W = compose(A, B, C, D);
+ //
+ // Then there is no observable difference between
+ // Z and W, so we might prefer to have them be identical.
+
+ // If there is only a single child, then we should
+ // just return that child rather than create a dummy composite.
+ //
+ if( childComponents.getCount() == 1 )
+ {
+ return childComponents[0];
+ }
+
+ return new CompositeComponentType(linkage, childComponents);
+}
+
+
+CompositeComponentType::CompositeComponentType(
+ Linkage* linkage,
+ List<RefPtr<ComponentType>> const& childComponents)
+ : ComponentType(linkage)
+ , m_childComponents(childComponents)
+{
+ HashSet<ComponentType*> requirementsSet;
+ for(auto child : childComponents )
+ {
+ child->enumerateModules([&](Module* module)
+ {
+ requirementsSet.Add(module);
+ });
+ }
+
+ for(auto child : childComponents )
+ {
+ auto childEntryPointCount = child->getEntryPointCount();
+ for(Index cc = 0; cc < childEntryPointCount; ++cc)
+ {
+ m_entryPoints.add(child->getEntryPoint(cc));
+ }
+
+ auto childShaderParamCount = child->getShaderParamCount();
+ for(Index pp = 0; pp < childShaderParamCount; ++pp)
+ {
+ m_shaderParams.add(child->getShaderParam(pp));
+ }
+
+ auto childSpecializationParamCount = child->getSpecializationParamCount();
+ for(Index pp = 0; pp < childSpecializationParamCount; ++pp)
+ {
+ m_specializationParams.add(child->getSpecializationParam(pp));
+ }
+
+ for(auto module : child->getModuleDependencies())
+ {
+ m_moduleDependencyList.addDependency(module);
+ }
+ for(auto filePath : child->getFilePathDependencies())
+ {
+ m_filePathDependencyList.addDependency(filePath);
+ }
+
+ auto childRequirementCount = child->getRequirementCount();
+ for(Index rr = 0; rr < childRequirementCount; ++rr)
+ {
+ auto childRequirement = child->getRequirement(rr);
+ if(!requirementsSet.Contains(childRequirement))
+ {
+ requirementsSet.Add(childRequirement);
+ m_requirements.add(childRequirement);
+ }
+ }
+ }
+}
+
+Index CompositeComponentType::getEntryPointCount()
+{
+ return m_entryPoints.getCount();
+}
+
+RefPtr<EntryPoint> CompositeComponentType::getEntryPoint(Index index)
+{
+ return m_entryPoints[index];
+}
+
+Index CompositeComponentType::getShaderParamCount()
+{
+ return m_shaderParams.getCount();
}
-void Program::addEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink)
+GlobalShaderParamInfo CompositeComponentType::getShaderParam(Index index)
{
- List<RefPtr<EntryPoint>> entryPoints;
- entryPoints.add(entryPoint);
+ return m_shaderParams[index];
+}
+
+Index CompositeComponentType::getSpecializationParamCount()
+{
+ return m_specializationParams.getCount();
+}
+
+SpecializationParam const& CompositeComponentType::getSpecializationParam(Index index)
+{
+ return m_specializationParams[index];
+}
+
+Index CompositeComponentType::getRequirementCount()
+{
+ return m_requirements.getCount();
+}
+
+RefPtr<ComponentType> CompositeComponentType::getRequirement(Index index)
+{
+ return m_requirements[index];
+}
+
+List<Module*> const& CompositeComponentType::getModuleDependencies()
+{
+ return m_moduleDependencyList.getModuleList();
+}
- RefPtr<EntryPointGroup> entryPointGroup = EntryPointGroup::create(getLinkageImpl(), entryPoints, sink);
+List<String> const& CompositeComponentType::getFilePathDependencies()
+{
+ return m_filePathDependencyList.getFilePathList();
+}
- addEntryPointGroup(entryPointGroup);
+void CompositeComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
+{
+ visitor->visitComposite(this, as<CompositeSpecializationInfo>(specializationInfo));
}
-void Program::addEntryPointGroup(EntryPointGroup* entryPointGroup)
+
+RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink)
{
- m_entryPointGroups.add(entryPointGroup);
+ SLANG_UNUSED(argCount);
+
+ RefPtr<CompositeSpecializationInfo> specializationInfo = new CompositeSpecializationInfo();
- for(auto entryPoint : entryPointGroup->getEntryPoints())
+ Index offset = 0;
+ for(auto child : m_childComponents)
{
- m_entryPoints.add(entryPoint);
- for(auto module : entryPoint->getModuleDependencies())
+ auto childParamCount = child->getSpecializationParamCount();
+ SLANG_ASSERT(offset + childParamCount <= argCount);
+
+ auto childInfo = child->_validateSpecializationArgs(
+ args + offset,
+ childParamCount,
+ sink);
+
+ specializationInfo->childInfos.add(childInfo);
+
+ offset += childParamCount;
+ }
+ return specializationInfo;
+}
+
+//
+// SpecializedComponentType
+//
+
+SpecializedComponentType::SpecializedComponentType(
+ ComponentType* base,
+ ComponentType::SpecializationInfo* specializationInfo,
+ List<SpecializationArg> const& specializationArgs,
+ DiagnosticSink* sink)
+ : ComponentType(base->getLinkage())
+ , m_base(base)
+ , m_specializationInfo(specializationInfo)
+ , m_specializationArgs(specializationArgs)
+{
+ m_irModule = generateIRForSpecializedComponentType(this, sink);
+
+ // The following is a bit of a hack.
+ //
+ // Back-end code generation relies on us having computed layouts for all tagged
+ // unions that end up being used in the code, which means we need a way to find
+ // all such types that get used in a program (and the stuff it imports).
+ //
+ // For now we are assuming a tagged union type only comes into existence
+ // as a (top-level) argument for a generic type parameter, so that we
+ // can check for them here and cache them on the entry point.
+ //
+ // A longer-term strategy might need to consider any (tagged or untagged)
+ // union types that get used inside of a module, and also take
+ // those lists into account.
+ //
+ // An even longer-term strategy would be to allow type layout to
+ // be performed on IR types, so taht we don't need to have front-end
+ // code worrying about this stuff.
+ //
+ for(auto arg : specializationArgs)
+ {
+ auto argType = as<Type>(arg.val);
+ if(!argType)
+ continue;
+
+ auto taggedUnionType = as<TaggedUnionType>(argType);
+ if(!taggedUnionType)
+ continue;
+
+ m_taggedUnionTypes.add(taggedUnionType);
+ }
+}
+
+void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
+{
+ SLANG_ASSERT(specializationInfo == nullptr);
+ SLANG_UNUSED(specializationInfo);
+ visitor->visitSpecialized(this);
+}
+
+Index SpecializedComponentType::getRequirementCount()
+{
+ // TODO: A specialized component type may have *more* requirements
+ // than the original, because it also needs to include the module(s)
+ // that define the types used for specialization arguments.
+
+ return m_base->getRequirementCount();
+}
+
+RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index)
+{
+ return m_base->getRequirement(index);
+}
+
+//
+// LegacyProgram
+//
+
+LegacyProgram::LegacyProgram(
+ Linkage* linkage,
+ List<RefPtr<TranslationUnitRequest>> const& translationUnits,
+ DiagnosticSink* sink)
+ : ComponentType(linkage)
+ , m_translationUnits(translationUnits)
+{
+ HashSet<ComponentType*> requirementsSet;
+
+ for(auto translationUnit : translationUnits )
+ {
+ ComponentType* child = translationUnit->getModule();
+
+ auto childEntryPointCount = child->getEntryPointCount();
+ for(Index cc = 0; cc < childEntryPointCount; ++cc)
+ {
+ m_entryPoints.add(child->getEntryPoint(cc));
+ }
+
+ for(auto module : child->getModuleDependencies())
+ {
+ m_moduleDependencies.addDependency(module);
+ }
+ for(auto filePath : child->getFilePathDependencies())
{
- addReferencedModule(module);
+ m_fileDependencies.addDependency(filePath);
+ }
+
+ auto childRequirementCount = child->getRequirementCount();
+ for(Index rr = 0; rr < childRequirementCount; ++rr)
+ {
+ auto childRequirement = child->getRequirement(rr);
+ if(!requirementsSet.Contains(childRequirement))
+ {
+ requirementsSet.Add(childRequirement);
+ m_requirements.add(childRequirement);
+ }
}
}
+
+ _collectShaderParams(sink);
}
-RefPtr<IRModule> Program::getOrCreateIRModule(DiagnosticSink* sink)
+void LegacyProgram::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
{
- if(!m_irModule)
+ visitor->visitLegacy(this, as<CompositeComponentType::CompositeSpecializationInfo>(specializationInfo));
+}
+
+RefPtr<ComponentType::SpecializationInfo> LegacyProgram::_validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink)
+{
+ SLANG_UNUSED(argCount);
+
+ RefPtr<CompositeComponentType::CompositeSpecializationInfo> info = new CompositeComponentType::CompositeSpecializationInfo();
+
+ Index offset = 0;
+ for(auto translationUnit : m_translationUnits)
{
- m_irModule = generateIRForProgram(
- m_linkage->getSessionImpl(),
- this,
+ ComponentType* child = translationUnit->getModule();
+ auto childParamCount = child->getSpecializationParamCount();
+ SLANG_ASSERT(offset + childParamCount <= argCount);
+
+ auto childInfo = child->_validateSpecializationArgs(
+ args + offset,
+ childParamCount,
sink);
+
+ info->childInfos.add(childInfo);
+
+ offset += childParamCount;
}
- return m_irModule;
+ return info;
}
+Index LegacyProgram::getRequirementCount()
+{
+ return m_requirements.getCount();
+}
+
+RefPtr<ComponentType> LegacyProgram::getRequirement(Index index)
+{
+ return m_requirements[index];
+}
+
+void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo)
+{
+ auto childCount = composite->getChildComponentCount();
+ for(Index ii = 0; ii < childCount; ++ii)
+ {
+ auto child = composite->getChildComponent(ii);
+ auto childSpecializationInfo = specializationInfo
+ ? specializationInfo->childInfos[ii]
+ : nullptr;
+
+ child->acceptVisitor(this, childSpecializationInfo);
+ }
+}
+
+void ComponentTypeVisitor::visitChildren(SpecializedComponentType* specialized)
+{
+ specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo());
+}
+
+void ComponentTypeVisitor::visitChildren(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo)
+{
+ auto childCount = legacy->getTranslationUnitCount();
+ for(Index ii = 0; ii < childCount; ++ii)
+ {
+ auto translationUnit = legacy->getTranslationUnit(ii);
+ ComponentType* child = translationUnit->getModule();
+ auto childSpecializationInfo = specializationInfo
+ ? specializationInfo->childInfos[ii]
+ : nullptr;
+
+ child->acceptVisitor(this, childSpecializationInfo);
+ }
+}
-TargetProgram* Program::getTargetProgram(TargetRequest* target)
+TargetProgram* ComponentType::getTargetProgram(TargetRequest* target)
{
RefPtr<TargetProgram> targetProgram;
if(!m_targetPrograms.TryGetValue(target, targetProgram))
@@ -1768,12 +2264,12 @@ TargetProgram* Program::getTargetProgram(TargetRequest* target)
//
TargetProgram::TargetProgram(
- Program* program,
+ ComponentType* componentType,
TargetRequest* targetReq)
- : m_program(program)
+ : m_program(componentType)
, m_targetReq(targetReq)
{
- m_entryPointResults.setCount(program->getEntryPoints().getCount());
+ m_entryPointResults.setCount(componentType->getEntryPointCount());
}
//
@@ -2458,10 +2954,10 @@ SLANG_API SlangResult spSetGlobalGenericArgs(
if (!request) return SLANG_FAIL;
auto req = convert(request);
- auto& genericArgStrings = req->globalGenericArgStrings;
- genericArgStrings.clear();
+ auto& argStrings = req->globalSpecializationArgStrings;
+ argStrings.clear();
for (int i = 0; i < genericArgCount; i++)
- genericArgStrings.add(genericArgs[i]);
+ argStrings.add(genericArgs[i]);
return SLANG_OK;
}
@@ -2477,7 +2973,7 @@ SLANG_API SlangResult spSetTypeNameForGlobalExistentialTypeParam(
if(!typeName) return SLANG_FAIL;
auto req = convert(request);
- auto& typeArgStrings = req->globalExistentialSlotArgStrings;
+ auto& typeArgStrings = req->globalSpecializationArgStrings;
if(Index(slotIndex) >= typeArgStrings.getCount())
typeArgStrings.setCount(slotIndex+1);
typeArgStrings[slotIndex] = String(typeName);
@@ -2501,7 +2997,7 @@ SLANG_API SlangResult spSetTypeNameForEntryPointExistentialTypeParam(
return SLANG_FAIL;
auto& entryPointInfo = req->entryPoints[entryPointIndex];
- auto& typeArgStrings = entryPointInfo.existentialArgStrings;
+ auto& typeArgStrings = entryPointInfo.specializationArgStrings;
if(Index(slotIndex) >= typeArgStrings.getCount())
typeArgStrings.setCount(slotIndex+1);
typeArgStrings[slotIndex] = String(typeName);
@@ -2569,7 +3065,7 @@ spGetDependencyFileCount(
if(!request) return 0;
auto req = convert(request);
auto frontEndReq = req->getFrontEndReq();
- auto program = frontEndReq->getProgram();
+ auto program = frontEndReq->getGlobalAndEntryPointsComponentType();
return (int) program->getFilePathDependencies().getCount();
}
@@ -2583,7 +3079,7 @@ spGetDependencyFilePath(
if(!request) return 0;
auto req = convert(request);
auto frontEndReq = req->getFrontEndReq();
- auto program = frontEndReq->getProgram();
+ auto program = frontEndReq->getGlobalAndEntryPointsComponentType();
return program->getFilePathDependencies()[index].begin();
}
@@ -2613,7 +3109,7 @@ SLANG_API void const* spGetEntryPointCode(
using namespace Slang;
auto req = convert(request);
auto linkage = req->getLinkage();
- auto program = req->getSpecializedProgram();
+ auto program = req->getSpecializedGlobalAndEntryPointsComponentType();
// TODO: We should really accept a target index in this API
Index targetIndex = 0;
@@ -2668,7 +3164,7 @@ SLANG_API SlangResult spGetEntryPointCodeBlob(
auto req = convert(request);
auto linkage = req->getLinkage();
- auto program = req->getSpecializedProgram();
+ auto program = req->getSpecializedGlobalAndEntryPointsComponentType();
Index targetCount = linkage->targets.getCount();
if((targetIndex < 0) || (targetIndex >= targetCount))
@@ -2715,13 +3211,13 @@ SLANG_API void const* spGetCompileRequestCode(
SLANG_API SlangResult spCompileRequest_getProgram(
SlangCompileRequest* request,
- slang::IProgram** outProgram)
+ slang::IComponentType** outProgram)
{
if( !request ) return SLANG_ERROR_INVALID_PARAMETER;
auto req = convert(request);
- auto program = req->getSpecializedProgram();
+ auto program = req->getSpecializedGlobalAndEntryPointsComponentType();
- *outProgram = Slang::ComPtr<slang::IProgram>(program).detach();
+ *outProgram = Slang::ComPtr<slang::IComponentType>(program).detach();
return SLANG_OK;
}
@@ -2731,7 +3227,7 @@ SLANG_API SlangReflection* spGetReflection(
if( !request ) return 0;
auto req = convert(request);
auto linkage = req->getLinkage();
- auto program = req->getSpecializedProgram();
+ auto program = req->getSpecializedGlobalAndEntryPointsComponentType();
// Note(tfoley): The API signature doesn't let the client
// specify which target they want to access reflection