diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2019-01-16 12:48:11 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-01-16 12:48:11 -0800 |
| commit | aedf61784606406c090302efd8b7ac668ac997fc (patch) | |
| tree | b485fe5d7b027b269fcfa10503321288d10b9800 /source/slang/check.cpp | |
| parent | 8e47a3802d4d74eb11620f147ef5b29b8e931d35 (diff) | |
Initial support for dynamic dispatch using "tagged union" types (#772)
* Initial support for dynamic dispatch using "tagged union" types
Suppose a user declares some generic shader code, like the following:
```hlsl
interface IFrobnicator { ... }
type_param T : IFrobincator;
ParameterBlock<T : IFrobnicator> gFrobnicator;
...
gFrobincator.frobnicate(value);
```
and then they have some concrete implementations of the required interface:
```hlsl
struct A : IFrobnicator { ... }
struct B : IFrobnicator { ... }
```
The current Slang compiler allows them to generate distinct compiled kernels for the case of `T=A` and the case of `T=B`. This means that the decision of which implementation to use must be made at or before the time when a shader gets bound in the application.
This change adds a new ability where the Slang compiler can generate code to handle the case where `T` might be *either* `A` or `B`, and which case it is will be determined dynamically at runtime. This means a single compiled kernel can handle both cases, and the decision about which code path to run can be made any time before the shader executes.
This new option is supported by defining a *tagged union* type. Via the API, the user specifies that `T` should be specialized to `__TaggedUnion(A,B)` (the double underscore indicates that this is an experimental and unsupported feature at present). We refer to the types `A` and `B` here as the "case" types of the tagged union. Conceptually, the compiler synthesizes a type something like:
```hlsl
struct TU { union { A a; B b; } payload; uint tag; }
```
The user can then allocate a constant buffer to hold their tagged union type, and when they pick a concrete type to use (say `B`), they fill in the first `sizeof(B)` bytes of their buffer with data describing a `B` instance, and then set the `tag` field to the appopriate 0-based index of the case type they chose (in this case the `B` case gets the tag value `1`).
Actually implementing tagged unions takes a few main steps:
* Type parsing was extended to special-case `__TaggedUnion` as a contextual keyword. This is really only intended to be used when parsing types from the API or command-line, and Bad Things are likely to happen if a user ever puts it directly in their code. Eventually construction of tagged unions should be an API feature and not part of the language syntax.
* Semantic checking was extended to recognize that a tagged union like `__TaggedUnion(A,B)` shoud support an interface like `IFrobnicator` whenever all of the case types suport it, as long as the interface is "safe" for use with tagged unions (which means it doesn't use a few of the advancd langauge features like associated types).
* The IR was extended with instructions to represent tagged union types and to extract their tag and the payload for the different cases as needed.
* IR generation was extended to synthesize implementations of interface methods for any interface that a tagged union needs to support. Right now the implementation is simplistic and only handles simple method requirements, which it does by emitting a `switch` instruction to pick between the different cases.
* A new IR pass was introduced to "desugar" any tagged union types used in the code. The downstream HLSL and GLSL compilers don't support `union`s, so we have to instead emit a tagged union as a "bag of bits" and implement loading the data for particular cases from it manually.
* Final code emit mostly Just Works after the above steps, but we had to introduce an explicit IR instruction for bit-casting to handle the output of the desugaring pass.
There are a bunch of gaps and caveats in this implementation, but that seems reasonable for something that is an experimental feature. The various `TODO` comments and assertion failures in unimplemented cases are intended, so that this work can be checked in even if it isn't feature-complete.
* fixup: missing files
* fixup: typos
Diffstat (limited to 'source/slang/check.cpp')
| -rw-r--r-- | source/slang/check.cpp | 163 |
1 files changed, 158 insertions, 5 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 2fa4e9bc7..9aa5eb689 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -9,6 +9,8 @@ namespace Slang { + RefPtr<TypeType> getTypeType( + Type* type); /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? bool isEffectivelyStatic( @@ -677,7 +679,7 @@ namespace Slang auto baseExprType = baseExpr->type.type; RefPtr<SharedTypeExpr> baseTypeExpr = new SharedTypeExpr(); baseTypeExpr->base.type = baseExprType; - baseTypeExpr->type = new TypeType(baseExprType); + baseTypeExpr->type.type = getTypeType(baseExprType); auto expr = new StaticMemberExpr(); expr->loc = loc; @@ -2071,9 +2073,7 @@ namespace Slang { RefPtr<TypeCastExpr> castExpr = createImplicitCastExpr(); - auto typeType = new TypeType(); - typeType->setSession(getSession()); - typeType->type = toType; + auto typeType = getTypeType(toType); auto typeExpr = new SharedTypeExpr(); typeExpr->type.type = typeType; @@ -5547,6 +5547,60 @@ namespace Slang return witness; } + /// Is the given interface one that a tagged-union type can conform to? + /// + /// If a tagged union type `__TaggedUnion(A,B)` is going to be + /// plugged in for a type parameter `T : IFoo` then we need to + /// be sure that the interface `IFoo` doesn't have anything + /// that could lead to unsafe/unsound behavior. This function + /// checks that all the requirements on the interfaceare safe ones. + /// + bool isInterfaceSafeForTaggedUnion( + DeclRef<InterfaceDecl> interfaceDeclRef) + { + for( auto memberDeclRef : getMembers(interfaceDeclRef) ) + { + if(!isInterfaceRequirementSafeForTaggedUnion(interfaceDeclRef, memberDeclRef)) + return false; + } + + return true; + } + + /// Is the given interface requirement one that a tagged-union type can satisfy? + /// + /// Unsafe requirements include any `static` requirements, + /// any associated types, and also any requirements that make + /// use of the `This` type (once we support it). + /// + bool isInterfaceRequirementSafeForTaggedUnion( + DeclRef<InterfaceDecl> interfaceDeclRef, + DeclRef<Decl> requirementDeclRef) + { + if(auto callableDeclRef = requirementDeclRef.As<CallableDecl>()) + { + // A `static` method requirement can't be satisfied by a + // tagged union, because there is no tag to dispatch on. + // + if(requirementDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + return false; + + // TODO: We will eventually want to check that any callable + // requirements do not use the `This` type or any associated + // types in ways that could lead to errors. + // + // For now we are disallowing interfaces that have associated + // types completely, and we haven't implemented the `This` + // type, so we should be safe. + + return true; + } + else + { + return false; + } + } + bool doesTypeConformToInterfaceImpl( RefPtr<Type> originalType, RefPtr<Type> type, @@ -5661,6 +5715,69 @@ namespace Slang } } } + else if(auto taggedUnionType = type->As<TaggedUnionType>()) + { + // A tagged union type conforms to an interface if all of + // the constituent types in the tagged union conform. + // + // We will iterate over the "case" types in the tagged + // union, and check if they conform to the interface. + // Along the way we will collect the conformance witness + // values *if* we are being asked to produce a witness + // value for the tagged union itself (that is, if + // `outWitness` is non-null). + // + List<RefPtr<Val>> caseWitnesses; + for(auto caseType : taggedUnionType->caseTypes) + { + RefPtr<Val> caseWitness; + + if(!doesTypeConformToInterfaceImpl( + caseType, + caseType, + interfaceDeclRef, + outWitness ? &caseWitness : nullptr, + nullptr)) + { + return false; + } + + if(outWitness) + { + caseWitnesses.Add(caseWitness); + } + } + + // We also need to validate the requirements on + // the interface to make sure that they are suitable for + // use with a tagged-union type. + // + // For example, if the interface includes a `static` method + // (which can therefore be called without a particular instance), + // then we wouldn't know what implementation of that method + // to use because there is no tag value to dispatch on. + // + // We will start out being conservative about what we accept + // here, just to keep things simple. + // + if(!isInterfaceSafeForTaggedUnion(interfaceDeclRef)) + return false; + + // If we reach this point then we have a concrete + // witness for each of the case types, and that is + // enough to build a witness for the tagged union. + // + if(outWitness) + { + RefPtr<TaggedUnionSubtypeWitness> taggedUnionWitness = new TaggedUnionSubtypeWitness(); + taggedUnionWitness->sub = taggedUnionType; + taggedUnionWitness->sup = DeclRefType::Create(getSession(), interfaceDeclRef); + taggedUnionWitness->caseWitnesses.SwapWith(caseWitnesses); + + *outWitness = taggedUnionWitness; + } + return true; + } // default is failure return false; @@ -8090,6 +8207,23 @@ namespace Slang return expr; } + RefPtr<Expr> visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr) + { + // We have an expression of the form `__TaggedUnion(A, B, ...)` + // which will evaluate to a tagged-union type over `A`, `B`, etc. + // + RefPtr<TaggedUnionType> type = new TaggedUnionType(); + expr->type = QualType(getTypeType(type)); + + for( auto& caseTypeExpr : expr->caseTypes ) + { + caseTypeExpr = CheckProperType(caseTypeExpr); + type->caseTypes.Add(caseTypeExpr.type); + } + + return expr; + } + @@ -9039,7 +9173,7 @@ namespace Slang scopesToTry.Add(module->moduleDecl->scope); List<RefPtr<Type>> globalGenericArgs; - for (auto name : entryPoint->genericParameterTypeNames) + for (auto name : entryPoint->genericArgStrings) { // parse type name RefPtr<Type> type; @@ -9059,6 +9193,25 @@ namespace Slang return; } + // The following is a bit of a hack. + // + // Back-end code generation relies on us having computed layouts for all tagged + // unions that end up being used in the code, which means we need a way to find + // all such types that get used in a module (and the stuff it imports). + // + // The Right Way to handle this would probably be to have each `ModuleDecl` track + // any tagged union types that get created in the context of that module, and + // then combine those lists later. + // + // For now we are assuming a tagged union type only comes into existence + // as a (top-level) argument for a generic type parameter, so that we + // can check for them here and cache them on the entry point request. + // + if( auto taggedUnionType = type->As<TaggedUnionType>() ) + { + entryPoint->taggedUnionTypes.Add(taggedUnionType); + } + globalGenericArgs.Add(type); } |
