diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2019-02-05 16:47:25 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-02-05 16:47:25 -0800 |
| commit | 60cc9f24c4bec54561bea873ee943aa3d0973dc2 (patch) | |
| tree | 16e404be181eba50d7d770f373d07cb17d9ac64d | |
| parent | 314795b5d8ff5845624f93e152face325659dd0c (diff) | |
Allow entry points to have explicit generic parameters (#826)
* Allow entry points to have explicit generic parameters
Prior to this change, the Slang implementation required users to use global `type_param` declarations in order to specialize a full shader. For example:
```hlsl
type_param L : ILight;
ParameterBlock<L> gLight;
[shader("fragment")]
float4 fs(...)
{ ... gLight.doSomething() ... }
```
With this change we can rewrite code like the above using explicit generics, plus the ability to have `uniform` entry-point parameters:
```hlsl
[shader("fragment")]
float4 fs<L : ILight>(
uniform ParameterBlock<L> light,
...)
{ ... light.doSomething() ... }
```
Having this support in place should make it possible for us to eliminate global generic type parameters and the complications they cause (both at a conceptual and implementation level).
The most central and visible piece of the change is that `EntryPointRequest` now holds a `DeclRef<FuncDecl>` instead of just ` RefPtr<FuncDecl>`, which allows it to refer to a specialization of a generic function.
Various places in the code that refer to the `EntryPointRequest::decl` member now use a `getFuncDecl()` or `getFuncDeclRef()` method as appropriate (see `compiler.h`).
In order to fill in the new data, the `findAndValidateEntryPoint` function has been greaterly overhauled.
The changes to its operation include:
* The by-name lookup step for the entry point function has been adapted to accept either a function or a generic function.
* The generic argument strings provided by API or command line are no longer parsed all the way to `Type`s, but instead just to `Expr`s in the first pass.
* There are now two cases for checking the global generic arguments against their matching parameters. The first case is the new one, where we plug the generic argument `Expr`s into the explicit generic parameters of an entry point (that case re-uses existing semantic checking logic). The second case is the pre-existing code for dealing with global generic type arguments.
The `lower-to-ir.cpp` logic for hadling entry points then had to be extended. Making it deal with a full `DeclRef` instead of just a `Decl` was the easy part (just call `emitDeclRef` instead of `ensureDecl`).
The more interesting bits were:
* We need to carefully add the `IREntryPointDecoration` to the nested function and not the generic in the case where we have a generic entry point. There is a handy `getResolvedInstForDecorations` that can extract the return value for an IR generic so that we can decorate the right hting.
* We need to make sure that in the case where we emit a `specialize` instruction (which normally wouldn't get a linkage decoration), we attach an `[export(...)]` decoration to it with the mangled name of the decl-ref, so that it can be found during the linking step.
The IR linking step is then slightly more complicated because the mangled entry point name could either refer directly to an `IRFunc` or to a `specialize` instruction for a generic entry point. The logic was refactored to first clone the entry point symbol without concern for which case it is (the old code was specific to functions), and then *if* the result is a `specialize` instruction, we attempt to run generic specialization on-demand.
That on-demand specialization is a bit of a kludge, but it deals with the fact that all the downstream passing only expect to see an `IRFunc`. A future cleanup might try to split out that specialization step into its own pass, which ends up being a limited form of the specialization pass.
Since I was already having to touch a lot of the code around IR linking, I went ahead and refactored the signature of the operations. I eliminated the need for the caller to create, pass in, and then destroy an `IRSpecializationState` (really an IR *linking* state), and replaced it with a structure local to the pass (that data structure was a remnant of an older approach in the compiler), and then also renamed the main operation to `linkIR` to reflect what it is doing in our conceptual flow.
Smaller changes made along the way include:
* Refactored `visitGenericAppExpr` to create a subroutine `checkGenericAppWithCheckedArgs` so that it can be used by the entry-point validation logic described above).
* Refactored the declarations around the IR passes in `emitEntryPoint()` (`emit.cpp`), to show that things are more self-contained than they used to be (e.g., that the `TypeLegalizationContext` is now only needed by one pass).
* Refactored the generic specialization code so that there is a stand-along free function that can perform specialization on a `specialize` instruction without all the other context being required. This is only to support the limited specialization that needs to be done as part of linking.
* Updated the `global-type-param.slang` test to actually test entry-point generic parameters. In a later pass we can/should rework all the tests/examples for global type parameters over to use explicit entry-point generic parameters (at which point we should rename the tests as well). For now I am leaving thigns with just one test case, with the expectation that bugs will be found and ironed out as we expand to more tests.
* fixup
* Fixup: don't leave entry-point decorations on stuff we don't want to keep
The IR `[entryPoint]` decoration is effectively a "keep this alive" decoration, which means that attaching it to something we don't intend to keep around can lead to Bad Things.
The approach to generic entry points was attaching `[entryPoint]` to the underlying `IRFunc` because that seemed to make sense, but that meant that the `specialize` instruction at global scope scould instantiate that generic and then keep it alive, even if the resulting function wouldn't be valid according to the language rules.
As a quick fix, I'm attaching `[entryPoint]` to the `specialize` instruction instead in such cases, and then re-attaching it to the result of explicit specialization during linking.
* Port most of remaining test and rename global type parameters
This change ports as many as possible of the existing tests for global type parameters over to use entry-point generic parameters instead. For the most part this is a mechanical change.
A few test cases remain using global generic parameters, as does the `model-viewer` example application.
The reason for this is that the shaders have either or both the following features:
* A vertex and fragment shader that can/shold agree on their parameters
* A type declaration (e.g., a `struct`) that is dependent on one of the generic type parameters
In these cases, it would really only make sense to switch to explicit parameters once we support shader entry points nested inside of a `struct` type, so that we can use an outer generic `struct` as a mechanism to scope the entry points and other type-dependent declrations.
Since global-scope type parameters need to persist for at least a bit longer, I went ahead and renamed all the use sites over to use `type_param` for consistency.
25 files changed, 668 insertions, 436 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 921ef61e6..c3f10fb4c 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -8326,12 +8326,8 @@ namespace Slang } } - RefPtr<Expr> visitGenericAppExpr(GenericAppExpr * genericAppExpr) + RefPtr<Expr> visitGenericAppExpr(GenericAppExpr* genericAppExpr) { - // We are applying a generic to arguments, but there might be multiple generic - // declarations with the same name, so this becomes a specialized case of - // overload resolution. - // Start by checking the base expression and arguments. auto& baseExpr = genericAppExpr->FunctionExpr; baseExpr = CheckTerm(baseExpr); @@ -8341,6 +8337,19 @@ namespace Slang arg = CheckTerm(arg); } + return checkGenericAppWithCheckedArgs(genericAppExpr); + } + + /// Check a generic application where the operands have already been checked. + RefPtr<Expr> checkGenericAppWithCheckedArgs(GenericAppExpr* genericAppExpr) + { + // We are applying a generic to arguments, but there might be multiple generic + // declarations with the same name, so this becomes a specialized case of + // overload resolution. + + auto& baseExpr = genericAppExpr->FunctionExpr; + auto& args = genericAppExpr->Arguments; + // If there was an error in the base expression, or in any of // the arguments, then just bail. if (IsErrorExpr(baseExpr)) @@ -9228,7 +9237,7 @@ namespace Slang // if( entryPoint->getStage() == Stage::Unknown ) { - sink->diagnose(entryPoint->decl, Diagnostics::entryPointHasNoStage, entryPoint->name); + sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::entryPointHasNoStage, entryPoint->name); } if (entryPoint->getStage() == Stage::Hull) @@ -9236,7 +9245,7 @@ namespace Slang auto translationUnit = entryPoint->getTranslationUnit(); auto translationUnitSyntax = translationUnit->SyntaxNode; - auto attr = entryPoint->decl->FindModifier<PatchConstantFuncAttribute>(); + auto attr = entryPoint->getFuncDecl()->FindModifier<PatchConstantFuncAttribute>(); if (attr) { @@ -9278,6 +9287,11 @@ namespace Slang { // The first step in validating the entry point is to find // the (unique) function declaration that matches its name. + // + // TODO: We will eventually need to update this logic + // to work by parsing the provided `entryPoint->name` string + // as an expression, so that we can handle more complex + // names like `foo<int>` or `SomeType.vs`. auto translationUnit = entryPoint->getTranslationUnit(); auto sink = &entryPoint->compileRequest->mSink; @@ -9305,11 +9319,20 @@ namespace Slang // We'll walk the linked list of declarations with the same name, // to see what we find. Along the way we'll keep track of the // first function declaration we find, if any: + // FuncDecl* entryPointFuncDecl = nullptr; for(auto ee = firstDeclWithName; ee; ee = ee->nextInContainerWithSameName) { + // We want to support the case where the declaration is + // a generic function, so we will automatically + // unwrap any outer `GenericDecl` we find here. + // + auto decl = ee; + if(auto genericDecl = as<GenericDecl>(decl)) + decl = genericDecl->inner; + // Is this declaration a function? - if (auto funcDecl = as<FuncDecl>(ee)) + if (auto funcDecl = as<FuncDecl>(decl)) { // Skip non-primary declarations, so that // we don't give an error when an entry @@ -9381,34 +9404,65 @@ namespace Slang // Phew, we have at least found a suitable decl. // Let's record that in the entry-point request so // that we don't have to re-do this effort again later. - entryPoint->decl = entryPointFuncDecl; + // + // Note: we may replace the decl-ref we store at this point + // later in this function, when we (potentially) specialize + // a generic entry point to generic arguments provided + // via the API. + // + entryPoint->funcDeclRef = makeDeclRef(entryPointFuncDecl); - // Lookup generic parameter types in global scope + // If the user specified generic arguments for the entry point, + // then we will want to parse those arguments as expressions + // in a scope that includes the tanslation unit that holds + // the entry point, as well as any other modules that got + // transitively loaded via `import`. + // + // TODO: This would be better handled by giving the user + // more explicit ways to parse/build types at the API level, + // rather than keeping things string-based this far along. + // + // TODO: Building a list of `scopesToTry` here shouldn't + // be required, since the `Scope` type itself has the ability + // for form chains for lookup purposes (e.g., the way that + // `import` is handled by modifying a scope). + // List<RefPtr<Scope>> scopesToTry; scopesToTry.Add(entryPoint->getTranslationUnit()->SyntaxNode->scope); for (auto & module : entryPoint->compileRequest->loadedModulesList) scopesToTry.Add(module->moduleDecl->scope); - List<RefPtr<Type>> globalGenericArgs; + // We are going to do some semantic checking, so we need to + // set up a `SemanticsVistitor` that we can use. + // + SemanticsVisitor semantics( + &entryPoint->compileRequest->mSink, + entryPoint->compileRequest, + entryPoint->getTranslationUnit()); + + // We will be looping over the generic argument strings + // that the user provided via the API (or command line), + // and parsing+checking each into an `Expr`. + // + // This loop will *not* handle coercing the arguments + // to be types. + // + List<RefPtr<Expr>> genericArgs; for (auto name : entryPoint->genericArgStrings) { - // parse type name - RefPtr<Type> type; + RefPtr<Expr> argExpr; for (auto & s : scopesToTry) { - RefPtr<Expr> typeExpr = entryPoint->compileRequest->parseTypeString(entryPoint->getTranslationUnit(), - name, s); - type = checkProperType(translationUnit, TypeExp(typeExpr)); - if (type) + argExpr = entryPoint->compileRequest->parseTypeString( + entryPoint->getTranslationUnit(), + name, + s); + argExpr = semantics.CheckTerm(argExpr); + if( argExpr ) { break; } } - if (!type) - { - sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, name); - return; - } // The following is a bit of a hack. // @@ -9424,165 +9478,257 @@ namespace Slang // as a (top-level) argument for a generic type parameter, so that we // can check for them here and cache them on the entry point request. // - if( auto taggedUnionType = as<TaggedUnionType>(type) ) + if( auto typeType = as<TypeType>(argExpr->type) ) { - entryPoint->taggedUnionTypes.Add(taggedUnionType); + auto type = typeType->type; + if( auto taggedUnionType = as<TaggedUnionType>(type) ) + { + entryPoint->taggedUnionTypes.Add(taggedUnionType); + } } - globalGenericArgs.Add(type); + genericArgs.Add(argExpr); } - // validate global type arguments only when we are generating code - if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) + // There are two cases we care about here, and we are going to treat them + // as mutually exclusive for simplicity. + // + // The first case is when the entry point function is itself generic, + // in which case we will assume that `genericArgs` lines up one-to-one + // with the explicit generic parameters of the entry point. + // + if( auto genericDecl = as<GenericDecl>(entryPointFuncDecl->ParentDecl) ) { - // check that user-provided type arguments conforms to the generic type - // parameter declaration of this translation unit - - // collect global generic parameters from all imported modules - List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; - // add current translation unit first - { - auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } - // add imported modules - for (auto loadedModule : entryPoint->compileRequest->loadedModulesList) - { - auto moduleDecl = loadedModule->moduleDecl; - auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } + // We will construct a suitable `GenericAppExpr` to represent + // the user-specified `genericDecl` being applied to the + // supplied `genericArgs`, and then use the existing + // semantic checking logic that would apply to an explicit + // generic application like `F<A,B,C>` if it were + // encountered in the source code. - if (globalGenericParams.Count() != globalGenericArgs.Count()) - { - sink->diagnose(entryPoint->decl, Diagnostics::mismatchEntryPointTypeArgument, - globalGenericParams.Count(), - globalGenericArgs.Count()); - return; - } + auto session = entryPoint->compileRequest->mSession; + auto genericDeclRef = makeDeclRef(genericDecl); - // We have an appropriate number of arguments for the global generic parameters, - // and now we need to check that the arguments conform to the declared constraints. + // The first pieces is a `VarExpr` that refers to `genericDecl`. // - // Along the way, we will build up an appropriate set of substitutions to represent - // the generic arguments and their conformances. + // TODO: This would not be needed if we instead parsed + // the supplied entry-point name into an expression + // earlier in this function. // - RefPtr<Substitutions> globalGenericSubsts; - auto globalGenericSubstLink = &globalGenericSubsts; - // - // TODO: There is a serious flaw to this checking logic if we ever have cases where - // the constraints on one `type_param` can depend on another `type_param`, e.g.: + RefPtr<VarExpr> genericExpr = new VarExpr(); + genericExpr->declRef = genericDeclRef; + genericExpr->type.type = getTypeForDeclRef(session, genericDeclRef); + + // Next we construct the actual `GenericAppExpr` // - // type_param A; - // type_param B : ISidekick<A>; + RefPtr<GenericAppExpr> genericAppExpr = new GenericAppExpr(); + genericAppExpr->FunctionExpr = genericExpr; + genericAppExpr->Arguments = genericArgs; + + // We use the semantics visitor to perform the + // actual checking logic (this might report + // errors) // - // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to - // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being - // set to `Batman` to know whether the setting for `B` is valid. In this limit - // the constraints can be mutually recursive (so `A : IMentor<B>`). + auto checkedExpr = semantics.checkGenericAppWithCheckedArgs(genericAppExpr); + + // Now we need to extract an appropriate decl-ref for the entry + // point from the `checkedExpr`. // - // The only way to check things correctly is to validate each conformance under - // a set of assumptions (substitutions) that includes all the type substitutions, - // and possibly also all the other constraints *except* the one to be validated. + if( auto declRefExpr = checkedExpr.as<DeclRefExpr>() ) + { + // TODO: We should eventually check for the case + // where we have a `MemberExpr` or another case of + // `DeclRefExpr` that cannot be summarized as just + // its decl-ref. + // + // The basic `VarExpr` and `StaticMemberExpr` cases + // should be allow-able. + + entryPoint->funcDeclRef = declRefExpr->declRef.as<FuncDecl>(); + } + else if( semantics.IsErrorExpr(checkedExpr) ) + { + // Any semantic error that occured should have been + // reported already. + } + else + { + // The result of specializing a reference to a generic + // function should always be a `DeclRefExpr` + // + SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`"); + } + } + else + { + // The other case is when the entry point function is *not* itself + // generic, so we assume that any generic arguments must have been intended + // to match up with global generic parameters instead. // - // We will punt on this for now, and just check each constraint in isolation. + // We will only validate global generic type arguments when we are going + // to generate code, since in a no-codegen pass we will typically *not* + // have arguments to associate with the parameters. // - UInt argCounter = 0; - for(auto& globalGenericParam : globalGenericParams) + if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) { - // Get the argument that matches this parameter. - UInt argIndex = argCounter++; - SLANG_ASSERT(argIndex < globalGenericArgs.Count()); - auto globalGenericArg = globalGenericArgs[argIndex]; + // check that user-provioded type arguments conforms to the generic type + // parameter declaration of this translation unit - // As a quick sanity check, see if the argument that is being supplied for a parameter - // is just the parameter itself, because this should always be an error: - // - if( auto argDeclRefType = as<DeclRefType>(globalGenericArg) ) + // collect global generic parameters from all imported modules + List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; + // add current translation unit first + { + auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>(); + for (auto p : globalGenParams) + globalGenericParams.Add(p); + } + // add imported modules + for (auto loadedModule : entryPoint->compileRequest->loadedModulesList) + { + auto moduleDecl = loadedModule->moduleDecl; + auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); + for (auto p : globalGenParams) + globalGenericParams.Add(p); + } + + if (globalGenericParams.Count() != genericArgs.Count()) { - auto argDeclRef = argDeclRefType->declRef; - if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) + sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::mismatchEntryPointTypeArgument, + globalGenericParams.Count(), + genericArgs.Count()); + return; + } + + // We have an appropriate number of arguments for the global generic parameters, + // and now we need to check that the arguments conform to the declared constraints. + // + // Along the way, we will build up an appropriate set of substitutions to represent + // the generic arguments and their conformances. + // + RefPtr<Substitutions> globalGenericSubsts; + auto globalGenericSubstLink = &globalGenericSubsts; + // + // TODO: There is a serious flaw to this checking logic if we ever have cases where + // the constraints on one `type_param` can depend on another `type_param`, e.g.: + // + // type_param A; + // type_param B : ISidekick<A>; + // + // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to + // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being + // set to `Batman` to know whether the setting for `B` is valid. In this limit + // the constraints can be mutually recursive (so `A : IMentor<B>`). + // + // The only way to check things correctly is to validate each conformance under + // a set of assumptions (substitutions) that includes all the type substitutions, + // and possibly also all the other constraints *except* the one to be validated. + // + // We will punt on this for now, and just check each constraint in isolation. + // + UInt argCounter = 0; + for(auto& globalGenericParam : globalGenericParams) + { + // Get the argument that matches this parameter. + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < genericArgs.Count()); + auto globalGenericArg = checkProperType(translationUnit, TypeExp(genericArgs[argIndex])); + if (!globalGenericArg) { - if(argGenericParamDeclRef.getDecl() == globalGenericParam) - { - // We are trying to specialize a generic parameter using itself. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToItself, - globalGenericParam->getName()); - sink->diagnose(entryPointFuncDecl, - Diagnostics::noteWhenCompilingEntryPoint, - entryPointFuncDecl->getName()); - continue; - } - else + sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, entryPoint->genericArgStrings[argIndex]); + return; + } + + // As a quick sanity check, see if the argument that is being supplied for a parameter + // is just the parameter itself, because this should always be an error: + // + if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() ) + { + auto argDeclRef = argDeclRefType->declRef; + if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) { - // We are trying to specialize a generic parameter using a *different* - // global generic type parameter. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, - globalGenericParam->getName(), - argGenericParamDeclRef.GetName()); - sink->diagnose(entryPointFuncDecl, - Diagnostics::noteWhenCompilingEntryPoint, - entryPointFuncDecl->getName()); - continue; + if(argGenericParamDeclRef.getDecl() == globalGenericParam) + { + // We are trying to specialize a generic parameter using itself. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToItself, + globalGenericParam->getName()); + sink->diagnose(entryPointFuncDecl, + Diagnostics::noteWhenCompilingEntryPoint, + entryPointFuncDecl->getName()); + continue; + } + else + { + // We are trying to specialize a generic parameter using a *different* + // global generic type parameter. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, + globalGenericParam->getName(), + argGenericParamDeclRef.GetName()); + sink->diagnose(entryPointFuncDecl, + Diagnostics::noteWhenCompilingEntryPoint, + entryPointFuncDecl->getName()); + continue; + } } } - } + // Create a substitution for this parameter/argument. + RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); + subst->paramDecl = globalGenericParam; + subst->actualType = globalGenericArg; - // Create a substitution for this parameter/argument. - RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); - subst->paramDecl = globalGenericParam; - subst->actualType = globalGenericArg; + // Walk through the declared constraints for the parameter, + // and check that the argument actually satisfies them. + for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) + { + // Get the type that the constraint is enforcing conformance to + auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); - // Walk through the declared constraints for the parameter, - // and check that the argument actually satisfies them. - for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) - { - // Get the type that the constraint is enforcing conformance to - auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); + // Use our semantic-checking logic to search for a witness to the required conformance + SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); + auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); + if (!witness) + { + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(globalGenericParam, + Diagnostics::typeArgumentDoesNotConformToInterface, + globalGenericParam->nameAndLoc.name, + globalGenericArg, + interfaceType); + } - // Use our semantic-checking logic to search for a witness to the required conformance - SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); - auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); - if (!witness) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(globalGenericParam, - Diagnostics::typeArgumentDoesNotConformToInterface, - globalGenericParam->nameAndLoc.name, - globalGenericArg, - interfaceType); + // Attach the concrete witness for this conformance to the + // substutiton + GlobalGenericParamSubstitution::ConstraintArg constraintArg; + constraintArg.decl = constraint; + constraintArg.val = witness; + subst->constraintArgs.Add(constraintArg); } - // Attach the concrete witness for this conformance to the - // substutiton - GlobalGenericParamSubstitution::ConstraintArg constraintArg; - constraintArg.decl = constraint; - constraintArg.val = witness; - subst->constraintArgs.Add(constraintArg); - } + // Add the substitution for this parameter to the global substitution + // set that we are building. - // Add the substitution for this parameter to the global substitution - // set that we are building. + *globalGenericSubstLink = subst; + globalGenericSubstLink = &subst->outer; + } - *globalGenericSubstLink = subst; - globalGenericSubstLink = &subst->outer; + entryPoint->globalGenericSubst = globalGenericSubsts; } - - entryPoint->globalGenericSubst = globalGenericSubsts; } + + // If any errors occured while we were checking the generic arguments + // of the entry point, then we should bail out rather than try to + // perform the next step of validation. + // if (sink->errorCount != 0) return; // Now that we've *found* the entry point, it is time to validate // that it actually meets the constraints for the chosen stage/profile. // - // TODO: This validation should be performed "under" any global generic + // TODO: This validation should (probably?) be performed "under" any global generic // parameter substitution we might have created, so that we can validate // based on knowledge of actual types. // @@ -9679,7 +9825,7 @@ namespace Slang RefPtr<EntryPointRequest> entryPointReq = new EntryPointRequest(); entryPointReq->compileRequest = compileRequest; entryPointReq->translationUnitIndex = int(tt); - entryPointReq->decl = funcDecl; + entryPointReq->funcDeclRef = makeDeclRef(funcDecl); entryPointReq->name = funcDecl->getName(); entryPointReq->profile = profile; diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp index 29f7a95d9..c93f3f70c 100644 --- a/source/slang/compiler.cpp +++ b/source/slang/compiler.cpp @@ -127,6 +127,17 @@ namespace Slang return compileRequest->translationUnits[translationUnitIndex].Ptr(); } + DeclRef<FuncDecl> EntryPointRequest::getFuncDeclRef() + { + return funcDeclRef; + } + + RefPtr<FuncDecl> EntryPointRequest::getFuncDecl() + { + return getFuncDeclRef().getDecl(); + } + + // Profile Profile::LookUp(char const* name) diff --git a/source/slang/compiler.h b/source/slang/compiler.h index 9e8f146b9..39199a62f 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -152,7 +152,11 @@ namespace Slang // This will be filled in as part of semantic analysis; // it should not be assumed to be available in cases // where any errors were diagnosed. - RefPtr<FuncDecl> decl; + // + DeclRef<FuncDecl> funcDeclRef; + + DeclRef<FuncDecl> getFuncDeclRef(); + RefPtr<FuncDecl> getFuncDecl(); RefPtr<Substitutions> globalGenericSubst; diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h index c0d31365e..0480eb934 100644 --- a/source/slang/decl-defs.h +++ b/source/slang/decl-defs.h @@ -175,7 +175,7 @@ SIMPLE_SYNTAX_CLASS(TypeAliasDecl, TypeDefDecl) SYNTAX_CLASS(AssocTypeDecl, AggTypeDecl) END_SYNTAX_CLASS() -// A '__generic_param' declaration, which defines a generic +// A 'type_param' declaration, which defines a generic // entry-point parameter. Is a container of GenericTypeConstraintDecl SYNTAX_CLASS(GlobalGenericParamDecl, AggTypeDecl) END_SYNTAX_CLASS() diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 68790d4ab..9a8446a03 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -274,7 +274,7 @@ DIAGNOSTIC(32003, Error, unexpectedEnumTagExpr, "unexpected form for 'enum' // 303xx: interfaces and associated types DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.") -DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'__generic_param' can only be defined global scope.") +DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'type_param' can only be defined global scope.") // TODO: need to assign numbers to all these extra diagnostics... DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.") DIAGNOSTIC(39999, Fatal, localVariableUsedBeforeDeclared, "local variable '$0' is being used before its declaration.") diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index d48390d31..d551ba1b9 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -6623,26 +6623,25 @@ String emitEntryPoint( EmitVisitor visitor(&context); - // We are going to create a fresh IR module that we will use to - // clone any code needed by the user's entry point. - IRSpecializationState* irSpecializationState = createIRSpecializationState( - entryPoint, - programLayout, - target, - targetRequest); - { - IRModule* irModule = getIRModule(irSpecializationState); + { auto compileRequest = translationUnit->compileRequest; auto session = compileRequest->mSession; - TypeLegalizationContext typeLegalizationContext; - initialize(&typeLegalizationContext, - session, - irModule); - - auto irEntryPoint = specializeIRForEntryPoint( - irSpecializationState, - entryPoint); + // We start out by performing "linking" at the level of the IR. + // This step will create a fresh IR module to be used for + // code generation, and will copy in any IR definitions that + // the desired entry point requires. Along the way it will + // resolve references to imported/exported symbols across + // modules, and also select between the definitions of + // any "profile-overloaded" symbols. + // + auto linkedIR = linkIR( + entryPoint, + programLayout, + target, + targetRequest); + auto irModule = linkedIR.module; + auto irEntryPoint = linkedIR.entryPoint; #if 0 dumpIRIfEnabled(compileRequest, irModule, "CLONED"); @@ -6708,9 +6707,22 @@ String emitEntryPoint( // we need to ensure that the code only uses types // that are legal on the chosen target. // - legalizeTypes( - &typeLegalizationContext, - irModule); + { + // TODO: The presence of `TypeLegalizationContext` + // in the public API of the `legalizeTypes` function + // is a throwback to when there was AST-level + // type legalization and all the complications it + // created. The pass should be refactored to not + // expose these details. + // + TypeLegalizationContext typeLegalizationContext; + initialize(&typeLegalizationContext, + session, + irModule); + legalizeTypes( + &typeLegalizationContext, + irModule); + } // Debugging output of legalization #if 0 @@ -6810,7 +6822,6 @@ String emitEntryPoint( // GlobalGenericParamSubstitution implementation may reference ir objects targetRequest->compileRequest->compiledModules.Add(irModule); } - destroyIRSpecializationState(irSpecializationState); // Deal with cases where a particular stage requires certain GLSL versions // and/or extensions. diff --git a/source/slang/ir-link.cpp b/source/slang/ir-link.cpp index dba4fc2d1..35e0f46b8 100644 --- a/source/slang/ir-link.cpp +++ b/source/slang/ir-link.cpp @@ -659,13 +659,29 @@ void cloneFunctionCommon( } } +// We will forward-declare the subroutine for eagerly specializing +// an IR-level generic to argument values, because `specializeIRForEntryPoint` +// needs to perform this operation even though it is logically part of +// the later generic specialization pass. +// +IRInst* specializeGeneric( + IRSpecialize* specializeInst); + IRFunc* specializeIRForEntryPoint( IRSpecContext* context, EntryPointRequest* entryPointRequest, EntryPointLayout* entryPointLayout) { - // Look up the IR symbol by name - auto mangledName = getMangledName(entryPointRequest->decl); + // We start by looking up the IR symbol that + // matches the mangled name given to the + // function we want to emit. + // + // Note: the function decl-ref may refer to + // a specialization of a generic function, + // so that the mangled name of the decl-ref is + // not the same as the mangled name of the decl. + // + auto mangledName = getMangledName(entryPointRequest->getFuncDeclRef()); RefPtr<IRSpecSymbol> sym; if (!context->getSymbols().TryGetValue(mangledName, sym)) { @@ -674,40 +690,68 @@ IRFunc* specializeIRForEntryPoint( } // TODO: deal with the case where we might - // have multiple versions... + // have multiple (profile-overloaded) versions... + // + auto originalVal = sym->irGlobalValue; + + // We will start by cloning the entry point reference + // like any other global value. + // + auto clonedVal = cloneGlobalValue(context, originalVal); - auto globalValue = sym->irGlobalValue; - if (globalValue->op != kIROp_Func) + // In the case where the user is requesting a specialization + // of a generic entry point, we have a bit of a problem. + // + // This function is expected to return an `IRFunc` and + // subsequent passes expect to find, e.g., layout information + // attached to the parameters of such a func. + // + // In the generic case, the `clonedValue` won't be an + // `IRFunc`, but instead an `IRSpecialize`. + // + if(auto clonedSpec = as<IRSpecialize>(clonedVal)) { - SLANG_UNEXPECTED("expected an IR function"); + // The Right Thing to do here is to perform some + // amount of generic specialization, at least + // until we get back an `IRFunc`. + // + // The dangerous thing is that the generic specialization + // pass can, in principle, change the signature of + // functions, so that attaching parameter layout + // information *after* specialization might not work. + // + // The compromise we make here is to directly + // invoke the logic for specializing a generic. + // + // In theory this isn't valid, because there is no + // way we can register the specialized function we + // create so that it would be re-used by other instantiations + // with the same arguments (because we cannot be + // sure the generic arguments are themselves fully specialized) + // + // In practice this isn't really a problem, because + // we don't want to share the definition between + // an entry point and an ordinary function anyway. + // + clonedVal = specializeGeneric(clonedSpec); + } + + auto clonedFunc = as<IRFunc>(clonedVal); + if(!clonedFunc) + { + SLANG_UNEXPECTED("expected entry point to be a function"); return nullptr; } - auto originalFunc = (IRFunc*)globalValue; - - // Create a clone for the IR function - auto clonedFunc = context->builder->createFunc(); - - // Note: we do *not* register this cloned declaration - // as the cloned value for the original symbol. - // This is kind of a kludge, but it ensures that - // in the unlikely case that the function is both - // used as an entry point and a callable function - // (yes, this would imply recursion...) we actually - // have two copies, which lets us arbitrarily - // transform the entry point to meet target requirements. - // - // TODO: The above statement is kind of bunk, though, - // because both versions of the function would have - // the same mangled name... :( - // We need to clone all the properties of the original - // function, including any blocks, their parameters, - // and their instructions. - cloneFunctionCommon(context, clonedFunc, originalFunc); + if( !clonedFunc->findDecorationImpl(kIROp_EntryPointDecoration) ) + { + context->builder->addEntryPointDecoration(clonedFunc); + } // We need to attach the layout information for // the entry point to this declaration, so that // we can use it to inform downstream code emit. + // context->builder->addLayoutDecoration( clonedFunc, entryPointLayout); @@ -1166,13 +1210,14 @@ struct IRSpecializationState } }; -IRSpecializationState* createIRSpecializationState( +LinkedIR linkIR( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, CodeGenTarget target, TargetRequest* targetReq) { - IRSpecializationState* state = new IRSpecializationState(); + IRSpecializationState stateStorage; + auto state = &stateStorage; state->programLayout = programLayout; state->target = target; @@ -1232,6 +1277,8 @@ IRSpecializationState* createIRSpecializationState( context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); } + context->builder->setInsertInto(context->getModule()->getModuleInst()); + // for now, clone all unreferenced witness tables // // TODO: This step should *not* be needed with the current IR @@ -1242,39 +1289,9 @@ IRSpecializationState* createIRSpecializationState( if (sym.Value->irGlobalValue->op == kIROp_WitnessTable) cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); } - return state; -} - -void destroyIRSpecializationState(IRSpecializationState* state) -{ - delete state; -} - -IRModule* getIRModule(IRSpecializationState* state) -{ - return state->irModule; -} - -IRFunc* specializeIRForEntryPoint( - IRSpecializationState* state, - EntryPointRequest* entryPointRequest) -{ - auto translationUnit = entryPointRequest->getTranslationUnit(); - auto originalIRModule = translationUnit->irModule; - if (!originalIRModule) - { - // We should already have emitted IR for the original - // translation unit, and it we don't have it, then - // we are now in trouble. - return nullptr; - } - - auto context = state->getContext(); - auto newProgramLayout = state->newProgramLayout; auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest); - // Next, we make sure to clone the global value for // the entry point function itself, and rely on // this step to recursively copy over anything else @@ -1317,13 +1334,19 @@ IRFunc* specializeIRForEntryPoint( context->builder->addLayoutDecoration(clonedType, taggedUnionTypeLayout); } - - // TODO: *technically* we should consider the case where // we have global variables with initializers, since // these should get run whether or not the entry point // references them. - return irEntryPoint; + + // Now that we've cloned the entry point and everything + // it refers to, we can package up the data we return + // to the caller. + // + LinkedIR linkedIR; + linkedIR.module = state->irModule; + linkedIR.entryPoint = irEntryPoint; + return linkedIR; } diff --git a/source/slang/ir-link.h b/source/slang/ir-link.h index 18940f825..4fcdb4618 100644 --- a/source/slang/ir-link.h +++ b/source/slang/ir-link.h @@ -5,29 +5,22 @@ namespace Slang { - // Interface to IR specialization for use when cloning target-specific - // IR as part of compiling an entry point. + struct LinkedIR + { + RefPtr<IRModule> module; + IRFunc* entryPoint; + }; - // `IRSpecializationState` is used as an opaque type to wrap up all - // the data needed to perform IR specialization, without exposing - // implementation details. - struct IRSpecializationState; - IRSpecializationState* createIRSpecializationState( - EntryPointRequest* entryPointRequest, - ProgramLayout* programLayout, - CodeGenTarget target, - TargetRequest* targetReq); - void destroyIRSpecializationState(IRSpecializationState* state); - IRModule* getIRModule(IRSpecializationState* state); - - struct ExtensionUsageTracker; // Clone the IR values reachable from the given entry point // into the IR module associated with the specialization state. // When multiple definitions of a symbol are found, the one // that is best specialized for the given `targetReq` will be // used. - IRFunc* specializeIRForEntryPoint( - IRSpecializationState* state, - EntryPointRequest* entryPointRequest); + // + LinkedIR linkIR( + EntryPointRequest* entryPointRequest, + ProgramLayout* programLayout, + CodeGenTarget target, + TargetRequest* targetReq); } diff --git a/source/slang/ir-specialize.cpp b/source/slang/ir-specialize.cpp index 144cc008f..0da06580f 100644 --- a/source/slang/ir-specialize.cpp +++ b/source/slang/ir-specialize.cpp @@ -29,6 +29,14 @@ namespace Slang // simplifications/specializations of one category can open // up opportunities for transformations in the other categories. +struct SpecializationContext; + +IRInst* specializeGenericImpl( + IRGeneric* genericVal, + IRSpecialize* specializeInst, + IRModule* module, + SpecializationContext* context); + struct SpecializationContext { // For convenience, we will keep a pointer to the module @@ -203,112 +211,26 @@ struct SpecializationContext // If no existing specialization is found, we need // to create the specialization instead. + // This mostly amounts to evaluating the generic as + // if it were a function being called. // - // Effectively this amounts to "calling" the generic - // on its concrete argument values and computing the - // result it returns. - // - // For now, all of our generics consist of a single - // basic block, so we can "call" them just by - // cloning the instructions in their single block - // into the global scope, using an environment for - // cloning that maps the generic parameters to - // the concrete arguments that were provided - // by the `specialize(...)` instruction. - // - IRCloneEnv env; - - // We will walk through the parameters of the generic and - // register the corresponding argument of the `specialize` - // instruction to be used as the "cloned" value for each - // parameter. - // - // Suppose we are looking at `specialize(g, a, b, c)` and `g` has - // three generic parameters: `T`, `U`, and `V`. Then we will - // be initializing our environment to map `T -> a`, `U -> b`, - // and `V -> c`. + // We will use a free function to do the actual work + // of evaluating the generic, so that the logic + // can be re-used in other cases that need to + // do one-off specialization. // - UInt argCounter = 0; - for( auto param : genericVal->getParams() ) - { - UInt argIndex = argCounter++; - SLANG_ASSERT(argIndex < specializeInst->getArgCount()); - - IRInst* arg = specializeInst->getArg(argIndex); - - env.mapOldValToNew.Add(param, arg); - } + IRInst* specializedVal = specializeGenericImpl(genericVal, specializeInst, module, this); - // We will set up an IR builder for insertion - // into the global scope, at the same location - // as the original generic. - // - IRBuilder builderStorage; - IRBuilder* builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - builder->setInsertBefore(genericVal); - // Now we will run through the body of the generic and - // clone each of its instructions into the global scope, - // until we reach a `return` instruction. + // The value that was returned from evaluating + // the generic is the specialized value, and we + // need to remember it in our dictionary of + // specializations so that we don't instantiate + // this generic again for the same arguments. // - for( auto bb : genericVal->getBlocks() ) - { - // We expect a generic to only ever contain a single block. - // - SLANG_ASSERT(bb == genericVal->getFirstBlock()); - - // We will iterate over the non-parameter ("ordinary") - // instructions only, because parameters were dealt - // with explictly at an earlier point. - // - for( auto ii : bb->getOrdinaryInsts() ) - { - // The last block of the generic is expected to end with - // a `return` instruction for the specialized value that - // comes out of the abstraction. - // - // We thus use that cloned value as the result of the - // specialization step. - // - if( auto returnValInst = as<IRReturnVal>(ii) ) - { - auto specializedVal = findCloneForOperand(&env, returnValInst->getVal()); - - // The value that was returned from evaluating - // the generic is the specialized value, and we - // need to remember it in our dictionary of - // specializations so that we don't instantiate - // this generic again for the same arguments. - // - genericSpecializations.Add(key, specializedVal); - - return specializedVal; - } - - // For any instruction other than a `return`, we will - // simply clone it completely into the global scope. - // - IRInst* clonedInst = cloneInst(&env, builder, ii); - - // Any new instructions we create during cloning were - // not present when we initially built our work list, - // so we need to make sure to consider them now. - // - // This is important for the cases where one generic - // invokes another, because there will be `specialize` - // operations nested inside the first generic that refer - // to the second. - // - addToWorkList(clonedInst); - } - } + genericSpecializations.Add(key, specializedVal); - // If we reach this point, something went wrong, because we - // never encountered a `return` inside the body of the generic. - // - SLANG_UNEXPECTED("no return from generic"); - UNREACHABLE_RETURN(nullptr); + return specializedVal; } // The logic for generating a specialization of an IR generic @@ -1302,4 +1224,131 @@ void specializeModule( context.processModule(); } + +IRInst* specializeGenericImpl( + IRGeneric* genericVal, + IRSpecialize* specializeInst, + IRModule* module, + SpecializationContext* context) +{ + // Effectively, specializing a generic amounts to "calling" the generic + // on its concrete argument values and computing the + // result it returns. + // + // For now, all of our generics consist of a single + // basic block, so we can "call" them just by + // cloning the instructions in their single block + // into the global scope, using an environment for + // cloning that maps the generic parameters to + // the concrete arguments that were provided + // by the `specialize(...)` instruction. + // + IRCloneEnv env; + + // We will walk through the parameters of the generic and + // register the corresponding argument of the `specialize` + // instruction to be used as the "cloned" value for each + // parameter. + // + // Suppose we are looking at `specialize(g, a, b, c)` and `g` has + // three generic parameters: `T`, `U`, and `V`. Then we will + // be initializing our environment to map `T -> a`, `U -> b`, + // and `V -> c`. + // + UInt argCounter = 0; + for( auto param : genericVal->getParams() ) + { + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < specializeInst->getArgCount()); + + IRInst* arg = specializeInst->getArg(argIndex); + + env.mapOldValToNew.Add(param, arg); + } + + // We will set up an IR builder for insertion + // into the global scope, at the same location + // as the original generic. + // + SharedIRBuilder sharedBuilderStorage; + sharedBuilderStorage.module = module; + sharedBuilderStorage.session = module->getSession(); + + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->sharedBuilder = &sharedBuilderStorage; + builder->setInsertBefore(genericVal); + + // Now we will run through the body of the generic and + // clone each of its instructions into the global scope, + // until we reach a `return` instruction. + // + for( auto bb : genericVal->getBlocks() ) + { + // We expect a generic to only ever contain a single block. + // + SLANG_ASSERT(bb == genericVal->getFirstBlock()); + + // We will iterate over the non-parameter ("ordinary") + // instructions only, because parameters were dealt + // with explictly at an earlier point. + // + for( auto ii : bb->getOrdinaryInsts() ) + { + // The last block of the generic is expected to end with + // a `return` instruction for the specialized value that + // comes out of the abstraction. + // + // We thus use that cloned value as the result of the + // specialization step. + // + if( auto returnValInst = as<IRReturnVal>(ii) ) + { + auto specializedVal = findCloneForOperand(&env, returnValInst->getVal()); + return specializedVal; + } + + // For any instruction other than a `return`, we will + // simply clone it completely into the global scope. + // + IRInst* clonedInst = cloneInst(&env, builder, ii); + + // Any new instructions we create during cloning were + // not present when we initially built our work list, + // so we need to make sure to consider them now. + // + // This is important for the cases where one generic + // invokes another, because there will be `specialize` + // operations nested inside the first generic that refer + // to the second. + // + if( context ) + { + context->addToWorkList(clonedInst); + } + } + } + + // If we reach this point, something went wrong, because we + // never encountered a `return` inside the body of the generic. + // + SLANG_UNEXPECTED("no return from generic"); + UNREACHABLE_RETURN(nullptr); +} + +IRInst* specializeGeneric( + IRSpecialize* specializeInst) +{ + auto baseGeneric = as<IRGeneric>(specializeInst->getBase()); + SLANG_ASSERT(baseGeneric); + if(!baseGeneric) return specializeInst; + + auto module = specializeInst->getModule(); + SLANG_ASSERT(module); + if(!module) return specializeInst; + + return specializeGenericImpl(baseGeneric, specializeInst, module, nullptr); +} + + } // namespace Slang diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 3a5e0cd2e..665904969 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -6113,25 +6113,31 @@ static void lowerEntryPointToIR( EntryPointRequest* entryPointRequest) { // First, lower the entry point like an ordinary function - auto entryPointFuncDecl = entryPointRequest->decl; - if (!entryPointFuncDecl) - { - // Something must have gone wrong earlier, if we - // weren't able to associate a declaration with - // the entry point request. - return; - } - auto loweredEntryPointFunc = ensureDecl(context, entryPointFuncDecl); + + auto session = context->getSession(); + auto entryPointFuncDeclRef = entryPointRequest->getFuncDeclRef(); + auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); + + auto builder = context->irBuilder; + builder->setInsertInto(builder->getModule()->getModuleInst()); + + auto loweredEntryPointFunc = getSimpleVal(context, + emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType)); // Attach a marker decoration so that we recognize // this as an entry point. - auto builder = context->irBuilder; - builder->addEntryPointDecoration(getSimpleVal(context, loweredEntryPointFunc)); + // + builder->addEntryPointDecoration(loweredEntryPointFunc); + + // + if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>()) + { + builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); + } // Now lower all the arguments supplied for global generic // type parameters. // - builder->setInsertInto(builder->getModule()->getModuleInst()); for (RefPtr<Substitutions> subst = entryPointRequest->globalGenericSubst; subst; subst = subst->outer) { auto gSubst = subst.as<GlobalGenericParamSubstitution>(); diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index 2fb571db2..583d4fe54 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -1790,7 +1790,7 @@ static void collectGlobalScopeParameters( // First enumerate parameters at global scope // We collect two things here: // 1. A shader parameter, which is always a variable - // 2. A global entry-point generic parameter type (`__generic_param`), + // 2. A global entry-point generic parameter type (`type_param`), // which is a GlobalGenericParamDecl // We collect global generic type parameters in the first pass, // So we can fill in the correct index into ordinary type layouts @@ -2255,13 +2255,13 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( static RefPtr<TypeLayout> computeEntryPointParameterTypeLayout( ParameterBindingContext* context, SubstitutionSet typeSubst, - RefPtr<ParamDecl> paramDecl, + DeclRef<ParamDecl> paramDeclRef, RefPtr<VarLayout> paramVarLayout, EntryPointParameterState& state) { - auto paramType = paramDecl->type.type->Substitute(typeSubst).as<Type>(); + auto paramType = GetType(paramDeclRef)->Substitute(typeSubst).as<Type>(); - if( paramDecl->HasModifier<HLSLUniformModifier>() ) + if( paramDeclRef.getDecl()->HasModifier<HLSLUniformModifier>() ) { // An entry-point parameter that is explicitly marked `uniform` represents // a uniform shader parameter passed via the implicitly-defined @@ -2283,21 +2283,24 @@ static RefPtr<TypeLayout> computeEntryPointParameterTypeLayout( state.directionMask = 0; // If it appears to be an input, process it as such. - if( paramDecl->HasModifier<InModifier>() || paramDecl->HasModifier<InOutModifier>() || !paramDecl->HasModifier<OutModifier>() ) + if( paramDeclRef.getDecl()->HasModifier<InModifier>() + || paramDeclRef.getDecl()->HasModifier<InOutModifier>() + || !paramDeclRef.getDecl()->HasModifier<OutModifier>() ) { state.directionMask |= kEntryPointParameterDirection_Input; } // If it appears to be an output, process it as such. - if(paramDecl->HasModifier<OutModifier>() || paramDecl->HasModifier<InOutModifier>()) + if(paramDeclRef.getDecl()->HasModifier<OutModifier>() + || paramDeclRef.getDecl()->HasModifier<InOutModifier>()) { state.directionMask |= kEntryPointParameterDirection_Output; } return processEntryPointVaryingParameterDecl( context, - paramDecl.Ptr(), - paramDecl->type.type->Substitute(typeSubst).as<Type>(), + paramDeclRef.getDecl(), + paramType, state, paramVarLayout); } @@ -2460,22 +2463,14 @@ static void collectEntryPointParameters( EntryPointRequest* entryPoint, SubstitutionSet typeSubst) { - FuncDecl* entryPointFuncDecl = entryPoint->decl; - if (!entryPointFuncDecl) - { - // Something must have failed earlier, so that - // we didn't find a declaration to match this - // entry point request. - // - return; - } + DeclRef<FuncDecl> entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); // We will take responsibility for creating and filling in // the `EntryPointLayout` object here. // RefPtr<EntryPointLayout> entryPointLayout = new EntryPointLayout(); entryPointLayout->profile = entryPoint->profile; - entryPointLayout->entryPoint = entryPointFuncDecl; + entryPointLayout->entryPoint = entryPointFuncDeclRef.getDecl(); // The entry point layout must be added to the output // program layout so that it can be accessed by reflection. @@ -2522,12 +2517,12 @@ static void collectEntryPointParameters( scopeBuilder.beginLayout(context); auto paramsStructLayout = scopeBuilder.m_structLayout; - for( auto paramDecl : entryPointFuncDecl->getMembersOfType<ParamDecl>() ) + for( auto paramDeclRef : getMembersOfType<ParamDecl>(entryPointFuncDeclRef) ) { // Any error messages we emit during the process should // refer to the location of this parameter. // - state.loc = paramDecl->loc; + state.loc = paramDeclRef.getLoc(); // We are going to construct the variable layout for this // parameter *before* computing the type layout, because @@ -2536,13 +2531,13 @@ static void collectEntryPointParameters( // back onto the `VarLayout`. // RefPtr<VarLayout> paramVarLayout = new VarLayout(); - paramVarLayout->varDecl = makeDeclRef(paramDecl.Ptr()); + paramVarLayout->varDecl = paramDeclRef; paramVarLayout->stage = state.stage; auto paramTypeLayout = computeEntryPointParameterTypeLayout( context, typeSubst, - paramDecl, + paramDeclRef, paramVarLayout, state); paramVarLayout->typeLayout = paramTypeLayout; @@ -2598,10 +2593,10 @@ static void collectEntryPointParameters( // TODO: Ideally we should make the layout process more robust to empty/void // types and apply this logic unconditionally. // - auto resultType = entryPointFuncDecl->ReturnType.type; + auto resultType = GetResultType(entryPointFuncDeclRef)->Substitute(typeSubst).as<Type>(); if( !resultType->Equals(resultType->getSession()->getVoidType()) ) { - state.loc = entryPointFuncDecl->loc; + state.loc = entryPointFuncDeclRef.getLoc(); state.directionMask = kEntryPointParameterDirection_Output; RefPtr<VarLayout> resultLayout = new VarLayout(); @@ -2609,7 +2604,7 @@ static void collectEntryPointParameters( auto resultTypeLayout = processEntryPointVaryingParameterDecl( context, - entryPointFuncDecl, + entryPointFuncDeclRef.getDecl(), resultType->Substitute(typeSubst).as<Type>(), state, resultLayout); diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 076111711..b9e3786a2 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -4781,8 +4781,7 @@ namespace Slang addBuiltinSyntax<Decl>(session, scope, #KEYWORD, &CALLBACK) DECL(typedef, ParseTypeDef); DECL(associatedtype, parseAssocType); - DECL(__generic_param, parseGlobalGenericParamDecl); - DECL(type_param, parseGlobalGenericParamDecl); + DECL(type_param, parseGlobalGenericParamDecl); DECL(cbuffer, parseHLSLCBufferDecl); DECL(tbuffer, parseHLSLTBufferDecl); DECL(__generic, ParseGenericDecl); diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h index 93fe687fd..b1faf6776 100644 --- a/source/slang/syntax-base-defs.h +++ b/source/slang/syntax-base-defs.h @@ -178,7 +178,7 @@ SYNTAX_CLASS(ThisTypeSubstitution, Substitutions) END_SYNTAX_CLASS() SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions) - // the __generic_param decl to be substituted + // the type_param decl to be substituted DECL_FIELD(GlobalGenericParamDecl*, paramDecl) // the actual type to substitute in diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 6b58809a6..0af318c9e 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -1427,7 +1427,7 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr<Substitutions> GlobalGenericParamSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) { - // if we find a GlobalGenericParamSubstitution in subst that references the same __generic_param decl + // if we find a GlobalGenericParamSubstitution in subst that references the same type_param decl // return a copy of that GlobalGenericParamSubstitution int diff = 0; @@ -1884,6 +1884,11 @@ void Type::accept(IValVisitor* visitor, void* extra) return decl->nameAndLoc.name; } + SourceLoc DeclRefBase::getLoc() const + { + return decl->loc; + } + DeclRefBase DeclRefBase::GetParent() const { // Want access to the free function (the 'as' method by default gets priority) diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 2151ed764..344e94ff9 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -482,6 +482,7 @@ namespace Slang // Convenience accessors for common properties of declarations Name* GetName() const; + SourceLoc getLoc() const; DeclRefBase GetParent() const; int GetHashCode() const; diff --git a/tests/bugs/gh-357.slang b/tests/bugs/gh-357.slang index be2ba95ed..043eebf17 100644 --- a/tests/bugs/gh-357.slang +++ b/tests/bugs/gh-357.slang @@ -25,11 +25,9 @@ struct AssocImpl : IAssoc typedef BaseImpl TBase; }; -__generic_param T : IAssoc; - - [numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain<T:IAssoc>( + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; T.TBase base; diff --git a/tests/compute/assoctype-generic-arg.slang b/tests/compute/assoctype-generic-arg.slang index 4bc77c925..dd183ea5d 100644 --- a/tests/compute/assoctype-generic-arg.slang +++ b/tests/compute/assoctype-generic-arg.slang @@ -25,11 +25,9 @@ struct AssocImpl : IAssoc typedef BaseImpl TBase; }; -__generic_param T : IAssoc; - - [numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain<T : IAssoc>( + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; T.TBase base; diff --git a/tests/compute/global-type-param-array.slang b/tests/compute/global-type-param-array.slang index 87236d8f6..d801efe2c 100644 --- a/tests/compute/global-type-param-array.slang +++ b/tests/compute/global-type-param-array.slang @@ -6,17 +6,16 @@ RWStructuredBuffer<float> outputBuffer; import globalTypeParamArrayShared; -__generic_param TImpl : IBase; - -ParameterBlock<TImpl> impl; - float doCompute<T:IBase>(T t) { return t.compute(1.0); } [numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain< + TImpl : IBase>( + uniform ParameterBlock<TImpl> impl, + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; float outVal = doCompute<TImpl>(impl); diff --git a/tests/compute/global-type-param-in-entrypoint.slang b/tests/compute/global-type-param-in-entrypoint.slang index 4bcf4cbca..9a1e9b054 100644 --- a/tests/compute/global-type-param-in-entrypoint.slang +++ b/tests/compute/global-type-param-in-entrypoint.slang @@ -8,7 +8,7 @@ interface IVertInterpolant float4 getColor(); } -__generic_param TVertInterpolant : IVertInterpolant; +type_param TVertInterpolant : IVertInterpolant; struct VertImpl : IVertInterpolant { diff --git a/tests/compute/global-type-param.slang b/tests/compute/global-type-param.slang index f177dcb1d..7621f8961 100644 --- a/tests/compute/global-type-param.slang +++ b/tests/compute/global-type-param.slang @@ -26,10 +26,8 @@ struct Impl : IBase } }; -__generic_param TImpl : IBase; - [numthreads(1, 1, 1)] -void computeMain( +void computeMain<TImpl:IBase>( uniform TImpl impl, uint3 dispatchThreadID : SV_DispatchThreadID) { diff --git a/tests/compute/global-type-param1.slang b/tests/compute/global-type-param1.slang index e16ffa9da..f33be8ec7 100644 --- a/tests/compute/global-type-param1.slang +++ b/tests/compute/global-type-param1.slang @@ -26,10 +26,6 @@ struct Impl : IBase } }; -__generic_param TImpl : IBase; - -ParameterBlock<TImpl> impl; - cbuffer C { float base0; // = 0.5 @@ -39,7 +35,10 @@ Texture2D tex1; // = 0.0 SamplerState sampler; [numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain< + TImpl : IBase>( + uniform ParameterBlock<TImpl> impl, + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; float b0 = tex1.SampleLevel(sampler, float2(0.0), 0.0).x + base0; // = 0.5 diff --git a/tests/compute/global-type-param2.slang b/tests/compute/global-type-param2.slang index f29d01407..976a31df8 100644 --- a/tests/compute/global-type-param2.slang +++ b/tests/compute/global-type-param2.slang @@ -38,10 +38,6 @@ struct Impl : IBase } }; -__generic_param TImpl : IBase; - -ParameterBlock<TImpl> impl; - // at binding c0: cbuffer existingBuffer { @@ -51,7 +47,10 @@ Texture2D tex1; // = 0.0 SamplerState sampler; [numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain< + TImpl : IBase>( + uniform ParameterBlock<TImpl> impl, + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; float b0 = tex1.SampleLevel(sampler, float2(0.0), 0.0).x + base0; // = 0.5 diff --git a/tests/compute/int-generic.slang b/tests/compute/int-generic.slang index 6bb63df8c..d9eb85f82 100644 --- a/tests/compute/int-generic.slang +++ b/tests/compute/int-generic.slang @@ -29,14 +29,12 @@ struct Material<let A:int, let B: int> : IMaterial TBRDF getBRDF() { TBRDF a; a.c = 0; return a; } }; -type_param TMaterial : IMaterial; - -TMaterial material; - [numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain<M : IMaterial>( + uniform M material, + uint3 dispatchThreadID : SV_DispatchThreadID) { - TMaterial.TBRDF brdf = material.getBRDF(); + M.TBRDF brdf = material.getBRDF(); int outVal = brdf.compute(); outputBuffer[dispatchThreadID.x] = outVal; }
\ No newline at end of file diff --git a/tests/compute/tagged-union.slang b/tests/compute/tagged-union.slang index 5089ec5a7..de69232f9 100644 --- a/tests/compute/tagged-union.slang +++ b/tests/compute/tagged-union.slang @@ -44,11 +44,13 @@ struct B : IFrobnicator } } +[numthreads(4, 1, 1)] +void computeMain // Now we will define the generic type parameter for our shader, // which will be constraints to be a type that implements our // `IFrobnicator` interface. // -type_param T : IFrobnicator; + <T : IFrobnicator> // // For the actual test runner, we will instruct it to plug in // a tagged-union type over the two concrete implemetnations. @@ -57,7 +59,7 @@ type_param T : IFrobnicator; // our intention when it is informed via the API. // //TEST_INPUT: type __TaggedUnion(A,B) - + ( // Next we need to pass in the actual parameter data for our // chosen `IFrobnicator` implementation. The decalration of // the constant buffer follows the conventional approach for @@ -68,7 +70,7 @@ type_param T : IFrobnicator; // the `render-test` tool doesn't yet support code that // uses multiple descriptor tables/sets. // -ConstantBuffer<T> gFrobnicator; + uniform ConstantBuffer<T> gFrobnicator, // Where things get interesting is when we go to provide the // data that will be used by the parameter block. @@ -92,17 +94,15 @@ ConstantBuffer<T> gFrobnicator; // //TEST_INPUT: cbuffer(data=[16 9 1 0 0], stride=4):dxbinding(0),glbinding(0) -int test(int val) -{ - return gFrobnicator.frobnicate(val); -} //TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(1),out -RWStructuredBuffer<int> gOutputBuffer; - -[numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) + uniform RWStructuredBuffer<int> gOutputBuffer, + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; - gOutputBuffer[tid] = test(tid); + + int val = tid; + val = gFrobnicator.frobnicate(val); + + gOutputBuffer[tid] = val; } diff --git a/tests/reflection/global-type-params.slang b/tests/reflection/global-type-params.slang index 74961b7cc..290e6353a 100644 --- a/tests/reflection/global-type-params.slang +++ b/tests/reflection/global-type-params.slang @@ -7,8 +7,8 @@ interface IBase {}; -__generic_param TParam : IBase; -__generic_param TParam2 : IBase; +type_param TParam : IBase; +type_param TParam2 : IBase; struct S { |
