diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-04-30 21:03:21 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-04-30 18:03:21 -0700 |
| commit | 2abd5bd5770b8052b8e4ceed86ff4ced883d2af7 (patch) | |
| tree | 88013a7a8df90a063286986aa0551c82e4fff19f /source/slang/slang-compiler.cpp | |
| parent | 52b91231cdadc048f93b224f5035759cf1a96eaa (diff) | |
Avoid classifying methods with `[numthreads]` as entry points for CUDA-related targets (#4063)
Diffstat (limited to 'source/slang/slang-compiler.cpp')
| -rw-r--r-- | source/slang/slang-compiler.cpp | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 851e3115b..1f1bed902 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2431,7 +2431,7 @@ namespace Slang EntryPoint* entryPoint, DiagnosticSink* sink); - void Module::_discoverEntryPoints(DiagnosticSink* sink) + void Module::_discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets) { for (auto globalDecl : m_moduleDecl->members) { @@ -2466,12 +2466,30 @@ namespace Slang else { // If there isn't a [shader] attribute, look for a [numthreads] attribute - // since that implicitly means a compute shader. + // since that implicitly means a compute shader. We'll not do this when compiling for + // CUDA/Torch since [numthreads] attributes are utilized differently for those targets. + // + + bool allTargetsCUDARelated = true; + for (auto target : targets) + { + if (!isCUDATarget(target) && + target->getTarget() != CodeGenTarget::PyTorchCppBinding) + { + allTargetsCUDARelated = false; + break; + } + } + + if (allTargetsCUDARelated && targets.getCount() > 0) + continue; + auto numThreadsAttr = funcDecl->findModifier<NumThreadsAttribute>(); if (numThreadsAttr) profile.setStage(Stage::Compute); else continue; + } RefPtr<EntryPoint> entryPoint = EntryPoint::create( |
