From 6a465a4db65b924b03930261da3b64b1c792ef85 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 9 Apr 2024 12:47:03 -0700 Subject: Allow COM based API to discover and check entrypoints without [shader] attribute. (#3914) * Allow COM based API to discover and check entrypoints without [shader] attribute. * Undo changes. * More comments. --- source/slang/slang-check-impl.h | 3 +++ source/slang/slang-compiler.h | 15 +++++++++++++++ source/slang/slang.cpp | 42 ++++++++++++++++++++++++++++++++++------- 3 files changed, 53 insertions(+), 7 deletions(-) (limited to 'source') diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 55edba6b9..e6e980fe8 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2773,4 +2773,7 @@ namespace Slang SemanticsDeclVisitorBase* visitor, Decl* decl, DeclCheckState state); + + RefPtr findAndValidateEntryPoint( + FrontEndEntryPointRequest* entryPointReq); } diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 14d4054c4..014b678f5 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1338,6 +1338,20 @@ namespace Slang return SLANG_OK; } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint( + char const* name, + SlangStage stage, + slang::IEntryPoint** outEntryPoint, + ISlangBlob** outDiagnostics) + { + ComPtr entryPoint(findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics)); + if ((!entryPoint)) + return SLANG_FAIL; + + *outEntryPoint = entryPoint.detach(); + return SLANG_OK; + } + virtual SlangInt32 SLANG_MCALL getDefinedEntryPointCount() override { return (SlangInt32)m_entryPoints.getCount(); @@ -1481,6 +1495,7 @@ namespace Slang }; RefPtr findEntryPointByName(UnownedStringSlice const& name); + RefPtr findAndCheckEntryPoint(UnownedStringSlice const& name, SlangStage stage, ISlangBlob** outDiagnostics); List>& getEntryPoints() { return m_entryPoints; } void _addEntryPoint(EntryPoint* entryPoint); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 6d40b46a2..0db833c7a 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -3987,13 +3987,6 @@ void Module::setName(String name) RefPtr Module::findEntryPointByName(UnownedStringSlice const& name) { - // TODO: We should consider having this function be expanded to be able - // to look up and validate possible entry-point functions in teh module - // even if they were not marked with `[shader(...)]` in the source code. - // - // With such a change the function would probably need to accept a stage - // to use and a sink to write validation errors to. - for(auto entryPoint : m_entryPoints) { if(entryPoint->getName()->text.getUnownedSlice() == name) @@ -4003,6 +3996,41 @@ RefPtr Module::findEntryPointByName(UnownedStringSlice const& name) return nullptr; } + +RefPtr Module::findAndCheckEntryPoint( + UnownedStringSlice const& name, + SlangStage stage, + ISlangBlob** outDiagnostics) +{ + // If there is already an entrypoint marked with the [shader] attribute, + // we should just return that. + // + if (auto existingEntryPoint = findEntryPointByName(name)) + return existingEntryPoint; + + // If the function hasn't been marked as [shader], then it won't be discovered + // by findEntryPointByName. We need to route this to the `findAndValidateEntryPoint` + // function. To do that we need to setup a FrontEndCompileRequest and a FrontEndEntryPointRequest. + // + DiagnosticSink sink(getLinkage()->getSourceManager(), DiagnosticSink::SourceLocationLexer()); + FrontEndCompileRequest frontEndRequest(getLinkage(), StdWriters::getSingleton(), &sink); + RefPtr tuRequest = new TranslationUnitRequest(&frontEndRequest); + tuRequest->module = this; + tuRequest->moduleName = m_name; + frontEndRequest.translationUnits.add(tuRequest); + FrontEndEntryPointRequest entryPointRequest( + &frontEndRequest, + 0, + getLinkage()->getNamePool()->getName(name), + Profile((Stage)stage)); + auto result = findAndValidateEntryPoint(&entryPointRequest); + if (outDiagnostics) + { + sink.getBlobIfNeeded(outDiagnostics); + } + return result; +} + void Module::_addEntryPoint(EntryPoint* entryPoint) { m_entryPoints.add(entryPoint); -- cgit v1.2.3