From the chain rule to automatic differentiation
In What is automatic differentiation, and why does it work?, I gave a semantic model that explains what automatic differentiation (AD) accomplishes. Correct implementations then flowed from that model, by applying the principle of type class morphisms. (An instance’s interpretation is the interpretation’s instance).
I’ve had a nagging discomfort about the role of the chain rule in AD, with an intuition that the chain rule can carry a more central role the the specification and implementation. This post gives a variation on the previous AD post that carries the chain rule further into the reasining and implementation, leading to simpler correctness proofs and a nearly unaltered implementation.
Finally, as a bonus, I’ll show how GHC rewrite rules enable an even simpler and more modular implementation.
I’ve included some optional content, including exercises. You can see my answers to the exercises by examining the HTML.
As before, I’ll start with a limited form of differentiation that works for functions of a scalar (1D) domain, where one can identify derivative values with regular values:
deriv :: Num a => (a -> a) -> (a -> a) -- simplification
The development below extends to higher-order derivatives and higher-dimensional domains.
The chain rule
At the heart of AD is the chain rule:
deriv (g . f) x == deriv g (f x) * deriv f x
Equivalently,
deriv (g . f) == (deriv g . f) * deriv f
where this (*)
is on functions: (*) = liftA2 (*) == h k x -> h x * k x
.
The traditional forward AD formulation is based on the chain rule but is not as symmetric as the chain rule.
In the function compositions g . f
considered, g
is always simple, while f
may be arbitrarily complex.
(I think the reverse true for reverse-mode AD.)
What might we find if we delay introducing this asymmetry?
A direct implementation of the chain rule
The chain rule applies to functions and their derivatives, so let’s formulate a direct implementation. Start with a type for holding two functions:
data FD a = FD (a -> a) (a -> a)
FD
is used to hold functions and their derivatives:
toFD :: (a -> a) -> FD a
toFD f = FD f (deriv f)
We do not have an implementation of deriv
, so toFD
here is part of the specification only, not the implementation.
Now we can specify a composition operator on FD
:
(~.~) :: FD a -> FD a -> FD a
We’ll want (~.~)
to represent composition of functions, where we have access to the derivatives as well.
That is, (~.~)
must satisfy:
toFD (g . f) == toFD g ~.~ toFD f
The implementation and its correctness follow from the chain rule:
FD g g' ~.~ FD f f' = FD (g . f) ((g' . f) * f')
Exercise: Fill in the proof that (~.~)
satisfies its specification.
(Exercise solutions are in the post’s HTML.)
From function to values
The FD
type and its composition operator implement the chain rule quite directly.
However, they are not suitable for AD, which operates on the values (range) of a function and of its derivative.
Let’s start over with the usual AD value representation:
data D a = D a a
Previously, I defined this ideal construction function:
toD :: (a -> a) -> a -> D a
toD f x = D (f x) (deriv f x)
-- or
toD == liftA2 D f (deriv f)
Instead, let’s now define toD
in terms of toFD
.
First, how do the FD
and D
representations relate?
fdAt :: FD a -> (a -> D a)
fdAt (FD f f') = liftA2 D f f'
Then we can define toD
:
toD = fdAt . toFD
Exercise: Show that these definitions of toD
are equivalent.
Again, toD
isn’t executable with this definition, because toFD
isn’t (because deriv
isn’t).
As before, toD
must be eliminated in our journey from specification to implementation.
Optional:
We can also define an odd sort of inverse for toD
:
fromD :: D a -> a -> (a -> a)
fromD
must satisfy, for all x
,
toD (fromD d x) x == d
It’s more convenient to relate flipped version of toD
and fromD
:
toD' :: a -> (a -> a) -> D a
fromD' :: a -> D a -> (a -> a)
toD' = flip toD
fromD' = flip fromD
Then fromD
must satisfy, for all x
,
toD' x . fromD' x == id
Exercise: Give a simple definition for fromD
and show that it’s correct (satisfies its specification).
A general, value-friendly chain rule
In What is AD …?, I defined correctness of the numeric class instances D
by saying that toD
must be a type class morphism for each of the numeric classes it implements.
For example, let’s take the sin
method.
The other unary methods will work just like it.
The morphism property:
toD (sin u) == sin (toD u)
Because of numeric overloading on functions, this property is equivalent to a more explicit one:
toD (sin . u) == sin . toD u
The sin
on the left is on numbers, and the sin
on the right is on D a
.
Let’s suppose we have a function adiff
(for automatic differentiation) such that for all g
and f
,
toD (g . f) == adiff g . toD f -- specification of adiff
adiff :: Num => (a -> a) -> (D a -> D a)
Then our goal would become
adiff sin . toD u == sin . toD u
and a correct definition of sin
would be immediate, as would be the other definitions:
sin = adiff sin
sqrt = adiff sqrt
...
Note that the adiff
specification above implies that for all g
,
toD g == adiff g . toD id
Exercise: Show that a necessary and sufficient definition for adiff
satisfying its specification is
adiff g (D a a') = D (g a) (deriv g a * a')
Derive this definition of adiff
from its specification.
The adiff
function satisfies a more symmetric property as well.
It distributes over composition:
adiff (h . g) == adiff h . adiff g
Exercise: Prove it this property from the specification.
Moreover, adiff
maps the identity to the identity: adiff id = id
.
Exercise: Show that for any definition of adiff
, if for all g
,
toD g == adiff g . toD id
and if for all h
and g
adiff (h . g) == adiff h . adiff g
then our adiff
specification holds, i.e., for g
and f
,
adiff g . toD f == toD (g . f)
Back to an implementation
We’re still not quite done, since adiff
depends on deriv
, which doesn’t have an implementation.
Let’s separate out the problematic deriv
by refactoring adiff
:
adiff g = g >-< deriv g
where
infix 0 >-<
(>-<) :: Num a => (a -> a) -> (a -> a) -> (D a -> D a)
(g >-< g') (D a a') = D (g a) (g' a * a')
After inlining this definition of adiff
, the method definitions are
sin = sin >-< deriv sin
sqrt = sqrt >-< deriv sqrt
...
Every remaining use of deriv
is applied to a function whose derivative is known, so we can replace each use.
sin = sin >-< cos
sqrt = sqrt >-< recip (2 * sqrt)
...
Now we have an executable implementation again.
These method definitions and the definition of (>-<)
are exactly as in What is automatic differentiation, and why does it work?.
Fun with rules
Let’s back up to our more elegant method definitions using adiff
:
sin = adiff sin
sqrt = adiff sqrt
...
We made these definitions executable in spite of their appeal to the non-executable deriv
by (a) refactoring adiff
to split the deriv
from the residual function (>-<)
, (b)
inlining adiff
, and (c) rewriting applications of deriv
with known derivative rules.
Now let’s get GHC to do these steps for us.
List the derivatives of known functions:
{-# RULES
"deriv negate" deriv negate = -1
"deriv abs" deriv abs = signum
"deriv signum" deriv signum = 0
"deriv recip" deriv recip = - sqr recip
"deriv exp" deriv exp = exp
"deriv log" deriv log = recip
"deriv sqrt" deriv sqrt = recip (2 * sqrt)
"deriv sin" deriv sin = cos
"deriv cos" deriv cos = - sin
"deriv asin" deriv asin = recip (sqrt (1-sqr))
"deriv acos" deriv acos = recip (- sqrt (1-sqr))
"deriv atan" deriv atan = recip (1+sqr)
"deriv sinh" deriv sinh = cosh
"deriv cosh" deriv cosh = sinh
"deriv asinh" deriv asinh = recip (sqrt (1+sqr))
"deriv acosh" deriv acosh = recip (- sqrt (sqr-1))
"deriv atanh" deriv atanh = recip (1-sqr)
#-}
Notice that these definitions are simpler and more modular than the standard differentiation rules, as they do not have the chain rule mixed in.
For instance, compare (a) deriv sin = cos
, (b) deriv (sin u) == cos u * deriv u
, and (c) deriv (sin u) x == cos u x * deriv u x
.
Now we can use the incredibly simple adiff
-based definitions of our methods, e.g., asin = adiff asin
.
The definition of adiff
must get inlined so as to reveal the deriv
applications, which then get rewritten according to the rules.
Fortunately, the adiff
definition is tiny, which encourages its inlining.
We could add an INLINE
pragma as a reminder.
GHC requires that a definition must be given deriv
, even it all uses are rewritten away, so use the following
deriv = error "deriv: undefined. Missing rewrite rule?"
newsham:
I don’t know whether to laugh or cry. I think I like it.
9 February 2009, 8:51 pmRyan Ingram:
The use of rewrite rules is a cool hack, but it feels too fragile for production code. The “rule” I use for my code (apologies for the pun) is that rewrite rules should not be relied on for correctness, only performance.
11 February 2009, 2:39 pmconal:
Hi Ryan,
That guideline is the conventional wisdom I hear. I hope that examples like the one above motivate a shift, so that rewrite rules can be counted on, considering the boost in modularity that results.
11 February 2009, 3:24 pmRyan Ingram:
I see where you are going, but you’re relying on a lot of compiler machinery (correct inlining, rewrites, etc.) in order to get a modest gain (adiff g = g >-< deriv g).
I’d prefer a solution in which this reliance is made explicit at compile time; basically relying on Template Haskell or the like to enforce that the g passed to adiff is a compile-time-constant which can be safely passed to deriv.
In fact, I think TH could be useful for a lot of things if the syntax for invoking it wasn’t so ugly. I’d love for a “transparent” TH. Consider a macro application in LISP: in the source code, it looks just like a function call. Perhaps a TH function should get some form of context (the AST it is a part of) and be allowed to return parts of that context “unmodified” after picking out the data it expects, instead of needing to use the ugly splice syntax.
This would give you the same power that you get from rewrite rules but relying on the specification of the language rather than the internals of the optimizer.
12 February 2009, 6:29 pm