summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/check.cpp163
-rw-r--r--source/slang/compiler.h8
-rw-r--r--source/slang/diagnostic-defs.h2
-rw-r--r--source/slang/emit.cpp98
-rw-r--r--source/slang/expr-defs.h11
-rw-r--r--source/slang/ir-inst-defs.h7
-rw-r--r--source/slang/ir-insts.h56
-rw-r--r--source/slang/ir-link.cpp24
-rw-r--r--source/slang/ir-ssa.cpp11
-rw-r--r--source/slang/ir-union.cpp776
-rw-r--r--source/slang/ir-union.h18
-rw-r--r--source/slang/ir.cpp64
-rw-r--r--source/slang/ir.h12
-rw-r--r--source/slang/lower-to-ir.cpp290
-rw-r--r--source/slang/mangle.cpp9
-rw-r--r--source/slang/parameter-binding.cpp14
-rw-r--r--source/slang/parser.cpp24
-rw-r--r--source/slang/slang.cpp2
-rw-r--r--source/slang/slang.vcxproj4
-rw-r--r--source/slang/slang.vcxproj.filters6
-rw-r--r--source/slang/syntax.cpp153
-rw-r--r--source/slang/type-defs.h18
-rw-r--r--source/slang/type-layout.cpp95
-rw-r--r--source/slang/type-layout.h28
-rw-r--r--source/slang/val-defs.h17
25 files changed, 1883 insertions, 27 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 2fa4e9bc7..9aa5eb689 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -9,6 +9,8 @@
namespace Slang
{
+ RefPtr<TypeType> getTypeType(
+ Type* type);
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
bool isEffectivelyStatic(
@@ -677,7 +679,7 @@ namespace Slang
auto baseExprType = baseExpr->type.type;
RefPtr<SharedTypeExpr> baseTypeExpr = new SharedTypeExpr();
baseTypeExpr->base.type = baseExprType;
- baseTypeExpr->type = new TypeType(baseExprType);
+ baseTypeExpr->type.type = getTypeType(baseExprType);
auto expr = new StaticMemberExpr();
expr->loc = loc;
@@ -2071,9 +2073,7 @@ namespace Slang
{
RefPtr<TypeCastExpr> castExpr = createImplicitCastExpr();
- auto typeType = new TypeType();
- typeType->setSession(getSession());
- typeType->type = toType;
+ auto typeType = getTypeType(toType);
auto typeExpr = new SharedTypeExpr();
typeExpr->type.type = typeType;
@@ -5547,6 +5547,60 @@ namespace Slang
return witness;
}
+ /// Is the given interface one that a tagged-union type can conform to?
+ ///
+ /// If a tagged union type `__TaggedUnion(A,B)` is going to be
+ /// plugged in for a type parameter `T : IFoo` then we need to
+ /// be sure that the interface `IFoo` doesn't have anything
+ /// that could lead to unsafe/unsound behavior. This function
+ /// checks that all the requirements on the interfaceare safe ones.
+ ///
+ bool isInterfaceSafeForTaggedUnion(
+ DeclRef<InterfaceDecl> interfaceDeclRef)
+ {
+ for( auto memberDeclRef : getMembers(interfaceDeclRef) )
+ {
+ if(!isInterfaceRequirementSafeForTaggedUnion(interfaceDeclRef, memberDeclRef))
+ return false;
+ }
+
+ return true;
+ }
+
+ /// Is the given interface requirement one that a tagged-union type can satisfy?
+ ///
+ /// Unsafe requirements include any `static` requirements,
+ /// any associated types, and also any requirements that make
+ /// use of the `This` type (once we support it).
+ ///
+ bool isInterfaceRequirementSafeForTaggedUnion(
+ DeclRef<InterfaceDecl> interfaceDeclRef,
+ DeclRef<Decl> requirementDeclRef)
+ {
+ if(auto callableDeclRef = requirementDeclRef.As<CallableDecl>())
+ {
+ // A `static` method requirement can't be satisfied by a
+ // tagged union, because there is no tag to dispatch on.
+ //
+ if(requirementDeclRef.getDecl()->HasModifier<HLSLStaticModifier>())
+ return false;
+
+ // TODO: We will eventually want to check that any callable
+ // requirements do not use the `This` type or any associated
+ // types in ways that could lead to errors.
+ //
+ // For now we are disallowing interfaces that have associated
+ // types completely, and we haven't implemented the `This`
+ // type, so we should be safe.
+
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
bool doesTypeConformToInterfaceImpl(
RefPtr<Type> originalType,
RefPtr<Type> type,
@@ -5661,6 +5715,69 @@ namespace Slang
}
}
}
+ else if(auto taggedUnionType = type->As<TaggedUnionType>())
+ {
+ // A tagged union type conforms to an interface if all of
+ // the constituent types in the tagged union conform.
+ //
+ // We will iterate over the "case" types in the tagged
+ // union, and check if they conform to the interface.
+ // Along the way we will collect the conformance witness
+ // values *if* we are being asked to produce a witness
+ // value for the tagged union itself (that is, if
+ // `outWitness` is non-null).
+ //
+ List<RefPtr<Val>> caseWitnesses;
+ for(auto caseType : taggedUnionType->caseTypes)
+ {
+ RefPtr<Val> caseWitness;
+
+ if(!doesTypeConformToInterfaceImpl(
+ caseType,
+ caseType,
+ interfaceDeclRef,
+ outWitness ? &caseWitness : nullptr,
+ nullptr))
+ {
+ return false;
+ }
+
+ if(outWitness)
+ {
+ caseWitnesses.Add(caseWitness);
+ }
+ }
+
+ // We also need to validate the requirements on
+ // the interface to make sure that they are suitable for
+ // use with a tagged-union type.
+ //
+ // For example, if the interface includes a `static` method
+ // (which can therefore be called without a particular instance),
+ // then we wouldn't know what implementation of that method
+ // to use because there is no tag value to dispatch on.
+ //
+ // We will start out being conservative about what we accept
+ // here, just to keep things simple.
+ //
+ if(!isInterfaceSafeForTaggedUnion(interfaceDeclRef))
+ return false;
+
+ // If we reach this point then we have a concrete
+ // witness for each of the case types, and that is
+ // enough to build a witness for the tagged union.
+ //
+ if(outWitness)
+ {
+ RefPtr<TaggedUnionSubtypeWitness> taggedUnionWitness = new TaggedUnionSubtypeWitness();
+ taggedUnionWitness->sub = taggedUnionType;
+ taggedUnionWitness->sup = DeclRefType::Create(getSession(), interfaceDeclRef);
+ taggedUnionWitness->caseWitnesses.SwapWith(caseWitnesses);
+
+ *outWitness = taggedUnionWitness;
+ }
+ return true;
+ }
// default is failure
return false;
@@ -8090,6 +8207,23 @@ namespace Slang
return expr;
}
+ RefPtr<Expr> visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr)
+ {
+ // We have an expression of the form `__TaggedUnion(A, B, ...)`
+ // which will evaluate to a tagged-union type over `A`, `B`, etc.
+ //
+ RefPtr<TaggedUnionType> type = new TaggedUnionType();
+ expr->type = QualType(getTypeType(type));
+
+ for( auto& caseTypeExpr : expr->caseTypes )
+ {
+ caseTypeExpr = CheckProperType(caseTypeExpr);
+ type->caseTypes.Add(caseTypeExpr.type);
+ }
+
+ return expr;
+ }
+
@@ -9039,7 +9173,7 @@ namespace Slang
scopesToTry.Add(module->moduleDecl->scope);
List<RefPtr<Type>> globalGenericArgs;
- for (auto name : entryPoint->genericParameterTypeNames)
+ for (auto name : entryPoint->genericArgStrings)
{
// parse type name
RefPtr<Type> type;
@@ -9059,6 +9193,25 @@ namespace Slang
return;
}
+ // The following is a bit of a hack.
+ //
+ // Back-end code generation relies on us having computed layouts for all tagged
+ // unions that end up being used in the code, which means we need a way to find
+ // all such types that get used in a module (and the stuff it imports).
+ //
+ // The Right Way to handle this would probably be to have each `ModuleDecl` track
+ // any tagged union types that get created in the context of that module, and
+ // then combine those lists later.
+ //
+ // For now we are assuming a tagged union type only comes into existence
+ // as a (top-level) argument for a generic type parameter, so that we
+ // can check for them here and cache them on the entry point request.
+ //
+ if( auto taggedUnionType = type->As<TaggedUnionType>() )
+ {
+ entryPoint->taggedUnionTypes.Add(taggedUnionType);
+ }
+
globalGenericArgs.Add(type);
}
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index d2072387e..41ba027c6 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -122,9 +122,8 @@ namespace Slang
// The name of the entry point function (e.g., `main`)
Name* name;
- // The type names we want to substitute into the
- // global generic type parameters
- List<String> genericParameterTypeNames;
+ /// Source code for the generic arguments to use for the generic parameters of the entry point.
+ List<String> genericArgStrings;
// The profile that the entry point will be compiled for
// (this is a combination of the target stage, and also
@@ -156,6 +155,9 @@ namespace Slang
RefPtr<FuncDecl> decl;
RefPtr<Substitutions> globalGenericSubst;
+
+ /// Any tagged union types that were referenced by the generic arguments of the entry point.
+ List<RefPtr<TaggedUnionType>> taggedUnionTypes;
};
enum class PassThroughMode : SlangPassThrough
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
index 76e59efa3..4a43cf8e4 100644
--- a/source/slang/diagnostic-defs.h
+++ b/source/slang/diagnostic-defs.h
@@ -347,7 +347,7 @@ DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is onl
DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration")
DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.")
-DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$1`.")
+DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$2`.")
DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself")
DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')")
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 2e801f0e3..0aebf6153 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -12,6 +12,7 @@
#include "ir-specialize.h"
#include "ir-specialize-resources.h"
#include "ir-ssa.h"
+#include "ir-union.h"
#include "ir-validate.h"
#include "legalize-types.h"
#include "lower-to-ir.h"
@@ -3922,6 +3923,72 @@ struct EmitVisitor
}
break;
+ case kIROp_BitCast:
+ {
+ // TODO: we can simplify the logic for arbitrary bitcasts
+ // by always bitcasting the source to a `uint*` type (if it
+ // isn't already) and then bitcasting that to the destination
+ // type (if it isn't already `uint*`.
+ //
+ // For now we are assuming the source type is *already*
+ // a `uint*` type of the appropriate size.
+ //
+// auto fromType = extractBaseType(inst->getOperand(0)->getDataType());
+ auto toType = extractBaseType(inst->getDataType());
+ switch(getTarget(ctx))
+ {
+ case CodeGenTarget::GLSL:
+ switch(toType)
+ {
+ default:
+ emit("/* unhandled */");
+ break;
+
+ case BaseType::UInt:
+ break;
+
+ case BaseType::Int:
+ emitIRType(ctx, inst->getDataType());
+ break;
+
+ case BaseType::Float:
+ emit("uintBitsToFloat(");
+ break;
+ }
+ break;
+
+ case CodeGenTarget::HLSL:
+ switch(toType)
+ {
+ default:
+ emit("/* unhandled */");
+ break;
+
+ case BaseType::UInt:
+ break;
+ case BaseType::Int:
+ emit("(");
+ emitIRType(ctx, inst->getDataType());
+ emit(")");
+ break;
+ case BaseType::Float:
+ emit("asfloat");
+ break;
+ }
+ break;
+
+
+ default:
+ SLANG_UNEXPECTED("unhandled codegen target");
+ break;
+ }
+
+ emit("(");
+ emitIROperand(ctx, inst->getOperand(0), mode, kEOp_General);
+ emit(")");
+ }
+ break;
+
default:
emit("/* unhandled */");
break;
@@ -3929,6 +3996,27 @@ struct EmitVisitor
maybeCloseParens(needClose);
}
+ BaseType extractBaseType(IRType* inType)
+ {
+ auto type = inType;
+ for(;;)
+ {
+ if(auto irBaseType = as<IRBasicType>(type))
+ {
+ return irBaseType->getBaseType();
+ }
+ else if(auto vecType = as<IRVectorType>(type))
+ {
+ type = vecType->getElementType();
+ continue;
+ }
+ else
+ {
+ return BaseType::Void;
+ }
+ }
+ }
+
void emitIRInst(
EmitContext* ctx,
IRInst* inst,
@@ -6565,6 +6653,14 @@ String emitEntryPoint(
#endif
validateIRModuleIfEnabled(compileRequest, irModule);
+ // Desguar any union types, since these will be illegal on
+ // various targets.
+ //
+ desugarUnionTypes(irModule);
+#if 0
+ dumpIRIfEnabled(compileRequest, irModule, "UNIONS DESUGARED");
+#endif
+ validateIRModuleIfEnabled(compileRequest, irModule);
// Any code that makes use of existential (interface) types
@@ -6595,8 +6691,6 @@ String emitEntryPoint(
//
specializeGenerics(irModule);
-
-
// Debugging code for IR transformations...
#if 0
dumpIRIfEnabled(compileRequest, irModule, "SPECIALIZED");
diff --git a/source/slang/expr-defs.h b/source/slang/expr-defs.h
index fb29f64de..bd4ba5038 100644
--- a/source/slang/expr-defs.h
+++ b/source/slang/expr-defs.h
@@ -193,3 +193,14 @@ RAW(
DeclRef<VarDeclBase> declRef;
)
END_SYNTAX_CLASS()
+
+ /// A type expression of the form `__TaggedUnion(A, ...)`.
+ ///
+ /// An expression of this form will resolve to a `TaggedUnionType`
+ /// when checked.
+ ///
+SYNTAX_CLASS(TaggedUnionTypeExpr, Expr)
+RAW(
+ List<TypeExp> caseTypes;
+)
+END_SYNTAX_CLASS() \ No newline at end of file
diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h
index 8d7a647f4..b6f8ce547 100644
--- a/source/slang/ir-inst-defs.h
+++ b/source/slang/ir-inst-defs.h
@@ -41,6 +41,8 @@ INST(Nop, nop, 0, 0)
INST(VectorType, Vec, 2, 0)
INST(MatrixType, Mat, 3, 0)
+ INST(TaggedUnionType, TaggedUnion, 0, 0)
+
/* Rate */
INST(ConstExprRate, ConstExpr, 0, 0)
INST(GroupSharedRate, GroupShared, 0, 0)
@@ -406,6 +408,11 @@ INST(ExtractExistentialValue, extractExistentialValue, 1, 0)
INST(ExtractExistentialType, extractExistentialType, 1, 0)
INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, 0)
+INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0)
+INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0)
+
+INST(BitCast, bitCast, 1, 0)
+
PSEUDO_INST(Pos)
PSEUDO_INST(PreInc)
diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h
index 0c56e8244..8662569ba 100644
--- a/source/slang/ir-insts.h
+++ b/source/slang/ir-insts.h
@@ -663,7 +663,6 @@ struct SharedIRBuilder
Dictionary<IRInstKey, IRInst*> globalValueNumberingMap;
Dictionary<IRConstantKey, IRConstant*> constantMap;
- Dictionary<Name*, IRWitnessTable*> witnessTableMap;
};
struct IRBuilderSourceLocRAII;
@@ -753,6 +752,13 @@ struct IRBuilder
IRType* const* paramTypes,
IRType* resultType);
+ IRFuncType* getFuncType(
+ List<IRType*> const& paramTypes,
+ IRType* resultType)
+ {
+ return getFuncType(paramTypes.Count(), paramTypes.Buffer(), resultType);
+ }
+
IRConstExprRate* getConstExprRate();
IRGroupSharedRate* getGroupSharedRate();
@@ -760,6 +766,16 @@ struct IRBuilder
IRRate* rate,
IRType* dataType);
+ IRType* getTaggedUnionType(
+ UInt caseCount,
+ IRType* const* caseTypes);
+
+ IRType* getTaggedUnionType(
+ List<IRType*> const& caseTypes)
+ {
+ return getTaggedUnionType(caseTypes.Count(), caseTypes.Buffer());
+ }
+
// Set the data type of an instruction, while preserving
// its rate, if any.
void setDataType(IRInst* inst, IRType* dataType);
@@ -794,6 +810,14 @@ struct IRBuilder
UInt argCount,
IRInst* const* args);
+ IRInst* emitCallInst(
+ IRType* type,
+ IRInst* func,
+ List<IRInst*> const& args)
+ {
+ return emitCallInst(type, func, args.Count(), args.Buffer());
+ }
+
IRInst* createIntrinsicInst(
IRType* type,
IROp op,
@@ -816,6 +840,13 @@ struct IRBuilder
UInt argCount,
IRInst* const* args);
+ IRInst* emitMakeVector(
+ IRType* type,
+ List<IRInst*> const& args)
+ {
+ return emitMakeVector(type, args.Count(), args.Buffer());
+ }
+
IRInst* emitMakeMatrix(
IRType* type,
UInt argCount,
@@ -831,6 +862,13 @@ struct IRBuilder
UInt argCount,
IRInst* const* args);
+ IRInst* emitMakeStruct(
+ IRType* type,
+ List<IRInst*> const& args)
+ {
+ return emitMakeStruct(type, args.Count(), args.Buffer());
+ }
+
IRInst* emitMakeExistential(
IRType* type,
IRInst* value,
@@ -1040,6 +1078,22 @@ struct IRBuilder
IRInst* param,
IRInst* val);
+ IRInst* emitExtractTaggedUnionTag(
+ IRInst* val);
+
+ IRInst* emitExtractTaggedUnionPayload(
+ IRType* type,
+ IRInst* val,
+ IRInst* tag);
+
+ IRInst* emitBitCast(
+ IRType* type,
+ IRInst* val);
+
+ //
+ // Decorations
+ //
+
IRDecoration* addDecoration(IRInst* value, IROp op, IRInst* const* operands, Int operandCount);
IRDecoration* addDecoration(IRInst* value, IROp op)
diff --git a/source/slang/ir-link.cpp b/source/slang/ir-link.cpp
index 25d3b40b6..231ee81d2 100644
--- a/source/slang/ir-link.cpp
+++ b/source/slang/ir-link.cpp
@@ -1284,6 +1284,30 @@ IRFunc* specializeIRForEntryPoint(
cloneValue(context, bindInst);
}
+ // HACK: we need to ensure that any tagged union types
+ // in the IR module have layout information copied over to them.
+ //
+ // Note that we do this *after* cloning the `bindGlobalGenericParam`
+ // instructions, since we expected the tagged union type(s) to
+ // be referenced by them.
+ //
+ for( auto taggedUnionTypeLayout : entryPointLayout->taggedUnionTypeLayouts )
+ {
+ auto taggedUnionType = taggedUnionTypeLayout->getType();
+ auto mangledName = getMangledTypeName(taggedUnionType);
+
+ RefPtr<IRSpecSymbol> sym;
+ if(!context->getSymbols().TryGetValue(mangledName, sym))
+ continue;
+
+ IRInst* clonedType = findClonedValue(context, sym->irGlobalValue);
+ if(!clonedType)
+ continue;
+
+ context->builder->addLayoutDecoration(clonedType, taggedUnionTypeLayout);
+ }
+
+
// TODO: *technically* we should consider the case where
// we have global variables with initializers, since
diff --git a/source/slang/ir-ssa.cpp b/source/slang/ir-ssa.cpp
index d893137de..8c5db68fe 100644
--- a/source/slang/ir-ssa.cpp
+++ b/source/slang/ir-ssa.cpp
@@ -1068,16 +1068,9 @@ void constructSSA(ConstructSSAContext* context)
newArgCount,
newArgs.Buffer());
- // Swap decorations (all children, really) over to the new instruction
+ // Transfer decorations (a terminator should have no children) over to the new instruction.
//
- // TODO: We might want to encapsualte this in a reusable subroutine if
- // we often need to copy decorations from one instruction to another.
- //
- while( auto firstChild = oldTerminator->getFirstDecoration() )
- {
- firstChild->removeFromParent();
- firstChild->insertAtEnd(newTerminator);
- }
+ oldTerminator->transferDecorationsTo(newTerminator);
// A terminator better not have uses, so we shouldn't have
// to replace them.
diff --git a/source/slang/ir-union.cpp b/source/slang/ir-union.cpp
new file mode 100644
index 000000000..c50d669e1
--- /dev/null
+++ b/source/slang/ir-union.cpp
@@ -0,0 +1,776 @@
+// ir-union.cpp
+#include "ir-union.h"
+
+#include "ir.h"
+#include "ir-insts.h"
+
+namespace Slang {
+
+// This file will implement a pass to replace any union types (currently
+// just tagged unions) with plain `struct` types that attempt to provide
+// equivalent semantics. This will necessarily be a bit fragile, and there
+// will be fundamental limits to what the translation can support without
+// improved features in the target shading languages/ILs.
+
+struct DesugarUnionTypesContext
+{
+ // We'll start with some basic state that we need to get the job done.
+ //
+ // This includes the IR module we are to process, as well as IR building
+ // state that we will initialize once and then use throughout the pass.
+ //
+ IRModule* module;
+ SharedIRBuilder sharedBuilderStorage;
+ IRBuilder builderStorage;
+ IRBuilder* getBuilder() { return &builderStorage; }
+
+ // Because we will be replacing instructions that refer to unions with
+ // different logic, we'll want to remove the original instructions.
+ // However, we need to be careful about modifying the IR tree while also
+ // iterating it, and to keep things simple for ourselves we'll go ahead
+ // and build up a list of instruction to remove along the way, and then
+ // remove them all at the end.
+ //
+ List<IRInst*> instsToRemove;
+
+ // The overall flow of the pass is pretty simple, so we will walk through it now.
+ //
+ void processModule()
+ {
+ // We start by initializing our IR building state.
+ //
+ sharedBuilderStorage.session = module->session;
+ sharedBuilderStorage.module = module;
+ builderStorage.sharedBuilder = &sharedBuilderStorage;
+
+ // Next, we will search for any instruction that create or use
+ // union types, and process them accordingingly (usually by
+ // constructing a new instruction to replace them).
+ //
+ processInstRec(module->getModuleInst());
+
+ // Along the way we will build up a list of the tagged union
+ // types that we encountered, but we will refrain from replacing
+ // them until we are done (so that we always know that the instructions
+ // we process above refer to the original type, and not its
+ // replacement.
+ //
+ for( auto info : taggedUnionInfos )
+ {
+ auto taggedUnionType = info->taggedUnionType;
+ auto replacementInst = info->replacementInst;
+
+ // TODO: We should consider transferring decorations from the source
+ // type to the destination, but doing so carelessly could create
+ // problems, since an IR struct type shouldn't have, e.g., a
+ // `TaggedUnionTypeLayout` attached to it.
+
+ taggedUnionType->replaceUsesWith(replacementInst);
+ taggedUnionType->removeAndDeallocate();
+ }
+
+ // As described previously, we build up the `instsToRemove` list as
+ // we iterate so that we can remove them all here and not risk
+ // modifying the IR tree while also walking it.
+ //
+ // TODO: This might be overkill and we could conceivably just be
+ // a bit careful in `processInstRec`.
+ //
+ for(auto inst : instsToRemove)
+ {
+ inst->removeAndDeallocate();
+ }
+ }
+
+ // In order to replace a (tagged) union type, we will need to know
+ // something about it, and we will use the `TaggedUnionInfo` type
+ // to collect all the relevant information.
+ //
+ struct TaggedUnionInfo : public RefObject
+ {
+ // We obviously need to know the tagged union itself, and
+ // we will also use this structure to track the instruction
+ // (an IR struct type) that will replace it.
+ //
+ IRTaggedUnionType* taggedUnionType;
+ IRInst* replacementInst;
+
+ // In order to compute a suitable layout for the replacement
+ // `struct` type we need to know how the tagged union itself
+ // would be laid out in memory, so we require that all tagged
+ // unions in the generated IR have an associated (target-specific)
+ // layout.
+ //
+ TaggedUnionTypeLayout* taggedUnionTypeLayout;
+
+ // The basic approach we will use 16-byte chunks (represented as an array
+ // of `uint4`s) to reprent the "bulk" of a type, and then use a single field
+ // that could be up to 12 bytes to represent the "rest" of the type.
+ //
+ // Note that there are deeply ingrained assumptions here that all types
+ // are at least four bytes in size (so that unions cannot easily
+ // accomodate `half` value), and that any types *larger* than four bytes
+ // will need to be loaded/stored via multiple 4-byte loads/stores.
+ //
+ // With the basic idea out of the way, we need an IR level field
+ // in our struct to hold the bulk data, which comprises a "key" for
+ // looking up the field, and the type of the field itself. We also
+ // keep track of how many bytes we put in our bulk storage.
+ //
+ // The bulk field might be:
+ //
+ // - null, if none of the case types was 16 bytes or more
+ // - a single `uint4` for between 16 and 31 (inclusive) bytes
+ // - an array of `uint4`s for 32 or more bytes
+ //
+ UInt64 bulkSize = 0;
+ IRInst* bulkFieldKey = nullptr;
+ IRType* bulkFieldType = nullptr;
+
+ // The same basic idea then applies to the rest of the data.
+ //
+ // The "rest" field will be either be absent (if the size of the
+ // type was evently divisible by 16), a scalar `uint`, or else
+ // a 2- or 3-component vector of `uint`.
+ //
+ UInt64 restSize = 0;
+ IRInst* restFieldKey = nullptr;
+ IRType* restFieldType = nullptr;
+
+ // Finally, since we are currently working with tagged unions,
+ // we need a field to hold the tag, which will always be allocated
+ // after the fields that hold the bulk/rest of the payload.
+ //
+ // This field is always a single `uint`.
+ //
+ // TODO: if/when we support untagged unions, they could be handled
+ // by having this field be null.
+ //
+ IRInst* tagFieldKey;
+ };
+
+ // We will build up a list of all the tagged union types we encounter,
+ // so that we can replace them with the synthesized types when we are done.
+ //
+ List<RefPtr<TaggedUnionInfo>> taggedUnionInfos;
+
+ // It is possible that we will see the same tagged union type referenced
+ // many times in the IR, but we only want to synthesize the information
+ // above (including the various IR structures) once, so we also maintain
+ // a map from the original IR type to the corresponding information.
+ //
+ Dictionary<IRInst*, TaggedUnionInfo*> mapIRTypeToTaggedUnionInfo;
+
+ // We will process all instructions in the module in a single recursive walk.
+ //
+ void processInstRec(IRInst* inst)
+ {
+ processInst(inst);
+
+ for( auto child : inst->getChildren() )
+ {
+ processInstRec(child);
+ }
+ }
+ //
+ // At each instruction, we will check if it is one of the union-related instructions
+ // we need to replace, and process it accordingly.
+ //
+ void processInst(IRInst* inst)
+ {
+ switch( inst->op )
+ {
+ default:
+ // Any instruction not listed below either doesn't involve union types,
+ // or handles them in a hands-off fashion that we don't need to care about.
+ //
+ // E.g., a `load` of a union type from a constant buffer will turn into
+ // a load of the replacement `struct` type once we are done, and nothing
+ // needs to be done to the `load` instruction.
+ //
+ break;
+
+ case kIROp_TaggedUnionType:
+ {
+ // We clearly need to process the tagged union type itself, but the actual
+ // work is handled by other functions. All we need to do here is ensure
+ // that the information for this type gets generated, and then we can
+ // rely on the main `processModule` function to do the actual replacement later.
+ //
+ auto type = cast<IRTaggedUnionType>(inst);
+ getTaggedUnionInfo(type);
+ }
+ break;
+
+ case kIROp_ExtractTaggedUnionTag:
+ {
+ // The case of extracting the tag from a tagged union is relatively
+ // simple, because the replacement type will have a dedicated field or it.
+ //
+ // We start by finding the tagged union value the instruction is operating
+ // on, and then looking up the information for its type (which had
+ // better be a tagged union type).
+ //
+ auto taggedUnionVal = inst->getOperand(0);
+ auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType());
+
+ // Because the replacement type will have an explicit field for the tag,
+ // we can simply emit a single field-extract instruction to read its value
+ // out.
+ //
+ auto builder = getBuilder();
+ builder->setInsertBefore(inst);
+ auto replacement = builder->emitFieldExtract(
+ inst->getFullType(),
+ taggedUnionVal,
+ taggedUnionInfo->tagFieldKey);
+
+ // Now we can replace anything that used the original instruction with
+ // the new field-extract operation, and add this instruction to the
+ // list for later removal.
+ //
+ inst->replaceUsesWith(replacement);
+ instsToRemove.Add(inst);
+ }
+ break;
+
+ case kIROp_ExtractTaggedUnionPayload:
+ {
+ // The most interesting case is when we are trying to extract a particular
+ // payload (one of the case types) from a union. We may need to extract
+ // one or more fields from the data stored in the union's replacement
+ // type (the bulk/rest fields), and we may also have to convert them
+ // to the type expected via bit-casts.
+
+ // We can start things off easily enough by extracting the tagged union
+ // value being operated on, as well as the information for its type.
+ //
+ auto taggedUnionVal = inst->getOperand(0);
+ auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType());
+
+ // Next we need to figure out which case is being extracted from the union.
+ // The operand for the case tag should be a literal by construction.
+ //
+ auto caseTagVal = inst->getOperand(1);
+ auto caseTagConst = as<IRIntLit>(caseTagVal);
+ SLANG_ASSERT(caseTagConst);
+
+ // The case type we are extracting will be the result type of the instruciton.
+ //
+ auto caseType = inst->getDataType();
+ //
+ // The tag value itself will be the index of the case type in the union
+ // type (and its layout).
+ //
+ auto caseTagIndex = UInt(caseTagConst->getValue());
+
+ // We can use the case tag value to look up the layout for the particular
+ // case type we are extracting (this will allow us to resolve byte offsets
+ // for fields, etc.).
+ //
+ auto taggedUnionTypeLayout = taggedUnionInfo->taggedUnionTypeLayout;
+ SLANG_ASSERT(caseTagIndex < taggedUnionTypeLayout->caseTypeLayouts.Count());
+ auto caseTypeLayout = taggedUnionTypeLayout->caseTypeLayouts[caseTagIndex];
+
+ // At this point we know the type we are trying to extract, as well
+ // as its layout. We will defer the actual implementation of extraction
+ // to a (recursive) subroutine that can extract a (sub-)field from the
+ // union at a given byte offset. Since we are extracting a full case
+ // right now, the byte offset will be zero.
+ //
+ auto payloadVal = extractPayload(
+ taggedUnionInfo,
+ taggedUnionVal,
+ caseType,
+ caseTypeLayout,
+ 0);
+
+ // TODO: There is a significant flaw in the above approach when
+ // the case type might be (or contain) an array. If we have a setup
+ // like the following:
+ //
+ // union SomeUnion { float someCase[100]; ... }
+ // ...
+ // float result = someUnion.someCase[someIndex];
+ //
+ // The current logic would desugar this into something like:
+ //
+ // struct SomeUnion { uint4 bulk[100]; ... }
+ // ...
+ // float[] tmp = { asfloat(someUnion.bulk[0].x), asfloat(someUnion.bulk[1].x), ... }
+ // float result = tmp[someIndex];
+ //
+ // The result is that we copy an entire 100-element array into local memory
+ // just to fetch a single element, when it would be much nicer to just do:
+ //
+ // float result = asfloat(someUnion.bulk[someIndex].x);
+ //
+ // Achieving the latter code requires that rather than blindly translate
+ // the `extractTaggedUnionPayload` instruction into a semantically equiavlent
+ // value (which might lead to a big copy in the end), we should transitively
+ // chase down any "access chains" off of `inst` and see what leaf values are
+ // actually needed, and generated more tailored extraction logic for just
+ // the elements/fields that actually get referenced.
+ //
+ // The more refined approach can be built on top of many of the same primitives,
+ // so for now we will resign ourselves to the simpler but potentially less
+ // efficient approach.
+
+ // Now that we've extracted the value for the payload from the fields of
+ // the replacement struct, we can use that extracted value to replace
+ // this instruction, and schedule the original instruction for removal.
+ //
+ inst->replaceUsesWith(payloadVal);
+ instsToRemove.Add(inst);
+ }
+ break;
+ }
+ }
+
+ // The `extractPayload` operation is the most important bit of translation we
+ // need to do to make unions work. We have as input the following:
+ //
+ IRInst* extractPayload(
+
+ // - Information about a tagged union type and its layout.
+ TaggedUnionInfo* taggedUnionInfo,
+
+ // - A single value of that tagged unon type.
+ IRInst* taggedUnionVal,
+
+ // - Type type of some "payload" field we want to extract from the union.
+ IRType* payloadType,
+
+ // - The memory layout of that payload type.
+ TypeLayout* payloadTypeLayout,
+
+ // - The byte offset at which we want to fetch the payload.
+ UInt64 payloadOffset)
+ {
+ // We are going to be building some IR code no matter what.
+ //
+ auto builder = getBuilder();
+
+ // The basic approach here will be to look at the type we
+ // are trying to extract from the union, and whenever possible
+ // recursively walk its structure so that we can express things
+ // in terms of extraction of smaller/simpler types.
+ //
+ if( auto irStructType = as<IRStructType>(payloadType) )
+ {
+ // A structure type is a nice recursive case: we simply
+ // want to extract each of its field recursively, and
+ // then construct a fresh value of the `struct` type.
+
+ // In all of the cases of this function we expect/require
+ // there to be complete type layout information for the
+ // types involved.
+ //
+ auto structTypeLayout = dynamic_cast<StructTypeLayout*>(payloadTypeLayout);
+ SLANG_ASSERT(structTypeLayout);
+
+ // We are going to emit code to extract each of the fields
+ // and collect them to use as operands to a `makeStruct`.
+ //
+ List<IRInst*> fieldVals;
+
+ // We need to walk over the fields in the order the IR expects them
+ UInt fieldCounter = 0;
+ for( auto irField : irStructType->getFields() )
+ {
+ IRType* fieldType = irField->getFieldType();
+
+ // TODO: We need to confirm/enforce that the fields of the
+ // IR struct and the fields of the layout still align.
+ //
+ UInt fieldIndex = fieldCounter++;
+ auto fieldLayout = structTypeLayout->fields[fieldIndex];
+ auto fieldTypeLayout = fieldLayout->getTypeLayout();
+
+ // The offset of the field can be computed from the base
+ // offset passed in, plus the reflection data for the field.
+ //
+ UInt64 fieldOffset = payloadOffset;
+ if(auto resInfo = fieldLayout->FindResourceInfo(LayoutResourceKind::Uniform))
+ fieldOffset += resInfo->index;
+
+ // We make a recursive call to extract each field, expecting
+ // that this will bottom out eventually.
+ //
+ IRInst* fieldVal = extractPayload(
+ taggedUnionInfo,
+ taggedUnionVal,
+ fieldType,
+ fieldTypeLayout,
+ fieldOffset);
+ fieldVals.Add(fieldVal);
+ }
+
+ // The final value is then just a new struct constructed from
+ // the extracted field values.
+ //
+ auto payloadVal = builder->emitMakeStruct(irStructType, fieldVals);
+ return payloadVal;
+ }
+ else if( auto vecType = as<IRVectorType>(payloadType) )
+ {
+ auto elementType = vecType->getElementType();
+
+ // We expect that by the time we are desugaring union types
+ // all vector types have literal constant values for their
+ // element count.
+ //
+ auto elementCountVal = vecType->getElementCount();
+ auto elementCountConst = as<IRIntLit>(elementCountVal);
+ SLANG_ASSERT(elementCountConst);
+ UInt elementCount = UInt(elementCountConst->getValue());
+
+ // HACK: There is currently no `VectorTypeLayout` and thus
+ // no way to query the layout of the elements of a vector
+ // type. Until that gets added we will kludge things here.
+ //
+ TypeLayout* elementTypeLayout = nullptr;
+ size_t elementSize = 0;
+ if(auto resInfo = payloadTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform))
+ elementSize = resInfo->count.getFiniteValue() / elementCount;
+
+ // Similar to the `struct` case above, we will extract a
+ // value for each element of the vector, and then use
+ // `makeVector` to construct the result value.
+ //
+ List<IRInst*> elementVals;
+ for(UInt ii = 0; ii < elementCount; ++ii)
+ {
+ auto elementVal = extractPayload(
+ taggedUnionInfo,
+ taggedUnionVal,
+ elementType,
+ elementTypeLayout,
+ payloadOffset + ii*elementSize);
+ elementVals.Add(elementVal);
+ }
+ return builder->emitMakeVector(vecType, elementVals);
+ }
+ else if( auto matType = as<IRMatrixType>(payloadType) )
+ {
+ SLANG_UNIMPLEMENTED_X("matrix in union type");
+ }
+ else if( auto arrayType = as<IRArrayType>(payloadType) )
+ {
+ SLANG_UNIMPLEMENTED_X("array in union type");
+ }
+ else
+ {
+ // If none of the above cases match, then we assume that
+ // we have an individual scalar field that we need to fetch.
+ //
+ UInt64 payloadSize = 0;
+ if( auto resInfo = payloadTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) )
+ {
+ // TODO: somebody before this point should generate an error if
+ // we have a `union` type that contains a potentially unbounded
+ // amount of data.
+ //
+ payloadSize = resInfo->count.getFiniteValue();
+ }
+
+ if( payloadSize != 4 )
+ {
+ // TODO: We should handle the case of 64-bit fields by fetching
+ // two `uint` values to form a `uint2`, and then using an
+ // appropriate bit-cast to get from `uint2` to, e.g., `double`.
+ //
+ // The case of 16-bit and smaller fields is more troublesome, but
+ // in the worst case we can load a `uint` and then use bitwise
+ // ops to extract what we need before bitcasting.
+ //
+ // The right long-term solution is for downstream languages to have
+ // better support for raw memory addressing.
+
+ SLANG_UNIMPLEMENTED_X("leaf union field with size other than 4 bytes");
+ }
+
+ // We know that we want to fetch a value of size `payloadSize`, and
+ // we have a known base value and an initial offset into it.
+ //
+ IRInst* baseVal = taggedUnionVal;
+ UInt64 offset = payloadOffset;
+
+ // We are going to refine our `baseVal` and `offset` as we go, by
+ // trying to narrow down the data we will access in the `struct`
+ // type that will provide storage for the union.
+ //
+ // The first thing we want to check is if the value sits in the
+ // "bulk" part of the storage, or the "rest."
+ //
+ UInt64 bulkSize = taggedUnionInfo->bulkSize;
+ if( offset < bulkSize )
+ {
+ // If the value starts in the bulk area, then the whole
+ // thing had better fit in the bulk area. The 16-byte
+ // granularity rules for constant buffers should ensure
+ // this property for us on current targets.
+ //
+ SLANG_ASSERT(offset + payloadSize <= bulkSize);
+
+ // Since we know we'll be accessing the bulk storage,
+ // we will extract it here. The extracted field will
+ // be our new base value, but the `offset` doesn't need
+ // to be updated since the bulk field sits at offset 0.
+ //
+ baseVal = builder->emitFieldExtract(
+ taggedUnionInfo->bulkFieldType,
+ baseVal,
+ taggedUnionInfo->bulkFieldKey);
+
+ // The bulk storage could be an array, if there are 32
+ // or more bytes of bulk storage.
+ //
+ if( auto baseArrayType = as<IRArrayType>(baseVal->getDataType()) )
+ {
+ // If an array was allocated for bulk storage then
+ // our leaf value resides entirely within a single
+ // element (due to constant buffer layout rules),
+ // and so we will fetch the appropriate element here.
+ //
+ // We will change our `baseVal` to the extracted element,
+ // and then also adjust our `offset` to be relative
+ // to that element.
+ //
+ size_t bulkElementSize = 16;
+ auto index = offset / bulkElementSize;
+ baseVal = builder->emitElementExtract(
+ baseArrayType->getElementType(),
+ baseVal,
+ builder->getIntValue(builder->getIntType(), index));
+ offset -= index*bulkElementSize;
+ }
+ }
+ else
+ {
+ // If the offset of the field we want is past the end of
+ // the bulk field then it must sit inside of the rest field,
+ // and we'll extract it here. This establishes a new
+ // base value, and we adjust the `offset` to be relative
+ // to the rest field (which starts at an offset equal to `bulkSize`).
+ //
+ baseVal = builder->emitFieldExtract(
+ taggedUnionInfo->restFieldType,
+ baseVal,
+ taggedUnionInfo->restFieldKey);
+ offset -= bulkSize;
+ }
+
+ // We've now extracted a field that could be either a scalar or
+ // a vector, and we have an offset into it. In the case where
+ // the base value is a vector, we will extract out the appropriate
+ // element.
+ //
+ if( auto baseVecType = as<IRVectorType>(baseVal->getDataType()) )
+ {
+ size_t vecElementSize = 4;
+ auto index = offset / vecElementSize;
+ baseVal = builder->emitElementExtract(
+ baseVecType->getElementType(),
+ baseVal,
+ builder->getIntValue(builder->getIntType(), index));
+ offset -= index*vecElementSize;
+ }
+
+ // At this point, our `baseVal` should be a single `uint`, and
+ // it should provide the storage for the exact thing we wanted
+ // to access (under the assumption that we always fetch 4 bytes
+ // on 4-byte alignment).
+ //
+ IRInst* payloadVal = baseVal;
+ SLANG_ASSERT(offset == 0);
+
+ // TODO: we could imagine adding logic here to handle types less
+ // than 4 bytes in size by shifting and masking the value we
+ // just loaded.
+
+ // The payload field we were trying to extract might have a type
+ // other than `uint`, and to handle that case we need to employ
+ // a bit-cast to get to the desired type.
+ //
+ if( payloadVal->getDataType() != payloadType )
+ {
+ payloadVal = builder->emitBitCast(
+ payloadType,
+ payloadVal);
+ }
+ return payloadVal;
+ }
+ }
+
+ // All of the logic so far as assumed we can just call `getTaggedUnionInfo`
+ // and have easy access to all the required information and the
+ // synthesized replacement type.
+ //
+ TaggedUnionInfo* getTaggedUnionInfo(IRType* type)
+ {
+ // The big picture is fairly simple: we will lazily build and
+ // memoize the information about tagged unions.
+ //
+ {
+ TaggedUnionInfo* info = nullptr;
+ if(mapIRTypeToTaggedUnionInfo.TryGetValue(type, info))
+ return info;
+ }
+
+ // When we don't find information in our memo-cache, we
+ // will construct it and add it to both the memo-cache
+ // *and* a global list of all tagged unions encountered,
+ // so that we can replacement them later.
+ //
+ auto info = createTaggedUnionInfo(type);
+ mapIRTypeToTaggedUnionInfo.Add(type, info.Ptr());
+ taggedUnionInfos.Add(info);
+
+ return info;
+ }
+
+ // The actual logic for creating a `TaggedUnionInfo` is relatively
+ // straightforward once we've decided what information we need.
+ //
+ RefPtr<TaggedUnionInfo> createTaggedUnionInfo(IRType* type)
+ {
+ // We expect that any type used as an operation to one of the
+ // `extractTaggedUnion*` operations must be an IR tagged union.
+ //
+ // Note: If/when we ever expose `union`s to user and allow
+ // then to create *generic* tagged union types it might appear
+ // that this needs to be changed to account for a `specialize`
+ // instruction in place of a concrete tagged union, but in
+ // practice this pass needs to be performed late enough that
+ // any such generic should be fully specialized.
+ //
+ auto taggedUnionType = as<IRTaggedUnionType>(type);
+ SLANG_ASSERT(taggedUnionType);
+
+ RefPtr<TaggedUnionInfo> info = new TaggedUnionInfo();
+ info->taggedUnionType = taggedUnionType;
+
+ // We are going to create an instruction to replace `type`,
+ // and thus will be placing it into the same parent.
+ //
+ auto builder = getBuilder();
+ builder->setInsertBefore(type);
+
+ // A tagged union type will be replaced with an ordinary
+ // `struct` type with fields to store all the relevant
+ // data from any of the cases, plus a tag field.
+ //
+ auto structType = builder->createStructType();
+ info->replacementInst = structType;
+
+ // We require/expect the earlier code generation steps to have
+ // assocaited a layout with every tagged union that appears in
+ // the code.
+ //
+ auto layoutDecoration = type->findDecoration<IRLayoutDecoration>();
+ SLANG_ASSERT(layoutDecoration);
+ auto layout = layoutDecoration->getLayout();
+ SLANG_ASSERT(layout);
+ auto taggedUnionTypeLayout = dynamic_cast<TaggedUnionTypeLayout*>(layout);
+ SLANG_ASSERT(taggedUnionTypeLayout);
+
+ info->taggedUnionTypeLayout = taggedUnionTypeLayout;
+
+ // The size of the "payload" for the different cases (everything but
+ // the tag) is taken to be the offset of the tag itself.
+ //
+ // TODO: this might be inaccurate if the payload size isn't a multiple
+ // of the tag's alignment. We should deal with that when/if we support
+ // types smaller than 4 bytes in unions.
+ //
+ auto payloadSize = taggedUnionTypeLayout->tagOffset.getFiniteValue();
+
+ // We are going to be construction IR code that makes use of the `int`
+ // and `uint` types in several cases, so we go ahead and get a pointer
+ // to those types here.
+ //
+ auto intType = getBuilder()->getIntType();
+ auto uintType = getBuilder()->getBasicType(BaseType::UInt);
+
+ // For now we will use a simple stragegy for how we encode a union,
+ // which depends only on the total number of bytes needed, and not
+ // on the makeup of the values being stored.
+ //
+ // We will start by allocating one or more `uint4` values (in an
+ // array for the "or more" case) to hold the bulk of any large
+ // payload value.
+ //
+ size_t bulkVectorSize = 16; // Note: assuming `sizeof(uint4) == 16` on all targets
+ auto bulkVectorCount = payloadSize / bulkVectorSize;
+ auto bulkFieldSize = bulkVectorCount * bulkVectorSize;
+ if( bulkVectorCount )
+ {
+ IRType* bulkFieldType = builder->getVectorType(
+ uintType,
+ builder->getIntValue(intType, 4));
+
+ if( bulkVectorCount > 1 )
+ {
+ bulkFieldType = builder->getArrayType(
+ bulkFieldType,
+ builder->getIntValue(intType, bulkVectorCount));
+ }
+
+ auto bulkFieldKey = builder->createStructKey();
+ builder->createStructField(structType, bulkFieldKey, bulkFieldType);
+
+ info->bulkFieldKey = bulkFieldKey;
+ info->bulkFieldType = bulkFieldType;
+ }
+ info->bulkSize = bulkFieldSize;
+
+ // The rest of the data (anything that doesn't fit in the bulk field),
+ // will get allocated into a single scalar or vector of `uint`.
+ //
+ auto restSize = payloadSize - bulkFieldSize;
+ if( restSize )
+ {
+ size_t restElementSize = 4; // assuming `sizeof(uint) == 4` on all targets
+ auto restElementCount = restSize / restElementSize;
+ auto restFieldSize = restElementSize * restElementCount;
+ SLANG_ASSERT(restFieldSize == restSize); // Note: all our current targets have minimum 4-byte storage granularity
+
+ IRType* restFieldType = uintType;
+ if( restElementCount > 1 )
+ {
+ restFieldType = builder->getVectorType(
+ restFieldType,
+ builder->getIntValue(intType, restElementCount));
+ }
+
+ auto restFieldKey = builder->createStructKey();
+ builder->createStructField(structType, restFieldKey, restFieldType);
+
+ info->restFieldKey = restFieldKey;
+ info->restFieldType = restFieldType;
+ info->restSize = restFieldSize;
+ }
+
+ // Finally, we add a field to represent the tag.
+ //
+ auto tagFieldType = uintType;
+ auto tagFieldKey = builder->createStructKey();
+ builder->createStructField(structType, tagFieldKey, tagFieldType);
+
+ info->tagFieldKey = tagFieldKey;
+
+ return info;
+ }
+};
+
+void desugarUnionTypes(
+ IRModule* module)
+{
+ DesugarUnionTypesContext context;
+ context.module = module;
+
+ context.processModule();
+}
+
+} // namespace Slang
diff --git a/source/slang/ir-union.h b/source/slang/ir-union.h
new file mode 100644
index 000000000..58de4e81e
--- /dev/null
+++ b/source/slang/ir-union.h
@@ -0,0 +1,18 @@
+// ir-union.h
+#pragma once
+
+namespace Slang {
+
+struct IRModule;
+
+ /// Desugar any unions types, and code using them, in `module`
+ ///
+ /// Union types will be replaced with ordinary `struct` types that store
+ /// the data of the underlying type as a "bag of bits" and references
+ /// to cases of the union will be replaced with logic to extract the
+ /// relevant bits.
+ ///
+void desugarUnionTypes(
+ IRModule* module);
+
+} // namespace Slang
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 48716bd87..8f33034b7 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -1731,6 +1731,18 @@ namespace Slang
operands);
}
+ IRType* IRBuilder::getTaggedUnionType(
+ UInt caseCount,
+ IRType* const* caseTypes)
+ {
+ return (IRType*) findOrEmitHoistableInst(
+ this,
+ getTypeKind(),
+ kIROp_TaggedUnionType,
+ caseCount,
+ (IRInst* const*) caseTypes);
+ }
+
void IRBuilder::setDataType(IRInst* inst, IRType* dataType)
{
if (auto oldRateQualifiedType = as<IRRateQualifiedType>(inst->getFullType()))
@@ -2684,6 +2696,50 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitExtractTaggedUnionTag(
+ IRInst* val)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_ExtractTaggedUnionTag,
+ getBasicType(BaseType::UInt),
+ val);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitExtractTaggedUnionPayload(
+ IRType* type,
+ IRInst* val,
+ IRInst* tag)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_ExtractTaggedUnionPayload,
+ type,
+ val,
+ tag);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitBitCast(
+ IRType* type,
+ IRInst* val)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_BitCast,
+ type,
+ val);
+ addInst(inst);
+ return inst;
+ }
+
+ //
+ // Decorations
+ //
+
IRDecoration* IRBuilder::addDecoration(IRInst* value, IROp op, IRInst* const* operands, Int operandCount)
{
auto decoration = createInstWithTrailingArgs<IRDecoration>(
@@ -3796,6 +3852,14 @@ namespace Slang
}
}
+ void IRInst::transferDecorationsTo(IRInst* target)
+ {
+ while( auto decoration = getFirstDecoration() )
+ {
+ decoration->removeFromParent();
+ decoration->insertAtStart(target);
+ }
+ }
bool IRInst::mightHaveSideEffects()
{
diff --git a/source/slang/ir.h b/source/slang/ir.h
index de644a709..e62bb2249 100644
--- a/source/slang/ir.h
+++ b/source/slang/ir.h
@@ -373,6 +373,9 @@ struct IRInst
// for those values.
void removeArguments();
+ /// Transfer any decorations of this instruction to the `target` instruction.
+ void transferDecorationsTo(IRInst* target);
+
/// Does this instruction have any uses?
bool hasUses() const { return firstUse != nullptr; }
@@ -959,18 +962,23 @@ struct IRStructField : IRInst
// *not* contain the keys, because code needs to be able to
// reference the keys from scopes outside of the struct.
//
-struct IRStructType : IRInst
+struct IRStructType : IRType
{
IRInstList<IRStructField> getFields() { return IRInstList<IRStructField>(getChildren()); }
IR_LEAF_ISA(StructType)
};
-struct IRInterfaceType : IRInst
+struct IRInterfaceType : IRType
{
IR_LEAF_ISA(InterfaceType)
};
+struct IRTaggedUnionType : IRType
+{
+ IR_LEAF_ISA(TaggedUnionType)
+};
+
/// @brief A global value that potentially holds executable code.
///
struct IRGlobalValueWithCode : IRInst
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index b1bc63fa1..3e981dbc5 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -1121,6 +1121,259 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
requirementKey));
}
+ LoweredValInfo visitTaggedUnionSubtypeWitness(
+ TaggedUnionSubtypeWitness* val)
+ {
+ // The sub-type in this case is a tagged union `A | B | ...`,
+ // and the witness holds an array of witnesses showing that each
+ // "case" (`A`, `B`, etc.) is a subtype of the super-type.
+
+ // We will start by getting the IR-level representation of the
+ // sub type (the tagged union type).
+ //
+ auto irTaggedUnionType = lowerType(context, val->sub);
+
+ // We can turn each of those per-case witnesses into a witness
+ // table value:
+ //
+ auto caseCount = val->caseWitnesses.Count();
+ List<IRInst*> caseWitnessTables;
+ for( auto caseWitness : val->caseWitnesses )
+ {
+ auto caseWitnessTable = lowerSimpleVal(context, caseWitness);
+ caseWitnessTables.Add(caseWitnessTable);
+ }
+
+ // Now we need to synthesize a witness table for the tagged union
+ // value, showing how it can implement all of the requirements
+ // of the super type by delegating to the appropriate implementation
+ // on a per-case basis.
+ //
+ // We will assume here that the super-type is an interface, and it
+ // will be left to the front-end to ensure this property.
+ //
+ auto supDeclRefType = val->sup->As<DeclRefType>();
+ if(!supDeclRefType)
+ {
+ SLANG_UNEXPECTED("super-type not a decl-ref type when generating tagged union witness table");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+ auto supInterfaceDeclRef = supDeclRefType->declRef.As<InterfaceDecl>();
+ if( !supInterfaceDeclRef )
+ {
+ SLANG_UNEXPECTED("super-type not an interface type when generating tagged union witness table");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+
+ auto irWitnessTable = getBuilder()->createWitnessTable();
+
+ // Now we will iterate over the requirements (members) of the
+ // interface and try to synthesize an appropriate value for each.
+ //
+ for( auto reqDeclRef : getMembers(supInterfaceDeclRef) )
+ {
+ // TODO: if there are any members we shouldn't process as a requirement,
+ // then we should detect and skip them here.
+ //
+
+ // Every interface requirement will have a unique key that is used
+ // when looking up the requirement in a concrete witness table.
+ //
+ auto irReqKey = getInterfaceRequirementKey(context, reqDeclRef.getDecl());
+
+ // We expect that each of the witness tables in `caseWitnessTables`
+ // will have an entry to match these keys. However, we may not
+ // have a concrete `IRWitnessTable` for each of the case types, either
+ // because they are a specialization of a generic (so that the witness
+ // table reference is a `specialize` instruction at this point), or
+ // they are a type external to this module (so that we have a declaration
+ // rather than a definition of the witness table).
+
+ // Our task is to create an IR value that can satisfy the interface
+ // requirement for the tagged union type, by appropriately delegating
+ // to the implementations of the same requirement in the case types.
+ //
+ IRInst* irSatisfyingVal = nullptr;
+
+
+
+ if(auto callableDeclRef = reqDeclRef.As<CallableDecl>())
+ {
+ // We have something callable, so we need to synthesize
+ // a function to satisfy it.
+ //
+ auto irFunc = getBuilder()->createFunc();
+ irSatisfyingVal = irFunc;
+
+ IRBuilder subBuilderStorage;
+ auto subBuilder = &subBuilderStorage;
+ subBuilder->sharedBuilder = getBuilder()->sharedBuilder;
+ subBuilder->setInsertInto(irFunc);
+
+ // We will start by setting up the function parameters,
+ // which live in the entry block of the IR function.
+ //
+ auto entryBlock = subBuilder->emitBlock();
+ subBuilder->setInsertInto(entryBlock);
+
+ // Create a `this` parameter of the tagged-union type.
+ //
+ // TODO: need to handle the `[mutating]` case here...
+ //
+ auto irThisType = irTaggedUnionType;
+ auto irThisParam = subBuilder->emitParam(irThisType);
+
+ List<IRType*> irParamTypes;
+ irParamTypes.Add(irThisType);
+
+ // Create the remaining parameters of the callable,
+ // using a decl-ref specialized to the tagged union
+ // type (so that things like associated types are
+ // mapped to the correct witness value).
+ //
+ List<IRParam*> irParams;
+ for( auto paramDeclRef : getMembersOfType<ParamDecl>(callableDeclRef) )
+ {
+ // TODO: need to handle `out` and `in out` here. Over all
+ // there is a lot of duplication here with the existing logic
+ // for emitting the signature of a `CallableDecl`, and we should
+ // try to re-use that if at all possible.
+ //
+ auto irParamType = lowerType(context, GetType(paramDeclRef));
+ auto irParam = subBuilder->emitParam(irParamType);
+
+ irParams.Add(irParam);
+ irParamTypes.Add(irParamType);
+ }
+
+ auto irResultType = lowerType(context, GetResultType(callableDeclRef));
+
+ auto irFuncType = subBuilder->getFuncType(
+ irParamTypes,
+ irResultType);
+ irFunc->setFullType(irFuncType);
+
+ // The first thing our function needs to do is extract the tag
+ // from the incoming `this` parameter.
+ //
+ auto irTagVal = subBuilder->emitExtractTaggedUnionTag(irThisParam);
+
+ // Next we want to emit a `switch` on the tag value, but before we
+ // do that we need to generate the code for each of the cases so that
+ // our `switch` has somewhere to branch to.
+ //
+ List<IRInst*> switchCaseOperands;
+
+ IRBlock* defaultLabel = nullptr;
+
+ for( UInt ii = 0; ii < caseCount; ++ii )
+ {
+ auto caseTag = subBuilder->getIntValue(irTagVal->getDataType(), ii);
+
+ subBuilder->setInsertInto(irFunc);
+ auto caseLabel = subBuilder->emitBlock();
+
+ if(!defaultLabel)
+ defaultLabel = caseLabel;
+
+ switchCaseOperands.Add(caseTag);
+ switchCaseOperands.Add(caseLabel);
+
+ subBuilder->setInsertInto(caseLabel);
+
+ // We need to look up the satisfying value for this interface
+ // requirement on the witness table of the particular case value.
+ //
+ // We already have the witness table, and the requirement key is
+ // just `irReqKey`.
+ //
+ auto caseWitnessTable = caseWitnessTables[ii];
+
+ // The subtle bit here is determining the type we expect the
+ // satisfying value to have, since that depends on the actual
+ // type that is satisfying the requirement.
+ //
+ IRType* caseResultType = irResultType;
+ IRType* caseFuncType = nullptr;
+ auto caseFunc = subBuilder->emitLookupInterfaceMethodInst(
+ caseFuncType,
+ caseWitnessTable,
+ irReqKey);
+
+ // We are going to emit a `call` to the satisfying value
+ // for the case type, so we will collect the arguments for that call.
+ //
+ List<IRInst*> caseArgs;
+
+ // The `this` argument to the call will need to represent the
+ // appropriate field of our tagged union.
+ //
+ IRType* caseThisType = (IRType*) irTaggedUnionType->getOperand(ii);
+ auto caseThisArg = subBuilder->emitExtractTaggedUnionPayload(
+ caseThisType,
+ irThisParam, caseTag);
+ caseArgs.Add(caseThisArg);
+
+ // The remaining arguments to the call will just be forwarded from
+ // the parameters of the wrapper functon.
+ //
+ // TODO: This would need to change if/when we started allowing `This` type
+ // or assocaited-type parameters to be used at call sites where a tagged
+ // union is used.
+ //
+ for( auto param : irParams )
+ {
+ caseArgs.Add(param);
+ }
+
+ auto caseCall = subBuilder->emitCallInst(caseResultType, caseFunc, caseArgs);
+
+ if( as<IRVoidType>(irResultType->getDataType()) )
+ {
+ subBuilder->emitReturn();
+ }
+ else
+ {
+ subBuilder->emitReturn(caseCall);
+ }
+ }
+
+ // We will create a block to represent the supposedly-unreachable
+ // code that will run if no `case` matches.
+ //
+ subBuilder->setInsertInto(irFunc);
+ auto invalidLabel = subBuilder->emitBlock();
+ subBuilder->setInsertInto(invalidLabel);
+ subBuilder->emitUnreachable();
+
+ if(!defaultLabel) defaultLabel = invalidLabel;
+
+ // Now we have enough information to go back and emit the `switch` instruction
+ // into the entry block.
+ subBuilder->setInsertInto(entryBlock);
+ subBuilder->emitSwitch(
+ irTagVal, // value to `switch` on
+ invalidLabel, // `break` label (block after the `switch` statement ends)
+ defaultLabel, // `default` label (where to go if no `case` matches)
+ switchCaseOperands.Count(),
+ switchCaseOperands.Buffer());
+ }
+ else
+ {
+ // TODO: We need to handle other cases of interface requirements.
+ SLANG_UNEXPECTED("unexpceted interface requirement when generating tagged union witness table");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+
+ // Once we've generating a value to satisfying the requirement, we install
+ // it into the witness table for our tagged-union type.
+ //
+ getBuilder()->createWitnessTableEntry(irWitnessTable, irReqKey, irSatisfyingVal);
+ }
+
+ return LoweredValInfo::simple(irWitnessTable);
+ }
+
LoweredValInfo visitConstantIntVal(ConstantIntVal* val)
{
// TODO: it is a bit messy here that the `ConstantIntVal` representation
@@ -1304,6 +1557,37 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(getBuilder()->emitExtractExistentialWitnessTable(existentialVal));
}
+ LoweredValInfo visitTaggedUnionType(TaggedUnionType* type)
+ {
+ // A tagged union type will lower into an IR `union` over the cases,
+ // along with an IR `struct` with a field for the union and a tag.
+ // (Note: we are placing the tag after the payload to avoid padding
+ // in the case where the payload is more aligned than the tag)
+ //
+ // TODO: should we be lowering directly like this, or have
+ // an IR-level representation of tagged unions?
+ //
+
+ List<IRType*> irCaseTypes;
+ for(auto caseType : type->caseTypes)
+ {
+ auto irCaseType = lowerType(context, caseType);
+ irCaseTypes.Add(irCaseType);
+ }
+
+ auto irType = getBuilder()->getTaggedUnionType(irCaseTypes);
+ if(!irType->findDecoration<IRLinkageDecoration>())
+ {
+ // We need a way for later passes to attach layout information
+ // to this type, so we will give it a mangled name here.
+ //
+ getBuilder()->addExportDecoration(
+ irType,
+ getMangledTypeName(type).getUnownedSlice());
+ }
+ return LoweredValInfo::simple(irType);
+ }
+
// We do not expect to encounter the following types in ASTs that have
// passed front-end semantic checking.
#define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); }
@@ -2337,6 +2621,12 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
UNREACHABLE_RETURN(LoweredValInfo());
}
+ LoweredValInfo visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* /*expr*/)
+ {
+ SLANG_UNIMPLEMENTED_X("tagged union type expression during code generation");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+
LoweredValInfo visitAssignExpr(AssignExpr* expr)
{
// Because our representation of lowered "values"
diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp
index d78f3321a..8ad0bc9f5 100644
--- a/source/slang/mangle.cpp
+++ b/source/slang/mangle.cpp
@@ -144,6 +144,15 @@ namespace Slang
emitSimpleIntVal(context, arrType->ArrayLength);
emitType(context, arrType->baseType);
}
+ else if( auto taggedUnionType = dynamic_cast<TaggedUnionType*>(type) )
+ {
+ emitRaw(context, "u");
+ for( auto caseType : taggedUnionType->caseTypes )
+ {
+ emitType(context, caseType);
+ }
+ emitRaw(context, "U");
+ }
else
{
SLANG_UNEXPECTED("unimplemented case in mangling");
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index b8df39ade..3f5c01d09 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -2429,10 +2429,20 @@ static void collectEntryPointParameters(
entryPointLayout->entryPoint = entryPointFuncDecl;
context->entryPointLayout = entryPointLayout;
-
-
context->shared->programLayout->entryPoints.Add(entryPointLayout);
+ // Note: this isn't really the best place for this logic to sit,
+ // but it is the simplest place where we have a direct correspondance
+ // between a single `EntryPointRequest` and its matching `EntryPointLayout`,
+ // so we'll use it.
+ //
+ for( auto taggedUnionType : entryPoint->taggedUnionTypes )
+ {
+ auto substType = taggedUnionType->Substitute(typeSubst).As<Type>();
+ auto typeLayout = CreateTypeLayout(context->layoutContext, substType);
+ entryPointLayout->taggedUnionTypeLayouts.Add(typeLayout);
+ }
+
// Okay, we seemingly have an entry-point function, and now we need to collect info on its parameters too
//
// TODO: Long-term we probably want complete information on all inputs/outputs of an entry point,
diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp
index d0e965970..1afc8ebcf 100644
--- a/source/slang/parser.cpp
+++ b/source/slang/parser.cpp
@@ -1769,6 +1769,25 @@ namespace Slang
return typeExpr;
}
+ static RefPtr<Expr> parseTaggedUnionType(Parser* parser)
+ {
+ RefPtr<TaggedUnionTypeExpr> taggedUnionType = new TaggedUnionTypeExpr();
+
+ parser->ReadToken(TokenType::LParent);
+ while(!AdvanceIfMatch(parser, TokenType::RParent))
+ {
+ auto caseType = parser->ParseTypeExp();
+ taggedUnionType->caseTypes.Add(caseType);
+
+ if(AdvanceIf(parser, TokenType::RParent))
+ break;
+
+ parser->ReadToken(TokenType::Comma);
+ }
+
+ return taggedUnionType;
+ }
+
static TypeSpec parseTypeSpec(Parser* parser)
{
TypeSpec typeSpec;
@@ -1812,6 +1831,11 @@ namespace Slang
typeSpec.expr = createDeclRefType(parser, decl);
return typeSpec;
}
+ else if(AdvanceIf(parser, "__TaggedUnion"))
+ {
+ typeSpec.expr = parseTaggedUnionType(parser);
+ return typeSpec;
+ }
Token typeName = parser->ReadToken(TokenType::Identifier);
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index d516d0cb2..38b417960 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -827,7 +827,7 @@ int CompileRequest::addEntryPoint(
entryPoint->profile = entryPointProfile;
entryPoint->translationUnitIndex = translationUnitIndex;
for (auto typeName : genericTypeNames)
- entryPoint->genericParameterTypeNames.Add(typeName);
+ entryPoint->genericArgStrings.Add(typeName);
auto translationUnit = translationUnits[translationUnitIndex].Ptr();
translationUnit->entryPoints.Add(entryPoint);
diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj
index 132a56ea5..42b1a449d 100644
--- a/source/slang/slang.vcxproj
+++ b/source/slang/slang.vcxproj
@@ -198,6 +198,7 @@
<ClInclude Include="ir-specialize-resources.h" />
<ClInclude Include="ir-specialize.h" />
<ClInclude Include="ir-ssa.h" />
+ <ClInclude Include="ir-union.h" />
<ClInclude Include="ir-validate.h" />
<ClInclude Include="ir.h" />
<ClInclude Include="legalize-types.h" />
@@ -252,6 +253,7 @@
<ClCompile Include="ir-specialize-resources.cpp" />
<ClCompile Include="ir-specialize.cpp" />
<ClCompile Include="ir-ssa.cpp" />
+ <ClCompile Include="ir-union.cpp" />
<ClCompile Include="ir-validate.cpp" />
<ClCompile Include="ir.cpp" />
<ClCompile Include="legalize-types.cpp" />
@@ -314,4 +316,4 @@
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
-</Project> \ No newline at end of file
+</Project> \ No newline at end of file
diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters
index 390c0cc5f..aeb7a1b3e 100644
--- a/source/slang/slang.vcxproj.filters
+++ b/source/slang/slang.vcxproj.filters
@@ -93,6 +93,9 @@
<ClInclude Include="ir-ssa.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="ir-union.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="ir-validate.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -251,6 +254,9 @@
<ClCompile Include="ir-ssa.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="ir-union.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="ir-validate.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index 320fc576b..fe8b2c4fe 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -2539,7 +2539,160 @@ void Type::accept(IValVisitor* visitor, void* extra)
return substValue;
}
+ //
+ // TaggedUnionType
+ //
+
+ String TaggedUnionType::ToString()
+ {
+ String result;
+ result.append("__TaggedUnion(");
+ bool first = true;
+ for( auto caseType : caseTypes )
+ {
+ if(!first) result.append(", ");
+ first = false;
+
+ result.append(caseType->ToString());
+ }
+ result.append(")");
+ return result;
+ }
+
+ bool TaggedUnionType::EqualsImpl(Type* type)
+ {
+ auto taggedUnion = type->As<TaggedUnionType>();
+ if(!taggedUnion)
+ return false;
+
+ auto caseCount = caseTypes.Count();
+ if(caseCount != taggedUnion->caseTypes.Count())
+ return false;
+
+ for( UInt ii = 0; ii < caseCount; ++ii )
+ {
+ if(!caseTypes[ii]->Equals(taggedUnion->caseTypes[ii]))
+ return false;
+ }
+ return true;
+ }
+
+ int TaggedUnionType::GetHashCode()
+ {
+ int hashCode = 0;
+ for( auto caseType : caseTypes )
+ {
+ hashCode = combineHash(hashCode, caseType->GetHashCode());
+ }
+ return hashCode;
+ }
+
+ RefPtr<Type> TaggedUnionType::CreateCanonicalType()
+ {
+ RefPtr<TaggedUnionType> canType = new TaggedUnionType();
+ canType->setSession(getSession());
+
+ for( auto caseType : caseTypes )
+ {
+ auto canCaseType = caseType->GetCanonicalType();
+ canType->caseTypes.Add(canCaseType);
+ }
+
+ return canType;
+ }
+
+ RefPtr<Val> TaggedUnionType::SubstituteImpl(SubstitutionSet subst, int* ioDiff)
+ {
+ int diff = 0;
+
+ List<RefPtr<Type>> substCaseTypes;
+ for( auto caseType : caseTypes )
+ {
+ substCaseTypes.Add(caseType->SubstituteImpl(subst, &diff).As<Type>());
+ }
+ if(!diff)
+ return this;
+
+ (*ioDiff)++;
+
+ RefPtr<TaggedUnionType> substType = new TaggedUnionType();
+ substType->setSession(getSession());
+ substType->caseTypes.SwapWith(substCaseTypes);
+ return substType;
+ }
+
+//
+// TaggedUnionSubtypeWitness
+//
+
+
+bool TaggedUnionSubtypeWitness::EqualsVal(Val* val)
+{
+ auto taggedUnionWitness = val->dynamicCast<TaggedUnionSubtypeWitness>();
+ if(!taggedUnionWitness)
+ return false;
+
+ auto caseCount = caseWitnesses.Count();
+ if(caseCount != taggedUnionWitness->caseWitnesses.Count())
+ return false;
+
+ for(UInt ii = 0; ii < caseCount; ++ii)
+ {
+ if(!caseWitnesses[ii]->EqualsVal(taggedUnionWitness->caseWitnesses[ii]))
+ return false;
+ }
+
+ return true;
+}
+
+String TaggedUnionSubtypeWitness::ToString()
+{
+ String result;
+ result.append("TaggedUnionSubtypeWitness(");
+ bool first = true;
+ for( auto caseWitness : caseWitnesses )
+ {
+ if(!first) result.append(", ");
+ first = false;
+
+ result.append(caseWitness->ToString());
+ }
+ return result;
+}
+
+int TaggedUnionSubtypeWitness::GetHashCode()
+{
+ int hash = 0;
+ for( auto caseWitness : caseWitnesses )
+ {
+ hash = combineHash(hash, caseWitness->GetHashCode());
+ }
+ return hash;
+}
+
+RefPtr<Val> TaggedUnionSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+
+ auto substSub = sub->SubstituteImpl(subst, &diff).As<Type>();
+ auto substSup = sup->SubstituteImpl(subst, &diff).As<Type>();
+
+ List<RefPtr<Val>> substCaseWitnesses;
+ for( auto caseWitness : caseWitnesses )
+ {
+ substCaseWitnesses.Add(caseWitness->SubstituteImpl(subst, &diff));
+ }
+
+ if(!diff)
+ return this;
+ (*ioDiff)++;
+ RefPtr<TaggedUnionSubtypeWitness> substWitness = new TaggedUnionSubtypeWitness();
+ substWitness->sub = substSub;
+ substWitness->sup = substSup;
+ substWitness->caseWitnesses.SwapWith(substCaseWitnesses);
+ return substWitness;
}
+} // namespace Slang
diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h
index d9bfecc03..2318c6933 100644
--- a/source/slang/type-defs.h
+++ b/source/slang/type-defs.h
@@ -457,3 +457,21 @@ RAW(
virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override;
)
END_SYNTAX_CLASS()
+
+ /// A tagged union of zero or more other types.
+SYNTAX_CLASS(TaggedUnionType, Type)
+RAW(
+ /// The distinct "cases" the tagged union can store.
+ ///
+ /// For each type in this array, the array index is the
+ /// tag value for that case.
+ ///
+ List<RefPtr<Type>> caseTypes;
+
+ virtual String ToString() override;
+ virtual bool EqualsImpl(Type * type) override;
+ virtual int GetHashCode() override;
+ virtual RefPtr<Type> CreateCanonicalType() override;
+ virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override;
+)
+END_SYNTAX_CLASS()
diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp
index f5cd518b8..05c61e706 100644
--- a/source/slang/type-layout.cpp
+++ b/source/slang/type-layout.cpp
@@ -2349,6 +2349,101 @@ SimpleLayoutInfo GetLayoutImpl(
rules,
outTypeLayout);
}
+ else if( auto taggedUnionType = type->As<TaggedUnionType>() )
+ {
+ // A tagged union type needs to be laid out as the maximum
+ // size of any constituent type.
+ //
+ // In practice, only a tagged union of uniform data will
+ // work, but for now we will compute the maximum usage
+ // for each resource kind for generality.
+ //
+ // For the uniform data we will start with a size
+ // of zero and an alignment of one for our base case
+ // (this is what a tagged union of no cases would consume).
+ //
+ UniformLayoutInfo info(0, 1);
+
+ // If we are being asked to construct a full `TypeLayout`
+ // object, then we'll allocate it up front.
+ //
+ RefPtr<TaggedUnionTypeLayout> taggedUnionLayout;
+ if( outTypeLayout )
+ {
+ taggedUnionLayout = new TaggedUnionTypeLayout();
+ taggedUnionLayout->type = type;
+ taggedUnionLayout->rules = rules;
+ *outTypeLayout = taggedUnionLayout;
+ }
+
+ // Now we iterate over the case types and see if they
+ // change our computed maximum size/alignement.
+ //
+ for( auto caseType : taggedUnionType->caseTypes )
+ {
+ RefPtr<TypeLayout> caseTypeLayout;
+ UniformLayoutInfo caseTypeInfo = GetLayoutImpl(context, caseType, outTypeLayout ? &caseTypeLayout : nullptr).getUniformLayout();
+
+ info.size = maximum(info.size, caseTypeInfo.size);
+ info.alignment = std::max(info.alignment, caseTypeInfo.alignment);
+
+ // If we are building a full `TypeLayout` we need to
+ // do a few more steps for each case type.
+ //
+ if( outTypeLayout )
+ {
+ // We need to remember the layout of the case type
+ // on the final `TaggedUnionTypeLayout`.
+ //
+ taggedUnionLayout->caseTypeLayouts.Add(caseTypeLayout);
+
+ // We also need to consider contributions for other
+ // resource kinds beyond uniform data.
+ //
+ for( auto caseResInfo : caseTypeLayout->resourceInfos )
+ {
+ auto unionResInfo = taggedUnionLayout->findOrAddResourceInfo(caseResInfo.kind);
+ unionResInfo->count = maximum(unionResInfo->count, caseResInfo.count);
+ }
+ }
+ }
+
+ // After we've computed the size required to hold all the
+ // case types, we will allocate space for the tag field.
+ //
+ // TODO: This assumes the tag will always be allocated out
+ // of uniform storage, which means we can't support a tagged
+ // union as part of a varying input/output signature. That is
+ // probably a valid limitation, but it should get enforced
+ // somewhere along the way.
+ //
+ {
+ // The tag is always a `uint` for now.
+ //
+ auto tagInfo = context.rules->GetScalarLayout(BaseType::UInt);
+ info.size = RoundToAlignment(info.size, tagInfo.alignment);
+
+ if( outTypeLayout )
+ {
+ taggedUnionLayout->tagOffset = info.size;
+ }
+
+ info.size += tagInfo.size;
+ info.alignment = std::max(info.alignment, tagInfo.alignment);
+ }
+
+ // As a final step, if we are computing a full `TypeLayout`
+ // we will make sure that its information on uniform layout
+ // matches what we've computed in the `UniformLayoutInfo` we return.
+ //
+ if( outTypeLayout )
+ {
+ taggedUnionLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->count = info.size;
+ taggedUnionLayout->uniformAlignment = info.alignment;
+ }
+
+ return info;
+ }
// catch-all case in case nothing matched
SLANG_ASSERT(!"unimplemented");
diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h
index c326ae989..da2f0e4f7 100644
--- a/source/slang/type-layout.h
+++ b/source/slang/type-layout.h
@@ -553,6 +553,28 @@ public:
int paramIndex = 0;
};
+ /// Layout information for a tagged union type.
+class TaggedUnionTypeLayout : public TypeLayout
+{
+public:
+ /// The layouts of each of the case types.
+ ///
+ /// The order of entries in this array matches
+ /// the order of case types on the original
+ /// `TaggedUnionType`, and the index of a case
+ /// type is also the tag value for that case.
+ ///
+ List<RefPtr<TypeLayout>> caseTypeLayouts;
+
+ /// The byte offset for the tag field.
+ ///
+ /// The tag field will always be allocted as
+ /// a `uint`, so we don't store a separate layout
+ /// for it.
+ ///
+ LayoutSize tagOffset;
+};
+
// Layout information for a single shader entry point
// within a program
//
@@ -579,6 +601,12 @@ public:
usesAnySampleRateInput = 0x1,
};
unsigned flags = 0;
+
+ /// Layouts for all tagged union types required by this entry point.
+ ///
+ /// These are any tagged union types used by the generic
+ /// arguments that this entry point is being compiled with.
+ List<RefPtr<TypeLayout>> taggedUnionTypeLayouts;
};
class GenericParamLayout : public Layout
diff --git a/source/slang/val-defs.h b/source/slang/val-defs.h
index f96ee026e..f5b099079 100644
--- a/source/slang/val-defs.h
+++ b/source/slang/val-defs.h
@@ -136,3 +136,20 @@ RAW(
)
END_SYNTAX_CLASS()
+// A witness that `sub : sup`, because `sub` is a tagged union
+// of the form `A | B | C | ...` and each of `A : sup`,
+// `B : sup`, `C : sup`, etc.
+//
+SYNTAX_CLASS(TaggedUnionSubtypeWitness, SubtypeWitness)
+RAW(
+ // Witnesses that each of the "case" types in the union
+ // is a subtype of `sup`.
+ //
+ List<RefPtr<Val>> caseWitnesses;
+
+ virtual bool EqualsVal(Val* val) override;
+ virtual String ToString() override;
+ virtual int GetHashCode() override;
+ virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int * ioDiff) override;
+)
+END_SYNTAX_CLASS()