summaryrefslogtreecommitdiff
path: root/tools/gfx
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-02-20 14:37:41 -0800
committerGitHub <noreply@github.com>2022-02-20 14:37:41 -0800
commitc4790309ec46ae2f4f7c49eb50699a950ee7a9a4 (patch)
treee89f2a4a0a8a3fee16ebde5ce5b05ceb1d473398 /tools/gfx
parente272aec6a9ddb8b0af82f72c061f5393f2b2bdab (diff)
gfx: defer downstream shader compilation until draw/dispatch. (#2139)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tools/gfx')
-rw-r--r--tools/gfx/d3d/d3d-swapchain.h40
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp724
-rw-r--r--tools/gfx/renderer-shared.cpp6
-rw-r--r--tools/gfx/renderer-shared.h70
-rw-r--r--tools/gfx/vulkan/render-vk.cpp2
5 files changed, 484 insertions, 358 deletions
diff --git a/tools/gfx/d3d/d3d-swapchain.h b/tools/gfx/d3d/d3d-swapchain.h
index 99343aaf4..1c29b2039 100644
--- a/tools/gfx/d3d/d3d-swapchain.h
+++ b/tools/gfx/d3d/d3d-swapchain.h
@@ -48,17 +48,44 @@ public:
swapChainDesc.OutputWindow = (HWND)window.handleValues[0];
swapChainDesc.SampleDesc.Count = 1;
swapChainDesc.Windowed = TRUE;
-
if (!desc.enableVSync)
{
swapChainDesc.Flags |= DXGI_SWAP_CHAIN_FLAG_FRAME_LATENCY_WAITABLE_OBJECT;
}
// Swap chain needs the queue so that it can force a flush on it.
- ComPtr<IDXGISwapChain> swapChain;
- SLANG_RETURN_ON_FAIL(
- getDXGIFactory()->CreateSwapChain(getOwningDevice(), &swapChainDesc, swapChain.writeRef()));
- SLANG_RETURN_ON_FAIL(swapChain->QueryInterface(m_swapChain.writeRef()));
+ ComPtr<IDXGIFactory2> dxgiFactory2;
+ getDXGIFactory()->QueryInterface(IID_PPV_ARGS(dxgiFactory2.writeRef()));
+ if (!dxgiFactory2)
+ {
+ ComPtr<IDXGISwapChain> swapChain;
+ SLANG_RETURN_ON_FAIL(getDXGIFactory()->CreateSwapChain(
+ getOwningDevice(), &swapChainDesc, swapChain.writeRef()));
+ SLANG_RETURN_ON_FAIL(getDXGIFactory()->MakeWindowAssociation(
+ (HWND)window.handleValues[0], DXGI_MWA_NO_ALT_ENTER));
+ SLANG_RETURN_ON_FAIL(swapChain->QueryInterface(m_swapChain.writeRef()));
+ }
+ else
+ {
+ DXGI_SWAP_CHAIN_DESC1 desc1 = {};
+ desc1.BufferCount = swapChainDesc.BufferCount;
+ desc1.BufferUsage = swapChainDesc.BufferUsage;
+ desc1.Flags = swapChainDesc.Flags;
+ desc1.Format = swapChainDesc.BufferDesc.Format;
+ desc1.Height = swapChainDesc.BufferDesc.Height;
+ desc1.Width = swapChainDesc.BufferDesc.Width;
+ desc1.SampleDesc = swapChainDesc.SampleDesc;
+ desc1.SwapEffect = swapChainDesc.SwapEffect;
+ ComPtr<IDXGISwapChain1> swapChain1;
+ SLANG_RETURN_ON_FAIL(dxgiFactory2->CreateSwapChainForHwnd(
+ getOwningDevice(),
+ (HWND)window.handleValues[0],
+ &desc1,
+ nullptr,
+ nullptr,
+ swapChain1.writeRef()));
+ SLANG_RETURN_ON_FAIL(swapChain1->QueryInterface(m_swapChain.writeRef()));
+ }
if (!desc.enableVSync)
{
@@ -74,9 +101,6 @@ public:
m_swapChain->SetMaximumFrameLatency(maxLatency);
}
- SLANG_RETURN_ON_FAIL(getDXGIFactory()->MakeWindowAssociation(
- (HWND)window.handleValues[0], DXGI_MWA_NO_ALT_ENTER));
-
createSwapchainBufferImages();
return SLANG_OK;
}
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 29335ed1d..aa008ea81 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -478,6 +478,10 @@ public:
class PipelineStateImpl : public PipelineStateBase
{
public:
+ PipelineStateImpl(D3D12Device* device)
+ : m_device(device)
+ {}
+ D3D12Device* m_device;
ComPtr<ID3D12PipelineState> m_pipelineState;
void init(const GraphicsPipelineStateDesc& inDesc)
{
@@ -495,10 +499,12 @@ public:
}
virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(InteropHandle* outHandle) override
{
+ SLANG_RETURN_ON_FAIL(ensureAPIPipelineStateCreated());
outHandle->api = InteropHandleAPI::D3D12;
outHandle->handleValue = reinterpret_cast<uint64_t>(m_pipelineState.get());
return SLANG_OK;
}
+ virtual Result ensureAPIPipelineStateCreated() override;
};
#if SLANG_GFX_HAS_DXR_SUPPORT
@@ -506,19 +512,25 @@ public:
{
public:
ComPtr<ID3D12StateObject> m_stateObject;
+ D3D12Device* m_device;
+ RayTracingPipelineStateImpl(D3D12Device* device)
+ : m_device(device)
+ {}
void init(const RayTracingPipelineStateDesc& inDesc)
{
PipelineStateDesc pipelineDesc;
pipelineDesc.type = PipelineType::RayTracing;
- pipelineDesc.rayTracing = inDesc;
+ pipelineDesc.rayTracing.set(inDesc);
initializeBase(pipelineDesc);
}
virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(InteropHandle* outHandle) override
{
+ SLANG_RETURN_ON_FAIL(ensureAPIPipelineStateCreated());
outHandle->api = InteropHandleAPI::D3D12;
outHandle->handleValue = reinterpret_cast<uint64_t>(m_stateObject.get());
return SLANG_OK;
}
+ virtual Result ensureAPIPipelineStateCreated() override;
};
#endif
@@ -2464,6 +2476,62 @@ public:
public:
List<ShaderBinary> m_shaders;
RefPtr<RootShaderObjectLayoutImpl> m_rootObjectLayout;
+ Result compileShaders()
+ {
+ // For a fully specialized program, read and store its kernel code in `shaderProgram`.
+ auto compileShader = [&](slang::EntryPointReflection* entryPointInfo,
+ slang::IComponentType* entryPointComponent,
+ SlangInt entryPointIndex)
+ {
+ auto stage = entryPointInfo->getStage();
+ ComPtr<ISlangBlob> kernelCode;
+ ComPtr<ISlangBlob> diagnostics;
+ auto compileResult = entryPointComponent->getEntryPointCode(
+ entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef());
+ if (diagnostics)
+ {
+ getDebugCallback()->handleMessage(
+ compileResult == SLANG_OK ? DebugMessageType::Warning
+ : DebugMessageType::Error,
+ DebugMessageSource::Slang,
+ (char*)diagnostics->getBufferPointer());
+ }
+ SLANG_RETURN_ON_FAIL(compileResult);
+ ShaderBinary shaderBin;
+ shaderBin.stage = stage;
+ shaderBin.entryPointInfo = entryPointInfo;
+ shaderBin.code.addRange(
+ reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
+ (Index)kernelCode->getBufferSize());
+ m_shaders.add(_Move(shaderBin));
+ return SLANG_OK;
+ };
+
+ if (linkedEntryPoints.getCount() == 0)
+ {
+ // If the user does not explicitly specify entry point components, find them from
+ // `linkedEntryPoints`.
+ auto programReflection = linkedProgram->getLayout();
+ for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++)
+ {
+ SLANG_RETURN_ON_FAIL(compileShader(
+ programReflection->getEntryPointByIndex(i),
+ linkedProgram,
+ (SlangInt)i));
+ }
+ }
+ else
+ {
+ // If the user specifies entry point components via the separated entry point array,
+ // compile code from there.
+ for (auto& entryPoint : linkedEntryPoints)
+ {
+ SLANG_RETURN_ON_FAIL(compileShader(
+ entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0));
+ }
+ }
+ return SLANG_OK;
+ }
};
class ShaderObjectImpl
@@ -5241,21 +5309,24 @@ public:
auto cmdImpl = static_cast<CommandBufferImpl*>(commandBuffers[i]);
commandLists.add(cmdImpl->m_cmdList);
}
- m_d3dQueue->ExecuteCommandLists((UINT)count, commandLists.getArrayView().getBuffer());
+ if (count > 0)
+ {
+ m_d3dQueue->ExecuteCommandLists((UINT)count, commandLists.getArrayView().getBuffer());
- m_fenceValue++;
+ m_fenceValue++;
- for (uint32_t i = 0; i < count; i++)
- {
- if (i > 0 && commandBuffers[i] == commandBuffers[i - 1])
- continue;
- auto cmdImpl = static_cast<CommandBufferImpl*>(commandBuffers[i]);
- auto transientHeap = cmdImpl->m_transientHeap;
- auto& waitInfo = transientHeap->getQueueWaitInfo(m_queueIndex);
- waitInfo.waitValue = m_fenceValue;
- waitInfo.fence = m_fence;
+ for (uint32_t i = 0; i < count; i++)
+ {
+ if (i > 0 && commandBuffers[i] == commandBuffers[i - 1])
+ continue;
+ auto cmdImpl = static_cast<CommandBufferImpl*>(commandBuffers[i]);
+ auto transientHeap = cmdImpl->m_transientHeap;
+ auto& waitInfo = transientHeap->getQueueWaitInfo(m_queueIndex);
+ waitInfo.waitValue = m_fenceValue;
+ waitInfo.fence = m_fence;
+ }
+ m_d3dQueue->Signal(m_fence, m_fenceValue);
}
- m_d3dQueue->Signal(m_fence, m_fenceValue);
if (fence)
{
@@ -5539,6 +5610,7 @@ Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitte
auto commandList = m_d3dCmdList;
auto pipelineTypeIndex = (int)newPipelineImpl->desc.type;
auto programImpl = static_cast<ShaderProgramImpl*>(newPipelineImpl->m_program.Ptr());
+ newPipelineImpl->ensureAPIPipelineStateCreated();
submitter->setRootSignature(programImpl->m_rootObjectLayout->m_rootSignature);
submitter->setPipelineState(newPipelineImpl);
RootShaderObjectLayoutImpl* rootLayoutImpl = programImpl->m_rootObjectLayout;
@@ -7670,64 +7742,6 @@ Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr
}
return rootShaderLayoutResult;
}
- if (shaderProgram->isSpecializable())
- {
- // For a specializable program, we don't invoke any actual slang compilation yet.
- returnComPtr(outProgram, shaderProgram);
- return SLANG_OK;
- }
- // For a fully specialized program, read and store its kernel code in `shaderProgram`.
- auto compileShader = [&](slang::EntryPointReflection* entryPointInfo,
- slang::IComponentType* entryPointComponent,
- SlangInt entryPointIndex)
- {
- auto stage = entryPointInfo->getStage();
- ComPtr<ISlangBlob> kernelCode;
- ComPtr<ISlangBlob> diagnostics;
- auto compileResult = entryPointComponent->getEntryPointCode(
- entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef());
- if (diagnostics)
- {
- getDebugCallback()->handleMessage(
- compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error,
- DebugMessageSource::Slang,
- (char*)diagnostics->getBufferPointer());
- if (outDiagnosticBlob)
- returnComPtr(outDiagnosticBlob, diagnostics);
- }
- SLANG_RETURN_ON_FAIL(compileResult);
- ShaderBinary shaderBin;
- shaderBin.stage = stage;
- shaderBin.entryPointInfo = entryPointInfo;
- shaderBin.code.addRange(
- reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
- (Index)kernelCode->getBufferSize());
- shaderProgram->m_shaders.add(_Move(shaderBin));
- return SLANG_OK;
- };
-
- if (shaderProgram->linkedEntryPoints.getCount() == 0)
- {
- // If the user does not explicitly specify entry point components, find them from `linkedEntryPoints`.
- auto programReflection = shaderProgram->linkedProgram->getLayout();
- for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++)
- {
- SLANG_RETURN_ON_FAIL(compileShader(
- programReflection->getEntryPointByIndex(i),
- shaderProgram->linkedProgram,
- (SlangInt)i));
- }
- }
- else
- {
- // If the user specifies entry point components via the separated entry point array, compile code
- // from there.
- for (auto& entryPoint : shaderProgram->linkedEntryPoints)
- {
- SLANG_RETURN_ON_FAIL(
- compileShader(entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0));
- }
- }
returnComPtr(outProgram, shaderProgram);
return SLANG_OK;
}
@@ -7785,251 +7799,17 @@ Result D3D12Device::createShaderTable(const IShaderTable::Desc& desc, IShaderTab
return SLANG_OK;
}
-Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState)
+Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& desc, IPipelineState** outState)
{
- GraphicsPipelineStateDesc desc = inDesc;
- auto programImpl = (ShaderProgramImpl*) desc.program;
-
- if (!programImpl->m_rootObjectLayout->m_rootSignature)
- {
- RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
- pipelineStateImpl->init(desc);
- returnComPtr(outState, pipelineStateImpl);
- return SLANG_OK;
- }
-
- // Only actually create a D3D12 pipeline state if the pipeline is fully specialized.
- auto inputLayoutImpl = (InputLayoutImpl*) desc.inputLayout;
-
- // Describe and create the graphics pipeline state object (PSO)
- D3D12_GRAPHICS_PIPELINE_STATE_DESC psoDesc = {};
-
- psoDesc.pRootSignature = programImpl->m_rootObjectLayout->m_rootSignature;
-
- for (auto& shaderBin : programImpl->m_shaders)
- {
- switch (shaderBin.stage)
- {
- case SLANG_STAGE_VERTEX:
- psoDesc.VS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) };
- break;
- case SLANG_STAGE_FRAGMENT:
- psoDesc.PS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) };
- break;
- case SLANG_STAGE_DOMAIN:
- psoDesc.DS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) };
- break;
- case SLANG_STAGE_HULL:
- psoDesc.HS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) };
- break;
- case SLANG_STAGE_GEOMETRY:
- psoDesc.GS = { shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount()) };
- break;
- default:
- getDebugCallback()->handleMessage(
- DebugMessageType::Error, DebugMessageSource::Layer, "Unsupported shader stage.");
- return SLANG_E_NOT_AVAILABLE;
- }
- }
-
- if (inputLayoutImpl)
- {
- psoDesc.InputLayout = {
- inputLayoutImpl->m_elements.getBuffer(), UINT(inputLayoutImpl->m_elements.getCount())};
- }
-
- psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType);
-
- {
- auto framebufferLayout = static_cast<FramebufferLayoutImpl*>(desc.framebufferLayout);
- const int numRenderTargets = int(framebufferLayout->m_renderTargets.getCount());
-
- if (framebufferLayout->m_hasDepthStencil)
- {
- psoDesc.DSVFormat = D3DUtil::getMapFormat(framebufferLayout->m_depthStencil.format);
- psoDesc.SampleDesc.Count = framebufferLayout->m_depthStencil.sampleCount;
- }
- else
- {
- psoDesc.DSVFormat = DXGI_FORMAT_UNKNOWN;
- if (framebufferLayout->m_renderTargets.getCount())
- {
- psoDesc.SampleDesc.Count = framebufferLayout->m_renderTargets[0].sampleCount;
- }
- }
- psoDesc.NumRenderTargets = numRenderTargets;
- for (Int i = 0; i < numRenderTargets; i++)
- {
- psoDesc.RTVFormats[i] =
- D3DUtil::getMapFormat(framebufferLayout->m_renderTargets[i].format);
- }
-
- psoDesc.SampleDesc.Quality = 0;
- psoDesc.SampleMask = UINT_MAX;
- }
-
- {
- auto& rs = psoDesc.RasterizerState;
- rs.FillMode = D3DUtil::getFillMode(desc.rasterizer.fillMode);
- rs.CullMode = D3DUtil::getCullMode(desc.rasterizer.cullMode);
- rs.FrontCounterClockwise =
- desc.rasterizer.frontFace == gfx::FrontFaceMode::CounterClockwise ? TRUE : FALSE;
- rs.DepthBias = desc.rasterizer.depthBias;
- rs.DepthBiasClamp = desc.rasterizer.depthBiasClamp;
- rs.SlopeScaledDepthBias = desc.rasterizer.slopeScaledDepthBias;
- rs.DepthClipEnable = desc.rasterizer.depthClipEnable ? TRUE : FALSE;
- rs.MultisampleEnable = desc.rasterizer.multisampleEnable ? TRUE : FALSE;
- rs.AntialiasedLineEnable = desc.rasterizer.antialiasedLineEnable ? TRUE : FALSE;
- rs.ForcedSampleCount = desc.rasterizer.forcedSampleCount;
- rs.ConservativeRaster = desc.rasterizer.enableConservativeRasterization
- ? D3D12_CONSERVATIVE_RASTERIZATION_MODE_ON
- : D3D12_CONSERVATIVE_RASTERIZATION_MODE_OFF;
- }
-
- {
- D3D12_BLEND_DESC& blend = psoDesc.BlendState;
- blend.IndependentBlendEnable = FALSE;
- blend.AlphaToCoverageEnable = desc.blend.alphaToCoverageEnable ? TRUE : FALSE;
- blend.RenderTarget[0].RenderTargetWriteMask = (uint8_t)RenderTargetWriteMask::EnableAll;
- for (uint32_t i = 0; i < desc.blend.targetCount; i++)
- {
- auto& d3dDesc = blend.RenderTarget[i];
- d3dDesc.BlendEnable = desc.blend.targets[i].enableBlend ? TRUE : FALSE;
- d3dDesc.BlendOp = D3DUtil::getBlendOp(desc.blend.targets[i].color.op);
- d3dDesc.BlendOpAlpha = D3DUtil::getBlendOp(desc.blend.targets[i].alpha.op);
- d3dDesc.DestBlend = D3DUtil::getBlendFactor(desc.blend.targets[i].color.dstFactor);
- d3dDesc.DestBlendAlpha = D3DUtil::getBlendFactor(desc.blend.targets[i].alpha.dstFactor);
- d3dDesc.LogicOp = D3D12_LOGIC_OP_NOOP;
- d3dDesc.LogicOpEnable = FALSE;
- d3dDesc.RenderTargetWriteMask = desc.blend.targets[i].writeMask;
- d3dDesc.SrcBlend = D3DUtil::getBlendFactor(desc.blend.targets[i].color.srcFactor);
- d3dDesc.SrcBlendAlpha = D3DUtil::getBlendFactor(desc.blend.targets[i].alpha.srcFactor);
- }
- for (uint32_t i = 1; i < desc.blend.targetCount; i++)
- {
- if (memcmp(&desc.blend.targets[i], &desc.blend.targets[0], sizeof(desc.blend.targets[0])) != 0)
- {
- blend.IndependentBlendEnable = TRUE;
- break;
- }
- }
- for (uint32_t i = (uint32_t)desc.blend.targetCount; i < D3D12_SIMULTANEOUS_RENDER_TARGET_COUNT; ++i)
- {
- blend.RenderTarget[i] = blend.RenderTarget[0];
- }
- }
-
- {
- auto& ds = psoDesc.DepthStencilState;
-
- ds.DepthEnable = inDesc.depthStencil.depthTestEnable;
- ds.DepthWriteMask = inDesc.depthStencil.depthWriteEnable ? D3D12_DEPTH_WRITE_MASK_ALL
- : D3D12_DEPTH_WRITE_MASK_ZERO;
- ds.DepthFunc = D3DUtil::getComparisonFunc(inDesc.depthStencil.depthFunc);
- ds.StencilEnable = inDesc.depthStencil.stencilEnable;
- ds.StencilReadMask = (UINT8)inDesc.depthStencil.stencilReadMask;
- ds.StencilWriteMask = (UINT8)inDesc.depthStencil.stencilWriteMask;
- ds.FrontFace = D3DUtil::translateStencilOpDesc(inDesc.depthStencil.frontFace);
- ds.BackFace = D3DUtil::translateStencilOpDesc(inDesc.depthStencil.backFace);
- }
-
- psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.primitiveType);
-
- ComPtr<ID3D12PipelineState> pipelineState;
- if (m_pipelineCreationAPIDispatcher)
- {
- SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createGraphicsPipelineState(
- this, programImpl->linkedProgram.get(), &psoDesc, (void**)pipelineState.writeRef()));
- }
- else
- {
- SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef())));
- }
-
- RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
- pipelineStateImpl->m_pipelineState = pipelineState;
+ RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this);
pipelineStateImpl->init(desc);
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
}
-Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& inDesc, IPipelineState** outState)
+Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& desc, IPipelineState** outState)
{
- ComputePipelineStateDesc desc = inDesc;
-
- auto programImpl = (ShaderProgramImpl*) desc.program;
- if (!programImpl->m_rootObjectLayout->m_rootSignature)
- {
- RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
- pipelineStateImpl->init(desc);
- returnComPtr(outState, pipelineStateImpl);
- return SLANG_OK;
- }
-
- // Only actually create a D3D12 pipeline state if the pipeline is fully specialized.
- ComPtr<ID3D12PipelineState> pipelineState;
- if (!programImpl->isSpecializable())
- {
- // Describe and create the compute pipeline state object
- D3D12_COMPUTE_PIPELINE_STATE_DESC computeDesc = {};
- computeDesc.pRootSignature =
- desc.d3d12RootSignatureOverride
- ? static_cast<ID3D12RootSignature*>(desc.d3d12RootSignatureOverride)
- : programImpl->m_rootObjectLayout->m_rootSignature;
- computeDesc.CS = {
- programImpl->m_shaders[0].code.getBuffer(),
- SIZE_T(programImpl->m_shaders[0].code.getCount())};
-
-#ifdef GFX_NVAPI
- if (m_nvapi)
- {
- // Also fill the extension structure.
- // Use the same UAV slot index and register space that are declared in the shader.
-
- // For simplicities sake we just use u0
- NVAPI_D3D12_PSO_SET_SHADER_EXTENSION_SLOT_DESC extensionDesc;
- extensionDesc.baseVersion = NV_PSO_EXTENSION_DESC_VER;
- extensionDesc.version = NV_SET_SHADER_EXTENSION_SLOT_DESC_VER;
- extensionDesc.uavSlot = 0;
- extensionDesc.registerSpace = 0;
-
- // Put the pointer to the extension into an array - there can be multiple extensions
- // enabled at once.
- const NVAPI_D3D12_PSO_EXTENSION_DESC* extensions[] = {&extensionDesc};
-
- // Now create the PSO.
- const NvAPI_Status nvapiStatus = NvAPI_D3D12_CreateComputePipelineState(
- m_device,
- &computeDesc,
- SLANG_COUNT_OF(extensions),
- extensions,
- pipelineState.writeRef());
-
- if (nvapiStatus != NVAPI_OK)
- {
- return SLANG_FAIL;
- }
- }
- else
-#endif
- {
- if (m_pipelineCreationAPIDispatcher)
- {
- SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createComputePipelineState(
- this,
- programImpl->linkedProgram.get(),
- &computeDesc,
- (void**)pipelineState.writeRef()));
- }
- else
- {
- SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(
- &computeDesc, IID_PPV_ARGS(pipelineState.writeRef())));
- }
- }
- }
- RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
- pipelineStateImpl->m_pipelineState = pipelineState;
+ RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this);
pipelineStateImpl->init(desc);
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
@@ -8466,25 +8246,15 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays(
m_commandBuffer->m_cmdList4->DispatchRays(&dispatchDesc);
}
-Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState)
+Result D3D12Device::RayTracingPipelineStateImpl::ensureAPIPipelineStateCreated()
{
- if (!m_device5)
- {
- return SLANG_E_NOT_AVAILABLE;
- }
-
- RefPtr<RayTracingPipelineStateImpl> pipelineStateImpl = new RayTracingPipelineStateImpl();
- pipelineStateImpl->init(inDesc);
+ if (m_stateObject)
+ return SLANG_OK;
- auto program = static_cast<ShaderProgramImpl*>(inDesc.program);
+ auto program = static_cast<ShaderProgramImpl*>(m_program.Ptr());
auto slangGlobalScope = program->linkedProgram;
auto programLayout = slangGlobalScope->getLayout();
- if (!program->m_rootObjectLayout->m_rootSignature)
- {
- returnComPtr(outState, pipelineStateImpl);
- return SLANG_OK;
- }
List<D3D12_STATE_SUBOBJECT> subObjects;
ChunkedList<D3D12_DXIL_LIBRARY_DESC> dxilLibraries;
ChunkedList<D3D12_HIT_GROUP_DESC> hitGroups;
@@ -8548,28 +8318,30 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
compileShader(entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0));
}
}
-
- for (int i = 0; i < inDesc.hitGroupCount; i++)
+
+ for (Index i = 0; i < desc.rayTracing.hitGroupDescs.getCount(); i++)
{
- auto hitGroup = inDesc.hitGroups[i];
+ auto& hitGroup = desc.rayTracing.hitGroups[i];
D3D12_HIT_GROUP_DESC hitGroupDesc = {};
- hitGroupDesc.Type = hitGroup.intersectionEntryPoint == nullptr
+ hitGroupDesc.Type = hitGroup.intersectionEntryPoint.getLength() == 0
? D3D12_HIT_GROUP_TYPE_TRIANGLES
: D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE;
- if (hitGroup.anyHitEntryPoint)
+ if (hitGroup.anyHitEntryPoint.getLength())
{
- hitGroupDesc.AnyHitShaderImport = getWStr(hitGroup.anyHitEntryPoint);
+ hitGroupDesc.AnyHitShaderImport = getWStr(hitGroup.anyHitEntryPoint.getBuffer());
}
- if (hitGroup.closestHitEntryPoint)
+ if (hitGroup.closestHitEntryPoint.getLength())
{
- hitGroupDesc.ClosestHitShaderImport = getWStr(hitGroup.closestHitEntryPoint);
+ hitGroupDesc.ClosestHitShaderImport =
+ getWStr(hitGroup.closestHitEntryPoint.getBuffer());
}
- if (hitGroup.intersectionEntryPoint)
+ if (hitGroup.intersectionEntryPoint.getLength())
{
- hitGroupDesc.IntersectionShaderImport = getWStr(hitGroup.intersectionEntryPoint);
+ hitGroupDesc.IntersectionShaderImport =
+ getWStr(hitGroup.intersectionEntryPoint.getBuffer());
}
- hitGroupDesc.HitGroupExport = getWStr(hitGroup.hitGroupName);
+ hitGroupDesc.HitGroupExport = getWStr(hitGroup.hitGroupName.getBuffer());
D3D12_STATE_SUBOBJECT hitGroupSubObject = {};
hitGroupSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP;
@@ -8578,10 +8350,10 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
}
D3D12_RAYTRACING_SHADER_CONFIG shaderConfig = {};
- // According to DXR spec, fixed function triangle intersections must use float2 as ray attributes
- // that defines the barycentric coordinates at intersection.
- shaderConfig.MaxAttributeSizeInBytes = inDesc.maxAttributeSizeInBytes;
- shaderConfig.MaxPayloadSizeInBytes = inDesc.maxRayPayloadSize;
+ // According to DXR spec, fixed function triangle intersections must use float2 as ray
+ // attributes that defines the barycentric coordinates at intersection.
+ shaderConfig.MaxAttributeSizeInBytes = desc.rayTracing.maxAttributeSizeInBytes;
+ shaderConfig.MaxPayloadSizeInBytes = desc.rayTracing.maxRayPayloadSize;
D3D12_STATE_SUBOBJECT shaderConfigSubObject = {};
shaderConfigSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG;
shaderConfigSubObject.pDesc = &shaderConfig;
@@ -8595,28 +8367,42 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
subObjects.add(globalSignatureSubobject);
D3D12_RAYTRACING_PIPELINE_CONFIG pipelineConfig = {};
- pipelineConfig.MaxTraceRecursionDepth = inDesc.maxRecursion;
+ pipelineConfig.MaxTraceRecursionDepth = desc.rayTracing.maxRecursion;
D3D12_STATE_SUBOBJECT pipelineConfigSubobject = {};
pipelineConfigSubobject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_PIPELINE_CONFIG;
pipelineConfigSubobject.pDesc = &pipelineConfig;
subObjects.add(pipelineConfigSubobject);
- if (m_pipelineCreationAPIDispatcher)
+ if (m_device->m_pipelineCreationAPIDispatcher)
{
- m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangGlobalScope);
+ m_device->m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(
+ m_device, slangGlobalScope);
}
D3D12_STATE_OBJECT_DESC rtpsoDesc = {};
rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE;
rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount();
rtpsoDesc.pSubobjects = subObjects.getBuffer();
- SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef())));
+ SLANG_RETURN_ON_FAIL(m_device->m_device5->CreateStateObject(
+ &rtpsoDesc, IID_PPV_ARGS(m_stateObject.writeRef())));
- if (m_pipelineCreationAPIDispatcher)
+ if (m_device->m_pipelineCreationAPIDispatcher)
{
- m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangGlobalScope);
+ m_device->m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(
+ m_device, slangGlobalScope);
}
+ return SLANG_OK;
+}
+Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState)
+{
+ if (!m_device5)
+ {
+ return SLANG_E_NOT_AVAILABLE;
+ }
+
+ RefPtr<RayTracingPipelineStateImpl> pipelineStateImpl = new RayTracingPipelineStateImpl(this);
+ pipelineStateImpl->init(inDesc);
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
}
@@ -8823,4 +8609,254 @@ Result D3D12Device::ShaderObjectImpl::setResource(ShaderOffset const& offset, IR
return SLANG_OK;
}
+Result D3D12Device::PipelineStateImpl::ensureAPIPipelineStateCreated()
+{
+ if (m_pipelineState)
+ return SLANG_OK;
+
+ auto programImpl = static_cast<ShaderProgramImpl*>(m_program.Ptr());
+ if (programImpl->m_shaders.getCount() == 0)
+ {
+ SLANG_RETURN_ON_FAIL(programImpl->compileShaders());
+ }
+ if (desc.type == PipelineType::Graphics)
+ {
+ // Only actually create a D3D12 pipeline state if the pipeline is fully specialized.
+ auto inputLayoutImpl = (InputLayoutImpl*)desc.graphics.inputLayout;
+
+ // Describe and create the graphics pipeline state object (PSO)
+ D3D12_GRAPHICS_PIPELINE_STATE_DESC psoDesc = {};
+
+ psoDesc.pRootSignature = programImpl->m_rootObjectLayout->m_rootSignature;
+
+ for (auto& shaderBin : programImpl->m_shaders)
+ {
+ switch (shaderBin.stage)
+ {
+ case SLANG_STAGE_VERTEX:
+ psoDesc.VS = {shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount())};
+ break;
+ case SLANG_STAGE_FRAGMENT:
+ psoDesc.PS = {shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount())};
+ break;
+ case SLANG_STAGE_DOMAIN:
+ psoDesc.DS = {shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount())};
+ break;
+ case SLANG_STAGE_HULL:
+ psoDesc.HS = {shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount())};
+ break;
+ case SLANG_STAGE_GEOMETRY:
+ psoDesc.GS = {shaderBin.code.getBuffer(), SIZE_T(shaderBin.code.getCount())};
+ break;
+ default:
+ getDebugCallback()->handleMessage(
+ DebugMessageType::Error,
+ DebugMessageSource::Layer,
+ "Unsupported shader stage.");
+ return SLANG_E_NOT_AVAILABLE;
+ }
+ }
+
+ if (inputLayoutImpl)
+ {
+ psoDesc.InputLayout = {
+ inputLayoutImpl->m_elements.getBuffer(),
+ UINT(inputLayoutImpl->m_elements.getCount())};
+ }
+
+ psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.graphics.primitiveType);
+
+ {
+ auto framebufferLayout = static_cast<FramebufferLayoutImpl*>(desc.graphics.framebufferLayout);
+ const int numRenderTargets = int(framebufferLayout->m_renderTargets.getCount());
+
+ if (framebufferLayout->m_hasDepthStencil)
+ {
+ psoDesc.DSVFormat = D3DUtil::getMapFormat(framebufferLayout->m_depthStencil.format);
+ psoDesc.SampleDesc.Count = framebufferLayout->m_depthStencil.sampleCount;
+ }
+ else
+ {
+ psoDesc.DSVFormat = DXGI_FORMAT_UNKNOWN;
+ if (framebufferLayout->m_renderTargets.getCount())
+ {
+ psoDesc.SampleDesc.Count = framebufferLayout->m_renderTargets[0].sampleCount;
+ }
+ }
+ psoDesc.NumRenderTargets = numRenderTargets;
+ for (Int i = 0; i < numRenderTargets; i++)
+ {
+ psoDesc.RTVFormats[i] =
+ D3DUtil::getMapFormat(framebufferLayout->m_renderTargets[i].format);
+ }
+
+ psoDesc.SampleDesc.Quality = 0;
+ psoDesc.SampleMask = UINT_MAX;
+ }
+
+ {
+ auto& rs = psoDesc.RasterizerState;
+ rs.FillMode = D3DUtil::getFillMode(desc.graphics.rasterizer.fillMode);
+ rs.CullMode = D3DUtil::getCullMode(desc.graphics.rasterizer.cullMode);
+ rs.FrontCounterClockwise =
+ desc.graphics.rasterizer.frontFace == gfx::FrontFaceMode::CounterClockwise ? TRUE
+ : FALSE;
+ rs.DepthBias = desc.graphics.rasterizer.depthBias;
+ rs.DepthBiasClamp = desc.graphics.rasterizer.depthBiasClamp;
+ rs.SlopeScaledDepthBias = desc.graphics.rasterizer.slopeScaledDepthBias;
+ rs.DepthClipEnable = desc.graphics.rasterizer.depthClipEnable ? TRUE : FALSE;
+ rs.MultisampleEnable = desc.graphics.rasterizer.multisampleEnable ? TRUE : FALSE;
+ rs.AntialiasedLineEnable =
+ desc.graphics.rasterizer.antialiasedLineEnable ? TRUE : FALSE;
+ rs.ForcedSampleCount = desc.graphics.rasterizer.forcedSampleCount;
+ rs.ConservativeRaster = desc.graphics.rasterizer.enableConservativeRasterization
+ ? D3D12_CONSERVATIVE_RASTERIZATION_MODE_ON
+ : D3D12_CONSERVATIVE_RASTERIZATION_MODE_OFF;
+ }
+
+ {
+ D3D12_BLEND_DESC& blend = psoDesc.BlendState;
+ blend.IndependentBlendEnable = FALSE;
+ blend.AlphaToCoverageEnable = desc.graphics.blend.alphaToCoverageEnable ? TRUE : FALSE;
+ blend.RenderTarget[0].RenderTargetWriteMask = (uint8_t)RenderTargetWriteMask::EnableAll;
+ for (uint32_t i = 0; i < desc.graphics.blend.targetCount; i++)
+ {
+ auto& d3dDesc = blend.RenderTarget[i];
+ d3dDesc.BlendEnable = desc.graphics.blend.targets[i].enableBlend ? TRUE : FALSE;
+ d3dDesc.BlendOp = D3DUtil::getBlendOp(desc.graphics.blend.targets[i].color.op);
+ d3dDesc.BlendOpAlpha = D3DUtil::getBlendOp(desc.graphics.blend.targets[i].alpha.op);
+ d3dDesc.DestBlend =
+ D3DUtil::getBlendFactor(desc.graphics.blend.targets[i].color.dstFactor);
+ d3dDesc.DestBlendAlpha =
+ D3DUtil::getBlendFactor(desc.graphics.blend.targets[i].alpha.dstFactor);
+ d3dDesc.LogicOp = D3D12_LOGIC_OP_NOOP;
+ d3dDesc.LogicOpEnable = FALSE;
+ d3dDesc.RenderTargetWriteMask = desc.graphics.blend.targets[i].writeMask;
+ d3dDesc.SrcBlend =
+ D3DUtil::getBlendFactor(desc.graphics.blend.targets[i].color.srcFactor);
+ d3dDesc.SrcBlendAlpha =
+ D3DUtil::getBlendFactor(desc.graphics.blend.targets[i].alpha.srcFactor);
+ }
+ for (uint32_t i = 1; i < desc.graphics.blend.targetCount; i++)
+ {
+ if (memcmp(
+ &desc.graphics.blend.targets[i],
+ &desc.graphics.blend.targets[0],
+ sizeof(desc.graphics.blend.targets[0])) != 0)
+ {
+ blend.IndependentBlendEnable = TRUE;
+ break;
+ }
+ }
+ for (uint32_t i = (uint32_t)desc.graphics.blend.targetCount;
+ i < D3D12_SIMULTANEOUS_RENDER_TARGET_COUNT;
+ ++i)
+ {
+ blend.RenderTarget[i] = blend.RenderTarget[0];
+ }
+ }
+
+ {
+ auto& ds = psoDesc.DepthStencilState;
+
+ ds.DepthEnable = desc.graphics.depthStencil.depthTestEnable;
+ ds.DepthWriteMask = desc.graphics.depthStencil.depthWriteEnable
+ ? D3D12_DEPTH_WRITE_MASK_ALL
+ : D3D12_DEPTH_WRITE_MASK_ZERO;
+ ds.DepthFunc = D3DUtil::getComparisonFunc(desc.graphics.depthStencil.depthFunc);
+ ds.StencilEnable = desc.graphics.depthStencil.stencilEnable;
+ ds.StencilReadMask = (UINT8)desc.graphics.depthStencil.stencilReadMask;
+ ds.StencilWriteMask = (UINT8)desc.graphics.depthStencil.stencilWriteMask;
+ ds.FrontFace = D3DUtil::translateStencilOpDesc(desc.graphics.depthStencil.frontFace);
+ ds.BackFace = D3DUtil::translateStencilOpDesc(desc.graphics.depthStencil.backFace);
+ }
+
+ psoDesc.PrimitiveTopologyType = D3DUtil::getPrimitiveType(desc.graphics.primitiveType);
+
+ if (m_device->m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(
+ m_device->m_pipelineCreationAPIDispatcher->createGraphicsPipelineState(
+ m_device,
+ programImpl->linkedProgram.get(),
+ &psoDesc,
+ (void**)m_pipelineState.writeRef()));
+ }
+ else
+ {
+ SLANG_RETURN_ON_FAIL(m_device->m_device->CreateGraphicsPipelineState(
+ &psoDesc, IID_PPV_ARGS(m_pipelineState.writeRef())));
+ }
+ }
+ else
+ {
+
+ // Only actually create a D3D12 pipeline state if the pipeline is fully specialized.
+ ComPtr<ID3D12PipelineState> pipelineState;
+ if (!programImpl->isSpecializable())
+ {
+ // Describe and create the compute pipeline state object
+ D3D12_COMPUTE_PIPELINE_STATE_DESC computeDesc = {};
+ computeDesc.pRootSignature =
+ desc.compute.d3d12RootSignatureOverride
+ ? static_cast<ID3D12RootSignature*>(desc.compute.d3d12RootSignatureOverride)
+ : programImpl->m_rootObjectLayout->m_rootSignature;
+ computeDesc.CS = {
+ programImpl->m_shaders[0].code.getBuffer(),
+ SIZE_T(programImpl->m_shaders[0].code.getCount())};
+
+#ifdef GFX_NVAPI
+ if (m_nvapi)
+ {
+ // Also fill the extension structure.
+ // Use the same UAV slot index and register space that are declared in the shader.
+
+ // For simplicities sake we just use u0
+ NVAPI_D3D12_PSO_SET_SHADER_EXTENSION_SLOT_DESC extensionDesc;
+ extensionDesc.baseVersion = NV_PSO_EXTENSION_DESC_VER;
+ extensionDesc.version = NV_SET_SHADER_EXTENSION_SLOT_DESC_VER;
+ extensionDesc.uavSlot = 0;
+ extensionDesc.registerSpace = 0;
+
+ // Put the pointer to the extension into an array - there can be multiple extensions
+ // enabled at once.
+ const NVAPI_D3D12_PSO_EXTENSION_DESC* extensions[] = {&extensionDesc};
+
+ // Now create the PSO.
+ const NvAPI_Status nvapiStatus = NvAPI_D3D12_CreateComputePipelineState(
+ m_device->m_device,
+ &computeDesc,
+ SLANG_COUNT_OF(extensions),
+ extensions,
+ m_pipelineState.writeRef());
+
+ if (nvapiStatus != NVAPI_OK)
+ {
+ return SLANG_FAIL;
+ }
+ }
+ else
+#endif
+ {
+ if (m_device->m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(
+ m_device->m_pipelineCreationAPIDispatcher->createComputePipelineState(
+ m_device,
+ programImpl->linkedProgram.get(),
+ &computeDesc,
+ (void**)m_pipelineState.writeRef()));
+ }
+ else
+ {
+ SLANG_RETURN_ON_FAIL(m_device->m_device->CreateComputePipelineState(
+ &computeDesc, IID_PPV_ARGS(m_pipelineState.writeRef())));
+ }
+ }
+ }
+ }
+
+ return SLANG_OK;
+}
+
} // renderer_test
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index 244146ca3..92d263ed6 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -873,7 +873,7 @@ Result RendererBase::maybeSpecializePipeline(
case PipelineType::Graphics:
{
auto pipelineDesc = currentPipeline->desc.graphics;
- pipelineDesc.program = specializedProgram;
+ pipelineDesc.program = static_cast<ShaderProgramBase*>(specializedProgram.get());
SLANG_RETURN_ON_FAIL(createGraphicsPipelineState(
pipelineDesc, specializedPipelineComPtr.writeRef()));
break;
@@ -881,9 +881,9 @@ Result RendererBase::maybeSpecializePipeline(
case PipelineType::RayTracing:
{
auto pipelineDesc = currentPipeline->desc.rayTracing;
- pipelineDesc.program = specializedProgram;
+ pipelineDesc.program = static_cast<ShaderProgramBase*>(specializedProgram.get());
SLANG_RETURN_ON_FAIL(createRayTracingPipelineState(
- pipelineDesc, specializedPipelineComPtr.writeRef()));
+ pipelineDesc.get(), specializedPipelineComPtr.writeRef()));
break;
}
default:
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 8136d6735..a78113552 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -902,6 +902,72 @@ enum class PipelineType
CountOf,
};
+struct OwnedHitGroupDesc
+{
+ Slang::String hitGroupName;
+ Slang::String closestHitEntryPoint;
+ Slang::String anyHitEntryPoint;
+ Slang::String intersectionEntryPoint;
+
+ void set(const HitGroupDesc& desc)
+ {
+ hitGroupName = desc.hitGroupName;
+ closestHitEntryPoint = desc.closestHitEntryPoint;
+ anyHitEntryPoint = desc.anyHitEntryPoint;
+ intersectionEntryPoint = desc.intersectionEntryPoint;
+ }
+
+ HitGroupDesc get()
+ {
+ HitGroupDesc desc;
+ desc.hitGroupName = hitGroupName.getBuffer();
+ desc.closestHitEntryPoint = closestHitEntryPoint.getBuffer();
+ desc.anyHitEntryPoint = anyHitEntryPoint.getBuffer();
+ desc.intersectionEntryPoint = intersectionEntryPoint.getBuffer();
+ return desc;
+ }
+};
+
+struct OwnedRayTracingPipelineStateDesc
+{
+ Slang::RefPtr<ShaderProgramBase> program;
+ Slang::List<OwnedHitGroupDesc> hitGroups;
+ Slang::List<HitGroupDesc> hitGroupDescs;
+ int maxRecursion = 0;
+ int maxRayPayloadSize = 0;
+ int maxAttributeSizeInBytes = 8;
+ RayTracingPipelineFlags::Enum flags = RayTracingPipelineFlags::None;
+
+ RayTracingPipelineStateDesc get()
+ {
+ RayTracingPipelineStateDesc desc;
+ desc.program = program.Ptr();
+ desc.hitGroupCount = (int32_t)hitGroupDescs.getCount();
+ desc.hitGroups = hitGroupDescs.getBuffer();
+ desc.maxRecursion = maxRecursion;
+ desc.maxRayPayloadSize = maxRayPayloadSize;
+ desc.maxAttributeSizeInBytes = maxAttributeSizeInBytes;
+ desc.flags = flags;
+ return desc;
+ }
+
+ void set(const RayTracingPipelineStateDesc& inDesc)
+ {
+ program = static_cast<ShaderProgramBase*>(inDesc.program);
+ for (int32_t i = 0; i < inDesc.hitGroupCount; i++)
+ {
+ OwnedHitGroupDesc ownedHitGroupDesc;
+ ownedHitGroupDesc.set(inDesc.hitGroups[i]);
+ hitGroups.add(ownedHitGroupDesc);
+ hitGroupDescs.add(ownedHitGroupDesc.get());
+ }
+ maxRecursion = inDesc.maxRecursion;
+ maxRayPayloadSize = inDesc.maxRayPayloadSize;
+ maxAttributeSizeInBytes = inDesc.maxAttributeSizeInBytes;
+ flags = inDesc.flags;
+ }
+};
+
class PipelineStateBase
: public IPipelineState
, public Slang::ComObject
@@ -915,7 +981,7 @@ public:
PipelineType type;
GraphicsPipelineStateDesc graphics;
ComputePipelineStateDesc compute;
- RayTracingPipelineStateDesc rayTracing;
+ OwnedRayTracingPipelineStateDesc rayTracing;
ShaderProgramBase* getProgram()
{
switch (type)
@@ -950,6 +1016,7 @@ public:
}
virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(InteropHandle* outHandle) override;
+ virtual Result ensureAPIPipelineStateCreated() { return SLANG_OK; };
protected:
void initializeBase(const PipelineStateDesc& inDesc);
@@ -1439,5 +1506,4 @@ Result ShaderObjectBaseImpl<TShaderObjectImpl, TShaderObjectLayoutImpl, TShaderO
}
return SLANG_OK;
}
-
}
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index ebf25513b..d1713ae01 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -1023,7 +1023,7 @@ public:
{
PipelineStateDesc pipelineDesc;
pipelineDesc.type = PipelineType::RayTracing;
- pipelineDesc.rayTracing = inDesc;
+ pipelineDesc.rayTracing.set(inDesc);
initializeBase(pipelineDesc);
}