From 939be44ca23476e622dfb24a592383fe2a1da61f Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 26 Oct 2022 08:32:24 -0700 Subject: Auto synthesis of Differential type (#2466) --- source/slang/slang-check-decl.cpp | 126 +++++++++++++++++++++++++++++++++++++- 1 file changed, 123 insertions(+), 3 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 356105e4f..fa05dde11 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -981,7 +981,7 @@ namespace Slang VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr) { auto memberType = checkProperType(getLinkage(), varDecl->type, getSink()); - auto diffType = _getDifferential(m_astBuilder, memberType); + auto diffType = getDifferentialType(m_astBuilder, memberType, varDecl->loc); if (as(diffType)) { getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeIsNotDifferentiable, memberType); @@ -994,7 +994,7 @@ namespace Slang Diagnostics:: derivativeMemberAttributeCanOnlyBeUsedOnMembers); } - auto diffThisType = _getDifferential(m_astBuilder, thisType); + auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc); if (!thisType) { getSink()->diagnose( @@ -1359,6 +1359,104 @@ namespace Slang } } + bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requirementDeclRef, + RefPtr witnessTable) + { + // We currently can't handle generic types. + if (GetOuterGeneric(context->parentDecl) != nullptr) + { + return false; + } + + Decl* existingDecl = nullptr; + AggTypeDecl* aggTypeDecl = nullptr; + if (context->parentDecl->getMemberDictionary().TryGetValue(requirementDeclRef.getName(), existingDecl)) + { + aggTypeDecl = as(existingDecl); + SLANG_RELEASE_ASSERT(aggTypeDecl); + + // Remove the `ToBeSynthesizedModifier`. + if (as(aggTypeDecl->modifiers.first)) + { + aggTypeDecl->modifiers.first = aggTypeDecl->modifiers.first->next; + } + } + else + { + aggTypeDecl = m_astBuilder->create(); + aggTypeDecl->parentDecl = context->parentDecl; + context->parentDecl->members.add((aggTypeDecl)); + aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); + aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; + context->parentDecl->getMemberDictionary().Add(aggTypeDecl->getName(), aggTypeDecl); + } + + // TODO: if we want to make the synthesized type itself to be differentiable, + // add an inheritance decl here. Need to be careful to avoid infinite recursion + // trying to synthesize the higher order differential types. + + // Helper function to add a `diffType` field into the synthesized type for the original + // `member`. + auto differentialType = GetTypeForDeclRef(makeDeclRef(aggTypeDecl), context->parentDecl->loc); + auto addDiffMember = [&](Decl* member, Type* diffMemberType) + { + // If the field is differentiable, add a corresponding field in the associated Differential type. + auto diffField = m_astBuilder->create(); + diffField->nameAndLoc = member->nameAndLoc; + diffField->type.type = diffMemberType; + diffField->checkState = DeclCheckState::SignatureChecked; + diffField->parentDecl = aggTypeDecl; + aggTypeDecl->members.add(diffField); + + // Inject a `DerivativeMember` modifier on the original decl. + auto derivativeMemberModifier = m_astBuilder->create(); + auto fieldLookupExpr = m_astBuilder->create(); + fieldLookupExpr->type.type = diffMemberType; + auto baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->create(); + baseTypeType->type = differentialType; + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(diffField); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); + }; + + // Go through super types. + for (auto inheritance : context->parentDecl->getMembersOfType()) + { + if (auto baseDeclRefType = as(inheritance->base.type)) + { + // Skip interface super types. + if (baseDeclRefType->declRef.as()) + continue; + if (auto superDiffType = tryGetDifferentialType(m_astBuilder, baseDeclRefType)) + { + addDiffMember(inheritance, superDiffType); + } + } + } + + // We go through all members and generate their differential counterparts. + for (auto member : context->parentDecl->getMembersOfType()) + { + auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); + if (!diffType) + continue; + addDiffMember(member, diffType); + } + + // In the future when the Differential type itself needs to conform to some interface, + // this is the place to synthesize requirements for them. + addModifier(aggTypeDecl, m_astBuilder->create()); + auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, nullptr); + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); + return true; + } + void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*) { // If the autodiff core library (diff.meta.slang) has not been loaded yet, ignore any @@ -2146,6 +2244,13 @@ namespace Slang DeclRef requiredAssociatedTypeDeclRef, RefPtr witnessTable) { + if (auto declRefType = as(satisfyingType)) + { + // If we are seeing a placeholder that awaits synthesis, return false now to trigger + // auto synthesis. + if (declRefType->declRef.getDecl()->hasModifier()) + return false; + } // We need to confirm that the chosen type `satisfyingType`, // meets all the constraints placed on the associated type // requirement `requiredAssociatedTypeDeclRef`. @@ -2947,6 +3052,21 @@ namespace Slang witnessTable); } + if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as()) + { + if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier()) + { + switch (builtinAttr->kind) + { + case BuiltinAssociatedTypeRequirementKind::Differential: + return trySynthesizeDifferentialAssociatedTypeRequirementWitness( + context, + requiredAssocTypeDeclRef, + witnessTable); + } + } + } + // TODO: There are other kinds of requirements for which synthesis should // be possible: // @@ -4876,7 +4996,7 @@ namespace Slang // We will now look for other declarations with // the same name in the same parent/container. // - buildMemberDictionary(parentDecl); + parentDecl->buildMemberDictionary(); for (auto oldDecl = newDecl->nextInContainerWithSameName; oldDecl; oldDecl = oldDecl->nextInContainerWithSameName) { // For each matching declaration, we will check -- cgit v1.2.3