のんびりしているエンジニアの日記

ソフトウェアなどのエンジニア的な何かを書きます。

Chainerの学習の様子をリモートで確認するExtensionを作った

Sponsored Links

皆さんこんにちは
お元気ですか。私はGWでリフレッシュして、生き返りました。

Kaggleをやっているとき(特に画像などの長い場合)にリモートで
今学習されているかどうか、誤差はどうかなどのモデルの
様子が気になることはありませんか?

私は画像認識系のコンペを実際に行っている時に、気になることがあります。
これどうしようかと考えていたのですが、歩いている時にふと思いついたので実装しました。
このアイデアの実装のために、新しいChainerのExtensionを開発しました。(Trainerを使う想定です)

アイデア

Slackであれば外出中も見れると考えました。
そのため、学習の途中経過(lossなど)を投稿すれば見れる!
実装イメージは次の図に掲載しました。

f:id:tereka:20170509235100j:plain

コードを見た限りだと、Extensionで実装できそうだったので、トライしました。

Extensionの実装方法

Extensionの実装ですが、先に必要な情報を__init__に実装します。
そして、__call__が呼び出され、処理をする仕組みとなっています。

今回は既に実装されているPrintReportを参考に実装します。

前準備

SlackのWeb APIのIncoming Webhooksと投稿するChannelとusernameが必要です。
SlackのAPIに必要な情報は以下のurlから遷移して取得してください。

api.slack.com

コード

早速Extensionを実装しました。
表示情報は既に実装されているPrintReportと同じにしました。
_throw_slackにSlackに投稿する部分を定義しています。
__init__に初期設定で必要な情報、__call__に表示する部分を記載しました。

# coding:utf-8
import os
import sys
from chainer.training import extension
from chainer.training.extensions import log_report as log_report_module
import requests
import json


class SlackReport(extension.Extension):
    def __init__(self, entries, log_report='LogReport', username="", url="",channel="",out=sys.stdout):
        self._entries = entries
        self._log_report = log_report
        self._log_len = 0  # number of observations already printed
        self._out = out

        # format information
        entry_widths = [max(10, len(s)) for s in entries]

        header = '  '.join(('{:%d}' % w for w in entry_widths)).format(
            *entries) + '\n'
        self._header = header  # printed at the first call

        templates = []
        for entry, w in zip(entries, entry_widths):
            templates.append((entry, '{:<%dg}  ' % w, '  ' * (w + 2)))
        self._templates = templates
        self.username = username
        self.url = url
        self.channel = channel

    def __call__(self, trainer):
        if self._header:
            self._throw_slack(self._header)
            self._header = None

        log_report = self._log_report
        if isinstance(log_report, str):
            log_report = trainer.get_extension(log_report)
        elif isinstance(log_report, log_report_module.LogReport):
            log_report(trainer)  # update the log report
        else:
            raise TypeError('log report has a wrong type %s' %
                            type(log_report))

        log = log_report.log
        log_len = self._log_len
        while len(log) > log_len:
            self._observation_throw_slack(log[log_len])
            log_len += 1
        self._log_len = log_len

    def _throw_slack(self, text):
        try:
            payload_dic = {
                "text": text,
                "username": self.username,
                "channel": self.channel,
            }
            requests.post(self.url, data=json.dumps(payload_dic))
        except:
            self._out.write("error!")

    def serialize(self, serializer):
        log_report = self._log_report
        if isinstance(log_report, log_report_module.LogReport):
            log_report.serialize(serializer['_log_report'])

    def _observation_throw_slack(self, observation):
        text = ""
        for entry, template, empty in self._templates:
            if entry in observation:
                text += template.format(observation[entry])
            else:
                text += empty
        self._throw_slack(text)

使い方

APIに必要な情報の定義とExtensionにSlackReportを追加するのみです。
例えば、公式のMNISTサンプルでは次のExtensionを追記してください。

username,channel,urlはそれぞれ自分の設定を指定してください。

    trainer.extend(SlackReport(['epoch', 'main/loss', 'validation/main/loss',
                                'main/accuracy', 'validation/main/accuracy', 'elapsed_time'],
                               username="YOUR USER NAME",
                               channel="YOUR SLACK CHANNEL",
                               url="SLACK API URL",
                               ))

Slackへの投稿結果

ExampleのMNISTで実験したら、こんな感じになります。ちょっとガタガタなのはご愛嬌。
外出時に気になる場合は皆さんも使ってみましょう。

設定が間違っているときの処理は…、結構適当です。

f:id:tereka:20170509233656p:plain