summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp37
1 files changed, 34 insertions, 3 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index f75f84e21..15831ba26 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -7203,10 +7203,17 @@ namespace Slang
template<typename TDerivativeAttr>
void checkDerivativeAttributeImpl(
SemanticsVisitor* visitor,
+ Decl* funcDecl,
TDerivativeAttr* attr,
const List<Expr*>& imaginaryArguments,
const List<ParameterDirection>& expectedParamDirections)
{
+ if (isInterfaceRequirement(funcDecl))
+ {
+ visitor->getSink()->diagnose(attr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative);
+ return;
+ }
+
SemanticsContext::ExprLocalScope scope;
auto ctx = visitor->withExprLocalScope(&scope);
auto subVisitor = SemanticsVisitor(ctx);
@@ -7264,6 +7271,20 @@ namespace Slang
// We'll detect both these incorrect cases here and issue an appropriate diagnostic.
//
auto funcType = as<FuncType>(calleeDeclRef->type);
+ if (!funcType)
+ {
+ // The best candidate does not have a function type.
+ // If we reach here, it means the function is a generic and we can't deduce the
+ // generic arguments from imaginary argument list.
+ // In this case we issue a diagnostic to ask the user to explicitly provide the arguments.
+ visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction);
+ return;
+ }
+ if (isInterfaceRequirement(calleeDeclRef->declRef.getDecl()))
+ {
+ visitor->getSink()->diagnose(attr, Diagnostics::cannotUseInterfaceRequirementAsDerivative);
+ return;
+ }
for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii)
{
// Check if the resolved invoke argument type is an error type.
@@ -7511,6 +7532,16 @@ namespace Slang
visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
return;
}
+ if (isInterfaceRequirement(calleeFunc))
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative);
+ return;
+ }
+ if (isInterfaceRequirement(funcDecl))
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotUseInterfaceRequirementAsDerivative);
+ return;
+ }
if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc))
{
@@ -7546,7 +7577,7 @@ namespace Slang
return;
ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions);
+ checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr)
@@ -7557,7 +7588,7 @@ namespace Slang
return;
ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions);
+ checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr)
@@ -7568,7 +7599,7 @@ namespace Slang
return;
ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions);
+ checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
template<typename TDerivativeAttr, typename TDerivativeOfAttr>