From 2abd5bd5770b8052b8e4ceed86ff4ced883d2af7 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:03:21 -0400 Subject: Avoid classifying methods with `[numthreads]` as entry points for CUDA-related targets (#4063) --- source/slang/slang-compiler.cpp | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) (limited to 'source/slang/slang-compiler.cpp') diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 851e3115b..1f1bed902 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2431,7 +2431,7 @@ namespace Slang EntryPoint* entryPoint, DiagnosticSink* sink); - void Module::_discoverEntryPoints(DiagnosticSink* sink) + void Module::_discoverEntryPoints(DiagnosticSink* sink, const List>& targets) { for (auto globalDecl : m_moduleDecl->members) { @@ -2466,12 +2466,30 @@ namespace Slang else { // If there isn't a [shader] attribute, look for a [numthreads] attribute - // since that implicitly means a compute shader. + // since that implicitly means a compute shader. We'll not do this when compiling for + // CUDA/Torch since [numthreads] attributes are utilized differently for those targets. + // + + bool allTargetsCUDARelated = true; + for (auto target : targets) + { + if (!isCUDATarget(target) && + target->getTarget() != CodeGenTarget::PyTorchCppBinding) + { + allTargetsCUDARelated = false; + break; + } + } + + if (allTargetsCUDARelated && targets.getCount() > 0) + continue; + auto numThreadsAttr = funcDecl->findModifier(); if (numThreadsAttr) profile.setStage(Stage::Compute); else continue; + } RefPtr entryPoint = EntryPoint::create( -- cgit v1.2.3