From 52b91231cdadc048f93b224f5035759cf1a96eaa Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 30 Apr 2024 16:05:33 -0400 Subject: Added diagnostics & built-in type lowering for `[CUDAKernel]` functions (#4042) * Added diagnostics & built-in type lowering for `[CUDAKernel]` functions This PR adds - Diagnostics for non-void return from a cuda kernel entry point - Diagnostics for using differentiable types in a differentiable cuda kernel entry point - Logic for converting built-in types (float3, float3x3, etc..) to portable struct types and unpacks the parameter back into a built-in type on the CUDA side. This is because built-in types have different implementations in CUDA & CPP targets, which causes signature mis-match when linking. * Fix error codes * Add ability to lower structs and arrays that contain built-in types. + Added tests + Fix issue where the host-side was not marshalling data to lowered types. * Update slang-ir-pytorch-cpp-binding.cpp --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 0f9da12c4..921bd38e9 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7700,6 +7700,16 @@ namespace Slang } } } + + // If this method is intended to be a CUDA kernel, verify that the return type is void. + if (decl->findModifier()) + { + if (decl->returnType.type && !decl->returnType.type->equals(m_astBuilder->getVoidType())) + { + getSink()->diagnose(decl, Diagnostics::cudaKernelMustReturnVoid); + } + } + checkVisibility(decl); } @@ -9547,6 +9557,30 @@ namespace Slang checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } + static void checkCudaKernelAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, CudaKernelAttribute*) + { + // If the method is also marked differentiable, check that the data types are either non-differentiable + // or marked with no_diff. + // + // Note: This is a temporary restriction until we have a more complete story for differentiability. + // + if (funcDecl->findModifier()) + { + for (auto paramDecl : funcDecl->getParameters()) + { + auto paramType = paramDecl->type; + + if (visitor->isTypeDifferentiable(paramType)) + { + if (!paramDecl->hasModifier()) + { + visitor->getSink()->diagnose(paramDecl, Diagnostics::differentiableKernelEntryPointCannotHaveDifferentiableParams); + } + } + } + } + } + template bool tryCheckDerivativeOfAttributeImpl( SemanticsVisitor* visitor, @@ -9747,6 +9781,8 @@ namespace Slang checkDerivativeAttribute(this, decl, bwdDerivativeAttr); else if (auto primalAttr = as(attr)) checkDerivativeAttribute(this, decl, primalAttr); + else if (auto cudaKernelAttr = as(attr)) + checkCudaKernelAttribute(this, decl, cudaKernelAttr); } } -- cgit v1.2.3