diff --git a/README.md b/README.md index f383ddd..b63d801 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master) ``` -pip install rudalle==0.0.1rc7 +pip install rudalle==0.0.1rc8 ``` ### 🤗 HF Models: [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) diff --git a/rudalle/__init__.py b/rudalle/__init__.py index 47af0e6..e7e78eb 100644 --- a/rudalle/__init__.py +++ b/rudalle/__init__.py @@ -22,4 +22,4 @@ __all__ = [ 'image_prompts', ] -__version__ = '0.0.1-rc7' +__version__ = '0.0.1-rc8' diff --git a/rudalle/dalle/__init__.py b/rudalle/dalle/__init__.py index e9d82cc..c4ac465 100644 --- a/rudalle/dalle/__init__.py +++ b/rudalle/dalle/__init__.py @@ -21,14 +21,13 @@ MODELS = { attention_dropout_prob=0.1, image_tokens_per_dim=32, text_seq_length=128, - use_masks=True, cogview_sandwich_layernorm=True, cogview_pb_relax=True, vocab_size=16384+128, image_vocab_size=8192, ), repo_id='sberbank-ai/rudalle-Malevich', - filename='pytorch_model.bin', + filename='pytorch_model_v2.bin', full_description='', # TODO ), 'small': dict( @@ -42,7 +41,6 @@ MODELS = { attention_dropout_prob=0.1, image_tokens_per_dim=32, text_seq_length=128, - use_masks=True, cogview_sandwich_layernorm=True, cogview_pb_relax=True, vocab_size=16384+128, @@ -63,7 +61,7 @@ def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.') config = MODELS[name] - model = DalleModel(device=device, fp16=fp16, **config['model_params']) + model = DalleModel(device=device, **config['model_params']) if pretrained: cache_dir = os.path.join(cache_dir, name) config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) diff --git a/rudalle/dalle/model.py b/rudalle/dalle/model.py index 236683d..040e2ec 100644 --- a/rudalle/dalle/model.py +++ b/rudalle/dalle/model.py @@ -4,7 +4,6 @@ import torch.nn.functional as F from einops import rearrange from .utils import exists, is_empty, init_method_normal -from .image_attention import get_conv_mask, get_row_mask, get_col_mask from .transformer import DalleTransformer @@ -23,14 +22,11 @@ class DalleModel(torch.nn.Module): image_tokens_per_dim=32, image_vocab_size=16384, loss_img_weight=7, - fp16=False, - use_masks=True, cogview_sandwich_layernorm=False, cogview_pb_relax=False): super(DalleModel, self).__init__() self.device = device - self.fp16 = fp16 self.image_tokens_per_dim = image_tokens_per_dim self.image_seq_length = image_tokens_per_dim ** 2 self.text_seq_length = text_seq_length @@ -39,13 +35,6 @@ class DalleModel(torch.nn.Module): self.vocab_size = vocab_size self.loss_img_weight = loss_img_weight - # TODO "to" - mask_map = self.prepare_image_masks(num_layers, text_seq_length, image_tokens_per_dim) - if use_masks: - self._mask_map = mask_map - else: - self._mask_map = [] - init_method = init_method_normal(std=0.02) self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size) @@ -74,35 +63,15 @@ class DalleModel(torch.nn.Module): num_attention_heads, attention_dropout_prob, output_dropout_prob, + text_seq_length=text_seq_length, + image_tokens_per_dim=image_tokens_per_dim, cogview_sandwich_layernorm=cogview_sandwich_layernorm, cogview_pb_relax=cogview_pb_relax, ) - self.transformer._mask_map = self._mask_map def get_param(self, item): return getattr(self, item) - def prepare_image_masks(self, num_layers, text_seq_length, image_tokens_per_dim): - row_mask = get_row_mask(text_seq_length, image_tokens_per_dim).to(self.device) - col_mask = get_col_mask(text_seq_length, image_tokens_per_dim).to(self.device) - conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim).to(self.device) - if self.fp16: - row_mask = row_mask.half() - col_mask = col_mask.half() - conv_mask = conv_mask.half() - self.register_buffer('row_mask', row_mask) - self.register_buffer('col_mask', col_mask) - self.register_buffer('conv_mask', conv_mask) - mask_map = [] - for i in range(num_layers): - if ((i - 1) % 4 == 0): - mask_map.append(col_mask) - elif i != num_layers - 1: - mask_map.append(row_mask) - else: - mask_map.append(conv_mask) - return mask_map - def get_image_pos_embeddings(self, image_input_ids, past_length=0): input_shape = image_input_ids.size() row_ids = torch.arange(past_length, input_shape[-1] + past_length, @@ -172,6 +141,4 @@ class DalleModel(torch.nn.Module): def to(self, device, *args, **kwargs): self.device = device - self._mask_map = [mask.to(device) for mask in self._mask_map] - self.transformer._mask_map = [mask.to(device) for mask in self.transformer._mask_map] return super().to(device, *args, **kwargs) diff --git a/rudalle/dalle/transformer.py b/rudalle/dalle/transformer.py index 8dabffc..176b03e 100755 --- a/rudalle/dalle/transformer.py +++ b/rudalle/dalle/transformer.py @@ -5,6 +5,7 @@ import torch from torch.nn import LayerNorm from .utils import divide, split_tensor_along_last_dim +from .image_attention import get_conv_mask, get_row_mask, get_col_mask @torch.jit.script @@ -45,9 +46,11 @@ class DalleTransformer(torch.nn.Module): _mask_map = [] def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, - layernorm_epsilon=1.0e-5, cogview_sandwich_layernorm=False, cogview_pb_relax=False): + text_seq_length, image_tokens_per_dim, layernorm_epsilon=1.0e-5, + cogview_sandwich_layernorm=False, cogview_pb_relax=False): super(DalleTransformer, self).__init__() + self.num_layers = num_layers # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf self.cogview_pb_relax = cogview_pb_relax @@ -64,15 +67,30 @@ class DalleTransformer(torch.nn.Module): ) for _ in range(num_layers) ]) + row_mask = get_row_mask(text_seq_length, image_tokens_per_dim) + col_mask = get_col_mask(text_seq_length, image_tokens_per_dim) + conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim) + self.register_buffer('row_mask', row_mask) + self.register_buffer('col_mask', col_mask) + self.register_buffer('conv_mask', conv_mask) + # Final layer norm before output. self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + def _get_layer_mask(self, layer_id): + if ((layer_id - 1) % 4 == 0): + layer_mask = self.col_mask + elif layer_id != self.num_layers - 1: + layer_mask = self.row_mask + else: + layer_mask = self.conv_mask + return layer_mask + def forward(self, hidden_states, attention_mask, has_cache, use_cache): for i, layer in enumerate(self.layers): mask = attention_mask - if len(self._mask_map): - layer_mask = self._mask_map[i][:mask.size(2), :mask.size(3)] - mask = torch.mul(attention_mask, layer_mask) + layer_mask = self._get_layer_mask(i)[:mask.size(2), :mask.size(3)] + mask = torch.mul(attention_mask, layer_mask) hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache) output = self.final_layernorm(hidden_states) return output, present_has_cache diff --git a/setup.py b/setup.py index 6fdb599..614cc1b 100644 --- a/setup.py +++ b/setup.py @@ -45,8 +45,8 @@ setup( name='rudalle', version=get_version(), author='SberAI, SberDevices', - author_email='', - description='', + author_email='shonenkov@phystech.edu', + description='ruDALL-E generate images from texts in Russian language', packages=['rudalle', 'rudalle/dalle', 'rudalle/realesrgan', 'rudalle/ruclip', 'rudalle/vae'], package_data={'rudalle/vae': ['*.yml']}, install_requires=get_requirements(),