From 6c991942ac4ec2e2abf6abe73a2429183172af84 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 25 Sep 2023 18:30:34 -0400 Subject: Add test for vector-element contiguity error (#3235) --- source/slang/slang-emit-torch.cpp | 3 ++- .../autodiff/autopybind-vector-element-type.slang | 29 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 tests/autodiff/autopybind-vector-element-type.slang diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp index 54408aa80..bacc9d030 100644 --- a/source/slang/slang-emit-torch.cpp +++ b/source/slang/slang-emit-torch.cpp @@ -120,7 +120,8 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& emitTorchScalarTypeName(m_writer, inst->getOperand(0)->getDataType()); m_writer->emit(", "); - if (as(inst->getOperand(0)->getDataType())) + auto tensorViewType = as(inst->getDataType()); + if (as(tensorViewType->getElementType())) m_writer->emit("true"); else m_writer->emit("false"); diff --git a/tests/autodiff/autopybind-vector-element-type.slang b/tests/autodiff/autopybind-vector-element-type.slang new file mode 100644 index 000000000..78c503466 --- /dev/null +++ b/tests/autodiff/autopybind-vector-element-type.slang @@ -0,0 +1,29 @@ +//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none +//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none + +// CUDA: __global__ void __kernel__myKernel(TensorView inValues_[[#]], TensorView outValues_[[#]]) +[AutoPyBindCUDA] +[CudaKernel] +void myKernel(TensorView inValues, TensorView outValues) +{ + if (cudaThreadIdx().x > 0) + return; + outValues.store(cudaThreadIdx().x, sin(inValues.load(cudaThreadIdx().x))); +} + +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void __kernel__myKernel(TensorView {{[[:alnum:]_]+}}, TensorView {{[[:alnum:]_]+}}); + +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void myKernel(std::tuple {{[[:alnum:]_]+}}, std::tuple {{[[:alnum:]_]+}}, torch::Tensor {{[[:alnum:]_]+}}, torch::Tensor {{[[:alnum:]_]+}}) + +// TORCH: TensorView {{[[:alnum:]_]+}} = make_tensor_view({{[[:alnum:]_]+}}, "outValues", torch::kFloat32, true); + +// TORCH: TensorView {{[[:alnum:]_]+}} = make_tensor_view({{[[:alnum:]_]+}}, "inValues", torch::kFloat32, false); + +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: std::tuple, std::tuple, const char*, const char*> __funcinfo__myKernel() + +// TORCH: m.def("myKernel", &myKernel, "myKernel"); + +// TORCH: m.def("__funcinfo__myKernel", &__funcinfo__myKernel, "__funcinfo__myKernel"); \ No newline at end of file -- cgit v1.2.3