summaryrefslogtreecommitdiffstats
path: root/tools/gfx-unit-test/link-time-type-layout.cpp
blob: 3493ea82dcfa9986e4f05a91a76878845506d5e1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
#include "core/slang-blob.h"
#include "gfx-test-util.h"
#include "slang-rhi.h"
#include "unit-test/slang-unit-test.h"

using namespace rhi;

namespace gfx_test
{

static void diagnoseIfNeeded(Slang::ComPtr<slang::IBlob>& diagnosticsBlob)
{
    if (diagnosticsBlob && diagnosticsBlob->getBufferSize() > 0)
    {
        fprintf(stderr, "%s\n", (const char*)diagnosticsBlob->getBufferPointer());
    }
}

static Slang::Result loadSpirvProgram(
    rhi::IDevice* device,
    Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram,
    slang::ProgramLayout*& slangReflection)
{
    // main.slang: declares the interface and extern struct S, and the vertex shader.
    const char* mainSrc = R"(
        public interface IFoo
        {
            public float4 getFoo();
        };
        public extern struct S : IFoo;

        [shader("vertex")]
        float4 vertexMain(S params) : SV_Position
        {
            return params.getFoo();
        }
    )";

    // foo.slang: defines S with its field layout and its implementation of getFoo().
    const char* fooSrc = R"(
        import main;

        export public struct S : IFoo
        {
            public float4 getFoo() { return this.foo; }
            float4 foo : POSITION;
        }
    )";

    Slang::ComPtr<slang::ISession> slangSession;
    SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
    Slang::ComPtr<slang::IBlob> diagnosticsBlob;

    // Create blobs for the two modules.
    auto mainBlob = Slang::UnownedRawBlob::create(mainSrc, strlen(mainSrc));
    auto fooBlob = Slang::UnownedRawBlob::create(fooSrc, strlen(fooSrc));

    // Load modules from source.
    slang::IModule* mainModule = slangSession->loadModuleFromSource("main", "main.slang", mainBlob);
    slang::IModule* fooModule = slangSession->loadModuleFromSource("foo", "foo.slang", fooBlob);

    // Find the entry point from main.slang
    Slang::ComPtr<slang::IEntryPoint> vsEntryPoint;
    SLANG_RETURN_ON_FAIL(mainModule->findEntryPointByName("vertexMain", vsEntryPoint.writeRef()));

    // Compose the program from both modules and the entry point.
    Slang::List<slang::IComponentType*> componentTypes;
    componentTypes.add(mainModule);
    componentTypes.add(fooModule);
    componentTypes.add(vsEntryPoint);

    Slang::ComPtr<slang::IComponentType> composedProgram;
    SLANG_RETURN_ON_FAIL(slangSession->createCompositeComponentType(
        componentTypes.getBuffer(),
        componentTypes.getCount(),
        composedProgram.writeRef(),
        diagnosticsBlob.writeRef()));
    diagnoseIfNeeded(diagnosticsBlob);

    // Link the composite program.
    Slang::ComPtr<slang::IComponentType> linkedProgram;
    SLANG_RETURN_ON_FAIL(
        composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()));
    diagnoseIfNeeded(diagnosticsBlob);

    // Retrieve the reflection information.
    composedProgram = linkedProgram;
    slangReflection = composedProgram->getLayout();

    // Create a shader program that will generate SPIRV code.
    ShaderProgramDesc programDesc = {};
    programDesc.slangGlobalScope = composedProgram.get();
    auto shaderProgram = device->createShaderProgram(programDesc);
    outShaderProgram = shaderProgram;

    // Force SPIRV generation by explicitly requesting it
    Slang::ComPtr<slang::IBlob> spirvBlob;
    Slang::ComPtr<slang::IBlob> spirvDiagnostics;

    // Request SPIRV code generation for the vertex shader entry point
    auto targetIndex = 0;     // Assuming this is the first/only target
    auto entryPointIndex = 0; // Assuming this is the first/only entry point

    auto result = composedProgram->getEntryPointCode(
        entryPointIndex,
        targetIndex,
        spirvBlob.writeRef(),
        spirvDiagnostics.writeRef());

    if (SLANG_FAILED(result))
    {
        if (spirvDiagnostics && spirvDiagnostics->getBufferSize() > 0)
        {
            fprintf(
                stderr,
                "SPIRV generation failed: %s\n",
                (const char*)spirvDiagnostics->getBufferPointer());
        }
        return result;
    }

    // Verify we actually got SPIRV code
    if (!spirvBlob || spirvBlob->getBufferSize() == 0)
    {
        return SLANG_FAIL;
    }

    return SLANG_OK;
}

// Function to validate the type layout of struct S
static void validateStructSLayout(UnitTestContext* context, slang::ProgramLayout* slangReflection)
{
    // Check reflection is available
    SLANG_CHECK(slangReflection != nullptr);

    // Get the entry point layout for vertexMain
    auto entryPointCount = slangReflection->getEntryPointCount();
    slang::EntryPointLayout* entryPointLayout = nullptr;

    for (unsigned int i = 0; i < entryPointCount; i++)
    {
        auto currentEntryPoint = slangReflection->getEntryPointByIndex(i);
        const char* name = currentEntryPoint->getName();

        if (strcmp(name, "vertexMain") == 0)
        {
            entryPointLayout = currentEntryPoint;
            break;
        }
    }

    SLANG_CHECK_MSG(entryPointLayout != nullptr, "Could not find vertexMain entry point");

    // Get the parameter count for the entry point
    auto paramCount = entryPointLayout->getParameterCount();
    SLANG_CHECK_MSG(paramCount >= 1, "Entry point has no parameters");

    // Get the first parameter, which should be of type S
    auto paramLayout = entryPointLayout->getParameterByIndex(0);
    SLANG_CHECK_MSG(paramLayout != nullptr, "Could not get first parameter layout");

    // Get the type layout of the parameter
    auto typeLayout = paramLayout->getTypeLayout();
    SLANG_CHECK_MSG(typeLayout != nullptr, "Parameter has no type layout");

    // Check if it's a struct type
    auto kind = typeLayout->getKind();
    SLANG_CHECK_MSG(kind == slang::TypeReflection::Kind::Struct, "Parameter is not a struct type");

    // Get the field count
    auto fieldCount = typeLayout->getFieldCount();
    SLANG_CHECK_MSG(fieldCount >= 1, "Struct has no fields");

    // Check for the 'foo' field
    bool foundFooField = false;
    for (unsigned int i = 0; i < fieldCount; i++)
    {
        auto fieldLayout = typeLayout->getFieldByIndex(i);
        const char* fieldName = fieldLayout->getName();

        if (fieldName && strcmp(fieldName, "foo") == 0)
        {
            foundFooField = true;

            // Check that it's a float4 type
            auto fieldTypeLayout = fieldLayout->getTypeLayout();
            auto fieldTypeKind = fieldTypeLayout->getKind();

            SLANG_CHECK_MSG(
                fieldTypeKind == slang::TypeReflection::Kind::Vector,
                "Field 'foo' is not a vector type");

            auto elementCount = fieldTypeLayout->getElementCount();
            SLANG_CHECK_MSG(elementCount == 4, "Field 'foo' is not a 4-element vector");

            break;
        }
    }

    SLANG_CHECK_MSG(foundFooField, "Could not find field 'foo' in struct S");
}

void linkTimeTypeLayoutImpl(rhi::IDevice* device, UnitTestContext* context)
{
    Slang::ComPtr<rhi::IShaderProgram> shaderProgram;
    slang::ProgramLayout* slangReflection = nullptr;

    auto result = loadSpirvProgram(device, shaderProgram, slangReflection);
    SLANG_CHECK(SLANG_SUCCEEDED(result));

    // Validate the struct S layout
    validateStructSLayout(context, slangReflection);

    // Create a graphics pipeline to verify SPIRV code generation works
    InputElementDesc inputElements[] = {
        {"POSITION", 0, Format::RGBA32Float, 0, 0}, // S struct as POSITION semantic (float4)
    };
    VertexStreamDesc vertexStreams[] = {
        {16, InputSlotClass::PerVertex, 0}, // sizeof(float4)
    };
    InputLayoutDesc inputLayoutDesc = {};
    inputLayoutDesc.inputElementCount = SLANG_COUNT_OF(inputElements);
    inputLayoutDesc.inputElements = inputElements;
    inputLayoutDesc.vertexStreamCount = SLANG_COUNT_OF(vertexStreams);
    inputLayoutDesc.vertexStreams = vertexStreams;
    auto inputLayout = device->createInputLayout(inputLayoutDesc);
    SLANG_CHECK(inputLayout != nullptr);

    RenderPipelineDesc pipelineDesc = {};
    pipelineDesc.program = shaderProgram.get();
    pipelineDesc.inputLayout = inputLayout;
    pipelineDesc.primitiveTopology = PrimitiveTopology::TriangleList;

    ComPtr<IRenderPipeline> pipelineState;
    auto pipelineResult = device->createRenderPipeline(pipelineDesc, pipelineState.writeRef());
    SLANG_CHECK(SLANG_SUCCEEDED(pipelineResult));
}

//
// This test verifies that type layout information correctly propagates through
// the Slang compilation pipeline when types are defined in modules other than where they are used.
// Specifically, it tests
// that when using an extern struct that's defined in a separate module:
//
// 1. The struct definition is properly linked across module boundaries
// 2. The complete type layout information is available in the reflection data
// 3. SPIRV code generation succeeds with the linked type information (this
// failed before when layout information was required during code generation)
//

SLANG_UNIT_TEST(linkTimeTypeLayout)
{
    runTestImpl(linkTimeTypeLayoutImpl, unitTestContext, DeviceType::Vulkan);
}

} // namespace gfx_test