Merge branch 'master' into training-help-text
This commit is contained in:
commit
0c5522ea21
21 changed files with 688 additions and 137 deletions
32
.github/ISSUE_TEMPLATE/bug_report.md
vendored
32
.github/ISSUE_TEMPLATE/bug_report.md
vendored
|
@ -1,32 +0,0 @@
|
||||||
---
|
|
||||||
name: Bug report
|
|
||||||
about: Create a report to help us improve
|
|
||||||
title: ''
|
|
||||||
labels: bug-report
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Describe the bug**
|
|
||||||
A clear and concise description of what the bug is.
|
|
||||||
|
|
||||||
**To Reproduce**
|
|
||||||
Steps to reproduce the behavior:
|
|
||||||
1. Go to '...'
|
|
||||||
2. Click on '....'
|
|
||||||
3. Scroll down to '....'
|
|
||||||
4. See error
|
|
||||||
|
|
||||||
**Expected behavior**
|
|
||||||
A clear and concise description of what you expected to happen.
|
|
||||||
|
|
||||||
**Screenshots**
|
|
||||||
If applicable, add screenshots to help explain your problem.
|
|
||||||
|
|
||||||
**Desktop (please complete the following information):**
|
|
||||||
- OS: [e.g. Windows, Linux]
|
|
||||||
- Browser [e.g. chrome, safari]
|
|
||||||
- Commit revision [looks like this: e68484500f76a33ba477d5a99340ab30451e557b; can be seen when launching webui.bat, or obtained manually by running `git rev-parse HEAD`]
|
|
||||||
|
|
||||||
**Additional context**
|
|
||||||
Add any other context about the problem here.
|
|
81
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
81
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
name: Bug Report
|
||||||
|
description: You think somethings is broken in the UI
|
||||||
|
title: "[Bug]: "
|
||||||
|
labels: ["bug-report"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Is there an existing issue for this?
|
||||||
|
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
|
||||||
|
options:
|
||||||
|
- label: I have searched the existing issues and checked the recent builds/commits
|
||||||
|
required: true
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
|
||||||
|
- type: textarea
|
||||||
|
id: what-did
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: Tell us what happened in a very clear and simple way
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: steps
|
||||||
|
attributes:
|
||||||
|
label: Steps to reproduce the problem
|
||||||
|
description: Please provide us with precise step by step information on how to reproduce the bug
|
||||||
|
value: |
|
||||||
|
1. Go to ....
|
||||||
|
2. Press ....
|
||||||
|
3. ...
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: what-should
|
||||||
|
attributes:
|
||||||
|
label: What should have happened?
|
||||||
|
description: tell what you think the normal behavior should be
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: input
|
||||||
|
id: commit
|
||||||
|
attributes:
|
||||||
|
label: Commit where the problem happens
|
||||||
|
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
||||||
|
- type: dropdown
|
||||||
|
id: platforms
|
||||||
|
attributes:
|
||||||
|
label: What platforms do you use to access UI ?
|
||||||
|
multiple: true
|
||||||
|
options:
|
||||||
|
- Windows
|
||||||
|
- Linux
|
||||||
|
- MacOS
|
||||||
|
- iOS
|
||||||
|
- Android
|
||||||
|
- Other/Cloud
|
||||||
|
- type: dropdown
|
||||||
|
id: browsers
|
||||||
|
attributes:
|
||||||
|
label: What browsers do you use to access the UI ?
|
||||||
|
multiple: true
|
||||||
|
options:
|
||||||
|
- Mozilla Firefox
|
||||||
|
- Google Chrome
|
||||||
|
- Brave
|
||||||
|
- Apple Safari
|
||||||
|
- Microsoft Edge
|
||||||
|
- type: textarea
|
||||||
|
id: cmdargs
|
||||||
|
attributes:
|
||||||
|
label: Command Line Arguments
|
||||||
|
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
|
||||||
|
render: Shell
|
||||||
|
- type: textarea
|
||||||
|
id: misc
|
||||||
|
attributes:
|
||||||
|
label: Additional information, context and logs
|
||||||
|
description: Please provide us with any relevant additional info, context or log output.
|
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
|
@ -1,20 +0,0 @@
|
||||||
---
|
|
||||||
name: Feature request
|
|
||||||
about: Suggest an idea for this project
|
|
||||||
title: ''
|
|
||||||
labels: 'suggestion'
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Is your feature request related to a problem? Please describe.**
|
|
||||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
|
||||||
|
|
||||||
**Describe the solution you'd like**
|
|
||||||
A clear and concise description of what you want to happen.
|
|
||||||
|
|
||||||
**Describe alternatives you've considered**
|
|
||||||
A clear and concise description of any alternative solutions or features you've considered.
|
|
||||||
|
|
||||||
**Additional context**
|
|
||||||
Add any other context or screenshots about the feature request here.
|
|
40
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
40
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
name: Feature request
|
||||||
|
description: Suggest an idea for this project
|
||||||
|
title: "[Feature Request]: "
|
||||||
|
labels: ["suggestion"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Is there an existing issue for this?
|
||||||
|
description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
|
||||||
|
options:
|
||||||
|
- label: I have searched the existing issues and checked the recent builds/commits
|
||||||
|
required: true
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
*Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
|
||||||
|
- type: textarea
|
||||||
|
id: feature
|
||||||
|
attributes:
|
||||||
|
label: What would your feature do ?
|
||||||
|
description: Tell us about your feature in a very clear and simple way, and what problem it would solve
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: workflow
|
||||||
|
attributes:
|
||||||
|
label: Proposed workflow
|
||||||
|
description: Please provide us with step by step information on how you'd like the feature to be accessed and used
|
||||||
|
value: |
|
||||||
|
1. Go to ....
|
||||||
|
2. Press ....
|
||||||
|
3. ...
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: misc
|
||||||
|
attributes:
|
||||||
|
label: Additional information
|
||||||
|
description: Add any other context or screenshots about the feature request here.
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -27,3 +27,4 @@ __pycache__
|
||||||
notification.mp3
|
notification.mp3
|
||||||
/SwinIR
|
/SwinIR
|
||||||
/textual_inversion
|
/textual_inversion
|
||||||
|
.vscode
|
|
@ -70,6 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
||||||
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
||||||
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
|
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
|
||||||
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
|
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
|
||||||
|
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||||
|
|
|
@ -91,6 +91,8 @@ titles = {
|
||||||
|
|
||||||
"Weighted sum": "Result = A * (1 - M) + B * M",
|
"Weighted sum": "Result = A * (1 - M) + B * M",
|
||||||
"Add difference": "Result = A + (B - C) * M",
|
"Add difference": "Result = A + (B - C) * M",
|
||||||
|
|
||||||
|
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
10
launch.py
10
launch.py
|
@ -156,9 +156,15 @@ def prepare_enviroment():
|
||||||
if not is_installed("clip"):
|
if not is_installed("clip"):
|
||||||
run_pip(f"install {clip_package}", "clip")
|
run_pip(f"install {clip_package}", "clip")
|
||||||
|
|
||||||
if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"):
|
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
|
if platform.python_version().startswith("3.10"):
|
||||||
|
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
|
||||||
|
else:
|
||||||
|
print("Installation of xformers is not supported in this version of Python.")
|
||||||
|
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||||
|
if not is_installed("xformers"):
|
||||||
|
exit(0)
|
||||||
elif platform.system() == "Linux":
|
elif platform.system() == "Linux":
|
||||||
run_pip("install xformers", "xformers")
|
run_pip("install xformers", "xformers")
|
||||||
|
|
||||||
|
|
|
@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||||
|
|
||||||
if input_dir == '':
|
if input_dir == '':
|
||||||
return outputs, "Please select an input directory.", ''
|
return outputs, "Please select an input directory.", ''
|
||||||
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
|
||||||
for img in image_list:
|
for img in image_list:
|
||||||
image = Image.open(img)
|
try:
|
||||||
|
image = Image.open(img)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
imageArr.append(image)
|
imageArr.append(image)
|
||||||
imageNameArr.append(img)
|
imageNameArr.append(img)
|
||||||
else:
|
else:
|
||||||
|
@ -119,9 +122,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||||
while len(cached_images) > 2:
|
while len(cached_images) > 2:
|
||||||
del cached_images[next(iter(cached_images.keys()))]
|
del cached_images[next(iter(cached_images.keys()))]
|
||||||
|
|
||||||
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
if opts.use_original_name_batch and image_name != None:
|
||||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
|
basename = os.path.splitext(os.path.basename(image_name))[0]
|
||||||
forced_filename=image_name if opts.use_original_name_batch else None)
|
else:
|
||||||
|
basename = ''
|
||||||
|
|
||||||
|
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||||
|
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
||||||
|
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
image.info = existing_pnginfo
|
image.info = existing_pnginfo
|
||||||
|
|
|
@ -45,11 +45,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||||
else:
|
else:
|
||||||
prompt += ("" if prompt == "" else "\n") + line
|
prompt += ("" if prompt == "" else "\n") + line
|
||||||
|
|
||||||
if len(prompt) > 0:
|
res["Prompt"] = prompt
|
||||||
res["Prompt"] = prompt
|
res["Negative prompt"] = negative_prompt
|
||||||
|
|
||||||
if len(negative_prompt) > 0:
|
|
||||||
res["Negative prompt"] = negative_prompt
|
|
||||||
|
|
||||||
for k, v in re_param.findall(lastline):
|
for k, v in re_param.findall(lastline):
|
||||||
m = re_imagesize.match(v)
|
m = re_imagesize.match(v)
|
||||||
|
|
|
@ -22,40 +22,57 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
|
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if layer_structure is not None:
|
|
||||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
else:
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
layer_structure = parse_layer_structure(dim, state_dict)
|
|
||||||
|
|
||||||
linears = []
|
linears = []
|
||||||
for i in range(len(layer_structure) - 1):
|
for i in range(len(layer_structure) - 1):
|
||||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
|
if activation_func == "relu":
|
||||||
|
linears.append(torch.nn.ReLU())
|
||||||
|
elif activation_func == "leakyrelu":
|
||||||
|
linears.append(torch.nn.LeakyReLU())
|
||||||
|
elif activation_func == 'linear' or activation_func is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||||
|
|
||||||
if add_layer_norm:
|
if add_layer_norm:
|
||||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
self.linear = torch.nn.Sequential(*linears)
|
self.linear = torch.nn.Sequential(*linears)
|
||||||
|
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
try:
|
self.fix_old_state_dict(state_dict)
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
except RuntimeError:
|
|
||||||
self.try_load_previous(state_dict)
|
|
||||||
else:
|
else:
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
layer.weight.data.normal_(mean = 0.0, std = 0.01)
|
if type(layer) == torch.nn.Linear:
|
||||||
layer.bias.data.zero_()
|
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
|
layer.bias.data.zero_()
|
||||||
|
|
||||||
self.to(devices.device)
|
self.to(devices.device)
|
||||||
|
|
||||||
def try_load_previous(self, state_dict):
|
def fix_old_state_dict(self, state_dict):
|
||||||
states = self.state_dict()
|
changes = {
|
||||||
states['linear.0.bias'].copy_(state_dict['linear1.bias'])
|
'linear1.bias': 'linear.0.bias',
|
||||||
states['linear.0.weight'].copy_(state_dict['linear1.weight'])
|
'linear1.weight': 'linear.0.weight',
|
||||||
states['linear.1.bias'].copy_(state_dict['linear2.bias'])
|
'linear2.bias': 'linear.1.bias',
|
||||||
states['linear.1.weight'].copy_(state_dict['linear2.weight'])
|
'linear2.weight': 'linear.1.weight',
|
||||||
|
}
|
||||||
|
|
||||||
|
for fr, to in changes.items():
|
||||||
|
x = state_dict.get(fr, None)
|
||||||
|
if x is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
del state_dict[fr]
|
||||||
|
state_dict[to] = x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + self.linear(x) * self.multiplier
|
return x + self.linear(x) * self.multiplier
|
||||||
|
@ -63,7 +80,8 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
def trainables(self):
|
def trainables(self):
|
||||||
layer_structure = []
|
layer_structure = []
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
layer_structure += [layer.weight, layer.bias]
|
if type(layer) == torch.nn.Linear:
|
||||||
|
layer_structure += [layer.weight, layer.bias]
|
||||||
return layer_structure
|
return layer_structure
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,23 +89,11 @@ def apply_strength(value=None):
|
||||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||||
|
|
||||||
|
|
||||||
def parse_layer_structure(dim, state_dict):
|
|
||||||
i = 0
|
|
||||||
layer_structure = [1]
|
|
||||||
|
|
||||||
while (key := "linear.{}.weight".format(i)) in state_dict:
|
|
||||||
weight = state_dict[key]
|
|
||||||
layer_structure.append(len(weight) // dim)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return layer_structure
|
|
||||||
|
|
||||||
|
|
||||||
class Hypernetwork:
|
class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
|
@ -96,11 +102,12 @@ class Hypernetwork:
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
self.layer_structure = layer_structure
|
self.layer_structure = layer_structure
|
||||||
self.add_layer_norm = add_layer_norm
|
self.add_layer_norm = add_layer_norm
|
||||||
|
self.activation_func = activation_func
|
||||||
|
|
||||||
for size in enable_sizes or []:
|
for size in enable_sizes or []:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
)
|
)
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
|
@ -123,6 +130,7 @@ class Hypernetwork:
|
||||||
state_dict['name'] = self.name
|
state_dict['name'] = self.name
|
||||||
state_dict['layer_structure'] = self.layer_structure
|
state_dict['layer_structure'] = self.layer_structure
|
||||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||||
|
state_dict['activation_func'] = self.activation_func
|
||||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
|
|
||||||
|
@ -135,17 +143,19 @@ class Hypernetwork:
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu')
|
state_dict = torch.load(filename, map_location='cpu')
|
||||||
|
|
||||||
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||||
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
|
self.activation_func = state_dict.get('activation_func', None)
|
||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = state_dict.get('name', self.name)
|
self.name = state_dict.get('name', self.name)
|
||||||
self.step = state_dict.get('step', 0)
|
self.step = state_dict.get('step', 0)
|
||||||
self.layer_structure = state_dict.get('layer_structure', None)
|
|
||||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
|
||||||
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
||||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||||
|
|
||||||
|
@ -244,7 +254,11 @@ def stack_conds(conds):
|
||||||
|
|
||||||
return torch.stack(conds)
|
return torch.stack(conds)
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
|
from modules import images
|
||||||
|
|
||||||
assert hypernetwork_name, 'hypernetwork not selected'
|
assert hypernetwork_name, 'hypernetwork not selected'
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
|
@ -287,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
last_saved_image = "<none>"
|
last_saved_image = "<none>"
|
||||||
|
forced_filename = "<none>"
|
||||||
|
|
||||||
ititial_step = hypernetwork.step or 0
|
ititial_step = hypernetwork.step or 0
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
|
@ -334,7 +349,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
})
|
})
|
||||||
|
|
||||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||||
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
@ -370,7 +386,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
image.save(last_saved_image)
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
shared.state.job_no = hypernetwork.step
|
||||||
|
|
|
@ -10,19 +10,20 @@ from modules import sd_hijack, shared, devices
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
if not overwrite_old:
|
if not overwrite_old:
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
if type(layer_structure) == str:
|
if type(layer_structure) == str:
|
||||||
layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
|
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||||
|
|
||||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||||
name=name,
|
name=name,
|
||||||
enable_sizes=[int(x) for x in enable_sizes],
|
enable_sizes=[int(x) for x in enable_sizes],
|
||||||
layer_structure=layer_structure,
|
layer_structure=layer_structure,
|
||||||
add_layer_norm=add_layer_norm,
|
add_layer_norm=add_layer_norm,
|
||||||
|
activation_func=activation_func,
|
||||||
)
|
)
|
||||||
hypernet.save(fn)
|
hypernet.save(fn)
|
||||||
|
|
||||||
|
|
|
@ -28,9 +28,11 @@ class InterrogateModels:
|
||||||
clip_preprocess = None
|
clip_preprocess = None
|
||||||
categories = None
|
categories = None
|
||||||
dtype = None
|
dtype = None
|
||||||
|
running_on_cpu = None
|
||||||
|
|
||||||
def __init__(self, content_dir):
|
def __init__(self, content_dir):
|
||||||
self.categories = []
|
self.categories = []
|
||||||
|
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||||
|
|
||||||
if os.path.exists(content_dir):
|
if os.path.exists(content_dir):
|
||||||
for filename in os.listdir(content_dir):
|
for filename in os.listdir(content_dir):
|
||||||
|
@ -53,7 +55,11 @@ class InterrogateModels:
|
||||||
def load_clip_model(self):
|
def load_clip_model(self):
|
||||||
import clip
|
import clip
|
||||||
|
|
||||||
model, preprocess = clip.load(clip_model_name)
|
if self.running_on_cpu:
|
||||||
|
model, preprocess = clip.load(clip_model_name, device="cpu")
|
||||||
|
else:
|
||||||
|
model, preprocess = clip.load(clip_model_name)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model = model.to(devices.device_interrogate)
|
model = model.to(devices.device_interrogate)
|
||||||
|
|
||||||
|
@ -62,14 +68,14 @@ class InterrogateModels:
|
||||||
def load(self):
|
def load(self):
|
||||||
if self.blip_model is None:
|
if self.blip_model is None:
|
||||||
self.blip_model = self.load_blip_model()
|
self.blip_model = self.load_blip_model()
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||||
self.blip_model = self.blip_model.half()
|
self.blip_model = self.blip_model.half()
|
||||||
|
|
||||||
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
||||||
|
|
||||||
if self.clip_model is None:
|
if self.clip_model is None:
|
||||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||||
self.clip_model = self.clip_model.half()
|
self.clip_model = self.clip_model.half()
|
||||||
|
|
||||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||||
|
|
|
@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
|
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
|
||||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||||
|
@ -540,17 +540,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||||
|
|
||||||
|
def create_dummy_mask(self, x, width=None, height=None):
|
||||||
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
|
height = height or self.height
|
||||||
|
width = width or self.width
|
||||||
|
|
||||||
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||||
|
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||||
|
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||||
|
|
||||||
|
# Add the fake full 1s mask to the first dimension.
|
||||||
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||||
|
image_conditioning = image_conditioning.to(x.dtype)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Dummy zero conditioning if we're not using inpainting model.
|
||||||
|
# Still takes up a bit of memory, but no encoder call.
|
||||||
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||||
|
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
return image_conditioning
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
|
||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||||
|
|
||||||
|
@ -587,7 +607,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
x = None
|
x = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
@ -613,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
self.inpainting_mask_invert = inpainting_mask_invert
|
self.inpainting_mask_invert = inpainting_mask_invert
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
|
self.image_conditioning = None
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
||||||
|
@ -714,10 +735,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
elif self.inpainting_fill == 3:
|
elif self.inpainting_fill == 3:
|
||||||
self.init_latent = self.init_latent * self.mask
|
self.init_latent = self.init_latent * self.mask
|
||||||
|
|
||||||
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
|
if self.image_mask is not None:
|
||||||
|
conditioning_mask = np.array(self.image_mask.convert("L"))
|
||||||
|
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||||
|
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||||
|
|
||||||
|
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||||
|
conditioning_mask = torch.round(conditioning_mask)
|
||||||
|
else:
|
||||||
|
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
|
||||||
|
|
||||||
|
# Create another latent image, this time with a masked version of the original input.
|
||||||
|
conditioning_mask = conditioning_mask.to(image.device)
|
||||||
|
conditioning_image = image * (1.0 - conditioning_mask)
|
||||||
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||||
|
|
||||||
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||||
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
|
||||||
|
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
||||||
|
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
||||||
|
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
||||||
|
else:
|
||||||
|
self.image_conditioning = torch.zeros(
|
||||||
|
self.init_latent.shape[0], 5, 1, 1,
|
||||||
|
dtype=self.init_latent.dtype,
|
||||||
|
device=self.init_latent.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
samples = samples * self.nmask + self.init_latent * self.mask
|
samples = samples * self.nmask + self.init_latent * self.mask
|
||||||
|
|
331
modules/sd_hijack_inpainting.py
Normal file
331
modules/sd_hijack_inpainting.py
Normal file
|
@ -0,0 +1,331 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import repeat
|
||||||
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
|
import ldm.models.diffusion.ddpm
|
||||||
|
import ldm.models.diffusion.ddim
|
||||||
|
import ldm.models.diffusion.plms
|
||||||
|
|
||||||
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
# Monkey patch DDIMSampler methods from RunwayML repo directly.
|
||||||
|
# Adapted from:
|
||||||
|
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
|
||||||
|
# =================================================================================================
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_ddim(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||||
|
|
||||||
|
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [
|
||||||
|
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||||
|
for i in range(len(c[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
# Monkey patch PLMSSampler methods.
|
||||||
|
# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
|
||||||
|
# Adapted from:
|
||||||
|
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
|
||||||
|
# =================================================================================================
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_plms(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
|
||||||
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [
|
||||||
|
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||||
|
for i in range(len(c[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
|
||||||
|
# Adapted from:
|
||||||
|
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
|
||||||
|
# =================================================================================================
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||||
|
if null_label is not None:
|
||||||
|
xc = null_label
|
||||||
|
if isinstance(xc, ListConfig):
|
||||||
|
xc = list(xc)
|
||||||
|
if isinstance(xc, dict) or isinstance(xc, list):
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
if hasattr(xc, "to"):
|
||||||
|
xc = xc.to(self.device)
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
# todo: get null label from cond_stage_model
|
||||||
|
raise NotImplementedError()
|
||||||
|
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
concat_keys=("mask", "masked_image"),
|
||||||
|
masked_image_key="masked_image",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.masked_image_key = masked_image_key
|
||||||
|
assert self.masked_image_key in concat_keys
|
||||||
|
self.concat_keys = concat_keys
|
||||||
|
|
||||||
|
|
||||||
|
def should_hijack_inpainting(checkpoint_info):
|
||||||
|
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def do_inpainting_hijack():
|
||||||
|
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||||
|
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||||
|
|
||||||
|
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||||
|
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||||
|
|
||||||
|
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||||
|
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
|
@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader, devices
|
from modules import shared, modelloader, devices
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
@ -203,14 +204,26 @@ def load_model_weights(model, checkpoint_info):
|
||||||
model.sd_checkpoint_info = checkpoint_info
|
model.sd_checkpoint_info = checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model(checkpoint_info=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
if checkpoint_info.config != shared.cmd_opts.config:
|
if checkpoint_info.config != shared.cmd_opts.config:
|
||||||
print(f"Loading config from: {checkpoint_info.config}")
|
print(f"Loading config from: {checkpoint_info.config}")
|
||||||
|
|
||||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||||
|
|
||||||
|
if should_hijack_inpainting(checkpoint_info):
|
||||||
|
# Hardcoded config for now...
|
||||||
|
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||||
|
sd_config.model.params.use_ema = False
|
||||||
|
sd_config.model.params.conditioning_key = "hybrid"
|
||||||
|
sd_config.model.params.unet_config.params.in_channels = 9
|
||||||
|
|
||||||
|
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
||||||
|
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
||||||
|
|
||||||
|
do_inpainting_hijack()
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
|
@ -234,9 +247,9 @@ def reload_model_weights(sd_model, info=None):
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
|
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||||
checkpoints_loaded.clear()
|
checkpoints_loaded.clear()
|
||||||
shared.sd_model = load_model()
|
shared.sd_model = load_model(checkpoint_info)
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
|
|
@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler:
|
||||||
self.config = None
|
self.config = None
|
||||||
self.last_latent = None
|
self.last_latent = None
|
||||||
|
|
||||||
|
self.conditioning_key = sd_model.model.conditioning_key
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@ -136,6 +138,12 @@ class VanillaStableDiffusionSampler:
|
||||||
if self.stop_at is not None and self.step > self.stop_at:
|
if self.stop_at is not None and self.step > self.stop_at:
|
||||||
raise InterruptedException
|
raise InterruptedException
|
||||||
|
|
||||||
|
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||||
|
image_conditioning = None
|
||||||
|
if isinstance(cond, dict):
|
||||||
|
image_conditioning = cond["c_concat"][0]
|
||||||
|
cond = cond["c_crossattn"][0]
|
||||||
|
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
@ -157,6 +165,12 @@ class VanillaStableDiffusionSampler:
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||||
|
|
||||||
|
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||||
|
# Note that they need to be lists because it just concatenates them later.
|
||||||
|
if image_conditioning is not None:
|
||||||
|
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
|
@ -182,7 +196,7 @@ class VanillaStableDiffusionSampler:
|
||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
self.initialize(p)
|
self.initialize(p)
|
||||||
|
@ -196,20 +210,33 @@ class VanillaStableDiffusionSampler:
|
||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||||
|
|
||||||
self.init_latent = x
|
self.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
|
if image_conditioning is not None:
|
||||||
|
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
|
|
||||||
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
self.initialize(p)
|
self.initialize(p)
|
||||||
|
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
self.last_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
|
if image_conditioning is not None:
|
||||||
|
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
# existing code fails with certain step counts, like 9
|
# existing code fails with certain step counts, like 9
|
||||||
try:
|
try:
|
||||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
|
@ -228,7 +255,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise InterruptedException
|
raise InterruptedException
|
||||||
|
|
||||||
|
@ -239,28 +266,29 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1]:
|
if tensor.shape[1] == uncond.shape[1]:
|
||||||
cond_in = torch.cat([tensor, uncond])
|
cond_in = torch.cat([tensor, uncond])
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if shared.batch_cond_uncond:
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = a + batch_size
|
b = a + batch_size
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = min(a + batch_size, tensor.shape[0])
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||||
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||||
|
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
denoised = torch.clone(denoised_uncond)
|
denoised = torch.clone(denoised_uncond)
|
||||||
|
@ -306,6 +334,8 @@ class KDiffusionSampler:
|
||||||
self.config = None
|
self.config = None
|
||||||
self.last_latent = None
|
self.last_latent = None
|
||||||
|
|
||||||
|
self.conditioning_key = sd_model.model.conditioning_key
|
||||||
|
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
step = d['i']
|
step = d['i']
|
||||||
latent = d["denoised"]
|
latent = d["denoised"]
|
||||||
|
@ -361,7 +391,7 @@ class KDiffusionSampler:
|
||||||
|
|
||||||
return extra_params_kwargs
|
return extra_params_kwargs
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
|
@ -388,12 +418,18 @@ class KDiffusionSampler:
|
||||||
extra_params_kwargs['sigmas'] = sigma_sched
|
extra_params_kwargs['sigmas'] = sigma_sched
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
|
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale
|
||||||
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
|
@ -414,7 +450,13 @@ class KDiffusionSampler:
|
||||||
else:
|
else:
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
|
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
self.last_latent = x
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale
|
||||||
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
|
@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
|
||||||
|
|
||||||
self.dataset.append(entry)
|
self.dataset.append(entry)
|
||||||
|
|
||||||
assert len(self.dataset) > 1, "No images have been found in the dataset."
|
assert len(self.dataset) > 0, "No images have been found in the dataset."
|
||||||
self.length = len(self.dataset) * repeats // batch_size
|
self.length = len(self.dataset) * repeats // batch_size
|
||||||
|
|
||||||
self.initial_indexes = np.arange(len(self.dataset))
|
self.initial_indexes = np.arange(len(self.dataset))
|
||||||
|
@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
|
||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
def shuffle(self):
|
def shuffle(self):
|
||||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
|
||||||
|
|
||||||
def create_text(self, filename_text):
|
def create_text(self, filename_text):
|
||||||
text = random.choice(self.lines)
|
text = random.choice(self.lines)
|
||||||
|
|
|
@ -269,7 +269,12 @@ def calc_time_left(progress, threshold, label, force_display):
|
||||||
eta = (time_since_start/progress)
|
eta = (time_since_start/progress)
|
||||||
eta_relative = eta-time_since_start
|
eta_relative = eta-time_since_start
|
||||||
if (eta_relative > threshold and progress > 0.02) or force_display:
|
if (eta_relative > threshold and progress > 0.02) or force_display:
|
||||||
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
|
if eta_relative > 3600:
|
||||||
|
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
|
||||||
|
elif eta_relative > 60:
|
||||||
|
return label + time.strftime('%M:%S', time.gmtime(eta_relative))
|
||||||
|
else:
|
||||||
|
return label + time.strftime('%Ss', time.gmtime(eta_relative))
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@ -285,7 +290,7 @@ def check_progress_call(id_part):
|
||||||
if shared.state.sampling_steps > 0:
|
if shared.state.sampling_steps > 0:
|
||||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||||
|
|
||||||
time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display )
|
time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
|
||||||
if time_left != "":
|
if time_left != "":
|
||||||
shared.state.time_left_force_display = True
|
shared.state.time_left_force_display = True
|
||||||
|
|
||||||
|
@ -293,7 +298,7 @@ def check_progress_call(id_part):
|
||||||
|
|
||||||
progressbar = ""
|
progressbar = ""
|
||||||
if opts.show_progressbar:
|
if opts.show_progressbar:
|
||||||
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:hidden;width:{progress * 100}%">{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}</div></div>"""
|
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
|
||||||
|
|
||||||
image = gr_show(False)
|
image = gr_show(False)
|
||||||
preview_visibility = gr_show(False)
|
preview_visibility = gr_show(False)
|
||||||
|
@ -1221,6 +1226,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||||
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||||
|
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
|
@ -1311,6 +1317,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
overwrite_old_hypernetwork,
|
overwrite_old_hypernetwork,
|
||||||
new_hypernetwork_layer_structure,
|
new_hypernetwork_layer_structure,
|
||||||
new_hypernetwork_add_layer_norm,
|
new_hypernetwork_add_layer_norm,
|
||||||
|
new_hypernetwork_activation_func,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
|
|
|
@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
|
||||||
if info is None:
|
if info is None:
|
||||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||||
|
p.sd_model = shared.sd_model
|
||||||
|
|
||||||
|
|
||||||
def confirm_checkpoints(p, xs):
|
def confirm_checkpoints(p, xs):
|
||||||
|
|
5
webui.py
5
webui.py
|
@ -118,7 +118,8 @@ def api_only():
|
||||||
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||||
|
|
||||||
|
|
||||||
def webui(launch_api=False):
|
def webui():
|
||||||
|
launch_api = cmd_opts.api
|
||||||
initialize()
|
initialize()
|
||||||
|
|
||||||
while 1:
|
while 1:
|
||||||
|
@ -158,4 +159,4 @@ if __name__ == "__main__":
|
||||||
if cmd_opts.nowebui:
|
if cmd_opts.nowebui:
|
||||||
api_only()
|
api_only()
|
||||||
else:
|
else:
|
||||||
webui(cmd_opts.api)
|
webui()
|
||||||
|
|
Loading…
Reference in a new issue