CでRの拡張したら速すぎて(40〜50倍)吹いたwww

昨日Gibbs Sampler Algorithmをやってみたわけだが、Rの中でfor文を書いていて必要となるサンプル数が多くなると非常につらくなってくることは目に見えている。しかも、MCMCでは初期値依存となる期間のサンプルを捨てないといけない。そういうわけでじゃんじゃんサンプルを作っても大丈夫なような速度が必要。

Rで速度を上げようと思ったらapplyファミリーを使うとかベクトル単位での処理をするetcが常套手段*1。が、今回は本質的にfor文が必要なケースである。

で、困るわけだがRにはC、C++、fortranを使って拡張する機能がある。詳しくはこの辺に載っている。そういうわけでCのポインタもアドレスも理解していないid:syou6162がRが好きすぎたためにCを書いてみたという感じの内容。

#include <R.h> 
#include <Rinternals.h> 

SEXP r2norm(SEXP num)
{
  int i;
  double prevX1,prevX2;
  int n = INTEGER(num)[0];
  SEXP ans;
  prevX1 = 2.0;
  prevX2 = 1.0;
  PROTECT(ans = allocMatrix(REALSXP,2,n));
  GetRNGstate();
  for(i = 0;i < 2*n;i+=2){
    prevX1 = 1.0 + 0.7 * (prevX2 - 2.0) + (norm_rand() * (1.0 - 0.7 * 0.7));
    REAL(ans)[i] = prevX1;
    prevX2 = 2.0 + 0.7 * (prevX1 - 1.0) + (norm_rand() * (1.0 - 0.7 * 0.7));
    REAL(ans)[i+1] = prevX2;
  }
  PutRNGstate();
  UNPROTECT(1);
  return(ans);
}

とりあえずソース。内容については後述。

Cのソースを次のようにコンパイルします。Rの必要なヘッダーファイルとかはRが適当にやってくれるようです。R++。

/tmp% R CMD SHLIB hoge.c 
gcc -std=gnu99 -I/Library/Frameworks/R.framework/Resources/include  -I/usr/local/include    -fPIC  -g -O2 -c hoge.c -o hoge.o
gcc -std=gnu99 -dynamiclib -Wl,-headerpad_max_install_names  -undefined dynamic_lookup -single_module -multiply_defined suppress -L/usr/local/lib -o hoge.so hoge.o   -F/Library/Frameworks/R.framework/.. -framework R -Wl,-framework -Wl,CoreFoundation

で、とりあえずちゃんとできていることを確認しよう。こんな感じでやると

dyn.load("/tmp/hoge.so")
r2norm <- function(n){
  data.frame(t(.Call("r2norm", as.integer(n))))
}

r <- r2norm(1000)
x1 <- r[,1]
x2 <- r[,2]

ちゃんと相関係数0.7を持つ乱数系列が生成できていることが確認できた。

> cor(x1,x2)
[1] 0.6902752
> plot(x2,x1)

f:id:syou6162:20161031152057p:plain

速度を比較してみる

そういうわけでRでforを使って書いたものと、Cレベルのforを使って書いたものの速度比較をしてみる。適当に実行時間を比較する関数を用意しておく。
r.for <- function(n){
  B <- n
  x1 <- rep(NA,B+1)
  x2 <- rep(NA,B+1) 
  x1[1] <- -2
  x2[1] <- 1 
  for (i in 1:B){ 
    x1[i+1] <- rnorm(1,1+0.7*(x2[i]-2),sqrt(1-0.7^2)) 
    x2[i+1] <- rnorm(1,2+0.7*(x1[i+1]-1),sqrt(1-0.7^2)) 
  }
}

print.time <- function(n){
  cat("n =",n,fill=TRUE)
  print(system.time(r.for(n)))
  print(system.time(r2norm(n)))
}

で、時間を計測すると、こんな感じにいいい!!!

> print.time(100)
n = 100
   user  system elapsed 
  0.004   0.000   0.004 
   user  system elapsed 
      0       0       0 
> print.time(1000)
n = 1000
   user  system elapsed 
  0.037   0.000   0.037 
   user  system elapsed 
  0.001   0.000   0.000 
> print.time(10000)
n = 10000
   user  system elapsed 
  0.371   0.003   0.376 
   user  system elapsed 
  0.005   0.002   0.007 
> print.time(100000)
n = 1e+05
   user  system elapsed 
  3.857   0.032   3.918 
   user  system elapsed 
  0.085   0.016   0.102 
> print.time(1000000)
n = 1e+06
   user  system elapsed 
 36.130   0.283  36.667 
   user  system elapsed 
  0.683   0.150   0.835 

40倍以上早くなってるようです。Cすげえ、Rすげえ!!!

*1:outerとか使えるところだとすごく速度が違うのは前経験したことがある