summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-09-05 11:53:14 -0700
committerGitHub <noreply@github.com>2024-09-05 11:53:14 -0700
commit879ee3d187e577189eba9aed7bc6326b740cb627 (patch)
tree2317bf727e7958efacea24a3bcf6534a44c1827f /source
parenta3b25ceb4021811d481c9c4a07a8d029329f01f3 (diff)
Support entrypoints defined in a namespace. (#5011)
* Support entrypoints defined in a namespace. * Fix test.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-shader.cpp72
-rw-r--r--source/slang/slang-compiler.cpp18
-rwxr-xr-xsource/slang/slang-compiler.h2
-rw-r--r--source/slang/slang.cpp9
4 files changed, 30 insertions, 71 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 3a1e4c7f6..1718c3afd 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -235,75 +235,17 @@ namespace Slang
Name* name,
DiagnosticSink* sink)
{
- auto translationUnitSyntax = translationUnit->getModuleDecl();
- FuncDecl* entryPointFuncDecl = nullptr;
+ auto declRef = translationUnit->findDeclFromString(getText(name), sink);
+ FuncDecl* entryPointFuncDecl = declRef.as<FuncDecl>().getDecl();
- for (auto globalScope = translationUnit->getModuleDecl()->ownedScope; globalScope; globalScope = globalScope->nextSibling)
- {
- if (globalScope->containerDecl != translationUnitSyntax && globalScope->containerDecl->parentDecl != translationUnitSyntax)
- continue; // Skip scopes that aren't part of the current module.
-
- // We will look up any global-scope declarations in the translation
- // unit that match the name of our entry point.
- Decl* firstDeclWithName = nullptr;
- if (!globalScope->containerDecl->getMemberDictionary().tryGetValue(name, firstDeclWithName))
- {
- // If there doesn't appear to be any such declaration, then we are done with this scope.
- continue;
- }
-
- // We found at least one global-scope declaration with the right name,
- // but (1) it might not be a function, and (2) there might be
- // more than one function.
- //
- // We'll walk the linked list of declarations with the same name,
- // to see what we find. Along the way we'll keep track of the
- // first function declaration we find, if any:
- for (auto ee = firstDeclWithName; ee; ee = ee->nextInContainerWithSameName)
- {
- // Is this declaration a function?
- if (auto funcDecl = as<FuncDecl>(ee))
- {
- // Skip non-primary declarations, so that
- // we don't give an error when an entry
- // point is forward-declared.
- if (!isPrimaryDecl(funcDecl))
- continue;
-
- // is this the first one we've seen?
- if (!entryPointFuncDecl)
- {
- // If so, this is a candidate to be
- // the entry point function.
- entryPointFuncDecl = funcDecl;
- }
- else
- {
- // Uh-oh! We've already seen a function declaration with this
- // name before, so the whole thing is ambiguous. We need
- // to diagnose and bail out.
-
- sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, name);
-
- // List all of the declarations that the user *might* mean
- for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName)
- {
- if (auto candidate = as<FuncDecl>(ff))
- {
- sink->diagnose(candidate, Diagnostics::entryPointCandidate, candidate->getName());
- }
- }
-
- // Bail out.
- return nullptr;
- }
- }
- }
- }
+ if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit)
+ entryPointFuncDecl = nullptr;
if (!entryPointFuncDecl)
+ {
+ auto translationUnitSyntax = translationUnit->getModuleDecl();
sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, name);
-
+ }
return entryPointFuncDecl;
}
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index a5a09204b..7d2dbed0d 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -2535,8 +2535,11 @@ namespace Slang
{
if (m_entryPoints.getCount() > 0)
return;
-
- for (auto globalDecl : m_moduleDecl->members)
+ _discoverEntryPointsImpl(m_moduleDecl, sink, targets);
+ }
+ void Module::_discoverEntryPointsImpl(ContainerDecl* containerDecl, DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets)
+ {
+ for (auto globalDecl : containerDecl->members)
{
auto maybeFuncDecl = globalDecl;
if (auto genericDecl = as<GenericDecl>(maybeFuncDecl))
@@ -2544,13 +2547,19 @@ namespace Slang
maybeFuncDecl = genericDecl->inner;
}
+ if (as<NamespaceDeclBase>(globalDecl) || as<FileDecl>(globalDecl) || as<StructDecl>(globalDecl))
+ {
+ _discoverEntryPointsImpl(as<ContainerDecl>(globalDecl), sink, targets);
+ continue;
+ }
+
auto funcDecl = as<FuncDecl>(maybeFuncDecl);
if (!funcDecl)
continue;
Profile profile;
bool resolvedStageOfProfileWithEntryPoint = resolveStageOfProfileWithEntryPoint(profile, getLinkage()->m_optionSet, targets, funcDecl, sink);
- if(!resolvedStageOfProfileWithEntryPoint)
+ if (!resolvedStageOfProfileWithEntryPoint)
{
// If there isn't a [shader] attribute, look for a [numthreads] attribute
// since that implicitly means a compute shader. We'll not do this when compiling for
@@ -2560,7 +2569,7 @@ namespace Slang
bool allTargetsCUDARelated = true;
for (auto target : targets)
{
- if (!isCUDATarget(target) &&
+ if (!isCUDATarget(target) &&
target->getTarget() != CodeGenTarget::PyTorchCppBinding)
{
allTargetsCUDARelated = false;
@@ -2614,6 +2623,5 @@ namespace Slang
_addEntryPoint(entryPoint);
}
}
-
}
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 54f43b382..267fdeaf5 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -1568,6 +1568,8 @@ namespace Slang
void _collectShaderParams();
void _discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets);
+ void _discoverEntryPointsImpl(ContainerDecl* containerDecl, DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets);
+
class ModuleSpecializationInfo : public SpecializationInfo
{
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 4a6d33363..b0898b1a0 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -2328,7 +2328,14 @@ DeclRef<Decl> ComponentType::findDeclFromString(
{
result = declRefExpr->declRef;
}
-
+ else if (auto overloadedExpr = as<OverloadedExpr>(checkedExpr))
+ {
+ sink->diagnose(SourceLoc(), Diagnostics::ambiguousReference, name);
+ for (auto candidate : overloadedExpr->lookupResult2)
+ {
+ sink->diagnose(candidate.declRef.getDecl(), Diagnostics::overloadCandidate, candidate.declRef);
+ }
+ }
m_decls[name] = result;
return result;
}