2014年8月6日水曜日

AdaGradが12倍速くなる魔法

AdaGradは学習率を自動調整してくれる勾配法の亜種で、いろんなが絶賛しています。 勾配を足し込む時に、各次元ごとに今までの勾配の2乗和をとっておいて、その平方根で割ってあげるだけと、恐ろしくシンプルです。

Adaptive Subgradient Methods for Online Learning and Stochastic Optimization
John Duchi, Elad Hazan, Yoram Singer. JMLR 2011.

丁度、 @echizen_tm さんがブログを書いてました。

AdaGrad+RDAを実装しました。

通常のSGDなどは学習率をだんだん減衰させながら勾配を足していくわけですが、どの様に減衰させるかという問題にいつも頭を悩ませます。 AdaGradでは最初の学習率こそ外から与えますが、減衰のさせ方や減衰率といったハイパーパラメータから開放されます。 今日の話は、そんなAdaGradがちょっとの工夫で12倍速になった話です。

さて、本題はここから。 AdaGradの部分を高速化したいなー、と思って更新部分だけ抜き出して速度を測ってみました。 更新部分を抜粋。

void ada_grad(
    float eta,
    const std::vector<float>& g,
    std::vector<float>& sum,
    std::vector<float>& x) {
  for (std::size_t i = 0, size = x.size(); i < size; ++i) {
    sum[i] += g[i] * g[i];
    x[i] += eta / std::sqrt(sum[i]) * g[i];
  }
}

コードの全体は以下のgistにあります。

https://gist.github.com/unnonouno/9ca83eefb31a6474caf6

試しに速度を測ってみましょう。

$ gcc -O2 adagrad_test.cpp
$ ./a.out
7744.53 msec
200.099

まぁ、こんなもんです。 下の200.099は計算結果を適当に出力してるだけ。 さて、ここで一工夫。 なんてことはなくて、コンパイルオプションを変えます。 では、もう1度。

$ gcc -Ofast -march=native adagrad_test.cpp
$ ./a.out
540.679 msec
200.099
!!?

驚異的に速くなりましたね。 -O2の12倍です。 ちょっと何いってるのかわからないレベルで速くなります。 もちろんコンパイラが頑張るからなんですが、何でこんなに速くなるんでしょう。 アセンブリコードをよんでみましょう。 まずは、-O2から。 該当部分だけ。

.L5:
     movss   (%rcx,%rbx,4), %xmm1
     movq    (%rsi), %rax
     mulss   %xmm1, %xmm1
     leaq    0(,%rbx,4), %rbp
     addq    %rbp, %rax
     addss   (%rax), %xmm1
     sqrtss  %xmm1, %xmm0
     ucomiss %xmm0, %xmm0
     movss   %xmm1, (%rax)
     jp      .L10
.L3:
     movaps  %xmm3, %xmm1
     movq    (%rdx), %rax
     addq    $1, %rbx
     divss   %xmm0, %xmm1
     addq    %rbp, %rax
     cmpq    %r12, %rbx
     mulss   (%rcx,%rbp), %xmm1
     addss   (%rax), %xmm1
     movss   %xmm1, (%rax)
     jne     .L5

まぁこんなものかな。 そして、-Ofastの方。

.L4:
     vmovups (%rdi,%rax), %xmm1
     addq    $1, %r8
     vmovups (%rsi,%rax), %xmm2
     vinsertf128     $0x1, 16(%rdi,%rax), %ymm1, %ymm1
     vmulps  %ymm1, %ymm1, %ymm1
     vinsertf128     $0x1, 16(%rsi,%rax), %ymm2, %ymm2
     vaddps  %ymm2, %ymm1, %ymm1
     vmovups %xmm1, (%rsi,%rax)
     vextractf128    $0x1, %ymm1, 16(%rsi,%rax)
     vrsqrtps        %ymm1, %ymm2
     vmulps  %ymm1, %ymm2, %ymm1
     vmulps  %ymm2, %ymm1, %ymm1
     vaddps  %ymm5, %ymm1, %ymm1
     vmovups (%rcx,%rax), %xmm3
     vmulps  %ymm4, %ymm2, %ymm2
     vmulps  %ymm2, %ymm1, %ymm2
     vmovups (%rdi,%rax), %xmm1
     vinsertf128     $0x1, 16(%rcx,%rax), %ymm3, %ymm3
     vmulps  %ymm6, %ymm2, %ymm2
     vinsertf128     $0x1, 16(%rdi,%rax), %ymm1, %ymm1
     vmulps  %ymm1, %ymm2, %ymm1
     vaddps  %ymm3, %ymm1, %ymm1
     vmovups %xmm1, (%rcx,%rax)
     vextractf128    $0x1, %ymm1, 16(%rcx,%rax)
     addq    $32, %rax
     cmpq    %r8, %r10
     ja      .L4

劇的に変わりました。 まず、全体的にvのつく命令に置き換わっています。 そう、AVX命令に置き換わってます。 AVXでは、8つの単精度浮動少数を1度に扱えるので、これで8倍速です。 ループ内は掛けたり割ったりしてますが、分岐がないのでベクトル化できちゃう。 すごい。

でももっと速くなってますね。 何ででしょう。 ボトルネックになりそうなsqrt周りを見てみましょう。 実は-O2でも、sqrtssというCPU命令を使って1発で計算しているのでそれなりに高速です。 ところが、-Ofastでは、vrsqrtpsという何やら怪しげな命令が。 rsqrtというのは、-1/2乗を計算する命令です。 そのお陰で、-Ofast側ではdivが無くなっていることに気づきます。

ところで、何でrsqrtという命令が有るかというと、1/2乗よりも-1/2乗の方が簡単だからと思われます。 これは学生実験の時に教わった気がします。 -1/2をニュートン法で計算するときに、わり算が必要ないのです。

http://www.riken.jp/brict/Ijiri/study/fastsqrt.html

加えて、 rsqrtは精度を落として計算しているため高速なようです だいたい、初期学習率が適当なので、今のケースで精度はさしていらないでしょう。 ちょっと調べてみると、sqrt 命令が8〜105サイクルなのに rsqrt 命令が驚きの1サイクルと書かれています。 ということは、sqrtよりも高効率で、しかもdivも消えちゃう。 ちなみにperfで見てみると、rsqrtは全体の5%しかCPU時間を食っておらず、その他の掛け算1回分と同じコストでした。 極めて効率的に計算できるようです。

もはやsqrtがボトルネックになる時代ではないんですね。 対象とするCPUがどのような機能を持っていて、利用するコンパイラがどのような最適化をかけてくれるのかということを把握することは大事です。 実は最初実装した時に、分母が0のときに割らないように分岐していたのですが、εを足すように変えました。 これにしないと最適化がかかりません。 そもそも、掛けたり足したりしてるし、ベクトル2つ使ってるし、sqrtもでてくるし、流石にベクトル化できないだろうと思ってたら甘かったです。 ちなみにこのとき、1回の学習に8時間掛かるという試算になっていたのですが、上記の最適化とOpenMPによる4倍速で、計50倍速、最終的に10分で終わるようになりました。 遅いなー、と思って10台のマシンで計算するよりも、コンパイラの最適化しやすいように工夫したほうが効果が高かったりするんですね。

0 件のコメント:

コメントを投稿

'},ClipboardSwf:null,Version:'1.5.1'}};dp.SyntaxHighlighter=dp.sh;dp.sh.Toolbar.Commands={ExpandSource:{label:'+ expand source',check:function(highlighter){return highlighter.collapse;},func:function(sender,highlighter) {sender.parentNode.removeChild(sender);highlighter.div.className=highlighter.div.className.replace('collapsed','');}},ViewSource:{label:'view plain',func:function(sender,highlighter) {var code=dp.sh.Utils.FixForBlogger(highlighter.originalCode).replace(/'+code+'');wnd.document.close();}},CopyToClipboard:{label:'copy to clipboard',check:function(){return window.clipboardData!=null||dp.sh.ClipboardSwf!=null;},func:function(sender,highlighter) {var code=dp.sh.Utils.FixForBlogger(highlighter.originalCode).replace(/</g,'<').replace(/>/g,'>').replace(/&/g,'&');if(window.clipboardData) {window.clipboardData.setData('text',code);} else if(dp.sh.ClipboardSwf!=null) {var flashcopier=highlighter.flashCopier;if(flashcopier==null) {flashcopier=document.createElement('div');highlighter.flashCopier=flashcopier;highlighter.div.appendChild(flashcopier);} flashcopier.innerHTML='';} alert('The code is in your clipboard now');}},PrintSource:{label:'print',func:function(sender,highlighter) {var iframe=document.createElement('IFRAME');var doc=null;iframe.style.cssText='position:absolute;width:0px;height:0px;left:-500px;top:-500px;';document.body.appendChild(iframe);doc=iframe.contentWindow.document;dp.sh.Utils.CopyStyles(doc,window.document);doc.write('

'+highlighter.div.innerHTML+'

');doc.close();iframe.contentWindow.focus();iframe.contentWindow.print();alert('Printing...');document.body.removeChild(iframe);}},About:{label:'?',func:function(highlighter) {var wnd=window.open('','_blank','dialog,width=300,height=150,scrollbars=0');var doc=wnd.document;dp.sh.Utils.CopyStyles(doc,window.document);doc.write(dp.sh.Strings.AboutDialog.replace('{V}',dp.sh.Version));doc.close();wnd.focus();}}};dp.sh.Toolbar.Create=function(highlighter) {var div=document.createElement('DIV');div.className='tools';for(var name in dp.sh.Toolbar.Commands) {var cmd=dp.sh.Toolbar.Commands[name];if(cmd.check!=null&&!cmd.check(highlighter)) continue;div.innerHTML+=''+cmd.label+'';} return div;} dp.sh.Toolbar.Command=function(name,sender) {var n=sender;while(n!=null&&n.className.indexOf('dp-highlighter')==-1) n=n.parentNode;if(n!=null) dp.sh.Toolbar.Commands[name].func(sender,n.highlighter);} dp.sh.Utils.CopyStyles=function(destDoc,sourceDoc) {var links=sourceDoc.getElementsByTagName('link');for(var i=0;i');} dp.sh.Utils.FixForBlogger=function(str) {return(dp.sh.isBloggerMode==true)?str.replace(/
|<br\s*\/?>/gi,'\n'):str;} dp.sh.RegexLib={MultiLineCComments:new RegExp('/\\*[\\s\\S]*?\\*/','gm'),SingleLineCComments:new RegExp('//.*$','gm'),SingleLinePerlComments:new RegExp('#.*$','gm'),DoubleQuotedString:new RegExp('"(?:\\.|(\\\\\\")|[^\\""\\n])*"','g'),SingleQuotedString:new RegExp("'(?:\\.|(\\\\\\')|[^\\''\\n])*'",'g')};dp.sh.Match=function(value,index,css) {this.value=value;this.index=index;this.length=value.length;this.css=css;} dp.sh.Highlighter=function() {this.noGutter=false;this.addControls=true;this.collapse=false;this.tabsToSpaces=true;this.wrapColumn=80;this.showColumns=true;} dp.sh.Highlighter.SortCallback=function(m1,m2) {if(m1.indexm2.index) return 1;else {if(m1.lengthm2.length) return 1;} return 0;} dp.sh.Highlighter.prototype.CreateElement=function(name) {var result=document.createElement(name);result.highlighter=this;return result;} dp.sh.Highlighter.prototype.GetMatches=function(regex,css) {var index=0;var match=null;while((match=regex.exec(this.code))!=null) this.matches[this.matches.length]=new dp.sh.Match(match[0],match.index,css);} dp.sh.Highlighter.prototype.AddBit=function(str,css) {if(str==null||str.length==0) return;var span=this.CreateElement('SPAN');str=str.replace(/ /g,' ');str=str.replace(/');if(css!=null) {if((/br/gi).test(str)) {var lines=str.split(' 
');for(var i=0;ic.index)&&(match.index/gi,'\n');var lines=html.split('\n');if(this.addControls==true) this.bar.appendChild(dp.sh.Toolbar.Create(this));if(this.showColumns) {var div=this.CreateElement('div');var columns=this.CreateElement('div');var showEvery=10;var i=1;while(i<=150) {if(i%showEvery==0) {div.innerHTML+=i;i+=(i+'').length;} else {div.innerHTML+='·';i++;}} columns.className='columns';columns.appendChild(div);this.bar.appendChild(columns);} for(var i=0,lineIndex=this.firstLine;i0;i++) {if(Trim(lines[i]).length==0) continue;var matches=regex.exec(lines[i]);if(matches!=null&&matches.length>0) min=Math.min(matches[0].length,min);} if(min>0) for(var i=0;i