From a9444925750da2498f456a626f1b164d68efedf1 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 4 May 2023 16:43:40 -0400 Subject: Fix issue with out-of-order insts during type promotion when transposing code. (#2866) * Bugfixes for warped-area reparameterization * Update slang-ir-autodiff-transpose.h * Update slang-ir-autodiff-transpose.h * Mark all stdlib methods backward differentiable * Update diff.meta.slang --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index f2f1d0cc3..84a72a425 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -367,6 +367,7 @@ __generic [ForceInline] [ForwardDerivativeOf(transpose)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair> __d_transpose(DifferentialPair> m) { return DifferentialPair>(transpose(m.p), transpose(m.d)); @@ -376,6 +377,7 @@ __generic [ForceInline] [BackwardDerivativeOf(transpose)] [PreferRecompute] +[BackwardDifferentiable] void __d_transpose(inout DifferentialPair> m, matrix.Differential dOut) { m = diffPair(m.p, transpose(dOut)); @@ -386,6 +388,7 @@ __generic [ForceInline] [ForwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) { let primal = mul(left.p, right.p); @@ -396,13 +399,16 @@ DifferentialPair> mul(DifferentialPair> left, Differen __generic [BackwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] void __d_mul(inout DifferentialPair> left, inout DifferentialPair> right, vector.Differential dOut) { vector.Differential left_d_result; matrix.Differential right_d_result; + [ForceUnroll] for (int i = 0; i < N; ++i) { T sum = T(0); + [ForceUnroll] for (int j = 0; j < M; ++j) { sum += right.p[i][j] * dOut[j]; @@ -419,6 +425,7 @@ __generic [ForceInline] [ForwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) { let primal = mul(left.p, right.p); @@ -429,13 +436,16 @@ DifferentialPair> mul(DifferentialPair> left, Differen __generic [BackwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] void __d_mul(inout DifferentialPair> left, inout DifferentialPair> right, vector.Differential dOut) { matrix.Differential left_d_result; vector.Differential right_d_result; + [ForceUnroll] for (int j = 0; j < M; ++j) { T sum = T(0); + [ForceUnroll] for (int i = 0; i < N; ++i) { sum += left.p[i][j] * dOut[i]; @@ -452,6 +462,7 @@ __generic [ForceInline] [ForwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) { let primal = mul(left.p, right.p); @@ -462,22 +473,30 @@ DifferentialPair> mul(DifferentialPair> left, Differ __generic [BackwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] void mul(inout DifferentialPair> left, inout DifferentialPair> right, matrix.Differential dOut) { matrix.Differential left_d_result; + [ForceUnroll] for (int r = 0; r < R; ++r) + [ForceUnroll] for (int n = 0; n < N; ++n) left_d_result[r][n] = T(0.0); - + matrix.Differential right_d_result; + [ForceUnroll] for (int n = 0; n < N; ++n) + [ForceUnroll] for (int c = 0; c < C; ++c) right_d_result[n][c] = T(0.0); - + + [ForceUnroll] for (int r = 0; r < R; ++r) { + [ForceUnroll] for (int c = 0; c < C; ++c) { + [ForceUnroll] for (int n = 0; n < N; ++n) { left_d_result[r][n] += right.p[n][c] * dOut[r][c]; @@ -493,6 +512,7 @@ void mul(inout DifferentialPair> left, inout DifferentialPair [ForwardDerivativeOf(dot)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair __d_dot(DifferentialPair> dpx, DifferentialPair> dpy) { T result = T(0); @@ -510,6 +530,7 @@ DifferentialPair __d_dot(DifferentialPair> dpx, DifferentialPair __generic [BackwardDerivativeOf(dot)] [PreferRecompute] +[BackwardDifferentiable] void __d_dot(inout DifferentialPair> dpx, inout DifferentialPair> dpy, T.Differential dOut) { vector.Differential x_d_result, y_d_result; @@ -527,6 +548,7 @@ void __d_dot(inout DifferentialPair> dpx, inout DifferentialPair [ForwardDerivativeOf(cross)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair> __d_cross(DifferentialPair> a, DifferentialPair> b) { /* @@ -555,6 +577,7 @@ DifferentialPair> __d_cross(DifferentialPair> a, Diffe __generic [BackwardDerivativeOf(cross)] [PreferRecompute] +[BackwardDifferentiable] void __d_cross(inout DifferentialPair> a, inout DifferentialPair> b, vector.Differential dOut) { /* -- cgit v1.2.3