diff options
| author | Yong He <yonghe@outlook.com> | 2022-12-09 09:09:53 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-12-09 09:09:53 -0800 |
| commit | 8d359fc6133fa49d2d3b7f8bb4b37916e719c344 (patch) | |
| tree | 9270ccd8ee1e9162a516c5c36e17e049fe8b6244 /source | |
| parent | 41eb19e65a0974e23048bd7b3b1eb1e2f569b1d0 (diff) | |
Add `diffPair` stdlib function. (#2560)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 248112810..f58648657 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -92,19 +92,30 @@ struct DifferentialPair : IDifferentiable } }; -[ForceInline] +__generic<T: IDifferentiable> +__intrinsic_op($(kIROp_MakeDifferentialPair)) +DifferentialPair<T> diffPair(T primal, T.Differential diff); + +__generic<T: IDifferentiable> +[__unsafeForceInlineEarly] +DifferentialPair<T> diffPair(T primal) +{ + return diffPair(primal, T.dzero()); +} + +[__unsafeForceInlineEarly] void updatePrimal<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal) { p = DifferentialPair<T>(newPrimal, p.d); } -[ForceInline] +[__unsafeForceInlineEarly] void updateDiff<T : IDifferentiable>(inout DifferentialPair<T> p, T.Differential newDiff) { p = DifferentialPair<T>(p.p, newDiff); } -[ForceInline] +[__unsafeForceInlineEarly] void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T.Differential newDiff) { p = DifferentialPair<T>(newPrimal, newDiff); @@ -112,8 +123,8 @@ void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T // vector-matrix __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> -__target_intrinsic(hlsl) -__target_intrinsic(glsl, "($1 * $0)") +[ForceInline] +[ForwardDerivativeOf(mul)] DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, DifferentialPair<matrix<T, N, M>> right) { let primal = mul(left.p, right.p); @@ -135,6 +146,7 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen // matrix-matrix __generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> +[ForceInline] [ForwardDerivativeOf(mul)] DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, DifferentialPair<matrix<T,N,C>> left) { |
