// ir-legalize-types.cpp // This file implements a pass that takes IR // that has been fully specialized (no more // generics/interfaces needing to be specialized // away) and replaces any types that can't actually // be used as-is on the target. // // The particular case we are focused on is // aggregate types (e.g., `struct` types) that // contain resources (textures, samplers, etc.) // or that mix resources and ordinary "uniform" // data. #include "ir.h" #include "ir-insts.h" namespace Slang { struct LegalTypeImpl : RefObject { }; struct ImplicitDerefType; struct TupleType; struct LegalType { enum class Flavor { // Nothing: a NULL type none, // A simple type that can be represented directly as a `Type` simple, // Logically, we have a pointer-like type, but we are // going to represnet it as the pointed-to type implicitDeref, tuple, }; Flavor flavor = Flavor::none; RefPtr obj; static LegalType simple(Type* type) { LegalType result; result.flavor = Flavor::simple; result.obj = type; return result; } RefPtr getSimple() { assert(flavor == Flavor::simple); return obj.As(); } static LegalType implicitDeref( LegalType const& valueType); RefPtr getImplicitDeref() { assert(flavor == Flavor::implicitDeref); return obj.As(); } static LegalType tuple( RefPtr tupleType); RefPtr getTuple() { assert(flavor == Flavor::tuple); return obj.As(); } }; struct ImplicitDerefType : LegalTypeImpl { LegalType valueType; }; LegalType LegalType::implicitDeref( LegalType const& valueType) { RefPtr obj = new ImplicitDerefType(); obj->valueType = valueType; LegalType result; result.flavor = Flavor::implicitDeref; result.obj = obj; return result; } struct TupleType : LegalTypeImpl { struct Element { DeclRef fieldDeclRef; LegalType type; }; List elements; }; LegalType LegalType::tuple( RefPtr tupleType) { LegalType result; result.flavor = Flavor::tuple; result.obj = tupleType; return result; } struct LegalValImpl : RefObject { }; struct TupleVal; struct LegalVal { enum class Flavor { none, simple, implicitDeref, tuple, }; Flavor flavor; RefPtr obj; IRValue* irValue; static LegalVal simple(IRValue* irValue) { LegalVal result; result.flavor = Flavor::simple; result.irValue = irValue; return result; } IRValue* getSimple() { assert(flavor == Flavor::simple); return irValue; } static LegalVal tuple(RefPtr tupleVal); RefPtr getTuple() { assert(flavor == Flavor::tuple); return obj.As(); } static LegalVal implicitDeref(LegalVal const& val); LegalVal getImplicitDeref(); }; struct TupleVal : LegalValImpl { struct Element { DeclRef fieldDeclRef; LegalVal val; }; List elements; }; LegalVal LegalVal::tuple(RefPtr tupleVal) { LegalVal result; result.flavor = LegalVal::Flavor::tuple; result.obj = tupleVal; return result; } struct ImplicitDerefVal : LegalValImpl { LegalVal val; }; LegalVal LegalVal::implicitDeref(LegalVal const& val) { RefPtr implicitDerefVal = new ImplicitDerefVal(); implicitDerefVal->val = val; LegalVal result; result.flavor = LegalVal::Flavor::implicitDeref; result.obj = implicitDerefVal; return result; } LegalVal LegalVal::getImplicitDeref() { assert(flavor == Flavor::implicitDeref); return obj.As()->val; } struct TypeLegalizationContext { Session* session; IRModule* module; IRBuilder* builder; // When inserting new globals, put them before this one. IRGlobalValue* insertBeforeGlobal = nullptr; // When inserting new parameters, put them before this one. IRParam* insertBeforeParam = nullptr; Dictionary mapValToLegalVal; IRVar* insertBeforeLocalVar = nullptr; // store local var instructions that have been replaced here, so we can free them // when legalization has done List oldLocalVars; }; static void registerLegalizedValue( TypeLegalizationContext* context, IRValue* irValue, LegalVal const& legalVal) { context->mapValToLegalVal.Add(irValue, legalVal); } static bool isResourceType(Type* type) { while (auto arrayType = type->As()) { type = arrayType->baseType; } if (auto textureTypeBase = type->As()) { return true; } else if (auto samplerType = type->As()) { return true; } // TODO: need more comprehensive coverage here return false; } // Legalize a type, including any nested types // that it transitively contains. static LegalType legalizeType( TypeLegalizationContext* context, Type* type) { if (auto parameterBlockType = type->As()) { // We basically legalize the `ParameterBlock` type // over to `T`. In order to represent this preoperly, // we need to be careful to wrap it up in a way that // tells us to eliminate downstream deferences... auto legalElementType = legalizeType(context, parameterBlockType->getElementType()); return LegalType::implicitDeref(legalElementType); } else if (isResourceType(type)) { // We assume that any resource types not handled above // are legal as-is. return LegalType::simple(type); } else if (type->As()) { return LegalType::simple(type); } else if (type->As()) { return LegalType::simple(type); } else if (type->As()) { return LegalType::simple(type); } else if (auto declRefType = type->As()) { auto declRef = declRefType->declRef; if (auto aggTypeDeclRef = declRef.As()) { // Look at the (non-static) fields, and // see if anything needs to be cleaned up. // We collect the legalized types for the fields, // along with whether we've seen anything non-simple. List legalizedElements; bool anyComplex = false; bool anyResource = false; for (auto ff : getMembersOfType(aggTypeDeclRef)) { if (ff.getDecl()->HasModifier()) continue; auto fieldType = GetType(ff); if (isResourceType(fieldType)) { anyResource = true; } auto legalFieldType = legalizeType(context, fieldType); TupleType::Element element; element.fieldDeclRef = ff; element.type = legalFieldType; legalizedElements.Add(element); switch (legalFieldType.flavor) { case LegalType::Flavor::simple: break; default: anyComplex = true; break; } } // If we didn't see anything that requires work, // we can conceivably just use the type as-is // // TODO: this might be a good place to turn // a reference to a generic `struct` type into // a concrete non-generic type so that downstream // codegen doesn't have to deal with generics... // // TODO: In fact, why not just fully replace // all aggregate types here with some structural // types defined in the IR? if (!anyComplex && !anyResource) { return LegalType::simple(type); } // Okay, we are going to have to generate a // "tuple" type. // // TODO: split out the "simple" fields into // their own sub-type? RefPtr tupleType = new TupleType(); tupleType->elements = legalizedElements; return LegalType::tuple(tupleType); } } return LegalType::simple(type); } // Represents the "chain" of declarations that // were followed to get to a variable that we // are now declaring as a leaf variable. struct LegalVarChain { LegalVarChain* next; VarLayout* varLayout; }; static LegalVal declareVars( TypeLegalizationContext* context, IROp op, LegalType type, TypeLayout* typeLayout, LegalVarChain* varChain); // Legalize a type, and then expect it to // result in a simple type. static RefPtr legalizeSimpleType( TypeLegalizationContext* context, Type* type) { auto legalType = legalizeType(context, type); switch (legalType.flavor) { case LegalType::Flavor::simple: return legalType.getSimple(); default: // TODO: need to issue a diagnostic here. SLANG_UNEXPECTED("unexpected type case"); break; } } // Take a value that is being used as an operand, // and turn it into the equivalent legalized value. static LegalVal legalizeOperand( TypeLegalizationContext* context, IRValue* irValue) { LegalVal legalVal; if (context->mapValToLegalVal.TryGetValue(irValue, legalVal)) return legalVal; // For now, assume that anything not covered // by the mapping is legal as-is. return LegalVal::simple(irValue); } static void getArgumentValues( List & instArgs, LegalVal val) { switch (val.flavor) { case LegalVal::Flavor::simple: instArgs.Add(val.getSimple()); break; case LegalVal::Flavor::implicitDeref: getArgumentValues(instArgs, val.getImplicitDeref()); break; case LegalVal::Flavor::tuple: { for (auto elem : val.getTuple()->elements) getArgumentValues(instArgs, elem.val); } break; } } static LegalVal legalizeCall( TypeLegalizationContext* context, IRCall* callInst) { // TODO: implement legalization of non-simple return types auto retType = legalizeType(context, callInst->type); SLANG_ASSERT(retType.flavor == LegalType::Flavor::simple); List instArgs; for (auto i = 1u; i < callInst->argCount; i++) getArgumentValues(instArgs, legalizeOperand(context, callInst->getArg(i))); return LegalVal::simple(context->builder->emitCallInst(callInst->type, callInst->func.usedValue, instArgs.Count(), instArgs.Buffer())); } static LegalVal legalizeLoad( TypeLegalizationContext* context, LegalVal legalPtrVal) { switch (legalPtrVal.flavor) { case LegalVal::Flavor::simple: { return LegalVal::simple( context->builder->emitLoad(legalPtrVal.getSimple())); } break; case LegalVal::Flavor::implicitDeref: // We have turne a pointer(-like) type into its pointed-to (value) // type, and so the operation of loading goes away; we just use // the underlying value. return legalPtrVal.getImplicitDeref(); case LegalVal::Flavor::tuple: { // We need to emit a load for each element of // the tuple. RefPtr tupleVal = new TupleVal(); for (auto ee : legalPtrVal.getTuple()->elements) { TupleVal::Element element; element.fieldDeclRef = ee.fieldDeclRef; element.val = legalizeLoad(context, ee.val); tupleVal->elements.Add(element); } return LegalVal::tuple(tupleVal); } break; default: SLANG_UNEXPECTED("unhandled case"); break; } } static LegalVal legalizeStore( TypeLegalizationContext* context, LegalVal legalPtrVal, LegalVal legalVal) { switch (legalPtrVal.flavor) { case LegalVal::Flavor::simple: { context->builder->emitStore(legalPtrVal.getSimple(), legalVal.getSimple()); return legalVal; } break; case LegalVal::Flavor::implicitDeref: // TODO: what is the right behavior here? if (legalVal.flavor == LegalVal::Flavor::implicitDeref) return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal.getImplicitDeref()); else return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal); case LegalVal::Flavor::tuple: { // We need to emit a store for each element of // the tuple. auto destTuple = legalPtrVal.getTuple(); auto valTuple = legalVal.getTuple(); SLANG_ASSERT(destTuple->elements.Count() == valTuple->elements.Count()); for (UInt i = 0; i < valTuple->elements.Count(); i++) { legalizeStore(context, destTuple->elements[i].val, valTuple->elements[i].val); } return legalVal; } break; default: SLANG_UNEXPECTED("unhandled case"); break; } } static LegalVal legalizeFieldAddress( TypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, LegalVal legalFieldOperand) { auto builder = context->builder; // We don't expect any legalization to affect // the "field" argument. auto fieldOperand = legalFieldOperand.getSimple(); assert(fieldOperand->op == kIROp_decl_ref); auto fieldDeclRef = ((IRDeclRef*)fieldOperand)->declRef; switch (legalPtrOperand.flavor) { case LegalVal::Flavor::simple: return LegalVal::simple( builder->emitFieldAddress( type.getSimple(), legalPtrOperand.getSimple(), fieldOperand)); case LegalVal::Flavor::tuple: { // The operand is a tuple of pointer-like // values, we want to extract the element // corresponding to a field. We will handle // this by simply returning the corresponding // element from the operand. for (auto ee : legalPtrOperand.getTuple()->elements) { if (ee.fieldDeclRef.Equals(fieldDeclRef)) { return ee.val; } } SLANG_UNEXPECTED("didn't find tuple element"); UNREACHABLE_RETURN(LegalVal()); } default: SLANG_UNEXPECTED("unhandled"); UNREACHABLE_RETURN(LegalVal()); } } static LegalVal legalizeInst( TypeLegalizationContext* context, IRInst* inst, LegalType type, LegalVal const* args) { switch (inst->op) { case kIROp_Load: return legalizeLoad(context, args[0]); case kIROp_FieldAddress: return legalizeFieldAddress(context, type, args[0], args[1]); case kIROp_Store: return legalizeStore(context, args[0], args[1]); case kIROp_Call: return legalizeCall(context, (IRCall*)inst); default: // TODO: produce a user-visible diagnostic here SLANG_UNEXPECTED("non-simple operand(s)!"); break; } } RefPtr findVarLayout(IRValue* value) { if (auto layoutDecoration = value->findDecoration()) return layoutDecoration->layout.As(); return nullptr; } static LegalVal legalizeLocalVar( TypeLegalizationContext* context, IRVar* irLocalVar) { // Legalize the type for the variable's value auto legalValueType = legalizeType( context, irLocalVar->getType()->getValueType()); RefPtr varLayout = findVarLayout(irLocalVar); RefPtr typeLayout = varLayout ? varLayout->typeLayout : nullptr; // If we've decided to do implicit deref on the type, // then go ahead and declare a value of the pointed-to type. LegalType maybeSimpleType = legalValueType; while (maybeSimpleType.flavor == LegalType::Flavor::implicitDeref) { maybeSimpleType = maybeSimpleType.getImplicitDeref()->valueType; } switch (maybeSimpleType.flavor) { case LegalType::Flavor::simple: // Easy case: the type is usable as-is, and we // should just do that. irLocalVar->type = context->session->getPtrType( maybeSimpleType.getSimple()); return LegalVal::simple(irLocalVar); default: { context->insertBeforeLocalVar = irLocalVar; LegalVarChain* varChain = nullptr; LegalVarChain varChainStorage; if (varLayout) { varChainStorage.next = nullptr; varChainStorage.varLayout = varLayout; varChain = &varChainStorage; } LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain); // Remove the old local var. irLocalVar->removeFromParent(); // add old local var to list context->oldLocalVars.Add(irLocalVar); return newVal; } break; } } static LegalVal legalizeInst( TypeLegalizationContext* context, IRInst* inst) { if (inst->op == kIROp_Var) return legalizeLocalVar(context, (IRVar*)inst); // Need to legalize all the operands. auto argCount = inst->getArgCount(); List legalArgs; bool anyComplex = false; for (UInt aa = 0; aa < argCount; ++aa) { auto oldArg = inst->getArg(aa); auto legalArg = legalizeOperand(context, oldArg); legalArgs.Add(legalArg); if (legalArg.flavor != LegalVal::Flavor::simple) anyComplex = true; } // Also legalize the type of the instruction LegalType legalType = legalizeType(context, inst->type); if (!anyComplex && legalType.flavor == LegalType::Flavor::simple) { // Nothing interesting happened to the operands, // so we seem to be okay, right? for (UInt aa = 0; aa < argCount; ++aa) { auto legalArg = legalArgs[aa]; inst->setArg(aa, legalArg.getSimple()); } inst->type = legalType.getSimple(); return LegalVal::simple(inst); } // We have at least one "complex" operand, and we // need to figure out what to do with it. The anwer // will, in general, depend on what we are doing. // We will set up the IR builder so that any new // instructions generated will be placed after // the location of the original instruct. auto builder = context->builder; builder->curBlock = inst->getParentBlock(); builder->insertBeforeInst = inst->getNextInst(); LegalVal legalVal = legalizeInst( context, inst, legalType, legalArgs.Buffer()); // After we are done, we will eliminate the // original instruction by removing it from // the IR. // // TODO: we need to add it to a list of // instructions to be cleaned up... inst->removeFromParent(); // The value to be used when referencing // the original instruction will now be // whatever value(s) we created to replace it. return legalVal; } static void addParamType(IRFuncType * ftype, LegalType t) { switch (t.flavor) { case LegalType::Flavor::simple: ftype->paramTypes.Add(t.obj.As()); break; case LegalType::Flavor::implicitDeref: { auto imp = t.obj.As(); addParamType(ftype, imp->valueType); break; } case LegalType::Flavor::tuple: { auto tup = t.obj.As(); for (auto & elem : tup->elements) addParamType(ftype, elem.type); } break; default: SLANG_ASSERT(false); } } static void legalizeFunc( TypeLegalizationContext* context, IRFunc* irFunc) { // Overwrite the function's type with // the result of legalization. auto newFuncType = new IRFuncType(); newFuncType->setSession(context->session); auto oldFuncType = irFunc->type.As(); newFuncType->resultType = legalizeSimpleType(context, oldFuncType->resultType); for (auto & paramType : oldFuncType->paramTypes) { auto legalParamType = legalizeType(context, paramType); addParamType(newFuncType, legalParamType); } irFunc->type = newFuncType; List paramVals; List oldParams; // we use this list to store replaced local var insts. // these old instructions will be freed when we are done. context->oldLocalVars.Clear(); // Go through the blocks of the function for (auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock()) { // Legalize the parameters of the block, which may // involve increasing the number of parameters for (auto pp = bb->getFirstParam(); pp; pp = pp->nextParam) { auto legalParamType = legalizeType(context, pp->getType()); if (legalParamType.flavor != LegalType::Flavor::simple) { context->insertBeforeParam = pp; context->builder->curBlock = nullptr; auto paramVal = declareVars(context, kIROp_Param, legalParamType, nullptr, nullptr); paramVals.Add(paramVal); if (pp == bb->getFirstParam()) { bb->firstParam = pp; while (bb->firstParam->prevParam) bb->firstParam = bb->firstParam->prevParam; } bb->lastParam = pp->prevParam; if (pp->prevParam) pp->prevParam->nextParam = pp->nextParam; if (pp->nextParam) pp->nextParam->prevParam = pp->prevParam; auto oldParam = pp; oldParams.Add(oldParam); registerLegalizedValue(context, oldParam, paramVal); } } // Now legalize the instructions inside the block IRInst* nextInst = nullptr; for (auto ii = bb->getFirstInst(); ii; ii = nextInst) { nextInst = ii->getNextInst(); LegalVal legalVal = legalizeInst(context, ii); registerLegalizedValue(context, ii, legalVal); } } for (auto & op : oldParams) { SLANG_ASSERT(op->firstUse == nullptr || op->firstUse->nextUse == nullptr); op->deallocate(); } for (auto & lv : context->oldLocalVars) lv->deallocate(); } static LegalVal declareSimpleVar( TypeLegalizationContext* context, IROp op, Type* type, TypeLayout* typeLayout, LegalVarChain* varChain) { RefPtr varLayout; if (typeLayout) { // We need to construct a layout for the new variable // that reflects both the type we have given it, as // well as all the offset information that has accumulated // along the chain of parent variables. varLayout = new VarLayout(); varLayout->typeLayout = typeLayout; for (auto rr : typeLayout->resourceInfos) { auto resInfo = varLayout->findOrAddResourceInfo(rr.kind); for (auto vv = varChain; vv; vv = vv->next) { if (auto parentResInfo = vv->varLayout->FindResourceInfo(rr.kind)) { resInfo->index += parentResInfo->index; resInfo->space += parentResInfo->space; } } } // Some of the parent variables might actually contain offsets // to the `space` or `set` of the field, and we need to apply // those to all the nested resource infos. for (auto vv = varChain; vv; vv = vv->next) { auto parentSpaceInfo = vv->varLayout->findOrAddResourceInfo(LayoutResourceKind::RegisterSpace); if (!parentSpaceInfo) continue; for (auto& rr : varLayout->resourceInfos) { if (rr.kind == LayoutResourceKind::RegisterSpace) { rr.index += parentSpaceInfo->index; } else { rr.space += parentSpaceInfo->index; } } } } switch (op) { case kIROp_global_var: { IRBuilder* builder = context->builder; auto globalVar = builder->createGlobalVar(type); globalVar->removeFromParent(); globalVar->insertBefore(context->insertBeforeGlobal); if (varLayout) { builder->addLayoutDecoration(globalVar, varLayout); } return LegalVal::simple(globalVar); } break; case kIROp_Var: { IRBuilder* builder = context->builder; auto localVar = builder->emitVar(type); localVar->removeFromParent(); localVar->insertBefore(context->insertBeforeLocalVar); if (varLayout) { builder->addLayoutDecoration(localVar, varLayout); } return LegalVal::simple(localVar); } break; case kIROp_Param: { IRBuilder* builder = context->builder; auto param = builder->emitParam(type); if (context->insertBeforeParam->prevParam) context->insertBeforeParam->prevParam->nextParam = param; param->prevParam = context->insertBeforeParam->prevParam; param->nextParam = context->insertBeforeParam; context->insertBeforeParam->prevParam = param; if (varLayout) { builder->addLayoutDecoration(param, varLayout); } return LegalVal::simple(param); } break; default: SLANG_UNEXPECTED("unexpected IR opcode"); break; } } static RefPtr getDerefTypeLayout( TypeLayout* typeLayout) { if (!typeLayout) return nullptr; if (auto parameterGroupTypeLayout = dynamic_cast(typeLayout)) { return parameterGroupTypeLayout->elementTypeLayout; } return typeLayout; } static RefPtr getFieldLayout( TypeLayout* typeLayout, DeclRef fieldDeclRef) { if (!typeLayout) return nullptr; if (auto structTypeLayout = dynamic_cast(typeLayout)) { RefPtr fieldLayout; if (structTypeLayout->mapVarToLayout.TryGetValue(fieldDeclRef.getDecl(), fieldLayout)) return fieldLayout; } return nullptr; } static LegalVal declareVars( TypeLegalizationContext* context, IROp op, LegalType type, TypeLayout* typeLayout, LegalVarChain* varChain) { switch (type.flavor) { case LegalType::Flavor::simple: return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain); break; case LegalType::Flavor::implicitDeref: { // Just declare a variable of the pointed-to type, // since we are removing the indirection. auto val = declareVars( context, op, type.getImplicitDeref()->valueType, getDerefTypeLayout(typeLayout), varChain); return LegalVal::implicitDeref(val); } break; case LegalType::Flavor::tuple: { // Declare one variable for each element of the tuple auto tupleType = type.getTuple(); RefPtr tupleVal = new TupleVal(); for (auto ee : tupleType->elements) { auto fieldLayout = getFieldLayout(typeLayout, ee.fieldDeclRef); RefPtr fieldTypeLayout = fieldLayout ? fieldLayout->typeLayout : nullptr; // If we are processing layout information, then // we need to create a new link in the chain // of variables that will determine offsets // for the eventual leaf fields... LegalVarChain newVarChainStorage; LegalVarChain* newVarChain = varChain; if (fieldLayout) { newVarChainStorage.next = varChain; newVarChainStorage.varLayout = fieldLayout; newVarChain = &newVarChainStorage; } TupleVal::Element element; element.fieldDeclRef = ee.fieldDeclRef; element.val = declareVars( context, op, ee.type, fieldTypeLayout, newVarChain); tupleVal->elements.Add(element); } return LegalVal::tuple(tupleVal); } break; default: SLANG_UNEXPECTED("unhandled"); break; } } static void legalizeGlobalVar( TypeLegalizationContext* context, IRGlobalVar* irGlobalVar) { // Legalize the type for the variable's value auto legalValueType = legalizeType( context, irGlobalVar->getType()->getValueType()); RefPtr varLayout = findVarLayout(irGlobalVar); RefPtr typeLayout = varLayout ? varLayout->typeLayout : nullptr; switch (legalValueType.flavor) { case LegalType::Flavor::simple: // Easy case: the type is usable as-is, and we // should just do that. irGlobalVar->type = context->session->getPtrType( legalValueType.getSimple()); break; default: { context->insertBeforeGlobal = irGlobalVar->getNextValue(); LegalVarChain* varChain = nullptr; LegalVarChain varChainStorage; if (varLayout) { varChainStorage.next = nullptr; varChainStorage.varLayout = varLayout; varChain = &varChainStorage; } LegalVal newVal = declareVars(context, kIROp_global_var, legalValueType, typeLayout, varChain); // Register the new value as the replacement for the old registerLegalizedValue(context, irGlobalVar, newVal); // Remove the old global from the module. irGlobalVar->removeFromParent(); // TODO: actually clean up the global! } break; } } static void legalizeGlobalValue( TypeLegalizationContext* context, IRGlobalValue* irValue) { switch (irValue->op) { case kIROp_witness_table: // Just skip these. break; case kIROp_Func: legalizeFunc(context, (IRFunc*)irValue); break; case kIROp_global_var: legalizeGlobalVar(context, (IRGlobalVar*)irValue); break; default: SLANG_UNEXPECTED("unknown global value type"); break; } } static void legalizeTypes( TypeLegalizationContext* context) { auto module = context->module; for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) { legalizeGlobalValue(context, gv); } } void legalizeTypes( IRModule* module) { auto session = module->session; SharedIRBuilder sharedBuilderStorage; auto sharedBuilder = &sharedBuilderStorage; sharedBuilder->session = session; sharedBuilder->module = module; IRBuilder builderStorage; auto builder = &builderStorage; builder->sharedBuilder = sharedBuilder; TypeLegalizationContext contextStorage; auto context = &contextStorage; context->session = session; context->module = module; context->builder = builder; legalizeTypes(context); } #if 0 typedef unsigned int TypeScalarizationFlags; enum TypeScalarizationFlag { anyResource = 0x1, anyNonResource = 0x2, anyAggregate = 0x4, }; bool isResourceType(Type* type) { while (auto arrayType = type->As()) { type = arrayType->baseType; } if (auto textureTypeBase = type->As()) { return true; } else if (auto samplerType = type->As()) { return true; } // TODO: need more comprehensive coverage here return false; } TypeScalarizationFlags getTypeScalarizationFlags( Session* session, Type* type) { // TODO: we should probably cache flags once // they are computed, to avoid O(N^2) sorts // of behavior. if (isResourceType(type)) return TypeScalarizationFlag::anyNonResource; if(type->As()) { return TypeScalarizationFlag::anyNonResource; } if(type->As()) { return TypeScalarizationFlag::anyNonResource; } if(type->As()) { return TypeScalarizationFlag::anyNonResource; } else if (auto declRefType = type->As()) { auto declRef = declRefType->declRef; if (auto structDeclRef = declRef.As()) { TypeScalarizationFlags flags = TypeScalarizationFlag::anyAggregate; // For structure types, the basic rule will be // that if the type contains *any* resource-type // fields, then it needs to be scalarized. // If it contains any non-resource-type fields, // then we should aggregate these into a single // new `struct` type with just the non-resource // fields. for (auto fieldDeclRef : getMembersOfType(structDeclRef)) { auto fieldType = GetType(fieldDeclRef); // TODO: we are making a recursive call here, so // this will break if/when we ever allowed a recursive type! auto fieldFlags = getTypeScalarizationFlags(session, fieldType); flags |= fieldFlags; } return flags; } } else if (auto arrayType = type->As()) { return getTypeScalarizationFlags( session, arrayType->baseType); } // Default behavior: assume we have a non-resource type return TypeScalarizationFlag::anyNonResource; } struct ArrayScalarizationInfo { ArrayScalarizationInfo* next; RefPtr elementCount; RefPtr typeLayout; }; struct SharedScalarizationContext { }; struct ScalarizationContext { SharedScalarizationContext* shared; IRBuilder* builder; IRGlobalVar* globalVar; VarLayout* globalVarLayout; IRGlobalValue* valueToInsertAfter; }; IRValue* emitSimpleScalarizedField( ScalarizationContext* context, Type* inType, VarLayout* fieldLayout, TypeLayout* inTypeLayout, ArrayScalarizationInfo* arrayInfo) { auto builder = context->builder; auto globalVar = context->globalVar; auto globalVarLayout = context->globalVarLayout; auto valueToInsertAfter = context->valueToInsertAfter; RefPtr type = inType; RefPtr typeLayout = inTypeLayout; // If we are turning an array-of-structs into // a struct-of-arrays, then we need to apply // all the appropriate array dimensions here. for (auto aa = arrayInfo; aa; aa = aa->next) { type = builder->getSession()->getArrayType(type, aa->elementCount); if (typeLayout) { RefPtr arrayTypeLayout = new ArrayTypeLayout(); arrayTypeLayout->elementTypeLayout = typeLayout; // TODO: fill in the other fields! typeLayout = arrayTypeLayout; } } RefPtr newVarLayout; if (typeLayout) { newVarLayout = new VarLayout(); newVarLayout->typeLayout = typeLayout; if (fieldLayout) { for (auto fieldResourceInfo : fieldLayout->resourceInfos) { auto newResourceInfo = newVarLayout->findOrAddResourceInfo(fieldResourceInfo.kind); if (globalVarLayout) { if (auto globalResourceInfo = globalVarLayout->FindResourceInfo(fieldResourceInfo.kind)) { newResourceInfo->index += globalResourceInfo->index; newResourceInfo->space += globalResourceInfo->space; } } newResourceInfo->index += fieldResourceInfo.index; newResourceInfo->space += fieldResourceInfo.space; } } } auto newGlobalVar = addGlobalVariable(builder->getModule(), type); builder->addLayoutDecoration(newGlobalVar, newVarLayout); newGlobalVar->removeFromParent(); newGlobalVar->insertAfter(valueToInsertAfter); context->valueToInsertAfter = newGlobalVar; return newGlobalVar; } void scalarizeGlobalVariable( ScalarizationContext* context, Type* valueType, TypeLayout* valueTypeLayout, ArrayScalarizationInfo* arrayInfo) { if (auto arrayType = valueType->As()) { // Okay, we need to recurse down and scalarize the // array element type, wrapping up each field in // an array declarator as needed. ArrayScalarizationInfo newArrayInfo; newArrayInfo.next = arrayInfo; newArrayInfo.elementCount = arrayType->ArrayLength; RefPtr elementTypeLayout; if (auto arrayTypeLayout = dynamic_cast(valueTypeLayout)) { newArrayInfo.typeLayout = arrayTypeLayout; elementTypeLayout = arrayTypeLayout->elementTypeLayout; } scalarizeGlobalVariable( context, arrayType->baseType, elementTypeLayout, &newArrayInfo); // Now we need to look at all uses of the variable, // and properly rework element-index operations // to instead index into the sub-arrays... } else if (auto declRefType = valueType->As()) { auto declRef = declRefType->declRef; if (auto aggTypeDeclRef = declRef.As()) { RefPtr structTypeLayout = dynamic_cast(valueTypeLayout); // Okay, we need to look through the fields, and // create a new variable for each of them. Dictionary fieldMap; UInt fieldCounter = 0; for (auto fieldDeclRef : getMembersOfType(aggTypeDeclRef)) { UInt fieldIndex = fieldCounter++; RefPtr fieldLayout; RefPtr fieldTypeLayout; if (structTypeLayout) { fieldLayout = structTypeLayout->fields[fieldIndex]; fieldTypeLayout = fieldLayout->typeLayout; } // Note: we do *not* try to deal with recursive // expansion of the fields here, and instead // prefer to handle those in further // simplification passes. auto fieldGlobalVar = emitSimpleScalarizedField( context, GetType(fieldDeclRef), fieldLayout, fieldTypeLayout, arrayInfo); fieldMap.Add(fieldDeclRef.getDecl(), fieldGlobalVar); } // Now we need to scan for uses of the original variable, // and replace them with uses of the individual fields. auto globalVar = context->globalVar; IRUse* nextUse = nullptr; for (IRUse* use = globalVar->firstUse; use; use = nextUse) { nextUse = use->nextUse; IRUser* user = use->user; switch (user->op) { case kIROp_FieldAddress: { // This should be the easy case: we are taking // the address of a field inside this global // value, so we can just return the adress // of the global value that replaced that field. IRFieldAddress* fieldAddressInst = (IRFieldAddress*)user; IRValue* fieldOperand = fieldAddressInst->getField(); assert(fieldOperand->op == kIROp_decl_ref); auto fieldDeclRef = ((IRDeclRef*)fieldOperand)->declRef; auto fieldDecl = fieldDeclRef.getDecl(); IRValue* fieldVar = *fieldMap.TryGetValue(fieldDecl); fieldAddressInst->replaceUsesWith(fieldVar); } break; default: SLANG_UNEXPECTED("what to do?"); break; } } } else { SLANG_UNEXPECTED("not handled"); } } else { SLANG_UNEXPECTED("not handled"); } } void scalarizeGlobalVariable( SharedScalarizationContext* sharedContext, IRBuilder* builder, IRGlobalVar* globalVar, VarLayout* globalVarLayout, Type* valueType, TypeLayout* valueTypeLayout) { ScalarizationContext contextStorage; auto context = &contextStorage; context->shared = sharedContext; context->builder = builder; context->globalVar = globalVar; context->globalVarLayout = globalVarLayout; context->valueToInsertAfter = globalVar; scalarizeGlobalVariable( context, valueType, valueTypeLayout, nullptr); } RefPtr findVarLayout(IRValue* value) { if (auto layoutDecoration = value->findDecoration()) return layoutDecoration->layout.As(); return nullptr; } void scalarizeMixedResourceTypes( Session* session, IRModule* module) { SharedIRBuilder sharedBuilderStorage; auto sharedBuilder = &sharedBuilderStorage; sharedBuilder->session = session; sharedBuilder->module = module; IRBuilder builderStorage; auto builder = &builderStorage; builder->shared = sharedBuilder; SharedScalarizationContext sharedContextStorage; auto sharedContext = &sharedContextStorage; List workList; for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) { workList.Add(gv); } while (workList.Count()) { IRValue* value = workList[0]; workList.FastRemoveAt(0); switch (value->op) { case kIROp_Func: { // TODO: need to iterate over parameters of // the function (and its blocks) to make // sure that any types that need scalarization // are properly handled. } break; case kIROp_global_var: { IRGlobalVar* globalVar = (IRGlobalVar*)value; auto valueType = globalVar->getType()->getValueType(); auto flags = getTypeScalarizationFlags(session, valueType); if (!(flags & (TypeScalarizationFlag::anyNonResource | TypeScalarizationFlag::anyAggregate))) continue; auto varLayout = findVarLayout(globalVar); RefPtr typeLayout = varLayout ? varLayout->typeLayout : nullptr; // Okay, we have a variable of some composite type // that we need to scalarize. Since this is a global, // we also need to be careful to deal with any // layout information that has been attached. scalarizeGlobalVariable( sharedContext, builder, globalVar, varLayout, valueType, typeLayout); globalVar->removeFromParent(); // TODO: need to destroy this global! } break; default: { // TODO: look at the type of the value, // and if it needs scalarization, replace // it with a tuple here. } break; } } } #endif }