From d8a40abba5223fbcb56c52b04ccb88c02bbaf79f Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 21 Mar 2023 21:29:13 -0700 Subject: [TreatAsDifferentiable] functions. (#2720) --- docs/user-guide/07-autodiff.md | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) (limited to 'docs') diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index 14b92774d..f2bf4dcc5 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -687,6 +687,40 @@ float f(float x) However, the `no_diff` keyword is not required in a call if a non-differentiable function does not take any differentiable parameters, or if the result of the differentiable function is not dependant on the derivative being propagated through the call. +### Treat Non-Differentiable Functions as Differentiable +Slang allows functions to be marked with a `[TreatAsDifferentiable]` attribute for them to be considered as differentiable functions by the type-system. When a function is marked as `[TreatAsDifferentiable]`, the compiler will not generate derivative propagation code from the original function body or perform any additional checking on the function definition. Instead, it will generate trivial forward and backward propagation functions that returns 0. + +This feature can be useful if the user marked an `interface` method as forward or backward differentiable, but only wish to provide non-trivial derivative propagation functions for a subset of types that implement the interface. For other types that does not actually need differentiation, the user can simply put `[TreatAsDifferentiable]` on the method implementations for them to satisfy the interface requirement. + +See the following code for an example of `[TreatAsDifferentiable]`: +```csharp +interface IFoo +{ + [BackwardDifferentiable] + float f(float v); +} + +struct B : IFoo +{ + [TreatAsDifferentiable] + float f(float v) + { + return v * v; + } +} + +[BackwardDifferentiable] +float use(IFoo o, float x) +{ + return o.f(x); +} + +// Test: +B obj; +float result = __fwd_diff(use)(obj, diffPair(2.0, 1.0)).d; +// result == 0.0, since `[TreatAsDifferentiable]` causes a trivial derivative implementation +// being generated regardless of the original code. +``` ## Higher Order Differentiation -- cgit v1.2.3