前回の記事で無条件に使ってしまった関係(つまり『ゼロから作るDeep Learning』式(5.13))は、書中でも導出過程は省略されています。それについて深掘りしてみたいです。
ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装
- 作者: 斎藤康毅
- 出版社/メーカー: オライリージャパン
- 発売日: 2016/09/24
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (18件) を見る
式(5.13)は以下です。
\begin{align}
\frac{\partial L}{\partial \mathbf{X}} &= \frac{\partial L}{\partial \mathbf{Y}} \cdot \mathbf{W}^\mathrm{T} \\
\frac{\partial L}{\partial \mathbf{W}} &= \mathbf{X}^\mathrm{T} \cdot \frac{\partial L}{\partial \mathbf{Y}}
\end{align}
このうち、本質的に同じですから上の式についてのみ考えます。
ところで、どうもこの考え方、気持ち悪いです。私が知りたいのはむしろ、書中の乗算ノードの微分を考えたときと同様にとしたときのです。つまり、そういう演算があるかどうかはさておいて、
\begin{align}
\frac{\partial L}{\partial \mathbf{X}} &= \frac{\partial L}{\partial \mathbf{Y}} \cdot \frac{\partial \mathbf{Y}}{\partial \mathbf{X}}
\end{align}
から
\begin{align}
\frac{\partial \mathbf{Y}}{\partial \mathbf{X}} = \mathbf{W}^\mathrm{T}
\end{align}
を示せば(個人的に)納得できるというものです。なぜなら、成分ごとに考えれば
\begin{align}
\frac{\partial \mathbf{Z}}{\partial \mathbf{X}} = \frac{\partial \mathbf{Y}}{\partial \mathbf{X}}
\end{align}
となるからです。
さて、このような行列の微分の連鎖律(のようなもの)や、行列の行列での微分はこの文脈(problem statement)において成り立つものであって「一般に」ではないことに注意してください。例えば、行列の行列での微分は素直に考えれば4階の(?)テンソルになると思いますが、本問題では行列に落とし込んでいます。(でも、純粋な数学ならともかく、理工学なら考えている問題で合致すれば十分だとも思うのです。)
というわけで、
\begin{align}
\mathbf{X} &=
\begin{pmatrix}
x_{0, 0} & x_{0, 1} & \cdots & x_{0, J - 1} \\
x_{1, 0} & x_{1, 1} & \cdots & x_{1, J - 1} \\
\vdots & \vdots & \ddots & \vdots \\
x_{N - 1, 0} & x_{N - 1, 1} & \cdots & x_{N - 1, J - 1}
\end{pmatrix} \\
\mathbf{W} &=
\begin{pmatrix}
w_{0, 0} & w_{0, 1} & \cdots & w_{0, K - 1} \\
w_{1, 0} & w_{1, 1} & \cdots & w_{1, K - 1} \\
\vdots & \vdots & \ddots & \vdots \\
w_{J - 1, 0} & w_{J - 1, 1} & \cdots & w_{J - 1, K - 1}
\end{pmatrix}
\end{align}
に対して
\begin{align}
\mathbf{X} \cdot \mathbf{W} &= \mathbf{Z} \\
&=
\begin{pmatrix}
z_{0, 0} & z_{0, 1} & \cdots & z_{0, K - 1} \\
z_{1, 0} & z_{1, 1} & \cdots & z_{1, K - 1} \\
\vdots & \vdots & \ddots & \vdots \\
z_{N - 1, 0} & z_{N - 1, 1} & \cdots & z_{N - 1, K - 1}
\end{pmatrix} \\
&=
\begin{pmatrix}
\sum_j x_{0, j}w_{j, 0} & \sum_j x_{0, j}w_{j, 1} & \cdots & \sum_j x_{0, j}w_{j, K - 1} \\
\sum_j x_{1, j}w_{j, 0} & \sum_j x_{1, j}w_{j, 1} & \cdots & \sum_j x_{1, j}w_{j, K - 1} \\
\vdots & \vdots & \ddots & \vdots \\
\sum_j x_{N - 1, j}w_{j, 0} & \sum_j x_{N - 1, j}w_{j, 1} & \cdots & \sum_j x_{N - 1, j}w_{j, K - 1}
\end{pmatrix}
\end{align}
の微分
\begin{align}
\frac{\partial \mathbf{Z}}{\partial \mathbf{X}}
\end{align}
を考えます。
が微小変化したとき、に関与する成分を考えます。まず、については、の行目以外0になります。そこで、行目のみ取り出して縦ベクトルとして(後述)並べることにします。これをすべてのについて繰り返すと、
\begin{align}
\begin{pmatrix}
w_{0, 0} & w_{1, 0} & \cdots & w_{J - 1, 0} \\
w_{0, 1} & w_{1, 1} & \cdots & w_{J - 1, 1} \\
\vdots & \vdots & \ddots & \vdots \\
w_{0, K - 1} & w_{1, K - 1} & \cdots & w_{J - 1, K - 1}
\end{pmatrix}
= \mathbf{W}^\mathrm{T}
\end{align}
となります。これはによらず不変なので、これをということにします。
\begin{align}
\frac{\partial \mathbf{Z}}{\partial \mathbf{X}} = \mathbf{W}^\mathrm{T}
\end{align}
示せました。
なぜこんなことができたかというと、ドット積をとった時点で1つ、によらないということで1つ、テンソルの次元を節約できたからだと考えられます。つまり、
\begin{align}
\frac{\partial z_{n, k}}{\partial x_{n, j}}
\end{align}
からを消せて、のときと同様にベクトル、として次式のように考えることができたからです。
\begin{align}
\frac{\partial \mathbf{z}}{\partial \mathbf{x}}
\end{align}
ところで、先ほど縦ベクトルとして並べましたが、そのことに必然性はありません(強いて言えばにしたかったから)。つまり、
\begin{align}
\frac{\partial z_k}{\partial x_j} = \left(\frac{\partial \mathbf{z}}{\partial \mathbf{x}}\right)_{k, j}
\end{align}
とするか
\begin{align}
\frac{\partial z_k}{\partial x_j} = \left(\frac{\partial \mathbf{z}}{\partial \mathbf{x}}\right)_{j, k}
\end{align}
とするかは、一般に任意性があるようです。
以上、独りよがりな電波な自説となってしまいました。