summaryrefslogtreecommitdiffstats
path: root/tests/autodiff/diff-ptr-type-loop.slang
blob: 712f35bb2f6b9ad6193578a1d11521f19b1b7583 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-cuda -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

// ----- MyPtrType definition -----
struct MyPtrType : IDifferentiablePtrType
{
    typealias Differential = MyPtrType;

    RWStructuredBuffer<float> buffer;
    uint offset;

    float load(uint idx) { return buffer[offset + idx]; }
    void accumulate(uint idx, float value) { buffer[offset + idx] += value; }
}

[BackwardDerivative(load_bwd)]
float load(MyPtrType b, uint idx)
{
    return b.load(idx);
}

void load_bwd(DifferentialPtrPair<MyPtrType> b, uint idx, float grad)
{
    b.d.accumulate(idx, grad);
}


// ------
[Differentiable]
float reduce(MyPtrType a, uint num)
{   
    float sum = 0;
    [MaxIters(3)]
    for (uint i = 0; i < num; i++)
    {
        sum += load(a, i);
    }

    return sum;
}

[Differentiable]
float test(MyPtrType b, uint num)
{
    return reduce(b, num);
}

[numthreads(1, 1, 1)]
void computeMain(uint id : SV_DispatchThreadID)
{
    outputBuffer[0] = 1; // CHECK: 1
    outputBuffer[1] = 2; // CHECK: 2

    // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative.
    var b = DifferentialPtrPair<MyPtrType>( { outputBuffer, 0 }, { outputBuffer, 2 } );

    bwd_diff(test)(b, 2, 1.5f);

    // Check locations [2] and [3] in the buffer
    // CHECK: 1.5
    // CHECK: 1.5
}