summaryrefslogtreecommitdiffstats
path: root/tools/gfx/vulkan
diff options
context:
space:
mode:
authorskallweitNV <64953474+skallweitNV@users.noreply.github.com>2024-05-28 02:05:12 +0200
committerGitHub <noreply@github.com>2024-05-27 17:05:12 -0700
commiteefdd4ab99fa99ed326b68cd2b0d4024347ed8fc (patch)
tree22fdde5317e337ea5307c3487038a5191db9b0f3 /tools/gfx/vulkan
parentd9443d670ef8413971fe7c3f02368b60a7fc5904 (diff)
add support for callable shaders in gfx (#3460)
Co-authored-by: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tools/gfx/vulkan')
-rw-r--r--tools/gfx/vulkan/vk-command-encoder.cpp7
-rw-r--r--tools/gfx/vulkan/vk-shader-table.cpp18
2 files changed, 19 insertions, 6 deletions
diff --git a/tools/gfx/vulkan/vk-command-encoder.cpp b/tools/gfx/vulkan/vk-command-encoder.cpp
index 568b49179..7f3110ea5 100644
--- a/tools/gfx/vulkan/vk-command-encoder.cpp
+++ b/tools/gfx/vulkan/vk-command-encoder.cpp
@@ -1480,11 +1480,10 @@ Result RayTracingCommandEncoder::dispatchRays(
hitSBT.stride = alignedHandleSize;
hitSBT.size = shaderTableImpl->m_hitTableSize;
- // TODO: Are callable shaders needed?
VkStridedDeviceAddressRegionKHR callableSBT;
- callableSBT.deviceAddress = 0;
- callableSBT.stride = 0;
- callableSBT.size = 0;
+ callableSBT.deviceAddress = hitSBT.deviceAddress + hitSBT.size;
+ callableSBT.stride = alignedHandleSize;
+ callableSBT.size = shaderTableImpl->m_callableTableSize;
vkApi.vkCmdTraceRaysKHR(
vkCommandBuffer,
diff --git a/tools/gfx/vulkan/vk-shader-table.cpp b/tools/gfx/vulkan/vk-shader-table.cpp
index 0b6488465..beb826111 100644
--- a/tools/gfx/vulkan/vk-shader-table.cpp
+++ b/tools/gfx/vulkan/vk-shader-table.cpp
@@ -27,7 +27,8 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
m_missShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
m_hitTableSize = (uint32_t)VulkanUtil::calcAligned(
m_hitGroupCount * handleSize, rtProps.shaderGroupBaseAlignment);
- m_callableTableSize = 0; // TODO: Are callable shaders needed?
+ m_callableTableSize = (uint32_t)VulkanUtil::calcAligned(
+ m_callableShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
uint32_t tableSize = m_raygenTableSize + m_missTableSize + m_hitTableSize + m_callableTableSize;
auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
@@ -122,7 +123,20 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
}
subTablePtr += m_hitTableSize;
- // TODO: Callable shaders?
+ for (uint32_t i = 0; i < m_callableShaderCount; i++)
+ {
+ auto dstHandlePtr = subTablePtr + i * handleSize;
+ auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++];
+ auto shaderGroupIndexPtr =
+ pipelineImpl->shaderGroupNameToIndex.tryGetValue(shaderGroupName);
+ if (!shaderGroupIndexPtr)
+ continue;
+
+ auto shaderGroupIndex = *shaderGroupIndexPtr;
+ auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize;
+ memcpy(dstHandlePtr, srcHandlePtr, handleSize);
+ }
+ subTablePtr += m_callableTableSize;
stagingBuffer->unmap(nullptr);
encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize);