title | date |
---|---|
A Gentle Run-through of Continuation Passing Style and Its Use Cases |
2020-01-02 |
This is yet another CPS introduction post, with a focus on the use cases of CPS. This is where I find some other CPS articles to be insufficient in, for example, when I first read the CPS entry on Wikibooks, I got what CPS is but it wasn't quite clear to me what the different ways are in which CPS can profitably be used in practice. Besides, CPS is a fairly obscure, convoluted and counter-intuitive thing, so I reckon another post that explains it is always beneficial.
I'm aiming to write this post in a way that would have helped myself understand CPS better when I first learned it. It can be considered a supplement to the above Wikibooks entry.
I'll first go over CPS basics, followed by discussing some use cases of CPS, and only in the end will I briefly touch upon
the Cont
monad and callCC
. I personally find this flow to be more natural and digestible,
since it keeps things more concrete and "bare metal" for the most part of the post, without the extra
layer of abstraction Cont
and callCC
introduce.
Besides the CPS entry on Wikibooks, some examples and ideas in this post come from two excellent books: Design Concepts in Programming Languages (DCPL) by Franklyn Turbak and David Gifford with Mark Sheldon, and Compiling with Continuations (CwC) by Andrew W. Appel.
Normally a function takes some input, returns some value, and then the caller of the function continues to do something with that returned value. In CPS, instead of returning some value, the function would take a continuation function, which represents what the caller would do with the returned value, as a first-class parameter, and calls the continuation where the normal function returns the result.
So, a function of type a -> b
would become a -> (b -> r) -> r
in CPS, where b -> r
is the
continuation. These two types are isomorphic due to the isomorphism between b
and forall r. (b -> r) -> r
:
toCPS :: a -> (forall r. (a -> r) -> r)
toCPS = flip ($)
fromCPS :: (forall r. (a -> r) -> r) -> a
fromCPS = ($ id)
For a simple example, take a look at the pythagoras
and pythagoras_cps
functions in the Wikibooks entry.
This is pretty much all there is to say about what CPS is. At this point, a pressing question would be: why is CPS useful? After all, it makes the types longer, makes the code more obscure, and doesn't seem to be able to do anything regular functions can't do.
To answer this question, recognize that CPS has the following characteristics:
1. Every function call is a tail call. You don't ever call a function, then do something with the returned value. Instead, that "something" is passed to the function as the continuation. Functions don't return, but pass their results to the continuations. Therefore there are no non-tail calls.
2. CPS makes a number of concepts more explicit, including return address, call stack, evaluation order, intermediate results, etc.
-
When a function is called, a return address needs to be made available, which is where the callee should return to when it is done. For regular functions, return addresses are transparently handled by the subroutine mechanism: a return address is pushed onto the call stack just before entering a function. This is not something programmers implementing the functions need to care about. In CPS, continuations can be viewed as explicit return addresses: a function "returns" by calling its continuation.
-
It is similar for the call stack. Take a look at the
pythagoras_cps
example in the Wikibooks entry. The nested continuations can be regarded as representing the call stack. Each time a function is called, a new continuation is created which corresponds to the new stack frame for the function call. -
Evaluation order and intermediate results also become explicit. In the
pythagoras
example in the Wikibooks entry, it is unclear which is evaluated first,square x
orsquare y
. Indeed, the compiler is free to do either. And the results ofsquare x
andsquare y
are not named. Inpythagoras_cps
, the evaluation order and the names of intermediate results are both explicitly spelled out.
3. You have more power when you have continuations as a first class functions. When implementing a regular function, you can only return the result at the end of the function. But when you have continuations explicitly passed to you as first class functions, you can get more creative, for instance calling the continuation of an enclosing function, passing the continuations to other functions, or storing them in some data structure and using them at a later time.
Because of these characteristics, there are a number of use cases for which CPS is a good fit. Next I'll discuss some of these use cases in more detail.
Since all function calls in CPS are tail calls, one would naturally wonder if CPS can be used to convert a stack unsafe function into a stack safe one, and it can. Here's a plain-vanilla, stack unsafe factorial function, and the corresponding CPS version:
-- Non-CPS, stack unsafe
fac :: Integer -> Integer
fac 0 = 1
fac n = n * fac (n-1)
-- CPS
facCPS :: Integer -> (Integer -> r) -> r
facCPS 0 k = k 1
facCPS n k = facCPS (n-1) $ k . (*n)
Two things about facCPS
are worth pointing out:
- I didn't bother to convert the
-
and the*
operators into CPS. To do so, you'd need to create functionssubCPS, multCPS :: Int -> Int -> (Int -> r) -> r
, similar to theaddCPS
function in thepythagoras
example in the Wikibooks entry. I'm simply treating-
and*
as primitive operators, rather than regular functions. Because to implementsubCPS
andmultCPS
, you still need to use-
and*
as primitive operators, which kinda defeats the purpose. - This
facCPS
implementation is in fact still not stack safe, because of Haskell's laziness. The(*n)
will keep building up a large thunk before evaluating it, and evaluating the thunk may cause stack overflow for largen
. To make it stack safe we can replacek . (*n)
with\x -> x `seq` k (x * n)
. But this is a separate issue, so for simplicity I'm ignoring this laziness problem and just pretending thatk . (*n)
is fine.
Now facCPS id
is the factorial function, and it is stack safe as long as the programming language we use supports tail-call optimization (TCO), which Haskell does.
It is still stack unsafe in a non-TCO language. For example, facCPS
, when translated into Scala, is not stack safe,
because Scala only optimizes simple self-recursions, but doesn't have TCO in general.
In other words, Scala is not a properly tail recursive language. In particular, function
composition in Scala is not stack safe: if you compose the identity function with itself a million times,
the result is a function that causes stack overflow for all input. In Scala your options are trampolining the
recursive function, or implementing it using tail recursion or iteration.
It is also worth pointing out that, even though facCPS
is stack safe in Haskell, it is not the best
way to implement the factorial function. The following iterative implementation is much more
efficient:
facIterative :: Integer -> Integer
facIterative = go 1
where
go !acc 0 = acc
go !acc n = go (acc*n) (n-1)
(This is called an iterative implementation because it is a simple self-recursive implementation,
it works very much like a while
loop in an imperative language, and can easily be translated into one.)
facIterative
is more efficient mainly because the implementation is simple enough such that GHC, with -O2
enabled,
is able to compile it into a simple loop. facIterative
and facCPS id
have similar performances
without -O2
, but with -O2
, the former is much faster and performs much less memory allocation.
Non-local exits refers to a "jump" action in which a nested function doesn't return to its immediate caller, but
returns to some outer level of control. This can be achieved using goto
in many programming languages.
goto
is generally avoided but the same effect can be achieved by a combination of break
, continue
and
return
. Haskell, of course, has none of these things, instead, CPS is a nice way to implement non-local exits.
As an example, consider the following leaf-valued, non-empty tree data type:
data Tree = Branch Tree Tree | Leaf Int
Suppose we are tasked to implement a function that sums up the leaf values of a tree. This is fairly straightforward:
leafSum :: Tree -> Int
leafSum (Leaf x) = x
leafSum (Branch l r) = leafSum l + leafSum r
The CPS version is:
leafSumCPS :: Tree -> (Int -> r) -> r
leafSumCPS (Leaf x) k = k x
leafSumCPS (Branch l r) k =
leafSumCPS l $ \vl ->
leafSumCPS r $ \vr ->
k (vl + vr)
Now, suppose a (completely arbitrary and useless, admittedly) requirement is added to our task: if any leaf's value is 6, return 1000. Otherwise, still return the sum of all leaf values.
How do we modify our programs to accommodate this new requirement? It's not completely straightforward any more. We have a few options:
The naive approach is to traverse the tree twice. In the first traversal we check whether
there's any Leaf 6
, and if not, we traverse it again to compute the leaf sum. The code is omitted.
This works, but the problem is obvious: the tree is traversed twice.
Instead of traversing the tree twice, we can fuse the two traversals, by making
the inner function go
return a pair. The first component of the pair indicates
whether we have found a Leaf 6
, and the second component is the leaf sum of the
current subtree.
leafSumFused :: Tree -> Int
leafSumFused = snd . go False
where
go True _ = (True, 1000)
go False (Leaf 6) = (True, 1000)
go False (Leaf x) = (False, x)
go False (Branch l r) =
let (bl, resl) = go False l
(br, resr) = go False r
in if bl || br
then (True, 1000)
else (False, resl + resr)
This reduces the number of traversals to one, but it's kinda ugly, and is definitely not the most readable and extensible code one could ever wish to see.
Instead of a recursive tree traversal, we can write an iterative one. Unlike the factorial function, where the iterative implementation is straightforward, an iterative tree traversal requires explicitly managing a stack. The implementation looks like this:
leafSumIterative :: Tree -> Int
leafSumIterative tree = go [tree] 0
where
go [] !acc = acc
go (top:rest) !acc = case top of
Branch l r -> go (l:r:rest) acc
Leaf 6 -> 1000
Leaf x -> go rest (acc+x)
This works nicely, but iterative implementations are often less natural, and harder to read and write,
compared to recursive implementations.
In this particular case, the iterative implementation needs a stack (the first argument to go
). In the general case,
the iterative implementation can be much more complex.
This problem can be solved gracefully using CPS:
leafSumCPS' :: Tree -> (Int -> r) -> r
leafSumCPS' tree k = go tree k
where
go (Leaf 6) _ = k 1000
go (Leaf x) k' = k' x
go (Branch l r) k' =
go l $ \vl ->
go r $ \vr ->
k' (vl + vr)
Here, k
is the outer continuation, which expects the final result; k'
is the inner
continuation, which expects the current local result. In the Leaf 6
case, we pass
1000 to the outer continuation k
. This effectively returns 1000 as the final result.
Put it another way, k
here serves a similar purpose to the return
keyword in many
programming languages. Calling the final continuation with a value is
equivalent to returning that value.
The takeaway of this example is that, with CPS it is much easier to jump to a certain part of the code, as long as we have the continuations we need at our disposal, and invoke that right continuation at the right time. The next use case, backtracking, further demonstrates this point.
Backtracking is a class of algorithms for finding solutions to computational problems by traversing a search space. It can be especially useful for NP-complete problems, where an asymptotically more efficient algorithm may not exist. Intuitively, in a tree- or graph-shaped search space, backtracking involves making a guess at a certain node, continuing the search, and upon hitting a wall, going back to that node and taking a different branch.
So when we explore the search space using backtracking, we need to jump to certain nodes at certain times. And for the same reason CPS is a good fit for non-local exits, it is also a good fit for backtracking. By calling the right continuation at the right time, you can jump around the search space however you want.
The following example, stolen from the DCPL book, illustrates CPS-based backtracking. It uses backtracking to solve the SAT problem, which is NP-complete. The following data type is used for boolean formulae:
data BF = Var String | Not BF | And BF BF | Or BF BF
and the function we want to implement has the following type and expected behavior:
solve :: BF -> Maybe (Map String Bool)
test1 = Var "a" `And` (Not (Var "b") `And` Var "c")
test2 = Var "a" `And` Not (Var "a")
test3 = Var "a" `Or` (Not (Var "b") `And` Var "c")
test4 = (Var "a" `Or` (Var "b" `Or` Var "c")) `And` (Not (Var "a") `And` Not (Var "b"))
solve test1 `shouldBe` Just ([("a",True),("b",False),("c",True)])
solve test2 `shouldBe` Nothing
solve test3 `shouldBe` Just ([("a",True)])
solve test4 `shouldBe` Just ([("a",False),("b",False),("c",True)])
Note that the solution for test3
only contains ("a",True)
since the values of b
and c
don't matter.
Here's the implementation of solve
:
import Data.Map (Map, (!?))
import qualified Data.Map as Map
import Prelude hiding (fail, succ)
sat :: BF
-> Map String Bool
-> (Bool -> Map String Bool -> r -> r)
-> r
-> r
sat bf asst succ fail = case bf of
Var v ->
case asst !? v of
Just b -> succ b asst fail
Nothing ->
let asstT = Map.insert v True asst
asstF = Map.insert v False asst
tryT = succ True asstT tryF
tryF = succ False asstF fail
in tryT
Not bf' ->
let succNot = succ . not
in sat bf' asst succNot fail
And l r ->
let succAnd True asstAnd failAnd = sat r asstAnd succ failAnd
succAnd False asstAnd failAnd = succ False asstAnd failAnd
in sat l asst succAnd fail
Or l r ->
let succOr True asstOr failOr = succ True asstOr failOr
succOr False asstOr failOr = sat r asstOr succ failOr
in sat l asst succOr fail
solve :: BF -> Maybe (Map String Bool)
solve bf =
sat
bf
Map.empty
(\b asst fail -> if b then Just asst else fail)
Nothing
The Not
, And
and Or
cases are fairly routine. The key of the implementation is
the Var
case, where if the current variable is unassigned, we first try setting
it to True
, then continue with the exploration by calling the succ
continuation.
If it eventually fails, the tryF
continuation will be called, which essentially goes back to the current node,
resets the variable to False
, and then continues with the same exploration by calling the same
succ
continuation. If it fails again, then it means the variable in question can
be neither True
nor False
, suggesting that there's no solution to the Boolean formula. The
top-level fail
continuation is then called, causing the entire calculation to
return Nothing
.
There are several continuations in this implementation, which may take some time to
sort out. Specifically, sat
takes a succ
continuation of type
Bool -> Map String Bool -> r -> r
, a (global) fail
continuation of type r
(consider it
a simplification of () -> r
), and the succ
continuation itself also takes
a (local) fail
continuation of type r
. It is a good exercise to attempt to
fully understand how this implementation works.
Backtracking using a succ
continuation and a fail
continuation is the machinery behind
the logict library.
Left-associative <>
is asymptotically expensive for some monoid, most notably the free monoid (i.e., cons-lists),
and left-associative >>=
is asymptotically expensive for some monad, most notably the free monad.
One way to turn the table and change the associativity of <>
and >>=
is by using continuations.
First of all, CPS computations can be composed like this:
chainCPS :: ((a -> r) -> r) -> (a -> ((b -> r) -> r)) -> ((b -> r) -> r)
chainCPS s f = s . flip f
By the way, this suggests that (_ -> r) -> r)
is a monad (i.e., the Cont
monad). By turning <>
and >>=
into chainCPS
, we can change left-associative operations into
right-associative ones.
Every monoidal value a
can be embedded into, and projected from a CPS computation
of type (() -> a) -> a
:
monoidToCPS :: Monoid a => a -> (() -> a) -> a
monoidToCPS a = (a <>) . ($ ())
monoidFromCPS :: Monoid a => ((() -> a) -> a) -> a
monoidFromCPS cps = cps (const mempty)
Now, given a left-associative <>
, we can turn each operand into a CPS computation,
substitute chainCPS
for <>
, and finally, project the monoid out of the result, for instance
-- left-associative <>
sumL = ([1,2,3] <> [4,5,6]) <> [7,8,9]
-- right-associative <>
sumR = monoidFromCPS $
monoidToCPS [1,2,3]
`chainCPS` (\_ -> monoidToCPS [4,5,6])
`chainCPS` (\_ -> monoidToCPS [7,8,9])
If we expand chainCPS
in the definition of sumR
, and simplify, we get
sumR = [1,2,3] <> ([4,5,6] <> [7,8,9])
which is precisely the right-associative version of sumL
. This is the mechanism
behind DList.
In a completely analogous fashion, every monadic value m a
can be embedded into, and
projected from a CPS computation of type forall r. (a -> m r) -> m r
:
monadToCPS :: Monad m => m a -> (forall r. (a -> m r) -> m r)
monadToCPS ma = (ma >>=)
monadFromCPS :: Monad m => (forall r. (a -> m r) -> m r) -> m a
monadFromCPS cps = cps pure
Given a left-associative >>=
, we can turn each operand into a CPS computation, substitute
chainCPS
for >>=
, and project the monadic value out of the result, for instance
-- left-associative >>=
resL = [1,2,3] >>= (\x -> [x+1]) >>= (\y -> [y+2])
-- right-associative >>=
resR = monadFromCPS $
monadToCPS [1,2,3]
`chainCPS` (\x -> monadToCPS [x+1])
`chainCPS` (\y -> monadToCPS [y+2])
If we expand chainCPS
in the definition of resR
, and simplify, we get
resR = [1,2,3] >>= (\x -> [x+1] >>= (\y -> [y+2]))
which is precisely the right-associative version of resL
. This is the mechanism
behind Codensity.
Using a similar idea, we can express foldl
in terms of foldr
and vice versa. Take a look at the
implementation and see if you can spot the continuations involved.
Hint: a -> a
and (() -> a) -> a
are isomorphic types.
The CwC book gives an excellent detailed explanation on the usage of CPS in the Standard ML of New Jersey compiler. To very briefly summarize why CPS is useful as a compiler IR: since CPS makes a number of things more explicit, such as return address, call stack, intermediate results, as well as every aspect of the control flow, a CPS'ed program is usually much closer to the assembly code that the compiler eventually generates, compared to the original program. For example, the variables in a CPS'ed program corresponds quite closely to the registers of the target machine. The explicitness is often the opposite of what a programmer writing programs wants, but they are highly relevant to the compiler implementation.
CPS computations can be composed as demonstrated by chainCPS
. This gives rise to the Cont
monad.
Cont r a
is simply a newtype wrapper around (a -> r) -> r
.
The continuation a -> r
is "hidden" by Cont r a
, so how do we manipulate the continuations
using Cont r a
like what we did in the examples in Non-Local Exit and Backtracking? The answer is
callCC
. callCC
stands for "call with current continuation", and it brings the current continuation, i.e.,
the continuation that expects the final result of the current computation, into scope.
If we adapt the leafSumpCPS'
function (in the non-local exits example) to use the Cont
monad and
callCC
, it would look like this:
leafSumCPS'' :: Tree -> Cont r Int
leafSumCPS'' tree = callCC $ \k ->
let
go (Leaf 6) _ = k 1000
go (Leaf x) k' = k' x
go (Branch l r) k' =
go l $ \vl ->
go r $ \vr ->
k' (vl + vr)
in
go tree k
The go
function is exactly the same as before. We just need to add a callCC
at the beginning of leafSumCPS''
. More examples
of callCC
can be found in the Wikibooks entry.
The implementation of callCC
, after stripping out the cont
and runCont
wrappers, is
callCC :: ((a -> (b -> r) -> r) -> (a -> r) -> r)
-> (a -> r) -> r
callCC f h = f (\a _ -> h a) h
It is a good exercise to try to decipher what the type means and what the implementation does. Hint: note the
ignored argument (_
) in the implementation. Its type is b -> r
, and it is a continuation that basically represents
the rest of the computation given the result of calling k
(which can be any type b
; it doesn't matter). The entire
rest of the computation is ignored, and so we are effectively returning whatever value passed to k
.
To make it further clear, think about what happens if we never call the "exit function" k
vs. we do call k
. For example,
if we never call k
, which is the first argument to f
, then it is the same as if f
ignores its first argument, say
f = const res
where res :: (a -> r) -> r
. In this case res
is also what callCC f
returns, i.e., it returns the same
result as if we didn't use callCC
in the first place.
This post is a gentle introduction to CPS and some of its use cases. Once again, CPS is obscure, labyrinthine,
and counter-intuitive in many ways. Overusing it is hardly ever a compelling idea (some say it's like abusing goto
statements
in many programming languages). It can be highly rewarding to understand CPS, but it's probably advisable
to resist the urge to use it except for the killer use cases.
This post is adapted from a presentation I made in one of Formation's Haskell study group sessions. I received helpful feedback on both the presentation and a draft of this post from coworker Ian-Woo Kim.