Merge pull request #1324 from liamkerr/token_updates
Fixing Bugs with Token Counter
This commit is contained in:
commit
0b94fc5033
2 changed files with 37 additions and 14 deletions
|
@ -199,12 +199,18 @@ let txt2img_textarea, img2img_textarea = undefined;
|
||||||
let wait_time = 800
|
let wait_time = 800
|
||||||
let token_timeout;
|
let token_timeout;
|
||||||
|
|
||||||
function submit_prompt(event, generate_button_id) {
|
function update_txt2img_tokens(...args) {
|
||||||
if (event.altKey && event.keyCode === 13) {
|
update_token_counter("txt2img_token_button")
|
||||||
event.preventDefault();
|
if (args.length == 2)
|
||||||
gradioApp().getElementById(generate_button_id).click();
|
return args[0]
|
||||||
return;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function update_img2img_tokens(...args) {
|
||||||
|
update_token_counter("img2img_token_button")
|
||||||
|
if (args.length == 2)
|
||||||
|
return args[0]
|
||||||
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
function update_token_counter(button_id) {
|
function update_token_counter(button_id) {
|
||||||
|
@ -212,3 +218,10 @@ function update_token_counter(button_id) {
|
||||||
clearTimeout(token_timeout);
|
clearTimeout(token_timeout);
|
||||||
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||||
}
|
}
|
||||||
|
function submit_prompt(event, generate_button_id) {
|
||||||
|
if (event.altKey && event.keyCode === 13) {
|
||||||
|
event.preventDefault();
|
||||||
|
gradioApp().getElementById(generate_button_id).click();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import platform
|
import platform
|
||||||
import subprocess as sp
|
import subprocess as sp
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -33,6 +34,7 @@ import modules.gfpgan_model
|
||||||
import modules.codeformer_model
|
import modules.codeformer_model
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.generation_parameters_copypaste
|
import modules.generation_parameters_copypaste
|
||||||
|
from modules.prompt_parser import get_learned_conditioning_prompt_schedules
|
||||||
from modules.images import apply_filename_pattern, get_next_sequence_number
|
from modules.images import apply_filename_pattern, get_next_sequence_number
|
||||||
import modules.textual_inversion.ui
|
import modules.textual_inversion.ui
|
||||||
|
|
||||||
|
@ -384,8 +386,11 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
||||||
outputs=[seed, dummy_component]
|
outputs=[seed, dummy_component]
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_token_counter(text):
|
def update_token_counter(text, steps):
|
||||||
tokens, token_count, max_length = model_hijack.tokenize(text)
|
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
|
||||||
|
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
||||||
|
prompts = [prompt_text for step,prompt_text in flat_prompts]
|
||||||
|
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
|
||||||
style_class = ' class="red"' if (token_count > max_length) else ""
|
style_class = ' class="red"' if (token_count > max_length) else ""
|
||||||
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
||||||
|
|
||||||
|
@ -403,8 +408,7 @@ def create_toprow(is_img2img):
|
||||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
||||||
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||||
hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||||
hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
|
|
||||||
|
|
||||||
with gr.Column(scale=10, elem_id="style_pos_col"):
|
with gr.Column(scale=10, elem_id="style_pos_col"):
|
||||||
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
||||||
|
@ -435,7 +439,7 @@ def create_toprow(is_img2img):
|
||||||
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
||||||
save_style = gr.Button('Create style', elem_id="style_create")
|
save_style = gr.Button('Create style', elem_id="style_create")
|
||||||
|
|
||||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
|
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||||
|
@ -464,7 +468,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
|
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
|
||||||
with gr.Row(elem_id='txt2img_progress_row'):
|
with gr.Row(elem_id='txt2img_progress_row'):
|
||||||
|
@ -584,6 +588,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
|
|
||||||
roll.click(
|
roll.click(
|
||||||
fn=roll_artist,
|
fn=roll_artist,
|
||||||
|
_js="update_txt2img_tokens",
|
||||||
inputs=[
|
inputs=[
|
||||||
txt2img_prompt,
|
txt2img_prompt,
|
||||||
],
|
],
|
||||||
|
@ -612,9 +617,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||||
]
|
]
|
||||||
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
|
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
|
||||||
|
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True)
|
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
|
||||||
|
|
||||||
with gr.Row(elem_id='img2img_progress_row'):
|
with gr.Row(elem_id='img2img_progress_row'):
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
|
@ -788,6 +794,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
|
|
||||||
roll.click(
|
roll.click(
|
||||||
fn=roll_artist,
|
fn=roll_artist,
|
||||||
|
_js="update_img2img_tokens",
|
||||||
inputs=[
|
inputs=[
|
||||||
img2img_prompt,
|
img2img_prompt,
|
||||||
],
|
],
|
||||||
|
@ -798,6 +805,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
|
|
||||||
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
|
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
|
||||||
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
|
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
|
||||||
|
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
|
||||||
|
|
||||||
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
|
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
|
||||||
button.click(
|
button.click(
|
||||||
|
@ -809,9 +817,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
|
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
|
||||||
)
|
)
|
||||||
|
|
||||||
for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns):
|
for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
|
||||||
button.click(
|
button.click(
|
||||||
fn=apply_styles,
|
fn=apply_styles,
|
||||||
|
_js=js_func,
|
||||||
inputs=[prompt, negative_prompt, style1, style2],
|
inputs=[prompt, negative_prompt, style1, style2],
|
||||||
outputs=[prompt, negative_prompt, style1, style2],
|
outputs=[prompt, negative_prompt, style1, style2],
|
||||||
)
|
)
|
||||||
|
@ -834,6 +843,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
]
|
]
|
||||||
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
|
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
|
||||||
|
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
|
|
Loading…
Reference in a new issue