summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-26 08:32:24 -0700
committerGitHub <noreply@github.com>2022-10-26 08:32:24 -0700
commit939be44ca23476e622dfb24a592383fe2a1da61f (patch)
tree7f45645897fe5735d58a7687290552d479e4d6fc /source/slang/slang-check-decl.cpp
parent4fc34b18da2f83ee6b4f094067503a66cab3d0b5 (diff)
Auto synthesis of Differential type (#2466)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp126
1 files changed, 123 insertions, 3 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 356105e4f..fa05dde11 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -981,7 +981,7 @@ namespace Slang
VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr)
{
auto memberType = checkProperType(getLinkage(), varDecl->type, getSink());
- auto diffType = _getDifferential(m_astBuilder, memberType);
+ auto diffType = getDifferentialType(m_astBuilder, memberType, varDecl->loc);
if (as<ErrorType>(diffType))
{
getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeIsNotDifferentiable, memberType);
@@ -994,7 +994,7 @@ namespace Slang
Diagnostics::
derivativeMemberAttributeCanOnlyBeUsedOnMembers);
}
- auto diffThisType = _getDifferential(m_astBuilder, thisType);
+ auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc);
if (!thisType)
{
getSink()->diagnose(
@@ -1359,6 +1359,104 @@ namespace Slang
}
}
+ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness(
+ ConformanceCheckingContext* context,
+ DeclRef<Decl> requirementDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ // We currently can't handle generic types.
+ if (GetOuterGeneric(context->parentDecl) != nullptr)
+ {
+ return false;
+ }
+
+ Decl* existingDecl = nullptr;
+ AggTypeDecl* aggTypeDecl = nullptr;
+ if (context->parentDecl->getMemberDictionary().TryGetValue(requirementDeclRef.getName(), existingDecl))
+ {
+ aggTypeDecl = as<AggTypeDecl>(existingDecl);
+ SLANG_RELEASE_ASSERT(aggTypeDecl);
+
+ // Remove the `ToBeSynthesizedModifier`.
+ if (as<ToBeSynthesizedModifier>(aggTypeDecl->modifiers.first))
+ {
+ aggTypeDecl->modifiers.first = aggTypeDecl->modifiers.first->next;
+ }
+ }
+ else
+ {
+ aggTypeDecl = m_astBuilder->create<StructDecl>();
+ aggTypeDecl->parentDecl = context->parentDecl;
+ context->parentDecl->members.add((aggTypeDecl));
+ aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName();
+ aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc;
+ context->parentDecl->getMemberDictionary().Add(aggTypeDecl->getName(), aggTypeDecl);
+ }
+
+ // TODO: if we want to make the synthesized type itself to be differentiable,
+ // add an inheritance decl here. Need to be careful to avoid infinite recursion
+ // trying to synthesize the higher order differential types.
+
+ // Helper function to add a `diffType` field into the synthesized type for the original
+ // `member`.
+ auto differentialType = GetTypeForDeclRef(makeDeclRef(aggTypeDecl), context->parentDecl->loc);
+ auto addDiffMember = [&](Decl* member, Type* diffMemberType)
+ {
+ // If the field is differentiable, add a corresponding field in the associated Differential type.
+ auto diffField = m_astBuilder->create<VarDecl>();
+ diffField->nameAndLoc = member->nameAndLoc;
+ diffField->type.type = diffMemberType;
+ diffField->checkState = DeclCheckState::SignatureChecked;
+ diffField->parentDecl = aggTypeDecl;
+ aggTypeDecl->members.add(diffField);
+
+ // Inject a `DerivativeMember` modifier on the original decl.
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = diffMemberType;
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = differentialType;
+ auto baseTypeType = m_astBuilder->create<TypeType>();
+ baseTypeType->type = differentialType;
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+ fieldLookupExpr->declRef = makeDeclRef(diffField);
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(member, derivativeMemberModifier);
+ };
+
+ // Go through super types.
+ for (auto inheritance : context->parentDecl->getMembersOfType<InheritanceDecl>())
+ {
+ if (auto baseDeclRefType = as<DeclRefType>(inheritance->base.type))
+ {
+ // Skip interface super types.
+ if (baseDeclRefType->declRef.as<InterfaceDecl>())
+ continue;
+ if (auto superDiffType = tryGetDifferentialType(m_astBuilder, baseDeclRefType))
+ {
+ addDiffMember(inheritance, superDiffType);
+ }
+ }
+ }
+
+ // We go through all members and generate their differential counterparts.
+ for (auto member : context->parentDecl->getMembersOfType<VarDeclBase>())
+ {
+ auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type);
+ if (!diffType)
+ continue;
+ addDiffMember(member, diffType);
+ }
+
+ // In the future when the Differential type itself needs to conform to some interface,
+ // this is the place to synthesize requirements for them.
+ addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>());
+ auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, nullptr);
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType));
+ return true;
+ }
+
void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*)
{
// If the autodiff core library (diff.meta.slang) has not been loaded yet, ignore any
@@ -2146,6 +2244,13 @@ namespace Slang
DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef,
RefPtr<WitnessTable> witnessTable)
{
+ if (auto declRefType = as<DeclRefType>(satisfyingType))
+ {
+ // If we are seeing a placeholder that awaits synthesis, return false now to trigger
+ // auto synthesis.
+ if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>())
+ return false;
+ }
// We need to confirm that the chosen type `satisfyingType`,
// meets all the constraints placed on the associated type
// requirement `requiredAssociatedTypeDeclRef`.
@@ -2947,6 +3052,21 @@ namespace Slang
witnessTable);
}
+ if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>())
+ {
+ if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinAssociatedTypeRequirementAttribute>())
+ {
+ switch (builtinAttr->kind)
+ {
+ case BuiltinAssociatedTypeRequirementKind::Differential:
+ return trySynthesizeDifferentialAssociatedTypeRequirementWitness(
+ context,
+ requiredAssocTypeDeclRef,
+ witnessTable);
+ }
+ }
+ }
+
// TODO: There are other kinds of requirements for which synthesis should
// be possible:
//
@@ -4876,7 +4996,7 @@ namespace Slang
// We will now look for other declarations with
// the same name in the same parent/container.
//
- buildMemberDictionary(parentDecl);
+ parentDecl->buildMemberDictionary();
for (auto oldDecl = newDecl->nextInContainerWithSameName; oldDecl; oldDecl = oldDecl->nextInContainerWithSameName)
{
// For each matching declaration, we will check