Dropout

データサイエンスについて勉強したことを書いていきます。機械学習、解釈性、因果推論など。

XGBoostの論文を読んだのでGBDTについてまとめた

はじめに

今更ですが、XGboostの論文を読んだので、2章GBDT部分のまとめ記事を書こうと思います。*1
この記事を書くにあたって、できるだけ数式の解釈を書くように心がけました。数式の意味をひとつひとつ追っていくことは、実際にXGBoost(またはLightGBMやCatBoostなどのGBDT実装)を使う際にも役立つと考えています。たとえばハイパーパラメータがどこに効いているかを理解することでチューニングを効率化したり、モデルを理解することでよりモデルに合った特徴量のエンジニアリングができるのではないかと思います。

また、この記事に限りませんが、記述に間違いや不十分な点などあればご指摘頂ければ嬉しいです。

XGBoost論文

目的関数の設定


一般的な状況として、サンプルサイズが Iで特徴量の数が Mのデータ \mathcal{D} = \left\{ (\mathbf{x}_i, y_i) \right\}(i \in\mathcal{I} = \{1, \dots, I\}, \; \mathbf{x}_i \in \mathbb{R}^M, \; y_i \in\mathbb{R})に対する予測モデルを構築することを想定しましょう。*2
今回はツリーをアンサンブルした予測モデルを構築します。
 \mathcal{K}  =\{1, \dots, K\}のツリーを加法的に組み合わせた予測モデルは以下のように定式化できます。*3


\begin{align}
\hat{y}_{i} &= \phi\left(\mathbf{x}_{i}\right)=\sum_{k\in\mathcal{K}} f_{k}\left(\mathbf{x}_{i}\right),\\
\text{where}\quad f_{k} \in \mathcal{F} &= \left\{f(\mathbf{x})=w_{q(\mathbf{x})}\right\}\left(q : \mathbb{R}^{m} \rightarrow \mathcal{T},\; \mathcal{T} = \{1, \dots, T\}, \; w \in \mathbb{R}^{T}\right)
\end{align}

ここで、 f_kはひとつひとつのツリーを表しています。ツリー f(\mathbf{x})は特徴量 \mathbf{x} が与えられると、それを q(\mathbf{x})に従って各ノード t = 1, \dots, Tに紐づけ、それぞれのノードに対応する予測値 w_{q(\mathbf{x})}を返します。そして、ひとつひとつのツリーの予測値を足し合わせることで、最終的な予測結果 \hat{y}_iとします。



では、具体的にツリーをどうやって作っていくかを決めるために、最小化したい目的関数 \mathcal{L}(\phi)を設定します。

\begin{align}
\mathcal{L}(\phi) &= \sum_{i\in \mathcal{I}} l\left(y_{i}, \hat{y}_{i}\right)+\sum_{k\in \mathcal{K}} \Omega\left(f_{k}\right), \\
\text{where}\quad\Omega(f) &= \gamma T+\frac{1}{2} \lambda\|w\|^{2} = \gamma T+\frac{1}{2} \lambda \sum_{t \in \mathcal{T}} w_{t}^{2}
\end{align}

ここで、 l(y_{i}, \hat{y}_{i})は損失関数で、たとえば二乗誤差になります。ただし、単に二乗誤差を最小化するのではなく、過適合を回避して汎化性能を上げるために正則化 \Omega(f)が追加されています。なお、 \gamma \lambda はハイパーパラメータであり、交差検証などで最適な値を探索する必要があります。

  •  \Omega(f)の第一項 \gamma Tはツリーのノードの数に応じてペナルティが課されるようになっています。ハイパーパラメータ \gammaを大きくするとよりノード数少ないツリーが好まれるようになります。ツリーの大きさに制限をかけることで過適合を回避することが目的です。
  •  \Omega(f)の第二項 \frac{1}{2}\lambda\|w\|^{2}は各ノードが返す値の大きさに対してペナルティがかかることを意味しています。ハイパーパラメータ \lambdaを大きくすると、(絶対値で見て)より小さい wが好まれるようになります。 wが小さいということは最終的な出力を決める \sum_{k\in\mathcal{K}} f_{k}部分で足し合わされる値が小さくなるので、過適合を避けることに繋がります。

勾配ブースティング

さて、目的関数 \mathcal{L}(\phi)を最小化するような K個のツリー構築したいわけですが、 K個ツリーを同時に構築して最適化するのではなく、 k個目のツリーを作る際には、 k-1個目までに構築したツリーを所与として、目的関数を最小化するようなツリーを作ることにしましょう。*4


\begin{align}
\min_{f_k}\;\mathcal{L}^{(k)}=\sum_{i\in\mathcal{I}} l\left(y_{i}, \hat{y}_i^{(k - 1)}+f_{k}\left(\mathbf{x}_{i}\right)\right)+\Omega\left(f_{k}\right)
\end{align}

このステップで作成する k個目のツリーを合わせた予測値は \hat{y}_i^{(k)} = \hat{y}_{i}^{(k - 1)}+f_{k}\left(\mathbf{x}_{i}\right)であり、 k-1個目までのツリーではうまく予測できていない部分に対してフィットするように新しいツリーを構築すると解釈できます。このように残差に対してフィットするツリーを逐次的に作成していく手法をブースティングと呼びます。



さて、損失関数 \sum_{i\in\mathcal{I}} l\left(y_{i}, \hat{y}_{i}^{(k - 1)}+f_{k}\left(\mathbf{x}_{i}\right)\right)を直接最適化するのではなく、その2階近似を最適化することにしましょう。後にわかるように、2次近似によって解析的に解を求めることができます。 f_k = 0の周りで2階のテイラー展開を行うと、目的関数 \mathcal{L}^{(k)}は以下で近似できます。

\begin{align}
\mathcal{L}^{(k)} &\approx \sum_{i\in\mathcal{I}}\left[l\left(y_{i}, \hat{y}^{(k - 1)}\right)+g_{i} f_{k}\left(\mathbf{x}_{i}\right)+\frac{1}{2} h_{i} f_{k}^{2}\left(\mathbf{x}_{i}\right)\right] + \Omega\left(f_{k}\right),\\
\text{where} \quad g_i &= \frac{\partial }{\partial \hat{y}^{(k - 1)}}l\left(y_{i}, \hat{y}^{(k - 1)}\right),\\
h_i &= \frac{\partial^2 }{\partial \left(\hat{y}^{(k - 1)}\right)^2}l\left(y_{i}, \hat{y}^{(k - 1)}\right)
\end{align}

ここで、 g_i h_iはそれぞれ損失関数の1階と2階の勾配情報になります。勾配情報を使ったブースティングなので勾配ブースティングと呼ばれています。*5
今回 f_kを動かすことで目的関数を最小化するので、 f_kと関係ない第一項は無視できます。

\begin{align}
\tilde{\mathcal{L}}^{(k)} &=\sum_{i\in\mathcal{I}}\left[g_{i} f_{k}\left(\mathbf{x}_{i}\right)+\frac{1}{2} h_{i} f_{k}^{2}\left(\mathbf{x}_{i}\right)\right]+\gamma T+\frac{1}{2} \lambda \sum_{t \in \mathcal{T}} w_{t}^{2} \\
&= \sum_{t \in \mathcal{T}}\left[\sum_{i \in \mathcal{I}_{t}} g_{i} f_{k}\left(\mathbf{x}_{i}\right)+\frac{1}{2} \sum_{i \in \mathcal{I}_{t}} h_{i} f_{k}^{2}\left(\mathbf{x}_{i}\right)\right]+\gamma T +\frac{1}{2} \lambda \sum_{t \in \mathcal{T}} w_{t}^{2}\\
&=\sum_{t \in \mathcal{T}}\left[\left(\sum_{i \in \mathcal{I}_{t}} g_{i}\right) w_{t}+\frac{1}{2}\left(\lambda + \sum_{i \in \mathcal{I}_{t}} h_{i}\right) w_{t}^{2}\right]+\gamma T
\end{align}

  • 1行目の式では、 f_kと関係ない第一項を取り除き、 \Omega(f_k)の中身を書き下しました。
  • 2行目への変換ですが、全ての i\in\mathcal{I}について足し合わせている部分を、まずノード tに所属する部分 i \in \mathcal{I}_t (\mathcal{I}_t = \{i | q(\mathbf{x}_i) = t \}) を足し合わせてから、全てのノード t \in \mathcal{T}について足し合わせるように分解しています。
  • 3行目への変換では、同じノードに所属する f_k(\mathbf{x}_i)は全て w_tを返すというツリーの性質を利用しています。また、 w^2_tの共通部分をくくっています。

さて、 \tilde{\mathcal{L}}^{(k)} w_tに関しての2次式なので、解析的に解くことができます。

\begin{align}
w_{t}^{*}=-\frac{\sum_{i \in \mathcal{I}_{t}} g_{i}}{\lambda + \sum_{i \in \mathcal{I}_{t}} h_{i}}
\end{align}

以上で、 k個目のツリーに関して、各ノードが返すべき値 w_t^*が解析的に求まりました。*6
この式からもハイパーパラメータ \lambdaを大きくすると w^*_tが(絶対値で見て)小さくなることが見て取れます。この w^*_tを元の目的関数に代入してあげることで

\begin{align}
\tilde{\mathcal{L}}^{(k)}(q)=-\frac{1}{2} \sum_{t\in\mathcal{T}} \frac{\left(\sum_{i \in \mathcal{I}_{t}} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}_{t}} h_{i}}+\gamma T
\end{align}

を得ます。あとはツリーの構造 q、言い換えれば特徴量の分割ルールを決める必要があります。たとえば、一番シンプルなケースとして、全く分割を行わない場合( \mathcal{I})と一度だけ分割を行う場合( \mathcal{I}_L, \mathcal{I}_Rに分割)を比較しましょう。分割による目的関数の値の減少分は

\begin{align}
\mathcal{L}_{\text{split}}&= -\frac{1}{2} \frac{\left(\sum_{i \in \mathcal{I}} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}} h_{i}}+\gamma - \left(-\frac{1}{2} \frac{\left(\sum_{i \in \mathcal{I}_L} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}_L} h_{i}}-\frac{1}{2} \frac{\left(\sum_{i \in \mathcal{I}_R} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}_R} h_{i}}+2\gamma \right)\\
&= \frac{1}{2}\left(\frac{\left(\sum_{i \in \mathcal{I}_L} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}_L} h_{i}}+\frac{\left(\sum_{i \in \mathcal{I}_R} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}_R} h_{i}} - \frac{\left(\sum_{i \in \mathcal{I}} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}} h_{i}}\right) - \gamma
\end{align}


であり、これがプラスなら分割を行い、マイナスなら分割を行わないということになります。上式からも、ハイパーパラメータ \gamma を大きくするとより分割が行われなくなることが見て取れます。

ところで、そもそもどの特徴量のどの値で分割するべきかなのでしょうか?一番ナイーブな考え方は、全ての変数に対して全ての分割点を考慮して、一番目的関数の値を減少させるような分割を選ぶというものがあります。ただし、この方法は膨大な計算量が必要になるため、XGBoostでは近似手法が提供されており、3章に記述されています。さらに、4章移行では並列計算や比較実験などが記されています。


まとめ

XGboostの論文を読んだので、自身の理解を深めるために2章GBDT部分のまとめ記事を書きました。
今までなんとなく使っていたハイパーパラメータが具体的にどの部分に効いているのか学ぶことができて、とても有意義だったと感じています。

*1:なお、元論文のノーテーションがおかしい/統一的でないように感じたので、一部表記を変更しています。

*2:元論文では |\mathcal{D}| = Nですが、 i Nが混じるとややこしい気もするので Iにしました。

*3:元論文の q : \mathbb{R}^{m} \rightarrow Tとなっているのですが、 qはインデックス 1, \dots, Tを返す関数なので、タイポかと思われます。

*4:元論文の添字 tがツリーの数 Tとややこしいので kのまま進めることにしました。

*5:たぶん

*6:論文にはこれがimpuityみたいなものと書かれているのですが、不勉強で理解できませんでした。