add half() supporrt for CLIP interrogation

This commit is contained in:
AUTOMATIC 2022-09-11 23:24:24 +03:00
parent d97c6f221f
commit 8fb9c57ed6
6 changed files with 40 additions and 30 deletions

View file

@ -14,3 +14,9 @@ def get_optimal_device():
return torch.device("mps") return torch.device("mps")
return cpu return cpu
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View file

@ -1,7 +1,7 @@
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from modules import processing, shared, images from modules import processing, shared, images, devices
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
@ -11,7 +11,7 @@ cached_images = {}
def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
processing.torch_gc() devices.torch_gc()
image = image.convert("RGB") image = image.convert("RGB")
info = "" info = ""

View file

@ -3,6 +3,7 @@ import cv2
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageChops from PIL import Image, ImageOps, ImageChops
from modules import devices
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
@ -131,7 +132,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
upscaler = shared.sd_upscalers[upscaler_index] upscaler = shared.sd_upscalers[upscaler_index]
img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2) img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
processing.torch_gc() devices.torch_gc()
grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap) grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)

View file

@ -1,3 +1,4 @@
import contextlib
import os import os
import sys import sys
import traceback import traceback
@ -6,7 +7,6 @@ import re
import torch import torch
from PIL import Image
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
@ -26,6 +26,7 @@ class InterrogateModels:
clip_model = None clip_model = None
clip_preprocess = None clip_preprocess = None
categories = None categories = None
dtype = None
def __init__(self, content_dir): def __init__(self, content_dir):
self.categories = [] self.categories = []
@ -60,14 +61,20 @@ class InterrogateModels:
def load(self): def load(self):
if self.blip_model is None: if self.blip_model is None:
self.blip_model = self.load_blip_model() self.blip_model = self.load_blip_model()
if not shared.cmd_opts.no_half:
self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(shared.device) self.blip_model = self.blip_model.to(shared.device)
if self.clip_model is None: if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model() self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half:
self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(shared.device) self.clip_model = self.clip_model.to(shared.device)
self.dtype = next(self.clip_model.parameters()).dtype
def unload(self): def unload(self):
if not shared.opts.interrogate_keep_models_in_memory: if not shared.opts.interrogate_keep_models_in_memory:
if self.clip_model is not None: if self.clip_model is not None:
@ -76,14 +83,14 @@ class InterrogateModels:
if self.blip_model is not None: if self.blip_model is not None:
self.blip_model = self.blip_model.to(devices.cpu) self.blip_model = self.blip_model.to(devices.cpu)
devices.torch_gc()
def rank(self, image_features, text_array, top_count=1): def rank(self, image_features, text_array, top_count=1):
import clip import clip
top_count = min(top_count, len(text_array)) top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).cuda() text_tokens = clip.tokenize([text for text in text_array]).to(shared.device)
with torch.no_grad(): text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features = self.clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(shared.device) similarity = torch.zeros((1, len(text_array))).to(shared.device)
@ -94,13 +101,12 @@ class InterrogateModels:
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
def generate_caption(self, pil_image): def generate_caption(self, pil_image):
gpu_image = transforms.Compose([ gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(shared.device) ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
with torch.no_grad(): with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
@ -116,22 +122,23 @@ class InterrogateModels:
caption = self.generate_caption(pil_image) caption = self.generate_caption(pil_image)
res = caption res = caption
images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device) images = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
with torch.no_grad(): precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
image_features = self.clip_model.encode_image(images).float() with torch.no_grad(), precision_scope("cuda"):
image_features = self.clip_model.encode_image(images).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
if shared.opts.interrogate_use_builtin_artists: if shared.opts.interrogate_use_builtin_artists:
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
res += ", " + artist[0] res += ", " + artist[0]
for name, topn, items in self.categories: for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, items, top_count=topn)
for match, score in matches: for match, score in matches:
res += ", " + match res += ", " + match
except Exception: except Exception:
print(f"Error interrogating", file=sys.stderr) print(f"Error interrogating", file=sys.stderr)

View file

@ -10,6 +10,7 @@ from PIL import Image, ImageFilter, ImageOps
import random import random
import modules.sd_hijack import modules.sd_hijack
from modules import devices
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
@ -23,11 +24,6 @@ opt_C = 4
opt_f = 8 opt_f = 8
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class StableDiffusionProcessing: class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
@ -157,7 +153,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert p.prompt is not None assert p.prompt is not None
torch_gc() devices.torch_gc()
fix_seed(p) fix_seed(p)
@ -258,7 +254,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
if p.restore_faces: if p.restore_faces:
torch_gc() devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample) x_sample = modules.face_restoration.restore_faces(x_sample)
@ -297,7 +293,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
torch_gc() devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext()) return Processed(p, output_images, all_seeds[0], infotext())

View file

@ -4,7 +4,7 @@ import modules.scripts as scripts
import gradio as gr import gradio as gr
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from modules import images, processing from modules import images, processing, devices
from modules.processing import Processed, process_images from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
@ -77,7 +77,7 @@ class Script(scripts.Script):
mask.height - down - (mask_blur//2 if down > 0 else 0) mask.height - down - (mask_blur//2 if down > 0 else 0)
), fill="black") ), fill="black")
processing.torch_gc() devices.torch_gc()
grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)