diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/check.cpp | 163 | ||||
| -rw-r--r-- | source/slang/compiler.h | 8 | ||||
| -rw-r--r-- | source/slang/diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 98 | ||||
| -rw-r--r-- | source/slang/expr-defs.h | 11 | ||||
| -rw-r--r-- | source/slang/ir-inst-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/ir-insts.h | 56 | ||||
| -rw-r--r-- | source/slang/ir-link.cpp | 24 | ||||
| -rw-r--r-- | source/slang/ir-ssa.cpp | 11 | ||||
| -rw-r--r-- | source/slang/ir-union.cpp | 776 | ||||
| -rw-r--r-- | source/slang/ir-union.h | 18 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 64 | ||||
| -rw-r--r-- | source/slang/ir.h | 12 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 290 | ||||
| -rw-r--r-- | source/slang/mangle.cpp | 9 | ||||
| -rw-r--r-- | source/slang/parameter-binding.cpp | 14 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 24 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 4 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 6 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 153 | ||||
| -rw-r--r-- | source/slang/type-defs.h | 18 | ||||
| -rw-r--r-- | source/slang/type-layout.cpp | 95 | ||||
| -rw-r--r-- | source/slang/type-layout.h | 28 | ||||
| -rw-r--r-- | source/slang/val-defs.h | 17 |
25 files changed, 1883 insertions, 27 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 2fa4e9bc7..9aa5eb689 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -9,6 +9,8 @@ namespace Slang { + RefPtr<TypeType> getTypeType( + Type* type); /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? bool isEffectivelyStatic( @@ -677,7 +679,7 @@ namespace Slang auto baseExprType = baseExpr->type.type; RefPtr<SharedTypeExpr> baseTypeExpr = new SharedTypeExpr(); baseTypeExpr->base.type = baseExprType; - baseTypeExpr->type = new TypeType(baseExprType); + baseTypeExpr->type.type = getTypeType(baseExprType); auto expr = new StaticMemberExpr(); expr->loc = loc; @@ -2071,9 +2073,7 @@ namespace Slang { RefPtr<TypeCastExpr> castExpr = createImplicitCastExpr(); - auto typeType = new TypeType(); - typeType->setSession(getSession()); - typeType->type = toType; + auto typeType = getTypeType(toType); auto typeExpr = new SharedTypeExpr(); typeExpr->type.type = typeType; @@ -5547,6 +5547,60 @@ namespace Slang return witness; } + /// Is the given interface one that a tagged-union type can conform to? + /// + /// If a tagged union type `__TaggedUnion(A,B)` is going to be + /// plugged in for a type parameter `T : IFoo` then we need to + /// be sure that the interface `IFoo` doesn't have anything + /// that could lead to unsafe/unsound behavior. This function + /// checks that all the requirements on the interfaceare safe ones. + /// + bool isInterfaceSafeForTaggedUnion( + DeclRef<InterfaceDecl> interfaceDeclRef) + { + for( auto memberDeclRef : getMembers(interfaceDeclRef) ) + { + if(!isInterfaceRequirementSafeForTaggedUnion(interfaceDeclRef, memberDeclRef)) + return false; + } + + return true; + } + + /// Is the given interface requirement one that a tagged-union type can satisfy? + /// + /// Unsafe requirements include any `static` requirements, + /// any associated types, and also any requirements that make + /// use of the `This` type (once we support it). + /// + bool isInterfaceRequirementSafeForTaggedUnion( + DeclRef<InterfaceDecl> interfaceDeclRef, + DeclRef<Decl> requirementDeclRef) + { + if(auto callableDeclRef = requirementDeclRef.As<CallableDecl>()) + { + // A `static` method requirement can't be satisfied by a + // tagged union, because there is no tag to dispatch on. + // + if(requirementDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + return false; + + // TODO: We will eventually want to check that any callable + // requirements do not use the `This` type or any associated + // types in ways that could lead to errors. + // + // For now we are disallowing interfaces that have associated + // types completely, and we haven't implemented the `This` + // type, so we should be safe. + + return true; + } + else + { + return false; + } + } + bool doesTypeConformToInterfaceImpl( RefPtr<Type> originalType, RefPtr<Type> type, @@ -5661,6 +5715,69 @@ namespace Slang } } } + else if(auto taggedUnionType = type->As<TaggedUnionType>()) + { + // A tagged union type conforms to an interface if all of + // the constituent types in the tagged union conform. + // + // We will iterate over the "case" types in the tagged + // union, and check if they conform to the interface. + // Along the way we will collect the conformance witness + // values *if* we are being asked to produce a witness + // value for the tagged union itself (that is, if + // `outWitness` is non-null). + // + List<RefPtr<Val>> caseWitnesses; + for(auto caseType : taggedUnionType->caseTypes) + { + RefPtr<Val> caseWitness; + + if(!doesTypeConformToInterfaceImpl( + caseType, + caseType, + interfaceDeclRef, + outWitness ? &caseWitness : nullptr, + nullptr)) + { + return false; + } + + if(outWitness) + { + caseWitnesses.Add(caseWitness); + } + } + + // We also need to validate the requirements on + // the interface to make sure that they are suitable for + // use with a tagged-union type. + // + // For example, if the interface includes a `static` method + // (which can therefore be called without a particular instance), + // then we wouldn't know what implementation of that method + // to use because there is no tag value to dispatch on. + // + // We will start out being conservative about what we accept + // here, just to keep things simple. + // + if(!isInterfaceSafeForTaggedUnion(interfaceDeclRef)) + return false; + + // If we reach this point then we have a concrete + // witness for each of the case types, and that is + // enough to build a witness for the tagged union. + // + if(outWitness) + { + RefPtr<TaggedUnionSubtypeWitness> taggedUnionWitness = new TaggedUnionSubtypeWitness(); + taggedUnionWitness->sub = taggedUnionType; + taggedUnionWitness->sup = DeclRefType::Create(getSession(), interfaceDeclRef); + taggedUnionWitness->caseWitnesses.SwapWith(caseWitnesses); + + *outWitness = taggedUnionWitness; + } + return true; + } // default is failure return false; @@ -8090,6 +8207,23 @@ namespace Slang return expr; } + RefPtr<Expr> visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr) + { + // We have an expression of the form `__TaggedUnion(A, B, ...)` + // which will evaluate to a tagged-union type over `A`, `B`, etc. + // + RefPtr<TaggedUnionType> type = new TaggedUnionType(); + expr->type = QualType(getTypeType(type)); + + for( auto& caseTypeExpr : expr->caseTypes ) + { + caseTypeExpr = CheckProperType(caseTypeExpr); + type->caseTypes.Add(caseTypeExpr.type); + } + + return expr; + } + @@ -9039,7 +9173,7 @@ namespace Slang scopesToTry.Add(module->moduleDecl->scope); List<RefPtr<Type>> globalGenericArgs; - for (auto name : entryPoint->genericParameterTypeNames) + for (auto name : entryPoint->genericArgStrings) { // parse type name RefPtr<Type> type; @@ -9059,6 +9193,25 @@ namespace Slang return; } + // The following is a bit of a hack. + // + // Back-end code generation relies on us having computed layouts for all tagged + // unions that end up being used in the code, which means we need a way to find + // all such types that get used in a module (and the stuff it imports). + // + // The Right Way to handle this would probably be to have each `ModuleDecl` track + // any tagged union types that get created in the context of that module, and + // then combine those lists later. + // + // For now we are assuming a tagged union type only comes into existence + // as a (top-level) argument for a generic type parameter, so that we + // can check for them here and cache them on the entry point request. + // + if( auto taggedUnionType = type->As<TaggedUnionType>() ) + { + entryPoint->taggedUnionTypes.Add(taggedUnionType); + } + globalGenericArgs.Add(type); } diff --git a/source/slang/compiler.h b/source/slang/compiler.h index d2072387e..41ba027c6 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -122,9 +122,8 @@ namespace Slang // The name of the entry point function (e.g., `main`) Name* name; - // The type names we want to substitute into the - // global generic type parameters - List<String> genericParameterTypeNames; + /// Source code for the generic arguments to use for the generic parameters of the entry point. + List<String> genericArgStrings; // The profile that the entry point will be compiled for // (this is a combination of the target stage, and also @@ -156,6 +155,9 @@ namespace Slang RefPtr<FuncDecl> decl; RefPtr<Substitutions> globalGenericSubst; + + /// Any tagged union types that were referenced by the generic arguments of the entry point. + List<RefPtr<TaggedUnionType>> taggedUnionTypes; }; enum class PassThroughMode : SlangPassThrough diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 76e59efa3..4a43cf8e4 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -347,7 +347,7 @@ DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is onl DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration") DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.") -DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$1`.") +DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$2`.") DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself") DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')") diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 2e801f0e3..0aebf6153 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -12,6 +12,7 @@ #include "ir-specialize.h" #include "ir-specialize-resources.h" #include "ir-ssa.h" +#include "ir-union.h" #include "ir-validate.h" #include "legalize-types.h" #include "lower-to-ir.h" @@ -3922,6 +3923,72 @@ struct EmitVisitor } break; + case kIROp_BitCast: + { + // TODO: we can simplify the logic for arbitrary bitcasts + // by always bitcasting the source to a `uint*` type (if it + // isn't already) and then bitcasting that to the destination + // type (if it isn't already `uint*`. + // + // For now we are assuming the source type is *already* + // a `uint*` type of the appropriate size. + // +// auto fromType = extractBaseType(inst->getOperand(0)->getDataType()); + auto toType = extractBaseType(inst->getDataType()); + switch(getTarget(ctx)) + { + case CodeGenTarget::GLSL: + switch(toType) + { + default: + emit("/* unhandled */"); + break; + + case BaseType::UInt: + break; + + case BaseType::Int: + emitIRType(ctx, inst->getDataType()); + break; + + case BaseType::Float: + emit("uintBitsToFloat("); + break; + } + break; + + case CodeGenTarget::HLSL: + switch(toType) + { + default: + emit("/* unhandled */"); + break; + + case BaseType::UInt: + break; + case BaseType::Int: + emit("("); + emitIRType(ctx, inst->getDataType()); + emit(")"); + break; + case BaseType::Float: + emit("asfloat"); + break; + } + break; + + + default: + SLANG_UNEXPECTED("unhandled codegen target"); + break; + } + + emit("("); + emitIROperand(ctx, inst->getOperand(0), mode, kEOp_General); + emit(")"); + } + break; + default: emit("/* unhandled */"); break; @@ -3929,6 +3996,27 @@ struct EmitVisitor maybeCloseParens(needClose); } + BaseType extractBaseType(IRType* inType) + { + auto type = inType; + for(;;) + { + if(auto irBaseType = as<IRBasicType>(type)) + { + return irBaseType->getBaseType(); + } + else if(auto vecType = as<IRVectorType>(type)) + { + type = vecType->getElementType(); + continue; + } + else + { + return BaseType::Void; + } + } + } + void emitIRInst( EmitContext* ctx, IRInst* inst, @@ -6565,6 +6653,14 @@ String emitEntryPoint( #endif validateIRModuleIfEnabled(compileRequest, irModule); + // Desguar any union types, since these will be illegal on + // various targets. + // + desugarUnionTypes(irModule); +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "UNIONS DESUGARED"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); // Any code that makes use of existential (interface) types @@ -6595,8 +6691,6 @@ String emitEntryPoint( // specializeGenerics(irModule); - - // Debugging code for IR transformations... #if 0 dumpIRIfEnabled(compileRequest, irModule, "SPECIALIZED"); diff --git a/source/slang/expr-defs.h b/source/slang/expr-defs.h index fb29f64de..bd4ba5038 100644 --- a/source/slang/expr-defs.h +++ b/source/slang/expr-defs.h @@ -193,3 +193,14 @@ RAW( DeclRef<VarDeclBase> declRef; ) END_SYNTAX_CLASS() + + /// A type expression of the form `__TaggedUnion(A, ...)`. + /// + /// An expression of this form will resolve to a `TaggedUnionType` + /// when checked. + /// +SYNTAX_CLASS(TaggedUnionTypeExpr, Expr) +RAW( + List<TypeExp> caseTypes; +) +END_SYNTAX_CLASS()
\ No newline at end of file diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index 8d7a647f4..b6f8ce547 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -41,6 +41,8 @@ INST(Nop, nop, 0, 0) INST(VectorType, Vec, 2, 0) INST(MatrixType, Mat, 3, 0) + INST(TaggedUnionType, TaggedUnion, 0, 0) + /* Rate */ INST(ConstExprRate, ConstExpr, 0, 0) INST(GroupSharedRate, GroupShared, 0, 0) @@ -406,6 +408,11 @@ INST(ExtractExistentialValue, extractExistentialValue, 1, 0) INST(ExtractExistentialType, extractExistentialType, 1, 0) INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, 0) +INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0) +INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) + +INST(BitCast, bitCast, 1, 0) + PSEUDO_INST(Pos) PSEUDO_INST(PreInc) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 0c56e8244..8662569ba 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -663,7 +663,6 @@ struct SharedIRBuilder Dictionary<IRInstKey, IRInst*> globalValueNumberingMap; Dictionary<IRConstantKey, IRConstant*> constantMap; - Dictionary<Name*, IRWitnessTable*> witnessTableMap; }; struct IRBuilderSourceLocRAII; @@ -753,6 +752,13 @@ struct IRBuilder IRType* const* paramTypes, IRType* resultType); + IRFuncType* getFuncType( + List<IRType*> const& paramTypes, + IRType* resultType) + { + return getFuncType(paramTypes.Count(), paramTypes.Buffer(), resultType); + } + IRConstExprRate* getConstExprRate(); IRGroupSharedRate* getGroupSharedRate(); @@ -760,6 +766,16 @@ struct IRBuilder IRRate* rate, IRType* dataType); + IRType* getTaggedUnionType( + UInt caseCount, + IRType* const* caseTypes); + + IRType* getTaggedUnionType( + List<IRType*> const& caseTypes) + { + return getTaggedUnionType(caseTypes.Count(), caseTypes.Buffer()); + } + // Set the data type of an instruction, while preserving // its rate, if any. void setDataType(IRInst* inst, IRType* dataType); @@ -794,6 +810,14 @@ struct IRBuilder UInt argCount, IRInst* const* args); + IRInst* emitCallInst( + IRType* type, + IRInst* func, + List<IRInst*> const& args) + { + return emitCallInst(type, func, args.Count(), args.Buffer()); + } + IRInst* createIntrinsicInst( IRType* type, IROp op, @@ -816,6 +840,13 @@ struct IRBuilder UInt argCount, IRInst* const* args); + IRInst* emitMakeVector( + IRType* type, + List<IRInst*> const& args) + { + return emitMakeVector(type, args.Count(), args.Buffer()); + } + IRInst* emitMakeMatrix( IRType* type, UInt argCount, @@ -831,6 +862,13 @@ struct IRBuilder UInt argCount, IRInst* const* args); + IRInst* emitMakeStruct( + IRType* type, + List<IRInst*> const& args) + { + return emitMakeStruct(type, args.Count(), args.Buffer()); + } + IRInst* emitMakeExistential( IRType* type, IRInst* value, @@ -1040,6 +1078,22 @@ struct IRBuilder IRInst* param, IRInst* val); + IRInst* emitExtractTaggedUnionTag( + IRInst* val); + + IRInst* emitExtractTaggedUnionPayload( + IRType* type, + IRInst* val, + IRInst* tag); + + IRInst* emitBitCast( + IRType* type, + IRInst* val); + + // + // Decorations + // + IRDecoration* addDecoration(IRInst* value, IROp op, IRInst* const* operands, Int operandCount); IRDecoration* addDecoration(IRInst* value, IROp op) diff --git a/source/slang/ir-link.cpp b/source/slang/ir-link.cpp index 25d3b40b6..231ee81d2 100644 --- a/source/slang/ir-link.cpp +++ b/source/slang/ir-link.cpp @@ -1284,6 +1284,30 @@ IRFunc* specializeIRForEntryPoint( cloneValue(context, bindInst); } + // HACK: we need to ensure that any tagged union types + // in the IR module have layout information copied over to them. + // + // Note that we do this *after* cloning the `bindGlobalGenericParam` + // instructions, since we expected the tagged union type(s) to + // be referenced by them. + // + for( auto taggedUnionTypeLayout : entryPointLayout->taggedUnionTypeLayouts ) + { + auto taggedUnionType = taggedUnionTypeLayout->getType(); + auto mangledName = getMangledTypeName(taggedUnionType); + + RefPtr<IRSpecSymbol> sym; + if(!context->getSymbols().TryGetValue(mangledName, sym)) + continue; + + IRInst* clonedType = findClonedValue(context, sym->irGlobalValue); + if(!clonedType) + continue; + + context->builder->addLayoutDecoration(clonedType, taggedUnionTypeLayout); + } + + // TODO: *technically* we should consider the case where // we have global variables with initializers, since diff --git a/source/slang/ir-ssa.cpp b/source/slang/ir-ssa.cpp index d893137de..8c5db68fe 100644 --- a/source/slang/ir-ssa.cpp +++ b/source/slang/ir-ssa.cpp @@ -1068,16 +1068,9 @@ void constructSSA(ConstructSSAContext* context) newArgCount, newArgs.Buffer()); - // Swap decorations (all children, really) over to the new instruction + // Transfer decorations (a terminator should have no children) over to the new instruction. // - // TODO: We might want to encapsualte this in a reusable subroutine if - // we often need to copy decorations from one instruction to another. - // - while( auto firstChild = oldTerminator->getFirstDecoration() ) - { - firstChild->removeFromParent(); - firstChild->insertAtEnd(newTerminator); - } + oldTerminator->transferDecorationsTo(newTerminator); // A terminator better not have uses, so we shouldn't have // to replace them. diff --git a/source/slang/ir-union.cpp b/source/slang/ir-union.cpp new file mode 100644 index 000000000..c50d669e1 --- /dev/null +++ b/source/slang/ir-union.cpp @@ -0,0 +1,776 @@ +// ir-union.cpp +#include "ir-union.h" + +#include "ir.h" +#include "ir-insts.h" + +namespace Slang { + +// This file will implement a pass to replace any union types (currently +// just tagged unions) with plain `struct` types that attempt to provide +// equivalent semantics. This will necessarily be a bit fragile, and there +// will be fundamental limits to what the translation can support without +// improved features in the target shading languages/ILs. + +struct DesugarUnionTypesContext +{ + // We'll start with some basic state that we need to get the job done. + // + // This includes the IR module we are to process, as well as IR building + // state that we will initialize once and then use throughout the pass. + // + IRModule* module; + SharedIRBuilder sharedBuilderStorage; + IRBuilder builderStorage; + IRBuilder* getBuilder() { return &builderStorage; } + + // Because we will be replacing instructions that refer to unions with + // different logic, we'll want to remove the original instructions. + // However, we need to be careful about modifying the IR tree while also + // iterating it, and to keep things simple for ourselves we'll go ahead + // and build up a list of instruction to remove along the way, and then + // remove them all at the end. + // + List<IRInst*> instsToRemove; + + // The overall flow of the pass is pretty simple, so we will walk through it now. + // + void processModule() + { + // We start by initializing our IR building state. + // + sharedBuilderStorage.session = module->session; + sharedBuilderStorage.module = module; + builderStorage.sharedBuilder = &sharedBuilderStorage; + + // Next, we will search for any instruction that create or use + // union types, and process them accordingingly (usually by + // constructing a new instruction to replace them). + // + processInstRec(module->getModuleInst()); + + // Along the way we will build up a list of the tagged union + // types that we encountered, but we will refrain from replacing + // them until we are done (so that we always know that the instructions + // we process above refer to the original type, and not its + // replacement. + // + for( auto info : taggedUnionInfos ) + { + auto taggedUnionType = info->taggedUnionType; + auto replacementInst = info->replacementInst; + + // TODO: We should consider transferring decorations from the source + // type to the destination, but doing so carelessly could create + // problems, since an IR struct type shouldn't have, e.g., a + // `TaggedUnionTypeLayout` attached to it. + + taggedUnionType->replaceUsesWith(replacementInst); + taggedUnionType->removeAndDeallocate(); + } + + // As described previously, we build up the `instsToRemove` list as + // we iterate so that we can remove them all here and not risk + // modifying the IR tree while also walking it. + // + // TODO: This might be overkill and we could conceivably just be + // a bit careful in `processInstRec`. + // + for(auto inst : instsToRemove) + { + inst->removeAndDeallocate(); + } + } + + // In order to replace a (tagged) union type, we will need to know + // something about it, and we will use the `TaggedUnionInfo` type + // to collect all the relevant information. + // + struct TaggedUnionInfo : public RefObject + { + // We obviously need to know the tagged union itself, and + // we will also use this structure to track the instruction + // (an IR struct type) that will replace it. + // + IRTaggedUnionType* taggedUnionType; + IRInst* replacementInst; + + // In order to compute a suitable layout for the replacement + // `struct` type we need to know how the tagged union itself + // would be laid out in memory, so we require that all tagged + // unions in the generated IR have an associated (target-specific) + // layout. + // + TaggedUnionTypeLayout* taggedUnionTypeLayout; + + // The basic approach we will use 16-byte chunks (represented as an array + // of `uint4`s) to reprent the "bulk" of a type, and then use a single field + // that could be up to 12 bytes to represent the "rest" of the type. + // + // Note that there are deeply ingrained assumptions here that all types + // are at least four bytes in size (so that unions cannot easily + // accomodate `half` value), and that any types *larger* than four bytes + // will need to be loaded/stored via multiple 4-byte loads/stores. + // + // With the basic idea out of the way, we need an IR level field + // in our struct to hold the bulk data, which comprises a "key" for + // looking up the field, and the type of the field itself. We also + // keep track of how many bytes we put in our bulk storage. + // + // The bulk field might be: + // + // - null, if none of the case types was 16 bytes or more + // - a single `uint4` for between 16 and 31 (inclusive) bytes + // - an array of `uint4`s for 32 or more bytes + // + UInt64 bulkSize = 0; + IRInst* bulkFieldKey = nullptr; + IRType* bulkFieldType = nullptr; + + // The same basic idea then applies to the rest of the data. + // + // The "rest" field will be either be absent (if the size of the + // type was evently divisible by 16), a scalar `uint`, or else + // a 2- or 3-component vector of `uint`. + // + UInt64 restSize = 0; + IRInst* restFieldKey = nullptr; + IRType* restFieldType = nullptr; + + // Finally, since we are currently working with tagged unions, + // we need a field to hold the tag, which will always be allocated + // after the fields that hold the bulk/rest of the payload. + // + // This field is always a single `uint`. + // + // TODO: if/when we support untagged unions, they could be handled + // by having this field be null. + // + IRInst* tagFieldKey; + }; + + // We will build up a list of all the tagged union types we encounter, + // so that we can replace them with the synthesized types when we are done. + // + List<RefPtr<TaggedUnionInfo>> taggedUnionInfos; + + // It is possible that we will see the same tagged union type referenced + // many times in the IR, but we only want to synthesize the information + // above (including the various IR structures) once, so we also maintain + // a map from the original IR type to the corresponding information. + // + Dictionary<IRInst*, TaggedUnionInfo*> mapIRTypeToTaggedUnionInfo; + + // We will process all instructions in the module in a single recursive walk. + // + void processInstRec(IRInst* inst) + { + processInst(inst); + + for( auto child : inst->getChildren() ) + { + processInstRec(child); + } + } + // + // At each instruction, we will check if it is one of the union-related instructions + // we need to replace, and process it accordingly. + // + void processInst(IRInst* inst) + { + switch( inst->op ) + { + default: + // Any instruction not listed below either doesn't involve union types, + // or handles them in a hands-off fashion that we don't need to care about. + // + // E.g., a `load` of a union type from a constant buffer will turn into + // a load of the replacement `struct` type once we are done, and nothing + // needs to be done to the `load` instruction. + // + break; + + case kIROp_TaggedUnionType: + { + // We clearly need to process the tagged union type itself, but the actual + // work is handled by other functions. All we need to do here is ensure + // that the information for this type gets generated, and then we can + // rely on the main `processModule` function to do the actual replacement later. + // + auto type = cast<IRTaggedUnionType>(inst); + getTaggedUnionInfo(type); + } + break; + + case kIROp_ExtractTaggedUnionTag: + { + // The case of extracting the tag from a tagged union is relatively + // simple, because the replacement type will have a dedicated field or it. + // + // We start by finding the tagged union value the instruction is operating + // on, and then looking up the information for its type (which had + // better be a tagged union type). + // + auto taggedUnionVal = inst->getOperand(0); + auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType()); + + // Because the replacement type will have an explicit field for the tag, + // we can simply emit a single field-extract instruction to read its value + // out. + // + auto builder = getBuilder(); + builder->setInsertBefore(inst); + auto replacement = builder->emitFieldExtract( + inst->getFullType(), + taggedUnionVal, + taggedUnionInfo->tagFieldKey); + + // Now we can replace anything that used the original instruction with + // the new field-extract operation, and add this instruction to the + // list for later removal. + // + inst->replaceUsesWith(replacement); + instsToRemove.Add(inst); + } + break; + + case kIROp_ExtractTaggedUnionPayload: + { + // The most interesting case is when we are trying to extract a particular + // payload (one of the case types) from a union. We may need to extract + // one or more fields from the data stored in the union's replacement + // type (the bulk/rest fields), and we may also have to convert them + // to the type expected via bit-casts. + + // We can start things off easily enough by extracting the tagged union + // value being operated on, as well as the information for its type. + // + auto taggedUnionVal = inst->getOperand(0); + auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType()); + + // Next we need to figure out which case is being extracted from the union. + // The operand for the case tag should be a literal by construction. + // + auto caseTagVal = inst->getOperand(1); + auto caseTagConst = as<IRIntLit>(caseTagVal); + SLANG_ASSERT(caseTagConst); + + // The case type we are extracting will be the result type of the instruciton. + // + auto caseType = inst->getDataType(); + // + // The tag value itself will be the index of the case type in the union + // type (and its layout). + // + auto caseTagIndex = UInt(caseTagConst->getValue()); + + // We can use the case tag value to look up the layout for the particular + // case type we are extracting (this will allow us to resolve byte offsets + // for fields, etc.). + // + auto taggedUnionTypeLayout = taggedUnionInfo->taggedUnionTypeLayout; + SLANG_ASSERT(caseTagIndex < taggedUnionTypeLayout->caseTypeLayouts.Count()); + auto caseTypeLayout = taggedUnionTypeLayout->caseTypeLayouts[caseTagIndex]; + + // At this point we know the type we are trying to extract, as well + // as its layout. We will defer the actual implementation of extraction + // to a (recursive) subroutine that can extract a (sub-)field from the + // union at a given byte offset. Since we are extracting a full case + // right now, the byte offset will be zero. + // + auto payloadVal = extractPayload( + taggedUnionInfo, + taggedUnionVal, + caseType, + caseTypeLayout, + 0); + + // TODO: There is a significant flaw in the above approach when + // the case type might be (or contain) an array. If we have a setup + // like the following: + // + // union SomeUnion { float someCase[100]; ... } + // ... + // float result = someUnion.someCase[someIndex]; + // + // The current logic would desugar this into something like: + // + // struct SomeUnion { uint4 bulk[100]; ... } + // ... + // float[] tmp = { asfloat(someUnion.bulk[0].x), asfloat(someUnion.bulk[1].x), ... } + // float result = tmp[someIndex]; + // + // The result is that we copy an entire 100-element array into local memory + // just to fetch a single element, when it would be much nicer to just do: + // + // float result = asfloat(someUnion.bulk[someIndex].x); + // + // Achieving the latter code requires that rather than blindly translate + // the `extractTaggedUnionPayload` instruction into a semantically equiavlent + // value (which might lead to a big copy in the end), we should transitively + // chase down any "access chains" off of `inst` and see what leaf values are + // actually needed, and generated more tailored extraction logic for just + // the elements/fields that actually get referenced. + // + // The more refined approach can be built on top of many of the same primitives, + // so for now we will resign ourselves to the simpler but potentially less + // efficient approach. + + // Now that we've extracted the value for the payload from the fields of + // the replacement struct, we can use that extracted value to replace + // this instruction, and schedule the original instruction for removal. + // + inst->replaceUsesWith(payloadVal); + instsToRemove.Add(inst); + } + break; + } + } + + // The `extractPayload` operation is the most important bit of translation we + // need to do to make unions work. We have as input the following: + // + IRInst* extractPayload( + + // - Information about a tagged union type and its layout. + TaggedUnionInfo* taggedUnionInfo, + + // - A single value of that tagged unon type. + IRInst* taggedUnionVal, + + // - Type type of some "payload" field we want to extract from the union. + IRType* payloadType, + + // - The memory layout of that payload type. + TypeLayout* payloadTypeLayout, + + // - The byte offset at which we want to fetch the payload. + UInt64 payloadOffset) + { + // We are going to be building some IR code no matter what. + // + auto builder = getBuilder(); + + // The basic approach here will be to look at the type we + // are trying to extract from the union, and whenever possible + // recursively walk its structure so that we can express things + // in terms of extraction of smaller/simpler types. + // + if( auto irStructType = as<IRStructType>(payloadType) ) + { + // A structure type is a nice recursive case: we simply + // want to extract each of its field recursively, and + // then construct a fresh value of the `struct` type. + + // In all of the cases of this function we expect/require + // there to be complete type layout information for the + // types involved. + // + auto structTypeLayout = dynamic_cast<StructTypeLayout*>(payloadTypeLayout); + SLANG_ASSERT(structTypeLayout); + + // We are going to emit code to extract each of the fields + // and collect them to use as operands to a `makeStruct`. + // + List<IRInst*> fieldVals; + + // We need to walk over the fields in the order the IR expects them + UInt fieldCounter = 0; + for( auto irField : irStructType->getFields() ) + { + IRType* fieldType = irField->getFieldType(); + + // TODO: We need to confirm/enforce that the fields of the + // IR struct and the fields of the layout still align. + // + UInt fieldIndex = fieldCounter++; + auto fieldLayout = structTypeLayout->fields[fieldIndex]; + auto fieldTypeLayout = fieldLayout->getTypeLayout(); + + // The offset of the field can be computed from the base + // offset passed in, plus the reflection data for the field. + // + UInt64 fieldOffset = payloadOffset; + if(auto resInfo = fieldLayout->FindResourceInfo(LayoutResourceKind::Uniform)) + fieldOffset += resInfo->index; + + // We make a recursive call to extract each field, expecting + // that this will bottom out eventually. + // + IRInst* fieldVal = extractPayload( + taggedUnionInfo, + taggedUnionVal, + fieldType, + fieldTypeLayout, + fieldOffset); + fieldVals.Add(fieldVal); + } + + // The final value is then just a new struct constructed from + // the extracted field values. + // + auto payloadVal = builder->emitMakeStruct(irStructType, fieldVals); + return payloadVal; + } + else if( auto vecType = as<IRVectorType>(payloadType) ) + { + auto elementType = vecType->getElementType(); + + // We expect that by the time we are desugaring union types + // all vector types have literal constant values for their + // element count. + // + auto elementCountVal = vecType->getElementCount(); + auto elementCountConst = as<IRIntLit>(elementCountVal); + SLANG_ASSERT(elementCountConst); + UInt elementCount = UInt(elementCountConst->getValue()); + + // HACK: There is currently no `VectorTypeLayout` and thus + // no way to query the layout of the elements of a vector + // type. Until that gets added we will kludge things here. + // + TypeLayout* elementTypeLayout = nullptr; + size_t elementSize = 0; + if(auto resInfo = payloadTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform)) + elementSize = resInfo->count.getFiniteValue() / elementCount; + + // Similar to the `struct` case above, we will extract a + // value for each element of the vector, and then use + // `makeVector` to construct the result value. + // + List<IRInst*> elementVals; + for(UInt ii = 0; ii < elementCount; ++ii) + { + auto elementVal = extractPayload( + taggedUnionInfo, + taggedUnionVal, + elementType, + elementTypeLayout, + payloadOffset + ii*elementSize); + elementVals.Add(elementVal); + } + return builder->emitMakeVector(vecType, elementVals); + } + else if( auto matType = as<IRMatrixType>(payloadType) ) + { + SLANG_UNIMPLEMENTED_X("matrix in union type"); + } + else if( auto arrayType = as<IRArrayType>(payloadType) ) + { + SLANG_UNIMPLEMENTED_X("array in union type"); + } + else + { + // If none of the above cases match, then we assume that + // we have an individual scalar field that we need to fetch. + // + UInt64 payloadSize = 0; + if( auto resInfo = payloadTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) + { + // TODO: somebody before this point should generate an error if + // we have a `union` type that contains a potentially unbounded + // amount of data. + // + payloadSize = resInfo->count.getFiniteValue(); + } + + if( payloadSize != 4 ) + { + // TODO: We should handle the case of 64-bit fields by fetching + // two `uint` values to form a `uint2`, and then using an + // appropriate bit-cast to get from `uint2` to, e.g., `double`. + // + // The case of 16-bit and smaller fields is more troublesome, but + // in the worst case we can load a `uint` and then use bitwise + // ops to extract what we need before bitcasting. + // + // The right long-term solution is for downstream languages to have + // better support for raw memory addressing. + + SLANG_UNIMPLEMENTED_X("leaf union field with size other than 4 bytes"); + } + + // We know that we want to fetch a value of size `payloadSize`, and + // we have a known base value and an initial offset into it. + // + IRInst* baseVal = taggedUnionVal; + UInt64 offset = payloadOffset; + + // We are going to refine our `baseVal` and `offset` as we go, by + // trying to narrow down the data we will access in the `struct` + // type that will provide storage for the union. + // + // The first thing we want to check is if the value sits in the + // "bulk" part of the storage, or the "rest." + // + UInt64 bulkSize = taggedUnionInfo->bulkSize; + if( offset < bulkSize ) + { + // If the value starts in the bulk area, then the whole + // thing had better fit in the bulk area. The 16-byte + // granularity rules for constant buffers should ensure + // this property for us on current targets. + // + SLANG_ASSERT(offset + payloadSize <= bulkSize); + + // Since we know we'll be accessing the bulk storage, + // we will extract it here. The extracted field will + // be our new base value, but the `offset` doesn't need + // to be updated since the bulk field sits at offset 0. + // + baseVal = builder->emitFieldExtract( + taggedUnionInfo->bulkFieldType, + baseVal, + taggedUnionInfo->bulkFieldKey); + + // The bulk storage could be an array, if there are 32 + // or more bytes of bulk storage. + // + if( auto baseArrayType = as<IRArrayType>(baseVal->getDataType()) ) + { + // If an array was allocated for bulk storage then + // our leaf value resides entirely within a single + // element (due to constant buffer layout rules), + // and so we will fetch the appropriate element here. + // + // We will change our `baseVal` to the extracted element, + // and then also adjust our `offset` to be relative + // to that element. + // + size_t bulkElementSize = 16; + auto index = offset / bulkElementSize; + baseVal = builder->emitElementExtract( + baseArrayType->getElementType(), + baseVal, + builder->getIntValue(builder->getIntType(), index)); + offset -= index*bulkElementSize; + } + } + else + { + // If the offset of the field we want is past the end of + // the bulk field then it must sit inside of the rest field, + // and we'll extract it here. This establishes a new + // base value, and we adjust the `offset` to be relative + // to the rest field (which starts at an offset equal to `bulkSize`). + // + baseVal = builder->emitFieldExtract( + taggedUnionInfo->restFieldType, + baseVal, + taggedUnionInfo->restFieldKey); + offset -= bulkSize; + } + + // We've now extracted a field that could be either a scalar or + // a vector, and we have an offset into it. In the case where + // the base value is a vector, we will extract out the appropriate + // element. + // + if( auto baseVecType = as<IRVectorType>(baseVal->getDataType()) ) + { + size_t vecElementSize = 4; + auto index = offset / vecElementSize; + baseVal = builder->emitElementExtract( + baseVecType->getElementType(), + baseVal, + builder->getIntValue(builder->getIntType(), index)); + offset -= index*vecElementSize; + } + + // At this point, our `baseVal` should be a single `uint`, and + // it should provide the storage for the exact thing we wanted + // to access (under the assumption that we always fetch 4 bytes + // on 4-byte alignment). + // + IRInst* payloadVal = baseVal; + SLANG_ASSERT(offset == 0); + + // TODO: we could imagine adding logic here to handle types less + // than 4 bytes in size by shifting and masking the value we + // just loaded. + + // The payload field we were trying to extract might have a type + // other than `uint`, and to handle that case we need to employ + // a bit-cast to get to the desired type. + // + if( payloadVal->getDataType() != payloadType ) + { + payloadVal = builder->emitBitCast( + payloadType, + payloadVal); + } + return payloadVal; + } + } + + // All of the logic so far as assumed we can just call `getTaggedUnionInfo` + // and have easy access to all the required information and the + // synthesized replacement type. + // + TaggedUnionInfo* getTaggedUnionInfo(IRType* type) + { + // The big picture is fairly simple: we will lazily build and + // memoize the information about tagged unions. + // + { + TaggedUnionInfo* info = nullptr; + if(mapIRTypeToTaggedUnionInfo.TryGetValue(type, info)) + return info; + } + + // When we don't find information in our memo-cache, we + // will construct it and add it to both the memo-cache + // *and* a global list of all tagged unions encountered, + // so that we can replacement them later. + // + auto info = createTaggedUnionInfo(type); + mapIRTypeToTaggedUnionInfo.Add(type, info.Ptr()); + taggedUnionInfos.Add(info); + + return info; + } + + // The actual logic for creating a `TaggedUnionInfo` is relatively + // straightforward once we've decided what information we need. + // + RefPtr<TaggedUnionInfo> createTaggedUnionInfo(IRType* type) + { + // We expect that any type used as an operation to one of the + // `extractTaggedUnion*` operations must be an IR tagged union. + // + // Note: If/when we ever expose `union`s to user and allow + // then to create *generic* tagged union types it might appear + // that this needs to be changed to account for a `specialize` + // instruction in place of a concrete tagged union, but in + // practice this pass needs to be performed late enough that + // any such generic should be fully specialized. + // + auto taggedUnionType = as<IRTaggedUnionType>(type); + SLANG_ASSERT(taggedUnionType); + + RefPtr<TaggedUnionInfo> info = new TaggedUnionInfo(); + info->taggedUnionType = taggedUnionType; + + // We are going to create an instruction to replace `type`, + // and thus will be placing it into the same parent. + // + auto builder = getBuilder(); + builder->setInsertBefore(type); + + // A tagged union type will be replaced with an ordinary + // `struct` type with fields to store all the relevant + // data from any of the cases, plus a tag field. + // + auto structType = builder->createStructType(); + info->replacementInst = structType; + + // We require/expect the earlier code generation steps to have + // assocaited a layout with every tagged union that appears in + // the code. + // + auto layoutDecoration = type->findDecoration<IRLayoutDecoration>(); + SLANG_ASSERT(layoutDecoration); + auto layout = layoutDecoration->getLayout(); + SLANG_ASSERT(layout); + auto taggedUnionTypeLayout = dynamic_cast<TaggedUnionTypeLayout*>(layout); + SLANG_ASSERT(taggedUnionTypeLayout); + + info->taggedUnionTypeLayout = taggedUnionTypeLayout; + + // The size of the "payload" for the different cases (everything but + // the tag) is taken to be the offset of the tag itself. + // + // TODO: this might be inaccurate if the payload size isn't a multiple + // of the tag's alignment. We should deal with that when/if we support + // types smaller than 4 bytes in unions. + // + auto payloadSize = taggedUnionTypeLayout->tagOffset.getFiniteValue(); + + // We are going to be construction IR code that makes use of the `int` + // and `uint` types in several cases, so we go ahead and get a pointer + // to those types here. + // + auto intType = getBuilder()->getIntType(); + auto uintType = getBuilder()->getBasicType(BaseType::UInt); + + // For now we will use a simple stragegy for how we encode a union, + // which depends only on the total number of bytes needed, and not + // on the makeup of the values being stored. + // + // We will start by allocating one or more `uint4` values (in an + // array for the "or more" case) to hold the bulk of any large + // payload value. + // + size_t bulkVectorSize = 16; // Note: assuming `sizeof(uint4) == 16` on all targets + auto bulkVectorCount = payloadSize / bulkVectorSize; + auto bulkFieldSize = bulkVectorCount * bulkVectorSize; + if( bulkVectorCount ) + { + IRType* bulkFieldType = builder->getVectorType( + uintType, + builder->getIntValue(intType, 4)); + + if( bulkVectorCount > 1 ) + { + bulkFieldType = builder->getArrayType( + bulkFieldType, + builder->getIntValue(intType, bulkVectorCount)); + } + + auto bulkFieldKey = builder->createStructKey(); + builder->createStructField(structType, bulkFieldKey, bulkFieldType); + + info->bulkFieldKey = bulkFieldKey; + info->bulkFieldType = bulkFieldType; + } + info->bulkSize = bulkFieldSize; + + // The rest of the data (anything that doesn't fit in the bulk field), + // will get allocated into a single scalar or vector of `uint`. + // + auto restSize = payloadSize - bulkFieldSize; + if( restSize ) + { + size_t restElementSize = 4; // assuming `sizeof(uint) == 4` on all targets + auto restElementCount = restSize / restElementSize; + auto restFieldSize = restElementSize * restElementCount; + SLANG_ASSERT(restFieldSize == restSize); // Note: all our current targets have minimum 4-byte storage granularity + + IRType* restFieldType = uintType; + if( restElementCount > 1 ) + { + restFieldType = builder->getVectorType( + restFieldType, + builder->getIntValue(intType, restElementCount)); + } + + auto restFieldKey = builder->createStructKey(); + builder->createStructField(structType, restFieldKey, restFieldType); + + info->restFieldKey = restFieldKey; + info->restFieldType = restFieldType; + info->restSize = restFieldSize; + } + + // Finally, we add a field to represent the tag. + // + auto tagFieldType = uintType; + auto tagFieldKey = builder->createStructKey(); + builder->createStructField(structType, tagFieldKey, tagFieldType); + + info->tagFieldKey = tagFieldKey; + + return info; + } +}; + +void desugarUnionTypes( + IRModule* module) +{ + DesugarUnionTypesContext context; + context.module = module; + + context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/ir-union.h b/source/slang/ir-union.h new file mode 100644 index 000000000..58de4e81e --- /dev/null +++ b/source/slang/ir-union.h @@ -0,0 +1,18 @@ +// ir-union.h +#pragma once + +namespace Slang { + +struct IRModule; + + /// Desugar any unions types, and code using them, in `module` + /// + /// Union types will be replaced with ordinary `struct` types that store + /// the data of the underlying type as a "bag of bits" and references + /// to cases of the union will be replaced with logic to extract the + /// relevant bits. + /// +void desugarUnionTypes( + IRModule* module); + +} // namespace Slang diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 48716bd87..8f33034b7 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -1731,6 +1731,18 @@ namespace Slang operands); } + IRType* IRBuilder::getTaggedUnionType( + UInt caseCount, + IRType* const* caseTypes) + { + return (IRType*) findOrEmitHoistableInst( + this, + getTypeKind(), + kIROp_TaggedUnionType, + caseCount, + (IRInst* const*) caseTypes); + } + void IRBuilder::setDataType(IRInst* inst, IRType* dataType) { if (auto oldRateQualifiedType = as<IRRateQualifiedType>(inst->getFullType())) @@ -2684,6 +2696,50 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitExtractTaggedUnionTag( + IRInst* val) + { + auto inst = createInst<IRInst>( + this, + kIROp_ExtractTaggedUnionTag, + getBasicType(BaseType::UInt), + val); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitExtractTaggedUnionPayload( + IRType* type, + IRInst* val, + IRInst* tag) + { + auto inst = createInst<IRInst>( + this, + kIROp_ExtractTaggedUnionPayload, + type, + val, + tag); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBitCast( + IRType* type, + IRInst* val) + { + auto inst = createInst<IRInst>( + this, + kIROp_BitCast, + type, + val); + addInst(inst); + return inst; + } + + // + // Decorations + // + IRDecoration* IRBuilder::addDecoration(IRInst* value, IROp op, IRInst* const* operands, Int operandCount) { auto decoration = createInstWithTrailingArgs<IRDecoration>( @@ -3796,6 +3852,14 @@ namespace Slang } } + void IRInst::transferDecorationsTo(IRInst* target) + { + while( auto decoration = getFirstDecoration() ) + { + decoration->removeFromParent(); + decoration->insertAtStart(target); + } + } bool IRInst::mightHaveSideEffects() { diff --git a/source/slang/ir.h b/source/slang/ir.h index de644a709..e62bb2249 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -373,6 +373,9 @@ struct IRInst // for those values. void removeArguments(); + /// Transfer any decorations of this instruction to the `target` instruction. + void transferDecorationsTo(IRInst* target); + /// Does this instruction have any uses? bool hasUses() const { return firstUse != nullptr; } @@ -959,18 +962,23 @@ struct IRStructField : IRInst // *not* contain the keys, because code needs to be able to // reference the keys from scopes outside of the struct. // -struct IRStructType : IRInst +struct IRStructType : IRType { IRInstList<IRStructField> getFields() { return IRInstList<IRStructField>(getChildren()); } IR_LEAF_ISA(StructType) }; -struct IRInterfaceType : IRInst +struct IRInterfaceType : IRType { IR_LEAF_ISA(InterfaceType) }; +struct IRTaggedUnionType : IRType +{ + IR_LEAF_ISA(TaggedUnionType) +}; + /// @brief A global value that potentially holds executable code. /// struct IRGlobalValueWithCode : IRInst diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index b1bc63fa1..3e981dbc5 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -1121,6 +1121,259 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower requirementKey)); } + LoweredValInfo visitTaggedUnionSubtypeWitness( + TaggedUnionSubtypeWitness* val) + { + // The sub-type in this case is a tagged union `A | B | ...`, + // and the witness holds an array of witnesses showing that each + // "case" (`A`, `B`, etc.) is a subtype of the super-type. + + // We will start by getting the IR-level representation of the + // sub type (the tagged union type). + // + auto irTaggedUnionType = lowerType(context, val->sub); + + // We can turn each of those per-case witnesses into a witness + // table value: + // + auto caseCount = val->caseWitnesses.Count(); + List<IRInst*> caseWitnessTables; + for( auto caseWitness : val->caseWitnesses ) + { + auto caseWitnessTable = lowerSimpleVal(context, caseWitness); + caseWitnessTables.Add(caseWitnessTable); + } + + // Now we need to synthesize a witness table for the tagged union + // value, showing how it can implement all of the requirements + // of the super type by delegating to the appropriate implementation + // on a per-case basis. + // + // We will assume here that the super-type is an interface, and it + // will be left to the front-end to ensure this property. + // + auto supDeclRefType = val->sup->As<DeclRefType>(); + if(!supDeclRefType) + { + SLANG_UNEXPECTED("super-type not a decl-ref type when generating tagged union witness table"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + auto supInterfaceDeclRef = supDeclRefType->declRef.As<InterfaceDecl>(); + if( !supInterfaceDeclRef ) + { + SLANG_UNEXPECTED("super-type not an interface type when generating tagged union witness table"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + auto irWitnessTable = getBuilder()->createWitnessTable(); + + // Now we will iterate over the requirements (members) of the + // interface and try to synthesize an appropriate value for each. + // + for( auto reqDeclRef : getMembers(supInterfaceDeclRef) ) + { + // TODO: if there are any members we shouldn't process as a requirement, + // then we should detect and skip them here. + // + + // Every interface requirement will have a unique key that is used + // when looking up the requirement in a concrete witness table. + // + auto irReqKey = getInterfaceRequirementKey(context, reqDeclRef.getDecl()); + + // We expect that each of the witness tables in `caseWitnessTables` + // will have an entry to match these keys. However, we may not + // have a concrete `IRWitnessTable` for each of the case types, either + // because they are a specialization of a generic (so that the witness + // table reference is a `specialize` instruction at this point), or + // they are a type external to this module (so that we have a declaration + // rather than a definition of the witness table). + + // Our task is to create an IR value that can satisfy the interface + // requirement for the tagged union type, by appropriately delegating + // to the implementations of the same requirement in the case types. + // + IRInst* irSatisfyingVal = nullptr; + + + + if(auto callableDeclRef = reqDeclRef.As<CallableDecl>()) + { + // We have something callable, so we need to synthesize + // a function to satisfy it. + // + auto irFunc = getBuilder()->createFunc(); + irSatisfyingVal = irFunc; + + IRBuilder subBuilderStorage; + auto subBuilder = &subBuilderStorage; + subBuilder->sharedBuilder = getBuilder()->sharedBuilder; + subBuilder->setInsertInto(irFunc); + + // We will start by setting up the function parameters, + // which live in the entry block of the IR function. + // + auto entryBlock = subBuilder->emitBlock(); + subBuilder->setInsertInto(entryBlock); + + // Create a `this` parameter of the tagged-union type. + // + // TODO: need to handle the `[mutating]` case here... + // + auto irThisType = irTaggedUnionType; + auto irThisParam = subBuilder->emitParam(irThisType); + + List<IRType*> irParamTypes; + irParamTypes.Add(irThisType); + + // Create the remaining parameters of the callable, + // using a decl-ref specialized to the tagged union + // type (so that things like associated types are + // mapped to the correct witness value). + // + List<IRParam*> irParams; + for( auto paramDeclRef : getMembersOfType<ParamDecl>(callableDeclRef) ) + { + // TODO: need to handle `out` and `in out` here. Over all + // there is a lot of duplication here with the existing logic + // for emitting the signature of a `CallableDecl`, and we should + // try to re-use that if at all possible. + // + auto irParamType = lowerType(context, GetType(paramDeclRef)); + auto irParam = subBuilder->emitParam(irParamType); + + irParams.Add(irParam); + irParamTypes.Add(irParamType); + } + + auto irResultType = lowerType(context, GetResultType(callableDeclRef)); + + auto irFuncType = subBuilder->getFuncType( + irParamTypes, + irResultType); + irFunc->setFullType(irFuncType); + + // The first thing our function needs to do is extract the tag + // from the incoming `this` parameter. + // + auto irTagVal = subBuilder->emitExtractTaggedUnionTag(irThisParam); + + // Next we want to emit a `switch` on the tag value, but before we + // do that we need to generate the code for each of the cases so that + // our `switch` has somewhere to branch to. + // + List<IRInst*> switchCaseOperands; + + IRBlock* defaultLabel = nullptr; + + for( UInt ii = 0; ii < caseCount; ++ii ) + { + auto caseTag = subBuilder->getIntValue(irTagVal->getDataType(), ii); + + subBuilder->setInsertInto(irFunc); + auto caseLabel = subBuilder->emitBlock(); + + if(!defaultLabel) + defaultLabel = caseLabel; + + switchCaseOperands.Add(caseTag); + switchCaseOperands.Add(caseLabel); + + subBuilder->setInsertInto(caseLabel); + + // We need to look up the satisfying value for this interface + // requirement on the witness table of the particular case value. + // + // We already have the witness table, and the requirement key is + // just `irReqKey`. + // + auto caseWitnessTable = caseWitnessTables[ii]; + + // The subtle bit here is determining the type we expect the + // satisfying value to have, since that depends on the actual + // type that is satisfying the requirement. + // + IRType* caseResultType = irResultType; + IRType* caseFuncType = nullptr; + auto caseFunc = subBuilder->emitLookupInterfaceMethodInst( + caseFuncType, + caseWitnessTable, + irReqKey); + + // We are going to emit a `call` to the satisfying value + // for the case type, so we will collect the arguments for that call. + // + List<IRInst*> caseArgs; + + // The `this` argument to the call will need to represent the + // appropriate field of our tagged union. + // + IRType* caseThisType = (IRType*) irTaggedUnionType->getOperand(ii); + auto caseThisArg = subBuilder->emitExtractTaggedUnionPayload( + caseThisType, + irThisParam, caseTag); + caseArgs.Add(caseThisArg); + + // The remaining arguments to the call will just be forwarded from + // the parameters of the wrapper functon. + // + // TODO: This would need to change if/when we started allowing `This` type + // or assocaited-type parameters to be used at call sites where a tagged + // union is used. + // + for( auto param : irParams ) + { + caseArgs.Add(param); + } + + auto caseCall = subBuilder->emitCallInst(caseResultType, caseFunc, caseArgs); + + if( as<IRVoidType>(irResultType->getDataType()) ) + { + subBuilder->emitReturn(); + } + else + { + subBuilder->emitReturn(caseCall); + } + } + + // We will create a block to represent the supposedly-unreachable + // code that will run if no `case` matches. + // + subBuilder->setInsertInto(irFunc); + auto invalidLabel = subBuilder->emitBlock(); + subBuilder->setInsertInto(invalidLabel); + subBuilder->emitUnreachable(); + + if(!defaultLabel) defaultLabel = invalidLabel; + + // Now we have enough information to go back and emit the `switch` instruction + // into the entry block. + subBuilder->setInsertInto(entryBlock); + subBuilder->emitSwitch( + irTagVal, // value to `switch` on + invalidLabel, // `break` label (block after the `switch` statement ends) + defaultLabel, // `default` label (where to go if no `case` matches) + switchCaseOperands.Count(), + switchCaseOperands.Buffer()); + } + else + { + // TODO: We need to handle other cases of interface requirements. + SLANG_UNEXPECTED("unexpceted interface requirement when generating tagged union witness table"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + // Once we've generating a value to satisfying the requirement, we install + // it into the witness table for our tagged-union type. + // + getBuilder()->createWitnessTableEntry(irWitnessTable, irReqKey, irSatisfyingVal); + } + + return LoweredValInfo::simple(irWitnessTable); + } + LoweredValInfo visitConstantIntVal(ConstantIntVal* val) { // TODO: it is a bit messy here that the `ConstantIntVal` representation @@ -1304,6 +1557,37 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(getBuilder()->emitExtractExistentialWitnessTable(existentialVal)); } + LoweredValInfo visitTaggedUnionType(TaggedUnionType* type) + { + // A tagged union type will lower into an IR `union` over the cases, + // along with an IR `struct` with a field for the union and a tag. + // (Note: we are placing the tag after the payload to avoid padding + // in the case where the payload is more aligned than the tag) + // + // TODO: should we be lowering directly like this, or have + // an IR-level representation of tagged unions? + // + + List<IRType*> irCaseTypes; + for(auto caseType : type->caseTypes) + { + auto irCaseType = lowerType(context, caseType); + irCaseTypes.Add(irCaseType); + } + + auto irType = getBuilder()->getTaggedUnionType(irCaseTypes); + if(!irType->findDecoration<IRLinkageDecoration>()) + { + // We need a way for later passes to attach layout information + // to this type, so we will give it a mangled name here. + // + getBuilder()->addExportDecoration( + irType, + getMangledTypeName(type).getUnownedSlice()); + } + return LoweredValInfo::simple(irType); + } + // We do not expect to encounter the following types in ASTs that have // passed front-end semantic checking. #define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); } @@ -2337,6 +2621,12 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + LoweredValInfo visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* /*expr*/) + { + SLANG_UNIMPLEMENTED_X("tagged union type expression during code generation"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + LoweredValInfo visitAssignExpr(AssignExpr* expr) { // Because our representation of lowered "values" diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp index d78f3321a..8ad0bc9f5 100644 --- a/source/slang/mangle.cpp +++ b/source/slang/mangle.cpp @@ -144,6 +144,15 @@ namespace Slang emitSimpleIntVal(context, arrType->ArrayLength); emitType(context, arrType->baseType); } + else if( auto taggedUnionType = dynamic_cast<TaggedUnionType*>(type) ) + { + emitRaw(context, "u"); + for( auto caseType : taggedUnionType->caseTypes ) + { + emitType(context, caseType); + } + emitRaw(context, "U"); + } else { SLANG_UNEXPECTED("unimplemented case in mangling"); diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index b8df39ade..3f5c01d09 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -2429,10 +2429,20 @@ static void collectEntryPointParameters( entryPointLayout->entryPoint = entryPointFuncDecl; context->entryPointLayout = entryPointLayout; - - context->shared->programLayout->entryPoints.Add(entryPointLayout); + // Note: this isn't really the best place for this logic to sit, + // but it is the simplest place where we have a direct correspondance + // between a single `EntryPointRequest` and its matching `EntryPointLayout`, + // so we'll use it. + // + for( auto taggedUnionType : entryPoint->taggedUnionTypes ) + { + auto substType = taggedUnionType->Substitute(typeSubst).As<Type>(); + auto typeLayout = CreateTypeLayout(context->layoutContext, substType); + entryPointLayout->taggedUnionTypeLayouts.Add(typeLayout); + } + // Okay, we seemingly have an entry-point function, and now we need to collect info on its parameters too // // TODO: Long-term we probably want complete information on all inputs/outputs of an entry point, diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index d0e965970..1afc8ebcf 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -1769,6 +1769,25 @@ namespace Slang return typeExpr; } + static RefPtr<Expr> parseTaggedUnionType(Parser* parser) + { + RefPtr<TaggedUnionTypeExpr> taggedUnionType = new TaggedUnionTypeExpr(); + + parser->ReadToken(TokenType::LParent); + while(!AdvanceIfMatch(parser, TokenType::RParent)) + { + auto caseType = parser->ParseTypeExp(); + taggedUnionType->caseTypes.Add(caseType); + + if(AdvanceIf(parser, TokenType::RParent)) + break; + + parser->ReadToken(TokenType::Comma); + } + + return taggedUnionType; + } + static TypeSpec parseTypeSpec(Parser* parser) { TypeSpec typeSpec; @@ -1812,6 +1831,11 @@ namespace Slang typeSpec.expr = createDeclRefType(parser, decl); return typeSpec; } + else if(AdvanceIf(parser, "__TaggedUnion")) + { + typeSpec.expr = parseTaggedUnionType(parser); + return typeSpec; + } Token typeName = parser->ReadToken(TokenType::Identifier); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index d516d0cb2..38b417960 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -827,7 +827,7 @@ int CompileRequest::addEntryPoint( entryPoint->profile = entryPointProfile; entryPoint->translationUnitIndex = translationUnitIndex; for (auto typeName : genericTypeNames) - entryPoint->genericParameterTypeNames.Add(typeName); + entryPoint->genericArgStrings.Add(typeName); auto translationUnit = translationUnits[translationUnitIndex].Ptr(); translationUnit->entryPoints.Add(entryPoint); diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index 132a56ea5..42b1a449d 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -198,6 +198,7 @@ <ClInclude Include="ir-specialize-resources.h" /> <ClInclude Include="ir-specialize.h" /> <ClInclude Include="ir-ssa.h" /> + <ClInclude Include="ir-union.h" /> <ClInclude Include="ir-validate.h" /> <ClInclude Include="ir.h" /> <ClInclude Include="legalize-types.h" /> @@ -252,6 +253,7 @@ <ClCompile Include="ir-specialize-resources.cpp" /> <ClCompile Include="ir-specialize.cpp" /> <ClCompile Include="ir-ssa.cpp" /> + <ClCompile Include="ir-union.cpp" /> <ClCompile Include="ir-validate.cpp" /> <ClCompile Include="ir.cpp" /> <ClCompile Include="legalize-types.cpp" /> @@ -314,4 +316,4 @@ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> <ImportGroup Label="ExtensionTargets"> </ImportGroup> -</Project>
\ No newline at end of file +</Project>
\ No newline at end of file diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index 390c0cc5f..aeb7a1b3e 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -93,6 +93,9 @@ <ClInclude Include="ir-ssa.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="ir-union.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="ir-validate.h"> <Filter>Header Files</Filter> </ClInclude> @@ -251,6 +254,9 @@ <ClCompile Include="ir-ssa.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="ir-union.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="ir-validate.cpp"> <Filter>Source Files</Filter> </ClCompile> diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 320fc576b..fe8b2c4fe 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -2539,7 +2539,160 @@ void Type::accept(IValVisitor* visitor, void* extra) return substValue; } + // + // TaggedUnionType + // + + String TaggedUnionType::ToString() + { + String result; + result.append("__TaggedUnion("); + bool first = true; + for( auto caseType : caseTypes ) + { + if(!first) result.append(", "); + first = false; + + result.append(caseType->ToString()); + } + result.append(")"); + return result; + } + + bool TaggedUnionType::EqualsImpl(Type* type) + { + auto taggedUnion = type->As<TaggedUnionType>(); + if(!taggedUnion) + return false; + + auto caseCount = caseTypes.Count(); + if(caseCount != taggedUnion->caseTypes.Count()) + return false; + + for( UInt ii = 0; ii < caseCount; ++ii ) + { + if(!caseTypes[ii]->Equals(taggedUnion->caseTypes[ii])) + return false; + } + return true; + } + + int TaggedUnionType::GetHashCode() + { + int hashCode = 0; + for( auto caseType : caseTypes ) + { + hashCode = combineHash(hashCode, caseType->GetHashCode()); + } + return hashCode; + } + + RefPtr<Type> TaggedUnionType::CreateCanonicalType() + { + RefPtr<TaggedUnionType> canType = new TaggedUnionType(); + canType->setSession(getSession()); + + for( auto caseType : caseTypes ) + { + auto canCaseType = caseType->GetCanonicalType(); + canType->caseTypes.Add(canCaseType); + } + + return canType; + } + + RefPtr<Val> TaggedUnionType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + int diff = 0; + + List<RefPtr<Type>> substCaseTypes; + for( auto caseType : caseTypes ) + { + substCaseTypes.Add(caseType->SubstituteImpl(subst, &diff).As<Type>()); + } + if(!diff) + return this; + + (*ioDiff)++; + + RefPtr<TaggedUnionType> substType = new TaggedUnionType(); + substType->setSession(getSession()); + substType->caseTypes.SwapWith(substCaseTypes); + return substType; + } + +// +// TaggedUnionSubtypeWitness +// + + +bool TaggedUnionSubtypeWitness::EqualsVal(Val* val) +{ + auto taggedUnionWitness = val->dynamicCast<TaggedUnionSubtypeWitness>(); + if(!taggedUnionWitness) + return false; + + auto caseCount = caseWitnesses.Count(); + if(caseCount != taggedUnionWitness->caseWitnesses.Count()) + return false; + + for(UInt ii = 0; ii < caseCount; ++ii) + { + if(!caseWitnesses[ii]->EqualsVal(taggedUnionWitness->caseWitnesses[ii])) + return false; + } + + return true; +} + +String TaggedUnionSubtypeWitness::ToString() +{ + String result; + result.append("TaggedUnionSubtypeWitness("); + bool first = true; + for( auto caseWitness : caseWitnesses ) + { + if(!first) result.append(", "); + first = false; + + result.append(caseWitness->ToString()); + } + return result; +} + +int TaggedUnionSubtypeWitness::GetHashCode() +{ + int hash = 0; + for( auto caseWitness : caseWitnesses ) + { + hash = combineHash(hash, caseWitness->GetHashCode()); + } + return hash; +} + +RefPtr<Val> TaggedUnionSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substSub = sub->SubstituteImpl(subst, &diff).As<Type>(); + auto substSup = sup->SubstituteImpl(subst, &diff).As<Type>(); + + List<RefPtr<Val>> substCaseWitnesses; + for( auto caseWitness : caseWitnesses ) + { + substCaseWitnesses.Add(caseWitness->SubstituteImpl(subst, &diff)); + } + + if(!diff) + return this; + (*ioDiff)++; + RefPtr<TaggedUnionSubtypeWitness> substWitness = new TaggedUnionSubtypeWitness(); + substWitness->sub = substSub; + substWitness->sup = substSup; + substWitness->caseWitnesses.SwapWith(substCaseWitnesses); + return substWitness; } +} // namespace Slang diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h index d9bfecc03..2318c6933 100644 --- a/source/slang/type-defs.h +++ b/source/slang/type-defs.h @@ -457,3 +457,21 @@ RAW( virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; ) END_SYNTAX_CLASS() + + /// A tagged union of zero or more other types. +SYNTAX_CLASS(TaggedUnionType, Type) +RAW( + /// The distinct "cases" the tagged union can store. + /// + /// For each type in this array, the array index is the + /// tag value for that case. + /// + List<RefPtr<Type>> caseTypes; + + virtual String ToString() override; + virtual bool EqualsImpl(Type * type) override; + virtual int GetHashCode() override; + virtual RefPtr<Type> CreateCanonicalType() override; + virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; +) +END_SYNTAX_CLASS() diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp index f5cd518b8..05c61e706 100644 --- a/source/slang/type-layout.cpp +++ b/source/slang/type-layout.cpp @@ -2349,6 +2349,101 @@ SimpleLayoutInfo GetLayoutImpl( rules, outTypeLayout); } + else if( auto taggedUnionType = type->As<TaggedUnionType>() ) + { + // A tagged union type needs to be laid out as the maximum + // size of any constituent type. + // + // In practice, only a tagged union of uniform data will + // work, but for now we will compute the maximum usage + // for each resource kind for generality. + // + // For the uniform data we will start with a size + // of zero and an alignment of one for our base case + // (this is what a tagged union of no cases would consume). + // + UniformLayoutInfo info(0, 1); + + // If we are being asked to construct a full `TypeLayout` + // object, then we'll allocate it up front. + // + RefPtr<TaggedUnionTypeLayout> taggedUnionLayout; + if( outTypeLayout ) + { + taggedUnionLayout = new TaggedUnionTypeLayout(); + taggedUnionLayout->type = type; + taggedUnionLayout->rules = rules; + *outTypeLayout = taggedUnionLayout; + } + + // Now we iterate over the case types and see if they + // change our computed maximum size/alignement. + // + for( auto caseType : taggedUnionType->caseTypes ) + { + RefPtr<TypeLayout> caseTypeLayout; + UniformLayoutInfo caseTypeInfo = GetLayoutImpl(context, caseType, outTypeLayout ? &caseTypeLayout : nullptr).getUniformLayout(); + + info.size = maximum(info.size, caseTypeInfo.size); + info.alignment = std::max(info.alignment, caseTypeInfo.alignment); + + // If we are building a full `TypeLayout` we need to + // do a few more steps for each case type. + // + if( outTypeLayout ) + { + // We need to remember the layout of the case type + // on the final `TaggedUnionTypeLayout`. + // + taggedUnionLayout->caseTypeLayouts.Add(caseTypeLayout); + + // We also need to consider contributions for other + // resource kinds beyond uniform data. + // + for( auto caseResInfo : caseTypeLayout->resourceInfos ) + { + auto unionResInfo = taggedUnionLayout->findOrAddResourceInfo(caseResInfo.kind); + unionResInfo->count = maximum(unionResInfo->count, caseResInfo.count); + } + } + } + + // After we've computed the size required to hold all the + // case types, we will allocate space for the tag field. + // + // TODO: This assumes the tag will always be allocated out + // of uniform storage, which means we can't support a tagged + // union as part of a varying input/output signature. That is + // probably a valid limitation, but it should get enforced + // somewhere along the way. + // + { + // The tag is always a `uint` for now. + // + auto tagInfo = context.rules->GetScalarLayout(BaseType::UInt); + info.size = RoundToAlignment(info.size, tagInfo.alignment); + + if( outTypeLayout ) + { + taggedUnionLayout->tagOffset = info.size; + } + + info.size += tagInfo.size; + info.alignment = std::max(info.alignment, tagInfo.alignment); + } + + // As a final step, if we are computing a full `TypeLayout` + // we will make sure that its information on uniform layout + // matches what we've computed in the `UniformLayoutInfo` we return. + // + if( outTypeLayout ) + { + taggedUnionLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->count = info.size; + taggedUnionLayout->uniformAlignment = info.alignment; + } + + return info; + } // catch-all case in case nothing matched SLANG_ASSERT(!"unimplemented"); diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index c326ae989..da2f0e4f7 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -553,6 +553,28 @@ public: int paramIndex = 0; }; + /// Layout information for a tagged union type. +class TaggedUnionTypeLayout : public TypeLayout +{ +public: + /// The layouts of each of the case types. + /// + /// The order of entries in this array matches + /// the order of case types on the original + /// `TaggedUnionType`, and the index of a case + /// type is also the tag value for that case. + /// + List<RefPtr<TypeLayout>> caseTypeLayouts; + + /// The byte offset for the tag field. + /// + /// The tag field will always be allocted as + /// a `uint`, so we don't store a separate layout + /// for it. + /// + LayoutSize tagOffset; +}; + // Layout information for a single shader entry point // within a program // @@ -579,6 +601,12 @@ public: usesAnySampleRateInput = 0x1, }; unsigned flags = 0; + + /// Layouts for all tagged union types required by this entry point. + /// + /// These are any tagged union types used by the generic + /// arguments that this entry point is being compiled with. + List<RefPtr<TypeLayout>> taggedUnionTypeLayouts; }; class GenericParamLayout : public Layout diff --git a/source/slang/val-defs.h b/source/slang/val-defs.h index f96ee026e..f5b099079 100644 --- a/source/slang/val-defs.h +++ b/source/slang/val-defs.h @@ -136,3 +136,20 @@ RAW( ) END_SYNTAX_CLASS() +// A witness that `sub : sup`, because `sub` is a tagged union +// of the form `A | B | C | ...` and each of `A : sup`, +// `B : sup`, `C : sup`, etc. +// +SYNTAX_CLASS(TaggedUnionSubtypeWitness, SubtypeWitness) +RAW( + // Witnesses that each of the "case" types in the union + // is a subtype of `sup`. + // + List<RefPtr<Val>> caseWitnesses; + + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; +) +END_SYNTAX_CLASS() |
