diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-07 11:22:32 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-07 11:22:32 -0800 |
| commit | 257733f328f38a763c8b0c8830ff4c0d34ec9491 (patch) | |
| tree | 87e444746f353d69a365380904f3f8caf15fbfec /source/slang/slang-check-decl.cpp | |
| parent | 6f31eae79d5b4297d0099c5779a9806a786cf9f8 (diff) | |
Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688)
* Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs.
* Add diff implementation matrix versions of binary and ternary intrinsics.
* Add diff impl for legacy intrinsics.
* Fix diagnostics of using non-differentiable function in a diff operator.
* Add diff implementation for `determinant`.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 192 |
1 files changed, 54 insertions, 138 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index a1d5acfb0..7c42c1892 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4663,7 +4663,8 @@ namespace Slang TDerivativeAttr* attr, const List<Expr*>& imaginaryArguments) { - auto invokeExpr = visitor->constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); auto resolved = visitor->ResolveInvoke(invokeExpr); if (auto resolvedInvoke = as<InvokeExpr>(resolved)) { @@ -4690,38 +4691,34 @@ namespace Slang return "BackwardDerivative"; } - List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) { List<Expr*> imaginaryArguments; - for (auto param : originalFuncDecl->getParameters()) + for (auto param : func->getParameters()) { - auto arg = visitor->getASTBuilder()->create<VarExpr>(); + auto arg = astBuilder->create<VarExpr>(); arg->declRef.decl = param; arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) - { - arg->type.type = pairType; - } imaginaryArguments.add(arg); } return imaginaryArguments; } - List<Expr*> getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* fwdDiffFunc, SourceLoc loc) + List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { List<Expr*> imaginaryArguments; - for (auto param : fwdDiffFunc->getParameters()) + for (auto param : originalFuncDecl->getParameters()) { - auto arg = astBuilder->create<VarExpr>(); + auto arg = visitor->getASTBuilder()->create<VarExpr>(); arg->declRef.decl = param; arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = as<DifferentialPairType>(param->getType())) + if (auto pairType = visitor->getDifferentialPairType(param->getType())) { - arg->type.type = pairType->getPrimalType(); + arg->type.type = pairType; } imaginaryArguments.add(arg); } @@ -4731,6 +4728,11 @@ namespace Slang List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { List<Expr*> imaginaryArguments; + auto isOutParam = [&](ParamDecl* param) + { + return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr; + }; + for (auto param : originalFuncDecl->getParameters()) { auto arg = visitor->getASTBuilder()->create<VarExpr>(); @@ -4738,16 +4740,23 @@ namespace Slang arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) + if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType()))) { arg->type.type = pairType; - if (auto diffPairType = as<DifferentialPairType>(pairType)) + if (isOutParam(param)) { - if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr) - { - arg->type.isLeftValue = false; - arg->type.type = diffPairType->getPrimalType(); - } + // out T -> in T.Differential + arg->type.isLeftValue = false; + arg->type.type = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), pairType->getPrimalType()); + } + } + else + { + if (isOutParam(param)) + { + // Skip non-differentiable out params. + continue; } } imaginaryArguments.add(arg); @@ -4763,38 +4772,6 @@ namespace Slang return imaginaryArguments; } - List<Expr*> getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc) - { - // Note: it isn't always possible to construct original arguments from - // backward propagation arguments because backward propagation function - // may drop certain parameters. - List<Expr*> imaginaryArguments; - for (auto param : bwdDiffFunc->getParameters()) - { - auto arg = astBuilder->create<VarExpr>(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (auto pairType = as<DifferentialPairType>(param->getType())) - { - if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr) - { - arg->type.isLeftValue = false; - } - arg->type.type = pairType->getPrimalType(); - } - imaginaryArguments.add(arg); - } - // Assume the last parameter is `dOut`. - // This is not true if the function returns a non-differentiable value. - // However in that uncommon case we just fail the overload resolution - // and require the user to provide disambiguate themselves. - if (imaginaryArguments.getCount()) - imaginaryArguments.fastRemoveAt(imaginaryArguments.getCount() - 1); - return imaginaryArguments; - } - // This helper function is needed to workaround a gcc bug. // Remove when we upgrade to a newer version of gcc. template <typename T> @@ -4803,76 +4780,41 @@ namespace Slang return decl->findModifier<T>(); } - template <typename TDerivativeAttr, typename TDerivativeOfAttr> + template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr> void checkDerivativeOfAttributeImpl( SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind, - const List<Expr*>& imaginaryArgsToOriginal) + DeclAssociationKind assocKind) { DeclRef<Decl> calleeDeclRef; - auto calleeDeclRefExpr = as<DeclRefExpr>(derivativeOfAttr->funcExpr); - if (!calleeDeclRefExpr) + DeclRefExpr* calleeDeclRefExpr = nullptr; + DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>(); + diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr; + diffFuncExpr->loc = derivativeOfAttr->loc; + Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor); + if (!checkedDiffFuncExpr) { - auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as<InvokeExpr>(resolved)) - { - calleeDeclRefExpr = as<DeclRefExpr>(resolvedInvoke->functionExpr); - } + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs); + auto resolved = visitor->ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + auto resolvedDiffFuncExpr = as<DifferentiateExpr>(resolvedInvoke->functionExpr); + if (resolvedDiffFuncExpr) + calleeDeclRefExpr = as<DeclRefExpr>(resolvedDiffFuncExpr->baseFunction); } + if (!calleeDeclRefExpr) { visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } calleeDeclRef = calleeDeclRefExpr->declRef; - if (auto calleeGenDecl = as<GenericDecl>(calleeDeclRef.getDecl())) - { - auto parentGenericDecl = as<GenericDecl>(funcDecl->parentDecl); - if (!parentGenericDecl) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - FunctionDeclBase* funcReturnVal = nullptr; - List<Val*> args; - for (auto mm : parentGenericDecl->members) - { - if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm)) - { - args.add(DeclRefType::create(visitor->getASTBuilder(), DeclRef<Decl>(genericTypeParamDecl, nullptr))); - } - else if (auto genericValueParamDecl = as<GenericValueParamDecl>(mm)) - { - args.add(visitor->getASTBuilder()->getOrCreate<GenericParamIntVal>( - genericValueParamDecl->getType(), - genericValueParamDecl, nullptr)); - } - } - auto funcs = calleeGenDecl->getMembersOfType<FunctionDeclBase>(); - if (funcs.isEmpty()) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - funcReturnVal = funcs.getFirst(); - if (funcReturnVal) - { - auto subst = visitor->getASTBuilder()->getOrCreateGenericSubstitution(calleeGenDecl, args, nullptr); - calleeDeclRef.decl = funcReturnVal; - calleeDeclRef.substitutions = subst; - calleeDeclRefExpr = as<DeclRefExpr>(visitor->ConstructDeclRefExpr( - calleeDeclRef, nullptr, derivativeOfAttr->loc, nullptr)); - } - else - { - calleeDeclRef = DeclRef<Decl>(); - calleeDeclRefExpr = nullptr; - } - } - + auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl()); if (!calleeFunc) { @@ -4953,9 +4895,8 @@ namespace Slang if (!attr) return; - List<Expr*> imaginaryArgsToOriginal = getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc); - checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute>( - this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc, imaginaryArgsToOriginal); + checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>( + this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); } void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) @@ -4964,33 +4905,8 @@ namespace Slang if (!attr) return; - List<Expr*> imaginaryArguments = getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc); - - // The tricky part here is that we can't easily derive the arguments to original func just - // from the definition of a backward derivative function, because we don't know if the last - // parameter is just a normal parameter of the original func, or if it is the additional - // derivative of the return value. The solution here is to try to resolve the original - // function with or without the last argument. However if the type of the last argument - // isn't differentiable, we know that it can't possibly be the result derivative. - - if (imaginaryArguments.getCount() == 0 || - !tryGetDifferentialType(m_astBuilder, imaginaryArguments.getLast()->type.type)) - { - checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments); - return; - } - - // Otherwise, try resolve with all the arguments, if failed, resolve without the last - // argument. - if (tryCheckDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments)) - { - return; - } - - imaginaryArguments.removeLast(); - checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments); + checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>( + this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); } void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) |
