diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-18 12:37:27 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-18 12:37:27 -0800 |
| commit | d58e08f8237a1888ceaad53402d534679ea83b1a (patch) | |
| tree | e66838e0dc31fc12ebd7c1acecbb5060e8808366 /source/slang/slang-check-decl.cpp | |
| parent | 0a050a439fa91b66f2020421d4fec3e60aed4112 (diff) | |
Data flow validation pass for diagnosing derivative loss. (#2523)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index ffbc5a841..009d0a987 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4635,6 +4635,7 @@ namespace Slang checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr); attr->backDeclRef = fwdDerivativeAttr->funcExpr; fwdDerivativeAttr->funcExpr = nullptr; + getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), DeclAssociationKind::ForwardDerivativeFunc, funcDecl); return; } } @@ -4684,6 +4685,22 @@ namespace Slang if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>()) checkDerivativeAttribute(decl, derivativeAttr); + if (newContext.getParentDifferentiableAttribute()) + { + // Register additional types outside the function body first. + auto oldAttr = m_parentDifferentiableAttr; + m_parentDifferentiableAttr = newContext.getParentDifferentiableAttribute(); + for (auto param : decl->getParameters()) + maybeRegisterDifferentiableType(m_astBuilder, param->type.type); + maybeRegisterDifferentiableType(m_astBuilder, decl->returnType.type); + if (as<ConstructorDecl>(decl) || !isEffectivelyStatic(decl)) + { + auto thisType = calcThisType(makeDeclRef(decl)); + maybeRegisterDifferentiableType(m_astBuilder, thisType); + } + m_parentDifferentiableAttr = oldAttr; + } + if (auto body = decl->body) { checkStmt(decl->body, newContext); @@ -6379,6 +6396,126 @@ namespace Slang } } + /// Get a reference to the associated decl list for `decl` in the given dictionary + /// + /// Note: this function creates an empty list of candidates for the given type if + /// a matching entry doesn't exist already. + /// + static List<DeclAssociation>& _getDeclAssociationList( + Decl* decl, + OrderedDictionary<Decl*, RefPtr<DeclAssociationList>>& mapDeclToDeclarations) + { + RefPtr<DeclAssociationList> entry; + if (!mapDeclToDeclarations.TryGetValue(decl, entry)) + { + entry = new DeclAssociationList(); + mapDeclToDeclarations.Add(decl, entry); + } + return entry->associations; + } + + void SharedSemanticsContext::_addDeclAssociationsFromModule(ModuleDecl* moduleDecl) + { + for (auto& entry : moduleDecl->mapDeclToAssociatedDecls) + { + auto& list = _getDeclAssociationList(entry.Key, m_mapDeclToAssociatedDecls); + list.addRange(entry.Value->associations); + } + } + + void SharedSemanticsContext::registerAssociatedDecl(Decl* original, DeclAssociationKind kind, Decl* associated) + { + auto moduleDecl = getModuleDecl(associated); + DeclAssociation assoc = {kind, associated}; + _getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc); + + m_associatedDeclListsBuilt = false; + m_mapDeclToAssociatedDecls.Clear(); + } + + List<DeclAssociation> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl) + { + // This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`. + // Consider refactoring them into the same framework. + if (!m_associatedDeclListsBuilt) + { + m_associatedDeclListsBuilt = true; + + for (auto module : getSession()->stdlibModules) + { + _addDeclAssociationsFromModule(module->getModuleDecl()); + } + + if (m_module) + { + _addDeclAssociationsFromModule(m_module->getModuleDecl()); + for (auto moduleDecl : this->importedModulesList) + { + _addDeclAssociationsFromModule(moduleDecl); + } + } + else + { + for (auto module : m_linkage->loadedModulesList) + { + _addDeclAssociationsFromModule(module->getModuleDecl()); + } + } + } + return _getDeclAssociationList(decl, m_mapDeclToAssociatedDecls); + } + + bool SharedSemanticsContext::isDifferentiableFunc(FunctionDeclBase* func) + { + // A function is differentiable if it is marked as differentiable, or it + // has an associated derivative function. + if (func->findModifier<DifferentiableAttribute>()) + return true; + for (auto assocDecl : getAssociatedDeclsForDecl(func)) + { + switch (assocDecl.kind) + { + case DeclAssociationKind::ForwardDerivativeFunc: + case DeclAssociationKind::BackwardDerivativeFunc: + return true; + default: + break; + } + } + return false; + } + + bool SharedSemanticsContext::isBackwardDifferentiableFunc(FunctionDeclBase* func) + { + // A function is differentiable if it is marked as differentiable, or it + // has an associated derivative function. + if (func->findModifier<BackwardDifferentiableAttribute>()) + return true; + for (auto assocDecl : getAssociatedDeclsForDecl(func)) + { + switch (assocDecl.kind) + { + case DeclAssociationKind::BackwardDerivativeFunc: + return true; + default: + break; + } + } + if (auto builtinReq = func->findModifier<BuiltinRequirementModifier>()) + { + switch (builtinReq->kind) + { + case BuiltinRequirementKind::DAddFunc: + case BuiltinRequirementKind::DMulFunc: + case BuiltinRequirementKind::DZeroFunc: + return true; + default: + break; + } + } + return false; + } + List<ExtensionDecl*> const& getCandidateExtensions( DeclRef<AggTypeDecl> const& declRef, SemanticsVisitor* semantics) |
