diff options
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/gfx/d3d12/render-d3d12.cpp | 35 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.cpp | 7 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.h | 3 |
3 files changed, 42 insertions, 3 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index b2227aa76..8555bb4ec 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -7222,7 +7222,15 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType); ComPtr<ID3D12PipelineState> pipelineState; - SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + if (m_pipelineCreationAPIDispatcher) + { + SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createGraphicsPipelineState( + this, programImpl->slangProgram.get(), &psoDesc, (void**)pipelineState.writeRef())); + } + else + { + SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + } RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->m_pipelineState = pipelineState; @@ -7288,8 +7296,19 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i else #endif { - SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState( - &computeDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + if (m_pipelineCreationAPIDispatcher) + { + SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createComputePipelineState( + this, + programImpl->slangProgram.get(), + &computeDesc, + (void**)pipelineState.writeRef())); + } + else + { + SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState( + &computeDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + } } } RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(); @@ -7843,12 +7862,22 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD pipelineConfigSubobject.pDesc = &pipelineConfig; subObjects.add(pipelineConfigSubobject); + if (m_pipelineCreationAPIDispatcher) + { + m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangProgram); + } + D3D12_STATE_OBJECT_DESC rtpsoDesc = {}; rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE; rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount(); rtpsoDesc.pSubobjects = subObjects.getBuffer(); SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef()))); + if (m_pipelineCreationAPIDispatcher) + { + m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangProgram); + } + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 52cf7ffac..7ba939530 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -33,6 +33,7 @@ const Slang::Guid GfxGUID::IID_IQueryPool = SLANG_UUID_IQueryPool; const Slang::Guid GfxGUID::IID_IAccelerationStructure = SLANG_UUID_IAccelerationStructure; const Slang::Guid GfxGUID::IID_IFence = SLANG_UUID_IFence; const Slang::Guid GfxGUID::IID_IShaderTable = SLANG_UUID_IShaderTable; +const Slang::Guid GfxGUID::IID_IPipelineCreationAPIDispatcher = SLANG_UUID_IPipelineCreationAPIDispatcher; StageType translateStage(SlangStage slangStage) @@ -296,6 +297,12 @@ IDevice* gfx::RendererBase::getInterface(const Guid& guid) SLANG_NO_THROW Result SLANG_MCALL RendererBase::initialize(const Desc& desc) { + if (desc.apiCommandDispatcher) + { + desc.apiCommandDispatcher->queryInterface( + GfxGUID::IID_IPipelineCreationAPIDispatcher, + (void**)m_pipelineCreationAPIDispatcher.writeRef()); + } return SLANG_OK; } diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 6baade085..6bf86e28b 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -40,6 +40,7 @@ struct GfxGUID static const Slang::Guid IID_IAccelerationStructure; static const Slang::Guid IID_IFence; static const Slang::Guid IID_IShaderTable; + static const Slang::Guid IID_IPipelineCreationAPIDispatcher; }; // We use a `BreakableReference` to avoid the cyclic reference situation in gfx implementation. @@ -1355,6 +1356,8 @@ protected: virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc); protected: Slang::List<Slang::String> m_features; + Slang::ComPtr<IPipelineCreationAPIDispatcher> m_pipelineCreationAPIDispatcher; + public: SlangContext slangContext; ShaderCache shaderCache; |
