Browse Source

initial commit

pull/3/head
Alex Shonenkov 5 years ago
commit
e96631a867
23 changed files with 1435 additions and 0 deletions
  1. +168
    -0
      .gitignore
  2. +31
    -0
      .pre-commit-config.yaml
  3. +4
    -0
      requirements-test.txt
  4. +8
    -0
      requirements.txt
  5. +15
    -0
      rudalle/__init__.py
  6. +76
    -0
      rudalle/dalle/__init__.py
  7. +60
    -0
      rudalle/dalle/fp16.py
  8. +48
    -0
      rudalle/dalle/image_attention.py
  9. +171
    -0
      rudalle/dalle/model.py
  10. +332
    -0
      rudalle/dalle/transformer.py
  11. +54
    -0
      rudalle/dalle/utils.py
  12. +61
    -0
      rudalle/pipelines.py
  13. +64
    -0
      rudalle/tokenizer.py
  14. +36
    -0
      rudalle/utils.py
  15. +24
    -0
      rudalle/vae/__init__.py
  16. +100
    -0
      rudalle/vae/model.py
  17. +34
    -0
      rudalle/vae/vqgan.gumbelf8-sber.config.yml
  18. +13
    -0
      setup.cfg
  19. +0
    -0
      tests/__init__.py
  20. +39
    -0
      tests/conftest.py
  21. +31
    -0
      tests/test_dalle.py
  22. +17
    -0
      tests/test_tokenizer.py
  23. +49
    -0
      tests/test_vae.py

+ 168
- 0
.gitignore View File

@ -0,0 +1,168 @@
# Created by .ignore support plugin (hsz.mobi)
### JetBrains template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
settings/local.py
logs/*.log
# User-specific stuff:
.idea/
# Sensitive or high-churn files:
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.xml
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
# Gradle:
.idea/**/gradle.xml
.idea/**/libraries
# CMake
cmake-build-debug/
# Mongo Explorer plugin:
.idea/**/mongoSettings.xml
## File-based project format:
*.iws
## Plugin-specific files:
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
/tests/load_tests/logs/*
/tests/.pytest_cache/
ws_test.py
/.vscode/
.s3_cache/
mlruns
*.pyc
*.swp
*.pt
*.bin
.vscode/
runs/
jupyters/custom_*
*logs/

+ 31
- 0
.pre-commit-config.yaml View File

@ -0,0 +1,31 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.2.3
hooks:
- id: check-docstring-first
stages:
- commit
- push
- id: check-merge-conflict
stages:
- push
- id: double-quote-string-fixer
stages:
- commit
- push
- id: fix-encoding-pragma
stages:
- commit
- push
- id: flake8
args: ['--config=setup.cfg']
stages:
- commit
- push
- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v1.4.4
hooks:
- id: autopep8
stages:
- commit
- push

+ 4
- 0
requirements-test.txt View File

@ -0,0 +1,4 @@
-r requirements.txt
pytest
pytest-cov
pre-commit

+ 8
- 0
requirements.txt View File

@ -0,0 +1,8 @@
taming-transformers==0.0.1
more_itertools==8.10.0
transformers==4.10.2
youtokentome==1.0.6
einops==0.3.2
torch
torchvision
matplotlib

+ 15
- 0
rudalle/__init__.py View File

@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
from .vae import get_vae
from .dalle import get_rudalle_model
from .tokenizer import get_tokenizer
from . import vae, dalle, tokenizer
__all__ = [
'get_vae',
'get_rudalle_model',
'get_tokenizer',
'vae',
'dalle',
'tokenizer',
]

+ 76
- 0
rudalle/dalle/__init__.py View File

@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
import os
import torch
from huggingface_hub import hf_hub_url, cached_download
from .model import DalleModel
from .fp16 import FP16Module
MODELS = {
'Malevich': dict(
description='◼️ Malevich is 1.3 billion params model from the family GPT3-like, '
'that uses Russian language and text+image multi-modality.',
model_params=dict(
num_layers=24,
hidden_size=2048,
num_attention_heads=16,
embedding_dropout_prob=0.1,
output_dropout_prob=0.1,
attention_dropout_prob=0.1,
image_tokens_per_dim=32,
text_seq_length=128,
use_masks=True,
cogview_sandwich_layernorm=True,
cogview_pb_relax=True,
vocab_size=16384+128,
image_vocab_size=8192,
),
repo_id='sberbank-ai/rudalle-Malevich',
filename='pytorch_model.bin',
full_description='', # TODO
),
'small': dict(
description='',
model_params=dict(
num_layers=12,
hidden_size=768,
num_attention_heads=12,
embedding_dropout_prob=0.1,
output_dropout_prob=0.1,
attention_dropout_prob=0.1,
image_tokens_per_dim=32,
text_seq_length=128,
use_masks=True,
cogview_sandwich_layernorm=True,
cogview_pb_relax=True,
vocab_size=16384+128,
image_vocab_size=8192,
),
repo_id='',
filename='',
full_description='', # TODO
),
}
def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle'):
# TODO docstring
assert name in MODELS
config = MODELS[name]
model = DalleModel(device=device, fp16=fp16, **config['model_params'])
if pretrained:
cache_dir = os.path.join(cache_dir, name)
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu')
model.load_state_dict(checkpoint)
if fp16:
model = FP16Module(model)
model.eval()
model = model.to(device)
if config['description'] and pretrained:
print(config['description'])
return model

+ 60
- 0
rudalle/dalle/fp16.py View File

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, FLOAT_TYPES):
val = val.half()
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, HALF_TYPES):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16Module(nn.Module):
def __init__(self, module):
super(FP16Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
def get_param(self, item):
return self.module.get_param(item)

+ 48
- 0
rudalle/dalle/image_attention.py View File

@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
import torch
def _init_mask(text_tokens, image_tokens_per_dim):
attn_size = text_tokens + image_tokens_per_dim**2
mask = torch.tril(torch.ones(attn_size, attn_size))
return mask
def get_row_mask(text_tokens=256, image_tokens_per_dim=32):
mask = _init_mask(text_tokens, image_tokens_per_dim)
step = image_tokens_per_dim + 1
for col in range(text_tokens, mask.size(1)):
mask[col + step:, col] = 0.0
return mask
def get_col_mask(text_tokens=256, image_tokens_per_dim=32):
mask = _init_mask(text_tokens, image_tokens_per_dim)
step = image_tokens_per_dim - 1
for col in range(text_tokens, mask.size(1)):
for i in range(1, mask.size(0), step+1):
mask[col + i: col + i + step, col] = 0.0
return mask
def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11):
mask = _init_mask(text_tokens, image_tokens_per_dim)
shift = kernel // 2
for pos in range(text_tokens, mask.size(1)):
mask[pos+1:, pos] = 0.0
img = torch.zeros(image_tokens_per_dim, image_tokens_per_dim)
pixel_id = pos - text_tokens
row = pixel_id // image_tokens_per_dim
col = pixel_id % image_tokens_per_dim
for r in range(-shift, shift+1):
for c in range(-shift, shift+1):
c_abs = (c + col) % image_tokens_per_dim
r_abs = (r + row) % image_tokens_per_dim
img[r_abs, c_abs] = 0.2
cell_id = r_abs * image_tokens_per_dim + c_abs
if text_tokens + cell_id > pos:
mask[text_tokens + cell_id, pos] = 1.0
img[row, col] = 1.0
return mask

+ 171
- 0
rudalle/dalle/model.py View File

@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
from einops import rearrange
from .utils import exists, is_empty, init_method_normal
from .image_attention import get_conv_mask, get_row_mask, get_col_mask
from .transformer import DalleTransformer
class DalleModel(torch.nn.Module):
def __init__(self,
device,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
text_seq_length=128,
image_tokens_per_dim=32,
image_vocab_size=16384,
loss_img_weight=7,
fp16=False,
use_masks=True,
cogview_sandwich_layernorm=False,
cogview_pb_relax=False):
super(DalleModel, self).__init__()
self.device = device
self.fp16 = fp16
self.image_tokens_per_dim = image_tokens_per_dim
self.image_seq_length = image_tokens_per_dim ** 2
self.text_seq_length = text_seq_length
self.total_seq_length = self.text_seq_length + self.image_seq_length
self.total_vocab_size = vocab_size + image_vocab_size
self.vocab_size = vocab_size
self.loss_img_weight = loss_img_weight
# TODO "to"
mask_map = self.prepare_image_masks(num_layers, text_seq_length, image_tokens_per_dim)
if use_masks:
self._mask_map = mask_map
else:
self._mask_map = []
init_method = init_method_normal(std=0.02)
self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size)
self.image_embeddings = torch.nn.Embedding(image_vocab_size, hidden_size)
# Position embedding (serial).
self.text_pos_embeddings = torch.nn.Embedding(text_seq_length + 1, hidden_size)
self.image_row_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size)
self.image_col_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size)
init_method(self.text_pos_embeddings.weight)
init_method(self.image_row_embeddings.weight)
init_method(self.image_col_embeddings.weight)
self.to_logits = torch.nn.Sequential(
torch.nn.LayerNorm(hidden_size),
torch.nn.Linear(hidden_size, self.total_vocab_size),
)
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
# Transformer
self.transformer = DalleTransformer(
num_layers,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
cogview_sandwich_layernorm=cogview_sandwich_layernorm,
cogview_pb_relax=cogview_pb_relax,
)
self.transformer._mask_map = self._mask_map
def get_param(self, item):
return getattr(self, item)
def prepare_image_masks(self, num_layers, text_seq_length, image_tokens_per_dim):
row_mask = get_row_mask(text_seq_length, image_tokens_per_dim).to(self.device)
col_mask = get_col_mask(text_seq_length, image_tokens_per_dim).to(self.device)
conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim).to(self.device)
if self.fp16:
row_mask = row_mask.half()
col_mask = col_mask.half()
conv_mask = conv_mask.half()
self.register_buffer('row_mask', row_mask)
self.register_buffer('col_mask', col_mask)
self.register_buffer('conv_mask', conv_mask)
mask_map = []
for i in range(num_layers):
if ((i - 1) % 4 == 0):
mask_map.append(col_mask)
elif i != num_layers - 1:
mask_map.append(row_mask)
else:
mask_map.append(conv_mask)
return mask_map
def get_image_pos_embeddings(self, image_input_ids, past_length=0):
input_shape = image_input_ids.size()
row_ids = torch.arange(past_length, input_shape[-1] + past_length,
dtype=torch.long, device=self.device) // self.image_tokens_per_dim
row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1])
col_ids = torch.arange(past_length, input_shape[-1] + past_length,
dtype=torch.long, device=self.device) % self.image_tokens_per_dim
col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1])
return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids)
def forward(
self,
input_ids,
attention_mask,
return_loss=False,
has_cache=False,
use_cache=False,
):
text = input_ids[:, :self.text_seq_length]
text_range = torch.arange(self.text_seq_length)
text_range += (self.vocab_size - self.text_seq_length)
text_range = text_range.to(self.device)
text = torch.where(text == 0, text_range, text)
# some hardcode :)
text = F.pad(text, (1, 0), value=2)
text_embeddings = self.text_embeddings(text) + \
self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device))
image_input_ids = input_ids[:, self.text_seq_length:]
if exists(image_input_ids) and not is_empty(image_input_ids):
image_embeddings = self.image_embeddings(image_input_ids) + \
self.get_image_pos_embeddings(image_input_ids, past_length=0)
embeddings = torch.cat((text_embeddings, image_embeddings), dim=1)
else:
embeddings = text_embeddings
# some hardcode :)
if embeddings.shape[1] > self.total_seq_length:
embeddings = embeddings[:, :-1]
alpha = 0.1
embeddings = embeddings * alpha + embeddings.detach() * (1-alpha)
attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]]
transformer_output, present_has_cache = self.transformer(
embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache)
logits = self.to_logits(transformer_output)
if return_loss is False:
return logits, present_has_cache
labels = torch.cat((text[:, 1:], image_input_ids), dim=1).contiguous().long()
logits = rearrange(logits, 'b n c -> b c n')
text_logits = logits[:, :self.vocab_size, :self.text_seq_length].contiguous().float()
image_logits = logits[:, self.vocab_size:, self.text_seq_length:].contiguous().float()
loss_text = F.cross_entropy(
text_logits,
labels[:, :self.text_seq_length])
loss_img = F.cross_entropy(
image_logits,
labels[:, self.text_seq_length:])
loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()}

+ 332
- 0
rudalle/dalle/transformer.py View File

@ -0,0 +1,332 @@
# -*- coding: utf-8 -*-
import math
import torch
from torch.nn import LayerNorm
from .utils import divide, split_tensor_along_last_dim
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def gelu(x):
return gelu_impl(x)
class DalleTransformer(torch.nn.Module):
"""
This module takes input from embedding layer and it's output can
be used directly by a logit layer. It consists of L (num-layers)
blocks of:
layer norm
self attention
residual connection
layer norm
mlp
residual connection
followed by a final layer norm.
Arguments:
num_layers: Number of transformer layers.
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
"""
_mask_map = []
def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob,
layernorm_epsilon=1.0e-5, cogview_sandwich_layernorm=False, cogview_pb_relax=False):
super(DalleTransformer, self).__init__()
# CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
self.cogview_pb_relax = cogview_pb_relax
# Transformer layers.
self.layers = torch.nn.ModuleList([
DalleTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
cogview_sandwich_layernorm=cogview_sandwich_layernorm,
cogview_pb_relax=cogview_pb_relax,
) for _ in range(num_layers)
])
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, attention_mask, has_cache, use_cache):
for i, layer in enumerate(self.layers):
mask = attention_mask
if len(self._mask_map):
layer_mask = self._mask_map[i][:mask.size(2), :mask.size(3)]
mask = torch.mul(attention_mask, layer_mask)
hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache)
output = self.final_layernorm(hidden_states)
return output, present_has_cache
class DalleTransformerLayer(torch.nn.Module):
"""
A single layer transformer.
We use the following notation:
h: hidden size
n: number of attention heads
b: batch size
s: sequence length
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
Arguments:
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
cogview_sandwich_layernorm=False,
cogview_pb_relax=False):
super(DalleTransformerLayer, self).__init__()
# CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
self.cogview_sandwich_layernorm = cogview_sandwich_layernorm
self.cogview_pb_relax = cogview_pb_relax
# Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
if self.cogview_sandwich_layernorm:
self.before_first_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.before_second_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
# Self attention.
self.attention = DalleSelfAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
cogview_pb_relax=cogview_pb_relax
)
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
# MLP
self.mlp = DalleMLP(hidden_size, output_dropout_prob)
def forward(self, hidden_states, ltor_mask, has_cache, use_cache):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, has_cache = self.attention(
layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache)
if self.cogview_sandwich_layernorm:
attention_output = self.before_first_addition_layernorm(attention_output)
# Residual connection.
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
if self.cogview_sandwich_layernorm:
mlp_output = self.before_second_addition_layernorm(mlp_output)
# Second residual connection.
output = layernorm_input + mlp_output
return output, has_cache
class DalleSelfAttention(torch.nn.Module):
"""
Self-attention layer takes input with size [b, s, h] where b is
the batch size, s is the sequence length, and h is the hidden size
and creates output of the same size.
Arguments:
hidden_size: total hidden size of the layer (h).
num_attention_heads: number of attention heads (n). Note that we
require n to be divisible by number of GPUs
used to parallelize the model. Also, we
require hidden size to be divisible by n.
attention_dropout_prob: dropout probability for the attention scores.
output_dropout_prob: dropout probability for the output.
We use the following notation:
h: hidden_size
n: num_attention_heads
p: number of partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
"""
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob, cogview_pb_relax=False):
super(DalleSelfAttention, self).__init__()
# CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
self.cogview_pb_relax = cogview_pb_relax
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
self.query_key_value = torch.nn.Linear(hidden_size, 3*hidden_size)
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = torch.nn.Linear(hidden_size, hidden_size)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
def _transpose_for_scores(self, tensor):
""" Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """
new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask):
key_t = key_layer.transpose(-1, -2)
if self.cogview_pb_relax:
attention_scores = torch.matmul(
query_layer / math.sqrt(self.hidden_size_per_attention_head),
key_t
)
else:
attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head)
attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask)
if self.cogview_pb_relax:
# normalize attention scores. Should not affect resulting softmax value
alpha = 32
attention_scores_scaled = attention_scores / alpha
attention_scores_scaled_maxes, _ = attention_scores_scaled.detach().view(
[attention_scores.size(0), attention_scores.size(1), -1]
).max(dim=-1) # max per head per sample
attention_scores_scaled_maxes = attention_scores_scaled_maxes.unsqueeze(-1).unsqueeze(-1).expand(
[-1, -1, attention_scores.size(2), attention_scores.size(3)]
) # expand to [b, np, s, s]
attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha
return attention_scores
def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Attention heads. [b, s, hp]
if has_cache and use_cache:
mixed_x_layer = self.query_key_value(hidden_states[:, -1:, :])
else:
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
if use_cache and has_cache:
value_layer = torch.cat((self.past_value, value_layer), dim=-2)
query_layer = torch.cat((self.past_query, query_layer), dim=-2)
key_layer = torch.cat((self.past_key, key_layer), dim=-2)
attention_scores = self._calculate_attention_scores(
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
)
else:
attention_scores = self._calculate_attention_scores(
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
)
if use_cache:
self.past_query = query_layer
self.past_key = key_layer
self.past_value = value_layer
has_cache = True
else:
has_cache = False
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_dropout(attention_probs)
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
output = self.dense(context_layer)
output = self.output_dropout(output)
return output, has_cache
class DalleMLP(torch.nn.Module):
"""
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform gelu transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
Arguments:
hidden_size: The hidden size of the self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
"""
def __init__(self, hidden_size, output_dropout_prob):
super(DalleMLP, self).__init__()
# Project to 4h.
self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4*hidden_size)
# Project back to h.
self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size)
self.dropout = torch.nn.Dropout(output_dropout_prob)
def forward(self, hidden_states):
# [b, s, 4hp]
x = self.dense_h_to_4h(hidden_states)
x = gelu(x)
# [b, s, h]
x = self.dense_4h_to_h(x)
output = self.dropout(x)
return output

+ 54
- 0
rudalle/dalle/utils.py View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
import torch
def exists(val):
return val is not None
def is_empty(t):
return t.nelement() == 0
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""
Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def init_method_normal(std=0.02):
"""Init method based on normal distribution.
This is only used for embeddings. The transformer has its
own initializer.
"""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_

+ 61
- 0
rudalle/pipelines.py View File

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
import torch
import torchvision
import transformers
import more_itertools
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from . import utils
def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, temperature=1.0, bs=8, seed=None,
use_cache=True):
# TODO docstring
if seed is not None:
utils.seed_everything(seed)
vocab_size = dalle.get_param('vocab_size')
text_seq_length = dalle.get_param('text_seq_length')
image_seq_length = dalle.get_param('image_seq_length')
total_seq_length = dalle.get_param('total_seq_length')
device = dalle.get_param('device')
text = text.lower().strip()
input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length)
pil_images, scores = [], []
for chunk in more_itertools.chunked(range(images_num), bs):
chunk_bs = len(chunk)
with torch.no_grad():
attention_mask = torch.tril(torch.ones((chunk_bs, 1, total_seq_length, total_seq_length), device=device))
out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device)
has_cache = False
sample_scores = []
for _ in tqdm(range(out.shape[1], total_seq_length)):
logits, has_cache = dalle(out, attention_mask,
has_cache=has_cache, use_cache=use_cache, return_loss=False)
logits = logits[:, -1, vocab_size:]
logits /= temperature
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
sample = torch.multinomial(probs, 1)
sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)])
out = torch.cat((out, sample), dim=-1)
codebooks = out[:, -image_seq_length:]
images = vae.decode(codebooks)
pil_images += utils.torch_tensors_to_pil_list(images)
scores += torch.cat(sample_scores).sum(0).detach().cpu().numpy().tolist()
return pil_images, scores
def show(pil_images, nrow=4):
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
if not isinstance(imgs, list):
imgs = [imgs.cpu()]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14))
for i, img in enumerate(imgs):
img = img.detach()
img = torchvision.transforms.functional.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

+ 64
- 0
rudalle/tokenizer.py View File

@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
from os.path import join
import torch
import numpy as np
import youtokentome as yttm
from huggingface_hub import hf_hub_url, cached_download
def get_tokenizer(path=None, cache_dir='/tmp/rudalle'):
# TODO docstring
if path is None:
repo_id = 'shonenkov/rudalle-utils'
filename = 'bpe.model'
cache_dir = join(cache_dir, 'tokenizer')
config_file_url = hf_hub_url(repo_id=repo_id, filename=filename)
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
path = join(cache_dir, filename)
tokenizer = YTTMTokenizerWrapper(yttm.BPE(model=path))
print('tokenizer --> ready')
return tokenizer
class YTTMTokenizerWrapper:
eos_id = 3
bos_id = 2
unk_id = 1
pad_id = 0
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __len__(self):
return self.vocab_size()
def get_pad_token_id(self):
# TODO docstring
return self.tokenizer.subword_to_id('<PAD>')
def vocab_size(self):
# TODO docstring
return self.tokenizer.vocab_size()
def encode_text(self, text, text_seq_length, bpe_dropout=0.0):
# TODO docstring
tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=bpe_dropout)[0]
tokens = [self.bos_id] + tokens + [self.eos_id]
return self.prepare_tokens(tokens, text_seq_length)
def decode_text(self, encoded):
# TODO docstring
return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[
self.eos_id, self.bos_id, self.unk_id, self.pad_id
])[0]
@staticmethod
def prepare_tokens(tokens, text_seq_length):
# TODO docstring
empty_positions = text_seq_length - len(tokens)
if empty_positions > 0:
tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text
if len(tokens) > text_seq_length:
tokens = tokens[:text_seq_length]
return torch.tensor(tokens).long()

+ 36
- 0
rudalle/utils.py View File

@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
import os
import random
import torch
import torchvision
import numpy as np
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def torch_tensors_to_pil_list(input_images):
out_images = []
for in_image in input_images:
in_image = in_image.cpu().detach()
out_image = torchvision.transforms.functional.to_pil_image(in_image).convert('RGB')
out_images.append(out_image)
return out_images
def pil_list_to_torch_tensors(pil_images):
result = []
for pil_image in pil_images:
image = np.array(pil_image, dtype=np.uint8)
image = torch.from_numpy(image)
image = image.permute(2, 0, 1).unsqueeze(0)
result.append(image)
return torch.cat(result, dim=0)

+ 24
- 0
rudalle/vae/__init__.py View File

@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
from os.path import dirname, abspath, join
import torch
from huggingface_hub import hf_hub_url, cached_download
from omegaconf import OmegaConf
from .model import VQGanGumbelVAE
def get_vae(pretrained=True, cache_dir='/tmp/rudalle'):
# TODO
config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml'))
vae = VQGanGumbelVAE(config)
if pretrained:
repo_id = 'shonenkov/rudalle-utils'
filename = 'vqgan.gumbelf8-sber.model.ckpt'
cache_dir = join(cache_dir, 'vae')
config_file_url = hf_hub_url(repo_id=repo_id, filename=filename)
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
checkpoint = torch.load(join(cache_dir, filename), map_location='cpu')
vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
print('vae --> ready')
return vae

+ 100
- 0
rudalle/vae/model.py View File

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
from math import sqrt, log
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from einops import rearrange
from taming.modules.diffusionmodules.model import Encoder, Decoder
class VQGanGumbelVAE(torch.nn.Module):
def __init__(self, config):
super().__init__()
model = GumbelVQ(
ddconfig=config.model.params.ddconfig,
n_embed=config.model.params.n_embed,
embed_dim=config.model.params.embed_dim,
kl_weight=config.model.params.kl_weight,
)
self.model = model
self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2))
self.image_size = 256
self.num_tokens = config.model.params.n_embed
@torch.no_grad()
def get_codebook_indices(self, img):
img = (2 * img) - 1
_, _, [_, _, indices] = self.model.encode(img)
return rearrange(indices, 'b h w -> b (h w)')
def decode(self, img_seq):
b, n = img_seq.shape
one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.num_tokens).float()
z = (one_hot_indices @ self.model.quantize.embed.weight)
z = rearrange(z, 'b (h w) c -> b c h w', h=int(sqrt(n)))
img = self.model.decode(z)
img = (img.clamp(-1., 1.) + 1) * 0.5
return img
class GumbelQuantize(nn.Module):
"""
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
Gumbel Softmax trick quantizer
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
https://arxiv.org/abs/1611.01144
"""
def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
kl_weight=5e-4, temp_init=1.0, use_vqinterface=True):
super().__init__()
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
self.embed = nn.Embedding(self.n_embed, self.embedding_dim)
self.use_vqinterface = use_vqinterface
def forward(self, z, temp=None, return_logits=False):
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
ind = soft_one_hot.argmax(dim=1)
if self.use_vqinterface:
if return_logits:
return z_q, diff, (None, None, ind), logits
return z_q, diff, (None, None, ind)
return z_q, diff, ind
class GumbelVQ(nn.Module):
def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8):
super().__init__()
z_channels = ddconfig['z_channels']
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0)
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec

+ 34
- 0
rudalle/vae/vqgan.gumbelf8-sber.config.yml View File

@ -0,0 +1,34 @@
model:
base_learning_rate: 4.5e-06
target: taming.models.vqgan.GumbelVQ
params:
kl_weight: 1.0e-08
embed_dim: 256
n_embed: 8192
monitor: val/rec_loss
temperature_scheduler_config:
target: taming.lr_scheduler.LambdaWarmUpCosineScheduler
params:
warm_up_steps: 0
max_decay_steps: 1000001
lr_start: 0.9
lr_max: 0.9
lr_min: 1.0e-06
ddconfig:
double_z: false
z_channels: 256
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions:
- 32
dropout: 0.0
lossconfig:
target: taming.modules.losses.vqperceptual.DummyLoss

+ 13
- 0
setup.cfg View File

@ -0,0 +1,13 @@
[pep8]
max-line-length = 120
exclude = .tox,*migrations*,.json
[flake8]
max-line-length = 120
exclude = .tox,*migrations*,.json
[autopep8-wrapper]
exclude = .tox,*migrations*,.json
[check-docstring-first]
exclude = .tox,*migrations*,.json

+ 0
- 0
tests/__init__.py View File


+ 39
- 0
tests/conftest.py View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
import io
from os.path import abspath, dirname
import PIL
import pytest
import requests
from rudalle import get_tokenizer, get_rudalle_model, get_vae
TEST_ROOT = dirname(abspath(__file__))
@pytest.fixture(scope='module')
def vae():
vae = get_vae(pretrained=False)
yield vae
@pytest.fixture(scope='module')
def yttm_tokenizer():
tokenizer = get_tokenizer()
yield tokenizer
@pytest.fixture(scope='module')
def sample_image():
url = 'https://cdn.kqed.org/wp-content/uploads/sites/12/2013/12/rudolph.png'
resp = requests.get(url)
resp.raise_for_status()
image = PIL.Image.open(io.BytesIO(resp.content))
yield image
@pytest.fixture(scope='module')
def small_dalle():
model = get_rudalle_model('small', pretrained=False, fp16=False, device='cpu')
return model

+ 31
- 0
tests/test_dalle.py View File

@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
import torch
import pytest
from .test_vae import preprocess
@pytest.mark.parametrize('text', [
'мальчик играет с оленем',
])
def test_forward_step_and_criterion(text, sample_image, yttm_tokenizer, vae, small_dalle):
bs = 4
text_seq_length = small_dalle.get_param('text_seq_length')
total_seq_length = small_dalle.get_param('total_seq_length')
device = small_dalle.get_param('device')
img = sample_image.copy()
img = preprocess(img, target_image_size=256)
images = img.repeat(bs, 1, 1, 1).to(device)
text = text.lower().strip()
text_input_ids = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length)
text_input_ids = text_input_ids.unsqueeze(0).repeat(bs, 1).to(device)
attention_mask = torch.tril(torch.ones((bs, 1, total_seq_length, total_seq_length), device=device))
with torch.no_grad():
image_input_ids = vae.get_codebook_indices(images)
input_ids = torch.cat((text_input_ids, image_input_ids), dim=1)
loss, loss_values = small_dalle.forward(input_ids, attention_mask, return_loss=True)
assert type(loss.data.detach().item()) == float
assert type(loss_values) == dict

+ 17
- 0
tests/test_tokenizer.py View File

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
import pytest
@pytest.mark.parametrize('text, text_seq_length, bpe_dropout', [
('hello, how are you?', 128, 0.1),
('hello, how are you?', 128, 0.5),
('hello, how are you?', 128, 1.0),
('hello ... how are you ?', 256, 1.0),
('a person standing at a table with bottles of win', 64, 0.5),
('привет как дела???', 76, 0.0),
('клип на русском языке :)', 76, 0.1),
])
def test_encode_decode_text_yttm(yttm_tokenizer, text, text_seq_length, bpe_dropout):
tokens = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length, bpe_dropout=bpe_dropout)
decoded_text = yttm_tokenizer.decode_text(tokens)
assert text == decoded_text

+ 49
- 0
tests/test_vae.py View File

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
import PIL
import pytest
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
@pytest.mark.parametrize('target_image_size', [128, 192, 256])
def test_decode_vae(vae, sample_image, target_image_size):
img = sample_image.copy()
img = preprocess(img, target_image_size=target_image_size)
with torch.no_grad():
img_seq = vae.get_codebook_indices(img)
out_img = vae.decode(img_seq)
assert out_img.shape == (1, 3, target_image_size, target_image_size)
@pytest.mark.parametrize('target_image_size', [128, 192, 256])
def test_reconstruct_vae(vae, sample_image, target_image_size):
img = sample_image.copy()
with torch.no_grad():
x_vqgan = preprocess(img, target_image_size=target_image_size)
output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), vae.model)
assert output.shape == (1, 3, target_image_size, target_image_size)
def preprocess(img, target_image_size=256):
s = min(img.size)
if s < target_image_size:
raise ValueError(f'min dim for image {s} < {target_image_size}')
r = target_image_size / s
s = (round(r * img.size[1]), round(r * img.size[0]))
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
img = TF.center_crop(img, output_size=2 * [target_image_size])
img = torch.unsqueeze(T.ToTensor()(img), 0)
return img
def preprocess_vqgan(x):
x = 2.*x - 1.
return x
def reconstruct_with_vqgan(x, model):
z, _, [_, _, _] = model.encode(x)
print(f'VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}')
xrec = model.decode(z)
return xrec

Loading…
Cancel
Save