summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-deduplicate.cpp
blob: 51a6776273b434e31987f02249b7d3ad1b912038 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include "slang-ir-insts.h"

namespace Slang
{
    struct DeduplicateContext
    {
        SharedIRBuilder* builder;
        IRInst* addValue(IRInst* value)
        {
            if (!value) return nullptr;
            if (as<IRType>(value))
                return addTypeValue(value);
            if (auto constValue = as<IRConstant>(value))
                return addConstantValue(constValue);
            return value;
        }
        IRInst* addConstantValue(IRConstant* value)
        {
            IRConstantKey key = { value };
            value->setFullType((IRType*)addValue(value->getFullType()));
            if (auto newValue = builder->getConstantMap().TryGetValue(key))
                return *newValue;
            builder->getConstantMap()[key] = value;
            return value;
        }
        IRInst* addTypeValue(IRInst* value)
        {
            // Do not deduplicate struct or interface types.
            switch (value->getOp())
            {
            case kIROp_StructType:
            case kIROp_InterfaceType:
                return value;
            default:
                break;
            }

            for (UInt i = 0; i < value->getOperandCount(); i++)
            {
                value->setOperand(i, addValue(value->getOperand(i)));
            }
            value->setFullType((IRType*)addValue(value->getFullType()));
            IRInstKey key = { value };
            if (auto newValue = builder->getGlobalValueNumberingMap().TryGetValue(key))
                return *newValue;
            builder->getGlobalValueNumberingMap()[key] = value;
            return value;
        }
    };
    void SharedIRBuilder::deduplicateAndRebuildGlobalNumberingMap()
    {
        DeduplicateContext context;
        context.builder = this;
        m_constantMap.Clear();
        m_globalValueNumberingMap.Clear();
        List<IRInst*> instToRemove;
        for (auto inst : m_module->getGlobalInsts())
        {
            if (auto constVal = as<IRConstant>(inst))
            {
                auto newConst = context.addConstantValue(constVal);
                if (newConst != constVal)
                {
                    constVal->replaceUsesWith(newConst);
                    instToRemove.add(constVal);
                }
            }
        }
        for (auto inst : m_module->getGlobalInsts())
        {
            if (as<IRType>(inst) || as<IRSpecialize>(inst))
            {
                auto newInst = context.addTypeValue(inst);
                if (newInst != inst)
                {
                    inst->replaceUsesWith(newInst);
                    instToRemove.add(inst);
                }
            }
        }
        for (auto inst : instToRemove)
            inst->removeAndDeallocate();
    }

    void SharedIRBuilder::replaceGlobalInst(IRInst* oldInst, IRInst* newInst)
    {
        List<IRUse*> uses;
        for (auto use = oldInst->firstUse; use; use = use->nextUse)
        {
            uses.add(use);
        }

        bool shouldUpdateGlobalNumberedCache = false;
        for (auto use : uses)
        {
            use->set(newInst);
            // depending on the type of the user inst, we may need to rebuild and update the global
            // numbering cache.
            if (isGloballyNumberedInst(use->getUser()))
            {
                shouldUpdateGlobalNumberedCache = true;
            }
        }
        oldInst->removeAndDeallocate();
        if (shouldUpdateGlobalNumberedCache)
        {
            deduplicateAndRebuildGlobalNumberingMap();
        }
    }

    bool SharedIRBuilder::isGloballyNumberedInst(IRInst* inst)
    {
        if (!inst->getParent() || inst->getParent()->getOp() != kIROp_Module)
            return false;
        return m_globalValueNumberingMap.ContainsKey(IRInstKey{inst});
    }
}