問題を解決するつもりでキッチリ型を付けた先にある高い壁


過ぎたるは猶及ばざるがごとし.

最近null安全?だかの話のからみで,(静的な)型で契約云々を表現してシアワセになれるんだぜーと言うのをチラホラ見聞きする.たとえば,pythonで統計なり機械学習なりやっててnumpy弄るような人が,ndarray(多次元配列)のshape(多次元配列の形)が合わずエラーで落ちたりとかそういうアレについて云々という.なるほど型があれば実行前に止めることができ,実行時,エラー*1になってファーみたいなことは避けられるだろう.

しかし,これが天国へ続く道かどうかはまた別の話.(依存)型で舗装しているのは地獄への道かもしれないのだ.冒頭ツイートの通り陰腹召してでも諫めておかねば,その希望,容易に絶望に反転し得る.

では実際に見ていこう.内容としては,numpy ndarrayとその上での操作をいくつか取り上げ,配列要素だけでなくshapeまで型レベルで扱うことで,間違えたshapeを持つ操作を禁止できるようにしてみるというもの.サンプルコードはコチラ.全体の流れとしてはndarrayを定義,reshape,dot,transposeを実装,それらを使ってtensordotを実装してみるとどうなるかをみていく.

前回の記事に続き,今回も型レベルでアレコレするのでアタマのほうはこんな感じ.singletonsも使っていく.

{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DataKinds #-}

import Data.Singletons
import Data.Singletons.Prelude.Enum
import Data.Singletons.Prelude.List
import Data.Singletons.Prelude.Num
import GHC.TypeLits

shapeと中身の型aを型レベルに持つNDArray型を定義する.今回は中身の計算には興味が無いので,shapeに関する情報だけ抱えておくことにする.

data NDArray (shape :: [Nat]) a = NDArray (Sing shape) -- 中身はshape意外省略

まずは,numpy.reshapeを実装してみる.

reshape :: Product from ~ Product to => Sing to -> NDArray from a -> NDArray to a
reshape = const . NDArray

要素数が同じでshapeの違う別のndarrayに変換する操作だが,別段不可思議なところは無い.コンテキスト部分により要素数が同じという条件を付けており,以下のように確かに条件に合わない場合は型検査で失敗する.

-- reshape可
reshapeable :: NDArray '[2,3,4] a -> NDArray '[3,8] a
reshapeable = reshape sing

-- 型検査でreshape不可 (次元の積が合わない)
-- unreshapeable :: NDArray '[2,3,4] a -> NDArray '[3,3,4] a
-- unreshapeable = reshape sing -- Couldn't match type ‘24’ with ‘36’

次に,numpy.dotを実装する.

dot :: (Last xs ~ Last (Init ys), Num a) =>
       NDArray xs a -> NDArray ys a -> NDArray (Init xs :++ Init (Init ys) :++ '[Last ys]) a
NDArray xs `dot` NDArray ys = NDArray (sInit xs %:++ sInit (sInit ys) %:++ SCons (sLast ys) SNil)

dot積を取る2つのndarrayのshapeから結果のshapeが決定する.1つ目のshape最後の要素と,2つ目のshape最後から2番目の要素が同じという条件が付いている.これも,以下のように条件に合わない場合と結果の型が合わない場合は型検査で失敗する.

-- dot可
dottable :: Num a => NDArray '[2,3,4] a -> NDArray '[3,4,2] a -> NDArray '[2,3,3,2] a
dottable = dot

-- 型検査でdot不可 (結果の型が合わない)
-- undottable1 :: Num a => NDArray '[2,3,4] a -> NDArray '[3,4,4] a -> NDArray '[2,3,3,2] a
-- undottable1 = dot -- Couldn't match type ‘4’ with ‘2’

-- 型検査でdot不可 (引数の型が合わない)
-- undottable2 :: Num a => NDArray '[2,3,4] a -> NDArray '[4,3,2] a -> NDArray '[2,3,4,2] a
-- undottable2 = dot -- Couldn't match type ‘4’ with ‘3’

最後に,numpy.transposeを実装する.

transpose :: Sort axes ~ EnumFromTo 0 (Length shape - 1) =>
             Sing axes -> NDArray shape a ->
             NDArray (Map ((:!!$$) shape) axes) a
transpose axes (NDArray shape) =
  NDArray (sMap (singFun1 (toProxy shape) (shape %:!!)) axes) where
    toProxy :: Sing (shape :: [Nat]) -> Proxy (Apply (:!!$) shape)
    toProxy _ = Proxy

条件は少し複雑になり,軸(次元)の転置方法に相当するパラメータaxesが,ちゃんとshape全体のpermutationを意味するパラメータになっていなければならない.これも,次のように各理由で整合しない場合は型検査で失敗する.

transposable :: Sing '[1,0,2] -> NDArray '[2,3,4] a -> NDArray '[3,2,4] a
transposable = transpose

-- transpose不可 (axesにshapeの長さ以上のものが含まれる)
-- untransposable1 :: Sing '[1,0,3] -> NDArray '[2,3,4] a -> NDArray '[3,2,4] a
-- untransposable1 = transpose -- Couldn't match type ‘3’ with ‘2’

-- transpose不可 (axesに同じ要素が2つ以上含まれる)
-- untransposable2 :: Sing '[1,0,0] -> NDArray '[2,3,4] a -> NDArray '[3,2,2] a
-- untransposable2 = transpose -- Couldn't match type ‘1’ with ‘2’

-- transpose不可 (axesの長さとshapeの長さが一致しない)
-- untransposable3 :: Sing '[1,0] -> NDArray '[2,3,4] a -> NDArray '[3,2] a
-- untransposable3 = transpose -- Couldn't match type ‘'[]’ with ‘'[2]’

-- transpose不可 (結果の型が合わない)
-- untransposable4 :: Sing '[1,0,2] -> NDArray '[2,3,4] a -> NDArray '[3,2,5] a
-- untransposable4 = transpose -- Couldn't match type ‘4’ with ‘5’

さて,reshape,dot,transposeが定義されたので,これらを使ってnumpy.tensordotが定義できる.はずである.実際に「こうすればいいよね」という感覚に従って実装してみよう.

tensordot :: (Num a, ns ~ Nub ns, ms ~ Nub ms,
              Map ((:!!$$) xs) ns ~ Map ((:!!$$) ys) ms) =>
             NDArray xs a -> NDArray ys a -> (Sing ns, Sing ms) ->
             NDArray (Map ((:!!$$) xs) (EnumFromTo 0 (Length xs - 1) :\\ ns) :++
                      Map ((:!!$$) ys) (EnumFromTo 0 (Length ys - 1) :\\ ms)) a
tensordot x@(NDArray xs) y@(NDArray ys) (ns, ms) = result where
  range n = sEnumFromTo (sing :: Sing 0) (n %:- (sing :: Sing 1))
  notinns = range (sLength xs) %:\\ ns
  notinms = range (sLength ys) %:\\ ms
  tx = transpose (notinns %:++ ns) x
  ty = transpose (ms %:++ notinms) y
  dimsIn xs = sMap (singFun1 (toProxy xs) (xs %:!!)) where
    toProxy :: Sing (shape :: [Nat]) -> Proxy (Apply (:!!$) shape)
    toProxy _ = Proxy
  (oldxs, oldys) = (dimsIn xs notinns, dimsIn ys notinms) where
  rtx = reshape (SCons (sProduct oldxs) $ SCons (sProduct $ dimsIn xs ns) SNil) tx
  rty = reshape (SCons (sProduct $ dimsIn ys ms) $ SCons (sProduct oldys) SNil) ty
  result = reshape (oldxs %:++ oldys) (rtx `dot` rty)

とりあえずだが,コンテキストによる条件の表現としては,tensordotを取る2つのndarrayのshapeに対し,dotを取る部分の軸指定がpermutationの部分列になっており,かつ,指定された軸について同じ形をしているという条件相当を付けておいた.

このtensordotの定義は型検査に失敗し,5件くらいしか無いクセにトータル1200行くらいの長大なエラーメッセージを吐く.もちろん実装者も吐く.そして泣く.これは,ニンゲンには自明なことであっても,機械(型検査器)にはわからないことがあるためだ.それは,

  • transpose時の条件,転置方法が正しくpermutationに相当する値(型レベルの)になっているかがわからない.(130行のエラー x txとtyについて2箇所)
    • ニンゲンにとって,ある列の中からいくつかを取り出し(notinnsã‚„notinms),取り出した後の最初や最後にくっつけて(transposeの引数)も,順番が変わるだけで中身は一緒(=permutation)であることはわかる.
    • けど,機械(型検査器)にとっては自明ではない.
  • reshape時の条件,reshape前後の要素数が(型レベルで)同じであるかどうかがわからない.(280行のエラー x rtxとrtyについて2箇所 + 400行のエラー x resultについて1箇所)
    • ニンゲンにとって,ある正数の列を2グループに分け(xsに対するoldxsとそれ以外,ysに対するoldysとそれ以外),それぞれの積を取ったものの積は,元の列の積と同じことはわかる.
    • けど,機械(型検査器)にとっては自明ではない.

ことからきている.dotの条件についてのエラーが出ていないが,この条件についてはわかっているのかというと,それはtensordot自体の条件(Mapのやつ)から機械(型検査器)にもわかる.また,reshapableやtransposable等で使ったとき問題無かったのは,具体的な型レベルでの値(型レベル自然数リテラル)が入っているためであり,対して今回のtensordotの定義中では(型)変数のままなので,機械(型検査器)が計算を進められるかの状況が異なる.

このエラーを解決するためには,恐らく2つ方法がある.

  • 無視する.といってもエラーではどうしようもないので前回の記事と同じような方法を使い,transposeã‚„reshapeを使ってもエラーにならないケースを作り出す.ここで無視するのは以下2点だ,
    • (ニンゲンとって自明にunreachableな)エラーケースがコード上発生すること
    • これらを使う人には条件を示すことを要請する癖に,自分が使うときは自明だからいいだろと無視するカッコ悪さ
  • 使える情報から機械(型検査器)がわかる情報にニンゲンが変換して教えてあげる.

前者について,ニンゲンにとって自明とは言うものの,複雑な条件になると本当に自明かどうかはニンゲンにもわかるか怪しいし,そもそも自明だと思ったこと自体も錯誤かもしれない.わざわざ型レベルでどうにかしようという話自体,そういったニンゲンによるミスをどうにかしようというモチベーションだった筈であり,解決の方向性としてはどうなのソレという感がある.

後者が正攻法となるが,実は,それがこの界隈で何と呼ばれているか私達は既に知っている.それは定理証明と呼ばれるヤツである.こうなるとワンチャン証明器からextractしたほうが使い易いまである.Haskellでも実際GHC.TypeLits.Natはペアノ数ではない*2ようなので,そのままではInductionが効かず証明相当のサギョウがやりにくいことこの上無かったりする.


まとめると,各種契約を型で云々~に期待されるようガチガチに型レベルで設計されたものは,言語にも使う側にも型レベルでガチな取り組みを求められることがある.型レベルへ条件を持ち上げることは「人はミスせずに作業ができるか」という問題を「人は定理証明ができるか」という問題へと変換することなのだ.となると,最近null安全?とかの話から複雑な契約も型でヤッター!と言ってる方々はたぶん皆このへんについても「覚悟完了!当方に迎撃の準備有り!」なので,定理証明についても何でも*3聞いて大丈夫ということになる.やったぜ

*1:ChainerやらTensorFlowやらで時間をかけた学習後セーブもせず喰わせるデータ(のshape)間違えたりとか

*2:そもそもなんでこうなってるんだっけ?型レベル自然数リテラルはペアノ数へのエイリアスでよかったようにも思うんだけど

*3:ん?今