diff options
| author | Yong He <yonghe@outlook.com> | 2024-11-08 16:19:31 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-08 16:19:31 -0800 |
| commit | 5ca37c316a9ea1833907bae497734054dfa3c3cb (patch) | |
| tree | 6aec7069ca410fd6d2426c36ef102303512c28e3 /source/slang-wasm/slang-wasm.cpp | |
| parent | 7c414463063b979afb0b5184a48a13fcaf5b8af7 (diff) | |
Use automatic coarse grained memory management in wasm binding. (#5528)
Diffstat (limited to 'source/slang-wasm/slang-wasm.cpp')
| -rw-r--r-- | source/slang-wasm/slang-wasm.cpp | 151 |
1 files changed, 85 insertions, 66 deletions
diff --git a/source/slang-wasm/slang-wasm.cpp b/source/slang-wasm/slang-wasm.cpp index f73e70ba0..bdc2e5e6e 100644 --- a/source/slang-wasm/slang-wasm.cpp +++ b/source/slang-wasm/slang-wasm.cpp @@ -16,7 +16,6 @@ namespace wgsl { Error g_error; -CompileTargets g_compileTargets; Error getLastError() { @@ -25,9 +24,30 @@ Error getLastError() return currentError; } -CompileTargets* getCompileTargets() +emscripten::val getCompileTargets() { - return &g_compileTargets; + struct TargetPair + { + const char* name; + SlangCompileTarget target; + }; + static const TargetPair targets[] = { + {"GLSL", SLANG_GLSL}, + {"HLSL", SLANG_HLSL}, + {"WGSL", SLANG_WGSL}, + {"SPIRV", SLANG_SPIRV}, + {"METAL", SLANG_METAL}, + }; + + std::vector<emscripten::val> result; + for (auto target : targets) + { + auto entry = emscripten::val::object(); + entry.set("name", target.name); + entry.set("value", (int)target.target); + result.push_back(entry); + } + return emscripten::val::array(result); } GlobalSession* createGlobalSession() @@ -46,35 +66,9 @@ GlobalSession* createGlobalSession() return new GlobalSession(globalSession); } -CompileTargets::CompileTargets() -{ -#define MAKE_PAIR(x) {#x, SLANG_##x} - - m_compileTargetMap = { - MAKE_PAIR(GLSL), - MAKE_PAIR(HLSL), - MAKE_PAIR(WGSL), - MAKE_PAIR(SPIRV), - MAKE_PAIR(METAL), - }; -} - -int CompileTargets::findCompileTarget(const std::string& name) -{ - auto res = m_compileTargetMap.find(name); - if (res != m_compileTargetMap.end()) - { - return res->second; - } - else - { - return SLANG_TARGET_UNKNOWN; - } -} - Session* GlobalSession::createSession(int compileTarget) { - ISession* session = nullptr; + Slang::ComPtr<ISession> session; { SessionDesc sessionDesc = {}; sessionDesc.structureSize = sizeof(sessionDesc); @@ -83,7 +77,7 @@ Session* GlobalSession::createSession(int compileTarget) target.format = (SlangCompileTarget)compileTarget; sessionDesc.targets = ⌖ sessionDesc.targetCount = targetCount; - SlangResult result = m_interface->createSession(sessionDesc, &session); + SlangResult result = m_interface->createSession(sessionDesc, session.writeRef()); if (result != SLANG_OK) { g_error.type = std::string("USER"); @@ -95,12 +89,19 @@ Session* GlobalSession::createSession(int compileTarget) return new Session(session); } -Module* Session::loadModuleFromSource( +Session::~Session() +{ + m_componentTypes = {}; + auto refCount = static_cast<Slang::Linkage*>(m_interface.get())->debugGetReferenceCount(); + m_interface = nullptr; +} + +emscripten::val Session::loadModuleFromSource( const std::string& slangCode, const std::string& name, const std::string& path) { - Slang::ComPtr<IModule> module; + IModule* module = nullptr; { Slang::ComPtr<slang::IBlob> diagnosticsBlob; Slang::ComPtr<ISlangBlob> slangCodeBlob = @@ -116,14 +117,13 @@ Module* Session::loadModuleFromSource( g_error.message = std::string( (char*)diagnosticsBlob->getBufferPointer(), (char*)diagnosticsBlob->getBufferPointer() + diagnosticsBlob->getBufferSize()); - return nullptr; + return emscripten::val::null(); } } - - return new Module(module); + return emscripten::val(Module(module, this)); } -EntryPoint* Module::findEntryPointByName(const std::string& name) +emscripten::val Module::findEntryPointByName(const std::string& name) { Slang::ComPtr<IEntryPoint> entryPoint; { @@ -133,15 +133,15 @@ EntryPoint* Module::findEntryPointByName(const std::string& name) { g_error.type = std::string("USER"); g_error.result = result; - return nullptr; + return emscripten::val::null(); } } - - return new EntryPoint(entryPoint); + m_session->addComponentType(entryPoint.get()); + return emscripten::val(EntryPoint(entryPoint.get(), m_session)); } -EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage) +emscripten::val Module::findAndCheckEntryPoint(const std::string& name, int stage) { Slang::ComPtr<IEntryPoint> entryPoint; { @@ -161,11 +161,11 @@ EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage) char* diagnostics = (char*)diagnosticsBlob->getBufferPointer(); g_error.message = std::string(diagnostics); } - return nullptr; + return emscripten::val::null(); } } - - return new EntryPoint(entryPoint); + m_session->addComponentType(entryPoint.get()); + return emscripten::val(EntryPoint(entryPoint.get(), m_session)); } int Module::getDefinedEntryPointCount() @@ -173,10 +173,10 @@ int Module::getDefinedEntryPointCount() return moduleInterface()->getDefinedEntryPointCount(); } -EntryPoint* Module::getDefinedEntryPoint(int index) +emscripten::val Module::getDefinedEntryPoint(int index) { if (moduleInterface()->getDefinedEntryPointCount() <= index) - return nullptr; + return emscripten::val::null(); Slang::ComPtr<IEntryPoint> entryPoint; { @@ -192,21 +192,37 @@ EntryPoint* Module::getDefinedEntryPoint(int index) char* diagnostics = (char*)diagnosticsBlob->getBufferPointer(); g_error.message = std::string(diagnostics); } - return nullptr; + return emscripten::val::null(); } } - - return new EntryPoint(entryPoint); + m_session->addComponentType(entryPoint.get()); + return emscripten::val(EntryPoint(entryPoint.get(), m_session)); } -ComponentType* Session::createCompositeComponentType(const std::vector<ComponentType*>& components) +emscripten::val Session::createCompositeComponentType(emscripten::val components) { + if (!components.isArray()) + { + g_error.type = std::string("Slang WASM Bind"); + g_error.message = std::string("createCompositeComponentType: Components must be an array"); + return emscripten::val::null(); + } + std::vector<emscripten::val> componentsArray = + emscripten::vecFromJSArray<emscripten::val>(components); + Slang::ComPtr<IComponentType> composite; { - std::vector<IComponentType*> nativeComponents(components.size()); - for (size_t i = 0U; i < components.size(); i++) - nativeComponents[i] = components[i]->interface(); + std::vector<IComponentType*> nativeComponents; + for (size_t i = 0U; i < componentsArray.size(); i++) + { + auto componentVal = componentsArray[i]; + if (componentVal.instanceof (emscripten::val::module_property("ComponentType"))) + { + auto componentType = componentVal.as<ComponentType>(); + nativeComponents.push_back(componentType.interface()); + } + } SlangResult result = m_interface->createCompositeComponentType( nativeComponents.data(), (SlangInt)nativeComponents.size(), @@ -215,14 +231,14 @@ ComponentType* Session::createCompositeComponentType(const std::vector<Component { g_error.type = std::string("USER"); g_error.result = result; - return nullptr; + return emscripten::val::null(); } } - - return new ComponentType(composite); + addComponentType(composite.get()); + return emscripten::val(ComponentType(composite, this)); } -ComponentType* ComponentType::link() +emscripten::val ComponentType::link() { Slang::ComPtr<IComponentType> linkedProgram; { @@ -235,11 +251,11 @@ ComponentType* ComponentType::link() g_error.message = std::string( (char*)diagnosticBlob->getBufferPointer(), (char*)diagnosticBlob->getBufferPointer() + diagnosticBlob->getBufferSize()); - return nullptr; + return emscripten::val::null(); } } - - return new ComponentType(linkedProgram); + m_session->addComponentType(linkedProgram.get()); + return emscripten::val(ComponentType(linkedProgram, m_session)); } std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetIndex) @@ -344,14 +360,14 @@ emscripten::val ComponentType::getTargetCodeBlob(int targetIndex) return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(), ptr)); } -HashedString* ComponentType::loadStrings() +emscripten::val ComponentType::loadStrings() { slang::ProgramLayout* slangReflection = interface()->getLayout(); if (!slangReflection) { g_error.type = std::string("USER"); g_error.message = std::string("Failed to get reflection data"); - return nullptr; + return emscripten::val::null(); } SlangUInt hashedStringCount = slangReflection->getHashedStringCount(); @@ -359,11 +375,11 @@ HashedString* ComponentType::loadStrings() { g_error.type = std::string("USER"); g_error.message = std::string("Warn: No reflection data found"); - return nullptr; + return emscripten::val::null(); } size_t stringSize = 0; - HashedString* hashedStrings = new HashedString(); + std::vector<emscripten::val> result; for (SlangUInt ii = 0; ii < hashedStringCount; ++ii) { // For each string we can fetch its bytes from the Slang @@ -381,9 +397,12 @@ HashedString* ComponentType::loadStrings() // int hash = spComputeStringHash(stringData, stringSize); - hashedStrings->insertString(hash, std::string(stringData)); + emscripten::val entry = emscripten::val::object(); + entry.set("hash", hash); + entry.set("string", std::string(stringData)); + result.push_back(entry); } - return hashedStrings; + return emscripten::val::array(result); } ProgramLayout* ComponentType::getLayout(unsigned int targetIndex) |
