summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-10-25 07:45:23 -0700
committerGitHub <noreply@github.com>2023-10-25 22:45:23 +0800
commitf8bf75cf1ae0aeee155996a917c2925bc500f3e2 (patch)
tree07b418cfdc3fe106c492162624cfdaeb7a453be9
parentd8f4c9424c69a3d406fabf56a25dd3eda4bc7d51 (diff)
Support generic interfaces. (#3278)
* Initial support for generic interfaces. * Cleanup. * Add generic syntax for interfaces. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ast-base.cpp55
-rw-r--r--source/slang/slang-ast-base.h1
-rw-r--r--source/slang/slang-ast-builder.h1
-rw-r--r--source/slang/slang-ast-decl-ref.cpp10
-rw-r--r--source/slang/slang-ast-type.cpp4
-rw-r--r--source/slang/slang-ast-type.h4
-rw-r--r--source/slang/slang-check-decl.cpp18
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp143
-rw-r--r--source/slang/slang-mangle.cpp2
-rw-r--r--source/slang/slang-parser.cpp16
-rw-r--r--source/slang/slang-syntax.cpp2
-rw-r--r--tests/language-feature/generics/generic-interface-1.slang37
14 files changed, 204 insertions, 93 deletions
diff --git a/source/slang/slang-ast-base.cpp b/source/slang/slang-ast-base.cpp
index 60be7a563..d4904d2de 100644
--- a/source/slang/slang-ast-base.cpp
+++ b/source/slang/slang-ast-base.cpp
@@ -3,34 +3,43 @@
namespace Slang
{
-void NodeBase::_initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder)
-{
+ void NodeBase::_initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder)
+ {
#ifdef _DEBUG
- SLANG_UNUSED(inAstNodeType);
- static int32_t uidCounter = 0;
- static int32_t breakValue = 0;
- uidCounter++;
- _debugUID = uidCounter;
- if (inAstBuilder->getId() == -1)
- _debugUID = -_debugUID;
- if (breakValue != 0 && _debugUID == breakValue)
- SLANG_BREAKPOINT(0)
+ SLANG_UNUSED(inAstNodeType);
+ static int32_t uidCounter = 0;
+ static int32_t breakValue = 0;
+ uidCounter++;
+ _debugUID = uidCounter;
+ if (inAstBuilder->getId() == -1)
+ _debugUID = -_debugUID;
+ if (breakValue != 0 && _debugUID == breakValue)
+ SLANG_BREAKPOINT(0)
#else
- SLANG_UNUSED(inAstNodeType);
- SLANG_UNUSED(inAstBuilder);
+ SLANG_UNUSED(inAstNodeType);
+ SLANG_UNUSED(inAstBuilder);
#endif
-}
-DeclRefBase* Decl::getDefaultDeclRef()
-{
- if (auto astBuilder = getCurrentASTBuilder())
+ }
+ DeclRefBase* Decl::getDefaultDeclRef()
{
- const Index currentEpoch = astBuilder->getEpoch();
- if (currentEpoch != m_defaultDeclRefEpoch || !m_defaultDeclRef)
+ if (auto astBuilder = getCurrentASTBuilder())
{
- m_defaultDeclRef = astBuilder->getOrCreate<DirectDeclRef>(this);
- m_defaultDeclRefEpoch = currentEpoch;
+ const Index currentEpoch = astBuilder->getEpoch();
+ if (currentEpoch != m_defaultDeclRefEpoch || !m_defaultDeclRef)
+ {
+ m_defaultDeclRef = astBuilder->getOrCreate<DirectDeclRef>(this);
+ m_defaultDeclRefEpoch = currentEpoch;
+ }
}
+ return m_defaultDeclRef;
}
- return m_defaultDeclRef;
-}
+
+ bool Decl::isChildOf(Decl* other) const
+ {
+ for (auto parent = parentDecl; parent; parent = parent->parentDecl)
+ if (parent == other)
+ return true;
+ return false;
+ }
+
}
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index 0170ca493..579bda73a 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -734,6 +734,7 @@ public:
SLANG_RELEASE_ASSERT(state >= checkState.getState());
checkState.setState(state);
}
+ bool isChildOf(Decl* other) const;
private:
SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr;
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index aff3088ab..1c6637c31 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -294,7 +294,6 @@ public:
interfaceDecl->addMember(thisDecl);
auto thisConstraint = create<ThisTypeConstraintDecl>();
thisConstraint->loc = loc;
- thisConstraint->base.type = DeclRefType::create(this, getDirectDeclRef(interfaceDecl));
thisDecl->addMember(thisConstraint);
return interfaceDecl;
}
diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp
index c77cf72ed..c9511e4e7 100644
--- a/source/slang/slang-ast-decl-ref.cpp
+++ b/source/slang/slang-ast-decl-ref.cpp
@@ -150,6 +150,11 @@ Val* LookupDeclRef::_resolveImplOverride()
DeclRefBase* LookupDeclRef::_getBaseOverride()
{
+ auto supType = getWitness()->getSup();
+ if (auto declRefType = as<DeclRefType>(supType))
+ {
+ return declRefType->getDeclRef();
+ }
return nullptr;
}
@@ -432,10 +437,13 @@ DeclRef<Decl> createDefaultSubstitutionsIfNeeded(
ShortList<GenericDecl*> genericParentDecls;
auto lastSubstNode = SubstitutionSet(declRef).getInnerMostNodeWithSubstInfo();
auto lastGenApp = as<GenericAppDeclRef>(lastSubstNode);
+ auto lastLookup = as<LookupDeclRef>(lastSubstNode);
for (auto dd = declRef.getDecl()->parentDecl; dd; dd = dd->parentDecl)
{
if (lastGenApp && dd == lastGenApp->getGenericDecl())
break;
+ if (lastLookup && lastLookup->getDecl()->isChildOf(dd))
+ break;
if (auto gen = as<GenericDecl>(dd))
genericParentDecls.add(gen);
}
@@ -454,6 +462,8 @@ DeclRef<Decl> createDefaultSubstitutionsIfNeeded(
}
parentDeclRef = astBuilder->getGenericAppDeclRef(parentDeclRef.as<GenericDecl>(), args.getArrayView());
}
+ if (!parentDeclRef)
+ return declRef;
if (parentDeclRef.getDecl() == declRef.getDecl())
return parentDeclRef;
return astBuilder->getMemberDeclRef(parentDeclRef, declRef.getDecl());
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index a29ff9bb3..840aa4a67 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -636,9 +636,9 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder,
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-InterfaceDecl* ThisType::getInterfaceDecl()
+DeclRef<InterfaceDecl> ThisType::getInterfaceDeclRef()
{
- return dynamicCast<InterfaceDecl>(getDeclRefBase()->getDecl()->parentDecl);
+ return DeclRef<Decl>(getDeclRefBase()->getParent()).template as<InterfaceDecl>();
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AndType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 3c50b1899..638012652 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -55,7 +55,7 @@ class DeclRefType : public Type
{
SLANG_AST_CLASS(DeclRefType)
- static DeclRefType* create(ASTBuilder* astBuilder, DeclRef<Decl> declRef);
+ static Type* create(ASTBuilder* astBuilder, DeclRef<Decl> declRef);
DeclRef<Decl> getDeclRef() const { return DeclRef<Decl>(as<DeclRefBase>(getOperand(0))); }
DeclRefBase* getDeclRefBase() const { return as<DeclRefBase>(getOperand(0)); }
@@ -786,7 +786,7 @@ class ThisType : public DeclRefType
ThisType(DeclRefBase* declRef) : DeclRefType(declRef) {}
- InterfaceDecl* getInterfaceDecl();
+ DeclRef<InterfaceDecl> getInterfaceDeclRef();
};
/// The type of `A & B` where `A` and `B` are types
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 1e3c6a361..8df5ae618 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -169,6 +169,8 @@ namespace Slang
void visitInheritanceDecl(InheritanceDecl* inheritanceDecl);
+ void visitThisTypeConstraintDecl(ThisTypeConstraintDecl* thisTypeConstraintDecl);
+
/// Validate that `decl` isn't illegally inheriting from a type in another module.
///
/// This call checks a single `inheritanceDecl` to make sure that it either
@@ -1600,6 +1602,22 @@ namespace Slang
// based on the declaration that is doing the inheriting.
}
+ void SemanticsDeclBasesVisitor::visitThisTypeConstraintDecl(ThisTypeConstraintDecl* thisTypeConstraintDecl)
+ {
+ // Make sure IFoo<T>.This.ThisIsIFooConstraint.base.type is properly set
+ // to DeclRefType(IFoo<T>) with default generic arguments.
+ if (!thisTypeConstraintDecl->base.type)
+ {
+ auto parentTypeDecl = getParentDecl(getParentDecl(thisTypeConstraintDecl));
+ thisTypeConstraintDecl->base.type = DeclRefType::create(
+ m_astBuilder,
+ createDefaultSubstitutionsIfNeeded(
+ m_astBuilder,
+ this,
+ getDefaultDeclRef(parentTypeDecl)));
+ }
+ }
+
// Concretize interface conformances so that we have witnesses as required for lookup.
// for lookup.
struct SemanticsDeclConformancesVisitor
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index c04450b82..be6228f2d 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3226,7 +3226,7 @@ public:
IRType* getCapabilitySetType();
IRAssociatedType* getAssociatedType(ArrayView<IRInterfaceType*> constraintTypes);
- IRThisType* getThisType(IRInterfaceType* interfaceType);
+ IRThisType* getThisType(IRType* interfaceType);
IRRawPointerType* getRawPointerType();
IRRTTIPointerType* getRTTIPointerType(IRInst* rttiPtr);
IRRTTIType* getRTTIType();
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index cf58e6cd4..2f603ac17 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2653,7 +2653,7 @@ namespace Slang
(IRInst**)constraintTypes.getBuffer());
}
- IRThisType* IRBuilder::getThisType(IRInterfaceType* interfaceType)
+ IRThisType* IRBuilder::getThisType(IRType* interfaceType)
{
return (IRThisType*)getType(kIROp_ThisType, interfaceType);
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 2813918b6..d75b66a9b 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -535,6 +535,22 @@ struct SharedIRGenContext
List<IRInst*> m_stringLiterals;
};
+struct IRGenContext;
+
+struct AstOrIRType
+{
+ Type* astType = nullptr;
+ IRInst* irType = nullptr;
+ IRInst* getIRType(IRGenContext* context);
+
+ AstOrIRType& operator=(Type* t) { astType = t; irType = nullptr; return *this; }
+ AstOrIRType& operator=(IRInst* t) { astType = nullptr; irType = t; return *this; }
+ explicit operator bool()
+ {
+ return astType || irType;
+ }
+};
+
struct IRGenContext
{
ASTBuilder* astBuilder;
@@ -558,7 +574,7 @@ struct IRGenContext
LoweredValInfo thisVal;
// The IRType value to lower into for `ThisType`.
- IRInst* thisType = nullptr;
+ AstOrIRType thisType;
// The IR witness value to use for `ThisType`
IRInst* thisTypeWitness = nullptr;
@@ -824,6 +840,14 @@ static IRType* lowerType(
return lowerType(context, type.type);
}
+IRInst* AstOrIRType::getIRType(IRGenContext* context)
+{
+ if (irType)
+ return irType;
+ irType = lowerType(context, astType);
+ return irType;
+}
+
// Given a `DeclRef` for something callable, along with a bunch of
// arguments, emit an appropriate call to it.
LoweredValInfo emitCallToDeclRef(
@@ -1984,9 +2008,17 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// Therefore, `context->thisType` should have been set to `IRThisType`
// in `visitInterfaceDecl`, and we can just use that value here.
//
- if (context->thisType != nullptr)
- return LoweredValInfo::simple(context->thisType);
- return emitDeclRef(context, makeDeclRef(type->getInterfaceDecl()), getBuilder()->getTypeKind());
+ if (context->thisType.irType)
+ {
+ return LoweredValInfo::simple(context->thisType.irType);
+ }
+ auto interfaceType = emitDeclRef(context, type->getInterfaceDeclRef(), getBuilder()->getTypeKind());
+ auto result = LoweredValInfo::simple(getBuilder()->getThisType((IRType*)getSimpleVal(context, interfaceType)));
+ if (context->thisType.astType == type)
+ {
+ context->thisType = getSimpleVal(context, result);
+ }
+ return result;
}
LoweredValInfo visitAndType(AndType* type)
@@ -2668,7 +2700,9 @@ static Type* _findReplacementThisParamType(
if (auto interfaceDeclRef = parentDeclRef.as<InterfaceDecl>())
{
- auto thisType = DeclRefType::create(context->astBuilder, interfaceDeclRef.getDecl()->getThisTypeDecl());
+ auto thisType = DeclRefType::create(
+ context->astBuilder,
+ context->astBuilder->getMemberDeclRef(interfaceDeclRef, interfaceDeclRef.getDecl()->getThisTypeDecl()));
return thisType;
}
@@ -2704,6 +2738,11 @@ Type* getThisParamTypeForCallable(
IRGenContext* context,
DeclRef<Decl> callableDeclRef)
{
+ if (auto lookup = as<LookupDeclRef>((callableDeclRef.declRefBase)))
+ {
+ return lookup->getLookupSource();
+ }
+
auto parentDeclRef = callableDeclRef.getParent();
if(auto subscriptDeclRef = parentDeclRef.as<SubscriptDecl>())
@@ -7751,13 +7790,19 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Allocate an IRInterfaceType with the `operandCount` operands.
IRInterfaceType* irInterface = subBuilder->createInterfaceType(operandCount, nullptr);
+ auto finalVal = finishOuterGenerics(subBuilder, irInterface, outerGeneric);
// Add `irInterface` to decl mapping now to prevent cyclic lowering.
- context->setValue(decl, LoweredValInfo::simple(irInterface));
+ context->setValue(decl, LoweredValInfo::simple(finalVal));
+
+ subBuilder->setInsertBefore(irInterface);
// Setup subContext for proper lowering `ThisType`, associated types and
// the interface decl's self reference.
- auto thisType = getBuilder()->getThisType(irInterface);
+
+ auto thisType = DeclRefType::create(
+ context->astBuilder,
+ createDefaultSpecializedDeclRef(subContext, nullptr, decl->getThisTypeDecl()));
subContext->thisType = thisType;
// TODO: Need to add an appropriate stand-in witness here.
@@ -7880,14 +7925,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
subBuilder->setInsertInto(irInterface);
- // TODO: are there any interface members that should be
- // nested inside the interface type itself?
-
- irInterface->moveToEnd();
addTargetIntrinsicDecorations(subContext, irInterface, decl);
- auto finalVal = finishOuterGenerics(subBuilder, irInterface, outerGeneric);
return LoweredValInfo::simple(finalVal);
}
@@ -7939,8 +7979,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
LoweredValInfo visitThisTypeDecl(ThisTypeDecl* decl)
{
- auto interfaceType = ensureDecl(context, decl->parentDecl).val;
- return LoweredValInfo::simple(context->irBuilder->getThisType(as<IRInterfaceType>(interfaceType)));
+ SLANG_UNUSED(decl);
+ return LoweredValInfo();
}
LoweredValInfo visitThisTypeConstraintDecl(ThisTypeConstraintDecl* decl)
@@ -7968,14 +8008,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
const bool isPublicType = decl->findModifier<PublicModifier>() != nullptr;
- // Given a declaration of a type, we need to make sure
- // to output "witness tables" for any interfaces this
- // type has declared conformance to.
- for( auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>() )
- {
- ensureDecl(context, inheritanceDecl);
- }
-
// We are going to create nested IR building state
// to use when emitting the members of the type.
//
@@ -8001,11 +8033,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo::simple(subBuilder->getVoidType());
}
- const auto finishedVal = _getFinishOuterGenericsReturnValue(irAggType, outerGeneric);
+ auto finalFinishedVal = finishOuterGenerics(subBuilder, irAggType, outerGeneric);
// We add the decl now such that if there are Ptr or other references
// to this type they can still complete
- context->setValue(decl, LoweredValInfo::simple(finishedVal));
+ context->setValue(decl, LoweredValInfo::simple(finalFinishedVal));
+
+ subBuilder->setInsertBefore(irAggType);
+
+ // Given a declaration of a type, we need to make sure
+ // to output "witness tables" for any interfaces this
+ // type has declared conformance to.
+ for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>())
+ {
+ ensureDecl(subContext, inheritanceDecl);
+ }
addNameHint(context, irAggType, decl);
addLinkageDecoration(context, irAggType, decl);
@@ -8022,8 +8064,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
for( auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>() )
{
- if (isPublicType)
- ensureDecl(context, inheritanceDecl);
auto superType = inheritanceDecl->base;
if(auto superDeclRefType = as<DeclRefType>(superType))
{
@@ -8031,7 +8071,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
superDeclRefType->getDeclRef().as<ClassDecl>())
{
auto superKey = (IRStructKey*) getSimpleVal(context, ensureDecl(context, inheritanceDecl));
- auto irSuperType = lowerType(context, superType.type);
+ auto irSuperType = lowerType(subContext, superType.type);
subBuilder->createStructField(
irAggType,
superKey,
@@ -8053,8 +8093,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Each ordinary field will need to turn into a struct "key"
// that is used for fetching the field.
- IRInst* fieldKeyInst = getSimpleVal(context,
- ensureDecl(context, fieldDecl));
+ IRInst* fieldKeyInst = getSimpleVal(subContext,
+ ensureDecl(subContext, fieldDecl));
auto fieldKey = as<IRStructKey>(fieldKeyInst);
SLANG_ASSERT(fieldKey);
@@ -8085,7 +8125,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Instead we will force emission of all children of aggregate
// type declarations later, from the top-level emit logic.
- irAggType->moveToEnd();
addTargetIntrinsicDecorations(subContext, irAggType, decl);
for (auto modifier : decl->modifiers)
{
@@ -8093,9 +8132,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
subBuilder->addNonCopyableTypeDecoration(irAggType);
}
- auto finalFinishedVal = finishOuterGenerics(subBuilder, irAggType, outerGeneric);
- // Confirm that _getFinishOuterGenericsReturnValue above returned the same result
- SLANG_ASSERT(finalFinishedVal == finishedVal);
return LoweredValInfo::simple(finalFinishedVal);
}
@@ -8611,27 +8647,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return v;
}
- // This function matches the return value from finishOuterGenerics
- // so that we can create the target value without finishOuterGenerics having to be called.
- IRInst* _getFinishOuterGenericsReturnValue(
- IRInst* val,
- IRGeneric* parentGeneric)
- {
- IRInst* v = val;
- while (parentGeneric)
- {
- // There might be more outer generics,
- // so we need to loop until we run out.
- v = parentGeneric;
- auto parentBlock = as<IRBlock>(v->getParent());
- if (!parentBlock) break;
-
- parentGeneric = as<IRGeneric>(parentBlock->getParent());
- if (!parentGeneric) break;
-
- }
- return v;
- }
void addSpecializedForTargetDecorations(IRInst* inst, Decl* decl)
{
@@ -9700,6 +9715,26 @@ LoweredValInfo emitDeclRef(
const auto initialSubst = subst;
SLANG_UNUSED(initialSubst);
+
+ if (auto thisTypeDecl = as<ThisTypeDecl>(decl))
+ {
+ // A declref to ThisType decl should be lowered differently
+ // from other decls. In general, IFoo<T>.ThisType should lower to
+ // ThisType(specialize(IFoo,T)) instead of specialize(IFoo.ThisType, T).
+ SLANG_ASSERT(subst->getDecl() == decl);
+ IRType* parentInterfaceType = nullptr;
+ if (auto lookupDeclRef = as<LookupDeclRef>(subst))
+ {
+ parentInterfaceType = lowerType(context, lookupDeclRef->getWitness()->getSup());
+ }
+ else
+ {
+ parentInterfaceType = lowerType(context, DeclRefType::create(context->astBuilder, subst->getParent()));
+ }
+ auto thisType = context->irBuilder->getThisType(parentInterfaceType);
+ return LoweredValInfo::simple(thisType);
+ }
+
// We need to proceed by considering the specializations that
// have been put in place.
subst = SubstitutionSet(subst).getInnerMostNodeWithSubstInfo();
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index 59110ea05..3111ab132 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -223,7 +223,7 @@ namespace Slang
else if( auto thisType = dynamicCast<ThisType>(type) )
{
emitRaw(context, "t");
- emitQualifiedName(context, thisType->getInterfaceDecl());
+ emitQualifiedName(context, thisType->getInterfaceDeclRef());
}
else if (const auto errorType = dynamicCast<ErrorType>(type))
{
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 696575f8b..59aff4dc0 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -3061,7 +3061,7 @@ namespace Slang
parser->FillPosition(paramConstraint);
// substitution needs to be filled during check
- DeclRefType* paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl));
+ Type* paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl));
SharedTypeExpr* paramTypeExpr = parser->astBuilder->create<SharedTypeExpr>();
paramTypeExpr->loc = decl->loc;
@@ -3128,12 +3128,14 @@ namespace Slang
AdvanceIf(parser, TokenType::CompletionRequest);
decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
-
- parseOptionalInheritanceClause(parser, decl);
-
- parseDeclBody(parser, decl);
-
- return decl;
+ return parseOptGenericDecl(parser, [&](GenericDecl*)
+ {
+ // We allow for an inheritance clause on a `struct`
+ // so that it can conform to interfaces.
+ parseOptionalInheritanceClause(parser, decl);
+ parseDeclBody(parser, decl);
+ return decl;
+ });
}
static NodeBase* parseNamespaceDecl(Parser* parser, void* /*userData*/)
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 8ed50510f..d24fd239d 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -407,7 +407,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
// TODO: need to figure out how to unify this with the logic
// in the generic case...
- DeclRefType* DeclRefType::create(
+ Type* DeclRefType::create(
ASTBuilder* astBuilder,
DeclRef<Decl> declRef)
{
diff --git a/tests/language-feature/generics/generic-interface-1.slang b/tests/language-feature/generics/generic-interface-1.slang
new file mode 100644
index 000000000..217e7f06f
--- /dev/null
+++ b/tests/language-feature/generics/generic-interface-1.slang
@@ -0,0 +1,37 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
+
+interface IEqlTestable<T>
+{
+ bool testEql(T v1);
+}
+
+bool test<T>(IEqlTestable<T> v0, T v1)
+{
+ return v0.testEql(v1);
+}
+
+struct MyType : IEqlTestable<MyType>
+{
+ int val;
+ bool testEql(MyType v1)
+ {
+ return val == v1.val;
+ }
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(2, 1, 1)]
+void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int tid = dispatchThreadID.x;
+ MyType obj1, obj2;
+ obj1.val = tid;
+ obj2.val = 1;
+ let result = test(obj1, obj2);
+ outputBuffer[tid] = result ? 1 : 0;
+ // CHECK: 0
+ // CHECK: 1
+}