diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-08-20 23:53:01 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-20 20:53:01 -0700 |
| commit | 359e96c6d94b59950681d69fc7534d0bbbd8e9be (patch) | |
| tree | 445a523f8fd77824ad924718318dc64504f35156 /docs/proposals | |
| parent | f9f6a28df40f418ddd0c8ff3b9cacccdb085e202 (diff) | |
Proposal: A simpler and more flexible `IDifferentiable` system (#4865)
Designed to work for both value types as well as ref/ptr/buffer types
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'docs/proposals')
| -rw-r--r-- | docs/proposals/010-new-diff-type-system.md | 283 |
1 files changed, 283 insertions, 0 deletions
diff --git a/docs/proposals/010-new-diff-type-system.md b/docs/proposals/010-new-diff-type-system.md new file mode 100644 index 000000000..242f92de3 --- /dev/null +++ b/docs/proposals/010-new-diff-type-system.md @@ -0,0 +1,283 @@ +## Problem +Our current `IDifferentiable` system has some flaws. It works fine for value types, since we can assume that every input gets a corresponding output or 'return' value. It works poorly for buffer/pointer types, since we don't 'return' a buffer, but simply want the getters/setters to be differentiable, and the resulting type to have a second buffer/pointer for the differential data. + +Here's a demonstrative example with our current codebase when we use value types (like `float`) +```csharp +[Differentiable] +float add(float a, float b) +{ + return a + b; +} + +// Synthesized derivative: +[Differentiable] +void s_bwd_add(DifferentialPair<float> dpa, DifferentialPair<float> dpb, float.Differential d_out) +{ + // A backward derivative method is currently responsible for 'setting' the differential values. + dpa = DifferentialPair<float>(dpa.p, d_out); + dpb = DifferentialPair<float>(dpb.p, d_out); +} +``` + +Unfortunately, this makes little sense if we decide to use buffer or pointer types: +```csharp +struct DiffPtr<T> : IDifferentiable +{ + StructuredBuffer<T> bufferRef; + uint64 offset; + + [Differentiable] T get() { ... } + [Differentiable] void set(T t) { ... } + /* + Problem 1: + We use custom derivatives for get() and set() to backprop and + read gradients. If DiffPtr<T> is differentiable, then get() and + set() need to operate on the *pair* type and not this struct type. + There is no proper way to do this currently. + */ +}; + +[Differentiable] +void add(DiffPtr<float> a, DiffPtr<float> b, DiffPtr<float> output) +{ + output.set(a.get() + b.get()); +} + +// Synthesized derivative: +[Differentiable] +void s_bwd_add( + inout DifferentialPair<DiffPtr<float>> a, + inout DifferentialPair<DiffPtr<float>> b, + inout DifferentialPair<DiffPtr<float>> output) +{ + /* + Problem 2: + + Current backward mode semantics require that the method assume that the differentials + a.d and b.d are empty/zero, and it is the backward method's job to populate the result. + + It doesn't make sense to 'set' the differential part since it is a buffer ref. + Rather, we want the user to provide the differential pointer, and use custom derivatives of + the getters/setters to propagate derivatives. + + This also means methods like dzero(), dadd() and dmul() make no sense + in the context of pointer types. They cannot be initialized within a derivative method. + */ +} + +``` + +## Workarounds +At the moment the primary workaround is to use a **non-differentiable buffer type** with differentiable methods, and always initialize the object with two pointers for both the primal and differential buffers. This is how our `DiffTensorView<T>` object works. +Unfortunately, this is a rather hacky workaround with several drawbacks: +1. `DiffTensorView<T>` does not conform to `IDifferentiable`, but is used for derivatives. This makes our type system less useful as checks for `is_subtype` from applications using reflection need workarounds to account for corner cases like these. +2. `DiffTensorView<T>` always has two buffer pointers even when used in non-differentiable methods. This is extra data in the struct, and potentially extra tensor allocations (we explicitly handle this case in `slangtorch` by leaving the diff part uninitialized if a primal method is invoked) +3. Higher-order derivatives don't work well with this workaround. Differentiating a method twice needs a set of 4 pointers, but we need to account for this ahead of time by using new types like `DiffDiffTensorView` that worsens the problem of carrying around extra data where its not required. + + +## Solution + +We'll need to make the following 4 additions/changes: +### 1. `[deriv_method]` function decorator. +Intended for easy definition of custom derivatives for struct methods. It has the following properties: +1. Accesses to `this` within `[deriv_method]` are differential pairs. +2. Methods decorated with `[deriv_method]` cannot be called as regular methods (they can still be explicitly invoked with `bwd_diff(obj.method)`), and do not show up in the auto-complete list. + +See the next section for example uses of `[deriv_method]`. + +### 2. Split `IDifferentiable` interface: `IDifferentiableValueType` and `IDifferentiablePtrType` +This approach moves away from "type-driven" derivative semantics and towards more "function-driven" derivative semantics. +We no longer have a `dadd` , `dzero`, `dmul` etc.. we use default initialization instead of `dzero` and the backward derivative of the `use` method for `dadd` + +Further, `IDifferentiablePtrType` types don't have any of these properties. They do not need a way to 'add', and it is especially important that there is no default initializer. We never want the compiler to be able to create a new object of `IDifferentiablePtrType` since we want to get the user-provided pointers. + +Additionally, we can use `IDifferentiableValueType` as the current `IDifferentiable` for backwards compatibility (it should just work in 95% of cases, since no one really defines dadd/dzero/dmul explicitly anyway) + +Here's the new set of base interfaces: +```csharp +interface __IDifferentiableBase { } // Helper type for our implementation. +interface IDifferentiableValueType : __IDifferentiableBase +{ + associatedtype Differential : IDifferentiableValueType & IDefaultInitializable; + [Differentiable] This use(); // auto-synthesized +} + +interface IDifferentiablePtrType : __IDifferentiableBase +{ + associatedtype Differential : IDifferentiablePtrType; +} + +``` + +Some extras in stdlib allow us to constrain the diffpair type for things like `IArithmetic` +```csharp +// --- STDLIB EXTRAS --- + +interface ISelfDifferentiableValueType : IDifferentiableValueType +{ + // Force arithmetic types to be a differential pair of the same two types. + // Make it simple to define derivatives of arithmetic operations. + // + associatedtype Differential : This; +} + +extension IFloat : ISelfDifferentiableValueType +{ } + +extension float +{ + // trivial auto-synthesis (maybe we even prevent the user from overriding this) + float use() { return this; } + + // trivial auto-synthesis (maybe we even prevent the user from overriding this). + [ForwardDerivativeOf(use)] + [deriv_method] void use_fwd() { return this; } + + // auto-synthesized if necessary by invoking the use_bwd for all fields. + // we need to provide implementation for 'leaf' types. + [BackwardDerivativeOf(use)] + [deriv_method] [mutating] void use_bwd(float d) { this.d += d; } +} + +// The new system lets us define differentiable pointers easily. +// IDifferentiablePtrType'd values are simply treated as references, so they can be freely +// duplicated without requiring a `use()` for correctness. +// +struct DPtr<T : IDifferentiableValueType> : IDifferentiablePtrType +{ + typealias Differential = DPtr<T.Differential>; + + Buffer<T> buffer; + uint64 offset; + + [BackwardDerivative(get_bwd)] + [BackwardDerivative(get_fwd)] + T get() { return this.buffer[offset]; } + + [deriv_method] DifferentialPair<T> get_fwd() + { + return diffPair(this.p.buffer[offset], this.d().buffer[offset]); + } + + [deriv_method] void get_bwd(Differential d) + { + return this.d.InterlockedAdd(offset, d); + } + + DPtr<T> operator+(uint o) { return DPtr<T>{buffer, offset + o}; } +} + +// Or we can define a fancier differentiable pointer that does a hashgrid +struct DHashGridPtr<T : IDifferentiableValueType, let N: int> : IDifferentiablePtrType +{ + typealias Differential = DPtr<T.Differential>; + + Buffer<T> buffer; + uint64 offset; + + [BackwardDerivative(get_bwd)] + [BackwardDerivative(get_fwd)] + T get() { return this.buffer[offset]; } + + [deriv_method] DifferentialPair<T> get_fwd() + { + return diffPair(this.p().buffer[offset], this.d().buffer[offset]); + } + + [deriv_method] void get_bwd(Differential d) + { + return this.d().InterlockedAdd(offset * N + hash(get_thread_id()), d); + } +} +``` + +### 3. Every time we 'reuse' an object that conforms to `IDifferentiableValueType`, we split it with `use()` , and we use `__init__()` where necessary to initialize an accumulator. +Example: +```csharp +float f(float a) +{ + add(a, a); +} +float add(float a, float b) +{ + return a + b; +} + +// Synthesized derivatives +void add_bwd(inout DiffPair<float> dpa, inout DiffPair<float> dpb, float d_out) +{ + dpa = diffPair(dpa.p, d_out); + dpb = diffPair(dpb.p, d_out); +} + +// Preprocessed-f (before derivative generation) +float f_with_use_expansion(float a) +{ + DiffPair<float> a_extra = a.use(); + return add(a, a_extra); +} + +// After fwd-mode: +DiffPair<float> f_fwd(DiffPair<float> dpa) +{ + DiffPair<float> dpa_extra = dpa.use_fwd(); + return add_fwd(a, a_extra_fwd); +} + + +// bwd-mode: +void f_bwd(inout DiffPair<float> dpa, float d_out) +{ + // fwd-pass + + // split + DiffPair<float> dpa_extra = dpa.use_fwd(); + // ------- + + // bwd-pass + dpa_extra_bwd = DiffPair<float>(dpa_extra.p, float.Differential::__init__()); + add_bwd(dpa, dpa_extra, d_out); + + // merge + dpa.use_bwd(dpa_extra); +} +``` + +### 4. Objects that conform to `IDifferentiablePtrType` are used without splitting. They are simply not 'transposed' at all, because there is nothing to transpose. The fwd-mode pair is used as is. +Here's the same example above, but with the `DPtr` type defined above. + +```csharp +void f(DPtr<float> a, DPtr<float> output) +{ + add(a, a, output); +} + +void add(DPtr<float> a, DPtr<float> b, DPtr<float> output) +{ + output.set(a.get() + b.get()); +} + +// Synthesized derivatives +// (note: no inout req'd for IDifferentiablePtrType) +// important difference is that `ptr` types don't get transposed, only +// methods on the objects are. +// they DO NOT have a default initializer (the user must supply the differential part) +void add_bwd( + DifferentialPair<DPtr<float>> dpa, + DifferentialPair<DPtr<float>> dpb, + DifferentialPair<DPtr<float>> output) +{ + // forward pass. + var a_p = dpa.p.get(); + var b_p = dpb.p.get(); + // ---- + + // backward pass. + float.Differential d_val = DPtr<float>::set_bwd(output); // set_bwd works on the entire pair. + DifferentialPair<float> a_get_bwd = diffPair(a_p, float.Differential::__init__()); + DifferentialPair<float> b_get_bwd = diffPair(b_p, float.Differential::__init__()); + operator_float_add_bwd(a_get_result_bwd, b_get_result_bwd, d_val); + DPtr<float>::get_bwd(dpa); + DPtr<float>::get_bwd(dpb); +} +```
\ No newline at end of file |
