feat(api): add override_settings_restore_afterwards

This commit is contained in:
Philpax 2022-12-20 20:36:49 +11:00
parent 685f9631b5
commit 22f1527fa7

View file

@ -77,7 +77,7 @@ class StableDiffusionProcessing():
""" """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
""" """
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
if sampler_index is not None: if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@ -118,6 +118,7 @@ class StableDiffusionProcessing():
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise self.s_noise = s_noise or opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False self.is_using_inpainting_conditioning = False
if not seed_enable_extras: if not seed_enable_extras:
@ -147,11 +148,11 @@ class StableDiffusionProcessing():
# The "masked-image" in this case will just be all zeros since the entire image is masked. # The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension. # Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype) image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning return image_conditioning
@ -199,7 +200,7 @@ class StableDiffusionProcessing():
source_image * (1.0 - conditioning_mask), source_image * (1.0 - conditioning_mask),
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
) )
# Encode the new masked image using first stage of network. # Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
@ -463,12 +464,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
res = process_images_inner(p) res = process_images_inner(p)
finally: # restore opts to original state finally:
for k, v in stored_opts.items(): # restore opts to original state
setattr(opts, k, v) if p.override_settings_restore_afterwards:
if k == 'sd_hypernetwork': shared.reload_hypernetworks() for k, v in stored_opts.items():
if k == 'sd_model_checkpoint': sd_models.reload_model_weights() setattr(opts, k, v)
if k == 'sd_vae': sd_vae.reload_vae_weights() if k == 'sd_hypernetwork': shared.reload_hypernetworks()
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
if k == 'sd_vae': sd_vae.reload_vae_weights()
return res return res
@ -537,7 +540,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
for n in range(p.n_iter): for n in range(p.n_iter):
if state.skipped: if state.skipped:
state.skipped = False state.skipped = False
if state.interrupted: if state.interrupted:
break break
@ -612,7 +615,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
del x_samples_ddim del x_samples_ddim
devices.torch_gc() devices.torch_gc()
@ -720,7 +723,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
# Avoid making the inpainting conditioning unless necessary as # Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again. # this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)