summaryrefslogtreecommitdiffstats
path: root/tests/autodiff/generic-custom-jvp.slang
blob: 6e9e863bb44733aae930eefbf408f952dae575f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
//TEST_IGNORE_FILE:

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

typealias IDFloat = IFloat & IDifferentiable;

__generic<T : IDFloat>
typedef DifferentialPair<T> dfloat;

import test_intrinsics;

dpfloat my_pow_jvp(dpfloat x, dpfloat n)
{
    return dpfloat(
        pow(x.p, n.p),
        x.d * n.p * pow(x.p, n.p-1) + n.d * pow(x.p, n.p) * log(x.p));
}

[ForwardDerivative(my_pow_jvp)]
float _pow(float, float);

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
     {
        dpfloat dpa = dpfloat(5.0, 1.0);
        dpfloat dpn = dpfloat(2, 0.0);

        outputBuffer[0] = __fwd_diff(_pow)(dpa, dpn).d;        // Expect: 10.0
        outputBuffer[1] = __fwd_diff(_pow)(
            dpfloat(dpa.p, 0.0),
            dpfloat(dpn.p, 1.0)).d;                     // Expect: 40.23595
    }
}