From bee74b16eafa64ccc33bb386a1dc753cd6c41a82 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 26 Oct 2023 14:01:26 -0700 Subject: Add more diagnostics around use of custom derivatives. (#3291) Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index f75f84e21..15831ba26 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7203,10 +7203,17 @@ namespace Slang template void checkDerivativeAttributeImpl( SemanticsVisitor* visitor, + Decl* funcDecl, TDerivativeAttr* attr, const List& imaginaryArguments, const List& expectedParamDirections) { + if (isInterfaceRequirement(funcDecl)) + { + visitor->getSink()->diagnose(attr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); + return; + } + SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); @@ -7264,6 +7271,20 @@ namespace Slang // We'll detect both these incorrect cases here and issue an appropriate diagnostic. // auto funcType = as(calleeDeclRef->type); + if (!funcType) + { + // The best candidate does not have a function type. + // If we reach here, it means the function is a generic and we can't deduce the + // generic arguments from imaginary argument list. + // In this case we issue a diagnostic to ask the user to explicitly provide the arguments. + visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + return; + } + if (isInterfaceRequirement(calleeDeclRef->declRef.getDecl())) + { + visitor->getSink()->diagnose(attr, Diagnostics::cannotUseInterfaceRequirementAsDerivative); + return; + } for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii) { // Check if the resolved invoke argument type is an error type. @@ -7511,6 +7532,16 @@ namespace Slang visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } + if (isInterfaceRequirement(calleeFunc)) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); + return; + } + if (isInterfaceRequirement(funcDecl)) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotUseInterfaceRequirementAsDerivative); + return; + } if (auto existingModifier = _findModifier(calleeFunc)) { @@ -7546,7 +7577,7 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) @@ -7557,7 +7588,7 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr) @@ -7568,7 +7599,7 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } template -- cgit v1.2.3