diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2019-01-16 12:48:11 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-01-16 12:48:11 -0800 |
| commit | aedf61784606406c090302efd8b7ac668ac997fc (patch) | |
| tree | b485fe5d7b027b269fcfa10503321288d10b9800 /source/slang | |
| parent | 8e47a3802d4d74eb11620f147ef5b29b8e931d35 (diff) | |
Initial support for dynamic dispatch using "tagged union" types (#772)
* Initial support for dynamic dispatch using "tagged union" types
Suppose a user declares some generic shader code, like the following:
```hlsl
interface IFrobnicator { ... }
type_param T : IFrobincator;
ParameterBlock<T : IFrobnicator> gFrobnicator;
...
gFrobincator.frobnicate(value);
```
and then they have some concrete implementations of the required interface:
```hlsl
struct A : IFrobnicator { ... }
struct B : IFrobnicator { ... }
```
The current Slang compiler allows them to generate distinct compiled kernels for the case of `T=A` and the case of `T=B`. This means that the decision of which implementation to use must be made at or before the time when a shader gets bound in the application.
This change adds a new ability where the Slang compiler can generate code to handle the case where `T` might be *either* `A` or `B`, and which case it is will be determined dynamically at runtime. This means a single compiled kernel can handle both cases, and the decision about which code path to run can be made any time before the shader executes.
This new option is supported by defining a *tagged union* type. Via the API, the user specifies that `T` should be specialized to `__TaggedUnion(A,B)` (the double underscore indicates that this is an experimental and unsupported feature at present). We refer to the types `A` and `B` here as the "case" types of the tagged union. Conceptually, the compiler synthesizes a type something like:
```hlsl
struct TU { union { A a; B b; } payload; uint tag; }
```
The user can then allocate a constant buffer to hold their tagged union type, and when they pick a concrete type to use (say `B`), they fill in the first `sizeof(B)` bytes of their buffer with data describing a `B` instance, and then set the `tag` field to the appopriate 0-based index of the case type they chose (in this case the `B` case gets the tag value `1`).
Actually implementing tagged unions takes a few main steps:
* Type parsing was extended to special-case `__TaggedUnion` as a contextual keyword. This is really only intended to be used when parsing types from the API or command-line, and Bad Things are likely to happen if a user ever puts it directly in their code. Eventually construction of tagged unions should be an API feature and not part of the language syntax.
* Semantic checking was extended to recognize that a tagged union like `__TaggedUnion(A,B)` shoud support an interface like `IFrobnicator` whenever all of the case types suport it, as long as the interface is "safe" for use with tagged unions (which means it doesn't use a few of the advancd langauge features like associated types).
* The IR was extended with instructions to represent tagged union types and to extract their tag and the payload for the different cases as needed.
* IR generation was extended to synthesize implementations of interface methods for any interface that a tagged union needs to support. Right now the implementation is simplistic and only handles simple method requirements, which it does by emitting a `switch` instruction to pick between the different cases.
* A new IR pass was introduced to "desugar" any tagged union types used in the code. The downstream HLSL and GLSL compilers don't support `union`s, so we have to instead emit a tagged union as a "bag of bits" and implement loading the data for particular cases from it manually.
* Final code emit mostly Just Works after the above steps, but we had to introduce an explicit IR instruction for bit-casting to handle the output of the desugaring pass.
There are a bunch of gaps and caveats in this implementation, but that seems reasonable for something that is an experimental feature. The various `TODO` comments and assertion failures in unimplemented cases are intended, so that this work can be checked in even if it isn't feature-complete.
* fixup: missing files
* fixup: typos
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/check.cpp | 163 | ||||
| -rw-r--r-- | source/slang/compiler.h | 8 | ||||
| -rw-r--r-- | source/slang/diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 98 | ||||
| -rw-r--r-- | source/slang/expr-defs.h | 11 | ||||
| -rw-r--r-- | source/slang/ir-inst-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/ir-insts.h | 56 | ||||
| -rw-r--r-- | source/slang/ir-link.cpp | 24 | ||||
| -rw-r--r-- | source/slang/ir-ssa.cpp | 11 | ||||
| -rw-r--r-- | source/slang/ir-union.cpp | 776 | ||||
| -rw-r--r-- | source/slang/ir-union.h | 18 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 64 | ||||
| -rw-r--r-- | source/slang/ir.h | 12 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 290 | ||||
| -rw-r--r-- | source/slang/mangle.cpp | 9 | ||||
| -rw-r--r-- | source/slang/parameter-binding.cpp | 14 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 24 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 4 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 6 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 153 | ||||
| -rw-r--r-- | source/slang/type-defs.h | 18 | ||||
| -rw-r--r-- | source/slang/type-layout.cpp | 95 | ||||
| -rw-r--r-- | source/slang/type-layout.h | 28 | ||||
| -rw-r--r-- | source/slang/val-defs.h | 17 |
25 files changed, 1883 insertions, 27 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 2fa4e9bc7..9aa5eb689 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -9,6 +9,8 @@ namespace Slang { + RefPtr<TypeType> getTypeType( + Type* type); /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? bool isEffectivelyStatic( @@ -677,7 +679,7 @@ namespace Slang auto baseExprType = baseExpr->type.type; RefPtr<SharedTypeExpr> baseTypeExpr = new SharedTypeExpr(); baseTypeExpr->base.type = baseExprType; - baseTypeExpr->type = new TypeType(baseExprType); + baseTypeExpr->type.type = getTypeType(baseExprType); auto expr = new StaticMemberExpr(); expr->loc = loc; @@ -2071,9 +2073,7 @@ namespace Slang { RefPtr<TypeCastExpr> castExpr = createImplicitCastExpr(); - auto typeType = new TypeType(); - typeType->setSession(getSession()); - typeType->type = toType; + auto typeType = getTypeType(toType); auto typeExpr = new SharedTypeExpr(); typeExpr->type.type = typeType; @@ -5547,6 +5547,60 @@ namespace Slang return witness; } + /// Is the given interface one that a tagged-union type can conform to? + /// + /// If a tagged union type `__TaggedUnion(A,B)` is going to be + /// plugged in for a type parameter `T : IFoo` then we need to + /// be sure that the interface `IFoo` doesn't have anything + /// that could lead to unsafe/unsound behavior. This function + /// checks that all the requirements on the interfaceare safe ones. + /// + bool isInterfaceSafeForTaggedUnion( + DeclRef<InterfaceDecl> interfaceDeclRef) + { + for( auto memberDeclRef : getMembers(interfaceDeclRef) ) + { + if(!isInterfaceRequirementSafeForTaggedUnion(interfaceDeclRef, memberDeclRef)) + return false; + } + + return true; + } + + /// Is the given interface requirement one that a tagged-union type can satisfy? + /// + /// Unsafe requirements include any `static` requirements, + /// any associated types, and also any requirements that make + /// use of the `This` type (once we support it). + /// + bool isInterfaceRequirementSafeForTaggedUnion( + DeclRef<InterfaceDecl> interfaceDeclRef, + DeclRef<Decl> requirementDeclRef) + { + if(auto callableDeclRef = requirementDeclRef.As<CallableDecl>()) + { + // A `static` method requirement can't be satisfied by a + // tagged union, because there is no tag to dispatch on. + // + if(requirementDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + return false; + + // TODO: We will eventually want to check that any callable + // requirements do not use the `This` type or any associated + // types in ways that could lead to errors. + // + // For now we are disallowing interfaces that have associated + // types completely, and we haven't implemented the `This` + // type, so we should be safe. + + return true; + } + else + { + return false; + } + } + bool doesTypeConformToInterfaceImpl( RefPtr<Type> originalType, RefPtr<Type> type, @@ -5661,6 +5715,69 @@ namespace Slang } } } + else if(auto taggedUnionType = type->As<TaggedUnionType>()) + { + // A tagged union type conforms to an interface if all of + // the constituent types in the tagged union conform. + // + // We will iterate over the "case" types in the tagged + // union, and check if they conform to the interface. + // Along the way we will collect the conformance witness + // values *if* we are being asked to produce a witness + // value for the tagged union itself (that is, if + // `outWitness` is non-null). + // + List<RefPtr<Val>> caseWitnesses; + for(auto caseType : taggedUnionType->caseTypes) + { + RefPtr<Val> caseWitness; + + if(!doesTypeConformToInterfaceImpl( + caseType, + caseType, + interfaceDeclRef, + outWitness ? &caseWitness : nullptr, + nullptr)) + { + return false; + } + + if(outWitness) + { + caseWitnesses.Add(caseWitness); + } + } + + // We also need to validate the requirements on + // the interface to make sure that they are suitable for + // use with a tagged-union type. + // + // For example, if the interface includes a `static` method + // (which can therefore be called without a particular instance), + // then we wouldn't know what implementation of that method + // to use because there is no tag value to dispatch on. + // + // We will start out being conservative about what we accept + // here, just to keep things simple. + // + if(!isInterfaceSafeForTaggedUnion(interfaceDeclRef)) + return false; + + // If we reach this point then we have a concrete + // witness for each of the case types, and that is + // enough to build a witness for the tagged union. + // + if(outWitness) + { + RefPtr<TaggedUnionSubtypeWitness> taggedUnionWitness = new TaggedUnionSubtypeWitness(); + taggedUnionWitness->sub = taggedUnionType; + taggedUnionWitness->sup = DeclRefType::Create(getSession(), interfaceDeclRef); + taggedUnionWitness->caseWitnesses.SwapWith(caseWitnesses); + + *outWitness = taggedUnionWitness; + } + return true; + } // default is failure return false; @@ -8090,6 +8207,23 @@ namespace Slang return expr; } + RefPtr<Expr> visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr) + { + // We have an expression of the form `__TaggedUnion(A, B, ...)` + // which will evaluate to a tagged-union type over `A`, `B`, etc. + // + RefPtr<TaggedUnionType> type = new TaggedUnionType(); + expr->type = QualType(getTypeType(type)); + + for( auto& caseTypeExpr : expr->caseTypes ) + { + caseTypeExpr = CheckProperType(caseTypeExpr); + type->caseTypes.Add(caseTypeExpr.type); + } + + return expr; + } + @@ -9039,7 +9173,7 @@ namespace Slang scopesToTry.Add(module->moduleDecl->scope); List<RefPtr<Type>> globalGenericArgs; - for (auto name : entryPoint->genericParameterTypeNames) + for (auto name : entryPoint->genericArgStrings) { // parse type name RefPtr<Type> type; @@ -9059,6 +9193,25 @@ namespace Slang return; } + // The following is a bit of a hack. + // + // Back-end code generation relies on us having computed layouts for all tagged + // unions that end up being used in the code, which means we need a way to find + // all such types that get used in a module (and the stuff it imports). + // + // The Right Way to handle this would probably be to have each `ModuleDecl` track + // any tagged union types that get created in the context of that module, and + // then combine those lists later. + // + // For now we are assuming a tagged union type only comes into existence + // as a (top-level) argument for a generic type parameter, so that we + // can check for them here and cache them on the entry point request. + // + if( auto taggedUnionType = type->As<TaggedUnionType>() ) + { + entryPoint->taggedUnionTypes.Add(taggedUnionType); + } + globalGenericArgs.Add(type); } diff --git a/source/slang/compiler.h b/source/slang/compiler.h index d2072387e..41ba027c6 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -122,9 +122,8 @@ namespace Slang // The name of the entry point function (e.g., `main`) Name* name; - // The type names we want to substitute into the - // global generic type parameters - List<String> genericParameterTypeNames; + /// Source code for the generic arguments to use for the generic parameters of the entry point. + List<String> genericArgStrings; // The profile that the entry point will be compiled for // (this is a combination of the target stage, and also @@ -156,6 +155,9 @@ namespace Slang RefPtr<FuncDecl> decl; RefPtr<Substitutions> globalGenericSubst; + + /// Any tagged union types that were referenced by the generic arguments of the entry point. + List<RefPtr<TaggedUnionType>> taggedUnionTypes; }; enum class PassThroughMode : SlangPassThrough diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 76e59efa3..4a43cf8e4 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -347,7 +347,7 @@ DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is onl DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration") DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.") -DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$1`.") +DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$2`.") DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself") DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')") diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 2e801f0e3..0aebf6153 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -12,6 +12,7 @@ #include "ir-specialize.h" #include "ir-specialize-resources.h" #include "ir-ssa.h" +#include "ir-union.h" #include "ir-validate.h" #include "legalize-types.h" #include "lower-to-ir.h" @@ -3922,6 +3923,72 @@ struct EmitVisitor } break; + case kIROp_BitCast: + { + // TODO: we can simplify the logic for arbitrary bitcasts + // by always bitcasting the source to a `uint*` type (if it + // isn't already) and then bitcasting that to the destination + // type (if it isn't already `uint*`. + // + // For now we are assuming the source type is *already* + // a `uint*` type of the appropriate size. + // +// auto fromType = extractBaseType(inst->getOperand(0)->getDataType()); + auto toType = extractBaseType(inst->getDataType()); + switch(getTarget(ctx)) + { + case CodeGenTarget::GLSL: + switch(toType) + { + default: + emit("/* unhandled */"); + break; + + case BaseType::UInt: + break; + + case BaseType::Int: + emitIRType(ctx, inst->getDataType()); + break; + + case BaseType::Float: + emit("uintBitsToFloat("); + break; + } + break; + + case CodeGenTarget::HLSL: + switch(toType) + { + default: + emit("/* unhandled */"); + break; + + case BaseType::UInt: + break; + case BaseType::Int: + emit("("); + emitIRType(ctx, inst->getDataType()); + emit(")"); + break; + case BaseType::Float: + emit("asfloat"); + break; + } + break; + + + default: + SLANG_UNEXPECTED("unhandled codegen target"); + break; + } + + emit("("); + emitIROperand(ctx, inst->getOperand(0), mode, kEOp_General); + emit(")"); + } + break; + default: emit("/* unhandled */"); break; @@ -3929,6 +3996,27 @@ struct EmitVisitor maybeCloseParens(needClose); } + BaseType extractBaseType(IRType* inType) + { + auto type = inType; + for(;;) + { + if(auto irBaseType = as<IRBasicType>(type)) + { + return irBaseType->getBaseType(); + } + else if(auto vecType = as<IRVectorType>(type)) + { + type = vecType->getElementType(); + continue; + } + else + { + return BaseType::Void; + } + } + } + void emitIRInst( EmitContext* ctx, IRInst* inst, @@ -6565,6 +6653,14 @@ String emitEntryPoint( #endif validateIRModuleIfEnabled(compileRequest, irModule); + // Desguar any union types, since these will be illegal on + // various targets. + // + desugarUnionTypes(irModule); +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "UNIONS DESUGARED"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); // Any code that makes use of existential (interface) types @@ -6595,8 +6691,6 @@ String emitEntryPoint( // specializeGenerics(irModule); - - // Debugging code for IR transformations... #if 0 dumpIRIfEnabled(compileRequest, irModule, "SPECIALIZED"); diff --git a/source/slang/expr-defs.h b/source/slang/expr-defs.h index fb29f64de..bd4ba5038 100644 --- a/source/slang/expr-defs.h +++ b/source/slang/expr-defs.h @@ -193,3 +193,14 @@ RAW( DeclRef<VarDeclBase> declRef; ) END_SYNTAX_CLASS() + + /// A type expression of the form `__TaggedUnion(A, ...)`. + /// + /// An expression of this form will resolve to a `TaggedUnionType` + /// when checked. + /// +SYNTAX_CLASS(TaggedUnionTypeExpr, Expr) +RAW( + List<TypeExp> caseTypes; +) +END_SYNTAX_CLASS()
\ No newline at end of file diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index 8d7a647f4..b6f8ce547 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -41,6 +41,8 @@ INST(Nop, nop, 0, 0) INST(VectorType, Vec, 2, 0) INST(MatrixType, Mat, 3, 0) + INST(TaggedUnionType, TaggedUnion, 0, 0) + /* Rate */ INST(ConstExprRate, ConstExpr, 0, 0) INST(GroupSharedRate, GroupShared, 0, 0) @@ -406,6 +408,11 @@ INST(ExtractExistentialValue, extractExistentialValue, 1, 0) INST(ExtractExistentialType, extractExistentialType, 1, 0) INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, 0) +INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0) +INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) + +INST(BitCast, bitCast, 1, 0) + PSEUDO_INST(Pos) PSEUDO_INST(PreInc) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 0c56e8244..8662569ba 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -663,7 +663,6 @@ struct SharedIRBuilder Dictionary<IRInstKey, IRInst*> globalValueNumberingMap; Dictionary<IRConstantKey, IRConstant*> constantMap; - Dictionary<Name*, IRWitnessTable*> witnessTableMap; }; struct IRBuilderSourceLocRAII; @@ -753,6 +752,13 @@ struct IRBuilder IRType* const* paramTypes, IRType* resultType); + IRFuncType* getFuncType( + List<IRType*> const& paramTypes, + IRType* resultType) + { + return getFuncType(paramTypes.Count(), paramTypes.Buffer(), resultType); + } + IRConstExprRate* getConstExprRate(); IRGroupSharedRate* getGroupSharedRate(); @@ -760,6 +766,16 @@ struct IRBuilder IRRate* rate, IRType* dataType); + IRType* getTaggedUnionType( + UInt caseCount, + IRType* const* caseTypes); + + IRType* getTaggedUnionType( + List<IRType*> const& caseTypes) + { + return getTaggedUnionType(caseTypes.Count(), caseTypes.Buffer()); + } + // Set the data type of an instruction, while preserving // its rate, if any. void setDataType(IRInst* inst, IRType* dataType); @@ -794,6 +810,14 @@ struct IRBuilder UInt argCount, IRInst* const* args); + IRInst* emitCallInst( + IRType* type, + IRInst* func, + List<IRInst*> const& args) + { + return emitCallInst(type, func, args.Count(), args.Buffer()); + } + IRInst* createIntrinsicInst( IRType* type, IROp op, @@ -816,6 +840,13 @@ struct IRBuilder UInt argCount, IRInst* const* args); + IRInst* emitMakeVector( + IRType* type, + List<IRInst*> const& args) + { + return emitMakeVector(type, args.Count(), args.Buffer()); + } + IRInst* emitMakeMatrix( IRType* type, UInt argCount, @@ -831,6 +862,13 @@ struct IRBuilder UInt argCount, IRInst* const* args); + IRInst* emitMakeStruct( + IRType* type, + List<IRInst*> const& args) + { + return emitMakeStruct(type, args.Count(), args.Buffer()); + } + IRInst* emitMakeExistential( IRType* type, IRInst* value, @@ -1040,6 +1078,22 @@ struct IRBuilder IRInst* param, IRInst* val); + IRInst* emitExtractTaggedUnionTag( + IRInst* val); + + IRInst* emitExtractTaggedUnionPayload( + IRType* type, + IRInst* val, + IRInst* tag); + + IRInst* emitBitCast( + IRType* type, + IRInst* val); + + // + // Decorations + // + IRDecoration* addDecoration(IRInst* value, IROp op, IRInst* const* operands, Int operandCount); IRDecoration* addDecoration(IRInst* value, IROp op) diff --git a/source/slang/ir-link.cpp b/source/slang/ir-link.cpp index 25d3b40b6..231ee81d2 100644 --- a/source/slang/ir-link.cpp +++ b/source/slang/ir-link.cpp @@ -1284,6 +1284,30 @@ IRFunc* specializeIRForEntryPoint( cloneValue(context, bindInst); } + // HACK: we need to ensure that any tagged union types + // in the IR module have layout information copied over to them. + // + // Note that we do this *after* cloning the `bindGlobalGenericParam` + // instructions, since we expected the tagged union type(s) to + // be referenced by them. + // + for( auto taggedUnionTypeLayout : entryPointLayout->taggedUnionTypeLayouts ) + { + auto taggedUnionType = taggedUnionTypeLayout->getType(); + auto mangledName = getMangledTypeName(taggedUnionType); + + RefPtr<IRSpecSymbol> sym; + if(!context->getSymbols().TryGetValue(mangledName, sym)) + continue; + + IRInst* clonedType = findClonedValue(context, sym->irGlobalValue); + if(!clonedType) + continue; + + context->builder->addLayoutDecoration(clonedType, taggedUnionTypeLayout); + } + + // TODO: *technically* we should consider the case where // we have global variables with initializers, since diff --git a/source/slang/ir-ssa.cpp b/source/slang/ir-ssa.cpp index d893137de..8c5db68fe 100644 --- a/source/slang/ir-ssa.cpp +++ b/source/slang/ir-ssa.cpp @@ -1068,16 +1068,9 @@ void constructSSA(ConstructSSAContext* context) newArgCount, newArgs.Buffer()); - // Swap decorations (all children, really) over to the new instruction + // Transfer decorations (a terminator should have no children) over to the new instruction. // - // TODO: We might want to encapsualte this in a reusable subroutine if - // we often need to copy decorations from one instruction to another. - // - while( auto firstChild = oldTerminator->getFirstDecoration() ) - { - firstChild->removeFromParent(); - firstChild->insertAtEnd(newTerminator); - } + oldTerminator->transferDecorationsTo(newTerminator); // A terminator better not have uses, so we shouldn't have // to replace them. diff --git a/source/slang/ir-union.cpp b/source/slang/ir-union.cpp new file mode 100644 index 000000000..c50d669e1 --- /dev/null +++ b/source/slang/ir-union.cpp @@ -0,0 +1,776 @@ +// ir-union.cpp +#include "ir-union.h" + +#include "ir.h" +#include "ir-insts.h" + +namespace Slang { + +// This file will implement a pass to replace any union types (currently +// just tagged unions) with plain `struct` types that attempt to provide +// equivalent semantics. This will necessarily be a bit fragile, and there +// will be fundamental limits to what the translation can support without +// improved features in the target shading languages/ILs. + +struct DesugarUnionTypesContext +{ + // We'll start with some basic state that we need to get the job done. + // + // This includes the IR module we are to process, as well as IR building + // state that we will initialize once and then use throughout the pass. + // + IRModule* module; + SharedIRBuilder sharedBuilderStorage; + IRBuilder builderStorage; + IRBuilder* getBuilder() { return &builderStorage; } + + // Because we will be replacing instructions that refer to unions with + // different logic, we'll want to remove the original instructions. + // However, we need to be careful about modifying the IR tree while also + // iterating it, and to keep things simple for ourselves we'll go ahead + // and build up a list of instruction to remove along the way, and then + // remove them all at the end. + // + List<IRInst*> instsToRemove; + + // The overall flow of the pass is pretty simple, so we will walk through it now. + // + void processModule() + { + // We start by initializing our IR building state. + // + sharedBuilderStorage.session = module->session; + sharedBuilderStorage.module = module; + builderStorage.sharedBuilder = &sharedBuilderStorage; + + // Next, we will search for any instruction that create or use + // union types, and process them accordingingly (usually by + // constructing a new instruction to replace them). + // + processInstRec(module->getModuleInst()); + + // Along the way we will build up a list of the tagged union + // types that we encountered, but we will refrain from replacing + // them until we are done (so that we always know that the instructions + // we process above refer to the original type, and not its + // replacement. + // + for( auto info : taggedUnionInfos ) + { + auto taggedUnionType = info->taggedUnionType; + auto replacementInst = info->replacementInst; + + // TODO: We should consider transferring decorations from the source + // type to the destination, but doing so carelessly could create + // problems, since an IR struct type shouldn't have, e.g., a + // `TaggedUnionTypeLayout` attached to it. + + taggedUnionType->replaceUsesWith(replacementInst); + taggedUnionType->removeAndDeallocate(); + } + + // As described previously, we build up the `instsToRemove` list as + // we iterate so that we can remove them all here and not risk + // modifying the IR tree while also walking it. + // + // TODO: This might be overkill and we could conceivably just be + // a bit careful in `processInstRec`. + // + for(auto inst : instsToRemove) + { + inst->removeAndDeallocate(); + } + } + + // In order to replace a (tagged) union type, we will need to know + // something about it, and we will use the `TaggedUnionInfo` type + // to collect all the relevant information. + // + struct TaggedUnionInfo : public RefObject + { + // We obviously need to know the tagged union itself, and + // we will also use this structure to track the instruction + // (an IR struct type) that will replace it. + // + IRTaggedUnionType* taggedUnionType; + IRInst* replacementInst; + + // In order to compute a suitable layout for the replacement + // `struct` type we need to know how the tagged union itself + // would be laid out in memory, so we require that all tagged + // unions in the generated IR have an associated (target-specific) + // layout. + // + TaggedUnionTypeLayout* taggedUnionTypeLayout; + + // The basic approach we will use 16-byte chunks (represented as an array + // of `uint4`s) to reprent the "bulk" of a type, and then use a single field + // that could be up to 12 bytes to represent the "rest" of the type. + // + // Note that there are deeply ingrained assumptions here that all types + // are at least four bytes in size (so that unions cannot easily + // accomodate `half` value), and that any types *larger* than four bytes + // will need to be loaded/stored via multiple 4-byte loads/stores. + // + // With the basic idea out of the way, we need an IR level field + // in our struct to hold the bulk data, which comprises a "key" for + // looking up the field, and the type of the field itself. We also + // keep track of how many bytes we put in our bulk storage. + // + // The bulk field might be: + // + // - null, if none of the case types was 16 bytes or more + // - a single `uint4` for between 16 and 31 (inclusive) bytes + // - an array of `uint4`s for 32 or more bytes + // + UInt64 bulkSize = 0; + IRInst* bulkFieldKey = nullptr; + IRType* bulkFieldType = nullptr; + + // The same basic idea then applies to the rest of the data. + // + // The "rest" field will be either be absent (if the size of the + // type was evently divisible by 16), a scalar `uint`, or else + // a 2- or 3-component vector of `uint`. + // + UInt64 restSize = 0; + IRInst* restFieldKey = nullptr; + IRType* restFieldType = nullptr; + + // Finally, since we are currently working with tagged unions, + // we need a field to hold the tag, which will always be allocated + // after the fields that hold the bulk/rest of the payload. + // + // This field is always a single `uint`. + // + // TODO: if/when we support untagged unions, they could be handled + // by having this field be null. + // + IRInst* tagFieldKey; + }; + + // We will build up a list of all the tagged union types we encounter, + // so that we can replace them with the synthesized types when we are done. + // + List<RefPtr<TaggedUnionInfo>> taggedUnionInfos; + + // It is possible that we will see the same tagged union type referenced + // many times in the IR, but we only want to synthesize the information + // above (including the various IR structures) once, so we also maintain + // a map from the original IR type to the corresponding information. + // + Dictionary<IRInst*, TaggedUnionInfo*> mapIRTypeToTaggedUnionInfo; + + // We will process all instructions in the module in a single recursive walk. + // + void processInstRec(IRInst* inst) + { + processInst(inst); + + for( auto child : inst->getChildren() ) + { + processInstRec(child); + } + } + // + // At each instruction, we will check if it is one of the union-related instructions + // we need to replace, and process it accordingly. + // + void processInst(IRInst* inst) + { + switch( inst->op ) + { + default: + // Any instruction not listed below either doesn't involve union types, + // or handles them in a hands-off fashion that we don't need to care about. + // + // E.g., a `load` of a union type from a constant buffer will turn into + // a load of the replacement `struct` type once we are done, and nothing + // needs to be done to the `load` instruction. + // + break; + + case kIROp_TaggedUnionType: + { + // We clearly need to process the tagged union type itself, but the actual + // work is handled by other functions. All we need to do here is ensure + // that the information for this type gets generated, and then we can + // rely on the main `processModule` function to do the actual replacement later. + // + auto type = cast<IRTaggedUnionType>(inst); + getTaggedUnionInfo(type); + } + break; + + case kIROp_ExtractTaggedUnionTag: + { + // The case of extracting the tag from a tagged union is relatively + // simple, because the replacement type will have a dedicated field or it. + // + // We start by finding the tagged union value the instruction is operating + // on, and then looking up the information for its type (which had + // better be a tagged union type). + // + auto taggedUnionVal = inst->getOperand(0); + auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType()); + + // Because the replacement type will have an explicit field for the tag, + // we can simply emit a single field-extract instruction to read its value + // out. + // + auto builder = getBuilder(); + builder->setInsertBefore(inst); + auto replacement = builder->emitFieldExtract( + inst->getFullType(), + taggedUnionVal, + taggedUnionInfo->tagFieldKey); + + // Now we can replace anything that used the original instruction with + // the new field-extract operation, and add this instruction to the + // list for later removal. + // + inst->replaceUsesWith(replacement); + instsToRemove.Add(inst); + } + break; + + case kIROp_ExtractTaggedUnionPayload: + { + // The most interesting case is when we are trying to extract a particular + // payload (one of the case types) from a union. We may need to extract + // one or more fields from the data stored in the union's replacement + // type (the bulk/rest fields), and we may also have to convert them + // to the type expected via bit-casts. + + // We can start things off easily enough by extracting the tagged union + // value being operated on, as well as the information for its type. + // + auto taggedUnionVal = inst->getOperand(0); + auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType()); + + // Next we need to figure out which case is being extracted from the union. + // The operand for the case tag should be a literal by construction. + // + auto caseTagVal = inst->getOperand(1); + auto caseTagConst = as<IRIntLit>(caseTagVal); + SLANG_ASSERT(caseTagConst); + + // The case type we are extracting will be the result type of the instruciton. + // + auto caseType = inst->getDataType(); + // + // The tag value itself will be the index of the case type in the union + // type (and its layout). + // + auto caseTagIndex = UInt(caseTagConst->getValue()); + + // We can use the case tag value to look up the layout for the particular + // case type we are extracting (this will allow us to resolve byte offsets + // for fields, etc.). + // + auto taggedUnionTypeLayout = taggedUnionInfo->taggedUnionTypeLayout; + SLANG_ASSERT(caseTagIndex < taggedUnionTypeLayout->caseTypeLayouts.Count()); + auto caseTypeLayout = taggedUnionTypeLayout->caseTypeLayouts[caseTagIndex]; + + // At this point we know the type we are trying to extract, as well + // as its layout. We will defer the actual implementation of extraction + // to a (recursive) subroutine that can extract a (sub-)field from the + // union at a given byte offset. Since we are extracting a full case + // right now, the byte offset will be zero. + // + auto payloadVal = extractPayload( + taggedUnionInfo, + taggedUnionVal, + caseType, + caseTypeLayout, + 0); + + // TODO: There is a significant flaw in the above approach when + // the case type might be (or contain) an array. If we have a setup + // like the following: + // + // union SomeUnion { float someCase[100]; ... } + // ... + // float result = someUnion.someCase[someIndex]; + // + // The current logic would desugar this into something like: + // + // struct SomeUnion { uint4 bulk[100]; ... } + // ... + // float[] tmp = { asfloat(someUnion.bulk[0].x), asfloat(someUnion.bulk[1].x), ... } + // float result = tmp[someIndex]; + // + // The result is that we copy an entire 100-element array into local memory + // just to fetch a single element, when it would be much nicer to just do: + // + // float result = asfloat(someUnion.bulk[someIndex].x); + // + // Achieving the latter code requires that rather than blindly translate + // the `extractTaggedUnionPayload` instruction into a semantically equiavlent + // value (which might lead to a big copy in the end), we should transitively + // chase down any "access chains" off of `inst` and see what leaf values are + // actually needed, and generated more tailored extraction logic for just + // the elements/fields that actually get referenced. + // + // The more refined approach can be built on top of many of the same primitives, + // so for now we will resign ourselves to the simpler but potentially less + // efficient approach. + + // Now that we've extracted the value for the payload from the fields of + // the replacement struct, we can use that extracted value to replace + // this instruction, and schedule the original instruction for removal. + // + inst->replaceUsesWith(payloadVal); + instsToRemove.Add(inst); + } + break; + } + } + + // The `extractPayload` operation is the most important bit of translation we + // need to do to make unions work. We have as input the following: + // + IRInst* extractPayload( + + // - Information about a tagged union type and its layout. + TaggedUnionInfo* taggedUnionInfo, + + // - A single value of that tagged unon type. + IRInst* taggedUnionVal, + + // - Type type of some "payload" field we want to extract from the union. + IRType* payloadType, + + // - The memory layout of that payload type. + TypeLayout* payloadTypeLayout, + + // - The byte offset at which we want to fetch the payload. + UInt64 payloadOffset) + { + // We are going to be building some IR code no matter what. + // + auto builder = getBuilder(); + + // The basic approach here will be to look at the type we + // are trying to extract from the union, and whenever possible + // recursively walk its structure so that we can express things + // in terms of extraction of smaller/simpler types. + // + if( auto irStructType = as<IRStructType>(payloadType) ) + { + // A structure type is a nice recursive case: we simply + // want to extract each of its field recursively, and + // then construct a fresh value of the `struct` type. + + // In all of the cases of this function we expect/require + // there to be complete type layout information for the + // types involved. + // + auto structTypeLayout = dynamic_cast<StructTypeLayout*>(payloadTypeLayout); + SLANG_ASSERT(structTypeLayout); + + // We are going to emit code to extract each of the fields + // and collect them to use as operands to a `makeStruct`. + // + List<IRInst*> fieldVals; + + // We need to walk over the fields in the order the IR expects them + UInt fieldCounter = 0; + for( auto irField : irStructType->getFields() ) + { + IRType* fieldType = irField->getFieldType(); + + // TODO: We need to confirm/enforce that the fields of the + // IR struct and the fields of the layout still align. + // + UInt fieldIndex = fieldCounter++; + auto fieldLayout = structTypeLayout->fields[fieldIndex]; + auto fieldTypeLayout = fieldLayout->getTypeLayout(); + + // The offset of the field can be computed from the base + // offset passed in, plus the reflection data for the field. + // + UInt64 fieldOffset = payloadOffset; + if(auto resInfo = fieldLayout->FindResourceInfo(LayoutResourceKind::Uniform)) + fieldOffset += resInfo->index; + + // We make a recursive call to extract each field, expecting + // that this will bottom out eventually. + // + IRInst* fieldVal = extractPayload( + taggedUnionInfo, + taggedUnionVal, + fieldType, + fieldTypeLayout, + fieldOffset); + fieldVals.Add(fieldVal); + } + + // The final value is then just a new struct constructed from + // the extracted field values. + // + auto payloadVal = builder->emitMakeStruct(irStructType, fieldVals); + return payloadVal; + } + else if( auto vecType = as<IRVectorType>(payloadType) ) + { + auto elementType = vecType->getElementType(); + + // We expect that by the time we are desugaring union types + // all vector types have literal constant values for their + // element count. + // + auto elementCountVal = vecType->getElementCount(); + auto elementCountConst = as<IRIntLit>(elementCountVal); + SLANG_ASSERT(elementCountConst); + UInt elementCount = UInt(elementCountConst->getValue()); + + // HACK: There is currently no `VectorTypeLayout` and thus + // no way to query the layout of the elements of a vector + // type. Until that gets added we will kludge things here. + // + TypeLayout* elementTypeLayout = nullptr; + size_t elementSize = 0; + if(auto resInfo = payloadTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform)) + elementSize = resInfo->count.getFiniteValue() / elementCount; + + // Similar to the `struct` case above, we will extract a + // value for each element of the vector, and then use + // `makeVector` to construct the result value. + // + List<IRInst*> elementVals; + for(UInt ii = 0; ii < elementCount; ++ii) + { + auto elementVal = extractPayload( + taggedUnionInfo, + taggedUnionVal, + elementType, + elementTypeLayout, + payloadOffset + ii*elementSize); + elementVals.Add(elementVal); + } + return builder->emitMakeVector(vecType, elementVals); + } + else if( auto matType = as<IRMatrixType>(payloadType) ) + { + SLANG_UNIMPLEMENTED_X("matrix in union type"); + } + else if( auto arrayType = as<IRArrayType>(payloadType) ) + { + SLANG_UNIMPLEMENTED_X("array in union type"); + } + else + { + // If none of the above cases match, then we assume that + // we have an individual scalar field that we need to fetch. + // + UInt64 payloadSize = 0; + if( auto resInfo = payloadTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) + { + // TODO: somebody before this point should generate an error if + // we have a `union` type that contains a potentially unbounded + // amount of data. + // + payloadSize = resInfo->count.getFiniteValue(); + } + + if( payloadSize != 4 ) + { + // TODO: We should handle the case of 64-bit fields by fetching + // two `uint` values to form a `uint2`, and then using an + // appropriate bit-cast to get from `uint2` to, e.g., `double`. + // + // The case of 16-bit and smaller fields is more troublesome, but + // in the worst case we can load a `uint` and then use bitwise + // ops to extract what we need before bitcasting. + // + // The right long-term solution is for downstream languages to have + // better support for raw memory addressing. + + SLANG_UNIMPLEMENTED_X("leaf union field with size other than 4 bytes"); + } + + // We know that we want to fetch a value of size `payloadSize`, and + // we have a known base value and an initial offset into it. + // + IRInst* baseVal = taggedUnionVal; + UInt64 offset = payloadOffset; + + // We are going to refine our `baseVal` and `offset` as we go, by + // trying to narrow down the data we will access in the `struct` + // type that will provide storage for the union. + // + // The first thing we want to check is if the value sits in the + // "bulk" part of the storage, or the "rest." + // + UInt64 bulkSize = taggedUnionInfo->bulkSize; + if( offset < bulkSize ) + { + // If the value starts in the bulk area, then the whole + // thing had better fit in the bulk area. The 16-byte + // granularity rules for constant buffers should ensure + // this property for us on current targets. + // + SLANG_ASSERT(offset + payloadSize <= bulkSize); + + // Since we know we'll be accessing the bulk storage, + // we will extract it here. The extracted field will + // be our new base value, but the `offset` doesn't need + // to be updated since the bulk field sits at offset 0. + // + baseVal = builder->emitFieldExtract( + taggedUnionInfo->bulkFieldType, + baseVal, + taggedUnionInfo->bulkFieldKey); + + // The bulk storage could be an array, if there are 32 + // or more bytes of bulk storage. + // + if( auto baseArrayType = as<IRArrayType>(baseVal->getDataType()) ) + { + // If an array was allocated for bulk storage then + // our leaf value resides entirely within a single + // element (due to constant buffer layout rules), + // and so we will fetch the appropriate element here. + // + // We will change our `baseVal` to the extracted element, + // and then also adjust our `offset` to be relative + // to that element. + // + size_t bulkElementSize = 16; + auto index = offset / bulkElementSize; + baseVal = builder->emitElementExtract( + baseArrayType->getElementType(), + baseVal, + builder->getIntValue(builder->getIntType(), index)); + offset -= index*bulkElementSize; + } + } + else + { + // If the offset of the field we want is past the end of + // the bulk field then it must sit inside of the rest field, + // and we'll extract it here. This establishes a new + // base value, and we adjust the `offset` to be relative + // to the rest field (which starts at an offset equal to `bulkSize`). + // + baseVal = builder->emitFieldExtract( + taggedUnionInfo->restFieldType, + baseVal, + taggedUnionInfo->restFieldKey); + offset -= bulkSize; + } + + // We've now extracted a field that could be either a scalar or + // a vector, and we have an offset into it. In the case where + // the base value is a vector, we will extract out the appropriate + // element. + // + if( auto baseVecType = as<IRVectorType>(baseVal->getDataType()) ) + { + size_t vecElementSize = 4; + auto index = offset / vecElementSize; + baseVal = builder->emitElementExtract( + baseVecType->getElementType(), + baseVal, + builder->getIntValue(builder->getIntType(), index)); + offset -= index*vecElementSize; + } + + // At this point, our `baseVal` should be a single `uint`, and + // it should provide the storage for the exact thing we wanted + // to access (under the assumption that we always fetch 4 bytes + // on 4-byte alignment). + // + IRInst* payloadVal = baseVal; + SLANG_ASSERT(offset == 0); + + // TODO: we could imagine adding logic here to handle types less + // than 4 bytes in size by shifting and masking the value we + // just loaded. + + // The payload field we were trying to extract might have a type + // other than `uint`, and to handle that case we need to employ + // a bit-cast to get to the desired type. + // + if( payloadVal->getDataType() != payloadType ) + { + payloadVal = builder->emitBitCast( + payloadType, + payloadVal); + } + return payloadVal; + } + } + + // All of the logic so far as assumed we can just call `getTaggedUnionInfo` + // and have easy access to all the required information and the + // synthesized replacement type. + // + TaggedUnionInfo* getTaggedUnionInfo(IRType* type) + { + // The big picture is fairly simple: we will lazily build and + // memoize the information about tagged unions. + // + { + TaggedUnionInfo* info = nullptr; + if(mapIRTypeToTaggedUnionInfo.TryGetValue(type, info)) + return info; + } + + // When we don't find information in our memo-cache, we + // will construct it and add it to both the memo-cache + // *and* a global list of all tagged unions encountered, + // so that we can replacement them later. + // + auto info = createTaggedUnionInfo(type); + mapIRTypeToTaggedUnionInfo.Add(type, info.Ptr()); + taggedUnionInfos.Add(info); + + return info; + } + + // The actual logic for creating a `TaggedUnionInfo` is relatively + // straightforward once we've decided what information we need. + // + RefPtr<TaggedUnionInfo> createTaggedUnionInfo(IRType* type) + { + // We expect that any type used as an operation to one of the + // `extractTaggedUnion*` operations must be an IR tagged union. + // + // Note: If/when we ever expose `union`s to user and allow + // then to create *generic* tagged union types it might appear + // that this needs to be changed to account for a `specialize` + // instruction in place of a concrete tagged union, but in + // practice this pass needs to be performed late enough that + // any such generic should be fully specialized. + // + auto taggedUnionType = as<IRTaggedUnionType>(type); + SLANG_ASSERT(taggedUnionType); + + RefPtr<TaggedUnionInfo> info = new TaggedUnionInfo(); + info->taggedUnionType = taggedUnionType; + + // We are going to create an instruction to replace `type`, + // and thus will be placing it into the same parent. + // + auto builder = getBuilder(); + builder->setInsertBefore(type); + + // A tagged union type will be replaced with an ordinary + // `struct` type with fields to store all the relevant + // data from any of the cases, plus a tag field. + // + auto structType = builder->createStructType(); + info->replacementInst = structType; + + // We require/expect the earlier code generation steps to have + // assocaited a layout with every tagged union that appears in + // the code. + // + auto layoutDecoration = type->findDecoration<IRLayoutDecoration>(); + SLANG_ASSERT(layoutDecoration); + auto layout = layoutDecoration->getLayout(); + SLANG_ASSERT(layout); + auto taggedUnionTypeLayout = dynamic_cast<TaggedUnionTypeLayout*>(layout); + SLANG_ASSERT(taggedUnionTypeLayout); + + info->taggedUnionTypeLayout = taggedUnionTypeLayout; + + // The size of the "payload" for the different cases (everything but + // the tag) is taken to be the offset of the tag itself. + // + // TODO: this might be inaccurate if the payload size isn't a multiple + // of the tag's alignment. We should deal with that when/if we support + // types smaller than 4 bytes in unions. + // + auto payloadSize = taggedUnionTypeLayout->tagOffset.getFiniteValue(); + + // We are going to be construction IR code that makes use of the `int` + // and `uint` types in several cases, so we go ahead and get a pointer + // to those types here. + // + auto intType = getBuilder()->getIntType(); + auto uintType = getBuilder()->getBasicType(BaseType::UInt); + + // For now we will use a simple stragegy for how we encode a union, + // which depends only on the total number of bytes needed, and not + // on the makeup of the values being stored. + // + // We will start by allocating one or more `uint4` values (in an + // array for the "or more" case) to hold the bulk of any large + // payload value. + // + size_t bulkVectorSize = 16; // Note: assuming `sizeof(uint4) == 16` on all targets + auto bulkVectorCount = payloadSize / bulkVectorSize; + auto bulkFieldSize = bulkVectorCount * bulkVectorSize; + if( bulkVectorCount ) + { + IRType* bulkFieldType = builder->getVectorType( + uintType, + builder->getIntValue(intType, 4)); + + if( bulkVectorCount > 1 ) + { + bulkFieldType = builder->getArrayType( + bulkFieldType, + builder->getIntValue(intType, bulkVectorCount)); + } + + auto bulkFieldKey = builder->createStructKey(); + builder->createStructField(structType, bulkFieldKey, bulkFieldType); + + info->bulkFieldKey = bulkFieldKey; + info->bulkFieldType = bulkFieldType; + } + info->bulkSize = bulkFieldSize; + + // The rest of the data (anything that doesn't fit in the bulk field), + // will get allocated into a single scalar or vector of `uint`. + // + auto restSize = payloadSize - bulkFieldSize; + if( restSize ) + { + size_t restElementSize = 4; // assuming `sizeof(uint) == 4` on all targets + auto restElementCount = restSize / restElementSize; + auto restFieldSize = restElementSize * restElementCount; + SLANG_ASSERT(restFieldSize == restSize); // Note: all our current targets have minimum 4-byte storage granularity + + IRType* restFieldType = uintType; + if( restElementCount > 1 ) + { + restFieldType = builder->getVectorType( + restFieldType, + builder->getIntValue(intType, restElementCount)); + } + + auto restFieldKey = builder->createStructKey(); + builder->createStructField(structType, restFieldKey, restFieldType); + + info->restFieldKey = restFieldKey; + info->restFieldType = restFieldType; + info->restSize = restFieldSize; + } + + // Finally, we add a field to represent the tag. + // + auto tagFieldType = uintType; + auto tagFieldKey = builder->createStructKey(); + builder->createStructField(structType, tagFieldKey, tagFieldType); + + info->tagFieldKey = tagFieldKey; + + return info; + } +}; + +void desugarUnionTypes( + IRModule* module) +{ + DesugarUnionTypesContext context; + context.module = module; + + context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/ir-union.h b/source/slang/ir-union.h new file mode 100644 index 000000000..58de4e81e --- /dev/null +++ b/source/slang/ir-union.h @@ -0,0 +1,18 @@ +// ir-union.h +#pragma once + +namespace Slang { + +struct IRModule; + + /// Desugar any unions types, and code using them, in `module` + /// + /// Union types will be replaced with ordinary `struct` types that store + /// the data of the underlying type as a "bag of bits" and references + /// to cases of the union will be replaced with logic to extract the + /// relevant bits. + /// +void desugarUnionTypes( + IRModule* module); + +} // namespace Slang diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 48716bd87..8f33034b7 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -1731,6 +1731,18 @@ namespace Slang operands); } + IRType* IRBuilder::getTaggedUnionType( + UInt caseCount, + IRType* const* caseTypes) + { + return (IRType*) findOrEmitHoistableInst( + this, + getTypeKind(), + kIROp_TaggedUnionType, + caseCount, + (IRInst* const*) caseTypes); + } + void IRBuilder::setDataType(IRInst* inst, IRType* dataType) { if (auto oldRateQualifiedType = as<IRRateQualifiedType>(inst->getFullType())) @@ -2684,6 +2696,50 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitExtractTaggedUnionTag( + IRInst* val) + { + auto inst = createInst<IRInst>( + this, + kIROp_ExtractTaggedUnionTag, + getBasicType(BaseType::UInt), + val); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitExtractTaggedUnionPayload( + IRType* type, + IRInst* val, + IRInst* tag) + { + auto inst = createInst<IRInst>( + this, + kIROp_ExtractTaggedUnionPayload, + type, + val, + tag); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBitCast( + IRType* type, + IRInst* val) + { + auto inst = createInst<IRInst>( + this, + kIROp_BitCast, + type, + val); + addInst(inst); + return inst; + } + + // + // Decorations + // + IRDecoration* IRBuilder::addDecoration(IRInst* value, IROp op, IRInst* const* operands, Int operandCount) { auto decoration = createInstWithTrailingArgs<IRDecoration>( @@ -3796,6 +3852,14 @@ namespace Slang } } + void IRInst::transferDecorationsTo(IRInst* target) + { + while( auto decoration = getFirstDecoration() ) + { + decoration->removeFromParent(); + decoration->insertAtStart(target); + } + } bool IRInst::mightHaveSideEffects() { diff --git a/source/slang/ir.h b/source/slang/ir.h index de644a709..e62bb2249 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -373,6 +373,9 @@ struct IRInst // for those values. void removeArguments(); + /// Transfer any decorations of this instruction to the `target` instruction. + void transferDecorationsTo(IRInst* target); + /// Does this instruction have any uses? bool hasUses() const { return firstUse != nullptr; } @@ -959,18 +962,23 @@ struct IRStructField : IRInst // *not* contain the keys, because code needs to be able to // reference the keys from scopes outside of the struct. // -struct IRStructType : IRInst +struct IRStructType : IRType { IRInstList<IRStructField> getFields() { return IRInstList<IRStructField>(getChildren()); } IR_LEAF_ISA(StructType) }; -struct IRInterfaceType : IRInst +struct IRInterfaceType : IRType { IR_LEAF_ISA(InterfaceType) }; +struct IRTaggedUnionType : IRType +{ + IR_LEAF_ISA(TaggedUnionType) +}; + /// @brief A global value that potentially holds executable code. /// struct IRGlobalValueWithCode : IRInst diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index b1bc63fa1..3e981dbc5 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -1121,6 +1121,259 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower requirementKey)); } + LoweredValInfo visitTaggedUnionSubtypeWitness( + TaggedUnionSubtypeWitness* val) + { + // The sub-type in this case is a tagged union `A | B | ...`, + // and the witness holds an array of witnesses showing that each + // "case" (`A`, `B`, etc.) is a subtype of the super-type. + + // We will start by getting the IR-level representation of the + // sub type (the tagged union type). + // + auto irTaggedUnionType = lowerType(context, val->sub); + + // We can turn each of those per-case witnesses into a witness + // table value: + // + auto caseCount = val->caseWitnesses.Count(); + List<IRInst*> caseWitnessTables; + for( auto caseWitness : val->caseWitnesses ) + { + auto caseWitnessTable = lowerSimpleVal(context, caseWitness); + caseWitnessTables.Add(caseWitnessTable); + } + + // Now we need to synthesize a witness table for the tagged union + // value, showing how it can implement all of the requirements + // of the super type by delegating to the appropriate implementation + // on a per-case basis. + // + // We will assume here that the super-type is an interface, and it + // will be left to the front-end to ensure this property. + // + auto supDeclRefType = val->sup->As<DeclRefType>(); + if(!supDeclRefType) + { + SLANG_UNEXPECTED("super-type not a decl-ref type when generating tagged union witness table"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + auto supInterfaceDeclRef = supDeclRefType->declRef.As<InterfaceDecl>(); + if( !supInterfaceDeclRef ) + { + SLANG_UNEXPECTED("super-type not an interface type when generating tagged union witness table"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + auto irWitnessTable = getBuilder()->createWitnessTable(); + + // Now we will iterate over the requirements (members) of the + // interface and try to synthesize an appropriate value for each. + // + for( auto reqDeclRef : getMembers(supInterfaceDeclRef) ) + { + // TODO: if there are any members we shouldn't process as a requirement, + // then we should detect and skip them here. + // + + // Every interface requirement will have a unique key that is used + // when looking up the requirement in a concrete witness table. + // + auto irReqKey = getInterfaceRequirementKey(context, reqDeclRef.getDecl()); + + // We expect that each of the witness tables in `caseWitnessTables` + // will have an entry to match these keys. However, we may not + // have a concrete `IRWitnessTable` for each of the case types, either + // because they are a specialization of a generic (so that the witness + // table reference is a `specialize` instruction at this point), or + // they are a type external to this module (so that we have a declaration + // rather than a definition of the witness table). + + // Our task is to create an IR value that can satisfy the interface + // requirement for the tagged union type, by appropriately delegating + // to the implementations of the same requirement in the case types. + // + IRInst* irSatisfyingVal = nullptr; + + + + if(auto callableDeclRef = reqDeclRef.As<CallableDecl>()) + { + // We have something callable, so we need to synthesize + // a function to satisfy it. + // + auto irFunc = getBuilder()->createFunc(); + irSatisfyingVal = irFunc; + + IRBuilder subBuilderStorage; + auto subBuilder = &subBuilderStorage; + subBuilder->sharedBuilder = getBuilder()->sharedBuilder; + subBuilder->setInsertInto(irFunc); + + // We will start by setting up the function parameters, + // which live in the entry block of the IR function. + // + auto entryBlock = subBuilder->emitBlock(); + subBuilder->setInsertInto(entryBlock); + + // Create a `this` parameter of the tagged-union type. + // + // TODO: need to handle the `[mutating]` case here... + // + auto irThisType = irTaggedUnionType; + auto irThisParam = subBuilder->emitParam(irThisType); + + List<IRType*> irParamTypes; + irParamTypes.Add(irThisType); + + // Create the remaining parameters of the callable, + // using a decl-ref specialized to the tagged union + // type (so that things like associated types are + // mapped to the correct witness value). + // + List<IRParam*> irParams; + for( auto paramDeclRef : getMembersOfType<ParamDecl>(callableDeclRef) ) + { + // TODO: need to handle `out` and `in out` here. Over all + // there is a lot of duplication here with the existing logic + // for emitting the signature of a `CallableDecl`, and we should + // try to re-use that if at all possible. + // + auto irParamType = lowerType(context, GetType(paramDeclRef)); + auto irParam = subBuilder->emitParam(irParamType); + + irParams.Add(irParam); + irParamTypes.Add(irParamType); + } + + auto irResultType = lowerType(context, GetResultType(callableDeclRef)); + + auto irFuncType = subBuilder->getFuncType( + irParamTypes, + irResultType); + irFunc->setFullType(irFuncType); + + // The first thing our function needs to do is extract the tag + // from the incoming `this` parameter. + // + auto irTagVal = subBuilder->emitExtractTaggedUnionTag(irThisParam); + + // Next we want to emit a `switch` on the tag value, but before we + // do that we need to generate the code for each of the cases so that + // our `switch` has somewhere to branch to. + // + List<IRInst*> switchCaseOperands; + + IRBlock* defaultLabel = nullptr; + + for( UInt ii = 0; ii < caseCount; ++ii ) + { + auto caseTag = subBuilder->getIntValue(irTagVal->getDataType(), ii); + + subBuilder->setInsertInto(irFunc); + auto caseLabel = subBuilder->emitBlock(); + + if(!defaultLabel) + defaultLabel = caseLabel; + + switchCaseOperands.Add(caseTag); + switchCaseOperands.Add(caseLabel); + + subBuilder->setInsertInto(caseLabel); + + // We need to look up the satisfying value for this interface + // requirement on the witness table of the particular case value. + // + // We already have the witness table, and the requirement key is + // just `irReqKey`. + // + auto caseWitnessTable = caseWitnessTables[ii]; + + // The subtle bit here is determining the type we expect the + // satisfying value to have, since that depends on the actual + // type that is satisfying the requirement. + // + IRType* caseResultType = irResultType; + IRType* caseFuncType = nullptr; + auto caseFunc = subBuilder->emitLookupInterfaceMethodInst( + caseFuncType, + caseWitnessTable, + irReqKey); + + // We are going to emit a `call` to the satisfying value + // for the case type, so we will collect the arguments for that call. + // + List<IRInst*> caseArgs; + + // The `this` argument to the call will need to represent the + // appropriate field of our tagged union. + // + IRType* caseThisType = (IRType*) irTaggedUnionType->getOperand(ii); + auto caseThisArg = subBuilder->emitExtractTaggedUnionPayload( + caseThisType, + irThisParam, caseTag); + caseArgs.Add(caseThisArg); + + // The remaining arguments to the call will just be forwarded from + // the parameters of the wrapper functon. + // + // TODO: This would need to change if/when we started allowing `This` type + // or assocaited-type parameters to be used at call sites where a tagged + // union is used. + // + for( auto param : irParams ) + { + caseArgs.Add(param); + } + + auto caseCall = subBuilder->emitCallInst(caseResultType, caseFunc, caseArgs); + + if( as<IRVoidType>(irResultType->getDataType()) ) + { + subBuilder->emitReturn(); + } + else + { + subBuilder->emitReturn(caseCall); + } + } + + // We will create a block to represent the supposedly-unreachable + // code that will run if no `case` matches. + // + subBuilder->setInsertInto(irFunc); + auto invalidLabel = subBuilder->emitBlock(); + subBuilder->setInsertInto(invalidLabel); + subBuilder->emitUnreachable(); + + if(!defaultLabel) defaultLabel = invalidLabel; + + // Now we have enough information to go back and emit the `switch` instruction + // into the entry block. + subBuilder->setInsertInto(entryBlock); + subBuilder->emitSwitch( + irTagVal, // value to `switch` on + invalidLabel, // `break` label (block after the `switch` statement ends) + defaultLabel, // `default` label (where to go if no `case` matches) + switchCaseOperands.Count(), + switchCaseOperands.Buffer()); + } + else + { + // TODO: We need to handle other cases of interface requirements. + SLANG_UNEXPECTED("unexpceted interface requirement when generating tagged union witness table"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + // Once we've generating a value to satisfying the requirement, we install + // it into the witness table for our tagged-union type. + // + getBuilder()->createWitnessTableEntry(irWitnessTable, irReqKey, irSatisfyingVal); + } + + return LoweredValInfo::simple(irWitnessTable); + } + LoweredValInfo visitConstantIntVal(ConstantIntVal* val) { // TODO: it is a bit messy here that the `ConstantIntVal` representation @@ -1304,6 +1557,37 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(getBuilder()->emitExtractExistentialWitnessTable(existentialVal)); } + LoweredValInfo visitTaggedUnionType(TaggedUnionType* type) + { + // A tagged union type will lower into an IR `union` over the cases, + // along with an IR `struct` with a field for the union and a tag. + // (Note: we are placing the tag after the payload to avoid padding + // in the case where the payload is more aligned than the tag) + // + // TODO: should we be lowering directly like this, or have + // an IR-level representation of tagged unions? + // + + List<IRType*> irCaseTypes; + for(auto caseType : type->caseTypes) + { + auto irCaseType = lowerType(context, caseType); + irCaseTypes.Add(irCaseType); + } + + auto irType = getBuilder()->getTaggedUnionType(irCaseTypes); + if(!irType->findDecoration<IRLinkageDecoration>()) + { + // We need a way for later passes to attach layout information + // to this type, so we will give it a mangled name here. + // + getBuilder()->addExportDecoration( + irType, + getMangledTypeName(type).getUnownedSlice()); + } + return LoweredValInfo::simple(irType); + } + // We do not expect to encounter the following types in ASTs that have // passed front-end semantic checking. #define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); } @@ -2337,6 +2621,12 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + LoweredValInfo visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* /*expr*/) + { + SLANG_UNIMPLEMENTED_X("tagged union type expression during code generation"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + LoweredValInfo visitAssignExpr(AssignExpr* expr) { // Because our representation of lowered "values" diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp index d78f3321a..8ad0bc9f5 100644 --- a/source/slang/mangle.cpp +++ b/source/slang/mangle.cpp @@ -144,6 +144,15 @@ namespace Slang emitSimpleIntVal(context, arrType->ArrayLength); emitType(context, arrType->baseType); } + else if( auto taggedUnionType = dynamic_cast<TaggedUnionType*>(type) ) + { + emitRaw(context, "u"); + for( auto caseType : taggedUnionType->caseTypes ) + { + emitType(context, caseType); + } + emitRaw(context, "U"); + } else { SLANG_UNEXPECTED("unimplemented case in mangling"); diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index b8df39ade..3f5c01d09 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -2429,10 +2429,20 @@ static void collectEntryPointParameters( entryPointLayout->entryPoint = entryPointFuncDecl; context->entryPointLayout = entryPointLayout; - - context->shared->programLayout->entryPoints.Add(entryPointLayout); + // Note: this isn't really the best place for this logic to sit, + // but it is the simplest place where we have a direct correspondance + // between a single `EntryPointRequest` and its matching `EntryPointLayout`, + // so we'll use it. + // + for( auto taggedUnionType : entryPoint->taggedUnionTypes ) + { + auto substType = taggedUnionType->Substitute(typeSubst).As<Type>(); + auto typeLayout = CreateTypeLayout(context->layoutContext, substType); + entryPointLayout->taggedUnionTypeLayouts.Add(typeLayout); + } + // Okay, we seemingly have an entry-point function, and now we need to collect info on its parameters too // // TODO: Long-term we probably want complete information on all inputs/outputs of an entry point, diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index d0e965970..1afc8ebcf 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -1769,6 +1769,25 @@ namespace Slang return typeExpr; } + static RefPtr<Expr> parseTaggedUnionType(Parser* parser) + { + RefPtr<TaggedUnionTypeExpr> taggedUnionType = new TaggedUnionTypeExpr(); + + parser->ReadToken(TokenType::LParent); + while(!AdvanceIfMatch(parser, TokenType::RParent)) + { + auto caseType = parser->ParseTypeExp(); + taggedUnionType->caseTypes.Add(caseType); + + if(AdvanceIf(parser, TokenType::RParent)) + break; + + parser->ReadToken(TokenType::Comma); + } + + return taggedUnionType; + } + static TypeSpec parseTypeSpec(Parser* parser) { TypeSpec typeSpec; @@ -1812,6 +1831,11 @@ namespace Slang typeSpec.expr = createDeclRefType(parser, decl); return typeSpec; } + else if(AdvanceIf(parser, "__TaggedUnion")) + { + typeSpec.expr = parseTaggedUnionType(parser); + return typeSpec; + } Token typeName = parser->ReadToken(TokenType::Identifier); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index d516d0cb2..38b417960 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -827,7 +827,7 @@ int CompileRequest::addEntryPoint( entryPoint->profile = entryPointProfile; entryPoint->translationUnitIndex = translationUnitIndex; for (auto typeName : genericTypeNames) - entryPoint->genericParameterTypeNames.Add(typeName); + entryPoint->genericArgStrings.Add(typeName); auto translationUnit = translationUnits[translationUnitIndex].Ptr(); translationUnit->entryPoints.Add(entryPoint); diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index 132a56ea5..42b1a449d 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -198,6 +198,7 @@ <ClInclude Include="ir-specialize-resources.h" /> <ClInclude Include="ir-specialize.h" /> <ClInclude Include="ir-ssa.h" /> + <ClInclude Include="ir-union.h" /> <ClInclude Include="ir-validate.h" /> <ClInclude Include="ir.h" /> <ClInclude Include="legalize-types.h" /> @@ -252,6 +253,7 @@ <ClCompile Include="ir-specialize-resources.cpp" /> <ClCompile Include="ir-specialize.cpp" /> <ClCompile Include="ir-ssa.cpp" /> + <ClCompile Include="ir-union.cpp" /> <ClCompile Include="ir-validate.cpp" /> <ClCompile Include="ir.cpp" /> <ClCompile Include="legalize-types.cpp" /> @@ -314,4 +316,4 @@ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> <ImportGroup Label="ExtensionTargets"> </ImportGroup> -</Project>
\ No newline at end of file +</Project>
\ No newline at end of file diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index 390c0cc5f..aeb7a1b3e 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -93,6 +93,9 @@ <ClInclude Include="ir-ssa.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="ir-union.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="ir-validate.h"> <Filter>Header Files</Filter> </ClInclude> @@ -251,6 +254,9 @@ <ClCompile Include="ir-ssa.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="ir-union.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="ir-validate.cpp"> <Filter>Source Files</Filter> </ClCompile> diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 320fc576b..fe8b2c4fe 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -2539,7 +2539,160 @@ void Type::accept(IValVisitor* visitor, void* extra) return substValue; } + // + // TaggedUnionType + // + + String TaggedUnionType::ToString() + { + String result; + result.append("__TaggedUnion("); + bool first = true; + for( auto caseType : caseTypes ) + { + if(!first) result.append(", "); + first = false; + + result.append(caseType->ToString()); + } + result.append(")"); + return result; + } + + bool TaggedUnionType::EqualsImpl(Type* type) + { + auto taggedUnion = type->As<TaggedUnionType>(); + if(!taggedUnion) + return false; + + auto caseCount = caseTypes.Count(); + if(caseCount != taggedUnion->caseTypes.Count()) + return false; + + for( UInt ii = 0; ii < caseCount; ++ii ) + { + if(!caseTypes[ii]->Equals(taggedUnion->caseTypes[ii])) + return false; + } + return true; + } + + int TaggedUnionType::GetHashCode() + { + int hashCode = 0; + for( auto caseType : caseTypes ) + { + hashCode = combineHash(hashCode, caseType->GetHashCode()); + } + return hashCode; + } + + RefPtr<Type> TaggedUnionType::CreateCanonicalType() + { + RefPtr<TaggedUnionType> canType = new TaggedUnionType(); + canType->setSession(getSession()); + + for( auto caseType : caseTypes ) + { + auto canCaseType = caseType->GetCanonicalType(); + canType->caseTypes.Add(canCaseType); + } + + return canType; + } + + RefPtr<Val> TaggedUnionType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + int diff = 0; + + List<RefPtr<Type>> substCaseTypes; + for( auto caseType : caseTypes ) + { + substCaseTypes.Add(caseType->SubstituteImpl(subst, &diff).As<Type>()); + } + if(!diff) + return this; + + (*ioDiff)++; + + RefPtr<TaggedUnionType> substType = new TaggedUnionType(); + substType->setSession(getSession()); + substType->caseTypes.SwapWith(substCaseTypes); + return substType; + } + +// +// TaggedUnionSubtypeWitness +// + + +bool TaggedUnionSubtypeWitness::EqualsVal(Val* val) +{ + auto taggedUnionWitness = val->dynamicCast<TaggedUnionSubtypeWitness>(); + if(!taggedUnionWitness) + return false; + + auto caseCount = caseWitnesses.Count(); + if(caseCount != taggedUnionWitness->caseWitnesses.Count()) + return false; + + for(UInt ii = 0; ii < caseCount; ++ii) + { + if(!caseWitnesses[ii]->EqualsVal(taggedUnionWitness->caseWitnesses[ii])) + return false; + } + + return true; +} + +String TaggedUnionSubtypeWitness::ToString() +{ + String result; + result.append("TaggedUnionSubtypeWitness("); + bool first = true; + for( auto caseWitness : caseWitnesses ) + { + if(!first) result.append(", "); + first = false; + + result.append(caseWitness->ToString()); + } + return result; +} + +int TaggedUnionSubtypeWitness::GetHashCode() +{ + int hash = 0; + for( auto caseWitness : caseWitnesses ) + { + hash = combineHash(hash, caseWitness->GetHashCode()); + } + return hash; +} + +RefPtr<Val> TaggedUnionSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substSub = sub->SubstituteImpl(subst, &diff).As<Type>(); + auto substSup = sup->SubstituteImpl(subst, &diff).As<Type>(); + + List<RefPtr<Val>> substCaseWitnesses; + for( auto caseWitness : caseWitnesses ) + { + substCaseWitnesses.Add(caseWitness->SubstituteImpl(subst, &diff)); + } + + if(!diff) + return this; + (*ioDiff)++; + RefPtr<TaggedUnionSubtypeWitness> substWitness = new TaggedUnionSubtypeWitness(); + substWitness->sub = substSub; + substWitness->sup = substSup; + substWitness->caseWitnesses.SwapWith(substCaseWitnesses); + return substWitness; } +} // namespace Slang diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h index d9bfecc03..2318c6933 100644 --- a/source/slang/type-defs.h +++ b/source/slang/type-defs.h @@ -457,3 +457,21 @@ RAW( virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; ) END_SYNTAX_CLASS() + + /// A tagged union of zero or more other types. +SYNTAX_CLASS(TaggedUnionType, Type) +RAW( + /// The distinct "cases" the tagged union can store. + /// + /// For each type in this array, the array index is the + /// tag value for that case. + /// + List<RefPtr<Type>> caseTypes; + + virtual String ToString() override; + virtual bool EqualsImpl(Type * type) override; + virtual int GetHashCode() override; + virtual RefPtr<Type> CreateCanonicalType() override; + virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; +) +END_SYNTAX_CLASS() diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp index f5cd518b8..05c61e706 100644 --- a/source/slang/type-layout.cpp +++ b/source/slang/type-layout.cpp @@ -2349,6 +2349,101 @@ SimpleLayoutInfo GetLayoutImpl( rules, outTypeLayout); } + else if( auto taggedUnionType = type->As<TaggedUnionType>() ) + { + // A tagged union type needs to be laid out as the maximum + // size of any constituent type. + // + // In practice, only a tagged union of uniform data will + // work, but for now we will compute the maximum usage + // for each resource kind for generality. + // + // For the uniform data we will start with a size + // of zero and an alignment of one for our base case + // (this is what a tagged union of no cases would consume). + // + UniformLayoutInfo info(0, 1); + + // If we are being asked to construct a full `TypeLayout` + // object, then we'll allocate it up front. + // + RefPtr<TaggedUnionTypeLayout> taggedUnionLayout; + if( outTypeLayout ) + { + taggedUnionLayout = new TaggedUnionTypeLayout(); + taggedUnionLayout->type = type; + taggedUnionLayout->rules = rules; + *outTypeLayout = taggedUnionLayout; + } + + // Now we iterate over the case types and see if they + // change our computed maximum size/alignement. + // + for( auto caseType : taggedUnionType->caseTypes ) + { + RefPtr<TypeLayout> caseTypeLayout; + UniformLayoutInfo caseTypeInfo = GetLayoutImpl(context, caseType, outTypeLayout ? &caseTypeLayout : nullptr).getUniformLayout(); + + info.size = maximum(info.size, caseTypeInfo.size); + info.alignment = std::max(info.alignment, caseTypeInfo.alignment); + + // If we are building a full `TypeLayout` we need to + // do a few more steps for each case type. + // + if( outTypeLayout ) + { + // We need to remember the layout of the case type + // on the final `TaggedUnionTypeLayout`. + // + taggedUnionLayout->caseTypeLayouts.Add(caseTypeLayout); + + // We also need to consider contributions for other + // resource kinds beyond uniform data. + // + for( auto caseResInfo : caseTypeLayout->resourceInfos ) + { + auto unionResInfo = taggedUnionLayout->findOrAddResourceInfo(caseResInfo.kind); + unionResInfo->count = maximum(unionResInfo->count, caseResInfo.count); + } + } + } + + // After we've computed the size required to hold all the + // case types, we will allocate space for the tag field. + // + // TODO: This assumes the tag will always be allocated out + // of uniform storage, which means we can't support a tagged + // union as part of a varying input/output signature. That is + // probably a valid limitation, but it should get enforced + // somewhere along the way. + // + { + // The tag is always a `uint` for now. + // + auto tagInfo = context.rules->GetScalarLayout(BaseType::UInt); + info.size = RoundToAlignment(info.size, tagInfo.alignment); + + if( outTypeLayout ) + { + taggedUnionLayout->tagOffset = info.size; + } + + info.size += tagInfo.size; + info.alignment = std::max(info.alignment, tagInfo.alignment); + } + + // As a final step, if we are computing a full `TypeLayout` + // we will make sure that its information on uniform layout + // matches what we've computed in the `UniformLayoutInfo` we return. + // + if( outTypeLayout ) + { + taggedUnionLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->count = info.size; + taggedUnionLayout->uniformAlignment = info.alignment; + } + + return info; + } // catch-all case in case nothing matched SLANG_ASSERT(!"unimplemented"); diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index c326ae989..da2f0e4f7 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -553,6 +553,28 @@ public: int paramIndex = 0; }; + /// Layout information for a tagged union type. +class TaggedUnionTypeLayout : public TypeLayout +{ +public: + /// The layouts of each of the case types. + /// + /// The order of entries in this array matches + /// the order of case types on the original + /// `TaggedUnionType`, and the index of a case + /// type is also the tag value for that case. + /// + List<RefPtr<TypeLayout>> caseTypeLayouts; + + /// The byte offset for the tag field. + /// + /// The tag field will always be allocted as + /// a `uint`, so we don't store a separate layout + /// for it. + /// + LayoutSize tagOffset; +}; + // Layout information for a single shader entry point // within a program // @@ -579,6 +601,12 @@ public: usesAnySampleRateInput = 0x1, }; unsigned flags = 0; + + /// Layouts for all tagged union types required by this entry point. + /// + /// These are any tagged union types used by the generic + /// arguments that this entry point is being compiled with. + List<RefPtr<TypeLayout>> taggedUnionTypeLayouts; }; class GenericParamLayout : public Layout diff --git a/source/slang/val-defs.h b/source/slang/val-defs.h index f96ee026e..f5b099079 100644 --- a/source/slang/val-defs.h +++ b/source/slang/val-defs.h @@ -136,3 +136,20 @@ RAW( ) END_SYNTAX_CLASS() +// A witness that `sub : sup`, because `sub` is a tagged union +// of the form `A | B | C | ...` and each of `A : sup`, +// `B : sup`, `C : sup`, etc. +// +SYNTAX_CLASS(TaggedUnionSubtypeWitness, SubtypeWitness) +RAW( + // Witnesses that each of the "case" types in the union + // is a subtype of `sup`. + // + List<RefPtr<Val>> caseWitnesses; + + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; +) +END_SYNTAX_CLASS() |
