diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-24 22:16:21 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-24 22:16:21 -0800 |
| commit | 951ad25e0a9c3b0089c6b996b8e821ac93cf5766 (patch) | |
| tree | 7bed99484204611a4669d7c2c11019795e37f7cb /source/slang/slang-check-expr.cpp | |
| parent | a3b0eff62e59f3a05461bf3edee5e100e804e4d5 (diff) | |
Reimplement address elimination. (#2605)
* Reimplement address elimination pass.
* Fix error.
* Update test references.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 75 |
1 files changed, 74 insertions, 1 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 43124b535..2853c1eb9 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -922,7 +922,6 @@ namespace Slang } } - void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) { if (!builder->isDifferentiableInterfaceAvailable()) @@ -962,6 +961,80 @@ namespace Slang } } + void SemanticsVisitor::maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet) + { + if (workingSet.contains(type)) + return; + + if (!builder->isDifferentiableInterfaceAvailable()) + { + return; + } + + if (!m_parentDifferentiableAttr) + { + return; + } + + workingSet.add(type); + + // Check for special cases such as PtrTypeBase<T> or Array<T> + // This could potentially be handled later by simply defining extensions + // for Ptr<T:IDifferentiable> etc.. + // + if (auto ptrType = as<PtrTypeBase>(type)) + { + maybeRegisterDifferentiableTypeRecursive(builder, ptrType->getValueType(), workingSet); + return; + } + + if (auto arrayType = as<ArrayExpressionType>(type)) + { + maybeRegisterDifferentiableTypeRecursive(builder, arrayType->baseType, workingSet); + return; + } + + if (auto declRefType = as<DeclRefType>(type)) + { + if (auto subtypeWitness = as<SubtypeWitness>( + tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface()))) + { + registerDifferentiableType((DeclRefType*)type, subtypeWitness); + if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>()) + { + foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) + { + auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr); + maybeRegisterDifferentiableTypeRecursive(m_astBuilder, subType, workingSet); + }); + foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member) + { + auto fieldType = getType(m_astBuilder, member); + maybeRegisterDifferentiableTypeRecursive(m_astBuilder, fieldType, workingSet); + }); + } + } + return; + } + } + + void SemanticsVisitor::completeDifferentiableTypeDictionary() + { + ValSet workingSet; + for (auto type : m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness) + { + if (auto aggTypeDeclRef = type.Key.as<AggTypeDecl>()) + { + maybeRegisterDifferentiableTypeRecursive( + m_astBuilder, + m_astBuilder->getOrCreateDeclRefType( + aggTypeDeclRef.getDecl(), aggTypeDeclRef.substitutions), + workingSet); + } + } + } + + Expr* SemanticsVisitor::CheckTerm(Expr* term) { auto checkedTerm = _CheckTerm(term); |
