From 2abd5bd5770b8052b8e4ceed86ff4ced883d2af7 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:03:21 -0400 Subject: Avoid classifying methods with `[numthreads]` as entry points for CUDA-related targets (#4063) --- source/slang/slang-check-shader.cpp | 2 +- source/slang/slang-compiler.cpp | 22 ++++++++++++++++++++-- source/slang/slang-compiler.h | 2 +- source/slang/slang.cpp | 2 +- 4 files changed, 23 insertions(+), 5 deletions(-) (limited to 'source') diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index d4e8bda6f..2c1f8651c 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -918,7 +918,7 @@ namespace Slang for (Index tt = 0; tt < translationUnitCount; ++tt) { auto translationUnit = translationUnits[tt]; - translationUnit->getModule()->_discoverEntryPoints(sink); + translationUnit->getModule()->_discoverEntryPoints(sink, this->getLinkage()->targets); } } } diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 851e3115b..1f1bed902 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2431,7 +2431,7 @@ namespace Slang EntryPoint* entryPoint, DiagnosticSink* sink); - void Module::_discoverEntryPoints(DiagnosticSink* sink) + void Module::_discoverEntryPoints(DiagnosticSink* sink, const List>& targets) { for (auto globalDecl : m_moduleDecl->members) { @@ -2466,12 +2466,30 @@ namespace Slang else { // If there isn't a [shader] attribute, look for a [numthreads] attribute - // since that implicitly means a compute shader. + // since that implicitly means a compute shader. We'll not do this when compiling for + // CUDA/Torch since [numthreads] attributes are utilized differently for those targets. + // + + bool allTargetsCUDARelated = true; + for (auto target : targets) + { + if (!isCUDATarget(target) && + target->getTarget() != CodeGenTarget::PyTorchCppBinding) + { + allTargetsCUDARelated = false; + break; + } + } + + if (allTargetsCUDARelated && targets.getCount() > 0) + continue; + auto numThreadsAttr = funcDecl->findModifier(); if (numThreadsAttr) profile.setStage(Stage::Compute); else continue; + } RefPtr entryPoint = EntryPoint::create( diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index f9fa24622..cd4e7fdd5 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1483,7 +1483,7 @@ namespace Slang /// void _collectShaderParams(); - void _discoverEntryPoints(DiagnosticSink* sink); + void _discoverEntryPoints(DiagnosticSink* sink, const List>& targets); class ModuleSpecializationInfo : public SpecializationInfo { diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 4a446a351..654295211 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -5335,7 +5335,7 @@ void Linkage::prepareDeserializedModule(SerialContainerData::Module& moduleEntry module->setPathInfo(filePathInfo); module->setDigest(moduleEntry.digest); module->_collectShaderParams(); - module->_discoverEntryPoints(sink); + module->_discoverEntryPoints(sink, targets); // Hook up fileDecl's scope to module's scope. auto moduleDecl = module->getModuleDecl(); -- cgit v1.2.3