do not unnecessarily run VAE one more time when saving intermediate image with hires fix
This commit is contained in:
parent
9c67408004
commit
eb5e82c7dd
4 changed files with 23 additions and 22 deletions
|
@ -199,7 +199,7 @@ class StableDiffusionProcessing():
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -521,11 +521,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
# Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix.
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||||
if isinstance(p, StableDiffusionProcessingTxt2Img):
|
|
||||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, n=n)
|
|
||||||
else:
|
|
||||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
|
||||||
|
|
||||||
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
||||||
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||||
|
@ -653,7 +649,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, n=0):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
|
@ -666,9 +662,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||||
|
|
||||||
|
"""saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
|
||||||
|
def save_intermediate(image, index):
|
||||||
|
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
image = sd_samplers.sample_to_image(image, index)
|
||||||
|
|
||||||
|
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
|
||||||
|
|
||||||
if opts.use_scale_latent_for_hires_fix:
|
if opts.use_scale_latent_for_hires_fix:
|
||||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||||
|
|
||||||
|
for i in range(samples.shape[0]):
|
||||||
|
save_intermediate(samples, i)
|
||||||
else:
|
else:
|
||||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
@ -678,6 +686,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
image = Image.fromarray(x_sample)
|
image = Image.fromarray(x_sample)
|
||||||
|
|
||||||
|
save_intermediate(image, i)
|
||||||
|
|
||||||
image = images.resize_image(0, image, self.width, self.height)
|
image = images.resize_image(0, image, self.width, self.height)
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = np.moveaxis(image, 2, 0)
|
image = np.moveaxis(image, 2, 0)
|
||||||
|
@ -689,15 +700,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
||||||
|
|
||||||
# Save a copy of the image/s before doing highres fix, if applicable.
|
|
||||||
if opts.save and not self.do_not_save_samples and opts.save_images_before_highres_fix:
|
|
||||||
for i in range(self.batch_size):
|
|
||||||
# This batch's ith image.
|
|
||||||
img = sd_samplers.sample_to_image(samples, i)
|
|
||||||
# Index that accounts for both batch size and batch count.
|
|
||||||
ind = i + self.batch_size*n
|
|
||||||
images.save_image(img, self.outpath_samples, "", self.all_seeds[ind], self.all_prompts[ind], opts.samples_format, suffix=f"-before-highres-fix")
|
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||||
|
@ -844,8 +846,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
|
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
|
||||||
|
|
||||||
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||||
|
@ -856,4 +857,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
del x
|
del x
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
|
@ -96,6 +96,7 @@ def single_sample_to_image(sample):
|
||||||
def sample_to_image(samples, index=0):
|
def sample_to_image(samples, index=0):
|
||||||
return single_sample_to_image(samples[index])
|
return single_sample_to_image(samples[index])
|
||||||
|
|
||||||
|
|
||||||
def samples_to_image_grid(samples):
|
def samples_to_image_grid(samples):
|
||||||
return images.image_grid([single_sample_to_image(sample) for sample in samples])
|
return images.image_grid([single_sample_to_image(sample) for sample in samples])
|
||||||
|
|
||||||
|
|
|
@ -256,6 +256,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||||
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||||
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
||||||
|
|
||||||
|
@ -322,7 +323,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
|
|
|
@ -166,8 +166,7 @@ class Script(scripts.Script):
|
||||||
if override_strength:
|
if override_strength:
|
||||||
p.denoising_strength = 1.0
|
p.denoising_strength = 1.0
|
||||||
|
|
||||||
|
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
|
||||||
lat = (p.init_latent.cpu().numpy() * 10).astype(int)
|
lat = (p.init_latent.cpu().numpy() * 10).astype(int)
|
||||||
|
|
||||||
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
|
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
|
||||||
|
|
Loading…
Reference in a new issue