summaryrefslogtreecommitdiff
path: root/tools/gfx-unit-test/gfx-test-util.cpp
blob: 1b23047a120ae204e19040d13a735098129547db (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include "gfx-test-util.h"
#include "tools/unit-test/slang-unit-test.h"

#include <slang-com-ptr.h>

using Slang::ComPtr;

namespace gfx_test
{
    void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob)
    {
        if (diagnosticsBlob != nullptr)
        {
            getTestReporter()->message(TestMessageType::Info, (const char*)diagnosticsBlob->getBufferPointer());
        }
    }

    Slang::Result loadComputeProgram(
        gfx::IDevice* device,
        Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
        const char* shaderModuleName,
        const char* entryPointName,
        slang::ProgramLayout*& slangReflection)
    {
        Slang::ComPtr<slang::ISession> slangSession;
        SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
        Slang::ComPtr<slang::IBlob> diagnosticsBlob;
        slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef());
        diagnoseIfNeeded(diagnosticsBlob);
        if (!module)
            return SLANG_FAIL;

        ComPtr<slang::IEntryPoint> computeEntryPoint;
        SLANG_RETURN_ON_FAIL(
            module->findEntryPointByName(entryPointName, computeEntryPoint.writeRef()));

        Slang::List<slang::IComponentType*> componentTypes;
        componentTypes.add(module);
        componentTypes.add(computeEntryPoint);

        Slang::ComPtr<slang::IComponentType> composedProgram;
        SlangResult result = slangSession->createCompositeComponentType(
            componentTypes.getBuffer(),
            componentTypes.getCount(),
            composedProgram.writeRef(),
            diagnosticsBlob.writeRef());
        diagnoseIfNeeded(diagnosticsBlob);
        SLANG_RETURN_ON_FAIL(result);
        slangReflection = composedProgram->getLayout();

        gfx::IShaderProgram::Desc programDesc = {};
        programDesc.pipelineType = gfx::PipelineType::Compute;
        programDesc.slangProgram = composedProgram.get();

        auto shaderProgram = device->createProgram(programDesc);

        outShaderProgram = shaderProgram;
        return SLANG_OK;
    }

    void compareComputeResult(gfx::IDevice* device, gfx::IBufferResource* buffer, uint8_t* expectedResult, size_t expectedBufferSize)
    {
        // Read back the results.
        ComPtr<ISlangBlob> resultBlob;
        GFX_CHECK_CALL_ABORT(device->readBufferResource(
            buffer, 0, expectedBufferSize, resultBlob.writeRef()));
        if (resultBlob->getBufferSize() < expectedBufferSize)
        {
            getTestReporter()->addResult(TestResult::Fail);
            return;
        }

        // Compare results.
        auto result = reinterpret_cast<const uint8_t*>(resultBlob->getBufferPointer());
        for (int i = 0; i < expectedBufferSize; i++)
        {
            if (expectedResult[i] != result[i])
            {
                getTestReporter()->addResult(TestResult::Fail);
                return;
            }

        }
        getTestReporter()->addResult(TestResult::Pass);
    }
}