summaryrefslogtreecommitdiff
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-05-04 16:43:40 -0400
committerGitHub <noreply@github.com>2023-05-04 13:43:40 -0700
commita9444925750da2498f456a626f1b164d68efedf1 (patch)
treeffdbf4c28b4ca01f50a74f45ef27834482b209c6 /source/slang/diff.meta.slang
parentab3ac985479132856613d6dcc8b43f4a4ef8c6b7 (diff)
Fix issue with out-of-order insts during type promotion when transposing code. (#2866)
* Bugfixes for warped-area reparameterization * Update slang-ir-autodiff-transpose.h * Update slang-ir-autodiff-transpose.h * Mark all stdlib methods backward differentiable * Update diff.meta.slang --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang27
1 files changed, 25 insertions, 2 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index f2f1d0cc3..84a72a425 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -367,6 +367,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[ForwardDerivativeOf(transpose)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<matrix<T, M, N>> __d_transpose(DifferentialPair<matrix<T, N, M>> m)
{
return DifferentialPair<matrix<T, M, N>>(transpose(m.p), transpose(m.d));
@@ -376,6 +377,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[BackwardDerivativeOf(transpose)]
[PreferRecompute]
+[BackwardDifferentiable]
void __d_transpose(inout DifferentialPair<matrix<T, N, M>> m, matrix<T, M, N>.Differential dOut)
{
m = diffPair(m.p, transpose(dOut));
@@ -386,6 +388,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, DifferentialPair<matrix<T, N, M>> right)
{
let primal = mul(left.p, right.p);
@@ -396,13 +399,16 @@ DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, Differen
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[BackwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<matrix<T, N, M>> right, vector<T, M>.Differential dOut)
{
vector<T, N>.Differential left_d_result;
matrix<T, N, M>.Differential right_d_result;
+ [ForceUnroll]
for (int i = 0; i < N; ++i)
{
T sum = T(0);
+ [ForceUnroll]
for (int j = 0; j < M; ++j)
{
sum += right.p[i][j] * dOut[j];
@@ -419,6 +425,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, DifferentialPair<vector<T,M>> right)
{
let primal = mul(left.p, right.p);
@@ -429,13 +436,16 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[BackwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPair<vector<T, M>> right, vector<T, N>.Differential dOut)
{
matrix<T, N, M>.Differential left_d_result;
vector<T, M>.Differential right_d_result;
+ [ForceUnroll]
for (int j = 0; j < M; ++j)
{
T sum = T(0);
+ [ForceUnroll]
for (int i = 0; i < N; ++i)
{
sum += left.p[i][j] * dOut[i];
@@ -452,6 +462,7 @@ __generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, DifferentialPair<matrix<T,N,C>> right)
{
let primal = mul(left.p, right.p);
@@ -462,22 +473,30 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, Differ
__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
[BackwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<matrix<T, N, C>> right, matrix<T, R, C>.Differential dOut)
{
matrix<T, R, N>.Differential left_d_result;
+ [ForceUnroll]
for (int r = 0; r < R; ++r)
+ [ForceUnroll]
for (int n = 0; n < N; ++n)
left_d_result[r][n] = T(0.0);
-
+
matrix<T, N, C>.Differential right_d_result;
+ [ForceUnroll]
for (int n = 0; n < N; ++n)
+ [ForceUnroll]
for (int c = 0; c < C; ++c)
right_d_result[n][c] = T(0.0);
-
+
+ [ForceUnroll]
for (int r = 0; r < R; ++r)
{
+ [ForceUnroll]
for (int c = 0; c < C; ++c)
{
+ [ForceUnroll]
for (int n = 0; n < N; ++n)
{
left_d_result[r][n] += right.p[n][c] * dOut[r][c];
@@ -493,6 +512,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma
__generic<T : __BuiltinFloatingPointType, let N : int>
[ForwardDerivativeOf(dot)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy)
{
T result = T(0);
@@ -510,6 +530,7 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair
__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDerivativeOf(dot)]
[PreferRecompute]
+[BackwardDifferentiable]
void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut)
{
vector<T, N>.Differential x_d_result, y_d_result;
@@ -527,6 +548,7 @@ void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<ve
__generic<T : __BuiltinFloatingPointType>
[ForwardDerivativeOf(cross)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, DifferentialPair<vector<T, 3>> b)
{
/*
@@ -555,6 +577,7 @@ DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, Diffe
__generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(cross)]
[PreferRecompute]
+[BackwardDifferentiable]
void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<vector<T, 3>> b, vector<T, 3>.Differential dOut)
{
/*