diff options
| author | Anders Leino <aleino@nvidia.com> | 2025-02-11 02:42:08 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-10 16:42:08 -0800 |
| commit | 3c2d46aa1c8575dc046d7457793e77c7a4789093 (patch) | |
| tree | b333f6f799b975ca4e18fadd3cec0a144d8ec082 /docs/proposals/010-new-diff-type-system.md | |
| parent | 133bd259c00984c6a01869f71951a7feb919463a (diff) | |
Remove the docs/proposals directory (#6313)
* Remove the docs/proposals directory
This directory will get added to the spec repository in the following PR:
https://github.com/shader-slang/spec/pull/6
This closes #6155.
* Remove entry from .github/CODEOWNERS file
* Redirect some proposal references
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'docs/proposals/010-new-diff-type-system.md')
| -rw-r--r-- | docs/proposals/010-new-diff-type-system.md | 285 |
1 files changed, 0 insertions, 285 deletions
diff --git a/docs/proposals/010-new-diff-type-system.md b/docs/proposals/010-new-diff-type-system.md deleted file mode 100644 index 61157819e..000000000 --- a/docs/proposals/010-new-diff-type-system.md +++ /dev/null @@ -1,285 +0,0 @@ -# SP #010: New Differentiable Type System - -## Problem -Our current `IDifferentiable` system has some flaws. It works fine for value types, since we can assume that every input gets a corresponding output or 'return' value. It works poorly for buffer/pointer types, since we don't 'return' a buffer, but simply want the getters/setters to be differentiable, and the resulting type to have a second buffer/pointer for the differential data. - -Here's a demonstrative example with our current codebase when we use value types (like `float`) -```csharp -[Differentiable] -float add(float a, float b) -{ - return a + b; -} - -// Synthesized derivative: -[Differentiable] -void s_bwd_add(DifferentialPair<float> dpa, DifferentialPair<float> dpb, float.Differential d_out) -{ - // A backward derivative method is currently responsible for 'setting' the differential values. - dpa = DifferentialPair<float>(dpa.p, d_out); - dpb = DifferentialPair<float>(dpb.p, d_out); -} -``` - -Unfortunately, this makes little sense if we decide to use buffer or pointer types: -```csharp -struct DiffPtr<T> : IDifferentiable -{ - StructuredBuffer<T> bufferRef; - uint64 offset; - - [Differentiable] T get() { ... } - [Differentiable] void set(T t) { ... } - /* - Problem 1: - We use custom derivatives for get() and set() to backprop and - read gradients. If DiffPtr<T> is differentiable, then get() and - set() need to operate on the *pair* type and not this struct type. - There is no proper way to do this currently. - */ -}; - -[Differentiable] -void add(DiffPtr<float> a, DiffPtr<float> b, DiffPtr<float> output) -{ - output.set(a.get() + b.get()); -} - -// Synthesized derivative: -[Differentiable] -void s_bwd_add( - inout DifferentialPair<DiffPtr<float>> a, - inout DifferentialPair<DiffPtr<float>> b, - inout DifferentialPair<DiffPtr<float>> output) -{ - /* - Problem 2: - - Current backward mode semantics require that the method assume that the differentials - a.d and b.d are empty/zero, and it is the backward method's job to populate the result. - - It doesn't make sense to 'set' the differential part since it is a buffer ref. - Rather, we want the user to provide the differential pointer, and use custom derivatives of - the getters/setters to propagate derivatives. - - This also means methods like dzero(), dadd() and dmul() make no sense - in the context of pointer types. They cannot be initialized within a derivative method. - */ -} - -``` - -## Workarounds -At the moment the primary workaround is to use a **non-differentiable buffer type** with differentiable methods, and always initialize the object with two pointers for both the primal and differential buffers. This is how our `DiffTensorView<T>` object works. -Unfortunately, this is a rather hacky workaround with several drawbacks: -1. `DiffTensorView<T>` does not conform to `IDifferentiable`, but is used for derivatives. This makes our type system less useful as checks for `is_subtype` from applications using reflection need workarounds to account for corner cases like these. -2. `DiffTensorView<T>` always has two buffer pointers even when used in non-differentiable methods. This is extra data in the struct, and potentially extra tensor allocations (we explicitly handle this case in `slangtorch` by leaving the diff part uninitialized if a primal method is invoked) -3. Higher-order derivatives don't work well with this workaround. Differentiating a method twice needs a set of 4 pointers, but we need to account for this ahead of time by using new types like `DiffDiffTensorView` that worsens the problem of carrying around extra data where its not required. - - -## Solution - -We'll need to make the following 4 additions/changes: -### 1. `[deriv_method]` function decorator. -Intended for easy definition of custom derivatives for struct methods. It has the following properties: -1. Accesses to `this` within `[deriv_method]` are differential pairs. -2. Methods decorated with `[deriv_method]` cannot be called as regular methods (they can still be explicitly invoked with `bwd_diff(obj.method)`), and do not show up in the auto-complete list. - -See the next section for example uses of `[deriv_method]`. - -### 2. Split `IDifferentiable` interface: `IDifferentiableValueType` and `IDifferentiablePtrType` -This approach moves away from "type-driven" derivative semantics and towards more "function-driven" derivative semantics. -We no longer have a `dadd` , `dzero`, `dmul` etc.. we use default initialization instead of `dzero` and the backward derivative of the `use` method for `dadd` - -Further, `IDifferentiablePtrType` types don't have any of these properties. They do not need a way to 'add', and it is especially important that there is no default initializer. We never want the compiler to be able to create a new object of `IDifferentiablePtrType` since we want to get the user-provided pointers. - -Additionally, we can use `IDifferentiableValueType` as the current `IDifferentiable` for backwards compatibility (it should just work in 95% of cases, since no one really defines dadd/dzero/dmul explicitly anyway) - -Here's the new set of base interfaces: -```csharp -interface __IDifferentiableBase { } // Helper type for our implementation. -interface IDifferentiableValueType : __IDifferentiableBase -{ - associatedtype Differential : IDifferentiableValueType & IDefaultInitializable; - [Differentiable] This use(); // auto-synthesized -} - -interface IDifferentiablePtrType : __IDifferentiableBase -{ - associatedtype Differential : IDifferentiablePtrType; -} - -``` - -Some extras in the core module allow us to constrain the diffpair type for things like `IArithmetic` -```csharp -// --- CORE MODULE EXTRAS --- - -interface ISelfDifferentiableValueType : IDifferentiableValueType -{ - // Force arithmetic types to be a differential pair of the same two types. - // Make it simple to define derivatives of arithmetic operations. - // - associatedtype Differential : This; -} - -extension IFloat : ISelfDifferentiableValueType -{ } - -extension float -{ - // trivial auto-synthesis (maybe we even prevent the user from overriding this) - float use() { return this; } - - // trivial auto-synthesis (maybe we even prevent the user from overriding this). - [ForwardDerivativeOf(use)] - [deriv_method] void use_fwd() { return this; } - - // auto-synthesized if necessary by invoking the use_bwd for all fields. - // we need to provide implementation for 'leaf' types. - [BackwardDerivativeOf(use)] - [deriv_method] [mutating] void use_bwd(float d) { this.d += d; } -} - -// The new system lets us define differentiable pointers easily. -// IDifferentiablePtrType'd values are simply treated as references, so they can be freely -// duplicated without requiring a `use()` for correctness. -// -struct DPtr<T : IDifferentiableValueType> : IDifferentiablePtrType -{ - typealias Differential = DPtr<T.Differential>; - - Buffer<T> buffer; - uint64 offset; - - [BackwardDerivative(get_bwd)] - [BackwardDerivative(get_fwd)] - T get() { return this.buffer[offset]; } - - [deriv_method] DifferentialPair<T> get_fwd() - { - return diffPair(this.p.buffer[offset], this.d().buffer[offset]); - } - - [deriv_method] void get_bwd(Differential d) - { - return this.d.InterlockedAdd(offset, d); - } - - DPtr<T> operator+(uint o) { return DPtr<T>{buffer, offset + o}; } -} - -// Or we can define a fancier differentiable pointer that does a hashgrid -struct DHashGridPtr<T : IDifferentiableValueType, let N: int> : IDifferentiablePtrType -{ - typealias Differential = DPtr<T.Differential>; - - Buffer<T> buffer; - uint64 offset; - - [BackwardDerivative(get_bwd)] - [BackwardDerivative(get_fwd)] - T get() { return this.buffer[offset]; } - - [deriv_method] DifferentialPair<T> get_fwd() - { - return diffPair(this.p().buffer[offset], this.d().buffer[offset]); - } - - [deriv_method] void get_bwd(Differential d) - { - return this.d().InterlockedAdd(offset * N + hash(get_thread_id()), d); - } -} -``` - -### 3. Every time we 'reuse' an object that conforms to `IDifferentiableValueType`, we split it with `use()` , and we use `__init__()` where necessary to initialize an accumulator. -Example: -```csharp -float f(float a) -{ - add(a, a); -} -float add(float a, float b) -{ - return a + b; -} - -// Synthesized derivatives -void add_bwd(inout DiffPair<float> dpa, inout DiffPair<float> dpb, float d_out) -{ - dpa = diffPair(dpa.p, d_out); - dpb = diffPair(dpb.p, d_out); -} - -// Preprocessed-f (before derivative generation) -float f_with_use_expansion(float a) -{ - DiffPair<float> a_extra = a.use(); - return add(a, a_extra); -} - -// After fwd-mode: -DiffPair<float> f_fwd(DiffPair<float> dpa) -{ - DiffPair<float> dpa_extra = dpa.use_fwd(); - return add_fwd(a, a_extra_fwd); -} - - -// bwd-mode: -void f_bwd(inout DiffPair<float> dpa, float d_out) -{ - // fwd-pass - - // split - DiffPair<float> dpa_extra = dpa.use_fwd(); - // ------- - - // bwd-pass - dpa_extra_bwd = DiffPair<float>(dpa_extra.p, float.Differential::__init__()); - add_bwd(dpa, dpa_extra, d_out); - - // merge - dpa.use_bwd(dpa_extra); -} -``` - -### 4. Objects that conform to `IDifferentiablePtrType` are used without splitting. They are simply not 'transposed' at all, because there is nothing to transpose. The fwd-mode pair is used as is. -Here's the same example above, but with the `DPtr` type defined above. - -```csharp -void f(DPtr<float> a, DPtr<float> output) -{ - add(a, a, output); -} - -void add(DPtr<float> a, DPtr<float> b, DPtr<float> output) -{ - output.set(a.get() + b.get()); -} - -// Synthesized derivatives -// (note: no inout req'd for IDifferentiablePtrType) -// important difference is that `ptr` types don't get transposed, only -// methods on the objects are. -// they DO NOT have a default initializer (the user must supply the differential part) -void add_bwd( - DifferentialPair<DPtr<float>> dpa, - DifferentialPair<DPtr<float>> dpb, - DifferentialPair<DPtr<float>> output) -{ - // forward pass. - var a_p = dpa.p.get(); - var b_p = dpb.p.get(); - // ---- - - // backward pass. - float.Differential d_val = DPtr<float>::set_bwd(output); // set_bwd works on the entire pair. - DifferentialPair<float> a_get_bwd = diffPair(a_p, float.Differential::__init__()); - DifferentialPair<float> b_get_bwd = diffPair(b_p, float.Differential::__init__()); - operator_float_add_bwd(a_get_result_bwd, b_get_result_bwd, d_val); - DPtr<float>::get_bwd(dpa); - DPtr<float>::get_bwd(dpb); -} -``` |
