summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-05-02 19:01:21 -0400
committerGitHub <noreply@github.com>2024-05-02 16:01:21 -0700
commitc763750a7305fbf12c1f5c177260294a32fe286d (patch)
treeec544212bd5bf0ac72cbe3c01d2ed85fc3f05cd9 /source
parente5d49cf21db7a398afe6cfdb76f6b4a028e9eecb (diff)
Handle case where types can be used as their own `Differential` type. (#4057)
* Avoid synthesis for when types can be used as their own differenial + Add test * Add missing files.. * Fix issue with method synthesis for self-differential types + Add a generic test * Fix * Fix issue with out-of-date type resolution cache. Witness tables created during the conformance checking phase not being taken into account during the decl type resolution phase because the epoch is not updated after conformance checking. This leads to certain complex associated-type lookup chains (such as the one in tests/compute/assoctype-nested-lookup) not resolving properly and causing errors. * Delete self-differential-type-synthesis-extension.slang * Quick fix to repopulate stdlib cache for deferred stdlib loading * Update slang-check-decl.cpp
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-decl.cpp44
-rw-r--r--source/slang/slang-check-expr.cpp120
-rw-r--r--source/slang/slang-check-impl.h3
-rw-r--r--source/slang/slang-options.cpp4
4 files changed, 144 insertions, 27 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 921bd38e9..0d089874e 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2144,7 +2144,40 @@ namespace Slang
SLANG_RELEASE_ASSERT(aggTypeDecl);
synth.pushContainerScope(aggTypeDecl);
}
- else
+
+ // If we did not find an existing empty struct, we may need to synthesize one.
+ // But first, we check if the parent type can be used as its own differential type.
+ //
+ if (!aggTypeDecl
+ && as<AggTypeDecl>(context->parentDecl)
+ && canStructBeUsedAsSelfDifferentialType(as<AggTypeDecl>(context->parentDecl)))
+ {
+ // If the parent type can be used as its own differential type, we will create a typealias
+ // to itself as the differential type.
+ //
+ auto assocTypeDef = m_astBuilder->create<TypeDefDecl>();
+ assocTypeDef->nameAndLoc.name = getName("Differential");
+ assocTypeDef->type.type = context->conformingType;
+ assocTypeDef->parentDecl = context->parentDecl;
+ assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked);
+ context->parentDecl->members.add(assocTypeDef);
+
+ markSelfDifferentialMembersOfType(as<AggTypeDecl>(context->parentDecl), context->conformingType);
+
+ if (doesTypeSatisfyAssociatedTypeConstraintRequirement(context->conformingType, requirementDeclRef, witnessTable))
+ {
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType));
+
+ // Increase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types.
+ m_astBuilder->incrementEpoch();
+ return true;
+ }
+
+ // Something went wrong.
+ return false;
+ }
+
+ if (!aggTypeDecl)
{
aggTypeDecl = m_astBuilder->create<StructDecl>();
aggTypeDecl->parentDecl = context->parentDecl;
@@ -5741,6 +5774,15 @@ namespace Slang
{
checkConformance(type, inheritanceDecl, decl);
}
+
+ // Successful conformance checking may have created new witness tables.
+ // Increment epoch to invalidate the cache, so subsequent canonical types are
+ // re-calculated.
+ //
+ // TODO: Is it really necessary to invalidate globally? Maybe there's a way to invalidate only the
+ // types that are affected by these interface decls.
+ //
+ astBuilder->incrementEpoch();
}
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index a399ea389..ff74f9a62 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -598,37 +598,60 @@ namespace Slang
{
case BuiltinRequirementKind::DifferentialType:
{
- auto structDecl = m_astBuilder->create<StructDecl>();
- auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
- conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
- conformanceDecl->parentDecl = structDecl;
- structDecl->members.add(conformanceDecl);
- structDecl->parentDecl = parent;
-
- synthesizedDecl = structDecl;
- auto typeDef = m_astBuilder->create<TypeAliasDecl>();
- typeDef->nameAndLoc.name = getName("Differential");
- typeDef->parentDecl = structDecl;
-
- auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
-
- typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
- structDecl->members.add(typeDef);
+ if (!canStructBeUsedAsSelfDifferentialType(parent))
+ {
+ // Need to create a new struct type for the differential.
+ //
+ auto structDecl = m_astBuilder->create<StructDecl>();
+ auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
+ conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
+ conformanceDecl->parentDecl = structDecl;
+ structDecl->members.add(conformanceDecl);
+ structDecl->parentDecl = parent;
+
+ synthesizedDecl = structDecl;
+ auto typeDef = m_astBuilder->create<TypeAliasDecl>();
+ typeDef->nameAndLoc.name = getName("Differential");
+ typeDef->parentDecl = structDecl;
+
+ auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
+
+ typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
+ structDecl->members.add(typeDef);
+
+ synthesizedDecl->parentDecl = parent;
+ synthesizedDecl->nameAndLoc.name = item.declRef.getName();
+ synthesizedDecl->loc = parent->loc;
+ parent->members.add(synthesizedDecl);
+ parent->invalidateMemberDictionary();
+
+ // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
+ // from user-provided definitions, and proceed to fill in its definition.
+ auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
+ addModifier(synthesizedDecl, toBeSynthesized);
+ }
+ else
+ {
+ // There's no need for a new struct decl.
+ // We can simply add a typealias to the existing concrete type.
+ //
+ auto typeDef = m_astBuilder->create<TypeAliasDecl>();
+ typeDef->nameAndLoc.name = item.declRef.getName();
+ typeDef->parentDecl = parent;
+ typeDef->type.type = subType;
+
+ synthesizedDecl = parent;
+
+ parent->members.add(typeDef);
+ parent->invalidateMemberDictionary();
+
+ markSelfDifferentialMembersOfType(parent, subType);
+ }
}
break;
default:
return nullptr;
}
- synthesizedDecl->parentDecl = parent;
- synthesizedDecl->nameAndLoc.name = item.declRef.getName();
- synthesizedDecl->loc = parent->loc;
- parent->members.add(synthesizedDecl);
- parent->invalidateMemberDictionary();
-
- // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
- // from user-provided definitions, and proceed to fill in its definition.
- auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
- addModifier(synthesizedDecl, toBeSynthesized);
auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl);
return ConstructDeclRefExpr(
@@ -1145,6 +1168,51 @@ namespace Slang
return nullptr;
}
+ bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl)
+ {
+ // A struct can be used as its own differential type if all its members are differentiable
+ // and their differential types are the same as the original types.
+ //
+ bool canBeUsed = true;
+ for (auto member : aggTypeDecl->members)
+ {
+ if (auto varDecl = as<VarDecl>(member))
+ {
+ // Try to get the differential type of the member.
+ Type* diffType = tryGetDifferentialType(getASTBuilder(), varDecl->getType());
+ if (!diffType || !diffType->equals(varDecl->getType()))
+ {
+ canBeUsed = false;
+ break;
+ }
+ }
+ }
+ return canBeUsed;
+ }
+
+ void SemanticsVisitor::markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type)
+ {
+ // TODO: Handle extensions.
+ // Add derivative member attributes to all the fields pointing to themselves.
+ for (auto member : parent->getMembersOfType<VarDeclBase>())
+ {
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = member->getType();
+
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = type;
+ auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(type);
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+
+ fieldLookupExpr->declRef = makeDeclRef(member);
+
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(member, derivativeMemberModifier);
+ }
+ }
+
Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc)
{
auto result = tryGetDifferentialType(builder, type);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index d90c3c4b0..fc87c680b 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1332,6 +1332,9 @@ namespace Slang
Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc);
Type* tryGetDifferentialType(ASTBuilder* builder, Type* type);
+ // Helper function to check if a struct can be used as its own differential type.
+ bool canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl);
+ void markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type);
public:
diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp
index 857f4272c..805ea0fff 100644
--- a/source/slang/slang-options.cpp
+++ b/source/slang/slang-options.cpp
@@ -1707,6 +1707,10 @@ SlangResult OptionsParser::_parse(
ScopedAllocation contents;
SLANG_RETURN_ON_FAIL(File::readAllBytes(fileName.value, contents));
SLANG_RETURN_ON_FAIL(m_session->loadStdLib(contents.getData(), contents.getSizeInBytes()));
+
+ // Ensure that the linkage's AST builder is up-to-date.
+ linkage->getASTBuilder()->m_cachedNodes = asInternal(m_session)->getGlobalASTBuilder()->m_cachedNodes;
+
break;
}
case OptionKind::CompileStdLib: m_compileStdLib = true; break;