summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-float-non-uniform-resource-index.cpp
blob: dbcb093c2ba3c631a746ea9c6dc8e0213d93fab7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#include "slang-ir-float-non-uniform-resource-index.h"

#include "slang-ir-util.h"

namespace Slang
{
void processNonUniformResourceIndex(
    IRInst* nonUniformResourceIndexInst,
    NonUniformResourceIndexFloatMode floatMode)
{
    // float `NonUniformResourceIndex()` to right before the access operation
    // by walking up the use-def chain
    // from nonUniformResource inst of an index to an array of buffer or
    // texture def all the way to the leaf operations. To be precise:
    // - go through GEP and see if it calls an intrinsic function,
    //   then decorate the address itself (GetElementPtr)
    // - go through GEP to identify the pointer access and the Loads that it
    //   accesses (GetElementPtr -> Load), then decorate the load instruction.
    // - go through IntCasts to deal with u32 -> i32 / vice-versa (IntCast)
    List<IRInst*> resWorkList;

    // Handle cases when `nonUniformResourceIndexInst` inst is wrapped around
    // an index in a nested fashion, i.e. nonUniform(nonUniform(index)) by
    // only adding the inner-most inst in the worklist, and work our way out.
    auto insti = nonUniformResourceIndexInst;
    while (insti->getOp() == kIROp_NonUniformResourceIndex)
    {
        if (resWorkList.getCount() != 0)
            resWorkList.removeLast();
        resWorkList.add(insti);
        insti = insti->getOperand(0);
    }

    // For all the users of a `nonUniformResourceIndexInst`, make them directly
    // use the underlying base inst that is wrapped by `nonUniformResourceIndex`
    // and finally wrap them with a `nonUniformResourceIndex`, and add back to the
    // worklist, and keep bubbling them up until it can.
    for (Index i = 0; i < resWorkList.getCount(); i++)
    {
        auto inst = resWorkList[i];
        traverseUses(
            inst,
            [&](IRUse* use)
            {
                auto user = use->getUser();
                IRBuilder builder(user);
                builder.setInsertBefore(user);

                IRInst* newUser = nullptr;
                switch (user->getOp())
                {
                case kIROp_IntCast:
                    // Replace intCast(nonUniformRes(x)), into nonUniformRes(intCast(x))
                    newUser = builder.emitCast(user->getFullType(), inst->getOperand(0));
                    break;
                case kIROp_CastDescriptorHandleToUInt2:
                    {
                        // Replace castBindlessToInt(nonUniformRes(x)), into
                        // nonUniformRes(castBindlessToInt(x))
                        auto operand = inst->getOperand(0);
                        newUser = builder.emitIntrinsicInst(
                            user->getFullType(),
                            kIROp_CastDescriptorHandleToUInt2,
                            1,
                            &operand);
                    }
                    break;
                case kIROp_GetElementPtr:
                    // Ignore when `NonUniformResourceIndex` is not on the index
                    if (floatMode != NonUniformResourceIndexFloatMode::SPIRV)
                        break;
                    if (user->getOperand(1) == inst)
                    {
                        // Replace gep(pArray, nonUniformRes(x)), into
                        // nonUniformRes(gep(pArray, x))
                        newUser = builder.emitElementAddress(
                            user->getFullType(),
                            user->getOperand(0),
                            inst->getOperand(0));
                    }
                    break;
                case kIROp_GetElement:
                    // Ignore when `NonUniformResourceIndex` is not on base
                    if (user->getOperand(0) == inst)
                    {
                        // Replace getElement(nonuniformRes(obj), i), into
                        // nonUniformRes(getElement(obj, i))
                        newUser = builder.emitElementExtract(
                            user->getFullType(),
                            inst->getOperand(0),
                            user->getOperand(1));
                    }
                    break;
                case kIROp_Swizzle:
                    // Ignore when `NonUniformResourceIndex` is not on base
                    if (user->getOperand(0) == inst)
                    {
                        // Replace getElement(nonuniformRes(obj), i), into
                        // nonUniformRes(getElement(obj, i))
                        ShortList<IRInst*> operands;
                        for (UInt i = 0; i < user->getOperandCount(); i++)
                            operands.add(user->getOperand(i));
                        operands[0] = inst->getOperand(0);
                        newUser = builder.emitIntrinsicInst(
                            user->getFullType(),
                            kIROp_Swizzle,
                            operands.getCount(),
                            operands.getArrayView().getBuffer());
                    }
                    break;
                case kIROp_NonUniformResourceIndex:
                    // Replace nonUniformRes(nonUniformRes(x)), into nonUniformRes(x)
                    newUser = inst->getOperand(0);
                    break;
                case kIROp_Load:
                    if (floatMode != NonUniformResourceIndexFloatMode::SPIRV)
                        break;
                    // Replace load(nonUniformRes(x)), into nonUniformRes(load(x))
                    newUser = builder.emitLoad(user->getFullType(), inst->getOperand(0));
                    break;
                default:
                    // Ignore for all other unknown insts.
                    break;
                };

                // Early exit when we could not process the `NonUniformResourceIndex` inst.
                if (!newUser)
                    return;

                auto nonuniformUser = builder.emitNonUniformResourceIndexInst(newUser);
                user->replaceUsesWith(nonuniformUser);

                // Update the worklist with the newly added `NonUniformResourceIndex` inst,
                // based on the base inst it was constructed around, in case we need to further
                // bubble up the `NonUniformResourceIndex` inst.
                switch (user->getOp())
                {
                case kIROp_IntCast:
                case kIROp_GetElementPtr:
                case kIROp_Load:
                case kIROp_NonUniformResourceIndex:
                case kIROp_CastDescriptorHandleToUInt2:
                case kIROp_GetElement:
                case kIROp_Swizzle:
                    resWorkList.add(nonuniformUser);
                    break;
                };

                // Clean up the base inst from the IR module, to avoid duplicate decorations.
                user->removeAndDeallocate();
            });
    }

    if (floatMode != NonUniformResourceIndexFloatMode::SPIRV)
        return;
    // Once all the `NonUniformResourceIndex` insts are visited, and the inst type is bubbled up
    // to the parent, a decoration is added to the operands of the insts.
    for (int i = 0; i < resWorkList.getCount(); ++i)
    {
        // It is only required to decorate the base inst, if the `NonUniformResourceIndex` inst
        // around it has any active uses.
        auto inst = resWorkList[i];
        if (!inst->hasUses())
        {
            inst->removeAndDeallocate();
            continue;
        }
        // For each of the `NonUniformResourceIndex` inst that remain, decorate the base inst
        // with a [NonUniformResource] decoration, which is the operand0 of the inst, only
        // when the type is a resource type, or a pointer to a resource type, or a pointer
        // in the Physical Storage buffer address space.
        auto operand = inst->getOperand(0);
        auto type = operand->getDataType();
        if (isResourceType(type) || isPointerToResourceType(type))
        {
            IRBuilder builder(operand);
            builder.addSPIRVNonUniformResourceDecoration(operand);
            if (operand->getOp() == kIROp_Load)
            {
                // If the inst is a load, then the addr inst itself should also be decorated
                // with the [NonUniformResource] decoration.
                auto addr = operand->getOperand(0);
                if (!addr->findDecoration<IRSPIRVNonUniformResourceDecoration>())
                    builder.addSPIRVNonUniformResourceDecoration(addr);
            }
        }
        inst->replaceUsesWith(operand);
        inst->removeAndDeallocate();
    }
}

void floatNonUniformResourceIndex(IRModule* module, NonUniformResourceIndexFloatMode floatMode)
{
    // Walk through all the instructions in the module, and float the `NonUniformResourceIndex`
    // insts to the right place in the IR module.

    List<IRInst*> workList;
    for (auto globalInst : module->getGlobalInsts())
    {
        auto func = as<IRGlobalValueWithCode>(getGenericReturnVal(globalInst));
        if (!func)
            continue;
        workList.clear();
        for (auto block : func->getBlocks())
        {
            for (auto inst : block->getChildren())
            {
                if (inst->getOp() == kIROp_NonUniformResourceIndex)
                    workList.add(inst);
            }
        }
        for (auto inst : workList)
        {
            if (inst->getParent() != nullptr)
                processNonUniformResourceIndex(inst, floatMode);
        }
    }
}
} // namespace Slang