From d58e08f8237a1888ceaad53402d534679ea83b1a Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 18 Nov 2022 12:37:27 -0800 Subject: Data flow validation pass for diagnosing derivative loss. (#2523) --- tests/diagnostics/autodiff-data-flow.slang | 38 ++++++++++++++++++++++ .../diagnostics/autodiff-data-flow.slang.expected | 11 +++++++ tests/diagnostics/autodiff.slang | 27 +++++++++++++++ tests/diagnostics/autodiff.slang.expected | 14 ++++++++ 4 files changed, 90 insertions(+) create mode 100644 tests/diagnostics/autodiff-data-flow.slang create mode 100644 tests/diagnostics/autodiff-data-flow.slang.expected create mode 100644 tests/diagnostics/autodiff.slang create mode 100644 tests/diagnostics/autodiff.slang.expected (limited to 'tests/diagnostics') diff --git a/tests/diagnostics/autodiff-data-flow.slang b/tests/diagnostics/autodiff-data-flow.slang new file mode 100644 index 000000000..93c76c07e --- /dev/null +++ b/tests/diagnostics/autodiff-data-flow.slang @@ -0,0 +1,38 @@ +//DIAGNOSTIC_TEST:SIMPLE: + +float nonDiff(float x) +{ + return x; +} + +[ForwardDifferentiable] +float f(float x) +{ + float val = 0; + if (x > 5) + val = x + 1; + else + val = nonDiff(x * 2.0f); + // Not all path propagates derivatives through. + return val; +} + +// error: function does not return a differentiable value. +[ForwardDifferentiable] +void g(float x) +{ + float val = 0; + if (x > 5) + val = x + 1; + return; +} + + +[ForwardDifferentiable] +float h(float x) +{ + float val = 0; + // no diagnostic by clarifying intention. + val = no_diff(nonDiff(x * 2.0f)); + return val; +} diff --git a/tests/diagnostics/autodiff-data-flow.slang.expected b/tests/diagnostics/autodiff-data-flow.slang.expected new file mode 100644 index 000000000..869ce42b3 --- /dev/null +++ b/tests/diagnostics/autodiff-data-flow.slang.expected @@ -0,0 +1,11 @@ +result code = -1 +standard error = { +tests/diagnostics/autodiff-data-flow.slang(15): error 41020: derivative cannot be propagated through call to non-differentiable function `nonDiff`, use 'no_diff' to clarify intention. + val = nonDiff(x * 2.0f); + ^ +tests/diagnostics/autodiff-data-flow.slang(22): error 41021: a differentiable function must have at least one differentiable output. +void g(float x) + ^ +} +standard output = { +} diff --git a/tests/diagnostics/autodiff.slang b/tests/diagnostics/autodiff.slang new file mode 100644 index 000000000..f9fed6753 --- /dev/null +++ b/tests/diagnostics/autodiff.slang @@ -0,0 +1,27 @@ +//DIAGNOSTIC_TEST:SIMPLE: + +float nonDiff(float x) +{ + return x; +} + +[ForwardDifferentiable] +float f(float x) +{ + float val = 0; + if (x > 5) + val = x + 1; + return val; +} + +[ForwardDifferentiable] +float m(float x) +{ + float x1 = no_diff x; // invalid use of no_diff here. + return no_diff f(x); // no_diff on a differentiable call has no meaning. +} + +float n(float x) +{ + return no_diff nonDiff(x); // no_diff in a non-differentiable function +} \ No newline at end of file diff --git a/tests/diagnostics/autodiff.slang.expected b/tests/diagnostics/autodiff.slang.expected new file mode 100644 index 000000000..cd97bce76 --- /dev/null +++ b/tests/diagnostics/autodiff.slang.expected @@ -0,0 +1,14 @@ +result code = -1 +standard error = { +tests/diagnostics/autodiff.slang(20): error 38031: 'no_diff' can only be used to decorate a call. + float x1 = no_diff x; // invalid use of no_diff here. + ^~~~~~~ +tests/diagnostics/autodiff.slang(21): error 38032: use 'no_diff' on a call to a differentiable function has no meaning. + return no_diff f(x); // no_diff on a differentiable call has no meaning. + ^~~~~~~ +tests/diagnostics/autodiff.slang(26): error 38033: cannot use 'no_diff' in a non-differentiable function. + return no_diff nonDiff(x); // no_diff in a non-differentiable function + ^~~~~~~ +} +standard output = { +} -- cgit v1.2.3