summaryrefslogtreecommitdiff
path: root/tools/gfx/d3d12/render-d3d12.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-02-05 14:36:07 -0800
committerGitHub <noreply@github.com>2021-02-05 14:36:07 -0800
commitdf7548ef62c02b9ab1cc5addecaa6b6c150f2750 (patch)
tree17081a8d5de3fd3292043aae6761d0c8960e6783 /tools/gfx/d3d12/render-d3d12.cpp
parent5fbaccfc1d4ac7d17d528de894d1f276e41d9ce1 (diff)
Shader-Object example (#1694)
Diffstat (limited to 'tools/gfx/d3d12/render-d3d12.cpp')
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp82
1 files changed, 51 insertions, 31 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index de7cbd2e2..0ab07c262 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -2960,6 +2960,15 @@ void D3D12Renderer::setDescriptorSet(PipelineType pipelineType, IPipelineLayout*
Result D3D12Renderer::createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram)
{
+ if (desc.slangProgram && desc.slangProgram->getSpecializationParamCount() != 0)
+ {
+ // For a specializable program, we don't invoke any actual slang compilation yet.
+ RefPtr<ShaderProgramImpl> shaderProgram = new ShaderProgramImpl();
+ initProgramCommon(shaderProgram, desc);
+ *outProgram = shaderProgram.detach();
+ return SLANG_OK;
+ }
+
if( desc.kernelCount == 0 )
{
return createProgramFromSlang(this, desc, outProgram);
@@ -3740,43 +3749,54 @@ Result D3D12Renderer::createComputePipelineState(const ComputePipelineStateDesc&
auto pipelineLayoutImpl = (PipelineLayoutImpl*) desc.pipelineLayout;
auto programImpl = (ShaderProgramImpl*) desc.program;
- // Describe and create the compute pipeline state object
- D3D12_COMPUTE_PIPELINE_STATE_DESC computeDesc = {};
- computeDesc.pRootSignature = pipelineLayoutImpl->m_rootSignature;
- computeDesc.CS = { programImpl->m_computeShader.getBuffer(), SIZE_T(programImpl->m_computeShader.getCount()) };
-
+ // Only actually create a D3D12 pipeline state if the pipeline is fully specialized.
ComPtr<ID3D12PipelineState> pipelineState;
-
-#ifdef GFX_NVAPI
- if (m_nvapi)
+ if (!programImpl->slangProgram || programImpl->slangProgram->getSpecializationParamCount() == 0)
{
- // Also fill the extension structure.
- // Use the same UAV slot index and register space that are declared in the shader.
-
- // For simplicities sake we just use u0
- NVAPI_D3D12_PSO_SET_SHADER_EXTENSION_SLOT_DESC extensionDesc;
- extensionDesc.baseVersion = NV_PSO_EXTENSION_DESC_VER;
- extensionDesc.version = NV_SET_SHADER_EXTENSION_SLOT_DESC_VER;
- extensionDesc.uavSlot = 0;
- extensionDesc.registerSpace = 0;
-
- // Put the pointer to the extension into an array - there can be multiple extensions enabled at once.
- const NVAPI_D3D12_PSO_EXTENSION_DESC* extensions[] = { &extensionDesc };
+ // Describe and create the compute pipeline state object
+ D3D12_COMPUTE_PIPELINE_STATE_DESC computeDesc = {};
+ computeDesc.pRootSignature = pipelineLayoutImpl->m_rootSignature;
+ computeDesc.CS = {
+ programImpl->m_computeShader.getBuffer(),
+ SIZE_T(programImpl->m_computeShader.getCount())};
- // Now create the PSO.
- const NvAPI_Status nvapiStatus = NvAPI_D3D12_CreateComputePipelineState(m_device, &computeDesc, SLANG_COUNT_OF(extensions), extensions, pipelineState.writeRef());
-
- if (nvapiStatus != NVAPI_OK)
- {
- return SLANG_FAIL;
+#ifdef GFX_NVAPI
+ if (m_nvapi)
+ {
+ // Also fill the extension structure.
+ // Use the same UAV slot index and register space that are declared in the shader.
+
+ // For simplicities sake we just use u0
+ NVAPI_D3D12_PSO_SET_SHADER_EXTENSION_SLOT_DESC extensionDesc;
+ extensionDesc.baseVersion = NV_PSO_EXTENSION_DESC_VER;
+ extensionDesc.version = NV_SET_SHADER_EXTENSION_SLOT_DESC_VER;
+ extensionDesc.uavSlot = 0;
+ extensionDesc.registerSpace = 0;
+
+ // Put the pointer to the extension into an array - there can be multiple extensions
+ // enabled at once.
+ const NVAPI_D3D12_PSO_EXTENSION_DESC* extensions[] = {&extensionDesc};
+
+ // Now create the PSO.
+ const NvAPI_Status nvapiStatus = NvAPI_D3D12_CreateComputePipelineState(
+ m_device,
+ &computeDesc,
+ SLANG_COUNT_OF(extensions),
+ extensions,
+ pipelineState.writeRef());
+
+ if (nvapiStatus != NVAPI_OK)
+ {
+ return SLANG_FAIL;
+ }
}
- }
- else
+ else
#endif
- {
- SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(&computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ {
+ SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(
+ &computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
+ }
}
-
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl;
pipelineStateImpl->m_pipelineState = pipelineState;