summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-24 13:22:07 -0800
committerGitHub <noreply@github.com>2023-02-24 13:22:07 -0800
commit91694dacdb8d3ab7dd9783d7c0c43629bf11f578 (patch)
tree19eb8db8845b22c379ebc3f114d610c1e401bd9d /source/slang/slang-check-expr.cpp
parentbd6306cdaa4a49344658bd026721b6532e103d09 (diff)
Fix differential type registration through non-differentiable type. (#2677)
* Fix differential type registration through non-differentiable type. * More fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp98
1 files changed, 28 insertions, 70 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 3567e2593..2803b5959 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -923,7 +923,7 @@ namespace Slang
return result;
}
- void SemanticsVisitor::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness)
+ void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness)
{
SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr);
if (witness)
@@ -944,49 +944,23 @@ namespace Slang
return;
}
- // Check for special cases such as PtrTypeBase<T> or Array<T>
- // This could potentially be handled later by simply defining extensions
- // for Ptr<T:IDifferentiable> etc..
- //
- if (auto ptrType = as<PtrTypeBase>(type))
- {
- maybeRegisterDifferentiableType(builder, ptrType->getValueType());
- return;
- }
-
- if (auto arrayType = as<ArrayExpressionType>(type))
- {
- maybeRegisterDifferentiableType(builder, arrayType->getElementType());
- // Fall through to register the array type itself.
- }
-
- if (auto declRefType = as<DeclRefType>(type))
- {
- if (auto subtypeWitness = as<SubtypeWitness>(
- tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface())))
- {
- registerDifferentiableType((DeclRefType*)type, subtypeWitness);
- }
- return;
- }
+ maybeRegisterDifferentiableTypeImplRecursive(builder, type);
}
- void SemanticsVisitor::maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet)
+ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type)
{
- if (workingSet.contains(type))
+ // Recursively visit the tree of type and register all differentiable types along the way.
+
+ if (as<TypeType>(type))
return;
-
- if (!builder->isDifferentiableInterfaceAvailable())
- {
+ if (!type)
return;
- }
- if (!m_parentDifferentiableAttr)
- {
+ // Have we already registered this type? If so we can exit now.
+ if (m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.contains(type))
return;
- }
- workingSet.add(type);
+ m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.add(type);
// Check for special cases such as PtrTypeBase<T> or Array<T>
// This could potentially be handled later by simply defining extensions
@@ -994,13 +968,13 @@ namespace Slang
//
if (auto ptrType = as<PtrTypeBase>(type))
{
- maybeRegisterDifferentiableTypeRecursive(builder, ptrType->getValueType(), workingSet);
+ maybeRegisterDifferentiableTypeImplRecursive(builder, ptrType->getValueType());
return;
}
if (auto arrayType = as<ArrayExpressionType>(type))
{
- maybeRegisterDifferentiableTypeRecursive(builder, arrayType->getElementType(), workingSet);
+ maybeRegisterDifferentiableTypeImplRecursive(builder, arrayType->getElementType());
// Fall through to register the array type itself.
}
@@ -1009,20 +983,20 @@ namespace Slang
if (auto subtypeWitness = as<SubtypeWitness>(
tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface())))
{
- registerDifferentiableType((DeclRefType*)type, subtypeWitness);
- if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>())
- {
- foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
- {
- auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr);
- maybeRegisterDifferentiableTypeRecursive(m_astBuilder, subType, workingSet);
- });
- foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member)
- {
- auto fieldType = getType(m_astBuilder, member);
- maybeRegisterDifferentiableTypeRecursive(m_astBuilder, fieldType, workingSet);
- });
- }
+ addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
+ }
+ if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>())
+ {
+ foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
+ {
+ auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr);
+ maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, subType);
+ });
+ foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member)
+ {
+ auto fieldType = getType(m_astBuilder, member);
+ maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType);
+ });
}
for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst = subst->outer)
{
@@ -1032,35 +1006,19 @@ namespace Slang
{
if (auto typeArg = as<Type>(arg))
{
- maybeRegisterDifferentiableTypeRecursive(m_astBuilder, typeArg, workingSet);
+ maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg);
}
}
}
else if (auto thisSubst = as<ThisTypeSubstitution>(subst))
{
- maybeRegisterDifferentiableTypeRecursive(m_astBuilder, thisSubst->witness->sub, workingSet);
+ maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, thisSubst->witness->sub);
}
}
return;
}
}
- void SemanticsVisitor::completeDifferentiableTypeDictionary()
- {
- ValSet workingSet;
- for (auto type : m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness)
- {
- if (auto aggTypeDeclRef = type.Key.as<AggTypeDecl>())
- {
- maybeRegisterDifferentiableTypeRecursive(
- m_astBuilder,
- m_astBuilder->getOrCreateDeclRefType(
- aggTypeDeclRef.getDecl(), aggTypeDeclRef.substitutions),
- workingSet);
- }
- }
- }
-
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{