summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-05-02 19:01:21 -0400
committerGitHub <noreply@github.com>2024-05-02 16:01:21 -0700
commitc763750a7305fbf12c1f5c177260294a32fe286d (patch)
treeec544212bd5bf0ac72cbe3c01d2ed85fc3f05cd9 /source/slang/slang-check-expr.cpp
parente5d49cf21db7a398afe6cfdb76f6b4a028e9eecb (diff)
Handle case where types can be used as their own `Differential` type. (#4057)
* Avoid synthesis for when types can be used as their own differenial + Add test * Add missing files.. * Fix issue with method synthesis for self-differential types + Add a generic test * Fix * Fix issue with out-of-date type resolution cache. Witness tables created during the conformance checking phase not being taken into account during the decl type resolution phase because the epoch is not updated after conformance checking. This leads to certain complex associated-type lookup chains (such as the one in tests/compute/assoctype-nested-lookup) not resolving properly and causing errors. * Delete self-differential-type-synthesis-extension.slang * Quick fix to repopulate stdlib cache for deferred stdlib loading * Update slang-check-decl.cpp
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp120
1 files changed, 94 insertions, 26 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index a399ea389..ff74f9a62 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -598,37 +598,60 @@ namespace Slang
{
case BuiltinRequirementKind::DifferentialType:
{
- auto structDecl = m_astBuilder->create<StructDecl>();
- auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
- conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
- conformanceDecl->parentDecl = structDecl;
- structDecl->members.add(conformanceDecl);
- structDecl->parentDecl = parent;
-
- synthesizedDecl = structDecl;
- auto typeDef = m_astBuilder->create<TypeAliasDecl>();
- typeDef->nameAndLoc.name = getName("Differential");
- typeDef->parentDecl = structDecl;
-
- auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
-
- typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
- structDecl->members.add(typeDef);
+ if (!canStructBeUsedAsSelfDifferentialType(parent))
+ {
+ // Need to create a new struct type for the differential.
+ //
+ auto structDecl = m_astBuilder->create<StructDecl>();
+ auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
+ conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
+ conformanceDecl->parentDecl = structDecl;
+ structDecl->members.add(conformanceDecl);
+ structDecl->parentDecl = parent;
+
+ synthesizedDecl = structDecl;
+ auto typeDef = m_astBuilder->create<TypeAliasDecl>();
+ typeDef->nameAndLoc.name = getName("Differential");
+ typeDef->parentDecl = structDecl;
+
+ auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
+
+ typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
+ structDecl->members.add(typeDef);
+
+ synthesizedDecl->parentDecl = parent;
+ synthesizedDecl->nameAndLoc.name = item.declRef.getName();
+ synthesizedDecl->loc = parent->loc;
+ parent->members.add(synthesizedDecl);
+ parent->invalidateMemberDictionary();
+
+ // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
+ // from user-provided definitions, and proceed to fill in its definition.
+ auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
+ addModifier(synthesizedDecl, toBeSynthesized);
+ }
+ else
+ {
+ // There's no need for a new struct decl.
+ // We can simply add a typealias to the existing concrete type.
+ //
+ auto typeDef = m_astBuilder->create<TypeAliasDecl>();
+ typeDef->nameAndLoc.name = item.declRef.getName();
+ typeDef->parentDecl = parent;
+ typeDef->type.type = subType;
+
+ synthesizedDecl = parent;
+
+ parent->members.add(typeDef);
+ parent->invalidateMemberDictionary();
+
+ markSelfDifferentialMembersOfType(parent, subType);
+ }
}
break;
default:
return nullptr;
}
- synthesizedDecl->parentDecl = parent;
- synthesizedDecl->nameAndLoc.name = item.declRef.getName();
- synthesizedDecl->loc = parent->loc;
- parent->members.add(synthesizedDecl);
- parent->invalidateMemberDictionary();
-
- // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
- // from user-provided definitions, and proceed to fill in its definition.
- auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
- addModifier(synthesizedDecl, toBeSynthesized);
auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl);
return ConstructDeclRefExpr(
@@ -1145,6 +1168,51 @@ namespace Slang
return nullptr;
}
+ bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl)
+ {
+ // A struct can be used as its own differential type if all its members are differentiable
+ // and their differential types are the same as the original types.
+ //
+ bool canBeUsed = true;
+ for (auto member : aggTypeDecl->members)
+ {
+ if (auto varDecl = as<VarDecl>(member))
+ {
+ // Try to get the differential type of the member.
+ Type* diffType = tryGetDifferentialType(getASTBuilder(), varDecl->getType());
+ if (!diffType || !diffType->equals(varDecl->getType()))
+ {
+ canBeUsed = false;
+ break;
+ }
+ }
+ }
+ return canBeUsed;
+ }
+
+ void SemanticsVisitor::markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type)
+ {
+ // TODO: Handle extensions.
+ // Add derivative member attributes to all the fields pointing to themselves.
+ for (auto member : parent->getMembersOfType<VarDeclBase>())
+ {
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = member->getType();
+
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = type;
+ auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(type);
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+
+ fieldLookupExpr->declRef = makeDeclRef(member);
+
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(member, derivativeMemberModifier);
+ }
+ }
+
Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc)
{
auto result = tryGetDifferentialType(builder, type);