diff options
Diffstat (limited to 'tools/gfx/cuda/render-cuda.cpp')
| -rw-r--r-- | tools/gfx/cuda/render-cuda.cpp | 115 |
1 files changed, 63 insertions, 52 deletions
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp index dbee0c5f2..f871e6246 100644 --- a/tools/gfx/cuda/render-cuda.cpp +++ b/tools/gfx/cuda/render-cuda.cpp @@ -181,6 +181,13 @@ void _optixLogCallback(unsigned int level, const char* tag, const char* message, # endif +class CUDAContext : public RefObject +{ +public: + CUcontext m_context = nullptr; + ~CUDAContext() { cuCtxDestroy(m_context); } +}; + class MemoryCUDAResource : public BufferResource { public: @@ -199,6 +206,8 @@ public: uint64_t getBindlessHandle() { return (uint64_t)m_cudaMemory; } void* m_cudaMemory = nullptr; + + RefPtr<CUDAContext> m_cudaContext; }; class TextureCUDAResource : public TextureResource @@ -238,12 +247,14 @@ public: CUarray m_cudaArray = CUarray(); CUmipmappedArray m_cudaMipMappedArray = CUmipmappedArray(); + + RefPtr<CUDAContext> m_cudaContext; }; -class CUDAResourceView : public IResourceView, public RefObject +class CUDAResourceView : public IResourceView, public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL IResourceView* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResourceView) @@ -451,7 +462,7 @@ public: ComPtr<IBufferResource> constantBuffer; SLANG_RETURN_ON_FAIL(device->createBufferResource( IResource::Usage::ConstantBuffer, bufferDesc, nullptr, constantBuffer.writeRef())); - bufferResource = dynamic_cast<MemoryCUDAResource*>(constantBuffer.get()); + bufferResource = static_cast<MemoryCUDAResource*>(constantBuffer.get()); return SLANG_OK; } @@ -514,8 +525,7 @@ public: *object = nullptr; return SLANG_OK; } - objects[subObjectIndex]->addRef(); - *object = objects[subObjectIndex].Ptr(); + returnComPtr(object, objects[subObjectIndex]); return SLANG_OK; } virtual SLANG_NO_THROW Result SLANG_MCALL @@ -648,7 +658,7 @@ public: auto& bindingRange = layout->m_bindingRanges[bindingRangeIndex]; auto viewIndex = bindingRange.baseIndex + offset.bindingArrayIndex; - auto cudaView = dynamic_cast<CUDAResourceView*>(resourceView); + auto cudaView = static_cast<CUDAResourceView*>(resourceView); resources[viewIndex] = cudaView; @@ -720,7 +730,6 @@ public: // type, we need to make sure to use that type as the specialization argument. // TODO: need to implement the case where the field is an array of existential values. - SLANG_ASSERT(bindingRange.count == 1); ExtendedShaderObjectType specializedSubObjType; SLANG_RETURN_ON_FAIL(objects[subObjIndex]->getSpecializedShaderObjectType(&specializedSubObjType)); args.add(specializedSubObjType); @@ -801,6 +810,12 @@ public: class CUDARootShaderObject : public CUDAShaderObject { public: + // Override default reference counting behavior to disable lifetime management. + // Root objects are managed by command buffer and does not need to be freed by the user. + SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } + SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } + +public: List<RefPtr<CUDAEntryPointShaderObject>> entryPointObjects; virtual SLANG_NO_THROW Result SLANG_MCALL init(IDevice* device, CUDAShaderObjectLayout* typeLayout) override; @@ -808,8 +823,7 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL getEntryPoint(UInt index, IShaderObject** outEntryPoint) override { - *outEntryPoint = entryPointObjects[index].Ptr(); - entryPointObjects[index]->addRef(); + returnComPtr(outEntryPoint, entryPointObjects[index]); return SLANG_OK; } virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override @@ -830,6 +844,7 @@ public: CUfunction cudaKernel; String kernelName; RefPtr<CUDAProgramLayout> layout; + RefPtr<CUDAContext> cudaContext; ~CUDAShaderProgram() { if (cudaModule) @@ -969,7 +984,7 @@ private: private: int m_deviceIndex = -1; CUdevice m_device = 0; - CUcontext m_context = nullptr; + RefPtr<CUDAContext> m_context; DeviceInfo m_info; String m_adapterName; @@ -979,10 +994,10 @@ public: class CommandBufferImpl : public ICommandBuffer , public CommandWriter - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ICommandBuffer* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandBuffer) @@ -1010,7 +1025,7 @@ public: virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) override { - if (uuid == GfxGUID::IID_ISlangUnknown || + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || uuid == GfxGUID::IID_IComputeCommandEncoder) { *outObject = static_cast<IComputeCommandEncoder*>(this); @@ -1025,7 +1040,7 @@ public: public: CommandWriter* m_writer; CommandBufferImpl* m_commandBuffer; - ComPtr<IShaderObject> m_rootObject; + RefPtr<ShaderObjectBase> m_rootObject; virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override {} void init(CommandBufferImpl* cmdBuffer) { @@ -1039,8 +1054,8 @@ public: m_writer->setPipelineState(state); PipelineStateBase* pipelineImpl = static_cast<PipelineStateBase*>(state); SLANG_RETURN_ON_FAIL(m_commandBuffer->m_device->createRootShaderObject( - pipelineImpl->m_program, outRootObject)); - m_rootObject = *outRootObject; + pipelineImpl->m_program, m_rootObject.writeRef())); + returnComPtr(outRootObject, m_rootObject); return SLANG_OK; } @@ -1066,7 +1081,7 @@ public: virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) override { - if (uuid == GfxGUID::IID_ISlangUnknown || + if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder || uuid == GfxGUID::IID_IResourceCommandEncoder) { *outObject = static_cast<IResourceCommandEncoder*>(this); @@ -1118,10 +1133,10 @@ public: class CommandQueueImpl : public ICommandQueue - , public RefObject + , public ComObject { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ALL ICommandQueue* getInterface(const Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandQueue) @@ -1176,7 +1191,7 @@ public: currentPipeline = dynamic_cast<CUDAPipelineState*>(state); } - Result bindRootShaderObject(PipelineType pipelineType, IShaderObject* object) + Result bindRootShaderObject(IShaderObject* object) { currentRootObject = dynamic_cast<CUDARootShaderObject*>(object); if (currentRootObject) @@ -1294,8 +1309,7 @@ public: break; case CommandName::BindRootShaderObject: bindRootShaderObject( - (PipelineType)cmd.operands[0], - commandBuffer->getObject<IShaderObject>(cmd.operands[1])); + commandBuffer->getObject<IShaderObject>(cmd.operands[0])); break; case CommandName::DispatchCompute: dispatchCompute( @@ -1324,13 +1338,6 @@ public: using TransientResourceHeapImpl = SimpleTransientResourceHeap<CUDADevice, CommandBufferImpl>; public: - ~CUDADevice() - { - if (m_context) - { - cuCtxDestroy(m_context); - } - } virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc) override { SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_PTX, "sm_5_1")); @@ -1342,15 +1349,12 @@ public: SLANG_RETURN_ON_FAIL(_findMaxFlopsDeviceIndex(&m_deviceIndex)); SLANG_CUDA_RETURN_WITH_REPORT_ON_FAIL(cudaSetDevice(m_deviceIndex), reportType); - if (m_context) - { - cuCtxDestroy(m_context); - m_context = nullptr; - } + m_context = new CUDAContext(); SLANG_CUDA_RETURN_ON_FAIL(cuDeviceGet(&m_device, m_deviceIndex)); - SLANG_CUDA_RETURN_WITH_REPORT_ON_FAIL(cuCtxCreate(&m_context, 0, m_device), reportType); + SLANG_CUDA_RETURN_WITH_REPORT_ON_FAIL( + cuCtxCreate(&m_context->m_context, 0, m_device), reportType); // Initialize DeviceInfo { @@ -1375,7 +1379,12 @@ public: const ITextureResource::SubresourceData* initData, ITextureResource** outResource) override { - RefPtr<TextureCUDAResource> tex = new TextureCUDAResource(desc); + TextureResource::Desc srcDesc(desc); + srcDesc.setDefaults(initialUsage); + + RefPtr<TextureCUDAResource> tex = new TextureCUDAResource(srcDesc); + tex->m_cudaContext = m_context; + CUresourcetype resourceType; size_t elementSize = 0; @@ -1777,7 +1786,7 @@ public: cuTexObjectCreate(&tex->m_cudaTexObj, &resDesc, &texDesc, nullptr)); } - *outResource = tex.detach(); + returnComPtr(outResource, tex); return SLANG_OK; } @@ -1788,12 +1797,13 @@ public: IBufferResource** outResource) override { RefPtr<MemoryCUDAResource> resource = new MemoryCUDAResource(desc); + resource->m_cudaContext = m_context; SLANG_CUDA_RETURN_ON_FAIL(cudaMallocManaged(&resource->m_cudaMemory, desc.sizeInBytes)); if (initData) { SLANG_CUDA_RETURN_ON_FAIL(cudaMemcpy(resource->m_cudaMemory, initData, desc.sizeInBytes, cudaMemcpyHostToDevice)); } - *outResource = resource.detach(); + returnComPtr(outResource, resource); return SLANG_OK; } @@ -1803,7 +1813,7 @@ public: RefPtr<CUDAResourceView> view = new CUDAResourceView(); view->desc = desc; view->textureResource = dynamic_cast<TextureCUDAResource*>(texture); - *outView = view.detach(); + returnComPtr(outView, view); return SLANG_OK; } @@ -1813,7 +1823,7 @@ public: RefPtr<CUDAResourceView> view = new CUDAResourceView(); view->desc = desc; view->memoryResource = dynamic_cast<MemoryCUDAResource*>(buffer); - *outView = view.detach(); + returnComPtr(outView, view); return SLANG_OK; } @@ -1823,7 +1833,7 @@ public: { RefPtr<CUDAShaderObjectLayout> cudaLayout; cudaLayout = new CUDAShaderObjectLayout(this, typeLayout); - *outLayout = cudaLayout.detach(); + returnRefPtrMove(outLayout, cudaLayout); return SLANG_OK; } @@ -1833,18 +1843,18 @@ public: { RefPtr<CUDAShaderObject> result = new CUDAShaderObject(); SLANG_RETURN_ON_FAIL(result->init(this, dynamic_cast<CUDAShaderObjectLayout*>(layout))); - *outObject = result.detach(); + returnComPtr(outObject, result); return SLANG_OK; } - Result createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) + Result createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject) { auto cudaProgram = dynamic_cast<CUDAShaderProgram*>(program); auto cudaLayout = cudaProgram->layout; RefPtr<CUDARootShaderObject> result = new CUDARootShaderObject(); SLANG_RETURN_ON_FAIL(result->init(this, cudaLayout)); - *outObject = result.detach(); + returnRefPtrMove(outObject, result); return SLANG_OK; } @@ -1856,10 +1866,11 @@ public: // the shader object bindings. RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram(); cudaProgram->slangProgram = desc.slangProgram; + cudaProgram->cudaContext = m_context; if (desc.slangProgram->getSpecializationParamCount() != 0) { cudaProgram->layout = new CUDAProgramLayout(this, desc.slangProgram->getLayout()); - *outProgram = cudaProgram.detach(); + returnComPtr(outProgram, cudaProgram); return SLANG_OK; } @@ -1893,7 +1904,7 @@ public: cudaProgram->layout = cudaLayout; } - *outProgram = cudaProgram.detach(); + returnComPtr(outProgram, cudaProgram); return SLANG_OK; } @@ -1901,9 +1912,9 @@ public: const ComputePipelineStateDesc& desc, IPipelineState** outState) override { RefPtr<CUDAPipelineState> state = new CUDAPipelineState(); - state->shaderProgram = dynamic_cast<CUDAShaderProgram*>(desc.program); + state->shaderProgram = static_cast<CUDAShaderProgram*>(desc.program); state->init(desc); - *outState = state.detach(); + returnComPtr(outState, state); return Result(); } @@ -1929,7 +1940,7 @@ public: { RefPtr<TransientResourceHeapImpl> result = new TransientResourceHeapImpl(); SLANG_RETURN_ON_FAIL(result->init(this, desc)); - *outHeap = result.detach(); + returnComPtr(outHeap, result); return SLANG_OK; } @@ -1938,7 +1949,7 @@ public: { RefPtr<CommandQueueImpl> queue = new CommandQueueImpl(); queue->init(this); - *outQueue = queue.detach(); + returnComPtr(outQueue, queue); return SLANG_OK; } virtual SLANG_NO_THROW Result SLANG_MCALL createSwapchain( @@ -2027,7 +2038,7 @@ public: (uint8_t*)bufferImpl->m_cudaMemory + offset, size, cudaMemcpyDefault); - *outBlob = blob.detach(); + returnComPtr(outBlob, blob); return SLANG_OK; } }; @@ -2115,7 +2126,7 @@ SlangResult SLANG_MCALL createCUDADevice(const IDevice::Desc* desc, IDevice** ou { RefPtr<CUDADevice> result = new CUDADevice(); SLANG_RETURN_ON_FAIL(result->initialize(*desc)); - *outDevice = result.detach(); + returnComPtr(outDevice, result); return SLANG_OK; } #else |
