summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-12-09 09:09:53 -0800
committerGitHub <noreply@github.com>2022-12-09 09:09:53 -0800
commit8d359fc6133fa49d2d3b7f8bb4b37916e719c344 (patch)
tree9270ccd8ee1e9162a516c5c36e17e049fe8b6244 /source
parent41eb19e65a0974e23048bd7b3b1eb1e2f569b1d0 (diff)
Add `diffPair` stdlib function. (#2560)
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang22
1 files changed, 17 insertions, 5 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 248112810..f58648657 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -92,19 +92,30 @@ struct DifferentialPair : IDifferentiable
}
};
-[ForceInline]
+__generic<T: IDifferentiable>
+__intrinsic_op($(kIROp_MakeDifferentialPair))
+DifferentialPair<T> diffPair(T primal, T.Differential diff);
+
+__generic<T: IDifferentiable>
+[__unsafeForceInlineEarly]
+DifferentialPair<T> diffPair(T primal)
+{
+ return diffPair(primal, T.dzero());
+}
+
+[__unsafeForceInlineEarly]
void updatePrimal<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal)
{
p = DifferentialPair<T>(newPrimal, p.d);
}
-[ForceInline]
+[__unsafeForceInlineEarly]
void updateDiff<T : IDifferentiable>(inout DifferentialPair<T> p, T.Differential newDiff)
{
p = DifferentialPair<T>(p.p, newDiff);
}
-[ForceInline]
+[__unsafeForceInlineEarly]
void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T.Differential newDiff)
{
p = DifferentialPair<T>(newPrimal, newDiff);
@@ -112,8 +123,8 @@ void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T
// vector-matrix
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
-__target_intrinsic(hlsl)
-__target_intrinsic(glsl, "($1 * $0)")
+[ForceInline]
+[ForwardDerivativeOf(mul)]
DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, DifferentialPair<matrix<T, N, M>> right)
{
let primal = mul(left.p, right.p);
@@ -135,6 +146,7 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen
// matrix-matrix
__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
+[ForceInline]
[ForwardDerivativeOf(mul)]
DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, DifferentialPair<matrix<T,N,C>> left)
{