diff options
| author | Yong He <yonghe@outlook.com> | 2024-09-05 11:24:19 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-05 11:24:19 -0700 |
| commit | d655302465457c5d3285ae5339201a0769cc38dc (patch) | |
| tree | 4c0946ba4ea4879831133370d2203f569c135c35 | |
| parent | a88055c6f5190ca62bb4aa853b4f0fa11546278f (diff) | |
Support `where` clause and type equality constraint. (#4986)
* Support `where` clause.
* Fix.
* Fix parser.
* Enhance test to cover traditional __generic syntax.
* Update user-guide.
* Support `where` clause on associatedtype.
* Fix.
* Put in more comments.
27 files changed, 739 insertions, 128 deletions
diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md index 569b36e1d..61bc43f89 100644 --- a/docs/user-guide/06-interfaces-generics.md +++ b/docs/user-guide/06-interfaces-generics.md @@ -51,13 +51,13 @@ Generics Generics can be used to eliminate duplicate code for shared logic that operates on different types. The following example shows how to define a generic method in Slang. ```csharp -int myGenericMethod<T: IFoo>(T arg) +int myGenericMethod<T>(T arg) where T : IFoo { return arg.myMethod(1.0); } ``` -The above listing defines a generic method named `myGenericMethod`, which accepts an argument that can be of any type `T` as long as `T` conforms to the `IFoo` interface. The `T` here is called a _generic type parameter_, and it is associated with an _type constraint_ that any type represented by `T` must conform to the interface `IFoo`. +The above listing defines a generic method named `myGenericMethod`, which accepts an argument that can be of any type `T` as long as `T` conforms to the `IFoo` interface. The `T` here is called a _generic type parameter_, and it is associated with an _type constraint_ in the following `where` clause to indicate that any type represented by `T` must conform to the interface `IFoo`. The following listing shows how to invoke a generic method: ```csharp @@ -83,6 +83,27 @@ void g2<let e : MyEnum>() { ... } void g3<let b : bool>() { ... } ``` +### Alternative Syntax + +Alternatively, you can use `__generic` keyword to define generic parameters before the method: +```csharp +__generic<typename T> // `typename` is optional. +int myGenericMethod(T arg) where T : IFoo +{ + return arg.myMethod(1.0); +} +``` + +The same method can be defined in an alternative simplified syntax without the `where` clause: +```csharp +int myGenericMethod<T:IFoo>(T arg) { ... } +``` + +Generic value parameters can also be defined using the traditional C-style syntax: +```csharp +void g1<typename T, int n>() { ... } +``` + Supported Constructs in Interface Definitions ----------------------------------------------------- @@ -103,7 +124,7 @@ The above listing declares that any conforming type must define a property named ```csharp interface IFoo { - int compute<T:IBar>(T val); + int compute<T>(T val) where T : IBar; } ``` The above listing declares that any conforming type must define a generic method named `compute` that has one generic type parameter conforming to the `IBar` interface. @@ -119,7 +140,7 @@ interface IFoo The above listing declares that any conforming type must define a static method named `compute`. This allows the following generic method to pass type-checking: ```csharp -void f<T:IFoo>() +void f<T>() where T : IFoo { T.compute(5); // OK, T has a static method `compute`. } @@ -394,7 +415,8 @@ struct MultiArrayFloatContainer : IFloatContainer In summary, an `asssociatedtype` requirement in an interface is similar to other types of requirements: a method requirement means that an implementation must provide a method matching the interface signature, while an `associatedtype` requirement means that an implementation must provide a type in its scope with the matching name and interface constraint. In general, when defining an interface that is producing and consuming an object whose actual type is implementation-dependent, the type of this object can often be modeled as an associated type in the interface. -### Comparison to the C++ Approach + +### Comparing Generics to C++ Templates Readers who are familiar with C++ could easily relate the `Iterator` example in previous subsection to the implementation of STL. In C++, the `sum` function can be easily written with templates: ```C++ template<typename TContainer> @@ -454,6 +476,38 @@ Note that the builtin `vector<float, N>` type also has an generic value paramete `vector<float, 1+1>` is allowed and considered equivalent to `vector<float, 2>`. +Type Equality Constraints +------------------------- + +In addition to type conformance constraints as in `where T : IFoo`, Slang also supports type equality constraints. This is mostly useful in specifying additional constraints for +associated types. For example: +```csharp +interface IFoo { associatedtype A; } + +// Access all T that conforms to IFoo, and T.A is `int`. +void foo<T>(T v) + where T : IFoo + where T.A == int +{ +} + +struct X : IFoo +{ + typealias A = int; +} + +struct Y : IFoo +{ + typealias A = float; +} + +void test() +{ + foo<X>(X()); // OK + foo<Y>(Y()); // Error, `Y` cannot be used for `T`. +} +``` + Interface-typed Values ------------------------------- @@ -809,6 +863,96 @@ int test() This feature is similar to extension traits in Rust. +Variadic Generics +------------------------- + +Slang supports variadic generic type parameters: +```csharp +struct MyType<each T> +{} +``` + +Here `each T` defines a generic type pack parameter that can be a list of zero or more types. Therefore, the following instantiation of `MyType` is valid: +``` +MyType // OK +MyType<int> // OK +MyType<int, float, void> // OK +``` + +A common use of variadic generics is to define `printf`: +```csharp +void printf<each T>(String message, expand each T args) { ... } +``` + +The type syntax `expand each T` represents a expansion of the type pack `T`. Therefore, the type of `args` parameter is an expanded type pack. +The `expand` expression can be thought of a map operation of a type pack. For example, +give type pack `T = int, float, bool`, `expand each T` evaluates to the type pack of the same types, i.e. `expand each T ==> int, float, bool`. +As a more interesting example, `expand S<each T>` will evaluate to `S<int>, S<float>, S<bool>`. + +You can use `expand` expression on tuple or type-pack values to compute an expression for each element of the tuple or type pack. +For example: + +```csharp +void printNumbers<each T>(expand each T args) where T == int +{ + // An single expression statement whose type will be `(void, void, ...)`. + // where each `void` is the result of evaluating expression `printf(...)` with + // each corresponding element in `args` passed as print operand. + // + expand printf("%d\n", each args); + + // The above statement is equivalent to: + // ``` + // (printf("%d\n", args[0]), printf("%d\n", args[1]), ..., printf("%d\n", args[n-1])); + // ``` +} +void compute<each T>(expand each T args) where T == int +{ + // Maps every element in `args` to `elementValue + 1`, and forward the + // new values as arguments to `printNumber`. + printNumber(expand (each args) + 1); + + // The above statement is equivalent to: + // ``` + // printNumber(args[0] + 1, args[1] + 1, ..., args[n-1] + 1); + // ``` +} +void test() +{ + compute(1,2,3); + // Prints: + // 2 + // 3 + // 4 +} +``` + +As another example, you can use `expand` expression to sum up elements in a variadic argument pack: +```csharp +void accumulateHelper(inout int dest, int value) { dest += value; } + +void sum<each T>(expand each T args) where T == int +{ + int result = 0; + expand accumulateHelper(result, each args); + + // The above statement is equivalent to: + // ``` + // (accumulateHelper(result, args[0]), accumulateHelper(result, args[1]), ..., accumulateHelper(result, args[n-1])); + // ``` + + return result; +} + +void test() +{ + int x = sum(1,2,3); // x == 6 +} +``` + +Note that a variadic type pack parameter must appear at the end of a parameter list. If a generic type contains more than one +type pack parameters, then each type pack must contain the same number of arguments at instantiation sites. + Builtin Interfaces ----------------------------- @@ -822,6 +966,10 @@ Slang supports the following builtin interfaces: - `IDifferentiable`, represents a value that is differentiable. - `IFloat`, represents a logical float that supports both `IArithmetic`, `ILogical` and `IDifferentiable` operations. Also provides methods to convert to and from `float`. Implemented by all builtin floating-point scalar, vector and matrix types. - `IArray<T>`, represents a logical array that supports retrieving an element of type `T` from an index. Implemented by array types, vectors and matrices. +- `IFunc<TResult, TParams...>` represent a callable object (with `operator()`) that returns `TResult` and takes `TParams...` as argument. +- `IMutatingFunc<TResult, TParams...>`, similar to `IFunc`, but the `operator()` method is `[mutating]`. +- `IDifferentiableFunc<TResult, TParams...>`, similar to `IFunc`, but the `operator()` method is `[Differentiable]`. +- `IDifferentiableMutatingFunc<TResult, TParams...>`, similar to `IFunc,` but the `operator()` method is `[Differentiable]` and `[mutating]`. - `__EnumType`, implemented by all enum types. - `__BuiltinIntegerType`, implemented by all integer scalar types. - `__BuiltinFloatingPointType`, implemented by all floating-point scalar types. diff --git a/docs/user-guide/toc.html b/docs/user-guide/toc.html index e6c75b1f1..df1b4ba9d 100644 --- a/docs/user-guide/toc.html +++ b/docs/user-guide/toc.html @@ -79,10 +79,12 @@ <li data-link="interfaces-generics#supported-constructs-in-interface-definitions"><span>Supported Constructs in Interface Definitions</span></li> <li data-link="interfaces-generics#associated-types"><span>Associated Types</span></li> <li data-link="interfaces-generics#generic-value-parameters"><span>Generic Value Parameters</span></li> +<li data-link="interfaces-generics#type-equality-constraints"><span>Type Equality Constraints</span></li> <li data-link="interfaces-generics#interface-typed-values"><span>Interface-typed Values</span></li> <li data-link="interfaces-generics#extending-a-type-with-additional-interface-conformances"><span>Extending a Type with Additional Interface Conformances</span></li> <li data-link="interfaces-generics#is-and-as-operator"><span>`is` and `as` Operator</span></li> <li data-link="interfaces-generics#extensions-to-interfaces"><span>Extensions to Interfaces</span></li> +<li data-link="interfaces-generics#variadic-generics"><span>Variadic Generics</span></li> <li data-link="interfaces-generics#builtin-interfaces"><span>Builtin Interfaces</span></li> </ul> </li> diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 4ed38ef69..36aa3a313 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -583,6 +583,11 @@ class GenericTypeConstraintDecl : public TypeConstraintDecl TypeExp sub; TypeExp sup; + // If this decl is defined in a where clause, store the source location of the where token. + SourceLoc whereTokenLoc = SourceLoc(); + + bool isEqualityConstraint = false; + // Overrides should be public so base classes can access const TypeExp& _getSupOverride() const { return sup; } }; diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 40a943dc0..24f98391c 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -477,6 +477,18 @@ void ASTIterator<CallbackFunc, FilterFunc>::visitDecl(DeclBase* decl) } else if (auto typeConstraint = as<TypeConstraintDecl>(decl)) { + if (auto genericTypeConstraint = as<GenericTypeConstraintDecl>(typeConstraint)) + { + // A generic constraint decl has a left hand side and right hand side expression + // for the base and super type of the constraint. + // In the case of a folded-in constraint syntax as in `Foo<T:IBar>`, + // the left hand side of the constraint is represented by the same token + // as the parameter decl itself, so we don't need to traverse into it. + // In the case of `Foo<T> where T:IBar`, the left hand side is its own + // expression so we do want to traverse it. + if (genericTypeConstraint->whereTokenLoc.isValid()) + visitExpr(genericTypeConstraint->sub.exp); + } visitExpr(typeConstraint->getSup().exp); } else if (auto typedefDecl = as<TypeDefDecl>(decl)) diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 21948dc04..83a4cf353 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1510,14 +1510,6 @@ namespace Slang const RequirementDictionary& getRequirementDictionary() { - if (m_requirementDictionary.getCount() != m_requirements.getCount()) - { - for (Index i = m_requirementDictionary.getCount(); i < m_requirements.getCount(); i++) - { - auto& r = m_requirements[i]; - m_requirementDictionary.add(r.key, r.value); - } - } return m_requirementDictionary; } @@ -1532,11 +1524,8 @@ namespace Slang // Whether or not this witness table is an extern declaration. bool isExtern = false; - // Satisfying values of each requirement. - List<KeyValuePair<Decl*, RequirementWitness>> m_requirements; - // Cached dictionary for looking up satisfying values. - SLANG_UNREFLECTED RequirementDictionary m_requirementDictionary; + RequirementDictionary m_requirementDictionary; RefPtr<WitnessTable> specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 3a1318696..401d73e29 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -72,7 +72,7 @@ class DeclRefType : public Type }; template<typename T> -DeclRef<T> isDeclRefTypeOf(Type* type) +DeclRef<T> isDeclRefTypeOf(Val* type) { if (auto declRefType = as<DeclRefType>(type)) { diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 2599ce46a..8752195cb 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -626,6 +626,13 @@ class DeclaredSubtypeWitness : public SubtypeWitness return as<DeclRefBase>(getOperand(2)); } + bool isEquality() + { + if (auto declRef = getDeclRef().as<GenericTypeConstraintDecl>()) + return declRef.getDecl()->isEqualityConstraint; + return false; + } + // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); @@ -895,4 +902,31 @@ void SubstitutionSet::forEachSubstitutionArg(F func) const } } } + +inline bool isTypeEqualityWitness(Val* witness) +{ + if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness)) + { + return declaredWitness->isEquality(); + } + else if (as<TypeEqualityWitness>(witness)) + { + return true; + } + else if (auto eachWitness = as<EachSubtypeWitness>(witness)) + { + return isTypeEqualityWitness(eachWitness->getPatternTypeWitness()); + } + else if (auto typePackWitness = as<TypePackSubtypeWitness>(witness)) + { + for (Index i = 0; i < typePackWitness->getCount(); i++) + { + if (!isTypeEqualityWitness(typePackWitness->getWitness(i))) + return false; + } + return true; + } + return false; +} + } // namespace Slang diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 3e3ed5297..90b0e44f5 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -648,7 +648,7 @@ namespace Slang else if (auto subEachType = as<EachType>(constraintDeclRef.getDecl()->sub.type)) constrainedGenericParams.add(as<DeclRefType>(subEachType->getElementType())->getDeclRef().getDecl()); - if (sub->equals(sup)) + if (sub->equals(sup) && isDeclRefTypeOf<InterfaceDecl>(sup)) { // We are trying to use an interface type itself to conform to the // type constraint. We can reach this case when the user code does @@ -674,6 +674,14 @@ namespace Slang sup, system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None); } + + if (constraintDecl->isEqualityConstraint) + { + // If constraint is an equality constraint, we need to make sure + // the witness is equality witness. + if (!isTypeEqualityWitness(subTypeWitness)) + subTypeWitness = nullptr; + } if(subTypeWitness) { diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 85040dc55..c0d7feaff 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1019,6 +1019,22 @@ namespace Slang *outCost = kConversionCost_CastToInterface; return true; } + else if (auto fromIsToWitness = tryGetSubtypeWitness(toType, fromType)) + { + // Is toType and fromType the same via some type equality witness? + // If so there is no need to do any conversion. + // + if (isTypeEqualityWitness(fromIsToWitness)) + { + if (outToExpr) + { + *outToExpr = createCastToSuperTypeExpr(toType, fromExpr, fromIsToWitness); + } + if (outCost) + *outCost = 0; + return true; + } + } // Disallow converting to a ParameterGroupType. // diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index ad3f94fc3..02e3241a9 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2278,14 +2278,18 @@ namespace Slang markSelfDifferentialMembersOfType(as<AggTypeDecl>(context->parentDecl), context->conformingType); + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType)); if (doesTypeSatisfyAssociatedTypeConstraintRequirement(context->conformingType, requirementDeclRef, witnessTable)) { - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType)); // Increase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types. m_astBuilder->incrementEpoch(); return true; } + else + { + witnessTable->m_requirementDictionary.remove(requirementDeclRef.getDecl()); + } // Something went wrong. return false; @@ -2471,19 +2475,16 @@ namespace Slang // conformance on the synthesized decl. checkAggTypeConformance(aggTypeDecl); - if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); + if (!doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) { - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); - - // Incrase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types. - m_astBuilder->incrementEpoch(); - return true; + // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed. + // If not, there is something wrong with the code synthesis logic. For now we just return false + // instead of crashing so the user can work around the issues. + witnessTable->m_requirementDictionary.remove(requirementDeclRef.getDecl()); + return false; } - - // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed. - // If not, there is something wrong with the code synthesis logic. For now we just return false - // instead of crashing so the user can work around the issues. - return false; + return true; } void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) @@ -2497,7 +2498,7 @@ namespace Slang CheckConstraintSubType(decl->sub); decl->sub = TranslateTypeNodeForced(decl->sub); decl->sup = TranslateTypeNodeForced(decl->sup); - if (!isValidGenericConstraintType(decl->sup) && !as<ErrorType>(decl->sub.type)) + if (!decl->isEqualityConstraint && !isValidGenericConstraintType(decl->sup) && !as<ErrorType>(decl->sub.type)) { getSink()->diagnose(decl->sup.exp, Diagnostics::invalidTypeForConstraint, decl->sup); } @@ -3548,18 +3549,28 @@ namespace Slang bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeConstraintRequirement(Type* satisfyingType, DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, RefPtr<WitnessTable> witnessTable) { + SLANG_UNUSED(satisfyingType); + // We will enumerate the type constraints placed on the // associated type and see if they can be satisfied. // bool conformance = true; Val* witness = nullptr; - for (auto requiredConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(m_astBuilder, requiredAssociatedTypeDeclRef)) + for (auto requiredConstraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, requiredAssociatedTypeDeclRef)) { // Grab the type we expect to conform to from the constraint. auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef); + auto subType = getSub(m_astBuilder, requiredConstraintDeclRef); + // Perform a search for a witness to the subtype relationship. - witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); + witness = tryGetSubtypeWitness(subType, requiredSuperType); + if (witness) + { + auto genConstraint = as<GenericTypeConstraintDecl>(requiredConstraintDeclRef.getDecl()); + if (genConstraint && genConstraint->isEqualityConstraint && !isTypeEqualityWitness(witness)) + witness = nullptr; + } if (witness) { // If a subtype witness was found, then the conformance @@ -3588,6 +3599,15 @@ namespace Slang if (declRefType->getDeclRef().getDecl()->hasModifier<ToBeSynthesizedModifier>()) return false; } + + // Register the satisfying type to the witness table + // before checking the constraints, since the subtype of + // the constraints maybe referencing the satisfying type via + // witness lookups. + auto requirementWitness = RequirementWitness(satisfyingType->getCanonicalType()); + witnessTable->m_requirementDictionary[requiredAssociatedTypeDeclRef.getDecl()] + = requirementWitness; + // We need to confirm that the chosen type `satisfyingType`, // meets all the constraints placed on the associated type // requirement `requiredAssociatedTypeDeclRef`. @@ -3601,15 +3621,11 @@ namespace Slang // TODO: if any conformance check failed, we should probably include // that in an error message produced about not satisfying the requirement. - if(conformance) + if (!conformance) { - // If all the constraints were satisfied, then the chosen - // type can indeed satisfy the interface requirement. - witnessTable->add( - requiredAssociatedTypeDeclRef.getDecl(), - RequirementWitness(satisfyingType->getCanonicalType())); + witnessTable->m_requirementDictionary.remove(requiredAssociatedTypeDeclRef.getDecl()); } - + return conformance; } @@ -9667,7 +9683,8 @@ namespace Slang ASTNodeType::DeclRefType); auto typeParamDecl = as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type)->getDeclRef().getDecl(); List<Type*>* constraintTypes = genericConstraints.tryGetValue(typeParamDecl); - assert(constraintTypes); + if (!constraintTypes) + continue; constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type); } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index f39d5be1d..dc4568f8a 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -766,6 +766,12 @@ namespace Slang InheritanceInfo _calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo); InheritanceInfo _calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* correspondingType, InheritanceCircularityInfo* circularityInfo); + // Get the inner most generic decl that a decl-ref is dependent on. + // For example, `Foo<T>` depends on the generic decl that defines `T`. + // + DeclRef<GenericDecl> getDependentGenericParent(DeclRef<Decl> declRef); + void getDependentGenericParentImpl(DeclRef<GenericDecl>& genericParent, DeclRef<Decl> declRef); + struct DirectBaseInfo { FacetList facets; diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp index 20f41c1bb..0d3929901 100644 --- a/source/slang/slang-check-inheritance.cpp +++ b/source/slang/slang-check-inheritance.cpp @@ -97,6 +97,51 @@ namespace Slang return info; } + void SharedSemanticsContext::getDependentGenericParentImpl(DeclRef<GenericDecl>& genericParent, DeclRef<Decl> declRef) + { + auto mergeParent = [](DeclRef<GenericDecl>& currentParent, DeclRef<GenericDecl> newParent) + { + if (!currentParent) + { + currentParent = newParent; + return; + } + if (currentParent == newParent) + return; + if (newParent.getDecl()->isChildOf(currentParent.getDecl())) + currentParent = newParent; + }; + + if (declRef.as<GenericTypeParamDeclBase>()) + { + if (!genericParent) + mergeParent(genericParent, declRef.getParent().as<GenericDecl>()); + return; + } + else if (auto lookupDeclRef = as<LookupDeclRef>(declRef.declRefBase)) + { + if (auto lookupSourceDeclRef = isDeclRefTypeOf<Decl>(lookupDeclRef->getLookupSource())) + getDependentGenericParentImpl(genericParent, lookupSourceDeclRef); + } + else if (auto genericAppDeclRef = as<GenericAppDeclRef>(declRef.declRefBase)) + { + for (Index i = 0; i < genericAppDeclRef->getArgCount(); i++) + { + if (auto argDeclRef = isDeclRefTypeOf<Decl>(genericAppDeclRef->getArg(i))) + { + getDependentGenericParentImpl(genericParent, argDeclRef); + } + } + } + } + + DeclRef<GenericDecl> SharedSemanticsContext::getDependentGenericParent(DeclRef<Decl> declRef) + { + DeclRef<GenericDecl> genericParent; + getDependentGenericParentImpl(genericParent, declRef); + return genericParent; + } + InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo) { // This method is the main engine for computing linearized inheritance @@ -305,39 +350,82 @@ namespace Slang // We now look at the structure of the declaration itself // to help us enumerate the direct bases. // - if (auto aggTypeDeclBaseRef = declRef.as<AggTypeDeclBase>()) + auto currentDeclRef = declRef; + for (; currentDeclRef;) { - // In the case where we have an aggregate type or `extension` - // declaration, we can use the explicit list of direct bases. - // - for (auto typeConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(_getASTBuilder(), aggTypeDeclBaseRef)) + if (auto aggTypeDeclBaseRef = currentDeclRef.as<AggTypeDeclBase>()) { - visitor.ensureDecl(typeConstraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); - - // Note: In certain cases something takes the *syntactic* form of an inheritance - // clause, but it is not actually something that should be treated as implying - // a subtype relationship. For example, an `enum` declaration can use what looks - // like an inheritance clause to indicate its underlying "tag type." - // - // We skip such pseudo-inheritance relationships for the purposes of determining - // the linearized list of bases. + // In the case where we have an aggregate type or `extension` + // declaration, we can use the explicit list of direct bases. // - if (typeConstraintDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>()) - continue; + for (auto typeConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(_getASTBuilder(), aggTypeDeclBaseRef)) + { + // Note: In certain cases something takes the *syntactic* form of an inheritance + // clause, but it is not actually something that should be treated as implying + // a subtype relationship. For example, an `enum` declaration can use what looks + // like an inheritance clause to indicate its underlying "tag type." + // + // We skip such pseudo-inheritance relationships for the purposes of determining + // the linearized list of bases. + // + if (typeConstraintDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>()) + continue; - // The base type and subtype witness can easily be determined - // using the `InheritanceDecl`. - // - auto baseType = getSup(astBuilder, typeConstraintDeclRef); - auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness( - selfType, - baseType, - typeConstraintDeclRef); + // The only case we will ever see a GenericTypeConstraintDecl inside a AggTypeDecl is when + // AggTypeDecl is a associatedtype decl. In this case, we will only lookup the type constraint + // if the constraint is on the associated type itself. + // + auto genericTypeConstraintDeclRef = typeConstraintDeclRef.as<GenericTypeConstraintDecl>(); + if (genericTypeConstraintDeclRef) + { + // If the base expr on the constraint isn't even a `VarExpr`, then it can't be referencing + // the associated type itself and we can skip this constraint. + if (!genericTypeConstraintDeclRef.getDecl()->sub.type + && !as<VarExpr>(genericTypeConstraintDeclRef.getDecl()->sub.exp)) + continue; + } + + visitor.ensureDecl(typeConstraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); - addDirectBaseType(baseType, satisfyingWitness); + // For generic type constraint decls, always make sure it is about the type being checked. + // + if (genericTypeConstraintDeclRef) + { + auto subType = getSub(astBuilder, genericTypeConstraintDeclRef); + if (subType != selfType) + continue; + } + else if (currentDeclRef != declRef) + { + continue; + } + // The base type and subtype witness can easily be determined + // using the `InheritanceDecl`. + // + auto baseType = getSup(astBuilder, typeConstraintDeclRef); + auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness( + selfType, + baseType, + typeConstraintDeclRef); + + addDirectBaseType(baseType, satisfyingWitness); + } } + if (currentDeclRef.as<AssocTypeDecl>()) + { + // If the current type is an associated type, continue inspecting the base/parent of the + // associatedtype to discover additional constraints defined on the parent associatedtype decls. + // + if (auto lookupDeclRef = as<LookupDeclRef>(currentDeclRef.declRefBase)) + { + currentDeclRef = isDeclRefTypeOf<Decl>(lookupDeclRef->getLookupSource()).as<AssocTypeDecl>(); + continue; + } + } + break; } - else if (auto genericTypeParamDeclRef = declRef.as<GenericTypeParamDeclBase>()) + + if (auto genericDeclRef = getDependentGenericParent(declRef)) { // The constraints placed on a generic type parameter are siblings of that // parameter in its parent `GenericDecl`, so we need to enumerate all of @@ -349,13 +437,11 @@ namespace Slang // representation would need to take into account canonicalization of // constraints. - auto genericDeclRef = genericTypeParamDeclRef.getParent().as<GenericDecl>(); - SLANG_ASSERT(genericDeclRef); ensureDecl(&visitor, genericDeclRef.getDecl(), DeclCheckState::CanSpecializeGeneric); if (auto extensionDecl = as<ExtensionDecl>(genericDeclRef.getDecl()->inner)) { - if (isDeclRefTypeOf<GenericTypeParamDecl>(extensionDecl->targetType.type) == genericTypeParamDeclRef) + if (isDeclRefTypeOf<GenericTypeParamDecl>(extensionDecl->targetType.type) == declRef) { // If `T` is a generic parameter where the same generic is an extension on `T`, // then we need to add the extension itself as a facet. @@ -377,13 +463,10 @@ namespace Slang auto superType = getSup(astBuilder, constraintDeclRef); // We only consider constraints where the type represented - // by `genericTypeParamDeclRef` is the subtype, since those + // by `declRef` is the subtype, since those // constraints are the ones that give us information about // the declared supertypes. // - // TODO: consider whether other kinds of constraints could - // also apply here. - // auto subDeclRefType = as<DeclRefType>(subType); if (!subDeclRefType) { @@ -394,7 +477,7 @@ namespace Slang if (!subDeclRefType) continue; } - if (subDeclRefType->getDeclRef() != genericTypeParamDeclRef) + if (subDeclRefType->getDeclRef() != declRef) continue; // Because the constraint is a declared inheritance relationship, @@ -402,9 +485,9 @@ namespace Slang // as in all the preceding cases. // auto satisfyingWitness = _getASTBuilder()->getDeclaredSubtypeWitness( - selfType, - superType, - constraintDeclRef); + selfType, + superType, + constraintDeclRef); addDirectBaseType(superType, satisfyingWitness); } } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f308f340d..5afca48b3 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -291,6 +291,10 @@ INST(IndexedFieldKey, indexedFieldKey, 2, HOISTABLE) // A placeholder witness that ThisType implements the enclosing interface. // Used only in interface definitions. INST(ThisTypeWitness, thisTypeWitness, 1, 0) + +// A placeholder witness for the fact that two types are equal. +INST(TypeEqualityWitness, TypeEqualityWitness, 2, HOISTABLE) + INST(GlobalHashedStringLiterals, global_hashed_string_literals, 0, 0) INST(Module, module, 0, PARENT) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 8f648aabd..fc963697a 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4075,6 +4075,8 @@ public: IRInst* createThisTypeWitness(IRType* interfaceType); + IRInst* getTypeEqualityWitness(IRType* witnessType, IRType* type1, IRType* type2); + IRInterfaceRequirementEntry* createInterfaceRequirementEntry( IRInst* requirementKey, IRInst* requirementVal); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index acc6abd57..d02c01105 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4614,6 +4614,15 @@ namespace Slang return witness; } + IRInst* IRBuilder::getTypeEqualityWitness(IRType* witnessType, IRType* type1, IRType* type2) + { + IRInst* operands[2] = { type1, type2 }; + return (IRType*)createIntrinsicInst( + witnessType, + kIROp_TypeEqualityWitness, + 2, + operands); + } IRStructType* IRBuilder::createStructType() { @@ -8347,6 +8356,7 @@ namespace Slang case kIROp_InterfaceRequirementEntry: case kIROp_Block: case kIROp_Each: + case kIROp_TypeEqualityWitness: return false; /// Liveness markers have no side effects diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 2d6ed2568..06b7c937f 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -717,6 +717,14 @@ bool _findAstNodeImpl(ASTLookupContext& context, SyntaxNode* node) ASTLookupExprVisitor visitor(&context); if (visitor.dispatchIfNotNull(typeConstraint->getSup().exp)) return true; + if (auto genTypeConstraint = as<GenericTypeConstraintDecl>(node)) + { + if (genTypeConstraint->whereTokenLoc.isValid()) + { + if (visitor.dispatchIfNotNull(genTypeConstraint->sub.exp)) + return true; + } + } } else if (auto typedefDecl = as<TypeDefDecl>(node)) { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f62bb631e..813467743 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1753,6 +1753,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower lowerType(context, val->getSup()))); } + LoweredValInfo visitTypeEqualityWitness(TypeEqualityWitness* val) + { + auto subType = lowerType(context, val->getSub()); + auto supType = lowerType(context, val->getSup()); + auto witnessType = context->irBuilder->getWitnessTableType( + lowerType(context, val->getSup())); + return LoweredValInfo::simple(context->irBuilder->getTypeEqualityWitness(witnessType, subType, supType)); + } + LoweredValInfo visitTransitiveSubtypeWitness( TransitiveSubtypeWitness* val) { @@ -4989,6 +4998,19 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> auto superType = lowerType(context, expr->type); auto value = lowerRValueExpr(context, expr->valueArg); + // First, we check if the witness is a type equality witness. + // If so, we can simply emit a bit cast to the target type that should eventually + // fold out to a no-op. + // Note: if we are going to equivalent but not identical types in the future, + // then the cast between equivalent types shouldn't be as simple as a bit cast + // and will require actual coercion logic between the two types. + // For now, we don't support type equivalence witness so this is safe for + // equal types. + if (isTypeEqualityWitness(expr->witnessArg)) + { + return LoweredValInfo::simple(getBuilder()->emitBitCast(superType, getSimpleVal(context, value))); + } + // The actual operation that we need to perform here // depends on the kind of subtype relationship we // are making use of. @@ -5039,7 +5061,6 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> return emitCastToConcreteSuperTypeRec(value, superType, expr->witnessArg); } } - SLANG_UNEXPECTED("unexpected case of subtype relationship"); UNREACHABLE_RETURN(LoweredValInfo()); } @@ -8400,8 +8421,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> for (auto constraintDecl : decl->getMembersOfType<GenericTypeConstraintDecl>()) { auto baseType = lowerType(context, constraintDecl->sup.type); - SLANG_ASSERT(baseType && baseType->getOp() == kIROp_InterfaceType); - constraintInterfaces.add((IRInterfaceType*)baseType); + if (baseType && baseType->getOp() == kIROp_InterfaceType) + constraintInterfaces.add((IRInterfaceType*)baseType); } auto assocType = context->irBuilder->getAssociatedType( constraintInterfaces.getArrayView().arrayView); @@ -10686,7 +10707,8 @@ LoweredValInfo emitDeclRef( for (auto argVal : genericSubst->getArgs()) { auto irArgVal = lowerSimpleVal(context, argVal); - SLANG_ASSERT(irArgVal); + if (!irArgVal) + continue; // It is possible that some of the arguments to the generic // represent conformances to conjunction types like `A & B`. diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 121c68be7..5729adb29 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -148,6 +148,12 @@ namespace Slang resetLookupScope(); } + void PushScope(Scope* newScope) + { + currentScope = newScope; + resetLookupScope(); + } + void pushScopeAndSetParent(ContainerDecl* containerDecl) { containerDecl->parentDecl = currentScope->containerDecl; @@ -306,6 +312,8 @@ namespace Slang static Expr* _parseGenericArg(Parser* parser); + static Expr* parsePrefixExpr(Parser* parser); + // static void Unexpected( @@ -721,6 +729,16 @@ namespace Slang return false; } + bool AdvanceIf(Parser* parser, char const* text, Token* outToken) + { + if (parser->LookAheadToken(text)) + { + *outToken = parser->ReadToken(); + return true; + } + return false; + } + /// Information on how to parse certain pairs of matches tokens struct MatchedTokenInfo { @@ -1496,10 +1514,47 @@ namespace Slang } else { - // default case is a type parameter - paramDecl = parser->astBuilder->create<GenericTypeParamDecl>(); - parser->FillPosition(paramDecl); - paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + // Disambiguate between a type parameter and a value parameter. + // If next token is "typename", then it is a type parameter. + bool isTypeParam = AdvanceIf(parser, "typename"); + if (!isTypeParam) + { + // Otherwise, if the next token is an identifier, followed by a colon, comma, '=' or '>', then it is a type parameter. + isTypeParam = parser->LookAheadToken(TokenType::Identifier); + auto nextNextTokenType = peekTokenType(parser, 1); + switch (nextNextTokenType) + { + case TokenType::Colon: + case TokenType::Comma: + case TokenType::OpGreater: + case TokenType::OpAssign: + break; + default: + isTypeParam = false; + break; + } + } + + if (isTypeParam) + { + // Parse as a type parameter. + paramDecl = parser->astBuilder->create<GenericTypeParamDecl>(); + parser->FillPosition(paramDecl); + paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + } + else + { + // Parse as a traditional syntax value parameter in the form of `type paramName`. + auto valueParamDecl = parser->astBuilder->create<GenericValueParamDecl>(); + parser->FillPosition(valueParamDecl); + valueParamDecl->type = parser->ParseTypeExp(); + valueParamDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + if (AdvanceIf(parser, TokenType::OpAssign)) + { + valueParamDecl->initExpr = parser->ParseInitExpr(); + } + return valueParamDecl; + } } if (AdvanceIf(parser, TokenType::Colon)) { @@ -1600,7 +1655,43 @@ namespace Slang } else { - return parseInner(nullptr); + auto genericParent = parser->currentScope ? as<GenericDecl>(parser->currentScope->containerDecl) : nullptr; + return parseInner(genericParent); + } + } + + static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericParent) + { + if (!genericParent) + return; + Token whereToken; + while (AdvanceIf(parser, "where", &whereToken)) + { + auto subType = parser->ParseTypeExp(); + if (AdvanceIf(parser, TokenType::Colon)) + { + for (;;) + { + auto constraint = parser->astBuilder->create<GenericTypeConstraintDecl>(); + constraint->whereTokenLoc = whereToken.loc; + parser->FillPosition(constraint); + constraint->sub = subType; + constraint->sup = parser->ParseTypeExp(); + AddMember(genericParent, constraint); + if (!AdvanceIf(parser, TokenType::Comma)) + break; + } + } + else if (AdvanceIf(parser, TokenType::OpEql)) + { + auto constraint = parser->astBuilder->create<GenericTypeConstraintDecl>(); + constraint->whereTokenLoc = whereToken.loc; + constraint->isEqualityConstraint = true; + parser->FillPosition(constraint); + constraint->sub = subType; + constraint->sup = parser->ParseTypeExp(); + AddMember(genericParent, constraint); + } } } @@ -1702,7 +1793,7 @@ namespace Slang decl->loc = declaratorInfo.nameAndLoc.loc; decl->nameAndLoc = declaratorInfo.nameAndLoc; - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { // HACK: The return type of the function will already have been // parsed in a scope that didn't include the function's generic @@ -1731,6 +1822,12 @@ namespace Slang } _parseOptSemantics(parser, decl); + + auto funcScope = parser->currentScope; + parser->PopScope(); + maybeParseGenericConstraints(parser, genericParent); + parser->PushScope(funcScope); + decl->body = parseOptBody(parser); if (auto block = as<BlockStmt>(decl->body)) { @@ -2597,7 +2694,7 @@ namespace Slang } else if (parser->LookAheadToken("expand") || parser->LookAheadToken("each")) { - typeSpec.expr = parser->ParseExpression(); + typeSpec.expr = parsePrefixExpr(parser); return typeSpec; } // Uncomment should we decide to enable (a,b,c) tuple types @@ -3335,12 +3432,13 @@ namespace Slang static NodeBase* parseExtensionDecl(Parser* parser, void* /*userData*/) { - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { ExtensionDecl* decl = parser->astBuilder->create<ExtensionDecl>(); parser->FillPosition(decl); decl->targetType = parser->ParseTypeExp(); parseOptionalInheritanceClause(parser, decl); + maybeParseGenericConstraints(parser, genericParent); parseDeclBody(parser, decl); return decl; }); @@ -3357,16 +3455,27 @@ namespace Slang parser->FillPosition(paramConstraint); // substitution needs to be filled during check - Type* paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl)); + Type* paramType = nullptr; + if (as<GenericTypeParamDeclBase>(decl)) + { + paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl)); - SharedTypeExpr* paramTypeExpr = parser->astBuilder->create<SharedTypeExpr>(); - paramTypeExpr->loc = decl->loc; - paramTypeExpr->base.type = paramType; - paramTypeExpr->type = QualType(parser->astBuilder->getTypeType(paramType)); + SharedTypeExpr* paramTypeExpr = parser->astBuilder->create<SharedTypeExpr>(); + paramTypeExpr->loc = decl->loc; + paramTypeExpr->base.type = paramType; + paramTypeExpr->type = QualType(parser->astBuilder->getTypeType(paramType)); - paramConstraint->sub = TypeExp(paramTypeExpr); - paramConstraint->sup = parser->ParseTypeExp(); + paramConstraint->sub = TypeExp(paramTypeExpr); + } + else if (as<AssocTypeDecl>(decl)) + { + auto varExpr = parser->astBuilder->create<VarExpr>(); + varExpr->scope = parser->currentScope; + varExpr->name = decl->getName(); + paramConstraint->sub.exp = varExpr; + } + paramConstraint->sup = parser->ParseTypeExp(); AddMember(decl, paramConstraint); } while (AdvanceIf(parser, TokenType::Comma)); } @@ -3380,6 +3489,7 @@ namespace Slang assocTypeDecl->nameAndLoc = NameLoc(nameToken); assocTypeDecl->loc = nameToken.loc; parseOptionalGenericConstraints(parser, assocTypeDecl); + maybeParseGenericConstraints(parser, assocTypeDecl); parser->ReadToken(TokenType::Semicolon); return assocTypeDecl; } @@ -3424,11 +3534,12 @@ namespace Slang AdvanceIf(parser, TokenType::CompletionRequest); decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { // We allow for an inheritance clause on a `struct` // so that it can conform to interfaces. parseOptionalInheritanceClause(parser, decl); + maybeParseGenericConstraints(parser, genericParent); parseDeclBody(parser, decl); return decl; }); @@ -3661,7 +3772,7 @@ namespace Slang { ConstructorDecl* decl = parser->astBuilder->create<ConstructorDecl>(); - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { // Note: we leave the source location of this decl as invalid, to // trigger the fallback logic that fills in the location of the @@ -3680,6 +3791,10 @@ namespace Slang decl->nameAndLoc.name = getName(parser, "$init"); parseParameterList(parser, decl); + auto funcScope = parser->currentScope; + parser->PopScope(); + maybeParseGenericConstraints(parser, genericParent); + parser->PushScope(funcScope); decl->body = parseOptBody(parser); @@ -3997,7 +4112,7 @@ namespace Slang parser->FillPosition(decl); decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { parser->PushScope(decl); parseModernParamList(parser, decl); @@ -4009,6 +4124,10 @@ namespace Slang { decl->returnType = parser->ParseTypeExp(); } + auto funcScope = parser->currentScope; + parser->PopScope(); + maybeParseGenericConstraints(parser, genericParent); + parser->PushScope(funcScope); decl->body = parseOptBody(parser); if (auto blockStmt = as<BlockStmt>(decl->body)) decl->closingSourceLoc = blockStmt->closingSourceLoc; @@ -4025,8 +4144,9 @@ namespace Slang parser->FillPosition(decl); decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { + maybeParseGenericConstraints(parser, genericParent); if( expect(parser, TokenType::OpAssign) ) { decl->type = parser->ParseTypeExp(); @@ -4860,7 +4980,7 @@ namespace Slang rs->nameAndLoc.name = generateName(this); rs->nameAndLoc.loc = rs->loc; } - return parseOptGenericDecl(this, [&](GenericDecl*) + return parseOptGenericDecl(this, [&](GenericDecl* genericParent) { // We allow for an inheritance clause on a `struct` // so that it can conform to interfaces. @@ -4878,6 +4998,7 @@ namespace Slang rs->hasBody = false; return rs; } + maybeParseGenericConstraints(this, genericParent); parseDeclBody(this, rs); return rs; }); @@ -4977,9 +5098,10 @@ namespace Slang addModifier(decl, parser->astBuilder->create<TransparentModifier>()); } - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { parseOptionalInheritanceClause(parser, decl); + maybeParseGenericConstraints(parser, genericParent); parser->ReadToken(TokenType::LBrace); Token closingToken; parser->pushScopeAndSetParent(decl); @@ -7583,7 +7705,7 @@ namespace Slang { EachExpr* eachExpr = parser->astBuilder->create<EachExpr>(); eachExpr->loc = loc; - eachExpr->baseExpr = parser->ParseLeafExpression(); + eachExpr->baseExpr = parsePostfixExpr(parser); return eachExpr; } diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index a55f0eb1a..ec7169e04 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -338,7 +338,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt RefPtr<WitnessTable> result = new WitnessTable(); result->baseType = as<Type>(newBaseType); result->witnessedType = as<Type>(newWitnessedType); - for (auto requirement : m_requirements) + for (auto requirement : m_requirementDictionary) { auto newRequirement = requirement.value.specialize(astBuilder, subst); result->add(requirement.key, newRequirement); @@ -500,7 +500,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt void WitnessTable::add(Decl* decl, RequirementWitness const& witness) { - m_requirements.add(KeyValuePair<Decl*, RequirementWitness>(decl, witness)); m_requirementDictionary.add(decl, witness); } diff --git a/tests/bugs/eroneous-generic-parse.slang b/tests/bugs/eroneous-generic-parse.slang index 80a693456..18bcb3b1e 100644 --- a/tests/bugs/eroneous-generic-parse.slang +++ b/tests/bugs/eroneous-generic-parse.slang @@ -1,8 +1,9 @@ -//DIAGNOSTIC_TEST:SIMPLE: -target hlsl -entry computeMain -stage compute +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target hlsl -entry computeMain -stage compute //TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=16):out RWStructuredBuffer<int> outputBuffer; +// CHECK: error 20001 // Previously this definition would lead to an infinite loop in parsing. int doThing<1>() { return 2; } diff --git a/tests/bugs/eroneous-generic-parse.slang.expected b/tests/bugs/eroneous-generic-parse.slang.expected deleted file mode 100644 index a6ee73d88..000000000 --- a/tests/bugs/eroneous-generic-parse.slang.expected +++ /dev/null @@ -1,8 +0,0 @@ -result code = -1 -standard error = { -tests/bugs/eroneous-generic-parse.slang(7): error 20001: unexpected integer literal, expected identifier -int doThing<1>() { return 2; } - ^ -} -standard output = { -} diff --git a/tests/bugs/parser-infinite-loop.slang b/tests/bugs/parser-infinite-loop.slang index 70abc9260..036202c4a 100644 --- a/tests/bugs/parser-infinite-loop.slang +++ b/tests/bugs/parser-infinite-loop.slang @@ -1,5 +1,6 @@ -//DIAGNOSTIC_TEST:SIMPLE: +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): +// CHECK: error 20001: struct test { float3 field; diff --git a/tests/bugs/parser-infinite-loop.slang.expected b/tests/bugs/parser-infinite-loop.slang.expected deleted file mode 100644 index df1d731bc..000000000 --- a/tests/bugs/parser-infinite-loop.slang.expected +++ /dev/null @@ -1,11 +0,0 @@ -result code = -1 -standard error = { -tests/bugs/parser-infinite-loop.slang(10): error 20001: unexpected integer literal, expected identifier - vector<int,2> v; - ^ -tests/bugs/parser-infinite-loop.slang(10): error 20001: unexpected identifier, expected '(' - vector<int,2> v; - ^ -} -standard output = { -} diff --git a/tests/language-feature/generics/where-1.slang b/tests/language-feature/generics/where-1.slang new file mode 100644 index 000000000..904e03d52 --- /dev/null +++ b/tests/language-feature/generics/where-1.slang @@ -0,0 +1,20 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +// Test that we can use `where` clause to constrain generic type parameters. + +T process<T, int N>(vector<T, N> v) where T : IFloat +{ + return v[0]; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + float a = 1.0; + outputBuffer[0] = process(a); + // CHECK: 1.0 +} diff --git a/tests/language-feature/generics/where-2.slang b/tests/language-feature/generics/where-2.slang new file mode 100644 index 000000000..b3e9bb86a --- /dev/null +++ b/tests/language-feature/generics/where-2.slang @@ -0,0 +1,33 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +// Test that we can use `where` clause to constrain associatedtype of generic type parameters. + +interface IFoo +{ + associatedtype TA; +} + +struct FooImpl : IFoo +{ + typealias TA = int; +} + +__generic<typename T> +T.TA process(T v) + where T : IFoo + where T.TA == int +{ + return 1; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + FooImpl fooImpl; + outputBuffer[0] = process(fooImpl); + // CHECK: 1.0 +} diff --git a/tests/language-feature/generics/where-3.slang b/tests/language-feature/generics/where-3.slang new file mode 100644 index 000000000..730373b76 --- /dev/null +++ b/tests/language-feature/generics/where-3.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +// Test that we can use `where` clause to constrain the type of a type pack. + +interface IFoo +{ + associatedtype TA; +} + +struct FooImpl : IFoo +{ + typealias TA = int; +} + +void add(inout int a, int b) +{ + a += b; +} + +int process<each T>(T v) where T == int +{ + int result = 0; + expand add(result, each v); + return result; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + outputBuffer[0] = process(1,2,3); + // CHECK: 6.0 +} diff --git a/tests/language-feature/generics/where-4.slang b/tests/language-feature/generics/where-4.slang new file mode 100644 index 000000000..9f86ca7dc --- /dev/null +++ b/tests/language-feature/generics/where-4.slang @@ -0,0 +1,52 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +// Test that we can use `where` clause to constrain an associatedtype. + +interface IBar +{ + associatedtype TB; + TB get(); +} + +interface IFoo +{ + associatedtype TA : IBar where TA.TB == int; + + TA getVal(); +} + +struct BarImpl : IBar +{ + typealias TB = int; + int x; + int get() { return x; } +} + +struct FooImpl : IFoo +{ + typealias TA = BarImpl; + TA getVal() { TA a; a.x = 1; return a; } +} + +int helper<T : IFoo>(T foo) +{ + // foo.getVal().get() has type `T.TA.TB`, + // because there is a type equality constraint defined on + // `IFoo.TA` such that `IFoo::TA.TB == int`, we should be able + // to conclude that `T.TA.TB` is `int` and the `return` here + // should type check. + return foo.getVal().get(); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + FooImpl foo; + + outputBuffer[0] = helper(foo); + // CHECK: 1.0 +} |
