Deriving parallel tree scans

The post Deriving list scans explored folds and scans on lists and showed how the usual, efficient scan implementations can be derived from simpler specifications.

Let’s see now how to apply the same techniques to scans over trees.

This new post is one of a series leading toward algorithms optimized for execution on massively parallel, consumer hardware, using CUDA or OpenCL.

Edits:

  • 2011-03-01: Added clarification about "" and "(⊕)".
  • 2011-03-23: corrected "linear-time" to "linear-work" in two places.

Trees

Our trees will be non-empty and binary:

data T a = Leaf a | Branch (T a) (T a)

instance Show a Show (T a) where
show (Leaf a) = show a
show (Branch s t) = "("++show s++","++show t++")"

Nothing surprising in the instances:

instance Functor T where
fmap f (Leaf a) = Leaf (f a)
fmap f (Branch s t) = Branch (fmap f s) (fmap f t)

instance Foldable T where
fold (Leaf a) = a
fold (Branch s t) = fold s ⊕ fold t

instance Traversable T where
sequenceA (Leaf a) = fmap Leaf a
sequenceA (Branch s t) =
liftA2 Branch (sequenceA s) (sequenceA t)

BTW, my type-setting software uses "" and "(⊕)" for Haskell’s "mempty" and "mappend".

Also handy will be extracting the first and last (i.e., leftmost and rightmost) leaves in a tree:

headT  T a  a
headT (Leaf a) = a
headT (s `Branch` _) = headT s

lastT T a a
lastT (Leaf a) = a
lastT (_ `Branch` t) = lastT t

Exercise: Prove that

headT ∘ fmap f ≡ f ∘ headT
lastT ∘ fmap f ≡ f ∘ lastT

Answer:

Consider the Leaf and Branch cases separately:

  headT (fmap f (Leaf a))
{- fmap on T -}
headT (Leaf (f a))
{- headT def -}
f a
{- headT def -}
f (headT (Leaf a))
  headT (fmap f (Branch s t))
{- fmap on T -}
headT (Branch (fmap f s) (fmap f t))
{- headT def -}
headT (fmap f s)
{- induction -}
f (headT s)
{- headT def -}
f (headT (Branch s t))

Similarly for lastT.

From lists to trees and back

We can flatten trees into lists:

flatten  T a  [a]
flatten = fold ∘ fmap (:[])

Equivalently, using foldMap:

flatten = foldMap (:[])

Alternatively, we could define fold via flatten:

instance Foldable T where fold = fold ∘ flatten
flatten  T a  [a]
flatten (Leaf a) = [a]
flatten (Branch s t) = flatten s ++ flatten t

We can also "unflatten" lists into balanced trees:

unflatten  [a]  T a
unflatten [] = error "unflatten: Oops! Empty list"
unflatten [a] = Leaf a
unflatten xs = Branch (unflatten prefix) (unflatten suffix)
where
(prefix,suffix) = splitAt (length xs `div` 2) xs

Both flatten and unflatten can be implemented more efficiently.

For instance,

t1,t2  T Int
t1 = unflatten [13]
t2 = unflatten [116]
*T> t1
(1,(2,3))
*T> t2
((((1,2),(3,4)),((5,6),(7,8))),(((9,10),(11,12)),((13,14),(15,16))))

Specifying tree scans

Prefixes and suffixes

The post Deriving list scans gave specifications for list scanning in terms of inits and tails. One consequence of this specification is that the output of scanning has one more element than the input. Alternatively, we could use non-empty variants of inits and tails, so that the input & output are in one-to-one correspondence.

inits'  [a]  [[a]]
inits' [] = []
inits' (x:xs) = map (x:) ([] : inits' xs)

The cons case can also be written as

inits' (x:xs) = [x] : map (x:) (inits' xs)
tails'  [a]  [[a]]
tails' [] = []
tails' xs@(_:xs') = xs : tails' xs'

For instance,

*T> inits' "abcd"
["a","ab","abc","abcd"]
*T> tails' "abcd"
["abcd","bcd","cd","d"]

Our tree functor has a symmetric definition, so we get more symmetry in the counterparts to inits' and tails':

initTs  T a  T (T a)
initTs (Leaf a) = Leaf (Leaf a)
initTs (s `Branch` t) =
Branch (initTs s) (fmap (s `Branch`) (initTs t))

tailTs T a T (T a)
tailTs (Leaf a) = Leaf (Leaf a)
tailTs (s `Branch` t) =
Branch (fmap (`Branch` t) (tailTs s)) (tailTs t)

Try it:

*T> t1
(1,(2,3))
*T> initTs t1
(1,((1,2),(1,(2,3))))
*T> tailTs t1
((1,(2,3)),((2,3),3))

*T> unflatten [15]
((1,2),(3,(4,5)))
*T> initTs (unflatten [15])
((1,(1,2)),(((1,2),3),(((1,2),(3,4)),((1,2),(3,(4,5))))))
*T> tailTs (unflatten [15])
((((1,2),(3,(4,5))),(2,(3,(4,5)))),((3,(4,5)),((4,5),5)))

Exercise: Prove that

lastT ∘ initTs ≡ id
headT ∘ tailTs ≡ id

Answer:

  lastT (initTs (Leaf a))
{- initTs def -}
lastT (Leaf (Leaf a))
{- lastT def -}
Leaf a

lastT (initTs (s `Branch` t))
{- initTs def -}
lastT (Branch (⋯) (fmap (s `Branch`) (initTs t)))
{- lastT def -}
lastT (fmap (s `Branch`) (initTs t))
{- lastT ∘ fmap f -}
(s `Branch`) (lastT (initTs t))
{- trivial -}
s `Branch` lastT (initTs t)
{- induction -}
s `Branch` t

Scan specification

Now we can specify prefix & suffix scanning:

scanlT, scanrT  Monoid a  T a  T a
scanlT = fmap fold ∘ initTs
scanrT = fmap fold ∘ tailTs

Try it out:

t3  T String
t3 = fmap (:[]) (unflatten "abcde")
*T> t3
(("a","b"),("c",("d","e")))
*T> scanlT t3
(("a","ab"),("abc",("abcd","abcde")))
*T> scanrT t3
(("abcde","bcde"),("cde",("de","e")))

To test on numbers, I’ll use a handy notation from Matt Hellige to add pre- and post-processing:

(↝)  (a'  a)  (b  b')  ((a  b)  (a'  b'))
(f ↝ h) g = h ∘ g ∘ f

And a version specialized to functors:

(↝*)  Functor f  (a'  a)  (b  b')
(f a f b) (f a' f b')
f ↝* g = fmap f ↝ fmap g
t4  T Integer
t4 = unflatten [16]

t5 T Integer
t5 = (Sum* getSum) scanlT t4

Try it:

*T> t4
((1,(2,3)),(4,(5,6)))
*T> initTs t4
((1,((1,2),(1,(2,3)))),(((1,(2,3)),4),(((1,(2,3)),(4,5)),((1,(2,3)),(4,(5,6))))))
*T> t5
((1,(3,6)),(10,(15,21)))

Exercise: Prove that we have properties similar to the ones relating fold, scanlT, and scanrT on list:

fold ≡ lastT ∘ scanlT
fold ≡ headT ∘ scanrT

Answer:

  lastT ∘ scanlT
{- scanlT spec -}
lastT ∘ fmap fold ∘ initTs
{- lastT ∘ fmap f -}
fold ∘ lastT ∘ initTs
{- lastT ∘ initTs -}
fold

headT ∘ scanrT
{- scanrT def -}
headT ∘ fmap fold ∘ tailTs
{- headT ∘ fmap f -}
fold ∘ headT ∘ tailTs
{- headT ∘ tailTs -}
fold

For instance,

*T> fold t3
"abcde"
*T> (lastT ∘ scanlT) t3
"abcde"
*T> (headT ∘ scanrT) t3
"abcde"

Deriving faster scans

Recall the specifications:

scanlT = fmap fold ∘ initTs
scanrT = fmap fold ∘ tailTs

To derive more efficient implementations, proceed as in Deriving list scans. Start with prefix scan (scanlT), and consider the Leaf and Branch cases separately.

  scanlT (Leaf a)
{- scanlT spec -}
fmap fold (initTs (Leaf a))
{- initTs def -}
fmap fold (Leaf (Leaf a))
{- fmap def -}
Leaf (fold (Leaf a))
{- fold def -}
Leaf a

scanlT (s `Branch` t)
{- scanlT spec -}
fmap fold (initTs (s `Branch` t))
{- initTs def -}
fmap fold (Branch (initTs s) (fmap (s `Branch`) (initTs t)))
{- fmap def -}
Branch (fmap fold (initTs s)) (fmap fold (fmap (s `Branch`) (initTs t)))
{- scanlT spec -}
Branch (scanlT s) (fmap fold (fmap (s `Branch`) (initTs t)))
{- functor law -}
Branch (scanlT s) (fmap (fold ∘ (s `Branch`)) (initTs t))
{- rework as λ -}
Branch (scanlT s) (fmap (λ t' fold (s `Branch` t')) (initTs t))
{- fold def -}
Branch (scanlT s) (fmap (λ t' fold s ⊕ fold t')) (initTs t))
{- rework λ -}
Branch (scanlT s) (fmap ((fold s ⊕) ∘ fold) (initTs t))
{- functor law -}
Branch (scanlT s) (fmap (fold s ⊕) (fmap fold (initTs t)))
{- scanlT spec -}
Branch (scanlT s) (fmap (fold s ⊕) (scanlT t))
{- lastT ∘ scanlT ≡ fold -}
Branch (scanlT s) (fmap (lastT (scanlT s) ⊕) (scanlT t))
{- factor out defs -}
Branch s' (fmap (lastT s' ⊕) t')
where s' = scanlT s
t' = scanlT t

Suffix scan has a similar derivation.

  scanrT (Leaf a)
{- scanrT def -}
fmap fold (tailTs (Leaf a))
{- tailTs def -}
fmap fold (Leaf (Leaf a))
{- fmap on T -}
Leaf (fold (Leaf a))
{- fold def -}
Leaf a

scanrT (s `Branch` t)
{- scanrT spec -}
fmap fold (tailTs (s `Branch` t))
{- tailTs def -}
fmap fold (Branch (fmap (`Branch` t) (tailTs s)) (tailTs t))
{- fmap def -}
Branch (fmap fold (fmap (`Branch` t) (tailTs s))) (fmap fold (tailTs t))
{- scanrT spec -}
Branch (fmap fold (fmap (`Branch` t) (tailTs s))) (scanrT t)
{- functor law -}
Branch (fmap (fold ∘ (`Branch` t)) (tailTs s)) (scanrT t)
{- rework as λ -}
Branch (fmap (λ s' fold (s' `Branch` t)) (tailTs s)) (scanrT t)
{- functor law -}
Branch (fmap (λ s' fold s' ⊕ fold t) (tailTs s)) (scanrT t)
{- rework λ -}
Branch (fmap ((⊕ fold t) ∘ fold) (tailTs s)) (scanrT t)
{- scanrT spec -}
Branch (fmap (⊕ fold t) (scanrT s)) (scanrT t)
{- headT ∘ scanrT -}
Branch (fmap (⊕ headT (scanrT t)) (scanrT s)) (scanrT t)
{- factor out defs -}
Branch (fmap (⊕ headT t') s') t'
where s' = scanrT s
t' = scanrT t

Extract code from these derivations:

scanlT'  Monoid a  T a  T a
scanlT' (Leaf a) = Leaf a
scanlT' (s `Branch` t) =
Branch s' (fmap (lastT s' ⊕) t')
where s' = scanlT' s
t' = scanlT' t

scanrT' Monoid a T a T a
scanrT' (Leaf a) = Leaf a
scanrT' (s `Branch` t) =
Branch (fmap (⊕ headT t') s') t'
where s' = scanrT' s
t' = scanrT' t

Try it:

*T> t3
(("a","b"),("c",("d","e")))
*T> scanlT' t3
(("a","ab"),("abc",("abcd","abcde")))
*T> scanrT' t3
(("abcde","bcde"),("cde",("de","e")))

Efficiency

Although I was just following my nose, without trying to get anywhere in particular, this result is exactly the algorithm I first thought of when considering how to parallelize tree scanning.

Let’s now consider the running time of this algorithm. Assume that the tree is balanced, to maximize parallelism. (I think balancing is optimal for parallelism here, but I’m not certain.)

For a tree with n leaves, the work Wn will be constant when n=1 and 2W(n/2)+n when n>1. Using the Master Theorem (explained more here), Wn=Θ(nlogn).

This result is disappointing, since scanning can be done with linear work by threading a single accumulator while traversing the input tree and building up the output tree.

I’m using the term "work" instead of "time" here, since I’m not assuming sequential execution.

We have a parallel algorithm that performs nlogn work, and a sequential program that performs linear work. Can we construct a linear-parallel algorithm?

Yes. Guy Blelloch came up with a clever linear-work parallel algorithm, which I’ll derive in another post.

Generalizing head and last

Can we replace the ad hoc (tree-specific) headT and lastT functions with general versions that work on all foldables? I’d want the generalization to also generalize the list functions head and last or, rather, to total variants (ones that cannot error due to empty list). For totality, provide a default value for when there are no elements.

headF, lastF  Foldable f  a  f a  a

I also want these functions to be as efficient on lists as head and last and as efficient on trees as headT and lastT.

The First and Last monoids provide left-biased and right-biased choice. They’re implemented as newtype wrappers around Maybe:

newtype First a = First { getFirst  Maybe a }

instance Monoid (First a) where
= First Nothing
r@(First (Just _)) ⊕ _ = r
First Nothing ⊕ r = r
newtype Last a = Last { getLast  Maybe a }

instance Monoid (Last a) where
= Last Nothing
_ ⊕ r@(Last (Just _)) = r
r ⊕ Last Nothing = r

For headF, embed all of the elements into the First monoid (via First ∘ Just), fold over the result, and extract the result, using the provided default value in case there are no elements. Similarly for lastF.

headF dflt = fromMaybe dflt ∘ getFirst ∘ foldMap (FirstJust)
lastF dflt = fromMaybe dflt ∘ getLast ∘ foldMap (LastJust)

For instance,

*T> headF 3 [1,2,4,8]
1
*T> headF 3 []
3

When our elements belong to a monoid, we can use as the default:

headFM  (Foldable f, Monoid m)  f m  m
headFM = headF ∅

lastFM (Foldable f, Monoid m) f m m
lastFM = headF ∅

For instance,

*T> lastFM ([]  [String])
""

Using headFM and lastFM in place of headT and lastT, we can easily handle addition of an Empty case to our tree functor in this post. The key choice is that fold Empty ≡ ∅ and fmap _ Empty ≡ Empty. Then headFM will choose the first leaf, and lastT

What about efficiency? Because headF and lastF are defined via foldMap, which is a composition of fold and fmap, one might think that we have to traverse the entire structure when used with functors like [] or T.

Laziness saves us, however, and we can even extract the head of an infinite list or a partially defined one. For instance,

  foldMap (FirstJust) [5 ]
≡ foldMap (FirstJust) (5 : [6 ])
First (Just 5) ⊕ foldMap (FirstJust) [6 ]
First (Just 5)

So

  headF d [5 ]
≡ fromMaybe d (getFirst (foldMap (FirstJust) [5 ]))
≡ fromMaybe d (getFirst (First (Just 5)))
≡ fromMaybe d (Just 5)
5

And, sure enough,

*T> foldMap (FirstJust) [5 ]
First {getFirst = Just 5}
*T> headF ⊥ [5 ]
5

Where to go from here?

  • As mentioned above, the derived scanning implementations perform asymtotically more work than necessary. Future posts explore how to derive parallel-friendly, linear-work algorithms. Then we’ll see how to transform the parallel-friendly algorithms so that they work destructively, overwriting their input as they go, and hence suitably for execution entirely in CUDA or OpenCL.
  • The functions initTs and tailTs are still tree-specific. To generalize the specification and derivation of list and tree scanning, find a way to generalize these two functions. The types of initTs and tailTs fit with the duplicate method on comonads. Moreover, tails is the usual definition of duplicate on lists, and I think inits would be extend for "snoc lists". For trees, however, I don’t think the correspondence holds. Am I missing something?
  • In particular, I want to extend the derivation to depth-typed, perfectly balanced trees, of the sort I played with in A trie for length-typed vectors and From tries to trees. The functions initTs and tailTs make unbalanced trees out of balanced ones, so I don’t know how to adapt the specifications given here to the setting of depth-typed balanced trees. Maybe I could just fill up the to-be-ignored elements with .

5 Comments

  1. Jake McArthur:

    The correspondence of tailsTs with duplicate comes up if you change your tree from a free monad to a cofree comonad.

  2. Conal Elliott » Blog Archive » Composable parallel scanning:

    [...] About « Deriving parallel tree scans [...]

  3. Russell O'Connor:

    Your theorem headT ∘ fmap f ≡ f ∘ headT is the free theorem for headT (and lastT), so you don’t even have to do anything to prove it. It follows from the type.

  4. conal:

    Thanks for the tip, Russell! I’ll keep an eye out for these free theorems.

  5. Conal Elliott » Blog Archive » Parallel tree scanning by composition:

    [...] final form is as in Deriving parallel tree scans, changed for the new scan interface. The derivation saved some work in wrapping & unwrapping [...]

Leave a comment