summaryrefslogtreecommitdiff
path: root/tests/autodiff/get-offset-ptr.slang
blob: 517acb54dfc6723c6fd351bca847d69de46da754 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
//TEST:SIMPLE(filecheck=CHECK): -target cuda -line-directive-mode none

//CHECK: struct s_bwd_prop_function_Intermediates{{[_0-9]+}}
//CHECK: {
//CHECK:     MyDiffPtr{{[_0-9]+}} {{[_A-Za-z0-9]+}};
//CHECK:     MyDiffPtr{{[_0-9]+}} {{[_A-Za-z0-9]+}};
//CHECK: };

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out, name outputBuffer
RWStructuredBuffer<float> outputBuffer;

struct MyDiffPtr
{
    uint offset;
    uint d_offset;

    [BackwardDerivative(__bwd_foo)]
    float foo() 
    { 
        return outputBuffer[offset] * outputBuffer[offset];
    }

    void __bwd_foo(float grad) 
    {
        outputBuffer[d_offset] = 2.f * outputBuffer[offset] * grad;
    }
};

[Differentiable]
float function(MyDiffPtr *i)
{
    return i[0].foo() + i[1].foo();
}

[numthreads(1, 1, 1), shader("compute")]
void main(uint3 dispatchThreadID: SV_DispatchThreadID)
{
    MyDiffPtr s[2] = {{0, 2}, {1, 3}};
    __bwd_diff(function)(&s[0], 1.0f);
}