From 8a71039475212fb1e1a6dd2fd2911d02769637ef Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 8 Apr 2021 21:10:30 -0700 Subject: Improve robustness of gfx lifetime management. (#1788) * Improve robustness of gfx lifetime management. * fix clang error * fix clang error * Fix clang warning --- tools/gfx/cpu/render-cpu.cpp | 46 ++-- tools/gfx/cuda/render-cuda.cpp | 115 +++++----- tools/gfx/d3d/d3d-swapchain.h | 9 +- tools/gfx/d3d11/render-d3d11.cpp | 117 +++++----- tools/gfx/d3d12/descriptor-heap-d3d12.h | 3 +- tools/gfx/d3d12/render-d3d12.cpp | 251 ++++++++++---------- tools/gfx/immediate-renderer-base.cpp | 69 +++--- tools/gfx/immediate-renderer-base.h | 32 ++- tools/gfx/open-gl/render-gl.cpp | 110 ++++----- tools/gfx/renderer-shared.cpp | 18 +- tools/gfx/renderer-shared.h | 191 ++++++++++++++-- tools/gfx/simple-render-pass-layout.h | 6 +- tools/gfx/simple-transient-resource-heap.h | 10 +- tools/gfx/transient-resource-heap-base.h | 9 +- tools/gfx/vulkan/render-vk.cpp | 354 ++++++++++++++++------------- tools/platform/window.h | 2 +- tools/render-test/render-test-main.cpp | 20 +- 17 files changed, 806 insertions(+), 556 deletions(-) (limited to 'tools') diff --git a/tools/gfx/cpu/render-cpu.cpp b/tools/gfx/cpu/render-cpu.cpp index 8dbe5460a..8dd1ccae6 100644 --- a/tools/gfx/cpu/render-cpu.cpp +++ b/tools/gfx/cpu/render-cpu.cpp @@ -191,6 +191,7 @@ public: {} ~CPUTextureResource() { + free(m_data); } Result init(ITextureResource::SubresourceData const* initData) @@ -341,7 +342,7 @@ public: void* m_data = nullptr; }; -class CPUResourceView : public IResourceView, public RefObject +class CPUResourceView : public IResourceView, public ComObject { public: enum class Kind @@ -350,7 +351,7 @@ public: Texture, }; - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResourceView* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResourceView) @@ -799,9 +800,9 @@ public: auto& bindingRange = layout->m_bindingRanges[bindingRangeIndex]; auto subObjectIndex = bindingRange.baseIndex + offset.bindingArrayIndex; - CPUShaderObject* subObject = m_objects[subObjectIndex]; + auto& subObject = m_objects[subObjectIndex]; - *outObject = ComPtr(subObject).detach(); + returnComPtr(outObject, subObject); return SLANG_OK; } @@ -1043,6 +1044,12 @@ public: class CPURootShaderObject : public CPUShaderObject { +public: + // Override default reference counting behavior to disable lifetime management via ComPtr. + // Root objects are managed by command buffer and does not need to be freed by the user. + SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } + SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } + public: SlangResult init(IDevice* device, CPUProgramLayout* programLayout); @@ -1056,7 +1063,7 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL getEntryPoint(UInt index, IShaderObject** outEntryPoint) override { - *outEntryPoint = ComPtr(m_entryPoints[index]).detach(); + returnComPtr(outEntryPoint, m_entryPoints[index]); return SLANG_OK; } virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override @@ -1083,7 +1090,7 @@ public: class CPUPipelineState : public PipelineStateBase { public: - CPUShaderProgram* getProgram() { return static_cast(m_program.get()); } + CPUShaderProgram* getProgram() { return static_cast(m_program.Ptr()); } void init(const ComputePipelineStateDesc& inDesc) { @@ -1194,11 +1201,14 @@ public: const ITextureResource::SubresourceData* initData, ITextureResource** outResource) override { - RefPtr texture = new CPUTextureResource(desc); + TextureResource::Desc srcDesc(desc); + srcDesc.setDefaults(initialUsage); + + RefPtr texture = new CPUTextureResource(srcDesc); SLANG_RETURN_ON_FAIL(texture->init(initData)); - *outResource = texture.detach(); + returnComPtr(outResource, texture); return SLANG_OK; } @@ -1214,7 +1224,7 @@ public: { SLANG_RETURN_ON_FAIL(resource->setData(0, desc.sizeInBytes, initData)); } - *outResource = resource.detach(); + returnComPtr(outResource, resource); return SLANG_OK; } @@ -1223,7 +1233,7 @@ public: { auto texture = static_cast(inTexture); RefPtr view = new CPUTextureView(desc, texture); - *outView = view.detach(); + returnComPtr(outView, view); return SLANG_OK; } @@ -1232,7 +1242,7 @@ public: { auto buffer = static_cast(inBuffer); RefPtr view = new CPUBufferView(desc, buffer); - *outView = view.detach(); + returnComPtr(outView, view); return SLANG_OK; } @@ -1241,7 +1251,7 @@ public: ShaderObjectLayoutBase** outLayout) override { RefPtr cpuLayout = new CPUShaderObjectLayout(this, typeLayout); - *outLayout = cpuLayout.detach(); + returnRefPtrMove(outLayout, cpuLayout); return SLANG_OK; } @@ -1254,19 +1264,19 @@ public: RefPtr result = new CPUShaderObject(); SLANG_RETURN_ON_FAIL(result->init(this, cpuLayout)); - *outObject = result.detach(); + returnComPtr(outObject, result); return SLANG_OK; } - virtual Result createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override + virtual Result createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject) override { auto cpuProgram = static_cast(program); auto cpuProgramLayout = cpuProgram->layout; RefPtr result = new CPURootShaderObject(); SLANG_RETURN_ON_FAIL(result->init(this, cpuProgramLayout)); - *outObject = result.detach(); + returnRefPtrMove(outObject, result); return SLANG_OK; } @@ -1292,7 +1302,7 @@ public: cpuProgram->layout = cpuProgramLayout; } - *outProgram = cpuProgram.detach(); + returnComPtr(outProgram, cpuProgram); return SLANG_OK; } @@ -1301,7 +1311,7 @@ public: { RefPtr state = new CPUPipelineState(); state->init(desc); - *outState = state.detach(); + returnComPtr(outState, state); return Result(); } @@ -1410,7 +1420,7 @@ SlangResult SLANG_MCALL createCPUDevice(const IDevice::Desc* desc, IDevice** out { RefPtr result = new CPUDevice(); SLANG_RETURN_ON_FAIL(result->initialize(*desc)); - *outDevice = result.detach(); + returnComPtr(outDevice, result); return SLANG_OK; } diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp index dbee0c5f2..f871e6246 100644 --- a/tools/gfx/cuda/render-cuda.cpp +++ b/tools/gfx/cuda/render-cuda.cpp @@ -181,6 +181,13 @@ void _optixLogCallback(unsigned int level, const char* tag, const char* message, # endif +class CUDAContext : public RefObject +{ +public: + CUcontext m_context = nullptr; + ~CUDAContext() { cuCtxDestroy(m_context); } +}; + class MemoryCUDAResource : public BufferResource { public: @@ -199,6 +206,8 @@ public: uint64_t getBindlessHandle() { return (uint64_t)m_cudaMemory; } void* m_cudaMemory = nullptr; + + RefPtr m_cudaContext; }; class TextureCUDAResource : public TextureResource @@ -238,12 +247,14 @@ public: CUarray m_cudaArray = CUarray(); CUmipmappedArray m_cudaMipMappedArray = CUmipmappedArray(); + + RefPtr m_cudaContext; }; -class CUDAResourceView : public IResourceView, public RefObject +class CUDAResourceView : public IResourceView, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResourceView* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResourceView) @@ -451,7 +462,7 @@ public: ComPtr constantBuffer; SLANG_RETURN_ON_FAIL(device->createBufferResource( IResource::Usage::ConstantBuffer, bufferDesc, nullptr, constantBuffer.writeRef())); - bufferResource = dynamic_cast(constantBuffer.get()); + bufferResource = static_cast(constantBuffer.get()); return SLANG_OK; } @@ -514,8 +525,7 @@ public: *object = nullptr; return SLANG_OK; } - objects[subObjectIndex]->addRef(); - *object = objects[subObjectIndex].Ptr(); + returnComPtr(object, objects[subObjectIndex]); return SLANG_OK; } virtual SLANG_NO_THROW Result SLANG_MCALL @@ -648,7 +658,7 @@ public: auto& bindingRange = layout->m_bindingRanges[bindingRangeIndex]; auto viewIndex = bindingRange.baseIndex + offset.bindingArrayIndex; - auto cudaView = dynamic_cast(resourceView); + auto cudaView = static_cast(resourceView); resources[viewIndex] = cudaView; @@ -720,7 +730,6 @@ public: // type, we need to make sure to use that type as the specialization argument. // TODO: need to implement the case where the field is an array of existential values. - SLANG_ASSERT(bindingRange.count == 1); ExtendedShaderObjectType specializedSubObjType; SLANG_RETURN_ON_FAIL(objects[subObjIndex]->getSpecializedShaderObjectType(&specializedSubObjType)); args.add(specializedSubObjType); @@ -800,6 +809,12 @@ public: class CUDARootShaderObject : public CUDAShaderObject { +public: + // Override default reference counting behavior to disable lifetime management. + // Root objects are managed by command buffer and does not need to be freed by the user. + SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } + SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } + public: List> entryPointObjects; virtual SLANG_NO_THROW Result SLANG_MCALL @@ -808,8 +823,7 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL getEntryPoint(UInt index, IShaderObject** outEntryPoint) override { - *outEntryPoint = entryPointObjects[index].Ptr(); - entryPointObjects[index]->addRef(); + returnComPtr(outEntryPoint, entryPointObjects[index]); return SLANG_OK; } virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override @@ -830,6 +844,7 @@ public: CUfunction cudaKernel; String kernelName; RefPtr layout; + RefPtr cudaContext; ~CUDAShaderProgram() { if (cudaModule) @@ -969,7 +984,7 @@ private: private: int m_deviceIndex = -1; CUdevice m_device = 0; - CUcontext m_context = nullptr; + RefPtr m_context; DeviceInfo m_info; String m_adapterName; @@ -979,10 +994,10 @@ public: class CommandBufferImpl : public ICommandBuffer , public CommandWriter - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ICommandBuffer* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandBuffer) @@ -1010,7 +1025,7 @@ public: virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) override { - if (uuid == GfxGUID::IID_ISlangUnknown || + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || uuid == GfxGUID::IID_IComputeCommandEncoder) { *outObject = static_cast(this); @@ -1025,7 +1040,7 @@ public: public: CommandWriter* m_writer; CommandBufferImpl* m_commandBuffer; - ComPtr m_rootObject; + RefPtr m_rootObject; virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override {} void init(CommandBufferImpl* cmdBuffer) { @@ -1039,8 +1054,8 @@ public: m_writer->setPipelineState(state); PipelineStateBase* pipelineImpl = static_cast(state); SLANG_RETURN_ON_FAIL(m_commandBuffer->m_device->createRootShaderObject( - pipelineImpl->m_program, outRootObject)); - m_rootObject = *outRootObject; + pipelineImpl->m_program, m_rootObject.writeRef())); + returnComPtr(outRootObject, m_rootObject); return SLANG_OK; } @@ -1066,7 +1081,7 @@ public: virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) override { - if (uuid == GfxGUID::IID_ISlangUnknown || + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || uuid == GfxGUID::IID_IResourceCommandEncoder) { *outObject = static_cast(this); @@ -1118,10 +1133,10 @@ public: class CommandQueueImpl : public ICommandQueue - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ICommandQueue* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandQueue) @@ -1176,7 +1191,7 @@ public: currentPipeline = dynamic_cast(state); } - Result bindRootShaderObject(PipelineType pipelineType, IShaderObject* object) + Result bindRootShaderObject(IShaderObject* object) { currentRootObject = dynamic_cast(object); if (currentRootObject) @@ -1294,8 +1309,7 @@ public: break; case CommandName::BindRootShaderObject: bindRootShaderObject( - (PipelineType)cmd.operands[0], - commandBuffer->getObject(cmd.operands[1])); + commandBuffer->getObject(cmd.operands[0])); break; case CommandName::DispatchCompute: dispatchCompute( @@ -1324,13 +1338,6 @@ public: using TransientResourceHeapImpl = SimpleTransientResourceHeap; public: - ~CUDADevice() - { - if (m_context) - { - cuCtxDestroy(m_context); - } - } virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc) override { SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_PTX, "sm_5_1")); @@ -1342,15 +1349,12 @@ public: SLANG_RETURN_ON_FAIL(_findMaxFlopsDeviceIndex(&m_deviceIndex)); SLANG_CUDA_RETURN_WITH_REPORT_ON_FAIL(cudaSetDevice(m_deviceIndex), reportType); - if (m_context) - { - cuCtxDestroy(m_context); - m_context = nullptr; - } + m_context = new CUDAContext(); SLANG_CUDA_RETURN_ON_FAIL(cuDeviceGet(&m_device, m_deviceIndex)); - SLANG_CUDA_RETURN_WITH_REPORT_ON_FAIL(cuCtxCreate(&m_context, 0, m_device), reportType); + SLANG_CUDA_RETURN_WITH_REPORT_ON_FAIL( + cuCtxCreate(&m_context->m_context, 0, m_device), reportType); // Initialize DeviceInfo { @@ -1375,7 +1379,12 @@ public: const ITextureResource::SubresourceData* initData, ITextureResource** outResource) override { - RefPtr tex = new TextureCUDAResource(desc); + TextureResource::Desc srcDesc(desc); + srcDesc.setDefaults(initialUsage); + + RefPtr tex = new TextureCUDAResource(srcDesc); + tex->m_cudaContext = m_context; + CUresourcetype resourceType; size_t elementSize = 0; @@ -1777,7 +1786,7 @@ public: cuTexObjectCreate(&tex->m_cudaTexObj, &resDesc, &texDesc, nullptr)); } - *outResource = tex.detach(); + returnComPtr(outResource, tex); return SLANG_OK; } @@ -1788,12 +1797,13 @@ public: IBufferResource** outResource) override { RefPtr resource = new MemoryCUDAResource(desc); + resource->m_cudaContext = m_context; SLANG_CUDA_RETURN_ON_FAIL(cudaMallocManaged(&resource->m_cudaMemory, desc.sizeInBytes)); if (initData) { SLANG_CUDA_RETURN_ON_FAIL(cudaMemcpy(resource->m_cudaMemory, initData, desc.sizeInBytes, cudaMemcpyHostToDevice)); } - *outResource = resource.detach(); + returnComPtr(outResource, resource); return SLANG_OK; } @@ -1803,7 +1813,7 @@ public: RefPtr view = new CUDAResourceView(); view->desc = desc; view->textureResource = dynamic_cast(texture); - *outView = view.detach(); + returnComPtr(outView, view); return SLANG_OK; } @@ -1813,7 +1823,7 @@ public: RefPtr view = new CUDAResourceView(); view->desc = desc; view->memoryResource = dynamic_cast(buffer); - *outView = view.detach(); + returnComPtr(outView, view); return SLANG_OK; } @@ -1823,7 +1833,7 @@ public: { RefPtr cudaLayout; cudaLayout = new CUDAShaderObjectLayout(this, typeLayout); - *outLayout = cudaLayout.detach(); + returnRefPtrMove(outLayout, cudaLayout); return SLANG_OK; } @@ -1833,18 +1843,18 @@ public: { RefPtr result = new CUDAShaderObject(); SLANG_RETURN_ON_FAIL(result->init(this, dynamic_cast(layout))); - *outObject = result.detach(); + returnComPtr(outObject, result); return SLANG_OK; } - Result createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) + Result createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject) { auto cudaProgram = dynamic_cast(program); auto cudaLayout = cudaProgram->layout; RefPtr result = new CUDARootShaderObject(); SLANG_RETURN_ON_FAIL(result->init(this, cudaLayout)); - *outObject = result.detach(); + returnRefPtrMove(outObject, result); return SLANG_OK; } @@ -1856,10 +1866,11 @@ public: // the shader object bindings. RefPtr cudaProgram = new CUDAShaderProgram(); cudaProgram->slangProgram = desc.slangProgram; + cudaProgram->cudaContext = m_context; if (desc.slangProgram->getSpecializationParamCount() != 0) { cudaProgram->layout = new CUDAProgramLayout(this, desc.slangProgram->getLayout()); - *outProgram = cudaProgram.detach(); + returnComPtr(outProgram, cudaProgram); return SLANG_OK; } @@ -1893,7 +1904,7 @@ public: cudaProgram->layout = cudaLayout; } - *outProgram = cudaProgram.detach(); + returnComPtr(outProgram, cudaProgram); return SLANG_OK; } @@ -1901,9 +1912,9 @@ public: const ComputePipelineStateDesc& desc, IPipelineState** outState) override { RefPtr state = new CUDAPipelineState(); - state->shaderProgram = dynamic_cast(desc.program); + state->shaderProgram = static_cast(desc.program); state->init(desc); - *outState = state.detach(); + returnComPtr(outState, state); return Result(); } @@ -1929,7 +1940,7 @@ public: { RefPtr result = new TransientResourceHeapImpl(); SLANG_RETURN_ON_FAIL(result->init(this, desc)); - *outHeap = result.detach(); + returnComPtr(outHeap, result); return SLANG_OK; } @@ -1938,7 +1949,7 @@ public: { RefPtr queue = new CommandQueueImpl(); queue->init(this); - *outQueue = queue.detach(); + returnComPtr(outQueue, queue); return SLANG_OK; } virtual SLANG_NO_THROW Result SLANG_MCALL createSwapchain( @@ -2027,7 +2038,7 @@ public: (uint8_t*)bufferImpl->m_cudaMemory + offset, size, cudaMemcpyDefault); - *outBlob = blob.detach(); + returnComPtr(outBlob, blob); return SLANG_OK; } }; @@ -2115,7 +2126,7 @@ SlangResult SLANG_MCALL createCUDADevice(const IDevice::Desc* desc, IDevice** ou { RefPtr result = new CUDADevice(); SLANG_RETURN_ON_FAIL(result->initialize(*desc)); - *outDevice = result.detach(); + returnComPtr(outDevice, result); return SLANG_OK; } #else diff --git a/tools/gfx/d3d/d3d-swapchain.h b/tools/gfx/d3d/d3d-swapchain.h index 5a9ead876..11914cca8 100644 --- a/tools/gfx/d3d/d3d-swapchain.h +++ b/tools/gfx/d3d/d3d-swapchain.h @@ -10,10 +10,10 @@ namespace gfx { class D3DSwapchainBase : public ISwapchain - , public Slang::RefObject + , public Slang::ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ISwapchain* getInterface(const Slang::Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ISwapchain) @@ -82,8 +82,7 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL getImage(uint32_t index, ITextureResource** outResource) override { - m_images[index]->addRef(); - *outResource = m_images[index].get(); + returnComPtr(outResource, m_images[index]); return SLANG_OK; } virtual SLANG_NO_THROW Result SLANG_MCALL present() override @@ -148,7 +147,7 @@ public: ISwapchain::Desc m_desc; HANDLE m_swapChainWaitableObject = nullptr; ComPtr m_swapChain; - Slang::ShortList> m_images; + Slang::ShortList> m_images; }; } diff --git a/tools/gfx/d3d11/render-d3d11.cpp b/tools/gfx/d3d11/render-d3d11.cpp index a099d98c7..f5dbcc0a7 100644 --- a/tools/gfx/d3d11/render-d3d11.cpp +++ b/tools/gfx/d3d11/render-d3d11.cpp @@ -107,7 +107,7 @@ public: ShaderObjectLayoutBase** outLayout) override; virtual Result createShaderObject(ShaderObjectLayoutBase* layout, IShaderObject** outObject) override; - virtual Result createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) + virtual Result createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject) override; virtual void bindRootShaderObject(IShaderObject* shaderObject) override; @@ -203,10 +203,10 @@ protected: }; - class SamplerStateImpl : public ISamplerState, public RefObject + class SamplerStateImpl : public ISamplerState, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ISamplerState* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ISamplerState) @@ -218,10 +218,10 @@ protected: }; - class ResourceViewImpl : public IResourceView, public RefObject + class ResourceViewImpl : public IResourceView, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResourceView* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResourceView) @@ -267,10 +267,10 @@ protected: class FramebufferLayoutImpl : public IFramebufferLayout - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IFramebufferLayout* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IFramebufferLayout) @@ -286,10 +286,10 @@ protected: class FramebufferImpl : public IFramebuffer - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IFramebuffer* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IFramebuffer) @@ -309,7 +309,7 @@ protected: public: ComPtr m_device; ComPtr m_dxgiFactory; - D3D11Device* m_renderer; + RefPtr m_renderer; Result init(D3D11Device* renderer, const ISwapchain::Desc& swapchainDesc, WindowHandle window) { m_renderer = renderer; @@ -330,11 +330,9 @@ protected: RefPtr image = new TextureResourceImpl(imageDesc, IResource::Usage::RenderTarget); image->m_resource = d3dResource; - ComPtr imageResourcePtr; - imageResourcePtr = image.Ptr(); for (uint32_t i = 0; i < m_desc.imageCount; i++) { - m_images.add(imageResourcePtr); + m_images.add(image); } } virtual IDXGIFactory* getDXGIFactory() override { return m_dxgiFactory; } @@ -347,10 +345,10 @@ protected: } }; - class InputLayoutImpl: public IInputLayout, public RefObject + class InputLayoutImpl: public IInputLayout, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IInputLayout* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IInputLayout) @@ -561,7 +559,7 @@ protected: RefPtr(new ShaderObjectLayoutImpl()); SLANG_RETURN_ON_FAIL(layout->_init(this)); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } }; @@ -654,7 +652,7 @@ protected: RefPtr layout = new RootShaderObjectLayoutImpl(); SLANG_RETURN_ON_FAIL(layout->_init(this)); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -733,10 +731,10 @@ protected: ShaderObjectLayoutImpl* layout, ShaderObjectImpl** outShaderObject) { - auto object = ComPtr(new ShaderObjectImpl()); + auto object = RefPtr(new ShaderObjectImpl()); SLANG_RETURN_ON_FAIL(object->init(device, layout)); - *outShaderObject = object.detach(); + returnRefPtrMove(outShaderObject, object); return SLANG_OK; } @@ -908,9 +906,7 @@ protected: return SLANG_E_INVALID_ARG; auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex); - auto object = m_objects[bindingRange.baseIndex + offset.bindingArrayIndex].Ptr(); - object->addRef(); - *outObject = object; + returnComPtr(outObject, m_objects[bindingRange.baseIndex + offset.bindingArrayIndex]); return SLANG_OK; } @@ -1310,7 +1306,7 @@ protected: { SLANG_RETURN_ON_FAIL(_createSpecializedLayout(m_specializedLayout.writeRef())); } - *outLayout = RefPtr(m_specializedLayout).detach(); + returnRefPtr(outLayout, m_specializedLayout); return SLANG_OK; } @@ -1324,10 +1320,11 @@ protected: SLANG_RETURN_ON_FAIL(getSpecializedShaderObjectType(&extendedType)); auto renderer = getRenderer(); - RefPtr layout; - SLANG_RETURN_ON_FAIL(renderer->getShaderObjectLayout(extendedType.slangType, layout.writeRef())); + RefPtr layout; + SLANG_RETURN_ON_FAIL(renderer->getShaderObjectLayout( + extendedType.slangType, (ShaderObjectLayoutBase**)layout.writeRef())); - *outLayout = static_cast(layout.detach()); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -1338,13 +1335,18 @@ protected: { typedef ShaderObjectImpl Super; + public: + // Override default reference counting behavior to disable lifetime management via ComPtr. + // Root objects are managed by command buffer and does not need to be freed by the user. + SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } + SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } public: static Result create(IDevice* device, RootShaderObjectLayoutImpl* layout, RootShaderObjectImpl** outShaderObject) { RefPtr object = new RootShaderObjectImpl(); SLANG_RETURN_ON_FAIL(object->init(device, layout)); - *outShaderObject = object.detach(); + returnRefPtrMove(outShaderObject, object); return SLANG_OK; } @@ -1353,8 +1355,7 @@ protected: UInt SLANG_MCALL getEntryPointCount() SLANG_OVERRIDE { return (UInt)m_entryPoints.getCount(); } SlangResult SLANG_MCALL getEntryPoint(UInt index, IShaderObject** outEntryPoint) SLANG_OVERRIDE { - *outEntryPoint = m_entryPoints[index]; - m_entryPoints[index]->addRef(); + returnComPtr(outEntryPoint, m_entryPoints[index]); return SLANG_OK; } @@ -1473,7 +1474,7 @@ protected: entryPointVars->m_specializedLayout = entryPointInfo.layout; } - *outLayout = specializedLayout.detach(); + returnRefPtrMove(outLayout, specializedLayout); return SLANG_OK; } @@ -1496,7 +1497,7 @@ protected: RefPtr m_currentFramebuffer; - ComPtr m_currentPipelineState; + RefPtr m_currentPipelineState; RootBindingState m_rootBindingState; @@ -1517,7 +1518,7 @@ SlangResult SLANG_MCALL createD3D11Device(const IDevice::Desc* desc, IDevice** o { RefPtr result = new D3D11Device(); SLANG_RETURN_ON_FAIL(result->initialize(*desc)); - *outDevice = result.detach(); + returnComPtr(outDevice, result); return SLANG_OK; } @@ -1782,7 +1783,7 @@ Result D3D11Device::createSwapchain( { RefPtr swapchain = new SwapchainImpl(); SLANG_RETURN_ON_FAIL(swapchain->init(this, desc, window)); - *outSwapchain = swapchain.detach(); + returnComPtr(outSwapchain, swapchain); return SLANG_OK; } @@ -1805,7 +1806,7 @@ Result D3D11Device::createFramebufferLayout( { layout->m_hasDepthStencil = false; } - *outLayout = layout.detach(); + returnComPtr(outLayout, layout); return SLANG_OK; } @@ -1822,7 +1823,7 @@ Result D3D11Device::createFramebuffer( } framebuffer->depthStencilView = static_cast(desc.depthStencilView); framebuffer->d3dDepthStencilView = framebuffer->depthStencilView->m_dsv; - *outFramebuffer = framebuffer.detach(); + returnComPtr(outFramebuffer, framebuffer); return SLANG_OK; } @@ -1909,7 +1910,7 @@ SlangResult D3D11Device::readTextureResource( } // Make sure to unmap m_immediateContext->Unmap(stagingTexture, 0); - *outBlob = blob.detach(); + returnComPtr(outBlob, blob); return SLANG_OK; } } @@ -2076,7 +2077,7 @@ Result D3D11Device::createTextureResource(IResource::Usage initialUsage, const I return SLANG_FAIL; } - *outResource = texture.detach(); + returnComPtr(outResource, texture); return SLANG_OK; } @@ -2167,7 +2168,7 @@ Result D3D11Device::createBufferResource(IResource::Usage initialUsage, const IB SLANG_RETURN_ON_FAIL(m_device->CreateBuffer(&bufDesc, nullptr, buffer->m_staging.writeRef())); } - *outResource = buffer.detach(); + returnComPtr(outResource, buffer); return SLANG_OK; } @@ -2287,7 +2288,7 @@ Result D3D11Device::createSamplerState(ISamplerState::Desc const& desc, ISampler RefPtr samplerImpl = new SamplerStateImpl(); samplerImpl->m_sampler = sampler; - *outSampler = samplerImpl.detach(); + returnComPtr(outSampler, samplerImpl); return SLANG_OK; } @@ -2312,7 +2313,7 @@ Result D3D11Device::createTextureView(ITextureResource* texture, IResourceView:: viewImpl->m_clearValue, &resourceImpl->getDesc()->optimalClearValue.color, sizeof(float) * 4); - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } break; @@ -2326,7 +2327,7 @@ Result D3D11Device::createTextureView(ITextureResource* texture, IResourceView:: viewImpl->m_type = ResourceViewImpl::Type::DSV; viewImpl->m_dsv = dsv; viewImpl->m_clearValue = resourceImpl->getDesc()->optimalClearValue.depthStencil; - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } break; @@ -2339,7 +2340,7 @@ Result D3D11Device::createTextureView(ITextureResource* texture, IResourceView:: RefPtr viewImpl = new UnorderedAccessViewImpl(); viewImpl->m_type = ResourceViewImpl::Type::UAV; viewImpl->m_uav = uav; - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } break; @@ -2352,7 +2353,7 @@ Result D3D11Device::createTextureView(ITextureResource* texture, IResourceView:: RefPtr viewImpl = new ShaderResourceViewImpl(); viewImpl->m_type = ResourceViewImpl::Type::SRV; viewImpl->m_srv = srv; - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } break; @@ -2397,7 +2398,7 @@ Result D3D11Device::createBufferView(IBufferResource* buffer, IResourceView::Des RefPtr viewImpl = new UnorderedAccessViewImpl(); viewImpl->m_type = ResourceViewImpl::Type::UAV; viewImpl->m_uav = uav; - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } break; @@ -2442,7 +2443,7 @@ Result D3D11Device::createBufferView(IBufferResource* buffer, IResourceView::Des RefPtr viewImpl = new ShaderResourceViewImpl(); viewImpl->m_type = ResourceViewImpl::Type::SRV; viewImpl->m_srv = srv; - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } break; @@ -2512,7 +2513,7 @@ Result D3D11Device::createInputLayout(const InputElementDesc* inputElementsIn, U RefPtr impl = new InputLayoutImpl; impl->m_layout.swap(inputLayout); - *outLayout = impl.detach(); + returnComPtr(outLayout, impl); return SLANG_OK; } @@ -2661,7 +2662,7 @@ void D3D11Device::setPipelineState(IPipelineState* state) case PipelineType::Graphics: { auto stateImpl = (GraphicsPipelineStateImpl*) state; - auto programImpl = static_cast(stateImpl->m_program.get()); + auto programImpl = static_cast(stateImpl->m_program.Ptr()); // TODO: We could conceivably do some lightweight state // differencing here (e.g., check if `programImpl` is the @@ -2705,7 +2706,7 @@ void D3D11Device::setPipelineState(IPipelineState* state) case PipelineType::Compute: { auto stateImpl = (ComputePipelineStateImpl*) state; - auto programImpl = static_cast(stateImpl->m_program.get()); + auto programImpl = static_cast(stateImpl->m_program.Ptr()); // CS @@ -2739,7 +2740,7 @@ Result D3D11Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr // For a specializable program, we don't invoke any actual slang compilation yet. RefPtr shaderProgram = new ShaderProgramImpl(); shaderProgram->slangProgram = desc.slangProgram; - *outProgram = shaderProgram.detach(); + returnComPtr(outProgram, shaderProgram); return SLANG_OK; } @@ -2802,7 +2803,7 @@ Result D3D11Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr SLANG_ASSERT(!"pipeline stage not implemented"); } } - *outProgram = shaderProgram.detach(); + returnComPtr(outProgram, shaderProgram); return SLANG_OK; } @@ -2940,7 +2941,7 @@ Result D3D11Device::createShaderObjectLayout( RefPtr layout; SLANG_RETURN_ON_FAIL(ShaderObjectLayoutImpl::createForElementType( this, typeLayout, layout.writeRef())); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -2949,11 +2950,11 @@ Result D3D11Device::createShaderObject(ShaderObjectLayoutBase* layout, IShaderOb RefPtr shaderObject; SLANG_RETURN_ON_FAIL(ShaderObjectImpl::create(this, static_cast(layout), shaderObject.writeRef())); - *outObject = shaderObject.detach(); + returnComPtr(outObject, shaderObject); return SLANG_OK; } -Result D3D11Device::createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) +Result D3D11Device::createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject) { auto programImpl = static_cast(program); RefPtr shaderObject; @@ -2962,7 +2963,7 @@ Result D3D11Device::createRootShaderObject(IShaderProgram* program, IShaderObjec this, programImpl->slangProgram, programImpl->slangProgram->getLayout(), rootLayout.writeRef())); SLANG_RETURN_ON_FAIL(RootShaderObjectImpl::create( this, rootLayout.Ptr(), shaderObject.writeRef())); - *outObject = shaderObject.detach(); + returnRefPtrMove(outObject, shaderObject); return SLANG_OK; } @@ -3123,7 +3124,7 @@ Result D3D11Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& state->m_blendColor[3] = 0; state->m_sampleMask = 0xFFFFFFFF; state->init(desc); - *outState = state.detach(); + returnComPtr(outState, state); return SLANG_OK; } @@ -3133,7 +3134,7 @@ Result D3D11Device::createComputePipelineState(const ComputePipelineStateDesc& i RefPtr state = new ComputePipelineStateImpl(); state->init(desc); - *outState = state.detach(); + returnComPtr(outState, state); return SLANG_OK; } @@ -3167,7 +3168,7 @@ void D3D11Device::_flushGraphicsState() m_framebufferBindingDirty = false; m_shaderBindingDirty = false; - auto pipelineState = static_cast(m_currentPipelineState.get()); + auto pipelineState = static_cast(m_currentPipelineState.Ptr()); auto rtvCount = (UINT)m_currentFramebuffer->renderTargetViews.getCount(); auto uavCount = (UINT)m_rootBindingState.uavBindings.getCount(); m_immediateContext->OMSetRenderTargetsAndUnorderedAccessViews( @@ -3182,7 +3183,7 @@ void D3D11Device::_flushGraphicsState() if (m_depthStencilStateDirty) { m_depthStencilStateDirty = false; - auto pipelineState = static_cast(m_currentPipelineState.get()); + auto pipelineState = static_cast(m_currentPipelineState.Ptr()); m_immediateContext->OMSetDepthStencilState( pipelineState->m_depthStencilState, m_stencilRef); } diff --git a/tools/gfx/d3d12/descriptor-heap-d3d12.h b/tools/gfx/d3d12/descriptor-heap-d3d12.h index 4b4a22cc3..c893551b0 100644 --- a/tools/gfx/d3d12/descriptor-heap-d3d12.h +++ b/tools/gfx/d3d12/descriptor-heap-d3d12.h @@ -5,6 +5,7 @@ #include #include "slang-com-ptr.h" +#include "core/slang-smart-pointer.h" #include "core/slang-list.h" #include "core/slang-virtual-object-pool.h" @@ -79,7 +80,7 @@ struct D3D12Descriptor /// Unlike the `D3D12DescriptorHeap` type, this class allows for both /// allocation and freeing of descriptors, by maintaining a free list. /// -class D3D12GeneralDescriptorHeap +class D3D12GeneralDescriptorHeap : public Slang::RefObject { ID3D12Device* m_device; int m_chunkSize; diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index b8ff47ad3..15acb680c 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -217,10 +217,11 @@ public: D3D12_RESOURCE_STATES m_defaultState; }; - class SamplerStateImpl : public ISamplerState, public RefObject + class SamplerStateImpl : public ISamplerState, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL + ISamplerState* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ISamplerState) @@ -229,17 +230,17 @@ public: } public: D3D12Descriptor m_descriptor; - D3D12Device* m_renderer; + Slang::RefPtr m_allocator; ~SamplerStateImpl() { - m_renderer->m_cpuSamplerHeap.free(m_descriptor); + m_allocator->free(m_descriptor); } }; - class ResourceViewImpl : public IResourceView, public RefObject + class ResourceViewImpl : public IResourceView, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResourceView* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResourceView) @@ -247,9 +248,9 @@ public: return nullptr; } public: - RefPtr m_resource; - D3D12Descriptor m_descriptor; - D3D12GeneralDescriptorHeap* m_allocator; + RefPtr m_resource; + D3D12Descriptor m_descriptor; + RefPtr m_allocator; ~ResourceViewImpl() { m_allocator->free(m_descriptor); @@ -258,10 +259,10 @@ public: class FramebufferLayoutImpl : public IFramebufferLayout - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IFramebufferLayout* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IFramebufferLayout) @@ -277,10 +278,10 @@ public: class FramebufferImpl : public IFramebuffer - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IFramebuffer* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IFramebuffer) @@ -289,8 +290,8 @@ public: } public: - ShortList> renderTargetViews; - ComPtr depthStencilView; + ShortList> renderTargetViews; + RefPtr depthStencilView; ShortList renderTargetDescriptors; struct Color4f { @@ -312,10 +313,10 @@ public: } }; - class InputLayoutImpl: public IInputLayout, public RefObject + class InputLayoutImpl: public IInputLayout, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IInputLayout* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IInputLayout) @@ -476,7 +477,7 @@ public: public: ComPtr m_commandAllocator; List> m_d3dCommandListPool; - List> m_commandBufferPool; + List> m_commandBufferPool; uint32_t m_commandListAllocId = 0; // Wait values for each command queue. struct QueueWaitInfo @@ -834,7 +835,7 @@ public: auto layout = RefPtr(new ShaderObjectLayoutImpl()); SLANG_RETURN_ON_FAIL(layout->_init(this)); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } }; @@ -921,7 +922,7 @@ public: RefPtr layout = new RootShaderObjectLayoutImpl(); SLANG_RETURN_ON_FAIL(layout->_init(this)); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -1487,11 +1488,10 @@ public: ShaderObjectLayoutImpl* layout, ShaderObjectImpl** outShaderObject) { - auto object = ComPtr(new ShaderObjectImpl()); + auto object = RefPtr(new ShaderObjectImpl()); SLANG_RETURN_ON_FAIL( - object->init(device, layout, &device->m_cpuViewHeap, &device->m_cpuSamplerHeap)); - - *outShaderObject = object.detach(); + object->init(device, layout, device->m_cpuViewHeap.Ptr(), device->m_cpuSamplerHeap.Ptr())); + returnRefPtrMove(outShaderObject, object); return SLANG_OK; } @@ -1510,7 +1510,7 @@ public: } } - RendererBase* getDevice() { return m_layout->getDevice(); } + RendererBase* getDevice() { return m_device.get(); } SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() SLANG_OVERRIDE { return 0; } @@ -1683,9 +1683,7 @@ public: return SLANG_E_INVALID_ARG; auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex); - auto object = m_objects[bindingRange.binding.index + offset.bindingArrayIndex].Ptr(); - object->addRef(); - *outObject = object; + returnComPtr(outObject, m_objects[bindingRange.binding.index + offset.bindingArrayIndex]); return SLANG_OK; } @@ -1842,6 +1840,8 @@ public: DescriptorHeapReference viewHeap, DescriptorHeapReference samplerHeap) { + m_device = device; + m_layout = layout; m_upToDateConstantBufferHeapVersion = 0; @@ -2261,7 +2261,7 @@ public: { SLANG_RETURN_ON_FAIL(_createSpecializedLayout(m_specializedLayout.writeRef())); } - *outLayout = RefPtr(m_specializedLayout).detach(); + returnRefPtr(outLayout, m_specializedLayout); return SLANG_OK; } @@ -2275,11 +2275,11 @@ public: SLANG_RETURN_ON_FAIL(getSpecializedShaderObjectType(&extendedType)); auto renderer = getRenderer(); - RefPtr layout; - SLANG_RETURN_ON_FAIL( - renderer->getShaderObjectLayout(extendedType.slangType, layout.writeRef())); + RefPtr layout; + SLANG_RETURN_ON_FAIL(renderer->getShaderObjectLayout( + extendedType.slangType, (ShaderObjectLayoutBase**)layout.writeRef())); - *outLayout = static_cast(layout.detach()); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -2291,7 +2291,7 @@ public: typedef ShaderObjectImpl Super; public: - // Override default reference counting behavior to disable lifetime management. + // Override default reference counting behavior to disable lifetime management via ComPtr. // Root objects are managed by command buffer and does not need to be freed by the user. SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } @@ -2308,8 +2308,7 @@ public: SlangResult SLANG_MCALL getEntryPoint(UInt index, IShaderObject** outEntryPoint) SLANG_OVERRIDE { - *outEntryPoint = m_entryPoints[index]; - m_entryPoints[index]->addRef(); + returnComPtr(outEntryPoint, m_entryPoints[index]); return SLANG_OK; } @@ -2456,7 +2455,7 @@ public: entryPointVars->m_specializedLayout = entryPointInfo.layout; } - *outLayout = specializedLayout.detach(); + returnRefPtrMove(outLayout, specializedLayout); return SLANG_OK; } @@ -2471,19 +2470,26 @@ public: class CommandBufferImpl : public ICommandBuffer - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + // There are a pair of cyclic references between a `TransientResourceHeap` and + // a `CommandBuffer` created from the heap. We need to break the cycle upon + // the public reference count of a command buffer dropping to 0. + SLANG_COM_OBJECT_IUNKNOWN_ALL + ICommandBuffer* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandBuffer) return static_cast(this); return nullptr; } + virtual void comFree() override { m_transientHeap.breakStrongReference(); } public: ComPtr m_cmdList; - TransientResourceHeapImpl* m_transientHeap; + BreakableReference m_transientHeap; + // Weak reference is fine here since `m_transientHeap` already holds strong reference to + // device. D3D12Device* m_renderer; RootShaderObjectImpl m_rootShaderObject; @@ -2512,17 +2518,17 @@ public: virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) override { - if (uuid == GfxGUID::IID_ISlangUnknown || + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || uuid == GfxGUID::IID_IRenderCommandEncoder) { - *outObject = static_cast(this); + returnComPtr(outObject, static_cast(this)); return SLANG_OK; } *outObject = nullptr; return SLANG_E_NO_INTERFACE; } - virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() { return 1; } - virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() { return 1; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } public: RefPtr m_renderPass; RefPtr m_framebuffer; @@ -2576,8 +2582,7 @@ public: // Transit resource states. { D3D12BarrierSubmitter submitter(m_d3dCmdList); - auto resourceViewImpl = - static_cast(framebuffer->renderTargetViews[i].get()); + auto resourceViewImpl = framebuffer->renderTargetViews[i].Ptr(); auto textureResource = static_cast(resourceViewImpl->m_resource.Ptr()); D3D12_RESOURCE_STATES initialState; @@ -2610,8 +2615,7 @@ public: // Transit resource states. { D3D12BarrierSubmitter submitter(m_d3dCmdList); - auto resourceViewImpl = - static_cast(framebuffer->depthStencilView.get()); + auto resourceViewImpl = framebuffer->depthStencilView.Ptr(); auto textureResource = static_cast(resourceViewImpl->m_resource.Ptr()); D3D12_RESOURCE_STATES initialState; @@ -2840,8 +2844,7 @@ public: // Transit resource states. { D3D12BarrierSubmitter submitter(m_d3dCmdList); - auto resourceViewImpl = static_cast( - m_framebuffer->renderTargetViews[i].get()); + auto resourceViewImpl = m_framebuffer->renderTargetViews[i].Ptr(); auto textureResource = static_cast(resourceViewImpl->m_resource.Ptr()); textureResource->m_resource.transition( @@ -2855,8 +2858,7 @@ public: { // Transit resource states. D3D12BarrierSubmitter submitter(m_d3dCmdList); - auto resourceViewImpl = - static_cast(m_framebuffer->depthStencilView.get()); + auto resourceViewImpl = m_framebuffer->depthStencilView.Ptr(); auto textureResource = static_cast(resourceViewImpl->m_resource.Ptr()); textureResource->m_resource.transition( @@ -2898,10 +2900,10 @@ public: virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) override { - if (uuid == GfxGUID::IID_ISlangUnknown || + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || uuid == GfxGUID::IID_IComputeCommandEncoder) { - *outObject = static_cast(this); + returnComPtr(outObject, static_cast(this)); return SLANG_OK; } *outObject = nullptr; @@ -2961,7 +2963,7 @@ public: virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) override { - if (uuid == GfxGUID::IID_ISlangUnknown || + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || uuid == GfxGUID::IID_IResourceCommandEncoder) { *outObject = static_cast(this); @@ -3026,19 +3028,20 @@ public: class CommandQueueImpl : public ICommandQueue - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ICommandQueue* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandQueue) return static_cast(this); return nullptr; } + void breakStrongReferenceToDevice() { m_renderer.breakStrongReference(); } public: - D3D12Device* m_renderer; + BreakableReference m_renderer; ComPtr m_device; ComPtr m_d3dQueue; ComPtr m_fence; @@ -3160,9 +3163,7 @@ public: RefPtr image = new TextureResourceImpl(imageDesc); image->m_resource.setResource(d3dResource.get()); image->m_defaultState = D3D12_RESOURCE_STATE_PRESENT; - ComPtr imageResourcePtr; - imageResourcePtr = image.Ptr(); - m_images.add(imageResourcePtr); + m_images.add(image); } for (auto evt : m_frameEvents) SetEvent(evt); @@ -3255,14 +3256,14 @@ public: RefPtr m_resourceCommandQueue; RefPtr m_resourceCommandTransientHeap; - D3D12GeneralDescriptorHeap m_rtvAllocator; - D3D12GeneralDescriptorHeap m_dsvAllocator; + RefPtr m_rtvAllocator; + RefPtr m_dsvAllocator; // Space in the GPU-visible heaps is precious, so we will also keep // around CPU-visible heaps for storing descriptors in a format // that is ready for copying into the GPU-visible heaps as needed. // - D3D12GeneralDescriptorHeap m_cpuViewHeap; ///< Cbv, Srv, Uav - D3D12GeneralDescriptorHeap m_cpuSamplerHeap; ///< Heap for samplers + RefPtr m_cpuViewHeap; ///< Cbv, Srv, Uav + RefPtr m_cpuSamplerHeap; ///< Heap for samplers // Dll entry points PFN_D3D12_GET_DEBUG_INTERFACE m_D3D12GetDebugInterface = nullptr; @@ -3294,12 +3295,11 @@ Result D3D12Device::TransientResourceHeapImpl::createCommandBuffer(ICommandBuffe if ((Index)m_commandListAllocId < m_commandBufferPool.getCount()) { auto result = static_cast( - m_commandBufferPool[m_commandListAllocId].get()); + m_commandBufferPool[m_commandListAllocId].Ptr()); m_d3dCommandListPool[m_commandListAllocId]->Reset(m_commandAllocator, nullptr); result->init(m_device, m_d3dCommandListPool[m_commandListAllocId], this); ++m_commandListAllocId; - result->addRef(); - *outCmdBuffer = result; + returnComPtr(outCmdBuffer, result); return SLANG_OK; } ComPtr cmdList; @@ -3313,11 +3313,9 @@ Result D3D12Device::TransientResourceHeapImpl::createCommandBuffer(ICommandBuffe m_d3dCommandListPool.add(cmdList); RefPtr cmdBuffer = new CommandBufferImpl(); cmdBuffer->init(m_device, cmdList, this); - ComPtr cmdBufferPtr; - *cmdBufferPtr.writeRef() = cmdBuffer.detach(); - m_commandBufferPool.add(cmdBufferPtr); + m_commandBufferPool.add(cmdBuffer); ++m_commandListAllocId; - *outCmdBuffer = cmdBufferPtr.detach(); + returnComPtr(outCmdBuffer, cmdBuffer); return SLANG_OK; } @@ -3329,7 +3327,7 @@ Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitte PipelineStateImpl* newPipelineImpl = static_cast(newPipeline.Ptr()); auto commandList = m_d3dCmdList; auto pipelineTypeIndex = (int)newPipelineImpl->desc.type; - auto programImpl = static_cast(newPipelineImpl->m_program.get()); + auto programImpl = static_cast(newPipelineImpl->m_program.Ptr()); commandList->SetPipelineState(newPipelineImpl->m_pipelineState); submitter->setRootSignature(programImpl->m_rootObjectLayout->m_rootSignature); RefPtr specializedRootLayout; @@ -3382,7 +3380,7 @@ Result D3D12Device::createTransientResourceHeapImpl( ITransientResourceHeap::Desc desc = {}; desc.constantBufferSize = constantBufferSize; SLANG_RETURN_ON_FAIL(result->init(desc, this, viewDescriptors, samplerDescriptors)); - *outHeap = result.detach(); + returnRefPtrMove(outHeap, result); return SLANG_OK; } @@ -3395,7 +3393,7 @@ Result D3D12Device::createCommandQueueImpl(D3D12Device::CommandQueueImpl** outQu RefPtr queue = new D3D12Device::CommandQueueImpl(); SLANG_RETURN_ON_FAIL(queue->init(this, (uint32_t)queueIndex)); - *outQueue = queue.detach(); + returnRefPtrMove(outQueue, queue); return SLANG_OK; } @@ -3403,7 +3401,7 @@ SlangResult SLANG_MCALL createD3D12Device(const IDevice::Desc* desc, IDevice** o { RefPtr result = new D3D12Device(); SLANG_RETURN_ON_FAIL(result->initialize(*desc)); - *outDevice = result.detach(); + returnComPtr(outDevice, result); return SLANG_OK; } @@ -3418,9 +3416,7 @@ SlangResult SLANG_MCALL createD3D12Device(const IDevice::Desc* desc, IDevice** o return proc; } -D3D12Device::~D3D12Device() -{ -} +D3D12Device::~D3D12Device() { m_shaderObjectLayoutCache = decltype(m_shaderObjectLayoutCache)(); } static void _initSrvDesc(IResource::Type resourceType, const ITextureResource::Desc& textureDesc, const D3D12_RESOURCE_DESC& desc, DXGI_FORMAT pixelFormat, D3D12_SHADER_RESOURCE_VIEW_DESC& descOut) { @@ -3628,7 +3624,7 @@ Result D3D12Device::captureTextureToSurface( resultBlob->m_data.setCount(bufferSize); memcpy(resultBlob->m_data.getBuffer(), data, bufferSize); dxResource->Unmap(0, nullptr); - *outBlob = resultBlob.detach(); + returnComPtr(outBlob, resultBlob); return SLANG_OK; } } @@ -3905,16 +3901,27 @@ Result D3D12Device::initialize(const Desc& desc) // Create a command queue for internal resource transfer operations. SLANG_RETURN_ON_FAIL(createCommandQueueImpl(m_resourceCommandQueue.writeRef())); + // `CommandQueueImpl` holds a back reference to `D3D12Device`, make it a weak reference here + // since this object is already owned by `D3D12Device`. + m_resourceCommandQueue->breakStrongReferenceToDevice(); + SLANG_RETURN_ON_FAIL(createTransientResourceHeapImpl(0, 8, 4, m_resourceCommandTransientHeap.writeRef())); + // `TransientResourceHeap` holds a back reference to `D3D12Device`, make it a weak reference here + // since this object is already owned by `D3D12Device`. + m_resourceCommandTransientHeap->breakStrongReferenceToDevice(); - SLANG_RETURN_ON_FAIL(m_cpuViewHeap.init( + m_cpuViewHeap = new D3D12GeneralDescriptorHeap(); + SLANG_RETURN_ON_FAIL(m_cpuViewHeap->init( m_device, 8192, D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE)); - SLANG_RETURN_ON_FAIL(m_cpuSamplerHeap.init( + m_cpuSamplerHeap = new D3D12GeneralDescriptorHeap(); + SLANG_RETURN_ON_FAIL(m_cpuSamplerHeap->init( m_device, 1024, D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER, D3D12_DESCRIPTOR_HEAP_FLAG_NONE)); - SLANG_RETURN_ON_FAIL(m_rtvAllocator.init( + m_rtvAllocator = new D3D12GeneralDescriptorHeap(); + SLANG_RETURN_ON_FAIL(m_rtvAllocator->init( m_device, 16, D3D12_DESCRIPTOR_HEAP_TYPE_RTV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE)); - SLANG_RETURN_ON_FAIL(m_dsvAllocator.init( + m_dsvAllocator = new D3D12GeneralDescriptorHeap(); + SLANG_RETURN_ON_FAIL(m_dsvAllocator->init( m_device, 16, D3D12_DESCRIPTOR_HEAP_TYPE_DSV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE)); ComPtr dxgiDevice; @@ -3937,7 +3944,7 @@ Result D3D12Device::createTransientResourceHeap( RefPtr heap; SLANG_RETURN_ON_FAIL( createTransientResourceHeapImpl(desc.constantBufferSize, 8192, 1024, heap.writeRef())); - *outHeap = heap.detach(); + returnComPtr(outHeap, heap); return SLANG_OK; } @@ -3945,7 +3952,7 @@ Result D3D12Device::createCommandQueue(const ICommandQueue::Desc& desc, ICommand { RefPtr queue; SLANG_RETURN_ON_FAIL(createCommandQueueImpl(queue.writeRef())); - *outQueue = queue.detach(); + returnComPtr(outQueue, queue); return SLANG_OK; } @@ -3954,7 +3961,7 @@ SLANG_NO_THROW Result SLANG_MCALL D3D12Device::createSwapchain( { RefPtr swapchain = new SwapchainImpl(); SLANG_RETURN_ON_FAIL(swapchain->init(this, desc, window)); - *outSwapchain = swapchain.detach(); + returnComPtr(outSwapchain, swapchain); return SLANG_OK; } @@ -4262,7 +4269,7 @@ Result D3D12Device::createTextureResource(IResource::Usage initialUsage, const I submitResourceCommandsAndWait(encodeInfo); } - *outResource = texture.detach(); + returnComPtr(outResource, texture); return SLANG_OK; } @@ -4287,7 +4294,7 @@ Result D3D12Device::createBufferResource(IResource::Usage initialUsage, const IB const D3D12_RESOURCE_STATES initialState = _calcResourceState(initialUsage); SLANG_RETURN_ON_FAIL(createBuffer(bufferDesc, initData, srcDesc.sizeInBytes, buffer->m_uploadResource, initialState, buffer->m_resource)); - *outResource = buffer.detach(); + returnComPtr(outResource, buffer); return SLANG_OK; } @@ -4400,7 +4407,7 @@ Result D3D12Device::createSamplerState(ISamplerState::Desc const& desc, ISampler dxDesc.MinLOD = desc.minLOD; dxDesc.MaxLOD = desc.maxLOD; - auto samplerHeap = &m_cpuSamplerHeap; + auto& samplerHeap = m_cpuSamplerHeap; D3D12Descriptor cpuDescriptor; samplerHeap->allocate(&cpuDescriptor); @@ -4411,9 +4418,9 @@ Result D3D12Device::createSamplerState(ISamplerState::Desc const& desc, ISampler // when we are done with a sampler we simply add it to the free list. // RefPtr samplerImpl = new SamplerStateImpl(); - samplerImpl->m_renderer = this; + samplerImpl->m_allocator = samplerHeap; samplerImpl->m_descriptor = cpuDescriptor; - *outSampler = samplerImpl.detach(); + returnComPtr(outSampler, samplerImpl); return SLANG_OK; } @@ -4431,8 +4438,8 @@ Result D3D12Device::createTextureView(ITextureResource* texture, IResourceView:: case IResourceView::Type::RenderTarget: { - SLANG_RETURN_ON_FAIL(m_rtvAllocator.allocate(&viewImpl->m_descriptor)); - viewImpl->m_allocator = &m_rtvAllocator; + SLANG_RETURN_ON_FAIL(m_rtvAllocator->allocate(&viewImpl->m_descriptor)); + viewImpl->m_allocator = m_rtvAllocator; D3D12_RENDER_TARGET_VIEW_DESC rtvDesc = {}; rtvDesc.Format = D3DUtil::getMapFormat(desc.format); switch (desc.renderTarget.shape) @@ -4462,8 +4469,8 @@ Result D3D12Device::createTextureView(ITextureResource* texture, IResourceView:: case IResourceView::Type::DepthStencil: { - SLANG_RETURN_ON_FAIL(m_dsvAllocator.allocate(&viewImpl->m_descriptor)); - viewImpl->m_allocator = &m_dsvAllocator; + SLANG_RETURN_ON_FAIL(m_dsvAllocator->allocate(&viewImpl->m_descriptor)); + viewImpl->m_allocator = m_dsvAllocator; D3D12_DEPTH_STENCIL_VIEW_DESC dsvDesc = {}; dsvDesc.Format = D3DUtil::getMapFormat(desc.format); switch (desc.renderTarget.shape) @@ -4489,16 +4496,16 @@ Result D3D12Device::createTextureView(ITextureResource* texture, IResourceView:: // TODO: need to support the separate "counter resource" for the case // of append/consume buffers with attached counters. - SLANG_RETURN_ON_FAIL(m_cpuViewHeap.allocate(&viewImpl->m_descriptor)); - viewImpl->m_allocator = &m_cpuViewHeap; + SLANG_RETURN_ON_FAIL(m_cpuViewHeap->allocate(&viewImpl->m_descriptor)); + viewImpl->m_allocator = m_cpuViewHeap; m_device->CreateUnorderedAccessView(resourceImpl->m_resource, nullptr, nullptr, viewImpl->m_descriptor.cpuHandle); } break; case IResourceView::Type::ShaderResource: { - SLANG_RETURN_ON_FAIL(m_cpuViewHeap.allocate(&viewImpl->m_descriptor)); - viewImpl->m_allocator = &m_cpuViewHeap; + SLANG_RETURN_ON_FAIL(m_cpuViewHeap->allocate(&viewImpl->m_descriptor)); + viewImpl->m_allocator = m_cpuViewHeap; // Need to construct the D3D12_SHADER_RESOURCE_VIEW_DESC because otherwise TextureCube is not accessed // appropriately (rather than just passing nullptr to CreateShaderResourceView) @@ -4513,7 +4520,7 @@ Result D3D12Device::createTextureView(ITextureResource* texture, IResourceView:: break; } - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } @@ -4557,8 +4564,8 @@ Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Des // TODO: need to support the separate "counter resource" for the case // of append/consume buffers with attached counters. - SLANG_RETURN_ON_FAIL(m_cpuViewHeap.allocate(&viewImpl->m_descriptor)); - viewImpl->m_allocator = &m_cpuViewHeap; + SLANG_RETURN_ON_FAIL(m_cpuViewHeap->allocate(&viewImpl->m_descriptor)); + viewImpl->m_allocator = m_cpuViewHeap; m_device->CreateUnorderedAccessView(resourceImpl->m_resource, nullptr, &uavDesc, viewImpl->m_descriptor.cpuHandle); } break; @@ -4587,14 +4594,14 @@ Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Des srvDesc.Buffer.NumElements = UINT(resourceDesc.sizeInBytes / gfxGetFormatSize(desc.format)); } - SLANG_RETURN_ON_FAIL(m_cpuViewHeap.allocate(&viewImpl->m_descriptor)); - viewImpl->m_allocator = &m_cpuViewHeap; + SLANG_RETURN_ON_FAIL(m_cpuViewHeap->allocate(&viewImpl->m_descriptor)); + viewImpl->m_allocator = m_cpuViewHeap; m_device->CreateShaderResourceView(resourceImpl->m_resource, &srvDesc, viewImpl->m_descriptor.cpuHandle); } break; } - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } @@ -4606,9 +4613,9 @@ Result D3D12Device::createFramebuffer(IFramebuffer::Desc const& desc, IFramebuff framebuffer->renderTargetClearValues.setCount(desc.renderTargetCount); for (uint32_t i = 0; i < desc.renderTargetCount; i++) { - framebuffer->renderTargetViews[i] = desc.renderTargetViews[i]; + framebuffer->renderTargetViews[i] = static_cast(desc.renderTargetViews[i]); framebuffer->renderTargetDescriptors[i] = - static_cast(desc.renderTargetViews[i])->m_descriptor.cpuHandle; + framebuffer->renderTargetViews[i]->m_descriptor.cpuHandle; auto clearValue = static_cast( static_cast(desc.renderTargetViews[i])->m_resource.Ptr()) @@ -4616,7 +4623,7 @@ Result D3D12Device::createFramebuffer(IFramebuffer::Desc const& desc, IFramebuff ->optimalClearValue.color; memcpy(&framebuffer->renderTargetClearValues[i], &clearValue, sizeof(ColorClearValue)); } - framebuffer->depthStencilView = desc.depthStencilView; + framebuffer->depthStencilView = static_cast(desc.depthStencilView); if (desc.depthStencilView) { framebuffer->depthStencilClearValue = @@ -4631,7 +4638,7 @@ Result D3D12Device::createFramebuffer(IFramebuffer::Desc const& desc, IFramebuff { framebuffer->depthStencilDescriptor.ptr = 0; } - *outFb = framebuffer.detach(); + returnComPtr(outFb, framebuffer); return SLANG_OK; } @@ -4654,7 +4661,7 @@ Result D3D12Device::createFramebufferLayout( { layout->m_hasDepthStencil = false; } - *outLayout = layout.detach(); + returnComPtr(outLayout, layout); return SLANG_OK; } @@ -4664,7 +4671,7 @@ Result D3D12Device::createRenderPassLayout( { RefPtr result = new RenderPassLayoutImpl(); result->init(desc); - *outRenderPassLayout = result.detach(); + returnComPtr(outRenderPassLayout, result); return SLANG_OK; } @@ -4711,7 +4718,7 @@ Result D3D12Device::createInputLayout(const InputElementDesc* inputElements, UIn dstEle.InstanceDataStepRate = 0; } - *outLayout = layout.detach(); + returnComPtr(outLayout, layout); return SLANG_OK; } @@ -4765,7 +4772,7 @@ Result D3D12Device::readBufferResource( stageBuf.getResource()->Unmap(0, nullptr); } - *outBlob = blob.detach(); + returnComPtr(outBlob, blob); return SLANG_OK; } @@ -4782,7 +4789,7 @@ Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr if (desc.slangProgram->getSpecializationParamCount() != 0) { // For a specializable program, we don't invoke any actual slang compilation yet. - *outProgram = shaderProgram.detach(); + returnComPtr(outProgram, shaderProgram); return SLANG_OK; } // For a fully specialized program, read and store its kernel code in `shaderProgram`. @@ -4820,7 +4827,7 @@ Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr reinterpret_cast(kernelCode->getBufferPointer()), (Index)kernelCode->getBufferSize()); } - *outProgram = shaderProgram.detach(); + returnComPtr(outProgram, shaderProgram); return SLANG_OK; } @@ -4832,7 +4839,7 @@ Result D3D12Device::createShaderObjectLayout( SLANG_RETURN_ON_FAIL( ShaderObjectLayoutImpl::createForElementType( this, typeLayout, layout.writeRef())); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -4844,7 +4851,7 @@ Result D3D12Device::createShaderObject( SLANG_RETURN_ON_FAIL(ShaderObjectImpl::create( this, reinterpret_cast(layout), shaderObject.writeRef())); - *outObject = shaderObject.detach(); + returnComPtr(outObject, shaderObject); return SLANG_OK; } @@ -4857,7 +4864,7 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& { RefPtr pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->init(desc); - *outState = pipelineStateImpl.detach(); + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } @@ -4956,7 +4963,7 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& RefPtr pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->m_pipelineState = pipelineState; pipelineStateImpl->init(desc); - *outState = pipelineStateImpl.detach(); + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } @@ -4969,7 +4976,7 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i { RefPtr pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->init(desc); - *outState = pipelineStateImpl.detach(); + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } @@ -5024,7 +5031,7 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i RefPtr pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->m_pipelineState = pipelineState; pipelineStateImpl->init(desc); - *outState = pipelineStateImpl.detach(); + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } diff --git a/tools/gfx/immediate-renderer-base.cpp b/tools/gfx/immediate-renderer-base.cpp index 8c2f3b927..0ee2b7261 100644 --- a/tools/gfx/immediate-renderer-base.cpp +++ b/tools/gfx/immediate-renderer-base.cpp @@ -19,10 +19,10 @@ using Slang::Guid; namespace { -class CommandBufferImpl : public ICommandBuffer, public RefObject +class CommandBufferImpl : public ICommandBuffer, public Slang::ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ICommandBuffer* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandBuffer) @@ -32,7 +32,7 @@ public: public: CommandWriter m_writer; - ImmediateRendererBase* m_renderer; + RefPtr m_renderer; RefPtr m_rootShaderObject; void init(ImmediateRendererBase* renderer) @@ -52,7 +52,8 @@ public: virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) override { - if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_IRenderCommandEncoder) + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || + uuid == GfxGUID::IID_IRenderCommandEncoder) { *outObject = static_cast(this); return SLANG_OK; @@ -110,8 +111,8 @@ public: m_writer->setPipelineState(state); auto stateImpl = static_cast(state); SLANG_RETURN_ON_FAIL(m_commandBuffer->m_renderer->createRootShaderObject( - stateImpl->m_program, outRootObject)); - *m_commandBuffer->m_rootShaderObject.writeRef() = static_cast(*outRootObject); + stateImpl->m_program, m_commandBuffer->m_rootShaderObject.writeRef())); + *outRootObject = m_commandBuffer->m_rootShaderObject.Ptr(); return SLANG_OK; } @@ -184,7 +185,8 @@ public: virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) override { - if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_IComputeCommandEncoder) + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || + uuid == GfxGUID::IID_IComputeCommandEncoder) { *outObject = static_cast(this); return SLANG_OK; @@ -215,8 +217,8 @@ public: m_writer->setPipelineState(state); auto stateImpl = static_cast(state); SLANG_RETURN_ON_FAIL(m_commandBuffer->m_renderer->createRootShaderObject( - stateImpl->m_program, outRootObject)); - *m_commandBuffer->m_rootShaderObject.writeRef() = static_cast(*outRootObject); + stateImpl->m_program, m_commandBuffer->m_rootShaderObject.writeRef())); + *outRootObject = m_commandBuffer->m_rootShaderObject.Ptr(); return SLANG_OK; } @@ -242,7 +244,8 @@ public: virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) override { - if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_IResourceCommandEncoder) + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || + uuid == GfxGUID::IID_IResourceCommandEncoder) { *outObject = static_cast(this); return SLANG_OK; @@ -381,33 +384,27 @@ public: } }; -class CommandQueueImpl - : public ICommandQueue - , public RefObject +class CommandQueueImpl : public ImmediateCommandQueueBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - ICommandQueue* getInterface(const Guid& guid) + ICommandQueue::Desc m_desc; + + ImmediateRendererBase* getRenderer() { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandQueue) - return static_cast(this); - return nullptr; + return static_cast(m_renderer.get()); } -public: - ImmediateRendererBase* m_renderer; - ICommandQueue::Desc m_desc; - CommandQueueImpl(ImmediateRendererBase* renderer) - : m_renderer(renderer) { + // Don't establish strong reference to `Device` at start, because + // there will be only one instance of command queue and it will be + // owned by `Device`. We should establish a strong reference only + // when there are external references to the command queue. + m_renderer.setWeakReference(renderer); m_desc.type = ICommandQueue::QueueType::Graphics; } - ~CommandQueueImpl() - { - m_renderer->m_queueCreateCount--; - } + ~CommandQueueImpl() { getRenderer()->m_queueCreateCount--; } virtual SLANG_NO_THROW const Desc& SLANG_MCALL getDesc() override { return m_desc; } @@ -420,10 +417,7 @@ public: } } - virtual SLANG_NO_THROW void SLANG_MCALL wait() override - { - m_renderer->waitForGpu(); - } + virtual SLANG_NO_THROW void SLANG_MCALL wait() override { getRenderer()->waitForGpu(); } }; using TransientResourceHeapImpl = @@ -431,7 +425,8 @@ using TransientResourceHeapImpl = } -ImmediateRendererBase::ImmediateRendererBase() { +ImmediateRendererBase::ImmediateRendererBase() +{ m_queue = new CommandQueueImpl(this); } @@ -441,7 +436,7 @@ SLANG_NO_THROW Result SLANG_MCALL ImmediateRendererBase::createTransientResource { RefPtr result = new TransientResourceHeapImpl(); SLANG_RETURN_ON_FAIL(result->init(this, desc)); - *outHeap = result.detach(); + returnComPtr(outHeap, result); return SLANG_OK; } @@ -453,8 +448,8 @@ SLANG_NO_THROW Result SLANG_MCALL ImmediateRendererBase::createCommandQueue( // Only one queue is supported. if (m_queueCreateCount != 0) return SLANG_FAIL; - *outQueue = m_queue.get(); - m_queue->addRef(); + m_queue->establishStrongReferenceToDevice(); + returnComPtr(outQueue, m_queue); return SLANG_OK; } @@ -464,7 +459,7 @@ SLANG_NO_THROW Result SLANG_MCALL ImmediateRendererBase::createRenderPassLayout( { RefPtr renderPass = new SimpleRenderPassLayout(); renderPass->init(desc); - *outRenderPassLayout = renderPass.detach(); + returnComPtr(outRenderPassLayout, renderPass); return SLANG_OK; } @@ -492,7 +487,7 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ImmediateRendererBase::readBufferResource return SLANG_FAIL; memcpy(blob->m_data.getBuffer(), content + offset, size); unmap(buffer); - *outBlob = blob.detach(); + returnComPtr(outBlob, blob); return SLANG_OK; } diff --git a/tools/gfx/immediate-renderer-base.h b/tools/gfx/immediate-renderer-base.h index 5f1770be0..6d4365cab 100644 --- a/tools/gfx/immediate-renderer-base.h +++ b/tools/gfx/immediate-renderer-base.h @@ -17,14 +17,36 @@ enum class MapFlavor WriteDiscard, }; -class ImmediateRendererBase : public RendererBase +class ImmediateCommandQueueBase + : public ICommandQueue + , public Slang::ComObject { -private: - ComPtr m_currentPipelineState; +public: + // Immediate device also holds a strong reference to an instance of `ImmediateCommandQueue`, + // forming a cyclic reference. Therefore we need a free-op here to break the cycle when + // the public reference count of the queue drops to 0. + SLANG_COM_OBJECT_IUNKNOWN_ALL + ICommandQueue* getInterface(const Slang::Guid& guid) + { + if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandQueue) + return static_cast(this); + return nullptr; + } + virtual void comFree() override { breakStrongReferenceToDevice(); } +public: + BreakableReference m_renderer; + void breakStrongReferenceToDevice() { m_renderer.breakStrongReference(); } + void establishStrongReferenceToDevice() { m_renderer.establishStrongReference(); } +}; + +class ImmediateRendererBase : public RendererBase +{ public: // Immediate commands to be implemented by each target. - virtual Result createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) = 0; + virtual Result createRootShaderObject( + IShaderProgram* program, + ShaderObjectBase** outObject) = 0; virtual void bindRootShaderObject(IShaderObject* rootObject) = 0; virtual void setPipelineState(IPipelineState* state) = 0; virtual void setFramebuffer(IFramebuffer* frameBuffer) = 0; @@ -55,7 +77,7 @@ public: virtual void unmap(IBufferResource* buffer) = 0; public: - Slang::ComPtr m_queue; + Slang::RefPtr m_queue; uint32_t m_queueCreateCount = 0; ImmediateRendererBase(); diff --git a/tools/gfx/open-gl/render-gl.cpp b/tools/gfx/open-gl/render-gl.cpp index 9997fd5c8..c0b739096 100644 --- a/tools/gfx/open-gl/render-gl.cpp +++ b/tools/gfx/open-gl/render-gl.cpp @@ -134,7 +134,7 @@ public: slang::TypeLayoutReflection* typeLayout, ShaderObjectLayoutBase** outLayout) override; virtual Result createShaderObject(ShaderObjectLayoutBase* layout, IShaderObject** outObject) override; - virtual Result createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override; + virtual Result createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject) override; virtual void bindRootShaderObject(IShaderObject* shaderObject) override; virtual SLANG_NO_THROW Result SLANG_MCALL @@ -201,17 +201,17 @@ public: GLsizei offset; }; - class InputLayoutImpl: public IInputLayout, public RefObject + class InputLayoutImpl: public IInputLayout, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IInputLayout* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IInputLayout) return static_cast(this); return nullptr; } - public: + public: VertexAttributeDesc m_attributes[kMaxVertexStreams]; UInt m_attributeCount = 0; }; @@ -238,7 +238,7 @@ public: } Usage m_initialUsage; - RefPtr > m_renderer; + RefPtr> m_renderer; GLuint m_handle; GLenum m_target; UInt m_size; @@ -267,15 +267,15 @@ public: } Usage m_initialUsage; - RefPtr > m_renderer; + RefPtr> m_renderer; GLenum m_target; GLuint m_handle; }; - class SamplerStateImpl : public ISamplerState, public RefObject + class SamplerStateImpl : public ISamplerState, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ISamplerState* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ISamplerState) @@ -286,10 +286,10 @@ public: GLuint m_samplerID; }; - class ResourceViewImpl : public IResourceView, public RefObject + class ResourceViewImpl : public IResourceView, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResourceView* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResourceView) @@ -331,10 +331,10 @@ public: class FramebufferLayoutImpl : public IFramebufferLayout - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IFramebufferLayout* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IFramebufferLayout) @@ -350,10 +350,10 @@ public: class FramebufferImpl : public IFramebuffer - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IFramebuffer* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IFramebuffer) @@ -364,7 +364,7 @@ public: public: GLuint m_framebuffer; ShortList m_drawBuffers; - WeakSink* m_renderer; + RefPtr> m_renderer; ShortList> renderTargetViews; RefPtr depthStencilView; ShortList m_colorClearValues; @@ -428,10 +428,10 @@ public: class SwapchainImpl : public ISwapchain - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ISwapchain* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ISwapchain) @@ -523,8 +523,7 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL getImage(uint32_t index, ITextureResource** outResource) override { - m_images[index]->addRef(); - *outResource = m_images[index].Ptr(); + returnComPtr(outResource, m_images[index]); return SLANG_OK; } virtual SLANG_NO_THROW Result SLANG_MCALL present() override @@ -570,7 +569,7 @@ public: } public: - WeakSink* m_renderer = nullptr; + RefPtr> m_renderer = nullptr; GLuint m_framebuffer; GLuint m_backBuffer; HGLRC m_glrc; @@ -600,7 +599,7 @@ public: } GLuint m_id; - RefPtr > m_renderer; + RefPtr> m_renderer; }; class PipelineStateImpl : public PipelineStateBase @@ -770,7 +769,7 @@ public: RefPtr(new ShaderObjectLayoutImpl()); SLANG_RETURN_ON_FAIL(layout->_init(this)); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } }; @@ -858,7 +857,7 @@ public: RefPtr layout = new RootShaderObjectLayoutImpl(); SLANG_RETURN_ON_FAIL(layout->_init(this)); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -937,10 +936,10 @@ public: ShaderObjectLayoutImpl* layout, ShaderObjectImpl** outShaderObject) { - auto object = ComPtr(new ShaderObjectImpl()); + auto object = RefPtr(new ShaderObjectImpl()); SLANG_RETURN_ON_FAIL(object->init(device, layout)); - *outShaderObject = object.detach(); + returnRefPtrMove(outShaderObject, object); return SLANG_OK; } @@ -1523,7 +1522,7 @@ public: { SLANG_RETURN_ON_FAIL(_createSpecializedLayout(m_specializedLayout.writeRef())); } - *outLayout = RefPtr(m_specializedLayout).detach(); + returnRefPtr(outLayout, m_specializedLayout); return SLANG_OK; } @@ -1537,10 +1536,11 @@ public: SLANG_RETURN_ON_FAIL(getSpecializedShaderObjectType(&extendedType)); auto renderer = getRenderer(); - RefPtr layout; - SLANG_RETURN_ON_FAIL(renderer->getShaderObjectLayout(extendedType.slangType, layout.writeRef())); + RefPtr layout; + SLANG_RETURN_ON_FAIL(renderer->getShaderObjectLayout( + extendedType.slangType, (ShaderObjectLayoutBase**)layout.writeRef())); - *outLayout = static_cast(layout.detach()); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -1551,13 +1551,19 @@ public: { typedef ShaderObjectImpl Super; + public: + // Override default reference counting behavior to disable lifetime management via ComPtr. + // Root objects are managed by command buffer and does not need to be freed by the user. + SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } + SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } + public: static Result create(IDevice* device, RootShaderObjectLayoutImpl* layout, RootShaderObjectImpl** outShaderObject) { RefPtr object = new RootShaderObjectImpl(); SLANG_RETURN_ON_FAIL(object->init(device, layout)); - *outShaderObject = object.detach(); + returnRefPtrMove(outShaderObject, object); return SLANG_OK; } @@ -1686,7 +1692,7 @@ public: entryPointVars->m_specializedLayout = entryPointInfo.layout; } - *outLayout = specializedLayout.detach(); + returnRefPtrMove(outLayout, specializedLayout); return SLANG_OK; } @@ -1737,7 +1743,7 @@ public: GLuint m_vao; RefPtr m_currentPipelineState; RefPtr m_currentFramebuffer; - RefPtr > m_weakRenderer; + RefPtr> m_weakRenderer; RootBindingState m_rootBindingState; @@ -1792,7 +1798,7 @@ SlangResult SLANG_MCALL createGLDevice(const IDevice::Desc* desc, IDevice** outR { RefPtr result = new GLDevice(); SLANG_RETURN_ON_FAIL(result->initialize(*desc)); - *outRenderer = result.detach(); + returnComPtr(outRenderer, result); return SLANG_OK; } @@ -2282,7 +2288,7 @@ SLANG_NO_THROW Result SLANG_MCALL GLDevice::createSwapchain( { RefPtr swapchain = new SwapchainImpl(); SLANG_RETURN_ON_FAIL(swapchain->init(this, desc, window)); - *outSwapchain = swapchain.detach(); + returnComPtr(outSwapchain, swapchain); wglMakeCurrent(m_hdc, m_glContext); return SLANG_OK; } @@ -2306,7 +2312,7 @@ SLANG_NO_THROW Result SLANG_MCALL GLDevice::createFramebufferLayout( { layout->m_hasDepthStencil = false; } - *outLayout = layout.detach(); + returnComPtr(outLayout, layout); return SLANG_OK; } @@ -2322,7 +2328,7 @@ SLANG_NO_THROW Result SLANG_MCALL } framebuffer->depthStencilView = static_cast(desc.depthStencilView); framebuffer->createGLFramebuffer(); - *outFramebuffer = framebuffer.detach(); + returnComPtr(outFramebuffer, framebuffer); return SLANG_OK; } @@ -2380,7 +2386,7 @@ SLANG_NO_THROW Result SLANG_MCALL GLDevice::readTextureResource( } } - *outBlob = blob.detach(); + returnComPtr(outBlob, blob); return SLANG_OK; } @@ -2584,7 +2590,7 @@ SLANG_NO_THROW Result SLANG_MCALL GLDevice::createTextureResource( texture->m_target = target; - *outResource = texture.detach(); + returnComPtr(outResource, texture); return SLANG_OK; } @@ -2628,7 +2634,7 @@ SLANG_NO_THROW Result SLANG_MCALL GLDevice::createBufferResource( glBufferData(target, descIn.sizeInBytes, initData, usage); RefPtr resourceImpl = new BufferResourceImpl(initialUsage, desc, m_weakRenderer, bufferID, target); - *outResource = resourceImpl.detach(); + returnComPtr(outResource, resourceImpl); return SLANG_OK; } @@ -2640,7 +2646,7 @@ SLANG_NO_THROW Result SLANG_MCALL RefPtr samplerImpl = new SamplerStateImpl(); samplerImpl->m_samplerID = samplerID; - *outSampler = samplerImpl.detach(); + returnComPtr(outSampler, samplerImpl); return SLANG_OK; } @@ -2671,7 +2677,7 @@ SLANG_NO_THROW Result SLANG_MCALL GLDevice::createTextureView( viewImpl->layered = GL_TRUE; viewImpl->level = 0; viewImpl->layer = 0; - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } @@ -2686,7 +2692,7 @@ SLANG_NO_THROW Result SLANG_MCALL GLDevice::createBufferView( viewImpl->type = ResourceViewImpl::Type::Buffer; viewImpl->m_resource = resourceImpl; viewImpl->m_bufferID = resourceImpl->m_handle; - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } @@ -2706,7 +2712,7 @@ SLANG_NO_THROW Result SLANG_MCALL GLDevice::createInputLayout( glAttr.offset = (GLsizei)inputAttr.offset; } - *outLayout = inputLayout.detach(); + returnComPtr(outLayout, inputLayout); return SLANG_OK; } @@ -2827,7 +2833,7 @@ void GLDevice::setPipelineState(IPipelineState* state) m_currentPipelineState = pipelineStateImpl; - auto program = static_cast(pipelineStateImpl->m_program.get()); + auto program = static_cast(pipelineStateImpl->m_program.Ptr()); GLuint programID = program ? program->m_id : 0; glUseProgram(programID); } @@ -2863,7 +2869,7 @@ Result GLDevice::createProgram(const IShaderProgram::Desc& desc, IShaderProgram* // For a specializable program, we don't invoke any actual slang compilation yet. RefPtr shaderProgram = new ShaderProgramImpl(m_weakRenderer, 0); shaderProgram->slangProgram = desc.slangProgram; - *outProgram = shaderProgram.detach(); + returnComPtr(outProgram, shaderProgram); return SLANG_OK; } @@ -2933,7 +2939,7 @@ Result GLDevice::createProgram(const IShaderProgram::Desc& desc, IShaderProgram* RefPtr program = new ShaderProgramImpl(m_weakRenderer, programID); program->slangProgram = desc.slangProgram; - *outProgram = program.detach(); + returnComPtr(outProgram, program); return SLANG_OK; } @@ -2947,7 +2953,7 @@ Result GLDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& in RefPtr pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->m_inputLayout = inputLayoutImpl; pipelineStateImpl->init(desc); - *outState = pipelineStateImpl.detach(); + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } @@ -2960,7 +2966,7 @@ Result GLDevice::createComputePipelineState(const ComputePipelineStateDesc& inDe RefPtr pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->m_program = programImpl; pipelineStateImpl->init(desc); - *outState = pipelineStateImpl.detach(); + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } @@ -2971,7 +2977,7 @@ Result GLDevice::createShaderObjectLayout( RefPtr layout; SLANG_RETURN_ON_FAIL(ShaderObjectLayoutImpl::createForElementType( this, typeLayout, layout.writeRef())); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -2980,11 +2986,11 @@ Result GLDevice::createShaderObject(ShaderObjectLayoutBase* layout, IShaderObjec RefPtr shaderObject; SLANG_RETURN_ON_FAIL(ShaderObjectImpl::create(this, static_cast(layout), shaderObject.writeRef())); - *outObject = shaderObject.detach(); + returnComPtr(outObject, shaderObject); return SLANG_OK; } -Result GLDevice::createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) +Result GLDevice::createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject) { auto programImpl = static_cast(program); RefPtr shaderObject; @@ -2993,7 +2999,7 @@ Result GLDevice::createRootShaderObject(IShaderProgram* program, IShaderObject** this, programImpl->slangProgram, programImpl->slangProgram->getLayout(), rootLayout.writeRef())); SLANG_RETURN_ON_FAIL(RootShaderObjectImpl::create( this, rootLayout.Ptr(), shaderObject.writeRef())); - *outObject = shaderObject.detach(); + returnRefPtrMove(outObject, shaderObject); return SLANG_OK; } diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 1571e9abf..333f1df54 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -300,7 +300,7 @@ ShaderComponentID ShaderCache::getComponentId(ComponentKey key) return resultId; } -void ShaderCache::addSpecializedPipeline(PipelineKey key, Slang::ComPtr specializedPipeline) +void ShaderCache::addSpecializedPipeline(PipelineKey key, Slang::RefPtr specializedPipeline) { specializedPipelines[key] = specializedPipeline; } @@ -363,7 +363,7 @@ Result RendererBase::maybeSpecializePipeline( pipelineKey.specializationArgs.addRange(specializationArgs.componentIDs); pipelineKey.updateHash(); - ComPtr specializedPipelineState = shaderCache.getSpecializedPipelineState(pipelineKey); + RefPtr specializedPipelineState = shaderCache.getSpecializedPipelineState(pipelineKey); // Try to find specialized pipeline from shader cache. if (!specializedPipelineState) { @@ -393,30 +393,34 @@ Result RendererBase::maybeSpecializePipeline( SLANG_RETURN_ON_FAIL(createProgram(specializedProgramDesc, specializedProgram.writeRef())); // Create specialized pipeline state. + ComPtr specializedPipelineComPtr; switch (pipelineType) { case PipelineType::Compute: { auto pipelineDesc = currentPipeline->desc.compute; pipelineDesc.program = specializedProgram; - SLANG_RETURN_ON_FAIL(createComputePipelineState(pipelineDesc, specializedPipelineState.writeRef())); + SLANG_RETURN_ON_FAIL( + createComputePipelineState(pipelineDesc, specializedPipelineComPtr.writeRef())); break; } case PipelineType::Graphics: { auto pipelineDesc = currentPipeline->desc.graphics; pipelineDesc.program = specializedProgram; - SLANG_RETURN_ON_FAIL(createGraphicsPipelineState(pipelineDesc, specializedPipelineState.writeRef())); + SLANG_RETURN_ON_FAIL(createGraphicsPipelineState( + pipelineDesc, specializedPipelineComPtr.writeRef())); break; } default: break; } - auto specializedPipelineStateBase = static_cast(specializedPipelineState.get()); - specializedPipelineStateBase->unspecializedPipelineState = currentPipeline; + specializedPipelineState = + static_cast(specializedPipelineComPtr.get()); + specializedPipelineState->unspecializedPipelineState = currentPipeline; shaderCache.addSpecializedPipeline(pipelineKey, specializedPipelineState); } - outNewPipeline = static_cast(specializedPipelineState.get()); + outNewPipeline = static_cast(specializedPipelineState.Ptr()); } return SLANG_OK; } diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 79c965631..ec33a3054 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -3,6 +3,7 @@ #include "slang-gfx.h" #include "slang-context.h" #include "core/slang-basic.h" +#include "core/slang-com-object.h" namespace gfx { @@ -34,9 +35,151 @@ struct GfxGUID static const Slang::Guid IID_ICommandQueue; }; +// We use a `BreakableReference` to avoid the cyclic reference situation in gfx implementation. +// It is a common scenario where objects created from an `IDevice` implementation needs to hold +// a strong reference to the device object that creates them. For example, a `Buffer` or a +// `CommandQueue` needs to store a `m_device` member that points to the `IDevice`. At the same +// time, the device implementation may also hold a reference to some of the objects it created +// to represent the current device/binding state. Both parties would like to maintain a strong +// reference to each other to achieve robustness against arbitrary ordering of destruction that +// can be triggered by the user. However this creates cyclic reference situations that break +// the `RefPtr` recyling mechanism. To solve this problem, we instead make each object reference +// the device via a `BreakableReference` pointer. A breakable reference can be +// turned into a weak reference via its `breakStrongReference()` call. +// If we know there is a cyclic reference between an API object and the device/pool that creates it, +// we can break the cycle when there is no longer any public references that come from `ComPtr`s to +// the API object, by turning the reference to the device object from the API object to a weak +// reference. +// The following example illustrate how this mechanism works: +// Suppose we have +// ``` +// class DeviceImpl : IDevice { RefPtr m_currentObject; }; +// class ShaderObjectImpl : IShaderObject { BreakableReference m_device; }; +// ``` +// And the user creates a device and a shader object, then somehow having the device reference +// the shader object (this may not happen in actual implemetations, we just use it to illustrate +// the situation): +// ``` +// ComPtr device = createDevice(); +// ComPtr res = device->createResourceX(...); +// device->m_currentResource = res; +// ``` +// This setup is robust to any destruction ordering. If user releases reference to `device` first, +// then the device object will not be freed yet, since there is still a strong reference to the device +// implementation via `res->m_device`. Next when the user releases reference to `res`, the public +// reference count to `res` via `ComPtr`s will go to 0, therefore triggering the call to +// `res->m_device.breakStrongReference()`, releasing the remaining reference to device. This will cause +// `device` to start destruction, which will release its strong reference to `res` during execution of +// its destructor. Finally, this will triger the actual destruction of `res`. +// On the other hand, if the user releases reference to `res` first, then the strong reference to `device` +// will be broken immediately, but the actual destruction of `res` will not start. Next when the user +// releases `device`, there will no longer be any other references to `device`, so the destruction of +// `device` will start, causing the release of the internal reference to `res`, leading to its destruction. +// Note that the above logic only works if it is known that there is a cyclic reference. If there are no +// such cyclic reference, then it will be incorrect to break the strong reference to `IDevice` upon +// public reference counter dropping to 0. This is because the actual destructor of `res` take place +// after breaking the cycle, but if the resource's strong reference to the device is already the last reference, +// turning that reference to weak reference will immediately trigger destruction of `device`, after which +// we can no longer destruct `res` if the destructor needs `device`. Therefore we need to be careful +// when using `BreakableReference`, and make sure we only call `breakStrongReference` only when it is known +// that there is a cyclic reference. Luckily for all scenarios so far this is statically known. +template +class BreakableReference +{ +private: + Slang::RefPtr m_strongPtr; + T* m_weakPtr = nullptr; + +public: + BreakableReference() = default; + + BreakableReference(T* p) { *this = p; } + + BreakableReference(Slang::RefPtr const& p) { *this = p; } + + void setWeakReference(T* p) { m_weakPtr = p; m_strongPtr = nullptr; } + + T& operator*() const { return *get(); } + + T* operator->() const { return get(); } + + T* get() const { return m_weakPtr; } + + operator T*() const { return get(); } + + void operator=(Slang::RefPtr const& p) + { + m_strongPtr = p; + m_weakPtr = p.Ptr(); + } + + void operator=(T* p) + { + m_strongPtr = p; + m_weakPtr = p; + } + + void breakStrongReference() { m_strongPtr = nullptr; } + + void establishStrongReference() { m_strongPtr = m_weakPtr; } +}; + +// Helpers for returning an object implementation as COM pointer. +template +void returnComPtr(TInterface** outInterface, TImpl* rawPtr) +{ + static_assert( + !std::is_base_of::value, + "TInterface must be an interface type."); + rawPtr->addRef(); + *outInterface = rawPtr; +} + +template +void returnComPtr(TInterface** outInterface, const Slang::RefPtr& refPtr) +{ + static_assert( + !std::is_base_of::value, + "TInterface must be an interface type."); + refPtr->addRef(); + *outInterface = refPtr.Ptr(); +} + +template +void returnComPtr(TInterface** outInterface, Slang::ComPtr& comPtr) +{ + static_assert( + !std::is_base_of::value, + "TInterface must be an interface type."); + *outInterface = comPtr.detach(); +} + +// Helpers for returning an object implementation as RefPtr. +template +void returnRefPtr(TDest** outPtr, Slang::RefPtr& refPtr) +{ + static_assert( + std::is_base_of::value, "TDest must be a non-interface type."); + static_assert( + std::is_base_of::value, "TImpl must be a non-interface type."); + *outPtr = refPtr.Ptr(); + refPtr->addReference(); +} + +template +void returnRefPtrMove(TDest** outPtr, Slang::RefPtr& refPtr) +{ + static_assert( + std::is_base_of::value, "TDest must be a non-interface type."); + static_assert( + std::is_base_of::value, "TImpl must be a non-interface type."); + *outPtr = refPtr.detach(); +} + + gfx::StageType translateStage(SlangStage slangStage); -class Resource : public Slang::RefObject +class Resource : public Slang::ComObject { public: /// Get the type @@ -56,7 +199,7 @@ protected: class BufferResource : public IBufferResource, public Resource { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResource* getInterface(const Slang::Guid& guid); public: @@ -78,7 +221,7 @@ protected: class TextureResource : public ITextureResource, public Resource { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResource* getInterface(const Slang::Guid& guid); public: @@ -138,6 +281,9 @@ struct ExtendedShaderObjectTypeList class ShaderObjectLayoutBase : public Slang::RefObject { protected: + // We always use a weak reference to the `IDevice` object here. + // `ShaderObject` implementations will make sure to hold a strong reference to `IDevice` + // while a `ShaderObjectLayout` may still be used. RendererBase* m_renderer; slang::TypeLayoutReflection* m_elementTypeLayout = nullptr; ShaderComponentID m_componentID = 0; @@ -182,9 +328,13 @@ public: void initBase(RendererBase* renderer, slang::TypeLayoutReflection* elementTypeLayout); }; -class ShaderObjectBase : public IShaderObject, public Slang::RefObject +class ShaderObjectBase : public IShaderObject, public Slang::ComObject { protected: + // A strong reference to `IDevice` to make sure the weak device reference in + // `ShaderObjectLayout`s are valid whenever they might be used. + BreakableReference m_device; + // The shader object layout used to create this shader object. Slang::RefPtr m_layout = nullptr; @@ -198,8 +348,9 @@ protected: Result _getSpecializedShaderObjectType(ExtendedShaderObjectType* outType); public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IShaderObject* getInterface(const Slang::Guid& guid); + void breakStrongReferenceToDevice() { m_device.breakStrongReference(); } public: ShaderComponentID getComponentID() @@ -235,21 +386,21 @@ public: virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) = 0; }; -class ShaderProgramBase : public IShaderProgram, public Slang::RefObject +class ShaderProgramBase : public IShaderProgram, public Slang::ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - + SLANG_COM_OBJECT_IUNKNOWN_ALL IShaderProgram* getInterface(const Slang::Guid& guid); ComPtr slangProgram; }; -class PipelineStateBase : public IPipelineState, public Slang::RefObject +class PipelineStateBase + : public IPipelineState + , public Slang::ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - + SLANG_COM_OBJECT_IUNKNOWN_ALL IPipelineState* getInterface(const Slang::Guid& guid); struct PipelineStateDesc @@ -270,10 +421,10 @@ public: // Indicates whether this is a specializable pipeline. A specializable // pipeline cannot be used directly and must be specialized first. bool isSpecializable = false; - ComPtr m_program; + Slang::RefPtr m_program; template TProgram* getProgram() { - return static_cast(m_program.get()); + return static_cast(m_program.Ptr()); } protected: @@ -360,14 +511,16 @@ public: ShaderComponentID getComponentId(Slang::UnownedStringSlice name); ShaderComponentID getComponentId(ComponentKey key); - Slang::ComPtr getSpecializedPipelineState(PipelineKey programKey) + Slang::RefPtr getSpecializedPipelineState(PipelineKey programKey) { - Slang::ComPtr result; + Slang::RefPtr result; if (specializedPipelines.TryGetValue(programKey, result)) return result; return nullptr; } - void addSpecializedPipeline(PipelineKey key, Slang::ComPtr specializedPipeline); + void addSpecializedPipeline( + PipelineKey key, + Slang::RefPtr specializedPipeline); void free() { specializedPipelines = decltype(specializedPipelines)(); @@ -376,16 +529,16 @@ public: protected: Slang::OrderedDictionary componentIds; - Slang::OrderedDictionary> specializedPipelines; + Slang::OrderedDictionary> specializedPipelines; }; // Renderer implementation shared by all platforms. // Responsible for shader compilation, specialization and caching. -class RendererBase : public Slang::RefObject, public IDevice +class RendererBase : public IDevice, public Slang::ComObject { friend class ShaderObjectBase; public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL virtual SLANG_NO_THROW Result SLANG_MCALL getFeatures( const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) SLANG_OVERRIDE; diff --git a/tools/gfx/simple-render-pass-layout.h b/tools/gfx/simple-render-pass-layout.h index 54d1e5649..14fffe37f 100644 --- a/tools/gfx/simple-render-pass-layout.h +++ b/tools/gfx/simple-render-pass-layout.h @@ -6,7 +6,7 @@ // concept. #include "slang-gfx.h" -#include "slang-com-helper.h" +#include "core/slang-com-object.h" #include "core/slang-basic.h" namespace gfx @@ -14,10 +14,10 @@ namespace gfx class SimpleRenderPassLayout : public IRenderPassLayout - , public Slang::RefObject + , public Slang::ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IRenderPassLayout* getInterface(const Slang::Guid& guid); public: diff --git a/tools/gfx/simple-transient-resource-heap.h b/tools/gfx/simple-transient-resource-heap.h index 5f6c32451..55731ddd0 100644 --- a/tools/gfx/simple-transient-resource-heap.h +++ b/tools/gfx/simple-transient-resource-heap.h @@ -11,10 +11,10 @@ namespace gfx template class SimpleTransientResourceHeap : public ITransientResourceHeap - , public Slang::RefObject + , public Slang::ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ITransientResourceHeap* getInterface(const Slang::Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ITransientResourceHeap) @@ -23,8 +23,8 @@ public: } public: - TDevice* m_device; - ComPtr m_constantBuffer; + Slang::RefPtr m_device; + Slang::ComPtr m_constantBuffer; public: Result init(TDevice* device, const ITransientResourceHeap::Desc& desc) @@ -43,7 +43,7 @@ public: { Slang::RefPtr newCmdBuffer = new TCommandBuffer(); newCmdBuffer->init(m_device); - *outCommandBuffer = newCmdBuffer.detach(); + returnComPtr(outCommandBuffer, newCmdBuffer); return SLANG_OK; } diff --git a/tools/gfx/transient-resource-heap-base.h b/tools/gfx/transient-resource-heap-base.h index 2376ab1ac..2db463e77 100644 --- a/tools/gfx/transient-resource-heap-base.h +++ b/tools/gfx/transient-resource-heap-base.h @@ -1,4 +1,4 @@ -#include "slang-gfx.h" +#include "renderer-shared.h" #include "source/core/slang-basic.h" namespace gfx @@ -6,19 +6,20 @@ namespace gfx template class TransientResourceHeapBase : public ITransientResourceHeap - , public Slang::RefObject + , public Slang::ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ITransientResourceHeap* getInterface(const Slang::Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ITransientResourceHeap) return static_cast(this); return nullptr; } + void breakStrongReferenceToDevice() { m_device.breakStrongReference(); } public: - TDevice* m_device; + BreakableReference m_device; Slang::List> m_constantBuffers; Slang::Index m_constantBufferAllocCounter = 0; size_t m_constantBufferOffsetAllocCounter = 0; diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index 32ff0e7a2..2ce51790a 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -153,10 +153,10 @@ public: const VulkanApi* m_api; }; - class InputLayoutImpl : public IInputLayout, public RefObject + class InputLayoutImpl : public IInputLayout, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IInputLayout* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IInputLayout) @@ -170,7 +170,7 @@ public: class BufferResourceImpl: public BufferResource { - public: + public: typedef BufferResource Parent; BufferResourceImpl(IResource::Usage initialUsage, const IBufferResource::Desc& desc, VKDevice* renderer): @@ -182,7 +182,7 @@ public: } IResource::Usage m_initialUsage; - VKDevice* m_renderer; + RefPtr m_renderer; Buffer m_buffer; Buffer m_uploadBuffer; }; @@ -192,24 +192,19 @@ public: public: typedef TextureResource Parent; - TextureResourceImpl(const Desc& desc, Usage initialUsage, const VulkanApi* api) : + TextureResourceImpl(const Desc& desc, Usage initialUsage, VKDevice* device) : Parent(desc), m_initialUsage(initialUsage), - m_api(api) + m_device(device) { } ~TextureResourceImpl() { - if (m_api) + auto& vkAPI = m_device->m_api; + if (!m_isWeakImageReference) { - if (m_imageMemory != VK_NULL_HANDLE) - { - m_api->vkFreeMemory(m_api->m_device, m_imageMemory, nullptr); - } - if (m_image != VK_NULL_HANDLE && !m_isWeakImageReference) - { - m_api->vkDestroyImage(m_api->m_device, m_image, nullptr); - } + vkAPI.vkFreeMemory(vkAPI.m_device, m_imageMemory, nullptr); + vkAPI.vkDestroyImage(vkAPI.m_device, m_image, nullptr); } } @@ -219,13 +214,13 @@ public: VkFormat m_vkformat = VK_FORMAT_R8G8B8A8_UNORM; VkDeviceMemory m_imageMemory = VK_NULL_HANDLE; bool m_isWeakImageReference = false; - const VulkanApi* m_api; + RefPtr m_device; }; - class SamplerStateImpl : public ISamplerState, public RefObject + class SamplerStateImpl : public ISamplerState, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ISamplerState* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ISamplerState) @@ -234,20 +229,21 @@ public: } public: VkSampler m_sampler; - const VulkanApi* m_api; - SamplerStateImpl(const VulkanApi* api) - : m_api(api) - {} + RefPtr m_device; + SamplerStateImpl(VKDevice* device) + : m_device(device) + { + } ~SamplerStateImpl() { - m_api->vkDestroySampler(m_api->m_device, m_sampler, nullptr); + m_device->m_api.vkDestroySampler(m_device->m_api.m_device, m_sampler, nullptr); } }; - class ResourceViewImpl : public IResourceView, public RefObject + class ResourceViewImpl : public IResourceView, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResourceView* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResourceView) @@ -262,24 +258,25 @@ public: PlainBuffer, }; public: - ResourceViewImpl(ViewType viewType, const VulkanApi* api) - : m_type(viewType), m_api(api) + ResourceViewImpl(ViewType viewType, VKDevice* device) + : m_type(viewType) + , m_device(device) { } ViewType m_type; - const VulkanApi* m_api; + RefPtr m_device; }; class TextureResourceViewImpl : public ResourceViewImpl { public: - TextureResourceViewImpl(const VulkanApi* api) - : ResourceViewImpl(ViewType::Texture, api) + TextureResourceViewImpl(VKDevice* device) + : ResourceViewImpl(ViewType::Texture, device) { } ~TextureResourceViewImpl() { - m_api->vkDestroyImageView(m_api->m_device, m_view, nullptr); + m_device->m_api.vkDestroyImageView(m_device->m_api.m_device, m_view, nullptr); } RefPtr m_texture; VkImageView m_view; @@ -289,13 +286,13 @@ public: class TexelBufferResourceViewImpl : public ResourceViewImpl { public: - TexelBufferResourceViewImpl(const VulkanApi* api) - : ResourceViewImpl(ViewType::TexelBuffer, api) + TexelBufferResourceViewImpl(VKDevice* device) + : ResourceViewImpl(ViewType::TexelBuffer, device) { } ~TexelBufferResourceViewImpl() { - m_api->vkDestroyBufferView(m_api->m_device, m_view, nullptr); + m_device->m_api.vkDestroyBufferView(m_device->m_api.m_device, m_view, nullptr); } RefPtr m_buffer; VkBufferView m_view; @@ -304,8 +301,8 @@ public: class PlainBufferResourceViewImpl : public ResourceViewImpl { public: - PlainBufferResourceViewImpl(const VulkanApi* api) - : ResourceViewImpl(ViewType::PlainBuffer, api) + PlainBufferResourceViewImpl(VKDevice* device) + : ResourceViewImpl(ViewType::PlainBuffer, device) { } RefPtr m_buffer; @@ -315,10 +312,10 @@ public: class FramebufferLayoutImpl : public IFramebufferLayout - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IFramebufferLayout* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IFramebufferLayout) @@ -328,12 +325,13 @@ public: public: VkRenderPass m_renderPass; - VKDevice* m_renderer; + RefPtr m_renderer; Array m_attachmentDescs; Array m_colorReferences; VkAttachmentReference m_depthReference; bool m_hasDepthStencilAttachment; uint32_t m_renderTargetCount; + public: ~FramebufferLayoutImpl() { @@ -428,10 +426,10 @@ public: class RenderPassLayoutImpl : public IRenderPassLayout - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IRenderPassLayout* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IRenderPassLayout) @@ -441,8 +439,7 @@ public: public: VkRenderPass m_renderPass; - VKDevice* m_renderer; - + RefPtr m_renderer; ~RenderPassLayoutImpl() { m_renderer->m_api.vkDestroyRenderPass( @@ -537,10 +534,10 @@ public: class FramebufferImpl : public IFramebuffer - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IFramebuffer* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IFramebuffer) @@ -554,7 +551,7 @@ public: ComPtr depthStencilView; uint32_t m_width; uint32_t m_height; - VKDevice* m_renderer; + RefPtr m_renderer; VkClearValue m_clearValues[kMaxAttachments]; RefPtr m_layout; public: @@ -644,18 +641,31 @@ public: class PipelineStateImpl : public PipelineStateBase { public: - PipelineStateImpl(const VulkanApi& api): - m_api(&api) + PipelineStateImpl(VKDevice* device) { + // Only weakly reference `device` at start. + // We make it a strong reference only when the pipeline state is exposed to the user. + // Note that `PipelineState`s may also be created via implicit specialization that + // happens behind the scenes, and the user will not have access to those specialized + // pipeline states. Only those pipeline states that are returned to the user needs to + // hold a strong reference to `device`. + m_device.setWeakReference(device); } ~PipelineStateImpl() { if (m_pipeline != VK_NULL_HANDLE) { - m_api->vkDestroyPipeline(m_api->m_device, m_pipeline, nullptr); + m_device->m_api.vkDestroyPipeline(m_device->m_api.m_device, m_pipeline, nullptr); } } + // Turns `m_device` into a strong reference. + // This method should be called before returning the pipeline state object to + // external users (i.e. via an `IPipelineState` pointer). + void establishStrongDeviceReference() { m_device.establishStrongReference(); } + + virtual void comFree() override { m_device.breakStrongReference(); } + void init(const GraphicsPipelineStateDesc& inDesc) { PipelineStateDesc pipelineDesc; @@ -671,9 +681,7 @@ public: initializeBase(pipelineDesc); } - const VulkanApi* m_api; - - RefPtr m_framebufferLayout; + BreakableReference m_device; VkPipeline m_pipeline = VK_NULL_HANDLE; }; @@ -835,16 +843,17 @@ public: Result setElementTypeLayout(slang::TypeLayoutReflection* typeLayout) { - typeLayout = _unwrapParameterGroups(typeLayout); - - m_elementTypeLayout = typeLayout; - // First we will use the Slang layout information to allocate // the descriptor set layout(s) required to store values // of the given type. // SLANG_RETURN_ON_FAIL(_addDescriptorSets(typeLayout)); + typeLayout = _unwrapParameterGroups(typeLayout); + + m_elementTypeLayout = typeLayout; + + // Next we will compute the binding ranges that are used to store // the logical contents of the object in memory. These will relate // to the descriptor ranges in the various sets, but not always @@ -944,7 +953,7 @@ public: auto layout = RefPtr(new ShaderObjectLayoutImpl()); SLANG_RETURN_ON_FAIL(layout->_init(this)); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } }; @@ -1053,7 +1062,7 @@ public: RefPtr layout = new EntryPointLayout(); SLANG_RETURN_ON_FAIL(layout->_init(this)); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -1125,7 +1134,7 @@ public: { RefPtr layout = new RootShaderObjectLayout(); SLANG_RETURN_ON_FAIL(layout->_init(this)); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -1299,14 +1308,14 @@ public: VkPipelineLayout m_pipelineLayout = VK_NULL_HANDLE; Array m_vkDescriptorSetLayouts; Array m_pushConstantRanges; - RefPtr m_renderer; + VKDevice* m_renderer = nullptr; }; class ShaderProgramImpl : public ShaderProgramBase { public: - ShaderProgramImpl(const VulkanApi& api, PipelineType pipelineType) - : m_api(&api) + ShaderProgramImpl(VKDevice* device, PipelineType pipelineType) + : m_device(device) , m_pipelineType(pipelineType) { for (auto& shaderModule : m_modules) @@ -1319,12 +1328,18 @@ public: { if (shaderModule != VK_NULL_HANDLE) { - m_api->vkDestroyShaderModule(m_api->m_device, shaderModule, nullptr); + m_device->m_api.vkDestroyShaderModule( + m_device->m_api.m_device, shaderModule, nullptr); } } } - const VulkanApi* m_api; + virtual void comFree() override + { + m_device.breakStrongReference(); + } + + BreakableReference m_device; PipelineType m_pipelineType; @@ -1478,7 +1493,7 @@ public: auto object = RefPtr(new ShaderObjectImpl()); SLANG_RETURN_ON_FAIL(object->init(device, layout)); - *outShaderObject = object.detach(); + returnRefPtrMove(outShaderObject, object); return SLANG_OK; } @@ -1641,7 +1656,7 @@ public: } } - return SLANG_E_NOT_IMPLEMENTED; + return SLANG_OK; } virtual SLANG_NO_THROW Result SLANG_MCALL @@ -1656,8 +1671,7 @@ public: auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex); auto object = m_objects[bindingRange.baseIndex + offset.bindingArrayIndex].Ptr(); - object->addRef(); - *outObject = object; + returnComPtr(outObject, object); // auto& subObjectRange = // m_layout->getSubObjectRange(bindingRange.subObjectRangeIndex); *outObject = @@ -2488,7 +2502,7 @@ public: { SLANG_RETURN_ON_FAIL(_createSpecializedLayout(m_specializedLayout.writeRef())); } - *outLayout = RefPtr(m_specializedLayout).detach(); + returnRefPtr(outLayout, m_specializedLayout); return SLANG_OK; } @@ -2502,11 +2516,11 @@ public: SLANG_RETURN_ON_FAIL(getSpecializedShaderObjectType(&extendedType)); auto device = getDevice(); - RefPtr layout; - SLANG_RETURN_ON_FAIL( - device->getShaderObjectLayout(extendedType.slangType, layout.writeRef())); + RefPtr layout; + SLANG_RETURN_ON_FAIL(device->getShaderObjectLayout( + extendedType.slangType, (ShaderObjectLayoutBase**)layout.writeRef())); - *outLayout = static_cast(layout.detach()); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -2526,7 +2540,7 @@ public: RefPtr object = new EntryPointShaderObject(); SLANG_RETURN_ON_FAIL(object->init(device, layout)); - *outShaderObject = object.detach(); + returnRefPtrMove(outShaderObject, object); return SLANG_OK; } @@ -2595,8 +2609,7 @@ public: SlangResult SLANG_MCALL getEntryPoint(UInt index, IShaderObject** outEntryPoint) SLANG_OVERRIDE { - *outEntryPoint = m_entryPoints[index]; - m_entryPoints[index]->addRef(); + returnComPtr(outEntryPoint, m_entryPoints[index]); return SLANG_OK; } @@ -2723,7 +2736,7 @@ public: entryPointVars->m_specializedLayout = entryPointInfo.layout; } - *outLayout = specializedLayout.detach(); + returnRefPtrMove(outLayout, specializedLayout); return SLANG_OK; } @@ -2734,24 +2747,27 @@ public: class CommandBufferImpl : public ICommandBuffer - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + // There are a pair of cyclic references between a `TransientResourceHeap` and + // a `CommandBuffer` created from the heap. We need to break the cycle when + // the public reference count of a command buffer drops to 0. + SLANG_COM_OBJECT_IUNKNOWN_ALL ICommandBuffer* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandBuffer) return static_cast(this); return nullptr; } - + virtual void comFree() override { m_transientHeap.breakStrongReference(); } public: VkCommandBuffer m_commandBuffer; VkCommandBuffer m_preCommandBuffer = VK_NULL_HANDLE; VkCommandPool m_pool; VkFence m_fence; VKDevice* m_renderer; - TransientResourceHeapImpl* m_transientHeap; + BreakableReference m_transientHeap; bool m_isPreCommandBufferEmpty = true; RootShaderObjectImpl m_rootObject; // Command buffers are deallocated by its command pool, @@ -2838,15 +2854,20 @@ public: VkIndexType m_boundIndexFormat; public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IRenderCommandEncoder* getInterface(const Guid& guid) + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + queryInterface(SlangUUID const& uuid, void** outObject) override { - if (guid == GfxGUID::IID_ISlangUnknown || - guid == GfxGUID::IID_IRenderCommandEncoder || - guid == GfxGUID::IID_ICommandEncoder) - return static_cast(this); - return nullptr; + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || + uuid == GfxGUID::IID_IRenderCommandEncoder) + { + *outObject = static_cast(this); + return SLANG_OK; + } + *outObject = nullptr; + return SLANG_E_NO_INTERFACE; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } void beginPass(IRenderPassLayout* renderPass, IFramebuffer* framebuffer) { @@ -3012,7 +3033,7 @@ public: void prepareDraw() { auto pipeline = static_cast(m_currentPipeline.Ptr()); - if (!pipeline || static_cast(pipeline->m_program.get()) + if (!pipeline || static_cast(pipeline->m_program.Ptr()) ->m_pipelineType != PipelineType::Graphics) { assert(!"Invalid render pipeline"); @@ -3074,7 +3095,6 @@ public: assert(!m_renderCommandEncoder->m_isOpen); m_renderCommandEncoder->beginPass(renderPass, framebuffer); *outEncoder = m_renderCommandEncoder.Ptr(); - m_renderCommandEncoder->addRef(); } class ComputeCommandEncoder @@ -3082,16 +3102,20 @@ public: , public PipelineCommandEncoder { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IComputeCommandEncoder* getInterface(const Guid& guid) + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + queryInterface(SlangUUID const& uuid, void** outObject) override { - if (guid == GfxGUID::IID_ISlangUnknown || - guid == GfxGUID::IID_IComputeCommandEncoder || - guid == GfxGUID::IID_ICommandEncoder) - return static_cast(this); - return nullptr; + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || + uuid == GfxGUID::IID_IComputeCommandEncoder) + { + *outObject = static_cast(this); + return SLANG_OK; + } + *outObject = nullptr; + return SLANG_E_NO_INTERFACE; } - + virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } public: virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override { @@ -3108,7 +3132,7 @@ public: { auto pipeline = static_cast(m_currentPipeline.Ptr()); if (!pipeline || - static_cast(pipeline->m_program.get())->m_pipelineType != + static_cast(pipeline->m_program.Ptr())->m_pipelineType != PipelineType::Compute) { assert(!"Invalid compute pipeline"); @@ -3133,7 +3157,6 @@ public: } assert(!m_computeCommandEncoder->m_isOpen); *outEncoder = m_computeCommandEncoder.Ptr(); - m_computeCommandEncoder->addRef(); } class ResourceCommandEncoder @@ -3143,16 +3166,20 @@ public: public: CommandBufferImpl* m_commandBuffer; public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IResourceCommandEncoder* getInterface(const Guid& guid) + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + queryInterface(SlangUUID const& uuid, void** outObject) override { - if (guid == GfxGUID::IID_ISlangUnknown || - guid == GfxGUID::IID_IResourceCommandEncoder || - guid == GfxGUID::IID_ICommandEncoder) - return static_cast(this); - return nullptr; + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || + uuid == GfxGUID::IID_IResourceCommandEncoder) + { + *outObject = static_cast(this); + return SLANG_OK; + } + *outObject = nullptr; + return SLANG_E_NO_INTERFACE; } - + virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } + virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } public: virtual SLANG_NO_THROW void SLANG_MCALL copyBuffer( IBufferResource* dst, @@ -3230,7 +3257,6 @@ public: m_resourceCommandEncoder->init(this); } *outEncoder = m_resourceCommandEncoder.Ptr(); - m_resourceCommandEncoder->addRef(); } virtual SLANG_NO_THROW void SLANG_MCALL close() override @@ -3263,10 +3289,10 @@ public: class CommandQueueImpl : public ICommandQueue - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ICommandQueue* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandQueue) @@ -3383,6 +3409,7 @@ public: Result init(const ITransientResourceHeap::Desc& desc, VKDevice* device); ~TransientResourceHeapImpl() { + m_commandBufferPool = decltype(m_commandBufferPool)(); m_device->m_api.vkDestroyCommandPool(m_device->m_api.m_device, m_commandPool, nullptr); m_device->m_api.vkDestroyFence(m_device->m_api.m_device, m_fence, nullptr); m_descSetAllocator.close(); @@ -3395,10 +3422,10 @@ public: class SwapchainImpl : public ISwapchain - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ISwapchain* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ISwapchain) @@ -3576,7 +3603,7 @@ public: imageDesc.init2D( IResource::Type::Texture2D, m_desc.format, m_desc.width, m_desc.height, 1); RefPtr image = new TextureResourceImpl( - imageDesc, gfx::IResource::Usage::RenderTarget, m_api); + imageDesc, gfx::IResource::Usage::RenderTarget, m_renderer); image->m_image = vkImages[i]; image->m_imageMemory = 0; image->m_vkformat = m_vkformat; @@ -3700,8 +3727,7 @@ public: { if (m_images.getCount() <= (Index)index) return SLANG_FAIL; - *outResource = m_images[index]; - m_images[index]->addRef(); + returnComPtr(outResource, m_images[index]); return SLANG_OK; } virtual SLANG_NO_THROW Result SLANG_MCALL resize(uint32_t width, uint32_t height) override @@ -3800,6 +3826,18 @@ public: DescriptorSetAllocator descriptorSetAllocator; uint32_t m_queueAllocCount; + + // A list to hold objects that may have a strong back reference to the device + // instance. Because of the pipeline cache in `RendererBase`, there could be a reference + // cycle among `VKDevice`->`PipelineStateImpl`->`ShaderProgramImpl`->`VkDevice`. + // Depending on whether a `PipelineState` objects gets stored in pipeline cache, there + // may or may not be such a reference cycle. + // We need to hold strong references to any objects that may become part of the reference + // cycle here, so that when objects like `ShaderProgramImpl` lost all public refernces, we + // can always safely break the strong reference in `ShaderProgramImpl::m_device` without + // worrying the `ShaderProgramImpl` object getting destroyed after the completion of + // `VKDevice::~VKDevice()'. + ChunkedList, 1024> m_deviceObjectsWithPotentialBackReferences; }; void VKDevice::PipelineCommandEncoder::init(CommandBufferImpl* commandBuffer) @@ -3894,7 +3932,7 @@ Result SLANG_MCALL createVKDevice(const IDevice::Desc* desc, IDevice** outRender { RefPtr result = new VKDevice(); SLANG_RETURN_ON_FAIL(result->initialize(*desc)); - *outRenderer = result.detach(); + returnComPtr(outRenderer, result); return SLANG_OK; } @@ -3907,8 +3945,8 @@ VKDevice::~VKDevice() } m_shaderObjectLayoutCache = decltype(m_shaderObjectLayoutCache)(); - shaderCache.free(); + m_deviceObjectsWithPotentialBackReferences.clearAndDeallocate(); // Same as clear but, also dtors all elements, which clear does not m_deviceQueue.destroy(); @@ -4374,7 +4412,7 @@ Result VKDevice::TransientResourceHeapImpl::createCommandBuffer(ICommandBuffer** auto result = m_commandBufferPool[m_commandBufferAllocId]; result->beginCommandBuffer(); m_commandBufferAllocId++; - *outCmdBuffer = result.detach(); + returnComPtr(outCmdBuffer, result); return SLANG_OK; } @@ -4383,7 +4421,7 @@ Result VKDevice::TransientResourceHeapImpl::createCommandBuffer(ICommandBuffer** m_device, m_commandPool, m_fence, this)); m_commandBufferPool.add(commandBuffer); m_commandBufferAllocId++; - *outCmdBuffer = commandBuffer.detach(); + returnComPtr(outCmdBuffer, commandBuffer); return SLANG_OK; } @@ -4407,7 +4445,7 @@ Result VKDevice::createTransientResourceHeap( { RefPtr result = new TransientResourceHeapImpl(); SLANG_RETURN_ON_FAIL(result->init(desc, this)); - *outHeap = result.detach(); + returnComPtr(outHeap, result); return SLANG_OK; } @@ -4421,7 +4459,7 @@ Result VKDevice::createCommandQueue(const ICommandQueue::Desc& desc, ICommandQue m_api.vkGetDeviceQueue(m_api.m_device, queueFamilyIndex, 0, &vkQueue); RefPtr result = new CommandQueueImpl(); result->init(this, vkQueue, queueFamilyIndex); - *outQueue = result.detach(); + returnComPtr(outQueue, result); m_queueAllocCount++; return SLANG_OK; } @@ -4438,7 +4476,7 @@ Result VKDevice::createSwapchain( RefPtr sc = new SwapchainImpl(); SLANG_RETURN_ON_FAIL(sc->init(this, desc, window)); - *outSwapchain = sc.detach(); + returnComPtr(outSwapchain, sc); return SLANG_OK; } @@ -4446,7 +4484,7 @@ Result VKDevice::createFramebufferLayout(const IFramebufferLayout::Desc& desc, I { RefPtr layout = new FramebufferLayoutImpl(); SLANG_RETURN_ON_FAIL(layout->init(this, desc)); - *outLayout = layout.detach(); + returnComPtr(outLayout, layout); return SLANG_OK; } @@ -4456,7 +4494,7 @@ Result VKDevice::createRenderPassLayout( { RefPtr result = new RenderPassLayoutImpl(); SLANG_RETURN_ON_FAIL(result->init(this, desc)); - *outRenderPassLayout = result.detach(); + returnComPtr(outRenderPassLayout, result); return SLANG_OK; } @@ -4464,7 +4502,7 @@ Result VKDevice::createFramebuffer(const IFramebuffer::Desc& desc, IFramebuffer* { RefPtr fb = new FramebufferImpl(); SLANG_RETURN_ON_FAIL(fb->init(this, desc)); - *outFramebuffer = fb.detach(); + returnComPtr(outFramebuffer, fb); return SLANG_OK; } @@ -4520,7 +4558,7 @@ SlangResult VKDevice::readBufferResource( ::memcpy(blob->m_data.getBuffer(), mappedData, size); m_api.vkUnmapMemory(m_device, staging.m_memory); - *outBlob = blob.detach(); + returnComPtr(outBlob, blob); return SLANG_OK; } @@ -4760,7 +4798,7 @@ Result VKDevice::createTextureResource(IResource::Usage initialUsage, const ITex const int arraySize = desc.calcEffectiveArraySize(); - RefPtr texture(new TextureResourceImpl(desc, initialUsage, &m_api)); + RefPtr texture(new TextureResourceImpl(desc, initialUsage, this)); texture->m_vkformat = format; // Create the image { @@ -4989,7 +5027,7 @@ Result VKDevice::createTextureResource(IResource::Usage initialUsage, const ITex } } m_deviceQueue.flushAndWait(); - *outResource = texture.detach(); + returnComPtr(outResource, texture); return SLANG_OK; } @@ -5043,7 +5081,7 @@ Result VKDevice::createBufferResource(IResource::Usage initialUsage, const IBuff m_deviceQueue.flush(); } - *outResource = buffer.detach(); + returnComPtr(outResource, buffer); return SLANG_OK; } @@ -5187,16 +5225,16 @@ Result VKDevice::createSamplerState(ISamplerState::Desc const& desc, ISamplerSta VkSampler sampler; SLANG_VK_RETURN_ON_FAIL(m_api.vkCreateSampler(m_device, &samplerInfo, nullptr, &sampler)); - RefPtr samplerImpl = new SamplerStateImpl(&m_api); + RefPtr samplerImpl = new SamplerStateImpl(this); samplerImpl->m_sampler = sampler; - *outSampler = samplerImpl.detach(); + returnComPtr(outSampler, samplerImpl); return SLANG_OK; } Result VKDevice::createTextureView(ITextureResource* texture, IResourceView::Desc const& desc, IResourceView** outView) { auto resourceImpl = static_cast(texture); - RefPtr view = new TextureResourceViewImpl(&m_api); + RefPtr view = new TextureResourceViewImpl(this); view->m_texture = resourceImpl; VkImageViewCreateInfo createInfo = {}; createInfo.sType = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO; @@ -5266,7 +5304,7 @@ Result VKDevice::createTextureView(ITextureResource* texture, IResourceView::Des break; } m_api.vkCreateImageView(m_device, &createInfo, nullptr, &view->m_view); - *outView = view.detach(); + returnComPtr(outView, view); return SLANG_OK; } @@ -5310,11 +5348,11 @@ Result VKDevice::createBufferView(IBufferResource* buffer, IResourceView::Desc c { // Buffer usage that doesn't involve formatting doesn't // require a view in Vulkan. - RefPtr viewImpl = new PlainBufferResourceViewImpl(&m_api); + RefPtr viewImpl = new PlainBufferResourceViewImpl(this); viewImpl->m_buffer = resourceImpl; viewImpl->offset = 0; viewImpl->size = size; - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } // @@ -5334,10 +5372,10 @@ Result VKDevice::createBufferView(IBufferResource* buffer, IResourceView::Desc c VkBufferView view; SLANG_VK_RETURN_ON_FAIL(m_api.vkCreateBufferView(m_device, &info, nullptr, &view)); - RefPtr viewImpl = new TexelBufferResourceViewImpl(&m_api); + RefPtr viewImpl = new TexelBufferResourceViewImpl(this); viewImpl->m_buffer = resourceImpl; viewImpl->m_view = view; - *outView = viewImpl.detach(); + returnComPtr(outView, viewImpl); return SLANG_OK; } break; @@ -5377,7 +5415,7 @@ Result VKDevice::createInputLayout(const InputElementDesc* elements, UInt numEle // Work out the overall size layout->m_vertexSize = int(vertexSize); - *outLayout = layout.detach(); + returnComPtr(outLayout, layout); return SLANG_OK; } @@ -5406,9 +5444,11 @@ static VkImageViewType _calcImageViewType(ITextureResource::Type type, const ITe Result VKDevice::createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) { - RefPtr shaderProgram = new ShaderProgramImpl(m_api, desc.pipelineType); + RefPtr shaderProgram = new ShaderProgramImpl(this, desc.pipelineType); shaderProgram->m_pipelineType = desc.pipelineType; shaderProgram->slangProgram = desc.slangProgram; + m_deviceObjectsWithPotentialBackReferences.add(shaderProgram); + RootShaderObjectLayout::create( this, desc.slangProgram, @@ -5417,7 +5457,7 @@ Result VKDevice::createProgram(const IShaderProgram::Desc& desc, IShaderProgram* if (desc.slangProgram->getSpecializationParamCount() != 0) { // For a specializable program, we don't invoke any actual slang compilation yet. - *outProgram = shaderProgram.detach(); + returnComPtr(outProgram, shaderProgram); return SLANG_OK; } @@ -5444,7 +5484,7 @@ Result VKDevice::createProgram(const IShaderProgram::Desc& desc, IShaderProgram* shaderModule)); shaderProgram->m_modules.add(shaderModule); } - *outProgram = shaderProgram.detach(); + returnComPtr(outProgram, shaderProgram); return SLANG_OK; } @@ -5455,7 +5495,7 @@ Result VKDevice::createShaderObjectLayout( RefPtr layout; SLANG_RETURN_ON_FAIL( ShaderObjectLayoutImpl::createForElementType(this, typeLayout, layout.writeRef())); - *outLayout = layout.detach(); + returnRefPtrMove(outLayout, layout); return SLANG_OK; } @@ -5464,7 +5504,7 @@ Result VKDevice::createShaderObject(ShaderObjectLayoutBase* layout, IShaderObjec RefPtr shaderObject; SLANG_RETURN_ON_FAIL(ShaderObjectImpl::create( this, static_cast(layout), shaderObject.writeRef())); - *outObject = shaderObject.detach(); + returnComPtr(outObject, shaderObject); return SLANG_OK; } @@ -5475,9 +5515,11 @@ Result VKDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& in if (!programImpl->m_rootObjectLayout->m_pipelineLayout) { - RefPtr pipelineStateImpl = new PipelineStateImpl(m_api); + RefPtr pipelineStateImpl = new PipelineStateImpl(this); pipelineStateImpl->init(desc); - *outState = pipelineStateImpl.detach(); + pipelineStateImpl->establishStrongDeviceReference(); + m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } @@ -5628,12 +5670,12 @@ Result VKDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& in VkPipeline pipeline = VK_NULL_HANDLE; SLANG_VK_CHECK(m_api.vkCreateGraphicsPipelines(m_device, pipelineCache, 1, &pipelineInfo, nullptr, &pipeline)); - RefPtr pipelineStateImpl = new PipelineStateImpl(m_api); + RefPtr pipelineStateImpl = new PipelineStateImpl(this); pipelineStateImpl->m_pipeline = pipeline; - pipelineStateImpl->m_framebufferLayout = - static_cast(desc.framebufferLayout); pipelineStateImpl->init(desc); - *outState = pipelineStateImpl.detach(); + pipelineStateImpl->establishStrongDeviceReference(); + m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } @@ -5643,9 +5685,11 @@ Result VKDevice::createComputePipelineState(const ComputePipelineStateDesc& inDe auto programImpl = static_cast(desc.program); if (!programImpl->m_rootObjectLayout->m_pipelineLayout) { - RefPtr pipelineStateImpl = new PipelineStateImpl(m_api); + RefPtr pipelineStateImpl = new PipelineStateImpl(this); pipelineStateImpl->init(desc); - *outState = pipelineStateImpl.detach(); + m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); + pipelineStateImpl->establishStrongDeviceReference(); + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } @@ -5660,10 +5704,12 @@ Result VKDevice::createComputePipelineState(const ComputePipelineStateDesc& inDe SLANG_VK_CHECK(m_api.vkCreateComputePipelines( m_device, pipelineCache, 1, &computePipelineInfo, nullptr, &pipeline)); - RefPtr pipelineStateImpl = new PipelineStateImpl(m_api); + RefPtr pipelineStateImpl = new PipelineStateImpl(this); pipelineStateImpl->m_pipeline = pipeline; pipelineStateImpl->init(desc); - *outState = pipelineStateImpl.detach(); + m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); + pipelineStateImpl->establishStrongDeviceReference(); + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } diff --git a/tools/platform/window.h b/tools/platform/window.h index 884f28b0d..65e0508b6 100644 --- a/tools/platform/window.h +++ b/tools/platform/window.h @@ -244,7 +244,7 @@ public: # define GFX_DUMP_LEAK # endif # define PLATFORM_UI_MAIN(APPLICATION_ENTRY) \ - int __stdcall WinMain( \ + int __stdcall wWinMain( \ void* /*instance*/, \ void* /* prevInstance */, \ void* /* commandLine */, \ diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index 6e45e5b24..43273acf6 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -117,7 +117,7 @@ protected: // variables for state to be used for rendering... uintptr_t m_constantBufferSize; - ComPtr m_device; + IDevice* m_device; ComPtr m_queue; ComPtr m_transientHeap; ComPtr m_renderPass; @@ -641,16 +641,6 @@ void RenderTestApp::runCompute(IComputeCommandEncoder* encoder) void RenderTestApp::finalize() { - m_inputLayout = nullptr; - m_vertexBuffer = nullptr; - m_shaderProgram = nullptr; - m_pipelineState = nullptr; - m_renderPass = nullptr; - m_framebuffer = nullptr; - m_framebufferLayout = nullptr; - m_colorBuffer = nullptr; - m_queue = nullptr; - m_device = nullptr; } Result RenderTestApp::writeBindingOutput(const char* fileName) @@ -840,7 +830,11 @@ static SlangResult _setSessionPrelude(const Options& options, const char* exePat SLANG_RETURN_ON_FAIL(TestToolUtil::getRootPath(exePath, rootPath)); String includePath; - SLANG_RETURN_ON_FAIL(TestToolUtil::getIncludePath(rootPath, "external/nvapi/nvHLSLExtns.h", includePath)); + if (TestToolUtil::getIncludePath(rootPath, "external/nvapi/nvHLSLExtns.h", includePath) != + SLANG_OK) + { + return SLANG_FAIL; + } StringBuilder buf; // We have to choose a slot that NVAPI will use. @@ -1129,8 +1123,8 @@ static SlangResult _innerMain(Slang::StdWriters* stdWriters, SlangSession* sessi app.update(); renderDocEndFrame(); app.finalize(); - return SLANG_OK; } + return SLANG_OK; } SLANG_TEST_TOOL_API SlangResult innerMain(Slang::StdWriters* stdWriters, SlangSession* sharedSession, int inArgc, const char*const* inArgv) -- cgit v1.2.3