summaryrefslogtreecommitdiff
path: root/source/slang-capture-replay/slang-module.cpp
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2024-05-17 11:43:08 -0500
committerGitHub <noreply@github.com>2024-05-17 09:43:08 -0700
commit9f786fdf71e90339e20979ef3ba8f073657f5a98 (patch)
treedac01aa41224373527d2ce793a82c34e142a7b05 /source/slang-capture-replay/slang-module.cpp
parent000396cdc4b00f7f8bf92a6569367e5cb9d6ba27 (diff)
capture/relay: Add capture interface classes (#4177)
* capture/relay: Add capture interface classes Add `ModuleCapture` class for capturing `IModule` - The `IModule` can only be created from -- `ISession::loadModule` -- `ISession::loadModuleFromIRBlob` -- `ISession::loadModuleFromSource` -- `ISession::loadModuleFromSourceString` so, we create the `ModuleCapture` at those methods in `SessionCapture` class. We use a hash map to store a map from `IModule` to `ModuleCapture` to avoid creating new `ModuleCapture` when there is already an old one. - In `SessionCapture::getLoadedModule`, we will assert on not finding a `ModuleCapture` instance. Add `EntryPointCapture` class for capturing `IEntryPoint`. - The `IEntryPoint` can only be created from: -- `IModule::findEntryPointByName` -- `IModule::findAndCheckEntryPoint` so, we create the `EntryPointCapture` at those methods in `ModuleCapture`. Similarly, we use a hash map to store a map from `IEntryPoint` to `EntryPointCapture`. - In `IModule::getDefinedEntryPoint`, we will assert on not finding a `EntryPointCapture` instance. Add `CompositeComponentTypeCapture` class for capturing CompositeComponentType, but since user is only exposed to `IComponentType`, so `CompositeComponentTypeCapture` just inherits from `IComponentType`. - `CompositeComponentType` can only be created from: -- ISession::createCompositeComponentType so create it here. Add `TypeConformanceCapture` class for capturing `ITypeConformance`. - The `ITypeConformance` can only be created from: -- `ISession::createTypeConformanceComponentType` so create it here. In addition, because `EntryPointCapture` and `ModuleCapture` share a some base class `IComponentType`, we generate the COM GUID for those two classes to differentiate them. * Fix the build issue * Add nullptr check for output parameter * define the SLANG_CAPTURE_ASSERT macro used in both debug and release build
Diffstat (limited to 'source/slang-capture-replay/slang-module.cpp')
-rw-r--r--source/slang-capture-replay/slang-module.cpp239
1 files changed, 239 insertions, 0 deletions
diff --git a/source/slang-capture-replay/slang-module.cpp b/source/slang-capture-replay/slang-module.cpp
new file mode 100644
index 000000000..31aa6bd54
--- /dev/null
+++ b/source/slang-capture-replay/slang-module.cpp
@@ -0,0 +1,239 @@
+#include "capture_utility.h"
+#include "slang-module.h"
+
+namespace SlangCapture
+{
+ ModuleCapture::ModuleCapture(slang::IModule* module)
+ : m_actualModule(module)
+ {
+ SLANG_CAPTURE_ASSERT(m_actualModule != nullptr);
+ slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, module);
+ }
+
+ ModuleCapture::~ModuleCapture()
+ {
+ m_actualModule->release();
+ }
+
+ ISlangUnknown* ModuleCapture::getInterface(const Guid& guid)
+ {
+ if(guid == ModuleCapture::getTypeGuid())
+ return static_cast<ISlangUnknown*>(this);
+ else
+ return nullptr;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::findEntryPointByName(
+ char const* name,
+ slang::IEntryPoint** outEntryPoint)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+
+ SlangResult res = m_actualModule->findEntryPointByName(name, outEntryPoint);
+
+ if (SLANG_OK == res)
+ {
+ EntryPointCapture* entryPointCapture = getEntryPointCapture(*outEntryPoint);
+ *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture);
+ }
+ return res;
+ }
+
+ SLANG_NO_THROW SlangInt32 ModuleCapture::getDefinedEntryPointCount()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangInt32 res = m_actualModule->getDefinedEntryPointCount();
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint)
+ {
+ // This call is to find the existing entry point, so it has been created already. Therefore, we don't create a new one
+ // and assert the error if it is not found in our map.
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->getDefinedEntryPoint(index, outEntryPoint);
+
+ if (*outEntryPoint)
+ {
+ EntryPointCapture* entryPointCapture = m_mapEntryPointToCapture.tryGetValue(*outEntryPoint);
+ if (!entryPointCapture)
+ {
+ SLANG_CAPTURE_ASSERT(!"Entrypoint not found in mapEntryPointToCapture");
+ }
+ *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture);
+ }
+ else
+ *outEntryPoint = nullptr;
+
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::serialize(ISlangBlob** outSerializedBlob)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->serialize(outSerializedBlob);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::writeToFile(char const* fileName)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->writeToFile(fileName);
+ return res;
+ }
+
+ SLANG_NO_THROW const char* ModuleCapture::getName()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ const char* res = m_actualModule->getName();
+ return res;
+ }
+
+ SLANG_NO_THROW const char* ModuleCapture::getFilePath()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ const char* res = m_actualModule->getFilePath();
+ return res;
+ }
+
+ SLANG_NO_THROW const char* ModuleCapture::getUniqueIdentity()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ const char* res = m_actualModule->getUniqueIdentity();
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::findAndCheckEntryPoint(
+ char const* name,
+ SlangStage stage,
+ slang::IEntryPoint** outEntryPoint,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+
+ SlangResult res = m_actualModule->findAndCheckEntryPoint(name, stage, outEntryPoint, outDiagnostics);
+
+ if (SLANG_OK == res)
+ {
+ EntryPointCapture* entryPointCapture = getEntryPointCapture(*outEntryPoint);
+ *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture);
+ }
+ return res;
+ }
+
+ SLANG_NO_THROW slang::ISession* ModuleCapture::getSession()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::ISession* session = m_actualModule->getSession();
+ return session;
+ }
+
+ SLANG_NO_THROW slang::ProgramLayout* ModuleCapture::getLayout(
+ SlangInt targetIndex,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::ProgramLayout* programLayout = m_actualModule->getLayout(targetIndex, outDiagnostics);
+ return programLayout;
+ }
+
+ SLANG_NO_THROW SlangInt ModuleCapture::getSpecializationParamCount()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangInt res = m_actualModule->getSpecializationParamCount();
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::getResultAsFileSystem(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ ISlangMutableFileSystem** outFileSystem)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem);
+ return res;
+ }
+
+ SLANG_NO_THROW void ModuleCapture::getEntryPointHash(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outHash)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ m_actualModule->getEntryPointHash(entryPointIndex, targetIndex, outHash);
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::link(
+ IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->link(outLinkedComponentType, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::renameEntryPoint(
+ const char* newName, IComponentType** outEntryPoint)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->renameEntryPoint(newName, outEntryPoint);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::linkWithOptions(
+ IComponentType** outLinkedComponentType,
+ uint32_t compilerOptionEntryCount,
+ slang::CompilerOptionEntry* compilerOptionEntries,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics);
+ return res;
+ }
+
+ EntryPointCapture* ModuleCapture::getEntryPointCapture(slang::IEntryPoint* entryPoint)
+ {
+ EntryPointCapture* entryPointCapture = nullptr;
+ entryPointCapture = m_mapEntryPointToCapture.tryGetValue(entryPoint);
+ if (!entryPointCapture)
+ {
+ entryPointCapture = new EntryPointCapture(entryPoint);
+ Slang::ComPtr<EntryPointCapture> result(entryPointCapture);
+ m_mapEntryPointToCapture.add(entryPoint, *result.detach());
+ }
+ return entryPointCapture;
+ }
+}