GaLoreにおける低ランク行列への射影を改善する
- 2024/04/21
- 08:54
GaLoreによる勾配の低ランク行列への射影において、SVDの上位ランクではなく重み付き確率を用いることで、損失の低下が改善することを示します。
前提
GaLoreはメモリ効率のよい機械学習の手法として、2024/04に提案されました。
簡単に動作を説明します。
GaLoreでは、勾配はSVDによって特異値分解されたあと次元削減されます。
これによって、従来よりも小さな空間に対してAdamなどの最適化関数を適用できるため、メモリの使用量が削減されるという理屈です。
NNの効率的な学習手法としてはLoRAが有名ですが、LoRAとGaLoreは大まかには以下の点で異なります。
- GaLoreは低ランクの重みを新たに置くのではなく、勾配を低ランク行列に射影します。つまり、学習対象となるパラメタはモデル全体です。
- LoRAは特殊な例を除いて事前学習が済んでいることを前提としていますが、GaLoreは事前学習でもチューニングでも利用可能です。
GaLoreは理屈が明快かつ実用性が高いので、注目している人は多いのではないでしょうか?
本題
しかし、ここで一点気に入らないことがあります。
それは、勾配の次元削減時に常に上位ランクを使用しているために、微細な情報が捨てられ、表現力が損なわれる、あるいは過学習の原因となるのではないかという懸念です。
そこで、特異値上位の特異ベクトルで勾配を射影するのではなく、重み付き確率を用いる方法を考えてみましょう。
実装は簡単なので、コードを以下に示します。
実験
それではこの改良が効果を発揮するか実験を行います。
実験内容は、LLaMA 7bの学習において上位ランクを常に使用する方法と、重み付き確率を利用する方法の2種類を使用し、損失の推移を比較するという内容です。
学習用のコードは公式実装を利用します。
学習には、以下の設定を用いました。
eval lossの画像です。オレンジがtop rankを取った場合の結果で5.34、ピンクがSoftmaxを通した場合の結果で4.83となりました。
今回のケースに関しては、およそ10%の改善でした。
念の為、Training Lossも貼っておきます。
感想
- 思いつき程度のネタでしたが、この程度の改変でこんなに変わると思っていなかったのでびっくりしました。
- 今回は最も単純なSoftmaxを用いましたが、もう少しマシな取り出し方があるかもしれないです。(温度をつけてみたり、cosを使ったり、リニアに減少させたり、あるいは完全にランダムにしたり)
- うまく行かないケースの探索も含め追試したい気持ちはあるのですが、計算資源がね…
それでは。