はじめに
Pyroで確率モデリングを書くときには「確率モデリング自体を知ること」と「Pyroの書き方を知ること」の両方が必要です。今回はPyroの書き方に重点をおいて、とある確率モデルを記述するためのPyroでのコード例を適当に記載します。
約束事として、観測変数(データ) $x$ に対して、このデータの生成にまつわるパラメータをすべてひっくるめて $\theta$ と記述することにします。例えば、 $x$ が正規分布由来とするならば $\theta = (\mu, \sigma)$ ということですが、扱う分布が複雑になるにつれ列挙するのが面倒になるので、この記事では $\theta$ にパラメータを表す記号として活躍してもらうことにします。
また基本的にはパラメータ $\theta$ に対しては事前分布が用意され、ベイズ的に取り扱う形になります。事前分布に必要な数値は、ハイパーパラメータとして具体的な値を与えて書きます。
またモジュールを下記のようにインポートしておきます。
import numpy as np import pandas as pd import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F import pyro import pyro.distributions as dist import pyro.infer as infer import pyro.optim as optim from pyro import sample, param, plate plt.style.use("seaborn")
単一の分布を使ったモデル
正規分布
同時分布の設計
$x$ をデータとします。正規分布のパラメータは $\theta = (\mu, \sigma)$ です。まずはパラメータとデータの同時分布を下記のように設定することにします。
$$ p(x, \theta) = p(\mu)p(\sigma)p(x \mid \mu, \sigma) $$
ここで、例えばパラメータ $\mu, \sigma$ に関しては独立に事前分布を与えるのではなく $p(\mu, \sigma)$ という2つの変数を同時に生成する同時分布を事前分布として準備しても構いません。それは平均 $\mu$ の値に応じて標準偏差が $\sigma$ 変わる、すなわち相関を持っているような事を想定する場合に実施します。
同時分布からのサンプリング
設計した同時分布からのデータとパラメータのサンプリングは事前分布を用いると実際に実施ができます。下記のように事前分布を含め統計モデルを準備することとしましょう。
$$ \begin{align} p(\mu) & = {\rm Normal}(\mu \mid 0, 5)\\ p(\sigma) & = {\rm InvGamma}(\sigma \mid 2, 2)\\ p(x _ i \mid \theta) &= {\rm Normal}(x _ i\mid \mu, \sigma) \end{align} $$
平均パラメータ $\mu$ は正規分布を事前分布としました。標準偏差パラメータ $\sigma$ は逆ガンマ分布を事前分布としました。逆ガンマ分布は非負の実数を生成する分布なので採用しました。実際には負の値を生成するような分布からサンプリングを行い、$\exp$ や $\rm softplus$ などで正の値に変換して使うという方法を取ることもできます。
Pyroコード
Pyroでのコードは下記になります。
ここで plate
コンテキストは、コンテキスト内の sample
が(条件付き)独立に得られることを表します。
plate(name_str, num_of_data)
という形式で、モデルを関数として呼び出したときに num_of_data
サイズのサンプリングを実施します。sample(obs=data)
とすることで、これは観測変数であるということをPyroに伝えることができ、円滑に $p(\theta \mid x)$ という観測データの条件付き確率(事後分布を)を計算することができます(この表記を用いなくても、pyroには既に書かれているモデルから条件付き確率を作り出すモジュールも準備されている)。
def gaussian(data=None): mu = sample("mu", dist.Normal(0., 5.)) sigma = sample("sigma", dist.InverseGamma(2., 2.)) num_of_data = len(data) if data is not None else 100 with plate("data", num_of_data): obs = sample("obs", dist.Normal(mu, sigma), obs=data) return obs obs = gaussian() plt.hist(obs)
ベルヌーイ分布
同時分布の設計
下記のように同時分布を設計します。$r$ はベルヌーイ分布のパラメータです。これは単に条件付き確率の定義に従っているだけになります。
$$ p(x, \theta) = p(r)p(x\mid r) $$
同時分布からのサンプリング
$$ \begin{align} p(r) & = {\rm Beta}(r \mid 1, 1)\\ p(x _ i \mid r) &= {\rm Bern}(x _ i\mid r) \end{align} $$
ベルヌーイ分布のパラメータの事前分布にはベータ分布を利用しました。ベータ分布は $0$ から $1$ の値を返す確率分布です。
Pyroコード
def bern(data=None): r = sample("r", dist.Beta(1, 1)) num_of_data = len(data) if data is not None else 100 with plate("data", num_of_data): obs = sample("obs", dist.Bernoulli(r), obs=data) return obs
カテゴリ分布
同時分布の設計
$\alpha$ をカテゴリ分布のパラメータとします。
$$ p(x, \theta) = p(\alpha) p(x \mid \mathbf \alpha) $$
同時分布からのサンプリング
$\alpha$ は各クラスの出現確率を表す成分を持つベクトルで、カテゴリ分布が3クラスを扱うとすると $\alpha$ も3次元になります。各成分の和は $1$ となっている必要があり、そのようなベクトルを生成する分布としてはディリクレ分布が使えます(当然、NNの常套手段である適当なベクトルをソフトマックス関数で正規化する方法を用いることもできる)。
$$ \begin{align} p(\mathbf \alpha) &= {\rm Dir}(\mathbf \alpha \mid (1, 1, 1)) \\ p(x \mid \mathbf \alpha) &= {\rm Categori} (x \mid \mathbf \alpha) \end{align} $$
pyroコード
Pyro
def categori(data=None): alpha = sample("alpha", dist.Dirichlet(torch.tensor([1., 1., 1.]))) num_of_data = len(data) if data is not None else 100 with plate("data", num_of_data): obs = sample("obs", dist.Categorical(alpha)) return obs obs = categori() plt.hist(obs)
混合モデル
次にベイズ統計モデリングを実施する場合に最もベーシックな応用事例となる混合モデルを扱います。 既にここまで読んでいると、同時分布の設計を行うことが同時分布からのサンプリングを考えることに相当しており、わざわざ同時分布を書き下すのは数式で解析的な議論をしない限りは特に不要そうであるので、以下では省略します。
ガウス混合モデル
ガウス混合モデルは潜在変数としてクラスタ $c$ があり、クラスタ $c$ 毎に異なる正規分布に従った観測データ $x$ が得られるというようなモデルになっています。数式で見るとなかなかに厄介に見えるのですが、実際に生成される過程を見るとそれ程難しくないように思えるはずです。
同時分布からのサンプリング
ここではクラスタの数を3つということにしておきましょう。
$\pi$ は負担率と呼ばれるベクトルでクラスタの数と同じ次元を持っています。各成分が、あるクラスタの出現する割合を格納しているイメージです。そしてその負担率 $\pi$ によってクラスタ $c$ が決まっています。クラスタの数毎に正規分布があり、データ $x$ はクラスタ $c$ の平均 $\mu _ c$ と標準偏差 $\sigma _ c$ をパラメータとする正規分布から生起します。
$$ \begin{align} p(\pi) & = {\rm Dir}(\pi \mid (1, 1, 1))\\ p(\mu _ k) & = {\rm Normal} (\mu _ k \mid 0, 10) \\ p(\sigma _ k) & = {\rm InvGamma} (\sigma _ k \mid 2, 2) \\ & (k = 1, 2, 3) \\ p(c _ k \mid \pi) &= {\rm Categori}(c _ k \mid \pi) \\ p(x _ i \mid \mu _ {c}, \sigma _ {c}) &= {\rm Normal}(x _ i \mid \mu _ c, \sigma _ c) \end{align} $$
Pyroコード
今回カテゴリの数は $3$ ということにしておきます。 正規分布やカテゴリ分布を単一で記述した際の知識をフルに使うので良い練習問題でしょう。
@infer.config_enumerate
は離散変数を確率変数として扱っている際のおまじないになります(リリース版では特に不要かもしれない。要確認)。
K = 3 @infer.config_enumerate def gmm(data=None): pi = pyro.sample('pi', dist.Dirichlet(torch.ones(K))) with pyro.plate('components', K): locs = pyro.sample('locs', dist.Normal(0., 10.)) scales = pyro.sample('scales', dist.InverseGamma(2., 2.)) num_of_data = len(data) if data is not None else 1000 with pyro.plate('data', num_of_data): c = pyro.sample('c', dist.Categorical(pi)) obs = pyro.sample('obs', dist.Normal(locs[c], scales[c]), obs=data) return obs obs = gmm() plt.hist(obs, bins=20)
ディリクレ過程混合モデル(某折過程モデル)
ディリクレ過程が何者なのかは他の解説に譲るとして、ディリクレ過程を用いるとクラスタの数をデータから適応的にモデル自身に決定させることができます。このことからノンパラメトリックベイズクラスタリングとしても知られています(後に見るように、本当にパラメータを何も設定しないわけではない。当然のことながら、事前分布の根っこまで行けばハイパーパラメータは必ず存在する。ここではクラスタリングにとって重要なパラメータであった「クラスタ数」をモデル自身に任せられるという意味で、その点でノンパラメトリックなのである。ただし、考えうるクラスタ数の上限に相当する値は与えなければならない)。
同時分布からのサンプリング
ここではクラスタの最大数を10ということにします。 もしもデータを見てもクラスタの数がそれ程多くないということであっても、多めに設定しておけばよいです。このモデルではもしもクラスタ数を多めに設定してしまったとしても推論の中でそのようなクラスタからデータが生成される確率が極めて低いという推論結果になります。
ここでは各クラスタはガウス分布に従っているとします。するガウス混合モデルでの負担率 $\pi$ をディリクレ過程から獲得することにすればやりたいことができます。ディリクレ過程の実現方法は幾つかありますが、今回は某折過程と呼ばれるものを用います。
$$ \begin{align} p(\beta _ k) & = {\rm Beta}(\beta _ k \mid 1, 1) \\ p(\mu _ k) & = {\rm Normal}(\mu _ k \mid 0, 10)\\ p(\sigma _ k) & = {\rm InvGamma}(\sigma _ k \mid 2, 2)\\ & (k = 1, ..., 10) \\ \pi _ k (\beta _ {1:10}) &= \beta _ k \prod _ {l < k} (1 - \beta _ l) \\ p(c _ i \mid \pi) &= {\rm Categori} (c _ i \mid \pi) \\ p(x _ i \mid \mu _ {c}, \sigma _ {c}) &= {\rm Normal}(x _ i \mid \mu _ c, \sigma _ c) \end{align} $$
Pyroコード
K = 10 def mix_weights(beta): beta1m_cumprod = (1 - beta).cumprod(-1) return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1) def dpmm(data=None): with pyro.plate("beta_plate", K-1): beta = pyro.sample("beta", dist.Beta(1, 1)) with pyro.plate("mu_plate", K): mu = pyro.sample("mu", dist.Normal(0., 10.)) sigma = pyro.sample("sigma", dist.InverseGamma(2., 2.)) num_of_data = len(data) if data is not None else 1000 with pyro.plate("data", num_of_data): z = pyro.sample("z", dist.Categorical(mix_weights(beta))) obs = pyro.sample("obs", dist.Normal(mu[z], sigma[z]), obs=data) return obs obs = dpmm() plt.hist(obs, bins=20)
最後に
余力があれば、上記のモデルらを用いて実際に推論を回してみると良いと思います。