記事一覧

VeLoRAによるMNISTと小さな改善

VeLoRAを実装してMNISTをします。
vはEMAを使って更新するとAccが改善することも示します。


前提

VeLoRAは簡単に説明すると、ニューラルネットの勾配計算に必要なメモリを削減する手法として、2024/05に提案されました。
前回の記事で説明したGaLoreと比較すると、GaLoreがSVDを用いて勾配を圧縮するのに対し、VeLoRAは勾配計算のために保存されるinputを圧縮することが特徴です。
論文では、SVDを用いないため、計算上のオーバーヘッドが生じにくいと主張しています。

VeLoRAの圧縮処理を示す式を置きます、動作はそんなに難しくないと思います。

キャプチャ
B×N×Dの入力Zがきたとき、ZはNDの任意の約数Mでサブトークンに分割されます。(この処理は単なるreshapeですね。)
さらに、M×1の行列vを掛けることで、zpはB×(ND/M)×1次元になります。
逆に、元に戻す処理はv.Tを掛けてからreshapeすればよいです。

問題はvをどのように初期化するかですが、論文によると最初のバッチのサブトークンの平均でよいそうです。
ホンマかいな。


本題

面白そうですが実装が無かったので、自前で実装しました。
ただし、PyTorchで実装するのはforward/backwardに手を入れるのが面倒そうだったので、みんな大好きゼロから本の実装を改造する形で実装しました。
実装はGitHubで公開しています。

Screenshot from 2024-05-31 18-47-34

次に、このモデルでMNISTをやります。
学習を試したかったらch05のtrain_velora.pyを実行してください。

モデルはL1=(784, 56)、L2=(56, 10)の2層Linearです。
MNISTのデータは(bs, 784)なので、ND=784と置いて実装しました。


実験結果

ひとまず論文通りの実装での学習結果。
notuse_ema.png

一応学習はできているものの、思ったより伸びないなぁという感想です。
デフォルトの2層Linearの実装でも90%は軽く超えてたと思うので、かなり悪く見えます。

やっぱりvを最初に定義してずっと固定しているのがまずいんじゃないでしょうか?
というわけで、vをEMAで更新するアプローチに改造してみます。

結果:
use_ema.png

だいぶマシになった。
計算コスト的に許されるのであれば更新した方がよさげです。
少なくとも今回の実験では。


感想

この手法とかGaLoreを見ていると、勾配計算って割とどんぶり勘定でいいんじゃね?という気分になってきます。

それでは。

コメント

コメントの投稿

非公開コメント

検索フォーム

プロフィール

birdMan

Author:birdMan
作ったもの(特に有志作成のmodや翻訳ファイル)を検索に引っかかるように置いてます.あとノウハウの共有備忘が目的です.

注意:
本ブログではアフィリエイトは一切使用していません.
何か連絡があったら下にSteamのプロフィールへのリンクを載せているのでそこへお願いします.