diff options
| -rw-r--r-- | source/slang/slang-ast-val.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-legalize-types.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-lookup.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 213 | ||||
| -rw-r--r-- | tests/language-feature/extensions/interface-extension.slang | 50 | ||||
| -rw-r--r-- | tests/language-feature/extensions/interface-extension.slang.expected.txt | 4 |
6 files changed, 229 insertions, 43 deletions
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index bb5bed1bc..4926643d3 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -200,6 +200,7 @@ class TaggedUnionSubtypeWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; + /// A witness of the fact that `ThisType(someInterface) : someInterface` class ThisTypeSubtypeWitness : public SubtypeWitness { SLANG_CLASS(ThisTypeSubtypeWitness) diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp index c6cd0f387..a6100f45c 100644 --- a/source/slang/slang-legalize-types.cpp +++ b/source/slang/slang-legalize-types.cpp @@ -1175,7 +1175,7 @@ LegalType legalizeTypeImpl( else if( auto existentialPtrType = as<IRExistentialBoxType>(type)) { // We want to transform an `ExistentialBox<T>` into just - // a `T`, with an `iplicitDeref` to make sure that any + // a `T`, with an `implicitDeref` to make sure that any // pointer-related operations on the box Just Work. // // Note: the logic here doesn't have to deal with moving diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index 3aab22724..b54b09d63 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -634,7 +634,7 @@ static void _lookUpMembersInSuperTypeImpl( interfaceType, superIsInterfaceWitness); - _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, leafType, interfaceType, leafIsInterfaceWitness, thisType->interfaceDeclRef, request, ioResult, inBreadcrumbs); + _lookUpMembersInSuperType(astBuilder, name, leafType, interfaceType, leafIsInterfaceWitness, request, ioResult, inBreadcrumbs); } } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 2f1511444..6361c135a 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -404,6 +404,9 @@ struct IRGenContext // The IRType value to lower into for `ThisType`. IRInst* thisType = nullptr; + // The IR witness value to use for `ThisType` + IRInst* thisTypeWitness = nullptr; + explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder) : shared(inShared) , astBuilder(inAstBuilder) @@ -1416,6 +1419,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(irWitnessTable); } + LoweredValInfo visitThisTypeSubtypeWitness(ThisTypeSubtypeWitness* val) + { + SLANG_UNUSED(val); + return LoweredValInfo::simple(context->thisTypeWitness); + } + LoweredValInfo visitConstantIntVal(ConstantIntVal* val) { // TODO: it is a bit messy here that the `ConstantIntVal` representation @@ -2233,6 +2242,31 @@ DeclRef<D> createDefaultSpecializedDeclRef(IRGenContext* context, D* decl) return declRef.as<D>(); } +static Type* _findReplacementThisParamType( + IRGenContext* context, + DeclRef<Decl> parentDeclRef) +{ + if( auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>() ) + { + auto targetType = getTargetType(context->astBuilder, extensionDeclRef); + if(auto targetDeclRefType = as<DeclRefType>(targetType)) + { + if(auto replacementType = _findReplacementThisParamType(context, targetDeclRefType->declRef)) + return replacementType; + } + return targetType; + } + + if (auto interfaceDeclRef = parentDeclRef.as<InterfaceDecl>()) + { + auto thisType = context->astBuilder->create<ThisType>(); + thisType->interfaceDeclRef = interfaceDeclRef; + return thisType; + } + + return nullptr; +} + /// Get the type of the `this` parameter introduced by `parentDeclRef`, or null. /// /// E.g., if `parentDeclRef` is a `struct` declaration, then this will @@ -2247,20 +2281,13 @@ Type* getThisParamTypeForContainer( IRGenContext* context, DeclRef<Decl> parentDeclRef) { - if (auto interfaceDeclRef = parentDeclRef.as<InterfaceDecl>()) - { - auto thisType = context->astBuilder->create<ThisType>(); - thisType->interfaceDeclRef = interfaceDeclRef; - return thisType; - } - else if( auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>() ) + if(auto replacementType = _findReplacementThisParamType(context, parentDeclRef)) + return replacementType; + + if( auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>() ) { return DeclRefType::create(context->astBuilder, aggTypeDeclRef); } - else if( auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>() ) - { - return getTargetType(context->astBuilder, extensionDeclRef); - } return nullptr; } @@ -5692,6 +5719,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> subContextStorage.env = &subEnvStorage; subContextStorage.thisType = outerContext->thisType; + subContextStorage.thisTypeWitness = outerContext->thisTypeWitness; } IRBuilder* getBuilder() { return &subBuilderStorage; } @@ -5962,6 +5990,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto thisType = getBuilder()->getThisType(irInterface); subContext->thisType = thisType; + // TODO: Need to add an appropriate stand-in witness here. + subContext->thisTypeWitness = nullptr; + // Lower associated types first, so they can be referred to when lowering functions. for (auto assocTypeDecl : decl->getMembersOfType<AssocTypeDecl>()) { @@ -6303,6 +6334,45 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return irGeneric; } + IRGeneric* emitOuterInterfaceGeneric( + IRGenContext* subContext, + ContainerDecl* parentDecl, + DeclRefType* interfaceType, + Decl* leafDecl) + { + auto subBuilder = subContext->irBuilder; + + // Of course, a generic might itself be nested inside of other generics... + emitOuterGenerics(subContext, parentDecl, leafDecl); + + // We need to create an IR generic + + auto irGeneric = subBuilder->emitGeneric(); + subBuilder->setInsertInto(irGeneric); + + auto irBlock = subBuilder->emitBlock(); + subBuilder->setInsertInto(irBlock); + + // The generic needs two parameters: one to represent the + // `ThisType`, and one to represent a witness that the + // `ThisType` conforms to the interface itself. + // + auto irThisTypeParam = subBuilder->emitParam(subBuilder->getTypeType()); + + auto irInterfaceType = lowerType(context, interfaceType); + auto irWitnessTableParam = subBuilder->emitParam(subBuilder->getWitnessTableType(irInterfaceType)); + subBuilder->addTypeConstraintDecoration(irThisTypeParam, irInterfaceType); + + // Now we need to wire up the IR parameters + // we created to be used as the `ThisType` in + // the body of the code. + // + subContext->thisType = irThisTypeParam; + subContext->thisTypeWitness = irWitnessTableParam; + + return irGeneric; + } + // If the given `decl` is enclosed in any generic declarations, then // emit IR-level generics to represent them. // The `leafDecl` represents the inner-most declaration we are actually @@ -6316,6 +6386,23 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { return emitOuterGeneric(subContext, genericAncestor, leafDecl); } + + // We introduce IR generics in one other case, where the input + // code wasn't visibly using generics: when a concrete member + // is defined on an interface type. In that case, the resulting + // definition needs to be generic on a parameter to represent + // the `ThisType` of the interface. + // + if(auto extensionAncestor = as<ExtensionDecl>(pp)) + { + if(auto targetDeclRefType = as<DeclRefType>(extensionAncestor->targetType)) + { + if(auto interfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>()) + { + return emitOuterInterfaceGeneric(subContext, extensionAncestor, targetDeclRefType, leafDecl); + } + } + } } return nullptr; @@ -7112,6 +7199,20 @@ bool canDeclLowerToAGeneric(Decl* decl) return false; } +static bool isInterfaceRequirement(Decl* decl) +{ + auto ancestor = decl->parentDecl; + for(; ancestor; ancestor = ancestor->parentDecl ) + { + if(as<InterfaceDecl>(ancestor)) + return true; + + if(as<ExtensionDecl>(ancestor)) + return false; + } + return false; +} + LoweredValInfo emitDeclRef( IRGenContext* context, Decl* decl, @@ -7204,36 +7305,66 @@ LoweredValInfo emitDeclRef( return lowerType(context, thisTypeSubst->witness->sub); } - // Somebody is trying to look up an interface requirement - // "through" some concrete type. We need to lower this decl-ref - // as a lookup of the corresponding member in a witness table. - // - // The witness table itself is referenced by the this-type - // substitution, so we can just lower that. - // - // Note: unlike the case for generics above, in the interface-lookup - // case, we don't end up caring about any further outer substitutions. - // That is because even if we are naming `ISomething<Foo>.doIt()`, - // a method inside a generic interface, we don't actually care - // about the substitution of `Foo` for the parameter `T` of - // `ISomething<T>`. That is because we really care about the - // witness table for the concrete type that conforms to `ISomething<Foo>`. - // - auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness); - // - // The key to use for looking up the interface member is - // derived from the declaration. - // - auto irRequirementKey = getInterfaceRequirementKey(context, decl); - // - // Those two pieces of information tell us what we need to - // do in order to look up the value that satisfied the requirement. - // - auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst( - type, - irWitnessTable, - irRequirementKey); - return LoweredValInfo::simple(irSatisfyingVal); + if(isInterfaceRequirement(decl)) + { + // Somebody is trying to look up an interface requirement + // "through" some concrete type. We need to lower this decl-ref + // as a lookup of the corresponding member in a witness table. + // + // The witness table itself is referenced by the this-type + // substitution, so we can just lower that. + // + // Note: unlike the case for generics above, in the interface-lookup + // case, we don't end up caring about any further outer substitutions. + // That is because even if we are naming `ISomething<Foo>.doIt()`, + // a method inside a generic interface, we don't actually care + // about the substitution of `Foo` for the parameter `T` of + // `ISomething<T>`. That is because we really care about the + // witness table for the concrete type that conforms to `ISomething<Foo>`. + // + auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness); + // + // The key to use for looking up the interface member is + // derived from the declaration. + // + auto irRequirementKey = getInterfaceRequirementKey(context, decl); + // + // Those two pieces of information tell us what we need to + // do in order to look up the value that satisfied the requirement. + // + auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst( + type, + irWitnessTable, + irRequirementKey); + return LoweredValInfo::simple(irSatisfyingVal); + } + else + { + // This case is a reference to a member declaration of the interface + // (or added by an extension of the interface) that does *not* + // represent a requirement of the interface. + // + // Our policy is that concrete methods/members on an interface type + // are lowered as generics, where the generic parameter represents + // the `ThisType`. + // + auto genericVal = emitDeclRef(context, decl, thisTypeSubst->outer, context->irBuilder->getGenericKind()); + auto irGenericVal = getSimpleVal(context, genericVal); + + // In order to reference the member for a particular type, we + // specialize the generic for that type. + // + IRInst* irSubType = lowerType(context, thisTypeSubst->witness->sub); + IRInst* irSubTypeWitness = lowerSimpleVal(context, thisTypeSubst->witness); + + IRInst* irSpecializeArgs[] = { irSubType, irSubTypeWitness }; + auto irSpecializedVal = context->irBuilder->emitSpecializeInst( + type, + irGenericVal, + 2, + irSpecializeArgs); + return LoweredValInfo::simple(irSpecializedVal); + } } else { diff --git a/tests/language-feature/extensions/interface-extension.slang b/tests/language-feature/extensions/interface-extension.slang new file mode 100644 index 000000000..824aa3450 --- /dev/null +++ b/tests/language-feature/extensions/interface-extension.slang @@ -0,0 +1,50 @@ +// interface-extension.slang + +// Test that an `extension` applied to an interface type works as users expect + +//TEST(compute):COMPARE_COMPUTE: + +interface ICounter +{ + [mutating] void add(int value); +} + +struct MyCounter : ICounter +{ + int _state = 0; + + [mutating] void add(int value) { _state += value; } +} + +extension ICounter +{ + [mutating] void increment() + { + this.add(1); + } +} + +void helper<T : ICounter>(in out T counter) +{ + counter.increment(); +} + +int test(int value) +{ + MyCounter counter = { value }; + counter.increment(); + helper(counter); + return counter._state; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + outputBuffer[tid] = outVal; +} diff --git a/tests/language-feature/extensions/interface-extension.slang.expected.txt b/tests/language-feature/extensions/interface-extension.slang.expected.txt new file mode 100644 index 000000000..f8affbc14 --- /dev/null +++ b/tests/language-feature/extensions/interface-extension.slang.expected.txt @@ -0,0 +1,4 @@ +2 +3 +4 +5 |
