Merge pull request #2143 from JC-Array/deepdanbooru_pre_process

deepbooru tags for textual inversion preproccessing
This commit is contained in:
AUTOMATIC1111 2022-10-12 08:35:27 +03:00 committed by GitHub
commit 2e2d45b281
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 114 additions and 27 deletions

View file

@ -1,21 +1,75 @@
import os.path import os.path
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from multiprocessing import get_context import multiprocessing
import time
def get_deepbooru_tags(pil_image):
"""
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
"""
from modules import shared # prevents circular reference
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha)
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process_queue.put(pil_image)
while shared.deepbooru_process_return["value"] == -1:
time.sleep(0.2)
tags = shared.deepbooru_process_return["value"]
release_process()
return tags
def _load_tf_and_return_tags(pil_image, threshold): def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort):
model, tags = get_deepbooru_tags_model()
while True: # while process is running, keep monitoring queue for new image
pil_image = queue.get()
if pil_image == "QUIT":
break
else:
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort)
def create_deepbooru_process(threshold, alpha_sort):
"""
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
the tags.
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_manager = multiprocessing.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort))
shared.deepbooru_process.start()
def release_process():
"""
Stops the deepbooru process to return used memory
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_queue.put("QUIT")
shared.deepbooru_process.join()
shared.deepbooru_process_queue = None
shared.deepbooru_process = None
shared.deepbooru_process_return = None
shared.deepbooru_process_manager = None
def get_deepbooru_tags_model():
import deepdanbooru as dd import deepdanbooru as dd
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
this_folder = os.path.dirname(__file__) this_folder = os.path.dirname(__file__)
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
if not os.path.exists(os.path.join(model_path, 'project.json')): if not os.path.exists(os.path.join(model_path, 'project.json')):
# there is no point importing these every time # there is no point importing these every time
import zipfile import zipfile
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", load_file_from_url(
model_path) r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
model_path)
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
zip_ref.extractall(model_path) zip_ref.extractall(model_path)
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
@ -24,7 +78,13 @@ def _load_tf_and_return_tags(pil_image, threshold):
model = dd.project.load_model_from_project( model = dd.project.load_model_from_project(
model_path, compile_model=True model_path, compile_model=True
) )
return model, tags
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
width = model.input_shape[2] width = model.input_shape[2]
height = model.input_shape[1] height = model.input_shape[1]
image = np.array(pil_image) image = np.array(pil_image)
@ -46,28 +106,27 @@ def _load_tf_and_return_tags(pil_image, threshold):
for i, tag in enumerate(tags): for i, tag in enumerate(tags):
result_dict[tag] = y[i] result_dict[tag] = y[i]
result_tags_out = []
unsorted_tags_in_theshold = []
result_tags_print = [] result_tags_print = []
for tag in tags: for tag in tags:
if result_dict[tag] >= threshold: if result_dict[tag] >= threshold:
if tag.startswith("rating:"): if tag.startswith("rating:"):
continue continue
result_tags_out.append(tag) unsorted_tags_in_theshold.append((result_dict[tag], tag))
result_tags_print.append(f'{result_dict[tag]} {tag}') result_tags_print.append(f'{result_dict[tag]} {tag}')
# sort tags
result_tags_out = []
sort_ndx = 0
if alpha_sort:
sort_ndx = 1
# sort by reverse by likelihood and normal for alpha
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold:
result_tags_out.append(tag)
print('\n'.join(sorted(result_tags_print, reverse=True))) print('\n'.join(sorted(result_tags_print, reverse=True)))
return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
def subprocess_init_no_cuda():
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
def get_deepbooru_tags(pil_image, threshold=0.5):
context = get_context('spawn')
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
ret = f.result() # will rethrow any exceptions
return ret

View file

@ -249,15 +249,20 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
})) }))
options_templates.update(options_section(('interrogate', "Interrogate Options"), { interrogate_option_dictionary = {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"), "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)")
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), }
}))
if cmd_opts.deepdanbooru:
interrogate_option_dictionary["interrogate_deepbooru_score_threshold"] = OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01})
interrogate_option_dictionary["deepbooru_sort_alpha"] = OptionInfo(True, "Interrogate: deepbooru sort alphabetically", gr.Checkbox)
options_templates.update(options_section(('interrogate', "Interrogate Options"), interrogate_option_dictionary))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),

View file

@ -3,11 +3,14 @@ from PIL import Image, ImageOps
import platform import platform
import sys import sys
import tqdm import tqdm
import time
from modules import shared, images from modules import shared, images
from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption):
width = process_width width = process_width
height = process_height height = process_height
src = os.path.abspath(process_src) src = os.path.abspath(process_src)
@ -25,10 +28,21 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
if process_caption: if process_caption:
shared.interrogator.load() shared.interrogator.load()
if process_caption_deepbooru:
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha)
def save_pic_with_caption(image, index): def save_pic_with_caption(image, index):
if process_caption: if process_caption:
caption = "-" + shared.interrogator.generate_caption(image) caption = "-" + shared.interrogator.generate_caption(image)
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png") caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
elif process_caption_deepbooru:
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process_queue.put(image)
while shared.deepbooru_process_return["value"] == -1:
time.sleep(0.2)
caption = "-" + shared.deepbooru_process_return["value"]
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
shared.deepbooru_process_return["value"] = -1
else: else:
caption = filename caption = filename
caption = os.path.splitext(caption)[0] caption = os.path.splitext(caption)[0]
@ -83,6 +97,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
if process_caption: if process_caption:
shared.interrogator.send_blip_to_ram() shared.interrogator.send_blip_to_ram()
if process_caption_deepbooru:
deepbooru.release_process()
def sanitize_caption(base_path, original_caption, suffix): def sanitize_caption(base_path, original_caption, suffix):
operating_system = platform.system().lower() operating_system = platform.system().lower()
if (operating_system == "windows"): if (operating_system == "windows"):

View file

@ -324,7 +324,7 @@ def interrogate(image):
def interrogate_deepbooru(image): def interrogate_deepbooru(image):
prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold) prompt = get_deepbooru_tags(image)
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
@ -1065,6 +1065,10 @@ def create_ui(wrap_gradio_gpu_call):
process_flip = gr.Checkbox(label='Create flipped copies') process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images into two') process_split = gr.Checkbox(label='Split oversized images into two')
process_caption = gr.Checkbox(label='Use BLIP caption as filename') process_caption = gr.Checkbox(label='Use BLIP caption as filename')
if cmd_opts.deepdanbooru:
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename')
else:
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename', visible=False)
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
@ -1142,6 +1146,7 @@ def create_ui(wrap_gradio_gpu_call):
process_flip, process_flip,
process_split, process_split,
process_caption, process_caption,
process_caption_deepbooru
], ],
outputs=[ outputs=[
ti_output, ti_output,