diff options
| author | Yong He <yonghe@outlook.com> | 2023-10-25 07:45:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-10-25 22:45:23 +0800 |
| commit | f8bf75cf1ae0aeee155996a917c2925bc500f3e2 (patch) | |
| tree | 07b418cfdc3fe106c492162624cfdaeb7a453be9 | |
| parent | d8f4c9424c69a3d406fabf56a25dd3eda4bc7d51 (diff) | |
Support generic interfaces. (#3278)
* Initial support for generic interfaces.
* Cleanup.
* Add generic syntax for interfaces.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ast-base.cpp | 55 | ||||
| -rw-r--r-- | source/slang/slang-ast-base.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl-ref.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 143 | ||||
| -rw-r--r-- | source/slang/slang-mangle.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 2 | ||||
| -rw-r--r-- | tests/language-feature/generics/generic-interface-1.slang | 37 |
14 files changed, 204 insertions, 93 deletions
diff --git a/source/slang/slang-ast-base.cpp b/source/slang/slang-ast-base.cpp index 60be7a563..d4904d2de 100644 --- a/source/slang/slang-ast-base.cpp +++ b/source/slang/slang-ast-base.cpp @@ -3,34 +3,43 @@ namespace Slang { -void NodeBase::_initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) -{ + void NodeBase::_initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) + { #ifdef _DEBUG - SLANG_UNUSED(inAstNodeType); - static int32_t uidCounter = 0; - static int32_t breakValue = 0; - uidCounter++; - _debugUID = uidCounter; - if (inAstBuilder->getId() == -1) - _debugUID = -_debugUID; - if (breakValue != 0 && _debugUID == breakValue) - SLANG_BREAKPOINT(0) + SLANG_UNUSED(inAstNodeType); + static int32_t uidCounter = 0; + static int32_t breakValue = 0; + uidCounter++; + _debugUID = uidCounter; + if (inAstBuilder->getId() == -1) + _debugUID = -_debugUID; + if (breakValue != 0 && _debugUID == breakValue) + SLANG_BREAKPOINT(0) #else - SLANG_UNUSED(inAstNodeType); - SLANG_UNUSED(inAstBuilder); + SLANG_UNUSED(inAstNodeType); + SLANG_UNUSED(inAstBuilder); #endif -} -DeclRefBase* Decl::getDefaultDeclRef() -{ - if (auto astBuilder = getCurrentASTBuilder()) + } + DeclRefBase* Decl::getDefaultDeclRef() { - const Index currentEpoch = astBuilder->getEpoch(); - if (currentEpoch != m_defaultDeclRefEpoch || !m_defaultDeclRef) + if (auto astBuilder = getCurrentASTBuilder()) { - m_defaultDeclRef = astBuilder->getOrCreate<DirectDeclRef>(this); - m_defaultDeclRefEpoch = currentEpoch; + const Index currentEpoch = astBuilder->getEpoch(); + if (currentEpoch != m_defaultDeclRefEpoch || !m_defaultDeclRef) + { + m_defaultDeclRef = astBuilder->getOrCreate<DirectDeclRef>(this); + m_defaultDeclRefEpoch = currentEpoch; + } } + return m_defaultDeclRef; } - return m_defaultDeclRef; -} + + bool Decl::isChildOf(Decl* other) const + { + for (auto parent = parentDecl; parent; parent = parent->parentDecl) + if (parent == other) + return true; + return false; + } + } diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 0170ca493..579bda73a 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -734,6 +734,7 @@ public: SLANG_RELEASE_ASSERT(state >= checkState.getState()); checkState.setState(state); } + bool isChildOf(Decl* other) const; private: SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr; diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index aff3088ab..1c6637c31 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -294,7 +294,6 @@ public: interfaceDecl->addMember(thisDecl); auto thisConstraint = create<ThisTypeConstraintDecl>(); thisConstraint->loc = loc; - thisConstraint->base.type = DeclRefType::create(this, getDirectDeclRef(interfaceDecl)); thisDecl->addMember(thisConstraint); return interfaceDecl; } diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp index c77cf72ed..c9511e4e7 100644 --- a/source/slang/slang-ast-decl-ref.cpp +++ b/source/slang/slang-ast-decl-ref.cpp @@ -150,6 +150,11 @@ Val* LookupDeclRef::_resolveImplOverride() DeclRefBase* LookupDeclRef::_getBaseOverride() { + auto supType = getWitness()->getSup(); + if (auto declRefType = as<DeclRefType>(supType)) + { + return declRefType->getDeclRef(); + } return nullptr; } @@ -432,10 +437,13 @@ DeclRef<Decl> createDefaultSubstitutionsIfNeeded( ShortList<GenericDecl*> genericParentDecls; auto lastSubstNode = SubstitutionSet(declRef).getInnerMostNodeWithSubstInfo(); auto lastGenApp = as<GenericAppDeclRef>(lastSubstNode); + auto lastLookup = as<LookupDeclRef>(lastSubstNode); for (auto dd = declRef.getDecl()->parentDecl; dd; dd = dd->parentDecl) { if (lastGenApp && dd == lastGenApp->getGenericDecl()) break; + if (lastLookup && lastLookup->getDecl()->isChildOf(dd)) + break; if (auto gen = as<GenericDecl>(dd)) genericParentDecls.add(gen); } @@ -454,6 +462,8 @@ DeclRef<Decl> createDefaultSubstitutionsIfNeeded( } parentDeclRef = astBuilder->getGenericAppDeclRef(parentDeclRef.as<GenericDecl>(), args.getArrayView()); } + if (!parentDeclRef) + return declRef; if (parentDeclRef.getDecl() == declRef.getDecl()) return parentDeclRef; return astBuilder->getMemberDeclRef(parentDeclRef, declRef.getDecl()); diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index a29ff9bb3..840aa4a67 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -636,9 +636,9 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder, // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -InterfaceDecl* ThisType::getInterfaceDecl() +DeclRef<InterfaceDecl> ThisType::getInterfaceDeclRef() { - return dynamicCast<InterfaceDecl>(getDeclRefBase()->getDecl()->parentDecl); + return DeclRef<Decl>(getDeclRefBase()->getParent()).template as<InterfaceDecl>(); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AndType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 3c50b1899..638012652 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -55,7 +55,7 @@ class DeclRefType : public Type { SLANG_AST_CLASS(DeclRefType) - static DeclRefType* create(ASTBuilder* astBuilder, DeclRef<Decl> declRef); + static Type* create(ASTBuilder* astBuilder, DeclRef<Decl> declRef); DeclRef<Decl> getDeclRef() const { return DeclRef<Decl>(as<DeclRefBase>(getOperand(0))); } DeclRefBase* getDeclRefBase() const { return as<DeclRefBase>(getOperand(0)); } @@ -786,7 +786,7 @@ class ThisType : public DeclRefType ThisType(DeclRefBase* declRef) : DeclRefType(declRef) {} - InterfaceDecl* getInterfaceDecl(); + DeclRef<InterfaceDecl> getInterfaceDeclRef(); }; /// The type of `A & B` where `A` and `B` are types diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 1e3c6a361..8df5ae618 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -169,6 +169,8 @@ namespace Slang void visitInheritanceDecl(InheritanceDecl* inheritanceDecl); + void visitThisTypeConstraintDecl(ThisTypeConstraintDecl* thisTypeConstraintDecl); + /// Validate that `decl` isn't illegally inheriting from a type in another module. /// /// This call checks a single `inheritanceDecl` to make sure that it either @@ -1600,6 +1602,22 @@ namespace Slang // based on the declaration that is doing the inheriting. } + void SemanticsDeclBasesVisitor::visitThisTypeConstraintDecl(ThisTypeConstraintDecl* thisTypeConstraintDecl) + { + // Make sure IFoo<T>.This.ThisIsIFooConstraint.base.type is properly set + // to DeclRefType(IFoo<T>) with default generic arguments. + if (!thisTypeConstraintDecl->base.type) + { + auto parentTypeDecl = getParentDecl(getParentDecl(thisTypeConstraintDecl)); + thisTypeConstraintDecl->base.type = DeclRefType::create( + m_astBuilder, + createDefaultSubstitutionsIfNeeded( + m_astBuilder, + this, + getDefaultDeclRef(parentTypeDecl))); + } + } + // Concretize interface conformances so that we have witnesses as required for lookup. // for lookup. struct SemanticsDeclConformancesVisitor diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index c04450b82..be6228f2d 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3226,7 +3226,7 @@ public: IRType* getCapabilitySetType(); IRAssociatedType* getAssociatedType(ArrayView<IRInterfaceType*> constraintTypes); - IRThisType* getThisType(IRInterfaceType* interfaceType); + IRThisType* getThisType(IRType* interfaceType); IRRawPointerType* getRawPointerType(); IRRTTIPointerType* getRTTIPointerType(IRInst* rttiPtr); IRRTTIType* getRTTIType(); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index cf58e6cd4..2f603ac17 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2653,7 +2653,7 @@ namespace Slang (IRInst**)constraintTypes.getBuffer()); } - IRThisType* IRBuilder::getThisType(IRInterfaceType* interfaceType) + IRThisType* IRBuilder::getThisType(IRType* interfaceType) { return (IRThisType*)getType(kIROp_ThisType, interfaceType); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 2813918b6..d75b66a9b 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -535,6 +535,22 @@ struct SharedIRGenContext List<IRInst*> m_stringLiterals; }; +struct IRGenContext; + +struct AstOrIRType +{ + Type* astType = nullptr; + IRInst* irType = nullptr; + IRInst* getIRType(IRGenContext* context); + + AstOrIRType& operator=(Type* t) { astType = t; irType = nullptr; return *this; } + AstOrIRType& operator=(IRInst* t) { astType = nullptr; irType = t; return *this; } + explicit operator bool() + { + return astType || irType; + } +}; + struct IRGenContext { ASTBuilder* astBuilder; @@ -558,7 +574,7 @@ struct IRGenContext LoweredValInfo thisVal; // The IRType value to lower into for `ThisType`. - IRInst* thisType = nullptr; + AstOrIRType thisType; // The IR witness value to use for `ThisType` IRInst* thisTypeWitness = nullptr; @@ -824,6 +840,14 @@ static IRType* lowerType( return lowerType(context, type.type); } +IRInst* AstOrIRType::getIRType(IRGenContext* context) +{ + if (irType) + return irType; + irType = lowerType(context, astType); + return irType; +} + // Given a `DeclRef` for something callable, along with a bunch of // arguments, emit an appropriate call to it. LoweredValInfo emitCallToDeclRef( @@ -1984,9 +2008,17 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // Therefore, `context->thisType` should have been set to `IRThisType` // in `visitInterfaceDecl`, and we can just use that value here. // - if (context->thisType != nullptr) - return LoweredValInfo::simple(context->thisType); - return emitDeclRef(context, makeDeclRef(type->getInterfaceDecl()), getBuilder()->getTypeKind()); + if (context->thisType.irType) + { + return LoweredValInfo::simple(context->thisType.irType); + } + auto interfaceType = emitDeclRef(context, type->getInterfaceDeclRef(), getBuilder()->getTypeKind()); + auto result = LoweredValInfo::simple(getBuilder()->getThisType((IRType*)getSimpleVal(context, interfaceType))); + if (context->thisType.astType == type) + { + context->thisType = getSimpleVal(context, result); + } + return result; } LoweredValInfo visitAndType(AndType* type) @@ -2668,7 +2700,9 @@ static Type* _findReplacementThisParamType( if (auto interfaceDeclRef = parentDeclRef.as<InterfaceDecl>()) { - auto thisType = DeclRefType::create(context->astBuilder, interfaceDeclRef.getDecl()->getThisTypeDecl()); + auto thisType = DeclRefType::create( + context->astBuilder, + context->astBuilder->getMemberDeclRef(interfaceDeclRef, interfaceDeclRef.getDecl()->getThisTypeDecl())); return thisType; } @@ -2704,6 +2738,11 @@ Type* getThisParamTypeForCallable( IRGenContext* context, DeclRef<Decl> callableDeclRef) { + if (auto lookup = as<LookupDeclRef>((callableDeclRef.declRefBase))) + { + return lookup->getLookupSource(); + } + auto parentDeclRef = callableDeclRef.getParent(); if(auto subscriptDeclRef = parentDeclRef.as<SubscriptDecl>()) @@ -7751,13 +7790,19 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Allocate an IRInterfaceType with the `operandCount` operands. IRInterfaceType* irInterface = subBuilder->createInterfaceType(operandCount, nullptr); + auto finalVal = finishOuterGenerics(subBuilder, irInterface, outerGeneric); // Add `irInterface` to decl mapping now to prevent cyclic lowering. - context->setValue(decl, LoweredValInfo::simple(irInterface)); + context->setValue(decl, LoweredValInfo::simple(finalVal)); + + subBuilder->setInsertBefore(irInterface); // Setup subContext for proper lowering `ThisType`, associated types and // the interface decl's self reference. - auto thisType = getBuilder()->getThisType(irInterface); + + auto thisType = DeclRefType::create( + context->astBuilder, + createDefaultSpecializedDeclRef(subContext, nullptr, decl->getThisTypeDecl())); subContext->thisType = thisType; // TODO: Need to add an appropriate stand-in witness here. @@ -7880,14 +7925,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } subBuilder->setInsertInto(irInterface); - // TODO: are there any interface members that should be - // nested inside the interface type itself? - - irInterface->moveToEnd(); addTargetIntrinsicDecorations(subContext, irInterface, decl); - auto finalVal = finishOuterGenerics(subBuilder, irInterface, outerGeneric); return LoweredValInfo::simple(finalVal); } @@ -7939,8 +7979,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> LoweredValInfo visitThisTypeDecl(ThisTypeDecl* decl) { - auto interfaceType = ensureDecl(context, decl->parentDecl).val; - return LoweredValInfo::simple(context->irBuilder->getThisType(as<IRInterfaceType>(interfaceType))); + SLANG_UNUSED(decl); + return LoweredValInfo(); } LoweredValInfo visitThisTypeConstraintDecl(ThisTypeConstraintDecl* decl) @@ -7968,14 +8008,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> const bool isPublicType = decl->findModifier<PublicModifier>() != nullptr; - // Given a declaration of a type, we need to make sure - // to output "witness tables" for any interfaces this - // type has declared conformance to. - for( auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>() ) - { - ensureDecl(context, inheritanceDecl); - } - // We are going to create nested IR building state // to use when emitting the members of the type. // @@ -8001,11 +8033,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo::simple(subBuilder->getVoidType()); } - const auto finishedVal = _getFinishOuterGenericsReturnValue(irAggType, outerGeneric); + auto finalFinishedVal = finishOuterGenerics(subBuilder, irAggType, outerGeneric); // We add the decl now such that if there are Ptr or other references // to this type they can still complete - context->setValue(decl, LoweredValInfo::simple(finishedVal)); + context->setValue(decl, LoweredValInfo::simple(finalFinishedVal)); + + subBuilder->setInsertBefore(irAggType); + + // Given a declaration of a type, we need to make sure + // to output "witness tables" for any interfaces this + // type has declared conformance to. + for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>()) + { + ensureDecl(subContext, inheritanceDecl); + } addNameHint(context, irAggType, decl); addLinkageDecoration(context, irAggType, decl); @@ -8022,8 +8064,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // for( auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>() ) { - if (isPublicType) - ensureDecl(context, inheritanceDecl); auto superType = inheritanceDecl->base; if(auto superDeclRefType = as<DeclRefType>(superType)) { @@ -8031,7 +8071,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> superDeclRefType->getDeclRef().as<ClassDecl>()) { auto superKey = (IRStructKey*) getSimpleVal(context, ensureDecl(context, inheritanceDecl)); - auto irSuperType = lowerType(context, superType.type); + auto irSuperType = lowerType(subContext, superType.type); subBuilder->createStructField( irAggType, superKey, @@ -8053,8 +8093,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Each ordinary field will need to turn into a struct "key" // that is used for fetching the field. - IRInst* fieldKeyInst = getSimpleVal(context, - ensureDecl(context, fieldDecl)); + IRInst* fieldKeyInst = getSimpleVal(subContext, + ensureDecl(subContext, fieldDecl)); auto fieldKey = as<IRStructKey>(fieldKeyInst); SLANG_ASSERT(fieldKey); @@ -8085,7 +8125,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Instead we will force emission of all children of aggregate // type declarations later, from the top-level emit logic. - irAggType->moveToEnd(); addTargetIntrinsicDecorations(subContext, irAggType, decl); for (auto modifier : decl->modifiers) { @@ -8093,9 +8132,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> subBuilder->addNonCopyableTypeDecoration(irAggType); } - auto finalFinishedVal = finishOuterGenerics(subBuilder, irAggType, outerGeneric); - // Confirm that _getFinishOuterGenericsReturnValue above returned the same result - SLANG_ASSERT(finalFinishedVal == finishedVal); return LoweredValInfo::simple(finalFinishedVal); } @@ -8611,27 +8647,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return v; } - // This function matches the return value from finishOuterGenerics - // so that we can create the target value without finishOuterGenerics having to be called. - IRInst* _getFinishOuterGenericsReturnValue( - IRInst* val, - IRGeneric* parentGeneric) - { - IRInst* v = val; - while (parentGeneric) - { - // There might be more outer generics, - // so we need to loop until we run out. - v = parentGeneric; - auto parentBlock = as<IRBlock>(v->getParent()); - if (!parentBlock) break; - - parentGeneric = as<IRGeneric>(parentBlock->getParent()); - if (!parentGeneric) break; - - } - return v; - } void addSpecializedForTargetDecorations(IRInst* inst, Decl* decl) { @@ -9700,6 +9715,26 @@ LoweredValInfo emitDeclRef( const auto initialSubst = subst; SLANG_UNUSED(initialSubst); + + if (auto thisTypeDecl = as<ThisTypeDecl>(decl)) + { + // A declref to ThisType decl should be lowered differently + // from other decls. In general, IFoo<T>.ThisType should lower to + // ThisType(specialize(IFoo,T)) instead of specialize(IFoo.ThisType, T). + SLANG_ASSERT(subst->getDecl() == decl); + IRType* parentInterfaceType = nullptr; + if (auto lookupDeclRef = as<LookupDeclRef>(subst)) + { + parentInterfaceType = lowerType(context, lookupDeclRef->getWitness()->getSup()); + } + else + { + parentInterfaceType = lowerType(context, DeclRefType::create(context->astBuilder, subst->getParent())); + } + auto thisType = context->irBuilder->getThisType(parentInterfaceType); + return LoweredValInfo::simple(thisType); + } + // We need to proceed by considering the specializations that // have been put in place. subst = SubstitutionSet(subst).getInnerMostNodeWithSubstInfo(); diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 59110ea05..3111ab132 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -223,7 +223,7 @@ namespace Slang else if( auto thisType = dynamicCast<ThisType>(type) ) { emitRaw(context, "t"); - emitQualifiedName(context, thisType->getInterfaceDecl()); + emitQualifiedName(context, thisType->getInterfaceDeclRef()); } else if (const auto errorType = dynamicCast<ErrorType>(type)) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 696575f8b..59aff4dc0 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -3061,7 +3061,7 @@ namespace Slang parser->FillPosition(paramConstraint); // substitution needs to be filled during check - DeclRefType* paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl)); + Type* paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl)); SharedTypeExpr* paramTypeExpr = parser->astBuilder->create<SharedTypeExpr>(); paramTypeExpr->loc = decl->loc; @@ -3128,12 +3128,14 @@ namespace Slang AdvanceIf(parser, TokenType::CompletionRequest); decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - - parseOptionalInheritanceClause(parser, decl); - - parseDeclBody(parser, decl); - - return decl; + return parseOptGenericDecl(parser, [&](GenericDecl*) + { + // We allow for an inheritance clause on a `struct` + // so that it can conform to interfaces. + parseOptionalInheritanceClause(parser, decl); + parseDeclBody(parser, decl); + return decl; + }); } static NodeBase* parseNamespaceDecl(Parser* parser, void* /*userData*/) diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 8ed50510f..d24fd239d 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -407,7 +407,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // TODO: need to figure out how to unify this with the logic // in the generic case... - DeclRefType* DeclRefType::create( + Type* DeclRefType::create( ASTBuilder* astBuilder, DeclRef<Decl> declRef) { diff --git a/tests/language-feature/generics/generic-interface-1.slang b/tests/language-feature/generics/generic-interface-1.slang new file mode 100644 index 000000000..217e7f06f --- /dev/null +++ b/tests/language-feature/generics/generic-interface-1.slang @@ -0,0 +1,37 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +interface IEqlTestable<T> +{ + bool testEql(T v1); +} + +bool test<T>(IEqlTestable<T> v0, T v1) +{ + return v0.testEql(v1); +} + +struct MyType : IEqlTestable<MyType> +{ + int val; + bool testEql(MyType v1) + { + return val == v1.val; + } +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(2, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + int tid = dispatchThreadID.x; + MyType obj1, obj2; + obj1.val = tid; + obj2.val = 1; + let result = test(obj1, obj2); + outputBuffer[tid] = result ? 1 : 0; + // CHECK: 0 + // CHECK: 1 +} |
