summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp120
1 files changed, 94 insertions, 26 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index a399ea389..ff74f9a62 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -598,37 +598,60 @@ namespace Slang
{
case BuiltinRequirementKind::DifferentialType:
{
- auto structDecl = m_astBuilder->create<StructDecl>();
- auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
- conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
- conformanceDecl->parentDecl = structDecl;
- structDecl->members.add(conformanceDecl);
- structDecl->parentDecl = parent;
-
- synthesizedDecl = structDecl;
- auto typeDef = m_astBuilder->create<TypeAliasDecl>();
- typeDef->nameAndLoc.name = getName("Differential");
- typeDef->parentDecl = structDecl;
-
- auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
-
- typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
- structDecl->members.add(typeDef);
+ if (!canStructBeUsedAsSelfDifferentialType(parent))
+ {
+ // Need to create a new struct type for the differential.
+ //
+ auto structDecl = m_astBuilder->create<StructDecl>();
+ auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
+ conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
+ conformanceDecl->parentDecl = structDecl;
+ structDecl->members.add(conformanceDecl);
+ structDecl->parentDecl = parent;
+
+ synthesizedDecl = structDecl;
+ auto typeDef = m_astBuilder->create<TypeAliasDecl>();
+ typeDef->nameAndLoc.name = getName("Differential");
+ typeDef->parentDecl = structDecl;
+
+ auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
+
+ typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
+ structDecl->members.add(typeDef);
+
+ synthesizedDecl->parentDecl = parent;
+ synthesizedDecl->nameAndLoc.name = item.declRef.getName();
+ synthesizedDecl->loc = parent->loc;
+ parent->members.add(synthesizedDecl);
+ parent->invalidateMemberDictionary();
+
+ // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
+ // from user-provided definitions, and proceed to fill in its definition.
+ auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
+ addModifier(synthesizedDecl, toBeSynthesized);
+ }
+ else
+ {
+ // There's no need for a new struct decl.
+ // We can simply add a typealias to the existing concrete type.
+ //
+ auto typeDef = m_astBuilder->create<TypeAliasDecl>();
+ typeDef->nameAndLoc.name = item.declRef.getName();
+ typeDef->parentDecl = parent;
+ typeDef->type.type = subType;
+
+ synthesizedDecl = parent;
+
+ parent->members.add(typeDef);
+ parent->invalidateMemberDictionary();
+
+ markSelfDifferentialMembersOfType(parent, subType);
+ }
}
break;
default:
return nullptr;
}
- synthesizedDecl->parentDecl = parent;
- synthesizedDecl->nameAndLoc.name = item.declRef.getName();
- synthesizedDecl->loc = parent->loc;
- parent->members.add(synthesizedDecl);
- parent->invalidateMemberDictionary();
-
- // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
- // from user-provided definitions, and proceed to fill in its definition.
- auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
- addModifier(synthesizedDecl, toBeSynthesized);
auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl);
return ConstructDeclRefExpr(
@@ -1145,6 +1168,51 @@ namespace Slang
return nullptr;
}
+ bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl)
+ {
+ // A struct can be used as its own differential type if all its members are differentiable
+ // and their differential types are the same as the original types.
+ //
+ bool canBeUsed = true;
+ for (auto member : aggTypeDecl->members)
+ {
+ if (auto varDecl = as<VarDecl>(member))
+ {
+ // Try to get the differential type of the member.
+ Type* diffType = tryGetDifferentialType(getASTBuilder(), varDecl->getType());
+ if (!diffType || !diffType->equals(varDecl->getType()))
+ {
+ canBeUsed = false;
+ break;
+ }
+ }
+ }
+ return canBeUsed;
+ }
+
+ void SemanticsVisitor::markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type)
+ {
+ // TODO: Handle extensions.
+ // Add derivative member attributes to all the fields pointing to themselves.
+ for (auto member : parent->getMembersOfType<VarDeclBase>())
+ {
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = member->getType();
+
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = type;
+ auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(type);
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+
+ fieldLookupExpr->declRef = makeDeclRef(member);
+
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(member, derivativeMemberModifier);
+ }
+ }
+
Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc)
{
auto result = tryGetDifferentialType(builder, type);