summaryrefslogtreecommitdiffstats
path: root/tests/autodiff/property.slang
blob: 2a626546ffd4bf38e065a5131657f59607380875 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -cuda
public struct ReadOnlyIndex
{
    private int _idx;
    __init(int i) { _idx = i; }
    public property int idx { get { return _idx; } }
}
struct GradientBuffer
{
    RWStructuredBuffer<float> primal;
    StructuredBuffer<float> grad;

    [Differentiable]
    void write(int idx, float v) { primal[idx] = detach(v); }

    [BackwardDerivativeOf(write)]
    void write_bwd(int idx, inout DifferentialPair<float> d) { d = diffPair(d.p, grad[idx]); }

    [Differentiable]
    void store(ReadOnlyIndex idx, float v) { write(idx.idx, v); }
}
[Differentiable]
void test(GradientBuffer buf, ReadOnlyIndex b, float x)
{
    buf.store(b, x);
}
public float repro(RWStructuredBuffer<float> primal, StructuredBuffer<float> grad)
{
    DifferentialPair<float> result = diffPair(1.0f);
    GradientBuffer buf = { primal, grad };
    bwd_diff(test)(buf, ReadOnlyIndex(5), result);
    return result.d;
}

//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0], stride=4)
RWStructuredBuffer<float> output;

//TEST_INPUT: set gPrimal = ubuffer(data=[0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0], stride=4)
RWStructuredBuffer<float> gPrimal;
//TEST_INPUT: set gGrad = ubuffer(data=[0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0], stride=4)
StructuredBuffer<float> gGrad;

[numthreads(1,1,1)]
void computeMain()
{
    // CHECK: 5.0
    output[0] = repro(gPrimal, gGrad);
}