diff options
Diffstat (limited to 'tools/gfx/d3d12/render-d3d12.cpp')
| -rw-r--r-- | tools/gfx/d3d12/render-d3d12.cpp | 35 |
1 files changed, 32 insertions, 3 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index b2227aa76..8555bb4ec 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -7222,7 +7222,15 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType); ComPtr<ID3D12PipelineState> pipelineState; - SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + if (m_pipelineCreationAPIDispatcher) + { + SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createGraphicsPipelineState( + this, programImpl->slangProgram.get(), &psoDesc, (void**)pipelineState.writeRef())); + } + else + { + SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + } RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->m_pipelineState = pipelineState; @@ -7288,8 +7296,19 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i else #endif { - SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState( - &computeDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + if (m_pipelineCreationAPIDispatcher) + { + SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createComputePipelineState( + this, + programImpl->slangProgram.get(), + &computeDesc, + (void**)pipelineState.writeRef())); + } + else + { + SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState( + &computeDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + } } } RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(); @@ -7843,12 +7862,22 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD pipelineConfigSubobject.pDesc = &pipelineConfig; subObjects.add(pipelineConfigSubobject); + if (m_pipelineCreationAPIDispatcher) + { + m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangProgram); + } + D3D12_STATE_OBJECT_DESC rtpsoDesc = {}; rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE; rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount(); rtpsoDesc.pSubobjects = subObjects.getBuffer(); SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef()))); + if (m_pipelineCreationAPIDispatcher) + { + m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangProgram); + } + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } |
