summaryrefslogtreecommitdiff
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2025-07-09 11:25:29 -0500
committerGitHub <noreply@github.com>2025-07-09 09:25:29 -0700
commita670bafc121c20168624f70a388dbe8556402c7f (patch)
tree79b48a80e7abc0744193716e400bb57a6c026bad /source/slang/diff.meta.slang
parenta7cb36901ccaf8297136c58c1451d6e04420af73 (diff)
no_diff diagnostics improvement (#7655)
close #6286. This PR is to improve the diagnostics for no_diff usage. In a differentiable function, any calls to a non-diff function with constant arguments should not require no_diff attribute. This PR adds this extra check at `checkAutoDiffUsages` where it checks the differentiability on IR. In a differentiable method, we will force to use `[NoDiffThis]` attribute if there is access to non-differentiable `This` type. Once this access is detected we will report a warning to bring users attention that this access won't generate any derivative, they have to use `[NoDiffThis]` to suppress that warning. This PR adds this check at type checking stage, because it's the easiest way to find out all the `This` accesses.
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang8
1 files changed, 8 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 542983049..b22d91595 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -992,10 +992,12 @@ struct DiffTensorView
[BackwardDerivative(__load_backward)]
[ForwardDerivative(__load_forward)]
+ [NoDiffThis]
T load(uint i) { return primal.load(i); }
[BackwardDerivative(__load_backward)]
[ForwardDerivative(__load_forward)]
+ [NoDiffThis]
__generic<let N : int>
T load(vector<uint, N> i) { return primal.load(i); }
@@ -1026,10 +1028,12 @@ struct DiffTensorView
[BackwardDerivative(__store_backward)]
[ForwardDerivative(__store_forward)]
+ [NoDiffThis]
void store(uint x, T val) { primal.store(x, val); }
[BackwardDerivative(__store_backward)]
[ForwardDerivative(__store_forward)]
+ [NoDiffThis]
__generic<let N : int>
void store(vector<uint, N> x, T val) { primal.store(x, val); }
@@ -1135,10 +1139,12 @@ struct DiffTensorView
[BackwardDerivative(__loadOnce_backward)]
[ForwardDerivative(__loadOnce_forward)]
+ [NoDiffThis]
T loadOnce(uint i) { return primal.load(i); }
[BackwardDerivative(__loadOnce_backward)]
[ForwardDerivative(__loadOnce_forward)]
+ [NoDiffThis]
__generic<let N : int>
T loadOnce(vector<uint, N> i) { return primal.load(i); }
@@ -1168,10 +1174,12 @@ struct DiffTensorView
[BackwardDerivative(__storeOnce_backward)]
[ForwardDerivative(__storeOnce_forward)]
+ [NoDiffThis]
void storeOnce(uint x, T val) { primal.store(x, val); }
[BackwardDerivative(__storeOnce_backward)]
[ForwardDerivative(__storeOnce_forward)]
+ [NoDiffThis]
__generic<let N : int>
void storeOnce(vector<uint, N> x, T val) { primal.store(x, val); }