diff options
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 37 |
1 files changed, 34 insertions, 3 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index f75f84e21..15831ba26 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7203,10 +7203,17 @@ namespace Slang template<typename TDerivativeAttr> void checkDerivativeAttributeImpl( SemanticsVisitor* visitor, + Decl* funcDecl, TDerivativeAttr* attr, const List<Expr*>& imaginaryArguments, const List<ParameterDirection>& expectedParamDirections) { + if (isInterfaceRequirement(funcDecl)) + { + visitor->getSink()->diagnose(attr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); + return; + } + SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); @@ -7264,6 +7271,20 @@ namespace Slang // We'll detect both these incorrect cases here and issue an appropriate diagnostic. // auto funcType = as<FuncType>(calleeDeclRef->type); + if (!funcType) + { + // The best candidate does not have a function type. + // If we reach here, it means the function is a generic and we can't deduce the + // generic arguments from imaginary argument list. + // In this case we issue a diagnostic to ask the user to explicitly provide the arguments. + visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + return; + } + if (isInterfaceRequirement(calleeDeclRef->declRef.getDecl())) + { + visitor->getSink()->diagnose(attr, Diagnostics::cannotUseInterfaceRequirementAsDerivative); + return; + } for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii) { // Check if the resolved invoke argument type is an error type. @@ -7511,6 +7532,16 @@ namespace Slang visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } + if (isInterfaceRequirement(calleeFunc)) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); + return; + } + if (isInterfaceRequirement(funcDecl)) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotUseInterfaceRequirementAsDerivative); + return; + } if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc)) { @@ -7546,7 +7577,7 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) @@ -7557,7 +7588,7 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr) @@ -7568,7 +7599,7 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } template<typename TDerivativeAttr, typename TDerivativeOfAttr> |
