From eefdd4ab99fa99ed326b68cd2b0d4024347ed8fc Mon Sep 17 00:00:00 2001 From: skallweitNV <64953474+skallweitNV@users.noreply.github.com> Date: Tue, 28 May 2024 02:05:12 +0200 Subject: add support for callable shaders in gfx (#3460) Co-authored-by: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Co-authored-by: Yong He --- tools/gfx/d3d12/d3d12-command-encoder.cpp | 9 +++++++++ tools/gfx/d3d12/d3d12-shader-table.cpp | 12 +++++++++++- tools/gfx/d3d12/d3d12-shader-table.h | 1 + tools/gfx/renderer-shared.cpp | 17 +++++++++++++++-- tools/gfx/renderer-shared.h | 1 + tools/gfx/vulkan/vk-command-encoder.cpp | 7 +++---- tools/gfx/vulkan/vk-shader-table.cpp | 18 ++++++++++++++++-- 7 files changed, 56 insertions(+), 9 deletions(-) (limited to 'tools/gfx') diff --git a/tools/gfx/d3d12/d3d12-command-encoder.cpp b/tools/gfx/d3d12/d3d12-command-encoder.cpp index ff50dcc5f..892c792fb 100644 --- a/tools/gfx/d3d12/d3d12-command-encoder.cpp +++ b/tools/gfx/d3d12/d3d12-command-encoder.cpp @@ -1409,6 +1409,15 @@ Result RayTracingCommandEncoderImpl::dispatchRays( dispatchDesc.HitGroupTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; } + if (shaderTableImpl->m_callableShaderCount > 0) + { + dispatchDesc.CallableShaderTable.StartAddress = + shaderTableAddr + shaderTableImpl->m_callableTableOffset; + dispatchDesc.CallableShaderTable.SizeInBytes = + shaderTableImpl->m_callableShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + dispatchDesc.CallableShaderTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + } + dispatchDesc.Width = (UINT)width; dispatchDesc.Height = (UINT)height; dispatchDesc.Depth = (UINT)depth; diff --git a/tools/gfx/d3d12/d3d12-shader-table.cpp b/tools/gfx/d3d12/d3d12-shader-table.cpp index 3e49350ab..f54b7c5e9 100644 --- a/tools/gfx/d3d12/d3d12-shader-table.cpp +++ b/tools/gfx/d3d12/d3d12-shader-table.cpp @@ -20,11 +20,14 @@ RefPtr ShaderTableImpl::createDeviceBuffer( uint32_t raygenTableSize = m_rayGenShaderCount * kRayGenRecordSize; uint32_t missTableSize = m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; uint32_t hitgroupTableSize = m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + uint32_t callableTableSize = m_callableShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; m_rayGenTableOffset = 0; m_missTableOffset = raygenTableSize; m_hitGroupTableOffset = (uint32_t)D3DUtil::calcAligned( m_missTableOffset + missTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT); - uint32_t tableSize = m_hitGroupTableOffset + hitgroupTableSize; + m_callableTableOffset = (uint32_t)D3DUtil::calcAligned( + m_hitGroupTableOffset + hitgroupTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT); + uint32_t tableSize = m_callableTableOffset + callableTableSize; auto pipelineImpl = static_cast(pipeline); ComPtr bufferResource; @@ -88,6 +91,13 @@ RefPtr ShaderTableImpl::createDeviceBuffer( m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + i], m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + i]); } + for (uint32_t i = 0; i < m_callableShaderCount; i++) + { + copyShaderIdInto( + stagingBufferPtr + m_callableTableOffset + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i, + m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + m_hitGroupCount + i], + m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + m_hitGroupCount + i]); + } stagingBuffer->unmap(nullptr); encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize); diff --git a/tools/gfx/d3d12/d3d12-shader-table.h b/tools/gfx/d3d12/d3d12-shader-table.h index 8cfd74874..6b39966b3 100644 --- a/tools/gfx/d3d12/d3d12-shader-table.h +++ b/tools/gfx/d3d12/d3d12-shader-table.h @@ -16,6 +16,7 @@ public: uint32_t m_rayGenTableOffset; uint32_t m_missTableOffset; uint32_t m_hitGroupTableOffset; + uint32_t m_callableTableOffset; DeviceImpl* m_device; diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 911d4712c..d4fda3183 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -1264,8 +1264,9 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc) m_rayGenShaderCount = desc.rayGenShaderCount; m_missShaderCount = desc.missShaderCount; m_hitGroupCount = desc.hitGroupCount; - m_shaderGroupNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount); - m_recordOverwrites.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount); + m_callableShaderCount = desc.callableShaderCount; + m_shaderGroupNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount + desc.callableShaderCount); + m_recordOverwrites.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount + desc.callableShaderCount); for (GfxIndex i = 0; i < desc.rayGenShaderCount; i++) { m_shaderGroupNames.add(desc.rayGenShaderEntryPointNames[i]); @@ -1302,6 +1303,18 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc) m_recordOverwrites.add(ShaderRecordOverwrite{}); } } + for (GfxIndex i = 0; i < desc.callableShaderCount; i++) + { + m_shaderGroupNames.add(desc.callableShaderEntryPointNames[i]); + if (desc.callableShaderRecordOverwrites) + { + m_recordOverwrites.add(desc.callableShaderRecordOverwrites[i]); + } + else + { + m_recordOverwrites.add(ShaderRecordOverwrite{}); + } + } return SLANG_OK; } diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 48d8d0e01..952beb2c1 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -1183,6 +1183,7 @@ public: uint32_t m_rayGenShaderCount; uint32_t m_missShaderCount; uint32_t m_hitGroupCount; + uint32_t m_callableShaderCount; Slang::Dictionary> m_deviceBuffers; diff --git a/tools/gfx/vulkan/vk-command-encoder.cpp b/tools/gfx/vulkan/vk-command-encoder.cpp index 568b49179..7f3110ea5 100644 --- a/tools/gfx/vulkan/vk-command-encoder.cpp +++ b/tools/gfx/vulkan/vk-command-encoder.cpp @@ -1480,11 +1480,10 @@ Result RayTracingCommandEncoder::dispatchRays( hitSBT.stride = alignedHandleSize; hitSBT.size = shaderTableImpl->m_hitTableSize; - // TODO: Are callable shaders needed? VkStridedDeviceAddressRegionKHR callableSBT; - callableSBT.deviceAddress = 0; - callableSBT.stride = 0; - callableSBT.size = 0; + callableSBT.deviceAddress = hitSBT.deviceAddress + hitSBT.size; + callableSBT.stride = alignedHandleSize; + callableSBT.size = shaderTableImpl->m_callableTableSize; vkApi.vkCmdTraceRaysKHR( vkCommandBuffer, diff --git a/tools/gfx/vulkan/vk-shader-table.cpp b/tools/gfx/vulkan/vk-shader-table.cpp index 0b6488465..beb826111 100644 --- a/tools/gfx/vulkan/vk-shader-table.cpp +++ b/tools/gfx/vulkan/vk-shader-table.cpp @@ -27,7 +27,8 @@ RefPtr ShaderTableImpl::createDeviceBuffer( m_missShaderCount * handleSize, rtProps.shaderGroupBaseAlignment); m_hitTableSize = (uint32_t)VulkanUtil::calcAligned( m_hitGroupCount * handleSize, rtProps.shaderGroupBaseAlignment); - m_callableTableSize = 0; // TODO: Are callable shaders needed? + m_callableTableSize = (uint32_t)VulkanUtil::calcAligned( + m_callableShaderCount * handleSize, rtProps.shaderGroupBaseAlignment); uint32_t tableSize = m_raygenTableSize + m_missTableSize + m_hitTableSize + m_callableTableSize; auto pipelineImpl = static_cast(pipeline); @@ -122,7 +123,20 @@ RefPtr ShaderTableImpl::createDeviceBuffer( } subTablePtr += m_hitTableSize; - // TODO: Callable shaders? + for (uint32_t i = 0; i < m_callableShaderCount; i++) + { + auto dstHandlePtr = subTablePtr + i * handleSize; + auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++]; + auto shaderGroupIndexPtr = + pipelineImpl->shaderGroupNameToIndex.tryGetValue(shaderGroupName); + if (!shaderGroupIndexPtr) + continue; + + auto shaderGroupIndex = *shaderGroupIndexPtr; + auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize; + memcpy(dstHandlePtr, srcHandlePtr, handleSize); + } + subTablePtr += m_callableTableSize; stagingBuffer->unmap(nullptr); encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize); -- cgit v1.2.3