add option to read generation params for learning previews from txt2img
This commit is contained in:
parent
bb295f5478
commit
c344ba3b32
3 changed files with 51 additions and 15 deletions
|
@ -180,7 +180,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
|
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
assert hypernetwork_name, 'hypernetwork not selected'
|
assert hypernetwork_name, 'hypernetwork not selected'
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
|
@ -265,20 +265,31 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||||
|
|
||||||
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
prompt=preview_text,
|
|
||||||
steps=20,
|
|
||||||
do_not_save_grid=True,
|
do_not_save_grid=True,
|
||||||
do_not_save_samples=True,
|
do_not_save_samples=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if preview_from_txt2img:
|
||||||
|
p.prompt = preview_prompt
|
||||||
|
p.negative_prompt = preview_negative_prompt
|
||||||
|
p.steps = preview_steps
|
||||||
|
p.sampler_index = preview_sampler_index
|
||||||
|
p.cfg_scale = preview_cfg_scale
|
||||||
|
p.seed = preview_seed
|
||||||
|
p.width = preview_width
|
||||||
|
p.height = preview_height
|
||||||
|
else:
|
||||||
|
p.prompt = entry.cond_text
|
||||||
|
p.steps = 20
|
||||||
|
|
||||||
|
preview_text = p.prompt
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
image = processed.images[0] if len(processed.images)>0 else None
|
image = processed.images[0] if len(processed.images)>0 else None
|
||||||
|
|
||||||
|
|
|
@ -172,7 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
assert embedding_name, 'embedding not selected'
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
@ -259,18 +259,29 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||||
|
|
||||||
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
|
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
prompt=preview_text,
|
|
||||||
steps=20,
|
|
||||||
height=training_height,
|
|
||||||
width=training_width,
|
|
||||||
do_not_save_grid=True,
|
do_not_save_grid=True,
|
||||||
do_not_save_samples=True,
|
do_not_save_samples=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if preview_from_txt2img:
|
||||||
|
p.prompt = preview_prompt
|
||||||
|
p.negative_prompt = preview_negative_prompt
|
||||||
|
p.steps = preview_steps
|
||||||
|
p.sampler_index = preview_sampler_index
|
||||||
|
p.cfg_scale = preview_cfg_scale
|
||||||
|
p.seed = preview_seed
|
||||||
|
p.width = preview_width
|
||||||
|
p.height = preview_height
|
||||||
|
else:
|
||||||
|
p.prompt = entry.cond_text
|
||||||
|
p.steps = 20
|
||||||
|
p.width = training_width
|
||||||
|
p.height = training_height
|
||||||
|
|
||||||
|
preview_text = p.prompt
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
image = processed.images[0]
|
image = processed.images[0]
|
||||||
|
|
||||||
|
|
|
@ -711,6 +711,18 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
(firstphase_width, "First pass size-1"),
|
(firstphase_width, "First pass size-1"),
|
||||||
(firstphase_height, "First pass size-2"),
|
(firstphase_height, "First pass size-2"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
txt2img_preview_params = [
|
||||||
|
txt2img_prompt,
|
||||||
|
txt2img_negative_prompt,
|
||||||
|
steps,
|
||||||
|
sampler_index,
|
||||||
|
cfg_scale,
|
||||||
|
seed,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
]
|
||||||
|
|
||||||
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
|
@ -1162,7 +1174,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||||
preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
|
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
interrupt_training = gr.Button(value="Interrupt")
|
interrupt_training = gr.Button(value="Interrupt")
|
||||||
|
@ -1240,7 +1252,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
template_file,
|
template_file,
|
||||||
save_image_with_stored_embedding,
|
save_image_with_stored_embedding,
|
||||||
preview_image_prompt,
|
preview_from_txt2img,
|
||||||
|
*txt2img_preview_params,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
ti_output,
|
ti_output,
|
||||||
|
@ -1260,7 +1273,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
template_file,
|
template_file,
|
||||||
preview_image_prompt,
|
preview_from_txt2img,
|
||||||
|
*txt2img_preview_params,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
ti_output,
|
ti_output,
|
||||||
|
|
Loading…
Reference in a new issue