Merge pull request #6017 from hitomi/master
Add memory cache for VAE weights
This commit is contained in:
commit
3d8256e40c
2 changed files with 26 additions and 6 deletions
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
import collections
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from modules import shared, devices, script_callbacks
|
from modules import shared, devices, script_callbacks
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
@ -30,6 +31,7 @@ base_vae = None
|
||||||
loaded_vae_file = None
|
loaded_vae_file = None
|
||||||
checkpoint_info = None
|
checkpoint_info = None
|
||||||
|
|
||||||
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
def get_base_vae(model):
|
def get_base_vae(model):
|
||||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||||
|
@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
|
||||||
global first_load, vae_dict, vae_list, loaded_vae_file
|
global first_load, vae_dict, vae_list, loaded_vae_file
|
||||||
# save_settings = False
|
# save_settings = False
|
||||||
|
|
||||||
|
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||||
|
|
||||||
if vae_file:
|
if vae_file:
|
||||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
if cache_enabled and vae_file in checkpoints_loaded:
|
||||||
print(f"Loading VAE weights from: {vae_file}")
|
# use vae checkpoint cache
|
||||||
store_base_vae(model)
|
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
||||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
store_base_vae(model)
|
||||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||||
_load_vae_dict(model, vae_dict_1)
|
else:
|
||||||
|
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||||
|
print(f"Loading VAE weights from: {vae_file}")
|
||||||
|
store_base_vae(model)
|
||||||
|
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||||
|
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||||
|
_load_vae_dict(model, vae_dict_1)
|
||||||
|
|
||||||
|
if cache_enabled:
|
||||||
|
# cache newly loaded vae
|
||||||
|
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||||
|
|
||||||
|
# clean up cache if limit is reached
|
||||||
|
if cache_enabled:
|
||||||
|
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||||
|
checkpoints_loaded.popitem(last=False) # LRU
|
||||||
|
|
||||||
# If vae used is not in dict, update it
|
# If vae used is not in dict, update it
|
||||||
# It will be removed on refresh though
|
# It will be removed on refresh though
|
||||||
|
|
|
@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), {
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
||||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||||
|
|
Loading…
Reference in a new issue