summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-hoist-local-types.cpp
blob: 756a25c498f466cc5975cd7400e923ccf78931ba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include "slang-ir-hoist-local-types.h"

#include "slang-ir-insts.h"
#include "slang-ir.h"

namespace Slang
{
struct HoistLocalTypesContext
{
    IRModule* module;
    DiagnosticSink* sink;

    SharedIRBuilder sharedBuilderStorage;

    List<IRInst*> workList;
    HashSet<IRInst*> workListSet;

    void addToWorkList(IRInst* inst)
    {
        for (auto ii = inst->getParent(); ii; ii = ii->getParent())
        {
            if (as<IRGeneric>(ii))
                return;
        }

        if (workListSet.Contains(inst))
            return;

        workList.add(inst);
        workListSet.Add(inst);
    }

    void processInst(IRInst* inst)
    {
        auto sharedBuilder = &sharedBuilderStorage;
        if (!as<IRType>(inst))
            return;
        if (inst->getParent() == module->getModuleInst())
            return;
        IRInstKey key = {inst};
        if (auto value = sharedBuilder->getGlobalValueNumberingMap().TryGetValue(key))
        {
            inst->replaceUsesWith(*value);
            inst->removeAndDeallocate();
            return;
        }
        IRBuilder builder(sharedBuilder);
        builder.setInsertInto(module->getModuleInst());
        bool hoistable = true;
        ShortList<IRInst*> mappedOperands;
        for (UInt i = 0; i < inst->getOperandCount(); i++)
        {
            IRInstKey opKey = {inst->getOperand(i)};
            if (auto value = sharedBuilder->getGlobalValueNumberingMap().TryGetValue(opKey))
            {
                mappedOperands.add(*value);
            }
            else
            {
                hoistable = false;
                break;
            }
        }
        if (hoistable)
        {
            auto newType = builder.getType(
                inst->getOp(), mappedOperands.getCount(), mappedOperands.getArrayView().getBuffer());
            inst->transferDecorationsTo(newType);
            inst->replaceUsesWith(newType);
            inst->removeAndDeallocate();
        }
    }

    void processModule()
    {
        SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
        sharedBuilder->init(module);

        // Deduplicate equivalent types and build numbering map for global types.
        sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();

        addToWorkList(module->getModuleInst());

        while (workList.getCount() != 0)
        {
            IRInst* inst = workList.getLast();

            workList.removeLast();
            workListSet.Remove(inst);

            processInst(inst);

            for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
            {
                addToWorkList(child);
            }
        }
    }
};

void hoistLocalTypes(IRModule* module, DiagnosticSink* sink)
{
    HoistLocalTypesContext context;
    context.module = module;
    context.sink = sink;
    context.processModule();
}

} // namespace Slang