From baf194e7456ba4568dcf11249896af35b3ce18cc Mon Sep 17 00:00:00 2001 From: Tim Foley Date: Wed, 11 Apr 2018 16:18:29 -0700 Subject: Introduce an IR-level type system (#481) * Introduce an IR-level type system Up to this point, the Slang IR has used the front-end type system to represent types in the IR. As a result (but ultimately more importantly) the IR representation of generics and specialization has used AST-level concepts embedded in the IR. For example, to express the specialization of `vector` to a concrete type `float` for `T`, we needed an IR operation that could represent the specialization, with operands that somehow represented the type argument `float`. The whole thing was very complicated. The big idea of this change is to introduce a new representation in which types in the IR are just ordinary instructions, so that using them as operands makes sense. The hierarchy of IR types closely mirrors the AST-side hierarchy for now, and that will probably be something we should maintain going forward. In order to make these changes work, though, I also had to do major overhauls of things like the way substitutions are performed, how we check interface conformances, the way lookup through interface types is done, etc. etc. This is a big change, and unfortunately any attempt to summarize it in the commit message wouldn't do it justice. * Fix 64-bit build warning * Fix up some clang warnings/errors --- source/slang/bytecode.cpp | 92 +- source/slang/check.cpp | 971 +++++++---- source/slang/compiler.h | 23 +- source/slang/core.meta.slang | 12 +- source/slang/core.meta.slang.h | 24 +- source/slang/decl-defs.h | 2 +- source/slang/emit.cpp | 1540 ++++++++-------- source/slang/glsl.meta.slang | 202 --- source/slang/hlsl.meta.slang | 70 +- source/slang/hlsl.meta.slang.h | 106 +- source/slang/ir-constexpr.cpp | 80 +- source/slang/ir-inst-defs.h | 247 ++- source/slang/ir-insts.h | 215 ++- source/slang/ir-legalize-types.cpp | 302 ++-- source/slang/ir-ssa.cpp | 14 +- source/slang/ir-validate.cpp | 3 + source/slang/ir.cpp | 3394 +++++++++++++++++------------------- source/slang/ir.h | 405 ++++- source/slang/legalize-types.cpp | 449 +++-- source/slang/legalize-types.h | 83 +- source/slang/lookup.cpp | 203 ++- source/slang/lower-to-ir.cpp | 1777 ++++++++++++------- source/slang/mangle.cpp | 136 +- source/slang/mangle.h | 5 +- source/slang/modifier-defs.h | 7 - source/slang/parameter-binding.cpp | 63 +- source/slang/parser.cpp | 2 +- source/slang/slang-stdlib.cpp | 18 - source/slang/slang.cpp | 8 - source/slang/slang.natvis | 6 +- source/slang/slang.vcxproj | 19 - source/slang/slang.vcxproj.filters | 1 - source/slang/syntax-base-defs.h | 55 +- source/slang/syntax.cpp | 1270 +++++++++----- source/slang/syntax.h | 140 +- source/slang/type-defs.h | 111 +- source/slang/type-system-shared.h | 34 +- source/slang/val-defs.h | 33 +- source/slang/vm.cpp | 32 +- 39 files changed, 6753 insertions(+), 5401 deletions(-) delete mode 100644 source/slang/glsl.meta.slang (limited to 'source') diff --git a/source/slang/bytecode.cpp b/source/slang/bytecode.cpp index 8a062faaa..63af9512a 100644 --- a/source/slang/bytecode.cpp +++ b/source/slang/bytecode.cpp @@ -107,7 +107,7 @@ struct SharedBytecodeGenerationContext // Types that have been emitted List> bcTypes; - Dictionary mapTypeToID; + Dictionary mapTypeToID; // Compile-time constant values that need // to be emitted... @@ -308,7 +308,7 @@ void encodeOperand( uint32_t getTypeID( BytecodeGenerationContext* context, - Type* type); + IRType* type); void encodeOperand( BytecodeGenerationContext* context, @@ -326,11 +326,8 @@ bool opHasResult(IRInst* inst) // the function returns the distinguished `Void` type, // since that is conceptually the same as "not returning // a value." - if (auto basicType = dynamic_cast(type)) - { - if (basicType->baseType == BaseType::Void) - return false; - } + if(type->op == kIROp_VoidType) + return false; return true; } @@ -465,7 +462,7 @@ void generateBytecodeForInst( BytecodeGenerationPtr emitBCType( BytecodeGenerationContext* context, - Type* type, + IRType* type, IROp op, BytecodeGenerationPtr const* args, UInt argCount) @@ -498,7 +495,7 @@ BytecodeGenerationPtr emitBCType( BytecodeGenerationPtr emitBCVarArgType( BytecodeGenerationContext* context, - Type* type, + IRType* type, IROp op, List> args) { @@ -507,7 +504,7 @@ BytecodeGenerationPtr emitBCVarArgType( BytecodeGenerationPtr emitBCType( BytecodeGenerationContext* context, - Type* type, + IRType* type, IROp op) { return emitBCType(context, type, op, nullptr, 0); @@ -515,12 +512,12 @@ BytecodeGenerationPtr emitBCType( BytecodeGenerationPtr emitBCType( BytecodeGenerationContext* context, - Type* type); + IRType* type); // Emit a `BCType` representation for the given `Type` BytecodeGenerationPtr emitBCTypeImpl( BytecodeGenerationContext* context, - Type* type) + IRType* type) { // A NULL type is interpreted as equivalent to `Void` for now. if( !type ) @@ -528,65 +525,20 @@ BytecodeGenerationPtr emitBCTypeImpl( return emitBCType(context, type, kIROp_VoidType); } - if( auto basicType = type->As() ) + List> operands; + UInt operandCount = type->getOperandCount(); + for (UInt ii = 0; ii < operandCount; ++ii) { - switch(basicType->baseType) - { - case BaseType::Void: return emitBCType(context, type, kIROp_VoidType); - case BaseType::Bool: return emitBCType(context, type, kIROp_BoolType); - case BaseType::Int: return emitBCType(context, type, kIROp_Int32Type); - case BaseType::UInt: return emitBCType(context, type, kIROp_UInt32Type); - case BaseType::UInt64: return emitBCType(context, type, kIROp_UInt64Type); - case BaseType::Half: return emitBCType(context, type, kIROp_Float16Type); - case BaseType::Float: return emitBCType(context, type, kIROp_Float32Type); - case BaseType::Double: return emitBCType(context, type, kIROp_Float64Type); - - default: - break; - } + operands.Add(emitBCType(context, (IRType*) type->getOperand(ii)).bitCast()); } - else if( auto funcType = type->As() ) - { - List> operands; - - operands.Add(emitBCType(context, funcType->resultType).bitCast()); - UInt paramCount = funcType->getParamCount(); - for(UInt pp = 0; pp < paramCount; ++pp) - { - operands.Add(emitBCType(context, funcType->getParamType(pp)).bitCast()); - } - - return emitBCVarArgType(context, type, kIROp_FuncType, operands); - } - else if( auto ptrType = type->As() ) - { - List> operands; - operands.Add(emitBCType(context, ptrType->getValueType()).bitCast()); - return emitBCVarArgType(context, type, kIROp_PtrType, operands); - } - else if( auto rwStructuredBufferType = type->As() ) - { - List> operands; - operands.Add(emitBCType(context, rwStructuredBufferType->elementType).bitCast()); - return emitBCVarArgType(context, type, kIROp_readWriteStructuredBufferType, operands); - } - else if( auto structuredBufferType = type->As() ) - { - List> operands; - operands.Add(emitBCType(context, structuredBufferType->elementType).bitCast()); - return emitBCVarArgType(context, type, kIROp_structuredBufferType, operands); - } - - - SLANG_UNEXPECTED("unimplemented"); - UNREACHABLE_RETURN(BytecodeGenerationPtr()); + return emitBCVarArgType(context, type, type->op, operands); } BytecodeGenerationPtr emitBCType( BytecodeGenerationContext* context, - Type* type) + IRType* type) { - auto canonical = type->GetCanonicalType(); + auto canonical = type->getCanonicalType(); UInt id = 0; if(context->shared->mapTypeToID.TryGetValue(canonical, id)) { @@ -599,7 +551,7 @@ BytecodeGenerationPtr emitBCType( uint32_t getTypeID( BytecodeGenerationContext* context, - Type* type) + IRType* type) { // We have a type, and we need to emit it (if we haven't // already) and return its index in the global type table. @@ -821,7 +773,7 @@ BytecodeGenerationPtr generateBytecodeSymbolForInst( bcRegs[localID+1].op = ii->op; bcRegs[localID+1].previousVarIndexPlusOne = (uint32_t)localID+1; bcRegs[localID+1].typeID = getTypeID(context, - (ii->getDataType()->As())->getValueType()); + (as(ii->getDataType()))->getValueType()); } break; } @@ -902,13 +854,13 @@ BytecodeGenerationPtr generateBytecodeSymbolForInst( } break; - case kIROp_global_var: - case kIROp_global_constant: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: { auto bcVar = allocate(context); bcVar->op = inst->op; - bcVar->typeID = getTypeID(context, inst->type); + bcVar->typeID = getTypeID(context, inst->getFullType()); // TODO: actually need to intialize with body instructions @@ -1003,7 +955,7 @@ BytecodeGenerationPtr generateBytecodeForModule( { auto irConstant = (IRConstant*) context->shared->constants[cc]; bcConstants[cc].op = irConstant->op; - bcConstants[cc].typeID = getTypeID(context, irConstant->type); + bcConstants[cc].typeID = getTypeID(context, irConstant->getFullType()); switch(irConstant->op) { diff --git a/source/slang/check.cpp b/source/slang/check.cpp index eb15d0889..67b628596 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -168,64 +168,54 @@ namespace Slang RefPtr baseExpr, SourceLoc loc) { + // Compute the type that this declaration reference will have in context. + // + auto type = GetTypeForDeclRef(declRef); + + // Construct an appropriate expression based on teh structured of + // the declaration reference. + // if (baseExpr) { - RefPtr expr; - DeclRef *declRefOut; + // If there was a base expression, we will have some kind of + // member expression. + // if (baseExpr->type->As()) { - auto sexpr = new StaticMemberExpr(); - sexpr->loc = loc; - sexpr->BaseExpression = baseExpr; - sexpr->name = declRef.GetName(); - sexpr->declRef = declRef; - declRefOut = &sexpr->declRef; - expr = sexpr; + // If the base expression was a type, then that means we + // are constructing a static member reference. + // + auto expr = new StaticMemberExpr(); + expr->loc = loc; + expr->type = type; + expr->BaseExpression = baseExpr; + expr->name = declRef.GetName(); + expr->declRef = declRef; + return expr; } else { - auto sexpr = new MemberExpr(); - sexpr->loc = loc; - sexpr->BaseExpression = baseExpr; - sexpr->name = declRef.GetName(); - sexpr->declRef = declRef; - declRefOut = &sexpr->declRef; - expr = sexpr; - } - - RefPtr baseThisTypeSubst; - if (auto baseDeclRefExpr = baseExpr->As()) - { - baseThisTypeSubst = getThisTypeSubst(baseDeclRefExpr->declRef, false); - } - if (declRef.As()) - { - // if this is a reference to type constraint, insert a this-type substitution - RefPtr expType; - expType = baseExpr->type; - if (auto baseExprTT = baseExpr->type->As()) - expType = baseExprTT->type; - auto thisTypeSubst = getNewThisTypeSubst(*declRefOut); - thisTypeSubst->sourceType = expType; - baseThisTypeSubst = nullptr; - } - // propagate "this-type" substitutions - if (baseThisTypeSubst) - { - if (auto declRefExpr = expr.As()) - { - getNewThisTypeSubst(declRefExpr->declRef)->sourceType = baseThisTypeSubst->sourceType; - } + // If the base expression wasn't a type, then this + // is a normal member expression. + // + auto expr = new MemberExpr(); + expr->loc = loc; + expr->type = type; + expr->BaseExpression = baseExpr; + expr->name = declRef.GetName(); + expr->declRef = declRef; + return expr; } - expr->type = GetTypeForDeclRef(*declRefOut); - return expr; } else { + // If there is no base expression, then the result must + // be an ordinary variable expression. + // auto expr = new VarExpr(); expr->loc = loc; expr->name = declRef.GetName(); - expr->type = GetTypeForDeclRef(declRef); + expr->type = type; expr->declRef = declRef; return expr; } @@ -444,12 +434,12 @@ namespace Slang // The arguments should already be checked against // the declaration. RefPtr InstantiateGenericType( - DeclRef genericDeclRef, - List> const& args) + DeclRef genericDeclRef, + List> const& args) { RefPtr subst = new GenericSubstitution(); subst->genericDecl = genericDeclRef.getDecl(); - subst->outer = genericDeclRef.substitutions.genericSubstitutions; + subst->outer = genericDeclRef.substitutions.substitutions; for (auto argExpr : args) { @@ -458,8 +448,7 @@ namespace Slang DeclRef innerDeclRef; innerDeclRef.decl = GetInner(genericDeclRef); - innerDeclRef.substitutions = SubstitutionSet(subst, genericDeclRef.substitutions.thisTypeSubstitution, - genericDeclRef.substitutions.globalGenParamSubstitutions); + innerDeclRef.substitutions = SubstitutionSet(subst); return DeclRefType::Create( getSession(), @@ -874,7 +863,7 @@ namespace Slang auto arg = fromInitializerListExpr->args[argIndex++]; - // + // RefPtr coercedArg; ConversionCost argCost; @@ -1066,7 +1055,7 @@ namespace Slang overloadContext.baseExpr = nullptr; overloadContext.mode = OverloadResolveContext::Mode::JustTrying; - + AddTypeOverloadCandidates(toType, overloadContext, toType); if(overloadContext.bestCandidates.Count() != 0) @@ -1821,7 +1810,7 @@ namespace Slang for (int pass = 0; pass < 2; pass++) { checkingPhase = pass == 0 ? CheckingPhase::Header : CheckingPhase::Body; - + for (auto & s : programNode->getMembersOfType()) { checkDecl(s.Ptr()); @@ -1866,7 +1855,7 @@ namespace Slang { checkModifiers(d.Ptr()); } - + if (pass == 0) { // now we can check all interface conformances @@ -1896,20 +1885,22 @@ namespace Slang } bool doesSignatureMatchRequirement( - DeclRef memberDecl, + DeclRef satisfyingMemberDeclRef, DeclRef requiredMemberDeclRef, - Dictionary, DeclRef> & requirementDict) + RefPtr witnessTable) { // TODO: actually implement matching here. For now we'll // just pretend that things are satisfied in order to make progress.. - requirementDict.AddIfNotExists(requiredMemberDeclRef, memberDecl); + witnessTable->requirementDictionary.Add( + requiredMemberDeclRef.getDecl(), + RequirementWitness(satisfyingMemberDeclRef)); return true; } bool doesGenericSignatureMatchRequirement( - DeclRef genDecl, - DeclRef requirementGenDecl, - Dictionary, DeclRef> & requirementDict) + DeclRef genDecl, + DeclRef requirementGenDecl, + RefPtr witnessTable) { if (genDecl.getDecl()->Members.Count() != requirementGenDecl.getDecl()->Members.Count()) return false; @@ -1948,20 +1939,81 @@ namespace Slang return false; } } - return doesMemberSatisfyRequirement(DeclRef(genDecl.getDecl()->inner.Ptr(), genDecl.substitutions), + + // TODO: this isn't right, because we need to specialize the + // declarations of the generics to a common set of substitutions, + // so that their types are comparable (e.g., foo and foo + // need to have substutition applies so that they are both foo, + // after which uses of the type X in their parameter lists can + // be compared). + + return doesMemberSatisfyRequirement( + DeclRef(genDecl.getDecl()->inner.Ptr(), genDecl.substitutions), DeclRef(requirementGenDecl.getDecl()->inner.Ptr(), requirementGenDecl.substitutions), - requirementDict); + witnessTable); + } + + bool doesTypeSatisfyAssociatedTypeRequirement( + RefPtr satisfyingType, + DeclRef requiredAssociatedTypeDeclRef, + RefPtr witnessTable) + { + // We need to confirm that the chosen type `satisfyingType`, + // meets all the constraints placed on the associated type + // requirement `requiredAssociatedTypeDeclRef`. + // + // We will enumerate the type constraints placed on the + // associated type and see if they can be satisfied. + // + bool conformance = true; + for (auto requiredConstraintDeclRef : getMembersOfType(requiredAssociatedTypeDeclRef)) + { + // Grab the type we expect to conform to from the constraint. + auto requiredSuperType = GetSup(requiredConstraintDeclRef); + + // Perform a search for a witness to the subtype relationship. + auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); + if(witness) + { + // If a subtype witness was found, then the conformance + // appears to hold, and we can satisfy that requirement. + witnessTable->requirementDictionary.Add(requiredConstraintDeclRef, RequirementWitness(witness)); + } + else + { + // If a witness couldn't be found, then the conformance + // seems like it will fail. + conformance = false; + } + } + + // TODO: if any conformance check failed, we should probably include + // that in an error message produced about not satisfying the requirement. + + if(conformance) + { + // If all the constraints were satsified, then the chosen + // type can indeed satisfy the interface requirement. + witnessTable->requirementDictionary.Add( + requiredAssociatedTypeDeclRef.getDecl(), + RequirementWitness(satisfyingType)); + } + + return conformance; } // Does the given `memberDecl` work as an implementation // to satisfy the requirement `requiredMemberDeclRef` // from an interface? + // + // If it does, then inserts a witness into `witnessTable` + // and returns `true`, otherwise returns `false` bool doesMemberSatisfyRequirement( - DeclRef memberDeclRef, - DeclRef requiredMemberDeclRef, - Dictionary, DeclRef> & requirementDictionary) + DeclRef memberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) { - // At a high level, we want to chack that the + // At a high level, we want to check that the // `memberDecl` and the `requiredMemberDeclRef` // have the same AST node class, and then also // check that their signatures match. @@ -1979,34 +2031,7 @@ namespace Slang // An associated type requirement should be allowed // to be satisfied by any type declaration: // a typedef, a `struct`, etc. - auto checkSubTypeMember = [&](DeclRef subStructTypeDeclRef) -> bool - { - checkDecl(subStructTypeDeclRef.getDecl()); - // this is a sub type (e.g. nested struct declaration) in an aggregate type - // check if this sub type declaration satisfies the constraints defined by the associated type - if (auto requiredTypeDeclRef = requiredMemberDeclRef.As()) - { - bool conformance = true; - auto inheritanceReqDeclRefs = getMembersOfType(requiredTypeDeclRef); - for (auto inheritanceReqDeclRef : inheritanceReqDeclRefs) - { - auto interfaceDeclRefType = inheritanceReqDeclRef.getDecl()->getSup().type.As(); - SLANG_ASSERT(interfaceDeclRefType); - auto interfaceDeclRef = interfaceDeclRefType->declRef.As(); - SLANG_ASSERT(interfaceDeclRef); - RefPtr declRefType = new DeclRefType(); - declRefType->declRef = subStructTypeDeclRef; - auto witness = tryGetInterfaceConformanceWitness(declRefType, - interfaceDeclRef).As(); - if (witness) - requirementDictionary.Add(inheritanceReqDeclRef, witness->getLastStepDeclRef()); - else - conformance = false; - } - return conformance; - } - return false; - }; + // if (auto memberFuncDecl = memberDeclRef.As()) { if (auto requiredFuncDeclRef = requiredMemberDeclRef.As()) @@ -2015,7 +2040,7 @@ namespace Slang return doesSignatureMatchRequirement( memberFuncDecl, requiredFuncDeclRef, - requirementDictionary); + witnessTable); } } else if (auto memberInitDecl = memberDeclRef.As()) @@ -2026,19 +2051,35 @@ namespace Slang return doesSignatureMatchRequirement( memberInitDecl, requiredInitDecl, - requirementDictionary); + witnessTable); } } else if (auto genDecl = memberDeclRef.As()) { + // For a generic member, we will check if it can satisfy + // a generic requirement in the interface. + // + // TODO: we could also conceivably check that the generic + // could be *specialized* to satisfy the requirement, + // and then install a specialization of the generic into + // the witness table. Actually doing this would seem + // to require performing something akin to overload + // resolution as part of requirement satisfaction. + // if (auto requiredGenDeclRef = requiredMemberDeclRef.As()) { - return doesGenericSignatureMatchRequirement(genDecl, requiredGenDeclRef, requirementDictionary); + return doesGenericSignatureMatchRequirement(genDecl, requiredGenDeclRef, witnessTable); } } - else if (auto subStructTypeDeclRef = memberDeclRef.As()) + else if (auto subAggTypeDeclRef = memberDeclRef.As()) { - return checkSubTypeMember(subStructTypeDeclRef); + if(auto requiredTypeDeclRef = requiredMemberDeclRef.As()) + { + checkDecl(subAggTypeDeclRef.getDecl()); + + auto satisfyingType = DeclRefType::Create(getSession(), subAggTypeDeclRef); + return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); + } } else if (auto typedefDeclRef = memberDeclRef.As()) { @@ -2046,28 +2087,25 @@ namespace Slang // check if the specified type satisfies the constraints defined by the associated type if (auto requiredTypeDeclRef = requiredMemberDeclRef.As()) { - auto declRefType = GetType(typedefDeclRef)->GetCanonicalType()->As(); - if (!declRefType) - return false; - - if (auto genTypeParamDeclRef = declRefType->declRef.As()) - { - // TODO: check generic type parameter satisfies constraints - return true; - } - - - auto containerDeclRef = declRefType->declRef.As(); - if (!containerDeclRef) - return false; + checkDecl(typedefDeclRef.getDecl()); - return checkSubTypeMember(containerDeclRef); + auto satisfyingType = getNamedType(getSession(), typedefDeclRef); + return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); } } // Default: just assume that thing aren't being satisfied. return false; } + // State used while checking if a declaration (either a type declaration + // or an extension of that type) conforms to the interfaces it claims + // via its inheritance clauses. + // + struct ConformanceCheckingContext + { + Dictionary, RefPtr> mapInterfaceToWitnessTable; + }; + // Find the appropriate member of a declared type to // satisfy a requirement of an interface the type // claims to conform to. @@ -2076,13 +2114,56 @@ namespace Slang // conforms to the interface `interfaceDeclRef`, and // `requiredMemberDeclRef` is a required member of // the interface. - RefPtr findWitnessForInterfaceRequirement( + // + // If a satisfying value is found, registers it in + // `witnessTable` and returns `true`, otherwise + // returns `false`. + // + bool findWitnessForInterfaceRequirement( + ConformanceCheckingContext* context, DeclRef typeDeclRef, - InheritanceDecl* inheritanceDecl, - DeclRef interfaceDeclRef, - DeclRef requiredMemberDeclRef, - Dictionary, DeclRef> & requirementWitness) + InheritanceDecl* inheritanceDecl, + DeclRef interfaceDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) { + // The goal of this function is to find a suitable + // value to satisfy the requirement. + // + // The 99% case is that the requirement is a named member + // of the interface, and we need to search for a member + // with the same name in the type declaration and + // its (known) extensions. + + // An important exception to the above is that an + // inheritance declaration in the interface is not going + // to be satisfied by an inheritance declaration in the + // conforming type, but rather by a full "witness table" + // full of the satisfying values for each requirement + // in the inherited-from interface. + // + if( auto requiredInheritanceDeclRef = requiredMemberDeclRef.As() ) + { + // Recursively check that the type conforms + // to the inherited interface. + // + // TODO: we *really* need a linearization step here!!!! + + RefPtr satisfyingWitnessTable = checkConformanceToType( + context, + typeDeclRef, + requiredInheritanceDeclRef.getDecl(), + getBaseType(requiredInheritanceDeclRef)); + + if(!satisfyingWitnessTable) + return false; + + witnessTable->requirementDictionary.Add( + requiredInheritanceDeclRef.getDecl(), + RequirementWitness(satisfyingWitnessTable)); + return true; + } + // We will look up members with the same name, // since only same-name members will be able to // satisfy the requirement. @@ -2117,21 +2198,21 @@ namespace Slang // Make sure that by-name lookup is possible. buildMemberDictionary(typeDeclRef.getDecl()); auto lookupResult = lookUpLocal(getSession(), this, name, typeDeclRef); - + if (!lookupResult.isValid()) { getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef); - return nullptr; + return false; } // Iterate over the members and look for one that matches // the expected signature for the requirement. for (auto member : lookupResult) { - if (doesMemberSatisfyRequirement(member.declRef, requiredMemberDeclRef, requirementWitness)) - return member.declRef.getDecl(); + if (doesMemberSatisfyRequirement(member.declRef, requiredMemberDeclRef, witnessTable)) + return true; } - + // No suitable member found, although there were candidates. // // TODO: Eventually we might want something akin to the current @@ -2140,83 +2221,125 @@ namespace Slang // and if nothing is found we print the candidates getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef); - return nullptr; + return false; } // Check that the type declaration `typeDecl`, which // declares conformance to the interface `interfaceDeclRef`, // (via the given `inheritanceDecl`) actually provides // members to satisfy all the requirements in the interface. - bool checkInterfaceConformance( - HashSet> & checkedInterfaceDeclRef, - DeclRef typeDeclRef, - InheritanceDecl* inheritanceDecl, - DeclRef interfaceDeclRef) - { - if (!checkedInterfaceDeclRef.Contains(interfaceDeclRef)) - checkedInterfaceDeclRef.Add(interfaceDeclRef); - else - return true; - - bool result = true; + RefPtr checkInterfaceConformance( + ConformanceCheckingContext* context, + DeclRef typeDeclRef, + InheritanceDecl* inheritanceDecl, + DeclRef interfaceDeclRef) + { + // Has somebody already checked this conformance, + // and/or is in the middle of checking it? + RefPtr witnessTable; + if(context->mapInterfaceToWitnessTable.TryGetValue(interfaceDeclRef, witnessTable)) + return witnessTable; // We need to check the declaration of the interface // before we can check that we conform to it. checkDecl(interfaceDeclRef.getDecl()); + // We will construct the witness table, and register it + // *before* we go about checking fine-grained requirements, + // in order to short-circuit any potential for infinite recursion. + + witnessTable = new WitnessTable(); + context->mapInterfaceToWitnessTable.Add(interfaceDeclRef, witnessTable); + + bool result = true; + // TODO: If we ever allow for implementation inheritance, // then we will need to consider the case where a type // declares that it conforms to an interface, but one of // its (non-interface) base types already conforms to // that interface, so that all of the requirements are // already satisfied with inherited implementations... - auto allMembers = getMembersWithExt(interfaceDeclRef); - for (auto requiredMemberDeclRef : allMembers) - { - // Some members of the interface don't actually represent - // things that we required of the implementing type. - // For example, when the interface declares that - // it inherits from another interface, we don't look for - // a matching inheritance clause on the type, but - // instead require that it also conforms to that - // interface. - if (auto requiredInheritanceDeclRef = requiredMemberDeclRef.As()) - { - // Recursively check that the type conforms - // to the inherited interface. - // - // TODO: we *really* need a linearization step here!!!! - result = result && checkConformanceToType( - checkedInterfaceDeclRef, - typeDeclRef, - inheritanceDecl, - getBaseType(requiredInheritanceDeclRef)); - continue; - } - - // Look for a member in the type that can satisfy the - // interface requirement. - auto isConformanceSatisfied = findWitnessForInterfaceRequirement( + for(auto requiredMemberDeclRef : getMembers(interfaceDeclRef)) + { + auto requirementSatisfied = findWitnessForInterfaceRequirement( + context, typeDeclRef, inheritanceDecl, interfaceDeclRef, requiredMemberDeclRef, - inheritanceDecl->requirementWitnesses); + witnessTable); - if (!isConformanceSatisfied) - { - result = false; + result = result && requirementSatisfied; + } + + // Extensions that apply to the interface type can create new conformances + // for the concrete types that inherit from the interface. + // + // These new conformances should not be able to introduce new *requirements* + // for an implementing interface (although they currently can), but we + // still need to go through this logic to find the appropriate value + // that will satisfy the requirement in these cases, and also to put + // the required entry into the witness table for the interface itself. + // + // TODO: This logic is a bit slippery, and we need to figure out what + // it means in the context of separate compilation. If module A defines + // an interface IA, module B defines a type C that conforms to IA, and then + // module C defines an extension that makes IA conform to IC, then it is + // unreasonable to expect the {B:IA} witness table to contain an entry + // corresponding to {IA:IC}. + // + // The simple answer then would be that the {IA:IC} conformance should be + // fixed, with a single witness table for {IA:IC}, but then what should + // happen in B explicitly conformed to IC already? + // + // For now we will just walk through the extensions that are known at + // the time we are compiling and handle those, and punt on the larger issue + // for abit longer. + for(auto candidateExt = interfaceDeclRef.getDecl()->candidateExtensions; candidateExt; candidateExt = candidateExt->nextCandidateExtension) + { + // We need to apply the extension to the interface type that our + // concrete type is inheriting from. + // + // TODO: need to decide if a this-type substitution is needed here. + // It probably it. + RefPtr targetType = DeclRefType::Create( + getSession(), + interfaceDeclRef); + auto extDeclRef = ApplyExtensionToType(candidateExt, targetType); + if(!extDeclRef) continue; + + // Only inheritance clauses from the extension matter right now. + for(auto requiredInheritanceDeclRef : getMembersOfType(extDeclRef)) + { + auto requirementSatisfied = findWitnessForInterfaceRequirement( + context, + typeDeclRef, + inheritanceDecl, + interfaceDeclRef, + requiredInheritanceDeclRef, + witnessTable); + + result = result && requirementSatisfied; } } - return result; + + // If we failed to satisfy any requirements along the way, + // then we don't actually want to keep the witness table + // we've been constructing, because the whole thing was a failure. + if(!result) + { + return nullptr; + } + + return witnessTable; } - bool checkConformanceToType( - HashSet>& checkedInterfaceDeclRefs, - DeclRef typeDeclRef, - InheritanceDecl* inheritanceDecl, - Type* baseType) + RefPtr checkConformanceToType( + ConformanceCheckingContext* context, + DeclRef typeDeclRef, + InheritanceDecl* inheritanceDecl, + Type* baseType) { if (auto baseDeclRefType = baseType->As()) { @@ -2227,7 +2350,7 @@ namespace Slang // We need to check that it provides all of the members // required by that interface. return checkInterfaceConformance( - checkedInterfaceDeclRefs, + context, typeDeclRef, inheritanceDecl, baseInterfaceDeclRef); @@ -2235,41 +2358,65 @@ namespace Slang } getSink()->diagnose(inheritanceDecl, Diagnostics::unimplemented, "type not supported for inheritance"); - return false; + return nullptr; } - // Check that the type declaration `typeDecl`, which - // declares that it inherits from another type via + // Check that the type (or extension) declaration `declRef`, + // which declares that it inherits from another type via // `inheritanceDecl` actually does what it needs to // for that inheritance to be valid. bool checkConformance( - DeclRef typeDecl, + DeclRef declRef, InheritanceDecl* inheritanceDecl) { + declRef = createDefaultSubstitutionsIfNeeded(getSession(), declRef).As(); + + // Don't check conformances for abstract types that + // are being used to express *required* conformances. + if (auto assocTypeDeclRef = declRef.As()) + { + // An associated type declaration represents a requirement + // in an outer interface declaration, and its members + // (type constraints) represent additional requirements. + return true; + } + else if (auto interfaceDeclRef = declRef.As()) + { + // HACK: Our semantics as they stand today are that an + // `extension` of an interface that adds a new inheritance + // clause acts *as if* that inheritnace clause had been + // attached to the original `interface` decl: that is, + // it adds additional requirements. + // + // This is *not* a reasonable semantic to keep long-term, + // but it is required for some of our current example + // code to work. + return true; + } + + // Look at the type being inherited from, and validate // appropriately. auto baseType = inheritanceDecl->base.type; - HashSet> checkdInterfaceDeclRefs; - return checkConformanceToType(checkdInterfaceDeclRefs, typeDecl, inheritanceDecl, baseType.As()); - } - bool checkConformance( - AggTypeDeclBase* typeDecl, - InheritanceDecl* inheritanceDecl) - { - return checkConformance(DeclRef(typeDecl, SubstitutionSet()), inheritanceDecl); + ConformanceCheckingContext context; + RefPtr witnessTable = checkConformanceToType(&context, declRef, inheritanceDecl, baseType); + if(!witnessTable) + return false; + + inheritanceDecl->witnessTable = witnessTable; + return true; } void checkExtensionConformance(ExtensionDecl* decl) { - DeclRef aggTypeDeclRef; if (auto targetDeclRefType = decl->targetType->As()) { - if (aggTypeDeclRef = targetDeclRefType->declRef.As()) + if (auto aggTypeDeclRef = targetDeclRefType->declRef.As()) { for (auto inheritanceDecl : decl->getMembersOfType()) { - checkConformance(aggTypeDeclRef.getDecl(), inheritanceDecl); + checkConformance(aggTypeDeclRef, inheritanceDecl); } } } @@ -2303,7 +2450,7 @@ namespace Slang // (That's what C# does). for (auto inheritanceDecl : decl->getMembersOfType()) { - checkConformance(decl, inheritanceDecl); + checkConformance(makeDeclRef(decl), inheritanceDecl); } } } @@ -2708,7 +2855,7 @@ namespace Slang // generic. // subst->genericDecl = prevGenericDecl; - prevFuncDeclRef.substitutions.genericSubstitutions = subst; + prevFuncDeclRef.substitutions.substitutions = subst; // // One way to think about it is that if we have these // declarations (ignore the name differences...): @@ -3481,6 +3628,7 @@ namespace Slang switch(getSourceLanguage()) { + default: case SourceLanguage::Slang: case SourceLanguage::HLSL: // HLSL: `static const` is used to mark compile-time constant expressions @@ -3626,7 +3774,7 @@ namespace Slang auto vectorGenericDecl = findMagicDecl( session, "Vector").As(); auto vectorTypeDecl = vectorGenericDecl->inner; - + auto substitutions = new GenericSubstitution(); substitutions->genericDecl = vectorGenericDecl.Ptr(); substitutions->args.Add(elementType); @@ -3815,11 +3963,10 @@ namespace Slang // TODO: need to check that the target type names a declaration... - DeclRef aggTypeDeclRef; if (auto targetDeclRefType = decl->targetType->As()) { // Attach our extension to that type as a candidate... - if (aggTypeDeclRef = targetDeclRefType->declRef.As()) + if (auto aggTypeDeclRef = targetDeclRefType->declRef.As()) { auto aggTypeDecl = aggTypeDeclRef.getDecl(); decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; @@ -4034,7 +4181,7 @@ namespace Slang // Crete a subtype witness based on the declared relationship // found in a single breadcrumb - RefPtr createSimplSubtypeWitness( + RefPtr createSimpleSubtypeWitness( TypeWitnessBreadcrumb* breadcrumb) { RefPtr witness = new DeclaredSubtypeWitness(); @@ -4052,7 +4199,7 @@ namespace Slang if(!inBreadcrumbs) { // We need to construct a witness to the fact - // that `type` has been proven to be equal + // that `type` has been proven to be *equal* // to `interfaceDeclRef`. // SLANG_UNEXPECTED("reflexive type witness"); @@ -4061,44 +4208,74 @@ namespace Slang // We might have one or more steps in the breadcrumb trail, e.g.: // - // (A : B) (B : C) (C : D) + // {A : B} {B : C} {C : D} // // The chain is stored as a reversed linked list, so that // the first entry would be the `(C : D)` relationship // above. // - // We are going to walk the list and build up a suitable - // subtype witness. + // We need to walk the list and build up a suitable witness, + // which in the above case would look like: + // + // Transitive( + // Transitive( + // Declared({A : B}), + // {B : C}), + // {C : D}) + // + // Because of the ordering of the breadcrumb trail, along + // with the way the `Transitive` case nests, we will be + // building these objects outside-in, and keeping + // track of the "hole" where the next step goes. + // auto bb = inBreadcrumbs; - // Create a witness for the last step in the chain - RefPtr witness = createSimplSubtypeWitness(bb); - bb = bb->prev; + // `witness` here will hold the first (outer-most) object + // we create, which is the overall result. + RefPtr witness; - // Now, as long as we have more entries to deal with, - // we'll be in a situation like: - // - // ... (B : C) - // - // and we want to wrap up one more link in our chain. + // `link` will point at the remaining "hole" in the + // data structure, to be filled in. + RefPtr* link = &witness; - while (bb) + // As long as there is more than one breadcrumb, we + // need to be creating transitie witnesses. + while(bb->prev) { - // Create simple witness for the step in the chain - RefPtr link = createSimplSubtypeWitness(bb); - - // Now join the link onto the existing chain represented - // by `witness`. + // On the first iteration when processing the list + // above, the breadcrumb would be for `{ C : D }`, + // and so we'd create: + // + // Transitive( + // [...], + // { C : D}) + // + // where `[...]` represents the "hole" we leave + // open to fill in next. + // RefPtr transitiveWitness = new TransitiveSubtypeWitness(); - transitiveWitness->sub = link->sub; - transitiveWitness->sup = witness->sup; - transitiveWitness->subToMid = link; - transitiveWitness->midToSup = witness; + transitiveWitness->sub = bb->sub; + transitiveWitness->sup = bb->sup; + transitiveWitness->midToSup = bb->declRef; + + // Fill in the current hole, and then set the + // hole to point into the node we just created. + *link = transitiveWitness; + link = &transitiveWitness->subToMid; - witness = transitiveWitness; + // Move on with the list. bb = bb->prev; } + // If we exit the loop, then there is only one breadcrumb left. + // In our running example this would be `{ A : B }`. We create + // a simple (declared) subtype witness for it, and plug the + // final hole, after which there shouldn't be a hole to deal with. + RefPtr declaredWitness = createSimpleSubtypeWitness(bb); + *link = declaredWitness; + + // We now know that our original `witness` variable has been + // filled in, and there are no other holes. return witness; } @@ -4325,7 +4502,7 @@ namespace Slang { if( auto leftInterfaceRef = leftDeclRefType->declRef.As() ) { - // + // return TryJoinTypeWithInterface(right, leftInterfaceRef); } } @@ -4333,7 +4510,7 @@ namespace Slang { if( auto rightInterfaceRef = rightDeclRefType->declRef.As() ) { - // + // return TryJoinTypeWithInterface(left, rightInterfaceRef); } } @@ -4481,9 +4658,9 @@ namespace Slang RefPtr solvedSubst = new GenericSubstitution(); solvedSubst->genericDecl = genericDeclRef.getDecl(); - solvedSubst->outer = genericDeclRef.substitutions.genericSubstitutions; + solvedSubst->outer = genericDeclRef.substitutions.substitutions; solvedSubst->args = args; - resultSubst.genericSubstitutions = solvedSubst; + resultSubst.substitutions = solvedSubst; for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType() ) { @@ -4959,12 +5136,12 @@ namespace Slang assert(subst); subst->genericDecl = genericDeclRef.getDecl(); - subst->outer = genericDeclRef.substitutions.genericSubstitutions; + subst->outer = genericDeclRef.substitutions.substitutions; for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType() ) { auto subset = genericDeclRef.substitutions; - subset.genericSubstitutions = subst; + subset.substitutions = subst; DeclRef constraintDeclRef( constraintDecl, subset); @@ -5039,7 +5216,7 @@ namespace Slang } subst->genericDecl = baseGenericRef.getDecl(); - subst->outer = baseGenericRef.substitutions.genericSubstitutions; + subst->outer = baseGenericRef.substitutions.substitutions; DeclRef innerDeclRef(GetInner(baseGenericRef), subst); @@ -5305,7 +5482,6 @@ namespace Slang } } - OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Func; candidate.item = item; @@ -5429,7 +5605,7 @@ namespace Slang auto constraintDecl2 = sndWit->declRef.As(); assert(constraintDecl1); assert(constraintDecl2); - return TryUnifyTypes(constraints, + return TryUnifyTypes(constraints, constraintDecl1.getDecl()->getSup().type, constraintDecl2.getDecl()->getSup().type); } @@ -5440,15 +5616,40 @@ namespace Slang // default: fail return false; } - - bool TryUnifySubstitutions( - ConstraintSystem& constraints, - RefPtr fst, - RefPtr snd) + + bool tryUnifySubstitutions( + ConstraintSystem& constraints, + RefPtr fst, + RefPtr snd) { // They must both be NULL or non-NULL if (!fst || !snd) - return fst == snd; + return !fst && !snd; + + if(auto fstGeneric = fst.As()) + { + if(auto sndGeneric = snd.As()) + { + return tryUnifyGenericSubstitutions( + constraints, + fstGeneric, + sndGeneric); + } + } + + // TODO: need to handle other cases here + + return false; + } + + bool tryUnifyGenericSubstitutions( + ConstraintSystem& constraints, + RefPtr fst, + RefPtr snd) + { + SLANG_ASSERT(fst); + SLANG_ASSERT(snd); + auto fstGen = fst; auto sndGen = snd; // They must be specializing the same generic @@ -5468,7 +5669,7 @@ namespace Slang } // Their "base" specializations must unify - if (!TryUnifySubstitutions(constraints, fstGen->outer, sndGen->outer)) + if (!tryUnifySubstitutions(constraints, fstGen->outer, sndGen->outer)) { okay = false; } @@ -5554,10 +5755,10 @@ namespace Slang // next we need to unify the substitutions applied // to each decalration reference. - if (!TryUnifySubstitutions( + if (!tryUnifySubstitutions( constraints, - fstDeclRef.substitutions.genericSubstitutions, - sndDeclRef.substitutions.genericSubstitutions)) + fstDeclRef.substitutions.substitutions, + sndDeclRef.substitutions.substitutions)) { return false; } @@ -5648,41 +5849,117 @@ namespace Slang // Is the candidate extension declaration actually applicable to the given type DeclRef ApplyExtensionToType( - ExtensionDecl* extDecl, - RefPtr type) + ExtensionDecl* extDecl, + RefPtr type) { + DeclRef extDeclRef = makeDeclRef(extDecl); + + // If the extension is a generic extension, then we + // need to infer type argumenst that will give + // us a target type that matches `type`. + // if (auto extGenericDecl = GetOuterGeneric(extDecl)) { ConstraintSystem constraints; constraints.genericDecl = extGenericDecl; if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type)) - return DeclRef().As(); + return DeclRef(); auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef(extGenericDecl, nullptr).As()); if (!constraintSubst) { - return DeclRef().As(); + return DeclRef(); } // Consruct a reference to the extension with our constraint variables // set as they were found by solving the constraint system. - DeclRef extDeclRef = DeclRef(extDecl, constraintSubst).As(); + extDeclRef = DeclRef(extDecl, constraintSubst).As(); + } - // We expect/require that the result of unification is such that - // the target types are now equal - SLANG_ASSERT(GetTargetType(extDeclRef)->Equals(type)); + // Now extract the target type from our (possibly specialized) extension decl-ref. + RefPtr targetType = GetTargetType(extDeclRef); - return extDeclRef; - } - else + // As a bit of a kludge here, if the target type of the extension is + // an interface, and the `type` we are trying to match up has a this-type + // substitution for that interface, then we want to attach a matching + // substitution to the extension decl-ref. + if(auto targetDeclRefType = targetType->As()) { - // The easy case is when the extension isn't generic: - // either it applies to the type or not. - if (!type->Equals(extDecl->targetType)) - return DeclRef().As(); - return DeclRef(extDecl, nullptr).As(); + if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.As()) + { + // Okay, the target type is an interface. + // + // Is the type we want to apply to also an interface? + if(auto appDeclRefType = type->As()) + { + if(auto appInterfaceDeclRef = appDeclRefType->declRef.As()) + { + if(appInterfaceDeclRef.getDecl() == targetInterfaceDeclRef.getDecl()) + { + // Looks like we have a match in the types, + // now let's see if we have a this-type substitution. + if(auto appThisTypeSubst = appInterfaceDeclRef.substitutions.substitutions.As()) + { + if(appThisTypeSubst->interfaceDecl == appInterfaceDeclRef.getDecl()) + { + // The type we want to apply to has a this-type substitution, + // and (by construction) the target type currently does not. + // + SLANG_ASSERT(!targetInterfaceDeclRef.substitutions.substitutions.As()); + + // We will create a new substitution to apply to the target type. + RefPtr newTargetSubst = new ThisTypeSubstitution(); + newTargetSubst->interfaceDecl = appThisTypeSubst->interfaceDecl; + newTargetSubst->witness = appThisTypeSubst->witness; + newTargetSubst->outer = targetInterfaceDeclRef.substitutions.substitutions; + + targetType = DeclRefType::Create(getSession(), + DeclRef(targetInterfaceDeclRef.getDecl(), newTargetSubst)); + + // Note: we are constructing a this-type substitution that + // we will apply to the extension declaration as well. + // This is not strictly allowed by our current representation + // choices, but we need it in order to make sure that + // references to the target type of the extension + // declaration have a chance to resolve the way we want them to. + + RefPtr newExtSubst = new ThisTypeSubstitution(); + newExtSubst->interfaceDecl = appThisTypeSubst->interfaceDecl; + newExtSubst->witness = appThisTypeSubst->witness; + newExtSubst->outer = extDeclRef.substitutions.substitutions; + + extDeclRef = DeclRef( + extDeclRef.getDecl(), + newExtSubst); + + // TODO: Ideally we should also apply the chosen specialization to + // the decl-ref for the extension, so that subsequent lookup through + // the members of this extension will retain that substitution and + // be able to apply it. + // + // E.g., if an extension method returns a value of an associated + // type, then we'd want that to become specialized to a concrete + // type when using the extension method on a value of concrete type. + // + // The challenge here that makes me reluctant to just staple on + // such a substitution is that it wouldn't follow our implicit + // rules about where `ThisTypeSubstitution`s can appear. + } + } + } + } + } + } } + + // In order for this extension to apply to the given type, we + // need to have a match on the target types. + if (!type->Equals(targetType)) + return DeclRef(); + + + return extDeclRef; } #if 0 @@ -6033,8 +6310,8 @@ namespace Slang // signature if( parentGenericDeclRef ) { - SLANG_RELEASE_ASSERT(declRef.substitutions); - auto genSubst = declRef.substitutions.genericSubstitutions; + auto genSubst = declRef.substitutions.substitutions.As(); + SLANG_RELEASE_ASSERT(genSubst); SLANG_RELEASE_ASSERT(genSubst->genericDecl == parentGenericDeclRef.getDecl()); sb << "<"; @@ -7166,8 +7443,10 @@ namespace Slang scopesToTry.Add(entryPoint->getTranslationUnit()->SyntaxNode->scope); for (auto & module : entryPoint->compileRequest->loadedModulesList) scopesToTry.Add(module->moduleDecl->scope); + + List> globalGenericArgs; for (auto name : entryPoint->genericParameterTypeNames) - { + { // parse type name RefPtr type; for (auto & s : scopesToTry) @@ -7185,9 +7464,10 @@ namespace Slang sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, name); return; } - entryPoint->genericParameterTypes.Add(type); + + globalGenericArgs.Add(type); } - + // validate global type arguments only when we are generating code if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) { @@ -7210,38 +7490,102 @@ namespace Slang for (auto p : globalGenParams) globalGenericParams.Add(p); } - if (globalGenericParams.Count() != entryPoint->genericParameterTypes.Count()) + + if (globalGenericParams.Count() != globalGenericArgs.Count()) { - sink->diagnose(entryPoint->decl, Diagnostics::mismatchEntryPointTypeArgument, globalGenericParams.Count(), - entryPoint->genericParameterTypes.Count()); + sink->diagnose(entryPoint->decl, Diagnostics::mismatchEntryPointTypeArgument, + globalGenericParams.Count(), + globalGenericArgs.Count()); return; } - // if entry-point type arguments matches parameters, try find - // SubtypeWitness for each argument - int index = 0; - for (auto & gParam : globalGenericParams) + + // We have an appropriate number of arguments for the global generic parameters, + // and now we need to check that the arguments conform to the declared constraints. + // + // Along the way, we will build up an appropriate set of substitutions to represent + // the generic arguments and their conformances. + // + RefPtr globalGenericSubsts; + auto globalGenericSubstLink = &globalGenericSubsts; + // + // TODO: There is a serious flaw to this checking logic if we ever have cases where + // the constraints on one `type_param` can depend on another `type_param`, e.g.: + // + // type_param A; + // type_param B : ISidekick; + // + // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to + // `ISidekick`, then the compiler needs to know whether `A` is being + // set to `Batman` to know whether the setting for `B` is valid. In this limit + // the constraints can be mutually recursive (so `A : IMentor`). + // + // The only way to check things corectly is to validate each conformance under + // a set of assumptions (substitutions) that includes all the type substitutions, + // and possibly also all the other constraints *except* the one to be validated. + // + // We will punt on this for now, and just check each constraint in isolation. + // + UInt argCounter = 0; + for(auto& globalGenericParam : globalGenericParams) { - for (auto constraint : gParam->getMembersOfType()) + // Get the argument that matches this parameter. + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < globalGenericArgs.Count()); + auto globalGenericArg = globalGenericArgs[argIndex]; + + // Create a substitution for this parameter/argument. + RefPtr subst = new GlobalGenericParamSubstitution(); + subst->paramDecl = globalGenericParam; + subst->actualType = globalGenericArg; + + // Walk through the declared constraints for the parameter, + // and check that the argument actually satisfies them. + for(auto constraint : globalGenericParam->getMembersOfType()) { + // Get the type that the constraint is enforcing conformance to auto interfaceType = GetSup(DeclRef(constraint, nullptr)); + + // Use our semantic-checking logic to search for a witness to the required conformance SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); - auto witness = visitor.tryGetSubtypeWitness(entryPoint->genericParameterTypes[index], interfaceType); + auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); if (!witness) { - sink->diagnose(gParam, - Diagnostics::typeArgumentDoesNotConformToInterface, gParam->nameAndLoc.name, entryPoint->genericParameterTypes[index], + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(globalGenericParam, + Diagnostics::typeArgumentDoesNotConformToInterface, + globalGenericParam->nameAndLoc.name, + globalGenericArg, interfaceType); } - entryPoint->genericParameterWitnesses.Add(witness); + + // Attach the concrete witness for this conformance to the + // substutiton + GlobalGenericParamSubstitution::ConstraintArg constraintArg; + constraintArg.decl = constraint; + constraintArg.val = witness; + subst->constraintArgs.Add(constraintArg); } - index++; + + // Add the substitution for this parameter to the global substitution + // set that we are building. + + *globalGenericSubstLink = subst; + globalGenericSubstLink = &subst->outer; } + + entryPoint->globalGenericSubst = globalGenericSubsts; } if (sink->errorCount != 0) return; // Now that we've *found* the entry point, it is time to validate // that it actually meets the constraints for the chosen stage/profile. + // + // TODO: This validation should be performed "under" any global generic + // parameter substitution we might have created, so that we can validate + // based on knowledge of actual types. + // validateEntryPoint(entryPoint); } @@ -7453,6 +7797,43 @@ namespace Slang return semantics->ApplyExtensionToType(extDecl, type); } + RefPtr createDefaultSubsitutionsForGeneric( + Session* session, + GenericDecl* genericDecl, + RefPtr outerSubst) + { + RefPtr genericSubst = new GenericSubstitution(); + genericSubst->genericDecl = genericDecl; + genericSubst->outer = outerSubst; + + for( auto mm : genericDecl->Members ) + { + if( auto genericTypeParamDecl = mm.As() ) + { + genericSubst->args.Add(DeclRefType::Create(session, DeclRef(genericTypeParamDecl.Ptr(), outerSubst))); + } + else if( auto genericValueParamDecl = mm.As() ) + { + genericSubst->args.Add(new GenericParamIntVal(DeclRef(genericValueParamDecl.Ptr(), outerSubst))); + } + } + + // create default substitution arguments for constraints + for (auto mm : genericDecl->Members) + { + if (auto genericTypeConstraintDecl = mm.As()) + { + RefPtr witness = new DeclaredSubtypeWitness(); + witness->declRef = DeclRef(genericTypeConstraintDecl.Ptr(), outerSubst); + witness->sub = genericTypeConstraintDecl->sub.type; + witness->sup = genericTypeConstraintDecl->sup.type; + genericSubst->args.Add(witness); + } + } + + return genericSubst; + } + // Sometimes we need to refer to a declaration the way that it would be specialized // inside the context where it is declared (e.g., with generic parameters filled in // using their archetypes). @@ -7460,53 +7841,25 @@ namespace Slang SubstitutionSet createDefaultSubstitutions( Session* session, Decl* decl, - SubstitutionSet parentSubst) + SubstitutionSet outerSubstSet) { - SubstitutionSet resultSubst = parentSubst; - if (auto interfaceDecl = dynamic_cast(decl)) - { - resultSubst.thisTypeSubstitution = new ThisTypeSubstitution(); - } auto dd = decl->ParentDecl; if( auto genericDecl = dynamic_cast(dd) ) { // We don't want to specialize references to anything // other than the "inner" declaration itself. if(decl != genericDecl->inner) - return resultSubst; + return outerSubstSet; - RefPtr subst = new GenericSubstitution(); - subst->genericDecl = genericDecl; - subst->outer = parentSubst.genericSubstitutions; - resultSubst.genericSubstitutions = subst; - SubstitutionSet outerSubst = resultSubst; - outerSubst.genericSubstitutions = outerSubst.genericSubstitutions?outerSubst.genericSubstitutions->outer:nullptr; - for( auto mm : genericDecl->Members ) - { - if( auto genericTypeParamDecl = mm.As() ) - { - subst->args.Add(DeclRefType::Create(session, DeclRef(genericTypeParamDecl.Ptr(), outerSubst))); - } - else if( auto genericValueParamDecl = mm.As() ) - { - subst->args.Add(new GenericParamIntVal(DeclRef(genericValueParamDecl.Ptr(), outerSubst))); - } - } + RefPtr genericSubst = createDefaultSubsitutionsForGeneric( + session, + genericDecl, + outerSubstSet.substitutions); - // create default substitution arguments for constraints - for (auto mm : genericDecl->Members) - { - if (auto genericTypeConstraintDecl = mm.As()) - { - RefPtr witness = new DeclaredSubtypeWitness(); - witness->declRef = DeclRef(genericTypeConstraintDecl.Ptr(), outerSubst); - witness->sub = genericTypeConstraintDecl->sub.type; - witness->sup = genericTypeConstraintDecl->sup.type; - subst->args.Add(witness); - } - } + return SubstitutionSet(genericSubst); } - return resultSubst; + + return outerSubstSet; } SubstitutionSet createDefaultSubstitutions( diff --git a/source/slang/compiler.h b/source/slang/compiler.h index 7ab47e6b3..703991e36 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -152,10 +152,7 @@ namespace Slang // where any errors were diagnosed. RefPtr decl; - // The declaration of the global generic parameter types - // This will be filled in as part of semantic analysis. - List> genericParameterTypes; - List> genericParameterWitnesses; + RefPtr globalGenericSubst; }; enum class PassThroughMode : SlangPassThrough @@ -453,7 +450,6 @@ namespace Slang RefPtr coreLanguageScope; RefPtr hlslLanguageScope; RefPtr slangLanguageScope; - RefPtr glslLanguageScope; List> loadedModuleCode; @@ -481,7 +477,6 @@ namespace Slang String getStdlibPath(); String getCoreLibraryCode(); String getHLSLLibraryCode(); - String getGLSLLibraryCode(); // Basic types that we don't want to re-create all the time RefPtr errorType; @@ -508,20 +503,6 @@ namespace Slang Type* getErrorType(); Type* getStringType(); - Type* getConstExprRate(); - RefPtr getRateQualifiedType( - Type* rate, - Type* valueType); - - RefPtr getConstExprType( - Type* valueType) - { - return getRateQualifiedType(getConstExprRate(), valueType); - } - - // Should not be used in front-end code - Type* getIRBasicBlockType(); - // Construct the type `Ptr`, where `Ptr` // is looked up as a builtin type. RefPtr getPtrType(RefPtr valueType); @@ -544,8 +525,6 @@ namespace Slang Type* elementType, IntVal* elementCount); - RefPtr getGroupSharedType(RefPtr valueType); - SyntaxClass findSyntaxClass(Name* name); Dictionary > mapNameToSyntaxClass; diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 785ef4406..35ad77f4f 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -101,20 +101,24 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) __generic __magic_type(PtrType) +__intrinsic_type($(kIROp_PtrType)) struct Ptr {}; __generic __magic_type(OutType) +__intrinsic_type($(kIROp_OutType)) struct Out {}; __generic __magic_type(InOutType) +__intrinsic_type($(kIROp_InOutType)) struct InOut {}; __magic_type(StringType) +__intrinsic_type($(kIROp_StringType)) struct String {}; @@ -181,6 +185,7 @@ sb << "__intrinsic_type(" << kIROp_TextureBufferType << ")\n"; sb << "__magic_type(TextureBuffer) struct TextureBuffer {};\n"; sb << "__generic\n"; +sb << "__intrinsic_type(" << kIROp_ParameterBlockType << ")\n"; sb << "__magic_type(ParameterBlockType) struct ParameterBlock {};\n"; static const char* kComponentNames[]{ "x", "y", "z", "w" }; @@ -313,11 +318,11 @@ for( int C = 2; C <= 4; ++C ) sb << "__magic_type(SamplerState," << int(SamplerStateFlavor::SamplerState) << ")\n"; -sb << "__intrinsic_type(" << kIROp_SamplerType << ", " << int(SamplerStateFlavor::SamplerState) << ")\n"; +sb << "__intrinsic_type(" << kIROp_SamplerStateType << ")\n"; sb << "struct SamplerState {};"; sb << "__magic_type(SamplerState," << int(SamplerStateFlavor::SamplerComparisonState) << ")\n"; -sb << "__intrinsic_type(" << kIROp_SamplerType << ", " << int(SamplerStateFlavor::SamplerComparisonState) << ")\n"; +sb << "__intrinsic_type(" << kIROp_SamplerComparisonStateType << ")\n"; sb << "struct SamplerComparisonState {};"; // TODO(tfoley): Need to handle `RW*` variants of texture types as well... @@ -377,6 +382,7 @@ for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) sb << "__generic "; sb << "__magic_type(TextureSampler," << int(flavor) << ")\n"; + sb << "__intrinsic_type(" << (kIROp_FirstTextureSamplerType + flavor) << ")\n"; sb << "struct Sampler"; sb << kBaseTextureAccessLevels[accessLevel].name; sb << name; @@ -434,7 +440,7 @@ for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) sb << "__generic "; sb << "__magic_type(Texture," << int(flavor) << ")\n"; - sb << "__intrinsic_type(" << kIROp_TextureType << ", " << flavor << ")\n"; + sb << "__intrinsic_type(" << (kIROp_FirstTextureType + flavor) << ")\n"; sb << "struct "; sb << kBaseTextureAccessLevels[accessLevel].name; sb << name; diff --git a/source/slang/core.meta.slang.h b/source/slang/core.meta.slang.h index bc0fbb53d..bbb258d15 100644 --- a/source/slang/core.meta.slang.h +++ b/source/slang/core.meta.slang.h @@ -101,20 +101,36 @@ SLANG_RAW("\n") SLANG_RAW("\n") SLANG_RAW("__generic\n") SLANG_RAW("__magic_type(PtrType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_PtrType +) +SLANG_RAW(")\n") SLANG_RAW("struct Ptr\n") SLANG_RAW("{};\n") SLANG_RAW("\n") SLANG_RAW("__generic\n") SLANG_RAW("__magic_type(OutType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_OutType +) +SLANG_RAW(")\n") SLANG_RAW("struct Out\n") SLANG_RAW("{};\n") SLANG_RAW("\n") SLANG_RAW("__generic\n") SLANG_RAW("__magic_type(InOutType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_InOutType +) +SLANG_RAW(")\n") SLANG_RAW("struct InOut\n") SLANG_RAW("{};\n") SLANG_RAW("\n") SLANG_RAW("__magic_type(StringType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_StringType +) +SLANG_RAW(")\n") SLANG_RAW("struct String\n") SLANG_RAW("{};\n") SLANG_RAW("\n") @@ -181,6 +197,7 @@ sb << "__intrinsic_type(" << kIROp_TextureBufferType << ")\n"; sb << "__magic_type(TextureBuffer) struct TextureBuffer {};\n"; sb << "__generic\n"; +sb << "__intrinsic_type(" << kIROp_ParameterBlockType << ")\n"; sb << "__magic_type(ParameterBlockType) struct ParameterBlock {};\n"; static const char* kComponentNames[]{ "x", "y", "z", "w" }; @@ -313,11 +330,11 @@ for( int C = 2; C <= 4; ++C ) sb << "__magic_type(SamplerState," << int(SamplerStateFlavor::SamplerState) << ")\n"; -sb << "__intrinsic_type(" << kIROp_SamplerType << ", " << int(SamplerStateFlavor::SamplerState) << ")\n"; +sb << "__intrinsic_type(" << kIROp_SamplerStateType << ")\n"; sb << "struct SamplerState {};"; sb << "__magic_type(SamplerState," << int(SamplerStateFlavor::SamplerComparisonState) << ")\n"; -sb << "__intrinsic_type(" << kIROp_SamplerType << ", " << int(SamplerStateFlavor::SamplerComparisonState) << ")\n"; +sb << "__intrinsic_type(" << kIROp_SamplerComparisonStateType << ")\n"; sb << "struct SamplerComparisonState {};"; // TODO(tfoley): Need to handle `RW*` variants of texture types as well... @@ -377,6 +394,7 @@ for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) sb << "__generic "; sb << "__magic_type(TextureSampler," << int(flavor) << ")\n"; + sb << "__intrinsic_type(" << (kIROp_FirstTextureSamplerType + flavor) << ")\n"; sb << "struct Sampler"; sb << kBaseTextureAccessLevels[accessLevel].name; sb << name; @@ -434,7 +452,7 @@ for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) sb << "__generic "; sb << "__magic_type(Texture," << int(flavor) << ")\n"; - sb << "__intrinsic_type(" << kIROp_TextureType << ", " << flavor << ")\n"; + sb << "__intrinsic_type(" << (kIROp_FirstTextureType + flavor) << ")\n"; sb << "struct "; sb << kBaseTextureAccessLevels[accessLevel].name; sb << name; diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h index 76480e64b..2f4f5abd3 100644 --- a/source/slang/decl-defs.h +++ b/source/slang/decl-defs.h @@ -108,7 +108,7 @@ SYNTAX_CLASS(InheritanceDecl, TypeConstraintDecl) // required by the base type to their concrete // implementations in the type that contains // this inheritance declaration. - Dictionary, DeclRef> requirementWitnesses; + RefPtr witnessTable; virtual TypeExp& getSup() override { return base; diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 15f295740..28fb0b551 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -282,7 +282,6 @@ static EOpInfo const* const kInfixOpInfos[] = &kEOp_Mod, }; - // // represents a declarator for use in emitting types @@ -302,16 +301,10 @@ struct EDeclarator SourceLoc loc; // Used for `Flavor::Array` - IntVal* elementCount; -}; - -struct TypeEmitArg -{ - EDeclarator* declarator; + IRInst* elementCount; }; struct EmitVisitor - : TypeVisitorWithArg { EmitContext* context; EmitVisitor(EmitContext* context) @@ -466,23 +459,6 @@ struct EmitVisitor emitName(name, SourceLoc()); } - void emitName( - Decl* decl, - SourceLoc const& loc) - { - if(auto name = decl->getName()) - emitName(name, loc); - - Emit("_S"); - Emit(getID(decl)); - } - - void emitName( - Decl* decl) - { - emitName(decl, SourceLoc()); - } - void Emit(IntegerLiteralValue value) { char buffer[32]; @@ -752,22 +728,6 @@ struct EmitVisitor // Types // - void Emit(RefPtr val) - { - if(auto constantIntVal = val.As()) - { - Emit(constantIntVal->value); - } - else if(auto varRefVal = val.As()) - { - EmitDeclRef(varRefVal->declRef); - } - else - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unknown type of integer constant value"); - } - } - void EmitDeclarator(EDeclarator* declarator) { if (!declarator) return; @@ -785,7 +745,7 @@ struct EmitVisitor Emit("["); if(auto elementCount = declarator->elementCount) { - Emit(elementCount); + EmitVal(elementCount); } Emit("]"); break; @@ -802,41 +762,35 @@ struct EmitVisitor } void emitGLSLTypePrefix( - RefPtr type) + IRType* type) { - if(auto basicElementType = type->As()) + switch (type->op) { - switch (basicElementType->baseType) - { - case BaseType::Float: - // no prefix - break; + case kIROp_FloatType: + // no prefix + break; - case BaseType::Int: Emit("i"); break; - case BaseType::UInt: Emit("u"); break; - case BaseType::Bool: Emit("b"); break; - case BaseType::Double: Emit("d"); break; - default: - SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled GLSL type prefix"); - break; - } - } - else if(auto vectorType = type->As()) - { - emitGLSLTypePrefix(vectorType->elementType); - } - else if(auto matrixType = type->As()) - { - emitGLSLTypePrefix(matrixType->getElementType()); - } - else - { + case kIROp_IntType: Emit("i"); break; + case kIROp_UIntType: Emit("u"); break; + case kIROp_BoolType: Emit("b"); break; + case kIROp_DoubleType: Emit("d"); break; + + case kIROp_VectorType: + emitGLSLTypePrefix(cast(type)->getElementType()); + break; + + case kIROp_MatrixType: + emitGLSLTypePrefix(cast(type)->getElementType()); + break; + + default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled GLSL type prefix"); + break; } } void emitHLSLTextureType( - RefPtr texType) + IRTextureTypeBase* texType) { switch(texType->getAccess()) { @@ -885,15 +839,15 @@ struct EmitVisitor Emit("Array"); } Emit("<"); - EmitType(texType->elementType); + EmitType(texType->getElementType()); Emit(" >"); } void emitGLSLTextureOrTextureSamplerType( - RefPtr type, - char const* baseName) + IRTextureTypeBase* type, + char const* baseName) { - emitGLSLTypePrefix(type->elementType); + emitGLSLTypePrefix(type->getElementType()); Emit(baseName); switch (type->GetBaseShape()) @@ -919,7 +873,7 @@ struct EmitVisitor } void emitGLSLTextureType( - RefPtr texType) + IRTextureType* texType) { switch(texType->getAccess()) { @@ -935,19 +889,19 @@ struct EmitVisitor } void emitGLSLTextureSamplerType( - RefPtr type) + IRTextureSamplerType* type) { emitGLSLTextureOrTextureSamplerType(type, "sampler"); } void emitGLSLImageType( - RefPtr type) + IRGLSLImageType* type) { emitGLSLTextureOrTextureSamplerType(type, "image"); } void emitTextureType( - RefPtr texType) + IRTextureType* texType) { switch(context->shared->target) { @@ -966,7 +920,7 @@ struct EmitVisitor } void emitTextureSamplerType( - RefPtr type) + IRTextureSamplerType* type) { switch(context->shared->target) { @@ -981,7 +935,7 @@ struct EmitVisitor } void emitImageType( - RefPtr type) + IRGLSLImageType* type) { switch(context->shared->target) { @@ -999,79 +953,27 @@ struct EmitVisitor } } - void emitTypeImpl(RefPtr type, EDeclarator* declarator) - { - TypeEmitArg arg; - arg.declarator = declarator; - - TypeVisitorWithArg::dispatch(type, arg); - } - -#define UNEXPECTED(NAME) \ - void visit##NAME(NAME*, TypeEmitArg const& arg) \ - { Emit(#NAME); EmitDeclarator(arg.declarator); } - - UNEXPECTED(ErrorType); - UNEXPECTED(OverloadGroupType); - UNEXPECTED(FuncType); - UNEXPECTED(TypeType); - UNEXPECTED(GenericDeclRefType); - UNEXPECTED(InitializerListType); - - UNEXPECTED(IRBasicBlockType); - UNEXPECTED(PtrType); - -#undef UNEXPECTED - - void visitNamedExpressionType(NamedExpressionType* type, TypeEmitArg const& arg) - { - // We will always emit the actual type referenced by - // a named type declaration, rather than try to produce - // equivalent `typedef` declarations in the output. - - emitTypeImpl(GetType(type->declRef), arg.declarator); - } - - void visitBasicExpressionType(BasicExpressionType* basicType, TypeEmitArg const& arg) - { - auto declarator = arg.declarator; - switch (basicType->baseType) - { - case BaseType::Void: Emit("void"); break; - case BaseType::Int: Emit("int"); break; - case BaseType::Float: Emit("float"); break; - case BaseType::UInt: Emit("uint"); break; - case BaseType::Bool: Emit("bool"); break; - case BaseType::Double: Emit("double"); break; - default: - SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled scalar type"); - break; - } - - EmitDeclarator(declarator); - } - void visitVectorExpressionType(VectorExpressionType* vecType, TypeEmitArg const& arg) + void emitVectorTypeImpl(IRVectorType* vecType) { - auto declarator = arg.declarator; switch(context->shared->target) { case CodeGenTarget::GLSL: case CodeGenTarget::GLSL_Vulkan: case CodeGenTarget::GLSL_Vulkan_OneDesc: { - emitGLSLTypePrefix(vecType->elementType); + emitGLSLTypePrefix(vecType->getElementType()); Emit("vec"); - Emit(vecType->elementCount); + EmitVal(vecType->getElementCount()); } break; case CodeGenTarget::HLSL: // TODO(tfoley): should really emit these with sugar Emit("vector<"); - EmitType(vecType->elementType); + EmitType(vecType->getElementType()); Emit(","); - Emit(vecType->elementCount); + EmitVal(vecType->getElementCount()); Emit(">"); break; @@ -1079,13 +981,10 @@ struct EmitVisitor SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled code generation target"); break; } - - EmitDeclarator(declarator); } - void visitMatrixExpressionType(MatrixExpressionType* matType, TypeEmitArg const& arg) + void emitMatrixTypeImpl(IRMatrixType* matType) { - auto declarator = arg.declarator; switch(context->shared->target) { case CodeGenTarget::GLSL: @@ -1094,11 +993,11 @@ struct EmitVisitor { emitGLSLTypePrefix(matType->getElementType()); Emit("mat"); - Emit(matType->getRowCount()); + EmitVal(matType->getRowCount()); // TODO(tfoley): only emit the next bit // for non-square matrix Emit("x"); - Emit(matType->getColumnCount()); + EmitVal(matType->getColumnCount()); } break; @@ -1107,9 +1006,9 @@ struct EmitVisitor Emit("matrix<"); EmitType(matType->getElementType()); Emit(","); - Emit(matType->getRowCount()); + EmitVal(matType->getRowCount()); Emit(","); - Emit(matType->getColumnCount()); + EmitVal(matType->getColumnCount()); Emit("> "); break; @@ -1117,42 +1016,18 @@ struct EmitVisitor SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled code generation target"); break; } - - EmitDeclarator(declarator); - } - - void visitTextureType(TextureType* texType, TypeEmitArg const& arg) - { - auto declarator = arg.declarator; - emitTextureType(texType); - EmitDeclarator(declarator); - } - - void visitTextureSamplerType(TextureSamplerType* textureSamplerType, TypeEmitArg const& arg) - { - auto declarator = arg.declarator; - emitTextureSamplerType(textureSamplerType); - EmitDeclarator(declarator); } - void visitGLSLImageType(GLSLImageType* imageType, TypeEmitArg const& arg) + void emitSamplerStateType(IRSamplerStateTypeBase* samplerStateType) { - auto declarator = arg.declarator; - emitImageType(imageType); - EmitDeclarator(declarator); - } - - void visitSamplerStateType(SamplerStateType* samplerStateType, TypeEmitArg const& arg) - { - auto declarator = arg.declarator; switch(context->shared->target) { case CodeGenTarget::HLSL: default: - switch (samplerStateType->flavor) + switch (samplerStateType->op) { - case SamplerStateFlavor::SamplerState: Emit("SamplerState"); break; - case SamplerStateFlavor::SamplerComparisonState: Emit("SamplerComparisonState"); break; + case kIROp_SamplerStateType: Emit("SamplerState"); break; + case kIROp_SamplerComparisonStateType: Emit("SamplerComparisonState"); break; default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled sampler state flavor"); break; @@ -1160,10 +1035,10 @@ struct EmitVisitor break; case CodeGenTarget::GLSL: - switch (samplerStateType->flavor) + switch (samplerStateType->op) { - case SamplerStateFlavor::SamplerState: Emit("sampler"); break; - case SamplerStateFlavor::SamplerComparisonState: Emit("samplerShadow"); break; + case kIROp_SamplerStateType: Emit("sampler"); break; + case kIROp_SamplerComparisonStateType: Emit("samplerShadow"); break; default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled sampler state flavor"); break; @@ -1171,69 +1046,217 @@ struct EmitVisitor break; break; } + } + + void emitStructuredBufferType(IRHLSLStructuredBufferTypeBase* type) + { + switch(context->shared->target) + { + case CodeGenTarget::HLSL: + default: + { + switch (type->op) + { + case kIROp_HLSLStructuredBufferType: Emit("StructuredBuffer"); break; + case kIROp_HLSLRWStructuredBufferType: Emit("RWStructuredBuffer"); break; + case kIROp_HLSLAppendStructuredBufferType: Emit("AppendStructuredBuffer"); break; + case kIROp_HLSLConsumeStructuredBufferType: Emit("ConsumeStructuredBuffer"); break; + + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled structured buffer type"); + break; + } - EmitDeclarator(declarator); + Emit("<"); + EmitType(type->getElementType()); + Emit(" >"); + } + break; + + case CodeGenTarget::GLSL: + // TODO: We desugar global variables with structured-buffer type into GLSL + // `buffer` declarations, but we don't currently handle structured-buffer types + // in other contexts (e.g., as function parameters). The simplest thing to do + // would be to emit a `StructuredBuffer` as `Foo[]` and `RWStructuredBuffer` + // as `in out Foo[]`, but that is starting to get into the realm of transformations + // that should really be handled during legalization, rather than during emission. + // + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "structured buffer type used unexpectedly"); + break; + } } - void visitDeclRefType(DeclRefType* declRefType, TypeEmitArg const& arg) + void emitUntypedBufferType(IRUntypedBufferResourceType* type) { - auto declarator = arg.declarator; - EmitDeclRef(declRefType->declRef); - EmitDeclarator(declarator); + switch(context->shared->target) + { + case CodeGenTarget::HLSL: + default: + { + switch (type->op) + { + case kIROp_HLSLByteAddressBufferType: Emit("ByteAddressBuffer"); break; + case kIROp_HLSLRWByteAddressBufferType: Emit("RWByteAddressBuffer"); break; + case kIROp_RaytracingAccelerationStructureType: Emit("RaytracingAccelerationStructureType"); break; + + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled buffer type"); + break; + } + } + break; + + case CodeGenTarget::GLSL: + { + switch (type->op) + { + case kIROp_HLSLByteAddressBufferType: Emit("ByteAddressBuffer"); break; + case kIROp_HLSLRWByteAddressBufferType: Emit("RWByteAddressBuffer"); break; + case kIROp_RaytracingAccelerationStructureType: Emit("RaytracingAccelerationStructureType"); break; + + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled buffer type"); + break; + } + } + break; + } } - void visitArrayExpressionType(ArrayExpressionType* arrayType, TypeEmitArg const& arg) + void emitSimpleTypeImpl(IRType* type) { - auto declarator = arg.declarator; + switch (type->op) + { + default: + break; - EDeclarator arrayDeclarator; - arrayDeclarator.next = declarator; + case kIROp_VoidType: Emit("void"); return; + case kIROp_IntType: Emit("int"); return; + case kIROp_UIntType: Emit("uint"); return; + case kIROp_BoolType: Emit("bool"); return; + case kIROp_HalfType: Emit("half"); return; + case kIROp_FloatType: Emit("float"); return; + case kIROp_DoubleType: Emit("double"); return; - if(arrayType->ArrayLength) + case kIROp_VectorType: + emitVectorTypeImpl((IRVectorType*)type); + return; + + case kIROp_MatrixType: + emitMatrixTypeImpl((IRMatrixType*)type); + return; + + case kIROp_SamplerStateType: + case kIROp_SamplerComparisonStateType: + emitSamplerStateType(cast(type)); + return; + + case kIROp_StructType: + emit(getIRName(type)); + return; + } + + // TODO: Ideally the following should be data-driven, + // based on meta-data attached to the definitions of + // each of these IR opcodes. + + if (auto texType = as(type)) { - arrayDeclarator.flavor = EDeclarator::Flavor::Array; - arrayDeclarator.elementCount = arrayType->ArrayLength.Ptr(); + emitTextureType(texType); + return; } - else + else if (auto textureSamplerType = as(type)) + { + emitTextureSamplerType(textureSamplerType); + return; + } + else if (auto imageType = as(type)) + { + emitImageType(imageType); + return; + } + else if (auto structuredBufferType = as(type)) + { + emitStructuredBufferType(structuredBufferType); + return; + } + else if(auto untypedBufferType = as(type)) { - arrayDeclarator.flavor = EDeclarator::Flavor::UnsizedArray; + emitUntypedBufferType(untypedBufferType); + return; } + // HACK: As a fallback for HLSL targets, assume that the name of the + // instruction being used is the same as the name of the HLSL type. + if(context->shared->target == CodeGenTarget::HLSL) + { + auto opInfo = getIROpInfo(type->op); + emit(opInfo.name); + UInt operandCount = type->getOperandCount(); + if(operandCount) + { + emit("<"); + for(UInt ii = 0; ii < operandCount; ++ii) + { + if(ii != 0) emit(", "); + EmitVal(type->getOperand(ii)); + } + emit(" >"); + } + + return; + } - emitTypeImpl(arrayType->baseType, &arrayDeclarator); + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled type"); } - void visitRateQualifiedType(RateQualifiedType* type, TypeEmitArg const& arg) + void emitArrayTypeImpl(IRArrayType* arrayType, EDeclarator* declarator) { - emitTypeImpl(type->valueType, arg.declarator); + EDeclarator arrayDeclarator; + arrayDeclarator.flavor = EDeclarator::Flavor::Array; + arrayDeclarator.next = declarator; + arrayDeclarator.elementCount = arrayType->getElementCount(); + + emitTypeImpl(arrayType->getElementType(), &arrayDeclarator); } - void visitConstExprRate(ConstExprRate* /*rate*/, TypeEmitArg const& /*arg*/) + void emitUnsizedArrayTypeImpl(IRUnsizedArrayType* arrayType, EDeclarator* declarator) { - // This should never appear as a data type - SLANG_UNEXPECTED("Rates not expected during emit"); + EDeclarator arrayDeclarator; + arrayDeclarator.flavor = EDeclarator::Flavor::UnsizedArray; + arrayDeclarator.next = declarator; + + emitTypeImpl(arrayType->getElementType(), &arrayDeclarator); } - void visitGroupSharedType(GroupSharedType* type, TypeEmitArg const& arg) + void emitTypeImpl(IRType* type, EDeclarator* declarator) { - switch(getTarget(context)) + switch (type->op) { - case CodeGenTarget::HLSL: - Emit("groupshared "); + default: + emitSimpleTypeImpl(type); + EmitDeclarator(declarator); break; - case CodeGenTarget::GLSL: - Emit("shared "); + case kIROp_RateQualifiedType: + { + auto rateQualifiedType = cast(type); + emitTypeImpl(rateQualifiedType->getValueType(), declarator); + } + + case kIROp_ArrayType: + emitArrayTypeImpl(cast(type), declarator); break; - default: + case kIROp_UnsizedArrayType: + emitUnsizedArrayTypeImpl(cast(type), declarator); break; } - emitTypeImpl(type->valueType, arg.declarator); + } void EmitType( - RefPtr type, + IRType* type, SourceLoc const& typeLoc, Name* name, SourceLoc const& nameLoc) @@ -1247,12 +1270,12 @@ struct EmitVisitor emitTypeImpl(type, &nameDeclarator); } - void EmitType(RefPtr type, Name* name) + void EmitType(IRType* type, Name* name) { EmitType(type, SourceLoc(), name, SourceLoc()); } - void EmitType(RefPtr type, String const& name) + void EmitType(IRType* type, String const& name) { // HACK: the rest of the code wants a `Name`, // so we'll create one for a bit... @@ -1263,7 +1286,7 @@ struct EmitVisitor } - void EmitType(RefPtr type) + void EmitType(IRType* type) { emitTypeImpl(type, nullptr); } @@ -1300,6 +1323,20 @@ struct EmitVisitor } } + void EmitType(IRType* type, Name* name, SourceLoc const& nameLoc) + { + EmitType( + type, + SourceLoc(), + name, + nameLoc); + } + + void EmitType(IRType* type, NameLoc const& nameAndLoc) + { + EmitType(type, nameAndLoc.name, nameAndLoc.loc); + } + bool isTargetIntrinsicModifierApplicable( IRTargetIntrinsicDecoration* decoration) { @@ -1407,80 +1444,18 @@ struct EmitVisitor } } - // - // Declaration References - // - - void EmitVal(RefPtr val) + void EmitVal(IRInst* val) { - if (auto type = val.As()) + if(auto type = as(val)) { EmitType(type); } - else if (auto intVal = val.As()) - { - Emit(intVal); - } else { - // Note(tfoley): ignore unhandled cases for semantics for now... - // assert(!"unimplemented"); + emitIRInstExpr(context, val, IREmitMode::Default); } } - bool isBuiltinDecl(Decl* decl) - { - for (auto dd = decl; dd; dd = dd->ParentDecl) - { - if (dd->FindModifier()) - return true; - } - return false; - } - - void EmitDeclRef(DeclRef declRef) - { - // When refering to anything other than a builtin, use its IR-facing name - if (!isBuiltinDecl(declRef.getDecl())) - { - emit(getIRName(declRef)); - return; - } - - - // TODO: need to qualify a declaration name based on parent scopes/declarations - - // Emit the name for the declaration itself - emitName(declRef.GetName()); - - // If the declaration is nested directly in a generic, then - // we need to output the generic arguments here - auto parentDeclRef = declRef.GetParent(); - if (auto genericDeclRef = parentDeclRef.As()) - { - // Only do this for declarations of appropriate flavors - if(auto funcDeclRef = declRef.As()) - { - // Don't emit generic arguments for functions, because HLSL doesn't allow them - return; - } - - GenericSubstitution* subst = declRef.substitutions.genericSubstitutions; - if (!subst) - return; - - Emit("<"); - UInt argCount = subst->args.Count(); - for (UInt aa = 0; aa < argCount; ++aa) - { - if (aa != 0) Emit(","); - EmitVal(subst->args[aa]); - } - Emit(" >"); - } - - } - typedef unsigned int ESemanticMask; enum { @@ -1491,50 +1466,6 @@ struct EmitVisitor kESemanticMask_Default = kESemanticMask_NoPackOffset, }; - void EmitSemantic(RefPtr semantic, ESemanticMask /*mask*/) - { - if (auto simple = semantic.As()) - { - Emit(" : "); - emit(simple->name.Content); - } - else if(auto registerSemantic = semantic.As()) - { - // Don't print out semantic from the user, since we are going to print the same thing our own way... - } - else if(auto packOffsetSemantic = semantic.As()) - { - // Don't print out semantic from the user, since we are going to print the same thing our own way... - } - else - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), semantic->loc, "unhandled kind of semantic"); - } - } - - - void EmitSemantics(RefPtr decl, ESemanticMask mask = kESemanticMask_Default ) - { - // Don't emit semantics if we aren't translating down to HLSL - switch (context->shared->target) - { - case CodeGenTarget::HLSL: - break; - - default: - return; - } - - for (auto mod = decl->modifiers.first; mod; mod = mod->next) - { - auto semantic = mod.As(); - if (!semantic) - continue; - - EmitSemantic(semantic, mask); - } - } - // A chain of variables to use for emitting semantic/layout info struct EmitVarChain { @@ -1851,7 +1782,6 @@ struct EmitVisitor } } - void emitGLSLVersionDirective( ModuleDecl* /*program*/) { @@ -1949,19 +1879,6 @@ struct EmitVisitor return context->shared->uniqueIDCounter++; } - UInt getID(Decl* decl) - { - auto& mapDeclToID = context->shared->mapDeclToID; - - UInt id = 0; - if(mapDeclToID.TryGetValue(decl, id)) - return id; - - id = allocateUniqueID(); - mapDeclToID.Add(decl, id); - return id; - } - // IR-level emit logc UInt getID(IRInst* value) @@ -1977,105 +1894,25 @@ struct EmitVisitor return id; } - String getIRName(Decl* decl) - { - // TODO: need a flag to get rid of the step that adds - // a prefix here, so that we can get "clean" output - // when needed. - // - - String name; - if (!(context->shared->entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_MANGLING)) - { - name.append("_s"); - } - name.append(getText(decl->getName())); - return name; - } - - String getIRName(DeclRefBase const& declRef) - { - // In general, when referring to a declaration that has been lowered - // via the IR, we want to use its mangled name. - // - // There are two main exceptions to this: - // - // 1. For debugging, we accept the `-no-mangle` flag which basically - // instructs us to try to use the original name of all declarations, - // to make the output more like what is expected to come out of - // fxc pass-through. This case should get deprecated some day. - // - // 2. It is really annoying to have the fields of a `struct` type - // get ridiculously lengthy mangled names, and this also messes - // up stuff like specialization (since the mangled name of a field - // would then include the mangled name of the outer type). - // - - String name; - if (context->shared->entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_MANGLING) - { - // Special case (1): - name.append(getText(declRef.GetName())); - return name; - } - - // Special case (2) - if (declRef.GetParent().decl->As()) - { - name.append(declRef.decl->nameAndLoc.name->text); - return name; - } - // General case: - name.append(getMangledName(declRef)); - return name; - } - String getIRName( IRInst* inst) { - switch(inst->op) - { - case kIROp_decl_ref: - { - auto irDeclRef = (IRDeclRef*) inst; - return getIRName(irDeclRef->declRef); - } - break; - - default: - break; - } - - if(auto decoration = inst->findDecoration()) - { - auto decl = decoration->decl; - if (auto reflectionNameMod = decl->FindModifier()) - { - return getText(reflectionNameMod->nameAndLoc.name); - } - - if ((context->shared->entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_MANGLING)) - { - return getIRName(decl); - } - } - - switch (inst->op) + // If the instruction has a mangled name, then emit using that. + if (auto globalValue = as(inst)) { - case kIROp_global_var: - case kIROp_global_constant: - case kIROp_Func: + auto mangledName = globalValue->mangledName; + if (mangledName) { - auto& mangledName = ((IRGlobalValue*)inst)->mangledName; - if(getText(mangledName).Length() != 0) + auto mangledNameText = getText(mangledName); + if (mangledNameText.Length() != 0) + { return getText(mangledName); + } } - break; - - default: - break; } + // Otherwise fall back to a construct temporary name + // for the instruction. StringBuilder sb; sb << "_S"; sb << getID(inst); @@ -2180,8 +2017,8 @@ struct EmitVisitor break; case kIROp_Var: - case kIROp_global_var: - case kIROp_global_constant: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: case kIROp_Param: return false; @@ -2190,7 +2027,7 @@ struct EmitVisitor case kIROp_boolConst: case kIROp_FieldAddress: case kIROp_getElementPtr: - case kIROp_specialize: + case kIROp_Specialize: case kIROp_BufferElementRef: return true; } @@ -2204,23 +2041,23 @@ struct EmitVisitor // variables. auto type = inst->getDataType(); - while (auto ptrType = type->As()) + while (auto ptrType = as(type)) { type = ptrType->getValueType(); } - if(type->As()) + if(as(type)) { // TODO: we need to be careful here, because // HLSL shader model 6 allows these as explicit // types. return true; } - else if (type->As()) + else if (as(type)) { return true; } - else if (type->As()) + else if (as(type)) { return true; } @@ -2231,15 +2068,15 @@ struct EmitVisitor // to fold them into their use sites in all cases if (getTarget(ctx) == CodeGenTarget::GLSL) { - if(type->As()) + if(as(type)) { return true; } - else if(type->As()) + else if(as(type)) { return true; } - else if(type->As()) + else if(as(type)) { return true; } @@ -2255,7 +2092,7 @@ struct EmitVisitor { auto type = inst->getDataType(); - if(type->As() && !type->As()) + if(as(type) && !as(type)) { // TODO: we need to be careful here, because // HLSL shader model 6 allows these as explicit @@ -2332,11 +2169,11 @@ struct EmitVisitor void emitIRRateQualifiers( EmitContext* ctx, - Type* rate) + IRRate* rate) { if(!rate) return; - if( auto constExprRate = rate->As() ) + if(as(rate)) { switch( getTarget(ctx) ) { @@ -2348,6 +2185,23 @@ struct EmitVisitor break; } } + + if (as(rate)) + { + switch( getTarget(ctx) ) + { + case CodeGenTarget::HLSL: + Emit("groupshared "); + break; + + case CodeGenTarget::GLSL: + Emit("shared "); + break; + + default: + break; + } + } } void emitIRRateQualifiers( @@ -2366,7 +2220,7 @@ struct EmitVisitor if(!type) return; - if (type->Equals(getSession()->getVoidType())) + if (as(type)) return; emitIRRateQualifiers(ctx, inst); @@ -2708,13 +2562,13 @@ struct EmitVisitor auto textureArg = args[0].get(); auto samplerArg = args[1].get(); - if (auto baseTextureType = textureArg->type->As()) + if (auto baseTextureType = as(textureArg->getDataType())) { emitGLSLTextureOrTextureSamplerType(baseTextureType, "sampler"); - if (auto samplerType = samplerArg->type->As()) + if (auto samplerType = as(samplerArg->getDataType())) { - if (samplerType->flavor == SamplerStateFlavor::SamplerComparisonState) + if (as(samplerType)) { Emit("Shadow"); } @@ -2746,7 +2600,7 @@ struct EmitVisitor // We are going to hack this *hard* for now. auto textureArg = args[0].get(); - if (auto baseTextureType = textureArg->type->As()) + if (auto baseTextureType = as(textureArg->getDataType())) { emitGLSLTextureOrTextureSamplerType(baseTextureType, "sampler"); Emit("("); @@ -2772,18 +2626,18 @@ struct EmitVisitor SLANG_RELEASE_ASSERT(argCount >= 1); auto textureArg = args[0].get(); - if (auto baseTextureType = textureArg->type->As()) + if (auto baseTextureType = as(textureArg->getDataType())) { - auto elementType = baseTextureType->elementType; - if (auto basicType = elementType->As()) + auto elementType = baseTextureType->getElementType(); + if (auto basicType = as(elementType)) { // A scalar result is expected Emit(".x"); } - else if (auto vectorType = elementType->As()) + else if (auto vectorType = as(elementType)) { // A vector result is expected - auto elementCount = GetIntVal(vectorType->elementCount); + auto elementCount = GetIntVal(vectorType->getElementCount()); if (elementCount < 4) { @@ -2813,9 +2667,9 @@ struct EmitVisitor SLANG_RELEASE_ASSERT(argCount > argIndex); auto vectorArg = args[argIndex].get(); - if (auto vectorType = vectorArg->type->As()) + if (auto vectorType = as(vectorArg->getDataType())) { - auto elementCount = GetIntVal(vectorType->elementCount); + auto elementCount = GetIntVal(vectorType->getElementCount()); Emit(elementCount); } else @@ -2850,7 +2704,7 @@ struct EmitVisitor UInt operandIndex = 1; - // + // if (auto targetIntrinsicDecoration = findTargetIntrinsicDecoration(ctx, func)) { emitTargetIntrinsicCallExpr( @@ -2869,7 +2723,29 @@ struct EmitVisitor // be better strategies (including just stuffing // a pointer to the original decl onto the callee). - UnmangleContext um(getText(func->mangledName)); + // If the intrinsic the user is calling is a generic, + // then the mangled name will have been set on the + // outer-most generic, and not on the leaf value + // (which is `func` above), so we need to walk + // upwards to find it. + // + IRGlobalValue* valueForName = func; + for(;;) + { + auto parentBlock = as(valueForName->parent); + if(!parentBlock) + break; + + auto parentGeneric = as(parentBlock->parent); + if(!parentGeneric) + break; + + valueForName = parentGeneric; + } + + // We will use the `UnmangleContext` utility to + // help us split the original name into its pieces. + UnmangleContext um(getText(valueForName->mangledName)); um.startUnmangling(); // We'll read through the qualified name of the @@ -3075,8 +2951,8 @@ struct EmitVisitor case kIROp_Mul: // Are we targetting GLSL, and are both operands matrices? if(getTarget(ctx) == CodeGenTarget::GLSL - && inst->getOperand(0)->type->As() - && inst->getOperand(1)->type->As()) + && as(inst->getOperand(0)->getDataType()) + && as(inst->getOperand(1)->getDataType())) { emit("matrixCompMult("); emitIROperand(ctx, inst->getOperand(0), mode); @@ -3096,7 +2972,7 @@ struct EmitVisitor case kIROp_Not: { - if (inst->getDataType()->Equals(getSession()->getBoolType())) + if (as(inst->getDataType())) { emit("!"); } @@ -3248,7 +3124,7 @@ struct EmitVisitor } break; - case kIROp_specialize: + case kIROp_Specialize: { emitIROperand(ctx, inst->getOperand(0), mode); } @@ -3322,8 +3198,8 @@ struct EmitVisitor case kIROp_Var: { - auto ptrType = inst->getDataType(); - auto valType = ((PtrType*)ptrType)->getValueType(); + auto ptrType = cast(inst->getDataType()); + auto valType = ptrType->getValueType(); auto name = getIRName(inst); emitIRType(ctx, valType, name); @@ -3379,7 +3255,22 @@ struct EmitVisitor emitIROperand(ctx, inst->getOperand(1), mode); emit(";\n"); } - break; + break; + } + } + + void emitIRSemantics( + EmitContext*, + VarLayout* varLayout) + { + if(varLayout->flags & VarLayoutFlag::HasSemantic) + { + Emit(" : "); + emit(varLayout->semanticName); + if(varLayout->semanticIndex) + { + Emit(varLayout->semanticIndex); + } } } @@ -3397,31 +3288,24 @@ struct EmitVisitor return; } - if(auto layoutDecoration = inst->findDecoration()) + if (auto semanticDecoration = inst->findDecoration()) { - if(auto varLayout = layoutDecoration->layout.As()) - { - if(varLayout->flags & VarLayoutFlag::HasSemantic) - { - Emit(" : "); - emit(varLayout->semanticName); - if(varLayout->semanticIndex) - { - Emit(varLayout->semanticIndex); - } - - return; - } - } + Emit(" : "); + emit(semanticDecoration->semanticName); + return; } - // TODO(tfoley): should we ever need to use the high-level declaration - // for this? It seems like the wrong approach... - - auto decoration = inst->findDecoration(); - if( decoration ) + if(auto layoutDecoration = inst->findDecoration()) { - EmitSemantics(decoration->decl); + auto layout = layoutDecoration->layout; + if(auto varLayout = layout.As()) + { + emitIRSemantics(ctx, varLayout); + } + else if (auto entryPointLayout = layout.As()) + { + emitIRSemantics(ctx, entryPointLayout->resultLayout); + } } } @@ -3502,7 +3386,7 @@ struct EmitVisitor // may exit this region with operations that do *not* branch // to `end`, but such non-local control flow will hopefully // be captured. - // + // void emitIRStmtsForBlocks( EmitContext* ctx, IRBlock* begin, @@ -4003,7 +3887,7 @@ struct EmitVisitor return getText(entryPointLayout->entryPoint->getName()); } - // + // return "main"; } @@ -4250,7 +4134,7 @@ struct EmitVisitor auto name = getIRFuncName(func); - emitIRType(ctx, resultType, name); + EmitType(resultType, name); emit("("); auto firstParam = func->getFirstParam(); @@ -4312,19 +4196,19 @@ struct EmitVisitor void emitIRParamType( EmitContext* ctx, - Type* type, + IRType* type, String const& name) { // An `out` or `inout` parameter will have been // encoded as a parameter of pointer type, so // we need to decode that here. // - if( auto outType = type->As() ) + if( auto outType = as(type)) { emit("out "); type = outType->getValueType(); } - else if( auto inOutType = type->As() ) + else if( auto inOutType = as(type)) { emit("inout "); type = inOutType->getValueType(); @@ -4333,16 +4217,29 @@ struct EmitVisitor emitIRType(ctx, type, name); } + IRInst* getSpecializedValue(IRSpecialize* specInst) + { + auto base = specInst->getBase(); + auto baseGeneric = as(base); + if (!baseGeneric) + return base; + + auto lastBlock = baseGeneric->getLastBlock(); + if (!lastBlock) + return base; + + auto returnInst = as(lastBlock->getTerminator()); + if (!returnInst) + return base; + + return returnInst->getVal(); + } + void emitIRFuncDecl( EmitContext* ctx, IRFunc* func) { - // We don't want to declare generic functions, - // because none of our targets actually support them. - if(func->getGenericDecl()) - return; - - // We also don't want to emit declarations for operations + // We don't want to emit declarations for operations // that only appear in the IR as stand-ins for built-in // operations on that target. if (isTargetIntrinsic(ctx, func)) @@ -4361,7 +4258,7 @@ struct EmitVisitor // and as a result it *also* doesn't have the IR `param` instructions, // so we need to emit a declaration entirely from the type. - auto funcType = func->getType(); + auto funcType = func->getDataType(); auto resultType = func->getResultType(); auto name = getIRFuncName(func); @@ -4432,9 +4329,9 @@ struct EmitVisitor if(!value) return nullptr; - if(value->op == kIROp_specialize) + while (auto specInst = as(value)) { - value = ((IRSpecialize*) value)->genericVal.get(); + value = getSpecializedValue(specInst); } if(value->op != kIROp_Func) @@ -4451,11 +4348,6 @@ struct EmitVisitor EmitContext* ctx, IRFunc* func) { - if(func->getGenericDecl()) - { - return; - } - if(!isDefinition(func)) { // This is just a function declaration, @@ -4479,27 +4371,39 @@ struct EmitVisitor } } -#if 0 void emitIRStruct( - EmitContext* context, - IRStructDecl* structType) + EmitContext* ctx, + IRStructType* structType) { emit("struct "); - emit(getName(structType)); + emit(getIRName(structType)); emit("\n{\n"); + indent(); - for(auto ff = structType->getFirstField(); ff; ff = ff->getNextField()) + for(auto ff : structType->getFields()) { + auto fieldKey = ff->getKey(); auto fieldType = ff->getFieldType(); - emitIRType(context, fieldType, getName(ff)); - emitIRSemantics(context, ff); + // Filter out fields with `void` type that might + // have been introduced by legalization. + if(as(fieldType)) + continue; + + // Note: GLSL doesn't support interpolation modifiers on `struct` fields + if( ctx->shared->target != CodeGenTarget::GLSL ) + { + emitInterpolationModifiers(ctx, fieldKey, fieldType); + } + emitIRType(ctx, fieldType, getIRName(fieldKey)); + emitIRSemantics(ctx, fieldKey); emit(";\n"); } + + dedent(); emit("};\n"); } -#endif void emitIRMatrixLayoutModifiers( EmitContext* ctx, @@ -4552,7 +4456,7 @@ struct EmitVisitor default: break; } - + } } @@ -4561,26 +4465,22 @@ struct EmitVisitor // of the variable is an integer type. void maybeEmitGLSLFlatModifier( EmitContext*, - Type* valueType) + IRType* valueType) { auto tt = valueType; - if(auto vecType = tt->As()) - tt = vecType->elementType; - if(auto vecType = tt->As()) + if(auto vecType = as(tt)) + tt = vecType->getElementType(); + if(auto vecType = as(tt)) tt = vecType->getElementType(); - auto baseType = tt->As(); - if(!baseType) - return; - - switch(baseType->baseType) + switch(tt->op) { default: break; - case BaseType::Int: - case BaseType::UInt: - case BaseType::UInt64: + case kIROp_IntType: + case kIROp_UIntType: + case kIROp_UInt64Type: Emit("flat "); break; } @@ -4588,36 +4488,51 @@ struct EmitVisitor void emitInterpolationModifiers( EmitContext* ctx, - VarDeclBase* decl, - Type* valueType) + IRInst* varInst, + IRType* valueType) { bool isGLSL = (ctx->shared->target == CodeGenTarget::GLSL); bool anyModifiers = false; - if(decl->FindModifier()) - { - anyModifiers = true; - Emit(isGLSL ? "flat " : "nointerpolation "); - } - else if(decl->FindModifier()) - { - anyModifiers = true; - Emit("noperspective "); - } - else if(decl->FindModifier()) - { - anyModifiers = true; - Emit(isGLSL ? "smooth " : "linear "); - } - else if(decl->FindModifier()) - { - anyModifiers = true; - Emit("sample "); - } - else if(decl->FindModifier()) + anyModifiers = true; + for(auto dd = varInst->firstDecoration; dd; dd = dd->next) { - anyModifiers = true; - Emit("centroid "); + if(dd->op != kIRDecorationOp_InterpolationMode) + continue; + + auto decoration = (IRInterpolationModeDecoration*)dd; + auto mode = decoration->mode; + + switch(mode) + { + case IRInterpolationMode::NoInterpolation: + anyModifiers = true; + Emit(isGLSL ? "flat " : "nointerpolation "); + break; + + case IRInterpolationMode::NoPerspective: + anyModifiers = true; + Emit("noperspective "); + break; + + case IRInterpolationMode::Linear: + anyModifiers = true; + Emit(isGLSL ? "smooth " : "linear "); + break; + + case IRInterpolationMode::Sample: + anyModifiers = true; + Emit("sample "); + break; + + case IRInterpolationMode::Centroid: + anyModifiers = true; + Emit("centroid "); + break; + + default: + break; + } } // If the user didn't explicitly qualify a varying @@ -4629,18 +4544,11 @@ struct EmitVisitor } } - void emitInterpolationModifiers( - EmitContext* ctx, - VarLayout* layout, - Type* valueType) - { - emitInterpolationModifiers(ctx, layout->varDecl, valueType); - } - void emitIRVarModifiers( EmitContext* ctx, VarLayout* layout, - Type* valueType) + IRInst* varDecl, + IRType* varType) { if (!layout) return; @@ -4651,7 +4559,7 @@ struct EmitVisitor // for an HLSL `RWTexture*` then we need to emit a `format` layout qualifier. if(getTarget(context) == CodeGenTarget::GLSL) { - if(auto resourceType = unwrapArray(valueType).As()) + if(auto resourceType = as(unwrapArray(varType))) { switch(resourceType->getAccess()) { @@ -4676,6 +4584,12 @@ struct EmitVisitor } } + if(layout->FindResourceInfo(LayoutResourceKind::VaryingInput) + || layout->FindResourceInfo(LayoutResourceKind::VaryingOutput)) + { + emitInterpolationModifiers(ctx, varDecl, varType); + } + if (ctx->shared->target == CodeGenTarget::GLSL) { // Layout-related modifiers need to come before the declaration, @@ -4696,20 +4610,12 @@ struct EmitVisitor case LayoutResourceKind::VaryingInput: { emit("in "); - if(layout->stage == Stage::Fragment) - { - maybeEmitGLSLFlatModifier(ctx, valueType); - } } break; - case LayoutResourceKind::FragmentOutput: + case LayoutResourceKind::VaryingOutput: { emit("out "); - if(layout->stage != Stage::Fragment) - { - maybeEmitGLSLFlatModifier(ctx, valueType); - } } break; @@ -4723,9 +4629,9 @@ struct EmitVisitor } void emitHLSLParameterBlock( - EmitContext* ctx, - IRGlobalVar* varDecl, - ParameterBlockType* type) + EmitContext* ctx, + IRGlobalVar* varDecl, + IRParameterBlockType* type) { emit("cbuffer "); @@ -4768,11 +4674,11 @@ struct EmitVisitor } void emitHLSLParameterGroup( - EmitContext* ctx, - IRGlobalVar* varDecl, - UniformParameterGroupType* type) + EmitContext* ctx, + IRGlobalVar* varDecl, + IRUniformParameterGroupType* type) { - if(auto parameterBlockType = type->As()) + if(auto parameterBlockType = as(type)) { emitHLSLParameterBlock(ctx, varDecl, parameterBlockType); return; @@ -4805,45 +4711,52 @@ struct EmitVisitor auto elementType = type->getElementType(); - - if(auto declRefType = elementType->As()) + if(auto structType = as(elementType)) { - if(auto structDeclRef = declRefType->declRef.As()) + auto structTypeLayout = typeLayout.As(); + assert(structTypeLayout); + + UInt fieldIndex = 0; + for(auto ff : structType->getFields()) { - auto structTypeLayout = typeLayout.As(); - assert(structTypeLayout); + // TODO: need a plan to deal with the case where the IR-level + // `struct` type might not match the high-level type, so that + // the numbering of fields is different. + // + // The right plan is probably to require that the lowering pass + // create a fresh layout for any type/variable that it splits + // in this fashion, so that the layout information it attaches + // can always be assumed to apply to the actual instruciton. + // - UInt fieldIndex = 0; - for(auto ff : GetFields(structDeclRef)) - { - // TODO: need a plan to deal with the case where the IR-level - // `struct` type might not match the high-level type, so that - // the numbering of fields is different. - // - // The right plan is probably to require that the lowering pass - // create a fresh layout for any type/variable that it splits - // in this fashion, so that the layout information it attaches - // can always be assumed to apply to the actual instruciton. - // + auto fieldLayout = structTypeLayout->fields[fieldIndex++]; - auto fieldLayout = structTypeLayout->fields[fieldIndex++]; + auto fieldKey = ff->getKey(); + auto fieldType = ff->getFieldType(); - auto fieldType = GetType(ff); - if(fieldType->Equals(getSession()->getVoidType())) - continue; + // Fields of `void` type aren't valid in HLSL/GLSL. + // + // TODO: legalization should get rid of any fields that have + // empty, or effectively empty types (e.g., emptry structs + // should be translated over to `void`). + if(as(fieldType)) + continue; - emitIRVarModifiers(ctx, fieldLayout, fieldType); + emitIRVarModifiers(ctx, fieldLayout, fieldKey, fieldType); - emitIRType(ctx, fieldType, getIRName(ff)); + emitIRType(ctx, fieldType, getIRName(fieldKey)); - emitHLSLParameterGroupFieldLayoutSemantics(fieldLayout, &elementChain); + emitHLSLParameterGroupFieldLayoutSemantics(fieldLayout, &elementChain); - emit(";\n"); - } + emit(";\n"); } } else { + // TODO: during legalization we should turn `ParameterGroup` where `X` + // is not a `struct` type into `ParameterGroup` where `S` is defined + // as something like `struct S { X _; };` + // emit("/* unexpected */"); } @@ -4852,9 +4765,9 @@ struct EmitVisitor } void emitGLSLParameterBlock( - EmitContext* ctx, - IRGlobalVar* varDecl, - ParameterBlockType* type) + EmitContext* ctx, + IRGlobalVar* varDecl, + IRParameterBlockType* type) { auto varLayout = getVarLayout(ctx, varDecl); assert(varLayout); @@ -4893,11 +4806,11 @@ struct EmitVisitor } void emitGLSLParameterGroup( - EmitContext* ctx, - IRGlobalVar* varDecl, - UniformParameterGroupType* type) + EmitContext* ctx, + IRGlobalVar* varDecl, + IRUniformParameterGroupType* type) { - if(auto parameterBlockType = type->As()) + if(auto parameterBlockType = as(type)) { emitGLSLParameterBlock(ctx, varDecl, parameterBlockType); return; @@ -4922,7 +4835,7 @@ struct EmitVisitor emitGLSLLayoutQualifier(LayoutResourceKind::DescriptorTableSlot, &containerChain); - if(type->As()) + if(as(type)) { emit("layout(std430) buffer "); } @@ -4939,52 +4852,50 @@ struct EmitVisitor auto elementType = type->getElementType(); - if(auto declRefType = elementType->As()) + if(auto structType = as(elementType)) { - if(auto structDeclRef = declRefType->declRef.As()) - { - auto structTypeLayout = typeLayout.As(); - assert(structTypeLayout); + auto structTypeLayout = typeLayout.As(); + assert(structTypeLayout); - UInt fieldIndex = 0; - for(auto ff : GetFields(structDeclRef)) - { - // TODO: need a plan to deal with the case where the IR-level - // `struct` type might not match the high-level type, so that - // the numbering of fields is different. - // - // The right plan is probably to require that the lowering pass - // create a fresh layout for any type/variable that it splits - // in this fashion, so that the layout information it attaches - // can always be assumed to apply to the actual instruciton. - // + UInt fieldIndex = 0; + for(auto ff : structType->getFields()) + { + // TODO: need a plan to deal with the case where the IR-level + // `struct` type might not match the high-level type, so that + // the numbering of fields is different. + // + // The right plan is probably to require that the lowering pass + // create a fresh layout for any type/variable that it splits + // in this fashion, so that the layout information it attaches + // can always be assumed to apply to the actual instruciton. + // - auto fieldLayout = structTypeLayout->fields[fieldIndex++]; + auto fieldLayout = structTypeLayout->fields[fieldIndex++]; - auto fieldType = GetType(ff); - if(fieldType->Equals(getSession()->getVoidType())) - continue; + auto fieldKey = ff->getKey(); + auto fieldType = ff->getFieldType(); + if(as(fieldType)) + continue; - // Note: we will emit matrix-layout modifiers here, but - // we will refrain from emitting other modifiers that - // might not be appropriate to the context (e.g., we - // shouldn't go emitting `uniform` just because these - // things are uniform...). - // - // TODO: we need a more refined set of modifiers that - // we should allow on fields, because we might end - // up supporting layout that isn't the default for - // the given block type (e.g., something other than - // `std140` for a uniform block). - // - emitIRMatrixLayoutModifiers(ctx, fieldLayout); + // Note: we will emit matrix-layout modifiers here, but + // we will refrain from emitting other modifiers that + // might not be appropriate to the context (e.g., we + // shouldn't go emitting `uniform` just because these + // things are uniform...). + // + // TODO: we need a more refined set of modifiers that + // we should allow on fields, because we might end + // up supporting layout that isn't the default for + // the given block type (e.g., something other than + // `std140` for a uniform block). + // + emitIRMatrixLayoutModifiers(ctx, fieldLayout); - emitIRType(ctx, fieldType, getIRName(ff)); + emitIRType(ctx, fieldType, getIRName(fieldKey)); // emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); - emit(";\n"); - } + emit(";\n"); } } else @@ -5002,9 +4913,9 @@ struct EmitVisitor } void emitIRParameterGroup( - EmitContext* ctx, - IRGlobalVar* varDecl, - UniformParameterGroupType* type) + EmitContext* ctx, + IRGlobalVar* varDecl, + IRUniformParameterGroupType* type) { switch (ctx->shared->target) { @@ -5042,8 +4953,8 @@ struct EmitVisitor // Need to emit appropriate modifiers here. auto layout = getVarLayout(ctx, varDecl); - - emitIRVarModifiers(ctx, layout, varType); + + emitIRVarModifiers(ctx, layout, varDecl, varType); #if 0 switch (addressSpace) @@ -5067,12 +4978,12 @@ struct EmitVisitor emit(";\n"); } - RefPtr unwrapArray(Type* type) + IRType* unwrapArray(IRType* type) { - Type* t = type; - while( auto arrayType = t->As() ) + IRType* t = type; + while( auto arrayType = as(t) ) { - t = arrayType->baseType; + t = arrayType->getElementType(); } return t; } @@ -5080,7 +4991,7 @@ struct EmitVisitor void emitIRStructuredBuffer_GLSL( EmitContext* ctx, IRGlobalVar* varDecl, - HLSLStructuredBufferTypeBase* structuredBufferType) + IRHLSLStructuredBufferTypeBase* structuredBufferType) { // Shader storage buffer is an OpenGL 430 feature // @@ -5145,7 +5056,7 @@ struct EmitVisitor // Emit a blank line so that the formatting is nicer. emit("\n"); - if (auto paramBlockType = varType->As()) + if (auto paramBlockType = as(varType)) { emitIRParameterGroup( ctx, @@ -5158,7 +5069,7 @@ struct EmitVisitor { // When outputting GLSL, we need to transform any declaration of // a `*StructuredBuffer` into an ordinary `buffer` declaration. - if( auto structuredBufferType = unwrapArray(varType)->As() ) + if( auto structuredBufferType = as(unwrapArray(varType)) ) { emitIRStructuredBuffer_GLSL( ctx, @@ -5205,7 +5116,7 @@ struct EmitVisitor } } - emitIRVarModifiers(ctx, layout, varType); + emitIRVarModifiers(ctx, layout, varDecl, varType); emitIRType(ctx, varType, getIRName(varDecl)); @@ -5282,11 +5193,11 @@ struct EmitVisitor emitIRFunc(ctx, (IRFunc*) inst); break; - case kIROp_global_var: + case kIROp_GlobalVar: emitIRGlobalVar(ctx, (IRGlobalVar*) inst); break; - case kIROp_global_constant: + case kIROp_GlobalConstant: emitIRGlobalConstant(ctx, (IRGlobalConstant*) inst); break; @@ -5294,202 +5205,158 @@ struct EmitVisitor emitIRVar(ctx, (IRVar*) inst); break; + case kIROp_StructType: + emitIRStruct(ctx, cast(inst)); + break; + default: break; } } - void ensureStructDecl( - EmitContext* ctx, - DeclRef declRef) + // An action to be performed during code emit. + struct EmitAction { - auto mangledName = getMangledName(declRef); - if(ctx->shared->irDeclsVisited.Contains(mangledName)) - return; - - ctx->shared->irDeclsVisited.Add(mangledName); - - // First emit any types used by fields of this type - for( auto ff : GetFields(declRef) ) + enum Level { - if(ff.getDecl()->HasModifier()) - continue; + ForwardDeclaration, + Definition, + }; + Level level; + IRInst* inst; + }; - auto fieldType = GetType(ff); - emitIRUsedType(ctx, fieldType); - } + struct ComputeEmitActionsContext + { + IRInst* moduleInst; + HashSet openInsts; + Dictionary mapInstToLevel; + List* actions; + }; - // Don't emit declarations for types that should be built-in on the target. - // - // TODO: This should really be checking if the type is a target intrinsic - // for the chosen target, and not just whether it is globally declared - // as a builtin (so that we can have types that are builtin in some cases, - // but not others). - if(declRef.getDecl()->HasModifier()) - return; + void ensureInstOperand( + ComputeEmitActionsContext* ctx, + IRInst* inst, + EmitAction::Level requiredLevel = EmitAction::Level::Definition) + { + if(!inst) return; - Emit("\nstruct "); - EmitDeclRef(declRef); - Emit("\n{\n"); - indent(); - for( auto ff : GetFields(declRef) ) + if(inst->getParent() == ctx->moduleInst) { - if(ff.getDecl()->HasModifier()) - continue; - - auto fieldType = GetType(ff); - - // Skip `void` fields that might have been created by legalization. - if(fieldType->Equals(getSession()->getVoidType())) - continue; - - // Note: GLSL doesn't support interpolation modifiers on `struct` fields - if( ctx->shared->target != CodeGenTarget::GLSL ) - { - emitInterpolationModifiers(ctx, ff.getDecl(), fieldType); - } - emitIRType(ctx, fieldType, getIRName(ff)); - - EmitSemantics(ff.getDecl()); - - emit(";\n"); + ensureGlobalInst(ctx, inst, requiredLevel); } - dedent(); - Emit("};\n"); } - void emitIRUsedDeclRef( - EmitContext* ctx, - DeclRef declRef) + void ensureInstOperandsRec( + ComputeEmitActionsContext* ctx, + IRInst* inst) { - auto decl = declRef.getDecl(); + ensureInstOperand(ctx, inst->getFullType()); - if(decl->HasModifier() - || decl->HasModifier()) + UInt operandCount = inst->operandCount; + for(UInt ii = 0; ii < operandCount; ++ii) { - return; + // TODO: there are some special cases we can add here, + // to avoid outputting full definitions in cases that + // can get by with forward declarations. + // + // For example, true pointer types should (in principle) + // only need the type they point to to be forward-declared. + // Similarly, a `call` instruction only needs the callee + // to be forward-declared, etc. + + ensureInstOperand(ctx, inst->getOperand(ii)); } - if( auto structDeclRef = declRef.As() ) + if(auto parentInst = as(inst)) { - // - ensureStructDecl(ctx, structDeclRef); + for(auto child : parentInst->getChildren()) + { + ensureInstOperandsRec(ctx, child); + } } } - // A type is going to be used by the IR, so - // make sure that we have emitted whatever - // it needs. - void emitIRUsedType( - EmitContext* ctx, - Type* type) + void ensureGlobalInst( + ComputeEmitActionsContext* ctx, + IRInst* inst, + EmitAction::Level requiredLevel) { - if(type->As()) - {} - else if(type->As()) - {} - else if(type->As()) - {} - else if(auto arrayType = type->As()) - { - emitIRUsedType(ctx, arrayType->baseType); - } - else if( auto textureType = type->As() ) - { - emitIRUsedType(ctx, textureType->elementType); - } - else if( auto genericType = type->As() ) - { - emitIRUsedType(ctx, genericType->elementType); - } - else if( auto ptrType = type->As() ) - { - emitIRUsedType(ctx, ptrType->getValueType()); - } - else if(type->As() ) - { - } - else if( auto declRefType = type->As() ) + // Skip certain instrutions, since they + // don't affect output. + switch(inst->op) { - auto declRef = declRefType->declRef; - emitIRUsedDeclRef(ctx, declRef); + case kIROp_WitnessTable: + case kIROp_Generic: + return; + + default: + break; } - else - {} - } - void emitIRUsedTypesForGlobalValueWithCode( - EmitContext* ctx, - IRGlobalValueWithCode* value) - { - for( auto bb = value->getFirstBlock(); bb; bb = bb->getNextBlock() ) + // Have we already processed this instruction? + EmitAction::Level existingLevel; + if(ctx->mapInstToLevel.TryGetValue(inst, existingLevel)) { - for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() ) - { - emitIRUsedTypesForValue(ctx, pp); - } - - for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() ) - { - emitIRUsedTypesForValue(ctx, ii); - } + // If we've already emitted it suitably, + // then don't worry about it. + if(existingLevel >= requiredLevel) + return; } - } - void emitIRUsedTypesForValue( - EmitContext* ctx, - IRInst* value) - { - if(!value) return; - switch( value->op ) + EmitAction action; + action.level = requiredLevel; + action.inst = inst; + + if(requiredLevel == EmitAction::Level::Definition) { - case kIROp_Func: + if(ctx->openInsts.Contains(inst)) { - auto irFunc = (IRFunc*) value; + SLANG_UNEXPECTED("circularity during codegen"); + return; + } - // Don't emit anything for a generic function, - // since we only care about the types used by - // the actual specializations. - if (irFunc->getGenericDecl()) - return; + ctx->openInsts.Add(inst); - emitIRUsedType(ctx, irFunc->getResultType()); + ensureInstOperandsRec(ctx, inst); - emitIRUsedTypesForGlobalValueWithCode(ctx, irFunc); - } - break; + ctx->openInsts.Remove(inst); + } - case kIROp_global_var: - { - auto irGlobal = (IRGlobalVar*) value; - emitIRUsedType(ctx, irGlobal->type); - emitIRUsedTypesForGlobalValueWithCode(ctx, irGlobal); - } - break; + ctx->mapInstToLevel[inst] = requiredLevel; + ctx->actions->Add(action); + } - case kIROp_global_constant: - { - auto irGlobal = (IRGlobalConstant*) value; - emitIRUsedType(ctx, irGlobal->type); - emitIRUsedTypesForGlobalValueWithCode(ctx, irGlobal); - } - break; + void computeIREmitActions( + IRModule* module, + List& ioActions) + { + ComputeEmitActionsContext ctx; + ctx.moduleInst = module->getModuleInst(); + ctx.actions = &ioActions; - default: - { - emitIRUsedType(ctx, value->type); - } - break; + for(auto inst : module->getGlobalInsts()) + { + ensureGlobalInst(&ctx, inst, EmitAction::Level::Definition); } } - void emitIRUsedTypesForModule( - EmitContext* ctx, - IRModule* module) + void executeIREmitActions( + EmitContext* ctx, + List const& actions) { - for(auto ii : module->getGlobalInsts()) + for(auto action : actions) { - emitIRUsedTypesForValue(ctx, ii); + switch(action.level) + { + case EmitAction::Level::ForwardDeclaration: + emitIRFuncDecl(ctx, cast(action.inst)); + break; + + case EmitAction::Level::Definition: + emitIRGlobalInst(ctx, action.inst); + break; + } } } @@ -5497,27 +5364,16 @@ struct EmitVisitor EmitContext* ctx, IRModule* module) { - emitIRUsedTypesForModule(ctx, module); - - // Before we emit code, we need to forward-declare - // all of our functions so that we don't have to - // sort them by dependencies. - for(auto ii : module->getGlobalInsts()) - { - if(ii->op != kIROp_Func) - continue; + // The IR will usually come in an order that respects + // dependencies between global declarations, but this + // isn't guaranteed, so we need to be careful about + // the order in which we emit things. - auto func = (IRFunc*) ii; - emitIRFuncDecl(ctx, func); - } + List actions; - for(auto ii : module->getGlobalInsts()) - { - emitIRGlobalInst(ctx, ii); - } + computeIREmitActions(module, actions); + executeIREmitActions(ctx, actions); } - - }; // @@ -5614,7 +5470,7 @@ String emitEntryPoint( TargetRequest* targetRequest) { auto translationUnit = entryPoint->getTranslationUnit(); - + SharedEmitContext sharedContext; sharedContext.target = target; sharedContext.finalTarget = targetRequest->target; @@ -5651,19 +5507,26 @@ String emitEntryPoint( target, targetRequest); { - TypeLegalizationContext typeLegalizationContext; - typeLegalizationContext.session = entryPoint->compileRequest->mSession; - IRModule* irModule = getIRModule(irSpecializationState); auto compileRequest = translationUnit->compileRequest; + auto session = compileRequest->mSession; - typeLegalizationContext.irModule = irModule; + TypeLegalizationContext typeLegalizationContext; + initialize(&typeLegalizationContext, + session, + irModule); specializeIRForEntryPoint( irSpecializationState, entryPoint, &sharedContext.extensionUsageTracker); +#if 0 + fprintf(stderr, "### CLONED:\n"); + dumpIR(irModule); + fprintf(stderr, "###\n"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); // If the user specified the flag that they want us to dump @@ -5685,15 +5548,16 @@ String emitEntryPoint( // Debugging code for IR transformations... #if 0 fprintf(stderr, "### SPECIALIZED:\n"); - dumpIR(lowered); + dumpIR(irModule); fprintf(stderr, "###\n"); #endif + validateIRModuleIfEnabled(compileRequest, irModule); // After we've fully specialized all generics, and // "devirtualized" all the calls through interfaces, // we need to ensure that the code only uses types // that are legal on the chosen target. - // + // legalizeTypes( &typeLegalizationContext, irModule); @@ -5701,9 +5565,10 @@ String emitEntryPoint( // Debugging output of legalization #if 0 fprintf(stderr, "### LEGALIZED:\n"); - dumpIR(lowered); + dumpIR(irModule); fprintf(stderr, "###\n"); #endif + validateIRModuleIfEnabled(compileRequest, irModule); // Once specialization and type legalization have been performed, // we should perform some of our basic optimization steps again, @@ -5712,6 +5577,11 @@ String emitEntryPoint( // so that we can work with the individual fields). constructSSA(irModule); +#if 0 + fprintf(stderr, "### AFTER SSA:\n"); + dumpIR(irModule); + fprintf(stderr, "###\n"); +#endif validateIRModuleIfEnabled(compileRequest, irModule); // After all of the required optimization and legalization @@ -5721,9 +5591,9 @@ String emitEntryPoint( // TODO: do we want to emit directly from IR, or translate the // IR back into AST for emission? visitor.emitIRModule(&context, irModule); - + // retain the specialized ir module, because the current - // GlobalGenericParamSubstitution implementation may reference ir objects + // GlobalGenericParamSubstitution implementation may reference ir objects targetRequest->compileRequest->compiledModules.Add(irModule); } destroyIRSpecializationState(irSpecializationState); @@ -5755,7 +5625,7 @@ String emitEntryPoint( finalResultBuilder << code; String finalResult = finalResultBuilder.ProduceString(); - + return finalResult; } diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang deleted file mode 100644 index a1ee2d9cf..000000000 --- a/source/slang/glsl.meta.slang +++ /dev/null @@ -1,202 +0,0 @@ -// Slang GLSL compatibility library - -${{{{ -static const struct { - char const* name; - char const* glslPrefix; -} kTypes[] = -{ - {"float", ""}, - {"int", "i"}, - {"uint", "u"}, - {"bool", "b"}, -}; -static const int kTypeCount = sizeof(kTypes) / sizeof(kTypes[0]); - -for( int tt = 0; tt < kTypeCount; ++tt ) -{ - // Declare GLSL aliases for HLSL types - for (int vv = 2; vv <= 4; ++vv) - { - sb << "typedef vector<" << kTypes[tt].name << "," << vv << "> " << kTypes[tt].glslPrefix << "vec" << vv << ";\n"; - sb << "typedef matrix<" << kTypes[tt].name << "," << vv << "," << vv << "> " << kTypes[tt].glslPrefix << "mat" << vv << ";\n"; - } - for (int rr = 2; rr <= 4; ++rr) - for (int cc = 2; cc <= 4; ++cc) - { - sb << "typedef matrix<" << kTypes[tt].name << "," << rr << "," << cc << "> " << kTypes[tt].glslPrefix << "mat" << rr << "x" << cc << ";\n"; - } -} - -// Multiplication operations for vectors + matrices - -// scalar-vector and vector-scalar -sb << "__generic __intrinsic_op(mul) vector operator*(vector x, T y);\n"; -sb << "__generic __intrinsic_op(mul) vector operator*(T x, vector y);\n"; - -// scalar-matrix and matrix-scalar -sb << "__generic __intrinsic_op(mul) matrix operator*(matrix x, T y);\n"; -sb << "__generic __intrinsic_op(mul) matrix operator*(T x, matrix y);\n"; - -// vector-vector (dot product) -sb << "__generic __intrinsic_op(dot) T operator*(vector x, vector y);\n"; - -// vector-matrix -sb << "__generic __intrinsic_op(mul) vector operator*(vector x, matrix y);\n"; - -// matrix-vector -sb << "__generic __intrinsic_op(mul) vector operator*(matrix x, vector y);\n"; - -// matrix-matrix -sb << "__generic __intrinsic_op(mul) matrix operator*(matrix x, matrix y);\n"; - - - -// - -// TODO(tfoley): Need to handle `RW*` variants of texture types as well... -static const struct { - char const* name; - TextureFlavor::Shape baseShape; - int coordCount; -} kBaseTextureTypes[] = { - { "1D", TextureFlavor::Shape::Shape1D, 1 }, - { "2D", TextureFlavor::Shape::Shape2D, 2 }, - { "3D", TextureFlavor::Shape::Shape3D, 3 }, - { "Cube", TextureFlavor::Shape::ShapeCube, 3 }, - { "Buffer", TextureFlavor::Shape::ShapeBuffer, 1 }, -}; -static const int kBaseTextureTypeCount = sizeof(kBaseTextureTypes) / sizeof(kBaseTextureTypes[0]); - - -static const struct { - char const* name; - SlangResourceAccess access; -} kBaseTextureAccessLevels[] = { - { "", SLANG_RESOURCE_ACCESS_READ }, - { "RW", SLANG_RESOURCE_ACCESS_READ_WRITE }, - { "RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED }, -}; -static const int kBaseTextureAccessLevelCount = sizeof(kBaseTextureAccessLevels) / sizeof(kBaseTextureAccessLevels[0]); - -for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) -{ - char const* shapeName = kBaseTextureTypes[tt].name; - TextureFlavor::Shape baseShape = kBaseTextureTypes[tt].baseShape; - - for (int isArray = 0; isArray < 2; ++isArray) - { - // Arrays of 3D textures aren't allowed - if (isArray && baseShape == TextureFlavor::Shape::Shape3D) continue; - - for (int isMultisample = 0; isMultisample < 2; ++isMultisample) - { - auto readAccess = SLANG_RESOURCE_ACCESS_READ; - auto readWriteAccess = SLANG_RESOURCE_ACCESS_READ_WRITE; - - // TODO: any constraints to enforce on what gets to be multisampled? - - - unsigned flavor = baseShape; - if (isArray) flavor |= TextureFlavor::ArrayFlag; - if (isMultisample) flavor |= TextureFlavor::MultisampleFlag; -// if (isShadow) flavor |= TextureFlavor::ShadowFlag; - - - - unsigned readFlavor = flavor | (readAccess << 8); - unsigned readWriteFlavor = flavor | (readWriteAccess << 8); - - StringBuilder nameBuilder; - nameBuilder << shapeName; - if (isMultisample) nameBuilder << "MS"; - if (isArray) nameBuilder << "Array"; - auto name = nameBuilder.ProduceString(); - - sb << "__generic "; - sb << "__magic_type(TextureSampler," << int(readFlavor) << ") struct "; - sb << "__sampler" << name; - sb << " {};\n"; - - sb << "__generic "; - sb << "__magic_type(Texture," << int(readFlavor) << ") struct "; - sb << "__texture" << name; - sb << " {};\n"; - - sb << "__generic "; - sb << "__magic_type(GLSLImageType," << int(readWriteFlavor) << ") struct "; - sb << "__image" << name; - sb << " {};\n"; - - // TODO(tfoley): flesh this out for all the available prefixes - static const struct - { - char const* prefix; - char const* elementType; - } kTextureElementTypes[] = { - { "", "vec4" }, - { "i", "ivec4" }, - { "u", "uvec4" }, - { nullptr, nullptr }, - }; - for( auto ee = kTextureElementTypes; ee->prefix; ++ee ) - { - sb << "typedef __sampler" << name << "<" << ee->elementType << "> " << ee->prefix << "sampler" << name << ";\n"; - sb << "typedef __texture" << name << "<" << ee->elementType << "> " << ee->prefix << "texture" << name << ";\n"; - sb << "typedef __image" << name << "<" << ee->elementType << "> " << ee->prefix << "image" << name << ";\n"; - } - } - } -} - -sb << "__generic __magic_type(GLSLInputParameterGroupType) struct __GLSLInputParameterGroup {};\n"; -sb << "__generic __magic_type(GLSLOutputParameterGroupType) struct __GLSLOutputParameterGroup {};\n"; -sb << "__generic __magic_type(GLSLShaderStorageBufferType) struct __GLSLShaderStorageBuffer {};\n"; - -sb << "__magic_type(SamplerState," << int(SamplerStateFlavor::SamplerState) << ") struct sampler {};"; - -sb << "__magic_type(GLSLInputAttachmentType) struct subpassInput {};"; - -// Define additional keywords - -sb << "syntax buffer : GLSLBufferModifier;\n"; - -// [GLSL 4.3] Storage Qualifiers - -// TODO: need to support `shared` here with its GLSL meaning - -sb << "syntax patch : GLSLPatchModifier;\n"; -// `centroid` and `sample` handled centrally - -// [GLSL 4.5] Interpolation Qualifiers -sb << "syntax smooth : SimpleModifier;\n"; -sb << "syntax flat : SimpleModifier;\n"; -sb << "syntax noperspective : SimpleModifier;\n"; - - -// [GLSL 4.3.2] Constant Qualifier - -// We need to handle GLSL `const` separately from HLSL `const`, -// since they mean such different things. - -// [GLSL 4.7.2] Precision Qualifiers -sb << "syntax highp : SimpleModifier;\n"; -sb << "syntax mediump : SimpleModifier;\n"; -sb << "syntax lowp : SimpleModifier;\n"; - -// [GLSL 4.8.1] The Invariant Qualifier - -sb << "syntax invariant : SimpleModifier;\n"; - -// [GLSL 4.10] Memory Qualifiers - -sb << "syntax coherent : SimpleModifier;\n"; -sb << "syntax volatile : SimpleModifier;\n"; -sb << "syntax restrict : SimpleModifier;\n"; -sb << "syntax readonly : GLSLReadOnlyModifier;\n"; -sb << "syntax writeonly : GLSLWriteOnlyModifier;\n"; - -// We will treat `subroutine` as a qualifier for now -sb << "syntax subroutine : SimpleModifier;\n"; -}}}} - diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 75d9a3d33..977cb54b9 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -2,7 +2,10 @@ typedef uint UINT; -__generic __magic_type(HLSLAppendStructuredBufferType) struct AppendStructuredBuffer +__generic +__magic_type(HLSLAppendStructuredBufferType) +__intrinsic_type($(kIROp_HLSLAppendStructuredBufferType)) +struct AppendStructuredBuffer { void Append(T value); @@ -11,7 +14,9 @@ __generic __magic_type(HLSLAppendStructuredBufferType) struct AppendStructure out uint stride); }; -__magic_type(HLSLByteAddressBufferType) struct ByteAddressBuffer +__magic_type(HLSLByteAddressBufferType) +__intrinsic_type($(kIROp_HLSLByteAddressBufferType)) +struct ByteAddressBuffer { void GetDimensions( out uint dim); @@ -31,7 +36,7 @@ __magic_type(HLSLByteAddressBufferType) struct ByteAddressBuffer __generic __magic_type(HLSLStructuredBufferType) -__intrinsic_type($(kIROp_structuredBufferType)) +__intrinsic_type($(kIROp_HLSLStructuredBufferType)) struct StructuredBuffer { void GetDimensions( @@ -44,7 +49,10 @@ struct StructuredBuffer __subscript(uint index) -> T { __intrinsic_op(bufferLoad) get; }; }; -__generic __magic_type(HLSLConsumeStructuredBufferType) struct ConsumeStructuredBuffer +__generic +__magic_type(HLSLConsumeStructuredBufferType) +__intrinsic_type($(kIROp_HLSLConsumeStructuredBufferType)) +struct ConsumeStructuredBuffer { T Consume(); @@ -53,17 +61,25 @@ __generic __magic_type(HLSLConsumeStructuredBufferType) struct ConsumeStructu out uint stride); }; -__generic __magic_type(HLSLInputPatchType) struct InputPatch +__generic +__magic_type(HLSLInputPatchType) +__intrinsic_type($(kIROp_HLSLInputPatchType)) +struct InputPatch { __subscript(uint index) -> T; }; -__generic __magic_type(HLSLOutputPatchType) struct OutputPatch +__generic +__magic_type(HLSLOutputPatchType) +__intrinsic_type($(kIROp_HLSLOutputPatchType)) +struct OutputPatch { __subscript(uint index) -> T; }; -__magic_type(HLSLRWByteAddressBufferType) struct RWByteAddressBuffer +__magic_type(HLSLRWByteAddressBufferType) +__intrinsic_type($(kIROp_HLSLRWByteAddressBufferType)) +struct RWByteAddressBuffer { // Note(tfoley): supports alll operations from `ByteAddressBuffer` // TODO(tfoley): can this be made a sub-type? @@ -178,7 +194,7 @@ __magic_type(HLSLRWByteAddressBufferType) struct RWByteAddressBuffer __generic __magic_type(HLSLRWStructuredBufferType) -__intrinsic_type($(kIROp_readWriteStructuredBufferType)) +__intrinsic_type($(kIROp_HLSLRWStructuredBufferType)) struct RWStructuredBuffer { uint DecrementCounter(); @@ -199,7 +215,10 @@ struct RWStructuredBuffer } }; -__generic __magic_type(HLSLPointStreamType) struct PointStream +__generic +__magic_type(HLSLPointStreamType) +__intrinsic_type($(kIROp_HLSLPointStreamType)) +struct PointStream { __target_intrinsic(glsl, "EmitVertex()") void Append(T value); @@ -208,7 +227,10 @@ __generic __magic_type(HLSLPointStreamType) struct PointStream void RestartStrip(); }; -__generic __magic_type(HLSLLineStreamType) struct LineStream +__generic +__magic_type(HLSLLineStreamType) +__intrinsic_type($(kIROp_HLSLLineStreamType)) +struct LineStream { __target_intrinsic(glsl, "EmitVertex()") void Append(T value); @@ -217,7 +239,10 @@ __generic __magic_type(HLSLLineStreamType) struct LineStream void RestartStrip(); }; -__generic __magic_type(HLSLTriangleStreamType) struct TriangleStream +__generic +__magic_type(HLSLTriangleStreamType) +__intrinsic_type($(kIROp_HLSLTriangleStreamType)) +struct TriangleStream { __target_intrinsic(glsl, "EmitVertex()") void Append(T value); @@ -1098,10 +1123,11 @@ static const int kBaseBufferAccessLevelCount = sizeof(kBaseBufferAccessLevels) / for (int aa = 0; aa < kBaseBufferAccessLevelCount; ++aa) { - - sb << "__generic __magic_type(Texture, "; - sb << TextureFlavor::create(TextureFlavor::Shape::ShapeBuffer, kBaseBufferAccessLevels[aa].access).flavor; - sb << ") struct "; + auto flavor = TextureFlavor::create(TextureFlavor::Shape::ShapeBuffer, kBaseBufferAccessLevels[aa].access).flavor; + sb << "__generic\n"; + sb << "__magic_type(Texture," << int(flavor) << ")\n"; + sb << "__intrinsic_type(" << (kIROp_FirstTextureType + flavor) << ")\n"; + sb << "struct "; sb << kBaseBufferAccessLevels[aa].name; sb << "Buffer {\n"; @@ -1151,7 +1177,10 @@ static const RAY_FLAG RAY_FLAG_CULL_NON_OPAQUE = 0x80; // 10.1.2 - Ray Description Structure -__builtin struct RayDesc +__builtin +__magic_type(RayDescType) +__intrinsic_type($(kIROp_RayDescType)) +struct RayDesc { float3 Origin; float TMin; @@ -1161,7 +1190,9 @@ __builtin struct RayDesc // 10.1.3 - Ray Acceleration Structure -__builtin __magic_type(UntypedBufferResourceType) +__builtin +__magic_type(RaytracingAccelerationStructureType) +__intrinsic_type($(kIROp_RaytracingAccelerationStructureType)) struct RaytracingAccelerationStructure {}; // 10.1.4 - Subobject Definitions @@ -1173,7 +1204,10 @@ struct RaytracingAccelerationStructure {}; // 10.1.5 - Intersection Attributes Structure -__builtin struct BuiltInTriangleIntersectionAttributes +__builtin +__magic_type(BuiltInTriangleIntersectionAttributesType) +__intrinsic_type($(kIROp_BuiltInTriangleIntersectionAttributesType)) +struct BuiltInTriangleIntersectionAttributes { float2 barycentrics; }; diff --git a/source/slang/hlsl.meta.slang.h b/source/slang/hlsl.meta.slang.h index 4d241041b..7e79eccf6 100644 --- a/source/slang/hlsl.meta.slang.h +++ b/source/slang/hlsl.meta.slang.h @@ -2,7 +2,13 @@ SLANG_RAW("// Slang HLSL compatibility library\n") SLANG_RAW("\n") SLANG_RAW("typedef uint UINT;\n") SLANG_RAW("\n") -SLANG_RAW("__generic __magic_type(HLSLAppendStructuredBufferType) struct AppendStructuredBuffer\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__magic_type(HLSLAppendStructuredBufferType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_HLSLAppendStructuredBufferType +) +SLANG_RAW(")\n") +SLANG_RAW("struct AppendStructuredBuffer\n") SLANG_RAW("{\n") SLANG_RAW(" void Append(T value);\n") SLANG_RAW("\n") @@ -11,7 +17,12 @@ SLANG_RAW(" out uint numStructs,\n") SLANG_RAW(" out uint stride);\n") SLANG_RAW("};\n") SLANG_RAW("\n") -SLANG_RAW("__magic_type(HLSLByteAddressBufferType) struct ByteAddressBuffer\n") +SLANG_RAW("__magic_type(HLSLByteAddressBufferType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_HLSLByteAddressBufferType +) +SLANG_RAW(")\n") +SLANG_RAW("struct ByteAddressBuffer\n") SLANG_RAW("{\n") SLANG_RAW(" void GetDimensions(\n") SLANG_RAW(" out uint dim);\n") @@ -32,7 +43,7 @@ SLANG_RAW("\n") SLANG_RAW("__generic\n") SLANG_RAW("__magic_type(HLSLStructuredBufferType)\n") SLANG_RAW("__intrinsic_type(") -SLANG_SPLICE(kIROp_structuredBufferType +SLANG_SPLICE(kIROp_HLSLStructuredBufferType ) SLANG_RAW(")\n") SLANG_RAW("struct StructuredBuffer\n") @@ -47,7 +58,13 @@ SLANG_RAW("\n") SLANG_RAW(" __subscript(uint index) -> T { __intrinsic_op(bufferLoad) get; };\n") SLANG_RAW("};\n") SLANG_RAW("\n") -SLANG_RAW("__generic __magic_type(HLSLConsumeStructuredBufferType) struct ConsumeStructuredBuffer\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__magic_type(HLSLConsumeStructuredBufferType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_HLSLConsumeStructuredBufferType +) +SLANG_RAW(")\n") +SLANG_RAW("struct ConsumeStructuredBuffer\n") SLANG_RAW("{\n") SLANG_RAW(" T Consume();\n") SLANG_RAW("\n") @@ -56,17 +73,34 @@ SLANG_RAW(" out uint numStructs,\n") SLANG_RAW(" out uint stride);\n") SLANG_RAW("};\n") SLANG_RAW("\n") -SLANG_RAW("__generic __magic_type(HLSLInputPatchType) struct InputPatch\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__magic_type(HLSLInputPatchType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_HLSLInputPatchType +) +SLANG_RAW(")\n") +SLANG_RAW("struct InputPatch\n") SLANG_RAW("{\n") SLANG_RAW(" __subscript(uint index) -> T;\n") SLANG_RAW("};\n") SLANG_RAW("\n") -SLANG_RAW("__generic __magic_type(HLSLOutputPatchType) struct OutputPatch\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__magic_type(HLSLOutputPatchType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_HLSLOutputPatchType +) +SLANG_RAW(")\n") +SLANG_RAW("struct OutputPatch\n") SLANG_RAW("{\n") SLANG_RAW(" __subscript(uint index) -> T;\n") SLANG_RAW("};\n") SLANG_RAW("\n") -SLANG_RAW("__magic_type(HLSLRWByteAddressBufferType) struct RWByteAddressBuffer\n") +SLANG_RAW("__magic_type(HLSLRWByteAddressBufferType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_HLSLRWByteAddressBufferType +) +SLANG_RAW(")\n") +SLANG_RAW("struct RWByteAddressBuffer\n") SLANG_RAW("{\n") SLANG_RAW(" // Note(tfoley): supports alll operations from `ByteAddressBuffer`\n") SLANG_RAW(" // TODO(tfoley): can this be made a sub-type?\n") @@ -182,7 +216,7 @@ SLANG_RAW("\n") SLANG_RAW("__generic\n") SLANG_RAW("__magic_type(HLSLRWStructuredBufferType)\n") SLANG_RAW("__intrinsic_type(") -SLANG_SPLICE(kIROp_readWriteStructuredBufferType +SLANG_SPLICE(kIROp_HLSLRWStructuredBufferType ) SLANG_RAW(")\n") SLANG_RAW("struct RWStructuredBuffer\n") @@ -205,7 +239,13 @@ SLANG_RAW(" ref;\n") SLANG_RAW("\t}\n") SLANG_RAW("};\n") SLANG_RAW("\n") -SLANG_RAW("__generic __magic_type(HLSLPointStreamType) struct PointStream\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__magic_type(HLSLPointStreamType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_HLSLPointStreamType +) +SLANG_RAW(")\n") +SLANG_RAW("struct PointStream\n") SLANG_RAW("{\n") SLANG_RAW(" __target_intrinsic(glsl, \"EmitVertex()\")\n") SLANG_RAW(" void Append(T value);\n") @@ -214,7 +254,13 @@ SLANG_RAW(" __target_intrinsic(glsl, \"EndPrimitive()\")\n") SLANG_RAW(" void RestartStrip();\n") SLANG_RAW("};\n") SLANG_RAW("\n") -SLANG_RAW("__generic __magic_type(HLSLLineStreamType) struct LineStream\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__magic_type(HLSLLineStreamType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_HLSLLineStreamType +) +SLANG_RAW(")\n") +SLANG_RAW("struct LineStream\n") SLANG_RAW("{\n") SLANG_RAW(" __target_intrinsic(glsl, \"EmitVertex()\")\n") SLANG_RAW(" void Append(T value);\n") @@ -223,7 +269,13 @@ SLANG_RAW(" __target_intrinsic(glsl, \"EndPrimitive()\")\n") SLANG_RAW(" void RestartStrip();\n") SLANG_RAW("};\n") SLANG_RAW("\n") -SLANG_RAW("__generic __magic_type(HLSLTriangleStreamType) struct TriangleStream\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__magic_type(HLSLTriangleStreamType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_HLSLTriangleStreamType +) +SLANG_RAW(")\n") +SLANG_RAW("struct TriangleStream\n") SLANG_RAW("{\n") SLANG_RAW(" __target_intrinsic(glsl, \"EmitVertex()\")\n") SLANG_RAW(" void Append(T value);\n") @@ -1104,10 +1156,11 @@ static const int kBaseBufferAccessLevelCount = sizeof(kBaseBufferAccessLevels) / for (int aa = 0; aa < kBaseBufferAccessLevelCount; ++aa) { - - sb << "__generic __magic_type(Texture, "; - sb << TextureFlavor::create(TextureFlavor::Shape::ShapeBuffer, kBaseBufferAccessLevels[aa].access).flavor; - sb << ") struct "; + auto flavor = TextureFlavor::create(TextureFlavor::Shape::ShapeBuffer, kBaseBufferAccessLevels[aa].access).flavor; + sb << "__generic\n"; + sb << "__magic_type(Texture," << int(flavor) << ")\n"; + sb << "__intrinsic_type(" << (kIROp_FirstTextureType + flavor) << ")\n"; + sb << "struct "; sb << kBaseBufferAccessLevels[aa].name; sb << "Buffer {\n"; @@ -1157,7 +1210,13 @@ SLANG_RAW("static const RAY_FLAG RAY_FLAG_CULL_NON_OPAQUE = 0x8 SLANG_RAW("\n") SLANG_RAW("// 10.1.2 - Ray Description Structure\n") SLANG_RAW("\n") -SLANG_RAW("__builtin struct RayDesc\n") +SLANG_RAW("__builtin\n") +SLANG_RAW("__magic_type(RayDescType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_RayDescType +) +SLANG_RAW(")\n") +SLANG_RAW("struct RayDesc\n") SLANG_RAW("{\n") SLANG_RAW(" float3 Origin;\n") SLANG_RAW(" float TMin;\n") @@ -1167,7 +1226,12 @@ SLANG_RAW("};\n") SLANG_RAW("\n") SLANG_RAW("// 10.1.3 - Ray Acceleration Structure\n") SLANG_RAW("\n") -SLANG_RAW("__builtin __magic_type(UntypedBufferResourceType)\n") +SLANG_RAW("__builtin\n") +SLANG_RAW("__magic_type(RaytracingAccelerationStructureType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_RaytracingAccelerationStructureType +) +SLANG_RAW(")\n") SLANG_RAW("struct RaytracingAccelerationStructure {};\n") SLANG_RAW("\n") SLANG_RAW("// 10.1.4 - Subobject Definitions\n") @@ -1179,7 +1243,13 @@ SLANG_RAW("// for this stuff comes across as a kludge rather than the best possi SLANG_RAW("\n") SLANG_RAW("// 10.1.5 - Intersection Attributes Structure\n") SLANG_RAW("\n") -SLANG_RAW("__builtin struct BuiltInTriangleIntersectionAttributes\n") +SLANG_RAW("__builtin\n") +SLANG_RAW("__magic_type(BuiltInTriangleIntersectionAttributesType)\n") +SLANG_RAW("__intrinsic_type(") +SLANG_SPLICE(kIROp_BuiltInTriangleIntersectionAttributesType +) +SLANG_RAW(")\n") +SLANG_RAW("struct BuiltInTriangleIntersectionAttributes\n") SLANG_RAW("{\n") SLANG_RAW(" float2 barycentrics;\n") SLANG_RAW("};\n") diff --git a/source/slang/ir-constexpr.cpp b/source/slang/ir-constexpr.cpp index ca64f5f04..0cd35161d 100644 --- a/source/slang/ir-constexpr.cpp +++ b/source/slang/ir-constexpr.cpp @@ -26,12 +26,12 @@ struct PropagateConstExprContext DiagnosticSink* getSink() { return sink; } }; -bool isConstExpr(Type* type) +bool isConstExpr(IRType* fullType) { - if( auto rateQualifiedType = type->As() ) + if( auto rateQualifiedType = as(fullType)) { - auto rate = rateQualifiedType->rate; - if(auto constExprRate = rate->As()) + auto rate = rateQualifiedType->getRate(); + if(auto constExprRate = as(rate)) return true; } @@ -101,7 +101,7 @@ void markConstExpr( PropagateConstExprContext* context, IRInst* value) { - Slang::markConstExpr(context->getSession(), value); + Slang::markConstExpr(context->getBuilder(), value); } @@ -285,49 +285,79 @@ bool propagateConstExprBackward( UInt callArgCount = operandCount - firstCallArg; auto callee = callInst->getOperand(0); - while( callee->op == kIROp_specialize ) + + // If we are calling a generic operation, then + // try to follow through the `specialize` chain + // and find the callee. + // + // TODO: This probably shouldn't be required, + // since we can hopefully use the type of the + // callee in all cases. + // + while(auto specInst = as(callee)) { - callee = ((IRSpecialize*) callee)->getOperand(0); + auto genericInst = as(specInst->getBase()); + if(!genericInst) + break; + + auto returnVal = findGenericReturnVal(genericInst); + if(!returnVal) + break; + + callee = returnVal; } - if( callee->op == kIROp_Func ) + + auto calleeFunc = as(callee); + if(calleeFunc && isDefinition(calleeFunc)) { - auto calleeFunc = (IRFunc*) callee; - auto calleeFuncType = calleeFunc->getType(); + // We have an IR-level function definition we are calling, + // and thus we can propagate `constexpr` information + // through its `IRParam`s. + + auto calleeFuncType = calleeFunc->getDataType(); UInt callParamCount = calleeFuncType->getParamCount(); SLANG_RELEASE_ASSERT(callParamCount == callArgCount); // If the callee has a definition, then we can read `constexpr` // information off of the parameters of its first IR block. - if( auto calleeFirstBlock = calleeFunc->getFirstBlock() ) + if(auto calleeFirstBlock = calleeFunc->getFirstBlock()) { UInt paramCounter = 0; - for( auto pp = calleeFirstBlock->getFirstParam(); pp; pp = pp->getNextParam() ) + for(auto pp = calleeFirstBlock->getFirstParam(); pp; pp = pp->getNextParam()) { UInt paramIndex = paramCounter++; auto param = pp; auto arg = callInst->getOperand(firstCallArg + paramIndex); - if( isConstExpr(param) ) + if(isConstExpr(param)) { - if( maybeMarkConstExpr(context, arg) ) + if(maybeMarkConstExpr(context, arg)) { changedThisIteration = true; } } } } - else + } + else + { + // If we don't have a concrete callee function + // definition, then we need to extract the + // type of the callee instruction, and try to work + // with that. + // + // Note that this does not allow us to propagate + // `constexpr` information from the body of a callee + // back to call sites. + auto calleeType = callee->getDataType(); + if(auto caleeFuncType = as(calleeType)) { - // If we don't have the definition/body for the callee, - // then we have to glean `constexpr` information from its - // type instead. - auto calleeType = calleeFunc->getType(); - auto paramCount = calleeType->getParamCount(); + auto paramCount = caleeFuncType->getParamCount(); for( UInt pp = 0; pp < paramCount; ++pp ) { - auto paramType = calleeType->getParamType(pp); + auto paramType = caleeFuncType->getParamType(pp); auto arg = callInst->getOperand(firstCallArg + pp); if( isConstExpr(paramType) ) { @@ -474,8 +504,8 @@ void propagateConstExpr( break; case kIROp_Func: - case kIROp_global_var: - case kIROp_global_constant: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: { IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) gv; @@ -511,8 +541,8 @@ void propagateConstExpr( break; case kIROp_Func: - case kIROp_global_var: - case kIROp_global_constant: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: { IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) ii; validateConstExpr(&context, code); diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index fbb3912d8..3e37259ea 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -8,59 +8,135 @@ #define INST_RANGE(BASE, FIRST, LAST) /* empty */ #endif +#ifndef MANUAL_INST_RANGE +#define MANUAL_INST_RANGE(NAME, START, COUNT) /* empty */ +#endif + #ifndef PSEUDO_INST #define PSEUDO_INST(ID) /* empty */ #endif #define PARENT kIROpFlag_Parent -// Invalid operation: should not appear in valid code INST(Nop, nop, 0, 0) -INST(TypeType, Type, 0, 0) -INST(VoidType, Void, 0, 0) -INST(BlockType, Block, 0, 0) -INST(VectorType, Vec, 2, 0) -INST(MatrixType, Mat, 3, 0) -INST(arrayType, Array, 2, 0) +/* Types */ + + /* Basic Types */ + + #define DEFINE_BASE_TYPE_INST(NAME) INST(NAME ## Type, NAME, 0, 0) + FOREACH_BASE_TYPE(DEFINE_BASE_TYPE_INST) + #undef DEFINE_BASE_TYPE_INST + INST(AfterBaseType, afterBaseType, 0, 0) + + INST_RANGE(BasicType, VoidType, AfterBaseType) + + INST(StringType, String, 0, 0) + INST(RayDescType, RayDesc, 0, 0) + INST(BuiltInTriangleIntersectionAttributesType, BuiltInTriangleIntersectionAttributes, 0, 0) + + /* ArrayTypeBase */ + INST(ArrayType, Array, 2, 0) + INST(UnsizedArrayType, UnsizedArray, 1, 0) + INST_RANGE(ArrayTypeBase, ArrayType, UnsizedArrayType) + + INST(FuncType, Func, 0, 0) + INST(BasicBlockType, BasicBlock, 0, 0) + + INST(VectorType, Vec, 2, 0) + INST(MatrixType, Mat, 3, 0) + + /* Rate */ + INST(ConstExprRate, ConstExpr, 0, 0) + INST(GroupSharedRate, GroupShared, 0, 0) + INST_RANGE(Rate, ConstExprRate, GroupSharedRate) + + INST(RateQualifiedType, RateQualified, 2, 0) + + // Kinds represent the "types of types." + // They should not really be nested under `IRType` + // in the overall hierarchy, but we can fix that later. + // + /* Kind */ + INST(TypeKind, Type, 0, 0) + INST(RateKind, Rate, 0, 0) + INST(GenericKind, Generic, 0, 0) + INST_RANGE(Kind, TypeKind, GenericKind) + + /* PtrTypeBase */ + INST(PtrType, Ptr, 1, 0) + /* OutTypeBase */ + INST(OutType, Out, 1, 0) + INST(InOutType, InOut, 1, 0) + INST_RANGE(OutTypeBase, OutType, InOutType) + INST_RANGE(PtrTypeBase, PtrType, InOutType) + + /* SamplerStateTypeBase */ + INST(SamplerStateType, SamplerState, 0, 0) + INST(SamplerComparisonStateType, SamplerComparisonState, 0, 0) + INST_RANGE(SamplerStateTypeBase, SamplerStateType, SamplerComparisonStateType) + + // TODO: Why do we have all this hierarchy here, when everything + // that actually matters is currently nested under `TextureTypeBase`? + /* ResourceTypeBase */ + /* ResourceType */ + /* TextureTypeBase */ + /* TextureType */ + MANUAL_INST_RANGE(TextureType, 0x10000, TextureFlavor::Count) + /* TextureSamplerType */ + MANUAL_INST_RANGE(TextureSamplerType, 0x20000, TextureFlavor::Count) + /* GLSLImageType */ + MANUAL_INST_RANGE(GLSLImageType, 0x30000, TextureFlavor::Count) + INST_RANGE(TextureTypeBase, FirstTextureType, LastGLSLImageType) + INST_RANGE(ResourceType, FirstTextureType, LastGLSLImageType) + INST_RANGE(ResourceTypeBase, FirstTextureType, LastGLSLImageType) + + /* UntypedBufferResourceType */ + INST(HLSLByteAddressBufferType, ByteAddressBuffer, 0, 0) + INST(HLSLRWByteAddressBufferType, RWByteAddressBuffer, 0, 0) + INST(RaytracingAccelerationStructureType, RaytracingAccelerationStructure, 0, 0) + INST_RANGE(UntypedBufferResourceType, HLSLByteAddressBufferType, RaytracingAccelerationStructureType) + + /* HLSLPatchType */ + INST(HLSLInputPatchType, InputPatch, 2, 0) + INST(HLSLOutputPatchType, OutputPatch, 2, 0) + INST_RANGE(HLSLPatchType, HLSLInputPatchType, HLSLOutputPatchType) + + INST(GLSLInputAttachmentType, GLSLInputAttachment, 0, 0) + + /* BuiltinGenericType */ + /* HLSLStreamOutputType */ + INST(HLSLPointStreamType, PointStream, 1, 0) + INST(HLSLLineStreamType, LineStream, 1, 0) + INST(HLSLTriangleStreamType, TriangleStream, 1, 0) + INST_RANGE(HLSLStreamOutputType, HLSLPointStreamType, HLSLTriangleStreamType) + + /* HLSLStructuredBufferTypeBase */ + INST(HLSLStructuredBufferType, StructuredBuffer, 0, 0) + INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, 0) + INST(HLSLAppendStructuredBufferType, AppendStructuredBuffer, 0, 0) + INST(HLSLConsumeStructuredBufferType, ConsumeStructuredBuffer, 0, 0) + INST_RANGE(HLSLStructuredBufferTypeBase, HLSLStructuredBufferType, HLSLConsumeStructuredBufferType) + + /* PointerLikeType */ + /* ParameterGroupType */ + /* UniformParameterGroupType */ + INST(ConstantBufferType, ConstantBuffer, 1, 0) + INST(TextureBufferType, TextureBuffer, 1, 0) + INST(ParameterBlockType, ParameterBlock, 1, 0) + INST(GLSLShaderStorageBufferType, GLSLShaderStorageBuffer, 0, 0) + INST_RANGE(UniformParameterGroupType, ConstantBufferType, GLSLShaderStorageBufferType) + + /* VaryingParameterGroupType */ + INST(GLSLInputParameterGroupType, GLSLInputParameterGroup, 0, 0) + INST(GLSLOutputParameterGroupType, GLSLOutputParameterGroup, 0, 0) + INST_RANGE(VaryingParameterGroupType, GLSLInputParameterGroupType, GLSLOutputParameterGroupType) + INST_RANGE(ParameterGroupType, ConstantBufferType, GLSLOutputParameterGroupType) + INST_RANGE(PointerLikeType, ConstantBufferType, GLSLOutputParameterGroupType) + INST_RANGE(BuiltinGenericType, HLSLPointStreamType, GLSLOutputParameterGroupType) -INST(BoolType, Bool, 0, 0) -INST(Float16Type, Float16, 0, 0) -INST(Float32Type, Float32, 0, 0) -INST(Float64Type, Float64, 0, 0) -// Signed integer types. -// Note that `IntPtr` represents a pointer-sized integer type, -// and will end up being equivalent to either `Int32` or `Int64` -// when it comes time to actually generate code. -// -INST(Int8Type, Int8, 0, 0) -INST(Int16Type, Int16, 0, 0) -INST(Int32Type, Int32, 0, 0) -INST(IntPtrType, IntPtr, 0, 0) -INST(Int64Type, Int64, 0, 0) - -// Unlike a lot of other IRs, we retain a distinction between -// signed and unsigned integer types, simply because many of -// the target languages we need to generate code for also -// keep this distinction, and it will help us generate variable -// declarations that will be friendly to debuggers. -// -// TODO: We may want to reconsider this choice simply because -// some targets (e.g., those based on C++) may have undefined -// behavior around operations on signed integers that are -// well-defined (two's complement) on unsigned integers. In -// those cases we either want to default to unsigned integers, -// and then cast around the few ops that care about the difference, -// or else we want to keep using the orignal types, but need -// to cast around any ordinary math operations on signed types. -// -INST(UInt8Type, Int8, 0, 0) -INST(UInt16Type, Int16, 0, 0) -INST(UInt32Type, Int32, 0, 0) -INST(UIntPtrType, IntPtr, 0, 0) -INST(UInt64Type, Int64, 0, 0) // A user-defined structure declaration at the IR level. // Unlike in the AST where there is a distinction between @@ -71,40 +147,53 @@ INST(UInt64Type, Int64, 0, 0) // This is a parent instruction that holds zero or more // `field` instructions. // -INST(StructType, Struct, 0, PARENT) +// Note: we are being a bit slippery here, because a `struct` +// instruction is really an `IRParentInst`, but we want it +// to also be caught in any dynamic cast to `IRType`, so we +// ensure that it comes at the *end* of the range for `IRType`, +// and the start of the range for `IRParentInst` (and `IRGlobalValue`) +INST(StructType, struct, 0, PARENT) -INST(FuncType, Func, 0, 0) -INST(PtrType, Ptr, 1, 0) -INST(TextureType, Texture, 2, 0) -INST(SamplerType, SamplerState, 1, 0) -INST(ConstantBufferType, ConstantBuffer, 1, 0) -INST(TextureBufferType, TextureBuffer, 1, 0) +INST_RANGE(Type, VoidType, StructType) -INST(structuredBufferType, StructuredBuffer, 1, 0) -INST(readWriteStructuredBufferType, RWStructuredBuffer, 1, 0) +/*IRParentInst*/ -// A type use to represent an earlier generic parameter in -// a signature. For example, given an AST declaration like: -// -// func Foo(int a, T b) -> U; -// -// The lowered function type would be something like: -// -// T U a b -// (Type, Type, Int32, GenericParameterType<0>) -> GenericParameterType<1> -// -INST(GenericParameterType, GenericParameterType, 1, 0) + /*IRGlobalValue*/ + + /*IRGlobalValueWithCode*/ + /* IRGlobalValueWIthParams*/ + INST(Func, func, 0, PARENT) + INST(Generic, generic, 0, PARENT) + INST_RANGE(GlobalValueWithParams, Func, Generic) -INST(boolConst, boolConst, 0, 0) -INST(IntLit, integer_constant, 0, 0) -INST(FloatLit, float_constant, 0, 0) -INST(decl_ref, decl_ref, 0, 0) + INST(GlobalVar, global_var, 0, 0) + INST(GlobalConstant, global_constant, 0, 0) + INST_RANGE(GlobalValueWithCode, Func, GlobalConstant) + + INST(StructKey, key, 0, 0) + INST(GlobalGenericParam, global_generic_param, 0, 0) + INST(WitnessTable, witness_table, 0, 0) + + INST_RANGE(GlobalValue, StructType, WitnessTable) + + INST(Module, module, 0, PARENT) + + INST(Block, block, 0, PARENT) + +INST_RANGE(ParentInst, StructType, Block) + +/* IRConstant */ + INST(boolConst, boolConst, 0, 0) + INST(IntLit, integer_constant, 0, 0) + INST(FloatLit, float_constant, 0, 0) +INST_RANGE(Constant, boolConst, FloatLit) INST(undefined, undefined, 0, 0) -INST(specialize, specialize, 2, 0) +INST(Specialize, specialize, 2, 0) INST(lookup_interface_method, lookup_interface_method, 2, 0) INST(lookup_witness_table, lookup_witness_table, 2, 0) +INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0) INST(Construct, construct, 0, 0) @@ -115,30 +204,11 @@ INST(makeStruct, makeStruct, 0, 0) INST(Call, call, 1, 0) -/*IRParentInst*/ - - INST(Module, module, 0, PARENT) - - INST(Block, block, 0, PARENT) - - /*IRGlobalValue*/ - - /*IRGlobalValueWithCode*/ - INST(Func, func, 0, PARENT) - INST(global_var, global_var, 0, 0) - INST(global_constant, global_constant, 0, 0) - INST_RANGE(GlobalValueWithCode, Func, global_constant) - - INST(witness_table, witness_table, 0, 0) - - INST_RANGE(GlobalValue, Func, witness_table) - -INST_RANGE(ParentInst, Module, witness_table) -INST(witness_table_entry, witness_table_entry, 2, 0) +INST(WitnessTableEntry, witness_table_entry, 2, 0) INST(Param, param, 0, 0) -INST(StructField, field, 0, 0) +INST(StructField, field, 2, 0) INST(Var, var, 0, 0) INST(Load, load, 1, 0) @@ -287,6 +357,7 @@ PSEUDO_INST(Or) #undef PSEUDO_INST #undef PARENT +#undef MANUAL_INST_RANGE #undef INST_RANGE #undef INST diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 231330a28..6b8a8b21e 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -88,6 +88,39 @@ struct IRGLSLOuterArrayDecoration : IRDecoration char const* outerArrayName; }; +// A decoration that marks a field key as having been associated +// with a particular simple semantic (e.g., `COLOR` or `SV_Position`, +// but not a `register` semantic). +// +// This is currently needed so that we can round-trip HLSL `struct` +// types that get used for varying input/output. This is an unfortunate +// case where some amount of "layout" information can't just come +// in via the `TypeLayout` part of things. +// +struct IRSemanticDecoration : IRDecoration +{ + enum { kDecorationOp = kIRDecorationOp_Semantic }; + + Name* semanticName; +}; + +enum class IRInterpolationMode +{ + Linear, + NoPerspective, + NoInterpolation, + + Centroid, + Sample, +}; + +struct IRInterpolationModeDecoration : IRDecoration +{ + enum { kDecorationOp = kIRDecorationOp_InterpolationMode }; + + IRInterpolationMode mode; +}; + // // An IR node to represent a reference to an AST-level @@ -108,8 +141,16 @@ struct IRDeclRef : IRInst // struct IRSpecialize : IRInst { - IRUse genericVal; - IRUse specDeclRefVal; + // The "base" for the call is the generic to be specialized + IRUse base; + IRInst* getBase() { return getOperand(0); } + + // after the generic value come the arguments + UInt getArgCount() { return getOperandCount() - 1; } + IRInst* getArg(UInt index) { return getOperand(index + 1); } + + IR_LEAF_ISA(Specialize) + }; // An instruction that looks up the implementation @@ -119,7 +160,10 @@ struct IRSpecialize : IRInst struct IRLookupWitnessMethod : IRInst { IRUse witnessTable; - IRUse requirementDeclRef; + IRUse requirementKey; + + IRInst* getWitnessTable() { return witnessTable.get(); } + IRInst* getRequirementKey() { return requirementKey.get(); } }; struct IRLookupWitnessTable : IRInst @@ -314,9 +358,9 @@ struct IRSwizzleSet : IRReturn // a stack allocation of some memory. struct IRVar : IRInst { - PtrType* getDataType() + IRPtrType* getDataType() { - return (PtrType*) IRInst::getDataType(); + return cast(IRInst::getDataType()); } static bool isaImpl(IROp op) { return op == kIROp_Var; } @@ -330,9 +374,9 @@ struct IRVar : IRInst /// blocks nested inside this value. struct IRGlobalVar : IRGlobalValueWithCode { - PtrType* getDataType() + IRPtrType* getDataType() { - return (PtrType*) IRInst::getDataType(); + return cast(IRInst::getDataType()); } }; @@ -343,6 +387,7 @@ struct IRGlobalVar : IRGlobalValueWithCode /// the code in the basic block(s) nested in this value. struct IRGlobalConstant : IRGlobalValueWithCode { + IR_LEAF_ISA(GlobalConstant) }; // An entry in a witness table (see below) @@ -353,6 +398,8 @@ struct IRWitnessTableEntry : IRInst // The IR-level value that satisfies the requirement IRUse satisfyingVal; + + IR_LEAF_ISA(WitnessTableEntry) }; // A witness table is a global value that stores @@ -367,16 +414,7 @@ struct IRWitnessTable : IRGlobalValue return IRInstList(getChildren()); } - RefPtr genericDecl; - DeclRef subTypeDeclRef, supTypeDeclRef; - - virtual void dispose() override - { - IRGlobalValue::dispose(); - genericDecl = decltype(genericDecl)(); - subTypeDeclRef = decltype(subTypeDeclRef)(); - supTypeDeclRef = decltype(supTypeDeclRef)(); - } + IR_LEAF_ISA(WitnessTable) }; // An instruction that yields an undefined value. @@ -388,6 +426,23 @@ struct IRUndefined : IRInst { }; +// A global-scope generic parameter (a type parameter, a +// constraint parameter, etc.) +struct IRGlobalGenericParam : IRGlobalValue +{ + IR_LEAF_ISA(GlobalGenericParam) +}; + +// An instruction that binds a global generic parameter +// to a particular value. +struct IRBindGlobalGenericParam : IRInst +{ + IRGlobalGenericParam* getParam() { return cast(getOperand(0)); } + IRInst* getVal() { return getOperand(1); } + + IR_LEAF_ISA(BindGlobalGenericParam) +}; + // Description of an instruction to be used for global value numbering struct IRInstKey { @@ -463,49 +518,81 @@ struct IRBuilder IRInst* getIntValue(IRType* type, IRIntegerValue value); IRInst* getFloatValue(IRType* type, IRFloatingPointValue value); - IRInst* getDeclRefVal( - DeclRefBase const& declRef); - IRInst* getTypeVal(IRType* type); // create an IR value that represents a type - IRInst* emitSpecializeInst( - IRType* type, - IRInst* genericVal, - IRInst* specDeclRef); + IRBasicType* getBasicType(BaseType baseType); + IRBasicType* getVoidType(); + IRBasicType* getBoolType(); + IRBasicType* getIntType(); + IRBasicBlockType* getBasicBlockType(); + IRType* getWitnessTableType() { return nullptr; } + IRType* getKeyType() { return nullptr; } - IRInst* emitSpecializeInst( - IRType* type, - IRInst* genericVal, - DeclRef specDeclRef); + IRTypeKind* getTypeKind(); + IRGenericKind* getGenericKind(); - IRInst* emitLookupInterfaceMethodInst( - IRType* type, - IRInst* witnessTableVal, - IRInst* interfaceMethodVal); + IRPtrType* getPtrType(IRType* valueType); + IROutType* getOutType(IRType* valueType); + IRInOutType* getInOutType(IRType* valueType); + IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); - IRInst* emitLookupInterfaceMethodInst( - IRType* type, - DeclRef witnessTableDeclRef, - DeclRef interfaceMethodDeclRef); + IRArrayTypeBase* getArrayTypeBase( + IROp op, + IRType* elementType, + IRInst* elementCount); - IRInst* emitLookupInterfaceMethodInst( + IRArrayType* getArrayType( + IRType* elementType, + IRInst* elementCount); + + IRUnsizedArrayType* getUnsizedArrayType( + IRType* elementType); + + IRVectorType* getVectorType( + IRType* elementType, + IRInst* elementCount); + + IRMatrixType* getMatrixType( + IRType* elementType, + IRInst* rowCount, + IRInst* columnCount); + + IRFuncType* getFuncType( + UInt paramCount, + IRType* const* paramTypes, + IRType* resultType); + + IRConstExprRate* getConstExprRate(); + IRGroupSharedRate* getGroupSharedRate(); + + IRRateQualifiedType* getRateQualifiedType( + IRRate* rate, + IRType* dataType); + + // Set the data type of an instruction, while preserving + // its rate, if any. + void setDataType(IRInst* inst, IRType* dataType); + + IRInst* emitSpecializeInst( IRType* type, - IRInst* witnessTableVal, - DeclRef interfaceMethodDeclRef); + IRInst* genericVal, + UInt argCount, + IRInst* const* args); - IRInst* emitFindWitnessTable( - DeclRef baseTypeDeclRef, - IRType* interfaceType); + IRInst* emitLookupInterfaceMethodInst( + IRType* type, + IRInst* witnessTableVal, + IRInst* interfaceMethodVal); IRInst* emitCallInst( IRType* type, - IRInst* func, + IRInst* func, UInt argCount, - IRInst* const* args); + IRInst* const* args); IRInst* emitIntrinsicInst( IRType* type, IROp op, UInt argCount, - IRInst* const* args); + IRInst* const* args); IRInst* emitConstructorInst( IRType* type, @@ -532,7 +619,7 @@ struct IRBuilder IRModule* createModule(); - + IRFunc* createFunc(); IRGlobalVar* createGlobalVar( IRType* valueType); @@ -543,6 +630,32 @@ struct IRBuilder IRWitnessTable* witnessTable, IRInst* requirementKey, IRInst* satisfyingVal); + + // Create an initially empty `struct` type. + IRStructType* createStructType(); + + // Create a global "key" to use for indexing into a `struct` type. + IRStructKey* createStructKey(); + + // Create a field nested in a struct type, declaring that + // the specified field key maps to a field with the specified type. + IRStructField* createStructField( + IRStructType* structType, + IRStructKey* fieldKey, + IRType* fieldType); + + IRGeneric* createGeneric(); + IRGeneric* emitGeneric(); + + // Low-level operation for creating a type. + IRType* getType( + IROp op, + UInt operandCount, + IRInst* const* operands); + IRType* getType( + IROp op); + + IRWitnessTable* lookupWitnessTable(Name* mangledName); void registerWitnessTable(IRWitnessTable* table); IRBlock* createBlock(); @@ -660,6 +773,12 @@ struct IRBuilder UInt caseArgCount, IRInst* const* caseArgs); + IRGlobalGenericParam* emitGlobalGenericParam(); + + IRBindGlobalGenericParam* emitBindGlobalGenericParam( + IRInst* param, + IRInst* val); + template T* addDecoration(IRInst* value, IRDecorationOp op) { @@ -667,7 +786,7 @@ struct IRBuilder auto decorationSize = sizeof(T); auto decoration = (T*)getModule()->memoryPool.allocZero(decorationSize); new(decoration)T(); - + decoration->op = op; decoration->next = value->firstDecoration; @@ -757,7 +876,7 @@ void specializeGenerics( // void markConstExpr( - Session* session, + IRBuilder* builder, IRInst* irValue); // diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index 7e380e237..20efc02b1 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -98,28 +98,11 @@ struct IRTypeLegalizationContext }; static void registerLegalizedValue( - IRTypeLegalizationContext* context, - IRInst* irValue, - LegalVal const& legalVal) -{ - context->mapValToLegalVal.Add(irValue, legalVal); -} - -static void maybeRegisterLegalizedGlobal( IRTypeLegalizationContext* context, - IRGlobalValue* irGlobalVar, + IRInst* irValue, LegalVal const& legalVal) { - // Check the mangled name of the symbol and don't register - // symbols that don't have an external name (currently - // indicated by them having an empty name string). - if (getText(irGlobalVar->mangledName).Length() == 0) - return; - - // Otherwise, register the legalized value for this symbol - // under its mangled name, so that other code can still - // find the right value(s) to use after legalization. - context->typeLegalizationContext->mapMangledNameToLegalIRValue.AddIfNotExists(irGlobalVar->mangledName, legalVal); + context->mapValToLegalVal[irValue] = legalVal; } struct IRGlobalNameInfo @@ -138,16 +121,16 @@ static LegalVal declareVars( static LegalType legalizeType( IRTypeLegalizationContext* context, - Type* type) + IRType* type) { return legalizeType(context->typeLegalizationContext, type); } // Legalize a type, and then expect it to // result in a simple type. -static RefPtr legalizeSimpleType( +static IRType* legalizeSimpleType( IRTypeLegalizationContext* context, - Type* type) + IRType* type) { auto legalType = legalizeType(context, type); switch (legalType.flavor) @@ -179,7 +162,7 @@ static LegalVal legalizeOperand( } static void getArgumentValues( - List & instArgs, + List & instArgs, LegalVal val) { switch (val.flavor) @@ -224,15 +207,15 @@ static LegalVal legalizeCall( IRCall* callInst) { // TODO: implement legalization of non-simple return types - auto retType = legalizeType(context, callInst->type); + auto retType = legalizeType(context, callInst->getFullType()); SLANG_ASSERT(retType.flavor == LegalType::Flavor::simple); - + List instArgs; for (auto i = 1u; i < callInst->getOperandCount(); i++) getArgumentValues(instArgs, legalizeOperand(context, callInst->getOperand(i))); return LegalVal::simple(context->builder->emitCallInst( - callInst->type, + callInst->getFullType(), callInst->func.get(), instArgs.Count(), instArgs.Buffer())); @@ -279,7 +262,7 @@ static LegalVal legalizeLoad( for (auto ee : legalPtrVal.getTuple()->elements) { TuplePseudoVal::Element element; - element.mangledName = ee.mangledName; + element.key = ee.key; element.val = legalizeLoad(context, ee.val); tupleVal->elements.Add(element); @@ -353,7 +336,7 @@ static LegalVal legalizeFieldAddress( IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, - DeclRef fieldDeclRef) + IRStructKey* fieldKey) { auto builder = context->builder; @@ -364,17 +347,15 @@ static LegalVal legalizeFieldAddress( builder->emitFieldAddress( type.getSimple(), legalPtrOperand.getSimple(), - builder->getDeclRefVal(fieldDeclRef))); + fieldKey)); case LegalVal::Flavor::pair: { - String mangledFieldName = getMangledName(fieldDeclRef.getDecl()); - // There are two sides, the ordinary and the special, // and we basically just dispatch to both of them. auto pairVal = legalPtrOperand.getPair(); auto pairInfo = pairVal->pairInfo; - auto pairElement = pairInfo->findElement(mangledFieldName); + auto pairElement = pairInfo->findElement(fieldKey); if (!pairElement) { SLANG_UNEXPECTED("didn't find tuple element"); @@ -400,18 +381,11 @@ static LegalVal legalizeFieldAddress( if (pairElement->flags & PairInfo::kFlag_hasOrdinary) { - // Note: the ordinary side of the pair is expected - // to be a filtered `struct` type, and so it will - // have different field declarations than the - // oridinal type. The element of the `PairInfo` - // structure stores the correct field decl-ref to use - // as `ordinaryFieldDeclRef`. - ordinaryVal = legalizeFieldAddress( context, ordinaryType, pairVal->ordinaryVal, - pairElement->ordinaryFieldDeclRef); + fieldKey); } if (pairElement->flags & PairInfo::kFlag_hasSpecial) @@ -420,7 +394,7 @@ static LegalVal legalizeFieldAddress( context, specialType, pairVal->specialVal, - fieldDeclRef); + fieldKey); } return LegalVal::pair(ordinaryVal, specialVal, fieldPairInfo); } @@ -428,8 +402,6 @@ static LegalVal legalizeFieldAddress( case LegalVal::Flavor::tuple: { - String mangledFieldName = getMangledName(fieldDeclRef.getDecl()); - // The operand is a tuple of pointer-like // values, we want to extract the element // corresponding to a field. We will handle @@ -438,7 +410,7 @@ static LegalVal legalizeFieldAddress( auto ptrTupleInfo = legalPtrOperand.getTuple(); for (auto ee : ptrTupleInfo->elements) { - if (ee.mangledName == mangledFieldName) + if (ee.key == fieldKey) { return ee.val; } @@ -465,15 +437,13 @@ static LegalVal legalizeFieldAddress( { // We don't expect any legalization to affect // the "field" argument. - auto fieldOperand = legalFieldOperand.getSimple(); - assert(fieldOperand->op == kIROp_decl_ref); - auto fieldDeclRef = ((IRDeclRef*)fieldOperand)->declRef; + auto fieldKey = legalFieldOperand.getSimple(); return legalizeFieldAddress( context, type, legalPtrOperand, - fieldDeclRef); + (IRStructKey*) fieldKey); } static LegalVal legalizeGetElementPtr( @@ -548,7 +518,7 @@ static LegalVal legalizeGetElementPtr( auto elemType = tupleType->elements[ee].type; TuplePseudoVal::Element resElem; - resElem.mangledName = ptrElem.mangledName; + resElem.key = ptrElem.key; resElem.val = legalizeGetElementPtr( context, elemType, @@ -646,8 +616,8 @@ static LegalVal legalizeLocalVar( case LegalType::Flavor::simple: // Easy case: the type is usable as-is, and we // should just do that. - irLocalVar->type = context->session->getPtrType( - maybeSimpleType.getSimple()); + irLocalVar->setFullType(context->builder->getPtrType( + maybeSimpleType.getSimple())); return LegalVal::simple(irLocalVar); default: @@ -684,7 +654,7 @@ static LegalVal legalizeParam( { // Simple case: things were legalized to a simple type, // so we can just use the original parameter as-is. - originalParam->type = legalParamType.getSimple(); + originalParam->setFullType(legalParamType.getSimple()); return LegalVal::simple(originalParam); } else @@ -702,6 +672,17 @@ static LegalVal legalizeParam( } } +static LegalVal legalizeFunc( + IRTypeLegalizationContext* context, + IRFunc* irFunc); + +static LegalVal legalizeGlobalVar( + IRTypeLegalizationContext* context, + IRGlobalVar* irGlobalVar); + +static LegalVal legalizeGlobalConstant( + IRTypeLegalizationContext* context, + IRGlobalConstant* irGlobalConstant); static LegalVal legalizeInst( @@ -717,6 +698,19 @@ static LegalVal legalizeInst( case kIROp_Param: return legalizeParam(context, cast(inst)); + case kIROp_WitnessTable: + // Just skip these. + break; + + case kIROp_Func: + return legalizeFunc(context, cast(inst)); + + case kIROp_GlobalVar: + return legalizeGlobalVar(context, cast(inst)); + + case kIROp_GlobalConstant: + return legalizeGlobalConstant(context, cast(inst)); + default: break; } @@ -736,7 +730,7 @@ static LegalVal legalizeInst( } // Also legalize the type of the instruction - LegalType legalType = legalizeType(context, inst->type); + LegalType legalType = legalizeType(context, inst->getFullType()); if (!anyComplex && legalType.flavor == LegalType::Flavor::simple) { @@ -749,7 +743,7 @@ static LegalVal legalizeInst( inst->setOperand(aa, legalArg.getSimple()); } - inst->type = legalType.getSimple(); + inst->setFullType(legalType.getSimple()); return LegalVal::simple(inst); } @@ -774,9 +768,8 @@ static LegalVal legalizeInst( // original instruction by removing it from // the IR. // - // TODO: we need to add it to a list of - // instructions to be cleaned up... inst->removeFromParent(); + context->replacedInstructions.Add(inst); // The value to be used when referencing // the original instruction will now be @@ -784,33 +777,35 @@ static LegalVal legalizeInst( return legalVal; } -static void addParamType(IRFuncType * ftype, LegalType t) +static void addParamType(List& ioParamTypes, LegalType t) { switch (t.flavor) { case LegalType::Flavor::none: break; + case LegalType::Flavor::simple: - ftype->paramTypes.Add(t.obj.As()); + ioParamTypes.Add(t.getSimple()); break; + case LegalType::Flavor::implicitDeref: { - auto imp = t.obj.As(); - addParamType(ftype, imp->valueType); + auto imp = t.getImplicitDeref(); + addParamType(ioParamTypes, imp->valueType); break; } case LegalType::Flavor::pair: { auto pairInfo = t.getPair(); - addParamType(ftype, pairInfo->ordinaryType); - addParamType(ftype, pairInfo->specialType); + addParamType(ioParamTypes, pairInfo->ordinaryType); + addParamType(ioParamTypes, pairInfo->specialType); } break; case LegalType::Flavor::tuple: { - auto tup = t.obj.As(); + auto tup = t.getTuple(); for (auto & elem : tup->elements) - addParamType(ftype, elem.type); + addParamType(ioParamTypes, elem.type); } break; default: @@ -818,54 +813,63 @@ static void addParamType(IRFuncType * ftype, LegalType t) } } -static void legalizeFunc( - IRTypeLegalizationContext* context, - IRFunc* irFunc) +static void legalizeInstsInParent( + IRTypeLegalizationContext* context, + IRParentInst* parent) { - // Overwrite the function's type with - // the result of legalization. - auto newFuncType = new IRFuncType(); - newFuncType->setSession(context->session); - auto oldFuncType = irFunc->type.As(); - newFuncType->resultType = legalizeSimpleType(context, oldFuncType->resultType); - for (auto & paramType : oldFuncType->paramTypes) - { - auto legalParamType = legalizeType(context, paramType); - addParamType(newFuncType, legalParamType); - } - irFunc->type = newFuncType; - - // we use this list to store replaced local var insts. - // these old instructions will be freed when we are done. - context->replacedInstructions.Clear(); - - // Go through the blocks of the function - for (auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock()) + IRInst* nextChild = nullptr; + for(auto child = parent->getFirstChild(); child; child = nextChild) { - // Legalize the instructions inside the block - IRInst* nextInst = nullptr; - for (auto ii = bb->getFirstInst(); ii; ii = nextInst) - { - nextInst = ii->getNextInst(); - - LegalVal legalVal = legalizeInst(context, ii); + nextChild = child->getNextInst(); - registerLegalizedValue(context, ii, legalVal); + if (auto block = as(child)) + { + legalizeInstsInParent(context, block); + } + else + { + LegalVal legalVal = legalizeInst(context, child); + registerLegalizedValue(context, child, legalVal); } - } +} - // Clean up after any instructions we replaced along the way. - for (auto & lv : context->replacedInstructions) +static LegalVal legalizeFunc( + IRTypeLegalizationContext* context, + IRFunc* irFunc) +{ + // Overwrite the function's type with the result of legalization. + + IRFuncType* oldFuncType = irFunc->getDataType(); + UInt oldParamCount = oldFuncType->getParamCount(); + + // TODO: we should give an error message when the result type of a function + // can't be legalized (e.g., trying to return a texture, or a structue that + // contains one). + IRType* newResultType = legalizeSimpleType(context, oldFuncType->getResultType()); + List newParamTypes; + for (UInt pp = 0; pp < oldParamCount; ++pp) { - lv->deallocate(); + auto legalParamType = legalizeType(context, oldFuncType->getParamType(pp)); + addParamType(newParamTypes, legalParamType); } + + auto newFuncType = context->builder->getFuncType( + newParamTypes.Count(), + newParamTypes.Buffer(), + newResultType); + + context->builder->setDataType(irFunc, newFuncType); + + legalizeInstsInParent(context, irFunc); + + return LegalVal::simple(irFunc); } static LegalVal declareSimpleVar( - IRTypeLegalizationContext* context, + IRTypeLegalizationContext* context, IROp op, - Type* type, + IRType* type, TypeLayout* typeLayout, LegalVarChain* varChain, IRGlobalNameInfo* globalNameInfo) @@ -885,7 +889,7 @@ static LegalVal declareSimpleVar( switch (op) { - case kIROp_global_var: + case kIROp_GlobalVar: { auto globalVar = builder->createGlobalVar(type); globalVar->removeFromParent(); @@ -907,7 +911,7 @@ static LegalVal declareSimpleVar( globalVar->mangledName = context->session->getNameObj(mangledNameStr); } } - + irVar = globalVar; @@ -1008,7 +1012,7 @@ static LegalVal declareVars( for (auto ee : tupleType->elements) { - auto fieldLayout = getFieldLayout(typeLayout, ee.mangledName); + auto fieldLayout = getFieldLayout(typeLayout, getText(ee.key->mangledName)); RefPtr fieldTypeLayout = fieldLayout ? fieldLayout->typeLayout : nullptr; // If we are processing layout information, then @@ -1033,7 +1037,7 @@ static LegalVal declareVars( globalNameInfo); TuplePseudoVal::Element element; - element.mangledName = ee.mangledName; + element.key = ee.key; element.val = fieldVal; tupleVal->elements.Add(element); } @@ -1048,7 +1052,7 @@ static LegalVal declareVars( } } -static void legalizeGlobalVar( +static LegalVal legalizeGlobalVar( IRTypeLegalizationContext* context, IRGlobalVar* irGlobalVar) { @@ -1065,9 +1069,11 @@ static void legalizeGlobalVar( case LegalType::Flavor::simple: // Easy case: the type is usable as-is, and we // should just do that. - irGlobalVar->type = context->session->getPtrType( - legalValueType.getSimple()); - break; + context->builder->setDataType( + irGlobalVar, + context->builder->getPtrType( + legalValueType.getSimple())); + return LegalVal::simple(irGlobalVar); default: { @@ -1086,23 +1092,22 @@ static void legalizeGlobalVar( globalNameInfo.globalVar = irGlobalVar; globalNameInfo.counter = 0; - LegalVal newVal = declareVars(context, kIROp_global_var, legalValueType, typeLayout, varChain, &globalNameInfo); + LegalVal newVal = declareVars(context, kIROp_GlobalVar, legalValueType, typeLayout, varChain, &globalNameInfo); // Register the new value as the replacement for the old registerLegalizedValue(context, irGlobalVar, newVal); - // Also register the variable according to its mangled name, if any. - maybeRegisterLegalizedGlobal(context, irGlobalVar, newVal); - // Remove the old global from the module. irGlobalVar->removeFromParent(); - // TODO: actually clean up the global! + context->replacedInstructions.Add(irGlobalVar); + + return newVal; } break; } } -static void legalizeGlobalConstant( +static LegalVal legalizeGlobalConstant( IRTypeLegalizationContext* context, IRGlobalConstant* irGlobalConstant) { @@ -1116,8 +1121,8 @@ static void legalizeGlobalConstant( case LegalType::Flavor::simple: // Easy case: the type is usable as-is, and we // should just do that. - irGlobalConstant->type = legalValueType.getSimple(); - break; + irGlobalConstant->setFullType(legalValueType.getSimple()); + return LegalVal::simple(irGlobalConstant); default: { @@ -1128,46 +1133,17 @@ static void legalizeGlobalConstant( globalNameInfo.counter = 0; // TODO: need to handle initializer here! - LegalVal newVal = declareVars(context, kIROp_global_constant, legalValueType, nullptr, nullptr, &globalNameInfo); + LegalVal newVal = declareVars(context, kIROp_GlobalConstant, legalValueType, nullptr, nullptr, &globalNameInfo); // Register the new value as the replacement for the old registerLegalizedValue(context, irGlobalConstant, newVal); - // Also register the variable according to its mangled name, if any. - maybeRegisterLegalizedGlobal(context, irGlobalConstant, newVal); - // Remove the old global from the module. irGlobalConstant->removeFromParent(); - // TODO: actually clean up the global! - } - break; - } -} - -static void legalizeGlobalValue( - IRTypeLegalizationContext* context, - IRGlobalValue* irValue) -{ - switch (irValue->op) - { - case kIROp_witness_table: - // Just skip these. - break; - - case kIROp_Func: - legalizeFunc(context, (IRFunc*)irValue); - break; - - case kIROp_global_var: - legalizeGlobalVar(context, (IRGlobalVar*)irValue); - break; + context->replacedInstructions.Add(irGlobalConstant); - case kIROp_global_constant: - legalizeGlobalConstant(context, (IRGlobalConstant*)irValue); - break; - - default: - SLANG_UNEXPECTED("unknown global value type"); + return newVal; + } break; } } @@ -1175,19 +1151,14 @@ static void legalizeGlobalValue( static void legalizeTypes( IRTypeLegalizationContext* context) { + // Legalize all the top-level instructions in the module auto module = context->module; - IRInst* next = nullptr; - for(auto ii = module->getGlobalInsts().getFirst(); ii; ii = next) + legalizeInstsInParent(context, module->moduleInst); + + // Clean up after any instructions we replaced along the way. + for (auto& lv : context->replacedInstructions) { - next = ii->getNextInst(); - - // TODO: Once we start having global-scope instructions that - // aren't `IRGlobalValue`s, we'll actually want to handle those - // here too. - auto gv = as(ii); - if (!gv) - continue; - legalizeGlobalValue(context, gv); + lv->deallocate(); } } @@ -1221,6 +1192,17 @@ void legalizeTypes( legalizeTypes(context); + // Clean up after any type instructions we removed (e.g., + // global `struct` types). + // + // TODO: this logic should probably get paired up with + // the case for `IRTypeLegalizationContext::replacedInstructions`, + // but we haven't yet folded all the legalization logic into + // the IR legalization pass (since it used to apply to the AST too). + for (auto& oldInst : typeLegalizationContext->instsToRemove) + { + oldInst->removeAndDeallocate(); + } } } diff --git a/source/slang/ir-ssa.cpp b/source/slang/ir-ssa.cpp index 60ecddfbd..1d049c685 100644 --- a/source/slang/ir-ssa.cpp +++ b/source/slang/ir-ssa.cpp @@ -84,6 +84,9 @@ struct ConstructSSAContext // IR building state to use during the operation SharedIRBuilder sharedBuilder; + IRBuilder builder; + IRBuilder* getBuilder() { return &builder; } + Dictionary> phiInfos; @@ -211,7 +214,7 @@ PhiInfo* addPhi( auto valueType = var->getDataType()->getValueType(); if( auto rate = var->getRate() ) { - valueType = context->sharedBuilder.getSession()->getRateQualifiedType(rate, valueType); + valueType = context->getBuilder()->getRateQualifiedType(rate, valueType); } IRParam* phi = builder->createParam(valueType); @@ -843,7 +846,7 @@ void constructSSA(ConstructSSAContext* context) } IRTerminatorInst* newTerminator = (IRTerminatorInst*)blockInfo->builder.emitIntrinsicInst( - oldTerminator->type, + oldTerminator->getFullType(), oldTerminator->op, newArgCount, newArgs.Buffer()); @@ -878,6 +881,9 @@ void constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal) context.sharedBuilder.module = module; context.sharedBuilder.session = module->session; + context.builder.sharedBuilder = &context.sharedBuilder; + context.builder.setInsertInto(module->moduleInst); + constructSSA(&context); } @@ -886,8 +892,8 @@ void constructSSA(IRModule* module, IRInst* globalVal) switch (globalVal->op) { case kIROp_Func: - case kIROp_global_var: - case kIROp_global_constant: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: constructSSA(module, (IRGlobalValueWithCode*)globalVal); default: diff --git a/source/slang/ir-validate.cpp b/source/slang/ir-validate.cpp index 95b8f2dff..1e36322f4 100644 --- a/source/slang/ir-validate.cpp +++ b/source/slang/ir-validate.cpp @@ -129,6 +129,9 @@ namespace Slang IRValidateContext* context, IRInst* inst) { + if(inst->getFullType()) + validateIRInstOperand(context, inst, &inst->typeUse); + UInt operandCount = inst->getOperandCount(); for (UInt ii = 0; ii < operandCount; ++ii) { diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 75f43453a..2615c1c07 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -14,38 +14,34 @@ namespace Slang Name* mangledName, IRGlobalValue* originalVal); - - static const IROpInfo kIROpInfos[] = + struct IROpMapEntry { -#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ - { #MNEMONIC, ARG_COUNT, FLAGS, }, -#include "ir-inst-defs.h" + IROp op; + IROpInfo info; }; - // - - IROp findIROp(char const* name) + // TODO: We should ideally be speeding up the name->inst + // mapping by using a dictionary, or even by pre-computing + // a hash table to be stored as a `static const` array. + static const IROpMapEntry kIROps[] = { - // TODO: need to make this faster by using a dictionary... - - static const struct { - char const* mnemonic; - IROp op; - } kOps[] = { + { kIROp_Invalid, { "invalid", 0, 0 } }, #define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ - { #MNEMONIC, kIROp_##ID }, - + { kIROp_##ID, { #MNEMONIC, ARG_COUNT, FLAGS, } }, #define PSEUDO_INST(ID) \ - { #ID, kIRPseudoOp_##ID }, - + { kIRPseudoOp_##ID, { #ID, 0, 0 } }, #include "ir-inst-defs.h" - }; + }; + + // - for (auto ee : kOps) + IROp findIROp(char const* name) + { + for (auto ee : kIROps) { - if (strcmp(name, ee.mnemonic) == 0) + if (strcmp(name, ee.info.name) == 0) return ee.op; } @@ -54,7 +50,13 @@ namespace Slang IROpInfo getIROpInfo(IROp op) { - return kIROpInfos[op]; + for (auto ee : kIROps) + { + if (ee.op == op) + return ee.info; + } + + return kIROps[0].info; } // @@ -65,7 +67,6 @@ namespace Slang auto uv = this->usedValue; if(!uv) { - assert(!user); assert(!nextUse); assert(!prevLink); return; @@ -160,6 +161,22 @@ namespace Slang return nullptr; } + // IRConstant + + IRIntegerValue GetIntVal(IRInst* inst) + { + switch (inst->op) + { + default: + SLANG_UNEXPECTED("needed a known integer value"); + UNREACHABLE_RETURN(0); + + case kIROp_IntLit: + return ((IRConstant*)inst)->u.intVal; + break; + } + } + // IRParam IRParam* IRParam::getNextParam() @@ -167,6 +184,17 @@ namespace Slang return as(getNextInst()); } + // IRArrayTypeBase + + IRInst* IRArrayTypeBase::getElementCount() + { + if (auto arrayType = as(this)) + return arrayType->getElementCount(); + + return nullptr; + } + + // IRBlock IRParam* IRBlock::getLastParam() @@ -416,13 +444,7 @@ namespace Slang return (IRBlock*)use->get(); } - // IRFunc - - IRType* IRFunc::getResultType() { return getType()->getResultType(); } - UInt IRFunc::getParamCount() { return getType()->getParamCount(); } - IRType* IRFunc::getParamType(UInt index) { return getType()->getParamType(index); } - - IRParam* IRFunc::getFirstParam() + IRParam* IRGlobalValueWithParams::getFirstParam() { auto entryBlock = getFirstBlock(); if(!entryBlock) return nullptr; @@ -430,6 +452,12 @@ namespace Slang return entryBlock->getFirstParam(); } + // IRFunc + + IRType* IRFunc::getResultType() { return getDataType()->getResultType(); } + UInt IRFunc::getParamCount() { return getDataType()->getParamCount(); } + IRType* IRFunc::getParamType(UInt index) { return getDataType()->getParamType(index); } + void IRGlobalValueWithCode::addBlock(IRBlock* block) { block->insertAtEnd(this); @@ -589,7 +617,7 @@ namespace Slang { if (rr == leftNonBlock) { - SLANG_ASSERT(!parentNonBlock); + SLANG_ASSERT(!parentNonBlock || parentNonBlock == leftNonBlock); parentNonBlock = rightNonBlock; break; } @@ -677,6 +705,9 @@ namespace Slang for (UInt ii = 0; ii < operandCount; ++ii) { auto operand = inst->getOperand(ii); + if (!operand) + continue; + auto operandParent = operand->getParent(); parent = mergeCandidateParentsForHoistableInst(parent, operandParent); @@ -727,22 +758,6 @@ namespace Slang value->sourceLoc = sourceLocInfo->sourceLoc; } - template - static T* createValue( - IRBuilder* builder, - IROp op, - IRType* type) - { - assert(builder->getModule()); - T* value = (T*)builder->getModule()->memoryPool.allocZero(sizeof(T)); - new(value)T(); - value->op = op; - value->type = type; - builder->getModule()->irObjectsToFree.Add(value); - return value; - } - - // Create an IR instruction/value and initialize it. // // In this case `argCount` and `args` represnt the @@ -752,23 +767,39 @@ namespace Slang static T* createInstImpl( IRModule* module, IRBuilder* builder, - UInt size, IROp op, IRType* type, UInt fixedArgCount, IRInst* const* fixedArgs, - UInt varArgCount = 0, - IRInst* const* varArgs = nullptr) + UInt varArgListCount, + UInt const* listArgCounts, + IRInst* const* const* listArgs) { + UInt varArgCount = 0; + for (UInt ii = 0; ii < varArgListCount; ++ii) + { + varArgCount += listArgCounts[ii]; + } + + UInt size = sizeof(IRInst) + (fixedArgCount + varArgCount) * sizeof(IRUse); + if (sizeof(T) > size) + { + size = sizeof(T); + } + assert(module); T* inst = (T*)module->memoryPool.allocZero(size); new(inst)T(); + inst->operandCount = (uint32_t)(fixedArgCount + varArgCount); inst->op = op; - inst->type = type; + if (type) + { + inst->typeUse.init(inst, type); + } maybeSetSourceLoc(builder, inst); @@ -783,13 +814,21 @@ namespace Slang operand++; } - for( UInt aa = 0; aa < varArgCount; ++aa ) + for (UInt ii = 0; ii < varArgListCount; ++ii) { - if (varArgs) + UInt listArgCount = listArgCounts[ii]; + for (UInt jj = 0; jj < listArgCount; ++jj) { - operand->init(inst, varArgs[aa]); + if (listArgs[ii]) + { + operand->init(inst, listArgs[ii][jj]); + } + else + { + operand->init(inst, nullptr); + } + operand++; } - operand++; } module->irObjectsToFree.Add(inst); return inst; @@ -798,24 +837,46 @@ namespace Slang template static T* createInstImpl( IRBuilder* builder, - UInt size, IROp op, IRType* type, UInt fixedArgCount, IRInst* const* fixedArgs, - UInt varArgCount = 0, + UInt varArgCount = 0, IRInst* const* varArgs = nullptr) { return createInstImpl( builder->getModule(), builder, - size, op, type, fixedArgCount, fixedArgs, - varArgCount, - varArgs); + 1, + &varArgCount, + &varArgs); + } + + template + static T* createInstImpl( + IRBuilder* builder, + IROp op, + IRType* type, + UInt fixedArgCount, + IRInst* const* fixedArgs, + UInt varArgListCount, + UInt const* listArgCount, + IRInst* const* const* listArgs) + { + return createInstImpl( + builder->getModule(), + builder, + op, + type, + fixedArgCount, + fixedArgs, + varArgListCount, + listArgCount, + listArgs); } template @@ -828,7 +889,6 @@ namespace Slang { return createInstImpl( builder, - sizeof(T), op, type, argCount, @@ -843,7 +903,6 @@ namespace Slang { return createInstImpl( builder, - sizeof(T), op, type, 0, @@ -859,7 +918,6 @@ namespace Slang { return createInstImpl( builder, - sizeof(T), op, type, 1, @@ -877,7 +935,6 @@ namespace Slang IRInst* args[] = { arg1, arg2 }; return createInstImpl( builder, - sizeof(T), op, type, 2, @@ -894,7 +951,6 @@ namespace Slang { return createInstImpl( builder, - sizeof(T) + argCount * sizeof(IRUse), op, type, argCount, @@ -913,7 +969,6 @@ namespace Slang { return createInstImpl( builder, - sizeof(T) + varArgCount * sizeof(IRUse), op, type, fixedArgCount, @@ -936,7 +991,6 @@ namespace Slang return createInstImpl( builder, - sizeof(T) + varArgCount * sizeof(IRUse), op, type, fixedArgCount, @@ -949,7 +1003,7 @@ namespace Slang bool operator==(IRInstKey const& left, IRInstKey const& right) { if(left.inst->op != right.inst->op) return false; - if(left.inst->parent != right.inst->parent) return false; + if(left.inst->getFullType() != right.inst->getFullType()) return false; if(left.inst->operandCount != right.inst->operandCount) return false; auto argCount = left.inst->operandCount; @@ -967,7 +1021,7 @@ namespace Slang int IRInstKey::GetHashCode() { auto code = Slang::GetHashCode(inst->op); - code = combineHash(code, Slang::GetHashCode(inst->parent)); + code = combineHash(code, Slang::GetHashCode(inst->getFullType())); code = combineHash(code, Slang::GetHashCode(inst->getOperandCount())); auto argCount = inst->getOperandCount(); @@ -984,7 +1038,7 @@ namespace Slang bool operator==(IRConstantKey const& left, IRConstantKey const& right) { if(left.inst->op != right.inst->op) return false; - if(left.inst->type != right.inst->type) return false; + if(left.inst->getFullType() != right.inst->getFullType()) return false; if(left.inst->u.ptrData[0] != right.inst->u.ptrData[0]) return false; if(left.inst->u.ptrData[1] != right.inst->u.ptrData[1]) return false; return true; @@ -993,7 +1047,7 @@ namespace Slang int IRConstantKey::GetHashCode() { auto code = Slang::GetHashCode(inst->op); - code = combineHash(code, Slang::GetHashCode(inst->type)); + code = combineHash(code, Slang::GetHashCode(inst->getFullType())); code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[0])); code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[1])); return code; @@ -1009,7 +1063,7 @@ namespace Slang IRConstant keyInst; memset(&keyInst, 0, sizeof(keyInst)); keyInst.op = op; - keyInst.type = type; + keyInst.typeUse.usedValue = type; memcpy(&keyInst.u, value, valueSize); IRConstantKey key; @@ -1029,7 +1083,7 @@ namespace Slang // way: we will construct a temporary instruction and // then use it to look up in a cache of instructions. - irValue = createValue(builder, op, type); + irValue = createInst(builder, op, type); memcpy(&irValue->u, value, valueSize); key.inst = irValue; @@ -1049,7 +1103,7 @@ namespace Slang return findOrEmitConstant( this, kIROp_boolConst, - getSession()->getBoolType(), + getBoolType(), sizeof(value), &value); } @@ -1074,72 +1128,330 @@ namespace Slang &value); } - IRUndefined* IRBuilder::emitUndefined(IRType* type) + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + UInt operandListCount, + UInt const* listOperandCounts, + IRInst* const* const* listOperands) { - auto inst = createInst( - this, - kIROp_undefined, - type); + UInt operandCount = 0; + for (UInt ii = 0; ii < operandListCount; ++ii) + { + operandCount += listOperandCounts[ii]; + } + + // We are going to create a dummy instruction on the stack, + // which will be used as a key for lookup, so see if we + // already have an equivalent instruction available to use. + + size_t keySize = sizeof(IRInst) + operandCount * sizeof(IRUse); + IRInst* keyInst = (IRInst*) malloc(keySize); + memset(keyInst, 0, keySize); + + new(keyInst) IRInst(); + keyInst->op = op; + keyInst->typeUse.usedValue = type; + keyInst->operandCount = (uint32_t) operandCount; + + IRUse* operand = keyInst->getOperands(); + for (UInt ii = 0; ii < operandListCount; ++ii) + { + UInt listOperandCount = listOperandCounts[ii]; + for (UInt jj = 0; jj < listOperandCount; ++jj) + { + operand->usedValue = listOperands[ii][jj]; + operand++; + } + } + + IRInstKey key; + key.inst = keyInst; + + IRInst* foundInst = nullptr; + bool found = builder->sharedBuilder->globalValueNumberingMap.TryGetValue(key, foundInst); + + free((void*)keyInst); + + if (found) + { + return foundInst; + } + + // If no instruction was found, then we need to emit it. + + IRInst* inst = createInstImpl( + builder, + op, + type, + 0, + nullptr, + operandListCount, + listOperandCounts, + listOperands); + addHoistableInst(builder, inst); + + key.inst = inst; + builder->sharedBuilder->globalValueNumberingMap.Add(key, inst); - addInst(inst); - return inst; } - IRInst* IRBuilder::getDeclRefVal( - DeclRefBase const& declRef) + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + UInt operandCount, + IRInst* const* operands) + { + return findOrEmitHoistableInst( + builder, + type, + op, + 1, + &operandCount, + &operands); + } + + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + IRInst* operand, + UInt operandCount, + IRInst* const* operands) + { + UInt counts[] = { 1, operandCount }; + IRInst* const* lists[] = { &operand, operands }; + + return findOrEmitHoistableInst( + builder, + type, + op, + 2, + counts, + lists); + } + + + IRType* IRBuilder::getType( + IROp op, + UInt operandCount, + IRInst* const* operands) { - // TODO: we should cache these... - auto irValue = createValue( + return (IRType*) findOrEmitHoistableInst( this, - kIROp_decl_ref, - nullptr); - irValue->declRef = DeclRef(declRef.decl, declRef.substitutions); + nullptr, + op, + operandCount, + operands); + } - addHoistableInst(this, irValue); + IRType* IRBuilder::getType( + IROp op) + { + return getType(op, 0, nullptr); + } - return irValue; + IRBasicType* IRBuilder::getBasicType(BaseType baseType) + { + return (IRBasicType*)getType( + IROp((UInt)kIROp_FirstBasicType + (UInt)baseType)); + } + + IRBasicType* IRBuilder::getVoidType() + { + return (IRVoidType*)getType(kIROp_VoidType); + } + + IRBasicType* IRBuilder::getBoolType() + { + return (IRBoolType*)getType(kIROp_BoolType); + } + + IRBasicType* IRBuilder::getIntType() + { + return (IRBasicType*)getType(kIROp_IntType); + } + + IRBasicBlockType* IRBuilder::getBasicBlockType() + { + return (IRBasicBlockType*)getType(kIROp_BasicBlockType); + } + + IRTypeKind* IRBuilder::getTypeKind() + { + return (IRTypeKind*)getType(kIROp_TypeKind); + } + + IRGenericKind* IRBuilder::getGenericKind() + { + return (IRGenericKind*)getType(kIROp_GenericKind); + } + + IRPtrType* IRBuilder::getPtrType(IRType* valueType) + { + return (IRPtrType*) getPtrType(kIROp_PtrType, valueType); + } + + IROutType* IRBuilder::getOutType(IRType* valueType) + { + return (IROutType*) getPtrType(kIROp_OutType, valueType); + } + + IRInOutType* IRBuilder::getInOutType(IRType* valueType) + { + return (IRInOutType*) getPtrType(kIROp_InOutType, valueType); + } + + IRPtrTypeBase* IRBuilder::getPtrType(IROp op, IRType* valueType) + { + IRInst* operands[] = { valueType }; + return (IRPtrTypeBase*) getType( + op, + 1, + operands); + } + + IRArrayTypeBase* IRBuilder::getArrayTypeBase( + IROp op, + IRType* elementType, + IRInst* elementCount) + { + IRInst* operands[] = { elementType, elementCount }; + return (IRArrayTypeBase*)getType( + op, + op == kIROp_ArrayType ? 2 : 1, + operands); + } + + IRArrayType* IRBuilder::getArrayType( + IRType* elementType, + IRInst* elementCount) + { + IRInst* operands[] = { elementType, elementCount }; + return (IRArrayType*)getType( + kIROp_ArrayType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRUnsizedArrayType* IRBuilder::getUnsizedArrayType( + IRType* elementType) + { + IRInst* operands[] = { elementType }; + return (IRUnsizedArrayType*)getType( + kIROp_UnsizedArrayType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRVectorType* IRBuilder::getVectorType( + IRType* elementType, + IRInst* elementCount) + { + IRInst* operands[] = { elementType, elementCount }; + return (IRVectorType*)getType( + kIROp_VectorType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRMatrixType* IRBuilder::getMatrixType( + IRType* elementType, + IRInst* rowCount, + IRInst* columnCount) + { + IRInst* operands[] = { elementType, rowCount, columnCount }; + return (IRMatrixType*)getType( + kIROp_MatrixType, + sizeof(operands) / sizeof(operands[0]), + operands); } - IRInst* IRBuilder::getTypeVal(IRType * type) + IRFuncType* IRBuilder::getFuncType( + UInt paramCount, + IRType* const* paramTypes, + IRType* resultType) { - auto irValue = createValue( + return (IRFuncType*) findOrEmitHoistableInst( this, - kIROp_TypeType, - nullptr); - irValue->type = type; - if (auto typetype = dynamic_cast(type)) - irValue->type = typetype->type; - return irValue; + nullptr, + kIROp_FuncType, + resultType, + paramCount, + (IRInst* const*) paramTypes); } - IRInst* IRBuilder::emitSpecializeInst( - Type* type, - IRInst* genericVal, - IRInst* specDeclRef) + IRConstExprRate* IRBuilder::getConstExprRate() + { + return (IRConstExprRate*)getType(kIROp_ConstExprRate); + } + + IRGroupSharedRate* IRBuilder::getGroupSharedRate() + { + return (IRGroupSharedRate*)getType(kIROp_GroupSharedRate); + } + + IRRateQualifiedType* IRBuilder::getRateQualifiedType( + IRRate* rate, + IRType* dataType) + { + IRInst* operands[] = { rate, dataType }; + return (IRRateQualifiedType*)getType( + kIROp_RateQualifiedType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + void IRBuilder::setDataType(IRInst* inst, IRType* dataType) + { + if (auto oldRateQualifiedType = as(inst->getFullType())) + { + // Construct a new rate-qualified type using the same rate. + + auto newRateQualifiedType = getRateQualifiedType( + oldRateQualifiedType->getRate(), + dataType); + + inst->setFullType(newRateQualifiedType); + } + else + { + // No rate? Just clobber the data type. + inst->setFullType(dataType); + } + } + + + IRUndefined* IRBuilder::emitUndefined(IRType* type) { - auto inst = createInst( + auto inst = createInst( this, - kIROp_specialize, - type, - genericVal, - specDeclRef); + kIROp_undefined, + type); + addInst(inst); + return inst; } IRInst* IRBuilder::emitSpecializeInst( - Type* type, + IRType* type, IRInst* genericVal, - DeclRef specDeclRef) + UInt argCount, + IRInst* const* args) { - auto specDeclRefVal = getDeclRefVal(specDeclRef); - auto inst = createInst( + auto inst = createInstWithTrailingArgs( this, - kIROp_specialize, + kIROp_Specialize, type, - genericVal, - specDeclRefVal); + 1, + &genericVal, + argCount, + args); + addInst(inst); return inst; } @@ -1155,45 +1467,7 @@ namespace Slang type, witnessTableVal, interfaceMethodVal); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitLookupInterfaceMethodInst( - IRType* type, - DeclRef witnessTableDeclRef, - DeclRef interfaceMethodDeclRef) - { - auto witnessTableVal = getDeclRefVal(witnessTableDeclRef); - DeclRef removeSubstDeclRef = interfaceMethodDeclRef; - removeSubstDeclRef.substitutions = SubstitutionSet(); - auto interfaceMethodVal = getDeclRefVal(removeSubstDeclRef); - return emitLookupInterfaceMethodInst(type, witnessTableVal, interfaceMethodVal); - } - - IRInst* IRBuilder::emitLookupInterfaceMethodInst( - IRType* type, - IRInst* witnessTableVal, - DeclRef interfaceMethodDeclRef) - { - DeclRef removeSubstDeclRef = interfaceMethodDeclRef; - removeSubstDeclRef.substitutions = SubstitutionSet(); - auto interfaceMethodVal = getDeclRefVal(removeSubstDeclRef); - return emitLookupInterfaceMethodInst(type, witnessTableVal, interfaceMethodVal); - } - IRInst* IRBuilder::emitFindWitnessTable( - DeclRef baseTypeDeclRef, - IRType* interfaceType) - { - auto interfaceTypeDeclRef = interfaceType->AsDeclRefType(); - SLANG_ASSERT(interfaceTypeDeclRef); - auto inst = createInst( - this, - kIROp_lookup_witness_table, - interfaceType, - getDeclRefVal(baseTypeDeclRef), - getDeclRefVal(interfaceTypeDeclRef->declRef)); addInst(inst); return inst; } @@ -1279,10 +1553,12 @@ namespace Slang auto moduleInst = createInstImpl( module, this, - sizeof(IRModuleInst), kIROp_Module, nullptr, 0, + nullptr, + 0, + nullptr, nullptr); module->moduleInst = moduleInst; @@ -1290,58 +1566,103 @@ namespace Slang } void addGlobalValue( - IRModule* module, + IRBuilder* builder, IRGlobalValue* value) { - if(!module) - return; - - value->insertAtEnd(module->moduleInst); - } + // Try to find a suitable parent for the + // global value we are emitting. + // + // We will start out search at the current + // parent instruction for the builder, and + // possibly work our way up. + // + auto parent = builder->insertIntoParent; + while(parent) + { + // Inserting into the top level of a module? + // That is fine, and we can stop searching. + if (as(parent)) + break; + + // Inserting into a basic block inside of + // a generic? That is okay too. + if (auto block = as(parent)) + { + if (as(block->parent)) + break; + } + + // Otherwise, move up the chain. + parent = parent->parent; + } + + // If we somehow ran out of parents (possibly + // because an instruction wasn't linked into + // the full hierarchy yet), then we will + // fall back to inserting into the overall module. + if (!parent) + { + parent = builder->getModule()->getModuleInst(); + } + + // If it turns out that we are inserting into the + // current "insert into" parent for the builder, then + // we need to respect its "insert before" setting + // as well. + if (parent == builder->insertIntoParent + && builder->insertBeforeInst) + { + value->insertBefore(builder->insertBeforeInst); + } + else + { + value->insertAtEnd(parent); + } + } IRFunc* IRBuilder::createFunc() { - IRFunc* rsFunc = createValue( + IRFunc* rsFunc = createInst( this, kIROp_Func, nullptr); maybeSetSourceLoc(this, rsFunc); - addGlobalValue(getModule(), rsFunc); + addGlobalValue(this, rsFunc); return rsFunc; } IRGlobalVar* IRBuilder::createGlobalVar( IRType* valueType) { - auto ptrType = getSession()->getPtrType(valueType); - IRGlobalVar* globalVar = createValue( + auto ptrType = getPtrType(valueType); + IRGlobalVar* globalVar = createInst( this, - kIROp_global_var, + kIROp_GlobalVar, ptrType); maybeSetSourceLoc(this, globalVar); - addGlobalValue(getModule(), globalVar); + addGlobalValue(this, globalVar); return globalVar; } IRGlobalConstant* IRBuilder::createGlobalConstant( IRType* valueType) { - IRGlobalConstant* globalConstant = createValue( + IRGlobalConstant* globalConstant = createInst( this, - kIROp_global_constant, + kIROp_GlobalConstant, valueType); maybeSetSourceLoc(this, globalConstant); - addGlobalValue(getModule(), globalConstant); + addGlobalValue(this, globalConstant); return globalConstant; } IRWitnessTable* IRBuilder::createWitnessTable() { - IRWitnessTable* witnessTable = createValue( + IRWitnessTable* witnessTable = createInst( this, - kIROp_witness_table, + kIROp_WitnessTable, nullptr); - addGlobalValue(getModule(), witnessTable); + addGlobalValue(this, witnessTable); return witnessTable; } @@ -1352,7 +1673,7 @@ namespace Slang { IRWitnessTableEntry* entry = createInst( this, - kIROp_witness_table_entry, + kIROp_WitnessTableEntry, nullptr, requirementKey, satisfyingVal); @@ -1365,6 +1686,68 @@ namespace Slang return entry; } + IRStructType* IRBuilder::createStructType() + { + IRStructType* structType = createInst( + this, + kIROp_StructType, + nullptr); + addGlobalValue(this, structType); + return structType; + } + + IRStructKey* IRBuilder::createStructKey() + { + IRStructKey* structKey = createInst( + this, + kIROp_StructKey, + nullptr); + addGlobalValue(this, structKey); + return structKey; + } + + // Create a field nested in a struct type, declaring that + // the specified field key maps to a field with the specified type. + IRStructField* IRBuilder::createStructField( + IRStructType* structType, + IRStructKey* fieldKey, + IRType* fieldType) + { + IRInst* operands[] = { fieldKey, fieldType }; + IRStructField* field = (IRStructField*) createInstWithTrailingArgs( + this, + kIROp_StructField, + nullptr, + 0, + nullptr, + 2, + operands); + + if (structType) + { + field->insertAtEnd(structType); + } + + return field; + } + + IRGeneric* IRBuilder::createGeneric() + { + IRGeneric* irGeneric = createInst( + this, + kIROp_Generic, + nullptr); + return irGeneric; + } + + IRGeneric* IRBuilder::emitGeneric() + { + auto irGeneric = createGeneric(); + addGlobalValue(this, irGeneric); + return irGeneric; + } + + IRWitnessTable * IRBuilder::lookupWitnessTable(Name* mangledName) { IRWitnessTable * result; @@ -1381,10 +1764,10 @@ namespace Slang IRBlock* IRBuilder::createBlock() { - return createValue( + return createInst( this, kIROp_Block, - getSession()->getIRBasicBlockType()); + getBasicBlockType()); } IRBlock* IRBuilder::emitBlock() @@ -1409,7 +1792,7 @@ namespace Slang IRParam* IRBuilder::createParam( IRType* type) { - auto param = createValue( + auto param = createInst( this, kIROp_Param, type); @@ -1430,7 +1813,7 @@ namespace Slang IRVar* IRBuilder::emitVar( IRType* type) { - auto allocatedType = getSession()->getPtrType(type); + auto allocatedType = getPtrType(type); auto inst = createInst( this, kIROp_Var, @@ -1449,12 +1832,12 @@ namespace Slang // results) at the "default" rate of the parent function, // unless a subsequent analysis pass constraints it. - RefPtr valueType; - if(auto ptrType = ptr->getDataType()->As()) + IRType* valueType = nullptr; + if(auto ptrType = as(ptr->getDataType())) { valueType = ptrType->getValueType(); } - else if(auto ptrLikeType = ptr->getDataType()->As()) + else if(auto ptrLikeType = as(ptr->getDataType())) { valueType = ptrLikeType->getElementType(); } @@ -1465,15 +1848,20 @@ namespace Slang return nullptr; } - // Ugly special case: the result of loading from `groupshared` - // memory should not itself be `groupshared`. + // Ugly special case: if the front-end created a variable with + // type `Ptr<@R T>` instead of `@R Ptr`, then the above + // logic will yield `@R T` instead of `T`, and we need to + // try and fix that up here. + // + // TODO: Lowering to the IR should be fixed to never create + // that case: rate-qualified types should only be allowed + // to appear as the type of an instruction, and should not + // be allowed as operands to type constructors (except + // in special cases we decide to allow). // - // TODO: This special case will go away once `GroupSharedType` - // is replaced by a `GroupSharedRate` that gets used together - // with `RateQualifiedType`. - if(auto rateType = valueType->As()) + if(auto rateType = as(valueType)) { - valueType = rateType->valueType; + valueType = rateType->getValueType(); } auto inst = createInst( @@ -1589,7 +1977,7 @@ namespace Slang UInt elementCount, UInt const* elementIndices) { - auto intType = getSession()->getBuiltinType(BaseType::Int); + auto intType = getBasicType(BaseType::Int); IRInst* irElementIndices[4]; for (UInt ii = 0; ii < elementCount; ++ii) @@ -1631,7 +2019,7 @@ namespace Slang UInt elementCount, UInt const* elementIndices) { - auto intType = getSession()->getBuiltinType(BaseType::Int); + auto intType = getBasicType(BaseType::Int); IRInst* irElementIndices[4]; for (UInt ii = 0; ii < elementCount; ++ii) @@ -1802,6 +2190,30 @@ namespace Slang return inst; } + IRGlobalGenericParam* IRBuilder::emitGlobalGenericParam() + { + IRGlobalGenericParam* irGenericParam = createInst( + this, + kIROp_GlobalGenericParam, + nullptr); + addGlobalValue(this, irGenericParam); + return irGenericParam; + } + + IRBindGlobalGenericParam* IRBuilder::emitBindGlobalGenericParam( + IRInst* param, + IRInst* val) + { + auto inst = createInst( + this, + kIROp_BindGlobalGenericParam, + nullptr, + param, + val); + addInst(inst); + return inst; + } + IRHighLevelDeclDecoration* IRBuilder::addHighLevelDeclDecoration(IRInst* inst, Decl* decl) { auto decoration = addDecoration(inst, kIRDecorationOp_HighLevelDecl); @@ -1873,6 +2285,11 @@ namespace Slang bool opHasResult(IRInst* inst); + bool instHasUses(IRInst* inst) + { + return inst->firstUse != nullptr; + } + static UInt getID( IRDumpContext* context, IRInst* value) @@ -1881,7 +2298,7 @@ namespace Slang if (context->mapValueToID.TryGetValue(value, id)) return id; - if (opHasResult(value)) + if (opHasResult(value) || instHasUses(value)) { id = context->idCounter++; } @@ -1900,33 +2317,30 @@ namespace Slang return; } - switch(inst->op) + if (auto globalValue = as(inst)) { - case kIROp_Func: - case kIROp_global_var: - case kIROp_global_constant: - case kIROp_witness_table: - { - auto irFunc = (IRFunc*) inst; - dump(context, "@"); - dump(context, getText(irFunc->mangledName).Buffer()); - } - break; - - default: + auto mangledName = globalValue->mangledName; + if(mangledName) { - UInt id = getID(context, inst); - if (id) - { - dump(context, "%"); - dump(context, id); - } - else + auto mangledNameText = getText(mangledName); + if (mangledNameText.Length() > 0) { - dump(context, "_"); + dump(context, "@"); + dump(context, mangledNameText.Buffer()); + return; } } - break; + } + + UInt id = getID(context, inst); + if (id) + { + dump(context, "%"); + dump(context, id); + } + else + { + dump(context, "_"); } } @@ -1945,7 +2359,7 @@ namespace Slang // TODO: we should have a dedicated value for the `undef` case if (!inst) { - dump(context, "undef"); + dumpID(context, inst); return; } @@ -1963,16 +2377,6 @@ namespace Slang dump(context, ((IRConstant*)inst)->u.intVal ? "true" : "false"); return; - case kIROp_TypeType: - dumpType(context, (IRType*)inst); - return; - - case kIROp_decl_ref: - dump(context, "$\""); - dumpDeclRef(context, ((IRDeclRef*)inst)->declRef); - dump(context, "\""); - return; - default: break; } @@ -1980,123 +2384,6 @@ namespace Slang dumpID(context, inst); } - static void dump( - IRDumpContext* context, - Name* name) - { - dump(context, getText(name).Buffer()); - } - - static void dumpVal( - IRDumpContext* context, - Val* val) - { - if(auto type = dynamic_cast(val)) - { - dumpType(context, type); - } - else if(auto constIntVal = dynamic_cast(val)) - { - dump(context, constIntVal->value); - } - else if(auto genericParamVal = dynamic_cast(val)) - { - dumpDeclRef(context, genericParamVal->declRef); - } - else if(auto declaredSubtypeWitness = dynamic_cast(val)) - { - dump(context, "DeclaredSubtypeWitness("); - dumpType(context, declaredSubtypeWitness->sub); - dump(context, ", "); - dumpType(context, declaredSubtypeWitness->sup); - dump(context, ", "); - dumpDeclRef(context, declaredSubtypeWitness->declRef); - dump(context, ")"); - } - else if (auto proxyVal = dynamic_cast(val)) - { - dumpOperand(context, proxyVal->inst.get()); - } - else - { - dump(context, "???"); - } - } - - static void dumpDeclRef( - IRDumpContext* context, - DeclRef const& declRef) - { - auto decl = declRef.getDecl(); - - auto parentDeclRef = declRef.GetParent(); - auto genericParentDeclRef = parentDeclRef.As(); - if (genericParentDeclRef) - { - if (genericParentDeclRef.getDecl()->inner.Ptr() == decl) - { - parentDeclRef = genericParentDeclRef.GetParent(); - } - else - { - genericParentDeclRef = DeclRef(); - } - } - - if(parentDeclRef.As()) - { - parentDeclRef = DeclRef(); - } - else if(parentDeclRef.As()) - { - parentDeclRef = DeclRef(); - } - - if(parentDeclRef) - { - dumpDeclRef(context, parentDeclRef); - dump(context, "."); - } - dump(context, decl->getName()); - if (auto genericTypeConstraintDecl = dynamic_cast(decl)) - { - dump(context, "{"); - dumpType(context, genericTypeConstraintDecl->sub); - dump(context, " : "); - dumpType(context, genericTypeConstraintDecl->sup); - dump(context, "}"); - } - else if (auto inheritanceDecl = dynamic_cast(decl)) - { - dump(context, "{ _ : "); - dumpType(context, inheritanceDecl->base); - dump(context, "}"); - } - - if(genericParentDeclRef) - { - auto subst = declRef.substitutions.genericSubstitutions; - if( !subst || subst->genericDecl != genericParentDeclRef.getDecl() ) - { - // No actual substitutions in place here - dump(context, "<>"); - } - else - { - auto args = subst->args; - bool first = true; - dump(context, "<"); - for(auto aa : args) - { - if(!first) dump(context, ","); - dumpVal(context, aa); - first = false; - } - dump(context, ">"); - } - } - } - static void dumpType( IRDumpContext* context, IRType* type) @@ -2107,84 +2394,10 @@ namespace Slang return; } - if(auto funcType = type->As()) - { - UInt paramCount = funcType->getParamCount(); - dump(context, "("); - for( UInt pp = 0; pp < paramCount; ++pp ) - { - if(pp != 0) dump(context, ", "); - dumpType(context, funcType->getParamType(pp)); - } - dump(context, ") -> "); - dumpType(context, funcType->getResultType()); - } - else if(auto arrayType = type->As()) - { - dumpType(context, arrayType->baseType); - dump(context, "["); - if(auto elementCount = arrayType->ArrayLength) - { - dumpVal(context, elementCount); - } - dump(context, "]"); - } - else if(auto declRefType = type->As()) - { - dumpDeclRef(context, declRefType->declRef); - } - else if(auto groupSharedType = type->As()) - { - dump(context, "@ThreadGroup "); - dumpType(context, groupSharedType->valueType); - } - else if(auto rateQualifiedType = type->As()) - { - dump(context, "@"); - dumpType(context, rateQualifiedType->rate); - dump(context, " "); - dumpType(context, rateQualifiedType->valueType); - } - else if(auto constExprRate = type->As()) - { - dump(context, "ConstExpr"); - } - else - { - // Need a default case here - dump(context, "???"); - } - -#if 0 - auto op = type->op; - auto opInfo = kIROpInfos[op]; - - switch (op) - { - case kIROp_StructType: - dumpID(context, type); - break; - - default: - { - dump(context, opInfo.name); - UInt argCount = type->getArgCount(); - - if (argCount > 1) - { - dump(context, "<"); - for (UInt aa = 1; aa < argCount; ++aa) - { - if (aa != 1) dump(context, ","); - dumpOperand(context, type->getArg(aa)); - - } - dump(context, ">"); - } - } - break; - } -#endif + // TODO: we should consider some special-case printing + // for types, so that the IR doesn't get too hard to read + // (always having to back-reference for what a type expands to) + dumpOperand(context, type); } static void dumpInstTypeClause( @@ -2245,60 +2458,11 @@ namespace Slang } } - void dumpGenericSignature( - IRDumpContext* context, - GenericDecl* genericDecl) - { - for( auto pp = genericDecl->ParentDecl; pp; pp = pp->ParentDecl ) - { - if( auto genericAncestor = dynamic_cast(pp) ) - { - dumpGenericSignature(context, genericAncestor); - break; - } - } - - dump(context, " <"); - bool first = true; - for (auto mm : genericDecl->Members) - { - - if( auto typeParamDecl = mm.As() ) - { - if (!first) dump(context, ", "); - dumpDeclRef(context, makeDeclRef(typeParamDecl.Ptr())); - first = false; - } - else if( auto valueParamDecl = mm.As() ) - { - if (!first) dump(context, ", "); - dumpDeclRef(context, makeDeclRef(valueParamDecl.Ptr())); - first = false; - } - } - first = true; - for (auto mm : genericDecl->Members) - { - if( auto constraintDecl = mm.As() ) - { - if (!first) dump(context, ", "); - else dump(context, " where "); - - dumpType(context, constraintDecl->sub); - dump(context, " : "); - dumpType(context, constraintDecl->sup); - first = false; - } - } - dump(context, ">"); - } - - void dumpIRFunc( + void dumpIRDecorations( IRDumpContext* context, - IRFunc* func) + IRInst* inst) { - - for( auto dd = func->firstDecoration; dd; dd = dd->next ) + for( auto dd = inst->firstDecoration; dd; dd = dd->next ) { switch( dd->op ) { @@ -2316,21 +2480,26 @@ namespace Slang } } + } + + void dumpIRGlobalValueWithCode( + IRDumpContext* context, + IRGlobalValueWithCode* code) + { + // TODO: should apply this to all instructions + dumpIRDecorations(context, code); + + auto opInfo = getIROpInfo(code->op); dump(context, "\n"); dumpIndent(context); - dump(context, "ir_func "); - dumpID(context, func); - - if (func->getGenericDecl()) - { - dump(context, " "); - dumpGenericSignature(context, func->getGenericDecl()); - } + dump(context, opInfo.name); + dump(context, " "); + dumpID(context, code); - dumpInstTypeClause(context, func->getType()); + dumpInstTypeClause(context, code->getFullType()); - if (!func->getFirstBlock()) + if (!code->getFirstBlock()) { // Just a declaration. dump(context, ";\n"); @@ -2343,9 +2512,9 @@ namespace Slang dump(context, "{\n"); context->indent++; - for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock()) + for (auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock()) { - if (bb != func->getFirstBlock()) + if (bb != code->getFirstBlock()) dump(context, "\n"); dumpBlock(context, bb); } @@ -2360,43 +2529,11 @@ namespace Slang IRDumpContext dumpContext; StringBuilder sbDump; dumpContext.builder = &sbDump; - dumpIRFunc(&dumpContext, func); + dumpIRGlobalValueWithCode(&dumpContext, func); auto strFunc = sbDump.ToString(); return strFunc; } - void dumpIRGlobalVar( - IRDumpContext* context, - IRGlobalVar* var) - { - dump(context, "\n"); - dumpIndent(context); - dump(context, "ir_global_var "); - dumpID(context, var); - dumpInstTypeClause(context, var->getFullType()); - - // TODO: deal with the case where a global - // might have embedded initialization logic. - - dump(context, ";\n"); - } - - void dumpIRGlobalConstant( - IRDumpContext* context, - IRGlobalConstant* val) - { - dump(context, "\n"); - dumpIndent(context); - dump(context, "ir_global_constant "); - dumpID(context, val); - dumpInstTypeClause(context, val->getFullType()); - - // TODO: deal with the case where a global - // might have embedded initialization logic. - - dump(context, ";\n"); - } - void dumpIRWitnessTableEntry( IRDumpContext* context, IRWitnessTableEntry* entry) @@ -2408,25 +2545,64 @@ namespace Slang dump(context, ")\n"); } - void dumpIRWitnessTable( + void dumpIRParentInst( IRDumpContext* context, - IRWitnessTable* witnessTable) + IRParentInst* inst) { + // TODO: should apply this to all instructions + dumpIRDecorations(context, inst); + + auto opInfo = getIROpInfo(inst->op); + dump(context, "\n"); dumpIndent(context); - dump(context, "ir_witness_table "); - dumpID(context, witnessTable); - dump(context, "\n{\n"); - context->indent++; + dump(context, opInfo.name); + dump(context, " "); + dumpID(context, inst); - for (auto ii : witnessTable->getChildren()) + dumpInstTypeClause(context, inst->getFullType()); + + if (!inst->getFirstChild()) { - dumpInst(context, ii); + // Empty. + dump(context, ";\n"); + return; } - context->indent--; - dump(context, "}\n"); - } + dump(context, "\n"); + + dumpIndent(context); + dump(context, "{\n"); + context->indent++; + + for (auto child = inst->getFirstChild(); child; child = child->getNextInst()) + { + dumpInst(context, child); + } + + context->indent--; + dump(context, "}\n"); + } + + void dumpIRGeneric( + IRDumpContext* context, + IRGeneric* witnessTable) + { + dump(context, "\n"); + dumpIndent(context); + dump(context, "ir_witness_table "); + dumpID(context, witnessTable); + dump(context, "\n{\n"); + context->indent++; + + for (auto ii : witnessTable->getChildren()) + { + dumpInst(context, ii); + } + + context->indent--; + dump(context, "}\n"); + } static void dumpInst( IRDumpContext* context, @@ -2447,22 +2623,18 @@ namespace Slang switch (op) { case kIROp_Func: - dumpIRFunc(context, (IRFunc*)inst); - return; - - case kIROp_global_var: - dumpIRGlobalVar(context, (IRGlobalVar*)inst); - return; - - case kIROp_global_constant: - dumpIRGlobalConstant(context, (IRGlobalConstant*)inst); + case kIROp_GlobalVar: + case kIROp_GlobalConstant: + case kIROp_Generic: + dumpIRGlobalValueWithCode(context, (IRGlobalValueWithCode*)inst); return; - case kIROp_witness_table: - dumpIRWitnessTable(context, (IRWitnessTable*)inst); + case kIROp_WitnessTable: + case kIROp_StructType: + dumpIRParentInst(context, (IRWitnessTable*)inst); return; - case kIROp_witness_table_entry: + case kIROp_WitnessTableEntry: dumpIRWitnessTableEntry(context, (IRWitnessTableEntry*)inst); return; @@ -2473,31 +2645,30 @@ namespace Slang // Okay, we have a seemingly "ordinary" op now dumpIndent(context); - auto opInfo = &kIROpInfos[op]; - auto type = inst->getFullType(); + auto opInfo = getIROpInfo(op); auto dataType = inst->getDataType(); + auto rate = inst->getRate(); - if (!dataType) + if(rate) { - // No result, okay... + dump(context, "@"); + dumpOperand(context, rate); + dump(context, " "); + } + + if(opHasResult(inst) || instHasUses(inst)) + { + dump(context, "let "); + dumpID(context, inst); + dumpInstTypeClause(context, dataType); + dump(context, "\t= "); } else { - auto basicType = dataType->As(); - if (basicType && basicType->baseType == BaseType::Void) - { - // No result, okay... - } - else - { - dump(context, "let "); - dumpID(context, inst); - dumpInstTypeClause(context, type); - dump(context, "\t= "); - } + // No result, okay... } - dump(context, opInfo->name); + dump(context, opInfo.name); UInt argCount = inst->getOperandCount(); UInt ii = 0; @@ -2531,7 +2702,6 @@ namespace Slang case kIROp_IntLit: case kIROp_FloatLit: case kIROp_boolConst: - case kIROp_decl_ref: dumpOperand(context, inst); break; @@ -2596,24 +2766,29 @@ namespace Slang // // - Type* IRInst::getRate() + IRRate* IRInst::getRate() { - if(auto rateQualifiedType = type->As()) - return rateQualifiedType->rate; + if(auto rateQualifiedType = as(getFullType())) + return rateQualifiedType->getRate(); return nullptr; } - Type* IRInst::getDataType() + IRType* IRInst::getDataType() { - if(auto rateQualifiedType = type->As()) - return rateQualifiedType->valueType; + auto type = getFullType(); + if(auto rateQualifiedType = as(type)) + return rateQualifiedType->getValueType(); return type; } void IRInst::replaceUsesWith(IRInst* other) { + // Safety check: don't try to replace something with itself. + if(other == this) + return; + // We will walk through the list of uses for the current // instruction, and make them point to the other inst. IRUse* ff = firstUse; @@ -2683,7 +2858,6 @@ namespace Slang void IRInst::dispose() { IRObject::dispose(); - type = decltype(type)(); } // Insert this instruction into the same basic block @@ -2862,7 +3036,7 @@ namespace Slang IRGlobalVar* addGlobalVariable( IRModule* module, - Type* valueType) + IRType* valueType) { auto session = module->session; @@ -2872,9 +3046,6 @@ namespace Slang IRBuilder builder; builder.sharedBuilder = &shared; - - RefPtr ptrType = session->getPtrType(valueType); - return builder.createGlobalVar(valueType); } @@ -2965,11 +3136,11 @@ namespace Slang { struct Element { + IRStructKey* key; ScalarizedVal val; - DeclRef declRef; }; - RefPtr type; + IRType* type; List elements; }; @@ -2978,8 +3149,8 @@ namespace Slang struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl { ScalarizedVal val; - RefPtr actualType; // the actual type of `val` - RefPtr pretendType; // the type this value pretends to have + IRType* actualType; // the actual type of `val` + IRType* pretendType; // the type this value pretends to have }; struct GlobalVaryingDeclarator @@ -2990,21 +3161,21 @@ namespace Slang }; Flavor flavor; - IntVal* elementCount; + IRInst* elementCount; GlobalVaryingDeclarator* next; }; struct GLSLSystemValueInfo { // The name of the built-in GLSL variable - char const* name; + char const* name; // The name of an outer array that wraps // the variable, in the case of a GS input char const* outerArrayName; // The required type of the built-in variable - RefPtr requiredType; + IRType* requiredType; }; void requireGLSLVersionImpl( @@ -3041,6 +3212,9 @@ namespace Slang { return sink; } + + IRBuilder* builder; + IRBuilder* getBuilder() { return builder; } }; GLSLSystemValueInfo* getGLSLSystemValueInfo( @@ -3059,7 +3233,7 @@ namespace Slang auto semanticName = semanticNameSpelling.ToLower(); - RefPtr requiredType; + IRType* requiredType = nullptr; if(semanticName == "sv_position") { @@ -3190,7 +3364,7 @@ namespace Slang } name = "gl_Layer"; - requiredType = context->session->getBuiltinType(BaseType::Int); + requiredType = context->getBuilder()->getBasicType(BaseType::Int); } else if (semanticName == "sv_sampleindex") { @@ -3262,7 +3436,7 @@ namespace Slang ScalarizedVal createSimpleGLSLGlobalVarying( GLSLLegalizationContext* context, IRBuilder* builder, - Type* inType, + IRType* inType, VarLayout* inVarLayout, TypeLayout* inTypeLayout, LayoutResourceKind kind, @@ -3279,7 +3453,7 @@ namespace Slang stage, &systemValueInfoStorage); - RefPtr type = inType; + IRType* type = inType; // A system-value semantic might end up needing to override the type // that the user specified. @@ -3295,12 +3469,12 @@ namespace Slang { assert(dd->flavor == GlobalVaryingDeclarator::Flavor::array); - RefPtr arrayType = builder->getSession()->getArrayType( + auto arrayType = builder->getArrayType( type, dd->elementCount); RefPtr arrayTypeLayout = new ArrayTypeLayout(); - arrayTypeLayout->type = arrayType; +// arrayTypeLayout->type = arrayType; arrayTypeLayout->rules = typeLayout->rules; arrayTypeLayout->originalElementTypeLayout = typeLayout; arrayTypeLayout->elementTypeLayout = typeLayout; @@ -3355,7 +3529,7 @@ namespace Slang // the actual type of the GLSL global. auto toType = inType; - if( !fromType->Equals(toType) ) + if( fromType != toType ) { RefPtr typeAdapter = new ScalarizedTypeAdapterValImpl; typeAdapter->actualType = systemValueInfo->requiredType; @@ -3381,7 +3555,7 @@ namespace Slang ScalarizedVal createGLSLGlobalVaryingsImpl( GLSLLegalizationContext* context, IRBuilder* builder, - Type* type, + IRType* type, VarLayout* varLayout, TypeLayout* typeLayout, LayoutResourceKind kind, @@ -3389,31 +3563,31 @@ namespace Slang UInt bindingIndex, GlobalVaryingDeclarator* declarator) { - if( type->As() ) + if( as(type) ) { return createSimpleGLSLGlobalVarying( context, builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); } - else if( type->As() ) + else if( as(type) ) { return createSimpleGLSLGlobalVarying( context, builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); } - else if( type->As() ) + else if( as(type) ) { // TODO: a matrix-type varying should probably be handled like an array of rows return createSimpleGLSLGlobalVarying( context, builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); } - else if( auto arrayType = type->As() ) + else if( auto arrayType = as(type) ) { // We will need to SOA-ize any nested types. - auto elementType = arrayType->baseType; - auto elementCount = arrayType->ArrayLength; + auto elementType = arrayType->getElementType(); + auto elementCount = arrayType->getElementCount(); auto arrayLayout = dynamic_cast(typeLayout); SLANG_ASSERT(arrayLayout); auto elementTypeLayout = arrayLayout->elementTypeLayout; @@ -3434,7 +3608,7 @@ namespace Slang bindingIndex, &arrayDeclarator); } - else if( auto streamType = type->As() ) + else if( auto streamType = as(type)) { auto elementType = streamType->getElementType(); auto streamLayout = dynamic_cast(typeLayout); @@ -3452,66 +3626,60 @@ namespace Slang bindingIndex, declarator); } - else if( auto declRefType = type->As() ) + else if(auto structType = as(type)) { - auto declRef = declRefType->declRef; - if( auto structDeclRef = declRef.As() ) - { - // This is either a user-defined struct, or a builtin type. - // TODO: exclude resource types here. + // We need to recurse down into the individual fields, + // and generate a variable for each of them. - // We need to recurse down into the individual fields, - // and generate a variable for each of them. + auto structTypeLayout = dynamic_cast(typeLayout); + SLANG_ASSERT(structTypeLayout); + RefPtr tupleValImpl = new ScalarizedTupleValImpl(); - // Note: we can use the presence of a `StructTypeLayout` as - // a quick way to reject a bunch of types that aren't actually `struct`s - auto structTypeLayout = dynamic_cast(typeLayout); - if( structTypeLayout ) - { - RefPtr tupleValImpl = new ScalarizedTupleValImpl(); + // Construct the actual type for the tuple (including any outer arrays) + IRType* fullType = type; + for( auto dd = declarator; dd; dd = dd->next ) + { + assert(dd->flavor == GlobalVaryingDeclarator::Flavor::array); + fullType = builder->getArrayType( + fullType, + dd->elementCount); + } - // Construct the actual type for the tuple (including any outer arrays) - RefPtr fullType = type; - for( auto dd = declarator; dd; dd = dd->next ) - { - assert(dd->flavor == GlobalVaryingDeclarator::Flavor::array); - fullType = builder->getSession()->getArrayType( - fullType, - dd->elementCount); - } + tupleValImpl->type = fullType; - tupleValImpl->type = fullType; + // Okay, we want to walk through the fields here, and + // generate one variable for each. + UInt fieldCounter = 0; + for(auto field : structType->getFields()) + { + UInt fieldIndex = fieldCounter++; - // Okay, we want to walk through the fields here, and - // generate one variable for each. - for( auto ff : structTypeLayout->fields ) - { - UInt fieldBindingIndex = bindingIndex; - if(auto fieldResInfo = ff->FindResourceInfo(kind)) - fieldBindingIndex += fieldResInfo->index; + auto fieldLayout = structTypeLayout->fields[fieldIndex]; - auto fieldVal = createGLSLGlobalVaryingsImpl( - context, - builder, - ff->typeLayout->type, - ff, - ff->typeLayout, - kind, - stage, - fieldBindingIndex, - declarator); - - ScalarizedTupleValImpl::Element element; - element.val = fieldVal; - element.declRef = ff->varDecl; - - tupleValImpl->elements.Add(element); - } + UInt fieldBindingIndex = bindingIndex; + if(auto fieldResInfo = fieldLayout->FindResourceInfo(kind)) + fieldBindingIndex += fieldResInfo->index; - return ScalarizedVal::tuple(tupleValImpl); - } + auto fieldVal = createGLSLGlobalVaryingsImpl( + context, + builder, + field->getFieldType(), + fieldLayout, + fieldLayout->typeLayout, + kind, + stage, + fieldBindingIndex, + declarator); + + ScalarizedTupleValImpl::Element element; + element.val = fieldVal; + element.key = field->getKey(); + + tupleValImpl->elements.Add(element); } + + return ScalarizedVal::tuple(tupleValImpl); } // Default case is to fall back on the simple behavior @@ -3523,7 +3691,7 @@ namespace Slang ScalarizedVal createGLSLGlobalVaryings( GLSLLegalizationContext* context, IRBuilder* builder, - Type* type, + IRType* type, VarLayout* layout, LayoutResourceKind kind, Stage stage) @@ -3536,27 +3704,44 @@ namespace Slang builder, type, layout, layout->typeLayout, kind, stage, bindingIndex, nullptr); } + IRType* getFieldType( + IRType* baseType, + IRStructKey* fieldKey) + { + if(auto structType = as(baseType)) + { + for(auto ff : structType->getFields()) + { + if(ff->getKey() == fieldKey) + return ff->getFieldType(); + } + } + + SLANG_UNEXPECTED("no such field"); + UNREACHABLE_RETURN(nullptr); + } + ScalarizedVal extractField( IRBuilder* builder, ScalarizedVal const& val, UInt fieldIndex, - DeclRef fieldDeclRef) + IRStructKey* fieldKey) { switch( val.flavor ) { case ScalarizedVal::Flavor::value: return ScalarizedVal::value( builder->emitFieldExtract( - GetType(fieldDeclRef.As()), + getFieldType(val.irValue->getDataType(), fieldKey), val.irValue, - builder->getDeclRefVal(fieldDeclRef))); + fieldKey)); case ScalarizedVal::Flavor::address: return ScalarizedVal::address( builder->emitFieldAddress( - GetType(fieldDeclRef.As()), + getFieldType(val.irValue->getDataType(), fieldKey), val.irValue, - builder->getDeclRefVal(fieldDeclRef))); + fieldKey)); case ScalarizedVal::Flavor::tuple: { @@ -3574,8 +3759,8 @@ namespace Slang ScalarizedVal adaptType( IRBuilder* builder, IRInst* val, - Type* toType, - Type* /*fromType*/) + IRType* toType, + IRType* /*fromType*/) { // TODO: actually consider what needs to go on here... return ScalarizedVal::value(builder->emitConstructorInst( @@ -3587,8 +3772,8 @@ namespace Slang ScalarizedVal adaptType( IRBuilder* builder, ScalarizedVal const& val, - Type* toType, - Type* fromType) + IRType* toType, + IRType* fromType) { switch( val.flavor ) { @@ -3647,7 +3832,7 @@ namespace Slang builder, left, ee, - rightElement.declRef); + rightElement.key); assign(builder, leftElementVal, rightElement.val); } } @@ -3672,7 +3857,7 @@ namespace Slang builder, right, ee, - leftTupleVal->elements[ee].declRef); + leftTupleVal->elements[ee].key); assign(builder, leftTupleVal->elements[ee].val, rightElementVal); } } @@ -3699,7 +3884,7 @@ namespace Slang ScalarizedVal getSubscriptVal( IRBuilder* builder, - Type* elementType, + IRType* elementType, ScalarizedVal val, IRInst* indexVal) { @@ -3715,7 +3900,7 @@ namespace Slang case ScalarizedVal::Flavor::address: return ScalarizedVal::address( builder->emitElementAddress( - builder->getSession()->getPtrType(elementType), + builder->getPtrType(elementType), val.irValue, indexVal)); @@ -3729,18 +3914,10 @@ namespace Slang UInt elementCount = inputTuple->elements.Count(); UInt elementCounter = 0; - auto declRefType = dynamic_cast(elementType); - SLANG_RELEASE_ASSERT(declRefType); - - auto aggTypeDeclRef = declRefType->declRef.As(); - SLANG_RELEASE_ASSERT(aggTypeDeclRef); - - for(auto fieldDeclRef : getMembersOfType(aggTypeDeclRef)) + auto structType = as(elementType); + for(auto field : structType->getFields()) { - if(fieldDeclRef.getDecl()->HasModifier()) - continue; - - auto tupleElementType = GetType(fieldDeclRef); + auto tupleElementType = field->getFieldType(); UInt elementIndex = elementCounter++; @@ -3748,7 +3925,7 @@ namespace Slang auto inputElement = inputTuple->elements[elementIndex]; ScalarizedTupleValImpl::Element resultElement; - resultElement.declRef = inputElement.declRef; + resultElement.key = inputElement.key; resultElement.val = getSubscriptVal( builder, tupleElementType, @@ -3770,7 +3947,7 @@ namespace Slang ScalarizedVal getSubscriptVal( IRBuilder* builder, - Type* elementType, + IRType* elementType, ScalarizedVal val, UInt index) { @@ -3779,7 +3956,7 @@ namespace Slang elementType, val, builder->getIntValue( - builder->getSession()->getIntType(), + builder->getIntType(), index)); } @@ -3797,7 +3974,7 @@ namespace Slang UInt elementCount = tupleVal->elements.Count(); auto type = tupleVal->type; - if( auto arrayType = type.As() ) + if( auto arrayType = as(type)) { // The tuple represent an array, which means that the // individual elements are expected to yield arrays as well. @@ -3806,13 +3983,13 @@ namespace Slang // then use these to construct our result. List arrayElementVals; - UInt arrayElementCount = (UInt) GetIntVal(arrayType->ArrayLength); + UInt arrayElementCount = (UInt) GetIntVal(arrayType->getElementCount()); for( UInt ii = 0; ii < arrayElementCount; ++ii ) { auto arrayElementPseudoVal = getSubscriptVal( builder, - arrayType->baseType, + arrayType->getElementType(), val, ii); @@ -3945,6 +4122,8 @@ namespace Slang builder.sharedBuilder = &shared; builder.setInsertInto(func); + context.builder = &builder; + // We will start by looking at the return type of the // function, because that will enable us to do an // early-out check to avoid more work. @@ -3953,7 +4132,7 @@ namespace Slang // a `void` return type, because there is no work // to be done on its return value in that case. auto resultType = func->getResultType(); - if( resultType->Equals(session->getVoidType()) ) + if(as(resultType)) { // In this case, the function doesn't return a value // so we don't need to transform its `return` sites. @@ -4060,10 +4239,10 @@ namespace Slang // don't fit into the standard varying model. // For right now we are only doing special-case handling // of geometry shader output streams. - if( auto paramPtrType = paramType->As() ) + if( auto paramPtrType = as(paramType) ) { auto valueType = paramPtrType->getValueType(); - if( auto gsStreamType = valueType->As() ) + if( auto gsStreamType = as(valueType) ) { // An output stream type like `TriangleStream` should // more or less translate into `out Foo` (plus scalarization). @@ -4097,7 +4276,7 @@ namespace Slang // Is it calling the append operation? auto callee = ii->getOperand(0); - while( callee->op == kIROp_specialize ) + while( callee->op == kIROp_Specialize ) { callee = ((IRSpecialize*) callee)->getOperand(0); } @@ -4132,7 +4311,7 @@ namespace Slang // Is the parameter type a special pointer type // that indicates the parameter is used for `out` // or `inout` access? - if(auto paramPtrType = paramType->As() ) + if(auto paramPtrType = as(paramType) ) { // Okay, we have the more interesting case here, // where the parameter was being passed by reference. @@ -4145,7 +4324,7 @@ namespace Slang auto localVariable = builder.emitVar(valueType); auto localVal = ScalarizedVal::address(localVariable); - if( auto inOutType = paramPtrType->As() ) + if( auto inOutType = as(paramPtrType) ) { // In the `in out` case we need to declare two // sets of global variables: one for the `in` @@ -4236,10 +4415,11 @@ namespace Slang // Finally, we need to patch up the type of the entry point, // because it is no longer accurate. - RefPtr voidFuncType = new FuncType(); - voidFuncType->setSession(session); - voidFuncType->resultType = session->getVoidType(); - func->type = voidFuncType; + IRFuncType* voidFuncType = builder.getFuncType( + 0, + nullptr, + builder.getVoidType()); + func->setFullType(voidFuncType); // TODO: we should technically be constructing // a new `EntryPointLayout` here to reflect @@ -4260,6 +4440,15 @@ namespace Slang RefPtr nextWithSameName; }; + struct IRSpecEnv + { + IRSpecEnv* parent = nullptr; + + // A map from original values to their cloned equivalents. + typedef Dictionary ClonedValueDictionary; + ClonedValueDictionary clonedValues; + }; + struct IRSharedSpecContext { // The code-generation target in use @@ -4277,16 +4466,38 @@ namespace Slang typedef Dictionary> SymbolDictionary; SymbolDictionary symbols; - // A map from values in the original IR module - // to their equivalent in the cloned module. - typedef Dictionary ClonedValueDictionary; - ClonedValueDictionary clonedValues; - SharedIRBuilder sharedBuilderStorage; IRBuilder builderStorage; - // Non-generic functions to be processed (for generic specialization context) - List workList; + // The "global" specialization environment. + IRSpecEnv globalEnv; + }; + + struct IRSharedGenericSpecContext : IRSharedSpecContext + { + // Instructions to be processed (for generic specialization context) + List workList; + HashSet workListSet; + void addToWorkList(IRInst* inst) + { + if(!workListSet.Contains(inst)) + { + workList.Add(inst); + workListSet.Add(inst); + } + } + IRInst* popWorkList() + { + UInt count = workList.Count(); + if(count != 0) + { + IRInst* inst = workList[count - 1]; + workList.FastRemoveAt(count - 1); + workListSet.Remove(inst); + return inst; + } + return nullptr; + } }; struct IRSpecContextBase @@ -4305,13 +4516,23 @@ namespace Slang IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } - IRSharedSpecContext::ClonedValueDictionary& getClonedValues() { return getShared()->clonedValues; } + // The current specialization environment to use. + IRSpecEnv* env = nullptr; + IRSpecEnv* getEnv() + { + // TODO: need to actually establish environments on contexts we create. + // + // Or more realistically we need to change the whole approach + // to specialization and cloning so that we don't try to share + // logic between two very different cases. + + + return env; + } // The IR builder to use for creating nodes IRBuilder* builder; - SubstitutionSet subst; - // A callback to be used when a value that is not registerd in `clonedValues` // is needed during cloning. This gives the subtype a chance to intercept // the operation and clone (or not) as needed. @@ -4319,24 +4540,6 @@ namespace Slang { return originalVal; } - - // A callback used to clone (or not) types. - virtual RefPtr maybeCloneType(Type* originalType) - { - return originalType; - } - - // A callback used to clone (or not) a declaration reference - virtual DeclRef maybeCloneDeclRef(DeclRef const& declRef) - { - return declRef; - } - - // A callback used to clone (or not) a Val - virtual RefPtr maybeCloneVal(Val* val) - { - return val; - } }; void registerClonedValue( @@ -4347,19 +4550,12 @@ namespace Slang if(!originalValue) return; - // Note: setting the entry direclty here rather than - // using `Add` or `AddIfNotExists` because we can conceivably - // clone the same value (e.g., a basic block inside a generic - // function) multiple times, and that is okay, and we really - // just need to keep track of the most recent value. - - // TODO: The same thing could potentially be handled more - // cleanly by having a notion of scoping for these cloned-value - // mappings, so that we register cloned values for things - // inside of a function to a temporary mapping that we - // throw away after the function is done. - - context->getClonedValues()[originalValue] = clonedValue; + // TODO: now that things are scoped using environments, we + // shouldn't be running into the cases where a value with + // the same key already exists. This should be changed to + // an `Add()` call. + // + context->getEnv()->clonedValues[originalValue] = clonedValue; } // Information on values to use when registering a cloned value @@ -4425,6 +4621,22 @@ namespace Slang } break; + case kIRDecorationOp_Semantic: + { + auto originalDecoration = (IRSemanticDecoration*)dd; + auto newDecoration = context->builder->addDecoration(clonedValue); + newDecoration->semanticName = originalDecoration->semanticName; + } + break; + + case kIRDecorationOp_InterpolationMode: + { + auto originalDecoration = (IRInterpolationModeDecoration*)dd; + auto newDecoration = context->builder->addDecoration(clonedValue); + newDecoration->mode = originalDecoration->mode; + } + break; + default: // Don't clone any decorations we don't understand. break; @@ -4435,46 +4647,37 @@ namespace Slang clonedValue->sourceLoc = originalValue->sourceLoc; } + // We use an `IRSpecContext` for the case where we are cloning + // code from one or more input modules to create a "linked" output + // module. Along the way, we will resolve profile-specific functions + // to the best definition for a given target. + // struct IRSpecContext : IRSpecContextBase { // Override the "maybe clone" logic so that we always clone virtual IRInst* maybeCloneValue(IRInst* originalVal) override; - - // Override teh "maybe clone" logic so that we carefully - // clone any IR proxy values inside substitutions - virtual DeclRef maybeCloneDeclRef(DeclRef const& declRef) override; - - virtual RefPtr maybeCloneType(Type* originalType) override; - virtual RefPtr maybeCloneVal(Val* val) override; }; IRGlobalValue* cloneGlobalValue(IRSpecContext* context, IRGlobalValue* originalVal); - RefPtr cloneSubstitutions( - IRSpecContext* context, - Substitutions* subst); - RefPtr IRSpecContext::maybeCloneType(Type* originalType) - { - return originalType->Substitute(subst).As(); - } - - RefPtr IRSpecContext::maybeCloneVal(Val * val) - { - return val->Substitute(subst); - } + IRInst* cloneValue( + IRSpecContextBase* context, + IRInst* originalValue); + IRType* cloneType( + IRSpecContextBase* context, + IRType* originalType); IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) { - switch (originalValue->op) + if (auto globalValue = as(originalValue)) { - case kIROp_global_var: - case kIROp_global_constant: - case kIROp_Func: - case kIROp_witness_table: - return cloneGlobalValue(this, (IRGlobalValue*) originalValue); + return cloneGlobalValue(this, globalValue); + } + switch (originalValue->op) + { case kIROp_boolConst: { IRConstant* c = (IRConstant*)originalValue; @@ -4486,70 +4689,43 @@ namespace Slang case kIROp_IntLit: { IRConstant* c = (IRConstant*)originalValue; - return builder->getIntValue(c->type, c->u.intVal); + return builder->getIntValue(cloneType(this, c->getDataType()), c->u.intVal); } break; case kIROp_FloatLit: { IRConstant* c = (IRConstant*)originalValue; - return builder->getFloatValue(c->type, c->u.floatVal); + return builder->getFloatValue(cloneType(this, c->getDataType()), c->u.floatVal); } break; - case kIROp_decl_ref: + default: { - IRDeclRef* od = (IRDeclRef*)originalValue; - auto newDeclRef = od->declRef; + // In the deafult case, assume that we have some sort of "hoistable" + // instruction that requires us to create a clone of it. - // if the declRef is one of the __generic_param decl being substituted by subst - // return the substituted decl - if (subst.globalGenParamSubstitutions) + UInt argCount = originalValue->getOperandCount(); + IRInst* clonedValue = createInstWithTrailingArgs( + builder, + originalValue->op, + cloneType(this, originalValue->getFullType()), + 0, nullptr, + argCount, nullptr); + registerClonedValue(this, clonedValue, originalValue); + for (UInt aa = 0; aa < argCount; ++aa) { - int diff = 0; - newDeclRef = od->declRef.SubstituteImpl(subst, &diff); - for (auto globalGenSubst = subst.globalGenParamSubstitutions; globalGenSubst; globalGenSubst = globalGenSubst->outer) - { - if (!globalGenSubst) - continue; - if (newDeclRef.getDecl() == globalGenSubst->paramDecl) - return builder->getTypeVal(globalGenSubst->actualType.As()); - else if (auto genConstraint = newDeclRef.As()) - { - // a decl-ref to GenericTypeConstraintDecl as a result of - // referencing a generic parameter type should be replaced with - // the actual witness table - if (genConstraint.getDecl()->ParentDecl == globalGenSubst->paramDecl) - { - // find the witness table from subst - for (auto witness : globalGenSubst->witnessTables) - { - if (witness.Key->EqualsVal(GetSup(genConstraint))) - { - auto proxyVal = witness.Value.As(); - SLANG_ASSERT(proxyVal); - return proxyVal->inst.get(); - } - } - } - } - } + IRInst* originalArg = originalValue->getOperand(aa); + IRInst* clonedArg = cloneValue(this, originalArg); + clonedValue->getOperands()[aa].init(clonedValue, clonedArg); } - auto declRef = maybeCloneDeclRef(newDeclRef); - return builder->getDeclRefVal(declRef); - } - break; - case kIROp_TypeType: - { - IRInst* od = (IRInst*)originalValue; - int ioDiff = 0; - auto newType = od->type->SubstituteImpl(subst, &ioDiff); - return builder->getTypeVal(newType.As()); + cloneDecorations(this, clonedValue, originalValue); + + addHoistableInst(builder, clonedValue); + + return clonedValue; } break; - default: - SLANG_UNEXPECTED("no value registered for IR value"); - UNREACHABLE_RETURN(nullptr); } } @@ -4557,102 +4733,41 @@ namespace Slang IRSpecContextBase* context, IRInst* originalValue); - RefPtr cloneSubstitutionArg( - IRSpecContext* context, - Val* val) + // Find a pre-existing cloned value, or return null if none is available. + IRInst* findClonedValue( + IRSpecContextBase* context, + IRInst* originalValue) { - if (auto proxyVal = dynamic_cast(val)) - { - auto newIRVal = cloneValue(context, proxyVal->inst.get()); - - RefPtr newProxyVal = new IRProxyVal(); - newProxyVal->inst.init(nullptr, newIRVal); - return newProxyVal; - } - else if (auto type = dynamic_cast(val)) - { - return context->maybeCloneType(type); - } - else + IRInst* clonedValue = nullptr; + for (auto env = context->getEnv(); env; env = env->parent) { - return context->maybeCloneVal(val); + if (env->clonedValues.TryGetValue(originalValue, clonedValue)) + { + return clonedValue; + } } - } - - RefPtr cloneGenericSubst(IRSpecContext* context, GenericSubstitution* genSubst) - { - if (!genSubst) - return nullptr; - RefPtr newSubst = new GenericSubstitution(); - newSubst->outer = cloneGenericSubst(context, genSubst->outer); - newSubst->genericDecl = genSubst->genericDecl; - - for (auto arg : genSubst->args) - { - auto newArg = cloneSubstitutionArg(context, arg); - newSubst->args.Add(newArg); - } - return newSubst; + return nullptr; } - RefPtr cloneGlobalGenericSubst(IRSpecContext* context, GlobalGenericParamSubstitution* subst) + IRInst* cloneValue( + IRSpecContextBase* context, + IRInst* originalValue) { - if (!subst) + if (!originalValue) return nullptr; - auto newSubst = new GlobalGenericParamSubstitution(); - newSubst->actualType = subst->actualType; - newSubst->paramDecl = subst->paramDecl; - newSubst->witnessTables = subst->witnessTables; - newSubst->outer = cloneGlobalGenericSubst(context, subst->outer); - return newSubst; - } - - SubstitutionSet cloneSubstitutions( - IRSpecContext* context, - SubstitutionSet subst) - { - SubstitutionSet rs; - if (!subst) - return rs; - rs.genericSubstitutions = cloneGenericSubst(context, subst.genericSubstitutions); - rs.globalGenParamSubstitutions = cloneGlobalGenericSubst(context, subst.globalGenParamSubstitutions); - if (auto thisSubst = subst.thisTypeSubstitution) - { - RefPtr newSubst = new ThisTypeSubstitution(); - newSubst->sourceType = thisSubst->sourceType; - rs.thisTypeSubstitution = newSubst; - } - return rs; - } - DeclRef IRSpecContext::maybeCloneDeclRef(DeclRef const& declRef) - { - // Un-specialized decl? Nothing to do. - if (!declRef.substitutions) - return declRef; - - DeclRef newDeclRef = declRef; - - // Scan through substitutions and clone as needed. - // - // TODO: this is wasteful since we clone *everything* - newDeclRef.substitutions = cloneSubstitutions(this, declRef.substitutions); + if (IRInst* clonedValue = findClonedValue(context, originalValue)) + return clonedValue; - return newDeclRef; + return context->maybeCloneValue(originalValue); } - IRInst* cloneValue( + IRType* cloneType( IRSpecContextBase* context, - IRInst* originalValue) + IRType* originalType) { - IRInst* clonedValue = nullptr; - if (context->getClonedValues().TryGetValue(originalValue, clonedValue)) - { - return clonedValue; - } - - return context->maybeCloneValue(originalValue); + return (IRType*)cloneValue(context, originalType); } IRInst* maybeCloneValueWithMangledName( @@ -4670,50 +4785,19 @@ namespace Slang } return cloneValue(context, originalValue); } - - void cloneInst( + + IRInst* cloneInst( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* originalInst, + IROriginalValuesForClone const& originalValues); + + IRInst* cloneInst( IRSpecContextBase* context, - IRBuilder* builder, - IRInst* originalInst) + IRBuilder* builder, + IRInst* originalInst) { - switch (originalInst->op) - { - // TODO: are there any instruction types that need to be handled - // specially here? That would be anything that has more state - // than is visible in its operand list... - case 0: // nothing yet - default: - { - // The common case is that we just need to construct a cloned - // instruction with the right number of operands, intialize - // it, and then add it to the sequence. - UInt argCount = originalInst->getOperandCount(); - IRInst* clonedInst = createInstWithTrailingArgs( - builder, originalInst->op, - context->maybeCloneType(originalInst->type), - 0, nullptr, - argCount, nullptr); - registerClonedValue(context, clonedInst, originalInst); - auto oldBuilder = context->builder; - context->builder = builder; - for (UInt aa = 0; aa < argCount; ++aa) - { - IRInst* originalArg = originalInst->getOperand(aa); - IRInst* clonedArg; - if (originalArg->op == kIROp_witness_table) - clonedArg = cloneGlobalValueWithMangledName((IRSpecContext*)context, - ((IRGlobalValue*)originalArg)->mangledName, (IRGlobalValue*)originalArg); - else - clonedArg = cloneValue(context, originalArg); - clonedInst->getOperands()[aa].init(clonedInst, clonedArg); - } - builder->addInst(clonedInst); - context->builder = oldBuilder; - cloneDecorations(context, clonedInst, originalInst); - } - - break; - } + return cloneInst(context, builder, originalInst, originalInst); } void cloneGlobalValueWithCodeCommon( @@ -4722,17 +4806,18 @@ namespace Slang IRGlobalValueWithCode* originalValue); IRGlobalVar* cloneGlobalVarImpl( - IRSpecContext* context, - IRGlobalVar* originalVar, + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalVar* originalVar, IROriginalValuesForClone const& originalValues) { - auto clonedVar = context->builder->createGlobalVar( - context->maybeCloneType(originalVar->getDataType()->getValueType())); + auto clonedVar = builder->createGlobalVar( + cloneType(context, originalVar->getDataType()->getValueType())); if(auto rate = originalVar->getRate() ) { - clonedVar->type = context->builder->getSession()->getRateQualifiedType( - rate, clonedVar->type); + clonedVar->setFullType(builder->getRateQualifiedType( + rate, clonedVar->getFullType())); } registerClonedValue(context, clonedVar, originalValues); @@ -4745,7 +4830,7 @@ namespace Slang VarLayout* layout = nullptr; if (context->globalVarLayouts.TryGetValue(mangledName, layout)) { - context->builder->addLayoutDecoration(clonedVar, layout); + builder->addLayoutDecoration(clonedVar, layout); } // Clone any code in the body of the variable, since this @@ -4759,11 +4844,13 @@ namespace Slang } IRGlobalConstant* cloneGlobalConstantImpl( - IRSpecContext* context, - IRGlobalConstant* originalVal, + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalConstant* originalVal, IROriginalValuesForClone const& originalValues) { - auto clonedVal = context->builder->createGlobalConstant(context->maybeCloneType(originalVal->getFullType())); + auto clonedVal = builder->createGlobalConstant( + cloneType(context, originalVal->getFullType())); registerClonedValue(context, clonedVal, originalValues); auto mangledName = originalVal->mangledName; @@ -4781,48 +4868,111 @@ namespace Slang return clonedVal; } - IRWitnessTable* cloneWitnessTableImpl( - IRSpecContextBase* context, - IRWitnessTable* originalTable, - IROriginalValuesForClone const& originalValues, - IRWitnessTable* dstTable = nullptr, - bool registerValue = true) + IRGeneric* cloneGenericImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGeneric* originalVal, + IROriginalValuesForClone const& originalValues) { - auto clonedTable = dstTable ? dstTable : context->builder->createWitnessTable(); - if (registerValue) - registerClonedValue(context, clonedTable, originalValues); + auto clonedVal = builder->emitGeneric(); + registerClonedValue(context, clonedVal, originalValues); - auto mangledName = originalTable->mangledName; - - clonedTable->mangledName = mangledName; - clonedTable->genericDecl = originalTable->genericDecl; - clonedTable->subTypeDeclRef = originalTable->subTypeDeclRef; - clonedTable->supTypeDeclRef = originalTable->supTypeDeclRef; - cloneDecorations(context, clonedTable, originalTable); + auto mangledName = originalVal->mangledName; + clonedVal->mangledName = mangledName; - // Clone the entries in the witness table as well - for(auto originalEntry : originalTable->getEntries() ) - { - auto clonedKey = cloneValue(context, originalEntry->requirementKey.get()); - - // if a global val with the mangled name already exists, don't clone again - auto clonedVal = maybeCloneValueWithMangledName(context, (IRGlobalValue*)(originalEntry->satisfyingVal.get())); + cloneDecorations(context, clonedVal, originalVal); - /*auto clonedEntry = */context->builder->createWitnessTableEntry( - clonedTable, - clonedKey, - clonedVal); - } + // Clone any code in the body of the generic, since this + // computes its result value. + cloneGlobalValueWithCodeCommon( + context, + clonedVal, + originalVal); - return clonedTable; + return clonedVal; } - IRWitnessTable* cloneWitnessTableWithoutRegistering( + void cloneSimpleGlobalValueImpl( + IRSpecContextBase* context, + IRGlobalValue* originalInst, + IROriginalValuesForClone const& originalValues, + IRGlobalValue* clonedInst, + bool registerValue = true) + { + if (registerValue) + registerClonedValue(context, clonedInst, originalValues); + + auto mangledName = originalInst->mangledName; + clonedInst->mangledName = mangledName; + + cloneDecorations(context, clonedInst, originalInst); + + // Set up an IR builder for inserting into the inst + IRBuilder builderStorage = *context->builder; + IRBuilder* builder = &builderStorage; + builder->setInsertInto(clonedInst); + + // Clone any children of the instruction + for (auto child : originalInst->getChildren()) + { + cloneInst(context, builder, child); + } + } + + IRStructKey* cloneStructKeyImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRStructKey* originalVal, + IROriginalValuesForClone const& originalValues) + { + auto clonedVal = builder->createStructKey(); + cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); + return clonedVal; + } + + IRGlobalGenericParam* cloneGlobalGenericParamImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalGenericParam* originalVal, + IROriginalValuesForClone const& originalValues) + { + auto clonedVal = builder->emitGlobalGenericParam(); + cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); + return clonedVal; + } + + + IRWitnessTable* cloneWitnessTableImpl( IRSpecContextBase* context, + IRBuilder* builder, + IRWitnessTable* originalTable, + IROriginalValuesForClone const& originalValues, + IRWitnessTable* dstTable = nullptr, + bool registerValue = true) + { + auto clonedTable = dstTable ? dstTable : builder->createWitnessTable(); + cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); + return clonedTable; + } + + IRWitnessTable* cloneWitnessTableWithoutRegistering( + IRSpecContextBase* context, + IRBuilder* builder, IRWitnessTable* originalTable, IRWitnessTable* dstTable = nullptr) { - return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone(), dstTable, false); + return cloneWitnessTableImpl(context, builder, originalTable, IROriginalValuesForClone(), dstTable, false); + } + + IRStructType* cloneStructTypeImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRStructType* originalStruct, + IROriginalValuesForClone const& originalValues) + { + auto clonedStruct = builder->createStructType(); + cloneSimpleGlobalValueImpl(context, originalStruct, originalValues, clonedStruct); + return clonedStruct; } void cloneGlobalValueWithCodeCommon( @@ -4887,11 +5037,14 @@ namespace Slang } - void checkIRDuplicate(IRParentInst* moduleInst, Name* mangledName) + void checkIRDuplicate(IRInst* inst, IRParentInst* moduleInst, Name* mangledName) { #ifdef _DEBUG for (auto child : moduleInst->getChildren()) { + if (child == inst) + continue; + if (child->op == kIROp_Func) { auto extName = ((IRGlobalValue*)child)->mangledName; @@ -4902,6 +5055,7 @@ namespace Slang } } #else + SLANG_UNREFERENCED_PARAMETER(inst); SLANG_UNREFERENCED_PARAMETER(moduleInst); SLANG_UNREFERENCED_PARAMETER(mangledName); #endif @@ -4915,9 +5069,7 @@ namespace Slang { // First clone all the simple properties. clonedFunc->mangledName = originalFunc->mangledName; - clonedFunc->genericDecls = originalFunc->genericDecls; - clonedFunc->specializedGenericLevel = originalFunc->specializedGenericLevel; - clonedFunc->type = context->maybeCloneType(originalFunc->type); + clonedFunc->setFullType(cloneType(context, originalFunc->getFullType())); cloneDecorations(context, clonedFunc, originalFunc); @@ -4930,10 +5082,9 @@ namespace Slang // it needs to follow its dependencies. // // TODO: This isn't really a good requirement to place on the IR... - clonedFunc->removeFromParent(); + clonedFunc->moveToEnd(); if (checkDuplicate) - checkIRDuplicate(context->getModule()->getModuleInst(), clonedFunc->mangledName); - clonedFunc->insertAtEnd(context->getModule()->getModuleInst()); + checkIRDuplicate(clonedFunc, context->getModule()->getModuleInst(), clonedFunc->mangledName); } IRFunc* specializeIRForEntryPoint( @@ -5072,17 +5223,51 @@ namespace Slang return result; } + IRInst* findGenericReturnVal(IRGeneric* generic) + { + auto lastBlock = generic->getLastBlock(); + if (!lastBlock) + return nullptr; + + auto returnInst = as(lastBlock->getTerminator()); + if (!returnInst) + return nullptr; + + auto val = returnInst->getVal(); + return val; + } + bool isDefinition( - IRGlobalValue* val) + IRGlobalValue* inVal) { + IRInst* val = inVal; + // unwrap any generic declarations to see + // the value they return. + for(;;) + { + auto genericInst = as(val); + if(!genericInst) + break; + + auto returnVal = findGenericReturnVal(genericInst); + if(!returnVal) + break; + + val = returnVal; + } + switch (val->op) { - case kIROp_witness_table: - case kIROp_global_var: - case kIROp_global_constant: + case kIROp_WitnessTable: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: case kIROp_Func: + case kIROp_Generic: return ((IRParentInst*)val)->getFirstChild() != nullptr; + case kIROp_StructType: + return true; + default: return false; } @@ -5146,51 +5331,92 @@ namespace Slang } IRFunc* cloneFuncImpl( - IRSpecContext* context, - IRFunc* originalFunc, + IRSpecContextBase* context, + IRBuilder* builder, + IRFunc* originalFunc, IROriginalValuesForClone const& originalValues) { - auto clonedFunc = context->builder->createFunc(); + auto clonedFunc = builder->createFunc(); registerClonedValue(context, clonedFunc, originalValues); cloneFunctionCommon(context, clonedFunc, originalFunc); return clonedFunc; } - // Directly clone a global value, based on a single definition/declaration, `originalVal`. - // The symbol `sym` will thread together other declarations of the same value, and - // we will register the new value as the cloned version of all of those. - IRGlobalValue* cloneGlobalValueImpl( - IRSpecContext* context, - IRGlobalValue* originalVal, - IRSpecSymbol* sym) - { - if( !originalVal ) - { - SLANG_UNEXPECTED("cloning a null value"); - UNREACHABLE_RETURN(nullptr); - } - switch( originalVal->op ) + IRInst* cloneInst( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* originalInst, + IROriginalValuesForClone const& originalValues) + { + switch (originalInst->op) { + // We need to special-case any instruction that is not + // allocated like an ordinary `IRInst` with trailing args. case kIROp_Func: - return cloneFuncImpl(context, (IRFunc*) originalVal, sym); + return cloneFuncImpl(context, builder, cast(originalInst), originalValues); - case kIROp_global_var: - return cloneGlobalVarImpl(context, (IRGlobalVar*)originalVal, sym); + case kIROp_GlobalVar: + return cloneGlobalVarImpl(context, builder, cast(originalInst), originalValues); - case kIROp_global_constant: - return cloneGlobalConstantImpl(context, (IRGlobalConstant*)originalVal, sym); + case kIROp_GlobalConstant: + return cloneGlobalConstantImpl(context, builder, cast(originalInst), originalValues); - case kIROp_witness_table: - return cloneWitnessTableImpl(context, (IRWitnessTable*)originalVal, sym); + case kIROp_WitnessTable: + return cloneWitnessTableImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_StructType: + return cloneStructTypeImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_Generic: + return cloneGenericImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_StructKey: + return cloneStructKeyImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_GlobalGenericParam: + return cloneGlobalGenericParamImpl(context, builder, cast(originalInst), originalValues); default: - SLANG_UNEXPECTED("unknown global value kind"); - UNREACHABLE_RETURN(nullptr); + break; } + // The common case is that we just need to construct a cloned + // instruction with the right number of operands, intialize + // it, and then add it to the sequence. + UInt argCount = originalInst->getOperandCount(); + IRInst* clonedInst = createInstWithTrailingArgs( + builder, originalInst->op, + cloneType(context, originalInst->getFullType()), + 0, nullptr, + argCount, nullptr); + registerClonedValue(context, clonedInst, originalValues); + auto oldBuilder = context->builder; + context->builder = builder; + for (UInt aa = 0; aa < argCount; ++aa) + { + IRInst* originalArg = originalInst->getOperand(aa); + IRInst* clonedArg = cloneValue(context, originalArg); + clonedInst->getOperands()[aa].init(clonedInst, clonedArg); + } + builder->addInst(clonedInst); + context->builder = oldBuilder; + cloneDecorations(context, clonedInst, originalInst); + + return clonedInst; + } + + IRGlobalValue* cloneGlobalValueImpl( + IRSpecContext* context, + IRGlobalValue* originalInst, + IROriginalValuesForClone const& originalValues) + { + auto clonedValue = cloneInst(context, &context->shared->builderStorage, originalInst, originalValues); + clonedValue->moveToEnd(); + return cast(clonedValue); } + // Clone a global value, which has the given `mangledName`. // The `originalVal` is a known global IR value with that name, if one is available. // (It is okay for this parameter to be null). @@ -5202,7 +5428,7 @@ namespace Slang // If the global value being cloned is already in target module, don't clone // Why checking this? // When specializing a generic function G (which is already in target module), - // where G calls a normal function F (which is already in target module), + // where G calls a normal function F (which is already in target module), // then when we are making a copy of G via cloneFuncCommom(), it will recursively clone F, // however we don't want to make a duplicate of F in the target module. if (originalVal->getParent() == context->getModule()->getModuleInst()) @@ -5210,17 +5436,19 @@ namespace Slang // Check if we've already cloned this value, for the case where // an original value has already been established. - IRInst* clonedVal = nullptr; - if( originalVal && context->getClonedValues().TryGetValue(originalVal, clonedVal) ) + if (originalVal) { - return (IRGlobalValue*) clonedVal; + if (IRInst* clonedVal = findClonedValue(context, originalVal)) + { + return cast(clonedVal); + } } if(getText(mangledName).Length() == 0) { // If there is no mangled name, then we assume this is a local symbol, // and it can't possibly have multiple declarations. - return cloneGlobalValueImpl(context, originalVal, nullptr); + return cloneGlobalValueImpl(context, originalVal, IROriginalValuesForClone()); } // @@ -5236,7 +5464,7 @@ namespace Slang // This shouldn't happen! SLANG_UNEXPECTED("no matching values registered"); - UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, nullptr)); + UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, IROriginalValuesForClone())); } // We will try to track the "best" declaration we can find. @@ -5256,12 +5484,15 @@ namespace Slang // Check if we've already cloned this value, for the case where // we didn't have an original value (just a name), but we've // now found a representative value. - if( !originalVal && context->getClonedValues().TryGetValue(bestVal, clonedVal) ) + if (!originalVal) { - return (IRGlobalValue*) clonedVal; + if (IRInst* clonedVal = findClonedValue(context, bestVal)) + { + return cast(clonedVal); + } } - return cloneGlobalValueImpl(context, bestVal, sym); + return cloneGlobalValueImpl(context, bestVal, IROriginalValuesForClone(sym)); } IRGlobalValue* cloneGlobalValueWithMangledName(IRSpecContext* context, Name* mangledName) @@ -5365,11 +5596,6 @@ namespace Slang ProgramLayout* programLayout, SubstitutionSet typeSubst); - RefPtr createGlobalGenericParamSubstitution( - EntryPointRequest * entryPointRequest, - ProgramLayout * programLayout, - IRSpecContext* context); - struct IRSpecializationState { ProgramLayout* programLayout; @@ -5382,8 +5608,16 @@ namespace Slang IRSharedSpecContext sharedContextStorage; IRSpecContext contextStorage; + IRSpecEnv globalEnv; + IRSharedSpecContext* getSharedContext() { return &sharedContextStorage; } IRSpecContext* getContext() { return &contextStorage; } + + IRSpecializationState() + { + contextStorage.env = &globalEnv; + } + ~IRSpecializationState() { newProgramLayout = nullptr; @@ -5429,19 +5663,27 @@ namespace Slang auto context = state->getContext(); context->shared = sharedContext; context->builder = &sharedContext->builderStorage; - // Create the GlobalGenericParamSubstitution for substituting global generic types - // into user-provided type arguments - auto globalParamSubst = createGlobalGenericParamSubstitution(entryPointRequest, programLayout, context); - context->subst.globalGenParamSubstitutions = globalParamSubst; - - // now specailize the program layout using the substitution - RefPtr newProgramLayout = specializeProgramLayout(targetReq, programLayout, context->subst); + // Now specialize the program layout using the substitution + // + // TODO: The specialization of the layout is conceptually an AST-level operations, + // and shouldn't be done here in the IR at all. + // + RefPtr newProgramLayout = specializeProgramLayout( + targetReq, + programLayout, + SubstitutionSet(entryPointRequest->globalGenericSubst)); + + // TODO: we need to register the (IR-level) arguments of the global generic parameters as the + // substitutions for the generic parameters in the original IR. + + // applyGlobalGenericParamSubsitution(...); + state->newProgramLayout = newProgramLayout; // Next, we want to optimize lookup for layout infromation - // associated with global declarations, so that we can + // associated with global declarations, so that we can // look things up based on the IR values (using mangled names) auto globalStructLayout = getGlobalStructLayout(newProgramLayout); for (auto globalVarLayout : globalStructLayout->fields) @@ -5453,7 +5695,7 @@ namespace Slang // for now, clone all unreferenced witness tables for (auto sym :context->getSymbols()) { - if (sym.Value->irGlobalValue->op == kIROp_witness_table) + if (sym.Value->irGlobalValue->op == kIROp_WitnessTable) cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); } return state; @@ -5526,6 +5768,20 @@ namespace Slang // it might reference. auto irEntryPoint = specializeIRForEntryPoint(context, entryPointRequest, entryPointLayout); + // HACK: right now the bindings for global generic parameters are coming in + // as part of the original IR module, and we need to make sure these get + // copied over, even if they aren't referenced. + // + for(auto inst : originalIRModule->getGlobalInsts()) + { + auto bindInst = as(inst); + if(!bindInst) + continue; + + cloneValue(context, bindInst); + } + + // TODO: *technically* we should consider the case where // we have global variables with initializers, since // these should get run whether or not the entry point @@ -5551,7 +5807,7 @@ namespace Slang break; } } - + struct IRGenericSpecContext : IRSpecContextBase { IRSpecContextBase* parent = nullptr; @@ -5560,383 +5816,69 @@ namespace Slang // Override the "maybe clone" logic so that we always clone virtual IRInst* maybeCloneValue(IRInst* originalVal) override; - - virtual RefPtr maybeCloneType(Type* originalType) override; - virtual RefPtr maybeCloneVal(Val* val) override; }; - // Convert a type-level value into an IR-level equivalent. - IRInst* getIRValue( - IRGenericSpecContext* context, - Val* val) + IRInst* IRGenericSpecContext::maybeCloneValue(IRInst* originalVal) { - if( auto subtypeWitness = dynamic_cast(val) ) - { - auto mangledName = context->getModule()->session->getNameObj(getMangledNameForConformanceWitness( - subtypeWitness->sub, - subtypeWitness->sup)); - RefPtr symbol; - - if (context->getSymbols().TryGetValue(mangledName, symbol)) - { - // Note: the symbols always come from the source module, - // not the destination module, so we may need to clone - // them if we are doing an initialize specialization pass. - return cloneValue(context, symbol->irGlobalValue); - } - else - { - // we don't have the required witness table yet, - // try to emit a specialize instruction to get one - auto subDeclRef = subtypeWitness->sub->AsDeclRefType(); - auto subDeclRefGen = DeclRef(subDeclRef->declRef.decl, - createDefaultSubstitutions(context->builder->getSession(), subDeclRef->declRef.decl)); - - auto genericName = context->getModule()->session->getNameObj(getMangledNameForConformanceWitness( - subDeclRefGen, - subtypeWitness->sup)); - if (context->getSymbols().TryGetValue(genericName, symbol)) - { - auto clonedSymbol = cloneValue(context, symbol->irGlobalValue); - auto specInst = context->builder->emitSpecializeInst(subtypeWitness->sup, clonedSymbol, subDeclRef->declRef); - return specInst; - } - else - { - SLANG_UNEXPECTED("witness table not exist"); - UNREACHABLE_RETURN(nullptr); - } - } - } - else if (auto intVal = dynamic_cast(val)) - { - return context->builder->getIntValue(context->shared->originalModule->session->getBuiltinType(BaseType::Int), intVal->value); - } - else if (auto proxyVal = dynamic_cast(val)) + if (parent) { - // The type-level value actually references an IR-level value, - // so we need to make sure to emit as if we were referencing - // the pointed-to value and not the proxy type-level `Val` - // instead. - - return context->maybeCloneValue(proxyVal->inst.get()); + return parent->maybeCloneValue(originalVal); } else { - SLANG_UNEXPECTED("unimplemented"); - UNREACHABLE_RETURN(nullptr); + return originalVal; } } - IRInst* getSubstValue( - IRGenericSpecContext* context, - DeclRef declRef) + // See the work list for the generic spec context with + // every relevant instruction from `inst` through its + // descendents. + void addToSpecializationWorkListRec( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) { - auto subst = context->subst.genericSubstitutions; - SLANG_ASSERT(subst); - auto genericDecl = subst->genericDecl; - - UInt orinaryParamCount = 0; - for( auto mm : genericDecl->Members ) + if(auto genericInst = as(inst)) { - if(mm.As()) - orinaryParamCount++; - else if(mm.As()) - orinaryParamCount++; + // We do *not* consider generics, or instructions nested under them. + return; } - - if( auto constraintDeclRef = declRef.As() ) + else if(auto parentInst = as(inst)) { - // We have a constraint, but we need to find its index in the - // argument list of the substitutions. - UInt constraintIndex = 0; - bool found = false; - for( auto cd : genericDecl->getMembersOfType() ) - { - if( cd.Ptr() == constraintDeclRef.getDecl() ) - { - found = true; - break; - } + // For a parent instruction, we will scan through its contents, + // since that will be where the `specialize` instructions are - constraintIndex++; - } - assert(found); - - UInt argIndex = orinaryParamCount + constraintIndex; - assert(argIndex < subst->args.Count()); - - return getIRValue(context, subst->args[argIndex]); - } - else if (auto valDeclRef = declRef.As()) - { - // We have a constraint, but we need to find its index in the - // argument list of the substitutions. - UInt argIdx = 0; - bool found = false; - for (auto cd : genericDecl->Members) + for(auto child : parentInst->children) { - if (cd.Ptr() == valDeclRef.getDecl()) - { - found = true; - break; - } - if (cd.As()) - argIdx++; - else if (cd.As()) - argIdx++; + addToSpecializationWorkListRec(sharedContext, child); } - assert(found); - - assert(argIdx < subst->args.Count()); - - return getIRValue(context, subst->args[argIdx]); } else { - SLANG_UNEXPECTED("unimplemented"); - UNREACHABLE_RETURN(nullptr); - } - } - - IRInst* IRGenericSpecContext::maybeCloneValue(IRInst* originalVal) - { - switch( originalVal->op ) - { - case kIROp_decl_ref: - { - auto declRefVal = (IRDeclRef*) originalVal; - auto declRef = declRefVal->declRef; - auto genSubst = subst.genericSubstitutions; - SLANG_ASSERT(genSubst); - // We may have a direct reference to one of the parameters - // of the generic we are specializing, and in that case - // we nee to translate it over to the equiavalent of - // the `Val` we have been given. - if(declRef.getDecl()->ParentDecl == genSubst->genericDecl && - (declRef.As() || declRef.As()|| - declRef.As())) - { - if (auto substVal = getSubstValue(this, declRef)) - return substVal; - } - int diff = 0; - auto substDeclRef = declRefVal->declRef.SubstituteImpl(subst, &diff); - if(!diff) - return originalVal; - - return builder->getDeclRefVal(substDeclRef); - } - break; - - default: - if (parent) - { - return parent->maybeCloneValue(originalVal); - } - else - { - return originalVal; - } - } - } - - RefPtr IRGenericSpecContext::maybeCloneType(Type* originalType) - { - return originalType->Substitute(subst).As(); - } - - RefPtr IRGenericSpecContext::maybeCloneVal(Val * val) - { - return val->Substitute(subst); - } - - // Given a list of substitutions, return the inner-most - // generic substitution in the list, or NULL if there - // are no generic substitutions. - RefPtr getInnermostGenericSubst( - SubstitutionSet inSubst) - { - return inSubst.genericSubstitutions; - } - - RefPtr getInnermostGenericDecl( - Decl* inDecl) - { - auto decl = inDecl; - while( decl ) - { - GenericDecl* genericDecl = dynamic_cast(decl); - if(genericDecl) - return genericDecl; - - decl = decl->ParentDecl; + // Default case: consider this instruction for specialization. + sharedContext->addToWorkList(inst); } - return nullptr; } - // This function takes a list of substitutions that we'd - // like to apply, but which might apply to a different - // declaration in cases where we have got target-specific - // overloads in the mix, and produces a new set of - // substitutiosn without this issue. - RefPtr cloneSubstitutionsForSpecialization( - IRSharedSpecContext* sharedContext, - RefPtr oldSubst, - Decl* newDecl) - { - // We will "peel back" layers of substitutions until - // we find our first generic subsitution. - auto oldGenericSubst = oldSubst; - if(!oldGenericSubst) - return nullptr; - - auto innerGenericName = oldGenericSubst->genericDecl->inner->getName(); - - // We will also peel back layers of declarations until - // we find our first generic decl. - GenericDecl* newGenericDecl = nullptr; - - for (Decl* d = newDecl; d; d = d->ParentDecl) - { - if (auto gd = dynamic_cast(d)) - { - if (gd->inner->getName() == innerGenericName) - { - newGenericDecl = gd; - break; - } - } - } - - if( !newGenericDecl ) - { - if(auto gd = dynamic_cast(newDecl)) - { - if( auto ed = gd->inner.As() ) - { - // TODO: we should confirm that it is an extension for the correct type... - - newGenericDecl = gd; - } - } - } - - SLANG_ASSERT(newGenericDecl); - - RefPtr newSubst = new GenericSubstitution(); - newSubst->genericDecl = newGenericDecl; - newSubst->args = oldGenericSubst->args; - - newSubst->outer = cloneSubstitutionsForSpecialization( - sharedContext, - oldGenericSubst->outer, - newGenericDecl->ParentDecl); - - return newSubst; - } - - IRFunc* getSpecializedFunc( - IRSharedSpecContext* sharedContext, - IRSpecContextBase* parentContext, - IRFunc* genericFunc, - DeclRef specDeclRef); - - IRWitnessTable* specializeWitnessTable( - IRSharedSpecContext* sharedContext, - IRSpecContextBase* parentContext, - IRWitnessTable* originalTable, - DeclRef specDeclRef, - IRWitnessTable* dstTable) + IRInst* specializeGeneric( + IRSharedGenericSpecContext* sharedContext, + IRSpecContextBase* parentContext, + IRGeneric* genericVal, + IRSpecialize* specializeInst) { // First, we want to see if an existing specialization // has already been made. To do that we will need to - // compute the mangled name of the specialized function, + // compute the mangled name of the specialized value, // so that we can look for existing declarations. - String specializedMangledName = getMangledNameForConformanceWitness(specDeclRef.Substitute(originalTable->subTypeDeclRef), - specDeclRef.Substitute(originalTable->supTypeDeclRef)); - - if (dstTable && getText(dstTable->mangledName).Length()) - specializedMangledName = getText(dstTable->mangledName); - - // TODO: This is a terrible linear search, and we should - // avoid it by building a dictionary ahead of time, - // as is being done for the `IRSpecContext` used above. - // We can probalby use the same basic context, actually. - if (!dstTable) - { - auto module = sharedContext->module; - for(auto ii : module->getGlobalInsts()) - { - auto gv = as(ii); - if (!gv) - continue; - - if (getText(gv->mangledName) == specializedMangledName) - return (IRWitnessTable*)gv; - } - } - RefPtr newSubst = cloneSubstitutionsForSpecialization( - sharedContext, - specDeclRef.substitutions.genericSubstitutions, - originalTable->genericDecl); - - IRGenericSpecContext context; - context.shared = sharedContext; - context.parent = parentContext; - context.builder = &sharedContext->builderStorage; - context.subst = specDeclRef.substitutions; - context.subst.genericSubstitutions = newSubst; - // TODO: other initialization is needed here... - - auto specTable = cloneWitnessTableWithoutRegistering(&context, originalTable, dstTable); - - // Set up the clone to recognize that it is no longer generic - specTable->mangledName = context.getModule()->session->getNameObj(specializedMangledName); - specTable->genericDecl = nullptr; - - // Specialization of witness tables should trigger cascading specializations - // of involved functions. - for (auto entry : specTable->getEntries()) - { - if (entry->satisfyingVal.get()->op == kIROp_Func) - { - IRFunc* func = (IRFunc*)entry->satisfyingVal.get(); - auto specFunc = getSpecializedFunc(sharedContext, parentContext, func, specDeclRef); - entry->satisfyingVal.set(specFunc); - insertGlobalValueSymbol(sharedContext, specFunc); - } - - } - // We also need to make sure that we register this specialized - // function under its mangled name, so that later lookup - // steps will find it. - insertGlobalValueSymbol(sharedContext, specTable); - - return specTable; - } - - IRFunc* getSpecializedFunc( - IRSharedSpecContext* sharedContext, - IRSpecContextBase* parentContext, - IRFunc* genericFunc, - DeclRef specDeclRef) - { - // First, we want to see if an existing specialization - // has already been made. To do that we will need to - // compute the mangled name of the specialized function, - // so that we can look for existing declarations. - String specMangledName; - if (genericFunc->getGenericDecl() == specDeclRef.decl) - specMangledName = getMangledName(specDeclRef); - else - specMangledName = mangleSpecializedFuncName(getText(genericFunc->mangledName), specDeclRef.substitutions); + String specMangledName = mangleSpecializedFuncName(getText(genericVal->mangledName), specializeInst); auto specMangledNameObj = sharedContext->module->session->getNameObj(specMangledName); + + // Now look up an existing symbol with a matching name RefPtr symb; if (sharedContext->symbols.TryGetValue(specMangledNameObj, symb)) { - return (IRFunc*)(symb->irGlobalValue); + return symb->irGlobalValue; } + // TODO: This is a terrible linear search, and we should // avoid it by building a dictionary ahead of time, // as is being done for the `IRSpecContext` used above. @@ -5948,104 +5890,285 @@ namespace Slang continue; if (gv->mangledName == specMangledNameObj) - return (IRFunc*) gv; + return gv; } // If we get to this point, then we need to construct a - // new `IRFunc` to represent the result of specialization. + // new IR value to represent the result of specialization. - // The substitutions we are applying might have been created - // using a different overload of a target-specific function, - // so we need to create a dummy substitution here, to make - // sure it used the correct generic. - RefPtr newSubst = cloneSubstitutionsForSpecialization( - sharedContext, - specDeclRef.substitutions.genericSubstitutions, - genericFunc->getGenericDecl()); + // We need to establish a new mapping from inst->inst to + // handle the specialization, because we don't want the + // clones we register in this pass to cause confusion + // in later steps that might clone the same code. - if (!newSubst) - return genericFunc; + IRSpecEnv env; + env.parent = &sharedContext->globalEnv; + if (parentContext) + { + env.parent = parentContext->getEnv(); + } + + // The result of specialization should be inserted + // into the global scope, at the same location as + // the original generic. + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(genericVal); IRGenericSpecContext context; context.shared = sharedContext; context.parent = parentContext; - context.builder = &sharedContext->builderStorage; - context.subst = specDeclRef.substitutions; - context.subst.genericSubstitutions = newSubst; + context.builder = builder; + context.env = &env; - // TODO: other initialization is needed here... + // Register the arguments of the `specialize` instruction to be used + // as the "cloned" value for each of the parameters of the generic. + // + UInt argCounter = 0; + for (auto param = genericVal->getFirstParam(); param; param = param->getNextParam()) + { + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < specializeInst->getArgCount()); - auto specFunc = cloneSimpleFuncWithoutRegistering(&context, genericFunc); + IRInst* arg = specializeInst->getArg(argIndex); - specFunc->mangledName = context.getModule()->session->getNameObj(specMangledName); - - // reduce specialized generic level by 1 - if (specFunc->specializedGenericLevel >= 0) - specFunc->specializedGenericLevel--; + registerClonedValue(&context, arg, param); + } - // Put the function into the global sequence right after - // the function it specializes. - // - // TODO: This shouldn't be needed, if we introduce a sorting - // step before we emit code. - //specFunc->removeFromParent(); - //specFunc->insertAfter(genericFunc); + // Okay, now we want to run through the body of the generic + // and clone stuff into the parent scope (which had + // better be the global scope). + for (auto bb : genericVal->getBlocks()) + { + // We expect a generic to only ever contain a single block. + SLANG_ASSERT(bb == genericVal->getFirstBlock()); - // At this point we've created a new non-generic function, - // which means we should add it to our work list for - // subsequent processing. - if (specFunc->specializedGenericLevel == -1) - sharedContext->workList.Add(specFunc); + for (auto ii : bb->getChildren()) + { + // Skip parameters, since they were handled earlier. + if (auto param = as(ii)) + continue; - // We also need to make sure that we register this specialized - // function under its mangled name, so that later lookup - // steps will find it. - insertGlobalValueSymbol(sharedContext, specFunc); + // The last block of the generic is expected to end with + // a `return` instruction for the specialized value that + // comes out of the abstraction. + // + // We thus use that cloned value as the result of the + // specialization step. + if (auto returnValInst = as(ii)) + { + auto clonedResult = cloneValue(&context, returnValInst->getVal()); + if (auto clonedGlobalValue = as(clonedResult)) + { + clonedGlobalValue->mangledName = specMangledNameObj; + + // TODO: create a symbol for it and add it to the map. + } - return specFunc; + return clonedResult; + } + + // Otherwise, clone the instruction into the global scope + IRInst* clonedInst = cloneInst(&context, context.builder, ii); + + // Now that we've cloned the instruction to a location outside + // of a generic, we should consider whether it can now be specialized. + addToSpecializationWorkListRec(sharedContext, clonedInst); + } + } + + // If we reach this point, something went wrong, because we + // never encountered a `return` inside the body of the generic. + SLANG_UNEXPECTED("no return from generic"); + UNREACHABLE_RETURN(nullptr); } // Find the value in the given witness table that // satisfies the given requirement (or return // null if not found). IRInst* findWitnessVal( - IRWitnessTable* witnessTable, - DeclRef const& requirementDeclRef) + IRWitnessTable* witnessTable, + IRInst* requirementKey) { // For now we will do a dumb linear search for( auto entry : witnessTable->getEntries() ) { - // We expect the key on the entry to be a decl-ref, - // but lets go ahead and check, just to be sure. - auto requirementKey = entry->requirementKey.get(); - if(requirementKey->op != kIROp_decl_ref) + // If the keys matched, then we use the value from this entry. + if (requirementKey == entry->requirementKey.get()) + { + auto satisfyingVal = entry->satisfyingVal.get(); + return satisfyingVal; + } + } + + // No matching entry found. + return nullptr; + } + + static bool canSpecializeGeneric( + IRGeneric* generic) + { + IRGeneric* g = generic; + for(;;) + { + auto val = findGenericReturnVal(g); + if(!val) + return false; + + if (auto nestedGeneric = as(val)) + { + // The outer generic returns an *inner* generic + // (so that multiple calls to `specialize` are + // needed to resolve it). We should look at + // what the nested generic returns to figure + // out whether specialization is allowed. + g = nestedGeneric; continue; - auto keyDeclRef = ((IRDeclRef*) requirementKey)->declRef; + } + + // We've found the leaf value that will be produced after + // all of the specialization is done. Now we want to know + // if that is a value suitable for actually specializing - // If the keys don't match, continue with the next entry. - if (!keyDeclRef.Equals(requirementDeclRef)) + if (auto globalValue = as(val)) { - // requirementDeclRef may be pointing to the inner decl of a generic decl - // in this case we compare keyDeclRef against the parent decl of requiredDeclRef - if (auto genRequiredDeclRef = requirementDeclRef.GetParent().As()) + if (isDefinition(globalValue)) + return true; + return false; + } + else + { + // There might be other cases with a declaration-vs-definition + // thing that we need to handle. + + return true; + } + } + } + + // Add any instruction that uses `inst` to the work list, + // so that it can be evaluated (or re-evaluated) for specialization. + void addUsesToWorkList( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + for(auto u = inst->firstUse; u; u = u->nextUse) + { + sharedContext->addToWorkList(u->getUser()); + } + } + + void specializeGenericsForInst( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + switch(inst->op) + { + default: + // The default behavior is to do nothing. + // An instruction is specialize-able once its operands + // are specialized, and after that it is also safe + // to consider the instruction specialized. + break; + + case kIROp_Specialize: + { + // We have a `specialize` instruction, so lets see + // whether we have an opportunity to perform the + // specialization here and now. + IRSpecialize* specInst = cast(inst); + + // Look at the base of the `specialize`, and see if + // it directly names a generic, so that we can apply + // specialization here and now. + auto baseVal = specInst->getBase(); + if(auto genericVal = as(baseVal)) { - if (!keyDeclRef.Equals(genRequiredDeclRef)) + if (canSpecializeGeneric(genericVal)) { - continue; + // Okay, we have a candidate for specialization here. + // + // We will apply the specialization logic to the body of the generic, + // which will yield, e.g., a specialized `IRFunc`. + // + auto specializedVal = specializeGeneric(sharedContext, nullptr, genericVal, specInst); + // + // Then we will replace the use sites for the `specialize` + // instruction with uses of the specialized value. + // + addUsesToWorkList(sharedContext, specInst); + specInst->replaceUsesWith(specializedVal); + specInst->removeAndDeallocate(); } } - else - continue; } + break; + + case kIROp_lookup_interface_method: + { + // We have a `lookup_interface_method` instruction, + // so let's see whether it is a lookup in a known + // witness table. + IRLookupWitnessMethod* lookupInst = cast(inst); + + // We only want to deal with the case where the witness-table + // argument points to a concrete global table (and not, e.g., a + // `specialize` instruction that will yield a table) + auto witnessTable = as(lookupInst->witnessTable.get()); + if(!witnessTable) + break; - // If the keys matched, then we use the value from - // this entry. - auto satisfyingVal = entry->satisfyingVal.get(); - return satisfyingVal; + // Use the witness table to look up the value that + // satisfies the requirement. + auto requirementKey = lookupInst->getRequirementKey(); + auto satisfyingVal = findWitnessVal(witnessTable, requirementKey); + // We expect to always find something, but lets just + // be careful here. + if(!satisfyingVal) + break; + + // If we get through all of the above checks, then we + // have a (more) concrete method that implements the interface, + // and so we should dispatch to that directly, rather than + // use the `lookup_interface_method` instruction. + addUsesToWorkList(sharedContext, lookupInst); + lookupInst->replaceUsesWith(satisfyingVal); + lookupInst->removeAndDeallocate(); + } + break; } + } - // No matching entry found. - return nullptr; + static bool isInstSpecialized( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + // If an instruction is still on our work list, then + // it isn't specialized, and conversely we say that + // if it *isn't* on the work list, it must be specialized. + // + // Note: if we end up with bugs in this logic, we could + // maintain an explicit set of specialized insts instead. + // + return !sharedContext->workListSet.Contains(inst); + } + + static bool canSpecializeInst( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + // We can specialize an instruction once all its + // operands are specialized. + + UInt operandCount = inst->getOperandCount(); + for(UInt ii = 0; ii < operandCount; ++ii) + { + IRInst* operand = inst->getOperand(ii); + if(!isInstSpecialized(sharedContext, operand)) + return false; + } + return true; } // Go through the code in the module and try to identify @@ -6056,7 +6179,7 @@ namespace Slang IRModule* module, CodeGenTarget target) { - IRSharedSpecContext sharedContextStorage; + IRSharedGenericSpecContext sharedContextStorage; auto sharedContext = &sharedContextStorage; initializeSharedSpecContext( @@ -6066,351 +6189,127 @@ namespace Slang module, target); - // Our goal here is to find `specialize` instructions that - // can be replaced with references to a suitably sepcialized - // funciton. As a simplification, we will only consider `specialize` - // calls that are inside of non-generic functions, since we assume - // that these will allow us to fully specialize the referenced - // function. - // - // We start by building up a work list of non-generic functions. - for(auto ii : module->getGlobalInsts()) - { - auto gv = as(ii); - if (!gv) - continue; + auto moduleInst = module->getModuleInst(); - // Is it a function? If not, skip. - if(gv->op != kIROp_Func) + // First things first, let's deal with any bindings for global generic parameters. + for(auto inst : moduleInst->getChildren()) + { + auto bindInst = as(inst); + if(!bindInst) continue; - auto func = (IRFunc*) gv; - // Is it generic? If so, skip. - if(func->getGenericDecl()) - continue; + auto param = bindInst->getParam(); + auto val = bindInst->getVal(); - sharedContext->workList.Add(func); + param->replaceUsesWith(val); } - - // Build dictionary for witness tables - Dictionary witnessTables; - for(auto ii : module->getGlobalInsts()) { - auto gv = as(ii); - if (!gv) - continue; - - if (gv->op == kIROp_witness_table) - witnessTables.AddIfNotExists(gv->mangledName, (IRWitnessTable*)gv); - } - - // Now that we have our work list, we are going to - // process it until it goes empty. Along the way - // we may specialize a function and thus create - // a new non-generic function, and in that case - // we will add the new function to the work list. - auto& workList = sharedContext->workList; - while( auto count = workList.Count() ) - { - // We will process the last entry in the - // work list, which amounts to treating - // it like a stack when we have recursive - // specialization to perform. - auto func = workList[count-1]; - workList.RemoveAt(count-1); - - // We are going to go ahead and walk through - // all the instructions in this function, - // and look for `specialize` operations. - for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) + // Now we will do a second pass to clean up the + // generic parameters and their bindings. + IRInst* next = nullptr; + for(auto inst = moduleInst->getFirstChild(); inst; inst = next) { - // We need to be careful when iterating over the instructions, - // because we might end up removing the "current" instruction, - // so that accessing `ii->next` would crash. - IRInst* nextInst = nullptr; - for( auto ii = bb->getFirstInst(); ii; ii = nextInst ) - { - nextInst = ii->getNextInst(); + next = inst->getNextInst(); - // We want to handle both `specialize` instructions, - // which trigger specialization, and also `lookup_interface_method` - // instructions, which may allow us to "de-virtualize" - // calls. - - switch( ii->op ) - { - default: - // Most instructions are ones we don't care about here. - continue; - - case kIROp_specialize: - { - // We have a `specialize` instruction, so lets see - // whether we have an opportunity to perform the - // specialization here and now. - IRSpecialize* specInst = (IRSpecialize*) ii; - - // Now we extract the specialized decl-ref that will - // tell us how to specialize things. - auto specDeclRefVal = (IRDeclRef*)specInst->specDeclRefVal.get(); - auto specDeclRef = specDeclRefVal->declRef; - - // We need to specialize functions and witness tables - auto genericVal = specInst->genericVal.get(); - if (genericVal->op == kIROp_Func) - { - auto genericFunc = (IRFunc*)genericVal; - if (!genericFunc->getGenericDecl()) - continue; - - // Okay, we have a candidate for specialization here. - // - // We will first find or construct a specialized version - // of the callee funciton/ - auto specFunc = getSpecializedFunc(sharedContext, nullptr, genericFunc, specDeclRef); - // - // Then we will replace the use sites for the `specialize` - // instruction with uses of the specialized function. - // - specInst->replaceUsesWith(specFunc); - - specInst->removeAndDeallocate(); - } - else if (genericVal->op == kIROp_witness_table) - { - // specialize a witness table - auto originalTable = (IRWitnessTable*)genericVal; - auto specWitnessTable = specializeWitnessTable(sharedContext, nullptr, originalTable, specDeclRef, nullptr); - witnessTables.AddIfNotExists(specWitnessTable->mangledName, specWitnessTable); - specInst->replaceUsesWith(specWitnessTable); - specInst->removeAndDeallocate(); - } - } - break; - case kIROp_lookup_witness_table: - { - // try find concrete witness table from global scope - IRLookupWitnessTable* lookupInst = (IRLookupWitnessTable*)ii; - IRWitnessTable* witnessTable = nullptr; - auto srcDeclRef = ((IRDeclRef*)lookupInst->sourceType.get())->declRef; - auto interfaceDeclRef = ((IRDeclRef*)lookupInst->interfaceType.get())->declRef; - auto mangledName = module->session->getNameObj(getMangledNameForConformanceWitness(srcDeclRef, interfaceDeclRef)); - witnessTables.TryGetValue(mangledName, witnessTable); - - if (!witnessTable) - { - // try specialize the witness table - auto genDeclRef = srcDeclRef; - genDeclRef.substitutions = createDefaultSubstitutions(module->session, genDeclRef.decl); - auto genName = module->session->getNameObj(getMangledNameForConformanceWitness(genDeclRef, interfaceDeclRef)); - IRWitnessTable* genTable = nullptr; - if (witnessTables.TryGetValue(genName, genTable)) - { - witnessTable = specializeWitnessTable(sharedContext, nullptr, genTable, srcDeclRef, nullptr); - witnessTables.AddIfNotExists(witnessTable->mangledName, witnessTable); - } - } - if (witnessTable) - { - lookupInst->replaceUsesWith(witnessTable); - lookupInst->removeAndDeallocate(); - } - } - break; - case kIROp_lookup_interface_method: - { - // We have a `lookup_interface_method` instruction, - // so let's see whether it is a lookup in a known - // witness table. - IRLookupWitnessMethod* lookupInst = (IRLookupWitnessMethod*) ii; - - // We only want to deal with the case where the witness-table - // argument points to a concrete global table. - auto witnessTableArg = lookupInst->witnessTable.get(); - if(witnessTableArg->op != kIROp_witness_table) - continue; - IRWitnessTable* witnessTable = (IRWitnessTable*)witnessTableArg; - - // We also need to be sure that the requirement we - // are trying to look up is identified via a decl-ref: - auto requirementArg = lookupInst->requirementDeclRef.get(); - if(requirementArg->op != kIROp_decl_ref) - continue; - auto requirementDeclRef = ((IRDeclRef*) requirementArg)->declRef; - - // Use the witness table to look up the value that - // satisfies the requirement. - auto satisfyingVal = findWitnessVal(witnessTable, requirementDeclRef); - // We expect to always find something, but lets just - // be careful here. - if(!satisfyingVal) - continue; - - // If we get through all of the above checks, then we - // have a (more) concrete method that implements the interface, - // and so we should dispatch to that directly, rather than - // use the `lookup_interface_method` instruction. - lookupInst->replaceUsesWith(satisfyingVal); - lookupInst->removeAndDeallocate(); - } - break; - } - - - // We only care about `specialize` instructions. - if(ii->op != kIROp_specialize) - continue; + switch(inst->op) + { + default: + break; + case kIROp_GlobalGenericParam: + case kIROp_BindGlobalGenericParam: + // A "bind" instruction should have no uses in the + // first place, and all the global generic parameters + // should have had their uses replaced. + SLANG_ASSERT(!inst->firstUse); + inst->removeAndDeallocate(); + break; } } } - // Once the work list has gone dry, we should have the invariant - // that there are no `specialize` instructions inside of non-generic - // functions that in turn reference a generic function. - } - - RefPtr createGlobalGenericParamSubstitution( - EntryPointRequest * entryPointRequest, - ProgramLayout * programLayout, - IRSpecContext* context) - { - RefPtr globalParamSubst; - GlobalGenericParamSubstitution * curTailSubst = nullptr; - - // Because we can't currently put `specialize` instructions inside - // witness tables, or at the global scope, we will track a set of - // witness tables that we need to clone, and then specialize - // from the original module(s) to get what we need. + // Our goal here is to find `specialize` instructions that + // can be replaced with references to, e.g., a suitably + // specialized function, and to resolve any `lookup_interface_method` + // instructions to the concrete value fetched from a witness + // table. + // + // We need to be careful of a few things: + // + // * It would not in general make sense to consider specialize-able + // instructions under an `IRGeneric`, since that could mean "specialziing" + // code to parameter values that are still unknown. + // + // * We *also* need to be careful not to specialize something when one + // or more of its inputs is also a `specialize` or `lookup_interface_method` + // instruction, because then we'd be propagating through non-concrete + // values. + // + // The approach we use here is to build a work list of instructions + // that *can* become fully specialized, but aren't yet. Any + // instruction on the work list will be considered to be "unspecialized" + // and any instruction not on the work list is considered specialized. + // + // We will start by recursively walking all the instructions to add + // the appropriate ones to our work list: + // + addToSpecializationWorkListRec(sharedContext, moduleInst); - struct WitnessTableCloneWorkItem + // Now we are going to repeatedly walk our work list, and filter + // it to create a new work list. + List workListCopy; + for(;;) { - IRWitnessTable* dstTable; - IRWitnessTable* originalTable; - }; - List witnessTablesToClone; + // Swap out the work list on the context so we can + // process it here without worrying about concurrent + // modifications. + workListCopy.Clear(); + workListCopy.SwapWith(sharedContext->workList); - struct WitnessTableSpecializationWorkItem - { - IRWitnessTable* dstTable; - IRWitnessTable* srcTable; - DeclRef specDeclRef; - }; - List witnessTablesToSpecailize; - - Dictionary witnessTablesByName; - auto namePool = entryPointRequest->compileRequest->getNamePool(); - - for (auto param : programLayout->globalGenericParams) - { - auto paramSubst = new GlobalGenericParamSubstitution(); - if (!globalParamSubst) - globalParamSubst = paramSubst; - if (curTailSubst) - curTailSubst->outer = paramSubst; - curTailSubst = paramSubst; - paramSubst->paramDecl = param->decl; - SLANG_ASSERT((UInt)param->index < entryPointRequest->genericParameterTypes.Count()); - paramSubst->actualType = entryPointRequest->genericParameterTypes[param->index]; - // find witness tables - for (auto witness : entryPointRequest->genericParameterWitnesses) + if(workListCopy.Count() == 0) + break; + + for(auto inst : workListCopy) { - if (auto subtypeWitness = witness.As()) + // We need to check whether it is possible to specialize + // the instruction yet (it might not be because its + // operands haven't been specialized) + if(!canSpecializeInst(sharedContext, inst)) { - if (subtypeWitness->sub->EqualsVal(paramSubst->actualType)) - { - auto witnessTableName = namePool->getName(getMangledNameForConformanceWitness(subtypeWitness->sub, subtypeWitness->sup)); - auto findWitnessTableByName = [&](Name* name) -> IRWitnessTable* - { - RefPtr symbol; - if (!context->getSymbols().TryGetValue(name, symbol)) - return nullptr; - - return (IRWitnessTable*) symbol->irGlobalValue; - }; - - auto findCloneOfWitnessTableByName = [&](Name* name) -> IRWitnessTable* - { - IRWitnessTable* clonedTable = nullptr; - if (witnessTablesByName.TryGetValue(name, clonedTable)) - return clonedTable; - - IRWitnessTable* originalTable = findWitnessTableByName(name); - if (!originalTable) - return nullptr; - - clonedTable = context->builder->createWitnessTable(); - - WitnessTableCloneWorkItem cloneWorkItem; - cloneWorkItem.originalTable = originalTable; - cloneWorkItem.dstTable = clonedTable; - witnessTablesToClone.Add(cloneWorkItem); - - return clonedTable; - }; - - // First look for a non-generic witness table that matches - auto table = findCloneOfWitnessTableByName(witnessTableName); - if (!table) - { - // If we didn't find a non-generic table, then maybe we are looking at - // a specialization of a generic witness table. - if (auto subDeclRefType = subtypeWitness->sub.As()) - { - auto defaultSubst = createDefaultSubstitutions(entryPointRequest->compileRequest->mSession, subDeclRefType->declRef.getDecl()); - auto genericWitnessTableName = namePool->getName( - getMangledNameForConformanceWitness(DeclRef(subDeclRefType->declRef.getDecl(), defaultSubst), subtypeWitness->sup)); - - IRWitnessTable* genericTable = findCloneOfWitnessTableByName(genericWitnessTableName); - SLANG_ASSERT(genericTable); - - WitnessTableSpecializationWorkItem specializeWorkItem; - specializeWorkItem.srcTable = genericTable; - specializeWorkItem.dstTable = context->builder->createWitnessTable(); - specializeWorkItem.dstTable->mangledName = context->getModule()->session->getNameObj(getMangledNameForConformanceWitness(subDeclRefType->declRef, subtypeWitness->sup)); - specializeWorkItem.specDeclRef = subDeclRefType->declRef; - - witnessTablesToSpecailize.Add(specializeWorkItem); - table = specializeWorkItem.dstTable; - } - } - // We expect to find the table no matter what. - SLANG_ASSERT(table); + // Put it back on the fresh work list, so that + // we can re-consider it in another iteration. + sharedContext->workList.Add(inst); + } + else + { + // Okay, perform any specialization step on this + // instruction that makes sense (which might be + // doing nothing). + specializeGenericsForInst(sharedContext, inst); - IRProxyVal * tableVal = new IRProxyVal(); - tableVal->inst.init(nullptr, table); - paramSubst->witnessTables.Add(KeyValuePair, RefPtr>(subtypeWitness->sup, tableVal)); - } + // Remove the instruction from consideration. + sharedContext->workListSet.Remove(inst); } } } - for (auto workItem : witnessTablesToClone) - { - cloneWitnessTableWithoutRegistering( - context, - workItem.originalTable, - workItem.dstTable); - } - - for (auto workItem : witnessTablesToSpecailize) - { - int diff = 0; - specializeWitnessTable( - context->shared, - context, - workItem.srcTable, - workItem.specDeclRef.SubstituteImpl(SubstitutionSet(nullptr, nullptr, globalParamSubst), &diff), - workItem.dstTable); - } + // Once the work list has gone dry, we should have the invariant + // that there are no `specialize` instructions inside of non-generic + // functions that in turn reference a generic function, *except* + // in the case where that generic is for a builtin function, in + // which case we wouldn't want to specialize it anyway. + } - return globalParamSubst; + void applyGlobalGenericParamSubstitution( + IRSpecContext* /*context*/) + { + // TODO: we need to figure out how to apply this } - + void markConstExpr( - Session* session, - IRInst* irValue) + IRBuilder* builder, + IRInst* irValue) { // We will take an IR value with type `T`, // and turn it into one with type `@ConstExpr T`. @@ -6418,6 +6317,9 @@ namespace Slang // TODO: need to be careful if the value already has a rate // qualifier set. - irValue->type = session->getConstExprType(irValue->getDataType()); + irValue->setFullType( + builder->getRateQualifiedType( + builder->getConstExprRate(), + irValue->getDataType())); } } diff --git a/source/slang/ir.h b/source/slang/ir.h index 3119f2aaa..4a393cae0 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -11,6 +11,7 @@ #include "source-loc.h" #include "memory_pool.h" +#include "type-system-shared.h" namespace Slang { @@ -35,11 +36,14 @@ enum : IROpFlags kIROpFlag_Parent = 1 << 0, }; -enum IROp : int16_t +enum IROp : int32_t { #define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ kIROp_##ID, +#define MANUAL_INST_RANGE(ID, START, COUNT) \ + kIROp_First##ID = START, kIROp_Last##ID = kIROp_First##ID + ((COUNT) - 1), + #include "ir-inst-defs.h" kIROpCount, @@ -119,9 +123,11 @@ enum IRDecorationOp : uint16_t kIRDecorationOp_Target, kIRDecorationOp_TargetIntrinsic, kIRDecorationOp_GLSLOuterArray, + kIRDecorationOp_Semantic, + kIRDecorationOp_InterpolationMode, }; -// represents an object allocated in an IR memory pool +// represents an object allocated in an IR memory pool struct IRObject { bool isDestroyed = false; @@ -146,12 +152,10 @@ struct IRDecoration : public IRObject IRDecorationOp op; }; -// Use AST-level types directly to represent the -// types of IR instructions/values -typedef Type IRType; - struct IRBlock; struct IRParentInst; +struct IRRate; +struct IRType; // Every value in the IR is an instruction (even things // like literal values). @@ -209,12 +213,14 @@ struct IRInst : public IRObject // The type of the result value of this instruction, // or `null` to indicate that the instruction has // no value. - RefPtr type; + IRUse typeUse; + + IRType* getFullType() { return (IRType*) typeUse.get(); } + void setFullType(IRType* type) { typeUse.init(this, (IRInst*) type); } - Type* getFullType() { return type; } + IRRate* getRate(); - Type* getRate(); - Type* getDataType(); + IRType* getDataType(); // After the type, we have data that is specific to // the subtype of `IRInst`. In most cases, this is @@ -277,6 +283,8 @@ struct IRInst : public IRObject // for those values. void removeArguments(); + // RTTI support + static bool isaImpl(IROp) { return true; } }; // `dynamic_cast` equivalent @@ -380,6 +388,43 @@ struct IRInstList : IRInstListBase Iterator end() { return Iterator(last ? last->next : nullptr); } }; +// Types + +#define IR_LEAF_ISA(NAME) static bool isaImpl(IROp op) { return op == kIROp_##NAME; } +#define IR_PARENT_ISA(NAME) static bool isaImpl(IROp op) { return op >= kIROp_First##NAME && op <= kIROp_Last##NAME; } + +#define SIMPLE_IR_TYPE(NAME, BASE) struct IR##NAME : IR##BASE { IR_LEAF_ISA(NAME) }; +#define SIMPLE_IR_PARENT_TYPE(NAME, BASE) struct IR##NAME : IR##BASE { IR_PARENT_ISA(NAME) }; + + +// All types in the IR are represented as instructions which conceptually +// execute before run time. +struct IRType : IRInst +{ + IRType* getCanonicalType() { return this; } + + IR_PARENT_ISA(Type) +}; + +struct IRBasicType : IRType +{ + BaseType getBaseType() { return BaseType(op - kIROp_FirstBasicType); } + + IR_PARENT_ISA(BasicType) +}; + +struct IRVoidType : IRBasicType +{ + IR_LEAF_ISA(VoidType) +}; + +struct IRBoolType : IRBasicType +{ + IR_LEAF_ISA(BoolType) +}; + +// Constant Instructions + typedef int64_t IRIntegerValue; typedef double IRFloatingPointValue; @@ -393,15 +438,25 @@ struct IRConstant : IRInst // HACK: allows us to hash the value easily void* ptrData[2]; } u; + + IR_PARENT_ISA(Constant) +}; + +struct IRIntLit : IRConstant +{ + IRIntegerValue getValue() { return u.intVal; } + + IR_LEAF_ISA(IntLit); }; +// Get the compile-time constant integer value of an instruction, +// if it has one, and assert-fail otherwise. +IRIntegerValue GetIntVal(IRInst* inst); + // A instruction that ends a basic block (usually because of control flow) struct IRTerminatorInst : IRInst { - static bool isaImpl(IROp op) - { - return (op >= kIROp_FirstTerminatorInst) && (op <= kIROp_LastTerminatorInst); - } + IR_PARENT_ISA(TerminatorInst) }; // A function parameter is owned by a basic block, and represents @@ -417,7 +472,7 @@ struct IRParam : IRInst IRParam* getNextParam(); IRParam* getPrevParam(); - static bool isaImpl(IROp op) { return op == kIROp_Param; } + IR_LEAF_ISA(Param) }; // A "parent" instruction is one that contains other instructions @@ -433,10 +488,7 @@ struct IRParentInst : IRInst IRInst* getLastChild() { return children.last; } IRInstListBase getChildren() { return children; } - static bool isaImpl(IROp op) - { - return (op >= kIROp_FirstParentInst) && (op <= kIROp_LastParentInst); - } + IR_PARENT_ISA(ParentInst) }; // A basic block is a parent instruction that adds the constraint @@ -510,7 +562,7 @@ struct IRBlock : IRParentInst // by the terminator instruction of the block. // The `getPredecessors()` and `getSuccessors()` functions // make this more precise. - // + // struct PredecessorList { PredecessorList(IRUse* begin) : b(begin) {} @@ -573,15 +625,204 @@ struct IRBlock : IRParentInst // - static bool isaImpl(IROp op) { return op == kIROp_Block; } + IR_LEAF_ISA(Block) +}; + +SIMPLE_IR_TYPE(BasicBlockType, Type) + +struct IRResourceTypeBase : IRType +{ + TextureFlavor getFlavor() const + { + return TextureFlavor(op & 0xFFFF); + } + + TextureFlavor::Shape GetBaseShape() const + { + return getFlavor().GetBaseShape(); + } + bool isMultisample() const { return getFlavor().isMultisample(); } + bool isArray() const { return getFlavor().isArray(); } + SlangResourceShape getShape() const { return getFlavor().getShape(); } + SlangResourceAccess getAccess() const { return getFlavor().getAccess(); } + + IR_PARENT_ISA(ResourceTypeBase); +}; + +struct IRResourceType : IRResourceTypeBase +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + + IR_PARENT_ISA(ResourceType) +}; + +struct IRTextureTypeBase : IRResourceType +{ + IR_PARENT_ISA(TextureTypeBase) +}; + +struct IRTextureType : IRTextureTypeBase +{ + IR_PARENT_ISA(TextureType) +}; + +struct IRTextureSamplerType : IRTextureTypeBase +{ + IR_PARENT_ISA(TextureSamplerType) +}; + +struct IRGLSLImageType : IRTextureTypeBase +{ + IR_PARENT_ISA(GLSLImageType) +}; + +struct IRSamplerStateTypeBase : IRType +{ + IR_PARENT_ISA(SamplerStateTypeBase) +}; + +SIMPLE_IR_TYPE(SamplerStateType, SamplerStateTypeBase) +SIMPLE_IR_TYPE(SamplerComparisonStateType, SamplerStateTypeBase) + +struct IRBuiltinGenericType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + + IR_PARENT_ISA(BuiltinGenericType) +}; + +SIMPLE_IR_PARENT_TYPE(PointerLikeType, BuiltinGenericType); +SIMPLE_IR_PARENT_TYPE(HLSLStructuredBufferTypeBase, BuiltinGenericType) +SIMPLE_IR_TYPE(HLSLStructuredBufferType, HLSLStructuredBufferTypeBase) +SIMPLE_IR_TYPE(HLSLRWStructuredBufferType, HLSLStructuredBufferTypeBase) +// TODO: need raster-ordered case here + +SIMPLE_IR_PARENT_TYPE(UntypedBufferResourceType, Type) +SIMPLE_IR_TYPE(HLSLByteAddressBufferType, UntypedBufferResourceType) +SIMPLE_IR_TYPE(HLSLRWByteAddressBufferType, UntypedBufferResourceType) + +SIMPLE_IR_TYPE(HLSLAppendStructuredBufferType, HLSLStructuredBufferTypeBase) +SIMPLE_IR_TYPE(HLSLConsumeStructuredBufferType, HLSLStructuredBufferTypeBase) + +struct IRHLSLPatchType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + IRInst* getElementCount() { return getOperand(1); } + + IR_PARENT_ISA(HLSLPatchType) +}; + +SIMPLE_IR_TYPE(HLSLInputPatchType, HLSLPatchType) +SIMPLE_IR_TYPE(HLSLOutputPatchType, HLSLPatchType) + +SIMPLE_IR_PARENT_TYPE(HLSLStreamOutputType, BuiltinGenericType) +SIMPLE_IR_TYPE(HLSLPointStreamType, HLSLStreamOutputType) +SIMPLE_IR_TYPE(HLSLLineStreamType, HLSLStreamOutputType) +SIMPLE_IR_TYPE(HLSLTriangleStreamType, HLSLStreamOutputType) + +SIMPLE_IR_TYPE(GLSLInputAttachmentType, Type) +SIMPLE_IR_PARENT_TYPE(ParameterGroupType, PointerLikeType) +SIMPLE_IR_PARENT_TYPE(UniformParameterGroupType, ParameterGroupType) +SIMPLE_IR_PARENT_TYPE(VaryingParameterGroupType, ParameterGroupType) +SIMPLE_IR_TYPE(ConstantBufferType, UniformParameterGroupType) +SIMPLE_IR_TYPE(TextureBufferType, UniformParameterGroupType) +SIMPLE_IR_TYPE(GLSLInputParameterGroupType, VaryingParameterGroupType) +SIMPLE_IR_TYPE(GLSLOutputParameterGroupType, VaryingParameterGroupType) +SIMPLE_IR_TYPE(GLSLShaderStorageBufferType, UniformParameterGroupType) +SIMPLE_IR_TYPE(ParameterBlockType, UniformParameterGroupType) + +struct IRArrayTypeBase : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + + // Returns the element count for an `IRArrayType`, and null + // for an `IRUnsizedArrayType`. + IRInst* getElementCount(); + + IR_PARENT_ISA(ArrayTypeBase) }; -// For right now, we will represent the type of -// an IR function using the type of the AST -// function from which it was created. +struct IRArrayType: IRArrayTypeBase +{ + IRInst* getElementCount() { return getOperand(1); } + + IR_LEAF_ISA(ArrayType) +}; + +SIMPLE_IR_TYPE(UnsizedArrayType, ArrayTypeBase) + +SIMPLE_IR_PARENT_TYPE(Rate, Type) +SIMPLE_IR_TYPE(ConstExprRate, Rate) +SIMPLE_IR_TYPE(GroupSharedRate, Rate) + +struct IRRateQualifiedType : IRType +{ + IRRate* getRate() { return (IRRate*) getOperand(0); } + IRType* getValueType() { return (IRType*) getOperand(1); } + + IR_LEAF_ISA(RateQualifiedType) +}; + + +// Unlike the AST-level type system where `TypeType` tracks the +// underlying type, the "type of types" in the IR is a simple +// value with no operands, so that all type nodes have the +// same type. +SIMPLE_IR_PARENT_TYPE(Kind, Type); +SIMPLE_IR_TYPE(TypeKind, Kind); + +// The kind of any and all generics. +// +// A more complete type system would include "arrow kinds" to +// be able to track the domain and range of generics (e.g., +// the `vector` generic maps a type and an integer to a type). +// This is only really needed if we ever wanted to support +// "higher-kinded" generics (e.g., a generic that takes another +// generic as a parameter). // -// TODO: need to do this better. -typedef FuncType IRFuncType; +SIMPLE_IR_TYPE(GenericKind, Kind) + +struct IRVectorType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + IRInst* getElementCount() { return getOperand(1); } + + IR_LEAF_ISA(VectorType) +}; + +struct IRMatrixType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + IRInst* getRowCount() { return getOperand(1); } + IRInst* getColumnCount() { return getOperand(2); } + + IR_LEAF_ISA(MatrixType) +}; + +struct IRPtrTypeBase : IRType +{ + IRType* getValueType() { return (IRType*)getOperand(0); } + + IR_PARENT_ISA(PtrTypeBase) +}; + +struct IRPtrType : IRPtrTypeBase +{ + IR_LEAF_ISA(PtrType) +}; + +SIMPLE_IR_PARENT_TYPE(OutTypeBase, PtrTypeBase) +SIMPLE_IR_TYPE(OutType, OutTypeBase) +SIMPLE_IR_TYPE(InOutType, OutTypeBase) + +struct IRFuncType : IRType +{ + IRType* getResultType() { return (IRType*) getOperand(0); } + UInt getParamCount() { return getOperandCount() - 1; } + IRType* getParamType(UInt index) { return (IRType*)getOperand(1 + index); } + + IR_LEAF_ISA(FuncType) +}; // A "global value" is an instruction that might have // linkage, so that it can be declared in one module @@ -607,12 +848,55 @@ struct IRGlobalValue : IRParentInst void moveToEnd(); #endif - static bool isaImpl(IROp op) - { - return (op >= kIROp_FirstGlobalValue) && (op <= kIROp_LastGlobalValue); - } + IR_PARENT_ISA(GlobalValue) +}; + +bool isDefinition( + IRGlobalValue* inVal); + + +// A structure type is represented as a parent instruction, +// where the child instructions represent the fields of the +// struct. +// +// The space of fields that a given struct type supports +// are defined as its "keys", which are global values +// (that is, they have mangled names that can be used +// for linkage). +// +struct IRStructKey : IRGlobalValue +{ + IR_LEAF_ISA(StructKey) +}; +// +// The fields of the struct are then defined as mappings +// from those keys to the associated type (in the case of +// the struct type) or to values (when lookup up a field). +// +// A struct field thus has two operands: the key, and the +// type of the field. +// +struct IRStructField : IRInst +{ + IRStructKey* getKey() { return cast(getOperand(0)); } + IRType* getFieldType() { return cast(getOperand(1)); } + + IR_LEAF_ISA(StructField) +}; +// +// The struct type is then represented as a parent instruction +// that contains the various fields. Note that a struct does +// *not* contain the keys, because code needs to be able to +// reference the keys from scopes outside of the struct. +// +struct IRStructType : IRGlobalValue +{ + IRInstList getFields() { return IRInstList(getChildren()); } + + IR_LEAF_ISA(StructType) }; + /// @brief A global value that potentially holds executable code. /// struct IRGlobalValueWithCode : IRGlobalValue @@ -628,48 +912,53 @@ struct IRGlobalValueWithCode : IRGlobalValue // Add a block to the end of this function. void addBlock(IRBlock* block); + + IR_PARENT_ISA(GlobalValueWithCode) +}; + +// A value that has parameters so that it can conceptually be called. +struct IRGlobalValueWithParams : IRGlobalValueWithCode +{ + // Convenience accessor for the IR parameters, + // which are actually the parameters of the first + // block. + IRParam* getFirstParam(); + + IR_PARENT_ISA(GlobalValueWithParams) }; // A function is a parent to zero or more blocks of instructions. // // A function is itself a value, so that it can be a direct operand of // an instruction (e.g., a call). -struct IRFunc : IRGlobalValueWithCode +struct IRFunc : IRGlobalValueWithParams { // The type of the IR-level function - IRFuncType* getType() { return (IRFuncType*) type.Ptr(); } - - // If this function is generic, then we store a reference - // to the AST-level generic that defines its parameters - // and their constraints. - List> genericDecls; - int specializedGenericLevel = -1; + IRFuncType* getDataType() { return (IRFuncType*) IRInst::getDataType(); } - GenericDecl* getGenericDecl() - { - if (specializedGenericLevel != -1) - return genericDecls[specializedGenericLevel].Ptr(); - return nullptr; - } - - // Convenience accessors for working with the + // Convenience accessors for working with the // function's type. - Type* getResultType(); + IRType* getResultType(); UInt getParamCount(); - Type* getParamType(UInt index); + IRType* getParamType(UInt index); - // Convenience accessor for the IR parameters, - // which are actually the parameters of the first - // block. - IRParam* getFirstParam(); + IR_LEAF_ISA(Func) +}; - virtual void dispose() override - { - IRGlobalValueWithCode::dispose(); - genericDecls = decltype(genericDecls)(); - } +// A generic is akin to a function, but is conceptually executed +// before runtime, to specialize the code nested within. +// +// In practice, a generic always holds only a single block, and ends +// with a `return` instruction for the value that the generic yields. +struct IRGeneric : IRGlobalValueWithParams +{ + IR_LEAF_ISA(Generic) }; +// Find the value that is returned from a generic, so that +// a pass can glean information from it. +IRInst* findGenericReturnVal(IRGeneric* generic); + // The IR module itself is represented as an instruction, which // serves at the root of the tree of all instructions in the module. struct IRModuleInst : IRParentInst @@ -680,6 +969,8 @@ struct IRModuleInst : IRParentInst IRModule* module; IRInstListBase getGlobalInsts() { return getChildren(); } + + IR_LEAF_ISA(Module) }; struct IRModule : RefObject diff --git a/source/slang/legalize-types.cpp b/source/slang/legalize-types.cpp index 0b8f49b0c..51a7af314 100644 --- a/source/slang/legalize-types.cpp +++ b/source/slang/legalize-types.cpp @@ -1,6 +1,7 @@ // legalize-types.cpp #include "legalize-types.h" +#include "ir-insts.h" #include "mangle.h" namespace Slang @@ -68,30 +69,30 @@ LegalType LegalType::pair( // -static bool isResourceType(Type* type) +static bool isResourceType(IRType* type) { - while (auto arrayType = type->As()) + while (auto arrayType = as(type)) { - type = arrayType->baseType; + type = arrayType->getElementType(); } - if (auto resourceTypeBase = type->As()) + if (auto resourceTypeBase = as(type)) { return true; } - else if (auto builtinGenericType = type->As()) + else if (auto builtinGenericType = as(type)) { return true; } - else if (auto pointerLikeType = type->As()) + else if (auto pointerLikeType = as(type)) { return true; } - else if (auto samplerType = type->As()) + else if (auto samplerType = as(type)) { return true; } - else if(auto untypedBufferType = type->As()) + else if(auto untypedBufferType = as(type)) { return true; } @@ -118,13 +119,13 @@ ModuleDecl* findModuleForDecl( struct TupleTypeBuilder { TypeLegalizationContext* context; - RefPtr type; - DeclRef typeDeclRef; + IRType* type; + IRStructType* originalStructType; struct OrdinaryElement { - DeclRef fieldDeclRef; - RefPtr type; + IRStructKey* fieldKey = nullptr; + IRType* type = nullptr; }; @@ -146,10 +147,10 @@ struct TupleTypeBuilder // Add a field to the (pseudo-)type we are building void addField( - DeclRef fieldDeclRef, - LegalType legalFieldType, - LegalType legalLeafType, - bool isResource) + IRStructKey* fieldKey, + LegalType legalFieldType, + LegalType legalLeafType, + bool isResource) { LegalType ordinaryType; LegalType specialType; @@ -188,7 +189,7 @@ struct TupleTypeBuilder // or a pair "under" an `implicitDeref`, so // we'll need to ensure that elsewhere. addField( - fieldDeclRef, + fieldKey, legalFieldType, legalLeafType.getImplicitDeref()->valueType, isResource); @@ -232,11 +233,11 @@ struct TupleTypeBuilder break; } - String mangledFieldName = getMangledName(fieldDeclRef.getDecl()); +// String mangledFieldName = getMangledName(fieldDeclRef.getDecl()); PairInfo::Element pairElement; pairElement.flags = 0; - pairElement.mangledName = mangledFieldName; + pairElement.key = fieldKey; pairElement.fieldPairInfo = elementPairInfo; // We will always add a field to the "ordinary" @@ -244,7 +245,7 @@ struct TupleTypeBuilder // data, just to keep the list of fields aligned // with the original type. OrdinaryElement ordinaryElement; - ordinaryElement.fieldDeclRef = fieldDeclRef; + ordinaryElement.fieldKey = fieldKey; if (ordinaryType.flavor != LegalType::Flavor::none) { anyOrdinary = true; @@ -273,7 +274,7 @@ struct TupleTypeBuilder pairElement.flags |= PairInfo::kFlag_hasSpecial; TuplePseudoType::Element specialElement; - specialElement.mangledName = mangledFieldName; + specialElement.key = fieldKey; specialElement.type = specialType; specialElements.Add(specialElement); } @@ -284,19 +285,15 @@ struct TupleTypeBuilder // Add a field to the (pseudo-)type we are building void addField( - DeclRef fieldDeclRef) + IRStructField* field) { - // Skip `static` fields. - if (fieldDeclRef.getDecl()->HasModifier()) - return; - - auto fieldType = GetType(fieldDeclRef); + auto fieldType = field->getFieldType(); bool isResourceField = isResourceType(fieldType); - auto legalFieldType = legalizeType(context, fieldType); + addField( - fieldDeclRef, + field->getKey(), legalFieldType, legalFieldType, isResourceField); @@ -328,69 +325,37 @@ struct TupleTypeBuilder LegalType ordinaryType; if (anyOrdinary) { - // We are going to create a new `struct` type declaration that clones - // the fields we care about from the original `struct` type. Note that - // these fields may have different types from what they did before, + // We are going to create an new IR `struct` type that contains + // the "ordinary" fields from the original type. Note that these + // fields may have different types from what they did before, // because the fields themselves might have been legalized. // - // Our new declaration will have the same name as the old one, so + // The new type will have the same mangled name as the old one, so // downstream code is going to need to be careful not to emit declarations // for both of them. This should be okay, though, because the original // type was illegal (that was the whole point) and so it shouldn't be - // allowed in the output anyway. - RefPtr ordinaryStructDecl = new StructDecl(); - ordinaryStructDecl->loc = typeDeclRef.getDecl()->loc; - ordinaryStructDecl->nameAndLoc = typeDeclRef.getDecl()->nameAndLoc; - - auto typeLegalizedModifier = new LegalizedModifier(); - typeLegalizedModifier->originalMangledName = getMangledName(typeDeclRef); - addModifier(ordinaryStructDecl, typeLegalizedModifier); - - // We will do something a bit unsavory here, by setting the logical - // parent of the new `struct` type to be the same as the orignal type - // (All of this helps ensure it gets the same mangled name). + // referenced in the output anyway. // - ordinaryStructDecl->ParentDecl = typeDeclRef.getDecl()->ParentDecl; - - if (context->mainModuleDecl) - { - // If the declaration we are lowering belongs to the AST-based - // module being lowered (rather than translated to IR), then we - // need to add any new declaration we create to that output. - - // If we are *not* outputting an IR module as well, then - // everything needs to wind up in a single AST module. - if (!context->irModule) - { - context->outputModuleDecl->Members.Add(ordinaryStructDecl); - } - else - { - // Otherwise, check if this declaration belongs to the main - // module (which is being lowered via the AST-to-AST pass), - // and add it to the output if needed. - // - // TODO: This won't work correctly if a type from the AST - // module is used to specialize a generic in the IR module, - // since the declaration would need to precede the specialized - // func... - auto parentModule = findModuleForDecl(typeDeclRef.getDecl()); - if (parentModule && (parentModule == context->mainModuleDecl)) - { - context->outputModuleDecl->Members.Add(ordinaryStructDecl); - } - } - } - - // For memory management reasons, we need to keep a reference to - // the declaration live, no matter what. - context->createdDecls.Add(ordinaryStructDecl); + IRBuilder* builder = context->getBuilder(); + IRStructType* ordinaryStructType = builder->createStructType(); + ordinaryStructType->sourceLoc = originalStructType->sourceLoc; + ordinaryStructType->mangledName = originalStructType->mangledName; + + // The new struct type will appear right after the original in the IR, + // so that we can be sure any instruction that could reference the + // original can also reference the new one. + ordinaryStructType->insertAfter(originalStructType); + + // Mark the original type for removal once all the other legalization + // activity is completed. This is necessary because both the original + // and replacement type have the same mangled name, so they would + // collide. + // + // (Also, the original type wasn't legal - that was the whole point...) + context->instsToRemove.Add(originalStructType); - UInt elementCounter = 0; for(auto ee : ordinaryElements) { - UInt elementIndex = elementCounter++; - // We will ensure that all the original fields are represented, // although they may have different types (due to legalization). // For fields that have *no* ordinary data, we will give them @@ -401,32 +366,23 @@ struct TupleTypeBuilder // and modified type will have the same number of fields, so // we can continue to look up field layouts by index in the // emit logic) - RefPtr fieldType = ee.type; + // + // TODO: we should scrap that, and layout lookup should just + // be based on mangled field names in all cases. + // + IRType* fieldType = ee.type; if(!fieldType) - fieldType = context->session->getVoidType(); + fieldType = context->getBuilder()->getVoidType(); // TODO: shallow clone of modifiers, etc. - RefPtr fieldDecl = new StructField(); - fieldDecl->loc = ee.fieldDeclRef.getDecl()->loc; - fieldDecl->nameAndLoc = ee.fieldDeclRef.getDecl()->nameAndLoc; - fieldDecl->type.type = fieldType; - - fieldDecl->ParentDecl = ordinaryStructDecl; - ordinaryStructDecl->Members.Add(fieldDecl); - - pairElements[elementIndex].ordinaryFieldDeclRef = makeDeclRef(fieldDecl.Ptr()); - - auto fieldLegalizedModifier = new LegalizedModifier(); - fieldLegalizedModifier->originalMangledName = getMangledName(ee.fieldDeclRef); - addModifier(fieldDecl, fieldLegalizedModifier); + builder->createStructField( + ordinaryStructType, + ee.fieldKey, + fieldType); } - RefPtr ordinaryStructType = DeclRefType::Create( - context->session, - makeDeclRef(ordinaryStructDecl.Ptr())); - - ordinaryType = LegalType::simple(ordinaryStructType); + ordinaryType = LegalType::simple((IRType*) ordinaryStructType); } LegalType specialType; @@ -449,44 +405,23 @@ struct TupleTypeBuilder }; -static RefPtr createBuiltinGenericType( +static IRType* createBuiltinGenericType( TypeLegalizationContext* context, - DeclRef const& typeDeclRef, - RefPtr elementType) + IROp op, + IRType* elementType) { - // We are going to take the type for the original - // decl-ref and construct a new one that uses - // our new element type as its parameter. - // - // TODO: we should have library code to make - // manipulations like this way easier. - - RefPtr oldGenericSubst = typeDeclRef.substitutions.genericSubstitutions; - SLANG_ASSERT(oldGenericSubst); - - RefPtr newGenericSubst = new GenericSubstitution(); - - newGenericSubst->outer = oldGenericSubst->outer; - newGenericSubst->genericDecl = oldGenericSubst->genericDecl; - newGenericSubst->args = oldGenericSubst->args; - newGenericSubst->args[0] = elementType; - - auto newDeclRef = DeclRef( - typeDeclRef.getDecl(), - newGenericSubst); - - auto newType = DeclRefType::Create( - context->session, - newDeclRef); - - return newType; + IRInst* operands[] = { elementType }; + return context->getBuilder()->getType( + op, + 1, + operands); } // Create a uniform buffer type with a given legalized // element type. static LegalType createLegalUniformBufferType( TypeLegalizationContext* context, - DeclRef const& typeDeclRef, + IROp op, LegalType legalElementType) { switch (legalElementType.flavor) @@ -497,7 +432,7 @@ static LegalType createLegalUniformBufferType( // so we want to create a uniform buffer that wraps it. return LegalType::simple(createBuiltinGenericType( context, - typeDeclRef, + op, legalElementType.getSimple())); } break; @@ -520,7 +455,7 @@ static LegalType createLegalUniformBufferType( // I'm going to attempt to hack this for now. return LegalType::implicitDeref(createLegalUniformBufferType( context, - typeDeclRef, + op, legalElementType.getImplicitDeref()->valueType)); } break; @@ -535,7 +470,7 @@ static LegalType createLegalUniformBufferType( auto ordinaryType = createLegalUniformBufferType( context, - typeDeclRef, + op, pairType->ordinaryType); auto specialType = LegalType::implicitDeref(pairType->specialType); @@ -558,7 +493,7 @@ static LegalType createLegalUniformBufferType( { TuplePseudoType::Element newElement; - newElement.mangledName = ee.mangledName; + newElement.key = ee.key; newElement.type = LegalType::implicitDeref(ee.type); bufferPseudoTupleType->elements.Add(newElement); @@ -576,20 +511,20 @@ static LegalType createLegalUniformBufferType( } static LegalType createLegalUniformBufferType( - TypeLegalizationContext* context, - UniformParameterGroupType* uniformBufferType, - LegalType legalElementType) + TypeLegalizationContext* context, + IRUniformParameterGroupType* uniformBufferType, + LegalType legalElementType) { return createLegalUniformBufferType( context, - uniformBufferType->declRef, + uniformBufferType->op, legalElementType); } // Create a pointer type with a given legalized value type. static LegalType createLegalPtrType( TypeLegalizationContext* context, - DeclRef const& typeDeclRef, + IROp op, LegalType legalValueType) { switch (legalValueType.flavor) @@ -600,7 +535,7 @@ static LegalType createLegalPtrType( // so we want to create a uniform buffer that wraps it. return LegalType::simple(createBuiltinGenericType( context, - typeDeclRef, + op, legalValueType.getSimple())); } break; @@ -610,7 +545,7 @@ static LegalType createLegalPtrType( // We are being asked to create a pointer type to something // that is implicitly dereferenced, meaning we had: // - // Ptr(PtrLink(T)) + // Ptr(PtrLike(T)) // // and now are being asked to make: // @@ -621,9 +556,12 @@ static LegalType createLegalPtrType( // implicitDeref(Ptr(LegalT)) // // and nobody should really be able to tell the difference, right? + // + // TODO: invetigate whether there are situations where this + // will matter. return LegalType::implicitDeref(createLegalPtrType( context, - typeDeclRef, + op, legalValueType.getImplicitDeref()->valueType)); } break; @@ -635,11 +573,11 @@ static LegalType createLegalPtrType( auto ordinaryType = createLegalPtrType( context, - typeDeclRef, + op, pairType->ordinaryType); auto specialType = createLegalPtrType( context, - typeDeclRef, + op, pairType->specialType); return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); @@ -658,10 +596,10 @@ static LegalType createLegalPtrType( { TuplePseudoType::Element newElement; - newElement.mangledName = ee.mangledName; + newElement.key = ee.key; newElement.type = createLegalPtrType( context, - typeDeclRef, + op, ee.type); ptrPseudoTupleType->elements.Add(newElement); @@ -680,30 +618,31 @@ static LegalType createLegalPtrType( struct LegalTypeWrapper { - virtual LegalType wrap(TypeLegalizationContext* context, Type* type) = 0; + virtual LegalType wrap(TypeLegalizationContext* context, IRType* type) = 0; }; struct ArrayLegalTypeWrapper : LegalTypeWrapper { - ArrayExpressionType* arrayType; + IRArrayTypeBase* arrayType; - LegalType wrap(TypeLegalizationContext* context, Type* type) + LegalType wrap(TypeLegalizationContext* context, IRType* type) { - return LegalType::simple(context->session->getArrayType( + return LegalType::simple(context->getBuilder()->getArrayTypeBase( + arrayType->op, type, - arrayType->ArrayLength)); + arrayType->getElementCount())); } }; struct BuiltinGenericLegalTypeWrapper : LegalTypeWrapper { - DeclRef declRef; + IROp op; - LegalType wrap(TypeLegalizationContext* context, Type* type) + LegalType wrap(TypeLegalizationContext* context, IRType* type) { return LegalType::simple(createBuiltinGenericType( context, - declRef, + op, type)); } }; @@ -711,7 +650,7 @@ struct BuiltinGenericLegalTypeWrapper : LegalTypeWrapper struct ImplicitDerefLegalTypeWrapper : LegalTypeWrapper { - LegalType wrap(TypeLegalizationContext*, Type* type) + LegalType wrap(TypeLegalizationContext*, IRType* type) { return LegalType::implicitDeref(LegalType::simple(type)); } @@ -773,7 +712,7 @@ static LegalType wrapLegalType( { TuplePseudoType::Element element; - element.mangledName = ee.mangledName; + element.key = ee.key; element.type = wrapLegalType( context, ee.type, @@ -794,14 +733,14 @@ static LegalType wrapLegalType( } } - // Legalize a type, including any nested types // that it transitively contains. -LegalType legalizeType( +LegalType legalizeTypeImpl( TypeLegalizationContext* context, - Type* type) + IRType* type) { - if (auto uniformBufferType = type->As()) + + if (auto uniformBufferType = as(type)) { // We have one of: // @@ -840,111 +779,99 @@ LegalType legalizeType( // are legal as-is. return LegalType::simple(type); } - else if (type->As()) + else if (as(type)) { return LegalType::simple(type); } - else if (type->As()) + else if (as(type)) { return LegalType::simple(type); } - else if (type->As()) + else if (as(type)) { return LegalType::simple(type); } - else if (auto ptrType = type->As()) + else if (auto ptrType = as(type)) { auto legalValueType = legalizeType(context, ptrType->getValueType()); - return createLegalPtrType(context, ptrType->declRef, legalValueType); + return createLegalPtrType(context, ptrType->op, legalValueType); } - else if (auto declRefType = type->As()) + else if(auto structType = as(type)) { - auto declRef = declRefType->declRef; - - LegalType legalType; - if(context->mapDeclRefToLegalType.TryGetValue(declRef, legalType)) - return legalType; - + // Look at the (non-static) fields, and + // see if anything needs to be cleaned up. + // The things that need to be "cleaned up" for + // our purposes are: + // + // - Fields of resource type, or any other future + // type we run into that isn't allowed in + // aggregates for at least some targets + // + // - Fields with types that themselves had to + // get legalized. + // + // If we don't run into any of these, we + // can just use the type as-is. Hooray! + // + // Otherwise, we are effectively going to split + // the type apart and create a `TuplePseudoType`. + // Every field of the original type will be + // represented as an element of this pseudo-type. + // Each element will record its `LegalType`, + // and the original field that it was created from. + // An element will also track whether it contains + // any "ordinary" data, and if so, it will remember + // an element index in a real (AST-level, non-pseudo) + // `TupleType` that is used to bundle together + // such fields. + // + // Storing all the simple fields together like this + // obviously adds complexity to the legalization + // pass, but it has important benefits: + // + // - It avoids creating functions with a very large + // number of parameters (when passing a structure + // with many fields), which might confuse downstream + // compilers. + // + // - It avoids applying AOS->SOA conversion to fields + // that don't actually need it, which is basically + // required if we want type layout to work. + // + // - It ensures that we can actually construct a + // constant-buffer type that wraps a legalized + // aggregate type; the ordinary fields will get + // placed inside a new constant-buffer type, + // while the special ones will get left outside. + // - if (auto aggTypeDeclRef = declRef.As()) + // TODO: there is a risk here that we might recursively + // invole `legalizeType` on the type that we are + // currently trying to legalize. We need to detect that + // situation somehow, by inserting a sentinel value + // into `mapTypeToLegalType` during the per-field + // legalization process, and then if we ever see that + // sentinel in a call to `legalizeType`, we need + // to construct some kind of proxy type to help resolve + // the problem. + + TupleTypeBuilder builder; + builder.context = context; + builder.type = type; + builder.originalStructType = structType; + + for (auto ff : structType->getFields()) { - // Look at the (non-static) fields, and - // see if anything needs to be cleaned up. - // The things that need to be "cleaned up" for - // our purposes are: - // - // - Fields of resource type, or any other future - // type we run into that isn't allowed in - // aggregates for at least some targets - // - // - Fields with types that themselves had to - // get legalized. - // - // If we don't run into any of these, we - // can just use the type as-is. Hooray! - // - // Otherwise, we are effectively going to split - // the type apart and create a `TuplePseudoType`. - // Every field of the original type will be - // represented as an element of this pseudo-type. - // Each element will record its `LegalType`, - // and the original field that it was created from. - // An element will also track whether it contains - // any "ordinary" data, and if so, it will remember - // an element index in a real (AST-level, non-pseudo) - // `TupleType` that is used to bundle together - // such fields. - // - // Storing all the simple fields together like this - // obviously adds complexity to the legalization - // pass, but it has important benefits: - // - // - It avoids creating functions with a very large - // number of parameters (when passing a structure - // with many fields), which might confuse downstream - // compilers. - // - // - It avoids applying AOS->SOA conversion to fields - // that don't actually need it, which is basically - // required if we want type layout to work. - // - // - It ensures that we can actually construct a - // constant-buffer type that wraps a legalized - // aggregate type; the ordinary fields will get - // placed inside a new constant-buffer type, - // while the special ones will get left outside. - // - - TupleTypeBuilder builder; - builder.context = context; - builder.type = type; - builder.typeDeclRef = aggTypeDeclRef; - - - for (auto ff : getMembersOfType(aggTypeDeclRef)) - { - builder.addField(ff); - } - - legalType = builder.getResult(); - context->mapDeclRefToLegalType.AddIfNotExists(declRef, legalType); - return legalType; + builder.addField(ff); } - // TODO: for other declaration-reference types, we really - // need to legalize the types used in substitutions, and - // signal an error if any of them turn out to be non-simple. - // - // The limited cases of types that can handle having non-simple - // types as generic arguments all need to be special-cased here. - // (For example, we can't handle `Texture2D`. - // + return builder.getResult(); } - else if(auto arrayType = type->As()) + else if(auto arrayType = as(type)) { auto legalElementType = legalizeType( context, - arrayType->baseType); + arrayType->getElementType()); switch (legalElementType.flavor) { @@ -972,6 +899,34 @@ LegalType legalizeType( return LegalType::simple(type); } +void initialize( + TypeLegalizationContext* context, + Session* session, + IRModule* module) +{ + context->session = session; + context->irModule = module; + + context->sharedBuilder.session = session; + context->sharedBuilder.module = module; + + context->builder.sharedBuilder = &context->sharedBuilder; + context->builder.setInsertInto(module->moduleInst); +} + +LegalType legalizeType( + TypeLegalizationContext* context, + IRType* type) +{ + LegalType legalType; + if(context->mapTypeToLegalType.TryGetValue(type, legalType)) + return legalType; + + legalType = legalizeTypeImpl(context, type); + context->mapTypeToLegalType[type] = legalType; + return legalType; +} + // RefPtr getDerefTypeLayout( diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h index 8958c683d..887f263f8 100644 --- a/source/slang/legalize-types.h +++ b/source/slang/legalize-types.h @@ -24,6 +24,7 @@ // and some extra tuple-ified fields. #include "../core/basic.h" +#include "ir-insts.h" #include "syntax.h" #include "type-layout.h" #include "name.h" @@ -31,6 +32,8 @@ namespace Slang { +struct IRBuilder; + struct LegalTypeImpl : RefObject { }; @@ -65,19 +68,20 @@ struct LegalType Flavor flavor = Flavor::none; RefPtr obj; + IRType* irType; - static LegalType simple(Type* type) + static LegalType simple(IRType* type) { LegalType result; result.flavor = Flavor::simple; - result.obj = type; + result.irType = type; return result; } - RefPtr getSimple() const + IRType* getSimple() const { assert(flavor == Flavor::simple); - return obj.As(); + return irType; } static LegalType implicitDeref( @@ -139,16 +143,18 @@ struct TuplePseudoType : LegalTypeImpl struct Element { // The field that this element replaces - String mangledName; + IRStructKey* key; // The legalized type of the element - LegalType type; + LegalType type; }; // All of the elements of the tuple pseduo-type. List elements; }; +struct IRStructKey; + struct PairInfo : RefObject { typedef unsigned int Flags; @@ -159,10 +165,11 @@ struct PairInfo : RefObject kFlag_hasOrdinaryAndSpecial = kFlag_hasOrdinary | kFlag_hasSpecial, }; + struct Element { // The original field the element represents - String mangledName; + IRStructKey* key; // The conceptual type of the field. // If both the `hasOrdinary` and @@ -182,22 +189,17 @@ struct PairInfo : RefObject // then this is the `PairInfo` for that // pair type: RefPtr fieldPairInfo; - - // The actual field decl-ref that needs - // to be used for looking up this element - // in the ordinary type. - DeclRef ordinaryFieldDeclRef; }; // For a pair type or value, we need to track // which fields are on which side(s). List elements; - Element* findElement(String const& mangledName) + Element* findElement(IRStructKey* key) { for (auto& ee : elements) { - if(ee.mangledName == mangledName) + if(ee.key == key) return ⅇ } return nullptr; @@ -322,8 +324,8 @@ struct TuplePseudoVal : LegalValImpl { struct Element { - String mangledName; - LegalVal val; + IRStructKey* key; + LegalVal val; }; List elements; @@ -348,48 +350,31 @@ struct ImplicitDerefVal : LegalValImpl struct TypeLegalizationContext { - /// The overall compilation session (used when - /// constructing types). + /// The overall compilation session.. Session* session; - // If the type we are legalizing comes from an - // AST module being lowered via AST-to-AST translation, - // then we want to add any new declaration we create - // to represent it to the appropriate output module. - // We store some fields here to enable that: - RefPtr mainModuleDecl; - RefPtr outputModuleDecl; - - // We also need to know whether the IR is involved - // at all, because if it is, then it will own certain - // declarations instead. - // - // We do this in a slightly silly way by storing a pointer - // to the IR module (if any), and assume that its presence - // or absence is the indicator we need. IRModule* irModule = nullptr; - /// A list to retain any AST objects created during type legalization. - List> createdDecls; - - /// A mapping from declaration references to the resulting - /// legalized type. - /// - /// For declaration-reference types, this map can be used - /// to cache a legalization so that it will be re-used - /// for equivalent declaration references (and so avoid - /// emitting declarations of legalized `struct` types - /// multiple times). - Dictionary, LegalType> mapDeclRefToLegalType; - - // - Dictionary mapMangledNameToLegalIRValue; + SharedIRBuilder sharedBuilder; + IRBuilder builder; + + IRBuilder* getBuilder() { return &builder; } + + Dictionary mapTypeToLegalType; + + // Intstructions to be removed when legalization is done + HashSet instsToRemove; }; +void initialize( + TypeLegalizationContext* context, + Session* session, + IRModule* module); + LegalType legalizeType( TypeLegalizationContext* context, - Type* type); + IRType* type); /// Try to find the module that (recursively) contains a given declaration. ModuleDecl* findModuleForDecl( diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp index eebef6503..2735bc6ba 100644 --- a/source/slang/lookup.cpp +++ b/source/slang/lookup.cpp @@ -222,6 +222,67 @@ void DoMemberLookupImpl( name, baseType, request, ioResult, breadcrumbs); } +// If we are about to perform lookup through an interface, then +// we need to specialize the decl-ref to that interface to include +// a "this type" subtitution. This function applies that substition +// when it is required, and returns the existing `declRef` otherwise. +DeclRef maybeSpecializeInterfaceDeclRef( + RefPtr subType, + RefPtr superType, + DeclRef superTypeDeclRef, // The decl-ref we are going to perform lookup in + DeclRef constraintDeclRef) // The type constraint that told us our type is a subtype +{ + if (auto superInterfaceDeclRef = superTypeDeclRef.As()) + { + // Create a subtype witness value to note the subtype relationship + // that makes this specialization valid. + // + // Note: this is to ensure that we can specialize the subtype witness + // later (e.g., by replacing a subtype witness that represents a generic + // constraint paraqmeter with the concrete generic arguments that + // are used at a particular call site to the generic). + RefPtr subtypeWitness = new DeclaredSubtypeWitness(); + subtypeWitness->declRef = constraintDeclRef; + subtypeWitness->sub = subType; + subtypeWitness->sup = superType; + + RefPtr thisTypeSubst = new ThisTypeSubstitution(); + thisTypeSubst->interfaceDecl = superInterfaceDeclRef.getDecl(); + thisTypeSubst->witness = subtypeWitness; + thisTypeSubst->outer = superInterfaceDeclRef.substitutions.substitutions; + + auto specializedInterfaceDeclRef = DeclRef(superInterfaceDeclRef.getDecl(), thisTypeSubst); + return specializedInterfaceDeclRef; + } + + return superTypeDeclRef; +} + +// Same as the above, but we are specializing a type instead of a decl-ref +RefPtr maybeSpecializeInterfaceDeclRef( + Session* session, + RefPtr subType, + RefPtr superType, // The type we are going to perform lookup in + DeclRef constraintDeclRef) // The type constraint that told us our type is a subtype +{ + if (auto superDeclRefType = superType->As()) + { + if (auto superInterfaceDeclRef = superDeclRefType->declRef.As()) + { + auto specializedInterfaceDeclRef = maybeSpecializeInterfaceDeclRef( + subType, + superType, + superInterfaceDeclRef, + constraintDeclRef); + auto specializedInterfaceType = DeclRefType::Create(session, specializedInterfaceDeclRef); + return specializedInterfaceType; + } + } + + return superType; +} + + // Look for members of the given name in the given container for declarations void DoLocalLookupImpl( Session* session, @@ -313,27 +374,53 @@ void DoLocalLookupImpl( // for interface decls, also lookup in the base interfaces if (request.semantics) { - bool isInterface = containerDeclRef.As() ? true : false; + // TODO: + // The logic here is a bit gross, because it tries to work in terms of + // decl-refs instead of types (e.g., it asserts that the target type + // for an `extension` declaration must be a decl-ref type). + // + // This code should be converted to do a type-based lookup + // through declared bases for *any* aggregate type declaration. + // I think that logic is present in the type-bsed lookup path, but + // it would be needed here for when doing lookup from inside an + // aggregate declaration. + // if we are looking at an extension, find the target decl that we are extending + DeclRef targetDeclRef = containerDeclRef; + RefPtr targetDeclRefType; if (auto extDeclRef = containerDeclRef.As()) { - auto targetDeclRefType = extDeclRef.getDecl()->targetType->AsDeclRefType(); + targetDeclRefType = extDeclRef.getDecl()->targetType->AsDeclRefType(); SLANG_ASSERT(targetDeclRefType); int diff = 0; - auto targetDeclRef = targetDeclRefType->declRef.As().SubstituteImpl(containerDeclRef.substitutions, &diff); - isInterface = targetDeclRef.As() ? true : false; + targetDeclRef = targetDeclRefType->declRef.As().SubstituteImpl(containerDeclRef.substitutions, &diff); } + // if we are looking inside an interface decl, try find in the interfaces it inherits from + bool isInterface = targetDeclRef.As() ? true : false; if (isInterface) { + if(!targetDeclRefType) + { + targetDeclRefType = DeclRefType::Create(session, targetDeclRef); + } + auto baseInterfaces = getMembersOfType(containerDeclRef); for (auto inheritanceDeclRef : baseInterfaces) { checkDecl(request.semantics, inheritanceDeclRef.decl); + auto baseType = inheritanceDeclRef.getDecl()->base.type.As(); SLANG_ASSERT(baseType); int diff = 0; auto baseInterfaceDeclRef = baseType->declRef.SubstituteImpl(containerDeclRef.substitutions, &diff); + + baseInterfaceDeclRef = maybeSpecializeInterfaceDeclRef( + targetDeclRefType, + baseType, + baseInterfaceDeclRef, + inheritanceDeclRef); + DoLocalLookupImpl(session, name, baseInterfaceDeclRef.As(), request, result, inBreadcrumbs); } } @@ -456,6 +543,68 @@ LookupResult lookUpLocal( return result; } +void lookUpMemberImpl( + Session* session, + SemanticsVisitor* semantics, + Name* name, + Type* type, + LookupResult& ioResult, + BreadcrumbInfo* inBreadcrumbs, + LookupMask mask); + +// Perform lookup "through" the given constraint decl-ref, +// which should show that `subType` is a sub-type of some +// super-type (e.g., an interface). +// +void lookUpThroughConstraint( + Session* session, + SemanticsVisitor* semantics, + Name* name, + Type* subType, + DeclRef constraintDeclRef, + LookupResult& ioResult, + BreadcrumbInfo* inBreadcrumbs, + LookupMask mask) +{ + // The super-type in the constraint (e.g., `Foo` in `T : Foo`) + // will tell us a type we should use for lookup. + // + auto superType = GetSup(constraintDeclRef); + // + // We will go ahead and perform lookup using `superType`, + // after dealing with some details. + + // If we are looking up through an interface type, then + // we need to be sure that we add an appropriate + // "this type" substitution here, since that needs to + // be applied to any members we look up. + // + superType = maybeSpecializeInterfaceDeclRef( + session, + subType, + superType, + constraintDeclRef); + + // We need to track the indirection we took in lookup, + // so that we can construct an approrpiate AST on the other + // side that includes the "upcase" from sub-type to super-type. + // + BreadcrumbInfo breadcrumb; + breadcrumb.prev = inBreadcrumbs; + breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint; + breadcrumb.declRef = constraintDeclRef; + + // TODO: Need to consider case where this might recurse infinitely (e.g., + // if an inheritance clause does something like `Bad : Bad>`. + // + // TODO: The even simpler thing we need to worry about here is that if + // there is ever a "diamond" relationship in the inheritance hierarchy, + // we might end up seeing the same interface via diffrent "paths" and + // we wouldn't want that to lead to overload-resolution failure. + // + lookUpMemberImpl(session, semantics, name, superType, ioResult, &breadcrumb, mask); +} + void lookUpMemberImpl( Session* session, SemanticsVisitor* semantics, @@ -472,20 +621,15 @@ void lookUpMemberImpl( { for (auto constraintDeclRef : getMembersOfType(declRef.As())) { - // The super-type in the constraint (e.g., `Foo` in `T : Foo`) - // will tell us a type we should use for lookup. - auto bound = GetSup(constraintDeclRef); - - // Go ahead and use the target type, with an appropriate breadcrumb - // to indicate that we indirected through a type constraint. - - BreadcrumbInfo breadcrumb; - breadcrumb.prev = inBreadcrumbs; - breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint; - breadcrumb.declRef = constraintDeclRef; - - // TODO: Need to consider case where this might recurse infinitely. - lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb, mask); + lookUpThroughConstraint( + session, + semantics, + name, + type, + constraintDeclRef, + ioResult, + inBreadcrumbs, + mask); } } else if (auto aggTypeDeclRef = declRef.As()) @@ -514,20 +658,15 @@ void lookUpMemberImpl( if(!subDeclRefType->declRef.Equals(genericTypeParamDeclRef)) continue; - // The super-type in the constraint (e.g., `Foo` in `T : Foo`) - // will tell us a type we should use for lookup. - auto bound = GetSup(constraintDeclRef); - - // Go ahead and use the target type, with an appropriate breadcrumb - // to indicate that we indirected through a type constraint. - - BreadcrumbInfo breadcrumb; - breadcrumb.prev = inBreadcrumbs; - breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint; - breadcrumb.declRef = constraintDeclRef; - - // TODO: Need to consider case where this might recurse infinitely. - lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb, mask); + lookUpThroughConstraint( + session, + semantics, + name, + type, + constraintDeclRef, + ioResult, + inBreadcrumbs, + mask); } } diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 5f8428698..4f5e8bceb 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -82,8 +82,8 @@ struct SubscriptInfo : ExtendedValueInfo struct BoundSubscriptInfo : ExtendedValueInfo { DeclRef declRef; - RefPtr type; - List args; + IRType* type; + List args; }; // Some cases of `ExtendedValueInfo` need to @@ -141,6 +141,12 @@ struct LoweredValInfo val = nullptr; } + LoweredValInfo(IRType* t) + { + flavor = Flavor::Simple; + val = t; + } + static LoweredValInfo simple(IRInst* v) { LoweredValInfo info; @@ -212,7 +218,7 @@ struct BoundMemberInfo : ExtendedValueInfo DeclRef declRef; // The type of this value - RefPtr type; + IRType* type; }; // Represents the result of a swizzle operation in @@ -224,7 +230,7 @@ struct BoundMemberInfo : ExtendedValueInfo struct SwizzledLValueInfo : ExtendedValueInfo { // The type of the expression. - RefPtr type; + IRType* type; // The base expression (this should be an l-value) LoweredValInfo base; @@ -272,12 +278,36 @@ LoweredValInfo LoweredValInfo::swizzledLValue( return info; } +// An "environment" for mapping AST declarations to IR values. +// +// This is required because in some cases we might lower the +// same AST declaration to the IR multiple times (e.g., when +// a generic transitively contains multiple functions, we +// will emit a distinct IR generic for each function, with +// its own copies of the generic parameters). +// +struct IRGenEnv +{ + // Map an AST-level declaration to the IR-level value that represents it. + Dictionary mapDeclToValue; + + // The next outer env around this one + IRGenEnv* outer = nullptr; +}; + struct SharedIRGenContext { CompileRequest* compileRequest; ModuleDecl* mainModuleDecl; - Dictionary declValues; + // The "global" environment for mapping declarations to their IR values. + IRGenEnv globalEnv; + + // Map an AST-level declaration of an interface + // requirement to the IR-level "key" that + // is used to fetch that requirement from a + // witness table. + Dictionary interfaceRequirementKeys; // Arrays we keep around strictly for memory-management purposes: @@ -297,8 +327,13 @@ struct SharedIRGenContext struct IRGenContext { + // Shared state for the IR generation process SharedIRGenContext* shared; + // environment for mapping AST decls to IR values + IRGenEnv* env; + + // IR builder to use when building code under this context IRBuilder* irBuilder; // The value to use for any `this` expressions @@ -310,12 +345,33 @@ struct IRGenContext // might be insufficient. LoweredValInfo thisVal; + explicit IRGenContext(SharedIRGenContext* inShared) + : shared(inShared) + , env(&inShared->globalEnv) + , irBuilder(nullptr) + {} + Session* getSession() { return shared->compileRequest->mSession; } }; +void setGlobalValue(SharedIRGenContext* sharedContext, Decl* decl, LoweredValInfo value) +{ + sharedContext->globalEnv.mapDeclToValue[decl] = value; +} + +void setGlobalValue(IRGenContext* context, Decl* decl, LoweredValInfo value) +{ + setGlobalValue(context->shared, decl, value); +} + +void setValue(IRGenContext* context, Decl* decl, LoweredValInfo value) +{ + context->env->mapDeclToValue[decl] = value; +} + // Ensure that a version of the given declaration has been emitted to the IR LoweredValInfo ensureDecl( IRGenContext* context, @@ -325,15 +381,8 @@ LoweredValInfo ensureDecl( // any needed specializations in place. LoweredValInfo emitDeclRef( IRGenContext* context, - DeclRef declRef); - -// Emit necessary `specialize` instruction needed by a declRef. -// This is currently used by emitDeclRef() and emitFuncRef() -LoweredValInfo maybeEmitSpecializeInst(IRGenContext* context, - LoweredValInfo loweredDecl, // the lowered value of the inner decl - DeclRef declRef // the full decl ref containing substitutions -); - + DeclRef declRef, + IRType* type); IRInst* getSimpleVal(IRGenContext* context, LoweredValInfo lowered); @@ -402,23 +451,22 @@ IRInst* getOneValOfType( IRGenContext* context, IRType* type) { - if (auto basicType = dynamic_cast(type)) + switch(type->op) { - switch (basicType->baseType) - { - case BaseType::Int: - case BaseType::UInt: - case BaseType::UInt64: - return context->irBuilder->getIntValue(type, 1); + case kIROp_IntType: + case kIROp_UIntType: + case kIROp_UInt64Type: + return context->irBuilder->getIntValue(type, 1); - case BaseType::Float: - case BaseType::Double: - return context->irBuilder->getFloatValue(type, 1.0); + case kIROp_HalfType: + case kIROp_FloatType: + case kIROp_DoubleType: + return context->irBuilder->getFloatValue(type, 1.0); - default: - break; - } + default: + break; } + // TODO: should make sure to handle vector and matrix types here SLANG_UNEXPECTED("inc/dec type"); @@ -473,103 +521,19 @@ LoweredValInfo emitPostOp( return LoweredValInfo::ptr(argPtr); } -IRInst* findWitnessTable( +LoweredValInfo lowerRValueExpr( IRGenContext* context, - DeclRef declRef); - -LoweredValInfo emitWitnessTableRef( - IRGenContext* context, - Expr* expr) -{ - if (auto mbrExpr = dynamic_cast(expr)) - { - if (auto typeConstraintDeclRef = mbrExpr->declRef.As()) - { - if (mbrExpr->declRef.getDecl()->ParentDecl->As() - || mbrExpr->declRef.getDecl()->ParentDecl->As()) - { - RefPtr exprType = nullptr; - if (auto tt = mbrExpr->BaseExpression->type->As()) - exprType = tt->type; - else - exprType = mbrExpr->BaseExpression->type; - auto declRefType = exprType->GetCanonicalType()->AsDeclRefType(); - SLANG_ASSERT(declRefType); - IRInst* witnessTableVal = nullptr; - DeclRef srcDeclRef = declRefType->declRef; - if (!declRefType->declRef.As()) - { - // if we are referring to an actual type, don't include substitution - // and generate specialize instruction - srcDeclRef.substitutions = SubstitutionSet(); - } - witnessTableVal = context->irBuilder->emitFindWitnessTable(srcDeclRef, mbrExpr->declRef.As().getDecl()->getSup().type); - return maybeEmitSpecializeInst(context, LoweredValInfo::simple(witnessTableVal), declRefType->declRef); - } - } - if (auto inheritanceDecl = mbrExpr->declRef.As()) - { - if (mbrExpr->declRef.getDecl()->ParentDecl->As()) - { - return LoweredValInfo::simple(findWitnessTable(context, mbrExpr->declRef)); - } - } + Expr* expr); - if (auto genConstraintDeclRef = mbrExpr->declRef.As()) - { - return LoweredValInfo::simple(context->irBuilder->getDeclRefVal(genConstraintDeclRef)); - } - } - SLANG_UNEXPECTED("unknown witness table expression"); -} +IRType* lowerType( + IRGenContext* context, + Type* type); -// Emit a reference to a function, where we have concluded -// that the original AST referenced `funcDeclRef`. The -// optional expression `funcExpr` can provide additional -// detail that might modify how we go about looking up -// the actual value to call. -LoweredValInfo emitFuncRef( +static IRType* lowerType( IRGenContext* context, - DeclRef funcDeclRef, - Expr* funcExpr) + QualType const& type) { - if( !funcExpr ) - { - return emitDeclRef(context, funcDeclRef); - } - - // Let's look at the expression to see what additional - // information it gives us. - - if(auto funcMemberExpr = dynamic_cast(funcExpr)) - { - auto baseExpr = funcMemberExpr->BaseExpression; - if(auto baseMemberExpr = baseExpr.As()) - { - auto baseMemberDeclRef = baseMemberExpr->declRef; - if(auto baseConstraintDeclRef = baseMemberDeclRef.As()) - { - // We are calling a method "through" a generic type - // parameter that was constrained to some type. - // That means `funcDeclRef` is a reference to the method - // on the `interface` type (which doesn't actually have - // a body, so we don't want to emit or call it), and - // we actually want to perform a lookup step to - // find the corresponding member on our chosen type. - RefPtr type = funcExpr->type; - auto loweredBaseWitnessTable = emitWitnessTableRef(context, baseMemberExpr); - auto loweredVal = LoweredValInfo::simple(context->irBuilder->emitLookupInterfaceMethodInst( - type, - loweredBaseWitnessTable.val, - funcDeclRef)); - return maybeEmitSpecializeInst(context, loweredVal, funcDeclRef); - } - } - } - - // We didn't trigger a special case, so just emit a reference - // to the function itself. - return emitDeclRef(context, funcDeclRef); + return lowerType(context, type.type); } // Given a `DeclRef` for something callable, along with a bunch of @@ -578,7 +542,7 @@ LoweredValInfo emitCallToDeclRef( IRGenContext* context, IRType* type, DeclRef funcDeclRef, - Expr* funcExpr, + IRType* funcType, UInt argCount, IRInst* const* args) { @@ -587,7 +551,7 @@ LoweredValInfo emitCallToDeclRef( if (auto subscriptDeclRef = funcDeclRef.As()) { - // A reference to a subscript declaration is a special case, + // A reference to a subscript declaration is a special case, // because it is not possible to call a subscript directly; // we must call one of its accessors. // @@ -605,7 +569,7 @@ LoweredValInfo emitCallToDeclRef( { // The `ref` accessor will return a pointer to the value, so // we need to reflect that in the type of our `call` instruction. - RefPtr ptrType = context->getSession()->getPtrType(type); + IRType* ptrType = context->irBuilder->getPtrType(type); // Rather than call `emitCallToVal` here, we make a recursive call // to `emitCallToDeclRef` so that it can handle things like intrinsic-op @@ -614,7 +578,7 @@ LoweredValInfo emitCallToDeclRef( context, ptrType, refAccessorDeclRef, - funcExpr, + funcType, argCount, args); @@ -744,7 +708,16 @@ LoweredValInfo emitCallToDeclRef( } // Fallback case is to emit an actual call. - LoweredValInfo funcVal = emitFuncRef(context, funcDeclRef, funcExpr); + if(!funcType) + { + List argTypes; + for(UInt ii = 0; ii < argCount; ++ii) + { + argTypes.Add(args[ii]->getDataType()); + } + funcType = builder->getFuncType(argCount, argTypes.Buffer(), type); + } + LoweredValInfo funcVal = emitDeclRef(context, funcDeclRef, funcType); return emitCallToVal(context, type, funcVal, argCount, args); } @@ -752,15 +725,22 @@ LoweredValInfo emitCallToDeclRef( IRGenContext* context, IRType* type, DeclRef funcDeclRef, - Expr* funcExpr, - List const& args) + IRType* funcType, + List const& args) { - return emitCallToDeclRef(context, type, funcDeclRef, funcExpr, args.Count(), args.Buffer()); + return emitCallToDeclRef(context, type, funcDeclRef, funcType, args.Count(), args.Buffer()); +} + +IRInst* getFieldKey( + IRGenContext* context, + DeclRef field) +{ + return getSimpleVal(context, emitDeclRef(context, field, context->irBuilder->getKeyType())); } LoweredValInfo extractField( IRGenContext* context, - Type* fieldType, + IRType* fieldType, LoweredValInfo base, DeclRef field) { @@ -775,7 +755,7 @@ LoweredValInfo extractField( builder->emitFieldExtract( fieldType, irBase, - builder->getDeclRefVal(field))); + getFieldKey(context, field))); } break; @@ -803,9 +783,9 @@ LoweredValInfo extractField( IRInst* irBasePtr = base.val; return LoweredValInfo::ptr( builder->emitFieldAddress( - context->getSession()->getPtrType(fieldType), + builder->getPtrType(fieldType), irBasePtr, - builder->getDeclRefVal(field))); + getFieldKey(context, field))); } break; } @@ -871,7 +851,7 @@ top: case LoweredValInfo::Flavor::SwizzledLValue: { auto swizzleInfo = lowered.getSwizzledLValueInfo(); - + return LoweredValInfo::simple(builder->emitSwizzle( swizzleInfo->type, getSimpleVal(context, swizzleInfo->base), @@ -911,45 +891,6 @@ IRInst* getSimpleVal(IRGenContext* context, LoweredValInfo lowered) } } -struct LoweredTypeInfo -{ - enum class Flavor - { - None, - Simple, - }; - - RefPtr type; - Flavor flavor; - - LoweredTypeInfo() - { - flavor = Flavor::None; - } - - LoweredTypeInfo(IRType* t) - { - flavor = Flavor::Simple; - type = t; - } -}; - -RefPtr getSimpleType(LoweredTypeInfo lowered) -{ - switch(lowered.flavor) - { - case LoweredTypeInfo::Flavor::None: - return nullptr; - - case LoweredTypeInfo::Flavor::Simple: - return lowered.type; - - default: - SLANG_UNEXPECTED("unhandled value flavor"); - UNREACHABLE_RETURN(nullptr); - } -} - LoweredValInfo lowerVal( IRGenContext* context, Val* val); @@ -962,42 +903,10 @@ IRInst* lowerSimpleVal( return getSimpleVal(context, lowered); } -LoweredTypeInfo lowerType( - IRGenContext* context, - Type* type); - -static LoweredTypeInfo lowerType( - IRGenContext* context, - QualType const& type) -{ - return lowerType(context, type.type); -} - -// Lower a type and expect the result to be simple -RefPtr lowerSimpleType( - IRGenContext* context, - Type* type) -{ - auto lowered = lowerType(context, type); - return getSimpleType(lowered); -} - -RefPtr lowerSimpleType( - IRGenContext* context, - QualType const& type) -{ - auto lowered = lowerType(context, type); - return getSimpleType(lowered); -} - LoweredValInfo lowerLValueExpr( IRGenContext* context, Expr* expr); -LoweredValInfo lowerRValueExpr( - IRGenContext* context, - Expr* expr); - void assign( IRGenContext* context, LoweredValInfo const& left, @@ -1014,29 +923,41 @@ LoweredValInfo lowerDecl( IRType* getIntType( IRGenContext* context) { - return context->getSession()->getBuiltinType(BaseType::Int); + return context->irBuilder->getBasicType(BaseType::Int); } -RefPtr getFuncType( - IRGenContext* context, - UInt paramCount, - RefPtr const* paramTypes, - IRType* resultType) +IRStructKey* getInterfaceRequirementKey( + IRGenContext* context, + Decl* requirementDecl) { - RefPtr funcType = new FuncType(); - funcType->setSession(context->getSession()); - funcType->resultType = resultType; - for (UInt pp = 0; pp < paramCount; ++pp) + IRStructKey* requirementKey = nullptr; + if(context->shared->interfaceRequirementKeys.TryGetValue(requirementDecl, requirementKey)) { - funcType->paramTypes.Add(paramTypes[pp]); + return requirementKey; } - return funcType; + + IRBuilder builderStorage = *context->irBuilder; + auto builder = &builderStorage; + + builder->setInsertInto(builder->sharedBuilder->module->getModuleInst()); + + // Construct a key to serve as the representation of + // this requirement in the IR, and to allow lookup + // into the declaration. + requirementKey = builder->createStructKey(); + requirementKey->mangledName = context->getSession()->getNameObj( + getMangledName(requirementDecl)); + + context->shared->interfaceRequirementKeys.Add(requirementDecl, requirementKey); + + return requirementKey; } + SubstitutionSet lowerSubstitutions(IRGenContext* context, SubstitutionSet subst); // -struct ValLoweringVisitor : ValVisitor +struct ValLoweringVisitor : ValVisitor { IRGenContext* context; @@ -1047,6 +968,42 @@ struct ValLoweringVisitor : ValVisitordeclRef, + lowerType(context, GetType(val->declRef))); + } + + LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val) + { + return emitDeclRef(context, val->declRef, + context->irBuilder->getWitnessTableType()); + } + + LoweredValInfo visitTransitiveSubtypeWitness( + TransitiveSubtypeWitness* val) + { + // The base (subToMid) will turn into a value with + // witness-table type. + IRInst* baseWitnessTable = lowerSimpleVal(context, val->subToMid); + + // The next step should map to an interface requirement + // that is itself an interface conformance, so the result + // of lowering this value should be a "key" that we can + // use to look up a witness table. + IRInst* requirementKey = getInterfaceRequirementKey(context, val->midToSup.getDecl()); + + // TODO: There are some ugly cases here if `midToSup` is allowed + // to be an arbitrary witness, rather than just a declared one, + // and we should probably change the front-end representation + // to reflect the right constraints. + + return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( + nullptr, + baseWitnessTable, + requirementKey)); + } + LoweredValInfo visitConstantIntVal(ConstantIntVal* val) { // TODO: it is a bit messy here that the `ConstantIntVal` representation @@ -1056,70 +1013,135 @@ struct ValLoweringVisitor : ValVisitorgetIntValue(type, val->value)); } - LoweredTypeInfo visitType(Type* type) + IRFuncType* visitFuncType(FuncType* type) { - // TODO(tfoley): Now that we use the AST types directly in the IR, there - // isn't much to do in the "lowering" step. Still, there might be cases - // where certain kinds of legalization need to take place, so this - // visitor setup might still be needed in the long run. - return LoweredTypeInfo(type); -// SLANG_UNIMPLEMENTED_X("type lowering"); + IRType* resultType = lowerType(context, type->getResultType()); + UInt paramCount = type->getParamCount(); + List paramTypes; + for (UInt pp = 0; pp < paramCount; ++pp) + { + paramTypes.Add(lowerType(context, type->getParamType(pp))); + } + return getBuilder()->getFuncType( + paramCount, + paramTypes.Buffer(), + resultType); } - LoweredTypeInfo visitFuncType(FuncType* type) + IRType* visitDeclRefType(DeclRefType* type) { - return LoweredTypeInfo(type); + return (IRType*) getSimpleVal( + context, + emitDeclRef(context, type->declRef, + context->irBuilder->getTypeKind())); } - void addGenericArgs(List* ioArgs, DeclRefBase declRef) + IRType* visitNamedExpressionType(NamedExpressionType* type) { - auto subs = declRef.substitutions.genericSubstitutions; - while(subs) - { - for (auto aa : subs->args) - { - (*ioArgs).Add(getSimpleVal(context, lowerVal(context, aa))); - } - subs = subs->outer; - } + return (IRType*) getSimpleVal(context, + emitDeclRef(context, type->declRef, + context->irBuilder->getTypeKind())); + } + + IRType* visitBasicExpressionType(BasicExpressionType* type) + { + return getBuilder()->getBasicType( + type->baseType); + } + + IRType* visitVectorExpressionType(VectorExpressionType* type) + { + auto elementType = lowerType(context, type->elementType); + auto elementCount = lowerSimpleVal(context, type->elementCount); + + return getBuilder()->getVectorType( + elementType, + elementCount); } - LoweredTypeInfo visitDeclRefType(DeclRefType* type) + IRType* visitMatrixExpressionType(MatrixExpressionType* type) { - // If the type in question comes from the module we are - // trying to lower, then we need to make sure to - // emit everything relevant to its declaration. + auto elementType = lowerType(context, type->getElementType()); + auto rowCount = lowerSimpleVal(context, type->getRowCount()); + auto columnCount = lowerSimpleVal(context, type->getColumnCount()); + + return getBuilder()->getMatrixType( + elementType, + rowCount, + columnCount); + } - // TODO: actually test what module the type is coming from. + IRType* visitArrayExpressionType(ArrayExpressionType* type) + { + auto elementType = lowerType(context, type->baseType); + if (type->ArrayLength) + { + auto elementCount = lowerSimpleVal(context, type->ArrayLength); + return getBuilder()->getArrayType( + elementType, + elementCount); + } + else + { + return getBuilder()->getUnsizedArrayType( + elementType); + } + } - lowerDecl(context, type->declRef); - return LoweredTypeInfo(type); + // Lower a type where the type declaration being referenced is assumed + // to be an intrinsic type, which can thus be lowered to a simple IR + // type with the appropriate opcode. + IRType* lowerSimpleIntrinsicType(DeclRefType* type) + { + auto intrinsicTypeModifier = type->declRef.getDecl()->FindModifier(); + SLANG_ASSERT(intrinsicTypeModifier); + IROp op = IROp(intrinsicTypeModifier->irOp); + return getBuilder()->getType(op); } - LoweredTypeInfo visitBasicExpressionType(BasicExpressionType* type) + // Lower a type where the type declaration being referenced is assumed + // to be an intrinsic type with a single generic type parameter, and + // which can thus be lowered to a simple IR type with the appropriate opcode. + IRType* lowerGenericIntrinsicType(DeclRefType* type, Type* elementType) { - return LoweredTypeInfo(type); + auto intrinsicTypeModifier = type->declRef.getDecl()->FindModifier(); + SLANG_ASSERT(intrinsicTypeModifier); + IROp op = IROp(intrinsicTypeModifier->irOp); + IRInst* irElementType = lowerType(context, elementType); + return getBuilder()->getType( + op, + 1, + &irElementType); } - LoweredTypeInfo visitVectorExpressionType(VectorExpressionType* type) + IRType* visitResourceType(ResourceType* type) { - return LoweredTypeInfo(type); + return lowerGenericIntrinsicType(type, type->elementType); } - LoweredTypeInfo visitMatrixExpressionType(MatrixExpressionType* type) + IRType* visitSamplerStateType(SamplerStateType* type) { - return LoweredTypeInfo(type); + return lowerSimpleIntrinsicType(type); } - LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type) + IRType* visitBuiltinGenericType(BuiltinGenericType* type) { - return LoweredTypeInfo(type); + return lowerGenericIntrinsicType(type, type->elementType); } - LoweredTypeInfo visitIRBasicBlockType(IRBasicBlockType* type) + IRType* visitUntypedBufferResourceType(UntypedBufferResourceType* type) { - return LoweredTypeInfo(type); + return lowerSimpleIntrinsicType(type); } + + // We do not expect to encounter the following types in ASTs that have + // passed front-end semantic checking. +#define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); } + UNEXPECTED_CASE(GenericDeclRefType) + UNEXPECTED_CASE(TypeType) + UNEXPECTED_CASE(ErrorType) + UNEXPECTED_CASE(InitializerListType) + UNEXPECTED_CASE(OverloadGroupType) }; LoweredValInfo lowerVal( @@ -1131,18 +1153,51 @@ LoweredValInfo lowerVal( return visitor.dispatch(val); } -LoweredTypeInfo lowerType( +IRType* lowerType( IRGenContext* context, Type* type) { ValLoweringVisitor visitor; visitor.context = context; - return visitor.dispatchType(type); + return (IRType*) getSimpleVal(context, visitor.dispatchType(type)); +} + +void addVarDecorations( + IRGenContext* context, + IRInst* inst, + Decl* decl) +{ + auto builder = context->irBuilder; + for(RefPtr mod : decl->modifiers) + { + if(mod.As()) + { + builder->addDecoration(inst)->mode = IRInterpolationMode::NoInterpolation; + } + else if(mod.As()) + { + builder->addDecoration(inst)->mode = IRInterpolationMode::NoPerspective; + } + else if(mod.As()) + { + builder->addDecoration(inst)->mode = IRInterpolationMode::Linear; + } + else if(mod.As()) + { + builder->addDecoration(inst)->mode = IRInterpolationMode::Sample; + } + else if(mod.As()) + { + builder->addDecoration(inst)->mode = IRInterpolationMode::Centroid; + } + + // TODO: what are other modifiers we need to propagate through? + } } LoweredValInfo createVar( IRGenContext* context, - RefPtr type, + IRType* type, Decl* decl = nullptr) { auto builder = context->irBuilder; @@ -1150,6 +1205,8 @@ LoweredValInfo createVar( if (decl) { + addVarDecorations(context, irAlloc, decl); + builder->addHighLevelDeclDecoration(irAlloc, decl); } @@ -1198,7 +1255,10 @@ struct ExprLoweringVisitorBase : ExprVisitor LoweredValInfo visitVarExpr(VarExpr* expr) { - LoweredValInfo info = emitDeclRef(context, expr->declRef); + LoweredValInfo info = emitDeclRef( + context, + expr->declRef, + lowerType(context, expr->type)); return info; } @@ -1263,7 +1323,6 @@ struct ExprLoweringVisitorBase : ExprVisitor // as an l-value, since that is the easiest way to handle it. LoweredValInfo visitDerefExpr(DerefExpr* expr) { - auto loweredType = lowerType(context, expr->type); auto loweredBase = lowerRValueExpr(context, expr->base); // TODO: handle tupel-type for `base` @@ -1273,10 +1332,10 @@ struct ExprLoweringVisitorBase : ExprVisitor // need to extract the value type from that pointer here. // IRInst* loweredBaseVal = getSimpleVal(context, loweredBase); - RefPtr loweredBaseType = loweredBaseVal->getDataType(); + IRType* loweredBaseType = loweredBaseVal->getDataType(); - if (loweredBaseType->As() - || loweredBaseType->As()) + if (as(loweredBaseType) + || as(loweredBaseType)) { // Note that we do *not* perform an actual `load` operation // here, but rather just use the pointer value to construct @@ -1305,7 +1364,8 @@ struct ExprLoweringVisitorBase : ExprVisitor LoweredValInfo visitInitializerListExpr(InitializerListExpr* expr) { // Allocate a temporary of the given type - RefPtr type = lowerSimpleType(context, expr->type); + auto type = expr->type; + IRType* irType = lowerType(context, type); List args; UInt argCount = expr->args.Count(); @@ -1315,7 +1375,6 @@ struct ExprLoweringVisitorBase : ExprVisitor if (auto arrayType = type->As()) { UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength); - auto elementType = lowerType(context, arrayType->baseType); for (UInt ee = 0; ee < elementCount; ++ee) { @@ -1332,12 +1391,10 @@ struct ExprLoweringVisitorBase : ExprVisitor } return LoweredValInfo::simple( - getBuilder()->emitMakeArray(type, args.Count(), args.Buffer())); + getBuilder()->emitMakeArray(irType, args.Count(), args.Buffer())); } else if (auto vectorType = type->As()) { - auto elementType = lowerType(context, vectorType->elementType); - UInt elementCount = (UInt) GetIntVal(vectorType->elementCount); UInt argCounter = 0; @@ -1357,7 +1414,7 @@ struct ExprLoweringVisitorBase : ExprVisitor } return LoweredValInfo::simple( - getBuilder()->emitMakeVector(type, args.Count(), args.Buffer())); + getBuilder()->emitMakeVector(irType, args.Count(), args.Buffer())); } else if (auto declRefType = type->As()) { @@ -1384,7 +1441,7 @@ struct ExprLoweringVisitorBase : ExprVisitor } return LoweredValInfo::simple( - getBuilder()->emitMakeStruct(type, args.Count(), args.Buffer())); + getBuilder()->emitMakeStruct(irType, args.Count(), args.Buffer())); } else { @@ -1406,13 +1463,13 @@ struct ExprLoweringVisitorBase : ExprVisitor LoweredValInfo visitIntegerLiteralExpr(IntegerLiteralExpr* expr) { - auto type = lowerSimpleType(context, expr->type); + auto type = lowerType(context, expr->type); return LoweredValInfo::simple(context->irBuilder->getIntValue(type, expr->value)); } LoweredValInfo visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr) { - auto type = lowerSimpleType(context, expr->type); + auto type = lowerType(context, expr->type); return LoweredValInfo::simple(context->irBuilder->getFloatValue(type, expr->value)); } @@ -1450,7 +1507,7 @@ struct ExprLoweringVisitorBase : ExprVisitor for (auto paramDeclRef : getMembersOfType(funcDeclRef)) { auto paramDecl = paramDeclRef.getDecl(); - RefPtr paramType = lowerSimpleType(context, GetType(paramDeclRef)); + IRType* paramType = lowerType(context, GetType(paramDeclRef)); UInt argIndex = argCounter++; RefPtr argExpr; @@ -1656,7 +1713,7 @@ struct ExprLoweringVisitorBase : ExprVisitor LoweredValInfo visitInvokeExpr(InvokeExpr* expr) { - auto type = lowerSimpleType(context, expr->type); + auto type = lowerType(context, expr->type); // We are going to look at the syntactic form of // the "function" expression, so that we can avoid @@ -1704,12 +1761,13 @@ struct ExprLoweringVisitorBase : ExprVisitor // These may include `out` and `inout` arguments that // require "fixup" work on the other side. // + auto funcType = lowerType(context, funcExpr->type); addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups); auto result = emitCallToDeclRef( context, type, funcDeclRef, - funcExpr, + funcType, irArgs); applyOutArgumentFixups(argFixups); return result; @@ -1733,9 +1791,9 @@ struct ExprLoweringVisitorBase : ExprVisitor } LoweredValInfo subscriptValue( - LoweredTypeInfo type, + IRType* type, LoweredValInfo baseVal, - IRInst* indexVal) + IRInst* indexVal) { auto builder = getBuilder(); switch (baseVal.flavor) @@ -1743,14 +1801,14 @@ struct ExprLoweringVisitorBase : ExprVisitor case LoweredValInfo::Flavor::Simple: return LoweredValInfo::simple( builder->emitElementExtract( - getSimpleType(type), + type, getSimpleVal(context, baseVal), indexVal)); case LoweredValInfo::Flavor::Ptr: return LoweredValInfo::ptr( builder->emitElementAddress( - context->getSession()->getPtrType(getSimpleType(type)), + context->irBuilder->getPtrType(type), baseVal.val, indexVal)); @@ -1762,16 +1820,17 @@ struct ExprLoweringVisitorBase : ExprVisitor } LoweredValInfo extractField( - LoweredTypeInfo fieldType, + IRType* fieldType, LoweredValInfo base, DeclRef field) { - return Slang::extractField(context, getSimpleType(fieldType), base, field); + return Slang::extractField(context, fieldType, base, field); } LoweredValInfo visitStaticMemberExpr(StaticMemberExpr* expr) { - return emitDeclRef(context, expr->declRef); + return emitDeclRef(context, expr->declRef, + lowerType(context, expr->type)); } LoweredValInfo visitGenericAppExpr(GenericAppExpr* /*expr*/) @@ -1809,7 +1868,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBasetype); + auto irType = lowerType(context, expr->type); auto loweredBase = lowerRValueExpr(context, expr->base); RefPtr swizzledLValue = new SwizzledLValueInfo(); @@ -1835,7 +1894,7 @@ struct RValueExprLoweringVisitor : ExprLoweringVisitorBasetype); + auto irType = lowerType(context, expr->type); auto irBase = getSimpleVal(context, lowerRValueExpr(context, expr->base)); auto builder = getBuilder(); @@ -1923,7 +1982,17 @@ struct StmtLoweringVisitor : StmtVisitor return; auto varDecl = stmt->varDecl; - auto varType = varDecl->type; + auto varType = lowerType(context, varDecl->type); + + IRGenEnv subEnvStorage; + IRGenEnv* subEnv = &subEnvStorage; + subEnv->outer = context->env; + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->env = subEnv; + + for (IntegerLiteralValue ii = rangeBeginVal; ii < rangeEndVal; ++ii) { @@ -1931,9 +2000,9 @@ struct StmtLoweringVisitor : StmtVisitor varType, ii); - context->shared->declValues[varDecl] = LoweredValInfo::simple(constVal); + subEnv->mapDeclToValue[varDecl] = LoweredValInfo::simple(constVal); - lowerStmt(context, stmt->body); + lowerStmt(subContext, stmt->body); } } @@ -2666,7 +2735,6 @@ top: // try to handle everything uniformly. // auto swizzleInfo = left.getSwizzledLValueInfo(); - auto type = swizzleInfo->type; auto loweredBase = swizzleInfo->base; // Load from the base value: @@ -2700,19 +2768,18 @@ top: // When storing to such a value, we need to emit a call // to the appropriate builtin "setter" accessor. auto subscriptInfo = left.getBoundSubscriptInfo(); - auto type = subscriptInfo->type; // Search for an appropriate "setter" declaration auto setters = getMembersOfType(subscriptInfo->declRef); if (setters.Count()) { auto allArgs = subscriptInfo->args; - + addArgs(context, &allArgs, right); emitCallToDeclRef( context, - context->getSession()->getVoidType(), + builder->getVoidType(), *setters.begin(), nullptr, allArgs); @@ -2780,11 +2847,13 @@ struct DeclLoweringVisitor : DeclVisitor LoweredValInfo visitDeclBase(DeclBase* /*decl*/) { SLANG_UNIMPLEMENTED_X("decl catch-all"); + UNREACHABLE_RETURN(LoweredValInfo()); } LoweredValInfo visitDecl(Decl* /*decl*/) { SLANG_UNIMPLEMENTED_X("decl catch-all"); + UNREACHABLE_RETURN(LoweredValInfo()); } LoweredValInfo visitExtensionDecl(ExtensionDecl* decl) @@ -2814,9 +2883,33 @@ struct DeclLoweringVisitor : DeclVisitor return LoweredValInfo(); } - LoweredValInfo visitTypeDefDecl(TypeDefDecl * decl) + LoweredValInfo visitTypeDefDecl(TypeDefDecl* decl) { - return LoweredValInfo::simple(context->irBuilder->getTypeVal(decl->type.type)); + // A type alias declaration may be generic, if it is + // nested under a generic type/function/etc. + // + IRBuilder subBuilderStorage = *getBuilder(); + IRBuilder* subBuilder = &subBuilderStorage; + IRGeneric* outerGeneric = emitOuterGenerics(subBuilder, decl, decl); + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->irBuilder = subBuilder; + + // TODO: if a type alias declaration can have linkage, + // we will need to lower it to some kind of global + // value in the IR so that we can attach a name to it. + // + // For now, we can only attach a name *if* the type + // alias is somehow generic. + if(outerGeneric) + { + setMangledName(outerGeneric, getMangledName(decl)); + } + + auto type = lowerType(subContext, decl->type.type); + + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, type)); } LoweredValInfo visitGenericTypeParamDecl(GenericTypeParamDecl* /*decl*/) @@ -2824,118 +2917,219 @@ struct DeclLoweringVisitor : DeclVisitor return LoweredValInfo(); } - void walkInheritanceHierarchyAndCreateWitnessTableCopies(IRWitnessTable* witnessTable, Type* subType, InheritanceDecl* inheritanceDecl) + LoweredValInfo visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) { - auto baseDeclRef = inheritanceDecl->base.type.As(); - if (auto baseInterfaceDeclRef = baseDeclRef->declRef.As()) + // This might be a type constraint on an associated type, + // in which case it should lower as the key for that + // interface requirement. + if(auto assocTypeDecl = decl->ParentDecl->As()) { - for (auto subInheritanceDeclRef : getMembersOfType(baseInterfaceDeclRef)) - { - auto cpyMangledName = context->getSession()->getNameObj(getMangledNameForConformanceWitness(subType, subInheritanceDeclRef.getDecl()->getSup().type)); - if (!witnessTablesDictionary.ContainsKey(cpyMangledName)) - { - auto cpyTable = context->irBuilder->createWitnessTable(); - cpyTable->mangledName = cpyMangledName; - context->irBuilder->createWitnessTableEntry(witnessTable, - context->irBuilder->getDeclRefVal(subInheritanceDeclRef), cpyTable); + // TODO: might need extra steps if we ever allow + // generic associated types. - // We need to copy all the entries from the original table to this new table. - for (auto entry : witnessTable->getEntries()) - { - context->irBuilder->createWitnessTableEntry(cpyTable, - entry->requirementKey.get(), - entry->satisfyingVal.get()); - } - witnessTablesDictionary.Add(cpyTable->mangledName, cpyTable); - walkInheritanceHierarchyAndCreateWitnessTableCopies(witnessTable, subType, subInheritanceDeclRef.getDecl()); - } + if(auto interfaceDecl = assocTypeDecl->ParentDecl->As()) + { + // Okay, this seems to be an interface rquirement, and + // we should lower it as such. + return LoweredValInfo::simple(getInterfaceRequirementKey(decl)); } } + + if(auto globalGenericParamDecl = decl->ParentDecl->As()) + { + // This is a constraint on a global generic type parameters, + // and so it should lower as a parameter of its own. + + auto inst = getBuilder()->emitGlobalGenericParam(); + setMangledName(inst, getMangledName(decl)); + return LoweredValInfo::simple(inst); + } + + // Otherwise we really don't expect to see a type constraint + // declaration like this during lowering, because a generic + // should have set up a parameter for any constraints as + // part of being lowered. + + SLANG_UNEXPECTED("generic type constraint during lowering"); + UNREACHABLE_RETURN(LoweredValInfo()); } - Dictionary witnessTablesDictionary; + LoweredValInfo visitGlobalGenericParamDecl(GlobalGenericParamDecl* decl) + { + auto inst = getBuilder()->emitGlobalGenericParam(); + setMangledName(inst, getMangledName(decl)); + return LoweredValInfo::simple(inst); + } - LoweredValInfo visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + void lowerWitnessTable( + IRGenContext* subContext, + WitnessTable* astWitnessTable, + IRWitnessTable* irWitnessTable, + Dictionary mapASTToIRWitnessTable) { - // Construct a type for the parent declaration. - // - // TODO: if this inheritance declaration is under an extension, - // then we should construct the type that is being extended, - // and not a reference to the extension itself. + auto subBuilder = subContext->irBuilder; - auto parentDecl = inheritanceDecl->ParentDecl; - RefPtr type; - if (auto extParentDecl = dynamic_cast(parentDecl)) + for(auto entry : astWitnessTable->requirementDictionary) { - type = extParentDecl->targetType.type; - if (auto declRefType = type.As()) + auto requiredMemberDecl = entry.Key; + auto satisfyingWitness = entry.Value; + + auto irRequirementKey = getInterfaceRequirementKey(requiredMemberDecl); + IRInst* irSatisfyingVal = nullptr; + + switch(satisfyingWitness.getFlavor()) { - if (auto aggTypeDecl = declRefType->declRef.As()) - parentDecl = aggTypeDecl.getDecl(); - } - } - else - { - type = DeclRefType::Create( - context->getSession(), - makeDeclRef(parentDecl)); + case RequirementWitness::Flavor::declRef: + { + auto satisfyingDeclRef = satisfyingWitness.getDeclRef(); + irSatisfyingVal = getSimpleVal(subContext, + emitDeclRef(subContext, satisfyingDeclRef, + // TODO: we need to know what type to plug in here... + nullptr)); + } + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = satisfyingWitness.getVal(); + irSatisfyingVal = lowerSimpleVal(subContext, satisfyingVal); + } + break; + + case RequirementWitness::Flavor::witnessTable: + { + auto astReqWitnessTable = satisfyingWitness.getWitnessTable(); + IRWitnessTable* irSatisfyingWitnessTable = nullptr; + if(!mapASTToIRWitnessTable.TryGetValue(astReqWitnessTable, irSatisfyingWitnessTable)) + { + // Need to construct a sub-witness-table + irSatisfyingWitnessTable = subBuilder->createWitnessTable(); + + // Recursively lower the sub-table. + lowerWitnessTable( + subContext, + astReqWitnessTable, + irSatisfyingWitnessTable, + mapASTToIRWitnessTable); + + irSatisfyingWitnessTable->moveToEnd(); + } + irSatisfyingVal = irSatisfyingWitnessTable; + } + break; + + default: + SLANG_UNEXPECTED("handled requirement witness case"); + break; + } + + + subBuilder->createWitnessTableEntry( + irWitnessTable, + irRequirementKey, + irSatisfyingVal); + } + } + + LoweredValInfo visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + // An inheritance clause inside of an `interface` + // declaration should not give rise to a witness + // table, because it represents something the + // interface requires, and not what it provides. + // + auto parentDecl = inheritanceDecl->ParentDecl; + if (auto parentInterfaceDecl = parentDecl->As()) + { + return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl)); + } + // + // We also need to cover the case where an `extension` + // declaration is being used to add a conformance to + // an existing `interface`: + // + if(auto parentExtensionDecl = parentDecl->As()) + { + auto targetType = parentExtensionDecl->targetType; + if(auto targetDeclRefType = targetType->As()) + { + if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.As()) + { + return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl)); + } + } + } + + // Find the type that is doing the inheriting. + // Under normal circumstances it is the type declaration that + // is the parent for the inheritance declaration, but if + // the inheritance declaration is on an `extension` declaration, + // then we need to identify the type being extended. + // + RefPtr subType; + if (auto extParentDecl = dynamic_cast(parentDecl)) + { + subType = extParentDecl->targetType.type; + } + else + { + subType = DeclRefType::Create( + context->getSession(), + makeDeclRef(parentDecl)); } + // What is the super-type that we have declared we inherit from? RefPtr superType = inheritanceDecl->base.type; // Construct the mangled name for the witness table, which depends // on the type that is conforming, and the type that it conforms to. - auto mangledName = context->getSession()->getNameObj(getMangledNameForConformanceWitness(type, superType)); - - // Build an IR level witness table, which will represent the - // conformance of the type to its super-type. - auto witnessTable = context->irBuilder->createWitnessTable(); - witnessTable->mangledName = mangledName; - - witnessTablesDictionary.Add(mangledName, witnessTable); - - if (parentDecl->ParentDecl) - witnessTable->genericDecl = dynamic_cast(parentDecl->ParentDecl); - witnessTable->subTypeDeclRef = makeDeclRef(parentDecl); - witnessTable->subTypeDeclRef.substitutions = createDefaultSubstitutions(context->getSession(), parentDecl); - witnessTable->supTypeDeclRef = inheritanceDecl->base.type->AsDeclRefType()->declRef; - - // Register the value now, rather than later, to avoid - // infinite recursion. - context->shared->declValues[inheritanceDecl] = LoweredValInfo::simple(witnessTable); - - - // Semantic checking will have filled in a dictionary of - // witnesses for requirements in the interface, and we - // will now navigate that dictionary to fill in the witness table. - for (auto entry : inheritanceDecl->requirementWitnesses) - { - auto requiredMemberDeclRef = entry.Key; - auto satisfyingMemberDeclRef = entry.Value; - - auto irRequirement = context->irBuilder->getDeclRefVal(requiredMemberDeclRef); - IRInst* irSatisfyingVal = nullptr; - if (satisfyingMemberDeclRef.As()) - irSatisfyingVal = context->irBuilder->getDeclRefVal(satisfyingMemberDeclRef); - else - irSatisfyingVal = getSimpleVal(context, ensureDecl(context, satisfyingMemberDeclRef)); + // + // TODO: This approach doesn't really make sense for generic `extension` conformances. + auto mangledName = context->getSession()->getNameObj( + getMangledNameForConformanceWitness(subType, superType)); - context->irBuilder->createWitnessTableEntry( - witnessTable, - irRequirement, - irSatisfyingVal); - } + // A witness table may need to be generic, if the outer + // declaration (either a type declaration or an `extension`) + // is generic. + // + IRBuilder subBuilderStorage = *getBuilder(); + IRBuilder* subBuilder = &subBuilderStorage; + emitOuterGenerics(subBuilder, inheritanceDecl, inheritanceDecl); - witnessTable->moveToEnd(); - walkInheritanceHierarchyAndCreateWitnessTableCopies(witnessTable, type, inheritanceDecl); + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->irBuilder = subBuilder; - // A direct reference to this inheritance relationship (e.g., - // as a subtype witness) will take the form of a reference to - // the witness table in the IR. - return LoweredValInfo::simple(witnessTable); - } + // Lower the super-type to force its declaration to be lowered. + // + // Note: we are using the "sub-context" here because the + // type being inherited from could reference generic parameters, + // and we need those parameters to lower as references to + // the parameters of our IR-level generic. + // + lowerType(subContext, superType); + + // Create the IR-level witness table + auto irWitnessTable = subBuilder->createWitnessTable(); + setMangledName(irWitnessTable, mangledName); + + // Register the value now, rather than later, to avoid any possible infinite recursion. + setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(irWitnessTable)); + // Make sure that all the entries in the witness table have been filled in, + // including any cases where there are sub-witness-tables for conformances + Dictionary mapASTToIRWitnessTable; + lowerWitnessTable( + subContext, + inheritanceDecl->witnessTable, + irWitnessTable, + mapASTToIRWitnessTable); + + irWitnessTable->moveToEnd(); + + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irWitnessTable)); + } LoweredValInfo visitDeclGroup(DeclGroup* declGroup) { @@ -2996,19 +3190,23 @@ struct DeclLoweringVisitor : DeclVisitor LoweredValInfo lowerGlobalVarDecl(VarDeclBase* decl) { - RefPtr varType = lowerSimpleType(context, decl->getType()); + IRType* varType = lowerType(context, decl->getType()); if (decl->HasModifier()) { - varType = context->getSession()->getGroupSharedType(varType); + // TODO: here we are applying the rate qualifier to + // the *data type* of the variable, when we really + // should be applying the rate to the variable itself. + // + // This ends up making a distinction between + // `Ptr<@GroupShared X>` and `@GroupShared Ptr`. + // The latter is more technically correct, but the + // code generation logic currently looks for the former. + + varType = getBuilder()->getRateQualifiedType( + getBuilder()->getGroupSharedRate(), + varType); } - // TODO: There might be other cases of storage qualifiers - // that should translate into "rate-qualified" types - // for the variable's storage. - // - // TODO: Also worth asking whether we should have semantic - // checking be responsible for applying qualifiers applied - // to a variable over to its type, when it makes sense. auto builder = getBuilder(); @@ -3035,8 +3233,7 @@ struct DeclLoweringVisitor : DeclVisitor // A global variable's SSA value is a *pointer* to // the underlying storage. - context->shared->declValues[ - DeclRef(decl, nullptr)] = globalVal; + setGlobalValue(context, decl, globalVal); if (isImportedDecl(decl)) { @@ -3064,12 +3261,15 @@ struct DeclLoweringVisitor : DeclVisitor subContext->irBuilder->emitReturn(getSimpleVal(subContext, initVal)); } + irGlobal->moveToEnd(); + return globalVal; } LoweredValInfo visitGenericValueParamDecl(GenericValueParamDecl* decl) { - return LoweredValInfo::simple(context->irBuilder->getDeclRefVal(DeclRefBase(decl))); + return emitDeclRef(context, makeDeclRef(decl), + lowerType(context, decl->type)); } LoweredValInfo visitVarDeclBase(VarDeclBase* decl) @@ -3092,7 +3292,7 @@ struct DeclLoweringVisitor : DeclVisitor // emit an SSA value in this common case. // - RefPtr varType = lowerSimpleType(context, decl->getType()); + IRType* varType = lowerType(context, decl->getType()); // TODO: If the variable is marked `static` then we need to // deal with it specially: we should move its allocation out @@ -3125,7 +3325,9 @@ struct DeclLoweringVisitor : DeclVisitor { // TODO: This logic is duplicated with the global-variable // case. We should seek to share it. - varType = context->getSession()->getGroupSharedType(varType); + varType = getBuilder()->getRateQualifiedType( + getBuilder()->getGroupSharedRate(), + varType); } LoweredValInfo varVal = createVar(context, varType, decl); @@ -3137,14 +3339,97 @@ struct DeclLoweringVisitor : DeclVisitor assign(context, varVal, initVal); } - context->shared->declValues[ - DeclRef(decl, nullptr)] = varVal; + setGlobalValue(context, decl, varVal); return varVal; } + IRStructKey* getInterfaceRequirementKey(Decl* requirementDecl) + { + return Slang::getInterfaceRequirementKey(context, requirementDecl); + } + + LoweredValInfo visitInterfaceDecl(InterfaceDecl* decl) + { + // The interface decl is not itself a type in the IR + // (yet), so the only thing we need to do here is + // enumerate the requirements that the interface + // imposes on implementations. + // + // These members will turn into the keys that will + // be used for lookup operations into witness + // tables that promise conformance to the interface. + // + // TODO: we don't handle the case here of an interface + // with concrete/default implementations for any + // of its members. + // + // TODO: If we want to support using an interface as + // an existential type, then we might need to emit + // a witness table for the interface type's conformance + // to its own interface. + // + for (auto requirementDecl : decl->Members) + { + getInterfaceRequirementKey(requirementDecl); + + // As a special case, any type constraints placed + // on an associated type will *also* need to be turned + // into requirement keys for this interface. + if (auto associatedTypeDecl = requirementDecl.As()) + { + for (auto constraintDecl : associatedTypeDecl->getMembersOfType()) + { + getInterfaceRequirementKey(constraintDecl); + } + } + } + + return LoweredValInfo(); + } + + IRGeneric* getOuterGeneric(IRGlobalValue* gv) + { + auto parentBlock = as(gv->getParent()); + if (!parentBlock) return nullptr; + + auto parentGeneric = as(parentBlock->getParent()); + return parentGeneric; + } + + void setMangledName(IRGlobalValue* inst, Name* name) + { + // If the instruction is nested inside one or more generics, + // then the mangled name should really apply to the outer-most + // generic, and not the declaration nested inside. + + IRGlobalValue* gv = inst; + while (auto outerGeneric = getOuterGeneric(gv)) + { + gv = outerGeneric; + } + + gv->mangledName = name; + } + + void setMangledName(IRGlobalValue* inst, String const& name) + { + setMangledName(inst, context->getSession()->getNameObj(name)); + } + LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl) { + // Don't generate an IR `struct` for intrinsic types + if(decl->FindModifier() || decl->FindModifier()) + { + return LoweredValInfo(); + } + + if(getMangledName(decl) == "_ST03int") + { + decl = decl; + } + // Given a declaration of a type, we need to make sure // to output "witness tables" for any interfaces this // type has declared conformance to. @@ -3153,13 +3438,92 @@ struct DeclLoweringVisitor : DeclVisitor ensureDecl(context, inheritanceDecl); } - // TODO: we currently store a Decl* in the witness table, which causes this function - // being invoked to translate the witness table entry into an IRInst. - // We should really allow a witness table entry to represent a type and not having to - // construct the type here. The current implementation will not work when the struct type - // is defined in a generic parent (we lose the environmental substitutions). - return LoweredValInfo::simple(context->irBuilder->getTypeVal(DeclRefType::Create(context->getSession(), - DeclRef(decl, nullptr)))); + // We are going to create nested IR building state + // to use when emitting the members of the type. + // + IRBuilder subBuilderStorage = *getBuilder(); + IRBuilder* subBuilder = &subBuilderStorage; + + // Emit any generics that should wrap the actual type. + emitOuterGenerics(subBuilder, decl, decl); + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->irBuilder = subBuilder; + + IRStructType* irStruct = subBuilder->createStructType(); + + setMangledName(irStruct, getMangledName(decl)); + + subBuilder->setInsertInto(irStruct); + + for (auto fieldDecl : decl->getMembersOfType()) + { + if (fieldDecl->HasModifier()) + { + // A `static` field is actually a global variable, + // and we should emit it as such. + ensureDecl(context, fieldDecl); + continue; + } + + // Each ordinary field will need to turn into a struct "key" + // that is used for fetching the field. + IRInst* fieldKeyInst = getSimpleVal(context, + ensureDecl(context, fieldDecl)); + auto fieldKey = as(fieldKeyInst); + assert(fieldKey); + + // Note: we lower the type of the field in the "sub" + // context, so that any generic parameters that were + // set up for the type can be referenced by the field type. + IRType* fieldType = lowerType( + subContext, + fieldDecl->getType()); + + // Then, the parent `struct` instruction itself will have + // a "field" instruction. + subBuilder->createStructField( + irStruct, + fieldKey, + fieldType); + } + + // TODO: we should enumerate the non-field members of the type + // as well, and ensure those have been emitted (e.g., any + // member functions). + + irStruct->moveToEnd(); + + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irStruct)); + } + + LoweredValInfo visitStructField(StructField* fieldDecl) + { + // Each field declaration in the AST translates into + // a "key" that can be used to extract field values + // from instances of struct types that contain the field. + // + // It is correct to say struct *types* because a `struct` + // nested under a generic can be used to realize a number + // of different concrete types, but all of these types + // will use the same space of keys. + + auto builder = getBuilder(); + auto irFieldKey = builder->createStructKey(); + + addVarDecorations(context, irFieldKey, fieldDecl); + + irFieldKey->mangledName = context->getSession()->getNameObj( + getMangledName(fieldDecl)); + + if (auto semanticModifier = fieldDecl->FindModifier()) + { + auto semanticDecoration = builder->addDecoration(irFieldKey); + semanticDecoration->semanticName = semanticModifier->name.getName(); + } + + return LoweredValInfo::simple(irFieldKey); } @@ -3227,7 +3591,7 @@ struct DeclLoweringVisitor : DeclVisitor struct ParameterInfo { // This AST-level type of the parameter - Type* type; + RefPtr type; // The direction (`in` vs `out` vs `in out`) ParameterDirection direction; @@ -3283,7 +3647,6 @@ struct DeclLoweringVisitor : DeclVisitor struct ParameterLists { List params; - List genericParams; }; // // Because there might be a `static` declaration somewhere @@ -3381,7 +3744,7 @@ struct DeclLoweringVisitor : DeclVisitor // we need to specialize it for any generic parameters // that are in scope here. auto declRef = createDefaultSpecializedDeclRef(typeDecl); - auto type = DeclRefType::Create(context->getSession(), declRef); + RefPtr type = DeclRefType::Create(context->getSession(), declRef); addThisParameter( type, ioParameterLists); @@ -3441,51 +3804,6 @@ struct DeclLoweringVisitor : DeclVisitor } } } - else if( auto genericDecl = dynamic_cast(decl) ) - { - for( auto memberDecl : genericDecl->Members ) - { - if( auto genericTypeParamDecl = memberDecl.As() ) - { - ioParameterLists->genericParams.Add(genericTypeParamDecl); - } - else if( auto genericValueParamDecl = memberDecl.As() ) - { - ioParameterLists->genericParams.Add(genericValueParamDecl); - } - else if( auto genericConstraintDel = memberDecl.As() ) - { - // When lowering to the IR we need to reify the constraints on - // a generic parameter as concrete parameters of their own. - // These parameter will usually be satisfied by passing a "witness" - // as the argument to correspond to the parameter. - // - // TODO: it is possible that all witness parameters should come - // after the other generic parameters, and thus should be collected - // in a third list. - // - ioParameterLists->genericParams.Add(genericConstraintDel); - } - } - } - - } - - void trySetMangledName( - IRFunc* irFunc, - Decl* decl) - { - // We want to generate a mangled name for the given declaration and attach - // it to the instruction. - // - // TODO: we probably want to start be doing an early-exit in cases - // where it doesn't make sense to attach a mangled name (e.g., because - // the declaration in question shouldn't have linkage). - // - - String mangledName = getMangledName(decl); - - irFunc->mangledName = context->getSession()->getNameObj(mangledName); } ModuleDecl* findModuleDecl(Decl* decl) @@ -3545,18 +3863,148 @@ struct DeclLoweringVisitor : DeclVisitor return false; } - RefPtr maybeGetConstExprType(Type* type, Decl* decl) + IRType* maybeGetConstExprType(IRType* type, Decl* decl) { if(isConstExprVar(decl)) { - return context->getSession()->getConstExprType(type); + return getBuilder()->getRateQualifiedType( + getBuilder()->getConstExprRate(), + type); } return type; } + IRGeneric* emitOuterGeneric( + IRBuilder* subBuilder, + GenericDecl* genericDecl, + Decl* leafDecl) + { + // Of course, a generic might itself be nested inside of other generics... + auto nextOuterGeneric = emitOuterGenerics(subBuilder, genericDecl, leafDecl); + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->irBuilder = subBuilder; + + + // We need to create an IR generic + + auto irGeneric = subBuilder->emitGeneric(); + subBuilder->setInsertInto(irGeneric); + + if (!nextOuterGeneric) + { + // If this is the outer-most generic, then it will be the + // global symbol that gets the mangled name from the inner + // declaration actually being lowered. + irGeneric->mangledName = context->getSession()->getNameObj(getMangledName(leafDecl)); + } + + auto irBlock = subBuilder->emitBlock(); + subBuilder->setInsertInto(irBlock); + + // Now emit any parameters of the generic + // + // First we start with type and value parameters, + // in the order they were declared. + for (auto member : genericDecl->Members) + { + if (auto typeParamDecl = member.As()) + { + // TODO: use a `TypeKind` to represent the + // classifier of the parameter. + auto param = subBuilder->emitParam(nullptr); + setValue(subContext, typeParamDecl, LoweredValInfo::simple(param)); + } + else if (auto valDecl = member.As()) + { + auto paramType = lowerType(subContext, valDecl->getType()); + auto param = subBuilder->emitParam(paramType); + setValue(subContext, valDecl, LoweredValInfo::simple(param)); + } + } + // Then we emit constraint parameters, again in + // declaration order. + for (auto member : genericDecl->Members) + { + if (auto constraintDecl = member.As()) + { + // TODO: use a `WitnessTableKind` to represent the + // classifier of the parameter. + auto param = subBuilder->emitParam(nullptr); + setValue(subContext, constraintDecl, LoweredValInfo::simple(param)); + } + } + + return irGeneric; + } + + // If the given `decl` is enclosed in any generic declarations, then + // emit IR-level generics to represent them. + // The `leafDecl` represents the inner-most declaration we are actually + // trying to emit, which is the one that should receive the mangled name. + // + IRGeneric* emitOuterGenerics(IRBuilder* subBuilder, Decl* decl, Decl* leafDecl) + { + for(auto pp = decl->ParentDecl; pp; pp = pp->ParentDecl) + { + if(auto genericAncestor = dynamic_cast(pp)) + { + return emitOuterGeneric(subBuilder, genericAncestor, leafDecl); + } + } + + return nullptr; + } + + // If any generic declarations have been created by `emitOuterGenerics`, + // then finish them off by emitting `return` instructions for the + // values that they should produce. + // + // Return the outer-most generic (if there is one), or the original + // value (if there were no generics), which should be the IR-level + // representation of the original declaration. + // + IRInst* finishOuterGenerics( + IRBuilder* subBuilder, + IRInst* val) + { + IRInst* v = val; + for(;;) + { + auto parentBlock = as(v->getParent()); + if (!parentBlock) break; + + auto parentGeneric = as(parentBlock->getParent()); + if (!parentGeneric) break; + + subBuilder->setInsertInto(parentBlock); + subBuilder->emitReturn(v); + parentGeneric->moveToEnd(); + + // There might be more outer generics, + // so we need to loop until we run out. + v = parentGeneric; + } + return v; + } + LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) { + // We are going to use a nested builder, because we will + // change the parent node that things get nested into. + + IRBuilder subBuilderStorage = *getBuilder(); + IRBuilder* subBuilder = &subBuilderStorage; + + + // The actual `IRFunction` that we emit needs to be nested + // inside of one `IRGeneric` for every outer `GenericDecl` + // in the declaration hierarchy. + + emitOuterGenerics(subBuilder, decl, decl); + // Collect the parameter lists we will use for our new function. ParameterLists parameterLists; collectParameterLists(decl, ¶meterLists, kParameterListCollectMode_Default); @@ -3584,9 +4032,6 @@ struct DeclLoweringVisitor : DeclVisitor } - IRBuilder subBuilderStorage = *getBuilder(); - IRBuilder* subBuilder = &subBuilderStorage; - IRGenContext subContextStorage = *context; IRGenContext* subContext = &subContextStorage; subContext->irBuilder = subBuilder; @@ -3594,27 +4039,14 @@ struct DeclLoweringVisitor : DeclVisitor // need to create an IR function here IRFunc* irFunc = subBuilder->createFunc(); - subBuilder->setInsertInto(irFunc); - - trySetMangledName(irFunc, decl); - List> paramTypes; + setMangledName(irFunc, getMangledName(decl)); - // We first need to walk the generic parameters (if any) - // because these will influence the declared type of - // the function. + List paramTypes; - for(auto pp = decl->ParentDecl; pp; pp = pp->ParentDecl) - { - if(auto genericAncestor = dynamic_cast(pp)) - { - irFunc->genericDecls.Add(genericAncestor); - } - } - irFunc->specializedGenericLevel = (int)irFunc->genericDecls.Count() - 1; for( auto paramInfo : parameterLists.params ) { - RefPtr irParamType = lowerSimpleType(context, paramInfo.type); + IRType* irParamType = lowerType(subContext, paramInfo.type); switch( paramInfo.direction ) { @@ -3627,10 +4059,10 @@ struct DeclLoweringVisitor : DeclVisitor // the IR, but we will use a specialized pointer // type that encodes the parameter direction information. case kParameterDirection_Out: - irParamType = context->getSession()->getOutType(irParamType); + irParamType = subBuilder->getOutType(irParamType); break; case kParameterDirection_InOut: - irParamType = context->getSession()->getInOutType(irParamType); + irParamType = subBuilder->getInOutType(irParamType); break; default: @@ -3649,7 +4081,7 @@ struct DeclLoweringVisitor : DeclVisitor paramTypes.Add(irParamType); } - auto irResultType = lowerSimpleType(context, declForReturnType->ReturnType); + auto irResultType = lowerType(subContext, declForReturnType->ReturnType); if (auto setterDecl = dynamic_cast(decl)) { @@ -3663,22 +4095,23 @@ struct DeclLoweringVisitor : DeclVisitor // Instead, a setter always returns `void` // - irResultType = context->getSession()->getVoidType(); + irResultType = subBuilder->getVoidType(); } if( auto refAccessorDecl = dynamic_cast(decl) ) { // A `ref` accessor needs to return a *pointer* to the value // being accessed, rather than a simple value. - irResultType = context->getSession()->getPtrType(irResultType); + irResultType = subBuilder->getPtrType(irResultType); } - auto irFuncType = getFuncType( - context, + auto irFuncType = subBuilder->getFuncType( paramTypes.Count(), paramTypes.Buffer(), irResultType); - irFunc->type = irFuncType; + irFunc->setFullType(irFuncType); + + subBuilder->setInsertInto(irFunc); if (isImportedDecl(decl)) { @@ -3788,8 +4221,7 @@ struct DeclLoweringVisitor : DeclVisitor if( auto paramDecl = paramInfo.decl ) { - DeclRef paramDeclRef = makeDeclRef(paramDecl); - subContext->shared->declValues[paramDeclRef] = paramVal; + setValue(subContext, paramDecl, paramVal); } if (paramInfo.isThisParam) @@ -3816,7 +4248,7 @@ struct DeclLoweringVisitor : DeclVisitor // of the body, in case the user didn't do so. if (!subContext->irBuilder->getBlock()->getTerminator()) { - if (irResultType->Equals(context->getSession()->getVoidType())) + if(as(irResultType)) { // `void`-returning function can get an implicit // return on exit of the body statement. @@ -3872,7 +4304,7 @@ struct DeclLoweringVisitor : DeclVisitor // of global values. irFunc->moveToEnd(); - return LoweredValInfo::simple(irFunc); + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irFunc)); } LoweredValInfo visitGenericDecl(GenericDecl * genDecl) @@ -3937,8 +4369,15 @@ LoweredValInfo lowerDecl( { IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, decl->loc); + IRGenEnv subEnv; + subEnv.outer = context->env; + + IRGenContext subContext = *context; + subContext.env = &subEnv; + + DeclLoweringVisitor visitor; - visitor.context = context; + visitor.context = &subContext; return visitor.dispatch(decl); } @@ -3950,11 +4389,21 @@ LoweredValInfo ensureDecl( auto shared = context->shared; LoweredValInfo result; - if(shared->declValues.TryGetValue(decl, result)) - return result; + + // Look for an existing value installed in this context + auto env = context->env; + while(env) + { + if(env->mapDeclToValue.TryGetValue(decl, result)) + return result; + + env = env->outer; + } + IRBuilder subIRBuilder; subIRBuilder.sharedBuilder = context->irBuilder->sharedBuilder; + subIRBuilder.setInsertInto(subIRBuilder.sharedBuilder->module->getModuleInst()); IRGenContext subContext = *context; @@ -3962,225 +4411,189 @@ LoweredValInfo ensureDecl( result = lowerDecl(&subContext, decl); - shared->declValues[decl] = result; + // By default assume that any value we are lowering represents + // something that should be installed globally. + setGlobalValue(shared, decl, result); return result; } -IRInst* findWitnessTable( +IRInst* lowerSubstitutionArg( IRGenContext* context, - DeclRef declRef) + Val* val) { - IRInst* irVal = getSimpleVal(context, emitDeclRef(context, declRef)); - if (!irVal) + if (auto type = dynamic_cast(val)) { - SLANG_UNEXPECTED("expected a witness table"); - return nullptr; + return lowerType(context, type); } - - if (irVal->op == kIROp_specialize) + else if (auto declaredSubtypeWitness = dynamic_cast(val)) { - return irVal; + // We need to look up the IR-level representation of the witness (which will be a witness table). + auto irWitnessTable = getSimpleVal( + context, + emitDeclRef( + context, + declaredSubtypeWitness->declRef, + context->irBuilder->getWitnessTableType())); + return irWitnessTable; } - - if (irVal->op != kIROp_witness_table) + else { - SLANG_UNEXPECTED("expected a witness table"); - return nullptr; + SLANG_UNIMPLEMENTED_X("value cases"); } - - return (IRWitnessTable*)irVal; } -RefPtr lowerSubstitutionArg( - IRGenContext* context, - Val* val) +// Can the IR lowered version of this declaration ever be an `IRGeneric`? +bool canDeclLowerToAGeneric(RefPtr decl) { - if (auto type = dynamic_cast(val)) - { - return lowerSimpleType(context, type); - } - else if (auto declaredSubtypeWitness = dynamic_cast(val)) - { - // We do not have a concrete witness table yet for a GenericTypeConstraintDecl witness + // A callable decl lowers to an `IRFunc`, and can be generic + if(decl.As()) return true; - if (declaredSubtypeWitness->declRef.As()) - return val; + // An aggregate type decl lowers to an `IRStruct`, and can be generic + if(decl.As()) return true; - // We need to look up the IR-level representation of the witness - // (which is a witness table). + // An inheritance decl lowers to an `IRWitnessTable`, and can be generic + if(decl.As()) return true; - auto irWitnessTable = findWitnessTable(context, declaredSubtypeWitness->declRef); + // A `typedef` declaration nested under a generic will turn into + // a generic that returns a type (a simple type-level function). + if(decl.As()) return true; - // We have an IR-level value, but we need to embed it into an AST-level - // type, so we will use a proxy `Val` that wraps up an `IRInst` as - // an AST-level value. - // - // TODO: This proxy value currently doesn't enter into use-def chaining, - // and so Bad Things could happen quite easily. We need to fix that - // up in a reasonably clean fashion. - // - RefPtr proxyVal = new IRProxyVal(); - proxyVal->inst.init(nullptr, irWitnessTable); - return proxyVal; - } - else - { - // For now, jsut assume that all other values - // lower to themselves. - // - // TODO: we should probably handle the case of - // a `Val` that references an AST-level `constexpr` - // variable, since that would need to be lowered - // to a `Val` that references the IR equivalent. - return val; - } + return false; } -// Given a set of substitutions, make sure that we have -// lowered the arguments being used into a form that -// is suitable for use in the IR. -RefPtr lowerGenericSubstitutions( - IRGenContext* context, - GenericSubstitution* genSubst) +LoweredValInfo emitDeclRef( + IRGenContext* context, + RefPtr decl, + RefPtr subst, + IRType* type) { - if(!genSubst) - return nullptr; - RefPtr result; - RefPtr newSubst = new GenericSubstitution(); - newSubst->genericDecl = genSubst->genericDecl; + // We need to proceed by considering the specializations that + // have been put in place. - for (auto arg : genSubst->args) - { - auto newArg = lowerSubstitutionArg(context, arg); - newSubst->args.Add(newArg); - } + // Ignore any global generic type substitutions during lowering. + // Really, we don't even expect these to appear. + while(auto globalGenericSubst = subst.As()) + subst = globalGenericSubst->outer; - result = newSubst; - if (genSubst->outer) + // If the declaration would not get wrapped in a `IRGeneric`, + // even if it is nested inside of an AST `GenericDecl`, then + // we should also ignore any generic substiuttions. + if(!canDeclLowerToAGeneric(decl)) { - result->outer = lowerGenericSubstitutions( - context, - genSubst->outer); + while(auto genericSubst = subst.As()) + subst = genericSubst->outer; } - return result; -} -RefPtr lowerGlobalGenericSubstitutions( - IRGenContext* context, - GlobalGenericParamSubstitution* genSubst) -{ - if (!genSubst) - return nullptr; - RefPtr result; - RefPtr newSubst = new GlobalGenericParamSubstitution(); - newSubst->actualType = lowerSubstitutionArg(context, genSubst->actualType); - newSubst->paramDecl = genSubst->paramDecl; - for (auto & tbl : genSubst->witnessTables) + // In the simplest case, there is no specialization going + // on, and the decl-ref turns into a reference to the + // lowered IR value for the declaration. + if(!subst) { - auto ntbl = tbl; - ntbl.Value = lowerSubstitutionArg(context, tbl.Value); - newSubst->witnessTables.Add(ntbl); + LoweredValInfo loweredDecl = ensureDecl(context, decl); + return loweredDecl; } - result = newSubst; - if (genSubst->outer) + + // Otherwise, we look at the kind of substitution, and let it guide us. + if(auto genericSubst = subst.As()) { - result->outer = lowerGlobalGenericSubstitutions( + // A generic substitution means we will need to output + // a `specialize` instruction to specialize the generic. + // + // First we want to emit the value without generic specialization + // applied, to get a correct value for it. + // + // Note: we only "unwrap" a single layer from the + // substitutions here, because the underlying declaration + // might be nested in multiple generics, or it might + // come from an interface. + // + LoweredValInfo genericVal = emitDeclRef( context, - genSubst->outer); - } - return result; -} - -RefPtr lowerThisTypeSubstitution( - IRGenContext* context, - ThisTypeSubstitution* thisSubst) -{ - if (!thisSubst) - return nullptr; - RefPtr newSubst = new ThisTypeSubstitution(); - newSubst->sourceType = lowerSubstitutionArg(context, thisSubst->sourceType); - return newSubst; -} - -SubstitutionSet lowerSubstitutions(IRGenContext* context, SubstitutionSet subst) -{ - SubstitutionSet rs; - rs.genericSubstitutions = lowerGenericSubstitutions(context, subst.genericSubstitutions); - rs.thisTypeSubstitution = lowerThisTypeSubstitution(context, subst.thisTypeSubstitution); - rs.globalGenParamSubstitutions = lowerGlobalGenericSubstitutions(context, subst.globalGenParamSubstitutions); - return rs; -} + decl, + genericSubst->outer, + context->irBuilder->getGenericKind()); -LoweredValInfo emitDeclRef( - IRGenContext* context, - DeclRef declRef) -{ - // First we need to construct an IR value representing the - // unspecialized declaration. - LoweredValInfo loweredDecl = ensureDecl(context, declRef.getDecl()); - - return maybeEmitSpecializeInst(context, loweredDecl, declRef); -} - -LoweredValInfo maybeEmitSpecializeInst(IRGenContext* context, - LoweredValInfo loweredDecl, - DeclRef declRef) -{ - // If this declaration reference doesn't involve any specializations, - // then we are done at this point. - if (!declRef.substitutions.genericSubstitutions) - return loweredDecl; - - // There's no reason to specialize something that maps to a NULL pointer. - if (loweredDecl.flavor == LoweredValInfo::Flavor::None) - return loweredDecl; + // There's no reason to specialize something that maps to a NULL pointer. + if (genericVal.flavor == LoweredValInfo::Flavor::None) + return LoweredValInfo(); - if (!declRef.As() && !declRef.As()) - return loweredDecl; + // We can only really specialize things that map to single values. + // It would be an error if we got a non-`None` value that + // wasn't somehow a single value. + auto irGenericVal = getSimpleVal(context, genericVal); - auto val = getSimpleVal(context, loweredDecl); + // We have the IR value for the generic we'd like to specialize, + // and now we need to get the value for the arguments. + List irArgs; + for (auto argVal : genericSubst->args) + { + auto irArgVal = lowerSimpleVal(context, argVal); + SLANG_ASSERT(irArgVal); + irArgs.Add(irArgVal); + } + // Once we have both the generic and its arguments, + // we can emit a `specialize` instruction and use + // its value as the result. + auto irSpecializedVal = context->irBuilder->emitSpecializeInst( + type, + irGenericVal, + irArgs.Count(), + irArgs.Buffer()); - RefPtr outterMostSubst, secondOutterMostSubst; - for (auto subst = declRef.substitutions.genericSubstitutions; subst; subst = subst->outer) - { - if (!subst->outer) - outterMostSubst = subst; - else - secondOutterMostSubst = subst; + return LoweredValInfo::simple(irSpecializedVal); } - auto newSubst = outterMostSubst; - // We have the "raw" substitutions from the AST, but we may - // need to walk through those and replace things in - // cases where the `Val`s used for substitution should - // lower to something other than their original form. - SubstitutionSet oldSubst = declRef.substitutions; - oldSubst.genericSubstitutions = newSubst; - auto lowedNewSubst = lowerSubstitutions(context, oldSubst); - DeclRef newDeclRef = DeclRef(declRef.decl, lowedNewSubst); - - RefPtr type; - if (auto declType = val->getDataType()) + else if(auto thisTypeSubst = subst.As()) { - type = declType->Substitute(newDeclRef.substitutions).As(); + // Somebody is trying to look up an interface requirement + // "through" some concrete type. We need to lower this decl-ref + // as a lookup of the corresponding member in a witness table. + // + // The witness table itself is referenced by the this-type + // substitution, so we can just lower that. + // + // Note: unlike the case for generics above, in the interface-lookup + // case, we don't end up caring about any further outer substitutions. + // That is because even if we are naming `ISomething.doIt()`, + // a method insided a generic interface, we don't actually care + // about the substitution of `Foo` for the parameter `T` of + // `ISomething`. That is because we really care about the + // witness table for the concrete type that conforms to `ISomething`. + // + auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness); + // + // The key to use for looking up the interface member is + // derived from the declaration. + // + auto irRequirementKey = getInterfaceRequirementKey(context, decl); + // + // Those two pieces of information tell us what we need to + // do in order to look up the value that satisfied the requirement. + // + auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst( + type, + irWitnessTable, + irRequirementKey); + return LoweredValInfo::simple(irSatisfyingVal); } - - // Otherwise, we need to construct a specialization of the - // given declaration. - auto specializedVal = LoweredValInfo::simple((IRInst*)context->irBuilder->emitSpecializeInst( - type, - val, - newDeclRef)); - if (secondOutterMostSubst) + else { - newDeclRef.substitutions.genericSubstitutions = new GenericSubstitution(*secondOutterMostSubst); - newDeclRef.substitutions.genericSubstitutions->outer = nullptr; - return maybeEmitSpecializeInst(context, specializedVal, newDeclRef); + SLANG_UNEXPECTED("uhandled substitution type"); } - return specializedVal; } +LoweredValInfo emitDeclRef( + IRGenContext* context, + DeclRef declRef, + IRType* type) +{ + return emitDeclRef( + context, + declRef.decl, + declRef.substitutions.substitutions, + type); +} static void lowerEntryPointToIR( IRGenContext* context, @@ -4195,10 +4608,34 @@ static void lowerEntryPointToIR( // the entry point request. return; } - // we need to lower all global type arguments as well auto loweredEntryPointFunc = ensureDecl(context, entryPointFuncDecl); - for (auto arg : entryPointRequest->genericParameterTypes) - lowerType(context, arg); + + // Now lower all the arguments supplied for global generic + // type parameters. + // + auto builder = context->irBuilder; + builder->setInsertInto(builder->getModule()->getModuleInst()); + for (RefPtr subst = entryPointRequest->globalGenericSubst; subst; subst = subst->outer) + { + auto gSubst = subst.As(); + if(!gSubst) + continue; + + IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl)); + IRType* typeVal = lowerType(context, gSubst->actualType); + + // bind `typeParam` to `typeVal` + builder->emitBindGlobalGenericParam(typeParam, typeVal); + + for (auto& constraintArg : gSubst->constraintArgs) + { + IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl)); + IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val); + + // bind `constraintParam` to `constraintVal` + builder->emitBindGlobalGenericParam(constraintParam, constraintVal); + } + } } IRModule* generateIRForTranslationUnit( @@ -4212,11 +4649,9 @@ IRModule* generateIRForTranslationUnit( sharedContext->compileRequest = compileRequest; sharedContext->mainModuleDecl = translationUnit->SyntaxNode; - IRGenContext contextStorage; + IRGenContext contextStorage(sharedContext); IRGenContext* context = &contextStorage; - context->shared = sharedContext; - SharedIRBuilder sharedBuilderStorage; SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->module = nullptr; @@ -4251,6 +4686,12 @@ IRModule* generateIRForTranslationUnit( ensureDecl(context, decl); } +#if 0 + fprintf(stderr, "### GENERATED\n"); + dumpIR(module); + fprintf(stderr, "###\n"); +#endif + validateIRModuleIfEnabled(compileRequest, module); // We will perform certain "mandatory" optimization passes now. diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp index f2bf279a2..7a50903a0 100644 --- a/source/slang/mangle.cpp +++ b/source/slang/mangle.cpp @@ -1,6 +1,7 @@ #include "mangle.h" #include "name.h" +#include "ir-insts.h" #include "syntax.h" namespace Slang @@ -159,12 +160,6 @@ namespace Slang // to mangle in the constraints even when // the whole thing is specialized... } - else if (auto proxyVal = dynamic_cast(val)) - { - // This is a proxy standing in for some IR-level - // value, so we certainly don't want to include - // it in the mangling. - } else if( auto genericParamIntVal = dynamic_cast(val) ) { // TODO: we shouldn't be including the names of generic parameters @@ -190,16 +185,89 @@ namespace Slang } } - // TODO: this needs to be centralized - RefPtr getOutermostGenericSubst( - RefPtr inSubst) + void emitIRVal( + ManglingContext* context, + IRInst* inst); + + void emitIRSimpleIntVal( + ManglingContext* context, + IRInst* inst) + { + if (auto intLit = as(inst)) + { + auto cVal = intLit->getValue(); + if(cVal >= 0 && cVal <= 9 ) + { + emit(context, (UInt)cVal); + return; + } + } + + // Fallback: + emitIRVal(context, inst); + } + + void emitIRVal( + ManglingContext* context, + IRInst* inst) { - for (auto subst = inSubst; subst; subst = subst->outer) + switch (inst->op) + { + case kIROp_VoidType: emitRaw(context, "V"); return; + case kIROp_BoolType: emitRaw(context, "b"); return; + case kIROp_IntType: emitRaw(context, "i"); return; + case kIROp_UIntType: emitRaw(context, "u"); return; + case kIROp_UInt64Type: emitRaw(context, "U"); return; + case kIROp_HalfType: emitRaw(context, "h"); return; + case kIROp_FloatType: emitRaw(context, "f"); return; + case kIROp_DoubleType: emitRaw(context, "d"); return; + + default: + break; + } + + if (auto globalVal = as(inst)) + { + // If it is a global value, it has its own mangled name. + emit(context, getText(globalVal->mangledName)); + } + // TODO: need to handle various type cases here + else if (auto intLit = as(inst)) + { + // TODO: need to figure out what prefix/suffix is needed + // to allow demangling later. + emitRaw(context, "k"); + emit(context, (UInt) intLit->getValue()); + } + // Note: the cases here handling types really should match + // the cases above that handle AST-level `Type`s. This + // seems to be a weakness in the way we mangle names, because + // we may mangle in both IR-level and AST-level types. + else if (auto vecType = as(inst)) + { + emitRaw(context, "v"); + emitIRSimpleIntVal(context, vecType->getElementCount()); + emitIRVal(context, vecType->getElementType()); + + } + else if( auto matType = as(inst) ) + { + emitRaw(context, "m"); + emitIRSimpleIntVal(context, matType->getRowCount()); + emitRaw(context, "x"); + emitIRSimpleIntVal(context, matType->getColumnCount()); + emitIRVal(context, matType->getElementType()); + } + else if (auto arrType = as(inst)) + { + emitRaw(context, "a"); + emitIRSimpleIntVal(context, arrType->getElementCount()); + emitIRVal(context, arrType->getElementCount()); + } + else { - if (auto genericSubst = subst.As()) - return genericSubst; + SLANG_UNEXPECTED("unimplemented case in mangling"); } - return nullptr; } void emitQualifiedName( @@ -231,6 +299,29 @@ namespace Slang return; } + // Inheritance declarations don't have meaningful names, + // and so we should emit them based on the type + // that is doing the inheriting. + if(auto inheritanceDeclRef = declRef.As()) + { + emit(context, "I"); + emitType(context, GetSup(inheritanceDeclRef)); + return; + } + + // Similarly, an extension doesn't have a name worth + // emitting, and we should base things on its target + // type instead. + if(auto extensionDeclRef = declRef.As()) + { + // TODO: as a special case, an "unconditional" extension + // that is in the same module as the type it extends should + // be treated as equivalent to the type itself. + emit(context, "X"); + emitType(context, GetTargetType(extensionDeclRef)); + return; + } + emitName(context, declRef.GetName()); // Are we the "inner" declaration beneath a generic decl? @@ -239,7 +330,7 @@ namespace Slang // There are two cases here: either we have specializations // in place for the parent generic declaration, or we don't. - auto subst = getOutermostGenericSubst(declRef.substitutions.genericSubstitutions); + auto subst = findInnerMostGenericSubstitution(declRef.substitutions); if( subst && subst->genericDecl == parentGenericDeclRef.getDecl() ) { // This is the case where we *do* have substitutions. @@ -373,13 +464,6 @@ namespace Slang String getMangledName(DeclRef const& declRef) { - // Special case: if a declaration is the result of a type legalization - // transformation, then it should just get the mangled name of the - // original declaration, and not the one that would be computed - // for it otherwise. - if(auto legalizedModifier = declRef.getDecl()->FindModifier()) - return legalizedModifier->originalMangledName; - ManglingContext context; mangleName(&context, declRef); return context.sb.ProduceString(); @@ -391,16 +475,18 @@ namespace Slang DeclRef(declRef.decl, declRef.substitutions)); } - String mangleSpecializedFuncName(String baseName, SubstitutionSet subst) + String mangleSpecializedFuncName(String baseName, IRSpecialize* specializeInst) { ManglingContext context; emitRaw(&context, baseName.Buffer()); emitRaw(&context, "_G"); - if (auto genSubst = subst.genericSubstitutions) + + UInt argCount = specializeInst->getArgCount(); + for (UInt aa = 0; aa < argCount; ++aa) { - for (auto a : genSubst->args) - emitVal(&context, a); + emitIRVal(&context, specializeInst->getArg(aa)); } + return context.sb.ProduceString(); } diff --git a/source/slang/mangle.h b/source/slang/mangle.h index 8f4c6d1d0..b6f7587ad 100644 --- a/source/slang/mangle.h +++ b/source/slang/mangle.h @@ -8,11 +8,14 @@ namespace Slang { + struct IRSpecialize; + String getMangledName(Decl* decl); String getMangledName(DeclRef const & declRef); String getMangledName(DeclRefBase const & declRef); - String mangleSpecializedFuncName(String baseName, SubstitutionSet subst); + String mangleSpecializedFuncName(String baseName, IRSpecialize* specializeInst); + String getMangledNameForConformanceWitness( Type* sub, Type* sup); diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h index baa08a160..6212a244e 100644 --- a/source/slang/modifier-defs.h +++ b/source/slang/modifier-defs.h @@ -398,10 +398,3 @@ SYNTAX_CLASS(ImplicitConversionModifier, Modifier) // The conversion cost, used to rank conversions FIELD(ConversionCost, cost) END_SYNTAX_CLASS() - -// A marker modifier used to indicate that a declaration was created as -// part of type legalization. -SYNTAX_CLASS(LegalizedModifier, Modifier) - FIELD(String, originalMangledName) -END_SYNTAX_CLASS() - diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index 572235280..4378cb06b 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -548,23 +548,72 @@ static bool validateGenericSubstitutionsMatch( return true; } +static bool validateThisTypeSubstitutionsMatch( + ParameterBindingContext* /*context*/, + ThisTypeSubstitution* /*left*/, + ThisTypeSubstitution* /*right*/, + StructuralTypeMatchStack* /*stack*/) +{ + // TODO: actual checking. + return true; +} + static bool validateSpecializationsMatch( ParameterBindingContext* context, SubstitutionSet left, SubstitutionSet right, StructuralTypeMatchStack* stack) { - if(!validateGenericSubstitutionsMatch( - context, - left.genericSubstitutions, - right.genericSubstitutions, - stack)) + auto ll = left.substitutions; + auto rr = right.substitutions; + for(;;) { + // Skip any global generic substitutions. + if(auto leftGlobalGeneric = ll.As()) + { + ll = leftGlobalGeneric->outer; + continue; + } + if(auto rightGlobalGeneric = rr.As()) + { + rr = rightGlobalGeneric->outer; + continue; + } + + // If either ran out, then we expect both to have run out. + if(!ll || !rr) + return !ll && !rr; + + auto leftSubst = ll; + auto rightSubst = rr; + + ll = ll->outer; + rr = rr->outer; + + if(auto leftGeneric = leftSubst.As()) + { + if(auto rightGeneric = rightSubst.As()) + { + if(validateGenericSubstitutionsMatch(context, leftGeneric, rightGeneric, stack)) + { + continue; + } + } + } + else if(auto leftThisType = leftSubst.As()) + { + if(auto rightThisType = rightSubst.As()) + { + if(validateThisTypeSubstitutionsMatch(context, leftThisType, rightThisType, stack)) + { + continue; + } + } + } + return false; } - // TODO: anything else to match? - return true; } diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index ea8e567b4..5d1b254d7 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -992,7 +992,7 @@ namespace Slang else { // default case is a type parameter - auto paramDecl = new GenericTypeParamDecl(); + RefPtr paramDecl = new GenericTypeParamDecl(); parser->FillPosition(paramDecl); paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); if (AdvanceIf(parser, TokenType::Colon)) diff --git a/source/slang/slang-stdlib.cpp b/source/slang/slang-stdlib.cpp index 69ae36a3f..bd2ce2561 100644 --- a/source/slang/slang-stdlib.cpp +++ b/source/slang/slang-stdlib.cpp @@ -269,22 +269,4 @@ namespace Slang hlslLibraryCode = sb.ProduceString(); return hlslLibraryCode; } - - - // GLSL-specific library code - - String Session::getGLSLLibraryCode() - { - if(glslLibraryCode.Length() != 0) - return glslLibraryCode; - - String path = getStdlibPath(); - - StringBuilder sb; - - #include "glsl.meta.slang.h" - - glslLibraryCode = sb.ProduceString(); - return glslLibraryCode; - } } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 2861b82ca..5df180f46 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -66,12 +66,8 @@ Session::Session() slangLanguageScope = new Scope(); slangLanguageScope->nextSibling = hlslLanguageScope; - glslLanguageScope = new Scope(); - glslLanguageScope->nextSibling = coreLanguageScope; - addBuiltinSource(coreLanguageScope, "core", getCoreLibraryCode()); addBuiltinSource(hlslLanguageScope, "hlsl", getHLSLLibraryCode()); - addBuiltinSource(glslLanguageScope, "glsl", getGLSLLibraryCode()); } struct IncludeHandlerImpl : IncludeHandler @@ -255,10 +251,6 @@ void CompileRequest::parseTranslationUnit( languageScope = mSession->hlslLanguageScope; break; - case SourceLanguage::GLSL: - languageScope = mSession->glslLanguageScope; - break; - case SourceLanguage::Slang: default: languageScope = mSession->slangLanguageScope; diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index 489005620..bb3d3a16c 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -81,14 +81,14 @@ {{{op}}} op - type + typeUse.usedValue {{count = {operandCount}}} operandCount operandCount - (IRUse*)(this + 1) + (IRUse*)(&(typeUse) + 1) @@ -108,7 +108,7 @@ {{{op}}} op - type + typeUse.usedValue diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index f7e4ed5b2..09990889d 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -275,25 +275,6 @@ %(Identity).cpp $(OutDir)slang-generate.exe - - Document - $(OutDir)slang-generate.exe %(Identity) - $(OutDir)slang-generate.exe %(Identity) - $(OutDir)slang-generate.exe %(Identity) - $(OutDir)slang-generate.exe %(Identity) - slang-generate %(Identity) - slang-generate %(Identity) - slang-generate %(Identity) - slang-generate %(Identity) - %(Identity).cpp - %(Identity).cpp - %(Identity).cpp - %(Identity).cpp - $(OutDir)slang-generate.exe - $(OutDir)slang-generate.exe - $(OutDir)slang-generate.exe - $(OutDir)slang-generate.exe - Document $(OutDir)slang-generate.exe %(Identity) diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index 55140a4da..82fc6ac87 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -88,7 +88,6 @@ - \ No newline at end of file diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h index 4fded014e..acc795d8b 100644 --- a/source/slang/syntax-base-defs.h +++ b/source/slang/syntax-base-defs.h @@ -81,8 +81,6 @@ public: Session* getSession() { return this->session; } void setSession(Session* s) { this->session = s; } - virtual String ToString() = 0; - bool Equals(Type * type); bool Equals(RefPtr type); @@ -131,10 +129,12 @@ END_SYNTAX_CLASS() // A substitution represents a binding of certain // type-level variables to concrete argument values ABSTRACT_SYNTAX_CLASS(Substitutions, RefObject) + // The next outer that this one refines. + FIELD(RefPtr, outer) RAW( // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) = 0; + virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) = 0; // Check if these are equivalent substitutiosn to another set virtual bool Equals(Substitutions* subst) = 0; @@ -151,12 +151,9 @@ SYNTAX_CLASS(GenericSubstitution, Substitutions) // The actual values of the arguments SYNTAX_FIELD(List>, args) - // Any further substitutions, relating to outer generic declarations - SYNTAX_FIELD(RefPtr, outer) - RAW( // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) override; // Check if these are equivalent substitutiosn to another set virtual bool Equals(Substitutions* subst) override; @@ -178,11 +175,17 @@ SYNTAX_CLASS(GenericSubstitution, Substitutions) END_SYNTAX_CLASS() SYNTAX_CLASS(ThisTypeSubstitution, Substitutions) + // The declaration of the interface that we are specializing + FIELD_INIT(InterfaceDecl*, interfaceDecl, nullptr) + + // A witness that shows that the concrete type used to + // specialize the interface conforms to the interface. + FIELD(RefPtr, witness) + // The actual type that provides the lookup scope for an associated type - SYNTAX_FIELD(RefPtr, sourceType) RAW( // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) override; // Check if these are equivalent substitutiosn to another set virtual bool Equals(Substitutions* subst) override; @@ -190,25 +193,31 @@ SYNTAX_CLASS(ThisTypeSubstitution, Substitutions) { return Equals(const_cast(&subst)); } - virtual int GetHashCode() const override - { - if (sourceType) - return sourceType->GetHashCode(); - return 0; - } + virtual int GetHashCode() const override; ) END_SYNTAX_CLASS() SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions) // the __generic_param decl to be substituted DECL_FIELD(GlobalGenericParamDecl*, paramDecl) + // the actual type to substitute in - SYNTAX_FIELD(RefPtr, actualType) - // Any further global type parameter substitutions - SYNTAX_FIELD(RefPtr, outer) + SYNTAX_FIELD(RefPtr, actualType) + + RAW( + struct ConstraintArg + { + RefPtr decl; + RefPtr val; + }; + ) + + // the values that satisfy any constraints on the type parameter + SYNTAX_FIELD(List, constraintArgs) + RAW( // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) override; // Check if these are equivalent substitutiosn to another set virtual bool Equals(Substitutions* subst) override; @@ -219,17 +228,13 @@ RAW( virtual int GetHashCode() const override { int rs = actualType->GetHashCode(); - for (auto && v : witnessTables) + for (auto && a : constraintArgs) { - rs = combineHash(rs, v.Key->GetHashCode()); - rs = combineHash(rs, v.Value->GetHashCode()); + rs = combineHash(rs, a.val->GetHashCode()); } return rs; } - typedef List, RefPtr>> WitnessTableLookupTable; ) - // The witness tables for each interface this actual type implements - SYNTAX_FIELD(WitnessTableLookupTable, witnessTables) END_SYNTAX_CLASS() ABSTRACT_SYNTAX_CLASS(SyntaxNode, SyntaxNodeBase) diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 70e230f33..9d29e7d21 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -228,12 +228,6 @@ void Type::accept(IValVisitor* visitor, void* extra) overloadedType = new OverloadGroupType(); overloadedType->setSession(this); - - irBasicBlockType = new IRBasicBlockType(); - irBasicBlockType->setSession(this); - - constExprRate = new ConstExprRate(); - constExprRate->setSession(this); } Type* Session::getBoolType() @@ -286,33 +280,12 @@ void Type::accept(IValVisitor* visitor, void* extra) return errorType; } - Type* Session::getIRBasicBlockType() - { - return irBasicBlockType; - } - - Type* Session::getConstExprRate() - { - return constExprRate; - } - Type* Session::getStringType() { auto stringTypeDecl = findMagicDecl(this, "StringType"); return DeclRefType::Create(this, makeDeclRef(stringTypeDecl)); } - RefPtr Session::getRateQualifiedType( - Type* rate, - Type* valueType) - { - RefPtr rateQualifiedType = new RateQualifiedType(); - rateQualifiedType->setSession(this); - rateQualifiedType->rate = rate; - rateQualifiedType->valueType = valueType; - return rateQualifiedType; - } - RefPtr Session::getPtrType( RefPtr valueType) { @@ -363,16 +336,6 @@ void Type::accept(IValVisitor* visitor, void* extra) return arrayType; } - - RefPtr Session::getGroupSharedType(RefPtr valueType) - { - RefPtr groupSharedType = new GroupSharedType(); - groupSharedType->setSession(this); - groupSharedType->valueType = valueType; - return groupSharedType; - } - - SyntaxClass Session::findSyntaxClass(Name* name) { SyntaxClass syntaxClass; @@ -432,142 +395,147 @@ void Type::accept(IValVisitor* visitor, void* extra) return baseType->ToString() + "[]"; } - // RateQualifiedType - - Slang::String RateQualifiedType::ToString() - { - return "@" + rate->ToString() + " " + valueType->ToString(); - } - - bool RateQualifiedType::EqualsImpl(Type * type) - { - auto rateQualifiedType = type->As(); - if(!rateQualifiedType) - return false; - - return rate->Equals(rateQualifiedType->rate) - && valueType->Equals(rateQualifiedType->valueType); - } - - RefPtr RateQualifiedType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - auto substRate = rate->SubstituteImpl(subst, &diff).As(); - auto substValueType = valueType->SubstituteImpl(subst, &diff).As(); - if(!diff) - return this; - - (*ioDiff)++; - - return getSession()->getRateQualifiedType(substRate, substValueType); - } - - RefPtr RateQualifiedType::CreateCanonicalType() - { - RefPtr canRate = rate->GetCanonicalType(); - RefPtr canValueType = valueType->GetCanonicalType(); - - RefPtr canRateQualifiedType = new RateQualifiedType(); - canRateQualifiedType->setSession(session); - canRateQualifiedType->rate = canRate; - canRateQualifiedType->valueType = valueType; - return canRateQualifiedType; - } + // DeclRefType - int RateQualifiedType::GetHashCode() + String DeclRefType::ToString() { - auto hash = (int)(typeid(this).hash_code()); - hash = combineHash(hash, rate->GetHashCode()); - hash = combineHash(hash, valueType->GetHashCode()); - return hash; + return declRef.toString(); } - // ConstExprRate - - Slang::String ConstExprRate::ToString() + int DeclRefType::GetHashCode() { - return "ConstExpr"; + return (declRef.GetHashCode() * 16777619) ^ (int)(typeid(this).hash_code()); } - bool ConstExprRate::EqualsImpl(Type * type) + bool DeclRefType::EqualsImpl(Type * type) { - auto constExprRate = type->As(); - if(!constExprRate) - return false; - - return true; + if (auto declRefType = type->AsDeclRefType()) + { + return declRef.Equals(declRefType->declRef); + } + return false; } - RefPtr ConstExprRate::SubstituteImpl(SubstitutionSet /*subst*/, int* /*ioDiff*/) + RefPtr DeclRefType::CreateCanonicalType() { + // A declaration reference is already canonical return this; } - RefPtr ConstExprRate::CreateCanonicalType() - { - return this; - } + // + // RequirementWitness + // - int ConstExprRate::GetHashCode() - { - auto hash = (int)(typeid(this).hash_code()); - return hash; - } + RequirementWitness::RequirementWitness(RefPtr val) + : m_flavor(Flavor::val) + , m_obj(val) + {} - // GroupSharedType - Slang::String GroupSharedType::ToString() - { - return "@ThreadGroup " + valueType->ToString(); - } + RequirementWitness::RequirementWitness(RefPtr witnessTable) + : m_flavor(Flavor::witnessTable) + , m_obj(witnessTable) + {} - bool GroupSharedType::EqualsImpl(Type * type) + RefPtr RequirementWitness::getWitnessTable() { - auto t = type->As(); - if (!t) - return false; - return valueType->Equals(t->valueType); + SLANG_ASSERT(getFlavor() == Flavor::witnessTable); + return m_obj.As(); } - RefPtr GroupSharedType::CreateCanonicalType() - { - auto canonicalValueType = valueType->GetCanonicalType(); - auto canonicalGroupSharedType = getSession()->getGroupSharedType(canonicalValueType); - return canonicalGroupSharedType; - } - int GroupSharedType::GetHashCode() + RequirementWitness RequirementWitness::specialize(SubstitutionSet const& subst) { - return combineHash( - valueType->GetHashCode(), - (int)(typeid(this).hash_code())); - } - - // DeclRefType + switch(getFlavor()) + { + default: + SLANG_UNEXPECTED("unknown requirement witness flavor"); + case RequirementWitness::Flavor::none: + return RequirementWitness(); - String DeclRefType::ToString() - { - return declRef.toString(); - } + case RequirementWitness::Flavor::declRef: + { + int diff = 0; + return RequirementWitness( + getDeclRef().SubstituteImpl(subst, &diff)); + } - int DeclRefType::GetHashCode() - { - return (declRef.GetHashCode() * 16777619) ^ (int)(typeid(this).hash_code()); + case RequirementWitness::Flavor::val: + return RequirementWitness( + getVal()->Substitute(subst)); + } } - bool DeclRefType::EqualsImpl(Type * type) + RequirementWitness tryLookUpRequirementWitness( + SubtypeWitness* subtypeWitness, + Decl* requirementKey) { - if (auto declRefType = type->AsDeclRefType()) + if(auto declaredSubtypeWitness = dynamic_cast(subtypeWitness)) { - return declRef.Equals(declRefType->declRef); + if(auto inheritanceDeclRef = declaredSubtypeWitness->declRef.As()) + { + // A conformance that was declared as part of an inheritance clause + // will have built up a dictionary of the satisfying declarations + // for each of its requirements. + RequirementWitness requirementWitness; + auto witnessTable = inheritanceDeclRef.getDecl()->witnessTable; + if(witnessTable && witnessTable->requirementDictionary.TryGetValue(requirementKey, requirementWitness)) + { + // The `inheritanceDeclRef` has substitutions applied to it that + // *aren't* present in the `requirementWitness`, because it was + // derived by the front-end when looking at the `InheritanceDecl` alone. + // + // We need to apply these substitutions here for the result to make sense. + // + // E.g., if we have a case like: + // + // interface ISidekick { associatedtype Hero; void follow(Hero hero); } + // struct Sidekick : ISidekick { typedef H Hero; void follow(H hero) {} }; + // + // void followHero(S s, S.Hero h) + // { + // s.follow(h); + // } + // + // Batman batman; + // Sidekick robin; + // followHero>(robin, batman); + // + // The second argument to `followHero` is `batman`, which has type `Batman`. + // The parameter declaration lists the type `S.Hero`, which is a reference + // to an associated type. The front end will expand this into something + // like `S.{S:ISidekick}.Hero` - that is, we'll end up with a declaration + // reference to `ISidekick.Hero` with a this-type substitution that references + // the `{S:ISidekick}` declaration as a witness. + // + // The front-end will expand the generic appliation `followHero>` + // to `followHero, {Sidekick:ISidekick}[H->Batman]>` + // (that is, the hidden second parameter will reference the inheritance + // clause on `Sidekick`, with a substitution to map `H` to `Batman`. + // + // This step should map the `{S:ISidekick}` declaration over to the + // concrete `{Sidekick:ISidekick}[H->Batman]` inheritance declaration. + // At that point `tryLookupRequirementWitness` might be called, because + // we want to look up the witness for the key `ISidekick.Hero` in the + // inheritance decl-ref that is `{Sidekick:ISidekick}[H->Batman]`. + // + // That lookup will yield us a reference to the typedef `Sidekick.Hero`, + // *without* any substitution for `H` (or rather, with a default one that + // maps `H` to `H`. + // + // So, in order to get the *right* end result, we need to apply + // the substitutions from the inheritance decl-ref to the witness. + // + requirementWitness = requirementWitness.specialize(inheritanceDeclRef.substitutions); + + return requirementWitness; + } + } } - return false; - } - RefPtr DeclRefType::CreateCanonicalType() - { - // A declaration reference is already canonical - return this; + // TODO: should handle the transitive case here too + + return RequirementWitness(); } RefPtr DeclRefType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) @@ -579,9 +547,12 @@ void Type::accept(IValVisitor* visitor, void* extra) if (auto genericTypeParamDecl = dynamic_cast(declRef.getDecl())) { // search for a substitution that might apply to us - for (auto s = subst.genericSubstitutions; s; s = s->outer.Ptr()) + for(auto s = subst.substitutions; s; s = s->outer) { - auto genericSubst = s; + auto genericSubst = s.As(); + if(!genericSubst) + continue; + // the generic decl associated with the substitution list must be // the generic decl that declared this parameter auto genericDecl = genericSubst->genericDecl; @@ -611,50 +582,15 @@ void Type::accept(IValVisitor* visitor, void* extra) } } } - // the second case we care about is when this decl type refers to an associatedtype decl - // we want to replace it with the actual associated type - else if (auto assocTypeDecl = dynamic_cast(declRef.getDecl())) - { - auto thisSubst = getThisTypeSubst(declRef, false); - auto oldSubstSrc = thisSubst ? thisSubst->sourceType : nullptr; - bool restore = false; - if (thisSubst && thisSubst->sourceType.Ptr() == dynamic_cast(this)) - thisSubst->sourceType = nullptr; - auto newSubst = substituteSubstitutions(declRef.substitutions, subst, ioDiff); - if (restore) - thisSubst->sourceType = oldSubstSrc; - if (auto thisTypeSubst = newSubst.thisTypeSubstitution) - { - if (thisTypeSubst->sourceType) - { - if (auto aggTypeDeclRef = thisTypeSubst->sourceType.As()->declRef.As()) - { - Decl * targetType = nullptr; - if (aggTypeDeclRef.getDecl()->memberDictionary.TryGetValue(assocTypeDecl->getName(), targetType)) - { - if (auto typeDefDecl = dynamic_cast(targetType)) - { - DeclRef targetTypeDeclRef(typeDefDecl, aggTypeDeclRef.substitutions); - return GetType(targetTypeDeclRef); - } - else if (auto targetAggType = dynamic_cast(targetType)) - { - return DeclRefType::Create(getSession(), DeclRef(targetAggType, aggTypeDeclRef.substitutions)); - } - else - { - SLANG_UNIMPLEMENTED_X("unknown assoctype implementation type."); - } - } - } - } - } - } else if (auto globalGenParam = dynamic_cast(declRef.getDecl())) { // search for a substitution that might apply to us - for (auto genericSubst = subst.globalGenParamSubstitutions; genericSubst; genericSubst = genericSubst->outer.Ptr()) + for(auto s = subst.substitutions; s; s = s->outer) { + auto genericSubst = s.As(); + if(!genericSubst) + continue; + if (genericSubst->paramDecl == globalGenParam) { (*ioDiff)++; @@ -671,6 +607,45 @@ void Type::accept(IValVisitor* visitor, void* extra) // Make sure to record the difference! *ioDiff += diff; + // If this type is a reference to an associated type declaration, + // and the substitutions provide a "this type" substitution for + // the outer interface, then try to replace the type with the + // actual value of the associated type for the given implementation. + // + if(auto substAssocTypeDecl = substDeclRef.decl->As()) + { + for(auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) + { + auto thisSubst = s.As(); + if(!thisSubst) + continue; + + if(auto interfaceDecl = substAssocTypeDecl->ParentDecl->As()) + { + if(thisSubst->interfaceDecl == interfaceDecl) + { + // We need to look up the declaration that satisfies + // the requirement named by the associated type. + Decl* requirementKey = substAssocTypeDecl; + RequirementWitness requirementWitness = tryLookUpRequirementWitness(thisSubst->witness, requirementKey); + switch(requirementWitness.getFlavor()) + { + default: + // No usable value was found, so there is nothing we can do. + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = requirementWitness.getVal(); + return satisfyingVal; + } + break; + } + } + } + } + } + // Re-construct the type in case we are using a specialized sub-class return DeclRefType::Create(getSession(), substDeclRef); } @@ -689,9 +664,7 @@ void Type::accept(IValVisitor* visitor, void* extra) return intVal; } - // TODO: need to figure out how to unify this with the logic - // in the generic case... - DeclRefType* DeclRefType::Create( + DeclRef createDefaultSubstitutionsIfNeeded( Session* session, DeclRef declRef) { @@ -701,30 +674,81 @@ void Type::accept(IValVisitor* visitor, void* extra) // within its own member functions). To handle this case, // we will construct a default specialization at the use // site if needed. + // + // This same logic should also apply to declarations nested + // more than one level inside of a generic (e.g., a `typdef` + // inside of a generic `struct`). + // + // Similarly, it needs to work for multiple levels of + // nested generics. + // + + // We are going to build up a list of substitutions that need + // to be applied to the decl-ref to make it specialized. + RefPtr substsToApply; + RefPtr* link = &substsToApply; - if (auto genericParent = declRef.GetParent().As()) + RefPtr dd = declRef.getDecl(); + for(;;) { - auto subst = declRef.substitutions; - // try find a substitution targeting this generic decl - bool substFound = false; - for (auto genSubst = subst.genericSubstitutions; genSubst; genSubst = genSubst->outer) + RefPtr childDecl = dd; + RefPtr parentDecl = dd->ParentDecl; + if(!parentDecl) + break; + + dd = parentDecl; + + if(auto genericParentDecl = parentDecl.As()) { - if (genSubst->genericDecl == genericParent.decl) + // Don't specialize any parameters of a generic. + if(childDecl != genericParentDecl->inner) + break; + + // We have a generic ancestor, but do we have an substitutions for it? + RefPtr foundSubst; + for(auto s = declRef.substitutions.substitutions; s; s = s->outer) { - substFound = true; + auto genSubst = s.As(); + if(!genSubst) + continue; + + if(genSubst->genericDecl != genericParentDecl) + continue; + + // Okay, we found a matching substitution, + // so there is nothing to be done. + foundSubst = genSubst; break; } - } - // we did not find an existing substituion, create a default one - if (!substFound) - { - declRef.substitutions = createDefaultSubstitutions( - session, - declRef.decl, - subst); + + if(!foundSubst) + { + RefPtr newSubst = createDefaultSubsitutionsForGeneric( + session, + genericParentDecl, + nullptr); + + *link = newSubst; + link = &newSubst->outer; + } } } + if(!substsToApply) + return declRef; + + int diff = 0; + return declRef.SubstituteImpl(substsToApply, &diff); + } + + // TODO: need to figure out how to unify this with the logic + // in the generic case... + DeclRefType* DeclRefType::Create( + Session* session, + DeclRef declRef) + { + declRef = createDefaultSubstitutionsIfNeeded(session, declRef); + if (auto builtinMod = declRef.getDecl()->FindModifier()) { auto type = new BasicExpressionType(builtinMod->tag); @@ -734,7 +758,15 @@ void Type::accept(IValVisitor* visitor, void* extra) } else if (auto magicMod = declRef.getDecl()->FindModifier()) { - GenericSubstitution* subst = declRef.substitutions.genericSubstitutions.Ptr(); + GenericSubstitution* subst = nullptr; + for(auto s = declRef.substitutions.substitutions; s; s = s->outer) + { + if(auto genericSubst = s.As()) + { + subst = genericSubst; + break; + } + } if (magicMod->name == "SamplerState") { @@ -910,28 +942,6 @@ void Type::accept(IValVisitor* visitor, void* extra) return (int)(int64_t)(void*)this; } - // IRBasicBlockType - - String IRBasicBlockType::ToString() - { - return "Block"; - } - - bool IRBasicBlockType::EqualsImpl(Type * /*type*/) - { - return false; - } - - RefPtr IRBasicBlockType::CreateCanonicalType() - { - return this; - } - - int IRBasicBlockType::GetHashCode() - { - return (int)(int64_t)(void*)this; - } - // InitializerListType String InitializerListType::ToString() @@ -1196,6 +1206,18 @@ void Type::accept(IValVisitor* visitor, void* extra) return elementType->AsBasicType(); } + // + + RefPtr findInnerMostGenericSubstitution(Substitutions* subst) + { + for(RefPtr s = subst; s; s = s->outer) + { + if(auto genericSubst = s.As()) + return genericSubst; + } + return nullptr; + } + // MatrixExpressionType String MatrixExpressionType::ToString() @@ -1212,24 +1234,24 @@ void Type::accept(IValVisitor* visitor, void* extra) Type* MatrixExpressionType::getElementType() { - return this->declRef.substitutions.genericSubstitutions->args[0].As().Ptr(); + return findInnerMostGenericSubstitution(declRef.substitutions)->args[0].As().Ptr(); } IntVal* MatrixExpressionType::getRowCount() { - return this->declRef.substitutions.genericSubstitutions->args[1].As().Ptr(); + return findInnerMostGenericSubstitution(declRef.substitutions)->args[1].As().Ptr(); } IntVal* MatrixExpressionType::getColumnCount() { - return this->declRef.substitutions.genericSubstitutions->args[2].As().Ptr(); + return findInnerMostGenericSubstitution(declRef.substitutions)->args[2].As().Ptr(); } // PtrTypeBase Type* PtrTypeBase::getValueType() { - return this->declRef.substitutions.genericSubstitutions->args[0].As().Ptr(); + return findInnerMostGenericSubstitution(declRef.substitutions)->args[0].As().Ptr(); } // GenericParamIntVal @@ -1256,9 +1278,13 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr GenericParamIntVal::SubstituteImpl(SubstitutionSet subst, int* ioDiff) { // search for a substitution that might apply to us - for (auto genSubst = subst.genericSubstitutions; genSubst; genSubst = genSubst->outer.Ptr()) + for(auto s = subst.substitutions; s; s = s->outer) { - // the generic decl associated with the substitution list must be + auto genSubst = s.As(); + if(!genSubst) + continue; + + // the generic decl associated with the substitution list must be // the generic decl that declared this parameter auto genericDecl = genSubst->genericDecl; if (genericDecl != declRef.getDecl()->ParentDecl) @@ -1293,17 +1319,18 @@ void Type::accept(IValVisitor* visitor, void* extra) // Substitutions - RefPtr GenericSubstitution::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr GenericSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) { if (!this) return nullptr; int diff = 0; - auto outerSubst = outer ? outer->SubstituteImpl(subst, &diff) : nullptr; + + if(substOuter != outer) diff++; List> substArgs; for (auto a : args) { - substArgs.Add(a->SubstituteImpl(subst, &diff)); + substArgs.Add(a->SubstituteImpl(substSet, &diff)); } if (!diff) return this; @@ -1312,7 +1339,7 @@ void Type::accept(IValVisitor* visitor, void* extra) auto substSubst = new GenericSubstitution(); substSubst->genericDecl = genericDecl; substSubst->args = substArgs; - substSubst->outer = outerSubst.As(); + substSubst->outer = substOuter; return substSubst; } @@ -1344,75 +1371,72 @@ void Type::accept(IValVisitor* visitor, void* extra) return true; } - RefPtr ThisTypeSubstitution::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr ThisTypeSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) { if (!this) return nullptr; int diff = 0; - RefPtr newSourceType; - if (sourceType) - newSourceType = sourceType->SubstituteImpl(subst, &diff); - else - { - // this_type is a free variable, use this_type from subst - if (subst.thisTypeSubstitution) - { - if (subst.thisTypeSubstitution->sourceType != sourceType) - { - newSourceType = subst.thisTypeSubstitution->sourceType; - diff = 1; - } - } - } + + if(substOuter != outer) diff++; + auto substWitness = witness->SubstituteImpl(substSet, &diff).As(); + if (!diff) return this; (*ioDiff)++; auto substSubst = new ThisTypeSubstitution(); - substSubst->sourceType = newSourceType; + substSubst->interfaceDecl = interfaceDecl; + substSubst->witness = substWitness; + substSubst->outer = substOuter; return substSubst; } bool ThisTypeSubstitution::Equals(Substitutions* subst) { if (!subst) - return true; + return this == nullptr; if (auto thisTypeSubst = dynamic_cast(subst)) { - if (!sourceType || !thisTypeSubst->sourceType) - return true; - return sourceType->EqualsVal(thisTypeSubst->sourceType); + return witness->EqualsVal(thisTypeSubst->witness); } return false; } - RefPtr GlobalGenericParamSubstitution::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + int ThisTypeSubstitution::GetHashCode() const + { + return witness->GetHashCode(); + } + + RefPtr GlobalGenericParamSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) { // if we find a GlobalGenericParamSubstitution in subst that references the same __generic_param decl // return a copy of that GlobalGenericParamSubstitution int diff = 0; - RefPtr outerSubst = outer ? outer->SubstituteImpl(subst, &diff) : nullptr; - for (auto gSubst = subst.globalGenParamSubstitutions; gSubst; gSubst = gSubst->outer) - { - if (gSubst->paramDecl == paramDecl) - { - // substitute only if we are really different - if (!gSubst->actualType->EqualsVal(actualType)) - { - RefPtr rs = new GlobalGenericParamSubstitution(*gSubst); - rs->outer = outerSubst.As(); - return rs; - } - } - } - if (diff) + if(substOuter != outer) diff++; + + auto substActualType = actualType->SubstituteImpl(substSet, &diff).As(); + + List substConstraintArgs; + for(auto constraintArg : constraintArgs) { - *ioDiff++; - RefPtr rs = new GlobalGenericParamSubstitution(*this); - rs->outer = outerSubst.As(); - return rs; + ConstraintArg substConstraintArg; + substConstraintArg.decl = constraintArg.decl; + substConstraintArg.val = constraintArg.val->SubstituteImpl(substSet, &diff); + + substConstraintArgs.Add(substConstraintArg); } - return this; + + if(!diff) + return this; + + (*ioDiff)++; + + RefPtr substSubst = new GlobalGenericParamSubstitution(); + substSubst->paramDecl = paramDecl; + substSubst->actualType = substActualType; + substSubst->constraintArgs = substConstraintArgs; + substSubst->outer = substOuter; + return substSubst; } bool GlobalGenericParamSubstitution::Equals(Substitutions* subst) @@ -1425,13 +1449,11 @@ void Type::accept(IValVisitor* visitor, void* extra) return false; if (!actualType->EqualsVal(genSubst->actualType)) return false; - if (witnessTables.Count() != genSubst->witnessTables.Count()) + if (constraintArgs.Count() != genSubst->constraintArgs.Count()) return false; - for (UInt i = 0; i < witnessTables.Count(); i++) + for (UInt i = 0; i < constraintArgs.Count(); i++) { - if (!witnessTables[i].Key->Equals(genSubst->witnessTables[i].Key)) - return false; - if (!witnessTables[i].Value->EqualsVal(genSubst->witnessTables[i].Value)) + if (!constraintArgs[i].val->EqualsVal(genSubst->constraintArgs[i].val)) return false; } return true; @@ -1474,74 +1496,354 @@ void Type::accept(IValVisitor* visitor, void* extra) UNREACHABLE_RETURN(expr); } - bool hasGlobalGenericSubst(SubstitutionSet destSubst, GlobalGenericParamSubstitution * genSubst) + void buildMemberDictionary(ContainerDecl* decl); + + InterfaceDecl* findOuterInterfaceDecl(Decl* decl) { - for (auto subst = destSubst.globalGenParamSubstitutions; subst; subst = subst->outer) + Decl* dd = decl; + while(dd) { - if (subst->paramDecl == genSubst->paramDecl) - return true; + if(auto interfaceDecl = dd->As()) + return interfaceDecl; + + dd = dd->ParentDecl; } - return false; + return nullptr; } - void insertGlobalGenericSubstitutions(SubstitutionSet & destSubst, SubstitutionSet srcSubst, int * ioDiff) + + RefPtr findGlobalGenericSubst( + RefPtr substs, + GlobalGenericParamDecl* paramDecl) { - int diff = 0; - - if (auto globalGenSubst = srcSubst.globalGenParamSubstitutions) + for(auto s = substs; s; s = s->outer) { - if (!hasGlobalGenericSubst(destSubst, globalGenSubst)) - { - RefPtr cpyGlobalGenSubst = new GlobalGenericParamSubstitution(*globalGenSubst); - cpyGlobalGenSubst->outer = destSubst.globalGenParamSubstitutions; - destSubst.globalGenParamSubstitutions = cpyGlobalGenSubst; - diff = 1; - } + auto gSubst = s.As(); + if(!gSubst) + continue; + + if(gSubst->paramDecl != paramDecl) + continue; + + return gSubst; } - *ioDiff += diff; + + return nullptr; } - void buildMemberDictionary(ContainerDecl* decl); + RefPtr specializeSubstitutionsShallow( + RefPtr substToSpecialize, + RefPtr substsToApply, + RefPtr restSubst, + int* ioDiff) + { + return substToSpecialize->applySubstitutionsShallow(substsToApply, restSubst, ioDiff); + } - DeclRefBase DeclRefBase::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + RefPtr specializeGlobalGenericSubstitutions( + Decl* declToSpecialize, + RefPtr substsToSpecialize, + RefPtr substsToApply, + int* ioDiff, + HashSet& ioParametersFound) { - int diff = 0; - auto substSubst = substituteSubstitutions(substitutions, subst, &diff); - if (!diff) - return *this; + // Any existing global-generic substitutions will trigger + // a recursive case that skips the rest of the function. + for(auto specSubst = substsToSpecialize; specSubst; specSubst = specSubst->outer) + { + auto specGlobalGenericSubst = specSubst.As(); + if(!specGlobalGenericSubst) + continue; - *ioDiff += diff; + ioParametersFound.Add(specGlobalGenericSubst->paramDecl); - DeclRefBase substDeclRef; - substDeclRef.decl = decl; - substDeclRef.substitutions = substSubst; - - // if this is a AssocTypeDecl, try lookup the actual associated type - if (auto assocTypeDecl = substDeclRef.decl->As()) + int diff = 0; + auto restSubst = specializeGlobalGenericSubstitutions( + declToSpecialize, + specSubst->outer, + substsToApply, + &diff, + ioParametersFound); + + auto firstSubst = specializeSubstitutionsShallow( + specGlobalGenericSubst, + substsToApply, + restSubst, + &diff); + + *ioDiff += diff; + return firstSubst; + } + + // No more existing substitutions, so we know we can apply + // our global generic substitutions without any special work. + + // We expect global generic substitutions to come at + // the end of the list in all cases, so lets advance + // until we see them. + RefPtr appGlobalGenericSubsts = substsToApply; + while(appGlobalGenericSubsts && !appGlobalGenericSubsts.As()) + appGlobalGenericSubsts = appGlobalGenericSubsts->outer; + + + // If there is nothing to apply, then we are done + if(!appGlobalGenericSubsts) + return nullptr; + + // Otherwise, it seems like something has to change. + (*ioDiff)++; + + // If there were no parameters bound by the existing substitution, + // then we can safely use the global generics from the to-apply set. + if(ioParametersFound.Count() == 0) + return appGlobalGenericSubsts; + + RefPtr resultSubst; + RefPtr* link = &resultSubst; + for(auto appSubst = appGlobalGenericSubsts; appSubst; appSubst = appSubst->outer) { - auto thisSubst = getThisTypeSubst(substDeclRef, false); - if (thisSubst) + auto appGlobalGenericSubst = appSubst.As(); + if(!appSubst) + continue; + + // Don't include substitutions for parameters already handled. + if(ioParametersFound.Contains(appGlobalGenericSubst->paramDecl)) + continue; + + RefPtr newSubst = new GlobalGenericParamSubstitution(); + newSubst->paramDecl = appGlobalGenericSubst->paramDecl; + newSubst->actualType = appGlobalGenericSubst->actualType; + newSubst->constraintArgs = appGlobalGenericSubst->constraintArgs; + + *link = newSubst; + link = &newSubst->outer; + } + + return resultSubst; + } + + RefPtr specializeGlobalGenericSubstitutions( + Decl* declToSpecialize, + RefPtr substsToSpecialize, + RefPtr substsToApply, + int* ioDiff) + { + // Keep track of any parameters already present in the + // existing substitution. + HashSet parametersFound; + return specializeGlobalGenericSubstitutions(declToSpecialize, substsToSpecialize, substsToApply, ioDiff, parametersFound); + } + + + // Construct new substitutions to apply to a declaration, + // based on a provided substituion set to be applied + RefPtr specializeSubstitutions( + Decl* declToSpecialize, + RefPtr substsToSpecialize, + RefPtr substsToApply, + int* ioDiff) + { + // No declaration? Then nothing to specialize. + if(!declToSpecialize) + return nullptr; + + // No (remaining) substitutions to apply? Then we are done. + if(!substsToApply) + return substsToSpecialize; + + // Walk the hierarchy of the declaration to determine what specializations might apply. + // We assume that the `substsToSpecialize` must be aligned with the ancestor + // hierarchy of `declToSpecialize` such that if, e.g., the `declToSpecialize` is + // nested directly in a generic, then `substToSpecialize` will either start with + // the corresponding `GenericSubstitution` or there will be *no* generic substitutions + // corresponding to that decl. + for(Decl* ancestorDecl = declToSpecialize; ancestorDecl; ancestorDecl = ancestorDecl->ParentDecl) + { + if(auto ancestorGenericDecl = ancestorDecl->As()) { - if (auto declRefType = thisSubst->sourceType.As()) + // The declaration is nested inside a generic. + // Does it already have a specialization for that generic? + if(auto specGenericSubst = substsToSpecialize.As()) { - if (auto aggDeclRef = declRefType->declRef.As()) + if(specGenericSubst->genericDecl == ancestorGenericDecl) { - Decl* subTypeDecl = nullptr; - buildMemberDictionary(aggDeclRef.getDecl()); - SLANG_ASSERT(aggDeclRef.getDecl()->memberDictionaryIsValid); - aggDeclRef.getDecl()->memberDictionary.TryGetValue(assocTypeDecl->getName(), subTypeDecl); - if (auto typeDefDecl = subTypeDecl->As()) - { - auto t = GetType(DeclRef(typeDefDecl, aggDeclRef.substitutions)); - auto canonicalType = t->GetCanonicalType()->AsDeclRefType(); - SLANG_ASSERT(canonicalType); - return canonicalType->declRef; - } - SLANG_ASSERT(subTypeDecl); - return DeclRefBase(subTypeDecl, aggDeclRef.substitutions); + // Yes. We have an existing specialization, so we will + // keep one matching it in place. + int diff = 0; + auto restSubst = specializeSubstitutions( + ancestorGenericDecl->ParentDecl, + specGenericSubst->outer, + substsToApply, + &diff); + + auto firstSubst = specializeSubstitutionsShallow( + specGenericSubst, + substsToApply, + restSubst, + &diff); + + *ioDiff += diff; + return firstSubst; + } + } + + // If the declaration is not already specialized + // for the given generic, then see if we are trying + // to *apply* such specializations to it. + // + // TODO: The way we handle things right now with + // "default" specializations, this case shouldn't + // actually come up. + // + for(auto s = substsToApply; s; s = s->outer) + { + auto appGenericSubst = s.As(); + if(!appGenericSubst) + continue; + + if(appGenericSubst->genericDecl != ancestorGenericDecl) + continue; + + // The substitutions we are applying are trying + // to specialize this generic, but we don't already + // have a generic substitution in place. + // We will need to create one. + + int diff = 0; + auto restSubst = specializeSubstitutions( + ancestorGenericDecl->ParentDecl, + substsToSpecialize, + substsToApply, + &diff); + + RefPtr firstSubst = new GenericSubstitution(); + firstSubst->genericDecl = ancestorGenericDecl; + firstSubst->args = appGenericSubst->args; + firstSubst->outer = restSubst; + + (*ioDiff)++; + return firstSubst; + } + } + else if(auto ancestorInterfaceDecl = ancestorDecl->As()) + { + // The task is basically the same as for the generic case: + // We want to see if there is any existing substitution that + // applies to this declaration, and use that if possible. + + // The declaration is nested inside a generic. + // Does it already have a specialization for that generic? + if(auto specThisTypeSubst = substsToSpecialize.As()) + { + if(specThisTypeSubst->interfaceDecl == ancestorInterfaceDecl) + { + // Yes. We have an existing specialization, so we will + // keep one matching it in place. + int diff = 0; + auto restSubst = specializeSubstitutions( + ancestorInterfaceDecl->ParentDecl, + specThisTypeSubst->outer, + substsToApply, + &diff); + + auto firstSubst = specializeSubstitutionsShallow( + specThisTypeSubst, + substsToApply, + restSubst, + &diff); + + *ioDiff += diff; + return firstSubst; } } + + // Otherwise, check if we are trying to apply + // a this-type substitution to the given interface + // + for(auto s = substsToApply; s; s = s->outer) + { + auto appThisTypeSubst = s.As(); + if(!appThisTypeSubst) + continue; + + if(appThisTypeSubst->interfaceDecl != ancestorInterfaceDecl) + continue; + + int diff = 0; + auto restSubst = specializeSubstitutions( + ancestorInterfaceDecl->ParentDecl, + substsToSpecialize, + substsToApply, + &diff); + + RefPtr firstSubst = new ThisTypeSubstitution(); + firstSubst->interfaceDecl = ancestorInterfaceDecl; + firstSubst->witness = appThisTypeSubst->witness; + firstSubst->outer = restSubst; + + (*ioDiff)++; + return firstSubst; + } } } + + // If we reach here then we've walked the full hierarchy up from + // `declToSpecialize` and either didn't run into an generic/interface + // declarations, or we didn't find any attempt to specialize them + // in either substitution. + // + // As an invariant, there should *not* be any generic or this-type + // substitutiosn in `substToSpecialize`, because otherwise they + // would be specializations that don't actually apply to the given + // declaration. + // + // The remaining substitutions to apply, if any, should thus be + // global-generic substitutions. And similarly, those are the + // only remaining substitutions we really care about in + // `substsToApply`. + // + // Note: this does *not* mean that `substsToApply` doesn't have + // any generic or this-type substitutions; it just means that none + // of them were applicable. + // + return specializeGlobalGenericSubstitutions( + declToSpecialize, + substsToSpecialize, + substsToApply, + ioDiff); + } + + DeclRefBase DeclRefBase::SubstituteImpl(SubstitutionSet substSet, int* ioDiff) + { + // Nothing to do when we have no declaration. + if(!decl) + return *this; + + // Apply the given substitutions to any specializations + // that have already been applied to this declaration. + int diff = 0; + + auto substSubst = specializeSubstitutions( + decl, + substitutions.substitutions, + substSet.substitutions, + &diff); + + if (!diff) + return *this; + + *ioDiff += diff; + + DeclRefBase substDeclRef; + substDeclRef.decl = decl; + substDeclRef.substitutions = substSubst; + + // TODO: The old code here used to try to translate a decl-ref + // to an associated type in a decl-ref for the concrete type + // in a paarticular implementation. + // + // I have only kept that logic in `DeclRefType::SubstituteImpl`, + // but it may turn out it is needed here too. + return substDeclRef; } @@ -1569,32 +1871,45 @@ void Type::accept(IValVisitor* visitor, void* extra) if (!parentDecl) return DeclRefBase(); - if (auto parentGeneric = dynamic_cast(parentDecl)) + // Default is to apply the same set of substitutions/specializations + // to the parent declaration as were applied to the child. + RefPtr substToApply = substitutions.substitutions; + + if(auto interfaceDecl = dynamic_cast(decl)) { - auto genSubst = substitutions.genericSubstitutions; - if (genSubst && genSubst->genericDecl == parentDecl) - { - // We strip away the specializations that were applied to - // the parent, since we were asked for a reference *to* the parent. - return DeclRefBase(parentGeneric, SubstitutionSet(genSubst->outer, substitutions.thisTypeSubstitution, - substitutions.globalGenParamSubstitutions)); - } - else + // The declaration being referenced is an `interface` declaration, + // and there might be a this-type substitution in place. + // A reference to the parent of the interface declaration + // should not include that substitution. + if(auto thisTypeSubst = substToApply.As()) { - // Either we don't have specializations, or the inner-most - // specializations didn't apply to the parent decl. This - // can happen if we are looking at an unspecialized - // declaration that is a child of a generic. - return DeclRefBase(parentGeneric, substitutions); + if(thisTypeSubst->interfaceDecl == interfaceDecl) + { + // Strip away that specializations that apply to the interface. + substToApply = thisTypeSubst->outer; + } } } - else + + if (auto parentGenericDecl = dynamic_cast(parentDecl)) { - // If the parent isn't a generic, then it must - // use the same specializations as this declaration - return DeclRefBase(parentDecl, substitutions); + // The parent of this declaration is a generic, which means + // that the decl-ref to the current declaration might include + // substitutiosn that specialize the generic parameters. + // A decl-ref to the parent generic should *not* include + // those substitutions. + // + if(auto genericSubst = substToApply.As()) + { + if(genericSubst->genericDecl == parentGenericDecl) + { + // Strip away the specializations that were applied to the parent. + substToApply = genericSubst->outer; + } + } } + return DeclRefBase(parentDecl, substToApply); } int DeclRefBase::GetHashCode() const @@ -1706,12 +2021,12 @@ void Type::accept(IValVisitor* visitor, void* extra) Type* HLSLPatchType::getElementType() { - return this->declRef.substitutions.genericSubstitutions->args[0].As().Ptr(); + return findInnerMostGenericSubstitution(declRef.substitutions)->args[0].As().Ptr(); } IntVal* HLSLPatchType::getElementCount() { - return this->declRef.substitutions.genericSubstitutions->args[1].As().Ptr(); + return findInnerMostGenericSubstitution(declRef.substitutions)->args[1].As().Ptr(); } // Constructors for types @@ -1742,7 +2057,9 @@ void Type::accept(IValVisitor* visitor, void* extra) Session* session, DeclRef const& declRef) { - auto namedType = new NamedExpressionType(declRef); + DeclRef specializedDeclRef = createDefaultSubstitutionsIfNeeded(session, declRef).As(); + + auto namedType = new NamedExpressionType(specializedDeclRef); namedType->setSession(session); return namedType; } @@ -1828,64 +2145,141 @@ void Type::accept(IValVisitor* visitor, void* extra) && declRef.Equals(otherWitness->declRef); } + RefPtr findThisTypeSubstitution( + Substitutions* substs, + InterfaceDecl* interfaceDecl) + { + for(RefPtr s = substs; s; s = s->outer) + { + auto thisTypeSubst = s.As(); + if(!thisTypeSubst) + continue; + + if(thisTypeSubst->interfaceDecl != interfaceDecl) + continue; + + return thisTypeSubst; + } + + return nullptr; + } + RefPtr DeclaredSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int * ioDiff) { - if (auto genConstraintDecl = declRef.As()) + if (auto genConstraintDeclRef = declRef.As()) { + auto genConstraintDecl = genConstraintDeclRef.getDecl(); + // search for a substitution that might apply to us - for (auto genericSubst = subst.genericSubstitutions; genericSubst; genericSubst = genericSubst->outer.Ptr()) + for(auto s = subst.substitutions; s; s = s->outer) { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genericSubst->genericDecl; - if (genericDecl != genConstraintDecl.getDecl()->ParentDecl) - continue; - bool found = false; - UInt index = 0; - for (auto m : genericDecl->Members) + if(auto genericSubst = s.As()) { - if (auto constraintParam = m.As()) + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = genericSubst->genericDecl; + if (genericDecl != genConstraintDecl->ParentDecl) + continue; + + bool found = false; + UInt index = 0; + for (auto m : genericDecl->Members) { - if (constraintParam.Ptr() == declRef.getDecl()) + if (auto constraintParam = m.As()) { - found = true; - break; + if (constraintParam.Ptr() == declRef.getDecl()) + { + found = true; + break; + } + index++; } - index++; + } + if (found) + { + (*ioDiff)++; + auto ordinaryParamCount = genericDecl->getMembersOfType().Count() + + genericDecl->getMembersOfType().Count(); + SLANG_ASSERT(index + ordinaryParamCount < genericSubst->args.Count()); + return genericSubst->args[index + ordinaryParamCount]; } } - if (found) + else if(auto globalGenericSubst = s.As()) { - (*ioDiff)++; - auto ordinaryParamCount = genericDecl->getMembersOfType().Count() + - genericDecl->getMembersOfType().Count(); - SLANG_ASSERT(index + ordinaryParamCount < genericSubst->args.Count()); - return genericSubst->args[index + ordinaryParamCount]; + // check if the substitution is really about this global generic type parameter + if (globalGenericSubst->paramDecl != genConstraintDecl->ParentDecl) + continue; + + for(auto constraintArg : globalGenericSubst->constraintArgs) + { + if(constraintArg.decl.Ptr() != genConstraintDecl) + continue; + + (*ioDiff)++; + return constraintArg.val; + } } } - for (auto globalGenParamSubst = subst.globalGenParamSubstitutions; globalGenParamSubst; globalGenParamSubst = globalGenParamSubst->outer.Ptr()) - { - // we have a GlobalGenericParamSubstitution, this substitution will provide - // a concrete IRWitnessTable for a generic global variable - auto supType = GetSup(genConstraintDecl); + } - // check if the substitution is really about this global generic type parameter - if (globalGenParamSubst->paramDecl != genConstraintDecl.getDecl()->ParentDecl) - continue; + // Perform substitution on the constituent elements. + int diff = 0; + auto substSub = sub->SubstituteImpl(subst, &diff).As(); + auto substSup = sup->SubstituteImpl(subst, &diff).As(); + auto substDeclRef = declRef.SubstituteImpl(subst, &diff); + if (!diff) + return this; + + (*ioDiff)++; - // find witness table for the required interface - for (auto witness : globalGenParamSubst->witnessTables) - if (witness.Key->EqualsVal(supType)) + // If we have a reference to a type constraint for an + // associated type declaration, then we can replace it + // with the concrete conformance witness for a concrete + // type implementing the outer interface. + // + // TODO: It is a bit gross that we use `GenericTypeConstraintDecl` for + // associated types, when they aren't really generic type *parameters*, + // so we'll need to change this location in the code if we ever clean + // up the hierarchy. + // + if (auto substTypeConstraintDecl = substDeclRef.decl->As()) + { + if (auto substAssocTypeDecl = substTypeConstraintDecl->ParentDecl->As()) + { + if (auto interfaceDecl = substAssocTypeDecl->ParentDecl->As()) + { + // At this point we have a constraint decl for an associated type, + // and we nee to see if we are dealing with a concrete substitution + // for the interface around that associated type. + if(auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.substitutions, interfaceDecl)) { - (*ioDiff)++; - return witness.Value; + // We need to look up the declaration that satisfies + // the requirement named by the associated type. + Decl* requirementKey = substTypeConstraintDecl; + RequirementWitness requirementWitness = tryLookUpRequirementWitness(thisTypeSubst->witness, requirementKey); + switch(requirementWitness.getFlavor()) + { + default: + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = requirementWitness.getVal(); + return satisfyingVal; + } + } } + } } } + + + + RefPtr rs = new DeclaredSubtypeWitness(); - rs->sub = sub->SubstituteImpl(subst, ioDiff).As(); - rs->sup = sup->SubstituteImpl(subst, ioDiff).As(); - rs->declRef = declRef.SubstituteImpl(subst, ioDiff); + rs->sub = substSub; + rs->sup = substSup; + rs->declRef = substDeclRef; return rs; } @@ -1918,7 +2312,7 @@ void Type::accept(IValVisitor* visitor, void* extra) return sub->Equals(otherWitness->sub) && sup->Equals(otherWitness->sup) && subToMid->EqualsVal(otherWitness->subToMid) - && midToSup->EqualsVal(otherWitness->midToSup); + && midToSup.Equals(otherWitness->midToSup); } RefPtr TransitiveSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int * ioDiff) @@ -1928,7 +2322,7 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr substSub = sub->SubstituteImpl(subst, &diff).As(); RefPtr substSup = sup->SubstituteImpl(subst, &diff).As(); RefPtr substSubToMid = subToMid->SubstituteImpl(subst, &diff).As(); - RefPtr substMidToSup = midToSup->SubstituteImpl(subst, &diff).As(); + DeclRef substMidToSup = midToSup.SubstituteImpl(subst, &diff); // If nothing changed, then we can bail out early. if (!diff) @@ -1971,7 +2365,7 @@ void Type::accept(IValVisitor* visitor, void* extra) sb << "TransitiveSubtypeWitness("; sb << this->subToMid->ToString(); sb << ", "; - sb << this->midToSup->ToString(); + sb << this->midToSup.toString(); sb << ")"; return sb.ProduceString(); } @@ -1981,29 +2375,7 @@ void Type::accept(IValVisitor* visitor, void* extra) auto hash = sub->GetHashCode(); hash = combineHash(hash, sup->GetHashCode()); hash = combineHash(hash, subToMid->GetHashCode()); - hash = combineHash(hash, midToSup->GetHashCode()); - return hash; - } - - // IRProxyVal - - bool IRProxyVal::EqualsVal(Val* val) - { - auto otherProxy = dynamic_cast(val); - if(!otherProxy) - return false; - - return this->inst.get() == otherProxy->inst.get(); - } - - String IRProxyVal::ToString() - { - return "IRProxyVal(...)"; - } - - int IRProxyVal::GetHashCode() - { - auto hash = Slang::GetHashCode(inst.get()); + hash = combineHash(hash, midToSup.GetHashCode()); return hash; } @@ -2020,77 +2392,19 @@ void Type::accept(IValVisitor* visitor, void* extra) return name->text; } - RefPtr getThisTypeSubst(DeclRefBase & declRef, bool insertSubstEntry) - { - RefPtr thisSubst = declRef.substitutions.thisTypeSubstitution; - if (!thisSubst) - { - thisSubst = new ThisTypeSubstitution(); - if (insertSubstEntry) - { - declRef.substitutions.thisTypeSubstitution = thisSubst; - } - } - return thisSubst; - } - - RefPtr getNewThisTypeSubst(DeclRefBase & declRef) - { - declRef.substitutions.thisTypeSubstitution = new ThisTypeSubstitution(); - return declRef.substitutions.thisTypeSubstitution; - } - - SubstitutionSet substituteSubstitutions(SubstitutionSet oldSubst, SubstitutionSet subst, int * ioDiff) - { - return oldSubst.substituteImpl(subst, ioDiff); - } - bool SubstitutionSet::Equals(SubstitutionSet substSet) const { - if (genericSubstitutions) - { - if (!genericSubstitutions->Equals(substSet.genericSubstitutions)) - return false; - } - else - { - if (substSet.genericSubstitutions) - return false; - } - if (thisTypeSubstitution) - { - if (!thisTypeSubstitution->Equals(substSet.thisTypeSubstitution)) - return false; - } - else - { - if (substSet.thisTypeSubstitution && substSet.thisTypeSubstitution->sourceType != nullptr) - return false; - } - return true; - } - SubstitutionSet SubstitutionSet::substituteImpl(SubstitutionSet subst, int * ioDiff) - { - SubstitutionSet rs; - if (genericSubstitutions) - rs.genericSubstitutions = genericSubstitutions->SubstituteImpl(subst, ioDiff).As(); - if (globalGenParamSubstitutions) - rs.globalGenParamSubstitutions = globalGenParamSubstitutions->SubstituteImpl(subst, ioDiff).As(); - if (thisTypeSubstitution) - rs.thisTypeSubstitution = thisTypeSubstitution->SubstituteImpl(subst, ioDiff).As(); + if(!substitutions || !substSet.substitutions) + return substitutions == substSet.substitutions; - insertGlobalGenericSubstitutions(rs, subst, ioDiff); - return rs; + return substitutions->Equals(substSet.substitutions); } + int SubstitutionSet::GetHashCode() const { int rs = 0; - if (genericSubstitutions) - rs = combineHash(rs, genericSubstitutions->GetHashCode()); - if (thisTypeSubstitution) - rs = combineHash(rs, thisTypeSubstitution->GetHashCode()); - if (globalGenParamSubstitutions) - rs = combineHash(rs, globalGenParamSubstitutions->GetHashCode()); + if (substitutions) + rs = combineHash(rs, substitutions->GetHashCode()); return rs; } } diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 0f23492d6..ebb9d814b 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -400,23 +400,18 @@ namespace Slang struct SubstitutionSet { - RefPtr genericSubstitutions; - RefPtr thisTypeSubstitution; - RefPtr globalGenParamSubstitutions; - operator bool() const + RefPtr substitutions; + operator Substitutions*() const { - return genericSubstitutions || thisTypeSubstitution || globalGenParamSubstitutions; + return substitutions; } + SubstitutionSet() {} - SubstitutionSet(RefPtr genSubst, RefPtr inThisTypeSubst, - RefPtr globalSubst) + SubstitutionSet(RefPtr subst) + : substitutions(subst) { - genericSubstitutions = genSubst; - thisTypeSubstitution = inThisTypeSubst; - globalGenParamSubstitutions = globalSubst; } bool Equals(SubstitutionSet substSet) const; - SubstitutionSet substituteImpl(SubstitutionSet subst, int * ioDiff); int GetHashCode() const; }; // A reference to a declaration, which may include @@ -444,11 +439,9 @@ namespace Slang substitutions(subst) {} - DeclRefBase(Decl* decl, RefPtr genSubstitutions, - RefPtr thisTypeSubst = nullptr, - RefPtr globalSubst = nullptr) - : decl(decl), - substitutions(genSubstitutions, thisTypeSubst, globalSubst) + DeclRefBase(Decl* decl, RefPtr subst) + : decl(decl) + , substitutions(subst) {} // Apply substitutions to a type or ddeclaration @@ -492,8 +485,8 @@ namespace Slang : DeclRefBase(decl, subst) {} - DeclRef(T* decl, RefPtr genSubst) - : DeclRefBase(decl, SubstitutionSet(genSubst, nullptr, nullptr)) + DeclRef(T* decl, RefPtr subst) + : DeclRefBase(decl, SubstitutionSet(subst)) {} template @@ -1004,6 +997,67 @@ namespace Slang LookupMask mask = LookupMask::Default; }; + struct WitnessTable; + + // A value that witnesses the satisfaction of an interface + // requirement by a particular declaration or value. + struct RequirementWitness + { + RequirementWitness() + : m_flavor(Flavor::none) + {} + + RequirementWitness(DeclRef declRef) + : m_flavor(Flavor::declRef) + , m_declRef(declRef) + {} + + RequirementWitness(RefPtr val); + + RequirementWitness(RefPtr witnessTable); + + enum class Flavor + { + none, + declRef, + val, + witnessTable, + }; + + Flavor getFlavor() + { + return m_flavor; + } + + DeclRef getDeclRef() + { + SLANG_ASSERT(getFlavor() == Flavor::declRef); + return m_declRef; + } + + RefPtr getVal() + { + SLANG_ASSERT(getFlavor() == Flavor::val); + return m_obj.As(); + } + + RefPtr getWitnessTable(); + + RequirementWitness specialize(SubstitutionSet const& subst); + + Flavor m_flavor; + DeclRef m_declRef; + RefPtr m_obj; + + }; + + typedef Dictionary RequirementDictionary; + + struct WitnessTable : RefObject + { + RequirementDictionary requirementDictionary; + }; + // Generate class definition for all syntax classes #define SYNTAX_FIELD(TYPE, NAME) TYPE NAME; #define FIELD(TYPE, NAME) TYPE NAME; @@ -1096,23 +1150,6 @@ namespace Slang return FilteredMemberRefList(declRef.getDecl()->Members, declRef.substitutions); } - // TODO: change this to return a lazy list instead of constructing actual list - inline List> getMembersWithExt(DeclRef const& declRef) - { - List> rs; - for (auto d : FilteredMemberRefList(declRef.getDecl()->Members, declRef.substitutions)) - rs.Add(d); - if (auto aggDeclRef = declRef.As()) - { - for (auto ext = GetCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension) - { - for (auto mbr : getMembers(DeclRef(ext, declRef.substitutions))) - rs.Add(mbr); - } - } - return rs; - } - template inline FilteredMemberRefList getMembersOfType(DeclRef const& declRef) { @@ -1245,29 +1282,16 @@ namespace Slang Session* session, Decl* decl); - void insertSubstAtBottom(RefPtr & substHead, RefPtr substToInsert); - RefPtr getNewThisTypeSubst(DeclRefBase & declRef); - RefPtr getThisTypeSubst(DeclRefBase & declRef, bool insertSubstEntry); - void removeSubstitution(DeclRefBase & declRef, RefPtr subst); - bool hasGenericSubstitutions(RefPtr subst); - RefPtr getGenericSubstitution(RefPtr subst); - - // This function substitutes the type arguments referenced in a linked list of substitutions - // which head is at `substHead` using the substitutions specified by `subst`. If the linked - // list `substHead` does not contain `GlobalGenericParamSubstitution` entries, they will be - // added to the bottom(outter most) of the linked list. - // Note that this function should be called when `substHead` is known to be the head of - // substitution linked list because the existance of `GlobalGenericPaaramSubstitution` is - // detected assuming the linked lists starts at `substHead`. If a substitution that is not - // the head of a substitution linked list is passed in, duplicate - // `GlobalGenericParamSubstitution`s could be appended to the linked list. - // This means that this function should * not* be called in places like - // `GenericSubstitution::SubstitutionImpl()` for its outer substitutions, because `outer` is - // obviously not the head of the linked list. Instead, use this function to substitution the - // substitution lists of `DeclRef` etc. to replace the call of - // `declRef.substitutions->SubstituteImpl()`, because the head to the linked list is known as a - // member of that class there. - SubstitutionSet substituteSubstitutions(SubstitutionSet oldSubst, SubstitutionSet subst, int * ioDiff); + DeclRef createDefaultSubstitutionsIfNeeded( + Session* session, + DeclRef declRef); + + RefPtr createDefaultSubsitutionsForGeneric( + Session* session, + GenericDecl* genericDecl, + RefPtr outerSubst); + + RefPtr findInnerMostGenericSubstitution(Substitutions* subst); } // namespace Slang #endif \ No newline at end of file diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h index 433c5e15c..14e9c0066 100644 --- a/source/slang/type-defs.h +++ b/source/slang/type-defs.h @@ -42,20 +42,6 @@ protected: ) END_SYNTAX_CLASS() -// The type of a reference to a basic block -// in our IR -SYNTAX_CLASS(IRBasicBlockType, Type) -RAW( -public: - virtual String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; - virtual int GetHashCode() override; -) -END_SYNTAX_CLASS() - // A type that takes the form of a reference to some declaration SYNTAX_CLASS(DeclRefType, Type) DECL_FIELD(DeclRef, declRef) @@ -107,9 +93,20 @@ protected: ) END_SYNTAX_CLASS() -// Base type for things we think of as "resources" -ABSTRACT_SYNTAX_CLASS(ResourceTypeBase, DeclRefType) +// Base type for things that are built in to the compiler, +// and will usually have special behavior or a custom +// mapping to the IR level. +ABSTRACT_SYNTAX_CLASS(BuiltinType, DeclRefType) +END_SYNTAX_CLASS() + +// Resources that contain "elements" that can be fetched +ABSTRACT_SYNTAX_CLASS(ResourceType, BuiltinType) + // The type that results from fetching an element from this resource + SYNTAX_FIELD(RefPtr, elementType) + + // Shape and access level information for this resource type FIELD(TextureFlavor, flavor) + RAW( TextureFlavor::Shape GetBaseShape() { @@ -123,12 +120,6 @@ ABSTRACT_SYNTAX_CLASS(ResourceTypeBase, DeclRefType) ) END_SYNTAX_CLASS() -// Resources that contain "elements" that can be fetched -ABSTRACT_SYNTAX_CLASS(ResourceType, ResourceTypeBase) - // The type that results from fetching an element from this resource - SYNTAX_FIELD(RefPtr, elementType) -END_SYNTAX_CLASS() - ABSTRACT_SYNTAX_CLASS(TextureTypeBase, ResourceType) RAW( TextureTypeBase() @@ -182,13 +173,13 @@ RAW( ) END_SYNTAX_CLASS() -SYNTAX_CLASS(SamplerStateType, DeclRefType) +SYNTAX_CLASS(SamplerStateType, BuiltinType) // What flavor of sampler state is this FIELD(SamplerStateFlavor, flavor) END_SYNTAX_CLASS() // Other cases of generic types known to the compiler -SYNTAX_CLASS(BuiltinGenericType, DeclRefType) +SYNTAX_CLASS(BuiltinGenericType, BuiltinType) SYNTAX_FIELD(RefPtr, elementType) RAW(Type* getElementType() { return elementType; }) @@ -206,14 +197,18 @@ SIMPLE_SYNTAX_CLASS(HLSLStructuredBufferType, HLSLStructuredBufferTypeBase) SIMPLE_SYNTAX_CLASS(HLSLRWStructuredBufferType, HLSLStructuredBufferTypeBase) // TODO: need raster-ordered case here -SIMPLE_SYNTAX_CLASS(UntypedBufferResourceType, DeclRefType) +SIMPLE_SYNTAX_CLASS(UntypedBufferResourceType, BuiltinType) SIMPLE_SYNTAX_CLASS(HLSLByteAddressBufferType, UntypedBufferResourceType) SIMPLE_SYNTAX_CLASS(HLSLRWByteAddressBufferType, UntypedBufferResourceType) +SIMPLE_SYNTAX_CLASS(RaytracingAccelerationStructureType, UntypedBufferResourceType) SIMPLE_SYNTAX_CLASS(HLSLAppendStructuredBufferType, HLSLStructuredBufferTypeBase) SIMPLE_SYNTAX_CLASS(HLSLConsumeStructuredBufferType, HLSLStructuredBufferTypeBase) -SYNTAX_CLASS(HLSLPatchType, DeclRefType) +SIMPLE_SYNTAX_CLASS(RayDescType, BuiltinType) +SIMPLE_SYNTAX_CLASS(BuiltInTriangleIntersectionAttributesType, BuiltinType) + +SYNTAX_CLASS(HLSLPatchType, BuiltinType) RAW( Type* getElementType(); IntVal* getElementCount(); @@ -231,7 +226,7 @@ SIMPLE_SYNTAX_CLASS(HLSLLineStreamType, HLSLStreamOutputType) SIMPLE_SYNTAX_CLASS(HLSLTriangleStreamType, HLSLStreamOutputType) // -SIMPLE_SYNTAX_CLASS(GLSLInputAttachmentType, DeclRefType) +SIMPLE_SYNTAX_CLASS(GLSLInputAttachmentType, BuiltinType) // Base class for types used when desugaring parameter block // declarations, includeing HLSL `cbuffer` or GLSL `uniform` blocks. @@ -272,64 +267,6 @@ protected: ) END_SYNTAX_CLASS() -// A type that has a rate qualifier applied. Conceptually `@R T` where `R` -// represents a rate, and `T` represents a data type. -SYNTAX_CLASS(RateQualifiedType, Type) - - // The rate `R` at which the value is computed/stored - SYNTAX_FIELD(RefPtr, rate); - - // The underlying data type `T` of the value - SYNTAX_FIELD(RefPtr, valueType); - -RAW( - virtual Slang::String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; - virtual int GetHashCode() override; - ) -END_SYNTAX_CLASS() - -// A representation of the `ConstExpr` rate, to be used -// in defining `@ConstExpr T` for particular data types `T` -SYNTAX_CLASS(ConstExprRate, Type) - -RAW( - virtual Slang::String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; - virtual int GetHashCode() override; - ) -END_SYNTAX_CLASS() - -// The effective type of a variable declared with `groupshared` storage qualifier. -// -// TODO: this should be converted to a `GroupSharedRate`, which then gets used -// in conjunction with `RateQualifiedType`. -SYNTAX_CLASS(GroupSharedType, Type) - SYNTAX_FIELD(RefPtr, valueType); - -RAW( - virtual ~GroupSharedType() - { - } - - virtual Slang::String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; - virtual int GetHashCode() override; - ) - -END_SYNTAX_CLASS() - // The "type" of an expression that resolves to a type. // For example, in the expression `float(2)` the sub-expression, // `float` would have the type `TypeType(float)`. @@ -389,11 +326,11 @@ protected: END_SYNTAX_CLASS() // The built-in `String` type -SIMPLE_SYNTAX_CLASS(StringType, DeclRefType) +SIMPLE_SYNTAX_CLASS(StringType, BuiltinType) // Base class for types that map down to // simple pointers as part of code generation. -SYNTAX_CLASS(PtrTypeBase, DeclRefType) +SYNTAX_CLASS(PtrTypeBase, BuiltinType) RAW( // Get the type of the pointed-to value. Type* getValueType(); diff --git a/source/slang/type-system-shared.h b/source/slang/type-system-shared.h index 5316dfa6e..61e0ebac7 100644 --- a/source/slang/type-system-shared.h +++ b/source/slang/type-system-shared.h @@ -5,16 +5,22 @@ namespace Slang { +#define FOREACH_BASE_TYPE(X) \ + X(Void) \ + X(Bool) \ + X(Int) \ + X(UInt) \ + X(UInt64) \ + X(Half) \ + X(Float) \ + X(Double) \ +/* end */ + enum class BaseType { - Void = 0, - Bool, - Int, - UInt, - UInt64, - Half, - Float, - Double, +#define DEFINE_BASE_TYPE(NAME) NAME, +FOREACH_BASE_TYPE(DEFINE_BASE_TYPE) +#undef DEFINE_BASE_TYPE }; struct TextureFlavor @@ -22,7 +28,7 @@ namespace Slang enum { // Mask for the overall "shape" of the texture - ShapeMask = SLANG_RESOURCE_BASE_SHAPE_MASK, + BaseShapeMask = SLANG_RESOURCE_BASE_SHAPE_MASK, // Flag for whether the shape has "array-ness" ArrayFlag = SLANG_TEXTURE_ARRAY_FLAG, @@ -50,9 +56,17 @@ namespace Slang ShapeCubeArray = ShapeCube | ArrayFlag, }; + enum + { + // This the total number of expressible flavors, + // which is *not* to say that every expressible + // flavor is actual valid. + Count = 0x10000, + }; + uint16_t flavor; - Shape GetBaseShape() const { return Shape(flavor & ShapeMask); } + Shape GetBaseShape() const { return Shape(flavor & BaseShapeMask); } bool isArray() const { return (flavor & ArrayFlag) != 0; } bool isMultisample() const { return (flavor & MultisampleFlag) != 0; } // bool isShadow() const { return (flavor & ShadowFlag) != 0; } diff --git a/source/slang/val-defs.h b/source/slang/val-defs.h index d83cda85c..1a277c60c 100644 --- a/source/slang/val-defs.h +++ b/source/slang/val-defs.h @@ -85,9 +85,6 @@ END_SYNTAX_CLASS() ABSTRACT_SYNTAX_CLASS(SubtypeWitness, Witness) FIELD(RefPtr, sub) FIELD(RefPtr, sup) - RAW( - virtual DeclRef getLastStepDeclRef() = 0; - ) END_SYNTAX_CLASS() SYNTAX_CLASS(TypeEqualityWitness, SubtypeWitness) @@ -96,10 +93,6 @@ RAW( virtual String ToString() override; virtual int GetHashCode() override; virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; - virtual DeclRef getLastStepDeclRef() override - { - return DeclRef(); - } ) END_SYNTAX_CLASS() // A witness that one type is a subtype of another @@ -111,10 +104,6 @@ RAW( virtual String ToString() override; virtual int GetHashCode() override; virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; - virtual DeclRef getLastStepDeclRef() override - { - return declRef; - } ) END_SYNTAX_CLASS() @@ -124,31 +113,11 @@ SYNTAX_CLASS(TransitiveSubtypeWitness, SubtypeWitness) FIELD(RefPtr, subToMid); // Witness that `mid : sup` - FIELD(RefPtr, midToSup); + FIELD(DeclRef, midToSup); RAW( virtual bool EqualsVal(Val* val) override; virtual String ToString() override; virtual int GetHashCode() override; virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; - virtual DeclRef getLastStepDeclRef() override - { - return midToSup->getLastStepDeclRef(); - } -) -END_SYNTAX_CLASS() - -// A value that is used as a proxy when we need to -// put an IR-level value into AST types -SYNTAX_CLASS(IRProxyVal, Val) - FIELD(IRUse, inst) -RAW( - virtual bool EqualsVal(Val* val) override; - virtual String ToString() override; - virtual int GetHashCode() override; - ~IRProxyVal() override - { - inst.clear(); - } ) END_SYNTAX_CLASS() - diff --git a/source/slang/vm.cpp b/source/slang/vm.cpp index 38083d631..802c8476b 100644 --- a/source/slang/vm.cpp +++ b/source/slang/vm.cpp @@ -257,7 +257,7 @@ VMSizeAlign getVMSymbolSize(BCSymbol* symbol) SLANG_UNEXPECTED("op"); break; - case kIROp_TypeType: + case kIROp_TypeKind: break; case kIROp_Func: @@ -409,16 +409,16 @@ void dumpVMFrame(VMFrame* vmFrame) { switch (regType.impl->op) { - case kIROp_TypeType: + case kIROp_TypeKind: // TODO: we could recursively go and print types... fprintf(stderr, ": Type = ???"); break; - case kIROp_readWriteStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: fprintf(stderr, ": RWStructuredBuffer = ???"); break; - case kIROp_structuredBufferType: + case kIROp_HLSLStructuredBufferType: fprintf(stderr, ": StructuredBuffer = ???"); break; @@ -426,11 +426,11 @@ void dumpVMFrame(VMFrame* vmFrame) fprintf(stderr, ": Bool = %s", *(bool*)regData ? "true" : "false"); break; - case kIROp_Int32Type: + case kIROp_IntType: fprintf(stderr, ": Int32 = %d", *(int32_t*)regData); break; - case kIROp_UInt32Type: + case kIROp_UIntType: fprintf(stderr, ": UInt32 = %u", *(uint32_t*)regData); break; @@ -499,16 +499,16 @@ void computeTypeSizeAlign( size = 1; break; - case kIROp_Int32Type: - case kIROp_UInt32Type: - case kIROp_Float32Type: + case kIROp_IntType: + case kIROp_UIntType: + case kIROp_FloatType: size = 4; break; case kIROp_FuncType: case kIROp_PtrType: - case kIROp_readWriteStructuredBufferType: - case kIROp_structuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + case kIROp_HLSLStructuredBufferType: size = sizeof(void*); break; @@ -632,7 +632,7 @@ void* loadVMSymbol( switch(bcSymbol->op) { - case kIROp_global_var: + case kIROp_GlobalVar: { auto type = getType(vmModule, bcSymbol->typeID); assert(type.impl->op == kIROp_PtrType); @@ -650,7 +650,7 @@ void* loadVMSymbol( } break; - case kIROp_global_constant: + case kIROp_GlobalConstant: { auto type = getType(vmModule, bcSymbol->typeID); void* valPtr = allocate(vm, type); @@ -1094,7 +1094,7 @@ void resumeThread( switch (type.impl->op) { - case kIROp_Int32Type: + case kIROp_IntType: *destPtr = *(int32_t*)leftPtr > *(int32_t*)rightPtr; break; @@ -1116,7 +1116,7 @@ void resumeThread( switch (type.impl->op) { - case kIROp_Int32Type: + case kIROp_IntType: *(int32_t*)destPtr = *(int32_t*)leftPtr * *(int32_t*)rightPtr; break; @@ -1138,7 +1138,7 @@ void resumeThread( switch (type.impl->op) { - case kIROp_Int32Type: + case kIROp_IntType: *(int32_t*)destPtr = *(int32_t*)leftPtr - *(int32_t*)rightPtr; break; -- cgit v1.2.3