Last active
March 18, 2024 23:17
-
-
Save kohya-ss/31edb9e1f3bde12a87228c82b7c38741 to your computer and use it in GitHub Desktop.
VAEとTAESDのdecode結果を比較するやつ、Gradio版
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Claude 3 Opus とめっちゃやり取りして動くようになった | |
# python vae_vs_taesd_gradio.py --image_dir /path/to/image/directory | |
import os | |
import argparse | |
import random | |
from PIL import Image | |
import torch | |
from diffusers import AutoencoderKL, AutoencoderTiny | |
import numpy as np | |
import gradio as gr | |
# コマンドライン引数のパーサーを設定 | |
parser = argparse.ArgumentParser(description="VAE and TAESD performance comparison") | |
parser.add_argument("--image_dir", type=str, required=True, help="Directory containing images") | |
args = parser.parse_args() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# VAEとTAESDをHuggingFaceから読み込む | |
# vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") | |
# taesd = AutoencoderKL.from_pretrained("Doggettx/sd-xlarge-taesd") | |
print("loading VAE...") | |
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae") | |
print("loading TAESD...") | |
taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl") | |
vae.to(device) | |
taesd.to(device) | |
# 画像ファイルのリストを取得 | |
image_files = [f for f in os.listdir(args.image_dir) if f.endswith(".jpg") or f.endswith(".png")] | |
def load_image(image_path): | |
image = Image.open(image_path).convert("RGB") | |
# 画像の面積が1024x1024より大きい場合、アスペクト比を保ちつつ縮小 | |
width, height = image.size | |
if width * height > 1024 * 1024: | |
ratio = (1024 * 1024 / (width * height)) ** 0.5 | |
new_width = int(width * ratio) | |
new_height = int(height * ratio) | |
image = image.resize((new_width, new_height)) | |
# 幅、高さとも8で割り切れるサイズにcrop | |
width, height = image.size | |
crop_width = width // 8 * 8 | |
crop_height = height // 8 * 8 | |
left = (width - crop_width) // 2 | |
top = (height - crop_height) // 2 | |
right = left + crop_width | |
bottom = top + crop_height | |
image = image.crop((left, top, right, bottom)) | |
return image | |
def encode_image(image): | |
# 画像をVAEでencodeしlatentsに変換 | |
# image_tensor = torch.from_numpy(np.array(image)).float() / 255.0 | |
image_tensor = torch.from_numpy(np.array(image)).float() / 127.5 - 1 | |
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
latents = vae.encode(image_tensor).latent_dist.sample().squeeze(0) | |
return latents | |
def decode_latents(latents, model): | |
# latentsを画像にdecode | |
with torch.no_grad(): | |
decoded_image = model.decode(latents.unsqueeze(0)).sample | |
decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1) | |
decoded_image = decoded_image.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8)) | |
return decoded_image | |
def compare_vae_taesd(left_button=None, right_button=None): | |
global current_index, correct_count, correct_side | |
# ボタンがクリックされたかどうかを判定 | |
button_clicked = left_button is not None or right_button is not None | |
if button_clicked: | |
if (left_button and correct_side == "LEFT") or (right_button and correct_side == "RIGHT"): | |
correct_count += 1 | |
current_index += 1 | |
if current_index < len(image_files): | |
# 画像を読み込んでVAEでencodeし、VAEとTAESDでdecode | |
image_path = os.path.join(args.image_dir, image_files[current_index]) | |
image = load_image(image_path) | |
latents = encode_image(image) | |
vae_decoded = decode_latents(latents, vae) | |
# taesd_decoded = decode_latents(latents, taesd) | |
taesd_decoded = decode_latents(latents * vae.config.scaling_factor, taesd) | |
# decodedした画像をランダムに左右に配置 | |
if random.choice([True, False]): | |
left_image = vae_decoded | |
right_image = taesd_decoded | |
correct_side = "LEFT" | |
else: | |
left_image = taesd_decoded | |
right_image = vae_decoded | |
correct_side = "RIGHT" | |
return left_image, right_image, f"Which side is VAE? ({current_index+1}/{len(image_files)})" | |
else: | |
accuracy = correct_count / len(image_files) | |
return None, None, f"Accuracy: {accuracy:.2f}" | |
# 最初の画像を表示するための関数 | |
def show_first_image(): | |
global current_index, correct_side | |
current_index = 0 | |
left_image, right_image, question = compare_vae_taesd() | |
return left_image, right_image, question | |
# Gradioインターフェースの設定 | |
with gr.Blocks() as demo: | |
gr.Markdown("## VAE vs TAESD Comparison") | |
with gr.Row(): | |
left_image = gr.Image() | |
right_image = gr.Image() | |
with gr.Row(): | |
left_button = gr.Button("LEFT") | |
right_button = gr.Button("RIGHT") | |
question_label = gr.Textbox(label="Question") | |
left_button.click(compare_vae_taesd, inputs=[left_button, right_button], outputs=[left_image, right_image, question_label]) | |
right_button.click(compare_vae_taesd, inputs=[left_button, right_button], outputs=[left_image, right_image, question_label]) | |
current_index = 0 | |
correct_count = 0 | |
demo.load(show_first_image, outputs=[left_image, right_image, question_label]) | |
demo.launch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment