From 257733f328f38a763c8b0c8830ff4c0d34ec9491 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 7 Mar 2023 11:22:32 -0800 Subject: Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688) * Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs. * Add diff implementation matrix versions of binary and ternary intrinsics. * Add diff impl for legacy intrinsics. * Fix diagnostics of using non-differentiable function in a diff operator. * Add diff implementation for `determinant`. --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 192 +++++++++++--------------------------- 1 file changed, 54 insertions(+), 138 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index a1d5acfb0..7c42c1892 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4663,7 +4663,8 @@ namespace Slang TDerivativeAttr* attr, const List& imaginaryArguments) { - auto invokeExpr = visitor->constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); auto resolved = visitor->ResolveInvoke(invokeExpr); if (auto resolvedInvoke = as(resolved)) { @@ -4690,38 +4691,34 @@ namespace Slang return "BackwardDerivative"; } - List getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + List getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) { List imaginaryArguments; - for (auto param : originalFuncDecl->getParameters()) + for (auto param : func->getParameters()) { - auto arg = visitor->getASTBuilder()->create(); + auto arg = astBuilder->create(); arg->declRef.decl = param; arg->type.isLeftValue = param->findModifier() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) - { - arg->type.type = pairType; - } imaginaryArguments.add(arg); } return imaginaryArguments; } - List getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* fwdDiffFunc, SourceLoc loc) + List getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { List imaginaryArguments; - for (auto param : fwdDiffFunc->getParameters()) + for (auto param : originalFuncDecl->getParameters()) { - auto arg = astBuilder->create(); + auto arg = visitor->getASTBuilder()->create(); arg->declRef.decl = param; arg->type.isLeftValue = param->findModifier() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = as(param->getType())) + if (auto pairType = visitor->getDifferentialPairType(param->getType())) { - arg->type.type = pairType->getPrimalType(); + arg->type.type = pairType; } imaginaryArguments.add(arg); } @@ -4731,6 +4728,11 @@ namespace Slang List getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { List imaginaryArguments; + auto isOutParam = [&](ParamDecl* param) + { + return param->findModifier() != nullptr && param->findModifier() == nullptr; + }; + for (auto param : originalFuncDecl->getParameters()) { auto arg = visitor->getASTBuilder()->create(); @@ -4738,16 +4740,23 @@ namespace Slang arg->type.isLeftValue = param->findModifier() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) + if (auto pairType = as(visitor->getDifferentialPairType(param->getType()))) { arg->type.type = pairType; - if (auto diffPairType = as(pairType)) + if (isOutParam(param)) { - if (param->findModifier() != nullptr && param->findModifier() == nullptr) - { - arg->type.isLeftValue = false; - arg->type.type = diffPairType->getPrimalType(); - } + // out T -> in T.Differential + arg->type.isLeftValue = false; + arg->type.type = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), pairType->getPrimalType()); + } + } + else + { + if (isOutParam(param)) + { + // Skip non-differentiable out params. + continue; } } imaginaryArguments.add(arg); @@ -4763,38 +4772,6 @@ namespace Slang return imaginaryArguments; } - List getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc) - { - // Note: it isn't always possible to construct original arguments from - // backward propagation arguments because backward propagation function - // may drop certain parameters. - List imaginaryArguments; - for (auto param : bwdDiffFunc->getParameters()) - { - auto arg = astBuilder->create(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (auto pairType = as(param->getType())) - { - if (param->findModifier() != nullptr && param->findModifier() == nullptr) - { - arg->type.isLeftValue = false; - } - arg->type.type = pairType->getPrimalType(); - } - imaginaryArguments.add(arg); - } - // Assume the last parameter is `dOut`. - // This is not true if the function returns a non-differentiable value. - // However in that uncommon case we just fail the overload resolution - // and require the user to provide disambiguate themselves. - if (imaginaryArguments.getCount()) - imaginaryArguments.fastRemoveAt(imaginaryArguments.getCount() - 1); - return imaginaryArguments; - } - // This helper function is needed to workaround a gcc bug. // Remove when we upgrade to a newer version of gcc. template @@ -4803,76 +4780,41 @@ namespace Slang return decl->findModifier(); } - template + template void checkDerivativeOfAttributeImpl( SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind, - const List& imaginaryArgsToOriginal) + DeclAssociationKind assocKind) { DeclRef calleeDeclRef; - auto calleeDeclRefExpr = as(derivativeOfAttr->funcExpr); - if (!calleeDeclRefExpr) + DeclRefExpr* calleeDeclRefExpr = nullptr; + DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create(); + diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr; + diffFuncExpr->loc = derivativeOfAttr->loc; + Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor); + if (!checkedDiffFuncExpr) { - auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as(resolved)) - { - calleeDeclRefExpr = as(resolvedInvoke->functionExpr); - } + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + List imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs); + auto resolved = visitor->ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as(resolved)) + { + auto resolvedDiffFuncExpr = as(resolvedInvoke->functionExpr); + if (resolvedDiffFuncExpr) + calleeDeclRefExpr = as(resolvedDiffFuncExpr->baseFunction); } + if (!calleeDeclRefExpr) { visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } calleeDeclRef = calleeDeclRefExpr->declRef; - if (auto calleeGenDecl = as(calleeDeclRef.getDecl())) - { - auto parentGenericDecl = as(funcDecl->parentDecl); - if (!parentGenericDecl) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - FunctionDeclBase* funcReturnVal = nullptr; - List args; - for (auto mm : parentGenericDecl->members) - { - if (auto genericTypeParamDecl = as(mm)) - { - args.add(DeclRefType::create(visitor->getASTBuilder(), DeclRef(genericTypeParamDecl, nullptr))); - } - else if (auto genericValueParamDecl = as(mm)) - { - args.add(visitor->getASTBuilder()->getOrCreate( - genericValueParamDecl->getType(), - genericValueParamDecl, nullptr)); - } - } - auto funcs = calleeGenDecl->getMembersOfType(); - if (funcs.isEmpty()) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - funcReturnVal = funcs.getFirst(); - if (funcReturnVal) - { - auto subst = visitor->getASTBuilder()->getOrCreateGenericSubstitution(calleeGenDecl, args, nullptr); - calleeDeclRef.decl = funcReturnVal; - calleeDeclRef.substitutions = subst; - calleeDeclRefExpr = as(visitor->ConstructDeclRefExpr( - calleeDeclRef, nullptr, derivativeOfAttr->loc, nullptr)); - } - else - { - calleeDeclRef = DeclRef(); - calleeDeclRefExpr = nullptr; - } - } - + auto calleeFunc = as(calleeDeclRef.getDecl()); if (!calleeFunc) { @@ -4953,9 +4895,8 @@ namespace Slang if (!attr) return; - List imaginaryArgsToOriginal = getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc); - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc, imaginaryArgsToOriginal); + checkDerivativeOfAttributeImpl( + this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); } void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) @@ -4964,33 +4905,8 @@ namespace Slang if (!attr) return; - List imaginaryArguments = getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc); - - // The tricky part here is that we can't easily derive the arguments to original func just - // from the definition of a backward derivative function, because we don't know if the last - // parameter is just a normal parameter of the original func, or if it is the additional - // derivative of the return value. The solution here is to try to resolve the original - // function with or without the last argument. However if the type of the last argument - // isn't differentiable, we know that it can't possibly be the result derivative. - - if (imaginaryArguments.getCount() == 0 || - !tryGetDifferentialType(m_astBuilder, imaginaryArguments.getLast()->type.type)) - { - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments); - return; - } - - // Otherwise, try resolve with all the arguments, if failed, resolve without the last - // argument. - if (tryCheckDerivativeOfAttributeImpl(this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments)) - { - return; - } - - imaginaryArguments.removeLast(); - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments); + checkDerivativeOfAttributeImpl( + this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); } void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) -- cgit v1.2.3