summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-08 21:52:34 -0800
committerGitHub <noreply@github.com>2023-03-08 21:52:34 -0800
commit86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch)
treeb4f9eb6cb1eea88145fde0bd1f670a8803120257 /source/slang/slang-check-decl.cpp
parent257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff)
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp614
1 files changed, 346 insertions, 268 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 7c42c1892..5cd7fba45 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -34,6 +34,26 @@ namespace Slang
}
};
+ struct SemanticsDeclAttributesVisitor
+ : public SemanticsDeclVisitorBase
+ , public DeclVisitor<SemanticsDeclAttributesVisitor>
+ {
+ SemanticsDeclAttributesVisitor(SemanticsContext const& outer)
+ : SemanticsDeclVisitorBase(outer)
+ {}
+
+ void visitDecl(Decl*) {}
+ void visitDeclGroup(DeclGroup*) {}
+
+ void visitFunctionDeclBase(FunctionDeclBase* decl);
+
+ void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr);
+
+ void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr);
+
+ void checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr);
+ };
+
struct SemanticsDeclHeaderVisitor
: public SemanticsDeclVisitorBase
, public DeclVisitor<SemanticsDeclHeaderVisitor>
@@ -258,10 +278,6 @@ namespace Slang
void visitFunctionDeclBase(FunctionDeclBase* funcDecl);
void visitParamDecl(ParamDecl* paramDecl);
-
- void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl);
-
- void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl);
};
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
@@ -4657,270 +4673,10 @@ namespace Slang
getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
}
- template<typename TDerivativeAttr>
- void checkDerivativeAttributeImpl(
- SemanticsVisitor* visitor,
- TDerivativeAttr* attr,
- const List<Expr*>& imaginaryArguments)
- {
- auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor);
- auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
- auto resolved = visitor->ResolveInvoke(invokeExpr);
- if (auto resolvedInvoke = as<InvokeExpr>(resolved))
- {
- if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
- {
- attr->funcExpr = calleeDeclRef;
- return;
- }
- }
- visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
- }
-
- template<typename TDerivativeAttr>
- const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); }
-
- template<>
- const char* getDerivativeAttrName<ForwardDerivativeAttribute>()
- {
- return "ForwardDerivative";
- }
- template<>
- const char* getDerivativeAttrName<BackwardDerivativeAttribute>()
- {
- return "BackwardDerivative";
- }
-
- List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc)
- {
- List<Expr*> imaginaryArguments;
- for (auto param : func->getParameters())
- {
- auto arg = astBuilder->create<VarExpr>();
- arg->declRef.decl = param;
- arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
- arg->type.type = param->getType();
- arg->loc = loc;
- imaginaryArguments.add(arg);
- }
- return imaginaryArguments;
- }
-
- List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
- {
- List<Expr*> imaginaryArguments;
- for (auto param : originalFuncDecl->getParameters())
- {
- auto arg = visitor->getASTBuilder()->create<VarExpr>();
- arg->declRef.decl = param;
- arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
- arg->type.type = param->getType();
- arg->loc = loc;
- if (auto pairType = visitor->getDifferentialPairType(param->getType()))
- {
- arg->type.type = pairType;
- }
- imaginaryArguments.add(arg);
- }
- return imaginaryArguments;
- }
-
- List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
- {
- List<Expr*> imaginaryArguments;
- auto isOutParam = [&](ParamDecl* param)
- {
- return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr;
- };
-
- for (auto param : originalFuncDecl->getParameters())
- {
- auto arg = visitor->getASTBuilder()->create<VarExpr>();
- arg->declRef.decl = param;
- arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
- arg->type.type = param->getType();
- arg->loc = loc;
- if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
- {
- arg->type.type = pairType;
- if (isOutParam(param))
- {
- // out T -> in T.Differential
- arg->type.isLeftValue = false;
- arg->type.type = visitor->tryGetDifferentialType(
- visitor->getASTBuilder(), pairType->getPrimalType());
- }
- }
- else
- {
- if (isOutParam(param))
- {
- // Skip non-differentiable out params.
- continue;
- }
- }
- imaginaryArguments.add(arg);
- }
- if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type))
- {
- auto arg = visitor->getASTBuilder()->create<InitializerListExpr>();
- arg->type.isLeftValue = false;
- arg->type.type = diffReturnType;
- arg->loc = loc;
- imaginaryArguments.add(arg);
- }
- return imaginaryArguments;
- }
-
- // This helper function is needed to workaround a gcc bug.
- // Remove when we upgrade to a newer version of gcc.
- template <typename T>
- static T* _findModifier(Decl* decl)
- {
- return decl->findModifier<T>();
- }
-
- template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr>
- void checkDerivativeOfAttributeImpl(
- SemanticsVisitor* visitor,
- FunctionDeclBase* funcDecl,
- TDerivativeOfAttr* derivativeOfAttr,
- DeclAssociationKind assocKind)
- {
- DeclRef<Decl> calleeDeclRef;
- DeclRefExpr* calleeDeclRefExpr = nullptr;
- DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>();
- diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr;
- diffFuncExpr->loc = derivativeOfAttr->loc;
- Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor);
- if (!checkedDiffFuncExpr)
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
- return;
- }
- List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc);
- auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs);
- auto resolved = visitor->ResolveInvoke(invokeExpr);
- if (auto resolvedInvoke = as<InvokeExpr>(resolved))
- {
- auto resolvedDiffFuncExpr = as<DifferentiateExpr>(resolvedInvoke->functionExpr);
- if (resolvedDiffFuncExpr)
- calleeDeclRefExpr = as<DeclRefExpr>(resolvedDiffFuncExpr->baseFunction);
- }
-
- if (!calleeDeclRefExpr)
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
- return;
- }
- calleeDeclRef = calleeDeclRefExpr->declRef;
-
- auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
- if (!calleeFunc)
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
- return;
- }
-
- if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc))
- {
- // The primal function already has a `[*Derivative]` attribute, this is invalid.
- visitor->getSink()->diagnose(
- derivativeOfAttr,
- Diagnostics::declAlreadyHasAttribute,
- calleeDeclRef,
- getDerivativeAttrName<TDerivativeAttr>());
- visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl());
- }
- derivativeOfAttr->funcExpr = calleeDeclRefExpr;
- auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>();
- derivativeAttr->loc = derivativeOfAttr->loc;
- auto outterGeneric = visitor->GetOuterGeneric(funcDecl);
- auto declRef =
- DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
- auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr);
- declRefExpr->type.type = nullptr;
- derivativeAttr->args.add(declRefExpr);
- derivativeAttr->funcExpr = declRefExpr;
- checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr);
- derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr;
- derivativeAttr->funcExpr = nullptr;
- visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl);
- }
-
- static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
- {
- if (!attr->funcExpr)
- return;
- if (attr->funcExpr->type.type)
- return;
-
- List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
- }
-
- static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr)
- {
- if (!attr->funcExpr)
- return;
- if (attr->funcExpr->type.type)
- return;
-
- List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
- }
-
- template<typename TDerivativeAttr, typename TDerivativeOfAttr>
- bool tryCheckDerivativeOfAttributeImpl(
- SemanticsVisitor* visitor,
- FunctionDeclBase* funcDecl,
- TDerivativeOfAttr* derivativeOfAttr,
- DeclAssociationKind assocKind,
- const List<Expr*>& imaginaryArgsToOriginal)
- {
- DiagnosticSink tempSink(visitor->getSourceManager(), nullptr);
- SemanticsVisitor subVisitor(visitor->withSink(&tempSink));
- checkDerivativeOfAttributeImpl<TDerivativeAttr>(
- &subVisitor,
- funcDecl,
- derivativeOfAttr,
- assocKind,
- imaginaryArgsToOriginal);
- return tempSink.getErrorCount() == 0;
- }
-
- void SemanticsDeclBodyVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl)
- {
- auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>();
- if (!attr)
- return;
-
- checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>(
- this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc);
- }
-
- void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl)
- {
- auto attr = funcDecl->findModifier<BackwardDerivativeOfAttribute>();
- if (!attr)
- return;
-
- checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>(
- this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc);
- }
-
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
{
auto newContext = withParentFunc(decl);
- // Run checking on attributes that can't be fully checked in header checking stage.
- checkForwardDerivativeOfAttribute(decl);
- if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>())
- checkDerivativeAttribute(this, decl, derivativeAttr);
- checkBackwardDerivativeOfAttribute(decl);
- if (auto derivativeAttr = decl->findModifier<BackwardDerivativeAttribute>())
- checkDerivativeAttribute(this, decl, derivativeAttr);
-
if (newContext.getParentDifferentiableAttribute())
{
// Register additional types outside the function body first.
@@ -6762,7 +6518,7 @@ namespace Slang
/// Note: this function creates an empty list of candidates for the given type if
/// a matching entry doesn't exist already.
///
- static List<DeclAssociation>& _getDeclAssociationList(
+ static List<RefPtr<DeclAssociation>>& _getDeclAssociationList(
Decl* decl,
OrderedDictionary<Decl*, RefPtr<DeclAssociationList>>& mapDeclToDeclarations)
{
@@ -6787,14 +6543,16 @@ namespace Slang
void SharedSemanticsContext::registerAssociatedDecl(Decl* original, DeclAssociationKind kind, Decl* associated)
{
auto moduleDecl = getModuleDecl(associated);
- DeclAssociation assoc = {kind, associated};
+ RefPtr<DeclAssociation> assoc = new DeclAssociation();
+ assoc->kind = kind;
+ assoc->decl = associated;
_getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc);
m_associatedDeclListsBuilt = false;
m_mapDeclToAssociatedDecls.Clear();
}
- List<DeclAssociation> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl)
+ List<RefPtr<DeclAssociation>> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl)
{
// This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`.
// Consider refactoring them into the same framework.
@@ -6838,6 +6596,23 @@ namespace Slang
FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel(FunctionDeclBase* func)
{
+ return _getFuncDifferentiableLevelImpl(func, 1);
+ }
+
+ FunctionDifferentiableLevel SharedSemanticsContext::_getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit)
+ {
+ if (recurseLimit > 0)
+ {
+ if (auto primalSubst = func->findModifier<PrimalSubstituteAttribute>())
+ {
+ if (auto declRefExpr = as<DeclRefExpr>(primalSubst->funcExpr))
+ {
+ if (auto primalSubstFunc = declRefExpr->declRef.as<FunctionDeclBase>())
+ return _getFuncDifferentiableLevelImpl(primalSubstFunc, recurseLimit - 1);
+ }
+ }
+ }
+
if (func->findModifier<BackwardDifferentiableAttribute>())
return FunctionDifferentiableLevel::Backward;
if (func->findModifier<BackwardDerivativeAttribute>())
@@ -6849,13 +6624,19 @@ namespace Slang
for (auto assocDecl : getAssociatedDeclsForDecl(func))
{
- switch (assocDecl.kind)
+ switch (assocDecl->kind)
{
case DeclAssociationKind::BackwardDerivativeFunc:
return FunctionDifferentiableLevel::Backward;
case DeclAssociationKind::ForwardDerivativeFunc:
diffLevel = FunctionDifferentiableLevel::Forward;
break;
+ case DeclAssociationKind::PrimalSubstituteFunc:
+ if (auto assocFunc = as<FunctionDeclBase>(assocDecl->decl))
+ {
+ return _getFuncDifferentiableLevelImpl(assocFunc, recurseLimit - 1);
+ }
+ break;
default:
break;
}
@@ -6971,6 +6752,10 @@ namespace Slang
SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl);
break;
+ case DeclCheckState::AttributesChecked:
+ SemanticsDeclAttributesVisitor(shared).dispatch(decl);
+ break;
+
case DeclCheckState::Checked:
SemanticsDeclBodyVisitor(shared).dispatch(decl);
break;
@@ -7058,4 +6843,297 @@ namespace Slang
return val;
}
+
+ template<typename TDerivativeAttr>
+ void checkDerivativeAttributeImpl(
+ SemanticsVisitor* visitor,
+ TDerivativeAttr* attr,
+ const List<Expr*>& imaginaryArguments)
+ {
+ SemanticsContext::ExprLocalScope scope;
+ auto ctx = visitor->withExprLocalScope(&scope);
+ auto subVisitor = SemanticsVisitor(ctx);
+ auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx);
+ auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
+ auto resolved = subVisitor.ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ {
+ attr->funcExpr = calleeDeclRef;
+ return;
+ }
+ }
+ visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ }
+
+ template<typename TDerivativeAttr>
+ const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); }
+
+ template<>
+ const char* getDerivativeAttrName<ForwardDerivativeAttribute>()
+ {
+ return "ForwardDerivative";
+ }
+ template<>
+ const char* getDerivativeAttrName<BackwardDerivativeAttribute>()
+ {
+ return "BackwardDerivative";
+ }
+ template<>
+ const char* getDerivativeAttrName<PrimalSubstituteAttribute>()
+ {
+ return "PrimalSubstitute";
+ }
+
+ List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc)
+ {
+ List<Expr*> imaginaryArguments;
+ for (auto param : func->getParameters())
+ {
+ auto arg = astBuilder->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = loc;
+ imaginaryArguments.add(arg);
+ }
+ return imaginaryArguments;
+ }
+
+ List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ {
+ List<Expr*> imaginaryArguments;
+ for (auto param : originalFuncDecl->getParameters())
+ {
+ auto arg = visitor->getASTBuilder()->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = loc;
+ if (auto pairType = visitor->getDifferentialPairType(param->getType()))
+ {
+ arg->type.type = pairType;
+ }
+ imaginaryArguments.add(arg);
+ }
+ return imaginaryArguments;
+ }
+
+ List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ {
+ List<Expr*> imaginaryArguments;
+ auto isOutParam = [&](ParamDecl* param)
+ {
+ return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr;
+ };
+
+ for (auto param : originalFuncDecl->getParameters())
+ {
+ auto arg = visitor->getASTBuilder()->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = loc;
+ if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
+ {
+ arg->type.type = pairType;
+ if (isOutParam(param))
+ {
+ // out T -> in T.Differential
+ arg->type.isLeftValue = false;
+ arg->type.type = visitor->tryGetDifferentialType(
+ visitor->getASTBuilder(), pairType->getPrimalType());
+ }
+ }
+ else
+ {
+ if (isOutParam(param))
+ {
+ // Skip non-differentiable out params.
+ continue;
+ }
+ }
+ imaginaryArguments.add(arg);
+ }
+ if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type))
+ {
+ auto arg = visitor->getASTBuilder()->create<InitializerListExpr>();
+ arg->type.isLeftValue = false;
+ arg->type.type = diffReturnType;
+ arg->loc = loc;
+ imaginaryArguments.add(arg);
+ }
+ return imaginaryArguments;
+ }
+
+ // This helper function is needed to workaround a gcc bug.
+ // Remove when we upgrade to a newer version of gcc.
+ template <typename T>
+ static T* _findModifier(Decl* decl)
+ {
+ return decl->findModifier<T>();
+ }
+
+ template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr>
+ void checkDerivativeOfAttributeImpl(
+ SemanticsVisitor* visitor,
+ FunctionDeclBase* funcDecl,
+ TDerivativeOfAttr* derivativeOfAttr,
+ DeclAssociationKind assocKind)
+ {
+ DeclRef<Decl> calleeDeclRef;
+ DeclRefExpr* calleeDeclRefExpr = nullptr;
+ HigherOrderInvokeExpr* higherOrderFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>();
+ higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr;
+ higherOrderFuncExpr->loc = derivativeOfAttr->loc;
+ Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, *visitor);
+ if (!checkedHigherOrderFuncExpr)
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
+ }
+ List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc);
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs);
+ SemanticsContext::ExprLocalScope scope;
+ auto ctx = visitor->withExprLocalScope(&scope);
+ auto subVisitor = SemanticsVisitor(ctx);
+ auto resolved = subVisitor.ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ auto resolvedFuncExpr = as<HigherOrderInvokeExpr>(resolvedInvoke->functionExpr);
+ if (resolvedFuncExpr)
+ calleeDeclRefExpr = as<DeclRefExpr>(resolvedFuncExpr->baseFunction);
+ }
+
+ if (!calleeDeclRefExpr)
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
+ }
+ calleeDeclRef = calleeDeclRefExpr->declRef;
+
+ auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
+ if (!calleeFunc)
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
+ }
+
+ if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc))
+ {
+ // The primal function already has a `[*Derivative]` attribute, this is invalid.
+ visitor->getSink()->diagnose(
+ derivativeOfAttr,
+ Diagnostics::declAlreadyHasAttribute,
+ calleeDeclRef,
+ getDerivativeAttrName<TDerivativeAttr>());
+ visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl());
+ }
+
+ derivativeOfAttr->funcExpr = calleeDeclRefExpr;
+ auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>();
+ derivativeAttr->loc = derivativeOfAttr->loc;
+ auto outterGeneric = visitor->GetOuterGeneric(funcDecl);
+ auto declRef =
+ DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
+ auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr);
+ declRefExpr->type.type = nullptr;
+ derivativeAttr->args.add(declRefExpr);
+ derivativeAttr->funcExpr = declRefExpr;
+ checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr);
+ derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr;
+ derivativeAttr->funcExpr = nullptr;
+ visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl);
+ }
+
+ static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ }
+
+ static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ }
+
+ static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ }
+
+ template<typename TDerivativeAttr, typename TDerivativeOfAttr>
+ bool tryCheckDerivativeOfAttributeImpl(
+ SemanticsVisitor* visitor,
+ FunctionDeclBase* funcDecl,
+ TDerivativeOfAttr* derivativeOfAttr,
+ DeclAssociationKind assocKind,
+ const List<Expr*>& imaginaryArgsToOriginal)
+ {
+ DiagnosticSink tempSink(visitor->getSourceManager(), nullptr);
+ SemanticsVisitor subVisitor(visitor->withSink(&tempSink));
+ checkDerivativeOfAttributeImpl<TDerivativeAttr>(
+ &subVisitor,
+ funcDecl,
+ derivativeOfAttr,
+ assocKind,
+ imaginaryArgsToOriginal);
+ return tempSink.getErrorCount() == 0;
+ }
+
+ void SemanticsDeclAttributesVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr)
+ {
+ checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>(
+ this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc);
+ }
+
+ void SemanticsDeclAttributesVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr)
+ {
+ checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>(
+ this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc);
+ }
+
+ void SemanticsDeclAttributesVisitor::checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr)
+ {
+ checkDerivativeOfAttributeImpl<PrimalSubstituteAttribute, PrimalSubstituteExpr>(
+ this, funcDecl, attr, DeclAssociationKind::PrimalSubstituteFunc);
+ }
+
+ void SemanticsDeclAttributesVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
+ {
+ // Run checking on attributes that can't be fully checked in header checking stage.
+ for (auto attr : decl->modifiers)
+ {
+ if (auto fwdDerivativeOfAttr = as<ForwardDerivativeOfAttribute>(attr))
+ checkForwardDerivativeOfAttribute(decl, fwdDerivativeOfAttr);
+ else if (auto bwdDerivativeOfAttr = as<BackwardDerivativeOfAttribute>(attr))
+ checkBackwardDerivativeOfAttribute(decl, bwdDerivativeOfAttr);
+ else if (auto primalOfAttr = as<PrimalSubstituteOfAttribute>(attr))
+ checkPrimalSubstituteOfAttribute(decl, primalOfAttr);
+ else if (auto fwdDerivativeAttr = as<ForwardDerivativeAttribute>(attr))
+ checkDerivativeAttribute(this, decl, fwdDerivativeAttr);
+ else if (auto bwdDerivativeAttr = as<BackwardDerivativeAttribute>(attr))
+ checkDerivativeAttribute(this, decl, bwdDerivativeAttr);
+ else if (auto primalAttr = as<PrimalSubstituteAttribute>(attr))
+ checkDerivativeAttribute(this, decl, primalAttr);
+ }
+ }
}