summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-10-26 14:01:26 -0700
committerGitHub <noreply@github.com>2023-10-26 14:01:26 -0700
commitbee74b16eafa64ccc33bb386a1dc753cd6c41a82 (patch)
tree199a4575cfe6297bb494d33f47d68ecd1a35776d
parent927d176be9ba03be161375b8695de1f0a37f1785 (diff)
Add more diagnostics around use of custom derivatives. (#3291)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-check-decl.cpp37
-rw-r--r--source/slang/slang-diagnostic-defs.h8
-rw-r--r--tests/diagnostics/custom-derivative-generic.slang51
3 files changed, 92 insertions, 4 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index f75f84e21..15831ba26 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -7203,10 +7203,17 @@ namespace Slang
template<typename TDerivativeAttr>
void checkDerivativeAttributeImpl(
SemanticsVisitor* visitor,
+ Decl* funcDecl,
TDerivativeAttr* attr,
const List<Expr*>& imaginaryArguments,
const List<ParameterDirection>& expectedParamDirections)
{
+ if (isInterfaceRequirement(funcDecl))
+ {
+ visitor->getSink()->diagnose(attr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative);
+ return;
+ }
+
SemanticsContext::ExprLocalScope scope;
auto ctx = visitor->withExprLocalScope(&scope);
auto subVisitor = SemanticsVisitor(ctx);
@@ -7264,6 +7271,20 @@ namespace Slang
// We'll detect both these incorrect cases here and issue an appropriate diagnostic.
//
auto funcType = as<FuncType>(calleeDeclRef->type);
+ if (!funcType)
+ {
+ // The best candidate does not have a function type.
+ // If we reach here, it means the function is a generic and we can't deduce the
+ // generic arguments from imaginary argument list.
+ // In this case we issue a diagnostic to ask the user to explicitly provide the arguments.
+ visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction);
+ return;
+ }
+ if (isInterfaceRequirement(calleeDeclRef->declRef.getDecl()))
+ {
+ visitor->getSink()->diagnose(attr, Diagnostics::cannotUseInterfaceRequirementAsDerivative);
+ return;
+ }
for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii)
{
// Check if the resolved invoke argument type is an error type.
@@ -7511,6 +7532,16 @@ namespace Slang
visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
return;
}
+ if (isInterfaceRequirement(calleeFunc))
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative);
+ return;
+ }
+ if (isInterfaceRequirement(funcDecl))
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotUseInterfaceRequirementAsDerivative);
+ return;
+ }
if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc))
{
@@ -7546,7 +7577,7 @@ namespace Slang
return;
ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions);
+ checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr)
@@ -7557,7 +7588,7 @@ namespace Slang
return;
ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions);
+ checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr)
@@ -7568,7 +7599,7 @@ namespace Slang
return;
ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions);
+ checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
template<typename TDerivativeAttr, typename TDerivativeOfAttr>
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 656e28701..79889a39d 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -393,7 +393,13 @@ DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot res
DIAGNOSTIC(31148, Error, cannotResolveDerivativeFunction, "cannot resolve the custom derivative function")
DIAGNOSTIC(31149, Error, customDerivativeSignatureMismatchAtPosition, "invalid custom derivative. parameter type mismatch at position $0. expected '$1', got '$2'")
DIAGNOSTIC(31150, Error, customDerivativeSignatureMismatch, "invalid custom derivative. could not resolve function with expected signature '$0'")
-
+DIAGNOSTIC(31151, Error, cannotResolveGenericArgumentForDerivativeFunction,
+ "The generic arguments to the derivative function cannot be deduced from the parameter list of the original function. "
+ "Consider using [ForwardDerivative], [BackwardDerivative] or [PrimalSubstitute] attributes on the primal function"
+ " with explicit generic arguments to associate it with a generic derivative function. Note that [ForwardDerivativeOf], "
+ "[BackwardDerivativeOf], and [PrimalSubstituteOf] attributes are not supported when the generic arguments to the derivatives cannot be automatically deduced.")
+DIAGNOSTIC(31152, Error, cannotAssociateInterfaceRequirementWithDerivative, "cannot associate an interface requirement with a derivative.")
+DIAGNOSTIC(31153, Error, cannotUseInterfaceRequirementAsDerivative, "cannot use an interface requirement as a derivative.")
DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1")
// Enums
diff --git a/tests/diagnostics/custom-derivative-generic.slang b/tests/diagnostics/custom-derivative-generic.slang
new file mode 100644
index 000000000..5f2cd9951
--- /dev/null
+++ b/tests/diagnostics/custom-derivative-generic.slang
@@ -0,0 +1,51 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+
+interface IFoo
+{
+ static float bar1(float x);
+
+ // CHECK-DAG: {{.*}}(13): error 31152
+ [PrimalSubstitute(bar1)]
+ static float bar(float x);
+
+ static DifferentialPair<float> dd(DifferentialPair<float> x);
+}
+
+__generic<let N:int>
+float f(float x)
+{
+ return N*x*x;
+}
+
+// CHECK-DAG: {{.*}}(26): error 31153
+[ForwardDerivative(IFoo.dd)]
+float bbb(float x);
+
+// CHECK-DAG: {{.*}}(30): error 31152
+[ForwardDerivativeOf(IFoo.bar)]
+DifferentialPair<float> dd1(DifferentialPair<float> x)
+{
+ return x;
+}
+
+// CHECK-DAG: {{.*}}(37): error 31151
+[BackwardDerivative(f)]
+DifferentialPair<float> df<let N:int>(inout DifferentialPair<float> x, float dOut)
+{
+ var primal = x.p * x.p;
+ var diff = 2 * x.p * x.d * N;
+ return DifferentialPair<float>(primal, diff);
+}
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(3.0, 1.0);
+ outputBuffer[1] = __fwd_diff(f<3>)(dpa).d; // Expect: 6.0
+ }
+}