diff options
Diffstat (limited to 'source/slang')
23 files changed, 3406 insertions, 1852 deletions
diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index edf41fff0..b7c1ccdc2 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -7107,8 +7107,7 @@ namespace Slang { if(context.mode != OverloadResolveContext::Mode::JustTrying) { - // TODO: diagnose a problem here - getSink()->diagnose(context.loc, Diagnostics::unimplemented, "generic constraint not satisfied"); + getSink()->diagnose(context.loc, Diagnostics::typeArgumentDoesNotConformToInterface, sub, sup); } return false; } @@ -9663,15 +9662,15 @@ namespace Slang } } - /// Recursively walk `paramDeclRef` and add any required existential slots to `ioSlots`. - static void _collectExistentialTypeParamsRec( - ExistentialTypeSlots& ioSlots, + /// Recursively walk `paramDeclRef` and add any existential/interface specialization parameters to `ioSpecializationParams`. + static void _collectExistentialSpecializationParamsRec( + SpecializationParams& ioSpecializationParams, DeclRef<VarDeclBase> paramDeclRef); - /// Recursively walk `type` and discover any required existential type parameters. - static void _collectExistentialTypeParamsRec( - ExistentialTypeSlots& ioSlots, - Type* type) + /// Recursively walk `type` and add any existential/interface specialization parameters to `ioSpecializationParams`. + static void _collectExistentialSpecializationParamsRec( + SpecializationParams& ioSpecializationParams, + Type* type) { // Whether or not something is an array does not affect // the number of existential slots it introduces. @@ -9683,7 +9682,9 @@ namespace Slang if( auto parameterGroupType = as<ParameterGroupType>(type) ) { - _collectExistentialTypeParamsRec(ioSlots, parameterGroupType->getElementType()); + _collectExistentialSpecializationParamsRec( + ioSpecializationParams, + parameterGroupType->getElementType()); return; } @@ -9692,9 +9693,14 @@ namespace Slang auto typeDeclRef = declRefType->declRef; if( auto interfaceDeclRef = typeDeclRef.as<InterfaceDecl>() ) { - // Each leaf parameter of interface type adds one slot. + // Each leaf parameter of interface type adds a specialization + // parameter, which determines the concrete type(s) that may + // be provided as arguments for that parameter. // - ioSlots.paramTypes.add(type); + SpecializationParam specializationParam; + specializationParam.flavor = SpecializationParam::Flavor::ExistentialType; + specializationParam.object = type; + ioSpecializationParams.add(specializationParam); } else if( auto structDeclRef = typeDeclRef.as<StructDecl>() ) { @@ -9706,7 +9712,9 @@ namespace Slang if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) continue; - _collectExistentialTypeParamsRec(ioSlots, fieldDeclRef); + _collectExistentialSpecializationParamsRec( + ioSpecializationParams, + fieldDeclRef); } } } @@ -9716,26 +9724,58 @@ namespace Slang // element types. } - static void _collectExistentialTypeParamsRec( - ExistentialTypeSlots& ioSlots, + static void _collectExistentialSpecializationParamsRec( + SpecializationParams& ioSpecializationParams, DeclRef<VarDeclBase> paramDeclRef) { - _collectExistentialTypeParamsRec(ioSlots, GetType(paramDeclRef)); + _collectExistentialSpecializationParamsRec( + ioSpecializationParams, + GetType(paramDeclRef)); } - /// Add information about a shader parameter to `ioParams` and `ioSlots` - static void _collectExistentialSlotsForShaderParam( + /// Collect any interface/existential specialization parameters for `paramDeclRef` into `ioParamInfo` and `ioSpecializationParams` + static void _collectExistentialSpecializationParamsForShaderParam( ShaderParamInfo& ioParamInfo, - ExistentialTypeSlots& ioSlots, + SpecializationParams& ioSpecializationParams, DeclRef<VarDeclBase> paramDeclRef) { - Index startSlot = ioSlots.paramTypes.getCount(); - _collectExistentialTypeParamsRec(ioSlots, paramDeclRef); - Index endSlot = ioSlots.paramTypes.getCount(); + Index beginParamIndex = ioSpecializationParams.getCount(); + _collectExistentialSpecializationParamsRec(ioSpecializationParams, paramDeclRef); + Index endParamIndex = ioSpecializationParams.getCount(); - ioParamInfo.firstExistentialTypeSlot = UInt(startSlot); - ioParamInfo.existentialTypeSlotCount = UInt(endSlot - startSlot);; + ioParamInfo.firstSpecializationParamIndex = beginParamIndex; + ioParamInfo.specializationParamCount = endParamIndex - beginParamIndex; + } + + void EntryPoint::_collectGenericSpecializationParamsRec(Decl* decl) + { + if(!decl) + return; + + _collectGenericSpecializationParamsRec(decl->ParentDecl); + + auto genericDecl = as<GenericDecl>(decl); + if(!genericDecl) + return; + + for(auto m : genericDecl->Members) + { + if(auto genericTypeParam = as<GenericTypeParamDecl>(m)) + { + SpecializationParam param; + param.flavor = SpecializationParam::Flavor::GenericType; + param.object = genericTypeParam; + m_genericSpecializationParams.add(param); + } + else if(auto genericValParam = as<GenericValueParamDecl>(m)) + { + SpecializationParam param; + param.flavor = SpecializationParam::Flavor::GenericValue; + param.object = genericValParam; + m_genericSpecializationParams.add(param); + } + } } /// Enumerate the existential-type parameters of an `EntryPoint`. @@ -9744,6 +9784,24 @@ namespace Slang /// void EntryPoint::_collectShaderParams() { + // We don't currently treat an entry point as having any + // *global* shader parameters. + // + // TODO: We could probably clean up the code a bit by treating + // an entry point as introducing a global shader parameter + // that is based on the implicit "parameters struct" type + // of the entry point itself. + + // We collect the generic parameters of the entry point, + // along with those of any outer generics first. + // + _collectGenericSpecializationParamsRec(getFuncDecl()); + + // After geneic specialization parameters have been collected, + // we look through the value parameters of the entry point + // function and see if any of them introduce existential/interface + // specialization parameters. + // // Note: we defensively test whether there is a function decl-ref // because this routine gets called from the constructor, and // a "dummy" entry point will have a null pointer for the function. @@ -9755,9 +9813,9 @@ namespace Slang ShaderParamInfo shaderParamInfo; shaderParamInfo.paramDeclRef = paramDeclRef; - _collectExistentialSlotsForShaderParam( + _collectExistentialSpecializationParamsForShaderParam( shaderParamInfo, - m_existentialSlots, + m_existentialSpecializationParams, paramDeclRef); m_shaderParams.add(shaderParamInfo); @@ -9765,111 +9823,6 @@ namespace Slang } } -static bool shouldUseFalcorCustomSharedKeywordSemantics( - EntryPointGroup* entryPointGroup) -{ - if( !entryPointGroup->getLinkageImpl()->m_useFalcorCustomSharedKeywordSemantics ) - return false; - - // As a sanity check, if we are being asked to lay out an - // empty entry-point group, then don't apply the convention. - // - if(entryPointGroup->getEntryPointCount() == 0) - return false; - - // Otherwise we will look at the first entry point in the group, - // and use that to determine whether it looks like we are compiling - // ray-tracing shaders or not. - // - switch( entryPointGroup->getEntryPoint(0)->getStage() ) - { - case Stage::AnyHit: - case Stage::Callable: - case Stage::ClosestHit: - case Stage::Intersection: - case Stage::Miss: - case Stage::RayGeneration: - return true; - - default: - return false; - } -} - - -static bool shouldUseFalcorCustomSharedKeywordSemantics( - Program* program) -{ - if( !program->getLinkageImpl()->m_useFalcorCustomSharedKeywordSemantics ) - return false; - - // As a sanity check, if we are being asked to lay out a program - // with *no* entry points, then we don't apply the convention. - // - if(program->getEntryPointGroupCount() == 0) - return false; - - // Otherwise we let the first entry-point group determine if we should - // apply the policy for the entire program. - // - // Note: this could lead to confusing results if a `Program` mixes - // entry point groups for RT and non-RT pipelines, but that isn't - // a scenario we expect to come up and this whole routine is handling - // "do what I mean" semantics for legacy behavior. - // - return shouldUseFalcorCustomSharedKeywordSemantics(program->getEntryPointGroup(0)); -} - - -void EntryPointGroup::_collectShaderParams(DiagnosticSink* sink) -{ - // If and only if we are in the special mode for Falcor support - // where non-`shared` global shader parameters are actually - // supposed to go into the "local root signature," then we - // will consider such parameters as if they were entry-point - // parameters, attached to the group. - // - if( shouldUseFalcorCustomSharedKeywordSemantics(this) ) - { - for( auto module : getModuleDependencies() ) - { - auto moduleDecl = module->getModuleDecl(); - for( auto globalVar : moduleDecl->getMembersOfType<VarDecl>() ) - { - // Don't consider globals that aren't shader parameters. - // - if(!isGlobalShaderParameter(globalVar)) - continue; - - // Don't consider global shader paramters that were marked - // `shared`, since that is how global-root-signature parameters - // are being specified. - // - if( globalVar->HasModifier<HLSLEffectSharedModifier>() ) - continue; - - auto paramDeclRef = makeDeclRef(globalVar.Ptr()); - - ShaderParamInfo shaderParamInfo; - shaderParamInfo.paramDeclRef = paramDeclRef; - - ExistentialTypeSlots slots; - _collectExistentialSlotsForShaderParam( - shaderParamInfo, - slots, - paramDeclRef); - - if( slots.paramTypes.getCount() != 0 ) - { - sink->diagnose(globalVar, Diagnostics::typeParametersNotAllowedOnEntryPointGlobal, globalVar); - } - - m_shaderParams.add(shaderParamInfo); - } - } - } -} - // Validate that an entry point function conforms to any additional // constraints based on the stage (and profile?) it specifies. void validateEntryPoint( @@ -10015,6 +9968,7 @@ void EntryPointGroup::_collectShaderParams(DiagnosticSink* sink) // auto compileRequest = entryPointReq->getCompileRequest(); auto translationUnit = entryPointReq->getTranslationUnit(); + auto linkage = compileRequest->getLinkage(); auto sink = compileRequest->getSink(); auto translationUnitSyntax = translationUnit->getModuleDecl(); @@ -10133,8 +10087,8 @@ void EntryPointGroup::_collectShaderParams(DiagnosticSink* sink) // a more uniform representation in the AST? } - RefPtr<EntryPoint> entryPoint = EntryPoint::create( + linkage, makeDeclRef(entryPointFuncDecl), entryPointProfile); @@ -10548,11 +10502,97 @@ static bool doesParameterMatch( return true; } - /// Enumerate the existential-type parameters of a `Program`. - /// - /// Any parameters found will be added to the list of existential slots on `this`. - /// - void Program::_collectShaderParams(DiagnosticSink* sink) + void Module::_collectShaderParams() + { + auto moduleDecl = m_moduleDecl; + + // We are going to walk the global declarations in the body of the + // module, and use those to build up our lists of: + // + // * Global shader parameters + // * Specialization parameters (both generic and interface/existential) + // * Requirements (`import`ed modules) + // + // For requirements, we want to be careful to only + // add each required module once (in case the same + // module got `import`ed multiple times), so we + // will keep a set of the modules we've already + // seen and processed. + // + HashSet<Module*> requiredModuleSet; + + for( auto globalDecl : moduleDecl->Members ) + { + if(auto globalVar = globalDecl.as<VarDecl>()) + { + // We do not want to consider global variable declarations + // that don't represents shader parameters. This includes + // things like `static` globals and `groupshared` variables. + // + if(!isGlobalShaderParameter(globalVar)) + continue; + + // At this point we know we have a global shader parameter. + + GlobalShaderParamInfo shaderParamInfo; + shaderParamInfo.paramDeclRef = makeDeclRef(globalVar.Ptr()); + + // We need to consider what specialization parameters + // are introduced by this shader parameter. This step + // fills in fields on `shaderParamInfo` so that we + // can assocaite specialization arguments supplied later + // with the correct parameter. + // + _collectExistentialSpecializationParamsForShaderParam( + shaderParamInfo, + m_specializationParams, + makeDeclRef(globalVar.Ptr())); + + m_shaderParams.add(shaderParamInfo); + } + else if( auto globalGenericParam = as<GlobalGenericParamDecl>(globalDecl) ) + { + // A global generic type parameter declaration introduces + // a suitable specialization parameter. + // + SpecializationParam specializationParam; + specializationParam.flavor = SpecializationParam::Flavor::GenericType; + specializationParam.object = globalGenericParam; + m_specializationParams.add(specializationParam); + } + else if( auto importDecl = as<ImportDecl>(globalDecl) ) + { + // An `import` declaration creates a requirement dependency + // from this module to another module. + // + auto importedModule = getModule(importDecl->importedModuleDecl); + if(!requiredModuleSet.Contains(importedModule)) + { + requiredModuleSet.Add(importedModule); + m_requirements.add(importedModule); + } + } + } + } + + Index Module::getRequirementCount() + { + return m_requirements.getCount(); + } + + RefPtr<ComponentType> Module::getRequirement(Index index) + { + return m_requirements[index]; + } + + void Module::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + { + visitor->visitModule(this, as<ModuleSpecializationInfo>(specializationInfo)); + } + + + /// Enumerate the parameters of a `LegacyProgram`. + void LegacyProgram::_collectShaderParams(DiagnosticSink* sink) { // We need to collect all of the global shader parameters // referenced by the compile request, and for each we @@ -10568,10 +10608,22 @@ static bool doesParameterMatch( // To deal with the first issue, we will maintain a map from a parameter // name to the index of an existing parameter with that name. // + // TODO: Eventually we should deprecate support for the + // deduplication feature of `LegaqcyProgram`, at which point + // this entire type and all its complications can be eliminated + // from the code (that includes a lot of support in the "parameter + // binding" step for shader parameters with multiple declarations). + // Until that point this type will have a fair amount of duplication + // with stuff in `Module` and `CompositeComponentType`. + + // We use a dictionary to keep track of any shader parameter + // we've alrady collected with a given name. + // Dictionary<Name*, Int> mapNameToParamIndex; - for( auto module : getModuleDependencies() ) + for( auto translationUnit : m_translationUnits ) { + auto module = translationUnit->getModule(); auto moduleDecl = module->getModuleDecl(); for( auto globalVar : moduleDecl->getMembersOfType<VarDecl>() ) { @@ -10582,27 +10634,6 @@ static bool doesParameterMatch( if(!isGlobalShaderParameter(globalVar)) continue; - // HACK: In order to support existing policy in the Falcor - // application, we support a custom mode where only - // global variables marked as `shared` should be considered - // as global shader parameters (that go in the "global - // root signature" for DXR). - // - // TODO: Eliminate this special case once all of the client - // application code has been ported to use more general-purpose - // Slang mechanisms (e.g., entry-point `uniform` parameters). - // - if( shouldUseFalcorCustomSharedKeywordSemantics(this) ) - { - if( !globalVar->HasModifier<HLSLEffectSharedModifier>() ) - { - // Skip a non-`shared` global for purposes of enumerating - // shader parameters. - // - continue; - } - } - // This declaration may represent the same logical parameter // as a declaration that came from a different translation unit. // If that is the case, we want to re-use the same `ShaderParamInfo` @@ -10646,9 +10677,9 @@ static bool doesParameterMatch( GlobalShaderParamInfo shaderParamInfo; shaderParamInfo.paramDeclRef = makeDeclRef(globalVar.Ptr()); - _collectExistentialSlotsForShaderParam( + _collectExistentialSpecializationParamsForShaderParam( shaderParamInfo, - m_globalExistentialSlots, + m_specializationParams, makeDeclRef(globalVar.Ptr())); m_shaderParams.add(shaderParamInfo); @@ -10656,13 +10687,59 @@ static bool doesParameterMatch( } } - /// Create a `Program` to represent the compiled code. + /// Create a new component type based on `inComponentType`, but with all its requiremetns filled. + RefPtr<ComponentType> fillRequirements( + ComponentType* inComponentType) + { + auto linkage = inComponentType->getLinkage(); + + // We are going to simplify things by solving the problem iteratively. + // If the current `componentType` has requirements for `A`, `B`, ... etc. + // then we will create a composite of `componentType`, `A`, `B`, ... + // and then see if the resulting composite has any requirements. + // + // This avoids the problem of trying to compute teh transitive closure + // of the requirements relationship (while dealing with deduplication, + // etc.) + + RefPtr<ComponentType> componentType = inComponentType; + for(;;) + { + auto requirementCount = componentType->getRequirementCount(); + if(requirementCount == 0) + break; + + List<RefPtr<ComponentType>> allComponents; + allComponents.add(componentType); + + for(Index rr = 0; rr < requirementCount; ++rr) + { + auto requirement = componentType->getRequirement(rr); + allComponents.add(requirement); + } + + componentType = CompositeComponentType::create( + linkage, + allComponents); + } + return componentType; + } + + /// Create a component type to represent the "global scope" of a compile request. /// - /// The created program will comprise all of the translation - /// units that were compiled as part of the request, as - /// well as any entry points in those translation units. + /// This component type will include all the modules and their global + /// parameters from the compile request, but not anything specific + /// to any entry point functions. /// - RefPtr<Program> createUnspecializedProgram( + /// The layout for this component type will thus represent the things that + /// a user is likely to want to have stay the same across all compiled + /// entry points. + /// + /// The component type that this function creates is unspecialized, in + /// that it doesn't take into account any specialization arguments + /// that might have been supplied as part of the compile request. + /// + RefPtr<ComponentType> createUnspecializedGlobalComponentType( FrontEndCompileRequest* compileRequest) { // We want our resulting program to depend on @@ -10677,16 +10754,51 @@ static bool doesParameterMatch( // auto linkage = compileRequest->getLinkage(); auto sink = compileRequest->getSink(); - auto program = new Program(linkage); - for(auto translationUnit : compileRequest->translationUnits ) + + RefPtr<ComponentType> globalComponentType; + if(compileRequest->translationUnits.getCount() == 1) { - program->addReferencedLeafModule(translationUnit->getModule()); + // The common case is that a compilation only uses + // a single translation unit, and thus results in + // a single `Module`. We can then use that module + // as the component type that represents the global scope. + // + globalComponentType = compileRequest->translationUnits[0]->getModule(); } - for(auto translationUnit : compileRequest->translationUnits ) + else { - program->addReferencedModule(translationUnit->getModule()); + globalComponentType = new LegacyProgram( + linkage, + compileRequest->translationUnits, + sink); } + return fillRequirements(globalComponentType); + } + + /// Create a component type that represents the global scope for a compile request, + /// along with any entry point functions. + /// + /// The resulting component type will include the global-scope information + /// first, so its layout will be compatible with the result of + /// `createUnspecializedGlobalComponentType`. + /// + /// The new component type will also add on any entry-point functions + /// that were requested and will thus include space for their `uniform` parameters. + /// If multiple entry points were requested then they will be given non-overlapping + /// parameter bindings, consistent with them being used together in + /// a single pipeline state, hit group, etc. + /// + /// The result of this function is unspecialized and doesn't take into + /// account any specialization arguments the user might have supplied. + /// + RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType( + FrontEndCompileRequest* compileRequest) + { + auto linkage = compileRequest->getLinkage(); + auto sink = compileRequest->getSink(); + + auto globalComponentType = compileRequest->getGlobalComponentType(); // The validation of entry points here will be modal, and controlled // by whether the user specified any entry points directly via @@ -10699,6 +10811,9 @@ static bool doesParameterMatch( // bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0; + List<RefPtr<ComponentType>> allComponentTypes; + allComponentTypes.add(globalComponentType); + if( anyExplicitEntryPoints ) { // If there were any explicit requests for entry points to be @@ -10716,8 +10831,9 @@ static bool doesParameterMatch( // but didn't specify any groups (since the current // compilation API doesn't allow for grouping). // - program->addEntryPoint(entryPoint, sink); entryPointReq->getTranslationUnit()->entryPoints.add(entryPoint); + + allComponentTypes.add(entryPoint); } } @@ -10777,6 +10893,7 @@ static bool doesParameterMatch( profile.setStage(entryPointAttr->stage); RefPtr<EntryPoint> entryPoint = EntryPoint::create( + linkage, makeDeclRef(funcDecl), profile); @@ -10792,179 +10909,340 @@ static bool doesParameterMatch( // group, so that its entry-point parameters lay out // independent of the others. // - program->addEntryPoint(entryPoint, sink); translationUnit->entryPoints.add(entryPoint); + + allComponentTypes.add(entryPoint); } } } - program->_collectShaderParams(sink); - - return program; + if(allComponentTypes.getCount() > 1) + { + auto composite = CompositeComponentType::create( + linkage, + allComponentTypes); + return composite; + } + else + { + return globalComponentType; + } } - static void _specializeExistentialTypeParams( - Linkage* linkage, - ExistentialTypeSlots& ioSlots, - List<RefPtr<Expr>> const& args, + RefPtr<ComponentType::SpecializationInfo> Module::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, DiagnosticSink* sink) { - Index slotCount = ioSlots.paramTypes.getCount(); - Index argCount = args.getCount(); + SLANG_ASSERT(argCount == getSpecializationParamCount()); - if( slotCount != argCount ) - { - sink->diagnose(SourceLoc(), Diagnostics::mismatchExistentialSlotArgCount, slotCount, argCount); - return; - } + SemanticsVisitor visitor(getLinkage(), sink); - SemanticsVisitor visitor(linkage, sink); + RefPtr<Module::ModuleSpecializationInfo> specializationInfo = new Module::ModuleSpecializationInfo(); - for( Index ii = 0; ii < slotCount; ++ii ) + for( Index ii = 0; ii < argCount; ++ii ) { - auto slotType = ioSlots.paramTypes[ii]; - auto argExpr = args[ii]; + auto& arg = args[ii]; + auto& param = m_specializationParams[ii]; - auto argType = checkProperType(linkage, TypeExp(argExpr), sink); - if(!argType) + auto argType = arg.val.as<Type>(); + SLANG_ASSERT(argType); + + switch( param.flavor ) { - // TODO: Each slot should track a source location and/or a `VarDeclBase` - // that names the parameter that the slot corresponds to. + case SpecializationParam::Flavor::GenericType: + { + auto genericTypeParamDecl = param.object.as<GlobalGenericParamDecl>(); + SLANG_ASSERT(genericTypeParamDecl); - sink->diagnose(SourceLoc(), Diagnostics::existentialSlotArgNotAType, ii); - return; + // TODO: There is a serious flaw to this checking logic if we ever have cases where + // the constraints on one `type_param` can depend on another `type_param`, e.g.: + // + // type_param A; + // type_param B : ISidekick<A>; + // + // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to + // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being + // set to `Batman` to know whether the setting for `B` is valid. In this limit + // the constraints can be mutually recursive (so `A : IMentor<B>`). + // + // The only way to check things correctly is to validate each conformance under + // a set of assumptions (substitutions) that includes all the type substitutions, + // and possibly also all the other constraints *except* the one to be validated. + // + // We will punt on this for now, and just check each constraint in isolation. + + // As a quick sanity check, see if the argument that is being supplied for a + // global generic type parameter is a reference to *another* global generic + // type parameter, since that should always be an error. + // + if( auto argDeclRefType = argType.as<DeclRefType>() ) + { + auto argDeclRef = argDeclRefType->declRef; + if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) + { + if(argGenericParamDeclRef.getDecl() == genericTypeParamDecl) + { + // We are trying to specialize a generic parameter using itself. + sink->diagnose(genericTypeParamDecl, + Diagnostics::cannotSpecializeGlobalGenericToItself, + genericTypeParamDecl->getName()); + continue; + } + else + { + // We are trying to specialize a generic parameter using a *different* + // global generic type parameter. + sink->diagnose(genericTypeParamDecl, + Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, + genericTypeParamDecl->getName(), + argGenericParamDeclRef.GetName()); + continue; + } + } + } + + ModuleSpecializationInfo::GenericArgInfo genericArgInfo; + genericArgInfo.paramDecl = genericTypeParamDecl; + genericArgInfo.argVal = argType; + specializationInfo->genericArgs.add(genericArgInfo); + + // Walk through the declared constraints for the parameter, + // and check that the argument actually satisfies them. + for(auto constraintDecl : genericTypeParamDecl->getMembersOfType<GenericTypeConstraintDecl>()) + { + // Get the type that the constraint is enforcing conformance to + auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraintDecl, nullptr)); + + // Use our semantic-checking logic to search for a witness to the required conformance + auto witness = visitor.tryGetSubtypeWitness(argType, interfaceType); + if (!witness) + { + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(genericTypeParamDecl, + Diagnostics::typeArgumentForGenericParameterDoesNotConformToInterface, + argType, + genericTypeParamDecl->nameAndLoc.name, + interfaceType); + } + + ModuleSpecializationInfo::GenericArgInfo constraintArgInfo; + constraintArgInfo.paramDecl = constraintDecl; + constraintArgInfo.argVal = witness; + specializationInfo->genericArgs.add(constraintArgInfo); + } + } + break; + + case SpecializationParam::Flavor::ExistentialType: + { + auto interfaceType = param.object.as<Type>(); + SLANG_ASSERT(interfaceType); + + auto witness = visitor.tryGetSubtypeWitness(argType, interfaceType); + if (!witness) + { + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(SourceLoc(), + Diagnostics::typeArgumentDoesNotConformToInterface, + argType, + interfaceType); + } + + ExpandedSpecializationArg expandedArg; + expandedArg.val = argType; + expandedArg.witness = witness; + + specializationInfo->existentialArgs.add(expandedArg); + } + break; + + default: + SLANG_UNEXPECTED("unhandled specialization parameter flavor"); } + } + return specializationInfo; + } - auto witness = visitor.tryGetSubtypeWitness(argType, slotType); - if (!witness) + + static void _extractSpecializationArgs( + ComponentType* componentType, + List<RefPtr<Expr>> const& argExprs, + List<SpecializationArg>& outArgs, + DiagnosticSink* sink) + { + auto linkage = componentType->getLinkage(); + + auto argCount = argExprs.getCount(); + for(Index ii = 0; ii < argCount; ++ii ) + { + auto argExpr = argExprs[ii]; + auto paramInfo = componentType->getSpecializationParam(ii); + + // TODO: We should support non-type arguments here + + auto argType = checkProperType(linkage, TypeExp(argExpr), sink); + if( !argType ) { // If no witness was found, then we will be unable to satisfy // the conformances required. - sink->diagnose(SourceLoc(), Diagnostics::existentialSlotArgDoesNotConform, ii, slotType); - return; + sink->diagnose(argExpr, + Diagnostics::expectedAType, + argExpr->type); + continue; } - ExistentialTypeSlots::Arg arg; - arg.type = argType; - arg.witness = witness; - ioSlots.args.add(arg); + SpecializationArg arg; + arg.val = argType; + outArgs.add(arg); } } - void EntryPoint::_specializeExistentialTypeParams( - List<RefPtr<Expr>> const& args, + RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArgsImpl( + SpecializationArg const* inArgs, + Index inArgCount, DiagnosticSink* sink) { - Slang::_specializeExistentialTypeParams(getLinkage(), m_existentialSlots, args, sink); - } + auto args = inArgs; + auto argCount = inArgCount; - /// Create a specialization an existing entry point based on generic arguments. - RefPtr<EntryPoint> createSpecializedEntryPoint( - EntryPoint* unspecializedEntryPoint, - List<RefPtr<Expr>> const& genericArgs, - List<RefPtr<Expr>> const& existentialArgs, - DiagnosticSink* sink) - { - auto linkage = unspecializedEntryPoint->getLinkage(); + SemanticsVisitor visitor(getLinkage(), sink); - // TODO: Need to be careful in case entry point already has a decl-ref, - // pertaining to outer specializations (e.g., when entry point was - // nested in a generic type. + // The first N arguments will be for the explicit generic parameters + // of the entry point (if it has any). // - auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl(); + auto genericSpecializationParamCount = getGenericSpecializationParamCount(); + SLANG_ASSERT(argCount >= genericSpecializationParamCount); - SemanticsVisitor semantics( - linkage, - sink); + Result result = SLANG_OK; - DeclRef<FuncDecl> entryPointFuncDeclRef = makeDeclRef(entryPointFuncDecl.Ptr()); - if( auto genericDecl = as<GenericDecl>(entryPointFuncDecl->ParentDecl) ) + RefPtr<EntryPointSpecializationInfo> info = new EntryPointSpecializationInfo(); + + DeclRef<FuncDecl> specializedFuncDeclRef = m_funcDeclRef; + if(genericSpecializationParamCount) { - // We will construct a suitable `GenericAppExpr` to represent - // the user-specified `genericDecl` being applied to the - // supplied `genericArgs`, and then use the existing - // semantic checking logic that would apply to an explicit - // generic application like `F<A,B,C>` if it were - // encountered in the source code. + // We need to construct a generic application and use + // the semantic checking machinery to expand out + // the rest of the arguments via inference... - auto session = linkage->getSessionImpl(); - auto genericDeclRef = makeDeclRef(genericDecl); + auto genericDeclRef = m_funcDeclRef.GetParent().as<GenericDecl>(); + SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters - // The first pieces is a `VarExpr` that refers to `genericDecl`. - // - // TODO: This would not be needed if we instead parsed - // the supplied entry-point name into an expression - // earlier in this function. - // - RefPtr<VarExpr> genericExpr = new VarExpr(); - genericExpr->declRef = genericDeclRef; - genericExpr->type.type = getTypeForDeclRef(session, genericDeclRef); - - // Next we construct the actual `GenericAppExpr` - // - RefPtr<GenericAppExpr> genericAppExpr = new GenericAppExpr(); - genericAppExpr->FunctionExpr = genericExpr; - genericAppExpr->Arguments = genericArgs; + RefPtr<GenericSubstitution> genericSubst = new GenericSubstitution(); + genericSubst->outer = genericDeclRef.substitutions.substitutions; + genericSubst->genericDecl = genericDeclRef.getDecl(); - // We use the semantics visitor to perform the - // actual checking logic (this might report - // errors) - // - auto checkedExpr = semantics.checkGenericAppWithCheckedArgs(genericAppExpr); - - // Now we need to extract an appropriate decl-ref for the entry - // point from the `checkedExpr`. - // - if( auto declRefExpr = checkedExpr.as<DeclRefExpr>() ) + for(Index ii = 0; ii < genericSpecializationParamCount; ++ii) { - // TODO: We should eventually check for the case - // where we have a `MemberExpr` or another case of - // `DeclRefExpr` that cannot be summarized as just - // its decl-ref. - // - // The basic `VarExpr` and `StaticMemberExpr` cases - // should be allow-able. - - entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>(); + auto specializationArg = args[ii]; + genericSubst->args.add(specializationArg.val); } - else if( semantics.IsErrorExpr(checkedExpr) ) + + for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { - // Any semantic error that occured should have been - // reported already. - return nullptr; + auto constraintSubst = genericDeclRef.substitutions; + constraintSubst.substitutions = genericSubst; + + DeclRef<GenericTypeConstraintDecl> constraintDeclRef( + constraintDecl, constraintSubst); + + auto sub = GetSub(constraintDeclRef); + auto sup = GetSup(constraintDeclRef); + + auto subTypeWitness = visitor.tryGetSubtypeWitness(sub, sup); + if(subTypeWitness) + { + genericSubst->args.add(subTypeWitness); + } + else + { + // TODO: diagnose a problem here + sink->diagnose(constraintDecl, Diagnostics::typeArgumentDoesNotConformToInterface, sub, sup); + result = SLANG_FAIL; + continue; + } } - else + + specializedFuncDeclRef.substitutions.substitutions = genericSubst; + } + + info->specializedFuncDeclRef = specializedFuncDeclRef; + + // Once the generic parameters (if any) have been dealt with, + // any remaining specialization arguments are for existential/interface + // specialization parameters, attached to the value parameters + // of the entry point. + // + args += genericSpecializationParamCount; + argCount -= genericSpecializationParamCount; + + auto existentialSpecializationParamCount = getExistentialSpecializationParamCount(); + SLANG_ASSERT(argCount == existentialSpecializationParamCount); + + for( Index ii = 0; ii < existentialSpecializationParamCount; ++ii ) + { + auto& param = m_existentialSpecializationParams[ii]; + auto& specializationArg = args[ii]; + + // TODO: We need to handle all the cases of "flavor" for the `param`s (not just types) + + auto paramType = param.object.as<Type>(); + auto argType = specializationArg.val.as<Type>(); + + auto witness = visitor.tryGetSubtypeWitness(argType, paramType); + if (!witness) { - // The result of specializing a reference to a generic - // function should always be a `DeclRefExpr` - // - SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`"); - UNREACHABLE_RETURN(nullptr); + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(SourceLoc(), Diagnostics::typeArgumentDoesNotConformToInterface, argType, paramType); + result = SLANG_FAIL; + continue; } + + ExpandedSpecializationArg expandedArg; + expandedArg.val = specializationArg.val; + expandedArg.witness = witness; + info->existentialSpecializationArgs.add(expandedArg); } - RefPtr<EntryPoint> specializedEntryPoint = EntryPoint::create( - entryPointFuncDeclRef, - unspecializedEntryPoint->getProfile()); + return info; + } - // Next we need to validate the existential arguments. - specializedEntryPoint->_specializeExistentialTypeParams(existentialArgs, sink); + /// Create a specialization an existing entry point based on specialization argument expressions. + RefPtr<ComponentType> createSpecializedEntryPoint( + EntryPoint* unspecializedEntryPoint, + List<RefPtr<Expr>> const& argExprs, + DiagnosticSink* sink) + { + // We need to convert all of the `Expr` arguments + // into `SpecializationArg`s, so that we can bottleneck + // through the shared logic. + // + List<SpecializationArg> args; + _extractSpecializationArgs(unspecializedEntryPoint, argExprs, args, sink); + if(sink->GetErrorCount()) + return nullptr; - return specializedEntryPoint; + return unspecializedEntryPoint->specialize( + args.getBuffer(), + args.getCount(), + sink); } - /// Parse an array of strings as generic arguments. + /// Parse an array of strings as specialization arguments. /// /// Names in the strings will be parsed in the context of /// the code loaded into the given compile request. /// - void parseGenericArgStrings( + void parseSpecializationArgStrings( EndToEndCompileRequest* endToEndReq, List<String> const& genericArgStrings, List<RefPtr<Expr>>& outGenericArgs) { - auto unspecialiedProgram = endToEndReq->getUnspecializedProgram(); + auto unspecialiedProgram = endToEndReq->getUnspecializedGlobalComponentType(); // TODO: Building a list of `scopesToTry` here shouldn't // be required, since the `Scope` type itself has the ability @@ -11004,351 +11282,109 @@ static bool doesParameterMatch( } } + if(!argExpr) + { + sink->diagnose(SourceLoc(), Diagnostics::internalCompilerError, "couldn't parse specialization argument"); + return; + } + outGenericArgs.add(argExpr); } } - void Program::_specializeExistentialTypeParams( - List<RefPtr<Expr>> const& args, - DiagnosticSink* sink) - { - Slang::_specializeExistentialTypeParams(getLinkageImpl(), m_globalExistentialSlots, args, sink); - } - Type* Linkage::specializeType( Type* unspecializedType, Int argCount, Type* const* args, DiagnosticSink* sink) { + SLANG_ASSERT(unspecializedType); + // TODO: We should cache and re-use specialized types // when the exact same arguments are provided again later. SemanticsVisitor visitor(this, sink); + SpecializationParams specializationParams; + _collectExistentialSpecializationParamsRec(specializationParams, unspecializedType); - ExistentialTypeSlots slots; - _collectExistentialTypeParamsRec(slots, unspecializedType); - - assert(slots.paramTypes.getCount() == argCount); + assert(specializationParams.getCount() == argCount); + ExpandedSpecializationArgs specializationArgs; for( Int aa = 0; aa < argCount; ++aa ) { + auto paramType = specializationParams[aa].object.as<Type>(); auto argType = args[aa]; - ExistentialTypeSlots::Arg arg; - arg.type = argType; - arg.witness = visitor.tryGetSubtypeWitness(argType, slots.paramTypes[aa]); - slots.args.add(arg); + ExpandedSpecializationArg arg; + arg.val = argType; + arg.witness = visitor.tryGetSubtypeWitness(argType, paramType); + specializationArgs.add(arg); } RefPtr<ExistentialSpecializedType> specializedType = new ExistentialSpecializedType(); specializedType->baseType = unspecializedType; - specializedType->slots = slots; + specializedType->args = specializationArgs; m_specializedTypes.add(specializedType); return specializedType; } - // Shared implementation logic for the `_createSpecializedProgram*` entry points. - static RefPtr<Program> _createSpecializedProgramImpl( + /// Shared implementation logic for the `_createSpecializedProgram*` entry points. + static RefPtr<ComponentType> _createSpecializedProgramImpl( Linkage* linkage, - Program* unspecializedProgram, - List<RefPtr<Expr>> const& globalGenericArgs, - List<RefPtr<Expr>> const& globalExistentialArgs, + ComponentType* unspecializedProgram, + List<RefPtr<Expr>> const& specializationArgExprs, DiagnosticSink* sink) { - // TODO: If there are no specialization arguments, + // If there are no specialization arguments, // then the the result of specialization should - // be the same as the input... *but* we are promising - // to return a program without any entry points, - // and that screws things up here. - // - // For now we just carefully avod this early-exit - // if we have any entry points. + // be the same as the input. // - // Eventually we should try to revise the model so - // that specialization of a program that includes - // entry points can make some kind of sense. - // - if( globalGenericArgs.getCount() == 0 - && globalExistentialArgs.getCount() == 0 - && unspecializedProgram->getEntryPointCount() == 0) + auto specializationArgCount = specializationArgExprs.getCount(); + if( specializationArgCount == 0 ) { return unspecializedProgram; } - // The given `unspecializedProgram` should be one that - // was checked through the front-end, so that now we - // only need to check if the given arguments can satisfy - // the requirements of the global generic parameters. - // - // The new program needs to start off with the same - // module dependency list as the original. - // - RefPtr<Program> specializedProgram = new Program(linkage); - for(auto module : unspecializedProgram->getModuleDependencies()) - { - specializedProgram->addReferencedLeafModule(module); - } - - - // We will collect all the global generic parameters - // defined in the modules being referenced, to find - // the global generic parameter signature of the - // program. - // - // TODO: Note that this doesn't handle the case where one - // or more of the type *arguments* that we are specifying - // ends up requiring additional modules to be referenced, - // which might in turn introduce new global generic parameters. - // - List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; - for(auto module : unspecializedProgram->getModuleDependencies()) + auto specializationParamCount = unspecializedProgram->getSpecializationParamCount(); + if(specializationArgCount != specializationParamCount ) { - for(auto param : module->getModuleDecl()->getMembersOfType<GlobalGenericParamDecl>()) - globalGenericParams.add(param); - } - - // Next, we will check whether the supplied arguments can - // satisfy those parameters. - // - // An easy early-out case will be if the number of - // arguments isn't correct. - // - if (globalGenericParams.getCount() != globalGenericArgs.getCount()) - { - sink->diagnose(SourceLoc(), Diagnostics::mismatchGlobalGenericArguments, - globalGenericParams.getCount(), - globalGenericArgs.getCount()); + sink->diagnose(SourceLoc(), Diagnostics::mismatchSpecializationArguments, + specializationParamCount, + specializationArgCount); return nullptr; } - // We have an appropriate number of arguments for the global generic parameters, + // We have an appropriate number of arguments for the global specialization parameters, // and now we need to check that the arguments conform to the declared constraints. // SemanticsVisitor visitor(linkage, sink); - // Along the way, we will build up an appropriate set of substitutions to represent - // the generic arguments and their conformances. - // - RefPtr<Substitutions> globalGenericSubsts; - auto globalGenericSubstLink = &globalGenericSubsts; - // - // TODO: There is a serious flaw to this checking logic if we ever have cases where - // the constraints on one `type_param` can depend on another `type_param`, e.g.: - // - // type_param A; - // type_param B : ISidekick<A>; - // - // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to - // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being - // set to `Batman` to know whether the setting for `B` is valid. In this limit - // the constraints can be mutually recursive (so `A : IMentor<B>`). - // - // The only way to check things correctly is to validate each conformance under - // a set of assumptions (substitutions) that includes all the type substitutions, - // and possibly also all the other constraints *except* the one to be validated. - // - // We will punt on this for now, and just check each constraint in isolation. - // - Index argCounter = 0; - for(auto& globalGenericParam : globalGenericParams) - { - // Get the argument that matches this parameter. - Index argIndex = argCounter++; - SLANG_ASSERT(argIndex < globalGenericArgs.getCount()); - auto globalGenericArg = checkProperType(linkage, TypeExp(globalGenericArgs[argIndex]), sink); - if (!globalGenericArg) - { - sink->diagnose(globalGenericParam, Diagnostics::globalGenericArgumentNotAType, globalGenericParam->getName()); - return nullptr; - } - - // As a quick sanity check, see if the argument that is being supplied for a parameter - // is just the parameter itself, because this should always be an error: - // - if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() ) - { - auto argDeclRef = argDeclRefType->declRef; - if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) - { - if(argGenericParamDeclRef.getDecl() == globalGenericParam) - { - // We are trying to specialize a generic parameter using itself. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToItself, - globalGenericParam->getName()); - continue; - } - else - { - // We are trying to specialize a generic parameter using a *different* - // global generic type parameter. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, - globalGenericParam->getName(), - argGenericParamDeclRef.GetName()); - continue; - } - } - } - - // Create a substitution for this parameter/argument. - RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); - subst->paramDecl = globalGenericParam; - subst->actualType = globalGenericArg; - - // Walk through the declared constraints for the parameter, - // and check that the argument actually satisfies them. - for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) - { - // Get the type that the constraint is enforcing conformance to - auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); - - // Use our semantic-checking logic to search for a witness to the required conformance - auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); - if (!witness) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(globalGenericParam, - Diagnostics::typeArgumentDoesNotConformToInterface, - globalGenericParam->nameAndLoc.name, - globalGenericArg, - interfaceType); - } - - // Attach the concrete witness for this conformance to the - // substutiton - GlobalGenericParamSubstitution::ConstraintArg constraintArg; - constraintArg.decl = constraint; - constraintArg.val = witness; - subst->constraintArgs.add(constraintArg); - } - - // Add the substitution for this parameter to the global substitution - // set that we are building. - - *globalGenericSubstLink = subst; - globalGenericSubstLink = &subst->outer; - } + List<SpecializationArg> specializationArgs; + _extractSpecializationArgs(unspecializedProgram, specializationArgExprs, specializationArgs, sink); if(sink->GetErrorCount()) return nullptr; - specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts); - - return specializedProgram; - } - - /// Create a specialized copy of `unspecializedProgram`. - /// - /// The specialized program will include the entry points - /// from the original program (whether thsoe entry points - /// are specialized or not). - /// - static RefPtr<Program> _createSpecializedProgram( - Linkage* linkage, - Program* unspecializedProgram, - List<RefPtr<Expr>> const& globalGenericArgs, - List<RefPtr<Expr>> const& globalExistentialArgs, - DiagnosticSink* sink) - { - auto specializedProgram = _createSpecializedProgramImpl( - linkage, - unspecializedProgram, - globalGenericArgs, - globalExistentialArgs, + auto specializedProgram = unspecializedProgram->specialize( + specializationArgs.getBuffer(), + specializationArgs.getCount(), sink); - // We need to ensure that the specialized program has whatever - // entry points (and groups) the unspecialized one had. - // - for( auto entryPointGroup : unspecializedProgram->getEntryPointGroups() ) - { - specializedProgram->addEntryPointGroup(entryPointGroup); - } - - // Now deal with the shader parameters and existential arguments - // - // Note: We should in theory be able to just copy over the shader - // parameters and existential slot information from the unspecialized - // program. This could save some time, but it would also mean that - // the only way to create a specialized program is by creating an - // unspecialized on first, which is maybe not always desirable. - // - specializedProgram->_collectShaderParams(sink); - specializedProgram->_specializeExistentialTypeParams(globalExistentialArgs, sink); - return specializedProgram; } - SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::specializeProgram( - slang::IProgram* inUnspecializedProgram, - SlangInt specializationArgCount, - slang::SpecializationArg const* specializationArgs, - slang::IProgram** outSpecializedProgram, - ISlangBlob** outDiagnostics) - { - auto unspecializedProgram = asInternal(inUnspecializedProgram); - - if( specializationArgCount == 0 ) - { - *outSpecializedProgram = ComPtr<slang::IProgram>(asExternal(unspecializedProgram)).detach(); - return SLANG_OK; - } - - List<RefPtr<Expr>> globalGenericArgs; - - List<RefPtr<Expr>> globalExistentialArgs; - for( Int ii = 0; ii < specializationArgCount; ++ii ) - { - auto& specializationArg = specializationArgs[ii]; - switch( specializationArg.kind ) - { - case slang::SpecializationArg::Kind::Type: - { - auto typeArg = asInternal(specializationArg.type); - RefPtr<SharedTypeExpr> argExpr = new SharedTypeExpr(); - argExpr->base = TypeExp(typeArg); - argExpr->type = QualType(getTypeType(typeArg)); - - globalExistentialArgs.add(argExpr); - } - break; - - default: - return SLANG_E_INVALID_ARG; - } - } - - DiagnosticSink sink(getSourceManager()); - auto specializedProgram = _createSpecializedProgram( - this, - unspecializedProgram, - globalGenericArgs, - globalExistentialArgs, - &sink); - sink.getBlobIfNeeded(outDiagnostics); - - if(!specializedProgram) - return SLANG_FAIL; - - *outSpecializedProgram = ComPtr<slang::IProgram>(asExternal(specializedProgram)).detach(); - return SLANG_OK; - } - - /// Specialize an entry point that was checked by the front-end, based on generic arguments. + /// Specialize an entry point that was checked by the front-end, based on specialization arguments. /// - /// If the end-to-end compile request included generic argument strings + /// If the end-to-end compile request included specialization argument strings /// for this entry point, then they will be parsed, checked, and used /// as arguments to the generic entry point. /// /// Returns a specialized entry point if everything worked as expected. /// Returns null and diagnoses errors if anything goes wrong. /// - RefPtr<EntryPoint> createSpecializedEntryPoint( + RefPtr<ComponentType> createSpecializedEntryPoint( EndToEndCompileRequest* endToEndReq, EntryPoint* unspecializedEntryPoint, EndToEndCompileRequest::EntryPointInfo const& entryPointInfo) @@ -11359,33 +11395,35 @@ static bool doesParameterMatch( // If the user specified generic arguments for the entry point, // then we will need to parse the arguments first. // - List<RefPtr<Expr>> genericArgs; - parseGenericArgStrings( - endToEndReq, - entryPointInfo.genericArgStrings, - genericArgs); - - List<RefPtr<Expr>> existentialArgs; - parseGenericArgStrings( + List<RefPtr<Expr>> specializationArgExprs; + parseSpecializationArgStrings( endToEndReq, - entryPointInfo.existentialArgStrings, - existentialArgs); + entryPointInfo.specializationArgStrings, + specializationArgExprs); // Next we specialize the entry point function given the parsed // generic argument expressions. // auto entryPoint = createSpecializedEntryPoint( unspecializedEntryPoint, - genericArgs, - existentialArgs, + specializationArgExprs, sink); return entryPoint; } - /// Create a specialized program based on the given compile request. + /// Create a specialized component type for the global scope of the given compile request. + /// + /// The specialized program will be consistent with that created by + /// `createUnspecializedGlobalComponentType`, and will simply fill in + /// its specialization parameters with the arguments (if any) supllied + /// as part fo the end-to-end compile request. + /// + /// The layout of the new component type will be consistent with that + /// of the original *if* there are no global generic type parameters + /// (only interface/existential parameters). /// - RefPtr<Program> createSpecializedProgram( + RefPtr<ComponentType> createSpecializedGlobalComponentType( EndToEndCompileRequest* endToEndReq) { // The compile request must have already completed front-end processing, @@ -11393,36 +11431,32 @@ static bool doesParameterMatch( // to parse and check any generic arguments that are being supplied for // global or entry-point generic parameters. // - auto unspecializedProgram = endToEndReq->getUnspecializedProgram(); + auto unspecializedProgram = endToEndReq->getUnspecializedGlobalComponentType(); auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); - // First, let's parse the generic argument strings that were - // provided via the API, so taht we can match them + // First, let's parse the specialization argument strings that were + // provided via the API, so that we can match them // against what was declared in the program. // - List<RefPtr<Expr>> globalGenericArgs; - parseGenericArgStrings( + List<RefPtr<Expr>> globalSpecializationArgs; + parseSpecializationArgStrings( endToEndReq, - endToEndReq->globalGenericArgStrings, - globalGenericArgs); + endToEndReq->globalSpecializationArgStrings, + globalSpecializationArgs); - // Also handle global existential type arguments. - List<RefPtr<Expr>> globalExistentialArgs; - parseGenericArgStrings( - endToEndReq, - endToEndReq->globalExistentialSlotArgStrings, - globalExistentialArgs); + // Don't proceed further if anything failed to parse. + if(sink->GetErrorCount()) + return nullptr; // Now we create the initial specialized program by // applying the global generic arguments (if any) to the // unspecialized program. // - auto sink = endToEndReq->getSink(); auto specializedProgram = _createSpecializedProgramImpl( - endToEndReq->getLinkage(), + linkage, unspecializedProgram, - globalGenericArgs, - globalExistentialArgs, + globalSpecializationArgs, sink); // If anything went wrong with the global generic @@ -11450,33 +11484,69 @@ static bool doesParameterMatch( endToEndReq->entryPoints.setCount(entryPointCount); } - Index entryPointCounter = 0; + return specializedProgram; + } - for( auto unspecializedEntryPointGroup : unspecializedProgram->getEntryPointGroups() ) - { - List<RefPtr<EntryPoint>> specializedEntryPoints; - for( auto unspecializedEntryPoint : unspecializedEntryPointGroup->getEntryPoints() ) - { - Index entryPointIndex = entryPointCounter++; - auto& entryPointInfo = endToEndReq->entryPoints[entryPointIndex]; + /// Create a specialized program based on the given compile request. + /// + /// The specialized program created here includes both the global + /// scope for all the translation units involved and all the entry + /// points, and it also includes any specialization arguments + /// that were supplied. + /// + /// It is important to note that this function specializes + /// the global scope and the entry points in isolation and then + /// composes them, and that this can lead to different layout + /// from the result of `createUnspecializedGlobalAndEntryPointsComponentType`. + /// + /// If we have a module `M` with entry point `E`, and each has one + /// specialization parameter, then `createUnspecialized...` will yield: + /// + /// compose(M,E) + /// + /// That composed type will have two specialization parameters (the one + /// from `M` plus the one from `E`) and so we might specialize it to get: + /// + /// specialize(compose(M,E), X, Y) + /// + /// while if we use `createSpecialized...` we will get: + /// + /// compose(specialize(M,X), specialize(E,Y)) + /// + /// While these options are semantically equivalent, they would not lay + /// out the same way in memory. + /// + /// There are many reasons why an application might prefer one over the + /// other, and an application that cares should use the more explicit + /// APIs to construct what they want. The behavior of this function + /// is just to provide a reasonable default for use by end-to-end + /// compilation (e.g., from the command line). + /// + RefPtr<ComponentType> createSpecializedGlobalAndEntryPointsComponentType( + EndToEndCompileRequest* endToEndReq) + { + auto specializedGlobalComponentType = endToEndReq->getSpecializedGlobalComponentType(); - auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); - specializedEntryPoints.add(specializedEntryPoint); - } + List<RefPtr<ComponentType>> allComponentTypes; + allComponentTypes.add(specializedGlobalComponentType); - RefPtr<EntryPointGroup> specializedEntryPointGroup = EntryPointGroup::create(linkage, specializedEntryPoints, endToEndReq->getSink()); - specializedProgram->addEntryPointGroup(specializedEntryPointGroup); - } + auto unspecializedGlobalAndEntryPointsComponentType = endToEndReq->getUnspecializedGlobalAndEntryPointsComponentType(); + auto entryPointCount = unspecializedGlobalAndEntryPointsComponentType->getEntryPointCount(); - // Finalize the information for the specialized program, - // now that we have computed its entry point list, etc. - // - specializedProgram->_collectShaderParams(sink); - specializedProgram->_specializeExistentialTypeParams(globalExistentialArgs, sink); + for(Index ii = 0; ii < entryPointCount; ++ii) + { + auto& entryPointInfo = endToEndReq->entryPoints[ii]; + auto unspecializedEntryPoint = unspecializedGlobalAndEntryPointsComponentType->getEntryPoint(ii); - return specializedProgram; + auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); + allComponentTypes.add(specializedEntryPoint); + } + + RefPtr<ComponentType> composed = CompositeComponentType::create(endToEndReq->getLinkage(), allComponentTypes); + return composed; } + void checkTranslationUnit( TranslationUnitRequest* translationUnit) { @@ -11488,6 +11558,8 @@ static bool doesParameterMatch( // checking that is required on all declarations // in the translation unit. visitor.checkDecl(translationUnit->getModuleDecl()); + + translationUnit->getModule()->_collectShaderParams(); } diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 64afef136..1c0bed065 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -199,10 +199,12 @@ namespace Slang // RefPtr<EntryPoint> EntryPoint::create( + Linkage* linkage, DeclRef<FuncDecl> funcDeclRef, Profile profile) { RefPtr<EntryPoint> entryPoint = new EntryPoint( + linkage, funcDeclRef.GetName(), profile, funcDeclRef); @@ -210,10 +212,12 @@ namespace Slang } RefPtr<EntryPoint> EntryPoint::createDummyForPassThrough( + Linkage* linkage, Name* name, Profile profile) { RefPtr<EntryPoint> entryPoint = new EntryPoint( + linkage, name, profile, DeclRef<FuncDecl>()); @@ -221,103 +225,93 @@ namespace Slang } EntryPoint::EntryPoint( + Linkage* linkage, Name* name, Profile profile, DeclRef<FuncDecl> funcDeclRef) - : m_name(name) + : ComponentType(linkage) + , m_name(name) , m_profile(profile) , m_funcDeclRef(funcDeclRef) { - // In order for later code generation to work, we need to track what - // modules each entry point depends on. We will build up the dependency - // list here when an `EntryPoint` gets created. - // - // We know an entry point depends on the module that declared the - // entry-point function itself. - // - // Note: we are carefully handling the case where `module` could - // be null, becase of "dummy" entry points created for pass-through - // compilation. + // Collect any specialization parameters used by the entry point // - if(auto module = getModule()) + _collectShaderParams(); + } + + Module* EntryPoint::getModule() + { + return Slang::getModule(getFuncDecl()); + } + + Index EntryPoint::getSpecializationParamCount() + { + return m_genericSpecializationParams.getCount() + m_existentialSpecializationParams.getCount(); + } + + SpecializationParam const& EntryPoint::getSpecializationParam(Index index) + { + auto genericParamCount = m_genericSpecializationParams.getCount(); + if(index < genericParamCount) { - m_dependencyList.addDependency(module); + return m_genericSpecializationParams[index]; } - // - // TODO: We also need to include the modules needed by any generic - // arguments in the dependency list, since in the general case they - // might come from modules other than the one defining the entry point. + else + { + return m_existentialSpecializationParams[index - genericParamCount]; + } + } - // The following is a bit of a hack. - // - // Back-end code generation relies on us having computed layouts for all tagged - // unions that end up being used in the code, which means we need a way to find - // all such types that get used in a program (and the stuff it imports). - // - // For now we are assuming a tagged union type only comes into existence - // as a (top-level) argument for a generic type parameter, so that we - // can check for them here and cache them on the entry point. + Index EntryPoint::getRequirementCount() + { + // The only requirement of an entry point is the module that contains it. // - // A longer-term strategy might need to consider any (tagged or untagged) - // union types that get used inside of a module, and also take - // those lists into account. + // TODO: We will eventually want to support the case of an entry + // point nested in a `struct` type, in which case there should be + // a single requirement representing that outer type (so that multiple + // entry points nested under the same type can share the storage + // for parameters at that scope). + + // Note: the defensive coding is here because the + // "dummy" entry points we create for pass-through + // compilation will not have an associated module. // - // An even longer-term strategy would be to allow type layout to - // be performed on IR types, so taht we don't need to have front-end - // code worrying about this stuff. - // - for( auto subst = funcDeclRef.substitutions.substitutions; subst; subst = subst->outer ) + if( auto module = getModule() ) { - if( auto genericSubst = as<GenericSubstitution>(subst) ) - { - for( auto arg : genericSubst->args ) - { - if( auto taggedUnionType = as<TaggedUnionType>(arg) ) - { - m_taggedUnionTypes.add(taggedUnionType); - } - } - } + return 1; } - - // Collect any existential-type parameters used by the entry point - // - _collectShaderParams(); + return 0; } - Module* EntryPoint::getModule() + RefPtr<ComponentType> EntryPoint::getRequirement(Index index) { - return Slang::getModule(getFuncDecl()); + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + SLANG_ASSERT(getModule()); + return getModule(); } - Linkage* EntryPoint::getLinkage() + void EntryPoint::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) { - return getModule()->getLinkage(); + visitor->visitEntryPoint(this, as<EntryPointSpecializationInfo>(specializationInfo)); } - // - // EntryPointGroup - // - - RefPtr<EntryPointGroup> EntryPointGroup::create( - Linkage* linkage, - List<RefPtr<EntryPoint>> const& entryPoints, - DiagnosticSink* sink) + List<Module*> const& EntryPoint::getModuleDependencies() { - RefPtr<EntryPointGroup> group = new EntryPointGroup(linkage); - - for( auto entryPoint : entryPoints ) - { - for( auto module : entryPoint->getModuleDependencies() ) - { - group->m_dependencyList.addDependency(module); - } - group->m_entryPoints.add(entryPoint); - } + if(auto module = getModule()) + return module->getModuleDependencies(); - group->_collectShaderParams(sink); + static List<Module*> empty; + return empty; + } - return group; + List<String> const& EntryPoint::getFilePathDependencies() + { + if(auto module = getModule()) + return getModule()->getFilePathDependencies(); + + static List<String> empty; + return empty; } // @@ -1949,7 +1943,7 @@ SlangResult dissassembleDXILUsingDXC( TargetRequest* targetReq, Int entryPointIndex) { - auto program = compileRequest->getSpecializedProgram(); + auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); auto targetProgram = program->getTargetProgram(targetReq); auto backEndReq = compileRequest->getBackEndReq(); @@ -2020,7 +2014,7 @@ SlangResult dissassembleDXILUsingDXC( return result; RefPtr<BackEndCompileRequest> backEndRequest = new BackEndCompileRequest( - m_program->getLinkageImpl(), + m_program->getLinkage(), sink, m_program); @@ -2083,7 +2077,7 @@ SlangResult dissassembleDXILUsingDXC( if (compileRequest->isCommandLineCompile) { auto linkage = compileRequest->getLinkage(); - auto program = compileRequest->getSpecializedProgram(); + auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); for (auto targetReq : linkage->targets) { Index entryPointCount = program->getEntryPointCount(); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index b7f62ae5b..46b257fe3 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -116,7 +116,6 @@ namespace Slang class Linkage; class Module; - class Program; class FrontEndCompileRequest; class BackEndCompileRequest; class EndToEndCompileRequest; @@ -146,13 +145,19 @@ namespace Slang struct ShaderParamInfo { DeclRef<VarDeclBase> paramDeclRef; - UInt firstExistentialTypeSlot = 0; - UInt existentialTypeSlotCount = 0; + Int firstSpecializationParamIndex = 0; + Int specializationParamCount = 0; }; /// Extended information specific to global shader parameters struct GlobalShaderParamInfo : ShaderParamInfo { + // TODO: This type should be eliminated if/when we remove + // support for compilation with multiple translation units + // that all declare the "same" shader parameter (e.g., a + // `cbuffer`) and expect those duplicate declarations + // to get the same parameter binding/layout. + // Additional global-scope declarations that are conceptually // declaring the "same" parameter as the `paramDeclRef`. List<DeclRef<VarDeclBase>> additionalParamDeclRefs; @@ -203,7 +208,7 @@ namespace Slang { public: /// Get the list of modules that are depended on. - List<RefPtr<Module>> const& getModuleList() { return m_moduleList; } + List<Module*> const& getModuleList() { return m_moduleList; } /// Add a module and everything it depends on to the list. void addDependency(Module* module); @@ -214,8 +219,8 @@ namespace Slang private: void _addDependency(Module* module); - List<RefPtr<Module>> m_moduleList; - HashSet<Module*> m_moduleSet; + List<Module*> m_moduleList; + HashSet<Module*> m_moduleSet; }; /// Tracks an unordered list of filesystem paths that something depends on @@ -245,6 +250,375 @@ namespace Slang HashSet<String> m_filePathSet; }; + class EntryPoint; + + class ComponentType; + class ComponentTypeVisitor; + + /// Base class for "component types" that represent the pieces a final + /// shader program gets linked together from. + /// + class ComponentType : public RefObject, public slang::IComponentType + { + public: + // + // ISlangUnknown interface + // + + SLANG_REF_OBJECT_IUNKNOWN_ALL; + ISlangUnknown* getInterface(Guid const& guid); + + // + // slang::IComponentType interface + // + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE; + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( + SlangInt targetIndex, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW IComponentType* SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE; + + /// Get the linkage (aka "session" in the public API) for this component type. + Linkage* getLinkage() { return m_linkage; } + + /// Get the target-specific version of this program for the given `target`. + /// + /// The `target` must be a target on the `Linkage` that was used to create this program. + TargetProgram* getTargetProgram(TargetRequest* target); + + /// Get the number of entry points linked into this component type. + virtual Index getEntryPointCount() = 0; + + /// Get one of the entry points linked into this component type. + virtual RefPtr<EntryPoint> getEntryPoint(Index index) = 0; + + /// Get the number of global shader parameters linked into this component type. + virtual Index getShaderParamCount() = 0; + + /// Get one of the global shader parametesr linked into this component type. + virtual GlobalShaderParamInfo getShaderParam(Index index) = 0; + + /// Get the number of (unspecialized) specialization parameters for the component type. + virtual Index getSpecializationParamCount() = 0; + + /// Get the specialization parameter at `index`. + virtual SpecializationParam const& getSpecializationParam(Index index) = 0; + + /// Get the number of "requirements" that this component type has. + /// + /// A requirement represents another component type that this component + /// needs in order to function correctly. For example, the dependency + /// of one module on another module that it `import`s is represented + /// as a requirement, as is the dependency of an entry point on the + /// module that defines it. + /// + virtual Index getRequirementCount() = 0; + + /// Get the requirement at `index`. + virtual RefPtr<ComponentType> getRequirement(Index index) = 0; + + /// Parse a type from a string, in the context of this component type. + /// + /// Any names in the string will be resolved using the modules + /// referenced by the program. + /// + /// On an error, returns null and reports diagnostic messages + /// to the provided `sink`. + /// + /// TODO: This function shouldn't be on the base class, since + /// it only really makes sense on `Module` and (as a compatibility + /// feature) on `LegacyProgram`. + /// + Type* getTypeFromString( + String const& typeStr, + DiagnosticSink* sink); + + /// Get a list of modules that this component type depends on. + /// + virtual List<Module*> const& getModuleDependencies() = 0; + + /// Get the full list of filesystem paths this component type depends on. + /// + virtual List<String> const& getFilePathDependencies() = 0; + + /// Callback for use with `enumerateIRModules` + typedef void (*EnumerateIRModulesCallback)(IRModule* irModule, void* userData); + + /// Invoke `callback` on all the IR modules that are (transitively) linked into this component type. + void enumerateIRModules(EnumerateIRModulesCallback callback, void* userData); + + /// Invoke `callback` on all the IR modules that are (transitively) linked into this component type. + template<typename F> + void enumerateIRModules(F const& callback) + { + struct Helper + { + static void helper(IRModule* irModule, void* userData) + { + (*(F*)userData)(irModule); + } + }; + enumerateIRModules(&Helper::helper, (void*)&callback); + } + + /// Callback for use with `enumerateModules` + typedef void (*EnumerateModulesCallback)(Module* module, void* userData); + + /// Invoke `callback` on all the modules that are (transitively) linked into this component type. + void enumerateModules(EnumerateModulesCallback callback, void* userData); + + /// Invoke `callback` on all the modules that are (transitively) linked into this component type. + template<typename F> + void enumerateModules(F const& callback) + { + struct Helper + { + static void helper(Module* module, void* userData) + { + (*(F*)userData)(module); + } + }; + enumerateModules(&Helper::helper, (void*)&callback); + } + + /// Side-band information generated when specializing this component type. + /// + /// Difference subclasses of `ComponentType` are expected to create their + /// own subclass of `SpecializationInfo` as the output of `_validateSpecializationArgs`. + /// Later, whenever we want to use a specialized component type we will + /// also have the `SpecializationInfo` available and will expect it to + /// have the correct (subclass-specific) type. + /// + class SpecializationInfo : public RefObject + { + }; + + /// Validate the given specialization `args` and compute any side-band specialization info. + /// + /// Any errors will be reported to `sink`, which can thus be used to test + /// if the operation was successful. + /// + /// A null return value is allowed, since not all subclasses require + /// custom side-band specialization information. + /// + /// This function is an implementation detail of `specialize()`. + /// + virtual RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) = 0; + + /// Validate the given specialization `args` and compute any side-band specialization info. + /// + /// Any errors will be reported to `sink`, which can thus be used to test + /// if the operation was successful. + /// + /// A null return value is allowed, since not all subclasses require + /// custom side-band specialization information. + /// + /// This function is an implementation detail of `specialize()`. + /// + RefPtr<SpecializationInfo> _validateSpecializationArgs( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) + { + if(argCount == 0) return nullptr; + return _validateSpecializationArgsImpl(args, argCount, sink); + } + + /// Specialize this component type given `specializationArgs` + /// + /// Any diagnostics will be reported to `sink`, which can be used + /// to determine if the operation was successful. It is allowed + /// for this operation to have a non-null return even when an + /// error is ecnountered. + /// + RefPtr<ComponentType> specialize( + SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + DiagnosticSink* sink); + + /// Invoke `visitor` on this component type, using the appropriate dynamic type. + /// + /// This function implements the "visitor pattern" for `ComponentType`. + /// + /// If the `specializationInfo` argument is non-null, it must be specialization + /// information generated for this specific component type by `_validateSpecializationArgs`. + /// In that case, appropriately-typed specialization information will be passed + /// when invoking the `visitor`. + /// + virtual void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) = 0; + + protected: + ComponentType(Linkage* linkage); + + private: + RefPtr<Linkage> m_linkage; + + // Cache of target-specific programs for each target. + Dictionary<TargetRequest*, RefPtr<TargetProgram>> m_targetPrograms; + + // Any types looked up dynamically using `getTypeFromString` + // + // TODO: Remove this. Type lookup should only be supported on `Module`s. + // + Dictionary<String, RefPtr<Type>> m_types; + }; + + /// A component type built up from other component types. + class CompositeComponentType : public ComponentType + { + public: + static RefPtr<ComponentType> create( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents); + + List<RefPtr<ComponentType>> const& getChildComponents() { return m_childComponents; }; + Index getChildComponentCount() { return m_childComponents.getCount(); } + RefPtr<ComponentType> getChildComponent(Index index) { return m_childComponents[index]; } + + Index getEntryPointCount() SLANG_OVERRIDE; + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE; + + Index getShaderParamCount() SLANG_OVERRIDE; + GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE; + + Index getSpecializationParamCount() SLANG_OVERRIDE; + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; + + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE; + List<String> const& getFilePathDependencies() SLANG_OVERRIDE; + + class CompositeSpecializationInfo : public SpecializationInfo + { + public: + List<RefPtr<SpecializationInfo>> childInfos; + }; + + protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; + + + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + + private: + CompositeComponentType( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents); + + List<RefPtr<ComponentType>> m_childComponents; + + // The following arrays hold the concatenated entry points, parameters, + // etc. from the child components. This approach allows for reasonably + // fast (constant time) access through operations like `getShaderParam`, + // but means that the memory usage of a composite is proportional to + // the sum of the memory usage of the children, rather than being fixed + // by the number of children (as it would be if we just stored + // `m_childComponents`). + // + // TODO: We could conceivably build some O(numChildren) arrays that + // support binary-search to provide logarithmic-time access to entry + // points, parameters, etc. while giving a better overall memory usage. + // + List<EntryPoint*> m_entryPoints; + List<GlobalShaderParamInfo> m_shaderParams; + List<SpecializationParam> m_specializationParams; + List<ComponentType*> m_requirements; + + ModuleDependencyList m_moduleDependencyList; + FilePathDependencyList m_filePathDependencyList; + }; + + /// A component type created by specializing another component type. + class SpecializedComponentType : public ComponentType + { + public: + SpecializedComponentType( + ComponentType* base, + SpecializationInfo* specializationInfo, + List<SpecializationArg> const& specializationArgs, + DiagnosticSink* sink); + + /// Get the base (unspecialized) component type that is being specialized. + RefPtr<ComponentType> getBaseComponentType() { return m_base; } + + RefPtr<SpecializationInfo> getSpecializationInfo() { return m_specializationInfo; } + + /// Get the number of arguments supplied for existential type parameters. + /// + /// Note that the number of arguments may not match the number of parameters. + /// In particular, an unspecialized entry point may have many parameters, but zero arguments. + Index getSpecializationArgCount() { return m_specializationArgs.getCount(); } + + /// Get the existential type argument (type and witness table) at `index`. + SpecializationArg const& getSpecializationArg(Index index) { return m_specializationArgs[index]; } + + /// Get an array of all existential type arguments. + SpecializationArg const* getSpecializationArgs() { return m_specializationArgs.getBuffer(); } + + Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { return m_base->getEntryPoint(index); } + + Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } + GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_base->getShaderParam(index); } + + Index getSpecializationParamCount() SLANG_OVERRIDE { return 0; } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); static SpecializationParam dummy; return dummy; } + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; + + /// TODO: These should include requirements/dependencies for the types + /// referenced in the specialization arguments... + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_base->getModuleDependencies(); } + List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_base->getFilePathDependencies(); } + + /// Get a list of tagged-union types referenced by the specialization parameters. + List<RefPtr<TaggedUnionType>> const& getTaggedUnionTypes() { return m_taggedUnionTypes; } + + RefPtr<IRModule> getIRModule() { return m_irModule; } + + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; + + protected: + + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE + { + SLANG_UNUSED(args); + SLANG_UNUSED(argCount); + SLANG_UNUSED(sink); + return nullptr; + } + + private: + RefPtr<ComponentType> m_base; + RefPtr<SpecializationInfo> m_specializationInfo; + SpecializationArgs m_specializationArgs; + RefPtr<IRModule> m_irModule; + + // Any tagged union types that were referenced by the specialization arguments. + List<RefPtr<TaggedUnionType>> m_taggedUnionTypes; + + }; + /// Describes an entry point for the purposes of layout and code generation. /// /// This class also tracks any generic arguments to the entry point, @@ -255,11 +629,12 @@ namespace Slang /// `getName()` and `getProfile()` methods should be expected to /// return useful data on pass-through entry points. /// - class EntryPoint : public RefObject + class EntryPoint : public ComponentType { public: /// Create an entry point that refers to the given function. static RefPtr<EntryPoint> create( + Linkage* linkage, DeclRef<FuncDecl> funcDeclRef, Profile profile); @@ -286,55 +661,66 @@ namespace Slang /// Get the module that contains the entry point. Module* getModule(); - /// Get the linkage that contains the module for this entry point. - Linkage* getLinkage(); - /// Get a list of modules that this entry point depends on. /// /// This will include the module that defines the entry point (see `getModule()`), /// but may also include modules that are required by its generic type arguments. /// - List<RefPtr<Module>> getModuleDependencies() { return m_dependencyList.getModuleList(); } - - /// Get a list of tagged-union types referenced by the entry point's generic parameters. - List<RefPtr<TaggedUnionType>> const& getTaggedUnionTypes() { return m_taggedUnionTypes; } + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE; // { return getModule()->getModuleDependencies(); } + List<String> const& getFilePathDependencies() SLANG_OVERRIDE; // { return getModule()->getFilePathDependencies(); } /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. static RefPtr<EntryPoint> createDummyForPassThrough( + Linkage* linkage, Name* name, Profile profile); /// Get the number of existential type parameters for the entry point. - Index getExistentialTypeParamCount() { return m_existentialSlots.paramTypes.getCount(); } + Index getSpecializationParamCount() SLANG_OVERRIDE; /// Get the existential type parameter at `index`. - Type* getExistentialTypeParam(Index index) { return m_existentialSlots.paramTypes[index]; } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; - /// Get the number of arguments supplied for existential type parameters. - /// - /// Note that the number of arguments may not match the number of parameters. - /// In particular, an unspecialized entry point may have many parameters, but zero arguments. - Index getExistentialTypeArgCount() { return m_existentialSlots.args.getCount(); } + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; - /// Get the existential type argument (type and witness table) at `index`. - ExistentialTypeSlots::Arg getExistentialTypeArg(Index index) { return m_existentialSlots.args[index]; } + SpecializationParams const& getExistentialSpecializationParams() { return m_existentialSpecializationParams; } - /// Get an array of all existential type arguments. - ExistentialTypeSlots::Arg const* getExistentialTypeArgs() { return m_existentialSlots.args.getBuffer(); } + Index getGenericSpecializationParamCount() { return m_genericSpecializationParams.getCount(); } + Index getExistentialSpecializationParamCount() { return m_existentialSpecializationParams.getCount(); } /// Get an array of all entry-point shader parameters. List<ShaderParamInfo> const& getShaderParams() { return m_shaderParams; } - void _specializeExistentialTypeParams( - List<RefPtr<Expr>> const& args, - DiagnosticSink* sink); + Index getEntryPointCount() SLANG_OVERRIDE { return 1; }; + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return this; } + + Index getShaderParamCount() SLANG_OVERRIDE { return 0; } + GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return GlobalShaderParamInfo(); } + + class EntryPointSpecializationInfo : public SpecializationInfo + { + public: + DeclRef<FuncDecl> specializedFuncDeclRef; + List<ExpandedSpecializationArg> existentialSpecializationArgs; + }; + + protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; + + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; private: EntryPoint( + Linkage* linkage, Name* name, Profile profile, DeclRef<FuncDecl> funcDeclRef); + void _collectGenericSpecializationParamsRec(Decl* decl); void _collectShaderParams(); // The name of the entry point function (e.g., `main`) @@ -345,8 +731,8 @@ namespace Slang // DeclRef<FuncDecl> m_funcDeclRef; - /// The existential/interface slots associated with the entry point parameter scope. - ExistentialTypeSlots m_existentialSlots; + SpecializationParams m_genericSpecializationParams; + SpecializationParams m_existentialSpecializationParams; /// Information about entry-point parameters List<ShaderParamInfo> m_shaderParams; @@ -362,61 +748,6 @@ namespace Slang // intrinsic to the entry point. // Profile m_profile; - - // Any tagged union types that were referenced by the generic arguments of the entry point. - List<RefPtr<TaggedUnionType>> m_taggedUnionTypes; - - // Modules the entry point depends on. - ModuleDependencyList m_dependencyList; - }; - - class EntryPointGroup : public RefObject - { - public: - static RefPtr<EntryPointGroup> create( - Linkage* linkage, - List<RefPtr<EntryPoint>> const& entryPoints, - DiagnosticSink* sink); - - Linkage* getLinkageImpl() { return m_linkage; } - - /// Get the number of entry points in the group - Index getEntryPointCount() { return m_entryPoints.getCount(); } - - /// Get the entry point at the given `index`. - RefPtr<EntryPoint> getEntryPoint(Index index) { return m_entryPoints[index]; } - - /// Get the full ist of entry points in the group. - List<RefPtr<EntryPoint>> const& getEntryPoints() { return m_entryPoints; } - - /// Get a list of modules that this entry point group depends on. - /// - /// This will include the dependencies of all of the entry points in the group. - /// - List<RefPtr<Module>> getModuleDependencies() { return m_dependencyList.getModuleList(); } - - /// Get an array of all entry-point-group shader parameters. - List<ShaderParamInfo> const& getShaderParams() { return m_shaderParams; } - - private: - EntryPointGroup(Linkage* linkage) - : m_linkage(linkage) - {} - - void _collectShaderParams(DiagnosticSink* sink); - - Linkage* m_linkage; - List<RefPtr<EntryPoint>> m_entryPoints; - - /// Information about shader parameters to be associated with the entry-point group itself. - /// - /// This list captures parameters that logically belong to the group itself, rather than - /// to any specific entry point in the group. - /// - List<ShaderParamInfo> m_shaderParams; - - /// Modules the entry point group depends on. - ModuleDependencyList m_dependencyList; }; enum class PassThroughMode : SlangPassThrough @@ -439,19 +770,52 @@ namespace Slang /// may span multiple Slang source files), and provides access /// to both the AST and IR representations of that code. /// - class Module : public RefObject, public slang::IModule + class Module : public ComponentType, public slang::IModule { + typedef ComponentType Super; + public: SLANG_REF_OBJECT_IUNKNOWN_ALL ISlangUnknown* getInterface(const Guid& guid); + + // Forward `IComponentType` methods + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } + + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( + SlangInt targetIndex, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW IComponentType* SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize(specializationArgs, specializationArgCount, outDiagnostics); + } + + // + /// Create a module (initially empty). Module(Linkage* linkage); - /// Get the parent linkage of this module. - Linkage* getLinkage() { return m_linkage; } - /// Get the AST for the module (if it has been parsed) ModuleDecl* getModuleDecl() { return m_moduleDecl; } @@ -459,7 +823,7 @@ namespace Slang IRModule* getIRModule() { return m_irModule; } /// Get the list of other modules this module depends on - List<RefPtr<Module>> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); } + List<Module*> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); } /// Get the list of filesystem paths this module depends on List<String> const& getFilePathDependencyList() { return m_filePathDependencyList.getFilePathList(); } @@ -474,7 +838,7 @@ namespace Slang /// /// This should only be called once, during creation of the module. /// - void setModuleDecl(ModuleDecl* moduleDecl) { m_moduleDecl = moduleDecl; } + void setModuleDecl(ModuleDecl* moduleDecl);// { m_moduleDecl = moduleDecl; } /// Set the IR for this module. /// @@ -482,16 +846,65 @@ namespace Slang /// void setIRModule(IRModule* irModule) { m_irModule = irModule; } - private: - // The parent linkage - Linkage* m_linkage = nullptr; + Index getEntryPointCount() SLANG_OVERRIDE { return 0; } + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return nullptr; } + + Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } + GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } + + Index getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { return m_specializationParams[index]; } + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; + + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencyList.getModuleList(); } + List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencyList.getFilePathList(); } + + /// Collect information on the shader parameters of the module. + /// + /// This method should only be called once, after the core + /// structured of the module (its AST and IR) have been created, + /// and before any of the `ComponentType` APIs are used. + /// + /// TODO: We might eventually consider a non-stateful approach + /// to constructing a `Module`. + /// + void _collectShaderParams(); + class ModuleSpecializationInfo : public SpecializationInfo + { + public: + struct GenericArgInfo + { + RefPtr<Decl> paramDecl; + RefPtr<Val> argVal; + }; + + List<GenericArgInfo> genericArgs; + List<ExpandedSpecializationArg> existentialArgs; + }; + + protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; + + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + + private: // The AST for the module RefPtr<ModuleDecl> m_moduleDecl; // The IR for the module RefPtr<IRModule> m_irModule = nullptr; + List<GlobalShaderParamInfo> m_shaderParams; + SpecializationParams m_specializationParams; + + List<Module*> m_requirements; + // List of modules this module depends on ModuleDependencyList m_moduleDependencyList; @@ -650,15 +1063,10 @@ namespace Slang SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModule( const char* moduleName, slang::IBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW SlangResult SLANG_MCALL createProgram( - slang::ProgramDesc const& desc, - slang::IProgram** outProgram) override; - SLANG_NO_THROW SlangResult SLANG_MCALL specializeProgram( - slang::IProgram* program, - SlangInt specializationArgCount, - slang::SpecializationArg const* specializationArgs, - slang::IProgram** outSpecializedProgram, - ISlangBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::IComponentType* SLANG_MCALL createCompositeComponentType( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + ISlangBlob** outDiagnostics = nullptr) override; SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL specializeType( slang::TypeReflection* type, slang::SpecializationArg const* specializationArgs, @@ -980,202 +1388,133 @@ namespace Slang int translationUnitIndex, String const& path); - Program* getProgram() { return m_program; } + /// Get a component type that represents the global scope of the compile request. + ComponentType* getGlobalComponentType() { return m_globalComponentType; } + + /// Get a component type that represents the global scope of the compile request, plus the requested entry points. + ComponentType* getGlobalAndEntryPointsComponentType() { return m_globalAndEntryPointsComponentType; } private: - RefPtr<Program> m_program; + /// A component type that includes only the global scopes of the translation unit(s) that were compiled. + RefPtr<ComponentType> m_globalComponentType; + + /// A component type that extends the global scopes with all of the entry points that were specified. + RefPtr<ComponentType> m_globalAndEntryPointsComponentType; }; - /// A collection of code modules and entry points that are intended to be used together. + /// A "legacy" program composes multiple translation units from a single compile request, + /// and takes care to treat global declarations of the same name from different translation + /// units as representing the "same" parameter. /// - /// A `Program` establishes that certain pieces of code are intended - /// to be used togehter so that, e.g., layout can make sure to allocate - /// space for the global shader parameters in all referenced modules. + /// TODO: This type only exists to support a single requirement: that multiple translation + /// units can be compiled in one pass and be guaranteed that the "same" parameter declared + /// in different translation units (hence different modules) will get the same layout. + /// This feature should be deprecated and removed as soon as possible, since the complexity + /// it creates in the codebase is not justified by its limited utility. /// - class Program : public RefObject, public slang::IProgram + class LegacyProgram : public ComponentType { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL; - ISlangUnknown* getInterface(Guid const& guid); + LegacyProgram( + Linkage* linkage, + List<RefPtr<TranslationUnitRequest>> const& translationUnits, + DiagnosticSink* sink); - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override; + Index getTranslationUnitCount() { return m_translationUnits.getCount(); } + RefPtr<TranslationUnitRequest> getTranslationUnit(Index index) { return m_translationUnits[index]; } - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( - SlangInt targetIndex, - slang::IBlob** outDiagnostics) override; - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) override; - - - /// Create a new program, initially empty. - /// - /// All code loaded into the program must come - /// from the given `linkage`. - Program( - Linkage* linkage); - - /// Get the linkage that this program uses. - Linkage* getLinkageImpl() { return m_linkage; } - - /// Get the number of entry points added to the program - Index getEntryPointCount() { return m_entryPoints.getCount(); } - - /// Get the entry point at the given `index`. - RefPtr<EntryPoint> getEntryPoint(Index index) { return m_entryPoints[index]; } - - /// Get the full ist of entry points on the program. - List<RefPtr<EntryPoint>> const& getEntryPoints() { return m_entryPoints; } - - - Index getEntryPointGroupCount() { return m_entryPointGroups.getCount(); } - RefPtr<EntryPointGroup> getEntryPointGroup(Index index) { return m_entryPointGroups[index]; } - List<RefPtr<EntryPointGroup>> const& getEntryPointGroups() { return m_entryPointGroups; } - - - /// Get the substitution (if any) that represents how global generics are specialized. - RefPtr<Substitutions> getGlobalGenericSubstitution() { return m_globalGenericSubst; } - - /// Get the full list of modules this program depends on - List<RefPtr<Module>> getModuleDependencies() { return m_moduleDependencyList.getModuleList(); } - - /// Get the full list of filesystem paths this program depends on - List<String> getFilePathDependencies() { return m_filePathDependencyList.getFilePathList(); } - - /// Get the target-specific version of this program for the given `target`. - /// - /// The `target` must be a target on the `Linkage` that was used to create this program. - TargetProgram* getTargetProgram(TargetRequest* target); - - /// Add a module (and everything it depends on) to the list of references - void addReferencedModule(Module* module); - - /// Add a module (but not the things it depends on) to the list of references - /// - /// This is a compatiblity hack for legacy compiler behavior. - void addReferencedLeafModule(Module* module); + Index getEntryPointCount() SLANG_OVERRIDE { return 0; } + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return nullptr; } + Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } + GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } - /// Add an entry point to the program - /// - /// This also adds everything the entry point depends on to the list of references. - /// - void addEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink); + Index getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { return m_specializationParams[index]; } - /// Add an entry point group to the program - /// - /// This also adds everything the entry point group depends on to the list of references. - /// - void addEntryPointGroup(EntryPointGroup* entryPointGroup); + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; - /// Set the global generic argument substitution to use. - void setGlobalGenericSubsitution(RefPtr<Substitutions> subst) - { - m_globalGenericSubst = subst; - } + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies.getModuleList(); } + List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_fileDependencies.getFilePathList(); } - /// Parse a type from a string, in the context of this program. - /// - /// Any names in the string will be resolved using the modules - /// referenced by the program. - /// - /// On an error, returns null and reports diagnostic messages - /// to the provided `sink`. - /// - Type* getTypeFromString(String typeStr, DiagnosticSink* sink); - - /// Get the IR module that represents this program and its entry points. - /// - /// The IR module for a program tries to be minimal, and in the - /// common case will only include symbols with `[import]` declarations - /// for the entry point(s) of the program, and any types they - /// depend on. - /// - /// This IR module is intended to be linked against the IR modules - /// for all of the dependencies (see `getModuleDependencies()`) to - /// provide complete code. - /// - RefPtr<IRModule> getOrCreateIRModule(DiagnosticSink* sink); - - /// Get the number of existential type parameters for the program. - Index getExistentialTypeParamCount() { return m_globalExistentialSlots.paramTypes.getCount(); } - - /// Get the existential type parameter at `index`. - Type* getExistentialTypeParam(Index index) { return m_globalExistentialSlots.paramTypes[index]; } - - /// Get the number of arguments supplied for existential type parameters. - /// - /// Note that the number of arguments may not match the number of parameters. - /// In particular, an unspecialized program may have many parameters, but zero arguments. - Index getExistentialTypeArgCount() { return m_globalExistentialSlots.args.getCount(); } - - /// Get the existential type argument (type and witness table) at `index`. - ExistentialTypeSlots::Arg getExistentialTypeArg(Index index) { return m_globalExistentialSlots.args[index]; } - - /// Get an array of all existential type arguments. - ExistentialTypeSlots::Arg const* getExistentialTypeArgs() { return m_globalExistentialSlots.args.getBuffer(); } - - /// Get an array of all global shader parameters. - List<GlobalShaderParamInfo> const& getShaderParams() { return m_shaderParams; } + protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; - void _collectShaderParams(DiagnosticSink* sink); - void _specializeExistentialTypeParams( - List<RefPtr<Expr>> const& args, - DiagnosticSink* sink); + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; private: + void _collectShaderParams(DiagnosticSink* sink); - // The linakge this program is associated with. - // - // Note that a `Program` keeps its associated linkage alive, - // and not vice versa. - // - RefPtr<Linkage> m_linkage; - - // Tracking data for the list of modules dependend on - ModuleDependencyList m_moduleDependencyList; - - // Tracking data for the list of filesystem paths dependend on - FilePathDependencyList m_filePathDependencyList; - - // Entry points that are part of the program. - List<RefPtr<EntryPoint> > m_entryPoints; - - // Entry points that are part of the program. - List<RefPtr<EntryPointGroup> > m_entryPointGroups; - - // Specializations for global generic parameters (if any) - RefPtr<Substitutions> m_globalGenericSubst; - - // The existential/interface slots associated with the global scope. - ExistentialTypeSlots m_globalExistentialSlots; + List<RefPtr<TranslationUnitRequest>> m_translationUnits; - /// Information about global shader parameters + List<EntryPoint*> m_entryPoints; List<GlobalShaderParamInfo> m_shaderParams; + List<ComponentType*> m_requirements; + SpecializationParams m_specializationParams; + ModuleDependencyList m_moduleDependencies; + FilePathDependencyList m_fileDependencies; + }; - // Generated IR for this program. - RefPtr<IRModule> m_irModule; - - // Cache of target-specific programs for each target. - Dictionary<TargetRequest*, RefPtr<TargetProgram>> m_targetPrograms; + /// A visitor for use with `ComponentType`s, allowing dispatch over the concrete subclasses. + class ComponentTypeVisitor + { + public: + // The following methods should be overriden in a concrete subclass + // to customize how it acts on each of the concrete types of component. + // + // In cases where the application wants to simply "recurse" on a + // composite, specialized, or legacy component type it can use + // the `visitChildren` methods below. + // + virtual void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; + virtual void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) = 0; + virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; + virtual void visitSpecialized(SpecializedComponentType* specialized) = 0; + virtual void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; - // Any types looked up dynamically using `getTypeFromString` - Dictionary<String, RefPtr<Type>> m_types; + protected: + // These helpers can be used to recurse into the logical children of a + // component type, and are useful for the common case where a visitor + // only cares about a few leaf cases. + // + // Note that for a `LegacyProgram` the "children" in this case are the + // `Module`s of the translation units that make up the legacy program. + // In some cases this is what is desired, but in others it is incorrect + // to treat a legacy program as a composition of modules, and instead + // it should be treated directly as a leaf case. Clients should make + // an informed decision based on an understanding of what `LegacyProgram` is used for. + // + void visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo); + void visitChildren(SpecializedComponentType* specialized); + void visitChildren(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo); }; - /// A `Program` specialized for a particular `TargetRequest` + /// A `TargetProgram` represents a `ComponentType` specialized for a particular `TargetRequest` + /// + /// TODO: This should probably be renamed to `TargetComponentType`. + /// + /// By binding a component type to a specific target, a `TargetProgram` allows + /// for things like layout to be computed, that fundamentally depend on + /// the choice of target. + /// + /// A `TargetProgram` handles request for compiled kernel code for + /// entry point functions. In practice, kernel code can only be + /// correctly generated when the underlying `ComponentType` is "fully linked" + /// (has no remaining unsatisfied requirements). + /// class TargetProgram : public RefObject { public: TargetProgram( - Program* program, + ComponentType* componentType, TargetRequest* targetReq); /// Get the underlying program - Program* getProgram() { return m_program; } + ComponentType* getProgram() { return m_program; } /// Get the underlying target TargetRequest* getTargetReq() { return m_targetReq; } @@ -1232,7 +1571,7 @@ namespace Slang private: // The program being compiled or laid out - Program* m_program; + ComponentType* m_program; // The target that code/layout will be generated for TargetRequest* m_targetReq; @@ -1253,7 +1592,7 @@ namespace Slang BackEndCompileRequest( Linkage* linkage, DiagnosticSink* sink, - Program* program = nullptr); + ComponentType* program = nullptr); // Should we dump intermediate results along the way, for debugging? bool shouldDumpIntermediates = false; @@ -1263,8 +1602,8 @@ namespace Slang LineDirectiveMode getLineDirectiveMode() { return lineDirectiveMode; } - Program* getProgram() { return m_program; } - void setProgram(Program* program) { m_program = program; } + ComponentType* getProgram() { return m_program; } + void setProgram(ComponentType* program) { m_program = program; } // Should R/W images without explicit formats be assumed to have "unknown" format? // @@ -1273,7 +1612,7 @@ namespace Slang bool useUnknownImageFormatAsDefault = false; private: - RefPtr<Program> m_program; + RefPtr<ComponentType> m_program; }; /// A compile request that spans the front and back ends of the compiler @@ -1311,11 +1650,8 @@ namespace Slang // Should we just pass the input to another compiler? PassThroughMode passThrough = PassThroughMode::None; - /// Source code for the generic arguments to use for the global generic parameters of the program. - List<String> globalGenericArgStrings; - - /// Types to use to fill global existential "slots" - List<String> globalExistentialSlotArgStrings; + /// Source code for the specialization arguments to use for the global specialization parameters of the program. + List<String> globalSpecializationArgStrings; bool shouldSkipCodegen = false; @@ -1331,11 +1667,8 @@ namespace Slang class EntryPointInfo : public RefObject { public: - /// Source code for the generic arguments to use for the generic parameters of the entry point. - List<String> genericArgStrings; - - /// Source code for the type arguments to plug into the existential type "slots" of the entry point - List<String> existentialArgStrings; + /// Source code for the specialization arguments to use for the specialization parameters of the entry point. + List<String> specializationArgStrings; }; List<EntryPointInfo> entryPoints; @@ -1370,8 +1703,15 @@ namespace Slang FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; } BackEndCompileRequest* getBackEndReq() { return m_backEndReq; } - Program* getUnspecializedProgram() { return getFrontEndReq()->getProgram(); } - Program* getSpecializedProgram() { return m_specializedProgram; } + + ComponentType* getUnspecializedGlobalComponentType() { return getFrontEndReq()->getGlobalComponentType(); } + ComponentType* getUnspecializedGlobalAndEntryPointsComponentType() + { + return getFrontEndReq()->getGlobalAndEntryPointsComponentType(); + } + + ComponentType* getSpecializedGlobalComponentType() { return m_specializedGlobalComponentType; } + ComponentType* getSpecializedGlobalAndEntryPointsComponentType() { return m_specializedGlobalAndEntryPointsComponentType; } private: void init(); @@ -1380,8 +1720,8 @@ namespace Slang RefPtr<Linkage> m_linkage; DiagnosticSink m_sink; RefPtr<FrontEndCompileRequest> m_frontEndReq; - RefPtr<Program> m_unspecializedProgram; - RefPtr<Program> m_specializedProgram; + RefPtr<ComponentType> m_specializedGlobalComponentType; + RefPtr<ComponentType> m_specializedGlobalAndEntryPointsComponentType; RefPtr<BackEndCompileRequest> m_backEndReq; // For output @@ -1623,14 +1963,14 @@ inline slang::IModule* asExternal(Module* module) return static_cast<slang::IModule*>(module); } -inline Program* asInternal(slang::IProgram* module) +inline ComponentType* asInternal(slang::IComponentType* componentType) { - return static_cast<Program*>(module); + return static_cast<ComponentType*>(componentType); } -inline slang::IProgram* asExternal(Program* module) +inline slang::IComponentType* asExternal(ComponentType* componentType) { - return static_cast<slang::IProgram*>(module); + return static_cast<slang::IComponentType*>(componentType); } static inline slang::ProgramLayout* asExternal(ProgramLayout* programLayout) diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index e8114dc95..802c89f8b 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -240,6 +240,8 @@ DIAGNOSTIC(30049, Note, thisIsImmutableByDefault, "a 'this' parameter is an imm DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'") DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'") +DIAGNOSTIC(30060, Error, expectedAType, "expected a type got a '$0'") + DIAGNOSTIC(30100, Error, staticRefToNonStaticMember, "type '$0' cannot be used to refer to non-static member '$1'") DIAGNOSTIC(30201, Error, functionRedefinition, "function '$0' already has a body") @@ -361,21 +363,22 @@ DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is onl DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration") DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.") -DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$2`.") +DIAGNOSTIC(38021, Error, typeArgumentForGenericParameterDoesNotConformToInterface, "type argument `$0` for generic parameter `$1` does not conform to interface `$2`.") DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself") DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')") + DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_DispatchThreadID must be either scalar or vector (1 to 3) of uint/int but is $0"); DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'") -DIAGNOSTIC(38025, Error, mismatchGlobalGenericArguments, "expected $0 global generic arguments ($1 provided)") +DIAGNOSTIC(38025, Error, mismatchSpecializationArguments, "expected $0 specialization arguments ($1 provided)") DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.") DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)") DIAGNOSTIC(38028, Error, existentialSlotArgNotAType, "existential slot argument $0 was not a type") -DIAGNOSTIC(38029, Error, existentialSlotArgDoesNotConform, "existential slot argument $0 does not conform to the required interface '$1'") +DIAGNOSTIC(38029, Error,typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'") DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC(39999, Fatal, errorInImportedModule, "error in imported module, compilation ceased.") diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index c7b5b602b..26af7b9f4 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -2459,13 +2459,13 @@ String CLikeSourceEmitter::getFuncName(IRFunc* func) // name for an entry-point function, but other // targets should try to use the original name. // - // TODO: always use `main`, and have any code - // that wraps this know to use `main` instead - // of the original entry-point name... + // TODO: always use the original name, and + // use the appropriate options for glslang to + // make it support a non-`main` name. // if (getSourceStyle() != SourceStyle::GLSL) { - return getText(entryPointLayout->entryPoint->getName()); + return getText(entryPointLayout->getFuncDecl()->getName()); } // diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 3b9023703..d3c852fd6 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -687,20 +687,20 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointL break; case Stage::Geometry: { - if (auto attrib = entryPointLayout->entryPoint->FindModifier<MaxVertexCountAttribute>()) + if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<MaxVertexCountAttribute>()) { m_writer->emit("layout(max_vertices = "); m_writer->emit(attrib->value); m_writer->emit(") out;\n"); } - if (auto attrib = entryPointLayout->entryPoint->FindModifier<InstanceAttribute>()) + if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<InstanceAttribute>()) { m_writer->emit("layout(invocations = "); m_writer->emit(attrib->value); m_writer->emit(") in;\n"); } - for (auto pp : entryPointLayout->entryPoint->GetParameters()) + for (auto pp : entryPointLayout->getFuncDecl()->GetParameters()) { if (auto inputPrimitiveTypeModifier = pp->FindModifier<HLSLGeometryShaderInputPrimitiveTypeModifier>()) { diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 9c0f7f02b..4abd692f8 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -294,13 +294,13 @@ void HLSLSourceEmitter::_emitHLSLEntryPointAttributes(IRFunc* irFunc, EntryPoint break; case Stage::Geometry: { - if (auto attrib = entryPointLayout->entryPoint->FindModifier<MaxVertexCountAttribute>()) + if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<MaxVertexCountAttribute>()) { m_writer->emit("[maxvertexcount("); m_writer->emit(attrib->value); m_writer->emit(")]\n"); } - if (auto attrib = entryPointLayout->entryPoint->FindModifier<InstanceAttribute>()) + if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<InstanceAttribute>()) { m_writer->emit("[instance("); m_writer->emit(attrib->value); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index c51040b98..d5dcaca98 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -49,36 +49,34 @@ enum class BuiltInCOp EntryPointLayout* findEntryPointLayout( ProgramLayout* programLayout, - EntryPoint* entryPoint, - EntryPointGroupLayout** outEntryPointGroupLayout = nullptr) + EntryPoint* entryPoint) { - for( auto entryPointGroupLayout : programLayout->entryPointGroups ) + // TODO: This function shouldn't need to exist, and it + // somewhat hampers the capabilities of the compiler (e.g., + // it isn't supported to have a single program contain + // two different "instances" of the same entry point). + // + // Code that cares about layouts should be looking up + // the entry point layout by index on a `ProgramLayout`, + // knowing that those indices will align with the order + // of entry points on the `ComponentType` for the program. + + for( auto entryPointLayout : programLayout->entryPoints ) { - for( auto entryPointLayout : entryPointGroupLayout->entryPoints ) - { - if(entryPointLayout->entryPoint->getName() != entryPoint->getName()) - continue; - - // TODO: We need to be careful about this check, since it relies on - // the profile information in the layout matching that in the request. - // - // What we really seem to want here is some dictionary mapping the - // `EntryPoint` directly to the `EntryPointLayout`, and maybe - // that is precisely what we should build... - // - if(entryPointLayout->profile != entryPoint->getProfile()) - continue; - - // TODO: can't easily filter on translation unit here... - // Ideally the `EntryPoint` should get filled in with a pointer - // the specific function declaration that represents the entry point. - - if( outEntryPointGroupLayout ) - { - *outEntryPointGroupLayout = entryPointGroupLayout; - } - return entryPointLayout; - } + if(entryPointLayout->entryPoint.GetName() != entryPoint->getName()) + continue; + + // TODO: We need to be careful about this check, since it relies on + // the profile information in the layout matching that in the request. + // + // What we really seem to want here is some dictionary mapping the + // `EntryPoint` directly to the `EntryPointLayout`, and maybe + // that is precisely what we should build... + // + if(entryPointLayout->profile != entryPoint->getProfile()) + continue; + + return entryPointLayout; } return nullptr; diff --git a/source/slang/slang-ir-bind-existentials.cpp b/source/slang/slang-ir-bind-existentials.cpp index e426e6e92..a99da3410 100644 --- a/source/slang/slang-ir-bind-existentials.cpp +++ b/source/slang/slang-ir-bind-existentials.cpp @@ -69,24 +69,8 @@ struct BindExistentialSlots void processGlobalExistentialSlots() { - // If there are any global existential slots, we will expect - // to find a `bindGlobalExistentialSlots` instruction at module scope. - // - // We will start out by finding that instruction, if it exists. - // - IRInst* bindGlobalExistentialSlotsInst = nullptr; - for( auto inst : module->getGlobalInsts() ) - { - if( inst->op == kIROp_BindGlobalExistentialSlots ) - { - bindGlobalExistentialSlotsInst = inst; - break; - } - } - - // Now we will start looking for global shader parameters that make - // use of existential slots (we can determine this from their - // layout). + // We will search for global shader parameters that make + // use of existential specialization parameters. // for( auto inst : module->getGlobalInsts() ) { @@ -96,23 +80,27 @@ struct BindExistentialSlots if(!globalParam) continue; - // We will delegate to a subroutine for the meat - // of the work, since much of it can be shared - // with the case for entry-point existential - // parameters. + // We only care about global shader parameters + // that have existential specialization parameters, + // and we expect all such parameters to have a + // `[bindExistentialSlots(...)]` decoration that + // was added during IR linking. // - processParameter(globalParam, bindGlobalExistentialSlotsInst); - } + auto bindSlotsInst = globalParam->findDecorationImpl(kIROp_BindExistentialSlotsDecoration); + if(!bindSlotsInst) + continue; - // Once we are done looping over global shader parameters, - // all of the relevant information from the - // `bindGlobalExistentialSlots` instruction will have - // been moved to the parameters themselves, so we - // can eliminate the binding instruction. - // - if( bindGlobalExistentialSlotsInst ) - { - bindGlobalExistentialSlotsInst->removeAndDeallocate(); + replaceTypeUsingExistentialSlots( + globalParam, + bindSlotsInst->getOperandCount(), + bindSlotsInst->getOperands()); + + // Once we have propagated the information from + // the `[bindExistentialSlots(...)]` decoration + // down into the parameter's type, we no longer + // need the decoration. + // + bindSlotsInst->removeAndDeallocate(); } } @@ -150,9 +138,25 @@ struct BindExistentialSlots // We then need to process each of the entry-point // parameters just like we did for global parameters. // + // Because the existential slot arguments for *all* of the parameters + // are attached in a single `[bindExistentialSlots(...)]` decoration, + // we need to carve them up appropriately across the parameters. + // The way we do this is a bit of a kludge, in that we track a + // single `slotOffset` and increment it for each parameter by the + // number of arguments it consumed. + // + // Note: a better approach here might rely on the layout information + // for the parameters, which should directly encode an offset for + // the existential specialization parameters it uses. The challenge + // with this is that we'd need to correctly interpret the offset + // relative to any global-scope specialization parameters or + // generic specialization parameters of the entry point. + // Ultimately the simplistic counter approach is less complicated. + // + Index slotOffset = 0; for( auto param : func->getParams() ) { - processParameter(param, bindEntryPointExistentialSlotsInst); + processEntryPointParameter(param, bindEntryPointExistentialSlotsInst, slotOffset); } // TODO: We would need to consider what to do if @@ -181,9 +185,10 @@ struct BindExistentialSlots // function `param` and a `[bindExistentialSlots(...)]` // decoration; both use the same subroutine. // - void processParameter( + void processEntryPointParameter( IRInst* param, - IRInst* bindSlotsInst) + IRInst* bindSlotsInst, + Index& ioSlotOperandOffset) { // We expect all shader parameters to have layout information, // but to be defensive we will skip any that don't. @@ -206,7 +211,6 @@ struct BindExistentialSlots // find out the stating slot, and the information on // the type to find out the number of slots. // - UInt firstSlot = resInfo->index; UInt slotCount = 0; if(auto typeResInfo = varLayout->getTypeLayout()->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam)) slotCount = UInt(typeResInfo->count.getFiniteValue()); @@ -233,7 +237,8 @@ struct BindExistentialSlots // this parameter. // UInt bindOperandCount = bindSlotsInst->getOperandCount(); - if( 2*(firstSlot + slotCount) > bindOperandCount ) + UInt slotOperandCount = 2*slotCount; + if( (ioSlotOperandOffset + slotOperandCount) > bindOperandCount ) { sink->diagnose(param->sourceLoc, Diagnostics::missingExistentialBindingsForParameter); return; @@ -244,7 +249,7 @@ struct BindExistentialSlots // keeping in mind that each slot accounts for two // operands. // - auto operandsForInst = bindSlotsInst->getOperands() + firstSlot; + auto operandsForInst = bindSlotsInst->getOperands() + ioSlotOperandOffset; // Once we've found the operands that are relevent to // the slots used by `param`, we will defer to a routine @@ -253,17 +258,17 @@ struct BindExistentialSlots // replaceTypeUsingExistentialSlots( param, - slotCount, + slotOperandCount, operandsForInst); + + ioSlotOperandOffset += slotOperandCount; } void replaceTypeUsingExistentialSlots( IRInst* inst, - UInt slotCount, + UInt slotOperandCount, IRUse const* slotArgs) { - SLANG_UNUSED(slotCount); - // We are going to alter the type of the // given `inst` based on information in // the `slotArgs`. @@ -282,7 +287,6 @@ struct BindExistentialSlots // a witness table, so the total number of operands // is twice the number of slots we are filling. // - UInt slotOperandCount = slotCount*2; List<IRInst*> slotOperands; for(UInt ii = 0; ii < slotOperandCount; ++ii) slotOperands.add(slotArgs[ii].get()); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 0c218f2aa..f9c157465 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -206,7 +206,6 @@ INST(Specialize, specialize, 2, 0) INST(lookup_interface_method, lookup_interface_method, 2, 0) INST(lookup_witness_table, lookup_witness_table, 2, 0) INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0) -INST(BindGlobalExistentialSlots, bindGlobalExistentialSlots, 0, 0) INST(Construct, construct, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 359c8e98d..cd59630d8 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1175,10 +1175,6 @@ struct IRBuilder IRInst* param, IRInst* val); - IRInst* emitBindGlobalExistentialSlots( - UInt argCount, - IRInst* const* args); - IRDecoration* addBindExistentialSlotsDecoration( IRInst* value, UInt argCount, diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 7e2ac1f98..56b06a499 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -8,14 +8,13 @@ namespace Slang { -// Needed for lookup up entry-point layouts. -// -// TODO: maybe arrange so that codegen is driven from the layout layer -// instead of the input/request layer. + /// Find a suitable layout for `entryPoint` in `programLayout`. + /// + /// TODO: This function should be eliminated. See its body + /// for an explanation of the problems. EntryPointLayout* findEntryPointLayout( ProgramLayout* programLayout, - EntryPoint* entryPoint, - EntryPointGroupLayout** outEntryPointGroupLayout); + EntryPoint* entryPoint); struct IRSpecSymbol : RefObject { @@ -339,9 +338,10 @@ IRType* cloneType( } void cloneGlobalValueWithCodeCommon( - IRSpecContextBase* context, - IRGlobalValueWithCode* clonedValue, - IRGlobalValueWithCode* originalValue); + IRSpecContextBase* context, + IRGlobalValueWithCode* clonedValue, + IRGlobalValueWithCode* originalValue, + IROriginalValuesForClone const& originalValues); IRRate* cloneRate( IRSpecContextBase* context, @@ -382,11 +382,58 @@ IRGlobalVar* cloneGlobalVarImpl( cloneGlobalValueWithCodeCommon( context, clonedVar, - originalVar); + originalVar, + originalValues); return clonedVar; } + /// Clone certain special decorations for `clonedInst` from its (potentially multiple) definitions. + /// + /// In most cases, once we've decided on the "best" definition to use for an IR instruction, + /// we only want the linking process to use the decorations from the single best definition. + /// In some casses, though, the canonical best definition might not have all the information. + /// + /// A concrete example is the `[bindExistentialsSlots(...)]` decorations for global shader + /// parameters and entry points. These decorations are only generated as part of the IR + /// associated with a specialization of a program, and not the original IR for the modules + /// of the program. + /// + /// This function scans through all the `originalValues` that were considered for `clonedInst`, + /// and copies over any decorations that are allowed to come from a non-"best" definition. + /// For a given decoration opcode, only one such decoration will ever be copied, and nothing + /// will be copied if the instruction already has a matching decoration (that was cloned + /// from the "best" definition). + /// +static void cloneExtraDecorations( + IRSpecContextBase* context, + IRInst* clonedInst, + IROriginalValuesForClone const& originalValues) +{ + IRBuilder builderStorage = *context->builder; + IRBuilder* builder = &builderStorage; + builder->setInsertInto(clonedInst); + + for(auto sym = originalValues.sym; sym; sym = sym->nextWithSameName) + { + for(auto decoration : sym->irGlobalValue->getDecorations()) + { + switch(decoration->op) + { + default: + break; + + case kIROp_BindExistentialSlotsDecoration: + if(!clonedInst->findDecorationImpl(decoration->op)) + { + cloneInst(context, builder, decoration); + } + break; + } + } + } +} + void cloneSimpleGlobalValueImpl( IRSpecContextBase* context, IRInst* originalInst, @@ -407,6 +454,12 @@ void cloneSimpleGlobalValueImpl( { cloneInst(context, builder, child); } + + // Also clone certain decorations if they appear on *any* + // definition of the symbol (not necessarily the one + // we picked as the primary/best). + // + cloneExtraDecorations(context, clonedInst, originalValues); } IRGlobalParam* cloneGlobalParamImpl( @@ -469,7 +522,8 @@ IRGeneric* cloneGenericImpl( cloneGlobalValueWithCodeCommon( context, clonedVal, - originalVal); + originalVal, + originalValues); return clonedVal; } @@ -545,7 +599,8 @@ IRInterfaceType* cloneInterfaceTypeImpl( void cloneGlobalValueWithCodeCommon( IRSpecContextBase* context, IRGlobalValueWithCode* clonedValue, - IRGlobalValueWithCode* originalValue) + IRGlobalValueWithCode* originalValue, + IROriginalValuesForClone const& originalValues) { // Next we are going to clone the actual code. IRBuilder builderStorage = *context->builder; @@ -553,6 +608,7 @@ void cloneGlobalValueWithCodeCommon( builder->setInsertInto(clonedValue); cloneDecorations(context, clonedValue, originalValue); + cloneExtraDecorations(context, clonedValue, originalValues); // We will walk through the blocks of the function, and clone each of them. // @@ -629,10 +685,11 @@ void checkIRDuplicate(IRInst* inst, IRInst* moduleInst, UnownedStringSlice const } void cloneFunctionCommon( - IRSpecContextBase* context, - IRFunc* clonedFunc, - IRFunc* originalFunc, - bool checkDuplicate = true) + IRSpecContextBase* context, + IRFunc* clonedFunc, + IRFunc* originalFunc, + IROriginalValuesForClone const& originalValues, + bool checkDuplicate = true) { // First clone all the simple properties. clonedFunc->setFullType(cloneType(context, originalFunc->getFullType())); @@ -640,7 +697,8 @@ void cloneFunctionCommon( cloneGlobalValueWithCodeCommon( context, clonedFunc, - originalFunc); + originalFunc, + originalValues); // Shuffle the function to the end of the list, because // it needs to follow its dependencies. @@ -667,7 +725,6 @@ IRInst* specializeGeneric( IRFunc* specializeIRForEntryPoint( IRSpecContext* context, - EntryPoint* entryPoint, EntryPointLayout* entryPointLayout) { // We start by looking up the IR symbol that @@ -679,7 +736,7 @@ IRFunc* specializeIRForEntryPoint( // so that the mangled name of the decl-ref is // not the same as the mangled name of the decl. // - auto mangledName = getMangledName(entryPoint->getFuncDeclRef()); + auto mangledName = getMangledName(entryPointLayout->getFuncDeclRef()); RefPtr<IRSpecSymbol> sym; if (!context->getSymbols().TryGetValue(mangledName, sym)) { @@ -947,7 +1004,7 @@ IRFunc* cloneFuncImpl( { auto clonedFunc = builder->createFunc(); registerClonedValue(context, clonedFunc, originalValues); - cloneFunctionCommon(context, clonedFunc, originalFunc); + cloneFunctionCommon(context, clonedFunc, originalFunc, originalValues); return clonedFunc; } @@ -1227,7 +1284,9 @@ LinkedIR linkIR( CodeGenTarget target, TargetRequest* targetReq) { - auto sink = compileRequest->getSink(); + // TODO: We need to make sure that the program we are being asked + // to compile has been "resolved" so that it has no outstanding + // unsatisfied requirements. IRSpecializationState stateStorage; auto state = &stateStorage; @@ -1252,12 +1311,11 @@ LinkedIR linkIR( // accelerate lookup, we will create a symbol table for looking // up IR definitions by their mangled name. // - auto originalProgramIRModule = program->getOrCreateIRModule(sink); - insertGlobalValueSymbols(sharedContext, originalProgramIRModule); - for (auto module : program->getModuleDependencies()) + program->enumerateIRModules([&](IRModule* irModule) { - insertGlobalValueSymbols(sharedContext, module->getIRModule()); - } + insertGlobalValueSymbols(sharedContext, irModule); + }); + auto context = state->getContext(); context->shared = sharedContext; @@ -1282,32 +1340,9 @@ LinkedIR linkIR( context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); } - EntryPointGroupLayout* entryPointGroupLayout = nullptr; - auto entryPointLayout = findEntryPointLayout(programLayout, entryPoint, &entryPointGroupLayout); - - auto offsetEntryPointLayout = entryPointLayout->getAbsoluteLayout(entryPointGroupLayout); + auto entryPointLayout = findEntryPointLayout(programLayout, entryPoint); - // Note: when we are doing the compatibility approach for Falcor, we - // can have global-scope symbols that are actually part of the - // local root signature (entry point group), so we need to make - // sure to apply those layouts appropriately. - auto entryPointGroupStructLayout = getScopeStructLayout(entryPointGroupLayout); - for(auto entry : entryPointGroupStructLayout->mapVarToLayout) - { - if(!entry.Key) - continue; - - auto mangledName = getMangledName(entry.Key); - auto groupVarLayout = entry.Value; - - // We need to "adjust" the layout that was computed for the parameter - // because it will be relative to the start of the entry-point group, - // rather than absolute. - // - auto absoluteVarLayout = groupVarLayout->getAbsoluteLayout(entryPointGroupLayout->parametersLayout); - - context->globalVarLayouts.AddIfNotExists(mangledName, absoluteVarLayout); - } + auto offsetEntryPointLayout = entryPointLayout; context->builder->setInsertInto(context->getModule()->getModuleInst()); @@ -1316,7 +1351,7 @@ LinkedIR linkIR( // TODO: This step should *not* be needed with the current IR // specialization approach, so we should consider removing it. // - for (auto sym :context->getSymbols()) + for (auto sym : context->getSymbols()) { if (sym.Value->irGlobalValue->op == kIROp_WitnessTable) cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); @@ -1327,28 +1362,31 @@ LinkedIR linkIR( // the entry point function itself, and rely on // this step to recursively copy over anything else // it might reference. - auto irEntryPoint = specializeIRForEntryPoint(context, entryPoint, offsetEntryPointLayout); + auto irEntryPoint = specializeIRForEntryPoint(context, offsetEntryPointLayout); - // HACK: right now the bindings for global generic parameters are coming in - // as part of the original IR module, and we need to make sure these get - // copied over, even if they aren't referenced. + // Bindings for global generic parameters are currently represented + // as stand-alone global-scope instructions in the IR module for + // `SpecializedComponentType`s. These instructions are required for + // correct codegen, and so we must make sure to copy them all over, + // even though they are not directly referenced. // - for(auto inst : originalProgramIRModule->getGlobalInsts()) - { - auto bindInst = as<IRBindGlobalGenericParam>(inst); - if(!bindInst) - continue; - - cloneValue(context, bindInst); - } - - for(auto inst : originalProgramIRModule->getGlobalInsts()) + // TODO: We should change these to decorations, akin to how + // `[bindExistentialSlots(...)]` works, so that they can be attached + // to the relevant parameters and cloned via `cloneExtraDecorations`. + // In the long run we do not want to *ever* iterate over all the + // instructions in all the input modules. + // + program->enumerateIRModules([&](IRModule* irModule) { - if(inst->op != kIROp_BindGlobalExistentialSlots) - continue; + for(auto inst : irModule->getGlobalInsts()) + { + auto bindInst = as<IRBindGlobalGenericParam>(inst); + if(!bindInst) + continue; - cloneValue(context, inst); - } + cloneValue(context, bindInst); + } + }); // HACK: we need to ensure that any tagged union types // in the IR module have layout information copied over to them. @@ -1357,7 +1395,7 @@ LinkedIR linkIR( // instructions, since we expected the tagged union type(s) to // be referenced by them. // - for( auto taggedUnionTypeLayout : entryPointLayout->taggedUnionTypeLayouts ) + for( auto taggedUnionTypeLayout : programLayout->taggedUnionTypeLayouts ) { auto taggedUnionType = taggedUnionTypeLayout->getType(); auto mangledName = getMangledTypeName(taggedUnionType); @@ -1377,6 +1415,15 @@ LinkedIR linkIR( // we have global variables with initializers, since // these should get run whether or not the entry point // references them. + // + // Or alternatively we can define by fiat that the initializers + // on global variables get run at an unspecified time between + // program startup and the first access to a given global. + // Such a definition gives us the freedom to eliminate globals + // that are never accessed, while still doing "eager" + // initialization for globals that are referenced (instead of + // having to add the overhead of lazy initialization a la + // function-`static` variables). // Now that we've cloned the entry point and everything // it refers to, we can package up the data we return diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ed86e4d6d..2e8de5e37 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2880,22 +2880,6 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitBindGlobalExistentialSlots( - UInt argCount, - IRInst* const* args) - { - auto inst = createInstWithTrailingArgs<IRInst>( - this, - kIROp_BindGlobalExistentialSlots, - getVoidType(), - 0, - nullptr, - argCount, - args); - addInst(inst); - return inst; - } - IRDecoration* IRBuilder::addBindExistentialSlotsDecoration( IRInst* value, UInt argCount, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index eef69b1b5..e41be0e33 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1613,13 +1613,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower auto irBaseType = lowerType(context, type->baseType); List<IRInst*> slotArgs; - for(auto arg : type->slots.args) + for(auto arg : type->args) { - auto irArgType = lowerType(context, arg.type); - auto irArgWitness = lowerSimpleVal(context, arg.witness); + auto irArgVal = lowerSimpleVal(context, arg.val); + slotArgs.add(irArgVal); - slotArgs.add(irArgType); - slotArgs.add(irArgWitness); + if(auto witness = arg.witness) + { + auto irArgWitness = lowerSimpleVal(context, witness); + slotArgs.add(irArgWitness); + } } auto irType = getBuilder()->getBindExistentialsType(irBaseType, slotArgs.getCount(), slotArgs.getBuffer()); @@ -6265,13 +6268,18 @@ static void lowerFrontEndEntryPointToIR( } static void lowerProgramEntryPointToIR( - IRGenContext* context, - EntryPoint* entryPoint) + IRGenContext* context, + EntryPoint* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) { + auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); + if(specializationInfo) + entryPointFuncDeclRef = specializationInfo->specializedFuncDeclRef; + // First, lower the entry point like an ordinary function + auto session = context->getSession(); - auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); auto builder = context->irBuilder; @@ -6280,7 +6288,6 @@ static void lowerProgramEntryPointToIR( auto loweredEntryPointFunc = getSimpleVal(context, emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType)); - // if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>()) { builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); @@ -6289,26 +6296,24 @@ static void lowerProgramEntryPointToIR( // We may have shader parameters of interface/existential type, // which need us to supply concrete type information for specialization. // - auto existentialTypeArgCount = entryPoint->getExistentialTypeArgCount(); - if( existentialTypeArgCount ) + if(specializationInfo && specializationInfo->existentialSpecializationArgs.getCount() != 0) { List<IRInst*> existentialSlotArgs; - for( Index ii = 0; ii < existentialTypeArgCount; ++ii ) + for(auto arg : specializationInfo->existentialSpecializationArgs ) { - auto arg = entryPoint->getExistentialTypeArg(ii); - - auto irArgType = lowerType(context, arg.type); - auto irWitnessTable = lowerSimpleVal(context, arg.witness); + auto irArgType = lowerSimpleVal(context, arg.val); existentialSlotArgs.add(irArgType); - existentialSlotArgs.add(irWitnessTable); + + if(auto witness = arg.witness) + { + auto irWitnessTable = lowerSimpleVal(context, witness); + existentialSlotArgs.add(irWitnessTable); + } } builder->addBindExistentialSlotsDecoration(loweredEntryPointFunc, existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer()); } - - - } /// Ensure that `decl` and all relevant declarations under it get emitted. @@ -6451,97 +6456,136 @@ IRModule* generateIRForTranslationUnit( return module; } -RefPtr<IRModule> generateIRForProgram( - Session* session, - Program* program, - DiagnosticSink* sink) + /// Context for generating IR code to represent a `SpecializedComponentType` +struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor { -// auto compileRequest = translationUnit->compileRequest; + DiagnosticSink* sink; + Linkage* linkage; + Session* session; + IRGenContext* context; + IRBuilder* builder; - SharedIRGenContext sharedContextStorage( - session, - sink); - SharedIRGenContext* sharedContext = &sharedContextStorage; + RefPtr<IRModule> process( + SpecializedComponentType* componentType, + DiagnosticSink* inSink) + { + sink = inSink; - IRGenContext contextStorage(sharedContext); - IRGenContext* context = &contextStorage; + linkage = componentType->getLinkage(); + session = linkage->getSessionImpl(); - SharedIRBuilder sharedBuilderStorage; - SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; - sharedBuilder->module = nullptr; - sharedBuilder->session = session; + SharedIRGenContext sharedContextStorage( + session, + sink); + SharedIRGenContext* sharedContext = &sharedContextStorage; - IRBuilder builderStorage; - IRBuilder* builder = &builderStorage; - builder->sharedBuilder = sharedBuilder; + IRGenContext contextStorage(sharedContext); + context = &contextStorage; - RefPtr<IRModule> module = builder->createModule(); - sharedBuilder->module = module; + SharedIRBuilder sharedBuilderStorage; + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = nullptr; + sharedBuilder->session = session; - context->irBuilder = builder; + IRBuilder builderStorage; + builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; - // We need to emit symbols for all of the entry - // points in the program; this is especially - // important in the case where a generic entry - // point is being specialized. - // - for(auto entryPoint : program->getEntryPoints()) - { - lowerProgramEntryPointToIR(context, entryPoint); - } + RefPtr<IRModule> module = builder->createModule(); + sharedBuilder->module = module; + builder->setInsertInto(module->getModuleInst()); - // Now lower all the arguments supplied for global generic - // type parameters. - // - for (RefPtr<Substitutions> subst = program->getGlobalGenericSubstitution(); subst; subst = subst->outer) - { - auto gSubst = subst.as<GlobalGenericParamSubstitution>(); - if(!gSubst) - continue; - - IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl)); - IRType* typeVal = lowerType(context, gSubst->actualType); + context->irBuilder = builder; - // bind `typeParam` to `typeVal` - builder->emitBindGlobalGenericParam(typeParam, typeVal); + componentType->acceptVisitor(this, nullptr); - for (auto& constraintArg : gSubst->constraintArgs) - { - IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl)); - IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val); + return module; + } - // bind `constraintParam` to `constraintVal` - builder->emitBindGlobalGenericParam(constraintParam, constraintVal); - } + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // We need to emit symbols for all of the entry + // points in the program; this is especially + // important in the case where a generic entry + // point is being specialized. + // + lowerProgramEntryPointToIR(context, entryPoint, specializationInfo); } - // We may have shader parameters of interface/existential type, - // which need us to supply concrete type information for specialization. - // - auto existentialTypeArgCount = program->getExistentialTypeArgCount(); - if( existentialTypeArgCount ) + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE { - List<IRInst*> existentialSlotArgs; - for( Index ii = 0; ii < existentialTypeArgCount; ++ii ) + // We've hit a leaf module, so we should be able to bind any global + // generic type parameters here... + // + if( specializationInfo ) { - auto arg = program->getExistentialTypeArg(ii); + for( auto genericArgInfo : specializationInfo->genericArgs ) + { + IRInst* irParam = getSimpleVal(context, ensureDecl(context, genericArgInfo.paramDecl)); + IRInst* irVal = lowerSimpleVal(context, genericArgInfo.argVal); - auto irArgType = lowerType(context, arg.type); - auto irWitnessTable = lowerSimpleVal(context, arg.witness); + // bind `irParam` to `irVal` + builder->emitBindGlobalGenericParam(irParam, irVal); + } - existentialSlotArgs.add(irArgType); - existentialSlotArgs.add(irWitnessTable); + auto shaderParamCount = module->getShaderParamCount(); + Index existentialArgOffset = 0; + + for( Index ii = 0; ii < shaderParamCount; ++ii ) + { + auto shaderParam = module->getShaderParam(ii); + auto specializationArgCount = shaderParam.specializationParamCount; + + IRInst* irParam = getSimpleVal(context, ensureDecl(context, shaderParam.paramDeclRef)); + List<IRInst*> irSlotArgs; + for( Index jj = 0; jj < specializationArgCount; ++jj ) + { + auto& specializationArg = specializationInfo->existentialArgs[existentialArgOffset++]; + + auto irType = lowerSimpleVal(context, specializationArg.val); + auto irWitness = lowerSimpleVal(context, specializationArg.witness); + + irSlotArgs.add(irType); + irSlotArgs.add(irWitness); + } + + builder->addBindExistentialSlotsDecoration( + irParam, + irSlotArgs.getCount(), + irSlotArgs.getBuffer()); + } } + } - builder->emitBindGlobalExistentialSlots(existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer()); + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); } + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } - // TODO: Should we apply any of the validation or - // mandatory optimization passes here? + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // TODO: This case should be akin to the `Module` case, + // and deal with global-scope specialization parameters + // directly. + // + SLANG_UNUSED(legacy); + SLANG_UNUSED(specializationInfo); + SLANG_UNIMPLEMENTED_X("legacy program case"); + } +}; - return module; +RefPtr<IRModule> generateIRForSpecializedComponentType( + SpecializedComponentType* componentType, + DiagnosticSink* sink) +{ + SpecializedComponentTypeIRGenContext context; + return context.process(componentType, sink); } } // namespace Slang diff --git a/source/slang/slang-lower-to-ir.h b/source/slang/slang-lower-to-ir.h index 060efb88b..33dfd9d27 100644 --- a/source/slang/slang-lower-to-ir.h +++ b/source/slang/slang-lower-to-ir.h @@ -15,14 +15,32 @@ namespace Slang { class EntryPoint; class ProgramLayout; + class SpecializedComponentType; class TranslationUnitRequest; + /// Generate an IR module to represent the code in the given `translationUnit`. + /// + /// The generated module will include IR definitions for any functions/types + /// in `translationUnit`, but it is *not* guaranteed to contain any definitions + /// from modules that are `import`ed into `translationUnit`. The resulting IR + /// module must be linked against other IR modules that define any symbols + /// that are imported before code generation can be performed. + /// IRModule* generateIRForTranslationUnit( TranslationUnitRequest* translationUnit); - RefPtr<IRModule> generateIRForProgram( - Session* session, - Program* program, - DiagnosticSink* sink); + /// Generate an IR module to represent the specializations applied by `componentType`. + /// + /// The generated IR will encode how `componentType` specializes global or + /// entry-point specialization parameters to concrete arguments (e.g., types). + /// + /// The generated IR module is *not* guaranteed to contain anything more, such + /// as the actual definitions of functions or types being specialized. The + /// resulting IR module must be linked against other IR modules that define + /// those symbols before code generation can be performed. + /// + RefPtr<IRModule> generateIRForSpecializedComponentType( + SpecializedComponentType* componentType, + DiagnosticSink* sink); } #endif diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index dc99f55e2..722725af7 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -674,16 +674,22 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( EntryPointParameterState const& state, RefPtr<VarLayout> varLayout); -// Collect a single declaration into our set of parameters -static void collectGlobalGenericParameter( - ParameterBindingContext* context, - RefPtr<GlobalGenericParamDecl> paramDecl) +static RefPtr<VarLayout> _createVarLayout( + TypeLayout* typeLayout, + DeclRef<VarDeclBase> varDeclRef) { - RefPtr<GenericParamLayout> layout = new GenericParamLayout(); - layout->decl = paramDecl; - layout->index = (int)context->shared->programLayout->globalGenericParams.getCount(); - context->shared->programLayout->globalGenericParams.add(layout); - context->shared->programLayout->globalGenericParamsMap[layout->decl->getName()->text] = layout.Ptr(); + RefPtr<VarLayout> varLayout = new VarLayout(); + varLayout->typeLayout = typeLayout; + varLayout->varDecl = varDeclRef; + + if(auto pendingDataTypeLayout = typeLayout->pendingDataTypeLayout) + { + RefPtr<VarLayout> pendingVarLayout = new VarLayout(); + pendingVarLayout->typeLayout = pendingDataTypeLayout; + varLayout->pendingVarLayout = pendingVarLayout; + } + + return varLayout; } // Collect a single declaration into our set of parameters @@ -712,9 +718,7 @@ static void collectGlobalScopeParameter( return; // Now create a variable layout that we can use - RefPtr<VarLayout> varLayout = new VarLayout(); - varLayout->typeLayout = typeLayout; - varLayout->varDecl = varDeclRef; + RefPtr<VarLayout> varLayout = _createVarLayout(typeLayout, varDeclRef); // The logic in `check.cpp` that created the `GlobalShaderParamInfo` // will have identified any cases where there might be multiple @@ -748,6 +752,7 @@ static void collectGlobalScopeParameter( RefPtr<VarLayout> additionalVarLayout = new VarLayout(); additionalVarLayout->typeLayout = typeLayout; additionalVarLayout->varDecl = additionalVarDeclRef; + additionalVarLayout->pendingVarLayout = varLayout->pendingVarLayout; parameterInfo->varLayouts.add(additionalVarLayout); } @@ -1770,15 +1775,40 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( return structLayout; } - else if (auto globalGenericParam = declRef.as<GlobalGenericParamDecl>()) + else if (auto globalGenericParamDecl = declRef.as<GlobalGenericParamDecl>()) { - auto genParamTypeLayout = new GenericParamTypeLayout(); - // we should have already populated ProgramLayout::genericEntryPointParams list at this point, - // so we can find the index of this generic param decl in the list - genParamTypeLayout->type = type; - genParamTypeLayout->paramIndex = findGenericParam(context->shared->programLayout->globalGenericParams, globalGenericParam.getDecl()); - genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; - return genParamTypeLayout; + auto& layoutContext = context->layoutContext; + + if( auto concreteType = findGlobalGenericSpecializationArg( + layoutContext, + globalGenericParamDecl) ) + { + // If we know what concrete type has been used to specialize + // the global generic type parameter, then we should use + // the concrete type instead. + // + // Note: it should be illegal for the user to use a generic + // type parameter in a varying parameter list without giving + // it an explicit user-defined semantic. Otherwise, it would be possible + // that the concrete type that gets plugged in is a user-defined + // `struct` that uses some `SV_` semantics in its definition, + // so that any static information about what system values + // the entry point uses would be incorrect. + // + return processEntryPointVaryingParameter(context, concreteType, state, varLayout); + } + else + { + // If we don't know a concrete type, then we aren't generating final + // code, so the reflection information should show the generic + // type parameter. + // + // We don't make any attempt to assign varying parameter resources + // to the generic type, since we can't know how many "slots" + // of varying input/output it would consume. + // + return createTypeLayoutForGlobalGenericTypeParam(layoutContext, type, globalGenericParamDecl); + } } else if (auto associatedTypeParam = declRef.as<AssocTypeDecl>()) { @@ -1804,15 +1834,12 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( /// Compute the type layout for a parameter declared directly on an entry point. static RefPtr<TypeLayout> computeEntryPointParameterTypeLayout( ParameterBindingContext* context, - SubstitutionSet typeSubst, DeclRef<VarDeclBase> paramDeclRef, RefPtr<VarLayout> paramVarLayout, EntryPointParameterState& state) { - auto paramDeclRefType = GetType(paramDeclRef); - SLANG_ASSERT(paramDeclRefType); - - auto paramType = paramDeclRefType->Substitute(typeSubst).as<Type>(); + auto paramType = GetType(paramDeclRef); + SLANG_ASSERT(paramType); if( paramDeclRef.getDecl()->HasModifier<HLSLUniformModifier>() ) { @@ -1940,6 +1967,12 @@ struct ScopeLayoutBuilder { m_structLayout->mapVarToLayout.Add(firstVarLayout->varDecl.getDecl(), firstVarLayout); } + } + + void addParameter( + RefPtr<VarLayout> varLayout) + { + _addParameter(varLayout, nullptr); // Any "pending" items on a field type become "pending" items // on the overall `struct` type layout. @@ -1948,42 +1981,58 @@ struct ScopeLayoutBuilder // `struct` layout logic in `type-layout.cpp`. If this gets any // more complicated we should see if there is a way to share it. // - if( auto fieldPendingDataTypeLayout = firstVarLayout->typeLayout->pendingDataTypeLayout ) + if( auto fieldPendingDataTypeLayout = varLayout->typeLayout->pendingDataTypeLayout ) { m_pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(nullptr, m_rules); - auto fieldPendingDataVarLayout = m_pendingDataTypeLayoutBuilder.addField(firstVarLayout->varDecl, fieldPendingDataTypeLayout); + auto fieldPendingDataVarLayout = m_pendingDataTypeLayoutBuilder.addField(varLayout->varDecl, fieldPendingDataTypeLayout); m_structLayout->pendingDataTypeLayout = m_pendingDataTypeLayoutBuilder.getTypeLayout(); - if( parameterInfo ) - { - for( auto& varLayout : parameterInfo->varLayouts ) - { - varLayout->pendingVarLayout = fieldPendingDataVarLayout; - } - } - else - { - firstVarLayout->pendingVarLayout = fieldPendingDataVarLayout; - } + varLayout->pendingVarLayout = fieldPendingDataVarLayout; } } void addParameter( - RefPtr<VarLayout> varLayout) - { - _addParameter(varLayout, nullptr); - } - - void addParameter( ParameterInfo* parameterInfo) { SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); auto firstVarLayout = parameterInfo->varLayouts.getFirst(); _addParameter(firstVarLayout, parameterInfo); - } + // Global parameters will have their non-orindary/uniform + // pending data handled by the main parameter binding + // logic, but we still need to construct a layout + // that includes any pending data. + // + if(auto fieldPendingVarLayout = firstVarLayout->pendingVarLayout) + { + auto fieldPendingTypeLayout = fieldPendingVarLayout->typeLayout; + + m_pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(nullptr, m_rules); + m_structLayout->pendingDataTypeLayout = m_pendingDataTypeLayoutBuilder.getTypeLayout(); + + auto fieldUniformLayoutInfo = fieldPendingTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform); + LayoutSize fieldUniformSize = fieldUniformLayoutInfo ? fieldUniformLayoutInfo->count : 0; + if( fieldUniformSize != 0 ) + { + // Make sure uniform fields get laid out properly... + + UniformLayoutInfo fieldInfo( + fieldUniformSize, + fieldPendingTypeLayout->uniformAlignment); + + LayoutSize uniformOffset = m_rules->AddStructField( + m_pendingDataTypeLayoutBuilder.getStructLayoutInfo(), + fieldInfo); + + fieldPendingVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); + } + + m_pendingDataTypeLayoutBuilder.getTypeLayout()->fields.add(fieldPendingVarLayout); + } + + } // Add a "simple" parameter that cannot have any user-defined // register or binding modifiers, so that its layout computation @@ -2088,21 +2137,48 @@ static ParameterBindingAndKindInfo maybeAllocateConstantBufferBinding( return info; } + /// Remove resource usage from `typeLayout` that should only be stored per-entry-point. + /// + /// This is used when constructing the overall layout for an entry point, to make sure + /// that certain kinds of resource usage from the entry point don't "leak" into + /// the resource usage of the overall program. + /// +static void removePerEntryPointParameterKinds( + TypeLayout* typeLayout) +{ + typeLayout->removeResourceUsage(LayoutResourceKind::VaryingInput); + typeLayout->removeResourceUsage(LayoutResourceKind::VaryingOutput); + typeLayout->removeResourceUsage(LayoutResourceKind::ShaderRecord); + typeLayout->removeResourceUsage(LayoutResourceKind::HitAttributes); + typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialObjectParam); + typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialTypeParam); +} + /// Iterate over the parameters of an entry point to compute its requirements. /// static RefPtr<EntryPointLayout> collectEntryPointParameters( - ParameterBindingContext* context, - EntryPoint* entryPoint, - SubstitutionSet typeSubst) + ParameterBindingContext* context, + EntryPoint* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) { DeclRef<FuncDecl> entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); + // If specialization was applied to the entry point, then the side-band + // information that was generated will have a more specialized reference + // to the entry point with generic parameters filled in. We should + // use that version if it is available. + // + if(specializationInfo) + entryPointFuncDeclRef = specializationInfo->specializedFuncDeclRef; + + auto entryPointType = DeclRefType::Create(context->getLinkage()->getSessionImpl(), entryPointFuncDeclRef); + // We will take responsibility for creating and filling in // the `EntryPointLayout` object here. // RefPtr<EntryPointLayout> entryPointLayout = new EntryPointLayout(); entryPointLayout->profile = entryPoint->getProfile(); - entryPointLayout->entryPoint = entryPointFuncDeclRef.getDecl(); + entryPointLayout->entryPoint = entryPointFuncDeclRef; // The entry point layout must be added to the output // program layout so that it can be accessed by reflection. @@ -2114,19 +2190,6 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( // context->entryPointLayout = entryPointLayout; - // Note: this isn't really the best place for this logic to sit, - // but it is the simplest place where we have a direct correspondence - // between a single `EntryPoint` and its matching `EntryPointLayout`, - // so we'll use it. - // - for( auto taggedUnionType : entryPoint->getTaggedUnionTypes() ) - { - SLANG_ASSERT(taggedUnionType); - auto substType = taggedUnionType->Substitute(typeSubst).as<Type>(); - auto typeLayout = createTypeLayout(context->layoutContext, substType); - entryPointLayout->taggedUnionTypeLayouts.add(typeLayout); - } - // We are going to iterate over the entry-point parameters, // and while we do so we will go ahead and perform layout/binding // assignment for two cases: @@ -2149,22 +2212,33 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( ScopeLayoutBuilder scopeBuilder; scopeBuilder.beginLayout(context); auto paramsStructLayout = scopeBuilder.m_structLayout; + paramsStructLayout->type = entryPointType; for( auto& shaderParamInfo : entryPoint->getShaderParams() ) { auto paramDeclRef = shaderParamInfo.paramDeclRef; + // Any generic specialization applied to the entry-point function + // must also be applied to its parameters. + paramDeclRef.substitutions = entryPointFuncDeclRef.substitutions; + // When computing layout for an entry-point parameter, // we want to make sure that the layout context has access // to the existential type arguments (if any) that were // provided for the entry-point existential type parameters (if any). // - context->layoutContext= context->layoutContext - .withExistentialTypeArgs( - entryPoint->getExistentialTypeArgCount(), - entryPoint->getExistentialTypeArgs()) - .withExistentialTypeSlotsOffsetBy( - shaderParamInfo.firstExistentialTypeSlot); + if(specializationInfo) + { + auto& existentialSpecializationArgs = specializationInfo->existentialSpecializationArgs; + auto genericSpecializationParamCount = entryPoint->getGenericSpecializationParamCount(); + + context->layoutContext = context->layoutContext + .withSpecializationArgs( + existentialSpecializationArgs.getBuffer(), + existentialSpecializationArgs.getCount()) + .withSpecializationArgsOffsetBy( + shaderParamInfo.firstSpecializationParamIndex - genericSpecializationParamCount); + } // Any error messages we emit during the process should // refer to the location of this parameter. @@ -2183,7 +2257,6 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( auto paramTypeLayout = computeEntryPointParameterTypeLayout( context, - typeSubst, paramDeclRef, paramVarLayout, state); @@ -2204,6 +2277,26 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( // scopeBuilder.addSimpleParameter(paramVarLayout); } + + // We don't want certain kinds of resource usage within an entry + // point to "leak" into the overall resource usage of the entry + // point and thus lead to offsetting of successive entry points. + // + // For example if we have a vertex and a fragment entry point + // in the some program, and each has one varying input, then + // the both the vertex and fragment varying outputs should have + // a location/index of zero. It would be bad if the fragment + // input (or whichever entry point comes second in the global + // ordering) started at location one, because then it wouldn't + // line up correctly with any vertex stage outputs. + // + // We handle this with a bit of a kludge, by removing the + // particular `LayoutResourceKind`s that are susceptible to + // this problem from the overall resource usage of the entry + // point. + // + removePerEntryPointParameterKinds(scopeBuilder.m_structLayout); + entryPointLayout->parametersLayout = scopeBuilder.endLayout(); // For an entry point with a non-`void` return type, we need to process the @@ -2212,7 +2305,7 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( // TODO: Ideally we should make the layout process more robust to empty/void // types and apply this logic unconditionally. // - auto resultType = GetResultType(entryPointFuncDeclRef)->Substitute(typeSubst).as<Type>(); + auto resultType = GetResultType(entryPointFuncDeclRef); SLANG_ASSERT(resultType); if( !resultType->Equals(resultType->getSession()->getVoidType()) ) @@ -2226,7 +2319,7 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( auto resultTypeLayout = processEntryPointVaryingParameterDecl( context, entryPointFuncDeclRef.getDecl(), - resultType->Substitute(typeSubst).as<Type>(), + resultType, state, resultLayout); @@ -2248,136 +2341,257 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( return entryPointLayout; } - /// Remove resource usage from `typeLayout` that should only be stored per-entry-point. - /// - /// This is used when constructing the layout for an entry point group, to make sure - /// that certain kinds of resource usage from the entry point don't "leak" into - /// the resource usage of the group. - /// -static void removePerEntryPointParameterKinds( - TypeLayout* typeLayout) + /// Visitor used by `collectGlobalGenericArguments` +struct CollectGlobalGenericArgumentsVisitor : ComponentTypeVisitor { - typeLayout->removeResourceUsage(LayoutResourceKind::VaryingInput); - typeLayout->removeResourceUsage(LayoutResourceKind::VaryingOutput); - typeLayout->removeResourceUsage(LayoutResourceKind::ShaderRecord); - typeLayout->removeResourceUsage(LayoutResourceKind::HitAttributes); - typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialObjectParam); - typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialTypeParam); -} + CollectGlobalGenericArgumentsVisitor( + ParameterBindingContext* context) + : m_context(context) + {} -static void collectParameters( - ParameterBindingContext* inContext, - Program* program) -{ - // All of the parameters in translation units directly - // referenced in the compile request are part of one - // logical namespace/"linkage" so that two parameters - // with the same name should represent the same - // parameter, and get the same binding(s) + ParameterBindingContext* m_context; - ParameterBindingContext contextData = *inContext; - auto context = &contextData; - context->stage = Stage::Unknown; + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(entryPoint); + SLANG_UNUSED(specializationInfo); + } - auto globalGenericSubst = program->getGlobalGenericSubstitution(); + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(module); - // We will start by looking for any global generic type parameters. + if(!specializationInfo) + return; - for(RefPtr<Module> module : program->getModuleDependencies()) - { - for( auto genParamDecl : module->getModuleDecl()->getMembersOfType<GlobalGenericParamDecl>() ) + for(auto& globalGenericArg : specializationInfo->genericArgs) { - collectGlobalGenericParameter(context, genParamDecl); + if(auto globalGenericTypeParamDecl = as<GlobalGenericParamDecl>(globalGenericArg.paramDecl)) + { + m_context->shared->programLayout->globalGenericArgs.Add(globalGenericTypeParamDecl, globalGenericArg.argVal); + } } } - // Once we have enumerated global generic type parameters, we can - // begin enumerating shader parameters, starting at the global scope. - // - // Because we have already enumerated the global generic type parameters, - // we will be able to look up the index of a global generic type parameter - // when we see it referenced in the type of one of the shader parameters. + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } - for(auto& globalParamInfo : program->getShaderParams() ) + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE { - // When computing layout for a global shader parameter, - // we want to make sure that the layout context has access - // to the existential type arguments (if any) that were - // provided for the global existential type parameters (if any). - // - context->layoutContext= context->layoutContext - .withExistentialTypeArgs( - program->getExistentialTypeArgCount(), - program->getExistentialTypeArgs()) - .withExistentialTypeSlotsOffsetBy( - globalParamInfo.firstExistentialTypeSlot); + specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo()); + } - collectGlobalScopeParameter(context, globalParamInfo, globalGenericSubst); + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // TODO: Need to do something in this case... + SLANG_UNUSED(legacy); + SLANG_UNUSED(specializationInfo); } +}; + + /// Collect an ordered list of all the specialization arguments given for global generic specialization parameters in `program`. + /// + /// This information is used to accelerate the process of mapping a global generic type + /// to its definition during type layout. + /// +static void collectGlobalGenericArguments( + ParameterBindingContext* context, + ComponentType* program) +{ + CollectGlobalGenericArgumentsVisitor visitor(context); + program->acceptVisitor(&visitor, nullptr); +} - // Next consider parameters for entry points - for( auto entryPointGroup : program->getEntryPointGroups() ) + /// Collect information about the (unspecialized) specialization parameters of `program` into `context`. + /// + /// This function computes the reflection/layout for for the specialization parameters, so + /// that they can be exposed to the API user. + /// +static void collectSpecializationParams( + ParameterBindingContext* context, + ComponentType* program) +{ + auto specializationParamCount = program->getSpecializationParamCount(); + for(Index ii = 0; ii < specializationParamCount; ++ii) { - RefPtr<EntryPointGroupLayout> entryPointGroupLayout = new EntryPointGroupLayout(); - entryPointGroupLayout->group = entryPointGroup; + auto specializationParam = program->getSpecializationParam(ii); + switch(specializationParam.flavor) + { + case SpecializationParam::Flavor::GenericType: + case SpecializationParam::Flavor::GenericValue: + { + RefPtr<GenericSpecializationParamLayout> paramLayout = new GenericSpecializationParamLayout(); + paramLayout->decl = specializationParam.object.as<Decl>(); + context->shared->programLayout->specializationParams.add(paramLayout); + } + break; - context->shared->programLayout->entryPointGroups.add(entryPointGroupLayout); + case SpecializationParam::Flavor::ExistentialType: + case SpecializationParam::Flavor::ExistentialValue: + { + RefPtr<ExistentialSpecializationParamLayout> paramLayout = new ExistentialSpecializationParamLayout(); + paramLayout->type = specializationParam.object.as<Type>(); + context->shared->programLayout->specializationParams.add(paramLayout); + } + break; + default: + SLANG_UNEXPECTED("unhandled specialization parameter flavor"); + break; + } + } +} + + /// Visitor used by `collectParameters()` +struct CollectParametersVisitor : ComponentTypeVisitor +{ + CollectParametersVisitor( + ParameterBindingContext* context) + : m_context(context) + {} - ScopeLayoutBuilder scopeBuilder; - scopeBuilder.beginLayout(context); - auto entryPointGroupParamsStructLayout = scopeBuilder.m_structLayout; + ParameterBindingContext* m_context; - // First lay out any shader parameters that belong to the group - // itself, rather than to its nested entry points. + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // The parameters of a composite component type can + // be determined by just visiting its children in order. // - // This ensures that looking up one of the parameters of the - // group by index in its parameters truct will Just Work. + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + // The parameters of a specialized component type + // are just those of its base component type, with + // appropriate specialization information passed + // along. // - for( auto groupParam : entryPointGroup->getShaderParams() ) + visitChildren(specialized); + + // While we are at it, we will also make note of any + // tagged-union types that were used as part of the + // specialization arguments, since we need to make + // sure that their layout information is computed + // and made available for IR code generation. + // + // Note: this isn't really the best place for this logic to sit, + // but it is the simplest place where we can collect all the tagged + // union types that get referenced by a program. + // + for( auto taggedUnionType : specialized->getTaggedUnionTypes() ) { - auto paramDeclRef = groupParam.paramDeclRef; - auto paramType = GetType(paramDeclRef); + SLANG_ASSERT(taggedUnionType); + auto substType = taggedUnionType; + auto typeLayout = createTypeLayout(m_context->layoutContext, substType); + m_context->shared->programLayout->taggedUnionTypeLayouts.add(typeLayout); + } + } - RefPtr<VarLayout> paramVarLayout = new VarLayout(); - paramVarLayout->varDecl = paramDeclRef; - auto paramTypeLayout = createTypeLayout( - context->layoutContext.with(context->getRulesFamily()->getConstantBufferRules()), - paramType); - paramVarLayout->typeLayout = paramTypeLayout; + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // An entry point is a leaf case. + // + // In our current model an entry point does not introduce + // any global shader parameters, but in practice it effectively + // acts a lot like a single global shader parameter named after + // the entry point and with a `struct` type that combines + // all the `uniform` entry point parameters. + // + // Later passes will need to make sure that the entry point + // gets enumerated in the right order relative to any global + // shader parameters. + // - scopeBuilder.addSimpleParameter(paramVarLayout); - } + ParameterBindingContext contextData = *m_context; + auto context = &contextData; + context->stage = entryPoint->getStage(); - for(auto entryPoint : entryPointGroup->getEntryPoints()) - { - // Note: we do not want the entry point group to accumulate - // locations for varying input/output parameters: those - // should be specific to each entry point. - // - // We address this issue by manually removing any - // layout information for the relevant resource kinds - // from the group's layout before adding the parameters - // of any entry point for layout. - // - removePerEntryPointParameterKinds(scopeBuilder.m_structLayout); + collectEntryPointParameters(context, entryPoint, specializationInfo); + } - context->stage = entryPoint->getStage(); - auto entryPointLayout = collectEntryPointParameters(context, entryPoint, globalGenericSubst); + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // A single module represents a leaf case for layout. + // + // We will enumerate the (global) shader parameters declared + // in the module and add each to our canonical ordering. + // + auto paramCount = module->getShaderParamCount(); - auto entryPointParamsLayout = entryPointLayout->parametersLayout; - auto entryPointParamsTypeLayout = entryPointParamsLayout->typeLayout; + ExpandedSpecializationArg* specializationArgs = specializationInfo + ? specializationInfo->existentialArgs.getBuffer() + : nullptr; - scopeBuilder.addSimpleParameter(entryPointParamsLayout); + for(Index pp = 0; pp < paramCount; ++pp) + { + auto shaderParamInfo = module->getShaderParam(pp); + if(specializationArgs) + { + m_context->layoutContext = m_context->layoutContext.withSpecializationArgs( + specializationArgs, + shaderParamInfo.specializationParamCount); + specializationArgs += shaderParamInfo.specializationParamCount; + } - entryPointGroupLayout->entryPoints.add(entryPointLayout); + collectGlobalScopeParameter(m_context, shaderParamInfo, SubstitutionSet()); } - removePerEntryPointParameterKinds(scopeBuilder.m_structLayout); + } + + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // A legacy program is also a leaf case, and we + // can enumerate its parameters directly. + // + // Note: there is a mismatch here where we really + // ought to be tracking specialization arguments + // for a `LegacyProgram` akin to how they are + // tracked for a `Module`, but right now we try + // to do it like a `CompositeComponentType`. + // As a result we are just ignoring specialization + // information here, which will lead to incorrect + // results if somebody every uses specialization + // together with the "legacy" program case. + // + // TODO: eliminate this problem by getting rid of + // `LegacyProgram`, rather than spend time trying + // to make this corner case actually work. + // + SLANG_UNUSED(specializationInfo); - entryPointGroupLayout->parametersLayout = scopeBuilder.endLayout(); + auto paramCount = legacy->getShaderParamCount(); + for(Index pp = 0; pp < paramCount; ++pp) + { + collectGlobalScopeParameter(m_context, legacy->getShaderParam(pp), SubstitutionSet()); + } } - context->entryPointLayout = nullptr; +}; + + /// Recursively collect the global shader parameters and entry points in `program`. + /// + /// This function is used to establish the global ordering of parameters and + /// entry points used for layout. + /// +static void collectParameters( + ParameterBindingContext* inContext, + ComponentType* program) +{ + // All of the parameters in translation units directly + // referenced in the compile request are part of one + // logical namespace/"linkage" so that two parameters + // with the same name should represent the same + // parameter, and get the same binding(s) + + ParameterBindingContext contextData = *inContext; + auto context = &contextData; + context->stage = Stage::Unknown; + + CollectParametersVisitor visitor(context); + program->acceptVisitor(&visitor, nullptr); } /// Emit a diagnostic about a uniform parameter at global scope. @@ -2418,6 +2632,302 @@ static int _calcTotalNumUsedRegistersForLayoutResourceKind(ParameterBindingConte return numUsed; } + /// Keep track of the running global counter for entry points and global parameters visited. + /// + /// Because of explicit `register` and `[[vk::binding(...)]]` support, parameter binding + /// needs to proceed in multiple passes, and each pass must both visit the things that + /// need layout (parameters and entry points) in the same order in each pass, and must + /// also be able to look up the side-band information that flows between passes. + /// + /// Currently the `ParameterBindingContext` keeps separate arrays for global shader + /// parameters and entry points, but in the global ordering for layout they can be + /// interleaved. There is also no simple tracking structure that relates a global + /// parameter or entry point to its index in those arrays. Instead, we just keep + /// running counters during our passes over the program so that we can easily + /// compute the linear index of each entry point and global parameter as it + /// is encountered. + /// +struct ParameterBindingVisitorCounters +{ + Index entryPointCounter = 0; + Index globalParamCounter = 0; +}; + + /// Recursive routine to "complete" all binding for parameters and entry points in `componentType`. + /// + /// This includes allocation of as-yet-unused register/binding ranges to parameters (which + /// will then affect the ranges of registers/bindings that are available to subsequent + /// parameters), and imporantly *also* includes allocate of space to any "pending" + /// data for interface/existential type parameters/fields. + /// +static void _completeBindings( + ParameterBindingContext* context, + ComponentType* componentType, + ParameterBindingVisitorCounters* ioCounters); + + /// A visitor used by `_completeBindings`. + /// + /// This visitor walks the structure of a `ComponentType` to ensure that + /// any shader parameters (and entry points) it contains that *don't* + /// have explicit bindings on them get allocated registers/bindings + /// as appropriate. + /// + /// The main complication of this visitor is how it handles the + /// `SpecializedComponentType` case, because a specialized component + /// type needs to be handled as an atomic unit that lays out the + /// same in all contexts. + /// +struct CompleteBindingsVisitor : ComponentTypeVisitor +{ + CompleteBindingsVisitor(ParameterBindingContext* context, ParameterBindingVisitorCounters* counters) + : m_context(context) + , m_counters(counters) + {} + + ParameterBindingContext* m_context; + ParameterBindingVisitorCounters* m_counters; + + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(entryPoint); + SLANG_UNUSED(specializationInfo); + + // We compute the index of the entry point in the global ordering, + // so we can look up the tracking data in our context. As a result + // we don't actually make use of the parameters that were passed in. + // + auto globalEntryPointIndex = m_counters->entryPointCounter++; + auto globalEntryPointInfo = m_context->shared->programLayout->entryPoints[globalEntryPointIndex]; + + + // We mostly treat an entry point like a single shader parameter that + // uses its `parametersLayout`. + // + completeBindingsForParameter(m_context, globalEntryPointInfo->parametersLayout); + } + + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(specializationInfo); + // A module is a leaf case: we just want to visit each parameter. + visitLeafParams(module); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(specializationInfo); + // A legacy program is a leaf case: we just want to visit each parameter. + visitLeafParams(legacy); + } + + void visitLeafParams(ComponentType* componentType) + { + auto paramCount = componentType->getShaderParamCount(); + for(Index ii = 0; ii < paramCount; ++ii) + { + auto globalParamIndex = m_counters->globalParamCounter++; + auto globalParamInfo = m_context->shared->parameters[globalParamIndex]; + + completeBindingsForParameter(m_context, globalParamInfo); + } + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // We just wnat to recurse on the children of the composite in order. + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + // The handling of a specialized component type here is subtle. + // + // We do *not* simply recurse on the base component type. + // Doing so would ensure that the parameters would get + // registers/bindings allocated to them, but it wouldn't + // allocate space for the "pending" data related to + // existential/interface parameters. + // + // Instead, we recursive through `_completeBindings`, + // which has the job of allocating space for the parameters, + // and then for any "pending" data required. + // + // Handling things this way ensures that a particular + // `SpecializedComponentType` gets laid out exactly + // the same wherever it gets used, rather than + // getting laid out differently when it is placed + // into different compositions. + // + auto base = specialized->getBaseComponentType(); + _completeBindings(m_context, base, m_counters); + } +}; + + /// A visitor used by `_completeBindings`. + /// + /// This visitor is used to follow up after the `CompleteBindingsVisitor` + /// any ensure that any "pending" data required by the parameters that + /// got laid out now gets a location. + /// + /// To make a concrete example: + /// + /// Texture2D a; + /// IThing b; + /// Texture2D c; + /// + /// If these parameters were laid out with `b` specialized to a type + /// that contains a single `Texture2D`, then the `CompleteBindingsVisitor` + /// would visit `a`, `b`, and then `c` in order. It would give `a` the + /// first register/binding available (say, `t0`). It would then make + /// a note that due to specialization, `b`, needs a `t` register as well, + /// but it *cannot* be allocated just yet, because doing so would change + /// the location of `c`, so it is marked as "pending." Then `c` would + /// be visited and get `t1`. As a result the registers given to `a` + /// and `c` are independent of how `b` gets specialized. + /// + /// Next, the `FlushPendingDataVisitor` comes through and applies to + /// the parameters again. For `a` there is no pending data, but for + /// `b` there is a pending request for a `t` register, so it gets allocated + /// now (getting `t2`). The `c` parameter then has no pending data, so + /// we are done. + /// + /// *When* the pending data gets flushed is then significant. In general, + /// the order in which modules get composed an specialized is signficaint. + /// The module above (let's call it `M`) has one specialization parameter + /// (for `b`), and if we want to compose it with another module `N` that + /// has no specialization parameters, we could compute either: + /// + /// compose(specialize(M, SomeType), N) + /// + /// or: + /// + /// specialize(compose(M,N), SomeType) + /// + /// In the first case, the "pending" data for `M` gets flushed right after `M`, + /// so that `specialize(M,SomeType)` can have a consistent layout + /// regardless of how it is used. In the second case, the pending data for + /// `M` only gets flushed after `N`'s parameters are allocated, thus guaranteeing + /// that the `compose(M,N)` part has a consistent layout regardless of what + /// type gets plugged in during specialization. + /// + /// There are trade-offs to be made by an application about which approach + /// to prefer, and the compiler supports either policy choice. + /// +struct FlushPendingDataVisitor : ComponentTypeVisitor +{ + FlushPendingDataVisitor(ParameterBindingContext* context, ParameterBindingVisitorCounters* counters) + : m_context(context) + , m_counters(counters) + {} + + ParameterBindingContext* m_context; + ParameterBindingVisitorCounters* m_counters; + + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(entryPoint); + SLANG_UNUSED(specializationInfo); + + auto globalEntryPointIndex = m_counters->entryPointCounter++; + auto globalEntryPointInfo = m_context->shared->programLayout->entryPoints[globalEntryPointIndex]; + + // We need to allocate space for any "pending" data that + // appeared in the entry-point parameter list. + // + _allocateBindingsForPendingData(m_context, globalEntryPointInfo->parametersLayout->pendingVarLayout); + } + + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(specializationInfo); + visitLeafParams(module); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(specializationInfo); + visitLeafParams(legacy); + } + + void visitLeafParams(ComponentType* componentType) + { + // In the "leaf" case we just allocate space for any + // pending data in the parameters, in order. + // + auto paramCount = componentType->getShaderParamCount(); + for(Index ii = 0; ii < paramCount; ++ii) + { + auto globalParamIndex = m_counters->globalParamCounter++; + auto globalParamInfo = m_context->shared->parameters[globalParamIndex]; + auto firstVarLayout = globalParamInfo->varLayouts[0]; + + _allocateBindingsForPendingData(m_context, firstVarLayout->pendingVarLayout); + } + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + // Because `SpecializedComponentType` was a special case for `CompleteBindingsVisitor`, + // it ends up being a special case here too. + // + // The `CompleteBindings...` pass treated a `SpecializedComponentType` + // as an atomic unit. Any "pending" data that came from its parameters + // will already have been dealt with, so it would be incorrect for + // us to recurse into `specialized`. + // + // Instead, we just need to *skip* `specialized`, since it was + // completely handled already. This isn't quite as simple + // as just doing nothing, because our passes are using + // some global counters to find the absolute/linear index + // of each parameter and entry point as it is encountered. + // We will simply bump those counters by the number of + // parameters and entry points contained under `specialized`, + // which is luckily provided by the `ComponentType` API. + // + m_counters->globalParamCounter += specialized->getShaderParamCount(); + m_counters->entryPointCounter += specialized->getEntryPointCount(); + } + +}; + +static void _completeBindings( + ParameterBindingContext* context, + ComponentType* componentType, + ParameterBindingVisitorCounters* ioCounters) +{ + ParameterBindingVisitorCounters savedCounters = *ioCounters; + + CompleteBindingsVisitor completeBindingsVisitor(context, ioCounters); + componentType->acceptVisitor(&completeBindingsVisitor, nullptr); + + FlushPendingDataVisitor flushVisitor(context, &savedCounters); + componentType->acceptVisitor(&flushVisitor, nullptr); +} + + /// "Complete" binding of parametesr in the given `program`. + /// + /// Completing binding involves both assigning registers/bindings + /// to an parameters that didn't get explicit locations, and then + /// also providing locations to any "pending" data that needed + /// space allocated (used for existential/interface type parameters). + /// +static void _completeBindings( + ParameterBindingContext* context, + ComponentType* program) +{ + // The process of completing binding has a recursive structure, + // so we will immediately delegate to a subroutine that handles + // the recursion. + // + ParameterBindingVisitorCounters counters; + _completeBindings(context, program, &counters); +} + RefPtr<ProgramLayout> generateParameterBindings( TargetProgram* targetProgram, DiagnosticSink* sink) @@ -2450,20 +2960,67 @@ RefPtr<ProgramLayout> generateParameterBindings( context.shared = &sharedContext; context.layoutContext = layoutContext; - // Walk through AST to discover all the parameters + // We want to start by finding out what (if anything) has + // been bound to the global generic parameters of the + // program, since we need to know these types to compute + // layout for parameters that use the generic type parameters. + // + collectGlobalGenericArguments(&context, program); + + // Next we want to collect a full listing of all the shader + // parameters that need to be considered for layout, along + // with all of the entry points, which also need their + // parameters laid out and thus act pretty much like global + // parameters themselves. + // collectParameters(&context, program); - // Now walk through the parameters to generate initial binding information + // We will also collect basic information on the specialization + // parameters exposed by the program. + // + // Whereas `collectGlobalGenericArguments` was collecting the + // concrete types that have been plugged into specialization + // parameters, this step is about collecting the *unspecialized* + // parameters (if any) for the purposes of reflection. + // + collectSpecializationParams(&context, program); + + // Once we have a canonical list of all the shader parameters + // (and entry points) in need of layout, we will walk through + // the parameters that might have explicit binding annotations, + // and "reserve" the registers/bindings/etc. that those parameters + // declare so that subequent automatic layout steps do not try to + // overlap them. + // + // Along the way we will issue diagnostics if there appear to + // be overlapping, conflicting, or inconsistent explicit bindings. + // + // Note that we do *not* support explicit binding annotations + // on entry point parameters, so we only consider global shader + // parameters here. + // + // (Also note that explicit bindings end up being the main + // source of complexity in the layout system, and we could greatly + // simplify this file by eliminating support for explicit + // binding in the future) + // for( auto& parameter : sharedContext.parameters ) { generateParameterBindings(&context, parameter); } - // Determine if there are any global-scope parameters that use `Uniform` - // resources, and thus need to get packaged into a constant buffer. + // Once we have a canonical list of all the parameters, we can + // detect if there are any global-scope parameters that make use + // of `LayoutResourceKind::Uniform`, since such parameters would + // need to be packaged into a "default" constant buffer. + // The fxc/dxc compilers support this step, and in reflection + // refer to the generated constant buffer as `$Globals`. + // + // Note that this logic doesn't account for the existance of + // "legacy" (non-buffer-bound) uniforms in GLSL for OpenGL. + // If we wanted to support legaqcy uniforms we would probably + // want to do so through a different feature. // - // Note: this doesn't account for GLSL's support for "legacy" uniforms - // at global scope, which don't get assigned a CB. bool needDefaultConstantBuffer = false; for( auto& parameterInfo : sharedContext.parameters ) { @@ -2525,6 +3082,32 @@ RefPtr<ProgramLayout> generateParameterBindings( break; } } + + // We also need a default space for any entry-point parameters + // that consume appropriate resource kinds. + // + for(auto& entryPoint : sharedContext.programLayout->entryPoints) + { + auto paramsLayout = entryPoint->parametersLayout; + for(auto resInfo : paramsLayout->resourceInfos ) + { + switch(resInfo.kind) + { + default: + break; + + case LayoutResourceKind::RegisterSpace: + case LayoutResourceKind::VaryingInput: + case LayoutResourceKind::VaryingOutput: + case LayoutResourceKind::HitAttributes: + case LayoutResourceKind::RayPayload: + continue; + } + + needDefaultSpace = true; + break; + } + } } // If we need a space for default bindings, then allocate it here. @@ -2575,12 +3158,14 @@ RefPtr<ProgramLayout> generateParameterBindings( &context, needDefaultConstantBuffer); - // Now walk through again to actually give everything - // ranges of registers... - for( auto& parameter : sharedContext.parameters ) - { - completeBindingsForParameter(&context, parameter); - } + // Now that all of the explicit bindings have been dealt with + // and we've also allocate any space/buffer that is required + // for global-scope parameters, we will go through the + // shader parameters and entry points yet again, in order + // to actually allocate specific bindings/registers to + // parameters and entry points that need them. + // + _completeBindings(&context, program); // Next we need to create a type layout to reflect the information // we have collected, and we will use the `ScopeLayoutBuilder` @@ -2602,104 +3187,6 @@ RefPtr<ProgramLayout> generateParameterBindings( cbInfo->index = globalConstantBufferBinding.index; } - // After we have laid out all the ordinary global parameters, - // we need to "flush" out any pending data that was associated with - // the global scope as part of dealing with interface-type parameters. - // - _allocateBindingsForPendingData(&context, globalScopeVarLayout->pendingVarLayout); - - // After we have allocated registers/bindings to everything - // in the global scope we will process the parameters - // of the entry points. - // - // Note: at the moment we are laying out *all* information related to global-scope - // parameters (including pending data from interface-type parameters) before - // anything pertaining to entry points. This is a crucial design choice to - // get right, and we might want to revisit it based on experience. - - // In some cases, a user will want to ensure that all the - // entry points they compile get non-overlapping - // registers/bindings. E.g., if you have a vertex and fragment - // shader being compiled together for Vulkan, you probably want distinct - // bindings for their entry-point `uniform` parametres, so - // that they can be used together. - // - // In other cases, however, a user probably doesn't want us - // to conservatively allocate non-overlapping bindings. - // E.g., if they have a bunch of compute shaders in a single - // file, then they probably want each compute shader to - // compute its parameter layout "from scratch" as if the - // others don't exist. - // - // The way we handle this is by putting the entry points of a - // `Program` into groups, and ensuring that within each group - // we allocate parameters that don't overlap, but we don't - // worry about overlap across groups. - // - for( auto entryPointGroup : sharedContext.programLayout->entryPointGroups ) - { - // We save off the allocation state as it was before the entry-point - // group, so that we can restore it to this state after each group. - // - // TODO: We probably ought to wrap all the state relevant to allocation - // of registers/bindings into a single struct/field so that we only - // have one thing to save/restore here even if new state gets added. - // - auto savedGlobalSpaceUsedRangeSets = sharedContext.globalSpaceUsedRangeSets; - auto savedUsedSpaces = sharedContext.usedSpaces; - - // The group will have been allocated a layout that combines the - // usage of all of the contained entry-points, so we just need to - // allocate the entry-point group to be placed after all the global-scope - // parameters. - // - auto entryPointGroupParamsLayout = entryPointGroup->parametersLayout; - completeBindingsForParameter(&context, entryPointGroupParamsLayout); - - _allocateBindingsForPendingData(&context, entryPointGroupParamsLayout->pendingVarLayout); - - // TODO: Should we add the offset information from the group to - // the layout information for each entry point (and thence to its parameters)? - // - // This seems important if we want to allow clients to conveniently - // ignore groups when doing their reflection queries. - - // Once we've allocated bindigns for the parameters of entry points - // in the group, we restore the state for tracking what register/bindings - // are used to where it was before. - // - sharedContext.globalSpaceUsedRangeSets = savedGlobalSpaceUsedRangeSets; - sharedContext.usedSpaces = savedUsedSpaces; - } - - // HACK: we want global parameters to not have to deal with offsetting - // by the `VarLayout` stored in `globalScopeVarLayout`, so we will scan - // through and for any global parameter that used "pending" data, we will manually - // offset all of its resource infos to account for where the global pending data - // got placed. - // - // TODO: A more appropriate solution would be to pass the `globalScopeVarLayout` - // down into the pass that puts layout information onto global parameters in - // the IR, and apply the offsetting there. - // - for( auto& parameterInfo : sharedContext.parameters ) - { - for( auto varLayout : parameterInfo->varLayouts ) - { - auto pendingVarLayout = varLayout->pendingVarLayout; - if(!pendingVarLayout) continue; - - for( auto& resInfo : pendingVarLayout->resourceInfos ) - { - if( auto globalResInfo = globalScopeVarLayout->pendingVarLayout->FindResourceInfo(resInfo.kind) ) - { - resInfo.index += globalResInfo->index; - resInfo.space += globalResInfo->space; - } - } - } - } - programLayout->parametersLayout = globalScopeVarLayout; { @@ -2723,7 +3210,7 @@ ProgramLayout* TargetProgram::getOrCreateLayout(DiagnosticSink* sink) } void generateParameterBindings( - Program* program, + ComponentType* program, TargetRequest* targetReq, DiagnosticSink* sink) { diff --git a/source/slang/slang-reflection.cpp b/source/slang/slang-reflection.cpp index 3871e2ccc..5e6fd10cd 100644 --- a/source/slang/slang-reflection.cpp +++ b/source/slang/slang-reflection.cpp @@ -51,9 +51,9 @@ static inline SlangReflectionTypeLayout* convert(TypeLayout* type) return (SlangReflectionTypeLayout*) type; } -static inline GenericParamLayout* convert(SlangReflectionTypeParameter * typeParam) +static inline SpecializationParamLayout* convert(SlangReflectionTypeParameter * typeParam) { - return (GenericParamLayout*)typeParam; + return (SpecializationParamLayout*) typeParam; } static inline VarDeclBase* convert(SlangReflectionVariable* var) @@ -86,16 +86,6 @@ static inline SlangReflectionEntryPoint* convert(EntryPointLayout* entryPoint) return (SlangReflectionEntryPoint*) entryPoint; } -static inline EntryPointGroupLayout* convert(SlangEntryPointGroupLayout* entryPointGroup) -{ - return (EntryPointGroupLayout*) entryPointGroup; -} - -static inline SlangEntryPointGroupLayout* convert(EntryPointGroupLayout* entryPointGroup) -{ - return (SlangEntryPointGroupLayout*) entryPointGroup; -} - static inline ProgramLayout* convert(SlangReflection* program) { return (ProgramLayout*) program; @@ -865,7 +855,7 @@ SLANG_API int spReflectionTypeLayout_getGenericParamIndex(SlangReflectionTypeLay if(auto genericParamTypeLayout = as<GenericParamTypeLayout>(typeLayout)) { - return genericParamTypeLayout->paramIndex; + return (int) genericParamTypeLayout->paramIndex; } else { @@ -1222,7 +1212,7 @@ SLANG_API char const* spReflectionEntryPoint_getName( auto entryPointLayout = convert(inEntryPoint); if(!entryPointLayout) return 0; - return getText(entryPointLayout->entryPoint->getName()).begin(); + return getText(entryPointLayout->entryPoint.GetName()).begin(); } SLANG_API unsigned spReflectionEntryPoint_getParameterCount( @@ -1270,7 +1260,7 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( SlangUInt sizeAlongAxis[3] = { 1, 1, 1 }; // First look for the HLSL case, where we have an attribute attached to the entry point function - auto numThreadsAttribute = entryPointFunc->FindModifier<NumThreadsAttribute>(); + auto numThreadsAttribute = entryPointFunc.getDecl()->FindModifier<NumThreadsAttribute>(); if (numThreadsAttribute) { sizeAlongAxis[0] = numThreadsAttribute->x; @@ -1281,7 +1271,7 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( { // Fall back to the GLSL case, which requires a search over global-scope declarations // to look for as with the `local_size_*` qualifier - auto module = as<ModuleDecl>(entryPointFunc->ParentDecl); + auto module = as<ModuleDecl>(entryPointFunc.getDecl()->ParentDecl); if (module) { for (auto dd : module->Members) @@ -1323,11 +1313,43 @@ SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput( return (entryPointLayout->flags & EntryPointLayout::Flag::usesAnySampleRateInput) != 0; } +SLANG_API SlangReflectionVariableLayout* spReflectionEntryPoint_getVarLayout( + SlangReflectionEntryPoint* inEntryPoint) +{ + auto entryPointLayout = convert(inEntryPoint); + if(!entryPointLayout) + return nullptr; + + return convert(entryPointLayout->parametersLayout); +} + +static bool hasDefaultConstantBuffer(ScopeLayout* layout) +{ + auto typeLayout = layout->parametersLayout->getTypeLayout(); + return as<ParameterGroupTypeLayout>(typeLayout) != nullptr; +} + +SLANG_API int spReflectionEntryPoint_hasDefaultConstantBuffer( + SlangReflectionEntryPoint* inEntryPoint) +{ + auto entryPointLayout = convert(inEntryPoint); + if(!entryPointLayout) + return 0; + + return hasDefaultConstantBuffer(entryPointLayout); +} + + // SlangReflectionTypeParameter SLANG_API char const* spReflectionTypeParameter_GetName(SlangReflectionTypeParameter * inTypeParam) { - auto typeParam = convert(inTypeParam); - return typeParam->decl->getName()->text.getBuffer(); + auto specializationParam = convert(inTypeParam); + if( auto genericParamLayout = as<GenericSpecializationParamLayout>(specializationParam) ) + { + return genericParamLayout->decl->getName()->text.getBuffer(); + } + // TODO: Add case for existential type parameter? They don't have as simple of a notion of "name" as the generic case... + return nullptr; } SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParameter * inTypeParam) @@ -1338,16 +1360,34 @@ SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParamet SLANG_API unsigned int spReflectionTypeParameter_GetConstraintCount(SlangReflectionTypeParameter* inTypeParam) { - auto typeParam = convert(inTypeParam); - auto constraints = typeParam->decl->getMembersOfType<GenericTypeConstraintDecl>(); - return (unsigned int)constraints.getCount(); + auto specializationParam = convert(inTypeParam); + if(auto genericParamLayout = as<GenericSpecializationParamLayout>(specializationParam)) + { + if( auto globalGenericParamDecl = as<GlobalGenericParamDecl>(genericParamLayout->decl) ) + { + auto constraints = globalGenericParamDecl->getMembersOfType<GenericTypeConstraintDecl>(); + return (unsigned int)constraints.getCount(); + } + // TODO: Add case for entry-point generic parameters. + } + // TODO: Add case for existential type parameters. + return 0; } SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex(SlangReflectionTypeParameter * inTypeParam, unsigned index) { - auto typeParam = convert(inTypeParam); - auto constraints = typeParam->decl->getMembersOfType<GenericTypeConstraintDecl>(); - return (SlangReflectionType*)constraints.toArray()[index]->sup.Ptr(); + auto specializationParam = convert(inTypeParam); + if(auto genericParamLayout = as<GenericSpecializationParamLayout>(specializationParam)) + { + if( auto globalGenericParamDecl = as<GlobalGenericParamDecl>(genericParamLayout->decl) ) + { + auto constraints = globalGenericParamDecl->getMembersOfType<GenericTypeConstraintDecl>(); + return (SlangReflectionType*)constraints.toArray()[index]->sup.Ptr(); + } + // TODO: Add case for entry-point generic parameters. + } + // TODO: Add case for existential type parameters. + return 0; } // Shader Reflection @@ -1379,22 +1419,32 @@ SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflec SLANG_API unsigned int spReflection_GetTypeParameterCount(SlangReflection * reflection) { auto program = convert(reflection); - return (unsigned int)program->globalGenericParams.getCount(); + return (unsigned int) program->specializationParams.getCount(); } SLANG_API SlangReflectionTypeParameter* spReflection_GetTypeParameterByIndex(SlangReflection * reflection, unsigned int index) { auto program = convert(reflection); - return (SlangReflectionTypeParameter*)program->globalGenericParams[index].Ptr(); + return (SlangReflectionTypeParameter*) program->specializationParams[index].Ptr(); } SLANG_API SlangReflectionTypeParameter * spReflection_FindTypeParameter(SlangReflection * inProgram, char const * name) { auto program = convert(inProgram); if (!program) return nullptr; - GenericParamLayout * result = nullptr; - program->globalGenericParamsMap.TryGetValue(name, result); - return (SlangReflectionTypeParameter*)result; + for( auto& param : program->specializationParams ) + { + auto genericParamLayout = as<GenericSpecializationParamLayout>(param); + if(!genericParamLayout) + continue; + + if(getText(genericParamLayout->decl->getName()) != UnownedTerminatedStringSlice(name)) + continue; + + return (SlangReflectionTypeParameter*) genericParamLayout; + } + + return 0; } SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* inProgram) @@ -1421,7 +1471,7 @@ SLANG_API SlangReflectionEntryPoint* spReflection_findEntryPointByName(SlangRefl // TODO: improve on naive linear search for(auto ep : program->entryPoints) { - if(ep->entryPoint->getName()->text == name) + if(ep->entryPoint.GetName()->text == name) { return convert(ep); } @@ -1430,75 +1480,6 @@ SLANG_API SlangReflectionEntryPoint* spReflection_findEntryPointByName(SlangRefl return nullptr; } - -SLANG_API SlangInt spReflection_getEntryPointGroupCount(SlangReflection* inProgram) -{ - auto program = convert(inProgram); - if(!program) return 0; - - return program->entryPointGroups.getCount(); -} - -SLANG_API SlangEntryPointGroupLayout* spReflection_getEntryPointGroupByIndex(SlangReflection* inProgram, SlangInt index) -{ - auto program = convert(inProgram); - if(!program) return 0; - - if(index < 0) return nullptr; - if(index >= program->entryPointGroups.getCount()) return nullptr; - - return convert(program->entryPointGroups[(int) index].Ptr()); -} - -SLANG_API SlangInt spEntryPointGroupLayout_getEntryPointCount(SlangEntryPointGroupLayout* inGroup) -{ - auto group = convert(inGroup); - if(!group) return 0; - - return group->entryPoints.getCount(); -} - -SLANG_API SlangReflectionEntryPoint* spEntryPointGroupLayout_getEntryPointByIndex(SlangEntryPointGroupLayout* inGroup, SlangInt index) -{ - auto group = convert(inGroup); - if(!group) return 0; - - if(index < 0) return nullptr; - if(index >= group->entryPoints.getCount()) return nullptr; - - return convert(group->entryPoints[(int) index].Ptr()); -} - -SLANG_API SlangReflectionVariableLayout* spEntryPointGroupLayout_getVarLayout(SlangEntryPointGroupLayout* inGroup) -{ - auto group = convert(inGroup); - if(!group) return 0; - - return convert(group->parametersLayout); -} - -SLANG_API SlangInt spEntryPointGroupLayout_getParameterCount(SlangEntryPointGroupLayout* inGroup) -{ - auto groupLayout = convert(inGroup); - if(!groupLayout) return 0; - - auto& params = groupLayout->group->getShaderParams(); - return params.getCount(); -} - -SLANG_API SlangReflectionVariableLayout* spEntryPointGroupLayout_getParameterByIndex(SlangEntryPointGroupLayout* inGroup, SlangInt index) -{ - auto groupLayout = convert(inGroup); - if(!groupLayout) return nullptr; - - auto& params = groupLayout->group->getShaderParams(); - if(index < 0) return nullptr; - if(index >= params.getCount()) return nullptr; - - return convert(getScopeStructLayout(groupLayout)->fields[index]); -} - - SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* inProgram) { auto program = convert(inProgram); @@ -1531,7 +1512,7 @@ SLANG_API SlangReflectionType* spReflection_specializeType( auto unspecializedType = convert(inType); if(!unspecializedType) return nullptr; - auto linkage = programLayout->getProgram()->getLinkageImpl(); + auto linkage = programLayout->getProgram()->getLinkage(); DiagnosticSink sink(linkage->getSourceManager()); diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 08d671241..c4152d78c 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -2766,10 +2766,10 @@ String ExistentialSpecializedType::ToString() String result; result.append("__ExistentialSpecializedType("); result.append(baseType->ToString()); - for( auto arg : slots.args ) + for( auto arg : args ) { result.append(", "); - result.append(arg.type->ToString()); + result.append(arg.val->ToString()); } result.append(")"); return result; @@ -2784,16 +2784,19 @@ bool ExistentialSpecializedType::EqualsImpl(Type * type) if(!baseType->Equals(other->baseType)) return false; - auto argCount = slots.args.getCount(); - if(argCount != other->slots.args.getCount()) + auto argCount = args.getCount(); + if(argCount != other->args.getCount()) return false; for( Index ii = 0; ii < argCount; ++ii ) { - if(!slots.args[ii].type->Equals(other->slots.args[ii].type)) + auto arg = args[ii]; + auto otherArg = other->args[ii]; + + if(!arg.val->EqualsVal(otherArg.val)) return false; - if(!slots.args[ii].witness->EqualsVal(other->slots.args[ii].witness)) + if(!areValsEqual(arg.witness, otherArg.witness)) return false; } return true; @@ -2803,51 +2806,63 @@ int ExistentialSpecializedType::GetHashCode() { Hasher hasher; hasher.hashObject(baseType); - for(auto arg : slots.args) + for(auto arg : args) { - hasher.hashObject(arg.type); - hasher.hashObject(arg.witness); + hasher.hashObject(arg.val); + if(auto witness = arg.witness) + hasher.hashObject(witness); } return hasher.getResult(); } +RefPtr<Val> getCanonicalValue(Val* val) +{ + if(!val) + return nullptr; + if(auto type = as<Type>(val)) + { + return type->GetCanonicalType(); + } + // TODO: We may eventually need/want some sort of canonicalization + // for non-type values, but for now there is nothing to do. + return val; +} + RefPtr<Type> ExistentialSpecializedType::CreateCanonicalType() { RefPtr<ExistentialSpecializedType> canType = new ExistentialSpecializedType(); canType->setSession(getSession()); canType->baseType = baseType->GetCanonicalType(); - for( auto paramType : slots.paramTypes ) + for( auto arg : args ) { - canType->slots.paramTypes.add( paramType->GetCanonicalType() ); - } - for( auto arg : slots.args ) - { - ExistentialTypeSlots::Arg canArg; - canArg.type = arg.type->GetCanonicalType(); - canArg.witness = arg.witness; - canType->slots.args.add(canArg); + ExpandedSpecializationArg canArg; + canArg.val = getCanonicalValue(arg.val); + canArg.witness = getCanonicalValue(arg.witness); + canType->args.add(canArg); } return canType; } +RefPtr<Val> substituteImpl(Val* val, SubstitutionSet subst, int* ioDiff) +{ + if(!val) return nullptr; + return val->SubstituteImpl(subst, ioDiff); +} + RefPtr<Val> ExistentialSpecializedType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) { int diff = 0; auto substBaseType = baseType->SubstituteImpl(subst, &diff).as<Type>(); - ExistentialTypeSlots substSlots; - for( auto paramType : slots.paramTypes ) - { - substSlots.paramTypes.add( paramType->SubstituteImpl(subst, &diff).as<Type>() ); - } - for( auto arg : slots.args ) + ExpandedSpecializationArgs substArgs; + for( auto arg : args ) { - ExistentialTypeSlots::Arg substArg; - substArg.type = arg.type->SubstituteImpl(subst, &diff).as<Type>(); - substArg.witness = arg.witness->SubstituteImpl(subst, &diff); - substSlots.args.add(substArg); + ExpandedSpecializationArg substArg; + substArg.val = Slang::substituteImpl(arg.val, subst, &diff); + substArg.witness = Slang::substituteImpl(arg.witness, subst, &diff); + substArgs.add(substArg); } if(!diff) @@ -2858,7 +2873,7 @@ RefPtr<Val> ExistentialSpecializedType::SubstituteImpl(SubstitutionSet subst, in RefPtr<ExistentialSpecializedType> substType = new ExistentialSpecializedType(); substType->setSession(getSession()); substType->baseType = substBaseType; - substType->slots = substSlots; + substType->args = substArgs; return substType; } diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index d156750be..88a2ca847 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -1132,31 +1132,31 @@ namespace Slang typedef Dictionary<unsigned int, RefPtr<RefObject>> AttributeArgumentValueDict; - /// Collects information about existential type parameters and their arguments. - struct ExistentialTypeSlots + struct SpecializationParam { - /// For each type parameter, holds the interface/existential type that constrains it. - List<RefPtr<Type>> paramTypes; - - /// An argument for an existential type parameter. - /// - /// Comprises a concrete type and a witness for its conformance to the desired - /// interface/existential type for the corresponding parameter. - /// - struct Arg + enum class Flavor { - RefPtr<Type> type; - RefPtr<Val> witness; + GenericType, + GenericValue, + ExistentialType, + ExistentialValue, }; + Flavor flavor; + RefPtr<RefObject> object; + }; + typedef List<SpecializationParam> SpecializationParams; - /// Any arguments provided for the existential type parameters. - /// - /// It is possible for `args` to be empty even if `paramTypes` is non-empty; - /// that situation represents an unspecialized program or entry point. - /// - List<Arg> args; + struct SpecializationArg + { + RefPtr<Val> val; }; + typedef List<SpecializationArg> SpecializationArgs; + struct ExpandedSpecializationArg : SpecializationArg + { + RefPtr<Val> witness; + }; + typedef List<ExpandedSpecializationArg> ExpandedSpecializationArgs; // Generate class definition for all syntax classes #define SYNTAX_FIELD(TYPE, NAME) TYPE NAME; diff --git a/source/slang/slang-type-defs.h b/source/slang/slang-type-defs.h index d9907bafe..7afc23411 100644 --- a/source/slang/slang-type-defs.h +++ b/source/slang/slang-type-defs.h @@ -479,7 +479,7 @@ END_SYNTAX_CLASS() SYNTAX_CLASS(ExistentialSpecializedType, Type) RAW( RefPtr<Type> baseType; - ExistentialTypeSlots slots; + ExpandedSpecializationArgs args; virtual String ToString() override; virtual bool EqualsImpl(Type * type) override; diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 9b6571372..f76b29a51 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1222,19 +1222,11 @@ EntryPointLayout* EntryPointLayout::getAbsoluteLayout( { adjustedLayout->resultLayout = baseResultLayout->getAbsoluteLayout(parentLayout); } - adjustedLayout->taggedUnionTypeLayouts = this->taggedUnionTypeLayouts; m_absoluteLayout = adjustedLayout; return adjustedLayout; } -EntryPointLayout* EntryPointLayout::getAbsoluteLayout(EntryPointGroupLayout* parentGroup) -{ - SLANG_ASSERT(parentGroup); - return getAbsoluteLayout(parentGroup->parametersLayout); -} - - VarLayout* VarLayout::getAbsoluteLayout(VarLayout* parentAbsoluteLayout) { if( !m_absoluteLayout ) @@ -1933,9 +1925,31 @@ static TypeLayoutResult _createTypeLayout( return _createTypeLayout(subContext, type); } -int findGenericParam(List<RefPtr<GenericParamLayout>> & genericParameters, GlobalGenericParamDecl * decl) +RefPtr<Type> findGlobalGenericSpecializationArg( + TypeLayoutContext const& context, + GlobalGenericParamDecl* decl) { - return (int)genericParameters.findFirstIndex([=](RefPtr<GenericParamLayout> & x) {return x->decl.Ptr() == decl; }); + RefPtr<Val> arg; + context.programLayout->globalGenericArgs.TryGetValue(decl, arg); + return arg.as<Type>(); +} + +Index findGlobalGenericSpecializationParamIndex( + ComponentType* type, + GlobalGenericParamDecl* decl) +{ + Index paramCount = type->getSpecializationParamCount(); + for( Index pp = 0; pp < paramCount; ++pp ) + { + auto param = type->getSpecializationParam(pp); + if(param.flavor != SpecializationParam::Flavor::GenericType) + continue; + if(param.object.Ptr() != decl) + continue; + + return pp; + } + return -1; } // When constructing a new var layout from an existing one, @@ -2393,6 +2407,37 @@ TypeLayoutResult StructTypeLayoutBuilder::getTypeLayoutResult() return TypeLayoutResult(m_typeLayout, m_info); } +static TypeLayoutResult _createTypeLayoutForGlobalGenericTypeParam( + TypeLayoutContext const& context, + Type* type, + GlobalGenericParamDecl* globalGenericParamDecl) +{ + SimpleLayoutInfo info; + info.alignment = 0; + info.size = 0; + info.kind = LayoutResourceKind::GenericResource; + + RefPtr<GenericParamTypeLayout> typeLayout = new GenericParamTypeLayout(); + // we should have already populated ProgramLayout::genericEntryPointParams list at this point, + // so we can find the index of this generic param decl in the list + typeLayout->type = type; + typeLayout->paramIndex = findGlobalGenericSpecializationParamIndex( + context.programLayout->getProgram(), + globalGenericParamDecl); + typeLayout->rules = context.rules; + typeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; + + return TypeLayoutResult(typeLayout, info); +} + +RefPtr<TypeLayout> createTypeLayoutForGlobalGenericTypeParam( + TypeLayoutContext const& context, + Type* type, + GlobalGenericParamDecl* globalGenericParamDecl) +{ + return _createTypeLayoutForGlobalGenericTypeParam(context, type, globalGenericParamDecl).layout; +} + static TypeLayoutResult _createTypeLayout( TypeLayoutContext const& context, Type* type) @@ -2825,7 +2870,7 @@ static TypeLayoutResult _createTypeLayout( // to all the incoming specialized type slots that haven't already // been consumed/claimed by preceding fields. // - auto fieldLayoutContext = context.withExistentialTypeSlotsOffsetBy(baseExistentialSlotIndex); + auto fieldLayoutContext = context.withSpecializationArgsOffsetBy(baseExistentialSlotIndex); TypeLayoutResult fieldResult = _createTypeLayout( fieldLayoutContext, @@ -2863,22 +2908,25 @@ static TypeLayoutResult _createTypeLayout( return typeLayoutBuilder.getTypeLayoutResult(); } - else if (auto globalGenParam = declRef.as<GlobalGenericParamDecl>()) + else if (auto globalGenericParamDecl = declRef.as<GlobalGenericParamDecl>()) { - SimpleLayoutInfo info; - info.alignment = 0; - info.size = 0; - info.kind = LayoutResourceKind::GenericResource; - - auto genParamTypeLayout = new GenericParamTypeLayout(); - // we should have already populated ProgramLayout::genericEntryPointParams list at this point, - // so we can find the index of this generic param decl in the list - genParamTypeLayout->type = type; - genParamTypeLayout->paramIndex = findGenericParam(context.programLayout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl()); - genParamTypeLayout->rules = rules; - genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; - - return TypeLayoutResult(genParamTypeLayout, info); + if( auto concreteType = findGlobalGenericSpecializationArg( + context, + globalGenericParamDecl) ) + { + // If we know what concrete type has been used to specialize + // the global generic type parameter, then we should use + // the concrete type instead. + // + return _createTypeLayout(context, concreteType); + } + else + { + // Otherwise we must create a type layout that represents + // the generic type parameter itself. + // + return _createTypeLayoutForGlobalGenericTypeParam(context, type, globalGenericParamDecl); + } } else if (auto assocTypeParam = declRef.as<AssocTypeDecl>()) { @@ -2934,9 +2982,11 @@ static TypeLayoutResult _createTypeLayout( // If there are any concrete types available, the first one will be // the value that should be plugged into the slot we just introduced. // - if( context.existentialTypeArgCount ) + if( context.specializationArgCount ) { - RefPtr<Type> concreteType = context.existentialTypeArgs[0].type; + auto& specializationArg = context.specializationArgs[0]; + RefPtr<Type> concreteType = specializationArg.val.as<Type>(); + SLANG_ASSERT(concreteType); RefPtr<TypeLayout> concreteTypeLayout = createTypeLayout(context, concreteType); @@ -3046,9 +3096,9 @@ static TypeLayoutResult _createTypeLayout( } else if( auto existentialSpecializedType = as<ExistentialSpecializedType>(type) ) { - TypeLayoutContext subContext = context.withExistentialTypeArgs( - existentialSpecializedType->slots.args.getCount(), - existentialSpecializedType->slots.args.getBuffer()); + TypeLayoutContext subContext = context.withSpecializationArgs( + existentialSpecializedType->args.getBuffer(), + existentialSpecializedType->args.getCount()); auto baseTypeLayoutResult = _createTypeLayout( subContext, diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index e066c2700..b7b3c3207 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -622,7 +622,7 @@ class GenericParamTypeLayout : public TypeLayout { public: RefPtr<GlobalGenericParamDecl> getGlobalGenericParamDecl(); - int paramIndex = 0; + Index paramIndex = 0; }; /// Layout information for a tagged union type. @@ -667,8 +667,6 @@ public: StructTypeLayout* getScopeStructLayout( ScopeLayout* programLayout); -class EntryPointGroupLayout; - // Layout information for a single shader entry point // within a program // @@ -682,7 +680,10 @@ class EntryPointLayout : public ScopeLayout { public: // The corresponding function declaration - RefPtr<FuncDecl> entryPoint; + DeclRef<FuncDecl> entryPoint; + + DeclRef<FuncDecl> getFuncDeclRef() { return entryPoint; } + FuncDecl* getFuncDecl() { return entryPoint.getDecl(); } // The shader profile that was used to compile the entry point Profile profile; @@ -696,30 +697,38 @@ public: }; unsigned flags = 0; - /// Layouts for all tagged union types required by this entry point. - /// - /// These are any tagged union types used by the generic - /// arguments that this entry point is being compiled with. - List<RefPtr<TypeLayout>> taggedUnionTypeLayouts; - EntryPointLayout* getAbsoluteLayout(VarLayout* parentLayout); - EntryPointLayout* getAbsoluteLayout(EntryPointGroupLayout* parentGroup); RefPtr<EntryPointLayout> m_absoluteLayout; }; -class EntryPointGroupLayout : public ScopeLayout + /// Reflection/layout information about a specialization parameter +class SpecializationParamLayout : public Layout { public: - RefPtr<EntryPointGroup> group; - List<RefPtr<EntryPointLayout>> entryPoints; + Index index; }; -class GenericParamLayout : public Layout + /// Reflection/layout information about a generic specialization parameter +class GenericSpecializationParamLayout : public SpecializationParamLayout { public: - RefPtr<GlobalGenericParamDecl> decl; - int index; + /// The declaration of the generic parameter. + /// + /// Could be any subclass of `Decl` that represents a generic value or type parameter. + RefPtr<Decl> decl; +}; + + /// Reflection/layout information about an existential/interface specialization parameter. +class ExistentialSpecializationParamLayout : public SpecializationParamLayout +{ +public: + /// The type that needs to be specialized. + /// + /// Currently, this will be an `interface` type that any concrete + /// type argument getting plugged in must conform to. + /// + RefPtr<Type> type; }; // Layout information for the global scope of a program @@ -748,7 +757,7 @@ public: TargetProgram* getTargetProgram() { return targetProgram; } TargetRequest* getTargetReq() { return targetProgram->getTargetReq(); } - Program* getProgram() { return targetProgram->getProgram(); } + ComponentType* getProgram() { return targetProgram->getProgram(); } // We catalog the requested entry points here, @@ -756,12 +765,21 @@ public: // will (eventually) belong there... List<RefPtr<EntryPointLayout>> entryPoints; - // Entry points can also be grouped for layout purposes (e.g., to form - // ray-tracing hit groups), so this array represents those groups - List<RefPtr<EntryPointGroupLayout>> entryPointGroups; + /// Reflection information on (unspecialized) specialization parameters. + List<RefPtr<SpecializationParamLayout>> specializationParams; - List<RefPtr<GenericParamLayout>> globalGenericParams; - Dictionary<String, GenericParamLayout*> globalGenericParamsMap; + /// Concrete argument values that were provided to specific global generic parameters. + /// + /// Not useful for reflection, but valuable for code generation. + /// + Dictionary<GlobalGenericParamDecl*, RefPtr<Val>> globalGenericArgs; + + /// Layouts for all tagged union types required by this program + /// + /// These are any tagged union types used by the specialization + /// arguments that have been used to specialize the program. + /// + List<RefPtr<TypeLayout>> taggedUnionTypeLayouts; }; StructTypeLayout* getGlobalStructLayout( @@ -908,8 +926,6 @@ struct LayoutRulesFamilyImpl virtual LayoutRulesImpl* getShaderRecordConstantBufferRules() = 0; }; -typedef List<RefPtr<GenericParamLayout>> GenericParamLayouts; - struct TypeLayoutContext { // The layout rules to use (e.g., we compute @@ -930,10 +946,10 @@ struct TypeLayoutContext MatrixLayoutMode matrixLayoutMode; // The concrete types (if any) to plug into the currently in-scope - // existential type slots. + // specialization params. // - Int existentialTypeArgCount = 0; - ExistentialTypeSlots::Arg const* existentialTypeArgs = nullptr; + Int specializationArgCount = 0; + ExpandedSpecializationArg const* specializationArgs = nullptr; LayoutRulesImpl* getRules() { return rules; } LayoutRulesFamilyImpl* getRulesFamily() const { return rules->getLayoutRulesFamily(); } @@ -952,29 +968,29 @@ struct TypeLayoutContext return result; } - TypeLayoutContext withExistentialTypeArgs( - Int argCount, - ExistentialTypeSlots::Arg const* args) const + TypeLayoutContext withSpecializationArgs( + ExpandedSpecializationArg const* args, + Int argCount) const { TypeLayoutContext result = *this; - result.existentialTypeArgCount = argCount; - result.existentialTypeArgs = args; + result.specializationArgCount = argCount; + result.specializationArgs = args; return result; } - TypeLayoutContext withExistentialTypeSlotsOffsetBy( + TypeLayoutContext withSpecializationArgsOffsetBy( Int offset) const { TypeLayoutContext result = *this; - if( existentialTypeArgCount > offset ) + if( specializationArgCount > offset ) { - result.existentialTypeArgCount = existentialTypeArgCount - offset; - result.existentialTypeArgs = existentialTypeArgs + offset; + result.specializationArgCount = specializationArgCount - offset; + result.specializationArgs = specializationArgs + offset; } else { - result.existentialTypeArgCount = 0; - result.existentialTypeArgs = nullptr; + result.specializationArgCount = 0; + result.specializationArgs = nullptr; } return result; @@ -1061,6 +1077,8 @@ public: /// TypeLayoutResult getTypeLayoutResult(); + UniformLayoutInfo* getStructLayoutInfo() { return &m_info; } + private: /// The layout rules being used, if layout has begun. LayoutRulesImpl* m_rules = nullptr; @@ -1139,8 +1157,16 @@ createStructuredBufferTypeLayout( RefPtr<Type> structuredBufferType, RefPtr<Type> elementType); -int findGenericParam(List<RefPtr<GenericParamLayout>> & genericParameters, GlobalGenericParamDecl * decl); -// + /// Create a type layout for an unspecialized `globalGenericParamDecl`. +RefPtr<TypeLayout> createTypeLayoutForGlobalGenericTypeParam( + TypeLayoutContext const& context, + Type* type, + GlobalGenericParamDecl* globalGenericParamDecl); + + /// Find the concrete type (if any) that was plugged in for the global generic type parameter `decl`. +RefPtr<Type> findGlobalGenericSpecializationArg( + TypeLayoutContext const& context, + GlobalGenericParamDecl* decl); // Given an existing type layout `oldTypeLayout`, apply offsets // to any contained fields based on the resource infos in `offsetVarLayout`. diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index b18e4d4d9..b030e5cf9 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -436,44 +436,31 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModule( return asExternal(module); } -SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createProgram( - slang::ProgramDesc const& desc, - slang::IProgram** outProgram) +SLANG_NO_THROW slang::IComponentType* SLANG_MCALL Linkage::createCompositeComponentType( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + ISlangBlob** outDiagnostics) { - RefPtr<Program> program = new Program(this); - - auto itemCount = desc.itemCount; - for(SlangInt ii = 0; ii < itemCount; ++ii) - { - auto& item = desc.items[ii]; - switch(item.kind) - { - case slang::ProgramDesc::Item::Kind::Program: - { - Program* existingProgram = asInternal(item.program); - for(auto referencedModule : existingProgram->getModuleDependencies()) - { - program->addReferencedLeafModule(referencedModule); - } - - // TODO: Need to decide whether to include the entry points as well... - } - break; + // Attempting to create a "composite" of just one component type should + // just return the component type itself, to avoid redundant work. + // + if( componentTypeCount == 1) + return componentTypes[0]; - case slang::ProgramDesc::Item::Kind::Module: - { - Module* module = asInternal(item.module); - program->addReferencedModule(module); - } - break; + DiagnosticSink sink(getSourceManager()); - default: - return SLANG_E_INVALID_ARG; - } + List<RefPtr<ComponentType>> childComponents; + for( Int cc = 0; cc < componentTypeCount; ++cc ) + { + childComponents.add(asInternal(componentTypes[cc])); } - *outProgram = asExternal(program.detach()); - return SLANG_OK; + RefPtr<ComponentType> composite = CompositeComponentType::create( + this, + childComponents); + + sink.getBlobIfNeeded(outDiagnostics); + return asExternal(composite.detach()); } SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType( @@ -781,7 +768,9 @@ RefPtr<Type> checkProperType( TypeExp typeExp, DiagnosticSink* sink); -Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) +Type* ComponentType::getTypeFromString( + String const& typeStr, + DiagnosticSink* sink) { // If we've looked up this type name before, // then we can re-use it. @@ -803,7 +792,7 @@ Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) for(auto module : getModuleDependencies()) scopesToTry.add(module->getModuleDecl()->scope); - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); for(auto& s : scopesToTry) { RefPtr<Expr> typeExpr = linkage->parseTypeString( @@ -899,10 +888,16 @@ void FrontEndCompileRequest::parseTranslationUnit( } } -RefPtr<Program> createUnspecializedProgram( +RefPtr<ComponentType> createUnspecializedGlobalComponentType( + FrontEndCompileRequest* compileRequest); + +RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType( FrontEndCompileRequest* compileRequest); -RefPtr<Program> createSpecializedProgram( +RefPtr<ComponentType> createSpecializedGlobalComponentType( + EndToEndCompileRequest* endToEndReq); + +RefPtr<ComponentType> createSpecializedGlobalAndEntryPointsComponentType( EndToEndCompileRequest* endToEndReq); void FrontEndCompileRequest::checkAllTranslationUnits() @@ -1032,7 +1027,11 @@ SlangResult FrontEndCompileRequest::executeActionsInner() // Look up all the entry points that are expected, // and use them to populate the `program` member. // - m_program = createUnspecializedProgram(this); + m_globalComponentType = createUnspecializedGlobalComponentType(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + m_globalAndEntryPointsComponentType = createUnspecializedGlobalAndEntryPointsComponentType(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1052,7 +1051,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner() // for(auto targetReq : getLinkage()->targets) { - auto targetProgram = m_program->getTargetProgram(targetReq); + auto targetProgram = m_globalAndEntryPointsComponentType->getTargetProgram(targetReq); targetProgram->getOrCreateLayout(getSink()); } if (getSink()->GetErrorCount() != 0) @@ -1064,7 +1063,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner() BackEndCompileRequest::BackEndCompileRequest( Linkage* linkage, DiagnosticSink* sink, - Program* program) + ComponentType* program) : CompileRequestBase(linkage, sink) , m_program(program) {} @@ -1146,7 +1145,8 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // that was computed in the front-end for all subsequent // reflection queries, etc. // - m_specializedProgram = getUnspecializedProgram(); + m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); + m_specializedGlobalAndEntryPointsComponentType = getUnspecializedGlobalAndEntryPointsComponentType(); return SLANG_OK; } @@ -1156,7 +1156,11 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // if (passThrough == PassThroughMode::None) { - m_specializedProgram = createSpecializedProgram(this); + m_specializedGlobalComponentType = createSpecializedGlobalComponentType(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + m_specializedGlobalAndEntryPointsComponentType = createSpecializedGlobalAndEntryPointsComponentType(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1166,7 +1170,7 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // for (auto targetReq : getLinkage()->targets) { - auto targetProgram = m_specializedProgram->getTargetProgram(targetReq); + auto targetProgram = m_specializedGlobalAndEntryPointsComponentType->getTargetProgram(targetReq); targetProgram->getOrCreateLayout(getSink()); } if (getSink()->GetErrorCount() != 0) @@ -1178,20 +1182,27 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // to make sure that the logic in `generateOutput` // sees something worth processing. // - auto specializedProgram = new Program(getLinkage()); - m_specializedProgram = specializedProgram; + List<RefPtr<ComponentType>> dummyEntryPoints; for(auto entryPointReq : getFrontEndReq()->getEntryPointReqs()) { - RefPtr<EntryPoint> entryPoint = EntryPoint::createDummyForPassThrough( + RefPtr<EntryPoint> dummyEntryPoint = EntryPoint::createDummyForPassThrough( + getLinkage(), entryPointReq->getName(), entryPointReq->getProfile()); - specializedProgram->addEntryPoint(entryPoint, getSink()); + dummyEntryPoints.add(dummyEntryPoint); } + + RefPtr<ComponentType> composedProgram = CompositeComponentType::create( + getLinkage(), + dummyEntryPoints); + + m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); + m_specializedGlobalAndEntryPointsComponentType = composedProgram; } // Generate output code, in whatever format was requested - getBackEndReq()->setProgram(getSpecializedProgram()); + getBackEndReq()->setProgram(getSpecializedGlobalAndEntryPointsComponentType()); generateOutput(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1324,7 +1335,7 @@ int EndToEndCompileRequest::addEntryPoint( EntryPointInfo entryPointInfo; for (auto typeName : genericTypeNames) - entryPointInfo.genericArgStrings.add(typeName); + entryPointInfo.specializationArgStrings.add(typeName); Index result = entryPoints.getCount(); entryPoints.add(_Move(entryPointInfo)); @@ -1617,8 +1628,10 @@ void FilePathDependencyList::addDependency(Module* module) // Module::Module(Linkage* linkage) - : m_linkage(linkage) -{} + : ComponentType(linkage) +{ + addModuleDependency(this); +} ISlangUnknown* Module::getInterface(const Guid& guid) { @@ -1638,35 +1651,40 @@ void Module::addFilePathDependency(String const& path) m_filePathDependencyList.addDependency(path); } -// Program +void Module::setModuleDecl(ModuleDecl* moduleDecl) +{ + m_moduleDecl = moduleDecl; +} + +// ComponentType -static const Guid IID_IProgram = SLANG_UUID_IProgram; +static const Guid IID_IComponentType = SLANG_UUID_IComponentType; -Program::Program(Linkage* linkage) +ComponentType::ComponentType(Linkage* linkage) : m_linkage(linkage) {} -ISlangUnknown* Program::getInterface(Guid const& guid) +ISlangUnknown* ComponentType::getInterface(Guid const& guid) { if(guid == IID_ISlangUnknown - || guid == IID_IProgram) + || guid == IID_IComponentType) { - return static_cast<slang::IProgram*>(this); + return static_cast<slang::IComponentType*>(this); } return nullptr; } -SLANG_NO_THROW slang::ISession* SLANG_MCALL Program::getSession() +SLANG_NO_THROW slang::ISession* SLANG_MCALL ComponentType::getSession() { return m_linkage; } -SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout( +SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL ComponentType::getLayout( Int targetIndex, slang::IBlob** outDiagnostics) { - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); if(targetIndex < 0 || targetIndex >= linkage->targets.getCount()) return nullptr; auto target = linkage->targets[targetIndex]; @@ -1678,13 +1696,13 @@ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout( return asExternal(programLayout); } -SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode( +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCode( SlangInt entryPointIndex, Int targetIndex, slang::IBlob** outCode, slang::IBlob** outDiagnostics) { - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); if(targetIndex < 0 || targetIndex >= linkage->targets.getCount()) return SLANG_E_INVALID_ARG; auto target = linkage->targets[targetIndex]; @@ -1702,57 +1720,535 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode( return SLANG_OK; } +RefPtr<ComponentType> ComponentType::specialize( + SpecializationArg const* inSpecializationArgs, + SlangInt specializationArgCount, + DiagnosticSink* sink) +{ + List<SpecializationArg> specializationArgs; + specializationArgs.addRange( + inSpecializationArgs, + specializationArgCount); + + // We next need to validate that the specialization arguments + // make sense, and also expand them to include any derived data + // (e.g., interface conformance witnesses) that doesn't get + // passed explicitly through the API interface. + // + RefPtr<SpecializationInfo> specializationInfo = _validateSpecializationArgs( + specializationArgs.getBuffer(), + specializationArgCount, + sink); + + return new SpecializedComponentType( + this, + specializationInfo, + specializationArgs, + sink); +} -void Program::addReferencedModule(Module* module) +SLANG_NO_THROW slang::IComponentType* SLANG_MCALL ComponentType::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics) { - m_moduleDependencyList.addDependency(module); - m_filePathDependencyList.addDependency(module); + DiagnosticSink sink(getLinkage()->getSourceManager()); + + // First let's check if the number of arguments given matches + // the number of parameters that are present on this component type. + // + auto specializationParamCount = getSpecializationParamCount(); + if( specializationArgCount != specializationParamCount ) + { + // TODO: diagnose + sink.getBlobIfNeeded(outDiagnostics); + return nullptr; + } + + List<SpecializationArg> expandedArgs; + for( Int aa = 0; aa < specializationArgCount; ++aa ) + { + auto apiArg = specializationArgs[aa]; + + SpecializationArg expandedArg; + switch(apiArg.kind) + { + case slang::SpecializationArg::Kind::Type: + expandedArg.val = asInternal(apiArg.type); + break; + + default: + sink.getBlobIfNeeded(outDiagnostics); + return nullptr; + } + expandedArgs.add(expandedArg); + } + + auto specializedComponentType = specialize( + expandedArgs.getBuffer(), + expandedArgs.getCount(), + &sink); + + sink.getBlobIfNeeded(outDiagnostics); + + return specializedComponentType; } -void Program::addReferencedLeafModule(Module* module) + /// Visitor used by `ComponentType::enumerateModules` +struct EnumerateModulesVisitor : ComponentTypeVisitor { - m_moduleDependencyList.addLeafDependency(module); - m_filePathDependencyList.addDependency(module); + EnumerateModulesVisitor(ComponentType::EnumerateModulesCallback callback, void* userData) + : m_callback(callback) + , m_userData(userData) + {} + + ComponentType::EnumerateModulesCallback m_callback; + void* m_userData; + + void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + { + m_callback(module, m_userData); + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(legacy, specializationInfo); + } +}; + + +void ComponentType::enumerateModules(EnumerateModulesCallback callback, void* userData) +{ + EnumerateModulesVisitor visitor(callback, userData); + acceptVisitor(&visitor, nullptr); +} + + /// Visitor used by `ComponentType::enumerateIRModules` +struct EnumerateIRModulesVisitor : ComponentTypeVisitor +{ + EnumerateIRModulesVisitor(ComponentType::EnumerateIRModulesCallback callback, void* userData) + : m_callback(callback) + , m_userData(userData) + {} + + ComponentType::EnumerateIRModulesCallback m_callback; + void* m_userData; + + void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + { + m_callback(module->getIRModule(), m_userData); + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + + m_callback(specialized->getIRModule(), m_userData); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(legacy, specializationInfo); + } +}; + +void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData) +{ + EnumerateIRModulesVisitor visitor(callback, userData); + acceptVisitor(&visitor, nullptr); +} + +// +// CompositeComponentType +// + +RefPtr<ComponentType> CompositeComponentType::create( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents) +{ + // TODO: We should ideally be caching the results of + // composition on the `linkage`, so that if we get + // asked for the same composite again later we re-use + // it rather than re-create it. + // + // Similarly, we might want to do some amount of + // work to "canonicalize" the input for composition. + // E.g., if the user does: + // + // X = compose(A,B); + // Y = compose(C,D); + // Z = compose(X,Y); + // + // W = compose(A, B, C, D); + // + // Then there is no observable difference between + // Z and W, so we might prefer to have them be identical. + + // If there is only a single child, then we should + // just return that child rather than create a dummy composite. + // + if( childComponents.getCount() == 1 ) + { + return childComponents[0]; + } + + return new CompositeComponentType(linkage, childComponents); +} + + +CompositeComponentType::CompositeComponentType( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents) + : ComponentType(linkage) + , m_childComponents(childComponents) +{ + HashSet<ComponentType*> requirementsSet; + for(auto child : childComponents ) + { + child->enumerateModules([&](Module* module) + { + requirementsSet.Add(module); + }); + } + + for(auto child : childComponents ) + { + auto childEntryPointCount = child->getEntryPointCount(); + for(Index cc = 0; cc < childEntryPointCount; ++cc) + { + m_entryPoints.add(child->getEntryPoint(cc)); + } + + auto childShaderParamCount = child->getShaderParamCount(); + for(Index pp = 0; pp < childShaderParamCount; ++pp) + { + m_shaderParams.add(child->getShaderParam(pp)); + } + + auto childSpecializationParamCount = child->getSpecializationParamCount(); + for(Index pp = 0; pp < childSpecializationParamCount; ++pp) + { + m_specializationParams.add(child->getSpecializationParam(pp)); + } + + for(auto module : child->getModuleDependencies()) + { + m_moduleDependencyList.addDependency(module); + } + for(auto filePath : child->getFilePathDependencies()) + { + m_filePathDependencyList.addDependency(filePath); + } + + auto childRequirementCount = child->getRequirementCount(); + for(Index rr = 0; rr < childRequirementCount; ++rr) + { + auto childRequirement = child->getRequirement(rr); + if(!requirementsSet.Contains(childRequirement)) + { + requirementsSet.Add(childRequirement); + m_requirements.add(childRequirement); + } + } + } +} + +Index CompositeComponentType::getEntryPointCount() +{ + return m_entryPoints.getCount(); +} + +RefPtr<EntryPoint> CompositeComponentType::getEntryPoint(Index index) +{ + return m_entryPoints[index]; +} + +Index CompositeComponentType::getShaderParamCount() +{ + return m_shaderParams.getCount(); } -void Program::addEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) +GlobalShaderParamInfo CompositeComponentType::getShaderParam(Index index) { - List<RefPtr<EntryPoint>> entryPoints; - entryPoints.add(entryPoint); + return m_shaderParams[index]; +} + +Index CompositeComponentType::getSpecializationParamCount() +{ + return m_specializationParams.getCount(); +} + +SpecializationParam const& CompositeComponentType::getSpecializationParam(Index index) +{ + return m_specializationParams[index]; +} + +Index CompositeComponentType::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> CompositeComponentType::getRequirement(Index index) +{ + return m_requirements[index]; +} + +List<Module*> const& CompositeComponentType::getModuleDependencies() +{ + return m_moduleDependencyList.getModuleList(); +} - RefPtr<EntryPointGroup> entryPointGroup = EntryPointGroup::create(getLinkageImpl(), entryPoints, sink); +List<String> const& CompositeComponentType::getFilePathDependencies() +{ + return m_filePathDependencyList.getFilePathList(); +} - addEntryPointGroup(entryPointGroup); +void CompositeComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) +{ + visitor->visitComposite(this, as<CompositeSpecializationInfo>(specializationInfo)); } -void Program::addEntryPointGroup(EntryPointGroup* entryPointGroup) + +RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) { - m_entryPointGroups.add(entryPointGroup); + SLANG_UNUSED(argCount); + + RefPtr<CompositeSpecializationInfo> specializationInfo = new CompositeSpecializationInfo(); - for(auto entryPoint : entryPointGroup->getEntryPoints()) + Index offset = 0; + for(auto child : m_childComponents) { - m_entryPoints.add(entryPoint); - for(auto module : entryPoint->getModuleDependencies()) + auto childParamCount = child->getSpecializationParamCount(); + SLANG_ASSERT(offset + childParamCount <= argCount); + + auto childInfo = child->_validateSpecializationArgs( + args + offset, + childParamCount, + sink); + + specializationInfo->childInfos.add(childInfo); + + offset += childParamCount; + } + return specializationInfo; +} + +// +// SpecializedComponentType +// + +SpecializedComponentType::SpecializedComponentType( + ComponentType* base, + ComponentType::SpecializationInfo* specializationInfo, + List<SpecializationArg> const& specializationArgs, + DiagnosticSink* sink) + : ComponentType(base->getLinkage()) + , m_base(base) + , m_specializationInfo(specializationInfo) + , m_specializationArgs(specializationArgs) +{ + m_irModule = generateIRForSpecializedComponentType(this, sink); + + // The following is a bit of a hack. + // + // Back-end code generation relies on us having computed layouts for all tagged + // unions that end up being used in the code, which means we need a way to find + // all such types that get used in a program (and the stuff it imports). + // + // For now we are assuming a tagged union type only comes into existence + // as a (top-level) argument for a generic type parameter, so that we + // can check for them here and cache them on the entry point. + // + // A longer-term strategy might need to consider any (tagged or untagged) + // union types that get used inside of a module, and also take + // those lists into account. + // + // An even longer-term strategy would be to allow type layout to + // be performed on IR types, so taht we don't need to have front-end + // code worrying about this stuff. + // + for(auto arg : specializationArgs) + { + auto argType = as<Type>(arg.val); + if(!argType) + continue; + + auto taggedUnionType = as<TaggedUnionType>(argType); + if(!taggedUnionType) + continue; + + m_taggedUnionTypes.add(taggedUnionType); + } +} + +void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) +{ + SLANG_ASSERT(specializationInfo == nullptr); + SLANG_UNUSED(specializationInfo); + visitor->visitSpecialized(this); +} + +Index SpecializedComponentType::getRequirementCount() +{ + // TODO: A specialized component type may have *more* requirements + // than the original, because it also needs to include the module(s) + // that define the types used for specialization arguments. + + return m_base->getRequirementCount(); +} + +RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index) +{ + return m_base->getRequirement(index); +} + +// +// LegacyProgram +// + +LegacyProgram::LegacyProgram( + Linkage* linkage, + List<RefPtr<TranslationUnitRequest>> const& translationUnits, + DiagnosticSink* sink) + : ComponentType(linkage) + , m_translationUnits(translationUnits) +{ + HashSet<ComponentType*> requirementsSet; + + for(auto translationUnit : translationUnits ) + { + ComponentType* child = translationUnit->getModule(); + + auto childEntryPointCount = child->getEntryPointCount(); + for(Index cc = 0; cc < childEntryPointCount; ++cc) + { + m_entryPoints.add(child->getEntryPoint(cc)); + } + + for(auto module : child->getModuleDependencies()) + { + m_moduleDependencies.addDependency(module); + } + for(auto filePath : child->getFilePathDependencies()) { - addReferencedModule(module); + m_fileDependencies.addDependency(filePath); + } + + auto childRequirementCount = child->getRequirementCount(); + for(Index rr = 0; rr < childRequirementCount; ++rr) + { + auto childRequirement = child->getRequirement(rr); + if(!requirementsSet.Contains(childRequirement)) + { + requirementsSet.Add(childRequirement); + m_requirements.add(childRequirement); + } } } + + _collectShaderParams(sink); } -RefPtr<IRModule> Program::getOrCreateIRModule(DiagnosticSink* sink) +void LegacyProgram::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) { - if(!m_irModule) + visitor->visitLegacy(this, as<CompositeComponentType::CompositeSpecializationInfo>(specializationInfo)); +} + +RefPtr<ComponentType::SpecializationInfo> LegacyProgram::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) +{ + SLANG_UNUSED(argCount); + + RefPtr<CompositeComponentType::CompositeSpecializationInfo> info = new CompositeComponentType::CompositeSpecializationInfo(); + + Index offset = 0; + for(auto translationUnit : m_translationUnits) { - m_irModule = generateIRForProgram( - m_linkage->getSessionImpl(), - this, + ComponentType* child = translationUnit->getModule(); + auto childParamCount = child->getSpecializationParamCount(); + SLANG_ASSERT(offset + childParamCount <= argCount); + + auto childInfo = child->_validateSpecializationArgs( + args + offset, + childParamCount, sink); + + info->childInfos.add(childInfo); + + offset += childParamCount; } - return m_irModule; + return info; } +Index LegacyProgram::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> LegacyProgram::getRequirement(Index index) +{ + return m_requirements[index]; +} + +void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) +{ + auto childCount = composite->getChildComponentCount(); + for(Index ii = 0; ii < childCount; ++ii) + { + auto child = composite->getChildComponent(ii); + auto childSpecializationInfo = specializationInfo + ? specializationInfo->childInfos[ii] + : nullptr; + + child->acceptVisitor(this, childSpecializationInfo); + } +} + +void ComponentTypeVisitor::visitChildren(SpecializedComponentType* specialized) +{ + specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo()); +} + +void ComponentTypeVisitor::visitChildren(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) +{ + auto childCount = legacy->getTranslationUnitCount(); + for(Index ii = 0; ii < childCount; ++ii) + { + auto translationUnit = legacy->getTranslationUnit(ii); + ComponentType* child = translationUnit->getModule(); + auto childSpecializationInfo = specializationInfo + ? specializationInfo->childInfos[ii] + : nullptr; + + child->acceptVisitor(this, childSpecializationInfo); + } +} -TargetProgram* Program::getTargetProgram(TargetRequest* target) +TargetProgram* ComponentType::getTargetProgram(TargetRequest* target) { RefPtr<TargetProgram> targetProgram; if(!m_targetPrograms.TryGetValue(target, targetProgram)) @@ -1768,12 +2264,12 @@ TargetProgram* Program::getTargetProgram(TargetRequest* target) // TargetProgram::TargetProgram( - Program* program, + ComponentType* componentType, TargetRequest* targetReq) - : m_program(program) + : m_program(componentType) , m_targetReq(targetReq) { - m_entryPointResults.setCount(program->getEntryPoints().getCount()); + m_entryPointResults.setCount(componentType->getEntryPointCount()); } // @@ -2458,10 +2954,10 @@ SLANG_API SlangResult spSetGlobalGenericArgs( if (!request) return SLANG_FAIL; auto req = convert(request); - auto& genericArgStrings = req->globalGenericArgStrings; - genericArgStrings.clear(); + auto& argStrings = req->globalSpecializationArgStrings; + argStrings.clear(); for (int i = 0; i < genericArgCount; i++) - genericArgStrings.add(genericArgs[i]); + argStrings.add(genericArgs[i]); return SLANG_OK; } @@ -2477,7 +2973,7 @@ SLANG_API SlangResult spSetTypeNameForGlobalExistentialTypeParam( if(!typeName) return SLANG_FAIL; auto req = convert(request); - auto& typeArgStrings = req->globalExistentialSlotArgStrings; + auto& typeArgStrings = req->globalSpecializationArgStrings; if(Index(slotIndex) >= typeArgStrings.getCount()) typeArgStrings.setCount(slotIndex+1); typeArgStrings[slotIndex] = String(typeName); @@ -2501,7 +2997,7 @@ SLANG_API SlangResult spSetTypeNameForEntryPointExistentialTypeParam( return SLANG_FAIL; auto& entryPointInfo = req->entryPoints[entryPointIndex]; - auto& typeArgStrings = entryPointInfo.existentialArgStrings; + auto& typeArgStrings = entryPointInfo.specializationArgStrings; if(Index(slotIndex) >= typeArgStrings.getCount()) typeArgStrings.setCount(slotIndex+1); typeArgStrings[slotIndex] = String(typeName); @@ -2569,7 +3065,7 @@ spGetDependencyFileCount( if(!request) return 0; auto req = convert(request); auto frontEndReq = req->getFrontEndReq(); - auto program = frontEndReq->getProgram(); + auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); return (int) program->getFilePathDependencies().getCount(); } @@ -2583,7 +3079,7 @@ spGetDependencyFilePath( if(!request) return 0; auto req = convert(request); auto frontEndReq = req->getFrontEndReq(); - auto program = frontEndReq->getProgram(); + auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); return program->getFilePathDependencies()[index].begin(); } @@ -2613,7 +3109,7 @@ SLANG_API void const* spGetEntryPointCode( using namespace Slang; auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); // TODO: We should really accept a target index in this API Index targetIndex = 0; @@ -2668,7 +3164,7 @@ SLANG_API SlangResult spGetEntryPointCodeBlob( auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); Index targetCount = linkage->targets.getCount(); if((targetIndex < 0) || (targetIndex >= targetCount)) @@ -2715,13 +3211,13 @@ SLANG_API void const* spGetCompileRequestCode( SLANG_API SlangResult spCompileRequest_getProgram( SlangCompileRequest* request, - slang::IProgram** outProgram) + slang::IComponentType** outProgram) { if( !request ) return SLANG_ERROR_INVALID_PARAMETER; auto req = convert(request); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); - *outProgram = Slang::ComPtr<slang::IProgram>(program).detach(); + *outProgram = Slang::ComPtr<slang::IComponentType>(program).detach(); return SLANG_OK; } @@ -2731,7 +3227,7 @@ SLANG_API SlangReflection* spGetReflection( if( !request ) return 0; auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); // Note(tfoley): The API signature doesn't let the client // specify which target they want to access reflection |
