1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
#include "slang-ir-specialize-dispatch.h"
#include "slang-ir-generics-lowering-context.h"
#include "slang-ir-insts.h"
#include "slang-ir.h"
namespace Slang
{
IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
{
for (auto entry : table->getEntries())
{
if (entry->getRequirementKey() == key)
return entry->getSatisfyingVal();
}
return nullptr;
}
void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
{
auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0);
// Collect all witness tables of `witnessTableType` in current module.
List<IRWitnessTable*> witnessTables;
for (auto globalInst : sharedContext->module->getGlobalInsts())
{
if (globalInst->op == kIROp_WitnessTable && globalInst->getDataType() == witnessTableType)
{
witnessTables.add(cast<IRWitnessTable>(globalInst));
}
}
SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock());
auto block = dispatchFunc->getFirstBlock();
// The dispatch function before modification must be in the form of
// call(lookup_interface_method(witnessTableParam, interfaceReqKey), args)
// We now find the relavent instructions.
IRCall* callInst = nullptr;
IRLookupWitnessMethod* lookupInst = nullptr;
IRReturn* returnInst = nullptr;
for (auto inst : block->getOrdinaryInsts())
{
switch (inst->op)
{
case kIROp_Call:
callInst = cast<IRCall>(inst);
break;
case kIROp_lookup_interface_method:
lookupInst = cast<IRLookupWitnessMethod>(inst);
break;
case kIROp_ReturnVal:
case kIROp_ReturnVoid:
returnInst = cast<IRReturn>(inst);
break;
default:
break;
}
}
SLANG_ASSERT(callInst && lookupInst && returnInst);
IRBuilder builderStorage;
auto builder = &builderStorage;
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
builder->setInsertBefore(callInst);
auto witnessTableParam = block->getFirstParam();
auto requirementKey = lookupInst->getRequirementKey();
List<IRInst*> params;
for (auto param = block->getFirstParam()->getNextParam(); param; param = param->getNextParam())
{
params.add(param);
}
// Emit cascaded if statements to call the correct concrete function based on
// the witness table pointer passed in.
auto ifBlock = block;
for (Index i = 0; i < witnessTables.getCount(); i++)
{
auto witnessTable = witnessTables[i];
bool isLast = (i == witnessTables.getCount() - 1);
IRInst* cmpArgs[] =
{
builder->emitBitCast(builder->getUInt64Type(), witnessTableParam),
builder->emitBitCast(builder->getUInt64Type(),(IRInst*)witnessTable)
};
IRInst* condition = nullptr;
IRBlock* trueBlock = nullptr;
if (!isLast)
{
condition = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_Eql, 2, cmpArgs);
trueBlock = builder->emitBlock();
}
auto callee = findWitnessTableEntry(witnessTable, requirementKey);
SLANG_ASSERT(callee);
auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params);
if (callInst->getDataType()->op == kIROp_VoidType)
builder->emitReturn();
else
builder->emitReturn(specializedCallInst);
if (!isLast)
{
auto falseBlock = builder->emitBlock();
builder->setInsertInto(ifBlock);
builder->emitIf(condition, trueBlock, falseBlock);
builder->setInsertInto(falseBlock);
ifBlock = falseBlock;
}
}
// Remove old implementation.
lookupInst->removeAndDeallocate();
callInst->removeAndDeallocate();
returnInst->removeAndDeallocate();
}
void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext)
{
sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
for (auto kv : sharedContext->mapInterfaceRequirementKeyToDispatchMethods)
{
auto dispatchFunc = kv.Value;
specializeDispatchFunction(sharedContext, dispatchFunc);
}
}
} // namespace Slang
|