diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-03-24 19:50:51 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-24 16:50:51 -0700 |
| commit | 7292edbd3eba3da7e8490ad19169a7d18283057a (patch) | |
| tree | b49fb1ba6a76d9775f788057d91b22b88b4fc19c /source/slang | |
| parent | e794de0d63e6de9be564c971fd40486ecf631293 (diff) | |
Added `[BackwardDifferentiable]` tags for intrinsic + builtin methods (#2732)
* Added higher-order differentiability decorators for built-ins + preliminary tests
* Update diff.meta.slang
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 58 |
1 files changed, 53 insertions, 5 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 26a673512..bbe94dbc2 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -292,6 +292,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve #define VECTOR_MATRIX_BINARY_DIFF_IMPL(NAME) \ __generic<T : __BuiltinFloatingPointType, let N : int> \ + [BackwardDifferentiable] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<vector<T, N>> __d_##NAME##_vector( \ DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy) \ @@ -309,6 +310,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<vector<T, N>>(result, d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [BackwardDifferentiable] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \ DifferentialPair<matrix<T, M, N>> dpx, DifferentialPair<matrix<T, M, N>> dpy) \ @@ -327,6 +329,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<matrix<T, M, N>>(result, d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ + [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ inout DifferentialPair<vector<T, N>> dpx, \ @@ -346,6 +349,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpy = diffPair(dpy.p, right_d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ inout DifferentialPair<matrix<T, M, N>> dpx, \ @@ -368,6 +372,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve #define VECTOR_MATRIX_TERNARY_DIFF_IMPL(NAME) \ __generic<T : __BuiltinFloatingPointType, let N : int> \ + [BackwardDifferentiable] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<vector<T, N>> __d_##NAME##_vector( \ DifferentialPair<vector<T, N>> dpx, \ @@ -388,8 +393,9 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<vector<T, N>>(result, d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [BackwardDifferentiable] \ [ForwardDerivativeOf(NAME)] \ - DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \ + DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \ DifferentialPair<matrix<T, M, N>> dpx, \ DifferentialPair<matrix<T, M, N>> dpy, \ DifferentialPair<matrix<T, M, N>> dpz) \ @@ -409,6 +415,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<matrix<T, M, N>>(result, d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ + [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ inout DifferentialPair<vector<T, N>> dpx, \ @@ -433,6 +440,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpz = diffPair(dpz.p, right_d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ inout DifferentialPair<matrix<T, M, N>> dpx, \ @@ -460,12 +468,14 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve #define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \ __generic<T : __BuiltinFloatingPointType> \ + [BackwardDifferentiable] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \ { \ return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ + [BackwardDifferentiable] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \ { \ @@ -481,6 +491,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<vector<T, N>>(result, d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [BackwardDifferentiable] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpx) \ { \ @@ -498,12 +509,14 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<matrix<T, M, N>>(result, d_result); \ } \ __generic<T : __BuiltinFloatingPointType> \ + [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \ { \ dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ + [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \ @@ -518,6 +531,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpx = diffPair(dpx.p, d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ inout DifferentialPair<matrix<T, M, N>> dpx, matrix<T, M, N>.Differential dOut) \ @@ -581,6 +595,7 @@ SIMPLE_UNARY_DERIVATIVE_IMPL(atan, T(1.0) / (T(1.0) + dpx.p * dpx.p)) // Atan2 __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [ForwardDerivativeOf(atan2)] DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx) { @@ -592,6 +607,7 @@ DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx) } __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [BackwardDerivativeOf(atan2)] void __d_atan2(inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpx, T.Differential dOut) { @@ -603,12 +619,14 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(atan2) // fmod __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [ForwardDerivativeOf(fmod)] DifferentialPair<T> __d_fmod(DifferentialPair<T> x, DifferentialPair<T> y) { return DifferentialPair<T>(fmod(x.p, y.p), x.d); } __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [BackwardDerivativeOf(fmod)] void __d_fmod(inout DifferentialPair<T> x, inout DifferentialPair<T> y, T.Differential dOut) { @@ -619,6 +637,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(fmod) // Raise to a power __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [ForwardDerivativeOf(pow)] DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy) { @@ -638,6 +657,7 @@ DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy) } __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [BackwardDerivativeOf(pow)] void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) { @@ -663,6 +683,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(pow) // Maximum __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [ForwardDerivativeOf(max)] DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy) { @@ -673,6 +694,7 @@ DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy) } __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [BackwardDerivativeOf(max)] void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) { @@ -684,6 +706,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(max) // Minimum __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [ForwardDerivativeOf(min)] DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy) { @@ -694,6 +717,7 @@ DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy) } __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [BackwardDerivativeOf(min)] void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) { @@ -705,6 +729,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(min) // Lerp __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [ForwardDerivativeOf(lerp)] DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, DifferentialPair<T> dps) { @@ -714,6 +739,7 @@ DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, D ); } __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [BackwardDerivativeOf(lerp)] void __d_lerp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dps, T.Differential dOut) { @@ -725,6 +751,7 @@ VECTOR_MATRIX_TERNARY_DIFF_IMPL(lerp) // Clamp __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [ForwardDerivativeOf(clamp)] DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin, DifferentialPair<T> dpMax) { @@ -733,6 +760,7 @@ DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin dpx.p < dpMin.p ? (dpx.p > dpMax.p ? dpMax.d : dpx.d) : dpMin.d); } __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [BackwardDerivativeOf(clamp)] void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, inout DifferentialPair<T> dpMax, T.Differential dOut) { @@ -743,6 +771,7 @@ void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, i VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp) // fma +[BackwardDifferentiable] [ForwardDerivativeOf(fma)] DifferentialPair<double> __d_fma(DifferentialPair<double> dpx, DifferentialPair<double> dpy, DifferentialPair<double> dpz) { @@ -750,6 +779,7 @@ DifferentialPair<double> __d_fma(DifferentialPair<double> dpx, DifferentialPair< fma(dpx.p, dpy.p, dpz.p), dpy.p * dpx.d + dpx.p * dpy.d + dpz.d); } +[BackwardDifferentiable] [BackwardDerivativeOf(fma)] void __d_fma(inout DifferentialPair<double> dpx, inout DifferentialPair<double> dpy, inout DifferentialPair<double> dpz, double dOut) { @@ -758,6 +788,7 @@ void __d_fma(inout DifferentialPair<double> dpx, inout DifferentialPair<double> dpz = diffPair(dpz.p, dOut); } __generic<let N : int> +[BackwardDifferentiable] [ForwardDerivativeOf(fma)] DifferentialPair<vector<double, N>> __d_fma_vector( DifferentialPair<vector<double, N>> dpx, @@ -778,6 +809,7 @@ DifferentialPair<vector<double, N>> __d_fma_vector( return DifferentialPair<vector<double, N>>(result, d_result); } __generic<let N : int> +[BackwardDifferentiable] [BackwardDerivativeOf(fma)] void __d_fma_vector( inout DifferentialPair<vector<double, N>> dpx, @@ -803,6 +835,7 @@ void __d_fma_vector( // mad __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [ForwardDerivativeOf(mad)] DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, DifferentialPair<T> dpz) { @@ -811,6 +844,7 @@ DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, Di T.dadd(T.dadd(T.dmul(dpy.p, dpx.d), T.dmul(dpx.p, dpy.d)), dpz.d)); } __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [BackwardDerivativeOf(mad)] void __d_mad(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpz, T.Differential dOut) { @@ -829,12 +863,14 @@ T __smoothstep_impl(T minVal, T maxVal, T x) return t * t * (T(3.0) - T(2.0) * t); } __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [ForwardDerivativeOf(smoothstep)] DifferentialPair<T> __d_smoothstep(DifferentialPair<T> minVal, DifferentialPair<T> maxVal, DifferentialPair<T> x) { return __fwd_diff(__smoothstep_impl)(minVal, maxVal, x); } __generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] [BackwardDerivativeOf(smoothstep)] void __d_smoothstep(inout DifferentialPair<T> minVal, inout DifferentialPair<T> maxVal, inout DifferentialPair<T> x, T.Differential dOut) { @@ -856,6 +892,7 @@ T __length_impl(vector<T, N> x) } __generic<T: __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [ForwardDerivativeOf(length)] [ForceInline] DifferentialPair<T> __d_length(DifferentialPair<vector<T, N>> x) @@ -864,6 +901,7 @@ DifferentialPair<T> __d_length(DifferentialPair<vector<T, N>> x) } __generic<T: __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [BackwardDerivativeOf(length)] [ForceInline] void __d_length(inout DifferentialPair<vector<T, N>> x, T.Differential dOut) @@ -872,13 +910,14 @@ void __d_length(inout DifferentialPair<vector<T, N>> x, T.Differential dOut) } // Vector distance -__generic<T : __BuiltinFloatingPointType, let N : int> +__generic<T: __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] T __distance_impl(vector<T, N> x, vector<T, N> y) { return length(y - x); } __generic<T: __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [ForwardDerivativeOf(distance)] [ForceInline] DifferentialPair<T> __d_distance(DifferentialPair<vector<T, N>> x, DifferentialPair<vector<T, N>> y) @@ -887,6 +926,7 @@ DifferentialPair<T> __d_distance(DifferentialPair<vector<T, N>> x, DifferentialP } __generic<T: __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [BackwardDerivativeOf(distance)] [ForceInline] void __d_distance(inout DifferentialPair<vector<T, N>> x, inout DifferentialPair<vector<T, N>> y, T.Differential dOut) @@ -903,13 +943,15 @@ vector<T, N> __normalize_impl(vector<T, N> x) return x * r; } __generic<T: __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [ForwardDerivativeOf(normalize)] [ForceInline] DifferentialPair<vector<T, N>> __d_normalize(DifferentialPair<vector<T, N>> x) { return __fwd_diff(__normalize_impl)(x); } -__generic<T: __BuiltinFloatingPointType, let N : int> +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [BackwardDerivativeOf(normalize)] [ForceInline] void __d_distance(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dOut) @@ -924,14 +966,16 @@ vector<T, N> __reflect_impl(vector<T, N> i, vector<T, N> n) { return i - n * (T(2.0) * dot(i, n)); } -__generic<T: __BuiltinFloatingPointType, let N : int> +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [ForwardDerivativeOf(reflect)] [ForceInline] DifferentialPair<vector<T, N>> __d_reflect(DifferentialPair<vector<T, N>> i, DifferentialPair<vector<T, N>> n) { return __fwd_diff(__reflect_impl)(i, n); } -__generic<T: __BuiltinFloatingPointType, let N : int> +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [BackwardDerivativeOf(reflect)] [ForceInline] void __d_reflect(inout DifferentialPair<vector<T, N>> i, inout DifferentialPair<vector<T, N>> n, vector<T, N>.Differential dOut) @@ -948,6 +992,7 @@ vector<T, N> __refract_impl(vector<T, N> i, vector<T, N> n, T eta) return (k < T(0.0)) ? vector<T, N>(T(0.0)) : eta * i - (eta * dot(n, i) + sqrt(max(T(0.0),k))) * n; } __generic<T: __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [ForwardDerivativeOf(refract)] [ForceInline] DifferentialPair<vector<T, N>> __d_refract(DifferentialPair<vector<T, N>> i, DifferentialPair<vector<T, N>> n, DifferentialPair<T> eta) @@ -955,6 +1000,7 @@ DifferentialPair<vector<T, N>> __d_refract(DifferentialPair<vector<T, N>> i, Dif return __fwd_diff(__refract_impl)(i, n, eta); } __generic<T: __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [BackwardDerivativeOf(refract)] [ForceInline] void __d_refract(inout DifferentialPair<vector<T, N>> i, inout DifferentialPair<vector<T, N>> n, inout DifferentialPair<T> eta, vector<T, N>.Differential dOut) @@ -1053,6 +1099,7 @@ T __determinant_impl(matrix<T,N,N> m) return result; } __generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [ForwardDerivativeOf(determinant)] [ForceInline] DifferentialPair<T> __determinant_impl(DifferentialPair<matrix<T,N,N>> m) @@ -1060,6 +1107,7 @@ DifferentialPair<T> __determinant_impl(DifferentialPair<matrix<T,N,N>> m) return __fwd_diff(__determinant_impl)(m); } __generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] [BackwardDerivativeOf(determinant)] [ForceInline] void __d_determinant(inout DifferentialPair<matrix<T,N,N>> m, T.Differential dOut) |
