summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-16 16:08:51 -0800
committerGitHub <noreply@github.com>2022-11-16 16:08:51 -0800
commite13d38b6a281f444203410f09dab8b127e678975 (patch)
treee8db1272ee8a729256515cc11a635c3c68752004 /source/slang/slang-check-expr.cpp
parent801aa3b44254341018a1acbe754f2ce3b0900e2a (diff)
Language server improvements for auto-diff. (#2521)
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp26
1 files changed, 20 insertions, 6 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 311a5944b..b43a03150 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2115,9 +2115,14 @@ namespace Slang
resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType);
if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr)))
{
- if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>())
+ auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl();
+ if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl()))
{
- for (auto param : funcDecl.getDecl()->getParameters())
+ funcDecl = as<CallableDecl>(genDecl->inner);
+ }
+ if (funcDecl)
+ {
+ for (auto param : funcDecl->getParameters())
{
resultDiffExpr->newParameterNames.add(param->getName());
}
@@ -2144,14 +2149,19 @@ namespace Slang
resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType);
if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr)))
{
- if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>())
+ auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl();
+ if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl()))
+ {
+ funcDecl = as<CallableDecl>(genDecl->inner);
+ }
+ if (funcDecl)
{
- for (auto param : funcDecl.getDecl()->getParameters())
+ for (auto param : funcDecl->getParameters())
{
resultDiffExpr->newParameterNames.add(param->getName());
}
+ resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient"));
}
- resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient"));
}
}
};
@@ -2175,13 +2185,15 @@ namespace Slang
{
auto lookupResultExpr = semantics->ConstructLookupResultExpr(item,
nullptr,
- expr->loc,
+ overloadedExpr->loc,
nullptr);
auto candidateExpr = actions->createDifferentiateExpr(semantics);
actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr);
+ candidateExpr->loc = expr->loc;
result->candidiateExprs.add(candidateExpr);
}
result->type.type = astBuilder->getOverloadedType();
+ result->loc = expr->loc;
return result;
}
else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr->baseFunction))
@@ -2191,9 +2203,11 @@ namespace Slang
{
auto candidateExpr = actions->createDifferentiateExpr(semantics);
actions->fillDifferentiateExpr(candidateExpr, semantics, item);
+ candidateExpr->loc = expr->loc;
result->candidiateExprs.add(candidateExpr);
}
result->type.type = astBuilder->getOverloadedType();
+ result->loc = expr->loc;
return result;
}