From 4572976fd60817b9e2644b6fcadbf34511e770a9 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 25 Oct 2023 16:31:49 -0400 Subject: Update autodiff-texture example with one that auto-diffs a reference impl. (#3288) --- examples/autodiff-texture/train.slang | 141 ++++++++++++++-------------------- 1 file changed, 59 insertions(+), 82 deletions(-) (limited to 'examples') diff --git a/examples/autodiff-texture/train.slang b/examples/autodiff-texture/train.slang index 16f3be35c..e171d3d71 100644 --- a/examples/autodiff-texture/train.slang +++ b/examples/autodiff-texture/train.slang @@ -1,79 +1,47 @@ -// shaders.slang - -struct BwdTexture +// texture.slang + +// This class encapsulates a differentiable texture object that uses +// hardware sampling for the primal version, but substitutes a reference +// interpolation implementation to generate backward pass. +// +// This specific implementation also makes the choice to use fast fixed point +// atomics to accumulate the derivative (suitable for this example, but maybe +// not in general) +// +struct DifferentiableTexture { - RWStructuredBuffer accumulateBuffer; - Texture2D texture; + RWStructuredBuffer accumulateBuffer; // Per-mip-level accumulate buffer + Texture2D texture; // Hardware texture handle. float minLOD; - void writeDiffToTexel(uint offset, uint layerW, uint layerH, float x, float y, float3 val) - { - int4 uval = int4(int3(val * 65536), 1); - InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 0], uval.x); - InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 1], uval.y); - InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 2], uval.z); - InterlockedAdd(accumulateBuffer[offset + ((uint)y * layerW + (uint)x) * 4 + 3], uval.w); - } - void broadcastDiffToLayer(uint lod, float3 diff, float2 uv, uint w, uint h) + [BackwardDerivative(bwd_LoadTexel)] + float4 LoadTexel(int3 location, int2 offset, uint dLayerW, uint dMipOffset) { - var offset = mipOffset[lod / 4][lod % 4]; - w >>= lod; - h >>= lod; - uv = uv - floor(uv); - float2 loc = uv * float2(w, h) - float2(0.5); - float x0 = floor(loc.x); - float y0 = floor(loc.y); - float fracX = loc.x - x0; - float fracY = loc.y - y0; - float x1 = x0 + 1; - float y1 = y0 + 1; - if (x0 < 0) x0 += w; - if (y0 < 0) y0 += h; - if (x1 >= w) x1 -= w; - if (y1 >= h) y1 -= h; - float weight0 = 1.0f - fracY; - float weight1 = fracY; - float weight00 = weight0 * (1.0f - fracX); - float weight01 = weight0 * fracX; - float weight10 = weight1 * (1.0f - fracX); - float weight11 = weight1 * fracX; - - writeDiffToTexel(offset, w, h, x0, y0, weight00 * diff); - writeDiffToTexel(offset, w, h, x1, y0, weight01 * diff); - writeDiffToTexel(offset, w, h, x0, y1, weight10 * diff); - writeDiffToTexel(offset, w, h, x1, y1, weight11 * diff); + return texture.Load(location, offset); } - void sampleTexture_trilinear_bwd(uint w, uint h, uint levels, float2 uv, float2 dX, float2 dY, float4 dOut) + void bwd_LoadTexel(int3 location, int2 offset, uint dLayerW, uint dMipOffset, float4 val) { - dX = dX * float2(w, h); - dY = dY * float2(w, h); - - // Isotropic filter. - float lengthX = length(dX); - float lengthY = length(dY); - float LOD = log2(max(lengthX, lengthY)); - float maxLOD = levels - 1; - float clampedLOD = max(minLOD, (min(maxLOD, LOD))); - - float lodFrac = clampedLOD - floor(clampedLOD); - uint lod0 = (uint)floor(clampedLOD); - uint lod1 = min(levels - 1, lod0 + 1); - float weightLod0 = 1.0 - lodFrac; - float weightLod1 = lodFrac; - weightLod0 = 1.0; - broadcastDiffToLayer(lod0, dOut.xyz * weightLod0, uv, w, h); - broadcastDiffToLayer(lod1, dOut.xyz * weightLod1, uv, w, h); + // Ignore alpha dimension for this example.. + int4 uval = int4(int3(val.xyz * 65536), 1); + + // We'll use fast fixed point atomics instead of floats. + InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 0], uval.x); + InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 1], uval.y); + InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 2], uval.z); + InterlockedAdd(accumulateBuffer[dMipOffset + ((uint)location.y * dLayerW + (uint)location.x) * 4 + 3], uval.w); } + // Software reference implementation of linear filtering. + [BackwardDifferentiable] float4 sampleTexture_linear(uint lod, float2 uv, uint w, uint h) { w >>= lod; h >>= lod; - uv = uv - floor(uv); + uv = uv - no_diff(floor(uv)); float2 loc = uv * float2(w, h) - float2(0.5); - float x0 = floor(loc.x); - float y0 = floor(loc.y); + float x0 = no_diff(floor(loc.x)); + float y0 = no_diff(floor(loc.y)); float fracX = loc.x - x0; float fracY = loc.y - y0; float x1 = x0 + 1; @@ -88,12 +56,17 @@ struct BwdTexture float weight01 = weight0 * fracX; float weight10 = weight1 * (1.0f - fracX); float weight11 = weight1 * fracX; - return texture.Load(int3(int(x0), int(y0), int(lod)), int2(0)) * weight00 + - texture.Load(int3(int(x1), int(y0), int(lod)), int2(0)) * weight01 + - texture.Load(int3(int(x0), int(y1), int(lod)), int2(0)) * weight10 + - texture.Load(int3(int(x1), int(y1), int(lod)), int2(0)) * weight11; + + uint dLayerW = w >>= lod; + var offset = mipOffset[lod / 4][lod % 4]; + return LoadTexel(int3(int(x0), int(y0), int(lod)), int2(0), dLayerW, offset) * weight00 + + LoadTexel(int3(int(x1), int(y0), int(lod)), int2(0), dLayerW, offset) * weight01 + + LoadTexel(int3(int(x0), int(y1), int(lod)), int2(0), dLayerW, offset) * weight10 + + LoadTexel(int3(int(x1), int(y1), int(lod)), int2(0), dLayerW, offset) * weight11; } + // Software reference implementation of trilinear filtering. + [BackwardDifferentiable] float4 sampleTexture_trilinear(uint w, uint h, uint levels, float2 uv, float2 dX, float2 dY) { dX = dX * float2(w, h); @@ -106,7 +79,7 @@ struct BwdTexture float maxLOD = levels - 1; float clampedLOD = max(minLOD, (min(maxLOD, LOD))); - float lodFrac = clampedLOD - floor(clampedLOD); + float lodFrac = clampedLOD - no_diff(floor(clampedLOD)); uint lod0 = (uint)floor(clampedLOD); uint lod1 = min(levels - 1, lod0 + 1); float weightLod0 = 1.0 - lodFrac; @@ -117,27 +90,25 @@ struct BwdTexture return v0 + v1; } - static float4 sample(BwdTexture t, SamplerState s, no_diff float2 uv) + // Note that there is no need to mark this [BackwardDifferentiable] since it has a substitute + // that is marked [BackwardDifferentiable]. The compiler automatically considers a call to + // sample() to be differentiable. + // + static float4 sample(DifferentiableTexture t, SamplerState s, float2 uv, float2 dX, float2 dY) { return t.texture.Sample(s, uv); } - [ForwardDerivativeOf(BwdTexture.sample)] - static float4 fwd_sample(BwdTexture t, SamplerState s, no_diff float2 uv) + // Software reference implementation of DifferentiableTexture.sample (trilinear only in this example) + [PrimalSubstituteOf(DifferentiableTexture.sample)] + [BackwardDifferentiable] + static float4 sample_reference_impl(DifferentiableTexture t, SamplerState s, float2 uv, float2 dX, float2 dY) { - return float4(0.0); - } - - [BackwardDerivativeOf(BwdTexture.sample)] - static void bwd_sample(BwdTexture t, SamplerState s, no_diff float2 uv, float4 dOut) - { - float2 dX = ddx_coarse(uv); - float2 dY = ddy_coarse(uv); uint w; uint h; uint levels; t.texture.GetDimensions(0, w, h, levels); - t.sampleTexture_trilinear_bwd(w, h, levels, uv, dX, dY, dOut); + return t.sampleTexture_trilinear(w, h, levels, uv, dX, dY); } } @@ -148,7 +119,7 @@ cbuffer Uniforms Texture2D texRef; SamplerState sampler; - BwdTexture bwdTexture; + DifferentiableTexture bwdTexture; } struct AssembledVertex @@ -168,9 +139,15 @@ struct VertexStageOutput }; [BackwardDifferentiable] -float4 shadeFragment(no_diff float2 uv) +float4 shadeFragment(float2 uv) { - float3 color = BwdTexture.sample(bwdTexture, sampler, uv * 2).xyz; + uv = uv * 2; + + // Compute fragment differentials using shader intrinsics. + float2 dX = no_diff ddx_coarse(uv); + float2 dY = no_diff ddy_coarse(uv); + + float3 color = DifferentiableTexture.sample(bwdTexture, sampler, uv, dX, dY).xyz; return float4(color, 1.0); } -- cgit v1.2.3