summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-10-28 15:47:58 -0400
committerGitHub <noreply@github.com>2024-10-28 15:47:58 -0400
commitb61be5e6fb7fe1c4ec8228cdf73f49f11e5a0ac9 (patch)
tree0e392546b41a55d36a874f2a82110867e6dd422c /source/slang/diff.meta.slang
parent0557a199d2eb205bf133c8fc111cce3a19336fde (diff)
Assorted auto-diff enhancements for increased performance & more streamlined auto-diff results (#5394)
* Various AD enhancements * Fix issue with pt-loop test * Update pt-loop.slang * More fixes for perf. Final minimal context test now passes. * Fix issue with loop-elimination pass not running after dce * Try fix wgpu test by removing select operator * Disable wgpu * Delete out.wgsl * Remove comments * Update slang-ir-util.cpp * Fix header relative paths for slang-embed * Disbale wgpu for a few other tests * Better way of determining which params to ignore for side-effects * Update slang-ir-dce.cpp * Fix issue with circular reference from previous AD pass being left behind for the next AD pass * Update slang-ir-dce.cpp
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang9
1 files changed, 9 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index b39d91494..6042ff5cc 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -855,17 +855,20 @@ struct DiffTensorView
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T>(x)));
}
+ [ForceInline]
__generic<let N : int>
DifferentialPair<T> __load_forward(vector<uint, N> x)
{
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T, N>(x)));
}
+ [ForceInline]
void __load_backward(uint x, T.Differential dOut)
{
diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut));
}
+ [ForceInline]
__generic<let N : int>
void __load_backward(vector<uint, N> x, T.Differential dOut)
{
@@ -894,11 +897,13 @@ struct DiffTensorView
diff.store_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d));
}
+ [ForceInline]
void __store_backward(uint x, inout DifferentialPair<T> dpval)
{
dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x)));
}
+ [ForceInline]
__generic<let N : int>
void __store_backward(vector<uint, N> x, inout DifferentialPair<T> dpval)
{
@@ -999,11 +1004,13 @@ struct DiffTensorView
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.loadOnce_forward<T, N>(x)));
}
+ [ForceInline]
void __loadOnce_backward(uint x, T.Differential dOut)
{
diff.loadOnce_backward<T>(x, reinterpret<T, T.Differential>(dOut));
}
+ [ForceInline]
__generic<let N : int>
void __loadOnce_backward(vector<uint, N> x, T.Differential dOut)
{
@@ -1032,11 +1039,13 @@ struct DiffTensorView
diff.storeOnce_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d));
}
+ [ForceInline]
void __storeOnce_backward(uint x, inout DifferentialPair<T> dpval)
{
dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.storeOnce_backward<T>(x)));
}
+ [ForceInline]
__generic<let N : int>
void __storeOnce_backward(vector<uint, N> x, inout DifferentialPair<T> dpval)
{