do not replace entire unet for the resolution hack
This commit is contained in:
parent
2641d1b83b
commit
7dbfd8a7d8
3 changed files with 33 additions and 30 deletions
|
@ -11,7 +11,7 @@ import modules.textual_inversion.textual_inversion
|
||||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import opts, device, cmd_opts
|
from modules.shared import opts, device, cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
|
||||||
|
|
||||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||||
|
|
||||||
|
@ -35,11 +35,12 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
|
||||||
ldm.modules.attention.print = lambda *args: None
|
ldm.modules.attention.print = lambda *args: None
|
||||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
||||||
|
|
||||||
|
|
||||||
def apply_optimizations():
|
def apply_optimizations():
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||||
print("Applying xformers cross attention optimization.")
|
print("Applying xformers cross attention optimization.")
|
||||||
|
|
|
@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x):
|
||||||
return x + out
|
return x + out
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return cross_attention_attnblock_forward(self, x)
|
return cross_attention_attnblock_forward(self, x)
|
||||||
|
|
||||||
def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
|
||||||
assert (y is not None) == (
|
|
||||||
self.num_classes is not None
|
|
||||||
), "must specify y if and only if the model is class-conditional"
|
|
||||||
hs = []
|
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
|
||||||
emb = self.time_embed(t_emb)
|
|
||||||
|
|
||||||
if self.num_classes is not None:
|
|
||||||
assert y.shape == (x.shape[0],)
|
|
||||||
emb = emb + self.label_emb(y)
|
|
||||||
|
|
||||||
h = x.type(self.dtype)
|
|
||||||
for module in self.input_blocks:
|
|
||||||
h = module(h, emb, context)
|
|
||||||
hs.append(h)
|
|
||||||
h = self.middle_block(h, emb, context)
|
|
||||||
for module in self.output_blocks:
|
|
||||||
if h.shape[-2:] != hs[-1].shape[-2:]:
|
|
||||||
h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
|
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
|
||||||
h = module(h, emb, context)
|
|
||||||
h = h.type(x.dtype)
|
|
||||||
if self.predict_codebook_ids:
|
|
||||||
return self.id_predictor(h)
|
|
||||||
else:
|
|
||||||
return self.out(h)
|
|
||||||
|
|
30
modules/sd_hijack_unet.py
Normal file
30
modules/sd_hijack_unet.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TorchHijackForUnet:
|
||||||
|
"""
|
||||||
|
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||||
|
this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item == 'cat':
|
||||||
|
return self.cat
|
||||||
|
|
||||||
|
if hasattr(torch, item):
|
||||||
|
return getattr(torch, item)
|
||||||
|
|
||||||
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||||
|
|
||||||
|
def cat(self, tensors, *args, **kwargs):
|
||||||
|
if len(tensors) == 2:
|
||||||
|
a, b = tensors
|
||||||
|
if a.shape[-2:] != b.shape[-2:]:
|
||||||
|
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||||
|
|
||||||
|
tensors = (a, b)
|
||||||
|
|
||||||
|
return torch.cat(tensors, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
th = TorchHijackForUnet()
|
Loading…
Reference in a new issue