summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-24 22:16:21 -0800
committerGitHub <noreply@github.com>2023-01-24 22:16:21 -0800
commit951ad25e0a9c3b0089c6b996b8e821ac93cf5766 (patch)
tree7bed99484204611a4669d7c2c11019795e37f7cb /source/slang/slang-check-expr.cpp
parenta3b0eff62e59f3a05461bf3edee5e100e804e4d5 (diff)
Reimplement address elimination. (#2605)
* Reimplement address elimination pass. * Fix error. * Update test references. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp75
1 files changed, 74 insertions, 1 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 43124b535..2853c1eb9 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -922,7 +922,6 @@ namespace Slang
}
}
-
void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type)
{
if (!builder->isDifferentiableInterfaceAvailable())
@@ -962,6 +961,80 @@ namespace Slang
}
}
+ void SemanticsVisitor::maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet)
+ {
+ if (workingSet.contains(type))
+ return;
+
+ if (!builder->isDifferentiableInterfaceAvailable())
+ {
+ return;
+ }
+
+ if (!m_parentDifferentiableAttr)
+ {
+ return;
+ }
+
+ workingSet.add(type);
+
+ // Check for special cases such as PtrTypeBase<T> or Array<T>
+ // This could potentially be handled later by simply defining extensions
+ // for Ptr<T:IDifferentiable> etc..
+ //
+ if (auto ptrType = as<PtrTypeBase>(type))
+ {
+ maybeRegisterDifferentiableTypeRecursive(builder, ptrType->getValueType(), workingSet);
+ return;
+ }
+
+ if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ maybeRegisterDifferentiableTypeRecursive(builder, arrayType->baseType, workingSet);
+ return;
+ }
+
+ if (auto declRefType = as<DeclRefType>(type))
+ {
+ if (auto subtypeWitness = as<SubtypeWitness>(
+ tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface())))
+ {
+ registerDifferentiableType((DeclRefType*)type, subtypeWitness);
+ if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>())
+ {
+ foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
+ {
+ auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr);
+ maybeRegisterDifferentiableTypeRecursive(m_astBuilder, subType, workingSet);
+ });
+ foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member)
+ {
+ auto fieldType = getType(m_astBuilder, member);
+ maybeRegisterDifferentiableTypeRecursive(m_astBuilder, fieldType, workingSet);
+ });
+ }
+ }
+ return;
+ }
+ }
+
+ void SemanticsVisitor::completeDifferentiableTypeDictionary()
+ {
+ ValSet workingSet;
+ for (auto type : m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness)
+ {
+ if (auto aggTypeDeclRef = type.Key.as<AggTypeDecl>())
+ {
+ maybeRegisterDifferentiableTypeRecursive(
+ m_astBuilder,
+ m_astBuilder->getOrCreateDeclRefType(
+ aggTypeDeclRef.getDecl(), aggTypeDeclRef.substitutions),
+ workingSet);
+ }
+ }
+ }
+
+
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{
auto checkedTerm = _CheckTerm(term);