summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-autodiff-propagate.h
blob: 71cb17a08105e23542747d8156b402b9f9b5329f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// slang-ir-autodiff-propagate.h
#pragma once

#include "slang-compiler.h"
#include "slang-ir-autodiff.h"
#include "slang-ir-insts.h"
#include "slang-ir.h"

namespace Slang
{

inline bool isDifferentialInst(IRInst* inst)
{
    return inst->findDecoration<IRDifferentialInstDecoration>();
}

inline bool isPrimalInst(IRInst* inst)
{
    return inst->findDecoration<IRPrimalInstDecoration>() || (as<IRConstant>(inst) != nullptr);
}

inline bool isMixedDifferentialInst(IRInst* inst)
{
    return inst->findDecoration<IRMixedDifferentialInstDecoration>();
}

struct DiffPropagationPass : InstPassBase
{
    AutoDiffSharedContext* autodiffContext;

    DiffPropagationPass(AutoDiffSharedContext* autodiffContext)
        : autodiffContext(autodiffContext), InstPassBase(autodiffContext->moduleInst->getModule())
    {
    }


    bool shouldInstBeMarkedDifferential(IRInst* inst)
    {
        for (UIndex ii = 0; ii < inst->getOperandCount(); ii++)
        {
            if (isDifferentialInst(inst->getOperand(ii)))
            {
                return true;
            }
        }

        return false;
    }

    void addPendingUsersToWorkList(IRInst* inst)
    {
        auto use = inst->firstUse;
        while (use)
        {
            if (!isDifferentialInst(use->getUser()))
            {
                addToWorkList(use->getUser());
            }
            use = use->nextUse;
        }
    }

    // Propagate IRDifferentialInstDecoration for all children of instWithChildren.
    void propagateDiffInstDecoration(IRBuilder* builder, IRInst* instWithChildren)
    {
        List<IRInst*> initialList;
        // Mark 'GetDifferential' insts as differential.
        processChildInstsOfType<IRDifferentialPairGetDifferential>(
            kIROp_DifferentialPairGetDifferential,
            instWithChildren,
            [&](IRDifferentialPairGetDifferential* getDifferentialInst)
            {
                builder->markInstAsDifferential(getDifferentialInst);
                initialList.add(getDifferentialInst);
            });


        workList.clear();
        workListSet.clear();

        // Add the marked insts to the work list.
        for (auto inst : initialList)
        {
            // Look for insts marked as differential.
            if (isDifferentialInst(inst))
                addPendingUsersToWorkList(inst);
        }

        // Propagate to all users..
        while (workList.getCount() != 0)
        {
            IRInst* inst = pop();

            // Skip if this is already a differential inst.
            if (isDifferentialInst(inst))
            {
                continue;
            }

            if (shouldInstBeMarkedDifferential(inst))
            {
                builder->markInstAsDifferential(inst);
                addPendingUsersToWorkList(inst);
            }
        }
    }
};

} // namespace Slang