summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-04-09 12:47:03 -0700
committerGitHub <noreply@github.com>2024-04-09 12:47:03 -0700
commit6a465a4db65b924b03930261da3b64b1c792ef85 (patch)
tree8d90c1864fc47e2ed08ded8000a3eadb41ef8f60
parent957b2fbb67efa82d778052c0d63d4de339e89e6f (diff)
Allow COM based API to discover and check entrypoints without [shader] attribute. (#3914)
* Allow COM based API to discover and check entrypoints without [shader] attribute. * Undo changes. * More comments.
-rw-r--r--build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj1
-rw-r--r--build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters3
-rw-r--r--slang.h7
-rw-r--r--source/slang/slang-check-impl.h3
-rwxr-xr-xsource/slang/slang-compiler.h15
-rw-r--r--source/slang/slang.cpp42
-rw-r--r--tools/slang-unit-test/unit-test-find-check-entrypoint.cpp65
7 files changed, 129 insertions, 7 deletions
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj
index 2798b80e4..c1d6f395d 100644
--- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj
+++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj
@@ -294,6 +294,7 @@
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-crypto.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-default-matrix-layout.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-file-system.cpp" />
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-check-entrypoint.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-type-by-name.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-free-list.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-io.cpp" />
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters
index 3fd04c077..f3a9c85f8 100644
--- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters
+++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters
@@ -38,6 +38,9 @@
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-file-system.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-check-entrypoint.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-type-by-name.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/slang.h b/slang.h
index 4014be401..4ed37d88c 100644
--- a/slang.h
+++ b/slang.h
@@ -4949,6 +4949,13 @@ namespace slang
/// Get the unique identity of the module.
virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() = 0;
+ /// Find and validate an entry point by name, even if the function is
+ /// not marked with the `[shader("...")]` attribute.
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint(
+ char const* name,
+ SlangStage stage,
+ IEntryPoint** outEntryPoint,
+ ISlangBlob** outDiagnostics) = 0;
};
#define SLANG_UUID_IModule IModule::getTypeGuid()
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 55edba6b9..e6e980fe8 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2773,4 +2773,7 @@ namespace Slang
SemanticsDeclVisitorBase* visitor,
Decl* decl,
DeclCheckState state);
+
+ RefPtr<EntryPoint> findAndValidateEntryPoint(
+ FrontEndEntryPointRequest* entryPointReq);
}
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 14d4054c4..014b678f5 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -1338,6 +1338,20 @@ namespace Slang
return SLANG_OK;
}
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint(
+ char const* name,
+ SlangStage stage,
+ slang::IEntryPoint** outEntryPoint,
+ ISlangBlob** outDiagnostics)
+ {
+ ComPtr<slang::IEntryPoint> entryPoint(findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics));
+ if ((!entryPoint))
+ return SLANG_FAIL;
+
+ *outEntryPoint = entryPoint.detach();
+ return SLANG_OK;
+ }
+
virtual SlangInt32 SLANG_MCALL getDefinedEntryPointCount() override
{
return (SlangInt32)m_entryPoints.getCount();
@@ -1481,6 +1495,7 @@ namespace Slang
};
RefPtr<EntryPoint> findEntryPointByName(UnownedStringSlice const& name);
+ RefPtr<EntryPoint> findAndCheckEntryPoint(UnownedStringSlice const& name, SlangStage stage, ISlangBlob** outDiagnostics);
List<RefPtr<EntryPoint>>& getEntryPoints() { return m_entryPoints; }
void _addEntryPoint(EntryPoint* entryPoint);
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 6d40b46a2..0db833c7a 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -3987,13 +3987,6 @@ void Module::setName(String name)
RefPtr<EntryPoint> Module::findEntryPointByName(UnownedStringSlice const& name)
{
- // TODO: We should consider having this function be expanded to be able
- // to look up and validate possible entry-point functions in teh module
- // even if they were not marked with `[shader(...)]` in the source code.
- //
- // With such a change the function would probably need to accept a stage
- // to use and a sink to write validation errors to.
-
for(auto entryPoint : m_entryPoints)
{
if(entryPoint->getName()->text.getUnownedSlice() == name)
@@ -4003,6 +3996,41 @@ RefPtr<EntryPoint> Module::findEntryPointByName(UnownedStringSlice const& name)
return nullptr;
}
+
+RefPtr<EntryPoint> Module::findAndCheckEntryPoint(
+ UnownedStringSlice const& name,
+ SlangStage stage,
+ ISlangBlob** outDiagnostics)
+{
+ // If there is already an entrypoint marked with the [shader] attribute,
+ // we should just return that.
+ //
+ if (auto existingEntryPoint = findEntryPointByName(name))
+ return existingEntryPoint;
+
+ // If the function hasn't been marked as [shader], then it won't be discovered
+ // by findEntryPointByName. We need to route this to the `findAndValidateEntryPoint`
+ // function. To do that we need to setup a FrontEndCompileRequest and a FrontEndEntryPointRequest.
+ //
+ DiagnosticSink sink(getLinkage()->getSourceManager(), DiagnosticSink::SourceLocationLexer());
+ FrontEndCompileRequest frontEndRequest(getLinkage(), StdWriters::getSingleton(), &sink);
+ RefPtr<TranslationUnitRequest> tuRequest = new TranslationUnitRequest(&frontEndRequest);
+ tuRequest->module = this;
+ tuRequest->moduleName = m_name;
+ frontEndRequest.translationUnits.add(tuRequest);
+ FrontEndEntryPointRequest entryPointRequest(
+ &frontEndRequest,
+ 0,
+ getLinkage()->getNamePool()->getName(name),
+ Profile((Stage)stage));
+ auto result = findAndValidateEntryPoint(&entryPointRequest);
+ if (outDiagnostics)
+ {
+ sink.getBlobIfNeeded(outDiagnostics);
+ }
+ return result;
+}
+
void Module::_addEntryPoint(EntryPoint* entryPoint)
{
m_entryPoints.add(entryPoint);
diff --git a/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp
new file mode 100644
index 000000000..371c8ae81
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp
@@ -0,0 +1,65 @@
+// unit-test-translation-unit-import.cpp
+
+#include "../../slang.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "tools/unit-test/slang-unit-test.h"
+#include "../../slang-com-ptr.h"
+#include "../../source/core/slang-io.h"
+#include "../../source/core/slang-process.h"
+
+using namespace Slang;
+
+// Test that the IModule::findAndCheckEntryPoint API supports discovering
+// entrypoints without a [shader] attribute.
+
+SLANG_UNIT_TEST(findAndCheckEntryPoint)
+{
+ // Source for a module that contains an undecorated entrypoint.
+ const char* userSourceBody = R"(
+ float4 fragMain(float4 pos:SV_Position) : SV_Position
+ {
+ return pos;
+ }
+ )";
+
+ auto moduleName = "moduleG" + String(Process::getId());
+ String userSource = "import " + moduleName + ";\n" + userSourceBody;
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_HLSL;
+ targetDesc.profile = globalSession->findProfile("sm_5_0");
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ ComPtr<slang::IEntryPoint> entryPoint;
+ module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef());
+ SLANG_CHECK(entryPoint != nullptr);
+
+ ComPtr<slang::IComponentType> compositeProgram;
+ slang::IComponentType* components[] = { module, entryPoint.get() };
+ session->createCompositeComponentType(components, 2, compositeProgram.writeRef(), diagnosticBlob.writeRef());
+ SLANG_CHECK(compositeProgram != nullptr);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ compositeProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
+ SLANG_CHECK(linkedProgram != nullptr);
+
+ ComPtr<slang::IBlob> code;
+ linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
+ SLANG_CHECK(code != nullptr);
+
+ auto codeSrc = UnownedStringSlice((const char*)code->getBufferPointer());
+ SLANG_CHECK(codeSrc.indexOf(toSlice("fragMain")) != -1);
+}
+