Merge remote-tracking branch 'baai-open-internal/master' into alt-diffusion
This commit is contained in:
commit
3f401cdb64
6 changed files with 310 additions and 8 deletions
72
configs/altdiffusion/ad-inference.yaml
Normal file
72
configs/altdiffusion/ad-inference.yaml
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: modules.xlmr.BertSeriesModelWithTransformation
|
||||||
|
params:
|
||||||
|
name: "XLMR-Large"
|
|
@ -78,17 +78,24 @@ class StableDiffusionModelHijack:
|
||||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
|
||||||
|
if shared.text_model_name == "XLMR-Large":
|
||||||
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||||
|
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
|
||||||
|
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
apply_optimizations()
|
||||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
||||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
apply_optimizations()
|
||||||
|
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
apply_optimizations()
|
|
||||||
fix_checkpoint()
|
fix_checkpoint()
|
||||||
|
|
||||||
def flatten(el):
|
def flatten(el):
|
||||||
|
@ -101,7 +108,11 @@ class StableDiffusionModelHijack:
|
||||||
self.layers = flatten(m)
|
self.layers = flatten(m)
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
|
||||||
|
if shared.text_model_name == "XLMR-Large":
|
||||||
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
|
@ -129,8 +140,8 @@ class StableDiffusionModelHijack:
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||||
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
|
||||||
|
|
||||||
|
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithFixes(torch.nn.Module):
|
class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
|
|
||||||
from modules import prompt_parser, devices
|
from modules import prompt_parser, devices
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
def get_target_prompt_token_count(token_count):
|
def get_target_prompt_token_count(token_count):
|
||||||
return math.ceil(max(token_count, 1) / 75) * 75
|
return math.ceil(max(token_count, 1) / 75) * 75
|
||||||
|
@ -177,6 +177,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text):
|
||||||
|
if shared.text_model_name == "XLMR-Large":
|
||||||
|
return self.wrapped.encode(text)
|
||||||
|
|
||||||
use_old = opts.use_old_emphasis_implementation
|
use_old = opts.use_old_emphasis_implementation
|
||||||
if use_old:
|
if use_old:
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||||
|
@ -254,6 +257,9 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__(wrapped, hijack)
|
super().__init__(wrapped, hijack)
|
||||||
self.tokenizer = wrapped.tokenizer
|
self.tokenizer = wrapped.tokenizer
|
||||||
|
if shared.text_model_name == "XLMR-Large":
|
||||||
|
self.comma_token = None
|
||||||
|
else :
|
||||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||||
|
|
||||||
self.token_mults = {}
|
self.token_mults = {}
|
||||||
|
|
|
@ -108,6 +108,14 @@ restricted_opts = {
|
||||||
"outdir_txt2img_grids",
|
"outdir_txt2img_grids",
|
||||||
"outdir_save",
|
"outdir_save",
|
||||||
}
|
}
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
config = OmegaConf.load(f"{cmd_opts.config}")
|
||||||
|
# XLMR-Large
|
||||||
|
try:
|
||||||
|
text_model_name = config.model.params.cond_stage_config.params.name
|
||||||
|
|
||||||
|
except :
|
||||||
|
text_model_name = "stable_diffusion"
|
||||||
|
|
||||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||||
|
|
||||||
|
|
137
modules/xlmr.py
Normal file
137
modules/xlmr.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
from transformers import BertPreTrainedModel,BertModel,BertConfig
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||||
|
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class BertSeriesConfig(BertConfig):
|
||||||
|
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||||
|
|
||||||
|
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||||
|
self.project_dim = project_dim
|
||||||
|
self.pooler_fn = pooler_fn
|
||||||
|
self.learn_encoder = learn_encoder
|
||||||
|
|
||||||
|
class RobertaSeriesConfig(XLMRobertaConfig):
|
||||||
|
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||||
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
self.project_dim = project_dim
|
||||||
|
self.pooler_fn = pooler_fn
|
||||||
|
self.learn_encoder = learn_encoder
|
||||||
|
|
||||||
|
|
||||||
|
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
config_class = BertSeriesConfig
|
||||||
|
|
||||||
|
def __init__(self, config=None, **kargs):
|
||||||
|
# modify initialization for autoloading
|
||||||
|
if config is None:
|
||||||
|
config = XLMRobertaConfig()
|
||||||
|
config.attention_probs_dropout_prob= 0.1
|
||||||
|
config.bos_token_id=0
|
||||||
|
config.eos_token_id=2
|
||||||
|
config.hidden_act='gelu'
|
||||||
|
config.hidden_dropout_prob=0.1
|
||||||
|
config.hidden_size=1024
|
||||||
|
config.initializer_range=0.02
|
||||||
|
config.intermediate_size=4096
|
||||||
|
config.layer_norm_eps=1e-05
|
||||||
|
config.max_position_embeddings=514
|
||||||
|
|
||||||
|
config.num_attention_heads=16
|
||||||
|
config.num_hidden_layers=24
|
||||||
|
config.output_past=True
|
||||||
|
config.pad_token_id=1
|
||||||
|
config.position_embedding_type= "absolute"
|
||||||
|
|
||||||
|
config.type_vocab_size= 1
|
||||||
|
config.use_cache=True
|
||||||
|
config.vocab_size= 250002
|
||||||
|
config.project_dim = 768
|
||||||
|
config.learn_encoder = False
|
||||||
|
super().__init__(config)
|
||||||
|
self.roberta = XLMRobertaModel(config)
|
||||||
|
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||||
|
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||||
|
self.pooler = lambda x: x[:,0]
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def encode(self,c):
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
text = self.tokenizer(c,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_length=False,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt")
|
||||||
|
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||||
|
text["attention_mask"] = torch.tensor(
|
||||||
|
text['attention_mask']).to(device)
|
||||||
|
features = self(**text)
|
||||||
|
return features['projection_state']
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
) :
|
||||||
|
r"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
|
||||||
|
outputs = self.roberta(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# last module outputs
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
|
||||||
|
# project every module
|
||||||
|
sequence_output_ln = self.pre_LN(sequence_output)
|
||||||
|
|
||||||
|
# pooler
|
||||||
|
pooler_output = self.pooler(sequence_output_ln)
|
||||||
|
pooler_output = self.transformation(pooler_output)
|
||||||
|
projection_state = self.transformation(outputs.last_hidden_state)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'pooler_output':pooler_output,
|
||||||
|
'last_hidden_state':outputs.last_hidden_state,
|
||||||
|
'hidden_states':outputs.hidden_states,
|
||||||
|
'attentions':outputs.attentions,
|
||||||
|
'projection_state':projection_state,
|
||||||
|
'sequence_out': sequence_output
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||||
|
base_model_prefix = 'roberta'
|
||||||
|
config_class= RobertaSeriesConfig
|
68
v2-inference-v.yaml
Normal file
68
v2-inference-v.yaml
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-4
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
parameterization: "v"
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
use_fp16: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
Loading…
Reference in a new issue