make it possible for StableDiffusionProcessing to accept multiple different negative prompts in a batch
This commit is contained in:
parent
e35d8b493f
commit
617c5b486f
2 changed files with 24 additions and 33 deletions
|
@ -124,6 +124,7 @@ class StableDiffusionProcessing():
|
|||
self.scripts = None
|
||||
self.script_args = None
|
||||
self.all_prompts = None
|
||||
self.all_negative_prompts = None
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
|
||||
|
@ -202,7 +203,7 @@ class StableDiffusionProcessing():
|
|||
|
||||
|
||||
class Processed:
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
||||
self.images = images_list
|
||||
self.prompt = p.prompt
|
||||
self.negative_prompt = p.negative_prompt
|
||||
|
@ -241,16 +242,18 @@ class Processed:
|
|||
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
||||
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
|
||||
|
||||
self.all_prompts = all_prompts or [self.prompt]
|
||||
self.all_seeds = all_seeds or [self.seed]
|
||||
self.all_subseeds = all_subseeds or [self.subseed]
|
||||
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
|
||||
self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
|
||||
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
||||
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
||||
self.infotexts = infotexts or [info]
|
||||
|
||||
def js(self):
|
||||
obj = {
|
||||
"prompt": self.prompt,
|
||||
"prompt": self.all_prompts[0],
|
||||
"all_prompts": self.all_prompts,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"negative_prompt": self.all_negative_prompts[0],
|
||||
"all_negative_prompts": self.all_negative_prompts,
|
||||
"seed": self.seed,
|
||||
"all_seeds": self.all_seeds,
|
||||
"subseed": self.subseed,
|
||||
|
@ -411,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
||||
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if p.all_negative_prompts[0] else ""
|
||||
|
||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
|
||||
|
@ -440,10 +443,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
else:
|
||||
assert p.prompt is not None
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
seed = get_fixed_seed(p.seed)
|
||||
|
@ -453,15 +452,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
modules.sd_hijack.model_hijack.clear_comments()
|
||||
|
||||
comments = {}
|
||||
prompt_tmp = p.prompt
|
||||
negative_prompt_tmp = p.negative_prompt
|
||||
|
||||
shared.prompt_styles.apply_styles(p)
|
||||
|
||||
if type(p.prompt) == list:
|
||||
p.all_prompts = p.prompt
|
||||
p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
|
||||
else:
|
||||
p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
||||
p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
|
||||
|
||||
if type(p.negative_prompt) == list:
|
||||
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
|
||||
else:
|
||||
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
|
||||
|
||||
if type(seed) == list:
|
||||
p.all_seeds = seed
|
||||
|
@ -476,6 +476,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
def infotext(iteration=0, position_in_batch=0):
|
||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
|
@ -500,6 +504,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
break
|
||||
|
||||
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
|
||||
|
@ -510,7 +515,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||
|
||||
with devices.autocast():
|
||||
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
||||
uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
|
||||
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
||||
|
||||
if len(model_hijack.comments) > 0:
|
||||
|
@ -596,14 +601,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
devices.torch_gc()
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
|
||||
p.prompt = prompt_tmp
|
||||
p.negative_prompt = negative_prompt_tmp
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
@ -65,17 +65,6 @@ class StyleDatabase:
|
|||
def apply_negative_styles_to_prompt(self, prompt, styles):
|
||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
||||
|
||||
def apply_styles(self, p: StableDiffusionProcessing) -> None:
|
||||
if isinstance(p.prompt, list):
|
||||
p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
|
||||
else:
|
||||
p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
|
||||
|
||||
if isinstance(p.negative_prompt, list):
|
||||
p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
|
||||
else:
|
||||
p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
|
||||
|
||||
def save_styles(self, path: str) -> None:
|
||||
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
||||
fd, temp_path = tempfile.mkstemp(".csv")
|
||||
|
|
Loading…
Reference in a new issue