summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-legalize-types.cpp894
-rw-r--r--source/slang/slang-ir-specialize-resources.cpp25
-rw-r--r--source/slang/slang-legalize-types.h34
-rw-r--r--tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang55
-rw-r--r--tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang.expected.txt4
-rw-r--r--tests/language-feature/types/opaque/inout-param-opaque-type.slang42
-rw-r--r--tests/language-feature/types/opaque/inout-param-opaque-type.slang.expected.txt4
-rw-r--r--tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang39
-rw-r--r--tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang.expected.txt4
-rw-r--r--tests/language-feature/types/opaque/out-param-opaque-type.slang33
-rw-r--r--tests/language-feature/types/opaque/out-param-opaque-type.slang.expected.txt4
-rw-r--r--tests/language-feature/types/opaque/return-opaque-type-in-struct.slang38
-rw-r--r--tests/language-feature/types/opaque/return-opaque-type-in-struct.slang.expected.txt4
-rw-r--r--tests/language-feature/types/opaque/return-opaque-type.slang32
-rw-r--r--tests/language-feature/types/opaque/return-opaque-type.slang.expected.txt4
15 files changed, 1095 insertions, 121 deletions
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 1cd94decb..d88166f83 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -70,7 +70,7 @@ LegalVal LegalVal::implicitDeref(LegalVal const& val)
return result;
}
-LegalVal LegalVal::getImplicitDeref()
+LegalVal LegalVal::getImplicitDeref() const
{
SLANG_ASSERT(flavor == Flavor::implicitDeref);
return as<ImplicitDerefVal>(obj)->val;
@@ -194,90 +194,515 @@ static LegalVal legalizeOperand(
return LegalVal::simple(irValue);
}
-static void getArgumentValues(
- List<IRInst*> & instArgs,
- LegalVal val)
+ /// Helper type for legalization an IR `call` instruction
+struct LegalCallBuilder
{
- switch (val.flavor)
+ LegalCallBuilder(
+ IRTypeLegalizationContext* context,
+ IRCall* call)
+ : m_context(context)
+ , m_call(call)
+ {}
+
+ /// The context for legalization
+ IRTypeLegalizationContext* m_context = nullptr;
+
+ /// The `call` instruction we are legalizing
+ IRCall* m_call = nullptr;
+
+ /// The legalized arguments for the call
+ List<IRInst*> m_args;
+
+ /// Add a logical argument to the call (which may map to zero or mmore actual arguments)
+ void addArg(
+ LegalVal const& val)
{
- case LegalVal::Flavor::none:
- break;
+ // In order to add the argument(s) for `val`,
+ // we will recurse over its structure.
- case LegalVal::Flavor::simple:
- instArgs.add(val.getSimple());
- break;
+ switch (val.flavor)
+ {
+ case LegalVal::Flavor::none:
+ break;
- case LegalVal::Flavor::implicitDeref:
- getArgumentValues(instArgs, val.getImplicitDeref());
- break;
+ case LegalVal::Flavor::simple:
+ m_args.add(val.getSimple());
+ break;
- case LegalVal::Flavor::pair:
+ case LegalVal::Flavor::implicitDeref:
+ addArg(val.getImplicitDeref());
+ break;
+
+ case LegalVal::Flavor::pair:
+ {
+ auto pairVal = val.getPair();
+ addArg(pairVal->ordinaryVal);
+ addArg(pairVal->specialVal);
+ }
+ break;
+
+ case LegalVal::Flavor::tuple:
+ {
+ auto tuplePsuedoVal = val.getTuple();
+ for (auto elem : val.getTuple()->elements)
+ {
+ addArg(elem.val);
+ }
+ }
+ break;
+
+ default:
+ SLANG_UNEXPECTED("uhandled val flavor");
+ break;
+ }
+ }
+
+ /// Build a new call based on the original, given the expected `resultType`.
+ ///
+ /// Returns a value representing the result of the call.
+ LegalVal build(LegalType const& resultType)
+ {
+ // We can recursively decompose the cases for
+ // how to legalize a call based on the expected
+ // result type.
+ //
+ switch (resultType.flavor)
{
- auto pairVal = val.getPair();
- getArgumentValues(instArgs, pairVal->ordinaryVal);
- getArgumentValues(instArgs, pairVal->specialVal);
+ case LegalType::Flavor::simple:
+ // In the case where the result type is simple,
+ // we can directly emit the `call` instruction
+ // and use the result as our single result value.
+ //
+ return LegalVal::simple(_emitCall(resultType.getSimple()));
+
+ case LegalType::Flavor::none:
+ // In the case where there is no result type,
+ // that is equivalent to the call returning `void`.
+ //
+ // We directly emit the call and then return an
+ // empty value to represent the result.
+ //
+ _emitCall(m_context->builder->getVoidType());
+ return LegalVal();
+
+ case LegalVal::Flavor::implicitDeref:
+ // An `implicitDeref` wraps a single value, so we can simply
+ // unwrap, recurse on the innter value, and then wrap up
+ // the result.
+ //
+ return LegalVal::implicitDeref(build(resultType.getImplicitDeref()->valueType));
+
+ case LegalVal::Flavor::pair:
+ {
+ // A `pair` type consists of both an ordinary part and a special part.
+ //
+ auto pairType = resultType.getPair();
+
+ // The ordinary part will be used as the direct result of the call,
+ // while the special part will need to be returned via an `out`
+ // argument.
+ //
+ // We will start by emitting the declaration(s) needed for those
+ // `out` arguments that represent the special part, and adding
+ // them to the argument list. Basically this step is declaring
+ // local variables that will hold the special part of the result,
+ // and it returns a value that repsents those variables.
+ //
+ auto specialVal = _addOutArg(pairType->specialType);
+
+ // Once the argument values for the special part are set up,
+ // we can recurse on the ordinary part and emit the actual
+ // call operation (which will include our new arguments).
+ //
+ auto ordinaryVal = build(pairType->ordinaryType);
+
+ // The resulting value will be a pair of the ordinary value
+ // (returned from the `call` instruction) and the special value
+ // (declared as zero or more local variables).
+ //
+ RefPtr<PairPseudoVal> pairVal = new PairPseudoVal();
+ pairVal->pairInfo = pairType->pairInfo;
+ pairVal->ordinaryVal = ordinaryVal;
+ pairVal->specialVal = specialVal;
+ return LegalVal::pair(pairVal);
+ }
+ break;
+
+ case LegalVal::Flavor::tuple:
+ {
+ // A `tuple` value consists of zero or more elements
+ // that are each of a "special" type. We will handle
+ // *all* of those values as `out` arguments akin to
+ // what we did for the special half of a pair type
+ // above.
+ //
+ auto resultVal = _addOutArg(resultType);
+
+ // In this case there was no "ordinary" part to the
+ // result type of the function, so we know that
+ // the legalization funciton/call will use a `void`
+ // result type.
+ //
+ _emitCall(m_context->builder->getVoidType());
+ return resultVal;
+ }
+ break;
+
+ default:
+ // TODO: implement legalization of non-simple return types
+ SLANG_UNEXPECTED("unimplemented legalized return type for IRCall.");
}
- break;
+ }
- case LegalVal::Flavor::tuple:
+private:
+
+ /// Add an `out` argument to the call, to capture the given `resultType`.
+ LegalVal _addOutArg(LegalType const& resultType)
+ {
+ switch (resultType.flavor)
{
- auto tuplePsuedoVal = val.getTuple();
- for (auto elem : val.getTuple()->elements)
+ case LegalType::Flavor::simple:
+ {
+ // In the leaf case we have a simple type, and
+ // we just want to declare a local variable based on it.
+ //
+ auto simpleType = resultType.getSimple();
+ auto builder = m_context->builder;
+
+ // Recall that a local variable in our IR represents a *pointer*
+ // to storage of the appropriate type.
+ //
+ auto varPtr = builder->emitVar(simpleType);
+
+ // We need to pass that pointer as an argument to our new
+ // `call` instruction, so that it can receive the value
+ // written by the callee.
+ //
+ m_args.add(varPtr);
+
+ // Note: Because `varPtr` is a pointer to the value we want,
+ // we have the small problem of needing to return a `LegalVal`
+ // that has dereferenced the value after the call.
+ //
+ // We solve this problem by inserting as `load` from our
+ // new variable immediately after the call, before going
+ // and resetting the insertion point to continue inserting
+ // stuff before the call (which is where we wnat the local
+ // variable declarations to go).
+ //
+ // TODO: Confirm that this logic can't go awry if (somehow)
+ // there is no instruction after `m_call`. That should not
+ // be possible inside of a function body, but it could in
+ // theory be a problem if we ever have top-level module-scope
+ // code representing initialization of constants and/or globals.
+ //
+ builder->setInsertBefore(m_call->getNextInst());
+ auto val = builder->emitLoad(simpleType, varPtr);
+ builder->setInsertBefore(m_call);
+
+ return LegalVal::simple(val);
+ }
+ break;
+
+ // The remaining cases are a straightforward structural recursion
+ // on top of the base case above.
+
+ case LegalType::Flavor::none:
+ return LegalVal();
+
+ case LegalVal::Flavor::implicitDeref:
+ return LegalVal::implicitDeref(_addOutArg(resultType.getImplicitDeref()->valueType));
+
+ case LegalVal::Flavor::pair:
{
- getArgumentValues(instArgs, elem.val);
+ auto pairType = resultType.getPair();
+ auto specialVal = _addOutArg(pairType->specialType);
+ auto ordinaryVal = _addOutArg(pairType->ordinaryType);
+
+ RefPtr<PairPseudoVal> pairVal = new PairPseudoVal();
+ pairVal->pairInfo = pairType->pairInfo;
+ pairVal->ordinaryVal = ordinaryVal;
+ pairVal->specialVal = specialVal;
+
+ return LegalVal::pair(pairVal);
}
+ break;
+
+ case LegalVal::Flavor::tuple:
+ {
+ auto tuplePsuedoType = resultType.getTuple();
+
+ RefPtr<TuplePseudoVal> tupleVal = new TuplePseudoVal();
+ for (auto typeElement : tuplePsuedoType->elements)
+ {
+ TuplePseudoVal::Element valElement;
+ valElement.key = typeElement.key;
+ valElement.val = _addOutArg(typeElement.type);
+ tupleVal->elements.add(valElement);
+ }
+
+ return LegalVal::tuple(tupleVal);
+ }
+ break;
+
+ default:
+ // TODO: implement legalization of non-simple return types
+ SLANG_UNEXPECTED("unimplemented legalized return type for IRCall.");
}
- break;
+ }
- default:
- SLANG_UNEXPECTED("uhandled val flavor");
- break;
+ /// Emit the actual `call` instruction given an IR result type
+ IRInst* _emitCall(IRType* resultType)
+ {
+ // The generated call will include all of the arguments that have
+ // been added up to this point, which includes those that were
+ // added to represent legalized parts of the result type.
+ //
+ return m_context->builder->emitCallInst(
+ resultType,
+ m_call->getCallee(),
+ m_args.getCount(),
+ m_args.getBuffer());
}
-}
+};
+
static LegalVal legalizeCall(
IRTypeLegalizationContext* context,
IRCall* callInst)
{
- auto retType = legalizeType(context, callInst->getFullType());
- IRType* retIRType = nullptr;
- switch (retType.flavor)
+ LegalCallBuilder builder(context, callInst);
+
+ auto argCount = callInst->getArgCount();
+ for( UInt i = 0; i < argCount; i++ )
{
- case LegalType::Flavor::simple:
- retIRType = retType.getSimple();
- break;
- case LegalType::Flavor::none:
- retIRType = context->builder->getVoidType();
- break;
- default:
- // TODO: implement legalization of non-simple return types
- SLANG_UNEXPECTED("unimplemented legalized return type for IRInstCall.");
+ auto legalArg = legalizeOperand(context, callInst->getArg(i));
+ builder.addArg(legalArg);
}
- List<IRInst*> instArgs;
- for (auto i = 1u; i < callInst->getOperandCount(); i++)
- getArgumentValues(instArgs, legalizeOperand(context, callInst->getOperand(i)));
-
- return LegalVal::simple(context->builder->emitCallInst(
- retIRType,
- callInst->getCallee(),
- instArgs.getCount(),
- instArgs.getBuffer()));
+ auto legalResultType = legalizeType(context, callInst->getFullType());
+ return builder.build(legalResultType);
}
-static LegalVal legalizeRetVal(IRTypeLegalizationContext* context,
- LegalVal retVal)
+ /// Helper type for legalizing a `returnVal` instruction
+struct LegalReturnBuilder
{
- switch (retVal.flavor)
+ LegalReturnBuilder(IRTypeLegalizationContext* context, IRReturn* returnInst)
+ : m_context(context)
+ , m_returnInst(returnInst)
+ {}
+
+ /// Emit code to perform a return of `val`
+ void returnVal(LegalVal val)
{
- case LegalVal::Flavor::simple:
- return LegalVal::simple(context->builder->emitReturn(retVal.getSimple()));
- case LegalVal::Flavor::none:
- return LegalVal::simple(context->builder->emitReturn());
- default:
- // TODO: implement legalization of non-simple return types
- SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal.");
+ auto builder = m_context->builder;
+
+ switch (val.flavor)
+ {
+ case LegalVal::Flavor::simple:
+ // The case of a simple value is easy: just emit a `returnVal`.
+ //
+ builder->emitReturn(val.getSimple());
+ break;
+
+ case LegalVal::Flavor::none:
+ // The case of an empty/void value is also easy: emit a `return`.
+ //
+ builder->emitReturn();
+ break;
+
+ case LegalVal::Flavor::implicitDeref:
+ returnVal(val.getImplicitDeref());
+ break;
+
+ case LegalVal::Flavor::pair:
+ {
+ // The case for a pair value is the main interesting one.
+ // We need to write the special part of the return value
+ // to the `out` parameters that were declared to capture
+ // it, and then return the ordinary part of the value
+ // like normal.
+ //
+ // Note that the order here matters, because we need to
+ // emit the code that writes to the `out` parameters
+ // before the `return` instruction.
+ //
+ auto pairVal = val.getPair();
+ _writeResultParam(pairVal->specialVal);
+ returnVal(pairVal->ordinaryVal);
+ }
+ break;
+
+ case LegalVal::Flavor::tuple:
+ {
+ // The tuple case is kind of a degenerate combination
+ // of the `pair` and `none` cases: we need to emit
+ // writes to the `out` parameters declared to capture
+ // the tuple (all of it), and then we do a `return`
+ // of `void` because there is no ordinary result to
+ // capture.
+ //
+ _writeResultParam(val);
+ builder->emitReturn();
+ }
+ break;
+
+ default:
+ // TODO: implement legalization of non-simple return types
+ SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal.");
+ }
+ }
+
+private:
+
+ /// Write `val` to the `out` parameters of the enclosing function
+ void _writeResultParam(LegalVal const& val)
+ {
+ switch (val.flavor)
+ {
+ case LegalVal::Flavor::simple:
+ {
+ // The leaf case here is the interesting one.
+ //
+ // We know that if we are writing to `out` parameters to
+ // represent the function result then the function must
+ // have been legalized in a way that introduced those parameters.
+ // We thus need to look up the information on how the
+ // function got legalized so that we can identify the
+ // new parameters.
+ //
+ // TODO: One detail worth confirming here is whether there
+ // could ever be a case where a `return` instruction gets legalized
+ // before its outer function does.
+ //
+ if( !m_parentFuncInfo )
+ {
+ // We start by searching for the ancestor instruction
+ // that represents the function (or other code-bearing value)
+ // that holds this instruction.
+ //
+ auto p = m_returnInst->getParent();
+ while( p && !as<IRGlobalValueWithCode>(p) )
+ {
+ p = p->parent;
+ }
+
+ // We expect that the parent is actually an IR function.
+ //
+ // TODO: What about the case where we have an `IRGlobalVar`
+ // of a type that needs legalization, and teh variable has
+ // an initializer? For now, I believe that case is disallowed
+ // in the legalization for global variables.
+ //
+ auto parentFunc = as<IRFunc>(p);
+ SLANG_ASSERT(parentFunc);
+ if(!parentFunc)
+ return;
+
+ // We also expect that extended legalization information was
+ // recorded for the function.
+ //
+ RefPtr<LegalFuncInfo> parentFuncInfo;
+ if( !m_context->mapFuncToInfo.TryGetValue(parentFunc, parentFuncInfo) )
+ {
+ // If we fail to find the extended information then either:
+ //
+ // * The parent function has not been legalized yet. This would
+ // be a violation of our assumption about ordering of legalization.
+ //
+ // * The parent function was legalized, but didn't require any
+ // additional IR parameters to represent its result. This would
+ // be a violation of our assumption that the declared result type
+ // of a function and the type at `return` sites inside the function
+ // need to match.
+ //
+ SLANG_ASSERT(parentFuncInfo);
+ return;
+ }
+
+ // If we find the extended information, then this is the first
+ // leaf parameter we are dealing with, so we set up to read through
+ // the parameters starting at index zero.
+ //
+ m_parentFuncInfo = parentFuncInfo;
+ m_resultParamCounter = 0;
+ }
+ SLANG_ASSERT(m_parentFuncInfo);
+
+ // The recursion through the result `val` will iterate over the
+ // leaf parameters in the same order they should have been declared,
+ // so the parameter we need to write to will be the next one in order.
+ //
+ // We expect that the parameter index must be in range, beacuse otherwise
+ // the recursion here and the recursion that declared the parameters are
+ // mismatched in terms of how they traversed the hierarchical representation
+ // of `LegalVal` / `LegalType`.
+ //
+ Index resultParamIndex = m_resultParamCounter++;
+ SLANG_ASSERT(resultParamIndex >= 0);
+ SLANG_ASSERT(resultParamIndex < m_parentFuncInfo->resultParamVals.getCount());
+
+ // Once we've identified the right parameter, we can emit a `store`
+ // to write the value that the function wants to output.
+ //
+ // Note that an `out` parameter is represented with a pointer type
+ // in the IR, so that the `IRParam` here represents a pointer to
+ // the value that will receive the result.
+ //
+ auto resultParamPtr = m_parentFuncInfo->resultParamVals[resultParamIndex];
+ m_context->builder->emitStore(resultParamPtr, val.getSimple());
+ }
+ break;
+
+ // The remaining cases are just a straightforward recursion
+ // over the structure of the `val`.
+
+ case LegalVal::Flavor::none:
+ break;
+
+ case LegalVal::Flavor::implicitDeref:
+ _writeResultParam(val.getImplicitDeref());
+ break;
+
+ case LegalVal::Flavor::pair:
+ {
+ auto pairVal = val.getPair();
+ _writeResultParam(pairVal->ordinaryVal);
+ _writeResultParam(pairVal->specialVal);
+ }
+ break;
+
+ case LegalVal::Flavor::tuple:
+ {
+ auto tupleVal = val.getTuple();
+ for (auto element : tupleVal->elements)
+ {
+ _writeResultParam(element.val);
+ }
+ }
+ break;
+
+ default:
+ // TODO: implement legalization of non-simple return types
+ SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal.");
+ }
}
+
+ IRTypeLegalizationContext* m_context = nullptr;
+ IRReturn* m_returnInst = nullptr;
+
+ RefPtr<LegalFuncInfo> m_parentFuncInfo;
+ Index m_resultParamCounter = 0;
+};
+
+static LegalVal legalizeRetVal(
+ IRTypeLegalizationContext* context,
+ LegalVal retVal,
+ IRReturnVal* returnInst)
+{
+ LegalReturnBuilder builder(context, returnInst);
+ builder.returnVal(retVal);
+ return LegalVal();
}
static LegalVal legalizeLoad(
@@ -1232,7 +1657,7 @@ static LegalVal legalizeInst(
case kIROp_Call:
return legalizeCall(context, (IRCall*)inst);
case kIROp_ReturnVal:
- return legalizeRetVal(context, args[0]);
+ return legalizeRetVal(context, args[0], (IRReturnVal*)inst);
case kIROp_makeStruct:
return legalizeMakeStruct(
context,
@@ -1487,82 +1912,315 @@ static LegalVal legalizeInst(
return legalVal;
}
-static void addParamType(List<IRType*>& ioParamTypes, LegalType t)
+ /// Helper type for legalizing the signature of an `IRFunc`
+struct LegalFuncBuilder
{
- switch (t.flavor)
+ LegalFuncBuilder(IRTypeLegalizationContext* context)
+ : m_context(context)
+ {}
+
+ /// Construct a legalized value to represent `oldFunc`
+ LegalVal build(IRFunc* oldFunc)
{
- case LegalType::Flavor::none:
- break;
+ // We can start by computing what the type signature of the
+ // legalized function should be, based on the type signature
+ // of the original.
+ //
+ IRFuncType* oldFuncType = oldFunc->getDataType();
- case LegalType::Flavor::simple:
- ioParamTypes.add(t.getSimple());
- break;
+ // Each parameter of the original function will translate into
+ // zero or more parameters in the legalized function signature.
+ //
+ UInt oldParamCount = oldFuncType->getParamCount();
+ for (UInt pp = 0; pp < oldParamCount; ++pp)
+ {
+ auto legalParamType = legalizeType(m_context, oldFuncType->getParamType(pp));
+ _addParam(legalParamType);
+ }
- case LegalType::Flavor::implicitDeref:
- {
- auto imp = t.getImplicitDeref();
- addParamType(ioParamTypes, imp->valueType);
- break;
- }
- case LegalType::Flavor::pair:
+ // We will record how many parameters resulted from
+ // legalization of the original / "base" parameter list.
+ // This number will help us in computing how many parameters
+ // were added to capture the result type of the function.
+ //
+ Index baseLegalParamCount = m_paramTypes.getCount();
+
+ // Next we add a result type to the function based on the
+ // legalized result type of the original function.
+ //
+ // It is possible that this process will had one or more
+ // `out` parameters to represent parts of the result type
+ // that couldn't be passed via the ordinary function result.
+ //
+ auto legalResultType = legalizeType(m_context, oldFuncType->getResultType());
+ _addResult(legalResultType);
+
+ // If any part of the result type required new function parameters
+ // to be introduced, then we want to know how many there were.
+ // These additional function paameters will always come after the original
+ // parameters, so that they don't shift around call sites too much.
+ //
+ // TODO: Where we put the added `out` parameters in the signature may
+ // have performance implications when it starts interacting with ABI
+ // (e.g., most ABIs assign parameters to registers from left to right,
+ // so parameters later in the list are more likely to be passed through
+ // memory; we'd need to decide whether the base parameters or the
+ // legalized result parameters should be prioritized for register
+ // allocation).
+ //
+ Index resultParamCount = m_paramTypes.getCount() - baseLegalParamCount;
+
+ // If we didn't bottom out on a result type for the legalized function,
+ // then we should default to returning `void`.
+ //
+ auto irBuilder = m_context->builder;
+ if( !m_resultType )
{
- auto pairInfo = t.getPair();
- addParamType(ioParamTypes, pairInfo->ordinaryType);
- addParamType(ioParamTypes, pairInfo->specialType);
+ m_resultType = irBuilder->getVoidType();
}
- break;
- case LegalType::Flavor::tuple:
- {
- auto tup = t.getTuple();
- for (auto & elem : tup->elements)
- addParamType(ioParamTypes, elem.type);
- }
- break;
- default:
- SLANG_UNEXPECTED("unknown legalized type flavor");
+
+ // We will compute the new IR type for the function and install it
+ // as the data type of original function.
+ //
+ // Note: This is one of the few cases where the legalization pass
+ // prefers to modify an IR node in-place rather than create a distinct
+ // legalized copy of it.
+ //
+ auto newFuncType = irBuilder->getFuncType(
+ m_paramTypes.getCount(),
+ m_paramTypes.getBuffer(),
+ m_resultType);
+ irBuilder->setDataType(oldFunc, newFuncType);
+
+ // If the function required any new parameters to be created
+ // to represent the result/return type, then we need to
+ // actually add the appropriate IR parameters to represent
+ // that stuff as well.
+ //
+ if( resultParamCount != 0 )
+ {
+ // Only a function with a body will need this additonal
+ // step, since the function parameters are stored on the
+ // first block of the body.
+ //
+ auto firstBlock = oldFunc->getFirstBlock();
+ if( firstBlock )
+ {
+ // Because legalization of this function required us
+ // to introduce new parameters, we need to allocate
+ // a data structure to record the identities of those
+ // new parameters so that they can be looked up when
+ // legalizing the body of the function.
+ //
+ // In particular, we will use this information when
+ // legalizing `return` instructions in the function body,
+ // since those will need to store at least part of
+ // the reuslt value into the newly-declared parameter(s).
+ //
+ RefPtr<LegalFuncInfo> funcInfo = new LegalFuncInfo();
+ m_context->mapFuncToInfo.Add(oldFunc, funcInfo);
+
+ // We know that our new parameters need to come after
+ // those that were declared for the "base" parameters
+ // of the original function.
+ //
+ auto firstResultParamIndex = baseLegalParamCount;
+ auto firstOrdinaryInst = firstBlock->getFirstOrdinaryInst();
+ for( Index i = 0; i < resultParamCount; ++i )
+ {
+ // Note: The parameter types that were added to
+ // the `m_paramTypes` array already account for the
+ // fact that these are `out` parameters, since that
+ // impacts the function type signature as well.
+ // We do *not* need to wrap `paramType` in an `Out<...>`
+ // type here.
+ //
+ auto paramType = m_paramTypes[firstResultParamIndex + i];
+ auto param = irBuilder->createParam(paramType);
+ param->insertBefore(firstOrdinaryInst);
+
+ funcInfo->resultParamVals.add(param);
+ }
+ }
+ }
+
+ // Note: at this point we do *not* apply legalization to the parameters
+ // of the function or its body; those are left for the recursive part
+ // of the overall legalization pass to handle.
+
+ return LegalVal::simple(oldFunc);
}
-}
-static LegalVal legalizeFunc(
- IRTypeLegalizationContext* context,
- IRFunc* irFunc)
-{
- // Overwrite the function's type with the result of legalization.
- IRFuncType* oldFuncType = irFunc->getDataType();
- UInt oldParamCount = oldFuncType->getParamCount();
+private:
+ IRTypeLegalizationContext* m_context = nullptr;;
+
+ /// The types of the parameters of the legalized function
+ List<IRType*> m_paramTypes;
+
+ /// The result type of the legalized function (can be null to represent `void`)
+ IRType* m_resultType = nullptr;
- // TODO: we should give an error message when the result type of a function
- // can't be legalized (e.g., trying to return a texture, or a structue that
- // contains one).
- auto legalReturnType = legalizeType(context, oldFuncType->getResultType());
- IRType* newResultType = nullptr;
- switch (legalReturnType.flavor)
+ /// Add a parameter of type `t` to the function signature
+ void _addParam(LegalType t)
{
- case LegalType::Flavor::simple:
- newResultType = legalReturnType.getSimple();
- break;
- case LegalType::Flavor::none:
- newResultType = context->builder->getVoidType();
- break;
- default:
- SLANG_UNEXPECTED("unknown legalized function return type.");
+ // This logic is a simple recursion over the structure of `t`,
+ // with the leaf case adding parameters of simple IR type.
+
+ switch (t.flavor)
+ {
+ case LegalType::Flavor::none:
+ break;
+
+ case LegalType::Flavor::simple:
+ m_paramTypes.add(t.getSimple());
+ break;
+
+ case LegalType::Flavor::implicitDeref:
+ {
+ auto imp = t.getImplicitDeref();
+ _addParam(imp->valueType);
+ }
+ break;
+ case LegalType::Flavor::pair:
+ {
+ auto pairInfo = t.getPair();
+ _addParam(pairInfo->ordinaryType);
+ _addParam(pairInfo->specialType);
+ }
+ break;
+ case LegalType::Flavor::tuple:
+ {
+ auto tup = t.getTuple();
+ for (auto & elem : tup->elements)
+ _addParam(elem.type);
+ }
+ break;
+ default:
+ SLANG_UNEXPECTED("unknown legalized type flavor");
+ }
}
- List<IRType*> newParamTypes;
- for (UInt pp = 0; pp < oldParamCount; ++pp)
+
+ /// Set the logical result type of the legalized function to `t`
+ void _addResult(LegalType t)
{
- auto legalParamType = legalizeType(context, oldFuncType->getParamType(pp));
- addParamType(newParamTypes, legalParamType);
+ switch (t.flavor)
+ {
+ case LegalType::Flavor::simple:
+ // The simple case is when the result type is a simple IR
+ // type, and we can use it directly as the return type.
+ //
+ m_resultType = t.getSimple();
+ break;
+
+
+ case LegalType::Flavor::none:
+ // The case where we have no result type is also simple,
+ // becaues we can leave `m_resultType` as null to represent
+ // a `void` result type.
+ break;
+
+ case LegalType::Flavor::implicitDeref:
+ {
+ // An `implicitDeref` is a wrapper around another legal
+ // type, so we can simply set the result type to the
+ // unwrapped inner type.
+ //
+ auto imp = t.getImplicitDeref();
+ _addResult(imp->valueType);
+ }
+ break;
+
+ case LegalType::Flavor::pair:
+ {
+ // The `pair` case is the first interesting one.
+ //
+ // We will set the actual result type of the operation
+ // to the ordinary side of the pair, while any special
+ // part of the pair will be returned via fresh `out`
+ // parameters insteqad.
+ //
+ auto pairInfo = t.getPair();
+ _addResult(pairInfo->ordinaryType);
+ _addOutParam(pairInfo->specialType);
+ }
+ break;
+
+ case LegalType::Flavor::tuple:
+ {
+ // In the `tuple` case we have zero or more types,
+ // and there is no distinguished primary one that
+ // should become the result type of the legalized function.
+ //
+ // We will instead declare fresh `out` parameters to
+ // capture all the outputs in the tuple.
+ //
+ auto tup = t.getTuple();
+ for( auto & elem : tup->elements )
+ {
+ _addOutParam(elem.type);
+ }
+ }
+ break;
+
+ default:
+ SLANG_UNEXPECTED("unknown legalized type flavor");
+ }
}
- auto newFuncType = context->builder->getFuncType(
- newParamTypes.getCount(),
- newParamTypes.getBuffer(),
- newResultType);
+ /// Add a single `out` parameter based on type `t`.
+ void _addOutParam(LegalType t)
+ {
+ switch (t.flavor)
+ {
+ case LegalType::Flavor::simple:
+ // The simple case here is almost the same as `_addParam()`,
+ // except we wrap the simple type in `Out<...>` to indicate
+ // that we are producing an `out` parameter.
+ //
+ m_paramTypes.add(m_context->builder->getOutType(t.getSimple()));
+ break;
+
+ // The remaining cases are all simple recursion on the
+ // structure of `t`.
- context->builder->setDataType(irFunc, newFuncType);
+ case LegalType::Flavor::none:
+ break;
- return LegalVal::simple(irFunc);
+ case LegalType::Flavor::implicitDeref:
+ {
+ auto imp = t.getImplicitDeref();
+ _addOutParam(imp->valueType);
+ }
+ break;
+ case LegalType::Flavor::pair:
+ {
+ auto pairInfo = t.getPair();
+ _addOutParam(pairInfo->ordinaryType);
+ _addOutParam(pairInfo->specialType);
+ }
+ break;
+ case LegalType::Flavor::tuple:
+ {
+ auto tup = t.getTuple();
+ for( auto & elem : tup->elements )
+ {
+ _addOutParam(elem.type);
+ }
+ }
+ break;
+ default:
+ SLANG_UNEXPECTED("unknown legalized type flavor");
+ }
+ }
+};
+
+static LegalVal legalizeFunc(
+ IRTypeLegalizationContext* context,
+ IRFunc* irFunc)
+{
+ LegalFuncBuilder builder(context);
+ return builder.build(irFunc);
}
static LegalVal declareSimpleVar(
diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp
index 23e68182e..c7398fe23 100644
--- a/source/slang/slang-ir-specialize-resources.cpp
+++ b/source/slang/slang-ir-specialize-resources.cpp
@@ -171,7 +171,7 @@ struct ResourceOutputSpecializationPass
oldFunc,
newFunc);
- // At first `newFunc` is a directclone of `oldFunc`, and thus doesn't
+ // At first `newFunc` is a direct clone of `oldFunc`, and thus doesn't
// solve any of our problems. We will traverse `oldFunc` and specialize
// it as needed, while also collecting information that will allow
// us to rewrite call sites.
@@ -468,8 +468,29 @@ struct ResourceOutputSpecializationPass
//
// Any failures along the way cause the whole process to fail.
- for( auto param : func->getParams() )
+ // Note: We are introducing new parameters at the same time as we
+ // iterate over the parameter list, so we cannot just use the
+ // `func->getParams()` convenience accessor. Instead, we manually
+ // iterate over the parameters in a way that avoids invalidation
+ // if we remove the `param` we are working on.
+ //
+ // Note: it might seem odd that we are modifying `func` but will
+ // still bail out on any errors. You might ask: isn't there a chance
+ // that we will end up with the function in a partially-modified state?
+ //
+ // The important thing to remember is that `func` is *copy* of the
+ // original function, so any modifications we make to it do not
+ // affect the original, so that if we *do* have to bail out we can
+ // leave any call sites intact as calls to the original. The result
+ // is that bailing out here may leave the new/copied function in
+ // a state where it isn't useful, but it also won't have any uses,
+ // and can be eliminated later.
+ //
+ IRParam* nextParam = nullptr;
+ for( IRParam* param = func->getFirstParam(); param; param = nextParam )
{
+ nextParam = param->getNextParam();
+
ParamInfo paramInfo;
SLANG_RETURN_ON_FAIL(maybeSpecializeParam(param, paramInfo, outFuncInfo));
outFuncInfo.oldParams.add(paramInfo);
diff --git a/source/slang/slang-legalize-types.h b/source/slang/slang-legalize-types.h
index 8f2a7572f..600f1d7a7 100644
--- a/source/slang/slang-legalize-types.h
+++ b/source/slang/slang-legalize-types.h
@@ -509,7 +509,7 @@ struct LegalVal
}
static LegalVal implicitDeref(LegalVal const& val);
- LegalVal getImplicitDeref();
+ LegalVal getImplicitDeref() const;
static LegalVal pair(RefPtr<PairPseudoVal> pairInfo);
static LegalVal pair(
@@ -568,6 +568,30 @@ struct WrappedBufferPseudoVal : LegalValImpl
//
+ /// Information about a function that has been legalized
+ ///
+ /// This type is used to track any information about the function
+ /// and its signature that might be relevant to the legalization
+ /// of instructions inside the function body.
+ ///
+struct LegalFuncInfo : RefObject
+{
+ /// Any parameters that were added to the function signature
+ /// to represent the function result after legalization.
+ ///
+ /// It is possible that the result type of a function needed
+ /// to be split into multiple types, and as a result a single
+ /// function result couldn't return all of them.
+ ///
+ /// This array is a list of `out` parameters created to represent
+ /// additional function results. Because they are `out` parameters,
+ /// each is a *pointer* to a value of the relevant type.
+ ///
+ List<IRInst*> resultParamVals;
+};
+
+//
+
/// Context that drives type legalization
///
/// This type is an abstract base class, and there are
@@ -601,6 +625,14 @@ struct IRTypeLegalizationContext
Dictionary<IRType*, LegalType> mapTypeToLegalType;
+ /// Map a function to information about how it was legalized.
+ ///
+ /// Note that entries are only created if there is somehting for them
+ /// to represent, so many functions may lack entries in this map even
+ /// after legalization.
+ ///
+ Dictionary<IRFunc*, RefPtr<LegalFuncInfo>> mapFuncToInfo;
+
IRBuilder* getBuilder() { return builder; }
/// Customization point to decide what types are "special."
diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang
new file mode 100644
index 000000000..ea94e6ffa
--- /dev/null
+++ b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang
@@ -0,0 +1,55 @@
+// inout-param-opaque-type-in-struct.slang
+
+// Test that a function/method can have an `out` parameter of
+// aggregate type that includes an opaque type
+
+//TEST(compute):COMPARE_COMPUTE:
+
+struct Things
+{
+ int first;
+ RWStructuredBuffer<int> rest;
+}
+
+//TEST_INPUT:set C = new { {1, ubuffer(data=[2 3 4 5], stride=4)}, {6, ubuffer(data=[7 8 9 10], stride=4)} }
+cbuffer C
+{
+ Things gX;
+ Things gY;
+}
+
+void swap(
+ inout Things a,
+ inout Things b)
+{
+ Things t = a;
+ a = b;
+ b = t;
+}
+
+int eval(Things t, int val)
+{
+ return t.first*256 + t.rest[val];
+}
+
+int test(int val)
+{
+ Things f = gX;
+ Things g = gY;
+
+ swap(f, g);
+
+ return (eval(f,val) << 16) + eval(g,val);
+}
+
+//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<int> gOutput;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal);
+ gOutput[tid] = outVal;
+}
diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang.expected.txt b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang.expected.txt
new file mode 100644
index 000000000..43533c76f
--- /dev/null
+++ b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang.expected.txt
@@ -0,0 +1,4 @@
+6070102
+6080103
+6090104
+60A0105
diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type.slang b/tests/language-feature/types/opaque/inout-param-opaque-type.slang
new file mode 100644
index 000000000..682f89fd0
--- /dev/null
+++ b/tests/language-feature/types/opaque/inout-param-opaque-type.slang
@@ -0,0 +1,42 @@
+// inout-param-opaque-type.slang
+
+// Test that a function/method can have an `out` parameter of opaque type
+
+//TEST(compute):COMPARE_COMPUTE:
+
+//TEST_INPUT:set gX = ubuffer(data=[16 17 18 19], stride=4)
+RWStructuredBuffer<int> gX;
+
+//TEST_INPUT:set gY = ubuffer(data=[3 6 9 12], stride=4)
+RWStructuredBuffer<int> gY;
+
+void swap(
+ inout RWStructuredBuffer<int> a,
+ inout RWStructuredBuffer<int> b)
+{
+ RWStructuredBuffer<int> t = a;
+ a = b;
+ b = t;
+}
+
+int test(int val)
+{
+ RWStructuredBuffer<int> f = gX;
+ RWStructuredBuffer<int> g = gY;
+
+ swap(f, g);
+
+ return f[val] * 256 + g[val];
+}
+
+//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<int> gOutput;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal);
+ gOutput[tid] = outVal;
+}
diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type.slang.expected.txt b/tests/language-feature/types/opaque/inout-param-opaque-type.slang.expected.txt
new file mode 100644
index 000000000..81cf98393
--- /dev/null
+++ b/tests/language-feature/types/opaque/inout-param-opaque-type.slang.expected.txt
@@ -0,0 +1,4 @@
+310
+611
+912
+C13
diff --git a/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang
new file mode 100644
index 000000000..a6c645c01
--- /dev/null
+++ b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang
@@ -0,0 +1,39 @@
+// out-opaque-type-in-struct.slang
+
+// Test that a function/method can have an `out` parameter of
+// aggregate type that includes an opaque type
+
+//TEST(compute):COMPARE_COMPUTE:
+
+struct Things
+{
+ int first;
+ RWStructuredBuffer<int> rest;
+}
+
+//TEST_INPUT:set gThings = new Things { 1, ubuffer(data=[2 3 4 5], stride=4) }
+ConstantBuffer<Things> gThings;
+
+void getThings(out Things outThings)
+{
+ outThings = gThings;
+}
+
+int test(int val)
+{
+ Things things;
+ getThings(things);
+ return things.first * (16 << val) + things.rest[val];
+}
+
+//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<int> gOutput;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal);
+ gOutput[tid] = outVal;
+}
diff --git a/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang.expected.txt b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang.expected.txt
new file mode 100644
index 000000000..553843b5d
--- /dev/null
+++ b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang.expected.txt
@@ -0,0 +1,4 @@
+12
+23
+44
+85
diff --git a/tests/language-feature/types/opaque/out-param-opaque-type.slang b/tests/language-feature/types/opaque/out-param-opaque-type.slang
new file mode 100644
index 000000000..3ac7c0d6f
--- /dev/null
+++ b/tests/language-feature/types/opaque/out-param-opaque-type.slang
@@ -0,0 +1,33 @@
+// out-opaque-type.slang
+
+// Test that a function/method can have an `out` parameter of opaque type
+
+//TEST(compute):COMPARE_COMPUTE:
+
+//TEST_INPUT:set gThings = ubuffer(data=[16 17 18 19], stride=4)
+RWStructuredBuffer<int> gThings;
+
+
+void getThings(out RWStructuredBuffer<int> things)
+{
+ things = gThings;
+}
+
+int test(int val)
+{
+ RWStructuredBuffer<int> t;
+ getThings(t);
+ return t[val];
+}
+
+//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<int> gOutput;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal);
+ gOutput[tid] = outVal;
+}
diff --git a/tests/language-feature/types/opaque/out-param-opaque-type.slang.expected.txt b/tests/language-feature/types/opaque/out-param-opaque-type.slang.expected.txt
new file mode 100644
index 000000000..a0d427709
--- /dev/null
+++ b/tests/language-feature/types/opaque/out-param-opaque-type.slang.expected.txt
@@ -0,0 +1,4 @@
+10
+11
+12
+13
diff --git a/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang
new file mode 100644
index 000000000..2687af1c3
--- /dev/null
+++ b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang
@@ -0,0 +1,38 @@
+// return-opaque-type-in-struct.slang
+
+// Test that a function/method can return a value of
+// aggregate type that includes an opaque type
+
+//TEST(compute):COMPARE_COMPUTE:
+
+struct Things
+{
+ int first;
+ RWStructuredBuffer<int> rest;
+}
+
+//TEST_INPUT:set gThings = new Things { 1, ubuffer(data=[2 3 4 5], stride=4) }
+ConstantBuffer<Things> gThings;
+
+Things getThings()
+{
+ return gThings;
+}
+
+int test(int val)
+{
+ let things = getThings();
+ return things.first * (16 << val) + things.rest[val];
+}
+
+//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<int> gOutput;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal);
+ gOutput[tid] = outVal;
+}
diff --git a/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang.expected.txt b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang.expected.txt
new file mode 100644
index 000000000..553843b5d
--- /dev/null
+++ b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang.expected.txt
@@ -0,0 +1,4 @@
+12
+23
+44
+85
diff --git a/tests/language-feature/types/opaque/return-opaque-type.slang b/tests/language-feature/types/opaque/return-opaque-type.slang
new file mode 100644
index 000000000..83d4376ba
--- /dev/null
+++ b/tests/language-feature/types/opaque/return-opaque-type.slang
@@ -0,0 +1,32 @@
+// return-opaque-type.slang
+
+// Test that a function/method can return a value of an opaque type.
+
+//TEST(compute):COMPARE_COMPUTE:
+
+struct Stuff
+{
+ RWStructuredBuffer<int> things;
+
+ RWStructuredBuffer<int> getThings() { return things; }
+}
+
+//TEST_INPUT:set gStuff = new Stuff { ubuffer(data=[16 17 18 19], stride=4) }
+ConstantBuffer<Stuff> gStuff;
+
+int test(int val)
+{
+ return gStuff.getThings()[val];
+}
+
+//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<int> gOutput;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal);
+ gOutput[tid] = outVal;
+}
diff --git a/tests/language-feature/types/opaque/return-opaque-type.slang.expected.txt b/tests/language-feature/types/opaque/return-opaque-type.slang.expected.txt
new file mode 100644
index 000000000..a0d427709
--- /dev/null
+++ b/tests/language-feature/types/opaque/return-opaque-type.slang.expected.txt
@@ -0,0 +1,4 @@
+10
+11
+12
+13