Merge pull request #7113 from vladmandic/interrogate
Add selector to interrogate categories
This commit is contained in:
commit
7ba7f4ed6e
2 changed files with 27 additions and 17 deletions
|
@ -2,6 +2,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -20,19 +21,20 @@ Category = namedtuple("Category", ["name", "topn", "items"])
|
||||||
|
|
||||||
re_topn = re.compile(r"\.top(\d+)\.")
|
re_topn = re.compile(r"\.top(\d+)\.")
|
||||||
|
|
||||||
|
def category_types():
|
||||||
|
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
||||||
|
|
||||||
|
|
||||||
def download_default_clip_interrogate_categories(content_dir):
|
def download_default_clip_interrogate_categories(content_dir):
|
||||||
print("Downloading CLIP categories...")
|
print("Downloading CLIP categories...")
|
||||||
|
|
||||||
tmpdir = content_dir + "_tmp"
|
tmpdir = content_dir + "_tmp"
|
||||||
|
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(tmpdir)
|
os.makedirs(tmpdir)
|
||||||
|
for category_type in category_types:
|
||||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
|
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
||||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
|
|
||||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
|
|
||||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
|
|
||||||
|
|
||||||
os.rename(tmpdir, content_dir)
|
os.rename(tmpdir, content_dir)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -51,27 +53,32 @@ class InterrogateModels:
|
||||||
|
|
||||||
def __init__(self, content_dir):
|
def __init__(self, content_dir):
|
||||||
self.loaded_categories = None
|
self.loaded_categories = None
|
||||||
|
self.skip_categories = []
|
||||||
self.content_dir = content_dir
|
self.content_dir = content_dir
|
||||||
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||||
|
|
||||||
def categories(self):
|
def categories(self):
|
||||||
if self.loaded_categories is not None:
|
if not os.path.exists(self.content_dir):
|
||||||
|
download_default_clip_interrogate_categories(self.content_dir)
|
||||||
|
|
||||||
|
if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
|
||||||
return self.loaded_categories
|
return self.loaded_categories
|
||||||
|
|
||||||
self.loaded_categories = []
|
self.loaded_categories = []
|
||||||
|
|
||||||
if not os.path.exists(self.content_dir):
|
|
||||||
download_default_clip_interrogate_categories(self.content_dir)
|
|
||||||
|
|
||||||
if os.path.exists(self.content_dir):
|
if os.path.exists(self.content_dir):
|
||||||
for filename in os.listdir(self.content_dir):
|
self.skip_categories = shared.opts.interrogate_clip_skip_categories
|
||||||
m = re_topn.search(filename)
|
category_types = []
|
||||||
|
for filename in Path(self.content_dir).glob('*.txt'):
|
||||||
|
category_types.append(filename.stem)
|
||||||
|
if filename.stem in self.skip_categories:
|
||||||
|
continue
|
||||||
|
m = re_topn.search(filename.stem)
|
||||||
topn = 1 if m is None else int(m.group(1))
|
topn = 1 if m is None else int(m.group(1))
|
||||||
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
|
|
||||||
lines = [x.strip() for x in file.readlines()]
|
lines = [x.strip() for x in file.readlines()]
|
||||||
|
|
||||||
self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
|
self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
|
||||||
|
|
||||||
return self.loaded_categories
|
return self.loaded_categories
|
||||||
|
|
||||||
|
@ -139,6 +146,8 @@ class InterrogateModels:
|
||||||
def rank(self, image_features, text_array, top_count=1):
|
def rank(self, image_features, text_array, top_count=1):
|
||||||
import clip
|
import clip
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
if shared.opts.interrogate_clip_dict_limit != 0:
|
if shared.opts.interrogate_clip_dict_limit != 0:
|
||||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||||
|
|
||||||
|
|
|
@ -424,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
||||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
|
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
|
||||||
|
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
||||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||||
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
||||||
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
||||||
|
|
Loading…
Reference in a new issue