Merge pull request #1324 from liamkerr/token_updates

Fixing Bugs with Token Counter
This commit is contained in:
AUTOMATIC1111 2022-10-02 21:20:05 +03:00 committed by GitHub
commit 0b94fc5033
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 14 deletions

View file

@ -199,12 +199,18 @@ let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800 let wait_time = 800
let token_timeout; let token_timeout;
function submit_prompt(event, generate_button_id) { function update_txt2img_tokens(...args) {
if (event.altKey && event.keyCode === 13) { update_token_counter("txt2img_token_button")
event.preventDefault(); if (args.length == 2)
gradioApp().getElementById(generate_button_id).click(); return args[0]
return; return args;
} }
function update_img2img_tokens(...args) {
update_token_counter("img2img_token_button")
if (args.length == 2)
return args[0]
return args;
} }
function update_token_counter(button_id) { function update_token_counter(button_id) {
@ -212,3 +218,10 @@ function update_token_counter(button_id) {
clearTimeout(token_timeout); clearTimeout(token_timeout);
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
} }
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
gradioApp().getElementById(generate_button_id).click();
return;
}
}

View file

@ -11,6 +11,7 @@ import time
import traceback import traceback
import platform import platform
import subprocess as sp import subprocess as sp
from functools import reduce
import numpy as np import numpy as np
import torch import torch
@ -33,6 +34,7 @@ import modules.gfpgan_model
import modules.codeformer_model import modules.codeformer_model
import modules.styles import modules.styles
import modules.generation_parameters_copypaste import modules.generation_parameters_copypaste
from modules.prompt_parser import get_learned_conditioning_prompt_schedules
from modules.images import apply_filename_pattern, get_next_sequence_number from modules.images import apply_filename_pattern, get_next_sequence_number
import modules.textual_inversion.ui import modules.textual_inversion.ui
@ -384,8 +386,11 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component] outputs=[seed, dummy_component]
) )
def update_token_counter(text): def update_token_counter(text, steps):
tokens, token_count, max_length = model_hijack.tokenize(text) prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step,prompt_text in flat_prompts]
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
style_class = ' class="red"' if (token_count > max_length) else "" style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>" return f"<span {style_class}>{token_count}/{max_length}</span>"
@ -403,8 +408,7 @@ def create_toprow(is_img2img):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste") paste = gr.Button(value=paste_symbol, elem_id="paste")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter") token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
with gr.Column(scale=10, elem_id="style_pos_col"): with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@ -435,7 +439,7 @@ def create_toprow(is_img2img):
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
save_style = gr.Button('Create style', elem_id="style_create") save_style = gr.Button('Create style', elem_id="style_create")
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None): def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@ -464,7 +468,7 @@ def create_ui(wrap_gradio_gpu_call):
import modules.txt2img import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
with gr.Row(elem_id='txt2img_progress_row'): with gr.Row(elem_id='txt2img_progress_row'):
@ -584,6 +588,7 @@ def create_ui(wrap_gradio_gpu_call):
roll.click( roll.click(
fn=roll_artist, fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[ inputs=[
txt2img_prompt, txt2img_prompt,
], ],
@ -612,9 +617,10 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
] ]
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True) img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'): with gr.Row(elem_id='img2img_progress_row'):
with gr.Column(scale=1): with gr.Column(scale=1):
@ -788,6 +794,7 @@ def create_ui(wrap_gradio_gpu_call):
roll.click( roll.click(
fn=roll_artist, fn=roll_artist,
_js="update_img2img_tokens",
inputs=[ inputs=[
img2img_prompt, img2img_prompt,
], ],
@ -798,6 +805,7 @@ def create_ui(wrap_gradio_gpu_call):
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
button.click( button.click(
@ -809,9 +817,10 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
) )
for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns): for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
button.click( button.click(
fn=apply_styles, fn=apply_styles,
_js=js_func,
inputs=[prompt, negative_prompt, style1, style2], inputs=[prompt, negative_prompt, style1, style2],
outputs=[prompt, negative_prompt, style1, style2], outputs=[prompt, negative_prompt, style1, style2],
) )
@ -834,6 +843,7 @@ def create_ui(wrap_gradio_gpu_call):
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
] ]
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt) modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):