added interrupt button
added save button --always-batch-cond-uncond as a workaround for performance regression option for low memory users specify gradio version as 3.1.5 because of what looks like a bug
This commit is contained in:
parent
54dc6f9307
commit
a6adc22f07
3 changed files with 183 additions and 70 deletions
|
@ -1,6 +1,6 @@
|
||||||
basicsr
|
basicsr
|
||||||
gfpgan
|
gfpgan
|
||||||
gradio
|
gradio==3.1.5
|
||||||
numpy
|
numpy
|
||||||
Pillow
|
Pillow
|
||||||
realesrgan
|
realesrgan
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
console.log("running")
|
|
||||||
|
|
||||||
titles = {
|
titles = {
|
||||||
"Sampling steps": "How many times to imptove the generated image itratively; higher values take longer; very low values can produce bad results",
|
"Sampling steps": "How many times to imptove the generated image itratively; higher values take longer; very low values can produce bad results",
|
||||||
"Sampling method": "Which algorithm to use to produce the image",
|
"Sampling method": "Which algorithm to use to produce the image",
|
||||||
|
@ -29,6 +27,9 @@ titles = {
|
||||||
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
|
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
|
||||||
|
|
||||||
"Denoising Strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image.",
|
"Denoising Strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image.",
|
||||||
|
|
||||||
|
"Interrupt": "Stop processing images and return any results accumulated so far.",
|
||||||
|
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
|
||||||
}
|
}
|
||||||
|
|
||||||
function gradioApp(){
|
function gradioApp(){
|
||||||
|
@ -36,7 +37,7 @@ function gradioApp(){
|
||||||
}
|
}
|
||||||
|
|
||||||
function addTitles(root){
|
function addTitles(root){
|
||||||
root.querySelectorAll('span').forEach(function(span){
|
root.querySelectorAll('span, button').forEach(function(span){
|
||||||
tooltip = titles[span.textContent];
|
tooltip = titles[span.textContent];
|
||||||
if(tooltip){
|
if(tooltip){
|
||||||
span.title = tooltip;
|
span.title = tooltip;
|
||||||
|
|
214
webui.py
214
webui.py
|
@ -68,6 +68,7 @@ parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="em
|
||||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a little speed for low VRM usage")
|
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a little speed for low VRM usage")
|
||||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a lot of speed for very low VRM usage")
|
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a lot of speed for very low VRM usage")
|
||||||
|
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="a workaround test; may help with speed in you use --lowvram")
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
|
@ -75,9 +76,20 @@ cmd_opts = parser.parse_args()
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
gpu = torch.device("cuda")
|
gpu = torch.device("cuda")
|
||||||
device = gpu if torch.cuda.is_available() else cpu
|
device = gpu if torch.cuda.is_available() else cpu
|
||||||
batch_cond_uncond = not (cmd_opts.lowvram or cmd_opts.medvram)
|
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
class State:
|
||||||
|
interrupted = False
|
||||||
|
job = ""
|
||||||
|
|
||||||
|
def interrupt(self):
|
||||||
|
self.interrupted = True
|
||||||
|
|
||||||
|
|
||||||
|
state = State()
|
||||||
|
|
||||||
if not cmd_opts.share:
|
if not cmd_opts.share:
|
||||||
# fix gradio phoning home
|
# fix gradio phoning home
|
||||||
gradio.utils.version_check = lambda: None
|
gradio.utils.version_check = lambda: None
|
||||||
|
@ -198,6 +210,7 @@ class Options:
|
||||||
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output dictectory for img2img grids'),
|
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output dictectory for img2img grids'),
|
||||||
"save_to_dirs": OptionInfo(False, "When writing images/grids, create a directory with name derived from the prompt"),
|
"save_to_dirs": OptionInfo(False, "When writing images/grids, create a directory with name derived from the prompt"),
|
||||||
"save_to_dirs_prompt_len": OptionInfo(10, "When using above, how many words from prompt to put into directory name", gr.Slider, {"minimum": 1, "maximum": 32, "step": 1}),
|
"save_to_dirs_prompt_len": OptionInfo(10, "When using above, how many words from prompt to put into directory name", gr.Slider, {"minimum": 1, "maximum": 32, "step": 1}),
|
||||||
|
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button"),
|
||||||
"samples_save": OptionInfo(True, "Save indiviual samples"),
|
"samples_save": OptionInfo(True, "Save indiviual samples"),
|
||||||
"samples_format": OptionInfo('png', 'File format for indiviual samples'),
|
"samples_format": OptionInfo('png', 'File format for indiviual samples'),
|
||||||
"grid_save": OptionInfo(True, "Save image grids"),
|
"grid_save": OptionInfo(True, "Save image grids"),
|
||||||
|
@ -400,8 +413,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||||
image.save(f"{fullfn_without_extension}.jpg", quality=opts.jpeg_quality, pnginfo=pnginfo)
|
image.save(f"{fullfn_without_extension}.jpg", quality=opts.jpeg_quality, pnginfo=pnginfo)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_filename_part(text):
|
def sanitize_filename_part(text):
|
||||||
return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
|
return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
|
||||||
|
|
||||||
|
@ -410,6 +421,7 @@ def plaintext_to_html(text):
|
||||||
text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
|
text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def image_grid(imgs, batch_size=1, rows=None):
|
def image_grid(imgs, batch_size=1, rows=None):
|
||||||
if rows is None:
|
if rows is None:
|
||||||
if opts.n_rows > 0:
|
if opts.n_rows > 0:
|
||||||
|
@ -652,18 +664,29 @@ def wrap_gradio_gpu_call(func):
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return f
|
return wrap_gradio_call(f)
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_call(func):
|
def wrap_gradio_call(func):
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
res = list(func(*args, **kwargs))
|
res = list(func(*args, **kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
print("Error completing request", file=sys.stderr)
|
||||||
|
print("Arguments:", args, kwargs, file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
res = [None, f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||||
|
|
||||||
elapsed = time.perf_counter() - t
|
elapsed = time.perf_counter() - t
|
||||||
|
|
||||||
# last item is always HTML
|
# last item is always HTML
|
||||||
res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>"
|
res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>"
|
||||||
|
|
||||||
|
state.interrupted = False
|
||||||
|
|
||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
@ -883,7 +906,6 @@ class StableDiffusionProcessing:
|
||||||
self.extra_generation_params: dict = extra_generation_params
|
self.extra_generation_params: dict = extra_generation_params
|
||||||
self.overlay_images = overlay_images
|
self.overlay_images = overlay_images
|
||||||
self.paste_to = None
|
self.paste_to = None
|
||||||
self.progress_info = ""
|
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
pass
|
pass
|
||||||
|
@ -959,6 +981,15 @@ class CFGDenoiser(nn.Module):
|
||||||
|
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
|
||||||
|
def extended_trange(*args, **kwargs):
|
||||||
|
for x in tqdm.trange(*args, desc=state.job, **kwargs):
|
||||||
|
if state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
yield x
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
class KDiffusionSampler:
|
||||||
def __init__(self, funcname):
|
def __init__(self, funcname):
|
||||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
|
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
|
||||||
|
@ -980,7 +1011,7 @@ class KDiffusionSampler:
|
||||||
self.model_wrap_cfg.init_latent = p.init_latent
|
self.model_wrap_cfg.init_latent = p.init_latent
|
||||||
|
|
||||||
if hasattr(k_diffusion.sampling, 'trange'):
|
if hasattr(k_diffusion.sampling, 'trange'):
|
||||||
k_diffusion.sampling.trange = lambda *args, **kwargs: tqdm.tqdm(range(*args), desc=p.progress_info, **kwargs)
|
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
|
||||||
|
|
||||||
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
|
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
|
||||||
|
|
||||||
|
@ -989,13 +1020,36 @@ class KDiffusionSampler:
|
||||||
x = x * sigmas[0]
|
x = x * sigmas[0]
|
||||||
|
|
||||||
if hasattr(k_diffusion.sampling, 'trange'):
|
if hasattr(k_diffusion.sampling, 'trange'):
|
||||||
k_diffusion.sampling.trange = lambda *args, **kwargs: tqdm.tqdm(range(*args), desc=p.progress_info, **kwargs)
|
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
|
||||||
|
|
||||||
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
|
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
|
|
||||||
Processed = namedtuple('Processed', ['images', 'seed', 'info'])
|
class Processed:
|
||||||
|
def __init__(self, p: StableDiffusionProcessing, images, seed, info):
|
||||||
|
self.images = images
|
||||||
|
self.prompt = p.prompt
|
||||||
|
self.seed = seed
|
||||||
|
self.info = info
|
||||||
|
self.width = p.width
|
||||||
|
self.height = p.height
|
||||||
|
self.sampler = samplers[p.sampler_index].name
|
||||||
|
self.cfg_scale = p.cfg_scale
|
||||||
|
self.steps = p.steps
|
||||||
|
|
||||||
|
def js(self):
|
||||||
|
obj = {
|
||||||
|
"prompt": self.prompt,
|
||||||
|
"seed": int(self.seed),
|
||||||
|
"width": self.width,
|
||||||
|
"height": self.height,
|
||||||
|
"sampler": self.sampler,
|
||||||
|
"cfg_scale": self.cfg_scale,
|
||||||
|
"steps": self.steps,
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.dumps(obj)
|
||||||
|
|
||||||
|
|
||||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
@ -1063,6 +1117,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
p.init()
|
p.init()
|
||||||
|
|
||||||
for n in range(p.n_iter):
|
for n in range(p.n_iter):
|
||||||
|
if state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
|
@ -1075,7 +1132,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
# we manually generate all input noises because each one should have a specific seed
|
# we manually generate all input noises because each one should have a specific seed
|
||||||
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
|
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
|
||||||
|
|
||||||
p.progress_info = f"Batch {n+1} out of {p.n_iter}"
|
if p.n_iter > 0:
|
||||||
|
state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
|
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
@ -1137,7 +1196,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
save_image(grid, p.outpath_grids, "grid", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
save_image(grid, p.outpath_grids, "grid", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
||||||
|
|
||||||
torch_gc()
|
torch_gc()
|
||||||
return Processed(output_images, seed, infotext())
|
return Processed(p, output_images, seed, infotext())
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
@ -1188,28 +1247,21 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u
|
||||||
module.display = display
|
module.display = display
|
||||||
exec(compiled, module.__dict__)
|
exec(compiled, module.__dict__)
|
||||||
|
|
||||||
processed = Processed(*display_result_data)
|
processed = Processed(p, *display_result_data)
|
||||||
else:
|
else:
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
|
||||||
return processed.images, processed.seed, plaintext_to_html(processed.info)
|
return processed.images, processed.js(), plaintext_to_html(processed.info)
|
||||||
|
|
||||||
|
def save_files(js_data, images):
|
||||||
class Flagging(gr.FlaggingCallback):
|
|
||||||
|
|
||||||
def setup(self, components, flagging_dir: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
|
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
os.makedirs("log/images", exist_ok=True)
|
os.makedirs(opts.outdir_save, exist_ok=True)
|
||||||
|
|
||||||
# those must match the "txt2img" function
|
|
||||||
prompt, steps, sampler_index, use_gfpgan, prompt_matrix, n_iter, batch_size, cfg_scale, seed, height, width, code, images, seed, comment = flag_data
|
|
||||||
|
|
||||||
filenames = []
|
filenames = []
|
||||||
|
|
||||||
|
data = json.loads(js_data)
|
||||||
|
|
||||||
with open("log/log.csv", "a", encoding="utf8", newline='') as file:
|
with open("log/log.csv", "a", encoding="utf8", newline='') as file:
|
||||||
import time
|
import time
|
||||||
import base64
|
import base64
|
||||||
|
@ -1217,23 +1269,25 @@ class Flagging(gr.FlaggingCallback):
|
||||||
at_start = file.tell() == 0
|
at_start = file.tell() == 0
|
||||||
writer = csv.writer(file)
|
writer = csv.writer(file)
|
||||||
if at_start:
|
if at_start:
|
||||||
writer.writerow(["prompt", "seed", "width", "height", "cfgs", "steps", "filename"])
|
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename"])
|
||||||
|
|
||||||
filename_base = str(int(time.time() * 1000))
|
filename_base = str(int(time.time() * 1000))
|
||||||
for i, filedata in enumerate(images):
|
for i, filedata in enumerate(images):
|
||||||
filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png"
|
filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png"
|
||||||
|
filepath = os.path.join(opts.outdir_save, filename)
|
||||||
|
|
||||||
if filedata.startswith("data:image/png;base64,"):
|
if filedata.startswith("data:image/png;base64,"):
|
||||||
filedata = filedata[len("data:image/png;base64,"):]
|
filedata = filedata[len("data:image/png;base64,"):]
|
||||||
|
|
||||||
with open(filename, "wb") as imgfile:
|
with open(filepath, "wb") as imgfile:
|
||||||
imgfile.write(base64.decodebytes(filedata.encode('utf-8')))
|
imgfile.write(base64.decodebytes(filedata.encode('utf-8')))
|
||||||
|
|
||||||
filenames.append(filename)
|
filenames.append(filename)
|
||||||
|
|
||||||
writer.writerow([prompt, seed, width, height, cfg_scale, steps, filenames[0]])
|
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0]])
|
||||||
|
|
||||||
|
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
print("Logged:", filenames[0])
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -1267,8 +1321,15 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gallery = gr.Gallery(label='Output')
|
gallery = gr.Gallery(label='Output')
|
||||||
output_seed = gr.Number(label='Seed', visible=False)
|
|
||||||
|
with gr.Group():
|
||||||
|
with gr.Row():
|
||||||
|
interrupt = gr.Button('Interrupt')
|
||||||
|
save = gr.Button('Save')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
|
generation_info = gr.Textbox(visible=False)
|
||||||
|
|
||||||
txt2img_args = dict(
|
txt2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(txt2img),
|
fn=wrap_gradio_gpu_call(txt2img),
|
||||||
|
@ -1289,7 +1350,7 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
gallery,
|
gallery,
|
||||||
output_seed,
|
generation_info,
|
||||||
html_info
|
html_info
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -1297,6 +1358,25 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
prompt.submit(**txt2img_args)
|
prompt.submit(**txt2img_args)
|
||||||
submit.click(**txt2img_args)
|
submit.click(**txt2img_args)
|
||||||
|
|
||||||
|
interrupt.click(
|
||||||
|
fn=lambda: state.interrupt(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
save.click(
|
||||||
|
fn=wrap_gradio_call(save_files),
|
||||||
|
inputs=[
|
||||||
|
generation_info,
|
||||||
|
gallery,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
html_info,
|
||||||
|
html_info,
|
||||||
|
html_info,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_crop_region(mask, pad=0):
|
def get_crop_region(mask, pad=0):
|
||||||
h, w = mask.shape
|
h, w = mask.shape
|
||||||
|
@ -1508,6 +1588,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
|
||||||
p.batch_size = 1
|
p.batch_size = 1
|
||||||
p.do_not_save_grid = True
|
p.do_not_save_grid = True
|
||||||
|
|
||||||
|
state.job = f"Batch {i + 1} out of {n_iter}"
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
|
||||||
if initial_seed is None:
|
if initial_seed is None:
|
||||||
|
@ -1523,13 +1604,13 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
|
||||||
|
|
||||||
save_image(grid, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
|
save_image(grid, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
|
||||||
|
|
||||||
processed = Processed(history, initial_seed, initial_info)
|
processed = Processed(p, history, initial_seed, initial_info)
|
||||||
|
|
||||||
elif is_upscale:
|
elif is_upscale:
|
||||||
initial_seed = None
|
initial_seed = None
|
||||||
initial_info = None
|
initial_info = None
|
||||||
|
|
||||||
upscaler = sd_upscalers[upscaler_name]
|
upscaler = sd_upscalers.get(upscaler_name, next(iter(sd_upscalers.values())))
|
||||||
img = upscaler(init_img)
|
img = upscaler(init_img)
|
||||||
|
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
@ -1553,6 +1634,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
|
||||||
for i in range(batch_count):
|
for i in range(batch_count):
|
||||||
p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
|
p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
|
||||||
|
|
||||||
|
state.job = f"Batch {i + 1} out of {batch_count}"
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
|
||||||
if initial_seed is None:
|
if initial_seed is None:
|
||||||
|
@ -1565,19 +1647,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
|
||||||
image_index = 0
|
image_index = 0
|
||||||
for y, h, row in grid.tiles:
|
for y, h, row in grid.tiles:
|
||||||
for tiledata in row:
|
for tiledata in row:
|
||||||
tiledata[2] = work_results[image_index]
|
tiledata[2] = work_results[image_index] if image_index<len(work_results) else Image.new("RGB", (p.width, p.height))
|
||||||
image_index += 1
|
image_index += 1
|
||||||
|
|
||||||
combined_image = combine_grid(grid)
|
combined_image = combine_grid(grid)
|
||||||
|
|
||||||
save_image(combined_image, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename)
|
save_image(combined_image, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename)
|
||||||
|
|
||||||
processed = Processed([combined_image], initial_seed, initial_info)
|
processed = Processed(p, [combined_image], initial_seed, initial_info)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
|
||||||
return processed.images, processed.seed, plaintext_to_html(processed.info)
|
return processed.images, processed.js(), plaintext_to_html(processed.info)
|
||||||
|
|
||||||
|
|
||||||
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||||
|
@ -1609,8 +1691,8 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=True, visible=False)
|
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=True, visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=list(sd_upscalers.keys()), value="RealESRGAN")
|
sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=list(sd_upscalers.keys()), value=list(sd_upscalers.keys())[0], visible=False)
|
||||||
sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64)
|
sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
||||||
|
@ -1629,8 +1711,15 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gallery = gr.Gallery(label='Output')
|
gallery = gr.Gallery(label='Output')
|
||||||
output_seed = gr.Number(label='Seed', visible=False)
|
|
||||||
|
with gr.Group():
|
||||||
|
with gr.Row():
|
||||||
|
interrupt = gr.Button('Interrupt')
|
||||||
|
save = gr.Button('Save')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
|
generation_info = gr.Textbox(visible=False)
|
||||||
|
|
||||||
def apply_mode(mode):
|
def apply_mode(mode):
|
||||||
is_classic = mode == 0
|
is_classic = mode == 0
|
||||||
|
@ -1647,7 +1736,7 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
batch_count: gr.update(visible=not is_upscale),
|
batch_count: gr.update(visible=not is_upscale),
|
||||||
batch_size: gr.update(visible=not is_loopback),
|
batch_size: gr.update(visible=not is_loopback),
|
||||||
sd_upscale_upscaler_name: gr.update(visible=is_upscale),
|
sd_upscale_upscaler_name: gr.update(visible=is_upscale),
|
||||||
sd_upscale_overlap: gr.update(visible=is_upscale),
|
sd_upscale_overlap: gr.Slider.update(visible=is_upscale),
|
||||||
inpaint_full_res: gr.update(visible=is_inpaint),
|
inpaint_full_res: gr.update(visible=is_inpaint),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1695,7 +1784,7 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
gallery,
|
gallery,
|
||||||
output_seed,
|
generation_info,
|
||||||
html_info
|
html_info
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -1703,6 +1792,25 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
prompt.submit(**img2img_args)
|
prompt.submit(**img2img_args)
|
||||||
submit.click(**img2img_args)
|
submit.click(**img2img_args)
|
||||||
|
|
||||||
|
interrupt.click(
|
||||||
|
fn=lambda: state.interrupt(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
save.click(
|
||||||
|
fn=wrap_gradio_call(save_files),
|
||||||
|
inputs=[
|
||||||
|
generation_info,
|
||||||
|
gallery,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
html_info,
|
||||||
|
html_info,
|
||||||
|
html_info,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
|
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
|
||||||
info = realesrgan_models[RealESRGAN_model_index]
|
info = realesrgan_models[RealESRGAN_model_index]
|
||||||
|
@ -1744,7 +1852,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
|
||||||
|
|
||||||
save_image(image, outpath, "", None, '', opts.samples_format, short_filename=True)
|
save_image(image, outpath, "", None, '', opts.samples_format, short_filename=True)
|
||||||
|
|
||||||
return image, 0, ''
|
return image, '', ''
|
||||||
|
|
||||||
|
|
||||||
extras_interface = gr.Interface(
|
extras_interface = gr.Interface(
|
||||||
|
@ -1757,7 +1865,7 @@ extras_interface = gr.Interface(
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
gr.Image(label="Result"),
|
gr.Image(label="Result"),
|
||||||
gr.Number(label='Seed', visible=False),
|
gr.HTML(),
|
||||||
gr.HTML(),
|
gr.HTML(),
|
||||||
],
|
],
|
||||||
allow_flagging="never",
|
allow_flagging="never",
|
||||||
|
@ -1779,7 +1887,7 @@ def run_pnginfo(image):
|
||||||
message = "Nothing found in the image."
|
message = "Nothing found in the image."
|
||||||
info = f"<div><p>{message}<p></div>"
|
info = f"<div><p>{message}<p></div>"
|
||||||
|
|
||||||
return [info]
|
return '', '', info
|
||||||
|
|
||||||
|
|
||||||
pnginfo_interface = gr.Interface(
|
pnginfo_interface = gr.Interface(
|
||||||
|
@ -1789,6 +1897,8 @@ pnginfo_interface = gr.Interface(
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
gr.HTML(),
|
gr.HTML(),
|
||||||
|
gr.HTML(),
|
||||||
|
gr.HTML(),
|
||||||
],
|
],
|
||||||
allow_flagging="never",
|
allow_flagging="never",
|
||||||
analytics_enabled=False,
|
analytics_enabled=False,
|
||||||
|
@ -1809,7 +1919,7 @@ def run_settings(*args):
|
||||||
|
|
||||||
opts.save(config_filename)
|
opts.save(config_filename)
|
||||||
|
|
||||||
return 'Settings saved.', ''
|
return 'Settings saved.', '', ''
|
||||||
|
|
||||||
|
|
||||||
def create_setting_component(key):
|
def create_setting_component(key):
|
||||||
|
@ -1839,6 +1949,7 @@ settings_interface = gr.Interface(
|
||||||
outputs=[
|
outputs=[
|
||||||
gr.Textbox(label='Result'),
|
gr.Textbox(label='Result'),
|
||||||
gr.HTML(),
|
gr.HTML(),
|
||||||
|
gr.HTML(),
|
||||||
],
|
],
|
||||||
title=None,
|
title=None,
|
||||||
description=None,
|
description=None,
|
||||||
|
@ -1863,17 +1974,18 @@ try:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
sd_config = OmegaConf.load(cmd_opts.config)
|
if False:
|
||||||
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
|
sd_config = OmegaConf.load(cmd_opts.config)
|
||||||
sd_model = (sd_model if cmd_opts.no_half else sd_model.half())
|
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
|
||||||
|
sd_model = (sd_model if cmd_opts.no_half else sd_model.half())
|
||||||
|
|
||||||
if cmd_opts.lowvram or cmd_opts.medvram:
|
if cmd_opts.lowvram or cmd_opts.medvram:
|
||||||
setup_for_low_vram(sd_model)
|
setup_for_low_vram(sd_model)
|
||||||
else:
|
else:
|
||||||
sd_model = sd_model.to(device)
|
sd_model = sd_model.to(device)
|
||||||
|
|
||||||
model_hijack = StableDiffusionModelHijack()
|
model_hijack = StableDiffusionModelHijack()
|
||||||
model_hijack.hijack(sd_model)
|
model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
|
with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
|
||||||
css = file.read()
|
css = file.read()
|
||||||
|
|
Loading…
Reference in a new issue