diff options
Diffstat (limited to 'source/slang/slang-linkable-impls.cpp')
| -rw-r--r-- | source/slang/slang-linkable-impls.cpp | 752 |
1 files changed, 752 insertions, 0 deletions
diff --git a/source/slang/slang-linkable-impls.cpp b/source/slang/slang-linkable-impls.cpp new file mode 100644 index 000000000..d03ecb3ca --- /dev/null +++ b/source/slang/slang-linkable-impls.cpp @@ -0,0 +1,752 @@ +// slang-linkable-impls.cpp +#include "slang-linkable-impls.h" + +#include "slang-lower-to-ir.h" // for `generateIRForTypeConformance` +#include "slang-mangle.h" + +namespace Slang +{ + +// +// CompositeComponentType +// + +RefPtr<ComponentType> CompositeComponentType::create( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents) +{ + // TODO: We should ideally be caching the results of + // composition on the `linkage`, so that if we get + // asked for the same composite again later we re-use + // it rather than re-create it. + // + // Similarly, we might want to do some amount of + // work to "canonicalize" the input for composition. + // E.g., if the user does: + // + // X = compose(A,B); + // Y = compose(C,D); + // Z = compose(X,Y); + // + // W = compose(A, B, C, D); + // + // Then there is no observable difference between + // Z and W, so we might prefer to have them be identical. + + // If there is only a single child, then we should + // just return that child rather than create a dummy composite. + // + if (childComponents.getCount() == 1) + { + return childComponents[0]; + } + + return new CompositeComponentType(linkage, childComponents); +} + + +CompositeComponentType::CompositeComponentType( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents) + : ComponentType(linkage), m_childComponents(childComponents) +{ + HashSet<ComponentType*> requirementsSet; + for (auto child : childComponents) + { + child->enumerateModules([&](Module* module) { requirementsSet.add(module); }); + } + + for (auto child : childComponents) + { + auto childEntryPointCount = child->getEntryPointCount(); + for (Index cc = 0; cc < childEntryPointCount; ++cc) + { + m_entryPoints.add(child->getEntryPoint(cc)); + m_entryPointMangledNames.add(child->getEntryPointMangledName(cc)); + m_entryPointNameOverrides.add(child->getEntryPointNameOverride(cc)); + } + + auto childShaderParamCount = child->getShaderParamCount(); + for (Index pp = 0; pp < childShaderParamCount; ++pp) + { + m_shaderParams.add(child->getShaderParam(pp)); + } + + auto childSpecializationParamCount = child->getSpecializationParamCount(); + for (Index pp = 0; pp < childSpecializationParamCount; ++pp) + { + m_specializationParams.add(child->getSpecializationParam(pp)); + } + + for (auto module : child->getModuleDependencies()) + { + m_moduleDependencyList.addDependency(module); + } + for (auto sourceFile : child->getFileDependencies()) + { + m_fileDependencyList.addDependency(sourceFile); + } + + auto childRequirementCount = child->getRequirementCount(); + for (Index rr = 0; rr < childRequirementCount; ++rr) + { + auto childRequirement = child->getRequirement(rr); + if (!requirementsSet.contains(childRequirement)) + { + requirementsSet.add(childRequirement); + m_requirements.add(childRequirement); + } + } + } +} + +void CompositeComponentType::buildHash(DigestBuilder<SHA1>& builder) +{ + auto componentCount = getChildComponentCount(); + + for (Index i = 0; i < componentCount; ++i) + { + getChildComponent(i)->buildHash(builder); + } +} + +Index CompositeComponentType::getEntryPointCount() +{ + return m_entryPoints.getCount(); +} + +RefPtr<EntryPoint> CompositeComponentType::getEntryPoint(Index index) +{ + return m_entryPoints[index]; +} + +String CompositeComponentType::getEntryPointMangledName(Index index) +{ + return m_entryPointMangledNames[index]; +} + +String CompositeComponentType::getEntryPointNameOverride(Index index) +{ + return m_entryPointNameOverrides[index]; +} + +Index CompositeComponentType::getShaderParamCount() +{ + return m_shaderParams.getCount(); +} + +ShaderParamInfo CompositeComponentType::getShaderParam(Index index) +{ + return m_shaderParams[index]; +} + +Index CompositeComponentType::getSpecializationParamCount() +{ + return m_specializationParams.getCount(); +} + +SpecializationParam const& CompositeComponentType::getSpecializationParam(Index index) +{ + return m_specializationParams[index]; +} + +Index CompositeComponentType::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> CompositeComponentType::getRequirement(Index index) +{ + return m_requirements[index]; +} + +List<Module*> const& CompositeComponentType::getModuleDependencies() +{ + return m_moduleDependencyList.getModuleList(); +} + +List<SourceFile*> const& CompositeComponentType::getFileDependencies() +{ + return m_fileDependencyList.getFileList(); +} + +void CompositeComponentType::acceptVisitor( + ComponentTypeVisitor* visitor, + SpecializationInfo* specializationInfo) +{ + visitor->visitComposite(this, as<CompositeSpecializationInfo>(specializationInfo)); +} + +RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) +{ + SLANG_UNUSED(argCount); + + RefPtr<CompositeSpecializationInfo> specializationInfo = new CompositeSpecializationInfo(); + + Index offset = 0; + for (auto child : m_childComponents) + { + auto childParamCount = child->getSpecializationParamCount(); + SLANG_ASSERT(offset + childParamCount <= argCount); + + auto childInfo = child->_validateSpecializationArgs(args + offset, childParamCount, sink); + + specializationInfo->childInfos.add(childInfo); + + offset += childParamCount; + } + return specializationInfo; +} + +// +// SpecializedComponentType +// + +/// Utility type for collecting modules references by types/declarations +struct SpecializationArgModuleCollector : ComponentTypeVisitor +{ + HashSet<Module*> m_modulesSet; + List<Module*> m_modulesList; + + void addModule(Module* module) + { + m_modulesList.add(module); + m_modulesSet.add(module); + } + + void maybeAddModule(Module* module) + { + if (!module) + return; + if (m_modulesSet.contains(module)) + return; + + addModule(module); + } + + void collectReferencedModules(Decl* decl) + { + auto module = getModule(decl); + maybeAddModule(module); + } + + void collectReferencedModules(SubstitutionSet substitutions) + { + substitutions.forEachGenericSubstitution( + [this](GenericDecl*, Val::OperandView<Val> args) + { + for (auto arg : args) + { + collectReferencedModules(arg); + } + }); + } + + void collectReferencedModules(DeclRefBase* declRef) + { + collectReferencedModules(declRef->getDecl()); + collectReferencedModules(SubstitutionSet(declRef)); + } + + void collectReferencedModules(Type* type) + { + if (auto declRefType = as<DeclRefType>(type)) + { + collectReferencedModules(declRefType->getDeclRef()); + } + + // TODO: Handle non-decl-ref composite type cases + // (e.g., function types). + } + + void collectReferencedModules(Val* val) + { + if (auto type = as<Type>(val)) + { + collectReferencedModules(type); + } + else if (auto declRefVal = as<DeclRefIntVal>(val)) + { + collectReferencedModules(declRefVal->getDeclRef()); + } + + // TODO: other cases of values that could reference + // a declaration. + } + + void collectReferencedModules(List<ExpandedSpecializationArg> const& args) + { + for (auto arg : args) + { + collectReferencedModules(arg.val); + collectReferencedModules(arg.witness); + } + } + + // + // ComponentTypeVisitor methods + // + + void visitEntryPoint( + EntryPoint* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(entryPoint); + + if (!specializationInfo) + return; + + collectReferencedModules(specializationInfo->specializedFuncDeclRef); + collectReferencedModules(specializationInfo->existentialSpecializationArgs); + } + + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) + SLANG_OVERRIDE + { + SLANG_UNUSED(module); + + if (!specializationInfo) + return; + + for (auto arg : specializationInfo->genericArgs) + { + collectReferencedModules(arg.argVal); + } + collectReferencedModules(specializationInfo->existentialArgs); + } + + void visitComposite( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } +}; + +SpecializedComponentType::SpecializedComponentType( + ComponentType* base, + ComponentType::SpecializationInfo* specializationInfo, + List<SpecializationArg> const& specializationArgs, + DiagnosticSink* sink) + : ComponentType(base->getLinkage()) + , m_base(base) + , m_specializationInfo(specializationInfo) + , m_specializationArgs(specializationArgs) +{ + m_optionSet.overrideWith(base->getOptionSet()); + + m_irModule = generateIRForSpecializedComponentType(this, sink); + + // We need to account for the fact that a specialized + // entity like `myShader<SomeType>` needs to not only + // depend on the module(s) that `myShader` depends on, + // but also on any modules that `SomeType` depends on. + // + // We will set up a "collector" type that will be + // used to build a list of these additional modules. + // + SpecializationArgModuleCollector moduleCollector; + + // We don't want to go adding additional requirements for + // modules that the base component type already includes, + // so we will add those to the set of modules in + // the collector before we starting trying to add others. + // + base->enumerateModules([&](Module* module) { moduleCollector.m_modulesSet.add(module); }); + + // In order to collect the additional modules, we need + // to inspect the specialization arguments and see what + // they depend on. + // + // Naively, it seems like we'd just want to iterate + // over `specializationArgs`, which gives the specialization + // arguments as the user supplied them. However, such + // an approach would have a subtle problem. + // + // If we have a generic entry point like: + // + // // In module A + // myShader<T : IThing> + // + // + // And the type `SomeType` that is being used as an argument doesn't + // directly conform to `IThing`: + // + // // In module B + // struct SomeType { ... } + // + // and the conformance of `SomeType` to `IThing` is + // coming from yet another module: + // + // // In module C + // import B; + // extension SomeType : IThing { ... } + // + // In this case, the specialized component for `myShader<SomeType>` + // needs to depend on all of: + // + // * Module A, because it defines `myShader` + // * Module B, because it defines `SomeType` + // * Module C, because it defines the conformance `SomeType : IThing` + // + // We thus need to iterate over a form of the specialization + // arguments that includes the "expanded" arguments like + // interface conformance witnesses that got added during + // semantic checking. + // + // The expanded arguments are being stored in the `specializationInfo` + // today (for use by downstream code generation), and the easiest + // way to walk that information and get to the leaf nodes where + // the expanded arguments are stored is to apply a visitor to + // the specialized component type we are in the middle of constructing. + // + moduleCollector.visitSpecialized(this); + + // Now that we've collected our additional information, we can + // start to build up the final lists for the specialized component type. + // + // The starting point for our lists comes from the base component type. + // + m_moduleDependencies = base->getModuleDependencies(); + m_fileDependencies = base->getFileDependencies(); + + Index baseRequirementCount = base->getRequirementCount(); + for (Index r = 0; r < baseRequirementCount; r++) + { + m_requirements.add(base->getRequirement(r)); + } + + // The specialized component type will need to have additional + // dependencies and requirements based on the modules that + // were collected when looking at the specialization arguments. + + // We want to avoid adding the same file dependency more than once. + // + HashSet<SourceFile*> fileDependencySet; + for (SourceFile* sourceFile : m_fileDependencies) + fileDependencySet.add(sourceFile); + + for (auto module : moduleCollector.m_modulesList) + { + // The specialized component type will have an open (unsatisfied) + // requirement for each of the modules that its specialization + // arguments need. + // + // Note: what this means in practice is that the component type + // records that the given module(s) will need to be linked in + // before final code can be generated, but it importantly + // does not dictate the final placement of the parameters from + // those modules in the layout. + // + m_requirements.add(module); + + // The speciialized component type will also have a dependency + // on all the files that any of the modules involved in + // it depend on (including those that are required but not + // yet linked in). + // + // The file path information is what a client would need to + // use to decide if kernel code is out of date compared to + // source files, so we want to include anything that could + // affect the validity of generated code. + // + for (SourceFile* sourceFile : module->getFileDependencies()) + { + if (fileDependencySet.contains(sourceFile)) + continue; + fileDependencySet.add(sourceFile); + m_fileDependencies.add(sourceFile); + } + + // Finalyl we also add the module for the specialization arguments + // to the list of modules that would be used for legacy lookup + // operations where we need an implicit/default scope to use + // and want it to be expansive. + // + // TODO: This stuff really isn't worth keeping around long + // term, and we should ditch the entire "legacy lookup" idea. + // + m_moduleDependencies.add(module); + } + + // Because we are specializing shader code, the mangled entry + // point names for this component type may be different than + // for the base component type (e.g., the mangled name for `f<int>` + // is different than that that of the generic `f` function + // itself). + // + // We will compute the mangled names of all the entry points and + // store them here, so that we don't have to do it on the fly. + // Because the `ComponentType` structure is hierarchical, we + // need to use a recursive visitor to compute the names, + // and we will define that visitor locally: + // + struct EntryPointMangledNameCollector : ComponentTypeVisitor + { + List<String>* mangledEntryPointNames; + List<String>* entryPointNameOverrides; + + void visitEntryPoint( + EntryPoint* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + auto funcDeclRef = entryPoint->getFuncDeclRef(); + if (specializationInfo) + funcDeclRef = specializationInfo->specializedFuncDeclRef; + + (*mangledEntryPointNames).add(getMangledName(m_astBuilder, funcDeclRef)); + (*entryPointNameOverrides).add(entryPoint->getEntryPointNameOverride(0)); + } + + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + (*entryPointNameOverrides).getLast() = entryPoint->getEntryPointNameOverride(0); + } + + void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE {} + void visitComposite( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } + EntryPointMangledNameCollector(ASTBuilder* astBuilder) + : m_astBuilder(astBuilder) + { + } + ASTBuilder* m_astBuilder; + }; + + // With the visitor defined, we apply it to ourself to compute + // and collect the mangled entry point names. + // + EntryPointMangledNameCollector collector(getLinkage()->getASTBuilder()); + collector.mangledEntryPointNames = &m_entryPointMangledNames; + collector.entryPointNameOverrides = &m_entryPointNameOverrides; + collector.visitSpecialized(this); +} + +void SpecializedComponentType::buildHash(DigestBuilder<SHA1>& builder) +{ + auto specializationArgCount = getSpecializationArgCount(); + for (Index i = 0; i < specializationArgCount; ++i) + { + auto specializationArg = getSpecializationArg(i); + auto argString = specializationArg.val->toString(); + builder.append(argString); + } + + getBaseComponentType()->buildHash(builder); +} + +void SpecializedComponentType::acceptVisitor( + ComponentTypeVisitor* visitor, + SpecializationInfo* specializationInfo) +{ + SLANG_ASSERT(specializationInfo == nullptr); + SLANG_UNUSED(specializationInfo); + visitor->visitSpecialized(this); +} + +Index SpecializedComponentType::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index) +{ + return m_requirements[index]; +} + +String SpecializedComponentType::getEntryPointMangledName(Index index) +{ + return m_entryPointMangledNames[index]; +} + +String SpecializedComponentType::getEntryPointNameOverride(Index index) +{ + return m_entryPointNameOverrides[index]; +} + +// +// RenamedEntryPointComponentType +// + +RenamedEntryPointComponentType::RenamedEntryPointComponentType(ComponentType* base, String newName) + : ComponentType(base->getLinkage()), m_base(base), m_entryPointNameOverride(newName) +{ +} + +void RenamedEntryPointComponentType::acceptVisitor( + ComponentTypeVisitor* visitor, + SpecializationInfo* specializationInfo) +{ + visitor->visitRenamedEntryPoint( + this, + as<EntryPoint::EntryPointSpecializationInfo>(specializationInfo)); +} + +void RenamedEntryPointComponentType::buildHash(DigestBuilder<SHA1>& builder) +{ + SLANG_UNUSED(builder); +} + +// +// TypeConformance +// + +TypeConformance::TypeConformance( + Linkage* linkage, + SubtypeWitness* witness, + Int confomrmanceIdOverride, + DiagnosticSink* sink) + : ComponentType(linkage) + , m_subtypeWitness(witness) + , m_conformanceIdOverride(confomrmanceIdOverride) +{ + addDepedencyFromWitness(witness); + m_irModule = generateIRForTypeConformance(this, m_conformanceIdOverride, sink); +} + +void TypeConformance::addDepedencyFromWitness(SubtypeWitness* witness) +{ + if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness)) + { + auto declModule = getModule(declaredWitness->getDeclRef().getDecl()); + m_moduleDependencyList.addDependency(declModule); + m_fileDependencyList.addDependency(declModule); + if (m_requirementSet.add(declModule)) + { + m_requirements.add(declModule); + } + // TODO: handle the specialization arguments in declaredWitness->declRef.substitutions. + } + else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness)) + { + addDepedencyFromWitness(transitiveWitness->getMidToSup()); + addDepedencyFromWitness(transitiveWitness->getSubToMid()); + } + else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness)) + { + auto componentCount = conjunctionWitness->getComponentCount(); + for (Index i = 0; i < componentCount; ++i) + { + auto w = as<SubtypeWitness>(conjunctionWitness->getComponentWitness(i)); + if (w) + addDepedencyFromWitness(w); + } + } +} + +ISlangUnknown* TypeConformance::getInterface(const Guid& guid) +{ + if (guid == slang::ITypeConformance::getTypeGuid()) + return static_cast<slang::ITypeConformance*>(this); + + return Super::getInterface(guid); +} + +void TypeConformance::buildHash(DigestBuilder<SHA1>& builder) +{ + // TODO: Implement some kind of hashInto for Val then replace this + auto subtypeWitness = m_subtypeWitness->toString(); + + builder.append(subtypeWitness); + builder.append(m_conformanceIdOverride); +} + +List<Module*> const& TypeConformance::getModuleDependencies() +{ + return m_moduleDependencyList.getModuleList(); +} + +List<SourceFile*> const& TypeConformance::getFileDependencies() +{ + return m_fileDependencyList.getFileList(); +} + +Index TypeConformance::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> TypeConformance::getRequirement(Index index) +{ + return m_requirements[index]; +} + +void TypeConformance::acceptVisitor( + ComponentTypeVisitor* visitor, + ComponentType::SpecializationInfo* specializationInfo) +{ + SLANG_UNUSED(specializationInfo); + visitor->visitTypeConformance(this); +} + +RefPtr<ComponentType::SpecializationInfo> TypeConformance::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) +{ + SLANG_UNUSED(args); + SLANG_UNUSED(argCount); + SLANG_UNUSED(sink); + return nullptr; +} + +// +// ComponentTypeVisitor +// + +void ComponentTypeVisitor::visitChildren( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo) +{ + auto childCount = composite->getChildComponentCount(); + for (Index ii = 0; ii < childCount; ++ii) + { + auto child = composite->getChildComponent(ii); + auto childSpecializationInfo = + specializationInfo ? specializationInfo->childInfos[ii] : nullptr; + + child->acceptVisitor(this, childSpecializationInfo); + } +} + +void ComponentTypeVisitor::visitChildren(SpecializedComponentType* specialized) +{ + specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo()); +} + +} // namespace Slang |
