bring back short hashes to sd checkpoint selection
This commit is contained in:
parent
d1ea518dea
commit
c1928cdd61
2 changed files with 23 additions and 15 deletions
|
@ -41,14 +41,16 @@ class CheckpointInfo:
|
|||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
self.title = name
|
||||
self.name = name
|
||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
self.hash = model_hash(filename)
|
||||
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
|
||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||
|
||||
self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
|
||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||
|
||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||
|
||||
def register(self):
|
||||
checkpoints_list[self.title] = self
|
||||
|
@ -56,13 +58,15 @@ class CheckpointInfo:
|
|||
checkpoint_alisases[id] = self
|
||||
|
||||
def calculate_shorthash(self):
|
||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
|
||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
||||
self.shorthash = self.sha256[0:10]
|
||||
|
||||
if self.shorthash not in self.ids:
|
||||
self.ids += [self.shorthash, self.sha256]
|
||||
self.register()
|
||||
|
||||
self.title = f'{self.name} [{self.shorthash}]'
|
||||
|
||||
return self.shorthash
|
||||
|
||||
|
||||
|
@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
|||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
||||
title = checkpoint_info.title
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
if checkpoint_info.title != title:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||
|
||||
|
|
|
@ -439,7 +439,7 @@ def apply_setting(key, value):
|
|||
opts.data_labels[key].onchange()
|
||||
|
||||
opts.save(shared.config_filename)
|
||||
return value
|
||||
return getattr(opts, key)
|
||||
|
||||
|
||||
def update_generation_info(generation_info, html_info, img_index):
|
||||
|
@ -597,6 +597,16 @@ def ordered_ui_categories():
|
|||
yield category
|
||||
|
||||
|
||||
def get_value_for_setting(key):
|
||||
value = getattr(opts, key)
|
||||
|
||||
info = opts.data_labels[key]
|
||||
args = info.component_args() if callable(info.component_args) else info.component_args or {}
|
||||
args = {k: v for k, v in args.items() if k not in {'precision'}}
|
||||
|
||||
return gr.update(value=value, **args)
|
||||
|
||||
|
||||
def create_ui():
|
||||
import modules.img2img
|
||||
import modules.txt2img
|
||||
|
@ -1600,7 +1610,7 @@ def create_ui():
|
|||
|
||||
opts.save(shared.config_filename)
|
||||
|
||||
return gr.update(value=value), opts.dumpjson()
|
||||
return get_value_for_setting(key), opts.dumpjson()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
||||
with gr.Row():
|
||||
|
@ -1771,15 +1781,6 @@ def create_ui():
|
|||
|
||||
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
||||
|
||||
def get_value_for_setting(key):
|
||||
value = getattr(opts, key)
|
||||
|
||||
info = opts.data_labels[key]
|
||||
args = info.component_args() if callable(info.component_args) else info.component_args or {}
|
||||
args = {k: v for k, v in args.items() if k not in {'precision'}}
|
||||
|
||||
return gr.update(value=value, **args)
|
||||
|
||||
def get_settings_values():
|
||||
return [get_value_for_setting(key) for key in component_keys]
|
||||
|
||||
|
|
Loading…
Reference in a new issue