summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-check.cpp1312
-rw-r--r--source/slang/slang-compiler.cpp144
-rw-r--r--source/slang/slang-compiler.h920
-rw-r--r--source/slang/slang-diagnostic-defs.h9
-rw-r--r--source/slang/slang-emit-c-like.cpp8
-rw-r--r--source/slang/slang-emit-glsl.cpp6
-rw-r--r--source/slang/slang-emit-hlsl.cpp4
-rw-r--r--source/slang/slang-emit.cpp54
-rw-r--r--source/slang/slang-ir-bind-existentials.cpp92
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h4
-rw-r--r--source/slang/slang-ir-link.cpp189
-rw-r--r--source/slang/slang-ir.cpp16
-rw-r--r--source/slang/slang-lower-to-ir.cpp216
-rw-r--r--source/slang/slang-lower-to-ir.h26
-rw-r--r--source/slang/slang-parameter-binding.cpp1043
-rw-r--r--source/slang/slang-reflection.cpp181
-rw-r--r--source/slang/slang-syntax.cpp73
-rw-r--r--source/slang/slang-syntax.h38
-rw-r--r--source/slang/slang-type-defs.h2
-rw-r--r--source/slang/slang-type-layout.cpp112
-rw-r--r--source/slang/slang-type-layout.h108
-rw-r--r--source/slang/slang.cpp700
23 files changed, 3406 insertions, 1852 deletions
diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp
index edf41fff0..b7c1ccdc2 100644
--- a/source/slang/slang-check.cpp
+++ b/source/slang/slang-check.cpp
@@ -7107,8 +7107,7 @@ namespace Slang
{
if(context.mode != OverloadResolveContext::Mode::JustTrying)
{
- // TODO: diagnose a problem here
- getSink()->diagnose(context.loc, Diagnostics::unimplemented, "generic constraint not satisfied");
+ getSink()->diagnose(context.loc, Diagnostics::typeArgumentDoesNotConformToInterface, sub, sup);
}
return false;
}
@@ -9663,15 +9662,15 @@ namespace Slang
}
}
- /// Recursively walk `paramDeclRef` and add any required existential slots to `ioSlots`.
- static void _collectExistentialTypeParamsRec(
- ExistentialTypeSlots& ioSlots,
+ /// Recursively walk `paramDeclRef` and add any existential/interface specialization parameters to `ioSpecializationParams`.
+ static void _collectExistentialSpecializationParamsRec(
+ SpecializationParams& ioSpecializationParams,
DeclRef<VarDeclBase> paramDeclRef);
- /// Recursively walk `type` and discover any required existential type parameters.
- static void _collectExistentialTypeParamsRec(
- ExistentialTypeSlots& ioSlots,
- Type* type)
+ /// Recursively walk `type` and add any existential/interface specialization parameters to `ioSpecializationParams`.
+ static void _collectExistentialSpecializationParamsRec(
+ SpecializationParams& ioSpecializationParams,
+ Type* type)
{
// Whether or not something is an array does not affect
// the number of existential slots it introduces.
@@ -9683,7 +9682,9 @@ namespace Slang
if( auto parameterGroupType = as<ParameterGroupType>(type) )
{
- _collectExistentialTypeParamsRec(ioSlots, parameterGroupType->getElementType());
+ _collectExistentialSpecializationParamsRec(
+ ioSpecializationParams,
+ parameterGroupType->getElementType());
return;
}
@@ -9692,9 +9693,14 @@ namespace Slang
auto typeDeclRef = declRefType->declRef;
if( auto interfaceDeclRef = typeDeclRef.as<InterfaceDecl>() )
{
- // Each leaf parameter of interface type adds one slot.
+ // Each leaf parameter of interface type adds a specialization
+ // parameter, which determines the concrete type(s) that may
+ // be provided as arguments for that parameter.
//
- ioSlots.paramTypes.add(type);
+ SpecializationParam specializationParam;
+ specializationParam.flavor = SpecializationParam::Flavor::ExistentialType;
+ specializationParam.object = type;
+ ioSpecializationParams.add(specializationParam);
}
else if( auto structDeclRef = typeDeclRef.as<StructDecl>() )
{
@@ -9706,7 +9712,9 @@ namespace Slang
if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>())
continue;
- _collectExistentialTypeParamsRec(ioSlots, fieldDeclRef);
+ _collectExistentialSpecializationParamsRec(
+ ioSpecializationParams,
+ fieldDeclRef);
}
}
}
@@ -9716,26 +9724,58 @@ namespace Slang
// element types.
}
- static void _collectExistentialTypeParamsRec(
- ExistentialTypeSlots& ioSlots,
+ static void _collectExistentialSpecializationParamsRec(
+ SpecializationParams& ioSpecializationParams,
DeclRef<VarDeclBase> paramDeclRef)
{
- _collectExistentialTypeParamsRec(ioSlots, GetType(paramDeclRef));
+ _collectExistentialSpecializationParamsRec(
+ ioSpecializationParams,
+ GetType(paramDeclRef));
}
- /// Add information about a shader parameter to `ioParams` and `ioSlots`
- static void _collectExistentialSlotsForShaderParam(
+ /// Collect any interface/existential specialization parameters for `paramDeclRef` into `ioParamInfo` and `ioSpecializationParams`
+ static void _collectExistentialSpecializationParamsForShaderParam(
ShaderParamInfo& ioParamInfo,
- ExistentialTypeSlots& ioSlots,
+ SpecializationParams& ioSpecializationParams,
DeclRef<VarDeclBase> paramDeclRef)
{
- Index startSlot = ioSlots.paramTypes.getCount();
- _collectExistentialTypeParamsRec(ioSlots, paramDeclRef);
- Index endSlot = ioSlots.paramTypes.getCount();
+ Index beginParamIndex = ioSpecializationParams.getCount();
+ _collectExistentialSpecializationParamsRec(ioSpecializationParams, paramDeclRef);
+ Index endParamIndex = ioSpecializationParams.getCount();
- ioParamInfo.firstExistentialTypeSlot = UInt(startSlot);
- ioParamInfo.existentialTypeSlotCount = UInt(endSlot - startSlot);;
+ ioParamInfo.firstSpecializationParamIndex = beginParamIndex;
+ ioParamInfo.specializationParamCount = endParamIndex - beginParamIndex;
+ }
+
+ void EntryPoint::_collectGenericSpecializationParamsRec(Decl* decl)
+ {
+ if(!decl)
+ return;
+
+ _collectGenericSpecializationParamsRec(decl->ParentDecl);
+
+ auto genericDecl = as<GenericDecl>(decl);
+ if(!genericDecl)
+ return;
+
+ for(auto m : genericDecl->Members)
+ {
+ if(auto genericTypeParam = as<GenericTypeParamDecl>(m))
+ {
+ SpecializationParam param;
+ param.flavor = SpecializationParam::Flavor::GenericType;
+ param.object = genericTypeParam;
+ m_genericSpecializationParams.add(param);
+ }
+ else if(auto genericValParam = as<GenericValueParamDecl>(m))
+ {
+ SpecializationParam param;
+ param.flavor = SpecializationParam::Flavor::GenericValue;
+ param.object = genericValParam;
+ m_genericSpecializationParams.add(param);
+ }
+ }
}
/// Enumerate the existential-type parameters of an `EntryPoint`.
@@ -9744,6 +9784,24 @@ namespace Slang
///
void EntryPoint::_collectShaderParams()
{
+ // We don't currently treat an entry point as having any
+ // *global* shader parameters.
+ //
+ // TODO: We could probably clean up the code a bit by treating
+ // an entry point as introducing a global shader parameter
+ // that is based on the implicit "parameters struct" type
+ // of the entry point itself.
+
+ // We collect the generic parameters of the entry point,
+ // along with those of any outer generics first.
+ //
+ _collectGenericSpecializationParamsRec(getFuncDecl());
+
+ // After geneic specialization parameters have been collected,
+ // we look through the value parameters of the entry point
+ // function and see if any of them introduce existential/interface
+ // specialization parameters.
+ //
// Note: we defensively test whether there is a function decl-ref
// because this routine gets called from the constructor, and
// a "dummy" entry point will have a null pointer for the function.
@@ -9755,9 +9813,9 @@ namespace Slang
ShaderParamInfo shaderParamInfo;
shaderParamInfo.paramDeclRef = paramDeclRef;
- _collectExistentialSlotsForShaderParam(
+ _collectExistentialSpecializationParamsForShaderParam(
shaderParamInfo,
- m_existentialSlots,
+ m_existentialSpecializationParams,
paramDeclRef);
m_shaderParams.add(shaderParamInfo);
@@ -9765,111 +9823,6 @@ namespace Slang
}
}
-static bool shouldUseFalcorCustomSharedKeywordSemantics(
- EntryPointGroup* entryPointGroup)
-{
- if( !entryPointGroup->getLinkageImpl()->m_useFalcorCustomSharedKeywordSemantics )
- return false;
-
- // As a sanity check, if we are being asked to lay out an
- // empty entry-point group, then don't apply the convention.
- //
- if(entryPointGroup->getEntryPointCount() == 0)
- return false;
-
- // Otherwise we will look at the first entry point in the group,
- // and use that to determine whether it looks like we are compiling
- // ray-tracing shaders or not.
- //
- switch( entryPointGroup->getEntryPoint(0)->getStage() )
- {
- case Stage::AnyHit:
- case Stage::Callable:
- case Stage::ClosestHit:
- case Stage::Intersection:
- case Stage::Miss:
- case Stage::RayGeneration:
- return true;
-
- default:
- return false;
- }
-}
-
-
-static bool shouldUseFalcorCustomSharedKeywordSemantics(
- Program* program)
-{
- if( !program->getLinkageImpl()->m_useFalcorCustomSharedKeywordSemantics )
- return false;
-
- // As a sanity check, if we are being asked to lay out a program
- // with *no* entry points, then we don't apply the convention.
- //
- if(program->getEntryPointGroupCount() == 0)
- return false;
-
- // Otherwise we let the first entry-point group determine if we should
- // apply the policy for the entire program.
- //
- // Note: this could lead to confusing results if a `Program` mixes
- // entry point groups for RT and non-RT pipelines, but that isn't
- // a scenario we expect to come up and this whole routine is handling
- // "do what I mean" semantics for legacy behavior.
- //
- return shouldUseFalcorCustomSharedKeywordSemantics(program->getEntryPointGroup(0));
-}
-
-
-void EntryPointGroup::_collectShaderParams(DiagnosticSink* sink)
-{
- // If and only if we are in the special mode for Falcor support
- // where non-`shared` global shader parameters are actually
- // supposed to go into the "local root signature," then we
- // will consider such parameters as if they were entry-point
- // parameters, attached to the group.
- //
- if( shouldUseFalcorCustomSharedKeywordSemantics(this) )
- {
- for( auto module : getModuleDependencies() )
- {
- auto moduleDecl = module->getModuleDecl();
- for( auto globalVar : moduleDecl->getMembersOfType<VarDecl>() )
- {
- // Don't consider globals that aren't shader parameters.
- //
- if(!isGlobalShaderParameter(globalVar))
- continue;
-
- // Don't consider global shader paramters that were marked
- // `shared`, since that is how global-root-signature parameters
- // are being specified.
- //
- if( globalVar->HasModifier<HLSLEffectSharedModifier>() )
- continue;
-
- auto paramDeclRef = makeDeclRef(globalVar.Ptr());
-
- ShaderParamInfo shaderParamInfo;
- shaderParamInfo.paramDeclRef = paramDeclRef;
-
- ExistentialTypeSlots slots;
- _collectExistentialSlotsForShaderParam(
- shaderParamInfo,
- slots,
- paramDeclRef);
-
- if( slots.paramTypes.getCount() != 0 )
- {
- sink->diagnose(globalVar, Diagnostics::typeParametersNotAllowedOnEntryPointGlobal, globalVar);
- }
-
- m_shaderParams.add(shaderParamInfo);
- }
- }
- }
-}
-
// Validate that an entry point function conforms to any additional
// constraints based on the stage (and profile?) it specifies.
void validateEntryPoint(
@@ -10015,6 +9968,7 @@ void EntryPointGroup::_collectShaderParams(DiagnosticSink* sink)
//
auto compileRequest = entryPointReq->getCompileRequest();
auto translationUnit = entryPointReq->getTranslationUnit();
+ auto linkage = compileRequest->getLinkage();
auto sink = compileRequest->getSink();
auto translationUnitSyntax = translationUnit->getModuleDecl();
@@ -10133,8 +10087,8 @@ void EntryPointGroup::_collectShaderParams(DiagnosticSink* sink)
// a more uniform representation in the AST?
}
-
RefPtr<EntryPoint> entryPoint = EntryPoint::create(
+ linkage,
makeDeclRef(entryPointFuncDecl),
entryPointProfile);
@@ -10548,11 +10502,97 @@ static bool doesParameterMatch(
return true;
}
- /// Enumerate the existential-type parameters of a `Program`.
- ///
- /// Any parameters found will be added to the list of existential slots on `this`.
- ///
- void Program::_collectShaderParams(DiagnosticSink* sink)
+ void Module::_collectShaderParams()
+ {
+ auto moduleDecl = m_moduleDecl;
+
+ // We are going to walk the global declarations in the body of the
+ // module, and use those to build up our lists of:
+ //
+ // * Global shader parameters
+ // * Specialization parameters (both generic and interface/existential)
+ // * Requirements (`import`ed modules)
+ //
+ // For requirements, we want to be careful to only
+ // add each required module once (in case the same
+ // module got `import`ed multiple times), so we
+ // will keep a set of the modules we've already
+ // seen and processed.
+ //
+ HashSet<Module*> requiredModuleSet;
+
+ for( auto globalDecl : moduleDecl->Members )
+ {
+ if(auto globalVar = globalDecl.as<VarDecl>())
+ {
+ // We do not want to consider global variable declarations
+ // that don't represents shader parameters. This includes
+ // things like `static` globals and `groupshared` variables.
+ //
+ if(!isGlobalShaderParameter(globalVar))
+ continue;
+
+ // At this point we know we have a global shader parameter.
+
+ GlobalShaderParamInfo shaderParamInfo;
+ shaderParamInfo.paramDeclRef = makeDeclRef(globalVar.Ptr());
+
+ // We need to consider what specialization parameters
+ // are introduced by this shader parameter. This step
+ // fills in fields on `shaderParamInfo` so that we
+ // can assocaite specialization arguments supplied later
+ // with the correct parameter.
+ //
+ _collectExistentialSpecializationParamsForShaderParam(
+ shaderParamInfo,
+ m_specializationParams,
+ makeDeclRef(globalVar.Ptr()));
+
+ m_shaderParams.add(shaderParamInfo);
+ }
+ else if( auto globalGenericParam = as<GlobalGenericParamDecl>(globalDecl) )
+ {
+ // A global generic type parameter declaration introduces
+ // a suitable specialization parameter.
+ //
+ SpecializationParam specializationParam;
+ specializationParam.flavor = SpecializationParam::Flavor::GenericType;
+ specializationParam.object = globalGenericParam;
+ m_specializationParams.add(specializationParam);
+ }
+ else if( auto importDecl = as<ImportDecl>(globalDecl) )
+ {
+ // An `import` declaration creates a requirement dependency
+ // from this module to another module.
+ //
+ auto importedModule = getModule(importDecl->importedModuleDecl);
+ if(!requiredModuleSet.Contains(importedModule))
+ {
+ requiredModuleSet.Add(importedModule);
+ m_requirements.add(importedModule);
+ }
+ }
+ }
+ }
+
+ Index Module::getRequirementCount()
+ {
+ return m_requirements.getCount();
+ }
+
+ RefPtr<ComponentType> Module::getRequirement(Index index)
+ {
+ return m_requirements[index];
+ }
+
+ void Module::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
+ {
+ visitor->visitModule(this, as<ModuleSpecializationInfo>(specializationInfo));
+ }
+
+
+ /// Enumerate the parameters of a `LegacyProgram`.
+ void LegacyProgram::_collectShaderParams(DiagnosticSink* sink)
{
// We need to collect all of the global shader parameters
// referenced by the compile request, and for each we
@@ -10568,10 +10608,22 @@ static bool doesParameterMatch(
// To deal with the first issue, we will maintain a map from a parameter
// name to the index of an existing parameter with that name.
//
+ // TODO: Eventually we should deprecate support for the
+ // deduplication feature of `LegaqcyProgram`, at which point
+ // this entire type and all its complications can be eliminated
+ // from the code (that includes a lot of support in the "parameter
+ // binding" step for shader parameters with multiple declarations).
+ // Until that point this type will have a fair amount of duplication
+ // with stuff in `Module` and `CompositeComponentType`.
+
+ // We use a dictionary to keep track of any shader parameter
+ // we've alrady collected with a given name.
+ //
Dictionary<Name*, Int> mapNameToParamIndex;
- for( auto module : getModuleDependencies() )
+ for( auto translationUnit : m_translationUnits )
{
+ auto module = translationUnit->getModule();
auto moduleDecl = module->getModuleDecl();
for( auto globalVar : moduleDecl->getMembersOfType<VarDecl>() )
{
@@ -10582,27 +10634,6 @@ static bool doesParameterMatch(
if(!isGlobalShaderParameter(globalVar))
continue;
- // HACK: In order to support existing policy in the Falcor
- // application, we support a custom mode where only
- // global variables marked as `shared` should be considered
- // as global shader parameters (that go in the "global
- // root signature" for DXR).
- //
- // TODO: Eliminate this special case once all of the client
- // application code has been ported to use more general-purpose
- // Slang mechanisms (e.g., entry-point `uniform` parameters).
- //
- if( shouldUseFalcorCustomSharedKeywordSemantics(this) )
- {
- if( !globalVar->HasModifier<HLSLEffectSharedModifier>() )
- {
- // Skip a non-`shared` global for purposes of enumerating
- // shader parameters.
- //
- continue;
- }
- }
-
// This declaration may represent the same logical parameter
// as a declaration that came from a different translation unit.
// If that is the case, we want to re-use the same `ShaderParamInfo`
@@ -10646,9 +10677,9 @@ static bool doesParameterMatch(
GlobalShaderParamInfo shaderParamInfo;
shaderParamInfo.paramDeclRef = makeDeclRef(globalVar.Ptr());
- _collectExistentialSlotsForShaderParam(
+ _collectExistentialSpecializationParamsForShaderParam(
shaderParamInfo,
- m_globalExistentialSlots,
+ m_specializationParams,
makeDeclRef(globalVar.Ptr()));
m_shaderParams.add(shaderParamInfo);
@@ -10656,13 +10687,59 @@ static bool doesParameterMatch(
}
}
- /// Create a `Program` to represent the compiled code.
+ /// Create a new component type based on `inComponentType`, but with all its requiremetns filled.
+ RefPtr<ComponentType> fillRequirements(
+ ComponentType* inComponentType)
+ {
+ auto linkage = inComponentType->getLinkage();
+
+ // We are going to simplify things by solving the problem iteratively.
+ // If the current `componentType` has requirements for `A`, `B`, ... etc.
+ // then we will create a composite of `componentType`, `A`, `B`, ...
+ // and then see if the resulting composite has any requirements.
+ //
+ // This avoids the problem of trying to compute teh transitive closure
+ // of the requirements relationship (while dealing with deduplication,
+ // etc.)
+
+ RefPtr<ComponentType> componentType = inComponentType;
+ for(;;)
+ {
+ auto requirementCount = componentType->getRequirementCount();
+ if(requirementCount == 0)
+ break;
+
+ List<RefPtr<ComponentType>> allComponents;
+ allComponents.add(componentType);
+
+ for(Index rr = 0; rr < requirementCount; ++rr)
+ {
+ auto requirement = componentType->getRequirement(rr);
+ allComponents.add(requirement);
+ }
+
+ componentType = CompositeComponentType::create(
+ linkage,
+ allComponents);
+ }
+ return componentType;
+ }
+
+ /// Create a component type to represent the "global scope" of a compile request.
///
- /// The created program will comprise all of the translation
- /// units that were compiled as part of the request, as
- /// well as any entry points in those translation units.
+ /// This component type will include all the modules and their global
+ /// parameters from the compile request, but not anything specific
+ /// to any entry point functions.
///
- RefPtr<Program> createUnspecializedProgram(
+ /// The layout for this component type will thus represent the things that
+ /// a user is likely to want to have stay the same across all compiled
+ /// entry points.
+ ///
+ /// The component type that this function creates is unspecialized, in
+ /// that it doesn't take into account any specialization arguments
+ /// that might have been supplied as part of the compile request.
+ ///
+ RefPtr<ComponentType> createUnspecializedGlobalComponentType(
FrontEndCompileRequest* compileRequest)
{
// We want our resulting program to depend on
@@ -10677,16 +10754,51 @@ static bool doesParameterMatch(
//
auto linkage = compileRequest->getLinkage();
auto sink = compileRequest->getSink();
- auto program = new Program(linkage);
- for(auto translationUnit : compileRequest->translationUnits )
+
+ RefPtr<ComponentType> globalComponentType;
+ if(compileRequest->translationUnits.getCount() == 1)
{
- program->addReferencedLeafModule(translationUnit->getModule());
+ // The common case is that a compilation only uses
+ // a single translation unit, and thus results in
+ // a single `Module`. We can then use that module
+ // as the component type that represents the global scope.
+ //
+ globalComponentType = compileRequest->translationUnits[0]->getModule();
}
- for(auto translationUnit : compileRequest->translationUnits )
+ else
{
- program->addReferencedModule(translationUnit->getModule());
+ globalComponentType = new LegacyProgram(
+ linkage,
+ compileRequest->translationUnits,
+ sink);
}
+ return fillRequirements(globalComponentType);
+ }
+
+ /// Create a component type that represents the global scope for a compile request,
+ /// along with any entry point functions.
+ ///
+ /// The resulting component type will include the global-scope information
+ /// first, so its layout will be compatible with the result of
+ /// `createUnspecializedGlobalComponentType`.
+ ///
+ /// The new component type will also add on any entry-point functions
+ /// that were requested and will thus include space for their `uniform` parameters.
+ /// If multiple entry points were requested then they will be given non-overlapping
+ /// parameter bindings, consistent with them being used together in
+ /// a single pipeline state, hit group, etc.
+ ///
+ /// The result of this function is unspecialized and doesn't take into
+ /// account any specialization arguments the user might have supplied.
+ ///
+ RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType(
+ FrontEndCompileRequest* compileRequest)
+ {
+ auto linkage = compileRequest->getLinkage();
+ auto sink = compileRequest->getSink();
+
+ auto globalComponentType = compileRequest->getGlobalComponentType();
// The validation of entry points here will be modal, and controlled
// by whether the user specified any entry points directly via
@@ -10699,6 +10811,9 @@ static bool doesParameterMatch(
//
bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0;
+ List<RefPtr<ComponentType>> allComponentTypes;
+ allComponentTypes.add(globalComponentType);
+
if( anyExplicitEntryPoints )
{
// If there were any explicit requests for entry points to be
@@ -10716,8 +10831,9 @@ static bool doesParameterMatch(
// but didn't specify any groups (since the current
// compilation API doesn't allow for grouping).
//
- program->addEntryPoint(entryPoint, sink);
entryPointReq->getTranslationUnit()->entryPoints.add(entryPoint);
+
+ allComponentTypes.add(entryPoint);
}
}
@@ -10777,6 +10893,7 @@ static bool doesParameterMatch(
profile.setStage(entryPointAttr->stage);
RefPtr<EntryPoint> entryPoint = EntryPoint::create(
+ linkage,
makeDeclRef(funcDecl),
profile);
@@ -10792,179 +10909,340 @@ static bool doesParameterMatch(
// group, so that its entry-point parameters lay out
// independent of the others.
//
- program->addEntryPoint(entryPoint, sink);
translationUnit->entryPoints.add(entryPoint);
+
+ allComponentTypes.add(entryPoint);
}
}
}
- program->_collectShaderParams(sink);
-
- return program;
+ if(allComponentTypes.getCount() > 1)
+ {
+ auto composite = CompositeComponentType::create(
+ linkage,
+ allComponentTypes);
+ return composite;
+ }
+ else
+ {
+ return globalComponentType;
+ }
}
- static void _specializeExistentialTypeParams(
- Linkage* linkage,
- ExistentialTypeSlots& ioSlots,
- List<RefPtr<Expr>> const& args,
+ RefPtr<ComponentType::SpecializationInfo> Module::_validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
DiagnosticSink* sink)
{
- Index slotCount = ioSlots.paramTypes.getCount();
- Index argCount = args.getCount();
+ SLANG_ASSERT(argCount == getSpecializationParamCount());
- if( slotCount != argCount )
- {
- sink->diagnose(SourceLoc(), Diagnostics::mismatchExistentialSlotArgCount, slotCount, argCount);
- return;
- }
+ SemanticsVisitor visitor(getLinkage(), sink);
- SemanticsVisitor visitor(linkage, sink);
+ RefPtr<Module::ModuleSpecializationInfo> specializationInfo = new Module::ModuleSpecializationInfo();
- for( Index ii = 0; ii < slotCount; ++ii )
+ for( Index ii = 0; ii < argCount; ++ii )
{
- auto slotType = ioSlots.paramTypes[ii];
- auto argExpr = args[ii];
+ auto& arg = args[ii];
+ auto& param = m_specializationParams[ii];
- auto argType = checkProperType(linkage, TypeExp(argExpr), sink);
- if(!argType)
+ auto argType = arg.val.as<Type>();
+ SLANG_ASSERT(argType);
+
+ switch( param.flavor )
{
- // TODO: Each slot should track a source location and/or a `VarDeclBase`
- // that names the parameter that the slot corresponds to.
+ case SpecializationParam::Flavor::GenericType:
+ {
+ auto genericTypeParamDecl = param.object.as<GlobalGenericParamDecl>();
+ SLANG_ASSERT(genericTypeParamDecl);
- sink->diagnose(SourceLoc(), Diagnostics::existentialSlotArgNotAType, ii);
- return;
+ // TODO: There is a serious flaw to this checking logic if we ever have cases where
+ // the constraints on one `type_param` can depend on another `type_param`, e.g.:
+ //
+ // type_param A;
+ // type_param B : ISidekick<A>;
+ //
+ // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to
+ // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being
+ // set to `Batman` to know whether the setting for `B` is valid. In this limit
+ // the constraints can be mutually recursive (so `A : IMentor<B>`).
+ //
+ // The only way to check things correctly is to validate each conformance under
+ // a set of assumptions (substitutions) that includes all the type substitutions,
+ // and possibly also all the other constraints *except* the one to be validated.
+ //
+ // We will punt on this for now, and just check each constraint in isolation.
+
+ // As a quick sanity check, see if the argument that is being supplied for a
+ // global generic type parameter is a reference to *another* global generic
+ // type parameter, since that should always be an error.
+ //
+ if( auto argDeclRefType = argType.as<DeclRefType>() )
+ {
+ auto argDeclRef = argDeclRefType->declRef;
+ if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>())
+ {
+ if(argGenericParamDeclRef.getDecl() == genericTypeParamDecl)
+ {
+ // We are trying to specialize a generic parameter using itself.
+ sink->diagnose(genericTypeParamDecl,
+ Diagnostics::cannotSpecializeGlobalGenericToItself,
+ genericTypeParamDecl->getName());
+ continue;
+ }
+ else
+ {
+ // We are trying to specialize a generic parameter using a *different*
+ // global generic type parameter.
+ sink->diagnose(genericTypeParamDecl,
+ Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam,
+ genericTypeParamDecl->getName(),
+ argGenericParamDeclRef.GetName());
+ continue;
+ }
+ }
+ }
+
+ ModuleSpecializationInfo::GenericArgInfo genericArgInfo;
+ genericArgInfo.paramDecl = genericTypeParamDecl;
+ genericArgInfo.argVal = argType;
+ specializationInfo->genericArgs.add(genericArgInfo);
+
+ // Walk through the declared constraints for the parameter,
+ // and check that the argument actually satisfies them.
+ for(auto constraintDecl : genericTypeParamDecl->getMembersOfType<GenericTypeConstraintDecl>())
+ {
+ // Get the type that the constraint is enforcing conformance to
+ auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraintDecl, nullptr));
+
+ // Use our semantic-checking logic to search for a witness to the required conformance
+ auto witness = visitor.tryGetSubtypeWitness(argType, interfaceType);
+ if (!witness)
+ {
+ // If no witness was found, then we will be unable to satisfy
+ // the conformances required.
+ sink->diagnose(genericTypeParamDecl,
+ Diagnostics::typeArgumentForGenericParameterDoesNotConformToInterface,
+ argType,
+ genericTypeParamDecl->nameAndLoc.name,
+ interfaceType);
+ }
+
+ ModuleSpecializationInfo::GenericArgInfo constraintArgInfo;
+ constraintArgInfo.paramDecl = constraintDecl;
+ constraintArgInfo.argVal = witness;
+ specializationInfo->genericArgs.add(constraintArgInfo);
+ }
+ }
+ break;
+
+ case SpecializationParam::Flavor::ExistentialType:
+ {
+ auto interfaceType = param.object.as<Type>();
+ SLANG_ASSERT(interfaceType);
+
+ auto witness = visitor.tryGetSubtypeWitness(argType, interfaceType);
+ if (!witness)
+ {
+ // If no witness was found, then we will be unable to satisfy
+ // the conformances required.
+ sink->diagnose(SourceLoc(),
+ Diagnostics::typeArgumentDoesNotConformToInterface,
+ argType,
+ interfaceType);
+ }
+
+ ExpandedSpecializationArg expandedArg;
+ expandedArg.val = argType;
+ expandedArg.witness = witness;
+
+ specializationInfo->existentialArgs.add(expandedArg);
+ }
+ break;
+
+ default:
+ SLANG_UNEXPECTED("unhandled specialization parameter flavor");
}
+ }
+ return specializationInfo;
+ }
- auto witness = visitor.tryGetSubtypeWitness(argType, slotType);
- if (!witness)
+
+ static void _extractSpecializationArgs(
+ ComponentType* componentType,
+ List<RefPtr<Expr>> const& argExprs,
+ List<SpecializationArg>& outArgs,
+ DiagnosticSink* sink)
+ {
+ auto linkage = componentType->getLinkage();
+
+ auto argCount = argExprs.getCount();
+ for(Index ii = 0; ii < argCount; ++ii )
+ {
+ auto argExpr = argExprs[ii];
+ auto paramInfo = componentType->getSpecializationParam(ii);
+
+ // TODO: We should support non-type arguments here
+
+ auto argType = checkProperType(linkage, TypeExp(argExpr), sink);
+ if( !argType )
{
// If no witness was found, then we will be unable to satisfy
// the conformances required.
- sink->diagnose(SourceLoc(), Diagnostics::existentialSlotArgDoesNotConform, ii, slotType);
- return;
+ sink->diagnose(argExpr,
+ Diagnostics::expectedAType,
+ argExpr->type);
+ continue;
}
- ExistentialTypeSlots::Arg arg;
- arg.type = argType;
- arg.witness = witness;
- ioSlots.args.add(arg);
+ SpecializationArg arg;
+ arg.val = argType;
+ outArgs.add(arg);
}
}
- void EntryPoint::_specializeExistentialTypeParams(
- List<RefPtr<Expr>> const& args,
+ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArgsImpl(
+ SpecializationArg const* inArgs,
+ Index inArgCount,
DiagnosticSink* sink)
{
- Slang::_specializeExistentialTypeParams(getLinkage(), m_existentialSlots, args, sink);
- }
+ auto args = inArgs;
+ auto argCount = inArgCount;
- /// Create a specialization an existing entry point based on generic arguments.
- RefPtr<EntryPoint> createSpecializedEntryPoint(
- EntryPoint* unspecializedEntryPoint,
- List<RefPtr<Expr>> const& genericArgs,
- List<RefPtr<Expr>> const& existentialArgs,
- DiagnosticSink* sink)
- {
- auto linkage = unspecializedEntryPoint->getLinkage();
+ SemanticsVisitor visitor(getLinkage(), sink);
- // TODO: Need to be careful in case entry point already has a decl-ref,
- // pertaining to outer specializations (e.g., when entry point was
- // nested in a generic type.
+ // The first N arguments will be for the explicit generic parameters
+ // of the entry point (if it has any).
//
- auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl();
+ auto genericSpecializationParamCount = getGenericSpecializationParamCount();
+ SLANG_ASSERT(argCount >= genericSpecializationParamCount);
- SemanticsVisitor semantics(
- linkage,
- sink);
+ Result result = SLANG_OK;
- DeclRef<FuncDecl> entryPointFuncDeclRef = makeDeclRef(entryPointFuncDecl.Ptr());
- if( auto genericDecl = as<GenericDecl>(entryPointFuncDecl->ParentDecl) )
+ RefPtr<EntryPointSpecializationInfo> info = new EntryPointSpecializationInfo();
+
+ DeclRef<FuncDecl> specializedFuncDeclRef = m_funcDeclRef;
+ if(genericSpecializationParamCount)
{
- // We will construct a suitable `GenericAppExpr` to represent
- // the user-specified `genericDecl` being applied to the
- // supplied `genericArgs`, and then use the existing
- // semantic checking logic that would apply to an explicit
- // generic application like `F<A,B,C>` if it were
- // encountered in the source code.
+ // We need to construct a generic application and use
+ // the semantic checking machinery to expand out
+ // the rest of the arguments via inference...
- auto session = linkage->getSessionImpl();
- auto genericDeclRef = makeDeclRef(genericDecl);
+ auto genericDeclRef = m_funcDeclRef.GetParent().as<GenericDecl>();
+ SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters
- // The first pieces is a `VarExpr` that refers to `genericDecl`.
- //
- // TODO: This would not be needed if we instead parsed
- // the supplied entry-point name into an expression
- // earlier in this function.
- //
- RefPtr<VarExpr> genericExpr = new VarExpr();
- genericExpr->declRef = genericDeclRef;
- genericExpr->type.type = getTypeForDeclRef(session, genericDeclRef);
-
- // Next we construct the actual `GenericAppExpr`
- //
- RefPtr<GenericAppExpr> genericAppExpr = new GenericAppExpr();
- genericAppExpr->FunctionExpr = genericExpr;
- genericAppExpr->Arguments = genericArgs;
+ RefPtr<GenericSubstitution> genericSubst = new GenericSubstitution();
+ genericSubst->outer = genericDeclRef.substitutions.substitutions;
+ genericSubst->genericDecl = genericDeclRef.getDecl();
- // We use the semantics visitor to perform the
- // actual checking logic (this might report
- // errors)
- //
- auto checkedExpr = semantics.checkGenericAppWithCheckedArgs(genericAppExpr);
-
- // Now we need to extract an appropriate decl-ref for the entry
- // point from the `checkedExpr`.
- //
- if( auto declRefExpr = checkedExpr.as<DeclRefExpr>() )
+ for(Index ii = 0; ii < genericSpecializationParamCount; ++ii)
{
- // TODO: We should eventually check for the case
- // where we have a `MemberExpr` or another case of
- // `DeclRefExpr` that cannot be summarized as just
- // its decl-ref.
- //
- // The basic `VarExpr` and `StaticMemberExpr` cases
- // should be allow-able.
-
- entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>();
+ auto specializationArg = args[ii];
+ genericSubst->args.add(specializationArg.val);
}
- else if( semantics.IsErrorExpr(checkedExpr) )
+
+ for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
- // Any semantic error that occured should have been
- // reported already.
- return nullptr;
+ auto constraintSubst = genericDeclRef.substitutions;
+ constraintSubst.substitutions = genericSubst;
+
+ DeclRef<GenericTypeConstraintDecl> constraintDeclRef(
+ constraintDecl, constraintSubst);
+
+ auto sub = GetSub(constraintDeclRef);
+ auto sup = GetSup(constraintDeclRef);
+
+ auto subTypeWitness = visitor.tryGetSubtypeWitness(sub, sup);
+ if(subTypeWitness)
+ {
+ genericSubst->args.add(subTypeWitness);
+ }
+ else
+ {
+ // TODO: diagnose a problem here
+ sink->diagnose(constraintDecl, Diagnostics::typeArgumentDoesNotConformToInterface, sub, sup);
+ result = SLANG_FAIL;
+ continue;
+ }
}
- else
+
+ specializedFuncDeclRef.substitutions.substitutions = genericSubst;
+ }
+
+ info->specializedFuncDeclRef = specializedFuncDeclRef;
+
+ // Once the generic parameters (if any) have been dealt with,
+ // any remaining specialization arguments are for existential/interface
+ // specialization parameters, attached to the value parameters
+ // of the entry point.
+ //
+ args += genericSpecializationParamCount;
+ argCount -= genericSpecializationParamCount;
+
+ auto existentialSpecializationParamCount = getExistentialSpecializationParamCount();
+ SLANG_ASSERT(argCount == existentialSpecializationParamCount);
+
+ for( Index ii = 0; ii < existentialSpecializationParamCount; ++ii )
+ {
+ auto& param = m_existentialSpecializationParams[ii];
+ auto& specializationArg = args[ii];
+
+ // TODO: We need to handle all the cases of "flavor" for the `param`s (not just types)
+
+ auto paramType = param.object.as<Type>();
+ auto argType = specializationArg.val.as<Type>();
+
+ auto witness = visitor.tryGetSubtypeWitness(argType, paramType);
+ if (!witness)
{
- // The result of specializing a reference to a generic
- // function should always be a `DeclRefExpr`
- //
- SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`");
- UNREACHABLE_RETURN(nullptr);
+ // If no witness was found, then we will be unable to satisfy
+ // the conformances required.
+ sink->diagnose(SourceLoc(), Diagnostics::typeArgumentDoesNotConformToInterface, argType, paramType);
+ result = SLANG_FAIL;
+ continue;
}
+
+ ExpandedSpecializationArg expandedArg;
+ expandedArg.val = specializationArg.val;
+ expandedArg.witness = witness;
+ info->existentialSpecializationArgs.add(expandedArg);
}
- RefPtr<EntryPoint> specializedEntryPoint = EntryPoint::create(
- entryPointFuncDeclRef,
- unspecializedEntryPoint->getProfile());
+ return info;
+ }
- // Next we need to validate the existential arguments.
- specializedEntryPoint->_specializeExistentialTypeParams(existentialArgs, sink);
+ /// Create a specialization an existing entry point based on specialization argument expressions.
+ RefPtr<ComponentType> createSpecializedEntryPoint(
+ EntryPoint* unspecializedEntryPoint,
+ List<RefPtr<Expr>> const& argExprs,
+ DiagnosticSink* sink)
+ {
+ // We need to convert all of the `Expr` arguments
+ // into `SpecializationArg`s, so that we can bottleneck
+ // through the shared logic.
+ //
+ List<SpecializationArg> args;
+ _extractSpecializationArgs(unspecializedEntryPoint, argExprs, args, sink);
+ if(sink->GetErrorCount())
+ return nullptr;
- return specializedEntryPoint;
+ return unspecializedEntryPoint->specialize(
+ args.getBuffer(),
+ args.getCount(),
+ sink);
}
- /// Parse an array of strings as generic arguments.
+ /// Parse an array of strings as specialization arguments.
///
/// Names in the strings will be parsed in the context of
/// the code loaded into the given compile request.
///
- void parseGenericArgStrings(
+ void parseSpecializationArgStrings(
EndToEndCompileRequest* endToEndReq,
List<String> const& genericArgStrings,
List<RefPtr<Expr>>& outGenericArgs)
{
- auto unspecialiedProgram = endToEndReq->getUnspecializedProgram();
+ auto unspecialiedProgram = endToEndReq->getUnspecializedGlobalComponentType();
// TODO: Building a list of `scopesToTry` here shouldn't
// be required, since the `Scope` type itself has the ability
@@ -11004,351 +11282,109 @@ static bool doesParameterMatch(
}
}
+ if(!argExpr)
+ {
+ sink->diagnose(SourceLoc(), Diagnostics::internalCompilerError, "couldn't parse specialization argument");
+ return;
+ }
+
outGenericArgs.add(argExpr);
}
}
- void Program::_specializeExistentialTypeParams(
- List<RefPtr<Expr>> const& args,
- DiagnosticSink* sink)
- {
- Slang::_specializeExistentialTypeParams(getLinkageImpl(), m_globalExistentialSlots, args, sink);
- }
-
Type* Linkage::specializeType(
Type* unspecializedType,
Int argCount,
Type* const* args,
DiagnosticSink* sink)
{
+ SLANG_ASSERT(unspecializedType);
+
// TODO: We should cache and re-use specialized types
// when the exact same arguments are provided again later.
SemanticsVisitor visitor(this, sink);
+ SpecializationParams specializationParams;
+ _collectExistentialSpecializationParamsRec(specializationParams, unspecializedType);
- ExistentialTypeSlots slots;
- _collectExistentialTypeParamsRec(slots, unspecializedType);
-
- assert(slots.paramTypes.getCount() == argCount);
+ assert(specializationParams.getCount() == argCount);
+ ExpandedSpecializationArgs specializationArgs;
for( Int aa = 0; aa < argCount; ++aa )
{
+ auto paramType = specializationParams[aa].object.as<Type>();
auto argType = args[aa];
- ExistentialTypeSlots::Arg arg;
- arg.type = argType;
- arg.witness = visitor.tryGetSubtypeWitness(argType, slots.paramTypes[aa]);
- slots.args.add(arg);
+ ExpandedSpecializationArg arg;
+ arg.val = argType;
+ arg.witness = visitor.tryGetSubtypeWitness(argType, paramType);
+ specializationArgs.add(arg);
}
RefPtr<ExistentialSpecializedType> specializedType = new ExistentialSpecializedType();
specializedType->baseType = unspecializedType;
- specializedType->slots = slots;
+ specializedType->args = specializationArgs;
m_specializedTypes.add(specializedType);
return specializedType;
}
- // Shared implementation logic for the `_createSpecializedProgram*` entry points.
- static RefPtr<Program> _createSpecializedProgramImpl(
+ /// Shared implementation logic for the `_createSpecializedProgram*` entry points.
+ static RefPtr<ComponentType> _createSpecializedProgramImpl(
Linkage* linkage,
- Program* unspecializedProgram,
- List<RefPtr<Expr>> const& globalGenericArgs,
- List<RefPtr<Expr>> const& globalExistentialArgs,
+ ComponentType* unspecializedProgram,
+ List<RefPtr<Expr>> const& specializationArgExprs,
DiagnosticSink* sink)
{
- // TODO: If there are no specialization arguments,
+ // If there are no specialization arguments,
// then the the result of specialization should
- // be the same as the input... *but* we are promising
- // to return a program without any entry points,
- // and that screws things up here.
- //
- // For now we just carefully avod this early-exit
- // if we have any entry points.
+ // be the same as the input.
//
- // Eventually we should try to revise the model so
- // that specialization of a program that includes
- // entry points can make some kind of sense.
- //
- if( globalGenericArgs.getCount() == 0
- && globalExistentialArgs.getCount() == 0
- && unspecializedProgram->getEntryPointCount() == 0)
+ auto specializationArgCount = specializationArgExprs.getCount();
+ if( specializationArgCount == 0 )
{
return unspecializedProgram;
}
- // The given `unspecializedProgram` should be one that
- // was checked through the front-end, so that now we
- // only need to check if the given arguments can satisfy
- // the requirements of the global generic parameters.
- //
- // The new program needs to start off with the same
- // module dependency list as the original.
- //
- RefPtr<Program> specializedProgram = new Program(linkage);
- for(auto module : unspecializedProgram->getModuleDependencies())
- {
- specializedProgram->addReferencedLeafModule(module);
- }
-
-
- // We will collect all the global generic parameters
- // defined in the modules being referenced, to find
- // the global generic parameter signature of the
- // program.
- //
- // TODO: Note that this doesn't handle the case where one
- // or more of the type *arguments* that we are specifying
- // ends up requiring additional modules to be referenced,
- // which might in turn introduce new global generic parameters.
- //
- List<RefPtr<GlobalGenericParamDecl>> globalGenericParams;
- for(auto module : unspecializedProgram->getModuleDependencies())
+ auto specializationParamCount = unspecializedProgram->getSpecializationParamCount();
+ if(specializationArgCount != specializationParamCount )
{
- for(auto param : module->getModuleDecl()->getMembersOfType<GlobalGenericParamDecl>())
- globalGenericParams.add(param);
- }
-
- // Next, we will check whether the supplied arguments can
- // satisfy those parameters.
- //
- // An easy early-out case will be if the number of
- // arguments isn't correct.
- //
- if (globalGenericParams.getCount() != globalGenericArgs.getCount())
- {
- sink->diagnose(SourceLoc(), Diagnostics::mismatchGlobalGenericArguments,
- globalGenericParams.getCount(),
- globalGenericArgs.getCount());
+ sink->diagnose(SourceLoc(), Diagnostics::mismatchSpecializationArguments,
+ specializationParamCount,
+ specializationArgCount);
return nullptr;
}
- // We have an appropriate number of arguments for the global generic parameters,
+ // We have an appropriate number of arguments for the global specialization parameters,
// and now we need to check that the arguments conform to the declared constraints.
//
SemanticsVisitor visitor(linkage, sink);
- // Along the way, we will build up an appropriate set of substitutions to represent
- // the generic arguments and their conformances.
- //
- RefPtr<Substitutions> globalGenericSubsts;
- auto globalGenericSubstLink = &globalGenericSubsts;
- //
- // TODO: There is a serious flaw to this checking logic if we ever have cases where
- // the constraints on one `type_param` can depend on another `type_param`, e.g.:
- //
- // type_param A;
- // type_param B : ISidekick<A>;
- //
- // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to
- // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being
- // set to `Batman` to know whether the setting for `B` is valid. In this limit
- // the constraints can be mutually recursive (so `A : IMentor<B>`).
- //
- // The only way to check things correctly is to validate each conformance under
- // a set of assumptions (substitutions) that includes all the type substitutions,
- // and possibly also all the other constraints *except* the one to be validated.
- //
- // We will punt on this for now, and just check each constraint in isolation.
- //
- Index argCounter = 0;
- for(auto& globalGenericParam : globalGenericParams)
- {
- // Get the argument that matches this parameter.
- Index argIndex = argCounter++;
- SLANG_ASSERT(argIndex < globalGenericArgs.getCount());
- auto globalGenericArg = checkProperType(linkage, TypeExp(globalGenericArgs[argIndex]), sink);
- if (!globalGenericArg)
- {
- sink->diagnose(globalGenericParam, Diagnostics::globalGenericArgumentNotAType, globalGenericParam->getName());
- return nullptr;
- }
-
- // As a quick sanity check, see if the argument that is being supplied for a parameter
- // is just the parameter itself, because this should always be an error:
- //
- if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() )
- {
- auto argDeclRef = argDeclRefType->declRef;
- if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>())
- {
- if(argGenericParamDeclRef.getDecl() == globalGenericParam)
- {
- // We are trying to specialize a generic parameter using itself.
- sink->diagnose(globalGenericParam,
- Diagnostics::cannotSpecializeGlobalGenericToItself,
- globalGenericParam->getName());
- continue;
- }
- else
- {
- // We are trying to specialize a generic parameter using a *different*
- // global generic type parameter.
- sink->diagnose(globalGenericParam,
- Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam,
- globalGenericParam->getName(),
- argGenericParamDeclRef.GetName());
- continue;
- }
- }
- }
-
- // Create a substitution for this parameter/argument.
- RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution();
- subst->paramDecl = globalGenericParam;
- subst->actualType = globalGenericArg;
-
- // Walk through the declared constraints for the parameter,
- // and check that the argument actually satisfies them.
- for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>())
- {
- // Get the type that the constraint is enforcing conformance to
- auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr));
-
- // Use our semantic-checking logic to search for a witness to the required conformance
- auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType);
- if (!witness)
- {
- // If no witness was found, then we will be unable to satisfy
- // the conformances required.
- sink->diagnose(globalGenericParam,
- Diagnostics::typeArgumentDoesNotConformToInterface,
- globalGenericParam->nameAndLoc.name,
- globalGenericArg,
- interfaceType);
- }
-
- // Attach the concrete witness for this conformance to the
- // substutiton
- GlobalGenericParamSubstitution::ConstraintArg constraintArg;
- constraintArg.decl = constraint;
- constraintArg.val = witness;
- subst->constraintArgs.add(constraintArg);
- }
-
- // Add the substitution for this parameter to the global substitution
- // set that we are building.
-
- *globalGenericSubstLink = subst;
- globalGenericSubstLink = &subst->outer;
- }
+ List<SpecializationArg> specializationArgs;
+ _extractSpecializationArgs(unspecializedProgram, specializationArgExprs, specializationArgs, sink);
if(sink->GetErrorCount())
return nullptr;
- specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts);
-
- return specializedProgram;
- }
-
- /// Create a specialized copy of `unspecializedProgram`.
- ///
- /// The specialized program will include the entry points
- /// from the original program (whether thsoe entry points
- /// are specialized or not).
- ///
- static RefPtr<Program> _createSpecializedProgram(
- Linkage* linkage,
- Program* unspecializedProgram,
- List<RefPtr<Expr>> const& globalGenericArgs,
- List<RefPtr<Expr>> const& globalExistentialArgs,
- DiagnosticSink* sink)
- {
- auto specializedProgram = _createSpecializedProgramImpl(
- linkage,
- unspecializedProgram,
- globalGenericArgs,
- globalExistentialArgs,
+ auto specializedProgram = unspecializedProgram->specialize(
+ specializationArgs.getBuffer(),
+ specializationArgs.getCount(),
sink);
- // We need to ensure that the specialized program has whatever
- // entry points (and groups) the unspecialized one had.
- //
- for( auto entryPointGroup : unspecializedProgram->getEntryPointGroups() )
- {
- specializedProgram->addEntryPointGroup(entryPointGroup);
- }
-
- // Now deal with the shader parameters and existential arguments
- //
- // Note: We should in theory be able to just copy over the shader
- // parameters and existential slot information from the unspecialized
- // program. This could save some time, but it would also mean that
- // the only way to create a specialized program is by creating an
- // unspecialized on first, which is maybe not always desirable.
- //
- specializedProgram->_collectShaderParams(sink);
- specializedProgram->_specializeExistentialTypeParams(globalExistentialArgs, sink);
-
return specializedProgram;
}
- SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::specializeProgram(
- slang::IProgram* inUnspecializedProgram,
- SlangInt specializationArgCount,
- slang::SpecializationArg const* specializationArgs,
- slang::IProgram** outSpecializedProgram,
- ISlangBlob** outDiagnostics)
- {
- auto unspecializedProgram = asInternal(inUnspecializedProgram);
-
- if( specializationArgCount == 0 )
- {
- *outSpecializedProgram = ComPtr<slang::IProgram>(asExternal(unspecializedProgram)).detach();
- return SLANG_OK;
- }
-
- List<RefPtr<Expr>> globalGenericArgs;
-
- List<RefPtr<Expr>> globalExistentialArgs;
- for( Int ii = 0; ii < specializationArgCount; ++ii )
- {
- auto& specializationArg = specializationArgs[ii];
- switch( specializationArg.kind )
- {
- case slang::SpecializationArg::Kind::Type:
- {
- auto typeArg = asInternal(specializationArg.type);
- RefPtr<SharedTypeExpr> argExpr = new SharedTypeExpr();
- argExpr->base = TypeExp(typeArg);
- argExpr->type = QualType(getTypeType(typeArg));
-
- globalExistentialArgs.add(argExpr);
- }
- break;
-
- default:
- return SLANG_E_INVALID_ARG;
- }
- }
-
- DiagnosticSink sink(getSourceManager());
- auto specializedProgram = _createSpecializedProgram(
- this,
- unspecializedProgram,
- globalGenericArgs,
- globalExistentialArgs,
- &sink);
- sink.getBlobIfNeeded(outDiagnostics);
-
- if(!specializedProgram)
- return SLANG_FAIL;
-
- *outSpecializedProgram = ComPtr<slang::IProgram>(asExternal(specializedProgram)).detach();
- return SLANG_OK;
- }
-
- /// Specialize an entry point that was checked by the front-end, based on generic arguments.
+ /// Specialize an entry point that was checked by the front-end, based on specialization arguments.
///
- /// If the end-to-end compile request included generic argument strings
+ /// If the end-to-end compile request included specialization argument strings
/// for this entry point, then they will be parsed, checked, and used
/// as arguments to the generic entry point.
///
/// Returns a specialized entry point if everything worked as expected.
/// Returns null and diagnoses errors if anything goes wrong.
///
- RefPtr<EntryPoint> createSpecializedEntryPoint(
+ RefPtr<ComponentType> createSpecializedEntryPoint(
EndToEndCompileRequest* endToEndReq,
EntryPoint* unspecializedEntryPoint,
EndToEndCompileRequest::EntryPointInfo const& entryPointInfo)
@@ -11359,33 +11395,35 @@ static bool doesParameterMatch(
// If the user specified generic arguments for the entry point,
// then we will need to parse the arguments first.
//
- List<RefPtr<Expr>> genericArgs;
- parseGenericArgStrings(
- endToEndReq,
- entryPointInfo.genericArgStrings,
- genericArgs);
-
- List<RefPtr<Expr>> existentialArgs;
- parseGenericArgStrings(
+ List<RefPtr<Expr>> specializationArgExprs;
+ parseSpecializationArgStrings(
endToEndReq,
- entryPointInfo.existentialArgStrings,
- existentialArgs);
+ entryPointInfo.specializationArgStrings,
+ specializationArgExprs);
// Next we specialize the entry point function given the parsed
// generic argument expressions.
//
auto entryPoint = createSpecializedEntryPoint(
unspecializedEntryPoint,
- genericArgs,
- existentialArgs,
+ specializationArgExprs,
sink);
return entryPoint;
}
- /// Create a specialized program based on the given compile request.
+ /// Create a specialized component type for the global scope of the given compile request.
+ ///
+ /// The specialized program will be consistent with that created by
+ /// `createUnspecializedGlobalComponentType`, and will simply fill in
+ /// its specialization parameters with the arguments (if any) supllied
+ /// as part fo the end-to-end compile request.
+ ///
+ /// The layout of the new component type will be consistent with that
+ /// of the original *if* there are no global generic type parameters
+ /// (only interface/existential parameters).
///
- RefPtr<Program> createSpecializedProgram(
+ RefPtr<ComponentType> createSpecializedGlobalComponentType(
EndToEndCompileRequest* endToEndReq)
{
// The compile request must have already completed front-end processing,
@@ -11393,36 +11431,32 @@ static bool doesParameterMatch(
// to parse and check any generic arguments that are being supplied for
// global or entry-point generic parameters.
//
- auto unspecializedProgram = endToEndReq->getUnspecializedProgram();
+ auto unspecializedProgram = endToEndReq->getUnspecializedGlobalComponentType();
auto linkage = endToEndReq->getLinkage();
+ auto sink = endToEndReq->getSink();
- // First, let's parse the generic argument strings that were
- // provided via the API, so taht we can match them
+ // First, let's parse the specialization argument strings that were
+ // provided via the API, so that we can match them
// against what was declared in the program.
//
- List<RefPtr<Expr>> globalGenericArgs;
- parseGenericArgStrings(
+ List<RefPtr<Expr>> globalSpecializationArgs;
+ parseSpecializationArgStrings(
endToEndReq,
- endToEndReq->globalGenericArgStrings,
- globalGenericArgs);
+ endToEndReq->globalSpecializationArgStrings,
+ globalSpecializationArgs);
- // Also handle global existential type arguments.
- List<RefPtr<Expr>> globalExistentialArgs;
- parseGenericArgStrings(
- endToEndReq,
- endToEndReq->globalExistentialSlotArgStrings,
- globalExistentialArgs);
+ // Don't proceed further if anything failed to parse.
+ if(sink->GetErrorCount())
+ return nullptr;
// Now we create the initial specialized program by
// applying the global generic arguments (if any) to the
// unspecialized program.
//
- auto sink = endToEndReq->getSink();
auto specializedProgram = _createSpecializedProgramImpl(
- endToEndReq->getLinkage(),
+ linkage,
unspecializedProgram,
- globalGenericArgs,
- globalExistentialArgs,
+ globalSpecializationArgs,
sink);
// If anything went wrong with the global generic
@@ -11450,33 +11484,69 @@ static bool doesParameterMatch(
endToEndReq->entryPoints.setCount(entryPointCount);
}
- Index entryPointCounter = 0;
+ return specializedProgram;
+ }
- for( auto unspecializedEntryPointGroup : unspecializedProgram->getEntryPointGroups() )
- {
- List<RefPtr<EntryPoint>> specializedEntryPoints;
- for( auto unspecializedEntryPoint : unspecializedEntryPointGroup->getEntryPoints() )
- {
- Index entryPointIndex = entryPointCounter++;
- auto& entryPointInfo = endToEndReq->entryPoints[entryPointIndex];
+ /// Create a specialized program based on the given compile request.
+ ///
+ /// The specialized program created here includes both the global
+ /// scope for all the translation units involved and all the entry
+ /// points, and it also includes any specialization arguments
+ /// that were supplied.
+ ///
+ /// It is important to note that this function specializes
+ /// the global scope and the entry points in isolation and then
+ /// composes them, and that this can lead to different layout
+ /// from the result of `createUnspecializedGlobalAndEntryPointsComponentType`.
+ ///
+ /// If we have a module `M` with entry point `E`, and each has one
+ /// specialization parameter, then `createUnspecialized...` will yield:
+ ///
+ /// compose(M,E)
+ ///
+ /// That composed type will have two specialization parameters (the one
+ /// from `M` plus the one from `E`) and so we might specialize it to get:
+ ///
+ /// specialize(compose(M,E), X, Y)
+ ///
+ /// while if we use `createSpecialized...` we will get:
+ ///
+ /// compose(specialize(M,X), specialize(E,Y))
+ ///
+ /// While these options are semantically equivalent, they would not lay
+ /// out the same way in memory.
+ ///
+ /// There are many reasons why an application might prefer one over the
+ /// other, and an application that cares should use the more explicit
+ /// APIs to construct what they want. The behavior of this function
+ /// is just to provide a reasonable default for use by end-to-end
+ /// compilation (e.g., from the command line).
+ ///
+ RefPtr<ComponentType> createSpecializedGlobalAndEntryPointsComponentType(
+ EndToEndCompileRequest* endToEndReq)
+ {
+ auto specializedGlobalComponentType = endToEndReq->getSpecializedGlobalComponentType();
- auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo);
- specializedEntryPoints.add(specializedEntryPoint);
- }
+ List<RefPtr<ComponentType>> allComponentTypes;
+ allComponentTypes.add(specializedGlobalComponentType);
- RefPtr<EntryPointGroup> specializedEntryPointGroup = EntryPointGroup::create(linkage, specializedEntryPoints, endToEndReq->getSink());
- specializedProgram->addEntryPointGroup(specializedEntryPointGroup);
- }
+ auto unspecializedGlobalAndEntryPointsComponentType = endToEndReq->getUnspecializedGlobalAndEntryPointsComponentType();
+ auto entryPointCount = unspecializedGlobalAndEntryPointsComponentType->getEntryPointCount();
- // Finalize the information for the specialized program,
- // now that we have computed its entry point list, etc.
- //
- specializedProgram->_collectShaderParams(sink);
- specializedProgram->_specializeExistentialTypeParams(globalExistentialArgs, sink);
+ for(Index ii = 0; ii < entryPointCount; ++ii)
+ {
+ auto& entryPointInfo = endToEndReq->entryPoints[ii];
+ auto unspecializedEntryPoint = unspecializedGlobalAndEntryPointsComponentType->getEntryPoint(ii);
- return specializedProgram;
+ auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo);
+ allComponentTypes.add(specializedEntryPoint);
+ }
+
+ RefPtr<ComponentType> composed = CompositeComponentType::create(endToEndReq->getLinkage(), allComponentTypes);
+ return composed;
}
+
void checkTranslationUnit(
TranslationUnitRequest* translationUnit)
{
@@ -11488,6 +11558,8 @@ static bool doesParameterMatch(
// checking that is required on all declarations
// in the translation unit.
visitor.checkDecl(translationUnit->getModuleDecl());
+
+ translationUnit->getModule()->_collectShaderParams();
}
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 64afef136..1c0bed065 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -199,10 +199,12 @@ namespace Slang
//
RefPtr<EntryPoint> EntryPoint::create(
+ Linkage* linkage,
DeclRef<FuncDecl> funcDeclRef,
Profile profile)
{
RefPtr<EntryPoint> entryPoint = new EntryPoint(
+ linkage,
funcDeclRef.GetName(),
profile,
funcDeclRef);
@@ -210,10 +212,12 @@ namespace Slang
}
RefPtr<EntryPoint> EntryPoint::createDummyForPassThrough(
+ Linkage* linkage,
Name* name,
Profile profile)
{
RefPtr<EntryPoint> entryPoint = new EntryPoint(
+ linkage,
name,
profile,
DeclRef<FuncDecl>());
@@ -221,103 +225,93 @@ namespace Slang
}
EntryPoint::EntryPoint(
+ Linkage* linkage,
Name* name,
Profile profile,
DeclRef<FuncDecl> funcDeclRef)
- : m_name(name)
+ : ComponentType(linkage)
+ , m_name(name)
, m_profile(profile)
, m_funcDeclRef(funcDeclRef)
{
- // In order for later code generation to work, we need to track what
- // modules each entry point depends on. We will build up the dependency
- // list here when an `EntryPoint` gets created.
- //
- // We know an entry point depends on the module that declared the
- // entry-point function itself.
- //
- // Note: we are carefully handling the case where `module` could
- // be null, becase of "dummy" entry points created for pass-through
- // compilation.
+ // Collect any specialization parameters used by the entry point
//
- if(auto module = getModule())
+ _collectShaderParams();
+ }
+
+ Module* EntryPoint::getModule()
+ {
+ return Slang::getModule(getFuncDecl());
+ }
+
+ Index EntryPoint::getSpecializationParamCount()
+ {
+ return m_genericSpecializationParams.getCount() + m_existentialSpecializationParams.getCount();
+ }
+
+ SpecializationParam const& EntryPoint::getSpecializationParam(Index index)
+ {
+ auto genericParamCount = m_genericSpecializationParams.getCount();
+ if(index < genericParamCount)
{
- m_dependencyList.addDependency(module);
+ return m_genericSpecializationParams[index];
}
- //
- // TODO: We also need to include the modules needed by any generic
- // arguments in the dependency list, since in the general case they
- // might come from modules other than the one defining the entry point.
+ else
+ {
+ return m_existentialSpecializationParams[index - genericParamCount];
+ }
+ }
- // The following is a bit of a hack.
- //
- // Back-end code generation relies on us having computed layouts for all tagged
- // unions that end up being used in the code, which means we need a way to find
- // all such types that get used in a program (and the stuff it imports).
- //
- // For now we are assuming a tagged union type only comes into existence
- // as a (top-level) argument for a generic type parameter, so that we
- // can check for them here and cache them on the entry point.
+ Index EntryPoint::getRequirementCount()
+ {
+ // The only requirement of an entry point is the module that contains it.
//
- // A longer-term strategy might need to consider any (tagged or untagged)
- // union types that get used inside of a module, and also take
- // those lists into account.
+ // TODO: We will eventually want to support the case of an entry
+ // point nested in a `struct` type, in which case there should be
+ // a single requirement representing that outer type (so that multiple
+ // entry points nested under the same type can share the storage
+ // for parameters at that scope).
+
+ // Note: the defensive coding is here because the
+ // "dummy" entry points we create for pass-through
+ // compilation will not have an associated module.
//
- // An even longer-term strategy would be to allow type layout to
- // be performed on IR types, so taht we don't need to have front-end
- // code worrying about this stuff.
- //
- for( auto subst = funcDeclRef.substitutions.substitutions; subst; subst = subst->outer )
+ if( auto module = getModule() )
{
- if( auto genericSubst = as<GenericSubstitution>(subst) )
- {
- for( auto arg : genericSubst->args )
- {
- if( auto taggedUnionType = as<TaggedUnionType>(arg) )
- {
- m_taggedUnionTypes.add(taggedUnionType);
- }
- }
- }
+ return 1;
}
-
- // Collect any existential-type parameters used by the entry point
- //
- _collectShaderParams();
+ return 0;
}
- Module* EntryPoint::getModule()
+ RefPtr<ComponentType> EntryPoint::getRequirement(Index index)
{
- return Slang::getModule(getFuncDecl());
+ SLANG_UNUSED(index);
+ SLANG_ASSERT(index == 0);
+ SLANG_ASSERT(getModule());
+ return getModule();
}
- Linkage* EntryPoint::getLinkage()
+ void EntryPoint::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
{
- return getModule()->getLinkage();
+ visitor->visitEntryPoint(this, as<EntryPointSpecializationInfo>(specializationInfo));
}
- //
- // EntryPointGroup
- //
-
- RefPtr<EntryPointGroup> EntryPointGroup::create(
- Linkage* linkage,
- List<RefPtr<EntryPoint>> const& entryPoints,
- DiagnosticSink* sink)
+ List<Module*> const& EntryPoint::getModuleDependencies()
{
- RefPtr<EntryPointGroup> group = new EntryPointGroup(linkage);
-
- for( auto entryPoint : entryPoints )
- {
- for( auto module : entryPoint->getModuleDependencies() )
- {
- group->m_dependencyList.addDependency(module);
- }
- group->m_entryPoints.add(entryPoint);
- }
+ if(auto module = getModule())
+ return module->getModuleDependencies();
- group->_collectShaderParams(sink);
+ static List<Module*> empty;
+ return empty;
+ }
- return group;
+ List<String> const& EntryPoint::getFilePathDependencies()
+ {
+ if(auto module = getModule())
+ return getModule()->getFilePathDependencies();
+
+ static List<String> empty;
+ return empty;
}
//
@@ -1949,7 +1943,7 @@ SlangResult dissassembleDXILUsingDXC(
TargetRequest* targetReq,
Int entryPointIndex)
{
- auto program = compileRequest->getSpecializedProgram();
+ auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType();
auto targetProgram = program->getTargetProgram(targetReq);
auto backEndReq = compileRequest->getBackEndReq();
@@ -2020,7 +2014,7 @@ SlangResult dissassembleDXILUsingDXC(
return result;
RefPtr<BackEndCompileRequest> backEndRequest = new BackEndCompileRequest(
- m_program->getLinkageImpl(),
+ m_program->getLinkage(),
sink,
m_program);
@@ -2083,7 +2077,7 @@ SlangResult dissassembleDXILUsingDXC(
if (compileRequest->isCommandLineCompile)
{
auto linkage = compileRequest->getLinkage();
- auto program = compileRequest->getSpecializedProgram();
+ auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType();
for (auto targetReq : linkage->targets)
{
Index entryPointCount = program->getEntryPointCount();
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index b7f62ae5b..46b257fe3 100644
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -116,7 +116,6 @@ namespace Slang
class Linkage;
class Module;
- class Program;
class FrontEndCompileRequest;
class BackEndCompileRequest;
class EndToEndCompileRequest;
@@ -146,13 +145,19 @@ namespace Slang
struct ShaderParamInfo
{
DeclRef<VarDeclBase> paramDeclRef;
- UInt firstExistentialTypeSlot = 0;
- UInt existentialTypeSlotCount = 0;
+ Int firstSpecializationParamIndex = 0;
+ Int specializationParamCount = 0;
};
/// Extended information specific to global shader parameters
struct GlobalShaderParamInfo : ShaderParamInfo
{
+ // TODO: This type should be eliminated if/when we remove
+ // support for compilation with multiple translation units
+ // that all declare the "same" shader parameter (e.g., a
+ // `cbuffer`) and expect those duplicate declarations
+ // to get the same parameter binding/layout.
+
// Additional global-scope declarations that are conceptually
// declaring the "same" parameter as the `paramDeclRef`.
List<DeclRef<VarDeclBase>> additionalParamDeclRefs;
@@ -203,7 +208,7 @@ namespace Slang
{
public:
/// Get the list of modules that are depended on.
- List<RefPtr<Module>> const& getModuleList() { return m_moduleList; }
+ List<Module*> const& getModuleList() { return m_moduleList; }
/// Add a module and everything it depends on to the list.
void addDependency(Module* module);
@@ -214,8 +219,8 @@ namespace Slang
private:
void _addDependency(Module* module);
- List<RefPtr<Module>> m_moduleList;
- HashSet<Module*> m_moduleSet;
+ List<Module*> m_moduleList;
+ HashSet<Module*> m_moduleSet;
};
/// Tracks an unordered list of filesystem paths that something depends on
@@ -245,6 +250,375 @@ namespace Slang
HashSet<String> m_filePathSet;
};
+ class EntryPoint;
+
+ class ComponentType;
+ class ComponentTypeVisitor;
+
+ /// Base class for "component types" that represent the pieces a final
+ /// shader program gets linked together from.
+ ///
+ class ComponentType : public RefObject, public slang::IComponentType
+ {
+ public:
+ //
+ // ISlangUnknown interface
+ //
+
+ SLANG_REF_OBJECT_IUNKNOWN_ALL;
+ ISlangUnknown* getInterface(Guid const& guid);
+
+ //
+ // slang::IComponentType interface
+ //
+
+ SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE;
+ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout(
+ SlangInt targetIndex,
+ slang::IBlob** outDiagnostics) SLANG_OVERRIDE;
+ SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics) SLANG_OVERRIDE;
+ SLANG_NO_THROW IComponentType* SLANG_MCALL specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ ISlangBlob** outDiagnostics) SLANG_OVERRIDE;
+
+ /// Get the linkage (aka "session" in the public API) for this component type.
+ Linkage* getLinkage() { return m_linkage; }
+
+ /// Get the target-specific version of this program for the given `target`.
+ ///
+ /// The `target` must be a target on the `Linkage` that was used to create this program.
+ TargetProgram* getTargetProgram(TargetRequest* target);
+
+ /// Get the number of entry points linked into this component type.
+ virtual Index getEntryPointCount() = 0;
+
+ /// Get one of the entry points linked into this component type.
+ virtual RefPtr<EntryPoint> getEntryPoint(Index index) = 0;
+
+ /// Get the number of global shader parameters linked into this component type.
+ virtual Index getShaderParamCount() = 0;
+
+ /// Get one of the global shader parametesr linked into this component type.
+ virtual GlobalShaderParamInfo getShaderParam(Index index) = 0;
+
+ /// Get the number of (unspecialized) specialization parameters for the component type.
+ virtual Index getSpecializationParamCount() = 0;
+
+ /// Get the specialization parameter at `index`.
+ virtual SpecializationParam const& getSpecializationParam(Index index) = 0;
+
+ /// Get the number of "requirements" that this component type has.
+ ///
+ /// A requirement represents another component type that this component
+ /// needs in order to function correctly. For example, the dependency
+ /// of one module on another module that it `import`s is represented
+ /// as a requirement, as is the dependency of an entry point on the
+ /// module that defines it.
+ ///
+ virtual Index getRequirementCount() = 0;
+
+ /// Get the requirement at `index`.
+ virtual RefPtr<ComponentType> getRequirement(Index index) = 0;
+
+ /// Parse a type from a string, in the context of this component type.
+ ///
+ /// Any names in the string will be resolved using the modules
+ /// referenced by the program.
+ ///
+ /// On an error, returns null and reports diagnostic messages
+ /// to the provided `sink`.
+ ///
+ /// TODO: This function shouldn't be on the base class, since
+ /// it only really makes sense on `Module` and (as a compatibility
+ /// feature) on `LegacyProgram`.
+ ///
+ Type* getTypeFromString(
+ String const& typeStr,
+ DiagnosticSink* sink);
+
+ /// Get a list of modules that this component type depends on.
+ ///
+ virtual List<Module*> const& getModuleDependencies() = 0;
+
+ /// Get the full list of filesystem paths this component type depends on.
+ ///
+ virtual List<String> const& getFilePathDependencies() = 0;
+
+ /// Callback for use with `enumerateIRModules`
+ typedef void (*EnumerateIRModulesCallback)(IRModule* irModule, void* userData);
+
+ /// Invoke `callback` on all the IR modules that are (transitively) linked into this component type.
+ void enumerateIRModules(EnumerateIRModulesCallback callback, void* userData);
+
+ /// Invoke `callback` on all the IR modules that are (transitively) linked into this component type.
+ template<typename F>
+ void enumerateIRModules(F const& callback)
+ {
+ struct Helper
+ {
+ static void helper(IRModule* irModule, void* userData)
+ {
+ (*(F*)userData)(irModule);
+ }
+ };
+ enumerateIRModules(&Helper::helper, (void*)&callback);
+ }
+
+ /// Callback for use with `enumerateModules`
+ typedef void (*EnumerateModulesCallback)(Module* module, void* userData);
+
+ /// Invoke `callback` on all the modules that are (transitively) linked into this component type.
+ void enumerateModules(EnumerateModulesCallback callback, void* userData);
+
+ /// Invoke `callback` on all the modules that are (transitively) linked into this component type.
+ template<typename F>
+ void enumerateModules(F const& callback)
+ {
+ struct Helper
+ {
+ static void helper(Module* module, void* userData)
+ {
+ (*(F*)userData)(module);
+ }
+ };
+ enumerateModules(&Helper::helper, (void*)&callback);
+ }
+
+ /// Side-band information generated when specializing this component type.
+ ///
+ /// Difference subclasses of `ComponentType` are expected to create their
+ /// own subclass of `SpecializationInfo` as the output of `_validateSpecializationArgs`.
+ /// Later, whenever we want to use a specialized component type we will
+ /// also have the `SpecializationInfo` available and will expect it to
+ /// have the correct (subclass-specific) type.
+ ///
+ class SpecializationInfo : public RefObject
+ {
+ };
+
+ /// Validate the given specialization `args` and compute any side-band specialization info.
+ ///
+ /// Any errors will be reported to `sink`, which can thus be used to test
+ /// if the operation was successful.
+ ///
+ /// A null return value is allowed, since not all subclasses require
+ /// custom side-band specialization information.
+ ///
+ /// This function is an implementation detail of `specialize()`.
+ ///
+ virtual RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink) = 0;
+
+ /// Validate the given specialization `args` and compute any side-band specialization info.
+ ///
+ /// Any errors will be reported to `sink`, which can thus be used to test
+ /// if the operation was successful.
+ ///
+ /// A null return value is allowed, since not all subclasses require
+ /// custom side-band specialization information.
+ ///
+ /// This function is an implementation detail of `specialize()`.
+ ///
+ RefPtr<SpecializationInfo> _validateSpecializationArgs(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink)
+ {
+ if(argCount == 0) return nullptr;
+ return _validateSpecializationArgsImpl(args, argCount, sink);
+ }
+
+ /// Specialize this component type given `specializationArgs`
+ ///
+ /// Any diagnostics will be reported to `sink`, which can be used
+ /// to determine if the operation was successful. It is allowed
+ /// for this operation to have a non-null return even when an
+ /// error is ecnountered.
+ ///
+ RefPtr<ComponentType> specialize(
+ SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ DiagnosticSink* sink);
+
+ /// Invoke `visitor` on this component type, using the appropriate dynamic type.
+ ///
+ /// This function implements the "visitor pattern" for `ComponentType`.
+ ///
+ /// If the `specializationInfo` argument is non-null, it must be specialization
+ /// information generated for this specific component type by `_validateSpecializationArgs`.
+ /// In that case, appropriately-typed specialization information will be passed
+ /// when invoking the `visitor`.
+ ///
+ virtual void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) = 0;
+
+ protected:
+ ComponentType(Linkage* linkage);
+
+ private:
+ RefPtr<Linkage> m_linkage;
+
+ // Cache of target-specific programs for each target.
+ Dictionary<TargetRequest*, RefPtr<TargetProgram>> m_targetPrograms;
+
+ // Any types looked up dynamically using `getTypeFromString`
+ //
+ // TODO: Remove this. Type lookup should only be supported on `Module`s.
+ //
+ Dictionary<String, RefPtr<Type>> m_types;
+ };
+
+ /// A component type built up from other component types.
+ class CompositeComponentType : public ComponentType
+ {
+ public:
+ static RefPtr<ComponentType> create(
+ Linkage* linkage,
+ List<RefPtr<ComponentType>> const& childComponents);
+
+ List<RefPtr<ComponentType>> const& getChildComponents() { return m_childComponents; };
+ Index getChildComponentCount() { return m_childComponents.getCount(); }
+ RefPtr<ComponentType> getChildComponent(Index index) { return m_childComponents[index]; }
+
+ Index getEntryPointCount() SLANG_OVERRIDE;
+ RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE;
+
+ Index getShaderParamCount() SLANG_OVERRIDE;
+ GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE;
+
+ Index getSpecializationParamCount() SLANG_OVERRIDE;
+ SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE;
+
+ Index getRequirementCount() SLANG_OVERRIDE;
+ RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
+
+ List<Module*> const& getModuleDependencies() SLANG_OVERRIDE;
+ List<String> const& getFilePathDependencies() SLANG_OVERRIDE;
+
+ class CompositeSpecializationInfo : public SpecializationInfo
+ {
+ public:
+ List<RefPtr<SpecializationInfo>> childInfos;
+ };
+
+ protected:
+ void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE;
+
+
+ RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink) SLANG_OVERRIDE;
+
+ private:
+ CompositeComponentType(
+ Linkage* linkage,
+ List<RefPtr<ComponentType>> const& childComponents);
+
+ List<RefPtr<ComponentType>> m_childComponents;
+
+ // The following arrays hold the concatenated entry points, parameters,
+ // etc. from the child components. This approach allows for reasonably
+ // fast (constant time) access through operations like `getShaderParam`,
+ // but means that the memory usage of a composite is proportional to
+ // the sum of the memory usage of the children, rather than being fixed
+ // by the number of children (as it would be if we just stored
+ // `m_childComponents`).
+ //
+ // TODO: We could conceivably build some O(numChildren) arrays that
+ // support binary-search to provide logarithmic-time access to entry
+ // points, parameters, etc. while giving a better overall memory usage.
+ //
+ List<EntryPoint*> m_entryPoints;
+ List<GlobalShaderParamInfo> m_shaderParams;
+ List<SpecializationParam> m_specializationParams;
+ List<ComponentType*> m_requirements;
+
+ ModuleDependencyList m_moduleDependencyList;
+ FilePathDependencyList m_filePathDependencyList;
+ };
+
+ /// A component type created by specializing another component type.
+ class SpecializedComponentType : public ComponentType
+ {
+ public:
+ SpecializedComponentType(
+ ComponentType* base,
+ SpecializationInfo* specializationInfo,
+ List<SpecializationArg> const& specializationArgs,
+ DiagnosticSink* sink);
+
+ /// Get the base (unspecialized) component type that is being specialized.
+ RefPtr<ComponentType> getBaseComponentType() { return m_base; }
+
+ RefPtr<SpecializationInfo> getSpecializationInfo() { return m_specializationInfo; }
+
+ /// Get the number of arguments supplied for existential type parameters.
+ ///
+ /// Note that the number of arguments may not match the number of parameters.
+ /// In particular, an unspecialized entry point may have many parameters, but zero arguments.
+ Index getSpecializationArgCount() { return m_specializationArgs.getCount(); }
+
+ /// Get the existential type argument (type and witness table) at `index`.
+ SpecializationArg const& getSpecializationArg(Index index) { return m_specializationArgs[index]; }
+
+ /// Get an array of all existential type arguments.
+ SpecializationArg const* getSpecializationArgs() { return m_specializationArgs.getBuffer(); }
+
+ Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); }
+ RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { return m_base->getEntryPoint(index); }
+
+ Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); }
+ GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_base->getShaderParam(index); }
+
+ Index getSpecializationParamCount() SLANG_OVERRIDE { return 0; }
+ SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); static SpecializationParam dummy; return dummy; }
+
+ Index getRequirementCount() SLANG_OVERRIDE;
+ RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
+
+ /// TODO: These should include requirements/dependencies for the types
+ /// referenced in the specialization arguments...
+ List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_base->getModuleDependencies(); }
+ List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_base->getFilePathDependencies(); }
+
+ /// Get a list of tagged-union types referenced by the specialization parameters.
+ List<RefPtr<TaggedUnionType>> const& getTaggedUnionTypes() { return m_taggedUnionTypes; }
+
+ RefPtr<IRModule> getIRModule() { return m_irModule; }
+
+ void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE;
+
+ protected:
+
+ RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(args);
+ SLANG_UNUSED(argCount);
+ SLANG_UNUSED(sink);
+ return nullptr;
+ }
+
+ private:
+ RefPtr<ComponentType> m_base;
+ RefPtr<SpecializationInfo> m_specializationInfo;
+ SpecializationArgs m_specializationArgs;
+ RefPtr<IRModule> m_irModule;
+
+ // Any tagged union types that were referenced by the specialization arguments.
+ List<RefPtr<TaggedUnionType>> m_taggedUnionTypes;
+
+ };
+
/// Describes an entry point for the purposes of layout and code generation.
///
/// This class also tracks any generic arguments to the entry point,
@@ -255,11 +629,12 @@ namespace Slang
/// `getName()` and `getProfile()` methods should be expected to
/// return useful data on pass-through entry points.
///
- class EntryPoint : public RefObject
+ class EntryPoint : public ComponentType
{
public:
/// Create an entry point that refers to the given function.
static RefPtr<EntryPoint> create(
+ Linkage* linkage,
DeclRef<FuncDecl> funcDeclRef,
Profile profile);
@@ -286,55 +661,66 @@ namespace Slang
/// Get the module that contains the entry point.
Module* getModule();
- /// Get the linkage that contains the module for this entry point.
- Linkage* getLinkage();
-
/// Get a list of modules that this entry point depends on.
///
/// This will include the module that defines the entry point (see `getModule()`),
/// but may also include modules that are required by its generic type arguments.
///
- List<RefPtr<Module>> getModuleDependencies() { return m_dependencyList.getModuleList(); }
-
- /// Get a list of tagged-union types referenced by the entry point's generic parameters.
- List<RefPtr<TaggedUnionType>> const& getTaggedUnionTypes() { return m_taggedUnionTypes; }
+ List<Module*> const& getModuleDependencies() SLANG_OVERRIDE; // { return getModule()->getModuleDependencies(); }
+ List<String> const& getFilePathDependencies() SLANG_OVERRIDE; // { return getModule()->getFilePathDependencies(); }
/// Create a dummy `EntryPoint` that is only usable for pass-through compilation.
static RefPtr<EntryPoint> createDummyForPassThrough(
+ Linkage* linkage,
Name* name,
Profile profile);
/// Get the number of existential type parameters for the entry point.
- Index getExistentialTypeParamCount() { return m_existentialSlots.paramTypes.getCount(); }
+ Index getSpecializationParamCount() SLANG_OVERRIDE;
/// Get the existential type parameter at `index`.
- Type* getExistentialTypeParam(Index index) { return m_existentialSlots.paramTypes[index]; }
+ SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE;
- /// Get the number of arguments supplied for existential type parameters.
- ///
- /// Note that the number of arguments may not match the number of parameters.
- /// In particular, an unspecialized entry point may have many parameters, but zero arguments.
- Index getExistentialTypeArgCount() { return m_existentialSlots.args.getCount(); }
+ Index getRequirementCount() SLANG_OVERRIDE;
+ RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
- /// Get the existential type argument (type and witness table) at `index`.
- ExistentialTypeSlots::Arg getExistentialTypeArg(Index index) { return m_existentialSlots.args[index]; }
+ SpecializationParams const& getExistentialSpecializationParams() { return m_existentialSpecializationParams; }
- /// Get an array of all existential type arguments.
- ExistentialTypeSlots::Arg const* getExistentialTypeArgs() { return m_existentialSlots.args.getBuffer(); }
+ Index getGenericSpecializationParamCount() { return m_genericSpecializationParams.getCount(); }
+ Index getExistentialSpecializationParamCount() { return m_existentialSpecializationParams.getCount(); }
/// Get an array of all entry-point shader parameters.
List<ShaderParamInfo> const& getShaderParams() { return m_shaderParams; }
- void _specializeExistentialTypeParams(
- List<RefPtr<Expr>> const& args,
- DiagnosticSink* sink);
+ Index getEntryPointCount() SLANG_OVERRIDE { return 1; };
+ RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return this; }
+
+ Index getShaderParamCount() SLANG_OVERRIDE { return 0; }
+ GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return GlobalShaderParamInfo(); }
+
+ class EntryPointSpecializationInfo : public SpecializationInfo
+ {
+ public:
+ DeclRef<FuncDecl> specializedFuncDeclRef;
+ List<ExpandedSpecializationArg> existentialSpecializationArgs;
+ };
+
+ protected:
+ void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE;
+
+ RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink) SLANG_OVERRIDE;
private:
EntryPoint(
+ Linkage* linkage,
Name* name,
Profile profile,
DeclRef<FuncDecl> funcDeclRef);
+ void _collectGenericSpecializationParamsRec(Decl* decl);
void _collectShaderParams();
// The name of the entry point function (e.g., `main`)
@@ -345,8 +731,8 @@ namespace Slang
//
DeclRef<FuncDecl> m_funcDeclRef;
- /// The existential/interface slots associated with the entry point parameter scope.
- ExistentialTypeSlots m_existentialSlots;
+ SpecializationParams m_genericSpecializationParams;
+ SpecializationParams m_existentialSpecializationParams;
/// Information about entry-point parameters
List<ShaderParamInfo> m_shaderParams;
@@ -362,61 +748,6 @@ namespace Slang
// intrinsic to the entry point.
//
Profile m_profile;
-
- // Any tagged union types that were referenced by the generic arguments of the entry point.
- List<RefPtr<TaggedUnionType>> m_taggedUnionTypes;
-
- // Modules the entry point depends on.
- ModuleDependencyList m_dependencyList;
- };
-
- class EntryPointGroup : public RefObject
- {
- public:
- static RefPtr<EntryPointGroup> create(
- Linkage* linkage,
- List<RefPtr<EntryPoint>> const& entryPoints,
- DiagnosticSink* sink);
-
- Linkage* getLinkageImpl() { return m_linkage; }
-
- /// Get the number of entry points in the group
- Index getEntryPointCount() { return m_entryPoints.getCount(); }
-
- /// Get the entry point at the given `index`.
- RefPtr<EntryPoint> getEntryPoint(Index index) { return m_entryPoints[index]; }
-
- /// Get the full ist of entry points in the group.
- List<RefPtr<EntryPoint>> const& getEntryPoints() { return m_entryPoints; }
-
- /// Get a list of modules that this entry point group depends on.
- ///
- /// This will include the dependencies of all of the entry points in the group.
- ///
- List<RefPtr<Module>> getModuleDependencies() { return m_dependencyList.getModuleList(); }
-
- /// Get an array of all entry-point-group shader parameters.
- List<ShaderParamInfo> const& getShaderParams() { return m_shaderParams; }
-
- private:
- EntryPointGroup(Linkage* linkage)
- : m_linkage(linkage)
- {}
-
- void _collectShaderParams(DiagnosticSink* sink);
-
- Linkage* m_linkage;
- List<RefPtr<EntryPoint>> m_entryPoints;
-
- /// Information about shader parameters to be associated with the entry-point group itself.
- ///
- /// This list captures parameters that logically belong to the group itself, rather than
- /// to any specific entry point in the group.
- ///
- List<ShaderParamInfo> m_shaderParams;
-
- /// Modules the entry point group depends on.
- ModuleDependencyList m_dependencyList;
};
enum class PassThroughMode : SlangPassThrough
@@ -439,19 +770,52 @@ namespace Slang
/// may span multiple Slang source files), and provides access
/// to both the AST and IR representations of that code.
///
- class Module : public RefObject, public slang::IModule
+ class Module : public ComponentType, public slang::IModule
{
+ typedef ComponentType Super;
+
public:
SLANG_REF_OBJECT_IUNKNOWN_ALL
ISlangUnknown* getInterface(const Guid& guid);
+
+ // Forward `IComponentType` methods
+
+ SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE
+ {
+ return Super::getSession();
+ }
+
+ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout(
+ SlangInt targetIndex,
+ slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getLayout(targetIndex, outDiagnostics);
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
+ }
+
+ SLANG_NO_THROW IComponentType* SLANG_MCALL specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ ISlangBlob** outDiagnostics) SLANG_OVERRIDE
+ {
+ return Super::specialize(specializationArgs, specializationArgCount, outDiagnostics);
+ }
+
+ //
+
/// Create a module (initially empty).
Module(Linkage* linkage);
- /// Get the parent linkage of this module.
- Linkage* getLinkage() { return m_linkage; }
-
/// Get the AST for the module (if it has been parsed)
ModuleDecl* getModuleDecl() { return m_moduleDecl; }
@@ -459,7 +823,7 @@ namespace Slang
IRModule* getIRModule() { return m_irModule; }
/// Get the list of other modules this module depends on
- List<RefPtr<Module>> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); }
+ List<Module*> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); }
/// Get the list of filesystem paths this module depends on
List<String> const& getFilePathDependencyList() { return m_filePathDependencyList.getFilePathList(); }
@@ -474,7 +838,7 @@ namespace Slang
///
/// This should only be called once, during creation of the module.
///
- void setModuleDecl(ModuleDecl* moduleDecl) { m_moduleDecl = moduleDecl; }
+ void setModuleDecl(ModuleDecl* moduleDecl);// { m_moduleDecl = moduleDecl; }
/// Set the IR for this module.
///
@@ -482,16 +846,65 @@ namespace Slang
///
void setIRModule(IRModule* irModule) { m_irModule = irModule; }
- private:
- // The parent linkage
- Linkage* m_linkage = nullptr;
+ Index getEntryPointCount() SLANG_OVERRIDE { return 0; }
+ RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return nullptr; }
+
+ Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); }
+ GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; }
+
+ Index getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); }
+ SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { return m_specializationParams[index]; }
+
+ Index getRequirementCount() SLANG_OVERRIDE;
+ RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
+
+ List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencyList.getModuleList(); }
+ List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencyList.getFilePathList(); }
+
+ /// Collect information on the shader parameters of the module.
+ ///
+ /// This method should only be called once, after the core
+ /// structured of the module (its AST and IR) have been created,
+ /// and before any of the `ComponentType` APIs are used.
+ ///
+ /// TODO: We might eventually consider a non-stateful approach
+ /// to constructing a `Module`.
+ ///
+ void _collectShaderParams();
+ class ModuleSpecializationInfo : public SpecializationInfo
+ {
+ public:
+ struct GenericArgInfo
+ {
+ RefPtr<Decl> paramDecl;
+ RefPtr<Val> argVal;
+ };
+
+ List<GenericArgInfo> genericArgs;
+ List<ExpandedSpecializationArg> existentialArgs;
+ };
+
+ protected:
+ void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE;
+
+ RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink) SLANG_OVERRIDE;
+
+ private:
// The AST for the module
RefPtr<ModuleDecl> m_moduleDecl;
// The IR for the module
RefPtr<IRModule> m_irModule = nullptr;
+ List<GlobalShaderParamInfo> m_shaderParams;
+ SpecializationParams m_specializationParams;
+
+ List<Module*> m_requirements;
+
// List of modules this module depends on
ModuleDependencyList m_moduleDependencyList;
@@ -650,15 +1063,10 @@ namespace Slang
SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModule(
const char* moduleName,
slang::IBlob** outDiagnostics = nullptr) override;
- SLANG_NO_THROW SlangResult SLANG_MCALL createProgram(
- slang::ProgramDesc const& desc,
- slang::IProgram** outProgram) override;
- SLANG_NO_THROW SlangResult SLANG_MCALL specializeProgram(
- slang::IProgram* program,
- SlangInt specializationArgCount,
- slang::SpecializationArg const* specializationArgs,
- slang::IProgram** outSpecializedProgram,
- ISlangBlob** outDiagnostics = nullptr) override;
+ SLANG_NO_THROW slang::IComponentType* SLANG_MCALL createCompositeComponentType(
+ slang::IComponentType* const* componentTypes,
+ SlangInt componentTypeCount,
+ ISlangBlob** outDiagnostics = nullptr) override;
SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL specializeType(
slang::TypeReflection* type,
slang::SpecializationArg const* specializationArgs,
@@ -980,202 +1388,133 @@ namespace Slang
int translationUnitIndex,
String const& path);
- Program* getProgram() { return m_program; }
+ /// Get a component type that represents the global scope of the compile request.
+ ComponentType* getGlobalComponentType() { return m_globalComponentType; }
+
+ /// Get a component type that represents the global scope of the compile request, plus the requested entry points.
+ ComponentType* getGlobalAndEntryPointsComponentType() { return m_globalAndEntryPointsComponentType; }
private:
- RefPtr<Program> m_program;
+ /// A component type that includes only the global scopes of the translation unit(s) that were compiled.
+ RefPtr<ComponentType> m_globalComponentType;
+
+ /// A component type that extends the global scopes with all of the entry points that were specified.
+ RefPtr<ComponentType> m_globalAndEntryPointsComponentType;
};
- /// A collection of code modules and entry points that are intended to be used together.
+ /// A "legacy" program composes multiple translation units from a single compile request,
+ /// and takes care to treat global declarations of the same name from different translation
+ /// units as representing the "same" parameter.
///
- /// A `Program` establishes that certain pieces of code are intended
- /// to be used togehter so that, e.g., layout can make sure to allocate
- /// space for the global shader parameters in all referenced modules.
+ /// TODO: This type only exists to support a single requirement: that multiple translation
+ /// units can be compiled in one pass and be guaranteed that the "same" parameter declared
+ /// in different translation units (hence different modules) will get the same layout.
+ /// This feature should be deprecated and removed as soon as possible, since the complexity
+ /// it creates in the codebase is not justified by its limited utility.
///
- class Program : public RefObject, public slang::IProgram
+ class LegacyProgram : public ComponentType
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL;
- ISlangUnknown* getInterface(Guid const& guid);
+ LegacyProgram(
+ Linkage* linkage,
+ List<RefPtr<TranslationUnitRequest>> const& translationUnits,
+ DiagnosticSink* sink);
- SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override;
+ Index getTranslationUnitCount() { return m_translationUnits.getCount(); }
+ RefPtr<TranslationUnitRequest> getTranslationUnit(Index index) { return m_translationUnits[index]; }
- SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout(
- SlangInt targetIndex,
- slang::IBlob** outDiagnostics) override;
-
- SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
- SlangInt entryPointIndex,
- SlangInt targetIndex,
- slang::IBlob** outCode,
- slang::IBlob** outDiagnostics) override;
-
-
- /// Create a new program, initially empty.
- ///
- /// All code loaded into the program must come
- /// from the given `linkage`.
- Program(
- Linkage* linkage);
-
- /// Get the linkage that this program uses.
- Linkage* getLinkageImpl() { return m_linkage; }
-
- /// Get the number of entry points added to the program
- Index getEntryPointCount() { return m_entryPoints.getCount(); }
-
- /// Get the entry point at the given `index`.
- RefPtr<EntryPoint> getEntryPoint(Index index) { return m_entryPoints[index]; }
-
- /// Get the full ist of entry points on the program.
- List<RefPtr<EntryPoint>> const& getEntryPoints() { return m_entryPoints; }
-
-
- Index getEntryPointGroupCount() { return m_entryPointGroups.getCount(); }
- RefPtr<EntryPointGroup> getEntryPointGroup(Index index) { return m_entryPointGroups[index]; }
- List<RefPtr<EntryPointGroup>> const& getEntryPointGroups() { return m_entryPointGroups; }
-
-
- /// Get the substitution (if any) that represents how global generics are specialized.
- RefPtr<Substitutions> getGlobalGenericSubstitution() { return m_globalGenericSubst; }
-
- /// Get the full list of modules this program depends on
- List<RefPtr<Module>> getModuleDependencies() { return m_moduleDependencyList.getModuleList(); }
-
- /// Get the full list of filesystem paths this program depends on
- List<String> getFilePathDependencies() { return m_filePathDependencyList.getFilePathList(); }
-
- /// Get the target-specific version of this program for the given `target`.
- ///
- /// The `target` must be a target on the `Linkage` that was used to create this program.
- TargetProgram* getTargetProgram(TargetRequest* target);
-
- /// Add a module (and everything it depends on) to the list of references
- void addReferencedModule(Module* module);
-
- /// Add a module (but not the things it depends on) to the list of references
- ///
- /// This is a compatiblity hack for legacy compiler behavior.
- void addReferencedLeafModule(Module* module);
+ Index getEntryPointCount() SLANG_OVERRIDE { return 0; }
+ RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return nullptr; }
+ Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); }
+ GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; }
- /// Add an entry point to the program
- ///
- /// This also adds everything the entry point depends on to the list of references.
- ///
- void addEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink);
+ Index getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); }
+ SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { return m_specializationParams[index]; }
- /// Add an entry point group to the program
- ///
- /// This also adds everything the entry point group depends on to the list of references.
- ///
- void addEntryPointGroup(EntryPointGroup* entryPointGroup);
+ Index getRequirementCount() SLANG_OVERRIDE;
+ RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
- /// Set the global generic argument substitution to use.
- void setGlobalGenericSubsitution(RefPtr<Substitutions> subst)
- {
- m_globalGenericSubst = subst;
- }
+ List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies.getModuleList(); }
+ List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_fileDependencies.getFilePathList(); }
- /// Parse a type from a string, in the context of this program.
- ///
- /// Any names in the string will be resolved using the modules
- /// referenced by the program.
- ///
- /// On an error, returns null and reports diagnostic messages
- /// to the provided `sink`.
- ///
- Type* getTypeFromString(String typeStr, DiagnosticSink* sink);
-
- /// Get the IR module that represents this program and its entry points.
- ///
- /// The IR module for a program tries to be minimal, and in the
- /// common case will only include symbols with `[import]` declarations
- /// for the entry point(s) of the program, and any types they
- /// depend on.
- ///
- /// This IR module is intended to be linked against the IR modules
- /// for all of the dependencies (see `getModuleDependencies()`) to
- /// provide complete code.
- ///
- RefPtr<IRModule> getOrCreateIRModule(DiagnosticSink* sink);
-
- /// Get the number of existential type parameters for the program.
- Index getExistentialTypeParamCount() { return m_globalExistentialSlots.paramTypes.getCount(); }
-
- /// Get the existential type parameter at `index`.
- Type* getExistentialTypeParam(Index index) { return m_globalExistentialSlots.paramTypes[index]; }
-
- /// Get the number of arguments supplied for existential type parameters.
- ///
- /// Note that the number of arguments may not match the number of parameters.
- /// In particular, an unspecialized program may have many parameters, but zero arguments.
- Index getExistentialTypeArgCount() { return m_globalExistentialSlots.args.getCount(); }
-
- /// Get the existential type argument (type and witness table) at `index`.
- ExistentialTypeSlots::Arg getExistentialTypeArg(Index index) { return m_globalExistentialSlots.args[index]; }
-
- /// Get an array of all existential type arguments.
- ExistentialTypeSlots::Arg const* getExistentialTypeArgs() { return m_globalExistentialSlots.args.getBuffer(); }
-
- /// Get an array of all global shader parameters.
- List<GlobalShaderParamInfo> const& getShaderParams() { return m_shaderParams; }
+ protected:
+ void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE;
- void _collectShaderParams(DiagnosticSink* sink);
- void _specializeExistentialTypeParams(
- List<RefPtr<Expr>> const& args,
- DiagnosticSink* sink);
+ RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink) SLANG_OVERRIDE;
private:
+ void _collectShaderParams(DiagnosticSink* sink);
- // The linakge this program is associated with.
- //
- // Note that a `Program` keeps its associated linkage alive,
- // and not vice versa.
- //
- RefPtr<Linkage> m_linkage;
-
- // Tracking data for the list of modules dependend on
- ModuleDependencyList m_moduleDependencyList;
-
- // Tracking data for the list of filesystem paths dependend on
- FilePathDependencyList m_filePathDependencyList;
-
- // Entry points that are part of the program.
- List<RefPtr<EntryPoint> > m_entryPoints;
-
- // Entry points that are part of the program.
- List<RefPtr<EntryPointGroup> > m_entryPointGroups;
-
- // Specializations for global generic parameters (if any)
- RefPtr<Substitutions> m_globalGenericSubst;
-
- // The existential/interface slots associated with the global scope.
- ExistentialTypeSlots m_globalExistentialSlots;
+ List<RefPtr<TranslationUnitRequest>> m_translationUnits;
- /// Information about global shader parameters
+ List<EntryPoint*> m_entryPoints;
List<GlobalShaderParamInfo> m_shaderParams;
+ List<ComponentType*> m_requirements;
+ SpecializationParams m_specializationParams;
+ ModuleDependencyList m_moduleDependencies;
+ FilePathDependencyList m_fileDependencies;
+ };
- // Generated IR for this program.
- RefPtr<IRModule> m_irModule;
-
- // Cache of target-specific programs for each target.
- Dictionary<TargetRequest*, RefPtr<TargetProgram>> m_targetPrograms;
+ /// A visitor for use with `ComponentType`s, allowing dispatch over the concrete subclasses.
+ class ComponentTypeVisitor
+ {
+ public:
+ // The following methods should be overriden in a concrete subclass
+ // to customize how it acts on each of the concrete types of component.
+ //
+ // In cases where the application wants to simply "recurse" on a
+ // composite, specialized, or legacy component type it can use
+ // the `visitChildren` methods below.
+ //
+ virtual void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0;
+ virtual void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) = 0;
+ virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0;
+ virtual void visitSpecialized(SpecializedComponentType* specialized) = 0;
+ virtual void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0;
- // Any types looked up dynamically using `getTypeFromString`
- Dictionary<String, RefPtr<Type>> m_types;
+ protected:
+ // These helpers can be used to recurse into the logical children of a
+ // component type, and are useful for the common case where a visitor
+ // only cares about a few leaf cases.
+ //
+ // Note that for a `LegacyProgram` the "children" in this case are the
+ // `Module`s of the translation units that make up the legacy program.
+ // In some cases this is what is desired, but in others it is incorrect
+ // to treat a legacy program as a composition of modules, and instead
+ // it should be treated directly as a leaf case. Clients should make
+ // an informed decision based on an understanding of what `LegacyProgram` is used for.
+ //
+ void visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo);
+ void visitChildren(SpecializedComponentType* specialized);
+ void visitChildren(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo);
};
- /// A `Program` specialized for a particular `TargetRequest`
+ /// A `TargetProgram` represents a `ComponentType` specialized for a particular `TargetRequest`
+ ///
+ /// TODO: This should probably be renamed to `TargetComponentType`.
+ ///
+ /// By binding a component type to a specific target, a `TargetProgram` allows
+ /// for things like layout to be computed, that fundamentally depend on
+ /// the choice of target.
+ ///
+ /// A `TargetProgram` handles request for compiled kernel code for
+ /// entry point functions. In practice, kernel code can only be
+ /// correctly generated when the underlying `ComponentType` is "fully linked"
+ /// (has no remaining unsatisfied requirements).
+ ///
class TargetProgram : public RefObject
{
public:
TargetProgram(
- Program* program,
+ ComponentType* componentType,
TargetRequest* targetReq);
/// Get the underlying program
- Program* getProgram() { return m_program; }
+ ComponentType* getProgram() { return m_program; }
/// Get the underlying target
TargetRequest* getTargetReq() { return m_targetReq; }
@@ -1232,7 +1571,7 @@ namespace Slang
private:
// The program being compiled or laid out
- Program* m_program;
+ ComponentType* m_program;
// The target that code/layout will be generated for
TargetRequest* m_targetReq;
@@ -1253,7 +1592,7 @@ namespace Slang
BackEndCompileRequest(
Linkage* linkage,
DiagnosticSink* sink,
- Program* program = nullptr);
+ ComponentType* program = nullptr);
// Should we dump intermediate results along the way, for debugging?
bool shouldDumpIntermediates = false;
@@ -1263,8 +1602,8 @@ namespace Slang
LineDirectiveMode getLineDirectiveMode() { return lineDirectiveMode; }
- Program* getProgram() { return m_program; }
- void setProgram(Program* program) { m_program = program; }
+ ComponentType* getProgram() { return m_program; }
+ void setProgram(ComponentType* program) { m_program = program; }
// Should R/W images without explicit formats be assumed to have "unknown" format?
//
@@ -1273,7 +1612,7 @@ namespace Slang
bool useUnknownImageFormatAsDefault = false;
private:
- RefPtr<Program> m_program;
+ RefPtr<ComponentType> m_program;
};
/// A compile request that spans the front and back ends of the compiler
@@ -1311,11 +1650,8 @@ namespace Slang
// Should we just pass the input to another compiler?
PassThroughMode passThrough = PassThroughMode::None;
- /// Source code for the generic arguments to use for the global generic parameters of the program.
- List<String> globalGenericArgStrings;
-
- /// Types to use to fill global existential "slots"
- List<String> globalExistentialSlotArgStrings;
+ /// Source code for the specialization arguments to use for the global specialization parameters of the program.
+ List<String> globalSpecializationArgStrings;
bool shouldSkipCodegen = false;
@@ -1331,11 +1667,8 @@ namespace Slang
class EntryPointInfo : public RefObject
{
public:
- /// Source code for the generic arguments to use for the generic parameters of the entry point.
- List<String> genericArgStrings;
-
- /// Source code for the type arguments to plug into the existential type "slots" of the entry point
- List<String> existentialArgStrings;
+ /// Source code for the specialization arguments to use for the specialization parameters of the entry point.
+ List<String> specializationArgStrings;
};
List<EntryPointInfo> entryPoints;
@@ -1370,8 +1703,15 @@ namespace Slang
FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; }
BackEndCompileRequest* getBackEndReq() { return m_backEndReq; }
- Program* getUnspecializedProgram() { return getFrontEndReq()->getProgram(); }
- Program* getSpecializedProgram() { return m_specializedProgram; }
+
+ ComponentType* getUnspecializedGlobalComponentType() { return getFrontEndReq()->getGlobalComponentType(); }
+ ComponentType* getUnspecializedGlobalAndEntryPointsComponentType()
+ {
+ return getFrontEndReq()->getGlobalAndEntryPointsComponentType();
+ }
+
+ ComponentType* getSpecializedGlobalComponentType() { return m_specializedGlobalComponentType; }
+ ComponentType* getSpecializedGlobalAndEntryPointsComponentType() { return m_specializedGlobalAndEntryPointsComponentType; }
private:
void init();
@@ -1380,8 +1720,8 @@ namespace Slang
RefPtr<Linkage> m_linkage;
DiagnosticSink m_sink;
RefPtr<FrontEndCompileRequest> m_frontEndReq;
- RefPtr<Program> m_unspecializedProgram;
- RefPtr<Program> m_specializedProgram;
+ RefPtr<ComponentType> m_specializedGlobalComponentType;
+ RefPtr<ComponentType> m_specializedGlobalAndEntryPointsComponentType;
RefPtr<BackEndCompileRequest> m_backEndReq;
// For output
@@ -1623,14 +1963,14 @@ inline slang::IModule* asExternal(Module* module)
return static_cast<slang::IModule*>(module);
}
-inline Program* asInternal(slang::IProgram* module)
+inline ComponentType* asInternal(slang::IComponentType* componentType)
{
- return static_cast<Program*>(module);
+ return static_cast<ComponentType*>(componentType);
}
-inline slang::IProgram* asExternal(Program* module)
+inline slang::IComponentType* asExternal(ComponentType* componentType)
{
- return static_cast<slang::IProgram*>(module);
+ return static_cast<slang::IComponentType*>(componentType);
}
static inline slang::ProgramLayout* asExternal(ProgramLayout* programLayout)
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index e8114dc95..802c89f8b 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -240,6 +240,8 @@ DIAGNOSTIC(30049, Note, thisIsImmutableByDefault, "a 'this' parameter is an imm
DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'")
DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'")
+DIAGNOSTIC(30060, Error, expectedAType, "expected a type got a '$0'")
+
DIAGNOSTIC(30100, Error, staticRefToNonStaticMember, "type '$0' cannot be used to refer to non-static member '$1'")
DIAGNOSTIC(30201, Error, functionRedefinition, "function '$0' already has a body")
@@ -361,21 +363,22 @@ DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is onl
DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration")
DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.")
-DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$2`.")
+DIAGNOSTIC(38021, Error, typeArgumentForGenericParameterDoesNotConformToInterface, "type argument `$0` for generic parameter `$1` does not conform to interface `$2`.")
DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself")
DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')")
+
DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_DispatchThreadID must be either scalar or vector (1 to 3) of uint/int but is $0");
DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'")
-DIAGNOSTIC(38025, Error, mismatchGlobalGenericArguments, "expected $0 global generic arguments ($1 provided)")
+DIAGNOSTIC(38025, Error, mismatchSpecializationArguments, "expected $0 specialization arguments ($1 provided)")
DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.")
DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)")
DIAGNOSTIC(38028, Error, existentialSlotArgNotAType, "existential slot argument $0 was not a type")
-DIAGNOSTIC(38029, Error, existentialSlotArgDoesNotConform, "existential slot argument $0 does not conform to the required interface '$1'")
+DIAGNOSTIC(38029, Error,typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'")
DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself")
DIAGNOSTIC(39999, Fatal, errorInImportedModule, "error in imported module, compilation ceased.")
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index c7b5b602b..26af7b9f4 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -2459,13 +2459,13 @@ String CLikeSourceEmitter::getFuncName(IRFunc* func)
// name for an entry-point function, but other
// targets should try to use the original name.
//
- // TODO: always use `main`, and have any code
- // that wraps this know to use `main` instead
- // of the original entry-point name...
+ // TODO: always use the original name, and
+ // use the appropriate options for glslang to
+ // make it support a non-`main` name.
//
if (getSourceStyle() != SourceStyle::GLSL)
{
- return getText(entryPointLayout->entryPoint->getName());
+ return getText(entryPointLayout->getFuncDecl()->getName());
}
//
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index 3b9023703..d3c852fd6 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -687,20 +687,20 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointL
break;
case Stage::Geometry:
{
- if (auto attrib = entryPointLayout->entryPoint->FindModifier<MaxVertexCountAttribute>())
+ if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<MaxVertexCountAttribute>())
{
m_writer->emit("layout(max_vertices = ");
m_writer->emit(attrib->value);
m_writer->emit(") out;\n");
}
- if (auto attrib = entryPointLayout->entryPoint->FindModifier<InstanceAttribute>())
+ if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<InstanceAttribute>())
{
m_writer->emit("layout(invocations = ");
m_writer->emit(attrib->value);
m_writer->emit(") in;\n");
}
- for (auto pp : entryPointLayout->entryPoint->GetParameters())
+ for (auto pp : entryPointLayout->getFuncDecl()->GetParameters())
{
if (auto inputPrimitiveTypeModifier = pp->FindModifier<HLSLGeometryShaderInputPrimitiveTypeModifier>())
{
diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp
index 9c0f7f02b..4abd692f8 100644
--- a/source/slang/slang-emit-hlsl.cpp
+++ b/source/slang/slang-emit-hlsl.cpp
@@ -294,13 +294,13 @@ void HLSLSourceEmitter::_emitHLSLEntryPointAttributes(IRFunc* irFunc, EntryPoint
break;
case Stage::Geometry:
{
- if (auto attrib = entryPointLayout->entryPoint->FindModifier<MaxVertexCountAttribute>())
+ if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<MaxVertexCountAttribute>())
{
m_writer->emit("[maxvertexcount(");
m_writer->emit(attrib->value);
m_writer->emit(")]\n");
}
- if (auto attrib = entryPointLayout->entryPoint->FindModifier<InstanceAttribute>())
+ if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<InstanceAttribute>())
{
m_writer->emit("[instance(");
m_writer->emit(attrib->value);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index c51040b98..d5dcaca98 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -49,36 +49,34 @@ enum class BuiltInCOp
EntryPointLayout* findEntryPointLayout(
ProgramLayout* programLayout,
- EntryPoint* entryPoint,
- EntryPointGroupLayout** outEntryPointGroupLayout = nullptr)
+ EntryPoint* entryPoint)
{
- for( auto entryPointGroupLayout : programLayout->entryPointGroups )
+ // TODO: This function shouldn't need to exist, and it
+ // somewhat hampers the capabilities of the compiler (e.g.,
+ // it isn't supported to have a single program contain
+ // two different "instances" of the same entry point).
+ //
+ // Code that cares about layouts should be looking up
+ // the entry point layout by index on a `ProgramLayout`,
+ // knowing that those indices will align with the order
+ // of entry points on the `ComponentType` for the program.
+
+ for( auto entryPointLayout : programLayout->entryPoints )
{
- for( auto entryPointLayout : entryPointGroupLayout->entryPoints )
- {
- if(entryPointLayout->entryPoint->getName() != entryPoint->getName())
- continue;
-
- // TODO: We need to be careful about this check, since it relies on
- // the profile information in the layout matching that in the request.
- //
- // What we really seem to want here is some dictionary mapping the
- // `EntryPoint` directly to the `EntryPointLayout`, and maybe
- // that is precisely what we should build...
- //
- if(entryPointLayout->profile != entryPoint->getProfile())
- continue;
-
- // TODO: can't easily filter on translation unit here...
- // Ideally the `EntryPoint` should get filled in with a pointer
- // the specific function declaration that represents the entry point.
-
- if( outEntryPointGroupLayout )
- {
- *outEntryPointGroupLayout = entryPointGroupLayout;
- }
- return entryPointLayout;
- }
+ if(entryPointLayout->entryPoint.GetName() != entryPoint->getName())
+ continue;
+
+ // TODO: We need to be careful about this check, since it relies on
+ // the profile information in the layout matching that in the request.
+ //
+ // What we really seem to want here is some dictionary mapping the
+ // `EntryPoint` directly to the `EntryPointLayout`, and maybe
+ // that is precisely what we should build...
+ //
+ if(entryPointLayout->profile != entryPoint->getProfile())
+ continue;
+
+ return entryPointLayout;
}
return nullptr;
diff --git a/source/slang/slang-ir-bind-existentials.cpp b/source/slang/slang-ir-bind-existentials.cpp
index e426e6e92..a99da3410 100644
--- a/source/slang/slang-ir-bind-existentials.cpp
+++ b/source/slang/slang-ir-bind-existentials.cpp
@@ -69,24 +69,8 @@ struct BindExistentialSlots
void processGlobalExistentialSlots()
{
- // If there are any global existential slots, we will expect
- // to find a `bindGlobalExistentialSlots` instruction at module scope.
- //
- // We will start out by finding that instruction, if it exists.
- //
- IRInst* bindGlobalExistentialSlotsInst = nullptr;
- for( auto inst : module->getGlobalInsts() )
- {
- if( inst->op == kIROp_BindGlobalExistentialSlots )
- {
- bindGlobalExistentialSlotsInst = inst;
- break;
- }
- }
-
- // Now we will start looking for global shader parameters that make
- // use of existential slots (we can determine this from their
- // layout).
+ // We will search for global shader parameters that make
+ // use of existential specialization parameters.
//
for( auto inst : module->getGlobalInsts() )
{
@@ -96,23 +80,27 @@ struct BindExistentialSlots
if(!globalParam)
continue;
- // We will delegate to a subroutine for the meat
- // of the work, since much of it can be shared
- // with the case for entry-point existential
- // parameters.
+ // We only care about global shader parameters
+ // that have existential specialization parameters,
+ // and we expect all such parameters to have a
+ // `[bindExistentialSlots(...)]` decoration that
+ // was added during IR linking.
//
- processParameter(globalParam, bindGlobalExistentialSlotsInst);
- }
+ auto bindSlotsInst = globalParam->findDecorationImpl(kIROp_BindExistentialSlotsDecoration);
+ if(!bindSlotsInst)
+ continue;
- // Once we are done looping over global shader parameters,
- // all of the relevant information from the
- // `bindGlobalExistentialSlots` instruction will have
- // been moved to the parameters themselves, so we
- // can eliminate the binding instruction.
- //
- if( bindGlobalExistentialSlotsInst )
- {
- bindGlobalExistentialSlotsInst->removeAndDeallocate();
+ replaceTypeUsingExistentialSlots(
+ globalParam,
+ bindSlotsInst->getOperandCount(),
+ bindSlotsInst->getOperands());
+
+ // Once we have propagated the information from
+ // the `[bindExistentialSlots(...)]` decoration
+ // down into the parameter's type, we no longer
+ // need the decoration.
+ //
+ bindSlotsInst->removeAndDeallocate();
}
}
@@ -150,9 +138,25 @@ struct BindExistentialSlots
// We then need to process each of the entry-point
// parameters just like we did for global parameters.
//
+ // Because the existential slot arguments for *all* of the parameters
+ // are attached in a single `[bindExistentialSlots(...)]` decoration,
+ // we need to carve them up appropriately across the parameters.
+ // The way we do this is a bit of a kludge, in that we track a
+ // single `slotOffset` and increment it for each parameter by the
+ // number of arguments it consumed.
+ //
+ // Note: a better approach here might rely on the layout information
+ // for the parameters, which should directly encode an offset for
+ // the existential specialization parameters it uses. The challenge
+ // with this is that we'd need to correctly interpret the offset
+ // relative to any global-scope specialization parameters or
+ // generic specialization parameters of the entry point.
+ // Ultimately the simplistic counter approach is less complicated.
+ //
+ Index slotOffset = 0;
for( auto param : func->getParams() )
{
- processParameter(param, bindEntryPointExistentialSlotsInst);
+ processEntryPointParameter(param, bindEntryPointExistentialSlotsInst, slotOffset);
}
// TODO: We would need to consider what to do if
@@ -181,9 +185,10 @@ struct BindExistentialSlots
// function `param` and a `[bindExistentialSlots(...)]`
// decoration; both use the same subroutine.
//
- void processParameter(
+ void processEntryPointParameter(
IRInst* param,
- IRInst* bindSlotsInst)
+ IRInst* bindSlotsInst,
+ Index& ioSlotOperandOffset)
{
// We expect all shader parameters to have layout information,
// but to be defensive we will skip any that don't.
@@ -206,7 +211,6 @@ struct BindExistentialSlots
// find out the stating slot, and the information on
// the type to find out the number of slots.
//
- UInt firstSlot = resInfo->index;
UInt slotCount = 0;
if(auto typeResInfo = varLayout->getTypeLayout()->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam))
slotCount = UInt(typeResInfo->count.getFiniteValue());
@@ -233,7 +237,8 @@ struct BindExistentialSlots
// this parameter.
//
UInt bindOperandCount = bindSlotsInst->getOperandCount();
- if( 2*(firstSlot + slotCount) > bindOperandCount )
+ UInt slotOperandCount = 2*slotCount;
+ if( (ioSlotOperandOffset + slotOperandCount) > bindOperandCount )
{
sink->diagnose(param->sourceLoc, Diagnostics::missingExistentialBindingsForParameter);
return;
@@ -244,7 +249,7 @@ struct BindExistentialSlots
// keeping in mind that each slot accounts for two
// operands.
//
- auto operandsForInst = bindSlotsInst->getOperands() + firstSlot;
+ auto operandsForInst = bindSlotsInst->getOperands() + ioSlotOperandOffset;
// Once we've found the operands that are relevent to
// the slots used by `param`, we will defer to a routine
@@ -253,17 +258,17 @@ struct BindExistentialSlots
//
replaceTypeUsingExistentialSlots(
param,
- slotCount,
+ slotOperandCount,
operandsForInst);
+
+ ioSlotOperandOffset += slotOperandCount;
}
void replaceTypeUsingExistentialSlots(
IRInst* inst,
- UInt slotCount,
+ UInt slotOperandCount,
IRUse const* slotArgs)
{
- SLANG_UNUSED(slotCount);
-
// We are going to alter the type of the
// given `inst` based on information in
// the `slotArgs`.
@@ -282,7 +287,6 @@ struct BindExistentialSlots
// a witness table, so the total number of operands
// is twice the number of slots we are filling.
//
- UInt slotOperandCount = slotCount*2;
List<IRInst*> slotOperands;
for(UInt ii = 0; ii < slotOperandCount; ++ii)
slotOperands.add(slotArgs[ii].get());
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 0c218f2aa..f9c157465 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -206,7 +206,6 @@ INST(Specialize, specialize, 2, 0)
INST(lookup_interface_method, lookup_interface_method, 2, 0)
INST(lookup_witness_table, lookup_witness_table, 2, 0)
INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0)
-INST(BindGlobalExistentialSlots, bindGlobalExistentialSlots, 0, 0)
INST(Construct, construct, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 359c8e98d..cd59630d8 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1175,10 +1175,6 @@ struct IRBuilder
IRInst* param,
IRInst* val);
- IRInst* emitBindGlobalExistentialSlots(
- UInt argCount,
- IRInst* const* args);
-
IRDecoration* addBindExistentialSlotsDecoration(
IRInst* value,
UInt argCount,
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 7e2ac1f98..56b06a499 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -8,14 +8,13 @@
namespace Slang
{
-// Needed for lookup up entry-point layouts.
-//
-// TODO: maybe arrange so that codegen is driven from the layout layer
-// instead of the input/request layer.
+ /// Find a suitable layout for `entryPoint` in `programLayout`.
+ ///
+ /// TODO: This function should be eliminated. See its body
+ /// for an explanation of the problems.
EntryPointLayout* findEntryPointLayout(
ProgramLayout* programLayout,
- EntryPoint* entryPoint,
- EntryPointGroupLayout** outEntryPointGroupLayout);
+ EntryPoint* entryPoint);
struct IRSpecSymbol : RefObject
{
@@ -339,9 +338,10 @@ IRType* cloneType(
}
void cloneGlobalValueWithCodeCommon(
- IRSpecContextBase* context,
- IRGlobalValueWithCode* clonedValue,
- IRGlobalValueWithCode* originalValue);
+ IRSpecContextBase* context,
+ IRGlobalValueWithCode* clonedValue,
+ IRGlobalValueWithCode* originalValue,
+ IROriginalValuesForClone const& originalValues);
IRRate* cloneRate(
IRSpecContextBase* context,
@@ -382,11 +382,58 @@ IRGlobalVar* cloneGlobalVarImpl(
cloneGlobalValueWithCodeCommon(
context,
clonedVar,
- originalVar);
+ originalVar,
+ originalValues);
return clonedVar;
}
+ /// Clone certain special decorations for `clonedInst` from its (potentially multiple) definitions.
+ ///
+ /// In most cases, once we've decided on the "best" definition to use for an IR instruction,
+ /// we only want the linking process to use the decorations from the single best definition.
+ /// In some casses, though, the canonical best definition might not have all the information.
+ ///
+ /// A concrete example is the `[bindExistentialsSlots(...)]` decorations for global shader
+ /// parameters and entry points. These decorations are only generated as part of the IR
+ /// associated with a specialization of a program, and not the original IR for the modules
+ /// of the program.
+ ///
+ /// This function scans through all the `originalValues` that were considered for `clonedInst`,
+ /// and copies over any decorations that are allowed to come from a non-"best" definition.
+ /// For a given decoration opcode, only one such decoration will ever be copied, and nothing
+ /// will be copied if the instruction already has a matching decoration (that was cloned
+ /// from the "best" definition).
+ ///
+static void cloneExtraDecorations(
+ IRSpecContextBase* context,
+ IRInst* clonedInst,
+ IROriginalValuesForClone const& originalValues)
+{
+ IRBuilder builderStorage = *context->builder;
+ IRBuilder* builder = &builderStorage;
+ builder->setInsertInto(clonedInst);
+
+ for(auto sym = originalValues.sym; sym; sym = sym->nextWithSameName)
+ {
+ for(auto decoration : sym->irGlobalValue->getDecorations())
+ {
+ switch(decoration->op)
+ {
+ default:
+ break;
+
+ case kIROp_BindExistentialSlotsDecoration:
+ if(!clonedInst->findDecorationImpl(decoration->op))
+ {
+ cloneInst(context, builder, decoration);
+ }
+ break;
+ }
+ }
+ }
+}
+
void cloneSimpleGlobalValueImpl(
IRSpecContextBase* context,
IRInst* originalInst,
@@ -407,6 +454,12 @@ void cloneSimpleGlobalValueImpl(
{
cloneInst(context, builder, child);
}
+
+ // Also clone certain decorations if they appear on *any*
+ // definition of the symbol (not necessarily the one
+ // we picked as the primary/best).
+ //
+ cloneExtraDecorations(context, clonedInst, originalValues);
}
IRGlobalParam* cloneGlobalParamImpl(
@@ -469,7 +522,8 @@ IRGeneric* cloneGenericImpl(
cloneGlobalValueWithCodeCommon(
context,
clonedVal,
- originalVal);
+ originalVal,
+ originalValues);
return clonedVal;
}
@@ -545,7 +599,8 @@ IRInterfaceType* cloneInterfaceTypeImpl(
void cloneGlobalValueWithCodeCommon(
IRSpecContextBase* context,
IRGlobalValueWithCode* clonedValue,
- IRGlobalValueWithCode* originalValue)
+ IRGlobalValueWithCode* originalValue,
+ IROriginalValuesForClone const& originalValues)
{
// Next we are going to clone the actual code.
IRBuilder builderStorage = *context->builder;
@@ -553,6 +608,7 @@ void cloneGlobalValueWithCodeCommon(
builder->setInsertInto(clonedValue);
cloneDecorations(context, clonedValue, originalValue);
+ cloneExtraDecorations(context, clonedValue, originalValues);
// We will walk through the blocks of the function, and clone each of them.
//
@@ -629,10 +685,11 @@ void checkIRDuplicate(IRInst* inst, IRInst* moduleInst, UnownedStringSlice const
}
void cloneFunctionCommon(
- IRSpecContextBase* context,
- IRFunc* clonedFunc,
- IRFunc* originalFunc,
- bool checkDuplicate = true)
+ IRSpecContextBase* context,
+ IRFunc* clonedFunc,
+ IRFunc* originalFunc,
+ IROriginalValuesForClone const& originalValues,
+ bool checkDuplicate = true)
{
// First clone all the simple properties.
clonedFunc->setFullType(cloneType(context, originalFunc->getFullType()));
@@ -640,7 +697,8 @@ void cloneFunctionCommon(
cloneGlobalValueWithCodeCommon(
context,
clonedFunc,
- originalFunc);
+ originalFunc,
+ originalValues);
// Shuffle the function to the end of the list, because
// it needs to follow its dependencies.
@@ -667,7 +725,6 @@ IRInst* specializeGeneric(
IRFunc* specializeIRForEntryPoint(
IRSpecContext* context,
- EntryPoint* entryPoint,
EntryPointLayout* entryPointLayout)
{
// We start by looking up the IR symbol that
@@ -679,7 +736,7 @@ IRFunc* specializeIRForEntryPoint(
// so that the mangled name of the decl-ref is
// not the same as the mangled name of the decl.
//
- auto mangledName = getMangledName(entryPoint->getFuncDeclRef());
+ auto mangledName = getMangledName(entryPointLayout->getFuncDeclRef());
RefPtr<IRSpecSymbol> sym;
if (!context->getSymbols().TryGetValue(mangledName, sym))
{
@@ -947,7 +1004,7 @@ IRFunc* cloneFuncImpl(
{
auto clonedFunc = builder->createFunc();
registerClonedValue(context, clonedFunc, originalValues);
- cloneFunctionCommon(context, clonedFunc, originalFunc);
+ cloneFunctionCommon(context, clonedFunc, originalFunc, originalValues);
return clonedFunc;
}
@@ -1227,7 +1284,9 @@ LinkedIR linkIR(
CodeGenTarget target,
TargetRequest* targetReq)
{
- auto sink = compileRequest->getSink();
+ // TODO: We need to make sure that the program we are being asked
+ // to compile has been "resolved" so that it has no outstanding
+ // unsatisfied requirements.
IRSpecializationState stateStorage;
auto state = &stateStorage;
@@ -1252,12 +1311,11 @@ LinkedIR linkIR(
// accelerate lookup, we will create a symbol table for looking
// up IR definitions by their mangled name.
//
- auto originalProgramIRModule = program->getOrCreateIRModule(sink);
- insertGlobalValueSymbols(sharedContext, originalProgramIRModule);
- for (auto module : program->getModuleDependencies())
+ program->enumerateIRModules([&](IRModule* irModule)
{
- insertGlobalValueSymbols(sharedContext, module->getIRModule());
- }
+ insertGlobalValueSymbols(sharedContext, irModule);
+ });
+
auto context = state->getContext();
context->shared = sharedContext;
@@ -1282,32 +1340,9 @@ LinkedIR linkIR(
context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout);
}
- EntryPointGroupLayout* entryPointGroupLayout = nullptr;
- auto entryPointLayout = findEntryPointLayout(programLayout, entryPoint, &entryPointGroupLayout);
-
- auto offsetEntryPointLayout = entryPointLayout->getAbsoluteLayout(entryPointGroupLayout);
+ auto entryPointLayout = findEntryPointLayout(programLayout, entryPoint);
- // Note: when we are doing the compatibility approach for Falcor, we
- // can have global-scope symbols that are actually part of the
- // local root signature (entry point group), so we need to make
- // sure to apply those layouts appropriately.
- auto entryPointGroupStructLayout = getScopeStructLayout(entryPointGroupLayout);
- for(auto entry : entryPointGroupStructLayout->mapVarToLayout)
- {
- if(!entry.Key)
- continue;
-
- auto mangledName = getMangledName(entry.Key);
- auto groupVarLayout = entry.Value;
-
- // We need to "adjust" the layout that was computed for the parameter
- // because it will be relative to the start of the entry-point group,
- // rather than absolute.
- //
- auto absoluteVarLayout = groupVarLayout->getAbsoluteLayout(entryPointGroupLayout->parametersLayout);
-
- context->globalVarLayouts.AddIfNotExists(mangledName, absoluteVarLayout);
- }
+ auto offsetEntryPointLayout = entryPointLayout;
context->builder->setInsertInto(context->getModule()->getModuleInst());
@@ -1316,7 +1351,7 @@ LinkedIR linkIR(
// TODO: This step should *not* be needed with the current IR
// specialization approach, so we should consider removing it.
//
- for (auto sym :context->getSymbols())
+ for (auto sym : context->getSymbols())
{
if (sym.Value->irGlobalValue->op == kIROp_WitnessTable)
cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue);
@@ -1327,28 +1362,31 @@ LinkedIR linkIR(
// the entry point function itself, and rely on
// this step to recursively copy over anything else
// it might reference.
- auto irEntryPoint = specializeIRForEntryPoint(context, entryPoint, offsetEntryPointLayout);
+ auto irEntryPoint = specializeIRForEntryPoint(context, offsetEntryPointLayout);
- // HACK: right now the bindings for global generic parameters are coming in
- // as part of the original IR module, and we need to make sure these get
- // copied over, even if they aren't referenced.
+ // Bindings for global generic parameters are currently represented
+ // as stand-alone global-scope instructions in the IR module for
+ // `SpecializedComponentType`s. These instructions are required for
+ // correct codegen, and so we must make sure to copy them all over,
+ // even though they are not directly referenced.
//
- for(auto inst : originalProgramIRModule->getGlobalInsts())
- {
- auto bindInst = as<IRBindGlobalGenericParam>(inst);
- if(!bindInst)
- continue;
-
- cloneValue(context, bindInst);
- }
-
- for(auto inst : originalProgramIRModule->getGlobalInsts())
+ // TODO: We should change these to decorations, akin to how
+ // `[bindExistentialSlots(...)]` works, so that they can be attached
+ // to the relevant parameters and cloned via `cloneExtraDecorations`.
+ // In the long run we do not want to *ever* iterate over all the
+ // instructions in all the input modules.
+ //
+ program->enumerateIRModules([&](IRModule* irModule)
{
- if(inst->op != kIROp_BindGlobalExistentialSlots)
- continue;
+ for(auto inst : irModule->getGlobalInsts())
+ {
+ auto bindInst = as<IRBindGlobalGenericParam>(inst);
+ if(!bindInst)
+ continue;
- cloneValue(context, inst);
- }
+ cloneValue(context, bindInst);
+ }
+ });
// HACK: we need to ensure that any tagged union types
// in the IR module have layout information copied over to them.
@@ -1357,7 +1395,7 @@ LinkedIR linkIR(
// instructions, since we expected the tagged union type(s) to
// be referenced by them.
//
- for( auto taggedUnionTypeLayout : entryPointLayout->taggedUnionTypeLayouts )
+ for( auto taggedUnionTypeLayout : programLayout->taggedUnionTypeLayouts )
{
auto taggedUnionType = taggedUnionTypeLayout->getType();
auto mangledName = getMangledTypeName(taggedUnionType);
@@ -1377,6 +1415,15 @@ LinkedIR linkIR(
// we have global variables with initializers, since
// these should get run whether or not the entry point
// references them.
+ //
+ // Or alternatively we can define by fiat that the initializers
+ // on global variables get run at an unspecified time between
+ // program startup and the first access to a given global.
+ // Such a definition gives us the freedom to eliminate globals
+ // that are never accessed, while still doing "eager"
+ // initialization for globals that are referenced (instead of
+ // having to add the overhead of lazy initialization a la
+ // function-`static` variables).
// Now that we've cloned the entry point and everything
// it refers to, we can package up the data we return
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index ed86e4d6d..2e8de5e37 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2880,22 +2880,6 @@ namespace Slang
return inst;
}
- IRInst* IRBuilder::emitBindGlobalExistentialSlots(
- UInt argCount,
- IRInst* const* args)
- {
- auto inst = createInstWithTrailingArgs<IRInst>(
- this,
- kIROp_BindGlobalExistentialSlots,
- getVoidType(),
- 0,
- nullptr,
- argCount,
- args);
- addInst(inst);
- return inst;
- }
-
IRDecoration* IRBuilder::addBindExistentialSlotsDecoration(
IRInst* value,
UInt argCount,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index eef69b1b5..e41be0e33 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1613,13 +1613,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto irBaseType = lowerType(context, type->baseType);
List<IRInst*> slotArgs;
- for(auto arg : type->slots.args)
+ for(auto arg : type->args)
{
- auto irArgType = lowerType(context, arg.type);
- auto irArgWitness = lowerSimpleVal(context, arg.witness);
+ auto irArgVal = lowerSimpleVal(context, arg.val);
+ slotArgs.add(irArgVal);
- slotArgs.add(irArgType);
- slotArgs.add(irArgWitness);
+ if(auto witness = arg.witness)
+ {
+ auto irArgWitness = lowerSimpleVal(context, witness);
+ slotArgs.add(irArgWitness);
+ }
}
auto irType = getBuilder()->getBindExistentialsType(irBaseType, slotArgs.getCount(), slotArgs.getBuffer());
@@ -6265,13 +6268,18 @@ static void lowerFrontEndEntryPointToIR(
}
static void lowerProgramEntryPointToIR(
- IRGenContext* context,
- EntryPoint* entryPoint)
+ IRGenContext* context,
+ EntryPoint* entryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo)
{
+ auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef();
+ if(specializationInfo)
+ entryPointFuncDeclRef = specializationInfo->specializedFuncDeclRef;
+
// First, lower the entry point like an ordinary function
+
auto session = context->getSession();
- auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef();
auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef));
auto builder = context->irBuilder;
@@ -6280,7 +6288,6 @@ static void lowerProgramEntryPointToIR(
auto loweredEntryPointFunc = getSimpleVal(context,
emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType));
- //
if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>())
{
builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice());
@@ -6289,26 +6296,24 @@ static void lowerProgramEntryPointToIR(
// We may have shader parameters of interface/existential type,
// which need us to supply concrete type information for specialization.
//
- auto existentialTypeArgCount = entryPoint->getExistentialTypeArgCount();
- if( existentialTypeArgCount )
+ if(specializationInfo && specializationInfo->existentialSpecializationArgs.getCount() != 0)
{
List<IRInst*> existentialSlotArgs;
- for( Index ii = 0; ii < existentialTypeArgCount; ++ii )
+ for(auto arg : specializationInfo->existentialSpecializationArgs )
{
- auto arg = entryPoint->getExistentialTypeArg(ii);
-
- auto irArgType = lowerType(context, arg.type);
- auto irWitnessTable = lowerSimpleVal(context, arg.witness);
+ auto irArgType = lowerSimpleVal(context, arg.val);
existentialSlotArgs.add(irArgType);
- existentialSlotArgs.add(irWitnessTable);
+
+ if(auto witness = arg.witness)
+ {
+ auto irWitnessTable = lowerSimpleVal(context, witness);
+ existentialSlotArgs.add(irWitnessTable);
+ }
}
builder->addBindExistentialSlotsDecoration(loweredEntryPointFunc, existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer());
}
-
-
-
}
/// Ensure that `decl` and all relevant declarations under it get emitted.
@@ -6451,97 +6456,136 @@ IRModule* generateIRForTranslationUnit(
return module;
}
-RefPtr<IRModule> generateIRForProgram(
- Session* session,
- Program* program,
- DiagnosticSink* sink)
+ /// Context for generating IR code to represent a `SpecializedComponentType`
+struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor
{
-// auto compileRequest = translationUnit->compileRequest;
+ DiagnosticSink* sink;
+ Linkage* linkage;
+ Session* session;
+ IRGenContext* context;
+ IRBuilder* builder;
- SharedIRGenContext sharedContextStorage(
- session,
- sink);
- SharedIRGenContext* sharedContext = &sharedContextStorage;
+ RefPtr<IRModule> process(
+ SpecializedComponentType* componentType,
+ DiagnosticSink* inSink)
+ {
+ sink = inSink;
- IRGenContext contextStorage(sharedContext);
- IRGenContext* context = &contextStorage;
+ linkage = componentType->getLinkage();
+ session = linkage->getSessionImpl();
- SharedIRBuilder sharedBuilderStorage;
- SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
- sharedBuilder->module = nullptr;
- sharedBuilder->session = session;
+ SharedIRGenContext sharedContextStorage(
+ session,
+ sink);
+ SharedIRGenContext* sharedContext = &sharedContextStorage;
- IRBuilder builderStorage;
- IRBuilder* builder = &builderStorage;
- builder->sharedBuilder = sharedBuilder;
+ IRGenContext contextStorage(sharedContext);
+ context = &contextStorage;
- RefPtr<IRModule> module = builder->createModule();
- sharedBuilder->module = module;
+ SharedIRBuilder sharedBuilderStorage;
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->module = nullptr;
+ sharedBuilder->session = session;
- context->irBuilder = builder;
+ IRBuilder builderStorage;
+ builder = &builderStorage;
+ builder->sharedBuilder = sharedBuilder;
- // We need to emit symbols for all of the entry
- // points in the program; this is especially
- // important in the case where a generic entry
- // point is being specialized.
- //
- for(auto entryPoint : program->getEntryPoints())
- {
- lowerProgramEntryPointToIR(context, entryPoint);
- }
+ RefPtr<IRModule> module = builder->createModule();
+ sharedBuilder->module = module;
+ builder->setInsertInto(module->getModuleInst());
- // Now lower all the arguments supplied for global generic
- // type parameters.
- //
- for (RefPtr<Substitutions> subst = program->getGlobalGenericSubstitution(); subst; subst = subst->outer)
- {
- auto gSubst = subst.as<GlobalGenericParamSubstitution>();
- if(!gSubst)
- continue;
-
- IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl));
- IRType* typeVal = lowerType(context, gSubst->actualType);
+ context->irBuilder = builder;
- // bind `typeParam` to `typeVal`
- builder->emitBindGlobalGenericParam(typeParam, typeVal);
+ componentType->acceptVisitor(this, nullptr);
- for (auto& constraintArg : gSubst->constraintArgs)
- {
- IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl));
- IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val);
+ return module;
+ }
- // bind `constraintParam` to `constraintVal`
- builder->emitBindGlobalGenericParam(constraintParam, constraintVal);
- }
+ void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ // We need to emit symbols for all of the entry
+ // points in the program; this is especially
+ // important in the case where a generic entry
+ // point is being specialized.
+ //
+ lowerProgramEntryPointToIR(context, entryPoint, specializationInfo);
}
- // We may have shader parameters of interface/existential type,
- // which need us to supply concrete type information for specialization.
- //
- auto existentialTypeArgCount = program->getExistentialTypeArgCount();
- if( existentialTypeArgCount )
+ void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
{
- List<IRInst*> existentialSlotArgs;
- for( Index ii = 0; ii < existentialTypeArgCount; ++ii )
+ // We've hit a leaf module, so we should be able to bind any global
+ // generic type parameters here...
+ //
+ if( specializationInfo )
{
- auto arg = program->getExistentialTypeArg(ii);
+ for( auto genericArgInfo : specializationInfo->genericArgs )
+ {
+ IRInst* irParam = getSimpleVal(context, ensureDecl(context, genericArgInfo.paramDecl));
+ IRInst* irVal = lowerSimpleVal(context, genericArgInfo.argVal);
- auto irArgType = lowerType(context, arg.type);
- auto irWitnessTable = lowerSimpleVal(context, arg.witness);
+ // bind `irParam` to `irVal`
+ builder->emitBindGlobalGenericParam(irParam, irVal);
+ }
- existentialSlotArgs.add(irArgType);
- existentialSlotArgs.add(irWitnessTable);
+ auto shaderParamCount = module->getShaderParamCount();
+ Index existentialArgOffset = 0;
+
+ for( Index ii = 0; ii < shaderParamCount; ++ii )
+ {
+ auto shaderParam = module->getShaderParam(ii);
+ auto specializationArgCount = shaderParam.specializationParamCount;
+
+ IRInst* irParam = getSimpleVal(context, ensureDecl(context, shaderParam.paramDeclRef));
+ List<IRInst*> irSlotArgs;
+ for( Index jj = 0; jj < specializationArgCount; ++jj )
+ {
+ auto& specializationArg = specializationInfo->existentialArgs[existentialArgOffset++];
+
+ auto irType = lowerSimpleVal(context, specializationArg.val);
+ auto irWitness = lowerSimpleVal(context, specializationArg.witness);
+
+ irSlotArgs.add(irType);
+ irSlotArgs.add(irWitness);
+ }
+
+ builder->addBindExistentialSlotsDecoration(
+ irParam,
+ irSlotArgs.getCount(),
+ irSlotArgs.getBuffer());
+ }
}
+ }
- builder->emitBindGlobalExistentialSlots(existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer());
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(composite, specializationInfo);
}
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
+ {
+ visitChildren(specialized);
+ }
- // TODO: Should we apply any of the validation or
- // mandatory optimization passes here?
+ void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ // TODO: This case should be akin to the `Module` case,
+ // and deal with global-scope specialization parameters
+ // directly.
+ //
+ SLANG_UNUSED(legacy);
+ SLANG_UNUSED(specializationInfo);
+ SLANG_UNIMPLEMENTED_X("legacy program case");
+ }
+};
- return module;
+RefPtr<IRModule> generateIRForSpecializedComponentType(
+ SpecializedComponentType* componentType,
+ DiagnosticSink* sink)
+{
+ SpecializedComponentTypeIRGenContext context;
+ return context.process(componentType, sink);
}
} // namespace Slang
diff --git a/source/slang/slang-lower-to-ir.h b/source/slang/slang-lower-to-ir.h
index 060efb88b..33dfd9d27 100644
--- a/source/slang/slang-lower-to-ir.h
+++ b/source/slang/slang-lower-to-ir.h
@@ -15,14 +15,32 @@ namespace Slang
{
class EntryPoint;
class ProgramLayout;
+ class SpecializedComponentType;
class TranslationUnitRequest;
+ /// Generate an IR module to represent the code in the given `translationUnit`.
+ ///
+ /// The generated module will include IR definitions for any functions/types
+ /// in `translationUnit`, but it is *not* guaranteed to contain any definitions
+ /// from modules that are `import`ed into `translationUnit`. The resulting IR
+ /// module must be linked against other IR modules that define any symbols
+ /// that are imported before code generation can be performed.
+ ///
IRModule* generateIRForTranslationUnit(
TranslationUnitRequest* translationUnit);
- RefPtr<IRModule> generateIRForProgram(
- Session* session,
- Program* program,
- DiagnosticSink* sink);
+ /// Generate an IR module to represent the specializations applied by `componentType`.
+ ///
+ /// The generated IR will encode how `componentType` specializes global or
+ /// entry-point specialization parameters to concrete arguments (e.g., types).
+ ///
+ /// The generated IR module is *not* guaranteed to contain anything more, such
+ /// as the actual definitions of functions or types being specialized. The
+ /// resulting IR module must be linked against other IR modules that define
+ /// those symbols before code generation can be performed.
+ ///
+ RefPtr<IRModule> generateIRForSpecializedComponentType(
+ SpecializedComponentType* componentType,
+ DiagnosticSink* sink);
}
#endif
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index dc99f55e2..722725af7 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -674,16 +674,22 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(
EntryPointParameterState const& state,
RefPtr<VarLayout> varLayout);
-// Collect a single declaration into our set of parameters
-static void collectGlobalGenericParameter(
- ParameterBindingContext* context,
- RefPtr<GlobalGenericParamDecl> paramDecl)
+static RefPtr<VarLayout> _createVarLayout(
+ TypeLayout* typeLayout,
+ DeclRef<VarDeclBase> varDeclRef)
{
- RefPtr<GenericParamLayout> layout = new GenericParamLayout();
- layout->decl = paramDecl;
- layout->index = (int)context->shared->programLayout->globalGenericParams.getCount();
- context->shared->programLayout->globalGenericParams.add(layout);
- context->shared->programLayout->globalGenericParamsMap[layout->decl->getName()->text] = layout.Ptr();
+ RefPtr<VarLayout> varLayout = new VarLayout();
+ varLayout->typeLayout = typeLayout;
+ varLayout->varDecl = varDeclRef;
+
+ if(auto pendingDataTypeLayout = typeLayout->pendingDataTypeLayout)
+ {
+ RefPtr<VarLayout> pendingVarLayout = new VarLayout();
+ pendingVarLayout->typeLayout = pendingDataTypeLayout;
+ varLayout->pendingVarLayout = pendingVarLayout;
+ }
+
+ return varLayout;
}
// Collect a single declaration into our set of parameters
@@ -712,9 +718,7 @@ static void collectGlobalScopeParameter(
return;
// Now create a variable layout that we can use
- RefPtr<VarLayout> varLayout = new VarLayout();
- varLayout->typeLayout = typeLayout;
- varLayout->varDecl = varDeclRef;
+ RefPtr<VarLayout> varLayout = _createVarLayout(typeLayout, varDeclRef);
// The logic in `check.cpp` that created the `GlobalShaderParamInfo`
// will have identified any cases where there might be multiple
@@ -748,6 +752,7 @@ static void collectGlobalScopeParameter(
RefPtr<VarLayout> additionalVarLayout = new VarLayout();
additionalVarLayout->typeLayout = typeLayout;
additionalVarLayout->varDecl = additionalVarDeclRef;
+ additionalVarLayout->pendingVarLayout = varLayout->pendingVarLayout;
parameterInfo->varLayouts.add(additionalVarLayout);
}
@@ -1770,15 +1775,40 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(
return structLayout;
}
- else if (auto globalGenericParam = declRef.as<GlobalGenericParamDecl>())
+ else if (auto globalGenericParamDecl = declRef.as<GlobalGenericParamDecl>())
{
- auto genParamTypeLayout = new GenericParamTypeLayout();
- // we should have already populated ProgramLayout::genericEntryPointParams list at this point,
- // so we can find the index of this generic param decl in the list
- genParamTypeLayout->type = type;
- genParamTypeLayout->paramIndex = findGenericParam(context->shared->programLayout->globalGenericParams, globalGenericParam.getDecl());
- genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1;
- return genParamTypeLayout;
+ auto& layoutContext = context->layoutContext;
+
+ if( auto concreteType = findGlobalGenericSpecializationArg(
+ layoutContext,
+ globalGenericParamDecl) )
+ {
+ // If we know what concrete type has been used to specialize
+ // the global generic type parameter, then we should use
+ // the concrete type instead.
+ //
+ // Note: it should be illegal for the user to use a generic
+ // type parameter in a varying parameter list without giving
+ // it an explicit user-defined semantic. Otherwise, it would be possible
+ // that the concrete type that gets plugged in is a user-defined
+ // `struct` that uses some `SV_` semantics in its definition,
+ // so that any static information about what system values
+ // the entry point uses would be incorrect.
+ //
+ return processEntryPointVaryingParameter(context, concreteType, state, varLayout);
+ }
+ else
+ {
+ // If we don't know a concrete type, then we aren't generating final
+ // code, so the reflection information should show the generic
+ // type parameter.
+ //
+ // We don't make any attempt to assign varying parameter resources
+ // to the generic type, since we can't know how many "slots"
+ // of varying input/output it would consume.
+ //
+ return createTypeLayoutForGlobalGenericTypeParam(layoutContext, type, globalGenericParamDecl);
+ }
}
else if (auto associatedTypeParam = declRef.as<AssocTypeDecl>())
{
@@ -1804,15 +1834,12 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(
/// Compute the type layout for a parameter declared directly on an entry point.
static RefPtr<TypeLayout> computeEntryPointParameterTypeLayout(
ParameterBindingContext* context,
- SubstitutionSet typeSubst,
DeclRef<VarDeclBase> paramDeclRef,
RefPtr<VarLayout> paramVarLayout,
EntryPointParameterState& state)
{
- auto paramDeclRefType = GetType(paramDeclRef);
- SLANG_ASSERT(paramDeclRefType);
-
- auto paramType = paramDeclRefType->Substitute(typeSubst).as<Type>();
+ auto paramType = GetType(paramDeclRef);
+ SLANG_ASSERT(paramType);
if( paramDeclRef.getDecl()->HasModifier<HLSLUniformModifier>() )
{
@@ -1940,6 +1967,12 @@ struct ScopeLayoutBuilder
{
m_structLayout->mapVarToLayout.Add(firstVarLayout->varDecl.getDecl(), firstVarLayout);
}
+ }
+
+ void addParameter(
+ RefPtr<VarLayout> varLayout)
+ {
+ _addParameter(varLayout, nullptr);
// Any "pending" items on a field type become "pending" items
// on the overall `struct` type layout.
@@ -1948,42 +1981,58 @@ struct ScopeLayoutBuilder
// `struct` layout logic in `type-layout.cpp`. If this gets any
// more complicated we should see if there is a way to share it.
//
- if( auto fieldPendingDataTypeLayout = firstVarLayout->typeLayout->pendingDataTypeLayout )
+ if( auto fieldPendingDataTypeLayout = varLayout->typeLayout->pendingDataTypeLayout )
{
m_pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(nullptr, m_rules);
- auto fieldPendingDataVarLayout = m_pendingDataTypeLayoutBuilder.addField(firstVarLayout->varDecl, fieldPendingDataTypeLayout);
+ auto fieldPendingDataVarLayout = m_pendingDataTypeLayoutBuilder.addField(varLayout->varDecl, fieldPendingDataTypeLayout);
m_structLayout->pendingDataTypeLayout = m_pendingDataTypeLayoutBuilder.getTypeLayout();
- if( parameterInfo )
- {
- for( auto& varLayout : parameterInfo->varLayouts )
- {
- varLayout->pendingVarLayout = fieldPendingDataVarLayout;
- }
- }
- else
- {
- firstVarLayout->pendingVarLayout = fieldPendingDataVarLayout;
- }
+ varLayout->pendingVarLayout = fieldPendingDataVarLayout;
}
}
void addParameter(
- RefPtr<VarLayout> varLayout)
- {
- _addParameter(varLayout, nullptr);
- }
-
- void addParameter(
ParameterInfo* parameterInfo)
{
SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0);
auto firstVarLayout = parameterInfo->varLayouts.getFirst();
_addParameter(firstVarLayout, parameterInfo);
- }
+ // Global parameters will have their non-orindary/uniform
+ // pending data handled by the main parameter binding
+ // logic, but we still need to construct a layout
+ // that includes any pending data.
+ //
+ if(auto fieldPendingVarLayout = firstVarLayout->pendingVarLayout)
+ {
+ auto fieldPendingTypeLayout = fieldPendingVarLayout->typeLayout;
+
+ m_pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(nullptr, m_rules);
+ m_structLayout->pendingDataTypeLayout = m_pendingDataTypeLayoutBuilder.getTypeLayout();
+
+ auto fieldUniformLayoutInfo = fieldPendingTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform);
+ LayoutSize fieldUniformSize = fieldUniformLayoutInfo ? fieldUniformLayoutInfo->count : 0;
+ if( fieldUniformSize != 0 )
+ {
+ // Make sure uniform fields get laid out properly...
+
+ UniformLayoutInfo fieldInfo(
+ fieldUniformSize,
+ fieldPendingTypeLayout->uniformAlignment);
+
+ LayoutSize uniformOffset = m_rules->AddStructField(
+ m_pendingDataTypeLayoutBuilder.getStructLayoutInfo(),
+ fieldInfo);
+
+ fieldPendingVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue();
+ }
+
+ m_pendingDataTypeLayoutBuilder.getTypeLayout()->fields.add(fieldPendingVarLayout);
+ }
+
+ }
// Add a "simple" parameter that cannot have any user-defined
// register or binding modifiers, so that its layout computation
@@ -2088,21 +2137,48 @@ static ParameterBindingAndKindInfo maybeAllocateConstantBufferBinding(
return info;
}
+ /// Remove resource usage from `typeLayout` that should only be stored per-entry-point.
+ ///
+ /// This is used when constructing the overall layout for an entry point, to make sure
+ /// that certain kinds of resource usage from the entry point don't "leak" into
+ /// the resource usage of the overall program.
+ ///
+static void removePerEntryPointParameterKinds(
+ TypeLayout* typeLayout)
+{
+ typeLayout->removeResourceUsage(LayoutResourceKind::VaryingInput);
+ typeLayout->removeResourceUsage(LayoutResourceKind::VaryingOutput);
+ typeLayout->removeResourceUsage(LayoutResourceKind::ShaderRecord);
+ typeLayout->removeResourceUsage(LayoutResourceKind::HitAttributes);
+ typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialObjectParam);
+ typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialTypeParam);
+}
+
/// Iterate over the parameters of an entry point to compute its requirements.
///
static RefPtr<EntryPointLayout> collectEntryPointParameters(
- ParameterBindingContext* context,
- EntryPoint* entryPoint,
- SubstitutionSet typeSubst)
+ ParameterBindingContext* context,
+ EntryPoint* entryPoint,
+ EntryPoint::EntryPointSpecializationInfo* specializationInfo)
{
DeclRef<FuncDecl> entryPointFuncDeclRef = entryPoint->getFuncDeclRef();
+ // If specialization was applied to the entry point, then the side-band
+ // information that was generated will have a more specialized reference
+ // to the entry point with generic parameters filled in. We should
+ // use that version if it is available.
+ //
+ if(specializationInfo)
+ entryPointFuncDeclRef = specializationInfo->specializedFuncDeclRef;
+
+ auto entryPointType = DeclRefType::Create(context->getLinkage()->getSessionImpl(), entryPointFuncDeclRef);
+
// We will take responsibility for creating and filling in
// the `EntryPointLayout` object here.
//
RefPtr<EntryPointLayout> entryPointLayout = new EntryPointLayout();
entryPointLayout->profile = entryPoint->getProfile();
- entryPointLayout->entryPoint = entryPointFuncDeclRef.getDecl();
+ entryPointLayout->entryPoint = entryPointFuncDeclRef;
// The entry point layout must be added to the output
// program layout so that it can be accessed by reflection.
@@ -2114,19 +2190,6 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters(
//
context->entryPointLayout = entryPointLayout;
- // Note: this isn't really the best place for this logic to sit,
- // but it is the simplest place where we have a direct correspondence
- // between a single `EntryPoint` and its matching `EntryPointLayout`,
- // so we'll use it.
- //
- for( auto taggedUnionType : entryPoint->getTaggedUnionTypes() )
- {
- SLANG_ASSERT(taggedUnionType);
- auto substType = taggedUnionType->Substitute(typeSubst).as<Type>();
- auto typeLayout = createTypeLayout(context->layoutContext, substType);
- entryPointLayout->taggedUnionTypeLayouts.add(typeLayout);
- }
-
// We are going to iterate over the entry-point parameters,
// and while we do so we will go ahead and perform layout/binding
// assignment for two cases:
@@ -2149,22 +2212,33 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters(
ScopeLayoutBuilder scopeBuilder;
scopeBuilder.beginLayout(context);
auto paramsStructLayout = scopeBuilder.m_structLayout;
+ paramsStructLayout->type = entryPointType;
for( auto& shaderParamInfo : entryPoint->getShaderParams() )
{
auto paramDeclRef = shaderParamInfo.paramDeclRef;
+ // Any generic specialization applied to the entry-point function
+ // must also be applied to its parameters.
+ paramDeclRef.substitutions = entryPointFuncDeclRef.substitutions;
+
// When computing layout for an entry-point parameter,
// we want to make sure that the layout context has access
// to the existential type arguments (if any) that were
// provided for the entry-point existential type parameters (if any).
//
- context->layoutContext= context->layoutContext
- .withExistentialTypeArgs(
- entryPoint->getExistentialTypeArgCount(),
- entryPoint->getExistentialTypeArgs())
- .withExistentialTypeSlotsOffsetBy(
- shaderParamInfo.firstExistentialTypeSlot);
+ if(specializationInfo)
+ {
+ auto& existentialSpecializationArgs = specializationInfo->existentialSpecializationArgs;
+ auto genericSpecializationParamCount = entryPoint->getGenericSpecializationParamCount();
+
+ context->layoutContext = context->layoutContext
+ .withSpecializationArgs(
+ existentialSpecializationArgs.getBuffer(),
+ existentialSpecializationArgs.getCount())
+ .withSpecializationArgsOffsetBy(
+ shaderParamInfo.firstSpecializationParamIndex - genericSpecializationParamCount);
+ }
// Any error messages we emit during the process should
// refer to the location of this parameter.
@@ -2183,7 +2257,6 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters(
auto paramTypeLayout = computeEntryPointParameterTypeLayout(
context,
- typeSubst,
paramDeclRef,
paramVarLayout,
state);
@@ -2204,6 +2277,26 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters(
//
scopeBuilder.addSimpleParameter(paramVarLayout);
}
+
+ // We don't want certain kinds of resource usage within an entry
+ // point to "leak" into the overall resource usage of the entry
+ // point and thus lead to offsetting of successive entry points.
+ //
+ // For example if we have a vertex and a fragment entry point
+ // in the some program, and each has one varying input, then
+ // the both the vertex and fragment varying outputs should have
+ // a location/index of zero. It would be bad if the fragment
+ // input (or whichever entry point comes second in the global
+ // ordering) started at location one, because then it wouldn't
+ // line up correctly with any vertex stage outputs.
+ //
+ // We handle this with a bit of a kludge, by removing the
+ // particular `LayoutResourceKind`s that are susceptible to
+ // this problem from the overall resource usage of the entry
+ // point.
+ //
+ removePerEntryPointParameterKinds(scopeBuilder.m_structLayout);
+
entryPointLayout->parametersLayout = scopeBuilder.endLayout();
// For an entry point with a non-`void` return type, we need to process the
@@ -2212,7 +2305,7 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters(
// TODO: Ideally we should make the layout process more robust to empty/void
// types and apply this logic unconditionally.
//
- auto resultType = GetResultType(entryPointFuncDeclRef)->Substitute(typeSubst).as<Type>();
+ auto resultType = GetResultType(entryPointFuncDeclRef);
SLANG_ASSERT(resultType);
if( !resultType->Equals(resultType->getSession()->getVoidType()) )
@@ -2226,7 +2319,7 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters(
auto resultTypeLayout = processEntryPointVaryingParameterDecl(
context,
entryPointFuncDeclRef.getDecl(),
- resultType->Substitute(typeSubst).as<Type>(),
+ resultType,
state,
resultLayout);
@@ -2248,136 +2341,257 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters(
return entryPointLayout;
}
- /// Remove resource usage from `typeLayout` that should only be stored per-entry-point.
- ///
- /// This is used when constructing the layout for an entry point group, to make sure
- /// that certain kinds of resource usage from the entry point don't "leak" into
- /// the resource usage of the group.
- ///
-static void removePerEntryPointParameterKinds(
- TypeLayout* typeLayout)
+ /// Visitor used by `collectGlobalGenericArguments`
+struct CollectGlobalGenericArgumentsVisitor : ComponentTypeVisitor
{
- typeLayout->removeResourceUsage(LayoutResourceKind::VaryingInput);
- typeLayout->removeResourceUsage(LayoutResourceKind::VaryingOutput);
- typeLayout->removeResourceUsage(LayoutResourceKind::ShaderRecord);
- typeLayout->removeResourceUsage(LayoutResourceKind::HitAttributes);
- typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialObjectParam);
- typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialTypeParam);
-}
+ CollectGlobalGenericArgumentsVisitor(
+ ParameterBindingContext* context)
+ : m_context(context)
+ {}
-static void collectParameters(
- ParameterBindingContext* inContext,
- Program* program)
-{
- // All of the parameters in translation units directly
- // referenced in the compile request are part of one
- // logical namespace/"linkage" so that two parameters
- // with the same name should represent the same
- // parameter, and get the same binding(s)
+ ParameterBindingContext* m_context;
- ParameterBindingContext contextData = *inContext;
- auto context = &contextData;
- context->stage = Stage::Unknown;
+ void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(entryPoint);
+ SLANG_UNUSED(specializationInfo);
+ }
- auto globalGenericSubst = program->getGlobalGenericSubstitution();
+ void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(module);
- // We will start by looking for any global generic type parameters.
+ if(!specializationInfo)
+ return;
- for(RefPtr<Module> module : program->getModuleDependencies())
- {
- for( auto genParamDecl : module->getModuleDecl()->getMembersOfType<GlobalGenericParamDecl>() )
+ for(auto& globalGenericArg : specializationInfo->genericArgs)
{
- collectGlobalGenericParameter(context, genParamDecl);
+ if(auto globalGenericTypeParamDecl = as<GlobalGenericParamDecl>(globalGenericArg.paramDecl))
+ {
+ m_context->shared->programLayout->globalGenericArgs.Add(globalGenericTypeParamDecl, globalGenericArg.argVal);
+ }
}
}
- // Once we have enumerated global generic type parameters, we can
- // begin enumerating shader parameters, starting at the global scope.
- //
- // Because we have already enumerated the global generic type parameters,
- // we will be able to look up the index of a global generic type parameter
- // when we see it referenced in the type of one of the shader parameters.
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(composite, specializationInfo);
+ }
- for(auto& globalParamInfo : program->getShaderParams() )
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
{
- // When computing layout for a global shader parameter,
- // we want to make sure that the layout context has access
- // to the existential type arguments (if any) that were
- // provided for the global existential type parameters (if any).
- //
- context->layoutContext= context->layoutContext
- .withExistentialTypeArgs(
- program->getExistentialTypeArgCount(),
- program->getExistentialTypeArgs())
- .withExistentialTypeSlotsOffsetBy(
- globalParamInfo.firstExistentialTypeSlot);
+ specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo());
+ }
- collectGlobalScopeParameter(context, globalParamInfo, globalGenericSubst);
+ void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ // TODO: Need to do something in this case...
+ SLANG_UNUSED(legacy);
+ SLANG_UNUSED(specializationInfo);
}
+};
+
+ /// Collect an ordered list of all the specialization arguments given for global generic specialization parameters in `program`.
+ ///
+ /// This information is used to accelerate the process of mapping a global generic type
+ /// to its definition during type layout.
+ ///
+static void collectGlobalGenericArguments(
+ ParameterBindingContext* context,
+ ComponentType* program)
+{
+ CollectGlobalGenericArgumentsVisitor visitor(context);
+ program->acceptVisitor(&visitor, nullptr);
+}
- // Next consider parameters for entry points
- for( auto entryPointGroup : program->getEntryPointGroups() )
+ /// Collect information about the (unspecialized) specialization parameters of `program` into `context`.
+ ///
+ /// This function computes the reflection/layout for for the specialization parameters, so
+ /// that they can be exposed to the API user.
+ ///
+static void collectSpecializationParams(
+ ParameterBindingContext* context,
+ ComponentType* program)
+{
+ auto specializationParamCount = program->getSpecializationParamCount();
+ for(Index ii = 0; ii < specializationParamCount; ++ii)
{
- RefPtr<EntryPointGroupLayout> entryPointGroupLayout = new EntryPointGroupLayout();
- entryPointGroupLayout->group = entryPointGroup;
+ auto specializationParam = program->getSpecializationParam(ii);
+ switch(specializationParam.flavor)
+ {
+ case SpecializationParam::Flavor::GenericType:
+ case SpecializationParam::Flavor::GenericValue:
+ {
+ RefPtr<GenericSpecializationParamLayout> paramLayout = new GenericSpecializationParamLayout();
+ paramLayout->decl = specializationParam.object.as<Decl>();
+ context->shared->programLayout->specializationParams.add(paramLayout);
+ }
+ break;
- context->shared->programLayout->entryPointGroups.add(entryPointGroupLayout);
+ case SpecializationParam::Flavor::ExistentialType:
+ case SpecializationParam::Flavor::ExistentialValue:
+ {
+ RefPtr<ExistentialSpecializationParamLayout> paramLayout = new ExistentialSpecializationParamLayout();
+ paramLayout->type = specializationParam.object.as<Type>();
+ context->shared->programLayout->specializationParams.add(paramLayout);
+ }
+ break;
+ default:
+ SLANG_UNEXPECTED("unhandled specialization parameter flavor");
+ break;
+ }
+ }
+}
+
+ /// Visitor used by `collectParameters()`
+struct CollectParametersVisitor : ComponentTypeVisitor
+{
+ CollectParametersVisitor(
+ ParameterBindingContext* context)
+ : m_context(context)
+ {}
- ScopeLayoutBuilder scopeBuilder;
- scopeBuilder.beginLayout(context);
- auto entryPointGroupParamsStructLayout = scopeBuilder.m_structLayout;
+ ParameterBindingContext* m_context;
- // First lay out any shader parameters that belong to the group
- // itself, rather than to its nested entry points.
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ // The parameters of a composite component type can
+ // be determined by just visiting its children in order.
//
- // This ensures that looking up one of the parameters of the
- // group by index in its parameters truct will Just Work.
+ visitChildren(composite, specializationInfo);
+ }
+
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
+ {
+ // The parameters of a specialized component type
+ // are just those of its base component type, with
+ // appropriate specialization information passed
+ // along.
//
- for( auto groupParam : entryPointGroup->getShaderParams() )
+ visitChildren(specialized);
+
+ // While we are at it, we will also make note of any
+ // tagged-union types that were used as part of the
+ // specialization arguments, since we need to make
+ // sure that their layout information is computed
+ // and made available for IR code generation.
+ //
+ // Note: this isn't really the best place for this logic to sit,
+ // but it is the simplest place where we can collect all the tagged
+ // union types that get referenced by a program.
+ //
+ for( auto taggedUnionType : specialized->getTaggedUnionTypes() )
{
- auto paramDeclRef = groupParam.paramDeclRef;
- auto paramType = GetType(paramDeclRef);
+ SLANG_ASSERT(taggedUnionType);
+ auto substType = taggedUnionType;
+ auto typeLayout = createTypeLayout(m_context->layoutContext, substType);
+ m_context->shared->programLayout->taggedUnionTypeLayouts.add(typeLayout);
+ }
+ }
- RefPtr<VarLayout> paramVarLayout = new VarLayout();
- paramVarLayout->varDecl = paramDeclRef;
- auto paramTypeLayout = createTypeLayout(
- context->layoutContext.with(context->getRulesFamily()->getConstantBufferRules()),
- paramType);
- paramVarLayout->typeLayout = paramTypeLayout;
+ void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ // An entry point is a leaf case.
+ //
+ // In our current model an entry point does not introduce
+ // any global shader parameters, but in practice it effectively
+ // acts a lot like a single global shader parameter named after
+ // the entry point and with a `struct` type that combines
+ // all the `uniform` entry point parameters.
+ //
+ // Later passes will need to make sure that the entry point
+ // gets enumerated in the right order relative to any global
+ // shader parameters.
+ //
- scopeBuilder.addSimpleParameter(paramVarLayout);
- }
+ ParameterBindingContext contextData = *m_context;
+ auto context = &contextData;
+ context->stage = entryPoint->getStage();
- for(auto entryPoint : entryPointGroup->getEntryPoints())
- {
- // Note: we do not want the entry point group to accumulate
- // locations for varying input/output parameters: those
- // should be specific to each entry point.
- //
- // We address this issue by manually removing any
- // layout information for the relevant resource kinds
- // from the group's layout before adding the parameters
- // of any entry point for layout.
- //
- removePerEntryPointParameterKinds(scopeBuilder.m_structLayout);
+ collectEntryPointParameters(context, entryPoint, specializationInfo);
+ }
- context->stage = entryPoint->getStage();
- auto entryPointLayout = collectEntryPointParameters(context, entryPoint, globalGenericSubst);
+ void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ // A single module represents a leaf case for layout.
+ //
+ // We will enumerate the (global) shader parameters declared
+ // in the module and add each to our canonical ordering.
+ //
+ auto paramCount = module->getShaderParamCount();
- auto entryPointParamsLayout = entryPointLayout->parametersLayout;
- auto entryPointParamsTypeLayout = entryPointParamsLayout->typeLayout;
+ ExpandedSpecializationArg* specializationArgs = specializationInfo
+ ? specializationInfo->existentialArgs.getBuffer()
+ : nullptr;
- scopeBuilder.addSimpleParameter(entryPointParamsLayout);
+ for(Index pp = 0; pp < paramCount; ++pp)
+ {
+ auto shaderParamInfo = module->getShaderParam(pp);
+ if(specializationArgs)
+ {
+ m_context->layoutContext = m_context->layoutContext.withSpecializationArgs(
+ specializationArgs,
+ shaderParamInfo.specializationParamCount);
+ specializationArgs += shaderParamInfo.specializationParamCount;
+ }
- entryPointGroupLayout->entryPoints.add(entryPointLayout);
+ collectGlobalScopeParameter(m_context, shaderParamInfo, SubstitutionSet());
}
- removePerEntryPointParameterKinds(scopeBuilder.m_structLayout);
+ }
+
+
+ void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ // A legacy program is also a leaf case, and we
+ // can enumerate its parameters directly.
+ //
+ // Note: there is a mismatch here where we really
+ // ought to be tracking specialization arguments
+ // for a `LegacyProgram` akin to how they are
+ // tracked for a `Module`, but right now we try
+ // to do it like a `CompositeComponentType`.
+ // As a result we are just ignoring specialization
+ // information here, which will lead to incorrect
+ // results if somebody every uses specialization
+ // together with the "legacy" program case.
+ //
+ // TODO: eliminate this problem by getting rid of
+ // `LegacyProgram`, rather than spend time trying
+ // to make this corner case actually work.
+ //
+ SLANG_UNUSED(specializationInfo);
- entryPointGroupLayout->parametersLayout = scopeBuilder.endLayout();
+ auto paramCount = legacy->getShaderParamCount();
+ for(Index pp = 0; pp < paramCount; ++pp)
+ {
+ collectGlobalScopeParameter(m_context, legacy->getShaderParam(pp), SubstitutionSet());
+ }
}
- context->entryPointLayout = nullptr;
+};
+
+ /// Recursively collect the global shader parameters and entry points in `program`.
+ ///
+ /// This function is used to establish the global ordering of parameters and
+ /// entry points used for layout.
+ ///
+static void collectParameters(
+ ParameterBindingContext* inContext,
+ ComponentType* program)
+{
+ // All of the parameters in translation units directly
+ // referenced in the compile request are part of one
+ // logical namespace/"linkage" so that two parameters
+ // with the same name should represent the same
+ // parameter, and get the same binding(s)
+
+ ParameterBindingContext contextData = *inContext;
+ auto context = &contextData;
+ context->stage = Stage::Unknown;
+
+ CollectParametersVisitor visitor(context);
+ program->acceptVisitor(&visitor, nullptr);
}
/// Emit a diagnostic about a uniform parameter at global scope.
@@ -2418,6 +2632,302 @@ static int _calcTotalNumUsedRegistersForLayoutResourceKind(ParameterBindingConte
return numUsed;
}
+ /// Keep track of the running global counter for entry points and global parameters visited.
+ ///
+ /// Because of explicit `register` and `[[vk::binding(...)]]` support, parameter binding
+ /// needs to proceed in multiple passes, and each pass must both visit the things that
+ /// need layout (parameters and entry points) in the same order in each pass, and must
+ /// also be able to look up the side-band information that flows between passes.
+ ///
+ /// Currently the `ParameterBindingContext` keeps separate arrays for global shader
+ /// parameters and entry points, but in the global ordering for layout they can be
+ /// interleaved. There is also no simple tracking structure that relates a global
+ /// parameter or entry point to its index in those arrays. Instead, we just keep
+ /// running counters during our passes over the program so that we can easily
+ /// compute the linear index of each entry point and global parameter as it
+ /// is encountered.
+ ///
+struct ParameterBindingVisitorCounters
+{
+ Index entryPointCounter = 0;
+ Index globalParamCounter = 0;
+};
+
+ /// Recursive routine to "complete" all binding for parameters and entry points in `componentType`.
+ ///
+ /// This includes allocation of as-yet-unused register/binding ranges to parameters (which
+ /// will then affect the ranges of registers/bindings that are available to subsequent
+ /// parameters), and imporantly *also* includes allocate of space to any "pending"
+ /// data for interface/existential type parameters/fields.
+ ///
+static void _completeBindings(
+ ParameterBindingContext* context,
+ ComponentType* componentType,
+ ParameterBindingVisitorCounters* ioCounters);
+
+ /// A visitor used by `_completeBindings`.
+ ///
+ /// This visitor walks the structure of a `ComponentType` to ensure that
+ /// any shader parameters (and entry points) it contains that *don't*
+ /// have explicit bindings on them get allocated registers/bindings
+ /// as appropriate.
+ ///
+ /// The main complication of this visitor is how it handles the
+ /// `SpecializedComponentType` case, because a specialized component
+ /// type needs to be handled as an atomic unit that lays out the
+ /// same in all contexts.
+ ///
+struct CompleteBindingsVisitor : ComponentTypeVisitor
+{
+ CompleteBindingsVisitor(ParameterBindingContext* context, ParameterBindingVisitorCounters* counters)
+ : m_context(context)
+ , m_counters(counters)
+ {}
+
+ ParameterBindingContext* m_context;
+ ParameterBindingVisitorCounters* m_counters;
+
+ void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(entryPoint);
+ SLANG_UNUSED(specializationInfo);
+
+ // We compute the index of the entry point in the global ordering,
+ // so we can look up the tracking data in our context. As a result
+ // we don't actually make use of the parameters that were passed in.
+ //
+ auto globalEntryPointIndex = m_counters->entryPointCounter++;
+ auto globalEntryPointInfo = m_context->shared->programLayout->entryPoints[globalEntryPointIndex];
+
+
+ // We mostly treat an entry point like a single shader parameter that
+ // uses its `parametersLayout`.
+ //
+ completeBindingsForParameter(m_context, globalEntryPointInfo->parametersLayout);
+ }
+
+ void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(specializationInfo);
+ // A module is a leaf case: we just want to visit each parameter.
+ visitLeafParams(module);
+ }
+
+ void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(specializationInfo);
+ // A legacy program is a leaf case: we just want to visit each parameter.
+ visitLeafParams(legacy);
+ }
+
+ void visitLeafParams(ComponentType* componentType)
+ {
+ auto paramCount = componentType->getShaderParamCount();
+ for(Index ii = 0; ii < paramCount; ++ii)
+ {
+ auto globalParamIndex = m_counters->globalParamCounter++;
+ auto globalParamInfo = m_context->shared->parameters[globalParamIndex];
+
+ completeBindingsForParameter(m_context, globalParamInfo);
+ }
+ }
+
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ // We just wnat to recurse on the children of the composite in order.
+ visitChildren(composite, specializationInfo);
+ }
+
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
+ {
+ // The handling of a specialized component type here is subtle.
+ //
+ // We do *not* simply recurse on the base component type.
+ // Doing so would ensure that the parameters would get
+ // registers/bindings allocated to them, but it wouldn't
+ // allocate space for the "pending" data related to
+ // existential/interface parameters.
+ //
+ // Instead, we recursive through `_completeBindings`,
+ // which has the job of allocating space for the parameters,
+ // and then for any "pending" data required.
+ //
+ // Handling things this way ensures that a particular
+ // `SpecializedComponentType` gets laid out exactly
+ // the same wherever it gets used, rather than
+ // getting laid out differently when it is placed
+ // into different compositions.
+ //
+ auto base = specialized->getBaseComponentType();
+ _completeBindings(m_context, base, m_counters);
+ }
+};
+
+ /// A visitor used by `_completeBindings`.
+ ///
+ /// This visitor is used to follow up after the `CompleteBindingsVisitor`
+ /// any ensure that any "pending" data required by the parameters that
+ /// got laid out now gets a location.
+ ///
+ /// To make a concrete example:
+ ///
+ /// Texture2D a;
+ /// IThing b;
+ /// Texture2D c;
+ ///
+ /// If these parameters were laid out with `b` specialized to a type
+ /// that contains a single `Texture2D`, then the `CompleteBindingsVisitor`
+ /// would visit `a`, `b`, and then `c` in order. It would give `a` the
+ /// first register/binding available (say, `t0`). It would then make
+ /// a note that due to specialization, `b`, needs a `t` register as well,
+ /// but it *cannot* be allocated just yet, because doing so would change
+ /// the location of `c`, so it is marked as "pending." Then `c` would
+ /// be visited and get `t1`. As a result the registers given to `a`
+ /// and `c` are independent of how `b` gets specialized.
+ ///
+ /// Next, the `FlushPendingDataVisitor` comes through and applies to
+ /// the parameters again. For `a` there is no pending data, but for
+ /// `b` there is a pending request for a `t` register, so it gets allocated
+ /// now (getting `t2`). The `c` parameter then has no pending data, so
+ /// we are done.
+ ///
+ /// *When* the pending data gets flushed is then significant. In general,
+ /// the order in which modules get composed an specialized is signficaint.
+ /// The module above (let's call it `M`) has one specialization parameter
+ /// (for `b`), and if we want to compose it with another module `N` that
+ /// has no specialization parameters, we could compute either:
+ ///
+ /// compose(specialize(M, SomeType), N)
+ ///
+ /// or:
+ ///
+ /// specialize(compose(M,N), SomeType)
+ ///
+ /// In the first case, the "pending" data for `M` gets flushed right after `M`,
+ /// so that `specialize(M,SomeType)` can have a consistent layout
+ /// regardless of how it is used. In the second case, the pending data for
+ /// `M` only gets flushed after `N`'s parameters are allocated, thus guaranteeing
+ /// that the `compose(M,N)` part has a consistent layout regardless of what
+ /// type gets plugged in during specialization.
+ ///
+ /// There are trade-offs to be made by an application about which approach
+ /// to prefer, and the compiler supports either policy choice.
+ ///
+struct FlushPendingDataVisitor : ComponentTypeVisitor
+{
+ FlushPendingDataVisitor(ParameterBindingContext* context, ParameterBindingVisitorCounters* counters)
+ : m_context(context)
+ , m_counters(counters)
+ {}
+
+ ParameterBindingContext* m_context;
+ ParameterBindingVisitorCounters* m_counters;
+
+ void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(entryPoint);
+ SLANG_UNUSED(specializationInfo);
+
+ auto globalEntryPointIndex = m_counters->entryPointCounter++;
+ auto globalEntryPointInfo = m_context->shared->programLayout->entryPoints[globalEntryPointIndex];
+
+ // We need to allocate space for any "pending" data that
+ // appeared in the entry-point parameter list.
+ //
+ _allocateBindingsForPendingData(m_context, globalEntryPointInfo->parametersLayout->pendingVarLayout);
+ }
+
+ void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(specializationInfo);
+ visitLeafParams(module);
+ }
+
+ void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(specializationInfo);
+ visitLeafParams(legacy);
+ }
+
+ void visitLeafParams(ComponentType* componentType)
+ {
+ // In the "leaf" case we just allocate space for any
+ // pending data in the parameters, in order.
+ //
+ auto paramCount = componentType->getShaderParamCount();
+ for(Index ii = 0; ii < paramCount; ++ii)
+ {
+ auto globalParamIndex = m_counters->globalParamCounter++;
+ auto globalParamInfo = m_context->shared->parameters[globalParamIndex];
+ auto firstVarLayout = globalParamInfo->varLayouts[0];
+
+ _allocateBindingsForPendingData(m_context, firstVarLayout->pendingVarLayout);
+ }
+ }
+
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(composite, specializationInfo);
+ }
+
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
+ {
+ // Because `SpecializedComponentType` was a special case for `CompleteBindingsVisitor`,
+ // it ends up being a special case here too.
+ //
+ // The `CompleteBindings...` pass treated a `SpecializedComponentType`
+ // as an atomic unit. Any "pending" data that came from its parameters
+ // will already have been dealt with, so it would be incorrect for
+ // us to recurse into `specialized`.
+ //
+ // Instead, we just need to *skip* `specialized`, since it was
+ // completely handled already. This isn't quite as simple
+ // as just doing nothing, because our passes are using
+ // some global counters to find the absolute/linear index
+ // of each parameter and entry point as it is encountered.
+ // We will simply bump those counters by the number of
+ // parameters and entry points contained under `specialized`,
+ // which is luckily provided by the `ComponentType` API.
+ //
+ m_counters->globalParamCounter += specialized->getShaderParamCount();
+ m_counters->entryPointCounter += specialized->getEntryPointCount();
+ }
+
+};
+
+static void _completeBindings(
+ ParameterBindingContext* context,
+ ComponentType* componentType,
+ ParameterBindingVisitorCounters* ioCounters)
+{
+ ParameterBindingVisitorCounters savedCounters = *ioCounters;
+
+ CompleteBindingsVisitor completeBindingsVisitor(context, ioCounters);
+ componentType->acceptVisitor(&completeBindingsVisitor, nullptr);
+
+ FlushPendingDataVisitor flushVisitor(context, &savedCounters);
+ componentType->acceptVisitor(&flushVisitor, nullptr);
+}
+
+ /// "Complete" binding of parametesr in the given `program`.
+ ///
+ /// Completing binding involves both assigning registers/bindings
+ /// to an parameters that didn't get explicit locations, and then
+ /// also providing locations to any "pending" data that needed
+ /// space allocated (used for existential/interface type parameters).
+ ///
+static void _completeBindings(
+ ParameterBindingContext* context,
+ ComponentType* program)
+{
+ // The process of completing binding has a recursive structure,
+ // so we will immediately delegate to a subroutine that handles
+ // the recursion.
+ //
+ ParameterBindingVisitorCounters counters;
+ _completeBindings(context, program, &counters);
+}
+
RefPtr<ProgramLayout> generateParameterBindings(
TargetProgram* targetProgram,
DiagnosticSink* sink)
@@ -2450,20 +2960,67 @@ RefPtr<ProgramLayout> generateParameterBindings(
context.shared = &sharedContext;
context.layoutContext = layoutContext;
- // Walk through AST to discover all the parameters
+ // We want to start by finding out what (if anything) has
+ // been bound to the global generic parameters of the
+ // program, since we need to know these types to compute
+ // layout for parameters that use the generic type parameters.
+ //
+ collectGlobalGenericArguments(&context, program);
+
+ // Next we want to collect a full listing of all the shader
+ // parameters that need to be considered for layout, along
+ // with all of the entry points, which also need their
+ // parameters laid out and thus act pretty much like global
+ // parameters themselves.
+ //
collectParameters(&context, program);
- // Now walk through the parameters to generate initial binding information
+ // We will also collect basic information on the specialization
+ // parameters exposed by the program.
+ //
+ // Whereas `collectGlobalGenericArguments` was collecting the
+ // concrete types that have been plugged into specialization
+ // parameters, this step is about collecting the *unspecialized*
+ // parameters (if any) for the purposes of reflection.
+ //
+ collectSpecializationParams(&context, program);
+
+ // Once we have a canonical list of all the shader parameters
+ // (and entry points) in need of layout, we will walk through
+ // the parameters that might have explicit binding annotations,
+ // and "reserve" the registers/bindings/etc. that those parameters
+ // declare so that subequent automatic layout steps do not try to
+ // overlap them.
+ //
+ // Along the way we will issue diagnostics if there appear to
+ // be overlapping, conflicting, or inconsistent explicit bindings.
+ //
+ // Note that we do *not* support explicit binding annotations
+ // on entry point parameters, so we only consider global shader
+ // parameters here.
+ //
+ // (Also note that explicit bindings end up being the main
+ // source of complexity in the layout system, and we could greatly
+ // simplify this file by eliminating support for explicit
+ // binding in the future)
+ //
for( auto& parameter : sharedContext.parameters )
{
generateParameterBindings(&context, parameter);
}
- // Determine if there are any global-scope parameters that use `Uniform`
- // resources, and thus need to get packaged into a constant buffer.
+ // Once we have a canonical list of all the parameters, we can
+ // detect if there are any global-scope parameters that make use
+ // of `LayoutResourceKind::Uniform`, since such parameters would
+ // need to be packaged into a "default" constant buffer.
+ // The fxc/dxc compilers support this step, and in reflection
+ // refer to the generated constant buffer as `$Globals`.
+ //
+ // Note that this logic doesn't account for the existance of
+ // "legacy" (non-buffer-bound) uniforms in GLSL for OpenGL.
+ // If we wanted to support legaqcy uniforms we would probably
+ // want to do so through a different feature.
//
- // Note: this doesn't account for GLSL's support for "legacy" uniforms
- // at global scope, which don't get assigned a CB.
bool needDefaultConstantBuffer = false;
for( auto& parameterInfo : sharedContext.parameters )
{
@@ -2525,6 +3082,32 @@ RefPtr<ProgramLayout> generateParameterBindings(
break;
}
}
+
+ // We also need a default space for any entry-point parameters
+ // that consume appropriate resource kinds.
+ //
+ for(auto& entryPoint : sharedContext.programLayout->entryPoints)
+ {
+ auto paramsLayout = entryPoint->parametersLayout;
+ for(auto resInfo : paramsLayout->resourceInfos )
+ {
+ switch(resInfo.kind)
+ {
+ default:
+ break;
+
+ case LayoutResourceKind::RegisterSpace:
+ case LayoutResourceKind::VaryingInput:
+ case LayoutResourceKind::VaryingOutput:
+ case LayoutResourceKind::HitAttributes:
+ case LayoutResourceKind::RayPayload:
+ continue;
+ }
+
+ needDefaultSpace = true;
+ break;
+ }
+ }
}
// If we need a space for default bindings, then allocate it here.
@@ -2575,12 +3158,14 @@ RefPtr<ProgramLayout> generateParameterBindings(
&context,
needDefaultConstantBuffer);
- // Now walk through again to actually give everything
- // ranges of registers...
- for( auto& parameter : sharedContext.parameters )
- {
- completeBindingsForParameter(&context, parameter);
- }
+ // Now that all of the explicit bindings have been dealt with
+ // and we've also allocate any space/buffer that is required
+ // for global-scope parameters, we will go through the
+ // shader parameters and entry points yet again, in order
+ // to actually allocate specific bindings/registers to
+ // parameters and entry points that need them.
+ //
+ _completeBindings(&context, program);
// Next we need to create a type layout to reflect the information
// we have collected, and we will use the `ScopeLayoutBuilder`
@@ -2602,104 +3187,6 @@ RefPtr<ProgramLayout> generateParameterBindings(
cbInfo->index = globalConstantBufferBinding.index;
}
- // After we have laid out all the ordinary global parameters,
- // we need to "flush" out any pending data that was associated with
- // the global scope as part of dealing with interface-type parameters.
- //
- _allocateBindingsForPendingData(&context, globalScopeVarLayout->pendingVarLayout);
-
- // After we have allocated registers/bindings to everything
- // in the global scope we will process the parameters
- // of the entry points.
- //
- // Note: at the moment we are laying out *all* information related to global-scope
- // parameters (including pending data from interface-type parameters) before
- // anything pertaining to entry points. This is a crucial design choice to
- // get right, and we might want to revisit it based on experience.
-
- // In some cases, a user will want to ensure that all the
- // entry points they compile get non-overlapping
- // registers/bindings. E.g., if you have a vertex and fragment
- // shader being compiled together for Vulkan, you probably want distinct
- // bindings for their entry-point `uniform` parametres, so
- // that they can be used together.
- //
- // In other cases, however, a user probably doesn't want us
- // to conservatively allocate non-overlapping bindings.
- // E.g., if they have a bunch of compute shaders in a single
- // file, then they probably want each compute shader to
- // compute its parameter layout "from scratch" as if the
- // others don't exist.
- //
- // The way we handle this is by putting the entry points of a
- // `Program` into groups, and ensuring that within each group
- // we allocate parameters that don't overlap, but we don't
- // worry about overlap across groups.
- //
- for( auto entryPointGroup : sharedContext.programLayout->entryPointGroups )
- {
- // We save off the allocation state as it was before the entry-point
- // group, so that we can restore it to this state after each group.
- //
- // TODO: We probably ought to wrap all the state relevant to allocation
- // of registers/bindings into a single struct/field so that we only
- // have one thing to save/restore here even if new state gets added.
- //
- auto savedGlobalSpaceUsedRangeSets = sharedContext.globalSpaceUsedRangeSets;
- auto savedUsedSpaces = sharedContext.usedSpaces;
-
- // The group will have been allocated a layout that combines the
- // usage of all of the contained entry-points, so we just need to
- // allocate the entry-point group to be placed after all the global-scope
- // parameters.
- //
- auto entryPointGroupParamsLayout = entryPointGroup->parametersLayout;
- completeBindingsForParameter(&context, entryPointGroupParamsLayout);
-
- _allocateBindingsForPendingData(&context, entryPointGroupParamsLayout->pendingVarLayout);
-
- // TODO: Should we add the offset information from the group to
- // the layout information for each entry point (and thence to its parameters)?
- //
- // This seems important if we want to allow clients to conveniently
- // ignore groups when doing their reflection queries.
-
- // Once we've allocated bindigns for the parameters of entry points
- // in the group, we restore the state for tracking what register/bindings
- // are used to where it was before.
- //
- sharedContext.globalSpaceUsedRangeSets = savedGlobalSpaceUsedRangeSets;
- sharedContext.usedSpaces = savedUsedSpaces;
- }
-
- // HACK: we want global parameters to not have to deal with offsetting
- // by the `VarLayout` stored in `globalScopeVarLayout`, so we will scan
- // through and for any global parameter that used "pending" data, we will manually
- // offset all of its resource infos to account for where the global pending data
- // got placed.
- //
- // TODO: A more appropriate solution would be to pass the `globalScopeVarLayout`
- // down into the pass that puts layout information onto global parameters in
- // the IR, and apply the offsetting there.
- //
- for( auto& parameterInfo : sharedContext.parameters )
- {
- for( auto varLayout : parameterInfo->varLayouts )
- {
- auto pendingVarLayout = varLayout->pendingVarLayout;
- if(!pendingVarLayout) continue;
-
- for( auto& resInfo : pendingVarLayout->resourceInfos )
- {
- if( auto globalResInfo = globalScopeVarLayout->pendingVarLayout->FindResourceInfo(resInfo.kind) )
- {
- resInfo.index += globalResInfo->index;
- resInfo.space += globalResInfo->space;
- }
- }
- }
- }
-
programLayout->parametersLayout = globalScopeVarLayout;
{
@@ -2723,7 +3210,7 @@ ProgramLayout* TargetProgram::getOrCreateLayout(DiagnosticSink* sink)
}
void generateParameterBindings(
- Program* program,
+ ComponentType* program,
TargetRequest* targetReq,
DiagnosticSink* sink)
{
diff --git a/source/slang/slang-reflection.cpp b/source/slang/slang-reflection.cpp
index 3871e2ccc..5e6fd10cd 100644
--- a/source/slang/slang-reflection.cpp
+++ b/source/slang/slang-reflection.cpp
@@ -51,9 +51,9 @@ static inline SlangReflectionTypeLayout* convert(TypeLayout* type)
return (SlangReflectionTypeLayout*) type;
}
-static inline GenericParamLayout* convert(SlangReflectionTypeParameter * typeParam)
+static inline SpecializationParamLayout* convert(SlangReflectionTypeParameter * typeParam)
{
- return (GenericParamLayout*)typeParam;
+ return (SpecializationParamLayout*) typeParam;
}
static inline VarDeclBase* convert(SlangReflectionVariable* var)
@@ -86,16 +86,6 @@ static inline SlangReflectionEntryPoint* convert(EntryPointLayout* entryPoint)
return (SlangReflectionEntryPoint*) entryPoint;
}
-static inline EntryPointGroupLayout* convert(SlangEntryPointGroupLayout* entryPointGroup)
-{
- return (EntryPointGroupLayout*) entryPointGroup;
-}
-
-static inline SlangEntryPointGroupLayout* convert(EntryPointGroupLayout* entryPointGroup)
-{
- return (SlangEntryPointGroupLayout*) entryPointGroup;
-}
-
static inline ProgramLayout* convert(SlangReflection* program)
{
return (ProgramLayout*) program;
@@ -865,7 +855,7 @@ SLANG_API int spReflectionTypeLayout_getGenericParamIndex(SlangReflectionTypeLay
if(auto genericParamTypeLayout = as<GenericParamTypeLayout>(typeLayout))
{
- return genericParamTypeLayout->paramIndex;
+ return (int) genericParamTypeLayout->paramIndex;
}
else
{
@@ -1222,7 +1212,7 @@ SLANG_API char const* spReflectionEntryPoint_getName(
auto entryPointLayout = convert(inEntryPoint);
if(!entryPointLayout) return 0;
- return getText(entryPointLayout->entryPoint->getName()).begin();
+ return getText(entryPointLayout->entryPoint.GetName()).begin();
}
SLANG_API unsigned spReflectionEntryPoint_getParameterCount(
@@ -1270,7 +1260,7 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize(
SlangUInt sizeAlongAxis[3] = { 1, 1, 1 };
// First look for the HLSL case, where we have an attribute attached to the entry point function
- auto numThreadsAttribute = entryPointFunc->FindModifier<NumThreadsAttribute>();
+ auto numThreadsAttribute = entryPointFunc.getDecl()->FindModifier<NumThreadsAttribute>();
if (numThreadsAttribute)
{
sizeAlongAxis[0] = numThreadsAttribute->x;
@@ -1281,7 +1271,7 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize(
{
// Fall back to the GLSL case, which requires a search over global-scope declarations
// to look for as with the `local_size_*` qualifier
- auto module = as<ModuleDecl>(entryPointFunc->ParentDecl);
+ auto module = as<ModuleDecl>(entryPointFunc.getDecl()->ParentDecl);
if (module)
{
for (auto dd : module->Members)
@@ -1323,11 +1313,43 @@ SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput(
return (entryPointLayout->flags & EntryPointLayout::Flag::usesAnySampleRateInput) != 0;
}
+SLANG_API SlangReflectionVariableLayout* spReflectionEntryPoint_getVarLayout(
+ SlangReflectionEntryPoint* inEntryPoint)
+{
+ auto entryPointLayout = convert(inEntryPoint);
+ if(!entryPointLayout)
+ return nullptr;
+
+ return convert(entryPointLayout->parametersLayout);
+}
+
+static bool hasDefaultConstantBuffer(ScopeLayout* layout)
+{
+ auto typeLayout = layout->parametersLayout->getTypeLayout();
+ return as<ParameterGroupTypeLayout>(typeLayout) != nullptr;
+}
+
+SLANG_API int spReflectionEntryPoint_hasDefaultConstantBuffer(
+ SlangReflectionEntryPoint* inEntryPoint)
+{
+ auto entryPointLayout = convert(inEntryPoint);
+ if(!entryPointLayout)
+ return 0;
+
+ return hasDefaultConstantBuffer(entryPointLayout);
+}
+
+
// SlangReflectionTypeParameter
SLANG_API char const* spReflectionTypeParameter_GetName(SlangReflectionTypeParameter * inTypeParam)
{
- auto typeParam = convert(inTypeParam);
- return typeParam->decl->getName()->text.getBuffer();
+ auto specializationParam = convert(inTypeParam);
+ if( auto genericParamLayout = as<GenericSpecializationParamLayout>(specializationParam) )
+ {
+ return genericParamLayout->decl->getName()->text.getBuffer();
+ }
+ // TODO: Add case for existential type parameter? They don't have as simple of a notion of "name" as the generic case...
+ return nullptr;
}
SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParameter * inTypeParam)
@@ -1338,16 +1360,34 @@ SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParamet
SLANG_API unsigned int spReflectionTypeParameter_GetConstraintCount(SlangReflectionTypeParameter* inTypeParam)
{
- auto typeParam = convert(inTypeParam);
- auto constraints = typeParam->decl->getMembersOfType<GenericTypeConstraintDecl>();
- return (unsigned int)constraints.getCount();
+ auto specializationParam = convert(inTypeParam);
+ if(auto genericParamLayout = as<GenericSpecializationParamLayout>(specializationParam))
+ {
+ if( auto globalGenericParamDecl = as<GlobalGenericParamDecl>(genericParamLayout->decl) )
+ {
+ auto constraints = globalGenericParamDecl->getMembersOfType<GenericTypeConstraintDecl>();
+ return (unsigned int)constraints.getCount();
+ }
+ // TODO: Add case for entry-point generic parameters.
+ }
+ // TODO: Add case for existential type parameters.
+ return 0;
}
SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex(SlangReflectionTypeParameter * inTypeParam, unsigned index)
{
- auto typeParam = convert(inTypeParam);
- auto constraints = typeParam->decl->getMembersOfType<GenericTypeConstraintDecl>();
- return (SlangReflectionType*)constraints.toArray()[index]->sup.Ptr();
+ auto specializationParam = convert(inTypeParam);
+ if(auto genericParamLayout = as<GenericSpecializationParamLayout>(specializationParam))
+ {
+ if( auto globalGenericParamDecl = as<GlobalGenericParamDecl>(genericParamLayout->decl) )
+ {
+ auto constraints = globalGenericParamDecl->getMembersOfType<GenericTypeConstraintDecl>();
+ return (SlangReflectionType*)constraints.toArray()[index]->sup.Ptr();
+ }
+ // TODO: Add case for entry-point generic parameters.
+ }
+ // TODO: Add case for existential type parameters.
+ return 0;
}
// Shader Reflection
@@ -1379,22 +1419,32 @@ SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflec
SLANG_API unsigned int spReflection_GetTypeParameterCount(SlangReflection * reflection)
{
auto program = convert(reflection);
- return (unsigned int)program->globalGenericParams.getCount();
+ return (unsigned int) program->specializationParams.getCount();
}
SLANG_API SlangReflectionTypeParameter* spReflection_GetTypeParameterByIndex(SlangReflection * reflection, unsigned int index)
{
auto program = convert(reflection);
- return (SlangReflectionTypeParameter*)program->globalGenericParams[index].Ptr();
+ return (SlangReflectionTypeParameter*) program->specializationParams[index].Ptr();
}
SLANG_API SlangReflectionTypeParameter * spReflection_FindTypeParameter(SlangReflection * inProgram, char const * name)
{
auto program = convert(inProgram);
if (!program) return nullptr;
- GenericParamLayout * result = nullptr;
- program->globalGenericParamsMap.TryGetValue(name, result);
- return (SlangReflectionTypeParameter*)result;
+ for( auto& param : program->specializationParams )
+ {
+ auto genericParamLayout = as<GenericSpecializationParamLayout>(param);
+ if(!genericParamLayout)
+ continue;
+
+ if(getText(genericParamLayout->decl->getName()) != UnownedTerminatedStringSlice(name))
+ continue;
+
+ return (SlangReflectionTypeParameter*) genericParamLayout;
+ }
+
+ return 0;
}
SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* inProgram)
@@ -1421,7 +1471,7 @@ SLANG_API SlangReflectionEntryPoint* spReflection_findEntryPointByName(SlangRefl
// TODO: improve on naive linear search
for(auto ep : program->entryPoints)
{
- if(ep->entryPoint->getName()->text == name)
+ if(ep->entryPoint.GetName()->text == name)
{
return convert(ep);
}
@@ -1430,75 +1480,6 @@ SLANG_API SlangReflectionEntryPoint* spReflection_findEntryPointByName(SlangRefl
return nullptr;
}
-
-SLANG_API SlangInt spReflection_getEntryPointGroupCount(SlangReflection* inProgram)
-{
- auto program = convert(inProgram);
- if(!program) return 0;
-
- return program->entryPointGroups.getCount();
-}
-
-SLANG_API SlangEntryPointGroupLayout* spReflection_getEntryPointGroupByIndex(SlangReflection* inProgram, SlangInt index)
-{
- auto program = convert(inProgram);
- if(!program) return 0;
-
- if(index < 0) return nullptr;
- if(index >= program->entryPointGroups.getCount()) return nullptr;
-
- return convert(program->entryPointGroups[(int) index].Ptr());
-}
-
-SLANG_API SlangInt spEntryPointGroupLayout_getEntryPointCount(SlangEntryPointGroupLayout* inGroup)
-{
- auto group = convert(inGroup);
- if(!group) return 0;
-
- return group->entryPoints.getCount();
-}
-
-SLANG_API SlangReflectionEntryPoint* spEntryPointGroupLayout_getEntryPointByIndex(SlangEntryPointGroupLayout* inGroup, SlangInt index)
-{
- auto group = convert(inGroup);
- if(!group) return 0;
-
- if(index < 0) return nullptr;
- if(index >= group->entryPoints.getCount()) return nullptr;
-
- return convert(group->entryPoints[(int) index].Ptr());
-}
-
-SLANG_API SlangReflectionVariableLayout* spEntryPointGroupLayout_getVarLayout(SlangEntryPointGroupLayout* inGroup)
-{
- auto group = convert(inGroup);
- if(!group) return 0;
-
- return convert(group->parametersLayout);
-}
-
-SLANG_API SlangInt spEntryPointGroupLayout_getParameterCount(SlangEntryPointGroupLayout* inGroup)
-{
- auto groupLayout = convert(inGroup);
- if(!groupLayout) return 0;
-
- auto& params = groupLayout->group->getShaderParams();
- return params.getCount();
-}
-
-SLANG_API SlangReflectionVariableLayout* spEntryPointGroupLayout_getParameterByIndex(SlangEntryPointGroupLayout* inGroup, SlangInt index)
-{
- auto groupLayout = convert(inGroup);
- if(!groupLayout) return nullptr;
-
- auto& params = groupLayout->group->getShaderParams();
- if(index < 0) return nullptr;
- if(index >= params.getCount()) return nullptr;
-
- return convert(getScopeStructLayout(groupLayout)->fields[index]);
-}
-
-
SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* inProgram)
{
auto program = convert(inProgram);
@@ -1531,7 +1512,7 @@ SLANG_API SlangReflectionType* spReflection_specializeType(
auto unspecializedType = convert(inType);
if(!unspecializedType) return nullptr;
- auto linkage = programLayout->getProgram()->getLinkageImpl();
+ auto linkage = programLayout->getProgram()->getLinkage();
DiagnosticSink sink(linkage->getSourceManager());
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 08d671241..c4152d78c 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -2766,10 +2766,10 @@ String ExistentialSpecializedType::ToString()
String result;
result.append("__ExistentialSpecializedType(");
result.append(baseType->ToString());
- for( auto arg : slots.args )
+ for( auto arg : args )
{
result.append(", ");
- result.append(arg.type->ToString());
+ result.append(arg.val->ToString());
}
result.append(")");
return result;
@@ -2784,16 +2784,19 @@ bool ExistentialSpecializedType::EqualsImpl(Type * type)
if(!baseType->Equals(other->baseType))
return false;
- auto argCount = slots.args.getCount();
- if(argCount != other->slots.args.getCount())
+ auto argCount = args.getCount();
+ if(argCount != other->args.getCount())
return false;
for( Index ii = 0; ii < argCount; ++ii )
{
- if(!slots.args[ii].type->Equals(other->slots.args[ii].type))
+ auto arg = args[ii];
+ auto otherArg = other->args[ii];
+
+ if(!arg.val->EqualsVal(otherArg.val))
return false;
- if(!slots.args[ii].witness->EqualsVal(other->slots.args[ii].witness))
+ if(!areValsEqual(arg.witness, otherArg.witness))
return false;
}
return true;
@@ -2803,51 +2806,63 @@ int ExistentialSpecializedType::GetHashCode()
{
Hasher hasher;
hasher.hashObject(baseType);
- for(auto arg : slots.args)
+ for(auto arg : args)
{
- hasher.hashObject(arg.type);
- hasher.hashObject(arg.witness);
+ hasher.hashObject(arg.val);
+ if(auto witness = arg.witness)
+ hasher.hashObject(witness);
}
return hasher.getResult();
}
+RefPtr<Val> getCanonicalValue(Val* val)
+{
+ if(!val)
+ return nullptr;
+ if(auto type = as<Type>(val))
+ {
+ return type->GetCanonicalType();
+ }
+ // TODO: We may eventually need/want some sort of canonicalization
+ // for non-type values, but for now there is nothing to do.
+ return val;
+}
+
RefPtr<Type> ExistentialSpecializedType::CreateCanonicalType()
{
RefPtr<ExistentialSpecializedType> canType = new ExistentialSpecializedType();
canType->setSession(getSession());
canType->baseType = baseType->GetCanonicalType();
- for( auto paramType : slots.paramTypes )
+ for( auto arg : args )
{
- canType->slots.paramTypes.add( paramType->GetCanonicalType() );
- }
- for( auto arg : slots.args )
- {
- ExistentialTypeSlots::Arg canArg;
- canArg.type = arg.type->GetCanonicalType();
- canArg.witness = arg.witness;
- canType->slots.args.add(canArg);
+ ExpandedSpecializationArg canArg;
+ canArg.val = getCanonicalValue(arg.val);
+ canArg.witness = getCanonicalValue(arg.witness);
+ canType->args.add(canArg);
}
return canType;
}
+RefPtr<Val> substituteImpl(Val* val, SubstitutionSet subst, int* ioDiff)
+{
+ if(!val) return nullptr;
+ return val->SubstituteImpl(subst, ioDiff);
+}
+
RefPtr<Val> ExistentialSpecializedType::SubstituteImpl(SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
auto substBaseType = baseType->SubstituteImpl(subst, &diff).as<Type>();
- ExistentialTypeSlots substSlots;
- for( auto paramType : slots.paramTypes )
- {
- substSlots.paramTypes.add( paramType->SubstituteImpl(subst, &diff).as<Type>() );
- }
- for( auto arg : slots.args )
+ ExpandedSpecializationArgs substArgs;
+ for( auto arg : args )
{
- ExistentialTypeSlots::Arg substArg;
- substArg.type = arg.type->SubstituteImpl(subst, &diff).as<Type>();
- substArg.witness = arg.witness->SubstituteImpl(subst, &diff);
- substSlots.args.add(substArg);
+ ExpandedSpecializationArg substArg;
+ substArg.val = Slang::substituteImpl(arg.val, subst, &diff);
+ substArg.witness = Slang::substituteImpl(arg.witness, subst, &diff);
+ substArgs.add(substArg);
}
if(!diff)
@@ -2858,7 +2873,7 @@ RefPtr<Val> ExistentialSpecializedType::SubstituteImpl(SubstitutionSet subst, in
RefPtr<ExistentialSpecializedType> substType = new ExistentialSpecializedType();
substType->setSession(getSession());
substType->baseType = substBaseType;
- substType->slots = substSlots;
+ substType->args = substArgs;
return substType;
}
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index d156750be..88a2ca847 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -1132,31 +1132,31 @@ namespace Slang
typedef Dictionary<unsigned int, RefPtr<RefObject>> AttributeArgumentValueDict;
- /// Collects information about existential type parameters and their arguments.
- struct ExistentialTypeSlots
+ struct SpecializationParam
{
- /// For each type parameter, holds the interface/existential type that constrains it.
- List<RefPtr<Type>> paramTypes;
-
- /// An argument for an existential type parameter.
- ///
- /// Comprises a concrete type and a witness for its conformance to the desired
- /// interface/existential type for the corresponding parameter.
- ///
- struct Arg
+ enum class Flavor
{
- RefPtr<Type> type;
- RefPtr<Val> witness;
+ GenericType,
+ GenericValue,
+ ExistentialType,
+ ExistentialValue,
};
+ Flavor flavor;
+ RefPtr<RefObject> object;
+ };
+ typedef List<SpecializationParam> SpecializationParams;
- /// Any arguments provided for the existential type parameters.
- ///
- /// It is possible for `args` to be empty even if `paramTypes` is non-empty;
- /// that situation represents an unspecialized program or entry point.
- ///
- List<Arg> args;
+ struct SpecializationArg
+ {
+ RefPtr<Val> val;
};
+ typedef List<SpecializationArg> SpecializationArgs;
+ struct ExpandedSpecializationArg : SpecializationArg
+ {
+ RefPtr<Val> witness;
+ };
+ typedef List<ExpandedSpecializationArg> ExpandedSpecializationArgs;
// Generate class definition for all syntax classes
#define SYNTAX_FIELD(TYPE, NAME) TYPE NAME;
diff --git a/source/slang/slang-type-defs.h b/source/slang/slang-type-defs.h
index d9907bafe..7afc23411 100644
--- a/source/slang/slang-type-defs.h
+++ b/source/slang/slang-type-defs.h
@@ -479,7 +479,7 @@ END_SYNTAX_CLASS()
SYNTAX_CLASS(ExistentialSpecializedType, Type)
RAW(
RefPtr<Type> baseType;
- ExistentialTypeSlots slots;
+ ExpandedSpecializationArgs args;
virtual String ToString() override;
virtual bool EqualsImpl(Type * type) override;
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 9b6571372..f76b29a51 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -1222,19 +1222,11 @@ EntryPointLayout* EntryPointLayout::getAbsoluteLayout(
{
adjustedLayout->resultLayout = baseResultLayout->getAbsoluteLayout(parentLayout);
}
- adjustedLayout->taggedUnionTypeLayouts = this->taggedUnionTypeLayouts;
m_absoluteLayout = adjustedLayout;
return adjustedLayout;
}
-EntryPointLayout* EntryPointLayout::getAbsoluteLayout(EntryPointGroupLayout* parentGroup)
-{
- SLANG_ASSERT(parentGroup);
- return getAbsoluteLayout(parentGroup->parametersLayout);
-}
-
-
VarLayout* VarLayout::getAbsoluteLayout(VarLayout* parentAbsoluteLayout)
{
if( !m_absoluteLayout )
@@ -1933,9 +1925,31 @@ static TypeLayoutResult _createTypeLayout(
return _createTypeLayout(subContext, type);
}
-int findGenericParam(List<RefPtr<GenericParamLayout>> & genericParameters, GlobalGenericParamDecl * decl)
+RefPtr<Type> findGlobalGenericSpecializationArg(
+ TypeLayoutContext const& context,
+ GlobalGenericParamDecl* decl)
{
- return (int)genericParameters.findFirstIndex([=](RefPtr<GenericParamLayout> & x) {return x->decl.Ptr() == decl; });
+ RefPtr<Val> arg;
+ context.programLayout->globalGenericArgs.TryGetValue(decl, arg);
+ return arg.as<Type>();
+}
+
+Index findGlobalGenericSpecializationParamIndex(
+ ComponentType* type,
+ GlobalGenericParamDecl* decl)
+{
+ Index paramCount = type->getSpecializationParamCount();
+ for( Index pp = 0; pp < paramCount; ++pp )
+ {
+ auto param = type->getSpecializationParam(pp);
+ if(param.flavor != SpecializationParam::Flavor::GenericType)
+ continue;
+ if(param.object.Ptr() != decl)
+ continue;
+
+ return pp;
+ }
+ return -1;
}
// When constructing a new var layout from an existing one,
@@ -2393,6 +2407,37 @@ TypeLayoutResult StructTypeLayoutBuilder::getTypeLayoutResult()
return TypeLayoutResult(m_typeLayout, m_info);
}
+static TypeLayoutResult _createTypeLayoutForGlobalGenericTypeParam(
+ TypeLayoutContext const& context,
+ Type* type,
+ GlobalGenericParamDecl* globalGenericParamDecl)
+{
+ SimpleLayoutInfo info;
+ info.alignment = 0;
+ info.size = 0;
+ info.kind = LayoutResourceKind::GenericResource;
+
+ RefPtr<GenericParamTypeLayout> typeLayout = new GenericParamTypeLayout();
+ // we should have already populated ProgramLayout::genericEntryPointParams list at this point,
+ // so we can find the index of this generic param decl in the list
+ typeLayout->type = type;
+ typeLayout->paramIndex = findGlobalGenericSpecializationParamIndex(
+ context.programLayout->getProgram(),
+ globalGenericParamDecl);
+ typeLayout->rules = context.rules;
+ typeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1;
+
+ return TypeLayoutResult(typeLayout, info);
+}
+
+RefPtr<TypeLayout> createTypeLayoutForGlobalGenericTypeParam(
+ TypeLayoutContext const& context,
+ Type* type,
+ GlobalGenericParamDecl* globalGenericParamDecl)
+{
+ return _createTypeLayoutForGlobalGenericTypeParam(context, type, globalGenericParamDecl).layout;
+}
+
static TypeLayoutResult _createTypeLayout(
TypeLayoutContext const& context,
Type* type)
@@ -2825,7 +2870,7 @@ static TypeLayoutResult _createTypeLayout(
// to all the incoming specialized type slots that haven't already
// been consumed/claimed by preceding fields.
//
- auto fieldLayoutContext = context.withExistentialTypeSlotsOffsetBy(baseExistentialSlotIndex);
+ auto fieldLayoutContext = context.withSpecializationArgsOffsetBy(baseExistentialSlotIndex);
TypeLayoutResult fieldResult = _createTypeLayout(
fieldLayoutContext,
@@ -2863,22 +2908,25 @@ static TypeLayoutResult _createTypeLayout(
return typeLayoutBuilder.getTypeLayoutResult();
}
- else if (auto globalGenParam = declRef.as<GlobalGenericParamDecl>())
+ else if (auto globalGenericParamDecl = declRef.as<GlobalGenericParamDecl>())
{
- SimpleLayoutInfo info;
- info.alignment = 0;
- info.size = 0;
- info.kind = LayoutResourceKind::GenericResource;
-
- auto genParamTypeLayout = new GenericParamTypeLayout();
- // we should have already populated ProgramLayout::genericEntryPointParams list at this point,
- // so we can find the index of this generic param decl in the list
- genParamTypeLayout->type = type;
- genParamTypeLayout->paramIndex = findGenericParam(context.programLayout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl());
- genParamTypeLayout->rules = rules;
- genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1;
-
- return TypeLayoutResult(genParamTypeLayout, info);
+ if( auto concreteType = findGlobalGenericSpecializationArg(
+ context,
+ globalGenericParamDecl) )
+ {
+ // If we know what concrete type has been used to specialize
+ // the global generic type parameter, then we should use
+ // the concrete type instead.
+ //
+ return _createTypeLayout(context, concreteType);
+ }
+ else
+ {
+ // Otherwise we must create a type layout that represents
+ // the generic type parameter itself.
+ //
+ return _createTypeLayoutForGlobalGenericTypeParam(context, type, globalGenericParamDecl);
+ }
}
else if (auto assocTypeParam = declRef.as<AssocTypeDecl>())
{
@@ -2934,9 +2982,11 @@ static TypeLayoutResult _createTypeLayout(
// If there are any concrete types available, the first one will be
// the value that should be plugged into the slot we just introduced.
//
- if( context.existentialTypeArgCount )
+ if( context.specializationArgCount )
{
- RefPtr<Type> concreteType = context.existentialTypeArgs[0].type;
+ auto& specializationArg = context.specializationArgs[0];
+ RefPtr<Type> concreteType = specializationArg.val.as<Type>();
+ SLANG_ASSERT(concreteType);
RefPtr<TypeLayout> concreteTypeLayout = createTypeLayout(context, concreteType);
@@ -3046,9 +3096,9 @@ static TypeLayoutResult _createTypeLayout(
}
else if( auto existentialSpecializedType = as<ExistentialSpecializedType>(type) )
{
- TypeLayoutContext subContext = context.withExistentialTypeArgs(
- existentialSpecializedType->slots.args.getCount(),
- existentialSpecializedType->slots.args.getBuffer());
+ TypeLayoutContext subContext = context.withSpecializationArgs(
+ existentialSpecializedType->args.getBuffer(),
+ existentialSpecializedType->args.getCount());
auto baseTypeLayoutResult = _createTypeLayout(
subContext,
diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h
index e066c2700..b7b3c3207 100644
--- a/source/slang/slang-type-layout.h
+++ b/source/slang/slang-type-layout.h
@@ -622,7 +622,7 @@ class GenericParamTypeLayout : public TypeLayout
{
public:
RefPtr<GlobalGenericParamDecl> getGlobalGenericParamDecl();
- int paramIndex = 0;
+ Index paramIndex = 0;
};
/// Layout information for a tagged union type.
@@ -667,8 +667,6 @@ public:
StructTypeLayout* getScopeStructLayout(
ScopeLayout* programLayout);
-class EntryPointGroupLayout;
-
// Layout information for a single shader entry point
// within a program
//
@@ -682,7 +680,10 @@ class EntryPointLayout : public ScopeLayout
{
public:
// The corresponding function declaration
- RefPtr<FuncDecl> entryPoint;
+ DeclRef<FuncDecl> entryPoint;
+
+ DeclRef<FuncDecl> getFuncDeclRef() { return entryPoint; }
+ FuncDecl* getFuncDecl() { return entryPoint.getDecl(); }
// The shader profile that was used to compile the entry point
Profile profile;
@@ -696,30 +697,38 @@ public:
};
unsigned flags = 0;
- /// Layouts for all tagged union types required by this entry point.
- ///
- /// These are any tagged union types used by the generic
- /// arguments that this entry point is being compiled with.
- List<RefPtr<TypeLayout>> taggedUnionTypeLayouts;
-
EntryPointLayout* getAbsoluteLayout(VarLayout* parentLayout);
- EntryPointLayout* getAbsoluteLayout(EntryPointGroupLayout* parentGroup);
RefPtr<EntryPointLayout> m_absoluteLayout;
};
-class EntryPointGroupLayout : public ScopeLayout
+ /// Reflection/layout information about a specialization parameter
+class SpecializationParamLayout : public Layout
{
public:
- RefPtr<EntryPointGroup> group;
- List<RefPtr<EntryPointLayout>> entryPoints;
+ Index index;
};
-class GenericParamLayout : public Layout
+ /// Reflection/layout information about a generic specialization parameter
+class GenericSpecializationParamLayout : public SpecializationParamLayout
{
public:
- RefPtr<GlobalGenericParamDecl> decl;
- int index;
+ /// The declaration of the generic parameter.
+ ///
+ /// Could be any subclass of `Decl` that represents a generic value or type parameter.
+ RefPtr<Decl> decl;
+};
+
+ /// Reflection/layout information about an existential/interface specialization parameter.
+class ExistentialSpecializationParamLayout : public SpecializationParamLayout
+{
+public:
+ /// The type that needs to be specialized.
+ ///
+ /// Currently, this will be an `interface` type that any concrete
+ /// type argument getting plugged in must conform to.
+ ///
+ RefPtr<Type> type;
};
// Layout information for the global scope of a program
@@ -748,7 +757,7 @@ public:
TargetProgram* getTargetProgram() { return targetProgram; }
TargetRequest* getTargetReq() { return targetProgram->getTargetReq(); }
- Program* getProgram() { return targetProgram->getProgram(); }
+ ComponentType* getProgram() { return targetProgram->getProgram(); }
// We catalog the requested entry points here,
@@ -756,12 +765,21 @@ public:
// will (eventually) belong there...
List<RefPtr<EntryPointLayout>> entryPoints;
- // Entry points can also be grouped for layout purposes (e.g., to form
- // ray-tracing hit groups), so this array represents those groups
- List<RefPtr<EntryPointGroupLayout>> entryPointGroups;
+ /// Reflection information on (unspecialized) specialization parameters.
+ List<RefPtr<SpecializationParamLayout>> specializationParams;
- List<RefPtr<GenericParamLayout>> globalGenericParams;
- Dictionary<String, GenericParamLayout*> globalGenericParamsMap;
+ /// Concrete argument values that were provided to specific global generic parameters.
+ ///
+ /// Not useful for reflection, but valuable for code generation.
+ ///
+ Dictionary<GlobalGenericParamDecl*, RefPtr<Val>> globalGenericArgs;
+
+ /// Layouts for all tagged union types required by this program
+ ///
+ /// These are any tagged union types used by the specialization
+ /// arguments that have been used to specialize the program.
+ ///
+ List<RefPtr<TypeLayout>> taggedUnionTypeLayouts;
};
StructTypeLayout* getGlobalStructLayout(
@@ -908,8 +926,6 @@ struct LayoutRulesFamilyImpl
virtual LayoutRulesImpl* getShaderRecordConstantBufferRules() = 0;
};
-typedef List<RefPtr<GenericParamLayout>> GenericParamLayouts;
-
struct TypeLayoutContext
{
// The layout rules to use (e.g., we compute
@@ -930,10 +946,10 @@ struct TypeLayoutContext
MatrixLayoutMode matrixLayoutMode;
// The concrete types (if any) to plug into the currently in-scope
- // existential type slots.
+ // specialization params.
//
- Int existentialTypeArgCount = 0;
- ExistentialTypeSlots::Arg const* existentialTypeArgs = nullptr;
+ Int specializationArgCount = 0;
+ ExpandedSpecializationArg const* specializationArgs = nullptr;
LayoutRulesImpl* getRules() { return rules; }
LayoutRulesFamilyImpl* getRulesFamily() const { return rules->getLayoutRulesFamily(); }
@@ -952,29 +968,29 @@ struct TypeLayoutContext
return result;
}
- TypeLayoutContext withExistentialTypeArgs(
- Int argCount,
- ExistentialTypeSlots::Arg const* args) const
+ TypeLayoutContext withSpecializationArgs(
+ ExpandedSpecializationArg const* args,
+ Int argCount) const
{
TypeLayoutContext result = *this;
- result.existentialTypeArgCount = argCount;
- result.existentialTypeArgs = args;
+ result.specializationArgCount = argCount;
+ result.specializationArgs = args;
return result;
}
- TypeLayoutContext withExistentialTypeSlotsOffsetBy(
+ TypeLayoutContext withSpecializationArgsOffsetBy(
Int offset) const
{
TypeLayoutContext result = *this;
- if( existentialTypeArgCount > offset )
+ if( specializationArgCount > offset )
{
- result.existentialTypeArgCount = existentialTypeArgCount - offset;
- result.existentialTypeArgs = existentialTypeArgs + offset;
+ result.specializationArgCount = specializationArgCount - offset;
+ result.specializationArgs = specializationArgs + offset;
}
else
{
- result.existentialTypeArgCount = 0;
- result.existentialTypeArgs = nullptr;
+ result.specializationArgCount = 0;
+ result.specializationArgs = nullptr;
}
return result;
@@ -1061,6 +1077,8 @@ public:
///
TypeLayoutResult getTypeLayoutResult();
+ UniformLayoutInfo* getStructLayoutInfo() { return &m_info; }
+
private:
/// The layout rules being used, if layout has begun.
LayoutRulesImpl* m_rules = nullptr;
@@ -1139,8 +1157,16 @@ createStructuredBufferTypeLayout(
RefPtr<Type> structuredBufferType,
RefPtr<Type> elementType);
-int findGenericParam(List<RefPtr<GenericParamLayout>> & genericParameters, GlobalGenericParamDecl * decl);
-//
+ /// Create a type layout for an unspecialized `globalGenericParamDecl`.
+RefPtr<TypeLayout> createTypeLayoutForGlobalGenericTypeParam(
+ TypeLayoutContext const& context,
+ Type* type,
+ GlobalGenericParamDecl* globalGenericParamDecl);
+
+ /// Find the concrete type (if any) that was plugged in for the global generic type parameter `decl`.
+RefPtr<Type> findGlobalGenericSpecializationArg(
+ TypeLayoutContext const& context,
+ GlobalGenericParamDecl* decl);
// Given an existing type layout `oldTypeLayout`, apply offsets
// to any contained fields based on the resource infos in `offsetVarLayout`.
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index b18e4d4d9..b030e5cf9 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -436,44 +436,31 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModule(
return asExternal(module);
}
-SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createProgram(
- slang::ProgramDesc const& desc,
- slang::IProgram** outProgram)
+SLANG_NO_THROW slang::IComponentType* SLANG_MCALL Linkage::createCompositeComponentType(
+ slang::IComponentType* const* componentTypes,
+ SlangInt componentTypeCount,
+ ISlangBlob** outDiagnostics)
{
- RefPtr<Program> program = new Program(this);
-
- auto itemCount = desc.itemCount;
- for(SlangInt ii = 0; ii < itemCount; ++ii)
- {
- auto& item = desc.items[ii];
- switch(item.kind)
- {
- case slang::ProgramDesc::Item::Kind::Program:
- {
- Program* existingProgram = asInternal(item.program);
- for(auto referencedModule : existingProgram->getModuleDependencies())
- {
- program->addReferencedLeafModule(referencedModule);
- }
-
- // TODO: Need to decide whether to include the entry points as well...
- }
- break;
+ // Attempting to create a "composite" of just one component type should
+ // just return the component type itself, to avoid redundant work.
+ //
+ if( componentTypeCount == 1)
+ return componentTypes[0];
- case slang::ProgramDesc::Item::Kind::Module:
- {
- Module* module = asInternal(item.module);
- program->addReferencedModule(module);
- }
- break;
+ DiagnosticSink sink(getSourceManager());
- default:
- return SLANG_E_INVALID_ARG;
- }
+ List<RefPtr<ComponentType>> childComponents;
+ for( Int cc = 0; cc < componentTypeCount; ++cc )
+ {
+ childComponents.add(asInternal(componentTypes[cc]));
}
- *outProgram = asExternal(program.detach());
- return SLANG_OK;
+ RefPtr<ComponentType> composite = CompositeComponentType::create(
+ this,
+ childComponents);
+
+ sink.getBlobIfNeeded(outDiagnostics);
+ return asExternal(composite.detach());
}
SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType(
@@ -781,7 +768,9 @@ RefPtr<Type> checkProperType(
TypeExp typeExp,
DiagnosticSink* sink);
-Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink)
+Type* ComponentType::getTypeFromString(
+ String const& typeStr,
+ DiagnosticSink* sink)
{
// If we've looked up this type name before,
// then we can re-use it.
@@ -803,7 +792,7 @@ Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink)
for(auto module : getModuleDependencies())
scopesToTry.add(module->getModuleDecl()->scope);
- auto linkage = getLinkageImpl();
+ auto linkage = getLinkage();
for(auto& s : scopesToTry)
{
RefPtr<Expr> typeExpr = linkage->parseTypeString(
@@ -899,10 +888,16 @@ void FrontEndCompileRequest::parseTranslationUnit(
}
}
-RefPtr<Program> createUnspecializedProgram(
+RefPtr<ComponentType> createUnspecializedGlobalComponentType(
+ FrontEndCompileRequest* compileRequest);
+
+RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType(
FrontEndCompileRequest* compileRequest);
-RefPtr<Program> createSpecializedProgram(
+RefPtr<ComponentType> createSpecializedGlobalComponentType(
+ EndToEndCompileRequest* endToEndReq);
+
+RefPtr<ComponentType> createSpecializedGlobalAndEntryPointsComponentType(
EndToEndCompileRequest* endToEndReq);
void FrontEndCompileRequest::checkAllTranslationUnits()
@@ -1032,7 +1027,11 @@ SlangResult FrontEndCompileRequest::executeActionsInner()
// Look up all the entry points that are expected,
// and use them to populate the `program` member.
//
- m_program = createUnspecializedProgram(this);
+ m_globalComponentType = createUnspecializedGlobalComponentType(this);
+ if (getSink()->GetErrorCount() != 0)
+ return SLANG_FAIL;
+
+ m_globalAndEntryPointsComponentType = createUnspecializedGlobalAndEntryPointsComponentType(this);
if (getSink()->GetErrorCount() != 0)
return SLANG_FAIL;
@@ -1052,7 +1051,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner()
//
for(auto targetReq : getLinkage()->targets)
{
- auto targetProgram = m_program->getTargetProgram(targetReq);
+ auto targetProgram = m_globalAndEntryPointsComponentType->getTargetProgram(targetReq);
targetProgram->getOrCreateLayout(getSink());
}
if (getSink()->GetErrorCount() != 0)
@@ -1064,7 +1063,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner()
BackEndCompileRequest::BackEndCompileRequest(
Linkage* linkage,
DiagnosticSink* sink,
- Program* program)
+ ComponentType* program)
: CompileRequestBase(linkage, sink)
, m_program(program)
{}
@@ -1146,7 +1145,8 @@ SlangResult EndToEndCompileRequest::executeActionsInner()
// that was computed in the front-end for all subsequent
// reflection queries, etc.
//
- m_specializedProgram = getUnspecializedProgram();
+ m_specializedGlobalComponentType = getUnspecializedGlobalComponentType();
+ m_specializedGlobalAndEntryPointsComponentType = getUnspecializedGlobalAndEntryPointsComponentType();
return SLANG_OK;
}
@@ -1156,7 +1156,11 @@ SlangResult EndToEndCompileRequest::executeActionsInner()
//
if (passThrough == PassThroughMode::None)
{
- m_specializedProgram = createSpecializedProgram(this);
+ m_specializedGlobalComponentType = createSpecializedGlobalComponentType(this);
+ if (getSink()->GetErrorCount() != 0)
+ return SLANG_FAIL;
+
+ m_specializedGlobalAndEntryPointsComponentType = createSpecializedGlobalAndEntryPointsComponentType(this);
if (getSink()->GetErrorCount() != 0)
return SLANG_FAIL;
@@ -1166,7 +1170,7 @@ SlangResult EndToEndCompileRequest::executeActionsInner()
//
for (auto targetReq : getLinkage()->targets)
{
- auto targetProgram = m_specializedProgram->getTargetProgram(targetReq);
+ auto targetProgram = m_specializedGlobalAndEntryPointsComponentType->getTargetProgram(targetReq);
targetProgram->getOrCreateLayout(getSink());
}
if (getSink()->GetErrorCount() != 0)
@@ -1178,20 +1182,27 @@ SlangResult EndToEndCompileRequest::executeActionsInner()
// to make sure that the logic in `generateOutput`
// sees something worth processing.
//
- auto specializedProgram = new Program(getLinkage());
- m_specializedProgram = specializedProgram;
+ List<RefPtr<ComponentType>> dummyEntryPoints;
for(auto entryPointReq : getFrontEndReq()->getEntryPointReqs())
{
- RefPtr<EntryPoint> entryPoint = EntryPoint::createDummyForPassThrough(
+ RefPtr<EntryPoint> dummyEntryPoint = EntryPoint::createDummyForPassThrough(
+ getLinkage(),
entryPointReq->getName(),
entryPointReq->getProfile());
- specializedProgram->addEntryPoint(entryPoint, getSink());
+ dummyEntryPoints.add(dummyEntryPoint);
}
+
+ RefPtr<ComponentType> composedProgram = CompositeComponentType::create(
+ getLinkage(),
+ dummyEntryPoints);
+
+ m_specializedGlobalComponentType = getUnspecializedGlobalComponentType();
+ m_specializedGlobalAndEntryPointsComponentType = composedProgram;
}
// Generate output code, in whatever format was requested
- getBackEndReq()->setProgram(getSpecializedProgram());
+ getBackEndReq()->setProgram(getSpecializedGlobalAndEntryPointsComponentType());
generateOutput(this);
if (getSink()->GetErrorCount() != 0)
return SLANG_FAIL;
@@ -1324,7 +1335,7 @@ int EndToEndCompileRequest::addEntryPoint(
EntryPointInfo entryPointInfo;
for (auto typeName : genericTypeNames)
- entryPointInfo.genericArgStrings.add(typeName);
+ entryPointInfo.specializationArgStrings.add(typeName);
Index result = entryPoints.getCount();
entryPoints.add(_Move(entryPointInfo));
@@ -1617,8 +1628,10 @@ void FilePathDependencyList::addDependency(Module* module)
//
Module::Module(Linkage* linkage)
- : m_linkage(linkage)
-{}
+ : ComponentType(linkage)
+{
+ addModuleDependency(this);
+}
ISlangUnknown* Module::getInterface(const Guid& guid)
{
@@ -1638,35 +1651,40 @@ void Module::addFilePathDependency(String const& path)
m_filePathDependencyList.addDependency(path);
}
-// Program
+void Module::setModuleDecl(ModuleDecl* moduleDecl)
+{
+ m_moduleDecl = moduleDecl;
+}
+
+// ComponentType
-static const Guid IID_IProgram = SLANG_UUID_IProgram;
+static const Guid IID_IComponentType = SLANG_UUID_IComponentType;
-Program::Program(Linkage* linkage)
+ComponentType::ComponentType(Linkage* linkage)
: m_linkage(linkage)
{}
-ISlangUnknown* Program::getInterface(Guid const& guid)
+ISlangUnknown* ComponentType::getInterface(Guid const& guid)
{
if(guid == IID_ISlangUnknown
- || guid == IID_IProgram)
+ || guid == IID_IComponentType)
{
- return static_cast<slang::IProgram*>(this);
+ return static_cast<slang::IComponentType*>(this);
}
return nullptr;
}
-SLANG_NO_THROW slang::ISession* SLANG_MCALL Program::getSession()
+SLANG_NO_THROW slang::ISession* SLANG_MCALL ComponentType::getSession()
{
return m_linkage;
}
-SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout(
+SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL ComponentType::getLayout(
Int targetIndex,
slang::IBlob** outDiagnostics)
{
- auto linkage = getLinkageImpl();
+ auto linkage = getLinkage();
if(targetIndex < 0 || targetIndex >= linkage->targets.getCount())
return nullptr;
auto target = linkage->targets[targetIndex];
@@ -1678,13 +1696,13 @@ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout(
return asExternal(programLayout);
}
-SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode(
+SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCode(
SlangInt entryPointIndex,
Int targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics)
{
- auto linkage = getLinkageImpl();
+ auto linkage = getLinkage();
if(targetIndex < 0 || targetIndex >= linkage->targets.getCount())
return SLANG_E_INVALID_ARG;
auto target = linkage->targets[targetIndex];
@@ -1702,57 +1720,535 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode(
return SLANG_OK;
}
+RefPtr<ComponentType> ComponentType::specialize(
+ SpecializationArg const* inSpecializationArgs,
+ SlangInt specializationArgCount,
+ DiagnosticSink* sink)
+{
+ List<SpecializationArg> specializationArgs;
+ specializationArgs.addRange(
+ inSpecializationArgs,
+ specializationArgCount);
+
+ // We next need to validate that the specialization arguments
+ // make sense, and also expand them to include any derived data
+ // (e.g., interface conformance witnesses) that doesn't get
+ // passed explicitly through the API interface.
+ //
+ RefPtr<SpecializationInfo> specializationInfo = _validateSpecializationArgs(
+ specializationArgs.getBuffer(),
+ specializationArgCount,
+ sink);
+
+ return new SpecializedComponentType(
+ this,
+ specializationInfo,
+ specializationArgs,
+ sink);
+}
-void Program::addReferencedModule(Module* module)
+SLANG_NO_THROW slang::IComponentType* SLANG_MCALL ComponentType::specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ ISlangBlob** outDiagnostics)
{
- m_moduleDependencyList.addDependency(module);
- m_filePathDependencyList.addDependency(module);
+ DiagnosticSink sink(getLinkage()->getSourceManager());
+
+ // First let's check if the number of arguments given matches
+ // the number of parameters that are present on this component type.
+ //
+ auto specializationParamCount = getSpecializationParamCount();
+ if( specializationArgCount != specializationParamCount )
+ {
+ // TODO: diagnose
+ sink.getBlobIfNeeded(outDiagnostics);
+ return nullptr;
+ }
+
+ List<SpecializationArg> expandedArgs;
+ for( Int aa = 0; aa < specializationArgCount; ++aa )
+ {
+ auto apiArg = specializationArgs[aa];
+
+ SpecializationArg expandedArg;
+ switch(apiArg.kind)
+ {
+ case slang::SpecializationArg::Kind::Type:
+ expandedArg.val = asInternal(apiArg.type);
+ break;
+
+ default:
+ sink.getBlobIfNeeded(outDiagnostics);
+ return nullptr;
+ }
+ expandedArgs.add(expandedArg);
+ }
+
+ auto specializedComponentType = specialize(
+ expandedArgs.getBuffer(),
+ expandedArgs.getCount(),
+ &sink);
+
+ sink.getBlobIfNeeded(outDiagnostics);
+
+ return specializedComponentType;
}
-void Program::addReferencedLeafModule(Module* module)
+ /// Visitor used by `ComponentType::enumerateModules`
+struct EnumerateModulesVisitor : ComponentTypeVisitor
{
- m_moduleDependencyList.addLeafDependency(module);
- m_filePathDependencyList.addDependency(module);
+ EnumerateModulesVisitor(ComponentType::EnumerateModulesCallback callback, void* userData)
+ : m_callback(callback)
+ , m_userData(userData)
+ {}
+
+ ComponentType::EnumerateModulesCallback m_callback;
+ void* m_userData;
+
+ void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {}
+
+ void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE
+ {
+ m_callback(module, m_userData);
+ }
+
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(composite, specializationInfo);
+ }
+
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
+ {
+ visitChildren(specialized);
+ }
+
+ void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(legacy, specializationInfo);
+ }
+};
+
+
+void ComponentType::enumerateModules(EnumerateModulesCallback callback, void* userData)
+{
+ EnumerateModulesVisitor visitor(callback, userData);
+ acceptVisitor(&visitor, nullptr);
+}
+
+ /// Visitor used by `ComponentType::enumerateIRModules`
+struct EnumerateIRModulesVisitor : ComponentTypeVisitor
+{
+ EnumerateIRModulesVisitor(ComponentType::EnumerateIRModulesCallback callback, void* userData)
+ : m_callback(callback)
+ , m_userData(userData)
+ {}
+
+ ComponentType::EnumerateIRModulesCallback m_callback;
+ void* m_userData;
+
+ void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {}
+
+ void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE
+ {
+ m_callback(module->getIRModule(), m_userData);
+ }
+
+ void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(composite, specializationInfo);
+ }
+
+ void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE
+ {
+ visitChildren(specialized);
+
+ m_callback(specialized->getIRModule(), m_userData);
+ }
+
+ void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE
+ {
+ visitChildren(legacy, specializationInfo);
+ }
+};
+
+void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData)
+{
+ EnumerateIRModulesVisitor visitor(callback, userData);
+ acceptVisitor(&visitor, nullptr);
+}
+
+//
+// CompositeComponentType
+//
+
+RefPtr<ComponentType> CompositeComponentType::create(
+ Linkage* linkage,
+ List<RefPtr<ComponentType>> const& childComponents)
+{
+ // TODO: We should ideally be caching the results of
+ // composition on the `linkage`, so that if we get
+ // asked for the same composite again later we re-use
+ // it rather than re-create it.
+ //
+ // Similarly, we might want to do some amount of
+ // work to "canonicalize" the input for composition.
+ // E.g., if the user does:
+ //
+ // X = compose(A,B);
+ // Y = compose(C,D);
+ // Z = compose(X,Y);
+ //
+ // W = compose(A, B, C, D);
+ //
+ // Then there is no observable difference between
+ // Z and W, so we might prefer to have them be identical.
+
+ // If there is only a single child, then we should
+ // just return that child rather than create a dummy composite.
+ //
+ if( childComponents.getCount() == 1 )
+ {
+ return childComponents[0];
+ }
+
+ return new CompositeComponentType(linkage, childComponents);
+}
+
+
+CompositeComponentType::CompositeComponentType(
+ Linkage* linkage,
+ List<RefPtr<ComponentType>> const& childComponents)
+ : ComponentType(linkage)
+ , m_childComponents(childComponents)
+{
+ HashSet<ComponentType*> requirementsSet;
+ for(auto child : childComponents )
+ {
+ child->enumerateModules([&](Module* module)
+ {
+ requirementsSet.Add(module);
+ });
+ }
+
+ for(auto child : childComponents )
+ {
+ auto childEntryPointCount = child->getEntryPointCount();
+ for(Index cc = 0; cc < childEntryPointCount; ++cc)
+ {
+ m_entryPoints.add(child->getEntryPoint(cc));
+ }
+
+ auto childShaderParamCount = child->getShaderParamCount();
+ for(Index pp = 0; pp < childShaderParamCount; ++pp)
+ {
+ m_shaderParams.add(child->getShaderParam(pp));
+ }
+
+ auto childSpecializationParamCount = child->getSpecializationParamCount();
+ for(Index pp = 0; pp < childSpecializationParamCount; ++pp)
+ {
+ m_specializationParams.add(child->getSpecializationParam(pp));
+ }
+
+ for(auto module : child->getModuleDependencies())
+ {
+ m_moduleDependencyList.addDependency(module);
+ }
+ for(auto filePath : child->getFilePathDependencies())
+ {
+ m_filePathDependencyList.addDependency(filePath);
+ }
+
+ auto childRequirementCount = child->getRequirementCount();
+ for(Index rr = 0; rr < childRequirementCount; ++rr)
+ {
+ auto childRequirement = child->getRequirement(rr);
+ if(!requirementsSet.Contains(childRequirement))
+ {
+ requirementsSet.Add(childRequirement);
+ m_requirements.add(childRequirement);
+ }
+ }
+ }
+}
+
+Index CompositeComponentType::getEntryPointCount()
+{
+ return m_entryPoints.getCount();
+}
+
+RefPtr<EntryPoint> CompositeComponentType::getEntryPoint(Index index)
+{
+ return m_entryPoints[index];
+}
+
+Index CompositeComponentType::getShaderParamCount()
+{
+ return m_shaderParams.getCount();
}
-void Program::addEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink)
+GlobalShaderParamInfo CompositeComponentType::getShaderParam(Index index)
{
- List<RefPtr<EntryPoint>> entryPoints;
- entryPoints.add(entryPoint);
+ return m_shaderParams[index];
+}
+
+Index CompositeComponentType::getSpecializationParamCount()
+{
+ return m_specializationParams.getCount();
+}
+
+SpecializationParam const& CompositeComponentType::getSpecializationParam(Index index)
+{
+ return m_specializationParams[index];
+}
+
+Index CompositeComponentType::getRequirementCount()
+{
+ return m_requirements.getCount();
+}
+
+RefPtr<ComponentType> CompositeComponentType::getRequirement(Index index)
+{
+ return m_requirements[index];
+}
+
+List<Module*> const& CompositeComponentType::getModuleDependencies()
+{
+ return m_moduleDependencyList.getModuleList();
+}
- RefPtr<EntryPointGroup> entryPointGroup = EntryPointGroup::create(getLinkageImpl(), entryPoints, sink);
+List<String> const& CompositeComponentType::getFilePathDependencies()
+{
+ return m_filePathDependencyList.getFilePathList();
+}
- addEntryPointGroup(entryPointGroup);
+void CompositeComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
+{
+ visitor->visitComposite(this, as<CompositeSpecializationInfo>(specializationInfo));
}
-void Program::addEntryPointGroup(EntryPointGroup* entryPointGroup)
+
+RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink)
{
- m_entryPointGroups.add(entryPointGroup);
+ SLANG_UNUSED(argCount);
+
+ RefPtr<CompositeSpecializationInfo> specializationInfo = new CompositeSpecializationInfo();
- for(auto entryPoint : entryPointGroup->getEntryPoints())
+ Index offset = 0;
+ for(auto child : m_childComponents)
{
- m_entryPoints.add(entryPoint);
- for(auto module : entryPoint->getModuleDependencies())
+ auto childParamCount = child->getSpecializationParamCount();
+ SLANG_ASSERT(offset + childParamCount <= argCount);
+
+ auto childInfo = child->_validateSpecializationArgs(
+ args + offset,
+ childParamCount,
+ sink);
+
+ specializationInfo->childInfos.add(childInfo);
+
+ offset += childParamCount;
+ }
+ return specializationInfo;
+}
+
+//
+// SpecializedComponentType
+//
+
+SpecializedComponentType::SpecializedComponentType(
+ ComponentType* base,
+ ComponentType::SpecializationInfo* specializationInfo,
+ List<SpecializationArg> const& specializationArgs,
+ DiagnosticSink* sink)
+ : ComponentType(base->getLinkage())
+ , m_base(base)
+ , m_specializationInfo(specializationInfo)
+ , m_specializationArgs(specializationArgs)
+{
+ m_irModule = generateIRForSpecializedComponentType(this, sink);
+
+ // The following is a bit of a hack.
+ //
+ // Back-end code generation relies on us having computed layouts for all tagged
+ // unions that end up being used in the code, which means we need a way to find
+ // all such types that get used in a program (and the stuff it imports).
+ //
+ // For now we are assuming a tagged union type only comes into existence
+ // as a (top-level) argument for a generic type parameter, so that we
+ // can check for them here and cache them on the entry point.
+ //
+ // A longer-term strategy might need to consider any (tagged or untagged)
+ // union types that get used inside of a module, and also take
+ // those lists into account.
+ //
+ // An even longer-term strategy would be to allow type layout to
+ // be performed on IR types, so taht we don't need to have front-end
+ // code worrying about this stuff.
+ //
+ for(auto arg : specializationArgs)
+ {
+ auto argType = as<Type>(arg.val);
+ if(!argType)
+ continue;
+
+ auto taggedUnionType = as<TaggedUnionType>(argType);
+ if(!taggedUnionType)
+ continue;
+
+ m_taggedUnionTypes.add(taggedUnionType);
+ }
+}
+
+void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
+{
+ SLANG_ASSERT(specializationInfo == nullptr);
+ SLANG_UNUSED(specializationInfo);
+ visitor->visitSpecialized(this);
+}
+
+Index SpecializedComponentType::getRequirementCount()
+{
+ // TODO: A specialized component type may have *more* requirements
+ // than the original, because it also needs to include the module(s)
+ // that define the types used for specialization arguments.
+
+ return m_base->getRequirementCount();
+}
+
+RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index)
+{
+ return m_base->getRequirement(index);
+}
+
+//
+// LegacyProgram
+//
+
+LegacyProgram::LegacyProgram(
+ Linkage* linkage,
+ List<RefPtr<TranslationUnitRequest>> const& translationUnits,
+ DiagnosticSink* sink)
+ : ComponentType(linkage)
+ , m_translationUnits(translationUnits)
+{
+ HashSet<ComponentType*> requirementsSet;
+
+ for(auto translationUnit : translationUnits )
+ {
+ ComponentType* child = translationUnit->getModule();
+
+ auto childEntryPointCount = child->getEntryPointCount();
+ for(Index cc = 0; cc < childEntryPointCount; ++cc)
+ {
+ m_entryPoints.add(child->getEntryPoint(cc));
+ }
+
+ for(auto module : child->getModuleDependencies())
+ {
+ m_moduleDependencies.addDependency(module);
+ }
+ for(auto filePath : child->getFilePathDependencies())
{
- addReferencedModule(module);
+ m_fileDependencies.addDependency(filePath);
+ }
+
+ auto childRequirementCount = child->getRequirementCount();
+ for(Index rr = 0; rr < childRequirementCount; ++rr)
+ {
+ auto childRequirement = child->getRequirement(rr);
+ if(!requirementsSet.Contains(childRequirement))
+ {
+ requirementsSet.Add(childRequirement);
+ m_requirements.add(childRequirement);
+ }
}
}
+
+ _collectShaderParams(sink);
}
-RefPtr<IRModule> Program::getOrCreateIRModule(DiagnosticSink* sink)
+void LegacyProgram::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
{
- if(!m_irModule)
+ visitor->visitLegacy(this, as<CompositeComponentType::CompositeSpecializationInfo>(specializationInfo));
+}
+
+RefPtr<ComponentType::SpecializationInfo> LegacyProgram::_validateSpecializationArgsImpl(
+ SpecializationArg const* args,
+ Index argCount,
+ DiagnosticSink* sink)
+{
+ SLANG_UNUSED(argCount);
+
+ RefPtr<CompositeComponentType::CompositeSpecializationInfo> info = new CompositeComponentType::CompositeSpecializationInfo();
+
+ Index offset = 0;
+ for(auto translationUnit : m_translationUnits)
{
- m_irModule = generateIRForProgram(
- m_linkage->getSessionImpl(),
- this,
+ ComponentType* child = translationUnit->getModule();
+ auto childParamCount = child->getSpecializationParamCount();
+ SLANG_ASSERT(offset + childParamCount <= argCount);
+
+ auto childInfo = child->_validateSpecializationArgs(
+ args + offset,
+ childParamCount,
sink);
+
+ info->childInfos.add(childInfo);
+
+ offset += childParamCount;
}
- return m_irModule;
+ return info;
}
+Index LegacyProgram::getRequirementCount()
+{
+ return m_requirements.getCount();
+}
+
+RefPtr<ComponentType> LegacyProgram::getRequirement(Index index)
+{
+ return m_requirements[index];
+}
+
+void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo)
+{
+ auto childCount = composite->getChildComponentCount();
+ for(Index ii = 0; ii < childCount; ++ii)
+ {
+ auto child = composite->getChildComponent(ii);
+ auto childSpecializationInfo = specializationInfo
+ ? specializationInfo->childInfos[ii]
+ : nullptr;
+
+ child->acceptVisitor(this, childSpecializationInfo);
+ }
+}
+
+void ComponentTypeVisitor::visitChildren(SpecializedComponentType* specialized)
+{
+ specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo());
+}
+
+void ComponentTypeVisitor::visitChildren(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo)
+{
+ auto childCount = legacy->getTranslationUnitCount();
+ for(Index ii = 0; ii < childCount; ++ii)
+ {
+ auto translationUnit = legacy->getTranslationUnit(ii);
+ ComponentType* child = translationUnit->getModule();
+ auto childSpecializationInfo = specializationInfo
+ ? specializationInfo->childInfos[ii]
+ : nullptr;
+
+ child->acceptVisitor(this, childSpecializationInfo);
+ }
+}
-TargetProgram* Program::getTargetProgram(TargetRequest* target)
+TargetProgram* ComponentType::getTargetProgram(TargetRequest* target)
{
RefPtr<TargetProgram> targetProgram;
if(!m_targetPrograms.TryGetValue(target, targetProgram))
@@ -1768,12 +2264,12 @@ TargetProgram* Program::getTargetProgram(TargetRequest* target)
//
TargetProgram::TargetProgram(
- Program* program,
+ ComponentType* componentType,
TargetRequest* targetReq)
- : m_program(program)
+ : m_program(componentType)
, m_targetReq(targetReq)
{
- m_entryPointResults.setCount(program->getEntryPoints().getCount());
+ m_entryPointResults.setCount(componentType->getEntryPointCount());
}
//
@@ -2458,10 +2954,10 @@ SLANG_API SlangResult spSetGlobalGenericArgs(
if (!request) return SLANG_FAIL;
auto req = convert(request);
- auto& genericArgStrings = req->globalGenericArgStrings;
- genericArgStrings.clear();
+ auto& argStrings = req->globalSpecializationArgStrings;
+ argStrings.clear();
for (int i = 0; i < genericArgCount; i++)
- genericArgStrings.add(genericArgs[i]);
+ argStrings.add(genericArgs[i]);
return SLANG_OK;
}
@@ -2477,7 +2973,7 @@ SLANG_API SlangResult spSetTypeNameForGlobalExistentialTypeParam(
if(!typeName) return SLANG_FAIL;
auto req = convert(request);
- auto& typeArgStrings = req->globalExistentialSlotArgStrings;
+ auto& typeArgStrings = req->globalSpecializationArgStrings;
if(Index(slotIndex) >= typeArgStrings.getCount())
typeArgStrings.setCount(slotIndex+1);
typeArgStrings[slotIndex] = String(typeName);
@@ -2501,7 +2997,7 @@ SLANG_API SlangResult spSetTypeNameForEntryPointExistentialTypeParam(
return SLANG_FAIL;
auto& entryPointInfo = req->entryPoints[entryPointIndex];
- auto& typeArgStrings = entryPointInfo.existentialArgStrings;
+ auto& typeArgStrings = entryPointInfo.specializationArgStrings;
if(Index(slotIndex) >= typeArgStrings.getCount())
typeArgStrings.setCount(slotIndex+1);
typeArgStrings[slotIndex] = String(typeName);
@@ -2569,7 +3065,7 @@ spGetDependencyFileCount(
if(!request) return 0;
auto req = convert(request);
auto frontEndReq = req->getFrontEndReq();
- auto program = frontEndReq->getProgram();
+ auto program = frontEndReq->getGlobalAndEntryPointsComponentType();
return (int) program->getFilePathDependencies().getCount();
}
@@ -2583,7 +3079,7 @@ spGetDependencyFilePath(
if(!request) return 0;
auto req = convert(request);
auto frontEndReq = req->getFrontEndReq();
- auto program = frontEndReq->getProgram();
+ auto program = frontEndReq->getGlobalAndEntryPointsComponentType();
return program->getFilePathDependencies()[index].begin();
}
@@ -2613,7 +3109,7 @@ SLANG_API void const* spGetEntryPointCode(
using namespace Slang;
auto req = convert(request);
auto linkage = req->getLinkage();
- auto program = req->getSpecializedProgram();
+ auto program = req->getSpecializedGlobalAndEntryPointsComponentType();
// TODO: We should really accept a target index in this API
Index targetIndex = 0;
@@ -2668,7 +3164,7 @@ SLANG_API SlangResult spGetEntryPointCodeBlob(
auto req = convert(request);
auto linkage = req->getLinkage();
- auto program = req->getSpecializedProgram();
+ auto program = req->getSpecializedGlobalAndEntryPointsComponentType();
Index targetCount = linkage->targets.getCount();
if((targetIndex < 0) || (targetIndex >= targetCount))
@@ -2715,13 +3211,13 @@ SLANG_API void const* spGetCompileRequestCode(
SLANG_API SlangResult spCompileRequest_getProgram(
SlangCompileRequest* request,
- slang::IProgram** outProgram)
+ slang::IComponentType** outProgram)
{
if( !request ) return SLANG_ERROR_INVALID_PARAMETER;
auto req = convert(request);
- auto program = req->getSpecializedProgram();
+ auto program = req->getSpecializedGlobalAndEntryPointsComponentType();
- *outProgram = Slang::ComPtr<slang::IProgram>(program).detach();
+ *outProgram = Slang::ComPtr<slang::IComponentType>(program).detach();
return SLANG_OK;
}
@@ -2731,7 +3227,7 @@ SLANG_API SlangReflection* spGetReflection(
if( !request ) return 0;
auto req = convert(request);
auto linkage = req->getLinkage();
- auto program = req->getSpecializedProgram();
+ auto program = req->getSpecializedGlobalAndEntryPointsComponentType();
// Note(tfoley): The API signature doesn't let the client
// specify which target they want to access reflection