summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-10-28 08:53:41 -0700
committerGitHub <noreply@github.com>2024-10-28 08:53:41 -0700
commit04329077988a2b1f7a87b1d116457599039e5e12 (patch)
tree86b29e773f04f0926cc7685766d0a0b135c90500 /source
parenta3276e2876be00fd4b0a69e47a66b1cff29765f2 (diff)
More wasm binding for playground. (#5420)
Diffstat (limited to 'source')
-rw-r--r--source/slang-wasm/slang-wasm-bindings.cpp25
-rw-r--r--source/slang-wasm/slang-wasm.cpp96
-rw-r--r--source/slang-wasm/slang-wasm.h15
-rw-r--r--source/slang/slang.cpp8
4 files changed, 130 insertions, 14 deletions
diff --git a/source/slang-wasm/slang-wasm-bindings.cpp b/source/slang-wasm/slang-wasm-bindings.cpp
index 360dec6eb..56a472482 100644
--- a/source/slang-wasm/slang-wasm-bindings.cpp
+++ b/source/slang-wasm/slang-wasm-bindings.cpp
@@ -47,8 +47,14 @@ EMSCRIPTEN_BINDINGS(slang)
"getEntryPointCode",
&slang::wgsl::ComponentType::getEntryPointCode)
.function(
- "getEntryPointCodeSpirv",
- &slang::wgsl::ComponentType::getEntryPointCodeSpirv);
+ "getEntryPointCodeBlob",
+ &slang::wgsl::ComponentType::getEntryPointCodeBlob)
+ .function(
+ "getTargetCodeBlob",
+ &slang::wgsl::ComponentType::getTargetCodeBlob)
+ .function(
+ "getTargetCode",
+ &slang::wgsl::ComponentType::getTargetCode);
class_<slang::wgsl::Module, base<slang::wgsl::ComponentType>>("Module")
.function(
@@ -58,14 +64,25 @@ EMSCRIPTEN_BINDINGS(slang)
.function(
"findAndCheckEntryPoint",
&slang::wgsl::Module::findAndCheckEntryPoint,
- return_value_policy::take_ownership());
+ return_value_policy::take_ownership())
+ .function(
+ "getDefinedEntryPoint",
+ &slang::wgsl::Module::getDefinedEntryPoint,
+ return_value_policy::take_ownership())
+ .function(
+ "getDefinedEntryPointCount",
+ &slang::wgsl::Module::getDefinedEntryPointCount);
value_object<slang::wgsl::Error>("Error")
.field("type", &slang::wgsl::Error::type)
.field("result", &slang::wgsl::Error::result)
.field("message", &slang::wgsl::Error::message);
- class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint");
+ class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint")
+ .function(
+ "getName",
+ &slang::wgsl::EntryPoint::getName,
+ allow_raw_pointers());
class_<slang::wgsl::CompileTargets>("CompileTargets")
.function(
diff --git a/source/slang-wasm/slang-wasm.cpp b/source/slang-wasm/slang-wasm.cpp
index 886efc5f7..8948c1075 100644
--- a/source/slang-wasm/slang-wasm.cpp
+++ b/source/slang-wasm/slang-wasm.cpp
@@ -94,17 +94,15 @@ Session* GlobalSession::createSession(int compileTarget)
return new Session(session);
}
-Module* Session::loadModuleFromSource(const std::string& slangCode)
+Module* Session::loadModuleFromSource(const std::string& slangCode, const std::string& name, const std::string& path)
{
Slang::ComPtr<IModule> module;
{
- const char * name = "";
- const char * path = "";
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
Slang::ComPtr<ISlangBlob> slangCodeBlob = Slang::RawBlob::create(
slangCode.c_str(), slangCode.size());
module = m_interface->loadModuleFromSource(
- name, path, slangCodeBlob, diagnosticsBlob.writeRef());
+ name.c_str(), path.c_str(), slangCodeBlob, diagnosticsBlob.writeRef());
if (!module)
{
g_error.type = std::string("USER");
@@ -161,6 +159,38 @@ EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage)
return new EntryPoint(entryPoint);
}
+int Module::getDefinedEntryPointCount()
+{
+ return moduleInterface()->getDefinedEntryPointCount();
+}
+
+EntryPoint* Module::getDefinedEntryPoint(int index)
+{
+ if (moduleInterface()->getDefinedEntryPointCount() <= index)
+ return nullptr;
+
+ Slang::ComPtr<IEntryPoint> entryPoint;
+ {
+ Slang::ComPtr<slang::IBlob> diagnosticsBlob;
+ SlangResult result = moduleInterface()->getDefinedEntryPoint(index, entryPoint.writeRef());
+ if (!SLANG_SUCCEEDED(result))
+ {
+ g_error.type = std::string("USER");
+ g_error.result = result;
+
+ if (diagnosticsBlob->getBufferSize())
+ {
+ char* diagnostics = (char*)diagnosticsBlob->getBufferPointer();
+ g_error.message = std::string(diagnostics);
+ }
+ return nullptr;
+ }
+ }
+
+ return new EntryPoint(entryPoint);
+}
+
+
ComponentType* Session::createCompositeComponentType(
const std::vector<ComponentType*>& components)
{
@@ -235,9 +265,9 @@ std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetInde
return {};
}
-// Since spirv code is binary, we can't return it as a string, we will need to use emscripten::val
+// Since result code is binary, we can't return it as a string, we will need to use emscripten::val
// to wrap it and return it to the javascript side.
-emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int targetIndex)
+emscripten::val ComponentType::getEntryPointCodeBlob(int entryPointIndex, int targetIndex)
{
Slang::ComPtr<IBlob> kernelBlob;
Slang::ComPtr<ISlangBlob> diagnosticBlob;
@@ -262,6 +292,60 @@ emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int t
ptr));
}
+std::string ComponentType::getTargetCode(int targetIndex)
+{
+ {
+ Slang::ComPtr<IBlob> kernelBlob;
+ Slang::ComPtr<ISlangBlob> diagnosticBlob;
+ SlangResult result = interface()->getTargetCode(
+ targetIndex,
+ kernelBlob.writeRef(),
+ diagnosticBlob.writeRef());
+ if (result != SLANG_OK)
+ {
+ g_error.type = std::string("USER");
+ g_error.result = result;
+ g_error.message = std::string(
+ (char*)diagnosticBlob->getBufferPointer(),
+ (char*)diagnosticBlob->getBufferPointer() +
+ diagnosticBlob->getBufferSize());
+ return "";
+ }
+ std::string targetCode = std::string(
+ (char*)kernelBlob->getBufferPointer(),
+ (char*)kernelBlob->getBufferPointer() + kernelBlob->getBufferSize());
+ return targetCode;
+ }
+
+ return {};
+}
+
+// Since result code is binary, we can't return it as a string, we will need to use emscripten::val
+// to wrap it and return it to the javascript side.
+emscripten::val ComponentType::getTargetCodeBlob(int targetIndex)
+{
+ Slang::ComPtr<IBlob> kernelBlob;
+ Slang::ComPtr<ISlangBlob> diagnosticBlob;
+ SlangResult result = interface()->getTargetCode(
+ targetIndex,
+ kernelBlob.writeRef(),
+ diagnosticBlob.writeRef());
+ if (result != SLANG_OK)
+ {
+ g_error.type = std::string("USER");
+ g_error.result = result;
+ g_error.message = std::string(
+ (char*)diagnosticBlob->getBufferPointer(),
+ (char*)diagnosticBlob->getBufferPointer() +
+ diagnosticBlob->getBufferSize());
+ return {};
+ }
+
+ const uint8_t* ptr = (uint8_t*)kernelBlob->getBufferPointer();
+ return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(),
+ ptr));
+}
+
namespace lsp
{
Position translate(Slang::LanguageServerProtocol::Position p)
diff --git a/source/slang-wasm/slang-wasm.h b/source/slang-wasm/slang-wasm.h
index 5a299453c..eb302119b 100644
--- a/source/slang-wasm/slang-wasm.h
+++ b/source/slang-wasm/slang-wasm.h
@@ -48,7 +48,9 @@ public:
ComponentType* link();
std::string getEntryPointCode(int entryPointIndex, int targetIndex);
- emscripten::val getEntryPointCodeSpirv(int entryPointIndex, int targetIndex);
+ emscripten::val getEntryPointCodeBlob(int entryPointIndex, int targetIndex);
+ std::string getTargetCode(int targetIndex);
+ emscripten::val getTargetCodeBlob(int targetIndex);
slang::IComponentType* interface() const {return m_interface;}
@@ -62,9 +64,11 @@ private:
class EntryPoint : public ComponentType
{
public:
-
EntryPoint(slang::IEntryPoint* interface) : ComponentType(interface) {}
-
+ std::string getName() const
+ {
+ return entryPointInterface()->getFunctionReflection()->getName();
+ }
private:
slang::IEntryPoint* entryPointInterface() const {
@@ -80,6 +84,8 @@ public:
EntryPoint* findEntryPointByName(const std::string& name);
EntryPoint* findAndCheckEntryPoint(const std::string& name, int stage);
+ EntryPoint* getDefinedEntryPoint(int index);
+ int getDefinedEntryPointCount();
slang::IModule* moduleInterface() const {
return static_cast<slang::IModule*>(interface());
@@ -93,7 +99,8 @@ public:
Session(slang::ISession* interface)
: m_interface(interface) {}
- Module* loadModuleFromSource(const std::string& slangCode);
+ Module* loadModuleFromSource(
+ const std::string& slangCode, const std::string& name, const std::string& path);
ComponentType* createCompositeComponentType(
const std::vector<ComponentType*>& components);
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index b77a3efc8..b3ac7f73d 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -5040,13 +5040,21 @@ IArtifact* ComponentType::getTargetArtifact(Int targetIndex, slang::IBlob** outD
});
List<RefPtr<ComponentType>> components;
components.add(this);
+ bool entryPointsDiscovered = false;
for (auto module : modules)
{
for (auto entryPoint : module->getEntryPoints())
{
components.add(entryPoint);
+ entryPointsDiscovered = true;
}
}
+ // If no entry points were discovered, then we should return nullptr.
+ if (!entryPointsDiscovered)
+ {
+ return nullptr;
+ }
+
RefPtr<CompositeComponentType> composite = new CompositeComponentType(linkage, components);
ComPtr<IComponentType> linkedComponentType;
SLANG_RETURN_NULL_ON_FAIL(composite->link(linkedComponentType.writeRef(), outDiagnostics));