summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-26 13:59:11 -0700
committerGitHub <noreply@github.com>2023-03-26 13:59:11 -0700
commitd64ee86a3130f8eeb75d09193c38c621d7565eba (patch)
treefed25a0cc2a7372d26175774f5983bed693e6b64 /tests
parent666af0962b6ab41489a3a3287db83f77c2f6461a (diff)
Add PyTorch C++ binding generation. (#2734)
* Add PyTorch C++ binding generation. * fix --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/cuda-kernel-export.slang38
1 files changed, 9 insertions, 29 deletions
diff --git a/tests/autodiff/cuda-kernel-export.slang b/tests/autodiff/cuda-kernel-export.slang
index 54442498b..2700fb054 100644
--- a/tests/autodiff/cuda-kernel-export.slang
+++ b/tests/autodiff/cuda-kernel-export.slang
@@ -3,39 +3,19 @@
// Verify that we can output a cuda device function with [CudaDeviceExport].
// Disabled until we have FileCheck.
-struct MixedType : IDifferentiable
-{
- no_diff float noDiffField;
- float field;
-}
-
-[BackwardDifferentiable]
-float f1(MixedType m)
-{
- return 2.0 * m.field;
-}
-
-[BackwardDifferentiable]
-float f(MixedType m)
-{
- MixedType m1 = { m.noDiffField, m.field };
- return f1(m1);
-}
-
-[CudaDeviceExport]
-void diffF(inout DifferentialPair<MixedType> m, float dout)
-{
- __bwd_diff(f)(m, dout);
-}
[CudaKernel]
-void myKernel(float* inValues, float* outValues)
+void myKernel(TensorView<float> inValues, TensorView<float> outValues)
{
- outValues[0] = sin(inValues[0]);
+ if (cudaThreadIdx().x > 0)
+ return;
+ outValues.store(cudaThreadIdx().x, sin(inValues.load(cudaThreadIdx().x)));
}
-[CudaHost]
-public __extern_cpp void runCompute(float *inValues, float *outValues, uint3 dispathcSize)
+[TorchEntryPoint]
+public __extern_cpp TorchTensor<float> runCompute(TorchTensor<float> inValues)
{
- __dispatch_kernel(myKernel, uint3(128, 1, 1), dispathcSize)(inValues, outValues);
+ var outValues = TorchTensor<float>.alloc(1);
+ __dispatch_kernel(myKernel, uint3(1, 1, 1), uint3(32, 1, 1))(inValues, outValues);
+ return outValues;
} \ No newline at end of file