Continuation与Call/CC

和我同乘时光机。


本文要点

你将收获

  • 了解Continuation的概念。
  • 了解CPS(Continuation Passing Style)的概念与优势。
  • 有机会搭乘Call With Current Continuation这台时光机。

所需基础

  • 熟悉Haskell的Monad用法。

Continuation

什么是Continuation

为了了解何谓Continuation,我们先来考虑这样一个表达式:4 * 3
如果把焦点放在3上,在这里给出两种求值风格:

  • (4*) 3
  • ($3) (4*)

它们的结果都是一样的:12
前者,3被函数(4*)调用。后者,3选择让(4*)调用自己。
这期间微妙的区别,便是CPS(Continuation Passing Style)的来由。后者是典型的CPS,(4*)就是3Continuation

CPS的角度看,($3)要做的事情就是等待。它的类型是:(a -> r) -> r,意味着,它需要一个函数作为参数,才能求出最后的结果。它所等待的函数,类型为a -> r,被称为ContinuationContinuation指定了最终值应该如何被求得。

在实践中,我们可以简单地通过flip ($),将一个值转换成等待应用的形式。另外,可以通过传递id函数作为其参数,把原始值返回。


Continuation的好处

Continuation的内涵可不仅仅是这样简单地调换应用顺序,更重要的是,它带给了我们显式操控和动态选择程序控制流的可能。比如,实现程序的提前返回,异常的传递和处理可以通过Continuation实现 —— 一个Continuation用于处理正常情况,另一个Continuation用于处理异常情况,以及实现简单的并发(Concurrency)。

而且,当我们所有的函数都严格按照CPS来编写,那么所有的函数调用都会是尾调用(Tail Call)的形式。使用尾调用形式可以进行尾调用优化(TCO, Tail Call Optimization),使得程序不再需要运行时栈(Run-time Stack),在现在的许多解释器和编译器中,都能找到这样一种技术。


实例

想要采用Continuation的简单方法是修改所有函数,让函数返回待应用的形式。下面将演示两个实例。

pythagoras

1
2
3
4
5
6
7
8
9
10
-- 普通函数调用风格,计算平方和:
add :: Int -> Int -> Int
add x y = x + y
square :: Int -> Int
square x = x * x
pythagoras :: Int -> Int -> Int
pythagoras x y = add (square x) (square y)

将返回值修改成待应用形式,pythagoras将变成:

1
2
3
4
5
6
7
8
9
10
11
12
13
-- 函数的CPS版本
add_cps :: Int -> Int -> ((Int -> r) -> r)
add_cps x y = \k -> k (add x y)
square_cps :: Int -> ((Int -> r) -> r)
square_cps x = \k -> k (square x)
pythagoras_cps :: Int -> Int -> ((Int -> r) -> r)
pythagoras_cps x y = \k ->
square_cps x $ \x_squared ->
square_cps y $ \y_squared ->
add_cps x_squared y_squared $ k

pythagoras_cps执行过程:

  • x求平方,把结果丢到\x_squared -> ...这个Continuation里。
  • y求平方,把结果丢到\y_squared -> ...这个Continuation里。
  • x_squaredy_squared加起来,把结果丢到k这个Continuation里。

在GHCi里运行一下,把id作为参数/Continuation传给最终程序:

1
2
λ> pythagoras_cps 3 4 id
25

从这里我们已经可以看出,CPS这种风格最突出的特征是:所有函数调用都要把未来需要做的事情显式地传递给它。函数不止关注自己的值如何求出,还关注求出的值如何被使用。

thrice

1
2
3
4
-- 一个高阶函数,无Continuation版本
thrice :: (a -> a) -> a -> a
thrice f x = f (f (f x))
1
2
λ> thrice tail "foobar"
"bar"

thrice这样的高阶函数,转换成CPS时,它接受的函数参数也要改成相应的CPS版本。因此f :: a -> a将变成f_cps :: a -> ((a -> r) -> r),最终类型变成thrice_cps :: (a -> ((a -> r) -> r)) -> a -> ((a -> r) -> r)

1
2
3
4
5
6
7
-- 高阶函数带Continuation版本
thrice_cps :: (a -> ((a -> r) -> r)) -> a -> ((a -> r) -> r)
thrice_cps f_cps x = \k ->
f_cps x $ \fx ->
f_cps fx $ \ffx ->
f_cps ffx $ k

Cont Monad

有了这些continuation-passing函数,下一步应该做的是提供一种更简洁的组合它们的方法,而不是像上面那样嵌套那么多lambda
先来尝试写出用于应用(apply)待应用函数的组合子(Combinator)。它的类型应该是这样的:

1
chainCPS :: ((a -> r) -> r) -> (a -> ((b -> r) -> r)) -> ((b -> r) -> r)

实现如下:

1
chainCPS s f = \k -> s $ \x -> f x $ k

调用者首先提供一个待应用形式的值s,通过嵌套的lambdaa类型的值拿到,让f应用于它,最后再用新的Continuationk把这一切包裹起来。

没错,相信聪明的读者已经看出来,上面的函数类型和Monad里的(>>=)函数类型相当相似。而且,flip ($)也可以充当return的角色。



一个新的Monad诞生了。

我们可以定义Cont r a来包裹这种待应用形式的值,再实现包装和解包函数:

1
2
cont :: ((a -> r) -> r) -> Cont r a
runCont :: Cont r a -> (a -> r) -> r

Monad的实现上面已经介绍过,不同的只是包装和解包:

1
2
3
instance Monad (Cont r) where
return x = cont ($ x)
s >>= f = cont $ \c -> runCont s $ \x -> runCont (f x) c

Monad实例把Continuation的传递隐藏起来,拥有了Monad绑定的种种好处,把上面的pythagoras重写一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
-- 使用transformers包里面的Cont Monad
import Control.Monad.Trans.Cont
add_cont :: Int -> Int -> Cont r Int
add_cont x y = return (add x y)
square_cont :: Int -> Cont r Int
square_cont x = return (square x)
pythagoras_cont :: Int -> Int -> Cont r Int
pythagoras_cont x y = do
x_squared <- square_cont x
y_squared <- square_cont y
add_cont x_squared y_squared


callCC

实现了可爱的Monad固然是值得高兴的,可是别忘了我们想要CPS的目的:用Continuation精确操控程序控制流。如果把Continuation都隐藏在Monad背后,这种控制将不复存在。为了纠正这个问题,我们引入了callCC函数。

callCC是一个非常神奇的函数,举个例子:

1
2
3
4
5
6
7
-- Without callCC
square :: Int -> Cont r Int
square n = return (n ^ 2)
-- With callCC
squareCCC :: Int -> Cont r Int
squareCCC n = callCC $ \k -> k (n ^ 2)

传递给callCC的是一个函数,这个函数的返回值是一个待应用形式的值(即类型为Cont r a)。重点来了:让callCC如此特殊的原因就在于k,是k让精确控制程序成为可能。

那么k究竟是何方神圣?

事情是这样的:
程序调用callCC的瞬间,在此处就放置了一个传送门、或者说存档点、或者说时光机。业界所说的放置了,我认为不甚准确。k就是这个传送门callCC内部的另一端。一旦内部有任何一个值调用了kk就会无条件地把这个值传递回去,内部此后任何对k的调用都将被无情抛弃。
从另一个角度来看,k也可以说是调用callCC以后的所有计算,一旦callCC内部有值调用了kk就将外面的整个世界传递给这个值,毕竟callCC的全称是call with current continuation
随读者怎么理解,我更倾向于前者。

下面我们就来探索callCC究竟可以有怎样的可能性。


决定何时使用k

callCC允许我们决定让什么值传递回去,以及何时传递。下面举个例子:

1
2
3
4
5
6
7
-- 简单的callCC调用
foo :: Int -> Cont r String
foo x = callCC $ \k -> do
let y = x ^ 2 + 3
when (y > 20) $ k "over twenty"
return (show $ y - 4)

如果y的值大于20k就会带着"over twenty"回到Continuation,否则,直到最终才让show $ y - 4返回到Continuation中。
从这里看,k有点像命令式语言里的return,但实际上k远远强大于return,因为它是头等公民,你可以把k传递给其他函数,比如when,把k存储在Reader里等等。

同样,你可以在do里面嵌入callCC

1
2
3
4
5
6
7
8
bar :: Char -> String -> Cont r Int
bar c s = do
msg <- callCC $ \k -> do
let s0 = c : s
when (s0 == "hello") $ k "They say hello."
let s1 = show s0
return ("They appear to be saying " ++ s1)
return (length msg)

从这里看,k又有点像其他语言里的goto,调用了k就会返回到msg <- ...这一行。

k后面的无用的行:

1
2
3
4
5
quux :: Cont r Int
quux = callCC $ \k -> do
let n = 5
k n
return 25

这个函数会把5返回到Continuation中,return 25是完完全全没有作用的。


舞台背后

如此神奇的callCC究竟如何实现。经过上面的举例,我们可以得出callCC的类型为:

1
callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a

总体返回应该和参数函数的返回类型一致(即Cont r a),因为在没有调用k的情况下,得出的返回值应该一致。
那么k的类型呢?
正如上面所演示的,k的参数将被传到calCC调用点处。所以,k的参数类型是a
不过k的返回类型就有点意思了,b代表着任意类型。因为上面提过,k意味着跟在callCC以后的任何Continuation,所以k的返回类型是任意的。

注意:

1
2
3
4
5
6
7
-- 有错误的代码
quux :: Cont r Int
quux = callCC $ \k -> do
let n = 5
when True $ k n
k 25

由于whenk n的类型已经被约束为Cont r (),因此后面的k 25quux的返回类型不符合,应该改成return 25

实际实现:

1
callCC f = cont $ \h -> runCont (f (\a -> cont $ \_ -> h a)) h

慢慢看,慢慢理解。


更多实例

复杂控制结构

让我们来看看更现实的控制流操控的例子。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
{- We use the continuation monad to perform "escapes" from code blocks.
This function implements a complicated control structure to process
numbers:
Input (n) Output List Shown
========= ====== ==========
0-9 n none
10-199 number of digits in (n/2) digits of (n/2)
200-19999 n digits of (n/2)
20000-1999999 (n/2) backwards none
>= 2000000 sum of digits of (n/2) digits of (n/2)
-}
fun :: Int -> String
fun n = (`runCont` id) $ do
str <- callCC $ \exit1 -> do -- define "exit1"
when (n < 10) (exit1 (show n))
let ns = map digitToInt (show (n `div` 2))
n' <- callCC $ \exit2 -> do -- define "exit2"
when ((length ns) < 3) (exit2 (length ns))
when ((length ns) < 5) (exit2 n)
when ((length ns) < 7) $ do
let ns' = map intToDigit (reverse ns)
exit1 (dropWhile (=='0') ns') --escape 2 levels
return $ sum ns
return $ "(ns = " ++ (show ns) ++ ") " ++ (show n')
return $ "Answer: " ++ str

fun是一个使用ContcallCC建立起来的控制结构,接受不同的整数值n,而落到不同范围,从而做出不同的事情。剖析一下:

  1. (`runCont` id)意味着把最终的计算结果Continuation通过id给拿出来。
  2. 我们将callCC的结果绑定到str中:
    1. 如果n小于10,直接退出,只显示n
    2. 如果不是,继续进行。构造一个list,叫ns,里面转载n `div` 2的数字字符。
    3. 内部callCC,结果绑定到n'
      1. 如果length ns < 3,带着length ns从内部do-block退出。
      2. 如果n `div` 2小于5个字符,带着n从内部do-block退出。
      3. 如果n `div` 2小于7个字符,带着一个String直接退出到外部的callCC
      4. 否则,最终带着n `div` 2sum退出。
    4. 最终带着字符串返回。
  3. 最终带着str返回。

异常

Continuation的一种用途是对异常进行建模。要做到这一点,我们要维护两个Continuation,一个带我们到handler处理函数以处理异常,一个在无异常的情况下带我们到处理后的代码。

这里有一个简单的函数,模拟除零异常:

1
2
3
4
5
6
7
8
9
10
11
divExcpt :: Int -> Int -> (String -> Cont r Int) -> Cont r Int
divExcpt x y handler = callCC $ \ok -> do
err <- callCC $ \notOk -> do
when (y == 0) $ notOk "Denominator 0"
ok $ x `div` y
handler err
{- For example,
runCont (divExcpt 10 2 error) id --> 5
runCont (divExcpt 10 0 error) id --> *** Exception: Denominator 0
-}

从代码中可以看出,我们嵌套使用了两个callCC,一个是将在没出问题时使用的Continuation,一个是我们希望抛出异常时将使用到的Continuation。如果分母不是0,则okContinuation将直接返回到顶层,否则,err将带着"Denominator 0"交给handler处理。

下面介绍更通用的异常处理方法。第一个参数传递计算方法(准确地说,这个计算会得到一个throw函数,并在定义中决定是否使用)。另一个参数为错误处理程序。

1
2
3
4
5
6
7
8
import Control.Monad.Cont
tryCont :: MonadCont m => ((err -> m a) -> m a) -> (err -> m a) -> m a
tryCont c h = callCC $ \ok -> do
err <- callCC $ \notOk -> do
x <- c notOk
ok x
h err

try用到action中:

1
2
3
4
5
6
7
8
9
data SqrtException = LessThanZero deriving (Show, Eq)
sqrtIO :: (SqrtException -> ContT r IO ()) -> ContT r IO ()
sqrtIO throw = do
ln <- lift (putStr "Enter a number to sqrt: " >> readLn)
when (ln < 0) (throw LessThanZero)
lift $ print (sqrt ln)
main = runContT (tryCont sqrtIO (lift . print)) return

在这个例子中,throw意味着从封闭的callCC中逃脱(escape)了,这个throw原本在tryCont的内部callCC中。


参考

柯里化的前生今世(八):尾调用与CPS
https://en.wikipedia.org/wiki/Continuation-passing_style#CPS_in_Haskell
https://en.m.wikibooks.org/wiki/Haskell/Continuation_passing_style#callCC