summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-single-return.cpp
blob: df60098ce362cbe7f3a072a9cd87fe84555d289c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// slang-ir-single-return.cpp
#include "slang-ir-single-return.h"

#include "slang-ir-clone.h"
#include "slang-ir-eliminate-multilevel-break.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-insts.h"
#include "slang-ir-simplify-cfg.h"
#include "slang-ir.h"

namespace Slang
{

struct SingleReturnContext : public InstPassBase
{
    SingleReturnContext(IRModule* inModule)
        : InstPassBase(inModule)
    {
    }
    void processFunc(IRGlobalValueWithCode* func)
    {
        IRBuilder builder(module);
        simplifyCFG(func, CFGSimplificationOptions::getFast());

        // We make use of the `eliminate-multi-level-break` pass to implement the transformation.
        // To be able to do that, we need to prepare `func` so that the entire function body
        // is wrapped in a trivial loop and turn all `return`s into `break`s out of the outter most
        // loop.
        builder.setInsertInto(func);
        auto breakBlock = builder.emitBlock();
        auto returnBlock = builder.emitBlock();
        builder.setInsertInto(breakBlock);
        auto resultType = as<IRFuncType>(func->getDataType())->getResultType();

        IRInst* retValParam = nullptr;
        if (resultType->getOp() != kIROp_VoidType)
        {
            retValParam = builder.emitParam(resultType);
        }
        builder.emitBranch(returnBlock);

        auto originalStartBlock = func->getFirstBlock();
        auto loopHeaderBlock = builder.createBlock();
        loopHeaderBlock->insertBefore(originalStartBlock);
        builder.setInsertInto(loopHeaderBlock);

        // Move all params into `loopHeaderBlock`.
        List<IRParam*> params;
        for (auto param : originalStartBlock->getParams())
        {
            params.add(param);
        }
        for (auto param : params)
        {
            loopHeaderBlock->addParam(param);
        }

        builder.emitLoop(originalStartBlock, breakBlock, originalStartBlock);

        // Now replace all return insts as break insts.
        processChildInstsOfType<IRReturn>(
            kIROp_Return,
            func,
            [&](IRReturn* returnInst)
            {
                IRInst* retVal = nullptr;
                if (returnInst->getOperandCount() == 0)
                    retVal = builder.getVoidValue();
                else
                    retVal = returnInst->getVal();
                builder.setInsertBefore(returnInst);
                if (resultType->getOp() == kIROp_VoidType)
                {
                    builder.emitBranch(breakBlock);
                }
                else
                {
                    builder.emitBranch(breakBlock, 1, &retVal);
                }
                returnInst->removeAndDeallocate();
            });

        builder.setInsertInto(returnBlock);
        if (retValParam)
            builder.emitReturn(retValParam);
        else
            builder.emitReturn();
    }
};

void convertFuncToSingleReturnForm(IRModule* irModule, IRGlobalValueWithCode* func)
{
    SingleReturnContext context(irModule);
    context.processFunc(func);
}

int getReturnCount(IRGlobalValueWithCode* func)
{
    int returnCount = 0;
    for (auto block : func->getBlocks())
    {
        for (auto inst : block->getChildren())
        {
            if (inst->getOp() == kIROp_Return)
            {
                returnCount++;
            }
        }
    }
    return returnCount;
}

bool isSingleReturnFunc(IRGlobalValueWithCode* func)
{
    return getReturnCount(func) == 1;
}

} // namespace Slang