summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-03-08 18:08:24 -0800
committerGitHub <noreply@github.com>2024-03-08 18:08:24 -0800
commit0629b22bf09ae6b3c3689c5f98492df7577bf0d2 (patch)
tree286eaf6268986b1ecb3cc19e8f3b72495e881d78 /source/slang/slang-check-decl.cpp
parent21502874666c282a3c5fa1f802deff27fab4e93b (diff)
Enhance link-time type test. (#3724)
* Enhance link-time type test. * Fix. * Fix.
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp24
1 files changed, 20 insertions, 4 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 7784100a6..8dee7b0c5 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -311,6 +311,8 @@ namespace Slang
void visitAggTypeDecl(AggTypeDecl* aggTypeDecl);
+ SemanticsContext registerDifferentiableTypesForFunc(FunctionDeclBase* funcDecl);
+
};
template<typename VisitorType>
@@ -3660,9 +3662,12 @@ namespace Slang
// the work of constructing our synthesized method.
//
+ bool isInWrapperType = isWrapperTypeDecl(context->parentDecl);
+
// First, we check that the differentiabliity of the method matches the requirement,
// and we don't attempt to synthesize a method if they don't match.
- if (getShared()->getFuncDifferentiableLevel(
+ if (!isInWrapperType &&
+ getShared()->getFuncDifferentiableLevel(
as<FunctionDeclBase>(lookupResult.item.declRef.getDecl()))
< getShared()->getFuncDifferentiableLevel(
as<FunctionDeclBase>(requiredMemberDeclRef.getDecl())))
@@ -3689,7 +3694,7 @@ namespace Slang
auto synBase = m_astBuilder->create<OverloadedExpr>();
synBase->name = requiredMemberDeclRef.getDecl()->getName();
- if (isWrapperTypeDecl(context->parentDecl))
+ if (isInWrapperType)
{
auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl);
synBase->lookupResult2 = lookUpMember(
@@ -3701,6 +3706,10 @@ namespace Slang
LookupMask::Default,
LookupOptions::IgnoreBaseInterfaces);
addModifier(synFuncDecl, m_astBuilder->create<ForceInlineAttribute>());
+
+ synFuncDecl->parentDecl = aggTypeDecl;
+ SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl));
+ bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl);
}
else
{
@@ -3714,7 +3723,7 @@ namespace Slang
//
if (synThis)
{
- if (isWrapperTypeDecl(context->parentDecl))
+ if (isInWrapperType)
{
// If this is a wrapper type, then use the inner
// object as the actual this parameter for the redirected
@@ -3723,6 +3732,8 @@ namespace Slang
innerExpr->scope = synThis->scope;
innerExpr->name = getName("inner");
synBase->base = CheckExpr(innerExpr);
+ SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl));
+ bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, synBase->base->type);
}
else
{
@@ -6066,7 +6077,7 @@ namespace Slang
checkVisibility(decl);
}
- void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
+ SemanticsContext SemanticsDeclBodyVisitor::registerDifferentiableTypesForFunc(FunctionDeclBase* decl)
{
auto newContext = withParentFunc(decl);
if (newContext.getParentDifferentiableAttribute())
@@ -6086,7 +6097,12 @@ namespace Slang
}
m_parentDifferentiableAttr = oldAttr;
}
+ return newContext;
+ }
+ void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
+ {
+ auto newContext = registerDifferentiableTypesForFunc(decl);
if (const auto body = decl->body)
{
checkStmt(decl->body, newContext);