added prompt verification: if it's too long, a warning is returned in the text field along with the part of prompt that has been truncated
This commit is contained in:
parent
e996f3c118
commit
7fd0f31661
1 changed files with 34 additions and 0 deletions
34
webui.py
34
webui.py
|
@ -45,6 +45,7 @@ parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-i
|
||||||
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
|
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default='./GFPGAN')
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default='./GFPGAN')
|
||||||
|
parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
GFPGAN_dir = opt.gfpgan_dir
|
GFPGAN_dir = opt.gfpgan_dir
|
||||||
|
@ -231,6 +232,25 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def check_prompt_length(prompt, comments):
|
||||||
|
"""this function tests if prompt is too long, and if so, adds a message to comments"""
|
||||||
|
|
||||||
|
tokenizer = model.cond_stage_model.tokenizer
|
||||||
|
max_length = model.cond_stage_model.max_length
|
||||||
|
|
||||||
|
info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
|
||||||
|
ovf = info['overflowing_tokens'][0]
|
||||||
|
overflowing_count = ovf.shape[0]
|
||||||
|
if overflowing_count == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
|
||||||
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
|
overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
|
||||||
|
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
|
|
||||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN):
|
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN):
|
||||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||||
|
|
||||||
|
@ -248,6 +268,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||||
base_count = len(os.listdir(sample_path))
|
base_count = len(os.listdir(sample_path))
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
comments = []
|
||||||
|
|
||||||
prompt_matrix_parts = []
|
prompt_matrix_parts = []
|
||||||
if prompt_matrix:
|
if prompt_matrix:
|
||||||
all_prompts = []
|
all_prompts = []
|
||||||
|
@ -267,6 +289,15 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||||
|
|
||||||
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
|
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
if not opt.no_verify_input:
|
||||||
|
try:
|
||||||
|
check_prompt_length(prompt, comments)
|
||||||
|
except:
|
||||||
|
import traceback
|
||||||
|
print("Error verifying input:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
all_prompts = batch_size * n_iter * [prompt]
|
all_prompts = batch_size * n_iter * [prompt]
|
||||||
all_seeds = [seed + x for x in range(len(all_prompts))]
|
all_seeds = [seed + x for x in range(len(all_prompts))]
|
||||||
|
|
||||||
|
@ -333,6 +364,9 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||||
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
for comment in comments:
|
||||||
|
info += "\n\n" + comment
|
||||||
|
|
||||||
return output_images, seed, info
|
return output_images, seed, info
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue