diff options
| author | kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> | 2024-05-17 11:43:08 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-17 09:43:08 -0700 |
| commit | 9f786fdf71e90339e20979ef3ba8f073657f5a98 (patch) | |
| tree | dac01aa41224373527d2ce793a82c34e142a7b05 /source | |
| parent | 000396cdc4b00f7f8bf92a6569367e5cb9d6ba27 (diff) | |
capture/relay: Add capture interface classes (#4177)
* capture/relay: Add capture interface classes
Add `ModuleCapture` class for capturing `IModule`
- The `IModule` can only be created from
-- `ISession::loadModule`
-- `ISession::loadModuleFromIRBlob`
-- `ISession::loadModuleFromSource`
-- `ISession::loadModuleFromSourceString`
so, we create the `ModuleCapture` at those methods in `SessionCapture`
class. We use a hash map to store a map from `IModule` to `ModuleCapture`
to avoid creating new `ModuleCapture` when there is already an old one.
- In `SessionCapture::getLoadedModule`, we will assert on not finding
a `ModuleCapture` instance.
Add `EntryPointCapture` class for capturing `IEntryPoint`.
- The `IEntryPoint` can only be created from:
-- `IModule::findEntryPointByName`
-- `IModule::findAndCheckEntryPoint`
so, we create the `EntryPointCapture` at those methods in `ModuleCapture`.
Similarly, we use a hash map to store a map from `IEntryPoint` to
`EntryPointCapture`.
- In `IModule::getDefinedEntryPoint`, we will assert on not finding
a `EntryPointCapture` instance.
Add `CompositeComponentTypeCapture` class for capturing CompositeComponentType,
but since user is only exposed to `IComponentType`, so `CompositeComponentTypeCapture`
just inherits from `IComponentType`.
- `CompositeComponentType` can only be created from:
-- ISession::createCompositeComponentType
so create it here.
Add `TypeConformanceCapture` class for capturing `ITypeConformance`.
- The `ITypeConformance` can only be created from:
-- `ISession::createTypeConformanceComponentType`
so create it here.
In addition, because `EntryPointCapture` and `ModuleCapture` share a some
base class `IComponentType`, we generate the COM GUID for those two
classes to differentiate them.
* Fix the build issue
* Add nullptr check for output parameter
* define the SLANG_CAPTURE_ASSERT macro used in both debug and release build
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang-capture-replay/capture_utility.h | 15 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-composite-component-type.cpp | 129 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-composite-component-type.h | 68 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-entrypoint.cpp | 128 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-entrypoint.h | 70 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-filesystem.cpp | 20 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-global-session.cpp | 56 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-module.cpp | 239 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-module.h | 91 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-session.cpp | 141 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-session.h | 14 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-type-conformance.cpp | 131 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-type-conformance.h | 70 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 17 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 6 |
15 files changed, 1131 insertions, 64 deletions
diff --git a/source/slang-capture-replay/capture_utility.h b/source/slang-capture-replay/capture_utility.h index db2939894..6edd2af9c 100644 --- a/source/slang-capture-replay/capture_utility.h +++ b/source/slang-capture-replay/capture_utility.h @@ -1,6 +1,12 @@ #ifndef CAPTURE_UTILITY_H #define CAPTURE_UTILITY_H +// in gcc and clang, __PRETTY_FUNCTION__ is the function signature, +// while MSVC uses __FUNCSIG__ +#ifdef _MSC_VER +#define __PRETTY_FUNCTION__ __FUNCSIG__ +#endif + namespace SlangCapture { enum LogLevel: unsigned int @@ -15,4 +21,13 @@ namespace SlangCapture void slangCaptureLog(LogLevel logLevel, const char* fmt, ...); void setLogLevel(); } + +#define SLANG_CAPTURE_ASSERT(VALUE) \ + do { \ + if (!(VALUE)) { \ + SlangCapture::slangCaptureLog(SlangCapture::LogLevel::Error, "Assertion failed: %s, %s, %d\n", #VALUE, __FILE__, __LINE__);\ + std::abort(); \ + } \ + } while(0) + #endif // CAPTURE_UTILITY_H diff --git a/source/slang-capture-replay/slang-composite-component-type.cpp b/source/slang-capture-replay/slang-composite-component-type.cpp new file mode 100644 index 000000000..09fcec357 --- /dev/null +++ b/source/slang-capture-replay/slang-composite-component-type.cpp @@ -0,0 +1,129 @@ +#include "capture_utility.h" +#include "slang-composite-component-type.h" + +namespace SlangCapture +{ + CompositeComponentTypeCapture::CompositeComponentTypeCapture(slang::IComponentType* componentType) + : m_actualCompositeComponentType(componentType) + { + SLANG_CAPTURE_ASSERT(m_actualCompositeComponentType != nullptr); + slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, componentType); + } + + CompositeComponentTypeCapture::~CompositeComponentTypeCapture() + { + m_actualCompositeComponentType->release(); + } + + ISlangUnknown* CompositeComponentTypeCapture::getInterface(const Guid& guid) + { + if (guid == IComponentType::getTypeGuid()) + { + return static_cast<ISlangUnknown*>(this); + } + return nullptr; + } + + SLANG_NO_THROW slang::ISession* CompositeComponentTypeCapture::getSession() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ISession* res = m_actualCompositeComponentType->getSession(); + return res; + } + + SLANG_NO_THROW slang::ProgramLayout* CompositeComponentTypeCapture::getLayout( + SlangInt targetIndex, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ProgramLayout* res = m_actualCompositeComponentType->getLayout(targetIndex, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangInt CompositeComponentTypeCapture::getSpecializationParamCount() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangInt res = m_actualCompositeComponentType->getSpecializationParamCount(); + return res; + } + + SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualCompositeComponentType->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualCompositeComponentType->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + return res; + } + + SLANG_NO_THROW void CompositeComponentTypeCapture::getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + m_actualCompositeComponentType->getEntryPointHash(entryPointIndex, targetIndex, outHash); + } + + SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualCompositeComponentType->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualCompositeComponentType->link(outLinkedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualCompositeComponentType->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::renameEntryPoint( + const char* newName, IComponentType** outEntryPoint) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualCompositeComponentType->renameEntryPoint(newName, outEntryPoint); + return res; + } + + SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::linkWithOptions( + IComponentType** outLinkedComponentType, + uint32_t compilerOptionEntryCount, + slang::CompilerOptionEntry* compilerOptionEntries, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualCompositeComponentType->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics); + return res; + } +} diff --git a/source/slang-capture-replay/slang-composite-component-type.h b/source/slang-capture-replay/slang-composite-component-type.h new file mode 100644 index 000000000..36cdd16a5 --- /dev/null +++ b/source/slang-capture-replay/slang-composite-component-type.h @@ -0,0 +1,68 @@ +#ifndef SLANG_COMPOSITE_COMPONENT_TYPE_H +#define SLANG_COMPOSITE_COMPONENT_TYPE_H + +#include "../../slang-com-ptr.h" +#include "../../slang.h" +#include "../../slang-com-helper.h" +#include "../core/slang-smart-pointer.h" +#include "../core/slang-dictionary.h" +#include "../slang/slang-compiler.h" + +namespace SlangCapture +{ + using namespace Slang; + class CompositeComponentTypeCapture: public slang::IComponentType, public RefObject + { + public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + ISlangUnknown* getInterface(const Guid& guid); + + explicit CompositeComponentTypeCapture(slang::IComponentType* componentType); + ~CompositeComponentTypeCapture(); + + // Interfaces for `IComponentType` + virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override; + virtual SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( + SlangInt targetIndex = 0, + slang::IBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangInt SLANG_MCALL getSpecializationParamCount() override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) override; + virtual SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics = 0) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( + const char* newName, IComponentType** outEntryPoint) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + IComponentType** outLinkedComponentType, + uint32_t compilerOptionEntryCount, + slang::CompilerOptionEntry* compilerOptionEntries, + ISlangBlob** outDiagnostics = nullptr) override; + + slang::IComponentType* getActualCompositeComponentType() const { return m_actualCompositeComponentType; } + private: + Slang::ComPtr<slang::IComponentType> m_actualCompositeComponentType; + }; +} +#endif // SLANG_COMPOSITE_COMPONENT_TYPE_H diff --git a/source/slang-capture-replay/slang-entrypoint.cpp b/source/slang-capture-replay/slang-entrypoint.cpp new file mode 100644 index 000000000..05e25fcb0 --- /dev/null +++ b/source/slang-capture-replay/slang-entrypoint.cpp @@ -0,0 +1,128 @@ +#include "capture_utility.h" +#include "slang-entrypoint.h" + +namespace SlangCapture +{ + EntryPointCapture::EntryPointCapture(slang::IEntryPoint* entryPoint) + : m_actualEntryPoint(entryPoint) + { + SLANG_CAPTURE_ASSERT(m_actualEntryPoint != nullptr); + slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, entryPoint); + } + + EntryPointCapture::~EntryPointCapture() + { + m_actualEntryPoint->release(); + } + + ISlangUnknown* EntryPointCapture::getInterface(const Guid& guid) + { + if(guid == EntryPointCapture::getTypeGuid()) + return static_cast<ISlangUnknown*>(this); + else + return nullptr; + } + + SLANG_NO_THROW slang::ISession* EntryPointCapture::getSession() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ISession* res = m_actualEntryPoint->getSession(); + return res; + } + + SLANG_NO_THROW slang::ProgramLayout* EntryPointCapture::getLayout( + SlangInt targetIndex, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ProgramLayout* res = m_actualEntryPoint->getLayout(targetIndex, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangInt EntryPointCapture::getSpecializationParamCount() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangInt res = m_actualEntryPoint->getSpecializationParamCount(); + return res; + } + + SLANG_NO_THROW SlangResult EntryPointCapture::getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualEntryPoint->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult EntryPointCapture::getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualEntryPoint->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + return res; + } + + SLANG_NO_THROW void EntryPointCapture::getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + m_actualEntryPoint->getEntryPointHash(entryPointIndex, targetIndex, outHash); + } + + SLANG_NO_THROW SlangResult EntryPointCapture::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualEntryPoint->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult EntryPointCapture::link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualEntryPoint->link(outLinkedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult EntryPointCapture::getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualEntryPoint->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult EntryPointCapture::renameEntryPoint( + const char* newName, IComponentType** outEntryPoint) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualEntryPoint->renameEntryPoint(newName, outEntryPoint); + return res; + } + + SLANG_NO_THROW SlangResult EntryPointCapture::linkWithOptions( + IComponentType** outLinkedComponentType, + uint32_t compilerOptionEntryCount, + slang::CompilerOptionEntry* compilerOptionEntries, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualEntryPoint->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics); + return res; + } +} diff --git a/source/slang-capture-replay/slang-entrypoint.h b/source/slang-capture-replay/slang-entrypoint.h new file mode 100644 index 000000000..2ff3b0718 --- /dev/null +++ b/source/slang-capture-replay/slang-entrypoint.h @@ -0,0 +1,70 @@ +#ifndef SLANG_ENTRY_POINT_H +#define SLANG_ENTRY_POINT_H + +#include "../../slang-com-ptr.h" +#include "../../slang.h" +#include "../../slang-com-helper.h" +#include "../core/slang-smart-pointer.h" +#include "../core/slang-dictionary.h" +#include "../slang/slang-compiler.h" + +namespace SlangCapture +{ + using namespace Slang; + class EntryPointCapture : public slang::IEntryPoint, public RefObject + { + public: + SLANG_COM_INTERFACE(0xf4c1e23d, 0xb321, 0x4931, { 0x8f, 0x37, 0xf1, 0x22, 0x6a, 0xf9, 0x20, 0x85 }) + + SLANG_REF_OBJECT_IUNKNOWN_ALL + ISlangUnknown* getInterface(const Guid& guid); + + explicit EntryPointCapture(slang::IEntryPoint* entryPoint); + ~EntryPointCapture(); + + // Interfaces for `IComponentType` + virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override; + virtual SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( + SlangInt targetIndex = 0, + slang::IBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangInt SLANG_MCALL getSpecializationParamCount() override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) override; + virtual SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics = 0) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( + const char* newName, IComponentType** outEntryPoint) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + IComponentType** outLinkedComponentType, + uint32_t compilerOptionEntryCount, + slang::CompilerOptionEntry* compilerOptionEntries, + ISlangBlob** outDiagnostics = nullptr) override; + + slang::IEntryPoint* getActualEntryPoint() const { return m_actualEntryPoint; } + private: + Slang::ComPtr<slang::IEntryPoint> m_actualEntryPoint; + }; +} +#endif // SLANG_ENTRY_POINT_H diff --git a/source/slang-capture-replay/slang-filesystem.cpp b/source/slang-capture-replay/slang-filesystem.cpp index 10afbdf10..fac6cf66e 100644 --- a/source/slang-capture-replay/slang-filesystem.cpp +++ b/source/slang-capture-replay/slang-filesystem.cpp @@ -6,8 +6,8 @@ namespace SlangCapture FileSystemCapture::FileSystemCapture(ISlangFileSystemExt* fileSystem) : m_actualFileSystem(fileSystem) { - assert(m_actualFileSystem); - slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __func__, m_actualFileSystem.get()); + SLANG_CAPTURE_ASSERT(m_actualFileSystem); + slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, m_actualFileSystem.get()); } FileSystemCapture::~FileSystemCapture() @@ -31,7 +31,7 @@ namespace SlangCapture char const* path, ISlangBlob** outBlob) { - slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __func__, path); + slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path); SlangResult res = m_actualFileSystem->loadFile(path, outBlob); return res; } @@ -40,7 +40,7 @@ namespace SlangCapture const char* path, ISlangBlob** outUniqueIdentity) { - slangCaptureLog(LogLevel::Verbose, "%p: %s :\"%s\"\n", m_actualFileSystem.get(), __func__, path); + slangCaptureLog(LogLevel::Verbose, "%p: %s :\"%s\"\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path); SlangResult res = m_actualFileSystem->getFileUniqueIdentity(path, outUniqueIdentity); return res; } @@ -51,7 +51,7 @@ namespace SlangCapture const char* path, ISlangBlob** pathOut) { - slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __func__, path); + slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path); SlangResult res = m_actualFileSystem->calcCombinedPath(fromPathType, fromPath, path, pathOut); return res; } @@ -60,7 +60,7 @@ namespace SlangCapture const char* path, SlangPathType* pathTypeOut) { - slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __func__, path); + slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path); SlangResult res = m_actualFileSystem->getPathType(path, pathTypeOut); return res; } @@ -70,14 +70,14 @@ namespace SlangCapture const char* path, ISlangBlob** outPath) { - slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __func__, path); + slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path); SlangResult res = m_actualFileSystem->getPath(kind, path, outPath); return res; } SLANG_NO_THROW void FileSystemCapture::clearCache() { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualFileSystem.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__); m_actualFileSystem->clearCache(); } @@ -86,14 +86,14 @@ namespace SlangCapture FileSystemContentsCallBack callback, void* userData) { - slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __func__, path); + slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path); SlangResult res = m_actualFileSystem->enumeratePathContents(path, callback, userData); return res; } SLANG_NO_THROW OSPathKind FileSystemCapture::getOSPathKind() { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualFileSystem.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__); OSPathKind pathKind = m_actualFileSystem->getOSPathKind(); return pathKind; } diff --git a/source/slang-capture-replay/slang-global-session.cpp b/source/slang-capture-replay/slang-global-session.cpp index e4754b9d2..05a2b93ec 100644 --- a/source/slang-capture-replay/slang-global-session.cpp +++ b/source/slang-capture-replay/slang-global-session.cpp @@ -1,17 +1,17 @@ #include <vector> #include "slang-global-session.h" -#include "capture_utility.h" #include "slang-session.h" #include "slang-filesystem.h" #include "../slang/slang-compiler.h" +#include "capture_utility.h" namespace SlangCapture { GlobalSessionCapture::GlobalSessionCapture(slang::IGlobalSession* session): m_actualGlobalSession(session) { - assert(m_actualGlobalSession != nullptr); + SLANG_CAPTURE_ASSERT(m_actualGlobalSession != nullptr); } GlobalSessionCapture::~GlobalSessionCapture() @@ -29,7 +29,7 @@ namespace SlangCapture SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::createSession(slang::SessionDesc const& desc, slang::ISession** outSession) { setLogLevel(); - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); slang::ISession* actualSession = nullptr; SlangResult res = m_actualGlobalSession->createSession(desc, &actualSession); @@ -55,152 +55,152 @@ namespace SlangCapture SLANG_NO_THROW SlangProfileID SLANG_MCALL GlobalSessionCapture::findProfile(char const* name) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangProfileID profileId = m_actualGlobalSession->findProfile(name); return profileId; } SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::setDownstreamCompilerPath(SlangPassThrough passThrough, char const* path) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); m_actualGlobalSession->setDownstreamCompilerPath(passThrough, path); } SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::setDownstreamCompilerPrelude(SlangPassThrough inPassThrough, char const* prelude) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); m_actualGlobalSession->setDownstreamCompilerPrelude(inPassThrough, prelude); } SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::getDownstreamCompilerPrelude(SlangPassThrough inPassThrough, ISlangBlob** outPrelude) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); m_actualGlobalSession->getDownstreamCompilerPrelude(inPassThrough, outPrelude); } SLANG_NO_THROW const char* SLANG_MCALL GlobalSessionCapture::getBuildTagString() { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); const char* resStr = m_actualGlobalSession->getBuildTagString(); return resStr; } SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::setDefaultDownstreamCompiler(SlangSourceLanguage sourceLanguage, SlangPassThrough defaultCompiler) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangResult res = m_actualGlobalSession->setDefaultDownstreamCompiler(sourceLanguage, defaultCompiler); return res; } SLANG_NO_THROW SlangPassThrough SLANG_MCALL GlobalSessionCapture::getDefaultDownstreamCompiler(SlangSourceLanguage sourceLanguage) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangPassThrough passThrough = m_actualGlobalSession->getDefaultDownstreamCompiler(sourceLanguage); return passThrough; } SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::setLanguagePrelude(SlangSourceLanguage inSourceLanguage, char const* prelude) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); m_actualGlobalSession->setLanguagePrelude(inSourceLanguage, prelude); } SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::getLanguagePrelude(SlangSourceLanguage inSourceLanguage, ISlangBlob** outPrelude) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); m_actualGlobalSession->getLanguagePrelude(inSourceLanguage, outPrelude); } SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::createCompileRequest(slang::ICompileRequest** outCompileRequest) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangResult res = m_actualGlobalSession->createCompileRequest(outCompileRequest); return res; } SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::addBuiltins(char const* sourcePath, char const* sourceString) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); m_actualGlobalSession->addBuiltins(sourcePath, sourceString); } SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::setSharedLibraryLoader(ISlangSharedLibraryLoader* loader) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); m_actualGlobalSession->setSharedLibraryLoader(loader); } SLANG_NO_THROW ISlangSharedLibraryLoader* SLANG_MCALL GlobalSessionCapture::getSharedLibraryLoader() { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); ISlangSharedLibraryLoader* loader = m_actualGlobalSession->getSharedLibraryLoader(); return loader; } SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::checkCompileTargetSupport(SlangCompileTarget target) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangResult res = m_actualGlobalSession->checkCompileTargetSupport(target); return res; } SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::checkPassThroughSupport(SlangPassThrough passThrough) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangResult res = m_actualGlobalSession->checkPassThroughSupport(passThrough); return res; } SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::compileStdLib(slang::CompileStdLibFlags flags) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangResult res = m_actualGlobalSession->compileStdLib(flags); return res; } SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangResult res = m_actualGlobalSession->loadStdLib(stdLib, stdLibSizeInBytes); return res; } SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::saveStdLib(SlangArchiveType archiveType, ISlangBlob** outBlob) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangResult res = m_actualGlobalSession->saveStdLib(archiveType, outBlob); return res; } SLANG_NO_THROW SlangCapabilityID SLANG_MCALL GlobalSessionCapture::findCapability(char const* name) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangCapabilityID capId = m_actualGlobalSession->findCapability(name); return capId; } SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::setDownstreamCompilerForTransition(SlangCompileTarget source, SlangCompileTarget target, SlangPassThrough compiler) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); m_actualGlobalSession->setDownstreamCompilerForTransition(source, target, compiler); } SLANG_NO_THROW SlangPassThrough SLANG_MCALL GlobalSessionCapture::getDownstreamCompilerForTransition(SlangCompileTarget source, SlangCompileTarget target) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangPassThrough passThrough = m_actualGlobalSession->getDownstreamCompilerForTransition(source, target); return passThrough; } SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::getCompilerElapsedTime(double* outTotalTime, double* outDownstreamTime) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); m_actualGlobalSession->getCompilerElapsedTime(outTotalTime, outDownstreamTime); } SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::setSPIRVCoreGrammar(char const* jsonPath) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangResult res = m_actualGlobalSession->setSPIRVCoreGrammar(jsonPath); return res; } @@ -208,14 +208,14 @@ namespace SlangCapture SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::parseCommandLineArguments( int argc, const char* const* argv, slang::SessionDesc* outSessionDesc, ISlangUnknown** outAllocation) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangResult res = m_actualGlobalSession->parseCommandLineArguments(argc, argv, outSessionDesc, outAllocation); return res; } SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::getSessionDescDigest(slang::SessionDesc* sessionDesc, ISlangBlob** outBlob) { - slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__); + slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__); SlangResult res = m_actualGlobalSession->getSessionDescDigest(sessionDesc, outBlob); return res; } diff --git a/source/slang-capture-replay/slang-module.cpp b/source/slang-capture-replay/slang-module.cpp new file mode 100644 index 000000000..31aa6bd54 --- /dev/null +++ b/source/slang-capture-replay/slang-module.cpp @@ -0,0 +1,239 @@ +#include "capture_utility.h" +#include "slang-module.h" + +namespace SlangCapture +{ + ModuleCapture::ModuleCapture(slang::IModule* module) + : m_actualModule(module) + { + SLANG_CAPTURE_ASSERT(m_actualModule != nullptr); + slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, module); + } + + ModuleCapture::~ModuleCapture() + { + m_actualModule->release(); + } + + ISlangUnknown* ModuleCapture::getInterface(const Guid& guid) + { + if(guid == ModuleCapture::getTypeGuid()) + return static_cast<ISlangUnknown*>(this); + else + return nullptr; + } + + SLANG_NO_THROW SlangResult ModuleCapture::findEntryPointByName( + char const* name, + slang::IEntryPoint** outEntryPoint) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + + SlangResult res = m_actualModule->findEntryPointByName(name, outEntryPoint); + + if (SLANG_OK == res) + { + EntryPointCapture* entryPointCapture = getEntryPointCapture(*outEntryPoint); + *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture); + } + return res; + } + + SLANG_NO_THROW SlangInt32 ModuleCapture::getDefinedEntryPointCount() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangInt32 res = m_actualModule->getDefinedEntryPointCount(); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint) + { + // This call is to find the existing entry point, so it has been created already. Therefore, we don't create a new one + // and assert the error if it is not found in our map. + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getDefinedEntryPoint(index, outEntryPoint); + + if (*outEntryPoint) + { + EntryPointCapture* entryPointCapture = m_mapEntryPointToCapture.tryGetValue(*outEntryPoint); + if (!entryPointCapture) + { + SLANG_CAPTURE_ASSERT(!"Entrypoint not found in mapEntryPointToCapture"); + } + *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture); + } + else + *outEntryPoint = nullptr; + + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::serialize(ISlangBlob** outSerializedBlob) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->serialize(outSerializedBlob); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::writeToFile(char const* fileName) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->writeToFile(fileName); + return res; + } + + SLANG_NO_THROW const char* ModuleCapture::getName() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + const char* res = m_actualModule->getName(); + return res; + } + + SLANG_NO_THROW const char* ModuleCapture::getFilePath() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + const char* res = m_actualModule->getFilePath(); + return res; + } + + SLANG_NO_THROW const char* ModuleCapture::getUniqueIdentity() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + const char* res = m_actualModule->getUniqueIdentity(); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::findAndCheckEntryPoint( + char const* name, + SlangStage stage, + slang::IEntryPoint** outEntryPoint, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + + SlangResult res = m_actualModule->findAndCheckEntryPoint(name, stage, outEntryPoint, outDiagnostics); + + if (SLANG_OK == res) + { + EntryPointCapture* entryPointCapture = getEntryPointCapture(*outEntryPoint); + *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture); + } + return res; + } + + SLANG_NO_THROW slang::ISession* ModuleCapture::getSession() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ISession* session = m_actualModule->getSession(); + return session; + } + + SLANG_NO_THROW slang::ProgramLayout* ModuleCapture::getLayout( + SlangInt targetIndex, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ProgramLayout* programLayout = m_actualModule->getLayout(targetIndex, outDiagnostics); + return programLayout; + } + + SLANG_NO_THROW SlangInt ModuleCapture::getSpecializationParamCount() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangInt res = m_actualModule->getSpecializationParamCount(); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + return res; + } + + SLANG_NO_THROW void ModuleCapture::getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + m_actualModule->getEntryPointHash(entryPointIndex, targetIndex, outHash); + } + + SLANG_NO_THROW SlangResult ModuleCapture::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::link( + IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->link(outLinkedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::renameEntryPoint( + const char* newName, IComponentType** outEntryPoint) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->renameEntryPoint(newName, outEntryPoint); + return res; + } + + SLANG_NO_THROW SlangResult ModuleCapture::linkWithOptions( + IComponentType** outLinkedComponentType, + uint32_t compilerOptionEntryCount, + slang::CompilerOptionEntry* compilerOptionEntries, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics); + return res; + } + + EntryPointCapture* ModuleCapture::getEntryPointCapture(slang::IEntryPoint* entryPoint) + { + EntryPointCapture* entryPointCapture = nullptr; + entryPointCapture = m_mapEntryPointToCapture.tryGetValue(entryPoint); + if (!entryPointCapture) + { + entryPointCapture = new EntryPointCapture(entryPoint); + Slang::ComPtr<EntryPointCapture> result(entryPointCapture); + m_mapEntryPointToCapture.add(entryPoint, *result.detach()); + } + return entryPointCapture; + } +} diff --git a/source/slang-capture-replay/slang-module.h b/source/slang-capture-replay/slang-module.h new file mode 100644 index 000000000..d1180c828 --- /dev/null +++ b/source/slang-capture-replay/slang-module.h @@ -0,0 +1,91 @@ +#ifndef SLANG_MODULE_H +#define SLANG_MODULE_H + +#include "../../slang-com-ptr.h" +#include "../../slang.h" +#include "../../slang-com-helper.h" +#include "../core/slang-smart-pointer.h" +#include "../slang/slang-compiler.h" +#include "slang-entrypoint.h" + +namespace SlangCapture +{ + using namespace Slang; + class ModuleCapture : public slang::IModule, public RefObject + { + public: + SLANG_COM_INTERFACE(0xb1802991, 0x185a, 0x4a03, { 0xa7, 0x7e, 0x0c, 0x86, 0xe0, 0x68, 0x2a, 0xab }) + + SLANG_REF_OBJECT_IUNKNOWN_ALL + ISlangUnknown* getInterface(const Guid& guid); + + explicit ModuleCapture(slang::IModule* module); + ~ModuleCapture(); + + // Interfaces for `IModule` + virtual SLANG_NO_THROW SlangResult SLANG_MCALL findEntryPointByName( + char const* name, slang::IEntryPoint** outEntryPoint) override; + virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDefinedEntryPointCount() override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL serialize(ISlangBlob** outSerializedBlob) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL writeToFile(char const* fileName) override; + virtual SLANG_NO_THROW const char* SLANG_MCALL getName() override; + virtual SLANG_NO_THROW const char* SLANG_MCALL getFilePath() override; + virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint( + char const* name, SlangStage stage, slang::IEntryPoint** outEntryPoint, ISlangBlob** outDiagnostics) override; + + // Interfaces for `IComponentType` + virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override; + virtual SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( + SlangInt targetIndex = 0, + slang::IBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangInt SLANG_MCALL getSpecializationParamCount() override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) override; + virtual SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics = 0) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( + const char* newName, IComponentType** outEntryPoint) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + IComponentType** outLinkedComponentType, + uint32_t compilerOptionEntryCount, + slang::CompilerOptionEntry* compilerOptionEntries, + ISlangBlob** outDiagnostics = nullptr) override; + + slang::IModule* getActualModule() const { return m_actualModule; } + private: + EntryPointCapture* getEntryPointCapture(slang::IEntryPoint* entryPoint); + Slang::ComPtr<slang::IModule> m_actualModule; + + // `IEntryPoint` can only be created from 'IModule', so we need to capture it in + // this class, and create a map such that we don't create new `EntryPointCapture` + // for the same `IEntryPoint`. + Dictionary<slang::IEntryPoint*, EntryPointCapture> m_mapEntryPointToCapture; + }; +} // namespace SlangCapture + +#endif // SLANG_MODULE_H diff --git a/source/slang-capture-replay/slang-session.cpp b/source/slang-capture-replay/slang-session.cpp index 3d96af6d7..5cf029629 100644 --- a/source/slang-capture-replay/slang-session.cpp +++ b/source/slang-capture-replay/slang-session.cpp @@ -1,5 +1,8 @@ #include "capture_utility.h" #include "slang-session.h" +#include "slang-entrypoint.h" +#include "slang-composite-component-type.h" +#include "slang-type-conformance.h" namespace SlangCapture { @@ -7,8 +10,8 @@ namespace SlangCapture SessionCapture::SessionCapture(slang::ISession* session) : m_actualSession(session) { - assert(m_actualSession); - slangCaptureLog(LogLevel::Verbose, "%s: %p\n", "Session", session); + SLANG_CAPTURE_ASSERT(m_actualSession); + slangCaptureLog(LogLevel::Verbose, "%s: %p\n", "SessionCapture create:", session); } SessionCapture::~SessionCapture() @@ -26,7 +29,7 @@ namespace SlangCapture SLANG_NO_THROW slang::IGlobalSession* SessionCapture::getGlobalSession() { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); slang::IGlobalSession* pGlobalSession = m_actualSession->getGlobalSession(); return pGlobalSession; } @@ -35,9 +38,10 @@ namespace SlangCapture const char* moduleName, slang::IBlob** outDiagnostics) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); slang::IModule* pModule = m_actualSession->loadModule(moduleName, outDiagnostics); - return pModule; + ModuleCapture* pModuleCapture = getModuleCapture(pModule); + return static_cast<slang::IModule*>(pModuleCapture); } SLANG_NO_THROW slang::IModule* SessionCapture::loadModuleFromIRBlob( @@ -46,9 +50,10 @@ namespace SlangCapture slang::IBlob* source, slang::IBlob** outDiagnostics) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); slang::IModule* pModule = m_actualSession->loadModuleFromIRBlob(moduleName, path, source, outDiagnostics); - return pModule; + ModuleCapture* pModuleCapture = getModuleCapture(pModule); + return static_cast<slang::IModule*>(pModuleCapture); } SLANG_NO_THROW slang::IModule* SessionCapture::loadModuleFromSource( @@ -57,9 +62,10 @@ namespace SlangCapture slang::IBlob* source, slang::IBlob** outDiagnostics) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); slang::IModule* pModule = m_actualSession->loadModuleFromSource(moduleName, path, source, outDiagnostics); - return pModule; + ModuleCapture* pModuleCapture = getModuleCapture(pModule); + return static_cast<slang::IModule*>(pModuleCapture); } SLANG_NO_THROW slang::IModule* SessionCapture::loadModuleFromSourceString( @@ -68,9 +74,10 @@ namespace SlangCapture const char* string, slang::IBlob** outDiagnostics) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); slang::IModule* pModule = m_actualSession->loadModuleFromSourceString(moduleName, path, string, outDiagnostics); - return pModule; + ModuleCapture* pModuleCapture = getModuleCapture(pModule); + return static_cast<slang::IModule*>(pModuleCapture); } SLANG_NO_THROW SlangResult SessionCapture::createCompositeComponentType( @@ -79,8 +86,26 @@ namespace SlangCapture slang::IComponentType** outCompositeComponentType, ISlangBlob** outDiagnostics) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); - SlangResult result = m_actualSession->createCompositeComponentType(componentTypes, componentTypeCount, outCompositeComponentType, outDiagnostics); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + + Slang::List<slang::IComponentType*> componentTypeList; + + // get the actual component types from our capture wrappers + if(SLANG_OK != getActualComponentTypes(componentTypes, componentTypeCount, componentTypeList)) + { + SLANG_CAPTURE_ASSERT(!"Failed to get actual component types"); + } + + SlangResult result = m_actualSession->createCompositeComponentType( + componentTypeList.getBuffer(), componentTypeCount, outCompositeComponentType, outDiagnostics); + + if (SLANG_OK == result) + { + CompositeComponentTypeCapture* compositeComponentTypeCapture = new CompositeComponentTypeCapture(*outCompositeComponentType); + Slang::ComPtr<CompositeComponentTypeCapture> resultCapture(compositeComponentTypeCapture); + *outCompositeComponentType = resultCapture.detach(); + } + return result; } @@ -90,7 +115,7 @@ namespace SlangCapture SlangInt specializationArgCount, ISlangBlob** outDiagnostics) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); slang::TypeReflection* pTypeReflection = m_actualSession->specializeType(type, specializationArgs, specializationArgCount, outDiagnostics); return pTypeReflection; } @@ -101,7 +126,7 @@ namespace SlangCapture slang::LayoutRules rules, ISlangBlob** outDiagnostics) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); slang::TypeLayoutReflection* pTypeLayoutReflection = m_actualSession->getTypeLayout(type, targetIndex, rules, outDiagnostics); return pTypeLayoutReflection; } @@ -111,14 +136,14 @@ namespace SlangCapture slang::ContainerType containerType, ISlangBlob** outDiagnostics) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); slang::TypeReflection* pTypeReflection = m_actualSession->getContainerType(elementType, containerType, outDiagnostics); return pTypeReflection; } SLANG_NO_THROW slang::TypeReflection* SessionCapture::getDynamicType() { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); slang::TypeReflection* pTypeReflection = m_actualSession->getDynamicType(); return pTypeReflection; } @@ -127,7 +152,7 @@ namespace SlangCapture slang::TypeReflection* type, ISlangBlob** outNameBlob) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); SlangResult result = m_actualSession->getTypeRTTIMangledName(type, outNameBlob); return result; } @@ -137,7 +162,7 @@ namespace SlangCapture slang::TypeReflection* interfaceType, ISlangBlob** outNameBlob) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); SlangResult result = m_actualSession->getTypeConformanceWitnessMangledName(type, interfaceType, outNameBlob); return result; } @@ -147,7 +172,7 @@ namespace SlangCapture slang::TypeReflection* interfaceType, uint32_t* outId) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); SlangResult result = m_actualSession->getTypeConformanceWitnessSequentialID(type, interfaceType, outId); return result; } @@ -159,38 +184,104 @@ namespace SlangCapture SlangInt conformanceIdOverride, ISlangBlob** outDiagnostics) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult result = m_actualSession->createTypeConformanceComponentType(type, interfaceType, outConformance, conformanceIdOverride, outDiagnostics); + + if (SLANG_OK != result) + { + TypeConformanceCapture* conformanceCapture = new TypeConformanceCapture(*outConformance); + Slang::ComPtr<TypeConformanceCapture> resultCapture(conformanceCapture); + *outConformance = resultCapture.detach(); + } + return result; } SLANG_NO_THROW SlangResult SessionCapture::createCompileRequest( SlangCompileRequest** outCompileRequest) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); SlangResult result = m_actualSession->createCompileRequest(outCompileRequest); return result; } SLANG_NO_THROW SlangInt SessionCapture::getLoadedModuleCount() { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); SlangInt count = m_actualSession->getLoadedModuleCount(); return count; } SLANG_NO_THROW slang::IModule* SessionCapture::getLoadedModule(SlangInt index) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); slang::IModule* pModule = m_actualSession->getLoadedModule(index); + + if (pModule) + { + ModuleCapture* moduleCapture = m_mapModuleToCapture.tryGetValue(pModule); + if (!moduleCapture) + { + SLANG_CAPTURE_ASSERT(!"Module not found in mapModuleToCapture"); + } + return static_cast<slang::IModule*>(moduleCapture); + } + return pModule; } SLANG_NO_THROW bool SessionCapture::isBinaryModuleUpToDate(const char* modulePath, slang::IBlob* binaryModuleBlob) { - slangCaptureLog(LogLevel::Verbose, "%s\n", __func__); + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); bool result = m_actualSession->isBinaryModuleUpToDate(modulePath, binaryModuleBlob); return result; } + ModuleCapture* SessionCapture::getModuleCapture(slang::IModule* module) + { + ModuleCapture* moduleCapture = nullptr; + moduleCapture = m_mapModuleToCapture.tryGetValue(module); + if (!moduleCapture) + { + moduleCapture = new ModuleCapture(module); + Slang::ComPtr<ModuleCapture> result(moduleCapture); + m_mapModuleToCapture.add(module, *result.detach()); + } + return moduleCapture; + } + + SlangResult SessionCapture::getActualComponentTypes( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + List<slang::IComponentType*>& outActualComponentTypes) + { + for (SlangInt i = 0; i < componentTypeCount; i++) + { + slang::IComponentType* const& componentType = componentTypes[i]; + void* outObj = nullptr; + + if (componentType->queryInterface(ModuleCapture::getTypeGuid(), &outObj) == SLANG_OK) + { + ModuleCapture* moduleCapture = static_cast<ModuleCapture*>(outObj); + outActualComponentTypes.add(moduleCapture->getActualModule()); + } + else if (componentType->queryInterface(EntryPointCapture::getTypeGuid(), &outObj) == SLANG_OK) + { + EntryPointCapture* entrypointCapture = static_cast<EntryPointCapture*>(outObj); + outActualComponentTypes.add(entrypointCapture->getActualEntryPoint()); + } + // will fall back to the actual component type, it means that we didn't capture this type. + else + { + outActualComponentTypes.add(componentType); + } + } + + if (componentTypeCount == outActualComponentTypes.getCount()) + { + return SLANG_OK; + } + return SLANG_FAIL; + } } // namespace SlangCapture diff --git a/source/slang-capture-replay/slang-session.h b/source/slang-capture-replay/slang-session.h index 24903a1b2..f099108b0 100644 --- a/source/slang-capture-replay/slang-session.h +++ b/source/slang-capture-replay/slang-session.h @@ -5,7 +5,9 @@ #include "../../slang.h" #include "../../slang-com-helper.h" #include "../core/slang-smart-pointer.h" +#include "../core/slang-dictionary.h" #include "../slang/slang-compiler.h" +#include "slang-module.h" namespace SlangCapture { @@ -92,7 +94,19 @@ namespace SlangCapture { return static_cast<slang::ISession*>(session); } + + // The IComponentType object is the capture target, therefore `componentTypes` will not be + // the actual component types, we have to use the COM interface to get the actual objects. + SlangResult getActualComponentTypes( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + List<slang::IComponentType*>& outActualComponentTypes); + + ModuleCapture* getModuleCapture(slang::IModule* module); + Slang::ComPtr<slang::ISession> m_actualSession; + + Dictionary<slang::IModule*, ModuleCapture> m_mapModuleToCapture; }; } diff --git a/source/slang-capture-replay/slang-type-conformance.cpp b/source/slang-capture-replay/slang-type-conformance.cpp new file mode 100644 index 000000000..1ca9cf737 --- /dev/null +++ b/source/slang-capture-replay/slang-type-conformance.cpp @@ -0,0 +1,131 @@ +#include "capture_utility.h" +#include "slang-type-conformance.h" + +namespace SlangCapture +{ + TypeConformanceCapture::TypeConformanceCapture(slang::ITypeConformance* typeConformance) + : m_actualTypeConformance(typeConformance) + { + SLANG_CAPTURE_ASSERT(m_actualTypeConformance != nullptr); + slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, typeConformance); + } + TypeConformanceCapture::~TypeConformanceCapture() + { + m_actualTypeConformance->release(); + } + + ISlangUnknown* TypeConformanceCapture::getInterface(const Guid& guid) + { + if (guid == TypeConformanceCapture::getTypeGuid()) + { + return static_cast<ISlangUnknown*>(this); + } + else + { + return nullptr; + } + } + + SLANG_NO_THROW slang::ISession* TypeConformanceCapture::getSession() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ISession* res = m_actualTypeConformance->getSession(); + return res; + } + + SLANG_NO_THROW slang::ProgramLayout* TypeConformanceCapture::getLayout( + SlangInt targetIndex, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::ProgramLayout* res = m_actualTypeConformance->getLayout(targetIndex, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangInt TypeConformanceCapture::getSpecializationParamCount() + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangInt res = m_actualTypeConformance->getSpecializationParamCount(); + return res; + } + + SLANG_NO_THROW SlangResult TypeConformanceCapture::getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualTypeConformance->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult TypeConformanceCapture::getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualTypeConformance->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + return res; + } + + SLANG_NO_THROW void TypeConformanceCapture::getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + m_actualTypeConformance->getEntryPointHash(entryPointIndex, targetIndex, outHash); + } + + SLANG_NO_THROW SlangResult TypeConformanceCapture::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualTypeConformance->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult TypeConformanceCapture::link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualTypeConformance->link(outLinkedComponentType, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult TypeConformanceCapture::getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualTypeConformance->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); + return res; + } + + SLANG_NO_THROW SlangResult TypeConformanceCapture::renameEntryPoint( + const char* newName, IComponentType** outEntryPoint) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualTypeConformance->renameEntryPoint(newName, outEntryPoint); + return res; + } + + SLANG_NO_THROW SlangResult TypeConformanceCapture::linkWithOptions( + IComponentType** outLinkedComponentType, + uint32_t compilerOptionEntryCount, + slang::CompilerOptionEntry* compilerOptionEntries, + ISlangBlob** outDiagnostics) + { + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualTypeConformance->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics); + return res; + } +} diff --git a/source/slang-capture-replay/slang-type-conformance.h b/source/slang-capture-replay/slang-type-conformance.h new file mode 100644 index 000000000..e2a7b27c9 --- /dev/null +++ b/source/slang-capture-replay/slang-type-conformance.h @@ -0,0 +1,70 @@ +#ifndef SLANG_TYPE_CONFORMANCE_H +#define SLANG_TYPE_CONFORMANCE_H + +#include "../../slang-com-ptr.h" +#include "../../slang.h" +#include "../../slang-com-helper.h" +#include "../core/slang-smart-pointer.h" +#include "../core/slang-dictionary.h" +#include "../slang/slang-compiler.h" + +namespace SlangCapture +{ + using namespace Slang; + class TypeConformanceCapture: public slang::ITypeConformance, public RefObject + { + public: + SLANG_COM_INTERFACE(0x0e67d05d, 0xee0a, 0x41e1, { 0xb5, 0xa3, 0x23, 0xe3, 0xb0, 0xec, 0x33, 0xf1 }) + + SLANG_REF_OBJECT_IUNKNOWN_ALL + ISlangUnknown* getInterface(const Guid& guid); + + explicit TypeConformanceCapture(slang::ITypeConformance* typeConformance); + ~TypeConformanceCapture(); + + // Interfaces for `IComponentType` + virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override; + virtual SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( + SlangInt targetIndex = 0, + slang::IBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangInt SLANG_MCALL getSpecializationParamCount() override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) override; + virtual SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics = 0) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( + const char* newName, IComponentType** outEntryPoint) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + IComponentType** outLinkedComponentType, + uint32_t compilerOptionEntryCount, + slang::CompilerOptionEntry* compilerOptionEntries, + ISlangBlob** outDiagnostics = nullptr) override; + + slang::ITypeConformance* getActualTypeConformance() const { return m_actualTypeConformance; } + private: + Slang::ComPtr<slang::ITypeConformance> m_actualTypeConformance; + }; +} +#endif // SLANG_TYPE_CONFORMANCE_H diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index fc97a2f47..770449de7 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1335,6 +1335,11 @@ namespace Slang char const* name, slang::IEntryPoint** outEntryPoint) SLANG_OVERRIDE { + if (outEntryPoint == nullptr) + { + return SLANG_E_INVALID_ARG; + } + ComPtr<slang::IEntryPoint> entryPoint(findEntryPointByName(UnownedStringSlice(name))); if((!entryPoint)) return SLANG_FAIL; @@ -1349,6 +1354,11 @@ namespace Slang slang::IEntryPoint** outEntryPoint, ISlangBlob** outDiagnostics) override { + if (outEntryPoint == nullptr) + { + return SLANG_E_INVALID_ARG; + } + ComPtr<slang::IEntryPoint> entryPoint(findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics)); if ((!entryPoint)) return SLANG_FAIL; @@ -1367,6 +1377,11 @@ namespace Slang if (index < 0 || index >= m_entryPoints.getCount()) return SLANG_E_INVALID_ARG; + if (outEntryPoint == nullptr) + { + return SLANG_E_INVALID_ARG; + } + ComPtr<slang::IEntryPoint> entryPoint(m_entryPoints[index].Ptr()); *outEntryPoint = entryPoint.detach(); return SLANG_OK; @@ -1541,7 +1556,7 @@ namespace Slang // List of source files this module depends on FileDependencyList m_fileDependencyList; - // Entry points that were defined in thsi module + // Entry points that were defined in this module // // Note: the entry point defined in the module are *not* // part of the memory image/layout of the module when diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 4d83823d2..c6af8b34d 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1266,6 +1266,9 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompositeComponentType( slang::IComponentType** outCompositeComponentType, ISlangBlob** outDiagnostics) { + if (outCompositeComponentType == nullptr) + return SLANG_E_INVALID_ARG; + SLANG_AST_BUILDER_RAII(getASTBuilder()); // Attempting to create a "composite" of just one component type should @@ -1491,6 +1494,9 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentTy SlangInt conformanceIdOverride, ISlangBlob** outDiagnostics) { + if (outConformanceComponentType == nullptr) + return SLANG_E_INVALID_ARG; + SLANG_AST_BUILDER_RAII(getASTBuilder()); RefPtr<TypeConformance> result; |
