From the chain rule to automatic differentiation

In What is automatic differentiation, and why does it work?, I gave a semantic model that explains what automatic differentiation (AD) accomplishes. Correct implementations then flowed from that model, by applying the principle of type class morphisms. (An instance’s interpretation is the interpretation’s instance).

I’ve had a nagging discomfort about the role of the chain rule in AD, with an intuition that the chain rule can carry a more central role the the specification and implementation. This post gives a variation on the previous AD post that carries the chain rule further into the reasining and implementation, leading to simpler correctness proofs and a nearly unaltered implementation.

Finally, as a bonus, I’ll show how GHC rewrite rules enable an even simpler and more modular implementation.

I’ve included some optional content, including exercises. You can see my answers to the exercises by examining the HTML.

As before, I’ll start with a limited form of differentiation that works for functions of a scalar (1D) domain, where one can identify derivative values with regular values:

deriv :: Num a => (a -> a) -> (a -> a) --  simplification

The development below extends to higher-order derivatives and higher-dimensional domains.

The chain rule

At the heart of AD is the chain rule:

deriv (g . f) x == deriv g (f x) * deriv f x

Equivalently,

deriv (g . f) == (deriv g . f) * deriv f

where this (*) is on functions: (*) = liftA2 (*) == h k x -> h x * k x.

The traditional forward AD formulation is based on the chain rule but is not as symmetric as the chain rule. In the function compositions g . f considered, g is always simple, while f may be arbitrarily complex. (I think the reverse true for reverse-mode AD.) What might we find if we delay introducing this asymmetry?

A direct implementation of the chain rule

The chain rule applies to functions and their derivatives, so let’s formulate a direct implementation. Start with a type for holding two functions:

data FD a = FD (a -> a) (a -> a)

FD is used to hold functions and their derivatives:

toFD :: (a -> a) -> FD a
toFD f = FD f (deriv f)

We do not have an implementation of deriv, so toFD here is part of the specification only, not the implementation.

Now we can specify a composition operator on FD:

(~.~) :: FD a -> FD a -> FD a

We’ll want (~.~) to represent composition of functions, where we have access to the derivatives as well. That is, (~.~) must satisfy:

toFD (g . f) == toFD g ~.~ toFD f

The implementation and its correctness follow from the chain rule:

FD g g' ~.~ FD f f' = FD (g . f) ((g' . f) * f')

Exercise: Fill in the proof that (~.~) satisfies its specification.

(Exercise solutions are in the post’s HTML.)

From function to values

The FD type and its composition operator implement the chain rule quite directly. However, they are not suitable for AD, which operates on the values (range) of a function and of its derivative.

Let’s start over with the usual AD value representation:

data D a = D a a

Previously, I defined this ideal construction function:

toD :: (a -> a) -> a -> D a
toD f x = D (f x) (deriv f x)

-- or
toD == liftA2 D f (deriv f)

Instead, let’s now define toD in terms of toFD. First, how do the FD and D representations relate?

fdAt :: FD a -> (a -> D a)
fdAt (FD f f') = liftA2 D f f'

Then we can define toD:

toD = fdAt . toFD

Exercise: Show that these definitions of toD are equivalent.

Again, toD isn’t executable with this definition, because toFD isn’t (because deriv isn’t). As before, toD must be eliminated in our journey from specification to implementation.

Optional: We can also define an odd sort of inverse for toD:

fromD :: D a -> a -> (a -> a)

fromD must satisfy, for all x,

toD (fromD d x) x == d

It’s more convenient to relate flipped version of toD and fromD:

toD'   :: a -> (a -> a) ->   D a
fromD' :: a ->    D a   -> (a -> a)

toD'   = flip toD
fromD' = flip fromD

Then fromD must satisfy, for all x,

toD' x . fromD' x == id

Exercise: Give a simple definition for fromD and show that it’s correct (satisfies its specification).

A general, value-friendly chain rule

In What is AD …?, I defined correctness of the numeric class instances D by saying that toD must be a type class morphism for each of the numeric classes it implements. For example, let’s take the sin method. The other unary methods will work just like it. The morphism property:

toD (sin u) == sin (toD u)

Because of numeric overloading on functions, this property is equivalent to a more explicit one:

toD (sin . u) == sin . toD u

The sin on the left is on numbers, and the sin on the right is on D a.

Let’s suppose we have a function adiff (for automatic differentiation) such that for all g and f,

toD (g . f) == adiff g . toD f      -- specification of adiff

adiff :: Num => (a -> a) -> (D a -> D a)

Then our goal would become

adiff sin . toD u == sin . toD u

and a correct definition of sin would be immediate, as would be the other definitions:

sin  = adiff sin
sqrt = adiff sqrt
...

Note that the adiff specification above implies that for all g,

toD g == adiff g . toD id

Exercise: Show that a necessary and sufficient definition for adiff satisfying its specification is

adiff g (D a a') = D (g a) (deriv g a * a')

Derive this definition of adiff from its specification.

The adiff function satisfies a more symmetric property as well. It distributes over composition:

adiff (h . g) == adiff h . adiff g

Exercise: Prove it this property from the specification.

Moreover, adiff maps the identity to the identity: adiff id = id.

Exercise: Show that for any definition of adiff, if for all g,

toD g == adiff g . toD id

and if for all h and g

adiff (h . g) == adiff h . adiff g

then our adiff specification holds, i.e., for g and f,

adiff g . toD f == toD (g . f)

Back to an implementation

We’re still not quite done, since adiff depends on deriv, which doesn’t have an implementation. Let’s separate out the problematic deriv by refactoring adiff:

adiff g = g >-< deriv g

where

infix  0 >-<
(>-<) :: Num a => (a -> a) -> (a -> a) -> (D a -> D a)
(g >-< g') (D a a') = D (g a) (g' a * a')

After inlining this definition of adiff, the method definitions are

sin  = sin  >-< deriv sin
sqrt = sqrt >-< deriv sqrt
...

Every remaining use of deriv is applied to a function whose derivative is known, so we can replace each use.

sin  = sin  >-< cos
sqrt = sqrt >-< recip (2 * sqrt)
...

Now we have an executable implementation again. These method definitions and the definition of (>-<) are exactly as in What is automatic differentiation, and why does it work?.

Fun with rules

Let’s back up to our more elegant method definitions using adiff:

sin  = adiff sin
sqrt = adiff sqrt
...

We made these definitions executable in spite of their appeal to the non-executable deriv by (a) refactoring adiff to split the deriv from the residual function (>-<), (b) inlining adiff, and (c) rewriting applications of deriv with known derivative rules.

Now let’s get GHC to do these steps for us.

List the derivatives of known functions:

{-# RULES

"deriv negate"  deriv negate  =  -1
"deriv abs"     deriv abs     =  signum
"deriv signum"  deriv signum  =  0
"deriv recip"   deriv recip   =  - sqr recip
"deriv exp"     deriv exp     =  exp
"deriv log"     deriv log     =  recip
"deriv sqrt"    deriv sqrt    =  recip (2 * sqrt)
"deriv sin"     deriv sin     =  cos
"deriv cos"     deriv cos     =  - sin
"deriv asin"    deriv asin    =  recip (sqrt (1-sqr))
"deriv acos"    deriv acos    =  recip (- sqrt (1-sqr))
"deriv atan"    deriv atan    =  recip (1+sqr)
"deriv sinh"    deriv sinh    =  cosh
"deriv cosh"    deriv cosh    =  sinh
"deriv asinh"   deriv asinh   =  recip (sqrt (1+sqr))
"deriv acosh"   deriv acosh   =  recip (- sqrt (sqr-1))
"deriv atanh"   deriv atanh   =  recip (1-sqr)

 #-}

Notice that these definitions are simpler and more modular than the standard differentiation rules, as they do not have the chain rule mixed in. For instance, compare (a) deriv sin = cos, (b) deriv (sin u) == cos u * deriv u, and (c) deriv (sin u) x == cos u x * deriv u x.

Now we can use the incredibly simple adiff-based definitions of our methods, e.g., asin = adiff asin.

The definition of adiff must get inlined so as to reveal the deriv applications, which then get rewritten according to the rules. Fortunately, the adiff definition is tiny, which encourages its inlining. We could add an INLINE pragma as a reminder. GHC requires that a definition must be given deriv, even it all uses are rewritten away, so use the following

deriv = error "deriv: undefined.  Missing rewrite rule?"

4 Comments

  1. newsham:

    I don’t know whether to laugh or cry. I think I like it.

  2. Ryan Ingram:

    The use of rewrite rules is a cool hack, but it feels too fragile for production code. The “rule” I use for my code (apologies for the pun) is that rewrite rules should not be relied on for correctness, only performance.

  3. conal:

    Hi Ryan,

    … rewrite rules should not be relied on for correctness, …

    That guideline is the conventional wisdom I hear. I hope that examples like the one above motivate a shift, so that rewrite rules can be counted on, considering the boost in modularity that results.

  4. Ryan Ingram:

    I see where you are going, but you’re relying on a lot of compiler machinery (correct inlining, rewrites, etc.) in order to get a modest gain (adiff g = g >-< deriv g).

    I’d prefer a solution in which this reliance is made explicit at compile time; basically relying on Template Haskell or the like to enforce that the g passed to adiff is a compile-time-constant which can be safely passed to deriv.

    In fact, I think TH could be useful for a lot of things if the syntax for invoking it wasn’t so ugly. I’d love for a “transparent” TH. Consider a macro application in LISP: in the source code, it looks just like a function call. Perhaps a TH function should get some form of context (the AST it is a part of) and be allowed to return parts of that context “unmodified” after picking out the data it expects, instead of needing to use the ugly splice syntax.

    This would give you the same power that you get from rewrite rules but relying on the specification of the language rather than the internals of the optimizer.

Leave a comment