summaryrefslogtreecommitdiff
path: root/source/slang/diff.meta.slang
blob: e604140ae05238a80e57941b8e7dbcd9dfce0577 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

/// Modifer to mark a function for forward-mode differentiation.
/// i.e. the compiler will automatically generate a new function
/// that computes the jacobian-vector product of the original.
syntax __differentiate_jvp : JVPDerivativeModifier;

// Custom JVP Function reference
__attributeTarget(FuncDecl)
attribute_syntax [__custom_jvp(function)]   : CustomJVPAttribute;

//@ public:

    /// Interface to denote types as differentiable.
    /// Allows for user-specified differential types as
    /// well as automatic generation, for when the associated type
    /// hasn't been declared explicitly.
__magic_type(DifferentiableType)
interface IDifferentiable
{
    associatedtype Differential;
};

    /// Pair type that serves to wrap the primal and
    /// differential types of an arbitrary type T.
__generic<T : IDifferentiable>
__magic_type(DifferentialPairType)
__intrinsic_type($(kIROp_DifferentialPairType))
struct __DifferentialPair
{

    __intrinsic_op($(kIROp_MakeDifferentialPair))
    __init(T _primal, T.Differential _differential);

    __intrinsic_op($(kIROp_DifferentialPairGetDifferential))
    T.Differential d();

    T.Differential getDifferential()
    {
        return d();
    }

    __intrinsic_op($(kIROp_DifferentialPairGetPrimal))
    T p();

    T getPrimal()
    {
        return p();
    }
};

// Add extensions for the standard types
extension float : IDifferentiable
{
    typedef float Differential;
}

extension vector<float, 3> : IDifferentiable
{
    typedef vector<float, 3> Differential;
}

extension vector<float, 2> : IDifferentiable
{
    typedef vector<float, 2> Differential;
}

extension vector<float, 4> : IDifferentiable
{
    typedef vector<float, 4> Differential;
}