diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 120 |
1 files changed, 94 insertions, 26 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index a399ea389..ff74f9a62 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -598,37 +598,60 @@ namespace Slang { case BuiltinRequirementKind::DifferentialType: { - auto structDecl = m_astBuilder->create<StructDecl>(); - auto conformanceDecl = m_astBuilder->create<InheritanceDecl>(); - conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); - conformanceDecl->parentDecl = structDecl; - structDecl->members.add(conformanceDecl); - structDecl->parentDecl = parent; - - synthesizedDecl = structDecl; - auto typeDef = m_astBuilder->create<TypeAliasDecl>(); - typeDef->nameAndLoc.name = getName("Differential"); - typeDef->parentDecl = structDecl; - - auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); - - typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef); - structDecl->members.add(typeDef); + if (!canStructBeUsedAsSelfDifferentialType(parent)) + { + // Need to create a new struct type for the differential. + // + auto structDecl = m_astBuilder->create<StructDecl>(); + auto conformanceDecl = m_astBuilder->create<InheritanceDecl>(); + conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); + conformanceDecl->parentDecl = structDecl; + structDecl->members.add(conformanceDecl); + structDecl->parentDecl = parent; + + synthesizedDecl = structDecl; + auto typeDef = m_astBuilder->create<TypeAliasDecl>(); + typeDef->nameAndLoc.name = getName("Differential"); + typeDef->parentDecl = structDecl; + + auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); + + typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef); + structDecl->members.add(typeDef); + + synthesizedDecl->parentDecl = parent; + synthesizedDecl->nameAndLoc.name = item.declRef.getName(); + synthesizedDecl->loc = parent->loc; + parent->members.add(synthesizedDecl); + parent->invalidateMemberDictionary(); + + // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it + // from user-provided definitions, and proceed to fill in its definition. + auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>(); + addModifier(synthesizedDecl, toBeSynthesized); + } + else + { + // There's no need for a new struct decl. + // We can simply add a typealias to the existing concrete type. + // + auto typeDef = m_astBuilder->create<TypeAliasDecl>(); + typeDef->nameAndLoc.name = item.declRef.getName(); + typeDef->parentDecl = parent; + typeDef->type.type = subType; + + synthesizedDecl = parent; + + parent->members.add(typeDef); + parent->invalidateMemberDictionary(); + + markSelfDifferentialMembersOfType(parent, subType); + } } break; default: return nullptr; } - synthesizedDecl->parentDecl = parent; - synthesizedDecl->nameAndLoc.name = item.declRef.getName(); - synthesizedDecl->loc = parent->loc; - parent->members.add(synthesizedDecl); - parent->invalidateMemberDictionary(); - - // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it - // from user-provided definitions, and proceed to fill in its definition. - auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>(); - addModifier(synthesizedDecl, toBeSynthesized); auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl); return ConstructDeclRefExpr( @@ -1145,6 +1168,51 @@ namespace Slang return nullptr; } + bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl) + { + // A struct can be used as its own differential type if all its members are differentiable + // and their differential types are the same as the original types. + // + bool canBeUsed = true; + for (auto member : aggTypeDecl->members) + { + if (auto varDecl = as<VarDecl>(member)) + { + // Try to get the differential type of the member. + Type* diffType = tryGetDifferentialType(getASTBuilder(), varDecl->getType()); + if (!diffType || !diffType->equals(varDecl->getType())) + { + canBeUsed = false; + break; + } + } + } + return canBeUsed; + } + + void SemanticsVisitor::markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type) + { + // TODO: Handle extensions. + // Add derivative member attributes to all the fields pointing to themselves. + for (auto member : parent->getMembersOfType<VarDeclBase>()) + { + auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>(); + auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>(); + fieldLookupExpr->type.type = member->getType(); + + auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); + baseTypeExpr->base.type = type; + auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(type); + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + + fieldLookupExpr->declRef = makeDeclRef(member); + + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); + } + } + Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc) { auto result = tryGetDifferentialType(builder, type); |
