diff options
| author | Yong He <yonghe@outlook.com> | 2024-03-08 18:08:24 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-03-08 18:08:24 -0800 |
| commit | 0629b22bf09ae6b3c3689c5f98492df7577bf0d2 (patch) | |
| tree | 286eaf6268986b1ecb3cc19e8f3b72495e881d78 /source/slang/slang-check-decl.cpp | |
| parent | 21502874666c282a3c5fa1f802deff27fab4e93b (diff) | |
Enhance link-time type test. (#3724)
* Enhance link-time type test.
* Fix.
* Fix.
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7784100a6..8dee7b0c5 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -311,6 +311,8 @@ namespace Slang void visitAggTypeDecl(AggTypeDecl* aggTypeDecl); + SemanticsContext registerDifferentiableTypesForFunc(FunctionDeclBase* funcDecl); + }; template<typename VisitorType> @@ -3660,9 +3662,12 @@ namespace Slang // the work of constructing our synthesized method. // + bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); + // First, we check that the differentiabliity of the method matches the requirement, // and we don't attempt to synthesize a method if they don't match. - if (getShared()->getFuncDifferentiableLevel( + if (!isInWrapperType && + getShared()->getFuncDifferentiableLevel( as<FunctionDeclBase>(lookupResult.item.declRef.getDecl())) < getShared()->getFuncDifferentiableLevel( as<FunctionDeclBase>(requiredMemberDeclRef.getDecl()))) @@ -3689,7 +3694,7 @@ namespace Slang auto synBase = m_astBuilder->create<OverloadedExpr>(); synBase->name = requiredMemberDeclRef.getDecl()->getName(); - if (isWrapperTypeDecl(context->parentDecl)) + if (isInWrapperType) { auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl); synBase->lookupResult2 = lookUpMember( @@ -3701,6 +3706,10 @@ namespace Slang LookupMask::Default, LookupOptions::IgnoreBaseInterfaces); addModifier(synFuncDecl, m_astBuilder->create<ForceInlineAttribute>()); + + synFuncDecl->parentDecl = aggTypeDecl; + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl); } else { @@ -3714,7 +3723,7 @@ namespace Slang // if (synThis) { - if (isWrapperTypeDecl(context->parentDecl)) + if (isInWrapperType) { // If this is a wrapper type, then use the inner // object as the actual this parameter for the redirected @@ -3723,6 +3732,8 @@ namespace Slang innerExpr->scope = synThis->scope; innerExpr->name = getName("inner"); synBase->base = CheckExpr(innerExpr); + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, synBase->base->type); } else { @@ -6066,7 +6077,7 @@ namespace Slang checkVisibility(decl); } - void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + SemanticsContext SemanticsDeclBodyVisitor::registerDifferentiableTypesForFunc(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); if (newContext.getParentDifferentiableAttribute()) @@ -6086,7 +6097,12 @@ namespace Slang } m_parentDifferentiableAttr = oldAttr; } + return newContext; + } + void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + { + auto newContext = registerDifferentiableTypesForFunc(decl); if (const auto body = decl->body) { checkStmt(decl->body, newContext); |
