From a3ad4dd77bba6c87abad4f76b72055c9fed94bad Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 29 Jun 2023 20:39:40 -0400 Subject: Issue diagnostic for incorrect parameter types & directionality when defining custom derivatives (#2947) * Issue diagnostic for incorrect directionality when defining custom derivative * Better diagnostics on invalid custom derivatives * Avoid duplicating `getParameterDirection()` --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 128 ++++++++++++++++++++++++++++++----- source/slang/slang-diagnostic-defs.h | 2 + source/slang/slang-syntax.h | 2 + 3 files changed, 115 insertions(+), 17 deletions(-) (limited to 'source') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4e2f146a1..eeb54fc4e 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6918,12 +6918,18 @@ namespace Slang return val; } + struct ArgsWithDirectionInfo + { + List args; + List directions; + }; template void checkDerivativeAttributeImpl( SemanticsVisitor* visitor, TDerivativeAttr* attr, - const List& imaginaryArguments) + const List& imaginaryArguments, + const List& expectedParamDirections) { SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); @@ -6949,13 +6955,51 @@ namespace Slang visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveDerivativeFunction); return; } - + + // If left value is true, then convert the + // inner type to an InOutType. + // + auto qualTypeToString = [&](QualType qualType) -> String + { + Type* type = qualType.type; + if (qualType.isLeftValue) + { + type = ctx.getASTBuilder()->getInOutType(type); + } + return type->toString(); + }; + auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as(resolved)) { if (auto calleeDeclRef = as(resolvedInvoke->functionExpr)) { + // There are two ways to make it to this point.. a proper resolution, and a + // resolution that has failed due to type mismatch. + // Further, a proper resolution can still be invalid due to incorrect parameter + // directionality. + // We'll detect both these incorrect cases here and issue an appropriate diagnostic. + // + auto funcType = as(calleeDeclRef->type); + for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii) + { + // Check if the resolved invoke argument type is an error type. + // If so, then we have a type mismatch. + // + if (resolvedInvoke->arguments[ii]->type.type->equals(ctx.getASTBuilder()->getErrorType()) || + funcType->getParamDirection(ii) != expectedParamDirections[ii]) + { + visitor->getSink()->diagnose( + attr, + Diagnostics::customDerivativeSignatureMismatchAtPosition, + ii, + qualTypeToString(imaginaryArguments[ii]->type), + funcType->getParamType(ii)->toString()); + } + } + attr->funcExpr = calleeDeclRef; if (attr->args.getCount()) attr->args[0] = attr->funcExpr; @@ -6963,7 +7007,23 @@ namespace Slang } } - visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + // Build the expected signature from imaginary args to diagnose + // when no matching function is found (this excludes the case handled above) + // + StringBuilder builder; + builder << "("; + for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii) + { + if (ii != 0) + builder << ", "; + if (imaginaryArguments[ii]->type) + builder << qualTypeToString(imaginaryArguments[ii]->type); + else + builder << ""; + } + builder << ")"; + + visitor->getSink()->diagnose(attr, Diagnostics::customDerivativeSignatureMismatch, builder.produceString()); } template @@ -6985,9 +7045,10 @@ namespace Slang return "PrimalSubstitute"; } - List getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) + ArgsWithDirectionInfo getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) { List imaginaryArguments; + List directions; for (auto param : func->getParameters()) { auto arg = astBuilder->create(); @@ -6996,11 +7057,12 @@ namespace Slang arg->type.type = param->getType(); arg->loc = loc; imaginaryArguments.add(arg); + directions.add(getParameterDirection(param)); } - return imaginaryArguments; + return { imaginaryArguments, directions }; } - List getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + ArgsWithDirectionInfo getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { List imaginaryArguments; for (auto param : originalFuncDecl->getParameters()) @@ -7019,12 +7081,22 @@ namespace Slang } imaginaryArguments.add(arg); } - return imaginaryArguments; + + // Copy parameter directions as is. + List expectedParamDirections; + for (auto param : originalFuncDecl->getParameters()) + { + expectedParamDirections.add(getParameterDirection(param)); + } + + return { imaginaryArguments, expectedParamDirections }; } - List getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + ArgsWithDirectionInfo getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { List imaginaryArguments; + List expectedParamDirections; + auto isOutParam = [&](ParamDecl* param) { return param->findModifier() != nullptr @@ -7038,18 +7110,31 @@ namespace Slang arg->type.isLeftValue = param->findModifier() ? true : false; arg->type.type = param->getType(); arg->loc = loc; + + ParameterDirection direction = getParameterDirection(param); + bool isDiffParam = (!param->findModifier()); if (isDiffParam) { if (auto pairType = as(visitor->getDifferentialPairType(param->getType()))) { arg->type.type = pairType; + arg->type.isLeftValue = true; + if (isOutParam(param)) { - // out T -> in T.Differential + // out T : IDifferentiable -> in T.Differential arg->type.isLeftValue = false; arg->type.type = visitor->tryGetDifferentialType( visitor->getASTBuilder(), pairType->getPrimalType()); + + direction = ParameterDirection::kParameterDirection_In; + } + else + { + // in T : IDifferentiable -> inout DifferentialPair + // inout T : IDifferentiable -> inout DifferentialPair + direction = ParameterDirection::kParameterDirection_InOut; } } else @@ -7064,8 +7149,15 @@ namespace Slang // Skip non-differentiable out params. continue; } + + // no_diff inout T -> in T + // no_diff in T -> in T + // + direction = ParameterDirection::kParameterDirection_In; } + imaginaryArguments.add(arg); + expectedParamDirections.add(direction); } if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type)) { @@ -7074,8 +7166,10 @@ namespace Slang arg->type.type = diffReturnType; arg->loc = loc; imaginaryArguments.add(arg); + expectedParamDirections.add(ParameterDirection::kParameterDirection_In); } - return imaginaryArguments; + + return {imaginaryArguments, expectedParamDirections}; } // This helper function is needed to workaround a gcc bug. @@ -7105,7 +7199,7 @@ namespace Slang visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } - List imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); + List imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc).args; auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs); SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); @@ -7171,8 +7265,8 @@ namespace Slang if (attr->funcExpr->type.type) return; - List imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) @@ -7182,8 +7276,8 @@ namespace Slang if (attr->funcExpr->type.type) return; - List imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr) @@ -7193,8 +7287,8 @@ namespace Slang if (attr->funcExpr->type.type) return; - List imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); } template diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index c01a6ddf1..d38cde4e6 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -362,6 +362,8 @@ DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative att DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.") DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.") DIAGNOSTIC(31148, Error, cannotResolveDerivativeFunction, "cannot resolve the custom derivative function") +DIAGNOSTIC(31149, Error, customDerivativeSignatureMismatchAtPosition, "invalid custom derivative. parameter type mismatch at position $0. expected '$1', got '$2'") +DIAGNOSTIC(31150, Error, customDerivativeSignatureMismatch, "invalid custom derivative. could not resolve function with expected signature '$0'") DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 52df3a0e0..65ad121a0 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -319,6 +319,8 @@ namespace Slang const Substitutions* substs, InterfaceDecl* interfaceDecl); + ParameterDirection getParameterDirection(VarDeclBase* varDecl); + enum class UserDefinedAttributeTargets { None = 0, -- cgit v1.2.3