diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-04-30 21:03:21 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-04-30 18:03:21 -0700 |
| commit | 2abd5bd5770b8052b8e4ceed86ff4ced883d2af7 (patch) | |
| tree | 88013a7a8df90a063286986aa0551c82e4fff19f | |
| parent | 52b91231cdadc048f93b224f5035759cf1a96eaa (diff) | |
Avoid classifying methods with `[numthreads]` as entry points for CUDA-related targets (#4063)
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-compiler.cpp | 22 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 2 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 2 |
4 files changed, 23 insertions, 5 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index d4e8bda6f..2c1f8651c 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -918,7 +918,7 @@ namespace Slang for (Index tt = 0; tt < translationUnitCount; ++tt) { auto translationUnit = translationUnits[tt]; - translationUnit->getModule()->_discoverEntryPoints(sink); + translationUnit->getModule()->_discoverEntryPoints(sink, this->getLinkage()->targets); } } } diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 851e3115b..1f1bed902 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2431,7 +2431,7 @@ namespace Slang EntryPoint* entryPoint, DiagnosticSink* sink); - void Module::_discoverEntryPoints(DiagnosticSink* sink) + void Module::_discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets) { for (auto globalDecl : m_moduleDecl->members) { @@ -2466,12 +2466,30 @@ namespace Slang else { // If there isn't a [shader] attribute, look for a [numthreads] attribute - // since that implicitly means a compute shader. + // since that implicitly means a compute shader. We'll not do this when compiling for + // CUDA/Torch since [numthreads] attributes are utilized differently for those targets. + // + + bool allTargetsCUDARelated = true; + for (auto target : targets) + { + if (!isCUDATarget(target) && + target->getTarget() != CodeGenTarget::PyTorchCppBinding) + { + allTargetsCUDARelated = false; + break; + } + } + + if (allTargetsCUDARelated && targets.getCount() > 0) + continue; + auto numThreadsAttr = funcDecl->findModifier<NumThreadsAttribute>(); if (numThreadsAttr) profile.setStage(Stage::Compute); else continue; + } RefPtr<EntryPoint> entryPoint = EntryPoint::create( diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index f9fa24622..cd4e7fdd5 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1483,7 +1483,7 @@ namespace Slang /// void _collectShaderParams(); - void _discoverEntryPoints(DiagnosticSink* sink); + void _discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets); class ModuleSpecializationInfo : public SpecializationInfo { diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 4a446a351..654295211 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -5335,7 +5335,7 @@ void Linkage::prepareDeserializedModule(SerialContainerData::Module& moduleEntry module->setPathInfo(filePathInfo); module->setDigest(moduleEntry.digest); module->_collectShaderParams(); - module->_discoverEntryPoints(sink); + module->_discoverEntryPoints(sink, targets); // Hook up fileDecl's scope to module's scope. auto moduleDecl = module->getModuleDecl(); |
