VeLoRAによるMNISTと小さな改善
- 2024/05/31
- 19:20
VeLoRAを実装してMNISTをします。
vはEMAを使って更新するとAccが改善することも示します。
前提
VeLoRAは簡単に説明すると、ニューラルネットの勾配計算に必要なメモリを削減する手法として、2024/05に提案されました。
論文では、SVDを用いないため、計算上のオーバーヘッドが生じにくいと主張しています。
VeLoRAの圧縮処理を示す式を置きます、動作はそんなに難しくないと思います。
B×N×Dの入力Zがきたとき、ZはNDの任意の約数Mでサブトークンに分割されます。(この処理は単なるreshapeですね。)
さらに、M×1の行列vを掛けることで、zpはB×(ND/M)×1次元になります。
逆に、元に戻す処理はv.Tを掛けてからreshapeすればよいです。
問題はvをどのように初期化するかですが、論文によると最初のバッチのサブトークンの平均でよいそうです。
ホンマかいな。
本題
面白そうですが実装が無かったので、自前で実装しました。
ただし、PyTorchで実装するのはforward/backwardに手を入れるのが面倒そうだったので、みんな大好きゼロから本の実装を改造する形で実装しました。
実装はGitHubで公開しています。
次に、このモデルでMNISTをやります。
学習を試したかったらch05のtrain_velora.pyを実行してください。
モデルはL1=(784, 56)、L2=(56, 10)の2層Linearです。
MNISTのデータは(bs, 784)なので、ND=784と置いて実装しました。
実験結果
ひとまず論文通りの実装での学習結果。
一応学習はできているものの、思ったより伸びないなぁという感想です。
デフォルトの2層Linearの実装でも90%は軽く超えてたと思うので、かなり悪く見えます。
やっぱりvを最初に定義してずっと固定しているのがまずいんじゃないでしょうか?
というわけで、vをEMAで更新するアプローチに改造してみます。
結果:
だいぶマシになった。
計算コスト的に許されるのであれば更新した方がよさげです。
少なくとも今回の実験では。
感想
この手法とかGaLoreを見ていると、勾配計算って割とどんぶり勘定でいいんじゃね?という気分になってきます。
それでは。