summaryrefslogtreecommitdiffstats
path: root/source/slang/lower-to-ir.cpp
diff options
context:
space:
mode:
authorTim Foley <tfoleyNV@users.noreply.github.com>2019-01-16 12:48:11 -0800
committerGitHub <noreply@github.com>2019-01-16 12:48:11 -0800
commitaedf61784606406c090302efd8b7ac668ac997fc (patch)
treeb485fe5d7b027b269fcfa10503321288d10b9800 /source/slang/lower-to-ir.cpp
parent8e47a3802d4d74eb11620f147ef5b29b8e931d35 (diff)
Initial support for dynamic dispatch using "tagged union" types (#772)
* Initial support for dynamic dispatch using "tagged union" types Suppose a user declares some generic shader code, like the following: ```hlsl interface IFrobnicator { ... } type_param T : IFrobincator; ParameterBlock<T : IFrobnicator> gFrobnicator; ... gFrobincator.frobnicate(value); ``` and then they have some concrete implementations of the required interface: ```hlsl struct A : IFrobnicator { ... } struct B : IFrobnicator { ... } ``` The current Slang compiler allows them to generate distinct compiled kernels for the case of `T=A` and the case of `T=B`. This means that the decision of which implementation to use must be made at or before the time when a shader gets bound in the application. This change adds a new ability where the Slang compiler can generate code to handle the case where `T` might be *either* `A` or `B`, and which case it is will be determined dynamically at runtime. This means a single compiled kernel can handle both cases, and the decision about which code path to run can be made any time before the shader executes. This new option is supported by defining a *tagged union* type. Via the API, the user specifies that `T` should be specialized to `__TaggedUnion(A,B)` (the double underscore indicates that this is an experimental and unsupported feature at present). We refer to the types `A` and `B` here as the "case" types of the tagged union. Conceptually, the compiler synthesizes a type something like: ```hlsl struct TU { union { A a; B b; } payload; uint tag; } ``` The user can then allocate a constant buffer to hold their tagged union type, and when they pick a concrete type to use (say `B`), they fill in the first `sizeof(B)` bytes of their buffer with data describing a `B` instance, and then set the `tag` field to the appopriate 0-based index of the case type they chose (in this case the `B` case gets the tag value `1`). Actually implementing tagged unions takes a few main steps: * Type parsing was extended to special-case `__TaggedUnion` as a contextual keyword. This is really only intended to be used when parsing types from the API or command-line, and Bad Things are likely to happen if a user ever puts it directly in their code. Eventually construction of tagged unions should be an API feature and not part of the language syntax. * Semantic checking was extended to recognize that a tagged union like `__TaggedUnion(A,B)` shoud support an interface like `IFrobnicator` whenever all of the case types suport it, as long as the interface is "safe" for use with tagged unions (which means it doesn't use a few of the advancd langauge features like associated types). * The IR was extended with instructions to represent tagged union types and to extract their tag and the payload for the different cases as needed. * IR generation was extended to synthesize implementations of interface methods for any interface that a tagged union needs to support. Right now the implementation is simplistic and only handles simple method requirements, which it does by emitting a `switch` instruction to pick between the different cases. * A new IR pass was introduced to "desugar" any tagged union types used in the code. The downstream HLSL and GLSL compilers don't support `union`s, so we have to instead emit a tagged union as a "bag of bits" and implement loading the data for particular cases from it manually. * Final code emit mostly Just Works after the above steps, but we had to introduce an explicit IR instruction for bit-casting to handle the output of the desugaring pass. There are a bunch of gaps and caveats in this implementation, but that seems reasonable for something that is an experimental feature. The various `TODO` comments and assertion failures in unimplemented cases are intended, so that this work can be checked in even if it isn't feature-complete. * fixup: missing files * fixup: typos
Diffstat (limited to 'source/slang/lower-to-ir.cpp')
-rw-r--r--source/slang/lower-to-ir.cpp290
1 files changed, 290 insertions, 0 deletions
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index b1bc63fa1..3e981dbc5 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -1121,6 +1121,259 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
requirementKey));
}
+ LoweredValInfo visitTaggedUnionSubtypeWitness(
+ TaggedUnionSubtypeWitness* val)
+ {
+ // The sub-type in this case is a tagged union `A | B | ...`,
+ // and the witness holds an array of witnesses showing that each
+ // "case" (`A`, `B`, etc.) is a subtype of the super-type.
+
+ // We will start by getting the IR-level representation of the
+ // sub type (the tagged union type).
+ //
+ auto irTaggedUnionType = lowerType(context, val->sub);
+
+ // We can turn each of those per-case witnesses into a witness
+ // table value:
+ //
+ auto caseCount = val->caseWitnesses.Count();
+ List<IRInst*> caseWitnessTables;
+ for( auto caseWitness : val->caseWitnesses )
+ {
+ auto caseWitnessTable = lowerSimpleVal(context, caseWitness);
+ caseWitnessTables.Add(caseWitnessTable);
+ }
+
+ // Now we need to synthesize a witness table for the tagged union
+ // value, showing how it can implement all of the requirements
+ // of the super type by delegating to the appropriate implementation
+ // on a per-case basis.
+ //
+ // We will assume here that the super-type is an interface, and it
+ // will be left to the front-end to ensure this property.
+ //
+ auto supDeclRefType = val->sup->As<DeclRefType>();
+ if(!supDeclRefType)
+ {
+ SLANG_UNEXPECTED("super-type not a decl-ref type when generating tagged union witness table");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+ auto supInterfaceDeclRef = supDeclRefType->declRef.As<InterfaceDecl>();
+ if( !supInterfaceDeclRef )
+ {
+ SLANG_UNEXPECTED("super-type not an interface type when generating tagged union witness table");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+
+ auto irWitnessTable = getBuilder()->createWitnessTable();
+
+ // Now we will iterate over the requirements (members) of the
+ // interface and try to synthesize an appropriate value for each.
+ //
+ for( auto reqDeclRef : getMembers(supInterfaceDeclRef) )
+ {
+ // TODO: if there are any members we shouldn't process as a requirement,
+ // then we should detect and skip them here.
+ //
+
+ // Every interface requirement will have a unique key that is used
+ // when looking up the requirement in a concrete witness table.
+ //
+ auto irReqKey = getInterfaceRequirementKey(context, reqDeclRef.getDecl());
+
+ // We expect that each of the witness tables in `caseWitnessTables`
+ // will have an entry to match these keys. However, we may not
+ // have a concrete `IRWitnessTable` for each of the case types, either
+ // because they are a specialization of a generic (so that the witness
+ // table reference is a `specialize` instruction at this point), or
+ // they are a type external to this module (so that we have a declaration
+ // rather than a definition of the witness table).
+
+ // Our task is to create an IR value that can satisfy the interface
+ // requirement for the tagged union type, by appropriately delegating
+ // to the implementations of the same requirement in the case types.
+ //
+ IRInst* irSatisfyingVal = nullptr;
+
+
+
+ if(auto callableDeclRef = reqDeclRef.As<CallableDecl>())
+ {
+ // We have something callable, so we need to synthesize
+ // a function to satisfy it.
+ //
+ auto irFunc = getBuilder()->createFunc();
+ irSatisfyingVal = irFunc;
+
+ IRBuilder subBuilderStorage;
+ auto subBuilder = &subBuilderStorage;
+ subBuilder->sharedBuilder = getBuilder()->sharedBuilder;
+ subBuilder->setInsertInto(irFunc);
+
+ // We will start by setting up the function parameters,
+ // which live in the entry block of the IR function.
+ //
+ auto entryBlock = subBuilder->emitBlock();
+ subBuilder->setInsertInto(entryBlock);
+
+ // Create a `this` parameter of the tagged-union type.
+ //
+ // TODO: need to handle the `[mutating]` case here...
+ //
+ auto irThisType = irTaggedUnionType;
+ auto irThisParam = subBuilder->emitParam(irThisType);
+
+ List<IRType*> irParamTypes;
+ irParamTypes.Add(irThisType);
+
+ // Create the remaining parameters of the callable,
+ // using a decl-ref specialized to the tagged union
+ // type (so that things like associated types are
+ // mapped to the correct witness value).
+ //
+ List<IRParam*> irParams;
+ for( auto paramDeclRef : getMembersOfType<ParamDecl>(callableDeclRef) )
+ {
+ // TODO: need to handle `out` and `in out` here. Over all
+ // there is a lot of duplication here with the existing logic
+ // for emitting the signature of a `CallableDecl`, and we should
+ // try to re-use that if at all possible.
+ //
+ auto irParamType = lowerType(context, GetType(paramDeclRef));
+ auto irParam = subBuilder->emitParam(irParamType);
+
+ irParams.Add(irParam);
+ irParamTypes.Add(irParamType);
+ }
+
+ auto irResultType = lowerType(context, GetResultType(callableDeclRef));
+
+ auto irFuncType = subBuilder->getFuncType(
+ irParamTypes,
+ irResultType);
+ irFunc->setFullType(irFuncType);
+
+ // The first thing our function needs to do is extract the tag
+ // from the incoming `this` parameter.
+ //
+ auto irTagVal = subBuilder->emitExtractTaggedUnionTag(irThisParam);
+
+ // Next we want to emit a `switch` on the tag value, but before we
+ // do that we need to generate the code for each of the cases so that
+ // our `switch` has somewhere to branch to.
+ //
+ List<IRInst*> switchCaseOperands;
+
+ IRBlock* defaultLabel = nullptr;
+
+ for( UInt ii = 0; ii < caseCount; ++ii )
+ {
+ auto caseTag = subBuilder->getIntValue(irTagVal->getDataType(), ii);
+
+ subBuilder->setInsertInto(irFunc);
+ auto caseLabel = subBuilder->emitBlock();
+
+ if(!defaultLabel)
+ defaultLabel = caseLabel;
+
+ switchCaseOperands.Add(caseTag);
+ switchCaseOperands.Add(caseLabel);
+
+ subBuilder->setInsertInto(caseLabel);
+
+ // We need to look up the satisfying value for this interface
+ // requirement on the witness table of the particular case value.
+ //
+ // We already have the witness table, and the requirement key is
+ // just `irReqKey`.
+ //
+ auto caseWitnessTable = caseWitnessTables[ii];
+
+ // The subtle bit here is determining the type we expect the
+ // satisfying value to have, since that depends on the actual
+ // type that is satisfying the requirement.
+ //
+ IRType* caseResultType = irResultType;
+ IRType* caseFuncType = nullptr;
+ auto caseFunc = subBuilder->emitLookupInterfaceMethodInst(
+ caseFuncType,
+ caseWitnessTable,
+ irReqKey);
+
+ // We are going to emit a `call` to the satisfying value
+ // for the case type, so we will collect the arguments for that call.
+ //
+ List<IRInst*> caseArgs;
+
+ // The `this` argument to the call will need to represent the
+ // appropriate field of our tagged union.
+ //
+ IRType* caseThisType = (IRType*) irTaggedUnionType->getOperand(ii);
+ auto caseThisArg = subBuilder->emitExtractTaggedUnionPayload(
+ caseThisType,
+ irThisParam, caseTag);
+ caseArgs.Add(caseThisArg);
+
+ // The remaining arguments to the call will just be forwarded from
+ // the parameters of the wrapper functon.
+ //
+ // TODO: This would need to change if/when we started allowing `This` type
+ // or assocaited-type parameters to be used at call sites where a tagged
+ // union is used.
+ //
+ for( auto param : irParams )
+ {
+ caseArgs.Add(param);
+ }
+
+ auto caseCall = subBuilder->emitCallInst(caseResultType, caseFunc, caseArgs);
+
+ if( as<IRVoidType>(irResultType->getDataType()) )
+ {
+ subBuilder->emitReturn();
+ }
+ else
+ {
+ subBuilder->emitReturn(caseCall);
+ }
+ }
+
+ // We will create a block to represent the supposedly-unreachable
+ // code that will run if no `case` matches.
+ //
+ subBuilder->setInsertInto(irFunc);
+ auto invalidLabel = subBuilder->emitBlock();
+ subBuilder->setInsertInto(invalidLabel);
+ subBuilder->emitUnreachable();
+
+ if(!defaultLabel) defaultLabel = invalidLabel;
+
+ // Now we have enough information to go back and emit the `switch` instruction
+ // into the entry block.
+ subBuilder->setInsertInto(entryBlock);
+ subBuilder->emitSwitch(
+ irTagVal, // value to `switch` on
+ invalidLabel, // `break` label (block after the `switch` statement ends)
+ defaultLabel, // `default` label (where to go if no `case` matches)
+ switchCaseOperands.Count(),
+ switchCaseOperands.Buffer());
+ }
+ else
+ {
+ // TODO: We need to handle other cases of interface requirements.
+ SLANG_UNEXPECTED("unexpceted interface requirement when generating tagged union witness table");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+
+ // Once we've generating a value to satisfying the requirement, we install
+ // it into the witness table for our tagged-union type.
+ //
+ getBuilder()->createWitnessTableEntry(irWitnessTable, irReqKey, irSatisfyingVal);
+ }
+
+ return LoweredValInfo::simple(irWitnessTable);
+ }
+
LoweredValInfo visitConstantIntVal(ConstantIntVal* val)
{
// TODO: it is a bit messy here that the `ConstantIntVal` representation
@@ -1304,6 +1557,37 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(getBuilder()->emitExtractExistentialWitnessTable(existentialVal));
}
+ LoweredValInfo visitTaggedUnionType(TaggedUnionType* type)
+ {
+ // A tagged union type will lower into an IR `union` over the cases,
+ // along with an IR `struct` with a field for the union and a tag.
+ // (Note: we are placing the tag after the payload to avoid padding
+ // in the case where the payload is more aligned than the tag)
+ //
+ // TODO: should we be lowering directly like this, or have
+ // an IR-level representation of tagged unions?
+ //
+
+ List<IRType*> irCaseTypes;
+ for(auto caseType : type->caseTypes)
+ {
+ auto irCaseType = lowerType(context, caseType);
+ irCaseTypes.Add(irCaseType);
+ }
+
+ auto irType = getBuilder()->getTaggedUnionType(irCaseTypes);
+ if(!irType->findDecoration<IRLinkageDecoration>())
+ {
+ // We need a way for later passes to attach layout information
+ // to this type, so we will give it a mangled name here.
+ //
+ getBuilder()->addExportDecoration(
+ irType,
+ getMangledTypeName(type).getUnownedSlice());
+ }
+ return LoweredValInfo::simple(irType);
+ }
+
// We do not expect to encounter the following types in ASTs that have
// passed front-end semantic checking.
#define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); }
@@ -2337,6 +2621,12 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
UNREACHABLE_RETURN(LoweredValInfo());
}
+ LoweredValInfo visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* /*expr*/)
+ {
+ SLANG_UNIMPLEMENTED_X("tagged union type expression during code generation");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+
LoweredValInfo visitAssignExpr(AssignExpr* expr)
{
// Because our representation of lowered "values"