diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-08 21:52:34 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-08 21:52:34 -0800 |
| commit | 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch) | |
| tree | b4f9eb6cb1eea88145fde0bd1f670a8803120257 /source/slang/diff.meta.slang | |
| parent | 257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff) | |
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`.
* Fix
* Fix.
* Cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 85 |
1 files changed, 11 insertions, 74 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 54f927816..4301eda94 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -12,6 +12,9 @@ __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) +attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; + +__attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) @@ -20,6 +23,9 @@ attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [PrimalSubstituteOf(function)] : PrimalSubstituteOfAttribute; + __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; @@ -1037,6 +1043,7 @@ void __d_refract(inout DifferentialPair<vector<T, N>> i, inout DifferentialPair< // Sine and cosine __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(T x, out T s, out T c) { s = sin(x); @@ -1045,6 +1052,7 @@ void __sincos_impl(T x, out T s, out T c) __generic<T : __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c) { s = sin(x); @@ -1053,62 +1061,18 @@ void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c) __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(matrix<T, N, M> x, out matrix<T, N, M> s, out matrix<T, N, M> c) { s = sin(x); c = cos(x); } -__generic<T: __BuiltinFloatingPointType> -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic<T : __BuiltinFloatingPointType, let N : int> -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair<vector<T, N>> x, out DifferentialPair<vector<T, N>> s, out DifferentialPair<vector<T, N>> c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair<matrix<T, N, M>> x, out DifferentialPair<matrix<T, N, M>> s, out DifferentialPair<matrix<T, N, M>> c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic<T: __BuiltinFloatingPointType> -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair<T> x, T.Differential dS, T.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} -__generic<T: __BuiltinFloatingPointType, let N : int> -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dS, vector<T, N>.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} - -__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair<matrix<T, N, M>> x, matrix<T, N, M>.Differential dS, matrix<T, N, M>.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} // dst (obsolete) __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PrimalSubstituteOf(dst)] vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1) { vector<T, 4> dest; @@ -1118,25 +1082,11 @@ vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1) dest.w = src1.w; ; return dest; } -__generic<T : __BuiltinFloatingPointType> -[ForwardDerivativeOf(dst)] -[ForceInline] -DifferentialPair<vector<T, 4>> __d_dst(DifferentialPair<vector<T, 4>> src0, DifferentialPair<vector<T, 4>> src1) -{ - return __fwd_diff(__dst_impl)(src0, src1); -} -__generic<T : __BuiltinFloatingPointType> -[BackwardDerivativeOf(dst)] -[ForceInline] -void __d_dst(inout DifferentialPair<vector<T, 4>> src0, inout DifferentialPair<vector<T, 4>> src1, vector<T, 4>.Differential dOut) -{ - __bwd_diff(__dst_impl)(src0, src1, dOut); -} // Legacy lighting function (obsolete) -__target_intrinsic(hlsl) [__readNone] [BackwardDifferentiable] +[PrimalSubstituteOf(lit)] float4 __lit_impl(float n_dot_l, float n_dot_h, float m) { let ambient = 1.0f; @@ -1144,19 +1094,6 @@ float4 __lit_impl(float n_dot_l, float n_dot_h, float m) let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m)); return float4(ambient, diffuse, specular, 1.0f); } -[ForwardDerivativeOf(lit)] -[ForceInline] -DifferentialPair<float4> __d_lit(DifferentialPair<float> n_dot_l, DifferentialPair<float> n_dot_h, DifferentialPair<float> m) -{ - return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m); -} -[BackwardDerivativeOf(lit)] -[ForceInline] -void __d_lit(inout DifferentialPair<float> n_dot_l, inout DifferentialPair<float> n_dot_h, inout DifferentialPair<float> m, float4 dOut) -{ - __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut); -} - // Matrix determinant __generic<T : __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] |
