added prompt matrix feature
all images in batches now have proper seeds, not just the first one added code to remove bad characters from filenames added code to flag output which writes it to csv and saves images renamed some fields in UI for clarity
This commit is contained in:
parent
b63d0726cd
commit
3395c29127
1 changed files with 126 additions and 40 deletions
122
webui.py
122
webui.py
|
@ -8,12 +8,12 @@ from omegaconf import OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from torchvision.utils import make_grid
|
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
|
import csv
|
||||||
|
|
||||||
import k_diffusion as K
|
import k_diffusion as K
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
@ -28,6 +28,8 @@ mimetypes.add_type('application/javascript', '.js')
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
opt_f = 8
|
opt_f = 8
|
||||||
|
|
||||||
|
invalid_filename_chars = '<>:"/\|?*'
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
|
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
|
||||||
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
|
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
|
||||||
|
@ -127,13 +129,14 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
||||||
model = model.half().to(device)
|
model = model.half().to(device)
|
||||||
|
|
||||||
|
|
||||||
def image_grid(imgs, batch_size):
|
def image_grid(imgs, batch_size, round_down=False):
|
||||||
if opt.n_rows > 0:
|
if opt.n_rows > 0:
|
||||||
rows = opt.n_rows
|
rows = opt.n_rows
|
||||||
elif opt.n_rows == 0:
|
elif opt.n_rows == 0:
|
||||||
rows = batch_size
|
rows = batch_size
|
||||||
else:
|
else:
|
||||||
rows = round(math.sqrt(len(imgs)))
|
rows = math.sqrt(len(imgs))
|
||||||
|
rows = int(rows) if round_down else round(rows)
|
||||||
|
|
||||||
cols = math.ceil(len(imgs) / rows)
|
cols = math.ceil(len(imgs) / rows)
|
||||||
|
|
||||||
|
@ -146,7 +149,7 @@ def image_grid(imgs, batch_size):
|
||||||
return grid
|
return grid
|
||||||
|
|
||||||
|
|
||||||
def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
|
def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
outpath = opt.outdir or "outputs/txt2img-samples"
|
outpath = opt.outdir or "outputs/txt2img-samples"
|
||||||
|
@ -155,6 +158,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
|
||||||
seed = random.randrange(4294967294)
|
seed = random.randrange(4294967294)
|
||||||
|
|
||||||
seed = int(seed)
|
seed = int(seed)
|
||||||
|
keep_same_seed = False
|
||||||
|
|
||||||
is_PLMS = sampler_name == 'PLMS'
|
is_PLMS = sampler_name == 'PLMS'
|
||||||
is_DDIM = sampler_name == 'DDIM'
|
is_DDIM = sampler_name == 'DDIM'
|
||||||
|
@ -177,43 +181,78 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
|
||||||
batch_size = n_samples
|
batch_size = n_samples
|
||||||
|
|
||||||
assert prompt is not None
|
assert prompt is not None
|
||||||
data = [batch_size * [prompt]]
|
prompts = batch_size * [prompt]
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
sample_path = os.path.join(outpath, "samples")
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
base_count = len(os.listdir(sample_path))
|
base_count = len(os.listdir(sample_path))
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
prompt_matrix_prompts = []
|
||||||
|
comment = ""
|
||||||
|
if prompt_matrix:
|
||||||
|
keep_same_seed = True
|
||||||
|
comment = "Image prompts:\n\n"
|
||||||
|
|
||||||
|
items = prompt.split("|")
|
||||||
|
combination_count = 2 ** (len(items)-1)
|
||||||
|
for combination_num in range(combination_count):
|
||||||
|
current = items[0]
|
||||||
|
label = 'A'
|
||||||
|
|
||||||
|
for n, text in enumerate(items[1:]):
|
||||||
|
if combination_num & (2**n) > 0:
|
||||||
|
current += ("" if text.strip().startswith(",") else ", ") + text
|
||||||
|
label += chr(ord('B') + n)
|
||||||
|
|
||||||
|
comment += " - " + label + "\n"
|
||||||
|
|
||||||
|
prompt_matrix_prompts.append(current)
|
||||||
|
n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size)
|
||||||
|
|
||||||
|
comment += "\nwhere:\n"
|
||||||
|
for n, text in enumerate(items):
|
||||||
|
comment += " " + chr(ord('A') + n) + " = " + items[n] + "\n"
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||||
output_images = []
|
output_images = []
|
||||||
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
|
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
|
||||||
for n in range(n_iter):
|
for n in range(n_iter):
|
||||||
for batch_index, prompts in enumerate(data):
|
if prompt_matrix:
|
||||||
|
prompts = prompt_matrix_prompts[n*batch_size:(n+1)*batch_size]
|
||||||
|
|
||||||
uc = None
|
uc = None
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = model.get_learned_conditioning(len(prompts) * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
shape = [opt_C, height // opt_f, width // opt_f]
|
shape = [opt_C, height // opt_f, width // opt_f]
|
||||||
|
|
||||||
current_seed = seed + n * len(data) + batch_index
|
batch_seed = seed if keep_same_seed else seed + n * len(prompts)
|
||||||
|
|
||||||
|
# we manually generate all input noises because each one should have a specific seed
|
||||||
|
xs = []
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
current_seed = seed if keep_same_seed else batch_seed + i
|
||||||
torch.manual_seed(current_seed)
|
torch.manual_seed(current_seed)
|
||||||
|
xs.append(torch.randn(shape, device=device))
|
||||||
|
x = torch.stack(xs)
|
||||||
|
|
||||||
if is_Kdif:
|
if is_Kdif:
|
||||||
sigmas = model_wrap.get_sigmas(ddim_steps)
|
sigmas = model_wrap.get_sigmas(ddim_steps)
|
||||||
x = torch.randn([n_samples, *shape], device=device) * sigmas[0] # for GPU draw
|
x = x * sigmas[0]
|
||||||
model_wrap_cfg = CFGDenoiser(model_wrap)
|
model_wrap_cfg = CFGDenoiser(model_wrap)
|
||||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
|
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
|
||||||
|
|
||||||
elif sampler is not None:
|
elif sampler is not None:
|
||||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=None)
|
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=len(prompts), shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=x)
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
if not opt.skip_save or not opt.skip_grid:
|
if not opt.skip_save or not opt.skip_grid:
|
||||||
for x_sample in x_samples_ddim:
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
|
||||||
|
@ -222,14 +261,19 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
|
||||||
x_sample = restored_img
|
x_sample = restored_img
|
||||||
|
|
||||||
image = Image.fromarray(x_sample)
|
image = Image.fromarray(x_sample)
|
||||||
|
filename = f"{base_count:05}-{seed if keep_same_seed else batch_seed + i}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
|
||||||
|
|
||||||
|
image.save(os.path.join(sample_path, filename))
|
||||||
|
|
||||||
image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png"))
|
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
base_count += 1
|
base_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
if not opt.skip_grid:
|
||||||
# additionally, save as grid
|
# additionally, save as grid
|
||||||
grid = image_grid(output_images, batch_size)
|
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
|
||||||
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
|
@ -242,8 +286,49 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
|
||||||
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
if len(comment) > 0:
|
||||||
|
info += "\n\n" + comment
|
||||||
|
|
||||||
return output_images, seed, info
|
return output_images, seed, info
|
||||||
|
|
||||||
|
class Flagging(gr.FlaggingCallback):
|
||||||
|
|
||||||
|
def setup(self, components, flagging_dir: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int:
|
||||||
|
os.makedirs("log/images", exist_ok=True)
|
||||||
|
|
||||||
|
# those must match the "dream" function
|
||||||
|
prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
|
||||||
|
|
||||||
|
filenames = []
|
||||||
|
|
||||||
|
with open("log/log.csv", "a", encoding="utf8", newline='') as file:
|
||||||
|
import time
|
||||||
|
import base64
|
||||||
|
|
||||||
|
at_start = file.tell() == 0
|
||||||
|
writer = csv.writer(file)
|
||||||
|
if at_start:
|
||||||
|
writer.writerow(["prompt", "seed", "width", "height", "cfgs", "steps", "filename"])
|
||||||
|
|
||||||
|
filename_base = str(int(time.time() * 1000))
|
||||||
|
for i, filedata in enumerate(images):
|
||||||
|
filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png"
|
||||||
|
|
||||||
|
if filedata.startswith("data:image/png;base64,"):
|
||||||
|
filedata = filedata[len("data:image/png;base64,"):]
|
||||||
|
|
||||||
|
with open(filename, "wb") as imgfile:
|
||||||
|
imgfile.write(base64.decodebytes(filedata.encode('utf-8')))
|
||||||
|
|
||||||
|
filenames.append(filename)
|
||||||
|
|
||||||
|
writer.writerow([prompt, seed, width, height, cfg_scale, ddim_steps, filenames[0]])
|
||||||
|
|
||||||
|
print("Logged:", filenames[0])
|
||||||
|
|
||||||
|
|
||||||
dream_interface = gr.Interface(
|
dream_interface = gr.Interface(
|
||||||
dream,
|
dream,
|
||||||
|
@ -252,10 +337,11 @@ dream_interface = gr.Interface(
|
||||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
||||||
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
|
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
|
||||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||||
|
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
||||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
||||||
gr.Slider(minimum=1, maximum=16, step=1, label='Sampling iterations', value=1),
|
gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
||||||
gr.Slider(minimum=1, maximum=4, step=1, label='Samples per iteration', value=1),
|
gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
||||||
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0),
|
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly should the image follow the prompt)', value=7.0),
|
||||||
gr.Number(label='Seed', value=-1),
|
gr.Number(label='Seed', value=-1),
|
||||||
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
|
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
|
||||||
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
|
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
|
||||||
|
@ -267,7 +353,7 @@ dream_interface = gr.Interface(
|
||||||
],
|
],
|
||||||
title="Stable Diffusion Text-to-Image K",
|
title="Stable Diffusion Text-to-Image K",
|
||||||
description="Generate images from text with Stable Diffusion (using K-LMS)",
|
description="Generate images from text with Stable Diffusion (using K-LMS)",
|
||||||
allow_flagging="never"
|
flagging_callback=Flagging()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -346,8 +432,8 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e
|
||||||
x_sample = restored_img
|
x_sample = restored_img
|
||||||
|
|
||||||
image = Image.fromarray(x_sample)
|
image = Image.fromarray(x_sample)
|
||||||
|
image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"))
|
||||||
|
|
||||||
image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png"))
|
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
base_count += 1
|
base_count += 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue