diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-18 12:37:27 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-18 12:37:27 -0800 |
| commit | d58e08f8237a1888ceaad53402d534679ea83b1a (patch) | |
| tree | e66838e0dc31fc12ebd7c1acecbb5060e8808366 /tests | |
| parent | 0a050a439fa91b66f2020421d4fec3e60aed4112 (diff) | |
Data flow validation pass for diagnosing derivative loss. (#2523)
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/differential-method-synthesis.slang | 2 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-data-flow.slang | 38 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-data-flow.slang.expected | 11 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff.slang | 27 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff.slang.expected | 14 |
5 files changed, 91 insertions, 1 deletions
diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang index 4c96779f9..3220976e7 100644 --- a/tests/autodiff/differential-method-synthesis.slang +++ b/tests/autodiff/differential-method-synthesis.slang @@ -31,7 +31,7 @@ A f(A a) aout.y = 2 * a.b.x; aout.b.x = 5 * a.b.x; - return nonDiff(aout); + return no_diff(nonDiff(aout)); } [numthreads(1, 1, 1)] diff --git a/tests/diagnostics/autodiff-data-flow.slang b/tests/diagnostics/autodiff-data-flow.slang new file mode 100644 index 000000000..93c76c07e --- /dev/null +++ b/tests/diagnostics/autodiff-data-flow.slang @@ -0,0 +1,38 @@ +//DIAGNOSTIC_TEST:SIMPLE: + +float nonDiff(float x) +{ + return x; +} + +[ForwardDifferentiable] +float f(float x) +{ + float val = 0; + if (x > 5) + val = x + 1; + else + val = nonDiff(x * 2.0f); + // Not all path propagates derivatives through. + return val; +} + +// error: function does not return a differentiable value. +[ForwardDifferentiable] +void g(float x) +{ + float val = 0; + if (x > 5) + val = x + 1; + return; +} + + +[ForwardDifferentiable] +float h(float x) +{ + float val = 0; + // no diagnostic by clarifying intention. + val = no_diff(nonDiff(x * 2.0f)); + return val; +} diff --git a/tests/diagnostics/autodiff-data-flow.slang.expected b/tests/diagnostics/autodiff-data-flow.slang.expected new file mode 100644 index 000000000..869ce42b3 --- /dev/null +++ b/tests/diagnostics/autodiff-data-flow.slang.expected @@ -0,0 +1,11 @@ +result code = -1 +standard error = { +tests/diagnostics/autodiff-data-flow.slang(15): error 41020: derivative cannot be propagated through call to non-differentiable function `nonDiff`, use 'no_diff' to clarify intention. + val = nonDiff(x * 2.0f); + ^ +tests/diagnostics/autodiff-data-flow.slang(22): error 41021: a differentiable function must have at least one differentiable output. +void g(float x) + ^ +} +standard output = { +} diff --git a/tests/diagnostics/autodiff.slang b/tests/diagnostics/autodiff.slang new file mode 100644 index 000000000..f9fed6753 --- /dev/null +++ b/tests/diagnostics/autodiff.slang @@ -0,0 +1,27 @@ +//DIAGNOSTIC_TEST:SIMPLE: + +float nonDiff(float x) +{ + return x; +} + +[ForwardDifferentiable] +float f(float x) +{ + float val = 0; + if (x > 5) + val = x + 1; + return val; +} + +[ForwardDifferentiable] +float m(float x) +{ + float x1 = no_diff x; // invalid use of no_diff here. + return no_diff f(x); // no_diff on a differentiable call has no meaning. +} + +float n(float x) +{ + return no_diff nonDiff(x); // no_diff in a non-differentiable function +}
\ No newline at end of file diff --git a/tests/diagnostics/autodiff.slang.expected b/tests/diagnostics/autodiff.slang.expected new file mode 100644 index 000000000..cd97bce76 --- /dev/null +++ b/tests/diagnostics/autodiff.slang.expected @@ -0,0 +1,14 @@ +result code = -1 +standard error = { +tests/diagnostics/autodiff.slang(20): error 38031: 'no_diff' can only be used to decorate a call. + float x1 = no_diff x; // invalid use of no_diff here. + ^~~~~~~ +tests/diagnostics/autodiff.slang(21): error 38032: use 'no_diff' on a call to a differentiable function has no meaning. + return no_diff f(x); // no_diff on a differentiable call has no meaning. + ^~~~~~~ +tests/diagnostics/autodiff.slang(26): error 38033: cannot use 'no_diff' in a non-differentiable function. + return no_diff nonDiff(x); // no_diff in a non-differentiable function + ^~~~~~~ +} +standard output = { +} |
