2015年10月23日金曜日

Chainerのtype_check

だいぶ前ですが、Chainerのver. 1.1.0から、型のチェック機構が入りました。 この機能は、各Functionが呼ばれた時に、動作条件を満たしているか確認するものです。 Pythonだから型のチェックがなくて大変だと思われがちですが、実際には行列サイズに対する制約の方が多く、典型的な静的型付け言語の型システムだけでこれらを弾くのは難しいです。

最初に設計しているときから、この機能が必須だろうと思っていて、メインで作っていたのでその話を書きます。

例えば以下のnumpyのコードを考えてみましょう。 当然動きません。

>>> x = np.array([1, 2, 3])
>>> y = np.array([1, 2])
>>> z = x * y
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  ValueError: operands could not be broadcast together with shapes (3,) (2,)

要素ごとの積は、2つの引数の形が同じである必要があります。 numpyの基本的なところですが、numpyの多次元配列は shape というプロパティーを持っています。 これは、各次元のサイズをタプルにした値をです。 例えば、長さ3の1次元配列なら、 shape == (3,) です。

ベクトル計算系ライブラリで難しいのは、このように単に「 int 型でなければならない」、という様な条件というよりも、複数の引数間の関係によって条件が決まることが多いということです。 ですので既存のプログラミング言語の静的型でも、十分に条件を書ききるのは難しいのではないかなと思っています(もちろん、それくらいリッチに条件をかける型システムも存在するんだとおもいます)。

実用性や可読性などを考えると、こうした条件を宣言的に書くよりも、非常に愚直に手続き的に記述する、つまりチェックするコードを明示的に書いてしまった方が良いだろうと思いました。 その結果、Chainerのコード中では、この条件は以下のようなPythonコードで書かれています。

type_check.expect(in_types[0].shape == in_types[1].shape)

ここで、 in_types[0] というのは、入力される1番目の引数の型のことです。 同様に in_types[1] は2番目の引数の型ですね。 両者の shape が一致するという条件になるわけです。

ちなみに、上記条件のチェックに失敗すると、Chainerは以下の様なエラーメッセージをはきます。

Expect: in_types[0].shape == in_types[1].shape
Actual: (3,) != (2,)

このエラーが言わんとすることは、1番目の引数と2番めの引数の shape が一致することが期待されていたが、実際には (3,) と (2,) だったので一致しなかったということです。 実はここが一番頑張ったところでした。 Expectのところに、Pythonのコードによる条件が書かれているように見えます。 ふつうにPythonのコードを書くと、TrueかFalseしか返さないので、このような文字列を出力できません。 実は、 in_types[0].shape はタプルではなくて、抽象構文木を表すオブジェクトを返します。 このオブジェクトは __eq__ なども同じように構文木オブジェクトを返すため、 in_types[0].shape == in_types[1].shape 自体が構文木を返します。 type_check.expect は抽象構文木を受け取って、評価し、条件が成り立てば何もせず、成り立たなければ構文木の文字列表現と、評価後の値を使って、エラーメッセージを作って、例外にくるんで投げるという寸法です。 そのため、ほとんどのコードはインタプリタを書いているような状態になって、Python DSLでPythonインタプリタを書くみたいな謎な状況になりました。 結果として以下の様な条件も書けます。 ちなみに実際にこれはLSTMの型チェックに書かれているコードです(LSTMは第2引数の次元が、第1引数の4倍である必要あがある)。

expect(in_types[1].shape[1] == 4 * in_types[0].shape[1])

また、上の様なコードですので、各関数が満たすべき条件は、コードを見れば一目瞭然になっています。 個人的にPythonは、結局ソースを読む言語だと思っていて(標準ライブラリでも結局ソースを読む)、ソースをみれば何が条件なのかわかりやすくなっているのは良いなぁ、と思っています。

こうした shape に対する制約は、numpyがチェックしてくれるんだからいらないんではと思われがちです。 以下のコードを見てみましょう。 これは動きます。

>>> x = np.array([1, 2, 3])
>>> y = np.array([1])
>>> z = x * y

numpyのbroadcastという機能が働くためです。 動くんならいいじゃないかと思われるかもしれませんが、broadcastが起こっているということは、broadcast用の逆伝搬の処理をしないといけないということです。 broadcastの逆伝搬は、broadcastが起こった次元に対してsumを取る必要があります。 そのため、うっかりbroadcastが起こってしまうと逆伝搬の処理で間違ってしまうのです。 これを防ぐために、Chainerでは厳し目の条件でbroadcastが起こらないようになっています。

実はbroadcastが発生しても、 shape によっては、順伝搬と逆伝搬の両方でbroadcastが起こって一見すると動くことがありますが、概ね間違った処理になります。 実際に初期のバージョンではこの問題にだいぶ悩まされました。 一件、逆伝搬がうまくいくのですが、次のFunctionまで伝搬したとき、あるいはもっと後になってからエラーが吐かれます。 結果的にデバッグが非常に困難になりました。

とはいっても、broadcast処理をして欲しい時があります。 次の1.4.0では、明示的にfunctions.broadcast関数を呼ぶことでbroadcastできるようになりました。 この関数は逆伝搬でsumを実行するため、正しく逆伝搬するようになります。

Chainerの中で1つどうしても実現したかったのが、「順伝搬処理が正しく動けば逆伝搬は動く」という性質です。 これはもちろん(?)、OCamlの「ビルドが通ればだいたい動く」という感覚を再現したかったのです。 行列演算系のライブラリは、少なからず挙動に対する理解があやふやになるものです。 numpyのあらゆる仕様を詳細に理解している人が、この世に何人いるでしょう? その為、順伝搬のコードでもどうにかこうにか、やっと動くという事になることが多いわけですね。 その上で逆伝搬でエラーが起こると、デバッグは極めて大変、生産性は極めて落ちてしまいます。 加えて、コードは簡素でわかりやすく、エラーメッセージは可読で、困ったらソースを読めばわかる、加えて開発者は十分条件を書きやすい、そうしたものを実現したかったのでした。 型のエラーは確かに面倒ですが、意図通りに動くこと、意図がはっきりしないときは明示させること、意図しないことが起こらないことが非常に重要だと思っています。

0 件のコメント:

コメントを投稿

'},ClipboardSwf:null,Version:'1.5.1'}};dp.SyntaxHighlighter=dp.sh;dp.sh.Toolbar.Commands={ExpandSource:{label:'+ expand source',check:function(highlighter){return highlighter.collapse;},func:function(sender,highlighter) {sender.parentNode.removeChild(sender);highlighter.div.className=highlighter.div.className.replace('collapsed','');}},ViewSource:{label:'view plain',func:function(sender,highlighter) {var code=dp.sh.Utils.FixForBlogger(highlighter.originalCode).replace(/'+code+'');wnd.document.close();}},CopyToClipboard:{label:'copy to clipboard',check:function(){return window.clipboardData!=null||dp.sh.ClipboardSwf!=null;},func:function(sender,highlighter) {var code=dp.sh.Utils.FixForBlogger(highlighter.originalCode).replace(/</g,'<').replace(/>/g,'>').replace(/&/g,'&');if(window.clipboardData) {window.clipboardData.setData('text',code);} else if(dp.sh.ClipboardSwf!=null) {var flashcopier=highlighter.flashCopier;if(flashcopier==null) {flashcopier=document.createElement('div');highlighter.flashCopier=flashcopier;highlighter.div.appendChild(flashcopier);} flashcopier.innerHTML='';} alert('The code is in your clipboard now');}},PrintSource:{label:'print',func:function(sender,highlighter) {var iframe=document.createElement('IFRAME');var doc=null;iframe.style.cssText='position:absolute;width:0px;height:0px;left:-500px;top:-500px;';document.body.appendChild(iframe);doc=iframe.contentWindow.document;dp.sh.Utils.CopyStyles(doc,window.document);doc.write('

'+highlighter.div.innerHTML+'

');doc.close();iframe.contentWindow.focus();iframe.contentWindow.print();alert('Printing...');document.body.removeChild(iframe);}},About:{label:'?',func:function(highlighter) {var wnd=window.open('','_blank','dialog,width=300,height=150,scrollbars=0');var doc=wnd.document;dp.sh.Utils.CopyStyles(doc,window.document);doc.write(dp.sh.Strings.AboutDialog.replace('{V}',dp.sh.Version));doc.close();wnd.focus();}}};dp.sh.Toolbar.Create=function(highlighter) {var div=document.createElement('DIV');div.className='tools';for(var name in dp.sh.Toolbar.Commands) {var cmd=dp.sh.Toolbar.Commands[name];if(cmd.check!=null&&!cmd.check(highlighter)) continue;div.innerHTML+=''+cmd.label+'';} return div;} dp.sh.Toolbar.Command=function(name,sender) {var n=sender;while(n!=null&&n.className.indexOf('dp-highlighter')==-1) n=n.parentNode;if(n!=null) dp.sh.Toolbar.Commands[name].func(sender,n.highlighter);} dp.sh.Utils.CopyStyles=function(destDoc,sourceDoc) {var links=sourceDoc.getElementsByTagName('link');for(var i=0;i');} dp.sh.Utils.FixForBlogger=function(str) {return(dp.sh.isBloggerMode==true)?str.replace(/
|<br\s*\/?>/gi,'\n'):str;} dp.sh.RegexLib={MultiLineCComments:new RegExp('/\\*[\\s\\S]*?\\*/','gm'),SingleLineCComments:new RegExp('//.*$','gm'),SingleLinePerlComments:new RegExp('#.*$','gm'),DoubleQuotedString:new RegExp('"(?:\\.|(\\\\\\")|[^\\""\\n])*"','g'),SingleQuotedString:new RegExp("'(?:\\.|(\\\\\\')|[^\\''\\n])*'",'g')};dp.sh.Match=function(value,index,css) {this.value=value;this.index=index;this.length=value.length;this.css=css;} dp.sh.Highlighter=function() {this.noGutter=false;this.addControls=true;this.collapse=false;this.tabsToSpaces=true;this.wrapColumn=80;this.showColumns=true;} dp.sh.Highlighter.SortCallback=function(m1,m2) {if(m1.indexm2.index) return 1;else {if(m1.lengthm2.length) return 1;} return 0;} dp.sh.Highlighter.prototype.CreateElement=function(name) {var result=document.createElement(name);result.highlighter=this;return result;} dp.sh.Highlighter.prototype.GetMatches=function(regex,css) {var index=0;var match=null;while((match=regex.exec(this.code))!=null) this.matches[this.matches.length]=new dp.sh.Match(match[0],match.index,css);} dp.sh.Highlighter.prototype.AddBit=function(str,css) {if(str==null||str.length==0) return;var span=this.CreateElement('SPAN');str=str.replace(/ /g,' ');str=str.replace(/');if(css!=null) {if((/br/gi).test(str)) {var lines=str.split(' 
');for(var i=0;ic.index)&&(match.index/gi,'\n');var lines=html.split('\n');if(this.addControls==true) this.bar.appendChild(dp.sh.Toolbar.Create(this));if(this.showColumns) {var div=this.CreateElement('div');var columns=this.CreateElement('div');var showEvery=10;var i=1;while(i<=150) {if(i%showEvery==0) {div.innerHTML+=i;i+=(i+'').length;} else {div.innerHTML+='·';i++;}} columns.className='columns';columns.appendChild(div);this.bar.appendChild(columns);} for(var i=0,lineIndex=this.firstLine;i0;i++) {if(Trim(lines[i]).length==0) continue;var matches=regex.exec(lines[i]);if(matches!=null&&matches.length>0) min=Math.min(matches[0].length,min);} if(min>0) for(var i=0;i