summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slang-gfx.h27
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp35
-rw-r--r--tools/gfx/renderer-shared.cpp7
-rw-r--r--tools/gfx/renderer-shared.h3
4 files changed, 69 insertions, 3 deletions
diff --git a/slang-gfx.h b/slang-gfx.h
index dcd78fc5b..08573e821 100644
--- a/slang-gfx.h
+++ b/slang-gfx.h
@@ -1947,6 +1947,8 @@ public:
int requiredFeatureCount = 0;
// Array of required feature names, whose size is `requiredFeatureCount`.
const char** requiredFeatures = nullptr;
+ // A command dispatcher object that intercepts and handles actual low-level API call.
+ ISlangUnknown* apiCommandDispatcher = nullptr;
// The slot (typically UAV) used to identify NVAPI intrinsics. If >=0 NVAPI is required.
int nvapiExtnSlot = -1;
// The file system for loading cached shader kernels. The layer does not maintain a strong reference to the object,
@@ -2268,6 +2270,31 @@ public:
0x715bdf26, 0x5135, 0x11eb, { 0xAE, 0x93, 0x02, 0x42, 0xAC, 0x13, 0x00, 0x02 } \
}
+
+class IPipelineCreationAPIDispatcher : public ISlangUnknown
+{
+public:
+ virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState(
+ IDevice* device,
+ slang::IComponentType* program,
+ void* pipelineDesc,
+ void** outPipelineState) = 0;
+ virtual SLANG_NO_THROW Result SLANG_MCALL createGraphicsPipelineState(
+ IDevice* device,
+ slang::IComponentType* program,
+ void* pipelineDesc,
+ void** outPipelineState) = 0;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ beforeCreateRayTracingState(IDevice* device, slang::IComponentType* program) = 0;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ afterCreateRayTracingState(IDevice* device, slang::IComponentType* program) = 0;
+};
+#define SLANG_UUID_IPipelineCreationAPIDispatcher \
+ { \
+ 0xc3d5f782, 0xeae1, 0x4da6, { 0xab, 0x40, 0x75, 0x32, 0x31, 0x2, 0xb7, 0xdc } \
+ }
+
+
// Global public functions
extern "C"
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index b2227aa76..8555bb4ec 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -7222,7 +7222,15 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc&
psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType);
ComPtr<ID3D12PipelineState> pipelineState;
- SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createGraphicsPipelineState(
+ this, programImpl->slangProgram.get(), &psoDesc, (void**)pipelineState.writeRef()));
+ }
+ else
+ {
+ SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ }
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
pipelineStateImpl->m_pipelineState = pipelineState;
@@ -7288,8 +7296,19 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i
else
#endif
{
- SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(
- &computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createComputePipelineState(
+ this,
+ programImpl->slangProgram.get(),
+ &computeDesc,
+ (void**)pipelineState.writeRef()));
+ }
+ else
+ {
+ SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(
+ &computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ }
}
}
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
@@ -7843,12 +7862,22 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
pipelineConfigSubobject.pDesc = &pipelineConfig;
subObjects.add(pipelineConfigSubobject);
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangProgram);
+ }
+
D3D12_STATE_OBJECT_DESC rtpsoDesc = {};
rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE;
rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount();
rtpsoDesc.pSubobjects = subObjects.getBuffer();
SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef())));
+ if (m_pipelineCreationAPIDispatcher)
+ {
+ m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangProgram);
+ }
+
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
}
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index 52cf7ffac..7ba939530 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -33,6 +33,7 @@ const Slang::Guid GfxGUID::IID_IQueryPool = SLANG_UUID_IQueryPool;
const Slang::Guid GfxGUID::IID_IAccelerationStructure = SLANG_UUID_IAccelerationStructure;
const Slang::Guid GfxGUID::IID_IFence = SLANG_UUID_IFence;
const Slang::Guid GfxGUID::IID_IShaderTable = SLANG_UUID_IShaderTable;
+const Slang::Guid GfxGUID::IID_IPipelineCreationAPIDispatcher = SLANG_UUID_IPipelineCreationAPIDispatcher;
StageType translateStage(SlangStage slangStage)
@@ -296,6 +297,12 @@ IDevice* gfx::RendererBase::getInterface(const Guid& guid)
SLANG_NO_THROW Result SLANG_MCALL RendererBase::initialize(const Desc& desc)
{
+ if (desc.apiCommandDispatcher)
+ {
+ desc.apiCommandDispatcher->queryInterface(
+ GfxGUID::IID_IPipelineCreationAPIDispatcher,
+ (void**)m_pipelineCreationAPIDispatcher.writeRef());
+ }
return SLANG_OK;
}
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 6baade085..6bf86e28b 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -40,6 +40,7 @@ struct GfxGUID
static const Slang::Guid IID_IAccelerationStructure;
static const Slang::Guid IID_IFence;
static const Slang::Guid IID_IShaderTable;
+ static const Slang::Guid IID_IPipelineCreationAPIDispatcher;
};
// We use a `BreakableReference` to avoid the cyclic reference situation in gfx implementation.
@@ -1355,6 +1356,8 @@ protected:
virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc);
protected:
Slang::List<Slang::String> m_features;
+ Slang::ComPtr<IPipelineCreationAPIDispatcher> m_pipelineCreationAPIDispatcher;
+
public:
SlangContext slangContext;
ShaderCache shaderCache;