diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-03 16:44:33 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-03 16:44:33 -0800 |
| commit | 228e71dab7dfa18ece979f4099ec0c7d1e37e5ff (patch) | |
| tree | ff357f4aaed2dab25ae9e3665a97a7f3e6be32ef /source/slang/diff.meta.slang | |
| parent | ee49a62083d28353812185fd0f0c04fb50ca6be0 (diff) | |
Overhaul `transposeParameterBlock` to support `inout` params. (#2621)
* Overhaul `transposeParameterBlock` to support `inout` params.
* Small bug fixes.
* Bug fix on differentiable intrinsic specialization.
* Fixes.
* Run autodiff tests on CPU.
* Clean up.
* More bug fixes.,
* Add test coverage on inout param.
* Fix language server hinting for transcribed mutable params.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index adbf8ae48..055c44135 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -313,6 +313,24 @@ DifferentialPair<vector<T, N>> __d_sin_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_UNARY(T, N, __d_sin, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(sin)] +void __d_sin(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair( + dpx.p, + T.dmul(cos(dpx.p), dOut)); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(sin)] +void __d_sin_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + dpx = diffPair( + dpx.p, + vector<T, N>.dmul(cos(dpx.p), dOut)); +} + // Cosine __generic<T : __BuiltinFloatingPointType> @@ -331,6 +349,24 @@ DifferentialPair<vector<T, N>> __d_cos_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_UNARY(T, N, __d_cos, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(cos)] +void __d_cos(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair( + dpx.p, + T.dmul(-sin(dpx.p), dOut)); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(cos)] +void __d_cos_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + dpx = diffPair( + dpx.p, + vector<T, N>.dmul(-sin(dpx.p), dOut)); +} + // Base-e logarithm __generic<T : __BuiltinFloatingPointType> |
