commit
c1512ef9ae
4 changed files with 132 additions and 28 deletions
|
@ -10,13 +10,17 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru
|
from modules import sd_samplers, deepbooru, sd_hijack
|
||||||
from modules.api.models import *
|
from modules.api.models import *
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.extras import run_extras, run_pnginfo
|
from modules.extras import run_extras, run_pnginfo
|
||||||
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list
|
from modules.sd_models import checkpoints_list
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
|
from modules import devices
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
def upscaler_to_index(name: str):
|
||||||
|
@ -97,6 +101,11 @@ class Api:
|
||||||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||||
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
||||||
|
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
||||||
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||||
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||||
|
|
||||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||||
if shared.cmd_opts.api_auth:
|
if shared.cmd_opts.api_auth:
|
||||||
|
@ -326,6 +335,89 @@ class Api:
|
||||||
def refresh_checkpoints(self):
|
def refresh_checkpoints(self):
|
||||||
shared.refresh_checkpoints()
|
shared.refresh_checkpoints()
|
||||||
|
|
||||||
|
def create_embedding(self, args: dict):
|
||||||
|
try:
|
||||||
|
shared.state.begin()
|
||||||
|
filename = create_embedding(**args) # create empty embedding
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||||
|
shared.state.end()
|
||||||
|
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
|
||||||
|
except AssertionError as e:
|
||||||
|
shared.state.end()
|
||||||
|
return TrainResponse(info = "create embedding error: {error}".format(error = e))
|
||||||
|
|
||||||
|
def create_hypernetwork(self, args: dict):
|
||||||
|
try:
|
||||||
|
shared.state.begin()
|
||||||
|
filename = create_hypernetwork(**args) # create empty embedding
|
||||||
|
shared.state.end()
|
||||||
|
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
|
||||||
|
except AssertionError as e:
|
||||||
|
shared.state.end()
|
||||||
|
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
|
||||||
|
|
||||||
|
def preprocess(self, args: dict):
|
||||||
|
try:
|
||||||
|
shared.state.begin()
|
||||||
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||||
|
shared.state.end()
|
||||||
|
return PreprocessResponse(info = 'preprocess complete')
|
||||||
|
except KeyError as e:
|
||||||
|
shared.state.end()
|
||||||
|
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
|
||||||
|
except AssertionError as e:
|
||||||
|
shared.state.end()
|
||||||
|
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
shared.state.end()
|
||||||
|
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
||||||
|
|
||||||
|
def train_embedding(self, args: dict):
|
||||||
|
try:
|
||||||
|
shared.state.begin()
|
||||||
|
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||||
|
error = None
|
||||||
|
filename = ''
|
||||||
|
if not apply_optimizations:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
try:
|
||||||
|
embedding, filename = train_embedding(**args) # can take a long time to complete
|
||||||
|
except Exception as e:
|
||||||
|
error = e
|
||||||
|
finally:
|
||||||
|
if not apply_optimizations:
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
shared.state.end()
|
||||||
|
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||||
|
except AssertionError as msg:
|
||||||
|
shared.state.end()
|
||||||
|
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
|
||||||
|
|
||||||
|
def train_hypernetwork(self, args: dict):
|
||||||
|
try:
|
||||||
|
shared.state.begin()
|
||||||
|
initial_hypernetwork = shared.loaded_hypernetwork
|
||||||
|
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||||
|
error = None
|
||||||
|
filename = ''
|
||||||
|
if not apply_optimizations:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
try:
|
||||||
|
hypernetwork, filename = train_hypernetwork(*args)
|
||||||
|
except Exception as e:
|
||||||
|
error = e
|
||||||
|
finally:
|
||||||
|
shared.loaded_hypernetwork = initial_hypernetwork
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
if not apply_optimizations:
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
shared.state.end()
|
||||||
|
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||||
|
except AssertionError as msg:
|
||||||
|
shared.state.end()
|
||||||
|
return TrainResponse(info = "train embedding error: {error}".format(error = error))
|
||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port)
|
uvicorn.run(self.app, host=server_name, port=port)
|
||||||
|
|
|
@ -175,6 +175,15 @@ class InterrogateRequest(BaseModel):
|
||||||
class InterrogateResponse(BaseModel):
|
class InterrogateResponse(BaseModel):
|
||||||
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||||
|
|
||||||
|
class TrainResponse(BaseModel):
|
||||||
|
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
|
||||||
|
|
||||||
|
class CreateResponse(BaseModel):
|
||||||
|
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
||||||
|
|
||||||
|
class PreprocessResponse(BaseModel):
|
||||||
|
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
|
||||||
|
|
||||||
fields = {}
|
fields = {}
|
||||||
for key, metadata in opts.data_labels.items():
|
for key, metadata in opts.data_labels.items():
|
||||||
value = opts.data.get(key)
|
value = opts.data.get(key)
|
||||||
|
|
|
@ -378,6 +378,32 @@ def report_statistics(loss_info:dict):
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||||
|
# Remove illegal characters from name.
|
||||||
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||||
|
|
||||||
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
|
if not overwrite_old:
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
|
if type(layer_structure) == str:
|
||||||
|
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||||
|
|
||||||
|
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||||
|
name=name,
|
||||||
|
enable_sizes=[int(x) for x in enable_sizes],
|
||||||
|
layer_structure=layer_structure,
|
||||||
|
activation_func=activation_func,
|
||||||
|
weight_init=weight_init,
|
||||||
|
add_layer_norm=add_layer_norm,
|
||||||
|
use_dropout=use_dropout,
|
||||||
|
)
|
||||||
|
hypernet.save(fn)
|
||||||
|
|
||||||
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
|
|
|
@ -3,39 +3,16 @@ import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.textual_inversion.preprocess
|
import modules.hypernetworks.hypernetwork
|
||||||
import modules.textual_inversion.textual_inversion
|
|
||||||
from modules import devices, sd_hijack, shared
|
from modules import devices, sd_hijack, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
|
||||||
|
|
||||||
not_available = ["hardswish", "multiheadattention"]
|
not_available = ["hardswish", "multiheadattention"]
|
||||||
keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||||
# Remove illegal characters from name.
|
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
|
||||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
|
||||||
|
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
||||||
if not overwrite_old:
|
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
|
||||||
|
|
||||||
if type(layer_structure) == str:
|
|
||||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
|
||||||
|
|
||||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
|
||||||
name=name,
|
|
||||||
enable_sizes=[int(x) for x in enable_sizes],
|
|
||||||
layer_structure=layer_structure,
|
|
||||||
activation_func=activation_func,
|
|
||||||
weight_init=weight_init,
|
|
||||||
add_layer_norm=add_layer_norm,
|
|
||||||
use_dropout=use_dropout,
|
|
||||||
)
|
|
||||||
hypernet.save(fn)
|
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
|
||||||
|
|
||||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
|
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(*args):
|
def train_hypernetwork(*args):
|
||||||
|
|
Loading…
Reference in a new issue