summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-diff-call.cpp
blob: a574d6b7e33eab4f8c047d173e081d0b999c817d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// slang-ir-diff-call.cpp
#include "slang-ir-diff-call.h"

#include "slang-ir.h"
#include "slang-ir-insts.h"

namespace Slang
{

struct DerivativeCallProcessContext
{
    // This type passes over the module and replaces
    // derivative calls with the processed derivative 
    // function.
    //
    IRModule*                       module;

    bool processModule()
    {
        // Run through all the global-level instructions, 
        // looking for callable blocks.
        for (auto inst : module->getGlobalInsts())
        {
            // If the instr is a callable, get all the basic blocks
            if (auto callable = as<IRGlobalValueWithCode>(inst))
            {
                // Iterate over each block in the callable
                for (auto block : callable->getBlocks())
                {
                    // Iterate over each child instruction.
                    auto child = block->getFirstInst();
                    if (!child) continue;

                    do 
                    {
                        auto nextChild = child->getNextInst();
                        // Look for IRForwardDifferentiate
                        if (auto derivOf = as<IRForwardDifferentiate>(child))
                        {
                            processDifferentiate(derivOf);
                        }
                        child = nextChild;
                    } 
                    while (child);
                }
            }
        }
        return true;
    }

    // Perform forward-mode automatic differentiation on 
    // the intstructions.
    void processDifferentiate(IRForwardDifferentiate* derivOfInst)
    {
        IRInst* jvpCallable = nullptr;

        // First get base function 
        auto origCallable = derivOfInst->getBaseFn();

        // Resolve the derivative function for IRForwardDifferentiate(IRSpecialize(IRFunc))
        // Check the specialize inst for a reference to the derivative fn.
        // 
        if (auto origSpecialize = as<IRSpecialize>(origCallable))
        {
            if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
            {
                jvpCallable = jvpSpecRefDecorator->getForwardDerivativeFunc();
            }
        }

        // Resolve the derivative function for an IRForwardDifferentiate(IRFunc)
        //
        // Check for the 'JVPDerivativeReference' decorator on the
        // base function.
        //
        if (auto jvpRefDecorator = origCallable->findDecoration<IRForwardDerivativeDecoration>())
        {
            jvpCallable = jvpRefDecorator->getForwardDerivativeFunc();
        }

        SLANG_ASSERT(jvpCallable);

        // Substitute all uses of the 'derivativeOf' operation 
        // with the resolved derivative function.
        derivOfInst->replaceUsesWith(jvpCallable);

        // Remove the 'derivativeOf' inst.
        derivOfInst->removeAndDeallocate();
    }
};

// Set up context and call main process method.
// 
bool processDerivativeCalls(
        IRModule* module, 
        IRDerivativeCallProcessOptions const&)
{
    DerivativeCallProcessContext context;
    context.module = module;

    return context.processModule();
}

}