summaryrefslogtreecommitdiff
path: root/source/slang-wasm/slang-wasm.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-11-08 16:19:31 -0800
committerGitHub <noreply@github.com>2024-11-08 16:19:31 -0800
commit5ca37c316a9ea1833907bae497734054dfa3c3cb (patch)
tree6aec7069ca410fd6d2426c36ef102303512c28e3 /source/slang-wasm/slang-wasm.cpp
parent7c414463063b979afb0b5184a48a13fcaf5b8af7 (diff)
Use automatic coarse grained memory management in wasm binding. (#5528)
Diffstat (limited to 'source/slang-wasm/slang-wasm.cpp')
-rw-r--r--source/slang-wasm/slang-wasm.cpp151
1 files changed, 85 insertions, 66 deletions
diff --git a/source/slang-wasm/slang-wasm.cpp b/source/slang-wasm/slang-wasm.cpp
index f73e70ba0..bdc2e5e6e 100644
--- a/source/slang-wasm/slang-wasm.cpp
+++ b/source/slang-wasm/slang-wasm.cpp
@@ -16,7 +16,6 @@ namespace wgsl
{
Error g_error;
-CompileTargets g_compileTargets;
Error getLastError()
{
@@ -25,9 +24,30 @@ Error getLastError()
return currentError;
}
-CompileTargets* getCompileTargets()
+emscripten::val getCompileTargets()
{
- return &g_compileTargets;
+ struct TargetPair
+ {
+ const char* name;
+ SlangCompileTarget target;
+ };
+ static const TargetPair targets[] = {
+ {"GLSL", SLANG_GLSL},
+ {"HLSL", SLANG_HLSL},
+ {"WGSL", SLANG_WGSL},
+ {"SPIRV", SLANG_SPIRV},
+ {"METAL", SLANG_METAL},
+ };
+
+ std::vector<emscripten::val> result;
+ for (auto target : targets)
+ {
+ auto entry = emscripten::val::object();
+ entry.set("name", target.name);
+ entry.set("value", (int)target.target);
+ result.push_back(entry);
+ }
+ return emscripten::val::array(result);
}
GlobalSession* createGlobalSession()
@@ -46,35 +66,9 @@ GlobalSession* createGlobalSession()
return new GlobalSession(globalSession);
}
-CompileTargets::CompileTargets()
-{
-#define MAKE_PAIR(x) {#x, SLANG_##x}
-
- m_compileTargetMap = {
- MAKE_PAIR(GLSL),
- MAKE_PAIR(HLSL),
- MAKE_PAIR(WGSL),
- MAKE_PAIR(SPIRV),
- MAKE_PAIR(METAL),
- };
-}
-
-int CompileTargets::findCompileTarget(const std::string& name)
-{
- auto res = m_compileTargetMap.find(name);
- if (res != m_compileTargetMap.end())
- {
- return res->second;
- }
- else
- {
- return SLANG_TARGET_UNKNOWN;
- }
-}
-
Session* GlobalSession::createSession(int compileTarget)
{
- ISession* session = nullptr;
+ Slang::ComPtr<ISession> session;
{
SessionDesc sessionDesc = {};
sessionDesc.structureSize = sizeof(sessionDesc);
@@ -83,7 +77,7 @@ Session* GlobalSession::createSession(int compileTarget)
target.format = (SlangCompileTarget)compileTarget;
sessionDesc.targets = &target;
sessionDesc.targetCount = targetCount;
- SlangResult result = m_interface->createSession(sessionDesc, &session);
+ SlangResult result = m_interface->createSession(sessionDesc, session.writeRef());
if (result != SLANG_OK)
{
g_error.type = std::string("USER");
@@ -95,12 +89,19 @@ Session* GlobalSession::createSession(int compileTarget)
return new Session(session);
}
-Module* Session::loadModuleFromSource(
+Session::~Session()
+{
+ m_componentTypes = {};
+ auto refCount = static_cast<Slang::Linkage*>(m_interface.get())->debugGetReferenceCount();
+ m_interface = nullptr;
+}
+
+emscripten::val Session::loadModuleFromSource(
const std::string& slangCode,
const std::string& name,
const std::string& path)
{
- Slang::ComPtr<IModule> module;
+ IModule* module = nullptr;
{
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
Slang::ComPtr<ISlangBlob> slangCodeBlob =
@@ -116,14 +117,13 @@ Module* Session::loadModuleFromSource(
g_error.message = std::string(
(char*)diagnosticsBlob->getBufferPointer(),
(char*)diagnosticsBlob->getBufferPointer() + diagnosticsBlob->getBufferSize());
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new Module(module);
+ return emscripten::val(Module(module, this));
}
-EntryPoint* Module::findEntryPointByName(const std::string& name)
+emscripten::val Module::findEntryPointByName(const std::string& name)
{
Slang::ComPtr<IEntryPoint> entryPoint;
{
@@ -133,15 +133,15 @@ EntryPoint* Module::findEntryPointByName(const std::string& name)
{
g_error.type = std::string("USER");
g_error.result = result;
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new EntryPoint(entryPoint);
+ m_session->addComponentType(entryPoint.get());
+ return emscripten::val(EntryPoint(entryPoint.get(), m_session));
}
-EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage)
+emscripten::val Module::findAndCheckEntryPoint(const std::string& name, int stage)
{
Slang::ComPtr<IEntryPoint> entryPoint;
{
@@ -161,11 +161,11 @@ EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage)
char* diagnostics = (char*)diagnosticsBlob->getBufferPointer();
g_error.message = std::string(diagnostics);
}
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new EntryPoint(entryPoint);
+ m_session->addComponentType(entryPoint.get());
+ return emscripten::val(EntryPoint(entryPoint.get(), m_session));
}
int Module::getDefinedEntryPointCount()
@@ -173,10 +173,10 @@ int Module::getDefinedEntryPointCount()
return moduleInterface()->getDefinedEntryPointCount();
}
-EntryPoint* Module::getDefinedEntryPoint(int index)
+emscripten::val Module::getDefinedEntryPoint(int index)
{
if (moduleInterface()->getDefinedEntryPointCount() <= index)
- return nullptr;
+ return emscripten::val::null();
Slang::ComPtr<IEntryPoint> entryPoint;
{
@@ -192,21 +192,37 @@ EntryPoint* Module::getDefinedEntryPoint(int index)
char* diagnostics = (char*)diagnosticsBlob->getBufferPointer();
g_error.message = std::string(diagnostics);
}
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new EntryPoint(entryPoint);
+ m_session->addComponentType(entryPoint.get());
+ return emscripten::val(EntryPoint(entryPoint.get(), m_session));
}
-ComponentType* Session::createCompositeComponentType(const std::vector<ComponentType*>& components)
+emscripten::val Session::createCompositeComponentType(emscripten::val components)
{
+ if (!components.isArray())
+ {
+ g_error.type = std::string("Slang WASM Bind");
+ g_error.message = std::string("createCompositeComponentType: Components must be an array");
+ return emscripten::val::null();
+ }
+ std::vector<emscripten::val> componentsArray =
+ emscripten::vecFromJSArray<emscripten::val>(components);
+
Slang::ComPtr<IComponentType> composite;
{
- std::vector<IComponentType*> nativeComponents(components.size());
- for (size_t i = 0U; i < components.size(); i++)
- nativeComponents[i] = components[i]->interface();
+ std::vector<IComponentType*> nativeComponents;
+ for (size_t i = 0U; i < componentsArray.size(); i++)
+ {
+ auto componentVal = componentsArray[i];
+ if (componentVal.instanceof (emscripten::val::module_property("ComponentType")))
+ {
+ auto componentType = componentVal.as<ComponentType>();
+ nativeComponents.push_back(componentType.interface());
+ }
+ }
SlangResult result = m_interface->createCompositeComponentType(
nativeComponents.data(),
(SlangInt)nativeComponents.size(),
@@ -215,14 +231,14 @@ ComponentType* Session::createCompositeComponentType(const std::vector<Component
{
g_error.type = std::string("USER");
g_error.result = result;
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new ComponentType(composite);
+ addComponentType(composite.get());
+ return emscripten::val(ComponentType(composite, this));
}
-ComponentType* ComponentType::link()
+emscripten::val ComponentType::link()
{
Slang::ComPtr<IComponentType> linkedProgram;
{
@@ -235,11 +251,11 @@ ComponentType* ComponentType::link()
g_error.message = std::string(
(char*)diagnosticBlob->getBufferPointer(),
(char*)diagnosticBlob->getBufferPointer() + diagnosticBlob->getBufferSize());
- return nullptr;
+ return emscripten::val::null();
}
}
-
- return new ComponentType(linkedProgram);
+ m_session->addComponentType(linkedProgram.get());
+ return emscripten::val(ComponentType(linkedProgram, m_session));
}
std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetIndex)
@@ -344,14 +360,14 @@ emscripten::val ComponentType::getTargetCodeBlob(int targetIndex)
return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(), ptr));
}
-HashedString* ComponentType::loadStrings()
+emscripten::val ComponentType::loadStrings()
{
slang::ProgramLayout* slangReflection = interface()->getLayout();
if (!slangReflection)
{
g_error.type = std::string("USER");
g_error.message = std::string("Failed to get reflection data");
- return nullptr;
+ return emscripten::val::null();
}
SlangUInt hashedStringCount = slangReflection->getHashedStringCount();
@@ -359,11 +375,11 @@ HashedString* ComponentType::loadStrings()
{
g_error.type = std::string("USER");
g_error.message = std::string("Warn: No reflection data found");
- return nullptr;
+ return emscripten::val::null();
}
size_t stringSize = 0;
- HashedString* hashedStrings = new HashedString();
+ std::vector<emscripten::val> result;
for (SlangUInt ii = 0; ii < hashedStringCount; ++ii)
{
// For each string we can fetch its bytes from the Slang
@@ -381,9 +397,12 @@ HashedString* ComponentType::loadStrings()
//
int hash = spComputeStringHash(stringData, stringSize);
- hashedStrings->insertString(hash, std::string(stringData));
+ emscripten::val entry = emscripten::val::object();
+ entry.set("hash", hash);
+ entry.set("string", std::string(stringData));
+ result.push_back(entry);
}
- return hashedStrings;
+ return emscripten::val::array(result);
}
ProgramLayout* ComponentType::getLayout(unsigned int targetIndex)