summaryrefslogtreecommitdiffstats
path: root/tools/gfx/cuda/render-cuda.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-04-08 21:10:30 -0700
committerGitHub <noreply@github.com>2021-04-08 21:10:30 -0700
commit8a71039475212fb1e1a6dd2fd2911d02769637ef (patch)
tree0faa6e773d6b40c3dcbf0eed08217c629f8ebccf /tools/gfx/cuda/render-cuda.cpp
parentd27557d9b770810402a0bf99bcd891c145a1a69d (diff)
Improve robustness of gfx lifetime management. (#1788)
* Improve robustness of gfx lifetime management. * fix clang error * fix clang error * Fix clang warning
Diffstat (limited to 'tools/gfx/cuda/render-cuda.cpp')
-rw-r--r--tools/gfx/cuda/render-cuda.cpp115
1 files changed, 63 insertions, 52 deletions
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp
index dbee0c5f2..f871e6246 100644
--- a/tools/gfx/cuda/render-cuda.cpp
+++ b/tools/gfx/cuda/render-cuda.cpp
@@ -181,6 +181,13 @@ void _optixLogCallback(unsigned int level, const char* tag, const char* message,
# endif
+class CUDAContext : public RefObject
+{
+public:
+ CUcontext m_context = nullptr;
+ ~CUDAContext() { cuCtxDestroy(m_context); }
+};
+
class MemoryCUDAResource : public BufferResource
{
public:
@@ -199,6 +206,8 @@ public:
uint64_t getBindlessHandle() { return (uint64_t)m_cudaMemory; }
void* m_cudaMemory = nullptr;
+
+ RefPtr<CUDAContext> m_cudaContext;
};
class TextureCUDAResource : public TextureResource
@@ -238,12 +247,14 @@ public:
CUarray m_cudaArray = CUarray();
CUmipmappedArray m_cudaMipMappedArray = CUmipmappedArray();
+
+ RefPtr<CUDAContext> m_cudaContext;
};
-class CUDAResourceView : public IResourceView, public RefObject
+class CUDAResourceView : public IResourceView, public ComObject
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
+ SLANG_COM_OBJECT_IUNKNOWN_ALL
IResourceView* getInterface(const Guid& guid)
{
if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResourceView)
@@ -451,7 +462,7 @@ public:
ComPtr<IBufferResource> constantBuffer;
SLANG_RETURN_ON_FAIL(device->createBufferResource(
IResource::Usage::ConstantBuffer, bufferDesc, nullptr, constantBuffer.writeRef()));
- bufferResource = dynamic_cast<MemoryCUDAResource*>(constantBuffer.get());
+ bufferResource = static_cast<MemoryCUDAResource*>(constantBuffer.get());
return SLANG_OK;
}
@@ -514,8 +525,7 @@ public:
*object = nullptr;
return SLANG_OK;
}
- objects[subObjectIndex]->addRef();
- *object = objects[subObjectIndex].Ptr();
+ returnComPtr(object, objects[subObjectIndex]);
return SLANG_OK;
}
virtual SLANG_NO_THROW Result SLANG_MCALL
@@ -648,7 +658,7 @@ public:
auto& bindingRange = layout->m_bindingRanges[bindingRangeIndex];
auto viewIndex = bindingRange.baseIndex + offset.bindingArrayIndex;
- auto cudaView = dynamic_cast<CUDAResourceView*>(resourceView);
+ auto cudaView = static_cast<CUDAResourceView*>(resourceView);
resources[viewIndex] = cudaView;
@@ -720,7 +730,6 @@ public:
// type, we need to make sure to use that type as the specialization argument.
// TODO: need to implement the case where the field is an array of existential values.
- SLANG_ASSERT(bindingRange.count == 1);
ExtendedShaderObjectType specializedSubObjType;
SLANG_RETURN_ON_FAIL(objects[subObjIndex]->getSpecializedShaderObjectType(&specializedSubObjType));
args.add(specializedSubObjType);
@@ -801,6 +810,12 @@ public:
class CUDARootShaderObject : public CUDAShaderObject
{
public:
+ // Override default reference counting behavior to disable lifetime management.
+ // Root objects are managed by command buffer and does not need to be freed by the user.
+ SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; }
+ SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; }
+
+public:
List<RefPtr<CUDAEntryPointShaderObject>> entryPointObjects;
virtual SLANG_NO_THROW Result SLANG_MCALL
init(IDevice* device, CUDAShaderObjectLayout* typeLayout) override;
@@ -808,8 +823,7 @@ public:
virtual SLANG_NO_THROW Result SLANG_MCALL
getEntryPoint(UInt index, IShaderObject** outEntryPoint) override
{
- *outEntryPoint = entryPointObjects[index].Ptr();
- entryPointObjects[index]->addRef();
+ returnComPtr(outEntryPoint, entryPointObjects[index]);
return SLANG_OK;
}
virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override
@@ -830,6 +844,7 @@ public:
CUfunction cudaKernel;
String kernelName;
RefPtr<CUDAProgramLayout> layout;
+ RefPtr<CUDAContext> cudaContext;
~CUDAShaderProgram()
{
if (cudaModule)
@@ -969,7 +984,7 @@ private:
private:
int m_deviceIndex = -1;
CUdevice m_device = 0;
- CUcontext m_context = nullptr;
+ RefPtr<CUDAContext> m_context;
DeviceInfo m_info;
String m_adapterName;
@@ -979,10 +994,10 @@ public:
class CommandBufferImpl
: public ICommandBuffer
, public CommandWriter
- , public RefObject
+ , public ComObject
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
+ SLANG_COM_OBJECT_IUNKNOWN_ALL
ICommandBuffer* getInterface(const Guid& guid)
{
if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandBuffer)
@@ -1010,7 +1025,7 @@ public:
virtual SLANG_NO_THROW SlangResult SLANG_MCALL
queryInterface(SlangUUID const& uuid, void** outObject) override
{
- if (uuid == GfxGUID::IID_ISlangUnknown ||
+ if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder ||
uuid == GfxGUID::IID_IComputeCommandEncoder)
{
*outObject = static_cast<IComputeCommandEncoder*>(this);
@@ -1025,7 +1040,7 @@ public:
public:
CommandWriter* m_writer;
CommandBufferImpl* m_commandBuffer;
- ComPtr<IShaderObject> m_rootObject;
+ RefPtr<ShaderObjectBase> m_rootObject;
virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override {}
void init(CommandBufferImpl* cmdBuffer)
{
@@ -1039,8 +1054,8 @@ public:
m_writer->setPipelineState(state);
PipelineStateBase* pipelineImpl = static_cast<PipelineStateBase*>(state);
SLANG_RETURN_ON_FAIL(m_commandBuffer->m_device->createRootShaderObject(
- pipelineImpl->m_program, outRootObject));
- m_rootObject = *outRootObject;
+ pipelineImpl->m_program, m_rootObject.writeRef()));
+ returnComPtr(outRootObject, m_rootObject);
return SLANG_OK;
}
@@ -1066,7 +1081,7 @@ public:
virtual SLANG_NO_THROW SlangResult SLANG_MCALL
queryInterface(SlangUUID const& uuid, void** outObject) override
{
- if (uuid == GfxGUID::IID_ISlangUnknown ||
+ if (uuid == GfxGUID::IID_ISlangUnknown || uuid == GfxGUID::IID_ICommandEncoder ||
uuid == GfxGUID::IID_IResourceCommandEncoder)
{
*outObject = static_cast<IResourceCommandEncoder*>(this);
@@ -1118,10 +1133,10 @@ public:
class CommandQueueImpl
: public ICommandQueue
- , public RefObject
+ , public ComObject
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
+ SLANG_COM_OBJECT_IUNKNOWN_ALL
ICommandQueue* getInterface(const Guid& guid)
{
if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ICommandQueue)
@@ -1176,7 +1191,7 @@ public:
currentPipeline = dynamic_cast<CUDAPipelineState*>(state);
}
- Result bindRootShaderObject(PipelineType pipelineType, IShaderObject* object)
+ Result bindRootShaderObject(IShaderObject* object)
{
currentRootObject = dynamic_cast<CUDARootShaderObject*>(object);
if (currentRootObject)
@@ -1294,8 +1309,7 @@ public:
break;
case CommandName::BindRootShaderObject:
bindRootShaderObject(
- (PipelineType)cmd.operands[0],
- commandBuffer->getObject<IShaderObject>(cmd.operands[1]));
+ commandBuffer->getObject<IShaderObject>(cmd.operands[0]));
break;
case CommandName::DispatchCompute:
dispatchCompute(
@@ -1324,13 +1338,6 @@ public:
using TransientResourceHeapImpl = SimpleTransientResourceHeap<CUDADevice, CommandBufferImpl>;
public:
- ~CUDADevice()
- {
- if (m_context)
- {
- cuCtxDestroy(m_context);
- }
- }
virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc) override
{
SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_PTX, "sm_5_1"));
@@ -1342,15 +1349,12 @@ public:
SLANG_RETURN_ON_FAIL(_findMaxFlopsDeviceIndex(&m_deviceIndex));
SLANG_CUDA_RETURN_WITH_REPORT_ON_FAIL(cudaSetDevice(m_deviceIndex), reportType);
- if (m_context)
- {
- cuCtxDestroy(m_context);
- m_context = nullptr;
- }
+ m_context = new CUDAContext();
SLANG_CUDA_RETURN_ON_FAIL(cuDeviceGet(&m_device, m_deviceIndex));
- SLANG_CUDA_RETURN_WITH_REPORT_ON_FAIL(cuCtxCreate(&m_context, 0, m_device), reportType);
+ SLANG_CUDA_RETURN_WITH_REPORT_ON_FAIL(
+ cuCtxCreate(&m_context->m_context, 0, m_device), reportType);
// Initialize DeviceInfo
{
@@ -1375,7 +1379,12 @@ public:
const ITextureResource::SubresourceData* initData,
ITextureResource** outResource) override
{
- RefPtr<TextureCUDAResource> tex = new TextureCUDAResource(desc);
+ TextureResource::Desc srcDesc(desc);
+ srcDesc.setDefaults(initialUsage);
+
+ RefPtr<TextureCUDAResource> tex = new TextureCUDAResource(srcDesc);
+ tex->m_cudaContext = m_context;
+
CUresourcetype resourceType;
size_t elementSize = 0;
@@ -1777,7 +1786,7 @@ public:
cuTexObjectCreate(&tex->m_cudaTexObj, &resDesc, &texDesc, nullptr));
}
- *outResource = tex.detach();
+ returnComPtr(outResource, tex);
return SLANG_OK;
}
@@ -1788,12 +1797,13 @@ public:
IBufferResource** outResource) override
{
RefPtr<MemoryCUDAResource> resource = new MemoryCUDAResource(desc);
+ resource->m_cudaContext = m_context;
SLANG_CUDA_RETURN_ON_FAIL(cudaMallocManaged(&resource->m_cudaMemory, desc.sizeInBytes));
if (initData)
{
SLANG_CUDA_RETURN_ON_FAIL(cudaMemcpy(resource->m_cudaMemory, initData, desc.sizeInBytes, cudaMemcpyHostToDevice));
}
- *outResource = resource.detach();
+ returnComPtr(outResource, resource);
return SLANG_OK;
}
@@ -1803,7 +1813,7 @@ public:
RefPtr<CUDAResourceView> view = new CUDAResourceView();
view->desc = desc;
view->textureResource = dynamic_cast<TextureCUDAResource*>(texture);
- *outView = view.detach();
+ returnComPtr(outView, view);
return SLANG_OK;
}
@@ -1813,7 +1823,7 @@ public:
RefPtr<CUDAResourceView> view = new CUDAResourceView();
view->desc = desc;
view->memoryResource = dynamic_cast<MemoryCUDAResource*>(buffer);
- *outView = view.detach();
+ returnComPtr(outView, view);
return SLANG_OK;
}
@@ -1823,7 +1833,7 @@ public:
{
RefPtr<CUDAShaderObjectLayout> cudaLayout;
cudaLayout = new CUDAShaderObjectLayout(this, typeLayout);
- *outLayout = cudaLayout.detach();
+ returnRefPtrMove(outLayout, cudaLayout);
return SLANG_OK;
}
@@ -1833,18 +1843,18 @@ public:
{
RefPtr<CUDAShaderObject> result = new CUDAShaderObject();
SLANG_RETURN_ON_FAIL(result->init(this, dynamic_cast<CUDAShaderObjectLayout*>(layout)));
- *outObject = result.detach();
+ returnComPtr(outObject, result);
return SLANG_OK;
}
- Result createRootShaderObject(IShaderProgram* program, IShaderObject** outObject)
+ Result createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject)
{
auto cudaProgram = dynamic_cast<CUDAShaderProgram*>(program);
auto cudaLayout = cudaProgram->layout;
RefPtr<CUDARootShaderObject> result = new CUDARootShaderObject();
SLANG_RETURN_ON_FAIL(result->init(this, cudaLayout));
- *outObject = result.detach();
+ returnRefPtrMove(outObject, result);
return SLANG_OK;
}
@@ -1856,10 +1866,11 @@ public:
// the shader object bindings.
RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram();
cudaProgram->slangProgram = desc.slangProgram;
+ cudaProgram->cudaContext = m_context;
if (desc.slangProgram->getSpecializationParamCount() != 0)
{
cudaProgram->layout = new CUDAProgramLayout(this, desc.slangProgram->getLayout());
- *outProgram = cudaProgram.detach();
+ returnComPtr(outProgram, cudaProgram);
return SLANG_OK;
}
@@ -1893,7 +1904,7 @@ public:
cudaProgram->layout = cudaLayout;
}
- *outProgram = cudaProgram.detach();
+ returnComPtr(outProgram, cudaProgram);
return SLANG_OK;
}
@@ -1901,9 +1912,9 @@ public:
const ComputePipelineStateDesc& desc, IPipelineState** outState) override
{
RefPtr<CUDAPipelineState> state = new CUDAPipelineState();
- state->shaderProgram = dynamic_cast<CUDAShaderProgram*>(desc.program);
+ state->shaderProgram = static_cast<CUDAShaderProgram*>(desc.program);
state->init(desc);
- *outState = state.detach();
+ returnComPtr(outState, state);
return Result();
}
@@ -1929,7 +1940,7 @@ public:
{
RefPtr<TransientResourceHeapImpl> result = new TransientResourceHeapImpl();
SLANG_RETURN_ON_FAIL(result->init(this, desc));
- *outHeap = result.detach();
+ returnComPtr(outHeap, result);
return SLANG_OK;
}
@@ -1938,7 +1949,7 @@ public:
{
RefPtr<CommandQueueImpl> queue = new CommandQueueImpl();
queue->init(this);
- *outQueue = queue.detach();
+ returnComPtr(outQueue, queue);
return SLANG_OK;
}
virtual SLANG_NO_THROW Result SLANG_MCALL createSwapchain(
@@ -2027,7 +2038,7 @@ public:
(uint8_t*)bufferImpl->m_cudaMemory + offset,
size,
cudaMemcpyDefault);
- *outBlob = blob.detach();
+ returnComPtr(outBlob, blob);
return SLANG_OK;
}
};
@@ -2115,7 +2126,7 @@ SlangResult SLANG_MCALL createCUDADevice(const IDevice::Desc* desc, IDevice** ou
{
RefPtr<CUDADevice> result = new CUDADevice();
SLANG_RETURN_ON_FAIL(result->initialize(*desc));
- *outDevice = result.detach();
+ returnComPtr(outDevice, result);
return SLANG_OK;
}
#else