summaryrefslogtreecommitdiffstats
path: root/tests/bugs/gh-5781.slang
blob: 33456f5001403a390b37165e8e77079cb7093023 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
//TEST:SIMPLE(filecheck=CHECK): -target spirv
// CHECK: OpEntryPoint

module test;

public enum class MaterialID : uint { invalid = 0xffffffff };

public struct Material : IDifferentiable
{
    float x;
}

public struct Hit
{
    MaterialID material;
}

public struct Scene
{
    StructuredBuffer<Material> materials;
    RWStructuredBuffer<Material> grads;

    [Differentiable]
    Material load(MaterialID id) { return materials[uint(id)]; }

    void accumulate(MaterialID id, Material d) { grads[uint(id)].x += d.x; }

    [Differentiable, BackwardDerivative(_get_material_bwd)]
    public Material get_material(MaterialID id) { return load(id); }

    public void _get_material_bwd(MaterialID id, Material d) { accumulate(id, d); }

    [Differentiable]
    public Material get_material(Hit hit) { return get_material(hit.material); }
}

[Differentiable]
float trace(const Scene scene, Hit hit)
{
    Material m = scene.get_material(hit);
    return m.x;
}


[shader("compute")]
void main(
    uniform Scene scene,
    uniform StructuredBuffer<uint> input,
    uniform RWStructuredBuffer<float> output,
    uniform RWStructuredBuffer<float> grads
)
{
    Hit hit;
    hit.material = MaterialID(input[0]);
    output[0] = trace(scene, hit);
    bwd_diff(trace)(scene, hit, grads[0]);
}