diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-28 15:19:03 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-28 15:19:03 -0700 |
| commit | a61f089fbc4b944d058e6417d8a0d22d57ca5c92 (patch) | |
| tree | 4fa1a0c6370b8d34262d297653239f48aa004c71 | |
| parent | 8f03af5e5b580170fab3fd2fe6144f92038c7701 (diff) | |
Add slangpy doc, fix cuda prelude. (#2748)
* Add slangpy doc, fix cuda prelude.
* more bug fix.
* fix.
* fix.
* More fix.
* fix.
* f
* fix prelude.
* update prelude.
* update doc
* Update prelude.
* add zeros_like
* update doc.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | docs/user-guide/a1-02-slangpy.md | 266 | ||||
| -rw-r--r-- | docs/user-guide/a1-special-topics.md | 3 | ||||
| -rw-r--r-- | docs/user-guide/toc.html | 7 | ||||
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 22 | ||||
| -rw-r--r-- | prelude/slang-torch-prelude.h | 14 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 44 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit-torch.cpp | 141 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 7 |
11 files changed, 440 insertions, 73 deletions
diff --git a/.gitignore b/.gitignore index 8cf116d3b..efbfbde24 100644 --- a/.gitignore +++ b/.gitignore @@ -76,3 +76,4 @@ build/**/*.recipe build/**/*.log *.dll *.dxil +/tests/serialization/*.map diff --git a/docs/user-guide/a1-02-slangpy.md b/docs/user-guide/a1-02-slangpy.md new file mode 100644 index 000000000..7bcc47389 --- /dev/null +++ b/docs/user-guide/a1-02-slangpy.md @@ -0,0 +1,266 @@ +--- +layout: user-guide +--- + +Using Slang to Write PyTorch Kernels +========================================================= + +If you are a PyTorch user looking for a way to write complex, high performance and automatically differentiated kernel functions in a per-thread instead of full-tensor style, give Slang a try. Slang is evolved from on a traditional shading language that were designed +to provide a simple way to define kernel functions that runs extremely fast in graphics applications. With the latest addition of +automatic differentiation and PyTorch interop features, Slang provides a streamlined solution to author auto-differentiated kernels +that runs at the speed of light with a strongly typed, per-thread programming model. + +## Getting Started with `slangpy` + +In this tutorial, we will use a simple example to walkthrough the steps to use Slang in your PyTorch project. + +### Writing a simple kernel function as a Slang module + +Assume we want to write a kernel function that computes `x*x` for each element in the input tensor in Slang. To do so, +we start by creating a `square.slang` file: + +```csharp +// square.slang +float square(float x) +{ + return x * x; +} +``` + +This function is self explanatory. To use it in PyTorch, we need to write a GPU kernel function (that maps to a +`__global__` CUDA function) that defines how to compute each element of the input tensor. So we continue to write +the following Slang function: + +```csharp +[CudaKernel] +void square_fwd_kernel(TensorView<float> input, TensorView<float> output) +{ + uint3 globalIdx = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); + + if (globalIdx.x > input.size(0) || globalIdx.x > input.size(1)) + return; + float result = square(input[globalIdx.xy]); + output[globalIdx.xy] = result; +} +``` + +This code follows the standard pattern of a typical CUDA kernel function. It takes as input +two tensors, `input` and `output`. +It first obtains the global dispatch index of the current thread and performs range check to make sure we don't read or write out +of the bounds of input and output tensors, and then calls `square()` to compute the per-element result, and +store it at the corresponding location in `output` tensor. + +With a kernel function defined, we then need to expose a CPU(host) function that defines how this kernel is dispatched: +```csharp +[TorchEntryPoint] +TorchTensor<float> square_fwd(TorchTensor<float> input) +{ + var result = TorchTensor<float>.zerosLike(input); + let blockCount = uint3(1); + let groupSize = uint3(result.size(0), result.size(1), 1); + __dispatch_kernel(square_fwd_kernel, blockCount, groupSize)(input, result); + return result; +} +``` +Here, we first call `TorchTensor<float>.alloc` to allocate a 2D-tensor that has the same size as the input. +This function returns a `TorchTensor<float>` object that represents a CPU handle of a PyTorch tensor. +Then we launch `square_fwd_kernel` with the `__dispatch_kernel` syntax. Note that we can directly pass +`TorchTensor<float>` arguments to a `TensorView<float>` parameter and the compiler will automatically convert +the type and obtain a view into the tensor that can be accessed by the GPU kernel function. + +### Calling Slang module from Python + +Next, let's see how we can call the `square_fwd` function we defined in the Slang module. +To do so, we use a python package called `slangpy`. You can obtain it with + +```bash +pip install slangpy +``` + +With that, you can use the following code code call Slang function from Python: + +```python +import torch +import slangpy + +m = slangpy.loadModule("square.slang") + +x = torch.randn(2,2) +print(f"X = {x}") +y = m.square_fwd(x) +print(f"Y = {y.cpu()}") +``` + +Result output: +``` +X = tensor([[ 0.1407, 0.6594], + [-0.8978, -1.7230]]) +Y = tensor([[0.0198, 0.4349], + [0.8060, 2.9688]]) +``` + +And that's it! `slangpy.loadModule` uses JIT compilation to compile your Slang source into CUDA binary. +It may take a little longer the first time you execute the script, but the result will be cached and as +long as the kernel code is not changed, future runs will not rebuild the CUDA kernel. + +Because the PyTorch JIT system requires `ninja`, you need to make sure `ninja` is installed on your system +and is discoverable from the current environment. + +### Exposing an automatically differentiated kernel to PyTorch + +The above example demonstrates how to write a simple kernel function in Slang and call it from Python. +Another major benefit of using Slang is that the Slang compiler support generating backward derivative +propagation functions automatically. + +In the following section, we walkthrough how to use Slang to generate a backward propagation function +for `square`, and expose it to PyTorch as an autograd function. + +First we need to tell Slang compiler that we need the `square` function to be considered a differentiable function so Slang compiler can generate a backward derivative propagation function for it: +```csharp +[BackwardDifferentiable] +float square(float x) +{ + return x * x; +} +``` +This is done by simply adding a `[BackwardDifferentiable]` attribute to our `square`function. + +With that, we can now define `square_bwd_kernel` that performance backward propagation as: + +```csharp +[CudaKernel] +void square_bwd_kernel(TensorView<float> input, TensorView<float> grad_out, TensorView<float> grad_propagated) +{ + uint3 globalIdx = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); + + if (globalIdx.x > input.size(0) || globalIdx.x > input.size(1)) + return; + + DifferentialPair<float> dpInput = diffPair(input[globalIdx.xy]); + var gradInElem = grad_out[globalIdx.xy]; + __bwd_diff(square)(dpInput, gradInElem); + grad_propagated[globalIdx.xy] = dpInput.d; +} +``` + +Note that the function follows the same structure of `square_fwd_kernel`, with the only difference being that +instead of calling into `square` to compute the forward value for each tensor element, we are calling `__bwd_diff(square)` +that represents the automatically generated backward propagation function of `squre`. +`__bwd_diff(squre)` will have the following signature: +```csharp +void __bwd_diff_squre(inout DifferentialPair<float> dpInput, float dOut); +``` + +Where the first parameter, `dpInput` represents a pair of original and derivative value for `input`, and the second parameter, +`dOut`, represents the initial derivative with regard to some latent variable that we wish to backprop through. The resulting +derivative will be stored in `dpInput.d`. For example: + +```csharp +// construct a pair where the primal value is 3, and derivative value is 0. +var dp = diffPair(3.0); +__bwd_diff(square)(dp, 1.0); +// dp.d is now 6.0 +``` + +Similarly to `squre_fwd`, we can define the host side function `square_bwd` as: + +```csharp +[TorchEntryPoint] +TorchTensor<float> square_bwd(TorchTensor<float> input, TorchTensor<float> grad_out) +{ + var grad_propagated = TorchTensor<float>.zerosLike(input); + let blockCount = uint3(1); + let groupSize = uint3(input.size(0), input.size(1), 1); + __dispatch_kernel(square_bwd_kernel, blockCount, groupSize)(input, grad_out, grad_propagated); + return grad_propagated; +} +``` + +You can refer [this documentation](07-autodiff.md) for a detailed reference of Slang's automatic differentiation system. + +With this, the python script `slangpy.loadModule("square.slang")` will now return +a scope that defines two functions, `square_fwd` and `square_bwd`. We can then use these +two functions to define a PyTorch autograd kernel class: + +```python +m = slangpy.loadModule("square.slang") + +class MySquareFuncInSlang(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return m.square_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + [input] = ctx.saved_tensors + return m.square_bwd(input, grad_output) +``` + +Now we can use the autograd function `MySquareFuncInSlang` in our python script: + +```python +x = torch.tensor([[3.0, 4.0],[0.0, 1.0]], requires_grad=True, device=cuda_device) +print(f"X = {x}") +y_pred = MySquareFuncInSlang.apply(x) +loss = y_pred.sum() +loss.backward() +print(f"dX = {x.grad.cpu()}") +``` + +Output: +``` +X = tensor([[3., 4.], + [0., 1.]], device='cuda:0', requires_grad=True) +dX = tensor([[6., 8.], + [0., 2.]]) +``` + + +## Builtin Types for PyTorch Interop + +As shown in previous tutorial, Slang has defined the `TorchTensor<T>` and `TensorView<T>` type for interop with PyTorch +tensors. The `TorchTensor<T>` represents the CPU view of a tensor and provides methods to allocate a new tensor object. +The `TensorView<T>` represents the GPU view of a tensor and provides accesors to read write tensor data. + +Following is a list of builtin methods provided by each type. + +### `static TorchTensor<T> TorchTensor<T>.alloc(uint x, uint y, ...)` +Allocates a new PyTorch tensor with the given dimensions. + +### `static TorchTensor<T> TorchTensor<T>.zerosLike(TorchTensor<T> other)` +Allocates a new PyTorch tensor that has the same dimensions as `other` and initialize it to zero. + +### `uint TorchTensor<T>.dims()` +Returns the tensor's dimension count. + +### `uint TorchTensor<T>.size(int dim)` +Returns the tensor's size (in number of elements) at `dim`. + +### `uint TorchTensor<T>.stride(int dim)` +Returns the tensor's stride (in bytes) at `dim`. + +### `TensorView<T>.operator[uint x, uint y, ...]` +Provide an accessor to data content in a tensor. + +### `TensorView<T>.operator[vector<uint, N> index]` +Provide an accessor to data content in a tensor, indexed by a uint vector. +`tensor[uint3(1,2,3)]` is equivalent to `tensor[1,2,3]`. + +### `uint TensorView<T>.dims()` +Returns the tensor's dimension count. + +### `uint TensorView<T>.size(int dim)` +Returns the tensor's size (in number of elements) at `dim`. + +### `uint TensorView<T>.stride(int dim)` +Returns the tensor's stride (in bytes) at `dim`. + +### `cudaThreadIdx()` +Returns the `threadIdx` variable in CUDA. + +### `cudaBlockIdx()` +Returns the `blockIdx` variable in CUDA. + +### `cudaBlockDim()` +Returns the `blockDim` variable in CUDA.
\ No newline at end of file diff --git a/docs/user-guide/a1-special-topics.md b/docs/user-guide/a1-special-topics.md index 3091621ea..33c863eec 100644 --- a/docs/user-guide/a1-special-topics.md +++ b/docs/user-guide/a1-special-topics.md @@ -8,4 +8,5 @@ Special Topics This chapter covers several additional topics on using Slang. These topics do not belong to any categories covered in previous chapters, but they address specific issues that developers may frequently encounter. In this chapter: -1. [Handling matrix layout differences on different platforms](a1-01-matrix-layout.md)
\ No newline at end of file +1. [Handling matrix layout differences on different platforms](a1-01-matrix-layout.md) +2. [Using Slang to write PyTorch kernels](a1-02-slangpy.md)
\ No newline at end of file diff --git a/docs/user-guide/toc.html b/docs/user-guide/toc.html index 08f55ec0a..139d895fa 100644 --- a/docs/user-guide/toc.html +++ b/docs/user-guide/toc.html @@ -102,6 +102,13 @@ <li data-link="a1-01-matrix-layout#overriding-default-matrix-layout"><span>Overriding default matrix layout</span></li> </ul> </li> +<li data-link="a1-02-slangpy"><span>Using Slang to Write PyTorch Kernels</span> +<ul class="toc_list"> +<li data-link="a1-02-slangpy#writing-a-simple-kernel-function-as-a-slang-module"><span>Writing a simple kernel function as a Slang module</span></li> +<li data-link="a1-02-slangpy#calling-slang-module-from-python"><span>Calling Slang module from Python</span></li> +<li data-link="a1-02-slangpy#exposing-an-automatically-differentiated-kernel-to-pytorch"><span>Exposing an automatically differentiated kernel to PyTorch</span></li> +</ul> +</li> </ul> </li> </ul> diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 240dfb3e7..a1c49e108 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -6,6 +6,14 @@ #define SLANG_CUDA_RTC 0 #endif +#if SLANG_CUDA_RTC + +#else + +#include <cstdint> + +#endif + // Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support. // For this to work NVRTC needs to have the path to the CUDA SDK. // @@ -162,6 +170,8 @@ typedef int2 bool2; typedef int3 bool3; typedef int4 bool4; +#if SLANG_CUDA_RTC + typedef signed char int8_t; typedef short int16_t; typedef int int32_t; @@ -172,6 +182,8 @@ typedef unsigned short uint16_t; typedef unsigned int uint32_t; typedef unsigned long long uint64_t; +#endif + typedef long long longlong; typedef unsigned long long ulonglong; @@ -2099,27 +2111,27 @@ struct TensorView } template<typename T> - __device__ T load(uint32_t x) + __device__ T& load(uint32_t x) { return *reinterpret_cast<T*>(data + strides[0] * x); } template<typename T> - __device__ T load(uint32_t x, uint32_t y) + __device__ T& load(uint32_t x, uint32_t y) { return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y); } template<typename T> - __device__ T load(uint32_t x, uint32_t y, uint32_t z) + __device__ T& load(uint32_t x, uint32_t y, uint32_t z) { return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z); } template<typename T> - __device__ T load(uint32_t x, uint32_t y, uint32_t z, uint32_t w) + __device__ T& load(uint32_t x, uint32_t y, uint32_t z, uint32_t w) { return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w); } template<typename T> - __device__ T load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4) + __device__ T& load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4) { return *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4); } diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index f2accc149..4844e9248 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -4,6 +4,8 @@ #include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAUtils.h> #include <vector> +#include <stdexcept> +#include <string> #ifndef SLANG_NO_THROW # define SLANG_NO_THROW @@ -54,7 +56,7 @@ struct CudaTaskMemoryAllocator uint32_t* allocUIntArray(uint32_t size) { void* ptr = nullptr; - cudaMallocManaged(&ptr, size * sizeof(uint32_t)); + cudaMallocHost(&ptr, size * sizeof(uint32_t)); AT_CUDA_CHECK(cudaGetLastError()); return (uint32_t*)ptr; } @@ -66,15 +68,18 @@ struct CudaTaskMemoryAllocator } }; -TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val) +TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val, const char* name) { - val = val.to(torch::kCUDA); + if (!val.device().is_cuda()) + val = val.to(torch::kCUDA); + TensorView res = {}; res.dimensionCount = val.dim(); res.strides = allocator->allocUIntArray(val.dim()); res.sizes = allocator->allocUIntArray(val.dim()); res.data = nullptr; size_t elementSize = 4; + switch (val.scalar_type()) { case torch::kInt8: @@ -107,11 +112,14 @@ TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor va res.data = (uint8_t*)val.data_ptr<int64_t>(); break; } + for (int i = 0; i < val.dim(); ++i) { res.strides[i] = val.stride(i) * elementSize; res.sizes[i] = val.size(i); } + if (!res.data) + throw std::runtime_error(std::string(name).append(": data pointer is invalid.").c_str()); return res; } diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index d5dc7842a..51cf1cdb7 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -76,6 +76,47 @@ struct TensorView __target_intrinsic(cuda, "$0.strides[$1]") [__readNone] uint stride(uint i); + + __subscript(uint index) -> T + { + [ForceInline] [__readNone] get { return load(index); } + [ForceInline] set { store(index, newValue); } + } + __subscript(uint i1, uint i2) -> T + { + [ForceInline] [__readNone] get { return load(i1, i2); } + [ForceInline] set { store(i1, i2, newValue); } + } + __subscript(uint2 i) -> T + { + [ForceInline] [__readNone] get { return load(i.x, i.y); } + [ForceInline] set { store(i.x, i.y, newValue); } + } + __subscript(uint i1, uint i2, uint i3) -> T + { + [ForceInline] [__readNone] get { return load(i1, i2, i3); } + [ForceInline] set { store(i1, i2, i3, newValue); } + } + __subscript(uint3 i) -> T + { + [ForceInline] [__readNone] get { return load(i.x, i.y, i.z); } + [ForceInline] set { store(i.x, i.y, i.z, newValue); } + } + __subscript(uint i1, uint i2, uint i3, uint i4) -> T + { + [ForceInline] [__readNone] get { return load(i1, i2, i3, i4); } + [ForceInline] set { store(i1, i2, i3, i4, newValue); } + } + __subscript(uint4 i) -> T + { + [__readNone][ForceInline] get { return load(i.x, i.y, i.z, i.w); } + [ForceInline] set { store(i.x, i.y, i.z, i.w, newValue); } + } + __subscript(uint i1, uint i2, uint i3, uint i4, uint i5) -> T + { + [ForceInline] [__readNone] get { return load(i1, i2, i3, i4, i5); } + [ForceInline] set { store(i1, i2, i3, i4, i5, newValue); } + } } __generic<T> @@ -119,6 +160,9 @@ struct TorchTensor __intrinsic_op($(kIROp_AllocateTorchTensor)) static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> zerosLike(TorchTensor<T> other); } __generic<T: IDifferentiable> diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 2bc739142..08ba050db 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -945,6 +945,13 @@ String CLikeSourceEmitter::getName(IRInst* inst) return name; } +String CLikeSourceEmitter::getUnmangledName(IRInst* inst) +{ + if (auto nameHintDecor = inst->findDecoration<IRNameHintDecoration>()) + return nameHintDecor->getName(); + return getName(inst); +} + void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst) { switch(inst->getOp()) diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 9426531a8..8046fa633 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -324,6 +324,7 @@ public: virtual String generateEntryPointNameImpl(IREntryPointDecoration* entryPointDecor); String getName(IRInst* inst); + String getUnmangledName(IRInst* inst); void emitSimpleValue(IRInst* inst) { emitSimpleValueImpl(inst); } diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index 877c1dc03..4511039e3 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -24,6 +24,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); m_writer->emit(", "); emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitStringLiteral(getUnmangledName(inst->getOperand(1))); m_writer->emit(")"); return true; } @@ -67,72 +69,87 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& } case kIROp_AllocateTorchTensor: { - /* - Emit something like: - ``` - torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, - torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); - ``` - */ - m_writer->emit("torch::empty({ "); - for (UInt i = 0; i < inst->getOperandCount(); i++) + if (as<IRTorchTensorType>(inst->getOperand(0)->getDataType())) { - if (i > 0) - m_writer->emit(", "); - auto arg = inst->getOperand(i); - emitOperand(arg, getInfo(EmitOp::General)); + /* + Emit something like: + ``` + torch::Tensor out = torch::zeros_like(other); + ``` + */ + m_writer->emit("torch::zeros_like("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); } - m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); - - // Get the element type of the tensor. - auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0); - - // If instType is a vector type, then we need to get the element type. - if (auto vectorType = as<IRVectorType>(instType)) + else { - instType = vectorType->getElementType(); + /* + Emit something like: + ``` + torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, + torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); + ``` + */ + m_writer->emit("torch::empty({ "); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (i > 0) + m_writer->emit(", "); + auto arg = inst->getOperand(i); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); + + // Get the element type of the tensor. + auto instType = as<IRTorchTensorType>(inst->getDataType())->getOperand(0); + + // If instType is a vector type, then we need to get the element type. + if (auto vectorType = as<IRVectorType>(instType)) + { + instType = vectorType->getElementType(); + } + + switch (instType->getOp()) + { + case kIROp_FloatType: + m_writer->emit("kFloat32"); + break; + case kIROp_HalfType: + m_writer->emit("kFloat16"); + break; + case kIROp_DoubleType: + m_writer->emit("kFloat64"); + break; + case kIROp_UInt8Type: + m_writer->emit("kUInt8"); + break; + case kIROp_UInt16Type: + m_writer->emit("kUInt16"); + break; + case kIROp_UIntType: + m_writer->emit("kUInt32"); + break; + case kIROp_UInt64Type: + m_writer->emit("kUInt64"); + break; + case kIROp_Int8Type: + m_writer->emit("kInt8"); + break; + case kIROp_Int16Type: + m_writer->emit("kInt16"); + break; + case kIROp_IntType: + m_writer->emit("kInt32"); + break; + case kIROp_Int64Type: + m_writer->emit("kInt64"); + break; + default: + SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); + break; + } + m_writer->emit("))"); } - - switch (instType->getOp()) - { - case kIROp_FloatType: - m_writer->emit("kFloat32"); - break; - case kIROp_HalfType: - m_writer->emit("kFloat16"); - break; - case kIROp_DoubleType: - m_writer->emit("kFloat64"); - break; - case kIROp_UInt8Type: - m_writer->emit("kUInt8"); - break; - case kIROp_UInt16Type: - m_writer->emit("kUInt16"); - break; - case kIROp_UIntType: - m_writer->emit("kUInt32"); - break; - case kIROp_UInt64Type: - m_writer->emit("kUInt64"); - break; - case kIROp_Int8Type: - m_writer->emit("kInt8"); - break; - case kIROp_Int16Type: - m_writer->emit("kInt16"); - break; - case kIROp_IntType: - m_writer->emit("kInt32"); - break; - case kIROp_Int64Type: - m_writer->emit("kInt64"); - break; - default: - SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); - break; - } - m_writer->emit("))"); return true; } } diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index eb81bfd8c..971e87a6f 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -245,6 +245,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) return; } auto newParam = builder.emitParam(newParamType); + param->transferDecorationsTo(newParam); newParams.add(newParam); } @@ -361,14 +362,16 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) // Remove all [TorchEntryPoint] functions when emitting CUDA source. void removeTorchKernels(IRModule* module) { + List<IRInst*> toRemove; for (auto globalInst : module->getGlobalInsts()) { if (!as<IRFunc>(globalInst)) continue; if (globalInst->findDecoration<IRTorchEntryPointDecoration>()) - globalInst->removeAndDeallocate(); + toRemove.add(globalInst); } - + for (auto inst : toRemove) + inst->removeAndDeallocate(); } } |
