memory optimization for CLIP interrogator

changed default cfg_scale to a higher value
This commit is contained in:
AUTOMATIC 2022-09-12 11:55:27 +03:00
parent ab0a79cdf4
commit 9bb20be090
4 changed files with 36 additions and 7 deletions

View file

@ -11,7 +11,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared import modules.shared as shared
from modules import devices, paths from modules import devices, paths, lowvram
blip_image_eval_size = 384 blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
@ -75,19 +75,28 @@ class InterrogateModels:
self.dtype = next(self.clip_model.parameters()).dtype self.dtype = next(self.clip_model.parameters()).dtype
def unload(self): def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory: if not shared.opts.interrogate_keep_models_in_memory:
if self.clip_model is not None: if self.clip_model is not None:
self.clip_model = self.clip_model.to(devices.cpu) self.clip_model = self.clip_model.to(devices.cpu)
def send_blip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.blip_model is not None: if self.blip_model is not None:
self.blip_model = self.blip_model.to(devices.cpu) self.blip_model = self.blip_model.to(devices.cpu)
devices.torch_gc() def unload(self):
self.send_clip_to_ram()
self.send_blip_to_ram()
devices.torch_gc()
def rank(self, image_features, text_array, top_count=1): def rank(self, image_features, text_array, top_count=1):
import clip import clip
if shared.opts.interrogate_clip_dict_limit != 0:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array)) top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).to(shared.device) text_tokens = clip.tokenize([text for text in text_array]).to(shared.device)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
@ -117,16 +126,24 @@ class InterrogateModels:
res = None res = None
try: try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
self.load() self.load()
caption = self.generate_caption(pil_image) caption = self.generate_caption(pil_image)
self.send_blip_to_ram()
devices.torch_gc()
res = caption res = caption
images = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"): with torch.no_grad(), precision_scope("cuda"):
image_features = self.clip_model.encode_image(images).type(self.dtype) image_features = self.clip_model.encode_image(cilp_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
@ -146,4 +163,5 @@ class InterrogateModels:
self.unload() self.unload()
res += "<error>"
return res return res

View file

@ -5,6 +5,16 @@ module_in_gpu = None
cpu = torch.device("cpu") cpu = torch.device("cpu")
device = gpu = get_optimal_device() device = gpu = get_optimal_device()
def send_everything_to_cpu():
global module_in_gpu
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module_in_gpu = None
def setup_for_low_vram(sd_model, use_medvram): def setup_for_low_vram(sd_model, use_medvram):
parents = {} parents = {}

View file

@ -132,6 +132,7 @@ class Options:
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum descripton length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum descripton length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum descripton length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum descripton length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
} }
def __init__(self): def __init__(self):

View file

@ -270,7 +270,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.0) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
with gr.Group(): with gr.Group():
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
@ -413,7 +413,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
with gr.Group(): with gr.Group():
cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.0) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, visible=False) denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, visible=False)