summaryrefslogtreecommitdiffstats
path: root/tools/gfx-unit-test/link-time-type-layout-cache.cpp
blob: b04733fb699c97bc4e5555d2b35160d19369657b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
#include "core/slang-blob.h"
#include "gfx-test-util.h"
#include "slang-rhi.h"
#include "unit-test/slang-unit-test.h"

using namespace rhi;

namespace gfx_test
{

static void diagnoseIfNeeded(Slang::ComPtr<slang::IBlob>& diagnosticsBlob)
{
    if (diagnosticsBlob && diagnosticsBlob->getBufferSize() > 0)
    {
        fprintf(stderr, "%s\n", (const char*)diagnosticsBlob->getBufferPointer());
    }
}

// Function to find and validate the struct S type layout
static void validateStructSLayout(
    UnitTestContext* context,
    slang::ProgramLayout* slangReflection,
    int expectedFieldCount)
{
    // Check reflection is available
    SLANG_CHECK(slangReflection != nullptr);

    // Get the entry point layout for vertexMain
    auto entryPointCount = slangReflection->getEntryPointCount();
    slang::EntryPointLayout* entryPointLayout = nullptr;

    for (unsigned int i = 0; i < entryPointCount; i++)
    {
        auto currentEntryPoint = slangReflection->getEntryPointByIndex(i);
        const char* name = currentEntryPoint->getName();

        if (strcmp(name, "vertexMain") == 0)
        {
            entryPointLayout = currentEntryPoint;
            break;
        }
    }

    SLANG_CHECK_MSG(entryPointLayout != nullptr, "Could not find vertexMain entry point");

    // Get the parameter count for the entry point
    auto paramCount = entryPointLayout->getParameterCount();
    SLANG_CHECK_MSG(paramCount >= 1, "Entry point has no parameters");

    // Get the first parameter, which should be of type S
    auto paramLayout = entryPointLayout->getParameterByIndex(0);
    SLANG_CHECK_MSG(paramLayout != nullptr, "Could not get first parameter layout");

    // Get the type layout of the parameter
    auto typeLayout = paramLayout->getTypeLayout();
    SLANG_CHECK_MSG(typeLayout != nullptr, "Parameter has no type layout");

    // Check if it's a struct type
    auto kind = typeLayout->getKind();
    SLANG_CHECK_MSG(kind == slang::TypeReflection::Kind::Struct, "Parameter is not a struct type");

    // Get the field count
    auto fieldCount = typeLayout->getFieldCount();
    SLANG_CHECK_MSG(fieldCount == expectedFieldCount, "Struct has unexpected number of fields");

    // If we expect fields, check for the 'foo' field
    if (expectedFieldCount > 0)
    {
        bool foundFooField = false;
        for (unsigned int i = 0; i < fieldCount; i++)
        {
            auto fieldLayout = typeLayout->getFieldByIndex(i);
            const char* fieldName = fieldLayout->getName();

            if (fieldName && strcmp(fieldName, "foo") == 0)
            {
                foundFooField = true;

                // Check that it's a float4 type
                auto fieldTypeLayout = fieldLayout->getTypeLayout();
                auto fieldTypeKind = fieldTypeLayout->getKind();

                SLANG_CHECK_MSG(
                    fieldTypeKind == slang::TypeReflection::Kind::Vector,
                    "Field 'foo' is not a vector type");

                auto elementCount = fieldTypeLayout->getElementCount();
                SLANG_CHECK_MSG(elementCount == 4, "Field 'foo' is not a 4-element vector");

                break;
            }
        }

        SLANG_CHECK_MSG(foundFooField, "Could not find field 'foo' in struct S");
    }
}

void linkTimeTypeLayoutCacheImpl(rhi::IDevice* device, UnitTestContext* context)
{
    // main.slang: declares the interface and extern struct S
    const char* mainSrc = R"(
        public interface IFoo
        {
            public float4 getFoo();
        };
        public extern struct S : IFoo;

        [shader("vertex")]
        float4 vertexMain(S params) : SV_Position
        {
            return params.getFoo();
        }
    )";

    // foo.slang: defines S with its field layout and its implementation of getFoo()
    const char* fooSrc = R"(
        import main;

        export public struct S : IFoo
        {
            public float4 getFoo() { return this.foo; }
            float4 foo;
        }
    )";

    Slang::ComPtr<slang::ISession> slangSession;
    SLANG_CHECK(SLANG_SUCCEEDED(device->getSlangSession(slangSession.writeRef())));
    Slang::ComPtr<slang::IBlob> diagnosticsBlob;

    // Create blobs for the two modules
    auto mainBlob = Slang::UnownedRawBlob::create(mainSrc, strlen(mainSrc));
    auto fooBlob = Slang::UnownedRawBlob::create(fooSrc, strlen(fooSrc));

    // STEP 1: Load just the main module
    slang::IModule* mainModule = slangSession->loadModuleFromSource("main", "main.slang", mainBlob);
    SLANG_CHECK_MSG(mainModule != nullptr, "Failed to load main module");

    // Find the entry point from main.slang
    Slang::ComPtr<slang::IEntryPoint> vsEntryPoint;
    SLANG_CHECK(
        SLANG_SUCCEEDED(mainModule->findEntryPointByName("vertexMain", vsEntryPoint.writeRef())));

    // Create a program with just the main module
    Slang::List<slang::IComponentType*> componentTypes;
    componentTypes.add(mainModule);
    componentTypes.add(vsEntryPoint);

    Slang::ComPtr<slang::IComponentType> composedProgram;
    SLANG_CHECK(SLANG_SUCCEEDED(slangSession->createCompositeComponentType(
        componentTypes.getBuffer(),
        componentTypes.getCount(),
        composedProgram.writeRef(),
        diagnosticsBlob.writeRef())));
    diagnoseIfNeeded(diagnosticsBlob);

    // Link the main-only program
    Slang::ComPtr<slang::IComponentType> linkedProgram;
    SLANG_CHECK(SLANG_SUCCEEDED(
        composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef())));
    diagnoseIfNeeded(diagnosticsBlob);

    // Get the reflection information
    auto mainOnlyReflection = linkedProgram->getLayout();

    // Verify that struct S has no fields in the main-only program
    validateStructSLayout(context, mainOnlyReflection, 0);

    // STEP 2: Load the foo module and link it into the same program
    slang::IModule* fooModule = slangSession->loadModuleFromSource("foo", "foo.slang", fooBlob);
    SLANG_CHECK_MSG(fooModule != nullptr, "Failed to load foo module");

    // Create a new composite program that includes the foo module
    componentTypes.clear();
    componentTypes.add(mainModule);
    componentTypes.add(fooModule);
    componentTypes.add(vsEntryPoint);

    composedProgram = nullptr;
    SLANG_CHECK(SLANG_SUCCEEDED(slangSession->createCompositeComponentType(
        componentTypes.getBuffer(),
        componentTypes.getCount(),
        composedProgram.writeRef(),
        diagnosticsBlob.writeRef())));
    diagnoseIfNeeded(diagnosticsBlob);

    // Link the updated program
    linkedProgram = nullptr;
    SLANG_CHECK(SLANG_SUCCEEDED(
        composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef())));
    diagnoseIfNeeded(diagnosticsBlob);

    // Get the updated reflection information
    auto updatedReflection = linkedProgram->getLayout();

    // Verify that struct S now has one field in the updated program
    validateStructSLayout(context, updatedReflection, 1);
}

SLANG_UNIT_TEST(linkTimeTypeLayoutCache)
{
    runTestImpl(linkTimeTypeLayoutCacheImpl, unitTestContext, DeviceType::Vulkan);
}

} // namespace gfx_test