DRFとNuxtを使って画像分類(機械学習)をする①
はじめに
インターンしている小林です.この記事では,DRF(Djangoのいい感じのフレームワーク)を使って,APIを作るまで行います.記事は二編構成とし,一編はDRFによるAPI作成,二編はNuxtを用いてユーザが実際に入力することを想定してフロント作成します.具体的には,PyTorchのresnetを用いて,入力フォームから受け付けられた画像を推論して上位10位までの結果を表示させます.一編では,詳細な機械学習のアルゴリズムは説明せずに,APIを作る工程に重きを向けます.読者の対象はDRFを初めたての人が対象であり,機械学習の画像処理をある程度把握している人が対象となります.
構築したAPIは以下のような感じになります.
結果で返しているのはresnet-18に入力した画像を推論させ,確率値が高い上位10個を表示させています.用いてるモデルはImageNetの学習済みモデルです.
DRFについて
「Django」は Python で Webアプリケーションを作成するためフレームワークですが、「Django REST Framework」という Django のためのパッケージを使うことで、RESTful な API バックエンドを簡単に構築することができます。実際の現場では、SPA(シングルページアプリケーション)やスマホアプリのバックエンドとしてよく利用されています(引用:現場で使える Django REST Framework の教科書 (Django の教科書シリーズ))。とのことですが,基本的にはDjangoで足りない所を補ってやりたいというのが,これを使っている理由です.ただし,色々な機能があるため少し重たいファイルであることはデメリットですが,それを超える良い機能が複数あるので慣れると使いやすいものかと思われます.
目次
環境構築
環境構築は以下を基本として構築しました.
https://qiita.com/michio-k/items/371881a6b8ecfa768606
ファイル構成は以下のようになります.
home |- backend | |- core(Djangoのプロジェクトが入る) | |- app(APIを作成) | |- Dockerfile | |- requirements.txt |- front | |- nuxt (フロントのプロジェクト) | |- Dockerfile |- .gitignore |- docker-compose.yml |- README.md
今回操作するのは上記のbackendの方となります.二編目でfrontの方をいじっていきます.画像のAPIを作成する上で,いくつかインストールする必要があるモジュールがあるので,記していきます.
FROM python:3.7 ENV PYTHONUNBUFFERED 1 RUN mkdir /code WORKDIR /code RUN apt-get update && apt-get install -y \ libblas-dev \ liblapack-dev\ libatlas-base-dev \ libsm6 \ libxext6 \ libxrender-dev ADD requirements.txt /code/ RUN pip install --upgrade pip RUN pip install --no-cache-dir -r requirements.txt
インストールするpythonモジュールです.ここでは,django-jsonfieldを使って辞書のデータを受け付けるようにします.実務的になると,PostgresやMySQLを使う方がいいと思われるので,そちらをデータベースとして参照する方が望ましいです.今回は簡易的なものなので,これを使わず,sqlite(デフォルトの設定)でやっていきます.これらのことは後に後述します.
* 補足として,DjangoでPostgresやMySQLなどでデータベースを使用したいときは,以下のサイトを参考にしてください. qiita.com
*
画像分類の処理では今回はPyTorchを使っていきます.pillowはDRF(Django)が画像を読み込む時に必要となるのでインストールしておきます.
# backend/requirements.txt #Django django djangorestframework django-filter django-cors-headers django-jsonfield #Extra numpy pillow opencv-python torch torchvision
次にdocker-compose.ymlを記述します.以下の通りになります.
version: '3' services: #front: #container_name: front #build: ./front #tty: true #ports: # - '3000:3000' #volumes: #- ./front/:/usr/src/app # command: [sh, -c, "cd nuxt/ && npm run dev"] backend: container_name: backend build: ./backend tty: true ports: - '8000:8000' volumes: - ./backend:/code # command: python manage.py runserver 0.0.0.0:8000
これにて環境構築が終わりです.補足ですが,Dockerを使わなくてもrequirements.txtに記述してあるモジュールをインストールしている環境であるならば,これからやることはできます.また,下のコードで先程構築した環境に入ることができます.
docker-compose build #Dockerfileの環境を立ち上げる docker-compose up -d #起動 docker exec -it backend bash #containerの中に入る
DRFのモジュール作成
DRFのプロジェクトとアプリの作成
これからDjangoのプロジェクトを作成していきます.この方法は通常のDjangoのやり方と変わりません.
django-admin startproject core . python manage.py startapp app
にてファイルを作成します.次にcore内のsettings.pyに今回作ったファイルとDRFを読み込ませます.また,各種必要なものを記述しておきます.MEDIAは画像の保存先を指定するために必要となりますので追記してください.
#python:core/settings.py ALLOWED_HOSTS = ["localhost"] INSTALLED_APPS = [ 'django.contrib.admin', 'django.contrib.auth', 'django.contrib.contenttypes', 'django.contrib.sessions', 'django.contrib.messages', 'django.contrib.staticfiles', "rest_framework", #add "app", #add ] #add MEDIA_URL = "/media/" MEDIA_ROOT = os.path.join(BASE_DIR,"media")
resnetのファイル
この記事での推論はresnetを使用します.また,簡易的なものであるため,自前の学習済みモデルを使用せずネット上に公開されているFinetuned-modelを利用します(ImageNetです).そのため,以下にしているファイルを事前にダウンロードしてください.
- https://download.pytorch.org/models/resnet18-5c106cde.pth
- https://github.com/raghakot/keras-vis/blob/master/resources/imagenet_class_index.json
以上で初期に用意するファイル一式の準備はできました.上記のファイルとresnet用に用意するファイルは以下のように作成してください.
backend ├── Dockerfile ├── app │ ├── __init__.py │ ├── admin.py │ ├── apps.py │ ├── models.py │ ├── resnet #add │ │ ├── config │ │ │ ├── imagenet_class_index.json │ │ │ └── resnet18-5c106cde.pth │ │ ├── model.py │ │ └── predict.py │ ├── tests.py │ └── views.py ├── core │ ├── __init__.py │ ├── asgi.py │ ├── settings.py │ ├── urls.py │ └── wsgi.py ├── manage.py └── requirements.txt
上記のようなファイル構成になっていると大丈夫です!
modelsの作成
今回APIとして必要になるのは以下の通りです.
- 入力:入力した画像の名前と入力する画像
- 出力:確率値が高い上位10までのラベル一覧と確率値
となります.そのため受け付けるフィールドは三つとなります.
#python:app/models.py from django.db import models import jsonfield # from django.contrib.postgres.fields import JSONField Jsonを受け付ける class ImageModel(models.Model): name = models.CharField(max_length = 128,null=True,default="unknown") image = models.ImageField(upload_to="media") predict = jsonfield.JSONField()
ここではjsonfieldというモジュールを使って,出力するための値を受け付けます.出力する値はjson形式にしたいのですが,Django内で提供されているJSONFieldはデータベースがpostgresやMySQLなどに対応しており,設定を変更しなければ使用できません.これはDjangoのデフォルトのデータベースがsqliteであり,対応していないためエラーが起こります.この問題を解決するために,今回は設定を省略し,jsonfieldというもので簡単にsqliteが受け付けられるようにしました.
また,今回作成したデータベースを登録するためにadminの内容を変更します.以下のように記述してください.
#app/admin.py from django.contrib import admin from .models import ImageModel @admin.register(ImageModel) class ImageModel(admin.ModelAdmin): pass
serializerの作成
続いてserializerの作成です.serializerはDRF特有のものであり,通常のDjangoにはありません.詳細は記事は以下のものが参考になるかと思いますので,乗せて起きます.
- https://note.crohaco.net/2018/django-rest-framework-serializer/
- https://www.django-rest-framework.org/api-guide/serializers/
- https://qiita.com/cohey0727/items/39308b67044391103d7f
やっていることは,入力されたデータの値がModelの中身で定義した型と一緒なのか?ということをやったり,Json形式で入力されたものをPythonで読み込めるようにしたりとそんなことをやっています.
appの直下にserializers.pyのファイルを作成し,以下のように記述します.
#app/serializers.py from rest_framework import serializers from .models import ImageModel class ImageSerializer(serializers.ModelSerializer): class Meta: model = ImageModel fields = ("id","name","image","predict") read_only_fields = ('predict',"id")
上記のように今回は書きました.入力としてはnameとimageのみなので,入力に必要ないものは外しています.また,上記のMetaに関する情報はhttps://teratail.com/questions/87695 が参考になるかと思いますので適宜参考にしてみてください.
viewsの作成
今回はDRFのビューはクラスベースビューを用いて,ModelViewSetを使って見ました.また,postを受け付けるコードをactionで対応するようにしました(これはdef post()メソッドを用いてもらっても大丈夫です).
#app/views.py from rest_framework import viewsets from rest_framework.response import Response from rest_framework.decorators import action from rest_framework import status from .models import ImageModel from .serializers import ImageSerializer from .resnet.predict import predict #resnetの予測 class ImageViewSet(viewsets.ModelViewSet): queryset = ImageModel.objects.all() serializer_class = ImageSerializer #check ユーザのどんなクエリを受け付けるか @action(detail=False,methods=["post"]) def classification(self,request): serializer = self.serializer_class(data = request.data) serializer.is_valid(raise_exception=True) img = request.data["image"] name = request.data["name"] res = predict(img) # 保存 item = ImageModel(name=name, image=img,predict = res) item.save() return Response(res, status=status.HTTP_200_OK)
serializer_classで受け付ける入力フォームを決めています.
@action(detail=False,methods=["post"]) def classification(self,request): serializer = self.serializer_class(data = request.data) serializer.is_valid(raise_exception=True) img = request.data["image"] name = request.data["name"] res = predict(img) # 保存 item = ImageModel(name=name, image=img,predict = res) item.save() return Response(res, status=status.HTTP_200_OK)
serializer = self.serializer_class(data = request.data)は入力されたデータが正しいかどうかを検証するためにいれています.極端な話,PDFのファイルが入力された時,エラーを出力してくれます.serializer.is_valid(raise_exception=True)を記述するとこの時点でエラーのデータがあるとエラーの文章で値が返されます.また,記述方法として,
if serializer.is_valid(): img = request.data["image"] name = request.data["name"] return Response(serializer.data,status=status.HTTP_200_OK) else: return Response(serializer.errors,status=status.HTTP_400_BAD_REQUEST)
があります.これは,if文章でTrueかFalseを処理して中身を実行するかどうかを判断しています.しかし,この書き方は若干情緒でもあるのでこれを省略しserializer.is_valid(raise_exception=True)だけを記述することで,上記のことと全く同じようにしてくれます.
ifの中身は,predict(img)でdictの結果を返してresで受け付けています.このresは先ほどの確率値の上位10個が入っている値の一覧が格納されています.これらのデータをItemModel()に入れ,保存しています.
urlの繋ぎこみ
繋ぎこみをします.app以下にurls.pyのファイルを作成して以下のように記述してください.
#app/urls.py from rest_framework import routers from .views import ImageViewSet router = routers.DefaultRouter() router.register(r"^image",ImageViewSet)
#core/urls.py from django.contrib import admin from django.urls import path,include from django.conf import settings from django.conf.urls.static import static from app.urls import router as router urlpatterns = [ path('admin/', admin.site.urls), path("",include(router.urls)) ] if settings.DEBUG: urlpatterns += static(settings.MEDIA_URL,document_root = settings.MEDIA_ROOT)
繋ぎこみの際にファイル名を注意してください.これでlocalhost:8000/image/classificationとURLを入れた時に,APIをPOSTできるようになります.
これでDRFで記述すべきことは終わりました.次に,resnetの方を記述していきます.
resnetの作成
resnetのディレクトリの中のmodel.pyを作成します.以下のように記述してください.
#app/resnet/model.py import torch from torchvision import models def resnet_model(): MODEL_PATH = "./app/resnet/config/resnet18-5c106cde.pth" model = models.resnet18(pretrained=False) model.load_state_dict(torch.load(MODEL_PATH)) model.eval() return model
今回は簡易的に作っているため,MODEL_PATHをこんな風にPathを書くことはナンセンスだと思うので注意してください(笑).また,model.eval()を忘れないでください(これを書くの忘れて何時間も悩んだのは裏の話).
次にpredict.pyを作成します.以下のように記述してください.
#app/resnet/predict.py from PIL import Image import json import cv2 import numpy as np import torch import torch.nn as nn from torchvision import transforms from .model import resnet_model preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) softmax = torch.nn.Softmax(dim=1) model = resnet_model() def predict(img): # json with open("./app/resnet/config/imagenet_class_index.json", 'r') as f: image_dict = json.load(f) # img img = Image.open(img) img = img.convert('RGB') img = np.array(img) img = preprocess(transforms.ToPILImage()(img)).unsqueeze(0) # prediction predict = model(img).data prob = softmax(predict)[0].tolist() best_ten = np.argsort(prob)[::-1][:10] response = [] for i,rank in enumerate(best_ten,1): label = image_dict[str(rank)][1] response.append({"rank":i,"prob":prob[rank],"label":label}) return response
入力された画像はImage.open(img)にて読み込みます.cv.imread(img)で読み込むことはできないので注意してください.また,
preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
この部分はImageNetの学習方法と同じようにしているので,これがなければ出力で欲しいラベルが返って来なくなります pytorch.org.
returnで返しているのは配列であり,中身は辞書型になっています.probの確率値はtolist()でリスト型にしています.これはデータベースにデータが保存されるときに,データの型がnumpyであるとエラーの原因になるためです(これはちょっとはまり所でした).そこで,tolist()で通常の数値型に変更し,エラーの原因を未然に防ぎます.参照したのは以下のページです.
APIを使って推論
ここまででAPIは作れたので,実際の画面で確認していきます.
python manage.py makemigrations python manage.py migrate
をしてください.これはDjangoの定型文みたいなものなので,そうなんだーみたいな感じでやってください(これはデータベースを作ってくれたりしている).また,Modelsの中身を変更したり,追加する場合は上記のコードをもう一度入力してください.すると,更新されます.
python manage.py runserver 0.0.0.0:8000
を実行しlocalhost:8000/imageのurlを検索すると以下のような画面が出てくると思います.
ここに保存されたデータの一覧が出力されるようになります.localhost:8000/image/classificationのurlを検索するとAPIをPOSTできる画面が出力されるのでやって見てください.
まとめ
今回は画像を入力として受け付けて,上位10個の確率値とラベルを出力するAPIを作成しました.モデルはresnetを使用し,ImageNetの学習済みモデルで値を出力させました.これらは適宜自分のモデルの差し替えが可能なので他のものでも試していきたいと思います.
次回
次はNuxtを利用して,front側を作成していきたいと思います.入力フォームを作成し,実際にユーザが画像を投稿するようなイメージで作成します.その画像をAPIに投げ,値が返ってくるところまでを実装し,画面に表示させます.