diff options
| author | Yong He <yonghe@outlook.com> | 2024-11-08 16:19:31 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-08 16:19:31 -0800 |
| commit | 5ca37c316a9ea1833907bae497734054dfa3c3cb (patch) | |
| tree | 6aec7069ca410fd6d2426c36ef102303512c28e3 /source | |
| parent | 7c414463063b979afb0b5184a48a13fcaf5b8af7 (diff) | |
Use automatic coarse grained memory management in wasm binding. (#5528)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang-wasm/slang-wasm-bindings.cpp | 42 | ||||
| -rw-r--r-- | source/slang-wasm/slang-wasm.cpp | 151 | ||||
| -rw-r--r-- | source/slang-wasm/slang-wasm.h | 101 |
3 files changed, 158 insertions, 136 deletions
diff --git a/source/slang-wasm/slang-wasm-bindings.cpp b/source/slang-wasm/slang-wasm-bindings.cpp index e980cd608..346fe8a04 100644 --- a/source/slang-wasm/slang-wasm-bindings.cpp +++ b/source/slang-wasm/slang-wasm-bindings.cpp @@ -9,45 +9,36 @@ EMSCRIPTEN_BINDINGS(slang) { constant("SLANG_OK", SLANG_OK); - function( - "createGlobalSession", - &slang::wgsl::createGlobalSession, - return_value_policy::take_ownership()); - function("getLastError", &slang::wgsl::getLastError); - function( - "getCompileTargets", - &slang::wgsl::getCompileTargets, - return_value_policy::take_ownership()); + function("getCompileTargets", &slang::wgsl::getCompileTargets); class_<slang::wgsl::GlobalSession>("GlobalSession") .function( "createSession", &slang::wgsl::GlobalSession::createSession, - return_value_policy::take_ownership()); + allow_raw_pointers()); + + function("createGlobalSession", &slang::wgsl::createGlobalSession, allow_raw_pointers()); class_<slang::wgsl::Session>("Session") .function( "loadModuleFromSource", &slang::wgsl::Session::loadModuleFromSource, - return_value_policy::take_ownership()) + allow_raw_pointers()) .function( "createCompositeComponentType", &slang::wgsl::Session::createCompositeComponentType, - return_value_policy::take_ownership()); + allow_raw_pointers()); class_<slang::wgsl::ComponentType>("ComponentType") - .function("link", &slang::wgsl::ComponentType::link, return_value_policy::take_ownership()) + .function("link", &slang::wgsl::ComponentType::link, allow_raw_pointers()) .function("getEntryPointCode", &slang::wgsl::ComponentType::getEntryPointCode) .function("getEntryPointCodeBlob", &slang::wgsl::ComponentType::getEntryPointCodeBlob) .function("getTargetCodeBlob", &slang::wgsl::ComponentType::getTargetCodeBlob) .function("getTargetCode", &slang::wgsl::ComponentType::getTargetCode) .function("getLayout", &slang::wgsl::ComponentType::getLayout, allow_raw_pointers()) - .function( - "loadStrings", - &slang::wgsl::ComponentType::loadStrings, - return_value_policy::take_ownership()); + .function("loadStrings", &slang::wgsl::ComponentType::loadStrings, allow_raw_pointers()); class_<slang::wgsl::TypeLayoutReflection>("TypeLayoutReflection") .function( @@ -85,15 +76,15 @@ EMSCRIPTEN_BINDINGS(slang) .function( "findEntryPointByName", &slang::wgsl::Module::findEntryPointByName, - return_value_policy::take_ownership()) + allow_raw_pointers()) .function( "findAndCheckEntryPoint", &slang::wgsl::Module::findAndCheckEntryPoint, - return_value_policy::take_ownership()) + allow_raw_pointers()) .function( "getDefinedEntryPoint", &slang::wgsl::Module::getDefinedEntryPoint, - return_value_policy::take_ownership()) + allow_raw_pointers()) .function("getDefinedEntryPointCount", &slang::wgsl::Module::getDefinedEntryPointCount); value_object<slang::wgsl::Error>("Error") @@ -104,14 +95,6 @@ EMSCRIPTEN_BINDINGS(slang) class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint") .function("getName", &slang::wgsl::EntryPoint::getName, allow_raw_pointers()); - class_<slang::wgsl::CompileTargets>("CompileTargets") - .function( - "findCompileTarget", - &slang::wgsl::CompileTargets::findCompileTarget, - return_value_policy::take_ownership()); - - register_vector<slang::wgsl::ComponentType*>("ComponentTypeList"); - register_vector<std::string>("StringList"); register_optional<std::vector<std::string>>(); @@ -251,7 +234,4 @@ EMSCRIPTEN_BINDINGS(slang) "createLanguageServer", &slang::wgsl::lsp::createLanguageServer, return_value_policy::take_ownership()); - - class_<slang::wgsl::HashedString>("HashedString") - .function("getString", &slang::wgsl::HashedString::getString); }; diff --git a/source/slang-wasm/slang-wasm.cpp b/source/slang-wasm/slang-wasm.cpp index f73e70ba0..bdc2e5e6e 100644 --- a/source/slang-wasm/slang-wasm.cpp +++ b/source/slang-wasm/slang-wasm.cpp @@ -16,7 +16,6 @@ namespace wgsl { Error g_error; -CompileTargets g_compileTargets; Error getLastError() { @@ -25,9 +24,30 @@ Error getLastError() return currentError; } -CompileTargets* getCompileTargets() +emscripten::val getCompileTargets() { - return &g_compileTargets; + struct TargetPair + { + const char* name; + SlangCompileTarget target; + }; + static const TargetPair targets[] = { + {"GLSL", SLANG_GLSL}, + {"HLSL", SLANG_HLSL}, + {"WGSL", SLANG_WGSL}, + {"SPIRV", SLANG_SPIRV}, + {"METAL", SLANG_METAL}, + }; + + std::vector<emscripten::val> result; + for (auto target : targets) + { + auto entry = emscripten::val::object(); + entry.set("name", target.name); + entry.set("value", (int)target.target); + result.push_back(entry); + } + return emscripten::val::array(result); } GlobalSession* createGlobalSession() @@ -46,35 +66,9 @@ GlobalSession* createGlobalSession() return new GlobalSession(globalSession); } -CompileTargets::CompileTargets() -{ -#define MAKE_PAIR(x) {#x, SLANG_##x} - - m_compileTargetMap = { - MAKE_PAIR(GLSL), - MAKE_PAIR(HLSL), - MAKE_PAIR(WGSL), - MAKE_PAIR(SPIRV), - MAKE_PAIR(METAL), - }; -} - -int CompileTargets::findCompileTarget(const std::string& name) -{ - auto res = m_compileTargetMap.find(name); - if (res != m_compileTargetMap.end()) - { - return res->second; - } - else - { - return SLANG_TARGET_UNKNOWN; - } -} - Session* GlobalSession::createSession(int compileTarget) { - ISession* session = nullptr; + Slang::ComPtr<ISession> session; { SessionDesc sessionDesc = {}; sessionDesc.structureSize = sizeof(sessionDesc); @@ -83,7 +77,7 @@ Session* GlobalSession::createSession(int compileTarget) target.format = (SlangCompileTarget)compileTarget; sessionDesc.targets = ⌖ sessionDesc.targetCount = targetCount; - SlangResult result = m_interface->createSession(sessionDesc, &session); + SlangResult result = m_interface->createSession(sessionDesc, session.writeRef()); if (result != SLANG_OK) { g_error.type = std::string("USER"); @@ -95,12 +89,19 @@ Session* GlobalSession::createSession(int compileTarget) return new Session(session); } -Module* Session::loadModuleFromSource( +Session::~Session() +{ + m_componentTypes = {}; + auto refCount = static_cast<Slang::Linkage*>(m_interface.get())->debugGetReferenceCount(); + m_interface = nullptr; +} + +emscripten::val Session::loadModuleFromSource( const std::string& slangCode, const std::string& name, const std::string& path) { - Slang::ComPtr<IModule> module; + IModule* module = nullptr; { Slang::ComPtr<slang::IBlob> diagnosticsBlob; Slang::ComPtr<ISlangBlob> slangCodeBlob = @@ -116,14 +117,13 @@ Module* Session::loadModuleFromSource( g_error.message = std::string( (char*)diagnosticsBlob->getBufferPointer(), (char*)diagnosticsBlob->getBufferPointer() + diagnosticsBlob->getBufferSize()); - return nullptr; + return emscripten::val::null(); } } - - return new Module(module); + return emscripten::val(Module(module, this)); } -EntryPoint* Module::findEntryPointByName(const std::string& name) +emscripten::val Module::findEntryPointByName(const std::string& name) { Slang::ComPtr<IEntryPoint> entryPoint; { @@ -133,15 +133,15 @@ EntryPoint* Module::findEntryPointByName(const std::string& name) { g_error.type = std::string("USER"); g_error.result = result; - return nullptr; + return emscripten::val::null(); } } - - return new EntryPoint(entryPoint); + m_session->addComponentType(entryPoint.get()); + return emscripten::val(EntryPoint(entryPoint.get(), m_session)); } -EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage) +emscripten::val Module::findAndCheckEntryPoint(const std::string& name, int stage) { Slang::ComPtr<IEntryPoint> entryPoint; { @@ -161,11 +161,11 @@ EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage) char* diagnostics = (char*)diagnosticsBlob->getBufferPointer(); g_error.message = std::string(diagnostics); } - return nullptr; + return emscripten::val::null(); } } - - return new EntryPoint(entryPoint); + m_session->addComponentType(entryPoint.get()); + return emscripten::val(EntryPoint(entryPoint.get(), m_session)); } int Module::getDefinedEntryPointCount() @@ -173,10 +173,10 @@ int Module::getDefinedEntryPointCount() return moduleInterface()->getDefinedEntryPointCount(); } -EntryPoint* Module::getDefinedEntryPoint(int index) +emscripten::val Module::getDefinedEntryPoint(int index) { if (moduleInterface()->getDefinedEntryPointCount() <= index) - return nullptr; + return emscripten::val::null(); Slang::ComPtr<IEntryPoint> entryPoint; { @@ -192,21 +192,37 @@ EntryPoint* Module::getDefinedEntryPoint(int index) char* diagnostics = (char*)diagnosticsBlob->getBufferPointer(); g_error.message = std::string(diagnostics); } - return nullptr; + return emscripten::val::null(); } } - - return new EntryPoint(entryPoint); + m_session->addComponentType(entryPoint.get()); + return emscripten::val(EntryPoint(entryPoint.get(), m_session)); } -ComponentType* Session::createCompositeComponentType(const std::vector<ComponentType*>& components) +emscripten::val Session::createCompositeComponentType(emscripten::val components) { + if (!components.isArray()) + { + g_error.type = std::string("Slang WASM Bind"); + g_error.message = std::string("createCompositeComponentType: Components must be an array"); + return emscripten::val::null(); + } + std::vector<emscripten::val> componentsArray = + emscripten::vecFromJSArray<emscripten::val>(components); + Slang::ComPtr<IComponentType> composite; { - std::vector<IComponentType*> nativeComponents(components.size()); - for (size_t i = 0U; i < components.size(); i++) - nativeComponents[i] = components[i]->interface(); + std::vector<IComponentType*> nativeComponents; + for (size_t i = 0U; i < componentsArray.size(); i++) + { + auto componentVal = componentsArray[i]; + if (componentVal.instanceof (emscripten::val::module_property("ComponentType"))) + { + auto componentType = componentVal.as<ComponentType>(); + nativeComponents.push_back(componentType.interface()); + } + } SlangResult result = m_interface->createCompositeComponentType( nativeComponents.data(), (SlangInt)nativeComponents.size(), @@ -215,14 +231,14 @@ ComponentType* Session::createCompositeComponentType(const std::vector<Component { g_error.type = std::string("USER"); g_error.result = result; - return nullptr; + return emscripten::val::null(); } } - - return new ComponentType(composite); + addComponentType(composite.get()); + return emscripten::val(ComponentType(composite, this)); } -ComponentType* ComponentType::link() +emscripten::val ComponentType::link() { Slang::ComPtr<IComponentType> linkedProgram; { @@ -235,11 +251,11 @@ ComponentType* ComponentType::link() g_error.message = std::string( (char*)diagnosticBlob->getBufferPointer(), (char*)diagnosticBlob->getBufferPointer() + diagnosticBlob->getBufferSize()); - return nullptr; + return emscripten::val::null(); } } - - return new ComponentType(linkedProgram); + m_session->addComponentType(linkedProgram.get()); + return emscripten::val(ComponentType(linkedProgram, m_session)); } std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetIndex) @@ -344,14 +360,14 @@ emscripten::val ComponentType::getTargetCodeBlob(int targetIndex) return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(), ptr)); } -HashedString* ComponentType::loadStrings() +emscripten::val ComponentType::loadStrings() { slang::ProgramLayout* slangReflection = interface()->getLayout(); if (!slangReflection) { g_error.type = std::string("USER"); g_error.message = std::string("Failed to get reflection data"); - return nullptr; + return emscripten::val::null(); } SlangUInt hashedStringCount = slangReflection->getHashedStringCount(); @@ -359,11 +375,11 @@ HashedString* ComponentType::loadStrings() { g_error.type = std::string("USER"); g_error.message = std::string("Warn: No reflection data found"); - return nullptr; + return emscripten::val::null(); } size_t stringSize = 0; - HashedString* hashedStrings = new HashedString(); + std::vector<emscripten::val> result; for (SlangUInt ii = 0; ii < hashedStringCount; ++ii) { // For each string we can fetch its bytes from the Slang @@ -381,9 +397,12 @@ HashedString* ComponentType::loadStrings() // int hash = spComputeStringHash(stringData, stringSize); - hashedStrings->insertString(hash, std::string(stringData)); + emscripten::val entry = emscripten::val::object(); + entry.set("hash", hash); + entry.set("string", std::string(stringData)); + result.push_back(entry); } - return hashedStrings; + return emscripten::val::array(result); } ProgramLayout* ComponentType::getLayout(unsigned int targetIndex) diff --git a/source/slang-wasm/slang-wasm.h b/source/slang-wasm/slang-wasm.h index 9e1e023a9..1f2be2fb5 100644 --- a/source/slang-wasm/slang-wasm.h +++ b/source/slang-wasm/slang-wasm.h @@ -5,6 +5,25 @@ #include <slang.h> #include <unordered_map> +/** +The web assembly binding here is designed to make javascript code as simple and native as possible. +The big issue being handled here is lifetime management of objects created in the Slang API. + +The idea here is to make lifetime management as coarse grained as possible from the javascript side. +Only two types of objects need to be explicitly deleted by javascript: GlobalSession and Session. + +All the remaining objects returned by member functions of Session will have their lifetime managed +by the owning session in the C++ side. This way, the javascript code will never need to worry about +freeing small objects like ComponentType, EntryPoint, Module, TypeLayoutReflection, +VariableLayoutReflection, ProgramLayout etc. + +When a Session is no longer needed, the javascript code should explicitly delete it, this will allow +us to free all the objects we allocated from the session in one single explicit call. + +By making explicit memory management as coarse grained as possible, we are making memory management +efficient, simple, and less error prone. +*/ + namespace Slang { class LanguageServerCore; @@ -28,28 +47,9 @@ public: Error getLastError(); -class CompileTargets -{ -public: - CompileTargets(); - int findCompileTarget(const std::string& name); - -private: - std::unordered_map<std::string, SlangCompileTarget> m_compileTargetMap; -}; - -class HashedString -{ -public: - std::string getString(uint32_t hash) { return m_hashedStrings[(int)hash]; } - void insertString(int hash, const std::string& str) { m_hashedStrings[hash] = str; } - -private: - std::unordered_map<int, std::string> m_hashedStrings; -}; - -CompileTargets* getCompileTargets(); - +// returns mapping of codegen target from string to SlangCompileTarget +// in the form of [{name: STRING, value: INT}, ...]. +emscripten::val getCompileTargets(); class TypeLayoutReflection { @@ -84,37 +84,44 @@ public: slang::ProgramLayout* interface() const { return (slang::ProgramLayout*)this; } }; +class Session; class ComponentType { public: - ComponentType(slang::IComponentType* interface) - : m_interface(interface) + IComponentType* m_interface; + Session* m_session; + +public: + ComponentType(slang::IComponentType* interface, Session* session) + : m_interface(interface), m_session(session) { } - ComponentType* link(); + // Returns ComponentType or null. + emscripten::val link(); std::string getEntryPointCode(int entryPointIndex, int targetIndex); + + // Returns UInt8Array or null. emscripten::val getEntryPointCodeBlob(int entryPointIndex, int targetIndex); std::string getTargetCode(int targetIndex); + + // Returns UInt8Array or null. emscripten::val getTargetCodeBlob(int targetIndex); slang::wgsl::ProgramLayout* getLayout(unsigned int targetIndex); slang::IComponentType* interface() const { return m_interface; } - HashedString* loadStrings(); - virtual ~ComponentType() = default; - -private: - Slang::ComPtr<slang::IComponentType> m_interface; + // returns [{hash: HASH, string: STRING}, ...] + emscripten::val loadStrings(); }; class EntryPoint : public ComponentType { public: - EntryPoint(slang::IEntryPoint* interface) - : ComponentType(interface) + EntryPoint(slang::IComponentType* interface, Session* session) + : ComponentType(interface, session) { } std::string getName() const @@ -132,14 +139,20 @@ private: class Module : public ComponentType { public: - Module(slang::IModule* interface) - : ComponentType(interface) + Module(slang::IComponentType* interface, Session* session) + : ComponentType(interface, session) { } - EntryPoint* findEntryPointByName(const std::string& name); - EntryPoint* findAndCheckEntryPoint(const std::string& name, int stage); - EntryPoint* getDefinedEntryPoint(int index); + // Returns EntryPoint or null. + emscripten::val findEntryPointByName(const std::string& name); + + // Returns EntryPoint or null. + emscripten::val findAndCheckEntryPoint(const std::string& name, int stage); + + // Returns EntryPoint or null. + emscripten::val getDefinedEntryPoint(int index); + int getDefinedEntryPointCount(); slang::IModule* moduleInterface() const { return static_cast<slang::IModule*>(interface()); } @@ -152,17 +165,27 @@ public: : m_interface(interface) { } + ~Session(); - Module* loadModuleFromSource( + // Returns Module or null. + emscripten::val loadModuleFromSource( const std::string& slangCode, const std::string& name, const std::string& path); - ComponentType* createCompositeComponentType(const std::vector<ComponentType*>& components); + // `components` is a javascript array of ComponentType/Module/EntryPoint objects. + // Returns ComponentType or null. + emscripten::val createCompositeComponentType(emscripten::val components); slang::ISession* interface() const { return m_interface; } + void addComponentType(slang::IComponentType* componentType) + { + m_componentTypes.push_back(Slang::ComPtr<slang::IComponentType>(componentType)); + } + private: + std::vector<Slang::ComPtr<slang::IComponentType>> m_componentTypes; Slang::ComPtr<slang::ISession> m_interface; }; |
