summaryrefslogtreecommitdiff
path: root/tools/gfx/renderer-shared.h
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-04-08 21:10:30 -0700
committerGitHub <noreply@github.com>2021-04-08 21:10:30 -0700
commit8a71039475212fb1e1a6dd2fd2911d02769637ef (patch)
tree0faa6e773d6b40c3dcbf0eed08217c629f8ebccf /tools/gfx/renderer-shared.h
parentd27557d9b770810402a0bf99bcd891c145a1a69d (diff)
Improve robustness of gfx lifetime management. (#1788)
* Improve robustness of gfx lifetime management. * fix clang error * fix clang error * Fix clang warning
Diffstat (limited to 'tools/gfx/renderer-shared.h')
-rw-r--r--tools/gfx/renderer-shared.h191
1 files changed, 172 insertions, 19 deletions
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 79c965631..ec33a3054 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -3,6 +3,7 @@
#include "slang-gfx.h"
#include "slang-context.h"
#include "core/slang-basic.h"
+#include "core/slang-com-object.h"
namespace gfx
{
@@ -34,9 +35,151 @@ struct GfxGUID
static const Slang::Guid IID_ICommandQueue;
};
+// We use a `BreakableReference` to avoid the cyclic reference situation in gfx implementation.
+// It is a common scenario where objects created from an `IDevice` implementation needs to hold
+// a strong reference to the device object that creates them. For example, a `Buffer` or a
+// `CommandQueue` needs to store a `m_device` member that points to the `IDevice`. At the same
+// time, the device implementation may also hold a reference to some of the objects it created
+// to represent the current device/binding state. Both parties would like to maintain a strong
+// reference to each other to achieve robustness against arbitrary ordering of destruction that
+// can be triggered by the user. However this creates cyclic reference situations that break
+// the `RefPtr` recyling mechanism. To solve this problem, we instead make each object reference
+// the device via a `BreakableReference<TDeviceImpl>` pointer. A breakable reference can be
+// turned into a weak reference via its `breakStrongReference()` call.
+// If we know there is a cyclic reference between an API object and the device/pool that creates it,
+// we can break the cycle when there is no longer any public references that come from `ComPtr`s to
+// the API object, by turning the reference to the device object from the API object to a weak
+// reference.
+// The following example illustrate how this mechanism works:
+// Suppose we have
+// ```
+// class DeviceImpl : IDevice { RefPtr<ShaderObject> m_currentObject; };
+// class ShaderObjectImpl : IShaderObject { BreakableReference<DeviceImpl> m_device; };
+// ```
+// And the user creates a device and a shader object, then somehow having the device reference
+// the shader object (this may not happen in actual implemetations, we just use it to illustrate
+// the situation):
+// ```
+// ComPtr<IDevice> device = createDevice();
+// ComPtr<ISomeResource> res = device->createResourceX(...);
+// device->m_currentResource = res;
+// ```
+// This setup is robust to any destruction ordering. If user releases reference to `device` first,
+// then the device object will not be freed yet, since there is still a strong reference to the device
+// implementation via `res->m_device`. Next when the user releases reference to `res`, the public
+// reference count to `res` via `ComPtr`s will go to 0, therefore triggering the call to
+// `res->m_device.breakStrongReference()`, releasing the remaining reference to device. This will cause
+// `device` to start destruction, which will release its strong reference to `res` during execution of
+// its destructor. Finally, this will triger the actual destruction of `res`.
+// On the other hand, if the user releases reference to `res` first, then the strong reference to `device`
+// will be broken immediately, but the actual destruction of `res` will not start. Next when the user
+// releases `device`, there will no longer be any other references to `device`, so the destruction of
+// `device` will start, causing the release of the internal reference to `res`, leading to its destruction.
+// Note that the above logic only works if it is known that there is a cyclic reference. If there are no
+// such cyclic reference, then it will be incorrect to break the strong reference to `IDevice` upon
+// public reference counter dropping to 0. This is because the actual destructor of `res` take place
+// after breaking the cycle, but if the resource's strong reference to the device is already the last reference,
+// turning that reference to weak reference will immediately trigger destruction of `device`, after which
+// we can no longer destruct `res` if the destructor needs `device`. Therefore we need to be careful
+// when using `BreakableReference`, and make sure we only call `breakStrongReference` only when it is known
+// that there is a cyclic reference. Luckily for all scenarios so far this is statically known.
+template<typename T>
+class BreakableReference
+{
+private:
+ Slang::RefPtr<T> m_strongPtr;
+ T* m_weakPtr = nullptr;
+
+public:
+ BreakableReference() = default;
+
+ BreakableReference(T* p) { *this = p; }
+
+ BreakableReference(Slang::RefPtr<T> const& p) { *this = p; }
+
+ void setWeakReference(T* p) { m_weakPtr = p; m_strongPtr = nullptr; }
+
+ T& operator*() const { return *get(); }
+
+ T* operator->() const { return get(); }
+
+ T* get() const { return m_weakPtr; }
+
+ operator T*() const { return get(); }
+
+ void operator=(Slang::RefPtr<T> const& p)
+ {
+ m_strongPtr = p;
+ m_weakPtr = p.Ptr();
+ }
+
+ void operator=(T* p)
+ {
+ m_strongPtr = p;
+ m_weakPtr = p;
+ }
+
+ void breakStrongReference() { m_strongPtr = nullptr; }
+
+ void establishStrongReference() { m_strongPtr = m_weakPtr; }
+};
+
+// Helpers for returning an object implementation as COM pointer.
+template<typename TInterface, typename TImpl>
+void returnComPtr(TInterface** outInterface, TImpl* rawPtr)
+{
+ static_assert(
+ !std::is_base_of<Slang::RefObject, TInterface>::value,
+ "TInterface must be an interface type.");
+ rawPtr->addRef();
+ *outInterface = rawPtr;
+}
+
+template <typename TInterface, typename TImpl>
+void returnComPtr(TInterface** outInterface, const Slang::RefPtr<TImpl>& refPtr)
+{
+ static_assert(
+ !std::is_base_of<Slang::RefObject, TInterface>::value,
+ "TInterface must be an interface type.");
+ refPtr->addRef();
+ *outInterface = refPtr.Ptr();
+}
+
+template <typename TInterface, typename TImpl>
+void returnComPtr(TInterface** outInterface, Slang::ComPtr<TImpl>& comPtr)
+{
+ static_assert(
+ !std::is_base_of<Slang::RefObject, TInterface>::value,
+ "TInterface must be an interface type.");
+ *outInterface = comPtr.detach();
+}
+
+// Helpers for returning an object implementation as RefPtr.
+template <typename TDest, typename TImpl>
+void returnRefPtr(TDest** outPtr, Slang::RefPtr<TImpl>& refPtr)
+{
+ static_assert(
+ std::is_base_of<Slang::RefObject, TDest>::value, "TDest must be a non-interface type.");
+ static_assert(
+ std::is_base_of<Slang::RefObject, TImpl>::value, "TImpl must be a non-interface type.");
+ *outPtr = refPtr.Ptr();
+ refPtr->addReference();
+}
+
+template <typename TDest, typename TImpl>
+void returnRefPtrMove(TDest** outPtr, Slang::RefPtr<TImpl>& refPtr)
+{
+ static_assert(
+ std::is_base_of<Slang::RefObject, TDest>::value, "TDest must be a non-interface type.");
+ static_assert(
+ std::is_base_of<Slang::RefObject, TImpl>::value, "TImpl must be a non-interface type.");
+ *outPtr = refPtr.detach();
+}
+
+
gfx::StageType translateStage(SlangStage slangStage);
-class Resource : public Slang::RefObject
+class Resource : public Slang::ComObject
{
public:
/// Get the type
@@ -56,7 +199,7 @@ protected:
class BufferResource : public IBufferResource, public Resource
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
+ SLANG_COM_OBJECT_IUNKNOWN_ALL
IResource* getInterface(const Slang::Guid& guid);
public:
@@ -78,7 +221,7 @@ protected:
class TextureResource : public ITextureResource, public Resource
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
+ SLANG_COM_OBJECT_IUNKNOWN_ALL
IResource* getInterface(const Slang::Guid& guid);
public:
@@ -138,6 +281,9 @@ struct ExtendedShaderObjectTypeList
class ShaderObjectLayoutBase : public Slang::RefObject
{
protected:
+ // We always use a weak reference to the `IDevice` object here.
+ // `ShaderObject` implementations will make sure to hold a strong reference to `IDevice`
+ // while a `ShaderObjectLayout` may still be used.
RendererBase* m_renderer;
slang::TypeLayoutReflection* m_elementTypeLayout = nullptr;
ShaderComponentID m_componentID = 0;
@@ -182,9 +328,13 @@ public:
void initBase(RendererBase* renderer, slang::TypeLayoutReflection* elementTypeLayout);
};
-class ShaderObjectBase : public IShaderObject, public Slang::RefObject
+class ShaderObjectBase : public IShaderObject, public Slang::ComObject
{
protected:
+ // A strong reference to `IDevice` to make sure the weak device reference in
+ // `ShaderObjectLayout`s are valid whenever they might be used.
+ BreakableReference<RendererBase> m_device;
+
// The shader object layout used to create this shader object.
Slang::RefPtr<ShaderObjectLayoutBase> m_layout = nullptr;
@@ -198,8 +348,9 @@ protected:
Result _getSpecializedShaderObjectType(ExtendedShaderObjectType* outType);
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
+ SLANG_COM_OBJECT_IUNKNOWN_ALL
IShaderObject* getInterface(const Slang::Guid& guid);
+ void breakStrongReferenceToDevice() { m_device.breakStrongReference(); }
public:
ShaderComponentID getComponentID()
@@ -235,21 +386,21 @@ public:
virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) = 0;
};
-class ShaderProgramBase : public IShaderProgram, public Slang::RefObject
+class ShaderProgramBase : public IShaderProgram, public Slang::ComObject
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
-
+ SLANG_COM_OBJECT_IUNKNOWN_ALL
IShaderProgram* getInterface(const Slang::Guid& guid);
ComPtr<slang::IComponentType> slangProgram;
};
-class PipelineStateBase : public IPipelineState, public Slang::RefObject
+class PipelineStateBase
+ : public IPipelineState
+ , public Slang::ComObject
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
-
+ SLANG_COM_OBJECT_IUNKNOWN_ALL
IPipelineState* getInterface(const Slang::Guid& guid);
struct PipelineStateDesc
@@ -270,10 +421,10 @@ public:
// Indicates whether this is a specializable pipeline. A specializable
// pipeline cannot be used directly and must be specialized first.
bool isSpecializable = false;
- ComPtr<IShaderProgram> m_program;
+ Slang::RefPtr<ShaderProgramBase> m_program;
template <typename TProgram> TProgram* getProgram()
{
- return static_cast<TProgram*>(m_program.get());
+ return static_cast<TProgram*>(m_program.Ptr());
}
protected:
@@ -360,14 +511,16 @@ public:
ShaderComponentID getComponentId(Slang::UnownedStringSlice name);
ShaderComponentID getComponentId(ComponentKey key);
- Slang::ComPtr<IPipelineState> getSpecializedPipelineState(PipelineKey programKey)
+ Slang::RefPtr<PipelineStateBase> getSpecializedPipelineState(PipelineKey programKey)
{
- Slang::ComPtr<IPipelineState> result;
+ Slang::RefPtr<PipelineStateBase> result;
if (specializedPipelines.TryGetValue(programKey, result))
return result;
return nullptr;
}
- void addSpecializedPipeline(PipelineKey key, Slang::ComPtr<IPipelineState> specializedPipeline);
+ void addSpecializedPipeline(
+ PipelineKey key,
+ Slang::RefPtr<PipelineStateBase> specializedPipeline);
void free()
{
specializedPipelines = decltype(specializedPipelines)();
@@ -376,16 +529,16 @@ public:
protected:
Slang::OrderedDictionary<OwningComponentKey, ShaderComponentID> componentIds;
- Slang::OrderedDictionary<PipelineKey, Slang::ComPtr<IPipelineState>> specializedPipelines;
+ Slang::OrderedDictionary<PipelineKey, Slang::RefPtr<PipelineStateBase>> specializedPipelines;
};
// Renderer implementation shared by all platforms.
// Responsible for shader compilation, specialization and caching.
-class RendererBase : public Slang::RefObject, public IDevice
+class RendererBase : public IDevice, public Slang::ComObject
{
friend class ShaderObjectBase;
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
+ SLANG_COM_OBJECT_IUNKNOWN_ALL
virtual SLANG_NO_THROW Result SLANG_MCALL getFeatures(
const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) SLANG_OVERRIDE;