diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-04 09:36:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-04 09:36:23 -0700 |
| commit | c6e6b7a9177bf4f7fc2f05da36c5952979006d78 (patch) | |
| tree | 6db694b5b4bf94ce48678c73921676f9d305614d /source/slang/slang-check-decl.cpp | |
| parent | 015bde8d5a46f32979c00dbb1feb4b3d80729c44 (diff) | |
Higher order differentiation. (#2487)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 137 |
1 files changed, 112 insertions, 25 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7140d541a..333e9d973 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -254,6 +254,8 @@ namespace Slang void visitFunctionDeclBase(FunctionDeclBase* funcDecl); void visitParamDecl(ParamDecl* paramDecl); + + void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context); }; /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? @@ -1433,6 +1435,22 @@ namespace Slang synth.pushScopeForContainer(aggTypeDecl); } + // If `This` is nested inside a generic, we need to form a complete declref type to the + // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution + // from requirementDeclRef to get the generic substitution for outer generic parameters, and + // apply it to the newly synthesized decl. + SubstitutionSet substSet; + if (auto thisTypeSusbt = findThisTypeSubstitution( + requirementDeclRef.substitutions, + as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) + { + if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) + { + substSet = declRefType->declRef.substitutions; + } + } + auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet); + // Helper function to add a `diffType` field into the synthesized type for the original // `member`. auto differentialType = DeclRefType::create(m_astBuilder, makeDeclRef(aggTypeDecl)); @@ -1462,6 +1480,22 @@ namespace Slang addModifier(member, derivativeMemberModifier); }; + // Make the Differential type itself conform to `IDifferential` interface. + auto inheritanceIDiffernetiable = m_astBuilder->create<InheritanceDecl>(); + inheritanceIDiffernetiable->base.type = + DeclRefType::create(m_astBuilder, m_astBuilder->getDifferentiableInterface()); + inheritanceIDiffernetiable->parentDecl = aggTypeDecl; + aggTypeDecl->members.add(inheritanceIDiffernetiable); + + // The `Differential` type of a `Differential` type is always itself. + auto assocTypeDef = m_astBuilder->create<TypeDefDecl>(); + assocTypeDef->nameAndLoc.name = getName("Differential"); + assocTypeDef->type.type = satisfyingType; + assocTypeDef->parentDecl = aggTypeDecl; + assocTypeDef->setCheckState(DeclCheckState::Checked); + aggTypeDecl->members.add(assocTypeDef); + + // Go through all members and collect their differential types. // Go through super types. for (auto inheritance : context->parentDecl->getMembersOfType<InheritanceDecl>()) { @@ -1476,8 +1510,7 @@ namespace Slang } } } - - // We go through all members and generate their differential counterparts. + // Go through all var members. for (auto member : context->parentDecl->getMembersOfType<VarDeclBase>()) { auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); @@ -1488,22 +1521,9 @@ namespace Slang addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>()); - // If `This` is nested inside a generic, we need to form a complete declref type to the - // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution - // from requirementDeclRef to get the generic substitution for outer generic parameters, and - // apply it to the newly synthesized decl. - SubstitutionSet substSet; - if (auto thisTypeSusbt = findThisTypeSubstitution( - requirementDeclRef.substitutions, - as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) - { - if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) - { - substSet = declRefType->declRef.substitutions; - } - } - - auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet); + // Synthesize the rest of IDifferential method conformances by recursively checking + // conformance on the synthesized decl. + checkAggTypeConformance(aggTypeDecl); if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) { @@ -1616,6 +1636,50 @@ namespace Slang } }; + // Check that types used as `Differential` type use themselves as their own `Differential` type. + struct SemanticsDeclDifferentialConformanceVisitor + : public SemanticsDeclVisitorBase + , public DeclVisitor<SemanticsDeclDifferentialConformanceVisitor> + { + SemanticsDeclDifferentialConformanceVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + if (as<InterfaceDecl>(inheritanceDecl->parentDecl)) + return; + + if (!inheritanceDecl->witnessTable) + return; + auto baseType = as<DeclRefType>(inheritanceDecl->witnessTable->baseType); + if (!baseType) + return; + if (baseType->declRef.getDecl() != m_astBuilder->getDifferentiableInterface().getDecl()) + return; + RequirementWitness witnessValue; + auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType); + if (!inheritanceDecl->witnessTable->requirementDictionary.TryGetValue(requirementDecl, witnessValue)) + return; + + // A type used as differential type must have itself as its own differential type. + if (witnessValue.getFlavor() != RequirementWitness::Flavor::val) + return; + auto differentialType = as<DeclRefType>(witnessValue.getVal()); + if (!differentialType) + return; + auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType); + if (!differentialType->equals(diffDiffType)) + { + SourceLoc sourceLoc = differentialType->declRef.getDecl()->loc; + getSink()->diagnose(sourceLoc, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType); + getSink()->diagnose(inheritanceDecl, Diagnostics::noteSeeUseOfDifferentialType, differentialType, inheritanceDecl->getSup()); + } + } + }; + /// Recursively register any builtin declarations that need to be attached to the `session`. /// /// This function should only be needed for declarations in the standard library. @@ -1632,7 +1696,10 @@ namespace Slang { sharedASTBuilder->registerMagicDecl(decl, magicMod); } - + if (auto builtinRequirement = decl->findModifier<BuiltinRequirementModifier>()) + { + sharedASTBuilder->registerBuiltinRequirementDecl(decl, builtinRequirement); + } if(auto containerDecl = as<ContainerDecl>(decl)) { for(auto childDecl : containerDecl->members) @@ -2217,13 +2284,14 @@ namespace Slang // associated type and see if they can be satisfied. // bool conformance = true; + Val* witness = nullptr; for (auto requiredConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(requiredAssociatedTypeDeclRef)) { // Grab the type we expect to conform to from the constraint. auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef); // Perform a search for a witness to the subtype relationship. - auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); + witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); if (witness) { // If a subtype witness was found, then the conformance @@ -3040,7 +3108,7 @@ namespace Slang witnessTable)) return true; - if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>()) + if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>()) { switch (builtinAttr->kind) { @@ -3067,7 +3135,7 @@ namespace Slang if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>()) { - if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>()) + if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>()) { switch (builtinAttr->kind) { @@ -3160,7 +3228,7 @@ namespace Slang bool hasDifferentialAssocType = false; for (auto existingEntry : witnessTable->requirementList) { - if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementAttribute>()) + if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementModifier>()) { if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType && existingEntry.Value.getFlavor() != RequirementWitness::Flavor::none) @@ -3401,7 +3469,7 @@ namespace Slang // requirement, it may be possible that we can still synthesis the // implementation if this is one of the known builtin requirements. // Otherwise, report diagnostic now. - if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementAttribute>()) + if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>()) { getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); @@ -4499,11 +4567,29 @@ namespace Slang getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); } + void SemanticsDeclBodyVisitor::_maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context) + { + auto parentDifferentiableAttr = context.getParentDifferentiableAttribute(); + if (parentDifferentiableAttr) + { + auto diffBottomType = m_astBuilder->getDifferentialBottomType(); + auto idifferentiable = DeclRef<InterfaceDecl>(m_astBuilder->getDifferentiableInterface(), nullptr); + auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(diffBottomType, idifferentiable)); + SLANG_ASSERT(witness); + parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.Add( + as<DeclRefType>(diffBottomType)->declRef, + witness); + } + } + void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { + auto newContext = withParentFunc(decl); + _maybeRegisterDifferentialBottomTypeConformance(newContext); + if (auto body = decl->body) { - checkBodyStmt(body, decl); + checkStmt(decl->body, newContext); } } @@ -6234,6 +6320,7 @@ namespace Slang case DeclCheckState::TypesFullyResolved: SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl); + SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl); break; case DeclCheckState::Checked: |
