From b61be5e6fb7fe1c4ec8228cdf73f49f11e5a0ac9 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 28 Oct 2024 15:47:58 -0400 Subject: Assorted auto-diff enhancements for increased performance & more streamlined auto-diff results (#5394) * Various AD enhancements * Fix issue with pt-loop test * Update pt-loop.slang * More fixes for perf. Final minimal context test now passes. * Fix issue with loop-elimination pass not running after dce * Try fix wgpu test by removing select operator * Disable wgpu * Delete out.wgsl * Remove comments * Update slang-ir-util.cpp * Fix header relative paths for slang-embed * Disbale wgpu for a few other tests * Better way of determining which params to ignore for side-effects * Update slang-ir-dce.cpp * Fix issue with circular reference from previous AD pass being left behind for the next AD pass * Update slang-ir-dce.cpp --- source/slang/diff.meta.slang | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index b39d91494..6042ff5cc 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -855,17 +855,20 @@ struct DiffTensorView return diffPair(primal.load(x), reinterpret(diff.load_forward(x))); } + [ForceInline] __generic DifferentialPair __load_forward(vector x) { return diffPair(primal.load(x), reinterpret(diff.load_forward(x))); } + [ForceInline] void __load_backward(uint x, T.Differential dOut) { diff.load_backward(x, reinterpret(dOut)); } + [ForceInline] __generic void __load_backward(vector x, T.Differential dOut) { @@ -894,11 +897,13 @@ struct DiffTensorView diff.store_forward(x, reinterpret(dpval.d)); } + [ForceInline] void __store_backward(uint x, inout DifferentialPair dpval) { dpval = diffPair(dpval.p, reinterpret(diff.store_backward(x))); } + [ForceInline] __generic void __store_backward(vector x, inout DifferentialPair dpval) { @@ -999,11 +1004,13 @@ struct DiffTensorView return diffPair(primal.load(x), reinterpret(diff.loadOnce_forward(x))); } + [ForceInline] void __loadOnce_backward(uint x, T.Differential dOut) { diff.loadOnce_backward(x, reinterpret(dOut)); } + [ForceInline] __generic void __loadOnce_backward(vector x, T.Differential dOut) { @@ -1032,11 +1039,13 @@ struct DiffTensorView diff.storeOnce_forward(x, reinterpret(dpval.d)); } + [ForceInline] void __storeOnce_backward(uint x, inout DifferentialPair dpval) { dpval = diffPair(dpval.p, reinterpret(diff.storeOnce_backward(x))); } + [ForceInline] __generic void __storeOnce_backward(vector x, inout DifferentialPair dpval) { -- cgit v1.2.3