MCMCサンプルを{dplyr}で操る
RからStanやJAGSを実行して得られるMCMCサンプルは、一般的に iterationの数×chainの数×パラメータの次元 のようなオブジェクトとなっており、凝った操作をしようとするとかなりややこしいです。
『StanとRでベイズ統計モデリング (Wonderful R)』のなかでは、複雑なデータ加工部分は場合によりけりなので深入りしないで、GitHub上でソースコードを提供しています。そこでは、ユーザが新しく覚えることをなるべく少なくするため、Rの標準的な関数であるapply
関数群を使っていろいろ算出しています。しかし、apply
関数群は慣れていない人には習得しづらい欠点があります。
一方で、Rのデータ加工パッケージとして、%>%
によるパイプ処理・{dplyr}
パッケージ・{tidyr}
パッケージがここ最近よく使われており、僕も重い腰を上げてやっと使い始めたのですが、これが凄く使いやすい。%>%
、select
、filter
、mutate
、group_by
、summarize
、*_join
、pivot_longer
、pivot_wider
だけをまずは覚えればほとんど不自由しませんでした。これらがないともう他の言語に移れないレベルです。これらのパッケージの練習のおかげで、ややこしいMCMCサンプルの処理についても、こんな感じでやれば毎回ウンウン唸らずに統一的に操作できそうかなぁ、というところまで来ましたので簡単にメモします。
* * *
手始めに以下の図を描いてみます。
この図はパラメータごとにMCMCサンプルの中央値と95%CIを表示した図です。{ggmcmc}
パッケージや{bayesplot}
パッケージに含まれる関数を使うと一撃で描くこともできます。しかし、練習のため自分で算出して作図します。
library(rstan) library(ggmcmc) library(dplyr) data <- list(J=8, y=c(28, 8, -3, 7, -1, 1, 18, 12), sigma=c(15, 10, 16, 11, 9, 11, 10, 18)) model_code <- readr::read_file(url('https://raw.githubusercontent.com/wiki/stan-dev/rstan/8schools.stan')) fit <- stan(model_code=model_code, data=data, seed=1234) d_mcmc <- ggs(fit) d_qua <- d_mcmc %>% filter(grepl('^theta\\[\\d+\\]$', Parameter)) %>% group_by(Parameter) %>% summarize(`2.5%` = quantile(value, probs=.025), `50%` = quantile(value, probs=.5), `97.5%`= quantile(value, probs=.975)) %>% ungroup() p <- ggplot() + geom_pointrange(data=d_qua, mapping=aes(x=forcats::fct_rev(Parameter), y=`50%`, ymin=`2.5%`, ymax=`97.5%`)) + coord_flip() + labs(x='Parameter', y='Value') ggsave(p, file='fig1.png', dpi=300, w=4, h=3)
- 6行目:Web上から
8schools.stan
を読み込んで文字列としています。RStanの公式ページの例題で使われているモデルファイルです。 - 7行目:
stan
関数はmodel_code
引数でモデルを書いた文字列も指定できます(本来はファイル名を直接指定できればよかったのですがよくわかりませんでした)。 - 9行目:
{ggmcmc}
パッケージのggs
関数でtidyなデータにしておきます。tidyなデータについては西原さんの記事を参照。{ggmcmc}
パッケージに含まれる関数を使うとd_mcmc
からいろいろな図が描けます。詳しくは『StanとRでベイズ統計モデリング (Wonderful R)』の4章に書きましたので読んでいただけるとうれしいです。
なお、d_mcmc
は以下のようなデータフレームになります。
> d_mcmc # A tibble: 72,000 × 4 Iteration Chain Parameter value <dbl> <int> <fctr> <dbl> 1 1 1 mu -1.3476 2 2 1 mu -0.9601 3 3 1 mu 7.0919 4 4 1 mu 15.0782 5 5 1 mu 20.0110 6 6 1 mu 20.4483 7 7 1 mu 13.0249 8 8 1 mu 11.8232 9 9 1 mu 15.9213 10 10 1 mu 17.9294 # ... with 71,990 more rows
- 11行目:まずはモデルに含まれる
theta[数字]
というパラメータだけ残しています。grepl
関数でパラメータ名がマッチするか判定する際に正規表現を使う必要があります。ここが正規表現に慣れていない人は少し厳しいかもしれません。 - 12~15行目:
{dplyr}
パッケージの典型的な使い方です。Parameter
列ごとに要約量を算出します。列名が数字で始まる場合はバッククォートで囲む必要があります。1つずつ分位点を算出するのではなく、do
関数で一行で算出する方法もあるのですが、分かりにくく、現状issueとして検討中のようです(ここ→ここ)。 - 18行目:
ggplot2
でcoord_flip
すると下から上に向かってfactor
が並びますので、{forcats}
パッケージのfct_rev
関数で逆順にしています。
* * *
次に以下の図を描いてみます。久保本11章の図に相当します。
library(rstan) library(ggmcmc) library(dplyr) Y <- read.csv(url('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap12/input/data-kubo11a.txt'))$Y I <- length(Y) d <- data.frame(X=1:I, Y=Y) data <- list(I=I, Y=Y) model_code <- readr::read_file(url('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap12/model/model12-11.stan')) fit <- stan(model_code=model_code, data=data, seed=1234) d_mcmc <- ggs(fit) d_qua <- d_mcmc %>% filter(grepl('^Y_mean\\[\\d+\\]$', Parameter)) %>% tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% group_by(Parameter, x) %>% summarize(`2.5%` = quantile(value, probs=.025), `10%` = quantile(value, probs=.1), `50%` = quantile(value, probs=.5), `90%` = quantile(value, probs=.9), `97.5%`= quantile(value, probs=.975)) %>% ungroup() p <- ggplot() + geom_ribbon(data=d_qua, mapping=aes(x=x, ymin=`2.5%`, ymax=`97.5%`), alpha=1/6) + geom_ribbon(data=d_qua, mapping=aes(x=x, ymin=`10%`, ymax=`90%`), alpha=2/6) + geom_line(data=d_qua, mapping=aes(x=x, y=`50%`)) + geom_point(data=d, aes(x=X, y=Y), shape=1, size=2) + labs(x='i', y='Y[i]') + ylim(0, 22) ggsave(p, file='fig2.png', dpi=300, w=4, h=3)
- 15行目:
{tidyr}
パッケージのseparate
関数を使って、Y_mean[20]
をY_mean
と20
という2つの列に分解しています。 - 16行目:あとは集計の単位である
group_by
の単位が場面によって多少変わるぐらいで、特に悩まずに色々な量が算出できます。
この記事ではMCMCのChain
ごとに何かを算出することは取り上げませんでしたが、ggs
関数で作ったd_mcmc
はChain
列も含んでいますので自由自在です。
2017.07.16 追記
vector[D] mu[T]
のようにD次元vectorがT(時点の数)個並べたような配列に対し、各t,dのmu
の値の分位点を算出したい場合は以下のようにしました。tidyr::separate
でまず角括弧の切れ目で分けておいてから、カンマで切るのがポイントです。
d_qua <- d_mcmc %>% filter(grepl('^mu\\[\\d+,\\d+\\]$', Parameter)) %>% tidyr::separate(Parameter, into=c('Parameter', 't'), sep='[\\[\\]]') %>% tidyr::separate(t, into=c('t', 'd'), sep=',', convert=TRUE) %>% group_by(Parameter, t, d) %>% summarize(`2.5%` = quantile(value, probs=.025), `10%` = quantile(value, probs=.1), `50%` = quantile(value, probs=.5), `90%` = quantile(value, probs=.9), `97.5%`= quantile(value, probs=.975)) %>% ungroup()