Created
May 15, 2024 11:17
-
-
Save kohya-ss/1711f17fe77def811fcaf82877b0bec2 to your computer and use it in GitHub Desktop.
Revisions
-
kohya-ss created this gist
May 15, 2024 .There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,193 @@ # Dart v2を用いて sd-scripts の gen_img.py 用のプロンプトファイルを作成するスクリプト import random import torch from transformers import AutoModelForCausalLM, AutoTokenizer # Rating tag: <|rating:sfw|>, <|rating:general|>, <|rating:sensitive|>, nsfw, <|rating:questionable|>, <|rating:explicit|> # Aspect ratio tag: <|aspect_ratio:ultra_wide|>, <|aspect_ratio:wide|>, <|aspect_ratio:square|>, <|aspect_ratio:tall|>, <|aspect_ratio:ultra_tall|> # Length tag: <|length:very_short|>, <|length:short|>, <|length:medium|>, <|length:long|>, <|length:very_long|> """ prompt = ( f"<|bos|>" f"<copyright>{copyright_tags_here}</copyright>" f"<character>{character_tags_here}</character>" f"<|rating:general|><|aspect_ratio:tall|><|length:long|>" f"<general>{general_tags_here}" ) """ def get_prompt(model, num_prompts, rating, aspect_ratio, length, first_tag): prompt = ( f"<|bos|>" f"<copyright></copyright>" f"<character></character>" f"{rating}{aspect_ratio}{length}" f"<general>{first_tag}" ) prompts = [prompt] * num_prompts inputs = tokenizer(prompts, return_tensors="pt").input_ids inputs = inputs.to("cuda") with torch.no_grad(): outputs = model.generate( inputs, do_sample=True, temperature=1.0, top_p=1.0, top_k=100, max_new_tokens=128, num_beams=1, ) # return ", ".join([tag for tag in tokenizer.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != ""]) decoded = [] for i in range(num_prompts): output = outputs[i].cpu() tags = tokenizer.batch_decode(output, skip_special_tokens=True) prompt = ", ".join([tag for tag in tags if tag.strip() != ""]) decoded.append(prompt) return decoded # 網羅的に作るタグ類 """ 1024 x 1024 1:1 Square 1152 x 896 9:7 896 x 1152 7:9 1216 x 832 19:13 832 x 1216 13:19 1344 x 768 7:4 Horizontal 768 x 1344 4:7 Vertical 1536 x 640 12:5 Horizontal 640 x 1536 5:12 Vertical """ DIMENSIONS = [(1024, 1024), (1152, 896), (896, 1152), (1216, 832), (832, 1216), (1344, 768), (768, 1344), (1536, 640), (640, 1536)] ASPECT_RATIO_TAGS = [ "<|aspect_ratio:square|>", "<|aspect_ratio:wide|>", "<|aspect_ratio:tall|>", "<|aspect_ratio:wide|>", "<|aspect_ratio:tall|>", "<|aspect_ratio:wide|>", "<|aspect_ratio:tall|>", "<|aspect_ratio:ultra_wide|>", "<|aspect_ratio:ultra_tall|>", ] RATING_MODIFIERS = ["safe", "sensitive"] # , "nsfw", "explicit, nsfw"] RATING_TAGS = ["<|rating:general|>", "<|rating:sensitive|>"] # , "<|rating:questionable|>", "<|rating:explicit|>"] FIRST_TAGS = [ "no humans", "1girl", "2girls", "3girls", "4girls", "5girls", "6+girls", "1boy", "2boys", "3boys", "4boys", "5boys", "6+boys", "1other", "2others", "3others", "4others", "5others", "6+others", ] # ランダムに選ぶタグ類 LENGTH_TAGS = ["<|length:very_short|>", "<|length:short|>", "<|length:medium|>", "<|length:long|>", "<|length:very_long|>"] """ newest 2021 to 2024 recent 2018 to 2020 mid 2015 to 2017 early 2011 to 2014 oldest 2005 to 2010 """ YEAR_MODIFIERS = [None, "newest", "recent", "mid"] # , "early", "oldest"] # ranomly select 0 to 4 of these QUALITY_MODIFIERS_AND_AESTHETIC = ["masterpiece", "best quality", "very aesthetic", "absurdres"] # negative prompt NEGATIVE_PROMPT = ( "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, " + "oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]" ) # NEGATIVE_PROMPT = ( # "nsfw, lowres, bad, text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, " # + "oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, abstract" # ) NUM_PROMPTS_PER_VARIATION = 8 BATCH_SIZE = 8 # 大きくしたいが、バッチ内で length が同じになってしまう assert NUM_PROMPTS_PER_VARIATION * len(YEAR_MODIFIERS) % BATCH_SIZE == 0 MODEL_NAME = "p1atdev/dart-v2-base" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16) model.to("cuda") # make prompts PARTITION = "a" # prefix for the output file random.seed(42) prompts = [] for rating_modifier, rating_tag in zip(RATING_MODIFIERS, RATING_TAGS): negative_prompt = NEGATIVE_PROMPT if "nsfw" in rating_modifier: negative_prompt = negative_prompt.replace("nsfw, ", "") for dimension, aspect_ratio_tag in zip(DIMENSIONS, ASPECT_RATIO_TAGS): for first_tag in FIRST_TAGS: print(f"rating: {rating_modifier}, aspect ratio: {dimension}, first tag: {first_tag}") # year_modifier はDart v2の引数にならないので、ここでバッチを作ることでバッチサイズを稼ぐ dart_prompts = [] for i in range(0, NUM_PROMPTS_PER_VARIATION * len(YEAR_MODIFIERS), BATCH_SIZE): # ひとつのバッチの中で length が同じになってしまうのでどうにかしたいけど難しそう length = random.choice(LENGTH_TAGS) dart_prompts += get_prompt(model, BATCH_SIZE, rating_tag, aspect_ratio_tag, length, first_tag) num_prompts_for_each_year_modifier = NUM_PROMPTS_PER_VARIATION for j, year_modifier in enumerate(YEAR_MODIFIERS): for prompt in dart_prompts[j * num_prompts_for_each_year_modifier : (j + 1) * num_prompts_for_each_year_modifier]: # escape `(` and `)`, like "star (symbol)" -> "star \(symbol\)" prompt = prompt.replace("(", "\\(").replace(")", "\\)") # select quality modifiers and aesthetic quality_modifiers = random.sample(QUALITY_MODIFIERS_AND_AESTHETIC, random.randint(0, 4)) quality_modifiers = ", ".join(quality_modifiers) # combine quality modifiers and aesthetic, year modifier and rating modifier qm = f"{quality_modifiers}, " if quality_modifiers else "" ym = f", {year_modifier}" if year_modifier else "" # build final prompt image_index = len(prompts) width, height = dimension rm_filename = rating_modifier.replace(", ", "_") # "nsfw, explicit" -> "nsfw_explicit" ym_filename = year_modifier if year_modifier else "none" ft_filename = first_tag.replace("+", "") # remove "+" from "6+girls" etc. ft_filename = ft_filename.replace(" ", "") # remove space from "no humans" etc. image_filename = ( f"{PARTITION}{image_index:08d}_{rm_filename}_{width:04d}x{height:04d}_{ym_filename}_{ft_filename}.webp" ) seed = random.randint(0, 2**32 - 1) final_prompt = f"{qm}{prompt}, {rating_modifier}{ym} --n {negative_prompt} --w {width} --h {height} --ow {width} --oh {height} --d {seed} --f {image_filename}" prompts.append(final_prompt) # break # test # break # break # output to a file with open(f"prompts_{PARTITION}.txt", "w") as f: f.write("\n".join(prompts)) print(f"Done. {len(prompts)} prompts are written to prompts_{PARTITION}.txt.")