diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-16 16:08:51 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-16 16:08:51 -0800 |
| commit | e13d38b6a281f444203410f09dab8b127e678975 (patch) | |
| tree | e8db1272ee8a729256515cc11a635c3c68752004 /source/slang/slang-check-expr.cpp | |
| parent | 801aa3b44254341018a1acbe754f2ce3b0900e2a (diff) | |
Language server improvements for auto-diff. (#2521)
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 26 |
1 files changed, 20 insertions, 6 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 311a5944b..b43a03150 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2115,9 +2115,14 @@ namespace Slang resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType); if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>()) + auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl(); + if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl())) { - for (auto param : funcDecl.getDecl()->getParameters()) + funcDecl = as<CallableDecl>(genDecl->inner); + } + if (funcDecl) + { + for (auto param : funcDecl->getParameters()) { resultDiffExpr->newParameterNames.add(param->getName()); } @@ -2144,14 +2149,19 @@ namespace Slang resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType); if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>()) + auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl(); + if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl())) + { + funcDecl = as<CallableDecl>(genDecl->inner); + } + if (funcDecl) { - for (auto param : funcDecl.getDecl()->getParameters()) + for (auto param : funcDecl->getParameters()) { resultDiffExpr->newParameterNames.add(param->getName()); } + resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); } - resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); } } }; @@ -2175,13 +2185,15 @@ namespace Slang { auto lookupResultExpr = semantics->ConstructLookupResultExpr(item, nullptr, - expr->loc, + overloadedExpr->loc, nullptr); auto candidateExpr = actions->createDifferentiateExpr(semantics); actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr); + candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } result->type.type = astBuilder->getOverloadedType(); + result->loc = expr->loc; return result; } else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr->baseFunction)) @@ -2191,9 +2203,11 @@ namespace Slang { auto candidateExpr = actions->createDifferentiateExpr(semantics); actions->fillDifferentiateExpr(candidateExpr, semantics, item); + candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } result->type.type = astBuilder->getOverloadedType(); + result->loc = expr->loc; return result; } |
