summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-07-25 11:49:07 -0700
committerGitHub <noreply@github.com>2024-07-25 11:49:07 -0700
commit1343ab79fcd0ff9e5ffebbcf95414e51ab19e9cd (patch)
treeccb40045b16e83aaac126b7a47cff46b0f02ecf4 /source/slang
parent3c03d279ee4ccf4796901f4ea6640787d341d11d (diff)
Fix around extensions and `IDifferentiable` requirement synthesis. (#4729)
* Check extensions before function parameters. Fix decl ref formation for synthesized differentiable requirements that are inside an extension. * Fix clang errors. * More clang fix. * Fix warnings. * Fix build error. * Fix. * Fix typo.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-builder.h6
-rw-r--r--source/slang/slang-ast-type.cpp2
-rw-r--r--source/slang/slang-check-decl.cpp187
3 files changed, 134 insertions, 61 deletions
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 029c24216..5b4ec5538 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -336,7 +336,7 @@ public:
case ASTNodeType::ThisTypeDecl:
case ASTNodeType::ExtensionDecl:
case ASTNodeType::AssocTypeDecl:
- return getLookupDeclRef(lookupDeclRef->getLookupSource(), lookupDeclRef->getWitness(), memberDecl);
+ return getLookupDeclRef(lookupDeclRef->getLookupSource(), lookupDeclRef->getWitness(), memberDecl).template as<T>();
default:
break;
}
@@ -396,13 +396,13 @@ public:
return getOrCreate<GenericAppDeclRef>(innerDecl, genericDeclRef, args);
}
- LookupDeclRef* getLookupDeclRef(Type* base, SubtypeWitness* subtypeWitness, Decl* declToLookup)
+ DeclRef<Decl> getLookupDeclRef(Type* base, SubtypeWitness* subtypeWitness, Decl* declToLookup)
{
auto result = getOrCreate<LookupDeclRef>(declToLookup, base, subtypeWitness);
return result;
}
- LookupDeclRef* getLookupDeclRef(SubtypeWitness* subtypeWitness, Decl* declToLookup)
+ DeclRef<Decl> getLookupDeclRef(SubtypeWitness* subtypeWitness, Decl* declToLookup)
{
return getLookupDeclRef(subtypeWitness->getSub(), subtypeWitness, declToLookup);
}
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 2dc746e09..47cd68b9e 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -559,7 +559,7 @@ DeclRef<ThisTypeDecl> ExtractExistentialType::getThisTypeDeclRef()
}
SLANG_ASSERT(thisTypeDecl);
- DeclRef<ThisTypeDecl> specialiedInterfaceDeclRef = getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl);
+ DeclRef<ThisTypeDecl> specialiedInterfaceDeclRef = getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl).as<ThisTypeDecl>();
this->cachedThisTypeDeclRef = specialiedInterfaceDeclRef;
return specialiedInterfaceDeclRef;
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 969c87981..cd25e9d66 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2247,16 +2247,35 @@ namespace Slang
// from requirementDeclRef to get the generic arguments for the outer generic, and
// apply it to the newly synthesized decl.
SubstitutionSet substSet;
+ Type* thisType = nullptr;
if (auto thisWitness = findThisTypeWitness(
SubstitutionSet(requirementDeclRef),
as<InterfaceDecl>(requirementDeclRef.getParent()).getDecl()))
{
- if (auto declRefType = as<DeclRefType>(thisWitness->getSub()))
+ thisType = thisWitness->getSub();
+ if (auto declRefType = as<DeclRefType>(thisType))
{
substSet = SubstitutionSet(declRefType->getDeclRef());
}
}
- auto satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getMemberDeclRef(substSet.declRef, aggTypeDecl));
+ if (!substSet.declRef)
+ return false;
+ Type* satisfyingType = nullptr;
+ if (substSet.declRef->getDecl() == context->parentDecl)
+ {
+ // The type we are synthesizing conformance for is direct inside a type itself.
+ // We need to copy the outer generic arguments to the synthesized type.
+ satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getMemberDeclRef(substSet.declRef, aggTypeDecl));
+ }
+ else if (auto parentExtDecl = as<ExtensionDecl>(context->parentDecl))
+ {
+ // The type is defined in an extension, we need to form a declref to the parent
+ // extension from the requirementDeclRef.
+ auto extDeclRef = applyExtensionToType(parentExtDecl, thisType);
+ satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getMemberDeclRef(extDeclRef, aggTypeDecl));
+ }
+ if (!satisfyingType)
+ return false;
// Helper function to add a `diffType` field into the synthesized type for the original
// `member`.
@@ -2683,6 +2702,23 @@ namespace Slang
_registerBuiltinDeclsRec(session, decl);
}
+ void discoverExtensionDecls(List<ExtensionDecl*>& decls, Decl* parent)
+ {
+ if (auto extDecl = as<ExtensionDecl>(parent))
+ decls.add(extDecl);
+ if (auto containerDecl = as<ContainerDecl>(parent))
+ {
+ for (auto child : containerDecl->members)
+ {
+ discoverExtensionDecls(decls, child);
+ }
+ }
+ if (auto genericDecl = as<GenericDecl>(parent))
+ {
+ discoverExtensionDecls(decls, genericDecl->inner);
+ }
+ }
+
void SemanticsDeclVisitorBase::checkModule(ModuleDecl* moduleDecl)
{
// When we are dealing with code from the standard library,
@@ -2824,6 +2860,23 @@ namespace Slang
DeclCheckState::DefinitionChecked,
DeclCheckState::CapabilityChecked,
};
+
+ // Discover and check all extension decls before anything else.
+ List<ExtensionDecl*> extensionDecls;
+ discoverExtensionDecls(extensionDecls, moduleDecl);
+ for (auto s : states)
+ {
+ for (auto extensionDecl : extensionDecls)
+ {
+ ensureDecl(extensionDecl, s);
+ }
+ // We only need to check extension decls up to ReadyForLookup
+ // so they are properly registered in type inheritance infos.
+ if (s == DeclCheckState::ReadyForLookup)
+ break;
+ }
+
+ // With extensions taken care of, we can now check the remaining decls.
for(auto s : states)
{
// When advancing to state `s` we will recursively
@@ -5183,12 +5236,12 @@ namespace Slang
}
else if (auto funcDeclRef = requirementDeclRef.as<FuncDecl>())
{
- synFunc = as<FuncDecl>(synthesizeMethodSignatureForRequirementWitness(
- context, funcDeclRef, synArgs, synThis));
+ synFunc = as<FuncDecl>(synthesizeMethodSignatureForRequirementWitness(
+ context, funcDeclRef, synArgs, synThis));
}
-
+
SLANG_ASSERT(synFunc);
-
+
addModifier(synFunc, m_astBuilder->create<BackwardDifferentiableAttribute>());
if (synGeneric)
@@ -5231,49 +5284,49 @@ namespace Slang
switch (pattern)
{
- case SynthesisPattern::AllInductive:
+ case SynthesisPattern::AllInductive:
+ {
+ for (auto arg : synArgs)
+ {
+ auto memberExpr = m_astBuilder->create<MemberExpr>();
+ memberExpr->baseExpression = arg;
+
+ memberExpr->name = derivMemberName;
+
+ paramFields.add(memberExpr);
+ inductiveArgMask.add(true);
+ }
+ break;
+ }
+ case SynthesisPattern::FixedFirstArg:
+ {
+ int paramIndex = 0;
+ for (auto arg : synArgs)
{
- for (auto arg : synArgs)
+ if (paramIndex == 0)
+ {
+ paramFields.add(arg);
+ inductiveArgMask.add(false);
+
+ paramIndex++;
+ }
+ else
{
auto memberExpr = m_astBuilder->create<MemberExpr>();
memberExpr->baseExpression = arg;
memberExpr->name = derivMemberName;
-
paramFields.add(memberExpr);
inductiveArgMask.add(true);
- }
- break;
- }
- case SynthesisPattern::FixedFirstArg:
- {
- int paramIndex = 0;
- for (auto arg : synArgs)
- {
- if (paramIndex == 0)
- {
- paramFields.add(arg);
- inductiveArgMask.add(false);
-
- paramIndex++;
- }
- else
- {
- auto memberExpr = m_astBuilder->create<MemberExpr>();
- memberExpr->baseExpression = arg;
-
- memberExpr->name = derivMemberName;
- paramFields.add(memberExpr);
- inductiveArgMask.add(true);
- paramIndex++;
- }
+ paramIndex++;
}
- break;
}
- default:
- SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern");
- break;
+ break;
+ }
+ default:
+ SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern");
+ break;
}
// Invoke the method for the field and assign the value to resultVar.
@@ -5294,8 +5347,9 @@ namespace Slang
auto synReturn = m_astBuilder->create<ReturnStmt>();
synReturn->expression = resultVarExpr;
seqStmt->stmts.add(synReturn);
-
- context->parentDecl->members.add(synFunc);
+
+ Decl* witnessDecl = synGeneric ? (Decl*)synGeneric : synFunc;
+ context->parentDecl->members.add(witnessDecl);
context->parentDecl->invalidateMemberDictionary();
addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>());
@@ -5313,21 +5367,29 @@ namespace Slang
substSet = SubstitutionSet(declRefType->getDeclRef());
}
}
- if (auto outerGeneric = GetOuterGeneric(context->parentDecl))
+ if (!substSet.declRef)
+ return false;
+ DeclRef<Decl> synthesizedWitnessDeclRef;
+ if (auto parentExtDecl = as<ExtensionDecl>(context->parentDecl))
{
- // If the context->parentDecl is not the same as ThisType represented by genApp, then it must be an extension
- // to ThisType. In this case, we need to form a new GenericAppDeclRef to specailizethe outer parent extension
- // decl. Note that the extension might be a partial extension with some generic arguments missing, and
- // we can't support that case right now. For now we can just assume the extension will have the same set
- // of generic parameters as the target type.
- auto defaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, outerGeneric);
- auto specializedParent = m_astBuilder->getGenericAppDeclRef(makeDeclRef(outerGeneric), defaultArgs.getArrayView());
- auto specializedFunc = m_astBuilder->getMemberDeclRef(specializedParent, synFunc);
- witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(specializedFunc));
- return true;
- }
+ // If the conformance is declared on an extension to ThisType,
+ // we need to form a new proper decl ref to the parent extension decl
+ // with the correct specialization arguments.
+ //
+ if (GetOuterGeneric(context->parentDecl))
+ {
- witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(m_astBuilder->getDirectDeclRef(synFunc)));
+ auto extDeclRef = applyExtensionToType(parentExtDecl, context->conformingType);
+ synthesizedWitnessDeclRef = m_astBuilder->getMemberDeclRef(extDeclRef, witnessDecl);
+ }
+ }
+ else
+ {
+ synthesizedWitnessDeclRef = m_astBuilder->getMemberDeclRef(substSet.declRef, witnessDecl);
+ }
+ if (!synthesizedWitnessDeclRef)
+ synthesizedWitnessDeclRef = m_astBuilder->getDirectDeclRef(witnessDecl);
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(synthesizedWitnessDeclRef));
return true;
}
@@ -5351,6 +5413,15 @@ namespace Slang
// with the same name in the type declaration and
// its (known) extensions.
+ // The exception to that is when the requiredMemberDeclRef is already
+ // resolved to the actual satisfying decl, in which case we simply return
+ // true without any further lookup.
+ if (!as<InterfaceDecl>(requiredMemberDeclRef.getParent().getDecl()))
+ return true;
+
+ // If `requiredMemberDeclRef` is a lookup decl ref for an interface requirement
+ // we attempt to do the loopkup through witness tables.
+ //
// As a first pass, lets check if we already have a
// witness in the table for the requirement, so
// that we can bail out early.
@@ -5655,7 +5726,7 @@ namespace Slang
subType,
superInterfaceType,
inheritanceDecl,
- thisTypeDeclRef,
+ superInterfaceDeclRef,
requiredMemberDeclRef,
witnessTable,
subTypeConformsToSuperInterfaceWitness);
@@ -5674,7 +5745,7 @@ namespace Slang
subType,
superInterfaceType,
inheritanceDecl,
- thisTypeDeclRef,
+ superInterfaceDeclRef,
requiredMemberDeclRef,
witnessTable,
subTypeConformsToSuperInterfaceWitness);
@@ -5726,7 +5797,7 @@ namespace Slang
subType,
superInterfaceType,
inheritanceDecl,
- thisTypeDeclRef,
+ superInterfaceDeclRef,
requiredInheritanceDeclRef,
witnessTable,
subTypeConformsToSuperInterfaceWitness);
@@ -8551,7 +8622,9 @@ namespace Slang
// Looks like we have a match in the types,
// now let's see if `type`'s declref starts with a Lookup.
targetType = type;
- extDeclRef = m_astBuilder->getLookupDeclRef(thisTypeLookupDeclRef->getWitness(), extDeclRef.getDecl());
+ extDeclRef = m_astBuilder->getLookupDeclRef(
+ thisTypeLookupDeclRef->getWitness(), extDeclRef.getDecl())
+ .as<ExtensionDecl>();
}
}
}