diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2021-03-10 15:18:06 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-03-10 15:18:06 -0800 |
| commit | 6cbd9d68a03f0a22305d4e224a3da7633b23de38 (patch) | |
| tree | de436717081a9b2b7ddd3644f2e7ada130951141 | |
| parent | 6ef4054f8a8aea4ec61481057fa7e16aaecde6d7 (diff) | |
A bunch of overlapping semantic-checking fixes (#1743)
This change originally started with the simple goal of allowing generic functions with default argument values on their parameters to work:
```
void someFunction<T>(T value, int optional = 0);
```
The core problem there was that the compiler code was (correctly) anticipate the case where the default argument value for a parameter depends on a generic parameter, such as:
```
interface IDefaultable { static This getDefault(); }
void anotherFunction<T : IDefaultable>(T first, T second = T.getDefault());
```
Supporting this latter case requires some kind of ability to apply subsitutions to an `Expr`, but our compiler logic simply errored out in that case. The first major fix that went into this change was to add a new `SubstExpr<T>` type that behaves a lot like `DeclRef<T>` in that it stores a `T*` plus a set of substititions that need to be applied to it.
In addition, it was found that even if `anotherFunction<ConcreteType>(...)` might work, when generic argument inference was used for just `anotherFunction(...)` would fail because it includes a strict match on the number of arguments/parameters in the call expression.
The next problem that arose was that the test I'd created used an interace with an `__init` requirement, and it appeared that our code generation didn't work for that case:
```
interface IStuff { __init(int val); }
void f<T : IStuff>(T x = T(0));
```
In this case, the `T(0)` initialization would get compiled to `(ConcreteType) 0` in the output rather than calling the function generated for the `__init` inside `ConcreteType`. The basic problem there was a bit of crufty old logic we have in place to work around the large number of `__init` declarations in the stdlib that don't have proper `__intrinsic_op` modifiers on them. We really need to fix the underlying problem there, but I worked around it by having the IR lowering pass only do its workaround magic on stdlib declarations.
The next problem down this line was that my test had two different `__init` declarations in the concrete type and the logic for checking interface conformance was picking the wrong one to satisfying an interface requirement despite it being obviously wrong (not even the right number of parameter).
This last problem led me down the rabbit-hole of trying to actually get our semantic checking for interface requirements right. There were a few pieces to that work:
* Actually checking that the parameter and result types for two callables match is the simple part. If that was all that would be required we would have implement this logic a long time ago.
* Next we have to deal with functions that make use of the `This` type, associated types, etc. We have to know that when the interface uses `This`, we want to treat that as equivalent to `ConcreteType`, and similarly for associated types. Getting that working is mostly a matter of setting up a this-type subsitution for the interface member being checked.
* Finally, when comparing generic declarations like `IBase::doThing<T>` and `Derived::doThing<U>` we need to deal with the way that `T` and `U` represent the "same" logical type parameter, but are distinct `Decl`s. This is handled by specializing the base declaration to the parameters of the derived one (e.g., forming `IBase::doThing<U>` using the `U` from `Derived::doThing`).
The result seems to be passing our tests, but there are still a few gotchas lurking, I'm sure.
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 114 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 503 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 46 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 42 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 105 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 35 | ||||
| -rw-r--r-- | tests/diagnostics/interface-requirement-not-satisfied.slang.expected | 1 | ||||
| -rw-r--r-- | tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang.expected | 1 | ||||
| -rw-r--r-- | tests/language-feature/parameters/generic-func-param-default-arg.slang | 47 | ||||
| -rw-r--r-- | tests/language-feature/parameters/generic-func-param-default-arg.slang.expected.txt | 4 |
12 files changed, 770 insertions, 177 deletions
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 6697e878a..1921a101f 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -582,6 +582,107 @@ namespace Slang HashCode getHashCode() const; }; + /// An expression together with (optional) substutions to apply to it + /// + /// Under the hood this is a pair of an `Expr*` and a `SubstitutionSet`. + /// Conceptually it represents the result of applying the substitutions, + /// recursively, to the given expression. + /// + /// `SubstExprBase` exists primarily to provide a non-templated base type + /// for `SubstExpr<T>`. Code should prefer to use `SubstExpr<Expr>` instead + /// of `SubstExprBase` as often as possible. + /// + struct SubstExprBase + { + public: + /// Initialize as a null expression + SubstExprBase() + {} + + /// Initialize as the given `expr` with no subsitutions applied + SubstExprBase(Expr* expr) + : m_expr(expr) + {} + + /// Initialize as the given `expr` with the given `substs` applied + SubstExprBase(Expr* expr, SubstitutionSet const& substs) + : m_expr(expr) + , m_substs(substs) + {} + + /// Get the underlying expression without any substitutions + Expr* getExpr() const { return m_expr; } + + /// Get the subsitutions being applied, if any + SubstitutionSet const& getSubsts() const { return m_substs; } + + private: + Expr* m_expr = nullptr; + SubstitutionSet m_substs; + + typedef void (SubstExprBase::*SafeBool)(); + void SafeBoolTrue() {} + + public: + /// Test whether this is a non-null expression + operator SafeBool() + { + return m_expr ? &SubstExprBase::SafeBoolTrue : nullptr; + } + + /// Test whether this is a null expression + bool operator!() const { return m_expr == nullptr; } + + }; + + /// An expression together with (optional) substutions to apply to it + /// + /// Under the hood this is a pair of an `T*` (there `T: Expr`) and a `SubstitutionSet`. + /// Conceptually it represents the result of applying the substitutions, + /// recursively, to the given expression. + /// + template<typename T> + struct SubstExpr : SubstExprBase + { + private: + typedef SubstExprBase Super; + + public: + /// Initialize as a null expression + SubstExpr() + {} + + /// Initialize as the given `expr` with no subsitutions applied + SubstExpr(T* expr) + : Super(expr) + {} + + /// Initialize as the given `expr` with the given `substs` applied + SubstExpr(T* expr, SubstitutionSet const& substs) + : Super(expr, substs) + {} + + /// Initialize as a copy of the given `other` expression + template <typename U> + SubstExpr(SubstExpr<U> const& other, + typename EnableIf<IsConvertible<T*, U*>::Value, void>::type* = 0) + : Super(other.getExpr(), other.getSubsts()) + { + } + + /// Get the underlying expression without any substitutions + T* getExpr() const { return (T*) Super::getExpr(); } + + /// Dynamic cast to an expression of type `U` + /// + /// Returns a null expression if the cast fails, or if this expression was null. + template<typename U> + SubstExpr<U> as() + { + return SubstExpr<U>(Slang::as<U>(getExpr()), getSubsts()); + } + }; + class ASTBuilder; template<typename T> @@ -623,10 +724,10 @@ namespace Slang DeclRefBase substitute(ASTBuilder* astBuilder, DeclRefBase declRef) const; // Apply substitutions to an expression - Expr* substitute(ASTBuilder* astBuilder, Expr* expr) const; + SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const; // Apply substitutions to this declaration reference - DeclRefBase substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + DeclRefBase substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; // Returns true if 'as' will return a valid cast template <typename T> @@ -698,7 +799,8 @@ namespace Slang { return DeclRefBase::substitute(astBuilder, type); } - Expr* substitute(ASTBuilder* astBuilder, Expr* expr) const + + SubstExpr<Expr> substitute(ASTBuilder* astBuilder, Expr* expr) const { return DeclRefBase::substitute(astBuilder, expr); } @@ -711,7 +813,7 @@ namespace Slang } // Apply substitutions to this declaration reference - DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) + DeclRef<T> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const { return DeclRef<T>::unsafeInit(DeclRefBase::substituteImpl(astBuilder, subst, ioDiff)); } @@ -722,6 +824,10 @@ namespace Slang } }; + SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr); + DeclRef<Decl> substituteDeclRef(SubstitutionSet const& substs, ASTBuilder* astBuilder, DeclRef<Decl> const& declRef); + Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type); + SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const DeclRefBase& declRef) { declRef.toText(io); return io; } template<typename T> diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index a1c8369aa..726781fb0 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1374,8 +1374,33 @@ namespace Slang return false; } - // TODO: actually implement matching here. For now we'll - // just pretend that things are satisfied in order to make progress.. + // A signature matches the required one if it has the right number of parameters, + // and those parameters have the right types, and also the result/return type + // is the required one. + // + auto requiredParams = getParameters(requiredMemberDeclRef).toArray(); + auto satisfyingParams = getParameters(satisfyingMemberDeclRef).toArray(); + auto paramCount = requiredParams.getCount(); + if(satisfyingParams.getCount() != paramCount) + return false; + + for(Index paramIndex = 0; paramIndex < paramCount; ++paramIndex) + { + auto requiredParam = requiredParams[paramIndex]; + auto satisfyingParam = satisfyingParams[paramIndex]; + + auto requiredParamType = getType(m_astBuilder, requiredParam); + auto satisfyingParamType = getType(m_astBuilder, satisfyingParam); + + if(!requiredParamType->equals(satisfyingParamType)) + return false; + } + + auto requiredResultType = getResultType(m_astBuilder, requiredMemberDeclRef); + auto satisfyingResultType = getResultType(m_astBuilder, satisfyingMemberDeclRef); + if(!requiredResultType->equals(satisfyingResultType)) + return false; + witnessTable->add( requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); @@ -1491,58 +1516,234 @@ namespace Slang bool SemanticsVisitor::doesGenericSignatureMatchRequirement( - DeclRef<GenericDecl> genDecl, - DeclRef<GenericDecl> requirementGenDecl, + DeclRef<GenericDecl> satisfyingGenericDeclRef, + DeclRef<GenericDecl> requiredGenericDeclRef, RefPtr<WitnessTable> witnessTable) { - if (genDecl.getDecl()->members.getCount() != requirementGenDecl.getDecl()->members.getCount()) + // The signature of a generic is defiend by its members, and we need the + // satisfying value to have the same number of members for it to be an + // exact match. + // + auto memberCount = requiredGenericDeclRef.getDecl()->members.getCount(); + if(satisfyingGenericDeclRef.getDecl()->members.getCount() != memberCount) return false; - for (Index i = 0; i < genDecl.getDecl()->members.getCount(); i++) + + // We then want to check that pairwise members match, in order. + // + auto requiredMemberDeclRefs = getMembers(requiredGenericDeclRef); + auto satisfyingMemberDeclRefs = getMembers(satisfyingGenericDeclRef); + // + // We start by performing a superficial "structural" match of the parameters + // to ensure that the two generics have an equivalent mix of type, value, + // and constraint parameters in the same order. + // + // Note that in this step we do *not* make any checks on the actual types + // involved in constraints, or on the types of value parameters. The reason + // for this is that the types on those parameters could be dependent on + // type parameters in the generic parameter list, and thus there could be + // a mismatch at this point. For example, if we have: + // + // interface IBase { void doThing<T, U : IThing<T>>(); } + // struct Derived : IBase { void doThing<X, Y : IThing<X>>(); } + // + // We clearly have a signature match here, but the constraint parameters for + // `U : IThing<T>` and `Y : IThing<X>` have the problem that both the sub-type + // and super-type they reference are not equivalent without substititions. + // + // We will deal with this issue after the structural matching is checked, at + // which point we can actually verify things like types. + // + for (Index i = 0; i < memberCount; i++) { - auto genMbr = genDecl.getDecl()->members[i]; - auto requiredGenMbr = genDecl.getDecl()->members[i]; - if (auto genTypeMbr = as<GenericTypeParamDecl>(genMbr)) + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + + if (as<GenericTypeParamDecl>(requiredMemberDeclRef)) { - if (auto requiredGenTypeMbr = as<GenericTypeParamDecl>(requiredGenMbr)) + if (as<GenericTypeParamDecl>(satisfyingMemberDeclRef)) { } else return false; } - else if (auto genValMbr = as<GenericValueParamDecl>(genMbr)) + else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as<GenericValueParamDecl>()) { - if (auto requiredGenValMbr = as<GenericValueParamDecl>(requiredGenMbr)) + if (auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>()) { - if (!genValMbr->type->equals(requiredGenValMbr->type)) - return false; } else return false; } - else if (auto genTypeConstraintMbr = as<GenericTypeConstraintDecl>(genMbr)) + else if (auto requiredConstraintDeclRef = requiredMemberDeclRef.as<GenericTypeConstraintDecl>()) { - if (auto requiredTypeConstraintMbr = as<GenericTypeConstraintDecl>(requiredGenMbr)) + if (auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as<GenericTypeConstraintDecl>()) { - if (!genTypeConstraintMbr->sup->equals(requiredTypeConstraintMbr->sup)) - { - return false; - } } else return false; } } - // TODO: this isn't right, because we need to specialize the - // declarations of the generics to a common set of substitutions, - // so that their types are comparable (e.g., foo<T> and foo<U> - // need to have substitutions applies so that they are both foo<X>, - // after which uses of the type X in their parameter lists can - // be compared). + // In order to compare the inner declarations of the two generics, we need to + // align them so that they are expressed in terms of consistent type parameters. + // + // For example, we might have: + // + // interface IBase { void doThing<T>(T val); } + // struct Derived : IBase { void doThing<U>(U val); } + // + // If we directly compare the signatures of the inner `doThing` function declarations, + // we'd find a mismatch between the `T` and `U` types of the `val` parameter. + // + // We can get around this mismatch by constructing a specialized reference and + // then doing the comparison. For example `IBase::doThing<X>` and `Derived::doThing<X>` + // should both have the signature `X -> void`. + // + // The one big detail that we need to be careful about here is that when we + // recursively call `doesMemberSatisfyRequirement`, that will eventually store + // the satisfying `DeclRef` as the value for the given requirement key, and we don't + // want to store a specialized reference like `Derived::doThing<X>` - we need to + // somehow store the original declaration. + // + // The solution here is to specialize the *required* declaration to the parameters + // of the satisfying declaration. In the example above that means we are going to + // compare `Derived::doThing` against `IBase::doThing<U>` where the `U` there is + // the parameter of `Dervived::doThing`. + // + GenericSubstitution* requiredSubst = m_astBuilder->create<GenericSubstitution>(); + requiredSubst->genericDecl = requiredGenericDeclRef.getDecl(); + requiredSubst->outer = requiredGenericDeclRef.substitutions; + + for (Index i = 0; i < memberCount; i++) + { + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + + if(auto requiredTypeParamDeclRef = requiredMemberDeclRef.as<GenericTypeParamDecl>()) + { + auto satisfyingTypeParamDeclRef = satisfyingMemberDeclRef.as<GenericTypeParamDecl>(); + SLANG_ASSERT(satisfyingTypeParamDeclRef); + auto satisfyingType = DeclRefType::create(m_astBuilder, satisfyingTypeParamDeclRef); + + requiredSubst->args.add(satisfyingType); + } + else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as<GenericValueParamDecl>()) + { + auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>(); + SLANG_ASSERT(satisfyingValueParamDeclRef); + auto satisfyingVal = m_astBuilder->create<GenericParamIntVal>(); + satisfyingVal->declRef = satisfyingValueParamDeclRef; + + requiredSubst->args.add(satisfyingVal); + } + } + for (Index i = 0; i < memberCount; i++) + { + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + + if(auto requiredConstraintDeclRef = requiredMemberDeclRef.as<GenericTypeConstraintDecl>()) + { + auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as<GenericTypeConstraintDecl>(); + SLANG_ASSERT(satisfyingConstraintDeclRef); + + auto satisfyingWitness = m_astBuilder->create<DeclaredSubtypeWitness>(); + satisfyingWitness->sub = getSub(m_astBuilder, satisfyingConstraintDeclRef); + satisfyingWitness->sup = getSup(m_astBuilder, satisfyingConstraintDeclRef); + satisfyingWitness->declRef = satisfyingConstraintDeclRef; + + requiredSubst->args.add(satisfyingWitness); + } + } + + // Now that we have computed a set of specialization arguments that will + // specialize the generic requirement at the type parameters of the satisfying + // generic, we can construct a reference to that declaration and re-run some + // of the earlier checking logic with more type information usable. + // + auto specializedRequiredGenericDeclRef = DeclRef<GenericDecl>(requiredGenericDeclRef.getDecl(), requiredSubst); + auto specializedRequiredMemberDeclRefs = getMembers(specializedRequiredGenericDeclRef); + for (Index i = 0; i < memberCount; i++) + { + auto requiredMemberDeclRef = specializedRequiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + + if(auto requiredTypeParamDeclRef = requiredMemberDeclRef.as<GenericTypeParamDecl>()) + { + auto satisfyingTypeParamDeclRef = satisfyingMemberDeclRef.as<GenericTypeParamDecl>(); + SLANG_ASSERT(satisfyingTypeParamDeclRef); + + // There are no additional checks we need to make on plain old + // type parameters at this point. + // + // TODO: If we ever support having type parameters of higher kinds, + // then this is possibly where we'd want to check that the kinds of + // the two parameters match. + // + SLANG_UNUSED(satisfyingGenericDeclRef); + } + else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as<GenericValueParamDecl>()) + { + auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>(); + SLANG_ASSERT(satisfyingValueParamDeclRef); + + // For a generic value parameter, we need to check that the required + // and satisfying declaration both agree on the type of the parameter. + // + auto requiredParamType = getType(m_astBuilder, requiredValueParamDeclRef); + auto satisfyingParamType = getType(m_astBuilder, satisfyingValueParamDeclRef); + if (!satisfyingParamType->equals(requiredParamType)) + return false; + } + else if(auto requiredConstraintDeclRef = requiredMemberDeclRef.as<GenericTypeConstraintDecl>()) + { + auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as<GenericTypeConstraintDecl>(); + SLANG_ASSERT(satisfyingConstraintDeclRef); + + // For a generic constraint parameter, we need to check that the sub-type + // and super-type in the constraint both match. + // + // In current code the sub type will always be one of the generic type parameters, + // and the super-type will always be an interface, but there should be no + // need to make use of those additional details here. + + auto requiredSubType = getSub(m_astBuilder, requiredConstraintDeclRef); + auto satisfyingSubType = getSub(m_astBuilder, satisfyingConstraintDeclRef); + if (!satisfyingSubType->equals(requiredSubType)) + return false; + + auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef); + auto satisfyingSuperType = getSup(m_astBuilder, satisfyingConstraintDeclRef); + if (!satisfyingSuperType->equals(requiredSuperType)) + return false; + } + } + + // Note: the above logic really only applies to the case of an exact match on signature, + // even down to the way that constraints were declared. We could potentially be more + // relaxed by taking advantage of the way that various different generic signatures will + // actually lower to the same IR generic signature. + // + // In theory, all we really care about when it comes to constraints is that the constraints + // on the required and satisfying declaration are *equivalent*. + // + // More generally, a satisfying generic could actually provide *looser* constraints and + // still work; all that matters is that it can be instantiated at any argument values/types + // that are valid for the requirement. + // + // We leave both of those issues up to the synthesis path: if we do not find a member that + // provides an exact match, then the compiler should try to synthesize one that is an exact + // match and makes use of existing declarations that might have require defaulting of arguments + // or type conversations to fit. + + // Once we've validated that the generic signatures are in an exact match, and devised type + // arguments for the requirement to make the two align, we can recursively check the inner + // declaration (whatever it is) for an exact match. + // return doesMemberSatisfyRequirement( - DeclRef<Decl>(genDecl.getDecl()->inner, genDecl.substitutions), - DeclRef<Decl>(requirementGenDecl.getDecl()->inner, requirementGenDecl.substitutions), + DeclRef<Decl>(satisfyingGenericDeclRef.getDecl()->inner, satisfyingGenericDeclRef.substitutions), + DeclRef<Decl>(requiredGenericDeclRef.getDecl()->inner, requiredSubst), witnessTable); } @@ -2375,13 +2576,15 @@ namespace Slang bool SemanticsVisitor::findWitnessForInterfaceRequirement( ConformanceCheckingContext* context, - Type* type, + Type* subType, + Type* superInterfaceType, InheritanceDecl* inheritanceDecl, - DeclRef<InterfaceDecl> interfaceDeclRef, + DeclRef<InterfaceDecl> superInterfaceDeclRef, DeclRef<Decl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable) + RefPtr<WitnessTable> witnessTable, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness) { - SLANG_UNUSED(interfaceDeclRef) + SLANG_UNUSED(superInterfaceDeclRef) // The goal of this function is to find a suitable // value to satisfy the requirement. @@ -2415,18 +2618,40 @@ namespace Slang // // TODO: we *really* need a linearization step here!!!! - RefPtr<WitnessTable> satisfyingWitnessTable = checkConformanceToType( - context, - type, - requiredInheritanceDeclRef.getDecl(), - getBaseType(m_astBuilder, requiredInheritanceDeclRef)); + auto reqType = getBaseType(m_astBuilder, requiredInheritanceDeclRef); - if(!satisfyingWitnessTable) - return false; + DeclaredSubtypeWitness* interfaceIsReqWitness = m_astBuilder->create<DeclaredSubtypeWitness>(); + interfaceIsReqWitness->sub = superInterfaceType; + interfaceIsReqWitness->sup = reqType; + interfaceIsReqWitness->declRef = requiredInheritanceDeclRef; + // ... + + TransitiveSubtypeWitness* subIsReqWitness = m_astBuilder->create<TransitiveSubtypeWitness>(); + subIsReqWitness->sub = subType; + subIsReqWitness->sup = reqType; + subIsReqWitness->subToMid = subTypeConformsToSuperInterfaceWitness; + subIsReqWitness->midToSup = interfaceIsReqWitness; + // ... + + RefPtr<WitnessTable> satisfyingWitnessTable = new WitnessTable(); + satisfyingWitnessTable->witnessedType = subType; + satisfyingWitnessTable->baseType = reqType; witnessTable->add( requiredInheritanceDeclRef.getDecl(), RequirementWitness(satisfyingWitnessTable)); + + if( !checkConformanceToType( + context, + subType, + requiredInheritanceDeclRef.getDecl(), + reqType, + subIsReqWitness, + satisfyingWitnessTable) ) + { + return false; + } + return true; } @@ -2465,7 +2690,7 @@ namespace Slang // requests will be handled further down. For now we include // lookup results that might be usable, but not as-is. // - auto lookupResult = lookUpMember(m_astBuilder, this, name, type, LookupMask::Default, LookupOptions::IgnoreBaseInterfaces); + auto lookupResult = lookUpMember(m_astBuilder, this, name, subType, LookupMask::Default, LookupOptions::IgnoreBaseInterfaces); if(!lookupResult.isValid()) { @@ -2478,7 +2703,8 @@ namespace Slang // signatures of methods, as is done for Swift), we'd // need to revisit this step. // - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, type, requiredMemberDeclRef); + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); + getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); return false; } @@ -2521,26 +2747,29 @@ namespace Slang // and if nothing is found we print the candidates that made it // furthest in checking. // - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, type, requiredMemberDeclRef); + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); + getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); return false; } RefPtr<WitnessTable> SemanticsVisitor::checkInterfaceConformance( ConformanceCheckingContext* context, - Type* type, + Type* subType, + Type* superInterfaceType, InheritanceDecl* inheritanceDecl, - DeclRef<InterfaceDecl> interfaceDeclRef) + DeclRef<InterfaceDecl> superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitnes) { // Has somebody already checked this conformance, // and/or is in the middle of checking it? RefPtr<WitnessTable> witnessTable; - if(context->mapInterfaceToWitnessTable.TryGetValue(interfaceDeclRef, witnessTable)) + if(context->mapInterfaceToWitnessTable.TryGetValue(superInterfaceDeclRef, witnessTable)) return witnessTable; // We need to check the declaration of the interface // before we can check that we conform to it. // - ensureDecl(interfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); + ensureDecl(superInterfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); // We will construct the witness table, and register it // *before* we go about checking fine-grained requirements, @@ -2554,10 +2783,53 @@ namespace Slang if(!witnessTable) { witnessTable = new WitnessTable(); - witnessTable->baseType = DeclRefType::create(m_astBuilder, interfaceDeclRef); - witnessTable->witnessedType = type; + witnessTable->baseType = DeclRefType::create(m_astBuilder, superInterfaceDeclRef); + witnessTable->witnessedType = subType; } - context->mapInterfaceToWitnessTable.Add(interfaceDeclRef, witnessTable); + context->mapInterfaceToWitnessTable.Add(superInterfaceDeclRef, witnessTable); + + if(!checkInterfaceConformance(context, subType, superInterfaceType, inheritanceDecl, superInterfaceDeclRef, subTypeConformsToSuperInterfaceWitnes, witnessTable)) + return nullptr; + + return witnessTable; + } + + static bool isAssociatedTypeDecl(Decl* decl) + { + auto d = decl; + while(auto genericDecl = as<GenericDecl>(d)) + d = genericDecl->inner; + if(as<AssocTypeDecl>(d)) + return true; + return false; + } + + bool SemanticsVisitor::checkInterfaceConformance( + ConformanceCheckingContext* context, + Type* subType, + Type* superInterfaceType, + InheritanceDecl* inheritanceDecl, + DeclRef<InterfaceDecl> superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness, + WitnessTable* witnessTable) + { + // We need to check the declaration of the interface + // before we can check that we conform to it. + // + ensureDecl(superInterfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); + + // When comparing things like signatures, we need to do so in the context + // of a this-type substitution that aligns the signatures in the interface + // with those in the concrete type. For example, we need to treat any uses + // of `This` in the interface as equivalent to the concrete type for the + // purpose of signature matching (and similarly for associated types). + // + ThisTypeSubstitution* thisTypeSubst = m_astBuilder->create<ThisTypeSubstitution>(); + thisTypeSubst->interfaceDecl = superInterfaceDeclRef.getDecl(); + thisTypeSubst->witness = subTypeConformsToSuperInterfaceWitness; + thisTypeSubst->outer = superInterfaceDeclRef.substitutions.substitutions; + + auto specializedSuperInterfaceDeclRef = DeclRef<InterfaceDecl>(superInterfaceDeclRef.getDecl(), thisTypeSubst); bool result = true; @@ -2567,15 +2839,59 @@ namespace Slang // its (non-interface) base types already conforms to // that interface, so that all of the requirements are // already satisfied with inherited implementations... - for(auto requiredMemberDeclRef : getMembers(interfaceDeclRef)) + + // Note: we break this logic into two loops, where we first + // check conformance for all associated-type requirements + // and *then* check conformance for all other requirements. + // + // Checking associated-type requirements first ensures that + // we can make use of the identity of the associated types + // when checking other members. + // + // TODO: There could in theory be subtle cases involving + // circular or recursive dependency chains that make such + // a simple ordering impractical (e.g., associated type `A` + // is constrained to `IThing<This>` where `IThing<T>` requires + // that `T : IOtherThing where T.B == int` for another associated + // type `B`). + // + // The only robust solution long-term is probably to treat this + // as a type-inference problem by creating type variables to + // stand in for the associated-type requirements and then to discover + // constraints and solve for those type variables as part of the + // conformance-checking process. + // + for(auto requiredMemberDeclRef : getMembers(specializedSuperInterfaceDeclRef)) { + if(!isAssociatedTypeDecl(requiredMemberDeclRef)) + continue; + auto requirementSatisfied = findWitnessForInterfaceRequirement( context, - type, + subType, + superInterfaceType, inheritanceDecl, - interfaceDeclRef, + specializedSuperInterfaceDeclRef, requiredMemberDeclRef, - witnessTable); + witnessTable, + subTypeConformsToSuperInterfaceWitness); + + result = result && requirementSatisfied; + } + for(auto requiredMemberDeclRef : getMembers(specializedSuperInterfaceDeclRef)) + { + if(isAssociatedTypeDecl(requiredMemberDeclRef)) + continue; + + auto requirementSatisfied = findWitnessForInterfaceRequirement( + context, + subType, + superInterfaceType, + inheritanceDecl, + specializedSuperInterfaceDeclRef, + requiredMemberDeclRef, + witnessTable, + subTypeConformsToSuperInterfaceWitness); result = result && requirementSatisfied; } @@ -2604,14 +2920,12 @@ namespace Slang // the time we are compiling and handle those, and punt on the larger issue // for a bit longer. // - for(auto candidateExt : getCandidateExtensions(interfaceDeclRef, this)) + for(auto candidateExt : getCandidateExtensions(specializedSuperInterfaceDeclRef, this)) { // We need to apply the extension to the interface type that our // concrete type is inheriting from. // - // TODO: need to decide if a this-type substitution is needed here. - // It probably it. - Type* targetType = DeclRefType::create(m_astBuilder, interfaceDeclRef); + Type* targetType = DeclRefType::create(m_astBuilder, specializedSuperInterfaceDeclRef); auto extDeclRef = ApplyExtensionToType(candidateExt, targetType); if(!extDeclRef) continue; @@ -2621,65 +2935,66 @@ namespace Slang { auto requirementSatisfied = findWitnessForInterfaceRequirement( context, - type, + subType, + superInterfaceType, inheritanceDecl, - interfaceDeclRef, + specializedSuperInterfaceDeclRef, requiredInheritanceDeclRef, - witnessTable); + witnessTable, + subTypeConformsToSuperInterfaceWitness); result = result && requirementSatisfied; } } - // If we failed to satisfy any requirements along the way, - // then we don't actually want to keep the witness table - // we've been constructing, because the whole thing was a failure. - if(!result) - { - return nullptr; - } - - return witnessTable; + // The conformance was satisfied if all the requirements were satisfied. + // + return result; } - RefPtr<WitnessTable> SemanticsVisitor::checkConformanceToType( + bool SemanticsVisitor::checkConformanceToType( ConformanceCheckingContext* context, - Type* type, + Type* subType, InheritanceDecl* inheritanceDecl, - Type* baseType) + Type* superType, + SubtypeWitness* subIsSuperWitness, + WitnessTable* witnessTable) { - if (auto baseDeclRefType = as<DeclRefType>(baseType)) + if (auto supereclRefType = as<DeclRefType>(superType)) { - auto baseTypeDeclRef = baseDeclRefType->declRef; - if (auto baseInterfaceDeclRef = baseTypeDeclRef.as<InterfaceDecl>()) + auto superTypeDeclRef = supereclRefType->declRef; + if (auto superInterfaceDeclRef = superTypeDeclRef.as<InterfaceDecl>()) { // The type is stating that it conforms to an interface. // We need to check that it provides all of the members // required by that interface. return checkInterfaceConformance( context, - type, + subType, + superType, inheritanceDecl, - baseInterfaceDeclRef); + superInterfaceDeclRef, + subIsSuperWitness, + witnessTable); } - else if( auto structDeclRef = baseTypeDeclRef.as<StructDecl>() ) + else if( auto superStructDeclRef = superTypeDeclRef.as<StructDecl>() ) { // The type is saying it inherits from a `struct`, // which doesn't require any checking at present - return nullptr; + return true; } } getSink()->diagnose(inheritanceDecl, Diagnostics::unimplemented, "type not supported for inheritance"); - return nullptr; + return false; } bool SemanticsVisitor::checkConformance( - Type* type, + Type* subType, InheritanceDecl* inheritanceDecl, ContainerDecl* parentDecl) { - if( auto declRefType = as<DeclRefType>(type) ) + if( auto declRefType = as<DeclRefType>(subType) ) { auto declRef = declRefType->declRef; @@ -2709,16 +3024,32 @@ namespace Slang // Look at the type being inherited from, and validate // appropriately. - auto baseType = inheritanceDecl->base.type; + auto superType = inheritanceDecl->base.type; + + DeclaredSubtypeWitness* subIsSuperWitness = m_astBuilder->create<DeclaredSubtypeWitness>(); + subIsSuperWitness->declRef = makeDeclRef(inheritanceDecl); + subIsSuperWitness->sub = subType; + subIsSuperWitness->sup = superType; ConformanceCheckingContext context; - context.conformingType = type; + context.conformingType = subType; context.parentDecl = parentDecl; - RefPtr<WitnessTable> witnessTable = checkConformanceToType(&context, type, inheritanceDecl, baseType); + + + RefPtr<WitnessTable> witnessTable = inheritanceDecl->witnessTable; if(!witnessTable) + { + witnessTable = new WitnessTable(); + witnessTable->baseType = superType; + witnessTable->witnessedType = subType; + inheritanceDecl->witnessTable = witnessTable; + } + + if( !checkConformanceToType(&context, subType, inheritanceDecl, superType, subIsSuperWitness, witnessTable) ) + { return false; + } - inheritanceDecl->witnessTable = witnessTable; return true; } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 5c00cb64b..fbbfcd473 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -694,7 +694,7 @@ namespace Slang } IntVal* SemanticsVisitor::tryConstantFoldExpr( - InvokeExpr* invokeExpr, + SubstExpr<InvokeExpr> invokeExpr, ConstantFoldingCircularityInfo* circularityInfo) { // We need all the operands to the expression @@ -703,10 +703,10 @@ namespace Slang // // For right now we will look for calls to intrinsic functions, and then inspect // their names (this is bad and slow). - auto funcDeclRefExpr = as<DeclRefExpr>(invokeExpr->functionExpr); + auto funcDeclRefExpr = getBaseExpr(invokeExpr).as<DeclRefExpr>(); if (!funcDeclRefExpr) return nullptr; - auto funcDeclRef = funcDeclRefExpr->declRef; + auto funcDeclRef = getDeclRef(m_astBuilder, funcDeclRefExpr); auto intrinsicMod = funcDeclRef.getDecl()->findModifier<IntrinsicOpModifier>(); if (!intrinsicMod) { @@ -722,31 +722,31 @@ namespace Slang // Let's not constant-fold operations with more than a certain number of arguments, for simplicity static const int kMaxArgs = 8; - if (invokeExpr->arguments.getCount() > kMaxArgs) + auto argCount = getArgCount(invokeExpr); + if (argCount > kMaxArgs) return nullptr; // Before checking the operation name, let's look at the arguments IntVal* argVals[kMaxArgs]; IntegerLiteralValue constArgVals[kMaxArgs]; - int argCount = 0; bool allConst = true; - for (auto argExpr : invokeExpr->arguments) + for(Index a = 0; a < argCount; ++a) { + auto argExpr = getArg(invokeExpr, a); auto argVal = tryFoldIntegerConstantExpression(argExpr, circularityInfo); if (!argVal) return nullptr; - argVals[argCount] = argVal; + argVals[a] = argVal; if (auto constArgVal = as<ConstantIntVal>(argVal)) { - constArgVals[argCount] = constArgVal->value; + constArgVals[a] = constArgVal->value; } else { allConst = false; } - argCount++; } if (!allConst) @@ -866,25 +866,25 @@ namespace Slang } IntVal* SemanticsVisitor::tryConstantFoldExpr( - Expr* expr, + SubstExpr<Expr> expr, ConstantFoldingCircularityInfo* circularityInfo) { // Unwrap any "identity" expressions - while (auto parenExpr = as<ParenExpr>(expr)) + while (auto parenExpr = expr.as<ParenExpr>()) { - expr = parenExpr->base; + expr = getBaseExpr(parenExpr); } // TODO(tfoley): more serious constant folding here - if (auto intLitExpr = as<IntegerLiteralExpr>(expr)) + if (auto intLitExpr = expr.as<IntegerLiteralExpr>()) { return getIntVal(intLitExpr); } // it is possible that we are referring to a generic value param - if (auto declRefExpr = as<DeclRefExpr>(expr)) + if (auto declRefExpr = expr.as<DeclRefExpr>()) { - auto declRef = declRefExpr->declRef; + auto declRef = getDeclRef(m_astBuilder, declRefExpr); if (auto genericValParamRef = declRef.as<GenericValueParamDecl>()) { @@ -913,13 +913,13 @@ namespace Slang } } - if(auto castExpr = as<TypeCastExpr>(expr)) + if(auto castExpr = expr.as<TypeCastExpr>()) { - auto val = tryConstantFoldExpr(castExpr->arguments[0], circularityInfo); + auto val = tryConstantFoldExpr(getArg(castExpr, 0), circularityInfo); if(val) return val; } - else if (auto invokeExpr = as<InvokeExpr>(expr)) + else if (auto invokeExpr = expr.as<InvokeExpr>()) { auto val = tryConstantFoldExpr(invokeExpr, circularityInfo); if (val) @@ -930,12 +930,12 @@ namespace Slang } IntVal* SemanticsVisitor::tryFoldIntegerConstantExpression( - Expr* expr, + SubstExpr<Expr> expr, ConstantFoldingCircularityInfo* circularityInfo) { // Check if type is acceptable for an integer constant expression // - if(!isScalarIntegerType(expr->type)) + if(!isScalarIntegerType(getType(m_astBuilder, expr))) return nullptr; // Consider operations that we might be able to constant-fold... @@ -2037,7 +2037,11 @@ namespace Slang { auto containerDecl = scope->containerDecl; - if( auto setterDecl = as<SetterDecl>(containerDecl) ) + if( auto ctorDecl = as<ConstructorDecl>(containerDecl) ) + { + expr->type.isLeftValue = true; + } + else if( auto setterDecl = as<SetterDecl>(containerDecl) ) { expr->type.isLeftValue = true; } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index ad9de3f4b..bc265b1b4 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -830,27 +830,42 @@ namespace Slang // bool findWitnessForInterfaceRequirement( ConformanceCheckingContext* context, - Type* type, + Type* subType, + Type* superInterfaceType, InheritanceDecl* inheritanceDecl, - DeclRef<InterfaceDecl> interfaceDeclRef, + DeclRef<InterfaceDecl> superInterfaceDeclRef, DeclRef<Decl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable); + RefPtr<WitnessTable> witnessTable, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness); // Check that the type declaration `typeDecl`, which // declares conformance to the interface `interfaceDeclRef`, // (via the given `inheritanceDecl`) actually provides // members to satisfy all the requirements in the interface. + bool checkInterfaceConformance( + ConformanceCheckingContext* context, + Type* subType, + Type* superInterfaceType, + InheritanceDecl* inheritanceDecl, + DeclRef<InterfaceDecl> superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness, + WitnessTable* witnessTable); + RefPtr<WitnessTable> checkInterfaceConformance( ConformanceCheckingContext* context, - Type* type, + Type* subType, + Type* superInterfaceType, InheritanceDecl* inheritanceDecl, - DeclRef<InterfaceDecl> interfaceDeclRef); + DeclRef<InterfaceDecl> superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness); - RefPtr<WitnessTable> checkConformanceToType( + bool checkConformanceToType( ConformanceCheckingContext* context, - Type* type, + Type* subType, InheritanceDecl* inheritanceDecl, - Type* baseType); + Type* superType, + SubtypeWitness* subIsSuperWitness, + WitnessTable* witnessTable); /// Check that `type` which has declared that it inherits from (and/or implements) /// another type via `inheritanceDecl` actually does what it needs to for that @@ -913,6 +928,11 @@ namespace Slang IntVal* getIntVal(IntegerLiteralExpr* expr); + inline IntVal* getIntVal(SubstExpr<IntegerLiteralExpr> expr) + { + return getIntVal(expr.getExpr()); + } + Name* getName(String const& text) { return getNamePool()->getName(text); @@ -938,12 +958,12 @@ namespace Slang /// Try to apply front-end constant folding to determine the value of `invokeExpr`. IntVal* tryConstantFoldExpr( - InvokeExpr* invokeExpr, + SubstExpr<InvokeExpr> invokeExpr, ConstantFoldingCircularityInfo* circularityInfo); /// Try to apply front-end constant folding to determine the value of `expr`. IntVal* tryConstantFoldExpr( - Expr* expr, + SubstExpr<Expr> expr, ConstantFoldingCircularityInfo* circularityInfo); bool _checkForCircularityInConstantFolding( @@ -960,7 +980,7 @@ namespace Slang /// as an integer constant. /// IntVal* tryFoldIntegerConstantExpression( - Expr* expr, + SubstExpr<Expr> expr, ConstantFoldingCircularityInfo* circularityInfo); // Enforce that an expression resolves to an integer constant, and get its value diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index c4e4cbde4..448cdeb88 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1051,19 +1051,19 @@ namespace Slang Index argCount = context.getArgCount(); Index paramCount = params.getCount(); - // Bail out on mismatch. - // TODO(tfoley): need more nuance here - if (argCount != paramCount) + // If there are too many arguments, we cannot possibly have a match. + // + // Note that if there are *too few* arguments, we might still have + // a match, because the other arguments might have default values + // that can be used. + // + if (argCount > paramCount) { return DeclRef<Decl>(nullptr, nullptr); } for (Index aa = 0; aa < argCount; ++aa) { -#if 0 - if (!TryUnifyArgAndParamTypes(constraints, args[aa], params[aa])) - return DeclRef<Decl>(nullptr, nullptr); -#else // The question here is whether failure to "unify" an argument // and parameter should lead to immediate failure. // @@ -1083,7 +1083,6 @@ namespace Slang // unification step should be taken as an immediate failure... TryUnifyTypes(constraints, context.getArgTypeForInference(aa, this), getType(m_astBuilder, params[aa])); -#endif } } else diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index fb9fc70fd..3c9256178 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -633,7 +633,7 @@ LoweredValInfo emitCallToDeclRef( if( auto ctorDeclRef = funcDeclRef.as<ConstructorDecl>() ) { - if(!ctorDeclRef.getDecl()->body) + if(!ctorDeclRef.getDecl()->body && isFromStdLib(ctorDeclRef.decl)) { // HACK: For legacy reasons, all of the built-in initializers // in the standard library are declared without proper @@ -1114,7 +1114,6 @@ void getGenericTypeConformances(IRGenContext* context, ShortList<IRType*>& supTy } } -SubstitutionSet lowerSubstitutions(IRGenContext* context, SubstitutionSet subst); // struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredValInfo> @@ -3141,6 +3140,45 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + void _lowerSubstitutionArg(IRGenContext* subContext, GenericSubstitution* subst, Decl* paramDecl, Index argIndex) + { + SLANG_ASSERT(argIndex < subst->args.getCount()); + auto argVal = lowerVal(subContext, subst->args[argIndex]); + setValue(subContext, paramDecl, argVal); + } + + void _lowerSubstitutionEnv(IRGenContext* subContext, Substitutions* subst) + { + if(!subst) return; + _lowerSubstitutionEnv(subContext, subst->outer); + + if (auto genSubst = as<GenericSubstitution>(subst)) + { + auto genDecl = genSubst->genericDecl; + + Index argCounter = 0; + for( auto memberDecl: genDecl->members ) + { + if(auto typeParamDecl = as<GenericTypeParamDecl>(memberDecl) ) + { + _lowerSubstitutionArg(subContext, genSubst, typeParamDecl, argCounter++); + } + else if( auto valParamDecl = as<GenericValueParamDecl>(memberDecl) ) + { + _lowerSubstitutionArg(subContext, genSubst, valParamDecl, argCounter++); + } + } + for( auto memberDecl: genDecl->members ) + { + if(auto constraintDecl = as<GenericTypeConstraintDecl>(memberDecl) ) + { + _lowerSubstitutionArg(subContext, genSubst, constraintDecl, argCounter++); + } + } + } + // TODO: also need to handle this-type substitution here? + } + void addDirectCallArgs( InvokeExpr* expr, DeclRef<CallableDecl> funcDeclRef, @@ -3156,10 +3194,10 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> auto paramDirection = getParameterDirection(paramDecl); UInt argIndex = argCounter++; - Expr* argExpr = nullptr; if(argIndex < argCount) { - argExpr = expr->arguments[argIndex]; + auto argExpr = expr->arguments[argIndex]; + addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups); } else { @@ -3167,11 +3205,31 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // but there are still parameters remaining. This must mean // that these parameters have default argument expressions // associated with them. - argExpr = getInitExpr(getASTBuilder(), paramDeclRef); - - // Assert that such an expression must have been present. + // + // Currently we simply extract the initial-value expression + // from the parameter declaration and then lower it in + // the context of the caller. + // + // Note that the expression could involve subsitutions because + // in the general case it could depend on the generic parameters + // used the specialize the callee. For now we do not handle that + // case, and simply ignore generic arguments. + // + SubstExpr<Expr> argExpr = getInitExpr(getASTBuilder(), paramDeclRef); SLANG_ASSERT(argExpr); + IRGenEnv subEnvStorage; + IRGenEnv* subEnv = &subEnvStorage; + subEnv->outer = context->env; + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->env = subEnv; + + _lowerSubstitutionEnv(subContext, argExpr.getSubsts()); + + addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups); + // TODO: The approach we are taking here to default arguments // is simplistic, and has consequences for the front-end as // well as binary serialization of modules. @@ -3186,9 +3244,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // // Each of these options involves trade-offs, and we need to // make a conscious decision at some point. - } - addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups); + // Assert that such an expression must have been present. + } } } @@ -7347,35 +7405,6 @@ LoweredValInfo ensureDecl( return result; } -IRInst* lowerSubstitutionArg( - IRGenContext* context, - Val* val) -{ - if (auto type = dynamicCast<Type>(val)) - { - return lowerType(context, type); - } - else if (auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(val)) - { - // We need to look up the IR-level representation of the witness (which will be a witness table). - auto supType = lowerType( - context, - DeclRefType::create(context->astBuilder, declaredSubtypeWitness->declRef)); - auto irWitnessTable = getSimpleVal( - context, - emitDeclRef( - context, - declaredSubtypeWitness->declRef, - context->irBuilder->getWitnessTableType(supType))); - return irWitnessTable; - } - else - { - SLANG_UNIMPLEMENTED_X("value cases"); - UNREACHABLE_RETURN(nullptr); - } -} - // Can the IR lowered version of this declaration ever be an `IRGeneric`? bool canDeclLowerToAGeneric(Decl* decl) { diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 62de6599c..accbdb27e 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -641,17 +641,37 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return declRef.substituteImpl(astBuilder, substitutions, &diff); } - Expr* DeclRefBase::substitute(ASTBuilder* /* astBuilder*/, Expr* expr) const + SubstExpr<Expr> DeclRefBase::substitute(ASTBuilder* /* astBuilder*/, Expr* expr) const { - // No substitutions? Easy. - if (!substitutions) - return expr; + return SubstExpr<Expr>(expr, substitutions); + } + + SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr) + { + return SubstExpr<Expr>(expr, substs); + } + + DeclRef<Decl> substituteDeclRef(SubstitutionSet const& substs, ASTBuilder* astBuilder, DeclRef<Decl> const& declRef) + { + if(!substs) + return declRef; + + int diff = 0; + auto declRefBase = declRef.substituteImpl(astBuilder, substs, &diff); + return DeclRef<Decl>(declRefBase.decl, declRefBase.substitutions); + } - SLANG_UNIMPLEMENTED_X("generic substitution into expressions"); + Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type) + { + if(!type) return nullptr; + if(!substs) return type; + + SLANG_ASSERT(type); - UNREACHABLE_RETURN(expr); + return Slang::as<Type>(type->substitute(astBuilder, substs)); } + void buildMemberDictionary(ContainerDecl* decl); InterfaceDecl* findOuterInterfaceDecl(Decl* decl) @@ -854,7 +874,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - DeclRefBase DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet substSet, int* ioDiff) + DeclRefBase DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet substSet, int* ioDiff) const { // Nothing to do when we have no declaration. if(!decl) diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index a0c54f914..9589f00fd 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -106,7 +106,7 @@ namespace Slang return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } - inline Expr* getInitExpr(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& declRef) + inline SubstExpr<Expr> getInitExpr(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->initExpr); } @@ -121,7 +121,7 @@ namespace Slang return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } - inline Expr* getTagExpr(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> const& declRef) + inline SubstExpr<Expr> getTagExpr(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> const& declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->tagExpr); } @@ -165,6 +165,37 @@ namespace Slang return declRef.getDecl()->inner; } + // + + inline Type* getType(ASTBuilder* astBuilder, SubstExpr<Expr> expr) + { + return substituteType(expr.getSubsts(), astBuilder, expr.getExpr()->type); + } + + inline SubstExpr<Expr> getBaseExpr(SubstExpr<ParenExpr> expr) + { + return substituteExpr(expr.getSubsts(), expr.getExpr()->base); + } + + inline SubstExpr<Expr> getBaseExpr(SubstExpr<InvokeExpr> expr) + { + return substituteExpr(expr.getSubsts(), expr.getExpr()->functionExpr); + } + + inline Index getArgCount(SubstExpr<InvokeExpr> expr) + { + return expr.getExpr()->arguments.getCount(); + } + + inline SubstExpr<Expr> getArg(SubstExpr<InvokeExpr> expr, Index index) + { + return substituteExpr(expr.getSubsts(), expr.getExpr()->arguments[index]); + } + + inline DeclRef<Decl> getDeclRef(ASTBuilder* astBuilder, SubstExpr<DeclRefExpr> expr) + { + return substituteDeclRef(expr.getSubsts(), astBuilder, expr.getExpr()->declRef); + } // diff --git a/tests/diagnostics/interface-requirement-not-satisfied.slang.expected b/tests/diagnostics/interface-requirement-not-satisfied.slang.expected index d7614186b..464ffde25 100644 --- a/tests/diagnostics/interface-requirement-not-satisfied.slang.expected +++ b/tests/diagnostics/interface-requirement-not-satisfied.slang.expected @@ -3,6 +3,7 @@ standard error = { tests/diagnostics/interface-requirement-not-satisfied.slang(10): error 38100: type 'T' does not provide required interface member 'bar' struct T : IFoo ^~~~ +tests/diagnostics/interface-requirement-not-satisfied.slang(7): note: see declaration of 'bar' } standard output = { } diff --git a/tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang.expected b/tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang.expected index 3c9ef58d5..cc4f310ad 100644 --- a/tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang.expected +++ b/tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang.expected @@ -3,6 +3,7 @@ standard error = { tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang(10): error 38100: type 'Counter' does not provide required interface member 'processValue' struct Counter : IThing ^~~~~~ +tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang(7): note: see declaration of 'processValue' } standard output = { } diff --git a/tests/language-feature/parameters/generic-func-param-default-arg.slang b/tests/language-feature/parameters/generic-func-param-default-arg.slang new file mode 100644 index 000000000..b7e8d6fa2 --- /dev/null +++ b/tests/language-feature/parameters/generic-func-param-default-arg.slang @@ -0,0 +1,47 @@ +// generic-func-param-default-arg.slang + +// Test that generic functions can have default argument values on their parameters. + +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST_DISABLED:SIMPLE:-target hlsl -entry computeMain -dump-ir + + +interface IValue +{ + __init(); + + This plusA(This other); + This plusB(int other); +} + +T sum<T : IValue>(T value, T other = T(), int extra = 0) +{ + return value.plusA(other).plusB(extra); +} + +struct Simple : IValue +{ + int val; + + __init() { val = 0; } + __init(int val) { this.val = val; } + + Simple plusA(Simple other) { return Simple(val + other.val); } + Simple plusB(int other) { return Simple(val + other); } +} + +int test(int val) +{ + let s = Simple(val); + return sum<Simple>(s).val + 16*sum(s).val; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gBuffer +RWStructuredBuffer<int> gBuffer; + +[shader("compute")] +[numthreads(4)] +void computeMain(int tid : SV_DispatchThreadID) +{ + gBuffer[tid] = test(tid); +} diff --git a/tests/language-feature/parameters/generic-func-param-default-arg.slang.expected.txt b/tests/language-feature/parameters/generic-func-param-default-arg.slang.expected.txt new file mode 100644 index 000000000..d4cb1cc00 --- /dev/null +++ b/tests/language-feature/parameters/generic-func-param-default-arg.slang.expected.txt @@ -0,0 +1,4 @@ +0 +11 +22 +33 |
