summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-04-30 21:03:21 -0400
committerGitHub <noreply@github.com>2024-04-30 18:03:21 -0700
commit2abd5bd5770b8052b8e4ceed86ff4ced883d2af7 (patch)
tree88013a7a8df90a063286986aa0551c82e4fff19f /source
parent52b91231cdadc048f93b224f5035759cf1a96eaa (diff)
Avoid classifying methods with `[numthreads]` as entry points for CUDA-related targets (#4063)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-shader.cpp2
-rw-r--r--source/slang/slang-compiler.cpp22
-rwxr-xr-xsource/slang/slang-compiler.h2
-rw-r--r--source/slang/slang.cpp2
4 files changed, 23 insertions, 5 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index d4e8bda6f..2c1f8651c 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -918,7 +918,7 @@ namespace Slang
for (Index tt = 0; tt < translationUnitCount; ++tt)
{
auto translationUnit = translationUnits[tt];
- translationUnit->getModule()->_discoverEntryPoints(sink);
+ translationUnit->getModule()->_discoverEntryPoints(sink, this->getLinkage()->targets);
}
}
}
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 851e3115b..1f1bed902 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -2431,7 +2431,7 @@ namespace Slang
EntryPoint* entryPoint,
DiagnosticSink* sink);
- void Module::_discoverEntryPoints(DiagnosticSink* sink)
+ void Module::_discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets)
{
for (auto globalDecl : m_moduleDecl->members)
{
@@ -2466,12 +2466,30 @@ namespace Slang
else
{
// If there isn't a [shader] attribute, look for a [numthreads] attribute
- // since that implicitly means a compute shader.
+ // since that implicitly means a compute shader. We'll not do this when compiling for
+ // CUDA/Torch since [numthreads] attributes are utilized differently for those targets.
+ //
+
+ bool allTargetsCUDARelated = true;
+ for (auto target : targets)
+ {
+ if (!isCUDATarget(target) &&
+ target->getTarget() != CodeGenTarget::PyTorchCppBinding)
+ {
+ allTargetsCUDARelated = false;
+ break;
+ }
+ }
+
+ if (allTargetsCUDARelated && targets.getCount() > 0)
+ continue;
+
auto numThreadsAttr = funcDecl->findModifier<NumThreadsAttribute>();
if (numThreadsAttr)
profile.setStage(Stage::Compute);
else
continue;
+
}
RefPtr<EntryPoint> entryPoint = EntryPoint::create(
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index f9fa24622..cd4e7fdd5 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -1483,7 +1483,7 @@ namespace Slang
///
void _collectShaderParams();
- void _discoverEntryPoints(DiagnosticSink* sink);
+ void _discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets);
class ModuleSpecializationInfo : public SpecializationInfo
{
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 4a446a351..654295211 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -5335,7 +5335,7 @@ void Linkage::prepareDeserializedModule(SerialContainerData::Module& moduleEntry
module->setPathInfo(filePathInfo);
module->setDigest(moduleEntry.digest);
module->_collectShaderParams();
- module->_discoverEntryPoints(sink);
+ module->_discoverEntryPoints(sink, targets);
// Hook up fileDecl's scope to module's scope.
auto moduleDecl = module->getModuleDecl();