summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2022-07-18 15:44:29 -0400
committerGitHub <noreply@github.com>2022-07-18 15:44:29 -0400
commit2e4b5770fa7e6dbf56845382706b33a22d6a025b (patch)
treee58bcb7c446fcbffc8e1dd725636dbd737d946fc /source
parentda8f050f9865635e5778589e1e3883b614d73266 (diff)
Atomic ref counting for ISlangSharedLibrary (#2332)
* #include an absolute path didn't work - because paths were taken to always be relative. * Make ISlangSharedLibrary atomic ref counted. Update docs to say most COM interfaces are *not* atomic ref counted. * Upgrade slang-llvm to use version that atomic ref counts ISlangSharedLibrary. * Fix some typos in docs. * Fix ref count typo. * Fix missing 'override'
Diffstat (limited to 'source')
-rw-r--r--source/compiler-core/slang-downstream-compiler.cpp13
-rw-r--r--source/core/slang-com-object.h68
-rw-r--r--source/core/slang-shared-library.cpp5
-rw-r--r--source/core/slang-shared-library.h12
4 files changed, 86 insertions, 12 deletions
diff --git a/source/compiler-core/slang-downstream-compiler.cpp b/source/compiler-core/slang-downstream-compiler.cpp
index 85726557e..2e30c23bc 100644
--- a/source/compiler-core/slang-downstream-compiler.cpp
+++ b/source/compiler-core/slang-downstream-compiler.cpp
@@ -360,11 +360,16 @@ SlangResult CommandLineDownstreamCompileResult::getHostCallableSharedLibrary(Com
{
return SLANG_FAIL;
}
- // The shared library needs to keep temp files in scope
- RefPtr<TemporarySharedLibrary> sharedLib(new TemporarySharedLibrary(handle, m_moduleFilePath));
- sharedLib->m_temporaryFileSet = m_temporaryFiles;
+
+ {
+ // The shared library needs to keep temp files in scope
+ auto temporarySharedLibrary = new TemporarySharedLibrary(handle, m_moduleFilePath);
+ // Make sure it gets a ref count
+ m_hostCallableSharedLibrary = temporarySharedLibrary;
+ // Set any additional info on the non COM pointer
+ temporarySharedLibrary->m_temporaryFileSet = m_temporaryFiles;
+ }
- m_hostCallableSharedLibrary = sharedLib;
outLibrary = m_hostCallableSharedLibrary;
return SLANG_OK;
}
diff --git a/source/core/slang-com-object.h b/source/core/slang-com-object.h
index 50dac5ba2..617b7ccca 100644
--- a/source/core/slang-com-object.h
+++ b/source/core/slang-com-object.h
@@ -6,6 +6,74 @@
namespace Slang
{
+
+/// A base class for COM interfaces that require atomic ref counting
+/// and are *NOT* derived from RefObject
+class ComBaseObject
+{
+public:
+
+ /// If assigned the the ref count is *NOT* copied
+ ComBaseObject& operator=(const ComBaseObject&) { return *this; }
+
+ /// Copy Ctor, does not copy ref count
+ ComBaseObject(const ComBaseObject&) :
+ m_refCount(0)
+ {}
+
+ /// Default Ctor sets with no refs
+ ComBaseObject()
+ : m_refCount(0)
+ {}
+
+ /// Dtor needs to be virtual to avoid needing to
+ /// Implement release for all derived types.
+ virtual ~ComBaseObject()
+ {}
+
+protected:
+ inline uint32_t _releaseImpl();
+
+ std::atomic<uint32_t> m_refCount;
+};
+
+// ------------------------------------------------------------------
+inline uint32_t ComBaseObject::_releaseImpl()
+{
+ // Check there is a ref count to avoid underflow
+ SLANG_ASSERT(m_refCount != 0);
+ const uint32_t count = --m_refCount;
+ if (count == 0)
+ {
+ delete this;
+ }
+ return count;
+}
+
+#define SLANG_COM_BASE_IUNKNOWN_QUERY_INTERFACE \
+ SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) \
+ SLANG_OVERRIDE \
+ { \
+ void* intf = getInterface(uuid); \
+ if (intf) \
+ { \
+ ++m_refCount; \
+ *outObject = intf; \
+ return SLANG_OK; \
+ } \
+ return SLANG_E_NO_INTERFACE; \
+ }
+#define SLANG_COM_BASE_IUNKNOWN_ADD_REF \
+ SLANG_NO_THROW uint32_t SLANG_MCALL addRef() SLANG_OVERRIDE { return ++m_refCount; }
+#define SLANG_COM_BASE_IUNKNOWN_RELEASE \
+ SLANG_NO_THROW uint32_t SLANG_MCALL release() SLANG_OVERRIDE { return _releaseImpl(); }
+#define SLANG_COM_BASE_IUNKNOWN_ALL \
+ SLANG_COM_BASE_IUNKNOWN_QUERY_INTERFACE \
+ SLANG_COM_BASE_IUNKNOWN_ADD_REF \
+ SLANG_COM_BASE_IUNKNOWN_RELEASE
+
+
+/// COM object that derives from RefObject
class ComObject : public RefObject
{
protected:
diff --git a/source/core/slang-shared-library.cpp b/source/core/slang-shared-library.cpp
index 1513b420f..f31c7d689 100644
--- a/source/core/slang-shared-library.cpp
+++ b/source/core/slang-shared-library.cpp
@@ -78,6 +78,8 @@ TemporarySharedLibrary::~TemporarySharedLibrary()
SLANG_NO_THROW SlangResult SLANG_MCALL DefaultSharedLibrary::queryInterface(SlangUUID const& uuid, void** outObject)
{
+ // Mechanism to cast to underlying type.
+ // NOTE! Purposefully does not ref count
if (uuid == DefaultSharedLibrary::getTypeGuid())
{
*outObject = this;
@@ -86,7 +88,7 @@ SLANG_NO_THROW SlangResult SLANG_MCALL DefaultSharedLibrary::queryInterface(Slan
if (uuid == ISlangUnknown::getTypeGuid() || uuid == ISlangSharedLibrary::getTypeGuid())
{
- addReference();
+ ++m_refCount;
*outObject = static_cast<ISlangSharedLibrary*>(this);
return SLANG_OK;
}
@@ -106,7 +108,6 @@ void* DefaultSharedLibrary::findSymbolAddressByName(char const* name)
return SharedLibrary::findSymbolAddressByName(m_sharedLibraryHandle, name);
}
-
String SharedLibraryUtils::getSharedLibraryFileName(void* symbolInLib)
{
#if defined(_WIN32)
diff --git a/source/core/slang-shared-library.h b/source/core/slang-shared-library.h
index 452379d68..44adb1ac6 100644
--- a/source/core/slang-shared-library.h
+++ b/source/core/slang-shared-library.h
@@ -5,6 +5,7 @@
#include "../../slang-com-helper.h"
#include "../../slang-com-ptr.h"
+#include "../core/slang-com-object.h"
#include "../core/slang-io.h"
#include "../core/slang-platform.h"
#include "../core/slang-common.h"
@@ -45,16 +46,16 @@ private:
static DefaultSharedLibraryLoader s_singleton;
};
-class DefaultSharedLibrary : public ISlangSharedLibrary, public RefObject
+class DefaultSharedLibrary : public ISlangSharedLibrary, public ComBaseObject
{
public:
SLANG_CLASS_GUID(0xe7f2597b, 0xf803, 0x4b6e, { 0xaf, 0x8b, 0xcb, 0xe3, 0xa2, 0x21, 0xfd, 0x5a })
// ISlangUnknown
- virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE;
- SLANG_REF_OBJECT_IUNKNOWN_ADD_REF
- SLANG_REF_OBJECT_IUNKNOWN_RELEASE
-
+ SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE;
+ SLANG_COM_BASE_IUNKNOWN_ADD_REF
+ SLANG_COM_BASE_IUNKNOWN_RELEASE
+
// ISlangSharedLibrary
virtual SLANG_NO_THROW void* SLANG_MCALL findSymbolAddressByName(char const* name) SLANG_OVERRIDE;
@@ -69,7 +70,6 @@ class DefaultSharedLibrary : public ISlangSharedLibrary, public RefObject
virtual ~DefaultSharedLibrary();
protected:
- ISlangUnknown* getInterface(const Guid& guid);
SharedLibrary::Handle m_sharedLibraryHandle = nullptr;
};