From 7bbe7b4780345181cb586b03504ff63f9b8d5c4c Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 7 Apr 2023 09:57:39 -0400 Subject: Fix crash on overloaded custom derivative function (#2782) * Fix issue with resolving overloaded custom forward derivative methods. * Add test --- source/slang/slang-check-decl.cpp | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 3ffb6c100..63c7d9741 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6923,20 +6923,34 @@ namespace Slang auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); - if (auto derivFuncDeclRef = as(checkedFuncExpr)->declRef) + if (auto declRefExpr = as(checkedFuncExpr)) { - visitor->ensureDecl(derivFuncDeclRef, DeclCheckState::TypesFullyResolved); - auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); - auto resolved = subVisitor.ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as(resolved)) + visitor->ensureDecl(declRefExpr->declRef, DeclCheckState::TypesFullyResolved); + } + else if (auto overloadedExpr = as(checkedFuncExpr)) + { + for (auto candidate : overloadedExpr->lookupResult2.items) { - if (auto calleeDeclRef = as(resolvedInvoke->functionExpr)) - { - attr->funcExpr = calleeDeclRef; - return; - } + visitor->ensureDecl(candidate.declRef, DeclCheckState::TypesFullyResolved); } } + else + { + visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveDerivativeFunction); + return; + } + + auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); + auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as(resolved)) + { + if (auto calleeDeclRef = as(resolvedInvoke->functionExpr)) + { + attr->funcExpr = calleeDeclRef; + return; + } + } + visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); } -- cgit v1.2.3