summaryrefslogtreecommitdiffstats
path: root/docs
diff options
context:
space:
mode:
Diffstat (limited to 'docs')
-rw-r--r--docs/user-guide/a1-02-slangpy.md175
1 files changed, 175 insertions, 0 deletions
diff --git a/docs/user-guide/a1-02-slangpy.md b/docs/user-guide/a1-02-slangpy.md
index 7c652228e..fff806229 100644
--- a/docs/user-guide/a1-02-slangpy.md
+++ b/docs/user-guide/a1-02-slangpy.md
@@ -217,6 +217,181 @@ dX = tensor([[6., 8.],
[0., 2.]])
```
+## Back-propagating Derivatives through Complex Access Patterns
+
+In most common scenarios, a kernel function will access input tensors in a complex pattern instead of mapping
+1:1 from an input element to an output element, like the `square` example shown above. When you have a kernel
+function that access many different elements from the input tensors and use them to compute an output element,
+the derivatives of each input element can't be represented directly as a function parameter, like the `x` in `square(x)`.
+
+Consider a 3x3 box filtering kernel that computes for each pixel in a 2D image, the average value of its
+surrounding 3x3 pixel block. We can write a Slang function that computes the value of an output pixel:
+```csharp
+float computeOutputPixel(TensorView<float> input, uint2 pixelLoc)
+{
+ int width = input.dim(0);
+ int height = input.dim(1);
+
+ // Track the sum of neighboring pixels and the number
+ // of pixels currently accumulated.
+ int count = 0;
+ float sumValue = 0.0;
+
+ // Iterate through the surrounding area.
+ for (int x = pixelLoc.x - 1; x <= pixelLoc.x + 1; x++)
+ {
+ // Skip out of bounds pixels.
+ if (x < 0 || x >= width) continue;
+ if (y < 0 || y >= height) continue;
+
+ for (int y = pixelLoc.y - 1; y <= pixelLoc.y + 1; y++)
+ {
+ sumValue += input[x, y];
+ count++;
+ }
+ }
+
+ // Comptue the average value.
+ sumValue /= count;
+
+ return sumValue;
+}
+```
+
+We can define our kernel function to compute the entire output image by calling `computeOutputPixel`:
+
+```csharp
+[CudaKernel]
+void boxFilter_fwd(TensorView<float> input, TensorView<float> output)
+{
+ uint2 pixelLoc = (cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx()).xy;
+ int width = input.dim(0);
+ int height = input.dim(1);
+ if (pixelLoc.x >= width) return;
+ if (pixelLoc.y >= height) return;
+
+ float outputValueAtPixel = computeOutputPixel(input, pixelLoc)
+
+ // Write to output tensor.
+ output[pixelLoc] = outputValueAtPixel;
+}
+```
+
+How do we define the backward derivative propagation kernel? Note that in this example, there
+isn't a function like `square` that we can just mark as `[BackwardDifferentiable]` and
+call `__bwd_diff(square)` to get back the derivative of an input parameter.
+
+In this example, the input comes from multiple elements in a tensor. How do we propagate the
+derivatives to those input elements?
+
+The solution is to wrap tensor access with a custom function:
+```csharp
+float getInputElement(
+ TensorView<float> input,
+ TensorView<float> inputGradToPropagateTo,
+ uint2 loc)
+{
+ return input[loc];
+}
+```
+
+Note that the `getInputElement` function simply returns `input[loc]` and is not using the
+`inputGradToPropagateTo` parameter. That is intended. The `inputGradToPropagateTo` parameter
+is used to hold the backward propagated derivatives of each input element, and is reserved for later use.
+
+Now we can replace all direct accesses to `input` with a call to `getInputElement`. The
+`computeOutputPixel` can be implemented as following:
+
+```csharp
+[BackwardDifferentiable]
+float computeOutputPixel(
+ TensorView<float> input,
+ TensorView<float> inputGradToPropagateTo,
+ uint2 pixelLoc)
+{
+ int width = input.dim(0);
+ int height = input.dim(1);
+
+ // Track the sum of neighboring pixels and the number
+ // of pixels currently accumulated.
+ int count = 0;
+ float sumValue = 0.0;
+
+ // Iterate through the surrounding area.
+ for (int x = pixelLoc.x - 1; x <= pixelLoc.x + 1; x++)
+ {
+ // Skip out of bounds pixels.
+ if (x < 0 || x >= width) continue;
+ if (y < 0 || y >= height) continue;
+
+ for (int y = pixelLoc.y - 1; y <= pixelLoc.y + 1; y++)
+ {
+ sumValue += getInputElement(input, inputGradToPropagateTo, uint2(x, y));
+ count++;
+ }
+ }
+
+ // Comptue the average value.
+ sumValue /= count;
+
+ return sumValue;
+}
+```
+
+The main changes compared to our original version of `computeOutputPixel` are:
+- Added a `inputGradToPropagateTo` parameter.
+- Modified `input[x,y]` with a call to `getInputElement`.
+- Added a `[BackwardDifferentiable]` attribute to the function.
+
+With that, we can define our backward kernel function:
+
+```csharp
+[CudaKernel]
+void boxFilter_bwd(
+ TensorView<float> input,
+ TensorView<float> resultGradToPropagateFrom,
+ TensorView<float> inputGradToPropagateTo)
+{
+ uint2 pixelLoc = (cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx()).xy;
+ int width = input.dim(0);
+ int height = input.dim(1);
+ if (pixelLoc.x >= width) return;
+ if (pixelLoc.y >= height) return;
+
+ __bwd_diff(computeOutputPixel)(input, inputGradToPropagateTo, pixelLoc);
+}
+```
+
+The kernel function simply calls `__bwd_diff(computeOutputPixel)` without taking any return values from the call
+and without writing to any elements in the final `inputGradToPropagateTo` tensor. But when exactly does the proapgated
+output get written to the output gradient tensor (`inputGradToPropagateTo`)?
+
+And that logic is defined in our final piece of code:
+```csharp
+[BackwardDerivativeOf(getInputElement)]
+void getInputElement_bwd(
+ TensorView<float> input,
+ TensorView<float> inputGradToPropagateTo,
+ uint2 loc,
+ float derivative)
+{
+ inputGradToPropagateTo.InterlockedAdd(loc, derivative);
+}
+```
+
+Here, we are providing a custom defined backward propagation function for `getInputElement`.
+In this function, we simply add `derivative` to the element in `inputGradToPropagateTo` tensor.
+
+When we call `__bwd_diff(computeOutputPixel)` in `boxFilter_bwd`, the Slang compiler will automatically
+differentiate all operations and function calls in `computeOutputPixel`. By wrapping the tensor element access
+with `getInputElement` and by providing a custom backward propagation function of `getInputElement`, we are effectively
+telling the compiler what to do when a derivative propagates to an input tensor element. Inside the body
+of `getInputElement_bwd`, we define what to do then: atomically adds the derivative propagated to the input element
+in the `inputGradToPropagateTo` tensor. Therefore after running `boxFilter_bwd`, the `inputGradToPropagateTo` tensor will contain all the
+back propagated derivative values.
+
+Again, to understand all the details of the automatic differentiation system, please refer to the
+[Automatic Differentiation](07-autodiff.md) chapter for a detailed explanation.
## Builtin Library Support for PyTorch Interop