summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-specialize-dispatch.cpp
blob: 0c519427d27ae97c235307a053bf0d916d4e4c10 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include "slang-ir-specialize-dispatch.h"

#include "slang-ir-generics-lowering-context.h"
#include "slang-ir-insts.h"
#include "slang-ir.h"

namespace Slang
{
IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
{
    for (auto entry : table->getEntries())
    {
        if (entry->getRequirementKey() == key)
            return entry->getSatisfyingVal();
    }
    return nullptr;
}

void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
{
    auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0);

    // Collect all witness tables of `witnessTableType` in current module.
    List<IRWitnessTable*> witnessTables;
    for (auto globalInst : sharedContext->module->getGlobalInsts())
    {
        if (globalInst->op == kIROp_WitnessTable && globalInst->getDataType() == witnessTableType)
        {
            witnessTables.add(cast<IRWitnessTable>(globalInst));
        }
    }

    SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock());
    auto block = dispatchFunc->getFirstBlock();

    // The dispatch function before modification must be in the form of
    // call(lookup_interface_method(witnessTableParam, interfaceReqKey), args)
    // We now find the relavent instructions.
    IRCall* callInst = nullptr;
    IRLookupWitnessMethod* lookupInst = nullptr;
    IRReturn* returnInst = nullptr;
    for (auto inst : block->getOrdinaryInsts())
    {
        switch (inst->op)
        {
        case kIROp_Call:
            callInst = cast<IRCall>(inst);
            break;
        case kIROp_lookup_interface_method:
            lookupInst = cast<IRLookupWitnessMethod>(inst);
            break;
        case kIROp_ReturnVal:
        case kIROp_ReturnVoid:
            returnInst = cast<IRReturn>(inst);
            break;
        default:
            break;
        }
    }
    SLANG_ASSERT(callInst && lookupInst && returnInst);

    IRBuilder builderStorage;
    auto builder = &builderStorage;
    builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
    builder->setInsertBefore(callInst);

    auto witnessTableParam = block->getFirstParam();
    auto requirementKey = lookupInst->getRequirementKey();
    List<IRInst*> params;
    for (auto param = block->getFirstParam()->getNextParam(); param; param = param->getNextParam())
    {
        params.add(param);
    }

    // Emit cascaded if statements to call the correct concrete function based on
    // the witness table pointer passed in.
    auto ifBlock = block;
    for (Index i = 0; i < witnessTables.getCount(); i++)
    {
        auto witnessTable = witnessTables[i];
        bool isLast = (i == witnessTables.getCount() - 1);
        IRInst* cmpArgs[] =
        {
            builder->emitBitCast(builder->getUInt64Type(), witnessTableParam),
            builder->emitBitCast(builder->getUInt64Type(),(IRInst*)witnessTable)
        };
        IRInst* condition = nullptr;
        IRBlock* trueBlock = nullptr;
        if (!isLast)
        {
            condition = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_Eql, 2, cmpArgs);
            trueBlock = builder->emitBlock();
        }
        auto callee = findWitnessTableEntry(witnessTable, requirementKey);
        SLANG_ASSERT(callee);
        auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params);
        if (callInst->getDataType()->op == kIROp_VoidType)
            builder->emitReturn();
        else
            builder->emitReturn(specializedCallInst);
        if (!isLast)
        {
            auto falseBlock = builder->emitBlock();
            builder->setInsertInto(ifBlock);
            builder->emitIf(condition, trueBlock, falseBlock);
            builder->setInsertInto(falseBlock);
            ifBlock = falseBlock;
        }
    }

    // Remove old implementation.
    lookupInst->removeAndDeallocate();
    callInst->removeAndDeallocate();
    returnInst->removeAndDeallocate();
}

void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext)
{
    sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();

    for (auto kv : sharedContext->mapInterfaceRequirementKeyToDispatchMethods)
    {
        auto dispatchFunc = kv.Value;
        specializeDispatchFunction(sharedContext, dispatchFunc);
    }
}
} // namespace Slang