summaryrefslogtreecommitdiffstats
path: root/tools/gfx/vulkan
diff options
context:
space:
mode:
Diffstat (limited to 'tools/gfx/vulkan')
-rw-r--r--tools/gfx/vulkan/vk-command-encoder.cpp7
-rw-r--r--tools/gfx/vulkan/vk-shader-table.cpp18
2 files changed, 19 insertions, 6 deletions
diff --git a/tools/gfx/vulkan/vk-command-encoder.cpp b/tools/gfx/vulkan/vk-command-encoder.cpp
index 568b49179..7f3110ea5 100644
--- a/tools/gfx/vulkan/vk-command-encoder.cpp
+++ b/tools/gfx/vulkan/vk-command-encoder.cpp
@@ -1480,11 +1480,10 @@ Result RayTracingCommandEncoder::dispatchRays(
hitSBT.stride = alignedHandleSize;
hitSBT.size = shaderTableImpl->m_hitTableSize;
- // TODO: Are callable shaders needed?
VkStridedDeviceAddressRegionKHR callableSBT;
- callableSBT.deviceAddress = 0;
- callableSBT.stride = 0;
- callableSBT.size = 0;
+ callableSBT.deviceAddress = hitSBT.deviceAddress + hitSBT.size;
+ callableSBT.stride = alignedHandleSize;
+ callableSBT.size = shaderTableImpl->m_callableTableSize;
vkApi.vkCmdTraceRaysKHR(
vkCommandBuffer,
diff --git a/tools/gfx/vulkan/vk-shader-table.cpp b/tools/gfx/vulkan/vk-shader-table.cpp
index 0b6488465..beb826111 100644
--- a/tools/gfx/vulkan/vk-shader-table.cpp
+++ b/tools/gfx/vulkan/vk-shader-table.cpp
@@ -27,7 +27,8 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
m_missShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
m_hitTableSize = (uint32_t)VulkanUtil::calcAligned(
m_hitGroupCount * handleSize, rtProps.shaderGroupBaseAlignment);
- m_callableTableSize = 0; // TODO: Are callable shaders needed?
+ m_callableTableSize = (uint32_t)VulkanUtil::calcAligned(
+ m_callableShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
uint32_t tableSize = m_raygenTableSize + m_missTableSize + m_hitTableSize + m_callableTableSize;
auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
@@ -122,7 +123,20 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
}
subTablePtr += m_hitTableSize;
- // TODO: Callable shaders?
+ for (uint32_t i = 0; i < m_callableShaderCount; i++)
+ {
+ auto dstHandlePtr = subTablePtr + i * handleSize;
+ auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++];
+ auto shaderGroupIndexPtr =
+ pipelineImpl->shaderGroupNameToIndex.tryGetValue(shaderGroupName);
+ if (!shaderGroupIndexPtr)
+ continue;
+
+ auto shaderGroupIndex = *shaderGroupIndexPtr;
+ auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize;
+ memcpy(dstHandlePtr, srcHandlePtr, handleSize);
+ }
+ subTablePtr += m_callableTableSize;
stagingBuffer->unmap(nullptr);
encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize);