From 8d359fc6133fa49d2d3b7f8bb4b37916e719c344 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 9 Dec 2022 09:09:53 -0800 Subject: Add `diffPair` stdlib function. (#2560) --- source/slang/diff.meta.slang | 22 +++++++++++++++++----- tests/autodiff/matrix-arithmetic-fwd.slang | 6 +++--- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 248112810..f58648657 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -92,19 +92,30 @@ struct DifferentialPair : IDifferentiable } }; -[ForceInline] +__generic +__intrinsic_op($(kIROp_MakeDifferentialPair)) +DifferentialPair diffPair(T primal, T.Differential diff); + +__generic +[__unsafeForceInlineEarly] +DifferentialPair diffPair(T primal) +{ + return diffPair(primal, T.dzero()); +} + +[__unsafeForceInlineEarly] void updatePrimal(inout DifferentialPair p, T newPrimal) { p = DifferentialPair(newPrimal, p.d); } -[ForceInline] +[__unsafeForceInlineEarly] void updateDiff(inout DifferentialPair p, T.Differential newDiff) { p = DifferentialPair(p.p, newDiff); } -[ForceInline] +[__unsafeForceInlineEarly] void updatePair(inout DifferentialPair p, T newPrimal, T.Differential newDiff) { p = DifferentialPair(newPrimal, newDiff); @@ -112,8 +123,8 @@ void updatePair(inout DifferentialPair p, T newPrimal, T // vector-matrix __generic -__target_intrinsic(hlsl) -__target_intrinsic(glsl, "($1 * $0)") +[ForceInline] +[ForwardDerivativeOf(mul)] DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) { let primal = mul(left.p, right.p); @@ -135,6 +146,7 @@ DifferentialPair> mul(DifferentialPair> left, Differen // matrix-matrix __generic +[ForceInline] [ForwardDerivativeOf(mul)] DifferentialPair> mul(DifferentialPair> right, DifferentialPair> left) { diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang b/tests/autodiff/matrix-arithmetic-fwd.slang index 7a953cef8..ee909551a 100644 --- a/tests/autodiff/matrix-arithmetic-fwd.slang +++ b/tests/autodiff/matrix-arithmetic-fwd.slang @@ -27,8 +27,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float3x3 da = float3x3(1.0); outputBuffer[0] = __fwd_diff(g)( - DifferentialPair(a, da), - DifferentialPair(b, da)).d._11; // Expect: 8 + diffPair(a, da), + diffPair(b, da)).d._11; // Expect: 8 float2x2 l = float2x2(1.0, 2.0, 3.0, 4.0); float2x2 r = float2x2(10.0, 11.0, 12.0, 13.0); @@ -37,5 +37,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) //float2x2 epsilon = d * 0.001f; //outputBuffer[1] = (h(l + epsilon, r + epsilon) - h(l - epsilon, r - epsilon)) / (epsilon[0][0] * 2.0)); - outputBuffer[1] = __fwd_diff(h)(DifferentialPair(l, d), DifferentialPair(r, d)).d; // Expect 83.0 + outputBuffer[1] = __fwd_diff(h)(diffPair(l, d), diffPair(r, d)).d; // Expect 83.0 } -- cgit v1.2.3