6d805b669e
remove random artist built-in extension (to re-added as a normal extension on demand) remove artists.csv (but what does it mean????????????????????) make interrogate buttons show Loading... when you click them
210 lines
8.2 KiB
Python
210 lines
8.2 KiB
Python
import os
|
|
import sys
|
|
import traceback
|
|
from collections import namedtuple
|
|
import re
|
|
|
|
import torch
|
|
import torch.hub
|
|
|
|
from torchvision import transforms
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
|
|
import modules.shared as shared
|
|
from modules import devices, paths, lowvram, modelloader, errors
|
|
|
|
blip_image_eval_size = 384
|
|
clip_model_name = 'ViT-L/14'
|
|
|
|
Category = namedtuple("Category", ["name", "topn", "items"])
|
|
|
|
re_topn = re.compile(r"\.top(\d+)\.")
|
|
|
|
|
|
def download_default_clip_interrogate_categories(content_dir):
|
|
print("Downloading CLIP categories...")
|
|
|
|
tmpdir = content_dir + "_tmp"
|
|
try:
|
|
os.makedirs(tmpdir)
|
|
|
|
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
|
|
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
|
|
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
|
|
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
|
|
|
|
os.rename(tmpdir, content_dir)
|
|
|
|
except Exception as e:
|
|
errors.display(e, "downloading default CLIP interrogate categories")
|
|
finally:
|
|
if os.path.exists(tmpdir):
|
|
os.remove(tmpdir)
|
|
|
|
|
|
class InterrogateModels:
|
|
blip_model = None
|
|
clip_model = None
|
|
clip_preprocess = None
|
|
dtype = None
|
|
running_on_cpu = None
|
|
|
|
def __init__(self, content_dir):
|
|
self.loaded_categories = None
|
|
self.content_dir = content_dir
|
|
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
|
|
|
def categories(self):
|
|
if self.loaded_categories is not None:
|
|
return self.loaded_categories
|
|
|
|
self.loaded_categories = []
|
|
|
|
if not os.path.exists(self.content_dir):
|
|
download_default_clip_interrogate_categories(self.content_dir)
|
|
|
|
if os.path.exists(self.content_dir):
|
|
for filename in os.listdir(self.content_dir):
|
|
m = re_topn.search(filename)
|
|
topn = 1 if m is None else int(m.group(1))
|
|
|
|
with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
|
|
lines = [x.strip() for x in file.readlines()]
|
|
|
|
self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
|
|
|
|
return self.loaded_categories
|
|
|
|
def load_blip_model(self):
|
|
import models.blip
|
|
|
|
files = modelloader.load_models(
|
|
model_path=os.path.join(paths.models_path, "BLIP"),
|
|
model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
|
|
ext_filter=[".pth"],
|
|
download_name='model_base_caption_capfilt_large.pth',
|
|
)
|
|
|
|
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
|
blip_model.eval()
|
|
|
|
return blip_model
|
|
|
|
def load_clip_model(self):
|
|
import clip
|
|
|
|
if self.running_on_cpu:
|
|
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
|
else:
|
|
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
|
|
|
|
model.eval()
|
|
model = model.to(devices.device_interrogate)
|
|
|
|
return model, preprocess
|
|
|
|
def load(self):
|
|
if self.blip_model is None:
|
|
self.blip_model = self.load_blip_model()
|
|
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
|
self.blip_model = self.blip_model.half()
|
|
|
|
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
|
|
|
if self.clip_model is None:
|
|
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
|
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
|
self.clip_model = self.clip_model.half()
|
|
|
|
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
|
|
|
self.dtype = next(self.clip_model.parameters()).dtype
|
|
|
|
def send_clip_to_ram(self):
|
|
if not shared.opts.interrogate_keep_models_in_memory:
|
|
if self.clip_model is not None:
|
|
self.clip_model = self.clip_model.to(devices.cpu)
|
|
|
|
def send_blip_to_ram(self):
|
|
if not shared.opts.interrogate_keep_models_in_memory:
|
|
if self.blip_model is not None:
|
|
self.blip_model = self.blip_model.to(devices.cpu)
|
|
|
|
def unload(self):
|
|
self.send_clip_to_ram()
|
|
self.send_blip_to_ram()
|
|
|
|
devices.torch_gc()
|
|
|
|
def rank(self, image_features, text_array, top_count=1):
|
|
import clip
|
|
|
|
if shared.opts.interrogate_clip_dict_limit != 0:
|
|
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
|
|
|
top_count = min(top_count, len(text_array))
|
|
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
|
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
|
|
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
|
|
for i in range(image_features.shape[0]):
|
|
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
|
similarity /= image_features.shape[0]
|
|
|
|
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
|
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
|
|
|
def generate_caption(self, pil_image):
|
|
gpu_image = transforms.Compose([
|
|
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
|
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
|
|
|
with torch.no_grad():
|
|
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
|
|
|
return caption[0]
|
|
|
|
def interrogate(self, pil_image):
|
|
res = ""
|
|
shared.state.begin()
|
|
shared.state.job = 'interrogate'
|
|
try:
|
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
lowvram.send_everything_to_cpu()
|
|
devices.torch_gc()
|
|
|
|
self.load()
|
|
|
|
caption = self.generate_caption(pil_image)
|
|
self.send_blip_to_ram()
|
|
devices.torch_gc()
|
|
|
|
res = caption
|
|
|
|
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
|
|
|
with torch.no_grad(), devices.autocast():
|
|
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
|
|
for name, topn, items in self.categories():
|
|
matches = self.rank(image_features, items, top_count=topn)
|
|
for match, score in matches:
|
|
if shared.opts.interrogate_return_ranks:
|
|
res += f", ({match}:{score/100:.3f})"
|
|
else:
|
|
res += ", " + match
|
|
|
|
except Exception:
|
|
print("Error interrogating", file=sys.stderr)
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
res += "<error>"
|
|
|
|
self.unload()
|
|
shared.state.end()
|
|
|
|
return res
|