From 43e1b7cdc70b2fcac8a3e8ee72f5bc91726f4ec5 Mon Sep 17 00:00:00 2001 From: lucy96chen <47800040+lucy96chen@users.noreply.github.com> Date: Thu, 26 May 2022 10:54:35 -0700 Subject: Split render-vk.h/.cpp into a set of smaller files (#2244) * Some preliminary work on splitting render-vk * render-vk split, tests currently crash on null reference * fixed circular include --- tools/gfx/vulkan/vk-shader-table.cpp | 135 +++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 tools/gfx/vulkan/vk-shader-table.cpp (limited to 'tools/gfx/vulkan/vk-shader-table.cpp') diff --git a/tools/gfx/vulkan/vk-shader-table.cpp b/tools/gfx/vulkan/vk-shader-table.cpp new file mode 100644 index 000000000..89225f7c2 --- /dev/null +++ b/tools/gfx/vulkan/vk-shader-table.cpp @@ -0,0 +1,135 @@ +// vk-shader-table.cpp +#include "vk-shader-table.h" + +#include "vk-device.h" +#include "vk-transient-heap.h" + +#include "vk-helper-functions.h" + +namespace gfx +{ + +using namespace Slang; + +namespace vk +{ + +RefPtr ShaderTableImpl::createDeviceBuffer( + PipelineStateBase* pipeline, + TransientResourceHeapBase* transientHeap, + IResourceCommandEncoder* encoder) +{ + auto vkApi = m_device->m_api; + auto rtProps = vkApi.m_rtProperties; + uint32_t handleSize = rtProps.shaderGroupHandleSize; + m_raygenTableSize = (uint32_t)VulkanUtil::calcAligned( + m_rayGenShaderCount * handleSize, rtProps.shaderGroupBaseAlignment); + m_missTableSize = (uint32_t)VulkanUtil::calcAligned( + m_missShaderCount * handleSize, rtProps.shaderGroupBaseAlignment); + m_hitTableSize = (uint32_t)VulkanUtil::calcAligned( + m_hitGroupCount * handleSize, rtProps.shaderGroupBaseAlignment); + m_callableTableSize = 0; // TODO: Are callable shaders needed? + uint32_t tableSize = m_raygenTableSize + m_missTableSize + m_hitTableSize + m_callableTableSize; + + auto pipelineImpl = static_cast(pipeline); + ComPtr bufferResource; + IBufferResource::Desc bufferDesc = {}; + bufferDesc.memoryType = MemoryType::DeviceLocal; + bufferDesc.defaultState = ResourceState::General; + bufferDesc.allowedStates = + ResourceStateSet(ResourceState::General, ResourceState::CopyDestination); + bufferDesc.type = IResource::Type::Buffer; + bufferDesc.sizeInBytes = tableSize; + m_device->createBufferResource(bufferDesc, nullptr, bufferResource.writeRef()); + + TransientResourceHeapImpl* transientHeapImpl = + static_cast(transientHeap); + + IBufferResource* stagingBuffer = nullptr; + Offset stagingBufferOffset = 0; + transientHeapImpl->allocateStagingBuffer( + tableSize, stagingBuffer, stagingBufferOffset, MemoryType::Upload); + + assert(stagingBuffer); + void* stagingPtr = nullptr; + stagingBuffer->map(nullptr, &stagingPtr); + + List handles; + auto handleCount = pipelineImpl->shaderGroupCount; + auto totalHandleSize = handleSize * handleCount; + handles.setCount(totalHandleSize); + auto result = vkApi.vkGetRayTracingShaderGroupHandlesKHR( + m_device->m_device, + pipelineImpl->m_pipeline, + 0, + (uint32_t)handleCount, + totalHandleSize, + handles.getBuffer()); + + uint8_t* stagingBufferPtr = (uint8_t*)stagingPtr + stagingBufferOffset; + auto subTablePtr = stagingBufferPtr; + Int shaderTableEntryCounter = 0; + + // Each loop calculates the copy source and destination locations by fetching the name + // of the shader group from the list of shader group names and getting its corresponding + // index in the buffer of handles. + for (uint32_t i = 0; i < m_rayGenShaderCount; i++) + { + auto dstHandlePtr = subTablePtr + i * handleSize; + auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++]; + auto shaderGroupIndexPtr = + pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName); + if (!shaderGroupIndexPtr) + continue; + + auto shaderGroupIndex = *shaderGroupIndexPtr; + auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize; + memcpy(dstHandlePtr, srcHandlePtr, handleSize); + } + subTablePtr += m_raygenTableSize; + + for (uint32_t i = 0; i < m_missShaderCount; i++) + { + auto dstHandlePtr = subTablePtr + i * handleSize; + auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++]; + auto shaderGroupIndexPtr = + pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName); + if (!shaderGroupIndexPtr) + continue; + + auto shaderGroupIndex = *shaderGroupIndexPtr; + auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize; + memcpy(dstHandlePtr, srcHandlePtr, handleSize); + } + subTablePtr += m_missTableSize; + + for (uint32_t i = 0; i < m_hitGroupCount; i++) + { + auto dstHandlePtr = subTablePtr + i * handleSize; + auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++]; + auto shaderGroupIndexPtr = + pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName); + if (!shaderGroupIndexPtr) + continue; + + auto shaderGroupIndex = *shaderGroupIndexPtr; + auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize; + memcpy(dstHandlePtr, srcHandlePtr, handleSize); + } + subTablePtr += m_hitTableSize; + + // TODO: Callable shaders? + + stagingBuffer->unmap(nullptr); + encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize); + encoder->bufferBarrier( + 1, + bufferResource.readRef(), + gfx::ResourceState::CopyDestination, + gfx::ResourceState::ShaderResource); + RefPtr resultPtr = static_cast(bufferResource.get()); + return _Move(resultPtr); +} + +} // namespace vk +} // namespace gfx -- cgit v1.2.3