summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-18 12:37:27 -0800
committerGitHub <noreply@github.com>2022-11-18 12:37:27 -0800
commitd58e08f8237a1888ceaad53402d534679ea83b1a (patch)
treee66838e0dc31fc12ebd7c1acecbb5060e8808366 /tests
parent0a050a439fa91b66f2020421d4fec3e60aed4112 (diff)
Data flow validation pass for diagnosing derivative loss. (#2523)
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/differential-method-synthesis.slang2
-rw-r--r--tests/diagnostics/autodiff-data-flow.slang38
-rw-r--r--tests/diagnostics/autodiff-data-flow.slang.expected11
-rw-r--r--tests/diagnostics/autodiff.slang27
-rw-r--r--tests/diagnostics/autodiff.slang.expected14
5 files changed, 91 insertions, 1 deletions
diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang
index 4c96779f9..3220976e7 100644
--- a/tests/autodiff/differential-method-synthesis.slang
+++ b/tests/autodiff/differential-method-synthesis.slang
@@ -31,7 +31,7 @@ A f(A a)
aout.y = 2 * a.b.x;
aout.b.x = 5 * a.b.x;
- return nonDiff(aout);
+ return no_diff(nonDiff(aout));
}
[numthreads(1, 1, 1)]
diff --git a/tests/diagnostics/autodiff-data-flow.slang b/tests/diagnostics/autodiff-data-flow.slang
new file mode 100644
index 000000000..93c76c07e
--- /dev/null
+++ b/tests/diagnostics/autodiff-data-flow.slang
@@ -0,0 +1,38 @@
+//DIAGNOSTIC_TEST:SIMPLE:
+
+float nonDiff(float x)
+{
+ return x;
+}
+
+[ForwardDifferentiable]
+float f(float x)
+{
+ float val = 0;
+ if (x > 5)
+ val = x + 1;
+ else
+ val = nonDiff(x * 2.0f);
+ // Not all path propagates derivatives through.
+ return val;
+}
+
+// error: function does not return a differentiable value.
+[ForwardDifferentiable]
+void g(float x)
+{
+ float val = 0;
+ if (x > 5)
+ val = x + 1;
+ return;
+}
+
+
+[ForwardDifferentiable]
+float h(float x)
+{
+ float val = 0;
+ // no diagnostic by clarifying intention.
+ val = no_diff(nonDiff(x * 2.0f));
+ return val;
+}
diff --git a/tests/diagnostics/autodiff-data-flow.slang.expected b/tests/diagnostics/autodiff-data-flow.slang.expected
new file mode 100644
index 000000000..869ce42b3
--- /dev/null
+++ b/tests/diagnostics/autodiff-data-flow.slang.expected
@@ -0,0 +1,11 @@
+result code = -1
+standard error = {
+tests/diagnostics/autodiff-data-flow.slang(15): error 41020: derivative cannot be propagated through call to non-differentiable function `nonDiff`, use 'no_diff' to clarify intention.
+ val = nonDiff(x * 2.0f);
+ ^
+tests/diagnostics/autodiff-data-flow.slang(22): error 41021: a differentiable function must have at least one differentiable output.
+void g(float x)
+ ^
+}
+standard output = {
+}
diff --git a/tests/diagnostics/autodiff.slang b/tests/diagnostics/autodiff.slang
new file mode 100644
index 000000000..f9fed6753
--- /dev/null
+++ b/tests/diagnostics/autodiff.slang
@@ -0,0 +1,27 @@
+//DIAGNOSTIC_TEST:SIMPLE:
+
+float nonDiff(float x)
+{
+ return x;
+}
+
+[ForwardDifferentiable]
+float f(float x)
+{
+ float val = 0;
+ if (x > 5)
+ val = x + 1;
+ return val;
+}
+
+[ForwardDifferentiable]
+float m(float x)
+{
+ float x1 = no_diff x; // invalid use of no_diff here.
+ return no_diff f(x); // no_diff on a differentiable call has no meaning.
+}
+
+float n(float x)
+{
+ return no_diff nonDiff(x); // no_diff in a non-differentiable function
+} \ No newline at end of file
diff --git a/tests/diagnostics/autodiff.slang.expected b/tests/diagnostics/autodiff.slang.expected
new file mode 100644
index 000000000..cd97bce76
--- /dev/null
+++ b/tests/diagnostics/autodiff.slang.expected
@@ -0,0 +1,14 @@
+result code = -1
+standard error = {
+tests/diagnostics/autodiff.slang(20): error 38031: 'no_diff' can only be used to decorate a call.
+ float x1 = no_diff x; // invalid use of no_diff here.
+ ^~~~~~~
+tests/diagnostics/autodiff.slang(21): error 38032: use 'no_diff' on a call to a differentiable function has no meaning.
+ return no_diff f(x); // no_diff on a differentiable call has no meaning.
+ ^~~~~~~
+tests/diagnostics/autodiff.slang(26): error 38033: cannot use 'no_diff' in a non-differentiable function.
+ return no_diff nonDiff(x); // no_diff in a non-differentiable function
+ ^~~~~~~
+}
+standard output = {
+}