Stimulator

機械学習とか好きな技術話とかエンジニア的な話とかを書く

Pure Rustな近似最近傍探索ライブラリhoraを用いた画像検索を実装する

f:id:vaaaaaanquish:20210810063410p:plain

- はじめに -

本記事は、近似最近傍探索(ANN: Approximate Nearest Neighbor)による画像検索をRustを用いて実装した際のメモである。

画像からの特徴量抽出にTensorFlow Rust bindings、ANNのインデックス管理にRustライブラリであるhoraを利用した。

RustとANNの現状および、実装について触れる。

 

 

- RustとANN -

Rustの機械学習関連クレート、事例をまとめたリポジトリがある。

github.com

この中でも、ANN関連のクレートは充実している。利用する場合は以下のようなクレートが候補になる。

* Enet4/faiss-rs
* lerouxrgd/ngt-rs
* rust-cv/hnsw
* hora-search/hora
* InstantDomain/instant-distance
* granne/granne
* qdrant/qdrant

Pythonでもしばしば利用されるfacebook researchのfaiss、Yahoo!のNGTのrust bindingsは強く候補に上がる。C++からGPUが触れる点から利用だけならfaissが活用しやすいだろう。

 
他にPure Rustで機能が充実しているクレートにhoraがある。
github.com

horaには、PythonJavascriptJavaのbindingsがあるだけでなく、Pure Rustである事でWebAssembly化などもサポートしている。
また、インデキシングアルゴリズムとして多く利用されているHNSWIndex以外にグラフベースのSatellite System Graph*1、直積量子化を行うProduct Quantization Inverted File*2が実装されており、開発が継続されている数少ないクレートである。
一部SIMDによる高速化が図られている(https://github.com/rust-lang/packed_simdによるもの)。

(horaの由来は「小さな恋の歌」とREADMEに書いてあるが、どういう経路で知られたのかよくわからない)

今回は、画像検索のwasm化を目指し、horaを利用する。
画像検索がwasm化する事で、API経由で行われていた画像検索の一部がエッジデバイス上で処理できる可能性などの幅が出る事を期待する。
例えば、ネット環境を扱えないや工場やサーバセンター、病院であったり、個人情報の観点でスマフォやカメラの外に出せない画像をその場で類似画像検索にかける事ができる可能性である。

 
画像特徴を抽出する部分でもwasm化を目指すため、wasmの利用実績が多いTensorFlowを利用する。

TensorFlowにはRust bindingsが存在する。
github.com

今回はこちらを利用してモデルを作成し、wasm化する。
他にもDNNのライブラリはいくつかあるが、開発が活発でないか、PyTorchのRust bindingsは現在中間層の出力を受け取る方法がないなど、機能的に難しい場合が多かった。

(実験時に作成したPyTorchのRust bindingsでpretrain modelのpredictを実行するdockerなども公開している https://github.com/vaaaaanquish/tch-rs-pretrain-example-docker

 

- pretrainモデルによる特徴量化 -

TensorFlow 2.xでのRustとPythonの相互運用に関する以下の記事を参考にした。

TensorFlow 2.xでのRustとPython

import tensorflow as tf
from keras.models import Model
from tensorflow.python.framework.convert_to_constants import \
    convert_variables_to_constants_v2

# pretrainモデルの読み込み
model = tf.keras.applications.resnet50.ResNet50(weights='imagenet')

# 中間層の出力を得るモデルにする
embedding_model = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)

# tf.functionに変換しpbファイルとしてgraphを保存できる状態にする
resnet = tf.TensorSpec(embedding_model.input_shape, tf.float32, name="resnet")
concrete_function = tf.function(lambda x: embedding_model(x)).get_concrete_function(resnet)
frozen_model = convert_variables_to_constants_v2(concrete_function)

# fileをdumpする
tf.io.write_graph(frozen_model.graph, '/app/model', "model.pb", as_text=False)

Rustのbindingsから読み込み、画像ファイルを特徴量に変換する。

// モデルファイルを読み込み、セッションを作る
let mut graph = Graph::new();
let mut proto = Vec::new();
File::open("model/model.pb")?.read_to_end(&mut proto)?;
graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;
let session = Session::new(&SessionOptions::new(), &graph)?;

// 入力画像を読み込み、リサイズしてTensorに変換する
let img = ImageReader::open("./img/example.jpeg")?.decode()?;
let resized_img = img.resize_exact(224 as u32, 224 as u32, FilterType::Lanczos3);
let img_vec: Vec<f32> = resized_img.to_rgb8().to_vec().iter().map(|x| *x as f32).collect();
let x = Tensor::new(&[1, 224, 224, 3]).with_values(&img_vec)?;

// DNNに入力する
let mut args = SessionRunArgs::new();
args.add_feed(&graph.operation_by_name_required("resnet")?, 0, &x);
let output = args.request_fetch(&graph.operation_by_name_required("Identity")?, 0);
session.run(&mut args)?;

// check result
let output_tensor: Tensor<f32> = args.fetch(output)?;
let output_array: Vec<f32> = output_tensor.iter().map(|x| x.clone()).collect();
println!("{:?}", output_array);

出力として、特徴量vectorが得られる。

 

- 画像特徴のインデックスと検索 -

horaを利用して画像検索を行う。

// init index
let mut index = hora::index::hnsw_idx::HNSWIndex::<f32, usize>::new(2048, &hora::index::hnsw_params::HNSWParams::<f32>::default(),);

// 特定ディレクトリの画像ファイルをインデックス
let paths = fs::read_dir("img")?;
let mut file_map = HashMap::new();
for (i, path) in paths.into_iter().enumerate() {
    let file_path = path?.path();
    let path_str = file_path.to_str();
    if path_str.is_some() {
        file_map.insert(i, path_str.unwrap().to_string().clone());  // ファイル一覧を作成
        let emb_vec = emb.convert_from_img(path_str.unwrap())?;     // 画像特徴を得るメソッド
        index.add(emb_vec.as_slice(), i)?;                          // インデックス
    }
}
index.build(hora::core::metrics::Metric::Euclidean).unwrap();

// 画像をqueryとして検索
let query_image = &file_map[&100]
let emb_vec_target = emb.convert_from_img(&query_image.to_string())?;  // 画像特徴を得るメソッド
let result = index.search(emb_vec_target.as_slice(), 10);              // 特徴量をqueryとし検索
println!("neighbor images by query: {:?}", query_image);
for r in result {
    println!("{:?}", &file_map[&r]);
}

これらのコードは以下に公開している。

また、上記にはfood-101データセットを用いたインデキシングのサンプルが配置してあるため、今回はそちらを利用して検索の動作確認を行った。

www.tensorflow.org

 

- 検索結果 -

query画像をランダムに選択してTop5の画像を目視でチェックする。

f:id:vaaaaaanquish:20210810060405p:plain
餃子queryとTop5
f:id:vaaaaaanquish:20210810061939p:plain
ラーメンqueryとTop5

餃子は1つだけ間違えて寿司を引いてきているが概ね良さそう。

カテゴリを利用した精度測定などが考えられるが今回はここまで。

- おわりに -

Rustによる画像検索を実装し、動作を確認できた。

エッジデバイスやスマフォ上での画像検索が出来るようになってくると、インデックスファイルを小さくしても精度が保てるモデルの研究が出てきたりするかもなと妄想することができた。

コードは以下に公開した。
github.com

wasm化した上での画像検索は出来てはいるので次はそちらを書く。