From 9f786fdf71e90339e20979ef3ba8f073657f5a98 Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Fri, 17 May 2024 11:43:08 -0500 Subject: capture/relay: Add capture interface classes (#4177) * capture/relay: Add capture interface classes Add `ModuleCapture` class for capturing `IModule` - The `IModule` can only be created from -- `ISession::loadModule` -- `ISession::loadModuleFromIRBlob` -- `ISession::loadModuleFromSource` -- `ISession::loadModuleFromSourceString` so, we create the `ModuleCapture` at those methods in `SessionCapture` class. We use a hash map to store a map from `IModule` to `ModuleCapture` to avoid creating new `ModuleCapture` when there is already an old one. - In `SessionCapture::getLoadedModule`, we will assert on not finding a `ModuleCapture` instance. Add `EntryPointCapture` class for capturing `IEntryPoint`. - The `IEntryPoint` can only be created from: -- `IModule::findEntryPointByName` -- `IModule::findAndCheckEntryPoint` so, we create the `EntryPointCapture` at those methods in `ModuleCapture`. Similarly, we use a hash map to store a map from `IEntryPoint` to `EntryPointCapture`. - In `IModule::getDefinedEntryPoint`, we will assert on not finding a `EntryPointCapture` instance. Add `CompositeComponentTypeCapture` class for capturing CompositeComponentType, but since user is only exposed to `IComponentType`, so `CompositeComponentTypeCapture` just inherits from `IComponentType`. - `CompositeComponentType` can only be created from: -- ISession::createCompositeComponentType so create it here. Add `TypeConformanceCapture` class for capturing `ITypeConformance`. - The `ITypeConformance` can only be created from: -- `ISession::createTypeConformanceComponentType` so create it here. In addition, because `EntryPointCapture` and `ModuleCapture` share a some base class `IComponentType`, we generate the COM GUID for those two classes to differentiate them. * Fix the build issue * Add nullptr check for output parameter * define the SLANG_CAPTURE_ASSERT macro used in both debug and release build --- source/slang-capture-replay/slang-module.cpp | 239 +++++++++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 source/slang-capture-replay/slang-module.cpp (limited to 'source/slang-capture-replay/slang-module.cpp') diff --git a/source/slang-capture-replay/slang-module.cpp b/source/slang-capture-replay/slang-module.cpp new file mode 100644 index 000000000..31aa6bd54 --- /dev/null +++ b/source/slang-capture-replay/slang-module.cpp @@ -0,0 +1,239 @@ +#include "capture_utility.h" +#include "slang-module.h" + +namespace SlangCapture +{ + ModuleCapture::ModuleCapture(slang::IModule* module) + : m_actualModule(module) + { + SLANG_CAPTURE_ASSERT(m_actualModule != nullptr); + slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, module); + } + + ModuleCapture::~ModuleCapture() + { + m_actualModule->release(); + } + + ISlangUnknown* ModuleCapture::getInterface(const Guid& guid) + { + if(guid == ModuleCapture::getTypeGuid()) + return static_cast(this); + else + return nullptr; + } + + SLANG_NO_THROW SlangResult ModuleCapture::findEntryPointByName( + char const* name, + slang::IEntryPoint** outEntryPoint) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + + SlangResult res = m_actualModule->findEntryPointByName(name, outEntryPoint); + + if (SLANG_OK == res) + { + EntryPointCapture* entryPointCapture = getEntryPointCapture(*outEntryPoint); + *outEntryPoint = static_cast(entryPointCapture); + } + return res; + } + + SLANG_NO_THROW SlangInt32 ModuleCapture::getDefinedEntryPointCount() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangInt32 res = m_actualModule->getDefinedEntryPointCount(); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint) + { + // This call is to find the existing entry point, so it has been created already. Therefore, we don't create a new one + // and assert the error if it is not found in our map. + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getDefinedEntryPoint(index, outEntryPoint); + + if (*outEntryPoint) + { + EntryPointCapture* entryPointCapture = m_mapEntryPointToCapture.tryGetValue(*outEntryPoint); + if (!entryPointCapture) + { + SLANG_CAPTURE_ASSERT(!"Entrypoint not found in mapEntryPointToCapture"); + } + *outEntryPoint = static_cast(entryPointCapture); + } + else + *outEntryPoint = nullptr; + + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::serialize(ISlangBlob** outSerializedBlob) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->serialize(outSerializedBlob); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::writeToFile(char const* fileName) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->writeToFile(fileName); + return res; + } + + SLANG_NO_THROW const char* ModuleCapture::getName() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + const char* res = m_actualModule->getName(); + return res; + } + + SLANG_NO_THROW const char* ModuleCapture::getFilePath() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + const char* res = m_actualModule->getFilePath(); + return res; + } + + SLANG_NO_THROW const char* ModuleCapture::getUniqueIdentity() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + const char* res = m_actualModule->getUniqueIdentity(); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::findAndCheckEntryPoint( + char const* name, + SlangStage stage, + slang::IEntryPoint** outEntryPoint, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + + SlangResult res = m_actualModule->findAndCheckEntryPoint(name, stage, outEntryPoint, outDiagnostics); + + if (SLANG_OK == res) + { + EntryPointCapture* entryPointCapture = getEntryPointCapture(*outEntryPoint); + *outEntryPoint = static_cast(entryPointCapture); + } + return res; + } + + SLANG_NO_THROW slang::ISession* ModuleCapture::getSession() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ISession* session = m_actualModule->getSession(); + return session; + } + + SLANG_NO_THROW slang::ProgramLayout* ModuleCapture::getLayout( + SlangInt targetIndex, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ProgramLayout* programLayout = m_actualModule->getLayout(targetIndex, outDiagnostics); + return programLayout; + } + + SLANG_NO_THROW SlangInt ModuleCapture::getSpecializationParamCount() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangInt res = m_actualModule->getSpecializationParamCount(); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + return res; + } + + SLANG_NO_THROW void ModuleCapture::getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + m_actualModule->getEntryPointHash(entryPointIndex, targetIndex, outHash); + } + + SLANG_NO_THROW SlangResult ModuleCapture::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::link( + IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->link(outLinkedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::renameEntryPoint( + const char* newName, IComponentType** outEntryPoint) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->renameEntryPoint(newName, outEntryPoint); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::linkWithOptions( + IComponentType** outLinkedComponentType, + uint32_t compilerOptionEntryCount, + slang::CompilerOptionEntry* compilerOptionEntries, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics); + return res; + } + + EntryPointCapture* ModuleCapture::getEntryPointCapture(slang::IEntryPoint* entryPoint) + { + EntryPointCapture* entryPointCapture = nullptr; + entryPointCapture = m_mapEntryPointToCapture.tryGetValue(entryPoint); + if (!entryPointCapture) + { + entryPointCapture = new EntryPointCapture(entryPoint); + Slang::ComPtr result(entryPointCapture); + m_mapEntryPointToCapture.add(entryPoint, *result.detach()); + } + return entryPointCapture; + } +} -- cgit v1.2.3