From e13d38b6a281f444203410f09dab8b127e678975 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 16 Nov 2022 16:08:51 -0800 Subject: Language server improvements for auto-diff. (#2521) --- source/slang/slang-check-expr.cpp | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 311a5944b..b43a03150 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2115,9 +2115,14 @@ namespace Slang resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType); if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - if (auto funcDecl = declRefExpr->declRef.as()) + auto funcDecl = declRefExpr->declRef.as().getDecl(); + if (auto genDecl = as(declRefExpr->declRef.getDecl())) { - for (auto param : funcDecl.getDecl()->getParameters()) + funcDecl = as(genDecl->inner); + } + if (funcDecl) + { + for (auto param : funcDecl->getParameters()) { resultDiffExpr->newParameterNames.add(param->getName()); } @@ -2144,14 +2149,19 @@ namespace Slang resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType); if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - if (auto funcDecl = declRefExpr->declRef.as()) + auto funcDecl = declRefExpr->declRef.as().getDecl(); + if (auto genDecl = as(declRefExpr->declRef.getDecl())) + { + funcDecl = as(genDecl->inner); + } + if (funcDecl) { - for (auto param : funcDecl.getDecl()->getParameters()) + for (auto param : funcDecl->getParameters()) { resultDiffExpr->newParameterNames.add(param->getName()); } + resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); } - resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); } } }; @@ -2175,13 +2185,15 @@ namespace Slang { auto lookupResultExpr = semantics->ConstructLookupResultExpr(item, nullptr, - expr->loc, + overloadedExpr->loc, nullptr); auto candidateExpr = actions->createDifferentiateExpr(semantics); actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr); + candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } result->type.type = astBuilder->getOverloadedType(); + result->loc = expr->loc; return result; } else if (auto overloadedExpr2 = as(expr->baseFunction)) @@ -2191,9 +2203,11 @@ namespace Slang { auto candidateExpr = actions->createDifferentiateExpr(semantics); actions->fillDifferentiateExpr(candidateExpr, semantics, item); + candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } result->type.type = astBuilder->getOverloadedType(); + result->loc = expr->loc; return result; } -- cgit v1.2.3