機械学習ライブラリPyBrainでニューラルネットワークを作ってS3などで永続化したいときの話。
pickleを使う
pythonの標準の直列化モジュールpickle
を使うと、一見ニューラルネットワークが保存・復元され、ちゃんとpickle.load()
したオブジェクトに対してactivate()
もできる。
しかし、BackpropTrainer.train()
などによる学習が再開されないという問題に遭遇してハマった。pickleの詳細な挙動もわかっておらず、原因未調査。
NetworkWriterを使う
PyBrainのユーティリティとして提供されるNetworkWriter
を使うと上の問題が解消され、学習が再開される。使い方も簡単なのでこれでよい気がする。
インポート
from pybrain.tools.customxml import NetworkWriter, NetworkReader
書き出し
NetworkWriter.writeToFile(network, filename_local)
読み込み
network = NetworkReader.readFrom(filename_local)
ファイルはXMLで保存され、ニューラルやってるのがなんとなくわかるしpickle
のフォーマットよりなんとなく安心感がある。
<?xml version="1.0" ?>
<PyBrain>
<Network class="pybrain.structure.networks.feedforward.FeedForwardNetwork" name="FeedForwardNetwork-8">
<name val="u'FeedForwardNetwork-8'"/>
<Modules>
<LinearLayer class="pybrain.structure.modules.linearlayer.LinearLayer" inmodule="True" name="in">
<dim val="8"/>
<name val="'in'"/>
</LinearLayer>
<LinearLayer class="pybrain.structure.modules.linearlayer.LinearLayer" name="out" outmodule="True">
<dim val="1"/>
<name val="'out'"/>
</LinearLayer>
<BiasUnit class="pybrain.structure.modules.biasunit.BiasUnit" name="bias">
<name val="'bias'"/>
</BiasUnit>
<SigmoidLayer class="pybrain.structure.modules.sigmoidlayer.SigmoidLayer" name="hidden0">
<dim val="3"/>
<name val="'hidden0'"/>
</SigmoidLayer>
</Modules>
<Connections>
<FullConnection class="pybrain.structure.connections.full.FullConnection" name="FullConnection-6">
<inmod val="bias"/>
<outmod val="out"/>
<Parameters>[0.6554487520957738]</Parameters>
</FullConnection>
<FullConnection class="pybrain.structure.connections.full.FullConnection" name="FullConnection-7">
<inmod val="bias"/>
<outmod val="hidden0"/>
<Parameters>[0.8141201069100833, -1.9519540651889176, 0.3483014480876096]</Parameters>
</FullConnection>
<FullConnection class="pybrain.structure.connections.full.FullConnection" name="FullConnection-5">
<inmod val="in"/>
<outmod val="hidden0"/>
<Parameters>[0.32489279837910084, 0.34480786433574551, 0.75045803824057666, -0.58411948692771176, -0.12327324616272992, -0.41228675759787226, -0.85553671683893218, -1.3320521945223582, -1.0531422952418676, -0.40839301403900452, -2.7565756871565674, -1.6723188687051469, -1.3597994054921079, 0.24852450267525059, -0.40924881241151689, 0.54037857219934371, 1.0960673042273468, 1.3324258379470664, 0.29047259837334116, -0.022417631256966383, 0.44571376571760984, 0.6492450404233816, -0.29105564158278247, 1.2243353023571237]</Parameters>
</FullConnection>
<FullConnection class="pybrain.structure.connections.full.FullConnection" name="FullConnection-4">
<inmod val="hidden0"/>
<outmod val="out"/>
<Parameters>[0.25616738836523284, -2.2028123481048487, -0.11026881677981226]</Parameters>
</FullConnection>
</Connections>
</Network>
</PyBrain>