diff --git a/javascript/ui.js b/javascript/ui.js index bfe02410..f94ed081 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -199,12 +199,18 @@ let txt2img_textarea, img2img_textarea = undefined; let wait_time = 800 let token_timeout; -function submit_prompt(event, generate_button_id) { - if (event.altKey && event.keyCode === 13) { - event.preventDefault(); - gradioApp().getElementById(generate_button_id).click(); - return; - } +function update_txt2img_tokens(...args) { + update_token_counter("txt2img_token_button") + if (args.length == 2) + return args[0] + return args; +} + +function update_img2img_tokens(...args) { + update_token_counter("img2img_token_button") + if (args.length == 2) + return args[0] + return args; } function update_token_counter(button_id) { @@ -212,3 +218,10 @@ function update_token_counter(button_id) { clearTimeout(token_timeout); token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); } +function submit_prompt(event, generate_button_id) { + if (event.altKey && event.keyCode === 13) { + event.preventDefault(); + gradioApp().getElementById(generate_button_id).click(); + return; + } +} \ No newline at end of file diff --git a/modules/ui.py b/modules/ui.py index eca50df0..c8f5bb84 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -11,6 +11,7 @@ import time import traceback import platform import subprocess as sp +from functools import reduce import numpy as np import torch @@ -33,6 +34,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste +from modules.prompt_parser import get_learned_conditioning_prompt_schedules from modules.images import apply_filename_pattern, get_next_sequence_number import modules.textual_inversion.ui @@ -384,8 +386,11 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: outputs=[seed, dummy_component] ) -def update_token_counter(text): - tokens, token_count, max_length = model_hijack.tokenize(text) +def update_token_counter(text, steps): + prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step,prompt_text in flat_prompts] + tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) style_class = ' class="red"' if (token_count > max_length) else "" return f"{token_count}/{max_length}" @@ -403,8 +408,7 @@ def create_toprow(is_img2img): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter]) + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") with gr.Column(scale=10, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) @@ -435,7 +439,7 @@ def create_toprow(is_img2img): prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") save_style = gr.Button('Create style', elem_id="style_create") - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -464,7 +468,7 @@ def create_ui(wrap_gradio_gpu_call): import modules.txt2img with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) with gr.Row(elem_id='txt2img_progress_row'): @@ -584,6 +588,7 @@ def create_ui(wrap_gradio_gpu_call): roll.click( fn=roll_artist, + _js="update_txt2img_tokens", inputs=[ txt2img_prompt, ], @@ -612,9 +617,10 @@ def create_ui(wrap_gradio_gpu_call): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), ] modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) + token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): with gr.Column(scale=1): @@ -788,6 +794,7 @@ def create_ui(wrap_gradio_gpu_call): roll.click( fn=roll_artist, + _js="update_img2img_tokens", inputs=[ img2img_prompt, ], @@ -798,6 +805,7 @@ def create_ui(wrap_gradio_gpu_call): prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): button.click( @@ -809,9 +817,10 @@ def create_ui(wrap_gradio_gpu_call): outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], ) - for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns): + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): button.click( fn=apply_styles, + _js=js_func, inputs=[prompt, negative_prompt, style1, style2], outputs=[prompt, negative_prompt, style1, style2], ) @@ -834,6 +843,7 @@ def create_ui(wrap_gradio_gpu_call): (denoising_strength, "Denoising strength"), ] modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt) + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Row().style(equal_height=False):