diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-10-28 15:47:58 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-28 15:47:58 -0400 |
| commit | b61be5e6fb7fe1c4ec8228cdf73f49f11e5a0ac9 (patch) | |
| tree | 0e392546b41a55d36a874f2a82110867e6dd422c /source/slang/diff.meta.slang | |
| parent | 0557a199d2eb205bf133c8fc111cce3a19336fde (diff) | |
Assorted auto-diff enhancements for increased performance & more streamlined auto-diff results (#5394)
* Various AD enhancements
* Fix issue with pt-loop test
* Update pt-loop.slang
* More fixes for perf. Final minimal context test now passes.
* Fix issue with loop-elimination pass not running after dce
* Try fix wgpu test by removing select operator
* Disable wgpu
* Delete out.wgsl
* Remove comments
* Update slang-ir-util.cpp
* Fix header relative paths for slang-embed
* Disbale wgpu for a few other tests
* Better way of determining which params to ignore for side-effects
* Update slang-ir-dce.cpp
* Fix issue with circular reference from previous AD pass being left behind for the next AD pass
* Update slang-ir-dce.cpp
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index b39d91494..6042ff5cc 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -855,17 +855,20 @@ struct DiffTensorView return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T>(x))); } + [ForceInline] __generic<let N : int> DifferentialPair<T> __load_forward(vector<uint, N> x) { return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T, N>(x))); } + [ForceInline] void __load_backward(uint x, T.Differential dOut) { diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut)); } + [ForceInline] __generic<let N : int> void __load_backward(vector<uint, N> x, T.Differential dOut) { @@ -894,11 +897,13 @@ struct DiffTensorView diff.store_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d)); } + [ForceInline] void __store_backward(uint x, inout DifferentialPair<T> dpval) { dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x))); } + [ForceInline] __generic<let N : int> void __store_backward(vector<uint, N> x, inout DifferentialPair<T> dpval) { @@ -999,11 +1004,13 @@ struct DiffTensorView return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.loadOnce_forward<T, N>(x))); } + [ForceInline] void __loadOnce_backward(uint x, T.Differential dOut) { diff.loadOnce_backward<T>(x, reinterpret<T, T.Differential>(dOut)); } + [ForceInline] __generic<let N : int> void __loadOnce_backward(vector<uint, N> x, T.Differential dOut) { @@ -1032,11 +1039,13 @@ struct DiffTensorView diff.storeOnce_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d)); } + [ForceInline] void __storeOnce_backward(uint x, inout DifferentialPair<T> dpval) { dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.storeOnce_backward<T>(x))); } + [ForceInline] __generic<let N : int> void __storeOnce_backward(vector<uint, N> x, inout DifferentialPair<T> dpval) { |
