summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-02-20 21:35:03 -0800
committerGitHub <noreply@github.com>2024-02-20 21:35:03 -0800
commit2ee05c1257c916e5c804a6b565a2a6aa362050e0 (patch)
treeeab2a825afa5b8b48f62b5e6fd513f16a9e754e7 /source
parenta62be597990966b9516995650baf750ee6a0146b (diff)
Add wrapper type syntax for link time specialization. (#3606)
* Add wrapper type syntax for link time specialization. * Cleanup.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-decl.h3
-rw-r--r--source/slang/slang-check-decl.cpp366
-rw-r--r--source/slang/slang-check-impl.h17
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-lower-to-ir.cpp151
-rw-r--r--source/slang/slang-parser.cpp8
6 files changed, 462 insertions, 84 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 7a4d46947..c516b15c7 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -144,6 +144,9 @@ class AggTypeDecl : public AggTypeDeclBase
TypeTag typeTags = TypeTag::None;
+ // Used if this type declaration is a wrapper, i.e. struct FooWrapper:IFoo = Foo;
+ TypeExp wrappedType;
+
void unionTagsWith(TypeTag other);
void addTag(TypeTag tag);
bool hasTag(TypeTag tag);
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index dc888ccda..1c964ab88 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1738,6 +1738,24 @@ namespace Slang
{
structDecl->addTag(TypeTag::Incomplete);
}
+
+ // Slang supports a convenient syntax to create a wrapper type from
+ // an existing type that implements a given interface. For example,
+ // the user can write: struct FooWrapper:IFoo = Foo;
+ // In this case we will synthesize the FooWrapper type with an inner
+ // member of type `Foo`, and use it to implement all requirements of
+ // IFoo.
+ // If this is a wrapper struct, synthesize the inner member now.
+ if (structDecl->wrappedType.exp)
+ {
+ structDecl->wrappedType = CheckProperType(structDecl->wrappedType);
+ auto member = m_astBuilder->create<VarDecl>();
+ member->type = structDecl->wrappedType;
+ member->nameAndLoc.name = getName("inner");
+ member->nameAndLoc.loc = structDecl->wrappedType.exp->loc;
+ member->loc = member->nameAndLoc.loc;
+ structDecl->addMember(member);
+ }
checkVisibility(structDecl);
}
@@ -1839,6 +1857,14 @@ namespace Slang
getSink()->diagnose(varDecl->type.exp->loc, Diagnostics::incompleteTypeCannotBeUsedInBuffer, elementType);
}
}
+ else if (varDecl->findModifier<HLSLUniformModifier>())
+ {
+ auto varType = varDecl->getType();
+ if (doesTypeHaveTag(varType, TypeTag::Incomplete))
+ {
+ getSink()->diagnose(varDecl->type.exp->loc, Diagnostics::incompleteTypeCannotBeUsedInUniformParameter, varType);
+ }
+ }
maybeRegisterDifferentiableType(getASTBuilder(), varDecl->getType());
}
@@ -3499,6 +3525,16 @@ namespace Slang
witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef));
}
+ static bool isWrapperTypeDecl(Decl* decl)
+ {
+ if (auto aggTypeDecl = as<AggTypeDecl>(decl))
+ {
+ if (aggTypeDecl->wrappedType)
+ return true;
+ }
+ return false;
+ }
+
bool SemanticsVisitor::trySynthesizeMethodRequirementWitness(
ConformanceCheckingContext* context,
LookupResult const& lookupResult,
@@ -3576,14 +3612,47 @@ namespace Slang
//
auto synBase = m_astBuilder->create<OverloadedExpr>();
synBase->name = requiredMemberDeclRef.getDecl()->getName();
- synBase->lookupResult2 = lookupResult;
+
+ if (isWrapperTypeDecl(context->parentDecl))
+ {
+ auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl);
+ synBase->lookupResult2 = lookUpMember(
+ m_astBuilder,
+ this,
+ synBase->name,
+ aggTypeDecl->wrappedType.type,
+ aggTypeDecl->ownedScope,
+ LookupMask::Default,
+ LookupOptions::IgnoreBaseInterfaces);
+ addModifier(synFuncDecl, m_astBuilder->create<ForceInlineAttribute>());
+ }
+ else
+ {
+ synBase->lookupResult2 = lookupResult;
+ }
// If `synThis` is non-null, then we will use it as the base of
// the overloaded expression, so that we have an overloaded
// member reference, and not just an overloaded reference to some
// static definitions.
//
- synBase->base = synThis;
+ if (synThis)
+ {
+ if (isWrapperTypeDecl(context->parentDecl))
+ {
+ // If this is a wrapper type, then use the inner
+ // object as the actual this parameter for the redirected
+ // call.
+ auto innerExpr = m_astBuilder->create<VarExpr>();
+ innerExpr->scope = synThis->scope;
+ innerExpr->name = getName("inner");
+ synBase->base = CheckExpr(innerExpr);
+ }
+ else
+ {
+ synBase->base = synThis;
+ }
+ }
// We now have the reference to the overload group we plan to call,
// and we already built up the argument list, so we can construct
@@ -3695,6 +3764,9 @@ namespace Slang
DeclRef<PropertyDecl> requiredMemberDeclRef,
RefPtr<WitnessTable> witnessTable)
{
+ if (isWrapperTypeDecl(context->parentDecl))
+ return trySynthesizeWrapperTypePropertyRequirementWitness(context, requiredMemberDeclRef, witnessTable);
+
// The situation here is that the context of an inheritance
// declaration didn't provide an exact match for a required
// property. E.g.:
@@ -3745,16 +3817,10 @@ namespace Slang
//
auto synPropertyDecl = m_astBuilder->create<PropertyDecl>();
- // For now our synthesized property will use the name and source
- // location of the requirement we are trying to satisfy.
- //
- // TODO: as it stands right now our syntesized property and its
- // accesors will get mangled names, which we don't actually want.
- // Leaving out the name here doesn't help matters, becaues then
- // *all* synthesized members on a given type would share the same
- // mangled name.
- //
+ // Synthesize the property name with a prefix to avoid name clashing.
synPropertyDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc;
+ synPropertyDecl->nameAndLoc.name = getName(String("$syn_property_") + getText(requiredMemberDeclRef.getName()));
+
// The type of our synthesized property will be the expected type
// of the interface requirement.
@@ -4041,6 +4107,233 @@ namespace Slang
return true;
}
+ bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness(
+ ConformanceCheckingContext* context,
+ DeclRef<PropertyDecl> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ // We are synthesizing a property requirement for a wrapper type:
+ //
+ // interface IFoo { property value : int { get; set; } }
+ // struct Foo : IFoo = FooImpl;
+ //
+ // We need to synthesize Foo to:
+ //
+ // struct Foo : IFoo
+ // {
+ // FooImpl inner;
+ // property value : int { get { return inner.value; }
+ // set { inner.value = newValue; }
+ // }
+ // }
+ //
+ // To do so, we need to grab the witness table of FooImpl:IFoo, and create
+ // wrapper property in Foo that forwards the accessors to the inner object.
+ //
+ // We get started by constructing a synthesized `PropertyDecl`.
+ //
+ auto synPropertyDecl = m_astBuilder->create<PropertyDecl>();
+ synPropertyDecl->parentDecl = context->parentDecl;
+
+ // Synthesize the property name with a prefix to avoid name clashing.
+ //
+ synPropertyDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc;
+ synPropertyDecl->nameAndLoc.name = getName(String("$syn_property_") + getText(requiredMemberDeclRef.getName()));
+
+ // Find the witness that FooImpl : IFoo.
+ auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl);
+ auto innerType = aggTypeDecl->wrappedType.type;
+ DeclRef<Decl> innerProperty;
+ auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType);
+ if (!innerWitness)
+ return false;
+
+ for (auto requiredAccessorDeclRef : getMembersOfType<AccessorDecl>(m_astBuilder, requiredMemberDeclRef))
+ {
+ auto innerEntry = tryLookUpRequirementWitness(m_astBuilder, innerWitness, requiredAccessorDeclRef.getDecl());
+ if (innerEntry.getFlavor() != RequirementWitness::Flavor::declRef)
+ return false;
+ auto innerAccessorDeclRef = as<AccessorDecl>(innerEntry.getDeclRef());
+ if (!innerAccessorDeclRef)
+ return false;
+
+ // The synthesized accessor will be an AST node of the same class as
+ // the required accessor.
+ //
+ auto synAccessorDecl = (AccessorDecl*)m_astBuilder->createByNodeType(requiredAccessorDeclRef.getDecl()->astNodeType);
+ synAccessorDecl->ownedScope = m_astBuilder->create<Scope>();
+ synAccessorDecl->ownedScope->containerDecl = synAccessorDecl;
+ synAccessorDecl->ownedScope->parent = getScope(context->parentDecl);
+
+ // The return type should be the same as the inner object's accessor return type.
+ //
+ synAccessorDecl->returnType.type = getResultType(m_astBuilder, innerAccessorDeclRef);
+
+ // Similarly, our synthesized accessor will have parameters matching those of the inner accessor.
+ //
+ List<Expr*> synArgs;
+ for (auto innerParamDeclRef : getParameters(m_astBuilder, innerAccessorDeclRef))
+ {
+ auto paramType = getType(m_astBuilder, innerParamDeclRef);
+
+ // The synthesized parameter will ahve the same name and
+ // type as the parameter of the requirement.
+ //
+ auto synParamDecl = m_astBuilder->create<ParamDecl>();
+ synParamDecl->nameAndLoc = innerParamDeclRef.getDecl()->nameAndLoc;
+ synParamDecl->type.type = paramType;
+
+ // We need to add the parameter as a child declaration of
+ // the accessor we are building.
+ //
+ synParamDecl->parentDecl = synAccessorDecl;
+ synAccessorDecl->members.add(synParamDecl);
+
+ // For each paramter, we will create an argument expression
+ // to represent it in the body of the accessor.
+ //
+ auto synArg = m_astBuilder->create<VarExpr>();
+ synArg->declRef = makeDeclRef(synParamDecl);
+ synArg->type = paramType;
+ synArgs.add(synArg);
+ }
+
+ // Now synthesize the body of the property accessor.
+ // The body of the accessor will depend on the class of the accessor
+ // we are synthesizing (e.g., `get` vs. `set`).
+ //
+ Stmt* synBodyStmt = nullptr;
+ auto propertyRef = m_astBuilder->create<MemberExpr>();
+ propertyRef->scope = synAccessorDecl->ownedScope;
+ auto base = m_astBuilder->create<VarExpr>();
+ base->scope = propertyRef->scope;
+ base->name = getName("inner");
+ propertyRef->baseExpression = base;
+ innerProperty = innerAccessorDeclRef.getParent();
+ propertyRef->name = getParentDecl(innerAccessorDeclRef.getDecl())->getName();
+ auto checkedPropertyRefExpr = CheckExpr(propertyRef);
+
+ if (as<GetterDecl>(requiredAccessorDeclRef))
+ {
+ auto synReturn = m_astBuilder->create<ReturnStmt>();
+ synReturn->expression = checkedPropertyRefExpr;
+
+ synBodyStmt = synReturn;
+ }
+ else if (as<SetterDecl>(requiredAccessorDeclRef))
+ {
+ auto synAssign = m_astBuilder->create<AssignExpr>();
+ synAssign->left = checkedPropertyRefExpr;
+ synAssign->right = synArgs[0];
+
+ auto synCheckedAssign = checkAssignWithCheckedOperands(synAssign);
+
+ auto synExprStmt = m_astBuilder->create<ExpressionStmt>();
+ synExprStmt->expression = synCheckedAssign;
+
+ synBodyStmt = synExprStmt;
+ }
+ else
+ {
+ // While there are other kinds of accessors than `get` and `set`,
+ // those are currently only reserved for stdlib-internal use.
+ // We will not bother with synthesis for those cases.
+ //
+ return false;
+ }
+
+ addModifier(synAccessorDecl, m_astBuilder->create<ForceInlineAttribute>());
+ synAccessorDecl->body = synBodyStmt;
+
+ synAccessorDecl->parentDecl = synPropertyDecl;
+ synPropertyDecl->members.add(synAccessorDecl);
+
+ // Register the synthesized accessor.
+ //
+ witnessTable->add(requiredAccessorDeclRef.getDecl(), RequirementWitness(makeDeclRef(synAccessorDecl)));
+ }
+
+ // The type of our synthesized property will be the same as the inner property.
+ //
+ auto propertyType = getType(m_astBuilder, as<PropertyDecl>(innerProperty));
+ synPropertyDecl->type.type = propertyType;
+
+ // The visibility of synthesized decl should be the same as the inner requirement
+ if (innerProperty.getDecl()->findModifier<VisibilityModifier>())
+ {
+ auto vis = getDeclVisibility(innerProperty.getDecl());
+ addVisibilityModifier(m_astBuilder, synPropertyDecl, vis);
+ }
+
+ context->parentDecl->addMember(synPropertyDecl);
+ witnessTable->add(requiredMemberDeclRef.getDecl(),
+ RequirementWitness(makeDeclRef(synPropertyDecl)));
+ return true;
+ }
+
+ bool SemanticsVisitor::trySynthesizeAssociatedTypeRequirementWitness(
+ ConformanceCheckingContext* context,
+ LookupResult const& inLookupResult,
+ DeclRef<AssocTypeDecl> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ SLANG_UNUSED(inLookupResult);
+
+ // The only case we can synthesize for now is when the conformant type
+ // is a wrapper type.
+ if (!isWrapperTypeDecl(context->parentDecl))
+ return false;
+ auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl);
+ auto lookupResult = lookUpMember(
+ m_astBuilder,
+ this,
+ requiredMemberDeclRef.getName(),
+ aggTypeDecl->wrappedType.type,
+ aggTypeDecl->ownedScope,
+ LookupMask::Default,
+ LookupOptions::IgnoreBaseInterfaces);
+ if (!lookupResult.isValid() || lookupResult.isOverloaded())
+ return false;
+ auto assocType = DeclRefType::create(m_astBuilder, lookupResult.item.declRef);
+ witnessTable->add(requiredMemberDeclRef.getDecl(), assocType);
+ for (auto typeConstraintDecl : getMembersOfType<TypeConstraintDecl>(m_astBuilder, requiredMemberDeclRef))
+ {
+ auto witness = tryGetSubtypeWitness(assocType, getSup(m_astBuilder, typeConstraintDecl));
+ if (!witness)
+ return false;
+ witnessTable->add(typeConstraintDecl.getDecl(), witness);
+ }
+ return true;
+ }
+
+ bool SemanticsVisitor::trySynthesizeAssociatedConstantRequirementWitness(
+ ConformanceCheckingContext* context,
+ LookupResult const& inLookupResult,
+ DeclRef<VarDeclBase> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ SLANG_UNUSED(inLookupResult);
+
+ // The only case we can synthesize for now is when the conformant type
+ // is a wrapper type, i.e.
+ // struct Foo:IFoo = FooImpl;
+ if (!isWrapperTypeDecl(context->parentDecl))
+ return false;
+
+ // Find the witness that FooImpl : IFoo.
+ auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl);
+ auto innerType = aggTypeDecl->wrappedType.type;
+ DeclRef<Decl> innerProperty;
+ auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType);
+ if (!innerWitness)
+ return false;
+
+ auto witness = tryLookUpRequirementWitness(m_astBuilder, innerWitness, requiredMemberDeclRef.getDecl());
+ if (witness.getFlavor() != RequirementWitness::Flavor::val)
+ return false;
+ witnessTable->add(requiredMemberDeclRef.getDecl(), witness.getVal());
+ return true;
+ }
bool SemanticsVisitor::trySynthesizeRequirementWitness(
ConformanceCheckingContext* context,
@@ -4118,6 +4411,23 @@ namespace Slang
witnessTable);
}
}
+ else
+ {
+ return trySynthesizeAssociatedTypeRequirementWitness(
+ context,
+ lookupResult,
+ requiredAssocTypeDeclRef,
+ witnessTable);
+ }
+ }
+
+ if (auto requiredConstantDeclRef = requiredMemberDeclRef.as<VarDeclBase>())
+ {
+ return trySynthesizeAssociatedConstantRequirementWitness(
+ context,
+ lookupResult,
+ requiredConstantDeclRef,
+ witnessTable);
}
// TODO: There are other kinds of requirements for which synthesis should
@@ -4522,21 +4832,25 @@ namespace Slang
// requests will be handled further down. For now we include
// lookup results that might be usable, but not as-is.
//
- auto lookupResult = lookUpMember(m_astBuilder, this, name, subType, nullptr, LookupMask::Default, LookupOptions::IgnoreBaseInterfaces);
-
- if(!lookupResult.isValid())
+ LookupResult lookupResult;
+ if (!isWrapperTypeDecl(context->parentDecl))
{
- // If we failed to look up a member with the name of the
- // requirement, it may be possible that we can still synthesis the
- // implementation if this is one of the known builtin requirements.
- // Otherwise, report diagnostic now.
- if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>() &&
- !(requiredMemberDeclRef.as<GenericDecl>() &&
- getInner(requiredMemberDeclRef.as<GenericDecl>())->hasModifier<BuiltinRequirementModifier>()))
+ lookupResult = lookUpMember(m_astBuilder, this, name, subType, nullptr, LookupMask::Default, LookupOptions::IgnoreBaseInterfaces);
+
+ if (!lookupResult.isValid())
{
- getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef);
- getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef);
- return false;
+ // If we failed to look up a member with the name of the
+ // requirement, it may be possible that we can still synthesis the
+ // implementation if this is one of the known builtin requirements.
+ // Otherwise, report diagnostic now.
+ if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>() &&
+ !(requiredMemberDeclRef.as<GenericDecl>() &&
+ getInner(requiredMemberDeclRef.as<GenericDecl>())->hasModifier<BuiltinRequirementModifier>()))
+ {
+ getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef);
+ getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef);
+ return false;
+ }
}
}
@@ -4575,6 +4889,10 @@ namespace Slang
// used to synthesize an exact-match witness, by generating the
// code required to handle all the conversions that might be
// required on `this`.
+ //
+ // Another situation that will get us here is that we are dealing with
+ // a wrapper type (struct Foo:IFoo=FooImpl), and we will synthesize
+ // wrappers that redirects the call into the inner element.
//
if( trySynthesizeRequirementWitness(context, lookupResult, requiredMemberDeclRef, witnessTable) )
{
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index ba8e2f4cd..8abf06d6f 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1604,6 +1604,23 @@ namespace Slang
DeclRef<PropertyDecl> requiredMemberDeclRef,
RefPtr<WitnessTable> witnessTable);
+ bool trySynthesizeWrapperTypePropertyRequirementWitness(
+ ConformanceCheckingContext* context,
+ DeclRef<PropertyDecl> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable);
+
+ bool trySynthesizeAssociatedTypeRequirementWitness(
+ ConformanceCheckingContext* context,
+ LookupResult const& lookupResult,
+ DeclRef<AssocTypeDecl> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable);
+
+ bool trySynthesizeAssociatedConstantRequirementWitness(
+ ConformanceCheckingContext* context,
+ LookupResult const& lookupResult,
+ DeclRef<VarDeclBase> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable);
+
/// Attempt to synthesize a declartion that can satisfy `requiredMemberDeclRef` using `lookupResult`.
///
/// On success, installs the syntethesized declaration in `witnessTable` and returns `true`.
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 722085744..c90dc12e8 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -440,6 +440,7 @@ DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.
DIAGNOSTIC(31202, Error, duplicateModifier, "modifier '$0' is redundant or conflicting with existing modifier '$1'")
DIAGNOSTIC(31203, Error, cannotExportIncompleteType, "cannot export incomplete type '$0'")
DIAGNOSTIC(31204, Error, incompleteTypeCannotBeUsedInBuffer, "incomplete type '$0' cannot be used in a buffer")
+DIAGNOSTIC(31205, Error, incompleteTypeCannotBeUsedInUniformParameter, "incomplete type '$0' cannot be used in a uniform parameter")
// Enums
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index b560bb156..416a6671b 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7872,8 +7872,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
UInt operandCount = 0;
for (auto requirementDecl : decl->members)
{
+ if (as<SubscriptDecl>(requirementDecl) || as<PropertyDecl>(requirementDecl))
+ {
+ for (auto accessorDecl : as<ContainerDecl>(requirementDecl)->members)
+ {
+ if (as<AccessorDecl>(accessorDecl))
+ operandCount++;
+ }
+ }
if (!shouldDeclBeTreatedAsInterfaceRequirement(requirementDecl))
+ {
continue;
+ }
operandCount++;
// As a special case, any type constraints placed
@@ -7911,29 +7921,26 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
UInt entryIndex = 0;
-
- for (auto requirementDecl : decl->members)
- {
- auto requirementKey = getInterfaceRequirementKey(requirementDecl);
- if (!requirementKey) continue;
- auto entry = subBuilder->createInterfaceRequirementEntry(
- requirementKey,
- nullptr);
- if (auto inheritance = as<InheritanceDecl>(requirementDecl))
- {
- auto irBaseType = lowerType(context, inheritance->base.type);
- auto irWitnessTableType = subBuilder->getWitnessTableType(irBaseType);
- entry->setRequirementVal(irWitnessTableType);
- }
- else
+ auto addEntry = [&](IRStructKey* requirementKey, Decl* requirementDecl)
{
- IRInst* requirementVal = ensureDecl(subContext, requirementDecl).val;
- if (requirementVal)
+ auto entry = subBuilder->createInterfaceRequirementEntry(
+ requirementKey,
+ nullptr);
+ if (auto inheritance = as<InheritanceDecl>(requirementDecl))
+ {
+ auto irBaseType = lowerType(context, inheritance->base.type);
+ auto irWitnessTableType = subBuilder->getWitnessTableType(irBaseType);
+ entry->setRequirementVal(irWitnessTableType);
+ }
+ else
{
- switch (requirementVal->getOp())
+ IRInst* requirementVal = ensureDecl(subContext, requirementDecl).val;
+ if (requirementVal)
{
- case kIROp_Func:
- case kIROp_Generic:
+ switch (requirementVal->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_Generic:
{
// Remove lowered `IRFunc`s since we only care about
// function types.
@@ -7941,58 +7948,82 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
entry->setRequirementVal(reqType);
break;
}
- default:
- entry->setRequirementVal(requirementVal);
- break;
- }
- if (requirementDecl->findModifier<HLSLStaticModifier>())
- {
- getBuilder()->addStaticRequirementDecoration(requirementKey);
+ default:
+ entry->setRequirementVal(requirementVal);
+ break;
+ }
+ if (requirementDecl->findModifier<HLSLStaticModifier>())
+ {
+ getBuilder()->addStaticRequirementDecoration(requirementKey);
+ }
}
}
- }
- irInterface->setOperand(entryIndex, entry);
- entryIndex++;
- // Add addtional requirements for type constraints placed
- // on an associated types.
- if (auto associatedTypeDecl = as<AssocTypeDecl>(requirementDecl))
- {
- for (auto constraintDecl : associatedTypeDecl->getMembersOfType<TypeConstraintDecl>())
+ irInterface->setOperand(entryIndex, entry);
+ entryIndex++;
+ // Add addtional requirements for type constraints placed
+ // on an associated types.
+ if (auto associatedTypeDecl = as<AssocTypeDecl>(requirementDecl))
{
- auto constraintKey = getInterfaceRequirementKey(constraintDecl);
- auto constraintInterfaceType =
- lowerType(context, constraintDecl->getSup().type);
- auto witnessTableType =
- getBuilder()->getWitnessTableType(constraintInterfaceType);
+ for (auto constraintDecl : associatedTypeDecl->getMembersOfType<TypeConstraintDecl>())
+ {
+ auto constraintKey = getInterfaceRequirementKey(constraintDecl);
+ auto constraintInterfaceType =
+ lowerType(context, constraintDecl->getSup().type);
+ auto witnessTableType =
+ getBuilder()->getWitnessTableType(constraintInterfaceType);
- auto constraintEntry = subBuilder->createInterfaceRequirementEntry(constraintKey,
+ auto constraintEntry = subBuilder->createInterfaceRequirementEntry(constraintKey,
witnessTableType);
- irInterface->setOperand(entryIndex, constraintEntry);
- entryIndex++;
+ irInterface->setOperand(entryIndex, constraintEntry);
+ entryIndex++;
- context->setValue(constraintDecl, LoweredValInfo::simple(constraintEntry));
+ context->setValue(constraintDecl, LoweredValInfo::simple(constraintEntry));
+ }
}
- }
- else
- {
- CallableDecl* callableDecl = nullptr;
- if (auto genDecl = as<GenericDecl>(requirementDecl))
- callableDecl = as<CallableDecl>(genDecl->inner);
else
- callableDecl = as<CallableDecl>(requirementDecl);
- if (callableDecl)
{
- // Differentiable functions has additional requirements for the derivatives.
- for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>())
+ CallableDecl* callableDecl = nullptr;
+ if (auto genDecl = as<GenericDecl>(requirementDecl))
+ callableDecl = as<CallableDecl>(genDecl->inner);
+ else
+ callableDecl = as<CallableDecl>(requirementDecl);
+ if (callableDecl)
+ {
+ // Differentiable functions has additional requirements for the derivatives.
+ for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>())
+ {
+ auto diffKey = getInterfaceRequirementKey(diffDecl->referencedDecl);
+ insertRequirementKeyAssociation(diffDecl->referencedDecl, requirementKey, diffKey);
+ }
+ }
+ // Add lowered requirement entry to current decl mapping to prevent
+ // the function requirements from being lowered again when we get to
+ // `ensureAllDeclsRec`.
+ context->setValue(requirementDecl, LoweredValInfo::simple(entry));
+ }
+ };
+ for (auto requirementDecl : decl->members)
+ {
+ auto requirementKey = getInterfaceRequirementKey(requirementDecl);
+ if (!requirementKey)
+ {
+ if (as<PropertyDecl>(requirementDecl) || as<SubscriptDecl>(requirementDecl))
+ {
+ for (auto member : as<ContainerDecl>(requirementDecl)->members)
{
- auto diffKey = getInterfaceRequirementKey(diffDecl->referencedDecl);
- insertRequirementKeyAssociation(diffDecl->referencedDecl, requirementKey, diffKey);
+ if (auto accessorDecl = as<AccessorDecl>(member))
+ {
+ auto accessorKey = getInterfaceRequirementKey(accessorDecl);
+ if (accessorKey)
+ addEntry(accessorKey, accessorDecl);
+ }
}
}
- // Add lowered requirement entry to current decl mapping to prevent
- // the function requirements from being lowered again when we get to
- // `ensureAllDeclsRec`.
- context->setValue(requirementDecl, LoweredValInfo::simple(entry));
+ continue;
+ }
+ else
+ {
+ addEntry(requirementKey, requirementDecl);
}
}
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index b89b93138..b208e1098 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -4733,6 +4733,14 @@ namespace Slang
// We allow for an inheritance clause on a `struct`
// so that it can conform to interfaces.
parseOptionalInheritanceClause(this, rs);
+ if (AdvanceIf(this, TokenType::OpAssign))
+ {
+ rs->wrappedType = ParseTypeExp();
+ PushScope(rs);
+ PopScope();
+ ReadToken(TokenType::Semicolon);
+ return rs;
+ }
if (AdvanceIf(this, TokenType::Semicolon))
return rs;
parseDeclBody(this, rs);