summaryrefslogtreecommitdiff
path: root/tools/gfx/d3d12/render-d3d12.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/gfx/d3d12/render-d3d12.cpp')
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp35
1 files changed, 32 insertions, 3 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index b2227aa76..8555bb4ec 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -7222,7 +7222,15 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc&
psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType);
ComPtr<ID3D12PipelineState> pipelineState;
- SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createGraphicsPipelineState(
+ this, programImpl->slangProgram.get(), &psoDesc, (void**)pipelineState.writeRef()));
+ }
+ else
+ {
+ SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ }
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
pipelineStateImpl->m_pipelineState = pipelineState;
@@ -7288,8 +7296,19 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i
else
#endif
{
- SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(
- &computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createComputePipelineState(
+ this,
+ programImpl->slangProgram.get(),
+ &computeDesc,
+ (void**)pipelineState.writeRef()));
+ }
+ else
+ {
+ SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(
+ &computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ }
}
}
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
@@ -7843,12 +7862,22 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
pipelineConfigSubobject.pDesc = &pipelineConfig;
subObjects.add(pipelineConfigSubobject);
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangProgram);
+ }
+
D3D12_STATE_OBJECT_DESC rtpsoDesc = {};
rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE;
rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount();
rtpsoDesc.pSubobjects = subObjects.getBuffer();
SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangProgram);
+ }
+
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
}