summaryrefslogtreecommitdiffstats
path: root/examples/autodiff-texture/train.slang
blob: 7126cbde9731ad3e8416aed7ae022f8f0ff72ddf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
// texture.slang

// This class encapsulates a differentiable texture object that uses
// hardware sampling for the primal version, but substitutes a reference
// interpolation implementation to generate backward pass.
//
// This specific implementation also makes the choice to use fast fixed point
// atomics to accumulate the derivative (suitable for this example, but maybe
// not in general)
//
struct DifferentiableTexture
{
    RWStructuredBuffer<int> accumulateBuffer; // Per-mip-level accumulate buffer
    Texture2D texture;                        // Hardware texture handle.
    float minLOD;

    [BackwardDerivative(bwd_LoadTexel)]
    float4 LoadTexel(int3 location, constexpr int2 offset, uint dLayerW, uint dMipOffset)
    {
        return texture.Load(location, offset);
    }

    void bwd_LoadTexel(int3 location, constexpr int2 offset, uint dLayerW, uint dMipOffset, float4 val)
    {
        // Ignore alpha dimension for this example..
        int4 uval = int4(int3(val.xyz * 65536), 1);

        // We'll use fast fixed point atomics instead of floats.
        InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 0], uval.x);
        InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 1], uval.y);
        InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 2], uval.z);
        InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 3], uval.w);
    }

    // Software reference implementation of linear filtering.
    [BackwardDifferentiable]
    float4 sampleTexture_linear(uint lod, float2 uv, uint w, uint h)
    {
        w >>= lod;
        h >>= lod;
        uv = uv - no_diff(floor(uv));
        float2 loc = uv * float2(w, h) - float2(0.5);
        float x0 = no_diff(floor(loc.x));
        float y0 = no_diff(floor(loc.y));
        float fracX = loc.x - x0;
        float fracY = loc.y - y0;
        float x1 = x0 + 1;
        float y1 = y0 + 1;
        if (x0 < 0) x0 += w;
        if (y0 < 0) y0 += h;
        if (x1 >= w) x1 -= w;
        if (y1 >= h) y1 -= h;
        float weight0 = 1.0f - fracY;
        float weight1 = fracY;
        float weight00 = weight0 * (1.0f - fracX);
        float weight01 = weight0 * fracX;
        float weight10 = weight1 * (1.0f - fracX);
        float weight11 = weight1 * fracX;

        uint dLayerW = w >>= lod;
        var offset = mipOffset[lod / 4][lod % 4];
        return LoadTexel(int3(int(x0), int(y0), int(lod)), int2(0), dLayerW, offset) * weight00 +
               LoadTexel(int3(int(x1), int(y0), int(lod)), int2(0), dLayerW, offset) * weight01 +
               LoadTexel(int3(int(x0), int(y1), int(lod)), int2(0), dLayerW, offset) * weight10 +
               LoadTexel(int3(int(x1), int(y1), int(lod)), int2(0), dLayerW, offset) * weight11;
    }

    // Software reference implementation of trilinear filtering.
    [BackwardDifferentiable]
    float4 sampleTexture_trilinear(uint w, uint h, uint levels, float2 uv, float2 dX, float2 dY)
    {
        dX = dX * float2(w, h);
        dY = dY * float2(w, h);

        // Isotropic filter.
        float lengthX = length(dX);
        float lengthY = length(dY);
        float LOD = log2(max(lengthX, lengthY));
        float maxLOD = levels - 1;
        float clampedLOD = max(minLOD, (min(maxLOD, LOD)));

        float lodFrac = clampedLOD - no_diff(floor(clampedLOD));
        uint lod0 = (uint)floor(clampedLOD);
        uint lod1 = min(levels - 1, lod0 + 1);
        float weightLod0 = 1.0 - lodFrac;
        float weightLod1 = lodFrac;

        let v0 = sampleTexture_linear(lod0, uv, w, h) * weightLod0;
        let v1 = sampleTexture_linear(lod1, uv, w, h) * weightLod1;
        return v0 + v1;
    }

    // Note that there is no need to mark this [BackwardDifferentiable] since it has a substitute
    // that is marked [BackwardDifferentiable]. The compiler automatically considers a call to
    // sample() to be differentiable.
    //
    static float4 sample(DifferentiableTexture t, SamplerState s, float2 uv, float2 dX, float2 dY)
    {
        return t.texture.Sample(s, uv);
    }

    // Software reference implementation of DifferentiableTexture.sample (trilinear only in this example)
    [PrimalSubstituteOf(DifferentiableTexture.sample)]
    [BackwardDifferentiable]
    static float4 sample_reference_impl(DifferentiableTexture t, SamplerState s, float2 uv, float2 dX, float2 dY)
    {
        uint w;
        uint h;
        uint levels;
        t.texture.GetDimensions(0, w, h, levels);
        return t.sampleTexture_trilinear(w, h, levels, uv, dX, dY);
    }
}

cbuffer Uniforms
{
    float4x4 modelViewProjection;
    uint4 mipOffset[16];

    Texture2D texRef;
    SamplerState sampler;
    DifferentiableTexture bwdTexture;
}

struct AssembledVertex
{
    float3	position : POSITION;
};

struct Fragment
{
    float4 color;
};

struct VertexStageOutput
{
    float2 uv : UV;
    float4          sv_position     : SV_Position;
};

[BackwardDifferentiable]
float4 shadeFragment(float2 uv)
{
    uv = uv * 2;

    // Compute fragment differentials using shader intrinsics.
    float2 dX = no_diff ddx_coarse(uv);
    float2 dY = no_diff ddy_coarse(uv);

    float3 color = DifferentiableTexture.sample(bwdTexture, sampler, uv, dX, dY).xyz;
    return float4(color, 1.0);
}

[BackwardDifferentiable]
float3 loss(no_diff float2 uv, no_diff float4 screenPos)
{
    float3 refColor = (no_diff texRef.Load(int3(int2(screenPos.xy), 0))).xyz;
    float3 rs = shadeFragment(uv).xyz - refColor;
    rs *= rs;
    return rs;
}

[shader("vertex")]
VertexStageOutput vertexMain(
    AssembledVertex assembledVertex)
{
    VertexStageOutput output;

    float3 position = assembledVertex.position;

    output.uv = position.xy;
    output.sv_position = mul(modelViewProjection, float4(position, 1.0));

    return output;
}

float3 sqr(float3 v) { return v * v; }

[shader("fragment")]
float4 fragmentMain(
    float2 uv : UV) : SV_Target
{
    return shadeFragment(uv);
}

[shader("fragment")]
float4 diffFragmentMain(
    float2 uv : UV,
    float4 screenPos : SV_POSITION) : SV_Target
{
    __bwd_diff(loss)(uv, screenPos, float3(1.0));
    return float4(loss(uv, screenPos), 1.0);
}