Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Created May 15, 2024 11:17
Show Gist options
  • Save kohya-ss/1711f17fe77def811fcaf82877b0bec2 to your computer and use it in GitHub Desktop.
Save kohya-ss/1711f17fe77def811fcaf82877b0bec2 to your computer and use it in GitHub Desktop.

Revisions

  1. kohya-ss created this gist May 15, 2024.
    193 changes: 193 additions & 0 deletions make_prompts_with_dartv2.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,193 @@
    # Dart v2を用いて sd-scripts の gen_img.py 用のプロンプトファイルを作成するスクリプト

    import random
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer

    # Rating tag: <|rating:sfw|>, <|rating:general|>, <|rating:sensitive|>, nsfw, <|rating:questionable|>, <|rating:explicit|>
    # Aspect ratio tag: <|aspect_ratio:ultra_wide|>, <|aspect_ratio:wide|>, <|aspect_ratio:square|>, <|aspect_ratio:tall|>, <|aspect_ratio:ultra_tall|>
    # Length tag: <|length:very_short|>, <|length:short|>, <|length:medium|>, <|length:long|>, <|length:very_long|>
    """
    prompt = (
    f"<|bos|>"
    f"<copyright>{copyright_tags_here}</copyright>"
    f"<character>{character_tags_here}</character>"
    f"<|rating:general|><|aspect_ratio:tall|><|length:long|>"
    f"<general>{general_tags_here}"
    )
    """


    def get_prompt(model, num_prompts, rating, aspect_ratio, length, first_tag):

    prompt = (
    f"<|bos|>" f"<copyright></copyright>" f"<character></character>" f"{rating}{aspect_ratio}{length}" f"<general>{first_tag}"
    )
    prompts = [prompt] * num_prompts
    inputs = tokenizer(prompts, return_tensors="pt").input_ids
    inputs = inputs.to("cuda")
    with torch.no_grad():
    outputs = model.generate(
    inputs,
    do_sample=True,
    temperature=1.0,
    top_p=1.0,
    top_k=100,
    max_new_tokens=128,
    num_beams=1,
    )
    # return ", ".join([tag for tag in tokenizer.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != ""])
    decoded = []
    for i in range(num_prompts):
    output = outputs[i].cpu()
    tags = tokenizer.batch_decode(output, skip_special_tokens=True)
    prompt = ", ".join([tag for tag in tags if tag.strip() != ""])
    decoded.append(prompt)
    return decoded


    # 網羅的に作るタグ類
    """
    1024 x 1024 1:1 Square
    1152 x 896 9:7
    896 x 1152 7:9
    1216 x 832 19:13
    832 x 1216 13:19
    1344 x 768 7:4 Horizontal
    768 x 1344 4:7 Vertical
    1536 x 640 12:5 Horizontal
    640 x 1536 5:12 Vertical
    """
    DIMENSIONS = [(1024, 1024), (1152, 896), (896, 1152), (1216, 832), (832, 1216), (1344, 768), (768, 1344), (1536, 640), (640, 1536)]
    ASPECT_RATIO_TAGS = [
    "<|aspect_ratio:square|>",
    "<|aspect_ratio:wide|>",
    "<|aspect_ratio:tall|>",
    "<|aspect_ratio:wide|>",
    "<|aspect_ratio:tall|>",
    "<|aspect_ratio:wide|>",
    "<|aspect_ratio:tall|>",
    "<|aspect_ratio:ultra_wide|>",
    "<|aspect_ratio:ultra_tall|>",
    ]
    RATING_MODIFIERS = ["safe", "sensitive"] # , "nsfw", "explicit, nsfw"]
    RATING_TAGS = ["<|rating:general|>", "<|rating:sensitive|>"] # , "<|rating:questionable|>", "<|rating:explicit|>"]

    FIRST_TAGS = [
    "no humans",
    "1girl",
    "2girls",
    "3girls",
    "4girls",
    "5girls",
    "6+girls",
    "1boy",
    "2boys",
    "3boys",
    "4boys",
    "5boys",
    "6+boys",
    "1other",
    "2others",
    "3others",
    "4others",
    "5others",
    "6+others",
    ]


    # ランダムに選ぶタグ類
    LENGTH_TAGS = ["<|length:very_short|>", "<|length:short|>", "<|length:medium|>", "<|length:long|>", "<|length:very_long|>"]

    """
    newest 2021 to 2024
    recent 2018 to 2020
    mid 2015 to 2017
    early 2011 to 2014
    oldest 2005 to 2010
    """
    YEAR_MODIFIERS = [None, "newest", "recent", "mid"] # , "early", "oldest"]

    # ranomly select 0 to 4 of these
    QUALITY_MODIFIERS_AND_AESTHETIC = ["masterpiece", "best quality", "very aesthetic", "absurdres"]

    # negative prompt
    NEGATIVE_PROMPT = (
    "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, "
    + "oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"
    )
    # NEGATIVE_PROMPT = (
    # "nsfw, lowres, bad, text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, "
    # + "oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, abstract"
    # )

    NUM_PROMPTS_PER_VARIATION = 8
    BATCH_SIZE = 8 # 大きくしたいが、バッチ内で length が同じになってしまう

    assert NUM_PROMPTS_PER_VARIATION * len(YEAR_MODIFIERS) % BATCH_SIZE == 0

    MODEL_NAME = "p1atdev/dart-v2-base"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
    model.to("cuda")

    # make prompts
    PARTITION = "a" # prefix for the output file
    random.seed(42)
    prompts = []
    for rating_modifier, rating_tag in zip(RATING_MODIFIERS, RATING_TAGS):
    negative_prompt = NEGATIVE_PROMPT
    if "nsfw" in rating_modifier:
    negative_prompt = negative_prompt.replace("nsfw, ", "")

    for dimension, aspect_ratio_tag in zip(DIMENSIONS, ASPECT_RATIO_TAGS):
    for first_tag in FIRST_TAGS:
    print(f"rating: {rating_modifier}, aspect ratio: {dimension}, first tag: {first_tag}")

    # year_modifier はDart v2の引数にならないので、ここでバッチを作ることでバッチサイズを稼ぐ
    dart_prompts = []
    for i in range(0, NUM_PROMPTS_PER_VARIATION * len(YEAR_MODIFIERS), BATCH_SIZE):
    # ひとつのバッチの中で length が同じになってしまうのでどうにかしたいけど難しそう
    length = random.choice(LENGTH_TAGS)
    dart_prompts += get_prompt(model, BATCH_SIZE, rating_tag, aspect_ratio_tag, length, first_tag)

    num_prompts_for_each_year_modifier = NUM_PROMPTS_PER_VARIATION
    for j, year_modifier in enumerate(YEAR_MODIFIERS):

    for prompt in dart_prompts[j * num_prompts_for_each_year_modifier : (j + 1) * num_prompts_for_each_year_modifier]:
    # escape `(` and `)`, like "star (symbol)" -> "star \(symbol\)"
    prompt = prompt.replace("(", "\\(").replace(")", "\\)")

    # select quality modifiers and aesthetic
    quality_modifiers = random.sample(QUALITY_MODIFIERS_AND_AESTHETIC, random.randint(0, 4))
    quality_modifiers = ", ".join(quality_modifiers)

    # combine quality modifiers and aesthetic, year modifier and rating modifier
    qm = f"{quality_modifiers}, " if quality_modifiers else ""
    ym = f", {year_modifier}" if year_modifier else ""

    # build final prompt
    image_index = len(prompts)
    width, height = dimension

    rm_filename = rating_modifier.replace(", ", "_") # "nsfw, explicit" -> "nsfw_explicit"
    ym_filename = year_modifier if year_modifier else "none"
    ft_filename = first_tag.replace("+", "") # remove "+" from "6+girls" etc.
    ft_filename = ft_filename.replace(" ", "") # remove space from "no humans" etc.
    image_filename = (
    f"{PARTITION}{image_index:08d}_{rm_filename}_{width:04d}x{height:04d}_{ym_filename}_{ft_filename}.webp"
    )

    seed = random.randint(0, 2**32 - 1)

    final_prompt = f"{qm}{prompt}, {rating_modifier}{ym} --n {negative_prompt} --w {width} --h {height} --ow {width} --oh {height} --d {seed} --f {image_filename}"
    prompts.append(final_prompt)
    # break # test
    # break
    # break

    # output to a file
    with open(f"prompts_{PARTITION}.txt", "w") as f:
    f.write("\n".join(prompts))

    print(f"Done. {len(prompts)} prompts are written to prompts_{PARTITION}.txt.")