## From the chain rule to automatic differentiation

In What is automatic differentiation, and why does it work?, I gave a semantic model that explains what automatic differentiation (AD) accomplishes. Correct implementations then flowed from that model, by applying the principle of type class morphisms. (An instance’s interpretation is the interpretation’s instance).

I’ve had a nagging discomfort about the role of the chain rule in AD, with an intuition that the chain rule can carry a more central role the the specification and implementation. This post gives a variation on the previous AD post that carries the chain rule further into the reasining and implementation, leading to simpler correctness proofs and a nearly unaltered implementation.

Finally, as a bonus, I’ll show how GHC rewrite rules enable an even simpler and more modular implementation.

I’ve included some optional content, including exercises. You can see my answers to the exercises by examining the HTML.

As before, I’ll start with a limited form of differentiation that works for functions of a scalar (1D) domain, where one can identify derivative values with regular values:

``````deriv :: Num a => (a -> a) -> (a -> a) --  simplification
``````

The development below extends to higher-order derivatives and higher-dimensional domains.

### The chain rule

At the heart of AD is the chain rule:

``````deriv (g . f) x == deriv g (f x) * deriv f x
``````

Equivalently,

``````deriv (g . f) == (deriv g . f) * deriv f
``````

where this `(*)` is on functions: `(*) = liftA2 (*) == h k x -> h x * k x`.

The traditional forward AD formulation is based on the chain rule but is not as symmetric as the chain rule. In the function compositions `g . f` considered, `g` is always simple, while `f` may be arbitrarily complex. (I think the reverse true for reverse-mode AD.) What might we find if we delay introducing this asymmetry?

### A direct implementation of the chain rule

The chain rule applies to functions and their derivatives, so let’s formulate a direct implementation. Start with a type for holding two functions:

``````data FD a = FD (a -> a) (a -> a)
``````

`FD` is used to hold functions and their derivatives:

``````toFD :: (a -> a) -> FD a
toFD f = FD f (deriv f)
``````

We do not have an implementation of `deriv`, so `toFD` here is part of the specification only, not the implementation.

Now we can specify a composition operator on `FD`:

``````(~.~) :: FD a -> FD a -> FD a
``````

We’ll want `(~.~)` to represent composition of functions, where we have access to the derivatives as well. That is, `(~.~)` must satisfy:

``````toFD (g . f) == toFD g ~.~ toFD f
``````

The implementation and its correctness follow from the chain rule:

``````FD g g' ~.~ FD f f' = FD (g . f) ((g' . f) * f')
``````

Exercise: Fill in the proof that `(~.~)` satisfies its specification.

(Exercise solutions are in the post’s HTML.)

### From function to values

The `FD` type and its composition operator implement the chain rule quite directly. However, they are not suitable for AD, which operates on the values (range) of a function and of its derivative.

Let’s start over with the usual AD value representation:

``````data D a = D a a
``````

Previously, I defined this ideal construction function:

``````toD :: (a -> a) -> a -> D a
toD f x = D (f x) (deriv f x)

-- or
toD == liftA2 D f (deriv f)
``````

Instead, let’s now define `toD` in terms of `toFD`. First, how do the `FD` and `D` representations relate?

``````fdAt :: FD a -> (a -> D a)
fdAt (FD f f') = liftA2 D f f'
``````

Then we can define `toD`:

``````toD = fdAt . toFD
``````

Exercise: Show that these definitions of `toD` are equivalent.

Again, `toD` isn’t executable with this definition, because `toFD` isn’t (because `deriv` isn’t). As before, `toD` must be eliminated in our journey from specification to implementation.

Optional: We can also define an odd sort of inverse for `toD`:

``````fromD :: D a -> a -> (a -> a)
``````

`fromD` must satisfy, for all `x`,

``````toD (fromD d x) x == d
``````

It’s more convenient to relate flipped version of `toD` and `fromD`:

``````toD'   :: a -> (a -> a) ->   D a
fromD' :: a ->    D a   -> (a -> a)

toD'   = flip toD
fromD' = flip fromD
``````

Then `fromD` must satisfy, for all `x`,

``````toD' x . fromD' x == id
``````

Exercise: Give a simple definition for `fromD` and show that it’s correct (satisfies its specification).

### A general, value-friendly chain rule

In What is AD …?, I defined correctness of the numeric class instances `D` by saying that `toD` must be a type class morphism for each of the numeric classes it implements. For example, let’s take the `sin` method. The other unary methods will work just like it. The morphism property:

``````toD (sin u) == sin (toD u)
``````

Because of numeric overloading on functions, this property is equivalent to a more explicit one:

``````toD (sin . u) == sin . toD u
``````

The `sin` on the left is on numbers, and the `sin` on the right is on `D a`.

Let’s suppose we have a function `adiff` (for automatic differentiation) such that for all `g` and `f`,

``````toD (g . f) == adiff g . toD f      -- specification of adiff

adiff :: Num => (a -> a) -> (D a -> D a)
``````

Then our goal would become

``````adiff sin . toD u == sin . toD u
``````

and a correct definition of `sin` would be immediate, as would be the other definitions:

``````sin  = adiff sin
sqrt = adiff sqrt
...
``````

Note that the `adiff` specification above implies that for all `g`,

``````toD g == adiff g . toD id
``````

Exercise: Show that a necessary and sufficient definition for `adiff` satisfying its specification is

``````adiff g (D a a') = D (g a) (deriv g a * a')
``````

Derive this definition of `adiff` from its specification.

The `adiff` function satisfies a more symmetric property as well. It distributes over composition:

``````adiff (h . g) == adiff h . adiff g
``````

Exercise: Prove it this property from the specification.

Moreover, `adiff` maps the identity to the identity: `adiff id = id`.

Exercise: Show that for any definition of `adiff`, if for all `g`,

``````toD g == adiff g . toD id
``````

and if for all `h` and `g`

``````adiff (h . g) == adiff h . adiff g
``````

then our `adiff` specification holds, i.e., for `g` and `f`,

``````adiff g . toD f == toD (g . f)
``````

### Back to an implementation

We’re still not quite done, since `adiff` depends on `deriv`, which doesn’t have an implementation. Let’s separate out the problematic `deriv` by refactoring `adiff`:

``````adiff g = g >-< deriv g
``````

where

``````infix  0 >-<
(>-<) :: Num a => (a -> a) -> (a -> a) -> (D a -> D a)
(g >-< g') (D a a') = D (g a) (g' a * a')
``````

After inlining this definition of `adiff`, the method definitions are

``````sin  = sin  >-< deriv sin
sqrt = sqrt >-< deriv sqrt
...
``````

Every remaining use of `deriv` is applied to a function whose derivative is known, so we can replace each use.

``````sin  = sin  >-< cos
sqrt = sqrt >-< recip (2 * sqrt)
...
``````

Now we have an executable implementation again. These method definitions and the definition of `(>-<)` are exactly as in What is automatic differentiation, and why does it work?.

### Fun with rules

Let’s back up to our more elegant method definitions using `adiff`:

``````sin  = adiff sin
sqrt = adiff sqrt
...
``````

We made these definitions executable in spite of their appeal to the non-executable `deriv` by (a) refactoring `adiff` to split the `deriv` from the residual function `(>-<)`, (b) inlining `adiff`, and (c) rewriting applications of `deriv` with known derivative rules.

Now let’s get GHC to do these steps for us.

List the derivatives of known functions:

``````{-# RULES

"deriv negate"  deriv negate  =  -1
"deriv abs"     deriv abs     =  signum
"deriv signum"  deriv signum  =  0
"deriv recip"   deriv recip   =  - sqr recip
"deriv exp"     deriv exp     =  exp
"deriv log"     deriv log     =  recip
"deriv sqrt"    deriv sqrt    =  recip (2 * sqrt)
"deriv sin"     deriv sin     =  cos
"deriv cos"     deriv cos     =  - sin
"deriv asin"    deriv asin    =  recip (sqrt (1-sqr))
"deriv acos"    deriv acos    =  recip (- sqrt (1-sqr))
"deriv atan"    deriv atan    =  recip (1+sqr)
"deriv sinh"    deriv sinh    =  cosh
"deriv cosh"    deriv cosh    =  sinh
"deriv asinh"   deriv asinh   =  recip (sqrt (1+sqr))
"deriv acosh"   deriv acosh   =  recip (- sqrt (sqr-1))
"deriv atanh"   deriv atanh   =  recip (1-sqr)

#-}
``````

Notice that these definitions are simpler and more modular than the standard differentiation rules, as they do not have the chain rule mixed in. For instance, compare (a) `deriv sin = cos`, (b) `deriv (sin u) == cos u * deriv u`, and (c) `deriv (sin u) x == cos u x * deriv u x`.

Now we can use the incredibly simple `adiff`-based definitions of our methods, e.g., `asin = adiff asin`.

The definition of `adiff` must get inlined so as to reveal the `deriv` applications, which then get rewritten according to the rules. Fortunately, the `adiff` definition is tiny, which encourages its inlining. We could add an `INLINE` pragma as a reminder. GHC requires that a definition must be given `deriv`, even it all uses are rewritten away, so use the following

``````deriv = error "deriv: undefined.  Missing rewrite rule?"
``````

### 4 Comments

1. #### newsham:

I don’t know whether to laugh or cry. I think I like it.

2. #### Ryan Ingram:

The use of rewrite rules is a cool hack, but it feels too fragile for production code. The “rule” I use for my code (apologies for the pun) is that rewrite rules should not be relied on for correctness, only performance.

3. #### conal:

Hi Ryan,

… rewrite rules should not be relied on for correctness, …

That guideline is the conventional wisdom I hear. I hope that examples like the one above motivate a shift, so that rewrite rules can be counted on, considering the boost in modularity that results.

4. #### Ryan Ingram:

I see where you are going, but you’re relying on a lot of compiler machinery (correct inlining, rewrites, etc.) in order to get a modest gain (adiff g = g >-< deriv g).

I’d prefer a solution in which this reliance is made explicit at compile time; basically relying on Template Haskell or the like to enforce that the g passed to adiff is a compile-time-constant which can be safely passed to deriv.

In fact, I think TH could be useful for a lot of things if the syntax for invoking it wasn’t so ugly. I’d love for a “transparent” TH. Consider a macro application in LISP: in the source code, it looks just like a function call. Perhaps a TH function should get some form of context (the AST it is a part of) and be allowed to return parts of that context “unmodified” after picking out the data it expects, instead of needing to use the ugly splice syntax.

This would give you the same power that you get from rewrite rules but relying on the specification of the language rather than the internals of the optimizer.