added cheap NN approximation for VAE
This commit is contained in:
parent
5927d3fa95
commit
56e557c6ff
5 changed files with 81 additions and 17 deletions
|
@ -97,7 +97,10 @@ titles = {
|
|||
|
||||
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||
|
||||
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc."
|
||||
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
||||
|
||||
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
|
||||
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality."
|
||||
}
|
||||
|
||||
|
||||
|
|
BIN
models/VAE-approx/model.pt
Normal file
BIN
models/VAE-approx/model.pt
Normal file
Binary file not shown.
|
@ -9,7 +9,7 @@ import k_diffusion.sampling
|
|||
import torchsde._brownian.brownian_interval
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
from modules import prompt_parser, devices, processing, images
|
||||
from modules import prompt_parser, devices, processing, images, sd_vae_approx
|
||||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
|
@ -106,28 +106,31 @@ def setup_img2img_steps(p, steps=None):
|
|||
return steps, t_enc
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=False):
|
||||
if approximation:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||
coefs = torch.tensor(
|
||||
[[ 0.298, 0.207, 0.208],
|
||||
[ 0.187, 0.286, 0.173],
|
||||
[-0.158, 0.189, 0.264],
|
||||
[-0.184, -0.271, -0.473]]).to(sample.device)
|
||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=False):
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
|
||||
|
||||
def samples_to_image_grid(samples, approximation=False):
|
||||
def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
|
@ -136,7 +139,7 @@ def store_latent(decoded):
|
|||
|
||||
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if not shared.parallel_processing_allowed:
|
||||
shared.state.current_image = sample_to_image(decoded, approximation=opts.show_progress_approximate)
|
||||
shared.state.current_image = sample_to_image(decoded)
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
|
|
58
modules/sd_vae_approx.py
Normal file
58
modules/sd_vae_approx.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from modules import devices, paths
|
||||
|
||||
sd_vae_approx_model = None
|
||||
|
||||
|
||||
class VAEApprox(nn.Module):
|
||||
def __init__(self):
|
||||
super(VAEApprox, self).__init__()
|
||||
self.conv1 = nn.Conv2d(4, 8, (7, 7))
|
||||
self.conv2 = nn.Conv2d(8, 16, (5, 5))
|
||||
self.conv3 = nn.Conv2d(16, 32, (3, 3))
|
||||
self.conv4 = nn.Conv2d(32, 64, (3, 3))
|
||||
self.conv5 = nn.Conv2d(64, 32, (3, 3))
|
||||
self.conv6 = nn.Conv2d(32, 16, (3, 3))
|
||||
self.conv7 = nn.Conv2d(16, 8, (3, 3))
|
||||
self.conv8 = nn.Conv2d(8, 3, (3, 3))
|
||||
|
||||
def forward(self, x):
|
||||
extra = 11
|
||||
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
|
||||
x = nn.functional.pad(x, (extra, extra, extra, extra))
|
||||
|
||||
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
|
||||
x = layer(x)
|
||||
x = nn.functional.leaky_relu(x, 0.1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_approx_model
|
||||
|
||||
if sd_vae_approx_model is None:
|
||||
sd_vae_approx_model = VAEApprox()
|
||||
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
|
||||
sd_vae_approx_model.eval()
|
||||
sd_vae_approx_model.to(devices.device, devices.dtype)
|
||||
|
||||
return sd_vae_approx_model
|
||||
|
||||
|
||||
def cheap_approximation(sample):
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||
|
||||
coefs = torch.tensor([
|
||||
[0.298, 0.207, 0.208],
|
||||
[0.187, 0.286, 0.173],
|
||||
[-0.158, 0.189, 0.264],
|
||||
[-0.184, -0.271, -0.473],
|
||||
]).to(sample.device)
|
||||
|
||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||
|
||||
return x_sample
|
|
@ -212,9 +212,9 @@ class State:
|
|||
|
||||
import modules.sd_samplers
|
||||
if opts.show_progress_grid:
|
||||
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent, approximation=opts.show_progress_approximate)
|
||||
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
|
||||
else:
|
||||
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent, approximation=opts.show_progress_approximate)
|
||||
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
|
@ -392,7 +392,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_approximate": OptionInfo(False, "Calculate small previews using fast linear approximation instead of VAE"),
|
||||
"show_progress_type": OptionInfo("Full", "Image creation progress mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
|
|
Loading…
Reference in a new issue