y_uti のブログ

統計、機械学習、自然言語処理などに興味を持つエンジニアの技術ブログです

サポートベクトルマシンで MNIST 手書き数字データを分類する

サポートベクトルマシン (SVM) を用いて、MNIST 手書き数字データの分類を試してみます。

SVM の実装は広く使われているものがいくつかありますが*1、今回は LIBSVM を利用します。LIBSVM は以下のウェブサイトから入手できます。
LIBSVM -- A Library for Support Vector Machines

データの準備

まず、ダウンロードした圧縮ファイルを展開してビルドします。正常にビルドが終了すると、svm-scale, svm-train, svm-predict という 3 つの実行ファイルが生成されます。

$ tar zxf libsvm-3.18.tar.gz
$ cd libsvm-3.18
$ make

次に、MNIST 手書き数字データを LIBSVM の入力データ形式に変換します。テキスト形式に変換するところまでは、前回の記事 を参照してください。前回作成したテキスト形式のデータは、ピクセルデータとラベルデータのそれぞれについて 1 画像を 1 行として数字を列挙したものでした。ここから LIBSVM の入力データファイルを作成するには、以下のようにします。

$ cat train-images.txt |\
  awk '{
    for (i = 1; i <= NF; i++) if ($i > 0) printf("%d:%f%s", i, $i / 255, i < NF ? " " : "");
    printf("\n");
  }' |\
  paste -d' ' train-labels.txt - >train.svm

出力される train.svm ファイルの先頭行は以下のようになります*2。

$ head -n 1 train.svm
5 153:0.011765 154:0.070588 155:0.070588 ... 682:0.529412 683:0.517647 684:0.062745

LIBSVM の入力データ形式については、README ファイルに説明があります。詳細はそちらを参照してください。大まかに説明すると、1 データが 1 行、先頭列はラベルで、2 列目以降にデータの各次元の値を "<index>:<value>" の形で列挙します。値が 0 の次元は省略できます。また、LIBSVM では、データを適宜スケーリングして取り扱うことが推奨されているので、各ピクセルの値を 255 で割って 0 から 1 の範囲の値に変換しています*3。

作成した train.svm の学習には時間がかかるので、ここからサンプリングして小さな訓練データセットを作成します。訓練データのサンプリングには、LIBSVM に含まれている subset.py スクリプトを利用できます。以下の実行例では、train.svm から 1,000 データをサンプリングして train_subset.svm に保存しています。subset.py は、層化抽出法を用いて、元データにおける各ラベルの比率を保存するようにサンプリングしてくれます。

$ libsvm-3.18/tools/subset.py train.svm 1000 >train_subset.svm
手書き数字データの分類

train_subset.svm に対して、以下のように SVM の学習を実行します。パラメータの設定など、いろいろ検討できるのですが、ここではデフォルトの設定で実行してみます。実行結果は train_subset.model ファイルに出力されます。

$ libsvm-3.18/svm-train train_subset.svm train_subset.model
*
optimization finished, #iter = 57
nu = 0.432358
obj = -57.291330, rho = -0.347079
nSV = 93, nBSV = 86

...

optimization finished, #iter = 88
nu = 0.695029
obj = -90.474083, rho = 0.346886
nSV = 140, nBSV = 125
Total nSV = 905

学習されたモデルを使って、テストデータの分類を行えます。分類の対象とする t10k.svm は、train.svm と同様に作成しておきます。なお、train.svm と同じ手順で t10k.svm を作成すると t10k.svm にも先頭列に正解のラベルが付きますが、これは正解率の計算のために使われるだけで、もちろん分類の際には利用されません。

$ libsvm-3.18/svm-predict t10k.svm train_subset.model t10k.output
Accuracy = 85.62% (8562/10000) (classification)

実行の結果、t10k に含まれる 10,000 点の画像を分類した結果、正解率は 85.62% だったことが分かります。各画像の分類結果は t10k.output ファイルに出力されていますので、以下のように confusion matrix を出力して、そこから正解率を計算することもできます。以下の実行例では、2 列目が正解ラベル、3 列目が分類結果、1 列目がそのようなデータの個数になります。

$ cut -f1 -d' ' t10k.svm | paste -d' ' - t10k.output | sort | uniq -c | head
    896 0 0
      5 0 2
      3 0 3
      7 0 4
     38 0 5
     24 0 6
      2 0 7
      5 0 8
   1104 1 1
      2 1 2
パラメータチューニング

さて、ここまではデフォルトのパラメータ設定で実行しましたが、次に、学習の際のパラメータチューニングを試してみます。問題に合わせてパラメータを設定することで、正解率の向上を期待できます。

LIBSVM に含まれている grid.py というスクリプトを用いて、パラメータのチューニングを行えます。このスクリプトは、svm-train での主要なパラメータである cost と gamma について、それらの値を変えながら svm-train の学習を繰り返し、交差検定を行って正解率が最も高いものを教えてくれます。以下のように実行します。

$ libsvm-3.18/tools/grid.py train_subset.svm
[local] 5 -7 90.9 (best c=32.0, g=0.0078125, rate=90.9)
[local] -1 -7 88.1 (best c=32.0, g=0.0078125, rate=90.9)
[local] 5 -1 16.4 (best c=32.0, g=0.0078125, rate=90.9)
...
[local] 13 -9 88.9 (best c=8.0, g=0.03125, rate=92.0)
[local] 13 -3 61.0 (best c=8.0, g=0.03125, rate=92.0)
8.0 0.03125 92.0

実行の結果、cost = 8.0, gamma = 0.03125 が提案されました。このときの交差検定の正解率は 92.0% となっています。また、cost と gamma による正解率の変化を等高線グラフで出力してくれます*4。今回の実行では、以下のような様子になりました。
f:id:y_uti:20140725143812p:plain

グラフを見ると、正解率 92.0% になっている領域はかなり狭そうで、詳しい様子が分かりません。grid.py の実行時にパラメータの探索範囲を指定することで、特定の範囲をより詳しく調べられます。次の実行例では、cost は 0 から 4 までを 0.5 刻み、gamma は -4 から -8 までを -0.5 刻みで試すように指示しています。

$ libsvm-3.18/tools/grid.py -log2c 0,4,0.5 -log2g -4,-8,-0.5 train_subset.svm

実行結果は以下のようになり、cost = 1.414, gamma = 0.03125 がよいという提案になりました。
f:id:y_uti:20140725152950p:plain

提案されたパラメータでモデル学習、分類を実行してみます。92.69% という正解率が得られました。

$ libsvm-3.18/svm-train -c 1.41421356237 -g 0.03125 train_subset.svm train_subset.model
$ libsvm-3.18/svm-predict t10k.svm train_subset.model t10k.out
Accuracy = 92.69% (9269/10000) (classification)

ここまで、訓練データは 1,000 点として実行していましたが、より多くのデータを使うことで正解率も高くなります。パラメータの値は固定した状態で、利用するデータ数を変えてモデル学習、分類を繰り返したところ、正解率は以下のように変化しました。60,000 データすべてを使って学習したモデルでは、98.49% の正解率が得られました。MNIST のウェブページには、ガウシアンカーネルの SVM で誤識別率 1.4% とあり、今回の実験ではこれに近い結果が得られたことになります。
f:id:y_uti:20140725155339p:plain

*1:今回利用する LIBSVM のほかに、SVM-Light, TinySVM などがあります。また、機械学習ライブラリにも SVM を含むものが多くあるようです。

*2:長いので、途中を省略しています。

*3:svm-scale を使うと、訓練データの範囲を調べて自動的にデータをスケーリングできます。今回のデータでは、各ピクセルの値が 0 から 255 の範囲だと分かっているので、svm-scale を利用せずに直接変換しました。

*4:gnuplot が必要です。