summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-04 09:36:23 -0700
committerGitHub <noreply@github.com>2022-11-04 09:36:23 -0700
commitc6e6b7a9177bf4f7fc2f05da36c5952979006d78 (patch)
tree6db694b5b4bf94ce48678c73921676f9d305614d /source/slang/slang-check-decl.cpp
parent015bde8d5a46f32979c00dbb1feb4b3d80729c44 (diff)
Higher order differentiation. (#2487)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp137
1 files changed, 112 insertions, 25 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 7140d541a..333e9d973 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -254,6 +254,8 @@ namespace Slang
void visitFunctionDeclBase(FunctionDeclBase* funcDecl);
void visitParamDecl(ParamDecl* paramDecl);
+
+ void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context);
};
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
@@ -1433,6 +1435,22 @@ namespace Slang
synth.pushScopeForContainer(aggTypeDecl);
}
+ // If `This` is nested inside a generic, we need to form a complete declref type to the
+ // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution
+ // from requirementDeclRef to get the generic substitution for outer generic parameters, and
+ // apply it to the newly synthesized decl.
+ SubstitutionSet substSet;
+ if (auto thisTypeSusbt = findThisTypeSubstitution(
+ requirementDeclRef.substitutions,
+ as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
+ {
+ if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
+ {
+ substSet = declRefType->declRef.substitutions;
+ }
+ }
+ auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet);
+
// Helper function to add a `diffType` field into the synthesized type for the original
// `member`.
auto differentialType = DeclRefType::create(m_astBuilder, makeDeclRef(aggTypeDecl));
@@ -1462,6 +1480,22 @@ namespace Slang
addModifier(member, derivativeMemberModifier);
};
+ // Make the Differential type itself conform to `IDifferential` interface.
+ auto inheritanceIDiffernetiable = m_astBuilder->create<InheritanceDecl>();
+ inheritanceIDiffernetiable->base.type =
+ DeclRefType::create(m_astBuilder, m_astBuilder->getDifferentiableInterface());
+ inheritanceIDiffernetiable->parentDecl = aggTypeDecl;
+ aggTypeDecl->members.add(inheritanceIDiffernetiable);
+
+ // The `Differential` type of a `Differential` type is always itself.
+ auto assocTypeDef = m_astBuilder->create<TypeDefDecl>();
+ assocTypeDef->nameAndLoc.name = getName("Differential");
+ assocTypeDef->type.type = satisfyingType;
+ assocTypeDef->parentDecl = aggTypeDecl;
+ assocTypeDef->setCheckState(DeclCheckState::Checked);
+ aggTypeDecl->members.add(assocTypeDef);
+
+ // Go through all members and collect their differential types.
// Go through super types.
for (auto inheritance : context->parentDecl->getMembersOfType<InheritanceDecl>())
{
@@ -1476,8 +1510,7 @@ namespace Slang
}
}
}
-
- // We go through all members and generate their differential counterparts.
+ // Go through all var members.
for (auto member : context->parentDecl->getMembersOfType<VarDeclBase>())
{
auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type);
@@ -1488,22 +1521,9 @@ namespace Slang
addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>());
- // If `This` is nested inside a generic, we need to form a complete declref type to the
- // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution
- // from requirementDeclRef to get the generic substitution for outer generic parameters, and
- // apply it to the newly synthesized decl.
- SubstitutionSet substSet;
- if (auto thisTypeSusbt = findThisTypeSubstitution(
- requirementDeclRef.substitutions,
- as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
- {
- if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
- {
- substSet = declRefType->declRef.substitutions;
- }
- }
-
- auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet);
+ // Synthesize the rest of IDifferential method conformances by recursively checking
+ // conformance on the synthesized decl.
+ checkAggTypeConformance(aggTypeDecl);
if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable))
{
@@ -1616,6 +1636,50 @@ namespace Slang
}
};
+ // Check that types used as `Differential` type use themselves as their own `Differential` type.
+ struct SemanticsDeclDifferentialConformanceVisitor
+ : public SemanticsDeclVisitorBase
+ , public DeclVisitor<SemanticsDeclDifferentialConformanceVisitor>
+ {
+ SemanticsDeclDifferentialConformanceVisitor(SemanticsContext const& outer)
+ : SemanticsDeclVisitorBase(outer)
+ {}
+ void visitDecl(Decl*) {}
+ void visitDeclGroup(DeclGroup*) {}
+
+ void visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
+ {
+ if (as<InterfaceDecl>(inheritanceDecl->parentDecl))
+ return;
+
+ if (!inheritanceDecl->witnessTable)
+ return;
+ auto baseType = as<DeclRefType>(inheritanceDecl->witnessTable->baseType);
+ if (!baseType)
+ return;
+ if (baseType->declRef.getDecl() != m_astBuilder->getDifferentiableInterface().getDecl())
+ return;
+ RequirementWitness witnessValue;
+ auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType);
+ if (!inheritanceDecl->witnessTable->requirementDictionary.TryGetValue(requirementDecl, witnessValue))
+ return;
+
+ // A type used as differential type must have itself as its own differential type.
+ if (witnessValue.getFlavor() != RequirementWitness::Flavor::val)
+ return;
+ auto differentialType = as<DeclRefType>(witnessValue.getVal());
+ if (!differentialType)
+ return;
+ auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType);
+ if (!differentialType->equals(diffDiffType))
+ {
+ SourceLoc sourceLoc = differentialType->declRef.getDecl()->loc;
+ getSink()->diagnose(sourceLoc, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType);
+ getSink()->diagnose(inheritanceDecl, Diagnostics::noteSeeUseOfDifferentialType, differentialType, inheritanceDecl->getSup());
+ }
+ }
+ };
+
/// Recursively register any builtin declarations that need to be attached to the `session`.
///
/// This function should only be needed for declarations in the standard library.
@@ -1632,7 +1696,10 @@ namespace Slang
{
sharedASTBuilder->registerMagicDecl(decl, magicMod);
}
-
+ if (auto builtinRequirement = decl->findModifier<BuiltinRequirementModifier>())
+ {
+ sharedASTBuilder->registerBuiltinRequirementDecl(decl, builtinRequirement);
+ }
if(auto containerDecl = as<ContainerDecl>(decl))
{
for(auto childDecl : containerDecl->members)
@@ -2217,13 +2284,14 @@ namespace Slang
// associated type and see if they can be satisfied.
//
bool conformance = true;
+ Val* witness = nullptr;
for (auto requiredConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(requiredAssociatedTypeDeclRef))
{
// Grab the type we expect to conform to from the constraint.
auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef);
// Perform a search for a witness to the subtype relationship.
- auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType);
+ witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType);
if (witness)
{
// If a subtype witness was found, then the conformance
@@ -3040,7 +3108,7 @@ namespace Slang
witnessTable))
return true;
- if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>())
+ if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>())
{
switch (builtinAttr->kind)
{
@@ -3067,7 +3135,7 @@ namespace Slang
if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>())
{
- if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>())
+ if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>())
{
switch (builtinAttr->kind)
{
@@ -3160,7 +3228,7 @@ namespace Slang
bool hasDifferentialAssocType = false;
for (auto existingEntry : witnessTable->requirementList)
{
- if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementAttribute>())
+ if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementModifier>())
{
if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType &&
existingEntry.Value.getFlavor() != RequirementWitness::Flavor::none)
@@ -3401,7 +3469,7 @@ namespace Slang
// requirement, it may be possible that we can still synthesis the
// implementation if this is one of the known builtin requirements.
// Otherwise, report diagnostic now.
- if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementAttribute>())
+ if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>())
{
getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef);
getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef);
@@ -4499,11 +4567,29 @@ namespace Slang
getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
}
+ void SemanticsDeclBodyVisitor::_maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context)
+ {
+ auto parentDifferentiableAttr = context.getParentDifferentiableAttribute();
+ if (parentDifferentiableAttr)
+ {
+ auto diffBottomType = m_astBuilder->getDifferentialBottomType();
+ auto idifferentiable = DeclRef<InterfaceDecl>(m_astBuilder->getDifferentiableInterface(), nullptr);
+ auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(diffBottomType, idifferentiable));
+ SLANG_ASSERT(witness);
+ parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.Add(
+ as<DeclRefType>(diffBottomType)->declRef,
+ witness);
+ }
+ }
+
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
{
+ auto newContext = withParentFunc(decl);
+ _maybeRegisterDifferentialBottomTypeConformance(newContext);
+
if (auto body = decl->body)
{
- checkBodyStmt(body, decl);
+ checkStmt(decl->body, newContext);
}
}
@@ -6234,6 +6320,7 @@ namespace Slang
case DeclCheckState::TypesFullyResolved:
SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl);
+ SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl);
break;
case DeclCheckState::Checked: