再帰関数のスタックオーバーフローを倒す話 その1

再帰関数のスタックオーバーフローを倒す話を何回かに分けてします。

連載目次

はじめに

継続渡しスタイルもしくは継続渡し形式(Continuation Passing Style、以降CPS)という言葉を聞いたことがあるでしょうか。 今日はCPSの話をします。 前提知識は、F#のみです。

継続とは

CPSの前に、まずは継続の話です。 継続と言っても、継続的インテグレーションとか継続的デリバリとはまったく、これっぽっちも関係ありませんのでそういう話題を期待した人は回れ右。 これらの文脈では継続は「繰り返し」とかそんな風な意味を含んでいますが、今回扱う継続は「続き」とかそんな意味ととらえてください。

続きったって何の続きだ、となるわけですが、ざっくり説明すると、

「プログラムのある瞬間を考えたときに、その瞬間より後に実行される処理」

が継続です。

プログラムのデバッグでブレークポイントを貼って処理をブレークしたときに、「そのあとに実行される処理」ってあるじゃないですか。 あれをプログラミングの対象にしてしまおう、というような話だと思ってください。

let x = f 42 (* ここでブレークして、fから戻ってきた状態(fは実行済み) *)
printfn "%d" x

コメントに書いたような状態だと思ってください。 このときに、継続は

+---------+
| let x = | f 42
|         +------+
| printfn "%d" x |
+----------------+

この枠で囲われた部分です。 =の左側が計算されてからその結果が右側の変数xに設定されるので、let x =の部分も継続に含めています。

さて、これをプログラミングの対象にするにはどうすればいいでしょう?

継続を無名関数で表す

一つの方法として、プログラムを変形して継続を無名関数で表す、というのがあります。 やってみましょう。

f 42 |> (fun x ->
printfn "%d" x)

上のコードとこのコードが同じ動作をすることは分かるでしょうか。

先ほどはletの一部分が継続に含まれていたので、プログラミングの対象に出来なさそうでした。 それに対して、このコードでは先ほどの例と同じ継続は

          +-----------+
  f 42 |> | (fun x -> |
+---------+           |
| printfn "%d" x)     |
+---------------------+

この枠で囲われた部分です(|>演算子は使わないこともできるので含めていません)。 これなら、この関数自体をプログラミングの対象に出来ますね!

(* もはや値としての関数でしかない *)
let cont = (fun x -> printfn "%d" x)

これだけだとありがたみがさっぱりですが、継続は関数として表せる、ということがわかりました。

無名関数でletを代用する

先ほどの変形によって、letが消えたのに気づいたでしょうか? letによって導入していた変数は、継続を表す無名関数の引数に変わりました。 letを無名関数で表すことは後で重要になってくるので、もう少し詳しく見てみましょう。

このようなプログラムがあったとします。

let x = f 42
let y = g x
let z = h y
printfn "%A" (x, y, z)

これをletを使わずに無名関数だけで書いてみます。

f 42 |> (fun x ->
g x |> (fun y ->
h y |> (fun z ->
printfn "%A" (x, y, z) )))

f 42より後ろを表す継続はf 42の戻り値をxとして引数で受け取ります。 そして、g xより後ろを表す継続はg xの戻り値をyとして引数で受け取ります。 h yより以下略。

このように、継続を起動した(継続を表す関数を呼び出した)側の結果は、引数として受け取り、その継続の中で使えます。

継続渡しスタイル(CPS)

fとgとhがそれぞれこのような関数だったとしましょう。

let f x = x / 2    // int -> int (intを受け取ってintを返す関数)
let g x = x + 10   // 同上
let h x = string x // int -> string (intを受け取ってstringを返す関数)

これを元にして、各関数が「自身の処理をした後の継続cont*1」を受け取れるようにしてみます。

let fCont x cont = x / 2 |> cont    // int -> (int -> 'a) -> 'a (intと「intを受け取って何か('a)を返す関数」を受け取って何か('a)を返す関数)
                                    // 元の関数での戻り値は、第二引数で渡される関数の引数になっている
let gCont x cont = x + 10 |> cont   // 同上
let hCont x cont = string x |> cont // int -> (string -> 'a) -> 'a (intと「stringを受け取って何か('a)を返す関数」を受け取って何か('a)を返す関数)
                                    // 元の関数での戻り値は、第二引数で渡される関数の引数になっている

このように、「自身の処理をした後の継続」を受け取る関数のことを、「継続渡しスタイルの関数」と言います。

元の関数では結果はそのまま呼び出し元に返していましたが、このバージョンではcontに結果を渡しています。 contは「自身を処理した後の継続」ですから、それに結果を渡すことによって、contの中で結果が使えるようにするためです。

fContはどうやって使えばいいでしょうか? fであれば、例えばこのように使っていました。

let res = f 10
...

fContはこのように使います。

fCont 10 (fun res ->
...)

「無名関数でletを代用する」で見たような書き方になっていますね。 「無名関数でletを代用する」では、|>演算子を使って順番を入れ替えていましたが、継続渡しスタイルの関数を使う場合は不要です。

このように、継続渡しスタイルの関数を使って継続を渡すプログラミングスタイルが、「継続渡しスタイル(CPS)」です。

fCont 42 (fun x ->
gCont x (fun y ->
hCont y (fun z ->
printfn "%A" (x, y, z) )))

ここからは継続が関数として表せると何が嬉しいかを説明するための準備となることを説明します。

末尾呼び出し

末尾呼び出しというのは、関数を呼び出した後に結果を戻す以外にすることがないような関数呼び出しのことを言います*2。 さて、ではf1, f2, f3の中で末尾呼び出しされている関数はどれでしょうか?

let example x =
  if f1 x then f2 x
          else 10 + f3 x

答えは、f2だけです。

f1が末尾呼び出しじゃないというのは、f1を呼び出した後にthen節かelse節を実行する必要があることから分かります。

f2の後にelseがあるように思えるかもしれませんが、then節とelse節は二者択一であり、then節が選ばれたときにはelse節は実行されません。 then節ではf2を呼び出した後は何もすることなくその結果を戻すだけなので、f2は末尾呼び出しです。

f3の呼び出しは、その結果を使って10と加算するという処理がf3から戻ってきたときに必要です。 そのため、f3は末尾呼び出しではありません。

何が「末尾」になるのかは今回は横道なので深入りはしませんが、別の機会に(F#については)まとめようと思います。

末尾呼び出しの最適化

末尾呼び出しは「関数から戻ってきた後に結果を戻す以外にすることがないような関数呼び出し」でした。 何もすることがないのなら、関数呼び出しじゃなくて、単なるジャンプ命令に置き換えてしまえばスタックを消費しなくなっていいよね! というのが末尾呼び出しの最適化です*3。

これが嬉しいのは、例えば再帰関数が末尾呼び出しになっている場合です*4。 このような再帰を末尾再帰と言ったりします。 末尾呼び出しが最適化されないと、再帰の回数が積み重なるとスタックオーバーフローを起こしてしまいます。 末尾呼び出しが最適化されることで、再帰の回数が積み重なってもスタックオーバーフローが起こらなくなるため、再帰の回数が多くなり得る関数は末尾呼び出しの最適化がかかるように末尾再帰の形に変形することがあります。 式木の変形など、単純に書くと末尾再帰にならない再帰は山のようにあるので、末尾再帰の形に変形する方法は重要です。

あ、一応言っておくと、末尾呼び出しの最適化がかかるかどうかは言語や処理系によって違いますので、 末尾再帰に変形したからと言ってスタックオーバーフローが起きなくなることが保証されるわけではありません。 自分の好きなあの言語、あの処理系、末尾呼び出しの最適化がかかるかどうか調べておくといいでしょう。

CPS変換による末尾再帰関数への変換

さて、話を継続に戻します。 CPSに変形(CPS変換)することで、自動的に末尾再帰の関数が手に入るのです! なぜそうなるのかを見てみましょう。

継続渡しスタイルの関数と、それを使うプログラムです。

let fCont x cont = x / 2 |> cont
let gCont x cont = x + 10 |> cont
let hCont x cont = string x |> cont

let program () =
  fCont 42 (fun x ->
  gCont x (fun y ->
  hCont y (fun z ->
  printfn "%A" (x, y, z) )))

継続渡しスタイルの関数は、継続を末尾呼び出ししているのが一目で分かります。 では、継続渡しスタイルの関数を使っている側はどうでしょうか。 こちらも、それぞれの関数は末尾呼び出しになっています。 インデントを追加するとわかりやすいでしょう。

let program () =
  // fContの呼び出しは、program関数の末尾で行われている
  // gContなどの呼び出しは、関数でくるまれた中にいるその場では呼び出されない
  fCont 42 (fun x ->
    // gContの呼び出しは、fContの継続の末尾で行われている
    // hContなどの呼び出しは、関数でくるまれた中にいるのでその場では呼び出されない
    gCont x (fun y ->
      // hContの呼び出しは、gContの継続の末尾で行われている
      hCont y (fun z ->
        printfn "%A" (x, y, z)
      )
    )
  )

継続渡しスタイルの関数では、関数の最後は「継続を表す関数に結果を渡す」ことになりますし*5、 継続渡しスタイルの関数を呼び出す場合もやはり末尾呼び出しになります。 そのため、再帰部分を継続渡しスタイルで書けば自動的に末尾呼び出しになるのです。

つまり、末尾再帰ではない再帰関数をCPS変換したら末尾再帰関数になり、末尾呼び出しの最適化がかかります。 ようやく、CPS変換のうれしさが分かるところまで来ました。 では、末尾呼び出しになっていない再帰関数をCPS変換してみましょう。

階乗をCPS変換

簡単な例として、階乗からやってみます。 まずは、末尾再帰ではないfactの定義です。

let rec fact = function
| n when n = 0I -> 1I
| n -> n * (fact (n - 1I))

bigintが定数パターンとして使えないのでwhenを使っているのがちょっと残念ですが、それ以外は普通のコードです。 この関数は、再帰呼び出しをした後にその結果とnの値を掛けているため、末尾再帰になっていません。 そのため、この関数に50000Iを渡すとスタックオーバーフローが起きました。

これを、まずは再帰呼び出し部分をletを使った形に書き換えます。 letを使った形にするとCPS変換しやすくなるので、慣れないうちはまずはletを使った形に変形するところから始めるといいでしょう。

let rec fact n =
  if n = 0I then 1I
  else
    let pre = fact (n - 1I)
    n * pre

次に、これをCPSに書き換えます。 まずは、継続を引数contとして受け取るようにします。

(* 変換途中 *)
let rec fact' n cont =
  if n = 0I then 1I
  else
    let pre = fact' (n - 1I)
    n * pre

contは継続なので、fact'の処理の結果を渡してあげることでfact'の後ろの処理を実行します。 こうでしょうか?

(* 変換途中: elseがおかしい *)
let rec fact' n cont =
  if n = 0I then 1I |> cont
  else
    let pre = fact' (n - 1I)
    n * pre |> cont

これはコンパイルが通りません。 fact'は第二引数として継続を受け取るため、preはfact'の結果ではなく関数になってしまっています。 そこで、fact'を呼び出した後の処理(n * pre |> cont)をfact'に渡す無名関数の中に入れてしまいます。

(* 変換完了! *)
let rec fact' n cont =
  if n = 0I then 1I |> cont
  else
    fact' (n - 1I) (fun pre ->
    n * pre |> cont)

letで導入される変数を無名関数の引数として導入する形にするのは、今まで何回か見てきているので大丈夫でしょう。 これで無事、CPS変換できました! しかしこのままでは元の関数と同じ使い方ができません。 「スタックオーバーフローしなくなりましたが、代償として継続を渡す必要ができました!」では駄目でしょう。 そこで、CPSな関数をラップする関数を用意します。

CPS版のfact'をラップする

さて、fact'を外から呼び出す場合、contには何を渡せばいいでしょうか? それを考える前に、fact'のシグネチャを確認してみましょう。

val fact' :
  n:System.Numerics.BigInteger ->
    cont:(System.Numerics.BigInteger -> 'a) -> 'a

System.Numerics.BigIntgerの別名としてbigintがあるので、これを使って書き直すと、

val fact' : n:bigint -> cont:(bigint -> 'a) -> 'a

こうです。 ここから分かるのは、

  1. 継続を表す関数contには、fact'が計算した結果が渡される
  2. 継続を表す関数contは、任意の結果型を返せる
  3. 継続を表す関数contが返した型が、fact'全体の結果型になる

です。 1つ目は、今まで見てきた通りのことです。継続には結果が渡されます。 2つ目と3つ目に注目してください。 今まで、一番外側(一番深い部分)の継続では、printfnによる出力を行っていました。

fact' 5 (fun res ->
printfn "%d" res)

今まで通りならこんな感じです。 これを上の3つに当てはめてみると、

  1. resにはfact'が計算した結果が入っている
  2. printfn "%d" resはfact'が計算した結果を出力して、unitを返す
  3. fact'に渡した継続がunitを返すので、fact'の呼び出し全体としてもunitを返す

となります。 ということは、CPS変換された関数から値を取り出すには、継続に渡された結果をそのまま返せばいいということになります。 これは、継続としてid関数を渡せばいい、ということですね。

let res = fact' 5 id
printfn "%d" res

つまりこれを関数化すれば、factのユーザは中でCPS変換された関数に実装が変わってもなにも気にしなくていいわけです。

let fact n = fact' n id

fact'を外から使わせないようにするために、関数内関数にしてもいいでしょう。

let fact n =
  let rec fact' n cont =
    if n = 0I then 1I |> cont
    else
      fact' (n - 1I) (fun pre ->
      n * pre |> cont)
  fact' n id

これで変換完了です。 実際にこれを試したい人は、プロジェクトのプロパティから「末尾呼び出しの生成」をオンにしてください(Releaseモードであればデフォルトでオンのはずです)。 また、fsiであれば設定不要で試せます。 この関数には、50000Iを渡してもスタックオーバーフローは起こしません。 CPS変換をしたことによって、末尾再帰になり、末尾呼び出しの最適化がかかったようです。

スタックオーバーフローするような再帰を書いてしまったときに、CPS変換を行えばスタックオーバーフローを回避できるようになります。 他にも回避する方法はあります*6が、 CPS変換は慣れてしまえばほとんど機械的に行えるので、自分の道具箱に入れておいてもいいでしょう。

その2はコンピュテーション式の話になる予定です。

おまけ

ここからはおまけです。もしくはボーナスステージ。 色々な関数をCPS変換してみましょう。

sum関数

オリジナル

let rec sum = function
| [] -> 0
| x::xs -> x + (sum xs)

letで書き換え

let rec sum xs =
  match xs with
  | [] -> 0
  | x::xs ->
      let pre = sum xs
      x + pre

CPS!

let rec sum xs cont =
  match xs with
  | [] -> 0 |> cont
  | x::xs ->
      sum xs (fun pre ->
      x + pre |> cont)

あ、id渡すラッパー関数は自明なので書きません。

max関数をCPS変換

オリジナル

let rec max = function
| [x] -> x
| x::xs ->
    let pre = max xs
    if pre < x then x
               else pre

letで書き換え

letで書き換え自体は不要だけど、functionをmatchにしておく。

let rec max xs =
  match xs with
  | [x] -> x
  | x::xs ->
      let pre = max xs
      if pre < x then x
                 else pre

CPS!

let rec max xs cont =
  match xs with
  | [x] -> x |> cont
  | x::xs ->
      max xs (fun pre ->
      if pre < x then x   |> cont
                 else pre |> cont)

find関数をCPS変換

オリジナル

let rec find pred = function
| [] -> failwith "not found."
| x::xs -> if pred x then x
                     else find pred xs

letで書き換え

let rec find pred xs =
  match xs with
  | [] -> failwith "not found."
  | x::xs ->
      if pred x then x
                else
                  let res = find pred xs
                  res

CPS!

let rec find pred xs cont =
  match xs with
  | [] -> failwith "not found."
  | x::xs ->
      if pred x then x |> cont
                else find pred xs cont (* (fun res -> res |> cont)なので、単にcontを渡せばいい *)

map関数をCPS変換

オリジナル

let rec map f = function
| [] -> []
| x::xs -> (f x) :: (map f xs)

letで書き換え

let rec map f xs =
  match xs with
  | [] -> []
  | x::xs ->
      let y = f x
      let ys = map f xs
      y::ys

CPS!

let rec map f xs cont =
  match xs with
  | [] -> [] |> cont
  | x::xs ->
      let y = f x
      map f xs (fun ys ->
      y::ys |> cont)

これは、map自体のCPS変換です。 fがCPS変換された関数の場合は、

let rec map f xs cont =
  match xs with
  | [] -> [] |> cont
  | x::xs ->
      f x (fun y ->
      map f xs (fun ys ->
      y::ys |> cont))

こうですね。

フィボナッチ関数をCPS変換

オリジナル

let rec fib = function
| 0 | 1 -> 1
| n -> fib (n - 1) + fib (n - 2)

letで書き換え

let rec fib n =
  match n with
  | 0 | 1 -> 1
  | n ->
      let pre1 = fib (n - 1)
      let pre2 = fib (n - 2)
      pre1 + pre2

CPS!

let rec fib n cont =
  match n with
  | 0 | 1 -> 1 |> cont
  | n ->
      fib (n - 1) (fun pre1 ->
      fib (n - 2) (fun pre2 ->
      pre1 + pre2 |> cont))

*1:contはcontinuationの略です。継続を表す変数名には他にもkなどが使われたりします。

*2:再帰関数のことを扱う場合が多いですが、再帰関数でなくとも末尾呼び出しと言えます。

*3:自分自身のスタックを再利用したり、ループに変形したりというやり方もありますが、どの方法でもスタックを消費しないという効果は同じです。

*4:他にも、Chain of Responsibilityパターンを適用した際に大量のオブジェクトがchainを構成する場合など、再帰しない場合でもうれしい場面はあります。

*5:例外を投げるとか、継続を捨てるとかは無視します。

*6:アキュムレータ変数を使う方法や、ループに書き換える方法などが使えます。