From 7fd0f3166111c552a2ed4ee1d221583ff5cc1124 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 24 Aug 2022 00:02:43 +0300 Subject: [PATCH] added prompt verification: if it's too long, a warning is returned in the text field along with the part of prompt that has been truncated --- webui.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/webui.py b/webui.py index c8a62c4d..e1dc2c0f 100644 --- a/webui.py +++ b/webui.py @@ -45,6 +45,7 @@ parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-i parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",) parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default='./GFPGAN') +parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long") opt = parser.parse_args() GFPGAN_dir = opt.gfpgan_dir @@ -231,6 +232,25 @@ def draw_prompt_matrix(im, width, height, all_prompts): return result +def check_prompt_length(prompt, comments): + """this function tests if prompt is too long, and if so, adds a message to comments""" + + tokenizer = model.cond_stage_model.tokenizer + max_length = model.cond_stage_model.max_length + + info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt") + ovf = info['overflowing_tokens'][0] + overflowing_count = ovf.shape[0] + if overflowing_count == 0: + return + + vocab = {v: k for k, v in tokenizer.get_vocab().items()} + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + + comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" @@ -248,6 +268,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, base_count = len(os.listdir(sample_path)) grid_count = len(os.listdir(outpath)) - 1 + comments = [] + prompt_matrix_parts = [] if prompt_matrix: all_prompts = [] @@ -267,6 +289,15 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") else: + + if not opt.no_verify_input: + try: + check_prompt_length(prompt, comments) + except: + import traceback + print("Error verifying input:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + all_prompts = batch_size * n_iter * [prompt] all_seeds = [seed + x for x in range(len(all_prompts))] @@ -333,6 +364,9 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} """.strip() + for comment in comments: + info += "\n\n" + comment + return output_images, seed, info