summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-20 14:22:00 -0400
committerGitHub <noreply@github.com>2022-10-20 11:22:00 -0700
commit1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch)
treee85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-check-decl.cpp
parent576c8407e60143682cd40c68101c6eae8563ca3d (diff)
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp205
1 files changed, 203 insertions, 2 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index b18e1c4da..2d6e20622 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -835,7 +835,15 @@ namespace Slang
// If `decl` is a container, then we want to ensure its children.
if(auto containerDecl = as<ContainerDecl>(decl))
- {
+ {
+ bool trackDiffTypes = (as<GenericDecl>(decl) != nullptr);
+ if (trackDiffTypes)
+ {
+ // Add a context to track differentiable types.
+ DifferentiableTypeSemanticContext subDiffTypeContext;
+ visitor->getShared()->pushDiffTypeContext(&subDiffTypeContext);
+ }
+
// NOTE! We purposefully do not iterate with the for(auto childDecl : containerDecl->members) here,
// because the visitor may add to `members` whilst iteration takes place, invalidating the iterator
// and likely a crash.
@@ -857,6 +865,21 @@ namespace Slang
_ensureAllDeclsRec(visitor, childDecl, state);
}
+
+ if (trackDiffTypes)
+ {
+ auto subDiffTypeContext = visitor->getShared()->popDiffTypeContext();
+
+ // If there were any differentiable types used in differentiable
+ // methods, generate a dictionary with the required info.
+ //
+ if (subDiffTypeContext->isDictionaryRequired())
+ {
+ auto diffTypeDict = subDiffTypeContext->makeDifferentiableTypeDictionaryNode(visitor->getASTBuilder());
+ diffTypeDict->parentDecl = containerDecl;
+ containerDecl->members.add(diffTypeDict);
+ }
+ }
}
// Note: the "inner" declaration of a `GenericDecl` is currently
@@ -1234,6 +1257,49 @@ namespace Slang
}
}
+ void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*)
+ {
+ // If the autodiff core library (diff.meta.slang) has not been loaded yet, ignore any
+ // request to check differentiable types.
+ //
+ if (!m_astBuilder->isDifferentiableInterfaceAvailable())
+ return;
+
+ auto diffInterface = m_astBuilder->getDifferentiableInterface();
+
+ DeclRefType* type = nullptr;
+
+ if (auto extensionDecl = as<ExtensionDecl>(decl))
+ {
+ // If this is an extension, use the provided target type.
+ type = as<DeclRefType>(extensionDecl->targetType.type);
+ }
+ else
+ {
+ // If this is a type declaration, create a decl ref without
+ // any substitutions.
+ //
+ auto declRef = makeDeclRef(decl);
+
+ // TODO: Strip substitutions from the declreftype
+ type = DeclRefType::create(m_astBuilder, declRef);
+ }
+
+ // Skip if the declaration is the interface itself.
+ if (type->declRef == diffInterface)
+ return;
+
+ // If the DeclRefType conforms to IDifferentiable, register it with the top-level
+ // context.
+ //
+ if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, diffInterface)))
+ {
+ // TODO: Temporarily disabled to move to new system. Fix later.
+ // context->registerDifferentiableType(type, witness);
+ }
+
+ }
+
void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl)
{
// TODO: are there any other validations we can do at this point?
@@ -1287,6 +1353,23 @@ namespace Slang
ensureDecl(constraint, DeclCheckState::ReadyForReference);
}
}
+
+ // TODO(sai): Is this the right checking stage to be doing this?
+ DifferentiableTypeSemanticContext diffTypeContext;
+
+ for (Index i = 0; i < members.getCount(); ++i)
+ {
+ Decl* m = members[i];
+
+ if (auto typeParam = as<GenericTypeParamDecl>(m))
+ {
+ tryAddDifferentiableConformanceToContext(typeParam, &diffTypeContext);
+ }
+ }
+
+ auto diffTypeDictionaryNode = diffTypeContext.makeDifferentiableTypeDictionaryNode(m_astBuilder);
+ diffTypeDictionaryNode->parentDecl = genericDecl;
+ genericDecl->members.add(diffTypeDictionaryNode);
}
void SemanticsDeclBasesVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
@@ -1322,6 +1405,7 @@ namespace Slang
void visitAggTypeDecl(AggTypeDecl* aggTypeDecl)
{
checkAggTypeConformance(aggTypeDecl);
+ tryAddDifferentiableConformanceToContext(aggTypeDecl, getShared()->getDiffTypeContext());
}
// Conformances can also come via `extension` declarations, and
@@ -1330,6 +1414,7 @@ namespace Slang
void visitExtensionDecl(ExtensionDecl* extensionDecl)
{
checkExtensionConformance(extensionDecl);
+ tryAddDifferentiableConformanceToContext(extensionDecl, getShared()->getDiffTypeContext());
}
};
@@ -1486,6 +1571,32 @@ namespace Slang
// Furthermore, because a fully checked function will have checked
// its body, this also means that all function bodies and the
// declarations they contain should be fully checked.
+
+ // Generate a dictionary node to hold information about all
+ // available differentiable types in scope (including imports and stdlib)
+ //
+ if (getShared()->getDiffTypeContext()->isDictionaryRequired())
+ finishDifferentiableTypeDictionary(moduleDecl);
+ }
+
+ void SemanticsVisitor::finishDifferentiableTypeDictionary(ModuleDecl* moduleDecl)
+ {
+ // Grab the differentiable type information from imported modules.
+ for(auto importedModule : getShared()->importedModulesList)
+ {
+ this->getShared()->getDiffTypeContext()->addImportedModule(importedModule);
+ }
+
+ // Grad the differentiable type information from the standard library modules.
+ for (auto stdLibModule : this->getSession()->stdlibModules)
+ {
+ this->getShared()->getDiffTypeContext()->addImportedModule(stdLibModule->getModuleDecl());
+ }
+
+ auto diffTypeDictNode = this->getShared()->getDiffTypeContext()->makeDifferentiableTypeDictionaryNode(m_astBuilder);
+ diffTypeDictNode->parentDecl = moduleDecl;
+
+ moduleDecl->members.add(diffTypeDictNode);
}
bool SemanticsVisitor::doesSignatureMatchRequirement(
@@ -4292,7 +4403,23 @@ namespace Slang
nullptr);
args.add(val);
}
- // TODO: need to handle constraints here?
+ }
+
+ // Add defaults for constraint parameters.
+ for (auto dd : genericDecl->members)
+ {
+ if (auto constraintDecl = as<GenericTypeConstraintDecl>(dd))
+ {
+ // Convert the constraint to an appropriate witness.
+ auto witness = tryGetSubtypeWitness(constraintDecl->sub, constraintDecl->sup);
+
+ // Must be non-null since we know there's a constraint. If null, something is
+ // very wrong.
+ //
+ SLANG_ASSERT(witness);
+
+ args.add(witness);
+ }
}
GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution(genericDecl, args, nullptr);
return subst;
@@ -4725,6 +4852,11 @@ namespace Slang
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
{
+ if (decl->findModifier<JVPDerivativeModifier>())
+ {
+ this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
+ }
+
for(auto paramDecl : decl->getParameters())
{
ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
@@ -5594,6 +5726,75 @@ namespace Slang
m_candidateExtensionListsBuilt = false;
m_mapTypeDeclToCandidateExtensions.Clear();
}
+
+ void DifferentiableTypeSemanticContext::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness)
+ {
+ // Need to generate a type dictionary since we have a declaration that works with
+ // a differentiable type.
+ //
+ this->requireDifferentiableTypeDictionary();
+
+ m_mapTypeToIDifferentiableWitness.AddIfNotExists(DeclRefTypeKey(type), witness);
+ }
+
+ List<KeyValuePair<DeclRefType*, SubtypeWitness*>> DifferentiableTypeSemanticContext::getDifferentiableTypeConformanceList()
+ {
+ List<KeyValuePair<DeclRefType*, SubtypeWitness*>> diffConformances;
+ for (auto entry : m_mapTypeToIDifferentiableWitness)
+ {
+ diffConformances.add(KeyValuePair<DeclRefType*, SubtypeWitness*>(entry.Key.type, entry.Value));
+ }
+
+ return diffConformances;
+ }
+
+ DifferentiableTypeDictionary* DifferentiableTypeSemanticContext::makeDifferentiableTypeDictionaryNode(
+ ASTBuilder* builder)
+ {
+ auto dictionary = builder->create<DifferentiableTypeDictionary>();
+
+ for (auto item : m_mapTypeToIDifferentiableWitness)
+ {
+ auto entry = builder->create<DifferentiableTypeDictionaryItem>();
+ entry->baseType = item.Key.type;
+ entry->confWitness = item.Value;
+ entry->parentDecl = dictionary;
+
+ dictionary->members.add(entry);
+ }
+
+ for (auto item : m_importedDictionaries)
+ {
+ auto entry = builder->create<DifferentiableTypeDictionaryImportItem>();
+ entry->dictionaryRef = item;
+ entry->parentDecl = dictionary;
+
+ dictionary->members.add(entry);
+ }
+
+ return dictionary;
+ }
+
+ void DifferentiableTypeSemanticContext::addImportedModule(ModuleDecl* importedModuleDecl)
+ {
+ // TODO: This is a terribly slow way to find the diff type dictionary.
+ // Switch to lookUp() when possible (this might involve naming the dictionary something)
+ //
+ for (auto diffTypeDict : importedModuleDecl->getMembersOfType<DifferentiableTypeDictionary>())
+ {
+ m_importedDictionaries.add(makeDeclRef(diffTypeDict));
+ }
+ }
+
+ void DifferentiableTypeSemanticContext::requireDifferentiableTypeDictionary()
+ {
+ this->m_isTypeDictionaryRequired = true;
+ }
+
+ bool DifferentiableTypeSemanticContext::isDictionaryRequired()
+ {
+ return this->m_isTypeDictionaryRequired;
+ }
void SharedSemanticsContext::_addCandidateExtensionsFromModule(ModuleDecl* moduleDecl)
{