diff options
Diffstat (limited to 'source/slang/slang-check.cpp')
| -rw-r--r-- | source/slang/slang-check.cpp | 300 |
1 files changed, 283 insertions, 17 deletions
diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index 90947cf54..83b75b964 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -433,7 +433,7 @@ namespace Slang Session* getSession() { - return m_linkage->getSession(); + return m_linkage->getSessionImpl(); } public: @@ -9656,6 +9656,111 @@ namespace Slang } } +static bool shouldUseFalcorCustomSharedKeywordSemantics( + EntryPointGroup* entryPointGroup) +{ + if( !entryPointGroup->getLinkageImpl()->m_useFalcorCustomSharedKeywordSemantics ) + return false; + + // As a sanity check, if we are being asked to lay out an + // empty entry-point group, then don't apply the convention. + // + if(entryPointGroup->getEntryPointCount() == 0) + return false; + + // Otherwise we will look at the first entry point in the group, + // and use that to determine whether it looks like we are compiling + // ray-tracing shaders or not. + // + switch( entryPointGroup->getEntryPoint(0)->getStage() ) + { + case Stage::AnyHit: + case Stage::Callable: + case Stage::ClosestHit: + case Stage::Intersection: + case Stage::Miss: + case Stage::RayGeneration: + return true; + + default: + return false; + } +} + + +static bool shouldUseFalcorCustomSharedKeywordSemantics( + Program* program) +{ + if( !program->getLinkageImpl()->m_useFalcorCustomSharedKeywordSemantics ) + return false; + + // As a sanity check, if we are being asked to lay out a program + // with *no* entry points, then we don't apply the convention. + // + if(program->getEntryPointGroupCount() == 0) + return false; + + // Otherwise we let the first entry-point group determine if we should + // apply the policy for the entire program. + // + // Note: this could lead to confusing results if a `Program` mixes + // entry point groups for RT and non-RT pipelines, but that isn't + // a scenario we expect to come up and this whole routine is handling + // "do what I mean" semantics for legacy behavior. + // + return shouldUseFalcorCustomSharedKeywordSemantics(program->getEntryPointGroup(0)); +} + + +void EntryPointGroup::_collectShaderParams(DiagnosticSink* sink) +{ + // If and only if we are in the special mode for Falcor support + // where non-`shared` global shader parameters are actually + // supposed to go into the "local root signature," then we + // will consider such parameters as if they were entry-point + // parameters, attached to the group. + // + if( shouldUseFalcorCustomSharedKeywordSemantics(this) ) + { + for( auto module : getModuleDependencies() ) + { + auto moduleDecl = module->getModuleDecl(); + for( auto globalVar : moduleDecl->getMembersOfType<VarDecl>() ) + { + // Don't consider globals that aren't shader parameters. + // + if(!isGlobalShaderParameter(globalVar)) + continue; + + // Don't consider global shader paramters that were marked + // `shared`, since that is how global-root-signature parameters + // are being specified. + // + if( globalVar->HasModifier<HLSLEffectSharedModifier>() ) + continue; + + auto paramDeclRef = makeDeclRef(globalVar.Ptr()); + + ShaderParamInfo shaderParamInfo; + shaderParamInfo.paramDeclRef = paramDeclRef; + + ExistentialTypeSlots slots; + _collectExistentialSlotsForShaderParam( + shaderParamInfo, + slots, + paramDeclRef); + + if( slots.paramTypes.getCount() != 0 ) + { + sink->diagnose(globalVar, Diagnostics::typeParametersNotAllowedOnEntryPointGlobal, globalVar); + } + + m_shaderParams.add(shaderParamInfo); + } + } + } +} + // Validate that an entry point function conforms to any additional // constraints based on the stage (and profile?) it specifies. void validateEntryPoint( @@ -10334,9 +10439,6 @@ static bool doesParameterMatch( return true; } - - - /// Enumerate the existential-type parameters of a `Program`. /// /// Any parameters found will be added to the list of existential slots on `this`. @@ -10364,9 +10466,34 @@ static bool doesParameterMatch( auto moduleDecl = module->getModuleDecl(); for( auto globalVar : moduleDecl->getMembersOfType<VarDecl>() ) { + // We do not want to consider global variable declarations + // that don't represents shader parameters. This includes + // things like `static` globals and `groupshared` variables. + // if(!isGlobalShaderParameter(globalVar)) continue; + // HACK: In order to support existing policy in the Falcor + // application, we support a custom mode where only + // global variables marked as `shared` should be considered + // as global shader parameters (that go in the "global + // root signature" for DXR). + // + // TODO: Eliminate this special case once all of the client + // application code has been ported to use more general-purpose + // Slang mechanisms (e.g., entry-point `uniform` parameters). + // + if( shouldUseFalcorCustomSharedKeywordSemantics(this) ) + { + if( !globalVar->HasModifier<HLSLEffectSharedModifier>() ) + { + // Skip a non-`shared` global for purposes of enumerating + // shader parameters. + // + continue; + } + } + // This declaration may represent the same logical parameter // as a declaration that came from a different translation unit. // If that is the case, we want to re-use the same `ShaderParamInfo` @@ -10474,7 +10601,13 @@ static bool doesParameterMatch( entryPointReq); if( entryPoint ) { - program->addEntryPoint(entryPoint); + // TODO: We need to implement an explicit policy + // for what should happen if the user specified + // entry points via the command-line (or API), + // but didn't specify any groups (since the current + // compilation API doesn't allow for grouping). + // + program->addEntryPoint(entryPoint, sink); entryPointReq->getTranslationUnit()->entryPoints.add(entryPoint); } } @@ -10540,7 +10673,17 @@ static bool doesParameterMatch( validateEntryPoint(entryPoint, sink); - program->addEntryPoint(entryPoint); + // Note: in the case that the user didn't explicitly + // specify entry points and we are instead compiling + // a shader "library," then we do not want to automatically + // combine the entry points into groups in the generated + // `Program`, since that would be slightly too magical. + // + // Instead, each entry point will end up in a singleton + // group, so that its entry-point parameters lay out + // independent of the others. + // + program->addEntryPoint(entryPoint, sink); translationUnit->entryPoints.add(entryPoint); } } @@ -10636,7 +10779,7 @@ static bool doesParameterMatch( // generic application like `F<A,B,C>` if it were // encountered in the source code. - auto session = linkage->getSession(); + auto session = linkage->getSessionImpl(); auto genericDeclRef = makeDeclRef(genericDecl); // The first pieces is a `VarExpr` that refers to `genericDecl`. @@ -10760,7 +10903,7 @@ static bool doesParameterMatch( List<RefPtr<Expr>> const& args, DiagnosticSink* sink) { - Slang::_specializeExistentialTypeParams(getLinkage(), m_globalExistentialSlots, args, sink); + Slang::_specializeExistentialTypeParams(getLinkageImpl(), m_globalExistentialSlots, args, sink); } Type* Linkage::specializeType( @@ -10799,14 +10942,34 @@ static bool doesParameterMatch( return specializedType; } - /// Specialize a program to global generic arguments - RefPtr<Program> createSpecializedProgram( + // Shared implementation logic for the `_createSpecializedProgram*` entry points. + static RefPtr<Program> _createSpecializedProgramImpl( Linkage* linkage, Program* unspecializedProgram, List<RefPtr<Expr>> const& globalGenericArgs, List<RefPtr<Expr>> const& globalExistentialArgs, DiagnosticSink* sink) { + // TODO: If there are no specialization arguments, + // then the the result of specialization should + // be the same as the input... *but* we are promising + // to return a program without any entry points, + // and that screws things up here. + // + // For now we just carefully avod this early-exit + // if we have any entry points. + // + // Eventually we should try to revise the model so + // that specialization of a program that includes + // entry points can make some kind of sense. + // + if( globalGenericArgs.getCount() == 0 + && globalExistentialArgs.getCount() == 0 + && unspecializedProgram->getEntryPointCount() == 0) + { + return unspecializedProgram; + } + // The given `unspecializedProgram` should be one that // was checked through the front-end, so that now we // only need to check if the given arguments can satisfy @@ -10967,6 +11130,37 @@ static bool doesParameterMatch( specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts); + return specializedProgram; + } + + /// Create a specialized copy of `unspecializedProgram`. + /// + /// The specialized program will include the entry points + /// from the original program (whether thsoe entry points + /// are specialized or not). + /// + static RefPtr<Program> _createSpecializedProgram( + Linkage* linkage, + Program* unspecializedProgram, + List<RefPtr<Expr>> const& globalGenericArgs, + List<RefPtr<Expr>> const& globalExistentialArgs, + DiagnosticSink* sink) + { + auto specializedProgram = _createSpecializedProgramImpl( + linkage, + unspecializedProgram, + globalGenericArgs, + globalExistentialArgs, + sink); + + // We need to ensure that the specialized program has whatever + // entry points (and groups) the unspecialized one had. + // + for( auto entryPointGroup : unspecializedProgram->getEntryPointGroups() ) + { + specializedProgram->addEntryPointGroup(entryPointGroup); + } + // Now deal with the shader parameters and existential arguments // // Note: We should in theory be able to just copy over the shader @@ -10981,6 +11175,61 @@ static bool doesParameterMatch( return specializedProgram; } + SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::specializeProgram( + slang::IProgram* inUnspecializedProgram, + SlangInt specializationArgCount, + slang::SpecializationArg const* specializationArgs, + slang::IProgram** outSpecializedProgram, + ISlangBlob** outDiagnostics) + { + auto unspecializedProgram = asInternal(inUnspecializedProgram); + + if( specializationArgCount == 0 ) + { + *outSpecializedProgram = ComPtr<slang::IProgram>(asExternal(unspecializedProgram)).detach(); + return SLANG_OK; + } + + List<RefPtr<Expr>> globalGenericArgs; + + List<RefPtr<Expr>> globalExistentialArgs; + for( Int ii = 0; ii < specializationArgCount; ++ii ) + { + auto& specializationArg = specializationArgs[ii]; + switch( specializationArg.kind ) + { + case slang::SpecializationArg::Kind::Type: + { + auto typeArg = asInternal(specializationArg.type); + RefPtr<SharedTypeExpr> argExpr = new SharedTypeExpr(); + argExpr->base = TypeExp(typeArg); + argExpr->type = QualType(getTypeType(typeArg)); + + globalExistentialArgs.add(argExpr); + } + break; + + default: + return SLANG_E_INVALID_ARG; + } + } + + DiagnosticSink sink(getSourceManager()); + auto specializedProgram = _createSpecializedProgram( + this, + unspecializedProgram, + globalGenericArgs, + globalExistentialArgs, + &sink); + sink.getBlobIfNeeded(outDiagnostics); + + if(!specializedProgram) + return SLANG_FAIL; + + *outSpecializedProgram = ComPtr<slang::IProgram>(asExternal(specializedProgram)).detach(); + return SLANG_OK; + } + /// Specialize an entry point that was checked by the front-end, based on generic arguments. /// /// If the end-to-end compile request included generic argument strings @@ -11036,6 +11285,7 @@ static bool doesParameterMatch( // global or entry-point generic parameters. // auto unspecializedProgram = endToEndReq->getUnspecializedProgram(); + auto linkage = endToEndReq->getLinkage(); // First, let's parse the generic argument strings that were // provided via the API, so taht we can match them @@ -11058,12 +11308,13 @@ static bool doesParameterMatch( // applying the global generic arguments (if any) to the // unspecialized program. // - auto specializedProgram = createSpecializedProgram( + auto sink = endToEndReq->getSink(); + auto specializedProgram = _createSpecializedProgramImpl( endToEndReq->getLinkage(), unspecializedProgram, globalGenericArgs, globalExistentialArgs, - endToEndReq->getSink()); + sink); // If anything went wrong with the global generic // arguments, then bail out now. @@ -11090,15 +11341,30 @@ static bool doesParameterMatch( endToEndReq->entryPoints.setCount(entryPointCount); } - for( Index ii = 0; ii < entryPointCount; ++ii ) + Index entryPointCounter = 0; + + for( auto unspecializedEntryPointGroup : unspecializedProgram->getEntryPointGroups() ) { - auto unspecializedEntryPoint = unspecializedProgram->getEntryPoint(ii); - auto& entryPointInfo = endToEndReq->entryPoints[ii]; + List<RefPtr<EntryPoint>> specializedEntryPoints; + for( auto unspecializedEntryPoint : unspecializedEntryPointGroup->getEntryPoints() ) + { + Index entryPointIndex = entryPointCounter++; + auto& entryPointInfo = endToEndReq->entryPoints[entryPointIndex]; + + auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); + specializedEntryPoints.add(specializedEntryPoint); + } - auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); - specializedProgram->addEntryPoint(specializedEntryPoint); + RefPtr<EntryPointGroup> specializedEntryPointGroup = EntryPointGroup::create(linkage, specializedEntryPoints, endToEndReq->getSink()); + specializedProgram->addEntryPointGroup(specializedEntryPointGroup); } + // Finalize the information for the specialized program, + // now that we have computed its entry point list, etc. + // + specializedProgram->_collectShaderParams(sink); + specializedProgram->_specializeExistentialTypeParams(globalExistentialArgs, sink); + return specializedProgram; } |
