summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2024-05-17 11:43:08 -0500
committerGitHub <noreply@github.com>2024-05-17 09:43:08 -0700
commit9f786fdf71e90339e20979ef3ba8f073657f5a98 (patch)
treedac01aa41224373527d2ce793a82c34e142a7b05 /source
parent000396cdc4b00f7f8bf92a6569367e5cb9d6ba27 (diff)
capture/relay: Add capture interface classes (#4177)
* capture/relay: Add capture interface classes Add `ModuleCapture` class for capturing `IModule` - The `IModule` can only be created from -- `ISession::loadModule` -- `ISession::loadModuleFromIRBlob` -- `ISession::loadModuleFromSource` -- `ISession::loadModuleFromSourceString` so, we create the `ModuleCapture` at those methods in `SessionCapture` class. We use a hash map to store a map from `IModule` to `ModuleCapture` to avoid creating new `ModuleCapture` when there is already an old one. - In `SessionCapture::getLoadedModule`, we will assert on not finding a `ModuleCapture` instance. Add `EntryPointCapture` class for capturing `IEntryPoint`. - The `IEntryPoint` can only be created from: -- `IModule::findEntryPointByName` -- `IModule::findAndCheckEntryPoint` so, we create the `EntryPointCapture` at those methods in `ModuleCapture`. Similarly, we use a hash map to store a map from `IEntryPoint` to `EntryPointCapture`. - In `IModule::getDefinedEntryPoint`, we will assert on not finding a `EntryPointCapture` instance. Add `CompositeComponentTypeCapture` class for capturing CompositeComponentType, but since user is only exposed to `IComponentType`, so `CompositeComponentTypeCapture` just inherits from `IComponentType`. - `CompositeComponentType` can only be created from: -- ISession::createCompositeComponentType so create it here. Add `TypeConformanceCapture` class for capturing `ITypeConformance`. - The `ITypeConformance` can only be created from: -- `ISession::createTypeConformanceComponentType` so create it here. In addition, because `EntryPointCapture` and `ModuleCapture` share a some base class `IComponentType`, we generate the COM GUID for those two classes to differentiate them. * Fix the build issue * Add nullptr check for output parameter * define the SLANG_CAPTURE_ASSERT macro used in both debug and release build
Diffstat (limited to 'source')
-rw-r--r--source/slang-capture-replay/capture_utility.h15
-rw-r--r--source/slang-capture-replay/slang-composite-component-type.cpp129
-rw-r--r--source/slang-capture-replay/slang-composite-component-type.h68
-rw-r--r--source/slang-capture-replay/slang-entrypoint.cpp128
-rw-r--r--source/slang-capture-replay/slang-entrypoint.h70
-rw-r--r--source/slang-capture-replay/slang-filesystem.cpp20
-rw-r--r--source/slang-capture-replay/slang-global-session.cpp56
-rw-r--r--source/slang-capture-replay/slang-module.cpp239
-rw-r--r--source/slang-capture-replay/slang-module.h91
-rw-r--r--source/slang-capture-replay/slang-session.cpp141
-rw-r--r--source/slang-capture-replay/slang-session.h14
-rw-r--r--source/slang-capture-replay/slang-type-conformance.cpp131
-rw-r--r--source/slang-capture-replay/slang-type-conformance.h70
-rwxr-xr-xsource/slang/slang-compiler.h17
-rw-r--r--source/slang/slang.cpp6
15 files changed, 1131 insertions, 64 deletions
diff --git a/source/slang-capture-replay/capture_utility.h b/source/slang-capture-replay/capture_utility.h
index db2939894..6edd2af9c 100644
--- a/source/slang-capture-replay/capture_utility.h
+++ b/source/slang-capture-replay/capture_utility.h
@@ -1,6 +1,12 @@
#ifndef CAPTURE_UTILITY_H
#define CAPTURE_UTILITY_H
+// in gcc and clang, __PRETTY_FUNCTION__ is the function signature,
+// while MSVC uses __FUNCSIG__
+#ifdef _MSC_VER
+#define __PRETTY_FUNCTION__ __FUNCSIG__
+#endif
+
namespace SlangCapture
{
enum LogLevel: unsigned int
@@ -15,4 +21,13 @@ namespace SlangCapture
void slangCaptureLog(LogLevel logLevel, const char* fmt, ...);
void setLogLevel();
}
+
+#define SLANG_CAPTURE_ASSERT(VALUE) \
+ do { \
+ if (!(VALUE)) { \
+ SlangCapture::slangCaptureLog(SlangCapture::LogLevel::Error, "Assertion failed: %s, %s, %d\n", #VALUE, __FILE__, __LINE__);\
+ std::abort(); \
+ } \
+ } while(0)
+
#endif // CAPTURE_UTILITY_H
diff --git a/source/slang-capture-replay/slang-composite-component-type.cpp b/source/slang-capture-replay/slang-composite-component-type.cpp
new file mode 100644
index 000000000..09fcec357
--- /dev/null
+++ b/source/slang-capture-replay/slang-composite-component-type.cpp
@@ -0,0 +1,129 @@
+#include "capture_utility.h"
+#include "slang-composite-component-type.h"
+
+namespace SlangCapture
+{
+ CompositeComponentTypeCapture::CompositeComponentTypeCapture(slang::IComponentType* componentType)
+ : m_actualCompositeComponentType(componentType)
+ {
+ SLANG_CAPTURE_ASSERT(m_actualCompositeComponentType != nullptr);
+ slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, componentType);
+ }
+
+ CompositeComponentTypeCapture::~CompositeComponentTypeCapture()
+ {
+ m_actualCompositeComponentType->release();
+ }
+
+ ISlangUnknown* CompositeComponentTypeCapture::getInterface(const Guid& guid)
+ {
+ if (guid == IComponentType::getTypeGuid())
+ {
+ return static_cast<ISlangUnknown*>(this);
+ }
+ return nullptr;
+ }
+
+ SLANG_NO_THROW slang::ISession* CompositeComponentTypeCapture::getSession()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::ISession* res = m_actualCompositeComponentType->getSession();
+ return res;
+ }
+
+ SLANG_NO_THROW slang::ProgramLayout* CompositeComponentTypeCapture::getLayout(
+ SlangInt targetIndex,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::ProgramLayout* res = m_actualCompositeComponentType->getLayout(targetIndex, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangInt CompositeComponentTypeCapture::getSpecializationParamCount()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangInt res = m_actualCompositeComponentType->getSpecializationParamCount();
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualCompositeComponentType->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::getResultAsFileSystem(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ ISlangMutableFileSystem** outFileSystem)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualCompositeComponentType->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem);
+ return res;
+ }
+
+ SLANG_NO_THROW void CompositeComponentTypeCapture::getEntryPointHash(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outHash)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ m_actualCompositeComponentType->getEntryPointHash(entryPointIndex, targetIndex, outHash);
+ }
+
+ SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualCompositeComponentType->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::link(
+ slang::IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualCompositeComponentType->link(outLinkedComponentType, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualCompositeComponentType->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::renameEntryPoint(
+ const char* newName, IComponentType** outEntryPoint)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualCompositeComponentType->renameEntryPoint(newName, outEntryPoint);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::linkWithOptions(
+ IComponentType** outLinkedComponentType,
+ uint32_t compilerOptionEntryCount,
+ slang::CompilerOptionEntry* compilerOptionEntries,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualCompositeComponentType->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics);
+ return res;
+ }
+}
diff --git a/source/slang-capture-replay/slang-composite-component-type.h b/source/slang-capture-replay/slang-composite-component-type.h
new file mode 100644
index 000000000..36cdd16a5
--- /dev/null
+++ b/source/slang-capture-replay/slang-composite-component-type.h
@@ -0,0 +1,68 @@
+#ifndef SLANG_COMPOSITE_COMPONENT_TYPE_H
+#define SLANG_COMPOSITE_COMPONENT_TYPE_H
+
+#include "../../slang-com-ptr.h"
+#include "../../slang.h"
+#include "../../slang-com-helper.h"
+#include "../core/slang-smart-pointer.h"
+#include "../core/slang-dictionary.h"
+#include "../slang/slang-compiler.h"
+
+namespace SlangCapture
+{
+ using namespace Slang;
+ class CompositeComponentTypeCapture: public slang::IComponentType, public RefObject
+ {
+ public:
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+ ISlangUnknown* getInterface(const Guid& guid);
+
+ explicit CompositeComponentTypeCapture(slang::IComponentType* componentType);
+ ~CompositeComponentTypeCapture();
+
+ // Interfaces for `IComponentType`
+ virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override;
+ virtual SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout(
+ SlangInt targetIndex = 0,
+ slang::IBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangInt SLANG_MCALL getSpecializationParamCount() override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ ISlangMutableFileSystem** outFileSystem) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL getEntryPointHash(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outHash) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL link(
+ slang::IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics = 0) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint(
+ const char* newName, IComponentType** outEntryPoint) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions(
+ IComponentType** outLinkedComponentType,
+ uint32_t compilerOptionEntryCount,
+ slang::CompilerOptionEntry* compilerOptionEntries,
+ ISlangBlob** outDiagnostics = nullptr) override;
+
+ slang::IComponentType* getActualCompositeComponentType() const { return m_actualCompositeComponentType; }
+ private:
+ Slang::ComPtr<slang::IComponentType> m_actualCompositeComponentType;
+ };
+}
+#endif // SLANG_COMPOSITE_COMPONENT_TYPE_H
diff --git a/source/slang-capture-replay/slang-entrypoint.cpp b/source/slang-capture-replay/slang-entrypoint.cpp
new file mode 100644
index 000000000..05e25fcb0
--- /dev/null
+++ b/source/slang-capture-replay/slang-entrypoint.cpp
@@ -0,0 +1,128 @@
+#include "capture_utility.h"
+#include "slang-entrypoint.h"
+
+namespace SlangCapture
+{
+ EntryPointCapture::EntryPointCapture(slang::IEntryPoint* entryPoint)
+ : m_actualEntryPoint(entryPoint)
+ {
+ SLANG_CAPTURE_ASSERT(m_actualEntryPoint != nullptr);
+ slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, entryPoint);
+ }
+
+ EntryPointCapture::~EntryPointCapture()
+ {
+ m_actualEntryPoint->release();
+ }
+
+ ISlangUnknown* EntryPointCapture::getInterface(const Guid& guid)
+ {
+ if(guid == EntryPointCapture::getTypeGuid())
+ return static_cast<ISlangUnknown*>(this);
+ else
+ return nullptr;
+ }
+
+ SLANG_NO_THROW slang::ISession* EntryPointCapture::getSession()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::ISession* res = m_actualEntryPoint->getSession();
+ return res;
+ }
+
+ SLANG_NO_THROW slang::ProgramLayout* EntryPointCapture::getLayout(
+ SlangInt targetIndex,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::ProgramLayout* res = m_actualEntryPoint->getLayout(targetIndex, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangInt EntryPointCapture::getSpecializationParamCount()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangInt res = m_actualEntryPoint->getSpecializationParamCount();
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult EntryPointCapture::getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualEntryPoint->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult EntryPointCapture::getResultAsFileSystem(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ ISlangMutableFileSystem** outFileSystem)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualEntryPoint->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem);
+ return res;
+ }
+
+ SLANG_NO_THROW void EntryPointCapture::getEntryPointHash(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outHash)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ m_actualEntryPoint->getEntryPointHash(entryPointIndex, targetIndex, outHash);
+ }
+
+ SLANG_NO_THROW SlangResult EntryPointCapture::specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualEntryPoint->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult EntryPointCapture::link(
+ slang::IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualEntryPoint->link(outLinkedComponentType, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult EntryPointCapture::getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualEntryPoint->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult EntryPointCapture::renameEntryPoint(
+ const char* newName, IComponentType** outEntryPoint)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualEntryPoint->renameEntryPoint(newName, outEntryPoint);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult EntryPointCapture::linkWithOptions(
+ IComponentType** outLinkedComponentType,
+ uint32_t compilerOptionEntryCount,
+ slang::CompilerOptionEntry* compilerOptionEntries,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualEntryPoint->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics);
+ return res;
+ }
+}
diff --git a/source/slang-capture-replay/slang-entrypoint.h b/source/slang-capture-replay/slang-entrypoint.h
new file mode 100644
index 000000000..2ff3b0718
--- /dev/null
+++ b/source/slang-capture-replay/slang-entrypoint.h
@@ -0,0 +1,70 @@
+#ifndef SLANG_ENTRY_POINT_H
+#define SLANG_ENTRY_POINT_H
+
+#include "../../slang-com-ptr.h"
+#include "../../slang.h"
+#include "../../slang-com-helper.h"
+#include "../core/slang-smart-pointer.h"
+#include "../core/slang-dictionary.h"
+#include "../slang/slang-compiler.h"
+
+namespace SlangCapture
+{
+ using namespace Slang;
+ class EntryPointCapture : public slang::IEntryPoint, public RefObject
+ {
+ public:
+ SLANG_COM_INTERFACE(0xf4c1e23d, 0xb321, 0x4931, { 0x8f, 0x37, 0xf1, 0x22, 0x6a, 0xf9, 0x20, 0x85 })
+
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+ ISlangUnknown* getInterface(const Guid& guid);
+
+ explicit EntryPointCapture(slang::IEntryPoint* entryPoint);
+ ~EntryPointCapture();
+
+ // Interfaces for `IComponentType`
+ virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override;
+ virtual SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout(
+ SlangInt targetIndex = 0,
+ slang::IBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangInt SLANG_MCALL getSpecializationParamCount() override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ ISlangMutableFileSystem** outFileSystem) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL getEntryPointHash(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outHash) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL link(
+ slang::IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics = 0) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint(
+ const char* newName, IComponentType** outEntryPoint) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions(
+ IComponentType** outLinkedComponentType,
+ uint32_t compilerOptionEntryCount,
+ slang::CompilerOptionEntry* compilerOptionEntries,
+ ISlangBlob** outDiagnostics = nullptr) override;
+
+ slang::IEntryPoint* getActualEntryPoint() const { return m_actualEntryPoint; }
+ private:
+ Slang::ComPtr<slang::IEntryPoint> m_actualEntryPoint;
+ };
+}
+#endif // SLANG_ENTRY_POINT_H
diff --git a/source/slang-capture-replay/slang-filesystem.cpp b/source/slang-capture-replay/slang-filesystem.cpp
index 10afbdf10..fac6cf66e 100644
--- a/source/slang-capture-replay/slang-filesystem.cpp
+++ b/source/slang-capture-replay/slang-filesystem.cpp
@@ -6,8 +6,8 @@ namespace SlangCapture
FileSystemCapture::FileSystemCapture(ISlangFileSystemExt* fileSystem)
: m_actualFileSystem(fileSystem)
{
- assert(m_actualFileSystem);
- slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __func__, m_actualFileSystem.get());
+ SLANG_CAPTURE_ASSERT(m_actualFileSystem);
+ slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, m_actualFileSystem.get());
}
FileSystemCapture::~FileSystemCapture()
@@ -31,7 +31,7 @@ namespace SlangCapture
char const* path,
ISlangBlob** outBlob)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __func__, path);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path);
SlangResult res = m_actualFileSystem->loadFile(path, outBlob);
return res;
}
@@ -40,7 +40,7 @@ namespace SlangCapture
const char* path,
ISlangBlob** outUniqueIdentity)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s :\"%s\"\n", m_actualFileSystem.get(), __func__, path);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s :\"%s\"\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path);
SlangResult res = m_actualFileSystem->getFileUniqueIdentity(path, outUniqueIdentity);
return res;
}
@@ -51,7 +51,7 @@ namespace SlangCapture
const char* path,
ISlangBlob** pathOut)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __func__, path);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path);
SlangResult res = m_actualFileSystem->calcCombinedPath(fromPathType, fromPath, path, pathOut);
return res;
}
@@ -60,7 +60,7 @@ namespace SlangCapture
const char* path,
SlangPathType* pathTypeOut)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __func__, path);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path);
SlangResult res = m_actualFileSystem->getPathType(path, pathTypeOut);
return res;
}
@@ -70,14 +70,14 @@ namespace SlangCapture
const char* path,
ISlangBlob** outPath)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __func__, path);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path);
SlangResult res = m_actualFileSystem->getPath(kind, path, outPath);
return res;
}
SLANG_NO_THROW void FileSystemCapture::clearCache()
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualFileSystem.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__);
m_actualFileSystem->clearCache();
}
@@ -86,14 +86,14 @@ namespace SlangCapture
FileSystemContentsCallBack callback,
void* userData)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __func__, path);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s, :%s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__, path);
SlangResult res = m_actualFileSystem->enumeratePathContents(path, callback, userData);
return res;
}
SLANG_NO_THROW OSPathKind FileSystemCapture::getOSPathKind()
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualFileSystem.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualFileSystem.get(), __PRETTY_FUNCTION__);
OSPathKind pathKind = m_actualFileSystem->getOSPathKind();
return pathKind;
}
diff --git a/source/slang-capture-replay/slang-global-session.cpp b/source/slang-capture-replay/slang-global-session.cpp
index e4754b9d2..05a2b93ec 100644
--- a/source/slang-capture-replay/slang-global-session.cpp
+++ b/source/slang-capture-replay/slang-global-session.cpp
@@ -1,17 +1,17 @@
#include <vector>
#include "slang-global-session.h"
-#include "capture_utility.h"
#include "slang-session.h"
#include "slang-filesystem.h"
#include "../slang/slang-compiler.h"
+#include "capture_utility.h"
namespace SlangCapture
{
GlobalSessionCapture::GlobalSessionCapture(slang::IGlobalSession* session):
m_actualGlobalSession(session)
{
- assert(m_actualGlobalSession != nullptr);
+ SLANG_CAPTURE_ASSERT(m_actualGlobalSession != nullptr);
}
GlobalSessionCapture::~GlobalSessionCapture()
@@ -29,7 +29,7 @@ namespace SlangCapture
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::createSession(slang::SessionDesc const& desc, slang::ISession** outSession)
{
setLogLevel();
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
slang::ISession* actualSession = nullptr;
SlangResult res = m_actualGlobalSession->createSession(desc, &actualSession);
@@ -55,152 +55,152 @@ namespace SlangCapture
SLANG_NO_THROW SlangProfileID SLANG_MCALL GlobalSessionCapture::findProfile(char const* name)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangProfileID profileId = m_actualGlobalSession->findProfile(name);
return profileId;
}
SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::setDownstreamCompilerPath(SlangPassThrough passThrough, char const* path)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
m_actualGlobalSession->setDownstreamCompilerPath(passThrough, path);
}
SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::setDownstreamCompilerPrelude(SlangPassThrough inPassThrough, char const* prelude)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
m_actualGlobalSession->setDownstreamCompilerPrelude(inPassThrough, prelude);
}
SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::getDownstreamCompilerPrelude(SlangPassThrough inPassThrough, ISlangBlob** outPrelude)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
m_actualGlobalSession->getDownstreamCompilerPrelude(inPassThrough, outPrelude);
}
SLANG_NO_THROW const char* SLANG_MCALL GlobalSessionCapture::getBuildTagString()
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
const char* resStr = m_actualGlobalSession->getBuildTagString();
return resStr;
}
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::setDefaultDownstreamCompiler(SlangSourceLanguage sourceLanguage, SlangPassThrough defaultCompiler)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangResult res = m_actualGlobalSession->setDefaultDownstreamCompiler(sourceLanguage, defaultCompiler);
return res;
}
SLANG_NO_THROW SlangPassThrough SLANG_MCALL GlobalSessionCapture::getDefaultDownstreamCompiler(SlangSourceLanguage sourceLanguage)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangPassThrough passThrough = m_actualGlobalSession->getDefaultDownstreamCompiler(sourceLanguage);
return passThrough;
}
SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::setLanguagePrelude(SlangSourceLanguage inSourceLanguage, char const* prelude)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
m_actualGlobalSession->setLanguagePrelude(inSourceLanguage, prelude);
}
SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::getLanguagePrelude(SlangSourceLanguage inSourceLanguage, ISlangBlob** outPrelude)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
m_actualGlobalSession->getLanguagePrelude(inSourceLanguage, outPrelude);
}
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::createCompileRequest(slang::ICompileRequest** outCompileRequest)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangResult res = m_actualGlobalSession->createCompileRequest(outCompileRequest);
return res;
}
SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::addBuiltins(char const* sourcePath, char const* sourceString)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
m_actualGlobalSession->addBuiltins(sourcePath, sourceString);
}
SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::setSharedLibraryLoader(ISlangSharedLibraryLoader* loader)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
m_actualGlobalSession->setSharedLibraryLoader(loader);
}
SLANG_NO_THROW ISlangSharedLibraryLoader* SLANG_MCALL GlobalSessionCapture::getSharedLibraryLoader()
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
ISlangSharedLibraryLoader* loader = m_actualGlobalSession->getSharedLibraryLoader();
return loader;
}
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::checkCompileTargetSupport(SlangCompileTarget target)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangResult res = m_actualGlobalSession->checkCompileTargetSupport(target);
return res;
}
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::checkPassThroughSupport(SlangPassThrough passThrough)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangResult res = m_actualGlobalSession->checkPassThroughSupport(passThrough);
return res;
}
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::compileStdLib(slang::CompileStdLibFlags flags)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangResult res = m_actualGlobalSession->compileStdLib(flags);
return res;
}
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangResult res = m_actualGlobalSession->loadStdLib(stdLib, stdLibSizeInBytes);
return res;
}
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::saveStdLib(SlangArchiveType archiveType, ISlangBlob** outBlob)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangResult res = m_actualGlobalSession->saveStdLib(archiveType, outBlob);
return res;
}
SLANG_NO_THROW SlangCapabilityID SLANG_MCALL GlobalSessionCapture::findCapability(char const* name)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangCapabilityID capId = m_actualGlobalSession->findCapability(name);
return capId;
}
SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::setDownstreamCompilerForTransition(SlangCompileTarget source, SlangCompileTarget target, SlangPassThrough compiler)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
m_actualGlobalSession->setDownstreamCompilerForTransition(source, target, compiler);
}
SLANG_NO_THROW SlangPassThrough SLANG_MCALL GlobalSessionCapture::getDownstreamCompilerForTransition(SlangCompileTarget source, SlangCompileTarget target)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangPassThrough passThrough = m_actualGlobalSession->getDownstreamCompilerForTransition(source, target);
return passThrough;
}
SLANG_NO_THROW void SLANG_MCALL GlobalSessionCapture::getCompilerElapsedTime(double* outTotalTime, double* outDownstreamTime)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
m_actualGlobalSession->getCompilerElapsedTime(outTotalTime, outDownstreamTime);
}
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::setSPIRVCoreGrammar(char const* jsonPath)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangResult res = m_actualGlobalSession->setSPIRVCoreGrammar(jsonPath);
return res;
}
@@ -208,14 +208,14 @@ namespace SlangCapture
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::parseCommandLineArguments(
int argc, const char* const* argv, slang::SessionDesc* outSessionDesc, ISlangUnknown** outAllocation)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangResult res = m_actualGlobalSession->parseCommandLineArguments(argc, argv, outSessionDesc, outAllocation);
return res;
}
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::getSessionDescDigest(slang::SessionDesc* sessionDesc, ISlangBlob** outBlob)
{
- slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __func__);
+ slangCaptureLog(LogLevel::Verbose, "%p: %s\n", m_actualGlobalSession.get(), __PRETTY_FUNCTION__);
SlangResult res = m_actualGlobalSession->getSessionDescDigest(sessionDesc, outBlob);
return res;
}
diff --git a/source/slang-capture-replay/slang-module.cpp b/source/slang-capture-replay/slang-module.cpp
new file mode 100644
index 000000000..31aa6bd54
--- /dev/null
+++ b/source/slang-capture-replay/slang-module.cpp
@@ -0,0 +1,239 @@
+#include "capture_utility.h"
+#include "slang-module.h"
+
+namespace SlangCapture
+{
+ ModuleCapture::ModuleCapture(slang::IModule* module)
+ : m_actualModule(module)
+ {
+ SLANG_CAPTURE_ASSERT(m_actualModule != nullptr);
+ slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, module);
+ }
+
+ ModuleCapture::~ModuleCapture()
+ {
+ m_actualModule->release();
+ }
+
+ ISlangUnknown* ModuleCapture::getInterface(const Guid& guid)
+ {
+ if(guid == ModuleCapture::getTypeGuid())
+ return static_cast<ISlangUnknown*>(this);
+ else
+ return nullptr;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::findEntryPointByName(
+ char const* name,
+ slang::IEntryPoint** outEntryPoint)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+
+ SlangResult res = m_actualModule->findEntryPointByName(name, outEntryPoint);
+
+ if (SLANG_OK == res)
+ {
+ EntryPointCapture* entryPointCapture = getEntryPointCapture(*outEntryPoint);
+ *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture);
+ }
+ return res;
+ }
+
+ SLANG_NO_THROW SlangInt32 ModuleCapture::getDefinedEntryPointCount()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangInt32 res = m_actualModule->getDefinedEntryPointCount();
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint)
+ {
+ // This call is to find the existing entry point, so it has been created already. Therefore, we don't create a new one
+ // and assert the error if it is not found in our map.
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->getDefinedEntryPoint(index, outEntryPoint);
+
+ if (*outEntryPoint)
+ {
+ EntryPointCapture* entryPointCapture = m_mapEntryPointToCapture.tryGetValue(*outEntryPoint);
+ if (!entryPointCapture)
+ {
+ SLANG_CAPTURE_ASSERT(!"Entrypoint not found in mapEntryPointToCapture");
+ }
+ *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture);
+ }
+ else
+ *outEntryPoint = nullptr;
+
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::serialize(ISlangBlob** outSerializedBlob)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->serialize(outSerializedBlob);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::writeToFile(char const* fileName)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->writeToFile(fileName);
+ return res;
+ }
+
+ SLANG_NO_THROW const char* ModuleCapture::getName()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ const char* res = m_actualModule->getName();
+ return res;
+ }
+
+ SLANG_NO_THROW const char* ModuleCapture::getFilePath()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ const char* res = m_actualModule->getFilePath();
+ return res;
+ }
+
+ SLANG_NO_THROW const char* ModuleCapture::getUniqueIdentity()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ const char* res = m_actualModule->getUniqueIdentity();
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::findAndCheckEntryPoint(
+ char const* name,
+ SlangStage stage,
+ slang::IEntryPoint** outEntryPoint,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+
+ SlangResult res = m_actualModule->findAndCheckEntryPoint(name, stage, outEntryPoint, outDiagnostics);
+
+ if (SLANG_OK == res)
+ {
+ EntryPointCapture* entryPointCapture = getEntryPointCapture(*outEntryPoint);
+ *outEntryPoint = static_cast<slang::IEntryPoint*>(entryPointCapture);
+ }
+ return res;
+ }
+
+ SLANG_NO_THROW slang::ISession* ModuleCapture::getSession()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::ISession* session = m_actualModule->getSession();
+ return session;
+ }
+
+ SLANG_NO_THROW slang::ProgramLayout* ModuleCapture::getLayout(
+ SlangInt targetIndex,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::ProgramLayout* programLayout = m_actualModule->getLayout(targetIndex, outDiagnostics);
+ return programLayout;
+ }
+
+ SLANG_NO_THROW SlangInt ModuleCapture::getSpecializationParamCount()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangInt res = m_actualModule->getSpecializationParamCount();
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::getResultAsFileSystem(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ ISlangMutableFileSystem** outFileSystem)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem);
+ return res;
+ }
+
+ SLANG_NO_THROW void ModuleCapture::getEntryPointHash(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outHash)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ m_actualModule->getEntryPointHash(entryPointIndex, targetIndex, outHash);
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::link(
+ IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->link(outLinkedComponentType, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::renameEntryPoint(
+ const char* newName, IComponentType** outEntryPoint)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->renameEntryPoint(newName, outEntryPoint);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult ModuleCapture::linkWithOptions(
+ IComponentType** outLinkedComponentType,
+ uint32_t compilerOptionEntryCount,
+ slang::CompilerOptionEntry* compilerOptionEntries,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualModule->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics);
+ return res;
+ }
+
+ EntryPointCapture* ModuleCapture::getEntryPointCapture(slang::IEntryPoint* entryPoint)
+ {
+ EntryPointCapture* entryPointCapture = nullptr;
+ entryPointCapture = m_mapEntryPointToCapture.tryGetValue(entryPoint);
+ if (!entryPointCapture)
+ {
+ entryPointCapture = new EntryPointCapture(entryPoint);
+ Slang::ComPtr<EntryPointCapture> result(entryPointCapture);
+ m_mapEntryPointToCapture.add(entryPoint, *result.detach());
+ }
+ return entryPointCapture;
+ }
+}
diff --git a/source/slang-capture-replay/slang-module.h b/source/slang-capture-replay/slang-module.h
new file mode 100644
index 000000000..d1180c828
--- /dev/null
+++ b/source/slang-capture-replay/slang-module.h
@@ -0,0 +1,91 @@
+#ifndef SLANG_MODULE_H
+#define SLANG_MODULE_H
+
+#include "../../slang-com-ptr.h"
+#include "../../slang.h"
+#include "../../slang-com-helper.h"
+#include "../core/slang-smart-pointer.h"
+#include "../slang/slang-compiler.h"
+#include "slang-entrypoint.h"
+
+namespace SlangCapture
+{
+ using namespace Slang;
+ class ModuleCapture : public slang::IModule, public RefObject
+ {
+ public:
+ SLANG_COM_INTERFACE(0xb1802991, 0x185a, 0x4a03, { 0xa7, 0x7e, 0x0c, 0x86, 0xe0, 0x68, 0x2a, 0xab })
+
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+ ISlangUnknown* getInterface(const Guid& guid);
+
+ explicit ModuleCapture(slang::IModule* module);
+ ~ModuleCapture();
+
+ // Interfaces for `IModule`
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL findEntryPointByName(
+ char const* name, slang::IEntryPoint** outEntryPoint) override;
+ virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDefinedEntryPointCount() override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL
+ getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL serialize(ISlangBlob** outSerializedBlob) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL writeToFile(char const* fileName) override;
+ virtual SLANG_NO_THROW const char* SLANG_MCALL getName() override;
+ virtual SLANG_NO_THROW const char* SLANG_MCALL getFilePath() override;
+ virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint(
+ char const* name, SlangStage stage, slang::IEntryPoint** outEntryPoint, ISlangBlob** outDiagnostics) override;
+
+ // Interfaces for `IComponentType`
+ virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override;
+ virtual SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout(
+ SlangInt targetIndex = 0,
+ slang::IBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangInt SLANG_MCALL getSpecializationParamCount() override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ ISlangMutableFileSystem** outFileSystem) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL getEntryPointHash(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outHash) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL link(
+ slang::IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics = 0) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint(
+ const char* newName, IComponentType** outEntryPoint) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions(
+ IComponentType** outLinkedComponentType,
+ uint32_t compilerOptionEntryCount,
+ slang::CompilerOptionEntry* compilerOptionEntries,
+ ISlangBlob** outDiagnostics = nullptr) override;
+
+ slang::IModule* getActualModule() const { return m_actualModule; }
+ private:
+ EntryPointCapture* getEntryPointCapture(slang::IEntryPoint* entryPoint);
+ Slang::ComPtr<slang::IModule> m_actualModule;
+
+ // `IEntryPoint` can only be created from 'IModule', so we need to capture it in
+ // this class, and create a map such that we don't create new `EntryPointCapture`
+ // for the same `IEntryPoint`.
+ Dictionary<slang::IEntryPoint*, EntryPointCapture> m_mapEntryPointToCapture;
+ };
+} // namespace SlangCapture
+
+#endif // SLANG_MODULE_H
diff --git a/source/slang-capture-replay/slang-session.cpp b/source/slang-capture-replay/slang-session.cpp
index 3d96af6d7..5cf029629 100644
--- a/source/slang-capture-replay/slang-session.cpp
+++ b/source/slang-capture-replay/slang-session.cpp
@@ -1,5 +1,8 @@
#include "capture_utility.h"
#include "slang-session.h"
+#include "slang-entrypoint.h"
+#include "slang-composite-component-type.h"
+#include "slang-type-conformance.h"
namespace SlangCapture
{
@@ -7,8 +10,8 @@ namespace SlangCapture
SessionCapture::SessionCapture(slang::ISession* session)
: m_actualSession(session)
{
- assert(m_actualSession);
- slangCaptureLog(LogLevel::Verbose, "%s: %p\n", "Session", session);
+ SLANG_CAPTURE_ASSERT(m_actualSession);
+ slangCaptureLog(LogLevel::Verbose, "%s: %p\n", "SessionCapture create:", session);
}
SessionCapture::~SessionCapture()
@@ -26,7 +29,7 @@ namespace SlangCapture
SLANG_NO_THROW slang::IGlobalSession* SessionCapture::getGlobalSession()
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
slang::IGlobalSession* pGlobalSession = m_actualSession->getGlobalSession();
return pGlobalSession;
}
@@ -35,9 +38,10 @@ namespace SlangCapture
const char* moduleName,
slang::IBlob** outDiagnostics)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
slang::IModule* pModule = m_actualSession->loadModule(moduleName, outDiagnostics);
- return pModule;
+ ModuleCapture* pModuleCapture = getModuleCapture(pModule);
+ return static_cast<slang::IModule*>(pModuleCapture);
}
SLANG_NO_THROW slang::IModule* SessionCapture::loadModuleFromIRBlob(
@@ -46,9 +50,10 @@ namespace SlangCapture
slang::IBlob* source,
slang::IBlob** outDiagnostics)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
slang::IModule* pModule = m_actualSession->loadModuleFromIRBlob(moduleName, path, source, outDiagnostics);
- return pModule;
+ ModuleCapture* pModuleCapture = getModuleCapture(pModule);
+ return static_cast<slang::IModule*>(pModuleCapture);
}
SLANG_NO_THROW slang::IModule* SessionCapture::loadModuleFromSource(
@@ -57,9 +62,10 @@ namespace SlangCapture
slang::IBlob* source,
slang::IBlob** outDiagnostics)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
slang::IModule* pModule = m_actualSession->loadModuleFromSource(moduleName, path, source, outDiagnostics);
- return pModule;
+ ModuleCapture* pModuleCapture = getModuleCapture(pModule);
+ return static_cast<slang::IModule*>(pModuleCapture);
}
SLANG_NO_THROW slang::IModule* SessionCapture::loadModuleFromSourceString(
@@ -68,9 +74,10 @@ namespace SlangCapture
const char* string,
slang::IBlob** outDiagnostics)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
slang::IModule* pModule = m_actualSession->loadModuleFromSourceString(moduleName, path, string, outDiagnostics);
- return pModule;
+ ModuleCapture* pModuleCapture = getModuleCapture(pModule);
+ return static_cast<slang::IModule*>(pModuleCapture);
}
SLANG_NO_THROW SlangResult SessionCapture::createCompositeComponentType(
@@ -79,8 +86,26 @@ namespace SlangCapture
slang::IComponentType** outCompositeComponentType,
ISlangBlob** outDiagnostics)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
- SlangResult result = m_actualSession->createCompositeComponentType(componentTypes, componentTypeCount, outCompositeComponentType, outDiagnostics);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+
+ Slang::List<slang::IComponentType*> componentTypeList;
+
+ // get the actual component types from our capture wrappers
+ if(SLANG_OK != getActualComponentTypes(componentTypes, componentTypeCount, componentTypeList))
+ {
+ SLANG_CAPTURE_ASSERT(!"Failed to get actual component types");
+ }
+
+ SlangResult result = m_actualSession->createCompositeComponentType(
+ componentTypeList.getBuffer(), componentTypeCount, outCompositeComponentType, outDiagnostics);
+
+ if (SLANG_OK == result)
+ {
+ CompositeComponentTypeCapture* compositeComponentTypeCapture = new CompositeComponentTypeCapture(*outCompositeComponentType);
+ Slang::ComPtr<CompositeComponentTypeCapture> resultCapture(compositeComponentTypeCapture);
+ *outCompositeComponentType = resultCapture.detach();
+ }
+
return result;
}
@@ -90,7 +115,7 @@ namespace SlangCapture
SlangInt specializationArgCount,
ISlangBlob** outDiagnostics)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
slang::TypeReflection* pTypeReflection = m_actualSession->specializeType(type, specializationArgs, specializationArgCount, outDiagnostics);
return pTypeReflection;
}
@@ -101,7 +126,7 @@ namespace SlangCapture
slang::LayoutRules rules,
ISlangBlob** outDiagnostics)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
slang::TypeLayoutReflection* pTypeLayoutReflection = m_actualSession->getTypeLayout(type, targetIndex, rules, outDiagnostics);
return pTypeLayoutReflection;
}
@@ -111,14 +136,14 @@ namespace SlangCapture
slang::ContainerType containerType,
ISlangBlob** outDiagnostics)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
slang::TypeReflection* pTypeReflection = m_actualSession->getContainerType(elementType, containerType, outDiagnostics);
return pTypeReflection;
}
SLANG_NO_THROW slang::TypeReflection* SessionCapture::getDynamicType()
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
slang::TypeReflection* pTypeReflection = m_actualSession->getDynamicType();
return pTypeReflection;
}
@@ -127,7 +152,7 @@ namespace SlangCapture
slang::TypeReflection* type,
ISlangBlob** outNameBlob)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
SlangResult result = m_actualSession->getTypeRTTIMangledName(type, outNameBlob);
return result;
}
@@ -137,7 +162,7 @@ namespace SlangCapture
slang::TypeReflection* interfaceType,
ISlangBlob** outNameBlob)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
SlangResult result = m_actualSession->getTypeConformanceWitnessMangledName(type, interfaceType, outNameBlob);
return result;
}
@@ -147,7 +172,7 @@ namespace SlangCapture
slang::TypeReflection* interfaceType,
uint32_t* outId)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
SlangResult result = m_actualSession->getTypeConformanceWitnessSequentialID(type, interfaceType, outId);
return result;
}
@@ -159,38 +184,104 @@ namespace SlangCapture
SlangInt conformanceIdOverride,
ISlangBlob** outDiagnostics)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+
SlangResult result = m_actualSession->createTypeConformanceComponentType(type, interfaceType, outConformance, conformanceIdOverride, outDiagnostics);
+
+ if (SLANG_OK != result)
+ {
+ TypeConformanceCapture* conformanceCapture = new TypeConformanceCapture(*outConformance);
+ Slang::ComPtr<TypeConformanceCapture> resultCapture(conformanceCapture);
+ *outConformance = resultCapture.detach();
+ }
+
return result;
}
SLANG_NO_THROW SlangResult SessionCapture::createCompileRequest(
SlangCompileRequest** outCompileRequest)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
SlangResult result = m_actualSession->createCompileRequest(outCompileRequest);
return result;
}
SLANG_NO_THROW SlangInt SessionCapture::getLoadedModuleCount()
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
SlangInt count = m_actualSession->getLoadedModuleCount();
return count;
}
SLANG_NO_THROW slang::IModule* SessionCapture::getLoadedModule(SlangInt index)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
slang::IModule* pModule = m_actualSession->getLoadedModule(index);
+
+ if (pModule)
+ {
+ ModuleCapture* moduleCapture = m_mapModuleToCapture.tryGetValue(pModule);
+ if (!moduleCapture)
+ {
+ SLANG_CAPTURE_ASSERT(!"Module not found in mapModuleToCapture");
+ }
+ return static_cast<slang::IModule*>(moduleCapture);
+ }
+
return pModule;
}
SLANG_NO_THROW bool SessionCapture::isBinaryModuleUpToDate(const char* modulePath, slang::IBlob* binaryModuleBlob)
{
- slangCaptureLog(LogLevel::Verbose, "%s\n", __func__);
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
bool result = m_actualSession->isBinaryModuleUpToDate(modulePath, binaryModuleBlob);
return result;
}
+ ModuleCapture* SessionCapture::getModuleCapture(slang::IModule* module)
+ {
+ ModuleCapture* moduleCapture = nullptr;
+ moduleCapture = m_mapModuleToCapture.tryGetValue(module);
+ if (!moduleCapture)
+ {
+ moduleCapture = new ModuleCapture(module);
+ Slang::ComPtr<ModuleCapture> result(moduleCapture);
+ m_mapModuleToCapture.add(module, *result.detach());
+ }
+ return moduleCapture;
+ }
+
+ SlangResult SessionCapture::getActualComponentTypes(
+ slang::IComponentType* const* componentTypes,
+ SlangInt componentTypeCount,
+ List<slang::IComponentType*>& outActualComponentTypes)
+ {
+ for (SlangInt i = 0; i < componentTypeCount; i++)
+ {
+ slang::IComponentType* const& componentType = componentTypes[i];
+ void* outObj = nullptr;
+
+ if (componentType->queryInterface(ModuleCapture::getTypeGuid(), &outObj) == SLANG_OK)
+ {
+ ModuleCapture* moduleCapture = static_cast<ModuleCapture*>(outObj);
+ outActualComponentTypes.add(moduleCapture->getActualModule());
+ }
+ else if (componentType->queryInterface(EntryPointCapture::getTypeGuid(), &outObj) == SLANG_OK)
+ {
+ EntryPointCapture* entrypointCapture = static_cast<EntryPointCapture*>(outObj);
+ outActualComponentTypes.add(entrypointCapture->getActualEntryPoint());
+ }
+ // will fall back to the actual component type, it means that we didn't capture this type.
+ else
+ {
+ outActualComponentTypes.add(componentType);
+ }
+ }
+
+ if (componentTypeCount == outActualComponentTypes.getCount())
+ {
+ return SLANG_OK;
+ }
+ return SLANG_FAIL;
+ }
} // namespace SlangCapture
diff --git a/source/slang-capture-replay/slang-session.h b/source/slang-capture-replay/slang-session.h
index 24903a1b2..f099108b0 100644
--- a/source/slang-capture-replay/slang-session.h
+++ b/source/slang-capture-replay/slang-session.h
@@ -5,7 +5,9 @@
#include "../../slang.h"
#include "../../slang-com-helper.h"
#include "../core/slang-smart-pointer.h"
+#include "../core/slang-dictionary.h"
#include "../slang/slang-compiler.h"
+#include "slang-module.h"
namespace SlangCapture
{
@@ -92,7 +94,19 @@ namespace SlangCapture
{
return static_cast<slang::ISession*>(session);
}
+
+ // The IComponentType object is the capture target, therefore `componentTypes` will not be
+ // the actual component types, we have to use the COM interface to get the actual objects.
+ SlangResult getActualComponentTypes(
+ slang::IComponentType* const* componentTypes,
+ SlangInt componentTypeCount,
+ List<slang::IComponentType*>& outActualComponentTypes);
+
+ ModuleCapture* getModuleCapture(slang::IModule* module);
+
Slang::ComPtr<slang::ISession> m_actualSession;
+
+ Dictionary<slang::IModule*, ModuleCapture> m_mapModuleToCapture;
};
}
diff --git a/source/slang-capture-replay/slang-type-conformance.cpp b/source/slang-capture-replay/slang-type-conformance.cpp
new file mode 100644
index 000000000..1ca9cf737
--- /dev/null
+++ b/source/slang-capture-replay/slang-type-conformance.cpp
@@ -0,0 +1,131 @@
+#include "capture_utility.h"
+#include "slang-type-conformance.h"
+
+namespace SlangCapture
+{
+ TypeConformanceCapture::TypeConformanceCapture(slang::ITypeConformance* typeConformance)
+ : m_actualTypeConformance(typeConformance)
+ {
+ SLANG_CAPTURE_ASSERT(m_actualTypeConformance != nullptr);
+ slangCaptureLog(LogLevel::Verbose, "%s: %p\n", __PRETTY_FUNCTION__, typeConformance);
+ }
+ TypeConformanceCapture::~TypeConformanceCapture()
+ {
+ m_actualTypeConformance->release();
+ }
+
+ ISlangUnknown* TypeConformanceCapture::getInterface(const Guid& guid)
+ {
+ if (guid == TypeConformanceCapture::getTypeGuid())
+ {
+ return static_cast<ISlangUnknown*>(this);
+ }
+ else
+ {
+ return nullptr;
+ }
+ }
+
+ SLANG_NO_THROW slang::ISession* TypeConformanceCapture::getSession()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::ISession* res = m_actualTypeConformance->getSession();
+ return res;
+ }
+
+ SLANG_NO_THROW slang::ProgramLayout* TypeConformanceCapture::getLayout(
+ SlangInt targetIndex,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::ProgramLayout* res = m_actualTypeConformance->getLayout(targetIndex, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangInt TypeConformanceCapture::getSpecializationParamCount()
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangInt res = m_actualTypeConformance->getSpecializationParamCount();
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult TypeConformanceCapture::getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualTypeConformance->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult TypeConformanceCapture::getResultAsFileSystem(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ ISlangMutableFileSystem** outFileSystem)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualTypeConformance->getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem);
+ return res;
+ }
+
+ SLANG_NO_THROW void TypeConformanceCapture::getEntryPointHash(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outHash)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ m_actualTypeConformance->getEntryPointHash(entryPointIndex, targetIndex, outHash);
+ }
+
+ SLANG_NO_THROW SlangResult TypeConformanceCapture::specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualTypeConformance->specialize(specializationArgs, specializationArgCount, outSpecializedComponentType, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult TypeConformanceCapture::link(
+ slang::IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualTypeConformance->link(outLinkedComponentType, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult TypeConformanceCapture::getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualTypeConformance->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult TypeConformanceCapture::renameEntryPoint(
+ const char* newName, IComponentType** outEntryPoint)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualTypeConformance->renameEntryPoint(newName, outEntryPoint);
+ return res;
+ }
+
+ SLANG_NO_THROW SlangResult TypeConformanceCapture::linkWithOptions(
+ IComponentType** outLinkedComponentType,
+ uint32_t compilerOptionEntryCount,
+ slang::CompilerOptionEntry* compilerOptionEntries,
+ ISlangBlob** outDiagnostics)
+ {
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ SlangResult res = m_actualTypeConformance->linkWithOptions(outLinkedComponentType, compilerOptionEntryCount, compilerOptionEntries, outDiagnostics);
+ return res;
+ }
+}
diff --git a/source/slang-capture-replay/slang-type-conformance.h b/source/slang-capture-replay/slang-type-conformance.h
new file mode 100644
index 000000000..e2a7b27c9
--- /dev/null
+++ b/source/slang-capture-replay/slang-type-conformance.h
@@ -0,0 +1,70 @@
+#ifndef SLANG_TYPE_CONFORMANCE_H
+#define SLANG_TYPE_CONFORMANCE_H
+
+#include "../../slang-com-ptr.h"
+#include "../../slang.h"
+#include "../../slang-com-helper.h"
+#include "../core/slang-smart-pointer.h"
+#include "../core/slang-dictionary.h"
+#include "../slang/slang-compiler.h"
+
+namespace SlangCapture
+{
+ using namespace Slang;
+ class TypeConformanceCapture: public slang::ITypeConformance, public RefObject
+ {
+ public:
+ SLANG_COM_INTERFACE(0x0e67d05d, 0xee0a, 0x41e1, { 0xb5, 0xa3, 0x23, 0xe3, 0xb0, 0xec, 0x33, 0xf1 })
+
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+ ISlangUnknown* getInterface(const Guid& guid);
+
+ explicit TypeConformanceCapture(slang::ITypeConformance* typeConformance);
+ ~TypeConformanceCapture();
+
+ // Interfaces for `IComponentType`
+ virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override;
+ virtual SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout(
+ SlangInt targetIndex = 0,
+ slang::IBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangInt SLANG_MCALL getSpecializationParamCount() override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outCode,
+ slang::IBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ ISlangMutableFileSystem** outFileSystem) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL getEntryPointHash(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ slang::IBlob** outHash) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL specialize(
+ slang::SpecializationArg const* specializationArgs,
+ SlangInt specializationArgCount,
+ slang::IComponentType** outSpecializedComponentType,
+ ISlangBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL link(
+ slang::IComponentType** outLinkedComponentType,
+ ISlangBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable(
+ int entryPointIndex,
+ int targetIndex,
+ ISlangSharedLibrary** outSharedLibrary,
+ slang::IBlob** outDiagnostics = 0) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint(
+ const char* newName, IComponentType** outEntryPoint) override;
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions(
+ IComponentType** outLinkedComponentType,
+ uint32_t compilerOptionEntryCount,
+ slang::CompilerOptionEntry* compilerOptionEntries,
+ ISlangBlob** outDiagnostics = nullptr) override;
+
+ slang::ITypeConformance* getActualTypeConformance() const { return m_actualTypeConformance; }
+ private:
+ Slang::ComPtr<slang::ITypeConformance> m_actualTypeConformance;
+ };
+}
+#endif // SLANG_TYPE_CONFORMANCE_H
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index fc97a2f47..770449de7 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -1335,6 +1335,11 @@ namespace Slang
char const* name,
slang::IEntryPoint** outEntryPoint) SLANG_OVERRIDE
{
+ if (outEntryPoint == nullptr)
+ {
+ return SLANG_E_INVALID_ARG;
+ }
+
ComPtr<slang::IEntryPoint> entryPoint(findEntryPointByName(UnownedStringSlice(name)));
if((!entryPoint))
return SLANG_FAIL;
@@ -1349,6 +1354,11 @@ namespace Slang
slang::IEntryPoint** outEntryPoint,
ISlangBlob** outDiagnostics) override
{
+ if (outEntryPoint == nullptr)
+ {
+ return SLANG_E_INVALID_ARG;
+ }
+
ComPtr<slang::IEntryPoint> entryPoint(findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics));
if ((!entryPoint))
return SLANG_FAIL;
@@ -1367,6 +1377,11 @@ namespace Slang
if (index < 0 || index >= m_entryPoints.getCount())
return SLANG_E_INVALID_ARG;
+ if (outEntryPoint == nullptr)
+ {
+ return SLANG_E_INVALID_ARG;
+ }
+
ComPtr<slang::IEntryPoint> entryPoint(m_entryPoints[index].Ptr());
*outEntryPoint = entryPoint.detach();
return SLANG_OK;
@@ -1541,7 +1556,7 @@ namespace Slang
// List of source files this module depends on
FileDependencyList m_fileDependencyList;
- // Entry points that were defined in thsi module
+ // Entry points that were defined in this module
//
// Note: the entry point defined in the module are *not*
// part of the memory image/layout of the module when
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 4d83823d2..c6af8b34d 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1266,6 +1266,9 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompositeComponentType(
slang::IComponentType** outCompositeComponentType,
ISlangBlob** outDiagnostics)
{
+ if (outCompositeComponentType == nullptr)
+ return SLANG_E_INVALID_ARG;
+
SLANG_AST_BUILDER_RAII(getASTBuilder());
// Attempting to create a "composite" of just one component type should
@@ -1491,6 +1494,9 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentTy
SlangInt conformanceIdOverride,
ISlangBlob** outDiagnostics)
{
+ if (outConformanceComponentType == nullptr)
+ return SLANG_E_INVALID_ARG;
+
SLANG_AST_BUILDER_RAII(getASTBuilder());
RefPtr<TypeConformance> result;