無料で使えるシステムトレードフレームワーク「Jiji」 をリリースしました!

・OANDA Trade APIを利用した、オープンソースのシステムトレードフレームワークです。
・自分だけの取引アルゴリズムで、誰でも、いますぐ、かんたんに、自動取引を開始できます。

全銘柄、全期間の日ごとの各種平均値、標準偏差を計算するSQL

stockdb のデータをもとに、全銘柄、全期間の日ごとの各種平均値、標準偏差を計算するSQLを書いたのでメモ。

  • 終値、出来高、価格変動率のそれぞれについて、過去3,5,10,15,25,50,75日の平均値と標準偏差を一気に算出します。
  • それなりに時間はかかります。確認した環境だと1時間くらいでした。
  • Window関数すごい。
-- 市場が開いている日の一覧
CREATE MATERIALIZED VIEW days as (
  WITH ds AS (
    SELECT distinct date FROM rates
  ), x AS (
    SELECT
      date,
      row_number() OVER (ORDER BY date) AS index,
      lag(date,  1, null) OVER (ORDER BY date) AS prev,
      lag(date,  3, null) OVER (ORDER BY date) AS before_3_days,
      lag(date,  5, null) OVER (ORDER BY date) AS before_5_days,
      lag(date, 10, null) OVER (ORDER BY date) AS before_10_days,
      lag(date, 15, null) OVER (ORDER BY date) AS before_15_days,
      lag(date, 25, null) OVER (ORDER BY date) AS before_25_days,
      lag(date, 50, null) OVER (ORDER BY date) AS before_50_days,
      lag(date, 75, null) OVER (ORDER BY date) AS before_75_days
    FROM ds ORDER BY date desc
  )
  SELECT * FROM x WHERE prev IS NOT NULL
);


-- 出来高0(volume=0)のデータが抜けているのでそれをを補完したビューを作る
CREATE MATERIALIZED VIEW rates_filled AS (
  WITH all_stock_and_days as (
    SELECT d.*, s.id as stock_id FROM days as d, stocks as s
  ), x as (
    SELECT
      a.stock_id, a.date,
      CASE WHEN r.date IS NOT NULL THEN r.date
           ELSE (SELECT max(date) FROM rates WHERE stock_id = a.stock_id and date <= a.date )
      END  as actual
    FROM all_stock_and_days AS a
    LEFT JOIN rates as r ON a.stock_id = r.stock_id and a.date = r.date
  )
  SELECT
    x.stock_id, x.date,
    CASE WHEN x.actual = x.date THEN r.open   ELSE r.close END as open,
    CASE WHEN x.actual = x.date THEN r.close  ELSE r.close END as close,
    CASE WHEN x.actual = x.date THEN r.high   ELSE r.close END as high,
    CASE WHEN x.actual = x.date THEN r.low    ELSE r.close END as low,
    CASE WHEN x.actual = x.date THEN r.volume ELSE 0 END as volume
  FROM x
  LEFT JOIN rates as r ON x.stock_id = r.stock_id and x.actual = r.date
);
CREATE UNIQUE INDEX rates_filled_stock_id_date_index
  ON rates_filled (stock_id, date);


-- 前日からの価格変動率を計算
CREATE MATERIALIZED VIEW ratios AS (
  WITH x AS (
    SELECT d.*, s.id FROM days AS d, stocks AS s
  )
  SELECT
    x.id as stock_id, x.date,
    (r1.close-r2.close)/r2.close as ratio
  FROM x
  LEFT JOIN rates_filled as r1 on x.id = r1.stock_id AND r1.date = x.date
  LEFT JOIN rates_filled as r2 on x.id = r2.stock_id AND r2.date = x.prev
  WHERE r2.close IS NOT NULL
);
CREATE UNIQUE INDEX ratios_stock_id_date_index
  ON ratios (stock_id, date);


-- 全銘柄、全期間の日ごとの各種平均値、標準偏差をを計算
-- 終値、出来高、価格変動率のそれぞれについて、過去3,5,10,15,25,50,75日の平均値と標準偏差を算出する

CREATE MATERIALIZED VIEW ma AS (
SELECT
  r.stock_id,
  r.date,
  r.volume,

  avg(r.close)    OVER from_3_days_ago  as avg_close_3days,
  stddev(r.close) OVER from_3_days_ago  as sd_close_3days,
  avg(r.close)    OVER from_5_days_ago  as avg_close_5days,
  stddev(r.close) OVER from_5_days_ago  as sd_close_5days,
  avg(r.close)    OVER from_10_days_ago as avg_close_10days,
  stddev(r.close) OVER from_10_days_ago as sd_close_10days,
  avg(r.close)    OVER from_15_days_ago as avg_close_15days,
  stddev(r.close) OVER from_15_days_ago as sd_close_15days,
  avg(r.close)    OVER from_25_days_ago as avg_close_25days,
  stddev(r.close) OVER from_25_days_ago as sd_close_25days,
  avg(r.close)    OVER from_50_days_ago as avg_close_50days,
  stddev(r.close) OVER from_50_days_ago as sd_close_50days,
  avg(r.close)    OVER from_75_days_ago as avg_close_75days,
  stddev(r.close) OVER from_75_days_ago as sd_close_75days,

  avg(ra.ratio)    OVER from_3_days_ago  as avg_ratio_3days,
  stddev(ra.ratio) OVER from_3_days_ago  as sd_ratio_3days,
  avg(ra.ratio)    OVER from_5_days_ago  as avg_ratio_5days,
  stddev(ra.ratio) OVER from_5_days_ago  as sd_ratio_5days,
  avg(ra.ratio)    OVER from_10_days_ago as avg_ratio_10days,
  stddev(ra.ratio) OVER from_10_days_ago as sd_ratio_10days,
  avg(ra.ratio)    OVER from_15_days_ago as avg_ratio_15days,
  stddev(ra.ratio) OVER from_15_days_ago as sd_ratio_15days,
  avg(ra.ratio)    OVER from_25_days_ago as avg_ratio_25days,
  stddev(ra.ratio) OVER from_25_days_ago as sd_ratio_25days,
  avg(ra.ratio)    OVER from_50_days_ago as avg_ratio_50days,
  stddev(ra.ratio) OVER from_50_days_ago as sd_ratio_50days,
  avg(ra.ratio)    OVER from_75_days_ago as avg_ratio_75days,
  stddev(ra.ratio) OVER from_75_days_ago as sd_ratio_75days,

  avg(r.volume)    OVER from_3_days_ago  as avg_volume_3days,
  stddev(r.volume) OVER from_3_days_ago  as sd_volume_3days,
  avg(r.volume)    OVER from_5_days_ago  as avg_volume_5days,
  stddev(r.volume) OVER from_5_days_ago  as sd_volume_5days,
  avg(r.volume)    OVER from_10_days_ago as avg_volume_10days,
  stddev(r.volume) OVER from_10_days_ago as sd_volume_10days,
  avg(r.volume)    OVER from_15_days_ago as avg_volume_15days,
  stddev(r.volume) OVER from_15_days_ago as sd_volume_15days,
  avg(r.volume)    OVER from_25_days_ago as avg_volume_25days,
  stddev(r.volume) OVER from_25_days_ago as sd_volume_25days,
  avg(r.volume)    OVER from_50_days_ago as avg_volume_50days,
  stddev(r.volume) OVER from_50_days_ago as sd_volume_50days,
  avg(r.volume)    OVER from_75_days_ago as avg_volume_75days,
  stddev(r.volume) OVER from_75_days_ago as sd_volume_75days

FROM rates_filled as r
LEFT JOIN ratios AS ra ON ra.stock_id = r.stock_id AND ra.date = r.date
WINDOW from_3_days_ago  AS ( PARTITION BY r.stock_id ORDER BY r.stock_id, r.date desc ROWS BETWEEN CURRENT ROW AND  2 FOLLOWING ),
       from_5_days_ago  AS ( PARTITION BY r.stock_id ORDER BY r.stock_id, r.date desc ROWS BETWEEN CURRENT ROW AND  4 FOLLOWING ),
       from_10_days_ago AS ( PARTITION BY r.stock_id ORDER BY r.stock_id, r.date desc ROWS BETWEEN CURRENT ROW AND  9 FOLLOWING ),
       from_15_days_ago AS ( PARTITION BY r.stock_id ORDER BY r.stock_id, r.date desc ROWS BETWEEN CURRENT ROW AND 14 FOLLOWING ),
       from_25_days_ago AS ( PARTITION BY r.stock_id ORDER BY r.stock_id, r.date desc ROWS BETWEEN CURRENT ROW AND 24 FOLLOWING ),
       from_50_days_ago AS ( PARTITION BY r.stock_id ORDER BY r.stock_id, r.date desc ROWS BETWEEN CURRENT ROW AND 49 FOLLOWING ),
       from_75_days_ago AS ( PARTITION BY r.stock_id ORDER BY r.stock_id, r.date desc ROWS BETWEEN CURRENT ROW AND 74 FOLLOWING )
order by r.date desc
);
CREATE UNIQUE INDEX ma_stock_id_date_index
  ON ma(stock_id, date);

日本株の日足データをローカルのデータベースに取り込むツールを作った

日本株の日足データをローカルのデータベースに取り込むツールを作ってみました。

github.com

  • Quandl で公開されている日本株の日足データを取得して、ローカルのPostgreSQLに取り込みます。
  • Tokyo Stock Exchangeデータベースの全データを取り込むので、ETFのデータなども含まれます。
  • とりあえず、直近500日分のデータを取得するようにしています。 市場が開いている日のデータのみなので、大体2年分です。
  • 未インポートのデータのみ取り込むので、取りこぼしやエラーがあっても再実行すればOK。
  • 毎日cronで実行すれば、最新の日足データが使える状態になるはずです。
  • DBのテーブル構成は以下。

f:id:unageanu:20160322123029p:plain

事前準備

$ git --version
git version 1.8.3.1
$ docker -v
Docker version 1.10.2, build c3959b1
$ docker-compose -v
docker-compose version 1.6.2, build 4d72027

使い方

$ git clone https://github.com/unageanu/stock-db.git
$ cd stock-db
$ vi .env # POSTGRES_PASSWORD、QUANDL_API_KEYを設定します。
          # 以下は設定例
---
POSTGRES_USER=postgres
POSTGRES_PASSWORD=mysecretpassword
QUANDL_API_KEY=myquandlapikey
QUANDL_API_VERSION=2015-04-09
---
$ docker-compose up -d # PostgreSQLã‚’èµ·å‹•
$ bundle install
$ bundle exec ruby -I src ./src/importer.rb

QuandlのAPIキー取得はこちらを参照。 銘柄は4000弱あるので、取り込みには2,3時間かかります。

なお、いろいろ雑に書いているので、ご了承ください。

Clairをインストールして、Dockerイメージの脆弱性スキャンをする手順

Dockerイメージの脆弱性スキャンツール「Clair」 をインストールして、ローカルのイメージをチェックする手順です。微妙にはまったのでメモ。

0.環境

$ cat /etc/redhat-release
CentOS Linux release 7.2.1511 (Core) 
$ docker -v
Docker version 1.10.2, build c3959b1

1.PostgreSQLをインストールして起動

$ docker pull postgres:latest 
$ docker run --name postgres -p 5432:5432 -e POSTGRES_PASSWORD=<パスワード> -d postgres

2.clairをインストールして起動

$ mkdir ./clair_config
$ curl -L https://raw.githubusercontent.com/coreos/clair/master/config.example.yaml -o ./clair_config/config.yaml
$ vi ./clair_config/config.yaml
# database - source を以下の通り変更
---
database:
  # PostgreSQL Connection string
  # http://www.postgresql.org/docs/9.4/static/libpq-connect.html
  source: postgresql://postgres:<パスワード>@postgres:5432?sslmode=disable
---
$ docker run -p 6060-6061:6060-6061 --link postgres:postgres -v /tmp:/tmp -v $PWD/clair_config:/config quay.io/coreos/clair -config=/config/config.yaml

脆弱性データの読み込みが開始されるので、"updater: update finished" が表示されるまで待ちます。(1時間くらいかかりました・・・)

3.ローカルチェックツールをインストールして実行。

$ sudo yum -y install golang
$ export GOPATH=~/.go
$ go get -u github.com/coreos/clair/contrib/analyze-local-images
$ docker pull <チェックしたいイメージ>
$ sudo $GOPATH/bin/analyze-local-images <チェックしたいイメージ>

継続的に実行するにはどうしたらいいんだろう・・・。

追記: docker-compose.yml

docker-compose.yml も作ってみました。

version: '2'
services:
  postgres:
    container_name: clair_postgres
    image: postgres:latest
    environment:
      POSTGRES_PASSWORD: <パスワード>
    ports: 
      - "5432:5432"
    volumes:
      - ./data:/var/lib/postgresql/data

  clair:
    container_name: clair_clair
    image: quay.io/coreos/clair
    ports:
      - "6060-6061:6060-6061"
    links:
      - postgres 
    volumes:
      - /tmp:/tmp
      - ./config:/config
    command: [-config, /config/config.yaml]

GitHub の Dockerfile から Docker Image を自動ビルドする設定手順

Docker Hub の Automated Builds を使うと、GitHub または Bitbucket の Dockerfile の変更を検知して、Docker Image を自動ビルドすることができます。

  • レポジトリへのコミットを検知して、Docker Hub の Docker Image を自動でビルドします。
  • タグやブランチの追加を検出して、自動でタグ・ブランチ名をTAGとした Docker Image を作ることも可能。

一通り設定を試してみたので、メモです。

  • unagenau/docker-jiji2 に配置している Dockerfile を自動ビルドするようにしてみました。
  • Docker Hub のアカウントがない場合は、こちらから作成してください。

1. Docker Hub のアカウント と GitHubアカウントとを連携させる

まず初めに、Docker Hub のアカウント と GitHubアカウントとを連携させる必要があります。

Docker Hubにログインして、右上のメニューから Settings を選択します。

f:id:unageanu:20160207194520p:plain

続いて、 上のタブから Linked Accounts & Services を選択。

f:id:unageanu:20160207194521p:plain

GitHub を選択。

f:id:unageanu:20160207194522p:plain

リンク方法を選択します。 自動でhookを設定してくれるらしいので、 Public and Private を選択しました。

f:id:unageanu:20160207194523p:plain

連携確認の画面が表示されるので、連携を許可して完了。

2. 自動ビルドを行うレポジトリを作る

GitHubアカウントとの連携ができたら、自動ビルドを行うレポジトリを作ります。

右上のメニューから Create - Create Automated Build を選択。

f:id:unageanu:20160207194524p:plain

GitHub をクリック。

f:id:unageanu:20160207194525p:plain

Dockerfile をホストしているレポジトリを選択します。

f:id:unageanu:20160207194526p:plain

もろもろ設定します。

  • Short Description に説明を入力
  • Click here to customize をクリックすると、詳細設定画面が開きます。
  • Dockerfile をルートディレクトリに置いていなかったので、詳細画面の Dockerfile Locationで指定しています。
  • また、タグの追加を検出して、タグ名付のイメージを作成する設定も行ってみました。

f:id:unageanu:20160207194527p:plain

設定が完了後、Create をクリックすると、レポジトリが作成されます。

3. イメージをビルドしてみる

レポジトリができたら、GitHubレポジトリに変更をコミットする or タグを追加すると、自動でビルドが行われます。 また、UIから手動でビルドを実行することもできます。手動ビルドしたい場合は、Build Settings の Trigger ボタンをクリックすればOK。

f:id:unageanu:20160207194528p:plain

Build Details タブで、ビルドタスクの実行状況を確認できます。

f:id:unageanu:20160207194529p:plain

一旦 Quque に積まれた後、 10分程度待つとビルドされました。

レンジブレイク手法でのトレードをアシストするBotのサンプル

FXシステムトレードフレームワーク「Jiji」 のサンプルその3。
レンジブレイク手法を使ったトレードをアシストするBotを作ってみました。

FX Wroks さんのサイト に掲載されていた「レンジブレイクを狙うシンプルな順張り」手法を、そのままJijiに移植してみたものです。

動作

以下のような動作をします。

f:id:unageanu:20160123200914p:plain

  • 1) Botがレートを監視し、レンジブレイクをチェック
    • 条件は、サイトの内容と同等、8時間レートが100pips内で推移したあと、上or下に抜ける、としました。
    • 待つ期間やpipsは、パラメータで調整できるようにしています。
  • 2) レンジブレイクを検出したら、スマホに通知を送信します
    • ダマしが多いので、今回は通知を送って判断する形に。
  • 3) 通知を受けて最終判断を行い、トレードを実行。
    • 通知にあるボタンを押すことで、売or買で成行注文を実行できるようにしています。
    • 決済は、トレーリングストップで。

軽く動かしてみた感想

軽くテストしてみましたが、思ったよりもダマしに引っかかる感じですね。

f:id:unageanu:20160123200915p:plain

これは、まぁまぁ。

f:id:unageanu:20160123200917p:plain

これは、ブレイクと判定された時点で下げが終わっている・・。

f:id:unageanu:20160123200916p:plain

これは、一度上にブレイクしたあと、逆方向に進んでいます・・・。

ブレイクの条件を調整してみる、移動平均でのトレンドチェックと組み合わせるなど、カスタマイズして使ってみてください。

コード

# === レンジブレイクでトレードを行うエージェント
class RangeBreakAgent

  include Jiji::Model::Agents::Agent

  def self.description
    <<-STR
レンジブレイクでトレードを行うエージェント。
 - 指定期間(デフォルトは8時間)のレートが一定のpipsに収まっている状態から、
   レンジを抜けたタイミングで通知を送信。
 - 通知からトレード可否を判断し、取引を実行できます。
 - 決済はトレーリングストップで行います。
    STR
  end

  # UIから設定可能なプロパティの一覧
  def self.property_infos
    [
      Property.new('target_pair',  '対象とする通貨ペア',      'USDJPY'),
      Property.new('range_period', 'レンジを判定する期間(分)',   60 * 8),
      Property.new('range_pips',    'レンジ相場とみなす値幅(pips)', 100),
      Property.new('trailing_stop_pips',
        'トレールストップで決済する値幅(pips)',                       30),
      Property.new('trade_units',   '取引数量',                      1)
    ]
  end

  def post_create
    pair = broker.pairs.find { |p| p.name == @target_pair.to_sym }
    @checker = RangeBreakChecker.new(
      pair, @range_period.to_i, @range_pips.to_i)
  end

  def next_tick(tick)
    # レンジブレイクしたかどうかチェック
    result = @checker.check_range_break(tick)
    # ブレイクしていたら通知を送る
    send_notification(result) if result[:state] != :no
  end

  def execute_action(action)
    case action
    when 'range_break_buy'  then buy
    when 'range_break_sell' then sell
    else '不明なアクションです'
    end
  end

  def state
    { checker: @checker.state }
  end

  def restore_state(state)
    @checker.restore_state(state[:checker]) if state[:checker]
  end

  private

  def sell
    broker.sell(@target_pair.to_sym, @trade_units.to_i, :market, {
      trailing_stop: @trailing_stop_pips.to_i
    })
    '売注文を実行しました'
  end

  def buy
    broker.buy(@target_pair.to_sym, @trade_units.to_i, :market, {
      trailing_stop: @trailing_stop_pips.to_i
    })
    '買注文を実行しました'
  end

  def send_notification(result)
    message = "#{@target_pair} #{result[:price]}" \
      + ' がレンジブレイクしました。取引しますか?'
    @notifier.push_notification(message, [create_action(result)])
    logger.info "#{message} #{result[:state]} #{result[:time]}"
  end

  def create_action(result)
    if result[:state] == :break_high
      { 'label'  => '買注文を実行', 'action' => 'range_break_buy' }
    else
      { 'label'  => '売注文を実行', 'action' => 'range_break_sell' }
    end
  end

end

class RangeBreakChecker

  def initialize(pair, period, range_pips)
    @pair       = pair
    @range_pips = range_pips
    @candles    = Candles.new(period * 60)
  end

  def check_range_break(tick)
    tick_value = tick[@pair.name]
    result = check_state(tick_value, tick.timestamp)
    @candles.reset unless result == :no
    # 一度ブレイクしたら、一旦状態をリセットして次のブレイクを待つ
    @candles.update(tick_value, tick.timestamp)
    {
      state: result,
      price: tick_value.bid,
      time:  tick.timestamp
    }
  end

  def state
    @candles.state
  end

  def restore_state(state)
    @candles.restore_state(state)
  end

  private

  # レンジブレイクしているかどうか判定する
  def check_state(tick_value, time)
    highest = @candles.highest
    lowest  = @candles.lowest
    return :no if highest.nil? || lowest.nil?
    return :no unless over_period?(time)

    diff = highest - lowest
    return :no if diff >= @range_pips * @pair.pip
    calculate_state( tick_value, highest, diff )
  end

  def calculate_state( tick_value, highest, diff )
    center = highest - diff / 2
    pips = @range_pips / 2 * @pair.pip
    if tick_value.bid >= center + pips
      return :break_high
    elsif tick_value.bid <= center - pips
      return :break_low
    end
    :no
  end

  def over_period?(time)
    oldest_time = @candles.oldest_time
    return false unless oldest_time
    (time.to_i - oldest_time.to_i) >= @candles.period
  end

end

class Candles

  attr_reader :period

  def initialize(period)
    @candles     = []
    @period      = period
    @next_update = nil
  end

  def update(tick_value, time)
    time = Candles.normalize_time(time)
    if @next_update.nil? || time > @next_update
      new_candle(tick_value, time)
    else
      @candles.last.update(tick_value, time)
    end
  end

  def highest
    high = @candles.max_by { |c| c.high }
    high.nil? ? nil : BigDecimal.new(high.high, 10)
  end

  def lowest
    low = @candles.min_by { |c| c.low }
    low.nil? ? nil : BigDecimal.new(low.low, 10)
  end

  def oldest_time
    oldest = @candles.min_by { |c| c.time }
    oldest.nil? ? nil : oldest.time
  end

  def reset
    @candles     = []
    @next_update = nil
  end

  def new_candle(tick_value, time)
    limit = time - period
    @candles = @candles.reject { |c| c.time < limit }

    @candles << Candle.new
    @candles.last.update(tick_value, time)

    @next_update = time + (60 * 5)
  end

  def state
    {
      candles:     @candles.map { |c| c.to_h },
      next_update: @next_update
    }
  end

  def restore_state(state)
    @candles = state[:candles].map { |s| Candle.from_h(s) }
    @next_update = state[:next_update]
  end

  def self.normalize_time(time)
    Time.at((time.to_i / (60 * 5)).floor * 60 * 5)
  end

end

class Candle

  attr_reader :high, :low, :time

  def initialize(high = nil, low = nil, time = nil)
    @high = high
    @low  = low
    @time = time
  end

  def update(tick_value, time)
    price = extract_price(tick_value)
    @high = price if @high.nil? || @high < price
    @low  = price if @low.nil?  || @low > price
    @time = time  if @time.nil?
  end

  def to_h
    { high: @high, low: @low, time: @time }
  end

  def self.from_h(hash)
    Candle.new(hash[:high], hash[:low], hash[:time])
  end

  private

  def extract_price(tick_value)
    tick_value.bid
  end

end

機械学習手習い: サポートベクターマシン

「入門 機械学習」手習い、12日目。「12章 モデル比較」です。

www.amazon.co.jp

最後のアルゴリズム、サポートベクターマシン(SVM)を学び、最後に同じデータセットにロジスティック回帰やk近傍法など、今まで学んできたアルゴリズムを適用して比較します。

# 前準備
> setwd("12-Model_Comparison/")

サポートベクターマシン(SVM)

サポートベクターマシンは、分類モデルの一つで、ロジスティック回帰と違い、非線形の決定境界を持つデータもうまく分類することができます。

例えば、以下のようなデータ。決定境界が1つの線で表せないため、ロジスティック回帰ではうまく分類できません。

> library('ggplot2')

# データを読み込み
> df <- read.csv(file.path('data', 'df.csv'))
> head(df)
          X          Y Label
1 0.2655087 0.52601906     1
2 0.3721239 0.07333542     1
3 0.5728534 0.84974175     1
4 0.9082078 0.42305801     0

> ggplot(df, aes(x = X, y = Y, color = factor(Label))) + geom_point()
> ggsave(filename="plot01.png")

f:id:unageanu:20160121150757p:plain

試しに、ロジスティック回帰を使って、任意の点がどのラベルに属するか判定してみます。

> logit.fit <- glm(Label ~ X + Y,
    family = binomial(link = 'logit'), data = df)

> logit.predictions <- ifelse(predict(logit.fit) > 0, 1, 0)
> mean(with(df, logit.predictions == Label)) 
[1] 0.5156

精度は51.5%。クラスは2つかないので、ランダムに選んだ場合とほぼ変わらない精度です。

次に、SVMで分類した場合の精度も計測してみます。

> library('e1071')

> svm.fit <- svm(Label ~ X + Y, data = df)
> svm.predictions <- ifelse(predict(svm.fit) > 0, 1, 0)
> mean(with(df, svm.predictions == Label))
[1] 0.7204

72%になりました。ロジスティック回帰よりうまく分類できているようです。

各モデルがどのような判定を行っているか、図示してみます。

> library("reshape")
> df <- cbind(df, data.frame(Logit = ifelse(predict(logit.fit) > 0, 1, 0),
  SVM = ifelse(predict(svm.fit) > 0, 1, 0)))

> predictions <- melt(df, id.vars = c('X', 'Y'))

> ggplot(predictions, aes(x = X, y = Y, color = factor(value))) +
  geom_point() + facet_grid(variable ~ .)
> ggsave(filename="plot02.png")

f:id:unageanu:20160121150758p:plain

上から、分類対象のデータ、ロジスティック回帰で各点を分類した結果、SVMで各点を分類した結果、です。

ロジスティック回帰は全部0にしている・・・。一方、SVMだと元データの特徴を完全ではないですがうまく捉えているようです。

カーネルトリック

SVMでは、カーネルトリックという手法を使って、非線形の決定境界を生成できるようになっています。 カーネルトリックは、数学的な変換で元のデータセットを新しい空間に移動することで、決定境界を線形でも記述しやすくすします。

何種類かのカーネルで分類を試し、結果を比較してみます。 svm 関数では、引数で利用するカーネルを変更できるので、それを使います。

> df <- df[, c('X', 'Y', 'Label')]

# 線形カーネル
> linear.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'linear')

# 多項式カーネル
> polynomial.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'polynomial')

# ガウスカーネル
> radial.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'radial')

# シグモイドカーネル
> sigmoid.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'sigmoid')

> df <- cbind(df,
            data.frame(LinearSVM = ifelse(predict(linear.svm.fit) > 0, 1, 0),
                       PolynomialSVM = ifelse(predict(polynomial.svm.fit) > 0, 1, 0),
                       RadialSVM = ifelse(predict(radial.svm.fit) > 0, 1, 0),
                       SigmoidSVM = ifelse(predict(sigmoid.svm.fit) > 0, 1, 0)))

> predictions <- melt(df, id.vars = c('X', 'Y'))

> ggplot(predictions, aes(x = X, y = Y, color = factor(value))) +
  geom_point() + facet_grid(variable ~ .)
> ggsave(filename="plot03.png")

f:id:unageanu:20160121150759p:plain

線形カーネル、多項式カーネルはロジスティック回帰と同様、うまく分類できていない感じ。 ガウスカーネルは、正解に近い境界を生成できています。

今度は、パラメータを変えたパターンを試してみます。 多項式カーネルでは、次数をパラメータとして指定できるので、3,5,10,12のパターンで判定を行ってみます。

# 次数を変更したパターン
> polynomial.degree3.svm.fit <- svm(
  Label ~ X + Y, data = df, kernel = 'polynomial', degree = 3)
> polynomial.degree5.svm.fit <- svm(
  Label ~ X + Y, data = df, kernel = 'polynomial', degree = 5)
> polynomial.degree10.svm.fit <- svm(
  Label ~ X + Y, data = df, kernel = 'polynomial', degree = 10)
> polynomial.degree12.svm.fit <- svm(
  Label ~ X + Y, data = df, kernel = 'polynomial', degree = 12)

> df <- df[, c('X', 'Y', 'Label')]
> df <- cbind(df, data.frame(
  Degree3SVM = ifelse(predict(polynomial.degree3.svm.fit) > 0, 1, 0),
  Degree5SVM = ifelse(predict(polynomial.degree5.svm.fit) > 0, 1, 0),
  Degree10SVM = ifelse(predict(polynomial.degree10.svm.fit) > 0, 1, 0),
  Degree12SVM = ifelse(predict(polynomial.degree12.svm.fit) > 0, 1, 0)
))
> predictions <- melt(df, id.vars = c('X', 'Y'))
> ggplot(predictions, aes(x = X, y = Y, color = factor(value))) +
  geom_point() + facet_grid(variable ~ .)
> ggsave(filename="plot04.png")

f:id:unageanu:20160121150800p:plain

次数が3,5の場合は、指定なしの場合と変わりませんが、10,12だと少し判定できるようになっています。

次は、ガウスカーネルの cost パラメータを変えてみます。 cost は正則化の強さを示すパラメータで、値を大きくすると正則化が強く働き、訓練データに当てはまりにくくなります。

> radial.cost1.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'radial', cost = 1)
> radial.cost2.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'radial', cost = 2)
> radial.cost3.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'radial', cost = 3)
> radial.cost4.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'radial', cost = 4)

> df <- df[, c('X', 'Y', 'Label')]
> df <- cbind(df,
            data.frame(Cost1SVM = ifelse(predict(radial.cost1.svm.fit) > 0, 1, 0),
                       Cost2SVM = ifelse(predict(radial.cost2.svm.fit) > 0, 1, 0),
                       Cost3SVM = ifelse(predict(radial.cost3.svm.fit) > 0, 1, 0),
                       Cost4SVM = ifelse(predict(radial.cost4.svm.fit) > 0, 1, 0)))

> predictions <- melt(df, id.vars = c('X', 'Y'))
> ggplot(predictions, aes(x = X, y = Y, color = factor(value))) +
  geom_point() + facet_grid(variable ~ .)
> ggsave(filename="plot05.png")

f:id:unageanu:20160121150801p:plain

最後に、 シグモイドカーネルの gamma パラメータを試して終わり。

> sigmoid.gamma1.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'sigmoid', gamma = 1)
> sigmoid.gamma2.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'sigmoid', gamma = 2)
> sigmoid.gamma3.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'sigmoid', gamma = 3)
> sigmoid.gamma4.svm.fit <- svm(Label ~ X + Y, data = df, kernel = 'sigmoid', gamma = 4)


> df <- df[, c('X', 'Y', 'Label')]
> df <- cbind(df,
            data.frame(Gamma1SVM = ifelse(predict(sigmoid.gamma1.svm.fit) > 0, 1, 0),
                       Gamma2SVM = ifelse(predict(sigmoid.gamma2.svm.fit) > 0, 1, 0),
                       Gamma3SVM = ifelse(predict(sigmoid.gamma3.svm.fit) > 0, 1, 0),
                       Gamma4SVM = ifelse(predict(sigmoid.gamma4.svm.fit) > 0, 1, 0)))

> predictions <- melt(df, id.vars = c('X', 'Y'))
> ggplot(predictions, aes(x = X, y = Y, color = factor(value))) +
  geom_point() + facet_grid(variable ~ .)
> ggsave(filename="plot06.png")

f:id:unageanu:20160121150802p:plain

ガンマを変えると決定境界の形が変わります。(としか、書いてくれてない・・・)

アルゴリズムを比較する

SVNと、これまでに学習した、ロジスティック回帰/k近傍法を、同じデータセットに当てはめて比較してみます。

# 対象データの読み込みとクリーニング
> load(file.path('data', 'dtm.RData'))

> set.seed(1)

# 訓練データとテストデータに分割
> training.indices <- sort(sample(1:nrow(dtm), round(0.5 * nrow(dtm))))
> test.indices <- which(! 1:nrow(dtm) %in% training.indices)

> train.x <- dtm[training.indices, 3:ncol(dtm)]
> train.y <- dtm[training.indices, 1]

> test.x <- dtm[test.indices, 3:ncol(dtm)]
> test.y <- dtm[test.indices, 1]

> rm(dtm)

まずは、ロジスティック回帰。

> library('glmnet')
> regularized.logit.fit <- glmnet(train.x, train.y, family = c('binomial'))

最適な Lambda の値を、6章 と同様の手順で求めます。

> lambdas <- regularized.logit.fit$lambda
> performance <- data.frame()

> for (lambda in lambdas) {
  predictions <- predict(regularized.logit.fit, test.x, s = lambda)
  predictions <- as.numeric(predictions > 0)
  mse <- mean(predictions != test.y)
  performance <- rbind(performance, data.frame(Lambda = lambda, MSE = mse))
}

> ggplot(performance, aes(x = Lambda, y = MSE)) +
  geom_point() + scale_x_log10()
> ggsave(filename="lambda.png")

f:id:unageanu:20160121150756p:plain

1e-03 ぐらいがよさそう。 min を使ってMSEが最小になる Lambda を計算します。

# MSEが最小になる `Lambda` を計算。
# 今回のデータだと最小になるものが2つ存在するので、max を使って大きい(=より正則化が厳しい)方を選択している
> best.lambda <- with(performance, max(Lambda[which(MSE == min(MSE))]))

Lambda が決まったので、最終的なMSEを計算します。

> mse <- with(subset(performance, Lambda == best.lambda), MSE)
[1] 0.06830769

0.06 になりました。

次は、線形カーネルSVM。

> library('e1071')
> linear.svm.fit <- svm(train.x, train.y, kernel = 'linear')

> predictions <- predict(linear.svm.fit, test.x)
> predictions <- as.numeric(predictions > 0)
> 
> mse <- mean(predictions != test.y)
> mse
[1] 0.128

0.12 で、ロジスティックス回帰よりも悪い結果になりました。

次、ガウスカーネルSVM。

> linear.svm.fit <- svm(train.x, train.y, kernel = 'radial')

> predictions <- predict(linear.svm.fit, test.x)
> predictions <- as.numeric(predictions > 0)
 
> mse <- mean(predictions != test.y)
> mse
[1] 0.1421538

今回のデータセットでは、ロジスティック回帰や、線形カーネルSVMよりも悪い結果になりました。 これは、このデータでは理想的な決定境界が線形に近い可能性を示唆しています。

最後は、k近傍法です。

> library('class')

> knn.fit <- knn(train.x, test.x, train.y, k = 50)

> predictions <- as.numeric(as.character(knn.fit))

> mse <- mean(predictions != test.y)
> mse
[1] 0.1396923

kの値を5~50で変えて、一番良い値を使うようチューニングしてみます。

> performance <- data.frame()

# 5~50でkの値を変えて試行
> for (k in seq(5, 50, by = 5)) {
  knn.fit <- knn(train.x, test.x, train.y, k = k)
  predictions <- as.numeric(as.character(knn.fit))
  mse <- mean(predictions != test.y)
  performance <- rbind(performance, data.frame(K = k, MSE = mse))
}

# もっと良いkの値を取り出す。
> best.k <- with(performance, K[which(MSE == min(MSE))])

> best.mse <- with(subset(performance, K == best.k), MSE)
> best.mse
[1] 0.09169231

チューニングの結果、0.09まで改善しました。

ということで、この問題に対してはロジスティック回帰が一番適しているという結論になりました。

以上から、得られる教訓。

  • 実際のデータセットに取り組むときは、複数のアルゴリズムを試した方が良い。
  • 最適なアルゴリズムは、問題の構造に依存する。
  • モデルの性能は、パラメータにも依存する。良い結果を得たければパラメータの調整にも時間をかける

感想

  • 回帰分析とか、最適化について、入門くらいはできたかな。
  • 理論もいいけど、具体的にデータを触って、動かして学びたいプログラマ向けには悪くない本かと。(裏表紙にもそのようなコンセプトの本ですと書かれていますし)
  • ただ、Amazonの書評の通り、理論的説明はあまりないので、そこは別の本が必要な感じです。
    • この本の知識だけで Wikipedia とかに行くと ファッ てなります。
  • 機械学習自体は面白かった。これを使って何か作りたい。

おしまい。

機械学習手習い: ソーシャルグラフの分析

「入門 機械学習」手習い、11日目。「11章 ソーシャルグラフの分析」です。

www.amazon.co.jp

Twitterのソーシャルグラフの可視化をためし、グラフからおススメの友達を推薦するシステムを作ります。

# 前準備
> setwd("11-SNA/")

ローカルコミュニティ構造の可視化

最初の例では、ユーザー johnmyleswhite のフォロワーが、どのようなコミュニティ構造を持っているかを分析します。

ユーザー johnmyleswhite とそのユーザーが直接フォローしているユーザーのグラフ(ユーザーを中心とするエゴネットワークと呼びます)を読み込み、 各フォロワー間の距離を算出、これをもとに hclust で階層的クラスタリングを行います。

# グラフデータの読み込み
> user <- 'johnmyleswhite'
> user.ego <- read.graph("data/johnmyleswhite/johnmyleswhite_ego.graphml", format='graphml')

# ノード間の距離を算出
> user.sp <- shortest.paths(user.ego)
# 階層的クラスタリングで、フォロワーのコミュニティ構造を算出
> user.hc <- hclust(dist(user.sp))

# 可視化
> png(paste('../images/', user, '_dendrogram.png', sep=''), width=1680, height=1050)
> plot(user.hc)
> dev.off()

f:id:unageanu:20160120165216p:plain

グラフから、ざっくりと2つの大きなコミュニティがあり、その下にさらに小さなサブコミュニティがある構成になっていることがわかるかと。

グラフデータからおススメの友達を推薦する

「友達の友達」は友達になる確率が高い、の仮定のもと、グラフデータからおススメの友達を推薦してみます。

まずは、グラフデータの読み込み。

# おススメフォロワーを推薦する対象とするユーザー名
> user <- "drewconway"


# グラフデータの読み込み
> user.graph <- suppressWarnings(read.graph(paste("data/", user, "/", user, "_net.graphml", sep = ""), format = "graphml"))

グラフから、「友達の友達」を友達候補として取り出します。 「多くの友達が友達としている候補は適性が高い」とみなして順位づけして、ソート。

# "drewconway" がフォローしているユーザー(=友達)の一覧を取り出す
> friends <- V(user.graph)$name[neighbors(user.graph, user, mode = "out") + 1]
[1] "311nyc"       "aaronkoblin"  "abumuqawama"  "acroll"       "adamlaiacano"
[6] "aeromax" 

# グラフのエッジの一覧を取り出す
> user.el <- get.edgelist(user.graph)
> head(user.el)
     [,1]         [,2]          
[1,] "drewconway" "311nyc"      
[2,] "drewconway" "aaronkoblin" 
[3,] "drewconway" "abumuqawama" 
[4,] "drewconway" "acroll"      
[5,] "drewconway" "adamlaiacano"
[6,] "drewconway" "aeromax"     

# 友達の友達が2番目の要素(ターゲット)に含まれる行を取り出す。
# ただし、すでにフォロー済み(=友達)になっているユーザーは除く
> non.friends <- sapply(1:nrow(user.el), function(i) {
  ifelse(any(user.el[i,] == user | !user.el[i,1] %in% friends) | user.el[i,2] %in% friends, FALSE, TRUE)
})
> non.friends.el <- user.el[which(non.friends == TRUE),]
> head(non.friends.el)
     [,1]     [,2]          
[1,] "000988" "1000timesyes"
[2,] "000988" "10ch"        
[3,] "000988" "1mrankhan"   
[4,] "000988" "1ndus"       
[5,] "000988" "500startups" 
[6,] "000988" "_hoffman"   

# 友達候補ごとの友達の数を集計
> friends.count <- table(non.friends.el[,2])
> head(friends.count)
        ___emma   __damonwang__          __dave __davidflanagan         __iriss 
              1               1               2               3               1 
         __neha 
              1 

# データフレームに変換
> friends.followers <- data.frame(list(Twitter.Users = names(friends.count), 
    Friends.Following=as.numeric(friends.count)), stringsAsFactors = FALSE)
> head(friends.followers)
    Twitter.Users Friends.Following
1         ___emma                 1
2   __damonwang__                 1
3          __dave                 2
4 __davidflanagan                 3
5         __iriss                 1
6          __neha                 1

# 友達候補としての最適度を示す指標として、各友達候補をフォローしている友達比率を計算して使う。
# 多くの友達が友達としている候補は適性が高いとみなす。
> friends.followers$Friends.Norm <- friends.followers$Friends.Following / length(friends)
> head(friends.followers)                                                                                                                                        
    Twitter.Users Friends.Following Friends.Norm
1         ___emma                 1  0.003816794
2   __damonwang__                 1  0.003816794
3          __dave                 2  0.007633588
4 __davidflanagan                 3  0.011450382
5         __iriss                 1  0.003816794
6          __neha                 1  0.003816794

# お勧め度の指標でソート
> friends.followers <- friends.followers[with(friends.followers, order(-Friends.Norm)),]

データができたので、お勧め度順に上位6人を表示してみます。

# 上位6人を取得。
> head(friends.followers)                                                                                                                                        
      Twitter.Users Friends.Following Friends.Norm
13388       cshirky                80    0.3053435
21403    fredwilson                58    0.2213740
6950        bigdata                57    0.2175573
14062    dangerroom                57    0.2175573
55153 shitmydadsays                55    0.2099237
2025           al3x                54    0.2061069