diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2021-05-04 16:59:54 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-05-04 16:59:54 -0700 |
| commit | 85632e8db19916a2f348fe356a91119d2fde2929 (patch) | |
| tree | 9b5468ec7420cefb37b6777d94f985b7ac5c7cc8 | |
| parent | 731f1fc6b26659dc8f62fbc1969c076b78ada24f (diff) | |
Add support for returning structures that contain opaque types (#1835)
Introduction
============
Several of our target platforms share a concept of "opaque" types, including resources (`Texture2D`) and samplers (`SamplerState`), which are restricted in how they can be used. GLSL and SPIR-V place very severe restrictions, in that opaque types cannot be used for the type of:
* (mutable) local variables
* (mutable) global variables
* structure fields
* Function result/return
* `out` or `inout` parameters
The HLSL language allows all of these cases, but with the practical caveat that the compiler front-end must be able to statically analyze how opaque types have been used and "optimize away" all of the above cases. For example, it is legal to have a local variable of an opaque type, but at any point where the variable gets used it must be statically known which top-level shader parameter the variable refers to.
Existing Work
=============
In the Slang compiler we need to implement our own passes to detect these "illegal" uses of opaque types and legalize them. The work is basically broken into two distinct steps:
* The existing `legalizeResourceTypes()` pass detects illegal types (e.g., a `struct` that has a field of type `Texture2D`) and replaces them with legal types, sometimes by splitting apart declarations (e.g., a parameter using such a `struct` type gets split into multiple parameters). At a high level, we can think of this as "exposing" opaque types so that they are not hidden inside of nested structures.
* Next, the `specializeResourceOutputs()` pass detects calls to functions that output opaque types (whether by the function return value of `out` / `inout` parameters). The pass analyzes the body of such functions, and tries to isolate the logic that determines their resource-type outputs and hoise that logic into call sites (so that the opaque-type outputs can then be eliminated).
This Change
===========
One important missing case was that the type legalization step was incapable of legalizing types that appear in the result/return type of functions. The existing logic would simply diagnose an internal/unimplemented error if it ecountered a non-simple type in the return position.
At a high-level, supporting this case seems simple enough. Given a function signature like:
```
struct Things { int a; Texture2D b; }
Things myFunc(int x) { ... }
```
we want to split the result type into an "ordinary" result type and then `out` parameters for any opaque-type fields:
```
struct Things_Legal { int a; }
Things_Legal myFunc(int x, out Texture2D result_b) { ... };
```
Similarly, at a call site to a function like this:
```
Things t = myFunc(99);
```
we split the function result into ordinary and opaque-type parts, and pass the latter as `out` parameters:
```
Texture2D t_b;
Things_Legal t = myFunc(99, /*out*/ t_b);
```
The main place where things get tricky is when dealing with `return` sites within the body of a function that needs legalization:
```
Things myFunc(int x) {
...
Things things = ...;
...
return things;
}
```
In theory the answer is simple: a `return` translates into writes to the `out` parameters for any opaque-type data, followed by a return of the ordinary-type part:
```
Things_Legal myFunc(int x, out Texture2D result_b) {
...
Things_Legal things = ...;
Texture2D things_b = ...;
...
result_b = things_b;
return things;
}
```
The sticking point here is that this step requires tracking data between the legalization of the parameter list for `myFunc` and legalization of the `return`s in its body, so that we can identify the `result_b` parameter to be able to write to it. The existing type legalization pass was not built with the idea that such communication is commonly needed; it assumes that each instruction can be legalized in isolation, so long as dependencies are respected.
This change adds logic such that the `legalizeFunc()` step sets up a data structure that it used to represent information about how a function (and its parameter list) got legalized, so that the logic for a `return` can make use of that legalized information. Right now the information we track consists of just the list of parameters that were introduced to represent a return/result type.
Testing
=======
In order to confirm what features do/don't work, I added a set of tests that cover a cross-product of opaque type use cases:
* The opaque type can be used in the function result type, an `out` parameter, or an `inout` parameter
* The opaque type can be used "directly" or nested inside a `struct`.
These tests are helpful to make sure we handle the most important cases, but it is worth noting that the coverage is still lacking in that we do not sufficiently test all the options for what the function body might do. An opaque-type function result could be derived from many different sources:
* It could be a global shader parameter
* It could be an `in` or `inout` parameter of the function itself
* It could be wrapped up in one or more structure types
* It could be wrapped up in one or more array types (such that the output of specialization needs to pass around array indices)
* It could involve use of the type as a local variable (including passing it into other functions with result/`out`/`inout` outputs of opaque types)
This change makes it so that we can handle the simplest cases involving result/return types with a wrapper `struct`, and adds test cases that confirm we handle several other cases for `out` and `inout` parameters. Gaining confidence that we cover all the cases that arise in practical shaders will require more work over following changes.
15 files changed, 1095 insertions, 121 deletions
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 1cd94decb..d88166f83 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -70,7 +70,7 @@ LegalVal LegalVal::implicitDeref(LegalVal const& val) return result; } -LegalVal LegalVal::getImplicitDeref() +LegalVal LegalVal::getImplicitDeref() const { SLANG_ASSERT(flavor == Flavor::implicitDeref); return as<ImplicitDerefVal>(obj)->val; @@ -194,90 +194,515 @@ static LegalVal legalizeOperand( return LegalVal::simple(irValue); } -static void getArgumentValues( - List<IRInst*> & instArgs, - LegalVal val) + /// Helper type for legalization an IR `call` instruction +struct LegalCallBuilder { - switch (val.flavor) + LegalCallBuilder( + IRTypeLegalizationContext* context, + IRCall* call) + : m_context(context) + , m_call(call) + {} + + /// The context for legalization + IRTypeLegalizationContext* m_context = nullptr; + + /// The `call` instruction we are legalizing + IRCall* m_call = nullptr; + + /// The legalized arguments for the call + List<IRInst*> m_args; + + /// Add a logical argument to the call (which may map to zero or mmore actual arguments) + void addArg( + LegalVal const& val) { - case LegalVal::Flavor::none: - break; + // In order to add the argument(s) for `val`, + // we will recurse over its structure. - case LegalVal::Flavor::simple: - instArgs.add(val.getSimple()); - break; + switch (val.flavor) + { + case LegalVal::Flavor::none: + break; - case LegalVal::Flavor::implicitDeref: - getArgumentValues(instArgs, val.getImplicitDeref()); - break; + case LegalVal::Flavor::simple: + m_args.add(val.getSimple()); + break; - case LegalVal::Flavor::pair: + case LegalVal::Flavor::implicitDeref: + addArg(val.getImplicitDeref()); + break; + + case LegalVal::Flavor::pair: + { + auto pairVal = val.getPair(); + addArg(pairVal->ordinaryVal); + addArg(pairVal->specialVal); + } + break; + + case LegalVal::Flavor::tuple: + { + auto tuplePsuedoVal = val.getTuple(); + for (auto elem : val.getTuple()->elements) + { + addArg(elem.val); + } + } + break; + + default: + SLANG_UNEXPECTED("uhandled val flavor"); + break; + } + } + + /// Build a new call based on the original, given the expected `resultType`. + /// + /// Returns a value representing the result of the call. + LegalVal build(LegalType const& resultType) + { + // We can recursively decompose the cases for + // how to legalize a call based on the expected + // result type. + // + switch (resultType.flavor) { - auto pairVal = val.getPair(); - getArgumentValues(instArgs, pairVal->ordinaryVal); - getArgumentValues(instArgs, pairVal->specialVal); + case LegalType::Flavor::simple: + // In the case where the result type is simple, + // we can directly emit the `call` instruction + // and use the result as our single result value. + // + return LegalVal::simple(_emitCall(resultType.getSimple())); + + case LegalType::Flavor::none: + // In the case where there is no result type, + // that is equivalent to the call returning `void`. + // + // We directly emit the call and then return an + // empty value to represent the result. + // + _emitCall(m_context->builder->getVoidType()); + return LegalVal(); + + case LegalVal::Flavor::implicitDeref: + // An `implicitDeref` wraps a single value, so we can simply + // unwrap, recurse on the innter value, and then wrap up + // the result. + // + return LegalVal::implicitDeref(build(resultType.getImplicitDeref()->valueType)); + + case LegalVal::Flavor::pair: + { + // A `pair` type consists of both an ordinary part and a special part. + // + auto pairType = resultType.getPair(); + + // The ordinary part will be used as the direct result of the call, + // while the special part will need to be returned via an `out` + // argument. + // + // We will start by emitting the declaration(s) needed for those + // `out` arguments that represent the special part, and adding + // them to the argument list. Basically this step is declaring + // local variables that will hold the special part of the result, + // and it returns a value that repsents those variables. + // + auto specialVal = _addOutArg(pairType->specialType); + + // Once the argument values for the special part are set up, + // we can recurse on the ordinary part and emit the actual + // call operation (which will include our new arguments). + // + auto ordinaryVal = build(pairType->ordinaryType); + + // The resulting value will be a pair of the ordinary value + // (returned from the `call` instruction) and the special value + // (declared as zero or more local variables). + // + RefPtr<PairPseudoVal> pairVal = new PairPseudoVal(); + pairVal->pairInfo = pairType->pairInfo; + pairVal->ordinaryVal = ordinaryVal; + pairVal->specialVal = specialVal; + return LegalVal::pair(pairVal); + } + break; + + case LegalVal::Flavor::tuple: + { + // A `tuple` value consists of zero or more elements + // that are each of a "special" type. We will handle + // *all* of those values as `out` arguments akin to + // what we did for the special half of a pair type + // above. + // + auto resultVal = _addOutArg(resultType); + + // In this case there was no "ordinary" part to the + // result type of the function, so we know that + // the legalization funciton/call will use a `void` + // result type. + // + _emitCall(m_context->builder->getVoidType()); + return resultVal; + } + break; + + default: + // TODO: implement legalization of non-simple return types + SLANG_UNEXPECTED("unimplemented legalized return type for IRCall."); } - break; + } - case LegalVal::Flavor::tuple: +private: + + /// Add an `out` argument to the call, to capture the given `resultType`. + LegalVal _addOutArg(LegalType const& resultType) + { + switch (resultType.flavor) { - auto tuplePsuedoVal = val.getTuple(); - for (auto elem : val.getTuple()->elements) + case LegalType::Flavor::simple: + { + // In the leaf case we have a simple type, and + // we just want to declare a local variable based on it. + // + auto simpleType = resultType.getSimple(); + auto builder = m_context->builder; + + // Recall that a local variable in our IR represents a *pointer* + // to storage of the appropriate type. + // + auto varPtr = builder->emitVar(simpleType); + + // We need to pass that pointer as an argument to our new + // `call` instruction, so that it can receive the value + // written by the callee. + // + m_args.add(varPtr); + + // Note: Because `varPtr` is a pointer to the value we want, + // we have the small problem of needing to return a `LegalVal` + // that has dereferenced the value after the call. + // + // We solve this problem by inserting as `load` from our + // new variable immediately after the call, before going + // and resetting the insertion point to continue inserting + // stuff before the call (which is where we wnat the local + // variable declarations to go). + // + // TODO: Confirm that this logic can't go awry if (somehow) + // there is no instruction after `m_call`. That should not + // be possible inside of a function body, but it could in + // theory be a problem if we ever have top-level module-scope + // code representing initialization of constants and/or globals. + // + builder->setInsertBefore(m_call->getNextInst()); + auto val = builder->emitLoad(simpleType, varPtr); + builder->setInsertBefore(m_call); + + return LegalVal::simple(val); + } + break; + + // The remaining cases are a straightforward structural recursion + // on top of the base case above. + + case LegalType::Flavor::none: + return LegalVal(); + + case LegalVal::Flavor::implicitDeref: + return LegalVal::implicitDeref(_addOutArg(resultType.getImplicitDeref()->valueType)); + + case LegalVal::Flavor::pair: { - getArgumentValues(instArgs, elem.val); + auto pairType = resultType.getPair(); + auto specialVal = _addOutArg(pairType->specialType); + auto ordinaryVal = _addOutArg(pairType->ordinaryType); + + RefPtr<PairPseudoVal> pairVal = new PairPseudoVal(); + pairVal->pairInfo = pairType->pairInfo; + pairVal->ordinaryVal = ordinaryVal; + pairVal->specialVal = specialVal; + + return LegalVal::pair(pairVal); } + break; + + case LegalVal::Flavor::tuple: + { + auto tuplePsuedoType = resultType.getTuple(); + + RefPtr<TuplePseudoVal> tupleVal = new TuplePseudoVal(); + for (auto typeElement : tuplePsuedoType->elements) + { + TuplePseudoVal::Element valElement; + valElement.key = typeElement.key; + valElement.val = _addOutArg(typeElement.type); + tupleVal->elements.add(valElement); + } + + return LegalVal::tuple(tupleVal); + } + break; + + default: + // TODO: implement legalization of non-simple return types + SLANG_UNEXPECTED("unimplemented legalized return type for IRCall."); } - break; + } - default: - SLANG_UNEXPECTED("uhandled val flavor"); - break; + /// Emit the actual `call` instruction given an IR result type + IRInst* _emitCall(IRType* resultType) + { + // The generated call will include all of the arguments that have + // been added up to this point, which includes those that were + // added to represent legalized parts of the result type. + // + return m_context->builder->emitCallInst( + resultType, + m_call->getCallee(), + m_args.getCount(), + m_args.getBuffer()); } -} +}; + static LegalVal legalizeCall( IRTypeLegalizationContext* context, IRCall* callInst) { - auto retType = legalizeType(context, callInst->getFullType()); - IRType* retIRType = nullptr; - switch (retType.flavor) + LegalCallBuilder builder(context, callInst); + + auto argCount = callInst->getArgCount(); + for( UInt i = 0; i < argCount; i++ ) { - case LegalType::Flavor::simple: - retIRType = retType.getSimple(); - break; - case LegalType::Flavor::none: - retIRType = context->builder->getVoidType(); - break; - default: - // TODO: implement legalization of non-simple return types - SLANG_UNEXPECTED("unimplemented legalized return type for IRInstCall."); + auto legalArg = legalizeOperand(context, callInst->getArg(i)); + builder.addArg(legalArg); } - List<IRInst*> instArgs; - for (auto i = 1u; i < callInst->getOperandCount(); i++) - getArgumentValues(instArgs, legalizeOperand(context, callInst->getOperand(i))); - - return LegalVal::simple(context->builder->emitCallInst( - retIRType, - callInst->getCallee(), - instArgs.getCount(), - instArgs.getBuffer())); + auto legalResultType = legalizeType(context, callInst->getFullType()); + return builder.build(legalResultType); } -static LegalVal legalizeRetVal(IRTypeLegalizationContext* context, - LegalVal retVal) + /// Helper type for legalizing a `returnVal` instruction +struct LegalReturnBuilder { - switch (retVal.flavor) + LegalReturnBuilder(IRTypeLegalizationContext* context, IRReturn* returnInst) + : m_context(context) + , m_returnInst(returnInst) + {} + + /// Emit code to perform a return of `val` + void returnVal(LegalVal val) { - case LegalVal::Flavor::simple: - return LegalVal::simple(context->builder->emitReturn(retVal.getSimple())); - case LegalVal::Flavor::none: - return LegalVal::simple(context->builder->emitReturn()); - default: - // TODO: implement legalization of non-simple return types - SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal."); + auto builder = m_context->builder; + + switch (val.flavor) + { + case LegalVal::Flavor::simple: + // The case of a simple value is easy: just emit a `returnVal`. + // + builder->emitReturn(val.getSimple()); + break; + + case LegalVal::Flavor::none: + // The case of an empty/void value is also easy: emit a `return`. + // + builder->emitReturn(); + break; + + case LegalVal::Flavor::implicitDeref: + returnVal(val.getImplicitDeref()); + break; + + case LegalVal::Flavor::pair: + { + // The case for a pair value is the main interesting one. + // We need to write the special part of the return value + // to the `out` parameters that were declared to capture + // it, and then return the ordinary part of the value + // like normal. + // + // Note that the order here matters, because we need to + // emit the code that writes to the `out` parameters + // before the `return` instruction. + // + auto pairVal = val.getPair(); + _writeResultParam(pairVal->specialVal); + returnVal(pairVal->ordinaryVal); + } + break; + + case LegalVal::Flavor::tuple: + { + // The tuple case is kind of a degenerate combination + // of the `pair` and `none` cases: we need to emit + // writes to the `out` parameters declared to capture + // the tuple (all of it), and then we do a `return` + // of `void` because there is no ordinary result to + // capture. + // + _writeResultParam(val); + builder->emitReturn(); + } + break; + + default: + // TODO: implement legalization of non-simple return types + SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal."); + } + } + +private: + + /// Write `val` to the `out` parameters of the enclosing function + void _writeResultParam(LegalVal const& val) + { + switch (val.flavor) + { + case LegalVal::Flavor::simple: + { + // The leaf case here is the interesting one. + // + // We know that if we are writing to `out` parameters to + // represent the function result then the function must + // have been legalized in a way that introduced those parameters. + // We thus need to look up the information on how the + // function got legalized so that we can identify the + // new parameters. + // + // TODO: One detail worth confirming here is whether there + // could ever be a case where a `return` instruction gets legalized + // before its outer function does. + // + if( !m_parentFuncInfo ) + { + // We start by searching for the ancestor instruction + // that represents the function (or other code-bearing value) + // that holds this instruction. + // + auto p = m_returnInst->getParent(); + while( p && !as<IRGlobalValueWithCode>(p) ) + { + p = p->parent; + } + + // We expect that the parent is actually an IR function. + // + // TODO: What about the case where we have an `IRGlobalVar` + // of a type that needs legalization, and teh variable has + // an initializer? For now, I believe that case is disallowed + // in the legalization for global variables. + // + auto parentFunc = as<IRFunc>(p); + SLANG_ASSERT(parentFunc); + if(!parentFunc) + return; + + // We also expect that extended legalization information was + // recorded for the function. + // + RefPtr<LegalFuncInfo> parentFuncInfo; + if( !m_context->mapFuncToInfo.TryGetValue(parentFunc, parentFuncInfo) ) + { + // If we fail to find the extended information then either: + // + // * The parent function has not been legalized yet. This would + // be a violation of our assumption about ordering of legalization. + // + // * The parent function was legalized, but didn't require any + // additional IR parameters to represent its result. This would + // be a violation of our assumption that the declared result type + // of a function and the type at `return` sites inside the function + // need to match. + // + SLANG_ASSERT(parentFuncInfo); + return; + } + + // If we find the extended information, then this is the first + // leaf parameter we are dealing with, so we set up to read through + // the parameters starting at index zero. + // + m_parentFuncInfo = parentFuncInfo; + m_resultParamCounter = 0; + } + SLANG_ASSERT(m_parentFuncInfo); + + // The recursion through the result `val` will iterate over the + // leaf parameters in the same order they should have been declared, + // so the parameter we need to write to will be the next one in order. + // + // We expect that the parameter index must be in range, beacuse otherwise + // the recursion here and the recursion that declared the parameters are + // mismatched in terms of how they traversed the hierarchical representation + // of `LegalVal` / `LegalType`. + // + Index resultParamIndex = m_resultParamCounter++; + SLANG_ASSERT(resultParamIndex >= 0); + SLANG_ASSERT(resultParamIndex < m_parentFuncInfo->resultParamVals.getCount()); + + // Once we've identified the right parameter, we can emit a `store` + // to write the value that the function wants to output. + // + // Note that an `out` parameter is represented with a pointer type + // in the IR, so that the `IRParam` here represents a pointer to + // the value that will receive the result. + // + auto resultParamPtr = m_parentFuncInfo->resultParamVals[resultParamIndex]; + m_context->builder->emitStore(resultParamPtr, val.getSimple()); + } + break; + + // The remaining cases are just a straightforward recursion + // over the structure of the `val`. + + case LegalVal::Flavor::none: + break; + + case LegalVal::Flavor::implicitDeref: + _writeResultParam(val.getImplicitDeref()); + break; + + case LegalVal::Flavor::pair: + { + auto pairVal = val.getPair(); + _writeResultParam(pairVal->ordinaryVal); + _writeResultParam(pairVal->specialVal); + } + break; + + case LegalVal::Flavor::tuple: + { + auto tupleVal = val.getTuple(); + for (auto element : tupleVal->elements) + { + _writeResultParam(element.val); + } + } + break; + + default: + // TODO: implement legalization of non-simple return types + SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal."); + } } + + IRTypeLegalizationContext* m_context = nullptr; + IRReturn* m_returnInst = nullptr; + + RefPtr<LegalFuncInfo> m_parentFuncInfo; + Index m_resultParamCounter = 0; +}; + +static LegalVal legalizeRetVal( + IRTypeLegalizationContext* context, + LegalVal retVal, + IRReturnVal* returnInst) +{ + LegalReturnBuilder builder(context, returnInst); + builder.returnVal(retVal); + return LegalVal(); } static LegalVal legalizeLoad( @@ -1232,7 +1657,7 @@ static LegalVal legalizeInst( case kIROp_Call: return legalizeCall(context, (IRCall*)inst); case kIROp_ReturnVal: - return legalizeRetVal(context, args[0]); + return legalizeRetVal(context, args[0], (IRReturnVal*)inst); case kIROp_makeStruct: return legalizeMakeStruct( context, @@ -1487,82 +1912,315 @@ static LegalVal legalizeInst( return legalVal; } -static void addParamType(List<IRType*>& ioParamTypes, LegalType t) + /// Helper type for legalizing the signature of an `IRFunc` +struct LegalFuncBuilder { - switch (t.flavor) + LegalFuncBuilder(IRTypeLegalizationContext* context) + : m_context(context) + {} + + /// Construct a legalized value to represent `oldFunc` + LegalVal build(IRFunc* oldFunc) { - case LegalType::Flavor::none: - break; + // We can start by computing what the type signature of the + // legalized function should be, based on the type signature + // of the original. + // + IRFuncType* oldFuncType = oldFunc->getDataType(); - case LegalType::Flavor::simple: - ioParamTypes.add(t.getSimple()); - break; + // Each parameter of the original function will translate into + // zero or more parameters in the legalized function signature. + // + UInt oldParamCount = oldFuncType->getParamCount(); + for (UInt pp = 0; pp < oldParamCount; ++pp) + { + auto legalParamType = legalizeType(m_context, oldFuncType->getParamType(pp)); + _addParam(legalParamType); + } - case LegalType::Flavor::implicitDeref: - { - auto imp = t.getImplicitDeref(); - addParamType(ioParamTypes, imp->valueType); - break; - } - case LegalType::Flavor::pair: + // We will record how many parameters resulted from + // legalization of the original / "base" parameter list. + // This number will help us in computing how many parameters + // were added to capture the result type of the function. + // + Index baseLegalParamCount = m_paramTypes.getCount(); + + // Next we add a result type to the function based on the + // legalized result type of the original function. + // + // It is possible that this process will had one or more + // `out` parameters to represent parts of the result type + // that couldn't be passed via the ordinary function result. + // + auto legalResultType = legalizeType(m_context, oldFuncType->getResultType()); + _addResult(legalResultType); + + // If any part of the result type required new function parameters + // to be introduced, then we want to know how many there were. + // These additional function paameters will always come after the original + // parameters, so that they don't shift around call sites too much. + // + // TODO: Where we put the added `out` parameters in the signature may + // have performance implications when it starts interacting with ABI + // (e.g., most ABIs assign parameters to registers from left to right, + // so parameters later in the list are more likely to be passed through + // memory; we'd need to decide whether the base parameters or the + // legalized result parameters should be prioritized for register + // allocation). + // + Index resultParamCount = m_paramTypes.getCount() - baseLegalParamCount; + + // If we didn't bottom out on a result type for the legalized function, + // then we should default to returning `void`. + // + auto irBuilder = m_context->builder; + if( !m_resultType ) { - auto pairInfo = t.getPair(); - addParamType(ioParamTypes, pairInfo->ordinaryType); - addParamType(ioParamTypes, pairInfo->specialType); + m_resultType = irBuilder->getVoidType(); } - break; - case LegalType::Flavor::tuple: - { - auto tup = t.getTuple(); - for (auto & elem : tup->elements) - addParamType(ioParamTypes, elem.type); - } - break; - default: - SLANG_UNEXPECTED("unknown legalized type flavor"); + + // We will compute the new IR type for the function and install it + // as the data type of original function. + // + // Note: This is one of the few cases where the legalization pass + // prefers to modify an IR node in-place rather than create a distinct + // legalized copy of it. + // + auto newFuncType = irBuilder->getFuncType( + m_paramTypes.getCount(), + m_paramTypes.getBuffer(), + m_resultType); + irBuilder->setDataType(oldFunc, newFuncType); + + // If the function required any new parameters to be created + // to represent the result/return type, then we need to + // actually add the appropriate IR parameters to represent + // that stuff as well. + // + if( resultParamCount != 0 ) + { + // Only a function with a body will need this additonal + // step, since the function parameters are stored on the + // first block of the body. + // + auto firstBlock = oldFunc->getFirstBlock(); + if( firstBlock ) + { + // Because legalization of this function required us + // to introduce new parameters, we need to allocate + // a data structure to record the identities of those + // new parameters so that they can be looked up when + // legalizing the body of the function. + // + // In particular, we will use this information when + // legalizing `return` instructions in the function body, + // since those will need to store at least part of + // the reuslt value into the newly-declared parameter(s). + // + RefPtr<LegalFuncInfo> funcInfo = new LegalFuncInfo(); + m_context->mapFuncToInfo.Add(oldFunc, funcInfo); + + // We know that our new parameters need to come after + // those that were declared for the "base" parameters + // of the original function. + // + auto firstResultParamIndex = baseLegalParamCount; + auto firstOrdinaryInst = firstBlock->getFirstOrdinaryInst(); + for( Index i = 0; i < resultParamCount; ++i ) + { + // Note: The parameter types that were added to + // the `m_paramTypes` array already account for the + // fact that these are `out` parameters, since that + // impacts the function type signature as well. + // We do *not* need to wrap `paramType` in an `Out<...>` + // type here. + // + auto paramType = m_paramTypes[firstResultParamIndex + i]; + auto param = irBuilder->createParam(paramType); + param->insertBefore(firstOrdinaryInst); + + funcInfo->resultParamVals.add(param); + } + } + } + + // Note: at this point we do *not* apply legalization to the parameters + // of the function or its body; those are left for the recursive part + // of the overall legalization pass to handle. + + return LegalVal::simple(oldFunc); } -} -static LegalVal legalizeFunc( - IRTypeLegalizationContext* context, - IRFunc* irFunc) -{ - // Overwrite the function's type with the result of legalization. - IRFuncType* oldFuncType = irFunc->getDataType(); - UInt oldParamCount = oldFuncType->getParamCount(); +private: + IRTypeLegalizationContext* m_context = nullptr;; + + /// The types of the parameters of the legalized function + List<IRType*> m_paramTypes; + + /// The result type of the legalized function (can be null to represent `void`) + IRType* m_resultType = nullptr; - // TODO: we should give an error message when the result type of a function - // can't be legalized (e.g., trying to return a texture, or a structue that - // contains one). - auto legalReturnType = legalizeType(context, oldFuncType->getResultType()); - IRType* newResultType = nullptr; - switch (legalReturnType.flavor) + /// Add a parameter of type `t` to the function signature + void _addParam(LegalType t) { - case LegalType::Flavor::simple: - newResultType = legalReturnType.getSimple(); - break; - case LegalType::Flavor::none: - newResultType = context->builder->getVoidType(); - break; - default: - SLANG_UNEXPECTED("unknown legalized function return type."); + // This logic is a simple recursion over the structure of `t`, + // with the leaf case adding parameters of simple IR type. + + switch (t.flavor) + { + case LegalType::Flavor::none: + break; + + case LegalType::Flavor::simple: + m_paramTypes.add(t.getSimple()); + break; + + case LegalType::Flavor::implicitDeref: + { + auto imp = t.getImplicitDeref(); + _addParam(imp->valueType); + } + break; + case LegalType::Flavor::pair: + { + auto pairInfo = t.getPair(); + _addParam(pairInfo->ordinaryType); + _addParam(pairInfo->specialType); + } + break; + case LegalType::Flavor::tuple: + { + auto tup = t.getTuple(); + for (auto & elem : tup->elements) + _addParam(elem.type); + } + break; + default: + SLANG_UNEXPECTED("unknown legalized type flavor"); + } } - List<IRType*> newParamTypes; - for (UInt pp = 0; pp < oldParamCount; ++pp) + + /// Set the logical result type of the legalized function to `t` + void _addResult(LegalType t) { - auto legalParamType = legalizeType(context, oldFuncType->getParamType(pp)); - addParamType(newParamTypes, legalParamType); + switch (t.flavor) + { + case LegalType::Flavor::simple: + // The simple case is when the result type is a simple IR + // type, and we can use it directly as the return type. + // + m_resultType = t.getSimple(); + break; + + + case LegalType::Flavor::none: + // The case where we have no result type is also simple, + // becaues we can leave `m_resultType` as null to represent + // a `void` result type. + break; + + case LegalType::Flavor::implicitDeref: + { + // An `implicitDeref` is a wrapper around another legal + // type, so we can simply set the result type to the + // unwrapped inner type. + // + auto imp = t.getImplicitDeref(); + _addResult(imp->valueType); + } + break; + + case LegalType::Flavor::pair: + { + // The `pair` case is the first interesting one. + // + // We will set the actual result type of the operation + // to the ordinary side of the pair, while any special + // part of the pair will be returned via fresh `out` + // parameters insteqad. + // + auto pairInfo = t.getPair(); + _addResult(pairInfo->ordinaryType); + _addOutParam(pairInfo->specialType); + } + break; + + case LegalType::Flavor::tuple: + { + // In the `tuple` case we have zero or more types, + // and there is no distinguished primary one that + // should become the result type of the legalized function. + // + // We will instead declare fresh `out` parameters to + // capture all the outputs in the tuple. + // + auto tup = t.getTuple(); + for( auto & elem : tup->elements ) + { + _addOutParam(elem.type); + } + } + break; + + default: + SLANG_UNEXPECTED("unknown legalized type flavor"); + } } - auto newFuncType = context->builder->getFuncType( - newParamTypes.getCount(), - newParamTypes.getBuffer(), - newResultType); + /// Add a single `out` parameter based on type `t`. + void _addOutParam(LegalType t) + { + switch (t.flavor) + { + case LegalType::Flavor::simple: + // The simple case here is almost the same as `_addParam()`, + // except we wrap the simple type in `Out<...>` to indicate + // that we are producing an `out` parameter. + // + m_paramTypes.add(m_context->builder->getOutType(t.getSimple())); + break; + + // The remaining cases are all simple recursion on the + // structure of `t`. - context->builder->setDataType(irFunc, newFuncType); + case LegalType::Flavor::none: + break; - return LegalVal::simple(irFunc); + case LegalType::Flavor::implicitDeref: + { + auto imp = t.getImplicitDeref(); + _addOutParam(imp->valueType); + } + break; + case LegalType::Flavor::pair: + { + auto pairInfo = t.getPair(); + _addOutParam(pairInfo->ordinaryType); + _addOutParam(pairInfo->specialType); + } + break; + case LegalType::Flavor::tuple: + { + auto tup = t.getTuple(); + for( auto & elem : tup->elements ) + { + _addOutParam(elem.type); + } + } + break; + default: + SLANG_UNEXPECTED("unknown legalized type flavor"); + } + } +}; + +static LegalVal legalizeFunc( + IRTypeLegalizationContext* context, + IRFunc* irFunc) +{ + LegalFuncBuilder builder(context); + return builder.build(irFunc); } static LegalVal declareSimpleVar( diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index 23e68182e..c7398fe23 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -171,7 +171,7 @@ struct ResourceOutputSpecializationPass oldFunc, newFunc); - // At first `newFunc` is a directclone of `oldFunc`, and thus doesn't + // At first `newFunc` is a direct clone of `oldFunc`, and thus doesn't // solve any of our problems. We will traverse `oldFunc` and specialize // it as needed, while also collecting information that will allow // us to rewrite call sites. @@ -468,8 +468,29 @@ struct ResourceOutputSpecializationPass // // Any failures along the way cause the whole process to fail. - for( auto param : func->getParams() ) + // Note: We are introducing new parameters at the same time as we + // iterate over the parameter list, so we cannot just use the + // `func->getParams()` convenience accessor. Instead, we manually + // iterate over the parameters in a way that avoids invalidation + // if we remove the `param` we are working on. + // + // Note: it might seem odd that we are modifying `func` but will + // still bail out on any errors. You might ask: isn't there a chance + // that we will end up with the function in a partially-modified state? + // + // The important thing to remember is that `func` is *copy* of the + // original function, so any modifications we make to it do not + // affect the original, so that if we *do* have to bail out we can + // leave any call sites intact as calls to the original. The result + // is that bailing out here may leave the new/copied function in + // a state where it isn't useful, but it also won't have any uses, + // and can be eliminated later. + // + IRParam* nextParam = nullptr; + for( IRParam* param = func->getFirstParam(); param; param = nextParam ) { + nextParam = param->getNextParam(); + ParamInfo paramInfo; SLANG_RETURN_ON_FAIL(maybeSpecializeParam(param, paramInfo, outFuncInfo)); outFuncInfo.oldParams.add(paramInfo); diff --git a/source/slang/slang-legalize-types.h b/source/slang/slang-legalize-types.h index 8f2a7572f..600f1d7a7 100644 --- a/source/slang/slang-legalize-types.h +++ b/source/slang/slang-legalize-types.h @@ -509,7 +509,7 @@ struct LegalVal } static LegalVal implicitDeref(LegalVal const& val); - LegalVal getImplicitDeref(); + LegalVal getImplicitDeref() const; static LegalVal pair(RefPtr<PairPseudoVal> pairInfo); static LegalVal pair( @@ -568,6 +568,30 @@ struct WrappedBufferPseudoVal : LegalValImpl // + /// Information about a function that has been legalized + /// + /// This type is used to track any information about the function + /// and its signature that might be relevant to the legalization + /// of instructions inside the function body. + /// +struct LegalFuncInfo : RefObject +{ + /// Any parameters that were added to the function signature + /// to represent the function result after legalization. + /// + /// It is possible that the result type of a function needed + /// to be split into multiple types, and as a result a single + /// function result couldn't return all of them. + /// + /// This array is a list of `out` parameters created to represent + /// additional function results. Because they are `out` parameters, + /// each is a *pointer* to a value of the relevant type. + /// + List<IRInst*> resultParamVals; +}; + +// + /// Context that drives type legalization /// /// This type is an abstract base class, and there are @@ -601,6 +625,14 @@ struct IRTypeLegalizationContext Dictionary<IRType*, LegalType> mapTypeToLegalType; + /// Map a function to information about how it was legalized. + /// + /// Note that entries are only created if there is somehting for them + /// to represent, so many functions may lack entries in this map even + /// after legalization. + /// + Dictionary<IRFunc*, RefPtr<LegalFuncInfo>> mapFuncToInfo; + IRBuilder* getBuilder() { return builder; } /// Customization point to decide what types are "special." diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang new file mode 100644 index 000000000..ea94e6ffa --- /dev/null +++ b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang @@ -0,0 +1,55 @@ +// inout-param-opaque-type-in-struct.slang + +// Test that a function/method can have an `out` parameter of +// aggregate type that includes an opaque type + +//TEST(compute):COMPARE_COMPUTE: + +struct Things +{ + int first; + RWStructuredBuffer<int> rest; +} + +//TEST_INPUT:set C = new { {1, ubuffer(data=[2 3 4 5], stride=4)}, {6, ubuffer(data=[7 8 9 10], stride=4)} } +cbuffer C +{ + Things gX; + Things gY; +} + +void swap( + inout Things a, + inout Things b) +{ + Things t = a; + a = b; + b = t; +} + +int eval(Things t, int val) +{ + return t.first*256 + t.rest[val]; +} + +int test(int val) +{ + Things f = gX; + Things g = gY; + + swap(f, g); + + return (eval(f,val) << 16) + eval(g,val); +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang.expected.txt b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang.expected.txt new file mode 100644 index 000000000..43533c76f --- /dev/null +++ b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang.expected.txt @@ -0,0 +1,4 @@ +6070102 +6080103 +6090104 +60A0105 diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type.slang b/tests/language-feature/types/opaque/inout-param-opaque-type.slang new file mode 100644 index 000000000..682f89fd0 --- /dev/null +++ b/tests/language-feature/types/opaque/inout-param-opaque-type.slang @@ -0,0 +1,42 @@ +// inout-param-opaque-type.slang + +// Test that a function/method can have an `out` parameter of opaque type + +//TEST(compute):COMPARE_COMPUTE: + +//TEST_INPUT:set gX = ubuffer(data=[16 17 18 19], stride=4) +RWStructuredBuffer<int> gX; + +//TEST_INPUT:set gY = ubuffer(data=[3 6 9 12], stride=4) +RWStructuredBuffer<int> gY; + +void swap( + inout RWStructuredBuffer<int> a, + inout RWStructuredBuffer<int> b) +{ + RWStructuredBuffer<int> t = a; + a = b; + b = t; +} + +int test(int val) +{ + RWStructuredBuffer<int> f = gX; + RWStructuredBuffer<int> g = gY; + + swap(f, g); + + return f[val] * 256 + g[val]; +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type.slang.expected.txt b/tests/language-feature/types/opaque/inout-param-opaque-type.slang.expected.txt new file mode 100644 index 000000000..81cf98393 --- /dev/null +++ b/tests/language-feature/types/opaque/inout-param-opaque-type.slang.expected.txt @@ -0,0 +1,4 @@ +310 +611 +912 +C13 diff --git a/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang new file mode 100644 index 000000000..a6c645c01 --- /dev/null +++ b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang @@ -0,0 +1,39 @@ +// out-opaque-type-in-struct.slang + +// Test that a function/method can have an `out` parameter of +// aggregate type that includes an opaque type + +//TEST(compute):COMPARE_COMPUTE: + +struct Things +{ + int first; + RWStructuredBuffer<int> rest; +} + +//TEST_INPUT:set gThings = new Things { 1, ubuffer(data=[2 3 4 5], stride=4) } +ConstantBuffer<Things> gThings; + +void getThings(out Things outThings) +{ + outThings = gThings; +} + +int test(int val) +{ + Things things; + getThings(things); + return things.first * (16 << val) + things.rest[val]; +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang.expected.txt b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang.expected.txt new file mode 100644 index 000000000..553843b5d --- /dev/null +++ b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang.expected.txt @@ -0,0 +1,4 @@ +12 +23 +44 +85 diff --git a/tests/language-feature/types/opaque/out-param-opaque-type.slang b/tests/language-feature/types/opaque/out-param-opaque-type.slang new file mode 100644 index 000000000..3ac7c0d6f --- /dev/null +++ b/tests/language-feature/types/opaque/out-param-opaque-type.slang @@ -0,0 +1,33 @@ +// out-opaque-type.slang + +// Test that a function/method can have an `out` parameter of opaque type + +//TEST(compute):COMPARE_COMPUTE: + +//TEST_INPUT:set gThings = ubuffer(data=[16 17 18 19], stride=4) +RWStructuredBuffer<int> gThings; + + +void getThings(out RWStructuredBuffer<int> things) +{ + things = gThings; +} + +int test(int val) +{ + RWStructuredBuffer<int> t; + getThings(t); + return t[val]; +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/out-param-opaque-type.slang.expected.txt b/tests/language-feature/types/opaque/out-param-opaque-type.slang.expected.txt new file mode 100644 index 000000000..a0d427709 --- /dev/null +++ b/tests/language-feature/types/opaque/out-param-opaque-type.slang.expected.txt @@ -0,0 +1,4 @@ +10 +11 +12 +13 diff --git a/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang new file mode 100644 index 000000000..2687af1c3 --- /dev/null +++ b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang @@ -0,0 +1,38 @@ +// return-opaque-type-in-struct.slang + +// Test that a function/method can return a value of +// aggregate type that includes an opaque type + +//TEST(compute):COMPARE_COMPUTE: + +struct Things +{ + int first; + RWStructuredBuffer<int> rest; +} + +//TEST_INPUT:set gThings = new Things { 1, ubuffer(data=[2 3 4 5], stride=4) } +ConstantBuffer<Things> gThings; + +Things getThings() +{ + return gThings; +} + +int test(int val) +{ + let things = getThings(); + return things.first * (16 << val) + things.rest[val]; +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang.expected.txt b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang.expected.txt new file mode 100644 index 000000000..553843b5d --- /dev/null +++ b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang.expected.txt @@ -0,0 +1,4 @@ +12 +23 +44 +85 diff --git a/tests/language-feature/types/opaque/return-opaque-type.slang b/tests/language-feature/types/opaque/return-opaque-type.slang new file mode 100644 index 000000000..83d4376ba --- /dev/null +++ b/tests/language-feature/types/opaque/return-opaque-type.slang @@ -0,0 +1,32 @@ +// return-opaque-type.slang + +// Test that a function/method can return a value of an opaque type. + +//TEST(compute):COMPARE_COMPUTE: + +struct Stuff +{ + RWStructuredBuffer<int> things; + + RWStructuredBuffer<int> getThings() { return things; } +} + +//TEST_INPUT:set gStuff = new Stuff { ubuffer(data=[16 17 18 19], stride=4) } +ConstantBuffer<Stuff> gStuff; + +int test(int val) +{ + return gStuff.getThings()[val]; +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/return-opaque-type.slang.expected.txt b/tests/language-feature/types/opaque/return-opaque-type.slang.expected.txt new file mode 100644 index 000000000..a0d427709 --- /dev/null +++ b/tests/language-feature/types/opaque/return-opaque-type.slang.expected.txt @@ -0,0 +1,4 @@ +10 +11 +12 +13 |
