diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 26 |
1 files changed, 20 insertions, 6 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 311a5944b..b43a03150 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2115,9 +2115,14 @@ namespace Slang resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType); if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>()) + auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl(); + if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl())) { - for (auto param : funcDecl.getDecl()->getParameters()) + funcDecl = as<CallableDecl>(genDecl->inner); + } + if (funcDecl) + { + for (auto param : funcDecl->getParameters()) { resultDiffExpr->newParameterNames.add(param->getName()); } @@ -2144,14 +2149,19 @@ namespace Slang resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType); if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>()) + auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl(); + if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl())) + { + funcDecl = as<CallableDecl>(genDecl->inner); + } + if (funcDecl) { - for (auto param : funcDecl.getDecl()->getParameters()) + for (auto param : funcDecl->getParameters()) { resultDiffExpr->newParameterNames.add(param->getName()); } + resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); } - resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); } } }; @@ -2175,13 +2185,15 @@ namespace Slang { auto lookupResultExpr = semantics->ConstructLookupResultExpr(item, nullptr, - expr->loc, + overloadedExpr->loc, nullptr); auto candidateExpr = actions->createDifferentiateExpr(semantics); actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr); + candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } result->type.type = astBuilder->getOverloadedType(); + result->loc = expr->loc; return result; } else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr->baseFunction)) @@ -2191,9 +2203,11 @@ namespace Slang { auto candidateExpr = actions->createDifferentiateExpr(semantics); actions->fillDifferentiateExpr(candidateExpr, semantics, item); + candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } result->type.type = astBuilder->getOverloadedType(); + result->loc = expr->loc; return result; } |
