diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/cuda-kernel-export.slang | 38 |
1 files changed, 9 insertions, 29 deletions
diff --git a/tests/autodiff/cuda-kernel-export.slang b/tests/autodiff/cuda-kernel-export.slang index 54442498b..2700fb054 100644 --- a/tests/autodiff/cuda-kernel-export.slang +++ b/tests/autodiff/cuda-kernel-export.slang @@ -3,39 +3,19 @@ // Verify that we can output a cuda device function with [CudaDeviceExport]. // Disabled until we have FileCheck. -struct MixedType : IDifferentiable -{ - no_diff float noDiffField; - float field; -} - -[BackwardDifferentiable] -float f1(MixedType m) -{ - return 2.0 * m.field; -} - -[BackwardDifferentiable] -float f(MixedType m) -{ - MixedType m1 = { m.noDiffField, m.field }; - return f1(m1); -} - -[CudaDeviceExport] -void diffF(inout DifferentialPair<MixedType> m, float dout) -{ - __bwd_diff(f)(m, dout); -} [CudaKernel] -void myKernel(float* inValues, float* outValues) +void myKernel(TensorView<float> inValues, TensorView<float> outValues) { - outValues[0] = sin(inValues[0]); + if (cudaThreadIdx().x > 0) + return; + outValues.store(cudaThreadIdx().x, sin(inValues.load(cudaThreadIdx().x))); } -[CudaHost] -public __extern_cpp void runCompute(float *inValues, float *outValues, uint3 dispathcSize) +[TorchEntryPoint] +public __extern_cpp TorchTensor<float> runCompute(TorchTensor<float> inValues) { - __dispatch_kernel(myKernel, uint3(128, 1, 1), dispathcSize)(inValues, outValues); + var outValues = TorchTensor<float>.alloc(1); + __dispatch_kernel(myKernel, uint3(1, 1, 1), uint3(32, 1, 1))(inValues, outValues); + return outValues; }
\ No newline at end of file |
