blob: e604140ae05238a80e57941b8e7dbcd9dfce0577 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
|
/// Modifer to mark a function for forward-mode differentiation.
/// i.e. the compiler will automatically generate a new function
/// that computes the jacobian-vector product of the original.
syntax __differentiate_jvp : JVPDerivativeModifier;
// Custom JVP Function reference
__attributeTarget(FuncDecl)
attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
//@ public:
/// Interface to denote types as differentiable.
/// Allows for user-specified differential types as
/// well as automatic generation, for when the associated type
/// hasn't been declared explicitly.
__magic_type(DifferentiableType)
interface IDifferentiable
{
associatedtype Differential;
};
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
__generic<T : IDifferentiable>
__magic_type(DifferentialPairType)
__intrinsic_type($(kIROp_DifferentialPairType))
struct __DifferentialPair
{
__intrinsic_op($(kIROp_MakeDifferentialPair))
__init(T _primal, T.Differential _differential);
__intrinsic_op($(kIROp_DifferentialPairGetDifferential))
T.Differential d();
T.Differential getDifferential()
{
return d();
}
__intrinsic_op($(kIROp_DifferentialPairGetPrimal))
T p();
T getPrimal()
{
return p();
}
};
// Add extensions for the standard types
extension float : IDifferentiable
{
typedef float Differential;
}
extension vector<float, 3> : IDifferentiable
{
typedef vector<float, 3> Differential;
}
extension vector<float, 2> : IDifferentiable
{
typedef vector<float, 2> Differential;
}
extension vector<float, 4> : IDifferentiable
{
typedef vector<float, 4> Differential;
}
|