diff options
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 37 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 8 | ||||
| -rw-r--r-- | tests/diagnostics/custom-derivative-generic.slang | 51 |
3 files changed, 92 insertions, 4 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index f75f84e21..15831ba26 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7203,10 +7203,17 @@ namespace Slang template<typename TDerivativeAttr> void checkDerivativeAttributeImpl( SemanticsVisitor* visitor, + Decl* funcDecl, TDerivativeAttr* attr, const List<Expr*>& imaginaryArguments, const List<ParameterDirection>& expectedParamDirections) { + if (isInterfaceRequirement(funcDecl)) + { + visitor->getSink()->diagnose(attr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); + return; + } + SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); @@ -7264,6 +7271,20 @@ namespace Slang // We'll detect both these incorrect cases here and issue an appropriate diagnostic. // auto funcType = as<FuncType>(calleeDeclRef->type); + if (!funcType) + { + // The best candidate does not have a function type. + // If we reach here, it means the function is a generic and we can't deduce the + // generic arguments from imaginary argument list. + // In this case we issue a diagnostic to ask the user to explicitly provide the arguments. + visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + return; + } + if (isInterfaceRequirement(calleeDeclRef->declRef.getDecl())) + { + visitor->getSink()->diagnose(attr, Diagnostics::cannotUseInterfaceRequirementAsDerivative); + return; + } for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii) { // Check if the resolved invoke argument type is an error type. @@ -7511,6 +7532,16 @@ namespace Slang visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } + if (isInterfaceRequirement(calleeFunc)) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); + return; + } + if (isInterfaceRequirement(funcDecl)) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotUseInterfaceRequirementAsDerivative); + return; + } if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc)) { @@ -7546,7 +7577,7 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) @@ -7557,7 +7588,7 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr) @@ -7568,7 +7599,7 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } template<typename TDerivativeAttr, typename TDerivativeOfAttr> diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 656e28701..79889a39d 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -393,7 +393,13 @@ DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot res DIAGNOSTIC(31148, Error, cannotResolveDerivativeFunction, "cannot resolve the custom derivative function") DIAGNOSTIC(31149, Error, customDerivativeSignatureMismatchAtPosition, "invalid custom derivative. parameter type mismatch at position $0. expected '$1', got '$2'") DIAGNOSTIC(31150, Error, customDerivativeSignatureMismatch, "invalid custom derivative. could not resolve function with expected signature '$0'") - +DIAGNOSTIC(31151, Error, cannotResolveGenericArgumentForDerivativeFunction, + "The generic arguments to the derivative function cannot be deduced from the parameter list of the original function. " + "Consider using [ForwardDerivative], [BackwardDerivative] or [PrimalSubstitute] attributes on the primal function" + " with explicit generic arguments to associate it with a generic derivative function. Note that [ForwardDerivativeOf], " + "[BackwardDerivativeOf], and [PrimalSubstituteOf] attributes are not supported when the generic arguments to the derivatives cannot be automatically deduced.") +DIAGNOSTIC(31152, Error, cannotAssociateInterfaceRequirementWithDerivative, "cannot associate an interface requirement with a derivative.") +DIAGNOSTIC(31153, Error, cannotUseInterfaceRequirementAsDerivative, "cannot use an interface requirement as a derivative.") DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") // Enums diff --git a/tests/diagnostics/custom-derivative-generic.slang b/tests/diagnostics/custom-derivative-generic.slang new file mode 100644 index 000000000..5f2cd9951 --- /dev/null +++ b/tests/diagnostics/custom-derivative-generic.slang @@ -0,0 +1,51 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; + +interface IFoo +{ + static float bar1(float x); + + // CHECK-DAG: {{.*}}(13): error 31152 + [PrimalSubstitute(bar1)] + static float bar(float x); + + static DifferentialPair<float> dd(DifferentialPair<float> x); +} + +__generic<let N:int> +float f(float x) +{ + return N*x*x; +} + +// CHECK-DAG: {{.*}}(26): error 31153 +[ForwardDerivative(IFoo.dd)] +float bbb(float x); + +// CHECK-DAG: {{.*}}(30): error 31152 +[ForwardDerivativeOf(IFoo.bar)] +DifferentialPair<float> dd1(DifferentialPair<float> x) +{ + return x; +} + +// CHECK-DAG: {{.*}}(37): error 31151 +[BackwardDerivative(f)] +DifferentialPair<float> df<let N:int>(inout DifferentialPair<float> x, float dOut) +{ + var primal = x.p * x.p; + var diff = 2 * x.p * x.d * N; + return DifferentialPair<float>(primal, diff); +} +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(3.0, 1.0); + outputBuffer[1] = __fwd_diff(f<3>)(dpa).d; // Expect: 6.0 + } +} |
