summaryrefslogtreecommitdiffstats
path: root/tools/gfx/d3d12
diff options
context:
space:
mode:
authorskallweitNV <64953474+skallweitNV@users.noreply.github.com>2024-05-28 02:05:12 +0200
committerGitHub <noreply@github.com>2024-05-27 17:05:12 -0700
commiteefdd4ab99fa99ed326b68cd2b0d4024347ed8fc (patch)
tree22fdde5317e337ea5307c3487038a5191db9b0f3 /tools/gfx/d3d12
parentd9443d670ef8413971fe7c3f02368b60a7fc5904 (diff)
add support for callable shaders in gfx (#3460)
Co-authored-by: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tools/gfx/d3d12')
-rw-r--r--tools/gfx/d3d12/d3d12-command-encoder.cpp9
-rw-r--r--tools/gfx/d3d12/d3d12-shader-table.cpp12
-rw-r--r--tools/gfx/d3d12/d3d12-shader-table.h1
3 files changed, 21 insertions, 1 deletions
diff --git a/tools/gfx/d3d12/d3d12-command-encoder.cpp b/tools/gfx/d3d12/d3d12-command-encoder.cpp
index ff50dcc5f..892c792fb 100644
--- a/tools/gfx/d3d12/d3d12-command-encoder.cpp
+++ b/tools/gfx/d3d12/d3d12-command-encoder.cpp
@@ -1409,6 +1409,15 @@ Result RayTracingCommandEncoderImpl::dispatchRays(
dispatchDesc.HitGroupTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
}
+ if (shaderTableImpl->m_callableShaderCount > 0)
+ {
+ dispatchDesc.CallableShaderTable.StartAddress =
+ shaderTableAddr + shaderTableImpl->m_callableTableOffset;
+ dispatchDesc.CallableShaderTable.SizeInBytes =
+ shaderTableImpl->m_callableShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ dispatchDesc.CallableShaderTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ }
+
dispatchDesc.Width = (UINT)width;
dispatchDesc.Height = (UINT)height;
dispatchDesc.Depth = (UINT)depth;
diff --git a/tools/gfx/d3d12/d3d12-shader-table.cpp b/tools/gfx/d3d12/d3d12-shader-table.cpp
index 3e49350ab..f54b7c5e9 100644
--- a/tools/gfx/d3d12/d3d12-shader-table.cpp
+++ b/tools/gfx/d3d12/d3d12-shader-table.cpp
@@ -20,11 +20,14 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
uint32_t raygenTableSize = m_rayGenShaderCount * kRayGenRecordSize;
uint32_t missTableSize = m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
uint32_t hitgroupTableSize = m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ uint32_t callableTableSize = m_callableShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
m_rayGenTableOffset = 0;
m_missTableOffset = raygenTableSize;
m_hitGroupTableOffset = (uint32_t)D3DUtil::calcAligned(
m_missTableOffset + missTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
- uint32_t tableSize = m_hitGroupTableOffset + hitgroupTableSize;
+ m_callableTableOffset = (uint32_t)D3DUtil::calcAligned(
+ m_hitGroupTableOffset + hitgroupTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
+ uint32_t tableSize = m_callableTableOffset + callableTableSize;
auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
ComPtr<IBufferResource> bufferResource;
@@ -88,6 +91,13 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + i],
m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + i]);
}
+ for (uint32_t i = 0; i < m_callableShaderCount; i++)
+ {
+ copyShaderIdInto(
+ stagingBufferPtr + m_callableTableOffset + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
+ m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + m_hitGroupCount + i],
+ m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + m_hitGroupCount + i]);
+ }
stagingBuffer->unmap(nullptr);
encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize);
diff --git a/tools/gfx/d3d12/d3d12-shader-table.h b/tools/gfx/d3d12/d3d12-shader-table.h
index 8cfd74874..6b39966b3 100644
--- a/tools/gfx/d3d12/d3d12-shader-table.h
+++ b/tools/gfx/d3d12/d3d12-shader-table.h
@@ -16,6 +16,7 @@ public:
uint32_t m_rayGenTableOffset;
uint32_t m_missTableOffset;
uint32_t m_hitGroupTableOffset;
+ uint32_t m_callableTableOffset;
DeviceImpl* m_device;