summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-07 11:22:32 -0800
committerGitHub <noreply@github.com>2023-03-07 11:22:32 -0800
commit257733f328f38a763c8b0c8830ff4c0d34ec9491 (patch)
tree87e444746f353d69a365380904f3f8caf15fbfec /source/slang/slang-check-decl.cpp
parent6f31eae79d5b4297d0099c5779a9806a786cf9f8 (diff)
Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688)
* Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs. * Add diff implementation matrix versions of binary and ternary intrinsics. * Add diff impl for legacy intrinsics. * Fix diagnostics of using non-differentiable function in a diff operator. * Add diff implementation for `determinant`. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp192
1 files changed, 54 insertions, 138 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index a1d5acfb0..7c42c1892 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -4663,7 +4663,8 @@ namespace Slang
TDerivativeAttr* attr,
const List<Expr*>& imaginaryArguments)
{
- auto invokeExpr = visitor->constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
+ auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor);
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
auto resolved = visitor->ResolveInvoke(invokeExpr);
if (auto resolvedInvoke = as<InvokeExpr>(resolved))
{
@@ -4690,38 +4691,34 @@ namespace Slang
return "BackwardDerivative";
}
- List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc)
{
List<Expr*> imaginaryArguments;
- for (auto param : originalFuncDecl->getParameters())
+ for (auto param : func->getParameters())
{
- auto arg = visitor->getASTBuilder()->create<VarExpr>();
+ auto arg = astBuilder->create<VarExpr>();
arg->declRef.decl = param;
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
arg->loc = loc;
- if (auto pairType = visitor->getDifferentialPairType(param->getType()))
- {
- arg->type.type = pairType;
- }
imaginaryArguments.add(arg);
}
return imaginaryArguments;
}
- List<Expr*> getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* fwdDiffFunc, SourceLoc loc)
+ List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
{
List<Expr*> imaginaryArguments;
- for (auto param : fwdDiffFunc->getParameters())
+ for (auto param : originalFuncDecl->getParameters())
{
- auto arg = astBuilder->create<VarExpr>();
+ auto arg = visitor->getASTBuilder()->create<VarExpr>();
arg->declRef.decl = param;
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
arg->loc = loc;
- if (auto pairType = as<DifferentialPairType>(param->getType()))
+ if (auto pairType = visitor->getDifferentialPairType(param->getType()))
{
- arg->type.type = pairType->getPrimalType();
+ arg->type.type = pairType;
}
imaginaryArguments.add(arg);
}
@@ -4731,6 +4728,11 @@ namespace Slang
List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
{
List<Expr*> imaginaryArguments;
+ auto isOutParam = [&](ParamDecl* param)
+ {
+ return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr;
+ };
+
for (auto param : originalFuncDecl->getParameters())
{
auto arg = visitor->getASTBuilder()->create<VarExpr>();
@@ -4738,16 +4740,23 @@ namespace Slang
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
arg->loc = loc;
- if (auto pairType = visitor->getDifferentialPairType(param->getType()))
+ if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
{
arg->type.type = pairType;
- if (auto diffPairType = as<DifferentialPairType>(pairType))
+ if (isOutParam(param))
{
- if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr)
- {
- arg->type.isLeftValue = false;
- arg->type.type = diffPairType->getPrimalType();
- }
+ // out T -> in T.Differential
+ arg->type.isLeftValue = false;
+ arg->type.type = visitor->tryGetDifferentialType(
+ visitor->getASTBuilder(), pairType->getPrimalType());
+ }
+ }
+ else
+ {
+ if (isOutParam(param))
+ {
+ // Skip non-differentiable out params.
+ continue;
}
}
imaginaryArguments.add(arg);
@@ -4763,38 +4772,6 @@ namespace Slang
return imaginaryArguments;
}
- List<Expr*> getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc)
- {
- // Note: it isn't always possible to construct original arguments from
- // backward propagation arguments because backward propagation function
- // may drop certain parameters.
- List<Expr*> imaginaryArguments;
- for (auto param : bwdDiffFunc->getParameters())
- {
- auto arg = astBuilder->create<VarExpr>();
- arg->declRef.decl = param;
- arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
- arg->type.type = param->getType();
- arg->loc = loc;
- if (auto pairType = as<DifferentialPairType>(param->getType()))
- {
- if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr)
- {
- arg->type.isLeftValue = false;
- }
- arg->type.type = pairType->getPrimalType();
- }
- imaginaryArguments.add(arg);
- }
- // Assume the last parameter is `dOut`.
- // This is not true if the function returns a non-differentiable value.
- // However in that uncommon case we just fail the overload resolution
- // and require the user to provide disambiguate themselves.
- if (imaginaryArguments.getCount())
- imaginaryArguments.fastRemoveAt(imaginaryArguments.getCount() - 1);
- return imaginaryArguments;
- }
-
// This helper function is needed to workaround a gcc bug.
// Remove when we upgrade to a newer version of gcc.
template <typename T>
@@ -4803,76 +4780,41 @@ namespace Slang
return decl->findModifier<T>();
}
- template <typename TDerivativeAttr, typename TDerivativeOfAttr>
+ template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr>
void checkDerivativeOfAttributeImpl(
SemanticsVisitor* visitor,
FunctionDeclBase* funcDecl,
TDerivativeOfAttr* derivativeOfAttr,
- DeclAssociationKind assocKind,
- const List<Expr*>& imaginaryArgsToOriginal)
+ DeclAssociationKind assocKind)
{
DeclRef<Decl> calleeDeclRef;
- auto calleeDeclRefExpr = as<DeclRefExpr>(derivativeOfAttr->funcExpr);
- if (!calleeDeclRefExpr)
+ DeclRefExpr* calleeDeclRefExpr = nullptr;
+ DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>();
+ diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr;
+ diffFuncExpr->loc = derivativeOfAttr->loc;
+ Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor);
+ if (!checkedDiffFuncExpr)
{
- auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal);
- auto resolved = visitor->ResolveInvoke(invokeExpr);
- if (auto resolvedInvoke = as<InvokeExpr>(resolved))
- {
- calleeDeclRefExpr = as<DeclRefExpr>(resolvedInvoke->functionExpr);
- }
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
+ }
+ List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc);
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs);
+ auto resolved = visitor->ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ auto resolvedDiffFuncExpr = as<DifferentiateExpr>(resolvedInvoke->functionExpr);
+ if (resolvedDiffFuncExpr)
+ calleeDeclRefExpr = as<DeclRefExpr>(resolvedDiffFuncExpr->baseFunction);
}
+
if (!calleeDeclRefExpr)
{
visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
return;
}
calleeDeclRef = calleeDeclRefExpr->declRef;
- if (auto calleeGenDecl = as<GenericDecl>(calleeDeclRef.getDecl()))
- {
- auto parentGenericDecl = as<GenericDecl>(funcDecl->parentDecl);
- if (!parentGenericDecl)
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
- return;
- }
- FunctionDeclBase* funcReturnVal = nullptr;
- List<Val*> args;
- for (auto mm : parentGenericDecl->members)
- {
- if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm))
- {
- args.add(DeclRefType::create(visitor->getASTBuilder(), DeclRef<Decl>(genericTypeParamDecl, nullptr)));
- }
- else if (auto genericValueParamDecl = as<GenericValueParamDecl>(mm))
- {
- args.add(visitor->getASTBuilder()->getOrCreate<GenericParamIntVal>(
- genericValueParamDecl->getType(),
- genericValueParamDecl, nullptr));
- }
- }
- auto funcs = calleeGenDecl->getMembersOfType<FunctionDeclBase>();
- if (funcs.isEmpty())
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
- return;
- }
- funcReturnVal = funcs.getFirst();
- if (funcReturnVal)
- {
- auto subst = visitor->getASTBuilder()->getOrCreateGenericSubstitution(calleeGenDecl, args, nullptr);
- calleeDeclRef.decl = funcReturnVal;
- calleeDeclRef.substitutions = subst;
- calleeDeclRefExpr = as<DeclRefExpr>(visitor->ConstructDeclRefExpr(
- calleeDeclRef, nullptr, derivativeOfAttr->loc, nullptr));
- }
- else
- {
- calleeDeclRef = DeclRef<Decl>();
- calleeDeclRefExpr = nullptr;
- }
- }
-
+
auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
if (!calleeFunc)
{
@@ -4953,9 +4895,8 @@ namespace Slang
if (!attr)
return;
- List<Expr*> imaginaryArgsToOriginal = getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc);
- checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute>(
- this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc, imaginaryArgsToOriginal);
+ checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>(
+ this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc);
}
void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl)
@@ -4964,33 +4905,8 @@ namespace Slang
if (!attr)
return;
- List<Expr*> imaginaryArguments = getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc);
-
- // The tricky part here is that we can't easily derive the arguments to original func just
- // from the definition of a backward derivative function, because we don't know if the last
- // parameter is just a normal parameter of the original func, or if it is the additional
- // derivative of the return value. The solution here is to try to resolve the original
- // function with or without the last argument. However if the type of the last argument
- // isn't differentiable, we know that it can't possibly be the result derivative.
-
- if (imaginaryArguments.getCount() == 0 ||
- !tryGetDifferentialType(m_astBuilder, imaginaryArguments.getLast()->type.type))
- {
- checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(
- this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments);
- return;
- }
-
- // Otherwise, try resolve with all the arguments, if failed, resolve without the last
- // argument.
- if (tryCheckDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments))
- {
- return;
- }
-
- imaginaryArguments.removeLast();
- checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(
- this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments);
+ checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>(
+ this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc);
}
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)