actual support for share=True in gradio
This commit is contained in:
parent
f1aa1d6711
commit
54dc6f9307
1 changed files with 17 additions and 6 deletions
23
webui.py
23
webui.py
|
@ -37,6 +37,7 @@ from contextlib import nullcontext
|
||||||
import signal
|
import signal
|
||||||
import tqdm
|
import tqdm
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
|
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
@ -75,6 +76,7 @@ cpu = torch.device("cpu")
|
||||||
gpu = torch.device("cuda")
|
gpu = torch.device("cuda")
|
||||||
device = gpu if torch.cuda.is_available() else cpu
|
device = gpu if torch.cuda.is_available() else cpu
|
||||||
batch_cond_uncond = not (cmd_opts.lowvram or cmd_opts.medvram)
|
batch_cond_uncond = not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||||
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
if not cmd_opts.share:
|
if not cmd_opts.share:
|
||||||
# fix gradio phoning home
|
# fix gradio phoning home
|
||||||
|
@ -643,10 +645,20 @@ def resize_image(resize_mode, im, width, height):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_gradio_gpu_call(func):
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
with queue_lock:
|
||||||
|
res = func(*args, **kwargs)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_call(func):
|
def wrap_gradio_call(func):
|
||||||
def f(*p1, **p2):
|
def f(*args, **kwargs):
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
res = list(func(*p1, **p2))
|
res = list(func(*args, **kwargs))
|
||||||
elapsed = time.perf_counter() - t
|
elapsed = time.perf_counter() - t
|
||||||
|
|
||||||
# last item is always HTML
|
# last item is always HTML
|
||||||
|
@ -1259,7 +1271,7 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
|
|
||||||
txt2img_args = dict(
|
txt2img_args = dict(
|
||||||
fn=wrap_gradio_call(txt2img),
|
fn=wrap_gradio_gpu_call(txt2img),
|
||||||
inputs=[
|
inputs=[
|
||||||
prompt,
|
prompt,
|
||||||
negative_prompt,
|
negative_prompt,
|
||||||
|
@ -1657,7 +1669,7 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
)
|
)
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=wrap_gradio_call(img2img),
|
fn=wrap_gradio_gpu_call(img2img),
|
||||||
inputs=[
|
inputs=[
|
||||||
prompt,
|
prompt,
|
||||||
init_img,
|
init_img,
|
||||||
|
@ -1736,7 +1748,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
|
||||||
|
|
||||||
|
|
||||||
extras_interface = gr.Interface(
|
extras_interface = gr.Interface(
|
||||||
wrap_gradio_call(run_extras),
|
wrap_gradio_gpu_call(run_extras),
|
||||||
inputs=[
|
inputs=[
|
||||||
gr.Image(label="Source", source="upload", interactive=True, type="pil"),
|
gr.Image(label="Source", source="upload", interactive=True, type="pil"),
|
||||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=have_gfpgan),
|
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=have_gfpgan),
|
||||||
|
@ -1904,6 +1916,5 @@ def inject_gradio_html(javascript):
|
||||||
|
|
||||||
inject_gradio_html(javascript)
|
inject_gradio_html(javascript)
|
||||||
|
|
||||||
demo.queue(concurrency_count=1)
|
|
||||||
demo.launch(share=cmd_opts.share)
|
demo.launch(share=cmd_opts.share)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue