| @ -0,0 +1,168 @@ | |||
| # Created by .ignore support plugin (hsz.mobi) | |||
| ### JetBrains template | |||
| # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm | |||
| # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | |||
| settings/local.py | |||
| logs/*.log | |||
| # User-specific stuff: | |||
| .idea/ | |||
| # Sensitive or high-churn files: | |||
| .idea/**/dataSources/ | |||
| .idea/**/dataSources.ids | |||
| .idea/**/dataSources.xml | |||
| .idea/**/dataSources.local.xml | |||
| .idea/**/sqlDataSources.xml | |||
| .idea/**/dynamic.xml | |||
| .idea/**/uiDesigner.xml | |||
| # Gradle: | |||
| .idea/**/gradle.xml | |||
| .idea/**/libraries | |||
| # CMake | |||
| cmake-build-debug/ | |||
| # Mongo Explorer plugin: | |||
| .idea/**/mongoSettings.xml | |||
| ## File-based project format: | |||
| *.iws | |||
| ## Plugin-specific files: | |||
| # IntelliJ | |||
| out/ | |||
| # mpeltonen/sbt-idea plugin | |||
| .idea_modules/ | |||
| # JIRA plugin | |||
| atlassian-ide-plugin.xml | |||
| # Cursive Clojure plugin | |||
| .idea/replstate.xml | |||
| # Crashlytics plugin (for Android Studio and IntelliJ) | |||
| com_crashlytics_export_strings.xml | |||
| crashlytics.properties | |||
| crashlytics-build.properties | |||
| fabric.properties | |||
| ### Python template | |||
| # Byte-compiled / optimized / DLL files | |||
| __pycache__/ | |||
| *.py[cod] | |||
| *$py.class | |||
| # C extensions | |||
| *.so | |||
| # Distribution / packaging | |||
| .Python | |||
| build/ | |||
| develop-eggs/ | |||
| dist/ | |||
| downloads/ | |||
| eggs/ | |||
| .eggs/ | |||
| lib/ | |||
| lib64/ | |||
| parts/ | |||
| sdist/ | |||
| var/ | |||
| wheels/ | |||
| *.egg-info/ | |||
| .installed.cfg | |||
| *.egg | |||
| # PyInstaller | |||
| # Usually these files are written by a python script from a template | |||
| # before PyInstaller builds the exe, so as to inject date/other infos into it. | |||
| *.manifest | |||
| *.spec | |||
| # Installer logs | |||
| pip-log.txt | |||
| pip-delete-this-directory.txt | |||
| # Unit test / coverage reports | |||
| htmlcov/ | |||
| .tox/ | |||
| .coverage | |||
| .coverage.* | |||
| .cache | |||
| nosetests.xml | |||
| coverage.xml | |||
| *.cover | |||
| .hypothesis/ | |||
| # Translations | |||
| *.mo | |||
| *.pot | |||
| # Django stuff: | |||
| *.log | |||
| local_settings.py | |||
| # Flask stuff: | |||
| instance/ | |||
| .webassets-cache | |||
| # Scrapy stuff: | |||
| .scrapy | |||
| # Sphinx documentation | |||
| docs/_build/ | |||
| # PyBuilder | |||
| target/ | |||
| # Jupyter Notebook | |||
| .ipynb_checkpoints | |||
| # pyenv | |||
| .python-version | |||
| # celery beat schedule file | |||
| celerybeat-schedule | |||
| # SageMath parsed files | |||
| *.sage.py | |||
| # Environments | |||
| .env | |||
| .venv | |||
| env/ | |||
| venv/ | |||
| ENV/ | |||
| # Spyder project settings | |||
| .spyderproject | |||
| .spyproject | |||
| # Rope project settings | |||
| .ropeproject | |||
| # mkdocs documentation | |||
| /site | |||
| # mypy | |||
| .mypy_cache/ | |||
| /tests/load_tests/logs/* | |||
| /tests/.pytest_cache/ | |||
| ws_test.py | |||
| /.vscode/ | |||
| .s3_cache/ | |||
| mlruns | |||
| *.pyc | |||
| *.swp | |||
| *.pt | |||
| *.bin | |||
| .vscode/ | |||
| runs/ | |||
| jupyters/custom_* | |||
| *logs/ | |||
| @ -0,0 +1,31 @@ | |||
| repos: | |||
| - repo: https://github.com/pre-commit/pre-commit-hooks | |||
| rev: v2.2.3 | |||
| hooks: | |||
| - id: check-docstring-first | |||
| stages: | |||
| - commit | |||
| - push | |||
| - id: check-merge-conflict | |||
| stages: | |||
| - push | |||
| - id: double-quote-string-fixer | |||
| stages: | |||
| - commit | |||
| - push | |||
| - id: fix-encoding-pragma | |||
| stages: | |||
| - commit | |||
| - push | |||
| - id: flake8 | |||
| args: ['--config=setup.cfg'] | |||
| stages: | |||
| - commit | |||
| - push | |||
| - repo: https://github.com/pre-commit/mirrors-autopep8 | |||
| rev: v1.4.4 | |||
| hooks: | |||
| - id: autopep8 | |||
| stages: | |||
| - commit | |||
| - push | |||
| @ -0,0 +1,4 @@ | |||
| -r requirements.txt | |||
| pytest | |||
| pytest-cov | |||
| pre-commit | |||
| @ -0,0 +1,8 @@ | |||
| taming-transformers==0.0.1 | |||
| more_itertools==8.10.0 | |||
| transformers==4.10.2 | |||
| youtokentome==1.0.6 | |||
| einops==0.3.2 | |||
| torch | |||
| torchvision | |||
| matplotlib | |||
| @ -0,0 +1,15 @@ | |||
| # -*- coding: utf-8 -*- | |||
| from .vae import get_vae | |||
| from .dalle import get_rudalle_model | |||
| from .tokenizer import get_tokenizer | |||
| from . import vae, dalle, tokenizer | |||
| __all__ = [ | |||
| 'get_vae', | |||
| 'get_rudalle_model', | |||
| 'get_tokenizer', | |||
| 'vae', | |||
| 'dalle', | |||
| 'tokenizer', | |||
| ] | |||
| @ -0,0 +1,76 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import torch | |||
| from huggingface_hub import hf_hub_url, cached_download | |||
| from .model import DalleModel | |||
| from .fp16 import FP16Module | |||
| MODELS = { | |||
| 'Malevich': dict( | |||
| description='◼️ Malevich is 1.3 billion params model from the family GPT3-like, ' | |||
| 'that uses Russian language and text+image multi-modality.', | |||
| model_params=dict( | |||
| num_layers=24, | |||
| hidden_size=2048, | |||
| num_attention_heads=16, | |||
| embedding_dropout_prob=0.1, | |||
| output_dropout_prob=0.1, | |||
| attention_dropout_prob=0.1, | |||
| image_tokens_per_dim=32, | |||
| text_seq_length=128, | |||
| use_masks=True, | |||
| cogview_sandwich_layernorm=True, | |||
| cogview_pb_relax=True, | |||
| vocab_size=16384+128, | |||
| image_vocab_size=8192, | |||
| ), | |||
| repo_id='sberbank-ai/rudalle-Malevich', | |||
| filename='pytorch_model.bin', | |||
| full_description='', # TODO | |||
| ), | |||
| 'small': dict( | |||
| description='', | |||
| model_params=dict( | |||
| num_layers=12, | |||
| hidden_size=768, | |||
| num_attention_heads=12, | |||
| embedding_dropout_prob=0.1, | |||
| output_dropout_prob=0.1, | |||
| attention_dropout_prob=0.1, | |||
| image_tokens_per_dim=32, | |||
| text_seq_length=128, | |||
| use_masks=True, | |||
| cogview_sandwich_layernorm=True, | |||
| cogview_pb_relax=True, | |||
| vocab_size=16384+128, | |||
| image_vocab_size=8192, | |||
| ), | |||
| repo_id='', | |||
| filename='', | |||
| full_description='', # TODO | |||
| ), | |||
| } | |||
| def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle'): | |||
| # TODO docstring | |||
| assert name in MODELS | |||
| config = MODELS[name] | |||
| model = DalleModel(device=device, fp16=fp16, **config['model_params']) | |||
| if pretrained: | |||
| cache_dir = os.path.join(cache_dir, name) | |||
| config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) | |||
| cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) | |||
| checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu') | |||
| model.load_state_dict(checkpoint) | |||
| if fp16: | |||
| model = FP16Module(model) | |||
| model.eval() | |||
| model = model.to(device) | |||
| if config['description'] and pretrained: | |||
| print(config['description']) | |||
| return model | |||
| @ -0,0 +1,60 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import torch | |||
| from torch import nn | |||
| from torch.autograd import Variable | |||
| from torch.nn.parameter import Parameter | |||
| FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) | |||
| HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) | |||
| def conversion_helper(val, conversion): | |||
| """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" | |||
| if not isinstance(val, (tuple, list)): | |||
| return conversion(val) | |||
| rtn = [conversion_helper(v, conversion) for v in val] | |||
| if isinstance(val, tuple): | |||
| rtn = tuple(rtn) | |||
| return rtn | |||
| def fp32_to_fp16(val): | |||
| """Convert fp32 `val` to fp16""" | |||
| def half_conversion(val): | |||
| val_typecheck = val | |||
| if isinstance(val_typecheck, (Parameter, Variable)): | |||
| val_typecheck = val.data | |||
| if isinstance(val_typecheck, FLOAT_TYPES): | |||
| val = val.half() | |||
| return val | |||
| return conversion_helper(val, half_conversion) | |||
| def fp16_to_fp32(val): | |||
| """Convert fp16 `val` to fp32""" | |||
| def float_conversion(val): | |||
| val_typecheck = val | |||
| if isinstance(val_typecheck, (Parameter, Variable)): | |||
| val_typecheck = val.data | |||
| if isinstance(val_typecheck, HALF_TYPES): | |||
| val = val.float() | |||
| return val | |||
| return conversion_helper(val, float_conversion) | |||
| class FP16Module(nn.Module): | |||
| def __init__(self, module): | |||
| super(FP16Module, self).__init__() | |||
| self.add_module('module', module.half()) | |||
| def forward(self, *inputs, **kwargs): | |||
| return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) | |||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||
| return self.module.state_dict(destination, prefix, keep_vars) | |||
| def load_state_dict(self, state_dict, strict=True): | |||
| self.module.load_state_dict(state_dict, strict=strict) | |||
| def get_param(self, item): | |||
| return self.module.get_param(item) | |||
| @ -0,0 +1,48 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import torch | |||
| def _init_mask(text_tokens, image_tokens_per_dim): | |||
| attn_size = text_tokens + image_tokens_per_dim**2 | |||
| mask = torch.tril(torch.ones(attn_size, attn_size)) | |||
| return mask | |||
| def get_row_mask(text_tokens=256, image_tokens_per_dim=32): | |||
| mask = _init_mask(text_tokens, image_tokens_per_dim) | |||
| step = image_tokens_per_dim + 1 | |||
| for col in range(text_tokens, mask.size(1)): | |||
| mask[col + step:, col] = 0.0 | |||
| return mask | |||
| def get_col_mask(text_tokens=256, image_tokens_per_dim=32): | |||
| mask = _init_mask(text_tokens, image_tokens_per_dim) | |||
| step = image_tokens_per_dim - 1 | |||
| for col in range(text_tokens, mask.size(1)): | |||
| for i in range(1, mask.size(0), step+1): | |||
| mask[col + i: col + i + step, col] = 0.0 | |||
| return mask | |||
| def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11): | |||
| mask = _init_mask(text_tokens, image_tokens_per_dim) | |||
| shift = kernel // 2 | |||
| for pos in range(text_tokens, mask.size(1)): | |||
| mask[pos+1:, pos] = 0.0 | |||
| img = torch.zeros(image_tokens_per_dim, image_tokens_per_dim) | |||
| pixel_id = pos - text_tokens | |||
| row = pixel_id // image_tokens_per_dim | |||
| col = pixel_id % image_tokens_per_dim | |||
| for r in range(-shift, shift+1): | |||
| for c in range(-shift, shift+1): | |||
| c_abs = (c + col) % image_tokens_per_dim | |||
| r_abs = (r + row) % image_tokens_per_dim | |||
| img[r_abs, c_abs] = 0.2 | |||
| cell_id = r_abs * image_tokens_per_dim + c_abs | |||
| if text_tokens + cell_id > pos: | |||
| mask[text_tokens + cell_id, pos] = 1.0 | |||
| img[row, col] = 1.0 | |||
| return mask | |||
| @ -0,0 +1,171 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from einops import rearrange | |||
| from .utils import exists, is_empty, init_method_normal | |||
| from .image_attention import get_conv_mask, get_row_mask, get_col_mask | |||
| from .transformer import DalleTransformer | |||
| class DalleModel(torch.nn.Module): | |||
| def __init__(self, | |||
| device, | |||
| num_layers, | |||
| vocab_size, | |||
| hidden_size, | |||
| num_attention_heads, | |||
| embedding_dropout_prob, | |||
| attention_dropout_prob, | |||
| output_dropout_prob, | |||
| text_seq_length=128, | |||
| image_tokens_per_dim=32, | |||
| image_vocab_size=16384, | |||
| loss_img_weight=7, | |||
| fp16=False, | |||
| use_masks=True, | |||
| cogview_sandwich_layernorm=False, | |||
| cogview_pb_relax=False): | |||
| super(DalleModel, self).__init__() | |||
| self.device = device | |||
| self.fp16 = fp16 | |||
| self.image_tokens_per_dim = image_tokens_per_dim | |||
| self.image_seq_length = image_tokens_per_dim ** 2 | |||
| self.text_seq_length = text_seq_length | |||
| self.total_seq_length = self.text_seq_length + self.image_seq_length | |||
| self.total_vocab_size = vocab_size + image_vocab_size | |||
| self.vocab_size = vocab_size | |||
| self.loss_img_weight = loss_img_weight | |||
| # TODO "to" | |||
| mask_map = self.prepare_image_masks(num_layers, text_seq_length, image_tokens_per_dim) | |||
| if use_masks: | |||
| self._mask_map = mask_map | |||
| else: | |||
| self._mask_map = [] | |||
| init_method = init_method_normal(std=0.02) | |||
| self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size) | |||
| self.image_embeddings = torch.nn.Embedding(image_vocab_size, hidden_size) | |||
| # Position embedding (serial). | |||
| self.text_pos_embeddings = torch.nn.Embedding(text_seq_length + 1, hidden_size) | |||
| self.image_row_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size) | |||
| self.image_col_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size) | |||
| init_method(self.text_pos_embeddings.weight) | |||
| init_method(self.image_row_embeddings.weight) | |||
| init_method(self.image_col_embeddings.weight) | |||
| self.to_logits = torch.nn.Sequential( | |||
| torch.nn.LayerNorm(hidden_size), | |||
| torch.nn.Linear(hidden_size, self.total_vocab_size), | |||
| ) | |||
| # Embeddings dropout | |||
| self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) | |||
| # Transformer | |||
| self.transformer = DalleTransformer( | |||
| num_layers, | |||
| hidden_size, | |||
| num_attention_heads, | |||
| attention_dropout_prob, | |||
| output_dropout_prob, | |||
| cogview_sandwich_layernorm=cogview_sandwich_layernorm, | |||
| cogview_pb_relax=cogview_pb_relax, | |||
| ) | |||
| self.transformer._mask_map = self._mask_map | |||
| def get_param(self, item): | |||
| return getattr(self, item) | |||
| def prepare_image_masks(self, num_layers, text_seq_length, image_tokens_per_dim): | |||
| row_mask = get_row_mask(text_seq_length, image_tokens_per_dim).to(self.device) | |||
| col_mask = get_col_mask(text_seq_length, image_tokens_per_dim).to(self.device) | |||
| conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim).to(self.device) | |||
| if self.fp16: | |||
| row_mask = row_mask.half() | |||
| col_mask = col_mask.half() | |||
| conv_mask = conv_mask.half() | |||
| self.register_buffer('row_mask', row_mask) | |||
| self.register_buffer('col_mask', col_mask) | |||
| self.register_buffer('conv_mask', conv_mask) | |||
| mask_map = [] | |||
| for i in range(num_layers): | |||
| if ((i - 1) % 4 == 0): | |||
| mask_map.append(col_mask) | |||
| elif i != num_layers - 1: | |||
| mask_map.append(row_mask) | |||
| else: | |||
| mask_map.append(conv_mask) | |||
| return mask_map | |||
| def get_image_pos_embeddings(self, image_input_ids, past_length=0): | |||
| input_shape = image_input_ids.size() | |||
| row_ids = torch.arange(past_length, input_shape[-1] + past_length, | |||
| dtype=torch.long, device=self.device) // self.image_tokens_per_dim | |||
| row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1]) | |||
| col_ids = torch.arange(past_length, input_shape[-1] + past_length, | |||
| dtype=torch.long, device=self.device) % self.image_tokens_per_dim | |||
| col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1]) | |||
| return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids) | |||
| def forward( | |||
| self, | |||
| input_ids, | |||
| attention_mask, | |||
| return_loss=False, | |||
| has_cache=False, | |||
| use_cache=False, | |||
| ): | |||
| text = input_ids[:, :self.text_seq_length] | |||
| text_range = torch.arange(self.text_seq_length) | |||
| text_range += (self.vocab_size - self.text_seq_length) | |||
| text_range = text_range.to(self.device) | |||
| text = torch.where(text == 0, text_range, text) | |||
| # some hardcode :) | |||
| text = F.pad(text, (1, 0), value=2) | |||
| text_embeddings = self.text_embeddings(text) + \ | |||
| self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device)) | |||
| image_input_ids = input_ids[:, self.text_seq_length:] | |||
| if exists(image_input_ids) and not is_empty(image_input_ids): | |||
| image_embeddings = self.image_embeddings(image_input_ids) + \ | |||
| self.get_image_pos_embeddings(image_input_ids, past_length=0) | |||
| embeddings = torch.cat((text_embeddings, image_embeddings), dim=1) | |||
| else: | |||
| embeddings = text_embeddings | |||
| # some hardcode :) | |||
| if embeddings.shape[1] > self.total_seq_length: | |||
| embeddings = embeddings[:, :-1] | |||
| alpha = 0.1 | |||
| embeddings = embeddings * alpha + embeddings.detach() * (1-alpha) | |||
| attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]] | |||
| transformer_output, present_has_cache = self.transformer( | |||
| embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache) | |||
| logits = self.to_logits(transformer_output) | |||
| if return_loss is False: | |||
| return logits, present_has_cache | |||
| labels = torch.cat((text[:, 1:], image_input_ids), dim=1).contiguous().long() | |||
| logits = rearrange(logits, 'b n c -> b c n') | |||
| text_logits = logits[:, :self.vocab_size, :self.text_seq_length].contiguous().float() | |||
| image_logits = logits[:, self.vocab_size:, self.text_seq_length:].contiguous().float() | |||
| loss_text = F.cross_entropy( | |||
| text_logits, | |||
| labels[:, :self.text_seq_length]) | |||
| loss_img = F.cross_entropy( | |||
| image_logits, | |||
| labels[:, self.text_seq_length:]) | |||
| loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1) | |||
| return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()} | |||
| @ -0,0 +1,332 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import math | |||
| import torch | |||
| from torch.nn import LayerNorm | |||
| from .utils import divide, split_tensor_along_last_dim | |||
| @torch.jit.script | |||
| def gelu_impl(x): | |||
| """OpenAI's gelu implementation.""" | |||
| return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) | |||
| def gelu(x): | |||
| return gelu_impl(x) | |||
| class DalleTransformer(torch.nn.Module): | |||
| """ | |||
| This module takes input from embedding layer and it's output can | |||
| be used directly by a logit layer. It consists of L (num-layers) | |||
| blocks of: | |||
| layer norm | |||
| self attention | |||
| residual connection | |||
| layer norm | |||
| mlp | |||
| residual connection | |||
| followed by a final layer norm. | |||
| Arguments: | |||
| num_layers: Number of transformer layers. | |||
| hidden_size: The hidden size of the self attention. | |||
| num_attention_heads: number of attention head in the self | |||
| attention. | |||
| attention_dropout_prob: dropout probability of the attention | |||
| score in self attention. | |||
| output_dropout_prob: dropout probability for the outputs | |||
| after self attention and final output. | |||
| layernorm_epsilon: epsilon used in layernorm to avoid | |||
| division by zero. | |||
| """ | |||
| _mask_map = [] | |||
| def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, | |||
| layernorm_epsilon=1.0e-5, cogview_sandwich_layernorm=False, cogview_pb_relax=False): | |||
| super(DalleTransformer, self).__init__() | |||
| # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf | |||
| self.cogview_pb_relax = cogview_pb_relax | |||
| # Transformer layers. | |||
| self.layers = torch.nn.ModuleList([ | |||
| DalleTransformerLayer( | |||
| hidden_size, | |||
| num_attention_heads, | |||
| attention_dropout_prob, | |||
| output_dropout_prob, | |||
| layernorm_epsilon, | |||
| cogview_sandwich_layernorm=cogview_sandwich_layernorm, | |||
| cogview_pb_relax=cogview_pb_relax, | |||
| ) for _ in range(num_layers) | |||
| ]) | |||
| # Final layer norm before output. | |||
| self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) | |||
| def forward(self, hidden_states, attention_mask, has_cache, use_cache): | |||
| for i, layer in enumerate(self.layers): | |||
| mask = attention_mask | |||
| if len(self._mask_map): | |||
| layer_mask = self._mask_map[i][:mask.size(2), :mask.size(3)] | |||
| mask = torch.mul(attention_mask, layer_mask) | |||
| hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache) | |||
| output = self.final_layernorm(hidden_states) | |||
| return output, present_has_cache | |||
| class DalleTransformerLayer(torch.nn.Module): | |||
| """ | |||
| A single layer transformer. | |||
| We use the following notation: | |||
| h: hidden size | |||
| n: number of attention heads | |||
| b: batch size | |||
| s: sequence length | |||
| Transformer layer takes input with size [b, s, h] and returns an | |||
| output of the same size. | |||
| Arguments: | |||
| hidden_size: The hidden size of the self attention. | |||
| num_attention_heads: number of attention head in the self | |||
| attention. | |||
| attention_dropout_prob: dropout probability of the attention | |||
| score in self attention. | |||
| output_dropout_prob: dropout probability for the outputs | |||
| after self attention and final output. | |||
| layernorm_epsilon: epsilon used in layernorm to avoid | |||
| division by zero. | |||
| """ | |||
| def __init__(self, | |||
| hidden_size, | |||
| num_attention_heads, | |||
| attention_dropout_prob, | |||
| output_dropout_prob, | |||
| layernorm_epsilon, | |||
| cogview_sandwich_layernorm=False, | |||
| cogview_pb_relax=False): | |||
| super(DalleTransformerLayer, self).__init__() | |||
| # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf | |||
| self.cogview_sandwich_layernorm = cogview_sandwich_layernorm | |||
| self.cogview_pb_relax = cogview_pb_relax | |||
| # Layernorm on the input data. | |||
| self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) | |||
| if self.cogview_sandwich_layernorm: | |||
| self.before_first_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) | |||
| self.before_second_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) | |||
| # Self attention. | |||
| self.attention = DalleSelfAttention( | |||
| hidden_size, | |||
| num_attention_heads, | |||
| attention_dropout_prob, | |||
| output_dropout_prob, | |||
| cogview_pb_relax=cogview_pb_relax | |||
| ) | |||
| # Layernorm on the input data. | |||
| self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) | |||
| # MLP | |||
| self.mlp = DalleMLP(hidden_size, output_dropout_prob) | |||
| def forward(self, hidden_states, ltor_mask, has_cache, use_cache): | |||
| # hidden_states: [b, s, h] | |||
| # ltor_mask: [1, 1, s, s] | |||
| # Layer norm at the begining of the transformer layer. | |||
| layernorm_output = self.input_layernorm(hidden_states) | |||
| # Self attention. | |||
| attention_output, has_cache = self.attention( | |||
| layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache) | |||
| if self.cogview_sandwich_layernorm: | |||
| attention_output = self.before_first_addition_layernorm(attention_output) | |||
| # Residual connection. | |||
| layernorm_input = hidden_states + attention_output | |||
| # Layer norm post the self attention. | |||
| layernorm_output = self.post_attention_layernorm(layernorm_input) | |||
| # MLP. | |||
| mlp_output = self.mlp(layernorm_output) | |||
| if self.cogview_sandwich_layernorm: | |||
| mlp_output = self.before_second_addition_layernorm(mlp_output) | |||
| # Second residual connection. | |||
| output = layernorm_input + mlp_output | |||
| return output, has_cache | |||
| class DalleSelfAttention(torch.nn.Module): | |||
| """ | |||
| Self-attention layer takes input with size [b, s, h] where b is | |||
| the batch size, s is the sequence length, and h is the hidden size | |||
| and creates output of the same size. | |||
| Arguments: | |||
| hidden_size: total hidden size of the layer (h). | |||
| num_attention_heads: number of attention heads (n). Note that we | |||
| require n to be divisible by number of GPUs | |||
| used to parallelize the model. Also, we | |||
| require hidden size to be divisible by n. | |||
| attention_dropout_prob: dropout probability for the attention scores. | |||
| output_dropout_prob: dropout probability for the output. | |||
| We use the following notation: | |||
| h: hidden_size | |||
| n: num_attention_heads | |||
| p: number of partitions | |||
| np: n/p | |||
| hp: h/p | |||
| hn: h/n | |||
| b: batch size | |||
| s: sequence length | |||
| """ | |||
| def __init__(self, hidden_size, num_attention_heads, | |||
| attention_dropout_prob, output_dropout_prob, cogview_pb_relax=False): | |||
| super(DalleSelfAttention, self).__init__() | |||
| # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf | |||
| self.cogview_pb_relax = cogview_pb_relax | |||
| self.hidden_size = hidden_size | |||
| self.num_attention_heads = num_attention_heads | |||
| self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) | |||
| self.query_key_value = torch.nn.Linear(hidden_size, 3*hidden_size) | |||
| self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) | |||
| # Output. | |||
| self.dense = torch.nn.Linear(hidden_size, hidden_size) | |||
| self.output_dropout = torch.nn.Dropout(output_dropout_prob) | |||
| def _transpose_for_scores(self, tensor): | |||
| """ Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """ | |||
| new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head) | |||
| tensor = tensor.view(*new_tensor_shape) | |||
| return tensor.permute(0, 2, 1, 3) | |||
| def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask): | |||
| key_t = key_layer.transpose(-1, -2) | |||
| if self.cogview_pb_relax: | |||
| attention_scores = torch.matmul( | |||
| query_layer / math.sqrt(self.hidden_size_per_attention_head), | |||
| key_t | |||
| ) | |||
| else: | |||
| attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head) | |||
| attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask) | |||
| if self.cogview_pb_relax: | |||
| # normalize attention scores. Should not affect resulting softmax value | |||
| alpha = 32 | |||
| attention_scores_scaled = attention_scores / alpha | |||
| attention_scores_scaled_maxes, _ = attention_scores_scaled.detach().view( | |||
| [attention_scores.size(0), attention_scores.size(1), -1] | |||
| ).max(dim=-1) # max per head per sample | |||
| attention_scores_scaled_maxes = attention_scores_scaled_maxes.unsqueeze(-1).unsqueeze(-1).expand( | |||
| [-1, -1, attention_scores.size(2), attention_scores.size(3)] | |||
| ) # expand to [b, np, s, s] | |||
| attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha | |||
| return attention_scores | |||
| def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,): | |||
| # hidden_states: [b, s, h] | |||
| # ltor_mask: [1, 1, s, s] | |||
| # Attention heads. [b, s, hp] | |||
| if has_cache and use_cache: | |||
| mixed_x_layer = self.query_key_value(hidden_states[:, -1:, :]) | |||
| else: | |||
| mixed_x_layer = self.query_key_value(hidden_states) | |||
| (mixed_query_layer, | |||
| mixed_key_layer, | |||
| mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) | |||
| query_layer = self._transpose_for_scores(mixed_query_layer) | |||
| key_layer = self._transpose_for_scores(mixed_key_layer) | |||
| value_layer = self._transpose_for_scores(mixed_value_layer) | |||
| if use_cache and has_cache: | |||
| value_layer = torch.cat((self.past_value, value_layer), dim=-2) | |||
| query_layer = torch.cat((self.past_query, query_layer), dim=-2) | |||
| key_layer = torch.cat((self.past_key, key_layer), dim=-2) | |||
| attention_scores = self._calculate_attention_scores( | |||
| query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask | |||
| ) | |||
| else: | |||
| attention_scores = self._calculate_attention_scores( | |||
| query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask | |||
| ) | |||
| if use_cache: | |||
| self.past_query = query_layer | |||
| self.past_key = key_layer | |||
| self.past_value = value_layer | |||
| has_cache = True | |||
| else: | |||
| has_cache = False | |||
| # Attention probabilities. [b, np, s, s] | |||
| attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) | |||
| # This is actually dropping out entire tokens to attend to, which might | |||
| # seem a bit unusual, but is taken from the original Transformer paper. | |||
| attention_probs = self.attention_dropout(attention_probs) | |||
| # Context layer. | |||
| # [b, np, s, hn] | |||
| context_layer = torch.matmul(attention_probs, value_layer) | |||
| # [b, s, np, hn] | |||
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
| new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) | |||
| # [b, s, hp] | |||
| context_layer = context_layer.view(*new_context_layer_shape) | |||
| # Output. [b, s, h] | |||
| output = self.dense(context_layer) | |||
| output = self.output_dropout(output) | |||
| return output, has_cache | |||
| class DalleMLP(torch.nn.Module): | |||
| """ | |||
| MLP will take the input with h hidden state, project it to 4*h | |||
| hidden dimension, perform gelu transformation, and project the | |||
| state back into h hidden dimension. At the end, dropout is also | |||
| applied. | |||
| Arguments: | |||
| hidden_size: The hidden size of the self attention. | |||
| output_dropout_prob: dropout probability for the outputs | |||
| after self attention and final output. | |||
| """ | |||
| def __init__(self, hidden_size, output_dropout_prob): | |||
| super(DalleMLP, self).__init__() | |||
| # Project to 4h. | |||
| self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4*hidden_size) | |||
| # Project back to h. | |||
| self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size) | |||
| self.dropout = torch.nn.Dropout(output_dropout_prob) | |||
| def forward(self, hidden_states): | |||
| # [b, s, 4hp] | |||
| x = self.dense_h_to_4h(hidden_states) | |||
| x = gelu(x) | |||
| # [b, s, h] | |||
| x = self.dense_4h_to_h(x) | |||
| output = self.dropout(x) | |||
| return output | |||
| @ -0,0 +1,54 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import torch | |||
| def exists(val): | |||
| return val is not None | |||
| def is_empty(t): | |||
| return t.nelement() == 0 | |||
| def ensure_divisibility(numerator, denominator): | |||
| """Ensure that numerator is divisible by the denominator.""" | |||
| assert numerator % denominator == 0, '{} is not divisible by {}'.format( | |||
| numerator, denominator) | |||
| def divide(numerator, denominator): | |||
| """Ensure that numerator is divisible by the denominator and return | |||
| the division value.""" | |||
| ensure_divisibility(numerator, denominator) | |||
| return numerator // denominator | |||
| def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): | |||
| """ | |||
| Split a tensor along its last dimension. | |||
| Arguments: | |||
| tensor: input tensor. | |||
| num_partitions: number of partitions to split the tensor | |||
| contiguous_split_chunks: If True, make each chunk contiguous | |||
| in memory. | |||
| """ | |||
| # Get the size and dimension. | |||
| last_dim = tensor.dim() - 1 | |||
| last_dim_size = divide(tensor.size()[last_dim], num_partitions) | |||
| # Split. | |||
| tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) | |||
| # Note: torch.split does not create contiguous tensors by default. | |||
| if contiguous_split_chunks: | |||
| return tuple(chunk.contiguous() for chunk in tensor_list) | |||
| return tensor_list | |||
| def init_method_normal(std=0.02): | |||
| """Init method based on normal distribution. | |||
| This is only used for embeddings. The transformer has its | |||
| own initializer. | |||
| """ | |||
| def init_(tensor): | |||
| return torch.nn.init.normal_(tensor, mean=0.0, std=std) | |||
| return init_ | |||
| @ -0,0 +1,61 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import torch | |||
| import torchvision | |||
| import transformers | |||
| import more_itertools | |||
| import numpy as np | |||
| import matplotlib.pyplot as plt | |||
| from tqdm import tqdm | |||
| from . import utils | |||
| def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, temperature=1.0, bs=8, seed=None, | |||
| use_cache=True): | |||
| # TODO docstring | |||
| if seed is not None: | |||
| utils.seed_everything(seed) | |||
| vocab_size = dalle.get_param('vocab_size') | |||
| text_seq_length = dalle.get_param('text_seq_length') | |||
| image_seq_length = dalle.get_param('image_seq_length') | |||
| total_seq_length = dalle.get_param('total_seq_length') | |||
| device = dalle.get_param('device') | |||
| text = text.lower().strip() | |||
| input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length) | |||
| pil_images, scores = [], [] | |||
| for chunk in more_itertools.chunked(range(images_num), bs): | |||
| chunk_bs = len(chunk) | |||
| with torch.no_grad(): | |||
| attention_mask = torch.tril(torch.ones((chunk_bs, 1, total_seq_length, total_seq_length), device=device)) | |||
| out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device) | |||
| has_cache = False | |||
| sample_scores = [] | |||
| for _ in tqdm(range(out.shape[1], total_seq_length)): | |||
| logits, has_cache = dalle(out, attention_mask, | |||
| has_cache=has_cache, use_cache=use_cache, return_loss=False) | |||
| logits = logits[:, -1, vocab_size:] | |||
| logits /= temperature | |||
| filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |||
| probs = torch.nn.functional.softmax(filtered_logits, dim=-1) | |||
| sample = torch.multinomial(probs, 1) | |||
| sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)]) | |||
| out = torch.cat((out, sample), dim=-1) | |||
| codebooks = out[:, -image_seq_length:] | |||
| images = vae.decode(codebooks) | |||
| pil_images += utils.torch_tensors_to_pil_list(images) | |||
| scores += torch.cat(sample_scores).sum(0).detach().cpu().numpy().tolist() | |||
| return pil_images, scores | |||
| def show(pil_images, nrow=4): | |||
| imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) | |||
| if not isinstance(imgs, list): | |||
| imgs = [imgs.cpu()] | |||
| fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14)) | |||
| for i, img in enumerate(imgs): | |||
| img = img.detach() | |||
| img = torchvision.transforms.functional.to_pil_image(img) | |||
| axs[0, i].imshow(np.asarray(img)) | |||
| axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | |||
| @ -0,0 +1,64 @@ | |||
| # -*- coding: utf-8 -*- | |||
| from os.path import join | |||
| import torch | |||
| import numpy as np | |||
| import youtokentome as yttm | |||
| from huggingface_hub import hf_hub_url, cached_download | |||
| def get_tokenizer(path=None, cache_dir='/tmp/rudalle'): | |||
| # TODO docstring | |||
| if path is None: | |||
| repo_id = 'shonenkov/rudalle-utils' | |||
| filename = 'bpe.model' | |||
| cache_dir = join(cache_dir, 'tokenizer') | |||
| config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) | |||
| cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) | |||
| path = join(cache_dir, filename) | |||
| tokenizer = YTTMTokenizerWrapper(yttm.BPE(model=path)) | |||
| print('tokenizer --> ready') | |||
| return tokenizer | |||
| class YTTMTokenizerWrapper: | |||
| eos_id = 3 | |||
| bos_id = 2 | |||
| unk_id = 1 | |||
| pad_id = 0 | |||
| def __init__(self, tokenizer): | |||
| self.tokenizer = tokenizer | |||
| def __len__(self): | |||
| return self.vocab_size() | |||
| def get_pad_token_id(self): | |||
| # TODO docstring | |||
| return self.tokenizer.subword_to_id('<PAD>') | |||
| def vocab_size(self): | |||
| # TODO docstring | |||
| return self.tokenizer.vocab_size() | |||
| def encode_text(self, text, text_seq_length, bpe_dropout=0.0): | |||
| # TODO docstring | |||
| tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=bpe_dropout)[0] | |||
| tokens = [self.bos_id] + tokens + [self.eos_id] | |||
| return self.prepare_tokens(tokens, text_seq_length) | |||
| def decode_text(self, encoded): | |||
| # TODO docstring | |||
| return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ | |||
| self.eos_id, self.bos_id, self.unk_id, self.pad_id | |||
| ])[0] | |||
| @staticmethod | |||
| def prepare_tokens(tokens, text_seq_length): | |||
| # TODO docstring | |||
| empty_positions = text_seq_length - len(tokens) | |||
| if empty_positions > 0: | |||
| tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text | |||
| if len(tokens) > text_seq_length: | |||
| tokens = tokens[:text_seq_length] | |||
| return torch.tensor(tokens).long() | |||
| @ -0,0 +1,36 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import random | |||
| import torch | |||
| import torchvision | |||
| import numpy as np | |||
| def seed_everything(seed): | |||
| random.seed(seed) | |||
| os.environ['PYTHONHASHSEED'] = str(seed) | |||
| np.random.seed(seed) | |||
| torch.manual_seed(seed) | |||
| torch.cuda.manual_seed(seed) | |||
| torch.backends.cudnn.deterministic = True | |||
| torch.backends.cudnn.benchmark = True | |||
| def torch_tensors_to_pil_list(input_images): | |||
| out_images = [] | |||
| for in_image in input_images: | |||
| in_image = in_image.cpu().detach() | |||
| out_image = torchvision.transforms.functional.to_pil_image(in_image).convert('RGB') | |||
| out_images.append(out_image) | |||
| return out_images | |||
| def pil_list_to_torch_tensors(pil_images): | |||
| result = [] | |||
| for pil_image in pil_images: | |||
| image = np.array(pil_image, dtype=np.uint8) | |||
| image = torch.from_numpy(image) | |||
| image = image.permute(2, 0, 1).unsqueeze(0) | |||
| result.append(image) | |||
| return torch.cat(result, dim=0) | |||
| @ -0,0 +1,24 @@ | |||
| # -*- coding: utf-8 -*- | |||
| from os.path import dirname, abspath, join | |||
| import torch | |||
| from huggingface_hub import hf_hub_url, cached_download | |||
| from omegaconf import OmegaConf | |||
| from .model import VQGanGumbelVAE | |||
| def get_vae(pretrained=True, cache_dir='/tmp/rudalle'): | |||
| # TODO | |||
| config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml')) | |||
| vae = VQGanGumbelVAE(config) | |||
| if pretrained: | |||
| repo_id = 'shonenkov/rudalle-utils' | |||
| filename = 'vqgan.gumbelf8-sber.model.ckpt' | |||
| cache_dir = join(cache_dir, 'vae') | |||
| config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) | |||
| cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) | |||
| checkpoint = torch.load(join(cache_dir, filename), map_location='cpu') | |||
| vae.model.load_state_dict(checkpoint['state_dict'], strict=False) | |||
| print('vae --> ready') | |||
| return vae | |||
| @ -0,0 +1,100 @@ | |||
| # -*- coding: utf-8 -*- | |||
| from math import sqrt, log | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch import einsum | |||
| from einops import rearrange | |||
| from taming.modules.diffusionmodules.model import Encoder, Decoder | |||
| class VQGanGumbelVAE(torch.nn.Module): | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| model = GumbelVQ( | |||
| ddconfig=config.model.params.ddconfig, | |||
| n_embed=config.model.params.n_embed, | |||
| embed_dim=config.model.params.embed_dim, | |||
| kl_weight=config.model.params.kl_weight, | |||
| ) | |||
| self.model = model | |||
| self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2)) | |||
| self.image_size = 256 | |||
| self.num_tokens = config.model.params.n_embed | |||
| @torch.no_grad() | |||
| def get_codebook_indices(self, img): | |||
| img = (2 * img) - 1 | |||
| _, _, [_, _, indices] = self.model.encode(img) | |||
| return rearrange(indices, 'b h w -> b (h w)') | |||
| def decode(self, img_seq): | |||
| b, n = img_seq.shape | |||
| one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.num_tokens).float() | |||
| z = (one_hot_indices @ self.model.quantize.embed.weight) | |||
| z = rearrange(z, 'b (h w) c -> b c h w', h=int(sqrt(n))) | |||
| img = self.model.decode(z) | |||
| img = (img.clamp(-1., 1.) + 1) * 0.5 | |||
| return img | |||
| class GumbelQuantize(nn.Module): | |||
| """ | |||
| credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) | |||
| Gumbel Softmax trick quantizer | |||
| Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 | |||
| https://arxiv.org/abs/1611.01144 | |||
| """ | |||
| def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, | |||
| kl_weight=5e-4, temp_init=1.0, use_vqinterface=True): | |||
| super().__init__() | |||
| self.embedding_dim = embedding_dim | |||
| self.n_embed = n_embed | |||
| self.straight_through = straight_through | |||
| self.temperature = temp_init | |||
| self.kl_weight = kl_weight | |||
| self.proj = nn.Conv2d(num_hiddens, n_embed, 1) | |||
| self.embed = nn.Embedding(self.n_embed, self.embedding_dim) | |||
| self.use_vqinterface = use_vqinterface | |||
| def forward(self, z, temp=None, return_logits=False): | |||
| hard = self.straight_through if self.training else True | |||
| temp = self.temperature if temp is None else temp | |||
| logits = self.proj(z) | |||
| soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) | |||
| z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) | |||
| # + kl divergence to the prior loss | |||
| qy = F.softmax(logits, dim=1) | |||
| diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() | |||
| ind = soft_one_hot.argmax(dim=1) | |||
| if self.use_vqinterface: | |||
| if return_logits: | |||
| return z_q, diff, (None, None, ind), logits | |||
| return z_q, diff, (None, None, ind) | |||
| return z_q, diff, ind | |||
| class GumbelVQ(nn.Module): | |||
| def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8): | |||
| super().__init__() | |||
| z_channels = ddconfig['z_channels'] | |||
| self.encoder = Encoder(**ddconfig) | |||
| self.decoder = Decoder(**ddconfig) | |||
| self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0) | |||
| self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) | |||
| self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) | |||
| def encode(self, x): | |||
| h = self.encoder(x) | |||
| h = self.quant_conv(h) | |||
| quant, emb_loss, info = self.quantize(h) | |||
| return quant, emb_loss, info | |||
| def decode(self, quant): | |||
| quant = self.post_quant_conv(quant) | |||
| dec = self.decoder(quant) | |||
| return dec | |||
| @ -0,0 +1,34 @@ | |||
| model: | |||
| base_learning_rate: 4.5e-06 | |||
| target: taming.models.vqgan.GumbelVQ | |||
| params: | |||
| kl_weight: 1.0e-08 | |||
| embed_dim: 256 | |||
| n_embed: 8192 | |||
| monitor: val/rec_loss | |||
| temperature_scheduler_config: | |||
| target: taming.lr_scheduler.LambdaWarmUpCosineScheduler | |||
| params: | |||
| warm_up_steps: 0 | |||
| max_decay_steps: 1000001 | |||
| lr_start: 0.9 | |||
| lr_max: 0.9 | |||
| lr_min: 1.0e-06 | |||
| ddconfig: | |||
| double_z: false | |||
| z_channels: 256 | |||
| resolution: 256 | |||
| in_channels: 3 | |||
| out_ch: 3 | |||
| ch: 128 | |||
| ch_mult: | |||
| - 1 | |||
| - 1 | |||
| - 2 | |||
| - 4 | |||
| num_res_blocks: 2 | |||
| attn_resolutions: | |||
| - 32 | |||
| dropout: 0.0 | |||
| lossconfig: | |||
| target: taming.modules.losses.vqperceptual.DummyLoss | |||
| @ -0,0 +1,13 @@ | |||
| [pep8] | |||
| max-line-length = 120 | |||
| exclude = .tox,*migrations*,.json | |||
| [flake8] | |||
| max-line-length = 120 | |||
| exclude = .tox,*migrations*,.json | |||
| [autopep8-wrapper] | |||
| exclude = .tox,*migrations*,.json | |||
| [check-docstring-first] | |||
| exclude = .tox,*migrations*,.json | |||
| @ -0,0 +1,39 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import io | |||
| from os.path import abspath, dirname | |||
| import PIL | |||
| import pytest | |||
| import requests | |||
| from rudalle import get_tokenizer, get_rudalle_model, get_vae | |||
| TEST_ROOT = dirname(abspath(__file__)) | |||
| @pytest.fixture(scope='module') | |||
| def vae(): | |||
| vae = get_vae(pretrained=False) | |||
| yield vae | |||
| @pytest.fixture(scope='module') | |||
| def yttm_tokenizer(): | |||
| tokenizer = get_tokenizer() | |||
| yield tokenizer | |||
| @pytest.fixture(scope='module') | |||
| def sample_image(): | |||
| url = 'https://cdn.kqed.org/wp-content/uploads/sites/12/2013/12/rudolph.png' | |||
| resp = requests.get(url) | |||
| resp.raise_for_status() | |||
| image = PIL.Image.open(io.BytesIO(resp.content)) | |||
| yield image | |||
| @pytest.fixture(scope='module') | |||
| def small_dalle(): | |||
| model = get_rudalle_model('small', pretrained=False, fp16=False, device='cpu') | |||
| return model | |||
| @ -0,0 +1,31 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import torch | |||
| import pytest | |||
| from .test_vae import preprocess | |||
| @pytest.mark.parametrize('text', [ | |||
| 'мальчик играет с оленем', | |||
| ]) | |||
| def test_forward_step_and_criterion(text, sample_image, yttm_tokenizer, vae, small_dalle): | |||
| bs = 4 | |||
| text_seq_length = small_dalle.get_param('text_seq_length') | |||
| total_seq_length = small_dalle.get_param('total_seq_length') | |||
| device = small_dalle.get_param('device') | |||
| img = sample_image.copy() | |||
| img = preprocess(img, target_image_size=256) | |||
| images = img.repeat(bs, 1, 1, 1).to(device) | |||
| text = text.lower().strip() | |||
| text_input_ids = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length) | |||
| text_input_ids = text_input_ids.unsqueeze(0).repeat(bs, 1).to(device) | |||
| attention_mask = torch.tril(torch.ones((bs, 1, total_seq_length, total_seq_length), device=device)) | |||
| with torch.no_grad(): | |||
| image_input_ids = vae.get_codebook_indices(images) | |||
| input_ids = torch.cat((text_input_ids, image_input_ids), dim=1) | |||
| loss, loss_values = small_dalle.forward(input_ids, attention_mask, return_loss=True) | |||
| assert type(loss.data.detach().item()) == float | |||
| assert type(loss_values) == dict | |||
| @ -0,0 +1,17 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import pytest | |||
| @pytest.mark.parametrize('text, text_seq_length, bpe_dropout', [ | |||
| ('hello, how are you?', 128, 0.1), | |||
| ('hello, how are you?', 128, 0.5), | |||
| ('hello, how are you?', 128, 1.0), | |||
| ('hello ... how are you ?', 256, 1.0), | |||
| ('a person standing at a table with bottles of win', 64, 0.5), | |||
| ('привет как дела???', 76, 0.0), | |||
| ('клип на русском языке :)', 76, 0.1), | |||
| ]) | |||
| def test_encode_decode_text_yttm(yttm_tokenizer, text, text_seq_length, bpe_dropout): | |||
| tokens = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length, bpe_dropout=bpe_dropout) | |||
| decoded_text = yttm_tokenizer.decode_text(tokens) | |||
| assert text == decoded_text | |||
| @ -0,0 +1,49 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import PIL | |||
| import pytest | |||
| import torch | |||
| import torchvision.transforms as T | |||
| import torchvision.transforms.functional as TF | |||
| @pytest.mark.parametrize('target_image_size', [128, 192, 256]) | |||
| def test_decode_vae(vae, sample_image, target_image_size): | |||
| img = sample_image.copy() | |||
| img = preprocess(img, target_image_size=target_image_size) | |||
| with torch.no_grad(): | |||
| img_seq = vae.get_codebook_indices(img) | |||
| out_img = vae.decode(img_seq) | |||
| assert out_img.shape == (1, 3, target_image_size, target_image_size) | |||
| @pytest.mark.parametrize('target_image_size', [128, 192, 256]) | |||
| def test_reconstruct_vae(vae, sample_image, target_image_size): | |||
| img = sample_image.copy() | |||
| with torch.no_grad(): | |||
| x_vqgan = preprocess(img, target_image_size=target_image_size) | |||
| output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), vae.model) | |||
| assert output.shape == (1, 3, target_image_size, target_image_size) | |||
| def preprocess(img, target_image_size=256): | |||
| s = min(img.size) | |||
| if s < target_image_size: | |||
| raise ValueError(f'min dim for image {s} < {target_image_size}') | |||
| r = target_image_size / s | |||
| s = (round(r * img.size[1]), round(r * img.size[0])) | |||
| img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS) | |||
| img = TF.center_crop(img, output_size=2 * [target_image_size]) | |||
| img = torch.unsqueeze(T.ToTensor()(img), 0) | |||
| return img | |||
| def preprocess_vqgan(x): | |||
| x = 2.*x - 1. | |||
| return x | |||
| def reconstruct_with_vqgan(x, model): | |||
| z, _, [_, _, _] = model.encode(x) | |||
| print(f'VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}') | |||
| xrec = model.decode(z) | |||
| return xrec | |||