From 2552217b76c0bd83e18fceba1d35a367bf569eca Mon Sep 17 00:00:00 2001 From: Tim Foley Date: Thu, 8 Aug 2019 11:22:32 -0700 Subject: Revise new COM-lite API (#1007) * Revise new COM-lite API This change revises the "COM-lite" API that was recently introduced to try to streamline it and introduce some missing central/base concepts. The central new abstraction in the API is the notion of a "component type," which is a unit of shader code composition. A component type can have: * IR code for some number of functions/types/etc. * Zero or more global shader parameters * Zero or more "entry point" functions at which execution can start * Zero or more "specialization" parameters (types or values that must be filled in before kernel code can be generated) * Zero or more "requirements" (dependencies on other component types that must be satisfied before kernel code can be generated) Both individual compiled modules, and validated entry points are then examples of component types, and we additionally define a few services that apply to all component types: * We can take N component types and compose them to create a new component type that combines their code, shader parameters, entry points, and specialization parameters. A composed component type may also include requirements from the sub-component types, but it is also possible that by composing thing we satisfy requirements (if `A` requires `B`, and we compose `A` and `B`, then the requirement is now satisfied, and doesn't appear on the composite). * We can take a component type with N specialization parameters, and specialize it by giving N compatible specialization arguments. The result of specialization is a new component type with zero specialization parameters. Under the right circumstances the specialzed component type will be layout compatible with the unspecialized one. * One more example that isn't exposed in the public API today is that we can take a component with requirements and "complete" it by automatically composing it with component types that satisfy those requirements. This can be seen as a kind of linking step that pulls together the transitive closure of dependencies. * We can query the layout for the shader parameters and entry points of a component type, for a specific target. * We can query compiled kernel code for an entry point in a component type (for a specific target). This only works for component types with zero specialization parameters and zero requirements. The idea is that by giving users a fairly general algebra of operations on component types, they can compose final programs in ways that meet their requirements. For example, it becomes possible to incrementally "grow" a component type to represent the global root signature for ray tracing shaders as new entry points are added, in such a way that it always stays layout-compatible with kernels that have already been compiled. Much of the implementation work here is in implementing the unifying component type abstraction, and in particular re-writing code that used to assume a program consisted of a flat list of modules and entry points to work with a hierarchical representation that reflects the underlying algebra (e.g., with types to represent composite and specialized component types). There's also a hidden "legacy" case of a component type to deal with some legacy compiler behaviors that can't be directly modeled on top of the simple algebra with modules and entry points. This API is by no means feature-complete or fully developed. It is expected that we will flesh it out more when bringing up application code (e.g., Falcor) on top of the revamped API. One notable thing that went away in this change is explicit support for "entry point groups" and notions of local root signatures (especially the Falcor-specific handling of the `shared` keyword, which a previous change turned into an explicitly supported feature). With the new "building blocks" approach, it should be possible for a DXR application to deal with local root signatures as a matter of policy (on top of the API we provide). If/when we need to provide some kind of emulation of local root signatures for Vulkan (and/or if Vulkan is extended with an explicit notion of local root signatures), we might need to revisit this choice. * Fix debug build There was invalid code inside an `assert()`, so the release build didn't catch it. * fixup: warnings * fixup: more warnings-as-errors * fixup: review notes * fixup: use component type visitors in place of dynamic casting --- slang.h | 298 +++--- source/slang/slang-check.cpp | 1312 ++++++++++++++------------- source/slang/slang-compiler.cpp | 144 ++- source/slang/slang-compiler.h | 920 +++++++++++++------ source/slang/slang-diagnostic-defs.h | 9 +- source/slang/slang-emit-c-like.cpp | 8 +- source/slang/slang-emit-glsl.cpp | 6 +- source/slang/slang-emit-hlsl.cpp | 4 +- source/slang/slang-emit.cpp | 54 +- source/slang/slang-ir-bind-existentials.cpp | 92 +- source/slang/slang-ir-inst-defs.h | 1 - source/slang/slang-ir-insts.h | 4 - source/slang/slang-ir-link.cpp | 189 ++-- source/slang/slang-ir.cpp | 16 - source/slang/slang-lower-to-ir.cpp | 216 +++-- source/slang/slang-lower-to-ir.h | 26 +- source/slang/slang-parameter-binding.cpp | 1043 +++++++++++++++------ source/slang/slang-reflection.cpp | 181 ++-- source/slang/slang-syntax.cpp | 73 +- source/slang/slang-syntax.h | 38 +- source/slang/slang-type-defs.h | 2 +- source/slang/slang-type-layout.cpp | 112 ++- source/slang/slang-type-layout.h | 108 ++- source/slang/slang.cpp | 700 +++++++++++--- 24 files changed, 3587 insertions(+), 1969 deletions(-) diff --git a/slang.h b/slang.h index 55c454d0d..583072686 100644 --- a/slang.h +++ b/slang.h @@ -1571,10 +1571,7 @@ extern "C" typedef struct SlangProgramLayout SlangProgramLayout; typedef struct SlangEntryPoint SlangEntryPoint; typedef struct SlangEntryPointLayout SlangEntryPointLayout; - typedef struct SlangEntryPointGroupLayout SlangEntryPointGroupLayout; -// typedef struct SlangReflection SlangReflection; -// typedef struct SlangReflectionEntryPoint SlangReflectionEntryPoint; typedef struct SlangReflectionModifier SlangReflectionModifier; typedef struct SlangReflectionType SlangReflectionType; typedef struct SlangReflectionTypeLayout SlangReflectionTypeLayout; @@ -1900,6 +1897,12 @@ extern "C" SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput( SlangReflectionEntryPoint* entryPoint); + SLANG_API SlangReflectionVariableLayout* spReflectionEntryPoint_getVarLayout( + SlangReflectionEntryPoint* entryPoint); + + SLANG_API int spReflectionEntryPoint_hasDefaultConstantBuffer( + SlangReflectionEntryPoint* entryPoint); + // SlangReflectionTypeParameter SLANG_API char const* spReflectionTypeParameter_GetName(SlangReflectionTypeParameter* typeParam); SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParameter* typeParam); @@ -1922,9 +1925,6 @@ extern "C" SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangReflection* reflection, SlangUInt index); SLANG_API SlangReflectionEntryPoint* spReflection_findEntryPointByName(SlangReflection* reflection, char const* name); - SLANG_API SlangInt spReflection_getEntryPointGroupCount(SlangReflection* reflection); - SLANG_API SlangEntryPointGroupLayout* spReflection_getEntryPointGroupByIndex(SlangReflection* reflection, SlangInt index); - SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* reflection); SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* reflection); @@ -1935,16 +1935,6 @@ extern "C" SlangReflectionType* const* specializationArgs, ISlangBlob** outDiagnostics); - // Entry point group reflection - - SLANG_API SlangInt spEntryPointGroupLayout_getEntryPointCount(SlangEntryPointGroupLayout* group); - SLANG_API SlangReflectionEntryPoint* spEntryPointGroupLayout_getEntryPointByIndex(SlangEntryPointGroupLayout* group, SlangInt index); - SLANG_API SlangReflectionVariableLayout* spEntryPointGroupLayout_getVarLayout(SlangEntryPointGroupLayout* group); - - SLANG_API SlangInt spEntryPointGroupLayout_getParameterCount(SlangEntryPointGroupLayout* group); - SLANG_API SlangReflectionVariableLayout* spEntryPointGroupLayout_getParameterByIndex(SlangEntryPointGroupLayout* group, SlangInt index); - - #ifdef __cplusplus } @@ -2448,36 +2438,23 @@ namespace slang { return 0 != spReflectionEntryPoint_usesAnySampleRateInput((SlangReflectionEntryPoint*) this); } - }; - typedef EntryPointReflection EntryPointLayout; - - struct EntryPointGroupLayout - { - SlangInt getEntryPointCount() - { - return spEntryPointGroupLayout_getEntryPointCount((SlangEntryPointGroupLayout*) this); - } - - EntryPointReflection* getEntryPointByIndex(SlangInt index) - { - return (EntryPointReflection*) spEntryPointGroupLayout_getEntryPointByIndex((SlangEntryPointGroupLayout*) this, index); - } VariableLayoutReflection* getVarLayout() { - return (VariableLayoutReflection*) spEntryPointGroupLayout_getVarLayout((SlangEntryPointGroupLayout*) this); + return (VariableLayoutReflection*) spReflectionEntryPoint_getVarLayout((SlangReflectionEntryPoint*) this); } - SlangInt getParameterCount() + TypeLayoutReflection* getTypeLayout() { - return spEntryPointGroupLayout_getParameterCount((SlangEntryPointGroupLayout*) this); + return getVarLayout()->getTypeLayout(); } - VariableLayoutReflection* getParameterByIndex(SlangInt index) + bool hasDefaultConstantBuffer() { - return (VariableLayoutReflection*) spEntryPointGroupLayout_getParameterByIndex((SlangEntryPointGroupLayout*) this, index); + return spReflectionEntryPoint_hasDefaultConstantBuffer((SlangReflectionEntryPoint*) this) != 0; } }; + typedef EntryPointReflection EntryPointLayout; struct TypeParameterReflection { @@ -2548,16 +2525,6 @@ namespace slang return (EntryPointReflection*) spReflection_getEntryPointByIndex((SlangReflection*) this, index); } - SlangInt getEntryPointGroupCount() - { - return spReflection_getEntryPointGroupCount((SlangReflection*) this); - } - - EntryPointGroupLayout* getEntryPointGroupByIndex(SlangInt index) - { - return (EntryPointGroupLayout*) spReflection_getEntryPointGroupByIndex((SlangReflection*) this, index); - } - SlangUInt getGlobalConstantBufferBinding() { return spReflection_getGlobalConstantBufferBinding((SlangReflection*)this); @@ -2609,13 +2576,12 @@ namespace slang typedef ISlangBlob IBlob; + struct IComponentType; struct IGlobalSession; struct IModule; - struct IProgram; struct ISession; struct SessionDesc; - struct ProgramDesc; struct SpecializationArg; struct TargetDesc; @@ -2759,20 +2725,39 @@ namespace slang const char* moduleName, IBlob** outDiagnostics = nullptr) = 0; - /** Create a program out of existing compiled items. - */ - virtual SLANG_NO_THROW SlangResult SLANG_MCALL createProgram( - ProgramDesc const& desc, - IProgram** outProgram) = 0; + /** Combine multiple component types to create a composite component type. + + The `componentTypes` array must contain `componentTypeCount` pointers + to component types that were loaded or created using the same session. + + The shader parameters and specialization parameters of the composite will + be the union of those in `componentTypes`. The relative order of child + component types is significant, and will affect the order in which + parameters are reflected and laid out. + + The entry-point functions of the composite will be the union of those in + `componentTypes`, and will follow the ordering of `componentTypes`. + + The requirements of the composite component type will be a subset of + those in `componentTypes`. If an entry in `componentTypes` has a requirement + that can be satisfied by another entry, then the composition will + satisfy the requirement and it will not appear as a requirement of + the composite. If multiple entries in `componentTypes` have a requirement + for the same type, then only the first such requirement will be retained + on the composite. The relative ordering of requirements on the composite + will otherwise match that of `componentTypes`. + + If any diagnostics are generated during creation of the composite, they + will be written to `outDiagnostics`. If an error is encountered, the + function will return null. - /** Specialize a program based on type arguments. + It is an error to create a composite component type that recursively + aggregates the a single module more than once. */ - virtual SLANG_NO_THROW SlangResult SLANG_MCALL specializeProgram( - IProgram* program, - SlangInt specializationArgCount, - SpecializationArg const* specializationArgs, - IProgram** outSpecializedProgram, - ISlangBlob** outDiagnostics = nullptr) = 0; + virtual SLANG_NO_THROW IComponentType* SLANG_MCALL createCompositeComponentType( + IComponentType* const* componentTypes, + SlangInt componentTypeCount, + ISlangBlob** outDiagnostics = nullptr) = 0; /** Specialize a type based on type arguments. */ @@ -2799,76 +2784,95 @@ namespace slang #define SLANG_UUID_ISession { 0x67618701, 0xd116, 0x468f, { 0xab, 0x3b, 0x47, 0x4b, 0xed, 0xce, 0xe, 0x3d } } + /** A component type is a unit of shader code layout, reflection, and linking. - /** A module is the granularity of shader code compilation and loading. + A component type is a unit of shader code that can be included into + a linked and compiled shader program. Each component type may have: - In most cases a module corresponds to a single compile "translation unit." - This will often be a single `.slang` or `.hlsl` file and everything it - `#include`s. + * Zero or more uniform shader parameters, representing textures, + buffers, etc. that the code in the component depends on. - Notably, a module `M` does *not* include the things it `import`s, as these - as distinct modules that `M` depends on. There is a directed graph of - module dependencies, and all modules in the graph must belong to the - same session (`ISession`). - */ - struct IModule : public ISlangUnknown - { - public: - }; - - #define SLANG_UUID_IModule { 0xc720e64, 0x8722, 0x4d31, { 0x89, 0x90, 0x63, 0x8a, 0x98, 0xb1, 0xc2, 0x79 } } + * Zero or more *specialization* parameters, which are type or + value parameters that can be used to synthesize specialized + versions of the component type. + * Zero or more entry points, which are the individually invocable + kernels that can have final code generated. - /** Argument used for specialization to types/values. - */ - struct SpecializationArg - { - enum class Kind : int32_t - { - Unknown, - Type, - }; - Kind kind; - union - { - TypeReflection* type; - }; - }; + * Zero or more *requirements*, which are other component + types on which the component type depends. - /** Description of a program to be created. - */ - struct ProgramDesc - { - struct Item - { - enum class Kind : int32_t - { - Program, - Module, - }; - Kind kind; - union - { - IProgram* program; - IModule* module; - }; - }; + One example of a component type is a module of Slang code: - Item const* items; - SlangInt itemCount; - }; + * The global-scope shader parameters declared in the module are + the parameters when considered as a component type. + + * Any global-scope generic or interface type parameters introduce + specialization parameters for the module. + + * A module does not by default include any entry points when + considered as a component type (although the code of the + module might *declare* some entry points). + + * Any other modules that are `import`ed in the source code + become requirements of the module, when considered as a + component type. + + An entry point is another example of a component type: + + * The `uniform` parameters of the entry point function are + its shader parameters when considered as a component type. + + * Any generic or interface-type parameters of the entry point + introduce specialization parameters. + + * An entry point component type exposes a single entry point (itself). + + * An entry point has one requirement for the module in which + it was defined. + + Component types can be manipulated in a few ways: + + * Multiple component types can be combined into a composite, which + combines all of their code, parameters, etc. + * A component type can be specialized, by "plugging in" types and + values for its specialization parameters. + + * A component type can be laid out for a particular target, giving + offsets/bindings to the shader parameters it contains. + + * Generated kernel code can be requested for entry points. - /** A program comprises zero or more modules, entry points, etc. that have been linked together. */ - struct IProgram : public ISlangUnknown + struct IComponentType : public ISlangUnknown { - public: - /** Get the runtime session that this program belongs to. + /** Get the runtime session that this component type belongs to. */ virtual SLANG_NO_THROW ISession* SLANG_MCALL getSession() = 0; - /** Get the layout for this program for the chosen `targetIndex` + /** Get the layout for this program for the chosen `targetIndex`. + + The resulting layout will establish offsets/bindings for all + of the global and entry-point shader parameters in the + component type. + + If this component type has specialization parameters (that is, + it is not fully specialized), then the resulting layout may + be incomplete, and plugging in arguments for generic specialization + parameters may result in a component type that doesn't have + a compatible layout. If the component type only uses + interface-type specialization parameters, then the layout + for a specialization should be compatible with an unspecialized + layout (all parameters in the unspecialized layout will have + the same offset/binding in the specialized layout). + + If this component type is combined into a composite, then + the absolute offsets/bindings of parameters may not stay the same. + If the shader parameters in a component type don't make + use of explicit binding annotations (e.g., `register(...)`), + then the *relative* offset of shader parameters will stay + the same when it is used in a composition. */ virtual SLANG_NO_THROW ProgramLayout* SLANG_MCALL getLayout( SlangInt targetIndex = 0, @@ -2876,6 +2880,10 @@ namespace slang /** Get the compiled code for the entry point at `entryPointIndex` for the chosen `targetIndex` + Entry point code can only be computed for a component type that + has no specialization parameters (it must be fully specialized) + and that has no requirements (it must be fully linked). + If code has not already been generated for the given entry point and target, then a compilation error may be detected, in which case `outDiagnostics` (if non-null) will be filled in with a blob of messages diagnosing the error. @@ -2885,9 +2893,65 @@ namespace slang SlangInt targetIndex, IBlob** outCode, IBlob** outDiagnostics = nullptr) = 0; + + /** Specialize the component by binding its specialization parameters to concrete arguments. + + The `specializationArgs` array must have `specializationArgCount` entries, and + this must match the number of specialization parameters on this component type. + + If the specialization arguments are not valid, then the function will return null. + + If any diagnostics (error or warnings) are produced, they will be written to `outDiagnostics`. + */ + virtual SLANG_NO_THROW IComponentType* SLANG_MCALL specialize( + SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics = nullptr) = 0; + }; + #define SLANG_UUID_IComponentType { 0x5bc42be8, 0x5c50, 0x4929, { 0x9e, 0x5e, 0xd1, 0x5e, 0x7c, 0x24, 0x1, 0x5f } }; + + /** A module is the granularity of shader code compilation and loading. + + In most cases a module corresponds to a single compile "translation unit." + This will often be a single `.slang` or `.hlsl` file and everything it + `#include`s. + + Notably, a module `M` does *not* include the things it `import`s, as these + as distinct modules that `M` depends on. There is a directed graph of + module dependencies, and all modules in the graph must belong to the + same session (`ISession`). + + A module establishes a namespace for looking up types, functions, etc. + */ + struct IModule : public IComponentType + { + public: + /** Note: eventually operations for looking up types or entry + points by name should appear here. + */ }; - #define SLANG_UUID_IProgram { 0x5bc42be8, 0x5c50, 0x4929, { 0x9e, 0x5e, 0xd1, 0x5e, 0x7c, 0x24, 0x1, 0x5f } }; + #define SLANG_UUID_IModule { 0xc720e64, 0x8722, 0x4d31, { 0x89, 0x90, 0x63, 0x8a, 0x98, 0xb1, 0xc2, 0x79 } } + + + /** Argument used for specialization to types/values. + */ + struct SpecializationArg + { + enum class Kind : int32_t + { + Unknown, /**< An invalid specialization argument. */ + Type, /**< Specialize to a type. */ + }; + + /** The kind of specialization argument. */ + Kind kind; + union + { + /** A type specialization argument, used for `Kind::Type`. */ + TypeReflection* type; + }; + }; } #define SLANG_API_VERSION 0 @@ -2909,7 +2973,7 @@ namespace slang */ SLANG_API SlangResult spCompileRequest_getProgram( SlangCompileRequest* request, - slang::IProgram** outProgram); + slang::IComponentType** outProgram); #endif diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index edf41fff0..b7c1ccdc2 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -7107,8 +7107,7 @@ namespace Slang { if(context.mode != OverloadResolveContext::Mode::JustTrying) { - // TODO: diagnose a problem here - getSink()->diagnose(context.loc, Diagnostics::unimplemented, "generic constraint not satisfied"); + getSink()->diagnose(context.loc, Diagnostics::typeArgumentDoesNotConformToInterface, sub, sup); } return false; } @@ -9663,15 +9662,15 @@ namespace Slang } } - /// Recursively walk `paramDeclRef` and add any required existential slots to `ioSlots`. - static void _collectExistentialTypeParamsRec( - ExistentialTypeSlots& ioSlots, + /// Recursively walk `paramDeclRef` and add any existential/interface specialization parameters to `ioSpecializationParams`. + static void _collectExistentialSpecializationParamsRec( + SpecializationParams& ioSpecializationParams, DeclRef paramDeclRef); - /// Recursively walk `type` and discover any required existential type parameters. - static void _collectExistentialTypeParamsRec( - ExistentialTypeSlots& ioSlots, - Type* type) + /// Recursively walk `type` and add any existential/interface specialization parameters to `ioSpecializationParams`. + static void _collectExistentialSpecializationParamsRec( + SpecializationParams& ioSpecializationParams, + Type* type) { // Whether or not something is an array does not affect // the number of existential slots it introduces. @@ -9683,7 +9682,9 @@ namespace Slang if( auto parameterGroupType = as(type) ) { - _collectExistentialTypeParamsRec(ioSlots, parameterGroupType->getElementType()); + _collectExistentialSpecializationParamsRec( + ioSpecializationParams, + parameterGroupType->getElementType()); return; } @@ -9692,9 +9693,14 @@ namespace Slang auto typeDeclRef = declRefType->declRef; if( auto interfaceDeclRef = typeDeclRef.as() ) { - // Each leaf parameter of interface type adds one slot. + // Each leaf parameter of interface type adds a specialization + // parameter, which determines the concrete type(s) that may + // be provided as arguments for that parameter. // - ioSlots.paramTypes.add(type); + SpecializationParam specializationParam; + specializationParam.flavor = SpecializationParam::Flavor::ExistentialType; + specializationParam.object = type; + ioSpecializationParams.add(specializationParam); } else if( auto structDeclRef = typeDeclRef.as() ) { @@ -9706,7 +9712,9 @@ namespace Slang if(fieldDeclRef.getDecl()->HasModifier()) continue; - _collectExistentialTypeParamsRec(ioSlots, fieldDeclRef); + _collectExistentialSpecializationParamsRec( + ioSpecializationParams, + fieldDeclRef); } } } @@ -9716,26 +9724,58 @@ namespace Slang // element types. } - static void _collectExistentialTypeParamsRec( - ExistentialTypeSlots& ioSlots, + static void _collectExistentialSpecializationParamsRec( + SpecializationParams& ioSpecializationParams, DeclRef paramDeclRef) { - _collectExistentialTypeParamsRec(ioSlots, GetType(paramDeclRef)); + _collectExistentialSpecializationParamsRec( + ioSpecializationParams, + GetType(paramDeclRef)); } - /// Add information about a shader parameter to `ioParams` and `ioSlots` - static void _collectExistentialSlotsForShaderParam( + /// Collect any interface/existential specialization parameters for `paramDeclRef` into `ioParamInfo` and `ioSpecializationParams` + static void _collectExistentialSpecializationParamsForShaderParam( ShaderParamInfo& ioParamInfo, - ExistentialTypeSlots& ioSlots, + SpecializationParams& ioSpecializationParams, DeclRef paramDeclRef) { - Index startSlot = ioSlots.paramTypes.getCount(); - _collectExistentialTypeParamsRec(ioSlots, paramDeclRef); - Index endSlot = ioSlots.paramTypes.getCount(); + Index beginParamIndex = ioSpecializationParams.getCount(); + _collectExistentialSpecializationParamsRec(ioSpecializationParams, paramDeclRef); + Index endParamIndex = ioSpecializationParams.getCount(); - ioParamInfo.firstExistentialTypeSlot = UInt(startSlot); - ioParamInfo.existentialTypeSlotCount = UInt(endSlot - startSlot);; + ioParamInfo.firstSpecializationParamIndex = beginParamIndex; + ioParamInfo.specializationParamCount = endParamIndex - beginParamIndex; + } + + void EntryPoint::_collectGenericSpecializationParamsRec(Decl* decl) + { + if(!decl) + return; + + _collectGenericSpecializationParamsRec(decl->ParentDecl); + + auto genericDecl = as(decl); + if(!genericDecl) + return; + + for(auto m : genericDecl->Members) + { + if(auto genericTypeParam = as(m)) + { + SpecializationParam param; + param.flavor = SpecializationParam::Flavor::GenericType; + param.object = genericTypeParam; + m_genericSpecializationParams.add(param); + } + else if(auto genericValParam = as(m)) + { + SpecializationParam param; + param.flavor = SpecializationParam::Flavor::GenericValue; + param.object = genericValParam; + m_genericSpecializationParams.add(param); + } + } } /// Enumerate the existential-type parameters of an `EntryPoint`. @@ -9744,6 +9784,24 @@ namespace Slang /// void EntryPoint::_collectShaderParams() { + // We don't currently treat an entry point as having any + // *global* shader parameters. + // + // TODO: We could probably clean up the code a bit by treating + // an entry point as introducing a global shader parameter + // that is based on the implicit "parameters struct" type + // of the entry point itself. + + // We collect the generic parameters of the entry point, + // along with those of any outer generics first. + // + _collectGenericSpecializationParamsRec(getFuncDecl()); + + // After geneic specialization parameters have been collected, + // we look through the value parameters of the entry point + // function and see if any of them introduce existential/interface + // specialization parameters. + // // Note: we defensively test whether there is a function decl-ref // because this routine gets called from the constructor, and // a "dummy" entry point will have a null pointer for the function. @@ -9755,9 +9813,9 @@ namespace Slang ShaderParamInfo shaderParamInfo; shaderParamInfo.paramDeclRef = paramDeclRef; - _collectExistentialSlotsForShaderParam( + _collectExistentialSpecializationParamsForShaderParam( shaderParamInfo, - m_existentialSlots, + m_existentialSpecializationParams, paramDeclRef); m_shaderParams.add(shaderParamInfo); @@ -9765,111 +9823,6 @@ namespace Slang } } -static bool shouldUseFalcorCustomSharedKeywordSemantics( - EntryPointGroup* entryPointGroup) -{ - if( !entryPointGroup->getLinkageImpl()->m_useFalcorCustomSharedKeywordSemantics ) - return false; - - // As a sanity check, if we are being asked to lay out an - // empty entry-point group, then don't apply the convention. - // - if(entryPointGroup->getEntryPointCount() == 0) - return false; - - // Otherwise we will look at the first entry point in the group, - // and use that to determine whether it looks like we are compiling - // ray-tracing shaders or not. - // - switch( entryPointGroup->getEntryPoint(0)->getStage() ) - { - case Stage::AnyHit: - case Stage::Callable: - case Stage::ClosestHit: - case Stage::Intersection: - case Stage::Miss: - case Stage::RayGeneration: - return true; - - default: - return false; - } -} - - -static bool shouldUseFalcorCustomSharedKeywordSemantics( - Program* program) -{ - if( !program->getLinkageImpl()->m_useFalcorCustomSharedKeywordSemantics ) - return false; - - // As a sanity check, if we are being asked to lay out a program - // with *no* entry points, then we don't apply the convention. - // - if(program->getEntryPointGroupCount() == 0) - return false; - - // Otherwise we let the first entry-point group determine if we should - // apply the policy for the entire program. - // - // Note: this could lead to confusing results if a `Program` mixes - // entry point groups for RT and non-RT pipelines, but that isn't - // a scenario we expect to come up and this whole routine is handling - // "do what I mean" semantics for legacy behavior. - // - return shouldUseFalcorCustomSharedKeywordSemantics(program->getEntryPointGroup(0)); -} - - -void EntryPointGroup::_collectShaderParams(DiagnosticSink* sink) -{ - // If and only if we are in the special mode for Falcor support - // where non-`shared` global shader parameters are actually - // supposed to go into the "local root signature," then we - // will consider such parameters as if they were entry-point - // parameters, attached to the group. - // - if( shouldUseFalcorCustomSharedKeywordSemantics(this) ) - { - for( auto module : getModuleDependencies() ) - { - auto moduleDecl = module->getModuleDecl(); - for( auto globalVar : moduleDecl->getMembersOfType() ) - { - // Don't consider globals that aren't shader parameters. - // - if(!isGlobalShaderParameter(globalVar)) - continue; - - // Don't consider global shader paramters that were marked - // `shared`, since that is how global-root-signature parameters - // are being specified. - // - if( globalVar->HasModifier() ) - continue; - - auto paramDeclRef = makeDeclRef(globalVar.Ptr()); - - ShaderParamInfo shaderParamInfo; - shaderParamInfo.paramDeclRef = paramDeclRef; - - ExistentialTypeSlots slots; - _collectExistentialSlotsForShaderParam( - shaderParamInfo, - slots, - paramDeclRef); - - if( slots.paramTypes.getCount() != 0 ) - { - sink->diagnose(globalVar, Diagnostics::typeParametersNotAllowedOnEntryPointGlobal, globalVar); - } - - m_shaderParams.add(shaderParamInfo); - } - } - } -} - // Validate that an entry point function conforms to any additional // constraints based on the stage (and profile?) it specifies. void validateEntryPoint( @@ -10015,6 +9968,7 @@ void EntryPointGroup::_collectShaderParams(DiagnosticSink* sink) // auto compileRequest = entryPointReq->getCompileRequest(); auto translationUnit = entryPointReq->getTranslationUnit(); + auto linkage = compileRequest->getLinkage(); auto sink = compileRequest->getSink(); auto translationUnitSyntax = translationUnit->getModuleDecl(); @@ -10133,8 +10087,8 @@ void EntryPointGroup::_collectShaderParams(DiagnosticSink* sink) // a more uniform representation in the AST? } - RefPtr entryPoint = EntryPoint::create( + linkage, makeDeclRef(entryPointFuncDecl), entryPointProfile); @@ -10548,11 +10502,97 @@ static bool doesParameterMatch( return true; } - /// Enumerate the existential-type parameters of a `Program`. - /// - /// Any parameters found will be added to the list of existential slots on `this`. - /// - void Program::_collectShaderParams(DiagnosticSink* sink) + void Module::_collectShaderParams() + { + auto moduleDecl = m_moduleDecl; + + // We are going to walk the global declarations in the body of the + // module, and use those to build up our lists of: + // + // * Global shader parameters + // * Specialization parameters (both generic and interface/existential) + // * Requirements (`import`ed modules) + // + // For requirements, we want to be careful to only + // add each required module once (in case the same + // module got `import`ed multiple times), so we + // will keep a set of the modules we've already + // seen and processed. + // + HashSet requiredModuleSet; + + for( auto globalDecl : moduleDecl->Members ) + { + if(auto globalVar = globalDecl.as()) + { + // We do not want to consider global variable declarations + // that don't represents shader parameters. This includes + // things like `static` globals and `groupshared` variables. + // + if(!isGlobalShaderParameter(globalVar)) + continue; + + // At this point we know we have a global shader parameter. + + GlobalShaderParamInfo shaderParamInfo; + shaderParamInfo.paramDeclRef = makeDeclRef(globalVar.Ptr()); + + // We need to consider what specialization parameters + // are introduced by this shader parameter. This step + // fills in fields on `shaderParamInfo` so that we + // can assocaite specialization arguments supplied later + // with the correct parameter. + // + _collectExistentialSpecializationParamsForShaderParam( + shaderParamInfo, + m_specializationParams, + makeDeclRef(globalVar.Ptr())); + + m_shaderParams.add(shaderParamInfo); + } + else if( auto globalGenericParam = as(globalDecl) ) + { + // A global generic type parameter declaration introduces + // a suitable specialization parameter. + // + SpecializationParam specializationParam; + specializationParam.flavor = SpecializationParam::Flavor::GenericType; + specializationParam.object = globalGenericParam; + m_specializationParams.add(specializationParam); + } + else if( auto importDecl = as(globalDecl) ) + { + // An `import` declaration creates a requirement dependency + // from this module to another module. + // + auto importedModule = getModule(importDecl->importedModuleDecl); + if(!requiredModuleSet.Contains(importedModule)) + { + requiredModuleSet.Add(importedModule); + m_requirements.add(importedModule); + } + } + } + } + + Index Module::getRequirementCount() + { + return m_requirements.getCount(); + } + + RefPtr Module::getRequirement(Index index) + { + return m_requirements[index]; + } + + void Module::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + { + visitor->visitModule(this, as(specializationInfo)); + } + + + /// Enumerate the parameters of a `LegacyProgram`. + void LegacyProgram::_collectShaderParams(DiagnosticSink* sink) { // We need to collect all of the global shader parameters // referenced by the compile request, and for each we @@ -10568,10 +10608,22 @@ static bool doesParameterMatch( // To deal with the first issue, we will maintain a map from a parameter // name to the index of an existing parameter with that name. // + // TODO: Eventually we should deprecate support for the + // deduplication feature of `LegaqcyProgram`, at which point + // this entire type and all its complications can be eliminated + // from the code (that includes a lot of support in the "parameter + // binding" step for shader parameters with multiple declarations). + // Until that point this type will have a fair amount of duplication + // with stuff in `Module` and `CompositeComponentType`. + + // We use a dictionary to keep track of any shader parameter + // we've alrady collected with a given name. + // Dictionary mapNameToParamIndex; - for( auto module : getModuleDependencies() ) + for( auto translationUnit : m_translationUnits ) { + auto module = translationUnit->getModule(); auto moduleDecl = module->getModuleDecl(); for( auto globalVar : moduleDecl->getMembersOfType() ) { @@ -10582,27 +10634,6 @@ static bool doesParameterMatch( if(!isGlobalShaderParameter(globalVar)) continue; - // HACK: In order to support existing policy in the Falcor - // application, we support a custom mode where only - // global variables marked as `shared` should be considered - // as global shader parameters (that go in the "global - // root signature" for DXR). - // - // TODO: Eliminate this special case once all of the client - // application code has been ported to use more general-purpose - // Slang mechanisms (e.g., entry-point `uniform` parameters). - // - if( shouldUseFalcorCustomSharedKeywordSemantics(this) ) - { - if( !globalVar->HasModifier() ) - { - // Skip a non-`shared` global for purposes of enumerating - // shader parameters. - // - continue; - } - } - // This declaration may represent the same logical parameter // as a declaration that came from a different translation unit. // If that is the case, we want to re-use the same `ShaderParamInfo` @@ -10646,9 +10677,9 @@ static bool doesParameterMatch( GlobalShaderParamInfo shaderParamInfo; shaderParamInfo.paramDeclRef = makeDeclRef(globalVar.Ptr()); - _collectExistentialSlotsForShaderParam( + _collectExistentialSpecializationParamsForShaderParam( shaderParamInfo, - m_globalExistentialSlots, + m_specializationParams, makeDeclRef(globalVar.Ptr())); m_shaderParams.add(shaderParamInfo); @@ -10656,13 +10687,59 @@ static bool doesParameterMatch( } } - /// Create a `Program` to represent the compiled code. + /// Create a new component type based on `inComponentType`, but with all its requiremetns filled. + RefPtr fillRequirements( + ComponentType* inComponentType) + { + auto linkage = inComponentType->getLinkage(); + + // We are going to simplify things by solving the problem iteratively. + // If the current `componentType` has requirements for `A`, `B`, ... etc. + // then we will create a composite of `componentType`, `A`, `B`, ... + // and then see if the resulting composite has any requirements. + // + // This avoids the problem of trying to compute teh transitive closure + // of the requirements relationship (while dealing with deduplication, + // etc.) + + RefPtr componentType = inComponentType; + for(;;) + { + auto requirementCount = componentType->getRequirementCount(); + if(requirementCount == 0) + break; + + List> allComponents; + allComponents.add(componentType); + + for(Index rr = 0; rr < requirementCount; ++rr) + { + auto requirement = componentType->getRequirement(rr); + allComponents.add(requirement); + } + + componentType = CompositeComponentType::create( + linkage, + allComponents); + } + return componentType; + } + + /// Create a component type to represent the "global scope" of a compile request. /// - /// The created program will comprise all of the translation - /// units that were compiled as part of the request, as - /// well as any entry points in those translation units. + /// This component type will include all the modules and their global + /// parameters from the compile request, but not anything specific + /// to any entry point functions. /// - RefPtr createUnspecializedProgram( + /// The layout for this component type will thus represent the things that + /// a user is likely to want to have stay the same across all compiled + /// entry points. + /// + /// The component type that this function creates is unspecialized, in + /// that it doesn't take into account any specialization arguments + /// that might have been supplied as part of the compile request. + /// + RefPtr createUnspecializedGlobalComponentType( FrontEndCompileRequest* compileRequest) { // We want our resulting program to depend on @@ -10677,16 +10754,51 @@ static bool doesParameterMatch( // auto linkage = compileRequest->getLinkage(); auto sink = compileRequest->getSink(); - auto program = new Program(linkage); - for(auto translationUnit : compileRequest->translationUnits ) + + RefPtr globalComponentType; + if(compileRequest->translationUnits.getCount() == 1) { - program->addReferencedLeafModule(translationUnit->getModule()); + // The common case is that a compilation only uses + // a single translation unit, and thus results in + // a single `Module`. We can then use that module + // as the component type that represents the global scope. + // + globalComponentType = compileRequest->translationUnits[0]->getModule(); } - for(auto translationUnit : compileRequest->translationUnits ) + else { - program->addReferencedModule(translationUnit->getModule()); + globalComponentType = new LegacyProgram( + linkage, + compileRequest->translationUnits, + sink); } + return fillRequirements(globalComponentType); + } + + /// Create a component type that represents the global scope for a compile request, + /// along with any entry point functions. + /// + /// The resulting component type will include the global-scope information + /// first, so its layout will be compatible with the result of + /// `createUnspecializedGlobalComponentType`. + /// + /// The new component type will also add on any entry-point functions + /// that were requested and will thus include space for their `uniform` parameters. + /// If multiple entry points were requested then they will be given non-overlapping + /// parameter bindings, consistent with them being used together in + /// a single pipeline state, hit group, etc. + /// + /// The result of this function is unspecialized and doesn't take into + /// account any specialization arguments the user might have supplied. + /// + RefPtr createUnspecializedGlobalAndEntryPointsComponentType( + FrontEndCompileRequest* compileRequest) + { + auto linkage = compileRequest->getLinkage(); + auto sink = compileRequest->getSink(); + + auto globalComponentType = compileRequest->getGlobalComponentType(); // The validation of entry points here will be modal, and controlled // by whether the user specified any entry points directly via @@ -10699,6 +10811,9 @@ static bool doesParameterMatch( // bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0; + List> allComponentTypes; + allComponentTypes.add(globalComponentType); + if( anyExplicitEntryPoints ) { // If there were any explicit requests for entry points to be @@ -10716,8 +10831,9 @@ static bool doesParameterMatch( // but didn't specify any groups (since the current // compilation API doesn't allow for grouping). // - program->addEntryPoint(entryPoint, sink); entryPointReq->getTranslationUnit()->entryPoints.add(entryPoint); + + allComponentTypes.add(entryPoint); } } @@ -10777,6 +10893,7 @@ static bool doesParameterMatch( profile.setStage(entryPointAttr->stage); RefPtr entryPoint = EntryPoint::create( + linkage, makeDeclRef(funcDecl), profile); @@ -10792,179 +10909,340 @@ static bool doesParameterMatch( // group, so that its entry-point parameters lay out // independent of the others. // - program->addEntryPoint(entryPoint, sink); translationUnit->entryPoints.add(entryPoint); + + allComponentTypes.add(entryPoint); } } } - program->_collectShaderParams(sink); - - return program; + if(allComponentTypes.getCount() > 1) + { + auto composite = CompositeComponentType::create( + linkage, + allComponentTypes); + return composite; + } + else + { + return globalComponentType; + } } - static void _specializeExistentialTypeParams( - Linkage* linkage, - ExistentialTypeSlots& ioSlots, - List> const& args, + RefPtr Module::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, DiagnosticSink* sink) { - Index slotCount = ioSlots.paramTypes.getCount(); - Index argCount = args.getCount(); + SLANG_ASSERT(argCount == getSpecializationParamCount()); - if( slotCount != argCount ) - { - sink->diagnose(SourceLoc(), Diagnostics::mismatchExistentialSlotArgCount, slotCount, argCount); - return; - } + SemanticsVisitor visitor(getLinkage(), sink); - SemanticsVisitor visitor(linkage, sink); + RefPtr specializationInfo = new Module::ModuleSpecializationInfo(); - for( Index ii = 0; ii < slotCount; ++ii ) + for( Index ii = 0; ii < argCount; ++ii ) { - auto slotType = ioSlots.paramTypes[ii]; - auto argExpr = args[ii]; + auto& arg = args[ii]; + auto& param = m_specializationParams[ii]; - auto argType = checkProperType(linkage, TypeExp(argExpr), sink); - if(!argType) + auto argType = arg.val.as(); + SLANG_ASSERT(argType); + + switch( param.flavor ) { - // TODO: Each slot should track a source location and/or a `VarDeclBase` - // that names the parameter that the slot corresponds to. + case SpecializationParam::Flavor::GenericType: + { + auto genericTypeParamDecl = param.object.as(); + SLANG_ASSERT(genericTypeParamDecl); - sink->diagnose(SourceLoc(), Diagnostics::existentialSlotArgNotAType, ii); - return; + // TODO: There is a serious flaw to this checking logic if we ever have cases where + // the constraints on one `type_param` can depend on another `type_param`, e.g.: + // + // type_param A; + // type_param B : ISidekick; + // + // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to + // `ISidekick`, then the compiler needs to know whether `A` is being + // set to `Batman` to know whether the setting for `B` is valid. In this limit + // the constraints can be mutually recursive (so `A : IMentor`). + // + // The only way to check things correctly is to validate each conformance under + // a set of assumptions (substitutions) that includes all the type substitutions, + // and possibly also all the other constraints *except* the one to be validated. + // + // We will punt on this for now, and just check each constraint in isolation. + + // As a quick sanity check, see if the argument that is being supplied for a + // global generic type parameter is a reference to *another* global generic + // type parameter, since that should always be an error. + // + if( auto argDeclRefType = argType.as() ) + { + auto argDeclRef = argDeclRefType->declRef; + if(auto argGenericParamDeclRef = argDeclRef.as()) + { + if(argGenericParamDeclRef.getDecl() == genericTypeParamDecl) + { + // We are trying to specialize a generic parameter using itself. + sink->diagnose(genericTypeParamDecl, + Diagnostics::cannotSpecializeGlobalGenericToItself, + genericTypeParamDecl->getName()); + continue; + } + else + { + // We are trying to specialize a generic parameter using a *different* + // global generic type parameter. + sink->diagnose(genericTypeParamDecl, + Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, + genericTypeParamDecl->getName(), + argGenericParamDeclRef.GetName()); + continue; + } + } + } + + ModuleSpecializationInfo::GenericArgInfo genericArgInfo; + genericArgInfo.paramDecl = genericTypeParamDecl; + genericArgInfo.argVal = argType; + specializationInfo->genericArgs.add(genericArgInfo); + + // Walk through the declared constraints for the parameter, + // and check that the argument actually satisfies them. + for(auto constraintDecl : genericTypeParamDecl->getMembersOfType()) + { + // Get the type that the constraint is enforcing conformance to + auto interfaceType = GetSup(DeclRef(constraintDecl, nullptr)); + + // Use our semantic-checking logic to search for a witness to the required conformance + auto witness = visitor.tryGetSubtypeWitness(argType, interfaceType); + if (!witness) + { + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(genericTypeParamDecl, + Diagnostics::typeArgumentForGenericParameterDoesNotConformToInterface, + argType, + genericTypeParamDecl->nameAndLoc.name, + interfaceType); + } + + ModuleSpecializationInfo::GenericArgInfo constraintArgInfo; + constraintArgInfo.paramDecl = constraintDecl; + constraintArgInfo.argVal = witness; + specializationInfo->genericArgs.add(constraintArgInfo); + } + } + break; + + case SpecializationParam::Flavor::ExistentialType: + { + auto interfaceType = param.object.as(); + SLANG_ASSERT(interfaceType); + + auto witness = visitor.tryGetSubtypeWitness(argType, interfaceType); + if (!witness) + { + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(SourceLoc(), + Diagnostics::typeArgumentDoesNotConformToInterface, + argType, + interfaceType); + } + + ExpandedSpecializationArg expandedArg; + expandedArg.val = argType; + expandedArg.witness = witness; + + specializationInfo->existentialArgs.add(expandedArg); + } + break; + + default: + SLANG_UNEXPECTED("unhandled specialization parameter flavor"); } + } + return specializationInfo; + } - auto witness = visitor.tryGetSubtypeWitness(argType, slotType); - if (!witness) + + static void _extractSpecializationArgs( + ComponentType* componentType, + List> const& argExprs, + List& outArgs, + DiagnosticSink* sink) + { + auto linkage = componentType->getLinkage(); + + auto argCount = argExprs.getCount(); + for(Index ii = 0; ii < argCount; ++ii ) + { + auto argExpr = argExprs[ii]; + auto paramInfo = componentType->getSpecializationParam(ii); + + // TODO: We should support non-type arguments here + + auto argType = checkProperType(linkage, TypeExp(argExpr), sink); + if( !argType ) { // If no witness was found, then we will be unable to satisfy // the conformances required. - sink->diagnose(SourceLoc(), Diagnostics::existentialSlotArgDoesNotConform, ii, slotType); - return; + sink->diagnose(argExpr, + Diagnostics::expectedAType, + argExpr->type); + continue; } - ExistentialTypeSlots::Arg arg; - arg.type = argType; - arg.witness = witness; - ioSlots.args.add(arg); + SpecializationArg arg; + arg.val = argType; + outArgs.add(arg); } } - void EntryPoint::_specializeExistentialTypeParams( - List> const& args, + RefPtr EntryPoint::_validateSpecializationArgsImpl( + SpecializationArg const* inArgs, + Index inArgCount, DiagnosticSink* sink) { - Slang::_specializeExistentialTypeParams(getLinkage(), m_existentialSlots, args, sink); - } + auto args = inArgs; + auto argCount = inArgCount; - /// Create a specialization an existing entry point based on generic arguments. - RefPtr createSpecializedEntryPoint( - EntryPoint* unspecializedEntryPoint, - List> const& genericArgs, - List> const& existentialArgs, - DiagnosticSink* sink) - { - auto linkage = unspecializedEntryPoint->getLinkage(); + SemanticsVisitor visitor(getLinkage(), sink); - // TODO: Need to be careful in case entry point already has a decl-ref, - // pertaining to outer specializations (e.g., when entry point was - // nested in a generic type. + // The first N arguments will be for the explicit generic parameters + // of the entry point (if it has any). // - auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl(); + auto genericSpecializationParamCount = getGenericSpecializationParamCount(); + SLANG_ASSERT(argCount >= genericSpecializationParamCount); - SemanticsVisitor semantics( - linkage, - sink); + Result result = SLANG_OK; - DeclRef entryPointFuncDeclRef = makeDeclRef(entryPointFuncDecl.Ptr()); - if( auto genericDecl = as(entryPointFuncDecl->ParentDecl) ) + RefPtr info = new EntryPointSpecializationInfo(); + + DeclRef specializedFuncDeclRef = m_funcDeclRef; + if(genericSpecializationParamCount) { - // We will construct a suitable `GenericAppExpr` to represent - // the user-specified `genericDecl` being applied to the - // supplied `genericArgs`, and then use the existing - // semantic checking logic that would apply to an explicit - // generic application like `F` if it were - // encountered in the source code. + // We need to construct a generic application and use + // the semantic checking machinery to expand out + // the rest of the arguments via inference... - auto session = linkage->getSessionImpl(); - auto genericDeclRef = makeDeclRef(genericDecl); + auto genericDeclRef = m_funcDeclRef.GetParent().as(); + SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters - // The first pieces is a `VarExpr` that refers to `genericDecl`. - // - // TODO: This would not be needed if we instead parsed - // the supplied entry-point name into an expression - // earlier in this function. - // - RefPtr genericExpr = new VarExpr(); - genericExpr->declRef = genericDeclRef; - genericExpr->type.type = getTypeForDeclRef(session, genericDeclRef); - - // Next we construct the actual `GenericAppExpr` - // - RefPtr genericAppExpr = new GenericAppExpr(); - genericAppExpr->FunctionExpr = genericExpr; - genericAppExpr->Arguments = genericArgs; + RefPtr genericSubst = new GenericSubstitution(); + genericSubst->outer = genericDeclRef.substitutions.substitutions; + genericSubst->genericDecl = genericDeclRef.getDecl(); - // We use the semantics visitor to perform the - // actual checking logic (this might report - // errors) - // - auto checkedExpr = semantics.checkGenericAppWithCheckedArgs(genericAppExpr); - - // Now we need to extract an appropriate decl-ref for the entry - // point from the `checkedExpr`. - // - if( auto declRefExpr = checkedExpr.as() ) + for(Index ii = 0; ii < genericSpecializationParamCount; ++ii) { - // TODO: We should eventually check for the case - // where we have a `MemberExpr` or another case of - // `DeclRefExpr` that cannot be summarized as just - // its decl-ref. - // - // The basic `VarExpr` and `StaticMemberExpr` cases - // should be allow-able. - - entryPointFuncDeclRef = declRefExpr->declRef.as(); + auto specializationArg = args[ii]; + genericSubst->args.add(specializationArg.val); } - else if( semantics.IsErrorExpr(checkedExpr) ) + + for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType() ) { - // Any semantic error that occured should have been - // reported already. - return nullptr; + auto constraintSubst = genericDeclRef.substitutions; + constraintSubst.substitutions = genericSubst; + + DeclRef constraintDeclRef( + constraintDecl, constraintSubst); + + auto sub = GetSub(constraintDeclRef); + auto sup = GetSup(constraintDeclRef); + + auto subTypeWitness = visitor.tryGetSubtypeWitness(sub, sup); + if(subTypeWitness) + { + genericSubst->args.add(subTypeWitness); + } + else + { + // TODO: diagnose a problem here + sink->diagnose(constraintDecl, Diagnostics::typeArgumentDoesNotConformToInterface, sub, sup); + result = SLANG_FAIL; + continue; + } } - else + + specializedFuncDeclRef.substitutions.substitutions = genericSubst; + } + + info->specializedFuncDeclRef = specializedFuncDeclRef; + + // Once the generic parameters (if any) have been dealt with, + // any remaining specialization arguments are for existential/interface + // specialization parameters, attached to the value parameters + // of the entry point. + // + args += genericSpecializationParamCount; + argCount -= genericSpecializationParamCount; + + auto existentialSpecializationParamCount = getExistentialSpecializationParamCount(); + SLANG_ASSERT(argCount == existentialSpecializationParamCount); + + for( Index ii = 0; ii < existentialSpecializationParamCount; ++ii ) + { + auto& param = m_existentialSpecializationParams[ii]; + auto& specializationArg = args[ii]; + + // TODO: We need to handle all the cases of "flavor" for the `param`s (not just types) + + auto paramType = param.object.as(); + auto argType = specializationArg.val.as(); + + auto witness = visitor.tryGetSubtypeWitness(argType, paramType); + if (!witness) { - // The result of specializing a reference to a generic - // function should always be a `DeclRefExpr` - // - SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`"); - UNREACHABLE_RETURN(nullptr); + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(SourceLoc(), Diagnostics::typeArgumentDoesNotConformToInterface, argType, paramType); + result = SLANG_FAIL; + continue; } + + ExpandedSpecializationArg expandedArg; + expandedArg.val = specializationArg.val; + expandedArg.witness = witness; + info->existentialSpecializationArgs.add(expandedArg); } - RefPtr specializedEntryPoint = EntryPoint::create( - entryPointFuncDeclRef, - unspecializedEntryPoint->getProfile()); + return info; + } - // Next we need to validate the existential arguments. - specializedEntryPoint->_specializeExistentialTypeParams(existentialArgs, sink); + /// Create a specialization an existing entry point based on specialization argument expressions. + RefPtr createSpecializedEntryPoint( + EntryPoint* unspecializedEntryPoint, + List> const& argExprs, + DiagnosticSink* sink) + { + // We need to convert all of the `Expr` arguments + // into `SpecializationArg`s, so that we can bottleneck + // through the shared logic. + // + List args; + _extractSpecializationArgs(unspecializedEntryPoint, argExprs, args, sink); + if(sink->GetErrorCount()) + return nullptr; - return specializedEntryPoint; + return unspecializedEntryPoint->specialize( + args.getBuffer(), + args.getCount(), + sink); } - /// Parse an array of strings as generic arguments. + /// Parse an array of strings as specialization arguments. /// /// Names in the strings will be parsed in the context of /// the code loaded into the given compile request. /// - void parseGenericArgStrings( + void parseSpecializationArgStrings( EndToEndCompileRequest* endToEndReq, List const& genericArgStrings, List>& outGenericArgs) { - auto unspecialiedProgram = endToEndReq->getUnspecializedProgram(); + auto unspecialiedProgram = endToEndReq->getUnspecializedGlobalComponentType(); // TODO: Building a list of `scopesToTry` here shouldn't // be required, since the `Scope` type itself has the ability @@ -11004,351 +11282,109 @@ static bool doesParameterMatch( } } + if(!argExpr) + { + sink->diagnose(SourceLoc(), Diagnostics::internalCompilerError, "couldn't parse specialization argument"); + return; + } + outGenericArgs.add(argExpr); } } - void Program::_specializeExistentialTypeParams( - List> const& args, - DiagnosticSink* sink) - { - Slang::_specializeExistentialTypeParams(getLinkageImpl(), m_globalExistentialSlots, args, sink); - } - Type* Linkage::specializeType( Type* unspecializedType, Int argCount, Type* const* args, DiagnosticSink* sink) { + SLANG_ASSERT(unspecializedType); + // TODO: We should cache and re-use specialized types // when the exact same arguments are provided again later. SemanticsVisitor visitor(this, sink); + SpecializationParams specializationParams; + _collectExistentialSpecializationParamsRec(specializationParams, unspecializedType); - ExistentialTypeSlots slots; - _collectExistentialTypeParamsRec(slots, unspecializedType); - - assert(slots.paramTypes.getCount() == argCount); + assert(specializationParams.getCount() == argCount); + ExpandedSpecializationArgs specializationArgs; for( Int aa = 0; aa < argCount; ++aa ) { + auto paramType = specializationParams[aa].object.as(); auto argType = args[aa]; - ExistentialTypeSlots::Arg arg; - arg.type = argType; - arg.witness = visitor.tryGetSubtypeWitness(argType, slots.paramTypes[aa]); - slots.args.add(arg); + ExpandedSpecializationArg arg; + arg.val = argType; + arg.witness = visitor.tryGetSubtypeWitness(argType, paramType); + specializationArgs.add(arg); } RefPtr specializedType = new ExistentialSpecializedType(); specializedType->baseType = unspecializedType; - specializedType->slots = slots; + specializedType->args = specializationArgs; m_specializedTypes.add(specializedType); return specializedType; } - // Shared implementation logic for the `_createSpecializedProgram*` entry points. - static RefPtr _createSpecializedProgramImpl( + /// Shared implementation logic for the `_createSpecializedProgram*` entry points. + static RefPtr _createSpecializedProgramImpl( Linkage* linkage, - Program* unspecializedProgram, - List> const& globalGenericArgs, - List> const& globalExistentialArgs, + ComponentType* unspecializedProgram, + List> const& specializationArgExprs, DiagnosticSink* sink) { - // TODO: If there are no specialization arguments, + // If there are no specialization arguments, // then the the result of specialization should - // be the same as the input... *but* we are promising - // to return a program without any entry points, - // and that screws things up here. - // - // For now we just carefully avod this early-exit - // if we have any entry points. + // be the same as the input. // - // Eventually we should try to revise the model so - // that specialization of a program that includes - // entry points can make some kind of sense. - // - if( globalGenericArgs.getCount() == 0 - && globalExistentialArgs.getCount() == 0 - && unspecializedProgram->getEntryPointCount() == 0) + auto specializationArgCount = specializationArgExprs.getCount(); + if( specializationArgCount == 0 ) { return unspecializedProgram; } - // The given `unspecializedProgram` should be one that - // was checked through the front-end, so that now we - // only need to check if the given arguments can satisfy - // the requirements of the global generic parameters. - // - // The new program needs to start off with the same - // module dependency list as the original. - // - RefPtr specializedProgram = new Program(linkage); - for(auto module : unspecializedProgram->getModuleDependencies()) - { - specializedProgram->addReferencedLeafModule(module); - } - - - // We will collect all the global generic parameters - // defined in the modules being referenced, to find - // the global generic parameter signature of the - // program. - // - // TODO: Note that this doesn't handle the case where one - // or more of the type *arguments* that we are specifying - // ends up requiring additional modules to be referenced, - // which might in turn introduce new global generic parameters. - // - List> globalGenericParams; - for(auto module : unspecializedProgram->getModuleDependencies()) + auto specializationParamCount = unspecializedProgram->getSpecializationParamCount(); + if(specializationArgCount != specializationParamCount ) { - for(auto param : module->getModuleDecl()->getMembersOfType()) - globalGenericParams.add(param); - } - - // Next, we will check whether the supplied arguments can - // satisfy those parameters. - // - // An easy early-out case will be if the number of - // arguments isn't correct. - // - if (globalGenericParams.getCount() != globalGenericArgs.getCount()) - { - sink->diagnose(SourceLoc(), Diagnostics::mismatchGlobalGenericArguments, - globalGenericParams.getCount(), - globalGenericArgs.getCount()); + sink->diagnose(SourceLoc(), Diagnostics::mismatchSpecializationArguments, + specializationParamCount, + specializationArgCount); return nullptr; } - // We have an appropriate number of arguments for the global generic parameters, + // We have an appropriate number of arguments for the global specialization parameters, // and now we need to check that the arguments conform to the declared constraints. // SemanticsVisitor visitor(linkage, sink); - // Along the way, we will build up an appropriate set of substitutions to represent - // the generic arguments and their conformances. - // - RefPtr globalGenericSubsts; - auto globalGenericSubstLink = &globalGenericSubsts; - // - // TODO: There is a serious flaw to this checking logic if we ever have cases where - // the constraints on one `type_param` can depend on another `type_param`, e.g.: - // - // type_param A; - // type_param B : ISidekick; - // - // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to - // `ISidekick`, then the compiler needs to know whether `A` is being - // set to `Batman` to know whether the setting for `B` is valid. In this limit - // the constraints can be mutually recursive (so `A : IMentor`). - // - // The only way to check things correctly is to validate each conformance under - // a set of assumptions (substitutions) that includes all the type substitutions, - // and possibly also all the other constraints *except* the one to be validated. - // - // We will punt on this for now, and just check each constraint in isolation. - // - Index argCounter = 0; - for(auto& globalGenericParam : globalGenericParams) - { - // Get the argument that matches this parameter. - Index argIndex = argCounter++; - SLANG_ASSERT(argIndex < globalGenericArgs.getCount()); - auto globalGenericArg = checkProperType(linkage, TypeExp(globalGenericArgs[argIndex]), sink); - if (!globalGenericArg) - { - sink->diagnose(globalGenericParam, Diagnostics::globalGenericArgumentNotAType, globalGenericParam->getName()); - return nullptr; - } - - // As a quick sanity check, see if the argument that is being supplied for a parameter - // is just the parameter itself, because this should always be an error: - // - if( auto argDeclRefType = globalGenericArg.as() ) - { - auto argDeclRef = argDeclRefType->declRef; - if(auto argGenericParamDeclRef = argDeclRef.as()) - { - if(argGenericParamDeclRef.getDecl() == globalGenericParam) - { - // We are trying to specialize a generic parameter using itself. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToItself, - globalGenericParam->getName()); - continue; - } - else - { - // We are trying to specialize a generic parameter using a *different* - // global generic type parameter. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, - globalGenericParam->getName(), - argGenericParamDeclRef.GetName()); - continue; - } - } - } - - // Create a substitution for this parameter/argument. - RefPtr subst = new GlobalGenericParamSubstitution(); - subst->paramDecl = globalGenericParam; - subst->actualType = globalGenericArg; - - // Walk through the declared constraints for the parameter, - // and check that the argument actually satisfies them. - for(auto constraint : globalGenericParam->getMembersOfType()) - { - // Get the type that the constraint is enforcing conformance to - auto interfaceType = GetSup(DeclRef(constraint, nullptr)); - - // Use our semantic-checking logic to search for a witness to the required conformance - auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); - if (!witness) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(globalGenericParam, - Diagnostics::typeArgumentDoesNotConformToInterface, - globalGenericParam->nameAndLoc.name, - globalGenericArg, - interfaceType); - } - - // Attach the concrete witness for this conformance to the - // substutiton - GlobalGenericParamSubstitution::ConstraintArg constraintArg; - constraintArg.decl = constraint; - constraintArg.val = witness; - subst->constraintArgs.add(constraintArg); - } - - // Add the substitution for this parameter to the global substitution - // set that we are building. - - *globalGenericSubstLink = subst; - globalGenericSubstLink = &subst->outer; - } + List specializationArgs; + _extractSpecializationArgs(unspecializedProgram, specializationArgExprs, specializationArgs, sink); if(sink->GetErrorCount()) return nullptr; - specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts); - - return specializedProgram; - } - - /// Create a specialized copy of `unspecializedProgram`. - /// - /// The specialized program will include the entry points - /// from the original program (whether thsoe entry points - /// are specialized or not). - /// - static RefPtr _createSpecializedProgram( - Linkage* linkage, - Program* unspecializedProgram, - List> const& globalGenericArgs, - List> const& globalExistentialArgs, - DiagnosticSink* sink) - { - auto specializedProgram = _createSpecializedProgramImpl( - linkage, - unspecializedProgram, - globalGenericArgs, - globalExistentialArgs, + auto specializedProgram = unspecializedProgram->specialize( + specializationArgs.getBuffer(), + specializationArgs.getCount(), sink); - // We need to ensure that the specialized program has whatever - // entry points (and groups) the unspecialized one had. - // - for( auto entryPointGroup : unspecializedProgram->getEntryPointGroups() ) - { - specializedProgram->addEntryPointGroup(entryPointGroup); - } - - // Now deal with the shader parameters and existential arguments - // - // Note: We should in theory be able to just copy over the shader - // parameters and existential slot information from the unspecialized - // program. This could save some time, but it would also mean that - // the only way to create a specialized program is by creating an - // unspecialized on first, which is maybe not always desirable. - // - specializedProgram->_collectShaderParams(sink); - specializedProgram->_specializeExistentialTypeParams(globalExistentialArgs, sink); - return specializedProgram; } - SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::specializeProgram( - slang::IProgram* inUnspecializedProgram, - SlangInt specializationArgCount, - slang::SpecializationArg const* specializationArgs, - slang::IProgram** outSpecializedProgram, - ISlangBlob** outDiagnostics) - { - auto unspecializedProgram = asInternal(inUnspecializedProgram); - - if( specializationArgCount == 0 ) - { - *outSpecializedProgram = ComPtr(asExternal(unspecializedProgram)).detach(); - return SLANG_OK; - } - - List> globalGenericArgs; - - List> globalExistentialArgs; - for( Int ii = 0; ii < specializationArgCount; ++ii ) - { - auto& specializationArg = specializationArgs[ii]; - switch( specializationArg.kind ) - { - case slang::SpecializationArg::Kind::Type: - { - auto typeArg = asInternal(specializationArg.type); - RefPtr argExpr = new SharedTypeExpr(); - argExpr->base = TypeExp(typeArg); - argExpr->type = QualType(getTypeType(typeArg)); - - globalExistentialArgs.add(argExpr); - } - break; - - default: - return SLANG_E_INVALID_ARG; - } - } - - DiagnosticSink sink(getSourceManager()); - auto specializedProgram = _createSpecializedProgram( - this, - unspecializedProgram, - globalGenericArgs, - globalExistentialArgs, - &sink); - sink.getBlobIfNeeded(outDiagnostics); - - if(!specializedProgram) - return SLANG_FAIL; - - *outSpecializedProgram = ComPtr(asExternal(specializedProgram)).detach(); - return SLANG_OK; - } - - /// Specialize an entry point that was checked by the front-end, based on generic arguments. + /// Specialize an entry point that was checked by the front-end, based on specialization arguments. /// - /// If the end-to-end compile request included generic argument strings + /// If the end-to-end compile request included specialization argument strings /// for this entry point, then they will be parsed, checked, and used /// as arguments to the generic entry point. /// /// Returns a specialized entry point if everything worked as expected. /// Returns null and diagnoses errors if anything goes wrong. /// - RefPtr createSpecializedEntryPoint( + RefPtr createSpecializedEntryPoint( EndToEndCompileRequest* endToEndReq, EntryPoint* unspecializedEntryPoint, EndToEndCompileRequest::EntryPointInfo const& entryPointInfo) @@ -11359,33 +11395,35 @@ static bool doesParameterMatch( // If the user specified generic arguments for the entry point, // then we will need to parse the arguments first. // - List> genericArgs; - parseGenericArgStrings( - endToEndReq, - entryPointInfo.genericArgStrings, - genericArgs); - - List> existentialArgs; - parseGenericArgStrings( + List> specializationArgExprs; + parseSpecializationArgStrings( endToEndReq, - entryPointInfo.existentialArgStrings, - existentialArgs); + entryPointInfo.specializationArgStrings, + specializationArgExprs); // Next we specialize the entry point function given the parsed // generic argument expressions. // auto entryPoint = createSpecializedEntryPoint( unspecializedEntryPoint, - genericArgs, - existentialArgs, + specializationArgExprs, sink); return entryPoint; } - /// Create a specialized program based on the given compile request. + /// Create a specialized component type for the global scope of the given compile request. + /// + /// The specialized program will be consistent with that created by + /// `createUnspecializedGlobalComponentType`, and will simply fill in + /// its specialization parameters with the arguments (if any) supllied + /// as part fo the end-to-end compile request. + /// + /// The layout of the new component type will be consistent with that + /// of the original *if* there are no global generic type parameters + /// (only interface/existential parameters). /// - RefPtr createSpecializedProgram( + RefPtr createSpecializedGlobalComponentType( EndToEndCompileRequest* endToEndReq) { // The compile request must have already completed front-end processing, @@ -11393,36 +11431,32 @@ static bool doesParameterMatch( // to parse and check any generic arguments that are being supplied for // global or entry-point generic parameters. // - auto unspecializedProgram = endToEndReq->getUnspecializedProgram(); + auto unspecializedProgram = endToEndReq->getUnspecializedGlobalComponentType(); auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); - // First, let's parse the generic argument strings that were - // provided via the API, so taht we can match them + // First, let's parse the specialization argument strings that were + // provided via the API, so that we can match them // against what was declared in the program. // - List> globalGenericArgs; - parseGenericArgStrings( + List> globalSpecializationArgs; + parseSpecializationArgStrings( endToEndReq, - endToEndReq->globalGenericArgStrings, - globalGenericArgs); + endToEndReq->globalSpecializationArgStrings, + globalSpecializationArgs); - // Also handle global existential type arguments. - List> globalExistentialArgs; - parseGenericArgStrings( - endToEndReq, - endToEndReq->globalExistentialSlotArgStrings, - globalExistentialArgs); + // Don't proceed further if anything failed to parse. + if(sink->GetErrorCount()) + return nullptr; // Now we create the initial specialized program by // applying the global generic arguments (if any) to the // unspecialized program. // - auto sink = endToEndReq->getSink(); auto specializedProgram = _createSpecializedProgramImpl( - endToEndReq->getLinkage(), + linkage, unspecializedProgram, - globalGenericArgs, - globalExistentialArgs, + globalSpecializationArgs, sink); // If anything went wrong with the global generic @@ -11450,33 +11484,69 @@ static bool doesParameterMatch( endToEndReq->entryPoints.setCount(entryPointCount); } - Index entryPointCounter = 0; + return specializedProgram; + } - for( auto unspecializedEntryPointGroup : unspecializedProgram->getEntryPointGroups() ) - { - List> specializedEntryPoints; - for( auto unspecializedEntryPoint : unspecializedEntryPointGroup->getEntryPoints() ) - { - Index entryPointIndex = entryPointCounter++; - auto& entryPointInfo = endToEndReq->entryPoints[entryPointIndex]; + /// Create a specialized program based on the given compile request. + /// + /// The specialized program created here includes both the global + /// scope for all the translation units involved and all the entry + /// points, and it also includes any specialization arguments + /// that were supplied. + /// + /// It is important to note that this function specializes + /// the global scope and the entry points in isolation and then + /// composes them, and that this can lead to different layout + /// from the result of `createUnspecializedGlobalAndEntryPointsComponentType`. + /// + /// If we have a module `M` with entry point `E`, and each has one + /// specialization parameter, then `createUnspecialized...` will yield: + /// + /// compose(M,E) + /// + /// That composed type will have two specialization parameters (the one + /// from `M` plus the one from `E`) and so we might specialize it to get: + /// + /// specialize(compose(M,E), X, Y) + /// + /// while if we use `createSpecialized...` we will get: + /// + /// compose(specialize(M,X), specialize(E,Y)) + /// + /// While these options are semantically equivalent, they would not lay + /// out the same way in memory. + /// + /// There are many reasons why an application might prefer one over the + /// other, and an application that cares should use the more explicit + /// APIs to construct what they want. The behavior of this function + /// is just to provide a reasonable default for use by end-to-end + /// compilation (e.g., from the command line). + /// + RefPtr createSpecializedGlobalAndEntryPointsComponentType( + EndToEndCompileRequest* endToEndReq) + { + auto specializedGlobalComponentType = endToEndReq->getSpecializedGlobalComponentType(); - auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); - specializedEntryPoints.add(specializedEntryPoint); - } + List> allComponentTypes; + allComponentTypes.add(specializedGlobalComponentType); - RefPtr specializedEntryPointGroup = EntryPointGroup::create(linkage, specializedEntryPoints, endToEndReq->getSink()); - specializedProgram->addEntryPointGroup(specializedEntryPointGroup); - } + auto unspecializedGlobalAndEntryPointsComponentType = endToEndReq->getUnspecializedGlobalAndEntryPointsComponentType(); + auto entryPointCount = unspecializedGlobalAndEntryPointsComponentType->getEntryPointCount(); - // Finalize the information for the specialized program, - // now that we have computed its entry point list, etc. - // - specializedProgram->_collectShaderParams(sink); - specializedProgram->_specializeExistentialTypeParams(globalExistentialArgs, sink); + for(Index ii = 0; ii < entryPointCount; ++ii) + { + auto& entryPointInfo = endToEndReq->entryPoints[ii]; + auto unspecializedEntryPoint = unspecializedGlobalAndEntryPointsComponentType->getEntryPoint(ii); - return specializedProgram; + auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); + allComponentTypes.add(specializedEntryPoint); + } + + RefPtr composed = CompositeComponentType::create(endToEndReq->getLinkage(), allComponentTypes); + return composed; } + void checkTranslationUnit( TranslationUnitRequest* translationUnit) { @@ -11488,6 +11558,8 @@ static bool doesParameterMatch( // checking that is required on all declarations // in the translation unit. visitor.checkDecl(translationUnit->getModuleDecl()); + + translationUnit->getModule()->_collectShaderParams(); } diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 64afef136..1c0bed065 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -199,10 +199,12 @@ namespace Slang // RefPtr EntryPoint::create( + Linkage* linkage, DeclRef funcDeclRef, Profile profile) { RefPtr entryPoint = new EntryPoint( + linkage, funcDeclRef.GetName(), profile, funcDeclRef); @@ -210,10 +212,12 @@ namespace Slang } RefPtr EntryPoint::createDummyForPassThrough( + Linkage* linkage, Name* name, Profile profile) { RefPtr entryPoint = new EntryPoint( + linkage, name, profile, DeclRef()); @@ -221,103 +225,93 @@ namespace Slang } EntryPoint::EntryPoint( + Linkage* linkage, Name* name, Profile profile, DeclRef funcDeclRef) - : m_name(name) + : ComponentType(linkage) + , m_name(name) , m_profile(profile) , m_funcDeclRef(funcDeclRef) { - // In order for later code generation to work, we need to track what - // modules each entry point depends on. We will build up the dependency - // list here when an `EntryPoint` gets created. - // - // We know an entry point depends on the module that declared the - // entry-point function itself. - // - // Note: we are carefully handling the case where `module` could - // be null, becase of "dummy" entry points created for pass-through - // compilation. + // Collect any specialization parameters used by the entry point // - if(auto module = getModule()) + _collectShaderParams(); + } + + Module* EntryPoint::getModule() + { + return Slang::getModule(getFuncDecl()); + } + + Index EntryPoint::getSpecializationParamCount() + { + return m_genericSpecializationParams.getCount() + m_existentialSpecializationParams.getCount(); + } + + SpecializationParam const& EntryPoint::getSpecializationParam(Index index) + { + auto genericParamCount = m_genericSpecializationParams.getCount(); + if(index < genericParamCount) { - m_dependencyList.addDependency(module); + return m_genericSpecializationParams[index]; } - // - // TODO: We also need to include the modules needed by any generic - // arguments in the dependency list, since in the general case they - // might come from modules other than the one defining the entry point. + else + { + return m_existentialSpecializationParams[index - genericParamCount]; + } + } - // The following is a bit of a hack. - // - // Back-end code generation relies on us having computed layouts for all tagged - // unions that end up being used in the code, which means we need a way to find - // all such types that get used in a program (and the stuff it imports). - // - // For now we are assuming a tagged union type only comes into existence - // as a (top-level) argument for a generic type parameter, so that we - // can check for them here and cache them on the entry point. + Index EntryPoint::getRequirementCount() + { + // The only requirement of an entry point is the module that contains it. // - // A longer-term strategy might need to consider any (tagged or untagged) - // union types that get used inside of a module, and also take - // those lists into account. + // TODO: We will eventually want to support the case of an entry + // point nested in a `struct` type, in which case there should be + // a single requirement representing that outer type (so that multiple + // entry points nested under the same type can share the storage + // for parameters at that scope). + + // Note: the defensive coding is here because the + // "dummy" entry points we create for pass-through + // compilation will not have an associated module. // - // An even longer-term strategy would be to allow type layout to - // be performed on IR types, so taht we don't need to have front-end - // code worrying about this stuff. - // - for( auto subst = funcDeclRef.substitutions.substitutions; subst; subst = subst->outer ) + if( auto module = getModule() ) { - if( auto genericSubst = as(subst) ) - { - for( auto arg : genericSubst->args ) - { - if( auto taggedUnionType = as(arg) ) - { - m_taggedUnionTypes.add(taggedUnionType); - } - } - } + return 1; } - - // Collect any existential-type parameters used by the entry point - // - _collectShaderParams(); + return 0; } - Module* EntryPoint::getModule() + RefPtr EntryPoint::getRequirement(Index index) { - return Slang::getModule(getFuncDecl()); + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + SLANG_ASSERT(getModule()); + return getModule(); } - Linkage* EntryPoint::getLinkage() + void EntryPoint::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) { - return getModule()->getLinkage(); + visitor->visitEntryPoint(this, as(specializationInfo)); } - // - // EntryPointGroup - // - - RefPtr EntryPointGroup::create( - Linkage* linkage, - List> const& entryPoints, - DiagnosticSink* sink) + List const& EntryPoint::getModuleDependencies() { - RefPtr group = new EntryPointGroup(linkage); - - for( auto entryPoint : entryPoints ) - { - for( auto module : entryPoint->getModuleDependencies() ) - { - group->m_dependencyList.addDependency(module); - } - group->m_entryPoints.add(entryPoint); - } + if(auto module = getModule()) + return module->getModuleDependencies(); - group->_collectShaderParams(sink); + static List empty; + return empty; + } - return group; + List const& EntryPoint::getFilePathDependencies() + { + if(auto module = getModule()) + return getModule()->getFilePathDependencies(); + + static List empty; + return empty; } // @@ -1949,7 +1943,7 @@ SlangResult dissassembleDXILUsingDXC( TargetRequest* targetReq, Int entryPointIndex) { - auto program = compileRequest->getSpecializedProgram(); + auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); auto targetProgram = program->getTargetProgram(targetReq); auto backEndReq = compileRequest->getBackEndReq(); @@ -2020,7 +2014,7 @@ SlangResult dissassembleDXILUsingDXC( return result; RefPtr backEndRequest = new BackEndCompileRequest( - m_program->getLinkageImpl(), + m_program->getLinkage(), sink, m_program); @@ -2083,7 +2077,7 @@ SlangResult dissassembleDXILUsingDXC( if (compileRequest->isCommandLineCompile) { auto linkage = compileRequest->getLinkage(); - auto program = compileRequest->getSpecializedProgram(); + auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); for (auto targetReq : linkage->targets) { Index entryPointCount = program->getEntryPointCount(); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index b7f62ae5b..46b257fe3 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -116,7 +116,6 @@ namespace Slang class Linkage; class Module; - class Program; class FrontEndCompileRequest; class BackEndCompileRequest; class EndToEndCompileRequest; @@ -146,13 +145,19 @@ namespace Slang struct ShaderParamInfo { DeclRef paramDeclRef; - UInt firstExistentialTypeSlot = 0; - UInt existentialTypeSlotCount = 0; + Int firstSpecializationParamIndex = 0; + Int specializationParamCount = 0; }; /// Extended information specific to global shader parameters struct GlobalShaderParamInfo : ShaderParamInfo { + // TODO: This type should be eliminated if/when we remove + // support for compilation with multiple translation units + // that all declare the "same" shader parameter (e.g., a + // `cbuffer`) and expect those duplicate declarations + // to get the same parameter binding/layout. + // Additional global-scope declarations that are conceptually // declaring the "same" parameter as the `paramDeclRef`. List> additionalParamDeclRefs; @@ -203,7 +208,7 @@ namespace Slang { public: /// Get the list of modules that are depended on. - List> const& getModuleList() { return m_moduleList; } + List const& getModuleList() { return m_moduleList; } /// Add a module and everything it depends on to the list. void addDependency(Module* module); @@ -214,8 +219,8 @@ namespace Slang private: void _addDependency(Module* module); - List> m_moduleList; - HashSet m_moduleSet; + List m_moduleList; + HashSet m_moduleSet; }; /// Tracks an unordered list of filesystem paths that something depends on @@ -245,6 +250,375 @@ namespace Slang HashSet m_filePathSet; }; + class EntryPoint; + + class ComponentType; + class ComponentTypeVisitor; + + /// Base class for "component types" that represent the pieces a final + /// shader program gets linked together from. + /// + class ComponentType : public RefObject, public slang::IComponentType + { + public: + // + // ISlangUnknown interface + // + + SLANG_REF_OBJECT_IUNKNOWN_ALL; + ISlangUnknown* getInterface(Guid const& guid); + + // + // slang::IComponentType interface + // + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE; + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( + SlangInt targetIndex, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW IComponentType* SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE; + + /// Get the linkage (aka "session" in the public API) for this component type. + Linkage* getLinkage() { return m_linkage; } + + /// Get the target-specific version of this program for the given `target`. + /// + /// The `target` must be a target on the `Linkage` that was used to create this program. + TargetProgram* getTargetProgram(TargetRequest* target); + + /// Get the number of entry points linked into this component type. + virtual Index getEntryPointCount() = 0; + + /// Get one of the entry points linked into this component type. + virtual RefPtr getEntryPoint(Index index) = 0; + + /// Get the number of global shader parameters linked into this component type. + virtual Index getShaderParamCount() = 0; + + /// Get one of the global shader parametesr linked into this component type. + virtual GlobalShaderParamInfo getShaderParam(Index index) = 0; + + /// Get the number of (unspecialized) specialization parameters for the component type. + virtual Index getSpecializationParamCount() = 0; + + /// Get the specialization parameter at `index`. + virtual SpecializationParam const& getSpecializationParam(Index index) = 0; + + /// Get the number of "requirements" that this component type has. + /// + /// A requirement represents another component type that this component + /// needs in order to function correctly. For example, the dependency + /// of one module on another module that it `import`s is represented + /// as a requirement, as is the dependency of an entry point on the + /// module that defines it. + /// + virtual Index getRequirementCount() = 0; + + /// Get the requirement at `index`. + virtual RefPtr getRequirement(Index index) = 0; + + /// Parse a type from a string, in the context of this component type. + /// + /// Any names in the string will be resolved using the modules + /// referenced by the program. + /// + /// On an error, returns null and reports diagnostic messages + /// to the provided `sink`. + /// + /// TODO: This function shouldn't be on the base class, since + /// it only really makes sense on `Module` and (as a compatibility + /// feature) on `LegacyProgram`. + /// + Type* getTypeFromString( + String const& typeStr, + DiagnosticSink* sink); + + /// Get a list of modules that this component type depends on. + /// + virtual List const& getModuleDependencies() = 0; + + /// Get the full list of filesystem paths this component type depends on. + /// + virtual List const& getFilePathDependencies() = 0; + + /// Callback for use with `enumerateIRModules` + typedef void (*EnumerateIRModulesCallback)(IRModule* irModule, void* userData); + + /// Invoke `callback` on all the IR modules that are (transitively) linked into this component type. + void enumerateIRModules(EnumerateIRModulesCallback callback, void* userData); + + /// Invoke `callback` on all the IR modules that are (transitively) linked into this component type. + template + void enumerateIRModules(F const& callback) + { + struct Helper + { + static void helper(IRModule* irModule, void* userData) + { + (*(F*)userData)(irModule); + } + }; + enumerateIRModules(&Helper::helper, (void*)&callback); + } + + /// Callback for use with `enumerateModules` + typedef void (*EnumerateModulesCallback)(Module* module, void* userData); + + /// Invoke `callback` on all the modules that are (transitively) linked into this component type. + void enumerateModules(EnumerateModulesCallback callback, void* userData); + + /// Invoke `callback` on all the modules that are (transitively) linked into this component type. + template + void enumerateModules(F const& callback) + { + struct Helper + { + static void helper(Module* module, void* userData) + { + (*(F*)userData)(module); + } + }; + enumerateModules(&Helper::helper, (void*)&callback); + } + + /// Side-band information generated when specializing this component type. + /// + /// Difference subclasses of `ComponentType` are expected to create their + /// own subclass of `SpecializationInfo` as the output of `_validateSpecializationArgs`. + /// Later, whenever we want to use a specialized component type we will + /// also have the `SpecializationInfo` available and will expect it to + /// have the correct (subclass-specific) type. + /// + class SpecializationInfo : public RefObject + { + }; + + /// Validate the given specialization `args` and compute any side-band specialization info. + /// + /// Any errors will be reported to `sink`, which can thus be used to test + /// if the operation was successful. + /// + /// A null return value is allowed, since not all subclasses require + /// custom side-band specialization information. + /// + /// This function is an implementation detail of `specialize()`. + /// + virtual RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) = 0; + + /// Validate the given specialization `args` and compute any side-band specialization info. + /// + /// Any errors will be reported to `sink`, which can thus be used to test + /// if the operation was successful. + /// + /// A null return value is allowed, since not all subclasses require + /// custom side-band specialization information. + /// + /// This function is an implementation detail of `specialize()`. + /// + RefPtr _validateSpecializationArgs( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) + { + if(argCount == 0) return nullptr; + return _validateSpecializationArgsImpl(args, argCount, sink); + } + + /// Specialize this component type given `specializationArgs` + /// + /// Any diagnostics will be reported to `sink`, which can be used + /// to determine if the operation was successful. It is allowed + /// for this operation to have a non-null return even when an + /// error is ecnountered. + /// + RefPtr specialize( + SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + DiagnosticSink* sink); + + /// Invoke `visitor` on this component type, using the appropriate dynamic type. + /// + /// This function implements the "visitor pattern" for `ComponentType`. + /// + /// If the `specializationInfo` argument is non-null, it must be specialization + /// information generated for this specific component type by `_validateSpecializationArgs`. + /// In that case, appropriately-typed specialization information will be passed + /// when invoking the `visitor`. + /// + virtual void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) = 0; + + protected: + ComponentType(Linkage* linkage); + + private: + RefPtr m_linkage; + + // Cache of target-specific programs for each target. + Dictionary> m_targetPrograms; + + // Any types looked up dynamically using `getTypeFromString` + // + // TODO: Remove this. Type lookup should only be supported on `Module`s. + // + Dictionary> m_types; + }; + + /// A component type built up from other component types. + class CompositeComponentType : public ComponentType + { + public: + static RefPtr create( + Linkage* linkage, + List> const& childComponents); + + List> const& getChildComponents() { return m_childComponents; }; + Index getChildComponentCount() { return m_childComponents.getCount(); } + RefPtr getChildComponent(Index index) { return m_childComponents[index]; } + + Index getEntryPointCount() SLANG_OVERRIDE; + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE; + + Index getShaderParamCount() SLANG_OVERRIDE; + GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE; + + Index getSpecializationParamCount() SLANG_OVERRIDE; + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr getRequirement(Index index) SLANG_OVERRIDE; + + List const& getModuleDependencies() SLANG_OVERRIDE; + List const& getFilePathDependencies() SLANG_OVERRIDE; + + class CompositeSpecializationInfo : public SpecializationInfo + { + public: + List> childInfos; + }; + + protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; + + + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + + private: + CompositeComponentType( + Linkage* linkage, + List> const& childComponents); + + List> m_childComponents; + + // The following arrays hold the concatenated entry points, parameters, + // etc. from the child components. This approach allows for reasonably + // fast (constant time) access through operations like `getShaderParam`, + // but means that the memory usage of a composite is proportional to + // the sum of the memory usage of the children, rather than being fixed + // by the number of children (as it would be if we just stored + // `m_childComponents`). + // + // TODO: We could conceivably build some O(numChildren) arrays that + // support binary-search to provide logarithmic-time access to entry + // points, parameters, etc. while giving a better overall memory usage. + // + List m_entryPoints; + List m_shaderParams; + List m_specializationParams; + List m_requirements; + + ModuleDependencyList m_moduleDependencyList; + FilePathDependencyList m_filePathDependencyList; + }; + + /// A component type created by specializing another component type. + class SpecializedComponentType : public ComponentType + { + public: + SpecializedComponentType( + ComponentType* base, + SpecializationInfo* specializationInfo, + List const& specializationArgs, + DiagnosticSink* sink); + + /// Get the base (unspecialized) component type that is being specialized. + RefPtr getBaseComponentType() { return m_base; } + + RefPtr getSpecializationInfo() { return m_specializationInfo; } + + /// Get the number of arguments supplied for existential type parameters. + /// + /// Note that the number of arguments may not match the number of parameters. + /// In particular, an unspecialized entry point may have many parameters, but zero arguments. + Index getSpecializationArgCount() { return m_specializationArgs.getCount(); } + + /// Get the existential type argument (type and witness table) at `index`. + SpecializationArg const& getSpecializationArg(Index index) { return m_specializationArgs[index]; } + + /// Get an array of all existential type arguments. + SpecializationArg const* getSpecializationArgs() { return m_specializationArgs.getBuffer(); } + + Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE { return m_base->getEntryPoint(index); } + + Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } + GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_base->getShaderParam(index); } + + Index getSpecializationParamCount() SLANG_OVERRIDE { return 0; } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); static SpecializationParam dummy; return dummy; } + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr getRequirement(Index index) SLANG_OVERRIDE; + + /// TODO: These should include requirements/dependencies for the types + /// referenced in the specialization arguments... + List const& getModuleDependencies() SLANG_OVERRIDE { return m_base->getModuleDependencies(); } + List const& getFilePathDependencies() SLANG_OVERRIDE { return m_base->getFilePathDependencies(); } + + /// Get a list of tagged-union types referenced by the specialization parameters. + List> const& getTaggedUnionTypes() { return m_taggedUnionTypes; } + + RefPtr getIRModule() { return m_irModule; } + + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; + + protected: + + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE + { + SLANG_UNUSED(args); + SLANG_UNUSED(argCount); + SLANG_UNUSED(sink); + return nullptr; + } + + private: + RefPtr m_base; + RefPtr m_specializationInfo; + SpecializationArgs m_specializationArgs; + RefPtr m_irModule; + + // Any tagged union types that were referenced by the specialization arguments. + List> m_taggedUnionTypes; + + }; + /// Describes an entry point for the purposes of layout and code generation. /// /// This class also tracks any generic arguments to the entry point, @@ -255,11 +629,12 @@ namespace Slang /// `getName()` and `getProfile()` methods should be expected to /// return useful data on pass-through entry points. /// - class EntryPoint : public RefObject + class EntryPoint : public ComponentType { public: /// Create an entry point that refers to the given function. static RefPtr create( + Linkage* linkage, DeclRef funcDeclRef, Profile profile); @@ -286,55 +661,66 @@ namespace Slang /// Get the module that contains the entry point. Module* getModule(); - /// Get the linkage that contains the module for this entry point. - Linkage* getLinkage(); - /// Get a list of modules that this entry point depends on. /// /// This will include the module that defines the entry point (see `getModule()`), /// but may also include modules that are required by its generic type arguments. /// - List> getModuleDependencies() { return m_dependencyList.getModuleList(); } - - /// Get a list of tagged-union types referenced by the entry point's generic parameters. - List> const& getTaggedUnionTypes() { return m_taggedUnionTypes; } + List const& getModuleDependencies() SLANG_OVERRIDE; // { return getModule()->getModuleDependencies(); } + List const& getFilePathDependencies() SLANG_OVERRIDE; // { return getModule()->getFilePathDependencies(); } /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. static RefPtr createDummyForPassThrough( + Linkage* linkage, Name* name, Profile profile); /// Get the number of existential type parameters for the entry point. - Index getExistentialTypeParamCount() { return m_existentialSlots.paramTypes.getCount(); } + Index getSpecializationParamCount() SLANG_OVERRIDE; /// Get the existential type parameter at `index`. - Type* getExistentialTypeParam(Index index) { return m_existentialSlots.paramTypes[index]; } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; - /// Get the number of arguments supplied for existential type parameters. - /// - /// Note that the number of arguments may not match the number of parameters. - /// In particular, an unspecialized entry point may have many parameters, but zero arguments. - Index getExistentialTypeArgCount() { return m_existentialSlots.args.getCount(); } + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr getRequirement(Index index) SLANG_OVERRIDE; - /// Get the existential type argument (type and witness table) at `index`. - ExistentialTypeSlots::Arg getExistentialTypeArg(Index index) { return m_existentialSlots.args[index]; } + SpecializationParams const& getExistentialSpecializationParams() { return m_existentialSpecializationParams; } - /// Get an array of all existential type arguments. - ExistentialTypeSlots::Arg const* getExistentialTypeArgs() { return m_existentialSlots.args.getBuffer(); } + Index getGenericSpecializationParamCount() { return m_genericSpecializationParams.getCount(); } + Index getExistentialSpecializationParamCount() { return m_existentialSpecializationParams.getCount(); } /// Get an array of all entry-point shader parameters. List const& getShaderParams() { return m_shaderParams; } - void _specializeExistentialTypeParams( - List> const& args, - DiagnosticSink* sink); + Index getEntryPointCount() SLANG_OVERRIDE { return 1; }; + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return this; } + + Index getShaderParamCount() SLANG_OVERRIDE { return 0; } + GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return GlobalShaderParamInfo(); } + + class EntryPointSpecializationInfo : public SpecializationInfo + { + public: + DeclRef specializedFuncDeclRef; + List existentialSpecializationArgs; + }; + + protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; + + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; private: EntryPoint( + Linkage* linkage, Name* name, Profile profile, DeclRef funcDeclRef); + void _collectGenericSpecializationParamsRec(Decl* decl); void _collectShaderParams(); // The name of the entry point function (e.g., `main`) @@ -345,8 +731,8 @@ namespace Slang // DeclRef m_funcDeclRef; - /// The existential/interface slots associated with the entry point parameter scope. - ExistentialTypeSlots m_existentialSlots; + SpecializationParams m_genericSpecializationParams; + SpecializationParams m_existentialSpecializationParams; /// Information about entry-point parameters List m_shaderParams; @@ -362,61 +748,6 @@ namespace Slang // intrinsic to the entry point. // Profile m_profile; - - // Any tagged union types that were referenced by the generic arguments of the entry point. - List> m_taggedUnionTypes; - - // Modules the entry point depends on. - ModuleDependencyList m_dependencyList; - }; - - class EntryPointGroup : public RefObject - { - public: - static RefPtr create( - Linkage* linkage, - List> const& entryPoints, - DiagnosticSink* sink); - - Linkage* getLinkageImpl() { return m_linkage; } - - /// Get the number of entry points in the group - Index getEntryPointCount() { return m_entryPoints.getCount(); } - - /// Get the entry point at the given `index`. - RefPtr getEntryPoint(Index index) { return m_entryPoints[index]; } - - /// Get the full ist of entry points in the group. - List> const& getEntryPoints() { return m_entryPoints; } - - /// Get a list of modules that this entry point group depends on. - /// - /// This will include the dependencies of all of the entry points in the group. - /// - List> getModuleDependencies() { return m_dependencyList.getModuleList(); } - - /// Get an array of all entry-point-group shader parameters. - List const& getShaderParams() { return m_shaderParams; } - - private: - EntryPointGroup(Linkage* linkage) - : m_linkage(linkage) - {} - - void _collectShaderParams(DiagnosticSink* sink); - - Linkage* m_linkage; - List> m_entryPoints; - - /// Information about shader parameters to be associated with the entry-point group itself. - /// - /// This list captures parameters that logically belong to the group itself, rather than - /// to any specific entry point in the group. - /// - List m_shaderParams; - - /// Modules the entry point group depends on. - ModuleDependencyList m_dependencyList; }; enum class PassThroughMode : SlangPassThrough @@ -439,19 +770,52 @@ namespace Slang /// may span multiple Slang source files), and provides access /// to both the AST and IR representations of that code. /// - class Module : public RefObject, public slang::IModule + class Module : public ComponentType, public slang::IModule { + typedef ComponentType Super; + public: SLANG_REF_OBJECT_IUNKNOWN_ALL ISlangUnknown* getInterface(const Guid& guid); + + // Forward `IComponentType` methods + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } + + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( + SlangInt targetIndex, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW IComponentType* SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize(specializationArgs, specializationArgCount, outDiagnostics); + } + + // + /// Create a module (initially empty). Module(Linkage* linkage); - /// Get the parent linkage of this module. - Linkage* getLinkage() { return m_linkage; } - /// Get the AST for the module (if it has been parsed) ModuleDecl* getModuleDecl() { return m_moduleDecl; } @@ -459,7 +823,7 @@ namespace Slang IRModule* getIRModule() { return m_irModule; } /// Get the list of other modules this module depends on - List> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); } + List const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); } /// Get the list of filesystem paths this module depends on List const& getFilePathDependencyList() { return m_filePathDependencyList.getFilePathList(); } @@ -474,7 +838,7 @@ namespace Slang /// /// This should only be called once, during creation of the module. /// - void setModuleDecl(ModuleDecl* moduleDecl) { m_moduleDecl = moduleDecl; } + void setModuleDecl(ModuleDecl* moduleDecl);// { m_moduleDecl = moduleDecl; } /// Set the IR for this module. /// @@ -482,16 +846,65 @@ namespace Slang /// void setIRModule(IRModule* irModule) { m_irModule = irModule; } - private: - // The parent linkage - Linkage* m_linkage = nullptr; + Index getEntryPointCount() SLANG_OVERRIDE { return 0; } + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return nullptr; } + + Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } + GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } + + Index getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { return m_specializationParams[index]; } + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr getRequirement(Index index) SLANG_OVERRIDE; + + List const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencyList.getModuleList(); } + List const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencyList.getFilePathList(); } + + /// Collect information on the shader parameters of the module. + /// + /// This method should only be called once, after the core + /// structured of the module (its AST and IR) have been created, + /// and before any of the `ComponentType` APIs are used. + /// + /// TODO: We might eventually consider a non-stateful approach + /// to constructing a `Module`. + /// + void _collectShaderParams(); + class ModuleSpecializationInfo : public SpecializationInfo + { + public: + struct GenericArgInfo + { + RefPtr paramDecl; + RefPtr argVal; + }; + + List genericArgs; + List existentialArgs; + }; + + protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; + + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + + private: // The AST for the module RefPtr m_moduleDecl; // The IR for the module RefPtr m_irModule = nullptr; + List m_shaderParams; + SpecializationParams m_specializationParams; + + List m_requirements; + // List of modules this module depends on ModuleDependencyList m_moduleDependencyList; @@ -650,15 +1063,10 @@ namespace Slang SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModule( const char* moduleName, slang::IBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW SlangResult SLANG_MCALL createProgram( - slang::ProgramDesc const& desc, - slang::IProgram** outProgram) override; - SLANG_NO_THROW SlangResult SLANG_MCALL specializeProgram( - slang::IProgram* program, - SlangInt specializationArgCount, - slang::SpecializationArg const* specializationArgs, - slang::IProgram** outSpecializedProgram, - ISlangBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::IComponentType* SLANG_MCALL createCompositeComponentType( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + ISlangBlob** outDiagnostics = nullptr) override; SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL specializeType( slang::TypeReflection* type, slang::SpecializationArg const* specializationArgs, @@ -980,202 +1388,133 @@ namespace Slang int translationUnitIndex, String const& path); - Program* getProgram() { return m_program; } + /// Get a component type that represents the global scope of the compile request. + ComponentType* getGlobalComponentType() { return m_globalComponentType; } + + /// Get a component type that represents the global scope of the compile request, plus the requested entry points. + ComponentType* getGlobalAndEntryPointsComponentType() { return m_globalAndEntryPointsComponentType; } private: - RefPtr m_program; + /// A component type that includes only the global scopes of the translation unit(s) that were compiled. + RefPtr m_globalComponentType; + + /// A component type that extends the global scopes with all of the entry points that were specified. + RefPtr m_globalAndEntryPointsComponentType; }; - /// A collection of code modules and entry points that are intended to be used together. + /// A "legacy" program composes multiple translation units from a single compile request, + /// and takes care to treat global declarations of the same name from different translation + /// units as representing the "same" parameter. /// - /// A `Program` establishes that certain pieces of code are intended - /// to be used togehter so that, e.g., layout can make sure to allocate - /// space for the global shader parameters in all referenced modules. + /// TODO: This type only exists to support a single requirement: that multiple translation + /// units can be compiled in one pass and be guaranteed that the "same" parameter declared + /// in different translation units (hence different modules) will get the same layout. + /// This feature should be deprecated and removed as soon as possible, since the complexity + /// it creates in the codebase is not justified by its limited utility. /// - class Program : public RefObject, public slang::IProgram + class LegacyProgram : public ComponentType { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL; - ISlangUnknown* getInterface(Guid const& guid); + LegacyProgram( + Linkage* linkage, + List> const& translationUnits, + DiagnosticSink* sink); - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override; + Index getTranslationUnitCount() { return m_translationUnits.getCount(); } + RefPtr getTranslationUnit(Index index) { return m_translationUnits[index]; } - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( - SlangInt targetIndex, - slang::IBlob** outDiagnostics) override; - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) override; - - - /// Create a new program, initially empty. - /// - /// All code loaded into the program must come - /// from the given `linkage`. - Program( - Linkage* linkage); - - /// Get the linkage that this program uses. - Linkage* getLinkageImpl() { return m_linkage; } - - /// Get the number of entry points added to the program - Index getEntryPointCount() { return m_entryPoints.getCount(); } - - /// Get the entry point at the given `index`. - RefPtr getEntryPoint(Index index) { return m_entryPoints[index]; } - - /// Get the full ist of entry points on the program. - List> const& getEntryPoints() { return m_entryPoints; } - - - Index getEntryPointGroupCount() { return m_entryPointGroups.getCount(); } - RefPtr getEntryPointGroup(Index index) { return m_entryPointGroups[index]; } - List> const& getEntryPointGroups() { return m_entryPointGroups; } - - - /// Get the substitution (if any) that represents how global generics are specialized. - RefPtr getGlobalGenericSubstitution() { return m_globalGenericSubst; } - - /// Get the full list of modules this program depends on - List> getModuleDependencies() { return m_moduleDependencyList.getModuleList(); } - - /// Get the full list of filesystem paths this program depends on - List getFilePathDependencies() { return m_filePathDependencyList.getFilePathList(); } - - /// Get the target-specific version of this program for the given `target`. - /// - /// The `target` must be a target on the `Linkage` that was used to create this program. - TargetProgram* getTargetProgram(TargetRequest* target); - - /// Add a module (and everything it depends on) to the list of references - void addReferencedModule(Module* module); - - /// Add a module (but not the things it depends on) to the list of references - /// - /// This is a compatiblity hack for legacy compiler behavior. - void addReferencedLeafModule(Module* module); + Index getEntryPointCount() SLANG_OVERRIDE { return 0; } + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return nullptr; } + Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } + GlobalShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } - /// Add an entry point to the program - /// - /// This also adds everything the entry point depends on to the list of references. - /// - void addEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink); + Index getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { return m_specializationParams[index]; } - /// Add an entry point group to the program - /// - /// This also adds everything the entry point group depends on to the list of references. - /// - void addEntryPointGroup(EntryPointGroup* entryPointGroup); + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr getRequirement(Index index) SLANG_OVERRIDE; - /// Set the global generic argument substitution to use. - void setGlobalGenericSubsitution(RefPtr subst) - { - m_globalGenericSubst = subst; - } + List const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies.getModuleList(); } + List const& getFilePathDependencies() SLANG_OVERRIDE { return m_fileDependencies.getFilePathList(); } - /// Parse a type from a string, in the context of this program. - /// - /// Any names in the string will be resolved using the modules - /// referenced by the program. - /// - /// On an error, returns null and reports diagnostic messages - /// to the provided `sink`. - /// - Type* getTypeFromString(String typeStr, DiagnosticSink* sink); - - /// Get the IR module that represents this program and its entry points. - /// - /// The IR module for a program tries to be minimal, and in the - /// common case will only include symbols with `[import]` declarations - /// for the entry point(s) of the program, and any types they - /// depend on. - /// - /// This IR module is intended to be linked against the IR modules - /// for all of the dependencies (see `getModuleDependencies()`) to - /// provide complete code. - /// - RefPtr getOrCreateIRModule(DiagnosticSink* sink); - - /// Get the number of existential type parameters for the program. - Index getExistentialTypeParamCount() { return m_globalExistentialSlots.paramTypes.getCount(); } - - /// Get the existential type parameter at `index`. - Type* getExistentialTypeParam(Index index) { return m_globalExistentialSlots.paramTypes[index]; } - - /// Get the number of arguments supplied for existential type parameters. - /// - /// Note that the number of arguments may not match the number of parameters. - /// In particular, an unspecialized program may have many parameters, but zero arguments. - Index getExistentialTypeArgCount() { return m_globalExistentialSlots.args.getCount(); } - - /// Get the existential type argument (type and witness table) at `index`. - ExistentialTypeSlots::Arg getExistentialTypeArg(Index index) { return m_globalExistentialSlots.args[index]; } - - /// Get an array of all existential type arguments. - ExistentialTypeSlots::Arg const* getExistentialTypeArgs() { return m_globalExistentialSlots.args.getBuffer(); } - - /// Get an array of all global shader parameters. - List const& getShaderParams() { return m_shaderParams; } + protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; - void _collectShaderParams(DiagnosticSink* sink); - void _specializeExistentialTypeParams( - List> const& args, - DiagnosticSink* sink); + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; private: + void _collectShaderParams(DiagnosticSink* sink); - // The linakge this program is associated with. - // - // Note that a `Program` keeps its associated linkage alive, - // and not vice versa. - // - RefPtr m_linkage; - - // Tracking data for the list of modules dependend on - ModuleDependencyList m_moduleDependencyList; - - // Tracking data for the list of filesystem paths dependend on - FilePathDependencyList m_filePathDependencyList; - - // Entry points that are part of the program. - List > m_entryPoints; - - // Entry points that are part of the program. - List > m_entryPointGroups; - - // Specializations for global generic parameters (if any) - RefPtr m_globalGenericSubst; - - // The existential/interface slots associated with the global scope. - ExistentialTypeSlots m_globalExistentialSlots; + List> m_translationUnits; - /// Information about global shader parameters + List m_entryPoints; List m_shaderParams; + List m_requirements; + SpecializationParams m_specializationParams; + ModuleDependencyList m_moduleDependencies; + FilePathDependencyList m_fileDependencies; + }; - // Generated IR for this program. - RefPtr m_irModule; - - // Cache of target-specific programs for each target. - Dictionary> m_targetPrograms; + /// A visitor for use with `ComponentType`s, allowing dispatch over the concrete subclasses. + class ComponentTypeVisitor + { + public: + // The following methods should be overriden in a concrete subclass + // to customize how it acts on each of the concrete types of component. + // + // In cases where the application wants to simply "recurse" on a + // composite, specialized, or legacy component type it can use + // the `visitChildren` methods below. + // + virtual void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; + virtual void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) = 0; + virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; + virtual void visitSpecialized(SpecializedComponentType* specialized) = 0; + virtual void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; - // Any types looked up dynamically using `getTypeFromString` - Dictionary> m_types; + protected: + // These helpers can be used to recurse into the logical children of a + // component type, and are useful for the common case where a visitor + // only cares about a few leaf cases. + // + // Note that for a `LegacyProgram` the "children" in this case are the + // `Module`s of the translation units that make up the legacy program. + // In some cases this is what is desired, but in others it is incorrect + // to treat a legacy program as a composition of modules, and instead + // it should be treated directly as a leaf case. Clients should make + // an informed decision based on an understanding of what `LegacyProgram` is used for. + // + void visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo); + void visitChildren(SpecializedComponentType* specialized); + void visitChildren(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo); }; - /// A `Program` specialized for a particular `TargetRequest` + /// A `TargetProgram` represents a `ComponentType` specialized for a particular `TargetRequest` + /// + /// TODO: This should probably be renamed to `TargetComponentType`. + /// + /// By binding a component type to a specific target, a `TargetProgram` allows + /// for things like layout to be computed, that fundamentally depend on + /// the choice of target. + /// + /// A `TargetProgram` handles request for compiled kernel code for + /// entry point functions. In practice, kernel code can only be + /// correctly generated when the underlying `ComponentType` is "fully linked" + /// (has no remaining unsatisfied requirements). + /// class TargetProgram : public RefObject { public: TargetProgram( - Program* program, + ComponentType* componentType, TargetRequest* targetReq); /// Get the underlying program - Program* getProgram() { return m_program; } + ComponentType* getProgram() { return m_program; } /// Get the underlying target TargetRequest* getTargetReq() { return m_targetReq; } @@ -1232,7 +1571,7 @@ namespace Slang private: // The program being compiled or laid out - Program* m_program; + ComponentType* m_program; // The target that code/layout will be generated for TargetRequest* m_targetReq; @@ -1253,7 +1592,7 @@ namespace Slang BackEndCompileRequest( Linkage* linkage, DiagnosticSink* sink, - Program* program = nullptr); + ComponentType* program = nullptr); // Should we dump intermediate results along the way, for debugging? bool shouldDumpIntermediates = false; @@ -1263,8 +1602,8 @@ namespace Slang LineDirectiveMode getLineDirectiveMode() { return lineDirectiveMode; } - Program* getProgram() { return m_program; } - void setProgram(Program* program) { m_program = program; } + ComponentType* getProgram() { return m_program; } + void setProgram(ComponentType* program) { m_program = program; } // Should R/W images without explicit formats be assumed to have "unknown" format? // @@ -1273,7 +1612,7 @@ namespace Slang bool useUnknownImageFormatAsDefault = false; private: - RefPtr m_program; + RefPtr m_program; }; /// A compile request that spans the front and back ends of the compiler @@ -1311,11 +1650,8 @@ namespace Slang // Should we just pass the input to another compiler? PassThroughMode passThrough = PassThroughMode::None; - /// Source code for the generic arguments to use for the global generic parameters of the program. - List globalGenericArgStrings; - - /// Types to use to fill global existential "slots" - List globalExistentialSlotArgStrings; + /// Source code for the specialization arguments to use for the global specialization parameters of the program. + List globalSpecializationArgStrings; bool shouldSkipCodegen = false; @@ -1331,11 +1667,8 @@ namespace Slang class EntryPointInfo : public RefObject { public: - /// Source code for the generic arguments to use for the generic parameters of the entry point. - List genericArgStrings; - - /// Source code for the type arguments to plug into the existential type "slots" of the entry point - List existentialArgStrings; + /// Source code for the specialization arguments to use for the specialization parameters of the entry point. + List specializationArgStrings; }; List entryPoints; @@ -1370,8 +1703,15 @@ namespace Slang FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; } BackEndCompileRequest* getBackEndReq() { return m_backEndReq; } - Program* getUnspecializedProgram() { return getFrontEndReq()->getProgram(); } - Program* getSpecializedProgram() { return m_specializedProgram; } + + ComponentType* getUnspecializedGlobalComponentType() { return getFrontEndReq()->getGlobalComponentType(); } + ComponentType* getUnspecializedGlobalAndEntryPointsComponentType() + { + return getFrontEndReq()->getGlobalAndEntryPointsComponentType(); + } + + ComponentType* getSpecializedGlobalComponentType() { return m_specializedGlobalComponentType; } + ComponentType* getSpecializedGlobalAndEntryPointsComponentType() { return m_specializedGlobalAndEntryPointsComponentType; } private: void init(); @@ -1380,8 +1720,8 @@ namespace Slang RefPtr m_linkage; DiagnosticSink m_sink; RefPtr m_frontEndReq; - RefPtr m_unspecializedProgram; - RefPtr m_specializedProgram; + RefPtr m_specializedGlobalComponentType; + RefPtr m_specializedGlobalAndEntryPointsComponentType; RefPtr m_backEndReq; // For output @@ -1623,14 +1963,14 @@ inline slang::IModule* asExternal(Module* module) return static_cast(module); } -inline Program* asInternal(slang::IProgram* module) +inline ComponentType* asInternal(slang::IComponentType* componentType) { - return static_cast(module); + return static_cast(componentType); } -inline slang::IProgram* asExternal(Program* module) +inline slang::IComponentType* asExternal(ComponentType* componentType) { - return static_cast(module); + return static_cast(componentType); } static inline slang::ProgramLayout* asExternal(ProgramLayout* programLayout) diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index e8114dc95..802c89f8b 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -240,6 +240,8 @@ DIAGNOSTIC(30049, Note, thisIsImmutableByDefault, "a 'this' parameter is an imm DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'") DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'") +DIAGNOSTIC(30060, Error, expectedAType, "expected a type got a '$0'") + DIAGNOSTIC(30100, Error, staticRefToNonStaticMember, "type '$0' cannot be used to refer to non-static member '$1'") DIAGNOSTIC(30201, Error, functionRedefinition, "function '$0' already has a body") @@ -361,21 +363,22 @@ DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is onl DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration") DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.") -DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$2`.") +DIAGNOSTIC(38021, Error, typeArgumentForGenericParameterDoesNotConformToInterface, "type argument `$0` for generic parameter `$1` does not conform to interface `$2`.") DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself") DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')") + DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_DispatchThreadID must be either scalar or vector (1 to 3) of uint/int but is $0"); DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'") -DIAGNOSTIC(38025, Error, mismatchGlobalGenericArguments, "expected $0 global generic arguments ($1 provided)") +DIAGNOSTIC(38025, Error, mismatchSpecializationArguments, "expected $0 specialization arguments ($1 provided)") DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.") DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)") DIAGNOSTIC(38028, Error, existentialSlotArgNotAType, "existential slot argument $0 was not a type") -DIAGNOSTIC(38029, Error, existentialSlotArgDoesNotConform, "existential slot argument $0 does not conform to the required interface '$1'") +DIAGNOSTIC(38029, Error,typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'") DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC(39999, Fatal, errorInImportedModule, "error in imported module, compilation ceased.") diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index c7b5b602b..26af7b9f4 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -2459,13 +2459,13 @@ String CLikeSourceEmitter::getFuncName(IRFunc* func) // name for an entry-point function, but other // targets should try to use the original name. // - // TODO: always use `main`, and have any code - // that wraps this know to use `main` instead - // of the original entry-point name... + // TODO: always use the original name, and + // use the appropriate options for glslang to + // make it support a non-`main` name. // if (getSourceStyle() != SourceStyle::GLSL) { - return getText(entryPointLayout->entryPoint->getName()); + return getText(entryPointLayout->getFuncDecl()->getName()); } // diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 3b9023703..d3c852fd6 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -687,20 +687,20 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointL break; case Stage::Geometry: { - if (auto attrib = entryPointLayout->entryPoint->FindModifier()) + if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier()) { m_writer->emit("layout(max_vertices = "); m_writer->emit(attrib->value); m_writer->emit(") out;\n"); } - if (auto attrib = entryPointLayout->entryPoint->FindModifier()) + if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier()) { m_writer->emit("layout(invocations = "); m_writer->emit(attrib->value); m_writer->emit(") in;\n"); } - for (auto pp : entryPointLayout->entryPoint->GetParameters()) + for (auto pp : entryPointLayout->getFuncDecl()->GetParameters()) { if (auto inputPrimitiveTypeModifier = pp->FindModifier()) { diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 9c0f7f02b..4abd692f8 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -294,13 +294,13 @@ void HLSLSourceEmitter::_emitHLSLEntryPointAttributes(IRFunc* irFunc, EntryPoint break; case Stage::Geometry: { - if (auto attrib = entryPointLayout->entryPoint->FindModifier()) + if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier()) { m_writer->emit("[maxvertexcount("); m_writer->emit(attrib->value); m_writer->emit(")]\n"); } - if (auto attrib = entryPointLayout->entryPoint->FindModifier()) + if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier()) { m_writer->emit("[instance("); m_writer->emit(attrib->value); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index c51040b98..d5dcaca98 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -49,36 +49,34 @@ enum class BuiltInCOp EntryPointLayout* findEntryPointLayout( ProgramLayout* programLayout, - EntryPoint* entryPoint, - EntryPointGroupLayout** outEntryPointGroupLayout = nullptr) + EntryPoint* entryPoint) { - for( auto entryPointGroupLayout : programLayout->entryPointGroups ) + // TODO: This function shouldn't need to exist, and it + // somewhat hampers the capabilities of the compiler (e.g., + // it isn't supported to have a single program contain + // two different "instances" of the same entry point). + // + // Code that cares about layouts should be looking up + // the entry point layout by index on a `ProgramLayout`, + // knowing that those indices will align with the order + // of entry points on the `ComponentType` for the program. + + for( auto entryPointLayout : programLayout->entryPoints ) { - for( auto entryPointLayout : entryPointGroupLayout->entryPoints ) - { - if(entryPointLayout->entryPoint->getName() != entryPoint->getName()) - continue; - - // TODO: We need to be careful about this check, since it relies on - // the profile information in the layout matching that in the request. - // - // What we really seem to want here is some dictionary mapping the - // `EntryPoint` directly to the `EntryPointLayout`, and maybe - // that is precisely what we should build... - // - if(entryPointLayout->profile != entryPoint->getProfile()) - continue; - - // TODO: can't easily filter on translation unit here... - // Ideally the `EntryPoint` should get filled in with a pointer - // the specific function declaration that represents the entry point. - - if( outEntryPointGroupLayout ) - { - *outEntryPointGroupLayout = entryPointGroupLayout; - } - return entryPointLayout; - } + if(entryPointLayout->entryPoint.GetName() != entryPoint->getName()) + continue; + + // TODO: We need to be careful about this check, since it relies on + // the profile information in the layout matching that in the request. + // + // What we really seem to want here is some dictionary mapping the + // `EntryPoint` directly to the `EntryPointLayout`, and maybe + // that is precisely what we should build... + // + if(entryPointLayout->profile != entryPoint->getProfile()) + continue; + + return entryPointLayout; } return nullptr; diff --git a/source/slang/slang-ir-bind-existentials.cpp b/source/slang/slang-ir-bind-existentials.cpp index e426e6e92..a99da3410 100644 --- a/source/slang/slang-ir-bind-existentials.cpp +++ b/source/slang/slang-ir-bind-existentials.cpp @@ -69,24 +69,8 @@ struct BindExistentialSlots void processGlobalExistentialSlots() { - // If there are any global existential slots, we will expect - // to find a `bindGlobalExistentialSlots` instruction at module scope. - // - // We will start out by finding that instruction, if it exists. - // - IRInst* bindGlobalExistentialSlotsInst = nullptr; - for( auto inst : module->getGlobalInsts() ) - { - if( inst->op == kIROp_BindGlobalExistentialSlots ) - { - bindGlobalExistentialSlotsInst = inst; - break; - } - } - - // Now we will start looking for global shader parameters that make - // use of existential slots (we can determine this from their - // layout). + // We will search for global shader parameters that make + // use of existential specialization parameters. // for( auto inst : module->getGlobalInsts() ) { @@ -96,23 +80,27 @@ struct BindExistentialSlots if(!globalParam) continue; - // We will delegate to a subroutine for the meat - // of the work, since much of it can be shared - // with the case for entry-point existential - // parameters. + // We only care about global shader parameters + // that have existential specialization parameters, + // and we expect all such parameters to have a + // `[bindExistentialSlots(...)]` decoration that + // was added during IR linking. // - processParameter(globalParam, bindGlobalExistentialSlotsInst); - } + auto bindSlotsInst = globalParam->findDecorationImpl(kIROp_BindExistentialSlotsDecoration); + if(!bindSlotsInst) + continue; - // Once we are done looping over global shader parameters, - // all of the relevant information from the - // `bindGlobalExistentialSlots` instruction will have - // been moved to the parameters themselves, so we - // can eliminate the binding instruction. - // - if( bindGlobalExistentialSlotsInst ) - { - bindGlobalExistentialSlotsInst->removeAndDeallocate(); + replaceTypeUsingExistentialSlots( + globalParam, + bindSlotsInst->getOperandCount(), + bindSlotsInst->getOperands()); + + // Once we have propagated the information from + // the `[bindExistentialSlots(...)]` decoration + // down into the parameter's type, we no longer + // need the decoration. + // + bindSlotsInst->removeAndDeallocate(); } } @@ -150,9 +138,25 @@ struct BindExistentialSlots // We then need to process each of the entry-point // parameters just like we did for global parameters. // + // Because the existential slot arguments for *all* of the parameters + // are attached in a single `[bindExistentialSlots(...)]` decoration, + // we need to carve them up appropriately across the parameters. + // The way we do this is a bit of a kludge, in that we track a + // single `slotOffset` and increment it for each parameter by the + // number of arguments it consumed. + // + // Note: a better approach here might rely on the layout information + // for the parameters, which should directly encode an offset for + // the existential specialization parameters it uses. The challenge + // with this is that we'd need to correctly interpret the offset + // relative to any global-scope specialization parameters or + // generic specialization parameters of the entry point. + // Ultimately the simplistic counter approach is less complicated. + // + Index slotOffset = 0; for( auto param : func->getParams() ) { - processParameter(param, bindEntryPointExistentialSlotsInst); + processEntryPointParameter(param, bindEntryPointExistentialSlotsInst, slotOffset); } // TODO: We would need to consider what to do if @@ -181,9 +185,10 @@ struct BindExistentialSlots // function `param` and a `[bindExistentialSlots(...)]` // decoration; both use the same subroutine. // - void processParameter( + void processEntryPointParameter( IRInst* param, - IRInst* bindSlotsInst) + IRInst* bindSlotsInst, + Index& ioSlotOperandOffset) { // We expect all shader parameters to have layout information, // but to be defensive we will skip any that don't. @@ -206,7 +211,6 @@ struct BindExistentialSlots // find out the stating slot, and the information on // the type to find out the number of slots. // - UInt firstSlot = resInfo->index; UInt slotCount = 0; if(auto typeResInfo = varLayout->getTypeLayout()->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam)) slotCount = UInt(typeResInfo->count.getFiniteValue()); @@ -233,7 +237,8 @@ struct BindExistentialSlots // this parameter. // UInt bindOperandCount = bindSlotsInst->getOperandCount(); - if( 2*(firstSlot + slotCount) > bindOperandCount ) + UInt slotOperandCount = 2*slotCount; + if( (ioSlotOperandOffset + slotOperandCount) > bindOperandCount ) { sink->diagnose(param->sourceLoc, Diagnostics::missingExistentialBindingsForParameter); return; @@ -244,7 +249,7 @@ struct BindExistentialSlots // keeping in mind that each slot accounts for two // operands. // - auto operandsForInst = bindSlotsInst->getOperands() + firstSlot; + auto operandsForInst = bindSlotsInst->getOperands() + ioSlotOperandOffset; // Once we've found the operands that are relevent to // the slots used by `param`, we will defer to a routine @@ -253,17 +258,17 @@ struct BindExistentialSlots // replaceTypeUsingExistentialSlots( param, - slotCount, + slotOperandCount, operandsForInst); + + ioSlotOperandOffset += slotOperandCount; } void replaceTypeUsingExistentialSlots( IRInst* inst, - UInt slotCount, + UInt slotOperandCount, IRUse const* slotArgs) { - SLANG_UNUSED(slotCount); - // We are going to alter the type of the // given `inst` based on information in // the `slotArgs`. @@ -282,7 +287,6 @@ struct BindExistentialSlots // a witness table, so the total number of operands // is twice the number of slots we are filling. // - UInt slotOperandCount = slotCount*2; List slotOperands; for(UInt ii = 0; ii < slotOperandCount; ++ii) slotOperands.add(slotArgs[ii].get()); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 0c218f2aa..f9c157465 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -206,7 +206,6 @@ INST(Specialize, specialize, 2, 0) INST(lookup_interface_method, lookup_interface_method, 2, 0) INST(lookup_witness_table, lookup_witness_table, 2, 0) INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0) -INST(BindGlobalExistentialSlots, bindGlobalExistentialSlots, 0, 0) INST(Construct, construct, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 359c8e98d..cd59630d8 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1175,10 +1175,6 @@ struct IRBuilder IRInst* param, IRInst* val); - IRInst* emitBindGlobalExistentialSlots( - UInt argCount, - IRInst* const* args); - IRDecoration* addBindExistentialSlotsDecoration( IRInst* value, UInt argCount, diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 7e2ac1f98..56b06a499 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -8,14 +8,13 @@ namespace Slang { -// Needed for lookup up entry-point layouts. -// -// TODO: maybe arrange so that codegen is driven from the layout layer -// instead of the input/request layer. + /// Find a suitable layout for `entryPoint` in `programLayout`. + /// + /// TODO: This function should be eliminated. See its body + /// for an explanation of the problems. EntryPointLayout* findEntryPointLayout( ProgramLayout* programLayout, - EntryPoint* entryPoint, - EntryPointGroupLayout** outEntryPointGroupLayout); + EntryPoint* entryPoint); struct IRSpecSymbol : RefObject { @@ -339,9 +338,10 @@ IRType* cloneType( } void cloneGlobalValueWithCodeCommon( - IRSpecContextBase* context, - IRGlobalValueWithCode* clonedValue, - IRGlobalValueWithCode* originalValue); + IRSpecContextBase* context, + IRGlobalValueWithCode* clonedValue, + IRGlobalValueWithCode* originalValue, + IROriginalValuesForClone const& originalValues); IRRate* cloneRate( IRSpecContextBase* context, @@ -382,11 +382,58 @@ IRGlobalVar* cloneGlobalVarImpl( cloneGlobalValueWithCodeCommon( context, clonedVar, - originalVar); + originalVar, + originalValues); return clonedVar; } + /// Clone certain special decorations for `clonedInst` from its (potentially multiple) definitions. + /// + /// In most cases, once we've decided on the "best" definition to use for an IR instruction, + /// we only want the linking process to use the decorations from the single best definition. + /// In some casses, though, the canonical best definition might not have all the information. + /// + /// A concrete example is the `[bindExistentialsSlots(...)]` decorations for global shader + /// parameters and entry points. These decorations are only generated as part of the IR + /// associated with a specialization of a program, and not the original IR for the modules + /// of the program. + /// + /// This function scans through all the `originalValues` that were considered for `clonedInst`, + /// and copies over any decorations that are allowed to come from a non-"best" definition. + /// For a given decoration opcode, only one such decoration will ever be copied, and nothing + /// will be copied if the instruction already has a matching decoration (that was cloned + /// from the "best" definition). + /// +static void cloneExtraDecorations( + IRSpecContextBase* context, + IRInst* clonedInst, + IROriginalValuesForClone const& originalValues) +{ + IRBuilder builderStorage = *context->builder; + IRBuilder* builder = &builderStorage; + builder->setInsertInto(clonedInst); + + for(auto sym = originalValues.sym; sym; sym = sym->nextWithSameName) + { + for(auto decoration : sym->irGlobalValue->getDecorations()) + { + switch(decoration->op) + { + default: + break; + + case kIROp_BindExistentialSlotsDecoration: + if(!clonedInst->findDecorationImpl(decoration->op)) + { + cloneInst(context, builder, decoration); + } + break; + } + } + } +} + void cloneSimpleGlobalValueImpl( IRSpecContextBase* context, IRInst* originalInst, @@ -407,6 +454,12 @@ void cloneSimpleGlobalValueImpl( { cloneInst(context, builder, child); } + + // Also clone certain decorations if they appear on *any* + // definition of the symbol (not necessarily the one + // we picked as the primary/best). + // + cloneExtraDecorations(context, clonedInst, originalValues); } IRGlobalParam* cloneGlobalParamImpl( @@ -469,7 +522,8 @@ IRGeneric* cloneGenericImpl( cloneGlobalValueWithCodeCommon( context, clonedVal, - originalVal); + originalVal, + originalValues); return clonedVal; } @@ -545,7 +599,8 @@ IRInterfaceType* cloneInterfaceTypeImpl( void cloneGlobalValueWithCodeCommon( IRSpecContextBase* context, IRGlobalValueWithCode* clonedValue, - IRGlobalValueWithCode* originalValue) + IRGlobalValueWithCode* originalValue, + IROriginalValuesForClone const& originalValues) { // Next we are going to clone the actual code. IRBuilder builderStorage = *context->builder; @@ -553,6 +608,7 @@ void cloneGlobalValueWithCodeCommon( builder->setInsertInto(clonedValue); cloneDecorations(context, clonedValue, originalValue); + cloneExtraDecorations(context, clonedValue, originalValues); // We will walk through the blocks of the function, and clone each of them. // @@ -629,10 +685,11 @@ void checkIRDuplicate(IRInst* inst, IRInst* moduleInst, UnownedStringSlice const } void cloneFunctionCommon( - IRSpecContextBase* context, - IRFunc* clonedFunc, - IRFunc* originalFunc, - bool checkDuplicate = true) + IRSpecContextBase* context, + IRFunc* clonedFunc, + IRFunc* originalFunc, + IROriginalValuesForClone const& originalValues, + bool checkDuplicate = true) { // First clone all the simple properties. clonedFunc->setFullType(cloneType(context, originalFunc->getFullType())); @@ -640,7 +697,8 @@ void cloneFunctionCommon( cloneGlobalValueWithCodeCommon( context, clonedFunc, - originalFunc); + originalFunc, + originalValues); // Shuffle the function to the end of the list, because // it needs to follow its dependencies. @@ -667,7 +725,6 @@ IRInst* specializeGeneric( IRFunc* specializeIRForEntryPoint( IRSpecContext* context, - EntryPoint* entryPoint, EntryPointLayout* entryPointLayout) { // We start by looking up the IR symbol that @@ -679,7 +736,7 @@ IRFunc* specializeIRForEntryPoint( // so that the mangled name of the decl-ref is // not the same as the mangled name of the decl. // - auto mangledName = getMangledName(entryPoint->getFuncDeclRef()); + auto mangledName = getMangledName(entryPointLayout->getFuncDeclRef()); RefPtr sym; if (!context->getSymbols().TryGetValue(mangledName, sym)) { @@ -947,7 +1004,7 @@ IRFunc* cloneFuncImpl( { auto clonedFunc = builder->createFunc(); registerClonedValue(context, clonedFunc, originalValues); - cloneFunctionCommon(context, clonedFunc, originalFunc); + cloneFunctionCommon(context, clonedFunc, originalFunc, originalValues); return clonedFunc; } @@ -1227,7 +1284,9 @@ LinkedIR linkIR( CodeGenTarget target, TargetRequest* targetReq) { - auto sink = compileRequest->getSink(); + // TODO: We need to make sure that the program we are being asked + // to compile has been "resolved" so that it has no outstanding + // unsatisfied requirements. IRSpecializationState stateStorage; auto state = &stateStorage; @@ -1252,12 +1311,11 @@ LinkedIR linkIR( // accelerate lookup, we will create a symbol table for looking // up IR definitions by their mangled name. // - auto originalProgramIRModule = program->getOrCreateIRModule(sink); - insertGlobalValueSymbols(sharedContext, originalProgramIRModule); - for (auto module : program->getModuleDependencies()) + program->enumerateIRModules([&](IRModule* irModule) { - insertGlobalValueSymbols(sharedContext, module->getIRModule()); - } + insertGlobalValueSymbols(sharedContext, irModule); + }); + auto context = state->getContext(); context->shared = sharedContext; @@ -1282,32 +1340,9 @@ LinkedIR linkIR( context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); } - EntryPointGroupLayout* entryPointGroupLayout = nullptr; - auto entryPointLayout = findEntryPointLayout(programLayout, entryPoint, &entryPointGroupLayout); - - auto offsetEntryPointLayout = entryPointLayout->getAbsoluteLayout(entryPointGroupLayout); + auto entryPointLayout = findEntryPointLayout(programLayout, entryPoint); - // Note: when we are doing the compatibility approach for Falcor, we - // can have global-scope symbols that are actually part of the - // local root signature (entry point group), so we need to make - // sure to apply those layouts appropriately. - auto entryPointGroupStructLayout = getScopeStructLayout(entryPointGroupLayout); - for(auto entry : entryPointGroupStructLayout->mapVarToLayout) - { - if(!entry.Key) - continue; - - auto mangledName = getMangledName(entry.Key); - auto groupVarLayout = entry.Value; - - // We need to "adjust" the layout that was computed for the parameter - // because it will be relative to the start of the entry-point group, - // rather than absolute. - // - auto absoluteVarLayout = groupVarLayout->getAbsoluteLayout(entryPointGroupLayout->parametersLayout); - - context->globalVarLayouts.AddIfNotExists(mangledName, absoluteVarLayout); - } + auto offsetEntryPointLayout = entryPointLayout; context->builder->setInsertInto(context->getModule()->getModuleInst()); @@ -1316,7 +1351,7 @@ LinkedIR linkIR( // TODO: This step should *not* be needed with the current IR // specialization approach, so we should consider removing it. // - for (auto sym :context->getSymbols()) + for (auto sym : context->getSymbols()) { if (sym.Value->irGlobalValue->op == kIROp_WitnessTable) cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); @@ -1327,28 +1362,31 @@ LinkedIR linkIR( // the entry point function itself, and rely on // this step to recursively copy over anything else // it might reference. - auto irEntryPoint = specializeIRForEntryPoint(context, entryPoint, offsetEntryPointLayout); + auto irEntryPoint = specializeIRForEntryPoint(context, offsetEntryPointLayout); - // HACK: right now the bindings for global generic parameters are coming in - // as part of the original IR module, and we need to make sure these get - // copied over, even if they aren't referenced. + // Bindings for global generic parameters are currently represented + // as stand-alone global-scope instructions in the IR module for + // `SpecializedComponentType`s. These instructions are required for + // correct codegen, and so we must make sure to copy them all over, + // even though they are not directly referenced. // - for(auto inst : originalProgramIRModule->getGlobalInsts()) - { - auto bindInst = as(inst); - if(!bindInst) - continue; - - cloneValue(context, bindInst); - } - - for(auto inst : originalProgramIRModule->getGlobalInsts()) + // TODO: We should change these to decorations, akin to how + // `[bindExistentialSlots(...)]` works, so that they can be attached + // to the relevant parameters and cloned via `cloneExtraDecorations`. + // In the long run we do not want to *ever* iterate over all the + // instructions in all the input modules. + // + program->enumerateIRModules([&](IRModule* irModule) { - if(inst->op != kIROp_BindGlobalExistentialSlots) - continue; + for(auto inst : irModule->getGlobalInsts()) + { + auto bindInst = as(inst); + if(!bindInst) + continue; - cloneValue(context, inst); - } + cloneValue(context, bindInst); + } + }); // HACK: we need to ensure that any tagged union types // in the IR module have layout information copied over to them. @@ -1357,7 +1395,7 @@ LinkedIR linkIR( // instructions, since we expected the tagged union type(s) to // be referenced by them. // - for( auto taggedUnionTypeLayout : entryPointLayout->taggedUnionTypeLayouts ) + for( auto taggedUnionTypeLayout : programLayout->taggedUnionTypeLayouts ) { auto taggedUnionType = taggedUnionTypeLayout->getType(); auto mangledName = getMangledTypeName(taggedUnionType); @@ -1377,6 +1415,15 @@ LinkedIR linkIR( // we have global variables with initializers, since // these should get run whether or not the entry point // references them. + // + // Or alternatively we can define by fiat that the initializers + // on global variables get run at an unspecified time between + // program startup and the first access to a given global. + // Such a definition gives us the freedom to eliminate globals + // that are never accessed, while still doing "eager" + // initialization for globals that are referenced (instead of + // having to add the overhead of lazy initialization a la + // function-`static` variables). // Now that we've cloned the entry point and everything // it refers to, we can package up the data we return diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ed86e4d6d..2e8de5e37 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2880,22 +2880,6 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitBindGlobalExistentialSlots( - UInt argCount, - IRInst* const* args) - { - auto inst = createInstWithTrailingArgs( - this, - kIROp_BindGlobalExistentialSlots, - getVoidType(), - 0, - nullptr, - argCount, - args); - addInst(inst); - return inst; - } - IRDecoration* IRBuilder::addBindExistentialSlotsDecoration( IRInst* value, UInt argCount, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index eef69b1b5..e41be0e33 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1613,13 +1613,16 @@ struct ValLoweringVisitor : ValVisitorbaseType); List slotArgs; - for(auto arg : type->slots.args) + for(auto arg : type->args) { - auto irArgType = lowerType(context, arg.type); - auto irArgWitness = lowerSimpleVal(context, arg.witness); + auto irArgVal = lowerSimpleVal(context, arg.val); + slotArgs.add(irArgVal); - slotArgs.add(irArgType); - slotArgs.add(irArgWitness); + if(auto witness = arg.witness) + { + auto irArgWitness = lowerSimpleVal(context, witness); + slotArgs.add(irArgWitness); + } } auto irType = getBuilder()->getBindExistentialsType(irBaseType, slotArgs.getCount(), slotArgs.getBuffer()); @@ -6265,13 +6268,18 @@ static void lowerFrontEndEntryPointToIR( } static void lowerProgramEntryPointToIR( - IRGenContext* context, - EntryPoint* entryPoint) + IRGenContext* context, + EntryPoint* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) { + auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); + if(specializationInfo) + entryPointFuncDeclRef = specializationInfo->specializedFuncDeclRef; + // First, lower the entry point like an ordinary function + auto session = context->getSession(); - auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); auto builder = context->irBuilder; @@ -6280,7 +6288,6 @@ static void lowerProgramEntryPointToIR( auto loweredEntryPointFunc = getSimpleVal(context, emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType)); - // if(!loweredEntryPointFunc->findDecoration()) { builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); @@ -6289,26 +6296,24 @@ static void lowerProgramEntryPointToIR( // We may have shader parameters of interface/existential type, // which need us to supply concrete type information for specialization. // - auto existentialTypeArgCount = entryPoint->getExistentialTypeArgCount(); - if( existentialTypeArgCount ) + if(specializationInfo && specializationInfo->existentialSpecializationArgs.getCount() != 0) { List existentialSlotArgs; - for( Index ii = 0; ii < existentialTypeArgCount; ++ii ) + for(auto arg : specializationInfo->existentialSpecializationArgs ) { - auto arg = entryPoint->getExistentialTypeArg(ii); - - auto irArgType = lowerType(context, arg.type); - auto irWitnessTable = lowerSimpleVal(context, arg.witness); + auto irArgType = lowerSimpleVal(context, arg.val); existentialSlotArgs.add(irArgType); - existentialSlotArgs.add(irWitnessTable); + + if(auto witness = arg.witness) + { + auto irWitnessTable = lowerSimpleVal(context, witness); + existentialSlotArgs.add(irWitnessTable); + } } builder->addBindExistentialSlotsDecoration(loweredEntryPointFunc, existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer()); } - - - } /// Ensure that `decl` and all relevant declarations under it get emitted. @@ -6451,97 +6456,136 @@ IRModule* generateIRForTranslationUnit( return module; } -RefPtr generateIRForProgram( - Session* session, - Program* program, - DiagnosticSink* sink) + /// Context for generating IR code to represent a `SpecializedComponentType` +struct SpecializedComponentTypeIRGenContext : ComponentTypeVisitor { -// auto compileRequest = translationUnit->compileRequest; + DiagnosticSink* sink; + Linkage* linkage; + Session* session; + IRGenContext* context; + IRBuilder* builder; - SharedIRGenContext sharedContextStorage( - session, - sink); - SharedIRGenContext* sharedContext = &sharedContextStorage; + RefPtr process( + SpecializedComponentType* componentType, + DiagnosticSink* inSink) + { + sink = inSink; - IRGenContext contextStorage(sharedContext); - IRGenContext* context = &contextStorage; + linkage = componentType->getLinkage(); + session = linkage->getSessionImpl(); - SharedIRBuilder sharedBuilderStorage; - SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; - sharedBuilder->module = nullptr; - sharedBuilder->session = session; + SharedIRGenContext sharedContextStorage( + session, + sink); + SharedIRGenContext* sharedContext = &sharedContextStorage; - IRBuilder builderStorage; - IRBuilder* builder = &builderStorage; - builder->sharedBuilder = sharedBuilder; + IRGenContext contextStorage(sharedContext); + context = &contextStorage; - RefPtr module = builder->createModule(); - sharedBuilder->module = module; + SharedIRBuilder sharedBuilderStorage; + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = nullptr; + sharedBuilder->session = session; - context->irBuilder = builder; + IRBuilder builderStorage; + builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; - // We need to emit symbols for all of the entry - // points in the program; this is especially - // important in the case where a generic entry - // point is being specialized. - // - for(auto entryPoint : program->getEntryPoints()) - { - lowerProgramEntryPointToIR(context, entryPoint); - } + RefPtr module = builder->createModule(); + sharedBuilder->module = module; + builder->setInsertInto(module->getModuleInst()); - // Now lower all the arguments supplied for global generic - // type parameters. - // - for (RefPtr subst = program->getGlobalGenericSubstitution(); subst; subst = subst->outer) - { - auto gSubst = subst.as(); - if(!gSubst) - continue; - - IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl)); - IRType* typeVal = lowerType(context, gSubst->actualType); + context->irBuilder = builder; - // bind `typeParam` to `typeVal` - builder->emitBindGlobalGenericParam(typeParam, typeVal); + componentType->acceptVisitor(this, nullptr); - for (auto& constraintArg : gSubst->constraintArgs) - { - IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl)); - IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val); + return module; + } - // bind `constraintParam` to `constraintVal` - builder->emitBindGlobalGenericParam(constraintParam, constraintVal); - } + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // We need to emit symbols for all of the entry + // points in the program; this is especially + // important in the case where a generic entry + // point is being specialized. + // + lowerProgramEntryPointToIR(context, entryPoint, specializationInfo); } - // We may have shader parameters of interface/existential type, - // which need us to supply concrete type information for specialization. - // - auto existentialTypeArgCount = program->getExistentialTypeArgCount(); - if( existentialTypeArgCount ) + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE { - List existentialSlotArgs; - for( Index ii = 0; ii < existentialTypeArgCount; ++ii ) + // We've hit a leaf module, so we should be able to bind any global + // generic type parameters here... + // + if( specializationInfo ) { - auto arg = program->getExistentialTypeArg(ii); + for( auto genericArgInfo : specializationInfo->genericArgs ) + { + IRInst* irParam = getSimpleVal(context, ensureDecl(context, genericArgInfo.paramDecl)); + IRInst* irVal = lowerSimpleVal(context, genericArgInfo.argVal); - auto irArgType = lowerType(context, arg.type); - auto irWitnessTable = lowerSimpleVal(context, arg.witness); + // bind `irParam` to `irVal` + builder->emitBindGlobalGenericParam(irParam, irVal); + } - existentialSlotArgs.add(irArgType); - existentialSlotArgs.add(irWitnessTable); + auto shaderParamCount = module->getShaderParamCount(); + Index existentialArgOffset = 0; + + for( Index ii = 0; ii < shaderParamCount; ++ii ) + { + auto shaderParam = module->getShaderParam(ii); + auto specializationArgCount = shaderParam.specializationParamCount; + + IRInst* irParam = getSimpleVal(context, ensureDecl(context, shaderParam.paramDeclRef)); + List irSlotArgs; + for( Index jj = 0; jj < specializationArgCount; ++jj ) + { + auto& specializationArg = specializationInfo->existentialArgs[existentialArgOffset++]; + + auto irType = lowerSimpleVal(context, specializationArg.val); + auto irWitness = lowerSimpleVal(context, specializationArg.witness); + + irSlotArgs.add(irType); + irSlotArgs.add(irWitness); + } + + builder->addBindExistentialSlotsDecoration( + irParam, + irSlotArgs.getCount(), + irSlotArgs.getBuffer()); + } } + } - builder->emitBindGlobalExistentialSlots(existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer()); + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); } + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } - // TODO: Should we apply any of the validation or - // mandatory optimization passes here? + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // TODO: This case should be akin to the `Module` case, + // and deal with global-scope specialization parameters + // directly. + // + SLANG_UNUSED(legacy); + SLANG_UNUSED(specializationInfo); + SLANG_UNIMPLEMENTED_X("legacy program case"); + } +}; - return module; +RefPtr generateIRForSpecializedComponentType( + SpecializedComponentType* componentType, + DiagnosticSink* sink) +{ + SpecializedComponentTypeIRGenContext context; + return context.process(componentType, sink); } } // namespace Slang diff --git a/source/slang/slang-lower-to-ir.h b/source/slang/slang-lower-to-ir.h index 060efb88b..33dfd9d27 100644 --- a/source/slang/slang-lower-to-ir.h +++ b/source/slang/slang-lower-to-ir.h @@ -15,14 +15,32 @@ namespace Slang { class EntryPoint; class ProgramLayout; + class SpecializedComponentType; class TranslationUnitRequest; + /// Generate an IR module to represent the code in the given `translationUnit`. + /// + /// The generated module will include IR definitions for any functions/types + /// in `translationUnit`, but it is *not* guaranteed to contain any definitions + /// from modules that are `import`ed into `translationUnit`. The resulting IR + /// module must be linked against other IR modules that define any symbols + /// that are imported before code generation can be performed. + /// IRModule* generateIRForTranslationUnit( TranslationUnitRequest* translationUnit); - RefPtr generateIRForProgram( - Session* session, - Program* program, - DiagnosticSink* sink); + /// Generate an IR module to represent the specializations applied by `componentType`. + /// + /// The generated IR will encode how `componentType` specializes global or + /// entry-point specialization parameters to concrete arguments (e.g., types). + /// + /// The generated IR module is *not* guaranteed to contain anything more, such + /// as the actual definitions of functions or types being specialized. The + /// resulting IR module must be linked against other IR modules that define + /// those symbols before code generation can be performed. + /// + RefPtr generateIRForSpecializedComponentType( + SpecializedComponentType* componentType, + DiagnosticSink* sink); } #endif diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index dc99f55e2..722725af7 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -674,16 +674,22 @@ static RefPtr processEntryPointVaryingParameter( EntryPointParameterState const& state, RefPtr varLayout); -// Collect a single declaration into our set of parameters -static void collectGlobalGenericParameter( - ParameterBindingContext* context, - RefPtr paramDecl) +static RefPtr _createVarLayout( + TypeLayout* typeLayout, + DeclRef varDeclRef) { - RefPtr layout = new GenericParamLayout(); - layout->decl = paramDecl; - layout->index = (int)context->shared->programLayout->globalGenericParams.getCount(); - context->shared->programLayout->globalGenericParams.add(layout); - context->shared->programLayout->globalGenericParamsMap[layout->decl->getName()->text] = layout.Ptr(); + RefPtr varLayout = new VarLayout(); + varLayout->typeLayout = typeLayout; + varLayout->varDecl = varDeclRef; + + if(auto pendingDataTypeLayout = typeLayout->pendingDataTypeLayout) + { + RefPtr pendingVarLayout = new VarLayout(); + pendingVarLayout->typeLayout = pendingDataTypeLayout; + varLayout->pendingVarLayout = pendingVarLayout; + } + + return varLayout; } // Collect a single declaration into our set of parameters @@ -712,9 +718,7 @@ static void collectGlobalScopeParameter( return; // Now create a variable layout that we can use - RefPtr varLayout = new VarLayout(); - varLayout->typeLayout = typeLayout; - varLayout->varDecl = varDeclRef; + RefPtr varLayout = _createVarLayout(typeLayout, varDeclRef); // The logic in `check.cpp` that created the `GlobalShaderParamInfo` // will have identified any cases where there might be multiple @@ -748,6 +752,7 @@ static void collectGlobalScopeParameter( RefPtr additionalVarLayout = new VarLayout(); additionalVarLayout->typeLayout = typeLayout; additionalVarLayout->varDecl = additionalVarDeclRef; + additionalVarLayout->pendingVarLayout = varLayout->pendingVarLayout; parameterInfo->varLayouts.add(additionalVarLayout); } @@ -1770,15 +1775,40 @@ static RefPtr processEntryPointVaryingParameter( return structLayout; } - else if (auto globalGenericParam = declRef.as()) + else if (auto globalGenericParamDecl = declRef.as()) { - auto genParamTypeLayout = new GenericParamTypeLayout(); - // we should have already populated ProgramLayout::genericEntryPointParams list at this point, - // so we can find the index of this generic param decl in the list - genParamTypeLayout->type = type; - genParamTypeLayout->paramIndex = findGenericParam(context->shared->programLayout->globalGenericParams, globalGenericParam.getDecl()); - genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; - return genParamTypeLayout; + auto& layoutContext = context->layoutContext; + + if( auto concreteType = findGlobalGenericSpecializationArg( + layoutContext, + globalGenericParamDecl) ) + { + // If we know what concrete type has been used to specialize + // the global generic type parameter, then we should use + // the concrete type instead. + // + // Note: it should be illegal for the user to use a generic + // type parameter in a varying parameter list without giving + // it an explicit user-defined semantic. Otherwise, it would be possible + // that the concrete type that gets plugged in is a user-defined + // `struct` that uses some `SV_` semantics in its definition, + // so that any static information about what system values + // the entry point uses would be incorrect. + // + return processEntryPointVaryingParameter(context, concreteType, state, varLayout); + } + else + { + // If we don't know a concrete type, then we aren't generating final + // code, so the reflection information should show the generic + // type parameter. + // + // We don't make any attempt to assign varying parameter resources + // to the generic type, since we can't know how many "slots" + // of varying input/output it would consume. + // + return createTypeLayoutForGlobalGenericTypeParam(layoutContext, type, globalGenericParamDecl); + } } else if (auto associatedTypeParam = declRef.as()) { @@ -1804,15 +1834,12 @@ static RefPtr processEntryPointVaryingParameter( /// Compute the type layout for a parameter declared directly on an entry point. static RefPtr computeEntryPointParameterTypeLayout( ParameterBindingContext* context, - SubstitutionSet typeSubst, DeclRef paramDeclRef, RefPtr paramVarLayout, EntryPointParameterState& state) { - auto paramDeclRefType = GetType(paramDeclRef); - SLANG_ASSERT(paramDeclRefType); - - auto paramType = paramDeclRefType->Substitute(typeSubst).as(); + auto paramType = GetType(paramDeclRef); + SLANG_ASSERT(paramType); if( paramDeclRef.getDecl()->HasModifier() ) { @@ -1940,6 +1967,12 @@ struct ScopeLayoutBuilder { m_structLayout->mapVarToLayout.Add(firstVarLayout->varDecl.getDecl(), firstVarLayout); } + } + + void addParameter( + RefPtr varLayout) + { + _addParameter(varLayout, nullptr); // Any "pending" items on a field type become "pending" items // on the overall `struct` type layout. @@ -1948,33 +1981,17 @@ struct ScopeLayoutBuilder // `struct` layout logic in `type-layout.cpp`. If this gets any // more complicated we should see if there is a way to share it. // - if( auto fieldPendingDataTypeLayout = firstVarLayout->typeLayout->pendingDataTypeLayout ) + if( auto fieldPendingDataTypeLayout = varLayout->typeLayout->pendingDataTypeLayout ) { m_pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(nullptr, m_rules); - auto fieldPendingDataVarLayout = m_pendingDataTypeLayoutBuilder.addField(firstVarLayout->varDecl, fieldPendingDataTypeLayout); + auto fieldPendingDataVarLayout = m_pendingDataTypeLayoutBuilder.addField(varLayout->varDecl, fieldPendingDataTypeLayout); m_structLayout->pendingDataTypeLayout = m_pendingDataTypeLayoutBuilder.getTypeLayout(); - if( parameterInfo ) - { - for( auto& varLayout : parameterInfo->varLayouts ) - { - varLayout->pendingVarLayout = fieldPendingDataVarLayout; - } - } - else - { - firstVarLayout->pendingVarLayout = fieldPendingDataVarLayout; - } + varLayout->pendingVarLayout = fieldPendingDataVarLayout; } } - void addParameter( - RefPtr varLayout) - { - _addParameter(varLayout, nullptr); - } - void addParameter( ParameterInfo* parameterInfo) { @@ -1982,8 +1999,40 @@ struct ScopeLayoutBuilder auto firstVarLayout = parameterInfo->varLayouts.getFirst(); _addParameter(firstVarLayout, parameterInfo); - } + // Global parameters will have their non-orindary/uniform + // pending data handled by the main parameter binding + // logic, but we still need to construct a layout + // that includes any pending data. + // + if(auto fieldPendingVarLayout = firstVarLayout->pendingVarLayout) + { + auto fieldPendingTypeLayout = fieldPendingVarLayout->typeLayout; + + m_pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(nullptr, m_rules); + m_structLayout->pendingDataTypeLayout = m_pendingDataTypeLayoutBuilder.getTypeLayout(); + + auto fieldUniformLayoutInfo = fieldPendingTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform); + LayoutSize fieldUniformSize = fieldUniformLayoutInfo ? fieldUniformLayoutInfo->count : 0; + if( fieldUniformSize != 0 ) + { + // Make sure uniform fields get laid out properly... + + UniformLayoutInfo fieldInfo( + fieldUniformSize, + fieldPendingTypeLayout->uniformAlignment); + + LayoutSize uniformOffset = m_rules->AddStructField( + m_pendingDataTypeLayoutBuilder.getStructLayoutInfo(), + fieldInfo); + + fieldPendingVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); + } + + m_pendingDataTypeLayoutBuilder.getTypeLayout()->fields.add(fieldPendingVarLayout); + } + + } // Add a "simple" parameter that cannot have any user-defined // register or binding modifiers, so that its layout computation @@ -2088,21 +2137,48 @@ static ParameterBindingAndKindInfo maybeAllocateConstantBufferBinding( return info; } + /// Remove resource usage from `typeLayout` that should only be stored per-entry-point. + /// + /// This is used when constructing the overall layout for an entry point, to make sure + /// that certain kinds of resource usage from the entry point don't "leak" into + /// the resource usage of the overall program. + /// +static void removePerEntryPointParameterKinds( + TypeLayout* typeLayout) +{ + typeLayout->removeResourceUsage(LayoutResourceKind::VaryingInput); + typeLayout->removeResourceUsage(LayoutResourceKind::VaryingOutput); + typeLayout->removeResourceUsage(LayoutResourceKind::ShaderRecord); + typeLayout->removeResourceUsage(LayoutResourceKind::HitAttributes); + typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialObjectParam); + typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialTypeParam); +} + /// Iterate over the parameters of an entry point to compute its requirements. /// static RefPtr collectEntryPointParameters( - ParameterBindingContext* context, - EntryPoint* entryPoint, - SubstitutionSet typeSubst) + ParameterBindingContext* context, + EntryPoint* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) { DeclRef entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); + // If specialization was applied to the entry point, then the side-band + // information that was generated will have a more specialized reference + // to the entry point with generic parameters filled in. We should + // use that version if it is available. + // + if(specializationInfo) + entryPointFuncDeclRef = specializationInfo->specializedFuncDeclRef; + + auto entryPointType = DeclRefType::Create(context->getLinkage()->getSessionImpl(), entryPointFuncDeclRef); + // We will take responsibility for creating and filling in // the `EntryPointLayout` object here. // RefPtr entryPointLayout = new EntryPointLayout(); entryPointLayout->profile = entryPoint->getProfile(); - entryPointLayout->entryPoint = entryPointFuncDeclRef.getDecl(); + entryPointLayout->entryPoint = entryPointFuncDeclRef; // The entry point layout must be added to the output // program layout so that it can be accessed by reflection. @@ -2114,19 +2190,6 @@ static RefPtr collectEntryPointParameters( // context->entryPointLayout = entryPointLayout; - // Note: this isn't really the best place for this logic to sit, - // but it is the simplest place where we have a direct correspondence - // between a single `EntryPoint` and its matching `EntryPointLayout`, - // so we'll use it. - // - for( auto taggedUnionType : entryPoint->getTaggedUnionTypes() ) - { - SLANG_ASSERT(taggedUnionType); - auto substType = taggedUnionType->Substitute(typeSubst).as(); - auto typeLayout = createTypeLayout(context->layoutContext, substType); - entryPointLayout->taggedUnionTypeLayouts.add(typeLayout); - } - // We are going to iterate over the entry-point parameters, // and while we do so we will go ahead and perform layout/binding // assignment for two cases: @@ -2149,22 +2212,33 @@ static RefPtr collectEntryPointParameters( ScopeLayoutBuilder scopeBuilder; scopeBuilder.beginLayout(context); auto paramsStructLayout = scopeBuilder.m_structLayout; + paramsStructLayout->type = entryPointType; for( auto& shaderParamInfo : entryPoint->getShaderParams() ) { auto paramDeclRef = shaderParamInfo.paramDeclRef; + // Any generic specialization applied to the entry-point function + // must also be applied to its parameters. + paramDeclRef.substitutions = entryPointFuncDeclRef.substitutions; + // When computing layout for an entry-point parameter, // we want to make sure that the layout context has access // to the existential type arguments (if any) that were // provided for the entry-point existential type parameters (if any). // - context->layoutContext= context->layoutContext - .withExistentialTypeArgs( - entryPoint->getExistentialTypeArgCount(), - entryPoint->getExistentialTypeArgs()) - .withExistentialTypeSlotsOffsetBy( - shaderParamInfo.firstExistentialTypeSlot); + if(specializationInfo) + { + auto& existentialSpecializationArgs = specializationInfo->existentialSpecializationArgs; + auto genericSpecializationParamCount = entryPoint->getGenericSpecializationParamCount(); + + context->layoutContext = context->layoutContext + .withSpecializationArgs( + existentialSpecializationArgs.getBuffer(), + existentialSpecializationArgs.getCount()) + .withSpecializationArgsOffsetBy( + shaderParamInfo.firstSpecializationParamIndex - genericSpecializationParamCount); + } // Any error messages we emit during the process should // refer to the location of this parameter. @@ -2183,7 +2257,6 @@ static RefPtr collectEntryPointParameters( auto paramTypeLayout = computeEntryPointParameterTypeLayout( context, - typeSubst, paramDeclRef, paramVarLayout, state); @@ -2204,6 +2277,26 @@ static RefPtr collectEntryPointParameters( // scopeBuilder.addSimpleParameter(paramVarLayout); } + + // We don't want certain kinds of resource usage within an entry + // point to "leak" into the overall resource usage of the entry + // point and thus lead to offsetting of successive entry points. + // + // For example if we have a vertex and a fragment entry point + // in the some program, and each has one varying input, then + // the both the vertex and fragment varying outputs should have + // a location/index of zero. It would be bad if the fragment + // input (or whichever entry point comes second in the global + // ordering) started at location one, because then it wouldn't + // line up correctly with any vertex stage outputs. + // + // We handle this with a bit of a kludge, by removing the + // particular `LayoutResourceKind`s that are susceptible to + // this problem from the overall resource usage of the entry + // point. + // + removePerEntryPointParameterKinds(scopeBuilder.m_structLayout); + entryPointLayout->parametersLayout = scopeBuilder.endLayout(); // For an entry point with a non-`void` return type, we need to process the @@ -2212,7 +2305,7 @@ static RefPtr collectEntryPointParameters( // TODO: Ideally we should make the layout process more robust to empty/void // types and apply this logic unconditionally. // - auto resultType = GetResultType(entryPointFuncDeclRef)->Substitute(typeSubst).as(); + auto resultType = GetResultType(entryPointFuncDeclRef); SLANG_ASSERT(resultType); if( !resultType->Equals(resultType->getSession()->getVoidType()) ) @@ -2226,7 +2319,7 @@ static RefPtr collectEntryPointParameters( auto resultTypeLayout = processEntryPointVaryingParameterDecl( context, entryPointFuncDeclRef.getDecl(), - resultType->Substitute(typeSubst).as(), + resultType, state, resultLayout); @@ -2248,136 +2341,257 @@ static RefPtr collectEntryPointParameters( return entryPointLayout; } - /// Remove resource usage from `typeLayout` that should only be stored per-entry-point. - /// - /// This is used when constructing the layout for an entry point group, to make sure - /// that certain kinds of resource usage from the entry point don't "leak" into - /// the resource usage of the group. - /// -static void removePerEntryPointParameterKinds( - TypeLayout* typeLayout) + /// Visitor used by `collectGlobalGenericArguments` +struct CollectGlobalGenericArgumentsVisitor : ComponentTypeVisitor { - typeLayout->removeResourceUsage(LayoutResourceKind::VaryingInput); - typeLayout->removeResourceUsage(LayoutResourceKind::VaryingOutput); - typeLayout->removeResourceUsage(LayoutResourceKind::ShaderRecord); - typeLayout->removeResourceUsage(LayoutResourceKind::HitAttributes); - typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialObjectParam); - typeLayout->removeResourceUsage(LayoutResourceKind::ExistentialTypeParam); -} + CollectGlobalGenericArgumentsVisitor( + ParameterBindingContext* context) + : m_context(context) + {} -static void collectParameters( - ParameterBindingContext* inContext, - Program* program) -{ - // All of the parameters in translation units directly - // referenced in the compile request are part of one - // logical namespace/"linkage" so that two parameters - // with the same name should represent the same - // parameter, and get the same binding(s) + ParameterBindingContext* m_context; - ParameterBindingContext contextData = *inContext; - auto context = &contextData; - context->stage = Stage::Unknown; + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(entryPoint); + SLANG_UNUSED(specializationInfo); + } - auto globalGenericSubst = program->getGlobalGenericSubstitution(); + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(module); - // We will start by looking for any global generic type parameters. + if(!specializationInfo) + return; - for(RefPtr module : program->getModuleDependencies()) - { - for( auto genParamDecl : module->getModuleDecl()->getMembersOfType() ) + for(auto& globalGenericArg : specializationInfo->genericArgs) { - collectGlobalGenericParameter(context, genParamDecl); + if(auto globalGenericTypeParamDecl = as(globalGenericArg.paramDecl)) + { + m_context->shared->programLayout->globalGenericArgs.Add(globalGenericTypeParamDecl, globalGenericArg.argVal); + } } } - // Once we have enumerated global generic type parameters, we can - // begin enumerating shader parameters, starting at the global scope. - // - // Because we have already enumerated the global generic type parameters, - // we will be able to look up the index of a global generic type parameter - // when we see it referenced in the type of one of the shader parameters. + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } - for(auto& globalParamInfo : program->getShaderParams() ) + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE { - // When computing layout for a global shader parameter, - // we want to make sure that the layout context has access - // to the existential type arguments (if any) that were - // provided for the global existential type parameters (if any). - // - context->layoutContext= context->layoutContext - .withExistentialTypeArgs( - program->getExistentialTypeArgCount(), - program->getExistentialTypeArgs()) - .withExistentialTypeSlotsOffsetBy( - globalParamInfo.firstExistentialTypeSlot); + specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo()); + } - collectGlobalScopeParameter(context, globalParamInfo, globalGenericSubst); + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // TODO: Need to do something in this case... + SLANG_UNUSED(legacy); + SLANG_UNUSED(specializationInfo); } +}; + + /// Collect an ordered list of all the specialization arguments given for global generic specialization parameters in `program`. + /// + /// This information is used to accelerate the process of mapping a global generic type + /// to its definition during type layout. + /// +static void collectGlobalGenericArguments( + ParameterBindingContext* context, + ComponentType* program) +{ + CollectGlobalGenericArgumentsVisitor visitor(context); + program->acceptVisitor(&visitor, nullptr); +} - // Next consider parameters for entry points - for( auto entryPointGroup : program->getEntryPointGroups() ) + /// Collect information about the (unspecialized) specialization parameters of `program` into `context`. + /// + /// This function computes the reflection/layout for for the specialization parameters, so + /// that they can be exposed to the API user. + /// +static void collectSpecializationParams( + ParameterBindingContext* context, + ComponentType* program) +{ + auto specializationParamCount = program->getSpecializationParamCount(); + for(Index ii = 0; ii < specializationParamCount; ++ii) { - RefPtr entryPointGroupLayout = new EntryPointGroupLayout(); - entryPointGroupLayout->group = entryPointGroup; + auto specializationParam = program->getSpecializationParam(ii); + switch(specializationParam.flavor) + { + case SpecializationParam::Flavor::GenericType: + case SpecializationParam::Flavor::GenericValue: + { + RefPtr paramLayout = new GenericSpecializationParamLayout(); + paramLayout->decl = specializationParam.object.as(); + context->shared->programLayout->specializationParams.add(paramLayout); + } + break; - context->shared->programLayout->entryPointGroups.add(entryPointGroupLayout); + case SpecializationParam::Flavor::ExistentialType: + case SpecializationParam::Flavor::ExistentialValue: + { + RefPtr paramLayout = new ExistentialSpecializationParamLayout(); + paramLayout->type = specializationParam.object.as(); + context->shared->programLayout->specializationParams.add(paramLayout); + } + break; + default: + SLANG_UNEXPECTED("unhandled specialization parameter flavor"); + break; + } + } +} + + /// Visitor used by `collectParameters()` +struct CollectParametersVisitor : ComponentTypeVisitor +{ + CollectParametersVisitor( + ParameterBindingContext* context) + : m_context(context) + {} - ScopeLayoutBuilder scopeBuilder; - scopeBuilder.beginLayout(context); - auto entryPointGroupParamsStructLayout = scopeBuilder.m_structLayout; + ParameterBindingContext* m_context; - // First lay out any shader parameters that belong to the group - // itself, rather than to its nested entry points. + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // The parameters of a composite component type can + // be determined by just visiting its children in order. // - // This ensures that looking up one of the parameters of the - // group by index in its parameters truct will Just Work. + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + // The parameters of a specialized component type + // are just those of its base component type, with + // appropriate specialization information passed + // along. // - for( auto groupParam : entryPointGroup->getShaderParams() ) + visitChildren(specialized); + + // While we are at it, we will also make note of any + // tagged-union types that were used as part of the + // specialization arguments, since we need to make + // sure that their layout information is computed + // and made available for IR code generation. + // + // Note: this isn't really the best place for this logic to sit, + // but it is the simplest place where we can collect all the tagged + // union types that get referenced by a program. + // + for( auto taggedUnionType : specialized->getTaggedUnionTypes() ) { - auto paramDeclRef = groupParam.paramDeclRef; - auto paramType = GetType(paramDeclRef); + SLANG_ASSERT(taggedUnionType); + auto substType = taggedUnionType; + auto typeLayout = createTypeLayout(m_context->layoutContext, substType); + m_context->shared->programLayout->taggedUnionTypeLayouts.add(typeLayout); + } + } - RefPtr paramVarLayout = new VarLayout(); - paramVarLayout->varDecl = paramDeclRef; - auto paramTypeLayout = createTypeLayout( - context->layoutContext.with(context->getRulesFamily()->getConstantBufferRules()), - paramType); - paramVarLayout->typeLayout = paramTypeLayout; + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // An entry point is a leaf case. + // + // In our current model an entry point does not introduce + // any global shader parameters, but in practice it effectively + // acts a lot like a single global shader parameter named after + // the entry point and with a `struct` type that combines + // all the `uniform` entry point parameters. + // + // Later passes will need to make sure that the entry point + // gets enumerated in the right order relative to any global + // shader parameters. + // - scopeBuilder.addSimpleParameter(paramVarLayout); - } + ParameterBindingContext contextData = *m_context; + auto context = &contextData; + context->stage = entryPoint->getStage(); - for(auto entryPoint : entryPointGroup->getEntryPoints()) - { - // Note: we do not want the entry point group to accumulate - // locations for varying input/output parameters: those - // should be specific to each entry point. - // - // We address this issue by manually removing any - // layout information for the relevant resource kinds - // from the group's layout before adding the parameters - // of any entry point for layout. - // - removePerEntryPointParameterKinds(scopeBuilder.m_structLayout); + collectEntryPointParameters(context, entryPoint, specializationInfo); + } - context->stage = entryPoint->getStage(); - auto entryPointLayout = collectEntryPointParameters(context, entryPoint, globalGenericSubst); + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // A single module represents a leaf case for layout. + // + // We will enumerate the (global) shader parameters declared + // in the module and add each to our canonical ordering. + // + auto paramCount = module->getShaderParamCount(); - auto entryPointParamsLayout = entryPointLayout->parametersLayout; - auto entryPointParamsTypeLayout = entryPointParamsLayout->typeLayout; + ExpandedSpecializationArg* specializationArgs = specializationInfo + ? specializationInfo->existentialArgs.getBuffer() + : nullptr; - scopeBuilder.addSimpleParameter(entryPointParamsLayout); + for(Index pp = 0; pp < paramCount; ++pp) + { + auto shaderParamInfo = module->getShaderParam(pp); + if(specializationArgs) + { + m_context->layoutContext = m_context->layoutContext.withSpecializationArgs( + specializationArgs, + shaderParamInfo.specializationParamCount); + specializationArgs += shaderParamInfo.specializationParamCount; + } - entryPointGroupLayout->entryPoints.add(entryPointLayout); + collectGlobalScopeParameter(m_context, shaderParamInfo, SubstitutionSet()); } - removePerEntryPointParameterKinds(scopeBuilder.m_structLayout); + } + + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // A legacy program is also a leaf case, and we + // can enumerate its parameters directly. + // + // Note: there is a mismatch here where we really + // ought to be tracking specialization arguments + // for a `LegacyProgram` akin to how they are + // tracked for a `Module`, but right now we try + // to do it like a `CompositeComponentType`. + // As a result we are just ignoring specialization + // information here, which will lead to incorrect + // results if somebody every uses specialization + // together with the "legacy" program case. + // + // TODO: eliminate this problem by getting rid of + // `LegacyProgram`, rather than spend time trying + // to make this corner case actually work. + // + SLANG_UNUSED(specializationInfo); - entryPointGroupLayout->parametersLayout = scopeBuilder.endLayout(); + auto paramCount = legacy->getShaderParamCount(); + for(Index pp = 0; pp < paramCount; ++pp) + { + collectGlobalScopeParameter(m_context, legacy->getShaderParam(pp), SubstitutionSet()); + } } - context->entryPointLayout = nullptr; +}; + + /// Recursively collect the global shader parameters and entry points in `program`. + /// + /// This function is used to establish the global ordering of parameters and + /// entry points used for layout. + /// +static void collectParameters( + ParameterBindingContext* inContext, + ComponentType* program) +{ + // All of the parameters in translation units directly + // referenced in the compile request are part of one + // logical namespace/"linkage" so that two parameters + // with the same name should represent the same + // parameter, and get the same binding(s) + + ParameterBindingContext contextData = *inContext; + auto context = &contextData; + context->stage = Stage::Unknown; + + CollectParametersVisitor visitor(context); + program->acceptVisitor(&visitor, nullptr); } /// Emit a diagnostic about a uniform parameter at global scope. @@ -2418,6 +2632,302 @@ static int _calcTotalNumUsedRegistersForLayoutResourceKind(ParameterBindingConte return numUsed; } + /// Keep track of the running global counter for entry points and global parameters visited. + /// + /// Because of explicit `register` and `[[vk::binding(...)]]` support, parameter binding + /// needs to proceed in multiple passes, and each pass must both visit the things that + /// need layout (parameters and entry points) in the same order in each pass, and must + /// also be able to look up the side-band information that flows between passes. + /// + /// Currently the `ParameterBindingContext` keeps separate arrays for global shader + /// parameters and entry points, but in the global ordering for layout they can be + /// interleaved. There is also no simple tracking structure that relates a global + /// parameter or entry point to its index in those arrays. Instead, we just keep + /// running counters during our passes over the program so that we can easily + /// compute the linear index of each entry point and global parameter as it + /// is encountered. + /// +struct ParameterBindingVisitorCounters +{ + Index entryPointCounter = 0; + Index globalParamCounter = 0; +}; + + /// Recursive routine to "complete" all binding for parameters and entry points in `componentType`. + /// + /// This includes allocation of as-yet-unused register/binding ranges to parameters (which + /// will then affect the ranges of registers/bindings that are available to subsequent + /// parameters), and imporantly *also* includes allocate of space to any "pending" + /// data for interface/existential type parameters/fields. + /// +static void _completeBindings( + ParameterBindingContext* context, + ComponentType* componentType, + ParameterBindingVisitorCounters* ioCounters); + + /// A visitor used by `_completeBindings`. + /// + /// This visitor walks the structure of a `ComponentType` to ensure that + /// any shader parameters (and entry points) it contains that *don't* + /// have explicit bindings on them get allocated registers/bindings + /// as appropriate. + /// + /// The main complication of this visitor is how it handles the + /// `SpecializedComponentType` case, because a specialized component + /// type needs to be handled as an atomic unit that lays out the + /// same in all contexts. + /// +struct CompleteBindingsVisitor : ComponentTypeVisitor +{ + CompleteBindingsVisitor(ParameterBindingContext* context, ParameterBindingVisitorCounters* counters) + : m_context(context) + , m_counters(counters) + {} + + ParameterBindingContext* m_context; + ParameterBindingVisitorCounters* m_counters; + + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(entryPoint); + SLANG_UNUSED(specializationInfo); + + // We compute the index of the entry point in the global ordering, + // so we can look up the tracking data in our context. As a result + // we don't actually make use of the parameters that were passed in. + // + auto globalEntryPointIndex = m_counters->entryPointCounter++; + auto globalEntryPointInfo = m_context->shared->programLayout->entryPoints[globalEntryPointIndex]; + + + // We mostly treat an entry point like a single shader parameter that + // uses its `parametersLayout`. + // + completeBindingsForParameter(m_context, globalEntryPointInfo->parametersLayout); + } + + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(specializationInfo); + // A module is a leaf case: we just want to visit each parameter. + visitLeafParams(module); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(specializationInfo); + // A legacy program is a leaf case: we just want to visit each parameter. + visitLeafParams(legacy); + } + + void visitLeafParams(ComponentType* componentType) + { + auto paramCount = componentType->getShaderParamCount(); + for(Index ii = 0; ii < paramCount; ++ii) + { + auto globalParamIndex = m_counters->globalParamCounter++; + auto globalParamInfo = m_context->shared->parameters[globalParamIndex]; + + completeBindingsForParameter(m_context, globalParamInfo); + } + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + // We just wnat to recurse on the children of the composite in order. + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + // The handling of a specialized component type here is subtle. + // + // We do *not* simply recurse on the base component type. + // Doing so would ensure that the parameters would get + // registers/bindings allocated to them, but it wouldn't + // allocate space for the "pending" data related to + // existential/interface parameters. + // + // Instead, we recursive through `_completeBindings`, + // which has the job of allocating space for the parameters, + // and then for any "pending" data required. + // + // Handling things this way ensures that a particular + // `SpecializedComponentType` gets laid out exactly + // the same wherever it gets used, rather than + // getting laid out differently when it is placed + // into different compositions. + // + auto base = specialized->getBaseComponentType(); + _completeBindings(m_context, base, m_counters); + } +}; + + /// A visitor used by `_completeBindings`. + /// + /// This visitor is used to follow up after the `CompleteBindingsVisitor` + /// any ensure that any "pending" data required by the parameters that + /// got laid out now gets a location. + /// + /// To make a concrete example: + /// + /// Texture2D a; + /// IThing b; + /// Texture2D c; + /// + /// If these parameters were laid out with `b` specialized to a type + /// that contains a single `Texture2D`, then the `CompleteBindingsVisitor` + /// would visit `a`, `b`, and then `c` in order. It would give `a` the + /// first register/binding available (say, `t0`). It would then make + /// a note that due to specialization, `b`, needs a `t` register as well, + /// but it *cannot* be allocated just yet, because doing so would change + /// the location of `c`, so it is marked as "pending." Then `c` would + /// be visited and get `t1`. As a result the registers given to `a` + /// and `c` are independent of how `b` gets specialized. + /// + /// Next, the `FlushPendingDataVisitor` comes through and applies to + /// the parameters again. For `a` there is no pending data, but for + /// `b` there is a pending request for a `t` register, so it gets allocated + /// now (getting `t2`). The `c` parameter then has no pending data, so + /// we are done. + /// + /// *When* the pending data gets flushed is then significant. In general, + /// the order in which modules get composed an specialized is signficaint. + /// The module above (let's call it `M`) has one specialization parameter + /// (for `b`), and if we want to compose it with another module `N` that + /// has no specialization parameters, we could compute either: + /// + /// compose(specialize(M, SomeType), N) + /// + /// or: + /// + /// specialize(compose(M,N), SomeType) + /// + /// In the first case, the "pending" data for `M` gets flushed right after `M`, + /// so that `specialize(M,SomeType)` can have a consistent layout + /// regardless of how it is used. In the second case, the pending data for + /// `M` only gets flushed after `N`'s parameters are allocated, thus guaranteeing + /// that the `compose(M,N)` part has a consistent layout regardless of what + /// type gets plugged in during specialization. + /// + /// There are trade-offs to be made by an application about which approach + /// to prefer, and the compiler supports either policy choice. + /// +struct FlushPendingDataVisitor : ComponentTypeVisitor +{ + FlushPendingDataVisitor(ParameterBindingContext* context, ParameterBindingVisitorCounters* counters) + : m_context(context) + , m_counters(counters) + {} + + ParameterBindingContext* m_context; + ParameterBindingVisitorCounters* m_counters; + + void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(entryPoint); + SLANG_UNUSED(specializationInfo); + + auto globalEntryPointIndex = m_counters->entryPointCounter++; + auto globalEntryPointInfo = m_context->shared->programLayout->entryPoints[globalEntryPointIndex]; + + // We need to allocate space for any "pending" data that + // appeared in the entry-point parameter list. + // + _allocateBindingsForPendingData(m_context, globalEntryPointInfo->parametersLayout->pendingVarLayout); + } + + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(specializationInfo); + visitLeafParams(module); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(specializationInfo); + visitLeafParams(legacy); + } + + void visitLeafParams(ComponentType* componentType) + { + // In the "leaf" case we just allocate space for any + // pending data in the parameters, in order. + // + auto paramCount = componentType->getShaderParamCount(); + for(Index ii = 0; ii < paramCount; ++ii) + { + auto globalParamIndex = m_counters->globalParamCounter++; + auto globalParamInfo = m_context->shared->parameters[globalParamIndex]; + auto firstVarLayout = globalParamInfo->varLayouts[0]; + + _allocateBindingsForPendingData(m_context, firstVarLayout->pendingVarLayout); + } + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + // Because `SpecializedComponentType` was a special case for `CompleteBindingsVisitor`, + // it ends up being a special case here too. + // + // The `CompleteBindings...` pass treated a `SpecializedComponentType` + // as an atomic unit. Any "pending" data that came from its parameters + // will already have been dealt with, so it would be incorrect for + // us to recurse into `specialized`. + // + // Instead, we just need to *skip* `specialized`, since it was + // completely handled already. This isn't quite as simple + // as just doing nothing, because our passes are using + // some global counters to find the absolute/linear index + // of each parameter and entry point as it is encountered. + // We will simply bump those counters by the number of + // parameters and entry points contained under `specialized`, + // which is luckily provided by the `ComponentType` API. + // + m_counters->globalParamCounter += specialized->getShaderParamCount(); + m_counters->entryPointCounter += specialized->getEntryPointCount(); + } + +}; + +static void _completeBindings( + ParameterBindingContext* context, + ComponentType* componentType, + ParameterBindingVisitorCounters* ioCounters) +{ + ParameterBindingVisitorCounters savedCounters = *ioCounters; + + CompleteBindingsVisitor completeBindingsVisitor(context, ioCounters); + componentType->acceptVisitor(&completeBindingsVisitor, nullptr); + + FlushPendingDataVisitor flushVisitor(context, &savedCounters); + componentType->acceptVisitor(&flushVisitor, nullptr); +} + + /// "Complete" binding of parametesr in the given `program`. + /// + /// Completing binding involves both assigning registers/bindings + /// to an parameters that didn't get explicit locations, and then + /// also providing locations to any "pending" data that needed + /// space allocated (used for existential/interface type parameters). + /// +static void _completeBindings( + ParameterBindingContext* context, + ComponentType* program) +{ + // The process of completing binding has a recursive structure, + // so we will immediately delegate to a subroutine that handles + // the recursion. + // + ParameterBindingVisitorCounters counters; + _completeBindings(context, program, &counters); +} + RefPtr generateParameterBindings( TargetProgram* targetProgram, DiagnosticSink* sink) @@ -2450,20 +2960,67 @@ RefPtr generateParameterBindings( context.shared = &sharedContext; context.layoutContext = layoutContext; - // Walk through AST to discover all the parameters + // We want to start by finding out what (if anything) has + // been bound to the global generic parameters of the + // program, since we need to know these types to compute + // layout for parameters that use the generic type parameters. + // + collectGlobalGenericArguments(&context, program); + + // Next we want to collect a full listing of all the shader + // parameters that need to be considered for layout, along + // with all of the entry points, which also need their + // parameters laid out and thus act pretty much like global + // parameters themselves. + // collectParameters(&context, program); - // Now walk through the parameters to generate initial binding information + // We will also collect basic information on the specialization + // parameters exposed by the program. + // + // Whereas `collectGlobalGenericArguments` was collecting the + // concrete types that have been plugged into specialization + // parameters, this step is about collecting the *unspecialized* + // parameters (if any) for the purposes of reflection. + // + collectSpecializationParams(&context, program); + + // Once we have a canonical list of all the shader parameters + // (and entry points) in need of layout, we will walk through + // the parameters that might have explicit binding annotations, + // and "reserve" the registers/bindings/etc. that those parameters + // declare so that subequent automatic layout steps do not try to + // overlap them. + // + // Along the way we will issue diagnostics if there appear to + // be overlapping, conflicting, or inconsistent explicit bindings. + // + // Note that we do *not* support explicit binding annotations + // on entry point parameters, so we only consider global shader + // parameters here. + // + // (Also note that explicit bindings end up being the main + // source of complexity in the layout system, and we could greatly + // simplify this file by eliminating support for explicit + // binding in the future) + // for( auto& parameter : sharedContext.parameters ) { generateParameterBindings(&context, parameter); } - // Determine if there are any global-scope parameters that use `Uniform` - // resources, and thus need to get packaged into a constant buffer. + // Once we have a canonical list of all the parameters, we can + // detect if there are any global-scope parameters that make use + // of `LayoutResourceKind::Uniform`, since such parameters would + // need to be packaged into a "default" constant buffer. + // The fxc/dxc compilers support this step, and in reflection + // refer to the generated constant buffer as `$Globals`. + // + // Note that this logic doesn't account for the existance of + // "legacy" (non-buffer-bound) uniforms in GLSL for OpenGL. + // If we wanted to support legaqcy uniforms we would probably + // want to do so through a different feature. // - // Note: this doesn't account for GLSL's support for "legacy" uniforms - // at global scope, which don't get assigned a CB. bool needDefaultConstantBuffer = false; for( auto& parameterInfo : sharedContext.parameters ) { @@ -2525,6 +3082,32 @@ RefPtr generateParameterBindings( break; } } + + // We also need a default space for any entry-point parameters + // that consume appropriate resource kinds. + // + for(auto& entryPoint : sharedContext.programLayout->entryPoints) + { + auto paramsLayout = entryPoint->parametersLayout; + for(auto resInfo : paramsLayout->resourceInfos ) + { + switch(resInfo.kind) + { + default: + break; + + case LayoutResourceKind::RegisterSpace: + case LayoutResourceKind::VaryingInput: + case LayoutResourceKind::VaryingOutput: + case LayoutResourceKind::HitAttributes: + case LayoutResourceKind::RayPayload: + continue; + } + + needDefaultSpace = true; + break; + } + } } // If we need a space for default bindings, then allocate it here. @@ -2575,12 +3158,14 @@ RefPtr generateParameterBindings( &context, needDefaultConstantBuffer); - // Now walk through again to actually give everything - // ranges of registers... - for( auto& parameter : sharedContext.parameters ) - { - completeBindingsForParameter(&context, parameter); - } + // Now that all of the explicit bindings have been dealt with + // and we've also allocate any space/buffer that is required + // for global-scope parameters, we will go through the + // shader parameters and entry points yet again, in order + // to actually allocate specific bindings/registers to + // parameters and entry points that need them. + // + _completeBindings(&context, program); // Next we need to create a type layout to reflect the information // we have collected, and we will use the `ScopeLayoutBuilder` @@ -2602,104 +3187,6 @@ RefPtr generateParameterBindings( cbInfo->index = globalConstantBufferBinding.index; } - // After we have laid out all the ordinary global parameters, - // we need to "flush" out any pending data that was associated with - // the global scope as part of dealing with interface-type parameters. - // - _allocateBindingsForPendingData(&context, globalScopeVarLayout->pendingVarLayout); - - // After we have allocated registers/bindings to everything - // in the global scope we will process the parameters - // of the entry points. - // - // Note: at the moment we are laying out *all* information related to global-scope - // parameters (including pending data from interface-type parameters) before - // anything pertaining to entry points. This is a crucial design choice to - // get right, and we might want to revisit it based on experience. - - // In some cases, a user will want to ensure that all the - // entry points they compile get non-overlapping - // registers/bindings. E.g., if you have a vertex and fragment - // shader being compiled together for Vulkan, you probably want distinct - // bindings for their entry-point `uniform` parametres, so - // that they can be used together. - // - // In other cases, however, a user probably doesn't want us - // to conservatively allocate non-overlapping bindings. - // E.g., if they have a bunch of compute shaders in a single - // file, then they probably want each compute shader to - // compute its parameter layout "from scratch" as if the - // others don't exist. - // - // The way we handle this is by putting the entry points of a - // `Program` into groups, and ensuring that within each group - // we allocate parameters that don't overlap, but we don't - // worry about overlap across groups. - // - for( auto entryPointGroup : sharedContext.programLayout->entryPointGroups ) - { - // We save off the allocation state as it was before the entry-point - // group, so that we can restore it to this state after each group. - // - // TODO: We probably ought to wrap all the state relevant to allocation - // of registers/bindings into a single struct/field so that we only - // have one thing to save/restore here even if new state gets added. - // - auto savedGlobalSpaceUsedRangeSets = sharedContext.globalSpaceUsedRangeSets; - auto savedUsedSpaces = sharedContext.usedSpaces; - - // The group will have been allocated a layout that combines the - // usage of all of the contained entry-points, so we just need to - // allocate the entry-point group to be placed after all the global-scope - // parameters. - // - auto entryPointGroupParamsLayout = entryPointGroup->parametersLayout; - completeBindingsForParameter(&context, entryPointGroupParamsLayout); - - _allocateBindingsForPendingData(&context, entryPointGroupParamsLayout->pendingVarLayout); - - // TODO: Should we add the offset information from the group to - // the layout information for each entry point (and thence to its parameters)? - // - // This seems important if we want to allow clients to conveniently - // ignore groups when doing their reflection queries. - - // Once we've allocated bindigns for the parameters of entry points - // in the group, we restore the state for tracking what register/bindings - // are used to where it was before. - // - sharedContext.globalSpaceUsedRangeSets = savedGlobalSpaceUsedRangeSets; - sharedContext.usedSpaces = savedUsedSpaces; - } - - // HACK: we want global parameters to not have to deal with offsetting - // by the `VarLayout` stored in `globalScopeVarLayout`, so we will scan - // through and for any global parameter that used "pending" data, we will manually - // offset all of its resource infos to account for where the global pending data - // got placed. - // - // TODO: A more appropriate solution would be to pass the `globalScopeVarLayout` - // down into the pass that puts layout information onto global parameters in - // the IR, and apply the offsetting there. - // - for( auto& parameterInfo : sharedContext.parameters ) - { - for( auto varLayout : parameterInfo->varLayouts ) - { - auto pendingVarLayout = varLayout->pendingVarLayout; - if(!pendingVarLayout) continue; - - for( auto& resInfo : pendingVarLayout->resourceInfos ) - { - if( auto globalResInfo = globalScopeVarLayout->pendingVarLayout->FindResourceInfo(resInfo.kind) ) - { - resInfo.index += globalResInfo->index; - resInfo.space += globalResInfo->space; - } - } - } - } - programLayout->parametersLayout = globalScopeVarLayout; { @@ -2723,7 +3210,7 @@ ProgramLayout* TargetProgram::getOrCreateLayout(DiagnosticSink* sink) } void generateParameterBindings( - Program* program, + ComponentType* program, TargetRequest* targetReq, DiagnosticSink* sink) { diff --git a/source/slang/slang-reflection.cpp b/source/slang/slang-reflection.cpp index 3871e2ccc..5e6fd10cd 100644 --- a/source/slang/slang-reflection.cpp +++ b/source/slang/slang-reflection.cpp @@ -51,9 +51,9 @@ static inline SlangReflectionTypeLayout* convert(TypeLayout* type) return (SlangReflectionTypeLayout*) type; } -static inline GenericParamLayout* convert(SlangReflectionTypeParameter * typeParam) +static inline SpecializationParamLayout* convert(SlangReflectionTypeParameter * typeParam) { - return (GenericParamLayout*)typeParam; + return (SpecializationParamLayout*) typeParam; } static inline VarDeclBase* convert(SlangReflectionVariable* var) @@ -86,16 +86,6 @@ static inline SlangReflectionEntryPoint* convert(EntryPointLayout* entryPoint) return (SlangReflectionEntryPoint*) entryPoint; } -static inline EntryPointGroupLayout* convert(SlangEntryPointGroupLayout* entryPointGroup) -{ - return (EntryPointGroupLayout*) entryPointGroup; -} - -static inline SlangEntryPointGroupLayout* convert(EntryPointGroupLayout* entryPointGroup) -{ - return (SlangEntryPointGroupLayout*) entryPointGroup; -} - static inline ProgramLayout* convert(SlangReflection* program) { return (ProgramLayout*) program; @@ -865,7 +855,7 @@ SLANG_API int spReflectionTypeLayout_getGenericParamIndex(SlangReflectionTypeLay if(auto genericParamTypeLayout = as(typeLayout)) { - return genericParamTypeLayout->paramIndex; + return (int) genericParamTypeLayout->paramIndex; } else { @@ -1222,7 +1212,7 @@ SLANG_API char const* spReflectionEntryPoint_getName( auto entryPointLayout = convert(inEntryPoint); if(!entryPointLayout) return 0; - return getText(entryPointLayout->entryPoint->getName()).begin(); + return getText(entryPointLayout->entryPoint.GetName()).begin(); } SLANG_API unsigned spReflectionEntryPoint_getParameterCount( @@ -1270,7 +1260,7 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( SlangUInt sizeAlongAxis[3] = { 1, 1, 1 }; // First look for the HLSL case, where we have an attribute attached to the entry point function - auto numThreadsAttribute = entryPointFunc->FindModifier(); + auto numThreadsAttribute = entryPointFunc.getDecl()->FindModifier(); if (numThreadsAttribute) { sizeAlongAxis[0] = numThreadsAttribute->x; @@ -1281,7 +1271,7 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( { // Fall back to the GLSL case, which requires a search over global-scope declarations // to look for as with the `local_size_*` qualifier - auto module = as(entryPointFunc->ParentDecl); + auto module = as(entryPointFunc.getDecl()->ParentDecl); if (module) { for (auto dd : module->Members) @@ -1323,11 +1313,43 @@ SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput( return (entryPointLayout->flags & EntryPointLayout::Flag::usesAnySampleRateInput) != 0; } +SLANG_API SlangReflectionVariableLayout* spReflectionEntryPoint_getVarLayout( + SlangReflectionEntryPoint* inEntryPoint) +{ + auto entryPointLayout = convert(inEntryPoint); + if(!entryPointLayout) + return nullptr; + + return convert(entryPointLayout->parametersLayout); +} + +static bool hasDefaultConstantBuffer(ScopeLayout* layout) +{ + auto typeLayout = layout->parametersLayout->getTypeLayout(); + return as(typeLayout) != nullptr; +} + +SLANG_API int spReflectionEntryPoint_hasDefaultConstantBuffer( + SlangReflectionEntryPoint* inEntryPoint) +{ + auto entryPointLayout = convert(inEntryPoint); + if(!entryPointLayout) + return 0; + + return hasDefaultConstantBuffer(entryPointLayout); +} + + // SlangReflectionTypeParameter SLANG_API char const* spReflectionTypeParameter_GetName(SlangReflectionTypeParameter * inTypeParam) { - auto typeParam = convert(inTypeParam); - return typeParam->decl->getName()->text.getBuffer(); + auto specializationParam = convert(inTypeParam); + if( auto genericParamLayout = as(specializationParam) ) + { + return genericParamLayout->decl->getName()->text.getBuffer(); + } + // TODO: Add case for existential type parameter? They don't have as simple of a notion of "name" as the generic case... + return nullptr; } SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParameter * inTypeParam) @@ -1338,16 +1360,34 @@ SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParamet SLANG_API unsigned int spReflectionTypeParameter_GetConstraintCount(SlangReflectionTypeParameter* inTypeParam) { - auto typeParam = convert(inTypeParam); - auto constraints = typeParam->decl->getMembersOfType(); - return (unsigned int)constraints.getCount(); + auto specializationParam = convert(inTypeParam); + if(auto genericParamLayout = as(specializationParam)) + { + if( auto globalGenericParamDecl = as(genericParamLayout->decl) ) + { + auto constraints = globalGenericParamDecl->getMembersOfType(); + return (unsigned int)constraints.getCount(); + } + // TODO: Add case for entry-point generic parameters. + } + // TODO: Add case for existential type parameters. + return 0; } SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex(SlangReflectionTypeParameter * inTypeParam, unsigned index) { - auto typeParam = convert(inTypeParam); - auto constraints = typeParam->decl->getMembersOfType(); - return (SlangReflectionType*)constraints.toArray()[index]->sup.Ptr(); + auto specializationParam = convert(inTypeParam); + if(auto genericParamLayout = as(specializationParam)) + { + if( auto globalGenericParamDecl = as(genericParamLayout->decl) ) + { + auto constraints = globalGenericParamDecl->getMembersOfType(); + return (SlangReflectionType*)constraints.toArray()[index]->sup.Ptr(); + } + // TODO: Add case for entry-point generic parameters. + } + // TODO: Add case for existential type parameters. + return 0; } // Shader Reflection @@ -1379,22 +1419,32 @@ SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflec SLANG_API unsigned int spReflection_GetTypeParameterCount(SlangReflection * reflection) { auto program = convert(reflection); - return (unsigned int)program->globalGenericParams.getCount(); + return (unsigned int) program->specializationParams.getCount(); } SLANG_API SlangReflectionTypeParameter* spReflection_GetTypeParameterByIndex(SlangReflection * reflection, unsigned int index) { auto program = convert(reflection); - return (SlangReflectionTypeParameter*)program->globalGenericParams[index].Ptr(); + return (SlangReflectionTypeParameter*) program->specializationParams[index].Ptr(); } SLANG_API SlangReflectionTypeParameter * spReflection_FindTypeParameter(SlangReflection * inProgram, char const * name) { auto program = convert(inProgram); if (!program) return nullptr; - GenericParamLayout * result = nullptr; - program->globalGenericParamsMap.TryGetValue(name, result); - return (SlangReflectionTypeParameter*)result; + for( auto& param : program->specializationParams ) + { + auto genericParamLayout = as(param); + if(!genericParamLayout) + continue; + + if(getText(genericParamLayout->decl->getName()) != UnownedTerminatedStringSlice(name)) + continue; + + return (SlangReflectionTypeParameter*) genericParamLayout; + } + + return 0; } SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* inProgram) @@ -1421,7 +1471,7 @@ SLANG_API SlangReflectionEntryPoint* spReflection_findEntryPointByName(SlangRefl // TODO: improve on naive linear search for(auto ep : program->entryPoints) { - if(ep->entryPoint->getName()->text == name) + if(ep->entryPoint.GetName()->text == name) { return convert(ep); } @@ -1430,75 +1480,6 @@ SLANG_API SlangReflectionEntryPoint* spReflection_findEntryPointByName(SlangRefl return nullptr; } - -SLANG_API SlangInt spReflection_getEntryPointGroupCount(SlangReflection* inProgram) -{ - auto program = convert(inProgram); - if(!program) return 0; - - return program->entryPointGroups.getCount(); -} - -SLANG_API SlangEntryPointGroupLayout* spReflection_getEntryPointGroupByIndex(SlangReflection* inProgram, SlangInt index) -{ - auto program = convert(inProgram); - if(!program) return 0; - - if(index < 0) return nullptr; - if(index >= program->entryPointGroups.getCount()) return nullptr; - - return convert(program->entryPointGroups[(int) index].Ptr()); -} - -SLANG_API SlangInt spEntryPointGroupLayout_getEntryPointCount(SlangEntryPointGroupLayout* inGroup) -{ - auto group = convert(inGroup); - if(!group) return 0; - - return group->entryPoints.getCount(); -} - -SLANG_API SlangReflectionEntryPoint* spEntryPointGroupLayout_getEntryPointByIndex(SlangEntryPointGroupLayout* inGroup, SlangInt index) -{ - auto group = convert(inGroup); - if(!group) return 0; - - if(index < 0) return nullptr; - if(index >= group->entryPoints.getCount()) return nullptr; - - return convert(group->entryPoints[(int) index].Ptr()); -} - -SLANG_API SlangReflectionVariableLayout* spEntryPointGroupLayout_getVarLayout(SlangEntryPointGroupLayout* inGroup) -{ - auto group = convert(inGroup); - if(!group) return 0; - - return convert(group->parametersLayout); -} - -SLANG_API SlangInt spEntryPointGroupLayout_getParameterCount(SlangEntryPointGroupLayout* inGroup) -{ - auto groupLayout = convert(inGroup); - if(!groupLayout) return 0; - - auto& params = groupLayout->group->getShaderParams(); - return params.getCount(); -} - -SLANG_API SlangReflectionVariableLayout* spEntryPointGroupLayout_getParameterByIndex(SlangEntryPointGroupLayout* inGroup, SlangInt index) -{ - auto groupLayout = convert(inGroup); - if(!groupLayout) return nullptr; - - auto& params = groupLayout->group->getShaderParams(); - if(index < 0) return nullptr; - if(index >= params.getCount()) return nullptr; - - return convert(getScopeStructLayout(groupLayout)->fields[index]); -} - - SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* inProgram) { auto program = convert(inProgram); @@ -1531,7 +1512,7 @@ SLANG_API SlangReflectionType* spReflection_specializeType( auto unspecializedType = convert(inType); if(!unspecializedType) return nullptr; - auto linkage = programLayout->getProgram()->getLinkageImpl(); + auto linkage = programLayout->getProgram()->getLinkage(); DiagnosticSink sink(linkage->getSourceManager()); diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 08d671241..c4152d78c 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -2766,10 +2766,10 @@ String ExistentialSpecializedType::ToString() String result; result.append("__ExistentialSpecializedType("); result.append(baseType->ToString()); - for( auto arg : slots.args ) + for( auto arg : args ) { result.append(", "); - result.append(arg.type->ToString()); + result.append(arg.val->ToString()); } result.append(")"); return result; @@ -2784,16 +2784,19 @@ bool ExistentialSpecializedType::EqualsImpl(Type * type) if(!baseType->Equals(other->baseType)) return false; - auto argCount = slots.args.getCount(); - if(argCount != other->slots.args.getCount()) + auto argCount = args.getCount(); + if(argCount != other->args.getCount()) return false; for( Index ii = 0; ii < argCount; ++ii ) { - if(!slots.args[ii].type->Equals(other->slots.args[ii].type)) + auto arg = args[ii]; + auto otherArg = other->args[ii]; + + if(!arg.val->EqualsVal(otherArg.val)) return false; - if(!slots.args[ii].witness->EqualsVal(other->slots.args[ii].witness)) + if(!areValsEqual(arg.witness, otherArg.witness)) return false; } return true; @@ -2803,51 +2806,63 @@ int ExistentialSpecializedType::GetHashCode() { Hasher hasher; hasher.hashObject(baseType); - for(auto arg : slots.args) + for(auto arg : args) { - hasher.hashObject(arg.type); - hasher.hashObject(arg.witness); + hasher.hashObject(arg.val); + if(auto witness = arg.witness) + hasher.hashObject(witness); } return hasher.getResult(); } +RefPtr getCanonicalValue(Val* val) +{ + if(!val) + return nullptr; + if(auto type = as(val)) + { + return type->GetCanonicalType(); + } + // TODO: We may eventually need/want some sort of canonicalization + // for non-type values, but for now there is nothing to do. + return val; +} + RefPtr ExistentialSpecializedType::CreateCanonicalType() { RefPtr canType = new ExistentialSpecializedType(); canType->setSession(getSession()); canType->baseType = baseType->GetCanonicalType(); - for( auto paramType : slots.paramTypes ) + for( auto arg : args ) { - canType->slots.paramTypes.add( paramType->GetCanonicalType() ); - } - for( auto arg : slots.args ) - { - ExistentialTypeSlots::Arg canArg; - canArg.type = arg.type->GetCanonicalType(); - canArg.witness = arg.witness; - canType->slots.args.add(canArg); + ExpandedSpecializationArg canArg; + canArg.val = getCanonicalValue(arg.val); + canArg.witness = getCanonicalValue(arg.witness); + canType->args.add(canArg); } return canType; } +RefPtr substituteImpl(Val* val, SubstitutionSet subst, int* ioDiff) +{ + if(!val) return nullptr; + return val->SubstituteImpl(subst, ioDiff); +} + RefPtr ExistentialSpecializedType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) { int diff = 0; auto substBaseType = baseType->SubstituteImpl(subst, &diff).as(); - ExistentialTypeSlots substSlots; - for( auto paramType : slots.paramTypes ) - { - substSlots.paramTypes.add( paramType->SubstituteImpl(subst, &diff).as() ); - } - for( auto arg : slots.args ) + ExpandedSpecializationArgs substArgs; + for( auto arg : args ) { - ExistentialTypeSlots::Arg substArg; - substArg.type = arg.type->SubstituteImpl(subst, &diff).as(); - substArg.witness = arg.witness->SubstituteImpl(subst, &diff); - substSlots.args.add(substArg); + ExpandedSpecializationArg substArg; + substArg.val = Slang::substituteImpl(arg.val, subst, &diff); + substArg.witness = Slang::substituteImpl(arg.witness, subst, &diff); + substArgs.add(substArg); } if(!diff) @@ -2858,7 +2873,7 @@ RefPtr ExistentialSpecializedType::SubstituteImpl(SubstitutionSet subst, in RefPtr substType = new ExistentialSpecializedType(); substType->setSession(getSession()); substType->baseType = substBaseType; - substType->slots = substSlots; + substType->args = substArgs; return substType; } diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index d156750be..88a2ca847 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -1132,31 +1132,31 @@ namespace Slang typedef Dictionary> AttributeArgumentValueDict; - /// Collects information about existential type parameters and their arguments. - struct ExistentialTypeSlots + struct SpecializationParam { - /// For each type parameter, holds the interface/existential type that constrains it. - List> paramTypes; - - /// An argument for an existential type parameter. - /// - /// Comprises a concrete type and a witness for its conformance to the desired - /// interface/existential type for the corresponding parameter. - /// - struct Arg + enum class Flavor { - RefPtr type; - RefPtr witness; + GenericType, + GenericValue, + ExistentialType, + ExistentialValue, }; + Flavor flavor; + RefPtr object; + }; + typedef List SpecializationParams; - /// Any arguments provided for the existential type parameters. - /// - /// It is possible for `args` to be empty even if `paramTypes` is non-empty; - /// that situation represents an unspecialized program or entry point. - /// - List args; + struct SpecializationArg + { + RefPtr val; }; + typedef List SpecializationArgs; + struct ExpandedSpecializationArg : SpecializationArg + { + RefPtr witness; + }; + typedef List ExpandedSpecializationArgs; // Generate class definition for all syntax classes #define SYNTAX_FIELD(TYPE, NAME) TYPE NAME; diff --git a/source/slang/slang-type-defs.h b/source/slang/slang-type-defs.h index d9907bafe..7afc23411 100644 --- a/source/slang/slang-type-defs.h +++ b/source/slang/slang-type-defs.h @@ -479,7 +479,7 @@ END_SYNTAX_CLASS() SYNTAX_CLASS(ExistentialSpecializedType, Type) RAW( RefPtr baseType; - ExistentialTypeSlots slots; + ExpandedSpecializationArgs args; virtual String ToString() override; virtual bool EqualsImpl(Type * type) override; diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 9b6571372..f76b29a51 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1222,19 +1222,11 @@ EntryPointLayout* EntryPointLayout::getAbsoluteLayout( { adjustedLayout->resultLayout = baseResultLayout->getAbsoluteLayout(parentLayout); } - adjustedLayout->taggedUnionTypeLayouts = this->taggedUnionTypeLayouts; m_absoluteLayout = adjustedLayout; return adjustedLayout; } -EntryPointLayout* EntryPointLayout::getAbsoluteLayout(EntryPointGroupLayout* parentGroup) -{ - SLANG_ASSERT(parentGroup); - return getAbsoluteLayout(parentGroup->parametersLayout); -} - - VarLayout* VarLayout::getAbsoluteLayout(VarLayout* parentAbsoluteLayout) { if( !m_absoluteLayout ) @@ -1933,9 +1925,31 @@ static TypeLayoutResult _createTypeLayout( return _createTypeLayout(subContext, type); } -int findGenericParam(List> & genericParameters, GlobalGenericParamDecl * decl) +RefPtr findGlobalGenericSpecializationArg( + TypeLayoutContext const& context, + GlobalGenericParamDecl* decl) { - return (int)genericParameters.findFirstIndex([=](RefPtr & x) {return x->decl.Ptr() == decl; }); + RefPtr arg; + context.programLayout->globalGenericArgs.TryGetValue(decl, arg); + return arg.as(); +} + +Index findGlobalGenericSpecializationParamIndex( + ComponentType* type, + GlobalGenericParamDecl* decl) +{ + Index paramCount = type->getSpecializationParamCount(); + for( Index pp = 0; pp < paramCount; ++pp ) + { + auto param = type->getSpecializationParam(pp); + if(param.flavor != SpecializationParam::Flavor::GenericType) + continue; + if(param.object.Ptr() != decl) + continue; + + return pp; + } + return -1; } // When constructing a new var layout from an existing one, @@ -2393,6 +2407,37 @@ TypeLayoutResult StructTypeLayoutBuilder::getTypeLayoutResult() return TypeLayoutResult(m_typeLayout, m_info); } +static TypeLayoutResult _createTypeLayoutForGlobalGenericTypeParam( + TypeLayoutContext const& context, + Type* type, + GlobalGenericParamDecl* globalGenericParamDecl) +{ + SimpleLayoutInfo info; + info.alignment = 0; + info.size = 0; + info.kind = LayoutResourceKind::GenericResource; + + RefPtr typeLayout = new GenericParamTypeLayout(); + // we should have already populated ProgramLayout::genericEntryPointParams list at this point, + // so we can find the index of this generic param decl in the list + typeLayout->type = type; + typeLayout->paramIndex = findGlobalGenericSpecializationParamIndex( + context.programLayout->getProgram(), + globalGenericParamDecl); + typeLayout->rules = context.rules; + typeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; + + return TypeLayoutResult(typeLayout, info); +} + +RefPtr createTypeLayoutForGlobalGenericTypeParam( + TypeLayoutContext const& context, + Type* type, + GlobalGenericParamDecl* globalGenericParamDecl) +{ + return _createTypeLayoutForGlobalGenericTypeParam(context, type, globalGenericParamDecl).layout; +} + static TypeLayoutResult _createTypeLayout( TypeLayoutContext const& context, Type* type) @@ -2825,7 +2870,7 @@ static TypeLayoutResult _createTypeLayout( // to all the incoming specialized type slots that haven't already // been consumed/claimed by preceding fields. // - auto fieldLayoutContext = context.withExistentialTypeSlotsOffsetBy(baseExistentialSlotIndex); + auto fieldLayoutContext = context.withSpecializationArgsOffsetBy(baseExistentialSlotIndex); TypeLayoutResult fieldResult = _createTypeLayout( fieldLayoutContext, @@ -2863,22 +2908,25 @@ static TypeLayoutResult _createTypeLayout( return typeLayoutBuilder.getTypeLayoutResult(); } - else if (auto globalGenParam = declRef.as()) + else if (auto globalGenericParamDecl = declRef.as()) { - SimpleLayoutInfo info; - info.alignment = 0; - info.size = 0; - info.kind = LayoutResourceKind::GenericResource; - - auto genParamTypeLayout = new GenericParamTypeLayout(); - // we should have already populated ProgramLayout::genericEntryPointParams list at this point, - // so we can find the index of this generic param decl in the list - genParamTypeLayout->type = type; - genParamTypeLayout->paramIndex = findGenericParam(context.programLayout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl()); - genParamTypeLayout->rules = rules; - genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; - - return TypeLayoutResult(genParamTypeLayout, info); + if( auto concreteType = findGlobalGenericSpecializationArg( + context, + globalGenericParamDecl) ) + { + // If we know what concrete type has been used to specialize + // the global generic type parameter, then we should use + // the concrete type instead. + // + return _createTypeLayout(context, concreteType); + } + else + { + // Otherwise we must create a type layout that represents + // the generic type parameter itself. + // + return _createTypeLayoutForGlobalGenericTypeParam(context, type, globalGenericParamDecl); + } } else if (auto assocTypeParam = declRef.as()) { @@ -2934,9 +2982,11 @@ static TypeLayoutResult _createTypeLayout( // If there are any concrete types available, the first one will be // the value that should be plugged into the slot we just introduced. // - if( context.existentialTypeArgCount ) + if( context.specializationArgCount ) { - RefPtr concreteType = context.existentialTypeArgs[0].type; + auto& specializationArg = context.specializationArgs[0]; + RefPtr concreteType = specializationArg.val.as(); + SLANG_ASSERT(concreteType); RefPtr concreteTypeLayout = createTypeLayout(context, concreteType); @@ -3046,9 +3096,9 @@ static TypeLayoutResult _createTypeLayout( } else if( auto existentialSpecializedType = as(type) ) { - TypeLayoutContext subContext = context.withExistentialTypeArgs( - existentialSpecializedType->slots.args.getCount(), - existentialSpecializedType->slots.args.getBuffer()); + TypeLayoutContext subContext = context.withSpecializationArgs( + existentialSpecializedType->args.getBuffer(), + existentialSpecializedType->args.getCount()); auto baseTypeLayoutResult = _createTypeLayout( subContext, diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index e066c2700..b7b3c3207 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -622,7 +622,7 @@ class GenericParamTypeLayout : public TypeLayout { public: RefPtr getGlobalGenericParamDecl(); - int paramIndex = 0; + Index paramIndex = 0; }; /// Layout information for a tagged union type. @@ -667,8 +667,6 @@ public: StructTypeLayout* getScopeStructLayout( ScopeLayout* programLayout); -class EntryPointGroupLayout; - // Layout information for a single shader entry point // within a program // @@ -682,7 +680,10 @@ class EntryPointLayout : public ScopeLayout { public: // The corresponding function declaration - RefPtr entryPoint; + DeclRef entryPoint; + + DeclRef getFuncDeclRef() { return entryPoint; } + FuncDecl* getFuncDecl() { return entryPoint.getDecl(); } // The shader profile that was used to compile the entry point Profile profile; @@ -696,30 +697,38 @@ public: }; unsigned flags = 0; - /// Layouts for all tagged union types required by this entry point. - /// - /// These are any tagged union types used by the generic - /// arguments that this entry point is being compiled with. - List> taggedUnionTypeLayouts; - EntryPointLayout* getAbsoluteLayout(VarLayout* parentLayout); - EntryPointLayout* getAbsoluteLayout(EntryPointGroupLayout* parentGroup); RefPtr m_absoluteLayout; }; -class EntryPointGroupLayout : public ScopeLayout + /// Reflection/layout information about a specialization parameter +class SpecializationParamLayout : public Layout { public: - RefPtr group; - List> entryPoints; + Index index; }; -class GenericParamLayout : public Layout + /// Reflection/layout information about a generic specialization parameter +class GenericSpecializationParamLayout : public SpecializationParamLayout { public: - RefPtr decl; - int index; + /// The declaration of the generic parameter. + /// + /// Could be any subclass of `Decl` that represents a generic value or type parameter. + RefPtr decl; +}; + + /// Reflection/layout information about an existential/interface specialization parameter. +class ExistentialSpecializationParamLayout : public SpecializationParamLayout +{ +public: + /// The type that needs to be specialized. + /// + /// Currently, this will be an `interface` type that any concrete + /// type argument getting plugged in must conform to. + /// + RefPtr type; }; // Layout information for the global scope of a program @@ -748,7 +757,7 @@ public: TargetProgram* getTargetProgram() { return targetProgram; } TargetRequest* getTargetReq() { return targetProgram->getTargetReq(); } - Program* getProgram() { return targetProgram->getProgram(); } + ComponentType* getProgram() { return targetProgram->getProgram(); } // We catalog the requested entry points here, @@ -756,12 +765,21 @@ public: // will (eventually) belong there... List> entryPoints; - // Entry points can also be grouped for layout purposes (e.g., to form - // ray-tracing hit groups), so this array represents those groups - List> entryPointGroups; + /// Reflection information on (unspecialized) specialization parameters. + List> specializationParams; - List> globalGenericParams; - Dictionary globalGenericParamsMap; + /// Concrete argument values that were provided to specific global generic parameters. + /// + /// Not useful for reflection, but valuable for code generation. + /// + Dictionary> globalGenericArgs; + + /// Layouts for all tagged union types required by this program + /// + /// These are any tagged union types used by the specialization + /// arguments that have been used to specialize the program. + /// + List> taggedUnionTypeLayouts; }; StructTypeLayout* getGlobalStructLayout( @@ -908,8 +926,6 @@ struct LayoutRulesFamilyImpl virtual LayoutRulesImpl* getShaderRecordConstantBufferRules() = 0; }; -typedef List> GenericParamLayouts; - struct TypeLayoutContext { // The layout rules to use (e.g., we compute @@ -930,10 +946,10 @@ struct TypeLayoutContext MatrixLayoutMode matrixLayoutMode; // The concrete types (if any) to plug into the currently in-scope - // existential type slots. + // specialization params. // - Int existentialTypeArgCount = 0; - ExistentialTypeSlots::Arg const* existentialTypeArgs = nullptr; + Int specializationArgCount = 0; + ExpandedSpecializationArg const* specializationArgs = nullptr; LayoutRulesImpl* getRules() { return rules; } LayoutRulesFamilyImpl* getRulesFamily() const { return rules->getLayoutRulesFamily(); } @@ -952,29 +968,29 @@ struct TypeLayoutContext return result; } - TypeLayoutContext withExistentialTypeArgs( - Int argCount, - ExistentialTypeSlots::Arg const* args) const + TypeLayoutContext withSpecializationArgs( + ExpandedSpecializationArg const* args, + Int argCount) const { TypeLayoutContext result = *this; - result.existentialTypeArgCount = argCount; - result.existentialTypeArgs = args; + result.specializationArgCount = argCount; + result.specializationArgs = args; return result; } - TypeLayoutContext withExistentialTypeSlotsOffsetBy( + TypeLayoutContext withSpecializationArgsOffsetBy( Int offset) const { TypeLayoutContext result = *this; - if( existentialTypeArgCount > offset ) + if( specializationArgCount > offset ) { - result.existentialTypeArgCount = existentialTypeArgCount - offset; - result.existentialTypeArgs = existentialTypeArgs + offset; + result.specializationArgCount = specializationArgCount - offset; + result.specializationArgs = specializationArgs + offset; } else { - result.existentialTypeArgCount = 0; - result.existentialTypeArgs = nullptr; + result.specializationArgCount = 0; + result.specializationArgs = nullptr; } return result; @@ -1061,6 +1077,8 @@ public: /// TypeLayoutResult getTypeLayoutResult(); + UniformLayoutInfo* getStructLayoutInfo() { return &m_info; } + private: /// The layout rules being used, if layout has begun. LayoutRulesImpl* m_rules = nullptr; @@ -1139,8 +1157,16 @@ createStructuredBufferTypeLayout( RefPtr structuredBufferType, RefPtr elementType); -int findGenericParam(List> & genericParameters, GlobalGenericParamDecl * decl); -// + /// Create a type layout for an unspecialized `globalGenericParamDecl`. +RefPtr createTypeLayoutForGlobalGenericTypeParam( + TypeLayoutContext const& context, + Type* type, + GlobalGenericParamDecl* globalGenericParamDecl); + + /// Find the concrete type (if any) that was plugged in for the global generic type parameter `decl`. +RefPtr findGlobalGenericSpecializationArg( + TypeLayoutContext const& context, + GlobalGenericParamDecl* decl); // Given an existing type layout `oldTypeLayout`, apply offsets // to any contained fields based on the resource infos in `offsetVarLayout`. diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index b18e4d4d9..b030e5cf9 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -436,44 +436,31 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModule( return asExternal(module); } -SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createProgram( - slang::ProgramDesc const& desc, - slang::IProgram** outProgram) +SLANG_NO_THROW slang::IComponentType* SLANG_MCALL Linkage::createCompositeComponentType( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + ISlangBlob** outDiagnostics) { - RefPtr program = new Program(this); - - auto itemCount = desc.itemCount; - for(SlangInt ii = 0; ii < itemCount; ++ii) - { - auto& item = desc.items[ii]; - switch(item.kind) - { - case slang::ProgramDesc::Item::Kind::Program: - { - Program* existingProgram = asInternal(item.program); - for(auto referencedModule : existingProgram->getModuleDependencies()) - { - program->addReferencedLeafModule(referencedModule); - } - - // TODO: Need to decide whether to include the entry points as well... - } - break; + // Attempting to create a "composite" of just one component type should + // just return the component type itself, to avoid redundant work. + // + if( componentTypeCount == 1) + return componentTypes[0]; - case slang::ProgramDesc::Item::Kind::Module: - { - Module* module = asInternal(item.module); - program->addReferencedModule(module); - } - break; + DiagnosticSink sink(getSourceManager()); - default: - return SLANG_E_INVALID_ARG; - } + List> childComponents; + for( Int cc = 0; cc < componentTypeCount; ++cc ) + { + childComponents.add(asInternal(componentTypes[cc])); } - *outProgram = asExternal(program.detach()); - return SLANG_OK; + RefPtr composite = CompositeComponentType::create( + this, + childComponents); + + sink.getBlobIfNeeded(outDiagnostics); + return asExternal(composite.detach()); } SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType( @@ -781,7 +768,9 @@ RefPtr checkProperType( TypeExp typeExp, DiagnosticSink* sink); -Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) +Type* ComponentType::getTypeFromString( + String const& typeStr, + DiagnosticSink* sink) { // If we've looked up this type name before, // then we can re-use it. @@ -803,7 +792,7 @@ Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) for(auto module : getModuleDependencies()) scopesToTry.add(module->getModuleDecl()->scope); - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); for(auto& s : scopesToTry) { RefPtr typeExpr = linkage->parseTypeString( @@ -899,10 +888,16 @@ void FrontEndCompileRequest::parseTranslationUnit( } } -RefPtr createUnspecializedProgram( +RefPtr createUnspecializedGlobalComponentType( + FrontEndCompileRequest* compileRequest); + +RefPtr createUnspecializedGlobalAndEntryPointsComponentType( FrontEndCompileRequest* compileRequest); -RefPtr createSpecializedProgram( +RefPtr createSpecializedGlobalComponentType( + EndToEndCompileRequest* endToEndReq); + +RefPtr createSpecializedGlobalAndEntryPointsComponentType( EndToEndCompileRequest* endToEndReq); void FrontEndCompileRequest::checkAllTranslationUnits() @@ -1032,7 +1027,11 @@ SlangResult FrontEndCompileRequest::executeActionsInner() // Look up all the entry points that are expected, // and use them to populate the `program` member. // - m_program = createUnspecializedProgram(this); + m_globalComponentType = createUnspecializedGlobalComponentType(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + m_globalAndEntryPointsComponentType = createUnspecializedGlobalAndEntryPointsComponentType(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1052,7 +1051,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner() // for(auto targetReq : getLinkage()->targets) { - auto targetProgram = m_program->getTargetProgram(targetReq); + auto targetProgram = m_globalAndEntryPointsComponentType->getTargetProgram(targetReq); targetProgram->getOrCreateLayout(getSink()); } if (getSink()->GetErrorCount() != 0) @@ -1064,7 +1063,7 @@ SlangResult FrontEndCompileRequest::executeActionsInner() BackEndCompileRequest::BackEndCompileRequest( Linkage* linkage, DiagnosticSink* sink, - Program* program) + ComponentType* program) : CompileRequestBase(linkage, sink) , m_program(program) {} @@ -1146,7 +1145,8 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // that was computed in the front-end for all subsequent // reflection queries, etc. // - m_specializedProgram = getUnspecializedProgram(); + m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); + m_specializedGlobalAndEntryPointsComponentType = getUnspecializedGlobalAndEntryPointsComponentType(); return SLANG_OK; } @@ -1156,7 +1156,11 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // if (passThrough == PassThroughMode::None) { - m_specializedProgram = createSpecializedProgram(this); + m_specializedGlobalComponentType = createSpecializedGlobalComponentType(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + m_specializedGlobalAndEntryPointsComponentType = createSpecializedGlobalAndEntryPointsComponentType(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1166,7 +1170,7 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // for (auto targetReq : getLinkage()->targets) { - auto targetProgram = m_specializedProgram->getTargetProgram(targetReq); + auto targetProgram = m_specializedGlobalAndEntryPointsComponentType->getTargetProgram(targetReq); targetProgram->getOrCreateLayout(getSink()); } if (getSink()->GetErrorCount() != 0) @@ -1178,20 +1182,27 @@ SlangResult EndToEndCompileRequest::executeActionsInner() // to make sure that the logic in `generateOutput` // sees something worth processing. // - auto specializedProgram = new Program(getLinkage()); - m_specializedProgram = specializedProgram; + List> dummyEntryPoints; for(auto entryPointReq : getFrontEndReq()->getEntryPointReqs()) { - RefPtr entryPoint = EntryPoint::createDummyForPassThrough( + RefPtr dummyEntryPoint = EntryPoint::createDummyForPassThrough( + getLinkage(), entryPointReq->getName(), entryPointReq->getProfile()); - specializedProgram->addEntryPoint(entryPoint, getSink()); + dummyEntryPoints.add(dummyEntryPoint); } + + RefPtr composedProgram = CompositeComponentType::create( + getLinkage(), + dummyEntryPoints); + + m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); + m_specializedGlobalAndEntryPointsComponentType = composedProgram; } // Generate output code, in whatever format was requested - getBackEndReq()->setProgram(getSpecializedProgram()); + getBackEndReq()->setProgram(getSpecializedGlobalAndEntryPointsComponentType()); generateOutput(this); if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; @@ -1324,7 +1335,7 @@ int EndToEndCompileRequest::addEntryPoint( EntryPointInfo entryPointInfo; for (auto typeName : genericTypeNames) - entryPointInfo.genericArgStrings.add(typeName); + entryPointInfo.specializationArgStrings.add(typeName); Index result = entryPoints.getCount(); entryPoints.add(_Move(entryPointInfo)); @@ -1617,8 +1628,10 @@ void FilePathDependencyList::addDependency(Module* module) // Module::Module(Linkage* linkage) - : m_linkage(linkage) -{} + : ComponentType(linkage) +{ + addModuleDependency(this); +} ISlangUnknown* Module::getInterface(const Guid& guid) { @@ -1638,35 +1651,40 @@ void Module::addFilePathDependency(String const& path) m_filePathDependencyList.addDependency(path); } -// Program +void Module::setModuleDecl(ModuleDecl* moduleDecl) +{ + m_moduleDecl = moduleDecl; +} + +// ComponentType -static const Guid IID_IProgram = SLANG_UUID_IProgram; +static const Guid IID_IComponentType = SLANG_UUID_IComponentType; -Program::Program(Linkage* linkage) +ComponentType::ComponentType(Linkage* linkage) : m_linkage(linkage) {} -ISlangUnknown* Program::getInterface(Guid const& guid) +ISlangUnknown* ComponentType::getInterface(Guid const& guid) { if(guid == IID_ISlangUnknown - || guid == IID_IProgram) + || guid == IID_IComponentType) { - return static_cast(this); + return static_cast(this); } return nullptr; } -SLANG_NO_THROW slang::ISession* SLANG_MCALL Program::getSession() +SLANG_NO_THROW slang::ISession* SLANG_MCALL ComponentType::getSession() { return m_linkage; } -SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout( +SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL ComponentType::getLayout( Int targetIndex, slang::IBlob** outDiagnostics) { - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); if(targetIndex < 0 || targetIndex >= linkage->targets.getCount()) return nullptr; auto target = linkage->targets[targetIndex]; @@ -1678,13 +1696,13 @@ SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL Program::getLayout( return asExternal(programLayout); } -SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode( +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCode( SlangInt entryPointIndex, Int targetIndex, slang::IBlob** outCode, slang::IBlob** outDiagnostics) { - auto linkage = getLinkageImpl(); + auto linkage = getLinkage(); if(targetIndex < 0 || targetIndex >= linkage->targets.getCount()) return SLANG_E_INVALID_ARG; auto target = linkage->targets[targetIndex]; @@ -1702,57 +1720,535 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Program::getEntryPointCode( return SLANG_OK; } +RefPtr ComponentType::specialize( + SpecializationArg const* inSpecializationArgs, + SlangInt specializationArgCount, + DiagnosticSink* sink) +{ + List specializationArgs; + specializationArgs.addRange( + inSpecializationArgs, + specializationArgCount); + + // We next need to validate that the specialization arguments + // make sense, and also expand them to include any derived data + // (e.g., interface conformance witnesses) that doesn't get + // passed explicitly through the API interface. + // + RefPtr specializationInfo = _validateSpecializationArgs( + specializationArgs.getBuffer(), + specializationArgCount, + sink); + + return new SpecializedComponentType( + this, + specializationInfo, + specializationArgs, + sink); +} -void Program::addReferencedModule(Module* module) +SLANG_NO_THROW slang::IComponentType* SLANG_MCALL ComponentType::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics) { - m_moduleDependencyList.addDependency(module); - m_filePathDependencyList.addDependency(module); + DiagnosticSink sink(getLinkage()->getSourceManager()); + + // First let's check if the number of arguments given matches + // the number of parameters that are present on this component type. + // + auto specializationParamCount = getSpecializationParamCount(); + if( specializationArgCount != specializationParamCount ) + { + // TODO: diagnose + sink.getBlobIfNeeded(outDiagnostics); + return nullptr; + } + + List expandedArgs; + for( Int aa = 0; aa < specializationArgCount; ++aa ) + { + auto apiArg = specializationArgs[aa]; + + SpecializationArg expandedArg; + switch(apiArg.kind) + { + case slang::SpecializationArg::Kind::Type: + expandedArg.val = asInternal(apiArg.type); + break; + + default: + sink.getBlobIfNeeded(outDiagnostics); + return nullptr; + } + expandedArgs.add(expandedArg); + } + + auto specializedComponentType = specialize( + expandedArgs.getBuffer(), + expandedArgs.getCount(), + &sink); + + sink.getBlobIfNeeded(outDiagnostics); + + return specializedComponentType; } -void Program::addReferencedLeafModule(Module* module) + /// Visitor used by `ComponentType::enumerateModules` +struct EnumerateModulesVisitor : ComponentTypeVisitor { - m_moduleDependencyList.addLeafDependency(module); - m_filePathDependencyList.addDependency(module); + EnumerateModulesVisitor(ComponentType::EnumerateModulesCallback callback, void* userData) + : m_callback(callback) + , m_userData(userData) + {} + + ComponentType::EnumerateModulesCallback m_callback; + void* m_userData; + + void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + { + m_callback(module, m_userData); + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(legacy, specializationInfo); + } +}; + + +void ComponentType::enumerateModules(EnumerateModulesCallback callback, void* userData) +{ + EnumerateModulesVisitor visitor(callback, userData); + acceptVisitor(&visitor, nullptr); +} + + /// Visitor used by `ComponentType::enumerateIRModules` +struct EnumerateIRModulesVisitor : ComponentTypeVisitor +{ + EnumerateIRModulesVisitor(ComponentType::EnumerateIRModulesCallback callback, void* userData) + : m_callback(callback) + , m_userData(userData) + {} + + ComponentType::EnumerateIRModulesCallback m_callback; + void* m_userData; + + void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + { + m_callback(module->getIRModule(), m_userData); + } + + void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + + m_callback(specialized->getIRModule(), m_userData); + } + + void visitLegacy(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(legacy, specializationInfo); + } +}; + +void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData) +{ + EnumerateIRModulesVisitor visitor(callback, userData); + acceptVisitor(&visitor, nullptr); +} + +// +// CompositeComponentType +// + +RefPtr CompositeComponentType::create( + Linkage* linkage, + List> const& childComponents) +{ + // TODO: We should ideally be caching the results of + // composition on the `linkage`, so that if we get + // asked for the same composite again later we re-use + // it rather than re-create it. + // + // Similarly, we might want to do some amount of + // work to "canonicalize" the input for composition. + // E.g., if the user does: + // + // X = compose(A,B); + // Y = compose(C,D); + // Z = compose(X,Y); + // + // W = compose(A, B, C, D); + // + // Then there is no observable difference between + // Z and W, so we might prefer to have them be identical. + + // If there is only a single child, then we should + // just return that child rather than create a dummy composite. + // + if( childComponents.getCount() == 1 ) + { + return childComponents[0]; + } + + return new CompositeComponentType(linkage, childComponents); +} + + +CompositeComponentType::CompositeComponentType( + Linkage* linkage, + List> const& childComponents) + : ComponentType(linkage) + , m_childComponents(childComponents) +{ + HashSet requirementsSet; + for(auto child : childComponents ) + { + child->enumerateModules([&](Module* module) + { + requirementsSet.Add(module); + }); + } + + for(auto child : childComponents ) + { + auto childEntryPointCount = child->getEntryPointCount(); + for(Index cc = 0; cc < childEntryPointCount; ++cc) + { + m_entryPoints.add(child->getEntryPoint(cc)); + } + + auto childShaderParamCount = child->getShaderParamCount(); + for(Index pp = 0; pp < childShaderParamCount; ++pp) + { + m_shaderParams.add(child->getShaderParam(pp)); + } + + auto childSpecializationParamCount = child->getSpecializationParamCount(); + for(Index pp = 0; pp < childSpecializationParamCount; ++pp) + { + m_specializationParams.add(child->getSpecializationParam(pp)); + } + + for(auto module : child->getModuleDependencies()) + { + m_moduleDependencyList.addDependency(module); + } + for(auto filePath : child->getFilePathDependencies()) + { + m_filePathDependencyList.addDependency(filePath); + } + + auto childRequirementCount = child->getRequirementCount(); + for(Index rr = 0; rr < childRequirementCount; ++rr) + { + auto childRequirement = child->getRequirement(rr); + if(!requirementsSet.Contains(childRequirement)) + { + requirementsSet.Add(childRequirement); + m_requirements.add(childRequirement); + } + } + } +} + +Index CompositeComponentType::getEntryPointCount() +{ + return m_entryPoints.getCount(); +} + +RefPtr CompositeComponentType::getEntryPoint(Index index) +{ + return m_entryPoints[index]; +} + +Index CompositeComponentType::getShaderParamCount() +{ + return m_shaderParams.getCount(); } -void Program::addEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) +GlobalShaderParamInfo CompositeComponentType::getShaderParam(Index index) { - List> entryPoints; - entryPoints.add(entryPoint); + return m_shaderParams[index]; +} + +Index CompositeComponentType::getSpecializationParamCount() +{ + return m_specializationParams.getCount(); +} + +SpecializationParam const& CompositeComponentType::getSpecializationParam(Index index) +{ + return m_specializationParams[index]; +} + +Index CompositeComponentType::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr CompositeComponentType::getRequirement(Index index) +{ + return m_requirements[index]; +} + +List const& CompositeComponentType::getModuleDependencies() +{ + return m_moduleDependencyList.getModuleList(); +} - RefPtr entryPointGroup = EntryPointGroup::create(getLinkageImpl(), entryPoints, sink); +List const& CompositeComponentType::getFilePathDependencies() +{ + return m_filePathDependencyList.getFilePathList(); +} - addEntryPointGroup(entryPointGroup); +void CompositeComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) +{ + visitor->visitComposite(this, as(specializationInfo)); } -void Program::addEntryPointGroup(EntryPointGroup* entryPointGroup) + +RefPtr CompositeComponentType::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) { - m_entryPointGroups.add(entryPointGroup); + SLANG_UNUSED(argCount); + + RefPtr specializationInfo = new CompositeSpecializationInfo(); - for(auto entryPoint : entryPointGroup->getEntryPoints()) + Index offset = 0; + for(auto child : m_childComponents) { - m_entryPoints.add(entryPoint); - for(auto module : entryPoint->getModuleDependencies()) + auto childParamCount = child->getSpecializationParamCount(); + SLANG_ASSERT(offset + childParamCount <= argCount); + + auto childInfo = child->_validateSpecializationArgs( + args + offset, + childParamCount, + sink); + + specializationInfo->childInfos.add(childInfo); + + offset += childParamCount; + } + return specializationInfo; +} + +// +// SpecializedComponentType +// + +SpecializedComponentType::SpecializedComponentType( + ComponentType* base, + ComponentType::SpecializationInfo* specializationInfo, + List const& specializationArgs, + DiagnosticSink* sink) + : ComponentType(base->getLinkage()) + , m_base(base) + , m_specializationInfo(specializationInfo) + , m_specializationArgs(specializationArgs) +{ + m_irModule = generateIRForSpecializedComponentType(this, sink); + + // The following is a bit of a hack. + // + // Back-end code generation relies on us having computed layouts for all tagged + // unions that end up being used in the code, which means we need a way to find + // all such types that get used in a program (and the stuff it imports). + // + // For now we are assuming a tagged union type only comes into existence + // as a (top-level) argument for a generic type parameter, so that we + // can check for them here and cache them on the entry point. + // + // A longer-term strategy might need to consider any (tagged or untagged) + // union types that get used inside of a module, and also take + // those lists into account. + // + // An even longer-term strategy would be to allow type layout to + // be performed on IR types, so taht we don't need to have front-end + // code worrying about this stuff. + // + for(auto arg : specializationArgs) + { + auto argType = as(arg.val); + if(!argType) + continue; + + auto taggedUnionType = as(argType); + if(!taggedUnionType) + continue; + + m_taggedUnionTypes.add(taggedUnionType); + } +} + +void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) +{ + SLANG_ASSERT(specializationInfo == nullptr); + SLANG_UNUSED(specializationInfo); + visitor->visitSpecialized(this); +} + +Index SpecializedComponentType::getRequirementCount() +{ + // TODO: A specialized component type may have *more* requirements + // than the original, because it also needs to include the module(s) + // that define the types used for specialization arguments. + + return m_base->getRequirementCount(); +} + +RefPtr SpecializedComponentType::getRequirement(Index index) +{ + return m_base->getRequirement(index); +} + +// +// LegacyProgram +// + +LegacyProgram::LegacyProgram( + Linkage* linkage, + List> const& translationUnits, + DiagnosticSink* sink) + : ComponentType(linkage) + , m_translationUnits(translationUnits) +{ + HashSet requirementsSet; + + for(auto translationUnit : translationUnits ) + { + ComponentType* child = translationUnit->getModule(); + + auto childEntryPointCount = child->getEntryPointCount(); + for(Index cc = 0; cc < childEntryPointCount; ++cc) + { + m_entryPoints.add(child->getEntryPoint(cc)); + } + + for(auto module : child->getModuleDependencies()) + { + m_moduleDependencies.addDependency(module); + } + for(auto filePath : child->getFilePathDependencies()) { - addReferencedModule(module); + m_fileDependencies.addDependency(filePath); + } + + auto childRequirementCount = child->getRequirementCount(); + for(Index rr = 0; rr < childRequirementCount; ++rr) + { + auto childRequirement = child->getRequirement(rr); + if(!requirementsSet.Contains(childRequirement)) + { + requirementsSet.Add(childRequirement); + m_requirements.add(childRequirement); + } } } + + _collectShaderParams(sink); } -RefPtr Program::getOrCreateIRModule(DiagnosticSink* sink) +void LegacyProgram::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) { - if(!m_irModule) + visitor->visitLegacy(this, as(specializationInfo)); +} + +RefPtr LegacyProgram::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) +{ + SLANG_UNUSED(argCount); + + RefPtr info = new CompositeComponentType::CompositeSpecializationInfo(); + + Index offset = 0; + for(auto translationUnit : m_translationUnits) { - m_irModule = generateIRForProgram( - m_linkage->getSessionImpl(), - this, + ComponentType* child = translationUnit->getModule(); + auto childParamCount = child->getSpecializationParamCount(); + SLANG_ASSERT(offset + childParamCount <= argCount); + + auto childInfo = child->_validateSpecializationArgs( + args + offset, + childParamCount, sink); + + info->childInfos.add(childInfo); + + offset += childParamCount; } - return m_irModule; + return info; } +Index LegacyProgram::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr LegacyProgram::getRequirement(Index index) +{ + return m_requirements[index]; +} + +void ComponentTypeVisitor::visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) +{ + auto childCount = composite->getChildComponentCount(); + for(Index ii = 0; ii < childCount; ++ii) + { + auto child = composite->getChildComponent(ii); + auto childSpecializationInfo = specializationInfo + ? specializationInfo->childInfos[ii] + : nullptr; + + child->acceptVisitor(this, childSpecializationInfo); + } +} + +void ComponentTypeVisitor::visitChildren(SpecializedComponentType* specialized) +{ + specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo()); +} + +void ComponentTypeVisitor::visitChildren(LegacyProgram* legacy, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) +{ + auto childCount = legacy->getTranslationUnitCount(); + for(Index ii = 0; ii < childCount; ++ii) + { + auto translationUnit = legacy->getTranslationUnit(ii); + ComponentType* child = translationUnit->getModule(); + auto childSpecializationInfo = specializationInfo + ? specializationInfo->childInfos[ii] + : nullptr; + + child->acceptVisitor(this, childSpecializationInfo); + } +} -TargetProgram* Program::getTargetProgram(TargetRequest* target) +TargetProgram* ComponentType::getTargetProgram(TargetRequest* target) { RefPtr targetProgram; if(!m_targetPrograms.TryGetValue(target, targetProgram)) @@ -1768,12 +2264,12 @@ TargetProgram* Program::getTargetProgram(TargetRequest* target) // TargetProgram::TargetProgram( - Program* program, + ComponentType* componentType, TargetRequest* targetReq) - : m_program(program) + : m_program(componentType) , m_targetReq(targetReq) { - m_entryPointResults.setCount(program->getEntryPoints().getCount()); + m_entryPointResults.setCount(componentType->getEntryPointCount()); } // @@ -2458,10 +2954,10 @@ SLANG_API SlangResult spSetGlobalGenericArgs( if (!request) return SLANG_FAIL; auto req = convert(request); - auto& genericArgStrings = req->globalGenericArgStrings; - genericArgStrings.clear(); + auto& argStrings = req->globalSpecializationArgStrings; + argStrings.clear(); for (int i = 0; i < genericArgCount; i++) - genericArgStrings.add(genericArgs[i]); + argStrings.add(genericArgs[i]); return SLANG_OK; } @@ -2477,7 +2973,7 @@ SLANG_API SlangResult spSetTypeNameForGlobalExistentialTypeParam( if(!typeName) return SLANG_FAIL; auto req = convert(request); - auto& typeArgStrings = req->globalExistentialSlotArgStrings; + auto& typeArgStrings = req->globalSpecializationArgStrings; if(Index(slotIndex) >= typeArgStrings.getCount()) typeArgStrings.setCount(slotIndex+1); typeArgStrings[slotIndex] = String(typeName); @@ -2501,7 +2997,7 @@ SLANG_API SlangResult spSetTypeNameForEntryPointExistentialTypeParam( return SLANG_FAIL; auto& entryPointInfo = req->entryPoints[entryPointIndex]; - auto& typeArgStrings = entryPointInfo.existentialArgStrings; + auto& typeArgStrings = entryPointInfo.specializationArgStrings; if(Index(slotIndex) >= typeArgStrings.getCount()) typeArgStrings.setCount(slotIndex+1); typeArgStrings[slotIndex] = String(typeName); @@ -2569,7 +3065,7 @@ spGetDependencyFileCount( if(!request) return 0; auto req = convert(request); auto frontEndReq = req->getFrontEndReq(); - auto program = frontEndReq->getProgram(); + auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); return (int) program->getFilePathDependencies().getCount(); } @@ -2583,7 +3079,7 @@ spGetDependencyFilePath( if(!request) return 0; auto req = convert(request); auto frontEndReq = req->getFrontEndReq(); - auto program = frontEndReq->getProgram(); + auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); return program->getFilePathDependencies()[index].begin(); } @@ -2613,7 +3109,7 @@ SLANG_API void const* spGetEntryPointCode( using namespace Slang; auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); // TODO: We should really accept a target index in this API Index targetIndex = 0; @@ -2668,7 +3164,7 @@ SLANG_API SlangResult spGetEntryPointCodeBlob( auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); Index targetCount = linkage->targets.getCount(); if((targetIndex < 0) || (targetIndex >= targetCount)) @@ -2715,13 +3211,13 @@ SLANG_API void const* spGetCompileRequestCode( SLANG_API SlangResult spCompileRequest_getProgram( SlangCompileRequest* request, - slang::IProgram** outProgram) + slang::IComponentType** outProgram) { if( !request ) return SLANG_ERROR_INVALID_PARAMETER; auto req = convert(request); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); - *outProgram = Slang::ComPtr(program).detach(); + *outProgram = Slang::ComPtr(program).detach(); return SLANG_OK; } @@ -2731,7 +3227,7 @@ SLANG_API SlangReflection* spGetReflection( if( !request ) return 0; auto req = convert(request); auto linkage = req->getLinkage(); - auto program = req->getSpecializedProgram(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); // Note(tfoley): The API signature doesn't let the client // specify which target they want to access reflection -- cgit v1.2.3