diff options
| author | Yong He <yonghe@outlook.com> | 2024-10-28 08:53:41 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-28 08:53:41 -0700 |
| commit | 04329077988a2b1f7a87b1d116457599039e5e12 (patch) | |
| tree | 86b29e773f04f0926cc7685766d0a0b135c90500 /source | |
| parent | a3276e2876be00fd4b0a69e47a66b1cff29765f2 (diff) | |
More wasm binding for playground. (#5420)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang-wasm/slang-wasm-bindings.cpp | 25 | ||||
| -rw-r--r-- | source/slang-wasm/slang-wasm.cpp | 96 | ||||
| -rw-r--r-- | source/slang-wasm/slang-wasm.h | 15 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 8 |
4 files changed, 130 insertions, 14 deletions
diff --git a/source/slang-wasm/slang-wasm-bindings.cpp b/source/slang-wasm/slang-wasm-bindings.cpp index 360dec6eb..56a472482 100644 --- a/source/slang-wasm/slang-wasm-bindings.cpp +++ b/source/slang-wasm/slang-wasm-bindings.cpp @@ -47,8 +47,14 @@ EMSCRIPTEN_BINDINGS(slang) "getEntryPointCode", &slang::wgsl::ComponentType::getEntryPointCode) .function( - "getEntryPointCodeSpirv", - &slang::wgsl::ComponentType::getEntryPointCodeSpirv); + "getEntryPointCodeBlob", + &slang::wgsl::ComponentType::getEntryPointCodeBlob) + .function( + "getTargetCodeBlob", + &slang::wgsl::ComponentType::getTargetCodeBlob) + .function( + "getTargetCode", + &slang::wgsl::ComponentType::getTargetCode); class_<slang::wgsl::Module, base<slang::wgsl::ComponentType>>("Module") .function( @@ -58,14 +64,25 @@ EMSCRIPTEN_BINDINGS(slang) .function( "findAndCheckEntryPoint", &slang::wgsl::Module::findAndCheckEntryPoint, - return_value_policy::take_ownership()); + return_value_policy::take_ownership()) + .function( + "getDefinedEntryPoint", + &slang::wgsl::Module::getDefinedEntryPoint, + return_value_policy::take_ownership()) + .function( + "getDefinedEntryPointCount", + &slang::wgsl::Module::getDefinedEntryPointCount); value_object<slang::wgsl::Error>("Error") .field("type", &slang::wgsl::Error::type) .field("result", &slang::wgsl::Error::result) .field("message", &slang::wgsl::Error::message); - class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint"); + class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint") + .function( + "getName", + &slang::wgsl::EntryPoint::getName, + allow_raw_pointers()); class_<slang::wgsl::CompileTargets>("CompileTargets") .function( diff --git a/source/slang-wasm/slang-wasm.cpp b/source/slang-wasm/slang-wasm.cpp index 886efc5f7..8948c1075 100644 --- a/source/slang-wasm/slang-wasm.cpp +++ b/source/slang-wasm/slang-wasm.cpp @@ -94,17 +94,15 @@ Session* GlobalSession::createSession(int compileTarget) return new Session(session); } -Module* Session::loadModuleFromSource(const std::string& slangCode) +Module* Session::loadModuleFromSource(const std::string& slangCode, const std::string& name, const std::string& path) { Slang::ComPtr<IModule> module; { - const char * name = ""; - const char * path = ""; Slang::ComPtr<slang::IBlob> diagnosticsBlob; Slang::ComPtr<ISlangBlob> slangCodeBlob = Slang::RawBlob::create( slangCode.c_str(), slangCode.size()); module = m_interface->loadModuleFromSource( - name, path, slangCodeBlob, diagnosticsBlob.writeRef()); + name.c_str(), path.c_str(), slangCodeBlob, diagnosticsBlob.writeRef()); if (!module) { g_error.type = std::string("USER"); @@ -161,6 +159,38 @@ EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage) return new EntryPoint(entryPoint); } +int Module::getDefinedEntryPointCount() +{ + return moduleInterface()->getDefinedEntryPointCount(); +} + +EntryPoint* Module::getDefinedEntryPoint(int index) +{ + if (moduleInterface()->getDefinedEntryPointCount() <= index) + return nullptr; + + Slang::ComPtr<IEntryPoint> entryPoint; + { + Slang::ComPtr<slang::IBlob> diagnosticsBlob; + SlangResult result = moduleInterface()->getDefinedEntryPoint(index, entryPoint.writeRef()); + if (!SLANG_SUCCEEDED(result)) + { + g_error.type = std::string("USER"); + g_error.result = result; + + if (diagnosticsBlob->getBufferSize()) + { + char* diagnostics = (char*)diagnosticsBlob->getBufferPointer(); + g_error.message = std::string(diagnostics); + } + return nullptr; + } + } + + return new EntryPoint(entryPoint); +} + + ComponentType* Session::createCompositeComponentType( const std::vector<ComponentType*>& components) { @@ -235,9 +265,9 @@ std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetInde return {}; } -// Since spirv code is binary, we can't return it as a string, we will need to use emscripten::val +// Since result code is binary, we can't return it as a string, we will need to use emscripten::val // to wrap it and return it to the javascript side. -emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int targetIndex) +emscripten::val ComponentType::getEntryPointCodeBlob(int entryPointIndex, int targetIndex) { Slang::ComPtr<IBlob> kernelBlob; Slang::ComPtr<ISlangBlob> diagnosticBlob; @@ -262,6 +292,60 @@ emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int t ptr)); } +std::string ComponentType::getTargetCode(int targetIndex) +{ + { + Slang::ComPtr<IBlob> kernelBlob; + Slang::ComPtr<ISlangBlob> diagnosticBlob; + SlangResult result = interface()->getTargetCode( + targetIndex, + kernelBlob.writeRef(), + diagnosticBlob.writeRef()); + if (result != SLANG_OK) + { + g_error.type = std::string("USER"); + g_error.result = result; + g_error.message = std::string( + (char*)diagnosticBlob->getBufferPointer(), + (char*)diagnosticBlob->getBufferPointer() + + diagnosticBlob->getBufferSize()); + return ""; + } + std::string targetCode = std::string( + (char*)kernelBlob->getBufferPointer(), + (char*)kernelBlob->getBufferPointer() + kernelBlob->getBufferSize()); + return targetCode; + } + + return {}; +} + +// Since result code is binary, we can't return it as a string, we will need to use emscripten::val +// to wrap it and return it to the javascript side. +emscripten::val ComponentType::getTargetCodeBlob(int targetIndex) +{ + Slang::ComPtr<IBlob> kernelBlob; + Slang::ComPtr<ISlangBlob> diagnosticBlob; + SlangResult result = interface()->getTargetCode( + targetIndex, + kernelBlob.writeRef(), + diagnosticBlob.writeRef()); + if (result != SLANG_OK) + { + g_error.type = std::string("USER"); + g_error.result = result; + g_error.message = std::string( + (char*)diagnosticBlob->getBufferPointer(), + (char*)diagnosticBlob->getBufferPointer() + + diagnosticBlob->getBufferSize()); + return {}; + } + + const uint8_t* ptr = (uint8_t*)kernelBlob->getBufferPointer(); + return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(), + ptr)); +} + namespace lsp { Position translate(Slang::LanguageServerProtocol::Position p) diff --git a/source/slang-wasm/slang-wasm.h b/source/slang-wasm/slang-wasm.h index 5a299453c..eb302119b 100644 --- a/source/slang-wasm/slang-wasm.h +++ b/source/slang-wasm/slang-wasm.h @@ -48,7 +48,9 @@ public: ComponentType* link(); std::string getEntryPointCode(int entryPointIndex, int targetIndex); - emscripten::val getEntryPointCodeSpirv(int entryPointIndex, int targetIndex); + emscripten::val getEntryPointCodeBlob(int entryPointIndex, int targetIndex); + std::string getTargetCode(int targetIndex); + emscripten::val getTargetCodeBlob(int targetIndex); slang::IComponentType* interface() const {return m_interface;} @@ -62,9 +64,11 @@ private: class EntryPoint : public ComponentType { public: - EntryPoint(slang::IEntryPoint* interface) : ComponentType(interface) {} - + std::string getName() const + { + return entryPointInterface()->getFunctionReflection()->getName(); + } private: slang::IEntryPoint* entryPointInterface() const { @@ -80,6 +84,8 @@ public: EntryPoint* findEntryPointByName(const std::string& name); EntryPoint* findAndCheckEntryPoint(const std::string& name, int stage); + EntryPoint* getDefinedEntryPoint(int index); + int getDefinedEntryPointCount(); slang::IModule* moduleInterface() const { return static_cast<slang::IModule*>(interface()); @@ -93,7 +99,8 @@ public: Session(slang::ISession* interface) : m_interface(interface) {} - Module* loadModuleFromSource(const std::string& slangCode); + Module* loadModuleFromSource( + const std::string& slangCode, const std::string& name, const std::string& path); ComponentType* createCompositeComponentType( const std::vector<ComponentType*>& components); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index b77a3efc8..b3ac7f73d 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -5040,13 +5040,21 @@ IArtifact* ComponentType::getTargetArtifact(Int targetIndex, slang::IBlob** outD }); List<RefPtr<ComponentType>> components; components.add(this); + bool entryPointsDiscovered = false; for (auto module : modules) { for (auto entryPoint : module->getEntryPoints()) { components.add(entryPoint); + entryPointsDiscovered = true; } } + // If no entry points were discovered, then we should return nullptr. + if (!entryPointsDiscovered) + { + return nullptr; + } + RefPtr<CompositeComponentType> composite = new CompositeComponentType(linkage, components); ComPtr<IComponentType> linkedComponentType; SLANG_RETURN_NULL_ON_FAIL(composite->link(linkedComponentType.writeRef(), outDiagnostics)); |
