// ir-legalize-types.cpp // This file implements type legalization for the IR. // It uses the core legalization logic in // `legalize-types.{h,cpp}` to decide what to do with // the types, while this file handles the actual // rewriting of the IR to use the new types. // // This pass should only be applied to IR that has been // fully specialized (no more generics/interfaces), so // that the concrete type of everything is known. #include "ir.h" #include "ir-insts.h" #include "legalize-types.h" #include "mangle.h" #include "name.h" namespace Slang { LegalVal LegalVal::tuple(RefPtr tupleVal) { SLANG_ASSERT(tupleVal->elements.Count()); LegalVal result; result.flavor = LegalVal::Flavor::tuple; result.obj = tupleVal; return result; } LegalVal LegalVal::pair(RefPtr pairInfo) { LegalVal result; result.flavor = LegalVal::Flavor::pair; result.obj = pairInfo; return result; } LegalVal LegalVal::pair( LegalVal const& ordinaryVal, LegalVal const& specialVal, RefPtr pairInfo) { if (ordinaryVal.flavor == LegalVal::Flavor::none) return specialVal; if (specialVal.flavor == LegalVal::Flavor::none) return ordinaryVal; RefPtr obj = new PairPseudoVal(); obj->ordinaryVal = ordinaryVal; obj->specialVal = specialVal; obj->pairInfo = pairInfo; return LegalVal::pair(obj); } LegalVal LegalVal::implicitDeref(LegalVal const& val) { RefPtr implicitDerefVal = new ImplicitDerefVal(); implicitDerefVal->val = val; LegalVal result; result.flavor = LegalVal::Flavor::implicitDeref; result.obj = implicitDerefVal; return result; } LegalVal LegalVal::getImplicitDeref() { SLANG_ASSERT(flavor == Flavor::implicitDeref); return as(obj)->val; } // IRTypeLegalizationContext::IRTypeLegalizationContext( IRModule* inModule) { session = inModule->getSession(); module = inModule; auto sharedBuilder = &sharedBuilderStorage; sharedBuilder->session = session; sharedBuilder->module = module; builder = &builderStorage; builder->sharedBuilder = sharedBuilder; } static void registerLegalizedValue( IRTypeLegalizationContext* context, IRInst* irValue, LegalVal const& legalVal) { context->mapValToLegalVal[irValue] = legalVal; } struct IRGlobalNameInfo { IRInst* globalVar; UInt counter; }; static LegalVal declareVars( IRTypeLegalizationContext* context, IROp op, LegalType type, TypeLayout* typeLayout, LegalVarChain* varChain, UnownedStringSlice nameHint, IRGlobalNameInfo* globalNameInfo); // Take a value that is being used as an operand, // and turn it into the equivalent legalized value. static LegalVal legalizeOperand( IRTypeLegalizationContext* context, IRInst* irValue) { LegalVal legalVal; if (context->mapValToLegalVal.TryGetValue(irValue, legalVal)) return legalVal; // For now, assume that anything not covered // by the mapping is legal as-is. return LegalVal::simple(irValue); } static void getArgumentValues( List & instArgs, LegalVal val) { switch (val.flavor) { case LegalVal::Flavor::none: break; case LegalVal::Flavor::simple: instArgs.Add(val.getSimple()); break; case LegalVal::Flavor::implicitDeref: getArgumentValues(instArgs, val.getImplicitDeref()); break; case LegalVal::Flavor::pair: { auto pairVal = val.getPair(); getArgumentValues(instArgs, pairVal->ordinaryVal); getArgumentValues(instArgs, pairVal->specialVal); } break; case LegalVal::Flavor::tuple: { auto tuplePsuedoVal = val.getTuple(); for (auto elem : val.getTuple()->elements) { getArgumentValues(instArgs, elem.val); } } break; default: SLANG_UNEXPECTED("uhandled val flavor"); break; } } static LegalVal legalizeCall( IRTypeLegalizationContext* context, IRCall* callInst) { auto retType = legalizeType(context, callInst->getFullType()); IRType* retIRType = nullptr; switch (retType.flavor) { case LegalType::Flavor::simple: retIRType = retType.getSimple(); break; case LegalType::Flavor::none: retIRType = context->builder->getVoidType(); break; default: // TODO: implement legalization of non-simple return types SLANG_UNEXPECTED("unimplemented legalized return type for IRInstCall."); } List instArgs; for (auto i = 1u; i < callInst->getOperandCount(); i++) getArgumentValues(instArgs, legalizeOperand(context, callInst->getOperand(i))); return LegalVal::simple(context->builder->emitCallInst( retIRType, callInst->getCallee(), instArgs.Count(), instArgs.Buffer())); } static LegalVal legalizeRetVal(IRTypeLegalizationContext* context, LegalVal retVal) { switch (retVal.flavor) { case LegalVal::Flavor::simple: return LegalVal::simple(context->builder->emitReturn(retVal.getSimple())); case LegalVal::Flavor::none: return LegalVal::simple(context->builder->emitReturn()); default: // TODO: implement legalization of non-simple return types SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal."); } } static LegalVal legalizeLoad( IRTypeLegalizationContext* context, LegalVal legalPtrVal) { switch (legalPtrVal.flavor) { case LegalVal::Flavor::none: return LegalVal(); case LegalVal::Flavor::simple: { return LegalVal::simple( context->builder->emitLoad(legalPtrVal.getSimple())); } break; case LegalVal::Flavor::implicitDeref: // We have turne a pointer(-like) type into its pointed-to (value) // type, and so the operation of loading goes away; we just use // the underlying value. return legalPtrVal.getImplicitDeref(); case LegalVal::Flavor::pair: { auto ptrPairVal = legalPtrVal.getPair(); auto ordinaryVal = legalizeLoad(context, ptrPairVal->ordinaryVal); auto specialVal = legalizeLoad(context, ptrPairVal->specialVal); return LegalVal::pair(ordinaryVal, specialVal, ptrPairVal->pairInfo); } case LegalVal::Flavor::tuple: { // We need to emit a load for each element of // the tuple. auto ptrTupleVal = legalPtrVal.getTuple(); RefPtr tupleVal = new TuplePseudoVal(); for (auto ee : legalPtrVal.getTuple()->elements) { TuplePseudoVal::Element element; element.key = ee.key; element.val = legalizeLoad(context, ee.val); tupleVal->elements.Add(element); } return LegalVal::tuple(tupleVal); } break; default: SLANG_UNEXPECTED("unhandled case"); break; } } static LegalVal legalizeStore( IRTypeLegalizationContext* context, LegalVal legalPtrVal, LegalVal legalVal) { switch (legalPtrVal.flavor) { case LegalVal::Flavor::none: return LegalVal(); case LegalVal::Flavor::simple: { context->builder->emitStore(legalPtrVal.getSimple(), legalVal.getSimple()); return legalVal; } break; case LegalVal::Flavor::implicitDeref: // TODO: what is the right behavior here? // // The crux of the problem is that we may legalize a pointer-to-pointer // type in cases where one of the two needs to become an implicit-deref, // so that we have `PtrA>` become, say, `PtrA` with // an `implicitDeref` wrapper. When we encounter a store to that // wrapped value, we seemingly need to know whether the original code // meant to store to `*ptrPtr` or `**ptrPtr`, and need to legalize // the result accordingly... // if( legalVal.flavor == LegalVal::Flavor::implicitDeref ) return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal.getImplicitDeref()); else return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal); case LegalVal::Flavor::pair: { auto destPair = legalPtrVal.getPair(); auto valPair = legalVal.getPair(); legalizeStore(context, destPair->ordinaryVal, valPair->ordinaryVal); legalizeStore(context, destPair->specialVal, valPair->specialVal); return LegalVal(); } case LegalVal::Flavor::tuple: { // We need to emit a store for each element of // the tuple. auto destTuple = legalPtrVal.getTuple(); auto valTuple = legalVal.getTuple(); SLANG_ASSERT(destTuple->elements.Count() == valTuple->elements.Count()); for (UInt i = 0; i < valTuple->elements.Count(); i++) { legalizeStore(context, destTuple->elements[i].val, valTuple->elements[i].val); } return legalVal; } break; default: SLANG_UNEXPECTED("unhandled case"); break; } } static LegalVal legalizeFieldExtract( IRTypeLegalizationContext* context, LegalType type, LegalVal legalStructOperand, IRStructKey* fieldKey) { auto builder = context->builder; if (type.flavor == LegalType::Flavor::none) return LegalVal(); switch (legalStructOperand.flavor) { case LegalVal::Flavor::none: return LegalVal(); case LegalVal::Flavor::simple: return LegalVal::simple( builder->emitFieldExtract( type.getSimple(), legalStructOperand.getSimple(), fieldKey)); case LegalVal::Flavor::pair: { // There are two sides, the ordinary and the special, // and we basically just dispatch to both of them. auto pairVal = legalStructOperand.getPair(); auto pairInfo = pairVal->pairInfo; auto pairElement = pairInfo->findElement(fieldKey); if (!pairElement) { SLANG_UNEXPECTED("didn't find tuple element"); UNREACHABLE_RETURN(LegalVal()); } // If the field we are extracting has a pair type, // that means it exists on both the ordinary and // special sides. RefPtr fieldPairInfo; LegalType ordinaryType = type; LegalType specialType = type; if (type.flavor == LegalType::Flavor::pair) { auto fieldPairType = type.getPair(); fieldPairInfo = fieldPairType->pairInfo; ordinaryType = fieldPairType->ordinaryType; specialType = fieldPairType->specialType; } LegalVal ordinaryVal; LegalVal specialVal; if (pairElement->flags & PairInfo::kFlag_hasOrdinary) { ordinaryVal = legalizeFieldExtract( context, ordinaryType, pairVal->ordinaryVal, fieldKey); } if (pairElement->flags & PairInfo::kFlag_hasSpecial) { specialVal = legalizeFieldExtract( context, specialType, pairVal->specialVal, fieldKey); } return LegalVal::pair(ordinaryVal, specialVal, fieldPairInfo); } break; case LegalVal::Flavor::tuple: { // The operand is a tuple of pointer-like // values, we want to extract the element // corresponding to a field. We will handle // this by simply returning the corresponding // element from the operand. auto ptrTupleInfo = legalStructOperand.getTuple(); for (auto ee : ptrTupleInfo->elements) { if (ee.key == fieldKey) { return ee.val; } } // TODO: we can legally reach this case now // when the field is "ordinary". SLANG_UNEXPECTED("didn't find tuple element"); UNREACHABLE_RETURN(LegalVal()); } default: SLANG_UNEXPECTED("unhandled"); UNREACHABLE_RETURN(LegalVal()); } } static LegalVal legalizeFieldExtract( IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, LegalVal legalFieldOperand) { // We don't expect any legalization to affect // the "field" argument. auto fieldKey = legalFieldOperand.getSimple(); return legalizeFieldExtract( context, type, legalPtrOperand, (IRStructKey*) fieldKey); } /// Take a value of some buffer/pointer type and wrap it according to provided info. static LegalVal wrapBufferValue( IRTypeLegalizationContext* context, LegalVal legalPtrOperand, LegalElementWrapping const& elementInfo) { // The `elementInfo` tells us how a non-simple element // type was wrapped up into a new structure types used // as the element type of the buffer. // // This function will recurse through the structure of // `elementInfo` to pull out all the required data from // the buffer represented by `legalPtrOperand`. switch( elementInfo.flavor ) { default: SLANG_UNEXPECTED("unhandled"); UNREACHABLE_RETURN(LegalVal()); break; case LegalElementWrapping::Flavor::none: return LegalVal(); case LegalElementWrapping::Flavor::simple: { // In the leaf case, we just had to store some // data of a simple type in the buffer. We can // produce a valid result by computing the // address of the field used to represent the // element, and then returning *that* as if // it were the buffer type itself. // // (Basically instead of `someBuffer` we will // end up with `&(someBuffer->field)`. // auto builder = context->getBuilder(); auto simpleElementInfo = elementInfo.getSimple(); auto valPtr = builder->emitFieldAddress( simpleElementInfo->type, legalPtrOperand.getSimple(), simpleElementInfo->key); return LegalVal::simple(valPtr); } case LegalElementWrapping::Flavor::implicitDeref: { // If the element type was logically `ImplicitDeref`, // then we declared actual fields based on `T`, and // we need to extract references to those fields and // wrap them up in an `implicitDeref` value. // auto derefField = elementInfo.getImplicitDeref(); auto baseVal = wrapBufferValue(context, legalPtrOperand, derefField->field); return LegalVal::implicitDeref(baseVal); } case LegalElementWrapping::Flavor::pair: { // If the element type was logically a `Pair` // then we encoded fields for both `O` and `S` into // the actual element type, and now we need to // extract references to both and pair them up. // auto pairField = elementInfo.getPair(); auto pairInfo = pairField->pairInfo; auto ordinaryVal = wrapBufferValue(context, legalPtrOperand, pairField->ordinary); auto specialVal = wrapBufferValue(context, legalPtrOperand, pairField->special); return LegalVal::pair(ordinaryVal, specialVal, pairInfo); } case LegalElementWrapping::Flavor::tuple: { // If the element type was logically a `Tuple` // then we encoded fields for each of the `Ei` and // need to extract references to all of them and // encode them as a tuple. // auto tupleField = elementInfo.getTuple(); RefPtr obj = new TuplePseudoVal(); for( auto ee : tupleField->elements ) { auto elementVal = wrapBufferValue( context, legalPtrOperand, ee.field); TuplePseudoVal::Element element; element.key = ee.key; element.val = wrapBufferValue( context, legalPtrOperand, ee.field); obj->elements.Add(element); } return LegalVal::tuple(obj); } } } static IRType* getPointedToType( IRTypeLegalizationContext* context, IRType* ptrType) { auto valueType = tryGetPointedToType(context->builder, ptrType); if( !valueType ) { SLANG_UNEXPECTED("expected a pointer type during type legalization"); } return valueType; } static LegalType getPointedToType( IRTypeLegalizationContext* context, LegalType type) { switch( type.flavor ) { case LegalType::Flavor::none: return LegalType(); case LegalType::Flavor::simple: return LegalType::simple(getPointedToType(context, type.getSimple())); case LegalType::Flavor::implicitDeref: return type.getImplicitDeref()->valueType; case LegalType::Flavor::pair: { auto pairType = type.getPair(); auto ordinary = getPointedToType(context, pairType->ordinaryType); auto special = getPointedToType(context, pairType->specialType); return LegalType::pair(ordinary, special, pairType->pairInfo); } case LegalType::Flavor::tuple: { auto tupleType = type.getTuple(); RefPtr resultTuple = new TuplePseudoType(); for( auto ee : tupleType->elements ) { TuplePseudoType::Element resultElement; resultElement.key = ee.key; resultElement.type = getPointedToType(context, ee.type); resultTuple->elements.Add(resultElement); } return LegalType::tuple(resultTuple); } default: SLANG_UNEXPECTED("unhandled case in type legalization"); UNREACHABLE_RETURN(LegalType()); } } static LegalVal legalizeFieldAddress( IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, IRStructKey* fieldKey) { auto builder = context->builder; if (type.flavor == LegalType::Flavor::none) return LegalVal(); switch (legalPtrOperand.flavor) { case LegalVal::Flavor::none: return LegalVal(); case LegalVal::Flavor::simple: switch( type.flavor ) { case LegalType::Flavor::implicitDeref: // TODO: Should this case be needed? return legalizeFieldAddress( context, type.getImplicitDeref()->valueType, legalPtrOperand, fieldKey); default: return LegalVal::simple( builder->emitFieldAddress( type.getSimple(), legalPtrOperand.getSimple(), fieldKey)); } case LegalVal::Flavor::pair: { // There are two sides, the ordinary and the special, // and we basically just dispatch to both of them. auto pairVal = legalPtrOperand.getPair(); auto pairInfo = pairVal->pairInfo; auto pairElement = pairInfo->findElement(fieldKey); if (!pairElement) { SLANG_UNEXPECTED("didn't find tuple element"); UNREACHABLE_RETURN(LegalVal()); } // If the field we are extracting has a pair type, // that means it exists on both the ordinary and // special sides. RefPtr fieldPairInfo; LegalType ordinaryType = type; LegalType specialType = type; if (type.flavor == LegalType::Flavor::pair) { auto fieldPairType = type.getPair(); fieldPairInfo = fieldPairType->pairInfo; ordinaryType = fieldPairType->ordinaryType; specialType = fieldPairType->specialType; } LegalVal ordinaryVal; LegalVal specialVal; if (pairElement->flags & PairInfo::kFlag_hasOrdinary) { ordinaryVal = legalizeFieldAddress( context, ordinaryType, pairVal->ordinaryVal, fieldKey); } if (pairElement->flags & PairInfo::kFlag_hasSpecial) { specialVal = legalizeFieldAddress( context, specialType, pairVal->specialVal, fieldKey); } return LegalVal::pair(ordinaryVal, specialVal, fieldPairInfo); } break; case LegalVal::Flavor::tuple: { // The operand is a tuple of pointer-like // values, we want to extract the element // corresponding to a field. We will handle // this by simply returning the corresponding // element from the operand. auto ptrTupleInfo = legalPtrOperand.getTuple(); for (auto ee : ptrTupleInfo->elements) { if (ee.key == fieldKey) { return ee.val; } } // TODO: we can legally reach this case now // when the field is "ordinary". SLANG_UNEXPECTED("didn't find tuple element"); UNREACHABLE_RETURN(LegalVal()); } case LegalVal::Flavor::implicitDeref: { // The original value had a level of indirection // that is now being removed, so should not be // able to get at the *address* of the field any // more, and need to resign ourselves to just // getting at the field *value* and then // adding an implicit dereference on top of that. // auto implicitDerefVal = legalPtrOperand.getImplicitDeref(); auto valueType = getPointedToType(context, type); return LegalVal::implicitDeref(legalizeFieldExtract(context, valueType, implicitDerefVal, fieldKey)); } default: SLANG_UNEXPECTED("unhandled"); UNREACHABLE_RETURN(LegalVal()); } } static LegalVal legalizeFieldAddress( IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, LegalVal legalFieldOperand) { // We don't expect any legalization to affect // the "field" argument. auto fieldKey = legalFieldOperand.getSimple(); return legalizeFieldAddress( context, type, legalPtrOperand, (IRStructKey*) fieldKey); } static LegalVal legalizeGetElement( IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, IRInst* indexOperand) { auto builder = context->builder; switch (legalPtrOperand.flavor) { case LegalVal::Flavor::none: return LegalVal(); case LegalVal::Flavor::simple: return LegalVal::simple( builder->emitElementExtract( type.getSimple(), legalPtrOperand.getSimple(), indexOperand)); case LegalVal::Flavor::pair: { // There are two sides, the ordinary and the special, // and we basically just dispatch to both of them. auto pairVal = legalPtrOperand.getPair(); auto pairInfo = pairVal->pairInfo; LegalType ordinaryType = type; LegalType specialType = type; if (type.flavor == LegalType::Flavor::pair) { auto pairType = type.getPair(); ordinaryType = pairType->ordinaryType; specialType = pairType->specialType; } LegalVal ordinaryVal = legalizeGetElement( context, ordinaryType, pairVal->ordinaryVal, indexOperand); LegalVal specialVal = legalizeGetElement( context, specialType, pairVal->specialVal, indexOperand); return LegalVal::pair(ordinaryVal, specialVal, pairInfo); } break; case LegalVal::Flavor::tuple: { // The operand is a tuple of pointer-like // values, we want to extract the element // corresponding to a field. We will handle // this by simply returning the corresponding // element from the operand. auto ptrTupleInfo = legalPtrOperand.getTuple(); RefPtr resTupleInfo = new TuplePseudoVal(); auto tupleType = type.getTuple(); SLANG_ASSERT(tupleType); auto elemCount = ptrTupleInfo->elements.Count(); SLANG_ASSERT(elemCount == tupleType->elements.Count()); for(UInt ee = 0; ee < elemCount; ++ee) { auto ptrElem = ptrTupleInfo->elements[ee]; auto elemType = tupleType->elements[ee].type; TuplePseudoVal::Element resElem; resElem.key = ptrElem.key; resElem.val = legalizeGetElement( context, elemType, ptrElem.val, indexOperand); resTupleInfo->elements.Add(resElem); } return LegalVal::tuple(resTupleInfo); } default: SLANG_UNEXPECTED("unhandled"); UNREACHABLE_RETURN(LegalVal()); } } static LegalVal legalizeGetElement( IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, LegalVal legalIndexOperand) { // We don't expect any legalization to affect // the "index" argument. auto indexOperand = legalIndexOperand.getSimple(); return legalizeGetElement( context, type, legalPtrOperand, indexOperand); } static LegalVal legalizeGetElementPtr( IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, IRInst* indexOperand) { auto builder = context->builder; switch (legalPtrOperand.flavor) { case LegalVal::Flavor::none: return LegalVal(); case LegalVal::Flavor::simple: return LegalVal::simple( builder->emitElementAddress( type.getSimple(), legalPtrOperand.getSimple(), indexOperand)); case LegalVal::Flavor::pair: { // There are two sides, the ordinary and the special, // and we basically just dispatch to both of them. auto pairVal = legalPtrOperand.getPair(); auto pairInfo = pairVal->pairInfo; LegalType ordinaryType = type; LegalType specialType = type; if (type.flavor == LegalType::Flavor::pair) { auto pairType = type.getPair(); ordinaryType = pairType->ordinaryType; specialType = pairType->specialType; } LegalVal ordinaryVal = legalizeGetElementPtr( context, ordinaryType, pairVal->ordinaryVal, indexOperand); LegalVal specialVal = legalizeGetElementPtr( context, specialType, pairVal->specialVal, indexOperand); return LegalVal::pair(ordinaryVal, specialVal, pairInfo); } break; case LegalVal::Flavor::tuple: { // The operand is a tuple of pointer-like // values, we want to extract the element // corresponding to a field. We will handle // this by simply returning the corresponding // element from the operand. auto ptrTupleInfo = legalPtrOperand.getTuple(); RefPtr resTupleInfo = new TuplePseudoVal(); auto tupleType = type.getTuple(); SLANG_ASSERT(tupleType); auto elemCount = ptrTupleInfo->elements.Count(); SLANG_ASSERT(elemCount == tupleType->elements.Count()); for(UInt ee = 0; ee < elemCount; ++ee) { auto ptrElem = ptrTupleInfo->elements[ee]; auto elemType = tupleType->elements[ee].type; TuplePseudoVal::Element resElem; resElem.key = ptrElem.key; resElem.val = legalizeGetElementPtr( context, elemType, ptrElem.val, indexOperand); resTupleInfo->elements.Add(resElem); } return LegalVal::tuple(resTupleInfo); } case LegalVal::Flavor::implicitDeref: { // The original value used to be a pointer to an array, // and somebody is trying to get at an element pointer. // Now we just have an array (wrapped with an implicit // dereference) and need to just fetch the chosen element // instead (and then wrap the element value with an // implicit dereference). // // The result type for our `getElement` instruction needs // to be the type *pointed to* by `type`, and not `type. // auto valueType = getPointedToType(context, type); auto implicitDerefVal = legalPtrOperand.getImplicitDeref(); return LegalVal::implicitDeref(legalizeGetElement( context, valueType, implicitDerefVal, indexOperand)); } default: SLANG_UNEXPECTED("unhandled"); UNREACHABLE_RETURN(LegalVal()); } } static LegalVal legalizeGetElementPtr( IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, LegalVal legalIndexOperand) { // We don't expect any legalization to affect // the "index" argument. auto indexOperand = legalIndexOperand.getSimple(); return legalizeGetElementPtr( context, type, legalPtrOperand, indexOperand); } static LegalVal legalizeMakeStruct( IRTypeLegalizationContext* context, LegalType legalType, LegalVal const* legalArgs, UInt argCount) { auto builder = context->builder; switch(legalType.flavor) { case LegalType::Flavor::none: return LegalVal(); case LegalType::Flavor::simple: { List args; for(UInt aa = 0; aa < argCount; ++aa) { // Note: we assume that all the arguments // must be simple here, because otherwise // the `struct` type with them as fields // would not be simple... // args.Add(legalArgs[aa].getSimple()); } return LegalVal::simple( builder->emitMakeStruct( legalType.getSimple(), argCount, args.Buffer())); } case LegalType::Flavor::pair: { // There are two sides, the ordinary and the special, // and we basically just dispatch to both of them. auto pairType = legalType.getPair(); auto pairInfo = pairType->pairInfo; LegalType ordinaryType = pairType->ordinaryType; LegalType specialType = pairType->specialType; List ordinaryArgs; List specialArgs; UInt argCounter = 0; for(auto ee : pairInfo->elements) { UInt argIndex = argCounter++; LegalVal arg = legalArgs[argIndex]; if( arg.flavor == LegalVal::Flavor::pair ) { // The argument is itself a pair auto argPair = arg.getPair(); ordinaryArgs.Add(argPair->ordinaryVal); specialArgs.Add(argPair->specialVal); } else if(ee.flags & Slang::PairInfo::kFlag_hasOrdinary) { ordinaryArgs.Add(arg); } else if(ee.flags & Slang::PairInfo::kFlag_hasSpecial) { specialArgs.Add(arg); } } LegalVal ordinaryVal = legalizeMakeStruct( context, ordinaryType, ordinaryArgs.Buffer(), ordinaryArgs.Count()); LegalVal specialVal = legalizeMakeStruct( context, specialType, specialArgs.Buffer(), specialArgs.Count()); return LegalVal::pair(ordinaryVal, specialVal, pairInfo); } break; case LegalType::Flavor::tuple: { // We are constructing a tuple of values from // the individual fields. We need to identify // for each tuple element what field it uses, // and then extract that field's value. auto tupleType = legalType.getTuple(); RefPtr resTupleInfo = new TuplePseudoVal(); UInt argCounter = 0; for(auto typeElem : tupleType->elements) { auto elemKey = typeElem.key; UInt argIndex = argCounter++; SLANG_ASSERT(argIndex < argCount); LegalVal argVal = legalArgs[argIndex]; TuplePseudoVal::Element resElem; resElem.key = elemKey; resElem.val = argVal; resTupleInfo->elements.Add(resElem); } return LegalVal::tuple(resTupleInfo); } default: SLANG_UNEXPECTED("unhandled"); UNREACHABLE_RETURN(LegalVal()); } } static LegalVal legalizeConstruct(IRTypeLegalizationContext* context, LegalType type) { switch (type.flavor) { case LegalType::Flavor::none: return LegalVal(); case LegalType::Flavor::simple: return LegalVal::simple(context->builder->emitConstructorInst(type.getSimple(), 0, nullptr)); default: SLANG_UNEXPECTED("unhandled legalization case for construct inst."); UNREACHABLE_RETURN(LegalVal()); } } static LegalVal legalizeInst( IRTypeLegalizationContext* context, IRInst* inst, LegalType type, LegalVal const* args) { switch (inst->op) { case kIROp_Load: return legalizeLoad(context, args[0]); case kIROp_FieldAddress: return legalizeFieldAddress(context, type, args[0], args[1]); case kIROp_FieldExtract: return legalizeFieldExtract(context, type, args[0], args[1]); case kIROp_getElement: return legalizeGetElement(context, type, args[0], args[1]); case kIROp_getElementPtr: return legalizeGetElementPtr(context, type, args[0], args[1]); case kIROp_Store: return legalizeStore(context, args[0], args[1]); case kIROp_Call: return legalizeCall(context, (IRCall*)inst); case kIROp_ReturnVal: return legalizeRetVal(context, args[0]); case kIROp_makeStruct: return legalizeMakeStruct( context, type, args, inst->getOperandCount()); case kIROp_Construct: return legalizeConstruct(context, type); case kIROp_undefined: return LegalVal(); default: // TODO: produce a user-visible diagnostic here SLANG_UNEXPECTED("non-simple operand(s)!"); break; } } RefPtr findVarLayout(IRInst* value) { if (auto layoutDecoration = value->findDecoration()) return as(layoutDecoration->getLayout()); return nullptr; } static UnownedStringSlice findNameHint(IRInst* inst) { if( auto nameHintDecoration = inst->findDecoration() ) { return nameHintDecoration->getName(); } return UnownedStringSlice(); } static LegalVal legalizeLocalVar( IRTypeLegalizationContext* context, IRVar* irLocalVar) { // Legalize the type for the variable's value auto legalValueType = legalizeType( context, irLocalVar->getDataType()->getValueType()); auto originalRate = irLocalVar->getRate(); RefPtr varLayout = findVarLayout(irLocalVar); RefPtr typeLayout = varLayout ? varLayout->typeLayout : nullptr; // If we've decided to do implicit deref on the type, // then go ahead and declare a value of the pointed-to type. LegalType maybeSimpleType = legalValueType; while (maybeSimpleType.flavor == LegalType::Flavor::implicitDeref) { maybeSimpleType = maybeSimpleType.getImplicitDeref()->valueType; } switch (maybeSimpleType.flavor) { case LegalType::Flavor::simple: { // Easy case: the type is usable as-is, and we // should just do that. auto type = maybeSimpleType.getSimple(); type = context->builder->getPtrType(type); if( originalRate ) { type = context->builder->getRateQualifiedType( originalRate, type); } irLocalVar->setFullType(type); return LegalVal::simple(irLocalVar); } default: { // TODO: We don't handle rates in this path. context->insertBeforeLocalVar = irLocalVar; LegalVarChain* varChain = nullptr; LegalVarChain varChainStorage; if (varLayout) { varChainStorage.next = nullptr; varChainStorage.varLayout = varLayout; varChain = &varChainStorage; } UnownedStringSlice nameHint = findNameHint(irLocalVar); LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain, nameHint, nullptr); // Remove the old local var. irLocalVar->removeFromParent(); // add old local var to list context->replacedInstructions.Add(irLocalVar); return newVal; } break; } } static LegalVal legalizeParam( IRTypeLegalizationContext* context, IRParam* originalParam) { auto legalParamType = legalizeType(context, originalParam->getFullType()); if (legalParamType.flavor == LegalType::Flavor::simple) { // Simple case: things were legalized to a simple type, // so we can just use the original parameter as-is. originalParam->setFullType(legalParamType.getSimple()); return LegalVal::simple(originalParam); } else { // Complex case: we need to insert zero or more new parameters, // which will replace the old ones. context->insertBeforeParam = originalParam; UnownedStringSlice nameHint = findNameHint(originalParam); auto newVal = declareVars(context, kIROp_Param, legalParamType, nullptr, nullptr, nameHint, nullptr); originalParam->removeFromParent(); context->replacedInstructions.Add(originalParam); return newVal; } } static LegalVal legalizeFunc( IRTypeLegalizationContext* context, IRFunc* irFunc); static LegalVal legalizeGlobalVar( IRTypeLegalizationContext* context, IRGlobalVar* irGlobalVar); static LegalVal legalizeGlobalConstant( IRTypeLegalizationContext* context, IRGlobalConstant* irGlobalConstant); static LegalVal legalizeGlobalParam( IRTypeLegalizationContext* context, IRGlobalParam* irGlobalParam); static LegalVal legalizeInst( IRTypeLegalizationContext* context, IRInst* inst) { // Special-case certain operations switch (inst->op) { case kIROp_Var: return legalizeLocalVar(context, cast(inst)); case kIROp_Param: return legalizeParam(context, cast(inst)); case kIROp_WitnessTable: // Just skip these. break; case kIROp_Func: return legalizeFunc(context, cast(inst)); case kIROp_GlobalVar: return legalizeGlobalVar(context, cast(inst)); case kIROp_GlobalConstant: return legalizeGlobalConstant(context, cast(inst)); case kIROp_GlobalParam: return legalizeGlobalParam(context, cast(inst)); default: break; } // Need to legalize all the operands. auto argCount = inst->getOperandCount(); List legalArgs; bool anyComplex = false; for (UInt aa = 0; aa < argCount; ++aa) { auto oldArg = inst->getOperand(aa); auto legalArg = legalizeOperand(context, oldArg); legalArgs.Add(legalArg); if (legalArg.flavor != LegalVal::Flavor::simple) anyComplex = true; } // Also legalize the type of the instruction LegalType legalType = legalizeType(context, inst->getFullType()); if (!anyComplex && legalType.flavor == LegalType::Flavor::simple) { // Nothing interesting happened to the operands, // so we seem to be okay, right? for (UInt aa = 0; aa < argCount; ++aa) { auto legalArg = legalArgs[aa]; inst->setOperand(aa, legalArg.getSimple()); } inst->setFullType(legalType.getSimple()); return LegalVal::simple(inst); } // We have at least one "complex" operand, and we // need to figure out what to do with it. The anwer // will, in general, depend on what we are doing. // We will set up the IR builder so that any new // instructions generated will be placed before // the location of the original instruction. auto builder = context->builder; builder->setInsertBefore(inst); LegalVal legalVal = legalizeInst( context, inst, legalType, legalArgs.Buffer()); // After we are done, we will eliminate the // original instruction by removing it from // the IR. // inst->removeFromParent(); context->replacedInstructions.Add(inst); // The value to be used when referencing // the original instruction will now be // whatever value(s) we created to replace it. return legalVal; } static void addParamType(List& ioParamTypes, LegalType t) { switch (t.flavor) { case LegalType::Flavor::none: break; case LegalType::Flavor::simple: ioParamTypes.Add(t.getSimple()); break; case LegalType::Flavor::implicitDeref: { auto imp = t.getImplicitDeref(); addParamType(ioParamTypes, imp->valueType); break; } case LegalType::Flavor::pair: { auto pairInfo = t.getPair(); addParamType(ioParamTypes, pairInfo->ordinaryType); addParamType(ioParamTypes, pairInfo->specialType); } break; case LegalType::Flavor::tuple: { auto tup = t.getTuple(); for (auto & elem : tup->elements) addParamType(ioParamTypes, elem.type); } break; default: SLANG_UNEXPECTED("unknown legalized type flavor"); } } static void legalizeInstsInParent( IRTypeLegalizationContext* context, IRInst* parent) { IRInst* nextChild = nullptr; for(auto child = parent->getFirstDecorationOrChild(); child; child = nextChild) { nextChild = child->getNextInst(); if (auto block = as(child)) { legalizeInstsInParent(context, block); } else { LegalVal legalVal = legalizeInst(context, child); registerLegalizedValue(context, child, legalVal); } } } static LegalVal legalizeFunc( IRTypeLegalizationContext* context, IRFunc* irFunc) { // Overwrite the function's type with the result of legalization. IRFuncType* oldFuncType = irFunc->getDataType(); UInt oldParamCount = oldFuncType->getParamCount(); // TODO: we should give an error message when the result type of a function // can't be legalized (e.g., trying to return a texture, or a structue that // contains one). auto legalReturnType = legalizeType(context, oldFuncType->getResultType()); IRType* newResultType = nullptr; switch (legalReturnType.flavor) { case LegalType::Flavor::simple: newResultType = legalReturnType.getSimple(); break; case LegalType::Flavor::none: newResultType = context->builder->getVoidType(); break; default: SLANG_UNEXPECTED("unknown legalized function return type."); } List newParamTypes; for (UInt pp = 0; pp < oldParamCount; ++pp) { auto legalParamType = legalizeType(context, oldFuncType->getParamType(pp)); addParamType(newParamTypes, legalParamType); } auto newFuncType = context->builder->getFuncType( newParamTypes.Count(), newParamTypes.Buffer(), newResultType); context->builder->setDataType(irFunc, newFuncType); legalizeInstsInParent(context, irFunc); return LegalVal::simple(irFunc); } static LegalVal declareSimpleVar( IRTypeLegalizationContext* context, IROp op, IRType* type, TypeLayout* typeLayout, LegalVarChain* varChain, UnownedStringSlice nameHint, IRGlobalNameInfo* globalNameInfo) { SLANG_UNUSED(globalNameInfo); RefPtr varLayout = createVarLayout(varChain, typeLayout); DeclRef varDeclRef; if (varChain) { varDeclRef = varChain->varLayout->varDecl; } IRBuilder* builder = context->builder; IRInst* irVar = nullptr; LegalVal legalVarVal; switch (op) { case kIROp_GlobalVar: { auto globalVar = builder->createGlobalVar(type); globalVar->removeFromParent(); globalVar->insertBefore(context->insertBeforeGlobal); irVar = globalVar; legalVarVal = LegalVal::simple(irVar); } break; case kIROp_GlobalConstant: { auto globalConst = builder->createGlobalConstant(type); globalConst->removeFromParent(); globalConst->insertBefore(context->insertBeforeGlobal); irVar = globalConst; legalVarVal = LegalVal::simple(globalConst); } break; case kIROp_GlobalParam: { auto globalParam = builder->createGlobalParam(type); globalParam->removeFromParent(); globalParam->insertBefore(context->insertBeforeGlobal); irVar = globalParam; legalVarVal = LegalVal::simple(globalParam); } break; case kIROp_Var: { builder->setInsertBefore(context->insertBeforeLocalVar); auto localVar = builder->emitVar(type); irVar = localVar; legalVarVal = LegalVal::simple(irVar); } break; case kIROp_Param: { auto param = builder->emitParam(type); param->insertBefore(context->insertBeforeParam); irVar = param; legalVarVal = LegalVal::simple(irVar); } break; default: SLANG_UNEXPECTED("unexpected IR opcode"); break; } if (irVar) { if (varLayout) { builder->addLayoutDecoration(irVar, varLayout); } if (varDeclRef) { builder->addHighLevelDeclDecoration(irVar, varDeclRef.getDecl()); } if( nameHint.size() ) { context->builder->addNameHintDecoration(irVar, nameHint); } } return legalVarVal; } static LegalVal declareVars( IRTypeLegalizationContext* context, IROp op, LegalType type, TypeLayout* typeLayout, LegalVarChain* varChain, UnownedStringSlice nameHint, IRGlobalNameInfo* globalNameInfo) { switch (type.flavor) { case LegalType::Flavor::none: return LegalVal(); case LegalType::Flavor::simple: return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain, nameHint, globalNameInfo); break; case LegalType::Flavor::implicitDeref: { // Just declare a variable of the pointed-to type, // since we are removing the indirection. auto val = declareVars( context, op, type.getImplicitDeref()->valueType, typeLayout, varChain, nameHint, globalNameInfo); return LegalVal::implicitDeref(val); } break; case LegalType::Flavor::pair: { auto pairType = type.getPair(); auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain, nameHint, globalNameInfo); auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain, nameHint, globalNameInfo); return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo); } case LegalType::Flavor::tuple: { // Declare one variable for each element of the tuple auto tupleType = type.getTuple(); RefPtr tupleVal = new TuplePseudoVal(); for (auto ee : tupleType->elements) { // Fields are currently required to have linkage, since we use // their mangled name to look up field layout information. // auto fieldLinkage = ee.key->findDecoration(); SLANG_ASSERT(fieldLinkage); auto fieldLayout = getFieldLayout(typeLayout, fieldLinkage->getMangledName()); RefPtr fieldTypeLayout = fieldLayout ? fieldLayout->typeLayout : nullptr; // If we are processing layout information, then // we need to create a new link in the chain // of variables that will determine offsets // for the eventual leaf fields... LegalVarChain newVarChainStorage; LegalVarChain* newVarChain = varChain; if (fieldLayout) { newVarChainStorage.next = varChain; newVarChainStorage.varLayout = fieldLayout; newVarChain = &newVarChainStorage; } UnownedStringSlice fieldNameHint; String joinedNameHintStorage; if( nameHint.size() ) { if( auto fieldNameHintDecoration = ee.key->findDecoration() ) { joinedNameHintStorage.append(nameHint); joinedNameHintStorage.append("."); joinedNameHintStorage.append(fieldNameHintDecoration->getName()); fieldNameHint = joinedNameHintStorage.getUnownedSlice(); } } LegalVal fieldVal = declareVars( context, op, ee.type, fieldTypeLayout, newVarChain, fieldNameHint, globalNameInfo); TuplePseudoVal::Element element; element.key = ee.key; element.val = fieldVal; tupleVal->elements.Add(element); } return LegalVal::tuple(tupleVal); } break; case LegalType::Flavor::wrappedBuffer: { auto wrappedBuffer = type.getWrappedBuffer(); auto innerVal = declareSimpleVar( context, op, wrappedBuffer->simpleType, typeLayout, varChain, nameHint, globalNameInfo); auto wrappedVal = wrapBufferValue( context, innerVal, wrappedBuffer->elementInfo); return wrappedVal; } default: SLANG_UNEXPECTED("unhandled"); UNREACHABLE_RETURN(LegalVal()); break; } } static LegalVal legalizeGlobalVar( IRTypeLegalizationContext* context, IRGlobalVar* irGlobalVar) { // Legalize the type for the variable's value auto legalValueType = legalizeType( context, irGlobalVar->getDataType()->getValueType()); switch (legalValueType.flavor) { case LegalType::Flavor::simple: // Easy case: the type is usable as-is, and we // should just do that. context->builder->setDataType( irGlobalVar, context->builder->getPtrType( legalValueType.getSimple())); return LegalVal::simple(irGlobalVar); default: { context->insertBeforeGlobal = irGlobalVar->getNextInst(); IRGlobalNameInfo globalNameInfo; globalNameInfo.globalVar = irGlobalVar; globalNameInfo.counter = 0; UnownedStringSlice nameHint = findNameHint(irGlobalVar); LegalVal newVal = declareVars(context, kIROp_GlobalVar, legalValueType, nullptr, nullptr, nameHint, &globalNameInfo); // Register the new value as the replacement for the old registerLegalizedValue(context, irGlobalVar, newVal); // Remove the old global from the module. irGlobalVar->removeFromParent(); context->replacedInstructions.Add(irGlobalVar); return newVal; } break; } } static LegalVal legalizeGlobalConstant( IRTypeLegalizationContext* context, IRGlobalConstant* irGlobalConstant) { // Legalize the type for the variable's value auto legalValueType = legalizeType( context, irGlobalConstant->getFullType()); switch (legalValueType.flavor) { case LegalType::Flavor::simple: // Easy case: the type is usable as-is, and we // should just do that. irGlobalConstant->setFullType(legalValueType.getSimple()); return LegalVal::simple(irGlobalConstant); default: { context->insertBeforeGlobal = irGlobalConstant->getNextInst(); IRGlobalNameInfo globalNameInfo; globalNameInfo.globalVar = irGlobalConstant; globalNameInfo.counter = 0; // TODO: need to handle initializer here! UnownedStringSlice nameHint = findNameHint(irGlobalConstant); LegalVal newVal = declareVars(context, kIROp_GlobalConstant, legalValueType, nullptr, nullptr, nameHint, &globalNameInfo); // Register the new value as the replacement for the old registerLegalizedValue(context, irGlobalConstant, newVal); // Remove the old global from the module. irGlobalConstant->removeFromParent(); context->replacedInstructions.Add(irGlobalConstant); return newVal; } break; } } static LegalVal legalizeGlobalParam( IRTypeLegalizationContext* context, IRGlobalParam* irGlobalParam) { // Legalize the type for the variable's value auto legalValueType = legalizeType( context, irGlobalParam->getFullType()); RefPtr varLayout = findVarLayout(irGlobalParam); RefPtr typeLayout = varLayout ? varLayout->typeLayout : nullptr; switch (legalValueType.flavor) { case LegalType::Flavor::simple: // Easy case: the type is usable as-is, and we // should just do that. irGlobalParam->setFullType(legalValueType.getSimple()); return LegalVal::simple(irGlobalParam); default: { context->insertBeforeGlobal = irGlobalParam->getNextInst(); LegalVarChain* varChain = nullptr; LegalVarChain varChainStorage; if (varLayout) { varChainStorage.next = nullptr; varChainStorage.varLayout = varLayout; varChain = &varChainStorage; } IRGlobalNameInfo globalNameInfo; globalNameInfo.globalVar = irGlobalParam; globalNameInfo.counter = 0; // TODO: need to handle initializer here! UnownedStringSlice nameHint = findNameHint(irGlobalParam); LegalVal newVal = declareVars(context, kIROp_GlobalParam, legalValueType, typeLayout, varChain, nameHint, &globalNameInfo); // Register the new value as the replacement for the old registerLegalizedValue(context, irGlobalParam, newVal); // Remove the old global from the module. irGlobalParam->removeFromParent(); context->replacedInstructions.Add(irGlobalParam); return newVal; } break; } } static void legalizeTypes( IRTypeLegalizationContext* context) { // Legalize all the top-level instructions in the module auto module = context->module; legalizeInstsInParent(context, module->moduleInst); // Clean up after any instructions we replaced along the way. for (auto& lv : context->replacedInstructions) { lv->removeAndDeallocate(); } } // We use the same basic type legalization machinery for both simplifying // away resource-type fields nested in `struct`s and for shuffling around // exisential-box fields to get the layout right. // // The differences between the two passes come down to some very small // distinctions about what types each pass considers "special" (e.g., // resources in one case and existential boxes in the other), along // with what they want to do when a uniform/constant buffer needs to // be made where the element type is non-simple (that is, includes // some fields of "special" type). // // The resource case is then the simpler one: // struct IRResourceTypeLegalizationContext : IRTypeLegalizationContext { IRResourceTypeLegalizationContext(IRModule* module) : IRTypeLegalizationContext(module) {} bool isSpecialType(IRType* type) override { // For resource type legalization, the "special" types // we are working with are resource types. // return isResourceType(type); } LegalType createLegalUniformBufferType( IROp op, LegalType legalElementType) override { // The appropriate strategy for legalizing uniform buffers // with resources inside already exists, so we can delegate to it. // return createLegalUniformBufferTypeForResources( this, op, legalElementType); } }; // The case for legalizing existential box types is then similar. // struct IRExistentialTypeLegalizationContext : IRTypeLegalizationContext { IRExistentialTypeLegalizationContext(IRModule* module) : IRTypeLegalizationContext(module) {} bool isSpecialType(IRType* inType) override { // The "special" types for our purposes are existential // boxes, or arrays thereof. // auto type = unwrapArray(inType); return as(type) != nullptr; } LegalType createLegalUniformBufferType( IROp op, LegalType legalElementType) override { // We'll delegate the logic for creating uniform buffers // over a mix of ordinary and existential-box types to // a subroutine so it can live near the resource case. // // TODO: We should eventually try to refactor this code // so that related functionality is grouped together. // return createLegalUniformBufferTypeForExistentials( this, op, legalElementType); } }; // The main entry points that are used when transforming IR code // to get it ready for lower-level codegen are then simple // wrappers around `legalizeTypes()` that pick an appropriately // specialized context type to use to get the job done. void legalizeResourceTypes( IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); IRResourceTypeLegalizationContext context(module); legalizeTypes(&context); } void legalizeExistentialTypeLayout( IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(module); SLANG_UNUSED(sink); IRExistentialTypeLegalizationContext context(module); legalizeTypes(&context); } }