remove the need to place configs near models
This commit is contained in:
parent
7a14c8ab45
commit
d2ac95fa7b
10 changed files with 361 additions and 152 deletions
99
configs/instruct-pix2pix.yaml
Normal file
99
configs/instruct-pix2pix.yaml
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
||||||
|
# See more details in LICENSE.
|
||||||
|
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: edited
|
||||||
|
cond_stage_key: edit
|
||||||
|
# image_size: 64
|
||||||
|
# image_size: 32
|
||||||
|
image_size: 16
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: hybrid
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: true
|
||||||
|
load_ema: true
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 0 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 1
|
||||||
|
wrap: false
|
||||||
|
validation:
|
||||||
|
target: edit_dataset.EditDataset
|
||||||
|
params:
|
||||||
|
path: data/clip-filtered-dataset
|
||||||
|
cache_dir: data/
|
||||||
|
cache_name: data_10k
|
||||||
|
split: val
|
||||||
|
min_text_sim: 0.2
|
||||||
|
min_image_sim: 0.75
|
||||||
|
min_direction_sim: 0.2
|
||||||
|
max_samples_per_prompt: 1
|
||||||
|
min_resize_res: 512
|
||||||
|
max_resize_res: 512
|
||||||
|
crop_res: 512
|
||||||
|
output_as_edit: False
|
||||||
|
real_input: True
|
|
@ -1,8 +1,7 @@
|
||||||
model:
|
model:
|
||||||
base_learning_rate: 1.0e-4
|
base_learning_rate: 7.5e-05
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
params:
|
params:
|
||||||
parameterization: "v"
|
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
num_timesteps_cond: 1
|
num_timesteps_cond: 1
|
||||||
|
@ -12,29 +11,36 @@ model:
|
||||||
cond_stage_key: "txt"
|
cond_stage_key: "txt"
|
||||||
image_size: 64
|
image_size: 64
|
||||||
channels: 4
|
channels: 4
|
||||||
cond_stage_trainable: false
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
conditioning_key: crossattn
|
conditioning_key: hybrid # important
|
||||||
monitor: val/loss_simple_ema
|
monitor: val/loss_simple_ema
|
||||||
scale_factor: 0.18215
|
scale_factor: 0.18215
|
||||||
use_ema: False # we set this to false because this is an inference only config
|
finetune_keys: null
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
unet_config:
|
unet_config:
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
|
||||||
use_fp16: True
|
|
||||||
image_size: 32 # unused
|
image_size: 32 # unused
|
||||||
in_channels: 4
|
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 320
|
model_channels: 320
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
num_head_channels: 64 # need to fix for flash-attn
|
num_heads: 8
|
||||||
use_spatial_transformer: True
|
use_spatial_transformer: True
|
||||||
use_linear_in_transformer: True
|
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
context_dim: 1024
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
legacy: False
|
legacy: False
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
|
@ -43,7 +49,6 @@ model:
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
monitor: val/rec_loss
|
monitor: val/rec_loss
|
||||||
ddconfig:
|
ddconfig:
|
||||||
#attn_type: "vanilla-xformers"
|
|
||||||
double_z: true
|
double_z: true
|
||||||
z_channels: 4
|
z_channels: 4
|
||||||
resolution: 256
|
resolution: 256
|
||||||
|
@ -62,7 +67,4 @@ model:
|
||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
|
|
||||||
cond_stage_config:
|
cond_stage_config:
|
||||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
params:
|
|
||||||
freeze: True
|
|
||||||
layer: "penultimate"
|
|
|
@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
from modules.sd_models import checkpoints_list
|
||||||
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import List
|
from typing import List
|
||||||
|
@ -387,7 +388,7 @@ class Api:
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||||
|
|
||||||
def get_hypernetworks(self):
|
def get_hypernetworks(self):
|
||||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||||
|
|
|
@ -34,14 +34,18 @@ def get_cuda_device_string():
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
def get_optimal_device():
|
def get_optimal_device_name():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return torch.device(get_cuda_device_string())
|
return get_cuda_device_string()
|
||||||
|
|
||||||
if has_mps():
|
if has_mps():
|
||||||
return torch.device("mps")
|
return "mps"
|
||||||
|
|
||||||
return cpu
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimal_device():
|
||||||
|
return torch.device(get_optimal_device_name())
|
||||||
|
|
||||||
|
|
||||||
def get_device_for(task):
|
def get_device_for(task):
|
||||||
|
|
|
@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
||||||
return x_prev, pred_x0, e_t
|
return x_prev, pred_x0, e_t
|
||||||
|
|
||||||
|
|
||||||
def should_hijack_inpainting(checkpoint_info):
|
|
||||||
from modules import sd_models
|
|
||||||
|
|
||||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
|
||||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
|
||||||
|
|
||||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
|
||||||
|
|
||||||
|
|
||||||
def do_inpainting_hijack():
|
def do_inpainting_hijack():
|
||||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,6 @@ import collections
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import gc
|
import gc
|
||||||
import time
|
|
||||||
from collections import namedtuple
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
@ -14,10 +12,10 @@ import ldm.modules.midas as midas
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
|
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.sd_hijack_ip2p import should_hijack_ip2p
|
from modules.timer import Timer
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
@ -99,17 +97,6 @@ def checkpoint_tiles():
|
||||||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
||||||
|
|
||||||
|
|
||||||
def find_checkpoint_config(info):
|
|
||||||
if info is None:
|
|
||||||
return shared.cmd_opts.config
|
|
||||||
|
|
||||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
|
||||||
if os.path.exists(config):
|
|
||||||
return config
|
|
||||||
|
|
||||||
return shared.cmd_opts.config
|
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
checkpoints_list.clear()
|
checkpoints_list.clear()
|
||||||
checkpoint_alisases.clear()
|
checkpoint_alisases.clear()
|
||||||
|
@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd):
|
||||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
device = map_location or shared.weight_load_location
|
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||||
if device is None:
|
|
||||||
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
|
|
||||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||||
|
@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||||
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
|
timer.record("calculate hash")
|
||||||
|
|
||||||
|
if checkpoint_info in checkpoints_loaded:
|
||||||
|
# use checkpoint cache
|
||||||
|
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||||
|
return checkpoints_loaded[checkpoint_info]
|
||||||
|
|
||||||
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||||
|
res = read_state_dict(checkpoint_info.filename)
|
||||||
|
timer.record("load weights from disk")
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||||
title = checkpoint_info.title
|
title = checkpoint_info.title
|
||||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
|
timer.record("calculate hash")
|
||||||
|
|
||||||
if checkpoint_info.title != title:
|
if checkpoint_info.title != title:
|
||||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||||
|
|
||||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
if state_dict is None:
|
||||||
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
if cache_enabled and checkpoint_info in checkpoints_loaded:
|
model.load_state_dict(state_dict, strict=False)
|
||||||
# use checkpoint cache
|
del state_dict
|
||||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
timer.record("apply weights to model")
|
||||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
|
||||||
else:
|
|
||||||
# load from file
|
|
||||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
|
||||||
|
|
||||||
sd = read_state_dict(checkpoint_info.filename)
|
if shared.opts.sd_checkpoint_cache > 0:
|
||||||
model.load_state_dict(sd, strict=False)
|
# cache newly loaded model
|
||||||
del sd
|
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||||
|
|
||||||
if cache_enabled:
|
|
||||||
# cache newly loaded model
|
|
||||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
|
timer.record("apply channels_last")
|
||||||
|
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half:
|
||||||
vae = model.first_stage_model
|
vae = model.first_stage_model
|
||||||
depth_model = getattr(model, 'depth_model', None)
|
depth_model = getattr(model, 'depth_model', None)
|
||||||
|
|
||||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||||
if shared.cmd_opts.no_half_vae:
|
if shared.cmd_opts.no_half_vae:
|
||||||
model.first_stage_model = None
|
model.first_stage_model = None
|
||||||
# with --upcast-sampling, don't convert the depth model weights to float16
|
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||||
model.depth_model = None
|
model.depth_model = None
|
||||||
|
|
||||||
model.half()
|
model.half()
|
||||||
model.first_stage_model = vae
|
model.first_stage_model = vae
|
||||||
if depth_model:
|
if depth_model:
|
||||||
model.depth_model = depth_model
|
model.depth_model = depth_model
|
||||||
|
|
||||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
timer.record("apply half()")
|
||||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
|
||||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
|
||||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||||
|
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||||
|
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||||
|
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||||
|
|
||||||
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
timer.record("apply dtype to VAE")
|
||||||
|
|
||||||
# clean up cache if limit is reached
|
# clean up cache if limit is reached
|
||||||
if cache_enabled:
|
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
|
checkpoints_loaded.popitem(last=False)
|
||||||
checkpoints_loaded.popitem(last=False) # LRU
|
|
||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
model.sd_model_checkpoint = checkpoint_info.filename
|
model.sd_model_checkpoint = checkpoint_info.filename
|
||||||
|
@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
||||||
sd_vae.clear_loaded_vae()
|
sd_vae.clear_loaded_vae()
|
||||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
||||||
sd_vae.load_vae(model, vae_file, vae_source)
|
sd_vae.load_vae(model, vae_file, vae_source)
|
||||||
|
timer.record("load VAE")
|
||||||
|
|
||||||
|
|
||||||
def enable_midas_autodownload():
|
def enable_midas_autodownload():
|
||||||
|
@ -340,24 +340,20 @@ def enable_midas_autodownload():
|
||||||
midas.api.load_model = load_model_wrapper
|
midas.api.load_model = load_model_wrapper
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
def repair_config(sd_config):
|
||||||
def __init__(self):
|
|
||||||
self.start = time.time()
|
|
||||||
|
|
||||||
def elapsed(self):
|
if not hasattr(sd_config.model.params, "use_ema"):
|
||||||
end = time.time()
|
sd_config.model.params.use_ema = False
|
||||||
res = end - self.start
|
|
||||||
self.start = end
|
if shared.cmd_opts.no_half:
|
||||||
return res
|
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||||
|
elif shared.cmd_opts.upcast_sampling:
|
||||||
|
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None):
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
|
||||||
|
|
||||||
if checkpoint_config != shared.cmd_opts.config:
|
|
||||||
print(f"Loading config from: {checkpoint_config}")
|
|
||||||
|
|
||||||
if shared.sd_model:
|
if shared.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||||
|
@ -365,38 +361,27 @@ def load_model(checkpoint_info=None):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
sd_config = OmegaConf.load(checkpoint_config)
|
|
||||||
|
|
||||||
if should_hijack_inpainting(checkpoint_info):
|
|
||||||
# Hardcoded config for now...
|
|
||||||
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
|
||||||
sd_config.model.params.conditioning_key = "hybrid"
|
|
||||||
sd_config.model.params.unet_config.params.in_channels = 9
|
|
||||||
sd_config.model.params.finetune_keys = None
|
|
||||||
|
|
||||||
if should_hijack_ip2p(checkpoint_info):
|
|
||||||
sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion"
|
|
||||||
sd_config.model.params.conditioning_key = "hybrid"
|
|
||||||
sd_config.model.params.first_stage_key = "edited"
|
|
||||||
sd_config.model.params.cond_stage_key = "edit"
|
|
||||||
sd_config.model.params.image_size = 16
|
|
||||||
sd_config.model.params.unet_config.params.in_channels = 8
|
|
||||||
sd_config.model.params.unet_config.params.out_channels = 4
|
|
||||||
|
|
||||||
if not hasattr(sd_config.model.params, "use_ema"):
|
|
||||||
sd_config.model.params.use_ema = False
|
|
||||||
|
|
||||||
do_inpainting_hijack()
|
do_inpainting_hijack()
|
||||||
|
|
||||||
if shared.cmd_opts.no_half:
|
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
|
||||||
elif shared.cmd_opts.upcast_sampling:
|
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
sd_model = None
|
if already_loaded_state_dict is not None:
|
||||||
|
state_dict = already_loaded_state_dict
|
||||||
|
else:
|
||||||
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
|
|
||||||
|
timer.record("find config")
|
||||||
|
|
||||||
|
sd_config = OmegaConf.load(checkpoint_config)
|
||||||
|
repair_config(sd_config)
|
||||||
|
|
||||||
|
timer.record("load config")
|
||||||
|
|
||||||
|
print(f"Creating model from config: {checkpoint_config}")
|
||||||
|
|
||||||
|
sd_model = None
|
||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization():
|
with sd_disable_initialization.DisableInitialization():
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
@ -407,29 +392,35 @@ def load_model(checkpoint_info=None):
|
||||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
|
||||||
elapsed_create = timer.elapsed()
|
sd_model.used_config = checkpoint_config
|
||||||
|
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
timer.record("create model")
|
||||||
|
|
||||||
elapsed_load_weights = timer.elapsed()
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||||
else:
|
else:
|
||||||
sd_model.to(shared.device)
|
sd_model.to(shared.device)
|
||||||
|
|
||||||
|
timer.record("move model to device")
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
|
timer.record("hijack")
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
shared.sd_model = sd_model
|
shared.sd_model = sd_model
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
|
|
||||||
|
timer.record("load textual inversion embeddings")
|
||||||
|
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
elapsed_the_rest = timer.elapsed()
|
timer.record("scripts callbacks")
|
||||||
|
|
||||||
print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")
|
print(f"Model loaded in {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = shared.sd_model
|
||||||
|
|
||||||
if sd_model is None: # previous model load failed
|
if sd_model is None: # previous model load failed
|
||||||
current_checkpoint_info = None
|
current_checkpoint_info = None
|
||||||
else:
|
else:
|
||||||
|
@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
|
||||||
|
|
||||||
if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info):
|
|
||||||
del sd_model
|
|
||||||
checkpoints_loaded.clear()
|
|
||||||
load_model(checkpoint_info)
|
|
||||||
return shared.sd_model
|
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
|
@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
|
|
||||||
|
timer.record("find config")
|
||||||
|
|
||||||
|
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||||
|
del sd_model
|
||||||
|
checkpoints_loaded.clear()
|
||||||
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
|
||||||
|
return shared.sd_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to load checkpoint, restoring previous")
|
print("Failed to load checkpoint, restoring previous")
|
||||||
load_model_weights(sd_model, current_checkpoint_info)
|
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
timer.record("hijack")
|
||||||
|
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
timer.record("script callbacks")
|
||||||
|
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
|
timer.record("move model to device")
|
||||||
|
|
||||||
elapsed = timer.elapsed()
|
print(f"Weights loaded in {timer.summary()}.")
|
||||||
|
|
||||||
print(f"Weights loaded in {elapsed:.1f}s.")
|
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
65
modules/sd_models_config.py
Normal file
65
modules/sd_models_config.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
|
||||||
|
from modules import shared, paths
|
||||||
|
|
||||||
|
sd_configs_path = shared.sd_configs_path
|
||||||
|
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||||
|
|
||||||
|
|
||||||
|
config_default = shared.sd_default_config
|
||||||
|
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||||
|
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||||
|
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||||
|
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||||
|
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||||
|
|
||||||
|
re_parametrization_v = re.compile(r'-v\b')
|
||||||
|
|
||||||
|
|
||||||
|
def guess_model_config_from_state_dict(sd, filename):
|
||||||
|
fn = os.path.basename(filename)
|
||||||
|
|
||||||
|
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||||
|
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None)
|
||||||
|
|
||||||
|
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||||
|
if re.search(re_parametrization_v, fn) or "v2-1_768" in fn:
|
||||||
|
return config_sd2v
|
||||||
|
else:
|
||||||
|
return config_sd2
|
||||||
|
|
||||||
|
if diffusion_model_input is not None:
|
||||||
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
return config_inpainting
|
||||||
|
if diffusion_model_input.shape[1] == 8:
|
||||||
|
return config_instruct_pix2pix
|
||||||
|
|
||||||
|
if roberta_weight is not None:
|
||||||
|
return config_alt_diffusion
|
||||||
|
|
||||||
|
return config_default
|
||||||
|
|
||||||
|
|
||||||
|
def find_checkpoint_config(state_dict, info):
|
||||||
|
if info is None:
|
||||||
|
return guess_model_config_from_state_dict(state_dict, "")
|
||||||
|
|
||||||
|
config = find_checkpoint_config_near_filename(info)
|
||||||
|
if config is not None:
|
||||||
|
return config
|
||||||
|
|
||||||
|
return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||||
|
|
||||||
|
|
||||||
|
def find_checkpoint_config_near_filename(info):
|
||||||
|
if info is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||||
|
if os.path.exists(config):
|
||||||
|
return config
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -13,13 +13,14 @@ import modules.interrogate
|
||||||
import modules.memmon
|
import modules.memmon
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items
|
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
||||||
from modules.paths import models_path, script_path
|
from modules.paths import models_path, script_path
|
||||||
|
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
|
|
||||||
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
|
sd_configs_path = os.path.join(script_path, "configs")
|
||||||
|
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
default_sd_model_file = sd_model_file
|
default_sd_model_file = sd_model_file
|
||||||
|
|
||||||
|
@ -391,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||||
|
|
|
@ -4,7 +4,20 @@ def realesrgan_models_names():
|
||||||
import modules.realesrgan_model
|
import modules.realesrgan_model
|
||||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||||
|
|
||||||
|
|
||||||
def postprocessing_scripts():
|
def postprocessing_scripts():
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
|
||||||
return modules.scripts.scripts_postproc.scripts
|
return modules.scripts.scripts_postproc.scripts
|
||||||
|
|
||||||
|
|
||||||
|
def sd_vae_items():
|
||||||
|
import modules.sd_vae
|
||||||
|
|
||||||
|
return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_vae_list():
|
||||||
|
import modules.sd_vae
|
||||||
|
|
||||||
|
return modules.sd_vae.refresh_vae_list
|
||||||
|
|
35
modules/timer.py
Normal file
35
modules/timer.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
def __init__(self):
|
||||||
|
self.start = time.time()
|
||||||
|
self.records = {}
|
||||||
|
self.total = 0
|
||||||
|
|
||||||
|
def elapsed(self):
|
||||||
|
end = time.time()
|
||||||
|
res = end - self.start
|
||||||
|
self.start = end
|
||||||
|
return res
|
||||||
|
|
||||||
|
def record(self, category, extra_time=0):
|
||||||
|
e = self.elapsed()
|
||||||
|
if category not in self.records:
|
||||||
|
self.records[category] = 0
|
||||||
|
|
||||||
|
self.records[category] += e + extra_time
|
||||||
|
self.total += e + extra_time
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
res = f"{self.total:.1f}s"
|
||||||
|
|
||||||
|
additions = [x for x in self.records.items() if x[1] >= 0.1]
|
||||||
|
if not additions:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res += " ("
|
||||||
|
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
|
||||||
|
res += ")"
|
||||||
|
|
||||||
|
return res
|
Loading…
Reference in a new issue