diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 84e7e350..68c8f26d 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -256,6 +256,9 @@ def stack_conds(conds): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + # images is required here to give training previews their infotext. Importing this at the very top causes a circular dependency. + from modules import images + assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) @@ -298,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log last_saved_file = "" last_saved_image = "" + forced_filename = "" ititial_step = hypernetwork.step or 0 if ititial_step > steps: @@ -345,7 +349,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + forced_filename = f'{hypernetwork_name}-{hypernetwork.step}' + last_saved_image = os.path.join(images_dir, forced_filename) optimizer.zero_grad() shared.sd_model.cond_stage_model.to(devices.device) @@ -381,7 +386,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if image is not None: shared.state.current_image = image - image.save(last_saved_image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step