Merge branch 'master' into cpu-cmdline-opt

This commit is contained in:
brkirch 2022-10-04 07:42:53 -04:00 committed by GitHub
commit e9e2a7ec9a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 109 additions and 80 deletions

View file

@ -69,10 +69,14 @@ def setup_model(dirname):
self.net = net self.net = net
self.face_helper = face_helper self.face_helper = face_helper
self.net.to(devices.device_codeformer)
return net, face_helper return net, face_helper
def send_model_to(self, device):
self.net.to(device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)
def restore(self, np_image, w=None): def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1] np_image = np_image[:, :, ::-1]
@ -82,6 +86,8 @@ def setup_model(dirname):
if self.net is None or self.face_helper is None: if self.net is None or self.face_helper is None:
return np_image return np_image
self.send_model_to(devices.device_codeformer)
self.face_helper.clean_all() self.face_helper.clean_all()
self.face_helper.read_image(np_image) self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@ -113,8 +119,10 @@ def setup_model(dirname):
if original_resolution != restored_img.shape[0:2]: if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
self.face_helper.clean_all()
if shared.opts.face_restoration_unload: if shared.opts.face_restoration_unload:
self.net.to(devices.cpu) self.send_model_to(devices.cpu)
return restored_img return restored_img

View file

@ -1,3 +1,5 @@
import contextlib
import torch import torch
from modules import errors from modules import errors
@ -56,3 +58,11 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def autocast():
from modules import shared
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
return torch.autocast("cuda")

View file

@ -36,23 +36,33 @@ def gfpgann():
else: else:
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(devices.device_gfpgan)
loaded_gfpgan_model = model loaded_gfpgan_model = model
return model return model
def send_model_to(model, device):
model.gfpgan.to(device)
model.face_helper.face_det.to(device)
model.face_helper.face_parse.to(device)
def gfpgan_fix_faces(np_image): def gfpgan_fix_faces(np_image):
model = gfpgann() model = gfpgann()
if model is None: if model is None:
return np_image return np_image
send_model_to(model, devices.device_gfpgan)
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
model.face_helper.clean_all()
if shared.opts.face_restoration_unload: if shared.opts.face_restoration_unload:
model.gfpgan.to(devices.cpu) send_model_to(model, devices.cpu)
return np_image return np_image

View file

@ -1,4 +1,3 @@
import contextlib
import json import json
import math import math
import os import os
@ -330,9 +329,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
infotexts = [] infotexts = []
output_images = [] output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) with torch.no_grad():
with torch.no_grad(), precision_scope("cuda"), ema_scope():
p.init(all_prompts, all_seeds, all_subseeds) p.init(all_prompts, all_seeds, all_subseeds)
if state.job_count == -1: if state.job_count == -1:
@ -351,8 +349,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts) #c = p.sd_model.get_learned_conditioning(prompts)
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) with devices.autocast():
c = prompt_parser.get_learned_conditioning(prompts, p.steps) uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
if len(model_hijack.comments) > 0: if len(model_hijack.comments) > 0:
for comment in model_hijack.comments: for comment in model_hijack.comments:
@ -361,13 +360,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n+1} out of {p.n_iter}"
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
if state.interrupted: if state.interrupted:
# if we are interruped, sample returns just noise # if we are interruped, sample returns just noise
# use the image collected previously in sampler loop # use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent samples_ddim = shared.state.current_latent
samples_ddim = samples_ddim.to(devices.dtype)
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@ -386,6 +389,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample) x_sample = modules.face_restoration.restore_faces(x_sample)
devices.torch_gc()
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)

View file

@ -1,20 +1,11 @@
import re import re
from collections import namedtuple from collections import namedtuple
import torch import torch
from lark import Lark, Transformer, Visitor
import functools
import modules.shared as shared import modules.shared as shared
re_prompt = re.compile(r'''
(.*?)
\[
([^]:]+):
(?:([^]:]*):)?
([0-9]*\.?[0-9]+)
]
|
(.+)
''', re.X)
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100): # will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
@ -25,61 +16,57 @@ re_prompt = re.compile(r'''
def get_learned_conditioning_prompt_schedules(prompts, steps): def get_learned_conditioning_prompt_schedules(prompts, steps):
res = [] grammar = r"""
cache = {} start: prompt
prompt: (emphasized | scheduled | weighted | plain)*
for prompt in prompts: !emphasized: "(" prompt ")"
prompt_schedule: list[list[str | int]] = [[steps, ""]] | "(" prompt ":" prompt ")"
| "[" prompt "]"
cached = cache.get(prompt, None) scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
if cached is not None: !weighted: "{" weighted_item ("|" weighted_item)* "}"
res.append(cached) !weighted_item: prompt (":" prompt)?
continue plain: /([^\\\[\](){}:|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
for m in re_prompt.finditer(prompt): """
plaintext = m.group(1) if m.group(5) is None else m.group(5) parser = Lark(grammar, parser='lalr')
concept_from = m.group(2) def collect_steps(steps, tree):
concept_to = m.group(3) l = [steps]
if concept_to is None: class CollectSteps(Visitor):
concept_to = concept_from def scheduled(self, tree):
concept_from = "" tree.children[-1] = float(tree.children[-1])
swap_position = float(m.group(4)) if m.group(4) is not None else None if tree.children[-1] < 1:
tree.children[-1] *= steps
if swap_position is not None: tree.children[-1] = min(steps, int(tree.children[-1]))
if swap_position < 1: l.append(tree.children[-1])
swap_position = swap_position * steps CollectSteps().visit(tree)
swap_position = int(min(swap_position, steps)) return sorted(set(l))
def at_step(step, tree):
swap_index = None class AtStep(Transformer):
found_exact_index = False def scheduled(self, args):
for i in range(len(prompt_schedule)): if len(args) == 2:
end_step = prompt_schedule[i][0] before, after, when = (), *args
prompt_schedule[i][1] += plaintext else:
before, after, when = args
if swap_position is not None and swap_index is None: yield before if step <= when else after
if swap_position == end_step: def start(self, args):
swap_index = i def flatten(x):
found_exact_index = True if type(x) == str:
yield x
if swap_position < end_step: else:
swap_index = i for gen in x:
yield from flatten(gen)
if swap_index is not None: return ''.join(flatten(args[0]))
if not found_exact_index: def plain(self, args):
prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]]) yield args[0].value
def __default__(self, data, children, meta):
for i in range(len(prompt_schedule)): for child in children:
end_step = prompt_schedule[i][0] yield from child
must_replace = swap_position < end_step return AtStep().transform(tree)
@functools.cache
prompt_schedule[i][1] += concept_to if must_replace else concept_from def get_schedule(prompt):
tree = parser.parse(prompt)
res.append(prompt_schedule) return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
cache[prompt] = prompt_schedule return [get_schedule(prompt) for prompt in prompts]
#for t in prompt_schedule:
# print(t)
return res
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])

View file

@ -386,14 +386,22 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component] outputs=[seed, dummy_component]
) )
def update_token_counter(text, steps): def update_token_counter(text, steps):
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) try:
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
except Exception:
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
prompt_schedules = [[[steps, text]]]
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step,prompt_text in flat_prompts] prompts = [prompt_text for step, prompt_text in flat_prompts]
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
style_class = ' class="red"' if (token_count > max_length) else "" style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>" return f"<span {style_class}>{token_count}/{max_length}</span>"
def create_toprow(is_img2img): def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img" id_part = "img2img" if is_img2img else "txt2img"

View file

@ -22,3 +22,4 @@ clean-fid
resize-right resize-right
torchdiffeq torchdiffeq
kornia kornia
lark

View file

@ -21,3 +21,4 @@ clean-fid==0.1.29
resize-right==0.0.2 resize-right==0.0.2
torchdiffeq==0.2.3 torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2