Created
May 17, 2024 12:57
-
-
Save kohya-ss/4de9ab8cd3f9056ccd59957d87fe8882 to your computer and use it in GitHub Desktop.
WD14 Taggerでタグごとの確信度を取得する
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import csv | |
import glob | |
import os | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
import onnx | |
import onnxruntime as ort | |
from huggingface_hub import hf_hub_download | |
# from wd14 tagger | |
IMAGE_SIZE = 448 | |
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-vit-tagger-v3" | |
def preprocess_image(image): | |
image = np.array(image) | |
image = image[:, :, ::-1] # RGB->BGR | |
# pad to square | |
size = max(image.shape[0:2]) | |
pad_x = size - image.shape[1] | |
pad_y = size - image.shape[0] | |
pad_l = pad_x // 2 | |
pad_t = pad_y // 2 | |
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) | |
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 | |
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) | |
image = image.astype(np.float32) | |
return image | |
def main(args): | |
print("Loading wd14 tagger from Hugging Face") | |
repo_id = args.repo_id | |
onnx_path = hf_hub_download(repo_id, "model.onnx") | |
csv_path = hf_hub_download(repo_id, "selected_tags.csv") | |
print("Running wd14 tagger with onnx") | |
print(f"loading onnx model: {onnx_path}") | |
model = onnx.load(onnx_path) | |
input_name = model.graph.input[0].name | |
try: | |
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value | |
except Exception: | |
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param | |
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0: | |
# some rebatch model may use 'N' as dynamic axes | |
print(f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}") | |
args.batch_size = batch_size | |
del model | |
if "OpenVINOExecutionProvider" in ort.get_available_providers(): | |
# requires provider options for gpu support | |
# fp16 causes nonsense outputs | |
ort_sess = ort.InferenceSession( | |
onnx_path, | |
providers=(["OpenVINOExecutionProvider"]), | |
provider_options=[{"device_type": "GPU_FP32"}], | |
) | |
else: | |
ort_sess = ort.InferenceSession( | |
onnx_path, | |
providers=( | |
["CUDAExecutionProvider"] | |
if "CUDAExecutionProvider" in ort.get_available_providers() | |
else ( | |
["ROCMExecutionProvider"] | |
if "ROCMExecutionProvider" in ort.get_available_providers() | |
else ["CPUExecutionProvider"] | |
) | |
), | |
) | |
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") | |
# 依存ライブラリを増やしたくないので自力で読むよ | |
with open(csv_path, "r", encoding="utf-8") as f: | |
reader = csv.reader(f) | |
line = [row for row in reader] | |
header = line[0] # tag_id,name,category,count | |
rows = line[1:] | |
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" | |
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"] | |
general_tags = [row[1] for row in rows[0:] if row[2] == "0"] | |
character_tags = [row[1] for row in rows[0:] if row[2] == "4"] | |
# 画像を読み込む | |
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) | |
image_paths += glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) | |
image_paths += glob.glob(os.path.join(args.train_data_dir, "*.png")) | |
image_paths += glob.glob(os.path.join(args.train_data_dir, "*.webp")) | |
print(f"found {len(image_paths)} images.") | |
os.makedirs(args.output_dir, exist_ok=True) | |
def run_batch(path_imgs): | |
imgs = np.array([im for _, im in path_imgs]) | |
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy | |
probs = probs[: len(path_imgs)] | |
for (image_path, _), prob in zip(path_imgs, probs): | |
tag_confidences = {} | |
# rating tags | |
for i in range(4): | |
tag_confidences[rating_tags[i]] = prob[i] | |
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold | |
for i, p in enumerate(prob[4:]): | |
if p >= args.thresh: | |
if i < len(general_tags): | |
tag_name = general_tags[i] | |
else: | |
tag_name = character_tags[i - len(general_tags)] | |
tag_confidences[tag_name] = p | |
caption_file = os.path.splitext(image_path)[0] + ".csv" | |
caption_file = os.path.join(args.output_dir, os.path.basename(caption_file)) | |
with open(caption_file, "wt", encoding="utf-8") as f: | |
writer = csv.writer(f, lineterminator="\n") | |
writer.writerow(["tag", "confidence"]) | |
for tag, confidence in tag_confidences.items(): | |
writer.writerow([tag, confidence]) | |
b_imgs = [] | |
for image_path in tqdm(image_paths, smoothing=0.0): | |
try: | |
image = Image.open(image_path) | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
image = preprocess_image(image) | |
except Exception as e: | |
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") | |
continue | |
b_imgs.append((image_path, image)) | |
if len(b_imgs) >= args.batch_size: | |
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string | |
run_batch(b_imgs) | |
b_imgs.clear() | |
if len(b_imgs) > 0: | |
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string | |
run_batch(b_imgs) | |
print("done!") | |
def setup_parser() -> argparse.ArgumentParser: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") | |
parser.add_argument( | |
"--repo_id", | |
type=str, | |
default=DEFAULT_WD14_TAGGER_REPO, | |
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID, default: " | |
+ DEFAULT_WD14_TAGGER_REPO, | |
) | |
parser.add_argument("--batch_size", type=int, default=16, help="batch size in inference / 推論時のバッチサイズ") | |
parser.add_argument( | |
"--thresh", | |
type=float, | |
default=0.35, | |
help="threshold of confidence to add a tag / タグを追加するか判定する閾値, default: 0.35", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default=".", | |
help="output directory for tag confidence csv files / タグの確信度のCSVファイルの出力ディレクトリ", | |
) | |
return parser | |
if __name__ == "__main__": | |
parser = setup_parser() | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment