blob: 71cb17a08105e23542747d8156b402b9f9b5329f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
// slang-ir-autodiff-propagate.h
#pragma once
#include "slang-compiler.h"
#include "slang-ir-autodiff.h"
#include "slang-ir-insts.h"
#include "slang-ir.h"
namespace Slang
{
inline bool isDifferentialInst(IRInst* inst)
{
return inst->findDecoration<IRDifferentialInstDecoration>();
}
inline bool isPrimalInst(IRInst* inst)
{
return inst->findDecoration<IRPrimalInstDecoration>() || (as<IRConstant>(inst) != nullptr);
}
inline bool isMixedDifferentialInst(IRInst* inst)
{
return inst->findDecoration<IRMixedDifferentialInstDecoration>();
}
struct DiffPropagationPass : InstPassBase
{
AutoDiffSharedContext* autodiffContext;
DiffPropagationPass(AutoDiffSharedContext* autodiffContext)
: autodiffContext(autodiffContext), InstPassBase(autodiffContext->moduleInst->getModule())
{
}
bool shouldInstBeMarkedDifferential(IRInst* inst)
{
for (UIndex ii = 0; ii < inst->getOperandCount(); ii++)
{
if (isDifferentialInst(inst->getOperand(ii)))
{
return true;
}
}
return false;
}
void addPendingUsersToWorkList(IRInst* inst)
{
auto use = inst->firstUse;
while (use)
{
if (!isDifferentialInst(use->getUser()))
{
addToWorkList(use->getUser());
}
use = use->nextUse;
}
}
// Propagate IRDifferentialInstDecoration for all children of instWithChildren.
void propagateDiffInstDecoration(IRBuilder* builder, IRInst* instWithChildren)
{
List<IRInst*> initialList;
// Mark 'GetDifferential' insts as differential.
processChildInstsOfType<IRDifferentialPairGetDifferential>(
kIROp_DifferentialPairGetDifferential,
instWithChildren,
[&](IRDifferentialPairGetDifferential* getDifferentialInst)
{
builder->markInstAsDifferential(getDifferentialInst);
initialList.add(getDifferentialInst);
});
workList.clear();
workListSet.clear();
// Add the marked insts to the work list.
for (auto inst : initialList)
{
// Look for insts marked as differential.
if (isDifferentialInst(inst))
addPendingUsersToWorkList(inst);
}
// Propagate to all users..
while (workList.getCount() != 0)
{
IRInst* inst = pop();
// Skip if this is already a differential inst.
if (isDifferentialInst(inst))
{
continue;
}
if (shouldInstBeMarkedDifferential(inst))
{
builder->markInstAsDifferential(inst);
addPendingUsersToWorkList(inst);
}
}
}
};
} // namespace Slang
|