summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-call.cpp
blob: 92044be3c256ea62b7643f4a06c19be959447767 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// slang-ir-diff-call.cpp
#include "slang-ir-diff-call.h"

#include "slang-ir.h"
#include "slang-ir-insts.h"

namespace Slang
{

struct DerivativeCallProcessContext
{
    // This type passes over the module and replaces
    // derivative calls with the processed derivative 
    // function.
    //
    IRModule*                       module;

    bool processModule()
    {
        // Run through all the global-level instructions, 
        // looking for callable blocks.
        for (auto inst : module->getGlobalInsts())
        {
            // If the instr is a callable, get all the basic blocks
            if (auto callable = as<IRGlobalValueWithCode>(inst))
            {
                // Iterate over each block in the callable
                for (auto block : callable->getBlocks())
                {
                    // Iterate over each child instruction.
                    auto child = block->getFirstInst();
                    if (!child) continue;

                    do 
                    {
                        auto nextChild = child->getNextInst();
                        // Look for IRJVPDifferentiate
                        if (auto derivOf = as<IRJVPDifferentiate>(child))
                        {
                            processDifferentiate(derivOf);
                        }
                        child = nextChild;
                    } 
                    while (child);
                }
            }
        }
        return true;
    }

    // Perform forward-mode automatic differentiation on 
    // the intstructions.
    void processDifferentiate(IRJVPDifferentiate* derivOfInst)
    {
        IRFunc* jvpFunc = nullptr;

        // Resolve the derivative function.
        //
        // Check for the 'JVPDerivativeReference' decorator on the
        // base function.
        if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>())
        {
            jvpFunc = jvpRefDecorator->getJVPFunc();
        }
        
        // Substitute all uses of the 'derivativeOf' operation 
        // with the resolved derivative function.
        while (auto use = derivOfInst->firstUse)
        {
            use->set(jvpFunc);
        }

        // Remove the 'derivativeOf'
        derivOfInst->removeAndDeallocate();
    }
};

// Set up context and call main process method.
// 
bool processDerivativeCalls(
        IRModule* module, 
        IRDerivativeCallProcessOptions const&)
{
    DerivativeCallProcessContext context;
    context.module = module;

    return context.processModule();
}

}