summaryrefslogtreecommitdiff
path: root/tools/gfx/d3d12/render-d3d12.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-07-20 10:22:20 -0700
committerGitHub <noreply@github.com>2021-07-20 10:22:20 -0700
commitf9f8d3ec5c749bcbdab5a8fc2d2f919350f2423c (patch)
tree78d4fbb45e737fd6cccf8da419e4eae7b97bf7e2 /tools/gfx/d3d12/render-d3d12.cpp
parent6162950d9012833ef5d4f96b99c67a46bf97ce6d (diff)
Minor refactor to gfx D3D12 implementation. (#1913)
* Minor refactor to gfx D3D12 implementation. - Allow more flexible collection of shader stages in a shader program. - Add `createRayTracingPipelineState` public interface. (no implementation). * Fix Vulkan initialization. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tools/gfx/d3d12/render-d3d12.cpp')
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp70
1 files changed, 45 insertions, 25 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 0e7e7d3ab..24a1fd93e 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -123,6 +123,8 @@ public:
const GraphicsPipelineStateDesc& desc, IPipelineState** outState) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState(
const ComputePipelineStateDesc& desc, IPipelineState** outState) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState(
+ const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createQueryPool(
const IQueryPool::Desc& desc, IQueryPool** outState) override;
@@ -1938,13 +1940,17 @@ public:
// List<DescriptorSetInfo> m_gpuDescriptorSetInfos;
};
+ struct ShaderBinary
+ {
+ SlangStage stage;
+ List<uint8_t> code;
+ };
+
class ShaderProgramImpl : public ShaderProgramBase
{
public:
PipelineType m_pipelineType;
- List<uint8_t> m_vertexShader;
- List<uint8_t> m_pixelShader;
- List<uint8_t> m_computeShader;
+ List<ShaderBinary> m_shaders;
RefPtr<RootShaderObjectLayoutImpl> m_rootObjectLayout;
};
@@ -5226,25 +5232,12 @@ Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr
(char*)diagnostics->getBufferPointer());
}
SLANG_RETURN_ON_FAIL(compileResult);
- List<uint8_t>* shaderCodeDestBuffer = nullptr;
- switch (stage)
- {
- case SLANG_STAGE_COMPUTE:
- shaderCodeDestBuffer = &shaderProgram->m_computeShader;
- break;
- case SLANG_STAGE_VERTEX:
- shaderCodeDestBuffer = &shaderProgram->m_vertexShader;
- break;
- case SLANG_STAGE_FRAGMENT:
- shaderCodeDestBuffer = &shaderProgram->m_pixelShader;
- break;
- default:
- SLANG_ASSERT(!"unsupported shader stage.");
- return SLANG_FAIL;
- }
- shaderCodeDestBuffer->addRange(
+ ShaderBinary shaderBin;
+ shaderBin.stage = stage;
+ shaderBin.code.addRange(
reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
(Index)kernelCode->getBufferSize());
+ shaderProgram->m_shaders.add(_Move(shaderBin));
}
returnComPtr(outProgram, shaderProgram);
return SLANG_OK;
@@ -5294,9 +5287,31 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc&
D3D12_GRAPHICS_PIPELINE_STATE_DESC psoDesc = {};
psoDesc.pRootSignature = programImpl->m_rootObjectLayout->m_rootSignature;
-
- psoDesc.VS = { programImpl->m_vertexShader.getBuffer(), SIZE_T(programImpl->m_vertexShader.getCount()) };
- psoDesc.PS = { programImpl->m_pixelShader .getBuffer(), SIZE_T(programImpl->m_pixelShader .getCount()) };
+ for (auto& shaderBin : programImpl->m_shaders)
+ {
+ switch (shaderBin.stage)
+ {
+ case SLANG_STAGE_VERTEX:
+ psoDesc.VS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) };
+ break;
+ case SLANG_STAGE_FRAGMENT:
+ psoDesc.PS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) };
+ break;
+ case SLANG_STAGE_DOMAIN:
+ psoDesc.DS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) };
+ break;
+ case SLANG_STAGE_HULL:
+ psoDesc.HS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) };
+ break;
+ case SLANG_STAGE_GEOMETRY:
+ psoDesc.GS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) };
+ break;
+ default:
+ getDebugCallback()->handleMessage(
+ DebugMessageType::Error, DebugMessageSource::Layer, "Unsupported shader stage.");
+ return SLANG_E_NOT_AVAILABLE;
+ }
+ }
psoDesc.InputLayout = { inputLayoutImpl->m_elements.getBuffer(), UINT(inputLayoutImpl->m_elements.getCount()) };
psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType);
@@ -5407,8 +5422,8 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i
D3D12_COMPUTE_PIPELINE_STATE_DESC computeDesc = {};
computeDesc.pRootSignature = programImpl->m_rootObjectLayout->m_rootSignature;
computeDesc.CS = {
- programImpl->m_computeShader.getBuffer(),
- SIZE_T(programImpl->m_computeShader.getCount())};
+ programImpl->m_shaders[0].code.getBuffer(),
+ SIZE_T(programImpl->m_shaders[0].code.getCount())};
#ifdef GFX_NVAPI
if (m_nvapi)
@@ -5454,6 +5469,11 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i
return SLANG_OK;
}
+Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState)
+{
+ return SLANG_E_NOT_AVAILABLE;
+}
+
Result D3D12Device::QueryPoolImpl::init(const IQueryPool::Desc& desc, D3D12Device* device)
{
// Translate query type.