y_uti のブログ

統計、機械学習、自然言語処理などに興味を持つエンジニアの技術ブログです

DTW Barycenter Averaging で時系列データの平均を求める

DTW Barycenter Averaging (DBA) の手法を用いて、時系列データの平均を求めてみます。前回の記事*1と同様に台風の経路情報を題材として、複数の台風の経路の「平均」を計算しました。結果は以下のとおりです。気象庁のウェブサイト*2で公開されているデータから 10 個の台風を選び、DBA を適用して平均の経路を求めたもので、青線が各台風の経路、赤線が DBA による平均です。

データの準備

利用するデータは、前回の記事で作成した過去の台風のデータから、2012 年の台風 17 号に DTW 距離の近い 10 個を選択しました。DTW を計算するプログラムを用いて次のように実行します。

$ for f in T*.csv; do echo $f $(php dtw.php T1217.csv $f); done | sort -nsk2,2 | head -n 10
T1217.csv 0
T1101.csv 152.08961628221
T0315.csv 175.12050614974
T1304.csv 198.65671463983
T1506.csv 239.88518007505
T1102.csv 240.21226800402
T0306.csv 280.40652975797
T0715.csv 305.72729625626
T1203.csv 309.22264174911
T0423.csv 320.9641416908

この 10 個の台風について距離行列を計算して medoid を求めると、T1101 (2011 年の台風 1 号) になります。次の図は T1101 を赤線として 10 個の台風を描いたものです。Medoid は DTW 距離の総和を最小にするものなので、図のように、経路の一部は全体の平均から大きく外れることもあります。

DBA のアルゴリズム

DBA は時系列データの平均を求めるアルゴリズムです。次の論文に擬似コードが掲載されています。論文は第一著者のウェブページ (http://www.francois-petitjean.com/Research/) で公開されています。

IEEE ICDM 2014: F. Petitjean, G. Forestier, G. Webb, A. Nicholson, Y. Chen and E. Keogh, "Dynamic Time Warping Averaging of Time Series allows Faster and more Accurate Classification."

なお、DBA のアルゴリズム自体は同じ第一著者による以下の論文で発表されているようです。こちらは本文ファイルへのリンクが見つからなかったので、今回の記事では上記の論文のコードを参照しました。

Pattern Recognition: F. Petitjean, A. Ketterlin & P. Gançarski, "A global averaging method for dynamic time warping, with applications to clustering," 2011.

DBA は、適当な仮の「平均」を初期値として、それを反復的に更新することで時系列データの平均を求めます。次の図の例で DBA の動作を説明します。図の青線は T1217 と T1203、赤線は T1101 です。赤線の系列を初期値として、青線の 2 つの系列の平均を求めます。なお、ここでは簡単のため 2 系列の平均を求めますが、3 系列以上の場合でも同様に計算できます。

まず、各系列 (青線) と平均の系列 (赤線) との間で DTW をそれぞれ計算します。このとき、DTW の計算過程で得られる、要素間の対応付けを保持しておきます。下図は、左が T1217 と平均系列との DTW の結果、右が T1203 と平均系列との DTW の結果です。緑色の線が要素間の対応付けを表しています。

次に、平均系列の各要素を、それぞれ対応付けられた要素の重心に移動します。この操作を下図に示します。左は DTW の結果をまとめて日本列島付近を拡大したもので、右は、ここから平均系列の各要素を移動したものです。右図の桃色の線は移動前の状態を表しています。比較すると、重心に移動することで平均系列が改善されている様子が分かります。

以上の手順で平均の系列を移動したことにより、各系列との DTW の結果が変化します。DBA は、これを繰り返して平均の系列を順次更新していきます。例に示した 2 系列の平均では、9 回の反復処理で収束しました。得られた結果は以下のとおりです。

DBA の実装

それでは、DBA のアルゴリズムを実装してみます。いつものように PHP で実装します。

まず、DTW を計算して対応付けを求める関数を次のように実装しました。$a と $b を入力の時系列として、関数の前半では、二次元配列 $d に累積の DTW 距離を格納していきます。これと同時に、直前のインデックスを $p に格納していきます。後半では $p を最終要素から逆順に辿って対応付けを構築し、得られた $path を戻します。

<?php
function dtwpath($a, $b, $distance = 'euclid') {
    $p = array_fill(0, count($a), array_fill(0, count($b), false));
    $d = array_fill(0, count($a), array_fill(0, count($b), false));
    $d[0][0] = $distance($a[0], $b[0]);
    for ($i = 1; $i < count($a); ++$i) {
        $p[$i][0] = array($i - 1, 0);
        $d[$i][0] = $d[$i - 1][0] + $distance($a[$i], $b[0]);
    }
    for ($j = 1; $j < count($b); ++$j) {
        $p[0][$j] = array(0, $j - 1);
        $d[0][$j] = $d[0][$j - 1] + $distance($a[0], $b[$j]);
    }
    for ($i = 1; $i < count($a); ++$i) {
        for ($j = 1; $j < count($b); ++$j) {
            $prev = array($i - 1, $j - 1);
            if ($d[$i][$j - 1] < $d[$prev[0]][$prev[1]]) {
                $prev = array($i, $j - 1);
            }
            if ($d[$i - 1][$j] < $d[$prev[0]][$prev[1]]) {
                $prev = array($i - 1, $j);
            }
            $p[$i][$j] = $prev;
            $d[$i][$j] = $d[$prev[0]][$prev[1]] + $distance($a[$i], $b[$j]);
        }
    }
    $path = array();
    list ($i, $j) = array(count($a) - 1, count($b) - 1);
    while ($i > 0 || $j > 0) {
        $path[] = array($i, $j);
        list ($i, $j) = $p[$i][$j];
    }
    $path[] = array(0, 0);
    return array_reverse($path);
}

function euclid($a, $b) {
    return sqrt(array_sum(array_map(function ($x, $y) {
        return pow($x - $y, 2);
    }, $a, $b)));
}

これを用いて、DBA の一回の反復処理を次のように実装しました。$sequences が時系列データの配列、$seed が移動前の平均系列です。foreach のループで、平均系列と各系列との DTW の対応付けを求め、平均系列の各要素に対応付けられる要素を $vertices[$ci] に追加していきます。最後に、average 関数で平均系列の各要素を重心に移動して戻します。

<?php
require_once __DIR__ . '/dtwpath.php';

function dba($sequences, $seed) {
    $vertices = array_fill(0, count($seed) - 1, array());
    foreach ($sequences as $s) {
        $path = dtwpath($seed, $s);
        foreach ($path as list ($ci, $si)) {
            $vertices[$ci][] = $s[$si];
        }
    }
    return array_map('average', $vertices);
}

function average($vectors) {
    $n = count($vectors);
    $d = count($vectors[0]);
    $result = array_fill(0, $d, 0);
    for ($i = 0; $i < $d; ++$i) {
        foreach ($vectors as $v) {
            $result[$i] += $v[$i];
        }
        $result[$i] /= $n;
    }
    return $result;
}

この反復処理を繰り返して時系列データの平均を計算する処理は、次のように実装しました。第一引数を初期値として、第二引数以降の時系列データの平均を求めます。平均の系列が変化しなくなるか、反復回数が 100 回に達するまで処理を繰り返します。

<?php
require_once __DIR__ . '/dba.php';

function readCsv($filename) {
    return array_map(
        function ($line) { return explode(',', $line); },
        file($filename, FILE_IGNORE_NEW_LINES));
}

$average = readCsv($argv[1]);
$sequences = array();
for ($i = 2; $i < $argc; ++$i) {
    $sequence = readCsv($argv[$i]);
    $sequences[] = $sequence;
}

for ($i = 0; $i < 100; ++$i) {
    $next_average = dba($sequences, $average);
    if ($next_average === $average) {
        break;
    }
    $average = $next_average;
}

foreach ($average as $values) {
    echo implode(',', $values) . "\n";
}

このプログラムを次のように実行すると、10 個の台風を平均した経路が得られます。冒頭の図は、この出力結果を描画したものです。

$ php dba_main.php T1101.csv T*.csv | tee result.csv
10.647916666667,139.06458333333
12.328571428571,132.61785714286
14,129.96666666667
12.676,129.036
14.236363636364,128.3
  ...
42.7,161.17272727273
44.507692307692,165.56153846154
49.178260869565,169.88695652174
47.17,173.59
48.441666666667,177.85
初期値による DBA の結果の違い

ここからは、DBA で得られる時系列について少し調べてみます。まず、初期値による結果の違いを確認します。

アルゴリズムを考えるとわかるように、DBA では、初期値として与えた系列の長さは最後まで変わりません。論文を読む限りでは、このアルゴリズムは長さの等しい時系列データの集合に適用することを意図しているようにも思えるのですが*3、今回の記事のように長さがばらばらの時系列に適用するときには、初期値の系列長が特に問題になりそうです。直感的には、極端に短い時系列では良い結果が得られないように思われます。今回の実験では 10 個の台風の medoid として T1101 を初期値としましたが、確認してみると、これは全体の中で 2 番目に短いデータでした。

$ wc -l T*.csv
  61 T0306.csv
  50 T0315.csv
  57 T0423.csv
  59 T0715.csv
  36 T1101.csv
  61 T1102.csv
  53 T1203.csv
  69 T1217.csv
  34 T1304.csv
  60 T1506.csv
 540 合計

10 個の台風の中で一番長い T1217 を初期値にすると、次の結果になりました。赤線が T1217 を初期値としたときの結果です。比較のため、T1101 を初期値とした結果 (冒頭の図) も桃色で示しています。思ったほどの違いはありませんでしたが、少し異なる軌跡を描いていることがわかります。

系列ごとに平均を取る

今回利用した台風のデータでは、DBA で得られた平均の経路が全体的にでこぼこしていて、やや不自然な印象があります。この原因として、DTW で一対多の対応付けが得られるために重心の計算で特定の系列に強く引き寄せられるのではないかと考えました*4。

そこで、DBA のアルゴリズムを少し変更して、まず各系列の中で平均してから、系列ごとに同じ重さで重心を求めるようにしてみました。変更後の dba 関数の実装は次のとおりです。

<?php
function dba($sequences, $seed) {
    $vertices = array_fill(0, count($seed) - 1, array());
    foreach ($sequences as $s) {
        $path = dtwpath($seed, $s);
        $seqavg = array_fill(0, count($seed) - 1, array());
        foreach ($path as list ($ci, $si)) {
            $seqavg[$ci][] = $s[$si];
        }
        $seqavg = array_map('average', $seqavg);
        for ($i = 0; $i < count($seqavg); ++$i) {
            $vertices[$i][] = $seqavg[$i];
        }
    }
    return array_map('average', $vertices);
}

このように変更して実行した結果が以下のとおりです。初期値は T1101 としています。赤線が変更後のプログラムによるもので、桃色はオリジナルのプログラムによるものです。今回利用したデータでは、期待どおり滑らかな結果が得られました。その一方で、時系列の最初の点が西側にずれており、東京の南から南東方向で発生している台風の様子を捕らえられなくなっています。このあたりは用途に応じて工夫の余地がありそうです。

*1:k-medoids 法と DTW による時系列データのクラスタリング - y_uti のブログ

*2:http://www.data.jma.go.jp/fcd/yoho/typhoon/index.html, 今回の記事の図は、ここからダウンロードできる PDF ファイルのデータを加工して作成したものです。

*3:論文 III 節の A. Definitions で時系列の長さを定数 L として議論が進んでいます。また、Table III でもデータセットごとに系列長が一定のようです。

*4:論文に書かれている内容ではなく私の勝手な思いつきなので注意してください。