Merge branch 'AUTOMATIC1111:master' into img2img-api-scripts
This commit is contained in:
commit
50e2536279
24 changed files with 813 additions and 265 deletions
10
README.md
10
README.md
|
@ -1,9 +1,7 @@
|
||||||
# Stable Diffusion web UI
|
# Stable Diffusion web UI
|
||||||
A browser interface based on Gradio library for Stable Diffusion.
|
A browser interface based on Gradio library for Stable Diffusion.
|
||||||
|
|
||||||
![](txt2img_Screenshot.png)
|
![](screenshot.png)
|
||||||
|
|
||||||
Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
|
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
|
||||||
|
@ -97,9 +95,8 @@ Alternatively, use online services (like Google Colab):
|
||||||
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
|
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
|
||||||
2. Install [git](https://git-scm.com/download/win).
|
2. Install [git](https://git-scm.com/download/win).
|
||||||
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
|
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
|
||||||
4. Place `model.ckpt` in the `models` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
|
4. Place stable diffusion checkpoint (`model.ckpt`) in the `models/Stable-diffusion` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
|
||||||
5. _*(Optional)*_ Place `GFPGANv1.4.pth` in the base directory, alongside `webui.py` (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
|
5. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
||||||
6. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
|
||||||
|
|
||||||
### Automatic Installation on Linux
|
### Automatic Installation on Linux
|
||||||
1. Install the dependencies:
|
1. Install the dependencies:
|
||||||
|
@ -141,6 +138,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
||||||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||||
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
||||||
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
||||||
|
- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)
|
||||||
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
||||||
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
||||||
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
||||||
|
|
|
@ -184,7 +184,7 @@ SOFTWARE.
|
||||||
</pre>
|
</pre>
|
||||||
|
|
||||||
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
||||||
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
<small>Code added by contributors, most likely copied from this repository.</small>
|
||||||
|
|
||||||
<pre>
|
<pre>
|
||||||
Apache License
|
Apache License
|
||||||
|
@ -390,3 +390,30 @@ SOFTWARE.
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
</pre>
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
|
||||||
|
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
|
||||||
|
<pre>
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Alex Birch
|
||||||
|
Copyright (c) 2023 Amin Rezaei
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ class ExtrasBaseRequest(BaseModel):
|
||||||
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||||
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||||
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||||
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||||
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||||
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
||||||
|
|
|
@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs):
|
||||||
return orig_tensor_numpy(self, *args, **kwargs)
|
return orig_tensor_numpy(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||||
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
|
orig_cumsum = torch.cumsum
|
||||||
torch.Tensor.to = tensor_to_fix
|
orig_Tensor_cumsum = torch.Tensor.cumsum
|
||||||
torch.nn.functional.layer_norm = layer_norm_fix
|
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||||
torch.Tensor.numpy = numpy_fix
|
if input.device.type == 'mps':
|
||||||
|
output_dtype = kwargs.get('dtype', input.dtype)
|
||||||
|
if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
|
||||||
|
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||||
|
return cumsum_func(input, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if has_mps():
|
||||||
|
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||||
|
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||||
|
torch.Tensor.to = tensor_to_fix
|
||||||
|
torch.nn.functional.layer_norm = layer_norm_fix
|
||||||
|
torch.Tensor.numpy = numpy_fix
|
||||||
|
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||||
|
if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
|
||||||
|
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||||
|
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||||
|
orig_narrow = torch.narrow
|
||||||
|
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
|
||||||
|
|
|
@ -13,7 +13,7 @@ import tqdm
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from modules import devices, processing, sd_models, shared, sd_samplers
|
from modules import devices, processing, sd_models, shared, sd_samplers
|
||||||
from modules.textual_inversion import textual_inversion
|
from modules.textual_inversion import textual_inversion, logging
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||||
|
@ -458,6 +458,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
|
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||||
|
|
||||||
|
if shared.opts.save_training_settings_to_txt:
|
||||||
|
saved_params = dict(
|
||||||
|
model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
|
||||||
|
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
||||||
|
)
|
||||||
|
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
||||||
|
|
||||||
latent_sampling_method = ds.latent_sampling_method
|
latent_sampling_method = ds.latent_sampling_method
|
||||||
|
|
||||||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||||
|
|
|
@ -711,7 +711,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
self.truncate_x = 0
|
self.truncate_x = 0
|
||||||
self.truncate_y = 0
|
self.truncate_y = 0
|
||||||
|
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
if self.enable_hr:
|
if self.enable_hr:
|
||||||
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
|
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
|
||||||
|
|
|
@ -71,6 +71,7 @@ callback_map = dict(
|
||||||
callbacks_before_component=[],
|
callbacks_before_component=[],
|
||||||
callbacks_after_component=[],
|
callbacks_after_component=[],
|
||||||
callbacks_image_grid=[],
|
callbacks_image_grid=[],
|
||||||
|
callbacks_script_unloaded=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,6 +172,14 @@ def image_grid_callback(params: ImageGridLoopParams):
|
||||||
report_exception(c, 'image_grid')
|
report_exception(c, 'image_grid')
|
||||||
|
|
||||||
|
|
||||||
|
def script_unloaded_callback():
|
||||||
|
for c in reversed(callback_map['callbacks_script_unloaded']):
|
||||||
|
try:
|
||||||
|
c.callback()
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'script_unloaded')
|
||||||
|
|
||||||
|
|
||||||
def add_callback(callbacks, fun):
|
def add_callback(callbacks, fun):
|
||||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
|
@ -202,7 +211,7 @@ def on_app_started(callback):
|
||||||
|
|
||||||
def on_model_loaded(callback):
|
def on_model_loaded(callback):
|
||||||
"""register a function to be called when the stable diffusion model is created; the model is
|
"""register a function to be called when the stable diffusion model is created; the model is
|
||||||
passed as an argument"""
|
passed as an argument; this function is also called when the script is reloaded. """
|
||||||
add_callback(callback_map['callbacks_model_loaded'], callback)
|
add_callback(callback_map['callbacks_model_loaded'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
@ -279,3 +288,10 @@ def on_image_grid(callback):
|
||||||
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_image_grid'], callback)
|
add_callback(callback_map['callbacks_image_grid'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_script_unloaded(callback):
|
||||||
|
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
||||||
|
the script did should be reverted here"""
|
||||||
|
|
||||||
|
add_callback(callback_map['callbacks_script_unloaded'], callback)
|
||||||
|
|
|
@ -290,7 +290,6 @@ class ScriptRunner:
|
||||||
script.group = group
|
script.group = group
|
||||||
|
|
||||||
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
|
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
|
||||||
dropdown.save_to_config = True
|
|
||||||
inputs[0] = dropdown
|
inputs[0] = dropdown
|
||||||
|
|
||||||
for script in self.selectable_scripts:
|
for script in self.selectable_scripts:
|
||||||
|
|
|
@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
|
|
||||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
@ -43,20 +41,19 @@ def apply_optimizations():
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||||
optimization_method = 'xformers'
|
optimization_method = 'xformers'
|
||||||
|
elif cmd_opts.opt_sub_quad_attention:
|
||||||
|
print("Applying sub-quadratic cross attention optimization.")
|
||||||
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
||||||
|
optimization_method = 'sub-quadratic'
|
||||||
elif cmd_opts.opt_split_attention_v1:
|
elif cmd_opts.opt_split_attention_v1:
|
||||||
print("Applying v1 cross attention optimization.")
|
print("Applying v1 cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||||
optimization_method = 'V1'
|
optimization_method = 'V1'
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
|
||||||
if not invokeAI_mps_available and shared.device.type == 'mps':
|
print("Applying cross attention optimization (InvokeAI).")
|
||||||
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||||
print("Applying v1 cross attention optimization.")
|
optimization_method = 'InvokeAI'
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
|
||||||
optimization_method = 'V1'
|
|
||||||
else:
|
|
||||||
print("Applying cross attention optimization (InvokeAI).")
|
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
|
||||||
optimization_method = 'InvokeAI'
|
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||||
print("Applying cross attention optimization (Doggettx).")
|
print("Applying cross attention optimization (Doggettx).")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||||
|
@ -150,10 +147,10 @@ class StableDiffusionModelHijack:
|
||||||
def clear_comments(self):
|
def clear_comments(self):
|
||||||
self.comments = []
|
self.comments = []
|
||||||
|
|
||||||
def tokenize(self, text):
|
def get_prompt_lengths(self, text):
|
||||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
_, token_count = self.clip.process_texts([text])
|
||||||
|
|
||||||
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithFixes(torch.nn.Module):
|
class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
|
|
|
@ -1,30 +1,89 @@
|
||||||
import math
|
import math
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import prompt_parser, devices
|
from modules import prompt_parser, devices, sd_hijack
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
def get_target_prompt_token_count(token_count):
|
|
||||||
return math.ceil(max(token_count, 1) / 75) * 75
|
class PromptChunk:
|
||||||
|
"""
|
||||||
|
This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
|
||||||
|
If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
|
||||||
|
Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
|
||||||
|
so just 75 tokens from prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.tokens = []
|
||||||
|
self.multipliers = []
|
||||||
|
self.fixes = []
|
||||||
|
|
||||||
|
|
||||||
|
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||||
|
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
||||||
|
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
||||||
|
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
|
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
||||||
|
have unlimited prompt length and assign weights to tokens in prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.hijack = hijack
|
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
||||||
|
depending on model."""
|
||||||
|
|
||||||
|
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
||||||
|
self.chunk_length = 75
|
||||||
|
|
||||||
|
def empty_chunk(self):
|
||||||
|
"""creates an empty PromptChunk and returns it"""
|
||||||
|
|
||||||
|
chunk = PromptChunk()
|
||||||
|
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
||||||
|
chunk.multipliers = [1.0] * (self.chunk_length + 2)
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
def get_target_prompt_token_count(self, token_count):
|
||||||
|
"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
|
||||||
|
|
||||||
|
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
|
||||||
|
|
||||||
def tokenize(self, texts):
|
def tokenize(self, texts):
|
||||||
|
"""Converts a batch of texts into a batch of token ids"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def encode_with_transformers(self, tokens):
|
def encode_with_transformers(self, tokens):
|
||||||
|
"""
|
||||||
|
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
|
||||||
|
All python lists with tokens are assumed to have same length, usually 77.
|
||||||
|
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
||||||
|
model - can be 768 and 1024.
|
||||||
|
Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
|
||||||
|
"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def encode_embedding_init_text(self, init_text, nvpt):
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
"""Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
|
||||||
|
transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
def tokenize_line(self, line):
|
||||||
|
"""
|
||||||
|
this transforms a single prompt into a list of PromptChunk objects - as many as needed to
|
||||||
|
represent the prompt.
|
||||||
|
Returns the list and the total number of tokens in the prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
if opts.enable_emphasis:
|
if opts.enable_emphasis:
|
||||||
parsed = prompt_parser.parse_prompt_attention(line)
|
parsed = prompt_parser.parse_prompt_attention(line)
|
||||||
else:
|
else:
|
||||||
|
@ -32,205 +91,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
|
|
||||||
tokenized = self.tokenize([text for text, _ in parsed])
|
tokenized = self.tokenize([text for text, _ in parsed])
|
||||||
|
|
||||||
fixes = []
|
chunks = []
|
||||||
remade_tokens = []
|
chunk = PromptChunk()
|
||||||
multipliers = []
|
token_count = 0
|
||||||
last_comma = -1
|
last_comma = -1
|
||||||
|
|
||||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
def next_chunk():
|
||||||
i = 0
|
"""puts current chunk into the list of results and produces the next one - empty"""
|
||||||
while i < len(tokens):
|
nonlocal token_count
|
||||||
token = tokens[i]
|
nonlocal last_comma
|
||||||
|
nonlocal chunk
|
||||||
|
|
||||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
token_count += len(chunk.tokens)
|
||||||
|
to_add = self.chunk_length - len(chunk.tokens)
|
||||||
|
if to_add > 0:
|
||||||
|
chunk.tokens += [self.id_end] * to_add
|
||||||
|
chunk.multipliers += [1.0] * to_add
|
||||||
|
|
||||||
|
chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
|
||||||
|
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
|
||||||
|
|
||||||
|
last_comma = -1
|
||||||
|
chunks.append(chunk)
|
||||||
|
chunk = PromptChunk()
|
||||||
|
|
||||||
|
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||||
|
position = 0
|
||||||
|
while position < len(tokens):
|
||||||
|
token = tokens[position]
|
||||||
|
|
||||||
if token == self.comma_token:
|
if token == self.comma_token:
|
||||||
last_comma = len(remade_tokens)
|
last_comma = len(chunk.tokens)
|
||||||
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
|
||||||
last_comma += 1
|
|
||||||
reloc_tokens = remade_tokens[last_comma:]
|
|
||||||
reloc_mults = multipliers[last_comma:]
|
|
||||||
|
|
||||||
remade_tokens = remade_tokens[:last_comma]
|
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
||||||
length = len(remade_tokens)
|
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
||||||
|
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||||
|
break_location = last_comma + 1
|
||||||
|
|
||||||
rem = int(math.ceil(length / 75)) * 75 - length
|
reloc_tokens = chunk.tokens[break_location:]
|
||||||
remade_tokens += [self.id_end] * rem + reloc_tokens
|
reloc_mults = chunk.multipliers[break_location:]
|
||||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
|
||||||
|
|
||||||
|
chunk.tokens = chunk.tokens[:break_location]
|
||||||
|
chunk.multipliers = chunk.multipliers[:break_location]
|
||||||
|
|
||||||
|
next_chunk()
|
||||||
|
chunk.tokens = reloc_tokens
|
||||||
|
chunk.multipliers = reloc_mults
|
||||||
|
|
||||||
|
if len(chunk.tokens) == self.chunk_length:
|
||||||
|
next_chunk()
|
||||||
|
|
||||||
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
|
||||||
if embedding is None:
|
if embedding is None:
|
||||||
remade_tokens.append(token)
|
chunk.tokens.append(token)
|
||||||
multipliers.append(weight)
|
chunk.multipliers.append(weight)
|
||||||
i += 1
|
position += 1
|
||||||
else:
|
continue
|
||||||
emb_len = int(embedding.vec.shape[0])
|
|
||||||
iteration = len(remade_tokens) // 75
|
|
||||||
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
|
||||||
rem = (75 * (iteration + 1) - len(remade_tokens))
|
|
||||||
remade_tokens += [self.id_end] * rem
|
|
||||||
multipliers += [1.0] * rem
|
|
||||||
iteration += 1
|
|
||||||
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
|
||||||
remade_tokens += [0] * emb_len
|
|
||||||
multipliers += [weight] * emb_len
|
|
||||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
|
||||||
i += embedding_length_in_tokens
|
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
emb_len = int(embedding.vec.shape[0])
|
||||||
prompt_target_length = get_target_prompt_token_count(token_count)
|
if len(chunk.tokens) + emb_len > self.chunk_length:
|
||||||
tokens_to_add = prompt_target_length - len(remade_tokens)
|
next_chunk()
|
||||||
|
|
||||||
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
|
chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
|
||||||
multipliers = multipliers + [1.0] * tokens_to_add
|
|
||||||
|
|
||||||
return remade_tokens, fixes, multipliers, token_count
|
chunk.tokens += [0] * emb_len
|
||||||
|
chunk.multipliers += [weight] * emb_len
|
||||||
|
position += embedding_length_in_tokens
|
||||||
|
|
||||||
|
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
||||||
|
next_chunk()
|
||||||
|
|
||||||
|
return chunks, token_count
|
||||||
|
|
||||||
|
def process_texts(self, texts):
|
||||||
|
"""
|
||||||
|
Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
|
||||||
|
length, in tokens, of all texts.
|
||||||
|
"""
|
||||||
|
|
||||||
def process_text(self, texts):
|
|
||||||
used_custom_terms = []
|
|
||||||
remade_batch_tokens = []
|
|
||||||
hijack_comments = []
|
|
||||||
hijack_fixes = []
|
|
||||||
token_count = 0
|
token_count = 0
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
batch_multipliers = []
|
batch_chunks = []
|
||||||
for line in texts:
|
for line in texts:
|
||||||
if line in cache:
|
if line in cache:
|
||||||
remade_tokens, fixes, multipliers = cache[line]
|
chunks = cache[line]
|
||||||
else:
|
else:
|
||||||
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
chunks, current_token_count = self.tokenize_line(line)
|
||||||
token_count = max(current_token_count, token_count)
|
token_count = max(current_token_count, token_count)
|
||||||
|
|
||||||
cache[line] = (remade_tokens, fixes, multipliers)
|
cache[line] = chunks
|
||||||
|
|
||||||
remade_batch_tokens.append(remade_tokens)
|
batch_chunks.append(chunks)
|
||||||
hijack_fixes.append(fixes)
|
|
||||||
batch_multipliers.append(multipliers)
|
|
||||||
|
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
return batch_chunks, token_count
|
||||||
|
|
||||||
def process_text_old(self, texts):
|
def forward(self, texts):
|
||||||
id_start = self.id_start
|
"""
|
||||||
id_end = self.id_end
|
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
||||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
||||||
used_custom_terms = []
|
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
|
||||||
remade_batch_tokens = []
|
An example shape returned by this function can be: (2, 77, 768).
|
||||||
hijack_comments = []
|
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
||||||
hijack_fixes = []
|
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||||
token_count = 0
|
"""
|
||||||
|
|
||||||
cache = {}
|
if opts.use_old_emphasis_implementation:
|
||||||
batch_tokens = self.tokenize(texts)
|
import modules.sd_hijack_clip_old
|
||||||
batch_multipliers = []
|
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
||||||
for tokens in batch_tokens:
|
|
||||||
tuple_tokens = tuple(tokens)
|
|
||||||
|
|
||||||
if tuple_tokens in cache:
|
batch_chunks, token_count = self.process_texts(texts)
|
||||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
|
||||||
else:
|
|
||||||
fixes = []
|
|
||||||
remade_tokens = []
|
|
||||||
multipliers = []
|
|
||||||
mult = 1.0
|
|
||||||
|
|
||||||
i = 0
|
used_embeddings = {}
|
||||||
while i < len(tokens):
|
chunk_count = max([len(x) for x in batch_chunks])
|
||||||
token = tokens[i]
|
|
||||||
|
|
||||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
zs = []
|
||||||
|
for i in range(chunk_count):
|
||||||
|
batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
|
||||||
|
|
||||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
tokens = [x.tokens for x in batch_chunk]
|
||||||
if mult_change is not None:
|
multipliers = [x.multipliers for x in batch_chunk]
|
||||||
mult *= mult_change
|
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||||
i += 1
|
|
||||||
elif embedding is None:
|
|
||||||
remade_tokens.append(token)
|
|
||||||
multipliers.append(mult)
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
emb_len = int(embedding.vec.shape[0])
|
|
||||||
fixes.append((len(remade_tokens), embedding))
|
|
||||||
remade_tokens += [0] * emb_len
|
|
||||||
multipliers += [mult] * emb_len
|
|
||||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
|
||||||
i += embedding_length_in_tokens
|
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
for fixes in self.hijack.fixes:
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
for position, embedding in fixes:
|
||||||
ovf = remade_tokens[maxlen - 2:]
|
used_embeddings[embedding.name] = embedding
|
||||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
|
||||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
|
||||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
z = self.process_tokens(tokens, multipliers)
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
zs.append(z)
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
|
||||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
|
||||||
|
|
||||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
if len(used_embeddings) > 0:
|
||||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
|
||||||
|
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
|
||||||
|
|
||||||
remade_batch_tokens.append(remade_tokens)
|
return torch.hstack(zs)
|
||||||
hijack_fixes.append(fixes)
|
|
||||||
batch_multipliers.append(multipliers)
|
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
use_old = opts.use_old_emphasis_implementation
|
|
||||||
if use_old:
|
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
|
||||||
else:
|
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
|
||||||
|
|
||||||
self.hijack.comments += hijack_comments
|
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
|
||||||
|
|
||||||
if use_old:
|
|
||||||
self.hijack.fixes = hijack_fixes
|
|
||||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
|
||||||
|
|
||||||
z = None
|
|
||||||
i = 0
|
|
||||||
while max(map(len, remade_batch_tokens)) != 0:
|
|
||||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
|
||||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
|
||||||
|
|
||||||
self.hijack.fixes = []
|
|
||||||
for unfiltered in hijack_fixes:
|
|
||||||
fixes = []
|
|
||||||
for fix in unfiltered:
|
|
||||||
if fix[0] == i:
|
|
||||||
fixes.append(fix[1])
|
|
||||||
self.hijack.fixes.append(fixes)
|
|
||||||
|
|
||||||
tokens = []
|
|
||||||
multipliers = []
|
|
||||||
for j in range(len(remade_batch_tokens)):
|
|
||||||
if len(remade_batch_tokens[j]) > 0:
|
|
||||||
tokens.append(remade_batch_tokens[j][:75])
|
|
||||||
multipliers.append(batch_multipliers[j][:75])
|
|
||||||
else:
|
|
||||||
tokens.append([self.id_end] * 75)
|
|
||||||
multipliers.append([1.0] * 75)
|
|
||||||
|
|
||||||
z1 = self.process_tokens(tokens, multipliers)
|
|
||||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
|
||||||
|
|
||||||
remade_batch_tokens = rem_tokens
|
|
||||||
batch_multipliers = rem_multipliers
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return z
|
|
||||||
|
|
||||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||||
if not opts.use_old_emphasis_implementation:
|
"""
|
||||||
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
|
sends one single prompt chunk to be encoded by transformers neural network.
|
||||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
|
||||||
|
there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
|
||||||
|
Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
|
||||||
|
corresponds to one token.
|
||||||
|
"""
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
||||||
|
|
||||||
|
# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
|
||||||
if self.id_end != self.id_pad:
|
if self.id_end != self.id_pad:
|
||||||
for batch_pos in range(len(remade_batch_tokens)):
|
for batch_pos in range(len(remade_batch_tokens)):
|
||||||
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
||||||
|
@ -239,8 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
z = self.encode_with_transformers(tokens)
|
z = self.encode_with_transformers(tokens)
|
||||||
|
|
||||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
|
|
||||||
original_mean = z.mean()
|
original_mean = z.mean()
|
||||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||||
new_mean = z.mean()
|
new_mean = z.mean()
|
||||||
|
|
81
modules/sd_hijack_clip_old.py
Normal file
81
modules/sd_hijack_clip_old.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
from modules import sd_hijack_clip
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
|
def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
||||||
|
id_start = self.id_start
|
||||||
|
id_end = self.id_end
|
||||||
|
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||||
|
used_custom_terms = []
|
||||||
|
remade_batch_tokens = []
|
||||||
|
hijack_comments = []
|
||||||
|
hijack_fixes = []
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
batch_tokens = self.tokenize(texts)
|
||||||
|
batch_multipliers = []
|
||||||
|
for tokens in batch_tokens:
|
||||||
|
tuple_tokens = tuple(tokens)
|
||||||
|
|
||||||
|
if tuple_tokens in cache:
|
||||||
|
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||||
|
else:
|
||||||
|
fixes = []
|
||||||
|
remade_tokens = []
|
||||||
|
multipliers = []
|
||||||
|
mult = 1.0
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
token = tokens[i]
|
||||||
|
|
||||||
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
|
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
|
||||||
|
if mult_change is not None:
|
||||||
|
mult *= mult_change
|
||||||
|
i += 1
|
||||||
|
elif embedding is None:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(mult)
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
emb_len = int(embedding.vec.shape[0])
|
||||||
|
fixes.append((len(remade_tokens), embedding))
|
||||||
|
remade_tokens += [0] * emb_len
|
||||||
|
multipliers += [mult] * emb_len
|
||||||
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
|
i += embedding_length_in_tokens
|
||||||
|
|
||||||
|
if len(remade_tokens) > maxlen - 2:
|
||||||
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
ovf = remade_tokens[maxlen - 2:]
|
||||||
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
|
token_count = len(remade_tokens)
|
||||||
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
|
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||||
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
|
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||||
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||||
|
|
||||||
|
remade_batch_tokens.append(remade_tokens)
|
||||||
|
hijack_fixes.append(fixes)
|
||||||
|
batch_multipliers.append(multipliers)
|
||||||
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
|
||||||
|
def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
||||||
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
|
||||||
|
|
||||||
|
self.hijack.comments += hijack_comments
|
||||||
|
|
||||||
|
if len(used_custom_terms) > 0:
|
||||||
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
|
self.hijack.fixes = hijack_fixes
|
||||||
|
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
|
@ -1,7 +1,7 @@
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import importlib
|
import psutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
@ -12,6 +12,8 @@ from einops import rearrange
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||||
try:
|
try:
|
||||||
|
@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_vram():
|
||||||
|
if shared.device.type == 'cuda':
|
||||||
|
stats = torch.cuda.memory_stats(shared.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
return mem_free_total
|
||||||
|
else:
|
||||||
|
return psutil.virtual_memory().available
|
||||||
|
|
||||||
|
|
||||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
mem_free_total = get_available_vram()
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
|
@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
|
||||||
def check_for_psutil():
|
|
||||||
try:
|
|
||||||
spec = importlib.util.find_spec('psutil')
|
|
||||||
return spec is not None
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
invokeAI_mps_available = check_for_psutil()
|
|
||||||
|
|
||||||
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||||
if invokeAI_mps_available:
|
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
import psutil
|
|
||||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
|
||||||
|
|
||||||
def einsum_op_compvis(q, k, v):
|
def einsum_op_compvis(q, k, v):
|
||||||
s = einsum('b i d, b j d -> b i j', q, k)
|
s = einsum('b i d, b j d -> b i j', q, k)
|
||||||
|
@ -215,6 +214,71 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||||
|
|
||||||
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
||||||
|
|
||||||
|
|
||||||
|
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
||||||
|
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
||||||
|
def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||||
|
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
||||||
|
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
|
||||||
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||||
|
k = self.to_k(context_k)
|
||||||
|
v = self.to_v(context_v)
|
||||||
|
del context, context_k, context_v, x
|
||||||
|
|
||||||
|
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
|
||||||
|
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||||
|
|
||||||
|
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
||||||
|
|
||||||
|
out_proj, dropout = self.to_out
|
||||||
|
x = out_proj(x)
|
||||||
|
x = dropout(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||||
|
bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||||
|
batch_x_heads, q_tokens, _ = q.shape
|
||||||
|
_, k_tokens, _ = k.shape
|
||||||
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
|
if chunk_threshold is None:
|
||||||
|
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
||||||
|
elif chunk_threshold == 0:
|
||||||
|
chunk_threshold_bytes = None
|
||||||
|
else:
|
||||||
|
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
|
||||||
|
|
||||||
|
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
|
||||||
|
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
|
||||||
|
elif kv_chunk_size_min == 0:
|
||||||
|
kv_chunk_size_min = None
|
||||||
|
|
||||||
|
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||||
|
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||||
|
# i.e. send it down the unchunked fast-path
|
||||||
|
query_chunk_size = q_tokens
|
||||||
|
kv_chunk_size = k_tokens
|
||||||
|
|
||||||
|
return efficient_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_chunk_size=q_chunk_size,
|
||||||
|
kv_chunk_size=kv_chunk_size,
|
||||||
|
kv_chunk_size_min = kv_chunk_size_min,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
|
@ -252,12 +316,7 @@ def cross_attention_attnblock_forward(self, x):
|
||||||
|
|
||||||
h_ = torch.zeros_like(k, device=q.device)
|
h_ = torch.zeros_like(k, device=q.device)
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
mem_free_total = get_available_vram()
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
mem_required = tensor_size * 2.5
|
mem_required = tensor_size * 2.5
|
||||||
|
@ -312,3 +371,19 @@ def xformers_attnblock_forward(self, x):
|
||||||
return x + out
|
return x + out
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return cross_attention_attnblock_forward(self, x)
|
return cross_attention_attnblock_forward(self, x)
|
||||||
|
|
||||||
|
def sub_quad_attnblock_forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||||
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||||
|
out = self.proj_out(out)
|
||||||
|
return x + out
|
||||||
|
|
|
@ -56,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
|
||||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||||
|
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||||
|
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||||
|
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||||
|
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||||
|
@ -362,6 +366,7 @@ options_templates.update(options_section(('training', "Training"), {
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||||
|
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
|
@ -429,7 +434,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
||||||
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
|
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
|
||||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||||
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
|
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||||
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -576,6 +581,7 @@ latent_upscale_modes = {
|
||||||
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
|
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
|
||||||
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
|
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
|
||||||
"Latent (nearest)": {"mode": "nearest", "antialias": False},
|
"Latent (nearest)": {"mode": "nearest", "antialias": False},
|
||||||
|
"Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
sd_upscalers = []
|
sd_upscalers = []
|
||||||
|
|
205
modules/sub_quadratic_attention.py
Normal file
205
modules/sub_quadratic_attention.py
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
# original source:
|
||||||
|
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
|
||||||
|
# license:
|
||||||
|
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
|
||||||
|
# credit:
|
||||||
|
# Amin Rezaei (original author)
|
||||||
|
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
||||||
|
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
|
||||||
|
# implementation of:
|
||||||
|
# Self-attention Does Not Need O(n2) Memory":
|
||||||
|
# https://arxiv.org/abs/2112.05682v2
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
import math
|
||||||
|
from typing import Optional, NamedTuple, Protocol, List
|
||||||
|
|
||||||
|
def narrow_trunc(
|
||||||
|
input: Tensor,
|
||||||
|
dim: int,
|
||||||
|
start: int,
|
||||||
|
length: int
|
||||||
|
) -> Tensor:
|
||||||
|
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
||||||
|
|
||||||
|
class AttnChunk(NamedTuple):
|
||||||
|
exp_values: Tensor
|
||||||
|
exp_weights_sum: Tensor
|
||||||
|
max_score: Tensor
|
||||||
|
|
||||||
|
class SummarizeChunk(Protocol):
|
||||||
|
@staticmethod
|
||||||
|
def __call__(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
) -> AttnChunk: ...
|
||||||
|
|
||||||
|
class ComputeQueryChunkAttn(Protocol):
|
||||||
|
@staticmethod
|
||||||
|
def __call__(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
) -> Tensor: ...
|
||||||
|
|
||||||
|
def _summarize_chunk(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
scale: float,
|
||||||
|
) -> AttnChunk:
|
||||||
|
attn_weights = torch.baddbmm(
|
||||||
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
|
query,
|
||||||
|
key.transpose(1,2),
|
||||||
|
alpha=scale,
|
||||||
|
beta=0,
|
||||||
|
)
|
||||||
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||||
|
max_score = max_score.detach()
|
||||||
|
exp_weights = torch.exp(attn_weights - max_score)
|
||||||
|
exp_values = torch.bmm(exp_weights, value)
|
||||||
|
max_score = max_score.squeeze(-1)
|
||||||
|
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||||
|
|
||||||
|
def _query_chunk_attention(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
summarize_chunk: SummarizeChunk,
|
||||||
|
kv_chunk_size: int,
|
||||||
|
) -> Tensor:
|
||||||
|
batch_x_heads, k_tokens, k_channels_per_head = key.shape
|
||||||
|
_, _, v_channels_per_head = value.shape
|
||||||
|
|
||||||
|
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
||||||
|
key_chunk = narrow_trunc(
|
||||||
|
key,
|
||||||
|
1,
|
||||||
|
chunk_idx,
|
||||||
|
kv_chunk_size
|
||||||
|
)
|
||||||
|
value_chunk = narrow_trunc(
|
||||||
|
value,
|
||||||
|
1,
|
||||||
|
chunk_idx,
|
||||||
|
kv_chunk_size
|
||||||
|
)
|
||||||
|
return summarize_chunk(query, key_chunk, value_chunk)
|
||||||
|
|
||||||
|
chunks: List[AttnChunk] = [
|
||||||
|
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||||
|
]
|
||||||
|
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||||
|
chunk_values, chunk_weights, chunk_max = acc_chunk
|
||||||
|
|
||||||
|
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
|
||||||
|
max_diffs = torch.exp(chunk_max - global_max)
|
||||||
|
chunk_values *= torch.unsqueeze(max_diffs, -1)
|
||||||
|
chunk_weights *= max_diffs
|
||||||
|
|
||||||
|
all_values = chunk_values.sum(dim=0)
|
||||||
|
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||||
|
return all_values / all_weights
|
||||||
|
|
||||||
|
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||||
|
def _get_attention_scores_no_kv_chunking(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
scale: float,
|
||||||
|
) -> Tensor:
|
||||||
|
attn_scores = torch.baddbmm(
|
||||||
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
|
query,
|
||||||
|
key.transpose(1,2),
|
||||||
|
alpha=scale,
|
||||||
|
beta=0,
|
||||||
|
)
|
||||||
|
attn_probs = attn_scores.softmax(dim=-1)
|
||||||
|
del attn_scores
|
||||||
|
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||||
|
return hidden_states_slice
|
||||||
|
|
||||||
|
class ScannedChunk(NamedTuple):
|
||||||
|
chunk_idx: int
|
||||||
|
attn_chunk: AttnChunk
|
||||||
|
|
||||||
|
def efficient_dot_product_attention(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
query_chunk_size=1024,
|
||||||
|
kv_chunk_size: Optional[int] = None,
|
||||||
|
kv_chunk_size_min: Optional[int] = None,
|
||||||
|
use_checkpoint=True,
|
||||||
|
):
|
||||||
|
"""Computes efficient dot-product attention given query, key, and value.
|
||||||
|
This is efficient version of attention presented in
|
||||||
|
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
|
||||||
|
Args:
|
||||||
|
query: queries for calculating attention with shape of
|
||||||
|
`[batch * num_heads, tokens, channels_per_head]`.
|
||||||
|
key: keys for calculating attention with shape of
|
||||||
|
`[batch * num_heads, tokens, channels_per_head]`.
|
||||||
|
value: values to be used in attention with shape of
|
||||||
|
`[batch * num_heads, tokens, channels_per_head]`.
|
||||||
|
query_chunk_size: int: query chunks size
|
||||||
|
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
|
||||||
|
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
|
||||||
|
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
|
||||||
|
Returns:
|
||||||
|
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
|
||||||
|
"""
|
||||||
|
batch_x_heads, q_tokens, q_channels_per_head = query.shape
|
||||||
|
_, k_tokens, _ = key.shape
|
||||||
|
scale = q_channels_per_head ** -0.5
|
||||||
|
|
||||||
|
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
|
||||||
|
if kv_chunk_size_min is not None:
|
||||||
|
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||||
|
|
||||||
|
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||||
|
return narrow_trunc(
|
||||||
|
query,
|
||||||
|
1,
|
||||||
|
chunk_idx,
|
||||||
|
min(query_chunk_size, q_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||||
|
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||||
|
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||||
|
_get_attention_scores_no_kv_chunking,
|
||||||
|
scale=scale
|
||||||
|
) if k_tokens <= kv_chunk_size else (
|
||||||
|
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
||||||
|
partial(
|
||||||
|
_query_chunk_attention,
|
||||||
|
kv_chunk_size=kv_chunk_size,
|
||||||
|
summarize_chunk=summarize_chunk,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if q_tokens <= query_chunk_size:
|
||||||
|
# fast-path for when there's just 1 query chunk
|
||||||
|
return compute_query_chunk_attn(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||||
|
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||||
|
res = torch.cat([
|
||||||
|
compute_query_chunk_attn(
|
||||||
|
query=get_query_chunk(i * query_chunk_size),
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||||
|
], dim=1)
|
||||||
|
return res
|
24
modules/textual_inversion/logging.py
Normal file
24
modules/textual_inversion/logging.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
|
||||||
|
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
|
||||||
|
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
|
||||||
|
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
||||||
|
saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
|
||||||
|
|
||||||
|
|
||||||
|
def save_settings_to_file(log_directory, all_params):
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
|
||||||
|
|
||||||
|
keys = saved_params_all
|
||||||
|
if all_params.get('preview_from_txt2img'):
|
||||||
|
keys = keys | saved_params_previews
|
||||||
|
|
||||||
|
params.update({k: v for k, v in all_params.items() if k in keys})
|
||||||
|
|
||||||
|
filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||||
|
with open(os.path.join(log_directory, filename), "w") as file:
|
||||||
|
json.dump(params, file, indent=4)
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import inspect
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -17,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
|
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
|
||||||
insert_image_data_embed, extract_image_data_embed,
|
insert_image_data_embed, extract_image_data_embed,
|
||||||
caption_image_overlay)
|
caption_image_overlay)
|
||||||
|
from modules.textual_inversion.logging import save_settings_to_file
|
||||||
|
|
||||||
|
|
||||||
class Embedding:
|
class Embedding:
|
||||||
def __init__(self, vec, name, step=None):
|
def __init__(self, vec, name, step=None):
|
||||||
|
@ -76,7 +79,6 @@ class EmbeddingDatabase:
|
||||||
|
|
||||||
self.word_embeddings[embedding.name] = embedding
|
self.word_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
# TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
|
|
||||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
||||||
|
|
||||||
first_id = ids[0]
|
first_id = ids[0]
|
||||||
|
@ -149,19 +151,20 @@ class EmbeddingDatabase:
|
||||||
else:
|
else:
|
||||||
self.skipped_embeddings[name] = embedding
|
self.skipped_embeddings[name] = embedding
|
||||||
|
|
||||||
for fn in os.listdir(self.embeddings_dir):
|
for root, dirs, fns in os.walk(self.embeddings_dir):
|
||||||
try:
|
for fn in fns:
|
||||||
fullfn = os.path.join(self.embeddings_dir, fn)
|
try:
|
||||||
|
fullfn = os.path.join(root, fn)
|
||||||
|
|
||||||
if os.stat(fullfn).st_size == 0:
|
if os.stat(fullfn).st_size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
process_file(fullfn, fn)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
process_file(fullfn, fn)
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||||
if len(self.skipped_embeddings) > 0:
|
if len(self.skipped_embeddings) > 0:
|
||||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||||
|
@ -229,6 +232,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
**values,
|
**values,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||||
assert model_name, f"{name} not selected"
|
assert model_name, f"{name} not selected"
|
||||||
assert learn_rate, "Learning rate is empty or 0"
|
assert learn_rate, "Learning rate is empty or 0"
|
||||||
|
@ -292,8 +296,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
if initial_step >= steps:
|
if initial_step >= steps:
|
||||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
|
||||||
|
|
||||||
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
||||||
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
||||||
None
|
None
|
||||||
|
@ -307,6 +311,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
|
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||||
|
|
||||||
|
if shared.opts.save_training_settings_to_txt:
|
||||||
|
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
|
||||||
|
|
||||||
latent_sampling_method = ds.latent_sampling_method
|
latent_sampling_method = ds.latent_sampling_method
|
||||||
|
|
||||||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||||
|
|
|
@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
|
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
from modules.shared import opts, cmd_opts, restricted_opts
|
from modules.shared import opts, cmd_opts, restricted_opts
|
||||||
|
@ -256,6 +256,20 @@ def add_style(name: str, prompt: str, negative_prompt: str):
|
||||||
return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
|
return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
|
||||||
|
|
||||||
|
|
||||||
|
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
|
||||||
|
from modules import processing, devices
|
||||||
|
|
||||||
|
if not enable:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
|
||||||
|
|
||||||
|
with devices.autocast():
|
||||||
|
p.init([""], [0], [0])
|
||||||
|
|
||||||
|
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
|
||||||
|
|
||||||
|
|
||||||
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
|
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
|
||||||
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
|
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
|
||||||
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
|
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
|
||||||
|
@ -368,7 +382,7 @@ def update_token_counter(text, steps):
|
||||||
|
|
||||||
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
||||||
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
||||||
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
|
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
|
||||||
style_class = ' class="red"' if (token_count > max_length) else ""
|
style_class = ' class="red"' if (token_count > max_length) else ""
|
||||||
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
||||||
|
|
||||||
|
@ -435,11 +449,9 @@ def create_toprow(is_img2img):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=1, elem_id="style_pos_col"):
|
with gr.Column(scale=1, elem_id="style_pos_col"):
|
||||||
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
||||||
prompt_style.save_to_config = True
|
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id="style_neg_col"):
|
with gr.Column(scale=1, elem_id="style_neg_col"):
|
||||||
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
||||||
prompt_style2.save_to_config = True
|
|
||||||
|
|
||||||
return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||||
|
|
||||||
|
@ -550,6 +562,8 @@ Requested path was: {f}
|
||||||
os.startfile(path)
|
os.startfile(path)
|
||||||
elif platform.system() == "Darwin":
|
elif platform.system() == "Darwin":
|
||||||
sp.Popen(["open", path])
|
sp.Popen(["open", path])
|
||||||
|
elif "microsoft-standard-WSL2" in platform.uname().release:
|
||||||
|
sp.Popen(["wsl-open", path])
|
||||||
else:
|
else:
|
||||||
sp.Popen(["xdg-open", path])
|
sp.Popen(["xdg-open", path])
|
||||||
|
|
||||||
|
@ -636,7 +650,6 @@ def create_sampler_and_steps_selection(choices, tabname):
|
||||||
if opts.samplers_in_dropdown:
|
if opts.samplers_in_dropdown:
|
||||||
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
||||||
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
||||||
sampler_index.save_to_config = True
|
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||||
else:
|
else:
|
||||||
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
||||||
|
@ -707,6 +720,7 @@ def create_ui():
|
||||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
|
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
|
||||||
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
|
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
|
||||||
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
|
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
|
||||||
|
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
|
||||||
|
|
||||||
elif category == "hires_fix":
|
elif category == "hires_fix":
|
||||||
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
|
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
|
||||||
|
@ -730,6 +744,17 @@ def create_ui():
|
||||||
with FormGroup(elem_id="txt2img_script_container"):
|
with FormGroup(elem_id="txt2img_script_container"):
|
||||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
||||||
|
|
||||||
|
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||||
|
hr_resolution_preview_args = dict(
|
||||||
|
fn=calc_resolution_hires,
|
||||||
|
inputs=hr_resolution_preview_inputs,
|
||||||
|
outputs=[hr_final_resolution],
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for input in hr_resolution_preview_inputs:
|
||||||
|
input.change(**hr_resolution_preview_args)
|
||||||
|
|
||||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||||
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
|
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
|
||||||
|
|
||||||
|
@ -791,6 +816,7 @@ def create_ui():
|
||||||
fn=lambda x: gr_show(x),
|
fn=lambda x: gr_show(x),
|
||||||
inputs=[enable_hr],
|
inputs=[enable_hr],
|
||||||
outputs=[hr_options],
|
outputs=[hr_options],
|
||||||
|
show_progress = False,
|
||||||
)
|
)
|
||||||
|
|
||||||
txt2img_paste_fields = [
|
txt2img_paste_fields = [
|
||||||
|
@ -1792,7 +1818,7 @@ def create_ui():
|
||||||
if init_field is not None:
|
if init_field is not None:
|
||||||
init_field(saved_value)
|
init_field(saved_value)
|
||||||
|
|
||||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
|
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
|
||||||
apply_field(x, 'visible')
|
apply_field(x, 'visible')
|
||||||
|
|
||||||
if type(x) == gr.Slider:
|
if type(x) == gr.Slider:
|
||||||
|
@ -1813,11 +1839,8 @@ def create_ui():
|
||||||
if type(x) == gr.Number:
|
if type(x) == gr.Number:
|
||||||
apply_field(x, 'value')
|
apply_field(x, 'value')
|
||||||
|
|
||||||
# Since there are many dropdowns that shouldn't be saved,
|
if type(x) == gr.Dropdown:
|
||||||
# we only mark dropdowns that should be saved.
|
|
||||||
if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
|
|
||||||
apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
|
apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
|
||||||
apply_field(x, 'visible')
|
|
||||||
|
|
||||||
visit(txt2img_interface, loadsave, "txt2img")
|
visit(txt2img_interface, loadsave, "txt2img")
|
||||||
visit(img2img_interface, loadsave, "img2img")
|
visit(img2img_interface, loadsave, "img2img")
|
||||||
|
|
|
@ -23,3 +23,11 @@ class FormGroup(gr.Group, gr.components.FormComponent):
|
||||||
|
|
||||||
def get_block_name(self):
|
def get_block_name(self):
|
||||||
return "group"
|
return "group"
|
||||||
|
|
||||||
|
|
||||||
|
class FormHTML(gr.HTML, gr.components.FormComponent):
|
||||||
|
"""Same as gr.HTML but fits inside gradio forms"""
|
||||||
|
|
||||||
|
def get_block_name(self):
|
||||||
|
return "html"
|
||||||
|
|
||||||
|
|
|
@ -162,15 +162,15 @@ def install_extension_from_url(dirname, url):
|
||||||
shutil.rmtree(tmpdir, True)
|
shutil.rmtree(tmpdir, True)
|
||||||
|
|
||||||
|
|
||||||
def install_extension_from_index(url, hide_tags):
|
def install_extension_from_index(url, hide_tags, sort_column):
|
||||||
ext_table, message = install_extension_from_url(None, url)
|
ext_table, message = install_extension_from_url(None, url)
|
||||||
|
|
||||||
code, _ = refresh_available_extensions_from_data(hide_tags)
|
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
|
||||||
|
|
||||||
return code, ext_table, message
|
return code, ext_table, message
|
||||||
|
|
||||||
|
|
||||||
def refresh_available_extensions(url, hide_tags):
|
def refresh_available_extensions(url, hide_tags, sort_column):
|
||||||
global available_extensions
|
global available_extensions
|
||||||
|
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
@ -179,18 +179,28 @@ def refresh_available_extensions(url, hide_tags):
|
||||||
|
|
||||||
available_extensions = json.loads(text)
|
available_extensions = json.loads(text)
|
||||||
|
|
||||||
code, tags = refresh_available_extensions_from_data(hide_tags)
|
code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
|
||||||
|
|
||||||
return url, code, gr.CheckboxGroup.update(choices=tags), ''
|
return url, code, gr.CheckboxGroup.update(choices=tags), ''
|
||||||
|
|
||||||
|
|
||||||
def refresh_available_extensions_for_tags(hide_tags):
|
def refresh_available_extensions_for_tags(hide_tags, sort_column):
|
||||||
code, _ = refresh_available_extensions_from_data(hide_tags)
|
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
|
||||||
|
|
||||||
return code, ''
|
return code, ''
|
||||||
|
|
||||||
|
|
||||||
def refresh_available_extensions_from_data(hide_tags):
|
sort_ordering = [
|
||||||
|
# (reverse, order_by_function)
|
||||||
|
(True, lambda x: x.get('added', 'z')),
|
||||||
|
(False, lambda x: x.get('added', 'z')),
|
||||||
|
(False, lambda x: x.get('name', 'z')),
|
||||||
|
(True, lambda x: x.get('name', 'z')),
|
||||||
|
(False, lambda x: 'z'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_available_extensions_from_data(hide_tags, sort_column):
|
||||||
extlist = available_extensions["extensions"]
|
extlist = available_extensions["extensions"]
|
||||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
||||||
|
|
||||||
|
@ -210,8 +220,11 @@ def refresh_available_extensions_from_data(hide_tags):
|
||||||
<tbody>
|
<tbody>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for ext in extlist:
|
sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]
|
||||||
|
|
||||||
|
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
|
||||||
name = ext.get("name", "noname")
|
name = ext.get("name", "noname")
|
||||||
|
added = ext.get('added', 'unknown')
|
||||||
url = ext.get("url", None)
|
url = ext.get("url", None)
|
||||||
description = ext.get("description", "")
|
description = ext.get("description", "")
|
||||||
extension_tags = ext.get("tags", [])
|
extension_tags = ext.get("tags", [])
|
||||||
|
@ -233,7 +246,7 @@ def refresh_available_extensions_from_data(hide_tags):
|
||||||
code += f"""
|
code += f"""
|
||||||
<tr>
|
<tr>
|
||||||
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
|
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
|
||||||
<td>{html.escape(description)}</td>
|
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
|
||||||
<td>{install_code}</td>
|
<td>{install_code}</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
@ -291,25 +304,32 @@ def create_ui():
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
||||||
|
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
||||||
|
|
||||||
install_result = gr.HTML()
|
install_result = gr.HTML()
|
||||||
available_extensions_table = gr.HTML()
|
available_extensions_table = gr.HTML()
|
||||||
|
|
||||||
refresh_available_extensions_button.click(
|
refresh_available_extensions_button.click(
|
||||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
|
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
|
||||||
inputs=[available_extensions_index, hide_tags],
|
inputs=[available_extensions_index, hide_tags, sort_column],
|
||||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
|
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
|
||||||
)
|
)
|
||||||
|
|
||||||
install_extension_button.click(
|
install_extension_button.click(
|
||||||
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
|
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
|
||||||
inputs=[extension_to_install, hide_tags],
|
inputs=[extension_to_install, hide_tags, sort_column],
|
||||||
outputs=[available_extensions_table, extensions_table, install_result],
|
outputs=[available_extensions_table, extensions_table, install_result],
|
||||||
)
|
)
|
||||||
|
|
||||||
hide_tags.change(
|
hide_tags.change(
|
||||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
||||||
inputs=[hide_tags],
|
inputs=[hide_tags, sort_column],
|
||||||
|
outputs=[available_extensions_table, install_result]
|
||||||
|
)
|
||||||
|
|
||||||
|
sort_column.change(
|
||||||
|
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
||||||
|
inputs=[hide_tags, sort_column],
|
||||||
outputs=[available_extensions_table, install_result]
|
outputs=[available_extensions_table, install_result]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -30,4 +30,4 @@ inflection
|
||||||
GitPython
|
GitPython
|
||||||
torchsde
|
torchsde
|
||||||
safetensors
|
safetensors
|
||||||
psutil; sys_platform == 'darwin'
|
psutil
|
||||||
|
|
BIN
screenshot.png
BIN
screenshot.png
Binary file not shown.
Before Width: | Height: | Size: 513 KiB After Width: | Height: | Size: 411 KiB |
28
style.css
28
style.css
|
@ -555,7 +555,7 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
|
||||||
|
|
||||||
/* Extensions */
|
/* Extensions */
|
||||||
|
|
||||||
#tab_extensions table{
|
#tab_extensions table``{
|
||||||
border-collapse: collapse;
|
border-collapse: collapse;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -581,6 +581,15 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
|
||||||
font-size: 95%;
|
font-size: 95%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#available_extensions .info{
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#available_extensions .date_added{
|
||||||
|
opacity: 0.85;
|
||||||
|
font-size: 90%;
|
||||||
|
}
|
||||||
|
|
||||||
#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
|
#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
|
||||||
min-width: auto;
|
min-width: auto;
|
||||||
padding-left: 0.5em;
|
padding-left: 0.5em;
|
||||||
|
@ -633,6 +642,23 @@ footer {
|
||||||
opacity: 0.85;
|
opacity: 0.85;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#txtimg_hr_finalres{
|
||||||
|
min-height: 0 !important;
|
||||||
|
padding: .625rem .75rem;
|
||||||
|
margin-left: -0.75em
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#txtimg_hr_finalres .resolution{
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
#txt2img_checkboxes > div > div{
|
||||||
|
flex: 0;
|
||||||
|
white-space: nowrap;
|
||||||
|
min-width: auto;
|
||||||
|
}
|
||||||
|
|
||||||
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
||||||
The rtl media type will only be activated by the logic in javascript/localization.js.
|
The rtl media type will only be activated by the logic in javascript/localization.js.
|
||||||
If you change anything above, you need to make sure it is RTL compliant by just running
|
If you change anything above, you need to make sure it is RTL compliant by just running
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 329 KiB |
9
webui.py
9
webui.py
|
@ -4,7 +4,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
import importlib
|
import importlib
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import re
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
|
@ -13,6 +13,11 @@ from modules import import_hook, errors
|
||||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||||
|
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||||
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||||
|
|
||||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
|
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
|
||||||
import modules.codeformer_model as codeformer
|
import modules.codeformer_model as codeformer
|
||||||
import modules.extras
|
import modules.extras
|
||||||
|
@ -182,12 +187,14 @@ def webui():
|
||||||
|
|
||||||
sd_samplers.set_samplers()
|
sd_samplers.set_samplers()
|
||||||
|
|
||||||
|
modules.script_callbacks.script_unloaded_callback()
|
||||||
extensions.list_extensions()
|
extensions.list_extensions()
|
||||||
|
|
||||||
localization.list_localizations(cmd_opts.localizations_dir)
|
localization.list_localizations(cmd_opts.localizations_dir)
|
||||||
|
|
||||||
modelloader.forbid_loaded_nonbuiltin_upscalers()
|
modelloader.forbid_loaded_nonbuiltin_upscalers()
|
||||||
modules.scripts.reload_scripts()
|
modules.scripts.reload_scripts()
|
||||||
|
modules.script_callbacks.model_loaded_callback(shared.sd_model)
|
||||||
modelloader.load_upscalers()
|
modelloader.load_upscalers()
|
||||||
|
|
||||||
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
||||||
|
|
Loading…
Reference in a new issue