Fix logspam and live previews
This commit is contained in:
parent
1253199889
commit
21880eb9e5
3 changed files with 41 additions and 31 deletions
|
@ -19,9 +19,10 @@ class UniPCSampler(object):
|
||||||
attr = attr.to(torch.device("cuda"))
|
attr = attr.to(torch.device("cuda"))
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def set_hooks(self, before, after):
|
def set_hooks(self, before_sample, after_sample, after_update):
|
||||||
self.before_sample = before
|
self.before_sample = before_sample
|
||||||
self.after_sample = after
|
self.after_sample = after_sample
|
||||||
|
self.after_update = after_update
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(self,
|
def sample(self,
|
||||||
|
@ -50,9 +51,17 @@ class UniPCSampler(object):
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
if cbs != batch_size:
|
if cbs != batch_size:
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
elif isinstance(conditioning, list):
|
||||||
|
for ctmp in conditioning:
|
||||||
|
if ctmp.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if conditioning.shape[0] != batch_size:
|
if conditioning.shape[0] != batch_size:
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
@ -60,6 +69,7 @@ class UniPCSampler(object):
|
||||||
# sampling
|
# sampling
|
||||||
C, H, W = shape
|
C, H, W = shape
|
||||||
size = (batch_size, C, H, W)
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for UniPC sampling is {size}, eta {eta}')
|
||||||
|
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
|
@ -79,7 +89,7 @@ class UniPCSampler(object):
|
||||||
guidance_scale=unconditional_guidance_scale,
|
guidance_scale=unconditional_guidance_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample)
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
|
||||||
x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
|
x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
|
||||||
|
|
||||||
return x.to(device), None
|
return x.to(device), None
|
||||||
|
|
|
@ -378,7 +378,8 @@ class UniPC:
|
||||||
condition=None,
|
condition=None,
|
||||||
unconditional_condition=None,
|
unconditional_condition=None,
|
||||||
before_sample=None,
|
before_sample=None,
|
||||||
after_sample=None
|
after_sample=None,
|
||||||
|
after_update=None
|
||||||
):
|
):
|
||||||
"""Construct a UniPC.
|
"""Construct a UniPC.
|
||||||
|
|
||||||
|
@ -394,6 +395,7 @@ class UniPC:
|
||||||
self.unconditional_condition = unconditional_condition
|
self.unconditional_condition = unconditional_condition
|
||||||
self.before_sample = before_sample
|
self.before_sample = before_sample
|
||||||
self.after_sample = after_sample
|
self.after_sample = after_sample
|
||||||
|
self.after_update = after_update
|
||||||
|
|
||||||
def dynamic_thresholding_fn(self, x0, t=None):
|
def dynamic_thresholding_fn(self, x0, t=None):
|
||||||
"""
|
"""
|
||||||
|
@ -434,15 +436,6 @@ class UniPC:
|
||||||
noise = self.noise_prediction_fn(x, t)
|
noise = self.noise_prediction_fn(x, t)
|
||||||
dims = x.dim()
|
dims = x.dim()
|
||||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||||
from pprint import pp
|
|
||||||
print("X:")
|
|
||||||
pp(x)
|
|
||||||
print("sigma_t:")
|
|
||||||
pp(sigma_t)
|
|
||||||
print("noise:")
|
|
||||||
pp(noise)
|
|
||||||
print("alpha_t:")
|
|
||||||
pp(alpha_t)
|
|
||||||
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
||||||
if self.thresholding:
|
if self.thresholding:
|
||||||
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
||||||
|
@ -524,7 +517,7 @@ class UniPC:
|
||||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||||
|
|
||||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||||
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||||
ns = self.noise_schedule
|
ns = self.noise_schedule
|
||||||
assert order <= len(model_prev_list)
|
assert order <= len(model_prev_list)
|
||||||
|
|
||||||
|
@ -568,7 +561,7 @@ class UniPC:
|
||||||
A_p = C_inv_p
|
A_p = C_inv_p
|
||||||
|
|
||||||
if use_corrector:
|
if use_corrector:
|
||||||
print('using corrector')
|
#print('using corrector')
|
||||||
C_inv = torch.linalg.inv(C)
|
C_inv = torch.linalg.inv(C)
|
||||||
A_c = C_inv
|
A_c = C_inv
|
||||||
|
|
||||||
|
@ -627,7 +620,7 @@ class UniPC:
|
||||||
return x_t, model_t
|
return x_t, model_t
|
||||||
|
|
||||||
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
||||||
print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
||||||
ns = self.noise_schedule
|
ns = self.noise_schedule
|
||||||
assert order <= len(model_prev_list)
|
assert order <= len(model_prev_list)
|
||||||
dims = x.dim()
|
dims = x.dim()
|
||||||
|
@ -695,7 +688,7 @@ class UniPC:
|
||||||
D1s = None
|
D1s = None
|
||||||
|
|
||||||
if use_corrector:
|
if use_corrector:
|
||||||
print('using corrector')
|
#print('using corrector')
|
||||||
# for order 1, we use a simplified version
|
# for order 1, we use a simplified version
|
||||||
if order == 1:
|
if order == 1:
|
||||||
rhos_c = torch.tensor([0.5], device=b.device)
|
rhos_c = torch.tensor([0.5], device=b.device)
|
||||||
|
@ -755,8 +748,9 @@ class UniPC:
|
||||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||||
device = x.device
|
device = x.device
|
||||||
if method == 'multistep':
|
if method == 'multistep':
|
||||||
assert steps >= order
|
assert steps >= order, "UniPC order must be < sampling steps"
|
||||||
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||||
|
print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps")
|
||||||
assert timesteps.shape[0] - 1 == steps
|
assert timesteps.shape[0] - 1 == steps
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
|
@ -768,6 +762,8 @@ class UniPC:
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||||
if model_x is None:
|
if model_x is None:
|
||||||
model_x = self.model_fn(x, vec_t)
|
model_x = self.model_fn(x, vec_t)
|
||||||
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
model_prev_list.append(model_x)
|
model_prev_list.append(model_x)
|
||||||
t_prev_list.append(vec_t)
|
t_prev_list.append(vec_t)
|
||||||
for step in range(order, steps + 1):
|
for step in range(order, steps + 1):
|
||||||
|
@ -776,13 +772,15 @@ class UniPC:
|
||||||
step_order = min(order, steps + 1 - step)
|
step_order = min(order, steps + 1 - step)
|
||||||
else:
|
else:
|
||||||
step_order = order
|
step_order = order
|
||||||
print('this step order:', step_order)
|
#print('this step order:', step_order)
|
||||||
if step == steps:
|
if step == steps:
|
||||||
print('do not run corrector at the last step')
|
#print('do not run corrector at the last step')
|
||||||
use_corrector = False
|
use_corrector = False
|
||||||
else:
|
else:
|
||||||
use_corrector = True
|
use_corrector = True
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||||
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
for i in range(order - 1):
|
for i in range(order - 1):
|
||||||
t_prev_list[i] = t_prev_list[i + 1]
|
t_prev_list[i] = t_prev_list[i + 1]
|
||||||
model_prev_list[i] = model_prev_list[i + 1]
|
model_prev_list[i] = model_prev_list[i + 1]
|
||||||
|
|
|
@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler:
|
||||||
|
|
||||||
return x, ts, cond, unconditional_conditioning
|
return x, ts, cond, unconditional_conditioning
|
||||||
|
|
||||||
def after_sample(self, x, ts, cond, uncond, res):
|
def update_step(self, last_latent):
|
||||||
if self.is_unipc:
|
|
||||||
# unipc model_fn returns (pred_x0)
|
|
||||||
# p_sample_ddim returns (x_prev, pred_x0)
|
|
||||||
res = (None, res[0])
|
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
|
||||||
else:
|
else:
|
||||||
self.last_latent = res[1]
|
self.last_latent = last_latent
|
||||||
|
|
||||||
sd_samplers_common.store_latent(self.last_latent)
|
sd_samplers_common.store_latent(self.last_latent)
|
||||||
|
|
||||||
|
@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler:
|
||||||
state.sampling_step = self.step
|
state.sampling_step = self.step
|
||||||
shared.total_tqdm.update()
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
def after_sample(self, x, ts, cond, uncond, res):
|
||||||
|
if not self.is_unipc:
|
||||||
|
self.update_step(res[1])
|
||||||
|
|
||||||
return x, ts, cond, uncond, res
|
return x, ts, cond, uncond, res
|
||||||
|
|
||||||
|
def unipc_after_update(self, x, model_x):
|
||||||
|
self.update_step(x)
|
||||||
|
|
||||||
def initialize(self, p):
|
def initialize(self, p):
|
||||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||||
if self.eta != 0.0:
|
if self.eta != 0.0:
|
||||||
|
@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler:
|
||||||
if hasattr(self.sampler, fieldname):
|
if hasattr(self.sampler, fieldname):
|
||||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||||
if self.is_unipc:
|
if self.is_unipc:
|
||||||
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r))
|
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
|
||||||
|
|
||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
|
Loading…
Reference in a new issue