From 6a465a4db65b924b03930261da3b64b1c792ef85 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 9 Apr 2024 12:47:03 -0700 Subject: Allow COM based API to discover and check entrypoints without [shader] attribute. (#3914) * Allow COM based API to discover and check entrypoints without [shader] attribute. * Undo changes. * More comments. --- .../slang-unit-test-tool.vcxproj | 1 + .../slang-unit-test-tool.vcxproj.filters | 3 + slang.h | 7 +++ source/slang/slang-check-impl.h | 3 + source/slang/slang-compiler.h | 15 +++++ source/slang/slang.cpp | 42 +++++++++++--- .../unit-test-find-check-entrypoint.cpp | 65 ++++++++++++++++++++++ 7 files changed, 129 insertions(+), 7 deletions(-) create mode 100644 tools/slang-unit-test/unit-test-find-check-entrypoint.cpp diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj index 2798b80e4..c1d6f395d 100644 --- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj +++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj @@ -294,6 +294,7 @@ + diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters index 3fd04c077..f3a9c85f8 100644 --- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters +++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters @@ -38,6 +38,9 @@ Source Files + + Source Files + Source Files diff --git a/slang.h b/slang.h index 4014be401..4ed37d88c 100644 --- a/slang.h +++ b/slang.h @@ -4949,6 +4949,13 @@ namespace slang /// Get the unique identity of the module. virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() = 0; + /// Find and validate an entry point by name, even if the function is + /// not marked with the `[shader("...")]` attribute. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint( + char const* name, + SlangStage stage, + IEntryPoint** outEntryPoint, + ISlangBlob** outDiagnostics) = 0; }; #define SLANG_UUID_IModule IModule::getTypeGuid() diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 55edba6b9..e6e980fe8 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2773,4 +2773,7 @@ namespace Slang SemanticsDeclVisitorBase* visitor, Decl* decl, DeclCheckState state); + + RefPtr findAndValidateEntryPoint( + FrontEndEntryPointRequest* entryPointReq); } diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 14d4054c4..014b678f5 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1338,6 +1338,20 @@ namespace Slang return SLANG_OK; } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint( + char const* name, + SlangStage stage, + slang::IEntryPoint** outEntryPoint, + ISlangBlob** outDiagnostics) + { + ComPtr entryPoint(findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics)); + if ((!entryPoint)) + return SLANG_FAIL; + + *outEntryPoint = entryPoint.detach(); + return SLANG_OK; + } + virtual SlangInt32 SLANG_MCALL getDefinedEntryPointCount() override { return (SlangInt32)m_entryPoints.getCount(); @@ -1481,6 +1495,7 @@ namespace Slang }; RefPtr findEntryPointByName(UnownedStringSlice const& name); + RefPtr findAndCheckEntryPoint(UnownedStringSlice const& name, SlangStage stage, ISlangBlob** outDiagnostics); List>& getEntryPoints() { return m_entryPoints; } void _addEntryPoint(EntryPoint* entryPoint); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 6d40b46a2..0db833c7a 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -3987,13 +3987,6 @@ void Module::setName(String name) RefPtr Module::findEntryPointByName(UnownedStringSlice const& name) { - // TODO: We should consider having this function be expanded to be able - // to look up and validate possible entry-point functions in teh module - // even if they were not marked with `[shader(...)]` in the source code. - // - // With such a change the function would probably need to accept a stage - // to use and a sink to write validation errors to. - for(auto entryPoint : m_entryPoints) { if(entryPoint->getName()->text.getUnownedSlice() == name) @@ -4003,6 +3996,41 @@ RefPtr Module::findEntryPointByName(UnownedStringSlice const& name) return nullptr; } + +RefPtr Module::findAndCheckEntryPoint( + UnownedStringSlice const& name, + SlangStage stage, + ISlangBlob** outDiagnostics) +{ + // If there is already an entrypoint marked with the [shader] attribute, + // we should just return that. + // + if (auto existingEntryPoint = findEntryPointByName(name)) + return existingEntryPoint; + + // If the function hasn't been marked as [shader], then it won't be discovered + // by findEntryPointByName. We need to route this to the `findAndValidateEntryPoint` + // function. To do that we need to setup a FrontEndCompileRequest and a FrontEndEntryPointRequest. + // + DiagnosticSink sink(getLinkage()->getSourceManager(), DiagnosticSink::SourceLocationLexer()); + FrontEndCompileRequest frontEndRequest(getLinkage(), StdWriters::getSingleton(), &sink); + RefPtr tuRequest = new TranslationUnitRequest(&frontEndRequest); + tuRequest->module = this; + tuRequest->moduleName = m_name; + frontEndRequest.translationUnits.add(tuRequest); + FrontEndEntryPointRequest entryPointRequest( + &frontEndRequest, + 0, + getLinkage()->getNamePool()->getName(name), + Profile((Stage)stage)); + auto result = findAndValidateEntryPoint(&entryPointRequest); + if (outDiagnostics) + { + sink.getBlobIfNeeded(outDiagnostics); + } + return result; +} + void Module::_addEntryPoint(EntryPoint* entryPoint) { m_entryPoints.add(entryPoint); diff --git a/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp new file mode 100644 index 000000000..371c8ae81 --- /dev/null +++ b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp @@ -0,0 +1,65 @@ +// unit-test-translation-unit-import.cpp + +#include "../../slang.h" + +#include +#include + +#include "tools/unit-test/slang-unit-test.h" +#include "../../slang-com-ptr.h" +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" + +using namespace Slang; + +// Test that the IModule::findAndCheckEntryPoint API supports discovering +// entrypoints without a [shader] attribute. + +SLANG_UNIT_TEST(findAndCheckEntryPoint) +{ + // Source for a module that contains an undecorated entrypoint. + const char* userSourceBody = R"( + float4 fragMain(float4 pos:SV_Position) : SV_Position + { + return pos; + } + )"; + + auto moduleName = "moduleG" + String(Process::getId()); + String userSource = "import " + moduleName + ";\n" + userSourceBody; + ComPtr globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr diagnosticBlob; + auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr entryPoint; + module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(entryPoint != nullptr); + + ComPtr compositeProgram; + slang::IComponentType* components[] = { module, entryPoint.get() }; + session->createCompositeComponentType(components, 2, compositeProgram.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(compositeProgram != nullptr); + + ComPtr linkedProgram; + compositeProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(linkedProgram != nullptr); + + ComPtr code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(code != nullptr); + + auto codeSrc = UnownedStringSlice((const char*)code->getBufferPointer()); + SLANG_CHECK(codeSrc.indexOf(toSlice("fragMain")) != -1); +} + -- cgit v1.2.3