summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ast-val.h1
-rw-r--r--source/slang/slang-legalize-types.cpp2
-rw-r--r--source/slang/slang-lookup.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp213
-rw-r--r--tests/language-feature/extensions/interface-extension.slang50
-rw-r--r--tests/language-feature/extensions/interface-extension.slang.expected.txt4
6 files changed, 229 insertions, 43 deletions
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index bb5bed1bc..4926643d3 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -200,6 +200,7 @@ class TaggedUnionSubtypeWitness : public SubtypeWitness
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
+ /// A witness of the fact that `ThisType(someInterface) : someInterface`
class ThisTypeSubtypeWitness : public SubtypeWitness
{
SLANG_CLASS(ThisTypeSubtypeWitness)
diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp
index c6cd0f387..a6100f45c 100644
--- a/source/slang/slang-legalize-types.cpp
+++ b/source/slang/slang-legalize-types.cpp
@@ -1175,7 +1175,7 @@ LegalType legalizeTypeImpl(
else if( auto existentialPtrType = as<IRExistentialBoxType>(type))
{
// We want to transform an `ExistentialBox<T>` into just
- // a `T`, with an `iplicitDeref` to make sure that any
+ // a `T`, with an `implicitDeref` to make sure that any
// pointer-related operations on the box Just Work.
//
// Note: the logic here doesn't have to deal with moving
diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp
index 3aab22724..b54b09d63 100644
--- a/source/slang/slang-lookup.cpp
+++ b/source/slang/slang-lookup.cpp
@@ -634,7 +634,7 @@ static void _lookUpMembersInSuperTypeImpl(
interfaceType,
superIsInterfaceWitness);
- _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, leafType, interfaceType, leafIsInterfaceWitness, thisType->interfaceDeclRef, request, ioResult, inBreadcrumbs);
+ _lookUpMembersInSuperType(astBuilder, name, leafType, interfaceType, leafIsInterfaceWitness, request, ioResult, inBreadcrumbs);
}
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 2f1511444..6361c135a 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -404,6 +404,9 @@ struct IRGenContext
// The IRType value to lower into for `ThisType`.
IRInst* thisType = nullptr;
+ // The IR witness value to use for `ThisType`
+ IRInst* thisTypeWitness = nullptr;
+
explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder)
: shared(inShared)
, astBuilder(inAstBuilder)
@@ -1416,6 +1419,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(irWitnessTable);
}
+ LoweredValInfo visitThisTypeSubtypeWitness(ThisTypeSubtypeWitness* val)
+ {
+ SLANG_UNUSED(val);
+ return LoweredValInfo::simple(context->thisTypeWitness);
+ }
+
LoweredValInfo visitConstantIntVal(ConstantIntVal* val)
{
// TODO: it is a bit messy here that the `ConstantIntVal` representation
@@ -2233,6 +2242,31 @@ DeclRef<D> createDefaultSpecializedDeclRef(IRGenContext* context, D* decl)
return declRef.as<D>();
}
+static Type* _findReplacementThisParamType(
+ IRGenContext* context,
+ DeclRef<Decl> parentDeclRef)
+{
+ if( auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>() )
+ {
+ auto targetType = getTargetType(context->astBuilder, extensionDeclRef);
+ if(auto targetDeclRefType = as<DeclRefType>(targetType))
+ {
+ if(auto replacementType = _findReplacementThisParamType(context, targetDeclRefType->declRef))
+ return replacementType;
+ }
+ return targetType;
+ }
+
+ if (auto interfaceDeclRef = parentDeclRef.as<InterfaceDecl>())
+ {
+ auto thisType = context->astBuilder->create<ThisType>();
+ thisType->interfaceDeclRef = interfaceDeclRef;
+ return thisType;
+ }
+
+ return nullptr;
+}
+
/// Get the type of the `this` parameter introduced by `parentDeclRef`, or null.
///
/// E.g., if `parentDeclRef` is a `struct` declaration, then this will
@@ -2247,20 +2281,13 @@ Type* getThisParamTypeForContainer(
IRGenContext* context,
DeclRef<Decl> parentDeclRef)
{
- if (auto interfaceDeclRef = parentDeclRef.as<InterfaceDecl>())
- {
- auto thisType = context->astBuilder->create<ThisType>();
- thisType->interfaceDeclRef = interfaceDeclRef;
- return thisType;
- }
- else if( auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>() )
+ if(auto replacementType = _findReplacementThisParamType(context, parentDeclRef))
+ return replacementType;
+
+ if( auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>() )
{
return DeclRefType::create(context->astBuilder, aggTypeDeclRef);
}
- else if( auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>() )
- {
- return getTargetType(context->astBuilder, extensionDeclRef);
- }
return nullptr;
}
@@ -5692,6 +5719,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
subContextStorage.env = &subEnvStorage;
subContextStorage.thisType = outerContext->thisType;
+ subContextStorage.thisTypeWitness = outerContext->thisTypeWitness;
}
IRBuilder* getBuilder() { return &subBuilderStorage; }
@@ -5962,6 +5990,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto thisType = getBuilder()->getThisType(irInterface);
subContext->thisType = thisType;
+ // TODO: Need to add an appropriate stand-in witness here.
+ subContext->thisTypeWitness = nullptr;
+
// Lower associated types first, so they can be referred to when lowering functions.
for (auto assocTypeDecl : decl->getMembersOfType<AssocTypeDecl>())
{
@@ -6303,6 +6334,45 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return irGeneric;
}
+ IRGeneric* emitOuterInterfaceGeneric(
+ IRGenContext* subContext,
+ ContainerDecl* parentDecl,
+ DeclRefType* interfaceType,
+ Decl* leafDecl)
+ {
+ auto subBuilder = subContext->irBuilder;
+
+ // Of course, a generic might itself be nested inside of other generics...
+ emitOuterGenerics(subContext, parentDecl, leafDecl);
+
+ // We need to create an IR generic
+
+ auto irGeneric = subBuilder->emitGeneric();
+ subBuilder->setInsertInto(irGeneric);
+
+ auto irBlock = subBuilder->emitBlock();
+ subBuilder->setInsertInto(irBlock);
+
+ // The generic needs two parameters: one to represent the
+ // `ThisType`, and one to represent a witness that the
+ // `ThisType` conforms to the interface itself.
+ //
+ auto irThisTypeParam = subBuilder->emitParam(subBuilder->getTypeType());
+
+ auto irInterfaceType = lowerType(context, interfaceType);
+ auto irWitnessTableParam = subBuilder->emitParam(subBuilder->getWitnessTableType(irInterfaceType));
+ subBuilder->addTypeConstraintDecoration(irThisTypeParam, irInterfaceType);
+
+ // Now we need to wire up the IR parameters
+ // we created to be used as the `ThisType` in
+ // the body of the code.
+ //
+ subContext->thisType = irThisTypeParam;
+ subContext->thisTypeWitness = irWitnessTableParam;
+
+ return irGeneric;
+ }
+
// If the given `decl` is enclosed in any generic declarations, then
// emit IR-level generics to represent them.
// The `leafDecl` represents the inner-most declaration we are actually
@@ -6316,6 +6386,23 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
return emitOuterGeneric(subContext, genericAncestor, leafDecl);
}
+
+ // We introduce IR generics in one other case, where the input
+ // code wasn't visibly using generics: when a concrete member
+ // is defined on an interface type. In that case, the resulting
+ // definition needs to be generic on a parameter to represent
+ // the `ThisType` of the interface.
+ //
+ if(auto extensionAncestor = as<ExtensionDecl>(pp))
+ {
+ if(auto targetDeclRefType = as<DeclRefType>(extensionAncestor->targetType))
+ {
+ if(auto interfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>())
+ {
+ return emitOuterInterfaceGeneric(subContext, extensionAncestor, targetDeclRefType, leafDecl);
+ }
+ }
+ }
}
return nullptr;
@@ -7112,6 +7199,20 @@ bool canDeclLowerToAGeneric(Decl* decl)
return false;
}
+static bool isInterfaceRequirement(Decl* decl)
+{
+ auto ancestor = decl->parentDecl;
+ for(; ancestor; ancestor = ancestor->parentDecl )
+ {
+ if(as<InterfaceDecl>(ancestor))
+ return true;
+
+ if(as<ExtensionDecl>(ancestor))
+ return false;
+ }
+ return false;
+}
+
LoweredValInfo emitDeclRef(
IRGenContext* context,
Decl* decl,
@@ -7204,36 +7305,66 @@ LoweredValInfo emitDeclRef(
return lowerType(context, thisTypeSubst->witness->sub);
}
- // Somebody is trying to look up an interface requirement
- // "through" some concrete type. We need to lower this decl-ref
- // as a lookup of the corresponding member in a witness table.
- //
- // The witness table itself is referenced by the this-type
- // substitution, so we can just lower that.
- //
- // Note: unlike the case for generics above, in the interface-lookup
- // case, we don't end up caring about any further outer substitutions.
- // That is because even if we are naming `ISomething<Foo>.doIt()`,
- // a method inside a generic interface, we don't actually care
- // about the substitution of `Foo` for the parameter `T` of
- // `ISomething<T>`. That is because we really care about the
- // witness table for the concrete type that conforms to `ISomething<Foo>`.
- //
- auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness);
- //
- // The key to use for looking up the interface member is
- // derived from the declaration.
- //
- auto irRequirementKey = getInterfaceRequirementKey(context, decl);
- //
- // Those two pieces of information tell us what we need to
- // do in order to look up the value that satisfied the requirement.
- //
- auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst(
- type,
- irWitnessTable,
- irRequirementKey);
- return LoweredValInfo::simple(irSatisfyingVal);
+ if(isInterfaceRequirement(decl))
+ {
+ // Somebody is trying to look up an interface requirement
+ // "through" some concrete type. We need to lower this decl-ref
+ // as a lookup of the corresponding member in a witness table.
+ //
+ // The witness table itself is referenced by the this-type
+ // substitution, so we can just lower that.
+ //
+ // Note: unlike the case for generics above, in the interface-lookup
+ // case, we don't end up caring about any further outer substitutions.
+ // That is because even if we are naming `ISomething<Foo>.doIt()`,
+ // a method inside a generic interface, we don't actually care
+ // about the substitution of `Foo` for the parameter `T` of
+ // `ISomething<T>`. That is because we really care about the
+ // witness table for the concrete type that conforms to `ISomething<Foo>`.
+ //
+ auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness);
+ //
+ // The key to use for looking up the interface member is
+ // derived from the declaration.
+ //
+ auto irRequirementKey = getInterfaceRequirementKey(context, decl);
+ //
+ // Those two pieces of information tell us what we need to
+ // do in order to look up the value that satisfied the requirement.
+ //
+ auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst(
+ type,
+ irWitnessTable,
+ irRequirementKey);
+ return LoweredValInfo::simple(irSatisfyingVal);
+ }
+ else
+ {
+ // This case is a reference to a member declaration of the interface
+ // (or added by an extension of the interface) that does *not*
+ // represent a requirement of the interface.
+ //
+ // Our policy is that concrete methods/members on an interface type
+ // are lowered as generics, where the generic parameter represents
+ // the `ThisType`.
+ //
+ auto genericVal = emitDeclRef(context, decl, thisTypeSubst->outer, context->irBuilder->getGenericKind());
+ auto irGenericVal = getSimpleVal(context, genericVal);
+
+ // In order to reference the member for a particular type, we
+ // specialize the generic for that type.
+ //
+ IRInst* irSubType = lowerType(context, thisTypeSubst->witness->sub);
+ IRInst* irSubTypeWitness = lowerSimpleVal(context, thisTypeSubst->witness);
+
+ IRInst* irSpecializeArgs[] = { irSubType, irSubTypeWitness };
+ auto irSpecializedVal = context->irBuilder->emitSpecializeInst(
+ type,
+ irGenericVal,
+ 2,
+ irSpecializeArgs);
+ return LoweredValInfo::simple(irSpecializedVal);
+ }
}
else
{
diff --git a/tests/language-feature/extensions/interface-extension.slang b/tests/language-feature/extensions/interface-extension.slang
new file mode 100644
index 000000000..824aa3450
--- /dev/null
+++ b/tests/language-feature/extensions/interface-extension.slang
@@ -0,0 +1,50 @@
+// interface-extension.slang
+
+// Test that an `extension` applied to an interface type works as users expect
+
+//TEST(compute):COMPARE_COMPUTE:
+
+interface ICounter
+{
+ [mutating] void add(int value);
+}
+
+struct MyCounter : ICounter
+{
+ int _state = 0;
+
+ [mutating] void add(int value) { _state += value; }
+}
+
+extension ICounter
+{
+ [mutating] void increment()
+ {
+ this.add(1);
+ }
+}
+
+void helper<T : ICounter>(in out T counter)
+{
+ counter.increment();
+}
+
+int test(int value)
+{
+ MyCounter counter = { value };
+ counter.increment();
+ helper(counter);
+ return counter._state;
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal);
+ outputBuffer[tid] = outVal;
+}
diff --git a/tests/language-feature/extensions/interface-extension.slang.expected.txt b/tests/language-feature/extensions/interface-extension.slang.expected.txt
new file mode 100644
index 000000000..f8affbc14
--- /dev/null
+++ b/tests/language-feature/extensions/interface-extension.slang.expected.txt
@@ -0,0 +1,4 @@
+2
+3
+4
+5