summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-compiler.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-04-30 21:03:21 -0400
committerGitHub <noreply@github.com>2024-04-30 18:03:21 -0700
commit2abd5bd5770b8052b8e4ceed86ff4ced883d2af7 (patch)
tree88013a7a8df90a063286986aa0551c82e4fff19f /source/slang/slang-compiler.cpp
parent52b91231cdadc048f93b224f5035759cf1a96eaa (diff)
Avoid classifying methods with `[numthreads]` as entry points for CUDA-related targets (#4063)
Diffstat (limited to 'source/slang/slang-compiler.cpp')
-rw-r--r--source/slang/slang-compiler.cpp22
1 files changed, 20 insertions, 2 deletions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 851e3115b..1f1bed902 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -2431,7 +2431,7 @@ namespace Slang
EntryPoint* entryPoint,
DiagnosticSink* sink);
- void Module::_discoverEntryPoints(DiagnosticSink* sink)
+ void Module::_discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets)
{
for (auto globalDecl : m_moduleDecl->members)
{
@@ -2466,12 +2466,30 @@ namespace Slang
else
{
// If there isn't a [shader] attribute, look for a [numthreads] attribute
- // since that implicitly means a compute shader.
+ // since that implicitly means a compute shader. We'll not do this when compiling for
+ // CUDA/Torch since [numthreads] attributes are utilized differently for those targets.
+ //
+
+ bool allTargetsCUDARelated = true;
+ for (auto target : targets)
+ {
+ if (!isCUDATarget(target) &&
+ target->getTarget() != CodeGenTarget::PyTorchCppBinding)
+ {
+ allTargetsCUDARelated = false;
+ break;
+ }
+ }
+
+ if (allTargetsCUDARelated && targets.getCount() > 0)
+ continue;
+
auto numThreadsAttr = funcDecl->findModifier<NumThreadsAttribute>();
if (numThreadsAttr)
profile.setStage(Stage::Compute);
else
continue;
+
}
RefPtr<EntryPoint> entryPoint = EntryPoint::create(