make CLIP interrogator download original text files if the directory does not exist
remove random artist built-in extension (to re-added as a normal extension on demand) remove artists.csv (but what does it mean????????????????????) make interrogate buttons show Loading... when you click them
This commit is contained in:
parent
40ff6db532
commit
6d805b669e
9 changed files with 46 additions and 3151 deletions
|
@ -49,7 +49,6 @@ A browser interface based on Gradio library for Stable Diffusion.
|
||||||
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
||||||
- Mouseover hints for most UI elements
|
- Mouseover hints for most UI elements
|
||||||
- Possible to change defaults/mix/max/step values for UI elements via text config
|
- Possible to change defaults/mix/max/step values for UI elements via text config
|
||||||
- Random artist button
|
|
||||||
- Tiling support, a checkbox to create images that can be tiled like textures
|
- Tiling support, a checkbox to create images that can be tiled like textures
|
||||||
- Progress bar and live image generation preview
|
- Progress bar and live image generation preview
|
||||||
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
|
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
|
||||||
|
|
3041
artists.csv
3041
artists.csv
File diff suppressed because it is too large
Load diff
|
@ -1,50 +0,0 @@
|
||||||
import random
|
|
||||||
|
|
||||||
from modules import script_callbacks, shared
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
art_symbol = '\U0001f3a8' # 🎨
|
|
||||||
global_prompt = None
|
|
||||||
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
|
|
||||||
|
|
||||||
|
|
||||||
def roll_artist(prompt):
|
|
||||||
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
|
|
||||||
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
|
||||||
|
|
||||||
return prompt + ", " + artist.name if prompt != '' else artist.name
|
|
||||||
|
|
||||||
|
|
||||||
def add_roll_button(prompt):
|
|
||||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
|
||||||
|
|
||||||
roll.click(
|
|
||||||
fn=roll_artist,
|
|
||||||
_js="update_txt2img_tokens",
|
|
||||||
inputs=[
|
|
||||||
prompt,
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
prompt,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def after_component(component, **kwargs):
|
|
||||||
global global_prompt
|
|
||||||
|
|
||||||
elem_id = kwargs.get('elem_id', None)
|
|
||||||
if elem_id not in related_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
if elem_id == "txt2img_prompt":
|
|
||||||
global_prompt = component
|
|
||||||
elif elem_id == "txt2img_clear_prompt":
|
|
||||||
add_roll_button(global_prompt)
|
|
||||||
elif elem_id == "img2img_prompt":
|
|
||||||
global_prompt = component
|
|
||||||
elif elem_id == "img2img_clear_prompt":
|
|
||||||
add_roll_button(global_prompt)
|
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_after_component(after_component)
|
|
|
@ -14,7 +14,6 @@ titles = {
|
||||||
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
||||||
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
||||||
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
||||||
"\u{1f3a8}": "Add a random artist to the prompt.",
|
|
||||||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||||
"\u{1f4c2}": "Open images output directory",
|
"\u{1f4c2}": "Open images output directory",
|
||||||
"\u{1f4be}": "Save style",
|
"\u{1f4be}": "Save style",
|
||||||
|
|
|
@ -126,8 +126,6 @@ class Api:
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
|
||||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||||
|
@ -390,12 +388,6 @@ class Api:
|
||||||
|
|
||||||
return styleList
|
return styleList
|
||||||
|
|
||||||
def get_artists_categories(self):
|
|
||||||
return shared.artist_db.cats
|
|
||||||
|
|
||||||
def get_artists(self):
|
|
||||||
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
|
||||||
|
|
||||||
def get_embeddings(self):
|
def get_embeddings(self):
|
||||||
db = sd_hijack.model_hijack.embedding_db
|
db = sd_hijack.model_hijack.embedding_db
|
||||||
|
|
||||||
|
|
|
@ -1,25 +0,0 @@
|
||||||
import os.path
|
|
||||||
import csv
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
|
|
||||||
|
|
||||||
|
|
||||||
class ArtistsDatabase:
|
|
||||||
def __init__(self, filename):
|
|
||||||
self.cats = set()
|
|
||||||
self.artists = []
|
|
||||||
|
|
||||||
if not os.path.exists(filename):
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(filename, "r", newline='', encoding="utf8") as file:
|
|
||||||
reader = csv.DictReader(file)
|
|
||||||
|
|
||||||
for row in reader:
|
|
||||||
artist = Artist(row["artist"], float(row["score"]), row["category"])
|
|
||||||
self.artists.append(artist)
|
|
||||||
self.cats.add(artist.category)
|
|
||||||
|
|
||||||
def categories(self):
|
|
||||||
return sorted(self.cats)
|
|
|
@ -5,12 +5,13 @@ from collections import namedtuple
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.hub
|
||||||
|
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import devices, paths, lowvram, modelloader
|
from modules import devices, paths, lowvram, modelloader, errors
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
clip_model_name = 'ViT-L/14'
|
clip_model_name = 'ViT-L/14'
|
||||||
|
@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"])
|
||||||
re_topn = re.compile(r"\.top(\d+)\.")
|
re_topn = re.compile(r"\.top(\d+)\.")
|
||||||
|
|
||||||
|
|
||||||
|
def download_default_clip_interrogate_categories(content_dir):
|
||||||
|
print("Downloading CLIP categories...")
|
||||||
|
|
||||||
|
tmpdir = content_dir + "_tmp"
|
||||||
|
try:
|
||||||
|
os.makedirs(tmpdir)
|
||||||
|
|
||||||
|
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
|
||||||
|
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
|
||||||
|
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
|
||||||
|
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
|
||||||
|
|
||||||
|
os.rename(tmpdir, content_dir)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "downloading default CLIP interrogate categories")
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmpdir):
|
||||||
|
os.remove(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
class InterrogateModels:
|
class InterrogateModels:
|
||||||
blip_model = None
|
blip_model = None
|
||||||
clip_model = None
|
clip_model = None
|
||||||
clip_preprocess = None
|
clip_preprocess = None
|
||||||
categories = None
|
|
||||||
dtype = None
|
dtype = None
|
||||||
running_on_cpu = None
|
running_on_cpu = None
|
||||||
|
|
||||||
def __init__(self, content_dir):
|
def __init__(self, content_dir):
|
||||||
self.categories = []
|
self.loaded_categories = None
|
||||||
|
self.content_dir = content_dir
|
||||||
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||||
|
|
||||||
if os.path.exists(content_dir):
|
def categories(self):
|
||||||
for filename in os.listdir(content_dir):
|
if self.loaded_categories is not None:
|
||||||
|
return self.loaded_categories
|
||||||
|
|
||||||
|
self.loaded_categories = []
|
||||||
|
|
||||||
|
if not os.path.exists(self.content_dir):
|
||||||
|
download_default_clip_interrogate_categories(self.content_dir)
|
||||||
|
|
||||||
|
if os.path.exists(self.content_dir):
|
||||||
|
for filename in os.listdir(self.content_dir):
|
||||||
m = re_topn.search(filename)
|
m = re_topn.search(filename)
|
||||||
topn = 1 if m is None else int(m.group(1))
|
topn = 1 if m is None else int(m.group(1))
|
||||||
|
|
||||||
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
|
with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
|
||||||
lines = [x.strip() for x in file.readlines()]
|
lines = [x.strip() for x in file.readlines()]
|
||||||
|
|
||||||
self.categories.append(Category(name=filename, topn=topn, items=lines))
|
self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
|
||||||
|
|
||||||
|
return self.loaded_categories
|
||||||
|
|
||||||
def load_blip_model(self):
|
def load_blip_model(self):
|
||||||
import models.blip
|
import models.blip
|
||||||
|
@ -139,7 +172,6 @@ class InterrogateModels:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'interrogate'
|
shared.state.job = 'interrogate'
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
@ -159,12 +191,7 @@ class InterrogateModels:
|
||||||
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
if shared.opts.interrogate_use_builtin_artists:
|
for name, topn, items in self.categories():
|
||||||
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
|
|
||||||
|
|
||||||
res += ", " + artist[0]
|
|
||||||
|
|
||||||
for name, topn, items in self.categories:
|
|
||||||
matches = self.rank(image_features, items, top_count=topn)
|
matches = self.rank(image_features, items, top_count=topn)
|
||||||
for match, score in matches:
|
for match, score in matches:
|
||||||
if shared.opts.interrogate_return_ranks:
|
if shared.opts.interrogate_return_ranks:
|
||||||
|
|
|
@ -9,7 +9,6 @@ from PIL import Image
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
import modules.artists
|
|
||||||
import modules.interrogate
|
import modules.interrogate
|
||||||
import modules.memmon
|
import modules.memmon
|
||||||
import modules.styles
|
import modules.styles
|
||||||
|
@ -254,8 +253,6 @@ class State:
|
||||||
state = State()
|
state = State()
|
||||||
state.server_start = time.time()
|
state.server_start = time.time()
|
||||||
|
|
||||||
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
|
|
||||||
|
|
||||||
styles_filename = cmd_opts.styles_file
|
styles_filename = cmd_opts.styles_file
|
||||||
prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
||||||
|
|
||||||
|
@ -408,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
|
@ -419,7 +415,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
||||||
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
|
|
||||||
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
|
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
|
||||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||||
|
|
|
@ -228,17 +228,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
|
||||||
left, _ = os.path.splitext(filename)
|
left, _ = os.path.splitext(filename)
|
||||||
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
|
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
|
||||||
|
|
||||||
return [gr_show(True), None]
|
return [gr.update(), None]
|
||||||
|
|
||||||
|
|
||||||
def interrogate(image):
|
def interrogate(image):
|
||||||
prompt = shared.interrogator.interrogate(image.convert("RGB"))
|
prompt = shared.interrogator.interrogate(image.convert("RGB"))
|
||||||
return gr_show(True) if prompt is None else prompt
|
return gr.update() if prompt is None else prompt
|
||||||
|
|
||||||
|
|
||||||
def interrogate_deepbooru(image):
|
def interrogate_deepbooru(image):
|
||||||
prompt = deepbooru.model.tag(image)
|
prompt = deepbooru.model.tag(image)
|
||||||
return gr_show(True) if prompt is None else prompt
|
return gr.update() if prompt is None else prompt
|
||||||
|
|
||||||
|
|
||||||
def create_seed_inputs(target_interface):
|
def create_seed_inputs(target_interface):
|
||||||
|
@ -1039,7 +1039,6 @@ def create_ui():
|
||||||
init_img_inpaint,
|
init_img_inpaint,
|
||||||
],
|
],
|
||||||
outputs=[img2img_prompt, dummy_component],
|
outputs=[img2img_prompt, dummy_component],
|
||||||
show_progress=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
img2img_prompt.submit(**img2img_args)
|
img2img_prompt.submit(**img2img_args)
|
||||||
|
|
Loading…
Reference in a new issue