summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-09-04 13:25:37 -0700
committerGitHub <noreply@github.com>2024-09-04 13:25:37 -0700
commitddd29057e48a5b309726750e3daf78bfd073038e (patch)
treea054b99acb87d61ef4818dce5fa837ccfd050288
parent56a3c028a6725e13a2ae3a724eaee05ad9f4802a (diff)
Fix extension override behavior, and disallow extension on interface types. (#4977)
* Add a test to ensure extension does not override existing conformance. * Fix doc. * Update documentation. * Fix doc. * Add diagnostic test.
-rw-r--r--docs/user-guide/06-interfaces-generics.md51
-rw-r--r--source/slang/slang-check-constraint.cpp40
-rw-r--r--source/slang/slang-check-decl.cpp12
-rw-r--r--source/slang/slang-check-inheritance.cpp52
-rw-r--r--source/slang/slang-check-overload.cpp32
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--tests/diagnostics/interfaces/interface-extension.slang10
-rw-r--r--tests/language-feature/extensions/extension-override.slang65
8 files changed, 195 insertions, 68 deletions
diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md
index 982b2ba8b..569b36e1d 100644
--- a/docs/user-guide/06-interfaces-generics.md
+++ b/docs/user-guide/06-interfaces-generics.md
@@ -742,7 +742,8 @@ See [if-let syntax](convenience-features.html#if_let-syntax) for more details.
Extensions to Interfaces
-----------------------------
-In addition to extending ordinary types, you can define extensions on interfaces as well:
+In addition to extending ordinary types, you can define extensions on all types that conforms to some interface:
+
```csharp
// An example interface.
interface IFoo
@@ -750,9 +751,8 @@ interface IFoo
int foo();
}
-// Extending `IFoo` with a new method requirement
-// with a default implementation.
-extension IFoo
+// Extend any type `T` that conforms to `IFoo` with a `bar` method.
+extension<T:IFoo> T
{
int bar() { return 0; }
}
@@ -765,42 +765,47 @@ int use(IFoo foo)
}
```
-Although the syntax of above listing suggests that we are extending an interface with additional requirements, this interpretation does not make logical sense in many ways. Consider a type `MyType` that exists before the extension is defined:
-```csharp
-struct MyType : IFoo
-{
- int foo() { return 0; }
-}
-```
+Note that `interface` types cannot be extended, because extending an `interface` with new requirements would make all existing types that conforms
+to the interface no longer valid.
-If we extend the `IFoo` with new requirements, the existing `MyType` definition would become invalid since `MyType` no longer provides implementations to all interface requirements. Instead, what an `extension` on an interface `IFoo` means is that for all types that conforms to the `IFoo` interface and does not have a `bar` method defined, add a `bar` method defined in this extension to that type so that all `IFoo` typed values have a `bar` method defined. If a type already defines a matching `bar` method, then the existing method will always override the default method provided in the extension:
+In the presence of extensions, it is possible for a type to have multiple ways to
+conform to an interface. In this case, Slang will always prefer the more specific conformance
+over the generic one. For example, the following code illustrates this behavior:
```csharp
+interface IBase{}
interface IFoo
{
int foo();
}
-struct MyFoo1 : IFoo
+
+// MyObject directly implements IBase:
+struct MyObject : IBase, IFoo
{
int foo() { return 0; }
}
-extension IFoo
+
+// Generic extension that applies to all types that conforms to `IBase`:
+extension<T:IBase> T : IFoo
{
- int bar() { return 0; }
+ int foo() { return 1; }
}
-struct MyFoo2 : IFoo
+
+int helper<T:IFoo>(T obj)
{
- int foo() { return 0; }
- int bar() { return 1; }
+ return obj.foo();
}
-void test()
+
+int test()
{
- MyFoo1 f1;
- MyFoo2 f2;
- int a = f1.bar(); // a == 0, calling the method in the extension.
- int b = f2.bar(); // b == 1, calling the existing method in `MyFoo2`.
+ MyObject obj;
+
+ // Returns 0, the conformance defined directly by the type
+ // is preferred.
+ return helper(obj);
}
```
+
This feature is similar to extension traits in Rust.
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index e5551a875..3e3ed5297 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -83,16 +83,22 @@ namespace Slang
Type* interfaceType)
{
// The most basic test here should be: does the type declare conformance to the trait.
- if (isSubtype(type, interfaceType, constraints->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None))
- return type;
-
- // If additional subtype witnesses are provided for `type` in `constraints`,
- // try to use them to see if the interface is satisfied.
+
if (constraints->subTypeForAdditionalWitnesses == type)
{
+ // If additional subtype witnesses are provided for `type` in `constraints`,
+ // try to use them to see if the interface is satisfied.
if (constraints->additionalSubtypeWitnesses->containsKey(interfaceType))
return type;
}
+ else
+ {
+ if (isSubtype(
+ type,
+ interfaceType,
+ constraints->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None))
+ return type;
+ }
// Just because `type` doesn't conform to the given `interfaceDeclRef`, that
// doesn't necessarily indicate a failure. It is possible that we have a call
@@ -653,18 +659,22 @@ namespace Slang
}
// Search for a witness that shows the constraint is satisfied.
- auto subTypeWitness = isSubtype(
- sub,
- sup,
- system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None);
- if (!subTypeWitness)
+ SubtypeWitness* subTypeWitness = nullptr;
+ if (sub == system->subTypeForAdditionalWitnesses)
{
- if (sub == system->subTypeForAdditionalWitnesses)
- {
- // If no witness was found, try to find the witness from additional witness.
- system->additionalSubtypeWitnesses->tryGetValue(sup, subTypeWitness);
- }
+ // If we are trying to find the subtype info for a type whose inheritance info is
+ // being calculated, use what we have already known about the type.
+ system->additionalSubtypeWitnesses->tryGetValue(sup, subTypeWitness);
}
+ else
+ {
+ // The general case is to initiate a subtype query.
+ subTypeWitness = isSubtype(
+ sub,
+ sup,
+ system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None);
+ }
+
if(subTypeWitness)
{
// We found a witness, so it will become an (implicit) argument.
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 5654ac7a6..e5e8e8acc 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -8246,7 +8246,12 @@ namespace Slang
if (auto targetDeclRefType = as<DeclRefType>(decl->targetType))
{
// Attach our extension to that type as a candidate...
- if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>())
+ if (targetDeclRefType->getDeclRef().as<InterfaceDecl>())
+ {
+ getSink()->diagnose(decl->targetType.exp, Diagnostics::invalidExtensionOnInterface, decl->targetType);
+ return;
+ }
+ else if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>())
{
auto aggTypeDecl = aggTypeDeclRef.getDecl();
@@ -8303,6 +8308,7 @@ namespace Slang
// to extend.
//
decl->targetType = CheckProperType(decl->targetType);
+
_validateExtensionDeclTargetType(decl);
_validateExtensionDeclMembers(decl);
@@ -9188,13 +9194,13 @@ namespace Slang
// look up extensions based on what would be visible to that
// module.
//
- // We need to consider the extensions declared in the module itself,
+ // Extensions declared in the module itself should have already
+ // been registered when we check them, but we still need to bring
// along with everything the module imported.
//
// Note: there is an implicit assumption here that the `importedModules`
// member on the `SharedSemanticsContext` is accurate in this case.
//
- _addCandidateExtensionsFromModule(m_module->getModuleDecl());
for( auto moduleDecl : this->importedModulesList )
{
_addCandidateExtensionsFromModule(moduleDecl);
diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp
index 0dc80cdc3..20f41c1bb 100644
--- a/source/slang/slang-check-inheritance.cpp
+++ b/source/slang/slang-check-inheritance.cpp
@@ -422,40 +422,42 @@ namespace Slang
{
considerExtension(directAggTypeDeclRef, nullptr);
}
- HashSet<Type*> supTypesConsideredForExtensionApplication;
- Dictionary<Type*, SubtypeWitness*> additionalSubtypeWitnesses;
- for (;;)
+ if (!declRef.as<ExtensionDecl>())
{
- // After we flatten the list of bases, we may discover additional opportunities
- // to apply extensions.
- List<DeclRef<AggTypeDecl>> supTypeWorkList;
- for (auto curFacet : directBaseFacets)
+ HashSet<Type*> supTypesConsideredForExtensionApplication;
+ Dictionary<Type*, SubtypeWitness*> additionalSubtypeWitnesses;
+ for (;;)
{
- if (!curFacet->subtypeWitness)
- continue;
- auto inheritanceInfo = getInheritanceInfo(curFacet->subtypeWitness->getSup(), circularityInfo);
- for (auto facet : inheritanceInfo.facets)
+ // After we flatten the list of bases, we may discover additional opportunities
+ // to apply extensions.
+ List<DeclRef<AggTypeDecl>> supTypeWorkList;
+ auto base = directBases.begin();
+ for (auto baseFacet = directBaseFacets.getHead(); baseFacet.getImpl(); baseFacet = baseFacet->next)
{
- if (auto interfaceDeclRef = facet->origin.declRef.as<InterfaceDecl>())
+ for (auto facet : (*base)->facets)
{
- SubtypeWitness* transitiveWitness = curFacet->subtypeWitness;
- transitiveWitness = astBuilder->getTransitiveSubtypeWitness(curFacet->subtypeWitness, facet->subtypeWitness);
- additionalSubtypeWitnesses.addIfNotExists(facet->origin.type, transitiveWitness);
- if (supTypesConsideredForExtensionApplication.add(facet->origin.type))
+ if (auto interfaceDeclRef = facet->origin.declRef.as<InterfaceDecl>())
{
- supTypeWorkList.add(interfaceDeclRef);
+ SubtypeWitness* transitiveWitness = baseFacet->subtypeWitness;
+ transitiveWitness = astBuilder->getTransitiveSubtypeWitness(baseFacet->subtypeWitness, facet->subtypeWitness);
+ additionalSubtypeWitnesses.addIfNotExists(facet->origin.type, transitiveWitness);
+ if (supTypesConsideredForExtensionApplication.add(facet->origin.type))
+ {
+ supTypeWorkList.add(interfaceDeclRef);
+ }
}
}
+ ++base;
}
+ bool canExit = true;
+ for (auto baseItem : supTypeWorkList)
+ {
+ if (considerExtension(baseItem, &additionalSubtypeWitnesses))
+ canExit = false;
+ }
+ if (canExit)
+ break;
}
- bool canExit = true;
- for (auto baseItem : supTypeWorkList)
- {
- if (considerExtension(baseItem, &additionalSubtypeWitnesses))
- canExit = false;
- }
- if (canExit)
- break;
}
// At this point, the list of direct bases (each with its own linearization)
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 6d376dba3..0e01eeed2 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1187,6 +1187,16 @@ namespace Slang
return CountParameters(parentGeneric).required;
}
+ DeclRef<Decl> getParentDeclRef(DeclRef<Decl> declRef)
+ {
+ auto parent = declRef.getParent();
+ while (parent.as<GenericDecl>())
+ {
+ parent = parent.getParent();
+ }
+ return parent;
+ }
+
int SemanticsVisitor::CompareLookupResultItems(
LookupResultItem const& left,
LookupResultItem const& right)
@@ -1204,13 +1214,31 @@ namespace Slang
// directly (it is only visible through the requirement witness
// information for inheritance declarations).
//
- auto leftDeclRefParent = left.declRef.getParent();
- auto rightDeclRefParent = right.declRef.getParent();
+ auto leftDeclRefParent = getParentDeclRef(left.declRef);
+ auto rightDeclRefParent = getParentDeclRef(right.declRef);
bool leftIsInterfaceRequirement = isInterfaceRequirement(left.declRef.getDecl());
bool rightIsInterfaceRequirement = isInterfaceRequirement(right.declRef.getDecl());
if(leftIsInterfaceRequirement != rightIsInterfaceRequirement)
return int(leftIsInterfaceRequirement) - int(rightIsInterfaceRequirement);
+ // Prefer non-extension declarations over extension declarations.
+ bool leftIsExtension = as<ExtensionDecl>(leftDeclRefParent.getDecl()) != nullptr;
+ bool rightIsExtension = as<ExtensionDecl>(rightDeclRefParent.getDecl()) != nullptr;
+ if (leftIsExtension != rightIsExtension)
+ {
+ return int(leftIsExtension) - int(rightIsExtension);
+ }
+ else if (leftIsExtension)
+ {
+ // If both are declared in extensions, prefer the one that is least generic.
+ bool leftIsGeneric = leftDeclRefParent.getParent().as<GenericDecl>() != nullptr;
+ bool rightIsGeneric = rightDeclRefParent.getParent().as<GenericDecl>() != nullptr;
+ if (leftIsGeneric != rightIsGeneric)
+ {
+ return int(leftIsGeneric) - int(rightIsGeneric);
+ }
+ }
+
// Any decl is strictly better than a module decl.
bool leftIsModule = (as<ModuleDeclarationDecl>(left.declRef) != nullptr);
bool rightIsModule = (as<ModuleDeclarationDecl>(right.declRef) != nullptr);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 5285b5c6e..81170fac3 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -554,6 +554,7 @@ DIAGNOSTIC(30832, Error, invalidTypeForInheritance, "type '$0' cannot be used fo
DIAGNOSTIC(30850, Error, invalidExtensionOnType, "type '$0' cannot be extended. `extension` can only be used to extend a nominal type.")
DIAGNOSTIC(30851, Error, invalidMemberTypeInExtension, "$0 cannot be a part of an `extension`")
+DIAGNOSTIC(30852, Error, invalidExtensionOnInterface, "cannot extend interface type '$0'. consider using a generic extension: `extension<T:$0> T {...}`.")
// 309xx: subscripts
DIAGNOSTIC(30900, Error, multiDimensionalArrayNotSupported, "multi-dimensional array is not supported.")
diff --git a/tests/diagnostics/interfaces/interface-extension.slang b/tests/diagnostics/interfaces/interface-extension.slang
new file mode 100644
index 000000000..b63b454ab
--- /dev/null
+++ b/tests/diagnostics/interfaces/interface-extension.slang
@@ -0,0 +1,10 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):-target cpp -stage compute -entry main -disable-specialization
+
+interface IFoo{}
+
+
+// CHECK: ([[# @LINE+1]]): error 30852
+extension IFoo
+{
+ int f() { return 0; }
+} \ No newline at end of file
diff --git a/tests/language-feature/extensions/extension-override.slang b/tests/language-feature/extensions/extension-override.slang
new file mode 100644
index 000000000..30fa64965
--- /dev/null
+++ b/tests/language-feature/extensions/extension-override.slang
@@ -0,0 +1,65 @@
+// Test that the override behavior around extensions and generic extensions works as expected.
+
+// When there are multiple ways for a type to conform to an interface, then the expected behavior
+// is that:
+// 1. If the type directly implements an interface, use that conformance.
+// 2. Otherwise, if there is a direct extension on the type that makes it conform to the interface, use that
+// extension.
+// 3. Otherwise, if there is a generic extension that makes the type conform to the interface, use that.
+
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
+interface IFoo
+{
+ int getVal();
+}
+
+interface IBar
+{
+ int getValPlusOne();
+}
+
+interface IBaz
+{
+ int getValPlusTwo();
+}
+
+struct MyInt
+{
+ int v;
+}
+
+extension MyInt : IFoo
+{
+ int getVal() { return v; }
+}
+
+extension MyInt : IBar
+{
+ int getValPlusOne() { return this.getVal() + 2; }
+}
+
+extension<T: IFoo> T : IBar
+{
+ int getValPlusOne() { return this.getVal() + 1; }
+}
+
+int helper1<T:IBar>(T v){ return v.getValPlusOne();}
+int helper2<T:IFoo>(T v){ return v.getValPlusOne();}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ MyInt v = {1};
+
+ // CHECK: 3
+ outputBuffer[0] = v.getValPlusOne(); // should call MyInt::ext::getValPlusOne();
+
+ // CHECK: 3
+ outputBuffer[1] = helper1(v); // should call MyInt::ext::getValPlusOne();
+
+ // CHECK: 2
+ outputBuffer[2] = helper2(v); // should call T::ext::getValPlusOne();
+} \ No newline at end of file