From 6e24244832d9032f0993cb088af625238096b723 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 17 Dec 2024 21:57:58 -0800 Subject: Support specializing generic entrypoints in `findAndCheckEntryPoint`. (#5890) --- source/slang/slang-check-shader.cpp | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) (limited to 'source') diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 50382f9c1..c9e190469 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -229,29 +229,25 @@ bool isPrimaryDecl(CallableDecl* decl) return (!decl->primaryDecl) || (decl == decl->primaryDecl); } -FuncDecl* findFunctionDeclByName(Module* translationUnit, Name* name, DiagnosticSink* sink) +DeclRef findFunctionDeclByName(Module* translationUnit, Name* name, DiagnosticSink* sink) { - FuncDecl* entryPointFuncDecl = nullptr; + DeclRef entryPointFuncDeclRef; auto expr = translationUnit->findDeclFromString(getText(name), sink); if (auto declRefExpr = as(expr)) { - auto declRef = declRefExpr->declRef; - entryPointFuncDecl = declRef.as().getDecl(); + entryPointFuncDeclRef = declRefExpr->declRef.as(); - if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit) - entryPointFuncDecl = nullptr; + if (entryPointFuncDeclRef && getModule(entryPointFuncDeclRef.getDecl()) != translationUnit) + entryPointFuncDeclRef = DeclRef(); } - if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit) - entryPointFuncDecl = nullptr; - - if (!entryPointFuncDecl) + if (!entryPointFuncDeclRef) { auto translationUnitSyntax = translationUnit->getModuleDecl(); sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, name); } - return entryPointFuncDecl; + return entryPointFuncDeclRef; } // Is a entry pointer parmaeter of `type` always a uniform parameter? @@ -409,8 +405,8 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) // make up the entry point. // Name* name = linkage->getNamePool()->getName(stringLit->value); - FuncDecl* patchConstantFuncDecl = findFunctionDeclByName(module, name, sink); - if (!patchConstantFuncDecl) + DeclRef patchConstantFuncDeclRef = findFunctionDeclByName(module, name, sink); + if (!patchConstantFuncDeclRef) { sink->diagnose( expr, @@ -420,7 +416,7 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) return; } - attr->patchConstantFuncDecl = patchConstantFuncDecl; + attr->patchConstantFuncDecl = patchConstantFuncDeclRef.getDecl(); } } else if (stage == Stage::Compute) @@ -648,11 +644,11 @@ RefPtr findAndValidateEntryPoint(FrontEndEntryPointRequest* entryPoi auto sink = compileRequest->getSink(); auto entryPointName = entryPointReq->getName(); - FuncDecl* entryPointFuncDecl = + DeclRef entryPointFuncDeclRef = findFunctionDeclByName(translationUnit->getModule(), entryPointName, sink); // Did we find a function declaration in our search? - if (!entryPointFuncDecl) + if (!entryPointFuncDeclRef) { return nullptr; } @@ -673,14 +669,14 @@ RefPtr findAndValidateEntryPoint(FrontEndEntryPointRequest* entryPoi entryPointProfile, linkage->m_optionSet, linkage->targets, - entryPointFuncDecl, + entryPointFuncDeclRef.getDecl(), sink); // TODO: Should we attach a `[shader(...)]` attribute to an // entry point that didn't have one, so that we can have // a more uniform representation in the AST? RefPtr entryPoint = - EntryPoint::create(linkage, makeDeclRef(entryPointFuncDecl), entryPointProfile); + EntryPoint::create(linkage, entryPointFuncDeclRef, entryPointProfile); // Now that we've *found* the entry point, it is time to validate // that it actually meets the constraints for the chosen stage/profile. -- cgit v1.2.3