diff options
| -rw-r--r-- | slang-gfx.h | 27 | ||||
| -rw-r--r-- | tools/gfx/d3d12/render-d3d12.cpp | 35 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.cpp | 7 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.h | 3 |
4 files changed, 69 insertions, 3 deletions
diff --git a/slang-gfx.h b/slang-gfx.h index dcd78fc5b..08573e821 100644 --- a/slang-gfx.h +++ b/slang-gfx.h @@ -1947,6 +1947,8 @@ public: int requiredFeatureCount = 0; // Array of required feature names, whose size is `requiredFeatureCount`. const char** requiredFeatures = nullptr; + // A command dispatcher object that intercepts and handles actual low-level API call. + ISlangUnknown* apiCommandDispatcher = nullptr; // The slot (typically UAV) used to identify NVAPI intrinsics. If >=0 NVAPI is required. int nvapiExtnSlot = -1; // The file system for loading cached shader kernels. The layer does not maintain a strong reference to the object, @@ -2268,6 +2270,31 @@ public: 0x715bdf26, 0x5135, 0x11eb, { 0xAE, 0x93, 0x02, 0x42, 0xAC, 0x13, 0x00, 0x02 } \ } + +class IPipelineCreationAPIDispatcher : public ISlangUnknown +{ +public: + virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState( + IDevice* device, + slang::IComponentType* program, + void* pipelineDesc, + void** outPipelineState) = 0; + virtual SLANG_NO_THROW Result SLANG_MCALL createGraphicsPipelineState( + IDevice* device, + slang::IComponentType* program, + void* pipelineDesc, + void** outPipelineState) = 0; + virtual SLANG_NO_THROW Result SLANG_MCALL + beforeCreateRayTracingState(IDevice* device, slang::IComponentType* program) = 0; + virtual SLANG_NO_THROW Result SLANG_MCALL + afterCreateRayTracingState(IDevice* device, slang::IComponentType* program) = 0; +}; +#define SLANG_UUID_IPipelineCreationAPIDispatcher \ + { \ + 0xc3d5f782, 0xeae1, 0x4da6, { 0xab, 0x40, 0x75, 0x32, 0x31, 0x2, 0xb7, 0xdc } \ + } + + // Global public functions extern "C" diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index b2227aa76..8555bb4ec 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -7222,7 +7222,15 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType); ComPtr<ID3D12PipelineState> pipelineState; - SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + if (m_pipelineCreationAPIDispatcher) + { + SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createGraphicsPipelineState( + this, programImpl->slangProgram.get(), &psoDesc, (void**)pipelineState.writeRef())); + } + else + { + SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + } RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->m_pipelineState = pipelineState; @@ -7288,8 +7296,19 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i else #endif { - SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState( - &computeDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + if (m_pipelineCreationAPIDispatcher) + { + SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createComputePipelineState( + this, + programImpl->slangProgram.get(), + &computeDesc, + (void**)pipelineState.writeRef())); + } + else + { + SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState( + &computeDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + } } } RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(); @@ -7843,12 +7862,22 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD pipelineConfigSubobject.pDesc = &pipelineConfig; subObjects.add(pipelineConfigSubobject); + if (m_pipelineCreationAPIDispatcher) + { + m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangProgram); + } + D3D12_STATE_OBJECT_DESC rtpsoDesc = {}; rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE; rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount(); rtpsoDesc.pSubobjects = subObjects.getBuffer(); SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef()))); + if (m_pipelineCreationAPIDispatcher) + { + m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangProgram); + } + returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 52cf7ffac..7ba939530 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -33,6 +33,7 @@ const Slang::Guid GfxGUID::IID_IQueryPool = SLANG_UUID_IQueryPool; const Slang::Guid GfxGUID::IID_IAccelerationStructure = SLANG_UUID_IAccelerationStructure; const Slang::Guid GfxGUID::IID_IFence = SLANG_UUID_IFence; const Slang::Guid GfxGUID::IID_IShaderTable = SLANG_UUID_IShaderTable; +const Slang::Guid GfxGUID::IID_IPipelineCreationAPIDispatcher = SLANG_UUID_IPipelineCreationAPIDispatcher; StageType translateStage(SlangStage slangStage) @@ -296,6 +297,12 @@ IDevice* gfx::RendererBase::getInterface(const Guid& guid) SLANG_NO_THROW Result SLANG_MCALL RendererBase::initialize(const Desc& desc) { + if (desc.apiCommandDispatcher) + { + desc.apiCommandDispatcher->queryInterface( + GfxGUID::IID_IPipelineCreationAPIDispatcher, + (void**)m_pipelineCreationAPIDispatcher.writeRef()); + } return SLANG_OK; } diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 6baade085..6bf86e28b 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -40,6 +40,7 @@ struct GfxGUID static const Slang::Guid IID_IAccelerationStructure; static const Slang::Guid IID_IFence; static const Slang::Guid IID_IShaderTable; + static const Slang::Guid IID_IPipelineCreationAPIDispatcher; }; // We use a `BreakableReference` to avoid the cyclic reference situation in gfx implementation. @@ -1355,6 +1356,8 @@ protected: virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc); protected: Slang::List<Slang::String> m_features; + Slang::ComPtr<IPipelineCreationAPIDispatcher> m_pipelineCreationAPIDispatcher; + public: SlangContext slangContext; ShaderCache shaderCache; |
