summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-com-interface.cpp
blob: 1bcf3d2b60c9aa900fd5d2024df3059cb0ca90dc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// slang-ir-com-interface.cpp
#include "slang-ir-com-interface.h"

#include "slang-ir.h"
#include "slang-ir-insts.h"

namespace Slang
{

struct ComInterfaceLoweringContext
{
    IRModule* module;
    DiagnosticSink* diagnosticSink;

    SharedIRBuilder sharedBuilder;

    Dictionary<IRInterfaceType*, IRComPtrType*> comPtrTypes;

    void replaceTypeUses(IRInst* inst, IRInst* newValue)
    {
        List<IRUse*> uses;
        for (auto use = inst->firstUse; use; use = use->nextUse)
        {
            uses.add(use);
        }
        for (auto use : uses)
        {
            switch (use->getUser()->getOp())
            {
            case kIROp_WitnessTableIDType:
            case kIROp_WitnessTableType:
            case kIROp_ThisType:
            case kIROp_RTTIPointerType:
            case kIROp_RTTIHandleType:
            case kIROp_ComPtrType:
                continue;
            default:
                break;
            }
            use->set(newValue);
        }
    }

    IRComPtrType* processInterfaceType(IRInterfaceType* type)
    {
        if (!type->findDecoration<IRComInterfaceDecoration>())
            return nullptr;

        IRComPtrType* result = nullptr;

        if (comPtrTypes.TryGetValue(type, result))
            return result;

        IRBuilder builder(sharedBuilder);
        builder.setInsertInto(module->getModuleInst());
        result = builder.getComPtrType(type);

        replaceTypeUses(type, result);
        return result;
    }

    void processThisType(IRThisType* type)
    {
        auto comPtrType = processInterfaceType(as<IRInterfaceType>(type->getConstraintType()));
        if (!comPtrType)
            return;
        replaceTypeUses(type, comPtrType);
    }

    void processModule()
    {
        for (auto child : module->getGlobalInsts())
        {
            switch (child->getOp())
            {
            case kIROp_InterfaceType:
                processInterfaceType(as<IRInterfaceType>(child));
                break;
            case kIROp_ThisType:
                processThisType(as<IRThisType>(child));
                break;
            default:
                break;
            }
        }
    }
};

void lowerComInterfaces(IRModule* module, DiagnosticSink* sink)
{
    ComInterfaceLoweringContext context;
    context.module = module;
    context.diagnosticSink = sink;
    context.sharedBuilder.init(module);
    return context.processModule();
}

}