summaryrefslogtreecommitdiff
path: root/tools/gfx/d3d12/render-d3d12.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-02-01 20:00:55 -0800
committerGitHub <noreply@github.com>2022-02-01 20:00:55 -0800
commit5deb82929d289d6341e1919ee95b18b10f6db789 (patch)
tree196daa08b9e2805e6c686759a10f186ef57b3763 /tools/gfx/d3d12/render-d3d12.cpp
parente59516fa8c3a16eb7b99a928c5b85b97bf44fd72 (diff)
GFX: Add API interception mechanism for pipeline creation. (#2115)
This allows the user application to intercept API calls to create pipeline states. This feature can be used to integrate NVAPI in the user application. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tools/gfx/d3d12/render-d3d12.cpp')
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp35
1 files changed, 32 insertions, 3 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index b2227aa76..8555bb4ec 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -7222,7 +7222,15 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc&
psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType);
ComPtr<ID3D12PipelineState> pipelineState;
- SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createGraphicsPipelineState(
+ this, programImpl->slangProgram.get(), &psoDesc, (void**)pipelineState.writeRef()));
+ }
+ else
+ {
+ SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ }
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
pipelineStateImpl->m_pipelineState = pipelineState;
@@ -7288,8 +7296,19 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i
else
#endif
{
- SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(
- &computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createComputePipelineState(
+ this,
+ programImpl->slangProgram.get(),
+ &computeDesc,
+ (void**)pipelineState.writeRef()));
+ }
+ else
+ {
+ SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(
+ &computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ }
}
}
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
@@ -7843,12 +7862,22 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
pipelineConfigSubobject.pDesc = &pipelineConfig;
subObjects.add(pipelineConfigSubobject);
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangProgram);
+ }
+
D3D12_STATE_OBJECT_DESC rtpsoDesc = {};
rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE;
rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount();
rtpsoDesc.pSubobjects = subObjects.getBuffer();
SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangProgram);
+ }
+
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
}