add support for prompts, negative prompts, and sampler-by-name in text file script

This commit is contained in:
David Vorick 2022-12-13 12:03:16 -05:00
parent 685f9631b5
commit 27c0504bc4
No known key found for this signature in database
GPG key ID: 413D8291FBEBB2CC

View file

@ -9,6 +9,7 @@ import shlex
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import sd_samplers
from modules.processing import Processed, process_images from modules.processing import Processed, process_images
from PIL import Image from PIL import Image
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
@ -44,6 +45,7 @@ prompt_tags = {
"seed_resize_from_h": process_int_tag, "seed_resize_from_h": process_int_tag,
"seed_resize_from_w": process_int_tag, "seed_resize_from_w": process_int_tag,
"sampler_index": process_int_tag, "sampler_index": process_int_tag,
"sampler_name": process_string_tag,
"batch_size": process_int_tag, "batch_size": process_int_tag,
"n_iter": process_int_tag, "n_iter": process_int_tag,
"steps": process_int_tag, "steps": process_int_tag,
@ -66,14 +68,28 @@ def cmdargs(line):
arg = args[pos] arg = args[pos]
assert arg.startswith("--"), f'must start with "--": {arg}' assert arg.startswith("--"), f'must start with "--": {arg}'
assert pos+1 < len(args), f'missing argument for command line option {arg}'
tag = arg[2:] tag = arg[2:]
if tag == "prompt" or tag == "negative_prompt":
pos += 1
prompt = args[pos]
pos += 1
while pos < len(args) and not args[pos].startswith("--"):
prompt += " "
prompt += args[pos]
pos += 1
res[tag] = prompt
continue
func = prompt_tags.get(tag, None) func = prompt_tags.get(tag, None)
assert func, f'unknown commandline option: {arg}' assert func, f'unknown commandline option: {arg}'
assert pos+1 < len(args), f'missing argument for command line option {arg}'
val = args[pos+1] val = args[pos+1]
if tag == "sampler_name":
val = sd_samplers.samplers_map.get(val.lower(), None)
res[tag] = func(val) res[tag] = func(val)