summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2024-10-24 13:42:28 -0500
committerGitHub <noreply@github.com>2024-10-24 11:42:28 -0700
commit46b8ab8353966f2590ed2667028b220b57f963ae (patch)
treec260ae05a92afd9ea5000c66819f6ec8258e6b88 /source
parentee709cffe520df3cf082dc7923609c42dd14cabc (diff)
wasm: Add compile target option when creating slang session (#5403)
* wasm: Add compile target option when creating slang session Also add a new interface to return spirv code which is binary, because 'std::string ComponentType::getEntryPointCode' is not suitable for returning the binary data. We use a more standard way that wrap the binary data by using emscripten::val as the return type. * Add target of metal
Diffstat (limited to 'source')
-rw-r--r--source/slang-wasm/slang-wasm-bindings.cpp16
-rw-r--r--source/slang-wasm/slang-wasm.cpp63
-rw-r--r--source/slang-wasm/slang-wasm.h16
3 files changed, 91 insertions, 4 deletions
diff --git a/source/slang-wasm/slang-wasm-bindings.cpp b/source/slang-wasm/slang-wasm-bindings.cpp
index f8175180a..d033f3846 100644
--- a/source/slang-wasm/slang-wasm-bindings.cpp
+++ b/source/slang-wasm/slang-wasm-bindings.cpp
@@ -17,6 +17,11 @@ EMSCRIPTEN_BINDINGS(slang)
"getLastError",
&slang::wgsl::getLastError);
+ function(
+ "getCompileTargets",
+ &slang::wgsl::getCompileTargets,
+ return_value_policy::take_ownership());
+
class_<slang::wgsl::GlobalSession>("GlobalSession")
.function(
"createSession",
@@ -40,7 +45,10 @@ EMSCRIPTEN_BINDINGS(slang)
return_value_policy::take_ownership())
.function(
"getEntryPointCode",
- &slang::wgsl::ComponentType::getEntryPointCode);
+ &slang::wgsl::ComponentType::getEntryPointCode)
+ .function(
+ "getEntryPointCodeSpirv",
+ &slang::wgsl::ComponentType::getEntryPointCodeSpirv);
class_<slang::wgsl::Module, base<slang::wgsl::ComponentType>>("Module")
.function(
@@ -59,5 +67,11 @@ EMSCRIPTEN_BINDINGS(slang)
class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint");
+ class_<slang::wgsl::CompileTargets>("CompileTargets")
+ .function(
+ "findCompileTarget",
+ &slang::wgsl::CompileTargets::findCompileTarget,
+ return_value_policy::take_ownership());
+
register_vector<slang::wgsl::ComponentType*>("ComponentTypeList");
}
diff --git a/source/slang-wasm/slang-wasm.cpp b/source/slang-wasm/slang-wasm.cpp
index a679a5f3d..6fbe2dc6c 100644
--- a/source/slang-wasm/slang-wasm.cpp
+++ b/source/slang-wasm/slang-wasm.cpp
@@ -14,6 +14,7 @@ namespace wgsl
{
Error g_error;
+CompileTargets g_compileTargets;
Error getLastError()
{
@@ -22,6 +23,11 @@ Error getLastError()
return currentError;
}
+CompileTargets* getCompileTargets()
+{
+ return &g_compileTargets;
+}
+
GlobalSession* createGlobalSession()
{
IGlobalSession* globalSession = nullptr;
@@ -38,7 +44,33 @@ GlobalSession* createGlobalSession()
return new GlobalSession(globalSession);
}
-Session* GlobalSession::createSession()
+CompileTargets::CompileTargets()
+{
+#define MAKE_PAIR(x) { #x, SLANG_##x }
+
+ m_compileTargetMap = {
+ MAKE_PAIR(GLSL),
+ MAKE_PAIR(HLSL),
+ MAKE_PAIR(WGSL),
+ MAKE_PAIR(SPIRV),
+ MAKE_PAIR(METAL),
+ };
+}
+
+int CompileTargets::findCompileTarget(const std::string& name)
+{
+ auto res = m_compileTargetMap.find(name);
+ if ( res != m_compileTargetMap.end())
+ {
+ return res->second;
+ }
+ else
+ {
+ return SLANG_TARGET_UNKNOWN;
+ }
+}
+
+Session* GlobalSession::createSession(int compileTarget)
{
ISession* session = nullptr;
{
@@ -46,7 +78,7 @@ Session* GlobalSession::createSession()
sessionDesc.structureSize = sizeof(sessionDesc);
constexpr SlangInt targetCount = 1;
TargetDesc target = {};
- target.format = SLANG_WGSL;
+ target.format = (SlangCompileTarget)compileTarget;
sessionDesc.targets = &target;
sessionDesc.targetCount = targetCount;
SlangResult result = m_interface->createSession(sessionDesc, &session);
@@ -202,5 +234,32 @@ std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetInde
return {};
}
+// Since spirv code is binary, we can't return it as a string, we will need to use emscripten::val
+// to wrap it and return it to the javascript side.
+emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int targetIndex)
+{
+ Slang::ComPtr<IBlob> kernelBlob;
+ Slang::ComPtr<ISlangBlob> diagnosticBlob;
+ SlangResult result = interface()->getEntryPointCode(
+ entryPointIndex,
+ targetIndex,
+ kernelBlob.writeRef(),
+ diagnosticBlob.writeRef());
+ if (result != SLANG_OK)
+ {
+ g_error.type = std::string("USER");
+ g_error.result = result;
+ g_error.message = std::string(
+ (char*)diagnosticBlob->getBufferPointer(),
+ (char*)diagnosticBlob->getBufferPointer() +
+ diagnosticBlob->getBufferSize());
+ return {};
+ }
+
+ const uint8_t* ptr = (uint8_t*)kernelBlob->getBufferPointer();
+ return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(),
+ ptr));
+}
+
} // namespace wgsl
} // namespace slang
diff --git a/source/slang-wasm/slang-wasm.h b/source/slang-wasm/slang-wasm.h
index c329716e8..a54cfe1ea 100644
--- a/source/slang-wasm/slang-wasm.h
+++ b/source/slang-wasm/slang-wasm.h
@@ -1,6 +1,8 @@
#pragma once
#include <slang.h>
+#include <unordered_map>
+#include <emscripten/val.h>
namespace slang
{
@@ -20,6 +22,17 @@ public:
Error getLastError();
+class CompileTargets
+{
+public:
+ CompileTargets();
+ int findCompileTarget(const std::string& name);
+private:
+ std::unordered_map<std::string, SlangCompileTarget> m_compileTargetMap;
+};
+
+CompileTargets* getCompileTargets();
+
class ComponentType
{
public:
@@ -30,6 +43,7 @@ public:
ComponentType* link();
std::string getEntryPointCode(int entryPointIndex, int targetIndex);
+ emscripten::val getEntryPointCodeSpirv(int entryPointIndex, int targetIndex);
slang::IComponentType* interface() const {return m_interface;}
@@ -93,7 +107,7 @@ public:
GlobalSession(slang::IGlobalSession* interface)
: m_interface(interface) {}
- Session* createSession();
+ Session* createSession(int compileTarget);
slang::IGlobalSession* interface() const {return m_interface;}