diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-26 13:59:11 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-26 13:59:11 -0700 |
| commit | d64ee86a3130f8eeb75d09193c38c621d7565eba (patch) | |
| tree | fed25a0cc2a7372d26175774f5983bed693e6b64 /tests | |
| parent | 666af0962b6ab41489a3a3287db83f77c2f6461a (diff) | |
Add PyTorch C++ binding generation. (#2734)
* Add PyTorch C++ binding generation.
* fix
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/cuda-kernel-export.slang | 38 |
1 files changed, 9 insertions, 29 deletions
diff --git a/tests/autodiff/cuda-kernel-export.slang b/tests/autodiff/cuda-kernel-export.slang index 54442498b..2700fb054 100644 --- a/tests/autodiff/cuda-kernel-export.slang +++ b/tests/autodiff/cuda-kernel-export.slang @@ -3,39 +3,19 @@ // Verify that we can output a cuda device function with [CudaDeviceExport]. // Disabled until we have FileCheck. -struct MixedType : IDifferentiable -{ - no_diff float noDiffField; - float field; -} - -[BackwardDifferentiable] -float f1(MixedType m) -{ - return 2.0 * m.field; -} - -[BackwardDifferentiable] -float f(MixedType m) -{ - MixedType m1 = { m.noDiffField, m.field }; - return f1(m1); -} - -[CudaDeviceExport] -void diffF(inout DifferentialPair<MixedType> m, float dout) -{ - __bwd_diff(f)(m, dout); -} [CudaKernel] -void myKernel(float* inValues, float* outValues) +void myKernel(TensorView<float> inValues, TensorView<float> outValues) { - outValues[0] = sin(inValues[0]); + if (cudaThreadIdx().x > 0) + return; + outValues.store(cudaThreadIdx().x, sin(inValues.load(cudaThreadIdx().x))); } -[CudaHost] -public __extern_cpp void runCompute(float *inValues, float *outValues, uint3 dispathcSize) +[TorchEntryPoint] +public __extern_cpp TorchTensor<float> runCompute(TorchTensor<float> inValues) { - __dispatch_kernel(myKernel, uint3(128, 1, 1), dispathcSize)(inValues, outValues); + var outValues = TorchTensor<float>.alloc(1); + __dispatch_kernel(myKernel, uint3(1, 1, 1), uint3(32, 1, 1))(inValues, outValues); + return outValues; }
\ No newline at end of file |
