diff options
| author | Yong He <yonghe@outlook.com> | 2024-09-04 13:25:37 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-04 13:25:37 -0700 |
| commit | ddd29057e48a5b309726750e3daf78bfd073038e (patch) | |
| tree | a054b99acb87d61ef4818dce5fa837ccfd050288 | |
| parent | 56a3c028a6725e13a2ae3a724eaee05ad9f4802a (diff) | |
Fix extension override behavior, and disallow extension on interface types. (#4977)
* Add a test to ensure extension does not override existing conformance.
* Fix doc.
* Update documentation.
* Fix doc.
* Add diagnostic test.
| -rw-r--r-- | docs/user-guide/06-interfaces-generics.md | 51 | ||||
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 40 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-check-inheritance.cpp | 52 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 32 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 1 | ||||
| -rw-r--r-- | tests/diagnostics/interfaces/interface-extension.slang | 10 | ||||
| -rw-r--r-- | tests/language-feature/extensions/extension-override.slang | 65 |
8 files changed, 195 insertions, 68 deletions
diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md index 982b2ba8b..569b36e1d 100644 --- a/docs/user-guide/06-interfaces-generics.md +++ b/docs/user-guide/06-interfaces-generics.md @@ -742,7 +742,8 @@ See [if-let syntax](convenience-features.html#if_let-syntax) for more details. Extensions to Interfaces ----------------------------- -In addition to extending ordinary types, you can define extensions on interfaces as well: +In addition to extending ordinary types, you can define extensions on all types that conforms to some interface: + ```csharp // An example interface. interface IFoo @@ -750,9 +751,8 @@ interface IFoo int foo(); } -// Extending `IFoo` with a new method requirement -// with a default implementation. -extension IFoo +// Extend any type `T` that conforms to `IFoo` with a `bar` method. +extension<T:IFoo> T { int bar() { return 0; } } @@ -765,42 +765,47 @@ int use(IFoo foo) } ``` -Although the syntax of above listing suggests that we are extending an interface with additional requirements, this interpretation does not make logical sense in many ways. Consider a type `MyType` that exists before the extension is defined: -```csharp -struct MyType : IFoo -{ - int foo() { return 0; } -} -``` +Note that `interface` types cannot be extended, because extending an `interface` with new requirements would make all existing types that conforms +to the interface no longer valid. -If we extend the `IFoo` with new requirements, the existing `MyType` definition would become invalid since `MyType` no longer provides implementations to all interface requirements. Instead, what an `extension` on an interface `IFoo` means is that for all types that conforms to the `IFoo` interface and does not have a `bar` method defined, add a `bar` method defined in this extension to that type so that all `IFoo` typed values have a `bar` method defined. If a type already defines a matching `bar` method, then the existing method will always override the default method provided in the extension: +In the presence of extensions, it is possible for a type to have multiple ways to +conform to an interface. In this case, Slang will always prefer the more specific conformance +over the generic one. For example, the following code illustrates this behavior: ```csharp +interface IBase{} interface IFoo { int foo(); } -struct MyFoo1 : IFoo + +// MyObject directly implements IBase: +struct MyObject : IBase, IFoo { int foo() { return 0; } } -extension IFoo + +// Generic extension that applies to all types that conforms to `IBase`: +extension<T:IBase> T : IFoo { - int bar() { return 0; } + int foo() { return 1; } } -struct MyFoo2 : IFoo + +int helper<T:IFoo>(T obj) { - int foo() { return 0; } - int bar() { return 1; } + return obj.foo(); } -void test() + +int test() { - MyFoo1 f1; - MyFoo2 f2; - int a = f1.bar(); // a == 0, calling the method in the extension. - int b = f2.bar(); // b == 1, calling the existing method in `MyFoo2`. + MyObject obj; + + // Returns 0, the conformance defined directly by the type + // is preferred. + return helper(obj); } ``` + This feature is similar to extension traits in Rust. diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index e5551a875..3e3ed5297 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -83,16 +83,22 @@ namespace Slang Type* interfaceType) { // The most basic test here should be: does the type declare conformance to the trait. - if (isSubtype(type, interfaceType, constraints->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None)) - return type; - - // If additional subtype witnesses are provided for `type` in `constraints`, - // try to use them to see if the interface is satisfied. + if (constraints->subTypeForAdditionalWitnesses == type) { + // If additional subtype witnesses are provided for `type` in `constraints`, + // try to use them to see if the interface is satisfied. if (constraints->additionalSubtypeWitnesses->containsKey(interfaceType)) return type; } + else + { + if (isSubtype( + type, + interfaceType, + constraints->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None)) + return type; + } // Just because `type` doesn't conform to the given `interfaceDeclRef`, that // doesn't necessarily indicate a failure. It is possible that we have a call @@ -653,18 +659,22 @@ namespace Slang } // Search for a witness that shows the constraint is satisfied. - auto subTypeWitness = isSubtype( - sub, - sup, - system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None); - if (!subTypeWitness) + SubtypeWitness* subTypeWitness = nullptr; + if (sub == system->subTypeForAdditionalWitnesses) { - if (sub == system->subTypeForAdditionalWitnesses) - { - // If no witness was found, try to find the witness from additional witness. - system->additionalSubtypeWitnesses->tryGetValue(sup, subTypeWitness); - } + // If we are trying to find the subtype info for a type whose inheritance info is + // being calculated, use what we have already known about the type. + system->additionalSubtypeWitnesses->tryGetValue(sup, subTypeWitness); } + else + { + // The general case is to initiate a subtype query. + subTypeWitness = isSubtype( + sub, + sup, + system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None); + } + if(subTypeWitness) { // We found a witness, so it will become an (implicit) argument. diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5654ac7a6..e5e8e8acc 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -8246,7 +8246,12 @@ namespace Slang if (auto targetDeclRefType = as<DeclRefType>(decl->targetType)) { // Attach our extension to that type as a candidate... - if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>()) + if (targetDeclRefType->getDeclRef().as<InterfaceDecl>()) + { + getSink()->diagnose(decl->targetType.exp, Diagnostics::invalidExtensionOnInterface, decl->targetType); + return; + } + else if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>()) { auto aggTypeDecl = aggTypeDeclRef.getDecl(); @@ -8303,6 +8308,7 @@ namespace Slang // to extend. // decl->targetType = CheckProperType(decl->targetType); + _validateExtensionDeclTargetType(decl); _validateExtensionDeclMembers(decl); @@ -9188,13 +9194,13 @@ namespace Slang // look up extensions based on what would be visible to that // module. // - // We need to consider the extensions declared in the module itself, + // Extensions declared in the module itself should have already + // been registered when we check them, but we still need to bring // along with everything the module imported. // // Note: there is an implicit assumption here that the `importedModules` // member on the `SharedSemanticsContext` is accurate in this case. // - _addCandidateExtensionsFromModule(m_module->getModuleDecl()); for( auto moduleDecl : this->importedModulesList ) { _addCandidateExtensionsFromModule(moduleDecl); diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp index 0dc80cdc3..20f41c1bb 100644 --- a/source/slang/slang-check-inheritance.cpp +++ b/source/slang/slang-check-inheritance.cpp @@ -422,40 +422,42 @@ namespace Slang { considerExtension(directAggTypeDeclRef, nullptr); } - HashSet<Type*> supTypesConsideredForExtensionApplication; - Dictionary<Type*, SubtypeWitness*> additionalSubtypeWitnesses; - for (;;) + if (!declRef.as<ExtensionDecl>()) { - // After we flatten the list of bases, we may discover additional opportunities - // to apply extensions. - List<DeclRef<AggTypeDecl>> supTypeWorkList; - for (auto curFacet : directBaseFacets) + HashSet<Type*> supTypesConsideredForExtensionApplication; + Dictionary<Type*, SubtypeWitness*> additionalSubtypeWitnesses; + for (;;) { - if (!curFacet->subtypeWitness) - continue; - auto inheritanceInfo = getInheritanceInfo(curFacet->subtypeWitness->getSup(), circularityInfo); - for (auto facet : inheritanceInfo.facets) + // After we flatten the list of bases, we may discover additional opportunities + // to apply extensions. + List<DeclRef<AggTypeDecl>> supTypeWorkList; + auto base = directBases.begin(); + for (auto baseFacet = directBaseFacets.getHead(); baseFacet.getImpl(); baseFacet = baseFacet->next) { - if (auto interfaceDeclRef = facet->origin.declRef.as<InterfaceDecl>()) + for (auto facet : (*base)->facets) { - SubtypeWitness* transitiveWitness = curFacet->subtypeWitness; - transitiveWitness = astBuilder->getTransitiveSubtypeWitness(curFacet->subtypeWitness, facet->subtypeWitness); - additionalSubtypeWitnesses.addIfNotExists(facet->origin.type, transitiveWitness); - if (supTypesConsideredForExtensionApplication.add(facet->origin.type)) + if (auto interfaceDeclRef = facet->origin.declRef.as<InterfaceDecl>()) { - supTypeWorkList.add(interfaceDeclRef); + SubtypeWitness* transitiveWitness = baseFacet->subtypeWitness; + transitiveWitness = astBuilder->getTransitiveSubtypeWitness(baseFacet->subtypeWitness, facet->subtypeWitness); + additionalSubtypeWitnesses.addIfNotExists(facet->origin.type, transitiveWitness); + if (supTypesConsideredForExtensionApplication.add(facet->origin.type)) + { + supTypeWorkList.add(interfaceDeclRef); + } } } + ++base; } + bool canExit = true; + for (auto baseItem : supTypeWorkList) + { + if (considerExtension(baseItem, &additionalSubtypeWitnesses)) + canExit = false; + } + if (canExit) + break; } - bool canExit = true; - for (auto baseItem : supTypeWorkList) - { - if (considerExtension(baseItem, &additionalSubtypeWitnesses)) - canExit = false; - } - if (canExit) - break; } // At this point, the list of direct bases (each with its own linearization) diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 6d376dba3..0e01eeed2 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1187,6 +1187,16 @@ namespace Slang return CountParameters(parentGeneric).required; } + DeclRef<Decl> getParentDeclRef(DeclRef<Decl> declRef) + { + auto parent = declRef.getParent(); + while (parent.as<GenericDecl>()) + { + parent = parent.getParent(); + } + return parent; + } + int SemanticsVisitor::CompareLookupResultItems( LookupResultItem const& left, LookupResultItem const& right) @@ -1204,13 +1214,31 @@ namespace Slang // directly (it is only visible through the requirement witness // information for inheritance declarations). // - auto leftDeclRefParent = left.declRef.getParent(); - auto rightDeclRefParent = right.declRef.getParent(); + auto leftDeclRefParent = getParentDeclRef(left.declRef); + auto rightDeclRefParent = getParentDeclRef(right.declRef); bool leftIsInterfaceRequirement = isInterfaceRequirement(left.declRef.getDecl()); bool rightIsInterfaceRequirement = isInterfaceRequirement(right.declRef.getDecl()); if(leftIsInterfaceRequirement != rightIsInterfaceRequirement) return int(leftIsInterfaceRequirement) - int(rightIsInterfaceRequirement); + // Prefer non-extension declarations over extension declarations. + bool leftIsExtension = as<ExtensionDecl>(leftDeclRefParent.getDecl()) != nullptr; + bool rightIsExtension = as<ExtensionDecl>(rightDeclRefParent.getDecl()) != nullptr; + if (leftIsExtension != rightIsExtension) + { + return int(leftIsExtension) - int(rightIsExtension); + } + else if (leftIsExtension) + { + // If both are declared in extensions, prefer the one that is least generic. + bool leftIsGeneric = leftDeclRefParent.getParent().as<GenericDecl>() != nullptr; + bool rightIsGeneric = rightDeclRefParent.getParent().as<GenericDecl>() != nullptr; + if (leftIsGeneric != rightIsGeneric) + { + return int(leftIsGeneric) - int(rightIsGeneric); + } + } + // Any decl is strictly better than a module decl. bool leftIsModule = (as<ModuleDeclarationDecl>(left.declRef) != nullptr); bool rightIsModule = (as<ModuleDeclarationDecl>(right.declRef) != nullptr); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 5285b5c6e..81170fac3 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -554,6 +554,7 @@ DIAGNOSTIC(30832, Error, invalidTypeForInheritance, "type '$0' cannot be used fo DIAGNOSTIC(30850, Error, invalidExtensionOnType, "type '$0' cannot be extended. `extension` can only be used to extend a nominal type.") DIAGNOSTIC(30851, Error, invalidMemberTypeInExtension, "$0 cannot be a part of an `extension`") +DIAGNOSTIC(30852, Error, invalidExtensionOnInterface, "cannot extend interface type '$0'. consider using a generic extension: `extension<T:$0> T {...}`.") // 309xx: subscripts DIAGNOSTIC(30900, Error, multiDimensionalArrayNotSupported, "multi-dimensional array is not supported.") diff --git a/tests/diagnostics/interfaces/interface-extension.slang b/tests/diagnostics/interfaces/interface-extension.slang new file mode 100644 index 000000000..b63b454ab --- /dev/null +++ b/tests/diagnostics/interfaces/interface-extension.slang @@ -0,0 +1,10 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):-target cpp -stage compute -entry main -disable-specialization + +interface IFoo{} + + +// CHECK: ([[# @LINE+1]]): error 30852 +extension IFoo +{ + int f() { return 0; } +}
\ No newline at end of file diff --git a/tests/language-feature/extensions/extension-override.slang b/tests/language-feature/extensions/extension-override.slang new file mode 100644 index 000000000..30fa64965 --- /dev/null +++ b/tests/language-feature/extensions/extension-override.slang @@ -0,0 +1,65 @@ +// Test that the override behavior around extensions and generic extensions works as expected. + +// When there are multiple ways for a type to conform to an interface, then the expected behavior +// is that: +// 1. If the type directly implements an interface, use that conformance. +// 2. Otherwise, if there is a direct extension on the type that makes it conform to the interface, use that +// extension. +// 3. Otherwise, if there is a generic extension that makes the type conform to the interface, use that. + +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj +interface IFoo +{ + int getVal(); +} + +interface IBar +{ + int getValPlusOne(); +} + +interface IBaz +{ + int getValPlusTwo(); +} + +struct MyInt +{ + int v; +} + +extension MyInt : IFoo +{ + int getVal() { return v; } +} + +extension MyInt : IBar +{ + int getValPlusOne() { return this.getVal() + 2; } +} + +extension<T: IFoo> T : IBar +{ + int getValPlusOne() { return this.getVal() + 1; } +} + +int helper1<T:IBar>(T v){ return v.getValPlusOne();} +int helper2<T:IFoo>(T v){ return v.getValPlusOne();} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + MyInt v = {1}; + + // CHECK: 3 + outputBuffer[0] = v.getValPlusOne(); // should call MyInt::ext::getValPlusOne(); + + // CHECK: 3 + outputBuffer[1] = helper1(v); // should call MyInt::ext::getValPlusOne(); + + // CHECK: 2 + outputBuffer[2] = helper2(v); // should call T::ext::getValPlusOne(); +}
\ No newline at end of file |
