diff options
| author | kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> | 2024-05-17 11:43:08 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-17 09:43:08 -0700 |
| commit | 9f786fdf71e90339e20979ef3ba8f073657f5a98 (patch) | |
| tree | dac01aa41224373527d2ce793a82c34e142a7b05 /source/slang-capture-replay/slang-module.cpp | |
| parent | 000396cdc4b00f7f8bf92a6569367e5cb9d6ba27 (diff) | |
capture/relay: Add capture interface classes (#4177)
* capture/relay: Add capture interface classes
Add `ModuleCapture` class for capturing `IModule`
- The `IModule` can only be created from
-- `ISession::loadModule`
-- `ISession::loadModuleFromIRBlob`
-- `ISession::loadModuleFromSource`
-- `ISession::loadModuleFromSourceString`
so, we create the `ModuleCapture` at those methods in `SessionCapture`
class. We use a hash map to store a map from `IModule` to `ModuleCapture`
to avoid creating new `ModuleCapture` when there is already an old one.
- In `SessionCapture::getLoadedModule`, we will assert on not finding
a `ModuleCapture` instance.
Add `EntryPointCapture` class for capturing `IEntryPoint`.
- The `IEntryPoint` can only be created from:
-- `IModule::findEntryPointByName`
-- `IModule::findAndCheckEntryPoint`
so, we create the `EntryPointCapture` at those methods in `ModuleCapture`.
Similarly, we use a hash map to store a map from `IEntryPoint` to
`EntryPointCapture`.
- In `IModule::getDefinedEntryPoint`, we will assert on not finding
a `EntryPointCapture` instance.
Add `CompositeComponentTypeCapture` class for capturing CompositeComponentType,
but since user is only exposed to `IComponentType`, so `CompositeComponentTypeCapture`
just inherits from `IComponentType`.
- `CompositeComponentType` can only be created from:
-- ISession::createCompositeComponentType
so create it here.
Add `TypeConformanceCapture` class for capturing `ITypeConformance`.
- The `ITypeConformance` can only be created from:
-- `ISession::createTypeConformanceComponentType`
so create it here.
In addition, because `EntryPointCapture` and `ModuleCapture` share a some
base class `IComponentType`, we generate the COM GUID for those two
classes to differentiate them.
* Fix the build issue
* Add nullptr check for output parameter
* define the SLANG_CAPTURE_ASSERT macro used in both debug and release build
Diffstat (limited to 'source/slang-capture-replay/slang-module.cpp')
| -rw-r--r-- | source/slang-capture-replay/slang-module.cpp | 239 |
1 files changed, 239 insertions, 0 deletions
diff --git a/source/slang-capture-replay/slang-module.cpp b/source/slang-capture-replay/slang-module.cpp new file mode 100644 index 000000000..31aa6bd54 --- /dev/null +++ b/source/slang-capture-replay/slang-module.cpp @@ -0,0 +1,239 @@ +#include "capture_utility.h" +#include "slang-module.h" + +namespace SlangCapture +{ + ModuleCapture::ModuleCapture(slang::IModule* module) + : m_actualModule(module) + { + SLANG_CAPTURE_ASSERT(m_actualModule != nullptr); + slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, module); + } + + ModuleCapture::~ModuleCapture() + { + m_actualModule->release(); + } + + ISlangUnknown* ModuleCapture::getInterface(const Guid& guid) + { + if(guid == ModuleCapture::getTypeGuid()) + return static_cast<ISlangUnknown*>(this); + else + return nullptr; + } + + SLANG_NO_THROW SlangResult ModuleCapture::findEntryPointByName( + char const* name, + slang::IEntryPoint** outEntryPoint) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + + SlangResult res = m_actualModule->findEntryPointByName(name, outEntryPoint); + + if (SLANG_OK == res) + { + EntryPointCapture* entryPointCapture = getEntryPointCapture(*outEntryPoint); + *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture); + } + return res; + } + + SLANG_NO_THROW SlangInt32 ModuleCapture::getDefinedEntryPointCount() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangInt32 res = m_actualModule->getDefinedEntryPointCount(); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint) + { + // This call is to find the existing entry point, so it has been created already. Therefore, we don't create a new one + // and assert the error if it is not found in our map. + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getDefinedEntryPoint(index, outEntryPoint); + + if (*outEntryPoint) + { + EntryPointCapture* entryPointCapture = m_mapEntryPointToCapture.tryGetValue(*outEntryPoint); + if (!entryPointCapture) + { + SLANG_CAPTURE_ASSERT(!"Entrypoint not found in mapEntryPointToCapture"); + } + *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture); + } + else + *outEntryPoint = nullptr; + + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::serialize(ISlangBlob** outSerializedBlob) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->serialize(outSerializedBlob); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::writeToFile(char const* fileName) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->writeToFile(fileName); + return res; + } + + SLANG_NO_THROW const char* ModuleCapture::getName() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + const char* res = m_actualModule->getName(); + return res; + } + + SLANG_NO_THROW const char* ModuleCapture::getFilePath() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + const char* res = m_actualModule->getFilePath(); + return res; + } + + SLANG_NO_THROW const char* ModuleCapture::getUniqueIdentity() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + const char* res = m_actualModule->getUniqueIdentity(); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::findAndCheckEntryPoint( + char const* name, + SlangStage stage, + slang::IEntryPoint** outEntryPoint, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + + SlangResult res = m_actualModule->findAndCheckEntryPoint(name, stage, outEntryPoint, outDiagnostics); + + if (SLANG_OK == res) + { + EntryPointCapture* entryPointCapture = getEntryPointCapture(*outEntryPoint); + *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture); + } + return res; + } + + SLANG_NO_THROW slang::ISession* ModuleCapture::getSession() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ISession* session = m_actualModule->getSession(); + return session; + } + + SLANG_NO_THROW slang::ProgramLayout* ModuleCapture::getLayout( + SlangInt targetIndex, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ProgramLayout* programLayout = m_actualModule->getLayout(targetIndex, outDiagnostics); + return programLayout; + } + + SLANG_NO_THROW SlangInt ModuleCapture::getSpecializationParamCount() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangInt res = m_actualModule->getSpecializationParamCount(); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + return res; + } + + SLANG_NO_THROW void ModuleCapture::getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + m_actualModule->getEntryPointHash(entryPointIndex, targetIndex, outHash); + } + + SLANG_NO_THROW SlangResult ModuleCapture::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::link( + IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->link(outLinkedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::renameEntryPoint( + const char* newName, IComponentType** outEntryPoint) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->renameEntryPoint(newName, outEntryPoint); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::linkWithOptions( + IComponentType** outLinkedComponentType, + uint32_t compilerOptionEntryCount, + slang::CompilerOptionEntry* compilerOptionEntries, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics); + return res; + } + + EntryPointCapture* ModuleCapture::getEntryPointCapture(slang::IEntryPoint* entryPoint) + { + EntryPointCapture* entryPointCapture = nullptr; + entryPointCapture = m_mapEntryPointToCapture.tryGetValue(entryPoint); + if (!entryPointCapture) + { + entryPointCapture = new EntryPointCapture(entryPoint); + Slang::ComPtr<EntryPointCapture> result(entryPointCapture); + m_mapEntryPointToCapture.add(entryPoint, *result.detach()); + } + return entryPointCapture; + } +} |
