diff options
15 files changed, 1095 insertions, 121 deletions
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 1cd94decb..d88166f83 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -70,7 +70,7 @@ LegalVal LegalVal::implicitDeref(LegalVal const& val) return result; } -LegalVal LegalVal::getImplicitDeref() +LegalVal LegalVal::getImplicitDeref() const { SLANG_ASSERT(flavor == Flavor::implicitDeref); return as<ImplicitDerefVal>(obj)->val; @@ -194,90 +194,515 @@ static LegalVal legalizeOperand( return LegalVal::simple(irValue); } -static void getArgumentValues( - List<IRInst*> & instArgs, - LegalVal val) + /// Helper type for legalization an IR `call` instruction +struct LegalCallBuilder { - switch (val.flavor) + LegalCallBuilder( + IRTypeLegalizationContext* context, + IRCall* call) + : m_context(context) + , m_call(call) + {} + + /// The context for legalization + IRTypeLegalizationContext* m_context = nullptr; + + /// The `call` instruction we are legalizing + IRCall* m_call = nullptr; + + /// The legalized arguments for the call + List<IRInst*> m_args; + + /// Add a logical argument to the call (which may map to zero or mmore actual arguments) + void addArg( + LegalVal const& val) { - case LegalVal::Flavor::none: - break; + // In order to add the argument(s) for `val`, + // we will recurse over its structure. - case LegalVal::Flavor::simple: - instArgs.add(val.getSimple()); - break; + switch (val.flavor) + { + case LegalVal::Flavor::none: + break; - case LegalVal::Flavor::implicitDeref: - getArgumentValues(instArgs, val.getImplicitDeref()); - break; + case LegalVal::Flavor::simple: + m_args.add(val.getSimple()); + break; - case LegalVal::Flavor::pair: + case LegalVal::Flavor::implicitDeref: + addArg(val.getImplicitDeref()); + break; + + case LegalVal::Flavor::pair: + { + auto pairVal = val.getPair(); + addArg(pairVal->ordinaryVal); + addArg(pairVal->specialVal); + } + break; + + case LegalVal::Flavor::tuple: + { + auto tuplePsuedoVal = val.getTuple(); + for (auto elem : val.getTuple()->elements) + { + addArg(elem.val); + } + } + break; + + default: + SLANG_UNEXPECTED("uhandled val flavor"); + break; + } + } + + /// Build a new call based on the original, given the expected `resultType`. + /// + /// Returns a value representing the result of the call. + LegalVal build(LegalType const& resultType) + { + // We can recursively decompose the cases for + // how to legalize a call based on the expected + // result type. + // + switch (resultType.flavor) { - auto pairVal = val.getPair(); - getArgumentValues(instArgs, pairVal->ordinaryVal); - getArgumentValues(instArgs, pairVal->specialVal); + case LegalType::Flavor::simple: + // In the case where the result type is simple, + // we can directly emit the `call` instruction + // and use the result as our single result value. + // + return LegalVal::simple(_emitCall(resultType.getSimple())); + + case LegalType::Flavor::none: + // In the case where there is no result type, + // that is equivalent to the call returning `void`. + // + // We directly emit the call and then return an + // empty value to represent the result. + // + _emitCall(m_context->builder->getVoidType()); + return LegalVal(); + + case LegalVal::Flavor::implicitDeref: + // An `implicitDeref` wraps a single value, so we can simply + // unwrap, recurse on the innter value, and then wrap up + // the result. + // + return LegalVal::implicitDeref(build(resultType.getImplicitDeref()->valueType)); + + case LegalVal::Flavor::pair: + { + // A `pair` type consists of both an ordinary part and a special part. + // + auto pairType = resultType.getPair(); + + // The ordinary part will be used as the direct result of the call, + // while the special part will need to be returned via an `out` + // argument. + // + // We will start by emitting the declaration(s) needed for those + // `out` arguments that represent the special part, and adding + // them to the argument list. Basically this step is declaring + // local variables that will hold the special part of the result, + // and it returns a value that repsents those variables. + // + auto specialVal = _addOutArg(pairType->specialType); + + // Once the argument values for the special part are set up, + // we can recurse on the ordinary part and emit the actual + // call operation (which will include our new arguments). + // + auto ordinaryVal = build(pairType->ordinaryType); + + // The resulting value will be a pair of the ordinary value + // (returned from the `call` instruction) and the special value + // (declared as zero or more local variables). + // + RefPtr<PairPseudoVal> pairVal = new PairPseudoVal(); + pairVal->pairInfo = pairType->pairInfo; + pairVal->ordinaryVal = ordinaryVal; + pairVal->specialVal = specialVal; + return LegalVal::pair(pairVal); + } + break; + + case LegalVal::Flavor::tuple: + { + // A `tuple` value consists of zero or more elements + // that are each of a "special" type. We will handle + // *all* of those values as `out` arguments akin to + // what we did for the special half of a pair type + // above. + // + auto resultVal = _addOutArg(resultType); + + // In this case there was no "ordinary" part to the + // result type of the function, so we know that + // the legalization funciton/call will use a `void` + // result type. + // + _emitCall(m_context->builder->getVoidType()); + return resultVal; + } + break; + + default: + // TODO: implement legalization of non-simple return types + SLANG_UNEXPECTED("unimplemented legalized return type for IRCall."); } - break; + } - case LegalVal::Flavor::tuple: +private: + + /// Add an `out` argument to the call, to capture the given `resultType`. + LegalVal _addOutArg(LegalType const& resultType) + { + switch (resultType.flavor) { - auto tuplePsuedoVal = val.getTuple(); - for (auto elem : val.getTuple()->elements) + case LegalType::Flavor::simple: + { + // In the leaf case we have a simple type, and + // we just want to declare a local variable based on it. + // + auto simpleType = resultType.getSimple(); + auto builder = m_context->builder; + + // Recall that a local variable in our IR represents a *pointer* + // to storage of the appropriate type. + // + auto varPtr = builder->emitVar(simpleType); + + // We need to pass that pointer as an argument to our new + // `call` instruction, so that it can receive the value + // written by the callee. + // + m_args.add(varPtr); + + // Note: Because `varPtr` is a pointer to the value we want, + // we have the small problem of needing to return a `LegalVal` + // that has dereferenced the value after the call. + // + // We solve this problem by inserting as `load` from our + // new variable immediately after the call, before going + // and resetting the insertion point to continue inserting + // stuff before the call (which is where we wnat the local + // variable declarations to go). + // + // TODO: Confirm that this logic can't go awry if (somehow) + // there is no instruction after `m_call`. That should not + // be possible inside of a function body, but it could in + // theory be a problem if we ever have top-level module-scope + // code representing initialization of constants and/or globals. + // + builder->setInsertBefore(m_call->getNextInst()); + auto val = builder->emitLoad(simpleType, varPtr); + builder->setInsertBefore(m_call); + + return LegalVal::simple(val); + } + break; + + // The remaining cases are a straightforward structural recursion + // on top of the base case above. + + case LegalType::Flavor::none: + return LegalVal(); + + case LegalVal::Flavor::implicitDeref: + return LegalVal::implicitDeref(_addOutArg(resultType.getImplicitDeref()->valueType)); + + case LegalVal::Flavor::pair: { - getArgumentValues(instArgs, elem.val); + auto pairType = resultType.getPair(); + auto specialVal = _addOutArg(pairType->specialType); + auto ordinaryVal = _addOutArg(pairType->ordinaryType); + + RefPtr<PairPseudoVal> pairVal = new PairPseudoVal(); + pairVal->pairInfo = pairType->pairInfo; + pairVal->ordinaryVal = ordinaryVal; + pairVal->specialVal = specialVal; + + return LegalVal::pair(pairVal); } + break; + + case LegalVal::Flavor::tuple: + { + auto tuplePsuedoType = resultType.getTuple(); + + RefPtr<TuplePseudoVal> tupleVal = new TuplePseudoVal(); + for (auto typeElement : tuplePsuedoType->elements) + { + TuplePseudoVal::Element valElement; + valElement.key = typeElement.key; + valElement.val = _addOutArg(typeElement.type); + tupleVal->elements.add(valElement); + } + + return LegalVal::tuple(tupleVal); + } + break; + + default: + // TODO: implement legalization of non-simple return types + SLANG_UNEXPECTED("unimplemented legalized return type for IRCall."); } - break; + } - default: - SLANG_UNEXPECTED("uhandled val flavor"); - break; + /// Emit the actual `call` instruction given an IR result type + IRInst* _emitCall(IRType* resultType) + { + // The generated call will include all of the arguments that have + // been added up to this point, which includes those that were + // added to represent legalized parts of the result type. + // + return m_context->builder->emitCallInst( + resultType, + m_call->getCallee(), + m_args.getCount(), + m_args.getBuffer()); } -} +}; + static LegalVal legalizeCall( IRTypeLegalizationContext* context, IRCall* callInst) { - auto retType = legalizeType(context, callInst->getFullType()); - IRType* retIRType = nullptr; - switch (retType.flavor) + LegalCallBuilder builder(context, callInst); + + auto argCount = callInst->getArgCount(); + for( UInt i = 0; i < argCount; i++ ) { - case LegalType::Flavor::simple: - retIRType = retType.getSimple(); - break; - case LegalType::Flavor::none: - retIRType = context->builder->getVoidType(); - break; - default: - // TODO: implement legalization of non-simple return types - SLANG_UNEXPECTED("unimplemented legalized return type for IRInstCall."); + auto legalArg = legalizeOperand(context, callInst->getArg(i)); + builder.addArg(legalArg); } - List<IRInst*> instArgs; - for (auto i = 1u; i < callInst->getOperandCount(); i++) - getArgumentValues(instArgs, legalizeOperand(context, callInst->getOperand(i))); - - return LegalVal::simple(context->builder->emitCallInst( - retIRType, - callInst->getCallee(), - instArgs.getCount(), - instArgs.getBuffer())); + auto legalResultType = legalizeType(context, callInst->getFullType()); + return builder.build(legalResultType); } -static LegalVal legalizeRetVal(IRTypeLegalizationContext* context, - LegalVal retVal) + /// Helper type for legalizing a `returnVal` instruction +struct LegalReturnBuilder { - switch (retVal.flavor) + LegalReturnBuilder(IRTypeLegalizationContext* context, IRReturn* returnInst) + : m_context(context) + , m_returnInst(returnInst) + {} + + /// Emit code to perform a return of `val` + void returnVal(LegalVal val) { - case LegalVal::Flavor::simple: - return LegalVal::simple(context->builder->emitReturn(retVal.getSimple())); - case LegalVal::Flavor::none: - return LegalVal::simple(context->builder->emitReturn()); - default: - // TODO: implement legalization of non-simple return types - SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal."); + auto builder = m_context->builder; + + switch (val.flavor) + { + case LegalVal::Flavor::simple: + // The case of a simple value is easy: just emit a `returnVal`. + // + builder->emitReturn(val.getSimple()); + break; + + case LegalVal::Flavor::none: + // The case of an empty/void value is also easy: emit a `return`. + // + builder->emitReturn(); + break; + + case LegalVal::Flavor::implicitDeref: + returnVal(val.getImplicitDeref()); + break; + + case LegalVal::Flavor::pair: + { + // The case for a pair value is the main interesting one. + // We need to write the special part of the return value + // to the `out` parameters that were declared to capture + // it, and then return the ordinary part of the value + // like normal. + // + // Note that the order here matters, because we need to + // emit the code that writes to the `out` parameters + // before the `return` instruction. + // + auto pairVal = val.getPair(); + _writeResultParam(pairVal->specialVal); + returnVal(pairVal->ordinaryVal); + } + break; + + case LegalVal::Flavor::tuple: + { + // The tuple case is kind of a degenerate combination + // of the `pair` and `none` cases: we need to emit + // writes to the `out` parameters declared to capture + // the tuple (all of it), and then we do a `return` + // of `void` because there is no ordinary result to + // capture. + // + _writeResultParam(val); + builder->emitReturn(); + } + break; + + default: + // TODO: implement legalization of non-simple return types + SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal."); + } + } + +private: + + /// Write `val` to the `out` parameters of the enclosing function + void _writeResultParam(LegalVal const& val) + { + switch (val.flavor) + { + case LegalVal::Flavor::simple: + { + // The leaf case here is the interesting one. + // + // We know that if we are writing to `out` parameters to + // represent the function result then the function must + // have been legalized in a way that introduced those parameters. + // We thus need to look up the information on how the + // function got legalized so that we can identify the + // new parameters. + // + // TODO: One detail worth confirming here is whether there + // could ever be a case where a `return` instruction gets legalized + // before its outer function does. + // + if( !m_parentFuncInfo ) + { + // We start by searching for the ancestor instruction + // that represents the function (or other code-bearing value) + // that holds this instruction. + // + auto p = m_returnInst->getParent(); + while( p && !as<IRGlobalValueWithCode>(p) ) + { + p = p->parent; + } + + // We expect that the parent is actually an IR function. + // + // TODO: What about the case where we have an `IRGlobalVar` + // of a type that needs legalization, and teh variable has + // an initializer? For now, I believe that case is disallowed + // in the legalization for global variables. + // + auto parentFunc = as<IRFunc>(p); + SLANG_ASSERT(parentFunc); + if(!parentFunc) + return; + + // We also expect that extended legalization information was + // recorded for the function. + // + RefPtr<LegalFuncInfo> parentFuncInfo; + if( !m_context->mapFuncToInfo.TryGetValue(parentFunc, parentFuncInfo) ) + { + // If we fail to find the extended information then either: + // + // * The parent function has not been legalized yet. This would + // be a violation of our assumption about ordering of legalization. + // + // * The parent function was legalized, but didn't require any + // additional IR parameters to represent its result. This would + // be a violation of our assumption that the declared result type + // of a function and the type at `return` sites inside the function + // need to match. + // + SLANG_ASSERT(parentFuncInfo); + return; + } + + // If we find the extended information, then this is the first + // leaf parameter we are dealing with, so we set up to read through + // the parameters starting at index zero. + // + m_parentFuncInfo = parentFuncInfo; + m_resultParamCounter = 0; + } + SLANG_ASSERT(m_parentFuncInfo); + + // The recursion through the result `val` will iterate over the + // leaf parameters in the same order they should have been declared, + // so the parameter we need to write to will be the next one in order. + // + // We expect that the parameter index must be in range, beacuse otherwise + // the recursion here and the recursion that declared the parameters are + // mismatched in terms of how they traversed the hierarchical representation + // of `LegalVal` / `LegalType`. + // + Index resultParamIndex = m_resultParamCounter++; + SLANG_ASSERT(resultParamIndex >= 0); + SLANG_ASSERT(resultParamIndex < m_parentFuncInfo->resultParamVals.getCount()); + + // Once we've identified the right parameter, we can emit a `store` + // to write the value that the function wants to output. + // + // Note that an `out` parameter is represented with a pointer type + // in the IR, so that the `IRParam` here represents a pointer to + // the value that will receive the result. + // + auto resultParamPtr = m_parentFuncInfo->resultParamVals[resultParamIndex]; + m_context->builder->emitStore(resultParamPtr, val.getSimple()); + } + break; + + // The remaining cases are just a straightforward recursion + // over the structure of the `val`. + + case LegalVal::Flavor::none: + break; + + case LegalVal::Flavor::implicitDeref: + _writeResultParam(val.getImplicitDeref()); + break; + + case LegalVal::Flavor::pair: + { + auto pairVal = val.getPair(); + _writeResultParam(pairVal->ordinaryVal); + _writeResultParam(pairVal->specialVal); + } + break; + + case LegalVal::Flavor::tuple: + { + auto tupleVal = val.getTuple(); + for (auto element : tupleVal->elements) + { + _writeResultParam(element.val); + } + } + break; + + default: + // TODO: implement legalization of non-simple return types + SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal."); + } } + + IRTypeLegalizationContext* m_context = nullptr; + IRReturn* m_returnInst = nullptr; + + RefPtr<LegalFuncInfo> m_parentFuncInfo; + Index m_resultParamCounter = 0; +}; + +static LegalVal legalizeRetVal( + IRTypeLegalizationContext* context, + LegalVal retVal, + IRReturnVal* returnInst) +{ + LegalReturnBuilder builder(context, returnInst); + builder.returnVal(retVal); + return LegalVal(); } static LegalVal legalizeLoad( @@ -1232,7 +1657,7 @@ static LegalVal legalizeInst( case kIROp_Call: return legalizeCall(context, (IRCall*)inst); case kIROp_ReturnVal: - return legalizeRetVal(context, args[0]); + return legalizeRetVal(context, args[0], (IRReturnVal*)inst); case kIROp_makeStruct: return legalizeMakeStruct( context, @@ -1487,82 +1912,315 @@ static LegalVal legalizeInst( return legalVal; } -static void addParamType(List<IRType*>& ioParamTypes, LegalType t) + /// Helper type for legalizing the signature of an `IRFunc` +struct LegalFuncBuilder { - switch (t.flavor) + LegalFuncBuilder(IRTypeLegalizationContext* context) + : m_context(context) + {} + + /// Construct a legalized value to represent `oldFunc` + LegalVal build(IRFunc* oldFunc) { - case LegalType::Flavor::none: - break; + // We can start by computing what the type signature of the + // legalized function should be, based on the type signature + // of the original. + // + IRFuncType* oldFuncType = oldFunc->getDataType(); - case LegalType::Flavor::simple: - ioParamTypes.add(t.getSimple()); - break; + // Each parameter of the original function will translate into + // zero or more parameters in the legalized function signature. + // + UInt oldParamCount = oldFuncType->getParamCount(); + for (UInt pp = 0; pp < oldParamCount; ++pp) + { + auto legalParamType = legalizeType(m_context, oldFuncType->getParamType(pp)); + _addParam(legalParamType); + } - case LegalType::Flavor::implicitDeref: - { - auto imp = t.getImplicitDeref(); - addParamType(ioParamTypes, imp->valueType); - break; - } - case LegalType::Flavor::pair: + // We will record how many parameters resulted from + // legalization of the original / "base" parameter list. + // This number will help us in computing how many parameters + // were added to capture the result type of the function. + // + Index baseLegalParamCount = m_paramTypes.getCount(); + + // Next we add a result type to the function based on the + // legalized result type of the original function. + // + // It is possible that this process will had one or more + // `out` parameters to represent parts of the result type + // that couldn't be passed via the ordinary function result. + // + auto legalResultType = legalizeType(m_context, oldFuncType->getResultType()); + _addResult(legalResultType); + + // If any part of the result type required new function parameters + // to be introduced, then we want to know how many there were. + // These additional function paameters will always come after the original + // parameters, so that they don't shift around call sites too much. + // + // TODO: Where we put the added `out` parameters in the signature may + // have performance implications when it starts interacting with ABI + // (e.g., most ABIs assign parameters to registers from left to right, + // so parameters later in the list are more likely to be passed through + // memory; we'd need to decide whether the base parameters or the + // legalized result parameters should be prioritized for register + // allocation). + // + Index resultParamCount = m_paramTypes.getCount() - baseLegalParamCount; + + // If we didn't bottom out on a result type for the legalized function, + // then we should default to returning `void`. + // + auto irBuilder = m_context->builder; + if( !m_resultType ) { - auto pairInfo = t.getPair(); - addParamType(ioParamTypes, pairInfo->ordinaryType); - addParamType(ioParamTypes, pairInfo->specialType); + m_resultType = irBuilder->getVoidType(); } - break; - case LegalType::Flavor::tuple: - { - auto tup = t.getTuple(); - for (auto & elem : tup->elements) - addParamType(ioParamTypes, elem.type); - } - break; - default: - SLANG_UNEXPECTED("unknown legalized type flavor"); + + // We will compute the new IR type for the function and install it + // as the data type of original function. + // + // Note: This is one of the few cases where the legalization pass + // prefers to modify an IR node in-place rather than create a distinct + // legalized copy of it. + // + auto newFuncType = irBuilder->getFuncType( + m_paramTypes.getCount(), + m_paramTypes.getBuffer(), + m_resultType); + irBuilder->setDataType(oldFunc, newFuncType); + + // If the function required any new parameters to be created + // to represent the result/return type, then we need to + // actually add the appropriate IR parameters to represent + // that stuff as well. + // + if( resultParamCount != 0 ) + { + // Only a function with a body will need this additonal + // step, since the function parameters are stored on the + // first block of the body. + // + auto firstBlock = oldFunc->getFirstBlock(); + if( firstBlock ) + { + // Because legalization of this function required us + // to introduce new parameters, we need to allocate + // a data structure to record the identities of those + // new parameters so that they can be looked up when + // legalizing the body of the function. + // + // In particular, we will use this information when + // legalizing `return` instructions in the function body, + // since those will need to store at least part of + // the reuslt value into the newly-declared parameter(s). + // + RefPtr<LegalFuncInfo> funcInfo = new LegalFuncInfo(); + m_context->mapFuncToInfo.Add(oldFunc, funcInfo); + + // We know that our new parameters need to come after + // those that were declared for the "base" parameters + // of the original function. + // + auto firstResultParamIndex = baseLegalParamCount; + auto firstOrdinaryInst = firstBlock->getFirstOrdinaryInst(); + for( Index i = 0; i < resultParamCount; ++i ) + { + // Note: The parameter types that were added to + // the `m_paramTypes` array already account for the + // fact that these are `out` parameters, since that + // impacts the function type signature as well. + // We do *not* need to wrap `paramType` in an `Out<...>` + // type here. + // + auto paramType = m_paramTypes[firstResultParamIndex + i]; + auto param = irBuilder->createParam(paramType); + param->insertBefore(firstOrdinaryInst); + + funcInfo->resultParamVals.add(param); + } + } + } + + // Note: at this point we do *not* apply legalization to the parameters + // of the function or its body; those are left for the recursive part + // of the overall legalization pass to handle. + + return LegalVal::simple(oldFunc); } -} -static LegalVal legalizeFunc( - IRTypeLegalizationContext* context, - IRFunc* irFunc) -{ - // Overwrite the function's type with the result of legalization. - IRFuncType* oldFuncType = irFunc->getDataType(); - UInt oldParamCount = oldFuncType->getParamCount(); +private: + IRTypeLegalizationContext* m_context = nullptr;; + + /// The types of the parameters of the legalized function + List<IRType*> m_paramTypes; + + /// The result type of the legalized function (can be null to represent `void`) + IRType* m_resultType = nullptr; - // TODO: we should give an error message when the result type of a function - // can't be legalized (e.g., trying to return a texture, or a structue that - // contains one). - auto legalReturnType = legalizeType(context, oldFuncType->getResultType()); - IRType* newResultType = nullptr; - switch (legalReturnType.flavor) + /// Add a parameter of type `t` to the function signature + void _addParam(LegalType t) { - case LegalType::Flavor::simple: - newResultType = legalReturnType.getSimple(); - break; - case LegalType::Flavor::none: - newResultType = context->builder->getVoidType(); - break; - default: - SLANG_UNEXPECTED("unknown legalized function return type."); + // This logic is a simple recursion over the structure of `t`, + // with the leaf case adding parameters of simple IR type. + + switch (t.flavor) + { + case LegalType::Flavor::none: + break; + + case LegalType::Flavor::simple: + m_paramTypes.add(t.getSimple()); + break; + + case LegalType::Flavor::implicitDeref: + { + auto imp = t.getImplicitDeref(); + _addParam(imp->valueType); + } + break; + case LegalType::Flavor::pair: + { + auto pairInfo = t.getPair(); + _addParam(pairInfo->ordinaryType); + _addParam(pairInfo->specialType); + } + break; + case LegalType::Flavor::tuple: + { + auto tup = t.getTuple(); + for (auto & elem : tup->elements) + _addParam(elem.type); + } + break; + default: + SLANG_UNEXPECTED("unknown legalized type flavor"); + } } - List<IRType*> newParamTypes; - for (UInt pp = 0; pp < oldParamCount; ++pp) + + /// Set the logical result type of the legalized function to `t` + void _addResult(LegalType t) { - auto legalParamType = legalizeType(context, oldFuncType->getParamType(pp)); - addParamType(newParamTypes, legalParamType); + switch (t.flavor) + { + case LegalType::Flavor::simple: + // The simple case is when the result type is a simple IR + // type, and we can use it directly as the return type. + // + m_resultType = t.getSimple(); + break; + + + case LegalType::Flavor::none: + // The case where we have no result type is also simple, + // becaues we can leave `m_resultType` as null to represent + // a `void` result type. + break; + + case LegalType::Flavor::implicitDeref: + { + // An `implicitDeref` is a wrapper around another legal + // type, so we can simply set the result type to the + // unwrapped inner type. + // + auto imp = t.getImplicitDeref(); + _addResult(imp->valueType); + } + break; + + case LegalType::Flavor::pair: + { + // The `pair` case is the first interesting one. + // + // We will set the actual result type of the operation + // to the ordinary side of the pair, while any special + // part of the pair will be returned via fresh `out` + // parameters insteqad. + // + auto pairInfo = t.getPair(); + _addResult(pairInfo->ordinaryType); + _addOutParam(pairInfo->specialType); + } + break; + + case LegalType::Flavor::tuple: + { + // In the `tuple` case we have zero or more types, + // and there is no distinguished primary one that + // should become the result type of the legalized function. + // + // We will instead declare fresh `out` parameters to + // capture all the outputs in the tuple. + // + auto tup = t.getTuple(); + for( auto & elem : tup->elements ) + { + _addOutParam(elem.type); + } + } + break; + + default: + SLANG_UNEXPECTED("unknown legalized type flavor"); + } } - auto newFuncType = context->builder->getFuncType( - newParamTypes.getCount(), - newParamTypes.getBuffer(), - newResultType); + /// Add a single `out` parameter based on type `t`. + void _addOutParam(LegalType t) + { + switch (t.flavor) + { + case LegalType::Flavor::simple: + // The simple case here is almost the same as `_addParam()`, + // except we wrap the simple type in `Out<...>` to indicate + // that we are producing an `out` parameter. + // + m_paramTypes.add(m_context->builder->getOutType(t.getSimple())); + break; + + // The remaining cases are all simple recursion on the + // structure of `t`. - context->builder->setDataType(irFunc, newFuncType); + case LegalType::Flavor::none: + break; - return LegalVal::simple(irFunc); + case LegalType::Flavor::implicitDeref: + { + auto imp = t.getImplicitDeref(); + _addOutParam(imp->valueType); + } + break; + case LegalType::Flavor::pair: + { + auto pairInfo = t.getPair(); + _addOutParam(pairInfo->ordinaryType); + _addOutParam(pairInfo->specialType); + } + break; + case LegalType::Flavor::tuple: + { + auto tup = t.getTuple(); + for( auto & elem : tup->elements ) + { + _addOutParam(elem.type); + } + } + break; + default: + SLANG_UNEXPECTED("unknown legalized type flavor"); + } + } +}; + +static LegalVal legalizeFunc( + IRTypeLegalizationContext* context, + IRFunc* irFunc) +{ + LegalFuncBuilder builder(context); + return builder.build(irFunc); } static LegalVal declareSimpleVar( diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index 23e68182e..c7398fe23 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -171,7 +171,7 @@ struct ResourceOutputSpecializationPass oldFunc, newFunc); - // At first `newFunc` is a directclone of `oldFunc`, and thus doesn't + // At first `newFunc` is a direct clone of `oldFunc`, and thus doesn't // solve any of our problems. We will traverse `oldFunc` and specialize // it as needed, while also collecting information that will allow // us to rewrite call sites. @@ -468,8 +468,29 @@ struct ResourceOutputSpecializationPass // // Any failures along the way cause the whole process to fail. - for( auto param : func->getParams() ) + // Note: We are introducing new parameters at the same time as we + // iterate over the parameter list, so we cannot just use the + // `func->getParams()` convenience accessor. Instead, we manually + // iterate over the parameters in a way that avoids invalidation + // if we remove the `param` we are working on. + // + // Note: it might seem odd that we are modifying `func` but will + // still bail out on any errors. You might ask: isn't there a chance + // that we will end up with the function in a partially-modified state? + // + // The important thing to remember is that `func` is *copy* of the + // original function, so any modifications we make to it do not + // affect the original, so that if we *do* have to bail out we can + // leave any call sites intact as calls to the original. The result + // is that bailing out here may leave the new/copied function in + // a state where it isn't useful, but it also won't have any uses, + // and can be eliminated later. + // + IRParam* nextParam = nullptr; + for( IRParam* param = func->getFirstParam(); param; param = nextParam ) { + nextParam = param->getNextParam(); + ParamInfo paramInfo; SLANG_RETURN_ON_FAIL(maybeSpecializeParam(param, paramInfo, outFuncInfo)); outFuncInfo.oldParams.add(paramInfo); diff --git a/source/slang/slang-legalize-types.h b/source/slang/slang-legalize-types.h index 8f2a7572f..600f1d7a7 100644 --- a/source/slang/slang-legalize-types.h +++ b/source/slang/slang-legalize-types.h @@ -509,7 +509,7 @@ struct LegalVal } static LegalVal implicitDeref(LegalVal const& val); - LegalVal getImplicitDeref(); + LegalVal getImplicitDeref() const; static LegalVal pair(RefPtr<PairPseudoVal> pairInfo); static LegalVal pair( @@ -568,6 +568,30 @@ struct WrappedBufferPseudoVal : LegalValImpl // + /// Information about a function that has been legalized + /// + /// This type is used to track any information about the function + /// and its signature that might be relevant to the legalization + /// of instructions inside the function body. + /// +struct LegalFuncInfo : RefObject +{ + /// Any parameters that were added to the function signature + /// to represent the function result after legalization. + /// + /// It is possible that the result type of a function needed + /// to be split into multiple types, and as a result a single + /// function result couldn't return all of them. + /// + /// This array is a list of `out` parameters created to represent + /// additional function results. Because they are `out` parameters, + /// each is a *pointer* to a value of the relevant type. + /// + List<IRInst*> resultParamVals; +}; + +// + /// Context that drives type legalization /// /// This type is an abstract base class, and there are @@ -601,6 +625,14 @@ struct IRTypeLegalizationContext Dictionary<IRType*, LegalType> mapTypeToLegalType; + /// Map a function to information about how it was legalized. + /// + /// Note that entries are only created if there is somehting for them + /// to represent, so many functions may lack entries in this map even + /// after legalization. + /// + Dictionary<IRFunc*, RefPtr<LegalFuncInfo>> mapFuncToInfo; + IRBuilder* getBuilder() { return builder; } /// Customization point to decide what types are "special." diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang new file mode 100644 index 000000000..ea94e6ffa --- /dev/null +++ b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang @@ -0,0 +1,55 @@ +// inout-param-opaque-type-in-struct.slang + +// Test that a function/method can have an `out` parameter of +// aggregate type that includes an opaque type + +//TEST(compute):COMPARE_COMPUTE: + +struct Things +{ + int first; + RWStructuredBuffer<int> rest; +} + +//TEST_INPUT:set C = new { {1, ubuffer(data=[2 3 4 5], stride=4)}, {6, ubuffer(data=[7 8 9 10], stride=4)} } +cbuffer C +{ + Things gX; + Things gY; +} + +void swap( + inout Things a, + inout Things b) +{ + Things t = a; + a = b; + b = t; +} + +int eval(Things t, int val) +{ + return t.first*256 + t.rest[val]; +} + +int test(int val) +{ + Things f = gX; + Things g = gY; + + swap(f, g); + + return (eval(f,val) << 16) + eval(g,val); +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang.expected.txt b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang.expected.txt new file mode 100644 index 000000000..43533c76f --- /dev/null +++ b/tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang.expected.txt @@ -0,0 +1,4 @@ +6070102 +6080103 +6090104 +60A0105 diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type.slang b/tests/language-feature/types/opaque/inout-param-opaque-type.slang new file mode 100644 index 000000000..682f89fd0 --- /dev/null +++ b/tests/language-feature/types/opaque/inout-param-opaque-type.slang @@ -0,0 +1,42 @@ +// inout-param-opaque-type.slang + +// Test that a function/method can have an `out` parameter of opaque type + +//TEST(compute):COMPARE_COMPUTE: + +//TEST_INPUT:set gX = ubuffer(data=[16 17 18 19], stride=4) +RWStructuredBuffer<int> gX; + +//TEST_INPUT:set gY = ubuffer(data=[3 6 9 12], stride=4) +RWStructuredBuffer<int> gY; + +void swap( + inout RWStructuredBuffer<int> a, + inout RWStructuredBuffer<int> b) +{ + RWStructuredBuffer<int> t = a; + a = b; + b = t; +} + +int test(int val) +{ + RWStructuredBuffer<int> f = gX; + RWStructuredBuffer<int> g = gY; + + swap(f, g); + + return f[val] * 256 + g[val]; +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/inout-param-opaque-type.slang.expected.txt b/tests/language-feature/types/opaque/inout-param-opaque-type.slang.expected.txt new file mode 100644 index 000000000..81cf98393 --- /dev/null +++ b/tests/language-feature/types/opaque/inout-param-opaque-type.slang.expected.txt @@ -0,0 +1,4 @@ +310 +611 +912 +C13 diff --git a/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang new file mode 100644 index 000000000..a6c645c01 --- /dev/null +++ b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang @@ -0,0 +1,39 @@ +// out-opaque-type-in-struct.slang + +// Test that a function/method can have an `out` parameter of +// aggregate type that includes an opaque type + +//TEST(compute):COMPARE_COMPUTE: + +struct Things +{ + int first; + RWStructuredBuffer<int> rest; +} + +//TEST_INPUT:set gThings = new Things { 1, ubuffer(data=[2 3 4 5], stride=4) } +ConstantBuffer<Things> gThings; + +void getThings(out Things outThings) +{ + outThings = gThings; +} + +int test(int val) +{ + Things things; + getThings(things); + return things.first * (16 << val) + things.rest[val]; +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang.expected.txt b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang.expected.txt new file mode 100644 index 000000000..553843b5d --- /dev/null +++ b/tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang.expected.txt @@ -0,0 +1,4 @@ +12 +23 +44 +85 diff --git a/tests/language-feature/types/opaque/out-param-opaque-type.slang b/tests/language-feature/types/opaque/out-param-opaque-type.slang new file mode 100644 index 000000000..3ac7c0d6f --- /dev/null +++ b/tests/language-feature/types/opaque/out-param-opaque-type.slang @@ -0,0 +1,33 @@ +// out-opaque-type.slang + +// Test that a function/method can have an `out` parameter of opaque type + +//TEST(compute):COMPARE_COMPUTE: + +//TEST_INPUT:set gThings = ubuffer(data=[16 17 18 19], stride=4) +RWStructuredBuffer<int> gThings; + + +void getThings(out RWStructuredBuffer<int> things) +{ + things = gThings; +} + +int test(int val) +{ + RWStructuredBuffer<int> t; + getThings(t); + return t[val]; +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/out-param-opaque-type.slang.expected.txt b/tests/language-feature/types/opaque/out-param-opaque-type.slang.expected.txt new file mode 100644 index 000000000..a0d427709 --- /dev/null +++ b/tests/language-feature/types/opaque/out-param-opaque-type.slang.expected.txt @@ -0,0 +1,4 @@ +10 +11 +12 +13 diff --git a/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang new file mode 100644 index 000000000..2687af1c3 --- /dev/null +++ b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang @@ -0,0 +1,38 @@ +// return-opaque-type-in-struct.slang + +// Test that a function/method can return a value of +// aggregate type that includes an opaque type + +//TEST(compute):COMPARE_COMPUTE: + +struct Things +{ + int first; + RWStructuredBuffer<int> rest; +} + +//TEST_INPUT:set gThings = new Things { 1, ubuffer(data=[2 3 4 5], stride=4) } +ConstantBuffer<Things> gThings; + +Things getThings() +{ + return gThings; +} + +int test(int val) +{ + let things = getThings(); + return things.first * (16 << val) + things.rest[val]; +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang.expected.txt b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang.expected.txt new file mode 100644 index 000000000..553843b5d --- /dev/null +++ b/tests/language-feature/types/opaque/return-opaque-type-in-struct.slang.expected.txt @@ -0,0 +1,4 @@ +12 +23 +44 +85 diff --git a/tests/language-feature/types/opaque/return-opaque-type.slang b/tests/language-feature/types/opaque/return-opaque-type.slang new file mode 100644 index 000000000..83d4376ba --- /dev/null +++ b/tests/language-feature/types/opaque/return-opaque-type.slang @@ -0,0 +1,32 @@ +// return-opaque-type.slang + +// Test that a function/method can return a value of an opaque type. + +//TEST(compute):COMPARE_COMPUTE: + +struct Stuff +{ + RWStructuredBuffer<int> things; + + RWStructuredBuffer<int> getThings() { return things; } +} + +//TEST_INPUT:set gStuff = new Stuff { ubuffer(data=[16 17 18 19], stride=4) } +ConstantBuffer<Stuff> gStuff; + +int test(int val) +{ + return gStuff.getThings()[val]; +} + +//TEST_INPUT:set gOutput = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<int> gOutput; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + gOutput[tid] = outVal; +} diff --git a/tests/language-feature/types/opaque/return-opaque-type.slang.expected.txt b/tests/language-feature/types/opaque/return-opaque-type.slang.expected.txt new file mode 100644 index 000000000..a0d427709 --- /dev/null +++ b/tests/language-feature/types/opaque/return-opaque-type.slang.expected.txt @@ -0,0 +1,4 @@ +10 +11 +12 +13 |
