自動でtokenをrefreshして再送するRustの非同期HTTP Client

半年ほど前から、BlueskyのAT ProtocolのRust版ライブラリを作っている。 memo.sugyan.com

github.com

その中で最近実装した機能の話。

API Agent

本家の atproto (TypeScript実装)に AtpAgent というものがある。

この AtpAgent は、(Blueskyだけに限らない) AT Protocol のための汎用的なエージェントとして提供されている。その機能の一つとしてtokenの管理機能がある。

AT Protocolの認証

少なくとも 2023/11 時点では HTTP Bearer auth でJWTを送信することで認証を行う方式となっている。

com.atproto.server.createSession でログイン成功すると accessJwt と refreshJwt などが含まれる認証情報が返ってくるので、その accessJwt をBearer tokenに使って各エンポイントにアクセスする。

また、 com.atproto.server.refreshSession というエンドポイントがあり、ここを refreshJwt をtokenに使って叩くことでtokenを更新することができる。

tokenの管理と自動更新機構

で、 @atproto/api で提供されている AtpAgent はそれらのtokenの管理を自動で行う機能を持っている。

export class AtpAgent {
  ...
  session?: AtpSessionData

  /**
   * Internal fetch handler which adds access-token management
   */
  private async _fetch(...): Promise<AtpAgentFetchHandlerResponse> {
    ...
    // wait for any active session-refreshes to finish
    await this._refreshSessionPromise
    // send the request
    let res = await AtpAgent.fetch(...)
    // handle session-refreshes as needed
    if (isErrorResponse(res, ['ExpiredToken']) && this.session?.refreshJwt) {
      // attempt refresh
      await this._refreshSession()
      // resend the request with the new access token
      res = await AtpAgent.fetch(...)
    }
    return res
  }
}

まず現時点でもっているsession情報を元にリクエスト処理を試み、そのレスポンスが expired のエラーだったときのみそれを検出してtokenをrefreshして同じ内容のリクエストを再送する。最初のリクエストが成功していた場合はそのままそのレスポンスを返す。 という動き。

ATriumでの実装

で、これと同様の仕組みを持つ AtpAgent をATriumでも実装しようと考えた。

XrpcClient trait

ATriumでは、AT ProtocolのXRPCリクエストを送るための XrpcClient というtraitを定義している。

#[async_trait]
pub trait HttpClient {
    async fn send_http(
        &self,
        request: Request<Vec<u8>>,
    ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>;
}

pub type XrpcResult<O, E> = Result<OutputDataOrBytes<O>, self::Error<E>>;

#[async_trait]
pub trait XrpcClient: HttpClient {
    fn base_uri(&self) -> String;
    #[allow(unused_variables)]
    async fn auth(&self, is_refresh: bool) -> Option<String> {
        None
    }
    async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E>
    where
        P: Serialize + Send + Sync,
        I: Serialize + Send + Sync,
        O: DeserializeOwned + Send + Sync,
        E: DeserializeOwned + Send + Sync,
    {
        ...
    }
}

XRPC は規約に沿ったHTTPリクエスト/レスポンスでしかないので、内部としてはHTTPの処理をすることになる。 が、RustではHTTPリクエスト/レスポンスを処理するための標準ライブラリのようなものはなく、特に非同期の場合は reqwest や Isahc 、Surf のような3rd partyのライブラリが多く使われている、と思う。

ATriumではこれらのライブラリをバックエンドとして開発者が選択できるよう、HTTP部分を抽象化する HttpClient というtraitを定義し、それを継承した XrpcClient を実装したインスタンスを内部に持つ AtpServiceClient が各XRPCを処理する形にしている。

XrpcClient::send_xrpc() はデフォルト実装を持っており、HttpClient::send_http() さえ実装されていれば、あとはそれを使ってリクエストに使う入力のserializeや返ってきたレスポンスJSONのdeserializeなどの処理を行うようになっている。

session管理するwrapper

認証については XrpcClient のメソッドとして async fn auth(&self, is_refresh: bool) -> Option<String> を定義しているだけなので、ここでどのようなtokenを返すかはtrait実装者に任される。

インメモリで管理する場合、内部で Arc<RwLock<Option<Session>>> のように持っておいて管理することで、マルチスレッドで共有されていても安全に扱えるようになる。

(参考: https://github.com/usagi/rust-memory-container-cs)

なので、内部で XrpcClient を実装したものを持つWrapperを作り、主な処理はそれに移譲して auth() だけを実装することで、保持しているsession情報からtokenを返す XrpcClient 実装を作ることができる。

use std::sync::Arc;
use tokio::sync::RwLock;

struct Wrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    inner: T,
    session: Arc<RwLock<Option<Session>>>,
}

#[async_trait]
impl<T> HttpClient for Wrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    async fn send_http(
        &self,
        request: Request<Vec<u8>>,
    ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
        self.inner.send_http(request).await
    }
}

#[async_trait]
impl<T> XrpcClient for Wrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    fn base_uri(&self) -> String {
        self.inner.base_uri()
    }
    async fn auth(&self, is_refresh: bool) -> Option<String> {
        self.session.read().unwrap().as_ref().map(|session| {
            if is_refresh {
                session.refresh_jwt.clone()
            } else {
                session.access_jwt.clone()
            }
        })
    }
}

ここでは非同期の RwLock として tokio::sync を使っている。

これで、AtpAgent も同じsessionを共有し、例えばログイン成功時に取得したsession情報を書き込む、といったことをすれば、その後のリクエストでその値が使われるようになる。

struct AtpAgent<T>
where
    T: XrpcClient + Send + Sync,
{
    api: Service<Wrapper<T>>,
    session: Arc<RwLock<Option<Session>>>,
}

impl<T> AtpAgent<T>
where
    T: XrpcClient + Send + Sync,
{
    fn new(xrpc: T) -> Self {
        let session = Arc::new(RwLock::new(None));
        let api = Service::new(Wrapper {
            inner: xrpc,
            session: Arc::clone(&session)
        });
        Self { api, session }
    }
    async fn login(&self, ...) {
        // login process
        let session: Session = ...;
        self.session.write().await.replace(session);
    }
}

tokenの自動更新 (失敗例)

で、これを使ってtokenの自動更新を行うためには、XrpcClient::send_xrpc() をオーバーライドする形で実装すれば良さそう。

impl Wrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    async fn refresh_session() {
        // refresh process
        let session: Session = ...;
        self.session.write().await.replace(session);
    }
    fn is_expired<O, E>(result: &XrpcResult<O, E>) -> bool
    where
        O: DeserializeOwned + Send + Sync,
        E: DeserializeOwned + Send + Sync,
    {
        if let Err(Error::XrpcResponse(response)) = &result {
            if let Some(XrpcErrorKind::Undefined(body)) = &response.error {
                if let Some("ExpiredToken") = &body.error.as_deref() {
                    return true;
                }
            }
        }
        false
    }
}

impl<T> XrpcClient for Wrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    fn base_uri(&self) -> String {
        ...
    }
    async fn auth(&self, is_refresh: bool) -> Option<String> {
        ...
    }
    async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E>
    where
        P: Serialize + Send + Sync,
        I: Serialize + Send + Sync,
        O: DeserializeOwned + Send + Sync,
        E: DeserializeOwned + Send + Sync,
    {
        let result = self.inner.send_xrpc(request).await;
        // handle session-refreshes as needed
        if Self::is_expired(&result) {
            self.refresh_session().await;
            self.inner.send_xrpc(request).await
        } else {
            result
        }
    }
}

self.inner に send_xrpc() を移譲し、その結果を判定して Expired のエラーであった場合のみ self.refresh_session() を呼び出してtokenを更新し、再度同じ send_xrpc() を呼び出す、という形。

self.refresh_session() が成功して内部のsessionが書き変わっていれば、再度同じリクエストを送ったときにはtokenが更新されているので成功する。

…と思ったが、実際にはこれは思った通りには動かない。 Rustのtrait実装はインスタンスのメソッドをオーバーライドしているわけではなく、あくまでtraitのメソッドを実装しているだけなので、 self.inner に移譲したメソッドは self とは別のインスタンスとして扱われる。

つまり self.inner.send_xrpc() のデフォルト実装中で呼ばれる self.auth() はあくまで self.inner に実装されている auth() であり、Wrapper に実装されている auth() は呼ばれない。なので、いくら Wrapper の内部でsessionを更新しても self.inner には関係がない、ということになる。

2重のwrapperで解決

これをどうやって解決するかしばらく悩んだが、結局wrapperをもう一つ作って2重にsessionを共有することで想定した動きをするようになった。

主な処理を移譲された inner が同じsession情報を持って使ってくれていれば問題ないので、元のwrapperを RefreshWrapper として、同じように XrpcClient を実装する SessionWrapper を作り、 auth() で self.session を参照する機能だけをそちらに持たせるようにする。

struct RefreshWrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    inner: T,
    session: Arc<RwLock<Option<Session>>>,
}

#[async_trait]
impl<T> HttpClient for RefreshWrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    ... // (use inner)
}

#[async_trait]
impl<T> XrpcClient for RefreshWrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E>
    where
        P: Serialize + Send + Sync,
        I: Serialize + Send + Sync,
        O: DeserializeOwned + Send + Sync,
        E: DeserializeOwned + Send + Sync,
    {
        let result = self.inner.send_xrpc(request).await;
        // handle session-refreshes as needed
        if Self::is_expired(&result) {
            self.refresh_session().await;
            self.inner.send_xrpc(request).await
        } else {
            result
        }
    }
}

struct SessionWrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    inner: T,
    session: Arc<RwLock<Option<Session>>>,
}

#[async_trait]
impl<T> HttpClient for SessionWrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    ... // (use inner)
}

#[async_trait]
impl<T> XrpcClient for SessionWrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    ...

    async fn auth(&self, is_refresh: bool) -> Option<String> {
        self.session.read().await.as_ref().map(|session| {
            if is_refresh {
                session.refresh_jwt.clone()
            } else {
                session.access_jwt.clone()
            }
        })
    }
}

そして、 AtpAgent では XrpcClient を実装したものとして RefreshWrapper<SessionWrapper<T>> を使うようにする。両者間で session を共有するために Arc<RwLock<Option<Session>>> を渡している。

struct AtpAgent<T>
where
    T: XrpcClient + Send + Sync,
{
    api: Service<RefreshWrapper<SessionWrapper<T>>>,
    session: Arc<RwLock<Option<Session>>>,
}

impl<T> AtpAgent<T>
where
    T: XrpcClient + Send + Sync,
{
    fn new(xrpc: T) -> Self {
        let session = Arc::new(RwLock::new(None));
        let api = Service::new(Arc::new(RefreshWrapper {
            inner: SessionWrapper {
                inner: xrpc,
                session: Arc::clone(&session),
            },
            session: Arc::clone(&session),
        }));
        Self { api, session }
    }
}

send_xrpc() 内でのexpire検出と更新・再送の処理は RefreshWrapper で行われるが、実際の送信は内部の SessionWrapper に移譲されることになる。 SessionWrapper では共有されている self.session を auth() 内で参照する動作をするので、 RefreshWrapper (や、それを使う AtpAgent 本体) で更新されたsession情報がそのまま使われることになる。

これでtokenの自動更新が実現された。

並行処理での同時更新の問題

もう一つ、TypeScript実装の AtpAgent が持っている機能として、複数の refreshSession() が同時に来ても1回しか実行されないように制御する、というものがある。

  /**
   * Internal helper to refresh sessions
   * - Wraps the actual implementation in a promise-guard to ensure only
   *   one refresh is attempted at a time.
   */
  private async _refreshSession() {
    if (this._refreshSessionPromise) {
      return this._refreshSessionPromise
    }
    this._refreshSessionPromise = this._refreshSessionInner()
    try {
      await this._refreshSessionPromise
    } finally {
      this._refreshSessionPromise = undefined
    }
  }

内部状態として Promise<void> | undefined を持っており、 _refreshSession() が呼ばれた時点でそれが undefined であれば Promise<void> をセットした上で実際にその処理を await し、既に Promise<void> がセットされている場合はそれを返す、という動き。

もしAgentが並行して複数のAPIをほぼ同時に叩いたときに、tokenがexpiredだった場合はその結果がほぼ同時に返ってくることになる。その時点ではtokenがrefreshされていないので、それぞれがほぼ同時に自動refresh処理を行うことになるが、実際には1回だけrefreshされていれば良いので、そのように制御するための仕組み、と考えられる。

そもそも同じrefresh tokenを使って複数回refreshのリクエストすると、既に使用されたrefresh tokenが無効になり2回目以降のrefreshはエラーになる可能性もあり(ここはサーバ実装次第と思われる、現時点でのblueskyでは特に何も起こらず新しいtokenが発行されるだけ、のようだ)、非同期で並行処理され得る環境ではこういった制御は必要になる。

…しかしRustでは同じように実装しようとしても上手くいかなかった。 async なinnerを呼んだ返り値を保持するとなると Pin<Box<impl Future<Output = Result<...>> + 'static>> のような形になり、しかしこれを Mutex とかで保持しようとしてもこれは Send ではないので Mutex に入れられない、など… 何度もコンパイルエラーに悩まされて挫折した。 Rustの非同期に詳しい人なら上手く実装できるんだろうか…

Notify による制御実装

ChatGPTと長々と議論し、結局 tokio::sync にある Mutex と Notify を使うことで同様の動作を実現させた。

use tokio::sync::{Mutex, Notify};

struct RefreshWrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    inner: T,
    session: Arc<RwLock<Option<Session>>>,
    is_refreshing: Mutex<bool>,
    notify: Notify,
}

impl<T> RefreshWrapper<T>
where
    T: XrpcClient + Send + Sync,
{
    ...

    async fn refresh_session(&self) {
        {
            let mut is_refreshing = self.is_refreshing.lock().await;
            if *is_refreshing {
                drop(is_refreshing);
                return self.notify.notified().await;
            }
            *is_refreshing = true;
        }
        self.refresh_session_inner().await;
        *self.is_refreshing.lock().await = false;
        self.notify.notify_waiters();
    }
}

まず self.is_refreshing で、refresh実行中のものがあるか否かを保持する。これは Mutex で管理されるので、 最初にロックを取得して true に変更したものが完了してまた false に戻すまでは他の処理からの読み取り結果は必ず true になる。 そして、その true に変更できたものだけが、後続の実際の更新処理である self.refresh_session_inner() を実行する。

self.is_refreshing が true になっている間は、 self.refresh_session_inner() が終わったあとに呼ばれる self.notify.notify_waiters() によって完了を知らされるまで待機するだけ、という形になる。 こういった他のスレッドからの通知を待つための仕組みとして tokio::sync::Notify があるようだ。この場合は完了したことを知りたいだけなので Notify だが、処理結果も知りたい場合は oneshot などを使うと良いのかもしれない。

ライブラリとしては特定の非同期ランタイムに依存するようにはしたくないので tokio などを使うのは避けたかったが、標準ライブラリや futures などには同様の仕組みがないようだったので仕方なく tokio を使うことにした。実際のところ sync featureで使うこれらのものは特にランタイム依存は無いようで、 async-std など別の非同期ランタイムで実行しても問題なく動作するようではあった。

まとめ

ライブラリの依存は増えてしまったが、どうにか AtpAgent として実装したい機能は実現できた。Rustむずかしい…。

他にもっと良い方法をご存知の方がいればpull-requestなどいただけると嬉しいです。