FastAPI入門 - モダンなPythonフレームワークの特性をチュートリアルで手軽に学ぶ
PythonのWebフレームワークとしていま注目を集めるFastAPIは、シンプルにコードが書けるだけでなく、パフォーマンスが高いWebアプリケーションのバックエンドサーバーが構築可能です。同フレームワークの勘所をPythonスペシャリストの杜世橋さんが、初心者向けのハンズオン、そしてより実践的な画像への自動タグ付けサービス実装をとおして解説します。
FastAPIはいま非常に注目されているPythonのWebフレームワークの1つです。Flaskのようにシンプルに書ける一方でPythonのType Hintの機能をうまく活用し、HTTPのリクエスト/レスポンスをPythonの関数の引数/戻り値とシームレスにマッピングして非常に効率的に開発ができるのが最大の特徴です。非同期処理にも対応していてその名の通りとてもパフォーマンスが高いWebアプリケーションのバックエンドサーバー(以降WebAPIと表記)を書くことができます。
本稿はこうしたFastAPIの特性をスピーディに学べるよう、特に重要度が高いと思われる要素にフォーカスした4章で構成しています。まず、1~2章ではWebAPIサーバーを開発するうえで大事なHTTPのリクエストの解析とレスポンスの作成、およびDBとの連携について解説していきます。続く第3章ではハイパフォーマンスなWebAPIサーバーを開発するうえで重要な、非同期処理について説明していきます。そして最後の第4章ではFastAPIとPythonの機械学習のライブラリを組み合わせ、簡単なタグ付きの画像保存サービスを開発していきます。
開発環境
本稿では実行するOS環境はとくに指定しません。使用するライブラリはWindows/Mac/Linuxいずれでも提供されていますので問題なくインストールできるはずです。筆者はWindows 11のWSLを用いてUbuntu 22.04をインストールし、その中のデフォルトのPython 3.10で動作確認を行っています。ライブラリのインストールにはpip
を利用してもいいですが、筆者はPDMというパッケージマネージャを利用しています。FastAPIなどの実際にプログラムに含めるライブラリ以外にもblack
, ruff
, pyright
などのformatterとlinterを使用しています。なお、エディタはVSCodeを利用し、以下の3つの拡張を有効にします。
ms-python.black-formatter
ms-python.python
charliermarsh.ruff
formatterとlinter、そしてVSCodeの設定についてはこちらのリポジトリにあるpyproject.toml
とvscode_template
を参照してください。
入門編1:FastAPIのポイントをスピーディに学ぶ
まずはPDMを利用してFastAPIの開発環境を構築していきます。PDMをインストール後、今回の開発のためのディレクトリを作成し、以下のような構成にします。
/ src/ webapp/ __init__.py
このディレクトリのトップで以下のコマンドを入力し、プロジェクトを初期化します。
pdm init
途中でいくつか質問されますが、Pythonのバージョンとして3.10を利用すること以外は基本的にはデフォルトの選択肢で問題ないでしょう。その後、formatterとlinterを以下のコマンドでインストールします。types-SQLAlchemy
はSQLAlchemyをlinterに正しく認識させるための追加の型情報の補助パッケージです。
pdm add -d black ruff pyright types-SQLAlchemy
最後にFastAPIとUvicornをインストールします。UvicornはFastAPIで開発したアプリを実際に実行するアプリケーションサーバーです。Flaskなどで利用するgunicornと同様の物ですが、Pythonの非同期WebアプリインターフェースのASGIに対応しているのが違いです。
pdm add fastapi uvicorn
これでインストールは完了です。__init__.py
に以下のようにhelloを返すだけの処理を書いてみます。
from fastapi import FastAPI app = FastAPI() @app.get("/") def hello(): return "hello"
@app.get("/")
はFlaskなどでもおなじみの書き方で、デコレータを用いて関数とルートの紐づけを行っています。本記事ではこのデコレータ部分をルートの登録、登録される関数をハンドラー関数と呼びます。
このFastAPIアプリをUvicornで起動してみます。
pdm run uvicorn webapp:app --reload --no-server-header --no-date-header
pdm run
の部分はPDMで管理している仮想環境で実行するためのコマンドです。webapp:app
はモジュール名:変数名
の順番で書きます。今回はwebapp.__init__
にあるapp
を実行したいのでwebapp:app
という表記になっています。複数階層のとき、例えばPythonのimport文でx.y.app
と書けるときにはx.y:app
となります。--reload
はソースコードを変更したときに自動的にリロードするオプションで開発時によく使用します。起動すると以下のように表示されるはずです。デフォルトでは127.0.0.8:8000
にbindします。
$ pdm run uvicorn webapp:app --reload --no-server-header --no-date-header INFO: Started server process [9328] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
この状態でもう一つ別のターミナルを立ち上げ、HTTPのクライアントでリクエストを送ってみましょう。筆者はxhをよく利用します。
$ xh :8000 HTTP/1.1 200 OK Content-Length: 7 Content-Type: application/json "hello"
上記のようにhello
と返ってくれば成功です。
リクエストとレスポンスの処理
FastAPIではHTTPのリクエストとレスポンスをそのままPythonの関数の引数と戻り値にマッピングできますが、これこそ最も素晴らしい特徴といって過言ではありません。他のフレームワークでは事前にリクエストをある程度パースしてヘッダーやquery params、bodyなどを辞書型に格納し、開発者はキーを手動で指定し、そこから値を取得するという方式が一般です。これに対してFastAPIでは、関数の引数にマッピングするので必要なリクエストの内容がすぐにわかる可読性が高いコードになります。また、PythonのType Hint機能と合わせることで自動で型の変換やチェックなども行えます。
この章では、こうしたFastAPIの魅力的な機能を体験してみます。以下のライブラリを使用しますのでPDMやpipでインストールしてください。
python-multipart
また、以下のような構成のディレクトリを作成し、
/ src/ webapp/ __init__.py chapter1/ __init__.py
実行する際には以下のコマンドを使用します。
uvicorn webapp.chapter1.app --reload --no-server-header --no-date-header
Queryパラメータの取得
まずはHTTPのGETリクエストでよく利用されるQueryパラメータの取得の仕方について見ていきます。ここではa
とb
という2つのキーで数値型の値を受け取り、足し算をした結果を返すAPIを作ってみます。
from fastapi import FastAPI app = FastAPI() @app.get("/add") def add(a: int, b: int): return {"result": a + b}
この状態でアプリ本体を実行し、/add
にGETリクエストを送ってみます。
$ xh ":8000/add?a=1&b=2" HTTP/1.1 200 OK Content-Length: 12 Content-Type: application/json { "result": 3 }
上記のように、ちゃんとQueryパラメータのa
とb
が関数add
の引数にマッピングされ、文字列を整数に変換して足し算が正しく計算されたことがわかります。もしa
またはb
のいずれかが指定されなかったり、あるいは整数に変換できないデータが指定された場合はどうなるでしょうか。
$ xh ":8000/add?a=1" HTTP/1.1 422 Unprocessable Entity Content-Length: 86 Content-Type: application/json { "detail": [ { "loc": [ "query", "b" ], "msg": "field required", "type": "value_error.missing" } ] } $ xh ":8000/add?a=1&b=text" HTTP/1.1 422 Unprocessable Entity Content-Length: 99 Content-Type: application/json { "detail": [ { "loc": [ "query", "b" ], "msg": "value is not a valid integer", "type": "type_error.integer" } ] }
上記の2つの例のように、パラメータが足りなかったり、あるいは目的の型に変換できないデータが入っていた場合には自動的にHTTPの422エラーが返され、さらにどのパラメータが問題なのかまでエラーメッセージで教えてくれるというとても親切な設計になっています。
BodyのJSONのパース
POST/PUT/PATCHメソッドなどでBodyにJSONとしてリクエストパラメータが送られてくる場合、事前にリクエストのスキーマを定義しておく必要があります。FastAPIではこうした定義のためにPydanticを利用します。実はPythonの標準機能であるdataclass
で定義してもFastAPIは自動的にPydanticのモデルに変換して使用してくれますが、一部の機能に制限があるので本稿ではPydanticを使用していきます。なお、Pydanticは型以外にもさまざまなデータバリデーションを行えるライブラリですが、この名前は規則にこだわりすぎるという意味を表すpedanticに由来していると筆者は思います。
まずここでは以下のようなUserBase
クラスを定義します。
python-multipart
import enum import pydantic from fastapi import status class ProgrammingLanguage(str, enum.Enum): CPP = "c++" Python = "python" Java = "java" JavaScript = "javascript" Rust = "rust" Go = "go" Other = "other" class UserBase(pydantic.BaseModel): name: str age: int = pydantic.Field(ge=0) favorite_language = ProgrammingLanguage class UserIn(UserBase): pass class User(UserBase): id: int = pydantic.Field(ge=0)
名前、年齢、好きなプログラミング言語の3要素からなるクラスです。名前は通常の文字列型ですが、年齢はただの整数ではなく、Pydanticの機能を利用して0以上という条件をつけています。プログラミング言語は合計7つの選択肢があるとし、文字列をそのまま使うのではなくEnumを利用します。さらに、これらクラスを「インプットとして使う」ということを明確にするため、これを継承して別名をつけただけのUserIn
も定義します。また、実際にDBに保存する際にはidも付与するのでこれもUserBase
を継承したUser
として定義します。これでスキーマの準備完了です。
次にAPIを作成します。ここでは外部データベースを利用せず、メモリ内のリストに保存するという単純な仕組みで作ります。
users: list[User] = [] @app.post("/user", status_code=201) def post_user(user_in: UserIn): user = User(id=len(users), **user_in.dict()) users.append(user) return {"user_id": user.id}
REST APIではPOSTで新しいリソースを作った場合、200 OKではなく201 CREATEDを返すのが適切なので、ルートの設定をする際にstatus_code
で指定しています。関数の引数のuser_in
の型としてUserIn
を指定していますが、これだけでFastAPIはPydanticの機能を利用してJSONをパースするだけでなく、値のバリデーションまで実施してくれます。実際に試してみましょう。
$ xh post ":8000/user" name=Taro age=25 favorite_language=python HTTP/1.1 201 Created Content-Length: 13 Content-Type: application/json { "user_id": 0 }
上記の通り、正しい種類の値を指定するとリクエストが成功しました。続いて失敗するパターンを試してみます。
$ xh post ":8000/user" name=Taro age=-1 favorite_language=python HTTP/1.1 422 Unprocessable Entity Content-Length: 150 Content-Type: application/json { "detail": [ { "loc": [ "body", "age" ], "msg": "ensure this value is greater than or equal to 0", "type": "value_error.number.not_ge", "ctx": { "limit_value": 0 } } ] } $ xh post ":8000/user" name=Taro age=25 favorite_language=ruby HTTP/1.1 422 Unprocessable Entity Content-Length: 274 Content-Type: application/json { "detail": [ { "loc": [ "body", "favorite_language" ], "msg": "value is not a valid enumeration member; permitted: 'c++', 'python', 'java', 'javascript', 'rust', 'go', 'other'", "type": "type_error.enum", "ctx": { "enum_values": [ "c++", "python", "java", "javascript", "rust", "go", "other" ] } } ] }
ageに負の値を使用したり、定義外のfavorite_languageの値を使用するとバリデーションに失敗してエラーが戻りました。
Pathパラメータの取得
REST APIでは GET /user/:id
のようにURLのpathでidなどの値を表現することが多いです。FastAPIではこれもQueryパラメータと同様に簡単に型付きで指定できます。先ほど作ったUser
を取得するAPIを作ってみましょう。
@app.get("/user/{id}", response_model=User) def get_user(id: int): if id >= len(users): raise HTTPException(status.HTTP_404_NOT_FOUND, f"user of id={id} not found") return users[id]
ルートの登録時に{id}
のように{}
で囲むことで、その部分が変数であると認識され、実際の関数の引数にマッピングされます。
$ xh ":8000/user/0" HTTP/1.1 200 OK Content-Length: 60 Content-Type: application/json { "name": "Taro", "age": 25, "favorite_language": "python", "id": 0 }
なお、この例ではデータはメモリ上にしかないのでプログラムを起動しなおすとデータが消えてしまいます。GET /user/{id}
のテストをする際には事前に先ほど作成したPOST /user
で値を登録しておく必要があります。
Fileのアップロード
POSTリクエストでmultipart/form-data
形式でファイルを受け取る場合、前述のpython-multipart
というライブラリを利用します。インストールするだけでFastAPIが必要な時に自動的に利用するので特にimportは不要です。
import shutil from pathlib import Path from fastapi import UploadFile DATA_DIR = Path("/tmp/data") @app.post("/upload") def upload_file(data: UploadFile): with DATA_DIR.joinpath(data.filename).open("wb") as f: shutil.copyfileobj(data.file, f) return {"filename": data.filename}
Formでアップロードされるファイルを受け取る際は、引数の型にUploadFile
型を指定します。また、引数の名前はContent-Disposition
のname
で指定されるkeyの値とそろえる必要があります。UploadFile
にはfilename
とfile
の2つの重要なプロパティがあり、前者は元のファイル名の文字列、後者はPythonのファイルオブジェクトになっていますのでread
メソッドでデータを読み取れます。上記の例では直接read
メソッドは使用せず、より効率的な標準ライブラリのshutil
の関数を利用しています。
$ echo test > test.txt $ xh post -f ":8000/upload" [email protected] HTTP/1.1 200 OK Content-Length: 23 Content-Type: application/json { "filename": "test.txt" }
xhコマンドでファイルをアップロードする場合は-f
オプションでJSONではなくForm形式であることを指定し、 key@filepath
の形式でファイルを指定します。
OpenAPIドキュメントの自動生成
この時点ですでにFastAPIは自動的にOpenAPIのドキュメントを自動生成しており、/docs
にアクセスすることで閲覧できます。
各ルートのドキュメントを見ると入力のparametersには正しく型まで反映されていることがわかります。一方でresponseの方はGET /user/{id}
以外はschemaが正しく反映されていません。FastAPIではresponseのschemaを正しく反映させるにはルートを設定する際にresponse_model
を指定する必要があります。GET /user/{id}
のルートのみこれを設定していましたのでschemaも正しく反映されています。しかし、GET /user/{id}
200と422以外にも存在しないidを指定したときに404を返しますがそれがまだ反映されていません。デフォルトの戻り以外についてはresponses
を指定することで対応できます。コードを以下のように修正します。
class ErrorMessageOut(pydantic.BaseModel): detail: str @app.get( "/user/{id}", response_model=User, responses={status.HTTP_404_NOT_FOUND: {"model": ErrorMessageOut}}, ) def get_user(id: int): if id >= len(users): raise HTTPException(status.HTTP_404_NOT_FOUND, f"user of id={id} not found") return users[id]
再度、FastAPIが自動生成するドキュメントのWebページを見ますと今度は404もresponseに追加されています。
Webページではなく、OpenAPIの定義ファイルが必要になるケースもありますが、その場合はFastAPI
オブジェクトのopenapi
メソッドを呼ぶことでPythonの辞書型でOpenAPIドキュメントのデータを取得することができます。あとはこれをJSONファイルやYAMLファイルとして書き出せば完了です。
import json from app import app if __name__ == "__main__": doc = app.openapi() json.dump(doc, open("openapi.json", "w"), indent=4)
入門編2:データベースとの連携
前章ではリクエストで受け取ったデータをメモリの中にそのまま持っていたので、プログラムを終了するとデータが全て消えてしまいました。実際のWebAPIの開発ではデータをデータベース(DB)に保存して永続化することが最も基本的な要件になります。本章では代表的なRDBMSであるPostgresとNoSQLのMongoDBを対象に、FastAPIのアプリとDBとの連携について説明していきます。なお、PostgresやMongoDBは別途Dockerなどを利用して読者自身で準備していただく必要があります。
この章では新たに以下のライブラリを使用しますのでPDMやpipでインストールしてください。
sqlalchemy
pg8000
pymongo
SQLAlchemyを利用したRDBMSとの連携
前章のUserと同じモデルを利用します。まずは以下のテーブルをPostgres内に作成してください。
create table users ( id serial primary key, name varchar(64), age integer, favorite_language varchar(16) );
次に前章と同様にサブディレクトリchapter2_1
を作成し、実行時もwebapp.chapter2_1:app
を指定してください。
/ src/ webapp/ __init__.py chapter1/ __init__.py chapter2_1/ __init__.py model.py schema.py
SQLAlchemyはPythonの代表的なORMです。FastAPIではPydanticのクラスとSQLAlchemyのクラスを相互に変換して利用します。HTTPのリクエストやレスポンスを扱う際にはPydantic、RDBMSとのデータのやり取りを扱う際にはSQLAlchemyのクラスを利用するのが、基本的な考え方です。FastAPIの公式ドキュメントではPydanticのクラスはschema
, SQLAlchemyのクラスはmodel
というモジュールで別々に管理する方法をとっていますので、本稿でもそれに倣います。まずchapter2_1/model.py
に以下の内容を書きます。
from sqlalchemy import Column, Integer, String from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True, index=True) name = Column(String(64)) age = Column(Integer) favorite_language = Column(String(16))
事前に作成したusers
テーブルの構造をそのままSQLAlchemyで書いています。次にchapter2_1/schema.py
にはPydanticのクラスを書いていきますが、これは基本的には前章のPydanticで作ったUser
群と同じで、最後のUser
についてのみ以下のように追記します。
class User(UserBase): id: int = pydantic.Field(ge=0) class Config: orm_mode = True
変更点は最後のorm_mode
についての記載のみです。これを書くことにより、Pydanticはobj.a
のようなattribute方式でもインプットを解析しようとし、SQLAlchemyで作ったクラスに対応できるようになります。次に新規のUser
をDBに登録する関数と、指定したid
を持つUser
をDBから検索する関数を作成します。以下の内容をchapter2_1/__init__.py
に書きます。
from sqlalchemy import select from sqlalchemy.orm import Session from . import model, schema def _create_user(user: model.User, session: Session): session.add(user) session.commit() return user.id def _get_user(id: int, session: Session): stmt = select(model.User).where(model.User.id == id) user = session.execute(stmt).scalar_one_or_none() return user
最後にこれらを使用したハンドラー関数を作成し、ルートに登録して完成です。
from fastapi import FastAPI, status from sqlalchemy import create_engine DB_URI = "postgresql+pg8000://<user>:<password>@<host>:<port>/<dbname>" engine = create_engine(DB_URI) app = FastAPI() @app.post("/users") def create_user(user_in: schema.UserIn): user = model.User(**user_in.dict()) with Session(bind=engine) as session: user_id = _create_user(user, session) return {"user_id": user_id} @app.get("/users/{id}", response_model=schema.User) def get_user(id: int): with Session(bind=engine) as session: user = _get_user(id, session) if user is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"user of id={id} not found") return user
前述の通り、HTTPのリクエスト/レスポンスはPydantic、RDBMSとのやり取りはSQLAlchemyのクラスを用いている点に注目してください。POST /user
ではPydanticでリクエストを解析し、これをSQLAlchemyのクラスに変換してPostgresに入れています。一方でGET /users/{id}
ではPostgresからSQLAlchemyのクラスでデータを受け取り、Pydanticのクラスに変換してHTTPのレスポンスを作成しています。なお、ルートを登録するときにresponse_model
を指定することでSQLAlchemyのクラスをreturnしても自動でPydanticのクラスに変換してくれます。もし自分で手動変換したい場合はschema.User.from_orm
メソッドを利用します。
(注1) 本記事を執筆している時点(2022年12月)でSQLAlchemyの最新安定板は1.4系です。しかし、すでに2系の開発が進行中であり、本記事でもできるだけ2系でも使用可能な関数を利用して書きました。
(注2) SQLAlchemyのModelとPydanticのSchemaの両方を定義するというのは冗長であると感じる方も多いと思います。FastAPIの作者はこの問題を解決するためにSQLModelというライブラリも公開していますが、こちらは本稿執筆時点でまだバージョンが0.0.8でしたので今回は使用しませんでしたが、興味がある方はぜひ試してみてください。
PyMongoを利用したMongoDBとの連携
MongoDBの場合にはRDBMSにおけるSQLAlchemyのような定番のODM(MongoDBではORMではなくODM, Operation Document Mapperという)が存在せず、複数あるうちから開発者が選択することになります。例えばBeanieやODManticなどがPydanticをベースに利用されていてFastAPIとの親和性が高いですが、それらの説明は公式のチュートリアル譲り、ここではあまり情報がない、ODMを使用せずに公式の接続ドライバであるPyMongoのみを使った場合について説明します。MongoDB以外の新しいNoSQLを利用する場合には必ずしもPydanticを利用したODMが存在しないこともありますので、そういった場合には以下の説明が役に立つと思います。以下の内容はchapter2_2というサブディレクトリ内で開発していきます。
/ src/ webapp/ __init__.py ... chapter2_2/ __init__.py schema.py
PyMongoを利用する場合、MongoDBとはPythonのdictを用いてデータのやり取りします。よって、前節でSQLAlchemyのモデルを定義したようなことは行わず、model.py
は必要ありません。一方でPydanticのschemaについてはchapter2_1のschema.py
をベースに以下のように微修正をします。
def _new_object_id(): return str(ObjectId()) class User(UserBase): id: str = pydantic.Field(default_factory=_new_object_id, alias="_id")
MongoDBでは全てのドキュメントにプライマリーキーとして_id
という要素が必ず入っています。これは整数ではなくObjectIDという特殊な型になっています。ただ、これは文字列に変換できますのでUser.id
は文字列とし、デフォルトの値として都度ObjectIDの文字列を生成して入れることにします。また、alias
を利用することでクラスのメンバー変数の名前と実際ににエンコード/デコードするときのキーの名前を変更できます。今回の例ではMongoDBの_id
をUser
のid
に対応させます。あとは前節と同じように__init__.py
にCRUDの補助関数を作り、それをFastAPIのルートに登録するハンドラー関数の内部で呼びます。
補助関数は以下の通りです。SQLAlchemyのSession
ではなく、PyMongoのCollection
を引数にとります。
from typing import Optional from pymongo.collection import Collection from . import schema def _create_user(user: dict, col: Collection): col.insert_one(user) def _get_user(id: str, col: Collection) -> Optional[dict]: return col.find_one({"_id": id})
ルートに登録するハンドラー関数は以下の通りです。
import pymongo from fastapi import FastAPI, status from fastapi.exceptions import HTTPException from pymongo.database import Database db: Database = pymongo.MongoClient( "mongodb://<user>:<password>@<host>:<port>/<dbname>", connect=False ).get_default_database() col = db["users"] app = FastAPI() @app.post("/users") def create_user(user_in: schema.UserIn): user = schema.User(**user_in.dict()) _create_user(user.dict(by_alias=True), col) return {"user_id": user.id} @app.get("/users/{id}", response_model=schema.User, response_model_by_alias=False) def get_user(id: str): user = _get_user(id, col) if user is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"user of id={id} not found") return user
注意するポイントは2つです。まずPOST /users
の際にはPydanticのUser
でリクエストをパースしますが、これをPythonのdictに変換してPyMongoに渡します。この時id
は_id
としてエンコードする必要があるのでby_alias=True
を指定します。
一方でGET /users/{id}
の際にはPyMongoから受け取ったdictをresponse_model
を指定することでPydanticのUser
に変換しますが、FastAPIはデフォルトではHTTPのレスポンスへのエンコードにaliasを使用するのでレスポンスの中にはid
ではなく_id
と書かれてしまいます。これを防ぐためにさらにresponse_model_by_alias=False
を使用しています。
$ xh post ":8000/users" name=Taro age=30 favorite_language=python HTTP/1.1 200 OK Content-Length: 38 Content-Type: application/json { "user_id": "63967c8a27096a36b3cb32d5" } $ xh ":8000/users/63967c8a27096a36b3cb32d5" HTTP/1.1 200 OK Content-Length: 85 Content-Type: application/json { "name": "Taro", "age": 30, "favorite_language": "python", "id": "63967c8a27096a36b3cb32d5" }
コードの構造化に関するTips
この章ではSQLAlchemyやPyMongoの接続情報が入ったオブジェクトをトップレベルでグローバル変数として宣言し、ハンドラー関数がこのグローバル変数をそのまま利用するという方式で書いてきました。実はFastAPIの公式チュートリアルでもこの方式で書かれています。これは個人のスタイルの問題かもしれませんが、筆者はグローバル変数を共有する書き方はあまり好まず、初期化時により明示的に指定する方法がいいと考えています。ここでは筆者がよく利用する方法を紹介します。この方法はDBの接続オブジェクト以外にもAWS SDKのインスタンスや機械学習の学習済みモデルを含んだAPIを開発する際に幅広く応用できます。この節では以下のようなディレクトリ構成にします。
/ src/ webapp/ __init__.py chapter2_3/ __init__.py model.py schema.py routes/ __init__.py users.py
まずchapter2_1で作成したUser
の操作に関する処理をまとめてchapter2_3/routes/users.py
切り出します。
from fastapi import APIRouter, status from fastapi.exceptions import HTTPException from sqlalchemy import select from sqlalchemy.engine import Engine from sqlalchemy.orm import Session from .. import model, schema def _create_user(user: model.User, session: Session): session.add(user) session.commit() return user.id def _get_user(id: int, session: Session): stmt = select(model.User).where(model.User.id == id) user = session.execute(stmt).scalar_one_or_none() return user
そしてポイントとなるのがこの後で、FastAPIオブジェクトやSQLAlchemyのEngineオブジェクトをグローバル変数を使用しないでハンドラー関数に渡す必要があります。筆者は以下のように引数で必要なDBオブジェクトなどを受け取る関数を作り、内部ではFastAPIオブジェクト自体ではなく、APIRouter
オブジェクトを作成し、ハンドラー関数を登録してAPIRouter
を返すようにします。
def create_router(engine: Engine): api = APIRouter() @api.post("") def create_user(user_in: schema.UserIn): user = model.User(**user_in.dict()) with Session(bind=engine) as session: user_id = _create_user(user, session) return {"user_id": user_id} @api.get("/{id}", response_model=schema.User) def get_user(id: int): with Session(bind=engine) as session: user = _get_user(id, session) if user is None: raise HTTPException( status.HTTP_404_NOT_FOUND, f"user of id={id} not found" ) return user return api
こうすることでchapter2_3/__init__.py
に書くべき内容は以下のようにまで簡略化されます。
from fastapi import FastAPI from sqlalchemy import create_engine from .routes import users DB_URI = "postgresql+pg8000://<user>:<password>@<host>:<port>/<dbname>" engine = create_engine(DB_URI) app = FastAPI() user_router = users.create_router(engine) app.include_router(user_router, prefix="/users")
SQLAlchemyのcreate_session
オブジェクトをusers.create_router
に明示的に渡してuser_router
オブジェクトを作成し、これを本体にapp.include_router
にprefixを指定して取り込むという書き方になりました。このような書き方のメリットとしてはusers
の機能がコンポーネント化され、再利用やテストがしやすくなるという点が挙げられます。users
以外にも複数のルートがある場合は以下のようにそれぞれ必要な情報を受け取ってrouterオブジェクトを作り、本体にapp.include_router
で取り込んでいきます。
from .routes import items, users user_router = users.create_router(create_session) app.include_router(user_router, prefix="/users") item_router = items.create_router(create_session, other_obj) app.include_router(user_router, prefix="/items")
本稿ではDB_URI
はソースコードに直接書いていますが、実際の開発ではこれをコマンドライン引数や環境変数などから取得してください。
入門編3:非同期処理
非同期処理は現代的なWebAPIを開発するうえで必要不可欠な技術です。Webサーバーの主な負荷はIO待ちであり、非同期処理を用いることで少ないリソースで同時にたくさんのリクエストを処理できるようになります。Pythonでも3.5からasync/await構文を用いた非同期処理が導入され、FastAPIはこれを使用して非同期WebAPIを開発することができます。
この章では新たに以下のライブラリを使用しますのでPDMやpipでインストールしてください。
asyncpg
非同期処理入門
まずは簡単な処理を用いてPythonの非同期関数開発を体験してみます。以下のソースコード中のsync_f
という関数はという関数は5秒待機を2回繰り返してOKと表示します。すなわち合計で10秒待たされます。
import time def sync_f(): time.sleep(5) time.sleep(5) print("OK") sync_f()
これを非同期処理にするには以下のように書き換えます。
import asyncio async def async_f(): co1 = asyncio.sleep(5) task1 = asyncio.create_task(co1) co2 = asyncio.sleep(5) task2 = asyncio.create_task(co2) await task1 await task2 print("OK") asyncio.run(async_f())
async def
で定義された関数はもともとの戻り値をCoroutine
という構造に入れて返します。これはJavascriptでいうPromiseに相当しますが、この時点ではまだ実行されておらず、asyncio.create_task
という関数でTask
に変換すると実行開始されます。そしてawait
というキーワードを使用して非同期関数が完了するまで待ち、Task
あるいはCoroutine
の内部の値を取り出します。ただ、asyncio.sleep
は中身がNone
のCoroutine
を返すのでここでは戻りとは受け取りません。async def
された関数は同じくasync def
された関数の中でしか実行できず、外から実行する場合にはasyncio.run
を使用します。この場合は自動的に完了するまで待たされます。上記のソースコードを実行すると今度はおよそ5秒で完了します。これは2つの「5秒待機」が同時に非同期に実行されるからです。
上記の例では5秒間sleepしましたが、これは実際のWebアプリにおけるDBサーバーとの通信待ちを模したものです。1つのスレッドで同期的に書くとIO待ちも直列で待ち時間が増えるということ体感できたかと思います。実際のWebAPIのサーバーには、数千のクライアントから同時にアクセスが来ることも珍しくありません。これを解決する方法としてマルチスレッド処理という方法がありますが、スレッドの生成にはコストがかかります。非同期処理は1つのスレッドで同時に複数のIOを処理する方法であり、IOに関してはマルチスレッドよりも高い効率が期待できます。もちろん両者を組み合わせることも可能です。
FastAPIでもasync def
を用いれば非同期のハンドラー関数をルートに登録することができます。
@app.get("/") async def hello_5(): await asyncio.sleep(5) return "hello"
RDBMSとの非同期連携
ここでは非同期処理を利用してRDBMSとデータをやり取りする方法を見ていきます。基本的には前節と同じで同期のIO処理関数を非同期ので置き換え、async/awaitなどのキーワードを足していくだけです。第2章の1節(chapter2_1)のコードのうち、schemaとmodelは全く変えずに再利用します。__init__.py
の内容を以下のように書き換えます。
from fastapi import FastAPI, status from fastapi.exceptions import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from . import model, schema DB_URI = "postgresql+asyncpg://<user>:<password>@<host>:<port>/<dbname>" engine = create_async_engine(DB_URI) async def _create_user(user: model.User, session: AsyncSession): session.add(user) await session.commit() await session.refresh(user) return user.id async def _get_user(id: int, session: AsyncSession): stmt = select(model.User).where(model.User.id == id) user = await session.execute(stmt) return user.scalar_one_or_none() app = FastAPI() @app.post("/users") async def create_user_(user_in: schema.UserIn): user = model.User(**user_in.dict()) async with AsyncSession(bind=engine) as session: user_id = await _create_user(user, session) return {"user_id": user_id} @app.get("/users/{id}", response_model=schema.User) async def get_user_(id: int): async with AsyncSession(bind=engine) as session: user = await _get_user(id, session) if user is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"user of id={id} not found") return user
変更点をまとめます。
- ドライバーは非同期に対応した
asyncpg
を利用する。 - SQLAlchemyの
Engine
とSession
をAsyncEngine
とAsyncSession
に置き換える。 - 非同期処理をする関数を
async def
で定義し、内部で別の非同期関数を呼ぶ場合は結果をawait
で待つ。
基本的にはこれだけです。筆者も以前は非同期処理に苦手意識を持っていましたが、実際にはasync/awaitなどのキーワードを理解して適切に使用するだけで書き方は同期的な場合とほとんど変わりません。
非同期処理に対応していないライブラリの場合
前節ではSQLAlchemyは非同期処理にも対応しているのでasync/awaitを利用して非同期関数に変更することができました。しかし同じく第2章で扱ったPyMongoは非同期に対応していません。こういった場合には非同期に対応したライブラリ、例えばMotorなどに置き換えることも考えられますが、もし非同期ライブラリが無い場合はどうすればいいのでしょうか?Pythonのasyncio
モジュールには同期関数を無理やり非同期関数として実行するrun_in_executor
という関数があります。この関数は内部でconcurrent.futures.ThreadPoolExecutor
で実行することで非同期関数の見せかけていますが実はマルチスレッドを利用しているという仕組みです。通常のPythonのプログラムを各場合には確かにこの方法でいいのですが、実はFastAPIは既に内部で独自のスレッドプールを持っていて、async def
されていないハンドラー関数はこのスレッドプール内で実行されます。したがって、非同期ライブラリが存在しない場合には無理に非同期関数にする必要はなく、FastAPIの標準のマルチスレッド処理に任せることが公式ドキュメントでも推奨されています。詳細についてはこちらに書かれていますので興味がある読者はこちらもチェックしてみてください。
実践編:画像への自動タグ付けサービスを実装してみる
本記事の総仕上げとしていくらか実践的なWebアプリを作ってみます。ここでは画像を自動でタグをつけて保存し、後でタグで検索して画像をダウンロードできるようなWebサービスを模したWebAPIを開発してみます。実装する機能は以下の3つです。
POST /
:- JPEGの画像ファイルを受け取る
- 画像ファイルから自動でタグを抽出する
- 画像ファイルをストレージに保存する
- 画像ファイルの元のファイル名、保存先、タグをDBに保存する
- 保存した画像のIDと検出されたタグのリストを返す
GET /search?tag=<tag>
- 指定されたtagを持っている画像の一覧をDBから取得し、IDのリストを返す
GET /{id}
- 指定されたIDの画像ファイルのファイル名と画像データを返す
この章では新たに以下のライブラリを使用しますのでPDMやpipでインストールしてください。
pillow
onnx
onnxruntime
全体のディレクトリ構成は以下のようにします。
webapp/ __init__.py chapter4/ data/ labels.txt yolov3-12-int8.onnx __init__.py model.py schema.py storage.py tagging.py
タグ抽出器の実装
ここでは物体検出技術を利用して画像に写っているものを探し、そのカテゴリをタグとします。物体検出にはYOLOv3という機械学習のモデルを使用します。また、YOLOv3のモデルの訓練自体は行わず、公開されている訓練済みのモデルを使用します。YOLOv3はニューラルネットワークのモデルなのでTensorFlowやPyTorch、DarkNetなど様々なフレームワークで実装されています。近年ではONNXというニューラルネットワークのモデルとパラメータをまとめたフォーマットが標準化され、これを利用することで対応しているフレームワークであれば統一的に推論を行うことができます。特にONNXの推論に特化したランタイムとしてMicrosoft社が開発しているOSSのONNX Runtimeを利用します。
タグ抽出の全体の流れは以下の通りです。
- JPEGの画像データを読み取り、RGBにデコード
- デコードした画像を既定のサイズにリサイズするなどの前処理を行う
- YOLOv3のONNXモデルをONNX Runtimeにロードする
- ロードしたモデルに前処理した画像データを入力し、検出されたカテゴリのリストを取得する
- カテゴリは数値になっているので事前に用意されている文字列との対応表から文字列に変換する
この流れを実装したのが以下のtagging.py
です。
from pathlib import Path import numpy as np from onnxruntime import InferenceSession from PIL import Image def letterbox_image(image: Image.Image, size: tuple[int, int]): """resize image with unchanged aspect ratio using padding""" iw, ih = image.size w, h = size scale = min(w / iw, h / ih) nw = int(iw * scale) nh = int(ih * scale) image = image.resize((nw, nh), Image.Resampling.BICUBIC) new_image = Image.new("RGB", size, (128, 128, 128)) new_image.paste(image, ((w - nw) // 2, (h - nh) // 2)) return new_image def preprocess(image: Image.Image): model_image_size = (416, 416) boxed_image = letterbox_image(image, model_image_size) image_data = np.array(boxed_image, dtype="float32") image_data /= 255.0 image_data = np.transpose(image_data, [2, 0, 1]) image_data = np.expand_dims(image_data, 0) image_size = np.array([image.size[1], image.size[0]], dtype=np.float32).reshape( 1, 2 ) return image_data, image_size def get_prediction( session: InferenceSession, image_data: np.ndarray, image_size: np.ndarray ): inname = [input.name for input in session.get_inputs()] outname = [output.name for output in session.get_outputs()] input = {inname[0]: image_data, inname[1]: image_size} boxes, scores, indices = session.run(outname, input) out_boxes, out_scores, out_classes = [], [], [] for idx_ in indices: out_classes.append(idx_[1]) out_scores.append(scores[tuple(idx_)]) idx_1 = (idx_[0], idx_[2]) out_boxes.append(boxes[idx_1]) return out_boxes, out_scores, out_classes def extract_tag(image: Image.Image, session: InferenceSession, labels: list[str]): image_data, image_size = preprocess(image) _, _, classes = get_prediction(session, image_data, image_size) return frozenset(labels[i] for i in classes) class YOLOv3Tagger: def __init__(self, model_path: Path, label_path: Path) -> None: self._session = InferenceSession(str(model_path)) self._labels = label_path.open("r").read().strip().splitlines() def __call__(self, image: Image.Image): return extract_tag(image, self._session, self._labels)
なお、YOLOv3のモデルファイルはONNXのGitHubのリポジトリからダウンロードできます。本記事ではこの中で最も軽いYOLOv3-12-int8というモデルを使用しています。また、同リポジトリには必要な前処理のコードも書かれており、上記のコードはそれをベースにしています。出力されるカテゴリはCOCOという物体検出のデータセットで使用される80種類の値になっており、その定義はchapter4/data/labels.txt
にまとめてあります。
ストレージの実装
まずは画像ファイルを保存するストレージを実装していきます。実際のサービスではAWSのS3などのオブジェクトストレージサービスを使用するべきですが、ここでは説明を簡略化するためにシンプルにローカルのディスクに保存する実装にします。以下がstorage.py
の内容です。
import uuid from pathlib import Path class FileStorage: def __init__(self, path: Path) -> None: self._path = path def save(self, jpg_data: bytes): file_name = uuid.uuid4().hex + ".jpg" with self._path.joinpath(file_name).open("wb") as f: f.write(jpg_data) return file_name def get(self, file_name: str): file_path = self._path.joinpath(file_name) if not file_path.exists(): return None return file_path
クラスは保存先のディレクトリを保持します。save
メソッドではファイルのデータを受け取ったらユニークな名前をつけて保存し、get
メソッドでは保存した名前を受け取って実際に保存されているパスを返します。
DBとモデル
続いてRDBMSに以下の2つのテーブルを作成します。
create table images ( id serial primary key, name varchar(64), saved_name varchar(256) ); create table tags ( id serial primary key, image_id int references images(id), tag varchar(32) );
images
テーブルには元の名前name
とストレージに保存されたユニークな名前のsaved_name
が書かれ、tags
には画像のIDごとにどのタグがついているかが保存されます。
これを利用するためのSQLAlchemyのモデルを以下のようにmodel.py
に書きます。
from sqlalchemy import Column, Integer, String from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() class Image(Base): __tablename__ = "images" id = Column(Integer, primary_key=True, index=True) name = Column(String(64), nullable=False) saved_name = Column(String(256), nullable=False) class Tag(Base): __tablename__ = "tags" id = Column(Integer, primary_key=True, index=True) image_id = Column(Integer, nullable=False) tag = Column(String(32), nullable=False)
スキーマの定義
リクエスト/レスポンスに使用するPydanticのスキーマをschema.py
に作ります。リクエストはいずれもシンプルな値自体であり、GET /{id}
は画像データ自体を返すのでPOST /
とGET /searcn
のレスポンスと汎用のエラーレスポンスだけです。
from typing import TypeAlias import pydantic class PostImageOut(pydantic.BaseModel): id: int tags: frozenset[str] class ImageInfo(pydantic.BaseModel): id: int name: str class Config: orm_mode = True SearchImageOut: TypeAlias = list[ImageInfo] class ErrorMessageOut(pydantic.BaseModel): detail: str
ハンドラー関数の作成
route.py
に上で準備したパーツを合わせて実際の処理を組み立てていきます。
import io import PIL.Image from fastapi import APIRouter, UploadFile, status from fastapi.exceptions import HTTPException from fastapi.responses import FileResponse from sqlalchemy import select from sqlalchemy.engine import Engine from sqlalchemy.orm import Session from . import model, schema from .storage import FileStorage from .tagging import YOLOv3Tagger def save_image( jpg_data: bytes, file_name: str, tagger: YOLOv3Tagger, storage: FileStorage, session: Session, ): raw_data = PIL.Image.open(io.BytesIO(jpg_data)) tags = tagger(raw_data) saved_name = storage.save(jpg_data) image_model = model.Image(name=file_name, saved_name=saved_name) session.add(image_model) session.commit() tag_models = [model.Tag(image_id=image_model.id, tag=tag) for tag in tags] session.add_all(tag_models) session.commit() return image_model.id, tags def search_images(tag: str, session: Session): stmt = ( select(model.Image) .join(model.Tag, model.Image.id == model.Tag.image_id) .where(model.Tag.tag == tag) ) cursor = session.execute(stmt).scalars() return cursor.fetchall() def get_image(id: int, storage: FileStorage, session: Session): image_model = session.get(model.Image, id) if image_model is None: return None fp = storage.get(image_model.saved_name) # type: ignore if fp is None: raise FileNotFoundError() return fp, image_model.name def create_router(tagger: YOLOv3Tagger, storage: FileStorage, engine: Engine): router = APIRouter() @router.post("/", response_model=schema.PostImageOut) def post(file: UploadFile): file_name = file.filename jpg_data = file.file.read() with Session(bind=engine) as session: image_id, tags = save_image(jpg_data, file_name, tagger, storage, session) return {"id": image_id, "tags": tags} @router.get("/search", response_model=schema.SearchImageOut) def search(tag: str): with Session(bind=engine) as session: return search_images(tag, session) @router.get( "/{id}", response_class=FileResponse, responses={ status.HTTP_200_OK: {"content": {"image/jpg": {}}}, status.HTTP_404_NOT_FOUND: {"model": schema.ErrorMessageOut}, status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": schema.ErrorMessageOut}, }, ) def get(id: int): with Session(bind=engine) as session: try: res = get_image(id, storage, session) if res is None: raise HTTPException( status.HTTP_404_NOT_FOUND, f"image of {id=} not found" ) (image_path, image_name) = res return FileResponse( image_path, media_type="image/jpg", filename=image_name ) except FileNotFoundError: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, "speficied image doesn't exist in the storage", ) return router
save_image
関数はJPEGの画像データ、ファイル名、タグ抽出器、ストレージ、そしてSQLAlchemyのセッションを受け取り、以下の作業をします。
- 画像データをストレージに保存し、
images
テーブルにレコードを1行追加 - 画像からタグを抽出し、
tags
テーブルに抽出できた数だけレコードを追加 - 保存した画像のIDと抽出したタグのリストをdictでまとめて戻す
そしてこのsave_image
をPOST /
に対応するハンドラー関数内で呼び出しています。save_image
の戻り値はdictですが、response_model
を利用してこれをschema.PostImageOut
にマッピングしています。
search_images
関数はタグの文字列とSQLAlchemyのセッションを受け取り、SQLのJOINクエリを利用し、そのタグを持っているimages
のレコードを全て取得しています。そしてrouter
に登録するsearch
関数ではresponse_model
をschema.SearchImageOut
に指定して自動的にSQLAlchemyのオブジェクトをPydanticに変換しています。
get_image
は画像ID、ストレージ、SQLAlchemyのセッションを受け取り、指定したIDの画像を探します。そして、発見できた場合、その実際の保存パスと本来のファイル名を返します。そしてrouter
に登録するget
関数はFileResponse
というレスポンスを表すクラスのインスタンスを上記のパスとファイル名、そしてファイルタイプのimage/jpg
を指定して戻します。
FastAPIはデフォルトではJSONResponse
というその名の通りJSONを表すレスポンスのクラスを使用していますが、FileResponse
は指定されたパスのファイルを自動で読み取るだけでなく、filename
で指定したファイル名をContent-Disposition
ヘッダーにつけてくれます。こうすることでxhやwgetなどのクライアントは自動的に本来のファイル名で保存してくれるわけです。
アプリケーションの起動
__init__.py
に最後の仕上げを書きます。ここまで読んでくださった方であればもう説明は不要だと思います。
from pathlib import Path from fastapi import FastAPI from sqlalchemy import create_engine from .route import create_router from .storage import FileStorage from .tagging import YOLOv3Tagger storage = FileStorage(Path.home().joinpath(".image_storage")) tagger_data_dir = Path(__file__).parent.joinpath("data") tagger = YOLOv3Tagger( tagger_data_dir.joinpath("yolov3-12-int8.onnx"), tagger_data_dir.joinpath("labels.txt"), ) DB_URI = "postgresql+pg8000://<user>:<password>@<host>:<port>/<dbname>" engine = create_engine(DB_URI) app = FastAPI() router = create_router(tagger, storage, engine) app.include_router(router, prefix="")
以下のコマンドで起動します。
pdm run uvicorn webapp.chapter4:app --no-server-header --no-date-header
そして画像を2つ送信してみます。いずれも著作権フリーの画像です。
$ xh post -f :8000/ [email protected] HTTP/1.1 200 OK Content-Length: 23 Content-Type: application/json { "id": 1, "tags": [ "cat" ] } $ xh post -f :8000/ [email protected] HTTP/1.1 200 OK Content-Length: 45 Content-Type: application/json { "id": 2, "tags": [ "diningtable", "knife", "cup" ] }
一つ目の猫の画像からはcat
のタグが抽出され、2つ目のコーヒーカップとノートとペンの画像からは3つのタグが抽出されました。このうち、knife
は誤検出のようです。次にcat
を指定して検索を試します。
HTTP/1.1 200 OK Content-Length: 29 Content-Type: application/json [ { "id": 1, "name": "cat01.jpg" } ]
id=1のcat01.jpgが見つかりました。レスポンスは配列になっているので他にもcat
のタグを持つ画像があればそれらもここに現れます。最後にこのid=1の画像をダウンロードします。
$ xh -d ":8000/1" HTTP/1.1 200 OK Content-Disposition: attachment; filename="cat01.jpg" Content-Length: 243102 Content-Type: image/jpg Etag: b426965e853ed8dbe017986e54f7bcab Last-Modified: Sat, 17 Dec 2022 08:01:42 GMT Downloading 237.40KiB to "cat01.jpg" Done. 237.40KiB in 0.00149s (155.93MiB/s)
このようにContent-Type
がimage/jpg
になっており、Content-Disposition
ヘッダーにファイル名が正しく入っているので、xhはcat01.jpg
という名前で画像ファイルを保存することに成功しました。
まとめ
本記事ではFastAPIを利用したWebAPIの開発について説明しました。FastAPIは開発体験を重視した設計になっており、型情報を活用して安全性が高いWebAPIを非常に効率よく開発できますす。また、機械学習分野におけるPythonの豊富なエコシステムとの連携が容易であり、機械学習のモデルを利用したサービスを開発する際には最も有力な候補になると考えられます。読者の皆様が本記事を通してFastAPIに興味を持ち、ご自身の業務に活かせることができたら幸いです。
杜 世橋(Du Shiqiao) GitHub: lucidfrontier45
編集:はてな編集部