summaryrefslogtreecommitdiffstats
path: root/tools/gfx-unit-test/buffer-barrier-test.cpp
blob: 1d965628afba802f6c30a8c045464bda07cc12ae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include "core/slang-basic.h"
#include "gfx-test-util.h"
#include "slang-rhi.h"
#include "unit-test/slang-unit-test.h"

#include <slang-rhi/shader-cursor.h>

using namespace rhi;

namespace gfx_test
{
struct Shader
{
    ComPtr<IShaderProgram> program;
    slang::ProgramLayout* reflection = nullptr;
    ComputePipelineDesc pipelineDesc = {};
    ComPtr<IComputePipeline> pipeline;
};

struct Buffer
{
    BufferDesc desc;
    ComPtr<IBuffer> buffer;
    ComPtr<ITextureView> view;
};

ComPtr<IBuffer> createFloatBuffer(
    IDevice* device,
    bool unorderedAccess,
    size_t elementCount,
    float* initialData = nullptr)
{
    BufferDesc desc = {};
    desc.size = elementCount * sizeof(float);
    desc.elementSize = sizeof(float);
    desc.format = Format::Undefined;
    desc.memoryType = MemoryType::DeviceLocal;
    desc.usage =
        BufferUsage::ShaderResource | BufferUsage::CopyDestination | BufferUsage::CopySource;
    if (unorderedAccess)
        desc.usage |= BufferUsage::UnorderedAccess;

    ComPtr<IBuffer> buffer;
    GFX_CHECK_CALL_ABORT(device->createBuffer(desc, (void*)initialData, buffer.writeRef()));
    return buffer;
}

void barrierTestImpl(IDevice* device, UnitTestContext* context)
{
    Shader programA;
    Shader programB;
    GFX_CHECK_CALL_ABORT(loadComputeProgram(
        device,
        programA.program,
        "buffer-barrier-test",
        "computeA",
        programA.reflection));
    GFX_CHECK_CALL_ABORT(loadComputeProgram(
        device,
        programB.program,
        "buffer-barrier-test",
        "computeB",
        programB.reflection));
    programA.pipelineDesc.program = programA.program.get();
    programB.pipelineDesc.program = programB.program.get();
    GFX_CHECK_CALL_ABORT(
        device->createComputePipeline(programA.pipelineDesc, programA.pipeline.writeRef()));

    GFX_CHECK_CALL_ABORT(
        device->createComputePipeline(programB.pipelineDesc, programB.pipeline.writeRef()));

    float initialData[] = {1.0f, 2.0f, 3.0f, 4.0f};
    ComPtr<IBuffer> inputBuffer = createFloatBuffer(device, false, 4, initialData);
    ComPtr<IBuffer> intermediateBuffer = createFloatBuffer(device, true, 4, nullptr);
    ComPtr<IBuffer> outputBuffer = createFloatBuffer(device, true, 4, nullptr);

    // We have done all the set up work, now it is time to start recording a command buffer for
    // GPU execution.
    {
        auto queue = device->getQueue(QueueType::Graphics);
        auto commandEncoder = queue->createCommandEncoder();

        // Write inputBuffer data to intermediateBuffer
        {
            auto passEncoder = commandEncoder->beginComputePass();
            auto rootObject = passEncoder->bindPipeline(programA.pipeline);

            ShaderCursor cursor(rootObject->getEntryPoint(0));
            cursor["inBuffer"].setBinding(inputBuffer);
            cursor["outBuffer"].setBinding(intermediateBuffer);
            passEncoder->dispatchCompute(1, 1, 1);
            passEncoder->end();
        }

        // Resource transition is automatically handled.

        // Write intermediateBuffer data to outputBuffer

        {
            auto passEncoder = commandEncoder->beginComputePass();
            auto rootObject = passEncoder->bindPipeline(programB.pipeline);
            ShaderCursor cursor(rootObject->getEntryPoint(0));
            cursor["inBuffer"].setBinding(intermediateBuffer);
            cursor["outBuffer"].setBinding(outputBuffer);
            passEncoder->dispatchCompute(1, 1, 1);
            passEncoder->end();
        }


        queue->submit(commandEncoder->finish());
        queue->waitOnHost();
    }


    compareComputeResult(device, outputBuffer, makeArray<float>(11.0f, 12.0f, 13.0f, 14.0f));
}

void barrierTestAPI(UnitTestContext* context, DeviceType deviceType)
{
    Slang::List<const char*> searchPaths = {"", "../../tools/gfx-unit-test", "tools/gfx-unit-test"};
    auto device = createTestingDevice(context, deviceType, searchPaths);

    if (!device)
    {
        SLANG_IGNORE_TEST
    }

    barrierTestImpl(device.get(), context);
}

SLANG_UNIT_TEST(bufferBarrierVulkan)
{
    barrierTestAPI(unitTestContext, DeviceType::Vulkan);
}

} // namespace gfx_test