summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-undo-param-copy.cpp
blob: d8aac7201b251157cc51b5e0527818c3bac72388 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include "slang-ir-undo-param-copy.h"

#include "slang-ir-dce.h"
#include "slang-ir-insts.h"
#include "slang-ir.h"

namespace Slang
{
// This pass transforms variables decorated with TempCallArgVarDecoration
// by replacing them with direct references to the original parameters.
// This is important for CUDA/OptiX targets where functions like 'IgnoreHit'
// can prevent copy-back operations from executing.
struct UndoParameterCopyVisitor
{
    IRBuilder builder;
    IRModule* module;
    bool changed = false;

    // Track instructions to remove
    List<IRInst*> instsToRemove;

    UndoParameterCopyVisitor(IRModule* module)
        : module(module)
    {
        builder.setInsertInto(module);
    }

    // Process the entire module
    void processModule()
    {
        // Process all functions in the module
        for (auto inst = module->getModuleInst()->getFirstChild(); inst; inst = inst->getNextInst())
        {
            if (auto func = as<IRFunc>(inst))
            {
                processFunc(func);
            }
        }
    }

    // Process a single function
    void processFunc(IRFunc* func)
    {
        instsToRemove.clear();
        HashSet<IRInst*> originalPtrsForCopyBackCandidates; // Tracks original params that might
                                                            // have a redundant copy-back

        // Single pass to identify temps, replace uses, and identify redundant copy-back stores.
        for (auto block = func->getFirstBlock(); block; block = block->getNextBlock())
        {
            for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst())
            {
                if (auto varInst = as<IRVar>(inst))
                {
                    if (varInst->findDecoration<IRTempCallArgVarDecoration>())
                    {
                        IRStore* initializingStore = nullptr;
                        IRInst* originalParamPtr = nullptr;

                        // Scan for the store that initializes this varInst
                        // This store should be in the same block, after varInst.
                        // The value stored should be an IRLoad from the original parameter pointer.
                        for (auto scanInst = varInst->getNextInst(); scanInst;
                             scanInst = scanInst->getNextInst())
                        {
                            if (auto storeInst = as<IRStore>(scanInst))
                            {
                                if (storeInst->getPtr() == varInst)
                                {
                                    initializingStore = storeInst;
                                    if (auto loadInst = as<IRLoad>(storeInst->getVal()))
                                    {
                                        originalParamPtr = loadInst->getPtr();

                                        // Found the pattern: var, store(var, load(originalParam))
                                        this->changed = true;

                                        // Replace uses of varInst with originalParamPtr immediately
                                        varInst->replaceUsesWith(originalParamPtr);

                                        // Mark for removal
                                        instsToRemove.add(initializingStore);
                                        instsToRemove.add(varInst);

                                        // Record originalParamPtr for copy-back optimization check
                                        originalPtrsForCopyBackCandidates.add(originalParamPtr);
                                    }
                                    break; // Found the initializing store for varInst
                                }
                            }
                            // Stop scanning if another var declaration or a call is encountered
                            if (as<IRVar>(scanInst) || as<IRCall>(scanInst))
                            {
                                break;
                            }
                        }
                    }
                }
                else if (auto storeInst = as<IRStore>(inst))
                {
                    // Check for redundant copy-back: store(originalParam, load(originalParam))
                    IRInst* destPtr = storeInst->getPtr();
                    if (originalPtrsForCopyBackCandidates.contains(destPtr))
                    {
                        if (auto loadVal = as<IRLoad>(storeInst->getVal()))
                        {
                            if (loadVal->getPtr() == destPtr)
                            {
                                // This is a redundant copy-back store
                                instsToRemove.add(storeInst);
                                this->changed = true;
                            }
                        }
                    }
                }
            }
        }

        // Removal pass
        for (auto& inst : instsToRemove)
        {
            if (inst->getParent())
            {
                inst->removeAndDeallocate();
            }
        }
    }
};

void undoParameterCopy(IRModule* module)
{
    UndoParameterCopyVisitor visitor(module);
    visitor.processModule();

    // Run DCE to clean up any dead instructions
    if (visitor.changed)
    {
        eliminateDeadCode(module);
    }
}
} // namespace Slang