summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-shader.cpp85
-rwxr-xr-xsource/slang/slang-compiler.h14
-rw-r--r--source/slang/slang-ir-legalize-types.cpp4
-rw-r--r--source/slang/slang.cpp295
4 files changed, 346 insertions, 52 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 800a0519f..b6ed0224a 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -784,30 +784,10 @@ namespace Slang
return fillRequirements(globalComponentType);
}
- /// Create a component type that represents the global scope for a compile request,
- /// along with any entry point functions.
- ///
- /// The resulting component type will include the global-scope information
- /// first, so its layout will be compatible with the result of
- /// `createUnspecializedGlobalComponentType`.
- ///
- /// The new component type will also add on any entry-point functions
- /// that were requested and will thus include space for their `uniform` parameters.
- /// If multiple entry points were requested then they will be given non-overlapping
- /// parameter bindings, consistent with them being used together in
- /// a single pipeline state, hit group, etc.
- ///
- /// The result of this function is unspecialized and doesn't take into
- /// account any specialization arguments the user might have supplied.
- ///
- RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType(
- FrontEndCompileRequest* compileRequest,
- List<RefPtr<ComponentType>>& outUnspecializedEntryPoints)
+ void FrontEndCompileRequest::checkEntryPoints()
{
- auto linkage = compileRequest->getLinkage();
- auto sink = compileRequest->getSink();
-
- auto globalComponentType = compileRequest->getGlobalComponentType();
+ auto linkage = getLinkage();
+ auto sink = getSink();
// The validation of entry points here will be modal, and controlled
// by whether the user specified any entry points directly via
@@ -818,17 +798,14 @@ namespace Slang
// First, check if the user requested any entry points explicitly via
// the API or command line.
//
- bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0;
-
- List<RefPtr<ComponentType>> allComponentTypes;
- allComponentTypes.add(globalComponentType);
+ bool anyExplicitEntryPoints = getEntryPointReqCount() != 0;
if( anyExplicitEntryPoints )
{
// If there were any explicit requests for entry points to be
// checked, then we will *only* check those.
//
- for(auto entryPointReq : compileRequest->getEntryPointReqs())
+ for(auto entryPointReq : getEntryPointReqs())
{
auto entryPoint = findAndValidateEntryPoint(
entryPointReq);
@@ -841,9 +818,6 @@ namespace Slang
// compilation API doesn't allow for grouping).
//
entryPointReq->getTranslationUnit()->module->_addEntryPoint(entryPoint);
-
- outUnspecializedEntryPoints.add(entryPoint);
- allComponentTypes.add(entryPoint);
}
}
@@ -868,10 +842,10 @@ namespace Slang
// For now we'll start with an extremely basic approach that
// should work for typical HLSL code.
//
- Index translationUnitCount = compileRequest->translationUnits.getCount();
+ Index translationUnitCount = translationUnits.getCount();
for(Index tt = 0; tt < translationUnitCount; ++tt)
{
- auto translationUnit = compileRequest->translationUnits[tt];
+ auto translationUnit = translationUnits[tt];
for( auto globalDecl : translationUnit->getModuleDecl()->members )
{
auto maybeFuncDecl = globalDecl;
@@ -920,12 +894,51 @@ namespace Slang
// independent of the others.
//
translationUnit->module->_addEntryPoint(entryPoint);
-
- outUnspecializedEntryPoints.add(entryPoint);
- allComponentTypes.add(entryPoint);
}
}
}
+ }
+
+
+ /// Create a component type that represents the global scope for a compile request,
+ /// along with any entry point functions.
+ ///
+ /// The resulting component type will include the global-scope information
+ /// first, so its layout will be compatible with the result of
+ /// `createUnspecializedGlobalComponentType`.
+ ///
+ /// The new component type will also add on any entry-point functions
+ /// that were requested and will thus include space for their `uniform` parameters.
+ /// If multiple entry points were requested then they will be given non-overlapping
+ /// parameter bindings, consistent with them being used together in
+ /// a single pipeline state, hit group, etc.
+ ///
+ /// The result of this function is unspecialized and doesn't take into
+ /// account any specialization arguments the user might have supplied.
+ ///
+ RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType(
+ FrontEndCompileRequest* compileRequest,
+ List<RefPtr<ComponentType>>& outUnspecializedEntryPoints)
+ {
+ auto linkage = compileRequest->getLinkage();
+
+ auto globalComponentType = compileRequest->getGlobalComponentType();
+
+ List<RefPtr<ComponentType>> allComponentTypes;
+ allComponentTypes.add(globalComponentType);
+
+ Index translationUnitCount = compileRequest->translationUnits.getCount();
+ for(Index tt = 0; tt < translationUnitCount; ++tt)
+ {
+ auto translationUnit = compileRequest->translationUnits[tt];
+ auto module = translationUnit->getModule();
+
+ for(auto entryPoint : module->getEntryPoints() )
+ {
+ outUnspecializedEntryPoints.add(entryPoint);
+ allComponentTypes.add(entryPoint);
+ }
+ }
// Also consider entry points that were introduced via adding
// a library reference...
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 2388f5247..9e3550066 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -602,10 +602,8 @@ namespace Slang
Index getRequirementCount() SLANG_OVERRIDE;
RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
- /// TODO: These should include requirements/dependencies for the types
- /// referenced in the specialization arguments...
- List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_base->getModuleDependencies(); }
- List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_base->getFilePathDependencies(); }
+ List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies; }
+ List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencies; }
/// Get a list of tagged-union types referenced by the specialization parameters.
List<TaggedUnionType*> const& getTaggedUnionTypes() { return m_taggedUnionTypes; }
@@ -638,6 +636,9 @@ namespace Slang
// Any tagged union types that were referenced by the specialization arguments.
List<TaggedUnionType*> m_taggedUnionTypes;
+ List<Module*> m_moduleDependencies;
+ List<String> m_filePathDependencies;
+ List<RefPtr<ComponentType>> m_requirements;
};
/// Describes an entry point for the purposes of layout and code generation.
@@ -1362,6 +1363,7 @@ namespace Slang
DiagnosticSink* sink);
void loadParsedModule(
+ RefPtr<FrontEndCompileRequest> compileRequest,
RefPtr<TranslationUnitRequest> translationUnit,
Name* name,
PathInfo const& pathInfo);
@@ -1560,6 +1562,8 @@ namespace Slang
// of the translation units in the program
void checkAllTranslationUnits();
+ void checkEntryPoints();
+
void generateIR();
SlangResult executeActionsInner();
@@ -1576,6 +1580,8 @@ namespace Slang
/// @return The zero-based index of the translation unit in this compile request.
int addTranslationUnit(SourceLanguage language, Name* moduleName);
+ int addTranslationUnit(TranslationUnitRequest* translationUnit);
+
void addTranslationUnitSourceFile(
int translationUnitIndex,
SourceFile* sourceFile);
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 91229b80d..91427ce9c 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -1409,6 +1409,10 @@ static LegalVal legalizeInst(
break;
}
+ if(as<IRAttr>(inst))
+ return LegalVal::simple(inst);
+
+
// We will iterate over all the operands, extract the legalized
// value of each, and collect them in an array for subsequent use.
//
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 254a37c24..4a25c2392 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1552,6 +1552,7 @@ void FrontEndCompileRequest::checkAllTranslationUnits()
{
checkTranslationUnit(translationUnit.Ptr());
}
+ checkEntryPoints();
}
void FrontEndCompileRequest::generateIR()
@@ -1676,7 +1677,6 @@ SlangResult FrontEndCompileRequest::executeActionsInner()
if (getSink()->getErrorCount() != 0)
return SLANG_FAIL;
-
// Look up all the entry points that are expected,
// and use them to populate the `program` member.
//
@@ -1905,13 +1905,11 @@ SlangResult EndToEndCompileRequest::executeActions()
int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* moduleName)
{
- Index result = translationUnits.getCount();
-
if (!moduleName)
{
// We want to ensure that symbols defined in different translation
// units get unique mangled names, so that we can, e.g., tell apart
- // a `main()` function in `vertex.slang` and a `main()` in `fragment.slang`,
+ // a `main()` function in `vertex.hlsl` and a `main()` in `fragment.hlsl`,
// even when they are being compiled together.
//
String generatedName = "tu";
@@ -1925,11 +1923,17 @@ int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* mo
translationUnit->moduleName = moduleName;
- translationUnits.add(translationUnit);
+ return addTranslationUnit(translationUnit);
+}
+int FrontEndCompileRequest::addTranslationUnit(TranslationUnitRequest* translationUnit)
+{
+ Index result = translationUnits.getCount();
+ translationUnits.add(translationUnit);
return (int) result;
}
+
void FrontEndCompileRequest::addTranslationUnitSourceFile(
int translationUnitIndex,
SourceFile* sourceFile)
@@ -2041,6 +2045,7 @@ UInt Linkage::addTarget(
}
void Linkage::loadParsedModule(
+ RefPtr<FrontEndCompileRequest> compileRequest,
RefPtr<TranslationUnitRequest> translationUnit,
Name* name,
const PathInfo& pathInfo)
@@ -2061,7 +2066,7 @@ void Linkage::loadParsedModule(
auto sink = translationUnit->compileRequest->getSink();
int errorCountBefore = sink->getErrorCount();
- checkTranslationUnit(translationUnit.Ptr());
+ compileRequest->checkAllTranslationUnits();
int errorCountAfter = sink->getErrorCount();
if (errorCountAfter != errorCountBefore)
@@ -2106,6 +2111,8 @@ RefPtr<Module> Linkage::loadModule(
translationUnit->moduleName = name;
translationUnit->sourceLanguage = SourceLanguage::Slang;
+ frontEndReq->addTranslationUnit(translationUnit);
+
auto module = translationUnit->getModule();
ModuleBeingImportedRAII moduleBeingImported(
@@ -2132,6 +2139,7 @@ RefPtr<Module> Linkage::loadModule(
}
loadParsedModule(
+ frontEndReq,
translationUnit,
name,
filePathInfo);
@@ -2902,6 +2910,134 @@ RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpeci
// SpecializedComponentType
//
+/// Utility type for collecting modules references by types/declarations
+struct SpecializationArgModuleCollector : ComponentTypeVisitor
+{
+ HashSet<Module*> m_modulesSet;
+ List<Module*> m_modulesList;
+
+ void addModule(Module* module)
+ {
+ m_modulesList.add(module);
+ m_modulesSet.Add(module);
+ }
+
+ void maybeAddModule(Module* module)
+ {
+ if(!module)
+ return;
+ if(m_modulesSet.Contains(module))
+ return;
+
+ addModule(module);
+ }
+
+ void collectReferencedModules(Decl* decl)
+ {
+ auto module = getModule(decl);
+ maybeAddModule(module);
+ }
+
+ void collectReferencedModules(Substitutions* substitution)
+ {
+ if(auto genericSubst = as<GenericSubstitution>(substitution))
+ {
+ for(auto arg : genericSubst->args)
+ {
+ collectReferencedModules(arg);
+ }
+ }
+ }
+
+ void collectReferencedModules(SubstitutionSet const& substitutions)
+ {
+ for(auto subst = substitutions.substitutions; subst; subst = subst->outer)
+ {
+ collectReferencedModules(subst);
+ }
+ }
+
+ void collectReferencedModules(DeclRefBase const& declRef)
+ {
+ collectReferencedModules(declRef.decl);
+ collectReferencedModules(declRef.substitutions);
+ }
+
+ void collectReferencedModules(Type* type)
+ {
+ if(auto declRefType = as<DeclRefType>(type))
+ {
+ collectReferencedModules(declRefType->declRef);
+ }
+
+ // TODO: Handle non-decl-ref composite type cases
+ // (e.g., function types).
+ }
+
+ void collectReferencedModules(Val* val)
+ {
+ if(auto type = as<Type>(val))
+ {
+ collectReferencedModules(type);
+ }
+ else if (auto declRefVal = as<GenericParamIntVal>(val))
+ {
+ collectReferencedModules(declRefVal->declRef);
+ }
+
+ // TODO: other cases of values that could reference
+ // a declaration.
+ }
+
+ void collectReferencedModules(List<ExpandedSpecializationArg> const& args)
+ {
+ for(auto arg : args)
+ {
+ collectReferencedModules(arg.val);
+ collectReferencedModules(arg.witness);
+ }
+ }
+
+ //
+ // ComponentTypeVisitor methods
+ //
+
+ void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(entryPoint);
+
+ if(!specializationInfo)
+ return;
+
+ collectReferencedModules(specializationInfo->specializedFuncDeclRef);
+ collectReferencedModules(specializationInfo->existentialSpecializationArgs);
+ }
+
+ void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(module);
+
+ if(!specializationInfo)
+ return;
+
+ for(auto arg : specializationInfo->genericArgs)
+ {
+ collectReferencedModules(arg.argVal);
+ }
+ collectReferencedModules(specializationInfo->existentialArgs);
+ }
+
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(composite, specializationInfo);
+ }
+
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
+ {
+ visitChildren(specialized);
+ }
+};
+
SpecializedComponentType::SpecializedComponentType(
ComponentType* base,
ComponentType::SpecializationInfo* specializationInfo,
@@ -2914,8 +3050,147 @@ SpecializedComponentType::SpecializedComponentType(
{
m_irModule = generateIRForSpecializedComponentType(this, sink);
+ // We need to account for the fact that a specialized
+ // entity like `myShader<SomeType>` needs to not only
+ // depend on the module(s) that `myShader` depends on,
+ // but also on any modules that `SomeType` depends on.
+ //
+ // We will set up a "collector" type that will be
+ // used to build a list of these additional modules.
+ //
+ SpecializationArgModuleCollector moduleCollector;
+
+ // We don't want to go adding additional requirements for
+ // modules that the base component type already includes,
+ // so we will add those to the set of modules in
+ // the collector before we starting trying to add others.
+ //
+ base->enumerateModules([&](Module* module)
+ {
+ moduleCollector.m_modulesSet.Add(module);
+ });
+
+ // In order to collect the additional modules, we need
+ // to inspect the specialization arguments and see what
+ // they depend on.
+ //
+ // Naively, it seems like we'd just want to iterate
+ // over `specializationArgs`, which gives the specialization
+ // arguments as the user supplied them. However, such
+ // an approach would have a subtle problem.
+ //
+ // If we have a generic entry point like:
+ //
+ // // In module A
+ // myShader<T : IThing>
+ //
+ //
+ // And the type `SomeType` that is being used as an argument doesn't
+ // directly conform to `IThing`:
+ //
+ // // In module B
+ // struct SomeType { ... }
+ //
+ // and the conformance of `SomeType` to `IThing` is
+ // coming from yet another module:
+ //
+ // // In module C
+ // import B;
+ // extension SomeType : IThing { ... }
+ //
+ // In this case, the specialized component for `myShader<SomeType>`
+ // needs to depend on all of:
+ //
+ // * Module A, because it defines `myShader`
+ // * Module B, because it defines `SomeType`
+ // * Module C, because it defines the conformance `SomeType : IThing`
+ //
+ // We thus need to iterate over a form of the specialization
+ // arguments that includes the "expanded" arguments like
+ // interface conformance witnesses that got added during
+ // semantic checking.
+ //
+ // The expanded arguments are being stored in the `specializationInfo`
+ // today (for use by downstream code generation), and the easiest
+ // way to walk that information and get to the leaf nodes where
+ // the expanded arguments are stored is to apply a visitor to
+ // the specialized component type we are in the middle of constructing.
+ //
+ moduleCollector.visitSpecialized(this);
+
+ // Now that we've collected our additional information, we can
+ // start to build up the final lists for the specialized component type.
+ //
+ // The starting point for our lists comes from the base component type.
+ //
+ m_moduleDependencies = base->getModuleDependencies();
+ m_filePathDependencies = base->getFilePathDependencies();
+
+ Index baseRequirementCount = base->getRequirementCount();
+ for( Index r = 0; r < baseRequirementCount; r++ )
+ {
+ m_requirements.add(base->getRequirement(r));
+ }
+
+ // The specialized component type will need to have additional
+ // dependencies and requirements based on the modules that
+ // were collected when looking at the specialization arguments.
+
+ // We want to avoid adding the same file path dependency more than once.
+ //
+ HashSet<String> filePathDependencySet;
+ for(auto path : m_filePathDependencies)
+ filePathDependencySet.Add(path);
+
+ for(auto module : moduleCollector.m_modulesList)
+ {
+ // The specialized component type will have an open (unsatisfied)
+ // requirement for each of the modules that its specialization
+ // arguments need.
+ //
+ // Note: what this means in practice is that the component type
+ // records that the given module(s) will need to be linked in
+ // before final code can be generated, but it importantly
+ // does not dictate the final placement of the parameters from
+ // those modules in the layout.
+ //
+ m_requirements.add(module);
+
+ // The speciialized component type will also have a dependency
+ // on all the file paths that any of the modules involved in
+ // it depend on (including those that are required but not
+ // yet linked in).
+ //
+ // The file path information is what a client would need to
+ // use to decide if kernel code is out of date compared to
+ // source files, so we want to include anything that could
+ // affect the validity of generated code.
+ //
+ for(auto path : module->getFilePathDependencies())
+ {
+ if(filePathDependencySet.Contains(path))
+ continue;
+ filePathDependencySet.Add(path);
+ m_filePathDependencies.add(path);
+ }
+
+ // Finalyl we also add the module for the specialization arguments
+ // to the list of modules that would be used for legacy lookup
+ // operations where we need an implicit/default scope to use
+ // and want it to be expansive.
+ //
+ // TODO: This stuff really isn't worth keeping around long
+ // term, and we should ditch the entire "legacy lookup" idea.
+ //
+ m_moduleDependencies.add(module);
+ }
+
// The following is a bit of a hack.
//
+ // TODO: We should not need this hack any longer, since the
+ // new approach to `switch`-based dynamic dispatch has made
+ // the existing tagged-union support obsolete.
+ //
// Back-end code generation relies on us having computed layouts for all tagged
// unions that end up being used in the code, which means we need a way to find
// all such types that get used in a program (and the stuff it imports).
@@ -3001,16 +3276,12 @@ void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, Spec
Index SpecializedComponentType::getRequirementCount()
{
- // TODO: A specialized component type may have *more* requirements
- // than the original, because it also needs to include the module(s)
- // that define the types used for specialization arguments.
-
- return m_base->getRequirementCount();
+ return m_requirements.getCount();
}
RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index)
{
- return m_base->getRequirement(index);
+ return m_requirements[index];
}
String SpecializedComponentType::getEntryPointMangledName(Index index)