注目のディープラーニングフレームワーク「PyTorch」入門

PyTorchに注目が集まっている

 最近「PyTorch」熱を感じます。日本のAIベンチャーの雄であるPFNが「Chainer」から「PyTorch」に移行しましたし、つい先日も、著名なAI研究団体のOpenAIがメインのフレームワークをPyTorchにすることを発表し、この勢いは続きそうですね。

 周りでもPyTorchを使い始めている人が多いように思います。それを裏付けるようにVengineerさんのアンケートでも「現在使っているディープラーニングフレームワーク」が堂々の1位となっていました(TensorFlowとKerasを合計すると、やはりTensorFlow+Kerasが強いのですが)。

 というわけで、ミーハーな私もPyTorchに入門してみました。

PyTorchでジャンケン画像認識

 ラズパイマガジン2月号のAI特集で、ジャンケンの手の形を題材に、画像認識のAIモデルを作成する記事を書きました。

 詳細は以下記事参照下さい。

 学習に関しては、Google Colaboratoryを活用する形になっています。書籍では、フレームワークとしてTensorFlow(Keras)を使ったのですが、今回はこのコードを自分の勉強のためにPyTorchを使って1から書き直してみました(完全移植はできていません)。

 今回は特別に(というわけでもないですが)公開しちゃいます。既存のデータセットを使わずに、カメラで集めた普通の画像を学習するチュートリアルは意外に少なかったりするので、それなりに需要はあるのではないかなと思います。

ラズパイマガジンAI特集(PyTorch版)

 PyTorchでのデータの扱いを中心に、コメントを多めにつけました。単体でも分かる人には分かると思いますが、さっぱり分からないという人は、是非ラズパイマガジン2月号を読みながら、コードを実行してみて下さい。流れ自体は、TensorFlowでの学習と全く同じにしていますので、対応するコードを確認しながら実行すると理解が深まるのではないかと思います。なお、PyTorch初学者なので、間違っているところや、変な書き方しているところがあったら、そっと教えていただけると助かります。

 ラズパイマガジン、紙の書籍も電子書籍もありますので、興味ある方は是非。

PyTorchでAIマリオのクリアにチャレンジ

 深層強化学習でのAIマリオの学習です。以下Qiita記事参照ください。

PyTorch入門に最適な情報

 自分がPyTorchを実践にするにあたり参考にしたネット情報や、今後試したいなと思っているものを、初心者向けのものを中心にメモがわりに紹介したいと思います。

公式情報

Welcome to PyTorch Tutorials — PyTorch Tutorials 1.13.1+cu117 documentation
 公式のチュートリアルが丁寧で分かりやすいです。公式チュートリアル日本語訳も公開されています!

[1912.01703] PyTorch: An Imperative Style, High-Performance Deep Learning Library
 PyTorchの論文

Practical Deep Learning for Coders - Practical Deep Learning
 fast.aiというPyTorchのラッパー。PyTorchはラッパーがたくさんあって、色々混在している印象ですが、fast.aiはその中でもTensorFlowのKeras的存在で、簡単にかける印象です。今回は使用しませんでしたが、こういったラッパーも使いこなすと更に楽にプログラムをかけそうですね。

実践例

 実践例です。

【深層距離学習】Center Lossを徹底解説 -Pytorchによる実践あり-|はやぶさの技術ノート
 深層距離学習をPyTorchで試せます。そのままめっちゃ動きました。

角度を用いた深層距離学習(deep metric learning)を徹底解説 -PytorchによるAdaCos実践あり-|はやぶさの技術ノート

PyTorchでシンプルな多層ニューラルネットワークを作ろう - Qiita

#1 Neural Networks : PyTorchチュートリアルをやってみた - Qiita

#2 Training a classifier : PyTorchチュートリアルをやってみた - Qiita

実践 PyTorch Lightning (2019/11/30 分析コンペLT会 #1) - Speaker Deck

pytorchによる画像分類入門 - Qiita

【詳細(?)】pytorch入門 〜CIFAR10をCNNする〜 - Qiita

Optunaでハイパーパラメータの自動チューニング -Pytorch Lightning編-|はやぶさの技術ノート

【PyTorch入門】Tensorの扱いから単回帰まで - HELLO CYBERNETICS

PyTorch object detection with pre-trained networks - PyImageSearch

TIPS・テクニック

pytorch超入門 - Qiita
 良い感じにまとまっています

Pytorchのニューラルネットの書き方 - HELLO CYBERNETICS

Pytorchでモデル構築するとき、torchsummaryがマジ使える件について - Qiita

backbone としての timm 入門

Tensorflow/Pytorch モデル移植のススメ - Speaker Deck

GitHubリポジトリ

GitHub - jfzhang95/pytorch-deeplab-xception: DeepLab v3+ model in PyTorch. Support different backbones.

GitHub - doiken23/DeepLab_pytorch: Repository for DeepLab family

GitHub - huggingface/pytorch-image-models: PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, EfficientNetV2, NFNet, Vision Transformer, MixNet, MobileNet-V3/V2, RegNet, DPN, CSPNet, and more

GitHub - kentaroy47/timm_speed_benchmark: Benchmarking the speed of timm models

PyTorch書籍

 書籍の紹介です。

つくりながら学ぶ! PyTorchによる発展ディープラーニング

 とりあえず購入した本です。

 画像判別からGANまで幅広く書かれていて良いなと思いました。書籍のサンプルコードは以下GitHubにあります。

GitHub - YutaroOgawa/pytorch_advanced: 書籍「つくりながら学ぶ! PyTorchによる発展ディープラーニング」の実装コードを配置したリポジトリです

 ただ、モデルはありものを転移学習して使うものがほとんどです。手っ取り早く応用を試したい人は良いのですが、自分のように1からモデルを学習させるようなコードを書くときは、あんまり参考にできませんでした。

まとめ

 PyTorchに入門して、簡単な画像判別のGoogle Colabのコードを書いて公開しました。PyTorch、個人的にはデータの扱いや可視化がTensorFlowより直感的で分かりやすいなと感じました。

 ただ、学習は自分でループ回すコードを書く必要があります。Keras的に「fitしたら終わり」ではなくちょっと煩雑ですね。Keras的なラッパーもあるのですが、色々混在していて若干カオスな感じです。最初は、自分の理解を深める上でも一度は自分で学習ループ書いてみるのが良いかなと感じました。

 今回、PyTorchで学習したモデルをJetson Nanoやラズパイで動かしてエッジコンピューティング的なことも簡単にできそうですね。そちらに関しては、また需要がありそうだったら詳しい記事を書こうかなと思います。

参考リンク

Deep Learningを学ぶための教材? - Vengineerの戯言

ディープラーニングモデル圧縮手法 Pruning を PyTorch でお試し - OPTiM TECH BLOG

GitHub - WindVoiceVox/PyTorchPractice: PyTorchで微分を計算する方法を説明することで、ニューラルネットの操作の一歩手前を理解する。

関連記事

変更履歴

  • 2020/12/24 日本語チュートリアルに関して追記、リンク切れ修正
  • 2020/05/08 微修正