diff options
| author | Yong He <yonghe@outlook.com> | 2021-04-08 21:10:30 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-04-08 21:10:30 -0700 |
| commit | 8a71039475212fb1e1a6dd2fd2911d02769637ef (patch) | |
| tree | 0faa6e773d6b40c3dcbf0eed08217c629f8ebccf /tools/gfx/renderer-shared.h | |
| parent | d27557d9b770810402a0bf99bcd891c145a1a69d (diff) | |
Improve robustness of gfx lifetime management. (#1788)
* Improve robustness of gfx lifetime management.
* fix clang error
* fix clang error
* Fix clang warning
Diffstat (limited to 'tools/gfx/renderer-shared.h')
| -rw-r--r-- | tools/gfx/renderer-shared.h | 191 |
1 files changed, 172 insertions, 19 deletions
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 79c965631..ec33a3054 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -3,6 +3,7 @@ #include "slang-gfx.h" #include "slang-context.h" #include "core/slang-basic.h" +#include "core/slang-com-object.h" namespace gfx { @@ -34,9 +35,151 @@ struct GfxGUID static const Slang::Guid IID_ICommandQueue; }; +// We use a `BreakableReference` to avoid the cyclic reference situation in gfx implementation. +// It is a common scenario where objects created from an `IDevice` implementation needs to hold +// a strong reference to the device object that creates them. For example, a `Buffer` or a +// `CommandQueue` needs to store a `m_device` member that points to the `IDevice`. At the same +// time, the device implementation may also hold a reference to some of the objects it created +// to represent the current device/binding state. Both parties would like to maintain a strong +// reference to each other to achieve robustness against arbitrary ordering of destruction that +// can be triggered by the user. However this creates cyclic reference situations that break +// the `RefPtr` recyling mechanism. To solve this problem, we instead make each object reference +// the device via a `BreakableReference<TDeviceImpl>` pointer. A breakable reference can be +// turned into a weak reference via its `breakStrongReference()` call. +// If we know there is a cyclic reference between an API object and the device/pool that creates it, +// we can break the cycle when there is no longer any public references that come from `ComPtr`s to +// the API object, by turning the reference to the device object from the API object to a weak +// reference. +// The following example illustrate how this mechanism works: +// Suppose we have +// ``` +// class DeviceImpl : IDevice { RefPtr<ShaderObject> m_currentObject; }; +// class ShaderObjectImpl : IShaderObject { BreakableReference<DeviceImpl> m_device; }; +// ``` +// And the user creates a device and a shader object, then somehow having the device reference +// the shader object (this may not happen in actual implemetations, we just use it to illustrate +// the situation): +// ``` +// ComPtr<IDevice> device = createDevice(); +// ComPtr<ISomeResource> res = device->createResourceX(...); +// device->m_currentResource = res; +// ``` +// This setup is robust to any destruction ordering. If user releases reference to `device` first, +// then the device object will not be freed yet, since there is still a strong reference to the device +// implementation via `res->m_device`. Next when the user releases reference to `res`, the public +// reference count to `res` via `ComPtr`s will go to 0, therefore triggering the call to +// `res->m_device.breakStrongReference()`, releasing the remaining reference to device. This will cause +// `device` to start destruction, which will release its strong reference to `res` during execution of +// its destructor. Finally, this will triger the actual destruction of `res`. +// On the other hand, if the user releases reference to `res` first, then the strong reference to `device` +// will be broken immediately, but the actual destruction of `res` will not start. Next when the user +// releases `device`, there will no longer be any other references to `device`, so the destruction of +// `device` will start, causing the release of the internal reference to `res`, leading to its destruction. +// Note that the above logic only works if it is known that there is a cyclic reference. If there are no +// such cyclic reference, then it will be incorrect to break the strong reference to `IDevice` upon +// public reference counter dropping to 0. This is because the actual destructor of `res` take place +// after breaking the cycle, but if the resource's strong reference to the device is already the last reference, +// turning that reference to weak reference will immediately trigger destruction of `device`, after which +// we can no longer destruct `res` if the destructor needs `device`. Therefore we need to be careful +// when using `BreakableReference`, and make sure we only call `breakStrongReference` only when it is known +// that there is a cyclic reference. Luckily for all scenarios so far this is statically known. +template<typename T> +class BreakableReference +{ +private: + Slang::RefPtr<T> m_strongPtr; + T* m_weakPtr = nullptr; + +public: + BreakableReference() = default; + + BreakableReference(T* p) { *this = p; } + + BreakableReference(Slang::RefPtr<T> const& p) { *this = p; } + + void setWeakReference(T* p) { m_weakPtr = p; m_strongPtr = nullptr; } + + T& operator*() const { return *get(); } + + T* operator->() const { return get(); } + + T* get() const { return m_weakPtr; } + + operator T*() const { return get(); } + + void operator=(Slang::RefPtr<T> const& p) + { + m_strongPtr = p; + m_weakPtr = p.Ptr(); + } + + void operator=(T* p) + { + m_strongPtr = p; + m_weakPtr = p; + } + + void breakStrongReference() { m_strongPtr = nullptr; } + + void establishStrongReference() { m_strongPtr = m_weakPtr; } +}; + +// Helpers for returning an object implementation as COM pointer. +template<typename TInterface, typename TImpl> +void returnComPtr(TInterface** outInterface, TImpl* rawPtr) +{ + static_assert( + !std::is_base_of<Slang::RefObject, TInterface>::value, + "TInterface must be an interface type."); + rawPtr->addRef(); + *outInterface = rawPtr; +} + +template <typename TInterface, typename TImpl> +void returnComPtr(TInterface** outInterface, const Slang::RefPtr<TImpl>& refPtr) +{ + static_assert( + !std::is_base_of<Slang::RefObject, TInterface>::value, + "TInterface must be an interface type."); + refPtr->addRef(); + *outInterface = refPtr.Ptr(); +} + +template <typename TInterface, typename TImpl> +void returnComPtr(TInterface** outInterface, Slang::ComPtr<TImpl>& comPtr) +{ + static_assert( + !std::is_base_of<Slang::RefObject, TInterface>::value, + "TInterface must be an interface type."); + *outInterface = comPtr.detach(); +} + +// Helpers for returning an object implementation as RefPtr. +template <typename TDest, typename TImpl> +void returnRefPtr(TDest** outPtr, Slang::RefPtr<TImpl>& refPtr) +{ + static_assert( + std::is_base_of<Slang::RefObject, TDest>::value, "TDest must be a non-interface type."); + static_assert( + std::is_base_of<Slang::RefObject, TImpl>::value, "TImpl must be a non-interface type."); + *outPtr = refPtr.Ptr(); + refPtr->addReference(); +} + +template <typename TDest, typename TImpl> +void returnRefPtrMove(TDest** outPtr, Slang::RefPtr<TImpl>& refPtr) +{ + static_assert( + std::is_base_of<Slang::RefObject, TDest>::value, "TDest must be a non-interface type."); + static_assert( + std::is_base_of<Slang::RefObject, TImpl>::value, "TImpl must be a non-interface type."); + *outPtr = refPtr.detach(); +} + + gfx::StageType translateStage(SlangStage slangStage); -class Resource : public Slang::RefObject +class Resource : public Slang::ComObject { public: /// Get the type @@ -56,7 +199,7 @@ protected: class BufferResource : public IBufferResource, public Resource { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResource* getInterface(const Slang::Guid& guid); public: @@ -78,7 +221,7 @@ protected: class TextureResource : public ITextureResource, public Resource { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResource* getInterface(const Slang::Guid& guid); public: @@ -138,6 +281,9 @@ struct ExtendedShaderObjectTypeList class ShaderObjectLayoutBase : public Slang::RefObject { protected: + // We always use a weak reference to the `IDevice` object here. + // `ShaderObject` implementations will make sure to hold a strong reference to `IDevice` + // while a `ShaderObjectLayout` may still be used. RendererBase* m_renderer; slang::TypeLayoutReflection* m_elementTypeLayout = nullptr; ShaderComponentID m_componentID = 0; @@ -182,9 +328,13 @@ public: void initBase(RendererBase* renderer, slang::TypeLayoutReflection* elementTypeLayout); }; -class ShaderObjectBase : public IShaderObject, public Slang::RefObject +class ShaderObjectBase : public IShaderObject, public Slang::ComObject { protected: + // A strong reference to `IDevice` to make sure the weak device reference in + // `ShaderObjectLayout`s are valid whenever they might be used. + BreakableReference<RendererBase> m_device; + // The shader object layout used to create this shader object. Slang::RefPtr<ShaderObjectLayoutBase> m_layout = nullptr; @@ -198,8 +348,9 @@ protected: Result _getSpecializedShaderObjectType(ExtendedShaderObjectType* outType); public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IShaderObject* getInterface(const Slang::Guid& guid); + void breakStrongReferenceToDevice() { m_device.breakStrongReference(); } public: ShaderComponentID getComponentID() @@ -235,21 +386,21 @@ public: virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) = 0; }; -class ShaderProgramBase : public IShaderProgram, public Slang::RefObject +class ShaderProgramBase : public IShaderProgram, public Slang::ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - + SLANG_COM_OBJECT_IUNKNOWN_ALL IShaderProgram* getInterface(const Slang::Guid& guid); ComPtr<slang::IComponentType> slangProgram; }; -class PipelineStateBase : public IPipelineState, public Slang::RefObject +class PipelineStateBase + : public IPipelineState + , public Slang::ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - + SLANG_COM_OBJECT_IUNKNOWN_ALL IPipelineState* getInterface(const Slang::Guid& guid); struct PipelineStateDesc @@ -270,10 +421,10 @@ public: // Indicates whether this is a specializable pipeline. A specializable // pipeline cannot be used directly and must be specialized first. bool isSpecializable = false; - ComPtr<IShaderProgram> m_program; + Slang::RefPtr<ShaderProgramBase> m_program; template <typename TProgram> TProgram* getProgram() { - return static_cast<TProgram*>(m_program.get()); + return static_cast<TProgram*>(m_program.Ptr()); } protected: @@ -360,14 +511,16 @@ public: ShaderComponentID getComponentId(Slang::UnownedStringSlice name); ShaderComponentID getComponentId(ComponentKey key); - Slang::ComPtr<IPipelineState> getSpecializedPipelineState(PipelineKey programKey) + Slang::RefPtr<PipelineStateBase> getSpecializedPipelineState(PipelineKey programKey) { - Slang::ComPtr<IPipelineState> result; + Slang::RefPtr<PipelineStateBase> result; if (specializedPipelines.TryGetValue(programKey, result)) return result; return nullptr; } - void addSpecializedPipeline(PipelineKey key, Slang::ComPtr<IPipelineState> specializedPipeline); + void addSpecializedPipeline( + PipelineKey key, + Slang::RefPtr<PipelineStateBase> specializedPipeline); void free() { specializedPipelines = decltype(specializedPipelines)(); @@ -376,16 +529,16 @@ public: protected: Slang::OrderedDictionary<OwningComponentKey, ShaderComponentID> componentIds; - Slang::OrderedDictionary<PipelineKey, Slang::ComPtr<IPipelineState>> specializedPipelines; + Slang::OrderedDictionary<PipelineKey, Slang::RefPtr<PipelineStateBase>> specializedPipelines; }; // Renderer implementation shared by all platforms. // Responsible for shader compilation, specialization and caching. -class RendererBase : public Slang::RefObject, public IDevice +class RendererBase : public IDevice, public Slang::ComObject { friend class ShaderObjectBase; public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL virtual SLANG_NO_THROW Result SLANG_MCALL getFeatures( const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) SLANG_OVERRIDE; |
