diff options
Diffstat (limited to 'tools/gfx/vulkan')
| -rw-r--r-- | tools/gfx/vulkan/vk-command-encoder.cpp | 7 | ||||
| -rw-r--r-- | tools/gfx/vulkan/vk-shader-table.cpp | 18 |
2 files changed, 19 insertions, 6 deletions
diff --git a/tools/gfx/vulkan/vk-command-encoder.cpp b/tools/gfx/vulkan/vk-command-encoder.cpp index 568b49179..7f3110ea5 100644 --- a/tools/gfx/vulkan/vk-command-encoder.cpp +++ b/tools/gfx/vulkan/vk-command-encoder.cpp @@ -1480,11 +1480,10 @@ Result RayTracingCommandEncoder::dispatchRays( hitSBT.stride = alignedHandleSize; hitSBT.size = shaderTableImpl->m_hitTableSize; - // TODO: Are callable shaders needed? VkStridedDeviceAddressRegionKHR callableSBT; - callableSBT.deviceAddress = 0; - callableSBT.stride = 0; - callableSBT.size = 0; + callableSBT.deviceAddress = hitSBT.deviceAddress + hitSBT.size; + callableSBT.stride = alignedHandleSize; + callableSBT.size = shaderTableImpl->m_callableTableSize; vkApi.vkCmdTraceRaysKHR( vkCommandBuffer, diff --git a/tools/gfx/vulkan/vk-shader-table.cpp b/tools/gfx/vulkan/vk-shader-table.cpp index 0b6488465..beb826111 100644 --- a/tools/gfx/vulkan/vk-shader-table.cpp +++ b/tools/gfx/vulkan/vk-shader-table.cpp @@ -27,7 +27,8 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer( m_missShaderCount * handleSize, rtProps.shaderGroupBaseAlignment); m_hitTableSize = (uint32_t)VulkanUtil::calcAligned( m_hitGroupCount * handleSize, rtProps.shaderGroupBaseAlignment); - m_callableTableSize = 0; // TODO: Are callable shaders needed? + m_callableTableSize = (uint32_t)VulkanUtil::calcAligned( + m_callableShaderCount * handleSize, rtProps.shaderGroupBaseAlignment); uint32_t tableSize = m_raygenTableSize + m_missTableSize + m_hitTableSize + m_callableTableSize; auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline); @@ -122,7 +123,20 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer( } subTablePtr += m_hitTableSize; - // TODO: Callable shaders? + for (uint32_t i = 0; i < m_callableShaderCount; i++) + { + auto dstHandlePtr = subTablePtr + i * handleSize; + auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++]; + auto shaderGroupIndexPtr = + pipelineImpl->shaderGroupNameToIndex.tryGetValue(shaderGroupName); + if (!shaderGroupIndexPtr) + continue; + + auto shaderGroupIndex = *shaderGroupIndexPtr; + auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize; + memcpy(dstHandlePtr, srcHandlePtr, handleSize); + } + subTablePtr += m_callableTableSize; stagingBuffer->unmap(nullptr); encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize); |
