summaryrefslogtreecommitdiffstats
path: root/source/slang-llvm/slang-llvm-filecheck.cpp
blob: 492b0e0d419551f8f838f60069e389ac5535b815 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
// This file contains a definition of LLVMFileCheck, an implementaion for
// IFileCheck.

#include "slang-com-helper.h"
#include "slang-com-ptr.h"
#include "slang.h"

#include <core/slang-com-object.h>
#include <llvm/ADT/SmallString.h>
#include <llvm/FileCheck/FileCheck.h>
#include <llvm/Support/raw_ostream.h>
#include <slang-test/filecheck.h>

namespace slang_llvm
{

using namespace llvm;
using namespace Slang;

class LLVMFileCheck : IFileCheck, ComBaseObject
{
public:
    // ICastable
    virtual SLANG_NO_THROW void* SLANG_MCALL castAs(const Guid& guid) override;

    // IUnknown
    SLANG_COM_BASE_IUNKNOWN_ALL
    void* getInterface(const Guid& guid);
    void* getObject(const Guid& guid);

    // IFileCheck
    virtual TestResult SLANG_MCALL performTest(
        const char* programName,
        const char* rulesFilePath,
        const char* fileCheckPrefix,
        const char* stringToCheck,
        const char* stringToCheckName,
        ReportDiagnostic testReporter,
        void* reporterData,
        bool colorDiagnosticOutput) noexcept override;

private:
    // Everything we need to pass through LLVM back to our diagnostic handler
    struct ReporterData
    {
        ReportDiagnostic reportFun;
        // User data from the caller of performTest
        void* data;
        bool colorDiagnosticOutput;
        const char* programName;
        TestMessageType testMessageType;
    };

    static void fileCheckDiagHandler(const SMDiagnostic& diag, void* reporterData);
};

class DisplayedStringOStream : public raw_string_ostream
{
public:
    DisplayedStringOStream(std::string& s)
        : raw_string_ostream(s)
    {
    }
    virtual bool is_displayed() const override { return true; };
};

void LLVMFileCheck::fileCheckDiagHandler(const SMDiagnostic& diag, void* dataPtr)
{
    const ReporterData& reporterData = *reinterpret_cast<ReporterData*>(dataPtr);
    std::string s;
    DisplayedStringOStream o(s);
    o.enable_colors(reporterData.colorDiagnosticOutput);
    diag.print(reporterData.programName, o);
    reporterData.reportFun(reporterData.data, TestMessageType::TestFailure, s.c_str());
}

TestResult LLVMFileCheck::performTest(
    const char* const programName,
    const char* const rulesFilePath,
    const char* const fileCheckPrefix,
    const char* const stringToCheck,
    const char* const stringToCheckName,
    const ReportDiagnostic testReporter,
    void* const userReporterData,
    const bool colorDiagnosticOutput) noexcept
{
    //
    // Set up our FileCheck session
    //
    FileCheckRequest fcReq;
    fcReq.CheckPrefixes = {fileCheckPrefix};
    FileCheck fc(fcReq);

    //
    // Set up the LLVM source manager for diagnostic output from our input buffers
    //
    SourceMgr sourceManager;
    auto rulesTextOrError = MemoryBuffer::getFile(rulesFilePath, true);
    if (std::error_code err = rulesTextOrError.getError())
    {
        const std::string message = "Unable to load FileCheck rules file: " + err.message();
        testReporter(userReporterData, TestMessageType::RunError, message.c_str());
        return TestResult::Fail;
    }
    SmallString<4096> rulesBuffer;
    StringRef rulesStringRef = fc.CanonicalizeFile(*rulesTextOrError.get(), rulesBuffer);
    sourceManager.AddNewSourceBuffer(
        MemoryBuffer::getMemBuffer(rulesStringRef, rulesFilePath),
        SMLoc());

    SmallString<4096> inputBuffer;
    const auto inputStringMB =
        MemoryBuffer::getMemBuffer(StringRef(stringToCheck), stringToCheckName, false);
    const StringRef inputStringRef = fc.CanonicalizeFile(*inputStringMB.get(), inputBuffer);
    sourceManager.AddNewSourceBuffer(
        MemoryBuffer::getMemBuffer(inputStringRef, stringToCheckName),
        SMLoc());

    // Initialize this with a 'RunError' failure type. We'll "downgrade" this to
    // 'TestFailure' once we've done the FileCheck setup.
    ReporterData reporterData{
        testReporter,
        userReporterData,
        colorDiagnosticOutput,
        programName,
        TestMessageType::RunError};
    sourceManager.setDiagHandler(fileCheckDiagHandler, static_cast<void*>(&reporterData));

    auto checkPrefix = fc.buildCheckPrefixRegex();
    if (fc.readCheckFile(sourceManager, rulesStringRef, checkPrefix))
    {
        // FileCheck failed to find or understand any FileCheck rules in
        // the input file, automatic fail, and reported to the diag handler .
        return TestResult::Fail;
    }

    // We've done the FileCheck setup, so make sure that any diagnostics
    // reported on from here are just a regular test failure.
    reporterData.testMessageType = TestMessageType::TestFailure;
    if (!fc.checkInput(sourceManager, inputStringRef))
    {
        // An ordinary failure, the FileCheck rules didn't match
        return TestResult::Fail;
    }

    return TestResult::Pass;
}

void* LLVMFileCheck::castAs(const Guid& guid)
{
    if (auto ptr = getInterface(guid))
    {
        return ptr;
    }
    return getObject(guid);
}

void* LLVMFileCheck::getInterface(const Guid& guid)
{
    if (guid == ISlangUnknown::getTypeGuid() || guid == ICastable::getTypeGuid() ||
        guid == IFileCheck::getTypeGuid())
    {
        return static_cast<IFileCheck*>(this);
    }
    return nullptr;
}

void* LLVMFileCheck::getObject(const Guid& guid)
{
    SLANG_UNUSED(guid);
    return nullptr;
}

} // namespace slang_llvm

extern "C" SLANG_DLL_EXPORT SlangResult
createLLVMFileCheck_V1(const SlangUUID& intfGuid, void** out)
{
    Slang::ComPtr<slang_llvm::LLVMFileCheck> fileCheck(new slang_llvm::LLVMFileCheck);

    if (auto ptr = fileCheck->castAs(intfGuid))
    {
        fileCheck.detach();
        *out = ptr;
        return SLANG_OK;
    }

    return SLANG_E_NO_INTERFACE;
}