diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/cuda-kernel-export.slang | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/tests/autodiff/cuda-kernel-export.slang b/tests/autodiff/cuda-kernel-export.slang index 2700fb054..e16188abc 100644 --- a/tests/autodiff/cuda-kernel-export.slang +++ b/tests/autodiff/cuda-kernel-export.slang @@ -3,6 +3,22 @@ // Verify that we can output a cuda device function with [CudaDeviceExport]. // Disabled until we have FileCheck. +struct MySubType +{ + TorchTensor<float> array[2]; +} + +struct MyType +{ + float2 v; + MySubType sub[2]; +} + +struct MyInput +{ + TorchTensor<float> inValues; + float normalVal; +} [CudaKernel] void myKernel(TensorView<float> inValues, TensorView<float> outValues) @@ -13,9 +29,19 @@ void myKernel(TensorView<float> inValues, TensorView<float> outValues) } [TorchEntryPoint] -public __extern_cpp TorchTensor<float> runCompute(TorchTensor<float> inValues) +public __extern_cpp MyType runCompute(MyInput input) { + MyType rs; var outValues = TorchTensor<float>.alloc(1); + let inValues = input.inValues; + __dispatch_kernel(myKernel, uint3(1, 1, 1), uint3(32, 1, 1))(inValues, outValues); - return outValues; + + rs.v = float2(1.0, 2.0); + rs.sub[0].array[0] = outValues; + rs.sub[0].array[1] = inValues; + + rs.sub[1].array[0] = inValues; + rs.sub[1].array[1] = outValues; + return rs; }
\ No newline at end of file |
