Derivatives Defunctionalize CPS
Or, what’s the derivative of a function?
Jan-Willem Maessen
jmaessen@google.com
Derivatives Defunctionalize CPS
Or, what’s the derivative of a function?
Jan-Willem Maessen
jmaessen@google.com
Work In Progress
Thanks
Edwin Brady
Klaus Ostermann
Mark Miller
IFIP WG 2.16
A simple ordered set
data Set v = E | N (Set v) v (Set v)
insert v E = N E v E
insert v (N l n r)
| v < n = N (insert v l) n r
| n < v = N l n (insert v r)
| otherwise = N l n r
fold e n s = rec s where
rec E = e
rec (N l v r) = n (rec l) v (rec r)
Algebraic Types: Sums of Products
data Set v = E | N (Set v) v (Set v)
Set v = μs. 1 + s*v*s
data List v = Nil | Cons v (List v)
List v = μl. 1 + v*l
Derivatives, Illustrated
N 4
N 2
N 1
E
E
E
E
N 3
N 6
E
E
E
N 7
Derivatives, Illustrated
R 4
N 2
N 1
E
E
E
E
N 3
L 6
E
E
N 7
T
Taking the derivative of a Set
Set v = μs. 1 + s*v*s
∂s (1 + s*v*s)
= ∂s 1 + ∂s (s*v*s)
= 0 + ∂s*v*s + s*0*s + s*v*∂s
= ∂s*v*s + s*v*∂s
Set’ v = μs’. 1 + s’* v * Set v + Set v * v * s’
data Set’ v
= T | L (Set’ v) v (Set v) | R (Set v) v (Set’ v)
Substitution into a Derivative
data Set’ v
= T | L (Set’ v) v (Set v) | R (Set v) v (Set’ v)
sub :: Set’ v → Set v → Set v
sub T t’ = t’
sub (L c v r) l’ = sub c (N l’ v r)
sub (R l v c) r’ = sub c (N l v r’)
Local CPS conversion of insert
insert v t = insertK t v (λt’→ t’)
insertK v E k = k (N E v E)
insertK v (N l n r) k
| v < n = insertK v l (λl’→ k (N l’ n r))
| n < v = insertK v r (λr’→ k (N l n r’))
| otherwise = k (N l n r)
Defunctionalize the continuation
insert v t = insertC t v T
insertC v E c = applyC c (N E v E)
insertC v (N l n r) c
| v < n = insertC v l (L c n r)
| n < v = insertC v r (R l n c)
| otherwise = applyC c (N l n r)
applyC T = λt’→ t’
applyC (L c n r) = λl’→ applyC c (N l’ n r)
applyC (R l n c) = λr’→ applyC c (N l n r’)
Defunctionalize the continuation
insert v t = insertC t v T
insertC v E c = sub c (N E v E)
insertC v (N l n r) c
| v < n = insertC v l (L c n r)
| n < v = insertC v r (R l n c)
| otherwise = sub c (N l n r)
insertC :: v → Set v → Set’ v → Set v
applyC, sub :: Set’ v → Set v → Set v
Is this surprising?
What about other functions on Set v?
Catamorphism (tree fold)
fold e n s = rec s where
rec E = e
rec (N l v r) = n (rec l) v (rec r)
setSum = fold 0 (λl v r → l + v + r)
toList = fold [] (λl v r → l ++ [v] ++ r)
CPS-converted
fold e n s = recK s (λr → r) where
recK E k = k e
recK (N l v r) k =
recK l (λl'→
recK r (λr'→ k (n l' v r')))
Defunctionalized
fold e n s = recC s T where
recC E c = apply c e
recC (N l v r) c = recC l (L c v r)
apply T r = r
apply (L c v r) l' = recC r (R l' v c)
apply (R l' v c) r' = apply c (n l' v r')
data foldC a v = T
| L (foldC a v) v (Set v) | R a v (foldC a v)
Defunctionalized sum continuation
R 4
6
(partial sum)
L 6
E
E
N 7
T
The directional derivative of a Set
Set v = μs. 1 + s*v*s
∂s (1 + s*v*s)
= ∂s 1 + ∂s (s*v*s)
= 0 + ∂s*v*r + l*0*r + l*v*∂s
= ∂s*v*r + l*v*∂s
SetD l r v = μs’. 1 + s’*v*r + l*v*s’
data SetD l r v
= T | L (SetD v) v r | R l v (SetD v)
Using the directional derivative
data SetD l r v
= T | L (SetD v) v r | R l v (SetD v)
type Set’ v = SetD (Set v) (Set v) v
type foldC a v = SetD a (Set v) v
Why do we care?
QuickSort
qs as = qsa as where
qsa [] ts = ts
qsa (a:as) ts =
case partition (< a) as of
(ls, rs) -> qsa ls (a : qsa rs ts)
QuickSort: CPS
qs as = qsa as (λt’→ t’) where
qsa [] ts k = k ts
qsa (a:as) ts k =
case partition (< a) as of
(ls, rs) ->
qsa rs ts (λts’→ qsa ls (a:ts’) k)
QuickSort: Defunctionalized
qs as = qsa as Empty where
qsa [] ts c = apply c ts
qsa (a:as) ts c =
case partition (< a) as of
(ls, rs) ->
qsa rs ts (NodeLike ls a c)
apply Empty t’ = t’
apply (NodeLike ls a c) ts’ = qsa ls (a:ts’) c
data SetLike a =
Empty | NodeLike [a] a (SetLike a)
Set Union
union a b = rec a b Nothing Nothing where
rec (N l v r) b (Just lb) ub
| v <= lb = rec r b (Just lb) ub
rec (N l v r) b lb (Just ub)
| ub <= v = rec l b lb (Just ub)
rec (N l v r) b lb ub =
N (rec b l lb (Just v)) v
(rec b r (Just v) ub)
rec E E lb ub = E
rec E b lb ub = rec b E lb ub
Set Union: CPS
union a b = rec a b Nothing Nothing (λs->s) where
rec (N l v r) b (Just lb) ub k
| v <= lb = rec r b (Just lb) ub k
rec (N l v r) b lb (Just ub) k
| ub <= v = rec l b lb (Just ub) k
rec (N l v r) b lb ub k =
rec b l lb (Just v) (λl’ ->
rec b r (Just v) ub (λr’ ->
k (N l’ v r’)))
rec E E lb ub k = k E
rec E b lb ub k = rec b E lb ub k
Set Union: Defunctionalized
union a b = rec a b Nothing Nothing Id where
rec (N l v r) b (Just lb) ub c
| v <= lb = rec r b (Just lb) ub c
rec (N l v r) b lb (Just ub) c
| ub <= v = rec l b lb (Just ub) c
rec (N l v r) b lb ub c =
rec b l lb (Just v) (L c v (r, b, lb, ub))()
rec E E lb ub c = apply c E
rec E b lb ub c = rec b E lb ub c
apply Id s = s
apply (L c v (r, b, lb, ub)) l’ = rec b r (Just v) ub (R l’ v c)
apply (R l’ v c) r’ = apply c (N l’ v r’)
Current Work
Things to look at
Brent Yorgey on combinatorial species
Work on reifying recursion structure for sort in Coq.
MergeSort
ms as = msl as (length as) where
msl as n | n <= 1 = as
msl as n = case splitAt half as of
(xs, ys) -> merge (msl xs half) (msl ys (n - half))
where half = n `div` 2
merge (x:xs) (y:ys) | x <= y = x : merge xs (y:ys)
| otherwise = y : merge (x:xs) ys
merge [] ys = ys
merge xs [] = xs
MergeSort: Independent Local CPS
ms as = msl as (length as) (λrs→ rs) where
msl as n k | n <= 1 = k as
msl as n k = case splitAt half as of
(xs, ys) -> msl xs half (λxs’→
msl ys (n - half) (λys’→
merge xs’ ys’ (λrs→ k rs)))
where half = n `div` 2
merge (x:xs) (y:ys) k | x <= y = merge xs (y:ys) (λrs→ k (x:rs))
| otherwise = merge (x:xs) ys (λrs→ k (y:rs))
merge [] ys k = k ys
merge xs [] k = k xs
MergeSort: Defunctionalization
ms as = msl as (length as) Id where
msl as n c | n <= 1 = c as
msl as n c = case splitAt half as of
(xs, ys) -> msl xs half (S c (n-half) ys)
where half = n `div` 2
merge (x:xs) (y:ys) c | x <= y = merge xs (y:ys) (M x c)
| otherwise = merge (x:xs) ys (M y c)
merge [] ys c = applyM c ys
merge xs [] c = k xs
applyS Id rs = rs
applyS (S c n’ ys) xs’ = msl ys n’ (M c xs’)
applyS (M c xs’) ys’ = merge xs’ ys’ (I c)
applyM (I c) rs = applyS c rs
applyM (M x c) rs = applyM c (x:rs)
InsertionSort (Recursive, stable)
is [] = []
is (a:as) = ins (is as) where
ins [] = [a]
ins (b:bs)
| a < b = a : b : bs
| otherwise = b : ins bs
InsertionSort: CPS
is as = is as (λt’→ t’)
isK [] k = k []
isK (a:as) k = isK as (λbs → k (ins bs)) where
ins [] = [a]
ins (b:bs)
| a < b = a : b : bs
| otherwise = b : ins bs
InsertionSort: Defunctionalized
is as = is as []
isC [] c = apply c []
isC (a:as) c = isC as (a:c)
apply [] t’ = t’
apply (a:c) bs = apply c (ins a bs) where
ins [] = [a]
ins (b:bs)
| a < b = a : b : bs
| otherwise = b : ins bs
InsertionSort: Renamed
is as = is as []
isC [] c = isRev c []
isC (a:as) c = isC as (a:c)
isRev [] bs = bs
isRev (a:as) bs = isRev as (ins bs) where
ins [] = [a]
ins (b:bs)
| a < b = a : b : bs
| otherwise = b : ins bs
InsertionSort: Refactored
is as = isRev (reverse as) []
isRev [] bs = bs
isRev (a:as) bs = isRev as (ins bs) where
ins [] = [a]
ins (b:bs)
| a < b = a : b : bs
| otherwise = b : ins bs
InsertionSort: CPS-transform insert
is as = isRev (reverse as) []
isRev [] bs = bs
isRev (a:as) bs = isRev as (ins bs (λt’→ t’))
where
ins [] k = k [a]
ins (b:bs) k
| a < b = k (a : b : bs)
| otherwise = ins bs (λbs’→ k (b:bs’))
InsertionSort: Defunctionalize
is as = isRev (reverse as) []
isRev [] bs = bs
isRev (a:as) bs = isRev as (ins bs []) where
ins [] c = apply c [a]
ins (b:bs) c
| a < b = apply c (a : b : bs)
| otherwise = ins bs (b:c)
apply [] bs = bs
apply (b:c) bs’ = apply c (b:bs’)
InsertionSort: Rename
is as = isRev (reverse as) []
isRev [] bs = bs
isRev (a:as) bs = isRev as (ins bs []) where
ins [] c = revApp c [a]
ins (b:bs) c
| a < b = revApp c (a : b : bs)
| otherwise = ins bs (b:c)
revApp [] bs = bs
revApp (b:c) bs’ = revApp c (b:bs’)
MergeSort
ms as = msl as (length as) where
msl [] n = []
msl as n =
case splitAt half as of
(xs, ys) ->
merge (msl xs half) (msl ys (n - half))
where half = n `div` 2
merge (x:xs) (y:ys)
| x <= y = x : merge xs (y:ys)
| otherwise = y : merge (x:xs) ys
merge [] ys = ys
merge xs [] = xs
MergeSort: CPS
ms as = msl as (length as) (λrs→ rs) where
msl [] n k = k []
msl as n k =
case splitAt half as of
(xs, ys) ->
msl xs half (λxs’→
msl ys (n - half) (λys’→
merge xs’ ys’ k))
where half = n `div` 2
merge (x:xs) (y:ys) k
| x <= y = merge xs (y:ys) (λrs→ k (x:rs))
| otherwise = merge (x:xs) ys (λrs→ k (y:rs))
merge [] ys k = k ys
merge xs [] k = k xs
MergeSort: Defunctionalization
ms as = msl as (length as) Id where
msl [] n c = apply c []
msl as n c =
case splitAt half as of
(xs, ys) ->
msl xs half (S c (n-half) ys)
where half = n `div` 2
merge (x:xs) (y:ys) c
| x <= y = merge xs (y:ys) (MR x c)
| otherwise = merge (x:xs) ys (MR y c)
merge [] ys c = apply c ys
merge xs [] c = apply c xs
apply Id rs = rs
apply (S c n’ ys) xs’ = msl ys n’ (ML c xs’)
apply (ML c xs’) ys’ = merge xs’ ys’ c
apply (MR x c) rs = apply c (x:rs)
MergeSort: Taking the Integral?
data MSC c j v
= Id
| S (MSC c j v) Int [v]
| ML (MSC c j v) j
| MR c (MSC c j v)
data MSTree v
= Split (MSTree v) Int [v]
| Merge (MSTree v) (MSTree v)
| Leaf [v]
type MSCont v = MSC v [v] v