summaryrefslogtreecommitdiffstats
path: root/examples
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-10-25 16:31:49 -0400
committerGitHub <noreply@github.com>2023-10-25 13:31:49 -0700
commit4572976fd60817b9e2644b6fcadbf34511e770a9 (patch)
treeb4eea5274c3a98d2b4c1b05e6f1eecb2ffd450a3 /examples
parent1a8216b7cd6f272253e7381bc520c65b7dd38b24 (diff)
Update autodiff-texture example with one that auto-diffs a reference impl. (#3288)
Diffstat (limited to 'examples')
-rw-r--r--examples/autodiff-texture/train.slang141
1 files changed, 59 insertions, 82 deletions
diff --git a/examples/autodiff-texture/train.slang b/examples/autodiff-texture/train.slang
index 16f3be35c..e171d3d71 100644
--- a/examples/autodiff-texture/train.slang
+++ b/examples/autodiff-texture/train.slang
@@ -1,79 +1,47 @@
-// shaders.slang
-
-struct BwdTexture
+// texture.slang
+
+// This class encapsulates a differentiable texture object that uses
+// hardware sampling for the primal version, but substitutes a reference
+// interpolation implementation to generate backward pass.
+//
+// This specific implementation also makes the choice to use fast fixed point
+// atomics to accumulate the derivative (suitable for this example, but maybe
+// not in general)
+//
+struct DifferentiableTexture
{
- RWStructuredBuffer<int> accumulateBuffer;
- Texture2D texture;
+ RWStructuredBuffer<int> accumulateBuffer; // Per-mip-level accumulate buffer
+ Texture2D texture; // Hardware texture handle.
float minLOD;
- void writeDiffToTexel(uint offset, uint layerW, uint layerH, float x, float y, float3 val)
- {
- int4 uval = int4(int3(val * 65536), 1);
- InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 0], uval.x);
- InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 1], uval.y);
- InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 2], uval.z);
- InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 3], uval.w);
- }
- void broadcastDiffToLayer(uint lod, float3 diff, float2 uv, uint w, uint h)
+ [BackwardDerivative(bwd_LoadTexel)]
+ float4 LoadTexel(int3 location, int2 offset, uint dLayerW, uint dMipOffset)
{
- var offset = mipOffset[lod / 4][lod % 4];
- w >>= lod;
- h >>= lod;
- uv = uv - floor(uv);
- float2 loc = uv * float2(w, h) - float2(0.5);
- float x0 = floor(loc.x);
- float y0 = floor(loc.y);
- float fracX = loc.x - x0;
- float fracY = loc.y - y0;
- float x1 = x0 + 1;
- float y1 = y0 + 1;
- if (x0 < 0) x0 += w;
- if (y0 < 0) y0 += h;
- if (x1 >= w) x1 -= w;
- if (y1 >= h) y1 -= h;
- float weight0 = 1.0f - fracY;
- float weight1 = fracY;
- float weight00 = weight0 * (1.0f - fracX);
- float weight01 = weight0 * fracX;
- float weight10 = weight1 * (1.0f - fracX);
- float weight11 = weight1 * fracX;
-
- writeDiffToTexel(offset, w, h, x0, y0, weight00 * diff);
- writeDiffToTexel(offset, w, h, x1, y0, weight01 * diff);
- writeDiffToTexel(offset, w, h, x0, y1, weight10 * diff);
- writeDiffToTexel(offset, w, h, x1, y1, weight11 * diff);
+ return texture.Load(location, offset);
}
- void sampleTexture_trilinear_bwd(uint w, uint h, uint levels, float2 uv, float2 dX, float2 dY, float4 dOut)
+ void bwd_LoadTexel(int3 location, int2 offset, uint dLayerW, uint dMipOffset, float4 val)
{
- dX = dX * float2(w, h);
- dY = dY * float2(w, h);
-
- // Isotropic filter.
- float lengthX = length(dX);
- float lengthY = length(dY);
- float LOD = log2(max(lengthX, lengthY));
- float maxLOD = levels - 1;
- float clampedLOD = max(minLOD, (min(maxLOD, LOD)));
-
- float lodFrac = clampedLOD - floor(clampedLOD);
- uint lod0 = (uint)floor(clampedLOD);
- uint lod1 = min(levels - 1, lod0 + 1);
- float weightLod0 = 1.0 - lodFrac;
- float weightLod1 = lodFrac;
- weightLod0 = 1.0;
- broadcastDiffToLayer(lod0, dOut.xyz * weightLod0, uv, w, h);
- broadcastDiffToLayer(lod1, dOut.xyz * weightLod1, uv, w, h);
+ // Ignore alpha dimension for this example..
+ int4 uval = int4(int3(val.xyz * 65536), 1);
+
+ // We'll use fast fixed point atomics instead of floats.
+ InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 0], uval.x);
+ InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 1], uval.y);
+ InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 2], uval.z);
+ InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 3], uval.w);
}
+ // Software reference implementation of linear filtering.
+ [BackwardDifferentiable]
float4 sampleTexture_linear(uint lod, float2 uv, uint w, uint h)
{
w >>= lod;
h >>= lod;
- uv = uv - floor(uv);
+ uv = uv - no_diff(floor(uv));
float2 loc = uv * float2(w, h) - float2(0.5);
- float x0 = floor(loc.x);
- float y0 = floor(loc.y);
+ float x0 = no_diff(floor(loc.x));
+ float y0 = no_diff(floor(loc.y));
float fracX = loc.x - x0;
float fracY = loc.y - y0;
float x1 = x0 + 1;
@@ -88,12 +56,17 @@ struct BwdTexture
float weight01 = weight0 * fracX;
float weight10 = weight1 * (1.0f - fracX);
float weight11 = weight1 * fracX;
- return texture.Load(int3(int(x0), int(y0), int(lod)), int2(0)) * weight00 +
- texture.Load(int3(int(x1), int(y0), int(lod)), int2(0)) * weight01 +
- texture.Load(int3(int(x0), int(y1), int(lod)), int2(0)) * weight10 +
- texture.Load(int3(int(x1), int(y1), int(lod)), int2(0)) * weight11;
+
+ uint dLayerW = w >>= lod;
+ var offset = mipOffset[lod / 4][lod % 4];
+ return LoadTexel(int3(int(x0), int(y0), int(lod)), int2(0), dLayerW, offset) * weight00 +
+ LoadTexel(int3(int(x1), int(y0), int(lod)), int2(0), dLayerW, offset) * weight01 +
+ LoadTexel(int3(int(x0), int(y1), int(lod)), int2(0), dLayerW, offset) * weight10 +
+ LoadTexel(int3(int(x1), int(y1), int(lod)), int2(0), dLayerW, offset) * weight11;
}
+ // Software reference implementation of trilinear filtering.
+ [BackwardDifferentiable]
float4 sampleTexture_trilinear(uint w, uint h, uint levels, float2 uv, float2 dX, float2 dY)
{
dX = dX * float2(w, h);
@@ -106,7 +79,7 @@ struct BwdTexture
float maxLOD = levels - 1;
float clampedLOD = max(minLOD, (min(maxLOD, LOD)));
- float lodFrac = clampedLOD - floor(clampedLOD);
+ float lodFrac = clampedLOD - no_diff(floor(clampedLOD));
uint lod0 = (uint)floor(clampedLOD);
uint lod1 = min(levels - 1, lod0 + 1);
float weightLod0 = 1.0 - lodFrac;
@@ -117,27 +90,25 @@ struct BwdTexture
return v0 + v1;
}
- static float4 sample(BwdTexture t, SamplerState s, no_diff float2 uv)
+ // Note that there is no need to mark this [BackwardDifferentiable] since it has a substitute
+ // that is marked [BackwardDifferentiable]. The compiler automatically considers a call to
+ // sample() to be differentiable.
+ //
+ static float4 sample(DifferentiableTexture t, SamplerState s, float2 uv, float2 dX, float2 dY)
{
return t.texture.Sample(s, uv);
}
- [ForwardDerivativeOf(BwdTexture.sample)]
- static float4 fwd_sample(BwdTexture t, SamplerState s, no_diff float2 uv)
+ // Software reference implementation of DifferentiableTexture.sample (trilinear only in this example)
+ [PrimalSubstituteOf(DifferentiableTexture.sample)]
+ [BackwardDifferentiable]
+ static float4 sample_reference_impl(DifferentiableTexture t, SamplerState s, float2 uv, float2 dX, float2 dY)
{
- return float4(0.0);
- }
-
- [BackwardDerivativeOf(BwdTexture.sample)]
- static void bwd_sample(BwdTexture t, SamplerState s, no_diff float2 uv, float4 dOut)
- {
- float2 dX = ddx_coarse(uv);
- float2 dY = ddy_coarse(uv);
uint w;
uint h;
uint levels;
t.texture.GetDimensions(0, w, h, levels);
- t.sampleTexture_trilinear_bwd(w, h, levels, uv, dX, dY, dOut);
+ return t.sampleTexture_trilinear(w, h, levels, uv, dX, dY);
}
}
@@ -148,7 +119,7 @@ cbuffer Uniforms
Texture2D texRef;
SamplerState sampler;
- BwdTexture bwdTexture;
+ DifferentiableTexture bwdTexture;
}
struct AssembledVertex
@@ -168,9 +139,15 @@ struct VertexStageOutput
};
[BackwardDifferentiable]
-float4 shadeFragment(no_diff float2 uv)
+float4 shadeFragment(float2 uv)
{
- float3 color = BwdTexture.sample(bwdTexture, sampler, uv * 2).xyz;
+ uv = uv * 2;
+
+ // Compute fragment differentials using shader intrinsics.
+ float2 dX = no_diff ddx_coarse(uv);
+ float2 dY = no_diff ddy_coarse(uv);
+
+ float3 color = DifferentiableTexture.sample(bwdTexture, sampler, uv, dX, dY).xyz;
return float4(color, 1.0);
}