add initial version of the extensions tab

fix broken Restart Gradio button
This commit is contained in:
AUTOMATIC 2022-10-31 17:36:45 +03:00
parent 9b384dfb5c
commit 910a097ae2
9 changed files with 333 additions and 30 deletions

24
javascript/extensions.js Normal file
View file

@ -0,0 +1,24 @@
function extensions_apply(_, _){
disable = []
update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7))
if(x.name.startsWith("update_") && x.checked)
update.push(x.name.substr(7))
})
restart_reload()
return [JSON.stringify(disable), JSON.stringify(update)]
}
function extensions_check(){
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..."
})
return []
}

83
modules/extensions.py Normal file
View file

@ -0,0 +1,83 @@
import os
import sys
import traceback
import git
from modules import paths, shared
extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions")
def active():
return [x for x in extensions if x.enabled]
class Extension:
def __init__(self, name, path, enabled=True):
self.name = name
self.path = path
self.enabled = enabled
self.status = ''
self.can_update = False
repo = None
try:
if os.path.exists(os.path.join(path, ".git")):
repo = git.Repo(path)
except Exception:
print(f"Error reading github repository info from {path}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if repo is None or repo.bare:
self.remote = None
else:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
def list_files(self, subdir, extension):
from modules import scripts
dirpath = os.path.join(self.path, subdir)
if not os.path.isdir(dirpath):
return []
res = []
for filename in sorted(os.listdir(dirpath)):
res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename)))
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
return res
def check_updates(self):
repo = git.Repo(self.path)
for fetch in repo.remote().fetch("--dry-run"):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
self.status = "behind"
return
self.can_update = False
self.status = "latest"
def pull(self):
repo = git.Repo(self.path)
repo.remotes.origin.pull()
def list_extensions():
extensions.clear()
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
path = os.path.join(extensions_dir, dirname)
if not os.path.isdir(path):
continue
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension)

View file

@ -17,6 +17,11 @@ paste_fields = {}
bind_list = [] bind_list = []
def reset():
paste_fields.clear()
bind_list.clear()
def quote(text): def quote(text):
if ',' not in str(text): if ',' not in str(text):
return text return text

View file

@ -7,7 +7,7 @@ import modules.ui as ui
import gradio as gr import gradio as gr
from modules.processing import StableDiffusionProcessing from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks from modules import shared, paths, script_callbacks, extensions
AlwaysVisible = object() AlwaysVisible = object()
@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension):
for filename in sorted(os.listdir(basedir)): for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
extdir = os.path.join(paths.script_path, "extensions") for ext in extensions.active():
if os.path.exists(extdir): scripts_list += ext.list_files(scriptdirname, extension)
for dirname in sorted(os.listdir(extdir)):
dirpath = os.path.join(extdir, dirname)
scriptdirpath = os.path.join(dirpath, scriptdirname)
if not os.path.isdir(scriptdirpath):
continue
for filename in sorted(os.listdir(scriptdirpath)):
scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension):
def list_files_with_name(filename): def list_files_with_name(filename):
res = [] res = []
dirs = [paths.script_path] dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
extdir = os.path.join(paths.script_path, "extensions")
if os.path.exists(extdir):
dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
for dirpath in dirs: for dirpath in dirs:
if not os.path.isdir(dirpath): if not os.path.isdir(dirpath):

View file

@ -132,6 +132,7 @@ class State:
current_image = None current_image = None
current_image_sampling_step = 0 current_image_sampling_step = 0
textinfo = None textinfo = None
need_restart = False
def skip(self): def skip(self):
self.skipped = True self.skipped = True
@ -354,6 +355,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
})) }))
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
}))
options_templates.update()
class Options: class Options:
data = None data = None
@ -365,8 +372,9 @@ class Options:
def __setattr__(self, key, value): def __setattr__(self, key, value):
if self.data is not None: if self.data is not None:
if key in self.data: if key in self.data or key in self.data_labels:
self.data[key] = value self.data[key] = value
return
return super(Options, self).__setattr__(key, value) return super(Options, self).__setattr__(key, value)

View file

@ -19,7 +19,7 @@ import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules import sd_hijack, sd_models, localization, script_callbacks from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts from modules.shared import opts, cmd_opts, restricted_opts
@ -671,6 +671,7 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
@ -1511,8 +1512,9 @@ def create_ui(wrap_gradio_gpu_call):
column = None column = None
with gr.Row(elem_id="settings").style(equal_height=False): with gr.Row(elem_id="settings").style(equal_height=False):
for i, (k, item) in enumerate(opts.data_labels.items()): for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None
if previous_section != item.section: if previous_section != item.section and not section_must_be_skipped:
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
if column is not None: if column is not None:
column.__exit__() column.__exit__()
@ -1531,6 +1533,8 @@ def create_ui(wrap_gradio_gpu_call):
if k in quicksettings_names and not shared.cmd_opts.freeze_settings: if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item)) quicksettings_list.append((i, k, item))
components.append(dummy_component) components.append(dummy_component)
elif section_must_be_skipped:
components.append(dummy_component)
else: else:
component = create_setting_component(k) component = create_setting_component(k)
component_dict[k] = component component_dict[k] = component
@ -1572,9 +1576,10 @@ def create_ui(wrap_gradio_gpu_call):
def request_restart(): def request_restart():
shared.state.interrupt() shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True shared.state.need_restart = True
restart_gradio.click( restart_gradio.click(
fn=request_restart, fn=request_restart,
inputs=[], inputs=[],
outputs=[], outputs=[],
@ -1612,14 +1617,15 @@ def create_ui(wrap_gradio_gpu_call):
interfaces += script_callbacks.ui_tabs_callback() interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")] interfaces += [(settings_interface, "Settings", "settings")]
extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")]
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"): with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list: for i, k, item in quicksettings_list:
component = create_setting_component(k, is_quicksettings=True) component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component component_dict[k] = component
settings_interface.gradio_ref = demo
parameters_copypaste.integrate_settings_paste_fields(component_dict) parameters_copypaste.integrate_settings_paste_fields(component_dict)
parameters_copypaste.run_bind() parameters_copypaste.run_bind()

162
modules/ui_extensions.py Normal file
View file

@ -0,0 +1,162 @@
import json
import os.path
import shutil
import sys
import time
import traceback
import git
import gradio as gr
import html
from modules import extensions, shared, paths
def apply_and_restart(disable_list, update_list):
disabled = json.loads(disable_list)
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
update = json.loads(update_list)
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
update = set(update)
for ext in extensions.extensions:
if ext.name not in update:
continue
try:
ext.pull()
except Exception:
print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled
shared.opts.save(shared.config_filename)
shared.state.interrupt()
shared.state.need_restart = True
def check_updates():
for ext in extensions.extensions:
if ext.remote is None:
continue
try:
ext.check_updates()
except Exception:
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return extension_table()
def extension_table():
code = f"""<!-- {time.time()} -->
<table id="extensions">
<thead>
<tr>
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
<th>URL</th>
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
</tr>
</thead>
<tbody>
"""
for ext in extensions.extensions:
if ext.can_update:
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
else:
ext_status = ext.status
code += f"""
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
code += """
</tbody>
</table>
"""
return code
def install_extension_from_url(dirname, url):
assert url, 'No URL specified'
if dirname is None or dirname == "":
*parts, last_part = url.split('/')
last_part = last_part.replace(".git", "")
dirname = last_part
target_dir = os.path.join(extensions.extensions_dir, dirname)
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed'
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
try:
shutil.rmtree(tmpdir, True)
repo = git.Repo.clone_from(url, tmpdir)
repo.remote().fetch()
os.rename(tmpdir, target_dir)
extensions.list_extensions()
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
finally:
shutil.rmtree(tmpdir, True)
def create_ui():
import modules.ui
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"):
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False)
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False)
with gr.Row():
apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates")
extensions_table = gr.HTML(lambda: extension_table())
apply.click(
fn=apply_and_restart,
_js="extensions_apply",
inputs=[extensions_disabled_list, extensions_update_list],
outputs=[],
)
check.click(
fn=check_updates,
_js="extensions_check",
inputs=[],
outputs=[extensions_table],
)
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
intall_button = gr.Button(value="Install", variant="primary")
intall_result = gr.HTML(elem_id="extension_install_result")
intall_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
inputs=[install_dirname, install_url],
outputs=[extensions_table, intall_result],
)
return ui

View file

@ -530,6 +530,26 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
min-height: 480px !important; min-height: 480px !important;
} }
/* Extensions */
#extensions{
border-collapse: collapse;
}
#extensions td, #extensions th{
border: 1px solid #ccc;
padding: 0.25em 0.5em;
}
#extensions input[type="checkbox"]{
margin-right: 0.5em;
}
#tab_extensions button{
max-width: 16em;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic. /* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js. The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running If you change anything above, you need to make sure it is RTL compliant by just running

View file

@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
from modules import devices, sd_samplers, upscaler from modules import devices, sd_samplers, upscaler, extensions
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
@ -60,6 +60,11 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def initialize(): def initialize():
extensions.list_extensions()
#for ext in extensions.extensions:
# print(ext.name, ext.path, ext.enabled, ext.remote)
#exit()
if cmd_opts.ui_debug_mode: if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts() modules.scripts.load_scripts()
@ -92,15 +97,18 @@ def create_api(app):
api = Api(app, queue_lock) api = Api(app, queue_lock)
return api return api
def wait_on_server(demo=None): def wait_on_server(demo=None):
while 1: while 1:
time.sleep(0.5) time.sleep(0.5)
if demo and getattr(demo, 'do_restart', False): if shared.state.need_restart:
shared.state.need_restart = False
time.sleep(0.5) time.sleep(0.5)
demo.close() demo.close()
time.sleep(0.5) time.sleep(0.5)
break break
def api_only(): def api_only():
initialize() initialize()
@ -132,14 +140,16 @@ def webui():
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
if (launch_api): if launch_api:
create_api(app) create_api(app)
wait_on_server(demo) wait_on_server(demo)
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading Custom Scripts') print('Reloading extensions')
extensions.list_extensions()
print('Reloading custom scripts')
modules.scripts.reload_scripts() modules.scripts.reload_scripts()
print('Reloading modules: modules.ui') print('Reloading modules: modules.ui')
importlib.reload(modules.ui) importlib.reload(modules.ui)
@ -148,8 +158,6 @@ def webui():
print('Restarting Gradio') print('Restarting Gradio')
task = []
if __name__ == "__main__": if __name__ == "__main__":
if cmd_opts.nowebui: if cmd_opts.nowebui:
api_only() api_only()