summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-18 12:37:27 -0800
committerGitHub <noreply@github.com>2022-11-18 12:37:27 -0800
commitd58e08f8237a1888ceaad53402d534679ea83b1a (patch)
treee66838e0dc31fc12ebd7c1acecbb5060e8808366 /source/slang/slang-check-decl.cpp
parent0a050a439fa91b66f2020421d4fec3e60aed4112 (diff)
Data flow validation pass for diagnosing derivative loss. (#2523)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp137
1 files changed, 137 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index ffbc5a841..009d0a987 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -4635,6 +4635,7 @@ namespace Slang
checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr);
attr->backDeclRef = fwdDerivativeAttr->funcExpr;
fwdDerivativeAttr->funcExpr = nullptr;
+ getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), DeclAssociationKind::ForwardDerivativeFunc, funcDecl);
return;
}
}
@@ -4684,6 +4685,22 @@ namespace Slang
if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>())
checkDerivativeAttribute(decl, derivativeAttr);
+ if (newContext.getParentDifferentiableAttribute())
+ {
+ // Register additional types outside the function body first.
+ auto oldAttr = m_parentDifferentiableAttr;
+ m_parentDifferentiableAttr = newContext.getParentDifferentiableAttribute();
+ for (auto param : decl->getParameters())
+ maybeRegisterDifferentiableType(m_astBuilder, param->type.type);
+ maybeRegisterDifferentiableType(m_astBuilder, decl->returnType.type);
+ if (as<ConstructorDecl>(decl) || !isEffectivelyStatic(decl))
+ {
+ auto thisType = calcThisType(makeDeclRef(decl));
+ maybeRegisterDifferentiableType(m_astBuilder, thisType);
+ }
+ m_parentDifferentiableAttr = oldAttr;
+ }
+
if (auto body = decl->body)
{
checkStmt(decl->body, newContext);
@@ -6379,6 +6396,126 @@ namespace Slang
}
}
+ /// Get a reference to the associated decl list for `decl` in the given dictionary
+ ///
+ /// Note: this function creates an empty list of candidates for the given type if
+ /// a matching entry doesn't exist already.
+ ///
+ static List<DeclAssociation>& _getDeclAssociationList(
+ Decl* decl,
+ OrderedDictionary<Decl*, RefPtr<DeclAssociationList>>& mapDeclToDeclarations)
+ {
+ RefPtr<DeclAssociationList> entry;
+ if (!mapDeclToDeclarations.TryGetValue(decl, entry))
+ {
+ entry = new DeclAssociationList();
+ mapDeclToDeclarations.Add(decl, entry);
+ }
+ return entry->associations;
+ }
+
+ void SharedSemanticsContext::_addDeclAssociationsFromModule(ModuleDecl* moduleDecl)
+ {
+ for (auto& entry : moduleDecl->mapDeclToAssociatedDecls)
+ {
+ auto& list = _getDeclAssociationList(entry.Key, m_mapDeclToAssociatedDecls);
+ list.addRange(entry.Value->associations);
+ }
+ }
+
+ void SharedSemanticsContext::registerAssociatedDecl(Decl* original, DeclAssociationKind kind, Decl* associated)
+ {
+ auto moduleDecl = getModuleDecl(associated);
+ DeclAssociation assoc = {kind, associated};
+ _getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc);
+
+ m_associatedDeclListsBuilt = false;
+ m_mapDeclToAssociatedDecls.Clear();
+ }
+
+ List<DeclAssociation> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl)
+ {
+ // This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`.
+ // Consider refactoring them into the same framework.
+ if (!m_associatedDeclListsBuilt)
+ {
+ m_associatedDeclListsBuilt = true;
+
+ for (auto module : getSession()->stdlibModules)
+ {
+ _addDeclAssociationsFromModule(module->getModuleDecl());
+ }
+
+ if (m_module)
+ {
+ _addDeclAssociationsFromModule(m_module->getModuleDecl());
+ for (auto moduleDecl : this->importedModulesList)
+ {
+ _addDeclAssociationsFromModule(moduleDecl);
+ }
+ }
+ else
+ {
+ for (auto module : m_linkage->loadedModulesList)
+ {
+ _addDeclAssociationsFromModule(module->getModuleDecl());
+ }
+ }
+ }
+ return _getDeclAssociationList(decl, m_mapDeclToAssociatedDecls);
+ }
+
+ bool SharedSemanticsContext::isDifferentiableFunc(FunctionDeclBase* func)
+ {
+ // A function is differentiable if it is marked as differentiable, or it
+ // has an associated derivative function.
+ if (func->findModifier<DifferentiableAttribute>())
+ return true;
+ for (auto assocDecl : getAssociatedDeclsForDecl(func))
+ {
+ switch (assocDecl.kind)
+ {
+ case DeclAssociationKind::ForwardDerivativeFunc:
+ case DeclAssociationKind::BackwardDerivativeFunc:
+ return true;
+ default:
+ break;
+ }
+ }
+ return false;
+ }
+
+ bool SharedSemanticsContext::isBackwardDifferentiableFunc(FunctionDeclBase* func)
+ {
+ // A function is differentiable if it is marked as differentiable, or it
+ // has an associated derivative function.
+ if (func->findModifier<BackwardDifferentiableAttribute>())
+ return true;
+ for (auto assocDecl : getAssociatedDeclsForDecl(func))
+ {
+ switch (assocDecl.kind)
+ {
+ case DeclAssociationKind::BackwardDerivativeFunc:
+ return true;
+ default:
+ break;
+ }
+ }
+ if (auto builtinReq = func->findModifier<BuiltinRequirementModifier>())
+ {
+ switch (builtinReq->kind)
+ {
+ case BuiltinRequirementKind::DAddFunc:
+ case BuiltinRequirementKind::DMulFunc:
+ case BuiltinRequirementKind::DZeroFunc:
+ return true;
+ default:
+ break;
+ }
+ }
+ return false;
+ }
+
List<ExtensionDecl*> const& getCandidateExtensions(
DeclRef<AggTypeDecl> const& declRef,
SemanticsVisitor* semantics)