1 of 45

Derivatives Defunctionalize CPS

Or, what’s the derivative of a function?

Jan-Willem Maessen

jmaessen@google.com

2 of 45

Derivatives Defunctionalize CPS

Or, what’s the derivative of a function?

Jan-Willem Maessen

jmaessen@google.com

Work In Progress

3 of 45

Thanks

Edwin Brady

Klaus Ostermann

Mark Miller

IFIP WG 2.16

4 of 45

A simple ordered set

data Set v = E | N (Set v) v (Set v)

insert v E = N E v E

insert v (N l n r)

| v < n = N (insert v l) n r

| n < v = N l n (insert v r)

| otherwise = N l n r

fold e n s = rec s where

rec E = e

rec (N l v r) = n (rec l) v (rec r)

5 of 45

Algebraic Types: Sums of Products

data Set v = E | N (Set v) v (Set v)

Set v = μs. 1 + s*v*s

data List v = Nil | Cons v (List v)

List v = μl. 1 + v*l

6 of 45

7 of 45

Derivatives, Illustrated

N 4

N 2

N 1

E

E

E

E

N 3

N 6

E

E

E

N 7

8 of 45

Derivatives, Illustrated

R 4

N 2

N 1

E

E

E

E

N 3

L 6

E

E

N 7

T

9 of 45

Taking the derivative of a Set

Set v = μs. 1 + s*v*s

s (1 + s*v*s)

= s 1 + s (s*v*s)

= 0 + s*v*s + s*0*s + s*v*s

= s*v*s + s*v*s

Set’ v = μs’. 1 + s’* v * Set v + Set v * v * s’

data Set’ v

= T | L (Set’ v) v (Set v) | R (Set v) v (Set’ v)

10 of 45

Substitution into a Derivative

data Set’ v

= T | L (Set’ v) v (Set v) | R (Set v) v (Set’ v)

sub :: Set’ v → Set v → Set v

sub T t’ = t’

sub (L c v r) l’ = sub c (N l’ v r)

sub (R l v c) r’ = sub c (N l v r’)

11 of 45

Local CPS conversion of insert

insert v t = insertK t v (λt’→ t’)

insertK v E k = k (N E v E)

insertK v (N l n r) k

| v < n = insertK v l (λl’→ k (N l’ n r))

| n < v = insertK v r (λr’→ k (N l n r’))

| otherwise = k (N l n r)

12 of 45

Defunctionalize the continuation

insert v t = insertC t v T

insertC v E c = applyC c (N E v E)

insertC v (N l n r) c

| v < n = insertC v l (L c n r)

| n < v = insertC v r (R l n c)

| otherwise = applyC c (N l n r)

applyC T = λt’→ t’

applyC (L c n r) = λl’→ applyC c (N l’ n r)

applyC (R l n c) = λr’→ applyC c (N l n r’)

13 of 45

Defunctionalize the continuation

insert v t = insertC t v T

insertC v E c = sub c (N E v E)

insertC v (N l n r) c

| v < n = insertC v l (L c n r)

| n < v = insertC v r (R l n c)

| otherwise = sub c (N l n r)

insertC :: v → Set v → Set’ v → Set v

applyC, sub :: Set’ v → Set v → Set v

14 of 45

Is this surprising?

  • The insert function:
    • Finds the context in t where v belongs
    • Inserts v there if it isn’t there already.
    • We trace a single path to a single context.

What about other functions on Set v?

15 of 45

Catamorphism (tree fold)

fold e n s = rec s where

rec E = e

rec (N l v r) = n (rec l) v (rec r)

setSum = fold 0 (λl v r → l + v + r)

toList = fold [] (λl v r → l ++ [v] ++ r)

16 of 45

CPS-converted

fold e n s = recK s (λr → r) where

recK E k = k e

recK (N l v r) k =

recK l (λl'→

recK r (λr'→ k (n l' v r')))

17 of 45

Defunctionalized

fold e n s = recC s T where

recC E c = apply c e

recC (N l v r) c = recC l (L c v r)

apply T r = r

apply (L c v r) l' = recC r (R l' v c)

apply (R l' v c) r' = apply c (n l' v r')

data foldC a v = T

| L (foldC a v) v (Set v) | R a v (foldC a v)

18 of 45

Defunctionalized sum continuation

R 4

6

(partial sum)

L 6

E

E

N 7

T

19 of 45

20 of 45

The directional derivative of a Set

Set v = μs. 1 + s*v*s

s (1 + s*v*s)

= s 1 + s (s*v*s)

= 0 + s*v*r + l*0*r + l*v*s

= s*v*r + l*v*s

SetD l r v = μs’. 1 + s’*v*r + l*v*s’

data SetD l r v

= T | L (SetD v) v r | R l v (SetD v)

21 of 45

Using the directional derivative

data SetD l r v

= T | L (SetD v) v r | R l v (SetD v)

type Set’ v = SetD (Set v) (Set v) v

type foldC a v = SetD a (Set v) v

22 of 45

Why do we care?

  • Transform recursion into iteration
  • Identify derivative behavior for function types
  • Solidify the connection between two ideas that seem to have very similar goals:
    • CPS to expose exact control flow (eg arg ordering)
    • Derivatives to describe the shapes of iteration
  • Solid foundation for immutable iterators

23 of 45

QuickSort

qs as = qsa as where

qsa [] ts = ts

qsa (a:as) ts =

case partition (< a) as of

(ls, rs) -> qsa ls (a : qsa rs ts)

24 of 45

QuickSort: CPS

qs as = qsa as (λt’→ t’) where

qsa [] ts k = k ts

qsa (a:as) ts k =

case partition (< a) as of

(ls, rs) ->

qsa rs ts (λts’→ qsa ls (a:ts’) k)

25 of 45

QuickSort: Defunctionalized

qs as = qsa as Empty where

qsa [] ts c = apply c ts

qsa (a:as) ts c =

case partition (< a) as of

(ls, rs) ->

qsa rs ts (NodeLike ls a c)

apply Empty t’ = t’

apply (NodeLike ls a c) ts’ = qsa ls (a:ts’) c

data SetLike a =

Empty | NodeLike [a] a (SetLike a)

26 of 45

Set Union

union a b = rec a b Nothing Nothing where

rec (N l v r) b (Just lb) ub

| v <= lb = rec r b (Just lb) ub

rec (N l v r) b lb (Just ub)

| ub <= v = rec l b lb (Just ub)

rec (N l v r) b lb ub =

N (rec b l lb (Just v)) v

(rec b r (Just v) ub)

rec E E lb ub = E

rec E b lb ub = rec b E lb ub

27 of 45

Set Union: CPS

union a b = rec a b Nothing Nothing (λs->s) where

rec (N l v r) b (Just lb) ub k

| v <= lb = rec r b (Just lb) ub k

rec (N l v r) b lb (Just ub) k

| ub <= v = rec l b lb (Just ub) k

rec (N l v r) b lb ub k =

rec b l lb (Just v) (λl’ ->

rec b r (Just v) ub (λr’ ->

k (N l’ v r’)))

rec E E lb ub k = k E

rec E b lb ub k = rec b E lb ub k

28 of 45

Set Union: Defunctionalized

union a b = rec a b Nothing Nothing Id where

rec (N l v r) b (Just lb) ub c

| v <= lb = rec r b (Just lb) ub c

rec (N l v r) b lb (Just ub) c

| ub <= v = rec l b lb (Just ub) c

rec (N l v r) b lb ub c =

rec b l lb (Just v) (L c v (r, b, lb, ub))()

rec E E lb ub c = apply c E

rec E b lb ub c = rec b E lb ub c

apply Id s = s

apply (L c v (r, b, lb, ub)) l’ = rec b r (Just v) ub (R l’ v c)

apply (R l’ v c) r’ = apply c (N l’ v r’)

29 of 45

Current Work

  • Recursion on multiple data structures
    • Set Union
    • Sorting
  • Generalize to arbitrary functions
    • Theory (probably wrong): the derivative of a function is the type of its continuation

30 of 45

Things to look at

Brent Yorgey on combinatorial species

Work on reifying recursion structure for sort in Coq.

31 of 45

MergeSort

ms as = msl as (length as) where

msl as n | n <= 1 = as

msl as n = case splitAt half as of

(xs, ys) -> merge (msl xs half) (msl ys (n - half))

where half = n `div` 2

merge (x:xs) (y:ys) | x <= y = x : merge xs (y:ys)

| otherwise = y : merge (x:xs) ys

merge [] ys = ys

merge xs [] = xs

32 of 45

MergeSort: Independent Local CPS

ms as = msl as (length as) (λrs→ rs) where

msl as n k | n <= 1 = k as

msl as n k = case splitAt half as of

(xs, ys) -> msl xs half (λxs’→

msl ys (n - half) (λys’→

merge xs’ ys’ (λrs→ k rs)))

where half = n `div` 2

merge (x:xs) (y:ys) k | x <= y = merge xs (y:ys) (λrs→ k (x:rs))

| otherwise = merge (x:xs) ys (λrs→ k (y:rs))

merge [] ys k = k ys

merge xs [] k = k xs

33 of 45

MergeSort: Defunctionalization

ms as = msl as (length as) Id where

msl as n c | n <= 1 = c as

msl as n c = case splitAt half as of

(xs, ys) -> msl xs half (S c (n-half) ys)

where half = n `div` 2

merge (x:xs) (y:ys) c | x <= y = merge xs (y:ys) (M x c)

| otherwise = merge (x:xs) ys (M y c)

merge [] ys c = applyM c ys

merge xs [] c = k xs

applyS Id rs = rs

applyS (S c n’ ys) xs’ = msl ys n’ (M c xs’)

applyS (M c xs’) ys’ = merge xs’ ys’ (I c)

applyM (I c) rs = applyS c rs

applyM (M x c) rs = applyM c (x:rs)

34 of 45

InsertionSort (Recursive, stable)

is [] = []

is (a:as) = ins (is as) where

ins [] = [a]

ins (b:bs)

| a < b = a : b : bs

| otherwise = b : ins bs

35 of 45

InsertionSort: CPS

is as = is as (λt’→ t’)

isK [] k = k []

isK (a:as) k = isK as (λbs → k (ins bs)) where

ins [] = [a]

ins (b:bs)

| a < b = a : b : bs

| otherwise = b : ins bs

36 of 45

InsertionSort: Defunctionalized

is as = is as []

isC [] c = apply c []

isC (a:as) c = isC as (a:c)

apply [] t’ = t’

apply (a:c) bs = apply c (ins a bs) where

ins [] = [a]

ins (b:bs)

| a < b = a : b : bs

| otherwise = b : ins bs

37 of 45

InsertionSort: Renamed

is as = is as []

isC [] c = isRev c []

isC (a:as) c = isC as (a:c)

isRev [] bs = bs

isRev (a:as) bs = isRev as (ins bs) where

ins [] = [a]

ins (b:bs)

| a < b = a : b : bs

| otherwise = b : ins bs

38 of 45

InsertionSort: Refactored

is as = isRev (reverse as) []

isRev [] bs = bs

isRev (a:as) bs = isRev as (ins bs) where

ins [] = [a]

ins (b:bs)

| a < b = a : b : bs

| otherwise = b : ins bs

39 of 45

InsertionSort: CPS-transform insert

is as = isRev (reverse as) []

isRev [] bs = bs

isRev (a:as) bs = isRev as (ins bs (λt’→ t’))

where

ins [] k = k [a]

ins (b:bs) k

| a < b = k (a : b : bs)

| otherwise = ins bs (λbs’→ k (b:bs’))

40 of 45

InsertionSort: Defunctionalize

is as = isRev (reverse as) []

isRev [] bs = bs

isRev (a:as) bs = isRev as (ins bs []) where

ins [] c = apply c [a]

ins (b:bs) c

| a < b = apply c (a : b : bs)

| otherwise = ins bs (b:c)

apply [] bs = bs

apply (b:c) bs’ = apply c (b:bs’)

41 of 45

InsertionSort: Rename

is as = isRev (reverse as) []

isRev [] bs = bs

isRev (a:as) bs = isRev as (ins bs []) where

ins [] c = revApp c [a]

ins (b:bs) c

| a < b = revApp c (a : b : bs)

| otherwise = ins bs (b:c)

revApp [] bs = bs

revApp (b:c) bs’ = revApp c (b:bs’)

42 of 45

MergeSort

ms as = msl as (length as) where

msl [] n = []

msl as n =

case splitAt half as of

(xs, ys) ->

merge (msl xs half) (msl ys (n - half))

where half = n `div` 2

merge (x:xs) (y:ys)

| x <= y = x : merge xs (y:ys)

| otherwise = y : merge (x:xs) ys

merge [] ys = ys

merge xs [] = xs

43 of 45

MergeSort: CPS

ms as = msl as (length as) (λrs→ rs) where

msl [] n k = k []

msl as n k =

case splitAt half as of

(xs, ys) ->

msl xs half (λxs’→

msl ys (n - half) (λys’→

merge xs’ ys’ k))

where half = n `div` 2

merge (x:xs) (y:ys) k

| x <= y = merge xs (y:ys) (λrs→ k (x:rs))

| otherwise = merge (x:xs) ys (λrs→ k (y:rs))

merge [] ys k = k ys

merge xs [] k = k xs

44 of 45

MergeSort: Defunctionalization

ms as = msl as (length as) Id where

msl [] n c = apply c []

msl as n c =

case splitAt half as of

(xs, ys) ->

msl xs half (S c (n-half) ys)

where half = n `div` 2

merge (x:xs) (y:ys) c

| x <= y = merge xs (y:ys) (MR x c)

| otherwise = merge (x:xs) ys (MR y c)

merge [] ys c = apply c ys

merge xs [] c = apply c xs

apply Id rs = rs

apply (S c n’ ys) xs’ = msl ys n’ (ML c xs’)

apply (ML c xs’) ys’ = merge xs’ ys’ c

apply (MR x c) rs = apply c (x:rs)

45 of 45

MergeSort: Taking the Integral?

data MSC c j v

= Id

| S (MSC c j v) Int [v]

| ML (MSC c j v) j

| MR c (MSC c j v)

data MSTree v

= Split (MSTree v) Int [v]

| Merge (MSTree v) (MSTree v)

| Leaf [v]

type MSCont v = MSC v [v] v