1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
|
#include "slang-ir-transform-params-to-constref.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
#include "slang-ir.h"
namespace Slang
{
struct TransformParamsToConstRefContext
{
IRModule* module;
DiagnosticSink* sink;
IRBuilder builder;
bool changed = false;
TransformParamsToConstRefContext(IRModule* module, DiagnosticSink* sink)
: module(module), sink(sink), builder(module)
{
}
// Check if a type should be transformed (struct, array, or other composite types)
bool shouldTransformParam(IRParam* param)
{
auto type = param->getDataType();
if (!type)
return false;
switch (type->getOp())
{
case kIROp_StructType:
case kIROp_ArrayType:
case kIROp_UnsizedArrayType:
case kIROp_VectorType:
case kIROp_MatrixType:
case kIROp_TupleType:
case kIROp_CoopVectorType:
// valid type, continue to check
break;
default:
return false;
}
return true;
}
void rewriteParamUseSitesToSupportConstRefUsage(HashSet<IRParam*>& updatedParams)
{
// Traverse the uses of our updated params to rewrite them.
// Assume a `in` parameter has been converted to a `constref` parameter.
for (auto param : updatedParams)
{
traverseUses(
param,
[&](IRUse* use)
{
auto user = use->getUser();
switch (user->getOp())
{
case kIROp_FieldExtract:
{
// Transform the IRFieldExtract into a IRFieldAddress
auto fieldExtract = as<IRFieldExtract>(user);
builder.setInsertBefore(fieldExtract);
auto fieldAddr = builder.emitFieldAddress(
fieldExtract->getBase(),
fieldExtract->getField());
auto loadInst = builder.emitLoad(fieldAddr);
fieldExtract->replaceUsesWith(loadInst);
fieldExtract->removeAndDeallocate();
break;
}
case kIROp_GetElement:
{
// Transform the IRGetElement into a IRGetElementPtr
auto getElement = as<IRGetElement>(user);
builder.setInsertBefore(getElement);
auto elemAddr = builder.emitElementAddress(
getElement->getBase(),
getElement->getIndex());
auto loadInst = builder.emitLoad(elemAddr);
getElement->replaceUsesWith(loadInst);
getElement->removeAndDeallocate();
break;
}
default:
{
// Insert a load before the user and replace the user with the load
builder.setInsertBefore(user);
auto loadInst = builder.emitLoad(param);
use->set(loadInst);
break;
}
}
});
}
}
// Update call sites to pass an address instead of value for each updated-param
void updateCallSites(IRFunc* func, HashSet<IRParam*>& updatedParams)
{
// Find all calls which use `func`.
List<IRCall*> callsToUpdate;
traverseUsers<IRCall>(func, [&](IRCall* call) { callsToUpdate.add(call); });
// Update each call site
for (auto call : callsToUpdate)
{
builder.setInsertBefore(call);
List<IRInst*> newArgs;
// Transform arguments to match the updated-parameter
UInt i = 0;
for (IRParam* param = func->getFirstParam(); param; param = param->getNextParam(), i++)
{
auto arg = call->getArg(i);
if (!updatedParams.contains(param))
{
newArgs.add(arg);
continue;
}
auto tempVar = builder.emitVar(arg->getFullType());
builder.emitStore(tempVar, arg);
newArgs.add(tempVar);
}
// Create new call with updated arguments
auto newCall = builder.emitCallInst(call->getFullType(), func, newArgs);
call->replaceUsesWith(newCall);
call->removeAndDeallocate();
}
}
// Check if function should be excluded from transformation
bool shouldProcessFunction(IRFunc* func)
{
// Skip functions without definitions
if (!func->isDefinition())
return false;
// Skip if we find any of these decorations
for (auto decoration : func->getDecorations())
{
// Skip functions with target intrinsic decorations.
// These functions cannot be properly legalized after
// transformation.
if (as<IRTargetIntrinsicDecoration>(decoration))
return false;
// Skip entry-point and pseudo-entry-point functions
// since we cannot legalize the input parameters.
if (as<IREntryPointDecoration>(decoration) || as<IRCudaKernelDecoration>(decoration) ||
as<IRAutoPyBindCudaDecoration>(decoration))
return false;
}
// Skip functions with `kIROp_GenericAsm` since
// these instructions inject target specific code
// using parameters in an unpredictable way, relying
// on assumptions that parameters do not change type.
for (auto block : func->getBlocks())
{
for (auto inst : block->getChildren())
{
if (!as<IRGenericAsm>(inst))
continue;
return false;
}
}
return true;
}
// Process a single function
void processFunc(IRFunc* func)
{
HashSet<IRParam*> updatedParams;
// First pass: Transform parameter types
for (auto param = func->getFirstParam(); param; param = param->getNextParam())
{
if (shouldTransformParam(param))
{
// Our goal here is to transform `in T` parameters to const-ref.
// We are selective about what we will transform for a few reasons:
// 1. no reason to transform simple primitives like `int`.
// 2. not every type makes sense as constref. For example, `ParameterBlock`.
// 3. constref is not 100% stable, so we need to be selective on what we let
// transform into constref.
//
// This allows us to pass the address of variables directly into a function,
// giving us the choice to remove copies into a parameter.
auto paramType = param->getDataType();
auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal);
param->setFullType(constRefType);
changed = true;
updatedParams.add(param);
}
}
if (updatedParams.getCount() == 0)
{
return;
}
// Second pass: Update function body according to the new `constref` parameters
rewriteParamUseSitesToSupportConstRefUsage(updatedParams);
// Third pass: Update call sites
updateCallSites(func, updatedParams);
}
void addFuncsToCallListInTopologicalOrder(
IRFunc* root,
List<IRFunc*>& functionsToProcess,
HashSet<IRFunc*>& visitedCandidates)
{
// We added 'root' already, leave
if (visitedCandidates.contains(root))
return;
visitedCandidates.add(root);
for (auto block : root->getBlocks())
{
for (auto blockInst : block->getChildren())
{
auto call = as<IRCall>(blockInst);
if (!call)
continue;
auto callee = as<IRFunc>(call->getCallee());
if (!callee)
continue;
addFuncsToCallListInTopologicalOrder(callee, functionsToProcess, visitedCandidates);
}
}
if (!shouldProcessFunction(root))
return;
functionsToProcess.add(root);
}
SlangResult processModule()
{
// Collect all functions that need processing.
// Process all callee's before callers; otherwise we introduce bugs
HashSet<IRFunc*> visitedCandidates;
List<IRFunc*> functionsToProcess;
for (auto inst = module->getModuleInst()->getFirstChild(); inst; inst = inst->getNextInst())
{
auto func = as<IRFunc>(inst);
if (!func)
continue;
addFuncsToCallListInTopologicalOrder(func, functionsToProcess, visitedCandidates);
}
// Process each function
for (auto func : functionsToProcess)
{
processFunc(func);
}
return SLANG_OK;
}
};
SlangResult transformParamsToConstRef(IRModule* module, DiagnosticSink* sink)
{
TransformParamsToConstRefContext context(module, sink);
return context.processModule();
}
} // namespace Slang
|