diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-28 15:19:03 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-28 15:19:03 -0700 |
| commit | a61f089fbc4b944d058e6417d8a0d22d57ca5c92 (patch) | |
| tree | 4fa1a0c6370b8d34262d297653239f48aa004c71 /docs/user-guide | |
| parent | 8f03af5e5b580170fab3fd2fe6144f92038c7701 (diff) | |
Add slangpy doc, fix cuda prelude. (#2748)
* Add slangpy doc, fix cuda prelude.
* more bug fix.
* fix.
* fix.
* More fix.
* fix.
* f
* fix prelude.
* update prelude.
* update doc
* Update prelude.
* add zeros_like
* update doc.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'docs/user-guide')
| -rw-r--r-- | docs/user-guide/a1-02-slangpy.md | 266 | ||||
| -rw-r--r-- | docs/user-guide/a1-special-topics.md | 3 | ||||
| -rw-r--r-- | docs/user-guide/toc.html | 7 |
3 files changed, 275 insertions, 1 deletions
diff --git a/docs/user-guide/a1-02-slangpy.md b/docs/user-guide/a1-02-slangpy.md new file mode 100644 index 000000000..7bcc47389 --- /dev/null +++ b/docs/user-guide/a1-02-slangpy.md @@ -0,0 +1,266 @@ +--- +layout: user-guide +--- + +Using Slang to Write PyTorch Kernels +========================================================= + +If you are a PyTorch user looking for a way to write complex, high performance and automatically differentiated kernel functions in a per-thread instead of full-tensor style, give Slang a try. Slang is evolved from on a traditional shading language that were designed +to provide a simple way to define kernel functions that runs extremely fast in graphics applications. With the latest addition of +automatic differentiation and PyTorch interop features, Slang provides a streamlined solution to author auto-differentiated kernels +that runs at the speed of light with a strongly typed, per-thread programming model. + +## Getting Started with `slangpy` + +In this tutorial, we will use a simple example to walkthrough the steps to use Slang in your PyTorch project. + +### Writing a simple kernel function as a Slang module + +Assume we want to write a kernel function that computes `x*x` for each element in the input tensor in Slang. To do so, +we start by creating a `square.slang` file: + +```csharp +// square.slang +float square(float x) +{ + return x * x; +} +``` + +This function is self explanatory. To use it in PyTorch, we need to write a GPU kernel function (that maps to a +`__global__` CUDA function) that defines how to compute each element of the input tensor. So we continue to write +the following Slang function: + +```csharp +[CudaKernel] +void square_fwd_kernel(TensorView<float> input, TensorView<float> output) +{ + uint3 globalIdx = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); + + if (globalIdx.x > input.size(0) || globalIdx.x > input.size(1)) + return; + float result = square(input[globalIdx.xy]); + output[globalIdx.xy] = result; +} +``` + +This code follows the standard pattern of a typical CUDA kernel function. It takes as input +two tensors, `input` and `output`. +It first obtains the global dispatch index of the current thread and performs range check to make sure we don't read or write out +of the bounds of input and output tensors, and then calls `square()` to compute the per-element result, and +store it at the corresponding location in `output` tensor. + +With a kernel function defined, we then need to expose a CPU(host) function that defines how this kernel is dispatched: +```csharp +[TorchEntryPoint] +TorchTensor<float> square_fwd(TorchTensor<float> input) +{ + var result = TorchTensor<float>.zerosLike(input); + let blockCount = uint3(1); + let groupSize = uint3(result.size(0), result.size(1), 1); + __dispatch_kernel(square_fwd_kernel, blockCount, groupSize)(input, result); + return result; +} +``` +Here, we first call `TorchTensor<float>.alloc` to allocate a 2D-tensor that has the same size as the input. +This function returns a `TorchTensor<float>` object that represents a CPU handle of a PyTorch tensor. +Then we launch `square_fwd_kernel` with the `__dispatch_kernel` syntax. Note that we can directly pass +`TorchTensor<float>` arguments to a `TensorView<float>` parameter and the compiler will automatically convert +the type and obtain a view into the tensor that can be accessed by the GPU kernel function. + +### Calling Slang module from Python + +Next, let's see how we can call the `square_fwd` function we defined in the Slang module. +To do so, we use a python package called `slangpy`. You can obtain it with + +```bash +pip install slangpy +``` + +With that, you can use the following code code call Slang function from Python: + +```python +import torch +import slangpy + +m = slangpy.loadModule("square.slang") + +x = torch.randn(2,2) +print(f"X = {x}") +y = m.square_fwd(x) +print(f"Y = {y.cpu()}") +``` + +Result output: +``` +X = tensor([[ 0.1407, 0.6594], + [-0.8978, -1.7230]]) +Y = tensor([[0.0198, 0.4349], + [0.8060, 2.9688]]) +``` + +And that's it! `slangpy.loadModule` uses JIT compilation to compile your Slang source into CUDA binary. +It may take a little longer the first time you execute the script, but the result will be cached and as +long as the kernel code is not changed, future runs will not rebuild the CUDA kernel. + +Because the PyTorch JIT system requires `ninja`, you need to make sure `ninja` is installed on your system +and is discoverable from the current environment. + +### Exposing an automatically differentiated kernel to PyTorch + +The above example demonstrates how to write a simple kernel function in Slang and call it from Python. +Another major benefit of using Slang is that the Slang compiler support generating backward derivative +propagation functions automatically. + +In the following section, we walkthrough how to use Slang to generate a backward propagation function +for `square`, and expose it to PyTorch as an autograd function. + +First we need to tell Slang compiler that we need the `square` function to be considered a differentiable function so Slang compiler can generate a backward derivative propagation function for it: +```csharp +[BackwardDifferentiable] +float square(float x) +{ + return x * x; +} +``` +This is done by simply adding a `[BackwardDifferentiable]` attribute to our `square`function. + +With that, we can now define `square_bwd_kernel` that performance backward propagation as: + +```csharp +[CudaKernel] +void square_bwd_kernel(TensorView<float> input, TensorView<float> grad_out, TensorView<float> grad_propagated) +{ + uint3 globalIdx = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); + + if (globalIdx.x > input.size(0) || globalIdx.x > input.size(1)) + return; + + DifferentialPair<float> dpInput = diffPair(input[globalIdx.xy]); + var gradInElem = grad_out[globalIdx.xy]; + __bwd_diff(square)(dpInput, gradInElem); + grad_propagated[globalIdx.xy] = dpInput.d; +} +``` + +Note that the function follows the same structure of `square_fwd_kernel`, with the only difference being that +instead of calling into `square` to compute the forward value for each tensor element, we are calling `__bwd_diff(square)` +that represents the automatically generated backward propagation function of `squre`. +`__bwd_diff(squre)` will have the following signature: +```csharp +void __bwd_diff_squre(inout DifferentialPair<float> dpInput, float dOut); +``` + +Where the first parameter, `dpInput` represents a pair of original and derivative value for `input`, and the second parameter, +`dOut`, represents the initial derivative with regard to some latent variable that we wish to backprop through. The resulting +derivative will be stored in `dpInput.d`. For example: + +```csharp +// construct a pair where the primal value is 3, and derivative value is 0. +var dp = diffPair(3.0); +__bwd_diff(square)(dp, 1.0); +// dp.d is now 6.0 +``` + +Similarly to `squre_fwd`, we can define the host side function `square_bwd` as: + +```csharp +[TorchEntryPoint] +TorchTensor<float> square_bwd(TorchTensor<float> input, TorchTensor<float> grad_out) +{ + var grad_propagated = TorchTensor<float>.zerosLike(input); + let blockCount = uint3(1); + let groupSize = uint3(input.size(0), input.size(1), 1); + __dispatch_kernel(square_bwd_kernel, blockCount, groupSize)(input, grad_out, grad_propagated); + return grad_propagated; +} +``` + +You can refer [this documentation](07-autodiff.md) for a detailed reference of Slang's automatic differentiation system. + +With this, the python script `slangpy.loadModule("square.slang")` will now return +a scope that defines two functions, `square_fwd` and `square_bwd`. We can then use these +two functions to define a PyTorch autograd kernel class: + +```python +m = slangpy.loadModule("square.slang") + +class MySquareFuncInSlang(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return m.square_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + [input] = ctx.saved_tensors + return m.square_bwd(input, grad_output) +``` + +Now we can use the autograd function `MySquareFuncInSlang` in our python script: + +```python +x = torch.tensor([[3.0, 4.0],[0.0, 1.0]], requires_grad=True, device=cuda_device) +print(f"X = {x}") +y_pred = MySquareFuncInSlang.apply(x) +loss = y_pred.sum() +loss.backward() +print(f"dX = {x.grad.cpu()}") +``` + +Output: +``` +X = tensor([[3., 4.], + [0., 1.]], device='cuda:0', requires_grad=True) +dX = tensor([[6., 8.], + [0., 2.]]) +``` + + +## Builtin Types for PyTorch Interop + +As shown in previous tutorial, Slang has defined the `TorchTensor<T>` and `TensorView<T>` type for interop with PyTorch +tensors. The `TorchTensor<T>` represents the CPU view of a tensor and provides methods to allocate a new tensor object. +The `TensorView<T>` represents the GPU view of a tensor and provides accesors to read write tensor data. + +Following is a list of builtin methods provided by each type. + +### `static TorchTensor<T> TorchTensor<T>.alloc(uint x, uint y, ...)` +Allocates a new PyTorch tensor with the given dimensions. + +### `static TorchTensor<T> TorchTensor<T>.zerosLike(TorchTensor<T> other)` +Allocates a new PyTorch tensor that has the same dimensions as `other` and initialize it to zero. + +### `uint TorchTensor<T>.dims()` +Returns the tensor's dimension count. + +### `uint TorchTensor<T>.size(int dim)` +Returns the tensor's size (in number of elements) at `dim`. + +### `uint TorchTensor<T>.stride(int dim)` +Returns the tensor's stride (in bytes) at `dim`. + +### `TensorView<T>.operator[uint x, uint y, ...]` +Provide an accessor to data content in a tensor. + +### `TensorView<T>.operator[vector<uint, N> index]` +Provide an accessor to data content in a tensor, indexed by a uint vector. +`tensor[uint3(1,2,3)]` is equivalent to `tensor[1,2,3]`. + +### `uint TensorView<T>.dims()` +Returns the tensor's dimension count. + +### `uint TensorView<T>.size(int dim)` +Returns the tensor's size (in number of elements) at `dim`. + +### `uint TensorView<T>.stride(int dim)` +Returns the tensor's stride (in bytes) at `dim`. + +### `cudaThreadIdx()` +Returns the `threadIdx` variable in CUDA. + +### `cudaBlockIdx()` +Returns the `blockIdx` variable in CUDA. + +### `cudaBlockDim()` +Returns the `blockDim` variable in CUDA.
\ No newline at end of file diff --git a/docs/user-guide/a1-special-topics.md b/docs/user-guide/a1-special-topics.md index 3091621ea..33c863eec 100644 --- a/docs/user-guide/a1-special-topics.md +++ b/docs/user-guide/a1-special-topics.md @@ -8,4 +8,5 @@ Special Topics This chapter covers several additional topics on using Slang. These topics do not belong to any categories covered in previous chapters, but they address specific issues that developers may frequently encounter. In this chapter: -1. [Handling matrix layout differences on different platforms](a1-01-matrix-layout.md)
\ No newline at end of file +1. [Handling matrix layout differences on different platforms](a1-01-matrix-layout.md) +2. [Using Slang to write PyTorch kernels](a1-02-slangpy.md)
\ No newline at end of file diff --git a/docs/user-guide/toc.html b/docs/user-guide/toc.html index 08f55ec0a..139d895fa 100644 --- a/docs/user-guide/toc.html +++ b/docs/user-guide/toc.html @@ -102,6 +102,13 @@ <li data-link="a1-01-matrix-layout#overriding-default-matrix-layout"><span>Overriding default matrix layout</span></li> </ul> </li> +<li data-link="a1-02-slangpy"><span>Using Slang to Write PyTorch Kernels</span> +<ul class="toc_list"> +<li data-link="a1-02-slangpy#writing-a-simple-kernel-function-as-a-slang-module"><span>Writing a simple kernel function as a Slang module</span></li> +<li data-link="a1-02-slangpy#calling-slang-module-from-python"><span>Calling Slang module from Python</span></li> +<li data-link="a1-02-slangpy#exposing-an-automatically-differentiated-kernel-to-pytorch"><span>Exposing an automatically differentiated kernel to PyTorch</span></li> +</ul> +</li> </ul> </li> </ul> |
