summaryrefslogtreecommitdiffstats
path: root/tools/gfx-unit-test/link-time-type-layout-nested.cpp
blob: 2c2c83a947fb20271bab8002f5f6ae118dadb1be (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#include "core/slang-blob.h"
#include "gfx-test-util.h"
#include "slang-rhi.h"
#include "unit-test/slang-unit-test.h"

using namespace rhi;

namespace gfx_test
{

static void diagnoseIfNeeded(Slang::ComPtr<slang::IBlob>& diagnosticsBlob)
{
    if (diagnosticsBlob && diagnosticsBlob->getBufferSize() > 0)
    {
        fprintf(stderr, "%s\n", (const char*)diagnosticsBlob->getBufferPointer());
    }
}

static Slang::Result loadProgram(
    rhi::IDevice* device,
    Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram,
    slang::ProgramLayout*& slangReflection)
{
    // main.slang: declares the interface, extern struct Inner, and Outer struct with Inner field
    const char* mainSrc = R"(
        // Define an interface
        public interface IFoo
        {
            public float4 getFoo();
        };

        // Define an extern struct that implements the interface
        public extern struct Inner : IFoo;

        // Define a regular struct that contains an Inner field
        public struct Outer
        {
            float2 position;
            Inner innerData;
            float2 texCoord;
        };

        // Vertex shader entry point that takes an Outer parameter
        [shader("vertex")]
        float4 vertexMain(Outer params) : SV_Position
        {
            return float4(params.position, 0.0f, 1.0f) + params.innerData.getFoo();
        }
    )";

    // inner.slang: defines Inner with its field layout and its implementation of getFoo()
    const char* innerSrc = R"(
        import main;

        // Define the implementation of Inner with its field layout
        export public struct Inner : IFoo
        {
            public float4 getFoo() { return this.data; }
            float4 data;
        }
    )";

    Slang::ComPtr<slang::ISession> slangSession;
    SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
    Slang::ComPtr<slang::IBlob> diagnosticsBlob;

    // Create blobs for the two modules
    auto mainBlob = Slang::UnownedRawBlob::create(mainSrc, strlen(mainSrc));
    auto innerBlob = Slang::UnownedRawBlob::create(innerSrc, strlen(innerSrc));

    // Load modules from source
    slang::IModule* mainModule = slangSession->loadModuleFromSource("main", "main.slang", mainBlob);
    slang::IModule* innerModule =
        slangSession->loadModuleFromSource("inner", "inner.slang", innerBlob);

    // Find the entry point from main.slang
    Slang::ComPtr<slang::IEntryPoint> vsEntryPoint;
    SLANG_RETURN_ON_FAIL(mainModule->findEntryPointByName("vertexMain", vsEntryPoint.writeRef()));

    // Compose the program from both modules and the entry point
    Slang::List<slang::IComponentType*> componentTypes;
    componentTypes.add(mainModule);
    componentTypes.add(innerModule);
    componentTypes.add(vsEntryPoint);

    Slang::ComPtr<slang::IComponentType> composedProgram;
    SLANG_RETURN_ON_FAIL(slangSession->createCompositeComponentType(
        componentTypes.getBuffer(),
        componentTypes.getCount(),
        composedProgram.writeRef(),
        diagnosticsBlob.writeRef()));
    diagnoseIfNeeded(diagnosticsBlob);

    // Link the composite program
    Slang::ComPtr<slang::IComponentType> linkedProgram;
    SLANG_RETURN_ON_FAIL(
        composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()));
    diagnoseIfNeeded(diagnosticsBlob);

    // Retrieve the reflection information
    composedProgram = linkedProgram;
    slangReflection = composedProgram->getLayout();

    // Create a shader program
    ShaderProgramDesc programDesc = {};
    programDesc.slangGlobalScope = composedProgram.get();
    auto shaderProgram = device->createShaderProgram(programDesc);
    outShaderProgram = shaderProgram;

    return SLANG_OK;
}

// Function to validate the type layout of Outer struct with nested Inner struct
static void validateNestedExternStructLayout(
    UnitTestContext* context,
    slang::ProgramLayout* slangReflection)
{
    // Check reflection is available
    SLANG_CHECK(slangReflection != nullptr);

    // Get the entry point layout for vertexMain
    slang::EntryPointLayout* entryPointLayout = slangReflection->findEntryPointByName("vertexMain");

    SLANG_CHECK_MSG(entryPointLayout != nullptr, "Could not find vertexMain entry point");

    // Get the parameter count for the entry point
    auto paramCount = entryPointLayout->getParameterCount();
    SLANG_CHECK_MSG(paramCount >= 1, "Entry point has no parameters");

    // Get the first parameter, which should be of type Outer
    auto paramLayout = entryPointLayout->getParameterByIndex(0);
    SLANG_CHECK_MSG(paramLayout != nullptr, "Could not get first parameter layout");

    // Get the type layout of the parameter
    auto outerTypeLayout = paramLayout->getTypeLayout();
    SLANG_CHECK_MSG(outerTypeLayout != nullptr, "Parameter has no type layout");

    // Check if it's a struct type
    auto kind = outerTypeLayout->getKind();
    SLANG_CHECK_MSG(kind == slang::TypeReflection::Kind::Struct, "Parameter is not a struct type");

    // Verify Outer has 3 fields: position, innerData, texCoord
    auto fieldCount = outerTypeLayout->getFieldCount();
    SLANG_CHECK_MSG(fieldCount == 3, "Outer struct does not have 3 fields");

    // Find and check the innerData field
    slang::VariableLayoutReflection* innerDataField = nullptr;
    for (unsigned int i = 0; i < fieldCount; i++)
    {
        auto fieldLayout = outerTypeLayout->getFieldByIndex(i);
        const char* fieldName = fieldLayout->getName();

        if (fieldName && strcmp(fieldName, "innerData") == 0)
        {
            innerDataField = fieldLayout;
            break;
        }
    }

    SLANG_CHECK_MSG(innerDataField != nullptr, "Could not find innerData field in Outer struct");

    // Get the type layout of the innerData field
    auto innerTypeLayout = innerDataField->getTypeLayout();
    SLANG_CHECK_MSG(innerTypeLayout != nullptr, "innerData field has no type layout");

    // Verify Inner is a struct type
    kind = innerTypeLayout->getKind();
    SLANG_CHECK_MSG(kind == slang::TypeReflection::Kind::Struct, "Inner is not a struct type");

    // Verify Inner has 1 field (data)
    fieldCount = innerTypeLayout->getFieldCount();
    SLANG_CHECK_MSG(fieldCount == 1, "Inner struct does not have 1 field");

    // Find and check the data field in Inner
    bool foundDataField = false;
    for (unsigned int i = 0; i < fieldCount; i++)
    {
        auto fieldLayout = innerTypeLayout->getFieldByIndex(i);
        const char* fieldName = fieldLayout->getName();

        if (fieldName && strcmp(fieldName, "data") == 0)
        {
            foundDataField = true;

            // Check that it's a float4 type
            auto fieldTypeLayout = fieldLayout->getTypeLayout();
            auto fieldTypeKind = fieldTypeLayout->getKind();

            SLANG_CHECK_MSG(
                fieldTypeKind == slang::TypeReflection::Kind::Vector,
                "Field 'data' is not a vector type");

            auto elementCount = fieldTypeLayout->getElementCount();
            SLANG_CHECK_MSG(elementCount == 4, "Field 'data' is not a 4-element vector");

            break;
        }
    }

    SLANG_CHECK_MSG(foundDataField, "Could not find field 'data' in Inner struct");
}

void linkTimeTypeLayoutNestedImpl(rhi::IDevice* device, UnitTestContext* context)
{
    Slang::ComPtr<rhi::IShaderProgram> shaderProgram;
    slang::ProgramLayout* slangReflection = nullptr;

    auto result = loadProgram(device, shaderProgram, slangReflection);
    SLANG_CHECK(SLANG_SUCCEEDED(result));

    // Validate the nested struct layout
    validateNestedExternStructLayout(context, slangReflection);

    // Create a graphics pipeline to verify everything works
    RenderPipelineDesc pipelineDesc = {};
    pipelineDesc.program = shaderProgram.get();
    pipelineDesc.primitiveTopology = PrimitiveTopology::TriangleList;

    ComPtr<IRenderPipeline> pipelineState;
    auto pipelineResult = device->createRenderPipeline(pipelineDesc, pipelineState.writeRef());
    SLANG_CHECK(SLANG_SUCCEEDED(pipelineResult));
}

//
// This test verifies that type layout information correctly propagates through
// the Slang compilation pipeline when a regular struct contains a field whose type
// is an extern struct defined in another module.
// Specifically, it tests that:
//
// 1. The Outer struct correctly includes the Inner extern struct as a field
// 2. After linking, the Inner struct's layout is properly resolved with its field
// 3. The complete type layout information is available in the reflection data
//

SLANG_UNIT_TEST(linkTimeTypeLayoutNested)
{
    runTestImpl(linkTimeTypeLayoutNestedImpl, unitTestContext, DeviceType::Vulkan);
}

} // namespace gfx_test