summaryrefslogtreecommitdiffstats
path: root/tests/autodiff/reverse-control-flow-3.slang
blob: 3b76b3b209f9803098f230ee86bcb29894ddf0e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
//DISABLE_TEST:SIMPLE(filecheck=CHK):-target hlsl -stage compute -entry computeMain -report-checkpoint-intermediates

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer

RWStructuredBuffer<float> outputBuffer;

struct PathState
{
    uint depth;
    bool terminated;

    bool isHit() { return !terminated; }
    bool isTerminated() { return terminated; }
};

struct PathResult : IDifferentiable
{
    float thp;
    float L;
}
struct VisibilityQuery
{
    bool test();
}

struct ClosestHitQuery
{
    bool test();
}
void generatePath(uint pathID, out PathState path)
{
    path.terminated = false;
    path.depth = 0;
}

[BackwardDifferentiable]
float lightEval(uint depth)
{
    if (depth == 1)
    {
        return 5.0f;
    }
    else
    {
        return 0.0f;
    }
}

struct MaterialParam : IDifferentiable
{
    float roughness;
}

[BackwardDifferentiable]
MaterialParam getParam(uint id)
{
    MaterialParam p;
    p.roughness = 0.5f;
    return p;
}

[ForwardDerivativeOf(getParam)]
DifferentialPair<MaterialParam> d_getParam(uint id)
{
    MaterialParam p;
    p.roughness = 0.5f;
    MaterialParam.Differential d;
    d.roughness = 1.0f;
    return diffPair(p, d);
}

[BackwardDerivativeOf(getParam)]
void d_getParam(uint id, MaterialParam.Differential diff)
{
    outputBuffer[id] += diff.roughness;
}

//CHK-DAG: note: checkpointing context of 8 bytes associated with function: 'updatePathThroughput'
//CHK-DAG: note: 8 bytes (PathResult_0) used to checkpoint the following item:
[BackwardDifferentiable]
void updatePathThroughput(inout PathResult path, const float weight)
{
    path.thp *= weight;
}

struct BSDFSample : IDifferentiable
{
    float val;
}

[BackwardDifferentiable]
bool bsdfGGXSample(const MaterialParam bsdfParams, out BSDFSample result)
{
    result.val = bsdfParams.roughness;
    return true;
}

[BackwardDifferentiable]
bool generateScatterRay(const MaterialParam bsdfParams, inout PathState path, inout PathResult pathRes)
{
    BSDFSample result;
    bool valid = bsdfGGXSample(bsdfParams, result);
    return generateScatterRay(result, bsdfParams, path, pathRes, valid);
}

/** Generates a new scatter ray using BSDF importance sampling.
    \param[in] sd Shading data.
    \param[in] mi Material instance at the shading point.
    \param[in,out] path The path state.
    \return True if a ray was generated, false otherwise.
*/
[BackwardDifferentiable]
bool generateScatterRay(const BSDFSample bs, const MaterialParam bsdfParams, inout PathState path, inout PathResult pathRes, bool valid)
{
    if (valid) valid = generateScatterRay(bs, bsdfParams, path, pathRes);
    return valid;
}

/** Generates a new scatter ray given a valid BSDF sample.
    \param[in] bs BSDF sample (assumed to be valid).
    \param[in] sd Shading data.
    \param[in] mi Material instance at the shading point.
    \param[in,out] path The path state.
    \return True if a ray was generated, false otherwise.
*/

//CHK-DAG: note: checkpointing context of 16 bytes associated with function: 'generateScatterRay'
[BackwardDifferentiable]
bool generateScatterRay(const BSDFSample bs, const MaterialParam bsdfParams, inout PathState path, inout PathResult pathRes)
{
    //CHK-DAG: note: 8 bytes (s_bwd_prop_updatePathThroughput_Intermediates_0) used to checkpoint the following item:
    //CHK-DAG: note: 8 bytes (PathResult_0) used to checkpoint the following item:
    updatePathThroughput(pathRes, bs.val);
    return true;
}

[BackwardDifferentiable]
void handleHit(inout PathState path, inout PathResult rs, inout VisibilityQuery vq)
{
    var param = getParam(0);

    bool lastVertex = param.roughness > 0.8;
    if (lastVertex)
    {
        path.terminated = true;
        return;
    }
     
    generateScatterRay(param, path, rs);

    rs.L = rs.thp * lightEval(path.depth);

    // Decide on next hit
    if (path.depth < 1)
        path.terminated = false;
    else
        path.terminated = true;
}

[BackwardDifferentiable]
float bsdfEval(const MaterialParam mparam)
{
    return mparam.roughness;
}

[BackwardDifferentiable]
void nextHit(inout PathState path, inout PathResult rs, inout ClosestHitQuery cq)
{
    path.depth = path.depth + 1;
}

[BackwardDifferentiable]
void handleMiss(inout PathState path, inout PathResult rs)
{
    rs.L = 0.0f;
    path.terminated = true;
}

[BackwardDifferentiable]
bool tracePath(uint pathID, out PathState path, inout PathResult pathRes)
{
    generatePath(pathID, path);

    float thp = pathRes.thp;
    float L = pathRes.L;

    for (int i = 0; i < 3; ++i)
    {
        if (path.isHit())
        {
            VisibilityQuery vq;
            handleHit(path, pathRes, vq);

            if (path.isTerminated()) break;

            ClosestHitQuery chq;
            nextHit(path, pathRes, chq);
        }
        else
        {
            handleMiss(path, pathRes);
        }
    }
    
    return true;
}

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
    {
        PathResult pathRes;
        pathRes.L = 1.f;
        pathRes.thp = 1.f;

        PathResult.Differential pathResD;
        pathResD.L = 1.0f;
        pathResD.thp = 0.f;

        var dpx = diffPair(pathRes, pathResD);
        __bwd_diff(tracePath)(1, dpx); // Expect: 5.0 in outputBuffer[3]
    }
}

//CHK-NOT: note