diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/emit.cpp | 81 | ||||
| -rw-r--r-- | source/slang/ir-legalize-types.cpp | 1455 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 1 | ||||
| -rw-r--r-- | source/slang/lower.cpp | 15 | ||||
| -rw-r--r-- | source/slang/mangle.cpp | 23 | ||||
| -rw-r--r-- | source/slang/mangle.h | 3 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 113 | ||||
| -rw-r--r-- | source/slang/type-defs.h | 36 |
8 files changed, 1183 insertions, 544 deletions
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index dccb12f53..5b7a42ad7 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -4,6 +4,7 @@ #include "ir-insts.h" #include "lower.h" #include "lower-to-ir.h" +#include "mangle.h" #include "name.h" #include "syntax.h" #include "type-layout.h" @@ -99,6 +100,8 @@ struct SharedEmitContext HashSet<Decl*> irDeclsVisited; Dictionary<IRBlock*, IRBlock*> irMapContinueTargetToLoopHead; + + HashSet<String> irTupleTypes; }; struct EmitContext @@ -1230,6 +1233,13 @@ struct EmitVisitor emitTypeImpl(type->valueType, arg.declarator); } + void visitFilteredTupleType(FilteredTupleType* type, TypeEmitArg const& arg) + { + auto declarator = arg.declarator; + emit(getMangledTypeName(type)); + EmitDeclarator(declarator); + } + void EmitType( RefPtr<Type> type, SourceLoc const& typeLoc, @@ -4199,7 +4209,10 @@ emitDeclImpl(decl, nullptr); return getText(reflectionNameMod->nameAndLoc.name); } - return getIRName(decl); + if ((context->shared->entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_MANGLING)) + { + return getIRName(decl); + } } switch (inst->op) @@ -6271,6 +6284,26 @@ emitDeclImpl(decl, nullptr); } } } + else if (auto filteredTupleType = elementType->As<FilteredTupleType>()) + { + auto structTypeLayout = typeLayout.As<StructTypeLayout>(); + assert(structTypeLayout); + + for (auto ee : filteredTupleType->elements) + { + RefPtr<VarLayout> fieldLayout; + structTypeLayout->mapVarToLayout.TryGetValue(ee.fieldDeclRef, fieldLayout); + + emitIRVarModifiers(ctx, fieldLayout); + + auto fieldType = ee.type; + emitIRType(ctx, fieldType, getIRName(ee.fieldDeclRef)); + + emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); + + emit(";\n"); + } + } else { emit("/* unexpected */"); @@ -6586,6 +6619,43 @@ emitDeclImpl(decl, nullptr); ensureStructDecl(ctx, structDeclRef); } } + else if (auto filteredTupleType = type->As<FilteredTupleType>()) + { + // First, ensure that the element types are ready: + for (auto ee : filteredTupleType->elements) + { + if (ee.type) + { + emitIRUsedType(ctx, ee.type); + } + } + + // Now, we want to ensure we've emitted a + // matching `struct` type declaration. + + String mangledName = getMangledTypeName(filteredTupleType); + if (!ctx->shared->irTupleTypes.Contains(mangledName)) + { + ctx->shared->irTupleTypes.Add(mangledName); + + // Emit the damn `struct` decl... + + Emit("struct "); + emit(mangledName); + Emit("\n{\n"); + for( auto ee : filteredTupleType->elements ) + { + if (!ee.type) + continue; + + emitIRType(ctx, ee.type, getIRName(ee.fieldDeclRef)); + + emit(";\n"); + } + Emit("};\n"); + + } + } else {} } @@ -6840,7 +6910,7 @@ String emitEntryPoint( // Debugging code for IR transformations... #if 0 - fprintf(stderr, "###\n"); + fprintf(stderr, "### SPECIALIZED:\n"); dumpIR(lowered); fprintf(stderr, "###\n"); #endif @@ -6852,6 +6922,13 @@ String emitEntryPoint( // legalizeTypes(lowered); + // Debugging output of legalization +#if 0 + fprintf(stderr, "### LEGALIZED:\n"); + dumpIR(lowered); + fprintf(stderr, "###\n"); +#endif + // TODO: do we want to emit directly from IR, or translate the // IR back into AST for emission? diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index 2719f336b..37df5f698 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -22,7 +22,9 @@ struct LegalTypeImpl : RefObject { }; struct ImplicitDerefType; -struct TupleType; +struct TuplePseudoType; +struct PairPseudoType; +struct PairInfo; struct LegalType { @@ -38,7 +40,14 @@ struct LegalType // going to represnet it as the pointed-to type implicitDeref, + // A compound type was broken apart into its constituent fields, + // so a tuple "pseduo-type" is being used to collect + // those fields together. tuple, + + // A type has to get split into "ordinary" and "special" parts, + // each of which will be represented with its own `LegalType`. + pair, }; Flavor flavor = Flavor::none; @@ -68,12 +77,26 @@ struct LegalType } static LegalType tuple( - RefPtr<TupleType> tupleType); + RefPtr<TuplePseudoType> tupleType); - RefPtr<TupleType> getTuple() + RefPtr<TuplePseudoType> getTuple() { assert(flavor == Flavor::tuple); - return obj.As<TupleType>(); + return obj.As<TuplePseudoType>(); + } + + static LegalType pair( + RefPtr<PairPseudoType> pairType); + + static LegalType pair( + RefPtr<Type> ordinaryType, + LegalType const& specialType, + RefPtr<PairInfo> pairInfo); + + RefPtr<PairPseudoType> getPair() + { + assert(flavor == Flavor::pair); + return obj.As<PairPseudoType>(); } }; @@ -94,19 +117,36 @@ LegalType LegalType::implicitDeref( return result; } -struct TupleType : LegalTypeImpl +// Represents the pseudo-type for a compound type +// that had to be broken apart because it contained +// one or more fields of types that shouldn't be +// allowed in aggregates. +// +// A tuple pseduo-type will have an element for +// each field of the original type, that represents +// the legalization of that field's type. +// +// It optionally also contains an "ordinary" type +// that packs together any per-field data that +// itself has (or contains) an ordinary type. +struct TuplePseudoType : LegalTypeImpl { + // Represents one element of the tuple pseudo-type struct Element { + // The field that this element replaces DeclRef<VarDeclBase> fieldDeclRef; + + // The legalized type of the element LegalType type; }; - List<Element> elements; + // All of the elements of the tuple pseduo-type. + List<Element> elements; }; LegalType LegalType::tuple( - RefPtr<TupleType> tupleType) + RefPtr<TuplePseudoType> tupleType) { LegalType result; result.flavor = Flavor::tuple; @@ -114,10 +154,108 @@ LegalType LegalType::tuple( return result; } +struct PairInfo : RefObject +{ + typedef unsigned int Flags; + enum + { + kFlag_hasOrdinary = 0x1, + kFlag_hasSpecial = 0x2, + }; + + struct Element + { + // The field the element represents + DeclRef<Decl> fieldDeclRef; + + // The conceptual type of the field. + // If both the `hasOrdinary` and + // `hasSpecial` bits are set, then + // this is expected to be a + // `LegalType::Flavor::pair` + LegalType type; + + // Is the value represented on + // the ordinary side, the special + // side, or both? + Flags flags; + }; + + // For a pair type or value, we need to track + // which fields are on which side(s). + List<Element> elements; + + Element* findElement(DeclRef<Decl> const& fieldDeclRef) + { + for (auto& ee : elements) + { + if(ee.fieldDeclRef.Equals(fieldDeclRef)) + return ⅇ + } + return nullptr; + } +}; + +struct PairPseudoType : LegalTypeImpl +{ + // Any field(s) with ordinary types will + // get captured here, as a completely + // standard AST-level type. + RefPtr<Type> ordinaryType; + + // Any fields with "special" (not ordinary) + // types will get captured here (usually + // with a tuple). + LegalType specialType; + + RefPtr<PairInfo> pairInfo; +}; + +LegalType LegalType::pair( + RefPtr<PairPseudoType> pairType) +{ + LegalType result; + result.flavor = Flavor::pair; + result.obj = pairType; + return result; +} + +LegalType LegalType::pair( + RefPtr<Type> ordinaryType, + LegalType const& specialType, + RefPtr<PairInfo> pairInfo) +{ + // Handle some special cases for when + // one or the other of the types isn't + // actually used. + + if (!ordinaryType) + { + // There was nothing ordinary. + return specialType; + } + + if (specialType.flavor == LegalType::Flavor::none) + { + return LegalType::simple(ordinaryType); + } + + // There were both ordinary and special fields, + // and so we need to handle them here. + + RefPtr<PairPseudoType> obj = new PairPseudoType(); + obj->ordinaryType = ordinaryType; + obj->specialType = specialType; + obj->pairInfo = pairInfo; + return LegalType::pair(obj); +} + + struct LegalValImpl : RefObject { }; -struct TupleVal; +struct TuplePseudoVal; +struct PairPseudoVal; struct LegalVal { @@ -127,11 +265,12 @@ struct LegalVal simple, implicitDeref, tuple, + pair, }; - Flavor flavor; + Flavor flavor = Flavor::none; RefPtr<RefObject> obj; - IRValue* irValue; + IRValue* irValue = nullptr; static LegalVal simple(IRValue* irValue) { @@ -147,30 +286,42 @@ struct LegalVal return irValue; } - static LegalVal tuple(RefPtr<TupleVal> tupleVal); + static LegalVal tuple(RefPtr<TuplePseudoVal> tupleVal); - RefPtr<TupleVal> getTuple() + RefPtr<TuplePseudoVal> getTuple() { assert(flavor == Flavor::tuple); - return obj.As<TupleVal>(); + return obj.As<TuplePseudoVal>(); } static LegalVal implicitDeref(LegalVal const& val); LegalVal getImplicitDeref(); + + static LegalVal pair(RefPtr<PairPseudoVal> pairInfo); + static LegalVal pair( + LegalVal const& ordinaryVal, + LegalVal const& specialVal, + RefPtr<PairInfo> pairInfo); + + RefPtr<PairPseudoVal> getPair() + { + assert(flavor == Flavor::pair); + return obj.As<PairPseudoVal>(); + } }; -struct TupleVal : LegalValImpl +struct TuplePseudoVal : LegalValImpl { struct Element { - DeclRef<VarDeclBase> fieldDeclRef; - LegalVal val; + DeclRef<VarDeclBase> fieldDeclRef; + LegalVal val; }; - List<Element> elements; + List<Element> elements; }; -LegalVal LegalVal::tuple(RefPtr<TupleVal> tupleVal) +LegalVal LegalVal::tuple(RefPtr<TuplePseudoVal> tupleVal) { LegalVal result; result.flavor = LegalVal::Flavor::tuple; @@ -178,6 +329,44 @@ LegalVal LegalVal::tuple(RefPtr<TupleVal> tupleVal) return result; } +struct PairPseudoVal : LegalValImpl +{ + LegalVal ordinaryVal; + LegalVal specialVal; + + // The info to tell us which fields + // are on which side(s) + RefPtr<PairInfo> pairInfo; +}; + +LegalVal LegalVal::pair(RefPtr<PairPseudoVal> pairInfo) +{ + LegalVal result; + result.flavor = LegalVal::Flavor::pair; + result.obj = pairInfo; + return result; +} + +LegalVal LegalVal::pair( + LegalVal const& ordinaryVal, + LegalVal const& specialVal, + RefPtr<PairInfo> pairInfo) +{ + if (ordinaryVal.flavor == LegalVal::Flavor::none) + return specialVal; + + if (specialVal.flavor == LegalVal::Flavor::none) + return ordinaryVal; + + + RefPtr<PairPseudoVal> obj = new PairPseudoVal(); + obj->ordinaryVal = ordinaryVal; + obj->specialVal = specialVal; + obj->pairInfo = pairInfo; + + return LegalVal::pair(obj); +} + struct ImplicitDerefVal : LegalValImpl { LegalVal val; @@ -251,6 +440,442 @@ static bool isResourceType(Type* type) return false; } +static LegalType legalizeType( + TypeLegalizationContext* context, + Type* type); + +// Helper type for legalization of aggregate types +// that might need to be turned into tuple pseudo-types. +struct TupleTypeBuilder +{ + TypeLegalizationContext* context; + RefPtr<Type> type; + + List<FilteredTupleType::Element> ordinaryElements; + List<TuplePseudoType::Element> specialElements; + + List<PairInfo::Element> pairElements; + + // Did we have any fields that forced us to change + // the actual type away from the declared type? + bool anyComplex = false; + + // Did we have any fields that actually required + // storage in the "special" part of things? + bool anySpecial = false; + + // Did we have any fields that actually used ordinary storage? + bool anyOrdinary = false; + + // Add a field to the (pseudo-)type we are building + void addField( + DeclRef<VarDeclBase> fieldDeclRef, + LegalType legalFieldType, + LegalType legalLeafType, + bool isResource) + { + RefPtr<Type> ordinaryType; + LegalType specialType; + RefPtr<PairInfo> elementPairInfo; + switch (legalLeafType.flavor) + { + case LegalType::Flavor::simple: + { + // We need to add an actual field, but we need + // to check if it is a resource type to know + // whether it should go in the "ordinary" list or not. + if (!isResource) + { + ordinaryType = legalLeafType.getSimple(); + } + else + { + specialType = legalFieldType; + } + } + break; + + case LegalType::Flavor::implicitDeref: + { + // TODO: we may want to say that any use + // of `implicitDeref` puts the entire thing + // into the "special" category, rather than + // try to look under the hood... + + anyComplex = true; + + // We want to recursively add data + // based on the unwrapped type. + // + // Note: this assumes we can't have a tuple + // or a pair "under" an `implicitDeref`, so + // we'll need to ensure that elsewhere. + addField( + fieldDeclRef, + legalFieldType, + legalLeafType.getImplicitDeref()->valueType, + isResource); + return; + } + break; + + case LegalType::Flavor::pair: + { + // The field's type had both special and non-special parts + auto pairType = legalLeafType.getPair(); + ordinaryType = pairType->ordinaryType; + specialType = pairType->specialType; + elementPairInfo = pairType->pairInfo; + } + break; + + case LegalType::Flavor::tuple: + { + // A tuple always represents "special" data + specialType = legalFieldType; + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + break; + } + + + PairInfo::Element pairElement; + pairElement.flags = 0; + pairElement.fieldDeclRef = fieldDeclRef; + + if (ordinaryType) + { + anyOrdinary = true; + pairElement.flags |= PairInfo::kFlag_hasOrdinary; + + FilteredTupleType::Element ordinaryElement; + ordinaryElement.fieldDeclRef = fieldDeclRef; + ordinaryElement.type = ordinaryType; + ordinaryElements.Add(ordinaryElement); + } + + if (specialType.flavor != LegalType::Flavor::none) + { + anySpecial = true; + anyComplex = true; + pairElement.flags |= PairInfo::kFlag_hasSpecial; + + TuplePseudoType::Element specialElement; + specialElement.fieldDeclRef = fieldDeclRef; + specialElement.type = specialType; + specialElements.Add(specialElement); + } + + pairElement.type = LegalType::pair(ordinaryType, specialType, elementPairInfo); + pairElements.Add(pairElement); + } + + // Add a field to the (pseudo-)type we are building + void addField( + DeclRef<VarDeclBase> fieldDeclRef) + { + // Skip `static` fields. + if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + return; + + auto fieldType = GetType(fieldDeclRef); + + bool isResourceField = isResourceType(fieldType); + + auto legalFieldType = legalizeType(context, fieldType); + addField( + fieldDeclRef, + legalFieldType, + legalFieldType, + isResourceField); + } + + LegalType getResult() + { + // If we didn't see anything "special" + // then we can use the type as-is. + // we can conceivably just use the type as-is + // + // TODO: this might be a good place to turn + // a reference to a generic `struct` type into + // a concrete non-generic type so that downstream + // codegen doesn't have to deal with generics... + // + // TODO: In fact, why not just fully replace + // all aggregate types here with some structural + // types defined in the IR? + if (!anyComplex) + { + return LegalType::simple(type); + } + + // If there were any "ordinary" fields along the way, + // then we need to collect them into a type to + // represent the ordinary part of things. + // + RefPtr<Type> ordinaryType; + if (anyOrdinary) + { + RefPtr<FilteredTupleType> ordinaryTypeImpl = new FilteredTupleType(); + ordinaryTypeImpl->setSession(context->session); + ordinaryTypeImpl->originalType = type; + ordinaryTypeImpl->elements = ordinaryElements; + ordinaryType = ordinaryTypeImpl; + } + + LegalType specialType; + if (anySpecial) + { + RefPtr<TuplePseudoType> specialTuple = new TuplePseudoType(); + specialTuple->elements = specialElements; + specialType = LegalType::tuple(specialTuple); + } + + RefPtr<PairInfo> pairInfo; + if (anyOrdinary && anySpecial) + { + pairInfo = new PairInfo(); + pairInfo->elements = pairElements; + } + + return LegalType::pair(ordinaryType, specialType, pairInfo); + } + +}; + +static RefPtr<Type> createBuiltinGenericType( + TypeLegalizationContext* context, + DeclRef<Decl> const& typeDeclRef, + RefPtr<Type> elementType) +{ + // We are going to take the type for the original + // decl-ref and construct a new one that uses + // our new element type as its parameter. + // + // TODO: we should have library code to make + // manipulations like this way easier. + + RefPtr<GenericSubstitution> oldGenericSubst = getGenericSubstitution( + typeDeclRef.substitutions); + SLANG_ASSERT(oldGenericSubst); + + RefPtr<GenericSubstitution> newGenericSubst = new GenericSubstitution(); + + newGenericSubst->outer = oldGenericSubst->outer; + newGenericSubst->genericDecl = oldGenericSubst->genericDecl; + newGenericSubst->args = oldGenericSubst->args; + newGenericSubst->args[0] = elementType; + + auto newDeclRef = DeclRef<Decl>( + typeDeclRef.getDecl(), + newGenericSubst); + + auto newType = DeclRefType::Create( + context->session, + newDeclRef); + + return newType; +} + +// Create a uniform buffer type with a given legalized +// element type. +static LegalType createLegalUniformBufferType( + TypeLegalizationContext* context, + DeclRef<Decl> const& typeDeclRef, + LegalType legalElementType) +{ + switch (legalElementType.flavor) + { + case LegalType::Flavor::simple: + { + // Easy case: we just have a simple element type, + // so we want to create a uniform buffer that wraps it. + return LegalType::simple(createBuiltinGenericType( + context, + typeDeclRef, + legalElementType.getSimple())); + } + break; + + case LegalType::Flavor::implicitDeref: + { + // This is actually an annoying case, because + // we are being asked to convert, e.g.,: + // + // cbuffer Foo { ParameterBlock<Bar> bar; } + // + // into the equivalent of: + // + // cbuffer Foo { Bar bar; } + // + // Which would really require a new `LegalType` that + // would reprerent a resource type with a modified + // element type. + // + // I'm going to attempt to hack this for now. + return LegalType::implicitDeref(createLegalUniformBufferType( + context, + typeDeclRef, + legalElementType.getImplicitDeref()->valueType)); + } + break; + + case LegalType::Flavor::pair: + { + // We assume that the "ordinary" part of things + // will get wrapped in a constant-buffer type, + // and the "special" part needs to be wrapped + // with an `implicitDeref`. + auto pairType = legalElementType.getPair(); + + auto ordinaryType = createBuiltinGenericType( + context, + typeDeclRef, + pairType->ordinaryType); + auto specialType = LegalType::implicitDeref(pairType->specialType); + + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // if we have a tuple type, then it must be representing + // the fields that can't be stored in a buffer anyway, + // so we just need to wrap each of them in an `implicitDeref` + + auto elementPseudoTupleType = legalElementType.getTuple(); + + RefPtr<TuplePseudoType> bufferPseudoTupleType = new TuplePseudoType(); + + // Wrap all the pseudo-tuple elements with `implicitDeref`, + // since they used to be inside a tuple, but aren't any more. + for (auto ee : elementPseudoTupleType->elements) + { + TuplePseudoType::Element newElement; + + newElement.fieldDeclRef = ee.fieldDeclRef; + newElement.type = LegalType::implicitDeref(ee.type); + + bufferPseudoTupleType->elements.Add(newElement); + } + + return LegalType::tuple(bufferPseudoTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + +static LegalType createLegalUniformBufferType( + TypeLegalizationContext* context, + UniformParameterGroupType* uniformBufferType, + LegalType legalElementType) +{ + return createLegalUniformBufferType( + context, + uniformBufferType->declRef, + legalElementType); +} + +// Create a pointer type with a given legalized value type. +static LegalType createLegalPtrType( + TypeLegalizationContext* context, + DeclRef<Decl> const& typeDeclRef, + LegalType legalValueType) +{ + switch (legalValueType.flavor) + { + case LegalType::Flavor::simple: + { + // Easy case: we just have a simple element type, + // so we want to create a uniform buffer that wraps it. + return LegalType::simple(createBuiltinGenericType( + context, + typeDeclRef, + legalValueType.getSimple())); + } + break; + + case LegalType::Flavor::implicitDeref: + { + // We are being asked to create a pointer type to something + // that is implicitly dereferenced, meaning we had: + // + // Ptr(PtrLink(T)) + // + // and now are being asked to make: + // + // Ptr(implicitDeref(LegalT)) + // + // So it seems like we can just create: + // + // implicitDeref(Ptr(LegalT)) + // + // and nobody should really be able to tell the difference, right? + return LegalType::implicitDeref(createLegalPtrType( + context, + typeDeclRef, + legalValueType.getImplicitDeref()->valueType)); + } + break; + + case LegalType::Flavor::pair: + { + // We just need to pointer-ify both sides of the pair. + auto pairType = legalValueType.getPair(); + + auto ordinaryType = createBuiltinGenericType( + context, + typeDeclRef, + pairType->ordinaryType); + auto specialType = createLegalPtrType( + context, + typeDeclRef, + pairType->specialType); + + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // Wrap each of the tuple elements up as a pointer. + auto valuePseudoTupleType = legalValueType.getTuple(); + + RefPtr<TuplePseudoType> ptrPseudoTupleType = new TuplePseudoType(); + + // Wrap all the pseudo-tuple elements with `implicitDeref`, + // since they used to be inside a tuple, but aren't any more. + for (auto ee : valuePseudoTupleType->elements) + { + TuplePseudoType::Element newElement; + + newElement.fieldDeclRef = ee.fieldDeclRef; + newElement.type = createLegalPtrType( + context, + typeDeclRef, + ee.type); + + ptrPseudoTupleType->elements.Add(newElement); + } + + return LegalType::tuple(ptrPseudoTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + // Legalize a type, including any nested types // that it transitively contains. static LegalType legalizeType( @@ -268,6 +893,30 @@ static LegalType legalizeType( parameterBlockType->getElementType()); return LegalType::implicitDeref(legalElementType); } + else if (auto uniformBufferType = type->As<UniformParameterGroupType>()) + { + // We have a `ConstantBuffer<T>` or `TextureBuffer<T>` or + // other pointer-like type that represents uniform parameters. + // We need to pull any resource-type fields out of it, but + // leave the non-resource fields where they are. + + // Legalize the element type to see what we are working with. + auto legalElementType = legalizeType(context, + uniformBufferType->getElementType()); + + switch (legalElementType.flavor) + { + case LegalType::Flavor::simple: + return LegalType::simple(type); + + default: + return createLegalUniformBufferType( + context, + uniformBufferType, + legalElementType); + } + + } else if (isResourceType(type)) { // We assume that any resource types not handled above @@ -286,6 +935,11 @@ static LegalType legalizeType( { return LegalType::simple(type); } + else if (auto ptrType = type->As<PtrTypeBase>()) + { + auto legalValueType = legalizeType(context, ptrType->getValueType()); + return createLegalPtrType(context, ptrType->declRef, legalValueType); + } else if (auto declRefType = type->As<DeclRefType>()) { auto declRef = declRefType->declRef; @@ -293,69 +947,72 @@ static LegalType legalizeType( { // Look at the (non-static) fields, and // see if anything needs to be cleaned up. + // The things that need to be "cleaned up" for + // our purposes are: + // + // - Fields of resource type, or any other future + // type we run into that isn't allowed in + // aggregates for at least some targets + // + // - Fields with types that themselves had to + // get legalized. + // + // If we don't run into any of these, we + // can just use the type as-is. Hooray! + // + // Otherwise, we are effectively going to split + // the type apart and create a `TuplePseudoType`. + // Every field of the original type will be + // represented as an element of this pseudo-type. + // Each element will record its `LegalType`, + // and the original field that it was created from. + // An element will also track whether it contains + // any "ordinary" data, and if so, it will remember + // an element index in a real (AST-level, non-pseudo) + // `TupleType` that is used to bundle together + // such fields. + // + // Storing all the simple fields together like this + // obviously adds complexity to the legalization + // pass, but it has important benefits: + // + // - It avoids creating functions with a very large + // number of parameters (when passing a structure + // with many fields), which might confuse downstream + // compilers. + // + // - It avoids applying AOS->SOA conversion to fields + // that don't actually need it, which is basically + // required if we want type layout to work. + // + // - It ensures that we can actually construct a + // constant-buffer type that wraps a legalized + // aggregate type; the ordinary fields will get + // placed inside a new constant-buffer type, + // while the special ones will get left outside. + // - // We collect the legalized types for the fields, - // along with whether we've seen anything non-simple. - List<TupleType::Element> legalizedElements; - bool anyComplex = false; - bool anyResource = false; - - for (auto ff : getMembersOfType<StructField>(aggTypeDeclRef)) - { - if (ff.getDecl()->HasModifier<HLSLStaticModifier>()) - continue; - - auto fieldType = GetType(ff); - if (isResourceType(fieldType)) - { - anyResource = true; - } - - auto legalFieldType = legalizeType(context, fieldType); - - TupleType::Element element; - element.fieldDeclRef = ff; - element.type = legalFieldType; - legalizedElements.Add(element); - - switch (legalFieldType.flavor) - { - case LegalType::Flavor::simple: - break; + TupleTypeBuilder builder; + builder.context = context; + builder.type = type; - default: - anyComplex = true; - break; - } - } - // If we didn't see anything that requires work, - // we can conceivably just use the type as-is - // - // TODO: this might be a good place to turn - // a reference to a generic `struct` type into - // a concrete non-generic type so that downstream - // codegen doesn't have to deal with generics... - // - // TODO: In fact, why not just fully replace - // all aggregate types here with some structural - // types defined in the IR? - if (!anyComplex && !anyResource) + for (auto ff : getMembersOfType<StructField>(aggTypeDeclRef)) { - return LegalType::simple(type); + builder.addField(ff); } - // Okay, we are going to have to generate a - // "tuple" type. - // - // TODO: split out the "simple" fields into - // their own sub-type? - - RefPtr<TupleType> tupleType = new TupleType(); - tupleType->elements = legalizedElements; - - return LegalType::tuple(tupleType); + return builder.getResult(); } + + // TODO: for other declaration-reference types, we really + // need to legalize the types used in substitutions, and + // signal an error if any of them turn out to be non-simple. + // + // The limited cases of types that can handle having non-simple + // types as generic arguments all need to be special-cased here. + // (For example, we can't handle `Texture2D<SomeStructWithTexturesInIt>`. + // } return LegalType::simple(type); @@ -418,18 +1075,38 @@ static void getArgumentValues( { switch (val.flavor) { + case LegalVal::Flavor::none: + break; + case LegalVal::Flavor::simple: instArgs.Add(val.getSimple()); break; + case LegalVal::Flavor::implicitDeref: getArgumentValues(instArgs, val.getImplicitDeref()); break; + + case LegalVal::Flavor::pair: + { + auto pairVal = val.getPair(); + getArgumentValues(instArgs, pairVal->ordinaryVal); + getArgumentValues(instArgs, pairVal->specialVal); + } + break; + case LegalVal::Flavor::tuple: { + auto tuplePsuedoVal = val.getTuple(); for (auto elem : val.getTuple()->elements) + { getArgumentValues(instArgs, elem.val); + } } break; + + default: + SLANG_UNEXPECTED("uhandled val flavor"); + break; } } @@ -445,7 +1122,11 @@ static LegalVal legalizeCall( for (auto i = 1u; i < callInst->argCount; i++) getArgumentValues(instArgs, legalizeOperand(context, callInst->getArg(i))); - return LegalVal::simple(context->builder->emitCallInst(callInst->type, callInst->func.usedValue, instArgs.Count(), instArgs.Buffer())); + return LegalVal::simple(context->builder->emitCallInst( + callInst->type, + callInst->func.usedValue, + instArgs.Count(), + instArgs.Buffer())); } static LegalVal legalizeLoad( @@ -454,6 +1135,9 @@ static LegalVal legalizeLoad( { switch (legalPtrVal.flavor) { + case LegalVal::Flavor::none: + return LegalVal(); + case LegalVal::Flavor::simple: { return LegalVal::simple( @@ -467,18 +1151,28 @@ static LegalVal legalizeLoad( // the underlying value. return legalPtrVal.getImplicitDeref(); + case LegalVal::Flavor::pair: + { + auto ptrPairVal = legalPtrVal.getPair(); + + auto ordinaryVal = legalizeLoad(context, ptrPairVal->ordinaryVal); + auto specialVal = legalizeLoad(context, ptrPairVal->specialVal); + return LegalVal::pair(ordinaryVal, specialVal, ptrPairVal->pairInfo); + } + case LegalVal::Flavor::tuple: { // We need to emit a load for each element of // the tuple. - RefPtr<TupleVal> tupleVal = new TupleVal(); + auto ptrTupleVal = legalPtrVal.getTuple(); + RefPtr<TuplePseudoVal> tupleVal = new TuplePseudoVal(); + for (auto ee : legalPtrVal.getTuple()->elements) { - TupleVal::Element element; + TuplePseudoVal::Element element; element.fieldDeclRef = ee.fieldDeclRef; element.val = legalizeLoad(context, ee.val); - tupleVal->elements.Add(element); } return LegalVal::tuple(tupleVal); @@ -498,6 +1192,9 @@ static LegalVal legalizeStore( { switch (legalPtrVal.flavor) { + case LegalVal::Flavor::none: + return LegalVal(); + case LegalVal::Flavor::simple: { context->builder->emitStore(legalPtrVal.getSimple(), legalVal.getSimple()); @@ -512,20 +1209,30 @@ static LegalVal legalizeStore( else return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal); + case LegalVal::Flavor::pair: + { + auto destPair = legalPtrVal.getPair(); + auto valPair = legalVal.getPair(); + legalizeStore(context, destPair->ordinaryVal, valPair->ordinaryVal); + legalizeStore(context, destPair->specialVal, valPair->specialVal); + return LegalVal(); + } + case LegalVal::Flavor::tuple: - { - // We need to emit a store for each element of - // the tuple. - auto destTuple = legalPtrVal.getTuple(); - auto valTuple = legalVal.getTuple(); - SLANG_ASSERT(destTuple->elements.Count() == valTuple->elements.Count()); - for (UInt i = 0; i < valTuple->elements.Count(); i++) { - legalizeStore(context, destTuple->elements[i].val, valTuple->elements[i].val); + // We need to emit a store for each element of + // the tuple. + auto destTuple = legalPtrVal.getTuple(); + auto valTuple = legalVal.getTuple(); + SLANG_ASSERT(destTuple->elements.Count() == valTuple->elements.Count()); + + for (UInt i = 0; i < valTuple->elements.Count(); i++) + { + legalizeStore(context, destTuple->elements[i].val, valTuple->elements[i].val); + } + return legalVal; } - return legalVal; - } - break; + break; default: SLANG_UNEXPECTED("unhandled case"); @@ -556,6 +1263,49 @@ static LegalVal legalizeFieldAddress( legalPtrOperand.getSimple(), fieldOperand)); + case LegalVal::Flavor::pair: + { + // There are two sides, the ordinary and the special, + // and we basically just dispatch to both of them. + auto pairVal = legalPtrOperand.getPair(); + auto pairInfo = pairVal->pairInfo; + auto pairElement = pairInfo->findElement(fieldDeclRef); + if (!pairElement) + { + SLANG_UNEXPECTED("didn't find tuple element"); + UNREACHABLE_RETURN(LegalVal()); + } + + // If the field we are extracting has a pair type, + // that means it exists on both the ordinary and + // special sides. + RefPtr<PairInfo> fieldPairInfo; + LegalType ordinaryType = type; + LegalType specialType = type; + if (type.flavor == LegalType::Flavor::pair) + { + auto fieldPairType = type.getPair(); + fieldPairInfo = fieldPairType->pairInfo; + ordinaryType = LegalType::simple(fieldPairType->ordinaryType); + specialType = fieldPairType->specialType; + } + + LegalVal ordinaryVal; + LegalVal specialVal; + + if (pairElement->flags & PairInfo::kFlag_hasOrdinary) + { + ordinaryVal = legalizeFieldAddress(context, ordinaryType, pairVal->ordinaryVal, legalFieldOperand); + } + + if (pairElement->flags & PairInfo::kFlag_hasSpecial) + { + specialVal = legalizeFieldAddress(context, specialType, pairVal->specialVal, legalFieldOperand); + } + return LegalVal::pair(ordinaryVal, specialVal, fieldPairInfo); + } + break; + case LegalVal::Flavor::tuple: { // The operand is a tuple of pointer-like @@ -563,13 +1313,18 @@ static LegalVal legalizeFieldAddress( // corresponding to a field. We will handle // this by simply returning the corresponding // element from the operand. - for (auto ee : legalPtrOperand.getTuple()->elements) + auto ptrTupleInfo = legalPtrOperand.getTuple(); + for (auto ee : ptrTupleInfo->elements) { if (ee.fieldDeclRef.Equals(fieldDeclRef)) { return ee.val; } } + + // TODO: we can legally reach this case now + // when the field is "ordinary". + SLANG_UNEXPECTED("didn't find tuple element"); UNREACHABLE_RETURN(LegalVal()); } @@ -743,6 +1498,8 @@ static void addParamType(IRFuncType * ftype, LegalType t) { switch (t.flavor) { + case LegalType::Flavor::none: + break; case LegalType::Flavor::simple: ftype->paramTypes.Add(t.obj.As<Type>()); break; @@ -752,9 +1509,16 @@ static void addParamType(IRFuncType * ftype, LegalType t) addParamType(ftype, imp->valueType); break; } + case LegalType::Flavor::pair: + { + auto pairInfo = t.getPair(); + addParamType(ftype, LegalType::simple(pairInfo->ordinaryType)); + addParamType(ftype, pairInfo->specialType); + } + break; case LegalType::Flavor::tuple: { - auto tup = t.obj.As<TupleType>(); + auto tup = t.obj.As<TuplePseudoType>(); for (auto & elem : tup->elements) addParamType(ftype, elem.type); } @@ -878,7 +1642,7 @@ static LegalVal declareSimpleVar( // those to all the nested resource infos. for (auto vv = varChain; vv; vv = vv->next) { - auto parentSpaceInfo = vv->varLayout->findOrAddResourceInfo(LayoutResourceKind::RegisterSpace); + auto parentSpaceInfo = vv->varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace); if (!parentSpaceInfo) continue; @@ -896,59 +1660,75 @@ static LegalVal declareSimpleVar( } } + DeclRef<VarDeclBase> varDeclRef; + if (varChain) + { + varDeclRef = varChain->varLayout->varDecl; + } + + IRBuilder* builder = context->builder; + + IRValue* irVar = nullptr; + LegalVal legalVarVal; + switch (op) { case kIROp_global_var: { - IRBuilder* builder = context->builder; - auto globalVar = builder->createGlobalVar(type); globalVar->removeFromParent(); globalVar->insertBefore(context->insertBeforeGlobal); - if (varLayout) - { - builder->addLayoutDecoration(globalVar, varLayout); - } - - return LegalVal::simple(globalVar); + irVar = globalVar; + legalVarVal = LegalVal::simple(irVar); } break; - case kIROp_Var: - { - IRBuilder* builder = context->builder; - auto localVar = builder->emitVar(type); - localVar->removeFromParent(); - localVar->insertBefore(context->insertBeforeLocalVar); - if (varLayout) + case kIROp_Var: { - builder->addLayoutDecoration(localVar, varLayout); + auto localVar = builder->emitVar(type); + localVar->removeFromParent(); + localVar->insertBefore(context->insertBeforeLocalVar); + + irVar = localVar; + legalVarVal = LegalVal::simple(irVar); + } - return LegalVal::simple(localVar); - } - break; + break; + case kIROp_Param: { - IRBuilder* builder = context->builder; auto param = builder->emitParam(type); if (context->insertBeforeParam->prevParam) context->insertBeforeParam->prevParam->nextParam = param; param->prevParam = context->insertBeforeParam->prevParam; param->nextParam = context->insertBeforeParam; context->insertBeforeParam->prevParam = param; - if (varLayout) - { - builder->addLayoutDecoration(param, varLayout); - } - return LegalVal::simple(param); + irVar = param; + legalVarVal = LegalVal::simple(irVar); } break; + default: SLANG_UNEXPECTED("unexpected IR opcode"); break; } + + if (irVar) + { + if (varLayout) + { + builder->addLayoutDecoration(irVar, varLayout); + } + + if (varDeclRef) + { + builder->addHighLevelDeclDecoration(irVar, varDeclRef.getDecl()); + } + } + + return legalVarVal; } static RefPtr<TypeLayout> getDerefTypeLayout( @@ -991,6 +1771,9 @@ static LegalVal declareVars( { switch (type.flavor) { + case LegalType::Flavor::none: + return LegalVal(); + case LegalType::Flavor::simple: return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain); break; @@ -1010,12 +1793,20 @@ static LegalVal declareVars( } break; + case LegalType::Flavor::pair: + { + auto pairType = type.getPair(); + auto ordinaryVal = declareVars(context, op, LegalType::simple(pairType->ordinaryType), typeLayout, varChain); + auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain); + return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo); + } + case LegalType::Flavor::tuple: { // Declare one variable for each element of the tuple auto tupleType = type.getTuple(); - RefPtr<TupleVal> tupleVal = new TupleVal(); + RefPtr<TuplePseudoVal> tupleVal = new TuplePseudoVal(); for (auto ee : tupleType->elements) { @@ -1035,14 +1826,16 @@ static LegalVal declareVars( newVarChain = &newVarChainStorage; } - TupleVal::Element element; - element.fieldDeclRef = ee.fieldDeclRef; - element.val = declareVars( + LegalVal fieldVal = declareVars( context, op, ee.type, fieldTypeLayout, newVarChain); + + TuplePseudoVal::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.val = fieldVal; tupleVal->elements.Add(element); } @@ -1166,424 +1959,4 @@ void legalizeTypes( } -#if 0 - typedef unsigned int TypeScalarizationFlags; - enum TypeScalarizationFlag - { - anyResource = 0x1, - anyNonResource = 0x2, - anyAggregate = 0x4, - }; - - bool isResourceType(Type* type) - { - while (auto arrayType = type->As<ArrayExpressionType>()) - { - type = arrayType->baseType; - } - - if (auto textureTypeBase = type->As<TextureTypeBase>()) - { - return true; - } - else if (auto samplerType = type->As<SamplerStateType>()) - { - return true; - } - - // TODO: need more comprehensive coverage here - - return false; - } - - TypeScalarizationFlags getTypeScalarizationFlags( - Session* session, - Type* type) - { - // TODO: we should probably cache flags once - // they are computed, to avoid O(N^2) sorts - // of behavior. - - if (isResourceType(type)) - return TypeScalarizationFlag::anyNonResource; - - if(type->As<BasicExpressionType>()) - { - return TypeScalarizationFlag::anyNonResource; - } - if(type->As<VectorExpressionType>()) - { - return TypeScalarizationFlag::anyNonResource; - } - if(type->As<MatrixExpressionType>()) - { - return TypeScalarizationFlag::anyNonResource; - } - else if (auto declRefType = type->As<DeclRefType>()) - { - auto declRef = declRefType->declRef; - if (auto structDeclRef = declRef.As<StructDecl>()) - { - TypeScalarizationFlags flags = TypeScalarizationFlag::anyAggregate; - - // For structure types, the basic rule will be - // that if the type contains *any* resource-type - // fields, then it needs to be scalarized. - // If it contains any non-resource-type fields, - // then we should aggregate these into a single - // new `struct` type with just the non-resource - // fields. - for (auto fieldDeclRef : getMembersOfType<StructField>(structDeclRef)) - { - auto fieldType = GetType(fieldDeclRef); - - // TODO: we are making a recursive call here, so - // this will break if/when we ever allowed a recursive type! - auto fieldFlags = getTypeScalarizationFlags(session, fieldType); - flags |= fieldFlags; - - } - - return flags; - } - } - else if (auto arrayType = type->As<ArrayExpressionType>()) - { - return getTypeScalarizationFlags( - session, - arrayType->baseType); - } - - // Default behavior: assume we have a non-resource type - return TypeScalarizationFlag::anyNonResource; - } - - struct ArrayScalarizationInfo - { - ArrayScalarizationInfo* next; - RefPtr<IntVal> elementCount; - RefPtr<ArrayTypeLayout> typeLayout; - }; - - struct SharedScalarizationContext - { - - }; - - struct ScalarizationContext - { - SharedScalarizationContext* shared; - - IRBuilder* builder; - IRGlobalVar* globalVar; - VarLayout* globalVarLayout; - - IRGlobalValue* valueToInsertAfter; - }; - - IRValue* emitSimpleScalarizedField( - ScalarizationContext* context, - Type* inType, - VarLayout* fieldLayout, - TypeLayout* inTypeLayout, - ArrayScalarizationInfo* arrayInfo) - { - auto builder = context->builder; - auto globalVar = context->globalVar; - auto globalVarLayout = context->globalVarLayout; - auto valueToInsertAfter = context->valueToInsertAfter; - - RefPtr<Type> type = inType; - RefPtr<TypeLayout> typeLayout = inTypeLayout; - - // If we are turning an array-of-structs into - // a struct-of-arrays, then we need to apply - // all the appropriate array dimensions here. - for (auto aa = arrayInfo; aa; aa = aa->next) - { - type = builder->getSession()->getArrayType(type, aa->elementCount); - - if (typeLayout) - { - RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); - arrayTypeLayout->elementTypeLayout = typeLayout; - - // TODO: fill in the other fields! - - typeLayout = arrayTypeLayout; - } - } - - RefPtr<VarLayout> newVarLayout; - if (typeLayout) - { - newVarLayout = new VarLayout(); - newVarLayout->typeLayout = typeLayout; - - if (fieldLayout) - { - for (auto fieldResourceInfo : fieldLayout->resourceInfos) - { - auto newResourceInfo = newVarLayout->findOrAddResourceInfo(fieldResourceInfo.kind); - - if (globalVarLayout) - { - if (auto globalResourceInfo = globalVarLayout->FindResourceInfo(fieldResourceInfo.kind)) - { - newResourceInfo->index += globalResourceInfo->index; - newResourceInfo->space += globalResourceInfo->space; - } - } - - newResourceInfo->index += fieldResourceInfo.index; - newResourceInfo->space += fieldResourceInfo.space; - } - } - } - - auto newGlobalVar = addGlobalVariable(builder->getModule(), type); - builder->addLayoutDecoration(newGlobalVar, newVarLayout); - - newGlobalVar->removeFromParent(); - newGlobalVar->insertAfter(valueToInsertAfter); - - context->valueToInsertAfter = newGlobalVar; - - return newGlobalVar; - } - - void scalarizeGlobalVariable( - ScalarizationContext* context, - Type* valueType, - TypeLayout* valueTypeLayout, - ArrayScalarizationInfo* arrayInfo) - { - if (auto arrayType = valueType->As<ArrayExpressionType>()) - { - // Okay, we need to recurse down and scalarize the - // array element type, wrapping up each field in - // an array declarator as needed. - - ArrayScalarizationInfo newArrayInfo; - newArrayInfo.next = arrayInfo; - newArrayInfo.elementCount = arrayType->ArrayLength; - - RefPtr<TypeLayout> elementTypeLayout; - if (auto arrayTypeLayout = dynamic_cast<ArrayTypeLayout*>(valueTypeLayout)) - { - newArrayInfo.typeLayout = arrayTypeLayout; - elementTypeLayout = arrayTypeLayout->elementTypeLayout; - } - - scalarizeGlobalVariable( - context, - arrayType->baseType, - elementTypeLayout, - &newArrayInfo); - - // Now we need to look at all uses of the variable, - // and properly rework element-index operations - // to instead index into the sub-arrays... - } - else if (auto declRefType = valueType->As<DeclRefType>()) - { - auto declRef = declRefType->declRef; - if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) - { - RefPtr<StructTypeLayout> structTypeLayout = dynamic_cast<StructTypeLayout*>(valueTypeLayout); - - // Okay, we need to look through the fields, and - // create a new variable for each of them. - Dictionary<Decl*, IRValue*> fieldMap; - UInt fieldCounter = 0; - for (auto fieldDeclRef : getMembersOfType<StructField>(aggTypeDeclRef)) - { - UInt fieldIndex = fieldCounter++; - - RefPtr<VarLayout> fieldLayout; - RefPtr<TypeLayout> fieldTypeLayout; - if (structTypeLayout) - { - fieldLayout = structTypeLayout->fields[fieldIndex]; - fieldTypeLayout = fieldLayout->typeLayout; - } - - // Note: we do *not* try to deal with recursive - // expansion of the fields here, and instead - // prefer to handle those in further - // simplification passes. - - auto fieldGlobalVar = emitSimpleScalarizedField( - context, - GetType(fieldDeclRef), - fieldLayout, - fieldTypeLayout, - arrayInfo); - - fieldMap.Add(fieldDeclRef.getDecl(), fieldGlobalVar); - } - - // Now we need to scan for uses of the original variable, - // and replace them with uses of the individual fields. - auto globalVar = context->globalVar; - IRUse* nextUse = nullptr; - for (IRUse* use = globalVar->firstUse; use; use = nextUse) - { - nextUse = use->nextUse; - - IRUser* user = use->user; - switch (user->op) - { - case kIROp_FieldAddress: - { - // This should be the easy case: we are taking - // the address of a field inside this global - // value, so we can just return the adress - // of the global value that replaced that field. - IRFieldAddress* fieldAddressInst = (IRFieldAddress*)user; - - IRValue* fieldOperand = fieldAddressInst->getField(); - assert(fieldOperand->op == kIROp_decl_ref); - auto fieldDeclRef = ((IRDeclRef*)fieldOperand)->declRef; - auto fieldDecl = fieldDeclRef.getDecl(); - - IRValue* fieldVar = *fieldMap.TryGetValue(fieldDecl); - - fieldAddressInst->replaceUsesWith(fieldVar); - } - break; - - default: - SLANG_UNEXPECTED("what to do?"); - break; - } - } - } - else - { - SLANG_UNEXPECTED("not handled"); - } - } - else - { - SLANG_UNEXPECTED("not handled"); - } - } - - void scalarizeGlobalVariable( - SharedScalarizationContext* sharedContext, - IRBuilder* builder, - IRGlobalVar* globalVar, - VarLayout* globalVarLayout, - Type* valueType, - TypeLayout* valueTypeLayout) - { - ScalarizationContext contextStorage; - auto context = &contextStorage; - - context->shared = sharedContext; - context->builder = builder; - context->globalVar = globalVar; - context->globalVarLayout = globalVarLayout; - context->valueToInsertAfter = globalVar; - - scalarizeGlobalVariable( - context, - valueType, - valueTypeLayout, - nullptr); - } - - RefPtr<VarLayout> findVarLayout(IRValue* value) - { - if (auto layoutDecoration = value->findDecoration<IRLayoutDecoration>()) - return layoutDecoration->layout.As<VarLayout>(); - return nullptr; - } - - void scalarizeMixedResourceTypes( - Session* session, - IRModule* module) - { - SharedIRBuilder sharedBuilderStorage; - auto sharedBuilder = &sharedBuilderStorage; - - sharedBuilder->session = session; - sharedBuilder->module = module; - - IRBuilder builderStorage; - auto builder = &builderStorage; - - builder->shared = sharedBuilder; - - SharedScalarizationContext sharedContextStorage; - auto sharedContext = &sharedContextStorage; - - - List<IRValue*> workList; - for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) - { - workList.Add(gv); - } - - while (workList.Count()) - { - IRValue* value = workList[0]; - workList.FastRemoveAt(0); - - switch (value->op) - { - case kIROp_Func: - { - // TODO: need to iterate over parameters of - // the function (and its blocks) to make - // sure that any types that need scalarization - // are properly handled. - } - break; - - case kIROp_global_var: - { - IRGlobalVar* globalVar = (IRGlobalVar*)value; - auto valueType = globalVar->getType()->getValueType(); - - auto flags = getTypeScalarizationFlags(session, valueType); - if (!(flags & (TypeScalarizationFlag::anyNonResource | TypeScalarizationFlag::anyAggregate))) - continue; - - auto varLayout = findVarLayout(globalVar); - RefPtr<TypeLayout> typeLayout = varLayout ? varLayout->typeLayout : nullptr; - - // Okay, we have a variable of some composite type - // that we need to scalarize. Since this is a global, - // we also need to be careful to deal with any - // layout information that has been attached. - - scalarizeGlobalVariable( - sharedContext, - builder, - globalVar, - varLayout, - valueType, - typeLayout); - - globalVar->removeFromParent(); - // TODO: need to destroy this global! - } - break; - - default: - { - // TODO: look at the type of the value, - // and if it needs scalarization, replace - // it with a tuple here. - } - break; - } - } - } - - -#endif - } diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 92bcb6707..9ecafbc7d 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -960,6 +960,7 @@ namespace Slang if( !ptrType ) { // Bad! + SLANG_ASSERT(ptrType); return nullptr; } diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp index a15104d6a..b375fa80e 100644 --- a/source/slang/lower.cpp +++ b/source/slang/lower.cpp @@ -778,6 +778,21 @@ struct LoweringVisitor translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>()); } + RefPtr<Type> visitFilteredTupleType(FilteredTupleType* type) + { + RefPtr<FilteredTupleType> loweredType = new FilteredTupleType(); + loweredType->setSession(type->getSession()); + loweredType->originalType = lowerType(type->originalType); + for (auto ee : type->elements) + { + FilteredTupleType::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.type = lowerType(ee.type); + loweredType->elements.Add(element); + } + return loweredType; + } + RefPtr<Type> visitTypeType(TypeType* type) { return getTypeType(lowerType(type->type)); diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp index 68fa7f31b..e2db1b456 100644 --- a/source/slang/mangle.cpp +++ b/source/slang/mangle.cpp @@ -24,6 +24,13 @@ namespace Slang context->sb.append(value); } + void emit( + ManglingContext* context, + String const& value) + { + context->sb.append(value); + } + void emitName( ManglingContext* context, Name* name) @@ -117,6 +124,14 @@ namespace Slang { emitQualifiedName(context, declRefType->declRef); } + else if (auto tupleType = dynamic_cast<FilteredTupleType*>(type)) + { + // TODO: this doesn't handle the possibility of multiple different + // filtered versions of the same type... + emitRaw(context, "t"); + emitType(context, tupleType->originalType); + emitRaw(context, "_"); + } else { SLANG_UNEXPECTED("unimplemented case in mangling"); @@ -398,4 +413,12 @@ namespace Slang return context.sb.ProduceString(); } + String getMangledTypeName(Type* type) + { + ManglingContext context; + emitType(&context, type); + return context.sb.ProduceString(); + } + + } diff --git a/source/slang/mangle.h b/source/slang/mangle.h index 29101a926..65afea741 100644 --- a/source/slang/mangle.h +++ b/source/slang/mangle.h @@ -11,10 +11,13 @@ namespace Slang String getMangledName(Decl* decl); String getMangledName(DeclRef<Decl> const & declRef); String getMangledName(DeclRefBase const & declRef); + String mangleSpecializedFuncName(String baseName, RefPtr<Substitutions> subst); String getMangledNameForConformanceWitness( Type* sub, Type* sup); + + String getMangledTypeName(Type* type); } #endif
\ No newline at end of file diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 81df49713..e5fc8dfa3 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -1713,4 +1713,117 @@ void Type::accept(IValVisitor* visitor, void* extra) return nullptr; } + // FilteredTupleType + + String FilteredTupleType::ToString() + { + StringBuilder sb; + sb.append(originalType->ToString()); + sb.append("{"); + bool first = true; + for (auto ee : elements) + { + if (!ee.type) + continue; + + if (!first) sb.append(", "); + + sb.append(ee.fieldDeclRef.GetName()->text); + sb.append(":"); + sb.append(ee.type->ToString()); + + first = false; + } + sb.append("}"); + return sb.ProduceString(); + } + + RefPtr<Val> FilteredTupleType::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + int diff = 0; + auto substOriginalType = originalType->SubstituteImpl(subst, &diff).As<Type>(); + + List<Element> substElements; + for (auto ee : elements) + { + Element substElement; + substElement.fieldDeclRef = ee.fieldDeclRef.SubstituteImpl(subst, &diff); + substElement.type = ee.type->SubstituteImpl(subst, &diff).As<Type>(); + substElements.Add(substElement); + } + + if (!diff) + return this; + + (*ioDiff)++; + RefPtr<FilteredTupleType> substType = new FilteredTupleType(); + substType->setSession(session); + substType->originalType = substOriginalType; + substType->elements = substElements; + return substType; + } + + bool FilteredTupleType::EqualsImpl(Type * type) + { + auto tupleType = type->As<FilteredTupleType>(); + if (!tupleType) + return false; + + if (!originalType->Equals(tupleType->originalType)) + return false; + + auto elementCount = elements.Count(); + if (tupleType->elements.Count() != elementCount) + return false; + + for (UInt ee = 0; ee < elementCount; ee++) + { + if (!elements[ee].type || !tupleType->elements[ee].type) + { + if (!elements[ee].type != !tupleType->elements[ee].type) + return false; + + continue; + } + + if (!elements[ee].fieldDeclRef.Equals(tupleType->elements[ee].fieldDeclRef)) + return false; + + if (!elements[ee].type->Equals(tupleType->elements[ee].type)) + return false; + } + return true; + } + + int FilteredTupleType::GetHashCode() + { + int hash = (int)(typeid(this).hash_code()); + hash = combineHash(hash, + originalType->GetHashCode()); + for (auto ee : elements) + { + hash = combineHash(hash, + ee.fieldDeclRef.GetHashCode()); + hash = combineHash(hash, + ee.type->GetHashCode()); + } + return hash; + } + + Type* FilteredTupleType::CreateCanonicalType() + { + RefPtr<FilteredTupleType> canTupleType = new FilteredTupleType(); + canTupleType->setSession(session); + canTupleType->originalType = originalType->GetCanonicalType(); + for (auto ee : elements) + { + Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.type = ee.type ? ee.type->GetCanonicalType() : nullptr; + + canTupleType->elements.Add(element); + } + getSession()->canonicalTypes.Add(canTupleType); + return canTupleType; + } } diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h index 34c5b5936..72bf6fe4c 100644 --- a/source/slang/type-defs.h +++ b/source/slang/type-defs.h @@ -492,4 +492,38 @@ protected: virtual int GetHashCode() override; virtual Type* CreateCanonicalType() override; ) -END_SYNTAX_CLASS()
\ No newline at end of file +END_SYNTAX_CLASS() + +// A type created to represent the result of filtering +// the fields of an aggregate type. +SYNTAX_CLASS(FilteredTupleType, Type) +RAW( + struct Element + { + // The original field this element represents + DeclRef<VarDeclBase> fieldDeclRef; + + // The type being used for the new field + RefPtr<Type> type; + }; +) + + FIELD(RefPtr<Type>, originalType); + FIELD(List<Element>, elements); + +RAW( + FilteredTupleType() + {} + + RefPtr<Type> getOriginalType() const { return originalType; } + List<Element> const& getElements() const { return elements; } + virtual String ToString() override; + +protected: + virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; + virtual bool EqualsImpl(Type * type) override; + virtual int GetHashCode() override; + virtual Type* CreateCanonicalType() override; +) + +END_SYNTAX_CLASS() |
