summaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-02-01 20:00:55 -0800
committerGitHub <noreply@github.com>2022-02-01 20:00:55 -0800
commit5deb82929d289d6341e1919ee95b18b10f6db789 (patch)
tree196daa08b9e2805e6c686759a10f186ef57b3763 /tools
parente59516fa8c3a16eb7b99a928c5b85b97bf44fd72 (diff)
GFX: Add API interception mechanism for pipeline creation. (#2115)
This allows the user application to intercept API calls to create pipeline states. This feature can be used to integrate NVAPI in the user application. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tools')
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp35
-rw-r--r--tools/gfx/renderer-shared.cpp7
-rw-r--r--tools/gfx/renderer-shared.h3
3 files changed, 42 insertions, 3 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index b2227aa76..8555bb4ec 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -7222,7 +7222,15 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc&
psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType);
ComPtr<ID3D12PipelineState> pipelineState;
- SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createGraphicsPipelineState(
+ this, programImpl->slangProgram.get(), &psoDesc, (void**)pipelineState.writeRef()));
+ }
+ else
+ {
+ SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ }
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
pipelineStateImpl->m_pipelineState = pipelineState;
@@ -7288,8 +7296,19 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i
else
#endif
{
- SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(
- &computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createComputePipelineState(
+ this,
+ programImpl->slangProgram.get(),
+ &computeDesc,
+ (void**)pipelineState.writeRef()));
+ }
+ else
+ {
+ SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(
+ &computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ }
}
}
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
@@ -7843,12 +7862,22 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
pipelineConfigSubobject.pDesc = &pipelineConfig;
subObjects.add(pipelineConfigSubobject);
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangProgram);
+ }
+
D3D12_STATE_OBJECT_DESC rtpsoDesc = {};
rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE;
rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount();
rtpsoDesc.pSubobjects = subObjects.getBuffer();
SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangProgram);
+ }
+
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
}
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index 52cf7ffac..7ba939530 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -33,6 +33,7 @@ const Slang::Guid GfxGUID::IID_IQueryPool = SLANG_UUID_IQueryPool;
const Slang::Guid GfxGUID::IID_IAccelerationStructure = SLANG_UUID_IAccelerationStructure;
const Slang::Guid GfxGUID::IID_IFence = SLANG_UUID_IFence;
const Slang::Guid GfxGUID::IID_IShaderTable = SLANG_UUID_IShaderTable;
+const Slang::Guid GfxGUID::IID_IPipelineCreationAPIDispatcher = SLANG_UUID_IPipelineCreationAPIDispatcher;
StageType translateStage(SlangStage slangStage)
@@ -296,6 +297,12 @@ IDevice* gfx::RendererBase::getInterface(const Guid& guid)
SLANG_NO_THROW Result SLANG_MCALL RendererBase::initialize(const Desc& desc)
{
+ if (desc.apiCommandDispatcher)
+ {
+ desc.apiCommandDispatcher->queryInterface(
+ GfxGUID::IID_IPipelineCreationAPIDispatcher,
+ (void**)m_pipelineCreationAPIDispatcher.writeRef());
+ }
return SLANG_OK;
}
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 6baade085..6bf86e28b 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -40,6 +40,7 @@ struct GfxGUID
static const Slang::Guid IID_IAccelerationStructure;
static const Slang::Guid IID_IFence;
static const Slang::Guid IID_IShaderTable;
+ static const Slang::Guid IID_IPipelineCreationAPIDispatcher;
};
// We use a `BreakableReference` to avoid the cyclic reference situation in gfx implementation.
@@ -1355,6 +1356,8 @@ protected:
virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc);
protected:
Slang::List<Slang::String> m_features;
+ Slang::ComPtr<IPipelineCreationAPIDispatcher> m_pipelineCreationAPIDispatcher;
+
public:
SlangContext slangContext;
ShaderCache shaderCache;