Browse Source

refactoring buffer masks dalle

feature/fix_masks
shonenkov 4 years ago
parent
commit
4933e62fce
6 changed files with 30 additions and 47 deletions
  1. +1
    -1
      README.md
  2. +1
    -1
      rudalle/__init__.py
  3. +2
    -4
      rudalle/dalle/__init__.py
  4. +2
    -35
      rudalle/dalle/model.py
  5. +22
    -4
      rudalle/dalle/transformer.py
  6. +2
    -2
      setup.py

+ 1
- 1
README.md View File

@ -7,7 +7,7 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)
```
pip install rudalle==0.0.1rc7
pip install rudalle==0.0.1rc8
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)


+ 1
- 1
rudalle/__init__.py View File

@ -22,4 +22,4 @@ __all__ = [
'image_prompts',
]
__version__ = '0.0.1-rc7'
__version__ = '0.0.1-rc8'

+ 2
- 4
rudalle/dalle/__init__.py View File

@ -21,14 +21,13 @@ MODELS = {
attention_dropout_prob=0.1,
image_tokens_per_dim=32,
text_seq_length=128,
use_masks=True,
cogview_sandwich_layernorm=True,
cogview_pb_relax=True,
vocab_size=16384+128,
image_vocab_size=8192,
),
repo_id='sberbank-ai/rudalle-Malevich',
filename='pytorch_model.bin',
filename='pytorch_model_v2.bin',
full_description='', # TODO
),
'small': dict(
@ -42,7 +41,6 @@ MODELS = {
attention_dropout_prob=0.1,
image_tokens_per_dim=32,
text_seq_length=128,
use_masks=True,
cogview_sandwich_layernorm=True,
cogview_pb_relax=True,
vocab_size=16384+128,
@ -63,7 +61,7 @@ def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir
print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.')
config = MODELS[name]
model = DalleModel(device=device, fp16=fp16, **config['model_params'])
model = DalleModel(device=device, **config['model_params'])
if pretrained:
cache_dir = os.path.join(cache_dir, name)
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])


+ 2
- 35
rudalle/dalle/model.py View File

@ -4,7 +4,6 @@ import torch.nn.functional as F
from einops import rearrange
from .utils import exists, is_empty, init_method_normal
from .image_attention import get_conv_mask, get_row_mask, get_col_mask
from .transformer import DalleTransformer
@ -23,14 +22,11 @@ class DalleModel(torch.nn.Module):
image_tokens_per_dim=32,
image_vocab_size=16384,
loss_img_weight=7,
fp16=False,
use_masks=True,
cogview_sandwich_layernorm=False,
cogview_pb_relax=False):
super(DalleModel, self).__init__()
self.device = device
self.fp16 = fp16
self.image_tokens_per_dim = image_tokens_per_dim
self.image_seq_length = image_tokens_per_dim ** 2
self.text_seq_length = text_seq_length
@ -39,13 +35,6 @@ class DalleModel(torch.nn.Module):
self.vocab_size = vocab_size
self.loss_img_weight = loss_img_weight
# TODO "to"
mask_map = self.prepare_image_masks(num_layers, text_seq_length, image_tokens_per_dim)
if use_masks:
self._mask_map = mask_map
else:
self._mask_map = []
init_method = init_method_normal(std=0.02)
self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size)
@ -74,35 +63,15 @@ class DalleModel(torch.nn.Module):
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
text_seq_length=text_seq_length,
image_tokens_per_dim=image_tokens_per_dim,
cogview_sandwich_layernorm=cogview_sandwich_layernorm,
cogview_pb_relax=cogview_pb_relax,
)
self.transformer._mask_map = self._mask_map
def get_param(self, item):
return getattr(self, item)
def prepare_image_masks(self, num_layers, text_seq_length, image_tokens_per_dim):
row_mask = get_row_mask(text_seq_length, image_tokens_per_dim).to(self.device)
col_mask = get_col_mask(text_seq_length, image_tokens_per_dim).to(self.device)
conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim).to(self.device)
if self.fp16:
row_mask = row_mask.half()
col_mask = col_mask.half()
conv_mask = conv_mask.half()
self.register_buffer('row_mask', row_mask)
self.register_buffer('col_mask', col_mask)
self.register_buffer('conv_mask', conv_mask)
mask_map = []
for i in range(num_layers):
if ((i - 1) % 4 == 0):
mask_map.append(col_mask)
elif i != num_layers - 1:
mask_map.append(row_mask)
else:
mask_map.append(conv_mask)
return mask_map
def get_image_pos_embeddings(self, image_input_ids, past_length=0):
input_shape = image_input_ids.size()
row_ids = torch.arange(past_length, input_shape[-1] + past_length,
@ -172,6 +141,4 @@ class DalleModel(torch.nn.Module):
def to(self, device, *args, **kwargs):
self.device = device
self._mask_map = [mask.to(device) for mask in self._mask_map]
self.transformer._mask_map = [mask.to(device) for mask in self.transformer._mask_map]
return super().to(device, *args, **kwargs)

+ 22
- 4
rudalle/dalle/transformer.py View File

@ -5,6 +5,7 @@ import torch
from torch.nn import LayerNorm
from .utils import divide, split_tensor_along_last_dim
from .image_attention import get_conv_mask, get_row_mask, get_col_mask
@torch.jit.script
@ -45,9 +46,11 @@ class DalleTransformer(torch.nn.Module):
_mask_map = []
def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob,
layernorm_epsilon=1.0e-5, cogview_sandwich_layernorm=False, cogview_pb_relax=False):
text_seq_length, image_tokens_per_dim, layernorm_epsilon=1.0e-5,
cogview_sandwich_layernorm=False, cogview_pb_relax=False):
super(DalleTransformer, self).__init__()
self.num_layers = num_layers
# CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
self.cogview_pb_relax = cogview_pb_relax
@ -64,15 +67,30 @@ class DalleTransformer(torch.nn.Module):
) for _ in range(num_layers)
])
row_mask = get_row_mask(text_seq_length, image_tokens_per_dim)
col_mask = get_col_mask(text_seq_length, image_tokens_per_dim)
conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim)
self.register_buffer('row_mask', row_mask)
self.register_buffer('col_mask', col_mask)
self.register_buffer('conv_mask', conv_mask)
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def _get_layer_mask(self, layer_id):
if ((layer_id - 1) % 4 == 0):
layer_mask = self.col_mask
elif layer_id != self.num_layers - 1:
layer_mask = self.row_mask
else:
layer_mask = self.conv_mask
return layer_mask
def forward(self, hidden_states, attention_mask, has_cache, use_cache):
for i, layer in enumerate(self.layers):
mask = attention_mask
if len(self._mask_map):
layer_mask = self._mask_map[i][:mask.size(2), :mask.size(3)]
mask = torch.mul(attention_mask, layer_mask)
layer_mask = self._get_layer_mask(i)[:mask.size(2), :mask.size(3)]
mask = torch.mul(attention_mask, layer_mask)
hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache)
output = self.final_layernorm(hidden_states)
return output, present_has_cache


+ 2
- 2
setup.py View File

@ -45,8 +45,8 @@ setup(
name='rudalle',
version=get_version(),
author='SberAI, SberDevices',
author_email='',
description='',
author_email='[email protected]',
description='ruDALL-E generate images from texts in Russian language',
packages=['rudalle', 'rudalle/dalle', 'rudalle/realesrgan', 'rudalle/ruclip', 'rudalle/vae'],
package_data={'rudalle/vae': ['*.yml']},
install_requires=get_requirements(),


Loading…
Cancel
Save