summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-11-08 16:19:31 -0800
committerGitHub <noreply@github.com>2024-11-08 16:19:31 -0800
commit5ca37c316a9ea1833907bae497734054dfa3c3cb (patch)
tree6aec7069ca410fd6d2426c36ef102303512c28e3 /source
parent7c414463063b979afb0b5184a48a13fcaf5b8af7 (diff)
Use automatic coarse grained memory management in wasm binding. (#5528)
Diffstat (limited to 'source')
-rw-r--r--source/slang-wasm/slang-wasm-bindings.cpp42
-rw-r--r--source/slang-wasm/slang-wasm.cpp151
-rw-r--r--source/slang-wasm/slang-wasm.h101
3 files changed, 158 insertions, 136 deletions
diff --git a/source/slang-wasm/slang-wasm-bindings.cpp b/source/slang-wasm/slang-wasm-bindings.cpp
index e980cd608..346fe8a04 100644
--- a/source/slang-wasm/slang-wasm-bindings.cpp
+++ b/source/slang-wasm/slang-wasm-bindings.cpp
@@ -9,45 +9,36 @@ EMSCRIPTEN_BINDINGS(slang)
{
constant("SLANG_OK", SLANG_OK);
- function(
- "createGlobalSession",
- &slang::wgsl::createGlobalSession,
- return_value_policy::take_ownership());
-
function("getLastError", &slang::wgsl::getLastError);
- function(
- "getCompileTargets",
- &slang::wgsl::getCompileTargets,
- return_value_policy::take_ownership());
+ function("getCompileTargets", &slang::wgsl::getCompileTargets);
class_<slang::wgsl::GlobalSession>("GlobalSession")
.function(
"createSession",
&slang::wgsl::GlobalSession::createSession,
- return_value_policy::take_ownership());
+ allow_raw_pointers());
+
+ function("createGlobalSession", &slang::wgsl::createGlobalSession, allow_raw_pointers());
class_<slang::wgsl::Session>("Session")
.function(
"loadModuleFromSource",
&slang::wgsl::Session::loadModuleFromSource,
- return_value_policy::take_ownership())
+ allow_raw_pointers())
.function(
"createCompositeComponentType",
&slang::wgsl::Session::createCompositeComponentType,
- return_value_policy::take_ownership());
+ allow_raw_pointers());
class_<slang::wgsl::ComponentType>("ComponentType")
- .function("link", &slang::wgsl::ComponentType::link, return_value_policy::take_ownership())
+ .function("link", &slang::wgsl::ComponentType::link, allow_raw_pointers())
.function("getEntryPointCode", &slang::wgsl::ComponentType::getEntryPointCode)
.function("getEntryPointCodeBlob", &slang::wgsl::ComponentType::getEntryPointCodeBlob)
.function("getTargetCodeBlob", &slang::wgsl::ComponentType::getTargetCodeBlob)
.function("getTargetCode", &slang::wgsl::ComponentType::getTargetCode)
.function("getLayout", &slang::wgsl::ComponentType::getLayout, allow_raw_pointers())
- .function(
- "loadStrings",
- &slang::wgsl::ComponentType::loadStrings,
- return_value_policy::take_ownership());
+ .function("loadStrings", &slang::wgsl::ComponentType::loadStrings, allow_raw_pointers());
class_<slang::wgsl::TypeLayoutReflection>("TypeLayoutReflection")
.function(
@@ -85,15 +76,15 @@ EMSCRIPTEN_BINDINGS(slang)
.function(
"findEntryPointByName",
&slang::wgsl::Module::findEntryPointByName,
- return_value_policy::take_ownership())
+ allow_raw_pointers())
.function(
"findAndCheckEntryPoint",
&slang::wgsl::Module::findAndCheckEntryPoint,
- return_value_policy::take_ownership())
+ allow_raw_pointers())
.function(
"getDefinedEntryPoint",
&slang::wgsl::Module::getDefinedEntryPoint,
- return_value_policy::take_ownership())
+ allow_raw_pointers())
.function("getDefinedEntryPointCount", &slang::wgsl::Module::getDefinedEntryPointCount);
value_object<slang::wgsl::Error>("Error")
@@ -104,14 +95,6 @@ EMSCRIPTEN_BINDINGS(slang)
class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint")
.function("getName", &slang::wgsl::EntryPoint::getName, allow_raw_pointers());
- class_<slang::wgsl::CompileTargets>("CompileTargets")
- .function(
- "findCompileTarget",
- &slang::wgsl::CompileTargets::findCompileTarget,
- return_value_policy::take_ownership());
-
- register_vector<slang::wgsl::ComponentType*>("ComponentTypeList");
-
register_vector<std::string>("StringList");
register_optional<std::vector<std::string>>();
@@ -251,7 +234,4 @@ EMSCRIPTEN_BINDINGS(slang)
"createLanguageServer",
&slang::wgsl::lsp::createLanguageServer,
return_value_policy::take_ownership());
-
- class_<slang::wgsl::HashedString>("HashedString")
- .function("getString", &slang::wgsl::HashedString::getString);
};
diff --git a/source/slang-wasm/slang-wasm.cpp b/source/slang-wasm/slang-wasm.cpp
index f73e70ba0..bdc2e5e6e 100644
--- a/source/slang-wasm/slang-wasm.cpp
+++ b/source/slang-wasm/slang-wasm.cpp
@@ -16,7 +16,6 @@ namespace wgsl
{
Error g_error;
-CompileTargets g_compileTargets;
Error getLastError()
{
@@ -25,9 +24,30 @@ Error getLastError()
return currentError;
}
-CompileTargets* getCompileTargets()
+emscripten::val getCompileTargets()
{
- return &g_compileTargets;
+ struct TargetPair
+ {
+ const char* name;
+ SlangCompileTarget target;
+ };
+ static const TargetPair targets[] = {
+ {"GLSL", SLANG_GLSL},
+ {"HLSL", SLANG_HLSL},
+ {"WGSL", SLANG_WGSL},
+ {"SPIRV", SLANG_SPIRV},
+ {"METAL", SLANG_METAL},
+ };
+
+ std::vector<emscripten::val> result;
+ for (auto target : targets)
+ {
+ auto entry = emscripten::val::object();
+ entry.set("name", target.name);
+ entry.set("value", (int)target.target);
+ result.push_back(entry);
+ }
+ return emscripten::val::array(result);
}
GlobalSession* createGlobalSession()
@@ -46,35 +66,9 @@ GlobalSession* createGlobalSession()
return new GlobalSession(globalSession);
}
-CompileTargets::CompileTargets()
-{
-#define MAKE_PAIR(x) {#x, SLANG_##x}
-
- m_compileTargetMap = {
- MAKE_PAIR(GLSL),
- MAKE_PAIR(HLSL),
- MAKE_PAIR(WGSL),
- MAKE_PAIR(SPIRV),
- MAKE_PAIR(METAL),
- };
-}
-
-int CompileTargets::findCompileTarget(const std::string& name)
-{
- auto res = m_compileTargetMap.find(name);
- if (res != m_compileTargetMap.end())
- {
- return res->second;
- }
- else
- {
- return SLANG_TARGET_UNKNOWN;
- }
-}
-
Session* GlobalSession::createSession(int compileTarget)
{
- ISession* session = nullptr;
+ Slang::ComPtr<ISession> session;
{
SessionDesc sessionDesc = {};
sessionDesc.structureSize = sizeof(sessionDesc);
@@ -83,7 +77,7 @@ Session* GlobalSession::createSession(int compileTarget)
target.format = (SlangCompileTarget)compileTarget;
sessionDesc.targets = &target;
sessionDesc.targetCount = targetCount;
- SlangResult result = m_interface->createSession(sessionDesc, &session);
+ SlangResult result = m_interface->createSession(sessionDesc, session.writeRef());
if (result != SLANG_OK)
{
g_error.type = std::string("USER");
@@ -95,12 +89,19 @@ Session* GlobalSession::createSession(int compileTarget)
return new Session(session);
}
-Module* Session::loadModuleFromSource(
+Session::~Session()
+{
+ m_componentTypes = {};
+ auto refCount = static_cast<Slang::Linkage*>(m_interface.get())->debugGetReferenceCount();
+ m_interface = nullptr;
+}
+
+emscripten::val Session::loadModuleFromSource(
const std::string& slangCode,
const std::string& name,
const std::string& path)
{
- Slang::ComPtr<IModule> module;
+ IModule* module = nullptr;
{
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
Slang::ComPtr<ISlangBlob> slangCodeBlob =
@@ -116,14 +117,13 @@ Module* Session::loadModuleFromSource(
g_error.message = std::string(
(char*)diagnosticsBlob->getBufferPointer(),
(char*)diagnosticsBlob->getBufferPointer() + diagnosticsBlob->getBufferSize());
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new Module(module);
+ return emscripten::val(Module(module, this));
}
-EntryPoint* Module::findEntryPointByName(const std::string& name)
+emscripten::val Module::findEntryPointByName(const std::string& name)
{
Slang::ComPtr<IEntryPoint> entryPoint;
{
@@ -133,15 +133,15 @@ EntryPoint* Module::findEntryPointByName(const std::string& name)
{
g_error.type = std::string("USER");
g_error.result = result;
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new EntryPoint(entryPoint);
+ m_session->addComponentType(entryPoint.get());
+ return emscripten::val(EntryPoint(entryPoint.get(), m_session));
}
-EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage)
+emscripten::val Module::findAndCheckEntryPoint(const std::string& name, int stage)
{
Slang::ComPtr<IEntryPoint> entryPoint;
{
@@ -161,11 +161,11 @@ EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage)
char* diagnostics = (char*)diagnosticsBlob->getBufferPointer();
g_error.message = std::string(diagnostics);
}
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new EntryPoint(entryPoint);
+ m_session->addComponentType(entryPoint.get());
+ return emscripten::val(EntryPoint(entryPoint.get(), m_session));
}
int Module::getDefinedEntryPointCount()
@@ -173,10 +173,10 @@ int Module::getDefinedEntryPointCount()
return moduleInterface()->getDefinedEntryPointCount();
}
-EntryPoint* Module::getDefinedEntryPoint(int index)
+emscripten::val Module::getDefinedEntryPoint(int index)
{
if (moduleInterface()->getDefinedEntryPointCount() <= index)
- return nullptr;
+ return emscripten::val::null();
Slang::ComPtr<IEntryPoint> entryPoint;
{
@@ -192,21 +192,37 @@ EntryPoint* Module::getDefinedEntryPoint(int index)
char* diagnostics = (char*)diagnosticsBlob->getBufferPointer();
g_error.message = std::string(diagnostics);
}
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new EntryPoint(entryPoint);
+ m_session->addComponentType(entryPoint.get());
+ return emscripten::val(EntryPoint(entryPoint.get(), m_session));
}
-ComponentType* Session::createCompositeComponentType(const std::vector<ComponentType*>& components)
+emscripten::val Session::createCompositeComponentType(emscripten::val components)
{
+ if (!components.isArray())
+ {
+ g_error.type = std::string("Slang WASM Bind");
+ g_error.message = std::string("createCompositeComponentType: Components must be an array");
+ return emscripten::val::null();
+ }
+ std::vector<emscripten::val> componentsArray =
+ emscripten::vecFromJSArray<emscripten::val>(components);
+
Slang::ComPtr<IComponentType> composite;
{
- std::vector<IComponentType*> nativeComponents(components.size());
- for (size_t i = 0U; i < components.size(); i++)
- nativeComponents[i] = components[i]->interface();
+ std::vector<IComponentType*> nativeComponents;
+ for (size_t i = 0U; i < componentsArray.size(); i++)
+ {
+ auto componentVal = componentsArray[i];
+ if (componentVal.instanceof (emscripten::val::module_property("ComponentType")))
+ {
+ auto componentType = componentVal.as<ComponentType>();
+ nativeComponents.push_back(componentType.interface());
+ }
+ }
SlangResult result = m_interface->createCompositeComponentType(
nativeComponents.data(),
(SlangInt)nativeComponents.size(),
@@ -215,14 +231,14 @@ ComponentType* Session::createCompositeComponentType(const std::vector<Component
{
g_error.type = std::string("USER");
g_error.result = result;
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new ComponentType(composite);
+ addComponentType(composite.get());
+ return emscripten::val(ComponentType(composite, this));
}
-ComponentType* ComponentType::link()
+emscripten::val ComponentType::link()
{
Slang::ComPtr<IComponentType> linkedProgram;
{
@@ -235,11 +251,11 @@ ComponentType* ComponentType::link()
g_error.message = std::string(
(char*)diagnosticBlob->getBufferPointer(),
(char*)diagnosticBlob->getBufferPointer() + diagnosticBlob->getBufferSize());
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new ComponentType(linkedProgram);
+ m_session->addComponentType(linkedProgram.get());
+ return emscripten::val(ComponentType(linkedProgram, m_session));
}
std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetIndex)
@@ -344,14 +360,14 @@ emscripten::val ComponentType::getTargetCodeBlob(int targetIndex)
return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(), ptr));
}
-HashedString* ComponentType::loadStrings()
+emscripten::val ComponentType::loadStrings()
{
slang::ProgramLayout* slangReflection = interface()->getLayout();
if (!slangReflection)
{
g_error.type = std::string("USER");
g_error.message = std::string("Failed to get reflection data");
- return nullptr;
+ return emscripten::val::null();
}
SlangUInt hashedStringCount = slangReflection->getHashedStringCount();
@@ -359,11 +375,11 @@ HashedString* ComponentType::loadStrings()
{
g_error.type = std::string("USER");
g_error.message = std::string("Warn: No reflection data found");
- return nullptr;
+ return emscripten::val::null();
}
size_t stringSize = 0;
- HashedString* hashedStrings = new HashedString();
+ std::vector<emscripten::val> result;
for (SlangUInt ii = 0; ii < hashedStringCount; ++ii)
{
// For each string we can fetch its bytes from the Slang
@@ -381,9 +397,12 @@ HashedString* ComponentType::loadStrings()
//
int hash = spComputeStringHash(stringData, stringSize);
- hashedStrings->insertString(hash, std::string(stringData));
+ emscripten::val entry = emscripten::val::object();
+ entry.set("hash", hash);
+ entry.set("string", std::string(stringData));
+ result.push_back(entry);
}
- return hashedStrings;
+ return emscripten::val::array(result);
}
ProgramLayout* ComponentType::getLayout(unsigned int targetIndex)
diff --git a/source/slang-wasm/slang-wasm.h b/source/slang-wasm/slang-wasm.h
index 9e1e023a9..1f2be2fb5 100644
--- a/source/slang-wasm/slang-wasm.h
+++ b/source/slang-wasm/slang-wasm.h
@@ -5,6 +5,25 @@
#include <slang.h>
#include <unordered_map>
+/**
+The web assembly binding here is designed to make javascript code as simple and native as possible.
+The big issue being handled here is lifetime management of objects created in the Slang API.
+
+The idea here is to make lifetime management as coarse grained as possible from the javascript side.
+Only two types of objects need to be explicitly deleted by javascript: GlobalSession and Session.
+
+All the remaining objects returned by member functions of Session will have their lifetime managed
+by the owning session in the C++ side. This way, the javascript code will never need to worry about
+freeing small objects like ComponentType, EntryPoint, Module, TypeLayoutReflection,
+VariableLayoutReflection, ProgramLayout etc.
+
+When a Session is no longer needed, the javascript code should explicitly delete it, this will allow
+us to free all the objects we allocated from the session in one single explicit call.
+
+By making explicit memory management as coarse grained as possible, we are making memory management
+efficient, simple, and less error prone.
+*/
+
namespace Slang
{
class LanguageServerCore;
@@ -28,28 +47,9 @@ public:
Error getLastError();
-class CompileTargets
-{
-public:
- CompileTargets();
- int findCompileTarget(const std::string& name);
-
-private:
- std::unordered_map<std::string, SlangCompileTarget> m_compileTargetMap;
-};
-
-class HashedString
-{
-public:
- std::string getString(uint32_t hash) { return m_hashedStrings[(int)hash]; }
- void insertString(int hash, const std::string& str) { m_hashedStrings[hash] = str; }
-
-private:
- std::unordered_map<int, std::string> m_hashedStrings;
-};
-
-CompileTargets* getCompileTargets();
-
+// returns mapping of codegen target from string to SlangCompileTarget
+// in the form of [{name: STRING, value: INT}, ...].
+emscripten::val getCompileTargets();
class TypeLayoutReflection
{
@@ -84,37 +84,44 @@ public:
slang::ProgramLayout* interface() const { return (slang::ProgramLayout*)this; }
};
+class Session;
class ComponentType
{
public:
- ComponentType(slang::IComponentType* interface)
- : m_interface(interface)
+ IComponentType* m_interface;
+ Session* m_session;
+
+public:
+ ComponentType(slang::IComponentType* interface, Session* session)
+ : m_interface(interface), m_session(session)
{
}
- ComponentType* link();
+ // Returns ComponentType or null.
+ emscripten::val link();
std::string getEntryPointCode(int entryPointIndex, int targetIndex);
+
+ // Returns UInt8Array or null.
emscripten::val getEntryPointCodeBlob(int entryPointIndex, int targetIndex);
std::string getTargetCode(int targetIndex);
+
+ // Returns UInt8Array or null.
emscripten::val getTargetCodeBlob(int targetIndex);
slang::wgsl::ProgramLayout* getLayout(unsigned int targetIndex);
slang::IComponentType* interface() const { return m_interface; }
- HashedString* loadStrings();
- virtual ~ComponentType() = default;
-
-private:
- Slang::ComPtr<slang::IComponentType> m_interface;
+ // returns [{hash: HASH, string: STRING}, ...]
+ emscripten::val loadStrings();
};
class EntryPoint : public ComponentType
{
public:
- EntryPoint(slang::IEntryPoint* interface)
- : ComponentType(interface)
+ EntryPoint(slang::IComponentType* interface, Session* session)
+ : ComponentType(interface, session)
{
}
std::string getName() const
@@ -132,14 +139,20 @@ private:
class Module : public ComponentType
{
public:
- Module(slang::IModule* interface)
- : ComponentType(interface)
+ Module(slang::IComponentType* interface, Session* session)
+ : ComponentType(interface, session)
{
}
- EntryPoint* findEntryPointByName(const std::string& name);
- EntryPoint* findAndCheckEntryPoint(const std::string& name, int stage);
- EntryPoint* getDefinedEntryPoint(int index);
+ // Returns EntryPoint or null.
+ emscripten::val findEntryPointByName(const std::string& name);
+
+ // Returns EntryPoint or null.
+ emscripten::val findAndCheckEntryPoint(const std::string& name, int stage);
+
+ // Returns EntryPoint or null.
+ emscripten::val getDefinedEntryPoint(int index);
+
int getDefinedEntryPointCount();
slang::IModule* moduleInterface() const { return static_cast<slang::IModule*>(interface()); }
@@ -152,17 +165,27 @@ public:
: m_interface(interface)
{
}
+ ~Session();
- Module* loadModuleFromSource(
+ // Returns Module or null.
+ emscripten::val loadModuleFromSource(
const std::string& slangCode,
const std::string& name,
const std::string& path);
- ComponentType* createCompositeComponentType(const std::vector<ComponentType*>& components);
+ // `components` is a javascript array of ComponentType/Module/EntryPoint objects.
+ // Returns ComponentType or null.
+ emscripten::val createCompositeComponentType(emscripten::val components);
slang::ISession* interface() const { return m_interface; }
+ void addComponentType(slang::IComponentType* componentType)
+ {
+ m_componentTypes.push_back(Slang::ComPtr<slang::IComponentType>(componentType));
+ }
+
private:
+ std::vector<Slang::ComPtr<slang::IComponentType>> m_componentTypes;
Slang::ComPtr<slang::ISession> m_interface;
};