diff options
Diffstat (limited to 'source/slang/slang-compiler.cpp')
| -rw-r--r-- | source/slang/slang-compiler.cpp | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 851e3115b..1f1bed902 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2431,7 +2431,7 @@ namespace Slang EntryPoint* entryPoint, DiagnosticSink* sink); - void Module::_discoverEntryPoints(DiagnosticSink* sink) + void Module::_discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets) { for (auto globalDecl : m_moduleDecl->members) { @@ -2466,12 +2466,30 @@ namespace Slang else { // If there isn't a [shader] attribute, look for a [numthreads] attribute - // since that implicitly means a compute shader. + // since that implicitly means a compute shader. We'll not do this when compiling for + // CUDA/Torch since [numthreads] attributes are utilized differently for those targets. + // + + bool allTargetsCUDARelated = true; + for (auto target : targets) + { + if (!isCUDATarget(target) && + target->getTarget() != CodeGenTarget::PyTorchCppBinding) + { + allTargetsCUDARelated = false; + break; + } + } + + if (allTargetsCUDARelated && targets.getCount() > 0) + continue; + auto numThreadsAttr = funcDecl->findModifier<NumThreadsAttribute>(); if (numThreadsAttr) profile.setStage(Stage::Compute); else continue; + } RefPtr<EntryPoint> entryPoint = EntryPoint::create( |
