From b40b711f54748145ed1340f2a3aa626dcb42b699 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 21 Jul 2023 16:28:22 -0400 Subject: Fix data-flow analysis not propagating diff property through differentiable calls (#3010) * Add test for nodiff diagnostic for non-diff call propagated through diff call * Add logic to disambiguate calls to differentiable and non-differentiable methods * Add expected results for test * Simplify test * Update slang-ir-check-differentiability.cpp * Added comments for TreatAsDifferentiableExpr flavors --------- Co-authored-by: Yong He --- tests/diagnostics/autodiff-data-flow-4.slang | 30 ++++++++++++++++++++++ .../autodiff-data-flow-4.slang.expected | 8 ++++++ 2 files changed, 38 insertions(+) create mode 100644 tests/diagnostics/autodiff-data-flow-4.slang create mode 100644 tests/diagnostics/autodiff-data-flow-4.slang.expected (limited to 'tests') diff --git a/tests/diagnostics/autodiff-data-flow-4.slang b/tests/diagnostics/autodiff-data-flow-4.slang new file mode 100644 index 000000000..08fabf954 --- /dev/null +++ b/tests/diagnostics/autodiff-data-flow-4.slang @@ -0,0 +1,30 @@ +//DIAGNOSTIC_TEST:SIMPLE: + +float nonDiff(float x) +{ + return x; +} + +[Differentiable] +float f(float x) +{ + return x * x; +} + +[Differentiable] +float h(float x) +{ + float val = 0; + + // call to non-differentiable method + // (should error if the data-flow propagation works properly.) + // + float y = nonDiff(x); + + // call to a differentiable method, using the + // result of a non-differentiable call. + // + val = f(y); + + return val; +} diff --git a/tests/diagnostics/autodiff-data-flow-4.slang.expected b/tests/diagnostics/autodiff-data-flow-4.slang.expected new file mode 100644 index 000000000..69f5d8707 --- /dev/null +++ b/tests/diagnostics/autodiff-data-flow-4.slang.expected @@ -0,0 +1,8 @@ +result code = -1 +standard error = { +tests/diagnostics/autodiff-data-flow-4.slang(29): error 41020: derivative cannot be propagated through call to non-backward-differentiable function `nonDiff`, use 'no_diff' to clarify intention. + float y = nonDiff(x); + ^ +} +standard output = { +} -- cgit v1.2.3