diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-24 13:22:07 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-24 13:22:07 -0800 |
| commit | 91694dacdb8d3ab7dd9783d7c0c43629bf11f578 (patch) | |
| tree | 19eb8db8845b22c379ebc3f114d610c1e401bd9d /source/slang/slang-check-expr.cpp | |
| parent | bd6306cdaa4a49344658bd026721b6532e103d09 (diff) | |
Fix differential type registration through non-differentiable type. (#2677)
* Fix differential type registration through non-differentiable type.
* More fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 98 |
1 files changed, 28 insertions, 70 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 3567e2593..2803b5959 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -923,7 +923,7 @@ namespace Slang return result; } - void SemanticsVisitor::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness) + void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness) { SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr); if (witness) @@ -944,49 +944,23 @@ namespace Slang return; } - // Check for special cases such as PtrTypeBase<T> or Array<T> - // This could potentially be handled later by simply defining extensions - // for Ptr<T:IDifferentiable> etc.. - // - if (auto ptrType = as<PtrTypeBase>(type)) - { - maybeRegisterDifferentiableType(builder, ptrType->getValueType()); - return; - } - - if (auto arrayType = as<ArrayExpressionType>(type)) - { - maybeRegisterDifferentiableType(builder, arrayType->getElementType()); - // Fall through to register the array type itself. - } - - if (auto declRefType = as<DeclRefType>(type)) - { - if (auto subtypeWitness = as<SubtypeWitness>( - tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface()))) - { - registerDifferentiableType((DeclRefType*)type, subtypeWitness); - } - return; - } + maybeRegisterDifferentiableTypeImplRecursive(builder, type); } - void SemanticsVisitor::maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet) + void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type) { - if (workingSet.contains(type)) + // Recursively visit the tree of type and register all differentiable types along the way. + + if (as<TypeType>(type)) return; - - if (!builder->isDifferentiableInterfaceAvailable()) - { + if (!type) return; - } - if (!m_parentDifferentiableAttr) - { + // Have we already registered this type? If so we can exit now. + if (m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.contains(type)) return; - } - workingSet.add(type); + m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.add(type); // Check for special cases such as PtrTypeBase<T> or Array<T> // This could potentially be handled later by simply defining extensions @@ -994,13 +968,13 @@ namespace Slang // if (auto ptrType = as<PtrTypeBase>(type)) { - maybeRegisterDifferentiableTypeRecursive(builder, ptrType->getValueType(), workingSet); + maybeRegisterDifferentiableTypeImplRecursive(builder, ptrType->getValueType()); return; } if (auto arrayType = as<ArrayExpressionType>(type)) { - maybeRegisterDifferentiableTypeRecursive(builder, arrayType->getElementType(), workingSet); + maybeRegisterDifferentiableTypeImplRecursive(builder, arrayType->getElementType()); // Fall through to register the array type itself. } @@ -1009,20 +983,20 @@ namespace Slang if (auto subtypeWitness = as<SubtypeWitness>( tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface()))) { - registerDifferentiableType((DeclRefType*)type, subtypeWitness); - if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>()) - { - foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) - { - auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr); - maybeRegisterDifferentiableTypeRecursive(m_astBuilder, subType, workingSet); - }); - foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member) - { - auto fieldType = getType(m_astBuilder, member); - maybeRegisterDifferentiableTypeRecursive(m_astBuilder, fieldType, workingSet); - }); - } + addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); + } + if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>()) + { + foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) + { + auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr); + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, subType); + }); + foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member) + { + auto fieldType = getType(m_astBuilder, member); + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType); + }); } for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst = subst->outer) { @@ -1032,35 +1006,19 @@ namespace Slang { if (auto typeArg = as<Type>(arg)) { - maybeRegisterDifferentiableTypeRecursive(m_astBuilder, typeArg, workingSet); + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg); } } } else if (auto thisSubst = as<ThisTypeSubstitution>(subst)) { - maybeRegisterDifferentiableTypeRecursive(m_astBuilder, thisSubst->witness->sub, workingSet); + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, thisSubst->witness->sub); } } return; } } - void SemanticsVisitor::completeDifferentiableTypeDictionary() - { - ValSet workingSet; - for (auto type : m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness) - { - if (auto aggTypeDeclRef = type.Key.as<AggTypeDecl>()) - { - maybeRegisterDifferentiableTypeRecursive( - m_astBuilder, - m_astBuilder->getOrCreateDeclRefType( - aggTypeDeclRef.getDecl(), aggTypeDeclRef.substitutions), - workingSet); - } - } - } - Expr* SemanticsVisitor::CheckTerm(Expr* term) { |
