From a670bafc121c20168624f70a388dbe8556402c7f Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Wed, 9 Jul 2025 11:25:29 -0500 Subject: no_diff diagnostics improvement (#7655) close #6286. This PR is to improve the diagnostics for no_diff usage. In a differentiable function, any calls to a non-diff function with constant arguments should not require no_diff attribute. This PR adds this extra check at `checkAutoDiffUsages` where it checks the differentiability on IR. In a differentiable method, we will force to use `[NoDiffThis]` attribute if there is access to non-differentiable `This` type. Once this access is detected we will report a warning to bring users attention that this access won't generate any derivative, they have to use `[NoDiffThis]` to suppress that warning. This PR adds this check at type checking stage, because it's the easiest way to find out all the `This` accesses. --- ...to-nodiff-function-diagnostic-improvement.slang | 40 ++++++++++++++++++ ...o-nodiff-function-diagnostic-improvement1.slang | 48 ++++++++++++++++++++++ tests/diagnostics/force-no-diff-this.slang | 42 +++++++++++++++++++ 3 files changed, 130 insertions(+) create mode 100644 tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang create mode 100644 tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang create mode 100644 tests/diagnostics/force-no-diff-this.slang (limited to 'tests/diagnostics') diff --git a/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang new file mode 100644 index 000000000..961ac75d5 --- /dev/null +++ b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang @@ -0,0 +1,40 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): + +float someNoDiffFunc(float x, no_diff float y) +{ + return x * x + y * y; +} + +// Previously, when we call a no-diff function side a differntiable function, we will have to use no_diff to tell compiler that this is intended. +// However, if the parameter is just a constant, there is no need to use no_diff, because constant won't carry any derivative information. +// Therefore, this test is to check we won't report any error when the parameter is a constant in this case. +[Differentiable] +float eval(float x) +{ + // CHECK-NOT: ([[# @LINE+1]]): error 41020 + return exp(x) - someNoDiffFunc(1.0f, x); +} + +[Differentiable] +float eval1(float x) +{ + // CHECK: ([[# @LINE+1]]): error 41020 + return exp(x) - someNoDiffFunc(x, 1.0); +} + +RWStructuredBuffer output; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + var x = diffPair(2.0f); + bwd_diff(eval)(x, 1.0f); + + output[0] = x.d; + + var x1 = diffPair(2.0f); + bwd_diff(eval1)(x1, 1.0f); + output[1] = x1.d; +} + diff --git a/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang new file mode 100644 index 000000000..f27c6ec6b --- /dev/null +++ b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang @@ -0,0 +1,48 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): + + +// Similar to const-to-nodiff-function-diagnostic-improvement.slang, but with a CoopVec type +// to reproduce a more realistic scenario. +extension CoopVec : IDifferentiable +{ + typealias Differential = CoopVec; +}; + +[BackwardDerivativeOf(exp)] +void exp_BackwardAutoDiff(inout DifferentialPair> p0, CoopVec.Differential dResult) +{ + p0 = diffPair(p0.p, dResult * exp(p0.p)); +} + +[Differentiable] +CoopVec eval(CoopVec x) +{ + // CHECK-NOT: ([[# @LINE+1]]): error 41020 + return exp(x) - CoopVec(1.); +} + +[Differentiable] +CoopVec eval1(CoopVec x) +{ + // test.slang(25): error 41020: derivative cannot be propagated through call to non-backward-differentiable function `CoopVec.$init`, use 'no_diff' to clarify intention. + // CHECK: ([[# @LINE+1]]): error 41020 + return exp(x) - CoopVec(x[0]); +} + + +RWStructuredBuffer output; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + var x = diffPair(CoopVec(2.0f), CoopVec(1.0f)); + bwd_diff(eval)(x, CoopVec(1.0f)); + + output[0] = x.d[0]; + + var x1 = diffPair(CoopVec(2.0f), CoopVec(1.0f)); + bwd_diff(eval1)(x1, CoopVec(1.0f)); + output[1] = x1.d[1]; +} + diff --git a/tests/diagnostics/force-no-diff-this.slang b/tests/diagnostics/force-no-diff-this.slang new file mode 100644 index 000000000..ae1464ffb --- /dev/null +++ b/tests/diagnostics/force-no-diff-this.slang @@ -0,0 +1,42 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): + +struct MyStruct where T: __BuiltinFloatingPointType +{ + float a; + __init(float a) { this.a = a;} + + [Differentiable] + T eval(T x) + { + //CHECK: ([[# @LINE+1]]): warning 31159 + return exp(x * T(a) * T(a)); + } + + [Differentiable] + [NoDiffThis] + T eval1(T x) + { + //CHECK-NOT: ([[# @LINE+1]]): warning 31159 + return exp(x * T(a) * T(a)); + } +}; + +[Differentiable] +float evalFunc(float x) +{ + MyStruct s = {x}; + return s.eval(x) + s.eval1(x); +} + +RWStructuredBuffer output; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + var x = diffPair(2.0f); + bwd_diff(evalFunc)(x, 1.0f); + + output[0] = x.d; +} + -- cgit v1.2.3