summaryrefslogtreecommitdiffstats
path: root/tools/gfx/d3d12/d3d12-shader-table.cpp
blob: 2773578b8682a1fa9a5df330e344d5c94de11f03 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// d3d12-shader-table.cpp
#include "d3d12-shader-table.h"

#include "d3d12-device.h"
#include "d3d12-pipeline-state.h"
#include "d3d12-transient-heap.h"

namespace gfx
{
namespace d3d12
{

using namespace Slang;

RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
    PipelineStateBase* pipeline,
    TransientResourceHeapBase* transientHeap,
    IResourceCommandEncoder* encoder)
{
    uint32_t raygenTableSize = m_rayGenShaderCount * kRayGenRecordSize;
    uint32_t missTableSize = m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
    uint32_t hitgroupTableSize = m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
    m_rayGenTableOffset = 0;
    m_missTableOffset = raygenTableSize;
    m_hitGroupTableOffset = (uint32_t)D3DUtil::calcAligned(
        m_missTableOffset + missTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
    uint32_t tableSize = m_hitGroupTableOffset + hitgroupTableSize;

    auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
    ComPtr<IBufferResource> bufferResource;
    IBufferResource::Desc bufferDesc = {};
    bufferDesc.memoryType = gfx::MemoryType::DeviceLocal;
    bufferDesc.defaultState = ResourceState::General;
    bufferDesc.type = IResource::Type::Buffer;
    bufferDesc.sizeInBytes = tableSize;
    m_device->createBufferResource(bufferDesc, nullptr, bufferResource.writeRef());

    ComPtr<ID3D12StateObjectProperties> stateObjectProperties;
    pipelineImpl->m_stateObject->QueryInterface(stateObjectProperties.writeRef());

    TransientResourceHeapImpl* transientHeapImpl =
        static_cast<TransientResourceHeapImpl*>(transientHeap);

    IBufferResource* stagingBuffer = nullptr;
    Offset stagingBufferOffset = 0;
    transientHeapImpl->allocateStagingBuffer(
        tableSize, stagingBuffer, stagingBufferOffset, MemoryType::Upload);

    assert(stagingBuffer);
    void* stagingPtr = nullptr;
    stagingBuffer->map(nullptr, &stagingPtr);

    auto copyShaderIdInto = [&](void* dest, String& name, const ShaderRecordOverwrite& overwrite)
    {
        if (name.getLength())
        {
            void* shaderId = stateObjectProperties->GetShaderIdentifier(name.toWString().begin());
            memcpy(dest, shaderId, D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES);
        }
        else
        {
            memset(dest, 0, D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES);
        }
        if (overwrite.size)
        {
            memcpy((uint8_t*)dest + overwrite.offset, overwrite.data, overwrite.size);
        }
    };

    uint8_t* stagingBufferPtr = (uint8_t*)stagingPtr + stagingBufferOffset;
    for (uint32_t i = 0; i < m_rayGenShaderCount; i++)
    {
        copyShaderIdInto(
            stagingBufferPtr + m_rayGenTableOffset + kRayGenRecordSize * i,
            m_shaderGroupNames[i],
            m_recordOverwrites[i]);
    }
    for (uint32_t i = 0; i < m_missShaderCount; i++)
    {
        copyShaderIdInto(
            stagingBufferPtr + m_missTableOffset + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
            m_shaderGroupNames[m_rayGenShaderCount + i],
            m_recordOverwrites[m_rayGenShaderCount + i]);
    }
    for (uint32_t i = 0; i < m_hitGroupCount; i++)
    {
        copyShaderIdInto(
            stagingBufferPtr + m_hitGroupTableOffset + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
            m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + i],
            m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + i]);
    }

    stagingBuffer->unmap(nullptr);
    encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize);
    encoder->bufferBarrier(
        1,
        bufferResource.readRef(),
        gfx::ResourceState::CopyDestination,
        gfx::ResourceState::ShaderResource);
    RefPtr<BufferResource> resultPtr = static_cast<BufferResource*>(bufferResource.get());
    return _Move(resultPtr);
}

} // namespace d3d12
} // namespace gfx