summaryrefslogtreecommitdiffstats
path: root/tests/autodiff/path-tracer/pt-loop.slang
blob: 93a187666ce537bbe34b7aaf77422809c9860092 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
//Tests automatic synthesis of Differential type requirement.

//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -dx12
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

struct PathData : IDifferentiable
{
    float3 thp;    
    uint length;
    bool terminated;
    bool isHit;

    [BackwardDifferentiable]
    __init()
    {
        this.thp = float3(1.f);
        this.length = 0;
        this.terminated = false;
        this.isHit = false;
    }
}

bool traceRayInline(uint length)
{
    if (length < 2) return true;
    else return false;
}

float3 getAlbedo(uint length)
{
    return float3(0.9f, 1.f, 1.f);
}

float3 getAlbedoDerivative(uint length)
{
    return float3(1.f, 0.f, 0.f);
}

[ForwardDerivativeOf(getAlbedo)]
[TreatAsDifferentiable]
DifferentialPair<float3> __fwd_d_getAlbedo(uint length)
{
    float3 primalValue = getAlbedo(length);
    float3 derivativeValue = no_diff getAlbedoDerivative(length);
    return DifferentialPair<float3>(primalValue, derivativeValue);
}

[BackwardDerivativeOf(getAlbedo)]
[TreatAsDifferentiable]
void __bwd_d_getAlbedo(uint length, float3 dOut)
{
    outputBuffer[2] += dOut.x;
}

[BackwardDifferentiable]
void handleHit(inout PathData pathData)
{
    if (pathData.length >= 2)
    {
        pathData.terminated = true;
        return;
    }

    float3 albedo = getAlbedo(pathData.length);
    pathData.thp *= albedo;
    pathData.length++;
}

[BackwardDifferentiable]
[PreferRecompute]
float3 tracePath()
{
    PathData pathData = PathData();

    if (traceRayInline(pathData.length))
    {
        pathData.isHit = true;
    }
    else
    {
        pathData.terminated = true;
        pathData.isHit = false;
    }

    [MaxIters(4)]
    while (!pathData.terminated)
    {
        if (pathData.isHit)
        {
            handleHit(pathData);

            //pathData.isHit = traceRayInline(pathData.length);
            if (!traceRayInline(pathData.length)) pathData.isHit = false;
            else pathData.isHit = true;
        }
        else
        {
            pathData.terminated = true;
        }
    }
    return pathData.thp;
}

[numthreads(1, 1, 1)]
void computeMain(uint3 dispathThreadID: SV_DispatchThreadID)
{
    DifferentialPair<float3> dpThp = __fwd_diff(tracePath)();
    outputBuffer[0] = dpThp.p.x;
    outputBuffer[1] = dpThp.d.x;

    __bwd_diff(tracePath)(float3(1.f, 0.f, 0.f));
}