昨日Gibbs Sampler Algorithmをやってみたわけだが、Rの中でfor文を書いていて必要となるサンプル数が多くなると非常につらくなってくることは目に見えている。しかも、MCMCでは初期値依存となる期間のサンプルを捨てないといけない。そういうわけでじゃんじゃんサンプルを作っても大丈夫なような速度が必要。
Rで速度を上げようと思ったらapplyファミリーを使うとかベクトル単位での処理をするetcが常套手段*1。が、今回は本質的にfor文が必要なケースである。
で、困るわけだがRにはC、C++、fortranを使って拡張する機能がある。詳しくはこの辺に載っている。そういうわけでCのポインタもアドレスも理解していないid:syou6162がRが好きすぎたためにCを書いてみたという感じの内容。
#include <R.h> #include <Rinternals.h> SEXP r2norm(SEXP num) { int i; double prevX1,prevX2; int n = INTEGER(num)[0]; SEXP ans; prevX1 = 2.0; prevX2 = 1.0; PROTECT(ans = allocMatrix(REALSXP,2,n)); GetRNGstate(); for(i = 0;i < 2*n;i+=2){ prevX1 = 1.0 + 0.7 * (prevX2 - 2.0) + (norm_rand() * (1.0 - 0.7 * 0.7)); REAL(ans)[i] = prevX1; prevX2 = 2.0 + 0.7 * (prevX1 - 1.0) + (norm_rand() * (1.0 - 0.7 * 0.7)); REAL(ans)[i+1] = prevX2; } PutRNGstate(); UNPROTECT(1); return(ans); }
とりあえずソース。内容については後述。
Cのソースを次のようにコンパイルします。Rの必要なヘッダーファイルとかはRが適当にやってくれるようです。R++。
/tmp% R CMD SHLIB hoge.c gcc -std=gnu99 -I/Library/Frameworks/R.framework/Resources/include -I/usr/local/include -fPIC -g -O2 -c hoge.c -o hoge.o gcc -std=gnu99 -dynamiclib -Wl,-headerpad_max_install_names -undefined dynamic_lookup -single_module -multiply_defined suppress -L/usr/local/lib -o hoge.so hoge.o -F/Library/Frameworks/R.framework/.. -framework R -Wl,-framework -Wl,CoreFoundation
で、とりあえずちゃんとできていることを確認しよう。こんな感じでやると
dyn.load("/tmp/hoge.so") r2norm <- function(n){ data.frame(t(.Call("r2norm", as.integer(n)))) } r <- r2norm(1000) x1 <- r[,1] x2 <- r[,2]
ちゃんと相関係数0.7を持つ乱数系列が生成できていることが確認できた。
> cor(x1,x2) [1] 0.6902752 > plot(x2,x1)
速度を比較してみる
そういうわけでRでforを使って書いたものと、Cレベルのforを使って書いたものの速度比較をしてみる。適当に実行時間を比較する関数を用意しておく。r.for <- function(n){ B <- n x1 <- rep(NA,B+1) x2 <- rep(NA,B+1) x1[1] <- -2 x2[1] <- 1 for (i in 1:B){ x1[i+1] <- rnorm(1,1+0.7*(x2[i]-2),sqrt(1-0.7^2)) x2[i+1] <- rnorm(1,2+0.7*(x1[i+1]-1),sqrt(1-0.7^2)) } } print.time <- function(n){ cat("n =",n,fill=TRUE) print(system.time(r.for(n))) print(system.time(r2norm(n))) }
で、時間を計測すると、こんな感じにいいい!!!
> print.time(100) n = 100 user system elapsed 0.004 0.000 0.004 user system elapsed 0 0 0 > print.time(1000) n = 1000 user system elapsed 0.037 0.000 0.037 user system elapsed 0.001 0.000 0.000 > print.time(10000) n = 10000 user system elapsed 0.371 0.003 0.376 user system elapsed 0.005 0.002 0.007 > print.time(100000) n = 1e+05 user system elapsed 3.857 0.032 3.918 user system elapsed 0.085 0.016 0.102 > print.time(1000000) n = 1e+06 user system elapsed 36.130 0.283 36.667 user system elapsed 0.683 0.150 0.835
40倍以上早くなってるようです。Cすげえ、Rすげえ!!!
*1:outerとか使えるところだとすごく速度が違うのは前経験したことがある