summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2018-01-16 10:52:10 -0800
committerTim Foley <tfoleyNV@users.noreply.github.com>2018-01-16 10:52:10 -0800
commit59691aeeb013c5bb7cdaa31a6fc572eebd8be610 (patch)
tree310754847c4c83ffa8fd97fcaadc7cdf7b14c253
parent68fd4485708bf98c66e27e330692138f3eb6f289 (diff)
Allow extension on interface (#369)
This completes item 5 in issue #361. The interesting change is that when checking for interface conformance, we include the requirements (include transitive interfaces) defined in extensions as well. (check.cpp line 1946) All the other changes are for one thing: reoder the semantic checkings to two explicit stages: check header and check body. In check header phase, we check everything except function bodies, register all extensions with their target decls, then check interface conformances for all concrete types. In body checking phase, we look inside the function bodies and check concrete statements/expressions. This change ensures that we take extension into consideration in all places where it should be.
-rw-r--r--source/slang/check.cpp545
-rw-r--r--source/slang/diagnostic-defs.h1
-rw-r--r--source/slang/diagnostics.cpp5
-rw-r--r--source/slang/lookup.cpp17
-rw-r--r--source/slang/syntax.h36
-rw-r--r--tests/compute/extension-on-interface.slang49
-rw-r--r--tests/compute/extension-on-interface.slang.expected.txt4
7 files changed, 412 insertions, 245 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 3103a7908..f7bb2ae1f 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -44,11 +44,24 @@ namespace Slang
return name;
}
+ enum class CheckingPhase
+ {
+ Header, Body
+ };
+
struct SemanticsVisitor
: ExprVisitor<SemanticsVisitor, RefPtr<Expr>>
, StmtVisitor<SemanticsVisitor>
, DeclVisitor<SemanticsVisitor>
{
+ CheckingPhase checkingPhase = CheckingPhase::Header;
+ DeclCheckState getCheckedState()
+ {
+ if (checkingPhase == CheckingPhase::Body)
+ return DeclCheckState::Checked;
+ else
+ return DeclCheckState::CheckedHeader;
+ }
DiagnosticSink* sink = nullptr;
DiagnosticSink* getSink()
{
@@ -462,13 +475,13 @@ namespace Slang
// Make sure a declaration has been checked, so we can refer to it.
// Note that this may lead to us recursively invoking checking,
// so this may not be the best way to handle things.
- void EnsureDecl(RefPtr<Decl> decl, DeclCheckState state = DeclCheckState::CheckedHeader)
+ void EnsureDecl(RefPtr<Decl> decl, DeclCheckState state)
{
if (decl->IsChecked(state)) return;
if (decl->checkState == DeclCheckState::CheckingHeader)
{
// We tried to reference the same declaration while checking it!
- throw "circularity";
+ sink->diagnose(decl, Diagnostics::cyclicReference, decl);
}
if (DeclCheckState::CheckingHeader > decl->checkState)
@@ -478,13 +491,12 @@ namespace Slang
// Use visitor pattern to dispatch to correct case
DeclVisitor::dispatch(decl);
-
- decl->SetCheckState(DeclCheckState::Checked);
+ decl->SetCheckState(state);
}
void EnusreAllDeclsRec(RefPtr<Decl> decl)
{
- EnsureDecl(decl, DeclCheckState::Checked);
+ checkDecl(decl);
if (auto containerDecl = decl.As<ContainerDecl>())
{
for (auto m : containerDecl->Members)
@@ -537,7 +549,7 @@ namespace Slang
//
auto genericDeclRef = genericDeclRefType->GetDeclRef();
- EnsureDecl(genericDeclRef.decl);
+ checkDecl(genericDeclRef.decl);
List<RefPtr<Expr>> args;
for (RefPtr<Decl> member : genericDeclRef.getDecl()->Members)
{
@@ -1314,7 +1326,7 @@ namespace Slang
// of the parameters of the generic.
if (decl->checkState == DeclCheckState::Unchecked)
{
- decl->checkState = DeclCheckState::Checked;
+ decl->checkState = getCheckedState();
CheckConstraintSubType(decl->sub);
decl->sub = TranslateTypeNodeForced(decl->sub);
decl->sup = TranslateTypeNodeForced(decl->sup);
@@ -1323,11 +1335,13 @@ namespace Slang
void checkDecl(Decl* decl)
{
- EnsureDecl(decl, DeclCheckState::Checked);
+ EnsureDecl(decl, checkingPhase == CheckingPhase::Header ? DeclCheckState::CheckedHeader : DeclCheckState::Checked);
}
- void visitGenericDecl(GenericDecl* genericDecl)
+ void checkGenericDeclHeader(GenericDecl* genericDecl)
{
+ if (genericDecl->IsChecked(DeclCheckState::CheckedHeader))
+ return;
// check the parameters
for (auto m : genericDecl->Members)
{
@@ -1340,21 +1354,29 @@ namespace Slang
// TODO: some real checking here...
CheckVarDeclCommon(valParam);
}
- else if(auto constraint = m.As<GenericTypeConstraintDecl>())
+ else if (auto constraint = m.As<GenericTypeConstraintDecl>())
{
CheckGenericConstraintDecl(constraint.Ptr());
}
}
genericDecl->SetCheckState(DeclCheckState::CheckedHeader);
+ }
+
+ void visitGenericDecl(GenericDecl* genericDecl)
+ {
+ checkGenericDeclHeader(genericDecl);
// check the nested declaration
// TODO: this needs to be done in an appropriate environment...
checkDecl(genericDecl->inner);
+ genericDecl->SetCheckState(getCheckedState());
}
void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl * genericConstraintDecl)
{
+ if (genericConstraintDecl->IsChecked(DeclCheckState::CheckedHeader))
+ return;
// check the type being inherited from
auto base = genericConstraintDecl->sup;
base = TranslateTypeNode(base);
@@ -1363,6 +1385,8 @@ namespace Slang
void visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
{
+ if (inheritanceDecl->IsChecked(DeclCheckState::CheckedHeader))
+ return;
// check the type being inherited from
auto base = inheritanceDecl->base;
CheckConstraintSubType(base);
@@ -1556,64 +1580,98 @@ namespace Slang
// anything else, to make sure that scoping works.
for(auto& importDecl : programNode->getMembersOfType<ImportDecl>())
{
- EnsureDecl(importDecl);
+ checkDecl(importDecl);
}
-
- //
-
- for (auto & s : programNode->getMembersOfType<TypeDefDecl>())
- checkDecl(s.Ptr());
- for (auto & s : programNode->getMembersOfType<StructDecl>())
- {
- checkDecl(s.Ptr());
- }
- for (auto & s : programNode->getMembersOfType<ClassDecl>())
- {
- checkDecl(s.Ptr());
- }
- // HACK(tfoley): Visiting all generic declarations here,
- // because otherwise they won't get visited.
+ // register all extensions
+ for (auto & s : programNode->getMembersOfType<ExtensionDecl>())
+ registerExtension(s);
for (auto & g : programNode->getMembersOfType<GenericDecl>())
{
- checkDecl(g.Ptr());
+ if (auto extDecl = g->inner->As<ExtensionDecl>())
+ {
+ checkGenericDeclHeader(g);
+ registerExtension(extDecl);
+ }
}
+ // check types
+ for (auto & s : programNode->getMembersOfType<TypeDefDecl>())
+ checkDecl(s.Ptr());
- for (auto & func : programNode->getMembersOfType<FuncDecl>())
+ for (int pass = 0; pass < 2; pass++)
{
- if (!func->IsChecked(DeclCheckState::Checked))
+ checkingPhase = pass == 0 ? CheckingPhase::Header : CheckingPhase::Body;
+
+ for (auto & s : programNode->getMembersOfType<AggTypeDecl>())
{
- VisitFunctionDeclaration(func.Ptr());
+ checkDecl(s.Ptr());
+ }
+ // HACK(tfoley): Visiting all generic declarations here,
+ // because otherwise they won't get visited.
+ for (auto & g : programNode->getMembersOfType<GenericDecl>())
+ {
+ checkDecl(g.Ptr());
}
- }
- for (auto & func : programNode->getMembersOfType<FuncDecl>())
- {
- EnsureDecl(func);
- }
-
- if (sink->GetErrorCount() != 0)
- return;
-
- // Force everything to be fully checked, just in case
- // Note that we don't just call this on the program,
- // because we'd end up recursing into this very code path...
- for (auto d : programNode->Members)
- {
- EnusreAllDeclsRec(d);
- }
- // Do any semantic checking required on modifiers?
- for (auto d : programNode->Members)
- {
- checkModifiers(d.Ptr());
+ // before checking conformance, make sure we check all the extension bodies
+ // generic extension decls are already checked by the loop above
+ for (auto & s : programNode->getMembersOfType<ExtensionDecl>())
+ checkDecl(s);
+
+ for (auto & func : programNode->getMembersOfType<FuncDecl>())
+ {
+ if (!func->IsChecked(getCheckedState()))
+ {
+ VisitFunctionDeclaration(func.Ptr());
+ }
+ }
+ for (auto & func : programNode->getMembersOfType<FuncDecl>())
+ {
+ checkDecl(func);
+ }
+
+ if (sink->GetErrorCount() != 0)
+ return;
+
+ // Force everything to be fully checked, just in case
+ // Note that we don't just call this on the program,
+ // because we'd end up recursing into this very code path...
+ for (auto d : programNode->Members)
+ {
+ EnusreAllDeclsRec(d);
+ }
+
+ // Do any semantic checking required on modifiers?
+ for (auto d : programNode->Members)
+ {
+ checkModifiers(d.Ptr());
+ }
+
+ if (pass == 0)
+ {
+ // now we can check all interface conformances
+ for (auto & s : programNode->getMembersOfType<AggTypeDecl>())
+ checkAggTypeConformance(s);
+ for (auto & s : programNode->getMembersOfType<ExtensionDecl>())
+ checkExtensionConformance(s);
+ for (auto & g : programNode->getMembersOfType<GenericDecl>())
+ {
+ if (auto innerAggDecl = g->inner->As<AggTypeDecl>())
+ checkAggTypeConformance(innerAggDecl);
+ else if (auto innerExtDecl = g->inner->As<ExtensionDecl>())
+ checkExtensionConformance(innerExtDecl);
+ }
+ }
}
}
void visitStructField(StructField* field)
{
+ if (field->IsChecked(DeclCheckState::CheckedHeader))
+ return;
// TODO: bottleneck through general-case variable checking
field->type = CheckUsableType(field->type);
- field->SetCheckState(DeclCheckState::Checked);
+ field->SetCheckState(getCheckedState());
}
bool doesSignatureMatchRequirement(
@@ -1702,7 +1760,7 @@ namespace Slang
// a typedef, a `struct`, etc.
auto checkSubTypeMember = [&](DeclRef<ContainerDecl> subStructTypeDeclRef) -> bool
{
- EnsureDecl(subStructTypeDeclRef.getDecl());
+ checkDecl(subStructTypeDeclRef.getDecl());
// this is a sub type (e.g. nested struct declaration) in an aggregate type
// check if this sub type declaration satisfies the constraints defined by the associated type
if (auto requiredTypeDeclRef = requiredMemberDeclRef.As<AssocTypeDecl>())
@@ -1877,7 +1935,7 @@ namespace Slang
// We need to check the declaration of the interface
// before we can check that we conform to it.
- EnsureDecl(interfaceDeclRef.getDecl());
+ checkDecl(interfaceDeclRef.getDecl());
// TODO: If we ever allow for implementation inheritance,
// then we will need to consider the case where a type
@@ -1885,8 +1943,8 @@ namespace Slang
// its (non-interface) base types already conforms to
// that interface, so that all of the requirements are
// already satisfied with inherited implementations...
-
- for (auto requiredMemberDeclRef : getMembers(interfaceDeclRef))
+ auto allMembers = getMembersWithExt(interfaceDeclRef);
+ for (auto requiredMemberDeclRef : allMembers)
{
// Some members of the interface don't actually represent
// things that we required of the implementing type.
@@ -1971,37 +2029,23 @@ namespace Slang
return checkConformance(DeclRef<AggTypeDeclBase>(typeDecl, SubstitutionSet()), inheritanceDecl);
}
- void visitAggTypeDecl(AggTypeDecl* decl)
+ void checkExtensionConformance(ExtensionDecl* decl)
{
- if (decl->IsChecked(DeclCheckState::Checked))
- return;
-
- // TODO: we should check inheritance declarations
- // first, since they need to be validated before
- // we can make use of the type (e.g., you need
- // to know that `A` inherits from `B` in order
- // to check an expression like `aValue.bMethod()`
- // where `aValue` is of type `A` but `bMethod`
- // is defined in type `B`.
- //
- // TODO: We should also add a pass that takes
- // all the stated inheritance relationships,
- // expands them to include implicitic inheritance,
- // and then linearizes them. This would allow
- // later passes that need to know everything
- // a type inherits from to proceed linearly
- // through the list, rather than having to
- // recurse (and potentially see the same interface
- // more than once).
-
- decl->SetCheckState(DeclCheckState::CheckedHeader);
-
- // Now check all of the member declarations.
- for (auto member : decl->Members)
+ DeclRef<AggTypeDecl> aggTypeDeclRef;
+ if (auto targetDeclRefType = decl->targetType->As<DeclRefType>())
{
- checkDecl(member);
+ if (aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>())
+ {
+ for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>())
+ {
+ checkConformance(aggTypeDeclRef.getDecl(), inheritanceDecl);
+ }
+ }
}
+ }
+ void checkAggTypeConformance(AggTypeDecl* decl)
+ {
// After we've checked members, we need to go through
// any inheritance clauses on the type itself, and
// confirm that the type actually provides whatever
@@ -2026,14 +2070,45 @@ namespace Slang
// be required to implement all interface requirements,
// just with `abstract` methods that replicate things?
// (That's what C# does).
-
for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>())
{
checkConformance(decl, inheritanceDecl);
}
}
+ }
+
+ void visitAggTypeDecl(AggTypeDecl* decl)
+ {
+ if (decl->IsChecked(getCheckedState()))
+ return;
+
+ // TODO: we should check inheritance declarations
+ // first, since they need to be validated before
+ // we can make use of the type (e.g., you need
+ // to know that `A` inherits from `B` in order
+ // to check an expression like `aValue.bMethod()`
+ // where `aValue` is of type `A` but `bMethod`
+ // is defined in type `B`.
+ //
+ // TODO: We should also add a pass that takes
+ // all the stated inheritance relationships,
+ // expands them to include implicitic inheritance,
+ // and then linearizes them. This would allow
+ // later passes that need to know everything
+ // a type inherits from to proceed linearly
+ // through the list, rather than having to
+ // recurse (and potentially see the same interface
+ // more than once).
+
+ decl->SetCheckState(DeclCheckState::CheckedHeader);
+
+ // Now check all of the member declarations.
+ for (auto member : decl->Members)
+ {
+ checkDecl(member);
+ }
- decl->SetCheckState(DeclCheckState::Checked);
+ decl->SetCheckState(getCheckedState());
}
void visitDeclGroup(DeclGroup* declGroup)
@@ -2046,45 +2121,52 @@ namespace Slang
void visitTypeDefDecl(TypeDefDecl* decl)
{
- if (decl->IsChecked(DeclCheckState::Checked)) return;
-
- decl->SetCheckState(DeclCheckState::CheckingHeader);
- decl->type = CheckProperType(decl->type);
- decl->SetCheckState(DeclCheckState::Checked);
+ if (decl->IsChecked(getCheckedState())) return;
+ if (checkingPhase == CheckingPhase::Header)
+ {
+ decl->type = CheckProperType(decl->type);
+ }
+ decl->SetCheckState(getCheckedState());
}
void visitGlobalGenericParamDecl(GlobalGenericParamDecl * decl)
{
- if (decl->IsChecked(DeclCheckState::Checked)) return;
- decl->SetCheckState(DeclCheckState::CheckedHeader);
- // global generic param only allowed in global scope
- auto program = decl->ParentDecl->As<ModuleDecl>();
- if (!program)
- getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly);
- // Now check all of the member declarations.
- for (auto member : decl->Members)
+ if (decl->IsChecked(getCheckedState())) return;
+ if (checkingPhase == CheckingPhase::Header)
{
- checkDecl(member);
+ decl->SetCheckState(DeclCheckState::CheckedHeader);
+ // global generic param only allowed in global scope
+ auto program = decl->ParentDecl->As<ModuleDecl>();
+ if (!program)
+ getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly);
+ // Now check all of the member declarations.
+ for (auto member : decl->Members)
+ {
+ checkDecl(member);
+ }
}
- decl->SetCheckState(DeclCheckState::Checked);
+ decl->SetCheckState(getCheckedState());
}
void visitAssocTypeDecl(AssocTypeDecl* decl)
{
- if (decl->IsChecked(DeclCheckState::Checked)) return;
- decl->SetCheckState(DeclCheckState::CheckedHeader);
+ if (decl->IsChecked(getCheckedState())) return;
+ if (checkingPhase == CheckingPhase::Header)
+ {
+ decl->SetCheckState(DeclCheckState::CheckedHeader);
- // assoctype only allowed in an interface
- auto interfaceDecl = decl->ParentDecl->As<InterfaceDecl>();
- if (!interfaceDecl)
- getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
+ // assoctype only allowed in an interface
+ auto interfaceDecl = decl->ParentDecl->As<InterfaceDecl>();
+ if (!interfaceDecl)
+ getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
- // Now check all of the member declarations.
- for (auto member : decl->Members)
- {
- checkDecl(member);
+ // Now check all of the member declarations.
+ for (auto member : decl->Members)
+ {
+ checkDecl(member);
+ }
}
- decl->SetCheckState(DeclCheckState::Checked);
+ decl->SetCheckState(getCheckedState());
}
void checkStmt(Stmt* stmt)
@@ -2095,20 +2177,26 @@ namespace Slang
void visitFuncDecl(FuncDecl *functionNode)
{
- if (functionNode->IsChecked(DeclCheckState::Checked))
+ if (functionNode->IsChecked(getCheckedState()))
return;
- VisitFunctionDeclaration(functionNode);
+ if (checkingPhase == CheckingPhase::Header)
+ {
+ VisitFunctionDeclaration(functionNode);
+ }
// TODO: This should really only set "checked header"
- functionNode->SetCheckState(DeclCheckState::Checked);
+ functionNode->SetCheckState(getCheckedState());
- // TODO: should put the checking of the body onto a "work list"
- // to avoid recursion here.
- if (functionNode->Body)
+ if (checkingPhase == CheckingPhase::Body)
{
- this->function = functionNode;
- checkStmt(functionNode->Body);
- this->function = nullptr;
+ // TODO: should put the checking of the body onto a "work list"
+ // to avoid recursion here.
+ if (functionNode->Body)
+ {
+ this->function = functionNode;
+ checkStmt(functionNode->Body);
+ this->function = nullptr;
+ }
}
}
@@ -2888,57 +2976,62 @@ namespace Slang
void visitVariable(Variable* varDecl)
{
- TypeExp typeExp = CheckUsableType(varDecl->type);
-#if 0
- if (typeExp.type->GetBindableResourceType() != BindableResourceType::NonBindable)
+ if (function || checkingPhase == CheckingPhase::Header)
{
- // We don't want to allow bindable resource types as local variables (at least for now).
- auto parentDecl = varDecl->ParentDecl;
- if (auto parentScopeDecl = dynamic_cast<ScopeDecl*>(parentDecl))
+ TypeExp typeExp = CheckUsableType(varDecl->type);
+#if 0
+ if (typeExp.type->GetBindableResourceType() != BindableResourceType::NonBindable)
{
- getSink()->diagnose(varDecl->type, Diagnostics::invalidTypeForLocalVariable);
+ // We don't want to allow bindable resource types as local variables (at least for now).
+ auto parentDecl = varDecl->ParentDecl;
+ if (auto parentScopeDecl = dynamic_cast<ScopeDecl*>(parentDecl))
+ {
+ getSink()->diagnose(varDecl->type, Diagnostics::invalidTypeForLocalVariable);
+ }
}
- }
#endif
- varDecl->type = typeExp;
- if (varDecl->type.Equals(getSession()->getVoidType()))
- {
- if (!isRewriteMode())
+ varDecl->type = typeExp;
+ if (varDecl->type.Equals(getSession()->getVoidType()))
{
- getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid);
+ if (!isRewriteMode())
+ {
+ getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid);
+ }
}
}
- if(auto initExpr = varDecl->initExpr)
+ if (checkingPhase == CheckingPhase::Body)
{
- initExpr = CheckTerm(initExpr);
- varDecl->initExpr = initExpr;
- }
+ if (auto initExpr = varDecl->initExpr)
+ {
+ initExpr = CheckTerm(initExpr);
+ varDecl->initExpr = initExpr;
+ }
- // If this is an array variable, then we first want to give
- // it a chance to infer an array size from its initializer
- //
- // TODO(tfoley): May need to extend this to handle the
- // multi-dimensional case...
- maybeInferArraySizeForVariable(varDecl);
- //
- // Next we want to make sure that the declared (or inferred)
- // size for the array meets whatever language-specific
- // constraints we want to enforce (e.g., disallow empty
- // arrays in specific cases)
- ValidateArraySizeForVariable(varDecl);
+ // If this is an array variable, then we first want to give
+ // it a chance to infer an array size from its initializer
+ //
+ // TODO(tfoley): May need to extend this to handle the
+ // multi-dimensional case...
+ maybeInferArraySizeForVariable(varDecl);
+ //
+ // Next we want to make sure that the declared (or inferred)
+ // size for the array meets whatever language-specific
+ // constraints we want to enforce (e.g., disallow empty
+ // arrays in specific cases)
+ ValidateArraySizeForVariable(varDecl);
- if(auto initExpr = varDecl->initExpr)
- {
- // TODO(tfoley): should coercion of initializer lists be special-cased
- // here, or handled as a general case for coercion?
+ if (auto initExpr = varDecl->initExpr)
+ {
+ // TODO(tfoley): should coercion of initializer lists be special-cased
+ // here, or handled as a general case for coercion?
- initExpr = Coerce(varDecl->type.Ptr(), initExpr);
- varDecl->initExpr = initExpr;
+ initExpr = Coerce(varDecl->type.Ptr(), initExpr);
+ varDecl->initExpr = initExpr;
+ }
}
-
- varDecl->SetCheckState(DeclCheckState::Checked);
+ varDecl->SetCheckState(getCheckedState());
}
void visitWhileStmt(WhileStmt *stmt)
@@ -3469,15 +3562,14 @@ namespace Slang
return expr;
}
-
- //
-
- void visitExtensionDecl(ExtensionDecl* decl)
+ void registerExtension(ExtensionDecl* decl)
{
- if (decl->IsChecked(DeclCheckState::Checked)) return;
+ if (decl->IsChecked(DeclCheckState::CheckedHeader))
+ return;
decl->SetCheckState(DeclCheckState::CheckingHeader);
decl->targetType = CheckProperType(decl->targetType);
+ decl->SetCheckState(DeclCheckState::CheckedHeader);
// TODO: need to check that the target type names a declaration...
@@ -3490,44 +3582,32 @@ namespace Slang
auto aggTypeDecl = aggTypeDeclRef.getDecl();
decl->nextCandidateExtension = aggTypeDecl->candidateExtensions;
aggTypeDecl->candidateExtensions = decl;
- }
- else
- {
- if (!isRewriteMode())
- {
- getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here");
- }
+ return;
}
}
- else if (decl->targetType->Equals(getSession()->getErrorType()))
+ if (!isRewriteMode())
{
- // there was an error, so ignore
+ getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here");
}
- else
+ }
+
+ void visitExtensionDecl(ExtensionDecl* decl)
+ {
+ if (decl->IsChecked(getCheckedState())) return;
+
+ if (!decl->targetType->As<DeclRefType>())
{
if (!isRewriteMode())
{
getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here");
}
}
-
- decl->SetCheckState(DeclCheckState::CheckedHeader);
-
// now check the members of the extension
for (auto m : decl->Members)
{
- EnsureDecl(m);
- }
-
- if (aggTypeDeclRef)
- {
- for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>())
- {
- checkConformance(aggTypeDeclRef.getDecl(), inheritanceDecl);
- }
+ checkDecl(m);
}
-
- decl->SetCheckState(DeclCheckState::Checked);
+ decl->SetCheckState(getCheckedState());
}
// Figure out what type an initializer/constructor declaration
@@ -3572,31 +3652,31 @@ namespace Slang
void visitConstructorDecl(ConstructorDecl* decl)
{
- if (decl->IsChecked(DeclCheckState::Checked)) return;
- decl->SetCheckState(DeclCheckState::CheckingHeader);
-
- for (auto& paramDecl : decl->GetParameters())
+ if (decl->IsChecked(getCheckedState())) return;
+ if (checkingPhase == CheckingPhase::Header)
{
- paramDecl->type = CheckUsableType(paramDecl->type);
- }
-
- // We need to compute the result tyep for this declaration,
- // since it wasn't filled in for us.
- decl->ReturnType.type = findResultTypeForConstructorDecl(decl);
-
+ decl->SetCheckState(DeclCheckState::CheckingHeader);
- decl->SetCheckState(DeclCheckState::CheckedHeader);
+ for (auto& paramDecl : decl->GetParameters())
+ {
+ paramDecl->type = CheckUsableType(paramDecl->type);
+ }
- // TODO(tfoley): check body
- decl->SetCheckState(DeclCheckState::Checked);
+ // We need to compute the result tyep for this declaration,
+ // since it wasn't filled in for us.
+ decl->ReturnType.type = findResultTypeForConstructorDecl(decl);
+ }
+ else
+ {
+ // TODO(tfoley): check body
+ }
+ decl->SetCheckState(getCheckedState());
}
void visitSubscriptDecl(SubscriptDecl* decl)
{
- if (decl->IsChecked(DeclCheckState::Checked)) return;
- decl->SetCheckState(DeclCheckState::CheckingHeader);
-
+ if (decl->IsChecked(getCheckedState())) return;
for (auto& paramDecl : decl->GetParameters())
{
paramDecl->type = CheckUsableType(paramDecl->type);
@@ -3604,8 +3684,6 @@ namespace Slang
decl->ReturnType = CheckUsableType(decl->ReturnType);
- decl->SetCheckState(DeclCheckState::CheckedHeader);
-
// If we have a subscript declaration with no accessor declarations,
// then we should create a single `GetterDecl` to represent
// the implicit meaning of their declaration, so:
@@ -3637,31 +3715,34 @@ namespace Slang
checkDecl(mm);
}
- decl->SetCheckState(DeclCheckState::Checked);
+ decl->SetCheckState(getCheckedState());
}
void visitAccessorDecl(AccessorDecl* decl)
{
- // An acessor must appear nested inside a subscript declaration (today),
- // or a property declaration (when we add them). It will derive
- // its return type from the outer declaration, so we handle both
- // of these checks at the same place.
- auto parent = decl->ParentDecl;
- if(auto parentSubscript = dynamic_cast<SubscriptDecl*>(parent))
+ if (checkingPhase == CheckingPhase::Header)
{
- decl->ReturnType = parentSubscript->ReturnType;
+ // An acessor must appear nested inside a subscript declaration (today),
+ // or a property declaration (when we add them). It will derive
+ // its return type from the outer declaration, so we handle both
+ // of these checks at the same place.
+ auto parent = decl->ParentDecl;
+ if (auto parentSubscript = dynamic_cast<SubscriptDecl*>(parent))
+ {
+ decl->ReturnType = parentSubscript->ReturnType;
+ }
+ // TODO: when we add "property" declarations, check for them here
+ else
+ {
+ getSink()->diagnose(decl, Diagnostics::accessorMustBeInsideSubscriptOrProperty);
+ }
+
}
- // TODO: when we add "property" declarations, check for them here
else
{
- getSink()->diagnose(decl, Diagnostics::accessorMustBeInsideSubscriptOrProperty);
+ // TODO: check the body!
}
-
- decl->SetCheckState(DeclCheckState::CheckedHeader);
-
- // TODO: check the body!
-
- decl->SetCheckState(DeclCheckState::Checked);
+ decl->SetCheckState(getCheckedState());
}
@@ -3815,7 +3896,7 @@ namespace Slang
{
for( auto inheritanceDeclRef : getMembersOfTypeWithExt<InheritanceDecl>(aggTypeDeclRef))
{
- EnsureDecl(inheritanceDeclRef.getDecl());
+ checkDecl(inheritanceDeclRef.getDecl());
// Here we will recursively look up conformance on the type
// that is being inherited from. This is dangerous because
@@ -3848,7 +3929,7 @@ namespace Slang
// if an inheritance decl is not found, try to find a GenericTypeConstraintDecl
for (auto genConstraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(aggTypeDeclRef))
{
- EnsureDecl(genConstraintDeclRef.getDecl());
+ checkDecl(genConstraintDeclRef.getDecl());
auto inheritedType = GetSup(genConstraintDeclRef);
TypeWitnessBreadcrumb breadcrumb;
breadcrumb.prev = inBreadcrumbs;
@@ -4978,7 +5059,7 @@ namespace Slang
OverloadResolveContext& context)
{
auto funcDecl = funcDeclRef.getDecl();
- EnsureDecl(funcDecl);
+ checkDecl(funcDecl);
// If this function is a redeclaration,
// then we don't want to include it multiple times,
@@ -5040,7 +5121,7 @@ namespace Slang
OverloadResolveContext& context,
RefPtr<Type> resultType)
{
- EnsureDecl(ctorDeclRef.getDecl());
+ checkDecl(ctorDeclRef.getDecl());
// `typeItem` refers to the type being constructed (the thing
// that was applied as a function) so we need to construct
@@ -5381,7 +5462,7 @@ namespace Slang
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context)
{
- EnsureDecl(genericDeclRef.getDecl());
+ checkDecl(genericDeclRef.getDecl());
ConstraintSystem constraints;
constraints.genericDecl = genericDeclRef.getDecl();
@@ -5980,7 +6061,7 @@ namespace Slang
{
if (auto genericDeclRef = baseItem.declRef.As<GenericDecl>())
{
- EnsureDecl(genericDeclRef.getDecl());
+ checkDecl(genericDeclRef.getDecl());
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Generic;
@@ -6168,8 +6249,6 @@ namespace Slang
{
// check the base expression first
expr->FunctionExpr = CheckExpr(expr->FunctionExpr);
-
-
// Next check the argument expressions
for (auto & arg : expr->Arguments)
{
@@ -6629,7 +6708,7 @@ namespace Slang
void visitImportDecl(ImportDecl* decl)
{
- if(decl->IsChecked(DeclCheckState::Checked))
+ if(decl->IsChecked(DeclCheckState::CheckedHeader))
return;
// We need to look for a module with the specified name
@@ -6653,7 +6732,7 @@ namespace Slang
importModuleIntoScope(scope.Ptr(), importedModuleDecl.Ptr());
- decl->SetCheckState(DeclCheckState::Checked);
+ decl->SetCheckState(getCheckedState());
}
// Perform semantic checking of an object-oriented `this`
@@ -6669,7 +6748,7 @@ namespace Slang
auto containerDecl = scope->containerDecl;
if (auto aggTypeDecl = containerDecl->As<AggTypeDecl>())
{
- EnsureDecl(aggTypeDecl);
+ checkDecl(aggTypeDecl);
// Okay, we are using `this` in the context of an
// aggregate type, so the expression should be
@@ -6681,7 +6760,7 @@ namespace Slang
}
else if (auto extensionDecl = containerDecl->As<ExtensionDecl>())
{
- EnsureDecl(extensionDecl);
+ checkDecl(extensionDecl);
// When `this` is used in the context of an `extension`
// declaration, then it should refer to an instance of
@@ -6918,7 +6997,7 @@ namespace Slang
{
if( sema )
{
- sema->EnsureDecl(declRef.getDecl());
+ sema->checkDecl(declRef.getDecl());
}
// We need to insert an appropriate type for the expression, based on
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
index 24e8bc713..7227156fa 100644
--- a/source/slang/diagnostic-defs.h
+++ b/source/slang/diagnostic-defs.h
@@ -198,6 +198,7 @@ DIAGNOSTIC(33070, Error, expectedFunction, "expression preceding parenthesis of
DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.")
DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'__generic_param' can only be defined global scope.")
// TODO: need to assign numbers to all these extra diagnostics...
+DIAGNOSTIC(39999, Error, cyclicReference, "cyclic reference '$0'.")
DIAGNOSTIC(39999, Error, expectedIntegerConstantWrongType, "expected integer constant (found: '$0')")
DIAGNOSTIC(39999, Error, expectedIntegerConstantNotConstant, "expression does not evaluate to a compile-time constant")
diff --git a/source/slang/diagnostics.cpp b/source/slang/diagnostics.cpp
index 4f4a33e60..64713072d 100644
--- a/source/slang/diagnostics.cpp
+++ b/source/slang/diagnostics.cpp
@@ -62,7 +62,10 @@ void printDiagnosticArg(StringBuilder& sb, TypeExp const& type)
void printDiagnosticArg(StringBuilder& sb, QualType const& type)
{
- sb << type.type->ToString();
+ if (type.type)
+ sb << type.type->ToString();
+ else
+ sb << "<null>";
}
void printDiagnosticArg(StringBuilder& sb, TokenType tokenType)
diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp
index 0791c508b..5f925a1c2 100644
--- a/source/slang/lookup.cpp
+++ b/source/slang/lookup.cpp
@@ -302,15 +302,26 @@ void DoLocalLookupImpl(
// for interface decls, also lookup in the base interfaces
if (request.semantics)
{
- if (auto interfaceDeclRef = containerDeclRef.As<InterfaceDecl>())
+ bool isInterface = containerDeclRef.As<InterfaceDecl>() ? true : false;
+ // if we are looking at an extension, find the target decl that we are extending
+ if (auto extDeclRef = containerDeclRef.As<ExtensionDecl>())
{
- auto baseInterfaces = getMembersOfType<InheritanceDecl>(interfaceDeclRef);
+ auto targetDeclRefType = extDeclRef.getDecl()->targetType->AsDeclRefType();
+ SLANG_ASSERT(targetDeclRefType);
+ int diff = 0;
+ auto targetDeclRef = targetDeclRefType->declRef.As<ContainerDecl>().SubstituteImpl(containerDeclRef.substitutions, &diff);
+ isInterface = targetDeclRef.As<InterfaceDecl>() ? true : false;
+ }
+ // if we are looking inside an interface decl, try find in the interfaces it inherits from
+ if (isInterface)
+ {
+ auto baseInterfaces = getMembersOfType<InheritanceDecl>(containerDeclRef);
for (auto inheritanceDeclRef : baseInterfaces)
{
auto baseType = inheritanceDeclRef.getDecl()->base.type.As<DeclRefType>();
SLANG_ASSERT(baseType);
int diff = 0;
- auto baseInterfaceDeclRef = baseType->declRef.SubstituteImpl(interfaceDeclRef.substitutions, &diff);
+ auto baseInterfaceDeclRef = baseType->declRef.SubstituteImpl(containerDeclRef.substitutions, &diff);
DoLocalLookupImpl(session, name, baseInterfaceDeclRef.As<ContainerDecl>(), request, result, inBreadcrumbs);
}
}
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index ab26b1f6d..93e421977 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -1098,32 +1098,52 @@ namespace Slang
// Declarations
//
+ inline ExtensionDecl* GetCandidateExtensions(DeclRef<AggTypeDecl> const& declRef)
+ {
+ return declRef.getDecl()->candidateExtensions;
+ }
+
inline FilteredMemberRefList<Decl> getMembers(DeclRef<ContainerDecl> const& declRef)
{
return FilteredMemberRefList<Decl>(declRef.getDecl()->Members, declRef.substitutions);
}
- template<typename T>
- inline FilteredMemberRefList<T> getMembersOfType(DeclRef<ContainerDecl> const& declRef)
+ // TODO: change this to return a lazy list instead of constructing actual list
+ inline List<DeclRef<Decl>> getMembersWithExt(DeclRef<ContainerDecl> const& declRef)
{
- return FilteredMemberRefList<T>(declRef.getDecl()->Members, declRef.substitutions);
+ List<DeclRef<Decl>> rs;
+ for (auto d : FilteredMemberRefList<Decl>(declRef.getDecl()->Members, declRef.substitutions))
+ rs.Add(d);
+ if (auto aggDeclRef = declRef.As<AggTypeDecl>())
+ {
+ for (auto ext = GetCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension)
+ {
+ for (auto mbr : getMembers(DeclRef<ContainerDecl>(ext, declRef.substitutions)))
+ rs.Add(mbr);
+ }
+ }
+ return rs;
}
- inline ExtensionDecl* GetCandidateExtensions(DeclRef<AggTypeDecl> const& declRef)
+ template<typename T>
+ inline FilteredMemberRefList<T> getMembersOfType(DeclRef<ContainerDecl> const& declRef)
{
- return declRef.getDecl()->candidateExtensions;
+ return FilteredMemberRefList<T>(declRef.getDecl()->Members, declRef.substitutions);
}
template<typename T>
- inline FilteredMemberRefList<T> getMembersOfTypeWithExt(DeclRef<ContainerDecl> const& declRef)
+ inline List<DeclRef<T>> getMembersOfTypeWithExt(DeclRef<ContainerDecl> const& declRef)
{
- auto rs = getMembersOfType<T>(declRef);
+ List<DeclRef<T>> rs;
+ for (auto d : getMembersOfType<T>(declRef))
+ rs.Add(d);
if (auto aggDeclRef = declRef.As<AggTypeDecl>())
{
for (auto ext = GetCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension)
{
auto extMembers = getMembersOfType<T>(DeclRef<ContainerDecl>(ext, declRef.substitutions));
- const_cast<List<RefPtr<Decl>>&>(rs.decls).AddRange(extMembers.decls);
+ for (auto mbr : extMembers)
+ rs.Add(mbr);
}
}
return rs;
diff --git a/tests/compute/extension-on-interface.slang b/tests/compute/extension-on-interface.slang
new file mode 100644
index 000000000..1d3fb5e30
--- /dev/null
+++ b/tests/compute/extension-on-interface.slang
@@ -0,0 +1,49 @@
+//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
+
+RWStructuredBuffer<float> outputBuffer;
+
+interface IOp
+{
+ float addf(float u, float v);
+}
+
+interface ISub
+{
+ float subf(float u, float v);
+}
+
+extension IOp : ISub
+{
+}
+
+struct Simple : IOp
+{
+ float base;
+ float addf(float u, float v)
+ {
+ return u+v;
+ }
+};
+
+__extension Simple : ISub
+{
+ float subf(float u, float v)
+ {
+ return base+u-v;
+ }
+};
+
+float testAddSub<T:IOp>(T t)
+{
+ return t.subf(t.addf(1.0, 1.0), 1.0);
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ Simple s;
+ s.base = 0.0;
+ float outVal = testAddSub(s);
+ outputBuffer[dispatchThreadID.x] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/extension-on-interface.slang.expected.txt b/tests/compute/extension-on-interface.slang.expected.txt
new file mode 100644
index 000000000..cc5e55ab6
--- /dev/null
+++ b/tests/compute/extension-on-interface.slang.expected.txt
@@ -0,0 +1,4 @@
+3F800000
+3F800000
+3F800000
+3F800000