Browse Source

Merge pull request #47 from sberbank-ai/feature/fix_masks

refactoring buffer masks dalle
pull/49/head
Alex 4 years ago
committed by GitHub
parent
commit
2eedbb2f85
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 30 additions and 47 deletions
  1. +1
    -1
      README.md
  2. +1
    -1
      rudalle/__init__.py
  3. +2
    -4
      rudalle/dalle/__init__.py
  4. +2
    -35
      rudalle/dalle/model.py
  5. +22
    -4
      rudalle/dalle/transformer.py
  6. +2
    -2
      setup.py

+ 1
- 1
README.md View File

@ -7,7 +7,7 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)
``` ```
pip install rudalle==0.0.1rc7
pip install rudalle==0.0.1rc8
``` ```
### 🤗 HF Models: ### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)


+ 1
- 1
rudalle/__init__.py View File

@ -22,4 +22,4 @@ __all__ = [
'image_prompts', 'image_prompts',
] ]
__version__ = '0.0.1-rc7'
__version__ = '0.0.1-rc8'

+ 2
- 4
rudalle/dalle/__init__.py View File

@ -21,14 +21,13 @@ MODELS = {
attention_dropout_prob=0.1, attention_dropout_prob=0.1,
image_tokens_per_dim=32, image_tokens_per_dim=32,
text_seq_length=128, text_seq_length=128,
use_masks=True,
cogview_sandwich_layernorm=True, cogview_sandwich_layernorm=True,
cogview_pb_relax=True, cogview_pb_relax=True,
vocab_size=16384+128, vocab_size=16384+128,
image_vocab_size=8192, image_vocab_size=8192,
), ),
repo_id='sberbank-ai/rudalle-Malevich', repo_id='sberbank-ai/rudalle-Malevich',
filename='pytorch_model.bin',
filename='pytorch_model_v2.bin',
full_description='', # TODO full_description='', # TODO
), ),
'small': dict( 'small': dict(
@ -42,7 +41,6 @@ MODELS = {
attention_dropout_prob=0.1, attention_dropout_prob=0.1,
image_tokens_per_dim=32, image_tokens_per_dim=32,
text_seq_length=128, text_seq_length=128,
use_masks=True,
cogview_sandwich_layernorm=True, cogview_sandwich_layernorm=True,
cogview_pb_relax=True, cogview_pb_relax=True,
vocab_size=16384+128, vocab_size=16384+128,
@ -63,7 +61,7 @@ def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir
print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.') print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.')
config = MODELS[name] config = MODELS[name]
model = DalleModel(device=device, fp16=fp16, **config['model_params'])
model = DalleModel(device=device, **config['model_params'])
if pretrained: if pretrained:
cache_dir = os.path.join(cache_dir, name) cache_dir = os.path.join(cache_dir, name)
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])


+ 2
- 35
rudalle/dalle/model.py View File

@ -4,7 +4,6 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from .utils import exists, is_empty, init_method_normal from .utils import exists, is_empty, init_method_normal
from .image_attention import get_conv_mask, get_row_mask, get_col_mask
from .transformer import DalleTransformer from .transformer import DalleTransformer
@ -23,14 +22,11 @@ class DalleModel(torch.nn.Module):
image_tokens_per_dim=32, image_tokens_per_dim=32,
image_vocab_size=16384, image_vocab_size=16384,
loss_img_weight=7, loss_img_weight=7,
fp16=False,
use_masks=True,
cogview_sandwich_layernorm=False, cogview_sandwich_layernorm=False,
cogview_pb_relax=False): cogview_pb_relax=False):
super(DalleModel, self).__init__() super(DalleModel, self).__init__()
self.device = device self.device = device
self.fp16 = fp16
self.image_tokens_per_dim = image_tokens_per_dim self.image_tokens_per_dim = image_tokens_per_dim
self.image_seq_length = image_tokens_per_dim ** 2 self.image_seq_length = image_tokens_per_dim ** 2
self.text_seq_length = text_seq_length self.text_seq_length = text_seq_length
@ -39,13 +35,6 @@ class DalleModel(torch.nn.Module):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.loss_img_weight = loss_img_weight self.loss_img_weight = loss_img_weight
# TODO "to"
mask_map = self.prepare_image_masks(num_layers, text_seq_length, image_tokens_per_dim)
if use_masks:
self._mask_map = mask_map
else:
self._mask_map = []
init_method = init_method_normal(std=0.02) init_method = init_method_normal(std=0.02)
self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size) self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size)
@ -74,35 +63,15 @@ class DalleModel(torch.nn.Module):
num_attention_heads, num_attention_heads,
attention_dropout_prob, attention_dropout_prob,
output_dropout_prob, output_dropout_prob,
text_seq_length=text_seq_length,
image_tokens_per_dim=image_tokens_per_dim,
cogview_sandwich_layernorm=cogview_sandwich_layernorm, cogview_sandwich_layernorm=cogview_sandwich_layernorm,
cogview_pb_relax=cogview_pb_relax, cogview_pb_relax=cogview_pb_relax,
) )
self.transformer._mask_map = self._mask_map
def get_param(self, item): def get_param(self, item):
return getattr(self, item) return getattr(self, item)
def prepare_image_masks(self, num_layers, text_seq_length, image_tokens_per_dim):
row_mask = get_row_mask(text_seq_length, image_tokens_per_dim).to(self.device)
col_mask = get_col_mask(text_seq_length, image_tokens_per_dim).to(self.device)
conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim).to(self.device)
if self.fp16:
row_mask = row_mask.half()
col_mask = col_mask.half()
conv_mask = conv_mask.half()
self.register_buffer('row_mask', row_mask)
self.register_buffer('col_mask', col_mask)
self.register_buffer('conv_mask', conv_mask)
mask_map = []
for i in range(num_layers):
if ((i - 1) % 4 == 0):
mask_map.append(col_mask)
elif i != num_layers - 1:
mask_map.append(row_mask)
else:
mask_map.append(conv_mask)
return mask_map
def get_image_pos_embeddings(self, image_input_ids, past_length=0): def get_image_pos_embeddings(self, image_input_ids, past_length=0):
input_shape = image_input_ids.size() input_shape = image_input_ids.size()
row_ids = torch.arange(past_length, input_shape[-1] + past_length, row_ids = torch.arange(past_length, input_shape[-1] + past_length,
@ -172,6 +141,4 @@ class DalleModel(torch.nn.Module):
def to(self, device, *args, **kwargs): def to(self, device, *args, **kwargs):
self.device = device self.device = device
self._mask_map = [mask.to(device) for mask in self._mask_map]
self.transformer._mask_map = [mask.to(device) for mask in self.transformer._mask_map]
return super().to(device, *args, **kwargs) return super().to(device, *args, **kwargs)

+ 22
- 4
rudalle/dalle/transformer.py View File

@ -5,6 +5,7 @@ import torch
from torch.nn import LayerNorm from torch.nn import LayerNorm
from .utils import divide, split_tensor_along_last_dim from .utils import divide, split_tensor_along_last_dim
from .image_attention import get_conv_mask, get_row_mask, get_col_mask
@torch.jit.script @torch.jit.script
@ -45,9 +46,11 @@ class DalleTransformer(torch.nn.Module):
_mask_map = [] _mask_map = []
def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob,
layernorm_epsilon=1.0e-5, cogview_sandwich_layernorm=False, cogview_pb_relax=False):
text_seq_length, image_tokens_per_dim, layernorm_epsilon=1.0e-5,
cogview_sandwich_layernorm=False, cogview_pb_relax=False):
super(DalleTransformer, self).__init__() super(DalleTransformer, self).__init__()
self.num_layers = num_layers
# CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
self.cogview_pb_relax = cogview_pb_relax self.cogview_pb_relax = cogview_pb_relax
@ -64,15 +67,30 @@ class DalleTransformer(torch.nn.Module):
) for _ in range(num_layers) ) for _ in range(num_layers)
]) ])
row_mask = get_row_mask(text_seq_length, image_tokens_per_dim)
col_mask = get_col_mask(text_seq_length, image_tokens_per_dim)
conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim)
self.register_buffer('row_mask', row_mask)
self.register_buffer('col_mask', col_mask)
self.register_buffer('conv_mask', conv_mask)
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def _get_layer_mask(self, layer_id):
if ((layer_id - 1) % 4 == 0):
layer_mask = self.col_mask
elif layer_id != self.num_layers - 1:
layer_mask = self.row_mask
else:
layer_mask = self.conv_mask
return layer_mask
def forward(self, hidden_states, attention_mask, has_cache, use_cache): def forward(self, hidden_states, attention_mask, has_cache, use_cache):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
mask = attention_mask mask = attention_mask
if len(self._mask_map):
layer_mask = self._mask_map[i][:mask.size(2), :mask.size(3)]
mask = torch.mul(attention_mask, layer_mask)
layer_mask = self._get_layer_mask(i)[:mask.size(2), :mask.size(3)]
mask = torch.mul(attention_mask, layer_mask)
hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache) hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache)
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
return output, present_has_cache return output, present_has_cache


+ 2
- 2
setup.py View File

@ -45,8 +45,8 @@ setup(
name='rudalle', name='rudalle',
version=get_version(), version=get_version(),
author='SberAI, SberDevices', author='SberAI, SberDevices',
author_email='',
description='',
author_email='[email protected]',
description='ruDALL-E generate images from texts in Russian language',
packages=['rudalle', 'rudalle/dalle', 'rudalle/realesrgan', 'rudalle/ruclip', 'rudalle/vae'], packages=['rudalle', 'rudalle/dalle', 'rudalle/realesrgan', 'rudalle/ruclip', 'rudalle/vae'],
package_data={'rudalle/vae': ['*.yml']}, package_data={'rudalle/vae': ['*.yml']},
install_requires=get_requirements(), install_requires=get_requirements(),


Loading…
Cancel
Save