alt-diffusion integration
This commit is contained in:
parent
3f401cdb64
commit
f34c734172
6 changed files with 50 additions and 26 deletions
|
@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion
|
||||||
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
|
|
||||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ def fix_checkpoint():
|
||||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|
||||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
|
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
fixes = None
|
fixes = None
|
||||||
comments = []
|
comments = []
|
||||||
|
@ -79,21 +80,22 @@ class StableDiffusionModelHijack:
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
|
|
||||||
if shared.text_model_name == "XLMR-Large":
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||||
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
apply_optimizations()
|
|
||||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
||||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
apply_optimizations()
|
|
||||||
|
apply_optimizations()
|
||||||
|
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
fix_checkpoint()
|
fix_checkpoint()
|
||||||
|
@ -109,7 +111,7 @@ class StableDiffusionModelHijack:
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
|
|
||||||
if shared.text_model_name == "XLMR-Large":
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
|
|
|
@ -4,7 +4,6 @@ import torch
|
||||||
|
|
||||||
from modules import prompt_parser, devices
|
from modules import prompt_parser, devices
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
import modules.shared as shared
|
|
||||||
|
|
||||||
def get_target_prompt_token_count(token_count):
|
def get_target_prompt_token_count(token_count):
|
||||||
return math.ceil(max(token_count, 1) / 75) * 75
|
return math.ceil(max(token_count, 1) / 75) * 75
|
||||||
|
@ -177,9 +176,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text):
|
||||||
if shared.text_model_name == "XLMR-Large":
|
|
||||||
return self.wrapped.encode(text)
|
|
||||||
|
|
||||||
use_old = opts.use_old_emphasis_implementation
|
use_old = opts.use_old_emphasis_implementation
|
||||||
if use_old:
|
if use_old:
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||||
|
@ -257,13 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__(wrapped, hijack)
|
super().__init__(wrapped, hijack)
|
||||||
self.tokenizer = wrapped.tokenizer
|
self.tokenizer = wrapped.tokenizer
|
||||||
if shared.text_model_name == "XLMR-Large":
|
|
||||||
self.comma_token = None
|
vocab = self.tokenizer.get_vocab()
|
||||||
else :
|
|
||||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
self.comma_token = vocab.get(',</w>', None)
|
||||||
|
|
||||||
self.token_mults = {}
|
self.token_mults = {}
|
||||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||||
for text, ident in tokens_with_parens:
|
for text, ident in tokens_with_parens:
|
||||||
mult = 1.0
|
mult = 1.0
|
||||||
for c in text:
|
for c in text:
|
||||||
|
|
34
modules/sd_hijack_xlmr.py
Normal file
34
modules/sd_hijack_xlmr.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
import open_clip.tokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import sd_hijack_clip, devices
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__(wrapped, hijack)
|
||||||
|
|
||||||
|
self.id_start = wrapped.config.bos_token_id
|
||||||
|
self.id_end = wrapped.config.eos_token_id
|
||||||
|
self.id_pad = wrapped.config.pad_token_id
|
||||||
|
|
||||||
|
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
|
||||||
|
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
|
||||||
|
# layer to work with - you have to use the last
|
||||||
|
|
||||||
|
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
|
||||||
|
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
|
||||||
|
z = features['projection_state']
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
embedding_layer = self.wrapped.roberta.embeddings
|
||||||
|
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||||
|
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||||
|
|
||||||
|
return embedded
|
|
@ -23,7 +23,7 @@ demo = None
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
default_sd_model_file = sd_model_file
|
default_sd_model_file = sd_model_file
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||||
|
@ -108,14 +108,6 @@ restricted_opts = {
|
||||||
"outdir_txt2img_grids",
|
"outdir_txt2img_grids",
|
||||||
"outdir_save",
|
"outdir_save",
|
||||||
}
|
}
|
||||||
from omegaconf import OmegaConf
|
|
||||||
config = OmegaConf.load(f"{cmd_opts.config}")
|
|
||||||
# XLMR-Large
|
|
||||||
try:
|
|
||||||
text_model_name = config.model.params.cond_stage_config.params.name
|
|
||||||
|
|
||||||
except :
|
|
||||||
text_model_name = "stable_diffusion"
|
|
||||||
|
|
||||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue