diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 919d5d31..f70872c4 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -275,7 +275,7 @@ re_attention = re.compile(r""" def parse_prompt_attention(text): """ - Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight. + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. Accepted tokens are: (abc) - increases attention to abc by a multiplier of 1.1 (abc:3.12) - increases attention to abc by a multiplier of 3.12 diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index a3345bb9..98123fbf 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch # Divide factor of safety as there's copying and fragmentation - return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) def einsum_op(q, k, v): if q.device.type == 'cuda': diff --git a/modules/sd_models.py b/modules/sd_models.py index 7ad6d474..eae22e87 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -148,7 +148,10 @@ def get_state_dict_from_checkpoint(pl_sd): if new_key is not None: sd[new_key] = v - return sd + pl_sd.clear() + pl_sd.update(sd) + + return pl_sd def load_model_weights(model, checkpoint_info): diff --git a/modules/ui.py b/modules/ui.py index 18a2add0..d2e24880 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,7 +12,7 @@ import time import traceback import platform import subprocess as sp -from functools import reduce +from functools import partial, reduce import numpy as np import torch @@ -261,6 +261,19 @@ def wrap_gradio_call(func, extra_outputs=None): return f +def calc_time_left(progress, threshold, label, force_display): + if progress == 0: + return "" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + if (eta_relative > threshold and progress > 0.02) or force_display: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + else: + return "" + + def check_progress_call(id_part): if shared.state.job_count == 0: return "", gr_show(False), gr_show(False), gr_show(False) @@ -272,11 +285,15 @@ def check_progress_call(id_part): if shared.state.sampling_steps > 0: progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display ) + if time_left != "": + shared.state.time_left_force_display = True + progress = min(progress, 1) progressbar = "" if opts.show_progressbar: - progressbar = f"""