From e96631a867fcadcfaa52eecb20b1e42b88aa4386 Mon Sep 17 00:00:00 2001 From: Alex Shonenkov Date: Mon, 1 Nov 2021 20:07:11 +0300 Subject: [PATCH] initial commit --- .gitignore | 168 +++++++++++ .pre-commit-config.yaml | 31 ++ requirements-test.txt | 4 + requirements.txt | 8 + rudalle/__init__.py | 15 + rudalle/dalle/__init__.py | 76 +++++ rudalle/dalle/fp16.py | 60 ++++ rudalle/dalle/image_attention.py | 48 +++ rudalle/dalle/model.py | 171 +++++++++++ rudalle/dalle/transformer.py | 332 +++++++++++++++++++++ rudalle/dalle/utils.py | 54 ++++ rudalle/pipelines.py | 61 ++++ rudalle/tokenizer.py | 64 ++++ rudalle/utils.py | 36 +++ rudalle/vae/__init__.py | 24 ++ rudalle/vae/model.py | 100 +++++++ rudalle/vae/vqgan.gumbelf8-sber.config.yml | 34 +++ setup.cfg | 13 + tests/__init__.py | 0 tests/conftest.py | 39 +++ tests/test_dalle.py | 31 ++ tests/test_tokenizer.py | 17 ++ tests/test_vae.py | 49 +++ 23 files changed, 1435 insertions(+) create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 requirements-test.txt create mode 100644 requirements.txt create mode 100644 rudalle/__init__.py create mode 100644 rudalle/dalle/__init__.py create mode 100755 rudalle/dalle/fp16.py create mode 100644 rudalle/dalle/image_attention.py create mode 100644 rudalle/dalle/model.py create mode 100755 rudalle/dalle/transformer.py create mode 100644 rudalle/dalle/utils.py create mode 100644 rudalle/pipelines.py create mode 100644 rudalle/tokenizer.py create mode 100644 rudalle/utils.py create mode 100644 rudalle/vae/__init__.py create mode 100644 rudalle/vae/model.py create mode 100644 rudalle/vae/vqgan.gumbelf8-sber.config.yml create mode 100644 setup.cfg create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_dalle.py create mode 100644 tests/test_tokenizer.py create mode 100644 tests/test_vae.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4317731 --- /dev/null +++ b/.gitignore @@ -0,0 +1,168 @@ +# Created by .ignore support plugin (hsz.mobi) +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +settings/local.py +logs/*.log + +# User-specific stuff: +.idea/ + +# Sensitive or high-churn files: +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.xml +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml + +# Gradle: +.idea/**/gradle.xml +.idea/**/libraries + +# CMake +cmake-build-debug/ + +# Mongo Explorer plugin: +.idea/**/mongoSettings.xml + +## File-based project format: +*.iws + +## Plugin-specific files: + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +/tests/load_tests/logs/* +/tests/.pytest_cache/ +ws_test.py +/.vscode/ + +.s3_cache/ +mlruns +*.pyc +*.swp +*.pt +*.bin +.vscode/ +runs/ +jupyters/custom_* + +*logs/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..d63bec2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.2.3 + hooks: + - id: check-docstring-first + stages: + - commit + - push + - id: check-merge-conflict + stages: + - push + - id: double-quote-string-fixer + stages: + - commit + - push + - id: fix-encoding-pragma + stages: + - commit + - push + - id: flake8 + args: ['--config=setup.cfg'] + stages: + - commit + - push +- repo: https://github.com/pre-commit/mirrors-autopep8 + rev: v1.4.4 + hooks: + - id: autopep8 + stages: + - commit + - push diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..13a0fd9 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,4 @@ +-r requirements.txt +pytest +pytest-cov +pre-commit diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..321b95d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +taming-transformers==0.0.1 +more_itertools==8.10.0 +transformers==4.10.2 +youtokentome==1.0.6 +einops==0.3.2 +torch +torchvision +matplotlib \ No newline at end of file diff --git a/rudalle/__init__.py b/rudalle/__init__.py new file mode 100644 index 0000000..430f18b --- /dev/null +++ b/rudalle/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +from .vae import get_vae +from .dalle import get_rudalle_model +from .tokenizer import get_tokenizer +from . import vae, dalle, tokenizer + + +__all__ = [ + 'get_vae', + 'get_rudalle_model', + 'get_tokenizer', + 'vae', + 'dalle', + 'tokenizer', +] diff --git a/rudalle/dalle/__init__.py b/rudalle/dalle/__init__.py new file mode 100644 index 0000000..26c7b68 --- /dev/null +++ b/rudalle/dalle/__init__.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +import os + +import torch +from huggingface_hub import hf_hub_url, cached_download + +from .model import DalleModel +from .fp16 import FP16Module + + +MODELS = { + 'Malevich': dict( + description='◼️ Malevich is 1.3 billion params model from the family GPT3-like, ' + 'that uses Russian language and text+image multi-modality.', + model_params=dict( + num_layers=24, + hidden_size=2048, + num_attention_heads=16, + embedding_dropout_prob=0.1, + output_dropout_prob=0.1, + attention_dropout_prob=0.1, + image_tokens_per_dim=32, + text_seq_length=128, + use_masks=True, + cogview_sandwich_layernorm=True, + cogview_pb_relax=True, + vocab_size=16384+128, + image_vocab_size=8192, + ), + repo_id='sberbank-ai/rudalle-Malevich', + filename='pytorch_model.bin', + full_description='', # TODO + ), + 'small': dict( + description='', + model_params=dict( + num_layers=12, + hidden_size=768, + num_attention_heads=12, + embedding_dropout_prob=0.1, + output_dropout_prob=0.1, + attention_dropout_prob=0.1, + image_tokens_per_dim=32, + text_seq_length=128, + use_masks=True, + cogview_sandwich_layernorm=True, + cogview_pb_relax=True, + vocab_size=16384+128, + image_vocab_size=8192, + ), + repo_id='', + filename='', + full_description='', # TODO + ), +} + + +def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle'): + # TODO docstring + assert name in MODELS + + config = MODELS[name] + model = DalleModel(device=device, fp16=fp16, **config['model_params']) + if pretrained: + cache_dir = os.path.join(cache_dir, name) + config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) + cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) + checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu') + model.load_state_dict(checkpoint) + if fp16: + model = FP16Module(model) + model.eval() + model = model.to(device) + if config['description'] and pretrained: + print(config['description']) + return model diff --git a/rudalle/dalle/fp16.py b/rudalle/dalle/fp16.py new file mode 100755 index 0000000..a16fa8d --- /dev/null +++ b/rudalle/dalle/fp16.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn.parameter import Parameter + +FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) + + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + + +def fp32_to_fp16(val): + """Convert fp32 `val` to fp16""" + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, FLOAT_TYPES): + val = val.half() + return val + return conversion_helper(val, half_conversion) + + +def fp16_to_fp32(val): + """Convert fp16 `val` to fp32""" + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, HALF_TYPES): + val = val.float() + return val + return conversion_helper(val, float_conversion) + + +class FP16Module(nn.Module): + def __init__(self, module): + super(FP16Module, self).__init__() + self.add_module('module', module.half()) + + def forward(self, *inputs, **kwargs): + return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination, prefix, keep_vars) + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) + + def get_param(self, item): + return self.module.get_param(item) diff --git a/rudalle/dalle/image_attention.py b/rudalle/dalle/image_attention.py new file mode 100644 index 0000000..0df1a77 --- /dev/null +++ b/rudalle/dalle/image_attention.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + +import torch + + +def _init_mask(text_tokens, image_tokens_per_dim): + attn_size = text_tokens + image_tokens_per_dim**2 + mask = torch.tril(torch.ones(attn_size, attn_size)) + return mask + + +def get_row_mask(text_tokens=256, image_tokens_per_dim=32): + mask = _init_mask(text_tokens, image_tokens_per_dim) + step = image_tokens_per_dim + 1 + for col in range(text_tokens, mask.size(1)): + mask[col + step:, col] = 0.0 + return mask + + +def get_col_mask(text_tokens=256, image_tokens_per_dim=32): + mask = _init_mask(text_tokens, image_tokens_per_dim) + step = image_tokens_per_dim - 1 + for col in range(text_tokens, mask.size(1)): + for i in range(1, mask.size(0), step+1): + mask[col + i: col + i + step, col] = 0.0 + return mask + + +def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11): + mask = _init_mask(text_tokens, image_tokens_per_dim) + shift = kernel // 2 + for pos in range(text_tokens, mask.size(1)): + mask[pos+1:, pos] = 0.0 + img = torch.zeros(image_tokens_per_dim, image_tokens_per_dim) + pixel_id = pos - text_tokens + row = pixel_id // image_tokens_per_dim + col = pixel_id % image_tokens_per_dim + for r in range(-shift, shift+1): + for c in range(-shift, shift+1): + c_abs = (c + col) % image_tokens_per_dim + r_abs = (r + row) % image_tokens_per_dim + img[r_abs, c_abs] = 0.2 + cell_id = r_abs * image_tokens_per_dim + c_abs + if text_tokens + cell_id > pos: + mask[text_tokens + cell_id, pos] = 1.0 + + img[row, col] = 1.0 + return mask diff --git a/rudalle/dalle/model.py b/rudalle/dalle/model.py new file mode 100644 index 0000000..be855f6 --- /dev/null +++ b/rudalle/dalle/model.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn.functional as F +from einops import rearrange + +from .utils import exists, is_empty, init_method_normal +from .image_attention import get_conv_mask, get_row_mask, get_col_mask + +from .transformer import DalleTransformer + + +class DalleModel(torch.nn.Module): + def __init__(self, + device, + num_layers, + vocab_size, + hidden_size, + num_attention_heads, + embedding_dropout_prob, + attention_dropout_prob, + output_dropout_prob, + text_seq_length=128, + image_tokens_per_dim=32, + image_vocab_size=16384, + loss_img_weight=7, + fp16=False, + use_masks=True, + cogview_sandwich_layernorm=False, + cogview_pb_relax=False): + + super(DalleModel, self).__init__() + self.device = device + self.fp16 = fp16 + self.image_tokens_per_dim = image_tokens_per_dim + self.image_seq_length = image_tokens_per_dim ** 2 + self.text_seq_length = text_seq_length + self.total_seq_length = self.text_seq_length + self.image_seq_length + self.total_vocab_size = vocab_size + image_vocab_size + self.vocab_size = vocab_size + self.loss_img_weight = loss_img_weight + + # TODO "to" + mask_map = self.prepare_image_masks(num_layers, text_seq_length, image_tokens_per_dim) + if use_masks: + self._mask_map = mask_map + else: + self._mask_map = [] + + init_method = init_method_normal(std=0.02) + + self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size) + self.image_embeddings = torch.nn.Embedding(image_vocab_size, hidden_size) + + # Position embedding (serial). + self.text_pos_embeddings = torch.nn.Embedding(text_seq_length + 1, hidden_size) + self.image_row_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size) + self.image_col_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size) + init_method(self.text_pos_embeddings.weight) + init_method(self.image_row_embeddings.weight) + init_method(self.image_col_embeddings.weight) + + self.to_logits = torch.nn.Sequential( + torch.nn.LayerNorm(hidden_size), + torch.nn.Linear(hidden_size, self.total_vocab_size), + ) + + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + # Transformer + self.transformer = DalleTransformer( + num_layers, + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + cogview_sandwich_layernorm=cogview_sandwich_layernorm, + cogview_pb_relax=cogview_pb_relax, + ) + self.transformer._mask_map = self._mask_map + + def get_param(self, item): + return getattr(self, item) + + def prepare_image_masks(self, num_layers, text_seq_length, image_tokens_per_dim): + row_mask = get_row_mask(text_seq_length, image_tokens_per_dim).to(self.device) + col_mask = get_col_mask(text_seq_length, image_tokens_per_dim).to(self.device) + conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim).to(self.device) + if self.fp16: + row_mask = row_mask.half() + col_mask = col_mask.half() + conv_mask = conv_mask.half() + self.register_buffer('row_mask', row_mask) + self.register_buffer('col_mask', col_mask) + self.register_buffer('conv_mask', conv_mask) + mask_map = [] + for i in range(num_layers): + if ((i - 1) % 4 == 0): + mask_map.append(col_mask) + elif i != num_layers - 1: + mask_map.append(row_mask) + else: + mask_map.append(conv_mask) + return mask_map + + def get_image_pos_embeddings(self, image_input_ids, past_length=0): + input_shape = image_input_ids.size() + row_ids = torch.arange(past_length, input_shape[-1] + past_length, + dtype=torch.long, device=self.device) // self.image_tokens_per_dim + row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1]) + col_ids = torch.arange(past_length, input_shape[-1] + past_length, + dtype=torch.long, device=self.device) % self.image_tokens_per_dim + col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1]) + return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids) + + def forward( + self, + input_ids, + attention_mask, + return_loss=False, + has_cache=False, + use_cache=False, + ): + text = input_ids[:, :self.text_seq_length] + text_range = torch.arange(self.text_seq_length) + text_range += (self.vocab_size - self.text_seq_length) + text_range = text_range.to(self.device) + text = torch.where(text == 0, text_range, text) + # some hardcode :) + text = F.pad(text, (1, 0), value=2) + text_embeddings = self.text_embeddings(text) + \ + self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device)) + + image_input_ids = input_ids[:, self.text_seq_length:] + + if exists(image_input_ids) and not is_empty(image_input_ids): + image_embeddings = self.image_embeddings(image_input_ids) + \ + self.get_image_pos_embeddings(image_input_ids, past_length=0) + embeddings = torch.cat((text_embeddings, image_embeddings), dim=1) + else: + embeddings = text_embeddings + # some hardcode :) + if embeddings.shape[1] > self.total_seq_length: + embeddings = embeddings[:, :-1] + + alpha = 0.1 + embeddings = embeddings * alpha + embeddings.detach() * (1-alpha) + + attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]] + transformer_output, present_has_cache = self.transformer( + embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache) + + logits = self.to_logits(transformer_output) + if return_loss is False: + return logits, present_has_cache + + labels = torch.cat((text[:, 1:], image_input_ids), dim=1).contiguous().long() + logits = rearrange(logits, 'b n c -> b c n') + + text_logits = logits[:, :self.vocab_size, :self.text_seq_length].contiguous().float() + image_logits = logits[:, self.vocab_size:, self.text_seq_length:].contiguous().float() + + loss_text = F.cross_entropy( + text_logits, + labels[:, :self.text_seq_length]) + loss_img = F.cross_entropy( + image_logits, + labels[:, self.text_seq_length:]) + + loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1) + return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()} diff --git a/rudalle/dalle/transformer.py b/rudalle/dalle/transformer.py new file mode 100755 index 0000000..5a62aa0 --- /dev/null +++ b/rudalle/dalle/transformer.py @@ -0,0 +1,332 @@ +# -*- coding: utf-8 -*- +import math + +import torch +from torch.nn import LayerNorm + +from .utils import divide, split_tensor_along_last_dim + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) + + +def gelu(x): + return gelu_impl(x) + + +class DalleTransformer(torch.nn.Module): + """ + This module takes input from embedding layer and it's output can + be used directly by a logit layer. It consists of L (num-layers) + blocks of: + layer norm + self attention + residual connection + layer norm + mlp + residual connection + followed by a final layer norm. + + Arguments: + num_layers: Number of transformer layers. + hidden_size: The hidden size of the self attention. + num_attention_heads: number of attention head in the self + attention. + attention_dropout_prob: dropout probability of the attention + score in self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + layernorm_epsilon: epsilon used in layernorm to avoid + division by zero. + """ + _mask_map = [] + + def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, + layernorm_epsilon=1.0e-5, cogview_sandwich_layernorm=False, cogview_pb_relax=False): + super(DalleTransformer, self).__init__() + + # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf + self.cogview_pb_relax = cogview_pb_relax + + # Transformer layers. + self.layers = torch.nn.ModuleList([ + DalleTransformerLayer( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + layernorm_epsilon, + cogview_sandwich_layernorm=cogview_sandwich_layernorm, + cogview_pb_relax=cogview_pb_relax, + ) for _ in range(num_layers) + ]) + + # Final layer norm before output. + self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + + def forward(self, hidden_states, attention_mask, has_cache, use_cache): + for i, layer in enumerate(self.layers): + mask = attention_mask + if len(self._mask_map): + layer_mask = self._mask_map[i][:mask.size(2), :mask.size(3)] + mask = torch.mul(attention_mask, layer_mask) + hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache) + output = self.final_layernorm(hidden_states) + return output, present_has_cache + + +class DalleTransformerLayer(torch.nn.Module): + """ + A single layer transformer. + + We use the following notation: + h: hidden size + n: number of attention heads + b: batch size + s: sequence length + Transformer layer takes input with size [b, s, h] and returns an + output of the same size. + + Arguments: + hidden_size: The hidden size of the self attention. + num_attention_heads: number of attention head in the self + attention. + attention_dropout_prob: dropout probability of the attention + score in self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + layernorm_epsilon: epsilon used in layernorm to avoid + division by zero. + """ + + def __init__(self, + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + layernorm_epsilon, + cogview_sandwich_layernorm=False, + cogview_pb_relax=False): + super(DalleTransformerLayer, self).__init__() + + # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf + self.cogview_sandwich_layernorm = cogview_sandwich_layernorm + self.cogview_pb_relax = cogview_pb_relax + + # Layernorm on the input data. + self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + + if self.cogview_sandwich_layernorm: + self.before_first_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.before_second_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + + # Self attention. + self.attention = DalleSelfAttention( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + cogview_pb_relax=cogview_pb_relax + ) + + # Layernorm on the input data. + self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + + # MLP + self.mlp = DalleMLP(hidden_size, output_dropout_prob) + + def forward(self, hidden_states, ltor_mask, has_cache, use_cache): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Layer norm at the begining of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, has_cache = self.attention( + layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache) + + if self.cogview_sandwich_layernorm: + attention_output = self.before_first_addition_layernorm(attention_output) + + # Residual connection. + layernorm_input = hidden_states + attention_output + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + if self.cogview_sandwich_layernorm: + mlp_output = self.before_second_addition_layernorm(mlp_output) + + # Second residual connection. + output = layernorm_input + mlp_output + + return output, has_cache + + +class DalleSelfAttention(torch.nn.Module): + """ + Self-attention layer takes input with size [b, s, h] where b is + the batch size, s is the sequence length, and h is the hidden size + and creates output of the same size. + Arguments: + hidden_size: total hidden size of the layer (h). + num_attention_heads: number of attention heads (n). Note that we + require n to be divisible by number of GPUs + used to parallelize the model. Also, we + require hidden size to be divisible by n. + attention_dropout_prob: dropout probability for the attention scores. + output_dropout_prob: dropout probability for the output. + We use the following notation: + h: hidden_size + n: num_attention_heads + p: number of partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + """ + + def __init__(self, hidden_size, num_attention_heads, + attention_dropout_prob, output_dropout_prob, cogview_pb_relax=False): + super(DalleSelfAttention, self).__init__() + + # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf + self.cogview_pb_relax = cogview_pb_relax + + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) + + self.query_key_value = torch.nn.Linear(hidden_size, 3*hidden_size) + self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) + + # Output. + self.dense = torch.nn.Linear(hidden_size, hidden_size) + self.output_dropout = torch.nn.Dropout(output_dropout_prob) + + def _transpose_for_scores(self, tensor): + """ Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """ + new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask): + key_t = key_layer.transpose(-1, -2) + if self.cogview_pb_relax: + attention_scores = torch.matmul( + query_layer / math.sqrt(self.hidden_size_per_attention_head), + key_t + ) + else: + attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head) + attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask) + if self.cogview_pb_relax: + # normalize attention scores. Should not affect resulting softmax value + alpha = 32 + attention_scores_scaled = attention_scores / alpha + attention_scores_scaled_maxes, _ = attention_scores_scaled.detach().view( + [attention_scores.size(0), attention_scores.size(1), -1] + ).max(dim=-1) # max per head per sample + attention_scores_scaled_maxes = attention_scores_scaled_maxes.unsqueeze(-1).unsqueeze(-1).expand( + [-1, -1, attention_scores.size(2), attention_scores.size(3)] + ) # expand to [b, np, s, s] + attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha + return attention_scores + + def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + # Attention heads. [b, s, hp] + if has_cache and use_cache: + mixed_x_layer = self.query_key_value(hidden_states[:, -1:, :]) + else: + mixed_x_layer = self.query_key_value(hidden_states) + + (mixed_query_layer, + mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + + if use_cache and has_cache: + value_layer = torch.cat((self.past_value, value_layer), dim=-2) + query_layer = torch.cat((self.past_query, query_layer), dim=-2) + key_layer = torch.cat((self.past_key, key_layer), dim=-2) + attention_scores = self._calculate_attention_scores( + query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask + ) + else: + attention_scores = self._calculate_attention_scores( + query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask + ) + + if use_cache: + self.past_query = query_layer + self.past_key = key_layer + self.past_value = value_layer + has_cache = True + else: + has_cache = False + + # Attention probabilities. [b, np, s, s] + attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + + # Context layer. + # [b, np, s, hn] + context_layer = torch.matmul(attention_probs, value_layer) + + # [b, s, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + output = self.dense(context_layer) + output = self.output_dropout(output) + return output, has_cache + + +class DalleMLP(torch.nn.Module): + """ + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform gelu transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + Arguments: + hidden_size: The hidden size of the self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + """ + + def __init__(self, hidden_size, output_dropout_prob): + super(DalleMLP, self).__init__() + # Project to 4h. + self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4*hidden_size) + # Project back to h. + self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size) + self.dropout = torch.nn.Dropout(output_dropout_prob) + + def forward(self, hidden_states): + # [b, s, 4hp] + x = self.dense_h_to_4h(hidden_states) + x = gelu(x) + # [b, s, h] + x = self.dense_4h_to_h(x) + output = self.dropout(x) + return output diff --git a/rudalle/dalle/utils.py b/rudalle/dalle/utils.py new file mode 100644 index 0000000..5eea270 --- /dev/null +++ b/rudalle/dalle/utils.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +import torch + + +def exists(val): + return val is not None + + +def is_empty(t): + return t.nelement() == 0 + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): + """ + Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + return tensor_list + + +def init_method_normal(std=0.02): + """Init method based on normal distribution. + + This is only used for embeddings. The transformer has its + own initializer. + """ + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + return init_ diff --git a/rudalle/pipelines.py b/rudalle/pipelines.py new file mode 100644 index 0000000..448a547 --- /dev/null +++ b/rudalle/pipelines.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +import torch +import torchvision +import transformers +import more_itertools +import numpy as np +import matplotlib.pyplot as plt +from tqdm import tqdm + +from . import utils + + +def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, temperature=1.0, bs=8, seed=None, + use_cache=True): + # TODO docstring + if seed is not None: + utils.seed_everything(seed) + + vocab_size = dalle.get_param('vocab_size') + text_seq_length = dalle.get_param('text_seq_length') + image_seq_length = dalle.get_param('image_seq_length') + total_seq_length = dalle.get_param('total_seq_length') + device = dalle.get_param('device') + + text = text.lower().strip() + input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length) + pil_images, scores = [], [] + for chunk in more_itertools.chunked(range(images_num), bs): + chunk_bs = len(chunk) + with torch.no_grad(): + attention_mask = torch.tril(torch.ones((chunk_bs, 1, total_seq_length, total_seq_length), device=device)) + out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device) + has_cache = False + sample_scores = [] + for _ in tqdm(range(out.shape[1], total_seq_length)): + logits, has_cache = dalle(out, attention_mask, + has_cache=has_cache, use_cache=use_cache, return_loss=False) + logits = logits[:, -1, vocab_size:] + logits /= temperature + filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + probs = torch.nn.functional.softmax(filtered_logits, dim=-1) + sample = torch.multinomial(probs, 1) + sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)]) + out = torch.cat((out, sample), dim=-1) + codebooks = out[:, -image_seq_length:] + images = vae.decode(codebooks) + pil_images += utils.torch_tensors_to_pil_list(images) + scores += torch.cat(sample_scores).sum(0).detach().cpu().numpy().tolist() + return pil_images, scores + + +def show(pil_images, nrow=4): + imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) + if not isinstance(imgs, list): + imgs = [imgs.cpu()] + fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14)) + for i, img in enumerate(imgs): + img = img.detach() + img = torchvision.transforms.functional.to_pil_image(img) + axs[0, i].imshow(np.asarray(img)) + axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) diff --git a/rudalle/tokenizer.py b/rudalle/tokenizer.py new file mode 100644 index 0000000..2ca1e7a --- /dev/null +++ b/rudalle/tokenizer.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +from os.path import join + +import torch +import numpy as np +import youtokentome as yttm +from huggingface_hub import hf_hub_url, cached_download + + +def get_tokenizer(path=None, cache_dir='/tmp/rudalle'): + # TODO docstring + if path is None: + repo_id = 'shonenkov/rudalle-utils' + filename = 'bpe.model' + cache_dir = join(cache_dir, 'tokenizer') + config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) + cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) + path = join(cache_dir, filename) + tokenizer = YTTMTokenizerWrapper(yttm.BPE(model=path)) + print('tokenizer --> ready') + return tokenizer + + +class YTTMTokenizerWrapper: + eos_id = 3 + bos_id = 2 + unk_id = 1 + pad_id = 0 + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __len__(self): + return self.vocab_size() + + def get_pad_token_id(self): + # TODO docstring + return self.tokenizer.subword_to_id('') + + def vocab_size(self): + # TODO docstring + return self.tokenizer.vocab_size() + + def encode_text(self, text, text_seq_length, bpe_dropout=0.0): + # TODO docstring + tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=bpe_dropout)[0] + tokens = [self.bos_id] + tokens + [self.eos_id] + return self.prepare_tokens(tokens, text_seq_length) + + def decode_text(self, encoded): + # TODO docstring + return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ + self.eos_id, self.bos_id, self.unk_id, self.pad_id + ])[0] + + @staticmethod + def prepare_tokens(tokens, text_seq_length): + # TODO docstring + empty_positions = text_seq_length - len(tokens) + if empty_positions > 0: + tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text + if len(tokens) > text_seq_length: + tokens = tokens[:text_seq_length] + return torch.tensor(tokens).long() diff --git a/rudalle/utils.py b/rudalle/utils.py new file mode 100644 index 0000000..3ef7a54 --- /dev/null +++ b/rudalle/utils.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +import os +import random + +import torch +import torchvision +import numpy as np + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + +def torch_tensors_to_pil_list(input_images): + out_images = [] + for in_image in input_images: + in_image = in_image.cpu().detach() + out_image = torchvision.transforms.functional.to_pil_image(in_image).convert('RGB') + out_images.append(out_image) + return out_images + + +def pil_list_to_torch_tensors(pil_images): + result = [] + for pil_image in pil_images: + image = np.array(pil_image, dtype=np.uint8) + image = torch.from_numpy(image) + image = image.permute(2, 0, 1).unsqueeze(0) + result.append(image) + return torch.cat(result, dim=0) diff --git a/rudalle/vae/__init__.py b/rudalle/vae/__init__.py new file mode 100644 index 0000000..fdda089 --- /dev/null +++ b/rudalle/vae/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +from os.path import dirname, abspath, join + +import torch +from huggingface_hub import hf_hub_url, cached_download +from omegaconf import OmegaConf + +from .model import VQGanGumbelVAE + + +def get_vae(pretrained=True, cache_dir='/tmp/rudalle'): + # TODO + config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml')) + vae = VQGanGumbelVAE(config) + if pretrained: + repo_id = 'shonenkov/rudalle-utils' + filename = 'vqgan.gumbelf8-sber.model.ckpt' + cache_dir = join(cache_dir, 'vae') + config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) + cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) + checkpoint = torch.load(join(cache_dir, filename), map_location='cpu') + vae.model.load_state_dict(checkpoint['state_dict'], strict=False) + print('vae --> ready') + return vae diff --git a/rudalle/vae/model.py b/rudalle/vae/model.py new file mode 100644 index 0000000..2c850b4 --- /dev/null +++ b/rudalle/vae/model.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +from math import sqrt, log + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from einops import rearrange +from taming.modules.diffusionmodules.model import Encoder, Decoder + + +class VQGanGumbelVAE(torch.nn.Module): + + def __init__(self, config): + super().__init__() + model = GumbelVQ( + ddconfig=config.model.params.ddconfig, + n_embed=config.model.params.n_embed, + embed_dim=config.model.params.embed_dim, + kl_weight=config.model.params.kl_weight, + ) + self.model = model + self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2)) + self.image_size = 256 + self.num_tokens = config.model.params.n_embed + + @torch.no_grad() + def get_codebook_indices(self, img): + img = (2 * img) - 1 + _, _, [_, _, indices] = self.model.encode(img) + return rearrange(indices, 'b h w -> b (h w)') + + def decode(self, img_seq): + b, n = img_seq.shape + one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.num_tokens).float() + z = (one_hot_indices @ self.model.quantize.embed.weight) + z = rearrange(z, 'b (h w) c -> b c h w', h=int(sqrt(n))) + img = self.model.decode(z) + img = (img.clamp(-1., 1.) + 1) * 0.5 + return img + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + + def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, + kl_weight=5e-4, temp_init=1.0, use_vqinterface=True): + super().__init__() + self.embedding_dim = embedding_dim + self.n_embed = n_embed + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(self.n_embed, self.embedding_dim) + self.use_vqinterface = use_vqinterface + + def forward(self, z, temp=None, return_logits=False): + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + logits = self.proj(z) + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + ind = soft_one_hot.argmax(dim=1) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + +class GumbelVQ(nn.Module): + + def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8): + super().__init__() + z_channels = ddconfig['z_channels'] + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0) + self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec diff --git a/rudalle/vae/vqgan.gumbelf8-sber.config.yml b/rudalle/vae/vqgan.gumbelf8-sber.config.yml new file mode 100644 index 0000000..cc03c6e --- /dev/null +++ b/rudalle/vae/vqgan.gumbelf8-sber.config.yml @@ -0,0 +1,34 @@ +model: + base_learning_rate: 4.5e-06 + target: taming.models.vqgan.GumbelVQ + params: + kl_weight: 1.0e-08 + embed_dim: 256 + n_embed: 8192 + monitor: val/rec_loss + temperature_scheduler_config: + target: taming.lr_scheduler.LambdaWarmUpCosineScheduler + params: + warm_up_steps: 0 + max_decay_steps: 1000001 + lr_start: 0.9 + lr_max: 0.9 + lr_min: 1.0e-06 + ddconfig: + double_z: false + z_channels: 256 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.DummyLoss diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..6eae07a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,13 @@ +[pep8] +max-line-length = 120 +exclude = .tox,*migrations*,.json + +[flake8] +max-line-length = 120 +exclude = .tox,*migrations*,.json + +[autopep8-wrapper] +exclude = .tox,*migrations*,.json + +[check-docstring-first] +exclude = .tox,*migrations*,.json diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0271408 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +import io +from os.path import abspath, dirname + +import PIL +import pytest +import requests + +from rudalle import get_tokenizer, get_rudalle_model, get_vae + + +TEST_ROOT = dirname(abspath(__file__)) + + +@pytest.fixture(scope='module') +def vae(): + vae = get_vae(pretrained=False) + yield vae + + +@pytest.fixture(scope='module') +def yttm_tokenizer(): + tokenizer = get_tokenizer() + yield tokenizer + + +@pytest.fixture(scope='module') +def sample_image(): + url = 'https://cdn.kqed.org/wp-content/uploads/sites/12/2013/12/rudolph.png' + resp = requests.get(url) + resp.raise_for_status() + image = PIL.Image.open(io.BytesIO(resp.content)) + yield image + + +@pytest.fixture(scope='module') +def small_dalle(): + model = get_rudalle_model('small', pretrained=False, fp16=False, device='cpu') + return model diff --git a/tests/test_dalle.py b/tests/test_dalle.py new file mode 100644 index 0000000..e3a54bf --- /dev/null +++ b/tests/test_dalle.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +import torch +import pytest + +from .test_vae import preprocess + + +@pytest.mark.parametrize('text', [ + 'мальчик играет с оленем', +]) +def test_forward_step_and_criterion(text, sample_image, yttm_tokenizer, vae, small_dalle): + bs = 4 + text_seq_length = small_dalle.get_param('text_seq_length') + total_seq_length = small_dalle.get_param('total_seq_length') + device = small_dalle.get_param('device') + + img = sample_image.copy() + img = preprocess(img, target_image_size=256) + images = img.repeat(bs, 1, 1, 1).to(device) + + text = text.lower().strip() + text_input_ids = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length) + text_input_ids = text_input_ids.unsqueeze(0).repeat(bs, 1).to(device) + + attention_mask = torch.tril(torch.ones((bs, 1, total_seq_length, total_seq_length), device=device)) + with torch.no_grad(): + image_input_ids = vae.get_codebook_indices(images) + input_ids = torch.cat((text_input_ids, image_input_ids), dim=1) + loss, loss_values = small_dalle.forward(input_ids, attention_mask, return_loss=True) + assert type(loss.data.detach().item()) == float + assert type(loss_values) == dict diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000..fcd1509 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +import pytest + + +@pytest.mark.parametrize('text, text_seq_length, bpe_dropout', [ + ('hello, how are you?', 128, 0.1), + ('hello, how are you?', 128, 0.5), + ('hello, how are you?', 128, 1.0), + ('hello ... how are you ?', 256, 1.0), + ('a person standing at a table with bottles of win', 64, 0.5), + ('привет как дела???', 76, 0.0), + ('клип на русском языке :)', 76, 0.1), +]) +def test_encode_decode_text_yttm(yttm_tokenizer, text, text_seq_length, bpe_dropout): + tokens = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length, bpe_dropout=bpe_dropout) + decoded_text = yttm_tokenizer.decode_text(tokens) + assert text == decoded_text diff --git a/tests/test_vae.py b/tests/test_vae.py new file mode 100644 index 0000000..0955573 --- /dev/null +++ b/tests/test_vae.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +import PIL +import pytest +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as TF + + +@pytest.mark.parametrize('target_image_size', [128, 192, 256]) +def test_decode_vae(vae, sample_image, target_image_size): + img = sample_image.copy() + img = preprocess(img, target_image_size=target_image_size) + with torch.no_grad(): + img_seq = vae.get_codebook_indices(img) + out_img = vae.decode(img_seq) + assert out_img.shape == (1, 3, target_image_size, target_image_size) + + +@pytest.mark.parametrize('target_image_size', [128, 192, 256]) +def test_reconstruct_vae(vae, sample_image, target_image_size): + img = sample_image.copy() + with torch.no_grad(): + x_vqgan = preprocess(img, target_image_size=target_image_size) + output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), vae.model) + assert output.shape == (1, 3, target_image_size, target_image_size) + + +def preprocess(img, target_image_size=256): + s = min(img.size) + if s < target_image_size: + raise ValueError(f'min dim for image {s} < {target_image_size}') + r = target_image_size / s + s = (round(r * img.size[1]), round(r * img.size[0])) + img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS) + img = TF.center_crop(img, output_size=2 * [target_image_size]) + img = torch.unsqueeze(T.ToTensor()(img), 0) + return img + + +def preprocess_vqgan(x): + x = 2.*x - 1. + return x + + +def reconstruct_with_vqgan(x, model): + z, _, [_, _, _] = model.encode(x) + print(f'VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}') + xrec = model.decode(z) + return xrec