What is automatic differentiation, and why does it work?
Bertrand Russell remarked that
Everything is vague to a degree you do not realize till you have tried to make it precise.
I’m mulling over automatic differentiation (AD) again, neatening up previous posts on derivatives and on linear maps, working them into a coherent whole for an ICFP submission. I understand the mechanics and some of the reasons for its correctness. After all, it’s "just the chain rule".
As usual, in the process of writing, I bumped up against Russell’s principle. I felt a growing uneasiness and realized that I didn’t understand AD in the way I like to understand software, namely,
- What does it mean, independently of implementation?
- How do the implementation and its correctness flow gracefully from that meaning?
- Where else might we go, guided by answers to the first two questions?
Ever since writing Simply efficient functional reactivity, the idea of type class morphisms keeps popping up for me as a framework in which to ask and answer these questions. To my delight, this framework gives me new and more satisfying insight into automatic differentiation.
What’s a derivative?
My first guess is that AD has something to do with derivatives, which then raises the question of what is a derivative. For now, I’m going to substitute a popular but problematic answer to that question and say that
deriv ∷ ⋯ ⇒ (a → b) → (a → b) -- simplification
As discussed in What is a derivative, really?, the popular answer has limited usefulness, applying just to scalar (one-dimensional) domain. The real deal involves distinguishing the type b
from the type a :-* b
of linear maps from a
to b
.
deriv ∷ (VectorSpace u, VectorSpace v) ⇒ (u → v) → (u → (u :-* v))
Why care about derivatives?
Derivatives are useful in a variety of application areas, including root-finding, optimization, curve and surface tessellation, and computation of surface normals for 3D rendering. Considering the usefulness of derivatives, it is worthwhile to find software methods that are
- simple (to implement and verify),
- convenient,
- accurate,
- efficient, and
- general.
What isn't AD?
Numeric approximation
One differentiation method numeric approximation, using simple finite differences. This method is based on the definition of (scalar) derivative:
deriv f x ≡ limh → 0(f (x + h) - f x) / h
The left-hand side reads "the derivative of f at x".
To approximate the derivative, use
deriv f x ≈ (f (x + h) - f x) / h
for a small value of h. While very simple, this method is often inaccurate, due to choosing either too large or too small a value for h. (Small values of h lead to rounding errors.) More sophisticated variations improve accuracy while sacrificing simplicity.
Symbolic differentiation
A second method is symbolic differentiation. Instead of using the definition of deriv directly, the symbolic method uses a collection of rules, such as those below:
deriv (u + v) ≡ deriv u + deriv v
deriv (u * v) ≡ deriv v * u + deriv u * v
deriv (- u) ≡ - deriv u
deriv (exp u) ≡ deriv u * exp u
deriv (log u) ≡ deriv u / u
deriv (sqrt u) ≡ deriv u / (2 * sqrt u)
deriv (sin u) ≡ deriv u * cos u
deriv (cos u) ≡ deriv u * (- sin u)
deriv (asin u) ≡ deriv u/(sqrt (1 - u^2))
deriv (acos u) ≡ - deriv u/(sqrt (1 - u^2))
deriv (atan u) ≡ deriv u / (u^2 + 1)
deriv (sinh u) ≡ deriv u * cosh u
deriv (cosh u) ≡ deriv u * sinh u
deriv (asinh u) ≡ deriv u / (sqrt (u^2 + 1))
deriv (acosh u) ≡ - deriv u / (sqrt (u^2 - 1))
deriv (atanh u) ≡ deriv u / (1 - u^2)
There are two main drawbacks to the symbolic approach to differentiation.
- As a symbolic method, it requires access to and transformation of source code, and placing restrictions on that source code.
- Implementations tend to be quite expensive and in particular perform redundant computation. (I wonder if this latter criticism is a straw man argument. Are symbolic methods necessarily expensive or just when implemented naïvely? For instance, can simply memoized symbolic differentiation be nearly as cheap as AD?)
What is AD and how does it work?
A third method is the topic of this post, namely automatic differentiation (also called "algorithmic differentiation"), or "AD". The idea of AD is to simultaneously manipulate values and derivatives. Overloading of the standard numerical operations (and literals) makes this combined manipulation as convenient and elegant as manipulating values without derivatives.
The implementation of AD can be quite simple, as shown below:
data D a = D a a deriving (Eq,Show)
instance Num a ⇒ Num (D a) where
D x x' + D y y' = D (x+y) (x'+y')
D x x' * D y y' = D (x*y) (y'*x + x'*y)
fromInteger x = D (fromInteger x) 0
negate (D x x') = D (negate x) (negate x')
signum (D x _ ) = D (signum x) 0
abs (D x x') = D (abs x) (x' * signum x)
instance Fractional x ⇒ Fractional (D x) where
fromRational x = D (fromRational x) 0
recip (D x x') = D (recip x) (- x' / sqr x)
sqr ∷ Num a ⇒ a → a
sqr x = x * x
instance Floating x ⇒ Floating (D x) where
π = D π 0
exp (D x x') = D (exp x) (x' * exp x)
log (D x x') = D (log x) (x' / x)
sqrt (D x x') = D (sqrt x) (x' / (2 * sqrt x))
sin (D x x') = D (sin x) (x' * cos x)
cos (D x x') = D (cos x) (x' * (- sin x))
asin (D x x') = D (asin x) (x' / sqrt (1 - sqr x))
acos (D x x') = D (acos x) (x' / (- sqrt (1 - sqr x)))
-- ⋯
As an example, define
f1 ∷ Floating a ⇒ a → a
f1 z = sqrt (3 * sin z)
and try it out in GHCi:
*Main> f1 (D 2 1)
D 1.6516332160855343 (-0.3779412091869595)
To test correctness, here is a symbolically differentiated version:
f2 ∷ Floating a ⇒ a → D a
f2 x = D (f1 x) (3 * cos x / (2 * sqrt (3 * sin x)))
Try it out:
*Main> f2 2
D 1.6516332160855343 (-0.3779412091869595)
The can also be made prettier, as in Beautiful differentiation. Add an operator that captures the chain rule, which is behind the differentiation laws listed above.
infix 0 >-<
(>-<) ∷ Num a ⇒ (a → a) → (a → a) → (D a → D a)
(f >-< f') (D a a') = D (f a) (a' * f' a)
Then, e.g.,
instance Floating a ⇒ Floating (D a) where
π = D π 0
exp = exp >-< exp
log = log >-< recip
sqrt = sqrt >-< recip (2 * sqrt)
sin = sin >-< cos
cos = cos >-< - sin
asin = asin >-< recip (sqrt (1-sqr))
acos = acos >-< recip (- sqrt (1-sqr))
-- ⋯
This AD implementation satisfy most of our criteria very well:
- It is simple to implement and verify. Both the implementation and its correctness follow directly from the familiar laws given above.
- It is convenient to use, as shown with
f1
above. - It is accurate, as shown above, producing exactly the same result as the symbolic differentiated code (
f2
). - It is efficient, involving no iteration or redundant computation.
The formulation above does less well with generality:
- It computes only first derivatives.
- It applies (correctly) only to functions over a scalar (one-dimensional) domain, excluding even complex numbers.
Both of these limitations are removed in the post Higher-dimensional, higher-order derivatives, functionally.
What is AD, really?
How do we know whether this AD implementation is correct? We can't begin to address this question until we first answer a more fundamental one: what does its correctness mean?
A model for AD
I'm pretty sure AD has something to do with calculating a function's values and derivative values simultaneously, so I'll start there.
withD ∷ ⋯ ⇒ (a → a) → (a → D a)
withD f x = D (f x) (deriv f x)
Or, in point-free form,
withD f = liftA2 D f (deriv f)
Since, on functions,
liftA2 h f g = λ x → h (f x) (g x)
We don't have an implementation of deriv
, so this definition of withD
will serve as a specification, not an implementation.
If AD is structured as type class instances, then I'd want there to be a compelling interpretation function that is faithful to each of those classes, as in the principle of type class morphisms, which is to say that the interpretation of each method corresponds to the same method for the interpretation.
For AD, the interpretation function is withD
. It's turned around this time (mapping to instead of from our type), as is sometimes the case. The Num
, Fractional
, and Floating
morphisms provide the specifications of the instances:
withD (u + v) ≡ withD u + withD v
withD (u * v) ≡ withD u * withD v
withD (sin u) ≡ sin (withD u)
⋯
Note here that the methods on the left are on a → a
, and on the right are on a → D a
.
These (morphism) properties exactly define correctness of any implementation of AD, answering my first question:
What does it mean, independently of implementation?
Deriving an AD implementation
Now that we have a simple, formal specification of AD (numeric type class morphisms), we can try to prove that the implementation above satisfies the specification. Better yet, let's do the reverse, and use the morphism properties to discover the implementation, and prove it correct in the process.
Addition
Here is the addition specification:
withD (u + v) ≡ withD u + withD v
Start with the left-hand side:
withD (u + v)
≡ {- def of withD -}
liftA2 D (u + v) (deriv (u + v))
≡ {- deriv rule for (+) -}
liftA2 D (u + v) (deriv u + deriv v)
≡ {- liftA2 on functions -}
λ x → D ((u + v) x) ((deriv u + deriv v) x)
≡ {- (+) on functions -}
λ x → D (u x + v x) (deriv u x + deriv v x)
Then start over with the right-hand side:
withD u + withD v
≡ {- (+) on functions -}
λ x → withD u x + withD v x
≡ {- def of withD -}
λ x → D (u x) (deriv u x) + D (v x) (deriv v x)
We need a definition of (+)
on D
that makes these two final forms equal, i.e.,
λ x → D (u x + v x) (deriv u x + deriv v x)
≡
λ x → D (u x) (deriv u x) + D (v x) (deriv v x)
An easy choice is
D a a' + D b b' = D (a + b) (a' + b')
This definition provides the missing link and that completes the proof that
withD (u + v) ≡ withD u + withD v
Multiplication
The specification:
withD (u * v) ≡ withD u * withD v
Reason similarly to the addition case. Begin with the left hand side:
withD (u * v)
≡ {- def of withD -}
liftA2 D (u * v) (deriv (u * v))
≡ {- deriv rule for (*) -}
liftA2 D (u * v) (deriv u * v + deriv v * u)
≡ {- liftA2 on functions -}
λ x → D ((u * v) x) ((deriv u * v + deriv v * u) x)
≡ {- (*) and (+) on functions -}
λ x → D (u x * v x) (deriv u x * v x + * deriv v x * u x)
Then start over with the right-hand side:
withD u * withD v
≡ {- (*) on functions -}
λ x → withD u x * withD v x
≡ {- def of withD -}
λ x → D (u x) (deriv u x) * D (v x) (deriv v x)
Sufficient definition:
D a a' * D b b' = D (a + b) (a' * b + b' * a)
Sine
Specification:
withD (sin u) ≡ sin (withD u)
Begin with the left hand side:
withD (sin u)
≡ {- def of withD -}
liftA2 D (sin u) (deriv (sin u))
≡ {- deriv rule for sin -}
liftA2 D (sin u) (deriv u * cos u)
≡ {- liftA2 on functions -}
λ x → D ((sin u) x) ((deriv u * cos u) x)
≡ {- sin, (*) and cos on functions -}
λ x → D (sin (u x)) (deriv u x * cos (u x))
Then start over with the right-hand side:
sin (withD u)
≡ {- sin on functions -}
λ x → sin (withD u x)
≡ {- def of withD -}
λ x → sin (D (u x) (deriv u x))
Sufficient definition:
sin (D a a') = D (sin a) (a' * cos a)
Or, using the chain rule operator,
sin = sin >-< cos
The whole implementation can be derived in exactly this style, answering my second question:
How does the implementation and its correctness flow gracefully from that meaning?
Higher-order derivatives
Given answers to the first two questions, let's, turn to the third:
Where else might we go, guided by answers to the first two questions?
Jerzy Karczmarczuk extended the D
representation above to an infinite "lazy tower of derivatives", in the paper Functional Differentiation of Computer Programs.
data D a = D a (D a)
The withD
function easily adapts to this new D
type:
withD ∷ ⋯ ⇒ (a → a) → (a → D a)
withD f x = D (f x) (withD (deriv f) x)
or
withD f = liftA2 D f (withD (deriv f))
These definitions were not brilliant insights. I looked for the simplest, type-correct possibility (without using ⊥).
Similarly, I'll try tweaking the previous derivations and see what pops out.
Addition
Left-hand side:
withD (u + v)
≡ {- def of withD -}
liftA2 D (u + v) (withD (deriv (u + v)))
≡ {- deriv rule for (+) -}
liftA2 D (u + v) (withD (deriv u + deriv v))
≡ {- (fixed-point) induction withD and (+) -}
liftA2 D (u + v) (withD (deriv u) + withD (deriv v))
≡ {- def of liftA2 and (+) on functions -}
λ x → D (u x + v x) (withD (deriv u) x + withD (deriv v) x)
Right-hand side:
withD u + withD v
≡ {- (+) on functions -}
λ x → withD u x + withD v x
≡ {- def of withD -}
λ x → D (u x) (withD (deriv u x)) + D (v x) (withD (deriv v x))
Again, we need a definition of (+)
on D
that makes the LHS and RHS final forms equal, i.e.,
λ x → D (u x + v x) (withD (deriv u) x + with (deriv v) x)
≡
λ x → D (u x) (withD (deriv u) x) + D (v x) (withD (deriv v) x)
Again, an easy choice is
D a a' + D b b' = D (a + b) (a' + b')
Multiplication
Left-hand side:
withD (u * v)
≡ {- def of withD -}
liftA2 D (u * v) (withD (deriv (u * v)))
≡ {- deriv rule for (*) -}
liftA2 D (u * v) (withD (deriv u * v + deriv v * u))
≡ {- induction for withD/(+) -}
liftA2 D (u * v) (withD (deriv u * v) + withD (deriv v * u))
≡ {- induction for withD/(*) -}
liftA2 D (u * v) (withD (deriv u) * withD v + withD (deriv v) * withD u)
≡ {- liftA2, (*), (+) on functions -}
λ x → liftA2 D (u x * v x) (withD (deriv u) x * withD v x + withD (deriv v) x * withD u x)
Right-hand side:
withD u * withD v
≡ {- def of withD -}
liftA2 D u (withD (deriv u)) * liftA2 D v (withD (deriv v))
≡ {- liftA2 and (*) on functions -}
λ x → D (u x) (withD (deriv u) x) * D (v x) (withD (deriv v) x)
A sufficient definition:
a@(D a0 a') * b@(D b0 b') = D (a0 + b0) (a' * b + b' * a)
Because
withD u x ≡ D (u x) (withD (deriv u) x)
withD v x ≡ D (v x) (withD (deriv v) x)
Sine
Left-hand side:
withD (sin u)
≡ {- def of withD -}
liftA2 D (sin u) (withD (deriv (sin u)))
≡ {- deriv rule for sin -}
liftA2 D (sin u) (withD (deriv u * cos u))
≡ {- induction for withD/(*) -}
liftA2 D (sin u) (withD (deriv u) * withD (cos u))
≡ {- induction for withD/cos -}
liftA2 D (sin u) (withD (deriv u) * cos (withD u))
≡ {- liftA2, sin, cos and (*) on functions -}
λ x → D (sin (u x)) (withD (deriv u) x * cos (withD u x))
Right-hand side:
sin (withD u)
≡ {- def of withD -}
sin (liftA2 D u (withD (deriv u)))
≡ {- liftA2 and sin on functions -}
λ x → sin (D (u x) (withD (deriv u) x))
To make the LHS and RHS final forms equal, define
sin a@(D a0 a') ≡ D (sin a0) (a' * cos a)
Higher-dimensional derivatives
I'll save non-scalar ("multi-variate") differentiation for another time. In addition to the considerations above, the key ideas are in Higher-dimensional, higher-order derivatives, functionally and Simpler, more efficient, functional linear maps.
Raoul Duke:
i read the wikipedia entry
http://en.wikipedia.org/wiki/Automatic_differentiation
and found it not bad. in particular i thought the following really explained AD well vs. the other 2 options:
“AD exploits the fact that any computer program that implements a vector function y = F(x) (generally) can be decomposed into a sequence of elementary assignments, any one of which may be trivially differentiated by a simple table lookup. These elemental partial derivatives, evaluated at a particular argument, are combined in accordance with the chain rule from derivative calculus to form some derivative information for F (such as gradients, tangents, the Jacobian matrix, etc.). This process yields exact (to numerical accuracy) derivatives. Because the symbolic transformation occurs only at the most basic level, AD avoids the computational problems inherent in complex symbolic computation.”
28 January 2009, 4:11 pmSimon PJ:
Conal, do you know Barak Pearlmutter? http://www-bcl.cs.nuim.ie/~barak/. He’s an expert on AD, and knows Haskell too. Well worth talking to.
My hind-brain memory is that he’s identified something to do with AD that you can nearly-but-not-quite Do Right in Haskell. I never quite got my brain around what that Thing was, but I’m sure he could tell you. It may be that some implementation hack would let you do it Right, but I’m not quite sure what the hack is. It’s cool stuff anyway.
Simon
2 February 2009, 3:39 amconal:
Simon, Thanks for the tip. Yes, I know Barak (CMU classmates). The issue he raises is “nested AD” — derivatives of functions built out of derivatives, and the danger of confusing one infinitessimal (perturbation) with another. I’ve read his Nesting forward-mode AD in a functional framework a couple of time and haven’t yet gotten my head around whether the problem is inherent in AD or in a way of attacking it. A related post is Differentiating regions by Chung-chieh Shan.
2 February 2009, 11:29 amConal Elliott » Blog Archive » From the chain rule to automatic differentiation:
[…] What is automatic differentiation, and why does it work?, I gave a semantic model that explains what automatic differentiation (AD) accomplishes. Correct […]
8 February 2009, 2:14 pmIn Heaven Now Are Three « Changing the world:
[…] Short summary: I rant a little about how I found out about Automatic Differentiation from Conal Elliott and then from wikipedia. After this I speak a little about how this tool was useful once in my […]
9 February 2009, 11:30 am