From f9f8d3ec5c749bcbdab5a8fc2d2f919350f2423c Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 20 Jul 2021 10:22:20 -0700 Subject: Minor refactor to gfx D3D12 implementation. (#1913) * Minor refactor to gfx D3D12 implementation. - Allow more flexible collection of shader stages in a shader program. - Add `createRayTracingPipelineState` public interface. (no implementation). * Fix Vulkan initialization. Co-authored-by: Yong He --- tools/gfx/d3d12/render-d3d12.cpp | 70 ++++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 25 deletions(-) (limited to 'tools/gfx/d3d12/render-d3d12.cpp') diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index 0e7e7d3ab..24a1fd93e 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -123,6 +123,8 @@ public: const GraphicsPipelineStateDesc& desc, IPipelineState** outState) override; virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState( const ComputePipelineStateDesc& desc, IPipelineState** outState) override; + virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState( + const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override; virtual SLANG_NO_THROW Result SLANG_MCALL createQueryPool( const IQueryPool::Desc& desc, IQueryPool** outState) override; @@ -1938,13 +1940,17 @@ public: // List m_gpuDescriptorSetInfos; }; + struct ShaderBinary + { + SlangStage stage; + List code; + }; + class ShaderProgramImpl : public ShaderProgramBase { public: PipelineType m_pipelineType; - List m_vertexShader; - List m_pixelShader; - List m_computeShader; + List m_shaders; RefPtr m_rootObjectLayout; }; @@ -5226,25 +5232,12 @@ Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr (char*)diagnostics->getBufferPointer()); } SLANG_RETURN_ON_FAIL(compileResult); - List* shaderCodeDestBuffer = nullptr; - switch (stage) - { - case SLANG_STAGE_COMPUTE: - shaderCodeDestBuffer = &shaderProgram->m_computeShader; - break; - case SLANG_STAGE_VERTEX: - shaderCodeDestBuffer = &shaderProgram->m_vertexShader; - break; - case SLANG_STAGE_FRAGMENT: - shaderCodeDestBuffer = &shaderProgram->m_pixelShader; - break; - default: - SLANG_ASSERT(!"unsupported shader stage."); - return SLANG_FAIL; - } - shaderCodeDestBuffer->addRange( + ShaderBinary shaderBin; + shaderBin.stage = stage; + shaderBin.code.addRange( reinterpret_cast(kernelCode->getBufferPointer()), (Index)kernelCode->getBufferSize()); + shaderProgram->m_shaders.add(_Move(shaderBin)); } returnComPtr(outProgram, shaderProgram); return SLANG_OK; @@ -5294,9 +5287,31 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& D3D12_GRAPHICS_PIPELINE_STATE_DESC psoDesc = {}; psoDesc.pRootSignature = programImpl->m_rootObjectLayout->m_rootSignature; - - psoDesc.VS = { programImpl->m_vertexShader.getBuffer(), SIZE_T(programImpl->m_vertexShader.getCount()) }; - psoDesc.PS = { programImpl->m_pixelShader .getBuffer(), SIZE_T(programImpl->m_pixelShader .getCount()) }; + for (auto& shaderBin : programImpl->m_shaders) + { + switch (shaderBin.stage) + { + case SLANG_STAGE_VERTEX: + psoDesc.VS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) }; + break; + case SLANG_STAGE_FRAGMENT: + psoDesc.PS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) }; + break; + case SLANG_STAGE_DOMAIN: + psoDesc.DS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) }; + break; + case SLANG_STAGE_HULL: + psoDesc.HS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) }; + break; + case SLANG_STAGE_GEOMETRY: + psoDesc.GS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) }; + break; + default: + getDebugCallback()->handleMessage( + DebugMessageType::Error, DebugMessageSource::Layer, "Unsupported shader stage."); + return SLANG_E_NOT_AVAILABLE; + } + } psoDesc.InputLayout = { inputLayoutImpl->m_elements.getBuffer(), UINT(inputLayoutImpl->m_elements.getCount()) }; psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType); @@ -5407,8 +5422,8 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i D3D12_COMPUTE_PIPELINE_STATE_DESC computeDesc = {}; computeDesc.pRootSignature = programImpl->m_rootObjectLayout->m_rootSignature; computeDesc.CS = { - programImpl->m_computeShader.getBuffer(), - SIZE_T(programImpl->m_computeShader.getCount())}; + programImpl->m_shaders[0].code.getBuffer(), + SIZE_T(programImpl->m_shaders[0].code.getCount())}; #ifdef GFX_NVAPI if (m_nvapi) @@ -5454,6 +5469,11 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i return SLANG_OK; } +Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState) +{ + return SLANG_E_NOT_AVAILABLE; +} + Result D3D12Device::QueryPoolImpl::init(const IQueryPool::Desc& desc, D3D12Device* device) { // Translate query type. -- cgit v1.2.3