CLIP interrogator
This commit is contained in:
parent
13008bab90
commit
f194457229
13 changed files with 204 additions and 13 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,3 +13,4 @@ __pycache__
|
|||
/embeddings
|
||||
/styles.csv
|
||||
/webui-user.bat
|
||||
/interrogate
|
||||
|
|
|
@ -40,6 +40,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||
- Styles
|
||||
- Variations
|
||||
- Seed resizing
|
||||
- CLIP interrogator
|
||||
|
||||
## Installing and running
|
||||
|
||||
|
@ -289,5 +290,6 @@ After that follow the instructions in the `Manual instructions` section starting
|
|||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||
- Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion
|
||||
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
||||
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||
- (You)
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
import torch
|
||||
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||
has_mps = getattr(torch, 'has_mps', False)
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
|
||||
if has_mps:
|
||||
return torch.device("mps")
|
||||
return torch.device("cpu")
|
||||
|
||||
return cpu
|
||||
|
|
142
modules/interrogate.py
Normal file
142
modules/interrogate.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import devices, paths
|
||||
|
||||
blip_image_eval_size = 384
|
||||
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
||||
Category = namedtuple("Category", ["name", "topn", "items"])
|
||||
|
||||
re_topn = re.compile(r"\.top(\d+)\.")
|
||||
|
||||
class InterrogateModels:
|
||||
blip_model = None
|
||||
clip_model = None
|
||||
clip_preprocess = None
|
||||
categories = None
|
||||
|
||||
def __init__(self, content_dir):
|
||||
self.categories = []
|
||||
|
||||
if os.path.exists(content_dir):
|
||||
for filename in os.listdir(content_dir):
|
||||
m = re_topn.search(filename)
|
||||
topn = 1 if m is None else int(m.group(1))
|
||||
|
||||
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
|
||||
lines = [x.strip() for x in file.readlines()]
|
||||
|
||||
self.categories.append(Category(name=filename, topn=topn, items=lines))
|
||||
|
||||
def load_blip_model(self):
|
||||
import models.blip
|
||||
|
||||
blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
||||
blip_model.eval()
|
||||
|
||||
return blip_model
|
||||
|
||||
def load_clip_model(self):
|
||||
import clip
|
||||
|
||||
model, preprocess = clip.load(clip_model_name)
|
||||
model.eval()
|
||||
model = model.to(shared.device)
|
||||
|
||||
return model, preprocess
|
||||
|
||||
def load(self):
|
||||
if self.blip_model is None:
|
||||
self.blip_model = self.load_blip_model()
|
||||
|
||||
self.blip_model = self.blip_model.to(shared.device)
|
||||
|
||||
if self.clip_model is None:
|
||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||
|
||||
self.clip_model = self.clip_model.to(shared.device)
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
if self.clip_model is not None:
|
||||
self.clip_model = self.clip_model.to(devices.cpu)
|
||||
|
||||
if self.blip_model is not None:
|
||||
self.blip_model = self.blip_model.to(devices.cpu)
|
||||
|
||||
|
||||
def rank(self, image_features, text_array, top_count=1):
|
||||
import clip
|
||||
|
||||
top_count = min(top_count, len(text_array))
|
||||
text_tokens = clip.tokenize([text for text in text_array]).cuda()
|
||||
with torch.no_grad():
|
||||
text_features = self.clip_model.encode_text(text_tokens).float()
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarity = torch.zeros((1, len(text_array))).to(shared.device)
|
||||
for i in range(image_features.shape[0]):
|
||||
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
||||
similarity /= image_features.shape[0]
|
||||
|
||||
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
||||
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
||||
|
||||
|
||||
def generate_caption(self, pil_image):
|
||||
gpu_image = transforms.Compose([
|
||||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])(pil_image).unsqueeze(0).to(shared.device)
|
||||
|
||||
with torch.no_grad():
|
||||
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
||||
|
||||
return caption[0]
|
||||
|
||||
def interrogate(self, pil_image):
|
||||
res = None
|
||||
|
||||
try:
|
||||
self.load()
|
||||
|
||||
caption = self.generate_caption(pil_image)
|
||||
res = caption
|
||||
|
||||
images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device)
|
||||
|
||||
with torch.no_grad():
|
||||
image_features = self.clip_model.encode_image(images).float()
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
if shared.opts.interrogate_use_builtin_artists:
|
||||
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
|
||||
|
||||
res += ", " + artist[0]
|
||||
|
||||
for name, topn, items in self.categories:
|
||||
matches = self.rank(image_features, items, top_count=topn)
|
||||
for match, score in matches:
|
||||
res += ", " + match
|
||||
|
||||
except Exception:
|
||||
print(f"Error interrogating", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
self.unload()
|
||||
|
||||
return res
|
|
@ -18,6 +18,7 @@ path_dirs = [
|
|||
(sd_path, 'ldm', 'Stable Diffusion'),
|
||||
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'),
|
||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
|
||||
]
|
||||
|
||||
paths = {}
|
||||
|
|
|
@ -11,6 +11,7 @@ import modules.artists
|
|||
from modules.paths import script_path, sd_path
|
||||
from modules.devices import get_optimal_device
|
||||
import modules.styles
|
||||
import modules.interrogate
|
||||
|
||||
config_filename = "config.json"
|
||||
|
||||
|
@ -77,6 +78,8 @@ artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.c
|
|||
styles_filename = os.path.join(script_path, 'styles.csv')
|
||||
prompt_styles = modules.styles.load_styles(styles_filename)
|
||||
|
||||
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||
|
||||
face_restorers = []
|
||||
|
||||
class Options:
|
||||
|
@ -123,6 +126,11 @@ class Options:
|
|||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
|
||||
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"interrogate_keep_models_in_memory": OptionInfo(True, "Interrogate: keep models in VRAM"),
|
||||
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum descripton length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum descripton length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -242,9 +242,14 @@ def add_style(style_name, text):
|
|||
return [update, update]
|
||||
|
||||
|
||||
def interrogate(image):
|
||||
prompt = shared.interrogator.interrogate(image)
|
||||
|
||||
return gr_show(True) if prompt is None else prompt
|
||||
|
||||
def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
with gr.Row():
|
||||
with gr.Row(elem_id="toprow"):
|
||||
txt2img_prompt = gr.Textbox(label="Prompt", elem_id="txt2img_prompt", show_label=False, placeholder="Prompt", lines=1)
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="txt2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1)
|
||||
txt2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1)
|
||||
|
@ -365,10 +370,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||
)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||
with gr.Row():
|
||||
with gr.Row(elem_id="toprow"):
|
||||
img2img_prompt = gr.Textbox(label="Prompt", elem_id="img2img_prompt", show_label=False, placeholder="Prompt", lines=1)
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="img2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1)
|
||||
img2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1)
|
||||
img2img_interrogate = gr.Button('Interrogate', elem_id="img2img_interrogate", variant='primary')
|
||||
submit = gr.Button('Generate', elem_id="img2img_generate", variant='primary')
|
||||
check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)
|
||||
|
||||
|
@ -461,6 +467,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||
inpaint_full_res: gr_show(is_inpaint),
|
||||
inpainting_mask_invert: gr_show(is_inpaint),
|
||||
denoising_strength_change_factor: gr_show(is_loopback),
|
||||
img2img_interrogate: gr_show(not is_inpaint),
|
||||
}
|
||||
|
||||
switch_mode.change(
|
||||
|
@ -480,6 +487,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||
inpaint_full_res,
|
||||
inpainting_mask_invert,
|
||||
denoising_strength_change_factor,
|
||||
img2img_interrogate,
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -540,6 +548,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||
img2img_prompt.submit(**img2img_args)
|
||||
submit.click(**img2img_args)
|
||||
|
||||
img2img_interrogate.click(
|
||||
fn=interrogate,
|
||||
inputs=[init_img],
|
||||
outputs=[img2img_prompt],
|
||||
)
|
||||
|
||||
check_progress.click(
|
||||
fn=check_progress_call,
|
||||
show_progress=False,
|
||||
|
|
|
@ -15,3 +15,5 @@ fonts
|
|||
font-roboto
|
||||
git+https://github.com/crowsonkb/k-diffusion.git
|
||||
git+https://github.com/TencentARC/GFPGAN.git
|
||||
timm==0.4.12
|
||||
fairscale==0.4.4
|
||||
|
|
|
@ -11,3 +11,5 @@ pytorch_lightning==1.7.2
|
|||
scikit-image==0.19.2
|
||||
fonts
|
||||
font-roboto
|
||||
timm==0.4.12
|
||||
fairscale==0.4.4
|
||||
|
|
|
@ -51,6 +51,8 @@ titles = {
|
|||
"Variation strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).",
|
||||
"Resize seed from height": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
|
||||
"Resize seed from width": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
|
||||
|
||||
"Interrogate": "Reconstruct frompt from existing image and put it into the prompt field.",
|
||||
}
|
||||
|
||||
function gradioApp(){
|
||||
|
|
|
@ -5,6 +5,10 @@
|
|||
max-width: 13em;
|
||||
}
|
||||
|
||||
#img2img_interrogate{
|
||||
max-width: 10em;
|
||||
}
|
||||
|
||||
#subseed_show{
|
||||
min-width: 6em;
|
||||
max-width: 6em;
|
||||
|
@ -26,7 +30,7 @@
|
|||
padding-right: 0;
|
||||
}
|
||||
|
||||
#component-1 div{
|
||||
#toprow div{
|
||||
border: none;
|
||||
gap: 0;
|
||||
}
|
||||
|
|
13
webui.bat
13
webui.bat
|
@ -85,7 +85,7 @@ if %ERRORLEVEL% == 0 goto :install_reqs
|
|||
goto :show_stdout_stderr
|
||||
|
||||
:install_reqs
|
||||
%PYTHON% -c "import omegaconf; import fonts" >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
%PYTHON% -c "import omegaconf; import fonts; import timm" >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
if %ERRORLEVEL% == 0 goto :make_dirs
|
||||
echo Installing requirements...
|
||||
%PYTHON% -m pip install -r %REQS_FILE% --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
|
@ -117,12 +117,19 @@ goto :show_stdout_stderr
|
|||
|
||||
:install_codeformer_reqs
|
||||
%PYTHON% -c "import lpips" >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
if %ERRORLEVEL% == 0 goto :check_model
|
||||
if %ERRORLEVEL% == 0 goto :clone_blip
|
||||
echo Installing requirements for CodeFormer...
|
||||
%PYTHON% -m pip install -r repositories\CodeFormer\requirements.txt --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
if %ERRORLEVEL% == 0 goto :check_model
|
||||
if %ERRORLEVEL% == 0 goto :clone_blip
|
||||
goto :show_stdout_stderr
|
||||
|
||||
:clone_blip
|
||||
if exist repositories\BLIP goto :check_model
|
||||
echo Cloning BLIP repository...
|
||||
%GIT% clone https://github.com/salesforce/BLIP.git repositories\BLIP >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
if %ERRORLEVEL% NEQ 0 goto :show_stdout_stderr
|
||||
%GIT% -C repositories/BLIP checkout 48211a1594f1321b00f14c9f7a5b4813144b2fb9 >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
if %ERRORLEVEL% NEQ 0 goto :show_stdout_stderr
|
||||
|
||||
:check_model
|
||||
dir model.ckpt >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
|
|
2
webui.py
2
webui.py
|
@ -33,6 +33,7 @@ shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
|||
esrgan.load_models(cmd_opts.esrgan_models_path)
|
||||
realesrgan.setup_realesrgan()
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
|
@ -116,5 +117,6 @@ def webui():
|
|||
|
||||
demo.launch(share=cmd_opts.share, server_name="0.0.0.0" if cmd_opts.listen else None, server_port=cmd_opts.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
webui()
|
||||
|
|
Loading…
Reference in a new issue