diff options
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 114 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 503 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 46 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 42 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 105 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 35 | ||||
| -rw-r--r-- | tests/diagnostics/interface-requirement-not-satisfied.slang.expected | 1 | ||||
| -rw-r--r-- | tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang.expected | 1 | ||||
| -rw-r--r-- | tests/language-feature/parameters/generic-func-param-default-arg.slang | 47 | ||||
| -rw-r--r-- | tests/language-feature/parameters/generic-func-param-default-arg.slang.expected.txt | 4 |
12 files changed, 770 insertions, 177 deletions
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 6697e878a..1921a101f 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -582,6 +582,107 @@ namespace Slang HashCode getHashCode() const; }; + /// An expression together with (optional) substutions to apply to it + /// + /// Under the hood this is a pair of an `Expr*` and a `SubstitutionSet`. + /// Conceptually it represents the result of applying the substitutions, + /// recursively, to the given expression. + /// + /// `SubstExprBase` exists primarily to provide a non-templated base type + /// for `SubstExpr<T>`. Code should prefer to use `SubstExpr<Expr>` instead + /// of `SubstExprBase` as often as possible. + /// + struct SubstExprBase + { + public: + /// Initialize as a null expression + SubstExprBase() + {} + + /// Initialize as the given `expr` with no subsitutions applied + SubstExprBase(Expr* expr) + : m_expr(expr) + {} + + /// Initialize as the given `expr` with the given `substs` applied + SubstExprBase(Expr* expr, SubstitutionSet const& substs) + : m_expr(expr) + , m_substs(substs) + {} + + /// Get the underlying expression without any substitutions + Expr* getExpr() const { return m_expr; } + + /// Get the subsitutions being applied, if any + SubstitutionSet const& getSubsts() const { return m_substs; } + + private: + Expr* m_expr = nullptr; + SubstitutionSet m_substs; + + typedef void (SubstExprBase::*SafeBool)(); + void SafeBoolTrue() {} + + public: + /// Test whether this is a non-null expression + operator SafeBool() + { + return m_expr ? &SubstExprBase::SafeBoolTrue : nullptr; + } + + /// Test whether this is a null expression + bool operator!() const { return m_expr == nullptr; } + + }; + + /// An expression together with (optional) substutions to apply to it + /// + /// Under the hood this is a pair of an `T*` (there `T: Expr`) and a `SubstitutionSet`. + /// Conceptually it represents the result of applying the substitutions, + /// recursively, to the given expression. + /// + template<typename T> + struct SubstExpr : SubstExprBase + { + private: + typedef SubstExprBase Super; + + public: + /// Initialize as a null expression + SubstExpr() + {} + + /// Initialize as the given `expr` with no subsitutions applied + SubstExpr(T* expr) + : Super(expr) + {} + + /// Initialize as the given `expr` with the given `substs` applied + SubstExpr(T* expr, SubstitutionSet const& substs) + : Super(expr, substs) + {} + + /// Initialize as a copy of the given `other` expression + template <typename U> + SubstExpr(SubstExpr<U> const& other, + typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) + : Super(other.getExpr(), other.getSubsts()) + { + } + + /// Get the underlying expression without any substitutions + T* getExpr() const { return (T*) Super::getExpr(); } + + /// Dynamic cast to an expression of type `U` + /// + /// Returns a null expression if the cast fails, or if this expression was null. + template<typename U> + SubstExpr<U> as() + { + return SubstExpr<U>(Slang::as<U>(getExpr()), getSubsts()); + } + }; + class ASTBuilder; template<typename T> @@ -623,10 +724,10 @@ namespace Slang DeclRefBase substitute(ASTBuilder* astBuilder, DeclRefBase declRef) const; // Apply substitutions to an expression - Expr* substitute(ASTBuilder* astBuilder, Expr* expr) const; + SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const; // Apply substitutions to this declaration reference - DeclRefBase substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + DeclRefBase substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; // Returns true if 'as' will return a valid cast template <typename T> @@ -698,7 +799,8 @@ namespace Slang { return DeclRefBase::substitute(astBuilder, type); } - Expr* substitute(ASTBuilder* astBuilder, Expr* expr) const + + SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const { return DeclRefBase::substitute(astBuilder, expr); } @@ -711,7 +813,7 @@ namespace Slang } // Apply substitutions to this declaration reference - DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) + DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const { return DeclRef<T>::unsafeInit(DeclRefBase::substituteImpl(astBuilder, subst, ioDiff)); } @@ -722,6 +824,10 @@ namespace Slang } }; + SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr); + DeclRef<Decl> substituteDeclRef(SubstitutionSet const& substs, ASTBuilder* astBuilder, DeclRef<Decl> const& declRef); + Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type); + SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const DeclRefBase& declRef) { declRef.toText(io); return io; } template<typename T> diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index a1c8369aa..726781fb0 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1374,8 +1374,33 @@ namespace Slang return false; } - // TODO: actually implement matching here. For now we'll - // just pretend that things are satisfied in order to make progress.. + // A signature matches the required one if it has the right number of parameters, + // and those parameters have the right types, and also the result/return type + // is the required one. + // + auto requiredParams = getParameters(requiredMemberDeclRef).toArray(); + auto satisfyingParams = getParameters(satisfyingMemberDeclRef).toArray(); + auto paramCount = requiredParams.getCount(); + if(satisfyingParams.getCount() != paramCount) + return false; + + for(Index paramIndex = 0; paramIndex < paramCount; ++paramIndex) + { + auto requiredParam = requiredParams[paramIndex]; + auto satisfyingParam = satisfyingParams[paramIndex]; + + auto requiredParamType = getType(m_astBuilder, requiredParam); + auto satisfyingParamType = getType(m_astBuilder, satisfyingParam); + + if(!requiredParamType->equals(satisfyingParamType)) + return false; + } + + auto requiredResultType = getResultType(m_astBuilder, requiredMemberDeclRef); + auto satisfyingResultType = getResultType(m_astBuilder, satisfyingMemberDeclRef); + if(!requiredResultType->equals(satisfyingResultType)) + return false; + witnessTable->add( requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); @@ -1491,58 +1516,234 @@ namespace Slang bool SemanticsVisitor::doesGenericSignatureMatchRequirement( - DeclRef<GenericDecl> genDecl, - DeclRef<GenericDecl> requirementGenDecl, + DeclRef<GenericDecl> satisfyingGenericDeclRef, + DeclRef<GenericDecl> requiredGenericDeclRef, RefPtr<WitnessTable> witnessTable) { - if (genDecl.getDecl()->members.getCount() != requirementGenDecl.getDecl()->members.getCount()) + // The signature of a generic is defiend by its members, and we need the + // satisfying value to have the same number of members for it to be an + // exact match. + // + auto memberCount = requiredGenericDeclRef.getDecl()->members.getCount(); + if(satisfyingGenericDeclRef.getDecl()->members.getCount() != memberCount) return false; - for (Index i = 0; i < genDecl.getDecl()->members.getCount(); i++) + + // We then want to check that pairwise members match, in order. + // + auto requiredMemberDeclRefs = getMembers(requiredGenericDeclRef); + auto satisfyingMemberDeclRefs = getMembers(satisfyingGenericDeclRef); + // + // We start by performing a superficial "structural" match of the parameters + // to ensure that the two generics have an equivalent mix of type, value, + // and constraint parameters in the same order. + // + // Note that in this step we do *not* make any checks on the actual types + // involved in constraints, or on the types of value parameters. The reason + // for this is that the types on those parameters could be dependent on + // type parameters in the generic parameter list, and thus there could be + // a mismatch at this point. For example, if we have: + // + // interface IBase { void doThing<T, U : IThing<T>>(); } + // struct Derived : IBase { void doThing<X, Y : IThing<X>>(); } + // + // We clearly have a signature match here, but the constraint parameters for + // `U : IThing<T>` and `Y : IThing<X>` have the problem that both the sub-type + // and super-type they reference are not equivalent without substititions. + // + // We will deal with this issue after the structural matching is checked, at + // which point we can actually verify things like types. + // + for (Index i = 0; i < memberCount; i++) { - auto genMbr = genDecl.getDecl()->members[i]; - auto requiredGenMbr = genDecl.getDecl()->members[i]; - if (auto genTypeMbr = as<GenericTypeParamDecl>(genMbr)) + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + + if (as<GenericTypeParamDecl>(requiredMemberDeclRef)) { - if (auto requiredGenTypeMbr = as<GenericTypeParamDecl>(requiredGenMbr)) + if (as<GenericTypeParamDecl>(satisfyingMemberDeclRef)) { } else return false; } - else if (auto genValMbr = as<GenericValueParamDecl>(genMbr)) + else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as<GenericValueParamDecl>()) { - if (auto requiredGenValMbr = as<GenericValueParamDecl>(requiredGenMbr)) + if (auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>()) { - if (!genValMbr->type->equals(requiredGenValMbr->type)) - return false; } else return false; } - else if (auto genTypeConstraintMbr = as<GenericTypeConstraintDecl>(genMbr)) + else if (auto requiredConstraintDeclRef = requiredMemberDeclRef.as<GenericTypeConstraintDecl>()) { - if (auto requiredTypeConstraintMbr = as<GenericTypeConstraintDecl>(requiredGenMbr)) + if (auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as<GenericTypeConstraintDecl>()) { - if (!genTypeConstraintMbr->sup->equals(requiredTypeConstraintMbr->sup)) - { - return false; - } } else return false; } } - // TODO: this isn't right, because we need to specialize the - // declarations of the generics to a common set of substitutions, - // so that their types are comparable (e.g., foo<T> and foo<U> - // need to have substitutions applies so that they are both foo<X>, - // after which uses of the type X in their parameter lists can - // be compared). + // In order to compare the inner declarations of the two generics, we need to + // align them so that they are expressed in terms of consistent type parameters. + // + // For example, we might have: + // + // interface IBase { void doThing<T>(T val); } + // struct Derived : IBase { void doThing<U>(U val); } + // + // If we directly compare the signatures of the inner `doThing` function declarations, + // we'd find a mismatch between the `T` and `U` types of the `val` parameter. + // + // We can get around this mismatch by constructing a specialized reference and + // then doing the comparison. For example `IBase::doThing<X>` and `Derived::doThing<X>` + // should both have the signature `X -> void`. + // + // The one big detail that we need to be careful about here is that when we + // recursively call `doesMemberSatisfyRequirement`, that will eventually store + // the satisfying `DeclRef` as the value for the given requirement key, and we don't + // want to store a specialized reference like `Derived::doThing<X>` - we need to + // somehow store the original declaration. + // + // The solution here is to specialize the *required* declaration to the parameters + // of the satisfying declaration. In the example above that means we are going to + // compare `Derived::doThing` against `IBase::doThing<U>` where the `U` there is + // the parameter of `Dervived::doThing`. + // + GenericSubstitution* requiredSubst = m_astBuilder->create<GenericSubstitution>(); + requiredSubst->genericDecl = requiredGenericDeclRef.getDecl(); + requiredSubst->outer = requiredGenericDeclRef.substitutions; + + for (Index i = 0; i < memberCount; i++) + { + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + + if(auto requiredTypeParamDeclRef = requiredMemberDeclRef.as<GenericTypeParamDecl>()) + { + auto satisfyingTypeParamDeclRef = satisfyingMemberDeclRef.as<GenericTypeParamDecl>(); + SLANG_ASSERT(satisfyingTypeParamDeclRef); + auto satisfyingType = DeclRefType::create(m_astBuilder, satisfyingTypeParamDeclRef); + + requiredSubst->args.add(satisfyingType); + } + else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as<GenericValueParamDecl>()) + { + auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>(); + SLANG_ASSERT(satisfyingValueParamDeclRef); + auto satisfyingVal = m_astBuilder->create<GenericParamIntVal>(); + satisfyingVal->declRef = satisfyingValueParamDeclRef; + + requiredSubst->args.add(satisfyingVal); + } + } + for (Index i = 0; i < memberCount; i++) + { + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + + if(auto requiredConstraintDeclRef = requiredMemberDeclRef.as<GenericTypeConstraintDecl>()) + { + auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as<GenericTypeConstraintDecl>(); + SLANG_ASSERT(satisfyingConstraintDeclRef); + + auto satisfyingWitness = m_astBuilder->create<DeclaredSubtypeWitness>(); + satisfyingWitness->sub = getSub(m_astBuilder, satisfyingConstraintDeclRef); + satisfyingWitness->sup = getSup(m_astBuilder, satisfyingConstraintDeclRef); + satisfyingWitness->declRef = satisfyingConstraintDeclRef; + + requiredSubst->args.add(satisfyingWitness); + } + } + + // Now that we have computed a set of specialization arguments that will + // specialize the generic requirement at the type parameters of the satisfying + // generic, we can construct a reference to that declaration and re-run some + // of the earlier checking logic with more type information usable. + // + auto specializedRequiredGenericDeclRef = DeclRef<GenericDecl>(requiredGenericDeclRef.getDecl(), requiredSubst); + auto specializedRequiredMemberDeclRefs = getMembers(specializedRequiredGenericDeclRef); + for (Index i = 0; i < memberCount; i++) + { + auto requiredMemberDeclRef = specializedRequiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + + if(auto requiredTypeParamDeclRef = requiredMemberDeclRef.as<GenericTypeParamDecl>()) + { + auto satisfyingTypeParamDeclRef = satisfyingMemberDeclRef.as<GenericTypeParamDecl>(); + SLANG_ASSERT(satisfyingTypeParamDeclRef); + + // There are no additional checks we need to make on plain old + // type parameters at this point. + // + // TODO: If we ever support having type parameters of higher kinds, + // then this is possibly where we'd want to check that the kinds of + // the two parameters match. + // + SLANG_UNUSED(satisfyingGenericDeclRef); + } + else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as<GenericValueParamDecl>()) + { + auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>(); + SLANG_ASSERT(satisfyingValueParamDeclRef); + + // For a generic value parameter, we need to check that the required + // and satisfying declaration both agree on the type of the parameter. + // + auto requiredParamType = getType(m_astBuilder, requiredValueParamDeclRef); + auto satisfyingParamType = getType(m_astBuilder, satisfyingValueParamDeclRef); + if (!satisfyingParamType->equals(requiredParamType)) + return false; + } + else if(auto requiredConstraintDeclRef = requiredMemberDeclRef.as<GenericTypeConstraintDecl>()) + { + auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as<GenericTypeConstraintDecl>(); + SLANG_ASSERT(satisfyingConstraintDeclRef); + + // For a generic constraint parameter, we need to check that the sub-type + // and super-type in the constraint both match. + // + // In current code the sub type will always be one of the generic type parameters, + // and the super-type will always be an interface, but there should be no + // need to make use of those additional details here. + + auto requiredSubType = getSub(m_astBuilder, requiredConstraintDeclRef); + auto satisfyingSubType = getSub(m_astBuilder, satisfyingConstraintDeclRef); + if (!satisfyingSubType->equals(requiredSubType)) + return false; + + auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef); + auto satisfyingSuperType = getSup(m_astBuilder, satisfyingConstraintDeclRef); + if (!satisfyingSuperType->equals(requiredSuperType)) + return false; + } + } + + // Note: the above logic really only applies to the case of an exact match on signature, + // even down to the way that constraints were declared. We could potentially be more + // relaxed by taking advantage of the way that various different generic signatures will + // actually lower to the same IR generic signature. + // + // In theory, all we really care about when it comes to constraints is that the constraints + // on the required and satisfying declaration are *equivalent*. + // + // More generally, a satisfying generic could actually provide *looser* constraints and + // still work; all that matters is that it can be instantiated at any argument values/types + // that are valid for the requirement. + // + // We leave both of those issues up to the synthesis path: if we do not find a member that + // provides an exact match, then the compiler should try to synthesize one that is an exact + // match and makes use of existing declarations that might have require defaulting of arguments + // or type conversations to fit. + + // Once we've validated that the generic signatures are in an exact match, and devised type + // arguments for the requirement to make the two align, we can recursively check the inner + // declaration (whatever it is) for an exact match. + // return doesMemberSatisfyRequirement( - DeclRef<Decl>(genDecl.getDecl()->inner, genDecl.substitutions), - DeclRef<Decl>(requirementGenDecl.getDecl()->inner, requirementGenDecl.substitutions), + DeclRef<Decl>(satisfyingGenericDeclRef.getDecl()->inner, satisfyingGenericDeclRef.substitutions), + DeclRef<Decl>(requiredGenericDeclRef.getDecl()->inner, requiredSubst), witnessTable); } @@ -2375,13 +2576,15 @@ namespace Slang bool SemanticsVisitor::findWitnessForInterfaceRequirement( ConformanceCheckingContext* context, - Type* type, + Type* subType, + Type* superInterfaceType, InheritanceDecl* inheritanceDecl, - DeclRef<InterfaceDecl> interfaceDeclRef, + DeclRef<InterfaceDecl> superInterfaceDeclRef, DeclRef<Decl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable) + RefPtr<WitnessTable> witnessTable, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness) { - SLANG_UNUSED(interfaceDeclRef) + SLANG_UNUSED(superInterfaceDeclRef) // The goal of this function is to find a suitable // value to satisfy the requirement. @@ -2415,18 +2618,40 @@ namespace Slang // // TODO: we *really* need a linearization step here!!!! - RefPtr<WitnessTable> satisfyingWitnessTable = checkConformanceToType( - context, - type, - requiredInheritanceDeclRef.getDecl(), - getBaseType(m_astBuilder, requiredInheritanceDeclRef)); + auto reqType = getBaseType(m_astBuilder, requiredInheritanceDeclRef); - if(!satisfyingWitnessTable) - return false; + DeclaredSubtypeWitness* interfaceIsReqWitness = m_astBuilder->create<DeclaredSubtypeWitness>(); + interfaceIsReqWitness->sub = superInterfaceType; + interfaceIsReqWitness->sup = reqType; + interfaceIsReqWitness->declRef = requiredInheritanceDeclRef; + // ... + + TransitiveSubtypeWitness* subIsReqWitness = m_astBuilder->create<TransitiveSubtypeWitness>(); + subIsReqWitness->sub = subType; + subIsReqWitness->sup = reqType; + subIsReqWitness->subToMid = subTypeConformsToSuperInterfaceWitness; + subIsReqWitness->midToSup = interfaceIsReqWitness; + // ... + + RefPtr<WitnessTable> satisfyingWitnessTable = new WitnessTable(); + satisfyingWitnessTable->witnessedType = subType; + satisfyingWitnessTable->baseType = reqType; witnessTable->add( requiredInheritanceDeclRef.getDecl(), RequirementWitness(satisfyingWitnessTable)); + + if( !checkConformanceToType( + context, + subType, + requiredInheritanceDeclRef.getDecl(), + reqType, + subIsReqWitness, + satisfyingWitnessTable) ) + { + return false; + } + return true; } @@ -2465,7 +2690,7 @@ namespace Slang // requests will be handled further down. For now we include // lookup results that might be usable, but not as-is. // - auto lookupResult = lookUpMember(m_astBuilder, this, name, type, LookupMask::Default, LookupOptions::IgnoreBaseInterfaces); + auto lookupResult = lookUpMember(m_astBuilder, this, name, subType, LookupMask::Default, LookupOptions::IgnoreBaseInterfaces); if(!lookupResult.isValid()) { @@ -2478,7 +2703,8 @@ namespace Slang // signatures of methods, as is done for Swift), we'd // need to revisit this step. // - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, type, requiredMemberDeclRef); + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); + getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); return false; } @@ -2521,26 +2747,29 @@ namespace Slang // and if nothing is found we print the candidates that made it // furthest in checking. // - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, type, requiredMemberDeclRef); + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); + getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); return false; } RefPtr<WitnessTable> SemanticsVisitor::checkInterfaceConformance( ConformanceCheckingContext* context, - Type* type, + Type* subType, + Type* superInterfaceType, InheritanceDecl* inheritanceDecl, - DeclRef<InterfaceDecl> interfaceDeclRef) + DeclRef<InterfaceDecl> superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitnes) { // Has somebody already checked this conformance, // and/or is in the middle of checking it? RefPtr<WitnessTable> witnessTable; - if(context->mapInterfaceToWitnessTable.TryGetValue(interfaceDeclRef, witnessTable)) + if(context->mapInterfaceToWitnessTable.TryGetValue(superInterfaceDeclRef, witnessTable)) return witnessTable; // We need to check the declaration of the interface // before we can check that we conform to it. // - ensureDecl(interfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); + ensureDecl(superInterfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); // We will construct the witness table, and register it // *before* we go about checking fine-grained requirements, @@ -2554,10 +2783,53 @@ namespace Slang if(!witnessTable) { witnessTable = new WitnessTable(); - witnessTable->baseType = DeclRefType::create(m_astBuilder, interfaceDeclRef); - witnessTable->witnessedType = type; + witnessTable->baseType = DeclRefType::create(m_astBuilder, superInterfaceDeclRef); + witnessTable->witnessedType = subType; } - context->mapInterfaceToWitnessTable.Add(interfaceDeclRef, witnessTable); + context->mapInterfaceToWitnessTable.Add(superInterfaceDeclRef, witnessTable); + + if(!checkInterfaceConformance(context, subType, superInterfaceType, inheritanceDecl, superInterfaceDeclRef, subTypeConformsToSuperInterfaceWitnes, witnessTable)) + return nullptr; + + return witnessTable; + } + + static bool isAssociatedTypeDecl(Decl* decl) + { + auto d = decl; + while(auto genericDecl = as<GenericDecl>(d)) + d = genericDecl->inner; + if(as<AssocTypeDecl>(d)) + return true; + return false; + } + + bool SemanticsVisitor::checkInterfaceConformance( + ConformanceCheckingContext* context, + Type* subType, + Type* superInterfaceType, + InheritanceDecl* inheritanceDecl, + DeclRef<InterfaceDecl> superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness, + WitnessTable* witnessTable) + { + // We need to check the declaration of the interface + // before we can check that we conform to it. + // + ensureDecl(superInterfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); + + // When comparing things like signatures, we need to do so in the context + // of a this-type substitution that aligns the signatures in the interface + // with those in the concrete type. For example, we need to treat any uses + // of `This` in the interface as equivalent to the concrete type for the + // purpose of signature matching (and similarly for associated types). + // + ThisTypeSubstitution* thisTypeSubst = m_astBuilder->create<ThisTypeSubstitution>(); + thisTypeSubst->interfaceDecl = superInterfaceDeclRef.getDecl(); + thisTypeSubst->witness = subTypeConformsToSuperInterfaceWitness; + thisTypeSubst->outer = superInterfaceDeclRef.substitutions.substitutions; + + auto specializedSuperInterfaceDeclRef = DeclRef<InterfaceDecl>(superInterfaceDeclRef.getDecl(), thisTypeSubst); bool result = true; @@ -2567,15 +2839,59 @@ namespace Slang // its (non-interface) base types already conforms to // that interface, so that all of the requirements are // already satisfied with inherited implementations... - for(auto requiredMemberDeclRef : getMembers(interfaceDeclRef)) + + // Note: we break this logic into two loops, where we first + // check conformance for all associated-type requirements + // and *then* check conformance for all other requirements. + // + // Checking associated-type requirements first ensures that + // we can make use of the identity of the associated types + // when checking other members. + // + // TODO: There could in theory be subtle cases involving + // circular or recursive dependency chains that make such + // a simple ordering impractical (e.g., associated type `A` + // is constrained to `IThing<This>` where `IThing<T>` requires + // that `T : IOtherThing where T.B == int` for another associated + // type `B`). + // + // The only robust solution long-term is probably to treat this + // as a type-inference problem by creating type variables to + // stand in for the associated-type requirements and then to discover + // constraints and solve for those type variables as part of the + // conformance-checking process. + // + for(auto requiredMemberDeclRef : getMembers(specializedSuperInterfaceDeclRef)) { + if(!isAssociatedTypeDecl(requiredMemberDeclRef)) + continue; + auto requirementSatisfied = findWitnessForInterfaceRequirement( context, - type, + subType, + superInterfaceType, inheritanceDecl, - interfaceDeclRef, + specializedSuperInterfaceDeclRef, requiredMemberDeclRef, - witnessTable); + witnessTable, + subTypeConformsToSuperInterfaceWitness); + + result = result && requirementSatisfied; + } + for(auto requiredMemberDeclRef : getMembers(specializedSuperInterfaceDeclRef)) + { + if(isAssociatedTypeDecl(requiredMemberDeclRef)) + continue; + + auto requirementSatisfied = findWitnessForInterfaceRequirement( + context, + subType, + superInterfaceType, + inheritanceDecl, + specializedSuperInterfaceDeclRef, + requiredMemberDeclRef, + witnessTable, + subTypeConformsToSuperInterfaceWitness); result = result && requirementSatisfied; } @@ -2604,14 +2920,12 @@ namespace Slang // the time we are compiling and handle those, and punt on the larger issue // for a bit longer. // - for(auto candidateExt : getCandidateExtensions(interfaceDeclRef, this)) + for(auto candidateExt : getCandidateExtensions(specializedSuperInterfaceDeclRef, this)) { // We need to apply the extension to the interface type that our // concrete type is inheriting from. // - // TODO: need to decide if a this-type substitution is needed here. - // It probably it. - Type* targetType = DeclRefType::create(m_astBuilder, interfaceDeclRef); + Type* targetType = DeclRefType::create(m_astBuilder, specializedSuperInterfaceDeclRef); auto extDeclRef = ApplyExtensionToType(candidateExt, targetType); if(!extDeclRef) continue; @@ -2621,65 +2935,66 @@ namespace Slang { auto requirementSatisfied = findWitnessForInterfaceRequirement( context, - type, + subType, + superInterfaceType, inheritanceDecl, - interfaceDeclRef, + specializedSuperInterfaceDeclRef, requiredInheritanceDeclRef, - witnessTable); + witnessTable, + subTypeConformsToSuperInterfaceWitness); result = result && requirementSatisfied; } } - // If we failed to satisfy any requirements along the way, - // then we don't actually want to keep the witness table - // we've been constructing, because the whole thing was a failure. - if(!result) - { - return nullptr; - } - - return witnessTable; + // The conformance was satisfied if all the requirements were satisfied. + // + return result; } - RefPtr<WitnessTable> SemanticsVisitor::checkConformanceToType( + bool SemanticsVisitor::checkConformanceToType( ConformanceCheckingContext* context, - Type* type, + Type* subType, InheritanceDecl* inheritanceDecl, - Type* baseType) + Type* superType, + SubtypeWitness* subIsSuperWitness, + WitnessTable* witnessTable) { - if (auto baseDeclRefType = as<DeclRefType>(baseType)) + if (auto supereclRefType = as<DeclRefType>(superType)) { - auto baseTypeDeclRef = baseDeclRefType->declRef; - if (auto baseInterfaceDeclRef = baseTypeDeclRef.as<InterfaceDecl>()) + auto superTypeDeclRef = supereclRefType->declRef; + if (auto superInterfaceDeclRef = superTypeDeclRef.as<InterfaceDecl>()) { // The type is stating that it conforms to an interface. // We need to check that it provides all of the members // required by that interface. return checkInterfaceConformance( context, - type, + subType, + superType, inheritanceDecl, - baseInterfaceDeclRef); + superInterfaceDeclRef, + subIsSuperWitness, + witnessTable); } - else if( auto structDeclRef = baseTypeDeclRef.as<StructDecl>() ) + else if( auto superStructDeclRef = superTypeDeclRef.as<StructDecl>() ) { // The type is saying it inherits from a `struct`, // which doesn't require any checking at present - return nullptr; + return true; } } getSink()->diagnose(inheritanceDecl, Diagnostics::unimplemented, "type not supported for inheritance"); - return nullptr; + return false; } bool SemanticsVisitor::checkConformance( - Type* type, + Type* subType, InheritanceDecl* inheritanceDecl, ContainerDecl* parentDecl) { - if( auto declRefType = as<DeclRefType>(type) ) + if( auto declRefType = as<DeclRefType>(subType) ) { auto declRef = declRefType->declRef; @@ -2709,16 +3024,32 @@ namespace Slang // Look at the type being inherited from, and validate // appropriately. - auto baseType = inheritanceDecl->base.type; + auto superType = inheritanceDecl->base.type; + + DeclaredSubtypeWitness* subIsSuperWitness = m_astBuilder->create<DeclaredSubtypeWitness>(); + subIsSuperWitness->declRef = makeDeclRef(inheritanceDecl); + subIsSuperWitness->sub = subType; + subIsSuperWitness->sup = superType; ConformanceCheckingContext context; - context.conformingType = type; + context.conformingType = subType; context.parentDecl = parentDecl; - RefPtr<WitnessTable> witnessTable = checkConformanceToType(&context, type, inheritanceDecl, baseType); + + + RefPtr<WitnessTable> witnessTable = inheritanceDecl->witnessTable; if(!witnessTable) + { + witnessTable = new WitnessTable(); + witnessTable->baseType = superType; + witnessTable->witnessedType = subType; + inheritanceDecl->witnessTable = witnessTable; + } + + if( !checkConformanceToType(&context, subType, inheritanceDecl, superType, subIsSuperWitness, witnessTable) ) + { return false; + } - inheritanceDecl->witnessTable = witnessTable; return true; } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 5c00cb64b..fbbfcd473 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -694,7 +694,7 @@ namespace Slang } IntVal* SemanticsVisitor::tryConstantFoldExpr( - InvokeExpr* invokeExpr, + SubstExpr<InvokeExpr> invokeExpr, ConstantFoldingCircularityInfo* circularityInfo) { // We need all the operands to the expression @@ -703,10 +703,10 @@ namespace Slang // // For right now we will look for calls to intrinsic functions, and then inspect // their names (this is bad and slow). - auto funcDeclRefExpr = as<DeclRefExpr>(invokeExpr->functionExpr); + auto funcDeclRefExpr = getBaseExpr(invokeExpr).as<DeclRefExpr>(); if (!funcDeclRefExpr) return nullptr; - auto funcDeclRef = funcDeclRefExpr->declRef; + auto funcDeclRef = getDeclRef(m_astBuilder, funcDeclRefExpr); auto intrinsicMod = funcDeclRef.getDecl()->findModifier<IntrinsicOpModifier>(); if (!intrinsicMod) { @@ -722,31 +722,31 @@ namespace Slang // Let's not constant-fold operations with more than a certain number of arguments, for simplicity static const int kMaxArgs = 8; - if (invokeExpr->arguments.getCount() > kMaxArgs) + auto argCount = getArgCount(invokeExpr); + if (argCount > kMaxArgs) return nullptr; // Before checking the operation name, let's look at the arguments IntVal* argVals[kMaxArgs]; IntegerLiteralValue constArgVals[kMaxArgs]; - int argCount = 0; bool allConst = true; - for (auto argExpr : invokeExpr->arguments) + for(Index a = 0; a < argCount; ++a) { + auto argExpr = getArg(invokeExpr, a); auto argVal = tryFoldIntegerConstantExpression(argExpr, circularityInfo); if (!argVal) return nullptr; - argVals[argCount] = argVal; + argVals[a] = argVal; if (auto constArgVal = as<ConstantIntVal>(argVal)) { - constArgVals[argCount] = constArgVal->value; + constArgVals[a] = constArgVal->value; } else { allConst = false; } - argCount++; } if (!allConst) @@ -866,25 +866,25 @@ namespace Slang } IntVal* SemanticsVisitor::tryConstantFoldExpr( - Expr* expr, + SubstExpr<Expr> expr, ConstantFoldingCircularityInfo* circularityInfo) { // Unwrap any "identity" expressions - while (auto parenExpr = as<ParenExpr>(expr)) + while (auto parenExpr = expr.as<ParenExpr>()) { - expr = parenExpr->base; + expr = getBaseExpr(parenExpr); } // TODO(tfoley): more serious constant folding here - if (auto intLitExpr = as<IntegerLiteralExpr>(expr)) + if (auto intLitExpr = expr.as<IntegerLiteralExpr>()) { return getIntVal(intLitExpr); } // it is possible that we are referring to a generic value param - if (auto declRefExpr = as<DeclRefExpr>(expr)) + if (auto declRefExpr = expr.as<DeclRefExpr>()) { - auto declRef = declRefExpr->declRef; + auto declRef = getDeclRef(m_astBuilder, declRefExpr); if (auto genericValParamRef = declRef.as<GenericValueParamDecl>()) { @@ -913,13 +913,13 @@ namespace Slang } } - if(auto castExpr = as<TypeCastExpr>(expr)) + if(auto castExpr = expr.as<TypeCastExpr>()) { - auto val = tryConstantFoldExpr(castExpr->arguments[0], circularityInfo); + auto val = tryConstantFoldExpr(getArg(castExpr, 0), circularityInfo); if(val) return val; } - else if (auto invokeExpr = as<InvokeExpr>(expr)) + else if (auto invokeExpr = expr.as<InvokeExpr>()) { auto val = tryConstantFoldExpr(invokeExpr, circularityInfo); if (val) @@ -930,12 +930,12 @@ namespace Slang } IntVal* SemanticsVisitor::tryFoldIntegerConstantExpression( - Expr* expr, + SubstExpr<Expr> expr, ConstantFoldingCircularityInfo* circularityInfo) { // Check if type is acceptable for an integer constant expression // - if(!isScalarIntegerType(expr->type)) + if(!isScalarIntegerType(getType(m_astBuilder, expr))) return nullptr; // Consider operations that we might be able to constant-fold... @@ -2037,7 +2037,11 @@ namespace Slang { auto containerDecl = scope->containerDecl; - if( auto setterDecl = as<SetterDecl>(containerDecl) ) + if( auto ctorDecl = as<ConstructorDecl>(containerDecl) ) + { + expr->type.isLeftValue = true; + } + else if( auto setterDecl = as<SetterDecl>(containerDecl) ) { expr->type.isLeftValue = true; } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index ad9de3f4b..bc265b1b4 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -830,27 +830,42 @@ namespace Slang // bool findWitnessForInterfaceRequirement( ConformanceCheckingContext* context, - Type* type, + Type* subType, + Type* superInterfaceType, InheritanceDecl* inheritanceDecl, - DeclRef<InterfaceDecl> interfaceDeclRef, + DeclRef<InterfaceDecl> superInterfaceDeclRef, DeclRef<Decl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable); + RefPtr<WitnessTable> witnessTable, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness); // Check that the type declaration `typeDecl`, which // declares conformance to the interface `interfaceDeclRef`, // (via the given `inheritanceDecl`) actually provides // members to satisfy all the requirements in the interface. + bool checkInterfaceConformance( + ConformanceCheckingContext* context, + Type* subType, + Type* superInterfaceType, + InheritanceDecl* inheritanceDecl, + DeclRef<InterfaceDecl> superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness, + WitnessTable* witnessTable); + RefPtr<WitnessTable> checkInterfaceConformance( ConformanceCheckingContext* context, - Type* type, + Type* subType, + Type* superInterfaceType, InheritanceDecl* inheritanceDecl, - DeclRef<InterfaceDecl> interfaceDeclRef); + DeclRef<InterfaceDecl> superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness); - RefPtr<WitnessTable> checkConformanceToType( + bool checkConformanceToType( ConformanceCheckingContext* context, - Type* type, + Type* subType, InheritanceDecl* inheritanceDecl, - Type* baseType); + Type* superType, + SubtypeWitness* subIsSuperWitness, + WitnessTable* witnessTable); /// Check that `type` which has declared that it inherits from (and/or implements) /// another type via `inheritanceDecl` actually does what it needs to for that @@ -913,6 +928,11 @@ namespace Slang IntVal* getIntVal(IntegerLiteralExpr* expr); + inline IntVal* getIntVal(SubstExpr<IntegerLiteralExpr> expr) + { + return getIntVal(expr.getExpr()); + } + Name* getName(String const& text) { return getNamePool()->getName(text); @@ -938,12 +958,12 @@ namespace Slang /// Try to apply front-end constant folding to determine the value of `invokeExpr`. IntVal* tryConstantFoldExpr( - InvokeExpr* invokeExpr, + SubstExpr<InvokeExpr> invokeExpr, ConstantFoldingCircularityInfo* circularityInfo); /// Try to apply front-end constant folding to determine the value of `expr`. IntVal* tryConstantFoldExpr( - Expr* expr, + SubstExpr<Expr> expr, ConstantFoldingCircularityInfo* circularityInfo); bool _checkForCircularityInConstantFolding( @@ -960,7 +980,7 @@ namespace Slang /// as an integer constant. /// IntVal* tryFoldIntegerConstantExpression( - Expr* expr, + SubstExpr<Expr> expr, ConstantFoldingCircularityInfo* circularityInfo); // Enforce that an expression resolves to an integer constant, and get its value diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index c4e4cbde4..448cdeb88 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1051,19 +1051,19 @@ namespace Slang Index argCount = context.getArgCount(); Index paramCount = params.getCount(); - // Bail out on mismatch. - // TODO(tfoley): need more nuance here - if (argCount != paramCount) + // If there are too many arguments, we cannot possibly have a match. + // + // Note that if there are *too few* arguments, we might still have + // a match, because the other arguments might have default values + // that can be used. + // + if (argCount > paramCount) { return DeclRef<Decl>(nullptr, nullptr); } for (Index aa = 0; aa < argCount; ++aa) { -#if 0 - if (!TryUnifyArgAndParamTypes(constraints, args[aa], params[aa])) - return DeclRef<Decl>(nullptr, nullptr); -#else // The question here is whether failure to "unify" an argument // and parameter should lead to immediate failure. // @@ -1083,7 +1083,6 @@ namespace Slang // unification step should be taken as an immediate failure... TryUnifyTypes(constraints, context.getArgTypeForInference(aa, this), getType(m_astBuilder, params[aa])); -#endif } } else diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index fb9fc70fd..3c9256178 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -633,7 +633,7 @@ LoweredValInfo emitCallToDeclRef( if( auto ctorDeclRef = funcDeclRef.as<ConstructorDecl>() ) { - if(!ctorDeclRef.getDecl()->body) + if(!ctorDeclRef.getDecl()->body && isFromStdLib(ctorDeclRef.decl)) { // HACK: For legacy reasons, all of the built-in initializers // in the standard library are declared without proper @@ -1114,7 +1114,6 @@ void getGenericTypeConformances(IRGenContext* context, ShortList<IRType*>& supTy } } -SubstitutionSet lowerSubstitutions(IRGenContext* context, SubstitutionSet subst); // struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredValInfo> @@ -3141,6 +3140,45 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + void _lowerSubstitutionArg(IRGenContext* subContext, GenericSubstitution* subst, Decl* paramDecl, Index argIndex) + { + SLANG_ASSERT(argIndex < subst->args.getCount()); + auto argVal = lowerVal(subContext, subst->args[argIndex]); + setValue(subContext, paramDecl, argVal); + } + + void _lowerSubstitutionEnv(IRGenContext* subContext, Substitutions* subst) + { + if(!subst) return; + _lowerSubstitutionEnv(subContext, subst->outer); + + if (auto genSubst = as<GenericSubstitution>(subst)) + { + auto genDecl = genSubst->genericDecl; + + Index argCounter = 0; + for( auto memberDecl: genDecl->members ) + { + if(auto typeParamDecl = as<GenericTypeParamDecl>(memberDecl) ) + { + _lowerSubstitutionArg(subContext, genSubst, typeParamDecl, argCounter++); + } + else if( auto valParamDecl = as<GenericValueParamDecl>(memberDecl) ) + { + _lowerSubstitutionArg(subContext, genSubst, valParamDecl, argCounter++); + } + } + for( auto memberDecl: genDecl->members ) + { + if(auto constraintDecl = as<GenericTypeConstraintDecl>(memberDecl) ) + { + _lowerSubstitutionArg(subContext, genSubst, constraintDecl, argCounter++); + } + } + } + // TODO: also need to handle this-type substitution here? + } + void addDirectCallArgs( InvokeExpr* expr, DeclRef<CallableDecl> funcDeclRef, @@ -3156,10 +3194,10 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> auto paramDirection = getParameterDirection(paramDecl); UInt argIndex = argCounter++; - Expr* argExpr = nullptr; if(argIndex < argCount) { - argExpr = expr->arguments[argIndex]; + auto argExpr = expr->arguments[argIndex]; + addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups); } else { @@ -3167,11 +3205,31 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // but there are still parameters remaining. This must mean // that these parameters have default argument expressions // associated with them. - argExpr = getInitExpr(getASTBuilder(), paramDeclRef); - - // Assert that such an expression must have been present. + // + // Currently we simply extract the initial-value expression + // from the parameter declaration and then lower it in + // the context of the caller. + // + // Note that the expression could involve subsitutions because + // in the general case it could depend on the generic parameters + // used the specialize the callee. For now we do not handle that + // case, and simply ignore generic arguments. + // + SubstExpr<Expr> argExpr = getInitExpr(getASTBuilder(), paramDeclRef); SLANG_ASSERT(argExpr); + IRGenEnv subEnvStorage; + IRGenEnv* subEnv = &subEnvStorage; + subEnv->outer = context->env; + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->env = subEnv; + + _lowerSubstitutionEnv(subContext, argExpr.getSubsts()); + + addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups); + // TODO: The approach we are taking here to default arguments // is simplistic, and has consequences for the front-end as // well as binary serialization of modules. @@ -3186,9 +3244,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // // Each of these options involves trade-offs, and we need to // make a conscious decision at some point. - } - addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups); + // Assert that such an expression must have been present. + } } } @@ -7347,35 +7405,6 @@ LoweredValInfo ensureDecl( return result; } -IRInst* lowerSubstitutionArg( - IRGenContext* context, - Val* val) -{ - if (auto type = dynamicCast<Type>(val)) - { - return lowerType(context, type); - } - else if (auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(val)) - { - // We need to look up the IR-level representation of the witness (which will be a witness table). - auto supType = lowerType( - context, - DeclRefType::create(context->astBuilder, declaredSubtypeWitness->declRef)); - auto irWitnessTable = getSimpleVal( - context, - emitDeclRef( - context, - declaredSubtypeWitness->declRef, - context->irBuilder->getWitnessTableType(supType))); - return irWitnessTable; - } - else - { - SLANG_UNIMPLEMENTED_X("value cases"); - UNREACHABLE_RETURN(nullptr); - } -} - // Can the IR lowered version of this declaration ever be an `IRGeneric`? bool canDeclLowerToAGeneric(Decl* decl) { diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 62de6599c..accbdb27e 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -641,17 +641,37 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return declRef.substituteImpl(astBuilder, substitutions, &diff); } - Expr* DeclRefBase::substitute(ASTBuilder* /* astBuilder*/, Expr* expr) const + SubstExpr<Expr> DeclRefBase::substitute(ASTBuilder* /* astBuilder*/, Expr* expr) const { - // No substitutions? Easy. - if (!substitutions) - return expr; + return SubstExpr<Expr>(expr, substitutions); + } + + SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr) + { + return SubstExpr<Expr>(expr, substs); + } + + DeclRef<Decl> substituteDeclRef(SubstitutionSet const& substs, ASTBuilder* astBuilder, DeclRef<Decl> const& declRef) + { + if(!substs) + return declRef; + + int diff = 0; + auto declRefBase = declRef.substituteImpl(astBuilder, substs, &diff); + return DeclRef<Decl>(declRefBase.decl, declRefBase.substitutions); + } - SLANG_UNIMPLEMENTED_X("generic substitution into expressions"); + Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type) + { + if(!type) return nullptr; + if(!substs) return type; + + SLANG_ASSERT(type); - UNREACHABLE_RETURN(expr); + return Slang::as<Type>(type->substitute(astBuilder, substs)); } + void buildMemberDictionary(ContainerDecl* decl); InterfaceDecl* findOuterInterfaceDecl(Decl* decl) @@ -854,7 +874,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - DeclRefBase DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet substSet, int* ioDiff) + DeclRefBase DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet substSet, int* ioDiff) const { // Nothing to do when we have no declaration. if(!decl) diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index a0c54f914..9589f00fd 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -106,7 +106,7 @@ namespace Slang return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } - inline Expr* getInitExpr(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& declRef) + inline SubstExpr<Expr> getInitExpr(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->initExpr); } @@ -121,7 +121,7 @@ namespace Slang return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } - inline Expr* getTagExpr(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> const& declRef) + inline SubstExpr<Expr> getTagExpr(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> const& declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->tagExpr); } @@ -165,6 +165,37 @@ namespace Slang return declRef.getDecl()->inner; } + // + + inline Type* getType(ASTBuilder* astBuilder, SubstExpr<Expr> expr) + { + return substituteType(expr.getSubsts(), astBuilder, expr.getExpr()->type); + } + + inline SubstExpr<Expr> getBaseExpr(SubstExpr<ParenExpr> expr) + { + return substituteExpr(expr.getSubsts(), expr.getExpr()->base); + } + + inline SubstExpr<Expr> getBaseExpr(SubstExpr<InvokeExpr> expr) + { + return substituteExpr(expr.getSubsts(), expr.getExpr()->functionExpr); + } + + inline Index getArgCount(SubstExpr<InvokeExpr> expr) + { + return expr.getExpr()->arguments.getCount(); + } + + inline SubstExpr<Expr> getArg(SubstExpr<InvokeExpr> expr, Index index) + { + return substituteExpr(expr.getSubsts(), expr.getExpr()->arguments[index]); + } + + inline DeclRef<Decl> getDeclRef(ASTBuilder* astBuilder, SubstExpr<DeclRefExpr> expr) + { + return substituteDeclRef(expr.getSubsts(), astBuilder, expr.getExpr()->declRef); + } // diff --git a/tests/diagnostics/interface-requirement-not-satisfied.slang.expected b/tests/diagnostics/interface-requirement-not-satisfied.slang.expected index d7614186b..464ffde25 100644 --- a/tests/diagnostics/interface-requirement-not-satisfied.slang.expected +++ b/tests/diagnostics/interface-requirement-not-satisfied.slang.expected @@ -3,6 +3,7 @@ standard error = { tests/diagnostics/interface-requirement-not-satisfied.slang(10): error 38100: type 'T' does not provide required interface member 'bar' struct T : IFoo ^~~~ +tests/diagnostics/interface-requirement-not-satisfied.slang(7): note: see declaration of 'bar' } standard output = { } diff --git a/tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang.expected b/tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang.expected index 3c9ef58d5..cc4f310ad 100644 --- a/tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang.expected +++ b/tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang.expected @@ -3,6 +3,7 @@ standard error = { tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang(10): error 38100: type 'Counter' does not provide required interface member 'processValue' struct Counter : IThing ^~~~~~ +tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang(7): note: see declaration of 'processValue' } standard output = { } diff --git a/tests/language-feature/parameters/generic-func-param-default-arg.slang b/tests/language-feature/parameters/generic-func-param-default-arg.slang new file mode 100644 index 000000000..b7e8d6fa2 --- /dev/null +++ b/tests/language-feature/parameters/generic-func-param-default-arg.slang @@ -0,0 +1,47 @@ +// generic-func-param-default-arg.slang + +// Test that generic functions can have default argument values on their parameters. + +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST_DISABLED:SIMPLE:-target hlsl -entry computeMain -dump-ir + + +interface IValue +{ + __init(); + + This plusA(This other); + This plusB(int other); +} + +T sum<T : IValue>(T value, T other = T(), int extra = 0) +{ + return value.plusA(other).plusB(extra); +} + +struct Simple : IValue +{ + int val; + + __init() { val = 0; } + __init(int val) { this.val = val; } + + Simple plusA(Simple other) { return Simple(val + other.val); } + Simple plusB(int other) { return Simple(val + other); } +} + +int test(int val) +{ + let s = Simple(val); + return sum<Simple>(s).val + 16*sum(s).val; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gBuffer +RWStructuredBuffer<int> gBuffer; + +[shader("compute")] +[numthreads(4)] +void computeMain(int tid : SV_DispatchThreadID) +{ + gBuffer[tid] = test(tid); +} diff --git a/tests/language-feature/parameters/generic-func-param-default-arg.slang.expected.txt b/tests/language-feature/parameters/generic-func-param-default-arg.slang.expected.txt new file mode 100644 index 000000000..d4cb1cc00 --- /dev/null +++ b/tests/language-feature/parameters/generic-func-param-default-arg.slang.expected.txt @@ -0,0 +1,4 @@ +0 +11 +22 +33 |
