1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
|
#include "slang-ir-variable-scope-correction.h"
#include "slang-ir-clone.h"
#include "slang-ir-dominators.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
#include "slang-ir.h"
namespace Slang
{
bool isCPUTarget(TargetRequest* targetReq);
bool isCUDATarget(TargetRequest* targetReq);
namespace
{ // anonymous
struct VariableScopeCorrectionContext
{
VariableScopeCorrectionContext(IRModule* module, TargetRequest* targetReq)
: m_module(module), m_builder(module), m_targetReq(targetReq)
{
}
void processModule();
/// Process a function in the module
void _processFunction(IRFunc* funcInst);
void _processInstruction(
IRDominatorTree* dominatorTree,
IRInst* instAfterParam,
IRInst* originInst,
const List<IRLoop*>& loopHeaderList,
List<IRInst*>& workList);
void _processStorableInst(IRInst* insertLoc, IRInst* inst, const List<IRUse*>& outOfScopeUses);
void _processUnstorableInst(IRInst* inst, const List<IRUse*>& outOfScopeUser);
bool _isStorableType(IRType* inst);
bool _isOutOfScopeUse(
IRInst* inst,
IRDominatorTree* domTree,
const List<IRLoop*>& loopHeaderList);
IRModule* m_module;
IRBuilder m_builder;
TargetRequest* m_targetReq;
};
void VariableScopeCorrectionContext::processModule()
{
IRModuleInst* moduleInst = m_module->getModuleInst();
for (IRInst* child : moduleInst->getChildren())
{
// We want to find all of the functions, and process them
if (auto funcInst = as<IRFunc>(child))
{
if (funcInst->getFirstBlock())
{
_processFunction(funcInst);
}
}
}
}
void VariableScopeCorrectionContext::_processFunction(IRFunc* funcInst)
{
IRDominatorTree* dominatorTree = m_module->findOrCreateDominatorTree(funcInst);
List<IRInst*> workList;
Dictionary<IRBlock*, List<IRLoop*>> loopHeaderMap;
// traverse all blocks in the function
for (auto block : funcInst->getBlocks())
{
// Traverse all the dominators of a given block to check whether this given block is in a
// loop region. Loop region blocks are the blocks that are dominated by the loop header
// block but not dominated by the loop break block.
auto dominatorBlock = dominatorTree->getImmediateDominator(block);
List<IRLoop*> loopHeaderList;
for (; dominatorBlock;
dominatorBlock = dominatorTree->getImmediateDominator(dominatorBlock))
{
// Find if the block is loop header block
if (auto loopHeader = as<IRLoop>(dominatorBlock->getTerminator()))
{
// Get the break block of the loop and check if such block
auto breakBlock = loopHeader->getBreakBlock();
// Check if the current block is dominated by the break block. If so, it means that
// the block is in the loop region.
if (!dominatorTree->dominates(breakBlock, block))
{
loopHeaderList.add(loopHeader);
}
}
}
loopHeaderMap.add(block, loopHeaderList);
}
if (loopHeaderMap.getCount() == 0)
{
return;
}
// Traverse all the instructions in function.
for (auto block : funcInst->getBlocks())
{
if (loopHeaderMap.containsKey(block))
{
for (auto inst : block->getChildren())
{
List<IRInst*> instList;
// Don't process the variable declaration instruction because the code is not
// emitted for them unless there is a use.
if (inst->getOp() == kIROp_Var)
{
continue;
}
workList.add(inst);
}
}
}
auto instAfterParam = funcInst->getFirstBlock()->getFirstOrdinaryInst();
for (auto inst = workList.begin(); inst != workList.end(); inst++)
{
if (auto loopHeaderList = loopHeaderMap.tryGetValue(getBlock(*inst)))
{
_processInstruction(dominatorTree, instAfterParam, *inst, *loopHeaderList, workList);
}
}
}
// Check if the instruction is used outside of the loop.
// The loopHeaderList contains all the loop headers where the original instruction is defined.
// So we if the block of the user instruction is dominated by the break block of the loop header,
// it means that it was out of the loop, so it's out of the scope of the loop.
// Note the reason we use the loopHeaderList is because there could be nested loops, so we need to
// check all the loop headers from inner to outer.
bool VariableScopeCorrectionContext::_isOutOfScopeUse(
IRInst* userInst,
IRDominatorTree* domTree,
const List<IRLoop*>& loopHeaderList)
{
if (auto block = getBlock(userInst))
{
// If the use site of this instruction is dominated by the break block, it means that the
// instruction is used after the break block, so we need to make that instruction available
// globally. By doing so, we record all the users of this instructions.
for (auto loopHeader : loopHeaderList)
{
auto breakBlock = loopHeader->getBreakBlock();
if (domTree->dominates(breakBlock, block))
{
return true;
}
}
}
return false;
}
void VariableScopeCorrectionContext::_processInstruction(
IRDominatorTree* dominatorTree,
IRInst* instAfterParam,
IRInst* originInst,
const List<IRLoop*>& loopHeaderList,
List<IRInst*>& workList)
{
List<IRUse*> outOfScopeUses;
for (auto use = originInst->firstUse; use; use = use->nextUse)
{
if (_isOutOfScopeUse(use->getUser(), dominatorTree, loopHeaderList))
{
outOfScopeUses.add(use);
}
}
if (outOfScopeUses.getCount() == 0)
return;
if (_isStorableType(originInst->getDataType()))
{
_processStorableInst(instAfterParam, originInst, outOfScopeUses);
}
else
{
_processUnstorableInst(originInst, outOfScopeUses);
// After processing the user, we need to add operands of the instruction to the worklist
// for later processing.
for (UInt idx = 0; idx < originInst->getOperandCount(); idx++)
{
workList.add(originInst->getOperand(idx));
}
}
}
void VariableScopeCorrectionContext::_processStorableInst(
IRInst* insertLoc,
IRInst* inst,
const List<IRUse*>& outOfScopeUses)
{
auto type = inst->getDataType();
// store instruction must have a result type
SLANG_ASSERT(type);
// declare a new variable at the beginning of the function used to store the result of the
// instruction
m_builder.setInsertBefore(insertLoc);
auto dstPtr = m_builder.emitVar(type);
// insert a store instruction after the instruction
m_builder.setInsertAfter(inst);
m_builder.emitStore(dstPtr, inst);
// last, replace operands in the use site instruction with the new variable
// Note, because "dstPtr" is a pointer type, we have to insert a load(dstPtr) instruction before
// use it. Simply replace any operand with pointer could generate error code.
for (auto use : outOfScopeUses)
{
m_builder.setInsertBefore(use->getUser());
auto loadInst = m_builder.emitLoad(type, dstPtr);
m_builder.replaceOperand(use, loadInst);
}
}
void VariableScopeCorrectionContext::_processUnstorableInst(
IRInst* inst,
const List<IRUse*>& outOfScopeUsers)
{
IRCloneEnv cloneEnv;
auto clonedInst = cloneInst(&cloneEnv, &m_builder, inst);
for (auto user : outOfScopeUsers)
{
// duplicate the invisible instruction and insert it right before the use site,
// then replace the operand with the duplicated instruction
clonedInst->insertBefore(user->getUser());
m_builder.replaceOperand(user, clonedInst);
}
}
bool VariableScopeCorrectionContext::_isStorableType(IRType* type)
{
if (!type)
return false;
// C/CPP/CUDA can store any type.
if (isCPUTarget(m_targetReq) || isCUDATarget(m_targetReq))
return true;
if (as<IRBasicType>(type))
return true;
switch (type->getOp())
{
case kIROp_VectorType:
case kIROp_MatrixType:
case kIROp_StructType: return true;
case kIROp_ArrayType:
{
if (auto arrayType = as<IRArrayTypeBase>(type))
return _isStorableType(arrayType->getElementType());
else
return false;
}
case kIROp_UnsizedArrayType: return false;
default: return false;
}
}
} // namespace
void applyVariableScopeCorrection(IRModule* module, TargetRequest* targetReq)
{
VariableScopeCorrectionContext context(module, targetReq);
context.processModule();
}
} // namespace Slang
|