summaryrefslogtreecommitdiffstats
path: root/tests/autodiff/local-redecl-custom-jvp.slang
blob: 52525075cf79fa66cefcd8898c39896d1fb561c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

typedef DifferentialPair<float> dpfloat;
typedef float.Differential dfloat;

import test_intrinsics;

dpfloat my_pow_jvp(dpfloat x, dpfloat n)
{
    return dpfloat(
        pow(x.p, n.p),
        x.d * n.p * pow(x.p, n.p-1) + n.d * pow(x.p, n.p) * log(x.p));
}

[ForwardDerivative(my_pow_jvp)]
float _pow(float, float);

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
     {
        dpfloat dpa = dpfloat(5.0, 1.0);
        dpfloat dpn = dpfloat(2, 0.0);

        outputBuffer[0] = __fwd_diff(_pow)(dpa, dpn).d;        // Expect: 10.0
        outputBuffer[1] = __fwd_diff(_pow)(
            dpfloat(dpa.p, 0.0),
            dpfloat(dpn.p, 1.0)).d;                     // Expect: 40.23595
    }
}