bring back short hashes to sd checkpoint selection

This commit is contained in:
AUTOMATIC 2023-01-19 18:58:08 +03:00
parent d1ea518dea
commit c1928cdd61
2 changed files with 23 additions and 15 deletions

View file

@ -41,14 +41,16 @@ class CheckpointInfo:
if name.startswith("\\") or name.startswith("/"): if name.startswith("\\") or name.startswith("/"):
name = name[1:] name = name[1:]
self.title = name self.name = name
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename) self.hash = model_hash(filename)
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title) self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
self.shorthash = self.sha256[0:10] if self.sha256 else None self.shorthash = self.sha256[0:10] if self.sha256 else None
self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else []) self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self): def register(self):
checkpoints_list[self.title] = self checkpoints_list[self.title] = self
@ -56,13 +58,15 @@ class CheckpointInfo:
checkpoint_alisases[id] = self checkpoint_alisases[id] = self
def calculate_shorthash(self): def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title) self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
self.shorthash = self.sha256[0:10] self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids: if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256] self.ids += [self.shorthash, self.sha256]
self.register() self.register()
self.title = f'{self.name} [{self.shorthash}]'
return self.shorthash return self.shorthash
@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
def load_model_weights(model, checkpoint_info: CheckpointInfo): def load_model_weights(model, checkpoint_info: CheckpointInfo):
title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash() sd_model_hash = checkpoint_info.calculate_shorthash()
if checkpoint_info.title != title:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
cache_enabled = shared.opts.sd_checkpoint_cache > 0 cache_enabled = shared.opts.sd_checkpoint_cache > 0

View file

@ -439,7 +439,7 @@ def apply_setting(key, value):
opts.data_labels[key].onchange() opts.data_labels[key].onchange()
opts.save(shared.config_filename) opts.save(shared.config_filename)
return value return getattr(opts, key)
def update_generation_info(generation_info, html_info, img_index): def update_generation_info(generation_info, html_info, img_index):
@ -597,6 +597,16 @@ def ordered_ui_categories():
yield category yield category
def get_value_for_setting(key):
value = getattr(opts, key)
info = opts.data_labels[key]
args = info.component_args() if callable(info.component_args) else info.component_args or {}
args = {k: v for k, v in args.items() if k not in {'precision'}}
return gr.update(value=value, **args)
def create_ui(): def create_ui():
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
@ -1600,7 +1610,7 @@ def create_ui():
opts.save(shared.config_filename) opts.save(shared.config_filename)
return gr.update(value=value), opts.dumpjson() return get_value_for_setting(key), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Blocks(analytics_enabled=False) as settings_interface:
with gr.Row(): with gr.Row():
@ -1771,15 +1781,6 @@ def create_ui():
component_keys = [k for k in opts.data_labels.keys() if k in component_dict] component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
def get_value_for_setting(key):
value = getattr(opts, key)
info = opts.data_labels[key]
args = info.component_args() if callable(info.component_args) else info.component_args or {}
args = {k: v for k, v in args.items() if k not in {'precision'}}
return gr.update(value=value, **args)
def get_settings_values(): def get_settings_values():
return [get_value_for_setting(key) for key in component_keys] return [get_value_for_setting(key) for key in component_keys]