add option to read generation params for learning previews from txt2img

This commit is contained in:
AUTOMATIC 2022-10-14 20:31:49 +03:00
parent bb295f5478
commit c344ba3b32
3 changed files with 51 additions and 15 deletions

View file

@ -180,7 +180,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return self.to_out(out) return self.to_out(out)
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected' assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
@ -265,20 +265,31 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad() optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
prompt=preview_text,
steps=20,
do_not_save_grid=True, do_not_save_grid=True,
do_not_save_samples=True, do_not_save_samples=True,
) )
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entry.cond_text
p.steps = 20
preview_text = p.prompt
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None image = processed.images[0] if len(processed.images)>0 else None

View file

@ -172,7 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
@ -259,18 +259,29 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
prompt=preview_text,
steps=20,
height=training_height,
width=training_width,
do_not_save_grid=True, do_not_save_grid=True,
do_not_save_samples=True, do_not_save_samples=True,
) )
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entry.cond_text
p.steps = 20
p.width = training_width
p.height = training_height
preview_text = p.prompt
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0]

View file

@ -711,6 +711,18 @@ def create_ui(wrap_gradio_gpu_call):
(firstphase_width, "First pass size-1"), (firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"), (firstphase_height, "First pass size-2"),
] ]
txt2img_preview_params = [
txt2img_prompt,
txt2img_negative_prompt,
steps,
sampler_index,
cfg_scale,
seed,
width,
height,
]
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
@ -1162,7 +1174,7 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_image_prompt = gr.Textbox(label='Preview prompt', value="") preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row(): with gr.Row():
interrupt_training = gr.Button(value="Interrupt") interrupt_training = gr.Button(value="Interrupt")
@ -1240,7 +1252,8 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every, save_embedding_every,
template_file, template_file,
save_image_with_stored_embedding, save_image_with_stored_embedding,
preview_image_prompt, preview_from_txt2img,
*txt2img_preview_params,
], ],
outputs=[ outputs=[
ti_output, ti_output,
@ -1260,7 +1273,8 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
preview_image_prompt, preview_from_txt2img,
*txt2img_preview_params,
], ],
outputs=[ outputs=[
ti_output, ti_output,