diff options
| author | Yong He <yonghe@outlook.com> | 2024-04-09 12:47:03 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-04-09 12:47:03 -0700 |
| commit | 6a465a4db65b924b03930261da3b64b1c792ef85 (patch) | |
| tree | 8d90c1864fc47e2ed08ded8000a3eadb41ef8f60 | |
| parent | 957b2fbb67efa82d778052c0d63d4de339e89e6f (diff) | |
Allow COM based API to discover and check entrypoints without [shader] attribute. (#3914)
* Allow COM based API to discover and check entrypoints without [shader] attribute.
* Undo changes.
* More comments.
| -rw-r--r-- | build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj | 1 | ||||
| -rw-r--r-- | build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters | 3 | ||||
| -rw-r--r-- | slang.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 3 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 15 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 42 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-find-check-entrypoint.cpp | 65 |
7 files changed, 129 insertions, 7 deletions
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj index 2798b80e4..c1d6f395d 100644 --- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj +++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj @@ -294,6 +294,7 @@ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-crypto.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-default-matrix-layout.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-file-system.cpp" />
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-check-entrypoint.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-type-by-name.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-free-list.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-io.cpp" />
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters index 3fd04c077..f3a9c85f8 100644 --- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters +++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters @@ -38,6 +38,9 @@ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-file-system.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-check-entrypoint.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-type-by-name.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -4949,6 +4949,13 @@ namespace slang /// Get the unique identity of the module. virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() = 0; + /// Find and validate an entry point by name, even if the function is + /// not marked with the `[shader("...")]` attribute. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint( + char const* name, + SlangStage stage, + IEntryPoint** outEntryPoint, + ISlangBlob** outDiagnostics) = 0; }; #define SLANG_UUID_IModule IModule::getTypeGuid() diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 55edba6b9..e6e980fe8 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2773,4 +2773,7 @@ namespace Slang SemanticsDeclVisitorBase* visitor, Decl* decl, DeclCheckState state); + + RefPtr<EntryPoint> findAndValidateEntryPoint( + FrontEndEntryPointRequest* entryPointReq); } diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 14d4054c4..014b678f5 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1338,6 +1338,20 @@ namespace Slang return SLANG_OK; } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint( + char const* name, + SlangStage stage, + slang::IEntryPoint** outEntryPoint, + ISlangBlob** outDiagnostics) + { + ComPtr<slang::IEntryPoint> entryPoint(findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics)); + if ((!entryPoint)) + return SLANG_FAIL; + + *outEntryPoint = entryPoint.detach(); + return SLANG_OK; + } + virtual SlangInt32 SLANG_MCALL getDefinedEntryPointCount() override { return (SlangInt32)m_entryPoints.getCount(); @@ -1481,6 +1495,7 @@ namespace Slang }; RefPtr<EntryPoint> findEntryPointByName(UnownedStringSlice const& name); + RefPtr<EntryPoint> findAndCheckEntryPoint(UnownedStringSlice const& name, SlangStage stage, ISlangBlob** outDiagnostics); List<RefPtr<EntryPoint>>& getEntryPoints() { return m_entryPoints; } void _addEntryPoint(EntryPoint* entryPoint); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 6d40b46a2..0db833c7a 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -3987,13 +3987,6 @@ void Module::setName(String name) RefPtr<EntryPoint> Module::findEntryPointByName(UnownedStringSlice const& name) { - // TODO: We should consider having this function be expanded to be able - // to look up and validate possible entry-point functions in teh module - // even if they were not marked with `[shader(...)]` in the source code. - // - // With such a change the function would probably need to accept a stage - // to use and a sink to write validation errors to. - for(auto entryPoint : m_entryPoints) { if(entryPoint->getName()->text.getUnownedSlice() == name) @@ -4003,6 +3996,41 @@ RefPtr<EntryPoint> Module::findEntryPointByName(UnownedStringSlice const& name) return nullptr; } + +RefPtr<EntryPoint> Module::findAndCheckEntryPoint( + UnownedStringSlice const& name, + SlangStage stage, + ISlangBlob** outDiagnostics) +{ + // If there is already an entrypoint marked with the [shader] attribute, + // we should just return that. + // + if (auto existingEntryPoint = findEntryPointByName(name)) + return existingEntryPoint; + + // If the function hasn't been marked as [shader], then it won't be discovered + // by findEntryPointByName. We need to route this to the `findAndValidateEntryPoint` + // function. To do that we need to setup a FrontEndCompileRequest and a FrontEndEntryPointRequest. + // + DiagnosticSink sink(getLinkage()->getSourceManager(), DiagnosticSink::SourceLocationLexer()); + FrontEndCompileRequest frontEndRequest(getLinkage(), StdWriters::getSingleton(), &sink); + RefPtr<TranslationUnitRequest> tuRequest = new TranslationUnitRequest(&frontEndRequest); + tuRequest->module = this; + tuRequest->moduleName = m_name; + frontEndRequest.translationUnits.add(tuRequest); + FrontEndEntryPointRequest entryPointRequest( + &frontEndRequest, + 0, + getLinkage()->getNamePool()->getName(name), + Profile((Stage)stage)); + auto result = findAndValidateEntryPoint(&entryPointRequest); + if (outDiagnostics) + { + sink.getBlobIfNeeded(outDiagnostics); + } + return result; +} + void Module::_addEntryPoint(EntryPoint* entryPoint) { m_entryPoints.add(entryPoint); diff --git a/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp new file mode 100644 index 000000000..371c8ae81 --- /dev/null +++ b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp @@ -0,0 +1,65 @@ +// unit-test-translation-unit-import.cpp + +#include "../../slang.h" + +#include <stdio.h> +#include <stdlib.h> + +#include "tools/unit-test/slang-unit-test.h" +#include "../../slang-com-ptr.h" +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" + +using namespace Slang; + +// Test that the IModule::findAndCheckEntryPoint API supports discovering +// entrypoints without a [shader] attribute. + +SLANG_UNIT_TEST(findAndCheckEntryPoint) +{ + // Source for a module that contains an undecorated entrypoint. + const char* userSourceBody = R"( + float4 fragMain(float4 pos:SV_Position) : SV_Position + { + return pos; + } + )"; + + auto moduleName = "moduleG" + String(Process::getId()); + String userSource = "import " + moduleName + ";\n" + userSourceBody; + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(entryPoint != nullptr); + + ComPtr<slang::IComponentType> compositeProgram; + slang::IComponentType* components[] = { module, entryPoint.get() }; + session->createCompositeComponentType(components, 2, compositeProgram.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(compositeProgram != nullptr); + + ComPtr<slang::IComponentType> linkedProgram; + compositeProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(linkedProgram != nullptr); + + ComPtr<slang::IBlob> code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(code != nullptr); + + auto codeSrc = UnownedStringSlice((const char*)code->getBufferPointer()); + SLANG_CHECK(codeSrc.indexOf(toSlice("fragMain")) != -1); +} + |
