summaryrefslogtreecommitdiffstats
path: root/tools/gfx/vulkan/vk-shader-table.cpp
diff options
context:
space:
mode:
authorlucy96chen <47800040+lucy96chen@users.noreply.github.com>2022-05-26 10:54:35 -0700
committerGitHub <noreply@github.com>2022-05-26 10:54:35 -0700
commit43e1b7cdc70b2fcac8a3e8ee72f5bc91726f4ec5 (patch)
tree1e4701b4ab324a199b81e1f6c671f6660f1050c5 /tools/gfx/vulkan/vk-shader-table.cpp
parent5ff4f42c636a67724523e4fe60697cfac64908cd (diff)
Split render-vk.h/.cpp into a set of smaller files (#2244)
* Some preliminary work on splitting render-vk * render-vk split, tests currently crash on null reference * fixed circular include
Diffstat (limited to 'tools/gfx/vulkan/vk-shader-table.cpp')
-rw-r--r--tools/gfx/vulkan/vk-shader-table.cpp135
1 files changed, 135 insertions, 0 deletions
diff --git a/tools/gfx/vulkan/vk-shader-table.cpp b/tools/gfx/vulkan/vk-shader-table.cpp
new file mode 100644
index 000000000..89225f7c2
--- /dev/null
+++ b/tools/gfx/vulkan/vk-shader-table.cpp
@@ -0,0 +1,135 @@
+// vk-shader-table.cpp
+#include "vk-shader-table.h"
+
+#include "vk-device.h"
+#include "vk-transient-heap.h"
+
+#include "vk-helper-functions.h"
+
+namespace gfx
+{
+
+using namespace Slang;
+
+namespace vk
+{
+
+RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
+ PipelineStateBase* pipeline,
+ TransientResourceHeapBase* transientHeap,
+ IResourceCommandEncoder* encoder)
+{
+ auto vkApi = m_device->m_api;
+ auto rtProps = vkApi.m_rtProperties;
+ uint32_t handleSize = rtProps.shaderGroupHandleSize;
+ m_raygenTableSize = (uint32_t)VulkanUtil::calcAligned(
+ m_rayGenShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
+ m_missTableSize = (uint32_t)VulkanUtil::calcAligned(
+ m_missShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
+ m_hitTableSize = (uint32_t)VulkanUtil::calcAligned(
+ m_hitGroupCount * handleSize, rtProps.shaderGroupBaseAlignment);
+ m_callableTableSize = 0; // TODO: Are callable shaders needed?
+ uint32_t tableSize = m_raygenTableSize + m_missTableSize + m_hitTableSize + m_callableTableSize;
+
+ auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
+ ComPtr<IBufferResource> bufferResource;
+ IBufferResource::Desc bufferDesc = {};
+ bufferDesc.memoryType = MemoryType::DeviceLocal;
+ bufferDesc.defaultState = ResourceState::General;
+ bufferDesc.allowedStates =
+ ResourceStateSet(ResourceState::General, ResourceState::CopyDestination);
+ bufferDesc.type = IResource::Type::Buffer;
+ bufferDesc.sizeInBytes = tableSize;
+ m_device->createBufferResource(bufferDesc, nullptr, bufferResource.writeRef());
+
+ TransientResourceHeapImpl* transientHeapImpl =
+ static_cast<TransientResourceHeapImpl*>(transientHeap);
+
+ IBufferResource* stagingBuffer = nullptr;
+ Offset stagingBufferOffset = 0;
+ transientHeapImpl->allocateStagingBuffer(
+ tableSize, stagingBuffer, stagingBufferOffset, MemoryType::Upload);
+
+ assert(stagingBuffer);
+ void* stagingPtr = nullptr;
+ stagingBuffer->map(nullptr, &stagingPtr);
+
+ List<uint8_t> handles;
+ auto handleCount = pipelineImpl->shaderGroupCount;
+ auto totalHandleSize = handleSize * handleCount;
+ handles.setCount(totalHandleSize);
+ auto result = vkApi.vkGetRayTracingShaderGroupHandlesKHR(
+ m_device->m_device,
+ pipelineImpl->m_pipeline,
+ 0,
+ (uint32_t)handleCount,
+ totalHandleSize,
+ handles.getBuffer());
+
+ uint8_t* stagingBufferPtr = (uint8_t*)stagingPtr + stagingBufferOffset;
+ auto subTablePtr = stagingBufferPtr;
+ Int shaderTableEntryCounter = 0;
+
+ // Each loop calculates the copy source and destination locations by fetching the name
+ // of the shader group from the list of shader group names and getting its corresponding
+ // index in the buffer of handles.
+ for (uint32_t i = 0; i < m_rayGenShaderCount; i++)
+ {
+ auto dstHandlePtr = subTablePtr + i * handleSize;
+ auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++];
+ auto shaderGroupIndexPtr =
+ pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName);
+ if (!shaderGroupIndexPtr)
+ continue;
+
+ auto shaderGroupIndex = *shaderGroupIndexPtr;
+ auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize;
+ memcpy(dstHandlePtr, srcHandlePtr, handleSize);
+ }
+ subTablePtr += m_raygenTableSize;
+
+ for (uint32_t i = 0; i < m_missShaderCount; i++)
+ {
+ auto dstHandlePtr = subTablePtr + i * handleSize;
+ auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++];
+ auto shaderGroupIndexPtr =
+ pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName);
+ if (!shaderGroupIndexPtr)
+ continue;
+
+ auto shaderGroupIndex = *shaderGroupIndexPtr;
+ auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize;
+ memcpy(dstHandlePtr, srcHandlePtr, handleSize);
+ }
+ subTablePtr += m_missTableSize;
+
+ for (uint32_t i = 0; i < m_hitGroupCount; i++)
+ {
+ auto dstHandlePtr = subTablePtr + i * handleSize;
+ auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++];
+ auto shaderGroupIndexPtr =
+ pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName);
+ if (!shaderGroupIndexPtr)
+ continue;
+
+ auto shaderGroupIndex = *shaderGroupIndexPtr;
+ auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize;
+ memcpy(dstHandlePtr, srcHandlePtr, handleSize);
+ }
+ subTablePtr += m_hitTableSize;
+
+ // TODO: Callable shaders?
+
+ stagingBuffer->unmap(nullptr);
+ encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize);
+ encoder->bufferBarrier(
+ 1,
+ bufferResource.readRef(),
+ gfx::ResourceState::CopyDestination,
+ gfx::ResourceState::ShaderResource);
+ RefPtr<BufferResource> resultPtr = static_cast<BufferResource*>(bufferResource.get());
+ return _Move(resultPtr);
+}
+
+} // namespace vk
+} // namespace gfx