summaryrefslogtreecommitdiffstats
path: root/tools/gfx
diff options
context:
space:
mode:
authorskallweitNV <64953474+skallweitNV@users.noreply.github.com>2024-05-28 02:05:12 +0200
committerGitHub <noreply@github.com>2024-05-27 17:05:12 -0700
commiteefdd4ab99fa99ed326b68cd2b0d4024347ed8fc (patch)
tree22fdde5317e337ea5307c3487038a5191db9b0f3 /tools/gfx
parentd9443d670ef8413971fe7c3f02368b60a7fc5904 (diff)
add support for callable shaders in gfx (#3460)
Co-authored-by: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tools/gfx')
-rw-r--r--tools/gfx/d3d12/d3d12-command-encoder.cpp9
-rw-r--r--tools/gfx/d3d12/d3d12-shader-table.cpp12
-rw-r--r--tools/gfx/d3d12/d3d12-shader-table.h1
-rw-r--r--tools/gfx/renderer-shared.cpp17
-rw-r--r--tools/gfx/renderer-shared.h1
-rw-r--r--tools/gfx/vulkan/vk-command-encoder.cpp7
-rw-r--r--tools/gfx/vulkan/vk-shader-table.cpp18
7 files changed, 56 insertions, 9 deletions
diff --git a/tools/gfx/d3d12/d3d12-command-encoder.cpp b/tools/gfx/d3d12/d3d12-command-encoder.cpp
index ff50dcc5f..892c792fb 100644
--- a/tools/gfx/d3d12/d3d12-command-encoder.cpp
+++ b/tools/gfx/d3d12/d3d12-command-encoder.cpp
@@ -1409,6 +1409,15 @@ Result RayTracingCommandEncoderImpl::dispatchRays(
dispatchDesc.HitGroupTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
}
+ if (shaderTableImpl->m_callableShaderCount > 0)
+ {
+ dispatchDesc.CallableShaderTable.StartAddress =
+ shaderTableAddr + shaderTableImpl->m_callableTableOffset;
+ dispatchDesc.CallableShaderTable.SizeInBytes =
+ shaderTableImpl->m_callableShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ dispatchDesc.CallableShaderTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ }
+
dispatchDesc.Width = (UINT)width;
dispatchDesc.Height = (UINT)height;
dispatchDesc.Depth = (UINT)depth;
diff --git a/tools/gfx/d3d12/d3d12-shader-table.cpp b/tools/gfx/d3d12/d3d12-shader-table.cpp
index 3e49350ab..f54b7c5e9 100644
--- a/tools/gfx/d3d12/d3d12-shader-table.cpp
+++ b/tools/gfx/d3d12/d3d12-shader-table.cpp
@@ -20,11 +20,14 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
uint32_t raygenTableSize = m_rayGenShaderCount * kRayGenRecordSize;
uint32_t missTableSize = m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
uint32_t hitgroupTableSize = m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ uint32_t callableTableSize = m_callableShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
m_rayGenTableOffset = 0;
m_missTableOffset = raygenTableSize;
m_hitGroupTableOffset = (uint32_t)D3DUtil::calcAligned(
m_missTableOffset + missTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
- uint32_t tableSize = m_hitGroupTableOffset + hitgroupTableSize;
+ m_callableTableOffset = (uint32_t)D3DUtil::calcAligned(
+ m_hitGroupTableOffset + hitgroupTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
+ uint32_t tableSize = m_callableTableOffset + callableTableSize;
auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
ComPtr<IBufferResource> bufferResource;
@@ -88,6 +91,13 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + i],
m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + i]);
}
+ for (uint32_t i = 0; i < m_callableShaderCount; i++)
+ {
+ copyShaderIdInto(
+ stagingBufferPtr + m_callableTableOffset + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
+ m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + m_hitGroupCount + i],
+ m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + m_hitGroupCount + i]);
+ }
stagingBuffer->unmap(nullptr);
encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize);
diff --git a/tools/gfx/d3d12/d3d12-shader-table.h b/tools/gfx/d3d12/d3d12-shader-table.h
index 8cfd74874..6b39966b3 100644
--- a/tools/gfx/d3d12/d3d12-shader-table.h
+++ b/tools/gfx/d3d12/d3d12-shader-table.h
@@ -16,6 +16,7 @@ public:
uint32_t m_rayGenTableOffset;
uint32_t m_missTableOffset;
uint32_t m_hitGroupTableOffset;
+ uint32_t m_callableTableOffset;
DeviceImpl* m_device;
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index 911d4712c..d4fda3183 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -1264,8 +1264,9 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc)
m_rayGenShaderCount = desc.rayGenShaderCount;
m_missShaderCount = desc.missShaderCount;
m_hitGroupCount = desc.hitGroupCount;
- m_shaderGroupNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount);
- m_recordOverwrites.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount);
+ m_callableShaderCount = desc.callableShaderCount;
+ m_shaderGroupNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount + desc.callableShaderCount);
+ m_recordOverwrites.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount + desc.callableShaderCount);
for (GfxIndex i = 0; i < desc.rayGenShaderCount; i++)
{
m_shaderGroupNames.add(desc.rayGenShaderEntryPointNames[i]);
@@ -1302,6 +1303,18 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc)
m_recordOverwrites.add(ShaderRecordOverwrite{});
}
}
+ for (GfxIndex i = 0; i < desc.callableShaderCount; i++)
+ {
+ m_shaderGroupNames.add(desc.callableShaderEntryPointNames[i]);
+ if (desc.callableShaderRecordOverwrites)
+ {
+ m_recordOverwrites.add(desc.callableShaderRecordOverwrites[i]);
+ }
+ else
+ {
+ m_recordOverwrites.add(ShaderRecordOverwrite{});
+ }
+ }
return SLANG_OK;
}
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 48d8d0e01..952beb2c1 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -1183,6 +1183,7 @@ public:
uint32_t m_rayGenShaderCount;
uint32_t m_missShaderCount;
uint32_t m_hitGroupCount;
+ uint32_t m_callableShaderCount;
Slang::Dictionary<PipelineStateBase*, Slang::RefPtr<BufferResource>> m_deviceBuffers;
diff --git a/tools/gfx/vulkan/vk-command-encoder.cpp b/tools/gfx/vulkan/vk-command-encoder.cpp
index 568b49179..7f3110ea5 100644
--- a/tools/gfx/vulkan/vk-command-encoder.cpp
+++ b/tools/gfx/vulkan/vk-command-encoder.cpp
@@ -1480,11 +1480,10 @@ Result RayTracingCommandEncoder::dispatchRays(
hitSBT.stride = alignedHandleSize;
hitSBT.size = shaderTableImpl->m_hitTableSize;
- // TODO: Are callable shaders needed?
VkStridedDeviceAddressRegionKHR callableSBT;
- callableSBT.deviceAddress = 0;
- callableSBT.stride = 0;
- callableSBT.size = 0;
+ callableSBT.deviceAddress = hitSBT.deviceAddress + hitSBT.size;
+ callableSBT.stride = alignedHandleSize;
+ callableSBT.size = shaderTableImpl->m_callableTableSize;
vkApi.vkCmdTraceRaysKHR(
vkCommandBuffer,
diff --git a/tools/gfx/vulkan/vk-shader-table.cpp b/tools/gfx/vulkan/vk-shader-table.cpp
index 0b6488465..beb826111 100644
--- a/tools/gfx/vulkan/vk-shader-table.cpp
+++ b/tools/gfx/vulkan/vk-shader-table.cpp
@@ -27,7 +27,8 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
m_missShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
m_hitTableSize = (uint32_t)VulkanUtil::calcAligned(
m_hitGroupCount * handleSize, rtProps.shaderGroupBaseAlignment);
- m_callableTableSize = 0; // TODO: Are callable shaders needed?
+ m_callableTableSize = (uint32_t)VulkanUtil::calcAligned(
+ m_callableShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
uint32_t tableSize = m_raygenTableSize + m_missTableSize + m_hitTableSize + m_callableTableSize;
auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
@@ -122,7 +123,20 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
}
subTablePtr += m_hitTableSize;
- // TODO: Callable shaders?
+ for (uint32_t i = 0; i < m_callableShaderCount; i++)
+ {
+ auto dstHandlePtr = subTablePtr + i * handleSize;
+ auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++];
+ auto shaderGroupIndexPtr =
+ pipelineImpl->shaderGroupNameToIndex.tryGetValue(shaderGroupName);
+ if (!shaderGroupIndexPtr)
+ continue;
+
+ auto shaderGroupIndex = *shaderGroupIndexPtr;
+ auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize;
+ memcpy(dstHandlePtr, srcHandlePtr, handleSize);
+ }
+ subTablePtr += m_callableTableSize;
stagingBuffer->unmap(nullptr);
encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize);