diff options
Diffstat (limited to 'docs/user-guide')
| -rw-r--r-- | docs/user-guide/a1-02-slangpy.md | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/docs/user-guide/a1-02-slangpy.md b/docs/user-guide/a1-02-slangpy.md index 7c652228e..fff806229 100644 --- a/docs/user-guide/a1-02-slangpy.md +++ b/docs/user-guide/a1-02-slangpy.md @@ -217,6 +217,181 @@ dX = tensor([[6., 8.], [0., 2.]]) ``` +## Back-propagating Derivatives through Complex Access Patterns + +In most common scenarios, a kernel function will access input tensors in a complex pattern instead of mapping +1:1 from an input element to an output element, like the `square` example shown above. When you have a kernel +function that access many different elements from the input tensors and use them to compute an output element, +the derivatives of each input element can't be represented directly as a function parameter, like the `x` in `square(x)`. + +Consider a 3x3 box filtering kernel that computes for each pixel in a 2D image, the average value of its +surrounding 3x3 pixel block. We can write a Slang function that computes the value of an output pixel: +```csharp +float computeOutputPixel(TensorView<float> input, uint2 pixelLoc) +{ + int width = input.dim(0); + int height = input.dim(1); + + // Track the sum of neighboring pixels and the number + // of pixels currently accumulated. + int count = 0; + float sumValue = 0.0; + + // Iterate through the surrounding area. + for (int x = pixelLoc.x - 1; x <= pixelLoc.x + 1; x++) + { + // Skip out of bounds pixels. + if (x < 0 || x >= width) continue; + if (y < 0 || y >= height) continue; + + for (int y = pixelLoc.y - 1; y <= pixelLoc.y + 1; y++) + { + sumValue += input[x, y]; + count++; + } + } + + // Comptue the average value. + sumValue /= count; + + return sumValue; +} +``` + +We can define our kernel function to compute the entire output image by calling `computeOutputPixel`: + +```csharp +[CudaKernel] +void boxFilter_fwd(TensorView<float> input, TensorView<float> output) +{ + uint2 pixelLoc = (cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx()).xy; + int width = input.dim(0); + int height = input.dim(1); + if (pixelLoc.x >= width) return; + if (pixelLoc.y >= height) return; + + float outputValueAtPixel = computeOutputPixel(input, pixelLoc) + + // Write to output tensor. + output[pixelLoc] = outputValueAtPixel; +} +``` + +How do we define the backward derivative propagation kernel? Note that in this example, there +isn't a function like `square` that we can just mark as `[BackwardDifferentiable]` and +call `__bwd_diff(square)` to get back the derivative of an input parameter. + +In this example, the input comes from multiple elements in a tensor. How do we propagate the +derivatives to those input elements? + +The solution is to wrap tensor access with a custom function: +```csharp +float getInputElement( + TensorView<float> input, + TensorView<float> inputGradToPropagateTo, + uint2 loc) +{ + return input[loc]; +} +``` + +Note that the `getInputElement` function simply returns `input[loc]` and is not using the +`inputGradToPropagateTo` parameter. That is intended. The `inputGradToPropagateTo` parameter +is used to hold the backward propagated derivatives of each input element, and is reserved for later use. + +Now we can replace all direct accesses to `input` with a call to `getInputElement`. The +`computeOutputPixel` can be implemented as following: + +```csharp +[BackwardDifferentiable] +float computeOutputPixel( + TensorView<float> input, + TensorView<float> inputGradToPropagateTo, + uint2 pixelLoc) +{ + int width = input.dim(0); + int height = input.dim(1); + + // Track the sum of neighboring pixels and the number + // of pixels currently accumulated. + int count = 0; + float sumValue = 0.0; + + // Iterate through the surrounding area. + for (int x = pixelLoc.x - 1; x <= pixelLoc.x + 1; x++) + { + // Skip out of bounds pixels. + if (x < 0 || x >= width) continue; + if (y < 0 || y >= height) continue; + + for (int y = pixelLoc.y - 1; y <= pixelLoc.y + 1; y++) + { + sumValue += getInputElement(input, inputGradToPropagateTo, uint2(x, y)); + count++; + } + } + + // Comptue the average value. + sumValue /= count; + + return sumValue; +} +``` + +The main changes compared to our original version of `computeOutputPixel` are: +- Added a `inputGradToPropagateTo` parameter. +- Modified `input[x,y]` with a call to `getInputElement`. +- Added a `[BackwardDifferentiable]` attribute to the function. + +With that, we can define our backward kernel function: + +```csharp +[CudaKernel] +void boxFilter_bwd( + TensorView<float> input, + TensorView<float> resultGradToPropagateFrom, + TensorView<float> inputGradToPropagateTo) +{ + uint2 pixelLoc = (cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx()).xy; + int width = input.dim(0); + int height = input.dim(1); + if (pixelLoc.x >= width) return; + if (pixelLoc.y >= height) return; + + __bwd_diff(computeOutputPixel)(input, inputGradToPropagateTo, pixelLoc); +} +``` + +The kernel function simply calls `__bwd_diff(computeOutputPixel)` without taking any return values from the call +and without writing to any elements in the final `inputGradToPropagateTo` tensor. But when exactly does the proapgated +output get written to the output gradient tensor (`inputGradToPropagateTo`)? + +And that logic is defined in our final piece of code: +```csharp +[BackwardDerivativeOf(getInputElement)] +void getInputElement_bwd( + TensorView<float> input, + TensorView<float> inputGradToPropagateTo, + uint2 loc, + float derivative) +{ + inputGradToPropagateTo.InterlockedAdd(loc, derivative); +} +``` + +Here, we are providing a custom defined backward propagation function for `getInputElement`. +In this function, we simply add `derivative` to the element in `inputGradToPropagateTo` tensor. + +When we call `__bwd_diff(computeOutputPixel)` in `boxFilter_bwd`, the Slang compiler will automatically +differentiate all operations and function calls in `computeOutputPixel`. By wrapping the tensor element access +with `getInputElement` and by providing a custom backward propagation function of `getInputElement`, we are effectively +telling the compiler what to do when a derivative propagates to an input tensor element. Inside the body +of `getInputElement_bwd`, we define what to do then: atomically adds the derivative propagated to the input element +in the `inputGradToPropagateTo` tensor. Therefore after running `boxFilter_bwd`, the `inputGradToPropagateTo` tensor will contain all the +back propagated derivative values. + +Again, to understand all the details of the automatic differentiation system, please refer to the +[Automatic Differentiation](07-autodiff.md) chapter for a detailed explanation. ## Builtin Library Support for PyTorch Interop |
