From 0629b22bf09ae6b3c3689c5f98492df7577bf0d2 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 8 Mar 2024 18:08:24 -0800 Subject: Enhance link-time type test. (#3724) * Enhance link-time type test. * Fix. * Fix. --- source/slang/slang-check-decl.cpp | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7784100a6..8dee7b0c5 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -311,6 +311,8 @@ namespace Slang void visitAggTypeDecl(AggTypeDecl* aggTypeDecl); + SemanticsContext registerDifferentiableTypesForFunc(FunctionDeclBase* funcDecl); + }; template @@ -3660,9 +3662,12 @@ namespace Slang // the work of constructing our synthesized method. // + bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); + // First, we check that the differentiabliity of the method matches the requirement, // and we don't attempt to synthesize a method if they don't match. - if (getShared()->getFuncDifferentiableLevel( + if (!isInWrapperType && + getShared()->getFuncDifferentiableLevel( as(lookupResult.item.declRef.getDecl())) < getShared()->getFuncDifferentiableLevel( as(requiredMemberDeclRef.getDecl()))) @@ -3689,7 +3694,7 @@ namespace Slang auto synBase = m_astBuilder->create(); synBase->name = requiredMemberDeclRef.getDecl()->getName(); - if (isWrapperTypeDecl(context->parentDecl)) + if (isInWrapperType) { auto aggTypeDecl = as(context->parentDecl); synBase->lookupResult2 = lookUpMember( @@ -3701,6 +3706,10 @@ namespace Slang LookupMask::Default, LookupOptions::IgnoreBaseInterfaces); addModifier(synFuncDecl, m_astBuilder->create()); + + synFuncDecl->parentDecl = aggTypeDecl; + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl); } else { @@ -3714,7 +3723,7 @@ namespace Slang // if (synThis) { - if (isWrapperTypeDecl(context->parentDecl)) + if (isInWrapperType) { // If this is a wrapper type, then use the inner // object as the actual this parameter for the redirected @@ -3723,6 +3732,8 @@ namespace Slang innerExpr->scope = synThis->scope; innerExpr->name = getName("inner"); synBase->base = CheckExpr(innerExpr); + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, synBase->base->type); } else { @@ -6066,7 +6077,7 @@ namespace Slang checkVisibility(decl); } - void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + SemanticsContext SemanticsDeclBodyVisitor::registerDifferentiableTypesForFunc(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); if (newContext.getParentDifferentiableAttribute()) @@ -6086,7 +6097,12 @@ namespace Slang } m_parentDifferentiableAttr = oldAttr; } + return newContext; + } + void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + { + auto newContext = registerDifferentiableTypesForFunc(decl); if (const auto body = decl->body) { checkStmt(decl->body, newContext); -- cgit v1.2.3