1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
|
// slang-ir-lower-generic-call.cpp
#include "slang-ir-lower-generic-call.h"
#include "slang-ir-generics-lowering-context.h"
namespace Slang
{
struct GenericCallLoweringContext
{
SharedGenericsLoweringContext* sharedContext;
// Represents a work item for unpacking `inout` or `out` arguments after a generic call.
struct ArgumentUnpackWorkItem
{
// Concrete typed destination.
IRInst* dstArg = nullptr;
// Packed argument.
IRInst* packedArg = nullptr;
};
// Packs `arg` into a `IRAnyValue` if necessary, to make it feedable into the parameter.
// If `arg` represents a concrete typed variable passed in to a generic `out` parameter,
// this function indicates that it needs to be unpacked after the call by setting
// `unpackAfterCall`.
IRInst* maybePackArgument(
IRBuilder* builder,
IRType* paramType,
IRInst* arg,
ArgumentUnpackWorkItem& unpackAfterCall)
{
unpackAfterCall.dstArg = nullptr;
unpackAfterCall.packedArg = nullptr;
// If either paramType or argType is a pointer type
// (because of `inout` or `out` modifiers), we extract
// the underlying value type first.
IRType* paramValType = paramType;
IRType* argValType = arg->getDataType();
IRInst* argVal = arg;
bool isParamPointer = false;
if (auto ptrType = as<IRPtrTypeBase>(paramType))
{
isParamPointer = true;
paramValType = ptrType->getValueType();
}
bool isArgPointer = false;
auto argType = arg->getDataType();
if (auto argPtrType = as<IRPtrTypeBase>(argType))
{
isArgPointer = true;
argValType = argPtrType->getValueType();
argVal = builder->emitLoad(arg);
}
// Pack `arg` if the parameter expects AnyValue but
// `arg` is not an AnyValue.
if (as<IRAnyValueType>(paramValType) && !as<IRAnyValueType>(argValType))
{
auto packedArgVal = builder->emitPackAnyValue(paramValType, argVal);
// if parameter expects an `out` pointer, store the packed val into a
// variable and pass in a pointer to that variable.
if (as<IRPtrTypeBase>(paramType))
{
auto tempVar = builder->emitVar(paramValType);
builder->emitStore(tempVar, packedArgVal);
// tempVar needs to be unpacked into original var after the call.
unpackAfterCall.dstArg = arg;
unpackAfterCall.packedArg = tempVar;
return tempVar;
}
else
{
return packedArgVal;
}
}
return arg;
}
IRInst* maybeUnpackValue(IRBuilder* builder, IRType* expectedType, IRType* actualType, IRInst* value)
{
if (as<IRAnyValueType>(actualType) && !as<IRAnyValueType>(expectedType))
{
auto unpack = builder->emitUnpackAnyValue(expectedType, value);
return unpack;
}
return value;
}
// Create a dispatch function for a interface method.
// On CPU, the dispatch function is implemented as a witness table lookup followed by
// a function-pointer call.
// On GPU targets, we can modify the body of the dispatch function in a follow-up
// pass to implement it with a `switch` statement based on the type ID.
IRFunc* _createInterfaceDispatchMethod(
IRBuilder* builder,
IRInterfaceType* interfaceType,
IRInst* requirementKey,
IRInst* requirementVal)
{
auto func = builder->createFunc();
if (auto linkage = requirementKey->findDecoration<IRLinkageDecoration>())
{
builder->addNameHintDecoration(func, linkage->getMangledName());
}
auto reqFuncType = cast<IRFuncType>(requirementVal);
List<IRType*> paramTypes;
paramTypes.add(builder->getWitnessTableType(interfaceType));
for (UInt i = 0; i < reqFuncType->getParamCount(); i++)
{
paramTypes.add(reqFuncType->getParamType(i));
}
auto dispatchFuncType = builder->getFuncType(paramTypes, reqFuncType->getResultType());
func->setFullType(dispatchFuncType);
builder->setInsertInto(func);
builder->emitBlock();
List<IRInst*> params;
IRParam* witnessTableParam = builder->emitParam(paramTypes[0]);
for (Index i = 1; i < paramTypes.getCount(); i++)
{
params.add(builder->emitParam(paramTypes[i]));
}
auto callee = builder->emitLookupInterfaceMethodInst(
reqFuncType, witnessTableParam, requirementKey);
auto call = (IRCall*)builder->emitCallInst(reqFuncType->getResultType(), callee, params);
if (call->getDataType()->getOp() == kIROp_VoidType)
builder->emitReturn();
else
builder->emitReturn(call);
return func;
}
// If an interface dispatch method is already created, return it.
// Otherwise, create the method.
IRFunc* getOrCreateInterfaceDispatchMethod(
IRBuilder* builder,
IRInterfaceType* interfaceType,
IRInst* requirementKey,
IRInst* requirementVal)
{
if (auto func = sharedContext->mapInterfaceRequirementKeyToDispatchMethods.TryGetValue(requirementKey))
return *func;
auto dispatchFunc =
_createInterfaceDispatchMethod(builder, interfaceType, requirementKey, requirementVal);
sharedContext->mapInterfaceRequirementKeyToDispatchMethods.AddIfNotExists(
requirementKey, dispatchFunc);
return dispatchFunc;
}
// Translate `callInst` into a call of `newCallee`, and respect the new `funcType`.
// If `newCallee` is a lowered generic function, `specializeInst` contains the type
// arguments used to specialize the callee.
void translateCallInst(
IRCall* callInst,
IRFuncType* funcType,
IRInst* newCallee,
IRSpecialize* specializeInst)
{
List<IRType*> paramTypes;
for (UInt i = 0; i < funcType->getParamCount(); i++)
paramTypes.add(funcType->getParamType(i));
IRBuilder builderStorage;
auto builder = &builderStorage;
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
builder->setInsertBefore(callInst);
// Process the argument list of the call.
// For each argument, we test if it needs to be packed into an `AnyValue` for the
// call. For `out` and `inout` parameters, they may also need to be unpacked after
// the call, in which case we add such the argument to `argsToUnpack` so it can be
// processed after the new call inst is emitted.
List<IRInst*> args;
List<ArgumentUnpackWorkItem> argsToUnpack;
for (UInt i = 0; i < callInst->getArgCount(); i++)
{
auto arg = callInst->getArg(i);
ArgumentUnpackWorkItem unpackWorkItem;
auto newArg = maybePackArgument(builder, paramTypes[i], arg, unpackWorkItem);
args.add(newArg);
if (unpackWorkItem.packedArg)
argsToUnpack.add(unpackWorkItem);
}
if (specializeInst)
{
for (UInt i = 0; i < specializeInst->getArgCount(); i++)
{
auto arg = specializeInst->getArg(i);
// Translate Type arguments into RTTI object.
if (as<IRType>(arg))
{
// We are using a simple type to specialize a callee.
// Generate RTTI for this type.
auto rttiObject = sharedContext->maybeEmitRTTIObject(arg);
arg = builder->emitGetAddress(
builder->getRTTIHandleType(),
rttiObject);
}
else if (arg->getOp() == kIROp_Specialize)
{
// The type argument used to specialize a callee is itself a
// specialization of some generic type.
// TODO: generate RTTI object for specializations of generic types.
SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types");
}
else if (arg->getOp() == kIROp_RTTIObject)
{
// We are inside a generic function and using a generic parameter
// to specialize another callee. The generic parameter of the caller
// has already been translated into an RTTI object, so we just need
// to pass this object down.
}
args.add(arg);
}
}
// If callee returns `AnyValue` but we are expecting a concrete value, unpack it.
auto calleeRetType = funcType->getResultType();
auto newCall = builder->emitCallInst(calleeRetType, newCallee, args);
auto callInstType = callInst->getDataType();
auto unpackInst = maybeUnpackValue(builder, callInstType, calleeRetType, newCall);
// Unpack other `out` arguments.
for (auto& item : argsToUnpack)
{
auto packedVal = builder->emitLoad(item.packedArg);
auto originalValType = cast<IRPtrTypeBase>(item.dstArg->getDataType())->getValueType();
auto unpackedVal = builder->emitUnpackAnyValue(originalValType, packedVal);
builder->emitStore(item.dstArg, unpackedVal);
}
callInst->replaceUsesWith(unpackInst);
callInst->removeAndDeallocate();
}
IRInst* findInnerMostSpecializingBase(IRSpecialize* inst)
{
auto result = inst->getBase();
while (auto specialize = as<IRSpecialize>(result))
result = specialize->getBase();
return result;
}
void lowerCallToSpecializedFunc(IRCall* callInst, IRSpecialize* specializeInst)
{
// If we see a call(specialize(gFunc, Targs), args),
// translate it into call(gFunc, args, Targs).
auto loweredFunc = specializeInst->getBase();
// All callees should have already been lowered in lower-generic-functions pass.
// For intrinsic generic functions, they are left as is, and we also need to ignore
// them here.
if (loweredFunc->getOp() == kIROp_Generic)
{
return;
}
else if (loweredFunc->getOp() == kIROp_Specialize)
{
// All nested generic functions are supposed to be flattend before this pass.
// If they are not, they represent an intrinsic function that should not be
// modified in this pass.
auto innerMostFunc = findInnerMostSpecializingBase(static_cast<IRSpecialize*>(loweredFunc));
if (innerMostFunc && innerMostFunc->getOp() == kIROp_Generic)
{
innerMostFunc =
findInnerMostGenericReturnVal(static_cast<IRGeneric*>(innerMostFunc));
}
if (innerMostFunc->findDecoration<IRTargetIntrinsicDecoration>())
return;
SLANG_UNEXPECTED("Nested generics specialization.");
}
else if (loweredFunc->getOp() == kIROp_lookup_interface_method)
{
lowerCallToInterfaceMethod(
callInst, cast<IRLookupWitnessMethod>(loweredFunc), specializeInst);
return;
}
IRFuncType* funcType = cast<IRFuncType>(loweredFunc->getDataType());
translateCallInst(callInst, funcType, loweredFunc, specializeInst);
}
void lowerCallToInterfaceMethod(IRCall* callInst, IRLookupWitnessMethod* lookupInst, IRSpecialize* specializeInst)
{
// If we see a call(lookup_interface_method(...), ...), we need to translate
// all occurences of associatedtypes.
auto interfaceType = cast<IRInterfaceType>(
cast<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())
->getConformanceType());
if (isBuiltin(interfaceType))
return;
IRBuilder builderStorage;
auto builder = &builderStorage;
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
builder->setInsertBefore(callInst);
// Create interface dispatch method that bottlenecks the dispatch logic.
auto requirementKey = lookupInst->getRequirementKey();
auto requirementVal =
sharedContext->findInterfaceRequirementVal(interfaceType, requirementKey);
auto dispatchFunc = getOrCreateInterfaceDispatchMethod(
builder, interfaceType, requirementKey, requirementVal);
auto parentFunc = getParentFunc(callInst);
// Don't process the call inst that is the one in the dispatch function itself.
if (parentFunc == dispatchFunc)
return;
// Replace `callInst` with a new call inst that calls `dispatchFunc` instead, and
// with the witness table as first argument,
builder->setInsertBefore(callInst);
List<IRInst*> newArgs;
newArgs.add(lookupInst->getWitnessTable());
for (UInt i = 0; i < callInst->getArgCount(); i++)
newArgs.add(callInst->getArg(i));
auto newCall =
(IRCall*)builder->emitCallInst(callInst->getFullType(), dispatchFunc, newArgs);
callInst->replaceUsesWith(newCall);
callInst->removeAndDeallocate();
// Translate the new call inst as normal, taking care of packing/unpacking inputs
// and outputs.
translateCallInst(
newCall,
cast<IRFuncType>(dispatchFunc->getFullType()),
dispatchFunc,
specializeInst);
}
void lowerCall(IRCall* callInst)
{
if (auto specializeInst = as<IRSpecialize>(callInst->getCallee()))
lowerCallToSpecializedFunc(callInst, specializeInst);
else if (auto lookupInst = as<IRLookupWitnessMethod>(callInst->getCallee()))
lowerCallToInterfaceMethod(callInst, lookupInst, nullptr);
}
void processInst(IRInst* inst)
{
if (auto callInst = as<IRCall>(inst))
{
lowerCall(callInst);
}
}
void processModule()
{
// We start by initializing our shared IR building state,
// since we will re-use that state for any code we
// generate along the way.
//
SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage;
sharedBuilder->module = sharedContext->module;
sharedBuilder->session = sharedContext->module->session;
sharedContext->addToWorkList(sharedContext->module->getModuleInst());
while (sharedContext->workList.getCount() != 0)
{
IRInst* inst = sharedContext->workList.getLast();
sharedContext->workList.removeLast();
sharedContext->workListSet.Remove(inst);
processInst(inst);
for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
{
sharedContext->addToWorkList(child);
}
}
}
};
void lowerGenericCalls(SharedGenericsLoweringContext* sharedContext)
{
GenericCallLoweringContext context;
context.sharedContext = sharedContext;
context.processModule();
}
}
|