blob: 92044be3c256ea62b7643f4a06c19be959447767 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
|
// slang-ir-diff-call.cpp
#include "slang-ir-diff-call.h"
#include "slang-ir.h"
#include "slang-ir-insts.h"
namespace Slang
{
struct DerivativeCallProcessContext
{
// This type passes over the module and replaces
// derivative calls with the processed derivative
// function.
//
IRModule* module;
bool processModule()
{
// Run through all the global-level instructions,
// looking for callable blocks.
for (auto inst : module->getGlobalInsts())
{
// If the instr is a callable, get all the basic blocks
if (auto callable = as<IRGlobalValueWithCode>(inst))
{
// Iterate over each block in the callable
for (auto block : callable->getBlocks())
{
// Iterate over each child instruction.
auto child = block->getFirstInst();
if (!child) continue;
do
{
auto nextChild = child->getNextInst();
// Look for IRJVPDifferentiate
if (auto derivOf = as<IRJVPDifferentiate>(child))
{
processDifferentiate(derivOf);
}
child = nextChild;
}
while (child);
}
}
}
return true;
}
// Perform forward-mode automatic differentiation on
// the intstructions.
void processDifferentiate(IRJVPDifferentiate* derivOfInst)
{
IRFunc* jvpFunc = nullptr;
// Resolve the derivative function.
//
// Check for the 'JVPDerivativeReference' decorator on the
// base function.
if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>())
{
jvpFunc = jvpRefDecorator->getJVPFunc();
}
// Substitute all uses of the 'derivativeOf' operation
// with the resolved derivative function.
while (auto use = derivOfInst->firstUse)
{
use->set(jvpFunc);
}
// Remove the 'derivativeOf'
derivOfInst->removeAndDeallocate();
}
};
// Set up context and call main process method.
//
bool processDerivativeCalls(
IRModule* module,
IRDerivativeCallProcessOptions const&)
{
DerivativeCallProcessContext context;
context.module = module;
return context.processModule();
}
}
|